Transformer模型的多头注意力机制,通俗讲解
Transformer模型最核心的内容就是多头自注意力机制。
今天的这篇文章,我希望能帮助同学们:
在没有学习过Transformer的情况下,快速的学习它的核心算法:
多头自注意力机制;
Multi Head Self Attention。
很多时候,如果我们想快速的学习一个特别复杂的模型,比如Transformer。
最好一开始,就能对它的核心结构,有一个基本的“认知”。
然后再去深入的探索这个模型的算法设计和代码实现。
这样能够让我们更快的学习和理解这个复杂的深度学习模型。
学习的过程,也会变得事半功倍!
“多头自注意力机制”,就是Transformer的核心结构。
观察Transformer模型,会看到其中包含了3个用橙色标记的“多头自注意力”。
这些橙色标记的“多头自注意力”的内部结构,长成下面这个样子:
接下来,我们就开始深入、详细的讨论,Transformer的“多头自注意力机制”。
讲解包括以下3个部分:
1.多头自注意力机制,到底是什么?
2.输入给多头自注意力的Q、K、V,是什么?
3.如何使用Q、K、V,计算多头自注意力?
1.多头自注意力机制,到底是什么?
我们需要快速对“多头自注意力”,有一个基本认识。
简单来说,它就是一个“超级复杂”的线性层组合体。
在这个超级复杂的线性层中,包含了许多组子线性层。
实际上,我们一开始去理解它,就把它当做是一个线性层就好了。
因为它最主要的功能,与线性层(nn.linear)一样,就是:
对输入的张量x,进行特征变换。
下面这张图,就表示了“多头自注意力机制”的计算过程。
为了观察的方便,我将多头自注意力的结构图“横着”摆。
观察设计图,可以看到:
多头注意力机制,包括了很多组“子线性层”;
图中标注的深浅不一的“Linear”,都是它的“子线性层”。
在这些子线性层中,包含了数量极为庞大的w和b参数。
这些w和b参数,就是用于对输入x,进行特征变换的!
它们的工作方式,与nn.Linear中的w和b的工作方式,并没有本质的不同。
当然了,这么多子线性层的计算,并且计算过程中还包含了:
Scaled Dot Product Attention(缩放点积注意力)和concat(多头结果合并)
这肯定会比普通的线性层计算更复杂。
这里这样解释,只是为了让大家对“多头自注意力”有基本的“认知”。
“多头自注意力”整体的功能和目标,仍然是“特征变换”。
当我们将输入序列的张量x,输入到“多头自注意力”后;
就会使用这些w和b参数,对x进行一系列的“复杂计算”。
最终输入x就会被“多头自注意力”,变换为一个新的张量x’。
此时我们会发现x和x’,它们具有完全一样的尺寸!
也就是如果x是3×5×4大小的张量,那么x’也是3×5×4大小的张量。
这时可能有同学可能会困惑,为什么要整这么复杂呢?
将x变换到x’,而x’似乎什么都没变?(至少尺寸完全没变)。
简单来说,特征变换后,x’会包含更多的信息,更能体现x本身的语义信息。
也就张量x’里面的那堆浮点数,更能表达原始输入序列的意义。
上述的解释都是为了表达:
经过自注意力的计算后,x’就包含了x中的“上下文”信息。
同学们此时如果不太明白也也没关系,只要知道:
“多头自注意力”的核心作用,就是对“输入张量x”进行特征变换。
而这个特征变换后的x’,会带有“全局上下文信息”;
对原始的输入序列,有着更好的“理解”和“表示”,就可以了。
2.输入给多头自注意力的Q、K、V,是什么?
在Transformer的编码器中,有一个“多头自注意力”。
在Transformer的解码器中,有两个“多头自注意力”。
这三个“多头自注意力”的内部结构和代码实现是完全相同的。
不同的地方是“输入数据”和对数据的“Masked掩码方式”不同。
备注:这块如果详细说,比较耗时,所以我就不深入探讨了。
同学们只要知道,这三个橙色的方块,都是“多头自注意力”;
其中的内部计算过程,是完全相同的,就可以了!
在讨论这个橙色方块内部的计算过程之前,我们要讨论它的输入数据。
我们以“编码器”中的橙色方块为例来说明。
备注:“编码器”中的多头自注意力和“解码器”中的内部结构和计算过程都一样,
但是由于“编码器”的输入更简单,所以更好讲。
上图的左侧是编码器,其中有一个橙色方块代表多头自注意力。
右侧是橙色方块的内部结构。
观察“多头自注意力”内部结构,可以看到,它包括了3个输入和一个输出。
这3个输入被称为,查询(Query)、键(Key)、和值(Value)。
Q、K、V这三组张量,其中保存的数据可以是相同的,也可以是不同的。
比如输入给解码器的“Q、K、V”就是相同的。
在最初理解多头自注意力时,其实并不需要关注Q、K、V所代表的具体含义。
什么是“查询(Query)”?
什么是“键(Key)”?
什么又是“值(Value)”?
这些根本就不重要,一上来去了解这么多概念,反而把简单的事弄得复杂。
我们在最初理解时,就把Q、K、V当做是三组输入张量,就可以了!
并且对于输入给解码器的多头自注意力的Q、K、V,还都来源于一个张量x!
这个张量x,就是Transformer左下方编码器的输入Inputs。
将Inputs序列输入至词嵌入层Embedding与位置编码Positional Encoding两个结构进行计算。
计算得到的x,就是Q、K、V!
也就是说,将x复制3份,变成Q、K、V,形成3个分支;
接着输入到橙色的多头注意力机制,再进行计算。
3.如何使用Q、K、V,计算多头自注意力?
多头,顾名思义,就是同时通过多个“自注意力机制”进行特征提取。
示意图中的不同颜色,就代表了“多头”。
简单来说,我们可以把“多头”理解成“多组”。
分成“多组”计算自注意力的原因是:
捕捉输入序列x,中不同子空间的特征。
如果你对这个描述不理解,那就别纠结了。
因为你完全可以将“多头”当做“单头”来理解“注意力机制”的计算,
并不影响对模型整体的理解。
备注:也就是把n_head这个参数当做是1就好了!无所谓的。
下面我们一步一步的来描述“多头自注意力机制”的计算过程。
第1步:对Q、K、V线性变换
上方的示意图,左侧表示了一个头的结构和右侧表示了其中的计算过程。
在每个“头”中,都有三组Linear线性层。
它们用于对输入张量Q、K、V进行线性变换。
所谓线性变换,就是将张量Q输入到qnet,计算出一个新的Q。
将张量K输入到knet,计算出一个新的K。
将张量V输入到vnet,计算出一个新的V。
第2步:计算缩放点积注意力
完成对Q、K、V的线性变换后,用新的Q、K、V这三组张量,计算缩放点积注意力。
也就是基于下方的公式,计算Scaled Dot Product Attention:
备注:对于这个公式的详细解释,这里就不说了;
因为想解释清楚,得用超级多的篇幅来说。
同学们可以直接将它就当做是一个张量的计算公式就好了。
简单来说,这个Attention计算的作用就是,将Q、K、V中信息进行“融合”。
融合后的输出,就有了全局信息。
第3步:融合多头计算的结果
每个头都会基于一组Q、K、V ,计算缩放点积注意力结果;
我们需要将这些不同头的计算结果拼接(concat)起来:
然后通过一个线性层进行融合,得到最终的输出张量。
这个输出张量的尺寸会和Q、K、V中的V,是一样的。
由于我们讨论的是解码器中的多头自注意力;
因此最终的输出结果,和最开始输入Q、K、V,都是一样的。
举例说明
下面我们基于一个具体例子,说明多头注意力的计算过程。
观察下方的示意图:
首先,输入的数据是,经过位置编码后的“You are welcome PAD”;
它对应最下方的黄色词向量矩阵。
第1步:对Q、K、V线性变换
4×6的输入张量,首先被复制为Query、Key和Value。
然后分别和3个Linear线性层做线性变换:
生成三组新的Q、K和V。
这里使用浅橙色、红色和深橙色来表示,尺寸都是4×6。
这个过程就是最基础的线性层计算。
第2步:计算缩放点积注意力
接着将q、k和v,三个计算结果,先split为多头的形式。
备注:split就是一种张量尺寸变换,可以忽略。
然后带入到Attention的计算公式中:
计算注意力机制的输出。
经过Attention的计算后,就会融合输入序列中的全局信息。
第3步:融合多头计算的结果
在示意图的最上方,Attention Score的输出是多组结果。
我们会使用一个线性层,将这些结果进行“Merge”,输出最终的一个结果。
我们会看到:
最下方的输入张量尺寸是4×6;
最上方的输出张量尺寸也是4×6;
备注:实际上中间的Q、K、V不一定是4×6大小的,示意图画的是4×6。
因此总结来说,整个多头注意力机制的作用,就是对输入张量进行特征变换!
变换后的输出张量,就带有了全局信息!
到这里,这篇“多头自注意力机制”的讲解,就说完了。
来源:小黑黑讲AI