|
|
|
|
|
""" |
|
|
Script to analyze and compare training results from multiple model runs. |
|
|
""" |
|
|
|
|
|
import json |
|
|
import os |
|
|
from pathlib import Path |
|
|
|
|
|
def load_metadata(run_dir): |
|
|
"""Load metadata from a training run directory""" |
|
|
metadata_path = os.path.join(run_dir, "metadata.json") |
|
|
if os.path.exists(metadata_path): |
|
|
with open(metadata_path, 'r', encoding='utf-8') as f: |
|
|
return json.load(f) |
|
|
return None |
|
|
|
|
|
def analyze_all_runs(): |
|
|
"""Analyze all training runs and create comparison""" |
|
|
runs_dir = Path("runs") |
|
|
results = [] |
|
|
|
|
|
|
|
|
for run_dir in runs_dir.glob("*/"): |
|
|
if run_dir.is_dir(): |
|
|
metadata = load_metadata(run_dir) |
|
|
if metadata: |
|
|
results.append({ |
|
|
'run_id': run_dir.name, |
|
|
'model': metadata.get('classifier', 'Unknown'), |
|
|
'dataset': 'VNTC' if 'VNTC' in metadata.get('config_name', '') else 'UTS2017_Bank', |
|
|
'max_features': metadata.get('max_features', 0), |
|
|
'ngram_range': metadata.get('ngram_range', [1,1]), |
|
|
'train_accuracy': metadata.get('train_accuracy', 0), |
|
|
'test_accuracy': metadata.get('test_accuracy', 0), |
|
|
'train_time': metadata.get('train_time', 0), |
|
|
'prediction_time': metadata.get('prediction_time', 0), |
|
|
'train_samples': metadata.get('train_samples', 0), |
|
|
'test_samples': metadata.get('test_samples', 0) |
|
|
}) |
|
|
|
|
|
return results |
|
|
|
|
|
def print_comparison_table(results): |
|
|
"""Print formatted comparison table""" |
|
|
print("\n" + "="*120) |
|
|
print("VIETNAMESE TEXT CLASSIFICATION - MODEL COMPARISON RESULTS") |
|
|
print("="*120) |
|
|
|
|
|
|
|
|
vntc_results = [r for r in results if r['dataset'] == 'VNTC'] |
|
|
|
|
|
if vntc_results: |
|
|
print("\nVNTC Dataset (Vietnamese News Classification):") |
|
|
print("-"*120) |
|
|
print(f"{'Model':<20} {'Features':<10} {'N-gram':<10} {'Train Acc':<12} {'Test Acc':<12} {'Train Time':<12} {'Pred Time':<12}") |
|
|
print("-"*120) |
|
|
|
|
|
|
|
|
vntc_results.sort(key=lambda x: x['test_accuracy'], reverse=True) |
|
|
|
|
|
for result in vntc_results: |
|
|
model = result['model'][:18] |
|
|
features = f"{result['max_features']//1000}k" if result['max_features'] > 0 else "N/A" |
|
|
ngram = f"{result['ngram_range'][0]}-{result['ngram_range'][1]}" |
|
|
train_acc = f"{result['train_accuracy']:.4f}" |
|
|
test_acc = f"{result['test_accuracy']:.4f}" |
|
|
train_time = f"{result['train_time']:.1f}s" |
|
|
pred_time = f"{result['prediction_time']:.1f}s" |
|
|
|
|
|
print(f"{model:<20} {features:<10} {ngram:<10} {train_acc:<12} {test_acc:<12} {train_time:<12} {pred_time:<12}") |
|
|
|
|
|
|
|
|
bank_results = [r for r in results if r['dataset'] == 'UTS2017_Bank'] |
|
|
|
|
|
if bank_results: |
|
|
print("\nUTS2017_Bank Dataset (Vietnamese Banking Text Classification):") |
|
|
print("-"*120) |
|
|
print(f"{'Model':<20} {'Features':<10} {'N-gram':<10} {'Train Acc':<12} {'Test Acc':<12} {'Train Time':<12} {'Pred Time':<12}") |
|
|
print("-"*120) |
|
|
|
|
|
|
|
|
bank_results.sort(key=lambda x: x['test_accuracy'], reverse=True) |
|
|
|
|
|
for result in bank_results: |
|
|
model = result['model'][:18] |
|
|
features = f"{result['max_features']//1000}k" if result['max_features'] > 0 else "N/A" |
|
|
ngram = f"{result['ngram_range'][0]}-{result['ngram_range'][1]}" |
|
|
train_acc = f"{result['train_accuracy']:.4f}" |
|
|
test_acc = f"{result['test_accuracy']:.4f}" |
|
|
train_time = f"{result['train_time']:.1f}s" |
|
|
pred_time = f"{result['prediction_time']:.1f}s" |
|
|
|
|
|
print(f"{model:<20} {features:<10} {ngram:<10} {train_acc:<12} {test_acc:<12} {train_time:<12} {pred_time:<12}") |
|
|
|
|
|
print("="*120) |
|
|
|
|
|
if vntc_results: |
|
|
best_vntc = max(vntc_results, key=lambda x: x['test_accuracy']) |
|
|
print(f"\nBest VNTC model: {best_vntc['model']} with {best_vntc['test_accuracy']:.4f} test accuracy") |
|
|
|
|
|
if bank_results: |
|
|
best_bank = max(bank_results, key=lambda x: x['test_accuracy']) |
|
|
print(f"Best UTS2017_Bank model: {best_bank['model']} with {best_bank['test_accuracy']:.4f} test accuracy") |
|
|
|
|
|
def main(): |
|
|
"""Main analysis function""" |
|
|
print("Analyzing Vietnamese Text Classification Training Results...") |
|
|
|
|
|
results = analyze_all_runs() |
|
|
|
|
|
if not results: |
|
|
print("No training results found in runs/ directory.") |
|
|
return |
|
|
|
|
|
print(f"Found {len(results)} training runs.") |
|
|
print_comparison_table(results) |
|
|
|
|
|
|
|
|
vntc_results = [r for r in results if r['dataset'] == 'VNTC'] |
|
|
bank_results = [r for r in results if r['dataset'] == 'UTS2017_Bank'] |
|
|
|
|
|
print("\nSummary:") |
|
|
print(f"- VNTC runs: {len(vntc_results)}") |
|
|
print(f"- UTS2017_Bank runs: {len(bank_results)}") |
|
|
|
|
|
if vntc_results: |
|
|
avg_vntc_acc = sum(r['test_accuracy'] for r in vntc_results) / len(vntc_results) |
|
|
print(f"- Average VNTC test accuracy: {avg_vntc_acc:.4f}") |
|
|
|
|
|
if bank_results: |
|
|
avg_bank_acc = sum(r['test_accuracy'] for r in bank_results) / len(bank_results) |
|
|
print(f"- Average UTS2017_Bank test accuracy: {avg_bank_acc:.4f}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |