什麼是蒸餾模型(Distilled Model)?即更小更高效的模型

什麼是蒸餾模型(Distilled Model)?即更小更高效的模型

文章目录

  • 蒸餾模型的優勢
  • 如何引入蒸餾模型?
  • 雙重損失
  • Softmax函式和Temperature
  • 用Temperature調整損失
  • 概述
  • DistilBERT
  • DistillGPT2
  • 實施LLM蒸餾
  • 框架和庫
  • 相關步驟
  • 瞭解蒸餾模型
  • 模型蒸餾的關鍵要素
  • 挑戰與侷限
  • 模型蒸餾的未來方向
  • 實際應用
  • 小結

什麼是蒸餾模型(Distilled Model)?即更小更高效的模型

你不可能沒聽說過 Deepseek,但您是否也在 Ollama 上看到過 Deepseek 的蒸餾模型?或者,如果您嘗試過 Groq Cloud,也可能看到過類似的模型。但這些“蒸餾”模型到底是什麼呢?在這裡,“蒸餾”指的是組織釋出的原始模型的蒸餾版本。蒸餾模型基本上是更小更高效的模型,旨在複製大型模型的行為,同時降低資源需求。

Deepseek

蒸餾模型的優勢

  • 減少記憶體佔用和計算需求
  • 降低推理和訓練過程中的能耗
  • 更快的處理時間

如何引入蒸餾模型?

這一過程旨在保持效能,同時減少記憶體佔用和計算需求。這是 Geoffrey Hinton 在其 2015 年的論文“Distilling the Knowledge in a Neural Network”中提出的一種模型壓縮形式

Hinton 提出了一個問題:是否有可能訓練一個大型神經網路,然後將其知識壓縮到一個較小的網路中?他認為,較小的網路充當學生,而較大的網路充當老師。目標是讓學生複製老師學到的關鍵權重。

Machine Learning

Source: Heroes of Machine Learning

通過分析教師的行為和預測,辛頓和他的同事們設計出了一種訓練方法,可以讓一個較小的(學生)網路有效地學習其權重。其核心思想是儘量減小學生輸出與兩類目標之間的誤差:實際地面實況(硬目標)和教師預測(軟目標)。

雙重損失

  • 硬損失:這是根據真實(地面實況)標籤測量的誤差。這是您在標準訓練中通常要優化的,以確保模型學習到正確的輸出。
  • 軟損失:這是根據教師預測測量的誤差。雖然教師可能並不完美,但它的預測包含了輸出類別相對概率的寶貴資訊,可以指導學生模型實現更好的泛化。

訓練目標是最小化這兩種損失的加權和。軟損失的權重用 λ 表示:

雙重損失公式

在這個公式中,引數 λ(軟權重)決定了從實際標籤學習與模仿教師輸出之間的平衡。儘管有人可能會說,真實標籤應該足以滿足訓練的需要,但結合教師的預測(軟損失)實際上有助於加快訓練速度,並通過細微的資訊指導學生提高成績。

Softmax函式和Temperature

該方法的一個關鍵組成部分是通過一個名為 Temperature (T) 的引數來修改 softmax 函式。softmax 函式也稱為歸一化指數函式,它將神經網路的原始輸出分數(logits)轉換為概率。對於具有 y_i 值的節點 i,標準 softmax 的定義如下:

Softmax函式

Hinton 引入了包含 Temperature 引數的新版 softmax 函式:

新版 softmax 函式

  • 當 T=1 時:函式表現與標準 softmax 相似。
  • 當 T>1 時:指數值變得不那麼極端,從而產生一種“softer”的類概率分佈。換句話說,概率分佈會變得更均勻,從而揭示出更多關於每個類別相對可能性的資訊。

用Temperature調整損失

由於較高的 temperature 會產生較柔和的分佈,因此在訓練過程中會有效地縮放梯度。為了糾正這一點並保持從軟目標中有效學習,軟損失乘以 T^2。更新後的整體損失函式變為

用Temperature調整損失

這種表述方式確保了硬損失(來自實際標籤)和經過Temperature調整的軟損失(來自教師的預測)都能為學生模型的訓練做出適當的貢獻。

概述

  • 師生動態(Teacher-Student Dynamics):學生模型通過最小化與真實標籤(硬損失)和教師預測(軟損失)的誤差來學習。
  • 加權損失函式(Weighted Loss Function):總體訓練損失是硬損失和軟損失的加權和,由引數 λ 控制。
  • 經Temperature調整Softmax(Temperature-Adjusted Softmax): 在 Softmax 函式中引入Temperature(T)會軟化概率分佈,將軟損失乘以 T^2 可以在訓練過程中補償這種影響。

將這些元素結合在一起,就能高效地訓練出經過提煉的網路,既能利用硬標籤的精確性,又能利用教師預測提供的更豐富、更翔實的指導。這一過程不僅能加快訓練速度,還能幫助較小的網路接近較大網路的效能。

DistilBERT

DistilBERT 對 Hinton 的蒸餾方法稍作修改,增加了餘弦嵌入損失,以測量學生和教師嵌入向量之間的距離。下面是一個快速比較:

  • DistilBERT:6 層,6600 萬個引數
  • BERT-base:12 層,1.1 億個引數

