本发明涉及大语言模型,尤其涉及一种大语言模型的蒸馏方法、装置、电子设备及可读存储介质。
背景技术:
1、模型蒸馏(model distillation)是一种模型压缩技术,旨在通过将一个复杂而大的模型(称为教师模型)的知识传递给一个较小而简单的模型(称为学生模型),从而使学生模型在保留高性能的同时,减少计算资源和存储空间的消耗。
2、模型蒸馏技术对于大语言模型(large language models, llms)具有重要意义。大语言模型通常包含数十亿甚至数千亿参数,在推理阶段需要消耗大量的计算资源和存储空间。
3、因此,如何对大语言模型进行蒸馏,就成为一个亟待解决的问题。
技术实现思路
1、本发明的目的在于提供一种大语言模型的蒸馏方法、装置、电子设备及可读存储介质。
2、为了实现上述发明目的之一,本发明一实施方式提供了一种大语言模型的蒸馏方法,包括以下步骤:获取若干已训练的教师大语言模型,获取包含有多个训练数据的第一数据集合;获取第一数据集合的子集为第二数据集合,基于若干教师大语言模型对第二数据集合中的每个训练数据均进行处理,并获取每个教师大语言模型所输出的文本序列,从而得到每个训练数据对应的若干文本序列;获取学生大语言模型,对第二数据集合中的每个训练数据均进行以下处理:获取所述训练数据对应的若干文本序列一一对应的若干token向量,将若干token向量排成一个队列,按照从队头朝向队尾的次序获取待训练token向量,基于所述待训练token向量训练所述学生大语言模型。
3、作为本发明一实施方式的进一步改进,所述基于若干教师大语言模型对第二数据集合中的每个训练数据均进行处理具体包括:基于若干教师大语言模型对第二数据集合中的每个训练数据均进行处理,且在每次进行处理,当输出的文本序列结束、或者输出的文本序列的长度大于预设阈值时,停止此次处理;当输出的文本序列的末尾呈现循环重复的字符串时,对所述文本序列的末尾进行裁剪,只保留一个循环重复的字符串,其中,所述字符串由若干token组成。
4、作为本发明一实施方式的进一步改进,所述获取所述训练数据对应的若干文本序列一一对应的若干token向量具体包括:获取所述训练数据对应的若干文本序列一一对应的若干token向量,当任一token向量中的token的数量len1大于预设长度值len2时,丢弃所述token向量中的后面的len1-len2个token。
5、作为本发明一实施方式的进一步改进,所述基于所述待训练token向量训练所述学生大语言模型具体包括:所述待训练token向量中的token的数量为len3,i=1,持续进行以下操作直至i=len3,所述操作包括:获取所述token向量中的前i个token,将前i个token输入到所述学生大语言模型进行推理,将推理得到的词表的概率分布和第i+1个token的独热编码的向量进行计算并得到交叉熵,将所述交叉熵作为损失函数,进行损失函数计算,梯度计算和反向传播,之后,i的值增加1。
6、作为本发明一实施方式的进一步改进,所述损失函数为交叉损失函数。
7、作为本发明一实施方式的进一步改进,len2=16。
8、本发明实施例还提供了一种大语言模型的蒸馏装置,包括以下模块:信息获取模块,用于获取若干已训练的教师大语言模型,获取包含有多个训练数据的第一数据集合;教师模块,用于获取第一数据集合的子集为第二数据集合,基于若干教师大语言模型对第二数据集合中的每个训练数据均进行处理,并获取每个教师大语言模型所输出的文本序列,从而得到每个训练数据对应的若干文本序列;学生模块,用于获取学生大语言模型,对第二数据集合中的每个训练数据均进行以下处理:获取所述训练数据对应的若干文本序列一一对应的若干token向量,将若干token向量排成一个队列,按照从队头朝向队尾的次序获取待训练token向量,基于所述待训练token向量训练所述学生大语言模型。
9、作为本发明一实施方式的进一步改进,所述教师模块还用于:基于若干教师大语言模型对第二数据集合中的每个训练数据均进行处理,且在每次进行处理,当输出的文本序列结束、或者输出的文本序列的长度大于预设阈值时,停止此次处理;当输出的文本序列的末尾呈现循环重复的字符串时,对所述文本序列的末尾进行裁剪,只保留一个循环重复的字符串,其中,所述字符串由若干token组成。
10、本发明实施例还提供了一种电子设备,包括:存储器,用于保存计算机程序;处理器,用于执行所述计算机程序以实现上述的蒸馏方法。
11、本发明实施例还提供了一种可读存储介质,用于保存计算机程序,所述计算机程序被处理器执行时实现上述的蒸馏方法。
12、相对于现有技术,本发明的技术效果在于:本发明实施例提供了一种大语言模型的蒸馏方法、装置、电子设备及可读存储介质,该蒸馏方法包括:获取若干已训练的教师大语言模型,获取包含有多个训练数据的第一数据集合;获取第一数据集合的子集为第二数据集合,基于若干大语言教师模型对第二数据集合中的每个训练数据均进行处理,从而得到第二数据集合中的每个训练数据对应的若干文本序列;获取大语言学生模型,对第二数据集合中的训练数据的文本序列对大语言学生模型进行训练。从而能够对大语言模型的进行蒸馏处理。
1.一种大语言模型的蒸馏方法,其特征在于,包括以下步骤:
2.根据权利要求1所述的蒸馏方法,其特征在于,所述基于若干教师大语言模型对第二数据集合中的每个训练数据均进行处理具体包括:
3.根据权利要求2所述的蒸馏方法,其特征在于,所述获取所述训练数据对应的若干文本序列一一对应的若干token向量具体包括:
4.根据权利要求3所述的蒸馏方法,其特征在于,所述基于所述待训练token向量训练所述学生大语言模型具体包括:
5.根据权利要求4所述的蒸馏方法,其特征在于,
6.根据权利要求4所述的蒸馏方法,其特征在于,
7.一种大语言模型的蒸馏装置,其特征在于,包括以下模块:
8.根据权利要求7所述的蒸馏装置,其特征在于,所述教师模块还用于:
9.一种电子设备,其特征在于,包括:
10.一种可读存储介质,其特征在于,用于保存计算机程序,所述计算机程序被处理器执行时实现如权利要求1至6任一项所述的蒸馏方法。