import tensorflow as tf
import numpy as np
import sys
# 定义模型结构
class Net(tf.keras.Model):
'''A simple linear model.'''
def __init__(self):
super(Net,self).__init__()
self.l1=tf.keras.layers.Dense(5)
def call(self,x):
return self.l1(x)
# 初始化模型,保存模型参数
net=Net()
net.save_weights('easy_checkpoint')
# 构建数据集,生成shape为10x1的输入数据和10x5的输出数据,一次取两个输入数据,2x1的shape
def toy_dataset():
inputs=tf.range(10.)[:,None]
labels=inputs*5.+tf.range(5.)[None,:]
return tf.data.Dataset.from_tensor_slices(dict(x=inputs,y=labels)).repeat().batch(2)
# 训练过程
def train_step(net,example,optimizer):
'''Trains `net` on `example` using `optimizer`.'''
with tf.GradientTape() as tape:
output=net(example['x'])
loss=tf.reduce_mean(tf.abs(output-example['y']))
variables=net.trainable_variables
gradients=tape.gradient(loss,variables)
optimizer.apply_gradients(zip(gradients,variables))
return loss
# 构建训练过程所需的参数,检查点和检查点管理器
opt=tf.keras.optimizers.Adam(0.1)
dataset=toy_dataset()
iterator=iter(dataset)
ckpt=tf.train.Checkpoint(step=tf.Variable(1),optimizer=opt,net=net,iterator=iterator)
manager=tf.train.CheckpointManager(ckpt,'./tf_ckpts',max_to_keep=3)
# 训练和检查点
def train_and_checkpoint(net,manager):
ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
print('Restored from {}'.format(manager.latest_checkpoint))
else:
print('Initializing from scratch')
for _ in range(50):
example=next(iterator)
loss=train_step(net,example,opt)
ckpt.step.assign_add(1)
if int(ckpt.step)%10==0:
save_path=manager.save()
print('Saved checkpoint for step {}: {}'.format(int(ckpt.step),save_path))
print('loss {:1.2f}'.format(loss.numpy()))
train_and_checkpoint(net,manager)
,训练过程中保存和使用检查点,下面是加载检查点获取网络内容
import tensorflow as tf
# 获取全连接层的偏置项
to_restore=tf.Variable(tf.zeros([5]))
print(to_restore.numpy())
fake_layer=tf.train.Checkpoint(bias=to_restore)
fake_net=tf.train.Checkpoint(l1=fake_layer)
new_root=tf.train.Checkpoint(net=fake_net)
status=new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())
# 能否匹配上
status.assert_existing_objects_matched()
# 核心节点是否都能匹配上
# status.assert_consumed()
# 检查点的懒加载,获取全连接层的权重
deferred_restore=tf.Variable(tf.zeros([1,5]))
print(deferred_restore.numpy())
fake_layer.kernel=deferred_restore
print(deferred_restore.numpy())
# 检查网络的节点名
reader=tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key=reader.get_variable_to_shape_map()
dtype_from_key=reader.get_variable_to_dtype_map()
print(sorted(shape_from_key.keys()))
print(sorted(dtype_from_key.keys()))
# 获取节点名的属性
key='net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'
print('Shape:',shape_from_key[key])
print('Dtype:',dtype_from_key[key].name)
print(reader.get_tensor(key))
# 懒加载的详细说明,python对象追踪
save=tf.train.Checkpoint()
save.listed=[tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped={'one':save.listed[0]}
save.mapped['two']=save.listed[1]
save_path=save.save('./tf_list_example')
restore=tf.train.Checkpoint()
v2=tf.Variable(0.)
assert 0.==v2.numpy()
restore.mapped={'two':v2}
restore.restore(save_path)
assert 2.==v2.numpy()
restore.listed=[]
print(restore.listed)
v1=tf.Variable(0.)
restore.listed.append(v1)
assert 1.==v1.numpy()
参考链接:Training checkpoints
本文创建于2022.10.5/0.8,修改于2022.10.5/0.8