BitTransformerLM / scripts /tools /integration_schedule.py
WCNegentropy's picture
๐Ÿš€ Refined BitTransformerLM: Organized codebase with best practices
4a5ea0f verified
import os
import time
import math
from itertools import cycle
from typing import Optional
import torch
import torch.nn.functional as F
from bit_transformer import (
BitTransformerLM,
text_to_bits,
quantize_dynamic,
prepare_qat_fx,
convert_qat_fx,
hil_safe_inference,
collapse_submodel,
diffusion_inference,
TelemetrySynthesizer,
save_distilled_model,
)
from bit_transformer.training import train_loop as train
from bit_transformer.optimization import configure_optimizer, adjust_learning_rate
from bit_transformer.utils import save_model, load_model, set_dropout
from bit_transformer.torch_utils import cpu_autocast
def lines_to_tensor(lines, max_len):
seqs = []
for text in lines:
bits = text_to_bits(text)[:max_len]
if len(bits) < max_len:
bits.extend([0] * (max_len - len(bits)))
seqs.append(bits)
return torch.tensor(seqs, dtype=torch.long)
def load_wikitext(dataset_size=128, max_len=64):
try:
from datasets import load_dataset
ds = load_dataset("wikitext", "wikitext-2-raw-v1")
train_lines = [t for t in ds["train"]["text"] if t.strip()][:dataset_size]
valid_split = max(1, dataset_size // 4)
valid_lines = [t for t in ds["validation"]["text"] if t.strip()][:valid_split]
train = lines_to_tensor(train_lines, max_len)
valid = lines_to_tensor(valid_lines, max_len)
return train, valid, train_lines
except Exception as e:
print("Dataset load failed, using random bits", e)
train = torch.randint(0, 2, (dataset_size, max_len), dtype=torch.long)
valid = torch.randint(0, 2, (max_len, max_len), dtype=torch.long)
return train, valid, ["" for _ in range(len(train))]
def _warmup(
model: BitTransformerLM,
data: torch.Tensor,
steps: int = 5,
freeze_old: bool = False,
old_layers: int = 0,
*,
diffusion: bool = False,
curriculum: bool = False,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> None:
"""Run a short warm-up loop after expansion."""
model.train()
set_dropout(model, 0.1)
if freeze_old:
for idx, layer in enumerate(model.layers):
if idx < old_layers:
for p in layer.parameters():
p.requires_grad_(False)
if optimizer is None or scheduler is None:
optimizer, scheduler = configure_optimizer(model, lr=1e-3, total_steps=steps)
it = iter(data.split(8))
for idx in range(steps):
try:
batch = next(it)
except StopIteration:
it = iter(data.split(8))
batch = next(it)
if diffusion:
p = 0.5 * (1 - idx / max(1, steps - 1)) if curriculum else 0.5
noise = (torch.rand_like(batch.float()) < p).long()
noisy = batch ^ noise
logits, _ = model(noisy, causal=False)
pred = logits.reshape(-1, 2)
target = batch.reshape(-1)
else:
logits, _ = model(batch)
pred = logits[:, :-1, :].reshape(-1, 2)
target = batch[:, 1:].reshape(-1)
loss = F.cross_entropy(pred, target)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
for p in model.parameters():
p.requires_grad_(True)
model.eval()
set_dropout(model, 0.0)
def integration_schedule(
steps: int = 10,
max_len: int = 64,
dataset_size: int = 128,
*,
weights_path: str = "weights/model.pt.gz",
plateau_steps: int = 0,
collapsed_path: str | None = None,
epochs_per_step: int = 2,
extra_steps: int = 3,
collapse: bool = True,
diffusion: bool = False,
noise_schedule: str = "linear",
diffusion_steps: int = 8,
diffusion_curriculum: bool = False,
use_checkpoint: bool = True,
reversible: bool = True,
improve_thresh: float = 0.01,
qat: bool = False,
):
start = time.time()
train_bits, valid_bits, train_lines = load_wikitext(dataset_size, max_len)
if os.path.exists(weights_path):
try:
model = load_model(weights_path)
print(f"Loaded model from {weights_path}")
except Exception as e:
print("Failed to load weights, initializing new model", e)
model = BitTransformerLM(
d_model=32,
nhead=4,
num_layers=1,
dim_feedforward=64,
max_seq_len=max_len,
use_act=True,
act_threshold=0.7,
reversible=reversible,
chunk_size=max_len,
use_autocast=True,
use_checkpoint=use_checkpoint,
)
else:
model = BitTransformerLM(
d_model=32,
nhead=4,
num_layers=1,
dim_feedforward=64,
max_seq_len=max_len,
use_act=True,
act_threshold=0.7,
reversible=reversible,
chunk_size=max_len,
use_autocast=True,
use_checkpoint=use_checkpoint,
)
if qat:
model = prepare_qat_fx(model)
results = []
scale_cycle = cycle(["layers", "width", "context"])
base_lr = 1e-3
prev_val_loss: Optional[float] = None
for step in range(steps):
model.train()
set_dropout(model, 0.1)
opt, sched = configure_optimizer(
model, lr=base_lr, total_steps=epochs_per_step
)
train(
model,
train_bits,
epochs=epochs_per_step,
extra_steps=extra_steps,
compress_prob=0.0 if diffusion else 1.0,
log=True,
diffusion=diffusion,
diffusion_curriculum=diffusion_curriculum,
optimizer=opt,
scheduler=sched,
)
model.eval()
set_dropout(model, 0.0)
with torch.no_grad():
logits, telemetry = model(valid_bits, causal=not diffusion)
if diffusion:
pred = logits.reshape(-1, 2)
target = valid_bits.reshape(-1)
else:
pred = logits[:, :-1, :].reshape(-1, 2)
target = valid_bits[:, 1:].reshape(-1)
val_loss = F.cross_entropy(pred, target).item()
k = telemetry["negentropy_logits"].mean().item()
c = telemetry["lz_complexity_logits"].mean().item()
s = telemetry["symbiosis_score"].mean().item()
print(f"Step {step} validation loss: {val_loss:.4f} K={k:.3f} C={c:.3f} S={s:.3f}")
results.append((step, val_loss, k, c, s))
if prev_val_loss is not None and prev_val_loss - val_loss < improve_thresh:
strategy = next(scale_cycle)
base_lr = adjust_learning_rate(opt, 1 / math.sqrt(2))
if strategy == "layers":
old_layers = model.num_layers
model = model.double_layers()
warm_opt, warm_sched = configure_optimizer(
model, lr=base_lr, total_steps=100
)
_warmup(
model,
train_bits,
steps=100,
freeze_old=True,
old_layers=old_layers,
diffusion=diffusion,
curriculum=diffusion_curriculum,
optimizer=warm_opt,
scheduler=warm_sched,
)
elif strategy == "width":
model = model.double_width()
warm_opt, warm_sched = configure_optimizer(
model, lr=base_lr, total_steps=100
)
_warmup(
model,
train_bits,
steps=100,
diffusion=diffusion,
curriculum=diffusion_curriculum,
optimizer=warm_opt,
scheduler=warm_sched,
)
else:
max_len *= 2
train_bits, valid_bits, train_lines = load_wikitext(
dataset_size, max_len
)
model = model.double_length()
warm_opt, warm_sched = configure_optimizer(
model, lr=base_lr, total_steps=100
)
_warmup(
model,
train_bits,
steps=100,
diffusion=diffusion,
curriculum=diffusion_curriculum,
optimizer=warm_opt,
scheduler=warm_sched,
)
prev_val_loss = val_loss
if time.time() - start > 8 * 60:
print("Time limit reached")
break
# optional plateau phase at final size
for p in range(plateau_steps):
model.train()
set_dropout(model, 0.1)
train(
model,
train_bits,
epochs=epochs_per_step,
extra_steps=extra_steps,
compress_prob=0.0 if diffusion else 1.0,
log=True,
diffusion=diffusion,
diffusion_curriculum=diffusion_curriculum,
)
model.eval()
set_dropout(model, 0.0)
with torch.no_grad():
logits, telemetry = model(valid_bits, causal=not diffusion)
if diffusion:
pred = logits.reshape(-1, 2)
target = valid_bits.reshape(-1)
else:
pred = logits[:, :-1, :].reshape(-1, 2)
target = valid_bits[:, 1:].reshape(-1)
val_loss = F.cross_entropy(pred, target).item()
k = telemetry["negentropy_logits"].mean().item()
c = telemetry["lz_complexity_logits"].mean().item()
s = telemetry["symbiosis_score"].mean().item()
idx = steps + p
print(
f"Plateau {p} validation loss: {val_loss:.4f} K={k:.3f} C={c:.3f} S={s:.3f}"
)
results.append((idx, val_loss, k, c, s))
if time.time() - start > 8 * 60:
print("Time limit reached")
break
# final validation after last step
model.eval()
set_dropout(model, 0.0)
with torch.no_grad():
logits, telemetry = model(valid_bits, causal=not diffusion)
if diffusion:
pred = logits.reshape(-1, 2)
target = valid_bits.reshape(-1)
else:
pred = logits[:, :-1, :].reshape(-1, 2)
target = valid_bits[:, 1:].reshape(-1)
val_loss = F.cross_entropy(pred, target).item()
k = telemetry["negentropy_logits"].mean().item()
c = telemetry["lz_complexity_logits"].mean().item()
s = telemetry["symbiosis_score"].mean().item()
print(f"Final validation loss: {val_loss:.4f} K={k:.3f} C={c:.3f} S={s:.3f}")
results.append((steps + plateau_steps, val_loss, k, c, s))
# persist final model weights for future runs
save_model(model, weights_path)
input_bits = valid_bits[:1]
if qat:
qmodel = convert_qat_fx(model)
else:
with cpu_autocast():
model(input_bits)
qmodel = quantize_dynamic(model)
qmodel.eval()
try:
hil_safe_inference(
qmodel,
input_bits,
c_floor=0.3,
s_floor=0.5,
causal=not diffusion,
strict=not diffusion,
)
except RuntimeError as e:
print("Safety gate triggered", e)
collapsed = None
if collapse:
synth = TelemetrySynthesizer(n_clusters=8)
reps = synth.cluster_sequences(model, train_bits[:64])
floors = {"negentropy": 0.3, "lz_complexity": 0.35, "symbiosis_score": 0.5}
collapsed, metrics = collapse_submodel(
reps,
target_params=dict(
d_model=16,
nhead=4,
num_layers=1,
dim_feedforward=32,
max_seq_len=max_len,
),
floors=floors,
)
collapsed.eval()
with torch.no_grad():
logits, _ = collapsed(valid_bits)
pred = logits[:, :-1, :].reshape(-1, 2)
target = valid_bits[:, 1:].reshape(-1)
c_loss = F.cross_entropy(pred, target).item()
print("Collapsed model validation loss:", c_loss)
if collapsed_path is not None:
save_distilled_model(
collapsed,
collapsed_path,
{**metrics, "val_loss": c_loss},
floors=floors,
)
if diffusion:
sample = diffusion_inference(
model, length=max_len, steps=diffusion_steps, schedule=noise_schedule
)
print("Diffusion sample:", sample[0].tolist())
return results, collapsed
if __name__ == "__main__":
integration_schedule()