如何使用FastAPI和Redis缓存构建毫秒级机器学习预测API

如何使用FastAPI和Redis缓存构建毫秒级机器学习预测API插图

你是否曾经等待模型返回预测结果的时间过长?我们都经历过。机器学习模型,尤其是大型复杂的模型,在实时运行中速度非常慢。而用户则期望即时反馈。这时,延迟就成了一个真正的问题。从技术角度来看,最大的问题之一是当相同的输入反复触发相同的缓慢过程时产生的冗余计算。在本篇文章中,我将向你展示如何解决这个问题。我们将构建一个基于 FastAPI 的机器学习服务,并集成 Redis 缓存,以便在几毫秒内返回重复的预测结果。

什么是FastAPI?

FastAPI 是一个用于使用 Python 构建 API 的现代高性能 Web 框架。它使用 Python 的类型提示进行数据验证,并使用 Swagger UI 和 ReDoc 自动生成交互式 API 文档。FastAPI 构建于 Starlette 和 Pydantic 之上,支持异步编程,其性能堪比 Node.js 和 Go。它的设计有利于快速开发健壮、可用于生产的 API,使其成为将机器学习模型部署为可扩展 RESTful 服务的绝佳选择。

什么是Redis?

Redis(远程字典服务器)是一个开源的内存数据结构存储,可用作数据库、缓存和消息代理。通过将数据存储在内存中,Redis 为读写操作提供了超低延迟,使其成为缓存频繁或计算密集型任务(例如机器学习模型预测)的理想选择。它支持各种数据结构,包括字符串、列表、集合和哈希,并提供密钥过期时间 (TTL) 等功能,以实现高效的缓存管理。

为什么要结合使用FastAPI和Redis?

 

将 FastAPI 与 Redis 集成,即可创建一个响应迅速且高效的系统。FastAPI 提供快速可靠的 API 请求处理接口,而 Redis 则充当缓存层,用于存储先前计算的结果。当再次收到相同的输入时,可以立即从 Redis 中检索结果,无需重新计算。这种方法可以减少延迟、降低计算负载并增强应用程序的可扩展性。在分布式环境中,Redis 充当可供多个 FastAPI 实例访问的集中式缓存,非常适合生产级机器学习部署。

现在,让我们逐步了解如何实现一个使用 Redis 缓存提供机器学习模型预测的 FastAPI 应用程序。此设置可确保从缓存中快速处理具有相同输入的重复请求,从而减少计算时间并提高响应时间。步骤如下:

  1. 加载预训练模型
  2. 创建用于预测的 FastAPI 端点
  3. 设置 Redis 缓存
  4. 衡量性能提升

现在,让我们更详细地了解这些步骤。

步骤 1:加载预训练模型

首先,假设您已经拥有一个训练好的机器学习模型,可以随时部署。实际上,大多数模型都是离线训练的(例如 scikit-learn 模型、TensorFlow/Pytorch 模型等),然后保存到磁盘,再加载到服务应用中。在本例中,我们将创建一个简单的 scikit-learn 分类器,该分类器将在著名的鸢尾花数据集上进行训练,并使用 joblib 保存。如果您已经拥有保存的模型文件,则可以跳过训练部分,直接加载即可。以下是训练模型并将其加载到服务应用中的方法:

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
import joblib
# Load example dataset and train a simple model (Iris classification)
X, y = load_iris(return_X_y=True)
# Train the model
model = RandomForestClassifier().fit(X, y)
# Save the trained model to disk
joblib.dump(model, "model.joblib")
# Load the pre-trained model from disk (using the saved file)
model = joblib.load("model.joblib")
print("Model loaded and ready to serve predictions.")
from sklearn.datasets import load_iris from sklearn.ensemble import RandomForestClassifier import joblib # Load example dataset and train a simple model (Iris classification) X, y = load_iris(return_X_y=True) # Train the model model = RandomForestClassifier().fit(X, y) # Save the trained model to disk joblib.dump(model, "model.joblib") # Load the pre-trained model from disk (using the saved file) model = joblib.load("model.joblib") print("Model loaded and ready to serve predictions.")
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
import joblib
# Load example dataset and train a simple model (Iris classification)
X, y = load_iris(return_X_y=True)
# Train the model
model = RandomForestClassifier().fit(X, y)
# Save the trained model to disk
joblib.dump(model, "model.joblib")
# Load the pre-trained model from disk (using the saved file)
model = joblib.load("model.joblib")
print("Model loaded and ready to serve predictions.")

