Transformer 模型效率:通过改进注意力机制降低机器学习成本
Transformer 架构由 Vaswani 等人在 2017 年发表的里程碑式论文《Attention Is All You Need》中首次提出,如今已被广泛认为是过去十年间最具开创性的科学突破之一。注意力机制是 Transformer 的核心创新,它为人工智能模型提供了一种全新的方法,使模型能够根据具体任务的需求,灵活地聚焦输入序列的不同部分,从而更深入地理解复杂的语言和结构。
最初在自然语言处理领域崭露头角,Transformer 架构的卓越性能很快吸引了跨学科的关注,其应用迅速扩展到语音识别、计算机视觉、强化学习、生物信息学等多个前沿领域,展现出令人瞩目的学科交叉潜力。然而与其革命性突破同时,注意力层的高计算复杂度也逐渐成为制约其进一步发展的瓶颈。随着模型规模的持续增长,注意力层的计算资源需求呈指数级上升,训练和部署成本也随之攀高。
寻找降低注意力层计算开销的有效策略,在提高基于 Transformer 的人工智能模型效率和可扩展性方面至关重要。本文将深入探讨在 PyTorch 生态系统中优化注意力层的多种技术路径,并将重点聚焦于那些在降低计算成本的同时能够保持注意力层精度的创新方法。这些方法包括 PyTorch SDPA、FlashAttention、TransformerEngine Attention、FlexAttention 以及 xFormer attention。
本文将排除通过近似注意力计算来减少计算成本的其他方法(如 DeepSpeed 的 Sparse Attention、Longformer、Linformer 等),同时也不会详细讨论通用的优化技术,尽管这些技术对注意力性能亦有积极影响,但它们并不专门针对注意力计算本身。
值得强调的是,注意力优化是一个极其活跃且快速发展的研究领域,新的方法和突破不断涌现。本文的目标并非提供一个详尽无遗的技术指南,而是希望通过梳理当前主流的优化路径,为读者提供一个清晰的技术概述,并为后续的深入探索和实践铺平道路。
实验模型
为了便于讨论,我们使用流行的 Python 包 timm(版本 0.9.7)构建一个基于 Vision Transformer(ViT)的分类模型,以演示各种注意力内核对性能的影响。
首先,定义一个简化的 Transformer 块,允许通过将注意力函数传递给其构造函数来参数化。由于不同的注意力实现可能假设特定的输入张量格式,还包括一个选项来控制格式,以确保与所选的注意力内核兼容。
# 通用导入
import os, time, functools
# torch 导入
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
# timm 导入
import timm
from timm.models.vision_transformer import VisionTransformer
from timm.layers import Mlp
IMG_SIZE = 224
BATCH_SIZE = 128
# 定义 ViT 设置
NUM_HEADS = 16
HEAD_DIM = 64
DEPTH = 24
PATCH_SIZE = 16
SEQ_LEN = (IMG_SIZE // PATCH_SIZE)**2 # 196
class MyAttentionBlock(nn.Module):
def __init__(
self,
attn_fn,
format = None,
dim: int = 768,
num_heads: int = 12,
**kwargs
) -> None:
super().__init__()
self.attn_fn = attn_fn
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.qkv = nn.Linear(dim, dim * 3, bias=False)
self.proj = nn.Linear(dim, dim)
self.mlp = Mlp(
in_features=dim,
hidden_features=dim * 4,
)
permute = (2, 0, 3, 1, 4)
self.permute_attn = functools.partial(torch.transpose,dim0=1,dim1=2)
if format == 'bshd':
permute = (2, 0, 1, 3, 4)
self.permute_attn = nn.Identity()
self.permute_qkv = functools.partial(torch.permute,dims=permute)
def forward(self, x_in: torch.Tensor) -> torch.Tensor:
x = self.norm1(x_in)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
# 根据指定格式置换张量
qkv = self.permute_qkv(qkv)
q, k, v = qkv.unbind(0)
# 使用用户指定的注意力函数
x = self.attn_fn(q, k, v)
# 根据指定格式置换输出
x = self.permute_attn(x).reshape(B, N, C)
x = self.proj(x)
x = x + x_in
x = x + self.mlp(self.norm2(x))
return x
我们定义一个随机生成的数据集,我们将在训练期间用它来训练
# 使用随机数据
class FakeDataset(Dataset):
def __len__(self):
return 1000000
def __getitem__(self, index):
rand_image = torch.randn([3, IMG_SIZE, IMG_SIZE],
dtype=torch.float32)
label = torch.tensor(data=index % 1000, dtype=torch.int64)
return rand_image, label
接下来定义 ViT 训练函数。虽然我们的示例侧重于演示训练工作负载,但必须强调的是,在模型推理期间优化注意力层同样重要,甚至更为重要。
定义的训练函数接受自定义的 Transformer 块和一个控制使用 torch.compile 的标志。
def train_fn(block_fn, compile):
torch.random.manual_seed(0)
device = torch.device("cuda:0")
torch.set_float32_matmul_precision("high")
# 创建数据集和数据加载器
train_set = FakeDataset()
train_loader = DataLoader(
train_set, batch_size=BATCH_SIZE,
num_workers=12, pin_memory=True, drop_last=True)
model = VisionTransformer(
img_size=IMG_SIZE,
patch_size=PATCH_SIZE,
embed_dim=NUM_HEADS*HEAD_DIM,
depth=DEPTH,
num_heads=NUM_HEADS,
class_token=False,
global_pool="avg",
block_fn=block_fn
).to(device)
if compile:
model = torch.compile(model)
# 定义损失和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters())
model.train()
t0 = time.perf_counter()
summ = 0
count = 0
for step, data in enumerate(train_loader):
# 将数据复制到 GPU
inputs = data[0].to(device=device, non_blocking=True)
label = data[1].to(device=device, non_blocking=True)
with torch.amp.autocast('cuda', enabled=True, dtype=torch.bfloat16):
outputs = model(inputs)
loss = criterion(outputs, label)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
# 捕获步骤时间
batch_time = time.perf_counter() - t0
if step > 20: # 跳过前几步
summ += batch_time
count += 1
t0 = time.perf_counter()
if step > 100:
break
print(f'average step time: {summ / count}')
# 定义编译和未编译的训练函数变体
train = functools.partial(train_fn, compile=False)
train_compile = functools.partial(train_fn, compile=True)
下面的代码块中定义了一个 PyTorch 原生的注意力函数,并使用它来训练我们的 ViT 模型:
def attn_fn(q, k, v):
scale = HEAD_DIM ** -0.5
q = q * scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
x = attn @ v
return x
block_fn = functools.partial(MyAttentionBlock, attn_fn=attn_fn)
print('Default Attention')
train(block_fn)
print('Compiled Default Attention')
train_compile(block_fn)
在 NVIDIA H100 上运行了这个模型,使用 CUDA 12.4 和 PyTorch 2.5.1。未编译的变体平均步时间为 370 毫秒(ms),而编译的变体改进为 242 ms。我们将使用这些结果作为基准,比较其他执行注意力计算的解决方案。
PyTorch SDPA
在 PyTorch 中提升注意力层性能的最简单方法之一是使用 scaled_dot_product_attention (SDPA) 函数。目前处于测试阶段的 PyTorch SDPA 整合了多个内核级优化,并根据输入的属性动态选择最有效的一个。支持的后端(截至目前)包括:FlashAttention-2、Memory-Efficient Attention、基于 C++ 的 Math Attention 和 CuDNN。这些后端将高级操作融合在一起,同时采用 GPU 级优化以提高计算效率和内存利用率。
SDPA 不断发展,定期引入新的和改进的后端实现。例如PyTorch 2.5 引入了一个更新的 CuDNN 后端,具有专门为 NVIDIA Hopper 架构 GPU 训练量身定制的 SDPA 原语。
在下面的代码块中,遍历支持的后端列表,并评估每个后端的训练运行时性能。我们使用一个辅助函数 set_sdpa_backend 来编程 SDPA 后端:
from torch.nn.functional import scaled_dot_product_attention as sdpa
def set_sdpa_backend(backend):
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(False)
torch.backends.cuda.enable_cudnn_sdp(False)
if backend in ['flash_sdp','all']:
torch.backends.cuda.enable_flash_sdp(True)
if backend in ['mem_efficient_sdp','all']:
torch.backends.cuda.enable_mem_efficient_sdp(True)
if backend in ['math_sdp','all']:
torch.backends.cuda.enable_math_sdp(True)
if backend in ['cudnn_sdp','all']:
torch.backends.cuda.enable_cudnn_sdp(True)
for backend in ['flash_sdp', 'mem_efficient_sdp',
'math_sdp', 'cudnn_sdp']:
set_sdpa_backend(backend)
block_fn = functools.partial(MyAttentionBlock,
attn_fn=sdpa)
print(f'PyTorch SDPA - {backend}')
train(block_fn)
print(f'Compiled PyTorch SDPA - {backend}')
train_compile(block_fn)
结果如下
SDPA 后端的选择对性能有明显影响,而模型编译执行的优化似乎掩盖了注意力内核之间的差异。需要注意的是不要从这些结果中得出任何结论,因为不同注意力函数对性能的影响可能因具体模型和用例而异。
第三方注意力内核
虽然 PyTorch SDPA 是一个很好的起点,但使用第三方注意力内核可以进一步加速机器学习工作负载。这些替代方案通常具有更大的灵活性,提供更广泛的注意力配置选项。有些还包括针对特定硬件加速器或更新 GPU 架构的优化。
我们将探讨一些可用的第三方注意力内核,并评估它们对运行时性能的潜在影响。
FlashAttention-3
虽然 Pytorch SDPA 支持 FlashAttention 后端,但更高级的 FlashAttention 实现可以在 flash-attn 库中找到。这里将探讨 FlashAttention-3 测试版,它的速度比 FlashAttention-2 快多达 2 倍。鉴于其开发的早期阶段,FlashAttention-3 只能直接从 GitHub 仓库安装,并且其使用仅限于某些头部维度。并且它尚不支持模型编译。在下面的代码块中,我们将 Transformer 块配置为使用 flash-attn-3,同时将注意力输入格式设置为“bshd”(批次、序列、头、深度),以满足库的期望。
# flash attention 3
from flash_attn_interface import flash_attn_func as fa3
attn_fn = lambda q,k,v: fa3(q,k,v)[0]
block_fn = functools.partial(MyAttentionBlock,
attn_fn=attn_fn,
format='bshd')
print(f'Flash Attention 3')
train(block_fn)
结果步时间为 240 ms,比 SDPA flash-attn 快 5%。
Transformer Engine
Transformer Engine (TE) 是一个专门设计用于加速 NVIDIA GPU 上 Transformer 模型的库。TE 定期更新,利用最新的 NVIDIA 硬件和软件产品的功能进行优化,使用户能够在这些优化集成到通用框架(如 PyTorch)之前很长时间内访问专用内核。
在下面的代码块中使用 TE 版本 1.11.0 的 DotProductAttention。与 PyTorch SDPA 类似,TE 支持通过环境变量控制的多个后端。这里我们演示使用 NVTE_FUSED_ATTN 后端。
def set_te_backend(backend):
# 必须在第一次使用 transformer_engine.pytorch.attention 之前应用
os.environ["NVTE_FLASH_ATTN"] = '0'
os.environ["NVTE_FUSED_ATTN"] = '0'
os.environ["NVTE_UNFUSED_ATTN"] = '0'
if backend == 'flash':
os.environ["NVTE_FLASH_ATTN"] = '1'
if backend == 'fused':
os.environ["NVTE_FUSED_ATTN"] = '1'
if backend == 'unfused':
os.environ["NVTE_UNFUSED_ATTN"] = '1'
from transformer_engine.pytorch.attention import DotProductAttention
set_te_backend('fused')
attn_fn = DotProductAttention(NUM_HEADS, HEAD_DIM, NUM_HEADS,
qkv_format='bshd',
# 禁用掩码(默认是因果掩码)
attn_mask_type='no_mask')
block_fn = functools.partial(MyAttentionBlock,
attn_fn=attn_fn,
format='bshd')
print(f'Transformer Engine Attention')
train(block_fn)
print(f'Compiled Transformer Engine Attention')
train_compile(block_fn)
TE 注意力在模型变体中的平均步时间分别为 243 ms 和 204 ms。
XFormer Attention
PyTorch SDPA 的内存高效后端的底层是由 xFormers 库提供的注意力内核。我们可以直接使用源代码,在下面的代码块中,使用 xFormers 版本 0.0.28 的 memory_efficient_attention 操作符。
# xformer memory efficient attention
from xformers.ops import memory_efficient_attention as mea
block_fn = functools.partial(MyAttentionBlock,
attn_fn=mea,
format='bshd')
print(f'xFormer Attention ')
train(block_fn)
print(f'Compiled xFormer Attention ')
train_compile(block_fn)
平均时间为 246 ms,比 SDPA 内存高效内核快 10.5%
结果
eager 模型的赢家是 flash-attn-3,平均步时间比基线模型快 54%。这相当于训练时间减少了 54%。在编译模式下,优化内核的性能大致相同,最快的实现达到了 202 ms,比基线提高了 20%。
为了进行更广泛的评估,我们增加注意力序列长度到 3136 个标记重新运行实验。
IMG_SIZE = 224
BATCH_SIZE = 8
# 定义 ViT 设置
NUM_HEADS = 12
HEAD_DIM = 64
DEPTH = 6
PATCH_SIZE = 4
SEQ_LEN = (IMG_SIZE // PATCH_SIZE)**2 # 3136
结果如下:
当序列长度更大时,注意力内核的性能影响更加明显。flash-attn-3 依旧领先——这次性能提高了约 5 倍,超过了 PyTorch 原生函数。对于编译模型 TE 内核脱颖而出,总体最佳步时间为 53 ms。
使用 FlexAttention 自定义注意力
在探讨标准注意力函数的优化过程中,我们常常遇到需要对注意力计算进行定制化修改的场景。这些修改可能涉及屏蔽中间张量的特定值或对其执行特定操作。这类定制需求可能会与我们此前讨论的优化注意力块的实现方案产生冲突。针对这一问题,我们将探讨几种可行的解决策略:
利用高级内核 API 在着手开发自定义解决方案之前,建议首先全面评估现有优化注意力内核提供的 API。许多先进的注意力内核已经提供了丰富且灵活的接口,支持高度定制化的注意力计算。通过仔细研究这些 API很可能找到已经满足需求的现成功能,从而避免重复造轮子。
实现自定义内核 若现有 API 无法完全满足特定需求,创建自定义注意力实现可能是唯一的解决方案。但这是一条充满挑战的技术路径。正如我们在之前的讨论中阐明的,自定义内核开发存在显著的技术复杂性和性能优化难题。对于追求极致性能的开发者而言,最有效的策略是在现有(最优)内核的基础上进行微小而精准的改动,而非全盘重构。
引入 FlexAttention PyTorch 最新引入的 FlexAttention 为解决注意力计算的定制化需求提供了一个突破性的解决方案。这一创新性特性使开发者能够在不牺牲性能的前提下,灵活实现各种注意力变体。
通过将查询和键标记的点积结果抽象为 score,FlexAttention 提供了两个关键的定制机制:
- score_mod 函数:允许对注意力分数进行编程级别的精细调整
- block_mask 掩码:可以自动应用于 score 张量,实现更复杂的注意力模式
FlexAttention 的技术创新体现在:
- 将 score_mod 操作符直接编译到注意力操作符中,形成单一的融合内核
- 利用 block_masks 的稀疏性特性,智能地避免不必要的计算开销
根据 FlexAttention 官方文档报告的基准测试,该方案在各类使用场景中均展现出显著的性能提升。下面我们将通过具体实例,深入探讨 score_mod 和 block_mask 的实际应用。
Score Mod 示例——使用 Tanh 进行softcap
softcap是一种常用的技术,以下代码块扩展了我们的 PyTorch 原生注意力内核,添加了softcap:
def softcap_attn(q, k, v):
scale = HEAD_DIM ** -0.5
q = q * scale
attn = q @ k.transpose(-2, -1)
# 应用软封顶
attn = 30 * torch.tanh(attn/30)
attn = attn.softmax(dim=-1)
x = attn @ v
return x
首先使用 PyTorch 原生内核训练模型,然后使用优化的 Flex Attention API 进行训练。这些实验是在 3136 长度序列设置下运行的。
# flex attention 导入
from torch.nn.attention.flex_attention import (
create_block_mask,
create_mask,
flex_attention
)
compiled_flex = torch.compile(flex_attention)
# score_mod 定义
def tanh_softcap(score, b, h, q_idx, kv_idx):
return 30 * torch.tanh(score/30)
block_fn = functools.partial(MyAttentionBlock, attn_fn=softcap_attn)
print(f'Attention with Softcap')
train(block_fn)
print(f'Compiled Attention with Softcap')
train_compile(block_fn)
flex_fn = functools.partial(flex_attention, score_mod=tanh_softcap)
compiled_flex_fn = functools.partial(compiled_flex, score_mod=tanh_softcap)
block_fn = functools.partial(MyAttentionBlock,
attn_fn=flex_fn)
compiled_block_fn = functools.partial(MyAttentionBlock,
attn_fn=compiled_flex_fn)
print(f'Flex Attention with Softcap')
train(compiled_block_fn)
print(f'Compiled Flex Attention with Softcap')
train_compile(block_fn)
结果如下
Flash Attention 内核的影响提供了约 3.5 倍的性能提升。
Mask Mod 示例——邻域掩码
还可以通过将稀疏掩码应用于注意力 score 来评估 mask_mod 。我们序列中的每个标记代表 2D 输入图像中的一个补丁。可以修改内核,使每个标记仅关注 2D 标记数组中相应位置的 5x5 窗口内的其他标记。
# 将标记 ID 转换为 2D 索引
def seq_indx_to_2d(idx):
n_row_patches = IMG_SIZE // PATCH_SIZE
r_ind = idx // n_row_patches
c_ind = idx % n_row_patches
return r_ind, c_ind
# 仅关注 2D 标记数组中 5x5 窗口内的标记
def mask_mod(b, h, q_idx, kv_idx):
q_r, q_c = seq_indx_to_2d(q_idx)
kv_r, kv_c = seq_indx_to_2d(kv_idx)
return torch.logical_and(torch.abs(q_r-kv_r)<5, torch.abs(q_c-kv_c)<5)
我们使用支持传递注意力掩码的 PyTorch SDPA。以下代码块包括带掩码的 SDPA 实验,随后是 Flex Attention 实现:
# 物化掩码以用于 SDPA
mask = create_mask(mask_mod, 1, 1, SEQ_LEN, SEQ_LEN, device='cuda')
set_sdpa_backend('all')
masked_sdpa = functools.partial(sdpa, attn_mask=mask)
block_fn = functools.partial(MyAttentionBlock,
attn_fn=masked_sdpa)
print(f'Masked SDPA Attention')
train(block_fn)
print(f'Compiled Masked SDPA Attention')
train_compile(block_fn)
block_mask = create_block_mask(mask_mod, None, None, SEQ_LEN, SEQ_LEN)
flex_fn = functools.partial(flex_attention, block_mask=block_mask)
compiled_flex_fn = functools.partial(compiled_flex, block_mask=block_mask)
block_fn = functools.partial(MyAttentionBlock,
attn_fn=flex_fn)
compiled_block_fn = functools.partial(MyAttentionBlock,
attn_fn=compiled_flex_fn)
print(f'Masked Flex Attention')
train(compiled_block_fn)
print(f'Compiled Masked Flex Attention')
train_compile(block_fn)
Flex Attention 提供了显著的性能提升,提升了 2.19 倍和 2.59 倍。
Flex Attention 限制
尽管我们成功地展示了 Flex Attention 的强大和潜力,但仍有一些限制需要注意:
使用 Flex Attention,(在撰写本文时)只能修改注意力分数(查询和键标记的点积结果)。它不支持在注意力计算的其他阶段进行更改。
由于依赖于 torch.compile,必须小心避免过度重新编译,这可能会极大地降低运行时性能。例如虽然对文档掩码的支持非常吸引人,但只有在所有文档的长度总和保持固定时,它才能按预期执行。
目前(在撰写本文时)Flex Attention 不支持包含 可训练 参数的 score_mod 实现。虽然文档中强调支持相对位置编码,但这些通常是用 可训练 参数(而不是固定值)实现的,目前无法使用。
总结
随着 ML 模型中对 Transformer 架构和注意力层的依赖增加,对优化这些组件的工具和技术的需求也在增加。在本文中,我们探讨了许多注意力内核变体,每个都有其独特的属性、功能和限制。重要的是,没有一种方法适用于所有情况——不同的模型和用例将需要使用不同的内核和不同的优化策略。这强调了拥有多种工具和技术来优化注意力层的重要性。
作者:Chaim Rand