tensorflow保存模型及重载模型
保存模型及重载部分模型
# -*- coding=utf-8 -*-
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
#---------------------------------------------保存模型---------------------------------------------
'''
v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2= tf.Variable(tf.zeros([200]), name="v2")
#v3= tf.Variable(tf.zeros([100]), name="v3")
saver = tf.train.Saver()
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
saver.save(sess,"checkpoint/model_test",global_step=1)
'''
#---------------------------------------------重载部分模型---------------------------------------------
v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2= tf.Variable(tf.zeros([200]), name="v2")
v3= tf.Variable(tf.zeros([100]), name="v3")
#saver = tf.train.Saver()
#saver1 = tf.train.Saver([v1])
#saver2 = tf.train.Saver([v2]+[v3])
## 获得保存的模型中的tensor名字
checkpoint_path = '/home/maozezhong/Desktop/bank_competition/cnn-text-classification-tf/checkpoint'
ckpt = tf.train.get_checkpoint_state(checkpoint_path)
reader = pywrap_tensorflow.NewCheckpointReader(ckpt.model_checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
point_name_list = []
for key in var_to_shape_map:
point_name_list.append(key)
print(point_name_list)
## 若全局变量的名称中包含保存的模型的tensor名称则重载
var_to_restore = [] #需要重载的模型参数的tensor名称
var = tf.global_variables()
for val in var:
print(val.name)
for point in point_name_list:
if point in val.name:
var_to_restore.append(val)
break
print(var_to_restore)
saver1 = tf.train.Saver(var_to_restore)
with tf.Session() as sess:
# init_op = tf.global_variables_initializer()
# sess.run(init_op)
saver1.restore(sess, "checkpoint/model_test-1")
#saver2.restore(sess, "checkpoint/model_test-1")
# saver.save(sess,"checkpoint/model_test",global_step=1)