3751 words
19 minutes
Transformer 架构详解
Transformer 架构详解
1. FFN 与普通全连接的区别
结构对比
| 特性 | 普通全连接层 | FFN |
|---|---|---|
| 层数 | 单层 | 两层 |
| 激活函数 | 可选 | 必须(非线性) |
| 维度变换 | 输入→输出 | d_model → d_ff → d_model |
| 位置交互 | 全连接 | Position-wise(逐位置) |
数学表达
普通全连接:
FFN(Transformer):
其中:
W_1 ∈ ℝ^{d_model × d_ff}:第一层权重W_2 ∈ ℝ^{d_ff × d_model}:第二层权重σ:非线性激活函数(通常为 ReLU 或 GELU)
2. FFN 的结构
架构图
FFN 架构
输入 x (d_model) │ ▼ ┌─────────────┐ │ Linear W1 │ d_model → d_ff (通常是 4×d_model) │ (W_1) │ └─────────────┘ │ ▼ ┌─────────────┐ │ ReLU/ │ 非线性激活 │ GELU │ └─────────────┘ │ ▼ ┌─────────────┐ │ Linear W2 │ d_ff → d_model │ (W_2) │ └─────────────┘ │ ▼ 输出 (d_model)维度扩展
d_model = 512 ↓d_ff = 2048 (4 × 512) ↓d_model = 512典型配置(以 BERT-base 为例):
d_model = 768d_ff = 3072 # 4 × 768d_head = 64 # 用于多头注意力num_heads = 123. 常见错误理解及纠正
错误理解
FFN 的作用是将 token 内部来自多个不同子空间的注意力信息拼接到同一维。
正确理解
这是完全错误的! 这种描述混淆了 FFN 和多头注意力(MHA)的功能。
| 操作 | 执行者 | 作用 |
|---|---|---|
| 拼接多个子空间信息 | 多头注意力 (MHA) | 扩展模型的表示能力,让不同头学习不同模式 |
| 维度映射与非线性变换 | FFN | 对每个位置的表示进行深层变换 |
FFN 实际做什么
FFN 是 position-wise(逐位置)的操作,每个 token 独立处理:
Token 1: [x₁] ──→ FFN ──→ [y₁] (独立计算)Token 2: [x₂] ──→ FFN ──→ [y₂] (独立计算)Token 3: [x₃] ──→ FFN ──→ [y₃] (独立计算) ...Token n: [xₙ] ──→ FFN ──→ [yₙ] (独立计算)FFN 不会跨位置交互信息,跨位置交互是注意力机制的工作。
4. Scaled Dot-Product Attention:为什么需要 Scaling?
常见错误理解
“d_k scaling 用来防止梯度爆炸”
正确机制
当维度 d_k 较大时,点积的方差会随 d_k 增长(d_k × σ²)。这会导致:
exp()函数放大数值差异- softmax 饱和到接近 one-hot 分布
- 某个 token “独占” 注意力,其他 token 梯度几乎为 0
Scaling 的真正作用是保持方差恒定,让 softmax 输出始终分布在合理范围内。
# 直觉理解d_k = 64 → dot product 方差 ≈ 64 × var(q) × var(k)d_k = 512 → dot product 方差 ≈ 512 × var(q) × var(k) # 大了 8 倍!
# 除以 √(d_k) 后,方差恢复为 1完整公式
5. Multi-Head Attention:不是更多参数,而是更多子空间
参数分布
d_model = 512, h = 8 时:├── 每个 head 的 d_k = d_v = 64├── 总参数量与单头注意力相同(4 × d_model × d_model 量级)└── 但能在多个语义子空间并行学习不同的注意力模式W^O 的作用
Multi-Head Attention 输出┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐│head1│ │head2│ │ ... │ │head8│ (h × d_k 维)└─────┘ └─────┘ └─────┘ └─────┘ ↓ concat (8 × 64 = 512 维) ↓ W^O (512 × 512) 映射回 d_model 维统一表征空间各 head 来自不同的 Q/K/V 投影子空间,分布不同,必须通过 W^O 对齐到统一表征空间。
6. Layer Norm vs Batch Norm:最容易混淆的点
错误理解
Layer Norm 在 token 之间做归一化
正确理解
LN 在每个 token 内部独立归一化
| Layer Norm | Batch Norm | |
|---|---|---|
| 归一化维度 | 每个 token 自己的 d_model 维 | 跨 token(同一个维度) |
| batch 依赖 | 无 | 有(训练/推理不一致) |
| 变长序列 | 天然适合 | 需要 padding |
# Layer Norm(Transformer 使用)token1: [x1, x2, x3, ..., x512] → 独立归一化token2: [y1, y2, y3, ..., y512] → 独立归一化(每个 token 只跟自己比)
# Batch Norm(CNN 常用)维度方向: [token1_d1, token2_d1, token3_d1, ...] → 一起归一化(同一个维度上的所有 token 比)残差连接的核心要求
x + Sublayer(x) 要求维度完全相同,梯度才能直接回传。
7. Transformer 完整架构图
Encoder-Decoder 架构
┌──────────────────────────────────────────────────────────────────────────────┐│ TRANSFORMER 完整架构 │└──────────────────────────────────────────────────────────────────────────────┘ ENCODER DECODER ┌─────────────────────────────────────┐ ┌──────────────────────────────────────┐ │ │ │ │ │ ┌───────────────────────────┐ │ │ ┌───────────────────────────┐ │ │ │ Input Embedding │ │ │ │ Output Embedding │ │ │ │ + Positional │ │ │ │ + Positional │ │ │ │ Encoding │ │ │ │ Encoding │ │ │ └───────────┬───────────────┘ │ │ └───────────┬───────────────┘ │ │ │ │ │ │ │ │ ▼ │ │ ▼ │ │ ┌──────────────────────────────┐ │ │ ┌───────────────────────────┐ │ │ │ Multi-Head │ │ │ │ Multi-Head │ │ │ │ Attention │ │ │ │ Self-Attention │ │ │ │ │ │ │ │ │ │ │ │ ┌─────┐ ┌─────┐ ┌─────┐ │ │ │ │ ┌─────┐ ┌─────┐ ┌─────┐ │ │ │ │ │Head1│ │Head2│ │...H │ │ │ │ │ │Head1│ │Head2│ │...H │ │ │ │ │ └──┬──┘ └──┬──┘ └──┬──┘ │ │ │ │ └──┬──┘ └──┬──┘ └──┬──┘ │ │ │ │ └────────┼────────┘ │ │ │ │ └────────┼────────┘ │ │ │ │ ▼ │ │ │ │ ▼ │ │ │ │ ┌───────────────────┐ │ │ │ │ ┌───────────────────┐ │ │ │ │ │ Linear (W_O) │ │ │ │ │ │ Linear (W_O) │ │ │ │ │ │ concat → d_model │ │ │ │ │ │ concat → d_model │ │ │ │ │ └─────────┬─────────┘ │ │ │ │ └─────────┬─────────┘ │ │ │ └─────────────┼────────────────┘ │ └────────────┼──────────────┘ │ │ │ │ │ │ │ │ ▼ │ │ ▼ │ │ ┌───────────────────────────┐ │ │ ┌───────────────────────────┐ │ │ │ ADD │ │ │ │ ADD │ │ │ │ (Residual) │ │ │ │ (Residual) │ │ │ └───────────┬───────────────┘ │ │ └───────────┬───────────────┘ │ │ │ │ │ │ │ │ ▼ │ │ ▼ │ │ ┌───────────────────────────┐ │ │ ┌───────────────────────────┐ │ │ │ LN │ │ │ │ LN │ │ │ │ (Layer Norm) │ │ │ │ (Layer Norm) │ │ │ └───────────┬───────────────┘ │ │ └───────────┬───────────────┘ │ │ │ │ │ │ │ │ ▼ │ │ │ │ │ ┌───────────────────────────┐ │ │ │ │ │ │ FFN │ │ │ │ │ │ │ │ │ │ │ │ │ │ ┌─────────────────────┐ │ │ │ │ │ │ │ │ Linear (W_1) │ │ │ │ │ │ │ │ │ d_model → d_ff │ │ │ │ │ │ │ │ └──────────┬──────────┘ │ │ │ │ │ │ │ ▼ │ │ │ │ │ │ │ ┌─────────────────────┐ │ │ │ │ │ │ │ │ ReLU / GELU │ │ │ │ │ │ │ │ └──────────┬──────────┘ │ │ │ │ │ │ │ ▼ │ │ │ │ │ │ │ ┌─────────────────────┐ │ │ │ │ │ │ │ │ Linear (W_2) │ │ │ │ │ │ │ │ │ d_ff → d_model │ │ │ │ │ │ │ │ └─────────────────────┘ │ │ │ │ │ │ └───────────┬───────────────┘ │ │ │ │ │ │ │ │ │ │ │ ▼ │ │ ▼ │ │ ┌───────────────────────────┐ │ │ ┌───────────────────────────┐ │ │ │ ADD │ │ │ │ Cross Attention │ │ │ │ (Residual) │ │ │ │ │ │ │ └───────────┬───────────────┘ │ │ │ Query: Decoder │ │ │ │ │ │ │ Key/Value: Encoder │ │ │ ▼ │ │ └─────────┴─────────────────┘ │ │ ┌───────────────────────────┐ │ │ │ │ │ │ LN │ │ │ ▼ │ │ └───────────┬───────────────┘ │ │ ┌─────────────────────────┐ │ │ │ │ │ │ ADD + LN │ │ │ ┌───────────────────────────┐ │ │ └─────────┴───────────────┘ │ │ │ × N (layers) │ │ │ │ │ │ └───────────────────────────┘ │ │ ▼ │ │ │ │ ┌───────────────────────────┐ │ └─────────────────────────────────────┘ │ │ × N (layers) │ │ │ └───────────────────────────┘ │ │ │ │ │ ▼ │ │ ┌───────────────────────────┐ │ │ │ Linear │ │ │ │ d_model → vocab │ │ │ └───────────┬───────────────┘ │ │ │ │ │ ▼ │ │ ┌───────────────────────────┐ │ │ │ Softmax │ │ │ └───────────┬───────────────┘ │ │ │ │ │ ▼ │ │ Output │ └──────────────────────────────────────┘单层 Encoder 详细结构
┌─────────────────────────────────────────────────────────────────┐│ Single Encoder Layer │└─────────────────────────────────────────────────────────────────┘
Input │ ▼┌─────────────────────────────────────────────────────────────────┐│ Multi-Head Self-Attention ││ ││ Query, Key, Value 均来自同一输入(Self) ││ ││ Input ──► W_Q, W_K, W_V ──► Split into H heads ──► Attention ││ │ │ ││ │ ┌──────────┴──────────┐ ││ │ │ 每个头独立计算 │ ││ │ │ Attention(Q,K,V) │ ││ │ └──────────┬──────────┘ ││ │ │ ││ ▼ ▼ ││ Concat heads ──► W_O ──► Output ││ │└───────────────────────────────┬─────────────────────────────────┘ │ ▼ ┌────────────┐ │ ADD │ Residual Connection │ (x + sub) │ └────────────┘ │ ▼ ┌────────────┐ │ LN │ LayerNorm └────────────┘ │ ▼ ┌────────────┐ │ FFN │ Position-wise FFN │ │ │ x' = FFN(x)│ └────────────┘ │ ▼ ┌────────────┐ │ ADD │ Residual Connection │ (x + x') │ └────────────┘ │ ▼ ┌────────────┐ │ LN │ LayerNorm └────────────┘ │ ▼ Output8. FFN 与 MHA 线性层的澄清
错误说法
FFN 和多头注意力机制的最后线性层都是针对同一 token 内部所作操作。
正确理解
这个说法不完全准确,两者有本质区别:
| 操作 | 作用域 | 说明 |
|---|---|---|
| MHA 最终线性层 (W_O) | 跨 token | 将 H 个头的输出拼接后线性变换,整合所有 token的信息 |
| FFN 线性层 (W_1, W_2) | token 内部 | 逐位置独立变换,不涉及 token 间交互 |
对比图示
MHA 最终线性层 (W_O): W_O (d_head×H → d_model) │ Head1 ──► [h₁] ─┐ │ Head2 ──► [h₂] ─┼──► Concat ────────┼──► Output ... ──► [...]─┘ [h₁⊕h₂⊕...] │ HeadH ──► [h_H] ─┘ │ │ ⚠️ 这里的拼接是多头拼接,但输出会受到所有 token 注意力分数的影响
FFN (两层线性):
Token 1: [x₁] ──► W₁ ──► σ ──► W₂ ──► [y₁] (独立) Token 2: [x₂] ──► W₁ ──► σ ──► W₂ ──► [y₂] (独立) ... Token n: [xₙ] ──► W₁ ──► σ ──► W₂ ──► [yₙ] (独立)
⚠️ 严格逐位置操作,无 token 间交互关键区别总结
┌─────────────────────────────────────────────────────────────────┐│ 核心区别 │├─────────────────────────────────────────────────────────────────┤│ ││ MHA (Multi-Head Attention): ││ ┌───────────────────────────────────────────────────────────┐ ││ │ 跨 token 操作:每个位置的输出受序列中所有位置影响 │ ││ │ 注意力分数决定了 token 间的信息流动 │ ││ └───────────────────────────────────────────────────────────┘ ││ ││ FFN (Feed-Forward Network): ││ ┌───────────────────────────────────────────────────────────┐ ││ │ token 内部操作:每个位置独立变换,无信息交互 │ ││ │ W₁, W₂ 仅负责将该位置的表示映射到更高维空间再映射回来 │ ││ └───────────────────────────────────────────────────────────┘ ││ │└─────────────────────────────────────────────────────────────────┘9. FFN 参数量分析
参数分布
在标准 Transformer 中,FFN 占据总参数的 60-70%:
┌─────────────────────────────────────────────────────────────────┐│ Transformer 参数分布(估算) │├─────────────────────────────────────────────────────────────────┤│ ││ Embedding + Positional Encoding ││ ████████████████████████████████████ 约 20-30% ││ ││ Multi-Head Attention ││ ████████████████████ 约 20-30% ││ ├── W_Q, W_K, W_V (d_model × d_model) ││ └── W_O (d_model × d_model) ││ ││ FFN ││ ████████████████████████████████████████████████ 约 60-70% ││ ├── W_1 (d_model × d_ff) = d_model × 4×d_model ││ └── W_2 (d_ff × d_model) = 4×d_model × d_model ││ │└─────────────────────────────────────────────────────────────────┘具体计算示例
以 BERT-base 为例:
# 模型配置d_model = 768d_ff = 3072 # 4 × 768num_heads = 12num_layers = 12vocab_size = 30522
# 参数量计算def count_parameters(): # Embedding embedding = vocab_size * d_model + 512 * d_model # + positional
# MHA per layer mha_per_layer = 3 * (d_model * d_model) + d_model * d_model # W_Q, W_K, W_V, W_O
# FFN per layer ffn_per_layer = 2 * (d_model * d_ff) # W_1, W_2
# Layer Norm per layer (2 per layer: after MHA and after FFN) ln_per_layer = 2 * (2 * d_model) # gain + bias per LN
# Total total = (embedding + num_layers * (mha_per_layer + ffn_per_layer + ln_per_layer))
ffn_ratio = (num_layers * ffn_per_layer) / total print(f"FFN parameters: {num_layers * ffn_per_layer:,}") print(f"Total parameters: {total:,}") print(f"FFN占比: {ffn_ratio:.1%}")
count_parameters()# FFN 占比约 66%为什么 FFN 参数量这么大?
# FFN 权重形状W_1.shape = (d_model, d_ff) # 768 × 3072 = 2,359,296W_2.shape = (d_ff, d_model) # 3072 × 768 = 2,359,296
# 对比 MHA 权重W_Q.shape = (d_model, d_model) # 768 × 768 = 589,824W_K.shape = (d_model, d_model) # 768 × 768 = 589,824W_V.shape = (d_model, d_model) # 768 × 768 = 589,824W_O.shape = (d_model, d_model) # 768 × 768 = 589,824
# FFN 单层参数量 ≈ 4.7M# MHA 单层参数量 ≈ 2.4M# FFN 是 MHA 的约 2 倍10. FFN 的作用总结
FFN 在 Transformer 中的真实角色
-
表示空间扩展
- 将 d_model 空间扩展到 d_ff (4×d_model) 空间
- 提供更丰富的表示能力
-
非线性变换
- 引入激活函数,增加模型的非线性表达能力
- 使模型能够学习更复杂的函数映射
-
Token 内部信息整合
- 虽然是 position-wise,但每个 token 内部进行了深层的非线性变换
- 可以视为对”每个 token 内部”表示的进一步加工
-
与注意力机制互补
- 注意力:跨 token 的信息交互
- FFN:token 内部的深度变换
- 两者缺一不可
┌─────────────────────────────────────────────────────────────────┐│ ││ 注意力机制 ──────► 跨 token 信息交互 ││ │ ││ │ 互补 ││ ▼ ││ FFN ──────────► Token 内部深层变换 ││ │└─────────────────────────────────────────────────────────────────┘11. BERT vs GPT:架构哲学的根本差异
架构对比
| BERT | GPT | |
|---|---|---|
| 架构 | Encoder-only | Decoder-only |
| 注意力 | 双向(无因果掩码) | 单向(因果掩码) |
| 训练目标 | MLM(Masked Language Model) | LM(Language Model) |
| 典型任务 | 理解(分类、NER、QA) | 生成(续写、对话) |
BERT 的 [MASK] 设计
训练时约:
- 80% 用 [MASK] 替换
- 10% 随机 token 替换
- 10% 保持不变
这减少了微调时与预训练分布的差异。
# BERT 的双向注意力[CLS] The cat sat on the [MASK] .
[MASK] 位置可以看到:The, cat, sat, on, the(双向)→ 约 15% 的 token 被 [MASK] 替换
# GPT 的单向注意力The cat sat on the mat. ↑ 因果掩码:只能看到左边的 token为什么需要 Layer Norm 而不是 Batch Norm?
- Layer Norm 不依赖 batch 维度,训练/推理行为一致
- Transformer 处理变长序列,Batch Norm 需要 padding
- 每个 token 独立归一化更符合自注意力的独立交互模式
12. 代码示例
PyTorch 实现
import torchimport torch.nn as nnimport torch.nn.functional as Fimport math
class TransformerFFN(nn.Module): """标准 FFN 实现"""
def __init__(self, d_model: int, d_ff: int, activation: str = "relu"): super().__init__() self.w_1 = nn.Linear(d_model, d_ff) self.w_2 = nn.Linear(d_ff, d_model) self.activation = activation
def forward(self, x: torch.Tensor) -> torch.Tensor: # x: (batch, seq_len, d_model) x = self.w_1(x) x = F.relu(x) if self.activation == "relu" else F.gelu(x) x = self.w_2(x) return x
class MultiHeadAttention(nn.Module): """多头注意力(简化版)"""
def __init__(self, d_model: int, num_heads: int): super().__init__() self.d_model = d_model self.num_heads = num_heads self.d_head = d_model // num_heads
self.w_q = nn.Linear(d_model, d_model) self.w_k = nn.Linear(d_model, d_model) self.w_v = nn.Linear(d_model, d_model) self.w_o = nn.Linear(d_model, d_model) # ← 这是跨 token 的线性层
def forward(self, x: torch.Tensor, mask=None) -> torch.Tensor: batch, seq_len, _ = x.shape
# Q, K, V Q = self.w_q(x) K = self.w_k(x) V = self.w_v(x)
# 分头 Q = Q.view(batch, seq_len, self.num_heads, self.d_head).transpose(1, 2) K = K.view(batch, seq_len, self.num_heads, self.d_head).transpose(1, 2) V = V.view(batch, seq_len, self.num_heads, self.d_head).transpose(1, 2)
# Scaled Dot-Product Attention scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_head) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) attn_weights = F.softmax(scores, dim=-1) attn_output = torch.matmul(attn_weights, V)
# 合并多头 attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(batch, seq_len, self.d_model)
# 最终线性层 W_O - 关键:这是注意力输出后的一次线性变换 output = self.w_o(attn_output)
return output
# 测试 FFN 的 position-wise 特性def test_position_wise(): d_model = 64 d_ff = 256 seq_len = 10 batch = 2
ffn = TransformerFFN(d_model, d_ff) x = torch.randn(batch, seq_len, d_model)
output = ffn(x)
print(f"输入形状: {x.shape}") print(f"输出形状: {output.shape}")
# 验证 position-wise:每个位置独立计算 # 修改一个位置,应该只有该位置输出变化 x_modified = x.clone() x_modified[:, 3, :] = 1000.0
output_modified = ffn(x_modified) diff = (output_modified - output).abs()
# 只有位置 3 的差异应该很大 print(f"位置 3 的差异: {diff[:, 3, :].mean():.6f}") print(f"其他位置的差异: {diff[:, :3, :].mean():.6f}, {diff[:, 4:, :].mean():.6f}") # 预期:位置 3 差异大,其他位置差异接近 0
test_position_wise()13. 常见问题
Q1: FFN 可以用卷积替代吗?
可以,但通常不这么做。FFN 的优势:
- 无需设计卷积核大小
- 适合任意序列长度
- 实现简单高效
Q2: FFN 的激活函数可以省略吗?
不能。激活函数提供非线性,是 FFN 表达能力的核心。
Q3: FFN 和 MoE(Mixture of Experts)有什么关系?
MoE 是对 FFN 的扩展:
标准 FFN: 每个 token 经过 1 个 FFNMoE FFN: 每个 token 经过 K 个 FFN 中的 1 个(由门控选择)总结
| 关键点 | 说明 |
|---|---|
| FFN 结构 | 两层线性层 + 非线性激活(ReLU/GELU) |
| 维度变化 | d_model → d_ff (4×d_model) → d_model |
| 操作方式 | Position-wise,每个 token 独立处理 |
| 参数量 | 占 Transformer 总参数的 60-70% |
| 与 MHA 关系 | MHA 处理 token 间交互,FFN 处理 token 内深度变换 |
| Scaling 作用 | 保持注意力方差恒定,防止 softmax 饱和 |
| W^O 的作用 | 将多头拼接结果映射回统一表征空间 |
| LN vs BN | LN 在每个 token 内部归一化,无 batch 依赖 |
| BERT vs GPT | Encoder vs Decoder,双向 vs 因果掩码 |
记住:FFN 不是拼接不同子空间信息的工具,那是 MHA 的职责。FFN 的核心作用是对每个 token 的表示进行深层非线性变换,两者共同构成 Transformer 的表示学习基础。
Transformer 架构详解
https://sgjki547.top/posts/transformer/