Spaces:
Sleeping
Sleeping
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import numpy as np | |
| import pandas as pd | |
| from pathlib import Path | |
| import json | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| def plot_training_history(history: dict, save_path: Path = None): | |
| """ | |
| Plot training and validation metrics over epochs. | |
| Args: | |
| history: Dictionary containing training history | |
| save_path: Path to save the plot | |
| """ | |
| plt.figure(figsize=(12, 5)) | |
| # Plot loss | |
| plt.subplot(1, 2, 1) | |
| plt.plot(history['train_loss'], label='Training Loss') | |
| plt.plot(history['val_loss'], label='Validation Loss') | |
| plt.title('Training and Validation Loss') | |
| plt.xlabel('Epoch') | |
| plt.ylabel('Loss') | |
| plt.legend() | |
| # Plot metrics | |
| plt.subplot(1, 2, 2) | |
| metrics = ['accuracy', 'precision', 'recall', 'f1'] | |
| for metric in metrics: | |
| values = [epoch_metrics[metric] for epoch_metrics in history['val_metrics']] | |
| plt.plot(values, label=metric.capitalize()) | |
| plt.title('Validation Metrics') | |
| plt.xlabel('Epoch') | |
| plt.ylabel('Score') | |
| plt.legend() | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path) | |
| logger.info(f"Training history plot saved to {save_path}") | |
| plt.close() | |
| def plot_confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray, save_path: Path = None): | |
| """ | |
| Plot confusion matrix for model predictions. | |
| Args: | |
| y_true: True labels | |
| y_pred: Predicted labels | |
| save_path: Path to save the plot | |
| """ | |
| from sklearn.metrics import confusion_matrix | |
| cm = confusion_matrix(y_true, y_pred) | |
| plt.figure(figsize=(8, 6)) | |
| sns.heatmap(cm, annot=True, fmt='d', cmap='Blues') | |
| plt.title('Confusion Matrix') | |
| plt.xlabel('Predicted Label') | |
| plt.ylabel('True Label') | |
| if save_path: | |
| plt.savefig(save_path) | |
| logger.info(f"Confusion matrix plot saved to {save_path}") | |
| plt.close() | |
| def plot_attention_weights(text: str, attention_weights: np.ndarray, save_path: Path = None): | |
| """ | |
| Plot attention weights for a given text. | |
| Args: | |
| text: Input text | |
| attention_weights: Attention weights for each token | |
| save_path: Path to save the plot | |
| """ | |
| tokens = text.split() | |
| plt.figure(figsize=(12, 4)) | |
| # Plot attention weights | |
| plt.bar(range(len(tokens)), attention_weights) | |
| plt.xticks(range(len(tokens)), tokens, rotation=45, ha='right') | |
| plt.title('Attention Weights') | |
| plt.xlabel('Tokens') | |
| plt.ylabel('Attention Weight') | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path) | |
| logger.info(f"Attention weights plot saved to {save_path}") | |
| plt.close() | |
| def plot_model_comparison(metrics: dict, save_path: Path = None): | |
| """ | |
| Plot comparison of different models' performance. | |
| Args: | |
| metrics: Dictionary containing model metrics | |
| save_path: Path to save the plot | |
| """ | |
| models = list(metrics.keys()) | |
| metric_names = ['accuracy', 'precision', 'recall', 'f1'] | |
| plt.figure(figsize=(10, 6)) | |
| x = np.arange(len(models)) | |
| width = 0.2 | |
| for i, metric in enumerate(metric_names): | |
| values = [metrics[model][metric] for model in models] | |
| plt.bar(x + i*width, values, width, label=metric.capitalize()) | |
| plt.title('Model Performance Comparison') | |
| plt.xlabel('Models') | |
| plt.ylabel('Score') | |
| plt.xticks(x + width*1.5, models, rotation=45) | |
| plt.legend() | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path) | |
| logger.info(f"Model comparison plot saved to {save_path}") | |
| plt.close() | |
| def plot_feature_importance(feature_importance: dict, save_path: Path = None): | |
| """ | |
| Plot feature importance scores. | |
| Args: | |
| feature_importance: Dictionary containing feature importance scores | |
| save_path: Path to save the plot | |
| """ | |
| features = list(feature_importance.keys()) | |
| importance = list(feature_importance.values()) | |
| # Sort by importance | |
| sorted_idx = np.argsort(importance) | |
| features = [features[i] for i in sorted_idx] | |
| importance = [importance[i] for i in sorted_idx] | |
| plt.figure(figsize=(10, 6)) | |
| plt.barh(range(len(features)), importance) | |
| plt.yticks(range(len(features)), features) | |
| plt.title('Feature Importance') | |
| plt.xlabel('Importance Score') | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path) | |
| logger.info(f"Feature importance plot saved to {save_path}") | |
| plt.close() | |
| def main(): | |
| # Create visualization directory | |
| vis_dir = Path(__file__).parent.parent.parent / "visualizations" | |
| vis_dir.mkdir(exist_ok=True) | |
| # Example usage | |
| history = { | |
| 'train_loss': [0.5, 0.4, 0.3], | |
| 'val_loss': [0.45, 0.35, 0.25], | |
| 'val_metrics': [ | |
| {'accuracy': 0.8, 'precision': 0.75, 'recall': 0.85, 'f1': 0.8}, | |
| {'accuracy': 0.85, 'precision': 0.8, 'recall': 0.9, 'f1': 0.85}, | |
| {'accuracy': 0.9, 'precision': 0.85, 'recall': 0.95, 'f1': 0.9} | |
| ] | |
| } | |
| # Plot training history | |
| plot_training_history(history, save_path=vis_dir / "training_history.png") | |
| # Example confusion matrix | |
| y_true = np.array([0, 1, 0, 1, 1, 0]) | |
| y_pred = np.array([0, 1, 0, 0, 1, 0]) | |
| plot_confusion_matrix(y_true, y_pred, save_path=vis_dir / "confusion_matrix.png") | |
| # Example model comparison | |
| metrics = { | |
| 'BERT': {'accuracy': 0.85, 'precision': 0.82, 'recall': 0.88, 'f1': 0.85}, | |
| 'BiLSTM': {'accuracy': 0.78, 'precision': 0.75, 'recall': 0.81, 'f1': 0.78}, | |
| 'Hybrid': {'accuracy': 0.92, 'precision': 0.9, 'recall': 0.94, 'f1': 0.92} | |
| } | |
| plot_model_comparison(metrics, save_path=vis_dir / "model_comparison.png") | |
| # Example feature importance | |
| feature_importance = { | |
| 'BERT': 0.4, | |
| 'BiLSTM': 0.3, | |
| 'Attention': 0.2, | |
| 'TF-IDF': 0.1 | |
| } | |
| plot_feature_importance(feature_importance, save_path=vis_dir / "feature_importance.png") | |
| if __name__ == "__main__": | |
| main() |