rahul7star's picture
Migrated from GitHub
96257b2 verified
import torch
import ipdb
class ModulatedRMSNorm(torch.nn.Module):
def __init__(self):
super(ModulatedRMSNorm, self).__init__()
def forward(self, x, scale, eps):
# Convert to fp32 for precision
x_fp32 = x.float()
scale_fp32 = scale.float()
# Compute RMS
mean_square = x_fp32.pow(2).mean(-1, keepdim=True)
inv_rms = torch.rsqrt(mean_square + eps)
# Normalize and modulate
x_normed = x_fp32 * inv_rms
# x_modulated = x_normed * (1 + scale_fp32.unsqueeze(1))
x_modulated = x_normed * (1 + scale_fp32) # TODO
return x_modulated.type_as(x)
# def modulated_rmsnorm(x, scale, eps=1e-6):
# return ModulatedRMSNorm.apply(x, scale, eps)
def modulated_rmsnorm(x, scale, eps=1e-6):
norm = ModulatedRMSNorm()
# ipdb.set_trace()
return norm.forward(x, scale, eps)