tf.keras.models.load_model() 添加custom_objects参数仍然出错ValueError: Unknown loss func

论坛 期权论坛 编程之家     
选择匿名的用户   2021-6-1 11:53   5215   0

tf.keras.models.load_model() 加载模型时,有一个参数compile,默认是True,会自动compile。但用户自定义的loss或者metric无法被识别。这时候需要把自定义的函数通过custom_objects传进去。

但是,使用custom_objects参数传入自定义函数,可以解决'.h5'格式的模型。

不能解决savedmodel格式的模型。

可以关注tensorflow github的issue,目前好像还没解决。

https://github.com/tensorflow/tensorflow/pull/34048

fails if the loaded model is a SavedModel (saved with format="tf") 
# this fails if the loaded model is a SavedModel (saved with format="tf") 
from tensorflow.keras import models
model = models.load_model("/path/to/tf_model", custom_objects={"custom_loss": custom_loss})

解决方法:把参数compile设置为False,手动compile。

def mloss(a,b):
    return a-b

def mMetric(a,b):
    return a-b

model_dir = 'path/to/your/tf-format/model'
model = tf.keras.models.load_model(model_dir, compile=False)
model.compile(optimizer = tf.keras.optimizers.Adam(lr = 1e-4), 
            loss = mloss, 
            metrics = ['accuracy', mMetric])

分享到 :
0 人收藏
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

积分:3875789
帖子:775174
精华:0
期权论坛 期权论坛
发布
内容

下载期权论坛手机APP