Encoder-Decoder 架构概述
Encoder-Decoder架构是序列到序列(Seq2Seq)任务的经典解决方案。编码器将输入序列编码为固定长度的表示,解码器基于这个表示生成目标序列。
🔷 Encoder (编码器)
输入序列 → 上下文向量
Self-Attention
Feed Forward
LayerNorm
→
🔶 Decoder (解码器)
上下文向量 → 输出序列
Masked Self-Attention
Cross-Attention
Feed Forward
核心特点
- 双向编码:编码器可以看到整个输入序列
- 自回归解码:解码器逐个生成输出tokens
- 注意力机制:通过Cross-Attention连接编码器和解码器
- 并行训练:训练时可以并行计算所有位置
编码器 (Encoder) 详解
编码器负责理解和编码输入序列,生成丰富的上下文表示供解码器使用。
编码器层结构
单个编码器层的计算流程:
- 多头自注意力:捕捉输入序列内部的依赖关系
- 残差连接 + 层归一化:稳定训练过程
- 前馈网络:非线性变换增强表达能力
- 残差连接 + 层归一化:再次规范化输出
编码器层的数学表示:
$$\text{Encoder}(X) = \text{LayerNorm}(X + \text{FFN}(\text{LayerNorm}(X + \text{SelfAttn}(X))))$$编码器层 (EncoderLayer) 实现
class EncoderLayer(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, n_heads)
self.feed_forward = PositionwiseFeedForward(d_model, d_ff)
self.norm1 = LayerNorm(d_model)
self.norm2 = LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Self-attention with residual connection
attn_output = self.self_attn(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_output))
# Feed-forward with residual connection
ff_output = self.feed_forward(x)
x = self.norm2(x + self.dropout(ff_output))
return x
解码器 (Decoder) 详解
解码器在编码器输出的基础上,自回归地生成目标序列。它包含三个主要的注意力机制。
解码器层结构
解码器的三重注意力机制:
1. Masked Self-Attention
防止看到未来信息
2. Cross-Attention
关注编码器输出
3. Feed-Forward
非线性变换
解码器层的数学表示:
$$\begin{align} Y_1 &= \text{LayerNorm}(Y + \text{MaskedSelfAttn}(Y)) \\ Y_2 &= \text{LayerNorm}(Y_1 + \text{CrossAttn}(Y_1, H, H)) \\ Y_3 &= \text{LayerNorm}(Y_2 + \text{FFN}(Y_2)) \end{align}$$其中 $H$ 是编码器的输出,$Y$ 是解码器的输入
Cross-Attention 机制详解
Cross-Attention是连接编码器和解码器的关键机制,让解码器能够关注到输入序列的相关部分。
🔗 Cross-Attention 信息流
Query: 来自解码器 | Key & Value: 来自编码器
Cross-Attention 的计算过程
$$\text{CrossAttn}(Q, K, V) = \text{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) V$$
其中:
- $Q$ = 解码器输出 × $W_Q$ (Query矩阵)
- $K$ = 编码器输出 × $W_K$ (Key矩阵)
- $V$ = 编码器输出 × $W_V$ (Value矩阵)
Cross-Attention 可视化示例
翻译任务:"I love you" → "我爱你"
编码器输入 (源语言)
["I", "love", "you"]
0.1
0.8
0.1
0.2
0.1
0.7
0.7
0.1
0.2
解码器输出 (目标语言)
["我", "爱", "你"]
注意力权重矩阵显示了解码器每个位置对编码器各位置的关注程度
掩码机制 (Masking)
掩码机制是确保模型正确学习的重要组件,在编码器和解码器中都有应用。
填充掩码 (Padding Mask)
目的:忽略序列中的填充tokens,防止模型关注无意义的填充位置。
前瞻掩码 (Look-ahead Mask)
目的:在解码器的自注意力中,防止当前位置看到未来的信息,保持自回归特性。
掩码矩阵示例(序列长度=4)
填充掩码
1
1
1
0
1
1
1
0
1
1
1
0
0
0
0
0
前瞻掩码
1
0
0
0
1
1
0
0
1
1
1
0
1
1
1
1
掩码机制实现
def create_look_ahead_mask(seq_len):
"""创建前瞻掩码,防止看到未来信息"""
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
return mask == 0 # 1表示可以关注,0表示掩码
def create_padding_mask(seq, pad_token_id=0):
"""创建填充掩码,忽略填充tokens"""
return (seq != pad_token_id).unsqueeze(1).unsqueeze(2)
架构变体与应用
| 架构类型 | 结构特点 | 应用场景 | 代表模型 |
|---|---|---|---|
| Encoder-Only | 仅编码器,双向注意力 | 理解任务,分类 | BERT, RoBERTa |
| Decoder-Only | 仅解码器,因果注意力 | 生成任务 | GPT, LLaMA |
| Encoder-Decoder | 编码器+解码器 | 序列转换 | T5, BART |
不同架构的优劣对比
Encoder-Decoder优势
- 适合seq2seq任务
- 编码器可以双向建模
- 解码器保持因果性
- 灵活的输入输出长度
Encoder-Decoder劣势
- 结构复杂,参数多
- 训练推理开销大
- 需要更多训练数据
- 调试难度较高
训练与推理过程
训练阶段 (Teacher Forcing)
Teacher Forcing:训练时将真实的目标序列作为解码器输入,而不是使用模型的预测结果。这样可以加速训练并提供更稳定的梯度。
训练阶段 - Teacher Forcing
# 训练时的前向传播
def forward_train(self, src, tgt):
# 编码器处理源序列
encoder_output = self.encoder(src)
# 解码器使用目标序列(去掉最后一个token)
decoder_input = tgt[:, :-1] # Teacher forcing
decoder_output = self.decoder(decoder_input, encoder_output)
# 预测下一个token
logits = self.output_projection(decoder_output)
return logits
推理阶段 (Auto-regressive Generation)
自回归生成:推理时逐个生成tokens,每次将之前生成的序列作为输入来预测下一个token。
推理阶段 - 自回归生成
def generate(self, src, max_length=100):
encoder_output = self.encoder(src)
# 从开始标记开始生成
generated = [self.start_token_id]
for _ in range(max_length):
decoder_input = torch.tensor([generated]).to(src.device)
decoder_output = self.decoder(decoder_input, encoder_output)
# 预测下一个token
next_token_logits = self.output_projection(decoder_output[:, -1, :])
next_token = torch.argmax(next_token_logits, dim=-1)
generated.append(next_token.item())
# 检查结束条件
if next_token.item() == self.end_token_id:
break
return generated
实际应用案例
机器翻译 (Neural Machine Translation)
- Google Translate:基于Transformer的多语言翻译
- Facebook M2M-100:多对多语言翻译模型
- mBART:多语言预训练的翻译模型
文本摘要 (Text Summarization)
- BART:去噪自编码器用于摘要生成
- T5:文本到文本的统一框架
- Pegasus:专门针对摘要任务预训练
对话系统 (Dialogue Systems)
- BlenderBot:开放域对话系统
- DialoGPT:基于GPT的对话生成
- Meena:Google的对话AI模型
🔧 交互式演示
体验不同的注意力机制: