一种transformer解码器的计算优化方法、系统及应用与流程

专利2025-07-02  7


本发明属于大语言模型、人工智能,涉及一种transformer解码器的计算优化方法、系统及应用。


背景技术:

1、transformer decoder主要存在两个方面的耗时,一个方面是访存耗时,传统计算流程是通过从ddr内存中读取输入数据以及权重等信息,在片上sram中计算完成,再将计算结果写入ddr内存中,由于片上sram内存有限,因此会存在大量的重复访问ddr内存,增加transformer decoder耗时。另外一方面是计算耗时,transformer decoder存在很多权重参数很大的矩阵,自回归模式下,也会存在大量重复计算。

2、目前主流的优化方法有以下几种:

3、1.由于transformer attention在自回归模式下,q,k,v存在大量重复计算,因此提出了kv cache来缓存k,v,减少重复计算,从而减少计算耗时

4、2.引入flash attention,通过将输入分割成块,并在输入块上进行多次传递,小块的计算进行在sram上进行,每个小块单独计算softmax,减少ddr的读和写操作,最后进行补偿整体softmax计算,将结果写入ddr中

5、可以看出,flash attention可以减少很多不必要的ddr访存次数,减少计算耗时,是目前比较好的优化方法,但是由于flash attention只考虑优化了输入以及q,k,v的分块单独计算,没有考虑结合attention中kv cache/旋转编码,也没有考虑decoder中feedforward模块以及其余linear层的操作,因此在整个transformer decoder计算仍存在较大优化空间。


技术实现思路

1、为了解决现有技术存在的不足,本发明的目的是提供一种transformer解码器的计算优化方法。

2、本发明的计算优化方法通过改变flash attention/feedforward network以及其余线性linear层计算流程以及切分方式,在减少不必要的ddr访问次数的情况下,减少sram内存占用,优化transformer decoder整个流程的计算耗时,实现transformer解码器的计算优化。

3、所述flash attention是一种优化的注意力机制,通常用于加速transformer模型的计算过程;所述feedforward network是指前馈网络,其中连接不形成环,数据只在一个方向上传播。

4、具体地,本发明方法针对目前解码器decoder中依然存在多余的访存次数以及计算操作,融合kv cache,优化flash attention以及feedforward中的相关计算流程,并进行多次有效地切分计算,减少不必要的访存次数以及计算耗时。

5、具体地,本发明的计算优化方法主要包括如下步骤:

6、步骤一、attention模块优化:首先对输入、权重wq、wk、wv,键缓存k_cache以及值缓存v_cache的最后一维d进行切分,切分成i份,每次从ddr读一份,复用i次,减少i倍的内存占用;通过减少每次从ddr内存中读取的数据量,从而降低内存占用;

7、然后对输入序列长度seq_len进行切分,本发明中将所述输入序列长度切分成seq_len份,根据切的份数,每一份均通过矩阵乘计算得到q、k、v,并计算q、k的旋转编码,然后更新键缓存k_cache、值缓存v_cache,计算切分分块后的局部softmax,计算完所有切分的seq_len份数时,更新整体softmax值,得到softmax的输出结果;最后对权重wo的第一维切分成i份,对softmax的输出结果和每一份wo进行矩阵乘操作,执行i次,将所有份数的矩阵乘输出结果相加得到attention的输出结果;

8、wq:查询(query)权重,用于将输入转换为查询向量。

9、wk:键(key)权重,用于将输入转换为键向量。

10、wv:值(value)权重,用于将输入转换为值向量。

11、k_cache:键缓存,存储过去时间步骤中的键向量,用于加速连续解码步骤。

12、v_cache:值缓存,存储过去时间步骤中的值向量,同样用于提高解码效率。

13、分块操作通过将权重矩阵和输入序列分成多个小段,这样可以在不同时间步上分别计算并缓存结果,减少每次计算和内存访问的负载。

