#!/usr/bin/env python3 """ Demo script for RND1 generation. """ import torch import argparse import os import sys import random import numpy as np from transformers import AutoTokenizer # Add RND1 module to path for local testing sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) def set_seed(seed: int): """Set random seed for reproducibility. """ random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def demo_completion( model_path: str, checkpoint_path: str = None, device: str = "cuda:0", use_bfloat16: bool = True, show_visualization: bool = True, num_steps: int = 64, max_new_tokens: int = 256, custom_prompt: str = None, temperature: float = 1.0, top_k: int = None, top_p: float = None, mask_token_id: int = 151669, seed: int = 12345, moe_backend: str = "hf", mode: str = "task", ): """ Demonstrate text completion using RND1. Args: model_path: Path to base model or HuggingFace model ID checkpoint_path: Path to custom checkpoint (if any) device: Device to run on (e.g., cuda:0, cpu) use_bfloat16: Whether to use bfloat16 precision show_visualization: Whether to show live visualization (requires rich) num_steps: Number of diffusion steps max_new_tokens: Maximum number of tokens to generate custom_prompt: Custom prompt to use instead of default examples temperature: Temperature for sampling (0.0 = greedy) top_k: Top-k filtering for sampling (None = disabled) top_p: Top-p (nucleus) filtering for sampling (None = disabled) mask_token_id: Token ID for mask token seed: Random seed for reproducibility moe_backend: MoE backend to use ('hf' or 'flashinfer') mode: Generation mode ('task' for Q&A format, 'completion' for continuation) """ set_seed(seed) from rnd.configuration_rnd import RND1Config from rnd.modeling_rnd import RND1LM print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) dtype = torch.bfloat16 if use_bfloat16 else torch.float32 print(f"Using dtype: {dtype}") if moe_backend == "hf": print("\n⚠️ Note: HuggingFace backend is slower. Consider using --moe_backend flashinfer or sglang for better performance.\n") # Load from checkpoint if provided, otherwise from model_path load_path = checkpoint_path if checkpoint_path else model_path print(f"Loading model from {load_path}...") # Load config and set RND1-specific settings cfg = RND1Config.from_pretrained(load_path) cfg.model_type = "rnd1" cfg.attn_implementation = "sdpa" cfg.moe_backend = moe_backend # Load model with RND1LM model = RND1LM.from_pretrained( load_path, config=cfg, torch_dtype=dtype, device_map="auto" if device == "cuda:0" else device, trust_remote_code=True, use_safetensors=True, low_cpu_mem_usage=True, ) print("Model loaded") model = model.eval() if custom_prompt: prompts = [custom_prompt] else: # Default prompts based on mode if mode == "task": prompts = ["Write a Python function that finds the longest common subsequence of two strings. Include comments explaining the algorithm."] else: prompts = ["The key to understanding quantum computing lies in"] greedy = (temperature == 1.0) generator = torch.Generator(device=device if device != "auto" else "cuda") generator.manual_seed(seed) for i, user_prompt in enumerate(prompts): print(f"\n{'='*60}") print(f"Mode: {mode.upper()}") print(f"Prompt {i+1}: {user_prompt[:100]}...") print(f"{'='*60}\n") if mode == "task": # Task mode: Add "Question: " prefix if not already present if not user_prompt.strip().startswith("Question:"): prompt = f"Question: {user_prompt}\n" else: prompt = user_prompt else: # Completion mode: Use prompt as-is for continuation prompt = user_prompt inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs.input_ids.to(device if device != "auto" else "cuda") attention_mask = inputs.attention_mask.to(device if device != "auto" else "cuda") if 'attention_mask' in inputs else None print("Generation parameters:") print(f" Prompt length: {input_ids.shape[1]} tokens") print(f" Max new tokens: {max_new_tokens}") print(f" Total sequence: {input_ids.shape[1] + max_new_tokens} tokens") print(f" Diffusion steps: {num_steps}") print(f" Temperature: {temperature}") print(f" Greedy: {greedy}") if top_k: print(f" Top-k: {top_k}") if top_p: print(f" Top-p: {top_p}") print() # Create explicit generation config that takes priority over model defaults from rnd.generation_config import RND1GenerationConfig gen_config = RND1GenerationConfig( max_new_tokens=max_new_tokens, num_diffusion_steps=num_steps, mask_token_id=mask_token_id, temperature=temperature if not greedy else 1.0, top_k=top_k, top_p=top_p, greedy=greedy, eos_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id else 151645, pad_token_id=tokenizer.pad_token_id, bos_token_id=tokenizer.bos_token_id, ) with torch.no_grad(): if show_visualization and hasattr(model, 'generate_with_visualization'): # Use method with visualization support (requires tokenizer) output = model.generate_with_visualization( tokenizer=tokenizer, inputs=input_ids, generation_config=gen_config, generator=generator, ) else: # Use standard generate method with explicit config output = model.generate( inputs=input_ids, generation_config=gen_config, generator=generator, ) generated_tokens = output[0][len(input_ids[0]):] generation = tokenizer.decode( generated_tokens.tolist(), skip_special_tokens=True ) print("\nGenerated response:") print(generation) print(f"\n(Generation completed in {num_steps} diffusion steps)") def main(): parser = argparse.ArgumentParser( description="RND1 diffusion model demo with live visualization", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) # Model configuration model_group = parser.add_argument_group('Model Configuration') model_group.add_argument( "--model_path", type=str, default="radicalnumerics/RND1-Base-0910", help="Path to model or HuggingFace model ID" ) model_group.add_argument( "--checkpoint", type=str, default=None, help="Path to custom checkpoint file or directory" ) model_group.add_argument( "--device", type=str, default="cuda:0", help="Device to run on (e.g., cuda:0, cpu)" ) model_group.add_argument( "--fp32", action="store_true", help="Use FP32 precision instead of BF16" ) # Generation configuration gen_group = parser.add_argument_group('Generation Settings') gen_group.add_argument( "--num_steps", type=int, default=256, help="Number of diffusion steps" ) gen_group.add_argument( "--max_new_tokens", type=int, default=256, help="Maximum number of tokens to generate" ) gen_group.add_argument( "--prompt", type=str, default=None, help="Custom prompt to use for generation" ) gen_group.add_argument( "--mode", type=str, default="task", choices=["task", "completion"], help="Generation mode: 'task' (Q&A format for instructions) or 'completion' (text continuation)" ) gen_group.add_argument( "--mask_token_id", type=int, default=151669, help="Token ID for mask token" ) # Sampling configuration sampling_group = parser.add_argument_group('Sampling Parameters') sampling_group.add_argument( "--temperature", type=float, default=1.0, help="Temperature for sampling (1.0 = greedy/deterministic)" ) sampling_group.add_argument( "--top_k", type=int, default=None, help="Top-k filtering: keep only k most likely tokens" ) sampling_group.add_argument( "--top_p", type=float, default=None, help="Top-p (nucleus) filtering: keep tokens with cumulative probability <= p" ) # Visualization viz_group = parser.add_argument_group('Visualization') viz_group.add_argument( "--no_viz", action="store_true", help="Disable live visualization during generation (requires rich library)" ) # Other settings other_group = parser.add_argument_group('Other Settings') other_group.add_argument( "--seed", type=int, default=12345, help="Random seed for reproducibility" ) moe_backend_group = parser.add_argument_group('MoE Backend') moe_backend_group.add_argument( "--moe_backend", type=str, default="hf", choices=["hf", "flashinfer", "sglang"], help="MoE backend to use for sparse mixture of experts layers" ) args = parser.parse_args() if args.temperature < 0: parser.error("Temperature must be non-negative") if args.top_k is not None and args.top_k <= 0: parser.error("Top-k must be positive") if args.top_p is not None and (args.top_p <= 0 or args.top_p > 1): parser.error("Top-p must be between 0 and 1") print("\n" + "="*60) print("RND1 Diffusion Language Model Demo") print("="*60) print("Configuration:") print(f" Model: {args.model_path}") if args.checkpoint: print(f" Checkpoint: {args.checkpoint}") print(f" Device: {args.device}") print(f" Precision: {'FP32' if args.fp32 else 'BF16'}") print(f" Mode: {args.mode.upper()} ({'Q&A format for instructions' if args.mode == 'task' else 'Text continuation'})") print(f" Random seed: {args.seed}") print(f" Diffusion steps: {args.num_steps}") print(f" Max new tokens: {args.max_new_tokens}") print(f" Algorithm: Entropy-based selection") print(f" Temperature: {args.temperature}") if args.top_k: print(f" Top-k: {args.top_k}") if args.top_p: print(f" Top-p: {args.top_p}") print(f" MoE Backend: {args.moe_backend}") print(f" Visualization: {'Enabled' if not args.no_viz else 'Disabled'}") print("="*60 + "\n") demo_completion( model_path=args.model_path, checkpoint_path=args.checkpoint, device=args.device, use_bfloat16=not args.fp32, show_visualization=not args.no_viz, num_steps=args.num_steps, max_new_tokens=args.max_new_tokens, custom_prompt=args.prompt, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, mask_token_id=args.mask_token_id, seed=args.seed, moe_backend=args.moe_backend, mode=args.mode, ) if __name__ == "__main__": main()