使用自定義損失函式和校準指標評估深度學習模型

使用自定義損失函式和校準指標評估深度學習模型

文章目录

  • 傳統深度學習模型評估
  • 問題
  • 什麼是自定義損失函式?
  • 為什麼要構建自定義損失函式?
  • 如何實現自定義損失函式?
  • 為什麼模型校準如此重要?
  • 校準誤差
  • 校準指標
  • PyTorch案例研究
  • 關鍵步驟
  • 小結

使用自定義損失函式和校準指標評估深度學習模型

評估深度學習模型是模型生命週期管理的重要組成部分。雖然傳統模型擅長快速提供模型效能基準,但它們往往無法捕捉實際應用的細微目標。例如,欺詐檢測系統可能優先考慮最小化假陰性而不是假陽性,而醫療診斷模型可能更看重召回率而不是準確率。在這種情況下,僅僅依賴傳統指標可能會導致模型行為不理想。這時,自定義損失函式和定製評估指標就派上用場了。

傳統深度學習模型評估

評估分類結果的傳統指標包括準確率、召回率、F1 分數等。交叉熵損失是分類的首選損失函式。這些典型的分類指標僅評估預測是否正確,而忽略了不確定性。

一個模型可能擁有很高的準確率,但機率估計卻很差。現代深度網路過於自信,即使錯誤,返回的機率也約為 0 或 1。

問題

Guo 等人的研究顯示,即使高度準確的深度模型,其校準也可能存在問題。同樣,一個模型可能擁有很高的 F1 分數,但其不確定性估計仍然可能存在校準誤差。最佳化目標函式(例如準確率或對數損失函式)也可能導致機率校準誤差,因為傳統的評估指標無法評估模型的置信度是否與現實相符。例如,肺炎檢測 AI 可能會根據在無害條件下也會發生的模式輸出 99.9% 的機率,從而導致過度自信。諸如溫度縮放之類的校準方法可以調整這些分數,使其更好地反映真實的可能性。

什麼是自定義損失函式?

自定義損失函式或目標函式是您為表達特定目標而發明的任何訓練損失函式(除了交叉熵和 MSE 等標準損失函式之外)。當更通用的損失函式無法滿足您的業務需求時,您可以自行開發一個。

例如,您可以使用一個損失函式,該函式對假陰性、漏報欺詐的懲罰力度要大於對假陽性的懲罰力度。這讓您可以處理不均衡的懲罰或目標,例如最大化 F1 值,而不僅僅是準確率。損失函式只是一個平滑的數學公式,用於比較預測值與標籤值,因此您可以設計任何公式來精確模擬您想要的指標或成本。

為什麼要構建自定義損失函式?

有時,預設損失函式會在重要案例(例如,稀有類別)上訓練不足,或者無法反映您的效用。自定義損失函式使您能夠:

  • 與業務邏輯保持一致:例如,對某種疾病的漏檢懲罰是誤報的 5 倍。
  • 處理不均衡:降低多數類別的權重,或關注少數類別。
  • 編碼領域啟發式演算法:例如,要求預測遵循單調性或排序規則。
  • 針對特定指標進行最佳化:近似 F1 值/準確率/召回率,或特定領域的投資回報率 (ROI)。

如何實現自定義損失函式?

在本節中,我們將使用 PyTorch 的 nn.Module 函式實現自定義損失函式。以下是其關鍵點:

    • 可微分性:確保損失函式對於模型輸出可微分。
  • 數值穩定性:在 PyTorch 中使用對數和指數函式或穩定函式(F.log_softmaxF.cross_entropy 等)。例如,可以以相同的方式使用 F.cross_entropy(包含 softmax 和 log),但隨後乘以 (1−𝑝𝑡)𝛾 來編寫 Focal Loss。這種方法避免了在單獨的 softmax 中計算機率,從而避免了下溢。
  • 程式碼示例:為了演示這個想法,以下是在 PyTorch 中定義自定義 Focal Loss 的方法:
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
   def __init__(self, gamma=2.0, weight=None):
       super(FocalLoss, self).__init__()
       self.gamma = gamma
       self.weight = weight  # weight tensor for classes (optional)
   def forward(self, logits, targets):
       # Compute standard cross entropy loss per-sample
       ce_loss = F.cross_entropy(logits, targets, weight=self.weight, reduction='none')
       p_t = torch.exp(-ce_loss)          # The model's estimated probability for true class
       loss = ((1 - p_t) ** self.gamma) * ce_loss
       return loss.mean()

