type
Post
status
Published
password
date
May 2, 2026
slug
summary
category
人工智能
URL
tags
强化学习
LLM
icon
原理
传统训练方式会在前向传播时保存所有中间层的激活值,用于反向传播计算梯度,这导致激活值内存占用随模型层数线性增长(通常占 GPU 显存的 60%-80%)。激活检查点采取 "抓大放小" 策略:
- 前向传播:将网络划分为若干片段,仅保存关键 "检查点" 位置的激活值(如每个 Transformer 层的输入)
- 反向传播:需要梯度时,从最近的检查点临时 重计算(recompute) 中间激活值。
效果
- 内存节省:通常可减少 50%-70% 的激活值显存占用。
- 计算代价:增加约20%-30% 的计算量(中间激活需要计算两次)
常见检查点
网络 | 选择 | 备注 |
Self-Attention 层 | 必选 | • Q/K/V 矩阵 + 注意力分数 = O (batch × seq² × head)
• 序列长度 2048 时,单层激活值可达 10GB+ |
Feed-Forward Network (FFN) | 推荐 | • 中间隐藏层通常是 embedding 维度的 4 倍
• 激活值内存仅次于注意力层 |
Embedding 层 & 输出层 | 避免 | • 通常只计算一次,重算收益低
• Embedding 查表操作重算代价高 |
代码实战
pytorch
单层
- 推荐使用:use_reentrant=False
特性 | use_reentrant=True | use_reentrant=False |
计算效率 | 必须完整重算整个函数 | 只重算到需要的位置,更快 |
torch.autograd.grad 支持 | ❌ 不支持 | ✅ 完全支持 |
关键字参数 **kwargs | ❌ 不支持 | ✅ 支持 |
嵌套结构张量(列表 / 字典) | ❌ 忽略其中的张量 | ✅ 正确处理 |
requires_grad 限制 | 至少一个输入输出必须为 True | 无限制 |
嵌套 checkpoint | 有各种限制 | ✅ 正常工作 |
调试信息 | 差,报错栈混乱 | ✅ 清晰的错误堆栈 |
torch.compile 兼容 | 差 | ✅ 良好 |