import spaces
from io import BytesIO
import torch
from PIL import Image
import base64
import json
import re
import logging
from transformers import AutoModel, AutoTokenizer, AutoProcessor, set_seed
# set_seed(42)
logger = logging.getLogger(__name__)
class ModelMiniCPMV4_5:
def __init__(self, path) -> None:
self.model = AutoModel.from_pretrained(
path, trust_remote_code=True, attn_implementation='sdpa', torch_dtype=torch.bfloat16, device_map="auto")
self.model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(
path, trust_remote_code=True)
self.processor = AutoProcessor.from_pretrained(
path, trust_remote_code=True)
def __call__(self, input_data):
image = None
if "image" in input_data and len(input_data["image"]) > 10:
image = Image.open(BytesIO(base64.b64decode(
input_data["image"]))).convert('RGB')
msgs = input_data["question"]
params = input_data.get("params", "{}")
params = json.loads(params)
msgs = json.loads(msgs)
temporal_ids = input_data.get("temporal_ids", None)
if temporal_ids:
temporal_ids = json.loads(temporal_ids)
if params.get("max_new_tokens", 0) > 16384:
logger.info(f"make max_new_tokens=16384, reducing limit to save memory")
params["max_new_tokens"] = 16384
if params.get("max_inp_length", 0) > 2048 * 10:
logger.info(f"make max_inp_length={2048 * 10}, keeping high limit for video processing")
params["max_inp_length"] = 2048 * 10
for msg in msgs:
if 'content' in msg:
contents = msg['content']
else:
contents = msg.pop('contents')
new_cnts = []
for c in contents:
if isinstance(c, dict):
if c['type'] == 'text':
c = c['pairs']
elif c['type'] == 'image':
c = Image.open(
BytesIO(base64.b64decode(c["pairs"]))).convert('RGB')
else:
raise ValueError(
"contents type only support text and image.")
new_cnts.append(c)
msg['content'] = new_cnts
logger.info(f'msgs: {str(msgs)}')
enable_thinking = params.pop('enable_thinking', True)
is_streaming = params.pop('stream', False)
if is_streaming:
return self._stream_chat(image, msgs, enable_thinking, params, temporal_ids)
else:
chat_kwargs = {
"image": image,
"msgs": msgs,
"tokenizer": self.tokenizer,
"processor": self.processor,
"enable_thinking": enable_thinking,
**params
}
if temporal_ids is not None:
chat_kwargs["temporal_ids"] = temporal_ids
answer = self.model.chat(**chat_kwargs)
res = re.sub(r'(.*)', '', answer)
res = res.replace('[', '')
res = res.replace(']', '')
res = res.replace('', '')
answer = res.replace('', '')
if not enable_thinking:
print(f"enable_thinking: {enable_thinking}")
answer = answer.replace('', '')
oids = self.tokenizer.encode(answer)
output_tokens = len(oids)
return answer, output_tokens
def _stream_chat(self, image, msgs, enable_thinking, params, temporal_ids=None):
try:
params['stream'] = True
chat_kwargs = {
"image": image,
"msgs": msgs,
"tokenizer": self.tokenizer,
"processor": self.processor,
"enable_thinking": enable_thinking,
**params
}
if temporal_ids is not None:
chat_kwargs["temporal_ids"] = temporal_ids
answer_generator = self.model.chat(**chat_kwargs)
if not hasattr(answer_generator, '__iter__'):
answer = answer_generator
res = re.sub(r'(.*)', '', answer)
res = res.replace('[', '')
res = res.replace(']', '')
res = res.replace('', '')
answer = res.replace('', '')
if not enable_thinking:
answer = answer.replace('', '')
char_count = 0
for char in answer:
yield char
char_count += 1
else:
full_answer = ""
chunk_count = 0
char_count = 0
for chunk in answer_generator:
if isinstance(chunk, str):
clean_chunk = re.sub(r'(.*)', '', chunk)
clean_chunk = clean_chunk.replace('[', '')
clean_chunk = clean_chunk.replace(']', '')
clean_chunk = clean_chunk.replace('', '')
clean_chunk = clean_chunk.replace('', '')
if not enable_thinking:
clean_chunk = clean_chunk.replace('', '')
full_answer += chunk
char_count += len(clean_chunk)
chunk_count += 1
yield clean_chunk
else:
full_answer += str(chunk)
char_count += len(str(chunk))
chunk_count += 1
yield str(chunk)
except Exception as e:
logger.error(f"Stream chat error: {e}")
yield f"Error: {str(e)}"