Transformer模型的多头注意力机制,通俗讲解

Transformer模型最核心的内容就是多头自注意力机制

今天的这篇文章,我希望能帮助同学们:

在没有学习过Transformer的情况下,快速的学习它的核心算法:

多头自注意力机制;

Multi Head Self Attention。

很多时候,如果我们想快速的学习一个特别复杂的模型,比如Transformer。

最好一开始,就能对它的核心结构,有一个基本的“认知”。

然后再去深入的探索这个模型的算法设计和代码实现。

这样能够让我们更快的学习和理解这个复杂的深度学习模型。

学习的过程,也会变得事半功倍!

“多头自注意力机制”,就是Transformer的核心结构。

观察Transformer模型,会看到其中包含了3个用橙色标记的“多头自注意力”。

图片

这些橙色标记的“多头自注意力”的内部结构,长成下面这个样子:

图片

接下来,我们就开始深入、详细的讨论,Transformer的“多头自注意力机制”。

讲解包括以下3个部分:

1.多头自注意力机制,到底是什么?

2.输入给多头自注意力的Q、K、V,是什么?

3.如何使用Q、K、V,计算多头自注意力?

1.多头自注意力机制,到底是什么?

我们需要快速对“多头自注意力”,有一个基本认识。

简单来说,它就是一个“超级复杂”的线性层组合体。

在这个超级复杂的线性层中,包含了许多组子线性层。

实际上,我们一开始去理解它,就把它当做是一个线性层就好了。

因为它最主要的功能,与线性层(nn.linear)一样,就是:

对输入的张量x,进行特征变换。

下面这张图,就表示了“多头自注意力机制”的计算过程。

为了观察的方便,我将多头自注意力的结构图“横着”摆。

图片

观察设计图,可以看到:

多头注意力机制,包括了很多组“子线性层”;

图中标注的深浅不一的“Linear”,都是它的“子线性层”。

在这些子线性层中,包含了数量极为庞大的w和b参数。

这些w和b参数,就是用于对输入x,进行特征变换的!

它们的工作方式,与nn.Linear中的w和b的工作方式,并没有本质的不同。

当然了,这么多子线性层的计算,并且计算过程中还包含了:

Scaled Dot Product Attention(缩放点积注意力)和concat(多头结果合并)

这肯定会比普通的线性层计算更复杂。

这里这样解释,只是为了让大家对“多头自注意力”有基本的“认知”。

“多头自注意力”整体的功能和目标,仍然是“特征变换”。

当我们将输入序列的张量x,输入到“多头自注意力”后;

就会使用这些w和b参数,对x进行一系列的“复杂计算”。

最终输入x就会被“多头自注意力”,变换为一个新的张量x’。

图片

此时我们会发现x和x’,它们具有完全一样的尺寸!

也就是如果x是3×5×4大小的张量,那么x’也是3×5×4大小的张量。

这时可能有同学可能会困惑,为什么要整这么复杂呢?

将x变换到x’,而x’似乎什么都没变?(至少尺寸完全没变)。

简单来说,特征变换后,x’会包含更多的信息,更能体现x本身的语义信息。

也就张量x’里面的那堆浮点数,更能表达原始输入序列的意义。

上述的解释都是为了表达:

经过自注意力的计算后,x’就包含了x中的“上下文”信息。

同学们此时如果不太明白也也没关系,只要知道:

“多头自注意力”的核心作用,就是对“输入张量x”进行特征变换。

而这个特征变换后的x’,会带有“全局上下文信息”;

对原始的输入序列,有着更好的“理解”和“表示”,就可以了。

2.输入给多头自注意力的Q、K、V,是什么?

在Transformer的编码器中,有一个“多头自注意力”。

在Transformer的解码器中,有两个“多头自注意力”。

图片

这三个“多头自注意力”的内部结构和代码实现是完全相同的。

不同的地方是“输入数据”和对数据的“Masked掩码方式”不同。

备注:这块如果详细说,比较耗时,所以我就不深入探讨了。

同学们只要知道,这三个橙色的方块,都是“多头自注意力”;

其中的内部计算过程,是完全相同的,就可以了!

在讨论这个橙色方块内部的计算过程之前,我们要讨论它的输入数据。

我们以“编码器”中的橙色方块为例来说明。

备注:“编码器”中的多头自注意力和“解码器”中的内部结构和计算过程都一样,

但是由于“编码器”的输入更简单,所以更好讲。

图片

上图的左侧是编码器,其中有一个橙色方块代表多头自注意力。

右侧是橙色方块的内部结构。

观察“多头自注意力”内部结构,可以看到,它包括了3个输入和一个输出。

