pulse_core_1 / analyze_results.py
Vu Anh
Update technical report and README with latest SVC model results
0b1c1cf
#!/usr/bin/env python3
"""
Script to analyze and compare training results from multiple model runs.
"""
import json
import os
from pathlib import Path
def load_metadata(run_dir):
"""Load metadata from a training run directory"""
metadata_path = os.path.join(run_dir, "metadata.json")
if os.path.exists(metadata_path):
with open(metadata_path, 'r', encoding='utf-8') as f:
return json.load(f)
return None
def analyze_all_runs():
"""Analyze all training runs and create comparison"""
runs_dir = Path("runs")
results = []
# Find all metadata files
for run_dir in runs_dir.glob("*/"):
if run_dir.is_dir():
metadata = load_metadata(run_dir)
if metadata:
results.append({
'run_id': run_dir.name,
'model': metadata.get('classifier', 'Unknown'),
'dataset': 'VNTC' if 'VNTC' in metadata.get('config_name', '') else 'UTS2017_Bank',
'max_features': metadata.get('max_features', 0),
'ngram_range': metadata.get('ngram_range', [1,1]),
'train_accuracy': metadata.get('train_accuracy', 0),
'test_accuracy': metadata.get('test_accuracy', 0),
'train_time': metadata.get('train_time', 0),
'prediction_time': metadata.get('prediction_time', 0),
'train_samples': metadata.get('train_samples', 0),
'test_samples': metadata.get('test_samples', 0)
})
return results
def print_comparison_table(results):
"""Print formatted comparison table"""
print("\n" + "="*120)
print("VIETNAMESE TEXT CLASSIFICATION - MODEL COMPARISON RESULTS")
print("="*120)
# Filter for VNTC results (news classification)
vntc_results = [r for r in results if r['dataset'] == 'VNTC']
if vntc_results:
print("\nVNTC Dataset (Vietnamese News Classification):")
print("-"*120)
print(f"{'Model':<20} {'Features':<10} {'N-gram':<10} {'Train Acc':<12} {'Test Acc':<12} {'Train Time':<12} {'Pred Time':<12}")
print("-"*120)
# Sort by test accuracy
vntc_results.sort(key=lambda x: x['test_accuracy'], reverse=True)
for result in vntc_results:
model = result['model'][:18]
features = f"{result['max_features']//1000}k" if result['max_features'] > 0 else "N/A"
ngram = f"{result['ngram_range'][0]}-{result['ngram_range'][1]}"
train_acc = f"{result['train_accuracy']:.4f}"
test_acc = f"{result['test_accuracy']:.4f}"
train_time = f"{result['train_time']:.1f}s"
pred_time = f"{result['prediction_time']:.1f}s"
print(f"{model:<20} {features:<10} {ngram:<10} {train_acc:<12} {test_acc:<12} {train_time:<12} {pred_time:<12}")
# Filter for UTS2017_Bank results
bank_results = [r for r in results if r['dataset'] == 'UTS2017_Bank']
if bank_results:
print("\nUTS2017_Bank Dataset (Vietnamese Banking Text Classification):")
print("-"*120)
print(f"{'Model':<20} {'Features':<10} {'N-gram':<10} {'Train Acc':<12} {'Test Acc':<12} {'Train Time':<12} {'Pred Time':<12}")
print("-"*120)
# Sort by test accuracy
bank_results.sort(key=lambda x: x['test_accuracy'], reverse=True)
for result in bank_results:
model = result['model'][:18]
features = f"{result['max_features']//1000}k" if result['max_features'] > 0 else "N/A"
ngram = f"{result['ngram_range'][0]}-{result['ngram_range'][1]}"
train_acc = f"{result['train_accuracy']:.4f}"
test_acc = f"{result['test_accuracy']:.4f}"
train_time = f"{result['train_time']:.1f}s"
pred_time = f"{result['prediction_time']:.1f}s"
print(f"{model:<20} {features:<10} {ngram:<10} {train_acc:<12} {test_acc:<12} {train_time:<12} {pred_time:<12}")
print("="*120)
if vntc_results:
best_vntc = max(vntc_results, key=lambda x: x['test_accuracy'])
print(f"\nBest VNTC model: {best_vntc['model']} with {best_vntc['test_accuracy']:.4f} test accuracy")
if bank_results:
best_bank = max(bank_results, key=lambda x: x['test_accuracy'])
print(f"Best UTS2017_Bank model: {best_bank['model']} with {best_bank['test_accuracy']:.4f} test accuracy")
def main():
"""Main analysis function"""
print("Analyzing Vietnamese Text Classification Training Results...")
results = analyze_all_runs()
if not results:
print("No training results found in runs/ directory.")
return
print(f"Found {len(results)} training runs.")
print_comparison_table(results)
# Create summary statistics
vntc_results = [r for r in results if r['dataset'] == 'VNTC']
bank_results = [r for r in results if r['dataset'] == 'UTS2017_Bank']
print("\nSummary:")
print(f"- VNTC runs: {len(vntc_results)}")
print(f"- UTS2017_Bank runs: {len(bank_results)}")
if vntc_results:
avg_vntc_acc = sum(r['test_accuracy'] for r in vntc_results) / len(vntc_results)
print(f"- Average VNTC test accuracy: {avg_vntc_acc:.4f}")
if bank_results:
avg_bank_acc = sum(r['test_accuracy'] for r in bank_results) / len(bank_results)
print(f"- Average UTS2017_Bank test accuracy: {avg_bank_acc:.4f}")
if __name__ == "__main__":
main()