import gradio as gr from PIL import Image import json from io import BytesIO import base64 import torch from tempfile import gettempdir from os import path, makedirs, remove import models import time def get_safe_cache_dir(): try: # Thử ghi vào ~/.cache/huggingface (nếu có) default_cache = path.expanduser("~/.cache/huggingface") makedirs(default_cache, exist_ok=True) test_file = path.join(default_cache, "test_write.txt") with open(test_file, "w") as f: f.write("ok") remove(test_file) return default_cache except Exception: # Nếu lỗi (ví dụ trên HuggingFace Spaces), dùng temp return path.join(gettempdir(), "huggingface") DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") CACHE_DIR = get_safe_cache_dir() AVAILABLE_MODELS = { # "TrOCR (Base Printed)": { # "id": "microsoft/trocr-base-printed", # "type": "trocr" # }, "EraX (VL-2B-V1.5)": { "id": "erax-ai/EraX-VL-2B-V1.5", "type": "erax" } } _model_cache = {} print("Using device:", DEVICE) print("Cache directory:", CACHE_DIR) def load_model(model_key): print("Processing image with model:", model_key) model_id = AVAILABLE_MODELS[model_key]["id"] model_type = AVAILABLE_MODELS[model_key]["type"] print("Model ID:", model_id, "Type:", model_type) if model_id in _model_cache: return _model_cache[model_key] if "trocr" == model_type: model = models.TrOCRModel(model_id, cache_dir=CACHE_DIR, device=DEVICE) elif "erax" == model_type: model = models.EraXModel(model_id, cache_dir=CACHE_DIR, device=DEVICE) else: raise ValueError("Unknown model") _model_cache[model_key] = model print('Load model:', model_id, ' successfully!') return model # Hàm xử lý ảnh đầu vào def gradio_process(image: Image.Image, model_key: str): if image is None: return {"error": "No image provided"} print('Received image size:', image.size) start = time.time() model = load_model(model_key) result = model.predict(image) print('Model predicted successfully!') print('Result:', result) print('Time taken for prediction:', time.time() - start) return json.dumps({ "texts": result, "image_size": { "width": image.width, "height": image.height }, "mode": image.mode, }, indent=4) # Giao diện Gradio demo = gr.Interface( fn=gradio_process, inputs=[ gr.Image(type="pil", label="Upload Image"), gr.Dropdown(choices=list(AVAILABLE_MODELS.keys()), label="Chọn mô hình", value="TrOCR (Base Printed)"), # gr.Textbox(label="Prompt (chỉ dùng cho EraX)", placeholder="Ảnh này có gì?") ], outputs=gr.JSON(label="Output (Text/JSON Extract)"), title="Image to Text/JSON Extractor", description="Upload an image and extract structured text using OCR." ) if __name__ == "__main__": demo.launch()