在上面的代码中,我们使用了 scikit-learn 内置的 Iris 数据集,在其上训练了一个随机森林分类器,然后将该模型保存到名为 model.joblib 的文件中。之后,我们使用 joblib.load 将其重新加载。joblib 库在保存 scikit-learn 模型时非常常用,主要是因为它擅长处理模型内部的 NumPy 数组。完成此步骤后,我们就有了一个可以对新数据进行预测的模型对象。不过需要注意的是,您可以在此处使用任何预先训练好的模型,使用 FastAPI 提供该模型的方式以及缓存的结果大致相同。唯一的问题是,模型应该有一个预测方法,该方法接受一些输入并生成结果。此外,请确保每次输入相同的输入时,模型的预测都保持不变(因此它是确定性的)。如果不是,那么对于非确定性模型,缓存将会出现问题,因为它会返回不正确的结果。

步骤 2:创建FastAPI预测端点

现在我们有了模型,让我们通过 API 使用它。我们将使用 FASTAPI 创建一个处理预测请求的 Web 服务器。FASTAPI 可以轻松定义端点并将请求参数映射到 Python 函数参数。在我们的示例中,我们假设模型接受四个特征。我们将创建一个 GET 端点 /predict,它接受这些特征作为查询参数并返回模型的预测结果。

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
from fastapi import FastAPI
import joblib
app = FastAPI()
# Load the trained model at startup (to avoid re-loading on every request)
model = joblib.load("model.joblib") # Ensure this file exists from the training step
@app.get("/predict")
def predict(sepal_length: float, sepal_width: float, petal_length: float, petal_width: float):
""" Predict the Iris flower species from input measurements. """
# Prepare the features for the model as a 2D list (model expects shape [n_samples, n_features])
features = [[sepal_length, sepal_width, petal_length, petal_width]]
# Get the prediction (in the iris dataset, prediction is an integer class label 0,1,2 representing the species)
prediction = model.predict(features)[0] # Get the first (only) prediction
return {"prediction": str(prediction)}
from fastapi import FastAPI import joblib app = FastAPI() # Load the trained model at startup (to avoid re-loading on every request) model = joblib.load("model.joblib") # Ensure this file exists from the training step @app.get("/predict") def predict(sepal_length: float, sepal_width: float, petal_length: float, petal_width: float): """ Predict the Iris flower species from input measurements. """ # Prepare the features for the model as a 2D list (model expects shape [n_samples, n_features]) features = [[sepal_length, sepal_width, petal_length, petal_width]] # Get the prediction (in the iris dataset, prediction is an integer class label 0,1,2 representing the species) prediction = model.predict(features)[0] # Get the first (only) prediction return {"prediction": str(prediction)}
from fastapi import FastAPI
import joblib
app = FastAPI()
# Load the trained model at startup (to avoid re-loading on every request)
model = joblib.load("model.joblib")  # Ensure this file exists from the training step
@app.get("/predict")
def predict(sepal_length: float, sepal_width: float, petal_length: float, petal_width: float):
""" Predict the Iris flower species from input measurements. """
# Prepare the features for the model as a 2D list (model expects shape [n_samples, n_features])
features = [[sepal_length, sepal_width, petal_length, petal_width]]
# Get the prediction (in the iris dataset, prediction is an integer class label 0,1,2 representing the species)
prediction = model.predict(features)[0]  # Get the first (only) prediction
return {"prediction": str(prediction)}

在上面的代码中,我们创建了一个 FastAPI 应用,在执行文件时,它会启动 API 服务器。FastAPI 对于 Python 来说速度超快,因此它可以轻松处理大量请求。然后我们在开始时加载模型,因为在每个请求上反复执行此操作会很慢,所以我们将它保存在内存中,随时可以使用。我们使用 @app.get 创建了一个 /predict 端点,GET 使测试变得简单,因为我们只需在 URL 中传递内容,但在实际项目中,您可能需要使用 POST,尤其是在发送大型或复杂输入(如图像或 JSON)时。此函数接受 4 个输入: sepal_lengthsepal_widthpetal_lengthpetal_width,FastAPI 会自动从 URL 中读取它们。在函数内部,我们将所有输入放入一个二维列表中(因为 scikit-learn 只接受二维数组),然后调用 model.predict(),它会给我们一个列表。然后我们将其以 JSON 格式返回,例如 {"prediction": "..."}

现在它可以正常工作了,您可以使用 uvicorn main:app --reload、hit /predict、endpoint 运行它并获取结果。即使您再次发送相同的输入,它仍然会再次运行模型,这不太好,所以下一步是添加 Redis 来缓存之前的结果,避免重复执行。

