
图片作者 | Canva
机器学习模型只有在触达用户时才能实现真正的价值,而 API 就是实现这一点的桥梁。但仅仅暴露你的模型是不够的;你需要一个安全、可扩展且高效的 API 来确保其可靠性。在本指南中,我们将构建一个生产级的 ML API,并使用 FastAPI,加入身份验证、输入验证和速率限制。这样,你的模型不仅能工作,而且能安全地进行大规模工作。
在本指南中,我将引导你构建一个安全的机器学习 API。我们将涵盖:
- 使用 FastAPI 构建快速、高效的 API
- 使用 JWT(JSON Web Token)身份验证保护你的端点
- 确保模型的输入有效且安全
- 为 API 端点添加速率限制,以防止滥用或过载
- 使用 Docker 将所有内容整洁地打包,实现一致的部署
项目结构将大致如下所示:
1 2 3 4 5 6 7 8 9 10 11 |
secure-ml-API/ ├── app/ │ ├── main.py # FastAPI 入口点 │ ├── model.py # 模型训练和序列化 │ ├── predict.py # 预测逻辑 │ ├── jwt.py # JWT 身份验证逻辑 │ ├── rate_limit.py # 速率限制逻辑 │ ├── validation.py # 输入验证逻辑 ├── Dockerfile # Docker 设置 ├── requirements.txt # Python 依赖项 └── README.md # 项目文档 |
让我们一步一步来。
第一步:训练和序列化模型 (app/model.py)
为了保持简单,我们将使用 Iris 数据集上的 RandomForestClassifier。RandomForestClassifier 是一种对事物(例如花朵、电子邮件、客户)进行分类的机器学习模型。在 Iris 花卉数据集中:
- 输入:4 个数字(萼片和花瓣的长度/宽度)
- 输出:物种(0=setosa、1=versicolor 或 2=virginica)
RandomForest 使用许多决策树来检查输入数字的模式,并根据这些模式返回最可能的物种。
1 2 3 4 5 6 7 8 9 10 11 12 |
# 训练模型并将其保存为 pickle 文件 def train_model(): iris = load_iris() X, y = iris.data, iris.target clf = RandomForestClassifier() clf.fit(X, y) # 保存训练好的模型 with open("app/model.pkl", "wb") as f: pickle.dump(clf, f) if __name__ == "__main__": train_model() |
运行此脚本以生成 model.pkl 文件。
第二步:定义预测逻辑 (app/predict.py)
现在,我们创建一个助手来加载模型并根据输入数据进行预测。
1 2 3 4 5 6 7 8 9 |
import pickle import numpy as np # 加载模型 with open("app/model.pkl", "rb") as f: model = pickle.load(f) # 进行预测 def make_prediction(data): arr = np.array(data).reshape(1, -1) # 将输入重塑为二维 return int(model.predict(arr)[0]) #返回预测的花朵种类 |
该函数期望一个包含 4 个特征的列表(例如 [5.1, 3.5, 1.4, 0.2])。
第三步:验证输入 (app/validation.py)
FastAPI 使用 Pydantic 模型提供自动输入验证。此模型将验证传入的特征是否格式正确。它还在处理之前验证它们是否是适当范围内的数值。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
from pydantic import BaseModel, field_validator from typing import List # 定义一个 Pydantic 模型 class PredictionInput(BaseModel): data: List[float] # 验证器,用于检查输入列表是否包含 4 个值 @field_validator("data") @classmethod def check_length(cls, v): if len(v) != 4: raise ValueError("data 必须包含恰好 4 个浮点数值") return v # 为文档提供示例 schema class Config: json_schema_extra = { "example": { "data": [5.1, 3.5, 1.4, 0.2], } } |
注意:步骤 4-5 是可选的,仅用于安全目的
第四步:添加 JWT 身份验证 (app/jwt.py)
JWT(JSON Web Token)提供比简单的基于令牌的身份验证更安全的身份验证。JWT 允许一个更强大的系统,其中声明(用户信息、过期时间等)嵌入在令牌中。使用共享密钥或公钥/私钥对进行验证。
我们将使用 pyjwt 库来处理 JWT。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
import jwt import os from datetime import datetime, timedelta from fastapi import HTTPException, status from fastapi.security import OAuth2PasswordBearer from typing import Optional from fastapi import Depends SECRET_KEY = os.getenv("SECRET_KEY", "mysecretkey") ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 30 oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): if expires_delta: expire = datetime.utcnow() + expires_delta else: expire = datetime.utcnow() + timedelta(minutes=15) to_encode = data.copy() to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt def verify_token(token: str = Depends(oauth2_scheme)): try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) return payload except jwt.PyJWTError: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的令牌", ) |
您需要创建一个路由来获取 JWT。
第五步:使用速率限制保护您的 API (app/rate_limit.py)
速率限制可以保护您的 API 免遭过度使用。它限制了每个 IP 在一分钟内发送请求的次数。我通过中间件添加了此功能。
RateLimitMiddleware 会检查每个请求的 IP,计算最后 60 秒内收到的请求数量,并在达到限制(默认为 60 次/分钟)时阻止其余请求。它也被称为节流率。如果有人超过限制,他们会收到“429 Too Many Requests”错误。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
import time from fastapi import Request, HTTPException from starlette.middleware.base import BaseHTTPMiddleware import time from fastapi import Request, HTTPException from starlette.middleware.base import BaseHTTPMiddleware class RateLimitMiddleware(BaseHTTPMiddleware): def __init__(self, app, throttle_rate: int = 60): super().__init__(app) self.throttle_rate = throttle_rate self.request_log = {} # 按 IP 跟踪时间戳 async def dispatch(self, request: Request, call_next): client_ip = request.client.host now = time.time() # 清理超过 60 秒的旧请求日志 self.request_log = { ip: [ts for ts in times if ts > now - 60] for ip, times in self.request_log.items() } ip_history = self.request_log.get(client_ip, []) if len(ip_history) >= self.throttle_rate: raise HTTPException(status_code=429, detail="请求次数过多") ip_history.append(now) self.request_log[client_ip] = ip_history return await call_next(request) |
这是一个简单的、基于内存的方法,对于小型项目来说效果很好。
第六步:构建 FastAPI 应用程序
将所有组件组合到 FastAPI 主应用程序中。这将包括健康检查、令牌生成和预测的路由。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
from fastapi import FastAPI, Depends from app.predict import make_prediction from app.jwt import verify_token, create_access_token, ACCESS_TOKEN_EXPIRE_MINUTES from app.rate_limit import RateLimitMiddleware from app.validation import PredictionInput from datetime import timedelta # 初始化 FastAPI 应用 app = FastAPI() # 如果你没有实现步骤 5,则跳过此路由 # 添加速率限制中间件,将请求限制为每分钟 5 次 app.add_middleware(RateLimitMiddleware, throttle_rate=5) # 根端点,用于确认 API 正在运行 @app.get("/") def root(): return {"message": "欢迎来到安全的机器学习 API"} # 如果你没有实现步骤 4,则跳过此路由 # 此端点在提供有效凭证时颁发令牌 @app.post("/token") def login(): # 定义令牌的过期时间(例如,30 分钟) access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) # 生成 JWT 令牌 access_token = create_access_token(data={"sub": "user"}, expires_delta=access_token_expires) return {"access_token": access_token, "token_type": "bearer"} # 预测端点,需要有效的 JWT 令牌进行身份验证 # 此外,输入数据会使用 PredictionInput 模型进行验证 @app.post("/predict") def predict(input_data: PredictionInput, token: str = Depends(verify_token)): prediction = make_prediction(input_data.data) return {"prediction": prediction} |
第七步:Docker 化应用程序
创建一个 Dockerfile 来打包应用程序和所有依赖项。
1 2 3 4 5 6 7 8 9 10 11 |
# 使用官方 Python 镜像 FROM python:3.10-slim # 设置工作目录 WORKDIR /app # 安装依赖项 COPY requirements.txt . RUN pip install --upgrade pip && pip install --no-cache-dir -r requirements.txt # 复制应用程序代码 COPY ./app ./app # 使用 Uvicorn 运行 FastAPI 应用 CMD ["python", "-m", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] |
以及一个简单的 requirements.txt 文件,内容如下:
1 2 3 4 5 6 7 8 9 10 11 |
scikit-learn numpy python-dotenv pyjwt aioredis fastapi-limiter redis pydantic fastapi uvicorn starlette |
第 8 步:构建和运行 Docker 容器
使用以下命令运行您的 API:
1 2 3 |
# 构建 Docker 镜像并运行它 docker build -t secure-ml-api . docker run -p 8000:8000 secure-ml-api |
现在您的机器学习 API 将可以通过 https://:8000 访问。
第 9 步:使用 Curl 测试您的 API
为此,请首先运行以下命令获取 JWT:
1 |
curl -X POST http://:8000/token |
复制访问令牌,然后运行以下命令:
1 2 3 4 |
curl -X POST http://:8000/predict \ -H "Content-Type: application/json" \ -H "Authorization: Bearer PASTE-TOKEN-HERE" \ -d '{"data": [1.5, 2.3, 3.1, 0.7]}' |
您应该会收到一个类似以下的预测结果:
1 |
{"prediction": 0} |
您可以尝试不同的输入来测试 API。
结论
将机器学习模型部署为安全 API 需要仔细关注身份验证、验证和可伸缩性。通过利用 FastAPI 的速度和简洁性以及 Docker 的可移植性,您可以创建健壮的端点,安全地公开模型的预测,同时防止滥用。这种方法确保了您的机器学习解决方案不仅准确,而且在实际应用中可靠且安全。
暂无评论。