PyTorch Lightning 使用 Optuna 进行超参数优化

PyTorch Lightning Hyperparameter Optimization with Optuna

PyTorch Lightning 使用 Optuna 进行超参数优化
作者 | Ideogram 提供图片

PyTorch Lightning 是近年来发布的一个高级替代方案,用于替代经典的 PyTorch 深度学习建模库。它简化了模型的训练、验证和部署过程。在进行超参数优化时,即寻找能最大化给定任务性能的最佳模型参数集,Optuna 可以是一个与 PyTorch Lightning 结合的强大工具,因为它易于集成,并提供了高效的搜索算法,可以在大量可能的配置中找到模型的最佳设置。

本文将介绍如何联合使用 PyTorch Lightning 和 Optuna 来指导深度学习模型的超参数优化过程。建议您具备实用的神经网络构建和训练基础知识,最好是使用 PyTorch。

分步流程

过程首先需要安装和导入一系列必要的库和模块,包括 PyTorch Lightning 和 Optuna。初始安装过程可能需要一些时间来完成。

接下来,进行大量导入

在使用 PyTorch Lightning 构建神经网络模型时,设置随机种子以保证结果的可复现性是一种常见做法。您可以在代码开头,导入语句之后,通过添加 pl.seed_everything(42) 来实现这一点。

接下来,我们通过创建一个继承自 pl.LightningModule 的类来定义神经网络模型架构:这是 Lightning 中与 PyTorch 的 Module 类相对应的部分。

我们将类命名为 MNISTClassifier,因为它是一个简单的前馈神经网络分类器,我们将使用 MNIST 数据集对其进行训练,以进行低分辨率图像分类。它包含少量线性层,前面有一个展平二维图像数据的输入层,中间是 ReLU 激活函数。该类还定义了用于前向传播的 forward() 方法,以及模拟训练、测试和验证步骤的其他方法。

在新建的类之外,下面的函数也将有助于计算一组多项预测的平均准确率

以下函数建立了一个数据准备管道,我们稍后将将其应用于 MNIST 数据集。它将数据集转换为张量,根据该数据集的先验已知统计数据对其进行归一化,下载并将训练数据分割为训练集和验证集,并为三个数据子集中的每个子集创建一个 DataLoader 对象。

objective() 函数是 Optuna 为定义超参数优化框架提供的核心元素。我们通过首先定义搜索空间来定义它:即需要尝试的超参数及其可能的值。这可以包括与神经网络层相关的架构超参数,也可以包括与训练算法相关的超参数,例如学习率和用于对抗过拟合等问题的 dropout 率。

在此函数中,我们尝试为每个可能的超参数设置初始化和训练模型,并包含一个用于提前停止的回调,以防模型过早稳定。与标准的 PyTorch 一样,Trainer 用于模拟训练过程。

另一个用于管理优化过程的 Optuna 关键函数是 run_optimization,在该函数中,我们根据前面优化函数中定义的规范,指示要运行的随机试验次数。

在完成超参数优化过程并确定最佳模型配置后,需要另一个函数来接收这些结果,并在测试集上评估该模型的性能以进行最终验证。

现在我们有了所有需要的类和函数,最后通过一个将它们整合在一起的演示来完成。

以下是工作流程的概要总结

  1. 运行一个 study(一组 Optuna 实验),指定试验次数
  2. 对训练过程中的优化过程进行一系列可视化
  3. 找到最佳模型后,将其暴露给测试数据以进一步评估

总结

本文介绍了如何将 PyTorch Lightning 和 Optuna 结合使用,为神经网络模型执行高效且有效的超参数优化。Optuna 提供了增强的模型调优算法,而 PyTorch Lightning 则构建在 PyTorch 之上,以更高级的方式进一步简化神经网络建模过程。

暂无评论。

留下回复

Machine Learning Mastery 是 Guiding Tech Media 的一部分,Guiding Tech Media 是一家领先的数字媒体出版商,专注于帮助人们了解技术。访问我们的公司网站以了解更多关于我们的使命和团队的信息。