3751 words
19 minutes
Transformer 架构详解

Transformer 架构详解#

1. FFN 与普通全连接的区别#

结构对比#

特性普通全连接层FFN
层数单层两层
激活函数可选必须(非线性)
维度变换输入→输出d_model → d_ff → d_model
位置交互全连接Position-wise(逐位置)

数学表达#

普通全连接:

y=Wx+by = W \cdot x + b

FFN(Transformer):

FFN(x)=W2σ(W1x+b1)+b2FFN(x) = W_2 \cdot \sigma(W_1 \cdot x + b_1) + b_2

其中:

  • W_1 ∈ ℝ^{d_model × d_ff}:第一层权重
  • W_2 ∈ ℝ^{d_ff × d_model}:第二层权重
  • σ:非线性激活函数(通常为 ReLU 或 GELU)

2. FFN 的结构#

架构图#

FFN 架构
输入 x (d_model)
┌─────────────┐
│ Linear W1 │ d_model → d_ff (通常是 4×d_model)
│ (W_1) │
└─────────────┘
┌─────────────┐
│ ReLU/ │ 非线性激活
│ GELU │
└─────────────┘
┌─────────────┐
│ Linear W2 │ d_ff → d_model
│ (W_2) │
└─────────────┘
输出 (d_model)

维度扩展#

d_model = 512
d_ff = 2048 (4 × 512)
d_model = 512

典型配置(以 BERT-base 为例):

d_model = 768
d_ff = 3072 # 4 × 768
d_head = 64 # 用于多头注意力
num_heads = 12

3. 常见错误理解及纠正#

错误理解#

FFN 的作用是将 token 内部来自多个不同子空间的注意力信息拼接到同一维。

正确理解#

这是完全错误的! 这种描述混淆了 FFN 和多头注意力(MHA)的功能。

操作执行者作用
拼接多个子空间信息多头注意力 (MHA)扩展模型的表示能力,让不同头学习不同模式
维度映射与非线性变换FFN对每个位置的表示进行深层变换

FFN 实际做什么#

FFN 是 position-wise(逐位置)的操作,每个 token 独立处理:

Token 1: [x₁] ──→ FFN ──→ [y₁] (独立计算)
Token 2: [x₂] ──→ FFN ──→ [y₂] (独立计算)
Token 3: [x₃] ──→ FFN ──→ [y₃] (独立计算)
...
Token n: [xₙ] ──→ FFN ──→ [yₙ] (独立计算)

FFN 不会跨位置交互信息,跨位置交互是注意力机制的工作。


4. Scaled Dot-Product Attention:为什么需要 Scaling?#

常见错误理解#

“d_k scaling 用来防止梯度爆炸”

正确机制#

当维度 d_k 较大时,点积的方差会随 d_k 增长(d_k × σ²)。这会导致:

  • exp() 函数放大数值差异
  • softmax 饱和到接近 one-hot 分布
  • 某个 token “独占” 注意力,其他 token 梯度几乎为 0

Scaling 的真正作用是保持方差恒定,让 softmax 输出始终分布在合理范围内。

# 直觉理解
d_k = 64 → dot product 方差 ≈ 64 × var(q) × var(k)
d_k = 512 → dot product 方差 ≈ 512 × var(q) × var(k) # 大了 8 倍!
# 除以 √(d_k) 后,方差恢复为 1

完整公式#

Attention(Q,K,V)=softmax(QKTdk)VAttention(Q, K, V) = softmax\left(\frac{QK^T}{\sqrt{d_k}}\right)V

5. Multi-Head Attention:不是更多参数,而是更多子空间#

参数分布#

d_model = 512, h = 8 时:
├── 每个 head 的 d_k = d_v = 64
├── 总参数量与单头注意力相同(4 × d_model × d_model 量级)
└── 但能在多个语义子空间并行学习不同的注意力模式

W^O 的作用#

Multi-Head Attention 输出
┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐
│head1│ │head2│ │ ... │ │head8│ (h × d_k 维)
└─────┘ └─────┘ └─────┘ └─────┘
↓ concat
(8 × 64 = 512 维)
↓ W^O (512 × 512)
映射回 d_model 维统一表征空间

各 head 来自不同的 Q/K/V 投影子空间,分布不同,必须通过 W^O 对齐到统一表征空间。


6. Layer Norm vs Batch Norm:最容易混淆的点#

错误理解#

Layer Norm 在 token 之间做归一化

正确理解#

LN 在每个 token 内部独立归一化

