← 返回首页

Encoder-Decoder 架构详解

编码器解码器的设计原理与Cross-Attention交互机制

Encoder-Decoder 架构概述

Encoder-Decoder架构是序列到序列(Seq2Seq)任务的经典解决方案。编码器将输入序列编码为固定长度的表示,解码器基于这个表示生成目标序列。

🔷 Encoder (编码器)

输入序列 → 上下文向量

Self-Attention
Feed Forward
LayerNorm

🔶 Decoder (解码器)

上下文向量 → 输出序列

Masked Self-Attention
Cross-Attention
Feed Forward

核心特点

  • 双向编码:编码器可以看到整个输入序列
  • 自回归解码:解码器逐个生成输出tokens
  • 注意力机制:通过Cross-Attention连接编码器和解码器
  • 并行训练:训练时可以并行计算所有位置

编码器 (Encoder) 详解

编码器负责理解和编码输入序列,生成丰富的上下文表示供解码器使用。

编码器层结构

单个编码器层的计算流程:

  1. 多头自注意力:捕捉输入序列内部的依赖关系
  2. 残差连接 + 层归一化:稳定训练过程
  3. 前馈网络:非线性变换增强表达能力
  4. 残差连接 + 层归一化:再次规范化输出

编码器层的数学表示:

$$\text{Encoder}(X) = \text{LayerNorm}(X + \text{FFN}(\text{LayerNorm}(X + \text{SelfAttn}(X))))$$
编码器层 (EncoderLayer) 实现

class EncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, n_heads)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff)
        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        # Self-attention with residual connection
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # Feed-forward with residual connection
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        
        return x
                    

解码器 (Decoder) 详解

解码器在编码器输出的基础上,自回归地生成目标序列。它包含三个主要的注意力机制。

解码器层结构

解码器的三重注意力机制:

1. Masked Self-Attention

防止看到未来信息

2. Cross-Attention

关注编码器输出

3. Feed-Forward

非线性变换

解码器层的数学表示:

$$\begin{align} Y_1 &= \text{LayerNorm}(Y + \text{MaskedSelfAttn}(Y)) \\ Y_2 &= \text{LayerNorm}(Y_1 + \text{CrossAttn}(Y_1, H, H)) \\ Y_3 &= \text{LayerNorm}(Y_2 + \text{FFN}(Y_2)) \end{align}$$

其中 $H$ 是编码器的输出,$Y$ 是解码器的输入

Cross-Attention 机制详解

Cross-Attention是连接编码器和解码器的关键机制,让解码器能够关注到输入序列的相关部分。

🔗 Cross-Attention 信息流

Query: 来自解码器 | Key & Value: 来自编码器

Cross-Attention 的计算过程

$$\text{CrossAttn}(Q, K, V) = \text{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) V$$

其中:

  • $Q$ = 解码器输出 × $W_Q$ (Query矩阵)
  • $K$ = 编码器输出 × $W_K$ (Key矩阵)
  • $V$ = 编码器输出 × $W_V$ (Value矩阵)

Cross-Attention 可视化示例

翻译任务:"I love you" → "我爱你"

编码器输入 (源语言)
["I", "love", "you"]
0.1
0.8
0.1
0.2
0.1
0.7
0.7
0.1
0.2
解码器输出 (目标语言)
["我", "爱", "你"]

注意力权重矩阵显示了解码器每个位置对编码器各位置的关注程度

掩码机制 (Masking)

掩码机制是确保模型正确学习的重要组件,在编码器和解码器中都有应用。

填充掩码 (Padding Mask)

目的:忽略序列中的填充tokens,防止模型关注无意义的填充位置。

前瞻掩码 (Look-ahead Mask)

目的:在解码器的自注意力中,防止当前位置看到未来的信息,保持自回归特性。

掩码矩阵示例(序列长度=4)

填充掩码
1
1
1
0
1
1
1
0
1
1
1
0
0
0
0
0
前瞻掩码
1
0
0
0
1
1
0
0
1
1
1
0
1
1
1
1
掩码机制实现

def create_look_ahead_mask(seq_len):
    """创建前瞻掩码,防止看到未来信息"""
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
    return mask == 0  # 1表示可以关注,0表示掩码

def create_padding_mask(seq, pad_token_id=0):
    """创建填充掩码,忽略填充tokens"""
    return (seq != pad_token_id).unsqueeze(1).unsqueeze(2)
                    

架构变体与应用

架构类型 结构特点 应用场景 代表模型
Encoder-Only 仅编码器,双向注意力 理解任务,分类 BERT, RoBERTa
Decoder-Only 仅解码器,因果注意力 生成任务 GPT, LLaMA
Encoder-Decoder 编码器+解码器 序列转换 T5, BART

不同架构的优劣对比

Encoder-Decoder优势

  • 适合seq2seq任务
  • 编码器可以双向建模
  • 解码器保持因果性
  • 灵活的输入输出长度

Encoder-Decoder劣势

  • 结构复杂,参数多
  • 训练推理开销大
  • 需要更多训练数据
  • 调试难度较高

训练与推理过程

训练阶段 (Teacher Forcing)

Teacher Forcing:训练时将真实的目标序列作为解码器输入,而不是使用模型的预测结果。这样可以加速训练并提供更稳定的梯度。

训练阶段 - Teacher Forcing

# 训练时的前向传播
def forward_train(self, src, tgt):
    # 编码器处理源序列
    encoder_output = self.encoder(src)
    
    # 解码器使用目标序列(去掉最后一个token)
    decoder_input = tgt[:, :-1]  # Teacher forcing
    decoder_output = self.decoder(decoder_input, encoder_output)
    
    # 预测下一个token
    logits = self.output_projection(decoder_output)
    return logits
                    

推理阶段 (Auto-regressive Generation)

自回归生成:推理时逐个生成tokens,每次将之前生成的序列作为输入来预测下一个token。

推理阶段 - 自回归生成

def generate(self, src, max_length=100):
    encoder_output = self.encoder(src)
    
    # 从开始标记开始生成
    generated = [self.start_token_id]
    
    for _ in range(max_length):
        decoder_input = torch.tensor([generated]).to(src.device)
        decoder_output = self.decoder(decoder_input, encoder_output)
        
        # 预测下一个token
        next_token_logits = self.output_projection(decoder_output[:, -1, :])
        next_token = torch.argmax(next_token_logits, dim=-1)
        
        generated.append(next_token.item())
        
        # 检查结束条件
        if next_token.item() == self.end_token_id:
            break
    
    return generated
                    

实际应用案例

机器翻译 (Neural Machine Translation)

  • Google Translate:基于Transformer的多语言翻译
  • Facebook M2M-100:多对多语言翻译模型
  • mBART:多语言预训练的翻译模型

文本摘要 (Text Summarization)

  • BART:去噪自编码器用于摘要生成
  • T5:文本到文本的统一框架
  • Pegasus:专门针对摘要任务预训练

对话系统 (Dialogue Systems)

  • BlenderBot:开放域对话系统
  • DialoGPT:基于GPT的对话生成
  • Meena:Google的对话AI模型

🔧 交互式演示

体验不同的注意力机制: