本文最后更新于:星期二, 八月 2日 2022, 9:32 晚上
使用LSTM训练最简单的IMDB影评分类任务,总结文本分类任务常见流程。
1. 模型训练和保存
1.1 训练结束时保存
训练模型,使用fit函数。fit函数的参数如下。
fit(
x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None,
validation_split=0.0, validation_data=None, shuffle=True, class_weight=None,
sample_weight=None, initial_epoch=0, steps_per_epoch=None,
validation_steps=None, validation_batch_size=None, validation_freq=1,
max_queue_size=10, workers=1, use_multiprocessing=False
)
x:训练数据
y:训练标签
batch_size:批次大小,默认为32
validation_data:在每个epoch结束之时计算loss等其他模型性能指标,不用做训练。
epoch:训练轮次
verbose:输出的详细程度,为1则输出进度条,表明每个epoch训练完成度;为0则什么也不输出,为2则很啰嗦地输出所有信息
最后保存模型用model.save('xxx.h5')
,这里模型格式为HDF5,因此结尾为h5。
model.fit(X_train, y_train, validation_data=(X_test, y_test), epoch=10, batch_size=64)
scores = model.evaluate(X_test, y_test, verbose=0)
print("Accuracy: %.2f%%" % (scores[1]*100))
model.save('models/sentiment-lstm.h5')
1.2 在训练期间保存模型(以 checkpoints 形式保存)
您可以使用训练好的模型而无需从头开始重新训练,或在您打断的地方开始训练,以防止训练过程没有保存。tf.keras.callbacks.ModelCheckpoint
允许在训练的过程中和结束时回调保存的模型。
checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
# 创建一个保存模型权重的回调函数
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
save_weights_only=True,
verbose=1)
# 使用新的回调函数训练模型
model.fit(train_images,
train_labels,
epochs=10,
validation_data=(test_images,test_labels),
callbacks=[cp_callback]) # 通过回调训练
# 这可能会生成与保存优化程序状态相关的警告。
# 这些警告(以及整个笔记本中的类似警告)是防止过时使用,可以忽略。
这将创建一个 TensorFlow checkpoint 文件集合,这些文件在每个 epoch 结束时更新
cp.ckpt.data-00001-of-00002
cp.ckpt.data-00000-of-00002
cp.ckpt.index
默认的 tensorflow 格式仅保存最近的5个 checkpoint 。
1.3 手动保存权重
不必等待epoch结束,通过执行save_weights
就可以生成ckpt文件。
# 保存权重
model.save_weights('./checkpoints/my_checkpoint')
2. 模型加载
2.1 从h5文件中恢复
# 重新创建完全相同的模型
model=load_model('models/sentiment-lstm.h5')
# 加载后重新编译模型,否则您将失去优化器的状态
model.compile(loss='binary_crossentropy',optimizer='adam', metrics=['accuracy'])
model.summary()
加载模型的时候,损失函数等参数需要重新设置。
2.2 从ckpt文件中断点续训
仅恢复模型的权重时,必须具有与原始模型具有相同网络结构的模型。
# 这个模型与ckpt保存的一样架构,只不过没经过fit训练
model = create_model()
# 加载权重
model.load_weights(checkpoint_path)
我们可以对回调函数增加一些新的设置,之前的回调函数每个epoch都覆盖掉之前的ckpt,现在我们想每过5个epoch保存一个新的断点:
# 在文件名中包含 epoch (使用 `str.format`)
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
# 创建一个回调,每 5 个 epochs 保存模型的权重
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
verbose=1,
save_weights_only=True,
period=5)
利用新的回调训练,并随后选择最新的断点文件:
# 使用新的回调训练模型
model.fit(train_images,
train_labels,
epochs=50,
callbacks=[cp_callback],
validation_data=(test_images,test_labels),
verbose=0)
# 选择新的断点
latest = tf.train.latest_checkpoint(checkpoint_dir)
>>> 'training_2/cp-0050.ckpt'
# 加载以前保存的权重
model.load_weights(latest)
record Python TensorFlow SaveModel
本博客所有文章除特别声明外,均采用 CC BY-SA 3.0协议 。转载请注明出处!