Layer NormBatch Norm
归一化维度每个 token 自己的 d_model 维跨 token(同一个维度)
batch 依赖有(训练/推理不一致)
变长序列天然适合需要 padding
# Layer Norm(Transformer 使用)
token1: [x1, x2, x3, ..., x512] → 独立归一化
token2: [y1, y2, y3, ..., y512] → 独立归一化
(每个 token 只跟自己比)
# Batch Norm(CNN 常用)
维度方向: [token1_d1, token2_d1, token3_d1, ...] → 一起归一化
(同一个维度上的所有 token 比)

残差连接的核心要求#

x + Sublayer(x) 要求维度完全相同,梯度才能直接回传。


7. Transformer 完整架构图#

Encoder-Decoder 架构#

┌──────────────────────────────────────────────────────────────────────────────┐
│ TRANSFORMER 完整架构 │
└──────────────────────────────────────────────────────────────────────────────┘
ENCODER DECODER
┌─────────────────────────────────────┐ ┌──────────────────────────────────────┐
│ │ │ │
│ ┌───────────────────────────┐ │ │ ┌───────────────────────────┐ │
│ │ Input Embedding │ │ │ │ Output Embedding │ │
│ │ + Positional │ │ │ │ + Positional │ │
│ │ Encoding │ │ │ │ Encoding │ │
│ └───────────┬───────────────┘ │ │ └───────────┬───────────────┘ │
│ │ │ │ │ │
│ ▼ │ │ ▼ │
│ ┌──────────────────────────────┐ │ │ ┌───────────────────────────┐ │
│ │ Multi-Head │ │ │ │ Multi-Head │ │
│ │ Attention │ │ │ │ Self-Attention │ │
│ │ │ │ │ │ │ │
│ │ ┌─────┐ ┌─────┐ ┌─────┐ │ │ │ │ ┌─────┐ ┌─────┐ ┌─────┐ │ │
│ │ │Head1│ │Head2│ │...H │ │ │ │ │ │Head1│ │Head2│ │...H │ │ │
│ │ └──┬──┘ └──┬──┘ └──┬──┘ │ │ │ │ └──┬──┘ └──┬──┘ └──┬──┘ │ │
│ │ └────────┼────────┘ │ │ │ │ └────────┼────────┘ │ │
│ │ ▼ │ │ │ │ ▼ │ │
│ │ ┌───────────────────┐ │ │ │ │ ┌───────────────────┐ │ │
│ │ │ Linear (W_O) │ │ │ │ │ │ Linear (W_O) │ │ │
│ │ │ concat → d_model │ │ │ │ │ │ concat → d_model │ │ │
│ │ └─────────┬─────────┘ │ │ │ │ └─────────┬─────────┘ │ │
│ └─────────────┼────────────────┘ │ └────────────┼──────────────┘ │
│ │ │ │ │ │
│ ▼ │ │ ▼ │
│ ┌───────────────────────────┐ │ │ ┌───────────────────────────┐ │
│ │ ADD │ │ │ │ ADD │ │
│ │ (Residual) │ │ │ │ (Residual) │ │
│ └───────────┬───────────────┘ │ │ └───────────┬───────────────┘ │
│ │ │ │ │ │
│ ▼ │ │ ▼ │
│ ┌───────────────────────────┐ │ │ ┌───────────────────────────┐ │
│ │ LN │ │ │ │ LN │ │
│ │ (Layer Norm) │ │ │ │ (Layer Norm) │ │
│ └───────────┬───────────────┘ │ │ └───────────┬───────────────┘ │
│ │ │ │ │ │
│ ▼ │ │ │ │
│ ┌───────────────────────────┐ │ │ │ │
│ │ FFN │ │ │ │ │
│ │ │ │ │ │ │
│ │ ┌─────────────────────┐ │ │ │ │ │
│ │ │ Linear (W_1) │ │ │ │ │ │
│ │ │ d_model → d_ff │ │ │ │ │ │
│ │ └──────────┬──────────┘ │ │ │ │ │
│ │ ▼ │ │ │ │ │
│ │ ┌─────────────────────┐ │ │ │ │ │
│ │ │ ReLU / GELU │ │ │ │ │ │
│ │ └──────────┬──────────┘ │ │ │ │ │
│ │ ▼ │ │ │ │ │
│ │ ┌─────────────────────┐ │ │ │ │ │
│ │ │ Linear (W_2) │ │ │ │ │ │
│ │ │ d_ff → d_model │ │ │ │ │ │
│ │ └─────────────────────┘ │ │ │ │ │
│ └───────────┬───────────────┘ │ │ │ │
│ │ │ │ │ │
│ ▼ │ │ ▼ │
│ ┌───────────────────────────┐ │ │ ┌───────────────────────────┐ │
│ │ ADD │ │ │ │ Cross Attention │ │
│ │ (Residual) │ │ │ │ │ │
│ └───────────┬───────────────┘ │ │ │ Query: Decoder │ │
│ │ │ │ │ Key/Value: Encoder │ │
│ ▼ │ │ └─────────┴─────────────────┘ │
│ ┌───────────────────────────┐ │ │ │ │
│ │ LN │ │ │ ▼ │
│ └───────────┬───────────────┘ │ │ ┌─────────────────────────┐ │
│ │ │ │ │ ADD + LN │ │
│ ┌───────────────────────────┐ │ │ └─────────┴───────────────┘ │
│ │ × N (layers) │ │ │ │ │
│ └───────────────────────────┘ │ │ ▼ │
│ │ │ ┌───────────────────────────┐ │
└─────────────────────────────────────┘ │ │ × N (layers) │ │
│ └───────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌───────────────────────────┐ │
│ │ Linear │ │
│ │ d_model → vocab │ │
│ └───────────┬───────────────┘ │
│ │ │
│ ▼ │
│ ┌───────────────────────────┐ │
│ │ Softmax │ │
│ └───────────┬───────────────┘ │
│ │ │
│ ▼ │
│ Output │
└──────────────────────────────────────┘

