什么是蒸馏模型(Distilled Model)?即更小更高效的模型

什么是蒸馏模型(Distilled Model)?即更小更高效的模型

你不可能没听说过 Deepseek,但您是否也在 Ollama 上看到过 Deepseek 的蒸馏模型?或者,如果您尝试过 Groq Cloud,也可能看到过类似的模型。但这些“蒸馏”模型到底是什么呢?在这里,“蒸馏”指的是组织发布的原始模型的蒸馏版本。蒸馏模型基本上是更小更高效的模型,旨在复制大型模型的行为,同时降低资源需求。

Deepseek

蒸馏模型的优势

  • 减少内存占用和计算需求
  • 降低推理和训练过程中的能耗
  • 更快的处理时间

如何引入蒸馏模型?

这一过程旨在保持性能,同时减少内存占用和计算需求。这是 Geoffrey Hinton 在其 2015 年的论文“Distilling the Knowledge in a Neural Network”中提出的一种模型压缩形式

Hinton 提出了一个问题:是否有可能训练一个大型神经网络,然后将其知识压缩到一个较小的网络中?他认为,较小的网络充当学生,而较大的网络充当老师。目标是让学生复制老师学到的关键权重。

Machine Learning

Source: Heroes of Machine Learning

通过分析教师的行为和预测,辛顿和他的同事们设计出了一种训练方法,可以让一个较小的(学生)网络有效地学习其权重。其核心思想是尽量减小学生输出与两类目标之间的误差:实际地面实况(硬目标)和教师预测(软目标)。

双重损失

  • 硬损失:这是根据真实(地面实况)标签测量的误差。这是您在标准训练中通常要优化的,以确保模型学习到正确的输出。
  • 软损失:这是根据教师预测测量的误差。虽然教师可能并不完美,但它的预测包含了输出类别相对概率的宝贵信息,可以指导学生模型实现更好的泛化。

训练目标是最小化这两种损失的加权和。软损失的权重用 λ 表示:

双重损失公式

在这个公式中,参数 λ(软权重)决定了从实际标签学习与模仿教师输出之间的平衡。尽管有人可能会说,真实标签应该足以满足训练的需要,但结合教师的预测(软损失)实际上有助于加快训练速度,并通过细微的信息指导学生提高成绩。

Softmax函数和Temperature

该方法的一个关键组成部分是通过一个名为 Temperature (T) 的参数来修改 softmax 函数。softmax 函数也称为归一化指数函数,它将神经网络的原始输出分数(logits)转换为概率。对于具有 y_i 值的节点 i,标准 softmax 的定义如下:

Softmax函数

Hinton 引入了包含 Temperature 参数的新版 softmax 函数:

新版 softmax 函数

  • 当 T=1 时:函数表现与标准 softmax 相似。
  • 当 T>1 时:指数值变得不那么极端,从而产生一种“softer”的类概率分布。换句话说,概率分布会变得更均匀,从而揭示出更多关于每个类别相对可能性的信息。

用Temperature调整损失

由于较高的 temperature 会产生较柔和的分布,因此在训练过程中会有效地缩放梯度。为了纠正这一点并保持从软目标中有效学习,软损失乘以 T^2。更新后的整体损失函数变为

用Temperature调整损失

这种表述方式确保了硬损失(来自实际标签)和经过Temperature调整的软损失(来自教师的预测)都能为学生模型的训练做出适当的贡献。

概述

  • 师生动态(Teacher-Student Dynamics):学生模型通过最小化与真实标签(硬损失)和教师预测(软损失)的误差来学习。
  • 加权损失函数(Weighted Loss Function):总体训练损失是硬损失和软损失的加权和,由参数 λ 控制。
  • 经Temperature调整Softmax(Temperature-Adjusted Softmax): 在 Softmax 函数中引入Temperature(T)会软化概率分布,将软损失乘以 T^2 可以在训练过程中补偿这种影响。

将这些元素结合在一起,就能高效地训练出经过提炼的网络,既能利用硬标签的精确性,又能利用教师预测提供的更丰富、更翔实的指导。这一过程不仅能加快训练速度,还能帮助较小的网络接近较大网络的性能。

DistilBERT

DistilBERT 对 Hinton 的蒸馏方法稍作修改,增加了余弦嵌入损失,以测量学生和教师嵌入向量之间的距离。下面是一个快速比较:

  • DistilBERT:6 层,6600 万个参数
  • BERT-base:12 层,1.1 亿个参数

两个模型都在相同的数据集(英语维基百科和多伦多图书语料库)上进行了再训练。关于评估任务

  • GLUE 任务:BERT-base 的平均准确率为 79.5%,而 DistilBERT 为 77%。
  • SQuAD 数据集:BERT-base 的 F1 准确率为 88.5%,而 DistilBERT 为 86%。

