Spaces:
Running
on
Zero
Running
on
Zero
from .mllm import MindOmniMLLM | |
from .image_decoder import OmniGen | |
import torch.nn as nn | |
from .image_decoder import Phi3DecoderLayer, ImageDecoderPipeline, OmniGenProcessor | |
import os | |
import torch | |
from huggingface_hub import snapshot_download | |
from safetensors.torch import load_file | |
from typing import Union | |
from diffusers.utils import logging | |
from diffusers.models import AutoencoderKL | |
from transformers import AutoProcessor | |
import re | |
from qwen_vl_utils import process_vision_info | |
logger = logging.get_logger(__name__) | |
class MindOmniConnector(nn.Module): | |
def __init__(self, pre_config, post_config, layer_num: int = 2): | |
super().__init__() | |
connector_decoder = nn.ModuleList( | |
[Phi3DecoderLayer(post_config, layer_idx) for layer_idx in range(layer_num)] | |
) | |
self.connector = nn.ModuleList( | |
[nn.Linear(pre_config.hidden_size, post_config.hidden_size)] # qwen2.5vl-7b: 3584 | |
) | |
self.connector.extend(connector_decoder) | |
class MindOmni: | |
def __init__(self, mllm, image_decoder, connector, vae, processor, mllm_processor, device: Union[str, torch.device] = None): | |
self.mllm = mllm | |
self.image_decoder = image_decoder | |
self.connector = connector | |
self.vae = vae | |
self.processor = processor | |
self.mllm_processor = mllm_processor | |
self.vae.to(torch.float32) | |
self.device = device | |
if device is None: | |
if torch.cuda.is_available(): | |
self.device = torch.device("cuda") | |
elif torch.backends.mps.is_available(): | |
self.device = torch.device("mps") | |
else: | |
logger.info("Don't detect any available GPUs, using CPU instead, this may take long time to generate image!!!") | |
self.device = torch.device("cpu") | |
def from_pretrained(cls, model_path): | |
model_path = snapshot_download(repo_id=model_path) | |
mllm = MindOmniMLLM.from_pretrained(os.path.join(model_path, 'mllm')) | |
image_decoder = OmniGen.from_pretrained(os.path.join(model_path, 'image_decoder')) | |
connector = MindOmniConnector(mllm.config, image_decoder.llm.config, 2).connector | |
connector_state = load_file(os.path.join(model_path, 'connector.safetensors')) | |
connector.load_state_dict(connector_state) | |
vae = AutoencoderKL.from_pretrained(os.path.join(model_path, 'vae')) | |
processor = OmniGenProcessor.from_pretrained(os.path.join(model_path, 'image_decoder')) | |
mllm_processor = AutoProcessor.from_pretrained(os.path.join(model_path, 'mllm')) | |
logger.info("MindOmni has been loaded.") | |
return cls(mllm, image_decoder, connector, vae, processor, mllm_processor) | |
def to(self, device: Union[str, torch.device] = None, dtype: Union[str, torch.device] = None): | |
if device is not None: | |
if isinstance(device, str): | |
device = torch.device(device) | |
self.mllm.to(device) | |
self.image_decoder.to(device) | |
self.connector.to(device) | |
self.vae.to(device) | |
self.device = device | |
if dtype is not None: | |
self.mllm.to(dtype) | |
self.image_decoder.to(dtype) | |
self.connector.to(dtype) | |
def eval(self): | |
self.mllm.eval() | |
self.image_decoder.eval() | |
self.connector.eval() | |
self.vae.eval() | |
def get_mllm_hidden_state(self, user_input, input_images, do_sample, temperature, max_new_tokens, only_understand=False, use_cot=False): | |
input_llm_images = input_images | |
processor = self.mllm_processor | |
model = self.mllm | |
if only_understand or not use_cot: | |
system_prompt = ( | |
"You are a helpful assistant." | |
) | |
else: | |
system_prompt = ( | |
"You are a helpful assistant. When the user requests an image, the assistant " | |
"first thinks about the reasoning process in the mind and then provides the user with concise prompt as the answer. " | |
"The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., " | |
"<think> reasoning process here </think><answer> answer here </answer>." | |
) | |
messages = [ | |
{ | |
"role": "system", | |
"content": [ | |
{"type": "text", "text": system_prompt}, | |
], | |
}, | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": "Generate an image according to the following instructions\n"}, | |
{"type": "text", "text": user_input}, | |
], | |
} | |
] | |
if input_llm_images is not None: | |
if only_understand: | |
assert len(input_llm_images) == 1, "only support single image when multimodal understanding" | |
messages[1]['content'][0] = {"type": "image", "image": input_llm_images[0]} | |
else: | |
user_input = f'<img><|image_1|></img> {user_input}' | |
messages[1]['content'][1] = {"type": "text", "text": user_input} | |
image_tags = re.findall(r'<\|image_\d+\|>', messages[1]['content'][1]['text']) | |
image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags] | |
pattern = r"<img><\|image_\d+\|></img>" | |
prompt_chunks = [chunk for chunk in re.split(pattern, messages[1]['content'][1]['text'])] | |
assert len(prompt_chunks) == len(input_llm_images) + 1 | |
new_content = [] | |
for idx, per_prompt in enumerate(prompt_chunks): | |
if idx != len(prompt_chunks) - 1: | |
item_text = {"type": "text", "text": per_prompt} | |
# resized_height, resized_width = input_images_shape[image_ids[idx] - 1] | |
image_path = input_llm_images[image_ids[idx] - 1] | |
# item_vit = {"type": "image", "image": image_path, "resized_height": resized_height, "resized_width": resized_width} | |
item_vit = {"type": "image", "image": image_path} | |
item_tag = {"type": "text", "text": f"<img>{image_tags[idx]}</img>"} | |
new_content.append(item_text) | |
new_content.append(item_vit) | |
new_content.append(item_tag) | |
else: | |
item_text = {"type": "text", "text": per_prompt} | |
new_content.append(item_text) | |
messages[1]['content'] = messages[1]['content'][:1] + new_content | |
text = processor.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True | |
) | |
image_inputs, video_inputs = process_vision_info(messages) | |
inputs = processor( | |
text=[text], | |
images=image_inputs, | |
videos=video_inputs, | |
padding=True, | |
return_tensors='pt', | |
) | |
inputs = inputs.to(self.device) | |
if use_cot: | |
# Inference: Generation of the output | |
temperature = temperature if do_sample else None | |
generated_dict = model.generate(**inputs, do_sample=do_sample, temperature=temperature, max_new_tokens=max_new_tokens, output_hidden_states=True, return_dict_in_generate=True) | |
generated_ids_trimmed = [ | |
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_dict.sequences) | |
] | |
output_hidden_state = [hidden_state[-1] for hidden_state in generated_dict.hidden_states] | |
context_hidden_state = torch.cat(output_hidden_state, dim=1) | |
output_text = processor.batch_decode( | |
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
) | |
prompt_ = output_text[0] | |
assistant_content = [ | |
{ | |
"role": "assistant", | |
"content": [ | |
{"type": "text", "text": prompt_}, | |
], | |
} | |
] | |
messages += assistant_content | |
else: | |
prompt_ = user_input | |
context_hidden_state = model(**inputs, output_hidden_states=True).hidden_states[-1] | |
return messages, prompt_, context_hidden_state | |
def generate_image(self, height, width, guidance_scale, inference_steps, separate_cfg_infer, offload_model, seed, max_input_image_size, | |
text, NEGATIVE_PROMPT, input_llm_images, do_sample, temperature, max_new_tokens, only_understand, use_cot=False, | |
cascade_thinking=False): | |
gen_pipe = ImageDecoderPipeline(self.vae, self.image_decoder, self.connector, self.processor) | |
message, prompt_, context_hidden_state = self.get_mllm_hidden_state(text, input_llm_images, do_sample, temperature, max_new_tokens, only_understand, use_cot=use_cot) | |
if cascade_thinking and use_cot: | |
answer_text = re.search(r'<answer>(.*)', prompt_, re.S) | |
think_text = re.search(r'<think>(.*?)</think>', prompt_, re.S) | |
if answer_text: | |
text = answer_text.group(1).strip() | |
elif think_text: | |
text = think_text.group(1).strip() | |
else: | |
text = prompt_.strip() | |
if len(text) > 0: | |
message, cascade_prompt_, context_hidden_state = self.get_mllm_hidden_state( | |
text, input_llm_images, do_sample, temperature, max_new_tokens, only_understand, use_cot=use_cot) | |
prompt_ = f'{prompt_}\n\ncascade_thinking:{cascade_prompt_}' | |
neg_message, neg_prompt_, neg_context_hidden_state = self.get_mllm_hidden_state(NEGATIVE_PROMPT, None, do_sample, temperature, max_new_tokens, only_understand, use_cot=False) | |
print(message) | |
output = gen_pipe( | |
context_hidden_state=context_hidden_state, | |
neg_context_hidden_state=neg_context_hidden_state, | |
height=height, | |
width=width, | |
guidance_scale=guidance_scale, | |
num_inference_steps=inference_steps, | |
separate_cfg_infer=separate_cfg_infer, | |
use_kv_cache=True, | |
offload_kv_cache=True, | |
offload_model=offload_model, | |
seed=seed, | |
max_input_image_size=max_input_image_size, | |
) | |
return output, prompt_ | |
def generate_text(self, text, input_llm_images, do_sample, temperature, max_new_tokens, only_understand): | |
_, answer, _ = self.get_mllm_hidden_state(text, input_llm_images, do_sample, temperature, max_new_tokens, only_understand=True, use_cot=True) | |
return answer | |