File size: 6,405 Bytes
877e000
 
 
 
 
 
 
 
 
6e3dbdb
 
877e000
 
 
6e3dbdb
 
877e000
 
 
6e3dbdb
 
877e000
 
6e3dbdb
877e000
6e3dbdb
 
877e000
 
 
 
 
 
 
 
6e3dbdb
877e000
6e3dbdb
877e000
 
 
 
 
 
 
 
 
 
 
6e3dbdb
877e000
6e3dbdb
877e000
6e3dbdb
877e000
 
 
 
 
 
 
 
 
 
 
 
 
6e3dbdb
877e000
 
 
 
 
6e3dbdb
877e000
 
 
 
6e3dbdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
877e000
 
 
 
 
 
 
 
 
6e3dbdb
 
877e000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# models/model_loader.py

from functools import lru_cache
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from .logging_config import logger
import os

MODEL_MAPPING = {
    "zero-shot-classification": {
        "primary": "distilbert-base-uncased",  # Much smaller than BART
        "fallback": "microsoft/DialoGPT-small",  # Very small
        "local_fallback": "distilbert-base-uncased"
    },
    "summarization": {
        "primary": "sshleifer/distilbart-cnn-6-6",  # Already small
        "fallback": "t5-small",  # Very small
        "local_fallback": "t5-small"
    },
    "text-classification": {
        "primary": "distilbert-base-uncased",  # Already small
        "fallback": "distilbert-base-uncased",
        "local_fallback": "distilbert-base-uncased"
    },
    # Use a much smaller model for text generation
    "text-generation": {
        "primary": "distilgpt2",  # Much smaller than TinyLlama
        "fallback": "gpt2"  # Small fallback
    }
}

_model_cache = {}

@lru_cache(maxsize=2)
def load_model(task, model_name=None):
    try:
        fallback_used = None
        if task == "text-generation":
            model_name = "distilgpt2"  # Use distilgpt2 instead of TinyLlama
        elif model_name is None or model_name in MODEL_MAPPING.get(task, {}):
            model_config = MODEL_MAPPING.get(task, {})
            if model_name is None:
                model_name = model_config.get("primary", "distilbert-base-uncased")
        cache_key = f"{task}_{model_name}"
        if cache_key in _model_cache:
            logger.info(f"Using cached model: {model_name} for task: {task}")
            return _model_cache[cache_key]
        logger.info(f"Loading model: {model_name} for task: {task}")
        model_kwargs = {"device": -1, "truncation": True}
        if task == "zero-shot-classification":
            model_kwargs.update({"max_length": 256, "truncation": True})  # Reduced max_length
        elif task == "summarization":
            model_kwargs.update({"max_length": 100, "min_length": 20, "do_sample": False, "num_beams": 1, "truncation": True})  # Reduced lengths
        elif task == "text-generation":
            model_kwargs.update({"max_length": 256, "do_sample": True, "temperature": 0.7, "top_p": 0.9, "repetition_penalty": 1.1, "truncation": True})  # Reduced max_length
        try:
            if task == "text-generation":
                tokenizer = AutoTokenizer.from_pretrained(model_name)
                model = AutoModelForCausalLM.from_pretrained(model_name)
                pad_token_id = tokenizer.eos_token_id if tokenizer.pad_token_id is None else tokenizer.pad_token_id
                pipe = pipeline(
                    "text-generation",
                    model=model,
                    tokenizer=tokenizer,
                    device=-1,
                    pad_token_id=pad_token_id,
                    truncation=True
                )
                pipe.fallback_used = False
                _model_cache[cache_key] = pipe
                logger.info(f"Successfully loaded text-generation model: {model_name}")
                return pipe
            else:
                model = pipeline(task, model=model_name, **model_kwargs)
                model.fallback_used = False
                _model_cache[cache_key] = model
                logger.info(f"Successfully loaded model: {model_name}")
                return model
        except Exception as e:
            logger.warning(f"Failed to load primary model {model_name} for {task}: {str(e)}")
            # Try fallback and local_fallback
            model_config = MODEL_MAPPING.get(task, {})
            for fallback_key in ["fallback", "local_fallback"]:
                fallback_model = model_config.get(fallback_key)
                if fallback_model and fallback_model != model_name:  # Don't try the same model again
                    try:
                        logger.info(f"Trying fallback model: {fallback_model} for {task}")
                        model = pipeline(task, model=fallback_model, device=-1, truncation=True)
                        model.fallback_used = True
                        model.fallback_model = fallback_model
                        _model_cache[f"{task}_{fallback_model}"] = model
                        logger.info(f"Loaded fallback model: {fallback_model} for {task}")
                        return model
                    except Exception as e2:
                        logger.warning(f"Failed to load fallback model {fallback_model} for {task}: {str(e2)}")
            logger.error(f"All model loading failed for {task}, using static fallback.")
            return create_text_fallback(task)
    except Exception as e:
        logger.error(f"Error in load_model: {str(e)}")
        return create_text_fallback(task)

def create_text_fallback(task):
    class TextFallback:
        def __init__(self, task_type):
            self.task_type = task_type
            self.fallback_used = True
            self.fallback_model = "static_fallback"
        def __call__(self, text, *args, **kwargs):
            if self.task_type == "text-generation":
                return [{"generated_text": "Summary unavailable: Unable to load TinyLlama model. Please check system memory or model availability."}]
            elif self.task_type == "zero-shot-classification":
                text_lower = text.lower()
                labels = args[0] if args else ["positive", "negative"]
                scores = []
                for label in labels:
                    if label.lower() in text_lower:
                        scores.append(0.8)
                    else:
                        scores.append(0.2)
                return {"labels": labels, "scores": scores}
            elif self.task_type == "summarization":
                sentences = text.split('.')
                if len(sentences) > 3:
                    summary = '. '.join(sentences[:2]) + '.'
                else:
                    summary = text[:200] + ('...' if len(text) > 200 else '')
                return [{"summary_text": summary}]
            else:
                return {"result": "Model unavailable, using fallback"}
    return TextFallback(task)

def clear_model_cache():
    global _model_cache
    _model_cache.clear()
    logger.info("Model cache cleared")

def get_available_models():
    return MODEL_MAPPING