Pytorch实战Transformer算法之Embedding层和Positional编码

使用Pytorch实现一个Transformer模型,专注于了解其中的算法和原理,今天描述Embedding层和Positional encoding。

图片

1:Embedding层

在 Transformer 模型中,Embedding 层(嵌入层)的主要作用是将输入的离散数据(如单词或字符)转换为连续的向量表示。

这些向量表示不仅能捕捉 Token 之间的语义关系,还能有效地压缩维度,从而减少计算复杂度和内存占用。特别是在面对庞大的词汇表时,Embedding 层能够将高维的稀疏表示转换为低维的稠密向量,使得模型更易于处理和训练。

直接上代码:

import torch.nn as nn
import torch
import torch.nn as nn
import torch.nn.functional as F

vocab_size = 2008 # 词典大小
emsize = 200 #d_model
inp = torch.LongTensor([[1,2,4,5,0],[0,6,7,8,9]]) 
tar = torch.LongTensor([[2,4,5,0,0,6],[6,7,8,9,10,0]])

以机器翻译为例,inp 相当于源序列,tar 相当于目标序列。

Pytorch 中可以直接使用 nn.Embedding:

embedding_layer = nn.Embedding(vocab_size, emsize)
print(embedding_layer) 

emb_inp = embedding_layer(inp)
emb_tar = embedding_layer(tar)

print("维度",emb_inp.shape)
print(emb_inp)

Embedding(2008, 200)
维度 torch.Size([2, 5, 200])

通过 Embedding 处理后,序列的维度就从 [2, 5, 6] 变更为 [2, 5, 200]

特别要说到的是 d_model,即 dimension_number,它等于Transformer 的hidden_size的值,也是Transfomger的Embedding size,同样是Wond vectors size 值,也是 WQ、WK、WV三个大矩阵中的一个 size 值,后面会继续描述。

2:Positional encoding

在 Transformer 模型中,自注意力机制本身是对位置无感知的,因此,在进行Embedding处理之后,需要直接添加位置信息,它通过将位置信息加到嵌入向量上来实现,而不会改变嵌入向量的维度。

下面是一个绝对编码位置的解决方案:

def get_angles(pos, i, d_model):

    angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model))
    return pos * angle_rates

defpositional_encoding(position, d_model):

    angle_rads = get_angles(np.arange(position)[:, np.newaxis],
                             np.arange(d_model)[np.newaxis, :],
                             d_model)

    sines = np.sin(angle_rads[:, 0::2])
    cosines = np.cos(angle_rads[:, 1::2])

    pos_encoding = np.concatenate([sines, cosines], axis=-1)
    pos_encoding = pos_encoding[np.newaxis, ...]
    return torch.tensor(pos_encoding, dtype=torch.float32)

接下去看看如何对源序列处理位置编码:

max_len_inp = emb_inp.shape[1]

pos_encoding = positional_encoding(max_len_inp, emsize)
print("pos_encoding:", pos_encoding)
print("pos_encoding.shape:", pos_encoding.shape) # (1, 5, 200)

pos_encoding.shape: torch.Size([1, 5, 200])

可以看到经过处理后,整个输入的维度没有变化。

接下去可视化位置编码:

import matplotlib.pyplot as plt
plt.pcolormesh(pos_encoding[0], cmap='RdBu')
plt.xlabel('d_model')
plt.xlim((0, 512))
plt.ylabel('Position')
plt.colorbar()
plt.show()

x 轴代表与词嵌入向量相同的维度 d_model,y 轴则代表序列中的每个位置,可以看到同一个序列中不同位置的词,以及同一个词在不同 d_model 维度,颜色是接近的。

来源:aigcrepo

THE END