from transformers import pipeline, AutoTokenizer, VisionEncoderDecoderModel, AutoProcessor, Qwen2VLForConditionalGeneration from PIL import Image from io import BytesIO import base64 import json import torch from qwen_vl_utils import process_vision_info import prompt # Chuyển ảnh thành base64 (tùy chọn nếu bạn cần hiển thị hoặc xuất) def pil_to_base64(image: Image.Image, format="PNG") -> str: buffered = BytesIO() image.save(buffered, format=format) buffered.seek(0) return base64.b64encode(buffered.read()).decode("utf-8") def parse_to_json(result_text): """ Nếu output là các dòng 'key: value', parse thành dict. Nếu không, gói nguyên text vào trường 'text'. """ data = {} lines = [line.strip() for line in result_text.splitlines() if line.strip()] for line in lines: if ":" in line: key, val = line.split(":", 1) data[key.strip()] = val.strip() else: # Nếu không tách được, gom vào list chung data.setdefault("text", []).append(line) # Nếu chỉ có list 'text', chuyển về chuỗi if set(data.keys()) == {"text"}: data = {"text": "\n".join(data["text"])} return data # class TrOCRModel: # def __init__(self, model_id="microsoft/trocr-base-printed", cache_dir=None, device=None): # self.model_id = model_id # self.cache_dir = cache_dir # self.device = device # self.processor = TrOCRProcessor.from_pretrained(self.model_id, cache_dir=self.cache_dir) # self.model = VisionEncoderDecoderModel.from_pretrained(self.model_id, cache_dir=self.cache_dir) # self.model.to(self.device) # def predict(self, image: Image.Image) -> str: # if image is None: # raise ValueError("No image provided") # image = image.convert("RGB") # pixel_values = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device) # with torch.no_grad(): # generated_ids = self.model.generate(pixel_values) # generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0] # return generated_text class TrOCRModel: def __init__(self, model_id="microsoft/trocr-base-printed", cache_dir=None, device=None): self.pipe = pipeline("image-to-text", model=model_id, device=device) def predict(self, image: Image.Image) -> str: if image is None: raise ValueError("No image provided") image = image.convert("RGB") result = self.pipe(image) return result[0]['generated_text'] if result else "" class EraXModel: def __init__(self, model_id="erax-ai/EraX-VL-2B-V1.5", cache_dir=None, device="auto"): size = { "shortest_edge": 56 * 56, # đủ chi tiết, dùng phổ biến trong ViT/TrOCR "longest_edge": 1280 * 28 * 28 # giới hạn chiều dài ảnh nếu là ảnh dọc hoặc ngang dài } # with open(config_json_path, 'r', encoding='utf-8') as f: # self.json_template = json.dumps(json.load(f), ensure_ascii=False) self.model = Qwen2VLForConditionalGeneration.from_pretrained( model_id, cache_dir=cache_dir, torch_dtype=torch.bfloat16, attn_implementation="eager", # replace with "flash_attention_2" if your GPU is Ampere architecture device_map="auto", ) self.tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir) self.processor = AutoProcessor.from_pretrained( model_id, size=size, cache_dir=cache_dir, ) # Generation configs self.generation_config = self.model.generation_config self.generation_config.do_sample = True self.generation_config.temperature = 1.0 self.generation_config.top_k = 1 self.generation_config.top_p = 0.9 self.generation_config.min_p = 0.1 self.generation_config.best_of = 5 self.generation_config.max_new_tokens = 784 self.generation_config.repetition_penalty = 1.06 def predict(self, image: Image.Image) -> str: if image is None: raise ValueError("No image provided") # image_path = "image.png" # # Read and encode the image # with open(image_path, "rb") as f: # encoded_image = base64.b64encode(f.read()) # decoded_image_text = encoded_image.decode('utf-8') # base64_data = f"data:image;base64,{decoded_image_text}" decoded_image_text = pil_to_base64(image) base64_data = f"data:image;base64,{decoded_image_text}" # Prepare messages messages = [ { "role": "user", "content": [ { "type": "image", "image": base64_data, }, { "type": "text", "text": prompt.CCCD_BOTH_SIDE_PROMPT, }, ], } ] # Prepare prompt tokenized_text = self.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) print("Tokenized text") image_inputs, video_inputs = process_vision_info(messages) print("Processed vision info done") inputs = self.processor( text=[tokenized_text], # images=image_inputs, images=[image], # videos=video_inputs, padding=True, return_tensors="pt", ).to(self.model.device) print("Inputs prepared") # Inference print("Generating text...") generated_ids = self.model.generate(**inputs, generation_config=self.generation_config) generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = self.processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) return output_text[0]