File size: 12,456 Bytes
3cf9fa0
 
 
d7291ef
351d460
d7291ef
 
 
 
 
3cf9fa0
d7291ef
 
 
 
 
 
 
 
3cf9fa0
 
 
 
 
 
 
d7291ef
 
3cf9fa0
 
d7291ef
 
3cf9fa0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7291ef
fe5d98f
d7291ef
3cf9fa0
 
 
 
 
 
d7291ef
 
 
 
 
3cf9fa0
d7291ef
3cf9fa0
 
d7291ef
 
3cf9fa0
d7291ef
 
3cf9fa0
d7291ef
 
 
3cf9fa0
d7291ef
3cf9fa0
 
 
 
d7291ef
 
 
3cf9fa0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7291ef
 
 
3cf9fa0
d7291ef
 
3cf9fa0
 
65933cd
 
 
3cf9fa0
 
 
351d460
 
 
 
3cf9fa0
 
 
351d460
 
 
3cf9fa0
351d460
3cf9fa0
351d460
3cf9fa0
 
 
 
 
 
 
351d460
3cf9fa0
 
 
 
351d460
3cf9fa0
 
 
351d460
3cf9fa0
 
 
 
351d460
3cf9fa0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351d460
 
 
 
 
3cf9fa0
351d460
3cf9fa0
ba5edb0
 
 
3cf9fa0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351d460
 
 
 
 
3cf9fa0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7291ef
3cf9fa0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
# app/services/vlm_services.py
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, List
import logging
from enum import Enum

logger = logging.getLogger(__name__)


class ModelType(Enum):
    """Enum for different VLM model types"""
    GPT4V = "gpt4v"
    CLAUDE_3_5_SONNET = "claude_3_5_sonnet"
    GEMINI_PRO_VISION = "gemini_pro_vision"
    LLAMA_VISION = "llama_vision"
    CUSTOM = "custom"


class ServiceStatus(Enum):
    READY = "ready"
    DEGRADED = "degraded"     # registered but probe failed or not run
    UNAVAILABLE = "unavailable"


class VLMService(ABC):
    """Abstract base class for VLM services"""

    def __init__(self, model_name: str, model_type: ModelType, provider: str = "custom", lazy_init: bool = True):
        self.model_name = model_name
        self.model_type = model_type
        self.provider = provider
        self.lazy_init = lazy_init
        self.is_available = True            # quick flag used by manager for random selection
        self.status = ServiceStatus.DEGRADED
        self._initialized = False

    async def probe(self) -> bool:
        """
        Lightweight reachability/metadata check. Providers should override.
        Must be quick (<5s) and NEVER raise. Return True if reachable/ok.
        """
        return True

    async def ensure_ready(self) -> bool:
        """
        Called once before first use. Providers may override to open clients/warm caches.
        Must set _initialized True and return True on success. NEVER raise.
        """
        self._initialized = True
        self.status = ServiceStatus.READY
        return True

    @abstractmethod
    async def generate_caption(self, image_bytes: bytes, prompt: str, metadata_instructions: str = "") -> Dict[str, Any]:
        """Generate caption for an image"""
        ...

    # Optional for multi-image models; override in providers that support it.
    async def generate_multi_image_caption(self, image_bytes_list: List[bytes], prompt: str, metadata_instructions: str = "") -> Dict[str, Any]:
        raise NotImplementedError("Multi-image caption not implemented for this service")

    def get_model_info(self) -> Dict[str, Any]:
        """Get model information"""
        return {
            "name": self.model_name,
            "type": self.model_type.value,
            "provider": self.provider,
            "available": self.is_available,
            "status": self.status.value,
            "lazy_init": self.lazy_init,
        }


