Spaces:
Running
on
Zero
Running
on
Zero
import os | |
from typing import Dict, List | |
import torch | |
from PIL import Image | |
from torchvision import transforms | |
from transformers import AutoTokenizer | |
from huggingface_hub import snapshot_download | |
import numpy as np | |
def crop_arr(pil_image, max_image_size): | |
while min(*pil_image.size) >= 2 * max_image_size: | |
pil_image = pil_image.resize( | |
tuple(x // 2 for x in pil_image.size), resample=Image.BOX | |
) | |
if max(*pil_image.size) > max_image_size: | |
scale = max_image_size / max(*pil_image.size) | |
pil_image = pil_image.resize( | |
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC | |
) | |
if min(*pil_image.size) < 16: | |
scale = 16 / min(*pil_image.size) | |
pil_image = pil_image.resize( | |
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC | |
) | |
arr = np.array(pil_image) | |
crop_y1 = (arr.shape[0] % 16) // 2 | |
crop_y2 = arr.shape[0] % 16 - crop_y1 | |
crop_x1 = (arr.shape[1] % 16) // 2 | |
crop_x2 = arr.shape[1] % 16 - crop_x1 | |
arr = arr[crop_y1:arr.shape[0] - crop_y2, crop_x1:arr.shape[1] - crop_x2] | |
return Image.fromarray(arr) | |
class OmniGenProcessor: | |
def __init__(self, max_image_size: int = 1024): | |
self.max_image_size = max_image_size | |
self.image_transform = transforms.Compose([ | |
transforms.Lambda(lambda pil_image: crop_arr(pil_image, max_image_size)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) | |
]) | |
self.collator = OmniGenCollator() | |
self.separate_collator = OmniGenSeparateCollator() | |
def from_pretrained(cls, model_name): | |
if not os.path.exists(model_name): | |
cache_folder = os.getenv('HF_HUB_CACHE') | |
model_name = snapshot_download(repo_id=model_name, | |
cache_dir=cache_folder, | |
allow_patterns="*.json") | |
text_tokenizer = AutoTokenizer.from_pretrained(model_name) | |
return cls(text_tokenizer) | |
def process_image(self, image): | |
image = Image.open(image).convert('RGB') | |
return self.image_transform(image) | |
def __call__(self, | |
context_hidden_state: List[torch.tensor], | |
neg_context_hidden_state: List[torch.tensor], | |
height: int = 1024, | |
width: int = 1024, | |
separate_cfg_input: bool = False, | |
) -> Dict: | |
input_data = [] | |
for i in range(len(context_hidden_state)): | |
cur_context_hidden_state = context_hidden_state[i] | |
cur_neg_context_hidden_state = neg_context_hidden_state[i] | |
input_data.append((cur_context_hidden_state, cur_neg_context_hidden_state, [height, width])) | |
if separate_cfg_input: | |
return self.separate_collator(input_data) | |
return self.collator(input_data) | |
class OmniGenCollator: | |
def __init__(self, pad_token_id=2, llm_pad_token_id=151643, hidden_size=3072): | |
self.llm_pad_token_id = llm_pad_token_id | |
self.pad_token_id = pad_token_id | |
self.hidden_size = hidden_size | |
def create_position(self, attention_mask, num_tokens_for_output_images): | |
position_ids = [] | |
text_length = attention_mask.size(-1) | |
img_length = max(num_tokens_for_output_images) | |
for mask in attention_mask: | |
temp_l = torch.sum(mask) | |
temp_position = [0] * (text_length - temp_l) + [i for i in range(temp_l + img_length + 1)] # we add a time embedding into the sequence, so add one more token | |
position_ids.append(temp_position) | |
return torch.LongTensor(position_ids) | |
def create_connector_position(self, llm_2d_attention_mask): | |
position_ids = [] | |
text_length = llm_2d_attention_mask.size(-1) | |
for batch_idx, mask in enumerate(llm_2d_attention_mask): | |
temp_l = torch.sum(llm_2d_attention_mask[batch_idx]) | |
temp_position = [0] * (text_length - temp_l) + [i for i in range(temp_l)] # only condition for mllm like qwen | |
position_ids.append(temp_position) | |
return torch.LongTensor(position_ids) | |
def create_mask(self, attention_mask, num_tokens_for_output_images): | |
extended_mask = [] | |
padding_images = [] | |
text_length = attention_mask.size(-1) | |
img_length = max(num_tokens_for_output_images) | |
seq_len = text_length + img_length + 1 # we add a time embedding into the sequence, so add one more token | |
inx = 0 | |
for mask in attention_mask: | |
temp_l = torch.sum(mask) | |
pad_l = text_length - temp_l | |
temp_mask = torch.tril(torch.ones(size=(temp_l + 1, temp_l + 1))) | |
image_mask = torch.zeros(size=(temp_l + 1, img_length)) | |
temp_mask = torch.cat([temp_mask, image_mask], dim=-1) | |
image_mask = torch.ones(size=(img_length, temp_l + img_length + 1)) | |
temp_mask = torch.cat([temp_mask, image_mask], dim=0) | |
if pad_l > 0: | |
pad_mask = torch.zeros(size=(temp_l + 1 + img_length, pad_l)) | |
temp_mask = torch.cat([pad_mask, temp_mask], dim=-1) | |
pad_mask = torch.ones(size=(pad_l, seq_len)) | |
temp_mask = torch.cat([pad_mask, temp_mask], dim=0) | |
true_img_length = num_tokens_for_output_images[inx] | |
pad_img_length = img_length - true_img_length | |
if pad_img_length > 0: | |
temp_mask[:, -pad_img_length:] = 0 | |
temp_padding_imgs = torch.zeros(size=(1, pad_img_length, self.hidden_size)) | |
else: | |
temp_padding_imgs = None | |
extended_mask.append(temp_mask.unsqueeze(0)) | |
padding_images.append(temp_padding_imgs) | |
inx += 1 | |
return torch.cat(extended_mask, dim=0), padding_images | |
def adjust_attention_for_input_images(self, attention_mask, image_sizes): | |
for b_inx in image_sizes.keys(): | |
for start_inx, end_inx in image_sizes[b_inx]: | |
attention_mask[b_inx][start_inx:end_inx, start_inx:end_inx] = 1 | |
return attention_mask | |
def pad_input(self, context_hidden_state): | |
# pad_token_id = self.llm_pad_token_id # 151642 <|endoftext|> in qwen2.5vl | |
max_l = max([x.shape[1] for x in context_hidden_state]) | |
attention_mask = [] | |
for i in range(len(context_hidden_state)): | |
temp_hidden = context_hidden_state[i] | |
temp_l = temp_hidden.shape[1] | |
pad_l = max_l - temp_l | |
if pad_l == 0: | |
attention_mask.append([1] * max_l) | |
else: | |
attention_mask.append([0] * pad_l + [1] * temp_l) | |
return torch.LongTensor(attention_mask) | |
def process_mllm_input(self, context_hidden_state, target_img_size): | |
num_tokens_for_output_images = [] | |
for img_size in target_img_size: | |
num_tokens_for_output_images.append(img_size[0] * img_size[1] // 16 // 16) | |
llm_2d_attention_mask = self.pad_input(context_hidden_state) | |
connector_position_ids = self.create_connector_position(llm_2d_attention_mask) | |
llm_position_ids = self.create_position(llm_2d_attention_mask, num_tokens_for_output_images) | |
llm_attention_mask, _ = self.create_mask(llm_2d_attention_mask, num_tokens_for_output_images) | |
return llm_2d_attention_mask, connector_position_ids, llm_attention_mask, llm_position_ids | |
class OmniGenSeparateCollator(OmniGenCollator): | |
def __call__(self, features): | |
context_hidden_state = [f[0] for f in features] | |
neg_context_hidden_state = [f[1] for f in features] | |
target_img_size = [f[2] for f in features] | |
all_context_hidden_state, all_connector_attention_mask, all_connector_position_ids, all_llm_attention_mask, all_llm_position_ids = [], [], [], [], [] | |
connector_attention_mask, connector_position_ids, llm_attention_mask, llm_position_ids = self.process_mllm_input(context_hidden_state, target_img_size) | |
all_context_hidden_state.append(context_hidden_state[0]) | |
all_connector_attention_mask.append(connector_attention_mask) | |
all_connector_position_ids.append(connector_position_ids) | |
all_llm_attention_mask.append(llm_attention_mask) | |
all_llm_position_ids.append(llm_position_ids) | |
if neg_context_hidden_state[0] is not None: | |
connector_attention_mask, connector_position_ids, llm_attention_mask, llm_position_ids = self.process_mllm_input(neg_context_hidden_state, target_img_size) | |
all_context_hidden_state.append(neg_context_hidden_state[0]) | |
all_connector_attention_mask.append(connector_attention_mask) | |
all_connector_position_ids.append(connector_position_ids) | |
all_llm_attention_mask.append(llm_attention_mask) | |
all_llm_position_ids.append(llm_position_ids) | |
data = { | |
"context_hidden_state": all_context_hidden_state, | |
"connector_attention_mask": all_connector_attention_mask, | |
"connector_position_ids": all_connector_position_ids, | |
"llm_attention_mask": all_llm_attention_mask, | |
"llm_position_ids": all_llm_position_ids, | |
} | |
return data | |