beeper-ascii-v1 / model.py
AbstractPhil's picture
Create model.py
4c9ea57 verified
class CausalSelfAttention(nn.Module):
def __init__(self, dim: int, n_heads: int, attn_dropout: float = 0.0):
super().__init__()
assert dim % n_heads == 0
self.nh = n_heads; self.hd = dim // n_heads
self.qkv = nn.Linear(dim, 3*dim, bias=False)
self.proj = nn.Linear(dim, dim, bias=False)
self.attn_dropout = attn_dropout
def forward(self, x):
B,T,C = x.shape
qkv = self.qkv(x); q,k,v = qkv.chunk(3, dim=-1)
q = q.view(B,T,self.nh,self.hd).transpose(1,2)
k = k.view(B,T,self.nh,self.hd).transpose(1,2)
v = v.view(B,T,self.nh,self.hd).transpose(1,2)
if x.is_cuda:
with sdpa_ctx_prefer_flash():
y = F.scaled_dot_product_attention(q,k,v,is_causal=True,
dropout_p=self.attn_dropout if self.training else 0.0)
else:
scale = 1.0 / math.sqrt(self.hd)
att = (q @ k.transpose(-2,-1)) * scale
mask = torch.full((1,1,T,T), float("-inf"), device=x.device)
mask = torch.triu(mask, diagonal=1)
att = (att + mask).softmax(dim=-1)
y = att @ v
y = y.transpose(1,2).contiguous().view(B,T,C)
return self.proj(y)
def _normalize_cell(X): # X: [V,D]
Xc = X - X.mean(dim=0, keepdim=True)
r = Xc.pow(2).sum(dim=1).mean().sqrt().clamp_min(1e-6)
return Xc / r
class CrystalBank(nn.Module):
def __init__(self, regions: int, dim: int):
super().__init__()
pts = torch.randn(regions, 5, dim) / math.sqrt(dim)
with torch.no_grad():
for i in range(regions): pts[i] = _normalize_cell(pts[i])
self.anchors = nn.Parameter(pts) # [C,5,D]
class GeometricGate(nn.Module):
def __init__(self, dim: int, regions: int, tau: float = 0.08):
super().__init__()
self.bank = CrystalBank(regions, dim)
self.nav = nn.Linear(dim, dim, bias=False)
self.tau = tau
self.scale = nn.Parameter(torch.tensor(1.0)) # residual mix scaler
def forward(self, h: torch.Tensor, punct_mask: Optional[torch.Tensor] = None, alpha_gate: float = 1.0, hard_mask_gate: bool=False):
B,T,D = h.shape
C = self.bank.anchors.size(0)
H = self.nav(h).reshape(B*T, D) # [BT,D]
A = self.bank.anchors.reshape(C*5, D) # [C*5,D]
# squared distances via expansion
x2 = (H*H).sum(dim=-1, keepdim=True) # [BT,1]
a2 = (A*A).sum(dim=-1).unsqueeze(0) # [1,C*5]
xa = H @ A.T # [BT,C*5]
d2 = (x2 + a2 - 2*xa).clamp_min(0.0).view(B*T, C, 5)
s = -torch.logsumexp(-d2 / max(1e-6,self.tau), dim=-1) # [BT,C]
w = F.softmax(-s / max(1e-6,self.tau), dim=-1).view(B,T,C)
centroids = self.bank.anchors.mean(dim=1) # [C,D]
g = (w @ centroids) # [B,T,D]
if punct_mask is not None:
# alpha soft mask reduces gate on punctuation tokens
pm = punct_mask.float().unsqueeze(-1) # [B,T,1], 1 on punct
if hard_mask_gate:
g = g * (1.0 - pm) # zero gate on punct
else:
g = g * (1.0 - pm*(1.0 - alpha_gate))
return w, self.scale * g
class MLP(nn.Module):
def __init__(self, dim, mlp_ratio=4.0, dropout=0.1):
super().__init__()
hidden = int(dim*mlp_ratio)
self.fc1 = nn.Linear(dim, hidden)
self.fc2 = nn.Linear(hidden, dim)
self.drop = nn.Dropout(dropout)
def forward(self, x):
x = self.fc1(x)
x = F.gelu(x, approximate="tanh")
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class GatedBlock(nn.Module):
def __init__(self, dim, n_heads, mlp_ratio, dropout, regions):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = CausalSelfAttention(dim, n_heads, attn_dropout=dropout)
self.gate = GeometricGate(dim, regions=regions, tau=0.08)
self.norm2 = nn.LayerNorm(dim)
self.mlp = MLP(dim, mlp_ratio=mlp_ratio, dropout=dropout)
self.mix_sdpa = nn.Parameter(torch.tensor(1.0)) # stage-adjustable
def forward(self, x, punct_mask=None, return_gate=False, alpha_gate=1.0, hard_mask_gate=False):
h = self.norm1(x)
att = self.attn(h) * self.mix_sdpa
w, g = self.gate(h, punct_mask=punct_mask, alpha_gate=alpha_gate, hard_mask_gate=hard_mask_gate)
x = x + att + g
x = x + self.mlp(self.norm2(x))
return (x, w) if return_gate else x
class CrystalBeeper(nn.Module):
def __init__(self, model_cfg: dict):
super().__init__()
self.cfg = model_cfg
D, L, H = model_cfg["dim"], model_cfg["n_layers"], model_cfg["n_heads"]
ctx = model_cfg["context"]
# Ingress
self.use_ascii = bool(model_cfg.get("use_ascii", True))
self.codec = AsciiCodec()
V_ascii = self.codec.vocab_size
self.token_emb = nn.Embedding(V_ascii, D)
self.pos_emb = nn.Parameter(torch.zeros(1, ctx, D))
self.drop = nn.Dropout(model_cfg.get("resid_dropout", 0.1))
regions = [int(model_cfg.get("regions_per_block", 64)) for _ in range(L)]
self.blocks = nn.ModuleList([
GatedBlock(D, H, model_cfg["mlp_ratio"], model_cfg["dropout"], regions[i])
for i in range(L)
])
self.norm = nn.LayerNorm(D)
# Output heads
self.ascii_head = nn.Linear(D, V_ascii, bias=False)
self.bpe_head = nn.Linear(D, model_cfg.get("vocab_size", 8192), bias=False) # optional
# Global tug (Rose)
self.rose_proj = nn.Linear(D, D, bias=False)
self.rose_anchors = nn.Parameter(torch.randn(3, D) / math.sqrt(D))
self.apply(self._init)
@staticmethod
def _init(m):
if isinstance(m, (nn.Linear, nn.Embedding)):
nn.init.normal_(m.weight, mean=0.0, std=0.02)
if getattr(m, "bias", None) is not None: nn.init.zeros_(m.bias)
def backbone(self, idx, punct_mask=None, return_routes=False, alpha_gate=1.0, hard_mask_gate=False):
B,T = idx.shape
x = self.token_emb(idx) + self.pos_emb[:, :T, :]
x = self.drop(x)
routes = []
for blk in self.blocks:
x, w = blk(x, punct_mask=punct_mask, return_gate=True, alpha_gate=alpha_gate, hard_mask_gate=hard_mask_gate)
if return_routes: routes.append(w)
x = self.norm(x)
return (x, routes) if return_routes else (x, None)
def forward(self, idx, punct_mask=None, head="ascii", return_routes=False, alpha_gate=1.0, hard_mask_gate=False):
h, routes = self.backbone(idx, punct_mask=punct_mask, return_routes=return_routes, alpha_gate=alpha_gate, hard_mask_gate=hard_mask_gate)
if head == "ascii": logits = self.ascii_head(h)
elif head == "bpe": logits = self.bpe_head(h)
else: raise ValueError("head must be 'ascii' or 'bpe'")
return logits, routes