|
|
"""Command-line interface entry points for BitTransformerLM.""" |
|
|
|
|
|
import sys |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
|
|
|
from .cli_standards import create_training_parser, create_inference_parser, BitTransformerCLI |
|
|
from .config import ( |
|
|
ExperimentConfig, |
|
|
ModelConfig, |
|
|
TrainingConfig, |
|
|
SafetyConfig, |
|
|
DataConfig, |
|
|
get_small_config, |
|
|
get_medium_config, |
|
|
get_large_config, |
|
|
) |
|
|
from .model import BitTransformerLM, diffusion_inference |
|
|
from .training import train_loop |
|
|
from .bit_io import text_to_bits, bits_to_text, infer_text |
|
|
from .utils import save_model, load_model |
|
|
from .dashboard_app import run_dashboard |
|
|
|
|
|
|
|
|
def setup_logging(level: str = "INFO") -> None: |
|
|
"""Setup logging configuration.""" |
|
|
logging.basicConfig( |
|
|
level=getattr(logging, level.upper()), |
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
|
|
handlers=[ |
|
|
logging.StreamHandler(sys.stdout), |
|
|
], |
|
|
) |
|
|
|
|
|
|
|
|
def train_cli() -> None: |
|
|
"""CLI entry point for training BitTransformerLM models.""" |
|
|
parser = create_training_parser() |
|
|
args = parser.parse_args() |
|
|
|
|
|
setup_logging(args.log_level) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
if args.model_size == "small": |
|
|
config = get_small_config() |
|
|
elif args.model_size == "medium": |
|
|
config = get_medium_config() |
|
|
elif args.model_size == "large": |
|
|
config = get_large_config() |
|
|
else: |
|
|
config = ExperimentConfig() |
|
|
|
|
|
|
|
|
config.model.d_model = args.d_model |
|
|
config.model.nhead = args.num_heads |
|
|
config.model.num_layers = args.num_layers |
|
|
config.model.max_seq_len = args.max_seq_len |
|
|
|
|
|
config.training.epochs = args.epochs |
|
|
config.training.batch_size = args.batch_size |
|
|
config.training.learning_rate = args.learning_rate |
|
|
config.training.weight_decay = args.weight_decay |
|
|
config.training.gradient_clip_val = args.grad_clip |
|
|
config.training.warmup_steps = args.warmup_steps |
|
|
config.training.amp = args.use_amp |
|
|
config.training.compile_model = args.compile_model |
|
|
|
|
|
config.safety.k_threshold = args.min_negentropy |
|
|
config.safety.c_threshold = args.max_complexity |
|
|
config.safety.s_threshold = args.min_symbiosis |
|
|
config.safety.enable_safety = args.enable_safety_gates |
|
|
|
|
|
config.data.dataset_path = Path(args.input_path) if args.input_path else None |
|
|
config.data.max_sequence_length = args.seq_length |
|
|
config.data.num_workers = args.num_workers |
|
|
|
|
|
config.output_dir = Path(args.output_path) |
|
|
config.seed = args.seed |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
config.device = "cuda" |
|
|
else: |
|
|
config.device = "cpu" |
|
|
|
|
|
logger.info(f"Starting training with config: {config.experiment_name}") |
|
|
logger.info(f"Model: {config.model.d_model}d, {config.model.num_layers}L, {config.model.nhead}H") |
|
|
logger.info(f"Device: {config.device}") |
|
|
|
|
|
|
|
|
model = BitTransformerLM(**config.model.to_dict()) |
|
|
model = model.to(config.device) |
|
|
|
|
|
|
|
|
logger.info("Creating synthetic training data...") |
|
|
torch.manual_seed(config.seed) |
|
|
data = torch.randint(0, 2, (args.dataset_size, config.data.max_sequence_length)) |
|
|
|
|
|
|
|
|
logger.info("Starting training...") |
|
|
try: |
|
|
train_loop( |
|
|
model, |
|
|
data, |
|
|
epochs=config.training.epochs, |
|
|
batch_size=config.training.batch_size, |
|
|
amp=config.training.amp, |
|
|
compile_model=config.training.compile_model, |
|
|
log=True, |
|
|
) |
|
|
|
|
|
|
|
|
save_path = config.output_dir / "model_final.pt" |
|
|
save_model(model, save_path) |
|
|
logger.info(f"Model saved to {save_path}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Training failed: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
def infer_cli() -> None: |
|
|
"""CLI entry point for BitTransformerLM inference.""" |
|
|
parser = create_inference_parser() |
|
|
parser.add_argument("--prompt", type=str, required=True, help="Text prompt for generation") |
|
|
parser.add_argument("--max-tokens", type=int, default=50, help="Maximum tokens to generate") |
|
|
parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature") |
|
|
parser.add_argument("--use-diffusion", action="store_true", help="Use diffusion mode") |
|
|
args = parser.parse_args() |
|
|
|
|
|
setup_logging(args.log_level) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
if not Path(args.weights_path).exists(): |
|
|
logger.error(f"Model weights not found at {args.weights_path}") |
|
|
sys.exit(1) |
|
|
|
|
|
logger.info(f"Loading model from {args.weights_path}") |
|
|
model = load_model(args.weights_path) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
model = model.to(device) |
|
|
|
|
|
logger.info(f"Model loaded on {device}") |
|
|
logger.info(f"Prompt: {args.prompt}") |
|
|
|
|
|
try: |
|
|
if args.use_diffusion: |
|
|
|
|
|
logger.info("Using diffusion inference mode") |
|
|
prompt_bits = text_to_bits(args.prompt) |
|
|
length = len(prompt_bits) + args.max_tokens * 9 |
|
|
|
|
|
generated_bits = diffusion_inference( |
|
|
model, |
|
|
length=length, |
|
|
steps=args.diffusion_steps, |
|
|
schedule=args.noise_schedule, |
|
|
) |
|
|
|
|
|
result = bits_to_text(generated_bits[0].tolist()) |
|
|
|
|
|
else: |
|
|
|
|
|
if args.enable_safety_gates: |
|
|
result = infer_text( |
|
|
model, |
|
|
args.prompt, |
|
|
c_floor=args.max_complexity, |
|
|
s_floor=args.min_symbiosis, |
|
|
) |
|
|
else: |
|
|
|
|
|
from .bit_io import sample_text |
|
|
result = sample_text( |
|
|
model, |
|
|
args.prompt, |
|
|
max_new_tokens=args.max_tokens, |
|
|
temperature=args.temperature, |
|
|
) |
|
|
|
|
|
print(f"\nGenerated text:\n{result}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Inference failed: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
def dashboard_cli() -> None: |
|
|
"""CLI entry point for BitTransformerLM dashboard.""" |
|
|
parser = BitTransformerCLI.create_standard_parser( |
|
|
"BitTransformerLM Dashboard", |
|
|
["io"] |
|
|
) |
|
|
parser.add_argument("--host", type=str, default="127.0.0.1", help="Dashboard host") |
|
|
parser.add_argument("--port", type=int, default=7860, help="Dashboard port") |
|
|
parser.add_argument("--share", action="store_true", help="Create public link") |
|
|
args = parser.parse_args() |
|
|
|
|
|
setup_logging(args.log_level) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
logger.info(f"Starting BitTransformerLM dashboard on {args.host}:{args.port}") |
|
|
|
|
|
try: |
|
|
run_dashboard( |
|
|
host=args.host, |
|
|
port=args.port, |
|
|
share=args.share, |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"Dashboard failed to start: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
import os |
|
|
script_name = os.path.basename(sys.argv[0]) |
|
|
|
|
|
if "train" in script_name: |
|
|
train_cli() |
|
|
elif "infer" in script_name: |
|
|
infer_cli() |
|
|
elif "dashboard" in script_name: |
|
|
dashboard_cli() |
|
|
else: |
|
|
print("Available commands:") |
|
|
print(" bit-transformer-train - Train a BitTransformerLM model") |
|
|
print(" bit-transformer-infer - Run inference with a trained model") |
|
|
print(" bit-transformer-dashboard - Launch interactive dashboard") |
|
|
sys.exit(1) |