原文作者: 章立 本站授权首发
checkpoint 主要的目的有两个:
- 如果训练过程中出现的意外情况,可以通过checkpoint快速恢复
- 通过checkpoint可以
stop early
在keras中使用 Model.save_weights
并且如果是 Model.save_weights
方法生成的checkpoint,需要使用 Model.load_weights
来加载,不能使用 tf.train.Checkpoint.restore
API文档中建议使用 tf.train.Checkpoint
- 定义checkpoint生成策略
- 管理checkpoint的恢复
/Users/ki/anaconda3/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
from ._conv import register_converters as _register_converters
在使用checkpoint之前,首先需要我们定义一个简单的网络与一个简单的输入,就像 quickstart2 中所介绍的构建方式一样
首先需要明确的是,对于tensorflow来说他的主要的对象都是类似于 tf.Variable
- 如何让对象被checkpoint?
- 什么对象才能被checkpoint?
- 如何从checkpoint恢复到对象中?
类的构造器: __init__(**kwag)
loss 2.08
loss 0.98
loss 2.09
loss 3.38
loss 4.45
loss 2.53
loss 1.36
loss 1.28
loss 2.03
loss 2.66
loss 2.78
loss 1.96
loss 1.24
loss 0.82
loss 0.78
loss 2.79
loss 2.07
loss 1.34
loss 0.85
loss 0.68
loss 2.61
loss 1.78
loss 1.08
loss 0.75
loss 1.13
loss 2.37
loss 1.46
loss 0.68
loss 0.55
loss 1.33
loss 2.20
loss 1.37
loss 0.59
loss 0.49
loss 1.13
loss 2.07
loss 1.30
loss 0.54
loss 0.42
loss 1.05
loss 1.90
loss 1.20
loss 0.59
loss 0.32
loss 0.83
loss 1.76
loss 1.16
loss 0.60
loss 0.26
loss 0.71
Q1 如何让对象被checkpoint?
通过将对象通过 Checkpoint
构造器传入,也可以通过 listed
或者 mapped
Q2 什么对象才能checkpoint?
派生出来的对象 才能被checkout
ValueError Traceback (most recent call last)
<ipython-input-4-ce172f4ab95a> in <module>()
4 step = tf.Variable(1), #记录迭代轮数
5 optimizer = opt, #记录优化器状态
----> 6 net = net #记录网络状态
7 )
~/anaconda3/lib/python3.6/site-packages/tensorflow_core/python/training/tracking/util.py in __init__(self, **kwargs)
1777 "object should be trackable (i.e. it is part of the "
1778 "TensorFlow Python API and manages state), please open an issue.")
-> 1779 % (v,))
1780 setattr(self, k, v)
1781 self._save_counter = None # Created lazily for restore-on-create.
ValueError: `Checkpoint` was expecting a trackable object (an object derived from `TrackableBase`), got []. If you believe this object should be trackable (i.e. it is part of the TensorFlow Python API and manages state), please open an issue.
Q3 如何从checkpoint恢复到对象中?
定义 相同参数 的对象,然后将这些对象构造一个Checkout,然后调用restore方法,从指定的路径上恢复。
Q3.1 为什么这里好像显示net3没有被加载成功呢?调用trainable_variables显示的结果不同
[<tf.Variable 'net/dense/kernel:0' shape=(1, 5) dtype=float32, numpy=
array([[4.503945 , 4.5462313, 4.852895 , 4.7684402, 4.9965386]],
<tf.Variable 'net/dense/bias:0' shape=(5,) dtype=float32, numpy=
array([3.375819 , 3.8449383, 2.7898471, 4.4520674, 4.1617193],
原因在于restore是 延迟加载(Delayed restorations)。Layer对象会将其内部的Variable的创建延迟到其首次调用。
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
rewrite_options {
meta_optimizer_iterations: ONE
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x1837233860>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from ./tf_estimator_example/model.ckpt-10
WARNING:tensorflow:From /Users/ki/anaconda3/lib/python3.6/site-packages/tensorflow_core/python/training/saver.py:1069: get_checkpoint_mtimes (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file utilities to get mtimes.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:loss = 3.5265698, step = 11
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 12 vs previous value: 12. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 15 vs previous value: 15. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 18 vs previous value: 18. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
INFO:tensorflow:Saving checkpoints for 20 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Loss for final step: 33.14527.
- 原文作者:知识铺
- 原文链接:https://index.zshipu.com/geek/post/%E4%BA%92%E8%81%94%E7%BD%91/%E7%9A%84%E6%95%99%E7%A8%8B/
- 版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 4.0 国际许可协议进行许可,非商业转载请注明出处(作者,原文链接),商业转载请联系作者获得授权。
- 免责声明:本页面内容均来源于站内编辑发布,部分信息来源互联网,并不意味着本站赞同其观点或者证实其内容的真实性,如涉及版权等问题,请立即联系客服进行更改或删除,保证您的合法权益。转载请注明来源,欢迎对文章中的引用来源进行考证,欢迎指出任何有错误或不够清晰的表达。也可以邮件至 sblig@126.com