单层 Encoder 详细结构#

┌─────────────────────────────────────────────────────────────────┐
│ Single Encoder Layer │
└─────────────────────────────────────────────────────────────────┘
Input
┌─────────────────────────────────────────────────────────────────┐
│ Multi-Head Self-Attention │
│ │
│ Query, Key, Value 均来自同一输入(Self) │
│ │
│ Input ──► W_Q, W_K, W_V ──► Split into H heads ──► Attention │
│ │ │ │
│ │ ┌──────────┴──────────┐ │
│ │ │ 每个头独立计算 │ │
│ │ │ Attention(Q,K,V) │ │
│ │ └──────────┬──────────┘ │
│ │ │ │
│ ▼ ▼ │
│ Concat heads ──► W_O ──► Output │
│ │
└───────────────────────────────┬─────────────────────────────────┘
┌────────────┐
│ ADD │ Residual Connection
│ (x + sub) │
└────────────┘
┌────────────┐
│ LN │ LayerNorm
└────────────┘
┌────────────┐
│ FFN │ Position-wise FFN
│ │
│ x' = FFN(x)│
└────────────┘
┌────────────┐
│ ADD │ Residual Connection
│ (x + x') │
└────────────┘
┌────────────┐
│ LN │ LayerNorm
└────────────┘
Output

8. FFN 与 MHA 线性层的澄清#

错误说法#

FFN 和多头注意力机制的最后线性层都是针对同一 token 内部所作操作。

正确理解#

这个说法不完全准确,两者有本质区别:

操作作用域说明
MHA 最终线性层 (W_O)跨 token将 H 个头的输出拼接后线性变换,整合所有 token的信息
FFN 线性层 (W_1, W_2)token 内部逐位置独立变换,不涉及 token 间交互

对比图示#

MHA 最终线性层 (W_O):
W_O (d_head×H → d_model)
Head1 ──► [h₁] ─┐ │
Head2 ──► [h₂] ─┼──► Concat ────────┼──► Output
... ──► [...]─┘ [h₁⊕h₂⊕...] │
HeadH ──► [h_H] ─┘ │
⚠️ 这里的拼接是多头拼接,但输出会受到所有 token 注意力分数的影响
FFN (两层线性):
Token 1: [x₁] ──► W₁ ──► σ ──► W₂ ──► [y₁] (独立)
Token 2: [x₂] ──► W₁ ──► σ ──► W₂ ──► [y₂] (独立)
...
Token n: [xₙ] ──► W₁ ──► σ ──► W₂ ──► [yₙ] (独立)
⚠️ 严格逐位置操作,无 token 间交互

关键区别总结#

┌─────────────────────────────────────────────────────────────────┐
│ 核心区别 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ MHA (Multi-Head Attention): │
│ ┌───────────────────────────────────────────────────────────┐ │
│ │ 跨 token 操作:每个位置的输出受序列中所有位置影响 │ │
│ │ 注意力分数决定了 token 间的信息流动 │ │
│ └───────────────────────────────────────────────────────────┘ │
│ │
│ FFN (Feed-Forward Network): │
│ ┌───────────────────────────────────────────────────────────┐ │
│ │ token 内部操作:每个位置独立变换,无信息交互 │ │
│ │ W₁, W₂ 仅负责将该位置的表示映射到更高维空间再映射回来 │ │
│ └───────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘

9. FFN 参数量分析#

参数分布#

在标准 Transformer 中,FFN 占据总参数的 60-70%

┌─────────────────────────────────────────────────────────────────┐
│ Transformer 参数分布(估算) │
├─────────────────────────────────────────────────────────────────┤
│ │
│ Embedding + Positional Encoding │
│ ████████████████████████████████████ 约 20-30% │
│ │
│ Multi-Head Attention │
│ ████████████████████ 约 20-30% │
│ ├── W_Q, W_K, W_V (d_model × d_model) │
│ └── W_O (d_model × d_model) │
│ │
│ FFN │
│ ████████████████████████████████████████████████ 约 60-70% │
│ ├── W_1 (d_model × d_ff) = d_model × 4×d_model │
│ └── W_2 (d_ff × d_model) = 4×d_model × d_model │
│ │
└─────────────────────────────────────────────────────────────────┘

具体计算示例#

以 BERT-base 为例:

# 模型配置
d_model = 768
d_ff = 3072 # 4 × 768
num_heads = 12
num_layers = 12
vocab_size = 30522
# 参数量计算
def count_parameters():
# Embedding
embedding = vocab_size * d_model + 512 * d_model # + positional
# MHA per layer
mha_per_layer = 3 * (d_model * d_model) + d_model * d_model # W_Q, W_K, W_V, W_O
# FFN per layer
ffn_per_layer = 2 * (d_model * d_ff) # W_1, W_2
# Layer Norm per layer (2 per layer: after MHA and after FFN)
ln_per_layer = 2 * (2 * d_model) # gain + bias per LN
# Total
total = (embedding +
num_layers * (mha_per_layer + ffn_per_layer + ln_per_layer))
ffn_ratio = (num_layers * ffn_per_layer) / total
print(f"FFN parameters: {num_layers * ffn_per_layer:,}")
print(f"Total parameters: {total:,}")
print(f"FFN占比: {ffn_ratio:.1%}")
count_parameters()
# FFN 占比约 66%

为什么 FFN 参数量这么大?#

# FFN 权重形状
W_1.shape = (d_model, d_ff) # 768 × 3072 = 2,359,296
W_2.shape = (d_ff, d_model) # 3072 × 768 = 2,359,296
# 对比 MHA 权重
W_Q.shape = (d_model, d_model) # 768 × 768 = 589,824
W_K.shape = (d_model, d_model) # 768 × 768 = 589,824
W_V.shape = (d_model, d_model) # 768 × 768 = 589,824
W_O.shape = (d_model, d_model) # 768 × 768 = 589,824
# FFN 单层参数量 ≈ 4.7M
# MHA 单层参数量 ≈ 2.4M
# FFN 是 MHA 的约 2 倍

10. FFN 的作用总结#

FFN 在 Transformer 中的真实角色#

  1. 表示空间扩展

    • 将 d_model 空间扩展到 d_ff (4×d_model) 空间
    • 提供更丰富的表示能力
  2. 非线性变换

    • 引入激活函数,增加模型的非线性表达能力
    • 使模型能够学习更复杂的函数映射
  3. Token 内部信息整合

    • 虽然是 position-wise,但每个 token 内部进行了深层的非线性变换
    • 可以视为对”每个 token 内部”表示的进一步加工
  4. 与注意力机制互补

    • 注意力:跨 token 的信息交互
    • FFN:token 内部的深度变换
    • 两者缺一不可
┌─────────────────────────────────────────────────────────────────┐
│ │
│ 注意力机制 ──────► 跨 token 信息交互 │
│ │ │
│ │ 互补 │
│ ▼ │
│ FFN ──────────► Token 内部深层变换 │
│ │
└─────────────────────────────────────────────────────────────────┘

11. BERT vs GPT:架构哲学的根本差异#

架构对比#

BERTGPT
架构Encoder-onlyDecoder-only
注意力双向(无因果掩码)单向(因果掩码)
训练目标MLM(Masked Language Model)LM(Language Model)
典型任务理解(分类、NER、QA)生成(续写、对话)

BERT 的 [MASK] 设计#

训练时约:

  • 80% 用 [MASK] 替换
  • 10% 随机 token 替换
  • 10% 保持不变

这减少了微调时与预训练分布的差异。

# BERT 的双向注意力
[CLS] The cat sat on the [MASK] .
[MASK] 位置可以看到:The, cat, sat, on, the(双向)
→ 约 15% 的 token 被 [MASK] 替换
# GPT 的单向注意力
The cat sat on the mat.
因果掩码:只能看到左边的 token

为什么需要 Layer Norm 而不是 Batch Norm?#

  • Layer Norm 不依赖 batch 维度,训练/推理行为一致
  • Transformer 处理变长序列,Batch Norm 需要 padding
  • 每个 token 独立归一化更符合自注意力的独立交互模式

12. 代码示例#

PyTorch 实现#

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class TransformerFFN(nn.Module):
"""标准 FFN 实现"""
def __init__(self, d_model: int, d_ff: int, activation: str = "relu"):
super().__init__()
self.w_1 = nn.Linear(d_model, d_ff)
self.w_2 = nn.Linear(d_ff, d_model)
self.activation = activation
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (batch, seq_len, d_model)
x = self.w_1(x)
x = F.relu(x) if self.activation == "relu" else F.gelu(x)
x = self.w_2(x)
return x
class MultiHeadAttention(nn.Module):
"""多头注意力(简化版)"""
def __init__(self, d_model: int, num_heads: int):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_head = d_model // num_heads
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model) # ← 这是跨 token 的线性层
def forward(self, x: torch.Tensor, mask=None) -> torch.Tensor:
batch, seq_len, _ = x.shape
# Q, K, V
Q = self.w_q(x)
K = self.w_k(x)
V = self.w_v(x)
# 分头
Q = Q.view(batch, seq_len, self.num_heads, self.d_head).transpose(1, 2)
K = K.view(batch, seq_len, self.num_heads, self.d_head).transpose(1, 2)
V = V.view(batch, seq_len, self.num_heads, self.d_head).transpose(1, 2)
# Scaled Dot-Product Attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_head)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn_weights = F.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_weights, V)
# 合并多头
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch, seq_len, self.d_model)
# 最终线性层 W_O - 关键:这是注意力输出后的一次线性变换
output = self.w_o(attn_output)
return output
# 测试 FFN 的 position-wise 特性
def test_position_wise():
d_model = 64
d_ff = 256
seq_len = 10
batch = 2
ffn = TransformerFFN(d_model, d_ff)
x = torch.randn(batch, seq_len, d_model)
output = ffn(x)
print(f"输入形状: {x.shape}")
print(f"输出形状: {output.shape}")
# 验证 position-wise:每个位置独立计算
# 修改一个位置,应该只有该位置输出变化
x_modified = x.clone()
x_modified[:, 3, :] = 1000.0
output_modified = ffn(x_modified)
diff = (output_modified - output).abs()
# 只有位置 3 的差异应该很大
print(f"位置 3 的差异: {diff[:, 3, :].mean():.6f}")
print(f"其他位置的差异: {diff[:, :3, :].mean():.6f}, {diff[:, 4:, :].mean():.6f}")
# 预期:位置 3 差异大,其他位置差异接近 0
test_position_wise()

