Tensorflow:
保存模型:
saver = tf.train.Saver()
saver.save(sess, "{}.ckpt")
加载模型:
saver = tf.train.import_meta_graph('{}.ckpt.meta')
saver.restore(sess, {}.ckpt)
修改模型某些参数
with tf.Session(graph=tf.Graph()) as sess: #添加了graph=tf.Graph()可以加载多个模型
'''
new_var_list=[] #新建一个空列表存储更新后的Variable变量
for var_name, _ in tf.train.list_variables(args.checkpoint_path): #得到checkpoint文件中所有的参数(名字,形状)元组
var = tf.train.load_variable(args.checkpoint_path, var_name) #得到上述参数的值
new_var = var*change #修改参数值(var)
print('change %s' % var_name)
changed_var = tf.Variable(new_var, name=var_name) #使用加入前缀的新名称重新构造了参数
new_var_list.append(changed_var) #把赋予新名称的参数加入空列表
print('starting to write new checkpoint !')
saver = tf.train.Saver(var_list=new_var_list) #构造一个保存器
sess.run(tf.global_variables_initializer()) #初始化一下参数(这一步必做)
model_name = 'mnist_mlp_chg' #构造一个保存的模型名称
checkpoint_path = os.path.join(args.new_checkpoint_path, model_name) #构造一下保存路径
saver.save(sess, checkpoint_path) #直接进行保存
print("done !")
'''
saver = tf.train.import_meta_graph(ori_ckpt_path)
saver.restore(sess, ori_ckpt_path)
w2 = tf.get_default_graph().get_tensor_by_name("w2:0")
b2 = tf.get_default_graph().get_tensor_by_name("b2:0")
print("【Original W】", sess.run(w2))
accuracy = tf.get_collection("accuracy")
x = tf.get_default_graph().get_tensor_by_name("x:0")
y_ = tf.get_default_graph().get_tensor_by_name("y_labels:0") # 真实标签
print("\n【Original testing accuracy:】", sess.run(accuracy, feed_dict={x: X_test,y_: y_test}))
array_w2 = sess.run(w2)
array_b2 = sess.run(b2)
new_w2 = array_w2 * change_rate
new_b2 = array_b2 * change_rate
w2 = tf.assign(w2, new_w2)
b2 = tf.assign(b2, new_b2)
sess.run(w2)
sess.run(b2)
saver.save(sess, new_checkpoint_path) # 直接进行保存
Keras:
保存模型:
model = Sequential()
model.add(...)
model.save('{}.h5')
加载模型:
from keras.models import load_model
model=load_model('{}.h5')
修改模型参数:
Keras的模型是用hdf5存储的,如果想要查看模型,keras提供了get_weights 的函数可以查看:
for layer in model.layers:
weights = layer.get_weights() # list of numpy array
而通过hdf5 模块也可以读取:hdf5的数据结构主要是File - Group - Dataset三级,具体操作API可以看官方文档。weights的tensor保存在Dataset的value中,而每一集都会有attrs保存各网络层的属性:
import h5py
def print_keras_wegiths(weight_file_path):
f = h5py.File(weight_file_path) # 读取weights h5文件返回File类
try:
if len(f.attrs.items()):
print("{} contains: ".format(weight_file_path))
print("Root attributes:")
for key, value in f.attrs.items():
print(" {}: {}".format(key, value)) # 输出储存在File类中的attrs信息,一般是各层的名称
for layer, g in f.items(): # 读取各层的名称以及包含层信息的Group类
print(" {}".format(layer))
print(" Attributes:")
for key, value in g.attrs.items(): # 输出储存在Group类中的attrs信息,一般是各层的weights和bias及他们的名称
print(" {}: {}".format(key, value))
print(" Dataset:")
for name, d in g.items(): # 读取各层储存具体信息的Dataset类
print(" {}: {}".format(name, d.value.shape)) # 输出储存在Dataset中的层名称和权重,也可以打印dataset的attrs,但是keras中是空的
print(" {}: {}".format(name. d.value))
finally:
f.close()
而如果想修改某个值,则需要通过新建File类,然后用create_group, create_dataset函数将信息重新写入,具体操作可以查看write hdf5 file
参考:
https://www.christopherlovell.co.uk/blog/2016/04/27/h5py-intro.html
|