voting-ensemble / voting.py
namanpenguin's picture
Upload 8 files
b44c3a2 verified
# voting.py
import numpy as np
import pandas as pd
from collections import defaultdict
from sklearn.metrics import classification_report
import os
import pickle
from config import LABEL_COLUMNS, PREDICTIONS_SAVE_DIR
def save_predictions(model_name, all_probs, true_labels):
"""
Saves the prediction probabilities and true labels for each target field
from a specific model. This data is then used by the voting ensemble.
Args:
model_name (str): Unique identifier for the model (e.g., "BERT", "TF_IDF_LR").
all_probs (list): A list where each element is a NumPy array of probabilities
for a corresponding label column (shape: num_samples, num_classes).
true_labels (list): A list where each element is a NumPy array of true labels
for a corresponding label column (shape: num_samples,).
"""
model_preds_dir = os.path.join(PREDICTIONS_SAVE_DIR, model_name)
os.makedirs(model_preds_dir, exist_ok=True) # Ensure the model-specific directory exists
for i, col in enumerate(LABEL_COLUMNS):
# Define file paths for probabilities and true labels for the current field
prob_file = os.path.join(model_preds_dir, f"{col}_probs.pkl")
true_file = os.path.join(model_preds_dir, f"{col}_true.pkl")
# Save probabilities (list of arrays) and true labels (list of arrays)
with open(prob_file, 'wb') as f:
pickle.dump(all_probs[i], f)
with open(true_file, 'wb') as f:
pickle.dump(true_labels[i], f)
print(f"Predictions for {model_name} saved to {model_preds_dir}")
def load_predictions(model_name):
"""
Loads saved prediction probabilities and true labels for a given model.
Args:
model_name (str): Unique identifier for the model.
Returns:
tuple: A tuple containing:
- all_probs (list): List of NumPy arrays of probabilities for each label column.
- true_labels (list): List of NumPy arrays of true labels for each label column.
Returns (None, None) if files are not found.
"""
model_preds_dir = os.path.join(PREDICTIONS_SAVE_DIR, model_name)
all_probs = [[] for _ in range(len(LABEL_COLUMNS))]
true_labels = [[] for _ in range(len(LABEL_COLUMNS))]
found_all_files = True
for i, col in enumerate(LABEL_COLUMNS):
prob_file = os.path.join(model_preds_dir, f"{col}_probs.pkl")
true_file = os.path.join(model_preds_dir, f"{col}_true.pkl")
if os.path.exists(prob_file) and os.path.exists(true_file):
with open(prob_file, 'rb') as f:
all_probs[i] = pickle.load(f)
with open(true_file, 'rb') as f:
true_labels[i] = pickle.load(f)
else:
print(f"Warning: Prediction files not found for {model_name} - {col}. This model might be excluded for this label in ensemble.")
found_all_files = False # Mark that not all files were found
if not found_all_files:
return None, None # Indicate that this model's predictions couldn't be fully loaded
# Convert list of lists to list of numpy arrays if they were loaded as lists
# This ensures consistency for stacking later.
all_probs = [np.array(p) for p in all_probs]
true_labels = [np.array(t) for t in true_labels]
return all_probs, true_labels
def perform_voting_ensemble(model_names_to_ensemble):
"""
Performs a soft voting ensemble (averaging probabilities) for each label
across a list of specified models.
Args:
model_names_to_ensemble (list): A list of string names of the models
whose predictions should be ensembled.
These names should match the directory names
under `PREDICTIONS_SAVE_DIR`.
Returns:
tuple: A tuple containing:
- ensemble_reports (dict): Classification reports for the ensemble predictions.
- all_true_labels_for_ensemble (list): List of true labels used for evaluation.
- ensemble_predictions (list): List of predicted class indices from the ensemble.
"""
print("\n--- Performing Voting Ensemble ---")
# defaultdict stores a list for each key, helpful when appending to potentially new keys
all_models_probs = defaultdict(list) # Stores list of probability arrays per label for all models
# Initialize with empty lists; true labels for evaluation (should be consistent across models)
all_true_labels_for_ensemble = [None for _ in range(len(LABEL_COLUMNS))]
# Load probabilities from all specified models
for model_name in model_names_to_ensemble:
print(f"Loading predictions for {model_name}...")
probs_per_label, true_labels_per_label = load_predictions(model_name)
if probs_per_label is None: # Skip this model if loading failed
continue
for i, col in enumerate(LABEL_COLUMNS):
if len(probs_per_label[i]) > 0: # Ensure probabilities were actually loaded for this label
all_models_probs[col].append(probs_per_label[i])
if all_true_labels_for_ensemble[i] is None: # Store true labels only once (they should be identical)
all_true_labels_for_ensemble[i] = true_labels_per_label[i]
ensemble_predictions = [[] for _ in range(len(LABEL_COLUMNS))]
ensemble_reports = {}
for i, col in enumerate(LABEL_COLUMNS):
if not all_models_probs[col]: # If no models provided predictions for this label
print(f"No valid predictions available for {col} to ensemble. Skipping.")
ensemble_reports[col] = {'accuracy': 0, 'weighted avg': {'precision': 0, 'recall': 0, 'f1-score': 0, 'support': 0}}
continue
# Stack probabilities for the current label from all models that had them.
# `stacked_probs` will have shape: (num_contributing_models, num_samples, num_classes)
stacked_probs = np.stack(all_models_probs[col], axis=0)
# Perform soft voting by summing probabilities across models.
# `summed_probs` will have shape: (num_samples, num_classes)
summed_probs = np.sum(stacked_probs, axis=0)
# Get the final predicted class by taking the argmax of the summed probabilities.
final_preds = np.argmax(summed_probs, axis=1) # (num_samples,)
ensemble_predictions[i] = final_preds.tolist()
# Evaluate ensemble predictions
y_true_ensemble = all_true_labels_for_ensemble[i]
if y_true_ensemble is not None: # Ensure true labels are available
try:
report = classification_report(y_true_ensemble, final_preds, output_dict=True, zero_division=0)
ensemble_reports[col] = report
except ValueError:
print(f"Warning: Could not generate ensemble classification report for {col}. Skipping.")
ensemble_reports[col] = {'accuracy': 0, 'weighted avg': {'precision': 0, 'recall': 0, 'f1-score': 0, 'support': 0}}
else:
print(f"Warning: True labels not found for {col}, cannot evaluate ensemble.")
ensemble_reports[col] = {'accuracy': 0, 'weighted avg': {'precision': 0, 'recall': 0, 'f1-score': 0, 'support': 0}}
return ensemble_reports, all_true_labels_for_ensemble, ensemble_predictions