# voting.py import numpy as np import pandas as pd from collections import defaultdict from sklearn.metrics import classification_report import os import pickle from config import LABEL_COLUMNS, PREDICTIONS_SAVE_DIR def save_predictions(model_name, all_probs, true_labels): """ Saves the prediction probabilities and true labels for each target field from a specific model. This data is then used by the voting ensemble. Args: model_name (str): Unique identifier for the model (e.g., "BERT", "TF_IDF_LR"). all_probs (list): A list where each element is a NumPy array of probabilities for a corresponding label column (shape: num_samples, num_classes). true_labels (list): A list where each element is a NumPy array of true labels for a corresponding label column (shape: num_samples,). """ model_preds_dir = os.path.join(PREDICTIONS_SAVE_DIR, model_name) os.makedirs(model_preds_dir, exist_ok=True) # Ensure the model-specific directory exists for i, col in enumerate(LABEL_COLUMNS): # Define file paths for probabilities and true labels for the current field prob_file = os.path.join(model_preds_dir, f"{col}_probs.pkl") true_file = os.path.join(model_preds_dir, f"{col}_true.pkl") # Save probabilities (list of arrays) and true labels (list of arrays) with open(prob_file, 'wb') as f: pickle.dump(all_probs[i], f) with open(true_file, 'wb') as f: pickle.dump(true_labels[i], f) print(f"Predictions for {model_name} saved to {model_preds_dir}") def load_predictions(model_name): """ Loads saved prediction probabilities and true labels for a given model. Args: model_name (str): Unique identifier for the model. Returns: tuple: A tuple containing: - all_probs (list): List of NumPy arrays of probabilities for each label column. - true_labels (list): List of NumPy arrays of true labels for each label column. Returns (None, None) if files are not found. """ model_preds_dir = os.path.join(PREDICTIONS_SAVE_DIR, model_name) all_probs = [[] for _ in range(len(LABEL_COLUMNS))] true_labels = [[] for _ in range(len(LABEL_COLUMNS))] found_all_files = True for i, col in enumerate(LABEL_COLUMNS): prob_file = os.path.join(model_preds_dir, f"{col}_probs.pkl") true_file = os.path.join(model_preds_dir, f"{col}_true.pkl") if os.path.exists(prob_file) and os.path.exists(true_file): with open(prob_file, 'rb') as f: all_probs[i] = pickle.load(f) with open(true_file, 'rb') as f: true_labels[i] = pickle.load(f) else: print(f"Warning: Prediction files not found for {model_name} - {col}. This model might be excluded for this label in ensemble.") found_all_files = False # Mark that not all files were found if not found_all_files: return None, None # Indicate that this model's predictions couldn't be fully loaded # Convert list of lists to list of numpy arrays if they were loaded as lists # This ensures consistency for stacking later. all_probs = [np.array(p) for p in all_probs] true_labels = [np.array(t) for t in true_labels] return all_probs, true_labels def perform_voting_ensemble(model_names_to_ensemble): """ Performs a soft voting ensemble (averaging probabilities) for each label across a list of specified models. Args: model_names_to_ensemble (list): A list of string names of the models whose predictions should be ensembled. These names should match the directory names under `PREDICTIONS_SAVE_DIR`. Returns: tuple: A tuple containing: - ensemble_reports (dict): Classification reports for the ensemble predictions. - all_true_labels_for_ensemble (list): List of true labels used for evaluation. - ensemble_predictions (list): List of predicted class indices from the ensemble. """ print("\n--- Performing Voting Ensemble ---") # defaultdict stores a list for each key, helpful when appending to potentially new keys all_models_probs = defaultdict(list) # Stores list of probability arrays per label for all models # Initialize with empty lists; true labels for evaluation (should be consistent across models) all_true_labels_for_ensemble = [None for _ in range(len(LABEL_COLUMNS))] # Load probabilities from all specified models for model_name in model_names_to_ensemble: print(f"Loading predictions for {model_name}...") probs_per_label, true_labels_per_label = load_predictions(model_name) if probs_per_label is None: # Skip this model if loading failed continue for i, col in enumerate(LABEL_COLUMNS): if len(probs_per_label[i]) > 0: # Ensure probabilities were actually loaded for this label all_models_probs[col].append(probs_per_label[i]) if all_true_labels_for_ensemble[i] is None: # Store true labels only once (they should be identical) all_true_labels_for_ensemble[i] = true_labels_per_label[i] ensemble_predictions = [[] for _ in range(len(LABEL_COLUMNS))] ensemble_reports = {} for i, col in enumerate(LABEL_COLUMNS): if not all_models_probs[col]: # If no models provided predictions for this label print(f"No valid predictions available for {col} to ensemble. Skipping.") ensemble_reports[col] = {'accuracy': 0, 'weighted avg': {'precision': 0, 'recall': 0, 'f1-score': 0, 'support': 0}} continue # Stack probabilities for the current label from all models that had them. # `stacked_probs` will have shape: (num_contributing_models, num_samples, num_classes) stacked_probs = np.stack(all_models_probs[col], axis=0) # Perform soft voting by summing probabilities across models. # `summed_probs` will have shape: (num_samples, num_classes) summed_probs = np.sum(stacked_probs, axis=0) # Get the final predicted class by taking the argmax of the summed probabilities. final_preds = np.argmax(summed_probs, axis=1) # (num_samples,) ensemble_predictions[i] = final_preds.tolist() # Evaluate ensemble predictions y_true_ensemble = all_true_labels_for_ensemble[i] if y_true_ensemble is not None: # Ensure true labels are available try: report = classification_report(y_true_ensemble, final_preds, output_dict=True, zero_division=0) ensemble_reports[col] = report except ValueError: print(f"Warning: Could not generate ensemble classification report for {col}. Skipping.") ensemble_reports[col] = {'accuracy': 0, 'weighted avg': {'precision': 0, 'recall': 0, 'f1-score': 0, 'support': 0}} else: print(f"Warning: True labels not found for {col}, cannot evaluate ensemble.") ensemble_reports[col] = {'accuracy': 0, 'weighted avg': {'precision': 0, 'recall': 0, 'f1-score': 0, 'support': 0}} return ensemble_reports, all_true_labels_for_ensemble, ensemble_predictions