File size: 8,927 Bytes
b13f5b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 |
#!/usr/bin/env python3
"""
Mass evaluation script for running predefined prompts through all checkpoints of a model.
Simple, clean, and minimal approach with readable markdown logging.
"""
import os
import sys
import glob
import time
import json
import argparse
from datetime import datetime
from typing import List, Dict, Any, Optional
# Add the parent directory to path so we can import inference
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
def load_prompts(prompts_file: str = "prompts.json") -> List[str]:
"""
Load benchmark prompts from JSON file.
Args:
prompts_file: Path to the prompts JSON file
Returns:
List of prompt strings
"""
# Get the directory of this script
script_dir = os.path.dirname(os.path.abspath(__file__))
prompts_path = os.path.join(script_dir, prompts_file)
if not os.path.exists(prompts_path):
print(f"β οΈ Prompts file not found: {prompts_path}")
print("Using default fallback prompts...")
# Fallback prompts if file doesn't exist
return ["Hello, how are you?"]
try:
with open(prompts_path, 'r') as f:
prompts = json.load(f)
# Handle both old format (dict with benchmark_prompts) and new format (simple list)
if isinstance(prompts, dict) and "benchmark_prompts" in prompts:
# Old format - extract text field
prompts = [p.get("text", str(p)) for p in prompts["benchmark_prompts"]]
elif isinstance(prompts, list):
# New simple format - already a list of strings
pass
else:
print("β οΈ Invalid prompts format, using fallback")
return ["Hello, how are you?"]
print(f"π Loaded {len(prompts)} prompts from {prompts_file}")
return prompts
except json.JSONDecodeError as e:
print(f"β Error parsing prompts file: {e}")
print("Using default fallback prompts...")
return ["Hello, how are you?"]
except Exception as e:
print(f"β Error loading prompts file: {e}")
print("Using default fallback prompts...")
return ["Hello, how are you?"]
def discover_checkpoints(model_name: str, base_dir: str = "../pico-train/runs") -> List[str]:
"""
Discover all available checkpoints for a given model.
Args:
model_name: Name of the model
base_dir: Base directory for model runs
Returns:
List of checkpoint paths sorted by step number
"""
model_path = os.path.join(base_dir, model_name, "checkpoints")
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model directory not found: {model_path}")
# Find all step_* directories
pattern = os.path.join(model_path, "step_*")
checkpoint_dirs = glob.glob(pattern)
# Filter out non-directories and extract step numbers for sorting
valid_checkpoints = []
for checkpoint_dir in checkpoint_dirs:
if os.path.isdir(checkpoint_dir):
try:
step_num = int(os.path.basename(checkpoint_dir).split('_')[1])
valid_checkpoints.append((step_num, checkpoint_dir))
except (IndexError, ValueError):
continue
# Sort by step number and return paths
valid_checkpoints.sort(key=lambda x: x[0])
return [checkpoint_path for _, checkpoint_path in valid_checkpoints]
def run_benchmark(model_name: str, output_dir: str = "results", prompts_file: str = "prompts.json") -> str:
"""
Run benchmark evaluation on all checkpoints of a model.
Args:
model_name: Name of the model to benchmark
output_dir: Directory to save results
prompts_file: Path to the prompts JSON file
Returns:
Path to the generated report file
"""
print(f"π Starting benchmark for model: {model_name}")
# Load prompts
benchmark_prompts = load_prompts(prompts_file)
if not benchmark_prompts:
print("β No prompts loaded")
return None
# Create output directory
os.makedirs(output_dir, exist_ok=True)
# Discover checkpoints
try:
checkpoints = discover_checkpoints(model_name)
print(f"π Found {len(checkpoints)} checkpoints")
except FileNotFoundError as e:
print(f"β Error: {e}")
return None
if not checkpoints:
print("β No valid checkpoints found")
return None
# Generate report filename
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
report_file = os.path.join(output_dir, f"{model_name}_benchmark_{timestamp}.md")
# Import inference module
try:
from inference import PicoLMInference
except ImportError as e:
print(f"β Failed to import inference module: {e}")
return None
# Start writing report
with open(report_file, 'w') as f:
f.write(f"# Benchmark Report: {model_name}\n\n")
f.write(f"**Generated**: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write(f"**Total Checkpoints**: {len(checkpoints)}\n")
f.write(f"**Total Prompts**: {len(benchmark_prompts)}\n\n")
f.write("---\n\n")
# Process each checkpoint
for i, checkpoint_path in enumerate(checkpoints, 1):
checkpoint_name = os.path.basename(checkpoint_path)
print(f"π Processing {checkpoint_name} ({i}/{len(checkpoints)})")
f.write(f"## Checkpoint: {checkpoint_name}\n\n")
f.write(f"**Path**: `{checkpoint_path}`\n\n")
try:
# Load model for this checkpoint
start_time = time.time()
inference = PicoLMInference(checkpoint_path=checkpoint_path, device="cuda")
load_time = time.time() - start_time
f.write(f"**Load Time**: {load_time:.2f}s\n\n")
# Run all prompts
for j, prompt_text in enumerate(benchmark_prompts, 1):
print(f" ββ Prompt {j}/{len(benchmark_prompts)}: {prompt_text[:30]}...")
f.write(f"### Prompt {j}: \"{prompt_text}\"\n\n")
try:
# Generate response with default parameters
gen_start = time.time()
response = inference.generate_completion(
prompt=prompt_text,
max_length=100,
temperature=0.7
)
gen_time = time.time() - gen_start
f.write(f"**Response**:\n```\n{response}\n```\n\n")
f.write(f"**Metadata**: max_length=100, temperature=0.7, time={gen_time:.2f}s\n\n")
except Exception as e:
f.write(f"**Error**: {str(e)}\n\n")
print(f" β οΈ Error on prompt {j}: {e}")
except Exception as e:
f.write(f"**Checkpoint Error**: {str(e)}\n\n")
print(f" β Failed to load checkpoint: {e}")
f.write("---\n\n")
print(f"β
Benchmark complete! Report saved to: {report_file}")
return report_file
def main():
"""Main function with command-line interface."""
parser = argparse.ArgumentParser(
description="Run benchmark evaluation on all checkpoints of a model",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python benchmark.py pico-decoder-tiny-dolma5M-v1
python benchmark.py pico-decoder-tiny-dolma29k-v3 --output results/
"""
)
parser.add_argument("model_name", type=str,
help="Model name (e.g., 'pico-decoder-tiny-dolma5M-v1')")
parser.add_argument("--output", "-o", type=str, default="results",
help="Output directory for results (default: results)")
parser.add_argument("--prompts", "-p", type=str, default="prompts.json",
help="Prompts JSON file (default: prompts.json)")
args = parser.parse_args()
try:
report_file = run_benchmark(args.model_name, args.output, args.prompts)
if report_file:
print(f"\nπ Report available at: {report_file}")
return 0
else:
return 1
except KeyboardInterrupt:
print("\nβΉοΈ Benchmark interrupted by user")
return 1
except Exception as e:
print(f"β Unexpected error: {e}")
import traceback
traceback.print_exc()
return 1
if __name__ == "__main__":
exit(main())
|