|
|
|
|
|
""" |
|
|
PURE BINARY TRANSFORMER - BITS ALL THE WAY DOWN |
|
|
- Vocab = 2 (0 and 1) |
|
|
- Weights = binary (-1 or +1, stored as bits) |
|
|
- Activations = binary where possible |
|
|
|
|
|
Uses Straight-Through Estimator (STE) for gradients. |
|
|
XNOR + popcount for matmul = insanely fast on hardware. |
|
|
""" |
|
|
|
|
|
import sys |
|
|
import math |
|
|
import time |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from collections import deque |
|
|
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
CONFIG = { |
|
|
"d": 256, |
|
|
"layers": 6, |
|
|
"heads": 8, |
|
|
"vocab": 2, |
|
|
"ctx": 2048, |
|
|
} |
|
|
|
|
|
LR = 1e-3 |
|
|
UPDATE_EVERY = 256 |
|
|
PRINT_EVERY = 50000 |
|
|
|
|
|
|
|
|
|
|
|
class BinarySign(torch.autograd.Function): |
|
|
"""Binarize to -1/+1 with straight-through estimator""" |
|
|
@staticmethod |
|
|
def forward(ctx, x): |
|
|
ctx.save_for_backward(x) |
|
|
return x.sign() |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, grad_output): |
|
|
x, = ctx.saved_tensors |
|
|
|
|
|
grad_input = grad_output.clone() |
|
|
grad_input[x.abs() > 1] = 0 |
|
|
return grad_input |
|
|
|
|
|
def binarize(x): |
|
|
return BinarySign.apply(x) |
|
|
|
|
|
class BinaryLinear(nn.Module): |
|
|
"""Linear layer with binary weights (-1/+1)""" |
|
|
def __init__(self, in_features, out_features, bias=False): |
|
|
super().__init__() |
|
|
self.in_features = in_features |
|
|
self.out_features = out_features |
|
|
|
|
|
|
|
|
self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.1) |
|
|
if bias: |
|
|
self.bias = nn.Parameter(torch.zeros(out_features)) |
|
|
else: |
|
|
self.bias = None |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
binary_weight = binarize(self.weight) |
|
|
|
|
|
|
|
|
|
|
|
alpha = self.weight.abs().mean() |
|
|
|
|
|
out = F.linear(x, binary_weight * alpha, self.bias) |
|
|
return out |
|
|
|
|
|
class BinaryAttention(nn.Module): |
|
|
"""Attention with binary QKV projections""" |
|
|
def __init__(self, d, h): |
|
|
super().__init__() |
|
|
self.h, self.dk = h, d // h |
|
|
self.q_proj = BinaryLinear(d, d) |
|
|
self.k_proj = BinaryLinear(d, d) |
|
|
self.v_proj = BinaryLinear(d, d) |
|
|
self.out_proj = BinaryLinear(d, d) |
|
|
|
|
|
def forward(self, x, mask=None): |
|
|
B, N, D = x.shape |
|
|
|
|
|
q = self.q_proj(x).view(B, N, self.h, self.dk).transpose(1, 2) |
|
|
k = self.k_proj(x).view(B, N, self.h, self.dk).transpose(1, 2) |
|
|
v = self.v_proj(x).view(B, N, self.h, self.dk).transpose(1, 2) |
|
|
|
|
|
|
|
|
att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) |
|
|
if mask is not None: |
|
|
att = att + mask |
|
|
att = F.softmax(att, dim=-1) |
|
|
|
|
|
out = (att @ v).transpose(1, 2).reshape(B, N, D) |
|
|
return self.out_proj(out) |
|
|
|
|
|
class BinaryMLP(nn.Module): |
|
|
"""MLP with binary weights""" |
|
|
def __init__(self, d): |
|
|
super().__init__() |
|
|
self.fc1 = BinaryLinear(d, d * 4) |
|
|
self.fc2 = BinaryLinear(d * 4, d) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
x = F.gelu(self.fc1(x)) |
|
|
return self.fc2(x) |
|
|
|
|
|
class BinaryBlock(nn.Module): |
|
|
def __init__(self, d, h): |
|
|
super().__init__() |
|
|
self.ln1 = nn.LayerNorm(d) |
|
|
self.attn = BinaryAttention(d, h) |
|
|
self.ln2 = nn.LayerNorm(d) |
|
|
self.mlp = BinaryMLP(d) |
|
|
|
|
|
def forward(self, x, mask): |
|
|
x = x + self.attn(self.ln1(x), mask) |
|
|
return x + self.mlp(self.ln2(x)) |
|
|
|
|
|
class PureBinaryTransformer(nn.Module): |
|
|
""" |
|
|
Transformer where: |
|
|
- Input vocab = 2 (bits) |
|
|
- All linear weights are binary (-1/+1) |
|
|
""" |
|
|
def __init__(self, cfg): |
|
|
super().__init__() |
|
|
d, L, h = cfg["d"], cfg["layers"], cfg["heads"] |
|
|
|
|
|
|
|
|
self.emb = nn.Embedding(2, d) |
|
|
|
|
|
|
|
|
self.blocks = nn.ModuleList([BinaryBlock(d, h) for _ in range(L)]) |
|
|
|
|
|
self.ln = nn.LayerNorm(d) |
|
|
self.head = BinaryLinear(d, 2) |
|
|
|
|
|
def forward(self, x): |
|
|
B, N = x.shape |
|
|
mask = torch.triu(torch.ones(N, N, device=x.device), 1) * -1e9 |
|
|
|
|
|
h = self.emb(x) |
|
|
for block in self.blocks: |
|
|
h = block(h, mask) |
|
|
|
|
|
return self.head(self.ln(h)) |
|
|
|
|
|
def count_params(self): |
|
|
return sum(p.numel() for p in self.parameters()) |
|
|
|
|
|
def count_binary_params(self): |
|
|
"""Count params that are binarized""" |
|
|
count = 0 |
|
|
for name, module in self.named_modules(): |
|
|
if isinstance(module, BinaryLinear): |
|
|
count += module.weight.numel() |
|
|
return count |
|
|
|
|
|
def byte_to_bits(byte_val): |
|
|
return [(byte_val >> (7 - i)) & 1 for i in range(8)] |
|
|
|
|
|
class BinaryTrainer: |
|
|
def __init__(self, model, lr=LR): |
|
|
self.model = model.to(DEVICE) |
|
|
self.opt = torch.optim.AdamW(model.parameters(), lr=lr) |
|
|
self.ctx_size = CONFIG["ctx"] |
|
|
self.buffer = deque(maxlen=self.ctx_size + 1) |
|
|
|
|
|
self.bits_seen = 0 |
|
|
self.bytes_seen = 0 |
|
|
self.total_loss = 0.0 |
|
|
self.updates = 0 |
|
|
self.start_time = time.time() |
|
|
|
|
|
def ingest_byte(self, byte_val): |
|
|
bits = byte_to_bits(byte_val) |
|
|
for bit in bits: |
|
|
self.buffer.append(bit) |
|
|
self.bits_seen += 1 |
|
|
|
|
|
if len(self.buffer) >= UPDATE_EVERY + 1 and self.bits_seen % UPDATE_EVERY == 0: |
|
|
self._update() |
|
|
|
|
|
self.bytes_seen += 1 |
|
|
|
|
|
if self.bits_seen % PRINT_EVERY == 0: |
|
|
self._print_stats() |
|
|
|
|
|
if self.bytes_seen % 500000 == 0 and self.bytes_seen > 0: |
|
|
self._save() |
|
|
|
|
|
def _update(self): |
|
|
tokens = list(self.buffer) |
|
|
x = torch.tensor(tokens[:-1], device=DEVICE, dtype=torch.long).unsqueeze(0) |
|
|
y = torch.tensor(tokens[1:], device=DEVICE, dtype=torch.long).unsqueeze(0) |
|
|
|
|
|
self.model.train() |
|
|
logits = self.model(x) |
|
|
loss = F.cross_entropy( |
|
|
logits[:, -UPDATE_EVERY:].reshape(-1, 2), |
|
|
y[:, -UPDATE_EVERY:].reshape(-1) |
|
|
) |
|
|
|
|
|
self.opt.zero_grad() |
|
|
loss.backward() |
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) |
|
|
self.opt.step() |
|
|
|
|
|
self.total_loss += loss.item() |
|
|
self.updates += 1 |
|
|
|
|
|
def _print_stats(self): |
|
|
elapsed = time.time() - self.start_time |
|
|
bytes_per_sec = self.bytes_seen / elapsed if elapsed > 0 else 0 |
|
|
avg_loss = self.total_loss / max(1, self.updates) |
|
|
|
|
|
entropy = avg_loss / math.log(2) |
|
|
compression = (1.0 - entropy) * 100 |
|
|
|
|
|
print(f"[{elapsed:.0f}s] {self.bytes_seen/1000:.1f}KB | {bytes_per_sec/1000:.2f} KB/s | " |
|
|
f"loss={avg_loss:.4f} | entropy={entropy:.3f} | compression={compression:.1f}%", flush=True) |
|
|
|
|
|
def _save(self): |
|
|
avg_loss = self.total_loss / max(1, self.updates) |
|
|
kb = self.bytes_seen // 1000 |
|
|
ckpt = { |
|
|
"model": self.model.state_dict(), |
|
|
"bits": self.bits_seen, |
|
|
"bytes": self.bytes_seen, |
|
|
"loss": avg_loss, |
|
|
} |
|
|
torch.save(ckpt, f"/workspace/purebit_ckpt_{kb}kb.pt") |
|
|
print(f"[SAVED] purebit_ckpt_{kb}kb.pt", flush=True) |
|
|
|
|
|
def main(): |
|
|
print(f"PURE BINARY TRANSFORMER - BITS ALL THE WAY DOWN", flush=True) |
|
|
print(f"Config: {CONFIG}", flush=True) |
|
|
print(f"Device: {DEVICE}", flush=True) |
|
|
|
|
|
model = PureBinaryTransformer(CONFIG) |
|
|
total_params = model.count_params() |
|
|
binary_params = model.count_binary_params() |
|
|
|
|
|
print(f"Total Parameters: {total_params:,} ({total_params/1e6:.2f}M)", flush=True) |
|
|
print(f"Binary Parameters: {binary_params:,} ({binary_params/total_params*100:.1f}%)", flush=True) |
|
|
print(f"Vocab: 2 (input bits)", flush=True) |
|
|
print(f"Weights: BINARY (-1/+1)", flush=True) |
|
|
print(f"", flush=True) |
|
|
print(f"🔥 BITS IN, BITS WEIGHTS, BITS OUT 🔥", flush=True) |
|
|
|
|
|
trainer = BinaryTrainer(model) |
|
|
|
|
|
print(f"Listening for bytes...", flush=True) |
|
|
|
|
|
while True: |
|
|
byte = sys.stdin.buffer.read(1) |
|
|
if not byte: |
|
|
break |
|
|
trainer.ingest_byte(byte[0]) |
|
|
|
|
|
print(f"Done. {trainer.bytes_seen:,} bytes = {trainer.bits_seen:,} bits", flush=True) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|