2021. 4. 9. 11:42ㆍCode Functions
# Function
def get_variables_to_restore(variables, var_keep_dic):
variables_to_restore = []
for v in variables:
# one can do include or exclude operations here.
if v.name.split(':')[0] in var_keep_dic:
print("Variables restored: %s" % v.name)
variables_to_restore.append(v)
return variables_to_restore
# get the checkpoint of the old-trained model
ckpt = tf.train.get_checkpoint_state(dir_of_ckpt)
pretrained_model = ckpt.model_checkpoint_path
# Building up the new graph, and create a session: sess
# Initilize all variables
variables = tf.global_variables()
sess.run(tf.variables_initializer(variables, name='init'))
# Get the trained variables
var_keep_dic = self.get_variable_in_checkpoint_file(pretrained_model)
variable_to_restore = get_variables_to_restore(variables, var_keep_dic)
restorer = tf.train.Saver(variables_to_restore)
restorer.restore(sess, pretrained_model)
print("loaded.")
출처:
innerpeace-wu.github.io/2017/12/13/Tensorflow-Restore-partial-weights/
'Code Functions' 카테고리의 다른 글
Colab Tensorflow Cuda 버전 바꾸기 (0) | 2021.01.19 |
---|---|
[C++ Function] 3차원 Euclidean Distance Code (0) | 2020.09.02 |