File size: 15,314 Bytes
2d67aa6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
"""
DFlash evaluation script: measure acceptance length (tau) with optional multi-step denoising.
Supports multi-GPU data parallelism via torchrun.

Usage:
    # Single GPU
    python scripts/eval_dflash.py \
        --target-model-path /workspace/models/Qwen3-8B \
        --draft-model-path /workspace/models/Qwen3-8B-DFlash-b16 \
        --dataset math500 --max-samples 10 --num-denoise-steps 1

    # 8 GPU (data parallel, each GPU runs a subset of samples)
    torchrun --standalone --nproc_per_node 8 scripts/eval_dflash.py \
        --target-model-path /workspace/models/Qwen3-8B \
        --draft-model-path /workspace/models/Qwen3-8B-DFlash-b16 \
        --dataset math500 --max-samples 500 --num-denoise-steps 2
"""

import argparse
import json
import os
import sys
import time
from statistics import mean, median

import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

# Add project root to path
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = os.path.dirname(SCRIPT_DIR)
sys.path.insert(0, ROOT_DIR)

from specforge.modeling.draft.dflash import DFlashDraftModel


def parse_args():
    parser = argparse.ArgumentParser(description="DFlash evaluation: acceptance length")
    parser.add_argument("--target-model-path", type=str, required=True,
                        help="Path to target model (e.g. /workspace/models/Qwen3-8B)")
    parser.add_argument("--draft-model-path", type=str, required=True,
                        help="Path to DFlash draft model (e.g. /workspace/models/Qwen3-8B-DFlash-b16)")
    parser.add_argument("--dataset", type=str, default="math500",
                        choices=["math500", "gsm8k", "custom"],
                        help="Evaluation dataset")
    parser.add_argument("--custom-data-path", type=str, default=None,
                        help="Path to custom jsonl data (when --dataset=custom)")
    parser.add_argument("--max-samples", type=int, default=10,
                        help="Max number of evaluation samples (total across all GPUs)")
    parser.add_argument("--max-new-tokens", type=int, default=512,
                        help="Max new tokens per sample")
    parser.add_argument("--num-denoise-steps", type=int, default=1,
                        help="Number of denoising steps (1=baseline, 2/3=multi-step)")
    parser.add_argument("--temperature", type=float, default=0.0,
                        help="Sampling temperature (0.0=greedy)")
    parser.add_argument("--dtype", type=str, default="bfloat16",
                        choices=["float16", "bfloat16", "float32"])
    parser.add_argument("--output-file", type=str, default=None,
                        help="Save results to JSON file")
    return parser.parse_args()


def get_rank_and_world():
    """Get distributed rank and world size, or (0, 1) for single GPU."""
    rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", 0)))
    world = int(os.environ.get("WORLD_SIZE", 1))
    return rank, world


def load_eval_data(args):
    """Load evaluation prompts as list of strings (offline, from local cache)."""
    prompts = []

    # Use local cached datasets (no network access needed)
    DATASET_CACHE = "/workspace/hanrui/datasets"

    if args.dataset == "math500":
        dataset = load_dataset("HuggingFaceH4/MATH-500", cache_dir=DATASET_CACHE)["test"]
        for idx, item in enumerate(dataset):
            if idx >= args.max_samples:
                break
            prompts.append(item["problem"])

    elif args.dataset == "gsm8k":
        dataset = load_dataset("openai/gsm8k", "main", cache_dir=DATASET_CACHE)["test"]
        for idx, item in enumerate(dataset):
            if idx >= args.max_samples:
                break
            prompts.append(item["question"])

    elif args.dataset == "custom":
        assert args.custom_data_path is not None, "Need --custom-data-path for custom dataset"
        with open(args.custom_data_path, "r") as f:
            for idx, line in enumerate(f):
                if idx >= args.max_samples:
                    break
                data = json.loads(line.strip())
                if "prompt" in data:
                    prompts.append(data["prompt"])
                elif "conversations" in data:
                    for msg in data["conversations"]:
                        if msg["role"] == "user":
                            prompts.append(msg["content"])
                            break

    return prompts


def load_draft_model(draft_model_path, torch_dtype, device):
    """Load draft model weights into OUR DFlashDraftModel (with multi-step denoising)."""
    from transformers import AutoConfig
    from safetensors.torch import load_file as load_safetensors
    import glob as glob_module

    draft_config = AutoConfig.from_pretrained(draft_model_path, trust_remote_code=True)
    draft = DFlashDraftModel(draft_config).to(torch_dtype)

    safetensors_files = sorted(glob_module.glob(os.path.join(draft_model_path, "*.safetensors")))
    bin_files = sorted(glob_module.glob(os.path.join(draft_model_path, "*.bin")))

    state_dict = {}
    if safetensors_files:
        for f in safetensors_files:
            state_dict.update(load_safetensors(f, device="cpu"))
    elif bin_files:
        for f in bin_files:
            state_dict.update(torch.load(f, map_location="cpu", weights_only=True))
    else:
        raise FileNotFoundError(f"No safetensors or bin files found in {draft_model_path}")

    missing, unexpected = draft.load_state_dict(state_dict, strict=False)
    if missing:
        print(f"  [rank {device}] WARNING: missing keys: {missing}")
    if unexpected:
        print(f"  [rank {device}] INFO: unexpected keys (ignored): {unexpected}")

    return draft.to(device).eval()