這裡,γ 調整了我們對困難樣本的關注程度。γ 越高,關注度越高,這意味著權重可以處理類別不平衡問題。

我們使用 Focal Loss 作為損失函式,因為它旨在解決物體檢測和其他機器學習任務中的類別不平衡問題,尤其是在處理大量易分類樣本(例如物體檢測中的背景)時。這使得它非常適合我們的任務。

為什麼模型校準如此重要?

校準描述了預測機率與真實世界頻率的對應程度。如果在所有將機率 p 分配給正類的例項中,大約有 p 部分為正類,則該模型校準良好。換句話說,“置信度 = 準確率(confidence = accuracy)”。例如,如果一個模型在 100 個測試用例上預測為 0.8,我們預計大約 80 個是正確的。在使用機率進行決策(例如風險評分;成本效益分析)時,校準非常重要。形式上,這意味著對於具有機率輸出𝑝^的分類器,校準是:

二元分類器

二元分類器的完美校準條件

校準誤差

校準誤差分為兩類:

  1. 過度自信:指模型的預測機率系統性地高於真實機率(例如,預測結果為 90%,但 80% 的時間都是正確的)。深度神經網路往往過於自信,尤其是在引數過度的情況下。過度自信的模型可能很危險;它們通常會做出過強的預測,並且在錯誤分類時會誤導我們。
  2. 欠自信:欠自信在深度網路中並不常見。這與過度自信相反,指的是模型的置信度過低(例如,預測結果為 60%,但 80% 的時間都是正確的)。雖然欠自信通常會使模型在預測時處於更安全的位置,但它可能看起來不夠確定,因此不太實用。

在實踐中,現代深度神經網路通常都過於自信。Guo 等人發現,具有批次規範、更深層等特徵的較新的深度網路,即使在誤分類的情況下,也會在某一類別中出現尖峰後驗分佈,機率非常高。當我們出現這些校準誤差時,認識到這些誤差對於我們做出可靠的預測至關重要。

校準指標

  • 信度圖:校準曲線。信度圖通常稱為信度圖,它還會根據預測的置信度得分將預測的成功結果分入不同的箱體。對於每個箱體,它繪製了正樣本的比例(y 軸)與平均預測機率(x 軸)的關係。

信度圖 校準曲線

Source: IQ

  • 預期校準誤差 (ECE):它概括了準確度和置信度之間的絕對差異,並根據 bin 的大小進行加權。形式上,acc(b)conf(b) 分別是準確度和 bin 大小的平均置信度。提醒一下,ECE 值越低越好(0=完美校準)。ECE 是平均校準誤差的量度。

預期校準誤差 (ECE)

預期校準誤差公式

  • 最大校準誤差 (MCE):所有區間中的最大差距:

最大校準誤差 (MCE)

最大校準誤差公式

  • Brier 分數:Brier 分數是預測機率與實際結果之間的均方誤差,其值為 0 或 1。這是一個合理的評分規則,可以同時反映校準性和準確率。然而,Brier 分數較低並不意味著預測校準良好。它兼具校準性和判別力。

PyTorch案例研究

本部分,我們將使用 BigMart Sales 資料集來演示自定義損失函式和校準矩陣如何幫助預測目標列 OutletSales

