Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,132 Bytes
7997f38 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import spaces
from io import BytesIO
import torch
from PIL import Image
import base64
import json
import re
import logging
from transformers import AutoModel, AutoTokenizer, AutoProcessor, set_seed
# set_seed(42)
logger = logging.getLogger(__name__)
class ModelMiniCPMV4_5:
def __init__(self, path) -> None:
self.model = AutoModel.from_pretrained(
path, trust_remote_code=True, attn_implementation='sdpa', torch_dtype=torch.bfloat16, device_map="auto")
self.model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(
path, trust_remote_code=True)
self.processor = AutoProcessor.from_pretrained(
path, trust_remote_code=True)
def __call__(self, input_data):
image = None
if "image" in input_data and len(input_data["image"]) > 10:
image = Image.open(BytesIO(base64.b64decode(
input_data["image"]))).convert('RGB')
msgs = input_data["question"]
params = input_data.get("params", "{}")
params = json.loads(params)
msgs = json.loads(msgs)
temporal_ids = input_data.get("temporal_ids", None)
if temporal_ids:
temporal_ids = json.loads(temporal_ids)
if params.get("max_new_tokens", 0) > 16384:
logger.info(f"make max_new_tokens=16384, reducing limit to save memory")
params["max_new_tokens"] = 16384
if params.get("max_inp_length", 0) > 2048 * 10:
logger.info(f"make max_inp_length={2048 * 10}, keeping high limit for video processing")
params["max_inp_length"] = 2048 * 10
for msg in msgs:
if 'content' in msg:
contents = msg['content']
else:
contents = msg.pop('contents')
new_cnts = []
for c in contents:
if isinstance(c, dict):
if c['type'] == 'text':
c = c['pairs']
elif c['type'] == 'image':
c = Image.open(
BytesIO(base64.b64decode(c["pairs"]))).convert('RGB')
else:
raise ValueError(
"contents type only support text and image.")
new_cnts.append(c)
msg['content'] = new_cnts
logger.info(f'msgs: {str(msgs)}')
enable_thinking = params.pop('enable_thinking', True)
is_streaming = params.pop('stream', False)
if is_streaming:
return self._stream_chat(image, msgs, enable_thinking, params, temporal_ids)
else:
chat_kwargs = {
"image": image,
"msgs": msgs,
"tokenizer": self.tokenizer,
"processor": self.processor,
"enable_thinking": enable_thinking,
**params
}
if temporal_ids is not None:
chat_kwargs["temporal_ids"] = temporal_ids
answer = self.model.chat(**chat_kwargs)
res = re.sub(r'(<box>.*</box>)', '', answer)
res = res.replace('<ref>', '')
res = res.replace('</ref>', '')
res = res.replace('<box>', '')
answer = res.replace('</box>', '')
if not enable_thinking:
print(f"enable_thinking: {enable_thinking}")
answer = answer.replace('</think>', '')
oids = self.tokenizer.encode(answer)
output_tokens = len(oids)
return answer, output_tokens
def _stream_chat(self, image, msgs, enable_thinking, params, temporal_ids=None):
try:
params['stream'] = True
chat_kwargs = {
"image": image,
"msgs": msgs,
"tokenizer": self.tokenizer,
"processor": self.processor,
"enable_thinking": enable_thinking,
**params
}
if temporal_ids is not None:
chat_kwargs["temporal_ids"] = temporal_ids
answer_generator = self.model.chat(**chat_kwargs)
if not hasattr(answer_generator, '__iter__'):
answer = answer_generator
res = re.sub(r'(<box>.*</box>)', '', answer)
res = res.replace('<ref>', '')
res = res.replace('</ref>', '')
res = res.replace('<box>', '')
answer = res.replace('</box>', '')
if not enable_thinking:
answer = answer.replace('</think>', '')
char_count = 0
for char in answer:
yield char
char_count += 1
else:
full_answer = ""
chunk_count = 0
char_count = 0
for chunk in answer_generator:
if isinstance(chunk, str):
clean_chunk = re.sub(r'(<box>.*</box>)', '', chunk)
clean_chunk = clean_chunk.replace('<ref>', '')
clean_chunk = clean_chunk.replace('</ref>', '')
clean_chunk = clean_chunk.replace('<box>', '')
clean_chunk = clean_chunk.replace('</box>', '')
if not enable_thinking:
clean_chunk = clean_chunk.replace('</think>', '')
full_answer += chunk
char_count += len(clean_chunk)
chunk_count += 1
yield clean_chunk
else:
full_answer += str(chunk)
char_count += len(str(chunk))
chunk_count += 1
yield str(chunk)
except Exception as e:
logger.error(f"Stream chat error: {e}")
yield f"Error: {str(e)}" |