def run_eval_on_samples(draft, target, tokenizer, prompts, args, device, rank):
    """Run evaluation on a list of prompts, return results."""
    stop_token_ids = [tokenizer.eos_token_id]
    if hasattr(tokenizer, "additional_special_tokens_ids"):
        stop_token_ids.extend(tokenizer.additional_special_tokens_ids[:3])

    all_acceptance_lengths = []
    all_taus = []
    all_times = []
    all_generated_tokens = []

    for i, prompt in enumerate(prompts):
        # Tokenize with chat template
        messages = [{"role": "user", "content": prompt}]
        try:
            input_text = tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
        except Exception:
            input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

        num_input_tokens = input_ids.shape[1]

        # Generate with speculative decoding
        t0 = time.time()
        output_ids, acceptance_lengths = draft.spec_generate(
            target=target,
            input_ids=input_ids,
            max_new_tokens=args.max_new_tokens,
            stop_token_ids=stop_token_ids,
            temperature=args.temperature,
            num_denoise_steps=args.num_denoise_steps,
        )
        t1 = time.time()

        num_output_tokens = output_ids.shape[1] - num_input_tokens
        elapsed = t1 - t0
        avg_tau = mean(acceptance_lengths) if acceptance_lengths else 0

        all_acceptance_lengths.extend(acceptance_lengths)
        all_taus.append(avg_tau)
        all_times.append(elapsed)
        all_generated_tokens.append(num_output_tokens)

        output_text = tokenizer.decode(output_ids[0, num_input_tokens:], skip_special_tokens=True)
        preview = output_text[:80].replace("\n", " ")

        print(f"  [GPU {rank}] Sample {i+1}/{len(prompts)} | "
              f"tokens={num_output_tokens} | tau={avg_tau:.2f} | "
              f"time={elapsed:.1f}s | {preview}...")

    return {
        "acceptance_lengths": all_acceptance_lengths,
        "per_sample_taus": all_taus,
        "times": all_times,
        "generated_tokens": all_generated_tokens,
    }


def main():
    args = parse_args()
    rank, world_size = get_rank_and_world()
    device = f"cuda:{rank}"
    torch.cuda.set_device(rank)

    dtype_map = {
        "float16": torch.float16,
        "bfloat16": torch.bfloat16,
        "float32": torch.float32,
    }
    torch_dtype = dtype_map[args.dtype]

    if rank == 0:
        print("=" * 60)
        print("DFlash Evaluation (Multi-GPU Data Parallel)")
        print("=" * 60)
        print(f"  Target model:       {args.target_model_path}")
        print(f"  Draft model:        {args.draft_model_path}")
        print(f"  Dataset:            {args.dataset}")
        print(f"  Max samples:        {args.max_samples}")
        print(f"  Max new tokens:     {args.max_new_tokens}")
        print(f"  Denoise steps:      {args.num_denoise_steps}")
        print(f"  Temperature:        {args.temperature}")
        print(f"  GPUs:               {world_size}")
        print(f"  Dtype:              {args.dtype}")
        print("=" * 60)

    # ---- Load models (each GPU loads its own copy) ----
    if rank == 0:
        print(f"\n[1/4] Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(args.target_model_path, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    if rank == 0:
        print(f"[2/4] Loading target model on {world_size} GPUs...")
    target = AutoModelForCausalLM.from_pretrained(
        args.target_model_path,
        torch_dtype=torch_dtype,
        trust_remote_code=True,
    ).to(device).eval()

    if rank == 0:
        print(f"[3/4] Loading draft model on {world_size} GPUs...")
    draft = load_draft_model(args.draft_model_path, torch_dtype, device)

    if rank == 0:
        print(f"  Draft layers:       {draft.config.num_hidden_layers}")
        print(f"  Draft block_size:   {draft.block_size}")
        print(f"  Draft mask_token:   {draft.mask_token_id}")
        print(f"  Draft layer_ids:    {draft.target_layer_ids}")

    # ---- Load and split data ----
    if rank == 0:
        print(f"[4/4] Loading evaluation data...")
    all_prompts = load_eval_data(args)

    # Split prompts across GPUs
    my_prompts = all_prompts[rank::world_size]
    if rank == 0:
        print(f"  Total prompts: {len(all_prompts)}, ~{len(my_prompts)} per GPU")

    # ---- Run evaluation ----
    if rank == 0:
        print("\n" + "=" * 60)
        print("Running evaluation...")
        print("=" * 60)

    results = run_eval_on_samples(draft, target, tokenizer, my_prompts, args, device, rank)

    # ---- Gather results from all GPUs ----
    if world_size > 1:
        import torch.distributed as dist
        if not dist.is_initialized():
            dist.init_process_group(backend="nccl")

        # Save per-rank results to shared filesystem (not /tmp which may not be shared)
        gather_dir = "/workspace/hanrui/cache/eval_gather"
        os.makedirs(gather_dir, exist_ok=True)
        tmp_file = os.path.join(gather_dir, f"rank{rank}.json")
        with open(tmp_file, "w") as f:
            json.dump(results, f)

        dist.barrier(device_ids=[rank])

        if rank == 0:
            # Aggregate all ranks
            all_acceptance_lengths = []
            all_per_sample_taus = []
            all_times = []
            all_generated_tokens = []

            for r in range(world_size):
                rf = os.path.join(gather_dir, f"rank{r}.json")
                with open(rf, "r") as f:
                    rank_results = json.load(f)
                all_acceptance_lengths.extend(rank_results["acceptance_lengths"])
                all_per_sample_taus.extend(rank_results["per_sample_taus"])
                all_times.extend(rank_results["times"])
                all_generated_tokens.extend(rank_results["generated_tokens"])
                os.remove(rf)

            results = {
                "acceptance_lengths": all_acceptance_lengths,
                "per_sample_taus": all_per_sample_taus,
                "times": all_times,
                "generated_tokens": all_generated_tokens,
            }
        else:
            # Wait for rank 0 to finish reading before removing
            dist.barrier(device_ids=[rank])
            if os.path.exists(tmp_file):
                os.remove(tmp_file)

        if rank == 0:
            dist.barrier(device_ids=[rank])

        dist.destroy_process_group()

    # ---- Print summary (rank 0 only) ----
    if rank == 0:
        acc_lens = results["acceptance_lengths"]
        per_sample_taus = results["per_sample_taus"]
        total_tokens = sum(results["generated_tokens"])
        total_time = sum(results["times"])
        # wall-clock time is max across GPUs (they run in parallel)
        wall_time = max(results["times"]) if results["times"] else 0

        overall_avg_tau = mean(acc_lens) if acc_lens else 0
        overall_median_tau = median(acc_lens) if acc_lens else 0

        print("\n" + "=" * 60)
        print("RESULTS SUMMARY")
        print("=" * 60)
        print(f"  Denoise steps:              {args.num_denoise_steps}")
        print(f"  GPUs used:                  {world_size}")
        print(f"  Samples evaluated:          {len(per_sample_taus)}")
        print(f"  Total blocks:               {len(acc_lens)}")
        print(f"  Total generated tokens:     {total_tokens}")
        print(f"  Total GPU-time:             {total_time:.2f}s")
        print(f"  Wall-clock time (approx):   {wall_time:.2f}s")
        print(f"  ---")
        print(f"  Avg acceptance length (tau): {overall_avg_tau:.2f}")
        print(f"  Median acceptance length:    {overall_median_tau:.1f}")
        print(f"  Per-sample avg tau:          {[f'{t:.2f}' for t in per_sample_taus]}")
        if per_sample_taus:
            print(f"  Min per-sample tau:          {min(per_sample_taus):.2f}")
            print(f"  Max per-sample tau:          {max(per_sample_taus):.2f}")
        print("=" * 60)

        # Save results
        if args.output_file:
            output = {
                "config": {
                    "target_model": args.target_model_path,
                    "draft_model": args.draft_model_path,
                    "dataset": args.dataset,
                    "max_samples": args.max_samples,
                    "max_new_tokens": args.max_new_tokens,
                    "num_denoise_steps": args.num_denoise_steps,
                    "temperature": args.temperature,
                    "num_gpus": world_size,
                },
                "results": {
                    "avg_tau": overall_avg_tau,
                    "median_tau": overall_median_tau,
                    "per_sample_tau": per_sample_taus,
                    "total_blocks": len(acc_lens),
                    "total_tokens": total_tokens,
                    "total_gpu_time": total_time,
                    "wall_clock_time": wall_time,
                    "all_acceptance_lengths": acc_lens,
                },
            }
            os.makedirs(os.path.dirname(args.output_file) or ".", exist_ok=True)
            with open(args.output_file, "w") as f:
                json.dump(output, f, indent=2)
            print(f"\nResults saved to {args.output_file}")


if __name__ == "__main__":
    main()