【Python深度学习系列】Keras回调函数Callbacks使用详解-训练过程可视化、早停、保存恢复(案例+源码)

这是我的第320篇原创文章。

一、引言

keras.callbacks

回调函数是一个函数的合集,会在训练的阶段中所使用。你可以使用回调函数来查看训练模型的内在状态和统计。你可以传递一个列表的回调函数(作为 callbacks 关键字参数)到 Sequential 或 Model 类型的 .fit() 方法。在训练时,相应的回调函数的方法就会被在各自的阶段被调用。

这里有两个关键的点:

(1)状态和统计:其实就是我们希望模型在训练过程中需要从过程中获取什么信息,比如我的损失loss,准确率accuracy等信息就是训练过程中的状态与统计信息;再比如我希望每一个epoch结束之后打印一些相应的自定义提示信息,这也是状态信息。

(2)各自的阶段:模型的训练一般是分为多少个epoch,然后每一个epoch又分为多少个batch,所以这个阶段可以是在每一个epoch之后执行回调函数,也可以是在每一个batch之后执行回调函数。

虽然我们称之为回调“函数”,但事实上Keras的回调函数是一个类,回调函数只是习惯性称呼

keras.callbacks.Callback()

这是回调函数的抽象类,定义新的回调函数必须继承自该类。

系统预定义的回调函数:

图片

二、实现过程

2.1 History

History(训练可视化

keras.callbacks.History()

该回调函数在Keras模型上会被自动调用,History对象即为fit方法的返回值,可以使用history中的存储的acc和loss数据对训练过程进行可视化画图。

示例代码:

class PrintDot(keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs):
        if epoch % 100 == 0:
            print('')
        print('.', end='')



EPOCHS = 1000
model = build_model()
history = model.fit(normed_train_data, train_labels,epochs=EPOCHS, validation_split=0.2, verbose=0,callbacks=[PrintDot()])
hist = pd.DataFrame(history.history)
hist['epoch'] = history.epoch
print('\n', hist.tail())
hist = pd.DataFrame(history.history)
hist['epoch'] = history.epoch

plt.figure()
plt.xlabel('Epoch')
plt.ylabel('Mean Abs Error [MPG]')
plt.plot(hist['epoch'], hist['mae'],
         label='Train Error')
plt.plot(hist['epoch'], hist['val_mae'],
         label='Val Error')
plt.ylim([0, 5])
plt.legend()

plt.figure()
plt.xlabel('Epoch')
plt.ylabel('Mean Square Error [$MPG^2$]')
plt.plot(hist['epoch'], hist['mse'],
         label='Train Error')
plt.plot(hist['epoch'], hist['val_mse'],
         label='Val Error')
plt.ylim([0, 20])
plt.legend()
plt.show()

通过history绘制训练过程损失函数的变化:

图片

定义新的回调函数PrintDot,继承keras.callbacks.Callback,传递给fit函数中的callbacks,实现每个完成的时期打印一个点来显示训练进度:

图片

hist:

图片

2.2 EarlyStopping

EarlyStopping

keras.callbacks.EarlyStopping(monitor=’val_loss’, patience=0, verbose=0, mode=’auto’)

当监测值不再改善时,该回调函数将中止训练。

定义回调函数EarlyStopping,传递给fit函数中的callbacks,实现训练早停,代码示例:

model = build_model()
early_stop = keras.callbacks.EarlyStopping(monitor='val_loss', patience=10)
history = model.fit(normed_train_data, train_labels, epochs=EPOCHS,validation_split=0.2, verbose=0, callbacks=[early_stop, PrintDot()])

2.3 ModelCheckpoint

ModelCheckpoint

keras.callbacks.ModelCheckpoint(
    filepath,
    monitor='val_loss',
    verbose=0,
    save_best_only=True,
    save_weights_only=False,
    mode='auto',
    period=1
)

该回调函数将在每个epoch后保存模型到filepath。

定义回调函数ModelCheckpoint,传递给fit函数中的callbacks,实现模型的保存与恢复,代码示例:代码示例:

filepath = "model_{epoch:02d}-{val_mse:.2f}.h5"
checkpoint = keras.callbacks.ModelCheckpoint(
    filepath=filepath,
    monitor='val_loss',
    save_best_only=True,
    verbose=1,
    save_weights_only=True,
    period=3
)
history = model.fit(normed_train_data, train_labels, epochs=EPOCHS,validation_split=0.2, verbose=0, callbacks=[checkpoint, PrintDot()])

结果:

图片

作者简介:

读研期间发表6篇SCI数据挖掘相关论文,现在某研究院从事数据算法相关科研工作,结合自身科研实践经历不定期分享关于Python、机器学习、深度学习、人工智能系列基础知识与应用案例。致力于只做原创,以最简单的方式理解和学习,关注我一起交流成长。需要数据集和源码的小伙伴可以关注底部公众号添加作者微信。

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-11 12:08:01       53 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-11 12:08:01       56 阅读
  3. 在Django里面运行非项目文件

    2024-07-11 12:08:01       46 阅读
  4. Python语言-面向对象

    2024-07-11 12:08:01       57 阅读

热门阅读

  1. 进阶版智能家居系统Demo[C#]:整合AI和自动化

    2024-07-11 12:08:01       20 阅读
  2. 【C语言】C语言可以做什么?

    2024-07-11 12:08:01       20 阅读
  3. Windows图形界面(GUI)-SDK-C/C++ - 按钮(button)

    2024-07-11 12:08:01       23 阅读
  4. [C++]继承

    2024-07-11 12:08:01       20 阅读
  5. 小笔记(1)

    2024-07-11 12:08:01       18 阅读
  6. Android知识收集

    2024-07-11 12:08:01       18 阅读
  7. 2024前端面试真题【Vue篇】

    2024-07-11 12:08:01       18 阅读