我們透過設定中位數閾值,將連續型 OutletSales 轉換為二元“高 vs 低”類別。然後,我們使用產品可見性等特徵在 PyTorch 中擬合一個簡單的分類器,並應用自定義損失函式和校準矩陣。

關鍵步驟

資料準備和預處理:在本部分中,我們將匯入庫、載入資料,以及最重要的資料預處理步驟。例如,缺失值處理、使分類列統一(“低脂”、“低脂”,如果所有列都相同,則它們將變為“低脂”)、為目標變數設定閾值、對分類變數執行獨熱編碼 (OHE) 以及拆分特徵。

import os
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, random_split
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.utils.class_weight import compute_class_weight
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)
# ----- missing-value handling -----
df['Weight'].fillna(df['Weight'].mean(), inplace=True)
df['OutletSize'].fillna(df['OutletSize'].mode()[0], inplace=True)
# ----- categorical cleaning -----
df['FatContent'].replace(
{'low fat': 'Low Fat', 'LF': 'Low Fat', 'reg': 'Regular'},
inplace=True
)
# ----- classification target -----
threshold = df['OutletSales'].median()
df['SalesCategory'] = (df['OutletSales'] > threshold).astype(int)
# ----- one-hot encode categoricals -----
cat_cols = [
'FatContent', 'ProductType', 'OutletID',
'OutletSize', 'LocationType', 'OutletType'
]
df = pd.get_dummies(df, columns=cat_cols, drop_first=True)
# ----- split features / labels -----
X = df.drop(['ProductID', 'OutletSales', 'SalesCategory'], axis=1).values
y = df['SalesCategory'].values
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=SEED, stratify=y
)
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
# create torch tensors
X_train_t = torch.tensor(X_train, dtype=torch.float32)
y_train_t = torch.tensor(y_train, dtype=torch.long)
X_test_t = torch.tensor(X_test, dtype=torch.float32)
y_test_t = torch.tensor(y_test, dtype=torch.long)
# split train into train/val (80/20 of original train)
val_frac = 0.2
val_size = int(len(X_train_t) * val_frac)
train_size = len(X_train_t) - val_size
train_ds, val_ds = random_split(
TensorDataset(X_train_t, y_train_t),
[train_size, val_size],
generator=torch.Generator().manual_seed(SEED)
)
train_loader = DataLoader(
train_ds, batch_size=64, shuffle=True, drop_last=True
)
val_loader = DataLoader(
val_ds, batch_size=256, shuffle=False
)

自定義損失:在第二步中,首先,我們將建立一個自定義的 SalesClassifier。假設我們應用焦點損失來更加重視少數類。然後,我們將重新調整模型以最大化焦點損失而不是交叉熵損失。在許多情況下,焦點損失會增加對少數類的召回率,但可能會降低原始準確率。之後,我們將在自定義 SoftF1Loss 的幫助下訓練我們的銷售分類器超過 100 個 epoch,並將最佳模型儲存為 best_model.pt

class SalesClassifier(nn.Module):
   def __init__(self, input_dim):
       super().__init__()
       self.net = nn.Sequential(
           nn.Linear(input_dim, 128),
           nn.BatchNorm1d(128),
           nn.ReLU(inplace=True),
           nn.Dropout(0.5),
           nn.Linear(128, 64),
           nn.ReLU(inplace=True),
           nn.Dropout(0.25),
           nn.Linear(64, 2)          # logits for 2 classes
       )
   def forward(self, x):
       return self.net(x)
# class-weighted CrossEntropy to fight imbalance
class_weights = compute_class_weight('balanced',
                                    classes=np.unique(y_train),
                                    y=y_train)
class_weights = torch.tensor(class_weights, dtype=torch.float32,
                            device=device)
ce_loss = nn.CrossEntropyLoss(weight=class_weights)

