File size: 3,540 Bytes
04ca1b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""
File managing the loading of models for text analysis and multiple-choice tasks.
"""
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModelForMultipleChoice
from utils import load_device

# Device configuration
device = load_device()

# Text analysis models
MODELS = {
    "Aubins/distil-bumble-bert": "Aubins/distil-bumble-bert",
}
id2label = {0: "BIASED", 1: "NEUTRAL"}
label2id = {"BIASED": 0, "NEUTRAL": 1}
loaded_models = {}

# BBQ multiple-choice models
BBQ_MODEL = "euler03/bbq-distil_bumble_bert"
bbq_model = None
bbq_tokenizer = None


def load_model(model_name: str):
    """
    Load and cache a sequence classification model for text objectivity analysis.
    
    Args:
        model_name (str) : Name of the model to load.
        
    Returns:
        tuple (model, tokenizer) : Loaded model and tokenizer.
    """
    if model_name not in MODELS:
        raise ValueError(f"Model '{model_name}' is not recognized. Available models: {list(MODELS.keys())}")
    if model_name in loaded_models:
        return loaded_models[model_name]
    try:
        model_path = MODELS[model_name]
        print(f"[Checkpoint] Loading model {model_name} from {model_path}...")
        
        model = AutoModelForSequenceClassification.from_pretrained(
            model_path, 
            num_labels=2,
            id2label=id2label,
            label2id=label2id
        ).to(device)
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        
        loaded_models[model_name] = (model, tokenizer)
        print(f"[Checkpoint] Model {model_name} loaded successfully.")
        return model, tokenizer
    
    except OSError as e:
        error = f"Error accessing model {model_name}: {str(e)}"
        print(error)
        raise RuntimeError(error) from e
    
    except torch.cuda.OutOfMemoryError as e:
        error = f"Out of GPU memory error loading model {model_name}: {str(e)}"
        print(error)
        raise RuntimeError(error) from e
    
    except Exception as e:
        error = f"Error loading model {model_name}: {str(e)}"
        print(error)
        raise RuntimeError(error) from e


def load_mc_model(model_name: str = BBQ_MODEL):
    """
    Load a multiple-choice model for scenario assessment.
    
    Args:
        model_name (str): Name or path of the model to load
        
    Returns:
        tuple: (model, tokenizer) for multiple choice tasks
    """
    if not model_name or not isinstance(model_name, str):
            raise ValueError(f"Invalid model name: expected a non-empty string but got {type(model_name).__name__}")
    
    try:
        print(f"[Checkpoint] Loading model {model_name}...") 
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForMultipleChoice.from_pretrained(model_name).to(device)
        
        print(f"[Checkpoint] Model {model_name} loaded successfully.")
        return model, tokenizer
        
    except OSError as e:
        error = f"Error accessing model {model_name}: {str(e)}"
        print(error)
        raise RuntimeError(error) from e
    
    except torch.cuda.OutOfMemoryError as e:
        error = f"Out of GPU memory error loading model {model_name}: {str(e)}"
        print(error)
        raise RuntimeError(error) from e
    
    except Exception as e:
        error = f"Error loading model {model_name}: {str(e)}"
        print(error)
        raise RuntimeError(error) from e


# Initialize BBQ model
bbq_model, bbq_tokenizer = load_mc_model(BBQ_MODEL)