""" File managing the loading of models for text analysis and multiple-choice tasks. """ import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModelForMultipleChoice from utils import load_device # Device configuration device = load_device() # Text analysis models MODELS = { "Aubins/distil-bumble-bert": "Aubins/distil-bumble-bert", } id2label = {0: "BIASED", 1: "NEUTRAL"} label2id = {"BIASED": 0, "NEUTRAL": 1} loaded_models = {} # BBQ multiple-choice models BBQ_MODEL = "euler03/bbq-distil_bumble_bert" bbq_model = None bbq_tokenizer = None def load_model(model_name: str): """ Load and cache a sequence classification model for text objectivity analysis. Args: model_name (str) : Name of the model to load. Returns: tuple (model, tokenizer) : Loaded model and tokenizer. """ if model_name not in MODELS: raise ValueError(f"Model '{model_name}' is not recognized. Available models: {list(MODELS.keys())}") if model_name in loaded_models: return loaded_models[model_name] try: model_path = MODELS[model_name] print(f"[Checkpoint] Loading model {model_name} from {model_path}...") model = AutoModelForSequenceClassification.from_pretrained( model_path, num_labels=2, id2label=id2label, label2id=label2id ).to(device) tokenizer = AutoTokenizer.from_pretrained(model_path) loaded_models[model_name] = (model, tokenizer) print(f"[Checkpoint] Model {model_name} loaded successfully.") return model, tokenizer except OSError as e: error = f"Error accessing model {model_name}: {str(e)}" print(error) raise RuntimeError(error) from e except torch.cuda.OutOfMemoryError as e: error = f"Out of GPU memory error loading model {model_name}: {str(e)}" print(error) raise RuntimeError(error) from e except Exception as e: error = f"Error loading model {model_name}: {str(e)}" print(error) raise RuntimeError(error) from e def load_mc_model(model_name: str = BBQ_MODEL): """ Load a multiple-choice model for scenario assessment. Args: model_name (str): Name or path of the model to load Returns: tuple: (model, tokenizer) for multiple choice tasks """ if not model_name or not isinstance(model_name, str): raise ValueError(f"Invalid model name: expected a non-empty string but got {type(model_name).__name__}") try: print(f"[Checkpoint] Loading model {model_name}...") tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForMultipleChoice.from_pretrained(model_name).to(device) print(f"[Checkpoint] Model {model_name} loaded successfully.") return model, tokenizer except OSError as e: error = f"Error accessing model {model_name}: {str(e)}" print(error) raise RuntimeError(error) from e except torch.cuda.OutOfMemoryError as e: error = f"Out of GPU memory error loading model {model_name}: {str(e)}" print(error) raise RuntimeError(error) from e except Exception as e: error = f"Error loading model {model_name}: {str(e)}" print(error) raise RuntimeError(error) from e # Initialize BBQ model bbq_model, bbq_tokenizer = load_mc_model(BBQ_MODEL)