如何使用FastAPI和Redis快取構建毫秒級機器學習預測API

如何使用FastAPI和Redis快取構建毫秒級機器學習預測API配圖

你是否曾經等待模型返回預測結果的時間過長?我們都經歷過。機器學習模型,尤其是大型複雜的模型,在即時執行中速度非常慢。而使用者則期望即時反饋。這時,延遲就成了一個真正的問題。從技術角度來看,最大的問題之一是當相同的輸入反覆觸發相同的緩慢過程時產生的冗餘計算。在本篇文章中,我將向你展示如何解決這個問題。我們將構建一個基於 FastAPI 的機器學習服務,並整合 Redis 快取,以便在幾毫秒內返回重複的預測結果。

什麼是FastAPI?

FastAPI 是一個用於使用 Python 構建 API 的現代高效能 Web 框架。它使用 Python 的型別提示進行資料驗證,並使用 Swagger UI 和 ReDoc 自動生成互動式 API 文件。FastAPI 構建於 Starlette 和 Pydantic 之上,支援非同步程式設計,其效能堪比 Node.js 和 Go。它的設計有利於快速開發健壯、可用於生產的 API,使其成為將機器學習模型部署為可擴充套件 RESTful 服務的絕佳選擇。

什麼是Redis?

Redis(遠端字典伺服器)是一個開源的記憶體資料結構儲存,可用作資料庫、快取和訊息代理。透過將資料儲存在記憶體中,Redis 為讀寫操作提供了超低延遲,使其成為快取頻繁或計算密集型任務(例如機器學習模型預測)的理想選擇。它支援各種資料結構,包括字串、列表、集合和雜湊,並提供金鑰過期時間 (TTL) 等功能,以實現高效的快取管理。

為什麼要結合使用FastAPI和Redis?

 

將 FastAPI 與 Redis 整合,即可建立一個響應迅速且高效的系統。FastAPI 提供快速可靠的 API 請求處理介面,而 Redis 則充當快取層,用於儲存先前計算的結果。當再次收到相同的輸入時,可以立即從 Redis 中檢索結果,無需重新計算。這種方法可以減少延遲、降低計算負載並增強應用程式的可擴充套件性。在分散式環境中,Redis 充當可供多個 FastAPI 例項訪問的集中式快取,非常適合生產級機器學習部署。

現在,讓我們逐步瞭解如何實現一個使用 Redis 快取提供機器學習模型預測的 FastAPI 應用程式。此設定可確保從快取中快速處理具有相同輸入的重複請求,從而減少計算時間並提高響應時間。步驟如下:

  1. 載入預訓練模型
  2. 建立用於預測的 FastAPI 端點
  3. 設定 Redis 快取
  4. 衡量效能提升

現在,讓我們更詳細地瞭解這些步驟。

步驟 1:載入預訓練模型

首先,假設您已經擁有一個訓練好的機器學習模型,可以隨時部署。實際上,大多數模型都是離線訓練的(例如 scikit-learn 模型、TensorFlow/Pytorch 模型等),然後儲存到磁碟,再載入到服務應用中。在本例中,我們將建立一個簡單的 scikit-learn 分類器,該分類器將在著名的鳶尾花資料集上進行訓練,並使用 joblib 儲存。如果您已經擁有儲存的模型檔案,則可以跳過訓練部分,直接載入即可。以下是訓練模型並將其載入到服務應用中的方法:

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
import joblib
# Load example dataset and train a simple model (Iris classification)
X, y = load_iris(return_X_y=True)
# Train the model
model = RandomForestClassifier().fit(X, y)
# Save the trained model to disk
joblib.dump(model, "model.joblib")
# Load the pre-trained model from disk (using the saved file)
model = joblib.load("model.joblib")
print("Model loaded and ready to serve predictions.")
from sklearn.datasets import load_iris from sklearn.ensemble import RandomForestClassifier import joblib # Load example dataset and train a simple model (Iris classification) X, y = load_iris(return_X_y=True) # Train the model model = RandomForestClassifier().fit(X, y) # Save the trained model to disk joblib.dump(model, "model.joblib") # Load the pre-trained model from disk (using the saved file) model = joblib.load("model.joblib") print("Model loaded and ready to serve predictions.")
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
import joblib
# Load example dataset and train a simple model (Iris classification)
X, y = load_iris(return_X_y=True)
# Train the model
model = RandomForestClassifier().fit(X, y)
# Save the trained model to disk
joblib.dump(model, "model.joblib")
# Load the pre-trained model from disk (using the saved file)
model = joblib.load("model.joblib")
print("Model loaded and ready to serve predictions.")

