# model.py import torch from transformers import AutoModel, AutoTokenizer, GenerationConfig class OCRModel: def __init__( self, model_id: str = "5CD-AI/Vintern-1B-v3_5", allow_flash_attn: bool = False, prefer_bfloat16: bool = False, ): self.model_id = model_id self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if self.device.type == "cuda": if prefer_bfloat16 and torch.cuda.is_bf16_supported(): self.dtype = torch.bfloat16 else: self.dtype = torch.float16 else: self.dtype = torch.float32 self.allow_flash_attn = bool(allow_flash_attn and self.device.type == "cuda") self.model = None self.tokenizer = None self.is_loaded = False @property def on_cuda(self): return self.device.type == "cuda" @property def device_str(self): return f"{self.device} ({str(self.dtype)})" def load(self): self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code=True) # ưu tiên API mới (dtype=), fallback torch_dtype nếu cần try: self.model = AutoModel.from_pretrained( self.model_id, dtype=self.dtype, trust_remote_code=True ) except TypeError: self.model = AutoModel.from_pretrained( self.model_id, torch_dtype=self.dtype, trust_remote_code=True ) self.model.to(device=self.device, dtype=self.dtype) self.model.eval() if not hasattr(self.model, "generation_config") or self.model.generation_config is None: self.model.generation_config = GenerationConfig() self.is_loaded = True def _build_gen_dict(self, **gen_kwargs) -> dict: """ Trả về generation_config dạng DICT theo kỳ vọng của InternVLChatModel.chat(), và LOẠI các khóa có thể bị truyền trùng trong .generate(...) """ # base từ GenerationConfig hiện có if hasattr(self.model, "generation_config") and self.model.generation_config is not None: try: base = self.model.generation_config.to_dict() except Exception: base = {} else: base = {} # gộp tham số từ UI for k, v in (gen_kwargs or {}).items(): base[k] = v # Bổ sung token ids nếu thiếu if "eos_token_id" not in base and hasattr(self.tokenizer, "eos_token_id"): base["eos_token_id"] = self.tokenizer.eos_token_id if "pad_token_id" not in base: pad_id = getattr(self.tokenizer, "pad_token_id", None) base["pad_token_id"] = pad_id if pad_id is not None else base.get("eos_token_id", None) if "bos_token_id" not in base and hasattr(self.tokenizer, "bos_token_id"): base["bos_token_id"] = self.tokenizer.bos_token_id # ép kiểu int cho *_token_id for key in ("eos_token_id", "pad_token_id", "bos_token_id"): if key in base and base[key] is not None: try: base[key] = int(base[key]) except Exception: pass # 🚫 LOẠI các khóa dễ bị “multiple values” for bad in ("use_cache", "output_attentions", "output_hidden_states", "return_dict_in_generate", "synced_gpus"): base.pop(bad, None) return base def chat(self, pixel_values: torch.Tensor, question: str, **gen_kwargs) -> str: if not self.is_loaded: self.load() # đồng bộ dtype/device input với model model_dtype = next(self.model.parameters()).dtype pixel_values = pixel_values.to(device=self.device, dtype=model_dtype) # DICT sạch cho generation_config gen_dict = self._build_gen_dict(**gen_kwargs) # gọi chat: yêu cầu tokenizer + generation_config (DICT) out = self.model.chat( pixel_values=pixel_values, question=question, tokenizer=self.tokenizer, generation_config=gen_dict, ) if isinstance(out, (list, tuple)) and len(out) >= 1: return out[0] return out