File size: 910 Bytes
96257b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
import ipdb

class ModulatedRMSNorm(torch.nn.Module):
    def __init__(self):
        super(ModulatedRMSNorm, self).__init__()

    def forward(self, x, scale, eps):
        # Convert to fp32 for precision
        x_fp32 = x.float()
        scale_fp32 = scale.float()

        # Compute RMS
        mean_square = x_fp32.pow(2).mean(-1, keepdim=True)
        inv_rms = torch.rsqrt(mean_square + eps)

        # Normalize and modulate
        x_normed = x_fp32 * inv_rms
        # x_modulated = x_normed * (1 + scale_fp32.unsqueeze(1))
        x_modulated = x_normed * (1 + scale_fp32) # TODO
        return x_modulated.type_as(x)


# def modulated_rmsnorm(x, scale, eps=1e-6):
#     return ModulatedRMSNorm.apply(x, scale, eps)

def modulated_rmsnorm(x, scale, eps=1e-6):
    norm = ModulatedRMSNorm()
    # ipdb.set_trace()
    
    return norm.forward(x, scale, eps)