ocr-detection / models.py
vungocthach1112's picture
update model
e5254aa
from transformers import pipeline, AutoTokenizer, VisionEncoderDecoderModel, AutoProcessor, Qwen2VLForConditionalGeneration
from PIL import Image
from io import BytesIO
import base64
import json
import torch
from qwen_vl_utils import process_vision_info
import prompt
# Chuyển ảnh thành base64 (tùy chọn nếu bạn cần hiển thị hoặc xuất)
def pil_to_base64(image: Image.Image, format="PNG") -> str:
buffered = BytesIO()
image.save(buffered, format=format)
buffered.seek(0)
return base64.b64encode(buffered.read()).decode("utf-8")
def parse_to_json(result_text):
"""
Nếu output là các dòng 'key: value', parse thành dict.
Nếu không, gói nguyên text vào trường 'text'.
"""
data = {}
lines = [line.strip() for line in result_text.splitlines() if line.strip()]
for line in lines:
if ":" in line:
key, val = line.split(":", 1)
data[key.strip()] = val.strip()
else:
# Nếu không tách được, gom vào list chung
data.setdefault("text", []).append(line)
# Nếu chỉ có list 'text', chuyển về chuỗi
if set(data.keys()) == {"text"}:
data = {"text": "\n".join(data["text"])}
return data
# class TrOCRModel:
# def __init__(self, model_id="microsoft/trocr-base-printed", cache_dir=None, device=None):
# self.model_id = model_id
# self.cache_dir = cache_dir
# self.device = device
# self.processor = TrOCRProcessor.from_pretrained(self.model_id, cache_dir=self.cache_dir)
# self.model = VisionEncoderDecoderModel.from_pretrained(self.model_id, cache_dir=self.cache_dir)
# self.model.to(self.device)
# def predict(self, image: Image.Image) -> str:
# if image is None:
# raise ValueError("No image provided")
# image = image.convert("RGB")
# pixel_values = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device)
# with torch.no_grad():
# generated_ids = self.model.generate(pixel_values)
# generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# return generated_text
class TrOCRModel:
def __init__(self, model_id="microsoft/trocr-base-printed", cache_dir=None, device=None):
self.pipe = pipeline("image-to-text", model=model_id, device=device)
def predict(self, image: Image.Image) -> str:
if image is None:
raise ValueError("No image provided")
image = image.convert("RGB")
result = self.pipe(image)
return result[0]['generated_text'] if result else ""
class EraXModel:
def __init__(self, model_id="erax-ai/EraX-VL-2B-V1.5", cache_dir=None, device="auto"):
size = {
"shortest_edge": 56 * 56, # đủ chi tiết, dùng phổ biến trong ViT/TrOCR
"longest_edge": 1280 * 28 * 28 # giới hạn chiều dài ảnh nếu là ảnh dọc hoặc ngang dài
}
# with open(config_json_path, 'r', encoding='utf-8') as f:
# self.json_template = json.dumps(json.load(f), ensure_ascii=False)
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
model_id,
cache_dir=cache_dir,
torch_dtype=torch.bfloat16,
attn_implementation="eager", # replace with "flash_attention_2" if your GPU is Ampere architecture
device_map="auto",
)
self.tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir)
self.processor = AutoProcessor.from_pretrained(
model_id,
size=size,
cache_dir=cache_dir,
)
# Generation configs
self.generation_config = self.model.generation_config
self.generation_config.do_sample = True
self.generation_config.temperature = 1.0
self.generation_config.top_k = 1
self.generation_config.top_p = 0.9
self.generation_config.min_p = 0.1
self.generation_config.best_of = 5
self.generation_config.max_new_tokens = 784
self.generation_config.repetition_penalty = 1.06
def predict(self, image: Image.Image) -> str:
if image is None:
raise ValueError("No image provided")
# image_path = "image.png"
# # Read and encode the image
# with open(image_path, "rb") as f:
# encoded_image = base64.b64encode(f.read())
# decoded_image_text = encoded_image.decode('utf-8')
# base64_data = f"data:image;base64,{decoded_image_text}"
decoded_image_text = pil_to_base64(image)
base64_data = f"data:image;base64,{decoded_image_text}"
# Prepare messages
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": base64_data,
},
{
"type": "text",
"text": prompt.CCCD_BOTH_SIDE_PROMPT,
},
],
}
]
# Prepare prompt
tokenized_text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
print("Tokenized text")
image_inputs, video_inputs = process_vision_info(messages)
print("Processed vision info done")
inputs = self.processor(
text=[tokenized_text],
# images=image_inputs,
images=[image],
# videos=video_inputs,
padding=True,
return_tensors="pt",
).to(self.model.device)
print("Inputs prepared")
# Inference
print("Generating text...")
generated_ids = self.model.generate(**inputs, generation_config=self.generation_config)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return output_text[0]