Pytorch实战Transformer算法之注意力机制

注意力机制是 Transformer 模型中最核心的部分。概念上,它是通过一个查询(query)与一组键值对(key-values)进行运算,最终生成一个输出。通过矩阵运算,可以同时让多个查询与一组键值对进行并行计算,从而最大化并行效率。

无论是查询(query)、键(keys)、值(values)还是输出,它们都是向量(vectors),查询和键值对是通过训练得到的权重矩阵计算得到的。

在 Transformer 中,注意力机制主要分为两种:

  • • 自注意力机制:这种机制在编码器或解码器中单独使用。在这种情况下,查询、键和值都来自同一个序列。
  • • 交叉注意力机制:这种机制在编码器-解码器模型中使用。例如,解码器的查询会与编码器的键和值进行注意力计算。
def scaled_dot_product_attention(q, k, v, mask=None):
    
    matmul_qk = torch.matmul(q, k.transpose(-2, -1))  # (..., seq_len_q, seq_len_k)

    dk = torch.tensor(k.size(-1), dtype=torch.float32)  
    # (..., seq_len_q, seq_len_k)
    scaled_attention_logits = matmul_qk / torch.sqrt(dk)  

    if mask isnotNone:
        scaled_attention_logits += (mask * -1e9)

    #  (..., seq_len_q, seq_len_k)
    attention_weights = F.softmax(scaled_attention_logits, dim=-1)  

    # (..., seq_len_q, seq_len_k) @ (..., seq_len_v, depth_v) -> (..., seq_len_q, depth_v)
    output = torch.matmul(attention_weights, v)

    return output, attention_weights

先看参数,注意维度:

  • • q: query shape == (..., seq_len_q, depth)
  • • k: key shape == (..., seq_len_k, depth)
  • • v: value shape == (..., seq_len_v, depth_v)
  • • mask:(..., seq_len_q, seq_len_k)

无论是自注意力机制还是交叉注意力机制,seq_len_k 和 seq_len_v 的序列长度必须相同,因为它们本质上是从同一个序列计算出来的。

另外 query 的 depth 和 key 的 depth 必须相同,因为后续的矩阵运算需要。depth 和 value 的 depth_v 可以不一致,但为了方便运算,通常 q、k、v 的最后一个维度默认与 d_model 相同。

而 mask 的维度取决于是自注意力机制还是交叉注意力机制,如果是自注意力机制 seq_len_q = seq_len_k,如果是交叉注意力机制,两者可以不一样。

scaled_attention_logits 是注意力机制中的关键步骤,用于计算输入序列中不同位置之间的相关性或重要性。具体来说,它表示查询向量和键向量之间的相似度。它为什么要除以 torch.sqrt(dk)?可以有效地稳定梯度、提高数值稳定性,并使得注意力分数的分布更加均匀,从而提高模型的表现。

接下去处理Mask部分,让 mask = 1 的 logits 变得非常小,经过 Softmax 归一化后,这些位置的注意力权重将接近于 0。

attention_weights 就是 softmax 处理后得到的注意力权重,如果是自注意力机制,则 seq_len_q 其实就是 seq_len_k,表示每个单词之间的权重,如果是交叉注意力机制,则表示 q 和 k 单词之间的权重。

而最重要的输出是权重和v之间的加权求和,也是一个矩阵操作,也间接说明 depth_v 可以不同于 depth,最后的维度 (..., seq_len_q, depth_v),其中 seq_len_q 是查询序列的长度,seq_len_k 是键序列的长度。

最终的输出是注意力权重与值向量,也是一个矩阵操作,输出的维度为 (..., seq_len_q, depth_v),这表明输出序列的长度与查询序列的长度相同,而每个位置的输出向量的维度为 depth_v,这一步骤间接说明 depth_v 可以不同于 depth,只要注意力权重的维度与值向量的维度匹配即可。

1:注意力机制和 padding mask

不管哪一种注意力机制,都会用到 padding mask,在这个例子中为了理解的更透彻,使用交叉注意力介绍。

先看 q、k、v:

q = emb_inp # [2, 5, 200]
k = emb_tar # [2, 6, 200]
v = torch.where(torch.rand(k.shape) > 0.5, torch.ones_like(k), torch.zeros_like(k)).float()

v = v.repeat(1, 1, 2)  #[2, 6, 400]

v.repeat 是为了演示 depth 和 depth_ 可以不一样。

接下去计算 padding mask,cross attention 时需要对 tar_mask 做 mask,它是目标序列的 mask:

tar_mask = create_padding_mask(tar)
mask = tar_mask.squeeze(1) # [2, 1, 6]
print(mask.shape)  # torch.Size([2, 1,6])

接下去计算交叉注意力权重,仔细体会维度:

_, _attention_weights = scaled_dot_product_attention(q, k, v, mask)
print("attention:", _attention_weights.shape)  
attention: torch.Size([2, 5, 6])

2:自注意力机制和 look ahead mask

temp_q = temp_k = emb_inp
look_ahead_mask = create_look_ahead_mask(temp_q.shape[1])
temp_v = torch.where(torch.rand(temp_k.shape) > 0.5, torch.ones_like(temp_k), torch.zeros_like(temp_k)).float()

# 右上角的三角形遮罩都是 0,不会看到未来的信息
_, _attention_weights = scaled_dot_product_attention(temp_q, temp_k, temp_v, look_ahead_mask)

print("attention_weights:", _attention_weights)

# 每个一序列的第一个词只会关注自己
print(_attention_weights[:, 0, :]) 

这样理解是不是就比较清楚了!

来源:aigcrepo

THE END