File size: 3,746 Bytes
bc6498b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#!/usr/bin/env python3
# compression_eval_llm_template.py
import argparse, json, os, time, math
from typing import Dict, Any, List
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
try:
    from energy_logger_nvml import EnergyLogger
    _HAS_NVML=True
except Exception:
    _HAS_NVML=False

def model_bytes(model: torch.nn.Module) -> int:
    total = 0
    for p in model.parameters():
        total += p.numel() * p.element_size()
    return total

def run_generation_bench(model, tokenizer, device, prompts: List[str], max_new_tokens=128):
    tokens_generated = 0
    latencies = []
    if _HAS_NVML:
        with EnergyLogger(tag="genbench") as el:
            for p in prompts:
                inputs = tokenizer(p, return_tensors="pt").to(device)
                t0 = time.time(); _ = model.generate(**inputs, max_new_tokens=max_new_tokens)
                if device == "cuda": torch.cuda.synchronize()
                latencies.append(time.time()-t0); tokens_generated += max_new_tokens
        energy_J = el.summary["energy_J"]; avg_W = el.summary["avg_power_W"]
    else:
        for p in prompts:
            inputs = tokenizer(p, return_tensors="pt").to(device)
            t0 = time.time(); _ = model.generate(**inputs, max_new_tokens=max_new_tokens)
            if device == "cuda": torch.cuda.synchronize()
            latencies.append(time.time()-t0); tokens_generated += max_new_tokens
        energy_J = None; avg_W = None
    toks_per_s = tokens_generated / sum(latencies)
    return {
        "tokens_generated": tokens_generated,
        "latency_ms_avg": 1000 * sum(latencies) / len(latencies),
        "latency_ms_p95": 1000 * sorted(latencies)[int(0.95*len(latencies))-1],
        "tokens_per_s": toks_per_s,
        "energy_J": energy_J,
        "avg_power_W": avg_W,
        "J_per_1M_tokens": None if energy_J is None else energy_J / max(1, tokens_generated) * 1_000_000
    }

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--model", type=str, default="sshleifer/tiny-gpt2")
    ap.add_argument("--dtype", type=str, default="fp16", choices=["fp16","bf16","fp32"])
    ap.add_argument("--prompts_file", type=str, required=True)
    ap.add_argument("--max_new_tokens", type=int, default=64)
    ap.add_argument("--tag", type=str, default="baseline")
    ap.add_argument("--load_8bit", action="store_true")
    ap.add_argument("--load_4bit", action="store_true")
    args = ap.parse_args()

    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}[args.dtype]

    quant_args: Dict[str, Any] = {}
    if args.load_8bit:
        quant_args["load_in_8bit"] = True; quant_args["device_map"] = "auto"
    elif args.load_4bit:
        quant_args["load_in_4bit"] = True; quant_args["bnb_4bit_compute_dtype"] = dtype; quant_args["device_map"] = "auto"

    tok = AutoTokenizer.from_pretrained(args.model, use_fast=True)
    model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=dtype, **quant_args)
    model.eval().to(device)

    prompts = [json.loads(l)["text"] for l in open(args.prompts_file)]
    size_bytes = model_bytes(model)
    bench = run_generation_bench(model, tok, device, prompts, max_new_tokens=args.max_new_tokens)

    out = {
        "model": args.model, "tag": args.tag, "dtype": args.dtype,
        "quant": "8bit" if args.load_8bit else ("4bit" if args.load_4bit else "full"),
        "size_bytes": int(size_bytes), **bench
    }
    os.makedirs("phase4_outputs", exist_ok=True)
    with open(f"phase4_outputs/llm_eval_{args.tag}.json", "w") as f:
        json.dump(out, f, indent=2)
    print(json.dumps(out, indent=2))

if __name__ == "__main__":
    main()