Spaces:
Runtime error
Runtime error
# voting.py | |
import numpy as np | |
import pandas as pd | |
from collections import defaultdict | |
from sklearn.metrics import classification_report | |
import os | |
import pickle | |
from config import LABEL_COLUMNS, PREDICTIONS_SAVE_DIR | |
def save_predictions(model_name, all_probs, true_labels): | |
""" | |
Saves the prediction probabilities and true labels for each target field | |
from a specific model. This data is then used by the voting ensemble. | |
Args: | |
model_name (str): Unique identifier for the model (e.g., "BERT", "TF_IDF_LR"). | |
all_probs (list): A list where each element is a NumPy array of probabilities | |
for a corresponding label column (shape: num_samples, num_classes). | |
true_labels (list): A list where each element is a NumPy array of true labels | |
for a corresponding label column (shape: num_samples,). | |
""" | |
model_preds_dir = os.path.join(PREDICTIONS_SAVE_DIR, model_name) | |
os.makedirs(model_preds_dir, exist_ok=True) # Ensure the model-specific directory exists | |
for i, col in enumerate(LABEL_COLUMNS): | |
# Define file paths for probabilities and true labels for the current field | |
prob_file = os.path.join(model_preds_dir, f"{col}_probs.pkl") | |
true_file = os.path.join(model_preds_dir, f"{col}_true.pkl") | |
# Save probabilities (list of arrays) and true labels (list of arrays) | |
with open(prob_file, 'wb') as f: | |
pickle.dump(all_probs[i], f) | |
with open(true_file, 'wb') as f: | |
pickle.dump(true_labels[i], f) | |
print(f"Predictions for {model_name} saved to {model_preds_dir}") | |
def load_predictions(model_name): | |
""" | |
Loads saved prediction probabilities and true labels for a given model. | |
Args: | |
model_name (str): Unique identifier for the model. | |
Returns: | |
tuple: A tuple containing: | |
- all_probs (list): List of NumPy arrays of probabilities for each label column. | |
- true_labels (list): List of NumPy arrays of true labels for each label column. | |
Returns (None, None) if files are not found. | |
""" | |
model_preds_dir = os.path.join(PREDICTIONS_SAVE_DIR, model_name) | |
all_probs = [[] for _ in range(len(LABEL_COLUMNS))] | |
true_labels = [[] for _ in range(len(LABEL_COLUMNS))] | |
found_all_files = True | |
for i, col in enumerate(LABEL_COLUMNS): | |
prob_file = os.path.join(model_preds_dir, f"{col}_probs.pkl") | |
true_file = os.path.join(model_preds_dir, f"{col}_true.pkl") | |
if os.path.exists(prob_file) and os.path.exists(true_file): | |
with open(prob_file, 'rb') as f: | |
all_probs[i] = pickle.load(f) | |
with open(true_file, 'rb') as f: | |
true_labels[i] = pickle.load(f) | |
else: | |
print(f"Warning: Prediction files not found for {model_name} - {col}. This model might be excluded for this label in ensemble.") | |
found_all_files = False # Mark that not all files were found | |
if not found_all_files: | |
return None, None # Indicate that this model's predictions couldn't be fully loaded | |
# Convert list of lists to list of numpy arrays if they were loaded as lists | |
# This ensures consistency for stacking later. | |
all_probs = [np.array(p) for p in all_probs] | |
true_labels = [np.array(t) for t in true_labels] | |
return all_probs, true_labels | |
def perform_voting_ensemble(model_names_to_ensemble): | |
""" | |
Performs a soft voting ensemble (averaging probabilities) for each label | |
across a list of specified models. | |
Args: | |
model_names_to_ensemble (list): A list of string names of the models | |
whose predictions should be ensembled. | |
These names should match the directory names | |
under `PREDICTIONS_SAVE_DIR`. | |
Returns: | |
tuple: A tuple containing: | |
- ensemble_reports (dict): Classification reports for the ensemble predictions. | |
- all_true_labels_for_ensemble (list): List of true labels used for evaluation. | |
- ensemble_predictions (list): List of predicted class indices from the ensemble. | |
""" | |
print("\n--- Performing Voting Ensemble ---") | |
# defaultdict stores a list for each key, helpful when appending to potentially new keys | |
all_models_probs = defaultdict(list) # Stores list of probability arrays per label for all models | |
# Initialize with empty lists; true labels for evaluation (should be consistent across models) | |
all_true_labels_for_ensemble = [None for _ in range(len(LABEL_COLUMNS))] | |
# Load probabilities from all specified models | |
for model_name in model_names_to_ensemble: | |
print(f"Loading predictions for {model_name}...") | |
probs_per_label, true_labels_per_label = load_predictions(model_name) | |
if probs_per_label is None: # Skip this model if loading failed | |
continue | |
for i, col in enumerate(LABEL_COLUMNS): | |
if len(probs_per_label[i]) > 0: # Ensure probabilities were actually loaded for this label | |
all_models_probs[col].append(probs_per_label[i]) | |
if all_true_labels_for_ensemble[i] is None: # Store true labels only once (they should be identical) | |
all_true_labels_for_ensemble[i] = true_labels_per_label[i] | |
ensemble_predictions = [[] for _ in range(len(LABEL_COLUMNS))] | |
ensemble_reports = {} | |
for i, col in enumerate(LABEL_COLUMNS): | |
if not all_models_probs[col]: # If no models provided predictions for this label | |
print(f"No valid predictions available for {col} to ensemble. Skipping.") | |
ensemble_reports[col] = {'accuracy': 0, 'weighted avg': {'precision': 0, 'recall': 0, 'f1-score': 0, 'support': 0}} | |
continue | |
# Stack probabilities for the current label from all models that had them. | |
# `stacked_probs` will have shape: (num_contributing_models, num_samples, num_classes) | |
stacked_probs = np.stack(all_models_probs[col], axis=0) | |
# Perform soft voting by summing probabilities across models. | |
# `summed_probs` will have shape: (num_samples, num_classes) | |
summed_probs = np.sum(stacked_probs, axis=0) | |
# Get the final predicted class by taking the argmax of the summed probabilities. | |
final_preds = np.argmax(summed_probs, axis=1) # (num_samples,) | |
ensemble_predictions[i] = final_preds.tolist() | |
# Evaluate ensemble predictions | |
y_true_ensemble = all_true_labels_for_ensemble[i] | |
if y_true_ensemble is not None: # Ensure true labels are available | |
try: | |
report = classification_report(y_true_ensemble, final_preds, output_dict=True, zero_division=0) | |
ensemble_reports[col] = report | |
except ValueError: | |
print(f"Warning: Could not generate ensemble classification report for {col}. Skipping.") | |
ensemble_reports[col] = {'accuracy': 0, 'weighted avg': {'precision': 0, 'recall': 0, 'f1-score': 0, 'support': 0}} | |
else: | |
print(f"Warning: True labels not found for {col}, cannot evaluate ensemble.") | |
ensemble_reports[col] = {'accuracy': 0, 'weighted avg': {'precision': 0, 'recall': 0, 'f1-score': 0, 'support': 0}} | |
return ensemble_reports, all_true_labels_for_ensemble, ensemble_predictions |