| import torch.nn as nn |
| import torch |
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, d, p=-1., eps=1e-8, bias=False): |
| """ |
| Root Mean Square Layer Normalization |
| :param d: model size |
| :param p: partial RMSNorm, valid value [0, 1], default -1.0 (disabled) |
| :param eps: epsilon value, default 1e-8 |
| :param bias: whether use bias term for RMSNorm, disabled by |
| default because RMSNorm doesn't enforce re-centering invariance. |
| """ |
| super(RMSNorm, self).__init__() |
|
|
| self.eps = eps |
| self.d = d |
| self.p = p |
| self.bias = bias |
|
|
| self.scale = nn.Parameter(torch.ones(d)) |
| self.register_parameter("scale", self.scale) |
|
|
| if self.bias: |
| self.offset = nn.Parameter(torch.zeros(d)) |
| self.register_parameter("offset", self.offset) |
|
|
| def forward(self, x): |
| if self.p < 0. or self.p > 1.: |
| norm_x = x.norm(2, dim=-1, keepdim=True) |
| d_x = self.d |
| else: |
| partial_size = int(self.d * self.p) |
| partial_x, _ = torch.split(x, [partial_size, self.d - partial_size], dim=-1) |
|
|
| norm_x = partial_x.norm(2, dim=-1, keepdim=True) |
| d_x = partial_size |
|
|
| rms_x = norm_x * d_x ** (-1. / 2) |
| x_normed = x / (rms_x + self.eps) |
|
|
| if self.bias: |
| return self.scale * x_normed + self.offset |
|
|
| return self.scale * x_normed |
|
|
|
|