步骤 3:添加Redis缓存用于预测

为了缓存模型输​​出,我们将使用 Redis。首先,确保 Redis 服务器正在运行。您可以本地安装,也可以运行 Docker 容器;它通常默认在 6379 端口运行。我们将使用 Python redis 库与服务器通信。

思路很简单:当请求到来时,创建一个代表输入的唯一键。然后检查该键是否存在于 Redis 中;如果该键已经存在,则意味着我们之前已经缓存过,因此我们只需返回保存的结果,无需再次调用模型。如果不存在,我们执行 model.predict,获取输出,将其保存到 Redis 中,并返回预测结果。

现在,让我们更新 FastAPI 应用以添加此缓存逻辑。

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
!pip install redis
import redis # New import to use Redis
# Connect to a local Redis server (adjust host/port if needed)
cache = redis.Redis(host="localhost", port=6379, db=0)
@app.get("/predict")
def predict(sepal_length: float, sepal_width: float, petal_length: float, petal_width: float):
"""
Predict the species, with caching to speed up repeated predictions.
"""
# 1. Create a unique cache key from input parameters
cache_key = f"{sepal_length}:{sepal_width}:{petal_length}:{petal_width}"
# 2. Check if the result is already cached in Redis
cached_val = cache.get(cache_key)
if cached_val:
# If cache hit, decode the bytes to a string and return the cached prediction
return {"prediction": cached_val.decode("utf-8")}
# 3. If not cached, compute the prediction using the model
features = [[sepal_length, sepal_width, petal_length, petal_width]]
prediction = model.predict(features)[0]
# 4. Store the result in Redis for next time (as a string)
cache.set(cache_key, str(prediction))
# 5. Return the freshly computed prediction
return {"prediction": str(prediction)}
!pip install redis import redis # New import to use Redis # Connect to a local Redis server (adjust host/port if needed) cache = redis.Redis(host="localhost", port=6379, db=0) @app.get("/predict") def predict(sepal_length: float, sepal_width: float, petal_length: float, petal_width: float): """ Predict the species, with caching to speed up repeated predictions. """ # 1. Create a unique cache key from input parameters cache_key = f"{sepal_length}:{sepal_width}:{petal_length}:{petal_width}" # 2. Check if the result is already cached in Redis cached_val = cache.get(cache_key) if cached_val: # If cache hit, decode the bytes to a string and return the cached prediction return {"prediction": cached_val.decode("utf-8")} # 3. If not cached, compute the prediction using the model features = [[sepal_length, sepal_width, petal_length, petal_width]] prediction = model.predict(features)[0] # 4. Store the result in Redis for next time (as a string) cache.set(cache_key, str(prediction)) # 5. Return the freshly computed prediction return {"prediction": str(prediction)}
!pip install redis
import redis  # New import to use Redis
# Connect to a local Redis server (adjust host/port if needed)
cache = redis.Redis(host="localhost", port=6379, db=0)
@app.get("/predict")
def predict(sepal_length: float, sepal_width: float, petal_length: float, petal_width: float):
"""
Predict the species, with caching to speed up repeated predictions.
"""
# 1. Create a unique cache key from input parameters
cache_key = f"{sepal_length}:{sepal_width}:{petal_length}:{petal_width}"
# 2. Check if the result is already cached in Redis
cached_val = cache.get(cache_key)
if cached_val:
# If cache hit, decode the bytes to a string and return the cached prediction
return {"prediction": cached_val.decode("utf-8")}
# 3. If not cached, compute the prediction using the model
features = [[sepal_length, sepal_width, petal_length, petal_width]]
prediction = model.predict(features)[0]
# 4. Store the result in Redis for next time (as a string)
cache.set(cache_key, str(prediction))
# 5. Return the freshly computed prediction
return {"prediction": str(prediction)}

在上面的代码中,我们现在添加了 Redis。首先,我们使用 redis.Redis() 创建了一个客户端。它连接到 Redis 服务器。默认情况下使用 db=0。然后,我们通过连接输入值来创建缓存键。这里这种方法有效是因为输入是简单的数字,但对于复杂的数字,最好使用哈希值或 JSON 字符串。每个输入的键必须是唯一的。我们使用了 cache.get(cache_key)。如果找到相同的键,它会返回该键,这使得速度更快,并且无需重新运行模型。但如果在缓存中找不到,我们需要运行模型并获取预测结果。最后,使用 cache.set() 将结果保存到 Redis 中。这样,下次相同的输入到来时,它就已经存在了,缓存速度会更快。

