Spaces:
Sleeping
Sleeping
""" | |
File managing the loading of models for text analysis and multiple-choice tasks. | |
""" | |
import torch | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModelForMultipleChoice | |
from utils import load_device | |
# Device configuration | |
device = load_device() | |
# Text analysis models | |
MODELS = { | |
"Aubins/distil-bumble-bert": "Aubins/distil-bumble-bert", | |
} | |
id2label = {0: "BIASED", 1: "NEUTRAL"} | |
label2id = {"BIASED": 0, "NEUTRAL": 1} | |
loaded_models = {} | |
# BBQ multiple-choice models | |
BBQ_MODEL = "euler03/bbq-distil_bumble_bert" | |
bbq_model = None | |
bbq_tokenizer = None | |
def load_model(model_name: str): | |
""" | |
Load and cache a sequence classification model for text objectivity analysis. | |
Args: | |
model_name (str) : Name of the model to load. | |
Returns: | |
tuple (model, tokenizer) : Loaded model and tokenizer. | |
""" | |
if model_name not in MODELS: | |
raise ValueError(f"Model '{model_name}' is not recognized. Available models: {list(MODELS.keys())}") | |
if model_name in loaded_models: | |
return loaded_models[model_name] | |
try: | |
model_path = MODELS[model_name] | |
print(f"[Checkpoint] Loading model {model_name} from {model_path}...") | |
model = AutoModelForSequenceClassification.from_pretrained( | |
model_path, | |
num_labels=2, | |
id2label=id2label, | |
label2id=label2id | |
).to(device) | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
loaded_models[model_name] = (model, tokenizer) | |
print(f"[Checkpoint] Model {model_name} loaded successfully.") | |
return model, tokenizer | |
except OSError as e: | |
error = f"Error accessing model {model_name}: {str(e)}" | |
print(error) | |
raise RuntimeError(error) from e | |
except torch.cuda.OutOfMemoryError as e: | |
error = f"Out of GPU memory error loading model {model_name}: {str(e)}" | |
print(error) | |
raise RuntimeError(error) from e | |
except Exception as e: | |
error = f"Error loading model {model_name}: {str(e)}" | |
print(error) | |
raise RuntimeError(error) from e | |
def load_mc_model(model_name: str = BBQ_MODEL): | |
""" | |
Load a multiple-choice model for scenario assessment. | |
Args: | |
model_name (str): Name or path of the model to load | |
Returns: | |
tuple: (model, tokenizer) for multiple choice tasks | |
""" | |
if not model_name or not isinstance(model_name, str): | |
raise ValueError(f"Invalid model name: expected a non-empty string but got {type(model_name).__name__}") | |
try: | |
print(f"[Checkpoint] Loading model {model_name}...") | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForMultipleChoice.from_pretrained(model_name).to(device) | |
print(f"[Checkpoint] Model {model_name} loaded successfully.") | |
return model, tokenizer | |
except OSError as e: | |
error = f"Error accessing model {model_name}: {str(e)}" | |
print(error) | |
raise RuntimeError(error) from e | |
except torch.cuda.OutOfMemoryError as e: | |
error = f"Out of GPU memory error loading model {model_name}: {str(e)}" | |
print(error) | |
raise RuntimeError(error) from e | |
except Exception as e: | |
error = f"Error loading model {model_name}: {str(e)}" | |
print(error) | |
raise RuntimeError(error) from e | |
# Initialize BBQ model | |
bbq_model, bbq_tokenizer = load_mc_model(BBQ_MODEL) |