引言
从 2017 年 Transformer 架构被提出以来,到 2025 已经 8 年过去了,Transformer 架构已经发生了很多变化。比如,现如今越来越多的大模型采用的是 RMSNorm1 而不是 LayerNorm。今天这篇文章就是对 RMSNorm 的一个简单介绍,在了解 RMSNorm 之前,我们不妨先回顾一下什么是 LayerNorm
LayerNorm 回顾
$\mathbf y=\frac{\mathbf x-E[\mathbf x]}{\sqrt{Var(\mathbf x)+\epsilon}}*\gamma+\beta$
上面是 LayerNorm 的公式,如果我们忽略放缩因子 $\gamma,\beta$ 不看,LayerNorm 做的事情很好理解:将每一个样本的特征向量 $x$ 转变为均值为 0,标准差为 1 的特征向量
为什么 LayerNorm 是有用的呢?之前流行的解释是
- re-centering:输入 $\mathbf x$ 总是会减去均值 $E[\mathbf x]$。好处是如果输入 $\mathbf x$ 发生了整体的偏移(Shift Noise)也没事,输入 $\mathbf x$ 始终会在 0 的附近
- re-scaling:减去均值之后总是会除以 $\sqrt{Var(\mathbf x)+\epsilon}$。好处是如果输入 $\mathbf x$ 被成比例放缩,也没有影响
可以写个简单的 PyTorch 代码验证一下
|
1 2 3 4 5 6 7 8 9 10 11 12 13 |
import torch def re_centering(x): return x - x.mean(dim=-1) def re_scaling(x): return x / (x.std(dim=-1) + 1e-5) x = torch.arange(4).float() print(x, re_centering(x + 10000)) # tensor([0., 1., 2., 3.]) tensor([-1.5000, -0.5000, 0.5000, 1.5000]) print(x, re_scaling(x * 10000)) # tensor([0., 1., 2., 3.]) tensor([0.0000, 0.7746, 1.5492, 2.3238]) |
RMSNorm
RMSNorm 认为 LayerNorm 的价值在于 re-scaling 特性,跟 re-centering 倒是关系不大1,所以在设计 RMSNorm 的时候作者只考虑如何做 re-scaling。下面是 RMSNorm 的公式
$ \mathbf y=\frac{\mathbf x}{\sqrt{\frac{1}{n}\sum_ix_i^2+\epsilon}}*\gamma $
和 LayerNorm 对比,主要的几个差异如下
- 分子不需要减去 $E[\mathbf x]$
- 分母从 $Var(\mathbf x)$ 变成了 $\frac{1}{n}\sum_ix_i^2$
- 只需要维护 $\gamma$ 参数,不需要维护 $\beta$
RMSNorm 的好处
通过上面观察到的几点差异,我们可以看出 RMSNorm 的一些显而易见的好处:
- 需要维护的参数更少了,只有 $\gamma$
- 计算量也减少了,因为不用计算输入 $\mathbf x$ 的均值 $E[\mathbf x]$(注意 $Var(\mathbf x)$ 的计算也需要均值)
当然,最重要的是,RMSNorm 的效果还真就挺好的,跟 LayerNorm 也差不了多少,具体的实验细节和结果可以参考原论文1
PyTorch API
PyTorch 提供的 nn.RMSNorm 实现有如下的几个参数
- normalized_shape:表示用于计算 RMS 基于的输入张量的末尾维度
- eps:为了数值稳定,加上的一个很小的值
- element_affine:是否要启用可学习参数 $\gamma$?
|
1 2 3 |
>>> rms_norm = nn.RMSNorm([2, 3]) >>> input = torch.randn(2, 2, 3) >>> rms_norm(input) |
RMSNorm from Scratch
手写 RMSNorm 的难度不是很大,下面我写的代码可以作为参考
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
import torch import torch.nn as nn import torch.nn.functional as F class RMSNorm(nn.Module): def __init__( self, normalized_shape: list | tuple, eps: float = 1e-5, element_affine: bool = True, ): super().__init__() self.eps = eps self.element_affine = element_affine if self.element_affine: self.gamma = nn.Parameter(torch.ones(normalized_shape)) else: self.register_parameter("gamma", None) def forward(self, x: torch.Tensor): x = x * torch.rsqrt(self.eps + x.pow(2).mean(dim=-1, keepdim=True)) return x if self.gamma is None else x * self.gamma |
-
Zhang, Biao, and Rico Sennrich. “Root mean square layer normalization.” Advances in Neural Information Processing Systems 32 (2019). ↩︎ ↩︎ ↩︎