Shadow0704 commited on
Commit
b85866b
·
verified ·
1 Parent(s): b133a37

Upload 5 files

Browse files
Files changed (5) hide show
  1. app.py +87 -0
  2. model.py +113 -0
  3. preprocess.py +37 -0
  4. requirements.txt +0 -0
  5. vintern_fast.py +201 -0
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import time
4
+ from typing import Tuple
5
+
6
+ import gradio as gr
7
+ from PIL import Image
8
+ import torch
9
+
10
+ from model import OCRModel
11
+ from preprocess import crop_by_region, to_tensor_one_tile # dùng hàm sẵn có của bạn
12
+
13
+ MODEL_ID = "5CD-AI/Vintern-1B-v3_5"
14
+
15
+ # CPU free-tier -> allow_flash_attn=False; GPU A10G có thể bật True
16
+ ocr_model = OCRModel(model_id=MODEL_ID, allow_flash_attn=False)
17
+
18
+ DEFAULT_PROMPT = "Chỉ trả về đúng nội dung văn bản nhìn thấy trong ảnh (không thêm giải thích)."
19
+ REGIONS = ["full", "head", "body", "foot"]
20
+ PRESETS = ["fast", "quality"]
21
+
22
+ def ensure_model_loaded():
23
+ if not ocr_model.is_loaded:
24
+ ocr_model.load()
25
+
26
+ def run_ocr(
27
+ image: Image.Image,
28
+ region: str,
29
+ preset: str,
30
+ prompt: str,
31
+ max_new_tokens: int
32
+ ):
33
+ if image is None:
34
+ return "⚠️ Chưa chọn ảnh."
35
+
36
+ ensure_model_loaded()
37
+
38
+ # 1) Cắt vùng theo tham số (giống logic Flask cũ của bạn)
39
+ pil = crop_by_region(image, region=region, head_ratio=0.28, foot_ratio=0.22)
40
+
41
+ # 2) Đưa về tensor (1 tile / 448)
42
+ px = to_tensor_one_tile(pil, input_size=448)
43
+
44
+ # 3) Đồng bộ device & dtype với model (QUAN TRỌNG để tránh lỗi float/half)
45
+ model_dtype = next(ocr_model.model.parameters()).dtype
46
+ px = px.to(device=ocr_model.device, dtype=model_dtype)
47
+
48
+ # 4) Tham số sinh text
49
+ if preset == "fast":
50
+ gen = dict(max_new_tokens=min(512, max_new_tokens),
51
+ do_sample=False, num_beams=1, repetition_penalty=1.05)
52
+ else:
53
+ gen = dict(max_new_tokens=max_new_tokens,
54
+ do_sample=False, num_beams=1, repetition_penalty=1.10)
55
+
56
+ question = f"<image>\n{(prompt or DEFAULT_PROMPT).strip()}\n"
57
+
58
+ t0 = time.time()
59
+ text = ocr_model.chat(px, question, **gen)
60
+ dt = time.time() - t0
61
+
62
+ return f"{text}\n\n— elapsed: {dt:.2f}s | device: {ocr_model.device_str}"
63
+
64
+ with gr.Blocks(title="OCR Demo (Gradio)") as demo:
65
+ gr.Markdown(
66
+ "# OCR Demo (Gradio)\n"
67
+ "Upload ảnh giấy tờ → chọn **vùng** → bấm **Extract**.\n"
68
+ f"Model: `{MODEL_ID}`"
69
+ )
70
+
71
+ with gr.Row():
72
+ with gr.Column(scale=1):
73
+ inp_img = gr.Image(type="pil", label="Ảnh", sources=["upload", "clipboard"])
74
+ region = gr.Radio(REGIONS, value="full", label="Vùng cắt")
75
+ preset = gr.Radio(PRESETS, value="fast", label="Chế độ")
76
+ with gr.Column(scale=1):
77
+ prompt = gr.Textbox(value=DEFAULT_PROMPT, label="Prompt", lines=3)
78
+ max_tokens = gr.Slider(16, 512, value=128, step=8, label="max_new_tokens")
79
+ btn = gr.Button("Extract nội dung", variant="primary")
80
+ out = gr.Textbox(label="Kết quả OCR", lines=18)
81
+
82
+ btn.click(run_ocr, [inp_img, region, preset, prompt, max_tokens], [out])
83
+
84
+ if __name__ == "__main__":
85
+ # Local: mở http://127.0.0.1:7860
86
+ # Trên Hugging Face: không cần chỉnh — Spaces sẽ tự bind PORT
87
+ demo.launch()
model.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model.py
2
+ import torch
3
+ from transformers import AutoModel, AutoTokenizer, GenerationConfig
4
+
5
+ class OCRModel:
6
+ def __init__(
7
+ self,
8
+ model_id: str = "5CD-AI/Vintern-1B-v3_5",
9
+ allow_flash_attn: bool = False,
10
+ prefer_bfloat16: bool = False,
11
+ ):
12
+ self.model_id = model_id
13
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ if self.device.type == "cuda":
15
+ if prefer_bfloat16 and torch.cuda.is_bf16_supported():
16
+ self.dtype = torch.bfloat16
17
+ else:
18
+ self.dtype = torch.float16
19
+ else:
20
+ self.dtype = torch.float32
21
+
22
+ self.allow_flash_attn = bool(allow_flash_attn and self.device.type == "cuda")
23
+ self.model = None
24
+ self.tokenizer = None
25
+ self.is_loaded = False
26
+
27
+ @property
28
+ def on_cuda(self): return self.device.type == "cuda"
29
+ @property
30
+ def device_str(self): return f"{self.device} ({str(self.dtype)})"
31
+
32
+ def load(self):
33
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code=True)
34
+ # ưu tiên API mới (dtype=), fallback torch_dtype nếu cần
35
+ try:
36
+ self.model = AutoModel.from_pretrained(
37
+ self.model_id, dtype=self.dtype, trust_remote_code=True
38
+ )
39
+ except TypeError:
40
+ self.model = AutoModel.from_pretrained(
41
+ self.model_id, torch_dtype=self.dtype, trust_remote_code=True
42
+ )
43
+ self.model.to(device=self.device, dtype=self.dtype)
44
+ self.model.eval()
45
+
46
+ if not hasattr(self.model, "generation_config") or self.model.generation_config is None:
47
+ self.model.generation_config = GenerationConfig()
48
+ self.is_loaded = True
49
+
50
+ def _build_gen_dict(self, **gen_kwargs) -> dict:
51
+ """
52
+ Trả về generation_config dạng DICT theo kỳ vọng của InternVLChatModel.chat(),
53
+ và LOẠI các khóa có thể bị truyền trùng trong .generate(...)
54
+ """
55
+ # base từ GenerationConfig hiện có
56
+ if hasattr(self.model, "generation_config") and self.model.generation_config is not None:
57
+ try:
58
+ base = self.model.generation_config.to_dict()
59
+ except Exception:
60
+ base = {}
61
+ else:
62
+ base = {}
63
+
64
+ # gộp tham số từ UI
65
+ for k, v in (gen_kwargs or {}).items():
66
+ base[k] = v
67
+
68
+ # Bổ sung token ids nếu thiếu
69
+ if "eos_token_id" not in base and hasattr(self.tokenizer, "eos_token_id"):
70
+ base["eos_token_id"] = self.tokenizer.eos_token_id
71
+ if "pad_token_id" not in base:
72
+ pad_id = getattr(self.tokenizer, "pad_token_id", None)
73
+ base["pad_token_id"] = pad_id if pad_id is not None else base.get("eos_token_id", None)
74
+ if "bos_token_id" not in base and hasattr(self.tokenizer, "bos_token_id"):
75
+ base["bos_token_id"] = self.tokenizer.bos_token_id
76
+
77
+ # ép kiểu int cho *_token_id
78
+ for key in ("eos_token_id", "pad_token_id", "bos_token_id"):
79
+ if key in base and base[key] is not None:
80
+ try:
81
+ base[key] = int(base[key])
82
+ except Exception:
83
+ pass
84
+
85
+ # 🚫 LOẠI các khóa dễ bị “multiple values”
86
+ for bad in ("use_cache", "output_attentions", "output_hidden_states",
87
+ "return_dict_in_generate", "synced_gpus"):
88
+ base.pop(bad, None)
89
+
90
+ return base
91
+
92
+ def chat(self, pixel_values: torch.Tensor, question: str, **gen_kwargs) -> str:
93
+ if not self.is_loaded:
94
+ self.load()
95
+
96
+ # đồng bộ dtype/device input với model
97
+ model_dtype = next(self.model.parameters()).dtype
98
+ pixel_values = pixel_values.to(device=self.device, dtype=model_dtype)
99
+
100
+ # DICT sạch cho generation_config
101
+ gen_dict = self._build_gen_dict(**gen_kwargs)
102
+
103
+ # gọi chat: yêu cầu tokenizer + generation_config (DICT)
104
+ out = self.model.chat(
105
+ pixel_values=pixel_values,
106
+ question=question,
107
+ tokenizer=self.tokenizer,
108
+ generation_config=gen_dict,
109
+ )
110
+
111
+ if isinstance(out, (list, tuple)) and len(out) >= 1:
112
+ return out[0]
113
+ return out
preprocess.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torchvision.transforms as T
3
+ from torchvision.transforms.functional import InterpolationMode
4
+ import torch
5
+
6
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
7
+ IMAGENET_STD = (0.229, 0.224, 0.225)
8
+ DEFAULT_INPUT_SIZE = 448
9
+
10
+ def build_transform(input_size: int) -> T.Compose:
11
+ return T.Compose([
12
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
13
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BILINEAR),
14
+ T.ToTensor(),
15
+ T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
16
+ ])
17
+
18
+ def crop_regions(pil_img: Image.Image, head_ratio=0.28, foot_ratio=0.22):
19
+ w, h = pil_img.size
20
+ head_h = int(h * head_ratio)
21
+ foot_h = int(h * foot_ratio)
22
+ head = pil_img.crop((0, 0, w, head_h))
23
+ foot = pil_img.crop((0, h - foot_h, w, h))
24
+ body = pil_img.crop((0, head_h, w, h - foot_h))
25
+ return head, body, foot
26
+
27
+ def crop_by_region(pil_img: Image.Image, region: str, head_ratio=0.28, foot_ratio=0.22) -> Image.Image:
28
+ r = (region or "full").lower()
29
+ if r == "full": return pil_img
30
+ head, body, foot = crop_regions(pil_img, head_ratio=head_ratio, foot_ratio=foot_ratio)
31
+ return {"head": head, "body": body, "foot": foot}.get(r, pil_img)
32
+
33
+ def to_tensor_one_tile(pil_img: Image.Image, input_size=DEFAULT_INPUT_SIZE, pin_memory=False) -> torch.Tensor:
34
+ transform = build_transform(input_size=input_size)
35
+ t = transform(pil_img).unsqueeze(0)
36
+ if pin_memory: t = t.pin_memory()
37
+ return t
requirements.txt ADDED
Binary file (2.4 kB). View file
 
