File size: 6,405 Bytes
877e000 6e3dbdb 877e000 6e3dbdb 877e000 6e3dbdb 877e000 6e3dbdb 877e000 6e3dbdb 877e000 6e3dbdb 877e000 6e3dbdb 877e000 6e3dbdb 877e000 6e3dbdb 877e000 6e3dbdb 877e000 6e3dbdb 877e000 6e3dbdb 877e000 6e3dbdb 877e000 6e3dbdb 877e000 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
# models/model_loader.py
from functools import lru_cache
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from .logging_config import logger
import os
MODEL_MAPPING = {
"zero-shot-classification": {
"primary": "distilbert-base-uncased", # Much smaller than BART
"fallback": "microsoft/DialoGPT-small", # Very small
"local_fallback": "distilbert-base-uncased"
},
"summarization": {
"primary": "sshleifer/distilbart-cnn-6-6", # Already small
"fallback": "t5-small", # Very small
"local_fallback": "t5-small"
},
"text-classification": {
"primary": "distilbert-base-uncased", # Already small
"fallback": "distilbert-base-uncased",
"local_fallback": "distilbert-base-uncased"
},
# Use a much smaller model for text generation
"text-generation": {
"primary": "distilgpt2", # Much smaller than TinyLlama
"fallback": "gpt2" # Small fallback
}
}
_model_cache = {}
@lru_cache(maxsize=2)
def load_model(task, model_name=None):
try:
fallback_used = None
if task == "text-generation":
model_name = "distilgpt2" # Use distilgpt2 instead of TinyLlama
elif model_name is None or model_name in MODEL_MAPPING.get(task, {}):
model_config = MODEL_MAPPING.get(task, {})
if model_name is None:
model_name = model_config.get("primary", "distilbert-base-uncased")
cache_key = f"{task}_{model_name}"
if cache_key in _model_cache:
logger.info(f"Using cached model: {model_name} for task: {task}")
return _model_cache[cache_key]
logger.info(f"Loading model: {model_name} for task: {task}")
model_kwargs = {"device": -1, "truncation": True}
if task == "zero-shot-classification":
model_kwargs.update({"max_length": 256, "truncation": True}) # Reduced max_length
elif task == "summarization":
model_kwargs.update({"max_length": 100, "min_length": 20, "do_sample": False, "num_beams": 1, "truncation": True}) # Reduced lengths
elif task == "text-generation":
model_kwargs.update({"max_length": 256, "do_sample": True, "temperature": 0.7, "top_p": 0.9, "repetition_penalty": 1.1, "truncation": True}) # Reduced max_length
try:
if task == "text-generation":
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
pad_token_id = tokenizer.eos_token_id if tokenizer.pad_token_id is None else tokenizer.pad_token_id
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=-1,
pad_token_id=pad_token_id,
truncation=True
)
pipe.fallback_used = False
_model_cache[cache_key] = pipe
logger.info(f"Successfully loaded text-generation model: {model_name}")
return pipe
else:
model = pipeline(task, model=model_name, **model_kwargs)
model.fallback_used = False
_model_cache[cache_key] = model
logger.info(f"Successfully loaded model: {model_name}")
return model
except Exception as e:
logger.warning(f"Failed to load primary model {model_name} for {task}: {str(e)}")
# Try fallback and local_fallback
model_config = MODEL_MAPPING.get(task, {})
for fallback_key in ["fallback", "local_fallback"]:
fallback_model = model_config.get(fallback_key)
if fallback_model and fallback_model != model_name: # Don't try the same model again
try:
logger.info(f"Trying fallback model: {fallback_model} for {task}")
model = pipeline(task, model=fallback_model, device=-1, truncation=True)
model.fallback_used = True
model.fallback_model = fallback_model
_model_cache[f"{task}_{fallback_model}"] = model
logger.info(f"Loaded fallback model: {fallback_model} for {task}")
return model
except Exception as e2:
logger.warning(f"Failed to load fallback model {fallback_model} for {task}: {str(e2)}")
logger.error(f"All model loading failed for {task}, using static fallback.")
return create_text_fallback(task)
except Exception as e:
logger.error(f"Error in load_model: {str(e)}")
return create_text_fallback(task)
def create_text_fallback(task):
class TextFallback:
def __init__(self, task_type):
self.task_type = task_type
self.fallback_used = True
self.fallback_model = "static_fallback"
def __call__(self, text, *args, **kwargs):
if self.task_type == "text-generation":
return [{"generated_text": "Summary unavailable: Unable to load TinyLlama model. Please check system memory or model availability."}]
elif self.task_type == "zero-shot-classification":
text_lower = text.lower()
labels = args[0] if args else ["positive", "negative"]
scores = []
for label in labels:
if label.lower() in text_lower:
scores.append(0.8)
else:
scores.append(0.2)
return {"labels": labels, "scores": scores}
elif self.task_type == "summarization":
sentences = text.split('.')
if len(sentences) > 3:
summary = '. '.join(sentences[:2]) + '.'
else:
summary = text[:200] + ('...' if len(text) > 200 else '')
return [{"summary_text": summary}]
else:
return {"result": "Model unavailable, using fallback"}
return TextFallback(task)
def clear_model_cache():
global _model_cache
_model_cache.clear()
logger.info("Model cache cleared")
def get_available_models():
return MODEL_MAPPING
|