兩個模型都在相同的資料集(英語維基百科和多倫多圖書語料庫)上進行了再訓練。關於評估任務

  • GLUE 任務:BERT-base 的平均準確率為 79.5%,而 DistilBERT 為 77%。
  • SQuAD 資料集:BERT-base 的 F1 準確率為 88.5%,而 DistilBERT 為 86%。

DistillGPT2

針對 GPT-2,最初發布了四種尺寸:

  • 最小的 GPT-2 有 12 層,大約有 1.17 億個引數(由於實現上的差異,有些報告指出有 1.24 億個引數)。
  • DistillGPT2是經過提煉的版本,有 6 層和 8200 萬個引數,同時保留了相同的嵌入大小(768)。

您可以在 Hugging Face 上探索該模型。

儘管 distillGPT2 的速度是 GPT-2 的兩倍,但它在大型文字資料集上的困惑度卻高出 5 個百分點。在 NLP 中,較低的困惑度表示較好的效能;因此,最小的 GPT-2 仍然優於其蒸餾後的同類產品。

實施LLM蒸餾

實施大型語言模型(LLM)蒸餾涉及多個步驟,需要使用專門的框架和庫。下面概述了這一過程:

框架和庫

  • Hugging Face 轉換器:提供一個 Distiller 類,可簡化從教師模型到學生模型的知識轉移。
  • 其他庫:
    • TensorFlow模型優化:提供模型剪枝、量化和蒸餾工具。
    • PyTorch Distiller:包含使用蒸餾技術壓縮模型的實用工具。
    • DeepSpeed:由微軟開發,包含模型訓練和蒸餾功能。

相關步驟

  1. 資料準備:準備一個能代表目標任務的資料集。資料增強技術可進一步提高訓練示例的多樣性。
  2. 選擇教師模型:選擇一個效能良好、經過預先訓練的教師模型。教師的質量直接影響學生的表現。
  3. 蒸餾過程
    • 訓練設定:初始化學生模型並配置訓練引數(如學習率、批量大小)。
    • 知識傳輸:使用教師模型生成軟目標(概率分佈)和硬目標(地面實況標籤)。
    • 訓練迴圈:訓練學生模型,使其預測與軟/硬目標之間的綜合損失最小。
  4. 評估指標:評估提煉模型的常用指標包括
    • 準確度:預測正確率。
    • 推理速度:處理輸入所需的時間。
    • 模型大小:規模縮小和計算效率。
    • 資源利用率:推理過程中計算資源消耗的效率。

瞭解蒸餾模型

瞭解蒸餾模型

Source: Knowledge Distillation

模型蒸餾的關鍵要素

選擇教師和學生模型架構

學生模型可以是教師模型的簡化版或量化版,也可以是完全不同的優化架構。選擇取決於部署環境的具體要求。

選擇教師和學生模型架構

Source: Relationship of teacher – student model

蒸餾過程解析

這一過程的核心是訓練學生模型模仿教師的行為。這是通過最小化學生預測與教師輸出之間的差異來實現的 – 這種監督學習方法構成了模型蒸餾的基礎。

蒸餾過程解析

Source: Knowledge Distillation Core Concepts

挑戰與侷限

雖然蒸餾模型具有明顯的優勢,但也有一些挑戰需要考慮:

  • 精度權衡:與大型模型相比,蒸餾模型的效能通常會略有下降。
  • 蒸餾過程的複雜性:配置合適的訓練環境和微調超引數(如 λ 和 T)可能具有挑戰性。
  • 領域適應性:蒸餾的有效性可能會因使用模型的特定領域或任務而有所不同。

模型蒸餾的未來方向

模型蒸餾領域發展迅速。一些前景廣闊的領域包括

  • 蒸餾技術的進步:正在進行的研究旨在縮小教師模型和學生模型之間的效能差距。
  • 自動蒸餾過程:自動調整超引數的新方法不斷湧現,使蒸餾過程更方便、更高效。
  • 更廣泛的應用:除 NLP 外,模型蒸餾技術在計算機視覺、強化學習和其他領域的應用也越來越廣泛,有可能改變資源受限環境下的部署方式。

實際應用

經過提煉的模型正在各行各業得到實際應用:

  • 移動和邊緣計算:它們體積更小,非常適合部署在計算能力有限的裝置上,確保移動應用程式和物聯網裝置的推理速度更快。
  • 能源效率:在雲服務等大規模部署中,降低功耗至關重要。蒸餾模型有助於降低能耗。
  • 快速原型開發:對於初創企業和研究人員來說,蒸餾模型可在效能和資源效率之間取得平衡,從而加快開發週期。

小結

蒸餾模型在高效能和計算效率之間實現了微妙的平衡,從而改變了深度學習。雖然它們可能會因為較小的規模和對軟損失訓練的依賴而犧牲一些準確性,但其更快的處理速度和更低的資源需求使它們在資源有限的環境中尤為寶貴。

從本質上講,蒸餾網路模仿了其較大對應網路的行為,但由於容量有限,其效能永遠無法超越。當計算資源有限或其效能非常接近原始模型時,這種權衡使蒸餾模型成為明智的選擇。相反,如果效能下降明顯,或者通過並行化等方法可以隨時獲得計算能力,那麼選擇原始的大型模型可能是更好的選擇。

評論留言