本发明属于食谱检索方法,涉及基于模态交互的食谱检索方法。
背景技术:
1、随着食物对人们在健康和生活中的重要性逐渐增加,互联网和社交网络提供了丰富的食物图片和食谱资源。然而,这些资源之间往往缺乏有意义的连接,使得开发有效的食谱检索系统变得至关重要。
2、在这一背景下,跨模态食谱检索技术迅速崛起,实现了对给定食谱对应图像的检索。recipe1m数据集的推出提供了一个丰富的、大规模的跨模态食谱数据集,为算法的训练和评估提供了有力的支持。现有的跨模态食谱检索工作大多依赖于循环网络,如lstms或transformers来编码食谱数据;基础神经网络架构,如resnet-50或vit,则被广泛应用于处理和编码图像数据。然而,仅仅依赖全局相似性而忽略精细的模态交互在某种程度上限制了性能的发挥。但由于跨模态食谱检索任务存在不同模态之间的语义间隙问题,导致跨模态食谱检索在准确性上仍然存在不足,故研究准确的跨模态食谱检索有着重要意义。
技术实现思路
1、本发明的目的是提供基于模态交互的食谱检索方法,解决了现有技术中存在的由于跨模态食谱检索任务存在不同模态之间的语义间隙问题,导致跨模态食谱检索在准确性上存在不足的问题。
2、本发明所采用的技术方案是,基于模态交互的食谱检索方法,具体按照如下步骤实施:
3、步骤1,获取跨模态食谱检索数据集,并将分为训练集和测试集;
4、步骤2,构建基于细粒度模态间交互的跨模态食谱检索网络模型;
5、步骤3,使用训练集训练步骤2构建的基于细粒度模态间交互的跨模态食谱检索网络模型;
6、步骤4,使用测试集对步骤3训练好的网络模型进行测试,得到食谱文本检索图像结果,并对检索结果进行评价;
7、步骤5,在测试集上对步骤3训练好的网络模型进行测试,得到食谱图像检索文本结果,并对检索结果进行评价。
8、本发明的特征还在于,
9、步骤1具体为:
10、从官方网站上下载食谱数据集d={i,t},其中,i={i1,…,im,…,im},t={t1,…,tn,…,tn},im表示d中的第m幅图像,tn表示d中第n个食谱,1≤m≤m,1≤n≤n,m表示图像的总个数,n表示食谱的总个数,im∈rw×h×d,w,h,d分别对应im中图像的宽、高和通道个数,分别表示tn中的标题、成分和指令三部分表示,其中,n1,l1分别表示菜谱标题的句子个数和句子长度,菜谱标题的句子个数为1,故n1=1,n2,l2分别表示菜谱成分的句子个数和句子长度,n3,l3分别表示菜谱指令的句子个数和句子长度;将d按照a:b个数比例划分为训练集trd和测试集ted,trd={itrd,ttrd},x+y=a;
11、x′+y′=b,表示trd中第x幅图像,表示trd中第y个食谱,表示ted中第x'幅图像,表示ted中第y'个食谱。
12、步骤2中基于细粒度模态间交互的跨模态食谱检索网络模型包括食谱分级编码模块、基于文本上下文视觉增强的图像编码模块、配对模块;
13、基于细粒度模态间交互的跨模态食谱检索网络模型的工作过程为:
14、步骤2.1,食谱分级编码模块从trd中的ttrd提取第y个食谱其中,食谱将送入分级编码模块,输出ttly,ingy,insy的编码特征,分别为将经过一个矢量串联和一个全连接降维网络层,得到最终的食谱深度特征表示
15、步骤2.2,基于文本上下文视觉增强的图像编码模块从trd中的ttrd提取第x幅图像送入到基干网络resnet-50中,提取resnet-50网络的第k层的浅层特征将和步骤2.1得到的标题和成分编码特征进行特征增强得到再把再送入到基干网络resnet-50的第k+1层到分类的前一层的网络中提取深层特征再经过全连接层得到图像的深度特征
16、步骤2.3,配对模块计算图像深度特征和食谱深度特征的配对损失。
17、步骤2.1中的分级编码模块由嵌入子模块、一级注意力机制子模块和二级注意力机制子模块组成;
18、分级编码模块的工作过程为:
19、首先,将送入嵌入子模块输出得到erec,erec={ettl,eing,eins},其中其次,把erec送入一级注意力机制子模块,输出其中接着,把第和分别送入二级注意力机制子模块,输出和其中最后将在d1维度上进行拼接,将拼接后的特征输入到一个全连接层,输出降维后的数据其中
20、嵌入子模块的工作过程为:将作为输入,设定输出向量维度d1,调用开源的库pytorch文本单词数据映射函数nn.embedding函数,输出erec,erec={ettl,eing,eins},其中
21、一级注意力机制子模块由一个注意力层和一个全局平均池化层构成;
22、一级注意力机制子模块工作过程为:
23、将嵌入子模块的输出erec={ettl,eing,eins}作为注意力层的输入,注意力层定义查询向量值向量和键向量给查询向量、值向量和键向量赋值:求和的相关度和的相关度和的相关度求解公式如下:
24、
25、其中,softmax()为概率归一化激活函数;
26、然后计算注意力机制的输出其中
27、
28、最后,将分别送入全局平均池化层,得到其平均值作为菜谱标题特征、菜谱成分特征和菜谱指令特征的总体表示和其中
29、二级注意力机制子模块由一个注意力层和一个全局平均池化层构成;
30、二级注意力机制子模块的工作过程为:
31、将一级注意力机制子模块的输出数据作为二级注意力机制子模块的输入,输入到二级注意力机制子模块的注意力层,注意力层定义查询向量值向量和键向量给查询向量、值向量和键向量赋值:求和的相关度和的相关度
32、
33、其中,softmax()为概率归一化激活函数;
34、然后计算注意力机制的输出
35、
36、
37、最后,再将分别送入全局平均池化层,得到其平均值作为菜谱成分特征和菜谱指令特征的总体表示和其中
38、分级编码模块将一级注意力机制子模块输出的和二级注意力机制子模块输出的分别作为ttly,ingy,insy的编码特征输出。
39、步骤2.2中的基于文本上下文视觉增强的图像编码模块由图像浅层特征提取模块、文本-图像块相似度度量子模块、残差连接子模块以及图像深层特征提取子模块组成;
40、图像浅层特征提取模块的工作过程为:
41、从trd中的ttrd提取第x幅图像其中将作为图像浅层特征提取模块的输入,将送入resnet-50网络中提取浅层特征具体为:
42、
43、其中resnet-50sl()表示把送入resnet-50网络,经过第1层到第k层子网络,输出第k层的特征;
44、文本-图像块相似度度量子模块的工作过程为:
45、将步骤2.1中一级注意力机制子模块输出数据和图像浅层特征提取模块的输出数据作为文本-图像块相似度度量子模块的输入,将和进行拼接,得到其中cat()为串联函数:
46、
47、其中,cat()为串联函数;
48、以及,调用开源的库numpy中数组形状改变函数将改变成wk×hk个dk维度的特征
49、然后再计算和之间的点乘,调用概率归一化激活函数softmax(),得到文本和图像块之间的相似度分数其中,
50、
51、文本-图像块相似度度量子模块将作为输出;
52、残差连接子模块的工作过程为:
53、将文本-图像块相似度度量子模块的输出作为输入,根据得到的相似度分数对文本和图像残差连接得到
54、
55、图像深层特征提取子模块的工作过程为:
56、将残差连接子模块的输出作为图像深层特征提取子模块的输入,然后使用resnet-50网络提取的深层特征
57、
58、其中,resnet-50dl()表示把送入resnet-50网络,经过第k+1层到第k层,输出第k层的特征;
59、最后,将送入到全连接层中的到最终的图像特征
60、
61、其中,fc()表示将维度为dk的特征送入到全连接层中,输出维度为d2。
62、步骤3具体为:
63、步骤3.1,设置网络模型训练参数;即设置学习率变量lr、训练迭代最大次数变量max_iter、每批次配对数据大小变量batch_size_paired、训练迭代次数变量step,初始化step=1;
64、步骤3.2,将训练集trd输入到步骤2构建的基于细粒度模态间交互的跨模态食谱检索网络模型中,按照设置的训练参数进行网络训练,当目标损失函数变化小于minloss或者训练迭代次数变量step≥max_iter时,网络训练结束,保存网络模型model;否则step=step+1,使用adam优化器来反向修正网络模型中各网络层的权重系数,继续训练。
65、步骤3.2中的目标损失函数为配对损失函数lpair:
66、
67、lcos(xa,xp,xn)=max(0,c(xa,xn)-c(xa,xp)+m) (20)
68、
69、其中,配对损失lpair由一个平均双向三重损失lbi构成,其中lpair损失函数的输入数据分别是食谱分级编码模块的输出数据基于文本上下文视觉增强的图像编码模块的输出数据平均双向三重损失lbi指的是计算每个batch中的每个菜谱图像和文本对之间的双向三重损失l'bi的平均值,对于每个菜谱图像和文本之间的双向三重损失函数l'bi的输入数据为一个batch中的第i个和第j个菜谱图像文本对,双向三重损失由两个三重损失lcos组成,其中,当i=j时,δ(i,j)=0,当i≠j时,δ(i,j)=1;三重损失函数lcos(xa,xp,xn)的输入数据为一个三元组,其中xa表示训练样本,xp表示正样本,xn表示负样本,m是超参数,c()表示余弦距离,旨在减少训练样本xa和正样本xp之间的距离,扩大训练样本xa和负样本xn之间的距离;双向三重损失函数l'bi由两部分组成,和其中的输入数据代表训练样本中第i个菜谱图像文本经过基于文本上下文视觉增强的图像编码模块后的输出数据,代表训练样本中第i个菜谱图像文本经过食谱分级编码模块后的输出数据,代表了正样本,代表训练样本中第j个菜谱图像文本经过食谱分级编码模块后的输出数据,代表了负样本,函数旨在减小菜谱图像数据与同一菜谱图像文本对中的文本数据之间的余弦距离,扩大菜谱图像数据与不同菜谱图像文本对中的文本数据之间的余弦距离;的输入数据代表训练样本中第i个菜谱图像文本经过食谱分级编码模块后的输出数据,代表训练样本中第i个菜谱图像文本经过基于文本上下文视觉增强的图像编码模块后的输出数据,代表了正样本,代表训练样本中第j个菜谱图像文本经过基于文本上下文视觉增强的图像编码模块后的输出数据,代表了负样本,函数旨在减小菜谱文本数据与同一菜谱图像文本对中的图像数据之间的余弦距离,扩大菜谱文本数据与不同菜谱图像文本对中的图像数据之间的余弦距离。
70、步骤4具体为:
71、步骤4.1,定义文本特征索引y',初始化为1,定义图像特征索引x',初始化为1,定义食谱分级编码模块输出文本特征vrec,初始化为空,vrec=null,定义基于文本上下文化的视觉增强模块输出图像特征vimg,初始化为空,vimg=null,定义文本检索图像的余弦距离结果d,初始化d为空,即d=null,定义文本检索图像结果r2ited,初始化r2ited为空,即r2ited=null;
72、步骤4.2,从测试集ted={ited,tted}的ited中选取第x'个图像特征送入到训练好的基于文本上下文视觉增强的图像编码模块中,输出结果为并将追加到vimg中,判断x'是否小于等于x',如果x'≤x',则x'=x'+1,并继续重复步骤4.2;
73、步骤4.3,从测试集ted={ited,tted}的tted中选取第y'个文本特征送入到训练好的嵌入子模块,一级注意力机制子模块和二级注意力子模块,输出结果为
74、步骤4.4,使用开源的库pytorch中的torch.mm()函数计算和的余弦相似度余弦距离
75、
76、步骤4.5,使用python内置排序函数sprt()对进行排序,余弦距离越小,相似度越高,返回排序结果dy',并将排序后结果对应的索引号x'对应的图像追加到
77、
78、步骤4.6,判断y'是否大于y',如果y'≤y',则y'=y'+1,重复步4.3;
79、步骤4.7,根据最终输出的菜谱文本检索图像结果评价网络模型,计算网络模型输出检索结果中正确结果的中位数排名medr、计算网络模型的召回率recall@kr2i:
80、
81、其中,n表示选取的样本数,n表示选取的样本数中正确结果在前k个的个数。
82、步骤5具体为:
83、步骤5.1,定义文本特征索引y',初始化为1,定义图像特征索引x',初始化为1,定义菜谱编码模块输出文本特征vrec,初始化为空,vrec=null,定义基于文本上下文化的视觉增强模块输出图像特征vimg,初始化为空,vimg=null,定义图像检索文本的余弦距离结果d,初始化d为空,即d=null,定义图像检索文本结果i2rted,初始化i2rted为空,即i2rted=null;
84、步骤5.2,从测试集ted={ited,tted}的tted中选取第y'个文本特征送入到训练好的嵌入子模块,一级注意力机制子模块和二级注意力子模块,输出结果为追加到vrec中,判断y'是否小于等于y',如果y'≤y',则y'=y'+1,并继续重复步骤5.2;
85、步骤5.3,从测试集ted={ited,tted}中ited中选取第x'个图像特征送入到训练好的基于文本上下文化视觉增强的图像编码模块,输出结果为
86、步骤5.4,使用开源的库pytorch中的torch.mm()函数计算和的余弦相似度余弦距离
87、
88、步骤5.5,使用python内置排序函数sort()对进行排序,余弦距离越小,相似度越高,返回排序结果dx',并将排序后结果对应的索引号y'对应的文本追加到
89、
90、步骤5.6,判断x'是否大于x',如果x'≤x',则x'=x'+1,重复步骤5.3;
91、步骤5.7,根据最终输出的菜谱图像检索文本任务的结果评价网络模型model,计算模型model输出检索结果中正确结果的中位数排名medr、计算模型model的召回率recall@ki2r:
92、
93、其中,n表示选取的样本数,n表示选取的样本数中正确结果在前k个的个数。
94、本发明的有益效果是:
95、本发明基于文本上下文化的视觉增强机制,计算图像特征和文本特征之间的相关性,从而增强图像特征;本发明采用浅层的图像特征,更加注重于图像的局部特征,保留了原始图像的位置和细节信息,使得本发明的方法在跨模态食谱检索领域取得了较高的准确率。
1.基于模态交互的食谱检索方法,其特征在于,具体按照如下步骤实施:
2.根据权利要求1所述的基于模态交互的食谱检索方法,其特征在于,所述步骤1具体为:
3.根据权利要求2所述的基于模态交互的食谱检索方法,其特征在于,所述步骤2中基于细粒度模态间交互的跨模态食谱检索网络模型包括食谱分级编码模块、基于文本上下文视觉增强的图像编码模块、配对模块;
4.根据权利要求3所述的基于模态交互的食谱检索方法,其特征在于,所述步骤2.1中的分级编码模块由嵌入子模块、一级注意力机制子模块和二级注意力机制子模块组成;
5.根据权利要求4所述的基于模态交互的食谱检索方法,其特征在于,所述分级编码模块的一级注意力机制子模块由一个注意力层和一个全局平均池化层构成;
6.根据权利要求5所述的基于模态交互的食谱检索方法,其特征在于,所述步骤2.2中的基于文本上下文视觉增强的图像编码模块由图像浅层特征提取模块、文本-图像块相似度度量子模块、残差连接子模块以及图像深层特征提取子模块组成;
7.根据权利要求6所述的基于模态交互的食谱检索方法,其特征在于,所述步骤3具体为:
8.根据权利要求7所述的基于模态交互的食谱检索方法,其特征在于,所述步骤3.2中的目标损失函数为配对损失函数lpair:
9.根据权利要求8所述的基于模态交互的食谱检索方法,其特征在于,所述步骤4具体为:
10.根据权利要求9所述的基于模态交互的食谱检索方法,其特征在于,所述步骤5具体为: