#!/usr/bin/env python3 """ BINARY TRANSFORMER - Raw network bytes → neural network No tokenizer. No preprocessing. Just bytes. Vocab = 256 (one token per byte value 0x00-0xFF) Input: Raw bytes from network stream via stdin """ import sys import math import time import torch import torch.nn as nn import torch.nn.functional as F from collections import deque DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.backends.cuda.matmul.allow_tf32 = True # Binary model config - TINY for speed CONFIG = { "d": 128, # smaller embedding "layers": 3, # fewer layers "heads": 4, "vocab": 256, # ONE TOKEN PER BYTE "ctx": 1024, # longer context (bytes are fine-grained) } LR = 3e-4 UPDATE_EVERY = 64 # bytes between updates PRINT_EVERY = 50000 # bytes between stats class ByteAttention(nn.Module): def __init__(self, d, h): super().__init__() self.h, self.dk = h, d // h self.qkv = nn.Linear(d, 3 * d, bias=False) self.proj = nn.Linear(d, d, bias=False) def forward(self, x, mask=None): B, N, D = x.shape qkv = self.qkv(x).view(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) if mask is not None: att = att + mask return self.proj((F.softmax(att, -1) @ v).transpose(1, 2).reshape(B, N, D)) class ByteBlock(nn.Module): def __init__(self, d, h): super().__init__() self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d) self.attn = ByteAttention(d, h) self.ff = nn.Sequential(nn.Linear(d, 4*d), nn.GELU(), nn.Linear(4*d, d)) def forward(self, x, mask): x = x + self.attn(self.ln1(x), mask) return x + self.ff(self.ln2(x)) class BinaryTransformer(nn.Module): def __init__(self, cfg): super().__init__() d, L, h, V = cfg["d"], cfg["layers"], cfg["heads"], cfg["vocab"] self.emb = nn.Embedding(V, d) # 256 embeddings, one per byte self.blocks = nn.ModuleList([ByteBlock(d, h) for _ in range(L)]) self.ln = nn.LayerNorm(d) self.head = nn.Linear(d, V, bias=False) self.head.weight = self.emb.weight # tie weights def forward(self, x): B, N = x.shape mask = torch.triu(torch.ones(N, N, device=x.device), 1) * -1e9 h = self.emb(x) for block in self.blocks: h = block(h, mask) return self.head(self.ln(h)) def count_params(self): return sum(p.numel() for p in self.parameters()) class BinaryTrainer: def __init__(self, model, lr=LR): self.model = model.to(DEVICE) self.opt = torch.optim.AdamW(model.parameters(), lr=lr) self.ctx_size = CONFIG["ctx"] self.buffer = deque(maxlen=self.ctx_size + 1) self.bytes_seen = 0 self.total_loss = 0.0 self.updates = 0 self.start_time = time.time() def ingest_byte(self, byte_val): """Absorb a single byte (0-255)""" self.buffer.append(byte_val) self.bytes_seen += 1 if len(self.buffer) >= UPDATE_EVERY + 1 and self.bytes_seen % UPDATE_EVERY == 0: self._update() if self.bytes_seen % PRINT_EVERY == 0: self._print_stats() # Save checkpoint every 500k bytes if self.bytes_seen % 500000 == 0 and self.bytes_seen > 0: self._save() def _update(self): tokens = list(self.buffer) x = torch.tensor(tokens[:-1], device=DEVICE, dtype=torch.long).unsqueeze(0) y = torch.tensor(tokens[1:], device=DEVICE, dtype=torch.long).unsqueeze(0) self.model.train() logits = self.model(x) loss = F.cross_entropy( logits[:, -UPDATE_EVERY:].reshape(-1, 256), y[:, -UPDATE_EVERY:].reshape(-1) ) self.opt.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) self.opt.step() self.total_loss += loss.item() self.updates += 1 def _print_stats(self): elapsed = time.time() - self.start_time rate = self.bytes_seen / elapsed if elapsed > 0 else 0 avg_loss = self.total_loss / max(1, self.updates) mb = self.bytes_seen / 1_000_000 # Bits per byte (compression metric) - log2(256)=8 is random, lower is learning bpb = avg_loss / math.log(2) print(f"[{elapsed:.0f}s] {mb:.2f}MB | {rate/1000:.1f} KB/s | " f"loss={avg_loss:.3f} | bpb={bpb:.2f} | updates={self.updates}", flush=True) def _save(self): avg_loss = self.total_loss / max(1, self.updates) mb = self.bytes_seen // 1_000_000 ckpt = { "model": self.model.state_dict(), "bytes": self.bytes_seen, "loss": avg_loss, } torch.save(ckpt, f"byte_ckpt_{mb}mb.pt") print(f"[SAVED] {mb}MB checkpoint", flush=True) def main(): print(f"BINARY TRANSFORMER - Raw bytes learning", flush=True) print(f"Config: {CONFIG}", flush=True) print(f"Device: {DEVICE}", flush=True) model = BinaryTransformer(CONFIG) params = model.count_params() print(f"Parameters: {params:,} ({params/1e6:.1f}M)", flush=True) print(f"Vocab: 256 (one per byte)", flush=True) trainer = BinaryTrainer(model) print(f"Listening for raw bytes on stdin...", flush=True) # Read raw bytes from stdin while True: byte = sys.stdin.buffer.read(1) if not byte: break trainer.ingest_byte(byte[0]) print(f"Stream ended. Total bytes: {trainer.bytes_seen:,}", flush=True) if __name__ == "__main__": main()