14、步骤二、feedforward模块优化:首先对前馈网络feedforward的权重w1,w3的最后一维切分成j份,减少在一次计算中处理的数据量,从而减少内存占用和可能的计算时间;然后将attention的输出结果和w1进行矩阵乘,再通过silu激活函数得到o1,通过上述非线性变换过程,能够增加模型的表达能力;同时将attention的输出结果和w3进行矩阵乘,得到o3,将o1和o3进行逐元素乘,得到o2,通过合并由不同权重生成的两个中间结果,能够增强特征表达;最后将权重w2的第一维切分成j份,将o2和w2进行矩阵乘,执行j次,将所有份数的输出结果拼接得到feedforward模块的输出结果;

15、步骤三、重复以上步骤优化transformer编码器的每一层;

16、步骤四、线性层优化:由于线性层linear最后一维维度很大,因此对最后一维进行切分,切分成k份,将feedforward模块的输出结果和切分后的linear进行矩阵乘,执行k次,将所有份数的输出结果拼接得到最后一层linear的输出结果。

17、对于以上四个部分,计算流程是完全串行执行的,因此每部分执行完,片上sram可以完全被复用。

18、本发明还提供了上述计算优化方法在机器翻译、文本生成、语音识别等中的应用。

19、本发明还提供了一种实现上述方法的硬件系统,所述硬件系统包括:存储器和处理器;所述存储器上存储有计算机程序,当所述计算机程序被所述处理器执行时,实现上述计算优化方法。

20、本发明还提供了一种计算机可读存储介质,其上存储有计算机程序,所述计算机程序被处理器执行时,实现上述计算优化方法。

21、本发明的有益效果包括:和现有技术相比,本发明方法在flash attention和kvcache缓存技术的基础上再进行数据计算的切分,首先节省了自回归模式下的kv计算,另外通过对数据切分计算,节省了片上sram的空间,使得每次计算都可以在片上sram上进行,不需要再把中间结果写入ddr中,频繁进行读写,attention模块得到的计算结果会直接放在sram上,用于下一模块feedforward计算,将每个部分串行起来,节省每个部分输出结果写入ddr的时间,另外由于尽可能地节省了sram的空间,使得ddr访存次数可以达到理论的访存次数,不需要额外重复地读写操作,通过数据切分,分批从ddr读取数据进行计算,使得访存时间能够掩盖片上sram数据的计算时间。因此,本发明方法相比于现有方案,能实现transformer decoder计算时间近似于理论ddr访存次数所用时间。


技术特征:

1.一种transformer解码器的计算优化方法,其特征在于,所述方法通过改变flashattention/feedforward network以及线性层计算流程以及切分方式,减少不必要的ddr内存访问次数,减少sram内存占用,优化transformer解码器整个流程的计算耗时,实现transformer解码器的计算优化。

2.如权利要求1所述的计算优化方法,其特征在于,所述方法包括如下步骤:

3.如权利要求2所述的计算优化方法,其特征在于,步骤一包括如下子步骤:

4.如权利要求2所述的计算优化方法,其特征在于,步骤二包括如下子步骤:

5.如权利要求2所述的计算优化方法,其特征在于,步骤四包括如下子步骤:

6.如权利要求1-5之任一项所述的计算优化方法在机器翻译、文本生成、语音识别中的应用。

7.一种实现如权利要求1-5之任一项所述计算优化方法的硬件系统,其特征在于,所述硬件系统包括:存储器和处理器;所述存储器上存储有计算机程序,当所述计算机程序被所述处理器执行时,实现如权利要求1-5任一项所述的计算优化方法。

8.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时,实现如权利要求1-5任一项所述的计算优化方法。


技术总结
本发明公开了一种transformer解码器的计算优化方法,所述方法通过改变flash attention/feedforward network以及线性层计算流程以及切分方式,减少不必要的ddr内存访问次数,减少SRAM内存占用,优化transformer解码器整个流程的计算耗时,实现transformer解码器的计算优化。本发明还公开了上述计算优化方法在机器翻译、文本生成、语音识别等中的应用。本发明还公开了实现上述计算优化方法的硬件系统和计算机可读存储介质。

技术研发人员:方国浩
受保护的技术使用者:上海曲速超为技术有限公司
技术研发日:
技术公布日:2024/12/17
转载请注明原文地址:https://xbbs.6miu.com/read-25898.html