class CausalSelfAttention(nn.Module): def __init__(self, dim: int, n_heads: int, attn_dropout: float = 0.0): super().__init__() assert dim % n_heads == 0 self.nh = n_heads; self.hd = dim // n_heads self.qkv = nn.Linear(dim, 3*dim, bias=False) self.proj = nn.Linear(dim, dim, bias=False) self.attn_dropout = attn_dropout def forward(self, x): B,T,C = x.shape qkv = self.qkv(x); q,k,v = qkv.chunk(3, dim=-1) q = q.view(B,T,self.nh,self.hd).transpose(1,2) k = k.view(B,T,self.nh,self.hd).transpose(1,2) v = v.view(B,T,self.nh,self.hd).transpose(1,2) if x.is_cuda: with sdpa_ctx_prefer_flash(): y = F.scaled_dot_product_attention(q,k,v,is_causal=True, dropout_p=self.attn_dropout if self.training else 0.0) else: scale = 1.0 / math.sqrt(self.hd) att = (q @ k.transpose(-2,-1)) * scale mask = torch.full((1,1,T,T), float("-inf"), device=x.device) mask = torch.triu(mask, diagonal=1) att = (att + mask).softmax(dim=-1) y = att @ v y = y.transpose(1,2).contiguous().view(B,T,C) return self.proj(y) def _normalize_cell(X): # X: [V,D] Xc = X - X.mean(dim=0, keepdim=True) r = Xc.pow(2).sum(dim=1).mean().sqrt().clamp_min(1e-6) return Xc / r class CrystalBank(nn.Module): def __init__(self, regions: int, dim: int): super().__init__() pts = torch.randn(regions, 5, dim) / math.sqrt(dim) with torch.no_grad(): for i in range(regions): pts[i] = _normalize_cell(pts[i]) self.anchors = nn.Parameter(pts) # [C,5,D] class GeometricGate(nn.Module): def __init__(self, dim: int, regions: int, tau: float = 0.08): super().__init__() self.bank = CrystalBank(regions, dim) self.nav = nn.Linear(dim, dim, bias=False) self.tau = tau self.scale = nn.Parameter(torch.tensor(1.0)) # residual mix scaler def forward(self, h: torch.Tensor, punct_mask: Optional[torch.Tensor] = None, alpha_gate: float = 1.0, hard_mask_gate: bool=False): B,T,D = h.shape C = self.bank.anchors.size(0) H = self.nav(h).reshape(B*T, D) # [BT,D] A = self.bank.anchors.reshape(C*5, D) # [C*5,D] # squared distances via expansion x2 = (H*H).sum(dim=-1, keepdim=True) # [BT,1] a2 = (A*A).sum(dim=-1).unsqueeze(0) # [1,C*5] xa = H @ A.T # [BT,C*5] d2 = (x2 + a2 - 2*xa).clamp_min(0.0).view(B*T, C, 5) s = -torch.logsumexp(-d2 / max(1e-6,self.tau), dim=-1) # [BT,C] w = F.softmax(-s / max(1e-6,self.tau), dim=-1).view(B,T,C) centroids = self.bank.anchors.mean(dim=1) # [C,D] g = (w @ centroids) # [B,T,D] if punct_mask is not None: # alpha soft mask reduces gate on punctuation tokens pm = punct_mask.float().unsqueeze(-1) # [B,T,1], 1 on punct if hard_mask_gate: g = g * (1.0 - pm) # zero gate on punct else: g = g * (1.0 - pm*(1.0 - alpha_gate)) return w, self.scale * g class MLP(nn.Module): def __init__(self, dim, mlp_ratio=4.0, dropout=0.1): super().__init__() hidden = int(dim*mlp_ratio) self.fc1 = nn.Linear(dim, hidden) self.fc2 = nn.Linear(hidden, dim) self.drop = nn.Dropout(dropout) def forward(self, x): x = self.fc1(x) x = F.gelu(x, approximate="tanh") x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class GatedBlock(nn.Module): def __init__(self, dim, n_heads, mlp_ratio, dropout, regions): super().__init__() self.norm1 = nn.LayerNorm(dim) self.attn = CausalSelfAttention(dim, n_heads, attn_dropout=dropout) self.gate = GeometricGate(dim, regions=regions, tau=0.08) self.norm2 = nn.LayerNorm(dim) self.mlp = MLP(dim, mlp_ratio=mlp_ratio, dropout=dropout) self.mix_sdpa = nn.Parameter(torch.tensor(1.0)) # stage-adjustable def forward(self, x, punct_mask=None, return_gate=False, alpha_gate=1.0, hard_mask_gate=False): h = self.norm1(x) att = self.attn(h) * self.mix_sdpa w, g = self.gate(h, punct_mask=punct_mask, alpha_gate=alpha_gate, hard_mask_gate=hard_mask_gate) x = x + att + g x = x + self.mlp(self.norm2(x)) return (x, w) if return_gate else x class CrystalBeeper(nn.Module): def __init__(self, model_cfg: dict): super().__init__() self.cfg = model_cfg D, L, H = model_cfg["dim"], model_cfg["n_layers"], model_cfg["n_heads"] ctx = model_cfg["context"] # Ingress self.use_ascii = bool(model_cfg.get("use_ascii", True)) self.codec = AsciiCodec() V_ascii = self.codec.vocab_size self.token_emb = nn.Embedding(V_ascii, D) self.pos_emb = nn.Parameter(torch.zeros(1, ctx, D)) self.drop = nn.Dropout(model_cfg.get("resid_dropout", 0.1)) regions = [int(model_cfg.get("regions_per_block", 64)) for _ in range(L)] self.blocks = nn.ModuleList([ GatedBlock(D, H, model_cfg["mlp_ratio"], model_cfg["dropout"], regions[i]) for i in range(L) ]) self.norm = nn.LayerNorm(D) # Output heads self.ascii_head = nn.Linear(D, V_ascii, bias=False) self.bpe_head = nn.Linear(D, model_cfg.get("vocab_size", 8192), bias=False) # optional # Global tug (Rose) self.rose_proj = nn.Linear(D, D, bias=False) self.rose_anchors = nn.Parameter(torch.randn(3, D) / math.sqrt(D)) self.apply(self._init) @staticmethod def _init(m): if isinstance(m, (nn.Linear, nn.Embedding)): nn.init.normal_(m.weight, mean=0.0, std=0.02) if getattr(m, "bias", None) is not None: nn.init.zeros_(m.bias) def backbone(self, idx, punct_mask=None, return_routes=False, alpha_gate=1.0, hard_mask_gate=False): B,T = idx.shape x = self.token_emb(idx) + self.pos_emb[:, :T, :] x = self.drop(x) routes = [] for blk in self.blocks: x, w = blk(x, punct_mask=punct_mask, return_gate=True, alpha_gate=alpha_gate, hard_mask_gate=hard_mask_gate) if return_routes: routes.append(w) x = self.norm(x) return (x, routes) if return_routes else (x, None) def forward(self, idx, punct_mask=None, head="ascii", return_routes=False, alpha_gate=1.0, hard_mask_gate=False): h, routes = self.backbone(idx, punct_mask=punct_mask, return_routes=return_routes, alpha_gate=alpha_gate, hard_mask_gate=hard_mask_gate) if head == "ascii": logits = self.ascii_head(h) elif head == "bpe": logits = self.bpe_head(h) else: raise ValueError("head must be 'ascii' or 'bpe'") return logits, routes