从 CS336 作业理解 Transformer 训练基本功
训练系统的基础不是某个单独概念,而是一组张量形状、内存和数值表示之间的约定。
Activation checkpointing
Checkpointing 常被概括为“时间换空间”。前向时不保存所有中间激活,反向时重新计算一部分,从而减少显存占用。
它不是保存参数矩阵的 checkpoint,而是保存计算图中必要的边界状态。
Linear 与 einsum
线性层可以写成:
... d_in, d_out d_in -> ... d_out
这说明输入最后一维是 d_in,权重提供 d_out 个输出方向,结果最后一维变成 d_out。
理解这个 einsum,比记住某个框架 API 更可靠。
权重布局
PyTorch 的 Linear 权重常按 (out_features, in_features) 存。这会让初学者觉得和数学里的 xW 方向相反。
实际实现会根据内存布局和 GEMM 调用做选择。重要的是知道每一维代表输入还是输出,而不是死记矩阵在纸面上的方向。
BF16
BF16 适合 LLM 的关键原因是指数位和 FP32 一样多,动态范围大,但尾数更短。它牺牲精度,保留范围,因此比 FP16 更不容易溢出。
训练基础最终都要回到三个问题:张量形状是什么、内存保存什么、数值格式能不能承受当前计算。
知识补全:形状推理是训练系统的地基
Transformer 训练里的很多 bug 都不是公式错,而是形状理解错。batch、sequence、head、hidden、vocab、expert 这些维度在不同模块中不断重排。
例如 attention 常见形状是:
(batch, seq, hidden)
-> (batch, heads, seq, head_dim)
MLP 则常在 hidden 和 intermediate hidden 之间变换。MoE 又会多出 expert 维度和 token dispatch。
掌握形状推理后,einsum、reshape、transpose、contiguous、shard 都会更容易理解。
数值格式的直觉
FP16 尾数更多但指数范围小,容易溢出。BF16 尾数更少但范围大,因此在大模型中更稳。FP32 通常用于累加、优化器状态或敏感计算。
混合精度训练的核心不是“全都用低精度”,而是把不同计算放在合适精度上。
学习检查清单
读训练代码时,可以逐步标注:
- 每个张量的 shape。
- 每次 matmul 的输入输出维度。
- 哪些激活会被保存到反向。
- 哪些地方用了 checkpointing。
- 参数、梯度、优化器状态分别是什么 dtype。
- 是否存在隐式转置或 contiguous 拷贝。
这比单独背 Transformer 结构更接近工程实践。