基于知识蒸馏的多阶段脉冲神经网络训练方法及装置

专利2025-06-02  42


本发明涉及深度学习,尤其涉及一种基于知识蒸馏的多阶段脉冲神经网络训练方法及装置。


背景技术:

1、脉冲神经网络(snn)作为一种新兴的神经网络,以神经元脉冲信号传递信息,具有接近生物神经系统的独特特性,脉冲神经网络利用代表神经元放电的脉冲离散事件来处理和传输信息,并且还具有高效的事件驱动处理和低功耗的优势。在生物神经系统中,神经元之间的信息传递是通过离散的脉冲信号进行的。相比之下,传统的人工神经网络(ann)采用连续的激活函数进行信息传递,忽略了脉冲的时态编码方式。而研究脉冲神经网络可以更好地模拟和理解生物神经系统的工作原理,从而提供对大脑运行机制的深入理解,因此如何有效地训练脉冲神经网络目前是一个非常值得研究的问题。

2、现阶段主流的构建脉冲神经网络方法主要有三种:第一种是采用尖峰时间依赖可塑性(stdp)学习规则,它是一种与生物神经系统中的可塑性机制有关的hebbian型学习规则,它可以用于模拟神经元之间突触权重的调整,但是不涉及传统的神经网络的权重更新方法。因此缺乏全局指导导致用stdp规则训练的脉冲神经网络不能扩展到具有良好性能的深层网络,并且被限制在用于解决简单任务的浅层结构,第二种是用设计的连续函数代替脉冲神经网络中的实际梯度,以近似尖峰神经元的不连续导数。该近似对于网络规模较小的脉冲神经网络是足够的,但是如果脉冲神经网络采用深度结构以及解决具有挑战性的任务的网络则难以胜任。此外,使用替代梯度方法训练的脉冲神经网络无法获得与其对应的人工神经网络相当的性能。第三种是ann-to-snn转换方法,其将预训练的人工神经网络的参数转换为具有相同结构的脉冲神经网络。尽管这种方法即使在大型网络中也能实现几乎等效的表示,但ann转换的脉冲神经网络通常需要大量的时间步长,时间步长达到数百甚至数千,从而导致大量的功耗,这和脉冲神经网络的初衷背道而驰。

3、因此,更高效的训练脉冲神经网络,缩小其与人工神经网络之间精度的差距,是一个非常值得研究的问题,所以需要有一种更创新且有效的训练方法,使得训练的模型,不仅能够拥有更接近人工神经网络的精度,也能使得训练过程更加的高效且低耗。


技术实现思路

1、本发明的目的在于针对现有技术的不足,提供一种基于知识蒸馏的多阶段脉冲神经网络训练方法及装置。

2、本发明的目的是通过以下技术方案来实现的:本发明实施例第一方面提供了一种基于知识蒸馏的多阶段脉冲神经网络训练方法,包括以下步骤:

3、(1)获取已知图像数据以构建数据集;

4、(2)选取人工神经网络作为知识蒸馏中的教师模型,选取脉冲神经网络作为知识蒸馏中的学生模型,并预设脉冲神经网络中lif神经元的参数;

5、(3)将作为学生模型的脉冲神经网络和作为教师模型的人工神经网络按照模型结构划分为k个不同的阶段;

6、(4)在使用数据集对脉冲神经网络进行训练的过程中,根据人工神经网络每个阶段输出的特征和脉冲神经网络每个阶段输出的特征计算该阶段基于特征的样本间损失;

7、(5)在使用数据集对脉冲神经网络进行训练的过程中,将脉冲神经网络每个阶段输出的特征通过全连接层映射到概率分布空间,再通过softmax函数转换为概率输出,并以人工神经网络最终输出的概率作为基础知识,计算该阶段基于置信度的损失;

8、(6)根据k个阶段的基于特征的样本间损失和基于置信度的损失计算总损失函数,以最小化总损失函数为优化目标,采用反向传播法更新脉冲神经网络的网络参数,以获取训练好的脉冲神经网络,用于执行图像分类任务。

9、进一步地,所述步骤(4)具体包括:对于第i个阶段的基于特征的样本间损失,其计算方法具体包括:

10、首先假设人工神经网络在第i个阶段输出的特征ft映射为gram矩阵,其维度为b*b,其中b表示批次大小;随后对于脉冲神经网络,在第i个阶段的输出,则有每个时间步特征将其映射为gram矩阵,其维度为t*b*b,其中t表示时间步长;再在时间维度上对所有时间步特征对应的gram矩阵进行平均处理,得到平均处理后的gram矩阵;其次,分别对特征ft对应的gram矩阵和平均处理后的gram矩阵进行l2正则化处理,分别得到人工神经网络和脉冲神经网络对应的正则化后的gram矩阵;最后根据人工神经网络和脉冲神经网络对应的正则化后的gram矩阵计算基于特征的样本间损失linter。

11、进一步地,所述基于特征的样本间损失linter的计算公式为:

12、

13、其中,β为调节因子,rt表示人工神经网络对应的正则化后的gram矩阵,rs表示脉冲神经网络对应的正则化后的gram矩阵,其表达式分别为:

14、rt=exp||rel(ft)||2

15、

