Spaces:
Sleeping
Sleeping
| # model.py | |
| import torch | |
| from transformers import AutoModel, AutoTokenizer, GenerationConfig | |
| class OCRModel: | |
| def __init__( | |
| self, | |
| model_id: str = "5CD-AI/Vintern-1B-v3_5", | |
| allow_flash_attn: bool = False, | |
| prefer_bfloat16: bool = False, | |
| ): | |
| self.model_id = model_id | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| if self.device.type == "cuda": | |
| if prefer_bfloat16 and torch.cuda.is_bf16_supported(): | |
| self.dtype = torch.bfloat16 | |
| else: | |
| self.dtype = torch.float16 | |
| else: | |
| self.dtype = torch.float32 | |
| self.allow_flash_attn = bool(allow_flash_attn and self.device.type == "cuda") | |
| self.model = None | |
| self.tokenizer = None | |
| self.is_loaded = False | |
| def on_cuda(self): return self.device.type == "cuda" | |
| def device_str(self): return f"{self.device} ({str(self.dtype)})" | |
| def load(self): | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code=True) | |
| # ưu tiên API mới (dtype=), fallback torch_dtype nếu cần | |
| try: | |
| self.model = AutoModel.from_pretrained( | |
| self.model_id, dtype=self.dtype, trust_remote_code=True | |
| ) | |
| except TypeError: | |
| self.model = AutoModel.from_pretrained( | |
| self.model_id, torch_dtype=self.dtype, trust_remote_code=True | |
| ) | |
| self.model.to(device=self.device, dtype=self.dtype) | |
| self.model.eval() | |
| if not hasattr(self.model, "generation_config") or self.model.generation_config is None: | |
| self.model.generation_config = GenerationConfig() | |
| self.is_loaded = True | |
| def _build_gen_dict(self, **gen_kwargs) -> dict: | |
| """ | |
| Trả về generation_config dạng DICT theo kỳ vọng của InternVLChatModel.chat(), | |
| và LOẠI các khóa có thể bị truyền trùng trong .generate(...) | |
| """ | |
| # base từ GenerationConfig hiện có | |
| if hasattr(self.model, "generation_config") and self.model.generation_config is not None: | |
| try: | |
| base = self.model.generation_config.to_dict() | |
| except Exception: | |
| base = {} | |
| else: | |
| base = {} | |
| # gộp tham số từ UI | |
| for k, v in (gen_kwargs or {}).items(): | |
| base[k] = v | |
| # Bổ sung token ids nếu thiếu | |
| if "eos_token_id" not in base and hasattr(self.tokenizer, "eos_token_id"): | |
| base["eos_token_id"] = self.tokenizer.eos_token_id | |
| if "pad_token_id" not in base: | |
| pad_id = getattr(self.tokenizer, "pad_token_id", None) | |
| base["pad_token_id"] = pad_id if pad_id is not None else base.get("eos_token_id", None) | |
| if "bos_token_id" not in base and hasattr(self.tokenizer, "bos_token_id"): | |
| base["bos_token_id"] = self.tokenizer.bos_token_id | |
| # ép kiểu int cho *_token_id | |
| for key in ("eos_token_id", "pad_token_id", "bos_token_id"): | |
| if key in base and base[key] is not None: | |
| try: | |
| base[key] = int(base[key]) | |
| except Exception: | |
| pass | |
| # 🚫 LOẠI các khóa dễ bị “multiple values” | |
| for bad in ("use_cache", "output_attentions", "output_hidden_states", | |
| "return_dict_in_generate", "synced_gpus"): | |
| base.pop(bad, None) | |
| return base | |
| def chat(self, pixel_values: torch.Tensor, question: str, **gen_kwargs) -> str: | |
| if not self.is_loaded: | |
| self.load() | |
| # đồng bộ dtype/device input với model | |
| model_dtype = next(self.model.parameters()).dtype | |
| pixel_values = pixel_values.to(device=self.device, dtype=model_dtype) | |
| # DICT sạch cho generation_config | |
| gen_dict = self._build_gen_dict(**gen_kwargs) | |
| # gọi chat: yêu cầu tokenizer + generation_config (DICT) | |
| out = self.model.chat( | |
| pixel_values=pixel_values, | |
| question=question, | |
| tokenizer=self.tokenizer, | |
| generation_config=gen_dict, | |
| ) | |
| if isinstance(out, (list, tuple)) and len(out) >= 1: | |
| return out[0] | |
| return out | |