本发明涉及联邦学习领域,具体而言,涉及一种基于联邦大模型的反转知识蒸馏方法和系统。
背景技术:
1、伴随着大模型的出现与逐渐成熟,大模型对于更大规模训练数据的依赖以及数据隐私催生了一个新的问题:即如何在不侵犯数据隐私法律条款的基础上,利用私人领域的孤立数据联合训练一个大规模模型。而一种基于联邦学习的大模型的训练架构与方法应运而生,解决分布式的大模型训练问题。联邦学习通过分布式的方式训练模型,不需要将数据集中到一个中心位置,从而有效地保护了用户的隐私。而知识蒸馏则通过将复杂模型的知识传递给较小模型,实现了模型的轻量化和高效推理。在边缘计算场景中,结合联邦学习和知识蒸馏技术,可以在保障数据隐私的前提下,提升模型在边缘设备上的性能,使得智能应用更加高效和便捷。联邦知识蒸馏是一种结合联邦学习和知识蒸馏的先进技术,通过在分布式环境中协作训练模型来提升整体性能和准确性。在这种框架下,不同的设备和节点可以共享它们的局部模型知识,而无需直接共享原始数据,从而保护数据隐私。同时,通过知识蒸馏技术,将复杂模型的知识传递给更轻量的模型,优化模型的推理效率和资源消耗。这种方法特别适用于对计算效率要求高的场景,如医疗健康和金融行业,实现高效且准确的智能应用。
2、在实现本发明的过程中,申请人发现:传统的知识蒸馏策略是由一个经过训练的大模型来将知识传授给一个或未经过训练的小模型中,往往忽略了具备个性化数据的小模型对大模型的贡献。这些传统策略无法让大模型在某些垂直领域的特定数据集上充分学习,导致性能和个性化的用户体验不理想。此前的研究仍存在一个明显的限制,大多数研究都局限于为同一个架构的模型进行联邦知识蒸馏,并未考虑计算资源受限且异构的客户端设备之间的协作训练。
3、因此,如何提升服务器端大参数模型的训练效率以及性能表象成为需要解决的技术问题。
技术实现思路
1、本发明旨在至少解决现有技术或相关技术中存在的技术问题之一,公开了一种基于联邦大模型的反转知识蒸馏方法和系统,通过多个小参数模型对大参数模型进行知识传递,能够减少网络资源消耗,提高准确度,降低训练轮次以及提升用户体验。
2、术语解释:联邦学习(federated learning)是一种分布式机器学习技术,它允许多个数据拥有方在不共享数据的情况下建立机器学习模型。联邦学习的目标是在保证数据隐私安全及合法合规的基础上,实现共同建模,提升 ai 模型的效果。知识蒸馏(knowledgedistillation)是一种模型压缩技术,它通过训练一个小型的神经网络(称为“学生模型”)来模仿一个大型的、预训练的神经网络(称为“教师模型”)。这种方法可以看作是知识从教师模型向学生模型的转移,因此得名“知识蒸馏”。本发明的第一方面公开了一种基于联邦大模型的反转知识蒸馏方法,包括:服务器向客户本地的客户端下发小参数模型,以便客户端使用私有数据集对小参数模型的权重参数进行更新,得到客户端模型;客户端模型使用公共数据集输出客户端模型软逻辑,并将客户端模型软逻辑发送至服务器;服务器的大参数模型使用公共数据集输出服务器模型软逻辑;根据服务器模型软逻辑与多个客户端模型软逻辑间的相关性以及每个客户端模型对公共数据集的准确率来计算每一个客户端的权重;根据客户端权重与客户端模型软逻辑计算加权客户软逻辑;使用服务器模型软逻辑与加权客户软逻辑进行知识蒸馏,以对服务器大参数模型的权重参数进行更新。
3、在该技术方案中,公开了一种多个客户端小参数模型将知识传递给单个服务器大参数模型的联邦知识蒸馏方法,本发明将此技术框架定义为反转知识蒸馏。由于客户计算资源的限制,服务器向客户端下发的是大参数模型的压缩版本,即小参数模型。客户接受来自服务器的小参数模型并使用私有数据集对其进行训练,客户端模型软逻辑输出是使用训练后客户端模型对公共图像数据集输出一个软逻辑,客户权重计算是计算每一个客户端在联邦知识蒸馏的过程中所占到的权重,客户的权重之和为1,服务器模型知识蒸馏是指服务器获得客户端加权软逻辑和服务器软逻辑,对服务器模型进行知识蒸馏过程对服务器模型进行更新。
4、根据本发明公开的基于联邦大模型的反转知识蒸馏方法,优选地,知识蒸馏具体包括:计算服务器模型软逻辑与加权客户软逻辑间的蒸馏损失;蒸馏损失与服务器模型的任务损失进行加权求和,得到服务器模型总损失;根据服务器模型总损失对服务器模型进行反向传播,从而依照客户端模型更新的方式来对服务器模型进行更新。服务器模型的任务损失是指服务器模型针对数据集输出自己的预测与数据集的真实预测之间的差距,通常使用交叉熵损失函数进行计算。根据本发明公开的基于联邦大模型的反转知识蒸馏方法,优选地,小参数模型是由服务器的大参数模型压缩而得。
5、根据本发明公开的基于联邦大模型的反转知识蒸馏方法,优选地,公共数据集由各个客户端所提供的数据组成。
6、根据本发明公开的基于联邦大模型的反转知识蒸馏方法,优选地,客户端模型软逻辑具体包括:使用softmax归一化函数作为客户模型的软逻辑输出函数。
7、根据本发明公开的基于联邦大模型的反转知识蒸馏方法,优选地,相关性是指余弦距离。
8、根据本发明公开的基于联邦大模型的反转知识蒸馏方法,优选地,服务器的大参数模型为resnet38模型,客户端模型为resnet14模型。
9、在该技术方案中,resnet(残差网络)是由微软研究院在2015年提出的深度学习模型,其核心思想是引入了残差学习框架,通过跳跃连接(skip connections)来解决深度神经网络训练中的梯度消失和梯度爆炸问题。resnet模型的核心是resnet模型中特有的残差块(residual blocks),resnet模型常见的有resnet-14、resnet-38、resnet-50等,其中14、38、50等数字代表了模型中残差块的数量,例如resnet-38包含了38个残差块,resnet-14包含了14个残差块。在resnet模型中,残差块分类两类:basicblock和bottleblock,若resnet模型采用basicblock,则resnet模型中残差块basicblock的数量x应满足等式(x-2)% 6=0;若resnet模型采用bottleblock,则resnet模型中残差块bottleblock的数量x应满足等式(x-2)% 9=0。
10、根据本发明公开的基于联邦大模型的反转知识蒸馏方法,优选地,客户端为个人计算机或手持移动设备,服务器为基站或边缘服务器。
11、本发明的第二方面公开了一种基于联邦大模型的反转知识蒸馏系统,包括:存储器,用于存储程序指令;处理器,用于调用存储器中存储的程序指令以实现如上述任一技术方案的基于联邦大模型的反转知识蒸馏方法。
12、本发明提供的技术方案的实际应用场景:随着人工智能技术的发展,多数摄像头已具备了图像识别功能。但某些情况下,摄像头针对一些从未见过的物体无法做出有效识别。通过本发明提供的方法,对于某些从未见过的物体,只需部分摄像头进行本地更新后来对一个服务器内的模型进行蒸馏过程。这样当某些摄像头无法识别某些物体或有新接入摄像头时,就可以向服务器发送请求来获得新模型,而不需要经过多次的本地更新,降低能源消耗的同时也可以降低摄像头时延,提升表现。
13、本发明的有益效果至少包括:本发明考虑到了传统联邦学习中模型聚合过程可能带来的用户隐私泄露问题,而删去了模型聚合过程,提高了网络资源的利用率。本发明提供的技术方案将联邦学习与知识蒸馏联合的概念扩展到了传统网络环境中,以一种独特的多个小参数模型对单个大参数模型进行知识蒸馏的方法,为服务器内大参数模型的更新与迭代提供了一种全新的方法,利用联邦学习与知识蒸馏来处理服务器内大参数模型对于多种数据类型的要求。能够让一个表现良好的模型在客户本地运行的同时提升服务器模型在某些未知数据集上的性能,同时还可以通过减少服务器模型的训练轮次来节省服务资源的消耗。
1.一种基于联邦大模型的反转知识蒸馏方法,其特征在于,包括:
2.根据权利要求1所述的基于联邦大模型的反转知识蒸馏方法,其特征在于,所述知识蒸馏具体包括:
3.根据权利要求1所述的基于联邦大模型的反转知识蒸馏方法,其特征在于,所述小参数模型是由所述服务器的大参数模型压缩而得。
4.根据权利要求1所述的基于联邦大模型的反转知识蒸馏方法,其特征在于,所述公共数据集由各个客户端所提供的数据组成。
5.根据权利要求1所述的基于联邦大模型的反转知识蒸馏方法,其特征在于,所述客户端模型软逻辑具体包括:使用softmax归一化函数作为客户模型的软逻辑输出函数。
6.根据权利要求1所述的基于联邦大模型的反转知识蒸馏方法,其特征在于,所述相关性是指余弦距离。
7.根据权利要求1至6中任一项所述的基于联邦大模型的反转知识蒸馏方法,其特征在于,所述服务器的大参数模型为resnet38模型,所述客户端模型为resnet14模型。
8.根据权利要求1至6中任一项所述的基于联邦大模型的反转知识蒸馏方法,其特征在于,所述客户端为个人计算机或手持移动设备,所述服务器为基站或边缘服务器。
9.一种基于联邦大模型的反转知识蒸馏系统,其特征在于,包括:
