记录模块的简单使用模型存储与恢复
原文地址:[ https://www.cnblogs.com/mbcbyq-2137/p/10044837.html] ( https://www.cnblogs.com/mbcbyq-2137/p/10044837.html) [2018-11-30 16:15]
虽然说 TensorFlow 2.0 即将问世,但是有一些模块的内容却是不大变化的。其中就有 tf.saved_model 模块,主要用于模型的存储和恢复。为了防止学习记录文件丢失或者蠢笨的脑子直接遗忘掉这部分内容,在此做点简单的记录,以便将来查阅。
最近为了一个课程作业,不得已涉及到关于图像超分辨率恢复的内容,不得不准备随时存储训练的模型,只好再回过头来瞄一眼 TensorFlow 文档,真是太痛苦了。
tf.saved_model 模块下面有很多文件和函数,精力有限,只好选择于自己有用的东西来看,可能并不全面,望日后补上。
其中最重要的就是该模块下的一个类:tf.saved_model.builder.SavedModelBuilder
tf.saved_model.builder.SavedModelBuilder:
# 构造函数
.__init__(export_dir)
"""
作用:
创建一个保存模型的实例对象
参数:
export_dir: 模型导出路径,由于 TensorFlow 会在你指定的路径上创建文件夹和文件,所以指定的路径最后不需要带 /,
例如:export_dir='/home/***/saved_model' 即可,最后不需要加上 /
"""
# 方法
# 1
.add_meta_graph_and_variables(sess, tags, signature_def_map=None, assets_collection=None,
clear_devices=False, main_op=None, strip_default_attrs=False, saver=None)
"""
作用:
保存会话对象中的 graph 和所有变量,具体描述可参见文档
参数:
sess: TensorFlow 会话对象,用于保存元图和变量
tags: 用于保存元图的标记集(如果存在多个图对象,需要设置保证每个图标签不一样),是一个列表
signature_def_map: 一个字典,保存模型时传入的参数,key 可以是字符串,也可以是 tf.saved_model.signature_constants 文件下预定义的变量,
值为 signatureDef protobuf(protobuf 是一种结构化的数据存储格式)
assets_collection: 略
clear_devices: 如果需要清除默认图上的设备信息,则设置为 true
main_op: 这个参数包括后面一系列与其相关的东西没有弄明白
strip_default_attrs: 如果设置为 True,将从 NodeDefs 中删除默认值属性
saver: tf.train.Saver 的一个实例,用于导出元图并保存变量
"""
# 2
.add_meta_graph()
"""
作用:
其除了没有 sess 参数以外,其他参数和 .add_meta_graph_and_variables() 一模一样
调用此方法之前必须先调用 .add_meta_graph_and_variables() 方法
"""
# 3
.save(as_text=False)
"""
作用:
将内建的 savedModel protobuf 写入磁盘
"""
除了这个最重要的类以外,tf.saved_model 模块还提供了一些方便构建 builder 和加载模型的函数方法。
# 1
tf.saved_model.utils.build_tensor_info(tensor)
"""
作用:
构建 TensorInfo protobuf,根据输入的 tensor 构建相应的 protobuf,返回的 TensorInfo 中包含输入 tensor 的 name,shape,dtype 信息
参数:
tensor: Tensor 或 SparseTensor
"""
# 2
tf.saved_model.signature_def_utils.build_signature_def(inputs=None, outputs=None, method_name=None)
"""
作用:
构建 SignatureDef protobuf,并返回 SignatureDef protobuf
参数:
inputs: 一个字典,键为字符串类型,值为关于 tensor 的信息,也就是上述的 .build_tensor_info() 函数返回的 TensorInfo protobuf
outputs: 一个字典,同上
method_name: SignatureDef 名称
"""
# 3
tf.saved_model.utils.get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None)
"""
作用:
根据一个 TensorInfo protobuf 解析出一个 tensor
参数:
tensor_info: 一个 TensorInfo protobuf
graph: tensor 所存在的 graph,参数为 None 时,使用默认图
import_scope: 给 tensor 的 name 加上前缀
"""
# 4
tf.saved_model.loader.load(sess, tags, export_dir, import_scope=None, **saver_kwargs)
"""
作用:
加载已存储的模型
参数:
sess: 用于恢复模型的 tf.Session() 对象
tags: 用于标识 MetaGraphDef 的标记,应该和存储模型时使用的此参数完全一致
export_dir: 模型存储路径
import_scope: 加前缀
"""
除了这些以外,还有一些 TensorFlow 为了方便而预定义的一些变量,这些变量完全可以使用自定义字符串代替,不再赘述。详情: https://tensorflow.google.cn/api_docs/python/tf/saved_model
如果只看这些内容的话,确实会使人产生巨大的疑惑,下面是具体实践的例子:
import tensorflow as tf
from tensorflow import saved_model as sm
# 首先定义一个极其简单的计算图
X = tf.placeholder(tf.float32, shape=(3, ))
scale = tf.Variable([10, 11, 12], dtype=tf.float32)
y = tf.multiply(X, scale)
# 在会话中运行
with tf.Session() as sess:
sess.run(tf.initializers.global_variables())
value = sess.run(y, feed_dict={X: [1., 2., 3.]})
print(value)
# 准备存储模型
path = '/home/×××/tf_model/model_1'
builder = sm.builder.SavedModelBuilder(path)
# 构建需要在新会话中恢复的变量的 TensorInfo protobuf
X_TensorInfo = sm.utils.build_tensor_info(X)
scale_TensorInfo = sm.utils.build_tensor_info(scale)
y_TensorInfo = sm.utils.build_tensor_info(y)
# 构建 SignatureDef protobuf
SignatureDef = sm.signature_def_utils.build_signature_def(
inputs={'input_1': X_TensorInfo, 'input_2': scale_TensorInfo},
outputs={'output': y_TensorInfo},
method_name='what'
)
# 将 graph 和变量等信息写入 MetaGraphDef protobuf
# 这里的 tags 里面的参数和 signature_def_map 字典里面的键都可以是自定义字符串,TensorFlow 为了方便使用,不在新地方将自定义的字符串忘记,可以使用预定义的这些值
builder.add_meta_graph_and_variables(sess, tags=[sm.tag_constants.TRAINING],
signature_def_map={sm.signature_constants.CLASSIFY_INPUTS: SignatureDef}
)
# 将 MetaGraphDef 写入磁盘
builder.save()
这样我们就把模型整体存储到了磁盘中,而且我们将三个变量 X, scale, y 全部序列化后存储到了其中,所以恢复模型时便可以将他们完全解析出来:
import tensorflow as tf
from tensorflow import saved_model as sm
# 需要建立一个会话对象,将模型恢复到其中
with tf.Session() as sess:
path = '/home/×××/tf_model/model_1'
MetaGraphDef = sm.loader.load(sess, tags=[sm.tag_constants.TRAINING], export_dir=path)
# 解析得到 SignatureDef protobuf
SignatureDef_d = MetaGraphDef.signature_def
SignatureDef = SignatureDef_d[sm.signature_constants.CLASSIFY_INPUTS]
# 解析得到 3 个变量对应的 TensorInfo protobuf
X_TensorInfo = SignatureDef.inputs['input_1']
scale_TensorInfo = SignatureDef.inputs['input_2']
y_TensorInfo = SignatureDef.outputs['output']
# 解析得到具体 Tensor
# .get_tensor_from_tensor_info() 函数中可以不传入 graph 参数,TensorFlow 自动使用默认图
X = sm.utils.get_tensor_from_tensor_info(X_TensorInfo, sess.graph)
scale = sm.utils.get_tensor_from_tensor_info(scale_TensorInfo, sess.graph)
y = sm.utils.get_tensor_from_tensor_info(y_TensorInfo, sess.graph)
print(sess.run(scale))
print(sess.run(y, feed_dict={X: [3., 2., 1.]}))
# �
- 原文作者:知识铺
- 原文链接:https://index.zshipu.com/geek/post/%E4%BA%92%E8%81%94%E7%BD%91/%E8%AE%B0%E5%BD%95%E6%A8%A1%E5%9D%97%E7%9A%84%E7%AE%80%E5%8D%95%E4%BD%BF%E7%94%A8%E6%A8%A1%E5%9E%8B%E5%AD%98%E5%82%A8%E4%B8%8E%E6%81%A2%E5%A4%8D/
- 版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 4.0 国际许可协议进行许可,非商业转载请注明出处(作者,原文链接),商业转载请联系作者获得授权。
- 免责声明:本页面内容均来源于站内编辑发布,部分信息来源互联网,并不意味着本站赞同其观点或者证实其内容的真实性,如涉及版权等问题,请立即联系客服进行更改或删除,保证您的合法权益。转载请注明来源,欢迎对文章中的引用来源进行考证,欢迎指出任何有错误或不够清晰的表达。也可以邮件至 sblig@126.com