DistillGPT2

针对 GPT-2,最初发布了四种尺寸:

  • 最小的 GPT-2 有 12 层,大约有 1.17 亿个参数(由于实现上的差异,有些报告指出有 1.24 亿个参数)。
  • DistillGPT2是经过提炼的版本,有 6 层和 8200 万个参数,同时保留了相同的嵌入大小(768)。

您可以在 Hugging Face 上探索该模型。

尽管 distillGPT2 的速度是 GPT-2 的两倍,但它在大型文本数据集上的困惑度却高出 5 个百分点。在 NLP 中,较低的困惑度表示较好的性能;因此,最小的 GPT-2 仍然优于其蒸馏后的同类产品。

实施LLM蒸馏

实施大型语言模型(LLM)蒸馏涉及多个步骤,需要使用专门的框架和库。下面概述了这一过程:

框架和库

  • Hugging Face 转换器:提供一个 Distiller 类,可简化从教师模型到学生模型的知识转移。
  • 其他库:
    • TensorFlow模型优化:提供模型剪枝、量化和蒸馏工具。
    • PyTorch Distiller:包含使用蒸馏技术压缩模型的实用工具。
    • DeepSpeed:由微软开发,包含模型训练和蒸馏功能。

相关步骤

  1. 数据准备:准备一个能代表目标任务的数据集。数据增强技术可进一步提高训练示例的多样性。
  2. 选择教师模型:选择一个性能良好、经过预先训练的教师模型。教师的质量直接影响学生的表现。
  3. 蒸馏过程
    • 训练设置:初始化学生模型并配置训练参数(如学习率、批量大小)。
    • 知识传输:使用教师模型生成软目标(概率分布)和硬目标(地面实况标签)。
    • 训练循环:训练学生模型,使其预测与软/硬目标之间的综合损失最小。
  4. 评估指标:评估提炼模型的常用指标包括
    • 准确度:预测正确率。
    • 推理速度:处理输入所需的时间。
    • 模型大小:规模缩小和计算效率。
    • 资源利用率:推理过程中计算资源消耗的效率。

了解蒸馏模型

了解蒸馏模型

Source: Knowledge Distillation

模型蒸馏的关键要素

选择教师和学生模型架构

学生模型可以是教师模型的简化版或量化版,也可以是完全不同的优化架构。选择取决于部署环境的具体要求。

选择教师和学生模型架构

Source: Relationship of teacher – student model

蒸馏过程解析

这一过程的核心是训练学生模型模仿教师的行为。这是通过最小化学生预测与教师输出之间的差异来实现的 – 这种监督学习方法构成了模型蒸馏的基础。

蒸馏过程解析

Source: Knowledge Distillation Core Concepts

挑战与局限

虽然蒸馏模型具有明显的优势,但也有一些挑战需要考虑:

  • 精度权衡:与大型模型相比,蒸馏模型的性能通常会略有下降。
  • 蒸馏过程的复杂性:配置合适的训练环境和微调超参数(如 λ 和 T)可能具有挑战性。
  • 领域适应性:蒸馏的有效性可能会因使用模型的特定领域或任务而有所不同。

模型蒸馏的未来方向

模型蒸馏领域发展迅速。一些前景广阔的领域包括

  • 蒸馏技术的进步:正在进行的研究旨在缩小教师模型和学生模型之间的性能差距。
  • 自动蒸馏过程:自动调整超参数的新方法不断涌现,使蒸馏过程更方便、更高效。
  • 更广泛的应用:除 NLP 外,模型蒸馏技术在计算机视觉、强化学习和其他领域的应用也越来越广泛,有可能改变资源受限环境下的部署方式。

实际应用

经过提炼的模型正在各行各业得到实际应用:

  • 移动和边缘计算:它们体积更小,非常适合部署在计算能力有限的设备上,确保移动应用程序和物联网设备的推理速度更快。
  • 能源效率:在云服务等大规模部署中,降低功耗至关重要。蒸馏模型有助于降低能耗。
  • 快速原型开发:对于初创企业和研究人员来说,蒸馏模型可在性能和资源效率之间取得平衡,从而加快开发周期。

小结

蒸馏模型在高性能和计算效率之间实现了微妙的平衡,从而改变了深度学习。虽然它们可能会因为较小的规模和对软损失训练的依赖而牺牲一些准确性,但其更快的处理速度和更低的资源需求使它们在资源有限的环境中尤为宝贵。

从本质上讲,蒸馏网络模仿了其较大对应网络的行为,但由于容量有限,其性能永远无法超越。当计算资源有限或其性能非常接近原始模型时,这种权衡使蒸馏模型成为明智的选择。相反,如果性能下降明显,或者通过并行化等方法可以随时获得计算能力,那么选择原始的大型模型可能是更好的选择。

评论留言