#!/usr/bin/env python3 """ Script to analyze and compare training results from multiple model runs. """ import json import os from pathlib import Path def load_metadata(run_dir): """Load metadata from a training run directory""" metadata_path = os.path.join(run_dir, "metadata.json") if os.path.exists(metadata_path): with open(metadata_path, 'r', encoding='utf-8') as f: return json.load(f) return None def analyze_all_runs(): """Analyze all training runs and create comparison""" runs_dir = Path("runs") results = [] # Find all metadata files for run_dir in runs_dir.glob("*/"): if run_dir.is_dir(): metadata = load_metadata(run_dir) if metadata: results.append({ 'run_id': run_dir.name, 'model': metadata.get('classifier', 'Unknown'), 'dataset': 'VNTC' if 'VNTC' in metadata.get('config_name', '') else 'UTS2017_Bank', 'max_features': metadata.get('max_features', 0), 'ngram_range': metadata.get('ngram_range', [1,1]), 'train_accuracy': metadata.get('train_accuracy', 0), 'test_accuracy': metadata.get('test_accuracy', 0), 'train_time': metadata.get('train_time', 0), 'prediction_time': metadata.get('prediction_time', 0), 'train_samples': metadata.get('train_samples', 0), 'test_samples': metadata.get('test_samples', 0) }) return results def print_comparison_table(results): """Print formatted comparison table""" print("\n" + "="*120) print("VIETNAMESE TEXT CLASSIFICATION - MODEL COMPARISON RESULTS") print("="*120) # Filter for VNTC results (news classification) vntc_results = [r for r in results if r['dataset'] == 'VNTC'] if vntc_results: print("\nVNTC Dataset (Vietnamese News Classification):") print("-"*120) print(f"{'Model':<20} {'Features':<10} {'N-gram':<10} {'Train Acc':<12} {'Test Acc':<12} {'Train Time':<12} {'Pred Time':<12}") print("-"*120) # Sort by test accuracy vntc_results.sort(key=lambda x: x['test_accuracy'], reverse=True) for result in vntc_results: model = result['model'][:18] features = f"{result['max_features']//1000}k" if result['max_features'] > 0 else "N/A" ngram = f"{result['ngram_range'][0]}-{result['ngram_range'][1]}" train_acc = f"{result['train_accuracy']:.4f}" test_acc = f"{result['test_accuracy']:.4f}" train_time = f"{result['train_time']:.1f}s" pred_time = f"{result['prediction_time']:.1f}s" print(f"{model:<20} {features:<10} {ngram:<10} {train_acc:<12} {test_acc:<12} {train_time:<12} {pred_time:<12}") # Filter for UTS2017_Bank results bank_results = [r for r in results if r['dataset'] == 'UTS2017_Bank'] if bank_results: print("\nUTS2017_Bank Dataset (Vietnamese Banking Text Classification):") print("-"*120) print(f"{'Model':<20} {'Features':<10} {'N-gram':<10} {'Train Acc':<12} {'Test Acc':<12} {'Train Time':<12} {'Pred Time':<12}") print("-"*120) # Sort by test accuracy bank_results.sort(key=lambda x: x['test_accuracy'], reverse=True) for result in bank_results: model = result['model'][:18] features = f"{result['max_features']//1000}k" if result['max_features'] > 0 else "N/A" ngram = f"{result['ngram_range'][0]}-{result['ngram_range'][1]}" train_acc = f"{result['train_accuracy']:.4f}" test_acc = f"{result['test_accuracy']:.4f}" train_time = f"{result['train_time']:.1f}s" pred_time = f"{result['prediction_time']:.1f}s" print(f"{model:<20} {features:<10} {ngram:<10} {train_acc:<12} {test_acc:<12} {train_time:<12} {pred_time:<12}") print("="*120) if vntc_results: best_vntc = max(vntc_results, key=lambda x: x['test_accuracy']) print(f"\nBest VNTC model: {best_vntc['model']} with {best_vntc['test_accuracy']:.4f} test accuracy") if bank_results: best_bank = max(bank_results, key=lambda x: x['test_accuracy']) print(f"Best UTS2017_Bank model: {best_bank['model']} with {best_bank['test_accuracy']:.4f} test accuracy") def main(): """Main analysis function""" print("Analyzing Vietnamese Text Classification Training Results...") results = analyze_all_runs() if not results: print("No training results found in runs/ directory.") return print(f"Found {len(results)} training runs.") print_comparison_table(results) # Create summary statistics vntc_results = [r for r in results if r['dataset'] == 'VNTC'] bank_results = [r for r in results if r['dataset'] == 'UTS2017_Bank'] print("\nSummary:") print(f"- VNTC runs: {len(vntc_results)}") print(f"- UTS2017_Bank runs: {len(bank_results)}") if vntc_results: avg_vntc_acc = sum(r['test_accuracy'] for r in vntc_results) / len(vntc_results) print(f"- Average VNTC test accuracy: {avg_vntc_acc:.4f}") if bank_results: avg_bank_acc = sum(r['test_accuracy'] for r in bank_results) / len(bank_results) print(f"- Average UTS2017_Bank test accuracy: {avg_bank_acc:.4f}") if __name__ == "__main__": main()