File size: 3,059 Bytes
0169392
 
 
 
 
 
 
 
 
e5254aa
0169392
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5254aa
 
 
 
0169392
 
 
 
 
 
 
 
 
 
 
e1f229a
0169392
 
e1f229a
0169392
 
 
 
e1f229a
0169392
e1f229a
0169392
 
 
 
 
 
 
 
 
 
 
 
e5254aa
 
0169392
e5254aa
0169392
 
e5254aa
 
 
0169392
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import gradio as gr
from PIL import Image
import json
from io import BytesIO
import base64
import torch
from tempfile import gettempdir
from os import path, makedirs, remove
import models
import time

def get_safe_cache_dir():
    try:
        # Thử ghi vào ~/.cache/huggingface (nếu có)
        default_cache = path.expanduser("~/.cache/huggingface")
        makedirs(default_cache, exist_ok=True)
        test_file = path.join(default_cache, "test_write.txt")
        with open(test_file, "w") as f:
            f.write("ok")
        remove(test_file)
        return default_cache
    except Exception:
        # Nếu lỗi (ví dụ trên HuggingFace Spaces), dùng temp
        return path.join(gettempdir(), "huggingface")

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CACHE_DIR = get_safe_cache_dir()
AVAILABLE_MODELS = {
    # "TrOCR (Base Printed)": {
    #     "id": "microsoft/trocr-base-printed",
    #     "type": "trocr"
    # },
    "EraX (VL-2B-V1.5)": {
        "id": "erax-ai/EraX-VL-2B-V1.5",
        "type": "erax"
    }
}
_model_cache = {}

print("Using device:", DEVICE)
print("Cache directory:", CACHE_DIR)

def load_model(model_key):
    print("Processing image with model:", model_key)
    model_id = AVAILABLE_MODELS[model_key]["id"]
    model_type = AVAILABLE_MODELS[model_key]["type"]
    print("Model ID:", model_id, "Type:", model_type)

    if model_id in _model_cache:
        return _model_cache[model_key]
    
    if "trocr" == model_type:
        model = models.TrOCRModel(model_id, cache_dir=CACHE_DIR, device=DEVICE)
    elif "erax" == model_type:
        model = models.EraXModel(model_id, cache_dir=CACHE_DIR, device=DEVICE)
    else:
        raise ValueError("Unknown model")

    _model_cache[model_key] = model
    print('Load model:', model_id, ' successfully!')
    return model

# Hàm xử lý ảnh đầu vào
def gradio_process(image: Image.Image, model_key: str):
    if image is None:
        return {"error": "No image provided"}
    
    print('Received image size:', image.size)

    start = time.time()
    model = load_model(model_key)
    result = model.predict(image)
    print('Model predicted successfully!')
    print('Result:', result)
    print('Time taken for prediction:', time.time() - start)

    return json.dumps({
        "texts": result,
        "image_size": {
            "width": image.width,
            "height": image.height
        },
        "mode": image.mode,
    }, indent=4)

# Giao diện Gradio
demo = gr.Interface(
    fn=gradio_process,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Dropdown(choices=list(AVAILABLE_MODELS.keys()), label="Chọn mô hình", value="TrOCR (Base Printed)"),
        # gr.Textbox(label="Prompt (chỉ dùng cho EraX)", placeholder="Ảnh này có gì?")
    ],
    outputs=gr.JSON(label="Output (Text/JSON Extract)"),
    title="Image to Text/JSON Extractor",
    description="Upload an image and extract structured text using OCR."
)

if __name__ == "__main__":
    demo.launch()