# ========================================== # image_processing_gpu.py - Version ZeroGPU avec modèles OCR commutables # ========================================== """ Module de traitement d'images GPU-optimisé pour ZeroGPU HuggingFace Spaces """ import time import torch import spaces from transformers import TrOCRProcessor, VisionEncoderDecoderModel from utils import ( optimize_image_for_ocr, prepare_image_for_dataset, create_thumbnail_fast, create_white_canvas, log_memory_usage, cleanup_memory, validate_ocr_result ) # ========================================== # Configuration des modèles OCR disponibles # ========================================== AVAILABLE_OCR_MODELS = { "microsoft/trocr-base-handwritten": { "description": "Modèle de base Microsoft pour écriture manuscrite", "display_name": "TrOCR Base Handwritten (Microsoft)", "optimized_for": "general_handwriting" }, "hoololi/trocr-base-handwritten-calctrainer": { "description": "Modèle fine tuné pour les nombres entiers", "display_name": "TrOCR CalcTrainer (Hoololi)", "optimized_for": "mathematical_numbers" } } # Modèle par défaut DEFAULT_OCR_MODEL = "hoololi/trocr-base-handwritten-calctrainer" current_ocr_model_name = DEFAULT_OCR_MODEL # Variables globales pour OCR processor = None model = None current_loaded_model = None def get_available_models() -> dict: """Retourne la liste des modèles OCR disponibles""" return AVAILABLE_OCR_MODELS def get_current_model_info() -> dict: """Retourne les informations du modèle OCR actuellement chargé""" global current_ocr_model_name, current_loaded_model model_config = AVAILABLE_OCR_MODELS.get(current_ocr_model_name, AVAILABLE_OCR_MODELS[DEFAULT_OCR_MODEL]) if torch.cuda.is_available(): device = "ZeroGPU" gpu_name = torch.cuda.get_device_name() else: device = "CPU" gpu_name = "N/A" return { "model_name": current_ocr_model_name, "display_name": model_config["display_name"], "description": model_config["description"], "current_loaded": current_loaded_model, "device": device, "gpu_name": gpu_name, "framework": "HuggingFace-Transformers-ZeroGPU", "optimized_for": model_config["optimized_for"], "is_loaded": processor is not None and model is not None, # Compatibilité avec l'ancien code "version": current_ocr_model_name } def set_ocr_model(model_name: str) -> bool: """ Change le modèle OCR actif Args: model_name: Nom exact du modèle (ex: "microsoft/trocr-base-handwritten") Returns: bool: True si le changement a réussi """ global current_ocr_model_name if model_name not in AVAILABLE_OCR_MODELS: print(f"❌ Modèle '{model_name}' non disponible. Modèles disponibles: {list(AVAILABLE_OCR_MODELS.keys())}") return False if model_name == current_ocr_model_name and processor is not None and model is not None: print(f"✅ Modèle '{model_name}' déjà chargé") return True model_config = AVAILABLE_OCR_MODELS[model_name] print(f"🔄 Changement vers le modèle: {model_config['display_name']}") current_ocr_model_name = model_name # Nettoyer le modèle précédent cleanup_current_model() # Charger le nouveau modèle return init_ocr_model() def cleanup_current_model(): """Nettoie le modèle actuellement chargé pour libérer la mémoire""" global processor, model, current_loaded_model if model is not None: del model model = None if processor is not None: del processor processor = None current_loaded_model = None # Nettoyage mémoire GPU si disponible if torch.cuda.is_available(): torch.cuda.empty_cache() print("🧹 Modèle précédent nettoyé") def init_ocr_model(model_name: str = None) -> bool: """ Initialise TrOCR pour ZeroGPU avec le modèle spécifié Args: model_name: Nom exact du modèle à charger (optionnel, utilise current_ocr_model_name par défaut) """ global processor, model, current_ocr_model_name, current_loaded_model if model_name is not None: if model_name not in AVAILABLE_OCR_MODELS: print(f"❌ Modèle '{model_name}' non disponible") return False current_ocr_model_name = model_name model_config = AVAILABLE_OCR_MODELS[current_ocr_model_name] try: print(f"🔄 Chargement {model_config['display_name']} (ZeroGPU)...") print(f" 📍 Modèle: {current_ocr_model_name}") processor = TrOCRProcessor.from_pretrained(current_ocr_model_name) model = VisionEncoderDecoderModel.from_pretrained(current_ocr_model_name) current_loaded_model = current_ocr_model_name # Optimisations model.eval() if torch.cuda.is_available(): model = model.cuda() device_info = f"GPU ({torch.cuda.get_device_name()})" print(f"✅ {model_config['display_name']} prêt sur {device_info} !") else: device_info = "CPU (ZeroGPU pas encore alloué)" print(f"⚠️ {model_config['display_name']} sur CPU - {device_info}") return True except Exception as e: print(f"❌ Erreur lors du chargement {model_config['display_name']}: {e}") return False # Alias pour compatibilité avec l'ancien code def get_ocr_model_info() -> dict: """Alias pour get_current_model_info() - compatibilité""" return get_current_model_info() @spaces.GPU def recognize_number_fast_with_image(image_dict, debug: bool = False) -> tuple[str, any, dict | None]: """ OCR avec TrOCR ZeroGPU - Version simplifiée avec modèle commutable """ if image_dict is None: if debug: print(" ❌ Image manquante") return "0", None, None try: start_time = time.time() if debug: model_info = get_current_model_info() print(f" 🔄 Début OCR {model_info['display_name']} ZeroGPU...") # Optimiser image optimized_image = optimize_image_for_ocr(image_dict, max_size=384) if optimized_image is None: if debug: print(" ❌ Échec optimisation image") return "0", None, None # TrOCR - traitement ZeroGPU if processor is None or model is None: if debug: print(" ❌ TrOCR non initialisé") return "0", None, None if debug: print(" 🤖 Lancement TrOCR ZeroGPU...") with torch.no_grad(): # Preprocessing pixel_values = processor(images=optimized_image, return_tensors="pt").pixel_values # GPU transfer si disponible if torch.cuda.is_available(): pixel_values = pixel_values.cuda() # Génération optimisée generated_ids = model.generate( pixel_values, max_length=4, num_beams=1, do_sample=False, early_stopping=True, pad_token_id=processor.tokenizer.pad_token_id ) # Décodage result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] final_result = validate_ocr_result(result, max_length=4) # Préparer pour dataset dataset_image_data = prepare_image_for_dataset(optimized_image) if debug: total_time = time.time() - start_time device = "ZeroGPU" if torch.cuda.is_available() else "CPU" model_name = get_current_model_info()['display_name'] print(f" ✅ {model_name} ({device}) terminé en {total_time:.1f}s → '{final_result}'") if dataset_image_data: print(f" 🖼️ Image dataset: {type(dataset_image_data.get('handwriting_image', 'None'))}") return final_result, optimized_image, dataset_image_data except Exception as e: print(f"❌ Erreur OCR TrOCR ZeroGPU: {e}") return "0", None, None def recognize_number_fast(image_dict) -> tuple[str, any]: """Version rapide standard""" result, optimized_image, _ = recognize_number_fast_with_image(image_dict) return result, optimized_image def recognize_number(image_dict) -> str: """Interface standard""" result, _ = recognize_number_fast(image_dict) return result