如何在 Python 中保存梯度提升模型和 XGBoost

XGBoost 可以用来利用梯度提升算法为表格数据创建性能最佳的模型。

一旦训练完成,将模型保存到文件以供将来在新测试和验证数据集以及全新数据上进行预测,这通常是一种好的做法。

在本帖中,您将了解如何使用标准的 Python pickle API 将 XGBoost 模型保存到文件。

完成本教程后,您将了解:

  • 如何使用 pickle 保存并稍后加载训练好的 XGBoost 模型。
  • 如何使用 joblib 保存并稍后加载训练好的 XGBoost 模型。

通过我的新书《XGBoost With Python启动您的项目,其中包括所有示例的分步教程Python 源代码文件。

让我们开始吧。

  • 2017 年 1 月更新:已更新以反映 scikit-learn API 0.18.1 版本中的更改。
  • **2018 年 3 月更新**:添加了下载数据集的备用链接,因为原始链接似乎已被删除。
  • 更新 2019 年 10 月:已更新为直接使用 Joblib API。
How to Save Gradient Boosting Models with XGBoost in Python

如何在 Python 中保存梯度提升模型和 XGBoost
照片由 Keoni Cabral 提供,保留部分权利。

在 Python 中使用 XGBoost 需要帮助吗?

参加我的免费 7 天电子邮件课程,探索 xgboost(含示例代码)。

立即点击注册,还将免费获得本课程的 PDF 电子书版本。

使用 Pickle 序列化您的 XGBoost 模型

Pickle 是 Python 中序列化对象的标准方法。

您可以使用 Python pickle API 序列化您的机器学习算法,并将序列化格式保存到文件,例如

稍后,您可以加载此文件以反序列化您的模型,并使用它来进行新的预测,例如

下面的示例演示了如何在 Pima Indians 糖尿病发病率数据集上训练 XGBoost 模型,将模型保存到文件,然后稍后加载它来进行预测。

下载数据集并将其保存在当前工作目录中。

完整的代码清单如下。

运行此示例会将您训练好的 XGBoost 模型保存到当前工作目录中的 pima.pickle.dat pickle 文件中。

注意:您的 结果可能因算法的随机性或评估程序的差异,或数值精度的差异而异。请考虑运行示例几次并比较平均结果。

加载模型并在训练数据集上进行预测后,将打印模型的准确率。

使用 joblib 序列化 XGBoost 模型

Joblib 是 SciPy 生态系统的一部分,提供用于 Python 作业管道的实用程序。

Joblib API 提供高效保存和加载使用 NumPy 数据结构的 Python 对象的实用程序。对于非常大的模型,这可能是一种更快的用法。

该 API 看起来很像 pickle API,例如,您可以如下保存训练好的模型

稍后,您可以从文件加载模型并使用它来进行预测,如下所示

下面的示例演示了如何为 Pima Indians 糖尿病发病率数据集训练一个分类 XGBoost 模型,使用 Joblib 将模型保存到文件,然后在稍后加载它以进行预测。

运行此示例会将模型保存到当前工作目录中的文件 pima.joblib.dat,还会为模型中的每个 NumPy 数组创建一个文件(在本例中是另外两个文件)。

注意:您的 结果可能因算法的随机性或评估程序的差异,或数值精度的差异而异。请考虑运行示例几次并比较平均结果。

加载模型后,将在训练数据集上对其进行评估,并打印预测的准确率。

总结

在本帖中,您了解了如何序列化训练好的 XGBoost 模型,并在稍后加载它们以进行预测。

具体来说,你学到了:

  • 如何使用 pickle API 序列化并稍后加载训练好的 XGBoost 模型。
  • 如何使用 joblib API 序列化并稍后加载训练好的 XGBoost 模型。

您对序列化 XGBoost 模型或本文有任何疑问吗?在评论中提出您的问题,我将尽力回答。

发现赢得竞赛的算法!

XGBoost With Python

在几分钟内开发您自己的 XGBoost 模型

...只需几行 Python 代码

在我的新电子书中探索如何实现
使用 Python 实现 XGBoost

它涵盖了自学教程,例如:
算法基础缩放超参数等等……

将 XGBoost 的强大功能带入您自己的项目

跳过学术理论。只看结果。

查看内容

