|
import torch
|
|
import torchvision.transforms as T
|
|
from PIL import Image
|
|
from transformers import AutoModel, AutoTokenizer
|
|
from torchvision.transforms.functional import InterpolationMode
|
|
|
|
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
|
IMAGENET_STD = (0.229, 0.224, 0.225)
|
|
|
|
|
|
def build_transform(input_size):
|
|
return T.Compose([
|
|
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
|
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
|
|
T.ToTensor(),
|
|
T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
|
|
])
|
|
|
|
|
|
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
|
|
best_ratio_diff = float('inf')
|
|
best_ratio = (1, 1)
|
|
area = width * height
|
|
for ratio in target_ratios:
|
|
target_aspect_ratio = ratio[0] / ratio[1]
|
|
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
|
if ratio_diff < best_ratio_diff:
|
|
best_ratio_diff = ratio_diff
|
|
best_ratio = ratio
|
|
elif ratio_diff == best_ratio_diff:
|
|
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
|
best_ratio = ratio
|
|
return best_ratio
|
|
|
|
|
|
def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
|
|
orig_width, orig_height = image.size
|
|
aspect_ratio = orig_width / orig_height
|
|
|
|
target_ratios = set(
|
|
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1)
|
|
for j in range(1, n + 1) if min_num <= i * j <= max_num
|
|
)
|
|
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
|
|
|
target_aspect_ratio = find_closest_aspect_ratio(
|
|
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
|
|
|
|
target_width = image_size * target_aspect_ratio[0]
|
|
target_height = image_size * target_aspect_ratio[1]
|
|
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
|
|
|
resized_img = image.resize((target_width, target_height))
|
|
processed_images = []
|
|
|
|
for i in range(blocks):
|
|
box = (
|
|
(i % (target_width // image_size)) * image_size,
|
|
(i // (target_width // image_size)) * image_size,
|
|
((i % (target_width // image_size)) + 1) * image_size,
|
|
((i // (target_width // image_size)) + 1) * image_size
|
|
)
|
|
split_img = resized_img.crop(box)
|
|
processed_images.append(split_img)
|
|
|
|
if use_thumbnail and len(processed_images) != 1:
|
|
thumbnail_img = image.resize((image_size, image_size))
|
|
processed_images.append(thumbnail_img)
|
|
|
|
return processed_images
|
|
|
|
|
|
def load_image(image_file, input_size=448, max_num=12):
|
|
image = Image.open(image_file).convert('RGB')
|
|
transform = build_transform(input_size)
|
|
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
|
|
pixel_values = [transform(im) for im in images]
|
|
return torch.stack(pixel_values)
|
|
|
|
|
|
def load_model():
|
|
model_name = "5CD-AI/Vintern-1B-v3_5"
|
|
try:
|
|
model = AutoModel.from_pretrained(
|
|
model_name,
|
|
torch_dtype=torch.bfloat16,
|
|
low_cpu_mem_usage=True,
|
|
trust_remote_code=True,
|
|
use_flash_attn=False
|
|
).eval().cuda()
|
|
except Exception:
|
|
model = AutoModel.from_pretrained(
|
|
model_name,
|
|
torch_dtype=torch.bfloat16,
|
|
low_cpu_mem_usage=True,
|
|
trust_remote_code=True
|
|
).eval().cuda()
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False)
|
|
return model, tokenizer
|
|
|
|
|
|
def extract_info_from_image(image_path, model, tokenizer, max_num_blocks=6):
|
|
pixel_values = load_image(image_path, max_num=max_num_blocks).to(torch.bfloat16).cuda()
|
|
|
|
question = "<image>\nTrích xuất dữ liệu các cột: STT, Mã số thuế, Tên người nộp thuế, Địa chỉ, Số tiền thuế nợ, Biện pháp cưỡng chế. Hãy cố gắng đọc rõ những con số hoặc chữ bị đóng dấu và trả về dạng markdown."
|
|
|
|
generation_config = dict(
|
|
max_new_tokens=2048,
|
|
do_sample=False,
|
|
num_beams=3,
|
|
repetition_penalty=2.5
|
|
)
|
|
|
|
response = model.chat(tokenizer, pixel_values, question, generation_config)
|
|
return response
|
|
|