Spaces:
Paused
Paused
File size: 910 Bytes
96257b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import torch
import ipdb
class ModulatedRMSNorm(torch.nn.Module):
def __init__(self):
super(ModulatedRMSNorm, self).__init__()
def forward(self, x, scale, eps):
# Convert to fp32 for precision
x_fp32 = x.float()
scale_fp32 = scale.float()
# Compute RMS
mean_square = x_fp32.pow(2).mean(-1, keepdim=True)
inv_rms = torch.rsqrt(mean_square + eps)
# Normalize and modulate
x_normed = x_fp32 * inv_rms
# x_modulated = x_normed * (1 + scale_fp32.unsqueeze(1))
x_modulated = x_normed * (1 + scale_fp32) # TODO
return x_modulated.type_as(x)
# def modulated_rmsnorm(x, scale, eps=1e-6):
# return ModulatedRMSNorm.apply(x, scale, eps)
def modulated_rmsnorm(x, scale, eps=1e-6):
norm = ModulatedRMSNorm()
# ipdb.set_trace()
return norm.forward(x, scale, eps) |