| """ |
| DFlash evaluation script: measure acceptance length (tau) with optional multi-step denoising. |
| Supports multi-GPU data parallelism via torchrun. |
| |
| Usage: |
| # Single GPU |
| python scripts/eval_dflash.py \ |
| --target-model-path /workspace/models/Qwen3-8B \ |
| --draft-model-path /workspace/models/Qwen3-8B-DFlash-b16 \ |
| --dataset math500 --max-samples 10 --num-denoise-steps 1 |
| |
| # 8 GPU (data parallel, each GPU runs a subset of samples) |
| torchrun --standalone --nproc_per_node 8 scripts/eval_dflash.py \ |
| --target-model-path /workspace/models/Qwen3-8B \ |
| --draft-model-path /workspace/models/Qwen3-8B-DFlash-b16 \ |
| --dataset math500 --max-samples 500 --num-denoise-steps 2 |
| """ |
|
|
| import argparse |
| import json |
| import os |
| import sys |
| import time |
| from statistics import mean, median |
|
|
| import torch |
| from datasets import load_dataset |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| |
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| ROOT_DIR = os.path.dirname(SCRIPT_DIR) |
| sys.path.insert(0, ROOT_DIR) |
|
|
| from specforge.modeling.draft.dflash import DFlashDraftModel |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="DFlash evaluation: acceptance length") |
| parser.add_argument("--target-model-path", type=str, required=True, |
| help="Path to target model (e.g. /workspace/models/Qwen3-8B)") |
| parser.add_argument("--draft-model-path", type=str, required=True, |
| help="Path to DFlash draft model (e.g. /workspace/models/Qwen3-8B-DFlash-b16)") |
| parser.add_argument("--dataset", type=str, default="math500", |
| choices=["math500", "gsm8k", "custom"], |
| help="Evaluation dataset") |
| parser.add_argument("--custom-data-path", type=str, default=None, |
| help="Path to custom jsonl data (when --dataset=custom)") |
| parser.add_argument("--max-samples", type=int, default=10, |
| help="Max number of evaluation samples (total across all GPUs)") |
| parser.add_argument("--max-new-tokens", type=int, default=512, |
| help="Max new tokens per sample") |
| parser.add_argument("--num-denoise-steps", type=int, default=1, |
| help="Number of denoising steps (1=baseline, 2/3=multi-step)") |
| parser.add_argument("--temperature", type=float, default=0.0, |
| help="Sampling temperature (0.0=greedy)") |
| parser.add_argument("--dtype", type=str, default="bfloat16", |
| choices=["float16", "bfloat16", "float32"]) |
| parser.add_argument("--output-file", type=str, default=None, |
| help="Save results to JSON file") |
| return parser.parse_args() |
|
|
|
|
| def get_rank_and_world(): |
| """Get distributed rank and world size, or (0, 1) for single GPU.""" |
| rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", 0))) |
| world = int(os.environ.get("WORLD_SIZE", 1)) |
| return rank, world |
|
|
|
|
| def load_eval_data(args): |
| """Load evaluation prompts as list of strings (offline, from local cache).""" |
| prompts = [] |
|
|
| |
| DATASET_CACHE = "/workspace/hanrui/datasets" |
|
|
| if args.dataset == "math500": |
| dataset = load_dataset("HuggingFaceH4/MATH-500", cache_dir=DATASET_CACHE)["test"] |
| for idx, item in enumerate(dataset): |
| if idx >= args.max_samples: |
| break |
| prompts.append(item["problem"]) |
|
|
| elif args.dataset == "gsm8k": |
| dataset = load_dataset("openai/gsm8k", "main", cache_dir=DATASET_CACHE)["test"] |
| for idx, item in enumerate(dataset): |
| if idx >= args.max_samples: |
| break |
| prompts.append(item["question"]) |
|
|
| elif args.dataset == "custom": |
| assert args.custom_data_path is not None, "Need --custom-data-path for custom dataset" |
| with open(args.custom_data_path, "r") as f: |
| for idx, line in enumerate(f): |
| if idx >= args.max_samples: |
| break |
| data = json.loads(line.strip()) |
| if "prompt" in data: |
| prompts.append(data["prompt"]) |
| elif "conversations" in data: |
| for msg in data["conversations"]: |
| if msg["role"] == "user": |
| prompts.append(msg["content"]) |
| break |
|
|
| return prompts |
|
|
|
|
| def load_draft_model(draft_model_path, torch_dtype, device): |
| """Load draft model weights into OUR DFlashDraftModel (with multi-step denoising).""" |
| from transformers import AutoConfig |
| from safetensors.torch import load_file as load_safetensors |
| import glob as glob_module |
|
|
| draft_config = AutoConfig.from_pretrained(draft_model_path, trust_remote_code=True) |
| draft = DFlashDraftModel(draft_config).to(torch_dtype) |
|
|
| safetensors_files = sorted(glob_module.glob(os.path.join(draft_model_path, "*.safetensors"))) |
| bin_files = sorted(glob_module.glob(os.path.join(draft_model_path, "*.bin"))) |
|
|
| state_dict = {} |
| if safetensors_files: |
| for f in safetensors_files: |
| state_dict.update(load_safetensors(f, device="cpu")) |
| elif bin_files: |
| for f in bin_files: |
| state_dict.update(torch.load(f, map_location="cpu", weights_only=True)) |
| else: |
| raise FileNotFoundError(f"No safetensors or bin files found in {draft_model_path}") |
|
|
| missing, unexpected = draft.load_state_dict(state_dict, strict=False) |
| if missing: |
| print(f" [rank {device}] WARNING: missing keys: {missing}") |
| if unexpected: |
| print(f" [rank {device}] INFO: unexpected keys (ignored): {unexpected}") |
|
|
| return draft.to(device).eval() |
|
|
|
|
| def run_eval_on_samples(draft, target, tokenizer, prompts, args, device, rank): |
| """Run evaluation on a list of prompts, return results.""" |
| stop_token_ids = [tokenizer.eos_token_id] |
| if hasattr(tokenizer, "additional_special_tokens_ids"): |
| stop_token_ids.extend(tokenizer.additional_special_tokens_ids[:3]) |
|
|
| all_acceptance_lengths = [] |
| all_taus = [] |
| all_times = [] |
| all_generated_tokens = [] |
|
|
| for i, prompt in enumerate(prompts): |
| |
| messages = [{"role": "user", "content": prompt}] |
| try: |
| input_text = tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True |
| ) |
| input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device) |
| except Exception: |
| input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) |
|
|
| num_input_tokens = input_ids.shape[1] |
|
|
| |
| t0 = time.time() |
| output_ids, acceptance_lengths = draft.spec_generate( |
| target=target, |
| input_ids=input_ids, |
| max_new_tokens=args.max_new_tokens, |
| stop_token_ids=stop_token_ids, |
| temperature=args.temperature, |
| num_denoise_steps=args.num_denoise_steps, |
| ) |
| t1 = time.time() |
|
|
| num_output_tokens = output_ids.shape[1] - num_input_tokens |
| elapsed = t1 - t0 |
| avg_tau = mean(acceptance_lengths) if acceptance_lengths else 0 |
|
|
| all_acceptance_lengths.extend(acceptance_lengths) |
| all_taus.append(avg_tau) |
| all_times.append(elapsed) |
| all_generated_tokens.append(num_output_tokens) |
|
|
| output_text = tokenizer.decode(output_ids[0, num_input_tokens:], skip_special_tokens=True) |
| preview = output_text[:80].replace("\n", " ") |
|
|
| print(f" [GPU {rank}] Sample {i+1}/{len(prompts)} | " |
| f"tokens={num_output_tokens} | tau={avg_tau:.2f} | " |
| f"time={elapsed:.1f}s | {preview}...") |
|
|
| return { |
| "acceptance_lengths": all_acceptance_lengths, |
| "per_sample_taus": all_taus, |
| "times": all_times, |
| "generated_tokens": all_generated_tokens, |
| } |
|
|
|
|
| def main(): |
| args = parse_args() |
| rank, world_size = get_rank_and_world() |
| device = f"cuda:{rank}" |
| torch.cuda.set_device(rank) |
|
|
| dtype_map = { |
| "float16": torch.float16, |
| "bfloat16": torch.bfloat16, |
| "float32": torch.float32, |
| } |
| torch_dtype = dtype_map[args.dtype] |
|
|
| if rank == 0: |
| print("=" * 60) |
| print("DFlash Evaluation (Multi-GPU Data Parallel)") |
| print("=" * 60) |
| print(f" Target model: {args.target_model_path}") |
| print(f" Draft model: {args.draft_model_path}") |
| print(f" Dataset: {args.dataset}") |
| print(f" Max samples: {args.max_samples}") |
| print(f" Max new tokens: {args.max_new_tokens}") |
| print(f" Denoise steps: {args.num_denoise_steps}") |
| print(f" Temperature: {args.temperature}") |
| print(f" GPUs: {world_size}") |
| print(f" Dtype: {args.dtype}") |
| print("=" * 60) |
|
|
| |
| if rank == 0: |
| print(f"\n[1/4] Loading tokenizer...") |
| tokenizer = AutoTokenizer.from_pretrained(args.target_model_path, trust_remote_code=True) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| if rank == 0: |
| print(f"[2/4] Loading target model on {world_size} GPUs...") |
| target = AutoModelForCausalLM.from_pretrained( |
| args.target_model_path, |
| torch_dtype=torch_dtype, |
| trust_remote_code=True, |
| ).to(device).eval() |
|
|
| if rank == 0: |
| print(f"[3/4] Loading draft model on {world_size} GPUs...") |
| draft = load_draft_model(args.draft_model_path, torch_dtype, device) |
|
|
| if rank == 0: |
| print(f" Draft layers: {draft.config.num_hidden_layers}") |
| print(f" Draft block_size: {draft.block_size}") |
| print(f" Draft mask_token: {draft.mask_token_id}") |
| print(f" Draft layer_ids: {draft.target_layer_ids}") |
|
|
| |
| if rank == 0: |
| print(f"[4/4] Loading evaluation data...") |
| all_prompts = load_eval_data(args) |
|
|
| |
| my_prompts = all_prompts[rank::world_size] |
| if rank == 0: |
| print(f" Total prompts: {len(all_prompts)}, ~{len(my_prompts)} per GPU") |
|
|
| |
| if rank == 0: |
| print("\n" + "=" * 60) |
| print("Running evaluation...") |
| print("=" * 60) |
|
|
| results = run_eval_on_samples(draft, target, tokenizer, my_prompts, args, device, rank) |
|
|
| |
| if world_size > 1: |
| import torch.distributed as dist |
| if not dist.is_initialized(): |
| dist.init_process_group(backend="nccl") |
|
|
| |
| gather_dir = "/workspace/hanrui/cache/eval_gather" |
| os.makedirs(gather_dir, exist_ok=True) |
| tmp_file = os.path.join(gather_dir, f"rank{rank}.json") |
| with open(tmp_file, "w") as f: |
| json.dump(results, f) |
|
|
| dist.barrier(device_ids=[rank]) |
|
|
| if rank == 0: |
| |
| all_acceptance_lengths = [] |
| all_per_sample_taus = [] |
| all_times = [] |
| all_generated_tokens = [] |
|
|
| for r in range(world_size): |
| rf = os.path.join(gather_dir, f"rank{r}.json") |
| with open(rf, "r") as f: |
| rank_results = json.load(f) |
| all_acceptance_lengths.extend(rank_results["acceptance_lengths"]) |
| all_per_sample_taus.extend(rank_results["per_sample_taus"]) |
| all_times.extend(rank_results["times"]) |
| all_generated_tokens.extend(rank_results["generated_tokens"]) |
| os.remove(rf) |
|
|
| results = { |
| "acceptance_lengths": all_acceptance_lengths, |
| "per_sample_taus": all_per_sample_taus, |
| "times": all_times, |
| "generated_tokens": all_generated_tokens, |
| } |
| else: |
| |
| dist.barrier(device_ids=[rank]) |
| if os.path.exists(tmp_file): |
| os.remove(tmp_file) |
|
|
| if rank == 0: |
| dist.barrier(device_ids=[rank]) |
|
|
| dist.destroy_process_group() |
|
|
| |
| if rank == 0: |
| acc_lens = results["acceptance_lengths"] |
| per_sample_taus = results["per_sample_taus"] |
| total_tokens = sum(results["generated_tokens"]) |
| total_time = sum(results["times"]) |
| |
| wall_time = max(results["times"]) if results["times"] else 0 |
|
|
| overall_avg_tau = mean(acc_lens) if acc_lens else 0 |
| overall_median_tau = median(acc_lens) if acc_lens else 0 |
|
|
| print("\n" + "=" * 60) |
| print("RESULTS SUMMARY") |
| print("=" * 60) |
| print(f" Denoise steps: {args.num_denoise_steps}") |
| print(f" GPUs used: {world_size}") |
| print(f" Samples evaluated: {len(per_sample_taus)}") |
| print(f" Total blocks: {len(acc_lens)}") |
| print(f" Total generated tokens: {total_tokens}") |
| print(f" Total GPU-time: {total_time:.2f}s") |
| print(f" Wall-clock time (approx): {wall_time:.2f}s") |
| print(f" ---") |
| print(f" Avg acceptance length (tau): {overall_avg_tau:.2f}") |
| print(f" Median acceptance length: {overall_median_tau:.1f}") |
| print(f" Per-sample avg tau: {[f'{t:.2f}' for t in per_sample_taus]}") |
| if per_sample_taus: |
| print(f" Min per-sample tau: {min(per_sample_taus):.2f}") |
| print(f" Max per-sample tau: {max(per_sample_taus):.2f}") |
| print("=" * 60) |
|
|
| |
| if args.output_file: |
| output = { |
| "config": { |
| "target_model": args.target_model_path, |
| "draft_model": args.draft_model_path, |
| "dataset": args.dataset, |
| "max_samples": args.max_samples, |
| "max_new_tokens": args.max_new_tokens, |
| "num_denoise_steps": args.num_denoise_steps, |
| "temperature": args.temperature, |
| "num_gpus": world_size, |
| }, |
| "results": { |
| "avg_tau": overall_avg_tau, |
| "median_tau": overall_median_tau, |
| "per_sample_tau": per_sample_taus, |
| "total_blocks": len(acc_lens), |
| "total_tokens": total_tokens, |
| "total_gpu_time": total_time, |
| "wall_clock_time": wall_time, |
| "all_acceptance_lengths": acc_lens, |
| }, |
| } |
| os.makedirs(os.path.dirname(args.output_file) or ".", exist_ok=True) |
| with open(args.output_file, "w") as f: |
| json.dump(output, f, indent=2) |
| print(f"\nResults saved to {args.output_file}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|