#!/usr/bin/env python3 """ PURE BINARY TRANSFORMER - BITS ALL THE WAY DOWN - Vocab = 2 (0 and 1) - Weights = binary (-1 or +1, stored as bits) - Activations = binary where possible Uses Straight-Through Estimator (STE) for gradients. XNOR + popcount for matmul = insanely fast on hardware. """ import sys import math import time import torch import torch.nn as nn import torch.nn.functional as F from collections import deque DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Config for pure binary transformer CONFIG = { "d": 256, # must be divisible by heads "layers": 6, "heads": 8, "vocab": 2, # 0 and 1 "ctx": 2048, } LR = 1e-3 UPDATE_EVERY = 256 PRINT_EVERY = 50000 # ============== BINARY LAYERS ============== class BinarySign(torch.autograd.Function): """Binarize to -1/+1 with straight-through estimator""" @staticmethod def forward(ctx, x): ctx.save_for_backward(x) return x.sign() @staticmethod def backward(ctx, grad_output): x, = ctx.saved_tensors # STE: pass gradient through if |x| <= 1 grad_input = grad_output.clone() grad_input[x.abs() > 1] = 0 return grad_input def binarize(x): return BinarySign.apply(x) class BinaryLinear(nn.Module): """Linear layer with binary weights (-1/+1)""" def __init__(self, in_features, out_features, bias=False): super().__init__() self.in_features = in_features self.out_features = out_features # Real-valued weights for training, binarized during forward self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.1) if bias: self.bias = nn.Parameter(torch.zeros(out_features)) else: self.bias = None def forward(self, x): # Binarize weights to -1/+1 binary_weight = binarize(self.weight) # Scale factor for better gradients (from XNOR-Net paper) # alpha = mean(|W|) alpha = self.weight.abs().mean() out = F.linear(x, binary_weight * alpha, self.bias) return out class BinaryAttention(nn.Module): """Attention with binary QKV projections""" def __init__(self, d, h): super().__init__() self.h, self.dk = h, d // h self.q_proj = BinaryLinear(d, d) self.k_proj = BinaryLinear(d, d) self.v_proj = BinaryLinear(d, d) self.out_proj = BinaryLinear(d, d) def forward(self, x, mask=None): B, N, D = x.shape q = self.q_proj(x).view(B, N, self.h, self.dk).transpose(1, 2) k = self.k_proj(x).view(B, N, self.h, self.dk).transpose(1, 2) v = self.v_proj(x).view(B, N, self.h, self.dk).transpose(1, 2) # Standard attention (values stay real for now) att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) if mask is not None: att = att + mask att = F.softmax(att, dim=-1) out = (att @ v).transpose(1, 2).reshape(B, N, D) return self.out_proj(out) class BinaryMLP(nn.Module): """MLP with binary weights""" def __init__(self, d): super().__init__() self.fc1 = BinaryLinear(d, d * 4) self.fc2 = BinaryLinear(d * 4, d) def forward(self, x): # Binary weights, but ReLU activation (could binarize this too) x = F.gelu(self.fc1(x)) return self.fc2(x) class BinaryBlock(nn.Module): def __init__(self, d, h): super().__init__() self.ln1 = nn.LayerNorm(d) self.attn = BinaryAttention(d, h) self.ln2 = nn.LayerNorm(d) self.mlp = BinaryMLP(d) def forward(self, x, mask): x = x + self.attn(self.ln1(x), mask) return x + self.mlp(self.ln2(x)) class PureBinaryTransformer(nn.Module): """ Transformer where: - Input vocab = 2 (bits) - All linear weights are binary (-1/+1) """ def __init__(self, cfg): super().__init__() d, L, h = cfg["d"], cfg["layers"], cfg["heads"] # Embeddings stay real (only 2 of them anyway) self.emb = nn.Embedding(2, d) # Binary blocks self.blocks = nn.ModuleList([BinaryBlock(d, h) for _ in range(L)]) self.ln = nn.LayerNorm(d) self.head = BinaryLinear(d, 2) # Binary output projection too! def forward(self, x): B, N = x.shape mask = torch.triu(torch.ones(N, N, device=x.device), 1) * -1e9 h = self.emb(x) for block in self.blocks: h = block(h, mask) return self.head(self.ln(h)) def count_params(self): return sum(p.numel() for p in self.parameters()) def count_binary_params(self): """Count params that are binarized""" count = 0 for name, module in self.named_modules(): if isinstance(module, BinaryLinear): count += module.weight.numel() return count def byte_to_bits(byte_val): return [(byte_val >> (7 - i)) & 1 for i in range(8)] class BinaryTrainer: def __init__(self, model, lr=LR): self.model = model.to(DEVICE) self.opt = torch.optim.AdamW(model.parameters(), lr=lr) self.ctx_size = CONFIG["ctx"] self.buffer = deque(maxlen=self.ctx_size + 1) self.bits_seen = 0 self.bytes_seen = 0 self.total_loss = 0.0 self.updates = 0 self.start_time = time.time() def ingest_byte(self, byte_val): bits = byte_to_bits(byte_val) for bit in bits: self.buffer.append(bit) self.bits_seen += 1 if len(self.buffer) >= UPDATE_EVERY + 1 and self.bits_seen % UPDATE_EVERY == 0: self._update() self.bytes_seen += 1 if self.bits_seen % PRINT_EVERY == 0: self._print_stats() if self.bytes_seen % 500000 == 0 and self.bytes_seen > 0: self._save() def _update(self): tokens = list(self.buffer) x = torch.tensor(tokens[:-1], device=DEVICE, dtype=torch.long).unsqueeze(0) y = torch.tensor(tokens[1:], device=DEVICE, dtype=torch.long).unsqueeze(0) self.model.train() logits = self.model(x) loss = F.cross_entropy( logits[:, -UPDATE_EVERY:].reshape(-1, 2), y[:, -UPDATE_EVERY:].reshape(-1) ) self.opt.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) self.opt.step() self.total_loss += loss.item() self.updates += 1 def _print_stats(self): elapsed = time.time() - self.start_time bytes_per_sec = self.bytes_seen / elapsed if elapsed > 0 else 0 avg_loss = self.total_loss / max(1, self.updates) entropy = avg_loss / math.log(2) compression = (1.0 - entropy) * 100 print(f"[{elapsed:.0f}s] {self.bytes_seen/1000:.1f}KB | {bytes_per_sec/1000:.2f} KB/s | " f"loss={avg_loss:.4f} | entropy={entropy:.3f} | compression={compression:.1f}%", flush=True) def _save(self): avg_loss = self.total_loss / max(1, self.updates) kb = self.bytes_seen // 1000 ckpt = { "model": self.model.state_dict(), "bits": self.bits_seen, "bytes": self.bytes_seen, "loss": avg_loss, } torch.save(ckpt, f"/workspace/purebit_ckpt_{kb}kb.pt") print(f"[SAVED] purebit_ckpt_{kb}kb.pt", flush=True) def main(): print(f"PURE BINARY TRANSFORMER - BITS ALL THE WAY DOWN", flush=True) print(f"Config: {CONFIG}", flush=True) print(f"Device: {DEVICE}", flush=True) model = PureBinaryTransformer(CONFIG) total_params = model.count_params() binary_params = model.count_binary_params() print(f"Total Parameters: {total_params:,} ({total_params/1e6:.2f}M)", flush=True) print(f"Binary Parameters: {binary_params:,} ({binary_params/total_params*100:.1f}%)", flush=True) print(f"Vocab: 2 (input bits)", flush=True) print(f"Weights: BINARY (-1/+1)", flush=True) print(f"", flush=True) print(f"🔥 BITS IN, BITS WEIGHTS, BITS OUT 🔥", flush=True) trainer = BinaryTrainer(model) print(f"Listening for bytes...", flush=True) while True: byte = sys.stdin.buffer.read(1) if not byte: break trainer.ingest_byte(byte[0]) print(f"Done. {trainer.bytes_seen:,} bytes = {trainer.bits_seen:,} bits", flush=True) if __name__ == "__main__": main()