# app.py import os import time from typing import Tuple import gradio as gr from PIL import Image import torch from model import OCRModel from preprocess import crop_by_region, to_tensor_one_tile # dùng hàm sẵn có của bạn MODEL_ID = "5CD-AI/Vintern-1B-v3_5" # CPU free-tier -> allow_flash_attn=False; GPU A10G có thể bật True ocr_model = OCRModel(model_id=MODEL_ID, allow_flash_attn=False) DEFAULT_PROMPT = "Chỉ trả về đúng nội dung văn bản nhìn thấy trong ảnh (không thêm giải thích)." REGIONS = ["full", "head", "body", "foot"] PRESETS = ["fast", "quality"] def ensure_model_loaded(): if not ocr_model.is_loaded: ocr_model.load() def run_ocr( image: Image.Image, region: str, preset: str, prompt: str, max_new_tokens: int ): if image is None: return "⚠️ Chưa chọn ảnh." ensure_model_loaded() # 1) Cắt vùng theo tham số (giống logic Flask cũ của bạn) pil = crop_by_region(image, region=region, head_ratio=0.28, foot_ratio=0.22) # 2) Đưa về tensor (1 tile / 448) px = to_tensor_one_tile(pil, input_size=448) # 3) Đồng bộ device & dtype với model (QUAN TRỌNG để tránh lỗi float/half) model_dtype = next(ocr_model.model.parameters()).dtype px = px.to(device=ocr_model.device, dtype=model_dtype) # 4) Tham số sinh text if preset == "fast": gen = dict(max_new_tokens=min(512, max_new_tokens), do_sample=False, num_beams=1, repetition_penalty=1.05) else: gen = dict(max_new_tokens=max_new_tokens, do_sample=False, num_beams=1, repetition_penalty=1.10) question = f"\n{(prompt or DEFAULT_PROMPT).strip()}\n" t0 = time.time() text = ocr_model.chat(px, question, **gen) dt = time.time() - t0 return f"{text}\n\n— elapsed: {dt:.2f}s | device: {ocr_model.device_str}" with gr.Blocks(title="OCR Demo (Gradio)") as demo: gr.Markdown( "# OCR Demo (Gradio)\n" "Upload ảnh giấy tờ → chọn **vùng** → bấm **Extract**.\n" f"Model: `{MODEL_ID}`" ) with gr.Row(): with gr.Column(scale=1): inp_img = gr.Image(type="pil", label="Ảnh", sources=["upload", "clipboard"]) region = gr.Radio(REGIONS, value="full", label="Vùng cắt") preset = gr.Radio(PRESETS, value="fast", label="Chế độ") with gr.Column(scale=1): prompt = gr.Textbox(value=DEFAULT_PROMPT, label="Prompt", lines=3) max_tokens = gr.Slider(16, 512, value=128, step=8, label="max_new_tokens") btn = gr.Button("Extract nội dung", variant="primary") out = gr.Textbox(label="Kết quả OCR", lines=18) btn.click(run_ocr, [inp_img, region, preset, prompt, max_tokens], [out]) if __name__ == "__main__": # Local: mở http://127.0.0.1:7860 # Trên Hugging Face: không cần chỉnh — Spaces sẽ tự bind PORT demo.launch()