13. 常见问题#

Q1: FFN 可以用卷积替代吗?#

可以,但通常不这么做。FFN 的优势:

  • 无需设计卷积核大小
  • 适合任意序列长度
  • 实现简单高效

Q2: FFN 的激活函数可以省略吗?#

不能。激活函数提供非线性,是 FFN 表达能力的核心。

Q3: FFN 和 MoE(Mixture of Experts)有什么关系?#

MoE 是对 FFN 的扩展:

标准 FFN: 每个 token 经过 1 个 FFN
MoE FFN: 每个 token 经过 K 个 FFN 中的 1 个(由门控选择)

总结#

关键点说明
FFN 结构两层线性层 + 非线性激活(ReLU/GELU)
维度变化d_model → d_ff (4×d_model) → d_model
操作方式Position-wise,每个 token 独立处理
参数量占 Transformer 总参数的 60-70%
与 MHA 关系MHA 处理 token 间交互,FFN 处理 token 内深度变换
Scaling 作用保持注意力方差恒定,防止 softmax 饱和
W^O 的作用将多头拼接结果映射回统一表征空间
LN vs BNLN 在每个 token 内部归一化,无 batch 依赖
BERT vs GPTEncoder vs Decoder,双向 vs 因果掩码

记住:FFN 不是拼接不同子空间信息的工具,那是 MHA 的职责。FFN 的核心作用是对每个 token 的表示进行深层非线性变换,两者共同构成 Transformer 的表示学习基础。

Transformer 架构详解
https://sgjki547.top/posts/transformer/
Author
SGJki
Published at
2026-04-08
License
CC BY-NC-SA 4.0