File size: 8,927 Bytes
b13f5b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
#!/usr/bin/env python3
"""
Mass evaluation script for running predefined prompts through all checkpoints of a model.
Simple, clean, and minimal approach with readable markdown logging.
"""

import os
import sys
import glob
import time
import json
import argparse
from datetime import datetime
from typing import List, Dict, Any, Optional

# Add the parent directory to path so we can import inference
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))


def load_prompts(prompts_file: str = "prompts.json") -> List[str]:
    """
    Load benchmark prompts from JSON file.
    
    Args:
        prompts_file: Path to the prompts JSON file
        
    Returns:
        List of prompt strings
    """
    # Get the directory of this script
    script_dir = os.path.dirname(os.path.abspath(__file__))
    prompts_path = os.path.join(script_dir, prompts_file)
    
    if not os.path.exists(prompts_path):
        print(f"⚠️  Prompts file not found: {prompts_path}")
        print("Using default fallback prompts...")
        # Fallback prompts if file doesn't exist
        return ["Hello, how are you?"]
    
    try:
        with open(prompts_path, 'r') as f:
            prompts = json.load(f)
        
        # Handle both old format (dict with benchmark_prompts) and new format (simple list)
        if isinstance(prompts, dict) and "benchmark_prompts" in prompts:
            # Old format - extract text field
            prompts = [p.get("text", str(p)) for p in prompts["benchmark_prompts"]]
        elif isinstance(prompts, list):
            # New simple format - already a list of strings
            pass
        else:
            print("⚠️  Invalid prompts format, using fallback")
            return ["Hello, how are you?"]
        
        print(f"πŸ“ Loaded {len(prompts)} prompts from {prompts_file}")
        return prompts
        
    except json.JSONDecodeError as e:
        print(f"❌ Error parsing prompts file: {e}")
        print("Using default fallback prompts...")
        return ["Hello, how are you?"]
    except Exception as e:
        print(f"❌ Error loading prompts file: {e}")
        print("Using default fallback prompts...")
        return ["Hello, how are you?"]


def discover_checkpoints(model_name: str, base_dir: str = "../pico-train/runs") -> List[str]:
    """
    Discover all available checkpoints for a given model.
    
    Args:
        model_name: Name of the model
        base_dir: Base directory for model runs
        
    Returns:
        List of checkpoint paths sorted by step number
    """
    model_path = os.path.join(base_dir, model_name, "checkpoints")
    
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model directory not found: {model_path}")
    
    # Find all step_* directories
    pattern = os.path.join(model_path, "step_*")
    checkpoint_dirs = glob.glob(pattern)
    
    # Filter out non-directories and extract step numbers for sorting
    valid_checkpoints = []
    for checkpoint_dir in checkpoint_dirs:
        if os.path.isdir(checkpoint_dir):
            try:
                step_num = int(os.path.basename(checkpoint_dir).split('_')[1])
                valid_checkpoints.append((step_num, checkpoint_dir))
            except (IndexError, ValueError):
                continue
    
    # Sort by step number and return paths
    valid_checkpoints.sort(key=lambda x: x[0])
    return [checkpoint_path for _, checkpoint_path in valid_checkpoints]


