File size: 6,874 Bytes
9d43dda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
#!/usr/bin/env python3
"""
DIBIT TRANSFORMER - 2-bit tokens
Vocab = 4 (00, 01, 10, 11)
Each byte = 4 tokens (vs 8 for bit-level)
Better context efficiency while still pure binary!
"""

import sys
import math
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import deque

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cuda.matmul.allow_tf32 = True

# DIBIT CONFIG - 2-bit tokens
CONFIG = {
    "d": 512,        # good size
    "layers": 12,    
    "heads": 8,
    "vocab": 4,      # 00, 01, 10, 11
    "ctx": 4096,     # 1024 bytes of context (2x more than bit-level!)
}

LR = 3e-4
UPDATE_EVERY = 512  # dibits between updates (128 bytes worth)
PRINT_EVERY = 50000  # dibits

class DibitAttention(nn.Module):
    def __init__(self, d, h):
        super().__init__()
        self.h, self.dk = h, d // h
        self.qkv = nn.Linear(d, 3 * d, bias=False)
        self.proj = nn.Linear(d, d, bias=False)
        
    def forward(self, x, mask=None):
        B, N, D = x.shape
        qkv = self.qkv(x).view(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
        if mask is not None:
            att = att + mask
        return self.proj((F.softmax(att, -1) @ v).transpose(1, 2).reshape(B, N, D))

class DibitBlock(nn.Module):
    def __init__(self, d, h):
        super().__init__()
        self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
        self.attn = DibitAttention(d, h)
        self.ff = nn.Sequential(nn.Linear(d, 4*d), nn.GELU(), nn.Linear(4*d, d))
        
    def forward(self, x, mask):
        x = x + self.attn(self.ln1(x), mask)
        return x + self.ff(self.ln2(x))

class DibitTransformer(nn.Module):
    """Transformer with vocab=4 (00, 01, 10, 11)"""
    def __init__(self, cfg):
        super().__init__()
        d, L, h = cfg["d"], cfg["layers"], cfg["heads"]
        self.emb = nn.Embedding(4, d)  # 4 embeddings for dibits
        self.blocks = nn.ModuleList([DibitBlock(d, h) for _ in range(L)])
        self.ln = nn.LayerNorm(d)
        self.head = nn.Linear(d, 4, bias=False)  # predict 00, 01, 10, or 11
        
    def forward(self, x):
        B, N = x.shape
        mask = torch.triu(torch.ones(N, N, device=x.device), 1) * -1e9
        h = self.emb(x)
        for block in self.blocks:
            h = block(h, mask)
        return self.head(self.ln(h))
    
    def count_params(self):
        return sum(p.numel() for p in self.parameters())

def byte_to_dibits(byte_val):
    """Convert byte to 4 dibits (2-bit chunks, MSB first)
    e.g., 0b11100100 -> [3, 2, 1, 0] (11, 10, 01, 00)
    """
    return [
        (byte_val >> 6) & 0b11,  # bits 7-6
        (byte_val >> 4) & 0b11,  # bits 5-4
        (byte_val >> 2) & 0b11,  # bits 3-2
        byte_val & 0b11,         # bits 1-0
    ]

def dibits_to_byte(dibits):
    """Convert 4 dibits back to byte"""
    return (dibits[0] << 6) | (dibits[1] << 4) | (dibits[2] << 2) | dibits[3]

class DibitTrainer:
    def __init__(self, model, lr=LR):
        self.model = model.to(DEVICE)
        self.opt = torch.optim.AdamW(model.parameters(), lr=lr)
        self.ctx_size = CONFIG["ctx"]
        self.buffer = deque(maxlen=self.ctx_size + 1)
        
        self.dibits_seen = 0
        self.bytes_seen = 0
        self.total_loss = 0.0
        self.updates = 0
        self.start_time = time.time()
        
    def ingest_byte(self, byte_val):
        """Convert byte to 4 dibits and absorb"""
        dibits = byte_to_dibits(byte_val)
        for dibit in dibits:
            self.buffer.append(dibit)
            self.dibits_seen += 1
            
            if len(self.buffer) >= UPDATE_EVERY + 1 and self.dibits_seen % UPDATE_EVERY == 0:
                self._update()
        
        self.bytes_seen += 1
        
        if self.dibits_seen % PRINT_EVERY == 0:
            self._print_stats()
            
        if self.bytes_seen % 500000 == 0 and self.bytes_seen > 0:
            self._save()
    
    def _update(self):
        tokens = list(self.buffer)
        x = torch.tensor(tokens[:-1], device=DEVICE, dtype=torch.long).unsqueeze(0)
        y = torch.tensor(tokens[1:], device=DEVICE, dtype=torch.long).unsqueeze(0)
        
        self.model.train()
        logits = self.model(x)
        loss = F.cross_entropy(
            logits[:, -UPDATE_EVERY:].reshape(-1, 4),
            y[:, -UPDATE_EVERY:].reshape(-1)
        )
        
        self.opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        self.opt.step()
        
        self.total_loss += loss.item()
        self.updates += 1
    
    def _print_stats(self):
        elapsed = time.time() - self.start_time
        bytes_per_sec = self.bytes_seen / elapsed if elapsed > 0 else 0
        avg_loss = self.total_loss / max(1, self.updates)
        
        # For dibits: random is log(4)/log(2) = 2.0 bits per dibit
        # Entropy in bits per dibit
        entropy_per_dibit = avg_loss / math.log(2)
        # Convert to bits per byte (4 dibits per byte)
        bpb = entropy_per_dibit * 4
        # Random byte = 8 bits, so compression vs random
        compression = (1.0 - bpb/8) * 100
        
        print(f"[{elapsed:.0f}s] {self.bytes_seen/1000:.1f}KB | {bytes_per_sec/1000:.2f} KB/s | "
              f"loss={avg_loss:.4f} | bpb={bpb:.2f} | compression={compression:.1f}%", flush=True)
    
    def _save(self):
        avg_loss = self.total_loss / max(1, self.updates)
        kb = self.bytes_seen // 1000
        ckpt = {
            "model": self.model.state_dict(),
            "dibits": self.dibits_seen,
            "bytes": self.bytes_seen,
            "loss": avg_loss,
        }
        torch.save(ckpt, f"/workspace/dibit_ckpt_{kb}kb.pt")
        print(f"[SAVED] dibit_ckpt_{kb}kb.pt", flush=True)

def main():
    print(f"DIBIT TRANSFORMER - Vocab = 4 (00, 01, 10, 11)", flush=True)
    print(f"Config: {CONFIG}", flush=True)
    print(f"Device: {DEVICE}", flush=True)
    
    model = DibitTransformer(CONFIG)
    params = model.count_params()
    print(f"Parameters: {params:,} ({params/1e6:.2f}M)", flush=True)
    print(f"Vocab: 4 (2-bit tokens: 00, 01, 10, 11)", flush=True)
    print(f"Each byte = 4 dibit tokens", flush=True)
    print(f"Context: {CONFIG['ctx']} dibits = {CONFIG['ctx']//4} bytes", flush=True)
    
    trainer = DibitTrainer(model)
    
    print(f"Listening for bytes (converting to dibits)...", flush=True)
    
    while True:
        byte = sys.stdin.buffer.read(1)
        if not byte:
            break
        trainer.ingest_byte(byte[0])
    
    print(f"Stream ended. Total: {trainer.bytes_seen:,} bytes = {trainer.dibits_seen:,} dibits", flush=True)

if __name__ == "__main__":
    main()