在上面的程式碼中,我們使用了 scikit-learn 內建的 Iris 資料集,在其上訓練了一個隨機森林分類器,然後將該模型儲存到名為 model.joblib 的檔案中。之後,我們使用 joblib.load 將其重新載入。joblib 庫在儲存 scikit-learn 模型時非常常用,主要是因為它擅長處理模型內部的 NumPy 陣列。完成此步驟後,我們就有了一個可以對新資料進行預測的模型物件。不過需要注意的是,您可以在此處使用任何預先訓練好的模型,使用 FastAPI 提供該模型的方式以及快取的結果大致相同。唯一的問題是,模型應該有一個預測方法,該方法接受一些輸入並生成結果。此外,請確保每次輸入相同的輸入時,模型的預測都保持不變(因此它是確定性的)。如果不是,那麼對於非確定性模型,快取將會出現問題,因為它會返回不正確的結果。

步驟 2:建立FastAPI預測端點

現在我們有了模型,讓我們透過 API 使用它。我們將使用 FASTAPI 建立一個處理預測請求的 Web 伺服器。FASTAPI 可以輕鬆定義端點並將請求引數對映到 Python 函式引數。在我們的示例中,我們假設模型接受四個特徵。我們將建立一個 GET 端點 /predict,它接受這些特徵作為查詢引數並返回模型的預測結果。

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
from fastapi import FastAPI
import joblib
app = FastAPI()
# Load the trained model at startup (to avoid re-loading on every request)
model = joblib.load("model.joblib") # Ensure this file exists from the training step
@app.get("/predict")
def predict(sepal_length: float, sepal_width: float, petal_length: float, petal_width: float):
""" Predict the Iris flower species from input measurements. """
# Prepare the features for the model as a 2D list (model expects shape [n_samples, n_features])
features = [[sepal_length, sepal_width, petal_length, petal_width]]
# Get the prediction (in the iris dataset, prediction is an integer class label 0,1,2 representing the species)
prediction = model.predict(features)[0] # Get the first (only) prediction
return {"prediction": str(prediction)}
from fastapi import FastAPI import joblib app = FastAPI() # Load the trained model at startup (to avoid re-loading on every request) model = joblib.load("model.joblib") # Ensure this file exists from the training step @app.get("/predict") def predict(sepal_length: float, sepal_width: float, petal_length: float, petal_width: float): """ Predict the Iris flower species from input measurements. """ # Prepare the features for the model as a 2D list (model expects shape [n_samples, n_features]) features = [[sepal_length, sepal_width, petal_length, petal_width]] # Get the prediction (in the iris dataset, prediction is an integer class label 0,1,2 representing the species) prediction = model.predict(features)[0] # Get the first (only) prediction return {"prediction": str(prediction)}
from fastapi import FastAPI
import joblib
app = FastAPI()
# Load the trained model at startup (to avoid re-loading on every request)
model = joblib.load("model.joblib")  # Ensure this file exists from the training step
@app.get("/predict")
def predict(sepal_length: float, sepal_width: float, petal_length: float, petal_width: float):
""" Predict the Iris flower species from input measurements. """
# Prepare the features for the model as a 2D list (model expects shape [n_samples, n_features])
features = [[sepal_length, sepal_width, petal_length, petal_width]]
# Get the prediction (in the iris dataset, prediction is an integer class label 0,1,2 representing the species)
prediction = model.predict(features)[0]  # Get the first (only) prediction
return {"prediction": str(prediction)}

在上面的程式碼中,我們建立了一個 FastAPI 應用,在執行檔案時,它會啟動 API 伺服器。FastAPI 對於 Python 來說速度超快,因此它可以輕鬆處理大量請求。然後我們在開始時載入模型,因為在每個請求上反覆執行此操作會很慢,所以我們將它儲存在記憶體中,隨時可以使用。我們使用 @app.get 建立了一個 /predict 端點,GET 使測試變得簡單,因為我們只需在 URL 中傳遞內容,但在實際專案中,您可能需要使用 POST,尤其是在傳送大型或複雜輸入(如影像或 JSON)時。此函式接受 4 個輸入: sepal_lengthsepal_widthpetal_lengthpetal_width,FastAPI 會自動從 URL 中讀取它們。在函式內部,我們將所有輸入放入一個二維列表中(因為 scikit-learn 只接受二維陣列),然後呼叫 model.predict(),它會給我們一個列表。然後我們將其以 JSON 格式返回,例如 {"prediction": "..."}

現在它可以正常工作了,您可以使用 uvicorn main:app --reload、hit /predict、endpoint 執行它並獲取結果。即使您再次傳送相同的輸入,它仍然會再次執行模型,這不太好,所以下一步是新增 Redis 來快取之前的結果,避免重複執行。

