|
|
|
class CausalSelfAttention(nn.Module): |
|
def __init__(self, dim: int, n_heads: int, attn_dropout: float = 0.0): |
|
super().__init__() |
|
assert dim % n_heads == 0 |
|
self.nh = n_heads; self.hd = dim // n_heads |
|
self.qkv = nn.Linear(dim, 3*dim, bias=False) |
|
self.proj = nn.Linear(dim, dim, bias=False) |
|
self.attn_dropout = attn_dropout |
|
def forward(self, x): |
|
B,T,C = x.shape |
|
qkv = self.qkv(x); q,k,v = qkv.chunk(3, dim=-1) |
|
q = q.view(B,T,self.nh,self.hd).transpose(1,2) |
|
k = k.view(B,T,self.nh,self.hd).transpose(1,2) |
|
v = v.view(B,T,self.nh,self.hd).transpose(1,2) |
|
if x.is_cuda: |
|
with sdpa_ctx_prefer_flash(): |
|
y = F.scaled_dot_product_attention(q,k,v,is_causal=True, |
|
dropout_p=self.attn_dropout if self.training else 0.0) |
|
else: |
|
scale = 1.0 / math.sqrt(self.hd) |
|
att = (q @ k.transpose(-2,-1)) * scale |
|
mask = torch.full((1,1,T,T), float("-inf"), device=x.device) |
|
mask = torch.triu(mask, diagonal=1) |
|
att = (att + mask).softmax(dim=-1) |
|
y = att @ v |
|
y = y.transpose(1,2).contiguous().view(B,T,C) |
|
return self.proj(y) |
|
|
|
def _normalize_cell(X): |
|
Xc = X - X.mean(dim=0, keepdim=True) |
|
r = Xc.pow(2).sum(dim=1).mean().sqrt().clamp_min(1e-6) |
|
return Xc / r |
|
|
|
class CrystalBank(nn.Module): |
|
def __init__(self, regions: int, dim: int): |
|
super().__init__() |
|
pts = torch.randn(regions, 5, dim) / math.sqrt(dim) |
|
with torch.no_grad(): |
|
for i in range(regions): pts[i] = _normalize_cell(pts[i]) |
|
self.anchors = nn.Parameter(pts) |
|
|
|
class GeometricGate(nn.Module): |
|
def __init__(self, dim: int, regions: int, tau: float = 0.08): |
|
super().__init__() |
|
self.bank = CrystalBank(regions, dim) |
|
self.nav = nn.Linear(dim, dim, bias=False) |
|
self.tau = tau |
|
self.scale = nn.Parameter(torch.tensor(1.0)) |
|
|
|
def forward(self, h: torch.Tensor, punct_mask: Optional[torch.Tensor] = None, alpha_gate: float = 1.0, hard_mask_gate: bool=False): |
|
B,T,D = h.shape |
|
C = self.bank.anchors.size(0) |
|
H = self.nav(h).reshape(B*T, D) |
|
A = self.bank.anchors.reshape(C*5, D) |
|
|
|
x2 = (H*H).sum(dim=-1, keepdim=True) |
|
a2 = (A*A).sum(dim=-1).unsqueeze(0) |
|
xa = H @ A.T |
|
d2 = (x2 + a2 - 2*xa).clamp_min(0.0).view(B*T, C, 5) |
|
s = -torch.logsumexp(-d2 / max(1e-6,self.tau), dim=-1) |
|
w = F.softmax(-s / max(1e-6,self.tau), dim=-1).view(B,T,C) |
|
|
|
centroids = self.bank.anchors.mean(dim=1) |
|
g = (w @ centroids) |
|
|
|
if punct_mask is not None: |
|
|
|
pm = punct_mask.float().unsqueeze(-1) |
|
if hard_mask_gate: |
|
g = g * (1.0 - pm) |
|
else: |
|
g = g * (1.0 - pm*(1.0 - alpha_gate)) |
|
|
|
return w, self.scale * g |
|
|
|
class MLP(nn.Module): |
|
def __init__(self, dim, mlp_ratio=4.0, dropout=0.1): |
|
super().__init__() |
|
hidden = int(dim*mlp_ratio) |
|
self.fc1 = nn.Linear(dim, hidden) |
|
self.fc2 = nn.Linear(hidden, dim) |
|
self.drop = nn.Dropout(dropout) |
|
def forward(self, x): |
|
x = self.fc1(x) |
|
x = F.gelu(x, approximate="tanh") |
|
x = self.drop(x) |
|
x = self.fc2(x) |
|
x = self.drop(x) |
|
return x |
|
|
|
class GatedBlock(nn.Module): |
|
def __init__(self, dim, n_heads, mlp_ratio, dropout, regions): |
|
super().__init__() |
|
self.norm1 = nn.LayerNorm(dim) |
|
self.attn = CausalSelfAttention(dim, n_heads, attn_dropout=dropout) |
|
self.gate = GeometricGate(dim, regions=regions, tau=0.08) |
|
self.norm2 = nn.LayerNorm(dim) |
|
self.mlp = MLP(dim, mlp_ratio=mlp_ratio, dropout=dropout) |
|
self.mix_sdpa = nn.Parameter(torch.tensor(1.0)) |
|
|
|
def forward(self, x, punct_mask=None, return_gate=False, alpha_gate=1.0, hard_mask_gate=False): |
|
h = self.norm1(x) |
|
att = self.attn(h) * self.mix_sdpa |
|
w, g = self.gate(h, punct_mask=punct_mask, alpha_gate=alpha_gate, hard_mask_gate=hard_mask_gate) |
|
x = x + att + g |
|
x = x + self.mlp(self.norm2(x)) |
|
return (x, w) if return_gate else x |
|
|
|
class CrystalBeeper(nn.Module): |
|
def __init__(self, model_cfg: dict): |
|
super().__init__() |
|
self.cfg = model_cfg |
|
D, L, H = model_cfg["dim"], model_cfg["n_layers"], model_cfg["n_heads"] |
|
ctx = model_cfg["context"] |
|
|
|
|
|
self.use_ascii = bool(model_cfg.get("use_ascii", True)) |
|
self.codec = AsciiCodec() |
|
V_ascii = self.codec.vocab_size |
|
|
|
self.token_emb = nn.Embedding(V_ascii, D) |
|
self.pos_emb = nn.Parameter(torch.zeros(1, ctx, D)) |
|
self.drop = nn.Dropout(model_cfg.get("resid_dropout", 0.1)) |
|
|
|
regions = [int(model_cfg.get("regions_per_block", 64)) for _ in range(L)] |
|
self.blocks = nn.ModuleList([ |
|
GatedBlock(D, H, model_cfg["mlp_ratio"], model_cfg["dropout"], regions[i]) |
|
for i in range(L) |
|
]) |
|
self.norm = nn.LayerNorm(D) |
|
|
|
|
|
self.ascii_head = nn.Linear(D, V_ascii, bias=False) |
|
self.bpe_head = nn.Linear(D, model_cfg.get("vocab_size", 8192), bias=False) |
|
|
|
|
|
self.rose_proj = nn.Linear(D, D, bias=False) |
|
self.rose_anchors = nn.Parameter(torch.randn(3, D) / math.sqrt(D)) |
|
|
|
self.apply(self._init) |
|
|
|
@staticmethod |
|
def _init(m): |
|
if isinstance(m, (nn.Linear, nn.Embedding)): |
|
nn.init.normal_(m.weight, mean=0.0, std=0.02) |
|
if getattr(m, "bias", None) is not None: nn.init.zeros_(m.bias) |
|
|
|
def backbone(self, idx, punct_mask=None, return_routes=False, alpha_gate=1.0, hard_mask_gate=False): |
|
B,T = idx.shape |
|
x = self.token_emb(idx) + self.pos_emb[:, :T, :] |
|
x = self.drop(x) |
|
routes = [] |
|
for blk in self.blocks: |
|
x, w = blk(x, punct_mask=punct_mask, return_gate=True, alpha_gate=alpha_gate, hard_mask_gate=hard_mask_gate) |
|
if return_routes: routes.append(w) |
|
x = self.norm(x) |
|
return (x, routes) if return_routes else (x, None) |
|
|
|
def forward(self, idx, punct_mask=None, head="ascii", return_routes=False, alpha_gate=1.0, hard_mask_gate=False): |
|
h, routes = self.backbone(idx, punct_mask=punct_mask, return_routes=return_routes, alpha_gate=alpha_gate, hard_mask_gate=hard_mask_gate) |
|
if head == "ascii": logits = self.ascii_head(h) |
|
elif head == "bpe": logits = self.bpe_head(h) |
|
else: raise ValueError("head must be 'ascii' or 'bpe'") |
|
return logits, routes |
|
|