深入理解 Transformer 架构:RNN 和 CNN 无法解决的问题
引言
Transformer 是 2017 年 Google 在论文《Attention Is All You Need》中提出的深度学习架构,核心完全基于自注意力机制(Self-Attention),彻底摒弃了 RNN/LSTM 等序列模型。
核心创新:用注意力机制替代序列传递,让任意两个位置之间可以直接建立关联,不受距离限制。
整体架构
Transformer 由编码器和解码器组成:
- Encoder:用于编码输入(如 BERT)
- Decoder:用于生成输出(如 GPT)
每个 Encoder/Decoder Layer 内部包含:
- Multi-Head Self-Attention
- Feed Forward Network
- 残差连接
- LayerNorm
Transformer 整体架构图
┌─────────────────────────────────────────────────────────────────────────────┐ │ TRANSFORMER (Encoder-Decoder) │ └─────────────────────────────────────────────────────────────────────────────┘
输入序列 输出序列 │ ▲ ▼ │ ┌─────────┐ ┌─────────────────────────────────────────┐ ┌─────────┐ │ Input │ │ ENCODER STACK │ │ Output │ │ Embedding│ │ (N × Encoder) │ │ (N × Decoder) │ └────┬────┘ │ │ └────┬────┘ │ │ ┌─────────────────────────────────┐ │ │ │ │ │ Encoder Layer │ │ │ │ │ │ │ │ │ │ │ │ ┌───────────────────────────┐ │ │ │ │ │ │ │ Multi-Head Self-Attention │ │ │ │ │ │ │ │ (MHA) │ │ │ │ │ │ │ └───────────────────────────┘ │ │ │ │ │ │ │ │ │ │ │ │ │ Add & LayerNorm │ │ │ │ │ │ │ │ │ │ │ │ │ ┌───────────────────────────┐ │ │ │ │ │ │ │ Feed-Forward (FFN) │ │ │ │ │ │ │ │ (FC + ReLU + FC) │ │ │ │ │ │ │ └───────────────────────────┘ │ │ │ │ │ │ │ │ │ │ │ │ │ Add & LayerNorm │ │ │ │ │ └───────────────┼─────────────────┘ │ │ │ │ │ │ │ │ │ ┌───────────────┴─────────────────┐ │ │ │ │ │ Encoder Layer N │ │ │ │ │ └───────────────┬─────────────────┘ │ │ │ └──────────────────┼──────────────────────┘ │ │ │ │ │ ▼ │ │ ┌─────────────────┐ │ │ │ Encoder │ │ │ │ Output │ │ │ │ (K, V for Dec) │ │ │ └────────┬────────┘ │ │ │ │ └───────────────────────────┼──────────────────────────────┘ │ │ ┌───────────────────────────┼───────────────────────────────┐ │ │ │ ▼ ▼ │ ┌─────────┐ ┌─────────────────────────────────────────┐ │ Output │ │ DECODER STACK │ │ Prob │ │ (N × Decoder) │ └─────────┘ │ │ ▲ │ ┌─────────────────────────────────┐ │ │ │ │ Decoder Layer │ │ │ │ │ │ │ │ │ │ ┌───────────────────────────┐ │ │ │ │ │ │ Masked Multi-Head │ │ │ │ │ │ │ Self-Attention (MHA) │ │ │ │ │ │ │ (防止看到未来位置) │ │ │ │ │ │ └───────────────────────────┘ │ │ │ │ │ │ │ │ │ │ │ Add & LayerNorm │ │ │ │ │ │ │ │ │ │ │ ┌───────────────────────────┐ │ │ │ │ │ │ Multi-Head Cross- │ │ │ │ │ │ │ Attention (MHA) │ │ │ │ │ │ │ (Encoder-Decoder Attn) │ │ │ │ │ │ └───────────────────────────┘ │ │ │ │ │ │ │ │ │ │ │ Add & LayerNorm │ │ │ │ │ │ │ │ │ │ │ ┌───────────────────────────┐ │ │ │ │ │ │ Feed-Forward (FFN) │ │ │ │ │ │ └───────────────────────────┘ │ │ │ │ │ │ │ │ │ │ │ Add & LayerNorm │ │ │ │ └───────────────┼─────────────────┘ │ │ │ │ │ │ │ ┌───────────────┴─────────────────┐ │ │ │ │ Decoder Layer N │ │ │ │ └───────────────┬─────────────────┘ │ │ └──────────────────┼──────────────────────┘ │ │ │ │ ▼ │ ┌─────────┐ │ │ Positional │ │ │ Encoding │ │ └────┬────┘ │ │ │ ▼ ▼ ┌─────────────────────────────────────────────────────────────┐ │ INPUT EMBEDDINGS │ │ (Token Embeddings + Positional Encoding) │ └─────────────────────────────────────────────────────────────┘核心组件详解
┌────────────────────────────────────────────────────────────────┐ │ Multi-Head Attention │ │ │ │ 输入 Q, K, V │ │ │ │ │ ▼ │ │ ┌─────┐ ┌─────┐ ┌─────┐ │ │ │ W_Q │ │ W_K │ │ W_V │ (可学习的投影矩阵) │ │ └──┬──┘ └──┬──┘ └──┬──┘ │ │ │ │ │ │ │ ▼ ▼ ▼ │ │ ┌──────┐ ┌──────┐ ┌──────┐ │ │ │Head 1│ │Head 2│ ... │Head h│ (h 个注意力头并行计算) │ │ │Attn │ │Attn │ │Attn │ │ │ └──┬───┘ └──┬───┘ └──┬───┘ │ │ │ │ │ │ │ └──────────┼──────────┘ │ │ ▼ │ │ Concat [h₁, h₂, ..., hₕ] │ │ │ │ │ ▼ │ │ Linear (W_O) │ │ │ │ │ ▼ │ │ 输出 │ └────────────────────────────────────────────────────────────────┘
┌────────────────────────────────────────────────────────────────┐ │ Feed-Forward Network (FFN) │ │ │ │ 输入 x (d_model) │ │ │ │ │ ▼ │ │ Linear₁ (W₁: d_model → d_ff, 通常 d_ff = 4×d_model) │ │ │ │ │ ▼ │ │ ReLU / GELU / Swish (非线性激活) │ │ │ │ │ ▼ │ │ Linear₂ (W₂: d_ff → d_model) │ │ │ │ │ ▼ │ │ 输出 x' (d_model) │ └────────────────────────────────────────────────────────────────┘数据流总结
输入: "The cat sat on the mat" │ ├─ Token Embeddings ──────────────────┐ │ │ ├─ Positional Encoding ───────────────┤ │ │ ▼ ▼ ┌──────────────────────────────────────────┐ │ Encoder Stack │ │ │ │ Input: [seq_len, d_model] │ │ ↓ │ │ MHA + Add&Norm │ │ ↓ │ │ FFN + Add&Norm │ │ ↓ │ │ Output: [seq_len, d_model] │ └──────────────────────────────────────────┘ │ │ (作为 K, V 传给 Decoder) ▼ ┌──────────────────────────────────────────┐ │ Decoder Stack │ │ │ │ 1. Masked MHA (防止看到未来) │ │ 2. Cross MHA (看 Encoder 的 K, V) │ │ 3. FFN + Add&Norm │ └──────────────────────────────────────────┘ │ ▼ Linear + Softmax │ ▼ 输出下一个词的概率关键参数
┌──────────────────┬───────────────────────────────────────┐ │ 组件 │ 常见配置 │ ├──────────────────┼───────────────────────────────────────┤ │ d_model │ 512 (BERT), 768 (GPT-2), 1024 (GPT-3) │ ├──────────────────┼───────────────────────────────────────┤ │ N (层数) │ 6 (原版), 12 (BERT-Large), 96 (GPT-3) │ ├──────────────────┼───────────────────────────────────────┤ │ MHA heads │ 8 (原版), 12/16 (BERT/GPT) │ ├──────────────────┼───────────────────────────────────────┤ │ d_ff │ 4 × d_model = 2048 (原版) │ ├──────────────────┼───────────────────────────────────────┤ │ FFN 参数量 │ 约占总参数的 60-70% │ ├──────────────────┼───────────────────────────────────────┤ │ Attention 参数量 │ 约占总参数的 20-30% │ └──────────────────┴───────────────────────────────────────┘Encoder 与 Decoder 结构对比
| 组件 | Encoder | Decoder |
|---|---|---|
| Self-Attention | 无掩码,双向可见 | 有掩码(Masked),只看左侧 |
| Cross-Attention | 无 | 有,K/V 来自 Encoder |
| 作用 | 理解完整输入 | 生成输出序列 |
| 代表模型 | BERT | GPT 系列 |
Self-Attention(自注意力机制)
为什么需要注意力机制?
RNN 的问题:序列从头传到尾,信息会”稀释”,长序列依赖难以建立。
具体计算步骤
-
生成 Q、K、V:
Q = X·W_QK = X·W_KV = X·W_V -
计算注意力分数:
Score(qᵢ, kⱼ) = qᵢ · kⱼ / √d_k -
Softmax 归一化:
αᵢⱼ = softmax(Score) -
加权求和:
Attention(Q,K,V) = Σⱼ αᵢⱼ · vⱼ
通俗理解
每个位置看完全部位置后,给相关位置加权,信息传递一步到位。
Multi-Head Attention(多头注意力)
单头注意力的局限
单头注意力使用一组 Q/K/V 变换,所有位置之间的关联强度都在同一个注意力矩阵中竞争。这意味着:
- 不同的关联模式(如主语-动词、代词-先行词、语义相似)必须共享同一组权重
- 模型需要在这多种关联模式之间做权衡
- 无法同时专注于多种不同类型的关联
多头注意力的核心价值
多头注意力通过并行多组独立的 Q/K/V 投影来解决这个问题。每组投影构成一个”头”,各头独立计算注意力,各自捕捉不同的关联模式。
投影过程(PyTorch 代码):
# 输入: X (batch, seq_len, d_model) e.g. (1, 512, 512)batch, seq_len, d_model = x.shapeh = 8 # 8 个头d_k = d_model // h # 64
# 线性投影到 Q/K/V 空间,参数: 4 × d_model × d_modelW_qkv = nn.Linear(d_model, 3 * d_model)qkv = W_qkv(x) # (batch, seq_len, 3 * d_model)
# 分割成 Q, K, Vqkv = qkv.reshape(batch, seq_len, 3, h, d_k).permute(2, 0, 3, 1, 4)# qkv[0]=Q, qkv[1]=K, qkv[2]=V 各自 shape: (batch, h, seq_len, d_k)
Q, K, V = qkv[0], qkv[1], qkv[2]
# 计算注意力分数 (batch, h, seq_len, seq_len)scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)attn_weights = F.softmax(scores, dim=-1)
# 加权求和 (batch, h, seq_len, d_k)context = torch.matmul(attn_weights, V)
# 拼接所有头的结果 (batch, seq_len, h * d_k) = (batch, seq_len, d_model)context = context.transpose(1, 2).reshape(batch, seq_len, d_model)
# 最终线性投影output = nn.Linear(h * d_k, d_model)(context)关键参数说明:
- 无论头数多少,Q/K/V 的投影矩阵总参数量为
4 × d_model²(不随 h 增加) - 每增加一个头,只是把 d_model 维度切分成更多子空间,并行计算后再拼接
- 参数量不变,计算量随头数线性增加
典型设置:
- h = 8 个头
- 每个头维度 d_k = d_model / h = 64
不同头可以分别关注不同类型的依赖:主语-动词关系、代词-先行词、语义相似等。
两种投影方案的本质区别
理解 Multi-Head Attention 时,一个常见的混淆点是:投影矩阵的参数到底是如何计算的? 这里存在两种理解方式,只有一种是符合原始论文的。
方案 A(错误理解):每个注意力头有独立的 W_Q、W_K、W_V 投影矩阵,每个头有独立的 d_model × d_k 权重。如果真是这样实现,参数量会随头数增加。
方案 B(正确理解,原始论文):单一投影到 3×d_model 维度,然后 reshape 并分割为 h 个头:
W_qkv: d_model → 3 × d_model │ ▼ [Q, K, V] 各自 shape: (batch, seq_len, d_model) │ ▼ reshape & split h 组 Q^i, K^i, V^i 各自 shape: (batch, h, seq_len, d_k)关键洞察:投影到 3×d_model 后,分割成 h 个头只是 reshape 操作,不产生新参数。
参数量的精确计算
以 d_model = 512, h = 8 为例:
W_qkv 投影层:
输入维度: d_model = 512输出维度: 3 × d_model = 1536参数量: 512 × 1536 = 786,432最终输出投影 W_O:
输入维度: h × d_k = 8 × 64 = 512 = d_model输出维度: d_model = 512参数量: 512 × 512 = 262,144Multi-Head Attention 总参数量:
W_qkv: 512 × 1536 = 786,432W_O: 512 × 512 = 262,144─────────────────────────────总计: 1,048,576 = 4 × 512² ≈ 4 × d_model²无论头数 h 是 8、16 还是 32,参数量始终固定为 4 × d_model²。
为什么”更多头 = 更多参数”是错误的?
这个误解源于没有区分参数存储和计算量:
| 指标 | 与头数 h 的关系 | 原因 |
|---|---|---|
| 参数量 | 固定 | W_qkv 输出固定 3×d_model,W_O 输入固定 h×d_k = d_model |
| 计算量 | 线性增长 | h 个头各自做注意力运算 |
核心要点:增加头数只是把 d_model 维度切分得更细,每条切分线上绑定的参数不变,总参数量不变。
PyTorch 实现验证
import torchimport torch.nn as nn
d_model = 512
# 验证: 无论 h 是多少,参数量不变for h in [2, 4, 8, 16, 32]: d_k = d_model // h W_qkv = nn.Linear(d_model, 3 * d_model) W_O = nn.Linear(h * d_k, d_model) total = sum(p.numel() for p in W_qkv.parameters()) + \ sum(p.numel() for p in W_O.parameters()) print(f"h={h:2d}, d_k={d_k:3d}, total_params={total:,}")输出验证:
h= 2, d_k=256, total_params=1,048,576h= 4, d_k=128, total_params=1,048,576h= 8, d_k= 64, total_params=1,048,576h=16, d_k= 32, total_params=1,048,576h=32, d_k= 16, total_params=1,048,576结论:参数量与 h 完全无关,始终是 4 × d_model²。
为什么还要用多个头?
如果参数量不变,那多头的意义何在?
每个头在独立的子空间运作:
d_k = d_model / h,头数越多,每个头处理的维度越低。这意味着:
- 更多独立的学习信号:8 个头可以学习 8 种不同的关联模式
- 更细粒度的注意力:每个头可以专注于捕捉某一类关联
单头 (d_k = 512): 注意力矩阵捕捉 "整体关联强度"
8 头 (d_k = 64): 头1: 主语-动词关系 头2: 代词-先行词关系 头3: 语义相似度 头4: 句法结构 头5: 语义角色 头6: 词汇搭配 头7: 指代关系 头8: 跨句关联每个头在低维子空间中学到的关联模式可以捕捉不同类型的依赖关系。
位置编码(Positional Encoding)
Attention 本身对位置不敏感。解决方案:在输入嵌入上加位置信息。
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))Feed Forward Network (FFN)
每个注意力层后面接一个 FFN:
FFN(x) = max(0, x·W₁ + b₁) · W₂ + b₂- 注意力负责”交互”
- FFN 负责”变换”
Layer Normalization + 残差连接
每个子层输出后:
Output = LayerNorm(x + Sublayer(x))- 残差连接让梯度直接回传
- LayerNorm 稳定训练
Transformer 变体
| 架构 | 注意力类型 | 适用任务 | 代表模型 |
|---|---|---|---|
| Encoder Only | 双向 Self-Attention | 理解任务:分类、序列标注、MLM | BERT, RoBERTa |
| Decoder Only | 单向(Masked)Self-Attention | 自回归生成:文本补全、对话 | GPT-2/3/4, LLaMA |
| Encoder-Decoder | Encoder 双向 + Decoder 单向 + Cross-Attention | 序列到序列:翻译、摘要、问答 | T5, BART, FLAN-T5 |
| ViT | Encoder + Patch Embedding | 将图像切成 patch 序列 | ViT, DeiT |
Encoder-Decoder 设计哲学:为什么这样设计?
核心问题
在理解 Transformer 的 Encoder-Decoder 架构之前,我们需要先问一个根本性的问题:为什么需要将架构分为编码器和解码器两个部分?
答案在于两类不同的任务本质:
- 理解任务(如分类、情感分析、实体识别):输入是完整的,模型需要”看到”整个输入才能理解
- 生成任务(如机器翻译、文本摘要、问答):输出是逐步生成的,在生成第 t 个 token 时,只能基于已生成的前 t-1 个 token
Encoder:双向理解的力量
Encoder 的职责是理解完整输入。
为什么 Encoder 不需要 Masked Attention?
Encoder 处理输入时,整个输入序列已经全部存在。在编码某个位置时,它可以看向所有其他位置,形成完整的上下文表示。这种双向注意力让每个位置都能建立与所有其他位置的关联。
BERT 的例子:
BERT 使用 Encoder 架构进行预训练,最著名的任务是 Masked Language Model (MLM)。因为 Encoder 能看到完整上下文,所以能够准确预测被 mask 的 token。
Decoder:单向生成的约束
Decoder 的职责是逐个生成输出 token。
为什么 Decoder 必须使用 Masked Attention?
当 Decoder 生成第 t 个 token 时,只能基于前 t-1 个已生成的 token。Masked Attention 通过将”未来”位置 mask 掉,确保模型只能基于已生成的内容预测下一个 token。
# Masked Attention 的实现示意mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()scores.masked_fill_(mask, float('-inf'))attn_weights = F.softmax(scores, dim=-1)Cross-Attention:连接编码与解码的桥梁
Cross-Attention 是 Encoder-Decoder 架构中最关键的设计。
工作原理:
- Q(来自 Decoder):代表当前位置的”查询”——“我正在尝试生成什么?”
- K(来自 Encoder):代表输入序列的”键”——“输入的每个位置讲了什么?”
- V(来自 Encoder):代表输入序列的”值”——“输入每个位置的具体信息是什么?”
通过 Q×K 的相似度计算,Decoder 知道”当前生成”与”输入”的每个位置有多相关,然后结合 V 获得最终的 attended 信息。
Decoder 为何需要输入?
Decoder 并非被动组件,而是主动逐个生成输出序列。其输入机制在训练和推理阶段有本质不同。
训练时:Teacher Forcing 与右移目标序列
训练时,Decoder 接收的是右移一位的目标序列:
输入: [<bos>] 我 爱 学 习输出: 我 爱 学 习 [<eos>]为何要右移? 考虑不右移会发生什么:
不右移:输入: 我 爱 学 习输出: 爱 学 习 [?] ← 模型可以直接复制输入!
右移后:输入: [<bos>] 我 爱 学 习输出: 我 爱 学 习 [<eos>]
预测"爱"时: 输入只有[<bos> 我],看不到"爱"本身预测"学"时: 输入只有[<bos> 我 爱],看不到"学"本身右移机制迫使模型学习前后 token 之间的关系,而非简单的复制操作。
推理时:自回归生成
推理时,Decoder 逐个生成 token,并将已生成的 token 作为下一步的输入:
Step 1: 输入 [<bos>] → 输出: 我Step 2: 输入 [<bos> 我] → 输出: 爱Step 3: 输入 [<bos> 我 爱] → 输出: 学Step 4: 输入 [<bos> 我 爱 学] → 输出: 习Step 5: 输入 [<bos> 我 爱 学 习] → 输出: <eos>Masked Self-Attention vs Cross-Attention
Decoder 包含两个截然不同的注意力机制,各司其职。
Masked Self-Attention(掩码自注意力)
Q、K、V 均来自 Decoder 自身输入:
- 来源:Decoder 输入的 token 序列
- 作用:确保语法连贯性
- 掩码:因果掩码,防止看到未来 token
Decoder 输入序列 │ ┌─────────┼─────────┐ ▼ ▼ ▼ Q K V └─────────┼─────────┘ ▼ Attention(Q,K,V) (含因果掩码)掩码原理:计算位置 i 的注意力时,将所有 j > i 的注意力分数设为 -∞,确保每个 token 只能看到自己和之前的 token。
Cross-Attention(交叉注意力)
Q 来自 Decoder,K 和 V 来自 Encoder 输出:
- Q 来源:Decoder
- K/V 来源:Encoder 输出(源序列的表示)
- 作用:确保语义准确性,参考原始输入
- 掩码:无掩码,Decoder 可以看源序列的所有位置
Decoder 输入 Encoder 输出 │ │ ▼ ▼ Q K / V │ │ └─────────┬──────────────┘ ▼ Attention(Q, K, V) (无因果掩码)对比总结
| 方面 | Masked Self-Attention | Cross-Attention |
|---|---|---|
| Q 来源 | Decoder 输入 | Decoder |
| K/V 来源 | Decoder 输入 | Encoder 输出 |
| 注意力范围 | 仅左侧(因果) | 全部源序列 |
| 作用 | 语法连贯性 | 语义准确性 |
| 掩码 | 因果掩码(不看未来) | 无掩码 |
| 类比 | 回顾前文 | 回看原文献 |
| 示例 | ”主语和动词要一致…" | "现在翻译 ‘love’ → ‘爱‘“ |
为何要分开两个注意力机制?
这种分离并非任意为之,而是 seq2seq 任务中职责分工的体现:
-
Masked Self-Attention 确保语法连贯
- 在已生成的序列内部,每个 token 必须与前文语法一致
- 这是单语问题——在目标语言中维持正确结构
- 例:生成 “loving” 而非 “learning” 违反语法,但可能仍捕捉了”love”的语义意图
-
Cross-Attention 确保语义准确
- 生成的 token 必须准确表示源序列的含义
- 这是跨语问题——在语言间对齐含义
- 例:翻译 “love” 时,模型通过 Cross-Attention 确认理解了”love”,而非与其他词混淆
类比:写作(Masked Self-Attention)与参照原文翻译(Cross-Attention)。两种能力缺一不可——不回顾前文无法写连贯,不参考原文无法准确翻译。
两者配合,使 Decoder 能够生成语法正确且语义忠于输入的输出。
为什么 seq2seq 任务需要 Encoder-Decoder?
对于机器翻译、文本摘要等 序列到序列(Seq2Seq) 任务:
- Encoder 充分理解输入:完整看输入,建立全面表示,不受生成顺序约束
- Decoder 负责生成输出:严格遵守生成顺序,每步只能基于历史内容预测
- Cross-Attention 建立输入输出关联:Decoder 每步生成时,可以回溯输入的任意位置
计算复杂度对比
| 架构 | Self-Attention 复杂度 | 总复杂度 |
|---|---|---|
| Encoder Only | O(n²·d) | O(L·n²·d) |
| Decoder Only | O(m²·d) | O(L·m²·d) |
| Encoder-Decoder | O(n²·d) + O(n·m·d) | O(L·(n² + n·m)·d) |
其中:n = 输入序列长度,m = 输出序列长度,d = 嵌入维度,L = 层数
设计哲学总结
Encoder-Decoder 架构将理解与生成解耦:
- Encoder:充分理解完整输入 → 双向注意力,无掩码 → 最大化信息利用
- Decoder:严格遵守生成顺序 → 单向注意力,有掩码 → 保持自回归属性
- Cross-Attention:桥接输入与输出 → Q 来自 Decoder,K/V 来自 Encoder → 显式建模输入-输出对齐
为什么 Transformer 效果这么好?
- 全局感受野:一层内就能建立任意位置关联
- 并行化:Attention 计算可并行,训练效率远高于 RNN
- 可扩展:残差连接使深层训练稳定,scaling law 效果显著
- 预训练+微调范式:大规模无监督预训练 + 任务微调
计算复杂度
| 组件 | 时间复杂度 | 空间复杂度 |
|---|---|---|
| Self-Attention | O(n² · d) | O(n²) |
| RNN | O(n · d · h) | O(d) |
| FFN | O(n · d²) | O(d²) |
其中 n = 序列长度,d = 嵌入维度。
Attention 的问题是序列长度的二次方,这也是后续优化(FlashAttention、Linear Attention、Mamba 等)的出发点。
Transformer 解决了 RNN 和 CNN 的哪些问题
长距离依赖问题
- RNN:O(n) 路径长度,信息指数衰减,超过 20-30 个词的依赖基本无法有效传递
- CNN:O(log_k(n)) 路径,需要多层堆叠,卷积核大小受限
- Transformer:O(1) 路径,任意位置直接关联
并行训练问题
- RNN:必须序列化,必须等前一步计算完才能计算下一步,GPU 利用率极低
- CNN:核内并行,核间有限
- Transformer:完全并行,矩阵运算,GPU 利用率接近 100%
可扩展性(深层次数)
- RNN:3-4 层以上基本训不动,梯度消失/爆炸
- CNN:可以深,但感受野问题依然在
- Transformer:残差连接使深层训练稳定,已训 100+ 层
显式对齐建模
- RNN:隐式对齐,扩散在 hidden state 中
- CNN:隐式对齐,扩散在卷积核中
- Transformer:Attention Score 就是对齐强度,可直接可视化
Self-Attention 的局限性
尽管 Transformer 通过自注意力机制解决了 RNN 和 CNN 的诸多问题,Self-Attention 本身也存在不容忽视的局限性。
O(n²) 复杂度问题
Self-Attention 的核心计算是生成注意力矩阵:
- 时间复杂度:O(n² · d) — 序列长度 n 的二次方
- 空间复杂度:O(n²) — 需要存储完整的注意力权重矩阵
以 GPT-2 为例:
- 1024 token 上下文:需存储约 4M 个注意力权重
- 2048 token 上下文:需存储约 16M 个注意力权重(4 倍)
- 4096 token 上下文:需存储约 64M 个注意力权重(16 倍)
n² 复杂度是 Transformer 扩展到超长序列的主要瓶颈,这也是后续大量优化工作的出发点。
静态位置编码的局限
Transformer 使用固定的位置编码(正弦/余弦函数),这些编码在预训练时就确定,不随输入动态调整。这导致:
- 缺乏相对位置感知:位置编码是绝对的,对”距离”的学习是间接的
- 无法处理超长序列:当序列超过预训练长度时,位置编码失效
- 对局部结构的忽视:文本的局部结构(短语、句子段落)没有特殊建模
解决方案:RoPE(旋转位置编码)、ALiBi(相对位置偏置)等现代位置编码方法部分缓解了这些问题。
信息瓶颈
自注意力的输出是所有位置的加权求和,d_model 维度固定。当序列很长时:
- 每个位置需要压缩来自全序列的信息
- 某些信息必然在加权求和中被稀释或丢失
- 模型需要更深的层来恢复和精炼信息
局部结构建模不足
Self-Attention 的注意力矩阵是全连接的,每个位置可以看到所有其他位置。这在理论上很完美,但实际中:
- 局部信息(如短语、搭配)被平等对待:没有对局部窗口的特殊偏好
- 计算浪费:很多位置之间的关联强度接近零,但仍需计算
- 对局部特征的捕捉不如 CNN:CNN 的局部卷积天然具有局部感受野
优化方向总览
| 优化方向 | 代表工作 | 核心思想 |
|---|---|---|
| Sparse Attention | Longformer, BigBird | 允许每个位置只关注部分其他位置,降低 O(n²) |
| Linear Attention | Performer, Reformer | 将 softmax 注意力分解为线性形式,降至 O(n) |
| State Space Models | Mamba, S4, S5 | 用状态空间模型替代注意力,O(n) 复杂度 |
| Local + Global | Swin Transformer | 分层设计,局部窗口内计算注意力 |
| Flash Attention | FlashAttention v1/2/3 | IO-aware 优化,利用 GPU 内存层次加速 |
为什么这些优化重要:大语言模型(LLM)需要处理超长上下文(100K+ token),O(n²) 复杂度在工程上变得不可接受。State Space Models(如 Mamba)更是因为其线性复杂度和高质量的序列建模引起了广泛关注。
各问题对比总结
| 问题 | RNN | CNN | Transformer |
|---|---|---|---|
| 长距离依赖 | ❌ 稀释/遗忘 | ❌ 受限于卷积核大小 | ✅ 全局关联一步到位 |
| 并行训练 | ❌ 必须逐时刻序列化 | ✅ 可并行 | ✅ 完全并行 |
| 方向性约束 | ✅ 单向/双向 | ❌ 局部窗口 | ✅ 无方向性 |
| 可扩展性 | ⚠️ 训练不稳定 | ✅ | ✅ scaling 效果好 |
| 显式对齐 | ❌ 隐式学习 | ❌ 隐式学习 | ✅ 可显式建模 |
结语
Transformer 通过自注意力机制从根本上解决了 RNN 的序列传递问题和 CNN 的局部感受野限制,使得模型能够:
- 在单层内建立任意距离的依赖关系
- 利用 GPU 进行完全并行的计算
- 通过残差连接训练深层网络
- 显式建模 token 之间的对齐关系
理解与生成的解耦设计使得不同任务可以选择最适合的架构:Encoder Only 适合理解任务,Decoder Only 适合生成任务,而 Encoder-Decoder 适合 seq2seq 任务。这些优势使得 Transformer 成为大语言模型、多模态模型等现代 AI 系统的基础架构。