LayerNorm vs RMSNorm:从几何自由度到 Triton kernel 成本
RMSNorm 和 LayerNorm 都是在做归一化,但它们保留和丢弃的信息不同。几何上看,这个差异非常直观。
RMSNorm
RMSNorm 的约束可以理解为把向量缩放到一个固定半径的超球面上。
在三维例子里,它要求:
x^2 + y^2 + z^2 = 3
它去掉的是向量长度,但保留方向和平移基准。
LayerNorm
LayerNorm 不仅缩放长度,还会去均值。三维里可以写成:
x^2 + y^2 + z^2 = 3
x + y + z = 0
也就是说,它要求数据同时落在球面上和过球心的平面上。两者交集是一条圆。推广到 M 维,RMSNorm 的自由度是 M-1,LayerNorm 的自由度是 M-2。
Kernel 成本
这个几何差异会落到实现成本上。
RMSNorm 只需要维护平方和累加器。LayerNorm 需要维护均值和方差,通常需要更多 reduction 和中间值。
从 Triton kernel 的角度,RMSNorm 更轻,不是因为概念更简单,而是因为它少去掉一个统计量。
结论
RMSNorm 是“只缩放长度”。LayerNorm 是“去均值 + 缩放长度”。在大模型推理里,这种少一个统计量的差异会变成真实的 kernel 成本差异。
知识补全:为什么 RMSNorm 常见于大模型
大模型中 RMSNorm 常见,不只是因为它计算少一点,还因为它保留了均值方向的信息。LayerNorm 会去掉每个 token 表示的均值分量,RMSNorm 则只按均方根缩放。
从实现上看,RMSNorm 通常只需要:
rms = sqrt(mean(x^2) + eps)
y = x / rms * weight
LayerNorm 需要:
mean = mean(x)
var = mean((x - mean)^2)
y = (x - mean) / sqrt(var + eps) * weight + bias
多出来的均值和方差会增加 reduction 和中间计算。
学习检查清单
比较两个归一化层时,可以看:
- 它去掉了哪些信息。
- 它需要几个 reduction。
- 是否有 bias。
- 是否适合 fused residual add。
- 推理时瓶颈是计算还是访存。
- Triton kernel 里需要维护几个 accumulator。
这样看归一化,就能从数学定义走到实际 kernel 成本。