Pytorch实战Transformer算法之注意力机制
注意力机制是 Transformer 模型中最核心的部分。概念上,它是通过一个查询(query)与一组键值对(key-values)进行运算,最终生成一个输出。通过矩阵运算,可以同时让多个查询与一组键值对进行并行计算,从而最大化并行效率。
无论是查询(query)、键(keys)、值(values)还是输出,它们都是向量(vectors),查询和键值对是通过训练得到的权重矩阵计算得到的。
在 Transformer 中,注意力机制主要分为两种:
-
• 自注意力机制:这种机制在编码器或解码器中单独使用。在这种情况下,查询、键和值都来自同一个序列。 -
• 交叉注意力机制:这种机制在编码器-解码器模型中使用。例如,解码器的查询会与编码器的键和值进行注意力计算。
def scaled_dot_product_attention(q, k, v, mask=None):
matmul_qk = torch.matmul(q, k.transpose(-2, -1)) # (..., seq_len_q, seq_len_k)
dk = torch.tensor(k.size(-1), dtype=torch.float32)
# (..., seq_len_q, seq_len_k)
scaled_attention_logits = matmul_qk / torch.sqrt(dk)
if mask isnotNone:
scaled_attention_logits += (mask * -1e9)
# (..., seq_len_q, seq_len_k)
attention_weights = F.softmax(scaled_attention_logits, dim=-1)
# (..., seq_len_q, seq_len_k) @ (..., seq_len_v, depth_v) -> (..., seq_len_q, depth_v)
output = torch.matmul(attention_weights, v)
return output, attention_weights
先看参数,注意维度:
-
• q: query shape == (..., seq_len_q, depth) -
• k: key shape == (..., seq_len_k, depth) -
• v: value shape == (..., seq_len_v, depth_v) -
• mask:(..., seq_len_q, seq_len_k)
无论是自注意力机制还是交叉注意力机制,seq_len_k 和 seq_len_v 的序列长度必须相同,因为它们本质上是从同一个序列计算出来的。
另外 query 的 depth 和 key 的 depth 必须相同,因为后续的矩阵运算需要。depth 和 value 的 depth_v 可以不一致,但为了方便运算,通常 q、k、v 的最后一个维度默认与 d_model 相同。
而 mask 的维度取决于是自注意力机制还是交叉注意力机制,如果是自注意力机制 seq_len_q = seq_len_k,如果是交叉注意力机制,两者可以不一样。
scaled_attention_logits 是注意力机制中的关键步骤,用于计算输入序列中不同位置之间的相关性或重要性。具体来说,它表示查询向量和键向量之间的相似度。它为什么要除以 torch.sqrt(dk)
?可以有效地稳定梯度、提高数值稳定性,并使得注意力分数的分布更加均匀,从而提高模型的表现。
接下去处理Mask部分,让 mask = 1 的 logits 变得非常小,经过 Softmax 归一化后,这些位置的注意力权重将接近于 0。
attention_weights 就是 softmax 处理后得到的注意力权重,如果是自注意力机制,则 seq_len_q 其实就是 seq_len_k,表示每个单词之间的权重,如果是交叉注意力机制,则表示 q 和 k 单词之间的权重。
而最重要的输出是权重和v之间的加权求和,也是一个矩阵操作,也间接说明 depth_v 可以不同于 depth,最后的维度 (..., seq_len_q, depth_v),其中 seq_len_q 是查询序列的长度,seq_len_k 是键序列的长度。
最终的输出是注意力权重与值向量,也是一个矩阵操作,输出的维度为 (..., seq_len_q, depth_v),这表明输出序列的长度与查询序列的长度相同,而每个位置的输出向量的维度为 depth_v,这一步骤间接说明 depth_v 可以不同于 depth,只要注意力权重的维度与值向量的维度匹配即可。
1:注意力机制和 padding mask
不管哪一种注意力机制,都会用到 padding mask,在这个例子中为了理解的更透彻,使用交叉注意力介绍。
先看 q、k、v:
q = emb_inp # [2, 5, 200]
k = emb_tar # [2, 6, 200]
v = torch.where(torch.rand(k.shape) > 0.5, torch.ones_like(k), torch.zeros_like(k)).float()
v = v.repeat(1, 1, 2) #[2, 6, 400]
v.repeat
是为了演示 depth 和 depth_ 可以不一样。
接下去计算 padding mask,cross attention 时需要对 tar_mask 做 mask,它是目标序列的 mask:
tar_mask = create_padding_mask(tar)
mask = tar_mask.squeeze(1) # [2, 1, 6]
print(mask.shape) # torch.Size([2, 1,6])
接下去计算交叉注意力权重,仔细体会维度:
_, _attention_weights = scaled_dot_product_attention(q, k, v, mask)
print("attention:", _attention_weights.shape)
attention: torch.Size([2, 5, 6])
2:自注意力机制和 look ahead mask
temp_q = temp_k = emb_inp
look_ahead_mask = create_look_ahead_mask(temp_q.shape[1])
temp_v = torch.where(torch.rand(temp_k.shape) > 0.5, torch.ones_like(temp_k), torch.zeros_like(temp_k)).float()
# 右上角的三角形遮罩都是 0,不会看到未来的信息
_, _attention_weights = scaled_dot_product_attention(temp_q, temp_k, temp_v, look_ahead_mask)
print("attention_weights:", _attention_weights)
# 每个一序列的第一个词只会关注自己
print(_attention_weights[:, 0, :])
这样理解是不是就比较清楚了!
来源:aigcrepo