Spaces:
Running
Running
# app/services/vlm_services.py | |
from __future__ import annotations | |
from abc import ABC, abstractmethod | |
from typing import Dict, Any, Optional, List | |
import logging | |
from enum import Enum | |
logger = logging.getLogger(__name__) | |
class ModelType(Enum): | |
"""Enum for different VLM model types""" | |
GPT4V = "gpt4v" | |
CLAUDE_3_5_SONNET = "claude_3_5_sonnet" | |
GEMINI_PRO_VISION = "gemini_pro_vision" | |
LLAMA_VISION = "llama_vision" | |
CUSTOM = "custom" | |
class ServiceStatus(Enum): | |
READY = "ready" | |
DEGRADED = "degraded" # registered but probe failed or not run | |
UNAVAILABLE = "unavailable" | |
class VLMService(ABC): | |
"""Abstract base class for VLM services""" | |
def __init__(self, model_name: str, model_type: ModelType, provider: str = "custom", lazy_init: bool = True): | |
self.model_name = model_name | |
self.model_type = model_type | |
self.provider = provider | |
self.lazy_init = lazy_init | |
self.is_available = True # quick flag used by manager for random selection | |
self.status = ServiceStatus.DEGRADED | |
self._initialized = False | |
async def probe(self) -> bool: | |
""" | |
Lightweight reachability/metadata check. Providers should override. | |
Must be quick (<5s) and NEVER raise. Return True if reachable/ok. | |
""" | |
return True | |
async def ensure_ready(self) -> bool: | |
""" | |
Called once before first use. Providers may override to open clients/warm caches. | |
Must set _initialized True and return True on success. NEVER raise. | |
""" | |
self._initialized = True | |
self.status = ServiceStatus.READY | |
return True | |
async def generate_caption(self, image_bytes: bytes, prompt: str, metadata_instructions: str = "") -> Dict[str, Any]: | |
"""Generate caption for an image""" | |
... | |
# Optional for multi-image models; override in providers that support it. | |
async def generate_multi_image_caption(self, image_bytes_list: List[bytes], prompt: str, metadata_instructions: str = "") -> Dict[str, Any]: | |
raise NotImplementedError("Multi-image caption not implemented for this service") | |
def get_model_info(self) -> Dict[str, Any]: | |
"""Get model information""" | |
return { | |
"name": self.model_name, | |
"type": self.model_type.value, | |
"provider": self.provider, | |
"available": self.is_available, | |
"status": self.status.value, | |
"lazy_init": self.lazy_init, | |
} | |
class VLMServiceManager: | |
"""Manager for multiple VLM services""" | |
def __init__(self): | |
self.services: Dict[str, VLMService] = {} | |
self.default_service: Optional[str] = None | |
def register_service(self, service: VLMService): | |
""" | |
Register a VLM service (NO network calls here). | |
We’ll probe later, asynchronously, so registration never blocks startup. | |
""" | |
self.services[service.model_name] = service | |
if not self.default_service: | |
self.default_service = service.model_name | |
logger.info("Registered VLM service: %s (%s)", service.model_name, service.provider) | |
async def probe_all(self): | |
""" | |
Run lightweight probes for all registered services. | |
Failures do not remove services; they stay DEGRADED and will lazy-init on first use. | |
""" | |
for svc in self.services.values(): | |
try: | |
ok = await svc.probe() | |
svc.status = ServiceStatus.READY if ok else ServiceStatus.DEGRADED | |
# If probe fails but lazy_init is allowed, keep is_available True so selection still works. | |
svc.is_available = ok or svc.lazy_init | |
logger.info("Probe %s -> %s", svc.model_name, svc.status.value) | |
except Exception as e: | |
logger.warning("Probe failed for %s: %r", svc.model_name, e) | |
svc.status = ServiceStatus.DEGRADED | |
svc.is_available = bool(svc.lazy_init) | |
def get_service(self, model_name: str) -> Optional[VLMService]: | |
"""Get a specific VLM service""" | |
return self.services.get(model_name) | |
def get_default_service(self) -> Optional[VLMService]: | |
"""Get the default VLM service""" | |
return self.services.get(self.default_service) if self.default_service else None | |
def get_available_models(self) -> list: | |
"""Get list of available model names""" | |
return list(self.services.keys()) | |
async def _pick_service(self, model_name: Optional[str], db_session) -> VLMService: | |
# Specific pick | |
service = None | |
if model_name and model_name != "random": | |
service = self.services.get(model_name) | |
if not service: | |
logger.warning("Model '%s' not found; will pick fallback", model_name) | |
# Fallback / random based on DB allowlist (is_available==True) | |
if not service and self.services: | |
if db_session: | |
try: | |
from .. import crud # local import to avoid cycles at import time | |
available_models = crud.get_models(db_session) | |
allowed = {m.m_code for m in available_models if getattr(m, "is_available", False)} | |
# Check for configured fallback model first | |
configured_fallback = crud.get_fallback_model(db_session) | |
if configured_fallback and configured_fallback in allowed: | |
fallback_service = self.services.get(configured_fallback) | |
if fallback_service and fallback_service.is_available: | |
logger.info("Using configured fallback model: %s", configured_fallback) | |
service = fallback_service | |
# If no configured fallback or it's not available, use STUB_MODEL as final fallback | |
if not service: | |
service = self.services.get("STUB_MODEL") or next(iter(self.services.values())) | |
logger.info("Using STUB_MODEL as final fallback") | |
except Exception as e: | |
logger.warning("DB availability check failed: %r; using first available", e) | |
avail = [s for s in self.services.values() if s.is_available] | |
service = (self.services.get("STUB_MODEL") or (random.choice(avail) if avail else next(iter(self.services.values())))) | |
else: | |
import random | |
avail = [s for s in self.services.values() if s.is_available] | |
service = (random.choice(avail) if avail else (self.services.get("STUB_MODEL") or next(iter(self.services.values())))) | |
if not service: | |
raise RuntimeError("No VLM service available") | |
# Lazy init on first use | |
if service.lazy_init and not service._initialized: | |
try: | |
ok = await service.ensure_ready() | |
service.status = ServiceStatus.READY if ok else ServiceStatus.DEGRADED | |
except Exception as e: | |
logger.warning("ensure_ready failed for %s: %r", service.model_name, e) | |
service.status = ServiceStatus.DEGRADED | |
return service | |
async def generate_caption(self, image_bytes: bytes, prompt: str, metadata_instructions: str = "", model_name: str | None = None, db_session=None) -> dict: | |
"""Generate caption using the specified model or fallback to available service.""" | |
service = await self._pick_service(model_name, db_session) | |
try: | |
result = await service.generate_caption(image_bytes, prompt, metadata_instructions) | |
result["model"] = service.model_name | |
return result | |
except Exception as e: | |
logger.error("Error with %s: %r; trying fallbacks", service.model_name, e) | |
# First, try the configured fallback model if available | |
if db_session: | |
try: | |
from .. import crud | |
configured_fallback = crud.get_fallback_model(db_session) | |
if configured_fallback and configured_fallback != service.model_name: | |
fallback_service = self.services.get(configured_fallback) | |
if fallback_service and fallback_service.is_available: | |
logger.info("Trying configured fallback model: %s", configured_fallback) | |
try: | |
if fallback_service.lazy_init and not fallback_service._initialized: | |
await fallback_service.ensure_ready() | |
res = await fallback_service.generate_caption(image_bytes, prompt, metadata_instructions) | |
res.update({ | |
"model": fallback_service.model_name, | |
"fallback_used": True, | |
"original_model": service.model_name, | |
"fallback_reason": str(e), | |
}) | |
logger.info("Configured fallback model %s succeeded", configured_fallback) | |
return res | |
except Exception as fe: | |
logger.warning("Configured fallback service %s also failed: %r", configured_fallback, fe) | |
except Exception as db_error: | |
logger.warning("Failed to get configured fallback: %r", db_error) | |
# If configured fallback failed or not available, try STUB_MODEL | |
stub_service = self.services.get("STUB_MODEL") | |
if stub_service and stub_service != service.model_name: | |
logger.info("Trying STUB_MODEL as final fallback") | |
try: | |
if stub_service.lazy_init and not stub_service._initialized: | |
await stub_service.ensure_ready() | |
res = await stub_service.generate_caption(image_bytes, prompt, metadata_instructions) | |
res.update({ | |
"model": stub_service.model_name, | |
"fallback_used": True, | |
"original_model": service.model_name, | |
"fallback_reason": str(e), | |
}) | |
logger.info("STUB_MODEL succeeded as final fallback") | |
return res | |
except Exception as fe: | |
logger.warning("STUB_MODEL also failed: %r", fe) | |
# All services failed | |
raise RuntimeError(f"All VLM services failed. Last error from {service.model_name}: {e}") | |
async def generate_multi_image_caption(self, image_bytes_list: List[bytes], prompt: str, metadata_instructions: str = "", model_name: str | None = None, db_session=None) -> dict: | |
"""Multi-image version if a provider supports it.""" | |
service = await self._pick_service(model_name, db_session) | |
try: | |
result = await service.generate_multi_image_caption(image_bytes_list, prompt, metadata_instructions) | |
result["model"] = service.model_name | |
return result | |
except Exception as e: | |
logger.error("Error with %s (multi): %r; trying fallbacks", service.model_name, e) | |
for other in self.services.values(): | |
if other is service: | |
continue | |
try: | |
if other.lazy_init and not other._initialized: | |
await other.ensure_ready() | |
res = await other.generate_multi_image_caption(image_bytes_list, prompt, metadata_instructions) | |
res.update({ | |
"model": other.model_name, | |
"fallback_used": True, | |
"original_model": service.model_name, | |
"fallback_reason": str(e), | |
}) | |
return res | |
except Exception: | |
continue | |
raise RuntimeError(f"All VLM services failed (multi). Last error from {service.model_name}: {e}") | |
# Global manager instance (as in your current code) | |
vlm_manager = VLMServiceManager() | |