⛩️梯度检查点
2026-5-2
| 2026-5-7
字数 1504阅读时长 4 分钟
type
Post
status
Published
password
date
May 2, 2026
slug
summary
category
人工智能
URL
tags
强化学习
LLM
icon

原理

传统训练方式会在前向传播时保存所有中间层的激活值,用于反向传播计算梯度,这导致激活值内存占用随模型层数线性增长(通常占 GPU 显存的 60%-80%)。激活检查点采取 "抓大放小" 策略:
  • 前向传播:将网络划分为若干片段,仅保存关键 "检查点" 位置的激活值(如每个 Transformer 层的输入)
  • 反向传播:需要梯度时,从最近的检查点临时 重计算(recompute) 中间激活值。

效果

  • 内存节省:通常可减少 50%-70% 的激活值显存占用。
  • 计算代价:增加约20%-30% 的计算量(中间激活需要计算两次)

常见检查点

网络
选择
备注
Self-Attention 层
必选
• Q/K/V 矩阵 + 注意力分数 = O (batch × seq² × head) • 序列长度 2048 时,单层激活值可达 10GB+
Feed-Forward Network (FFN)
推荐
• 中间隐藏层通常是 embedding 维度的 4 倍 • 激活值内存仅次于注意力层
Embedding 层 & 输出层
避免
• 通常只计算一次,重算收益低 • Embedding 查表操作重算代价高

代码实战

pytorch

单层

  • 推荐使用:use_reentrant=False
    • 特性
      use_reentrant=True
      use_reentrant=False
      计算效率
      必须完整重算整个函数
      只重算到需要的位置,更快
      torch.autograd.grad 支持
      ❌ 不支持
      ✅ 完全支持
      关键字参数 **kwargs
      ❌ 不支持
      ✅ 支持
      嵌套结构张量(列表 / 字典)
      ❌ 忽略其中的张量
      ✅ 正确处理
      requires_grad 限制
      至少一个输入输出必须为 True
      无限制
      嵌套 checkpoint
      有各种限制
      ✅ 正常工作
      调试信息
      差,报错栈混乱
      ✅ 清晰的错误堆栈
      torch.compile 兼容
      ✅ 良好
完整代码

批量处理

  • 强化学习
  • LLM
  • 在 X86 设备上使用 Docker 构建 ARM 镜像二十二、强化学习-大模型
    Loading...