
为 MLOps 保护 FastAPI 端点:身份验证指南
作者提供图片
引言
在当今的AI领域,数据科学家不仅专注于训练和优化机器学习模型。公司越来越青睐那些在机器学习操作(MLOps)方面具备技能的数据科学家,这包括为模型推理构建REST API以及将模型部署到云端。虽然创建简单的API对于测试目的可能有效,但在生产环境中部署模型需要更 robust 的方法,尤其是在安全性方面。
在本教程中,我们将使用FastAPI构建一个简单的机器学习应用程序。然后,我们将指导您如何为该应用程序设置认证,确保只有拥有正确令牌的用户才能访问模型以生成预测。
1. 设置项目
我们将构建一个“葡萄酒分类器”,并开始创建 Python 虚拟环境,安装训练和提供模型所需的 Python 库。
1 2 3 |
python -m venv venv source venv/bin/activate # Windows: venv\Scripts\activate pip install fastapi uvicorn scikit-learn pandas joblib python-dotenv |
接下来,我们将创建一个train_model.py
文件,并编写一个训练脚本来加载Scikit-learn的玩具数据集,使用随机森林分类器进行训练,并将训练好的模型保存在根目录下。
1 2 3 4 5 6 7 |
from sklearn.datasets import load_wine from sklearn.ensemble import RandomForestClassifier import joblib X, y = load_wine(return_X_y=True, as_frame=False) model = RandomForestClassifier(n_estimators=200, random_state=42).fit(X, y) joblib.dump(model, "wine_clf.joblib") |
运行训练脚本
1 |
python train_model.py |
2. 构建简单的FastAPI应用程序
现在,我们将创建一个main.py
文件来构建一个用于模型推理的REST API。该应用程序将加载训练好的模型,定义一个/predict
端点,并处理用于预测的传入请求。/predict
端点接受用户输入,通过模型处理,并返回预测的类别名称。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
import os from typing import List, Optional import joblib import uvicorn from dotenv import load_dotenv from fastapi import Depends, FastAPI, HTTPException, Security, status from fastapi.security.api_key import APIKeyHeader from pydantic import BaseModel app = FastAPI(title="Secured Wine Classifier") MODEL = joblib.load("wine_clf.joblib") CLASS_NAMES = ["Cultivar-0", "Cultivar-1", "Cultivar-2"] class WineRequest(BaseModel): data: List[List[float]] # each inner list: 13 numeric features class WineResponse(BaseModel): predictions: List[str] @app.post("/predict", response_model=WineResponse) async def predict(payload: WineRequest): preds = MODEL.predict(payload.data) labels = [CLASS_NAMES[i] for i in preds] return WineResponse(predictions=labels) if __name__ == "__main__": uvicorn.run("main:app", host="localhost", port=8000, reload=True) |
运行FastAPI应用程序
1 |
python main.py |
您现在可以使用curl
命令测试/predict
端点
1 2 3 |
curl -X POST http://:8000/predict \ -H "Content-Type: application/json" \ -d '{"data": [[14.23,1.71,2.43,15.6,127,2.80,3.06,0.28,2.29,5.64,1.04,3.92,1065]]}' |
响应
1 |
{"predictions":["Cultivar-0"]} |
正如您所见,该端点目前是未受保护的,这意味着任何人都可以访问它,这对于生产环境来说不是理想的。
3. 设置API密钥和自定义标头
为了保护API,我们将通过API密钥实现认证。首先,创建一个.env
文件并添加API密钥
1 |
API_KEY=abid1234 |
接下来,更新main.py
文件以包含API密钥逻辑。在初始化CLASS_NAMES
变量后添加以下代码
1 2 3 4 |
load_dotenv() API_KEY = os.getenv("API_KEY") API_KEY_NAME = "X-API-Key" api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) |
4. 实现认证依赖
在此步骤中,我们将实现一个认证依赖项来验证客户端提供的API密钥。这确保了只有具有有效API密钥的授权用户才能访问端点。
1 2 3 4 5 6 7 8 |
async def get_api_key(api_key: Optional[str] = Security(api_key_header)): if api_key == API_KEY: return api_key raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API Key", headers={"WWW-Authenticate": "Bearer"}, ) |
5. 使用认证保护端点
定义好认证依赖后,我们可以使用它来保护/predict
端点。通过将依赖项添加到端点,我们可以确保只有具有有效API密钥的请求才能访问预测服务。
以下是更新后的带有认证依赖项的/predict
端点
1 2 3 4 5 6 7 8 |
@app.post("/predict", response_model=WineResponse, dependencies=[Depends(get_api_key)]) async def predict(payload: WineRequest): preds = MODEL.predict(payload.data) labels = [CLASS_NAMES[i] for i in preds] return WineResponse(predictions=labels) if __name__ == "__main__": uvicorn.run("main:app", host="localhost", port=8000, reload=True) |
更新端点后,再次运行应用程序
1 |
python main.py |
您应该在终端中看到以下输出
1 2 3 4 5 6 |
INFO: Will watch for changes in these directories: ['C:\\Repository\\GitHub\\securing-fastapi-endpoints'] INFO: Uvicorn running on http://:8000 (Press CTRL+C to quit) INFO: Started reloader process [8372] using StatReload INFO: Started server process [19020] INFO: Waiting for application startup. INFO: Application startup complete. |
Swagger UI 是由FastAPI自动生成的,提供了一个交互式界面来浏览和测试您的API端点。一旦您的FastAPI应用程序正在运行,您可以通过在浏览器中导航到以下URL来访问Swagger UI:https://:8000/docs