def run_benchmark(model_name: str, output_dir: str = "results", prompts_file: str = "prompts.json") -> str:
    """
    Run benchmark evaluation on all checkpoints of a model.
    
    Args:
        model_name: Name of the model to benchmark
        output_dir: Directory to save results
        prompts_file: Path to the prompts JSON file
        
    Returns:
        Path to the generated report file
    """
    print(f"πŸš€ Starting benchmark for model: {model_name}")
    
    # Load prompts
    benchmark_prompts = load_prompts(prompts_file)
    if not benchmark_prompts:
        print("❌ No prompts loaded")
        return None
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Discover checkpoints
    try:
        checkpoints = discover_checkpoints(model_name)
        print(f"πŸ“Š Found {len(checkpoints)} checkpoints")
    except FileNotFoundError as e:
        print(f"❌ Error: {e}")
        return None
    
    if not checkpoints:
        print("❌ No valid checkpoints found")
        return None
    
    # Generate report filename
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    report_file = os.path.join(output_dir, f"{model_name}_benchmark_{timestamp}.md")
    
    # Import inference module
    try:
        from inference import PicoLMInference
    except ImportError as e:
        print(f"❌ Failed to import inference module: {e}")
        return None
    
    # Start writing report
    with open(report_file, 'w') as f:
        f.write(f"# Benchmark Report: {model_name}\n\n")
        f.write(f"**Generated**: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"**Total Checkpoints**: {len(checkpoints)}\n")
        f.write(f"**Total Prompts**: {len(benchmark_prompts)}\n\n")
        f.write("---\n\n")
        
        # Process each checkpoint
        for i, checkpoint_path in enumerate(checkpoints, 1):
            checkpoint_name = os.path.basename(checkpoint_path)
            print(f"πŸ“ Processing {checkpoint_name} ({i}/{len(checkpoints)})")
            
            f.write(f"## Checkpoint: {checkpoint_name}\n\n")
            f.write(f"**Path**: `{checkpoint_path}`\n\n")
            
            try:
                # Load model for this checkpoint
                start_time = time.time()
                inference = PicoLMInference(checkpoint_path=checkpoint_path, device="cuda")
                load_time = time.time() - start_time
                
                f.write(f"**Load Time**: {load_time:.2f}s\n\n")
                
                # Run all prompts
                for j, prompt_text in enumerate(benchmark_prompts, 1):
                    print(f"  └─ Prompt {j}/{len(benchmark_prompts)}: {prompt_text[:30]}...")
                    
                    f.write(f"### Prompt {j}: \"{prompt_text}\"\n\n")
                    
                    try:
                        # Generate response with default parameters
                        gen_start = time.time()
                        response = inference.generate_completion(
                            prompt=prompt_text,
                            max_length=100,
                            temperature=0.7
                        )
                        gen_time = time.time() - gen_start
                        
                        f.write(f"**Response**:\n```\n{response}\n```\n\n")
                        f.write(f"**Metadata**: max_length=100, temperature=0.7, time={gen_time:.2f}s\n\n")
                        
                    except Exception as e:
                        f.write(f"**Error**: {str(e)}\n\n")
                        print(f"    ⚠️  Error on prompt {j}: {e}")
                
            except Exception as e:
                f.write(f"**Checkpoint Error**: {str(e)}\n\n")
                print(f"  ❌ Failed to load checkpoint: {e}")
            
            f.write("---\n\n")
    
    print(f"βœ… Benchmark complete! Report saved to: {report_file}")
    return report_file


def main():
    """Main function with command-line interface."""
    parser = argparse.ArgumentParser(
        description="Run benchmark evaluation on all checkpoints of a model",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  python benchmark.py pico-decoder-tiny-dolma5M-v1
  python benchmark.py pico-decoder-tiny-dolma29k-v3 --output results/
        """
    )
    
    parser.add_argument("model_name", type=str, 
                       help="Model name (e.g., 'pico-decoder-tiny-dolma5M-v1')")
    parser.add_argument("--output", "-o", type=str, default="results",
                       help="Output directory for results (default: results)")
    parser.add_argument("--prompts", "-p", type=str, default="prompts.json",
                       help="Prompts JSON file (default: prompts.json)")
    
    args = parser.parse_args()
    
    try:
        report_file = run_benchmark(args.model_name, args.output, args.prompts)
        if report_file:
            print(f"\nπŸ“„ Report available at: {report_file}")
            return 0
        else:
            return 1
            
    except KeyboardInterrupt:
        print("\n⏹️  Benchmark interrupted by user")
        return 1
    except Exception as e:
        print(f"❌ Unexpected error: {e}")
        import traceback
        traceback.print_exc()
        return 1


if __name__ == "__main__":
    exit(main())