Spaces:
Sleeping
Sleeping
import os | |
import io | |
import time | |
import gc | |
import pickle | |
import tempfile | |
import logging | |
from typing import Optional | |
import asyncio | |
from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
from fastapi.responses import JSONResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from PIL import Image | |
import torch | |
from transformers import AutoProcessor, AutoModelForImageTextToText | |
from huggingface_hub import HfFolder, snapshot_download | |
# Ensure HF cache is writable and not using /data | |
import os as _os_env | |
_os_env.environ.setdefault("HF_HOME", "/tmp/hf_home") | |
_os_env.environ.setdefault("HUGGINGFACE_HUB_CACHE", "/tmp/hf_home") | |
# Avoid deprecated TRANSFORMERS_CACHE which may point to /data | |
if "TRANSFORMERS_CACHE" in _os_env.environ: | |
del _os_env.environ["TRANSFORMERS_CACHE"] | |
_os_env.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "0") | |
# Configuration du logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
APP_START_TS = time.time() | |
# Configuration du modèle | |
MODEL_ID = os.environ.get("MODEL_ID", "google/gemma-3n-E4B-it") # Fixed model name | |
DEVICE_MAP = os.environ.get("DEVICE_MAP", "cpu") # Force CPU pour Hugging Face Spaces | |
MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "256")) | |
# Fichier de cache pour partager le modèle entre Streamlit et FastAPI | |
MODEL_CACHE_FILE = os.path.join(tempfile.gettempdir(), "agrilens_model_cache.pkl") | |
def _get_dtype() -> torch.dtype: | |
"""Choix optimal du dtype selon le hardware.""" | |
# Force float32 pour Hugging Face Spaces (CPU) | |
return torch.float32 | |
def _build_prompt(culture: Optional[str], notes: Optional[str]) -> str: | |
"""Création du prompt d'analyse.""" | |
base = ( | |
"You are an agronomy assistant. Analyze the provided plant leaf image and identify the most likely disease. " | |
"Return a concise diagnosis in French with: disease name, short explanation of symptoms, " | |
"and 3 actionable treatment recommendations." | |
) | |
if culture: | |
base += f"\nCulture: {culture}" | |
if notes: | |
base += f"\nNotes: {notes}" | |
return base | |
class SharedModelManager: | |
"""Gestionnaire de modèle partagé entre Streamlit et FastAPI""" | |
def __init__(self): | |
self.model = None | |
self.processor = None | |
self.device_map = DEVICE_MAP | |
self.dtype = _get_dtype() | |
self._load_attempted = False | |
self._loading = False | |
self._load_error = None | |
self._last_load_attempt = 0 | |
self._load_timeout = 300 # 5 minutes timeout | |
logger.info(f"Initializing ModelManager with device_map={self.device_map}, dtype={self.dtype}") | |
# Try to recover from previous state | |
self._recover_state() | |
def _recover_state(self): | |
"""Try to recover model state from disk""" | |
try: | |
state_file = "/tmp/model_state.json" | |
if os.path.exists(state_file): | |
import json | |
with open(state_file, 'r') as f: | |
state = json.load(f) | |
# Check if the state is recent (less than 1 hour old) | |
if time.time() - state.get('timestamp', 0) < 3600: | |
logger.info("État précédent trouvé, tentative de récupération...") | |
# Note: We can't actually reload the model objects, but we can mark as attempted | |
self._load_attempted = True | |
self._last_load_attempt = state.get('timestamp', 0) | |
except Exception as e: | |
logger.warning(f"Impossible de récupérer l'état: {e}") | |
def _save_state(self): | |
"""Save current state to disk""" | |
try: | |
state_file = "/tmp/model_state.json" | |
import json | |
state = { | |
'timestamp': time.time(), | |
'model_loaded': self.model is not None, | |
'processor_loaded': self.processor is not None, | |
'load_attempted': self._load_attempted, | |
'loading': self._loading, | |
'error': self._load_error | |
} | |
with open(state_file, 'w') as f: | |
json.dump(state, f) | |
except Exception as e: | |
logger.warning(f"Impossible de sauvegarder l'état: {e}") | |
def check_streamlit_model_cache(self): | |
"""Vérifie si le modèle est disponible dans le cache Streamlit via un fichier""" | |
try: | |
# Vérifier si le fichier de cache existe et est récent (moins de 1 heure) | |
if os.path.exists(MODEL_CACHE_FILE): | |
file_age = time.time() - os.path.getmtime(MODEL_CACHE_FILE) | |
if file_age < 3600: # 1 heure | |
# Lire les informations du cache | |
try: | |
with open(MODEL_CACHE_FILE, 'rb') as f: | |
cache_data = pickle.load(f) | |
logger.info(f"Cache Streamlit trouvé: {cache_data}") | |
return True | |
except Exception as e: | |
logger.error(f"Erreur lors de la lecture du cache: {e}") | |
return False | |
except Exception as e: | |
logger.error(f"Erreur lors de la vérification du cache: {e}") | |
return False | |
def load_model_directly(self): | |
"""Robust model loading that tries multiple approaches to avoid permission issues""" | |
try: | |
import gc | |
self._loading = True | |
self._load_attempted = True | |
self._last_load_attempt = time.time() | |
self._load_error = None | |
# Try different approaches in order of preference | |
approaches = [ | |
("Direct HF Hub loading", self._try_direct_loading), | |
("Cache in /app/cache", self._try_app_cache), | |
("Cache in /tmp/hf_home", self._try_tmp_cache), | |
("Cache in /tmp/model_repo", self._try_tmp_repo), | |
] | |
for approach_name, approach_func in approaches: | |
try: | |
logger.info(f"Tentative: {approach_name}") | |
success = approach_func() | |
if success: | |
self._loading = False | |
self._save_state() | |
logger.info(f"✅ Succès avec {approach_name}") | |
return True | |
except Exception as e: | |
logger.warning(f"❌ Échec de {approach_name}: {e}") | |
continue | |
# If all approaches failed | |
self._loading = False | |
self._load_error = "Toutes les approches de chargement ont échoué" | |
self._save_state() | |
return False | |
except Exception as e: | |
logger.error(f"Erreur critique chargement: {e}") | |
self._loading = False | |
self._load_error = str(e) | |
self._save_state() | |
return False | |
def _try_direct_loading(self): | |
"""Try to load directly from Hugging Face Hub without using /data by forcing cache_dir""" | |
try: | |
logger.info("Chargement direct depuis HF Hub...") | |
writable_cache = os.environ.get("HF_HOME", "/home/user/.cache/huggingface") | |
os.makedirs(writable_cache, exist_ok=True) | |
# Load processor directly with explicit cache_dir | |
self.processor = AutoProcessor.from_pretrained( | |
MODEL_ID, | |
trust_remote_code=True, | |
cache_dir=writable_cache, | |
local_files_only=False, | |
) | |
logger.info("Processor chargé directement") | |
# Load model directly with explicit cache_dir | |
self.model = AutoModelForImageTextToText.from_pretrained( | |
MODEL_ID, | |
trust_remote_code=True, | |
cache_dir=writable_cache, | |
local_files_only=False, | |
low_cpu_mem_usage=True, | |
device_map=self.device_map, | |
torch_dtype=self.dtype, | |
) | |
if self.device_map == "cpu": | |
self.model = self.model.to("cpu") | |
logger.info("Modèle chargé directement depuis HF Hub") | |
return True | |
except Exception as e: | |
logger.error(f"Échec chargement direct: {e}") | |
return False | |
def _try_app_cache(self): | |
"""Try to cache in /app/cache directory""" | |
try: | |
from huggingface_hub import snapshot_download | |
cache_dir = "/app/cache/huggingface" | |
os.makedirs(cache_dir, exist_ok=True) | |
logger.info(f"Snapshot vers {cache_dir}") | |
snapshot_download( | |
repo_id=MODEL_ID, | |
local_dir=cache_dir, | |
local_dir_use_symlinks=False, | |
resume_download=True, | |
token=os.environ.get("HF_TOKEN", None), | |
) | |
# Load from cache | |
self.processor = AutoProcessor.from_pretrained( | |
cache_dir, | |
trust_remote_code=True, | |
local_files_only=True, | |
) | |
logger.info("Processor chargé depuis /app/cache") | |
self.model = AutoModelForImageTextToText.from_pretrained( | |
cache_dir, | |
trust_remote_code=True, | |
local_files_only=True, | |
low_cpu_mem_usage=True, | |
device_map=self.device_map, | |
torch_dtype=self.dtype, | |
) | |
if self.device_map == "cpu": | |
self.model = self.model.to("cpu") | |
logger.info("Modèle chargé depuis /app/cache") | |
return True | |
except Exception as e: | |
logger.error(f"Échec cache /app: {e}") | |
return False | |
def _try_tmp_cache(self): | |
"""Try to cache in /tmp/hf_home directory""" | |
try: | |
from huggingface_hub import snapshot_download | |
cache_dir = "/tmp/hf_home" | |
os.makedirs(cache_dir, exist_ok=True) | |
logger.info(f"Snapshot vers {cache_dir}") | |
snapshot_download( | |
repo_id=MODEL_ID, | |
local_dir=cache_dir, | |
local_dir_use_symlinks=False, | |
resume_download=True, | |
token=os.environ.get("HF_TOKEN", None), | |
) | |
# Load from cache | |
self.processor = AutoProcessor.from_pretrained( | |
cache_dir, | |
trust_remote_code=True, | |
local_files_only=True, | |
) | |
logger.info("Processor chargé depuis /tmp/hf_home") | |
self.model = AutoModelForImageTextToText.from_pretrained( | |
cache_dir, | |
trust_remote_code=True, | |
local_files_only=True, | |
low_cpu_mem_usage=True, | |
device_map=self.device_map, | |
torch_dtype=self.dtype, | |
) | |
if self.device_map == "cpu": | |
self.model = self.model.to("cpu") | |
logger.info("Modèle chargé depuis /tmp/hf_home") | |
return True | |
except Exception as e: | |
logger.error(f"Échec cache /tmp/hf_home: {e}") | |
return False | |
def _try_tmp_repo(self): | |
"""Try to cache in /tmp/model_repo directory (original approach)""" | |
try: | |
from huggingface_hub import snapshot_download | |
repo_dir = "/tmp/model_repo" | |
offload_dir = "/tmp/model_offload" | |
os.makedirs(repo_dir, exist_ok=True) | |
os.makedirs(offload_dir, exist_ok=True) | |
logger.info(f"Snapshot vers {repo_dir}") | |
snapshot_download( | |
repo_id=MODEL_ID, | |
local_dir=repo_dir, | |
local_dir_use_symlinks=False, | |
resume_download=True, | |
token=os.environ.get("HF_TOKEN", None), | |
) | |
# Load from cache | |
self.processor = AutoProcessor.from_pretrained( | |
repo_dir, | |
trust_remote_code=True, | |
local_files_only=True, | |
) | |
logger.info("Processor chargé depuis /tmp/model_repo") | |
self.model = AutoModelForImageTextToText.from_pretrained( | |
repo_dir, | |
trust_remote_code=True, | |
local_files_only=True, | |
low_cpu_mem_usage=True, | |
device_map=self.device_map, | |
torch_dtype=self.dtype, | |
offload_folder=offload_dir, | |
max_memory={0: "8GB", "cpu": "8GB"} if self.device_map == "cpu" else None, | |
) | |
if self.device_map == "cpu": | |
self.model = self.model.to("cpu") | |
logger.info("Modèle chargé depuis /tmp/model_repo") | |
return True | |
except Exception as e: | |
logger.error(f"Échec cache /tmp/model_repo: {e}") | |
return False | |
def load_model_with_retry(self, max_retries=5, delay=60): | |
"""Charge le modèle avec retry automatique en cas d'échec""" | |
for attempt in range(max_retries): | |
try: | |
logger.info(f"Tentative de chargement {attempt + 1}/{max_retries}") | |
success = self.load_model_directly() | |
if success: | |
return True | |
else: | |
logger.warning(f"Échec tentative {attempt + 1}, attente {delay}s...") | |
if attempt < max_retries - 1: | |
time.sleep(delay) | |
except Exception as e: | |
logger.error(f"Erreur tentative {attempt + 1}: {e}") | |
if attempt < max_retries - 1: | |
time.sleep(delay) | |
logger.error(f"Toutes les {max_retries} tentatives ont échoué") | |
return False | |
def ensure_model_loaded(self): | |
"""S'assure que le modèle est chargé""" | |
if self.model is not None and self.processor is not None: | |
return True | |
if not self._load_attempted: | |
self._load_attempted = True | |
# Charge directement le modèle (lancé à la demande) | |
return self.load_model_directly() | |
return False | |
def get_load_status(self): | |
"""Retourne le statut de chargement""" | |
return { | |
"loaded": self.model is not None and self.processor is not None, | |
"loading": self._loading, | |
"error": self._load_error, | |
"attempted": self._load_attempted | |
} | |
def _complete_partial_load(self): | |
"""Complete a partial model load (when processor is loaded but model is not)""" | |
try: | |
logger.info("Tentative de complétion du chargement partiel...") | |
if self.processor and not self.model: | |
logger.info("Processor disponible, chargement du modèle seulement...") | |
# Try to load just the model using the existing processor | |
try: | |
# Use the processor's config to load the model | |
model_config = self.processor.config | |
model_path = model_config._name_or_path | |
logger.info(f"Chargement du modèle depuis {model_path}") | |
self.model = AutoModelForImageTextToText.from_pretrained( | |
model_path, | |
trust_remote_code=True, | |
low_cpu_mem_usage=True, | |
device_map=self.device_map, | |
torch_dtype=self.dtype, | |
offload_folder="/tmp/model_offload", | |
max_memory={0: "8GB", "cpu": "8GB"} if self.device_map == "cpu" else None | |
) | |
if self.device_map == "cpu": | |
self.model = self.model.to("cpu") | |
logger.info("Modèle complété avec succès!") | |
self._loading = False | |
self._save_state() | |
return True | |
except Exception as e: | |
logger.error(f"Échec de la complétion: {e}") | |
# Fall back to full reload | |
return self.load_model_directly() | |
else: | |
logger.info("Pas de chargement partiel à compléter") | |
return False | |
except Exception as e: | |
logger.error(f"Erreur lors de la complétion: {e}") | |
return False | |
# Instance globale du gestionnaire de modèle | |
model_manager = SharedModelManager() | |
app = FastAPI(title="AgriLens AI FastAPI", version="1.0.0") | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Warmup non bloquant au démarrage - use a more robust approach | |
async def _warmup_background(): | |
"""Démarrage du chargement en arrière-plan sans bloquer le serveur""" | |
logger.info("Démarrage du chargement du modèle en arrière-plan...") | |
# Use a more robust approach that won't be cancelled | |
try: | |
# Run in thread but don't await it to avoid cancellation | |
import threading | |
thread = threading.Thread(target=model_manager.load_model_directly, daemon=True) | |
thread.start() | |
logger.info("Thread de chargement démarré") | |
except Exception as e: | |
logger.error(f"Erreur lors du démarrage du thread: {e}") | |
# Alternative: also try to load on first request if not already loaded | |
async def ensure_model_loaded_middleware(request, call_next): | |
"""Middleware pour s'assurer que le modèle est chargé avec récupération automatique""" | |
try: | |
current_time = time.time() | |
# Check for partial loads and trigger automatic recovery (with rate limiting) | |
if (model_manager.processor and not model_manager.model and | |
not model_manager._loading and | |
not hasattr(model_manager, '_middleware_recovery_triggered')): | |
logger.info("🔧 Récupération automatique déclenchée via middleware") | |
model_manager._middleware_recovery_triggered = current_time | |
# Start recovery in background | |
import threading | |
thread = threading.Thread(target=model_manager._complete_partial_load, daemon=True) | |
thread.start() | |
# Check if model needs loading (with rate limiting) | |
elif (not model_manager.model and not model_manager._loading and | |
not hasattr(model_manager, '_middleware_load_triggered')): | |
logger.info("Modèle non chargé, tentative de chargement...") | |
model_manager._middleware_load_triggered = current_time | |
# Start loading in background | |
import threading | |
thread = threading.Thread(target=model_manager.load_model_directly, daemon=True) | |
thread.start() | |
# Clean up old triggers (older than 5 minutes) | |
if hasattr(model_manager, '_middleware_recovery_triggered'): | |
if current_time - model_manager._middleware_recovery_triggered > 300: | |
delattr(model_manager, '_middleware_recovery_triggered') | |
if hasattr(model_manager, '_middleware_load_triggered'): | |
if current_time - model_manager._middleware_load_triggered > 300: | |
delattr(model_manager, '_middleware_load_triggered') | |
except Exception as e: | |
logger.error(f"Erreur dans le middleware: {e}") | |
response = await call_next(request) | |
return response | |
# Add a background task that keeps trying to load the model | |
async def _persistent_model_loader(): | |
"""Persistent model loader that keeps trying until success""" | |
import asyncio | |
import threading | |
def _load_loop(): | |
"""Infinite loop to keep trying to load the model""" | |
max_attempts = 5 # Maximum attempts before giving up | |
attempt_count = 0 | |
last_attempt_time = 0 | |
cooldown = 60 # Wait 60s between attempts | |
while attempt_count < max_attempts: | |
try: | |
current_time = time.time() | |
# Check if we should attempt loading | |
if (not model_manager.model and | |
not model_manager._loading and | |
current_time - last_attempt_time > cooldown): | |
logger.info(f"Persistent loader: tentative {attempt_count + 1}/{max_attempts}...") | |
last_attempt_time = current_time | |
attempt_count += 1 | |
success = model_manager.load_model_directly() | |
if success: | |
logger.info("Persistent loader: modèle chargé avec succès!") | |
break | |
else: | |
logger.warning(f"Persistent loader: échec {attempt_count}/{max_attempts}, nouvelle tentative dans {cooldown}s...") | |
time.sleep(cooldown) | |
else: | |
# Model is loading or loaded, wait a bit | |
time.sleep(10) | |
except Exception as e: | |
logger.error(f"Persistent loader: erreur: {e}") | |
attempt_count += 1 | |
time.sleep(cooldown) | |
if attempt_count >= max_attempts: | |
logger.warning("Persistent loader: nombre maximum de tentatives atteint, arrêt") | |
else: | |
logger.info("Persistent loader: terminé avec succès") | |
# Start the persistent loader in a daemon thread | |
thread = threading.Thread(target=_load_loop, daemon=True) | |
thread.start() | |
logger.info("Persistent model loader démarré") | |
# Add automated recovery system | |
async def _automated_recovery(): | |
"""Automated recovery system that detects and fixes partial loads""" | |
import threading | |
import time | |
def _recovery_loop(): | |
"""Continuous monitoring and recovery loop""" | |
last_recovery_attempt = 0 | |
recovery_cooldown = 60 # Wait 60s between recovery attempts | |
while True: | |
try: | |
current_time = time.time() | |
# Check for partial loads (processor loaded but model not) | |
if (model_manager.processor and not model_manager.model and | |
not model_manager._loading and | |
current_time - last_recovery_attempt > recovery_cooldown): | |
logger.info("🔧 Récupération automatique détectée: processor chargé mais modèle manquant") | |
logger.info("🚀 Lancement automatique de la récupération...") | |
last_recovery_attempt = current_time | |
# Try to complete the partial load | |
success = model_manager._complete_partial_load() | |
if success: | |
logger.info("✅ Récupération automatique réussie!") | |
break # Exit the loop if successful | |
else: | |
logger.warning("⚠️ Récupération automatique échouée, nouvelle tentative dans 60s...") | |
# Check for stuck loading states | |
elif (model_manager._loading and | |
current_time - model_manager._last_load_attempt > 300): # 5 minutes timeout | |
logger.warning("⏰ Timeout détecté, reset de l'état de chargement...") | |
model_manager._loading = False | |
model_manager._load_error = "Timeout - chargement bloqué" | |
model_manager._save_state() | |
# Wait before next check | |
time.sleep(15) # Check every 15 seconds | |
except Exception as e: | |
logger.error(f"Erreur dans la boucle de récupération: {e}") | |
time.sleep(30) | |
# Start the automated recovery in a daemon thread | |
thread = threading.Thread(target=_recovery_loop, daemon=True) | |
thread.start() | |
logger.info("🔧 Système de récupération automatique démarré") | |
# Add a more robust startup approach using a separate process | |
async def _robust_startup(): | |
"""Robust startup using a separate process to avoid CancelledError""" | |
import multiprocessing | |
import time | |
# Only start if not already loading | |
if model_manager._loading: | |
logger.info("Démarrage robuste: chargement déjà en cours, skip") | |
return | |
try: | |
logger.info("Démarrage du chargement du modèle en arrière-plan...") | |
# Set a flag to prevent multiple processes | |
if hasattr(model_manager, '_startup_process_running'): | |
logger.info("Processus de démarrage déjà en cours, skip") | |
return | |
model_manager._startup_process_running = True | |
def _startup_load(): | |
"""Load model in separate process""" | |
try: | |
# Set environment for this process | |
os.environ['HF_HOME'] = '/tmp/hf_home' | |
os.environ['TRANSFORMERS_CACHE'] = '/tmp/hf_home/transformers' | |
logger.info("Processus de chargement démarré") | |
success = model_manager.load_model_directly() | |
if success: | |
logger.info("Processus: chargement réussi") | |
else: | |
logger.warning("Processus: échec du chargement") | |
except Exception as e: | |
logger.error(f"Processus: erreur: {e}") | |
finally: | |
# Clean up | |
if hasattr(model_manager, '_startup_process_running'): | |
delattr(model_manager, '_startup_process_running') | |
# Start the process | |
process = multiprocessing.Process(target=_startup_load, daemon=True) | |
process.start() | |
logger.info(f"Processus de chargement du modèle démarré (PID: {process.pid})") | |
# Wait a bit for the process to start | |
time.sleep(2) | |
# Check if process is still alive | |
if not process.is_alive(): | |
logger.warning("Processus de démarrage s'est terminé prématurément") | |
if hasattr(model_manager, '_startup_process_running'): | |
delattr(model_manager, '_startup_process_running') | |
except Exception as e: | |
logger.error(f"Erreur lors du démarrage du processus: {e}") | |
if hasattr(model_manager, '_startup_process_running'): | |
delattr(model_manager, '_startup_process_running') | |
# Add health monitoring with automatic recovery | |
def health(): | |
"""Vérifie l'état de l'application et du modèle avec récupération automatique.""" | |
try: | |
# Check for partial loads and trigger automatic recovery | |
if model_manager.processor and not model_manager.model and not model_manager._loading: | |
logger.info("🔧 Récupération automatique déclenchée via /health") | |
# Start recovery in background | |
import threading | |
thread = threading.Thread(target=model_manager._complete_partial_load, daemon=True) | |
thread.start() | |
model_loaded = model_manager.ensure_model_loaded() | |
streamlit_cache_available = model_manager.check_streamlit_model_cache() | |
load_status = model_manager.get_load_status() | |
return { | |
"status": "ok" if model_loaded else "cold", | |
"uptime_s": int(time.time() - APP_START_TS), | |
"cuda": torch.cuda.is_available(), | |
"device_map": model_manager.device_map, | |
"dtype": str(model_manager.dtype), | |
"model_id": MODEL_ID, | |
"streamlit_cache_available": streamlit_cache_available, | |
"model_loaded": model_loaded, | |
"load_status": load_status, | |
"auto_recovery": "active", | |
} | |
except Exception as e: | |
logger.error(f"Erreur dans health check: {e}") | |
return { | |
"status": "error", | |
"error": str(e), | |
"uptime_s": int(time.time() - APP_START_TS), | |
} | |
def load(): | |
"""Force le chargement du modèle.""" | |
try: | |
success = model_manager.load_model_directly() | |
load_status = model_manager.get_load_status() | |
if success: | |
return {"status": "success", "message": "Modèle chargé avec succès", "load_status": load_status} | |
else: | |
return { | |
"status": "error", | |
"message": "Échec du chargement du modèle", | |
"load_status": load_status, | |
"error": model_manager._load_error | |
} | |
except Exception as e: | |
logger.error(f"Erreur lors du chargement forcé: {e}") | |
return {"status": "error", "message": f"Erreur: {str(e)}"} | |
async def diagnose( | |
image: UploadFile = File(...), | |
culture: Optional[str] = Form(None), | |
notes: Optional[str] = Form(None) | |
): | |
"""Analyse une image de feuille de plante.""" | |
try: | |
# Vérifier que le modèle est chargé | |
if not model_manager.ensure_model_loaded(): | |
load_status = model_manager.get_load_status() | |
if model_manager._loading: | |
raise HTTPException(status_code=503, detail="Modèle en cours de chargement, veuillez réessayer dans quelques secondes") | |
else: | |
raise HTTPException( | |
status_code=500, | |
detail=f"Modèle non disponible. Statut: {load_status}" | |
) | |
# Lire l'image | |
image_data = await image.read() | |
pil_image = Image.open(io.BytesIO(image_data)) | |
# Préparer le prompt | |
prompt = _build_prompt(culture, notes) | |
# Préparer les entrées pour le modèle | |
inputs = model_manager.processor( | |
images=pil_image, | |
text=prompt, | |
return_tensors="pt" | |
) | |
# Déplacer sur le bon device | |
if model_manager.device_map == "cpu": | |
inputs = {k: v.to("cpu") for k, v in inputs.items()} | |
# Générer la réponse | |
with torch.no_grad(): | |
outputs = model_manager.model.generate( | |
**inputs, | |
max_new_tokens=MAX_NEW_TOKENS, | |
do_sample=True, | |
temperature=0.7, | |
pad_token_id=model_manager.processor.tokenizer.eos_token_id | |
) | |
# Décoder la réponse | |
response_text = model_manager.processor.tokenizer.decode( | |
outputs[0], | |
skip_special_tokens=True | |
) | |
# Extraire seulement la partie générée (après le prompt) | |
if prompt in response_text: | |
diagnosis = response_text.split(prompt)[-1].strip() | |
else: | |
diagnosis = response_text.strip() | |
return { | |
"diagnosis": diagnosis, | |
"model_id": MODEL_ID, | |
"culture": culture, | |
"notes": notes, | |
"processing_time": time.time() - APP_START_TS | |
} | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Erreur lors du diagnostic: {e}") | |
raise HTTPException(status_code=500, detail=f"Erreur lors de l'analyse: {str(e)}") | |
def recover(): | |
"""Tente de récupérer un chargement partiel du modèle.""" | |
try: | |
if model_manager.processor and not model_manager.model: | |
logger.info("Récupération d'un chargement partiel...") | |
success = model_manager._complete_partial_load() | |
if success: | |
return {"status": "success", "message": "Modèle récupéré avec succès"} | |
else: | |
return {"status": "error", "message": "Échec de la récupération"} | |
else: | |
return {"status": "info", "message": "Pas de chargement partiel à récupérer"} | |
except Exception as e: | |
logger.error(f"Erreur lors de la récupération: {e}") | |
return {"status": "error", "message": f"Erreur: {str(e)}"} | |
def detailed_status(): | |
"""Statut détaillé du système avec informations de récupération automatique""" | |
try: | |
current_time = time.time() | |
# Calculate time since last load attempt | |
time_since_last_attempt = current_time - model_manager._last_load_attempt if model_manager._last_load_attempt > 0 else 0 | |
# Check for various states | |
partial_load_detected = model_manager.processor and not model_manager.model | |
stuck_loading = model_manager._loading and time_since_last_attempt > 300 | |
recovery_needed = partial_load_detected or stuck_loading | |
status_info = { | |
"timestamp": current_time, | |
"model_state": { | |
"processor_loaded": model_manager.processor is not None, | |
"model_loaded": model_manager.model is not None, | |
"loading": model_manager._loading, | |
"load_attempted": model_manager._load_attempted, | |
"time_since_last_attempt": f"{time_since_last_attempt:.1f}s" | |
}, | |
"auto_recovery": { | |
"active": True, | |
"partial_load_detected": partial_load_detected, | |
"stuck_loading_detected": stuck_loading, | |
"recovery_needed": recovery_needed, | |
"check_interval": "15s" | |
}, | |
"system": { | |
"uptime_s": int(current_time - APP_START_TS), | |
"device_map": model_manager.device_map, | |
"dtype": str(model_manager.dtype), | |
"model_id": MODEL_ID | |
} | |
} | |
# If recovery is needed, trigger it automatically | |
if recovery_needed: | |
logger.info("🔧 Récupération automatique déclenchée via /status") | |
if partial_load_detected: | |
import threading | |
thread = threading.Thread(target=model_manager._complete_partial_load, daemon=True) | |
thread.start() | |
elif stuck_loading: | |
model_manager._loading = False | |
model_manager._load_error = "Timeout - chargement bloqué" | |
model_manager._save_state() | |
return status_info | |
except Exception as e: | |
logger.error(f"Erreur dans detailed_status: {e}") | |
return { | |
"status": "error", | |
"error": str(e), | |
"timestamp": time.time() | |
} | |
def root(): | |
"""Page d'accueil avec informations sur l'API.""" | |
return { | |
"message": "AgriLens AI FastAPI", | |
"version": "1.0.0", | |
"endpoints": { | |
"health": "/health", | |
"load": "/load", | |
"diagnose": "/diagnose (POST)" | |
}, | |
"model": MODEL_ID, | |
"uptime_s": int(time.time() - APP_START_TS) | |
} | |
# Lancement correct pour Hugging Face Spaces | |
if __name__ == "__main__": | |
import uvicorn | |
port = int(os.environ.get("PORT", 7860)) # Hugging Face donne ce port | |
uvicorn.run("app:app", host="0.0.0.0", port=port) | |