|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ShivikM1Config(PretrainedConfig): |
|
|
model_type = "shivik-m1" |
|
|
def __init__(self, **kwargs): |
|
|
|
|
|
kwargs.setdefault("vocab_size", 49156) |
|
|
kwargs.setdefault("d_model", 2048) |
|
|
kwargs.setdefault("n_layers", 24) |
|
|
kwargs.setdefault("num_heads", 16) |
|
|
kwargs.setdefault("num_paths", 3) |
|
|
kwargs.setdefault("rotary_dim", 128) |
|
|
kwargs.setdefault("context_length", 4096) |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
import torch.nn.functional as F |
|
|
|
|
|
class RMSNorm(nn.Module): |
|
|
def __init__(self, dim, eps=1e-6): |
|
|
super().__init__() |
|
|
self.eps = eps |
|
|
self.weight = nn.Parameter(torch.ones(dim)) |
|
|
def forward(self, x): |
|
|
norm = x.pow(2).mean(-1, keepdim=True) |
|
|
x = x * torch.rsqrt(norm + self.eps) |
|
|
return x * self.weight |
|
|
|
|
|
def apply_rope(x, cos, sin): |
|
|
x1 = x[..., ::2] |
|
|
x2 = x[..., 1::2] |
|
|
x_rot = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1) |
|
|
return x_rot |
|
|
|
|
|
class MultiPathAttention(nn.Module): |
|
|
def __init__(self, d_model, num_heads, num_paths, rotary_dim): |
|
|
super().__init__() |
|
|
self.d_model = d_model |
|
|
self.num_heads = num_heads |
|
|
self.num_paths = num_paths |
|
|
self.head_dim = d_model // num_heads |
|
|
self.rotary_dim = rotary_dim |
|
|
self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=False) |
|
|
self.o_proj = nn.Linear(d_model, d_model, bias=False) |
|
|
self.path_weights = nn.Parameter(torch.zeros(num_paths)) |
|
|
def forward(self, x, cos, sin, mask, past_kv=None): |
|
|
B, T, C = x.shape |
|
|
qkv = self.qkv_proj(x) |
|
|
q, k, v = qkv.chunk(3, dim=-1) |
|
|
q = q.view(B, T, self.num_heads, self.head_dim) |
|
|
k = k.view(B, T, self.num_heads, self.head_dim) |
|
|
v = v.view(B, T, self.num_heads, self.head_dim) |
|
|
if self.rotary_dim > 0: |
|
|
q[..., :self.rotary_dim] = apply_rope(q[..., :self.rotary_dim], cos, sin) |
|
|
k[..., :self.rotary_dim] = apply_rope(k[..., :self.rotary_dim], cos, sin) |
|
|
if past_kv is not None: |
|
|
past_k, past_v = past_kv |
|
|
k = torch.cat([past_k, k], dim=1) |
|
|
v = torch.cat([past_v, v], dim=1) |
|
|
present = (k, v) |
|
|
path_attn = [] |
|
|
for _ in range(self.num_paths): |
|
|
scores = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5) |
|
|
scores = scores + mask |
|
|
att = scores.softmax(-1) |
|
|
out = att @ v |
|
|
path_attn.append(out) |
|
|
weights = F.softmax(self.path_weights, dim=0) |
|
|
final = sum(w * p for w, p in zip(weights, path_attn)) |
|
|
final = final.reshape(B, T, C) |
|
|
out = self.o_proj(final) |
|
|
return out, present |
|
|
|
|
|
class SwiGLU(nn.Module): |
|
|
def __init__(self, dim, hidden_dim): |
|
|
super().__init__() |
|
|
self.w1 = nn.Linear(dim, hidden_dim) |
|
|
self.w2 = nn.Linear(dim, hidden_dim) |
|
|
self.w3 = nn.Linear(hidden_dim, dim) |
|
|
def forward(self, x): |
|
|
return self.w3(F.silu(self.w1(x)) * self.w2(x)) |
|
|
|
|
|
class AriesBlock(nn.Module): |
|
|
def __init__(self, cfg): |
|
|
super().__init__() |
|
|
self.norm1 = RMSNorm(cfg.d_model) |
|
|
self.attn = MultiPathAttention(cfg.d_model, cfg.num_heads, cfg.num_paths, cfg.rotary_dim) |
|
|
self.norm2 = RMSNorm(cfg.d_model) |
|
|
self.mlp = SwiGLU(cfg.d_model, 4 * cfg.d_model) |
|
|
def forward(self, x, cos, sin, mask, past_kv=None): |
|
|
h, present = self.attn(self.norm1(x), cos, sin, mask, past_kv) |
|
|
x = x + h |
|
|
x = x + self.mlp(self.norm2(x)) |
|
|
return x, present |
|
|
|
|
|
class ShivikM1Model(nn.Module): |
|
|
def __init__(self, cfg: ShivikM1Config): |
|
|
super().__init__() |
|
|
vocab_size = getattr(cfg, "vocab_size", 49156) |
|
|
d_model = getattr(cfg, "d_model", 2048) |
|
|
n_layers = getattr(cfg, "n_layers", 24) |
|
|
num_heads = getattr(cfg, "num_heads", 16) |
|
|
ctxt = getattr(cfg, "context_length", 4096) |
|
|
num_paths = getattr(cfg, "num_paths", 3) |
|
|
rotary_dim = getattr(cfg, "rotary_dim", 128) |
|
|
self.cfg = cfg |
|
|
self.token_embed = nn.Embedding(vocab_size, d_model) |
|
|
self.pos_embed = nn.Parameter(torch.zeros(1, ctxt, d_model)) |
|
|
mask = torch.tril(torch.ones(ctxt, ctxt)).unsqueeze(0).unsqueeze(0) |
|
|
mask = (mask == 0).float() * -1e4 |
|
|
self.register_buffer("causal_mask", mask) |
|
|
t = torch.arange(ctxt) |
|
|
freqs = 1.0 / (10000 ** (torch.arange(0, rotary_dim, 2) / rotary_dim)) |
|
|
angles = torch.einsum("i,j->ij", t, freqs) |
|
|
cos = angles.cos().unsqueeze(1) |
|
|
sin = angles.sin().unsqueeze(1) |
|
|
self.register_buffer("rope_cos", cos) |
|
|
self.register_buffer("rope_sin", sin) |
|
|
self.blocks = nn.ModuleList([AriesBlock(cfg) for _ in range(n_layers)]) |
|
|
self.norm = RMSNorm(d_model) |
|
|
self.lm_head = nn.Linear(d_model, vocab_size, bias=False) |
|
|
|
|
|
try: |
|
|
self.lm_head.weight = self.token_embed.weight |
|
|
except Exception: |
|
|
pass |
|
|
def forward(self, input_ids, past_kvs=None): |
|
|
B, T = input_ids.shape |
|
|
x = self.token_embed(input_ids) + self.pos_embed[:, :T] |
|
|
mask = self.causal_mask[:, :, :T, :T] |
|
|
presents = [] |
|
|
if past_kvs is None: |
|
|
past_kvs = [None] * len(self.blocks) |
|
|
for i, block in enumerate(self.blocks): |
|
|
x, present = block(x, self.rope_cos[:T], self.rope_sin[:T], mask, past_kvs[i]) |
|
|
presents.append(present) |
|
|
x = self.norm(x) |
|
|
logits = self.lm_head(x) |
|
|
return {"logits": logits, "present_kvs": presents} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ShivikM1ForCausalLM(PreTrainedModel): |
|
|
config_class = ShivikM1Config |
|
|
base_model_prefix = "shivik_m1" |
|
|
def __init__(self, config): |
|
|
PreTrainedModel.__init__(self, config) |
|
|
self.model = ShivikM1Model(config) |
|
|
def forward(self, input_ids=None, **kwargs): |
|
|
return self.model(input_ids, **kwargs) |
|
|
|