在這裡,我們將使用一個名為 SoftF1Loss 的自定義損失函式。因此,這裡的 SoftF1Loss 是一個自定義損失函式,它以可區分的方式直接最佳化 F1 分數,使其適合基於梯度的訓練。它不使用硬 0/1 預測,而是使用來自模型輸出( torch.softmax )的軟機率,因此損失會隨著預測的變化而平滑變化。它使用這些機率和真實標籤計算軟真正例(TP)、假正例(FP)和假負例(FN),然後計算軟精度和召回率。由此,它得出一個“軟”F1 分數並返回 1 – F1,以便最小化損失將最大化 F1 分數。這在處理不平衡資料集時特別有用,因為準確率並不是衡量效能的良好指標。

# Differentiable Custom Loss Function Soft-F1 loss
class SoftF1Loss(nn.Module):
   def forward(self, logits, labels):
       probs = torch.softmax(logits, dim=1)[:, 1]   # positive-class prob
       labels = labels.float()
       tp = (probs * labels).sum()
       fp = (probs * (1 - labels)).sum()
       fn = ((1 - probs) * labels).sum()
       precision = tp / (tp + fp + 1e-7)
       recall    = tp / (tp + fn + 1e-7)
       f1 = 2 * precision * recall / (precision + recall + 1e-7)
       return 1 - f1
f1_loss = SoftF1Loss()
def total_loss(logits, targets, alpha=0.5):
   return alpha * ce_loss(logits, targets) + (1 - alpha) * f1_loss(logits, targets)
model = SalesClassifier(X_train.shape[1]).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
best_val = float('inf'); patience=10; patience_cnt=0
for epoch in range(1, 101):
   model.train()
   train_losses = []
   for xb, yb in train_loader:
       xb, yb = xb.to(device), yb.to(device)
       optimizer.zero_grad()
       logits = model(xb)
       loss = total_loss(logits, yb)
       loss.backward()
       optimizer.step()
       train_losses.append(loss.item())
   # ----- validation -----
   model.eval()
   with torch.no_grad():
       val_losses = []
       for xb, yb in val_loader:
           xb, yb = xb.to(device), yb.to(device)
           val_losses.append(total_loss(model(xb), yb).item())
       val_loss = np.mean(val_losses)
   if epoch % 10 == 0:
       print(f'Epoch {epoch:3d} | TrainLoss {np.mean(train_losses):.4f}'
             f' | ValLoss {val_loss:.4f}')
   # ----- early stopping -----
   if val_loss < best_val - 1e-4:
       best_val = val_loss
       patience_cnt = 0
       torch.save(model.state_dict(), 'best_model.pt')
   else:
       patience_cnt += 1
       if patience_cnt >= patience:
           print('Early stopping!')
           break

自定義損失函式

# load best weights
model.load_state_dict(torch.load('best_model.pt'))

校準前/後:在此流程中,我們可能發現基線模型的 ECE 較高,表明該模型過於自信。因此,基線模型的預期校準誤差 (ECE) 可能偏高/偏低,表明該模型過於自信/不足。

現在,我們可以使用溫度縮放來校準模型,然後重複該過程以計算新的 ECE 並繪製新的可靠性曲線。我們將看到,在溫度縮放之後,可靠性曲線可能會更接近對角線。

class ModelWithTemperature(nn.Module):
   def __init__(self, model):
       super().__init__()
       self.model = model
       self.temperature = nn.Parameter(torch.ones(1) * 1.5)
   def forward(self, x):
       logits = self.model(x)
       return logits / self.temperature
model_ts = ModelWithTemperature(model).to(device)
optim_ts = torch.optim.LBFGS([model_ts.temperature], lr=0.01, max_iter=50)
def _nll():
   optim_ts.zero_grad()
   logits = model_ts(X_val := X_test_t.to(device))   # use test set to fit T
   loss = ce_loss(logits, y_test_t.to(device))
   loss.backward()
   return loss
