Hanrui / idea1 /scripts /eval_dflash.py
Lekr0's picture
Add files using upload-large-folder tool
2d67aa6 verified
"""
DFlash evaluation script: measure acceptance length (tau) with optional multi-step denoising.
Supports multi-GPU data parallelism via torchrun.
Usage:
# Single GPU
python scripts/eval_dflash.py \
--target-model-path /workspace/models/Qwen3-8B \
--draft-model-path /workspace/models/Qwen3-8B-DFlash-b16 \
--dataset math500 --max-samples 10 --num-denoise-steps 1
# 8 GPU (data parallel, each GPU runs a subset of samples)
torchrun --standalone --nproc_per_node 8 scripts/eval_dflash.py \
--target-model-path /workspace/models/Qwen3-8B \
--draft-model-path /workspace/models/Qwen3-8B-DFlash-b16 \
--dataset math500 --max-samples 500 --num-denoise-steps 2
"""
import argparse
import json
import os
import sys
import time
from statistics import mean, median
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
# Add project root to path
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = os.path.dirname(SCRIPT_DIR)
sys.path.insert(0, ROOT_DIR)
from specforge.modeling.draft.dflash import DFlashDraftModel
def parse_args():
parser = argparse.ArgumentParser(description="DFlash evaluation: acceptance length")
parser.add_argument("--target-model-path", type=str, required=True,
help="Path to target model (e.g. /workspace/models/Qwen3-8B)")
parser.add_argument("--draft-model-path", type=str, required=True,
help="Path to DFlash draft model (e.g. /workspace/models/Qwen3-8B-DFlash-b16)")
parser.add_argument("--dataset", type=str, default="math500",
choices=["math500", "gsm8k", "custom"],
help="Evaluation dataset")
parser.add_argument("--custom-data-path", type=str, default=None,
help="Path to custom jsonl data (when --dataset=custom)")
parser.add_argument("--max-samples", type=int, default=10,
help="Max number of evaluation samples (total across all GPUs)")
parser.add_argument("--max-new-tokens", type=int, default=512,
help="Max new tokens per sample")
parser.add_argument("--num-denoise-steps", type=int, default=1,
help="Number of denoising steps (1=baseline, 2/3=multi-step)")
parser.add_argument("--temperature", type=float, default=0.0,
help="Sampling temperature (0.0=greedy)")
parser.add_argument("--dtype", type=str, default="bfloat16",
choices=["float16", "bfloat16", "float32"])
parser.add_argument("--output-file", type=str, default=None,
help="Save results to JSON file")
return parser.parse_args()
def get_rank_and_world():
"""Get distributed rank and world size, or (0, 1) for single GPU."""
rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", 0)))
world = int(os.environ.get("WORLD_SIZE", 1))
return rank, world
def load_eval_data(args):
"""Load evaluation prompts as list of strings (offline, from local cache)."""
prompts = []
# Use local cached datasets (no network access needed)
DATASET_CACHE = "/workspace/hanrui/datasets"
if args.dataset == "math500":
dataset = load_dataset("HuggingFaceH4/MATH-500", cache_dir=DATASET_CACHE)["test"]
for idx, item in enumerate(dataset):
if idx >= args.max_samples:
break
prompts.append(item["problem"])
elif args.dataset == "gsm8k":
dataset = load_dataset("openai/gsm8k", "main", cache_dir=DATASET_CACHE)["test"]
for idx, item in enumerate(dataset):
if idx >= args.max_samples:
break
prompts.append(item["question"])
elif args.dataset == "custom":
assert args.custom_data_path is not None, "Need --custom-data-path for custom dataset"
with open(args.custom_data_path, "r") as f:
for idx, line in enumerate(f):
if idx >= args.max_samples:
break
data = json.loads(line.strip())
if "prompt" in data:
prompts.append(data["prompt"])
elif "conversations" in data:
for msg in data["conversations"]:
if msg["role"] == "user":
prompts.append(msg["content"])
break
return prompts
def load_draft_model(draft_model_path, torch_dtype, device):
"""Load draft model weights into OUR DFlashDraftModel (with multi-step denoising)."""
from transformers import AutoConfig
from safetensors.torch import load_file as load_safetensors
import glob as glob_module
draft_config = AutoConfig.from_pretrained(draft_model_path, trust_remote_code=True)
draft = DFlashDraftModel(draft_config).to(torch_dtype)
safetensors_files = sorted(glob_module.glob(os.path.join(draft_model_path, "*.safetensors")))
bin_files = sorted(glob_module.glob(os.path.join(draft_model_path, "*.bin")))
state_dict = {}
if safetensors_files:
for f in safetensors_files:
state_dict.update(load_safetensors(f, device="cpu"))
elif bin_files:
for f in bin_files:
state_dict.update(torch.load(f, map_location="cpu", weights_only=True))
else:
raise FileNotFoundError(f"No safetensors or bin files found in {draft_model_path}")
missing, unexpected = draft.load_state_dict(state_dict, strict=False)
if missing:
print(f" [rank {device}] WARNING: missing keys: {missing}")
if unexpected:
print(f" [rank {device}] INFO: unexpected keys (ignored): {unexpected}")
return draft.to(device).eval()
def run_eval_on_samples(draft, target, tokenizer, prompts, args, device, rank):
"""Run evaluation on a list of prompts, return results."""
stop_token_ids = [tokenizer.eos_token_id]
if hasattr(tokenizer, "additional_special_tokens_ids"):
stop_token_ids.extend(tokenizer.additional_special_tokens_ids[:3])
all_acceptance_lengths = []
all_taus = []
all_times = []
all_generated_tokens = []
for i, prompt in enumerate(prompts):
# Tokenize with chat template
messages = [{"role": "user", "content": prompt}]
try:
input_text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
except Exception:
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
num_input_tokens = input_ids.shape[1]
# Generate with speculative decoding
t0 = time.time()
output_ids, acceptance_lengths = draft.spec_generate(
target=target,
input_ids=input_ids,
max_new_tokens=args.max_new_tokens,
stop_token_ids=stop_token_ids,
temperature=args.temperature,
num_denoise_steps=args.num_denoise_steps,
)
t1 = time.time()
num_output_tokens = output_ids.shape[1] - num_input_tokens
elapsed = t1 - t0
avg_tau = mean(acceptance_lengths) if acceptance_lengths else 0
all_acceptance_lengths.extend(acceptance_lengths)
all_taus.append(avg_tau)
all_times.append(elapsed)
all_generated_tokens.append(num_output_tokens)
output_text = tokenizer.decode(output_ids[0, num_input_tokens:], skip_special_tokens=True)
preview = output_text[:80].replace("\n", " ")
print(f" [GPU {rank}] Sample {i+1}/{len(prompts)} | "
f"tokens={num_output_tokens} | tau={avg_tau:.2f} | "
f"time={elapsed:.1f}s | {preview}...")
return {
"acceptance_lengths": all_acceptance_lengths,
"per_sample_taus": all_taus,
"times": all_times,
"generated_tokens": all_generated_tokens,
}
def main():
args = parse_args()
rank, world_size = get_rank_and_world()
device = f"cuda:{rank}"
torch.cuda.set_device(rank)
dtype_map = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}
torch_dtype = dtype_map[args.dtype]
if rank == 0:
print("=" * 60)
print("DFlash Evaluation (Multi-GPU Data Parallel)")
print("=" * 60)
print(f" Target model: {args.target_model_path}")
print(f" Draft model: {args.draft_model_path}")
print(f" Dataset: {args.dataset}")
print(f" Max samples: {args.max_samples}")
print(f" Max new tokens: {args.max_new_tokens}")
print(f" Denoise steps: {args.num_denoise_steps}")
print(f" Temperature: {args.temperature}")
print(f" GPUs: {world_size}")
print(f" Dtype: {args.dtype}")
print("=" * 60)
# ---- Load models (each GPU loads its own copy) ----
if rank == 0:
print(f"\n[1/4] Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(args.target_model_path, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if rank == 0:
print(f"[2/4] Loading target model on {world_size} GPUs...")
target = AutoModelForCausalLM.from_pretrained(
args.target_model_path,
torch_dtype=torch_dtype,
trust_remote_code=True,
).to(device).eval()
if rank == 0:
print(f"[3/4] Loading draft model on {world_size} GPUs...")
draft = load_draft_model(args.draft_model_path, torch_dtype, device)
if rank == 0:
print(f" Draft layers: {draft.config.num_hidden_layers}")
print(f" Draft block_size: {draft.block_size}")
print(f" Draft mask_token: {draft.mask_token_id}")
print(f" Draft layer_ids: {draft.target_layer_ids}")
# ---- Load and split data ----
if rank == 0:
print(f"[4/4] Loading evaluation data...")
all_prompts = load_eval_data(args)
# Split prompts across GPUs
my_prompts = all_prompts[rank::world_size]
if rank == 0:
print(f" Total prompts: {len(all_prompts)}, ~{len(my_prompts)} per GPU")
# ---- Run evaluation ----
if rank == 0:
print("\n" + "=" * 60)
print("Running evaluation...")
print("=" * 60)
results = run_eval_on_samples(draft, target, tokenizer, my_prompts, args, device, rank)
# ---- Gather results from all GPUs ----
if world_size > 1:
import torch.distributed as dist
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
# Save per-rank results to shared filesystem (not /tmp which may not be shared)
gather_dir = "/workspace/hanrui/cache/eval_gather"
os.makedirs(gather_dir, exist_ok=True)
tmp_file = os.path.join(gather_dir, f"rank{rank}.json")
with open(tmp_file, "w") as f:
json.dump(results, f)
dist.barrier(device_ids=[rank])
if rank == 0:
# Aggregate all ranks
all_acceptance_lengths = []
all_per_sample_taus = []
all_times = []
all_generated_tokens = []
for r in range(world_size):
rf = os.path.join(gather_dir, f"rank{r}.json")
with open(rf, "r") as f:
rank_results = json.load(f)
all_acceptance_lengths.extend(rank_results["acceptance_lengths"])
all_per_sample_taus.extend(rank_results["per_sample_taus"])
all_times.extend(rank_results["times"])
all_generated_tokens.extend(rank_results["generated_tokens"])
os.remove(rf)
results = {
"acceptance_lengths": all_acceptance_lengths,
"per_sample_taus": all_per_sample_taus,
"times": all_times,
"generated_tokens": all_generated_tokens,
}
else:
# Wait for rank 0 to finish reading before removing
dist.barrier(device_ids=[rank])
if os.path.exists(tmp_file):
os.remove(tmp_file)
if rank == 0:
dist.barrier(device_ids=[rank])
dist.destroy_process_group()
# ---- Print summary (rank 0 only) ----
if rank == 0:
acc_lens = results["acceptance_lengths"]
per_sample_taus = results["per_sample_taus"]
total_tokens = sum(results["generated_tokens"])
total_time = sum(results["times"])
# wall-clock time is max across GPUs (they run in parallel)
wall_time = max(results["times"]) if results["times"] else 0
overall_avg_tau = mean(acc_lens) if acc_lens else 0
overall_median_tau = median(acc_lens) if acc_lens else 0
print("\n" + "=" * 60)
print("RESULTS SUMMARY")
print("=" * 60)
print(f" Denoise steps: {args.num_denoise_steps}")
print(f" GPUs used: {world_size}")
print(f" Samples evaluated: {len(per_sample_taus)}")
print(f" Total blocks: {len(acc_lens)}")
print(f" Total generated tokens: {total_tokens}")
print(f" Total GPU-time: {total_time:.2f}s")
print(f" Wall-clock time (approx): {wall_time:.2f}s")
print(f" ---")
print(f" Avg acceptance length (tau): {overall_avg_tau:.2f}")
print(f" Median acceptance length: {overall_median_tau:.1f}")
print(f" Per-sample avg tau: {[f'{t:.2f}' for t in per_sample_taus]}")
if per_sample_taus:
print(f" Min per-sample tau: {min(per_sample_taus):.2f}")
print(f" Max per-sample tau: {max(per_sample_taus):.2f}")
print("=" * 60)
# Save results
if args.output_file:
output = {
"config": {
"target_model": args.target_model_path,
"draft_model": args.draft_model_path,
"dataset": args.dataset,
"max_samples": args.max_samples,
"max_new_tokens": args.max_new_tokens,
"num_denoise_steps": args.num_denoise_steps,
"temperature": args.temperature,
"num_gpus": world_size,
},
"results": {
"avg_tau": overall_avg_tau,
"median_tau": overall_median_tau,
"per_sample_tau": per_sample_taus,
"total_blocks": len(acc_lens),
"total_tokens": total_tokens,
"total_gpu_time": total_time,
"wall_clock_time": wall_time,
"all_acceptance_lengths": acc_lens,
},
}
os.makedirs(os.path.dirname(args.output_file) or ".", exist_ok=True)
with open(args.output_file, "w") as f:
json.dump(output, f, indent=2)
print(f"\nResults saved to {args.output_file}")
if __name__ == "__main__":
main()