binary-transformers / purebit_trainer.py
OpenTransformer's picture
Upload folder using huggingface_hub
9d43dda verified
#!/usr/bin/env python3
"""
PURE BINARY TRANSFORMER - BITS ALL THE WAY DOWN
- Vocab = 2 (0 and 1)
- Weights = binary (-1 or +1, stored as bits)
- Activations = binary where possible
Uses Straight-Through Estimator (STE) for gradients.
XNOR + popcount for matmul = insanely fast on hardware.
"""
import sys
import math
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import deque
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Config for pure binary transformer
CONFIG = {
"d": 256, # must be divisible by heads
"layers": 6,
"heads": 8,
"vocab": 2, # 0 and 1
"ctx": 2048,
}
LR = 1e-3
UPDATE_EVERY = 256
PRINT_EVERY = 50000
# ============== BINARY LAYERS ==============
class BinarySign(torch.autograd.Function):
"""Binarize to -1/+1 with straight-through estimator"""
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x.sign()
@staticmethod
def backward(ctx, grad_output):
x, = ctx.saved_tensors
# STE: pass gradient through if |x| <= 1
grad_input = grad_output.clone()
grad_input[x.abs() > 1] = 0
return grad_input
def binarize(x):
return BinarySign.apply(x)
class BinaryLinear(nn.Module):
"""Linear layer with binary weights (-1/+1)"""
def __init__(self, in_features, out_features, bias=False):
super().__init__()
self.in_features = in_features
self.out_features = out_features
# Real-valued weights for training, binarized during forward
self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.1)
if bias:
self.bias = nn.Parameter(torch.zeros(out_features))
else:
self.bias = None
def forward(self, x):
# Binarize weights to -1/+1
binary_weight = binarize(self.weight)
# Scale factor for better gradients (from XNOR-Net paper)
# alpha = mean(|W|)
alpha = self.weight.abs().mean()
out = F.linear(x, binary_weight * alpha, self.bias)
return out
class BinaryAttention(nn.Module):
"""Attention with binary QKV projections"""
def __init__(self, d, h):
super().__init__()
self.h, self.dk = h, d // h
self.q_proj = BinaryLinear(d, d)
self.k_proj = BinaryLinear(d, d)
self.v_proj = BinaryLinear(d, d)
self.out_proj = BinaryLinear(d, d)
def forward(self, x, mask=None):
B, N, D = x.shape
q = self.q_proj(x).view(B, N, self.h, self.dk).transpose(1, 2)
k = self.k_proj(x).view(B, N, self.h, self.dk).transpose(1, 2)
v = self.v_proj(x).view(B, N, self.h, self.dk).transpose(1, 2)
# Standard attention (values stay real for now)
att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
if mask is not None:
att = att + mask
att = F.softmax(att, dim=-1)
out = (att @ v).transpose(1, 2).reshape(B, N, D)
return self.out_proj(out)
class BinaryMLP(nn.Module):
"""MLP with binary weights"""
def __init__(self, d):
super().__init__()
self.fc1 = BinaryLinear(d, d * 4)
self.fc2 = BinaryLinear(d * 4, d)
def forward(self, x):
# Binary weights, but ReLU activation (could binarize this too)
x = F.gelu(self.fc1(x))
return self.fc2(x)
class BinaryBlock(nn.Module):
def __init__(self, d, h):
super().__init__()
self.ln1 = nn.LayerNorm(d)
self.attn = BinaryAttention(d, h)
self.ln2 = nn.LayerNorm(d)
self.mlp = BinaryMLP(d)
def forward(self, x, mask):
x = x + self.attn(self.ln1(x), mask)
return x + self.mlp(self.ln2(x))
class PureBinaryTransformer(nn.Module):
"""
Transformer where:
- Input vocab = 2 (bits)
- All linear weights are binary (-1/+1)
"""
def __init__(self, cfg):
super().__init__()
d, L, h = cfg["d"], cfg["layers"], cfg["heads"]
# Embeddings stay real (only 2 of them anyway)
self.emb = nn.Embedding(2, d)
# Binary blocks
self.blocks = nn.ModuleList([BinaryBlock(d, h) for _ in range(L)])
self.ln = nn.LayerNorm(d)
self.head = BinaryLinear(d, 2) # Binary output projection too!
def forward(self, x):
B, N = x.shape
mask = torch.triu(torch.ones(N, N, device=x.device), 1) * -1e9
h = self.emb(x)
for block in self.blocks:
h = block(h, mask)
return self.head(self.ln(h))
def count_params(self):
return sum(p.numel() for p in self.parameters())
def count_binary_params(self):
"""Count params that are binarized"""
count = 0
for name, module in self.named_modules():
if isinstance(module, BinaryLinear):
count += module.weight.numel()
return count
def byte_to_bits(byte_val):
return [(byte_val >> (7 - i)) & 1 for i in range(8)]
class BinaryTrainer:
def __init__(self, model, lr=LR):
self.model = model.to(DEVICE)
self.opt = torch.optim.AdamW(model.parameters(), lr=lr)
self.ctx_size = CONFIG["ctx"]
self.buffer = deque(maxlen=self.ctx_size + 1)
self.bits_seen = 0
self.bytes_seen = 0
self.total_loss = 0.0
self.updates = 0
self.start_time = time.time()
def ingest_byte(self, byte_val):
bits = byte_to_bits(byte_val)
for bit in bits:
self.buffer.append(bit)
self.bits_seen += 1
if len(self.buffer) >= UPDATE_EVERY + 1 and self.bits_seen % UPDATE_EVERY == 0:
self._update()
self.bytes_seen += 1
if self.bits_seen % PRINT_EVERY == 0:
self._print_stats()
if self.bytes_seen % 500000 == 0 and self.bytes_seen > 0:
self._save()
def _update(self):
tokens = list(self.buffer)
x = torch.tensor(tokens[:-1], device=DEVICE, dtype=torch.long).unsqueeze(0)
y = torch.tensor(tokens[1:], device=DEVICE, dtype=torch.long).unsqueeze(0)
self.model.train()
logits = self.model(x)
loss = F.cross_entropy(
logits[:, -UPDATE_EVERY:].reshape(-1, 2),
y[:, -UPDATE_EVERY:].reshape(-1)
)
self.opt.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.opt.step()
self.total_loss += loss.item()
self.updates += 1
def _print_stats(self):
elapsed = time.time() - self.start_time
bytes_per_sec = self.bytes_seen / elapsed if elapsed > 0 else 0
avg_loss = self.total_loss / max(1, self.updates)
entropy = avg_loss / math.log(2)
compression = (1.0 - entropy) * 100
print(f"[{elapsed:.0f}s] {self.bytes_seen/1000:.1f}KB | {bytes_per_sec/1000:.2f} KB/s | "
f"loss={avg_loss:.4f} | entropy={entropy:.3f} | compression={compression:.1f}%", flush=True)
def _save(self):
avg_loss = self.total_loss / max(1, self.updates)
kb = self.bytes_seen // 1000
ckpt = {
"model": self.model.state_dict(),
"bits": self.bits_seen,
"bytes": self.bytes_seen,
"loss": avg_loss,
}
torch.save(ckpt, f"/workspace/purebit_ckpt_{kb}kb.pt")
print(f"[SAVED] purebit_ckpt_{kb}kb.pt", flush=True)
def main():
print(f"PURE BINARY TRANSFORMER - BITS ALL THE WAY DOWN", flush=True)
print(f"Config: {CONFIG}", flush=True)
print(f"Device: {DEVICE}", flush=True)
model = PureBinaryTransformer(CONFIG)
total_params = model.count_params()
binary_params = model.count_binary_params()
print(f"Total Parameters: {total_params:,} ({total_params/1e6:.2f}M)", flush=True)
print(f"Binary Parameters: {binary_params:,} ({binary_params/total_params*100:.1f}%)", flush=True)
print(f"Vocab: 2 (input bits)", flush=True)
print(f"Weights: BINARY (-1/+1)", flush=True)
print(f"", flush=True)
print(f"🔥 BITS IN, BITS WEIGHTS, BITS OUT 🔥", flush=True)
trainer = BinaryTrainer(model)
print(f"Listening for bytes...", flush=True)
while True:
byte = sys.stdin.buffer.read(1)
if not byte:
break
trainer.ingest_byte(byte[0])
print(f"Done. {trainer.bytes_seen:,} bytes = {trainer.bits_seen:,} bits", flush=True)
if __name__ == "__main__":
main()