binary-transformers / bit_trainer.py
OpenTransformer's picture
Upload folder using huggingface_hub
9d43dda verified
#!/usr/bin/env python3
"""
BIT-LEVEL TRANSFORMER - The Ultimate Zero-Overhead Model
Vocab = 2 (just 0 and 1)
No tokenization. No bytes. Pure binary.
Each byte becomes 8 tokens (bits).
Model learns ALL structure from raw bits.
"""
import sys
import math
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import deque
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cuda.matmul.allow_tf32 = True
# BIT-LEVEL CONFIG - ABSOLUTE UNIT
CONFIG = {
"d": 768, # GPT-2 small size
"layers": 12, # DEEP for bit pattern learning
"heads": 12,
"vocab": 2, # JUST 0 AND 1!
"ctx": 4096, # 512 bytes of context
}
LR = 3e-4 # learning rate
UPDATE_EVERY = 2048 # bits between updates (256 bytes worth) - BIGGER BATCHES
PRINT_EVERY = 100000 # bits
class BitAttention(nn.Module):
def __init__(self, d, h):
super().__init__()
self.h, self.dk = h, d // h
self.qkv = nn.Linear(d, 3 * d, bias=False)
self.proj = nn.Linear(d, d, bias=False)
def forward(self, x, mask=None):
B, N, D = x.shape
qkv = self.qkv(x).view(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
if mask is not None:
att = att + mask
return self.proj((F.softmax(att, -1) @ v).transpose(1, 2).reshape(B, N, D))
class BitBlock(nn.Module):
def __init__(self, d, h):
super().__init__()
self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
self.attn = BitAttention(d, h)
self.ff = nn.Sequential(nn.Linear(d, 4*d), nn.GELU(), nn.Linear(4*d, d))
def forward(self, x, mask):
x = x + self.attn(self.ln1(x), mask)
return x + self.ff(self.ln2(x))
class BitTransformer(nn.Module):
"""Transformer with vocab=2 (just 0 and 1)"""
def __init__(self, cfg):
super().__init__()
d, L, h = cfg["d"], cfg["layers"], cfg["heads"]
self.emb = nn.Embedding(2, d) # ONLY 2 EMBEDDINGS!
self.blocks = nn.ModuleList([BitBlock(d, h) for _ in range(L)])
self.ln = nn.LayerNorm(d)
self.head = nn.Linear(d, 2, bias=False) # predict 0 or 1
def forward(self, x):
B, N = x.shape
mask = torch.triu(torch.ones(N, N, device=x.device), 1) * -1e9
h = self.emb(x)
for block in self.blocks:
h = block(h, mask)
return self.head(self.ln(h))
def count_params(self):
return sum(p.numel() for p in self.parameters())
def byte_to_bits(byte_val):
"""Convert byte to 8 bits (MSB first)"""
return [(byte_val >> (7 - i)) & 1 for i in range(8)]
def bits_to_byte(bits):
"""Convert 8 bits back to byte"""
val = 0
for i, b in enumerate(bits[:8]):
val |= (b << (7 - i))
return val
class BitTrainer:
def __init__(self, model, lr=LR):
self.model = model.to(DEVICE)
self.opt = torch.optim.AdamW(model.parameters(), lr=lr)
self.ctx_size = CONFIG["ctx"]
self.buffer = deque(maxlen=self.ctx_size + 1)
self.bits_seen = 0
self.bytes_seen = 0
self.total_loss = 0.0
self.updates = 0
self.start_time = time.time()
def ingest_byte(self, byte_val):
"""Convert byte to 8 bits and absorb"""
bits = byte_to_bits(byte_val)
for bit in bits:
self.buffer.append(bit)
self.bits_seen += 1
if len(self.buffer) >= UPDATE_EVERY + 1 and self.bits_seen % UPDATE_EVERY == 0:
self._update()
self.bytes_seen += 1
if self.bits_seen % PRINT_EVERY == 0:
self._print_stats()
if self.bytes_seen % 500000 == 0 and self.bytes_seen > 0:
self._save()
def _update(self):
bits = list(self.buffer)
x = torch.tensor(bits[:-1], device=DEVICE, dtype=torch.long).unsqueeze(0)
y = torch.tensor(bits[1:], device=DEVICE, dtype=torch.long).unsqueeze(0)
self.model.train()
logits = self.model(x)
loss = F.cross_entropy(
logits[:, -UPDATE_EVERY:].reshape(-1, 2),
y[:, -UPDATE_EVERY:].reshape(-1)
)
self.opt.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.opt.step()
self.total_loss += loss.item()
self.updates += 1
def _print_stats(self):
elapsed = time.time() - self.start_time
bits_per_sec = self.bits_seen / elapsed if elapsed > 0 else 0
bytes_per_sec = self.bytes_seen / elapsed if elapsed > 0 else 0
avg_loss = self.total_loss / max(1, self.updates)
# For bits: random is 1.0 (coin flip), lower = learning
# Entropy in bits per bit
entropy = avg_loss / math.log(2)
compression = (1.0 - entropy) * 100 # % compression vs random
print(f"[{elapsed:.0f}s] {self.bytes_seen/1000:.1f}KB | {bytes_per_sec/1000:.1f} KB/s | "
f"loss={avg_loss:.4f} | entropy={entropy:.3f} bit/bit | "
f"compression={compression:.1f}%", flush=True)
def _save(self):
avg_loss = self.total_loss / max(1, self.updates)
kb = self.bytes_seen // 1000
ckpt = {
"model": self.model.state_dict(),
"bits": self.bits_seen,
"bytes": self.bytes_seen,
"loss": avg_loss,
}
torch.save(ckpt, f"/workspace/bit_ckpt_{kb}kb.pt")
print(f"[SAVED] bit_ckpt_{kb}kb.pt", flush=True)
def main():
print(f"BIT-LEVEL TRANSFORMER - Vocab = 2 (just 0 and 1)", flush=True)
print(f"Config: {CONFIG}", flush=True)
print(f"Device: {DEVICE}", flush=True)
model = BitTransformer(CONFIG)
params = model.count_params()
print(f"Parameters: {params:,} ({params/1e6:.2f}M)", flush=True)
print(f"Vocab: 2 (literally just 0 and 1)", flush=True)
print(f"Each byte = 8 bit tokens", flush=True)
trainer = BitTrainer(model)
print(f"Listening for bytes (FAST batch mode)...", flush=True)
# Read in large chunks for speed
CHUNK_SIZE = 8192 # 8KB chunks = 65536 bits
while True:
chunk = sys.stdin.buffer.read(CHUNK_SIZE)
if not chunk:
break
for byte in chunk:
trainer.ingest_byte(byte)
print(f"Stream ended. Total: {trainer.bytes_seen:,} bytes = {trainer.bits_seen:,} bits", flush=True)
if __name__ == "__main__":
main()