tensorflow保存模型及重载模型

tensorflow保存模型及重载模型

保存模型及重载部分模型

# -*- coding=utf-8 -*-
import tensorflow as tf  
from tensorflow.python import pywrap_tensorflow

#---------------------------------------------保存模型---------------------------------------------
'''
v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")  
v2= tf.Variable(tf.zeros([200]), name="v2")  
#v3= tf.Variable(tf.zeros([100]), name="v3") 
saver = tf.train.Saver()  
with tf.Session() as sess:  
     init_op = tf.global_variables_initializer()  
     sess.run(init_op)  
     saver.save(sess,"checkpoint/model_test",global_step=1)
'''

#---------------------------------------------重载部分模型---------------------------------------------
v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")  
v2= tf.Variable(tf.zeros([200]), name="v2")  
v3= tf.Variable(tf.zeros([100]), name="v3")  
#saver = tf.train.Saver()  
#saver1 = tf.train.Saver([v1])  
#saver2 = tf.train.Saver([v2]+[v3])  

## 获得保存的模型中的tensor名字
checkpoint_path = '/home/maozezhong/Desktop/bank_competition/cnn-text-classification-tf/checkpoint'
ckpt = tf.train.get_checkpoint_state(checkpoint_path)
reader = pywrap_tensorflow.NewCheckpointReader(ckpt.model_checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
point_name_list = []
for key in var_to_shape_map:
    point_name_list.append(key)
print(point_name_list)

## 若全局变量的名称中包含保存的模型的tensor名称则重载
var_to_restore = [] #需要重载的模型参数的tensor名称
var = tf.global_variables()
for val in var:
    print(val.name)
    for point in point_name_list:
        if point in val.name:
            var_to_restore.append(val)
            break
print(var_to_restore)
saver1 = tf.train.Saver(var_to_restore) 
with tf.Session() as sess:  
    # init_op = tf.global_variables_initializer()  
    # sess.run(init_op)  
    saver1.restore(sess, "checkpoint/model_test-1")  
    #saver2.restore(sess, "checkpoint/model_test-1")  
    # saver.save(sess,"checkpoint/model_test",global_step=1)

参考