optim_ts.step(_nll)
print('Optimal temperature:', model_ts.temperature.item())
Optimal temperature: 1.585491418838501

視覺化:在本節中,我們將繪製校準“前”和“後”的可靠性圖表。這些圖表直觀地表示了改進後的比對效果。

@torch.no_grad()
def get_probs(mdl, X):
   mdl.eval()
   logits = mdl(X.to(device))
   return F.softmax(logits, dim=1).cpu()
def ece(probs, labels, n_bins=10):
   conf, preds = probs.max(1)
   accs = preds.eq(labels)
   bins = torch.linspace(0,1,n_bins+1)
   ece_val = torch.zeros(1)
   for lo, hi in zip(bins[:-1], bins[1:]):
       mask = (conf>lo) & (conf<=hi)
       if mask.any():
           ece_val += torch.abs(accs[mask].float().mean() - conf[mask].mean()) \
                      * mask.float().mean()
   return ece_val.item()
def plot_reliability(ax, probs, labels, n_bins=10, label='Model'):
   conf, preds = probs.max(1)
   accs = preds.eq(labels)
   bins = torch.linspace(0,1,n_bins+1)
   bin_acc, bin_conf = [], []
   for i in range(n_bins):
       mask = (conf>bins[i]) & (conf<=bins[i+1])
       if mask.any():
           bin_acc.append(accs[mask].float().mean().item())
           bin_conf.append(conf[mask].mean().item())
   ax.plot(bin_conf, bin_acc, marker='o', label=label)
   ax.plot([0,1],[0,1],'--',color='orange')
   ax.set_xlabel('Confidence'); ax.set_ylabel('Accuracy')
   ax.set_title('Reliability Diagram'); ax.grid(); ax.legend()
probs_before = get_probs(model   , X_test_t)
probs_after  = get_probs(model_ts, X_test_t)
print('\nClassification report (calibrated logits):')
print(classification_report(y_test, probs_after.argmax(1)))

圖表

print('ECE before T-scaling :', ece(probs_before, y_test_t))
print('ECE after  T-scaling :', ece(probs_after , y_test_t))
#----------------------------------------
# ECE before T-scaling : 0.05823298543691635
# ECE after  T-scaling : 0.02461853437125683
# ----------------------------------------------
# reliability plot
fig, ax = plt.subplots(figsize=(6,5))
plot_reliability(ax, probs_before, y_test_t, label='Before T-Scaling')
plot_reliability(ax, probs_after , y_test_t, label='After  T-Scaling')
plt.show()

繪製校準“前”和“後”的可靠性圖表

此圖表顯示了在溫度縮放之前(藍色)和之後(橙色)置信度得分與實際值的匹配程度。x 軸表示其平均置信度,y 軸表示這些預測被正確評分的頻率。虛線對角線表示與此線重合的完美校準點,表示……

例如,置信度為 70% 的預測,其正確評分率為 70%。縮放後,橙線比藍線更緊密地貼合這條對角線。尤其是在置信度為 0.6 到 0.9 的置信度空間的“中間”位置,並且幾乎與理想點 (1.0, 1.0) 相交。換句話說,溫度縮放可以降低模型過度或不足置信的傾向,從而使其機率的點估計更加準確。

點選此處檢視完整 notebook

小結

在現實世界的 AI 應用中,有效性和校準同等重要。一個模型可能具有很高的效度,但如果模型的置信度不準確,那麼更高的效度也毫無意義。因此,在訓練過程中根據您的問題陳述開發自定義損失函式可以符合我們的真實目標,並且我們會評估校準,以便能夠恰當地解釋預測機率。

因此,完整的評估策略會同時考慮兩者:我們首先允許自定義損失函式充分最佳化模型以適應任務,然後我們會有意識地校準和驗證機率輸出。現在,我們可以建立一個決策支援工具,其中“90% 的置信度”實際上就是“90% 的可能性”,這對於任何實際應用都至關重要。

評論留言