6. 测试安全端点
在本节中,我们将测试/predict
端点,涉及各种情况,以验证API密钥认证是否正常工作。这包括测试缺失的API密钥、无效的API密钥和有效的API密钥。
未提供API密钥进行测试
在此测试中,我们将向/predict
端点发送一个请求,而不提供X-API-Key
标头。
1 2 3 |
curl -X POST http://:8000/predict \ -H "Content-Type: application/json" \ -d '{"data": [[14.23,1.71,2.43,15.6,127,2.80,3.06,0.28,2.29,5.64,1.04,3.92,1065]]}' |
响应
1 |
{"detail":"Invalid API Key"} |
这证实了当未提供API密钥时,端点正确地拒绝了访问。
使用错误的API密钥进行测试
接下来,我们将通过在X-API-Key
标头中提供错误的API密钥来测试端点。
1 2 3 4 |
curl -X POST http://:8000/predict \ -H "Content-Type: application/json" \ -H "X-API-Key: abid11111" \ -d '{"data": [[14.23,1.71,2.43,15.6,127,2.80,3.06,0.28,2.29,5.64,1.04,3.92,1065]]}' |
响应
1 |
{"detail":"Invalid API Key"} |
这证实了当提供无效API密钥时,端点正确地拒绝了访问。
使用正确的API密钥进行测试
最后,我们将通过在X-API-Key
标头中提供正确的API密钥来测试端点。
1 2 3 4 |
curl -X POST http://:8000/predict \ -H "Content-Type: application/json" \ -H "X-API-Key: abid1234" \ -d '{"data": [[14.23,1.71,2.43,15.6,127,2.80,3.06,0.28,2.29,5.64,1.04,3.92,1065]]}' |
响应
1 |
{"predictions":["Cultivar-0"]} |
这证实了当提供有效的API密钥时,端点正确地处理请求。
总结
我们已经通过创建简单的FastAPI应用程序成功训练了模型并提供了服务。此外,我们通过实现认证增强了应用程序,展示了如何将安全性集成到Web API中。
FastAPI还包含内置的安全功能,用于高效的用户管理和基于角色的OAuth2认证系统。其简洁性使其成为构建安全且可扩展的Web应用程序的绝佳选择。
暂无评论。