class VLMServiceManager:
    """Manager for multiple VLM services"""

    def __init__(self):
        self.services: Dict[str, VLMService] = {}
        self.default_service: Optional[str] = None

    def register_service(self, service: VLMService):
        """
        Register a VLM service (NO network calls here).
        We’ll probe later, asynchronously, so registration never blocks startup.
        """
        self.services[service.model_name] = service
        if not self.default_service:
            self.default_service = service.model_name
        logger.info("Registered VLM service: %s (%s)", service.model_name, service.provider)

    async def probe_all(self):
        """
        Run lightweight probes for all registered services.
        Failures do not remove services; they stay DEGRADED and will lazy-init on first use.
        """
        for svc in self.services.values():
            try:
                ok = await svc.probe()
                svc.status = ServiceStatus.READY if ok else ServiceStatus.DEGRADED
                # If probe fails but lazy_init is allowed, keep is_available True so selection still works.
                svc.is_available = ok or svc.lazy_init
                logger.info("Probe %s -> %s", svc.model_name, svc.status.value)
            except Exception as e:
                logger.warning("Probe failed for %s: %r", svc.model_name, e)
                svc.status = ServiceStatus.DEGRADED
                svc.is_available = bool(svc.lazy_init)

    def get_service(self, model_name: str) -> Optional[VLMService]:
        """Get a specific VLM service"""
        return self.services.get(model_name)

    def get_default_service(self) -> Optional[VLMService]:
        """Get the default VLM service"""
        return self.services.get(self.default_service) if self.default_service else None

    def get_available_models(self) -> list:
        """Get list of available model names"""
        return list(self.services.keys())

    async def _pick_service(self, model_name: Optional[str], db_session) -> VLMService:
        # Specific pick
        service = None
        if model_name and model_name != "random":
            service = self.services.get(model_name)
            if not service:
                logger.warning("Model '%s' not found; will pick fallback", model_name)

        # Fallback / random based on DB allowlist (is_available==True)
        if not service and self.services:
            if db_session:
                try:
                    from .. import crud  # local import to avoid cycles at import time
                    available_models = crud.get_models(db_session)
                    allowed = {m.m_code for m in available_models if getattr(m, "is_available", False)}
                    
                    # Check for configured fallback model first
                    configured_fallback = crud.get_fallback_model(db_session)
                    if configured_fallback and configured_fallback in allowed:
                        fallback_service = self.services.get(configured_fallback)
                        if fallback_service and fallback_service.is_available:
                            logger.info("Using configured fallback model: %s", configured_fallback)
                            service = fallback_service
                    
                    # If no configured fallback or it's not available, use STUB_MODEL as final fallback
                    if not service:
                        service = self.services.get("STUB_MODEL") or next(iter(self.services.values()))
                        logger.info("Using STUB_MODEL as final fallback")
                except Exception as e:
                    logger.warning("DB availability check failed: %r; using first available", e)
                    avail = [s for s in self.services.values() if s.is_available]
                    service = (self.services.get("STUB_MODEL") or (random.choice(avail) if avail else next(iter(self.services.values()))))
            else:
                import random
                avail = [s for s in self.services.values() if s.is_available]
                service = (random.choice(avail) if avail else (self.services.get("STUB_MODEL") or next(iter(self.services.values()))))

        if not service:
            raise RuntimeError("No VLM service available")

        # Lazy init on first use
        if service.lazy_init and not service._initialized:
            try:
                ok = await service.ensure_ready()
                service.status = ServiceStatus.READY if ok else ServiceStatus.DEGRADED
            except Exception as e:
                logger.warning("ensure_ready failed for %s: %r", service.model_name, e)
                service.status = ServiceStatus.DEGRADED

        return service

    async def generate_caption(self, image_bytes: bytes, prompt: str, metadata_instructions: str = "", model_name: str | None = None, db_session=None) -> dict:
        """Generate caption using the specified model or fallback to available service."""
        service = await self._pick_service(model_name, db_session)
        try:
            result = await service.generate_caption(image_bytes, prompt, metadata_instructions)
            result["model"] = service.model_name
            return result
        except Exception as e:
            logger.error("Error with %s: %r; trying fallbacks", service.model_name, e)
            
            # First, try the configured fallback model if available
            if db_session:
                try:
                    from .. import crud
                    configured_fallback = crud.get_fallback_model(db_session)
                    if configured_fallback and configured_fallback != service.model_name:
                        fallback_service = self.services.get(configured_fallback)
                        if fallback_service and fallback_service.is_available:
                            logger.info("Trying configured fallback model: %s", configured_fallback)
                            try:
                                if fallback_service.lazy_init and not fallback_service._initialized:
                                    await fallback_service.ensure_ready()
                                res = await fallback_service.generate_caption(image_bytes, prompt, metadata_instructions)
                                res.update({
                                    "model": fallback_service.model_name,
                                    "fallback_used": True,
                                    "original_model": service.model_name,
                                    "fallback_reason": str(e),
                                })
                                logger.info("Configured fallback model %s succeeded", configured_fallback)
                                return res
                            except Exception as fe:
                                logger.warning("Configured fallback service %s also failed: %r", configured_fallback, fe)
                except Exception as db_error:
                    logger.warning("Failed to get configured fallback: %r", db_error)
            
            # If configured fallback failed or not available, try STUB_MODEL
            stub_service = self.services.get("STUB_MODEL")
            if stub_service and stub_service != service.model_name:
                logger.info("Trying STUB_MODEL as final fallback")
                try:
                    if stub_service.lazy_init and not stub_service._initialized:
                        await stub_service.ensure_ready()
                    res = await stub_service.generate_caption(image_bytes, prompt, metadata_instructions)
                    res.update({
                        "model": stub_service.model_name,
                        "fallback_used": True,
                        "original_model": service.model_name,
                        "fallback_reason": str(e),
                    })
                    logger.info("STUB_MODEL succeeded as final fallback")
                    return res
                except Exception as fe:
                    logger.warning("STUB_MODEL also failed: %r", fe)
            
            # All services failed
            raise RuntimeError(f"All VLM services failed. Last error from {service.model_name}: {e}")

    async def generate_multi_image_caption(self, image_bytes_list: List[bytes], prompt: str, metadata_instructions: str = "", model_name: str | None = None, db_session=None) -> dict:
        """Multi-image version if a provider supports it."""
        service = await self._pick_service(model_name, db_session)
        try:
            result = await service.generate_multi_image_caption(image_bytes_list, prompt, metadata_instructions)
            result["model"] = service.model_name
            return result
        except Exception as e:
            logger.error("Error with %s (multi): %r; trying fallbacks", service.model_name, e)
            for other in self.services.values():
                if other is service:
                    continue
                try:
                    if other.lazy_init and not other._initialized:
                        await other.ensure_ready()
                    res = await other.generate_multi_image_caption(image_bytes_list, prompt, metadata_instructions)
                    res.update({
                        "model": other.model_name,
                        "fallback_used": True,
                        "original_model": service.model_name,
                        "fallback_reason": str(e),
                    })
                    return res
                except Exception:
                    continue
            raise RuntimeError(f"All VLM services failed (multi). Last error from {service.model_name}: {e}")


# Global manager instance (as in your current code)
vlm_manager = VLMServiceManager()