这3个输入被称为,查询(Query)、键(Key)、和值(Value)。

Q、K、V这三组张量,其中保存的数据可以是相同的,也可以是不同的。

比如输入给解码器的“Q、K、V”就是相同的。

在最初理解多头自注意力时,其实并不需要关注Q、K、V所代表的具体含义。

什么是“查询(Query)”?

什么是“键(Key)”?

什么又是“值(Value)”?

这些根本就不重要,一上来去了解这么多概念,反而把简单的事弄得复杂。

我们在最初理解时,就把Q、K、V当做是三组输入张量,就可以了!

并且对于输入给解码器的多头自注意力的Q、K、V,还都来源于一个张量x!

这个张量x,就是Transformer左下方编码器的输入Inputs。

图片

将Inputs序列输入至词嵌入层Embedding与位置编码Positional Encoding两个结构进行计算。

计算得到的x,就是Q、K、V!

也就是说,将x复制3份,变成Q、K、V,形成3个分支;

接着输入到橙色的多头注意力机制,再进行计算。

3.如何使用Q、K、V,计算多头自注意力?

多头,顾名思义,就是同时通过多个“自注意力机制”进行特征提取。

图片

示意图中的不同颜色,就代表了“多头”。

简单来说,我们可以把“多头”理解成“多组”。

分成“多组”计算自注意力的原因是:

捕捉输入序列x,中不同子空间的特征。

如果你对这个描述不理解,那就别纠结了。

因为你完全可以将“多头”当做“单头”来理解“注意力机制”的计算,

并不影响对模型整体的理解。

备注:也就是把n_head这个参数当做是1就好了!无所谓的。

 

下面我们一步一步的来描述“多头自注意力机制”的计算过程。

第1步:对Q、K、V线性变换

图片

上方的示意图,左侧表示了一个头的结构和右侧表示了其中的计算过程。

在每个“头”中,都有三组Linear线性层。

它们用于对输入张量Q、K、V进行线性变换。

所谓线性变换,就是将张量Q输入到qnet,计算出一个新的Q。

将张量K输入到knet,计算出一个新的K。

将张量V输入到vnet,计算出一个新的V。

第2步:计算缩放点积注意力

完成对Q、K、V的线性变换后,用新的Q、K、V这三组张量,计算缩放点积注意力。

也就是基于下方的公式,计算Scaled Dot Product Attention:

图片

备注:对于这个公式的详细解释,这里就不说了;

因为想解释清楚,得用超级多的篇幅来说。

同学们可以直接将它就当做是一个张量的计算公式就好了。

简单来说,这个Attention计算的作用就是,将Q、K、V中信息进行“融合”。

融合后的输出,就有了全局信息。

第3步:融合多头计算的结果

每个头都会基于一组Q、K、V ,计算缩放点积注意力结果;

我们需要将这些不同头的计算结果拼接(concat)起来:

图片

然后通过一个线性层进行融合,得到最终的输出张量。

这个输出张量的尺寸会和Q、K、V中的V,是一样的。

由于我们讨论的是解码器中的多头自注意力;

因此最终的输出结果,和最开始输入Q、K、V,都是一样的。

举例说明

下面我们基于一个具体例子,说明多头注意力的计算过程。

观察下方的示意图:

图片

首先,输入的数据是,经过位置编码后的“You are welcome PAD”;

它对应最下方的黄色词向量矩阵。

第1步:对Q、K、V线性变换

4×6的输入张量,首先被复制为Query、Key和Value。

图片

然后分别和3个Linear线性层做线性变换:

生成三组新的Q、K和V。

这里使用浅橙色、红色和深橙色来表示,尺寸都是4×6。

这个过程就是最基础的线性层计算。

第2步:计算缩放点积注意力

接着将q、k和v,三个计算结果,先split为多头的形式。

备注:split就是一种张量尺寸变换,可以忽略。

然后带入到Attention的计算公式中:

图片

计算注意力机制的输出。

经过Attention的计算后,就会融合输入序列中的全局信息。

第3步:融合多头计算的结果

在示意图的最上方,Attention Score的输出是多组结果。

图片

我们会使用一个线性层,将这些结果进行“Merge”,输出最终的一个结果。

我们会看到:

最下方的输入张量尺寸是4×6;

最上方的输出张量尺寸也是4×6;

备注:实际上中间的Q、K、V不一定是4×6大小的,示意图画的是4×6。

因此总结来说,整个多头注意力机制的作用,就是对输入张量进行特征变换!

变换后的输出张量,就带有了全局信息!

到这里,这篇“多头自注意力机制”的讲解,就说完了。

来源:小黑黑讲AI

THE END