vintern_fast.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torchvision.transforms as T
4
+ from PIL import Image
5
+ from torchvision.transforms.functional import InterpolationMode
6
+ from transformers import AutoModel, AutoTokenizer
7
+ import time
8
+ import argparse
9
+ import sys
10
+ """
11
+ url: https://huggingface.co/5CD-AI/Vintern-1B-v3_5
12
+ """
13
+ # Ensure UTF-8 console output (fixes UnicodeEncodeError on Windows PowerShell)
14
+ try:
15
+ sys.stdout.reconfigure(encoding='utf-8')
16
+ sys.stderr.reconfigure(encoding='utf-8')
17
+ except Exception:
18
+ pass
19
+ # pip install ninja packaging wheel
20
+ # pip install flash-attn --no-build-isolation
21
+ # Khởi tạo timer
22
+ start_time = time.time()
23
+
24
+ # Chọn device (GPU nếu có)
25
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+
27
+ # Runtime backend optimizations
28
+ torch.backends.cudnn.benchmark = True
29
+ torch.backends.cuda.matmul.allow_tf32 = True
30
+ torch.backends.cudnn.allow_tf32 = True
31
+
32
+ print("Using device:", device)
33
+
34
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
35
+ IMAGENET_STD = (0.229, 0.224, 0.225)
36
+
37
+ def build_transform(input_size):
38
+ return T.Compose([
39
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
40
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BILINEAR),
41
+ T.ToTensor(),
42
+ T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
43
+ ])
44
+
45
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
46
+ best_ratio_diff = float('inf')
47
+ best_ratio = (1, 1)
48
+ area = width * height
49
+ for ratio in target_ratios:
50
+ target_aspect_ratio = ratio[0] / ratio[1]
51
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
52
+ if ratio_diff < best_ratio_diff:
53
+ best_ratio_diff = ratio_diff
54
+ best_ratio = ratio
55
+ elif ratio_diff == best_ratio_diff:
56
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
57
+ best_ratio = ratio
58
+ return best_ratio
59
+
60
+ def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
61
+ orig_width, orig_height = image.size
62
+ aspect_ratio = orig_width / orig_height
63
+
64
+ target_ratios = set(
65
+ (i, j) for n in range(min_num, max_num + 1)
66
+ for i in range(1, n + 1)
67
+ for j in range(1, n + 1)
68
+ if i * j <= max_num and i * j >= min_num
69
+ )
70
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
71
+
72
+ target_aspect_ratio = find_closest_aspect_ratio(
73
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
74
+
75
+ target_width = image_size * target_aspect_ratio[0]
76
+ target_height = image_size * target_aspect_ratio[1]
77
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
78
+
79
+ resized_img = image.resize((target_width, target_height))
80
+ processed_images = []
81
+ for i in range(blocks):
82
+ box = (
83
+ (i % (target_width // image_size)) * image_size,
84
+ (i // (target_width // image_size)) * image_size,
85
+ ((i % (target_width // image_size)) + 1) * image_size,
86
+ ((i // (target_width // image_size)) + 1) * image_size
87
+ )
88
+ split_img = resized_img.crop(box)
89
+ processed_images.append(split_img)
90
+ assert len(processed_images) == blocks
91
+
92
+ if use_thumbnail and len(processed_images) != 1:
93
+ thumbnail_img = image.resize((image_size, image_size))
94
+ processed_images.append(thumbnail_img)
95
+ return processed_images
96
+
97
+ def load_image(image_file, input_size=448, max_num=12, use_thumbnail=False, pin_memory=False):
98
+ image = Image.open(image_file).convert('RGB')
99
+ transform = build_transform(input_size=input_size)
100
+ # Fast path when using only one tile and no thumbnail
101
+ if max_num == 1 and not use_thumbnail:
102
+ pixel_values = transform(image).unsqueeze(0)
103
+ else:
104
+ images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=use_thumbnail, max_num=max_num)
105
+ pixel_values = [transform(img) for img in images]
106
+ pixel_values = torch.stack(pixel_values)
107
+ if pin_memory:
108
+ pixel_values = pixel_values.pin_memory()
109
+ return pixel_values
110
+
111
+ # Load model lên GPU
112
+ model_load_start = time.time()
113
+ model = AutoModel.from_pretrained(
114
+ "5CD-AI/Vintern-1B-v3_5",
115
+ torch_dtype=torch.float16,
116
+ low_cpu_mem_usage=True,
117
+ trust_remote_code=True,
118
+ use_flash_attn=True, # nếu đã cài flash-attn có thể đổi thành True
119
+ ).to(device).eval()
120
+ model_load_end = time.time()
121
+
122
+ tokenizer = AutoTokenizer.from_pretrained(
123
+ "5CD-AI/Vintern-1B-v3_5",
124
+ trust_remote_code=True,
125
+ use_fast=False
126
+ )
127
+
128
+ def main():
129
+ parser = argparse.ArgumentParser()
130
+ parser.add_argument('--image', type=str, default=r'C:\Users\ADMIN\Downloads\vintern_api\imgs\6.TKngknhnCMC_00001.png')
131
+ parser.add_argument('--input_size', type=int, default=384)
132
+ parser.add_argument('--max_num', type=int, default=1)
133
+ parser.add_argument('--use_thumbnail', action='store_true', default=False)
134
+ parser.add_argument('--max_new_tokens', type=int, default=128)
135
+ parser.add_argument('--num_beams', type=int, default=1)
136
+ parser.add_argument('--do_sample', action='store_true', default=False)
137
+ parser.add_argument('--repetition_penalty', type=float, default=2.5)
138
+ parser.add_argument('--question', type=str, default='<image>\nTrích xuất thông tin chính trong ảnh và trả về dạng markdown.')
139
+ parser.add_argument('--compile', action='store_true', default=False)
140
+ args = parser.parse_args()
141
+
142
+ pin_mem = device.type == 'cuda'
143
+
144
+ # Validate input size for this model family (fallback to 448 if incompatible)
145
+ valid_input_size = args.input_size
146
+ try:
147
+ # Many InternVL/Vintern checkpoints expect 448 per tile
148
+ if args.input_size != 448:
149
+ print(f"[warn] input_size {args.input_size} may be incompatible; falling back to 448 for stability.")
150
+ valid_input_size = 448
151
+ except Exception:
152
+ valid_input_size = 448
153
+
154
+ # Image preprocessing and non-blocking GPU transfer
155
+ pixel_values = load_image(
156
+ args.image,
157
+ input_size=valid_input_size,
158
+ max_num=args.max_num,
159
+ use_thumbnail=args.use_thumbnail,
160
+ pin_memory=pin_mem
161
+ )
162
+ pixel_values = pixel_values.contiguous(memory_format=torch.channels_last)
163
+ pixel_values = pixel_values.to(device=device, dtype=torch.float16, non_blocking=True)
164
+
165
+ # Optional compile for speedup (PyTorch 2.x). Fallback silently if unsupported.
166
+ if args.compile:
167
+ try:
168
+ model_forward = model.forward
169
+ model.forward = torch.compile(model_forward, mode='reduce-overhead', fullgraph=False) # type: ignore
170
+ except Exception:
171
+ pass
172
+
173
+ generation_config = dict(
174
+ max_new_tokens=args.max_new_tokens,
175
+ do_sample=args.do_sample,
176
+ num_beams=args.num_beams,
177
+ repetition_penalty=args.repetition_penalty
178
+ )
179
+
180
+ with torch.inference_mode():
181
+ response, history = model.chat(
182
+ tokenizer,
183
+ pixel_values,
184
+ args.question,
185
+ generation_config,
186
+ history=None,
187
+ return_history=True
188
+ )
189
+
190
+ print(f'User: {args.question}\nAssistant: {response}')
191
+
192
+ end_time = time.time()
193
+ print(f'Model load: {model_load_end - model_load_start:.2f}s | Total: {end_time - start_time:.2f}s')
194
+
195
+ del pixel_values
196
+ if device.type == 'cuda':
197
+ torch.cuda.empty_cache()
198
+
199
+
200
+ if __name__ == '__main__':
201
+ main()