WCNegentropy's picture
๐Ÿš€ Refined BitTransformerLM: Organized codebase with best practices
2f70b79 verified
"""Command-line interface entry points for BitTransformerLM."""
import sys
import logging
from pathlib import Path
from typing import Optional
import torch
from .cli_standards import create_training_parser, create_inference_parser, BitTransformerCLI
from .config import (
ExperimentConfig,
ModelConfig,
TrainingConfig,
SafetyConfig,
DataConfig,
get_small_config,
get_medium_config,
get_large_config,
)
from .model import BitTransformerLM, diffusion_inference
from .training import train_loop
from .bit_io import text_to_bits, bits_to_text, infer_text
from .utils import save_model, load_model
from .dashboard_app import run_dashboard
def setup_logging(level: str = "INFO") -> None:
"""Setup logging configuration."""
logging.basicConfig(
level=getattr(logging, level.upper()),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[
logging.StreamHandler(sys.stdout),
],
)
def train_cli() -> None:
"""CLI entry point for training BitTransformerLM models."""
parser = create_training_parser()
args = parser.parse_args()
setup_logging(args.log_level)
logger = logging.getLogger(__name__)
# Get preset configuration if specified
if args.model_size == "small":
config = get_small_config()
elif args.model_size == "medium":
config = get_medium_config()
elif args.model_size == "large":
config = get_large_config()
else:
config = ExperimentConfig()
# Override with command line arguments
config.model.d_model = args.d_model
config.model.nhead = args.num_heads
config.model.num_layers = args.num_layers
config.model.max_seq_len = args.max_seq_len
config.training.epochs = args.epochs
config.training.batch_size = args.batch_size
config.training.learning_rate = args.learning_rate
config.training.weight_decay = args.weight_decay
config.training.gradient_clip_val = args.grad_clip
config.training.warmup_steps = args.warmup_steps
config.training.amp = args.use_amp
config.training.compile_model = args.compile_model
config.safety.k_threshold = args.min_negentropy
config.safety.c_threshold = args.max_complexity
config.safety.s_threshold = args.min_symbiosis
config.safety.enable_safety = args.enable_safety_gates
config.data.dataset_path = Path(args.input_path) if args.input_path else None
config.data.max_sequence_length = args.seq_length
config.data.num_workers = args.num_workers
config.output_dir = Path(args.output_path)
config.seed = args.seed
# Set device
if torch.cuda.is_available():
config.device = "cuda"
else:
config.device = "cpu"
logger.info(f"Starting training with config: {config.experiment_name}")
logger.info(f"Model: {config.model.d_model}d, {config.model.num_layers}L, {config.model.nhead}H")
logger.info(f"Device: {config.device}")
# Create model
model = BitTransformerLM(**config.model.to_dict())
model = model.to(config.device)
# Create synthetic dataset for demonstration
logger.info("Creating synthetic training data...")
torch.manual_seed(config.seed)
data = torch.randint(0, 2, (args.dataset_size, config.data.max_sequence_length))
# Train model
logger.info("Starting training...")
try:
train_loop(
model,
data,
epochs=config.training.epochs,
batch_size=config.training.batch_size,
amp=config.training.amp,
compile_model=config.training.compile_model,
log=True,
)
# Save model
save_path = config.output_dir / "model_final.pt"
save_model(model, save_path)
logger.info(f"Model saved to {save_path}")
except Exception as e:
logger.error(f"Training failed: {e}")
sys.exit(1)
def infer_cli() -> None:
"""CLI entry point for BitTransformerLM inference."""
parser = create_inference_parser()
parser.add_argument("--prompt", type=str, required=True, help="Text prompt for generation")
parser.add_argument("--max-tokens", type=int, default=50, help="Maximum tokens to generate")
parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature")
parser.add_argument("--use-diffusion", action="store_true", help="Use diffusion mode")
args = parser.parse_args()
setup_logging(args.log_level)
logger = logging.getLogger(__name__)
# Load model
if not Path(args.weights_path).exists():
logger.error(f"Model weights not found at {args.weights_path}")
sys.exit(1)
logger.info(f"Loading model from {args.weights_path}")
model = load_model(args.weights_path)
model.eval()
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
logger.info(f"Model loaded on {device}")
logger.info(f"Prompt: {args.prompt}")
try:
if args.use_diffusion:
# Diffusion inference
logger.info("Using diffusion inference mode")
prompt_bits = text_to_bits(args.prompt)
length = len(prompt_bits) + args.max_tokens * 9 # Approximate
generated_bits = diffusion_inference(
model,
length=length,
steps=args.diffusion_steps,
schedule=args.noise_schedule,
)
result = bits_to_text(generated_bits[0].tolist())
else:
# Standard autoregressive inference with safety
if args.enable_safety_gates:
result = infer_text(
model,
args.prompt,
c_floor=args.max_complexity,
s_floor=args.min_symbiosis,
)
else:
# Simple generation without safety gates
from .bit_io import sample_text
result = sample_text(
model,
args.prompt,
max_new_tokens=args.max_tokens,
temperature=args.temperature,
)
print(f"\nGenerated text:\n{result}")
except Exception as e:
logger.error(f"Inference failed: {e}")
sys.exit(1)
def dashboard_cli() -> None:
"""CLI entry point for BitTransformerLM dashboard."""
parser = BitTransformerCLI.create_standard_parser(
"BitTransformerLM Dashboard",
["io"]
)
parser.add_argument("--host", type=str, default="127.0.0.1", help="Dashboard host")
parser.add_argument("--port", type=int, default=7860, help="Dashboard port")
parser.add_argument("--share", action="store_true", help="Create public link")
args = parser.parse_args()
setup_logging(args.log_level)
logger = logging.getLogger(__name__)
logger.info(f"Starting BitTransformerLM dashboard on {args.host}:{args.port}")
try:
run_dashboard(
host=args.host,
port=args.port,
share=args.share,
)
except Exception as e:
logger.error(f"Dashboard failed to start: {e}")
sys.exit(1)
if __name__ == "__main__":
# Simple dispatcher based on script name
import os
script_name = os.path.basename(sys.argv[0])
if "train" in script_name:
train_cli()
elif "infer" in script_name:
infer_cli()
elif "dashboard" in script_name:
dashboard_cli()
else:
print("Available commands:")
print(" bit-transformer-train - Train a BitTransformerLM model")
print(" bit-transformer-infer - Run inference with a trained model")
print(" bit-transformer-dashboard - Launch interactive dashboard")
sys.exit(1)