Spaces:
Configuration error
Configuration error
from typing import Dict, List, Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
from transformers import ( | |
AutoConfig, | |
AutoModelForCausalLM, | |
Qwen3Config, | |
Qwen3ForCausalLM, | |
Qwen3Model, | |
) | |
from transformers.generation.utils import GenerateOutput | |
from transformers.modeling_outputs import CausalLMOutputWithPast | |
from blip3o.model.blip3o_arch import blip3oMetaForCausalLM, blip3oMetaModel | |
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3 | |
from blip3o.utils import rank0_print | |
class blip3oQwenConfig(Qwen3Config): | |
model_type = "blip3o_qwen" | |
class blip3oQwenModel(blip3oMetaModel, Qwen3Model): | |
config_class = blip3oQwenConfig | |
def __init__(self, config: Qwen3Config): | |
super(blip3oQwenModel, self).__init__(config) | |
class blip3oQwenForCausalLM(Qwen3ForCausalLM, blip3oMetaForCausalLM): | |
config_class = blip3oQwenConfig | |
def __init__(self, config): | |
Qwen3ForCausalLM.__init__(self, config) | |
config.model_type = "blip3o_qwen" | |
config.rope_scaling = None | |
self.model = blip3oQwenModel(config) | |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
# Initialize weights and apply final processing | |
self.post_init() | |
def get_model(self): | |
return self.model | |
def get_sigmas(self, timesteps, device, n_dim=4, dtype=torch.float32): | |
sigmas = self.model.noise_scheduler.sigmas.to(device=device, dtype=dtype) | |
schedule_timesteps = self.model.noise_scheduler.timesteps.to(device) | |
timesteps = timesteps.to(device) | |
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] | |
sigma = sigmas[step_indices].flatten() | |
while len(sigma.shape) < n_dim: | |
sigma = sigma.unsqueeze(-1) | |
return sigma | |
def mask_drop(self, latents, drop_prob=0.1): | |
if drop_prob <= 0: | |
return latents | |
mask = torch.bernoulli(torch.zeros(latents.shape[0], device=latents.device, dtype=latents.dtype) + drop_prob) | |
while len(mask.shape) < len(latents.shape): | |
mask = mask.unsqueeze(-1) | |
mask = 1 - mask # need to flip 0 <-> 1 | |
return latents * mask | |
def forward( | |
self, | |
input_ids: torch.LongTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
images: Optional[torch.FloatTensor] = None, | |
target_images: Optional[torch.FloatTensor] = None, | |
image_sizes: Optional[List[List[int]]] = None, | |
return_dict: Optional[bool] = None, | |
modalities: Optional[List[str]] = ["image"], | |
dpo_forward: Optional[bool] = False, | |
cache_position=None, | |
) -> Union[Tuple, CausalLMOutputWithPast]: | |
if inputs_embeds is None: | |
(input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes) | |
outputs = self.model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_values=past_key_values, | |
inputs_embeds=inputs_embeds, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
hidden_states = outputs[0] | |
logits = self.lm_head(hidden_states) | |
if labels is not None: | |
shift_logits = logits[..., :-1, :].contiguous() | |
shift_labels = labels[..., 1:].contiguous() | |
loss_fct = torch.nn.CrossEntropyLoss() | |
shift_logits = shift_logits.view(-1, self.config.vocab_size) | |
shift_labels = shift_labels.view(-1) | |
shift_labels = shift_labels.to(shift_logits.device) | |
loss = loss_fct(shift_logits, shift_labels) | |
if target_images is not None: | |
vae = self.model.get_sana_vae() | |
latents = vae.encode(target_images).latent | |
if "shift_factor" in vae.config and vae.config.shift_factor is not None: | |
latents = latents - vae.config.shift_factor | |
latents = latents * vae.config.scaling_factor | |
noise = torch.randn_like(latents, device=latents.device) | |
weighting_scheme = "uniform" | |
u = compute_density_for_timestep_sampling( | |
weighting_scheme=weighting_scheme, | |
batch_size=latents.shape[0], | |
logit_mean=0.0, | |
logit_std=1.0, | |
mode_scale=1.29, | |
) | |
indices = (u * self.model.noise_scheduler.config.num_train_timesteps).long() | |
timesteps = self.model.noise_scheduler.timesteps[indices].to(device=latents.device) | |
sigmas = self.get_sigmas(timesteps, latents.device, n_dim=latents.ndim, dtype=latents.dtype) | |
noisy_latents = (1.0 - sigmas) * latents + sigmas * noise | |
sana = self.model.get_sana() | |
start_pos = (labels == self.config.image_start_tag_id).float().argmax(dim=1) | |
end_pos = (labels == self.config.image_end_tag_id).float().argmax(dim=1) | |
breakpoint() | |
selected_hidden_states = [] | |
for b in range(hidden_states.size(0)): | |
start = start_pos[b].item() + 1 | |
end = end_pos[b].item() | |
hidden_states_filter = hidden_states[b, start:end, :] | |
if hidden_states_filter.size(1) != 730: | |
hidden_states_filter = hidden_states[b, -730:, :] | |
selected_hidden_states.append(hidden_states_filter) | |
selected_hidden_states = torch.stack(selected_hidden_states, dim=0) | |
diffusion_pred = sana( | |
hidden_states=noisy_latents, | |
timestep=timesteps, | |
encoder_hidden_states=self.model.diffusion_connector(self.mask_drop(selected_hidden_states)), | |
encoder_attention_mask=None, | |
).sample | |
target = noise - latents | |
weighting = compute_loss_weighting_for_sd3(weighting_scheme=weighting_scheme, sigmas=sigmas) | |
diff_loss = torch.mean( | |
(weighting.float() * (diffusion_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), | |
1, | |
) | |
diff_loss = diff_loss.mean() | |
rank0_print(f" Cross-entropy loss {loss}, Diffusion loss {diff_loss} ") | |
loss += diff_loss | |
return CausalLMOutputWithPast( | |
loss=loss, | |
logits=logits, | |
past_key_values=outputs.past_key_values, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
def generate( | |
self, | |
inputs: Optional[torch.Tensor] = None, | |
images: Optional[torch.Tensor] = None, | |
image_sizes: Optional[torch.Tensor] = None, | |
modalities: Optional[List[str]] = ["image"], | |
**kwargs, | |
) -> Union[GenerateOutput, torch.LongTensor]: | |
position_ids = kwargs.pop("position_ids", None) | |
attention_mask = kwargs.pop("attention_mask", None) | |
if "inputs_embeds" in kwargs: | |
raise NotImplementedError("`inputs_embeds` is not supported") | |
if images is not None: | |
(inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes) | |
else: | |
inputs_embeds = self.get_model().embed_tokens(inputs) | |
return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs) | |
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): | |
images = kwargs.pop("images", None) | |
image_sizes = kwargs.pop("image_sizes", None) | |
inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs) | |
if images is not None: | |
inputs["images"] = images | |
if image_sizes is not None: | |
inputs["image_sizes"] = image_sizes | |
return inputs | |
AutoConfig.register("blip3o_qwen", blip3oQwenConfig) | |
AutoModelForCausalLM.register(blip3oQwenConfig, blip3oQwenForCausalLM) | |