步驟 3:新增Redis快取用於預測

為了快取模型輸​​出,我們將使用 Redis。首先,確保 Redis 伺服器正在執行。您可以本地安裝,也可以執行 Docker 容器;它通常預設在 6379 埠執行。我們將使用 Python redis 庫與伺服器通訊。

思路很簡單:當請求到來時,建立一個代表輸入的唯一鍵。然後檢查該鍵是否存在於 Redis 中;如果該鍵已經存在,則意味著我們之前已經快取過,因此我們只需返回儲存的結果,無需再次呼叫模型。如果不存在,我們執行 model.predict,獲取輸出,將其儲存到 Redis 中,並返回預測結果。

現在,讓我們更新 FastAPI 應用以新增此快取邏輯。

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
!pip install redis
import redis # New import to use Redis
# Connect to a local Redis server (adjust host/port if needed)
cache = redis.Redis(host="localhost", port=6379, db=0)
@app.get("/predict")
def predict(sepal_length: float, sepal_width: float, petal_length: float, petal_width: float):
"""
Predict the species, with caching to speed up repeated predictions.
"""
# 1. Create a unique cache key from input parameters
cache_key = f"{sepal_length}:{sepal_width}:{petal_length}:{petal_width}"
# 2. Check if the result is already cached in Redis
cached_val = cache.get(cache_key)
if cached_val:
# If cache hit, decode the bytes to a string and return the cached prediction
return {"prediction": cached_val.decode("utf-8")}
# 3. If not cached, compute the prediction using the model
features = [[sepal_length, sepal_width, petal_length, petal_width]]
prediction = model.predict(features)[0]
# 4. Store the result in Redis for next time (as a string)
cache.set(cache_key, str(prediction))
# 5. Return the freshly computed prediction
return {"prediction": str(prediction)}
!pip install redis import redis # New import to use Redis # Connect to a local Redis server (adjust host/port if needed) cache = redis.Redis(host="localhost", port=6379, db=0) @app.get("/predict") def predict(sepal_length: float, sepal_width: float, petal_length: float, petal_width: float): """ Predict the species, with caching to speed up repeated predictions. """ # 1. Create a unique cache key from input parameters cache_key = f"{sepal_length}:{sepal_width}:{petal_length}:{petal_width}" # 2. Check if the result is already cached in Redis cached_val = cache.get(cache_key) if cached_val: # If cache hit, decode the bytes to a string and return the cached prediction return {"prediction": cached_val.decode("utf-8")} # 3. If not cached, compute the prediction using the model features = [[sepal_length, sepal_width, petal_length, petal_width]] prediction = model.predict(features)[0] # 4. Store the result in Redis for next time (as a string) cache.set(cache_key, str(prediction)) # 5. Return the freshly computed prediction return {"prediction": str(prediction)}
!pip install redis
import redis  # New import to use Redis
# Connect to a local Redis server (adjust host/port if needed)
cache = redis.Redis(host="localhost", port=6379, db=0)
@app.get("/predict")
def predict(sepal_length: float, sepal_width: float, petal_length: float, petal_width: float):
"""
Predict the species, with caching to speed up repeated predictions.
"""
# 1. Create a unique cache key from input parameters
cache_key = f"{sepal_length}:{sepal_width}:{petal_length}:{petal_width}"
# 2. Check if the result is already cached in Redis
cached_val = cache.get(cache_key)
if cached_val:
# If cache hit, decode the bytes to a string and return the cached prediction
return {"prediction": cached_val.decode("utf-8")}
# 3. If not cached, compute the prediction using the model
features = [[sepal_length, sepal_width, petal_length, petal_width]]
prediction = model.predict(features)[0]
# 4. Store the result in Redis for next time (as a string)
cache.set(cache_key, str(prediction))
# 5. Return the freshly computed prediction
return {"prediction": str(prediction)}

在上面的程式碼中,我們現在新增了 Redis。首先,我們使用 redis.Redis() 建立了一個客戶端。它連線到 Redis 伺服器。預設情況下使用 db=0。然後,我們透過連線輸入值來建立快取鍵。這裡這種方法有效是因為輸入是簡單的數字,但對於複雜的數字,最好使用雜湊值或 JSON 字串。每個輸入的鍵必須是唯一的。我們使用了 cache.get(cache_key)。如果找到相同的鍵,它會返回該鍵,這使得速度更快,並且無需重新執行模型。但如果在快取中找不到,我們需要執行模型並獲取預測結果。最後,使用 cache.set() 將結果儲存到 Redis 中。這樣,下次相同的輸入到來時,它就已經存在了,快取速度會更快。

