""" DFlash evaluation script: measure acceptance length (tau) with optional multi-step denoising. Supports multi-GPU data parallelism via torchrun. Usage: # Single GPU python scripts/eval_dflash.py \ --target-model-path /workspace/models/Qwen3-8B \ --draft-model-path /workspace/models/Qwen3-8B-DFlash-b16 \ --dataset math500 --max-samples 10 --num-denoise-steps 1 # 8 GPU (data parallel, each GPU runs a subset of samples) torchrun --standalone --nproc_per_node 8 scripts/eval_dflash.py \ --target-model-path /workspace/models/Qwen3-8B \ --draft-model-path /workspace/models/Qwen3-8B-DFlash-b16 \ --dataset math500 --max-samples 500 --num-denoise-steps 2 """ import argparse import json import os import sys import time from statistics import mean, median import torch from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer # Add project root to path SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) ROOT_DIR = os.path.dirname(SCRIPT_DIR) sys.path.insert(0, ROOT_DIR) from specforge.modeling.draft.dflash import DFlashDraftModel def parse_args(): parser = argparse.ArgumentParser(description="DFlash evaluation: acceptance length") parser.add_argument("--target-model-path", type=str, required=True, help="Path to target model (e.g. /workspace/models/Qwen3-8B)") parser.add_argument("--draft-model-path", type=str, required=True, help="Path to DFlash draft model (e.g. /workspace/models/Qwen3-8B-DFlash-b16)") parser.add_argument("--dataset", type=str, default="math500", choices=["math500", "gsm8k", "custom"], help="Evaluation dataset") parser.add_argument("--custom-data-path", type=str, default=None, help="Path to custom jsonl data (when --dataset=custom)") parser.add_argument("--max-samples", type=int, default=10, help="Max number of evaluation samples (total across all GPUs)") parser.add_argument("--max-new-tokens", type=int, default=512, help="Max new tokens per sample") parser.add_argument("--num-denoise-steps", type=int, default=1, help="Number of denoising steps (1=baseline, 2/3=multi-step)") parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature (0.0=greedy)") parser.add_argument("--dtype", type=str, default="bfloat16", choices=["float16", "bfloat16", "float32"]) parser.add_argument("--output-file", type=str, default=None, help="Save results to JSON file") return parser.parse_args() def get_rank_and_world(): """Get distributed rank and world size, or (0, 1) for single GPU.""" rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", 0))) world = int(os.environ.get("WORLD_SIZE", 1)) return rank, world def load_eval_data(args): """Load evaluation prompts as list of strings (offline, from local cache).""" prompts = [] # Use local cached datasets (no network access needed) DATASET_CACHE = "/workspace/hanrui/datasets" if args.dataset == "math500": dataset = load_dataset("HuggingFaceH4/MATH-500", cache_dir=DATASET_CACHE)["test"] for idx, item in enumerate(dataset): if idx >= args.max_samples: break prompts.append(item["problem"]) elif args.dataset == "gsm8k": dataset = load_dataset("openai/gsm8k", "main", cache_dir=DATASET_CACHE)["test"] for idx, item in enumerate(dataset): if idx >= args.max_samples: break prompts.append(item["question"]) elif args.dataset == "custom": assert args.custom_data_path is not None, "Need --custom-data-path for custom dataset" with open(args.custom_data_path, "r") as f: for idx, line in enumerate(f): if idx >= args.max_samples: break data = json.loads(line.strip()) if "prompt" in data: prompts.append(data["prompt"]) elif "conversations" in data: for msg in data["conversations"]: if msg["role"] == "user": prompts.append(msg["content"]) break return prompts def load_draft_model(draft_model_path, torch_dtype, device): """Load draft model weights into OUR DFlashDraftModel (with multi-step denoising).""" from transformers import AutoConfig from safetensors.torch import load_file as load_safetensors import glob as glob_module draft_config = AutoConfig.from_pretrained(draft_model_path, trust_remote_code=True) draft = DFlashDraftModel(draft_config).to(torch_dtype) safetensors_files = sorted(glob_module.glob(os.path.join(draft_model_path, "*.safetensors"))) bin_files = sorted(glob_module.glob(os.path.join(draft_model_path, "*.bin"))) state_dict = {} if safetensors_files: for f in safetensors_files: state_dict.update(load_safetensors(f, device="cpu")) elif bin_files: for f in bin_files: state_dict.update(torch.load(f, map_location="cpu", weights_only=True)) else: raise FileNotFoundError(f"No safetensors or bin files found in {draft_model_path}") missing, unexpected = draft.load_state_dict(state_dict, strict=False) if missing: print(f" [rank {device}] WARNING: missing keys: {missing}") if unexpected: print(f" [rank {device}] INFO: unexpected keys (ignored): {unexpected}") return draft.to(device).eval() def run_eval_on_samples(draft, target, tokenizer, prompts, args, device, rank): """Run evaluation on a list of prompts, return results.""" stop_token_ids = [tokenizer.eos_token_id] if hasattr(tokenizer, "additional_special_tokens_ids"): stop_token_ids.extend(tokenizer.additional_special_tokens_ids[:3]) all_acceptance_lengths = [] all_taus = [] all_times = [] all_generated_tokens = [] for i, prompt in enumerate(prompts): # Tokenize with chat template messages = [{"role": "user", "content": prompt}] try: input_text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device) except Exception: input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) num_input_tokens = input_ids.shape[1] # Generate with speculative decoding t0 = time.time() output_ids, acceptance_lengths = draft.spec_generate( target=target, input_ids=input_ids, max_new_tokens=args.max_new_tokens, stop_token_ids=stop_token_ids, temperature=args.temperature, num_denoise_steps=args.num_denoise_steps, ) t1 = time.time() num_output_tokens = output_ids.shape[1] - num_input_tokens elapsed = t1 - t0 avg_tau = mean(acceptance_lengths) if acceptance_lengths else 0 all_acceptance_lengths.extend(acceptance_lengths) all_taus.append(avg_tau) all_times.append(elapsed) all_generated_tokens.append(num_output_tokens) output_text = tokenizer.decode(output_ids[0, num_input_tokens:], skip_special_tokens=True) preview = output_text[:80].replace("\n", " ") print(f" [GPU {rank}] Sample {i+1}/{len(prompts)} | " f"tokens={num_output_tokens} | tau={avg_tau:.2f} | " f"time={elapsed:.1f}s | {preview}...") return { "acceptance_lengths": all_acceptance_lengths, "per_sample_taus": all_taus, "times": all_times, "generated_tokens": all_generated_tokens, } def main(): args = parse_args() rank, world_size = get_rank_and_world() device = f"cuda:{rank}" torch.cuda.set_device(rank) dtype_map = { "float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32, } torch_dtype = dtype_map[args.dtype] if rank == 0: print("=" * 60) print("DFlash Evaluation (Multi-GPU Data Parallel)") print("=" * 60) print(f" Target model: {args.target_model_path}") print(f" Draft model: {args.draft_model_path}") print(f" Dataset: {args.dataset}") print(f" Max samples: {args.max_samples}") print(f" Max new tokens: {args.max_new_tokens}") print(f" Denoise steps: {args.num_denoise_steps}") print(f" Temperature: {args.temperature}") print(f" GPUs: {world_size}") print(f" Dtype: {args.dtype}") print("=" * 60) # ---- Load models (each GPU loads its own copy) ---- if rank == 0: print(f"\n[1/4] Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(args.target_model_path, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token if rank == 0: print(f"[2/4] Loading target model on {world_size} GPUs...") target = AutoModelForCausalLM.from_pretrained( args.target_model_path, torch_dtype=torch_dtype, trust_remote_code=True, ).to(device).eval() if rank == 0: print(f"[3/4] Loading draft model on {world_size} GPUs...") draft = load_draft_model(args.draft_model_path, torch_dtype, device) if rank == 0: print(f" Draft layers: {draft.config.num_hidden_layers}") print(f" Draft block_size: {draft.block_size}") print(f" Draft mask_token: {draft.mask_token_id}") print(f" Draft layer_ids: {draft.target_layer_ids}") # ---- Load and split data ---- if rank == 0: print(f"[4/4] Loading evaluation data...") all_prompts = load_eval_data(args) # Split prompts across GPUs my_prompts = all_prompts[rank::world_size] if rank == 0: print(f" Total prompts: {len(all_prompts)}, ~{len(my_prompts)} per GPU") # ---- Run evaluation ---- if rank == 0: print("\n" + "=" * 60) print("Running evaluation...") print("=" * 60) results = run_eval_on_samples(draft, target, tokenizer, my_prompts, args, device, rank) # ---- Gather results from all GPUs ---- if world_size > 1: import torch.distributed as dist if not dist.is_initialized(): dist.init_process_group(backend="nccl") # Save per-rank results to shared filesystem (not /tmp which may not be shared) gather_dir = "/workspace/hanrui/cache/eval_gather" os.makedirs(gather_dir, exist_ok=True) tmp_file = os.path.join(gather_dir, f"rank{rank}.json") with open(tmp_file, "w") as f: json.dump(results, f) dist.barrier(device_ids=[rank]) if rank == 0: # Aggregate all ranks all_acceptance_lengths = [] all_per_sample_taus = [] all_times = [] all_generated_tokens = [] for r in range(world_size): rf = os.path.join(gather_dir, f"rank{r}.json") with open(rf, "r") as f: rank_results = json.load(f) all_acceptance_lengths.extend(rank_results["acceptance_lengths"]) all_per_sample_taus.extend(rank_results["per_sample_taus"]) all_times.extend(rank_results["times"]) all_generated_tokens.extend(rank_results["generated_tokens"]) os.remove(rf) results = { "acceptance_lengths": all_acceptance_lengths, "per_sample_taus": all_per_sample_taus, "times": all_times, "generated_tokens": all_generated_tokens, } else: # Wait for rank 0 to finish reading before removing dist.barrier(device_ids=[rank]) if os.path.exists(tmp_file): os.remove(tmp_file) if rank == 0: dist.barrier(device_ids=[rank]) dist.destroy_process_group() # ---- Print summary (rank 0 only) ---- if rank == 0: acc_lens = results["acceptance_lengths"] per_sample_taus = results["per_sample_taus"] total_tokens = sum(results["generated_tokens"]) total_time = sum(results["times"]) # wall-clock time is max across GPUs (they run in parallel) wall_time = max(results["times"]) if results["times"] else 0 overall_avg_tau = mean(acc_lens) if acc_lens else 0 overall_median_tau = median(acc_lens) if acc_lens else 0 print("\n" + "=" * 60) print("RESULTS SUMMARY") print("=" * 60) print(f" Denoise steps: {args.num_denoise_steps}") print(f" GPUs used: {world_size}") print(f" Samples evaluated: {len(per_sample_taus)}") print(f" Total blocks: {len(acc_lens)}") print(f" Total generated tokens: {total_tokens}") print(f" Total GPU-time: {total_time:.2f}s") print(f" Wall-clock time (approx): {wall_time:.2f}s") print(f" ---") print(f" Avg acceptance length (tau): {overall_avg_tau:.2f}") print(f" Median acceptance length: {overall_median_tau:.1f}") print(f" Per-sample avg tau: {[f'{t:.2f}' for t in per_sample_taus]}") if per_sample_taus: print(f" Min per-sample tau: {min(per_sample_taus):.2f}") print(f" Max per-sample tau: {max(per_sample_taus):.2f}") print("=" * 60) # Save results if args.output_file: output = { "config": { "target_model": args.target_model_path, "draft_model": args.draft_model_path, "dataset": args.dataset, "max_samples": args.max_samples, "max_new_tokens": args.max_new_tokens, "num_denoise_steps": args.num_denoise_steps, "temperature": args.temperature, "num_gpus": world_size, }, "results": { "avg_tau": overall_avg_tau, "median_tau": overall_median_tau, "per_sample_tau": per_sample_taus, "total_blocks": len(acc_lens), "total_tokens": total_tokens, "total_gpu_time": total_time, "wall_clock_time": wall_time, "all_acceptance_lengths": acc_lens, }, } os.makedirs(os.path.dirname(args.output_file) or ".", exist_ok=True) with open(args.output_file, "w") as f: json.dump(output, f, indent=2) print(f"\nResults saved to {args.output_file}") if __name__ == "__main__": main()