保存模型的参数Tensorflow/Keras

论坛 期权论坛 编程之家     
选择匿名的用户   2021-6-1 11:53   634   0

Tensorflow:

保存模型:

saver = tf.train.Saver()
saver.save(sess, "{}.ckpt")

加载模型:

saver = tf.train.import_meta_graph('{}.ckpt.meta')
saver.restore(sess, {}.ckpt)

修改模型某些参数

    with tf.Session(graph=tf.Graph()) as sess:  #添加了graph=tf.Graph()可以加载多个模型
        '''
        new_var_list=[] #新建一个空列表存储更新后的Variable变量
        for var_name, _ in tf.train.list_variables(args.checkpoint_path): #得到checkpoint文件中所有的参数(名字,形状)元组
            var = tf.train.load_variable(args.checkpoint_path, var_name) #得到上述参数的值
            new_var = var*change #修改参数值(var)
            print('change %s' % var_name)
            changed_var = tf.Variable(new_var, name=var_name) #使用加入前缀的新名称重新构造了参数
            new_var_list.append(changed_var) #把赋予新名称的参数加入空列表
        
        print('starting to write new checkpoint !')
        saver = tf.train.Saver(var_list=new_var_list) #构造一个保存器
        sess.run(tf.global_variables_initializer()) #初始化一下参数(这一步必做)
        model_name = 'mnist_mlp_chg' #构造一个保存的模型名称
        checkpoint_path = os.path.join(args.new_checkpoint_path, model_name) #构造一下保存路径
        saver.save(sess, checkpoint_path) #直接进行保存
        print("done !")
        '''
        saver = tf.train.import_meta_graph(ori_ckpt_path)
        saver.restore(sess, ori_ckpt_path)

        w2 = tf.get_default_graph().get_tensor_by_name("w2:0")
        b2 = tf.get_default_graph().get_tensor_by_name("b2:0")

        print("【Original W】", sess.run(w2))

        accuracy = tf.get_collection("accuracy")
        x = tf.get_default_graph().get_tensor_by_name("x:0")
        y_ = tf.get_default_graph().get_tensor_by_name("y_labels:0")  # 真实标签
        print("\n【Original testing accuracy:】", sess.run(accuracy, feed_dict={x: X_test,y_: y_test}))
        

        array_w2 = sess.run(w2)
        array_b2 = sess.run(b2)
        new_w2 = array_w2 * change_rate
        new_b2 = array_b2 * change_rate
        w2 = tf.assign(w2, new_w2)
        b2 = tf.assign(b2, new_b2)
        sess.run(w2)
        sess.run(b2)
        saver.save(sess, new_checkpoint_path)  # 直接进行保存

Keras:

保存模型:

model = Sequential()
model.add(...)
model.save('{}.h5')

加载模型:

from keras.models import load_model
model=load_model('{}.h5')

修改模型参数:

Keras的模型是用hdf5存储的,如果想要查看模型,keras提供了get_weights的函数可以查看:

for layer in model.layers:
    weights = layer.get_weights()  # list of numpy array

而通过hdf5模块也可以读取:hdf5的数据结构主要是File - Group - Dataset三级,具体操作API可以看官方文档。weights的tensor保存在Dataset的value中,而每一集都会有attrs保存各网络层的属性:

import h5py

def print_keras_wegiths(weight_file_path):
    f = h5py.File(weight_file_path)  # 读取weights h5文件返回File类
    try:
        if len(f.attrs.items()):
            print("{} contains: ".format(weight_file_path))
            print("Root attributes:")
        for key, value in f.attrs.items():
            print("  {}: {}".format(key, value))  # 输出储存在File类中的attrs信息,一般是各层的名称

        for layer, g in f.items():  # 读取各层的名称以及包含层信息的Group类
            print("  {}".format(layer))
            print("    Attributes:")
            for key, value in g.attrs.items(): # 输出储存在Group类中的attrs信息,一般是各层的weights和bias及他们的名称
                print("      {}: {}".format(key, value))  

            print("    Dataset:")
            for name, d in g.items(): # 读取各层储存具体信息的Dataset类
                print("      {}: {}".format(name, d.value.shape)) # 输出储存在Dataset中的层名称和权重,也可以打印dataset的attrs,但是keras中是空的
                print("      {}: {}".format(name. d.value))
    finally:
        f.close()

而如果想修改某个值,则需要通过新建File类,然后用create_group, create_dataset函数将信息重新写入,具体操作可以查看write hdf5 file

参考:

https://www.christopherlovell.co.uk/blog/2016/04/27/h5py-intro.html

分享到 :
0 人收藏
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

积分:3875789
帖子:775174
精华:0
期权论坛 期权论坛
发布
内容

下载期权论坛手机APP