16、其中,exp表示指数函数,||·||2表示l2正则化,frel(ft)表示特征ft对应的gram矩阵,表示时间步特征对应的gram矩阵,其表达式分别为:

17、frel(ft)=ft·tt

18、

19、其中,ftt表示特征ft的转置,表示时间步特征的转置。

20、进一步地,所述步骤(5)具体包括:对于第i个阶段的基于置信度的损失lconf,其计算方法具体包括:

21、首先将脉冲神经网络在第i个阶段输出的特征通过全连接层映射到概率分布空间,再通过softmax函数将其转换为概率输出psc,具体通过在指定的时间窗口内对脉冲尖峰进行计数并将其归一化处理得到;然后,根据人工神经网络最终输出的概率ptc和第i个阶段脉冲神经网络的概率输出psc计算基于置信度的损失lconf。

22、进一步地,所述概率输出psc的计算公式为:

23、

24、其中,psc表示脉冲神经网络的输出为类别c的概率输出,nc和nj分别表示时间窗口t内类别c和类别j的峰值计数;

25、所述基于置信度的损失lconf的计算公式为:

26、

27、其中,ptc表示人工神经网络的最终输出为类别c的概率输出,c表示总的类别数量,γ是一个可调的超参数。

28、进一步地,所述总损失函数的计算公式为:

29、

30、其中,ltotal表示总损失函数,λ是用于平衡基于特征的样本间损失linter和基于置信度的损失lconf的超参数,i表示第i个阶段。

31、本发明实施例第二方面提供了一种基于知识蒸馏的多阶段脉冲神经网络训练装置,包括一个或多个处理器和存储器,所述存储器与所述处理器耦接;其中,所述存储器用于存储程序数据,所述处理器用于执行所述程序数据以实现上述的基于知识蒸馏的多阶段脉冲神经网络训练方法。

32、本发明实施例第三方面提供了一种计算机可读存储介质,其上存储有程序,该程序被处理器执行时,用于实现上述的基于知识蒸馏的多阶段脉冲神经网络训练方法。

33、本发明的有益效果为,本发明通过分阶段将脉冲神经网络的中间输出映射到logits空间中,与和人工神经网络的最终结果分别计算基于置信度的损失和基于特征的样本间损失,最后把各阶段的基于置信度的损失和基于特征的样本间损失整合计算得到总损失函数,从而训练更新脉冲神经网络的网络参数,用于执行图像分类任务;本发明所述的训练方法将人工神经网络和脉冲神经网络的输出映射到统一的logits空间,并且考虑了基于置信度和特征的样本间的联合损失,制定了一个能够让脉冲神经网络更高效从人工神经网络中提取知识的方案,从而进一步缩小了人工神经网络和脉冲神经网络之间的精度的差距。


技术特征:

1.一种基于知识蒸馏的多阶段脉冲神经网络训练方法,其特征在于,包括以下步骤:

2.根据权利要求1所述的基于知识蒸馏的多阶段脉冲神经网络训练方法,其特征在于,所述步骤(4)具体包括:对于第i个阶段的基于特征的样本间损失,其计算方法具体包括:

3.根据权利要求2所述的基于知识蒸馏的多阶段脉冲神经网络训练方法,其特征在于,所述基于特征的样本间损失lintet的计算公式为:

4.根据权利要求1所述的基于知识蒸馏的多阶段脉冲神经网络训练方法,其特征在于,所述步骤(5)具体包括:对于第i个阶段的基于置信度的损失lconf,其计算方法具体包括:

5.根据权利要求4所述的基于知识蒸馏的多阶段脉冲神经网络训练方法,其特征在于,所述概率输出psc的计算公式为:

6.根据权利要求1所述的基于知识蒸馏的多阶段脉冲神经网络训练方法,其特征在于,所述总损失函数的计算公式为:

7.一种基于知识蒸馏的多阶段脉冲神经网络训练装置,包括一个或多个处理器和存储器,其特征在于,所述存储器与所述处理器耦接;其中,所述存储器用于存储程序数据,所述处理器用于执行所述程序数据以实现权利要求1-6中任一项所述的基于知识蒸馏的多阶段脉冲神经网络训练方法。

8.一种计算机可读存储介质,其特征在于,其上存储有程序,该程序被处理器执行时,用于实现权利要求1-6中任一项所述的基于知识蒸馏的多阶段脉冲神经网络训练方法。


技术总结
本发明公开了一种基于知识蒸馏的多阶段脉冲神经网络训练方法及装置,该方法通过分阶段将脉冲神经网络的中间输出映射到logits空间中,与和人工神经网络的最终结果分别计算基于置信度的损失和基于特征的样本间损失,最后把各阶段的基于置信度的损失和基于特征的样本间损失整合计算得到总损失函数,从而训练更新脉冲神经网络的网络参数,用于执行图像分类任务。本发明所述的训练方法将人工神经网络和脉冲神经网络的输出映射到统一的logits空间,并且考虑了基于置信度和特征的样本间的联合损失,制定了一个能够让脉冲神经网络更高效从人工神经网络中提取知识的方案,从而进一步缩小了人工神经网络和脉冲神经网络之间的精度的差距。

技术研发人员:梁秀波,魏铭远
受保护的技术使用者:浙江大学
技术研发日:
技术公布日:2024/12/17
转载请注明原文地址:https://xbbs.6miu.com/read-24925.html