Spaces:
Sleeping
Sleeping
File size: 4,458 Bytes
b85866b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
# model.py
import torch
from transformers import AutoModel, AutoTokenizer, GenerationConfig
class OCRModel:
def __init__(
self,
model_id: str = "5CD-AI/Vintern-1B-v3_5",
allow_flash_attn: bool = False,
prefer_bfloat16: bool = False,
):
self.model_id = model_id
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if self.device.type == "cuda":
if prefer_bfloat16 and torch.cuda.is_bf16_supported():
self.dtype = torch.bfloat16
else:
self.dtype = torch.float16
else:
self.dtype = torch.float32
self.allow_flash_attn = bool(allow_flash_attn and self.device.type == "cuda")
self.model = None
self.tokenizer = None
self.is_loaded = False
@property
def on_cuda(self): return self.device.type == "cuda"
@property
def device_str(self): return f"{self.device} ({str(self.dtype)})"
def load(self):
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code=True)
# ưu tiên API mới (dtype=), fallback torch_dtype nếu cần
try:
self.model = AutoModel.from_pretrained(
self.model_id, dtype=self.dtype, trust_remote_code=True
)
except TypeError:
self.model = AutoModel.from_pretrained(
self.model_id, torch_dtype=self.dtype, trust_remote_code=True
)
self.model.to(device=self.device, dtype=self.dtype)
self.model.eval()
if not hasattr(self.model, "generation_config") or self.model.generation_config is None:
self.model.generation_config = GenerationConfig()
self.is_loaded = True
def _build_gen_dict(self, **gen_kwargs) -> dict:
"""
Trả về generation_config dạng DICT theo kỳ vọng của InternVLChatModel.chat(),
và LOẠI các khóa có thể bị truyền trùng trong .generate(...)
"""
# base từ GenerationConfig hiện có
if hasattr(self.model, "generation_config") and self.model.generation_config is not None:
try:
base = self.model.generation_config.to_dict()
except Exception:
base = {}
else:
base = {}
# gộp tham số từ UI
for k, v in (gen_kwargs or {}).items():
base[k] = v
# Bổ sung token ids nếu thiếu
if "eos_token_id" not in base and hasattr(self.tokenizer, "eos_token_id"):
base["eos_token_id"] = self.tokenizer.eos_token_id
if "pad_token_id" not in base:
pad_id = getattr(self.tokenizer, "pad_token_id", None)
base["pad_token_id"] = pad_id if pad_id is not None else base.get("eos_token_id", None)
if "bos_token_id" not in base and hasattr(self.tokenizer, "bos_token_id"):
base["bos_token_id"] = self.tokenizer.bos_token_id
# ép kiểu int cho *_token_id
for key in ("eos_token_id", "pad_token_id", "bos_token_id"):
if key in base and base[key] is not None:
try:
base[key] = int(base[key])
except Exception:
pass
# 🚫 LOẠI các khóa dễ bị “multiple values”
for bad in ("use_cache", "output_attentions", "output_hidden_states",
"return_dict_in_generate", "synced_gpus"):
base.pop(bad, None)
return base
def chat(self, pixel_values: torch.Tensor, question: str, **gen_kwargs) -> str:
if not self.is_loaded:
self.load()
# đồng bộ dtype/device input với model
model_dtype = next(self.model.parameters()).dtype
pixel_values = pixel_values.to(device=self.device, dtype=model_dtype)
# DICT sạch cho generation_config
gen_dict = self._build_gen_dict(**gen_kwargs)
# gọi chat: yêu cầu tokenizer + generation_config (DICT)
out = self.model.chat(
pixel_values=pixel_values,
question=question,
tokenizer=self.tokenizer,
generation_config=gen_dict,
)
if isinstance(out, (list, tuple)) and len(out) >= 1:
return out[0]
return out
|