jiuhai's picture
Upload 59 files
6858cdd verified
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers import (
AutoConfig,
AutoModelForCausalLM,
Qwen3Config,
Qwen3ForCausalLM,
Qwen3Model,
)
from transformers.generation.utils import GenerateOutput
from transformers.modeling_outputs import CausalLMOutputWithPast
from blip3o.model.blip3o_arch import blip3oMetaForCausalLM, blip3oMetaModel
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3
from blip3o.utils import rank0_print
class blip3oQwenConfig(Qwen3Config):
model_type = "blip3o_qwen"
class blip3oQwenModel(blip3oMetaModel, Qwen3Model):
config_class = blip3oQwenConfig
def __init__(self, config: Qwen3Config):
super(blip3oQwenModel, self).__init__(config)
class blip3oQwenForCausalLM(Qwen3ForCausalLM, blip3oMetaForCausalLM):
config_class = blip3oQwenConfig
def __init__(self, config):
Qwen3ForCausalLM.__init__(self, config)
config.model_type = "blip3o_qwen"
config.rope_scaling = None
self.model = blip3oQwenModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.model
def get_sigmas(self, timesteps, device, n_dim=4, dtype=torch.float32):
sigmas = self.model.noise_scheduler.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = self.model.noise_scheduler.timesteps.to(device)
timesteps = timesteps.to(device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
def mask_drop(self, latents, drop_prob=0.1):
if drop_prob <= 0:
return latents
mask = torch.bernoulli(torch.zeros(latents.shape[0], device=latents.device, dtype=latents.dtype) + drop_prob)
while len(mask.shape) < len(latents.shape):
mask = mask.unsqueeze(-1)
mask = 1 - mask # need to flip 0 <-> 1
return latents * mask
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
target_images: Optional[torch.FloatTensor] = None,
image_sizes: Optional[List[List[int]]] = None,
return_dict: Optional[bool] = None,
modalities: Optional[List[str]] = ["image"],
dpo_forward: Optional[bool] = False,
cache_position=None,
) -> Union[Tuple, CausalLMOutputWithPast]:
if inputs_embeds is None:
(input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if target_images is not None:
vae = self.model.get_sana_vae()
latents = vae.encode(target_images).latent
if "shift_factor" in vae.config and vae.config.shift_factor is not None:
latents = latents - vae.config.shift_factor
latents = latents * vae.config.scaling_factor
noise = torch.randn_like(latents, device=latents.device)
weighting_scheme = "uniform"
u = compute_density_for_timestep_sampling(
weighting_scheme=weighting_scheme,
batch_size=latents.shape[0],
logit_mean=0.0,
logit_std=1.0,
mode_scale=1.29,
)
indices = (u * self.model.noise_scheduler.config.num_train_timesteps).long()
timesteps = self.model.noise_scheduler.timesteps[indices].to(device=latents.device)
sigmas = self.get_sigmas(timesteps, latents.device, n_dim=latents.ndim, dtype=latents.dtype)
noisy_latents = (1.0 - sigmas) * latents + sigmas * noise
sana = self.model.get_sana()
start_pos = (labels == self.config.image_start_tag_id).float().argmax(dim=1)
end_pos = (labels == self.config.image_end_tag_id).float().argmax(dim=1)
breakpoint()
selected_hidden_states = []
for b in range(hidden_states.size(0)):
start = start_pos[b].item() + 1
end = end_pos[b].item()
hidden_states_filter = hidden_states[b, start:end, :]
if hidden_states_filter.size(1) != 730:
hidden_states_filter = hidden_states[b, -730:, :]
selected_hidden_states.append(hidden_states_filter)
selected_hidden_states = torch.stack(selected_hidden_states, dim=0)
diffusion_pred = sana(
hidden_states=noisy_latents,
timestep=timesteps,
encoder_hidden_states=self.model.diffusion_connector(self.mask_drop(selected_hidden_states)),
encoder_attention_mask=None,
).sample
target = noise - latents
weighting = compute_loss_weighting_for_sd3(weighting_scheme=weighting_scheme, sigmas=sigmas)
diff_loss = torch.mean(
(weighting.float() * (diffusion_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
1,
)
diff_loss = diff_loss.mean()
rank0_print(f" Cross-entropy loss {loss}, Diffusion loss {diff_loss} ")
loss += diff_loss
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
images: Optional[torch.Tensor] = None,
image_sizes: Optional[torch.Tensor] = None,
modalities: Optional[List[str]] = ["image"],
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
position_ids = kwargs.pop("position_ids", None)
attention_mask = kwargs.pop("attention_mask", None)
if "inputs_embeds" in kwargs:
raise NotImplementedError("`inputs_embeds` is not supported")
if images is not None:
(inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes)
else:
inputs_embeds = self.get_model().embed_tokens(inputs)
return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
images = kwargs.pop("images", None)
image_sizes = kwargs.pop("image_sizes", None)
inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
if images is not None:
inputs["images"] = images
if image_sizes is not None:
inputs["image_sizes"] = image_sizes
return inputs
AutoConfig.register("blip3o_qwen", blip3oQwenConfig)
AutoModelForCausalLM.register(blip3oQwenConfig, blip3oQwenForCausalLM)