26 条回复“如何在 Python 中使用 XGBoost 保存梯度提升模型”

  1. koji 2018 年 6 月 23 日凌晨 1:18 #

    嗨,Jason。感谢您分享您的知识,我喜欢阅读您的帖子。
    顺便问一下,与使用像 xgb.Booster(model_file=’model.model’) 相比,pickle XGBoost 模型有什么意义吗?

    这是我的实验。

    %timeit model = xgb.Booster(model_file=’model.model’)
    每次循环 118 微秒 ± 1.73 微秒(7 次运行的平均值 ± 标准差,每次 10000 次循环)

    pickle.dump(model, open(“model.pickle”, “wb”))
    %timeit loaded_model = pickle.load(open(“model.pickle”, “rb”))
    每次循环 139 微秒 ± 1.54 微秒(7 次运行的平均值 ± 标准差,每次 10000 次循环)

    我目前正在寻找更好的方法来在生产环境中使用 XGBoost 模型。我担心如果客户端有很多请求,读取文件可能会很慢。

    • Jason Brownlee 2018 年 6 月 23 日上午 6:20 #

      我认为不会,这确实取决于您的项目/代码。

  2. Ran Feldesh 2018 年 10 月 6 日上午 5:59 #

    也许是由于 sklearn 版本的原因,运行代码时会出错,“cross_validation”未找到。删除此项(同时保留“train_test_split”)并将相关的导入语句修改为“from sklearn.model_selection import train_test_split”,就像您其他的 XGBoost 教程一样,可以解决此问题。

    • Jason Brownlee 2018 年 10 月 6 日上午 11:40 #

      您必须使用 0.8 或更高版本的 scikit-learn。

  3. John 2018 年 11 月 4 日凌晨 1:56 #

    嗨,精彩的帖子,

    如果您能写一个关于如何将 xgboost 模型转换为 pmml 的教程,那就太好了。一些关于 PMMLPipeline 的解释以及如何正确使用它来生成 pmml 使用 sklearn2pmml 将会非常有帮助。

  4. Sophia Yue 2019 年 11 月 16 日上午 7:15 #

    嗨,Jason。我喜欢阅读您的帖子。根据我的理解,sklearn 中没有 cross_validation 对象。您可能应该从 cross_validation.train_test_split(X, Y, test_size=test_size, random_state=seed) 中删除它。

    XGBClassifier 对我来说是一个黑盒子。模型构建后,我可以看到特征重要性和准确性。我还需要了解什么?

    pickle.dumop 保存了哪些信息?是否有可能以有意义的方式查看 pima.pickle.dat 的内容?

    谢谢,
    索菲亚

  5. Xin 2020 年 1 月 14 日凌晨 2:41 #

    嘿 Jason,当我尝试通过 XGBost 算法 dump 训练好的模型时,我遇到了一个错误

    AttributeError: function ‘XGBoosterSerializeToBuffer’ not found

    遗憾的是,pickle 和 joblib 都无法工作,但对于回归算法(例如),dump 和 load 可以工作。
    您有什么建议吗?

    • Jason Brownlee 2020 年 1 月 14 日上午 7:24 #

      哦。我以前没见过这个,很抱歉。

      也许试试将代码/错误发布到 stackoverflow?

      • Xin 2020 年 1 月 15 日凌晨 4:07 #

        嘿,我通过将 xgboost 相关文件移动到正确路径解决了这个问题,这样它就可以调用正确的库了。但当虚拟环境改变时,它仍然不总是奏效。

  6. San 2020 年 2 月 4 日晚上 7:32 #

    你好,
    我有一个数据集,特别是处理多类分类问题。目标变量包含 22 个唯一类。我需要使用独热编码对目标变量进行编码,还是将其保留为标签编码就足够了?

    因为如果我使用独热编码对目标变量进行编码,它将导致很多新列。

  7. San 2020 年 2 月 14 日晚上 6:17 #

    我想将这个原始的多类分类问题分解为一组多类子问题。如果我事先知道如何将原始数据集中的类分类到更小的子问题中,手动决定类层次结构是否可以?

    因为我发现一些研究论文提出了通过各种方法(如相似性矩阵)自动推导多类分类问题的类层次结构的方法。

  8. Matan 2020 年 3 月 12 日凌晨 2:53 #

    有人知道这个方法是否已弃用吗?cross_validation.train_test_split

    我正在使用 sklearn 版本 0.22.1

    该方法出现在 0.16.1 版本的文档中,但我找不到关于它是否已弃用的任何信息。

  9. Isaac Tepatl 2020 年 6 月 26 日凌晨 2:42 #

    嗨,Jason,

    这个教程非常有帮助,您的教程很棒……我只是有一个快速的问题,我想知道是否可以从 XGBoost 模型中打印逻辑回归系数。我正在训练一个最终将部署到 SQL Server 的模型,但我需要实际的系数估计。

    非常感谢…

  10. Kartik Shenoy 2020 年 7 月 4 日凌晨 2:49 #

    嗨,Jason,

    我在 Kaggle notebook 上训练了一个 XGB 模型,并按照您提到的方式使用 pickle dump 了模型。

    但是当我尝试从我的笔记本电脑加载这个 pickle 模型时,我收到错误

    AttributeError: Can’t get attribute ‘XGBoostLabelEncoder’ on

    我用于加载模型的代码

    import pickle
    from xgboost import XGBClassifier
    import xgboost

    model = pickle.load(open(‘./filename.pkl’,’rb’))

    我似乎不明白我错过了什么。

    • Jason Brownlee 2020 年 7 月 4 日上午 6:04 #

      听到这个消息我很难过。

      也许确认一下您在工作站上使用的库版本与在 Kaggle 上使用的版本相同?

      • SUMRITI RANJAN PATRA 2020 年 11 月 8 日晚上 10:57 #

        嘿,有人能告诉我如何保存来自最佳迭代的模型(我是通过早期停止得到的)?通常它保存的是最后一次迭代,而不是最佳迭代,请有人告诉我。
        谢谢你。

        • Jason Brownlee 2020 年 11 月 9 日上午 6:12 #

          你确定吗?这听起来不对。

          早期停止会根据您的标准停止训练——训练停止时的模型就是“最佳”模型。

  11. William Smith 2020 年 9 月 13 日上午 11:48 #

    嗨 Jason
    文档现在说应该避免使用 pickle 和 joblib,因为如果您升级 xgboost 并更改了二进制格式,模型将无法加载。

    相反,它们说使用 save_model 和 load_model。

    您能否更新您的示例和您的书(我买了)?

    参见 https://docs.xgboost.com.cn/en/latest/tutorials/saving_model.html

留下回复

Machine Learning Mastery 是 Guiding Tech Media 的一部分,Guiding Tech Media 是一家领先的数字媒体出版商,专注于帮助人们了解技术。访问我们的公司网站以了解更多关于我们的使命和团队的信息。