[Tensorflow] Restore partial weights (부분적으로 네트워크 불러오기)

2021. 4. 9. 11:42Code Functions

# Function

def get_variables_to_restore(variables, var_keep_dic):

    variables_to_restore = []

    for v in variables:

        # one can do include or exclude operations here.

        if v.name.split(':')[0] in var_keep_dic:

            print("Variables restored: %s" % v.name)

            variables_to_restore.append(v)

    return variables_to_restore

 

 

# get the checkpoint of the old-trained model

ckpt = tf.train.get_checkpoint_state(dir_of_ckpt)

pretrained_model = ckpt.model_checkpoint_path

 

 

# Building up the new graph, and create a session: sess

 

 

# Initilize all variables

variables = tf.global_variables()

sess.run(tf.variables_initializer(variables, name='init'))

# Get the trained variables

var_keep_dic = self.get_variable_in_checkpoint_file(pretrained_model)

variable_to_restore = get_variables_to_restore(variables, var_keep_dic)

restorer = tf.train.Saver(variables_to_restore)

restorer.restore(sess, pretrained_model)

print("loaded.")

 

 

 

출처:

innerpeace-wu.github.io/2017/12/13/Tensorflow-Restore-partial-weights/

'Code Functions' 카테고리의 다른 글

Colab Tensorflow Cuda 버전 바꾸기  (0) 2021.01.19
[C++ Function] 3차원 Euclidean Distance Code  (0) 2020.09.02