回到首页

tensorflow Checkpoint实践

import tensorflow as tf
import numpy as np
import sys

# 定义模型结构
class Net(tf.keras.Model):
	'''A simple linear model.'''
	def __init__(self):
		super(Net,self).__init__()
		self.l1=tf.keras.layers.Dense(5)

	def call(self,x):
		return self.l1(x)

# 初始化模型,保存模型参数
net=Net()
net.save_weights('easy_checkpoint')

# 构建数据集,生成shape为10x1的输入数据和10x5的输出数据,一次取两个输入数据,2x1的shape
def toy_dataset():
	inputs=tf.range(10.)[:,None]
	labels=inputs*5.+tf.range(5.)[None,:]
	return tf.data.Dataset.from_tensor_slices(dict(x=inputs,y=labels)).repeat().batch(2)

# 训练过程
def train_step(net,example,optimizer):
	'''Trains `net` on `example` using `optimizer`.'''
	with tf.GradientTape() as tape:
		output=net(example['x'])
		loss=tf.reduce_mean(tf.abs(output-example['y']))
	variables=net.trainable_variables
	gradients=tape.gradient(loss,variables)
	optimizer.apply_gradients(zip(gradients,variables))
	return loss

# 构建训练过程所需的参数,检查点和检查点管理器
opt=tf.keras.optimizers.Adam(0.1)
dataset=toy_dataset()
iterator=iter(dataset)
ckpt=tf.train.Checkpoint(step=tf.Variable(1),optimizer=opt,net=net,iterator=iterator)
manager=tf.train.CheckpointManager(ckpt,'./tf_ckpts',max_to_keep=3)

# 训练和检查点
def train_and_checkpoint(net,manager):
	ckpt.restore(manager.latest_checkpoint)
	if manager.latest_checkpoint:
		print('Restored from {}'.format(manager.latest_checkpoint))
	else:
		print('Initializing from scratch')
	for _ in range(50):
		example=next(iterator)
		loss=train_step(net,example,opt)
		ckpt.step.assign_add(1)
		if int(ckpt.step)%10==0:
			save_path=manager.save()
			print('Saved checkpoint for step {}: {}'.format(int(ckpt.step),save_path))
			print('loss {:1.2f}'.format(loss.numpy()))

train_and_checkpoint(net,manager)
,训练过程中保存和使用检查点,下面是加载检查点获取网络内容
import tensorflow as tf

# 获取全连接层的偏置项
to_restore=tf.Variable(tf.zeros([5]))
print(to_restore.numpy())
fake_layer=tf.train.Checkpoint(bias=to_restore)
fake_net=tf.train.Checkpoint(l1=fake_layer)
new_root=tf.train.Checkpoint(net=fake_net)
status=new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())
# 能否匹配上
status.assert_existing_objects_matched()
# 核心节点是否都能匹配上
# status.assert_consumed()

# 检查点的懒加载,获取全连接层的权重
deferred_restore=tf.Variable(tf.zeros([1,5]))
print(deferred_restore.numpy())
fake_layer.kernel=deferred_restore
print(deferred_restore.numpy())

# 检查网络的节点名
reader=tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key=reader.get_variable_to_shape_map()
dtype_from_key=reader.get_variable_to_dtype_map()
print(sorted(shape_from_key.keys()))
print(sorted(dtype_from_key.keys()))

# 获取节点名的属性
key='net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'
print('Shape:',shape_from_key[key])
print('Dtype:',dtype_from_key[key].name)
print(reader.get_tensor(key))


# 懒加载的详细说明,python对象追踪
save=tf.train.Checkpoint()
save.listed=[tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped={'one':save.listed[0]}
save.mapped['two']=save.listed[1]
save_path=save.save('./tf_list_example')
restore=tf.train.Checkpoint()
v2=tf.Variable(0.)
assert 0.==v2.numpy()
restore.mapped={'two':v2}
restore.restore(save_path)
assert 2.==v2.numpy()
restore.listed=[]
print(restore.listed)
v1=tf.Variable(0.)
restore.listed.append(v1)
assert 1.==v1.numpy()

参考链接:Training checkpoints

本文创建于2022.10.5/0.8,修改于2022.10.5/0.8