File size: 7,440 Bytes
b44c3a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# voting.py

import numpy as np
import pandas as pd
from collections import defaultdict
from sklearn.metrics import classification_report
import os
import pickle

from config import LABEL_COLUMNS, PREDICTIONS_SAVE_DIR

def save_predictions(model_name, all_probs, true_labels):
    """
    Saves the prediction probabilities and true labels for each target field
    from a specific model. This data is then used by the voting ensemble.

    Args:
        model_name (str): Unique identifier for the model (e.g., "BERT", "TF_IDF_LR").
        all_probs (list): A list where each element is a NumPy array of probabilities
                          for a corresponding label column (shape: num_samples, num_classes).
        true_labels (list): A list where each element is a NumPy array of true labels
                            for a corresponding label column (shape: num_samples,).
    """
    model_preds_dir = os.path.join(PREDICTIONS_SAVE_DIR, model_name)
    os.makedirs(model_preds_dir, exist_ok=True) # Ensure the model-specific directory exists

    for i, col in enumerate(LABEL_COLUMNS):
        # Define file paths for probabilities and true labels for the current field
        prob_file = os.path.join(model_preds_dir, f"{col}_probs.pkl")
        true_file = os.path.join(model_preds_dir, f"{col}_true.pkl")

        # Save probabilities (list of arrays) and true labels (list of arrays)
        with open(prob_file, 'wb') as f:
            pickle.dump(all_probs[i], f)
        with open(true_file, 'wb') as f:
            pickle.dump(true_labels[i], f)
    print(f"Predictions for {model_name} saved to {model_preds_dir}")

def load_predictions(model_name):
    """
    Loads saved prediction probabilities and true labels for a given model.

    Args:
        model_name (str): Unique identifier for the model.

    Returns:
        tuple: A tuple containing:
            - all_probs (list): List of NumPy arrays of probabilities for each label column.
            - true_labels (list): List of NumPy arrays of true labels for each label column.
            Returns (None, None) if files are not found.
    """
    model_preds_dir = os.path.join(PREDICTIONS_SAVE_DIR, model_name)
    all_probs = [[] for _ in range(len(LABEL_COLUMNS))]
    true_labels = [[] for _ in range(len(LABEL_COLUMNS))]

    found_all_files = True
    for i, col in enumerate(LABEL_COLUMNS):
        prob_file = os.path.join(model_preds_dir, f"{col}_probs.pkl")
        true_file = os.path.join(model_preds_dir, f"{col}_true.pkl")
        if os.path.exists(prob_file) and os.path.exists(true_file):
            with open(prob_file, 'rb') as f:
                all_probs[i] = pickle.load(f)
            with open(true_file, 'rb') as f:
                true_labels[i] = pickle.load(f)
        else:
            print(f"Warning: Prediction files not found for {model_name} - {col}. This model might be excluded for this label in ensemble.")
            found_all_files = False # Mark that not all files were found

    if not found_all_files:
        return None, None # Indicate that this model's predictions couldn't be fully loaded

    # Convert list of lists to list of numpy arrays if they were loaded as lists
    # This ensures consistency for stacking later.
    all_probs = [np.array(p) for p in all_probs]
    true_labels = [np.array(t) for t in true_labels]

    return all_probs, true_labels

def perform_voting_ensemble(model_names_to_ensemble):
    """
    Performs a soft voting ensemble (averaging probabilities) for each label
    across a list of specified models.

    Args:
        model_names_to_ensemble (list): A list of string names of the models
                                        whose predictions should be ensembled.
                                        These names should match the directory names
                                        under `PREDICTIONS_SAVE_DIR`.

    Returns:
        tuple: A tuple containing:
            - ensemble_reports (dict): Classification reports for the ensemble predictions.
            - all_true_labels_for_ensemble (list): List of true labels used for evaluation.
            - ensemble_predictions (list): List of predicted class indices from the ensemble.
    """
    print("\n--- Performing Voting Ensemble ---")
    # defaultdict stores a list for each key, helpful when appending to potentially new keys
    all_models_probs = defaultdict(list) # Stores list of probability arrays per label for all models
    # Initialize with empty lists; true labels for evaluation (should be consistent across models)
    all_true_labels_for_ensemble = [None for _ in range(len(LABEL_COLUMNS))]

    # Load probabilities from all specified models
    for model_name in model_names_to_ensemble:
        print(f"Loading predictions for {model_name}...")
        probs_per_label, true_labels_per_label = load_predictions(model_name)

        if probs_per_label is None: # Skip this model if loading failed
            continue

        for i, col in enumerate(LABEL_COLUMNS):
            if len(probs_per_label[i]) > 0: # Ensure probabilities were actually loaded for this label
                all_models_probs[col].append(probs_per_label[i])
                if all_true_labels_for_ensemble[i] is None: # Store true labels only once (they should be identical)
                    all_true_labels_for_ensemble[i] = true_labels_per_label[i]

    ensemble_predictions = [[] for _ in range(len(LABEL_COLUMNS))]
    ensemble_reports = {}

    for i, col in enumerate(LABEL_COLUMNS):
        if not all_models_probs[col]: # If no models provided predictions for this label
            print(f"No valid predictions available for {col} to ensemble. Skipping.")
            ensemble_reports[col] = {'accuracy': 0, 'weighted avg': {'precision': 0, 'recall': 0, 'f1-score': 0, 'support': 0}}
            continue

        # Stack probabilities for the current label from all models that had them.
        # `stacked_probs` will have shape: (num_contributing_models, num_samples, num_classes)
        stacked_probs = np.stack(all_models_probs[col], axis=0)

        # Perform soft voting by summing probabilities across models.
        # `summed_probs` will have shape: (num_samples, num_classes)
        summed_probs = np.sum(stacked_probs, axis=0)

        # Get the final predicted class by taking the argmax of the summed probabilities.
        final_preds = np.argmax(summed_probs, axis=1) # (num_samples,)

        ensemble_predictions[i] = final_preds.tolist()

        # Evaluate ensemble predictions
        y_true_ensemble = all_true_labels_for_ensemble[i]
        if y_true_ensemble is not None: # Ensure true labels are available
            try:
                report = classification_report(y_true_ensemble, final_preds, output_dict=True, zero_division=0)
                ensemble_reports[col] = report
            except ValueError:
                print(f"Warning: Could not generate ensemble classification report for {col}. Skipping.")
                ensemble_reports[col] = {'accuracy': 0, 'weighted avg': {'precision': 0, 'recall': 0, 'f1-score': 0, 'support': 0}}
        else:
            print(f"Warning: True labels not found for {col}, cannot evaluate ensemble.")
            ensemble_reports[col] = {'accuracy': 0, 'weighted avg': {'precision': 0, 'recall': 0, 'f1-score': 0, 'support': 0}}


    return ensemble_reports, all_true_labels_for_ensemble, ensemble_predictions