# models/model_loader.py from functools import lru_cache from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM from .logging_config import logger import os MODEL_MAPPING = { "zero-shot-classification": { "primary": "distilbert-base-uncased", # Much smaller than BART "fallback": "microsoft/DialoGPT-small", # Very small "local_fallback": "distilbert-base-uncased" }, "summarization": { "primary": "sshleifer/distilbart-cnn-6-6", # Already small "fallback": "t5-small", # Very small "local_fallback": "t5-small" }, "text-classification": { "primary": "distilbert-base-uncased", # Already small "fallback": "distilbert-base-uncased", "local_fallback": "distilbert-base-uncased" }, # Use a much smaller model for text generation "text-generation": { "primary": "distilgpt2", # Much smaller than TinyLlama "fallback": "gpt2" # Small fallback } } _model_cache = {} @lru_cache(maxsize=2) def load_model(task, model_name=None): try: fallback_used = None if task == "text-generation": model_name = "distilgpt2" # Use distilgpt2 instead of TinyLlama elif model_name is None or model_name in MODEL_MAPPING.get(task, {}): model_config = MODEL_MAPPING.get(task, {}) if model_name is None: model_name = model_config.get("primary", "distilbert-base-uncased") cache_key = f"{task}_{model_name}" if cache_key in _model_cache: logger.info(f"Using cached model: {model_name} for task: {task}") return _model_cache[cache_key] logger.info(f"Loading model: {model_name} for task: {task}") model_kwargs = {"device": -1, "truncation": True} if task == "zero-shot-classification": model_kwargs.update({"max_length": 256, "truncation": True}) # Reduced max_length elif task == "summarization": model_kwargs.update({"max_length": 100, "min_length": 20, "do_sample": False, "num_beams": 1, "truncation": True}) # Reduced lengths elif task == "text-generation": model_kwargs.update({"max_length": 256, "do_sample": True, "temperature": 0.7, "top_p": 0.9, "repetition_penalty": 1.1, "truncation": True}) # Reduced max_length try: if task == "text-generation": tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) pad_token_id = tokenizer.eos_token_id if tokenizer.pad_token_id is None else tokenizer.pad_token_id pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, device=-1, pad_token_id=pad_token_id, truncation=True ) pipe.fallback_used = False _model_cache[cache_key] = pipe logger.info(f"Successfully loaded text-generation model: {model_name}") return pipe else: model = pipeline(task, model=model_name, **model_kwargs) model.fallback_used = False _model_cache[cache_key] = model logger.info(f"Successfully loaded model: {model_name}") return model except Exception as e: logger.warning(f"Failed to load primary model {model_name} for {task}: {str(e)}") # Try fallback and local_fallback model_config = MODEL_MAPPING.get(task, {}) for fallback_key in ["fallback", "local_fallback"]: fallback_model = model_config.get(fallback_key) if fallback_model and fallback_model != model_name: # Don't try the same model again try: logger.info(f"Trying fallback model: {fallback_model} for {task}") model = pipeline(task, model=fallback_model, device=-1, truncation=True) model.fallback_used = True model.fallback_model = fallback_model _model_cache[f"{task}_{fallback_model}"] = model logger.info(f"Loaded fallback model: {fallback_model} for {task}") return model except Exception as e2: logger.warning(f"Failed to load fallback model {fallback_model} for {task}: {str(e2)}") logger.error(f"All model loading failed for {task}, using static fallback.") return create_text_fallback(task) except Exception as e: logger.error(f"Error in load_model: {str(e)}") return create_text_fallback(task) def create_text_fallback(task): class TextFallback: def __init__(self, task_type): self.task_type = task_type self.fallback_used = True self.fallback_model = "static_fallback" def __call__(self, text, *args, **kwargs): if self.task_type == "text-generation": return [{"generated_text": "Summary unavailable: Unable to load TinyLlama model. Please check system memory or model availability."}] elif self.task_type == "zero-shot-classification": text_lower = text.lower() labels = args[0] if args else ["positive", "negative"] scores = [] for label in labels: if label.lower() in text_lower: scores.append(0.8) else: scores.append(0.2) return {"labels": labels, "scores": scores} elif self.task_type == "summarization": sentences = text.split('.') if len(sentences) > 3: summary = '. '.join(sentences[:2]) + '.' else: summary = text[:200] + ('...' if len(text) > 200 else '') return [{"summary_text": summary}] else: return {"result": "Model unavailable, using fallback"} return TextFallback(task) def clear_model_cache(): global _model_cache _model_cache.clear() logger.info("Model cache cleared") def get_available_models(): return MODEL_MAPPING