步骤 4:测试和衡量性能提升

现在我们的 FastAPI 应用已运行并连接到 Redis,是时候测试缓存如何改善响应时间了。在这里,我将演示如何使用 Python 的请求库以相同的输入调用 API 两次,并测量每次调用所花费的时间。另外,请确保在运行测试代码之前启动 FastAPI:

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
import requests, time
# Sample input to predict (same input will be used twice to test caching)
params = {
"sepal_length": 5.1,
"sepal_width": 3.5,
"petal_length": 1.4,
"petal_width": 0.2
}
# First request (expected to be a cache miss, will run the model)
start = time.time()
response1 = requests.get("http://localhost:8000/predict", params=params)
elapsed1 = time.time() - start
print("First response:", response1.json(), f"(Time: {elapsed1:.4f} seconds)")
import requests, time # Sample input to predict (same input will be used twice to test caching) params = { "sepal_length": 5.1, "sepal_width": 3.5, "petal_length": 1.4, "petal_width": 0.2 } # First request (expected to be a cache miss, will run the model) start = time.time() response1 = requests.get("http://localhost:8000/predict", params=params) elapsed1 = time.time() - start print("First response:", response1.json(), f"(Time: {elapsed1:.4f} seconds)")
import requests, time
# Sample input to predict (same input will be used twice to test caching)
params = {
"sepal_length": 5.1,
"sepal_width": 3.5,
"petal_length": 1.4,
"petal_width": 0.2
}
# First request (expected to be a cache miss, will run the model)
start = time.time()
response1 = requests.get("http://localhost:8000/predict", params=params)
elapsed1 = time.time() - start
print("First response:", response1.json(), f"(Time: {elapsed1:.4f} seconds)")

测试和衡量性能提升

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
# Second request (same params, expected cache hit, no model computation)
start = time.time()
response2 = requests.get("http://localhost:8000/predict", params=params)
elapsed2 = time.time() - start
print("Second response:", response2.json(), f"(Time: {elapsed2:.6f}seconds)")
# Second request (same params, expected cache hit, no model computation) start = time.time() response2 = requests.get("http://localhost:8000/predict", params=params) elapsed2 = time.time() - start print("Second response:", response2.json(), f"(Time: {elapsed2:.6f}seconds)")
# Second request (same params, expected cache hit, no model computation)
start = time.time()
response2 = requests.get("http://localhost:8000/predict", params=params)
elapsed2 = time.time() - start
print("Second response:", response2.json(), f"(Time: {elapsed2:.6f}seconds)")

测试和衡量性能提升

运行此代码时,您应该会看到第一个请求返回结果。第二个请求返回相同的结果,但速度明显更快。例如,您可能会发现第一次调用耗时大约几十毫秒(取决于模型复杂度),而第二次调用可能只需几毫秒甚至更短。在我们基于轻量级模型的简单演示中,差异可能很小(因为模型本身速度很快),但对于更重的模型,效果会非常显著。

比较

为了更好地理解这一点,让我们考虑一下我们实现的效果:

  • 不使用缓存:每个请求,即使是相同的请求,都会命中模型。如果模型每次预测需要 100 毫秒,那么 10 个相同的请求总共仍需要大约 1000 毫秒。
  • 使用缓存:第一个请求会命中全部数据(100 毫秒),但接下来的 9 个相同的请求可能每个只需要 1-2 毫秒(仅仅是 Redis 查找和返回数据)。因此,这 10 个请求的总耗时可能约为 120 毫秒,而不是 1000 毫秒,在这种情况下速度可提升约 8 倍。

在实际实验中,缓存可以带来数量级的提升。例如,在电子商务中,使用 Redis 意味着可以在几微秒内为重复请求返回建议,而无需使用完整的模型服务流水线重新计算它们。性能提升取决于模型推理的成本。模型越复杂,缓存重复调用带来的收益就越大。这还取决于请求模式:如果每个请求都是唯一的,缓存将无济于事(无需从内存中提供重复请求),但许多应用程序确实会看到重叠请求(例如,热门搜索查询、推荐商品等)。

您也可以直接检查 Redis 缓存,以验证它是否存储了键。​​

小结

在本文章中,我们演示了 FastAPI 和 Redis 如何协同工作以加速机器学习模型服务。 FastAPI 提供了一个快速且易于构建的 API 层来提供预测服务,而 Redis 则添加了一个缓存层,可显著降低重复计算的延迟和 CPU 负载。通过避免重复的模型调用,我们提高了响应速度,并使系统能够使用相同的资源处理更多请求。

评论留言