步驟 4:測試和衡量效能提升

現在我們的 FastAPI 應用已執行並連線到 Redis,是時候測試快取如何改善響應時間了。在這裡,我將演示如何使用 Python 的請求庫以相同的輸入呼叫 API 兩次,並測量每次呼叫所花費的時間。另外,請確保在執行測試程式碼之前啟動 FastAPI:

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
import requests, time
# Sample input to predict (same input will be used twice to test caching)
params = {
"sepal_length": 5.1,
"sepal_width": 3.5,
"petal_length": 1.4,
"petal_width": 0.2
}
# First request (expected to be a cache miss, will run the model)
start = time.time()
response1 = requests.get("http://localhost:8000/predict", params=params)
elapsed1 = time.time() - start
print("First response:", response1.json(), f"(Time: {elapsed1:.4f} seconds)")
import requests, time # Sample input to predict (same input will be used twice to test caching) params = { "sepal_length": 5.1, "sepal_width": 3.5, "petal_length": 1.4, "petal_width": 0.2 } # First request (expected to be a cache miss, will run the model) start = time.time() response1 = requests.get("http://localhost:8000/predict", params=params) elapsed1 = time.time() - start print("First response:", response1.json(), f"(Time: {elapsed1:.4f} seconds)")
import requests, time
# Sample input to predict (same input will be used twice to test caching)
params = {
"sepal_length": 5.1,
"sepal_width": 3.5,
"petal_length": 1.4,
"petal_width": 0.2
}
# First request (expected to be a cache miss, will run the model)
start = time.time()
response1 = requests.get("http://localhost:8000/predict", params=params)
elapsed1 = time.time() - start
print("First response:", response1.json(), f"(Time: {elapsed1:.4f} seconds)")

測試和衡量效能提升

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
# Second request (same params, expected cache hit, no model computation)
start = time.time()
response2 = requests.get("http://localhost:8000/predict", params=params)
elapsed2 = time.time() - start
print("Second response:", response2.json(), f"(Time: {elapsed2:.6f}seconds)")
# Second request (same params, expected cache hit, no model computation) start = time.time() response2 = requests.get("http://localhost:8000/predict", params=params) elapsed2 = time.time() - start print("Second response:", response2.json(), f"(Time: {elapsed2:.6f}seconds)")
# Second request (same params, expected cache hit, no model computation)
start = time.time()
response2 = requests.get("http://localhost:8000/predict", params=params)
elapsed2 = time.time() - start
print("Second response:", response2.json(), f"(Time: {elapsed2:.6f}seconds)")

測試和衡量效能提升

執行此程式碼時,您應該會看到第一個請求返回結果。第二個請求返回相同的結果,但速度明顯更快。例如,您可能會發現第一次呼叫耗時大約幾十毫秒(取決於模型複雜度),而第二次呼叫可能只需幾毫秒甚至更短。在我們基於輕量級模型的簡單演示中,差異可能很小(因為模型本身速度很快),但對於更重的模型,效果會非常顯著。

比較

為了更好地理解這一點,讓我們考慮一下我們實現的效果:

  • 不使用快取:每個請求,即使是相同的請求,都會命中模型。如果模型每次預測需要 100 毫秒,那麼 10 個相同的請求總共仍需要大約 1000 毫秒。
  • 使用快取:第一個請求會命中全部資料(100 毫秒),但接下來的 9 個相同的請求可能每個只需要 1-2 毫秒(僅僅是 Redis 查詢和返回資料)。因此,這 10 個請求的總耗時可能約為 120 毫秒,而不是 1000 毫秒,在這種情況下速度可提升約 8 倍。

在實際實驗中,快取可以帶來數量級的提升。例如,在電子商務中,使用 Redis 意味著可以在幾微秒內為重複請求返回建議,而無需使用完整的模型服務流水線重新計算它們。效能提升取決於模型推理的成本。模型越複雜,快取重複呼叫帶來的收益就越大。這還取決於請求模式:如果每個請求都是唯一的,快取將無濟於事(無需從記憶體中提供重複請求),但許多應用程式確實會看到重疊請求(例如,熱門搜尋查詢、推薦商品等)。

您也可以直接檢查 Redis 快取,以驗證它是否儲存了鍵。​​

小結

在本文章中,我們演示了 FastAPI 和 Redis 如何協同工作以加速機器學習模型服務。 FastAPI 提供了一個快速且易於構建的 API 層來提供預測服務,而 Redis 則新增了一個快取層,可顯著降低重複計算的延遲和 CPU 負載。透過避免重複的模型呼叫,我們提高了響應速度,並使系統能夠使用相同的資源處理更多請求。

評論留言