|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import inspect |
|
from typing import Any, Callable, Dict, List, Optional, Union |
|
|
|
import numpy as np |
|
import torch |
|
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer |
|
|
|
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor |
|
from diffusers.loaders import QwenImageLoraLoaderMixin |
|
from diffusers.models import AutoencoderKLQwenImage |
|
|
|
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler |
|
from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring |
|
from diffusers.utils.torch_utils import randn_tensor |
|
from diffusers.pipelines.pipeline_utils import DiffusionPipeline |
|
from diffusers.pipelines.qwenimage.pipeline_output import QwenImagePipelineOutput |
|
|
|
from transformer_qwenimage import QwenImageTransformer2DModel |
|
from controlnet_qwenimage import QwenImageControlNetModel, QwenImageMultiControlNetModel |
|
|
|
if is_torch_xla_available(): |
|
import torch_xla.core.xla_model as xm |
|
|
|
XLA_AVAILABLE = True |
|
else: |
|
XLA_AVAILABLE = False |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
EXAMPLE_DOC_STRING = """ |
|
Examples: |
|
```py |
|
>>> import torch |
|
>>> from diffusers.utils import load_image |
|
>>> from diffusers import QwenImageControlNetPipeline |
|
|
|
>>> controlnet = QwenImageControlNetModel.from_pretrained("InstantX/Qwen-Image-ControlNet-Union", torch_dtype=torch.bfloat16) |
|
>>> pipe = QwenImageControlNetPipeline.from_pretrained("Qwen/Qwen-Image", controlnet=controlnet, torch_dtype=torch.bfloat16) |
|
>>> pipe.to("cuda") |
|
>>> prompt = "" |
|
>>> negative_prompt = " " |
|
>>> control_image = load_image(CONDITION_IMAGE_PATH) |
|
>>> # Depending on the variant being used, the pipeline call will slightly vary. |
|
>>> # Refer to the pipeline documentation for more details. |
|
>>> image = pipe(prompt, negative_prompt=negative_prompt, control_image=control_image, controlnet_conditioning_scale=1.0, num_inference_steps=30, true_cfg_scale=4.0).images[0] |
|
>>> image.save("qwenimage_cn_union.png") |
|
``` |
|
""" |
|
|
|
|
|
def calculate_shift( |
|
image_seq_len, |
|
base_seq_len: int = 256, |
|
max_seq_len: int = 4096, |
|
base_shift: float = 0.5, |
|
max_shift: float = 1.15, |
|
): |
|
m = (max_shift - base_shift) / (max_seq_len - base_seq_len) |
|
b = base_shift - m * base_seq_len |
|
mu = image_seq_len * m + b |
|
return mu |
|
|
|
|
|
def retrieve_latents( |
|
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" |
|
): |
|
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": |
|
return encoder_output.latent_dist.sample(generator) |
|
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": |
|
return encoder_output.latent_dist.mode() |
|
elif hasattr(encoder_output, "latents"): |
|
return encoder_output.latents |
|
else: |
|
raise AttributeError("Could not access latents of provided encoder_output") |
|
|
|
|
|
def retrieve_timesteps( |
|
scheduler, |
|
num_inference_steps: Optional[int] = None, |
|
device: Optional[Union[str, torch.device]] = None, |
|
timesteps: Optional[List[int]] = None, |
|
sigmas: Optional[List[float]] = None, |
|
**kwargs, |
|
): |
|
r""" |
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles |
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. |
|
|
|
Args: |
|
scheduler (`SchedulerMixin`): |
|
The scheduler to get timesteps from. |
|
num_inference_steps (`int`): |
|
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` |
|
must be `None`. |
|
device (`str` or `torch.device`, *optional*): |
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. |
|
timesteps (`List[int]`, *optional*): |
|
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, |
|
`num_inference_steps` and `sigmas` must be `None`. |
|
sigmas (`List[float]`, *optional*): |
|
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, |
|
`num_inference_steps` and `timesteps` must be `None`. |
|
|
|
Returns: |
|
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the |
|
second element is the number of inference steps. |
|
""" |
|
if timesteps is not None and sigmas is not None: |
|
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") |
|
if timesteps is not None: |
|
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
|
if not accepts_timesteps: |
|
raise ValueError( |
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
|
f" timestep schedules. Please check whether you are using the correct scheduler." |
|
) |
|
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) |
|
timesteps = scheduler.timesteps |
|
num_inference_steps = len(timesteps) |
|
elif sigmas is not None: |
|
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
|
if not accept_sigmas: |
|
raise ValueError( |
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
|
f" sigmas schedules. Please check whether you are using the correct scheduler." |
|
) |
|
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) |
|
timesteps = scheduler.timesteps |
|
num_inference_steps = len(timesteps) |
|
else: |
|
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) |
|
timesteps = scheduler.timesteps |
|
return timesteps, num_inference_steps |
|
|
|
|
|
class QwenImageControlNetPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): |
|
r""" |
|
The QwenImage pipeline for text-to-image generation. |
|
|
|
Args: |
|
transformer ([`QwenImageTransformer2DModel`]): |
|
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. |
|
scheduler ([`FlowMatchEulerDiscreteScheduler`]): |
|
A scheduler to be used in combination with `transformer` to denoise the encoded image latents. |
|
vae ([`AutoencoderKL`]): |
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. |
|
text_encoder ([`Qwen2.5-VL-7B-Instruct`]): |
|
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the |
|
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. |
|
tokenizer (`QwenTokenizer`): |
|
Tokenizer of class |
|
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). |
|
""" |
|
|
|
model_cpu_offload_seq = "text_encoder->transformer->vae" |
|
_callback_tensor_inputs = ["latents", "prompt_embeds"] |
|
|
|
def __init__( |
|
self, |
|
scheduler: FlowMatchEulerDiscreteScheduler, |
|
vae: AutoencoderKLQwenImage, |
|
text_encoder: Qwen2_5_VLForConditionalGeneration, |
|
tokenizer: Qwen2Tokenizer, |
|
transformer: QwenImageTransformer2DModel, |
|
controlnet: QwenImageControlNetModel, |
|
): |
|
super().__init__() |
|
|
|
self.register_modules( |
|
vae=vae, |
|
text_encoder=text_encoder, |
|
tokenizer=tokenizer, |
|
transformer=transformer, |
|
scheduler=scheduler, |
|
controlnet=controlnet, |
|
) |
|
self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 |
|
|
|
|
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) |
|
self.tokenizer_max_length = 1024 |
|
self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" |
|
self.prompt_template_encode_start_idx = 34 |
|
self.default_sample_size = 128 |
|
|
|
def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): |
|
bool_mask = mask.bool() |
|
valid_lengths = bool_mask.sum(dim=1) |
|
selected = hidden_states[bool_mask] |
|
split_result = torch.split(selected, valid_lengths.tolist(), dim=0) |
|
|
|
return split_result |
|
|
|
def _get_qwen_prompt_embeds( |
|
self, |
|
prompt: Union[str, List[str]] = None, |
|
device: Optional[torch.device] = None, |
|
dtype: Optional[torch.dtype] = None, |
|
): |
|
device = device or self._execution_device |
|
dtype = dtype or self.text_encoder.dtype |
|
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt |
|
|
|
template = self.prompt_template_encode |
|
drop_idx = self.prompt_template_encode_start_idx |
|
txt = [template.format(e) for e in prompt] |
|
txt_tokens = self.tokenizer( |
|
txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" |
|
).to(self.device) |
|
encoder_hidden_states = self.text_encoder( |
|
input_ids=txt_tokens.input_ids, |
|
attention_mask=txt_tokens.attention_mask, |
|
output_hidden_states=True, |
|
) |
|
hidden_states = encoder_hidden_states.hidden_states[-1] |
|
split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) |
|
split_hidden_states = [e[drop_idx:] for e in split_hidden_states] |
|
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] |
|
max_seq_len = max([e.size(0) for e in split_hidden_states]) |
|
prompt_embeds = torch.stack( |
|
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] |
|
) |
|
encoder_attention_mask = torch.stack( |
|
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] |
|
) |
|
|
|
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) |
|
|
|
return prompt_embeds, encoder_attention_mask |
|
|
|
def encode_prompt( |
|
self, |
|
prompt: Union[str, List[str]], |
|
device: Optional[torch.device] = None, |
|
num_images_per_prompt: int = 1, |
|
prompt_embeds: Optional[torch.Tensor] = None, |
|
prompt_embeds_mask: Optional[torch.Tensor] = None, |
|
max_sequence_length: int = 1024, |
|
): |
|
r""" |
|
|
|
Args: |
|
prompt (`str` or `List[str]`, *optional*): |
|
prompt to be encoded |
|
device: (`torch.device`): |
|
torch device |
|
num_images_per_prompt (`int`): |
|
number of images that should be generated per prompt |
|
prompt_embeds (`torch.Tensor`, *optional*): |
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not |
|
provided, text embeddings will be generated from `prompt` input argument. |
|
""" |
|
device = device or self._execution_device |
|
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt |
|
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] |
|
|
|
if prompt_embeds is None: |
|
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) |
|
|
|
_, seq_len, _ = prompt_embeds.shape |
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) |
|
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) |
|
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) |
|
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) |
|
|
|
return prompt_embeds, prompt_embeds_mask |
|
|
|
def check_inputs( |
|
self, |
|
prompt, |
|
height, |
|
width, |
|
negative_prompt=None, |
|
prompt_embeds=None, |
|
negative_prompt_embeds=None, |
|
prompt_embeds_mask=None, |
|
negative_prompt_embeds_mask=None, |
|
callback_on_step_end_tensor_inputs=None, |
|
max_sequence_length=None, |
|
): |
|
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: |
|
logger.warning( |
|
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" |
|
) |
|
|
|
if callback_on_step_end_tensor_inputs is not None and not all( |
|
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs |
|
): |
|
raise ValueError( |
|
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" |
|
) |
|
|
|
if prompt is not None and prompt_embeds is not None: |
|
raise ValueError( |
|
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" |
|
" only forward one of the two." |
|
) |
|
elif prompt is None and prompt_embeds is None: |
|
raise ValueError( |
|
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." |
|
) |
|
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): |
|
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") |
|
|
|
if negative_prompt is not None and negative_prompt_embeds is not None: |
|
raise ValueError( |
|
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" |
|
f" {negative_prompt_embeds}. Please make sure to only forward one of the two." |
|
) |
|
|
|
if prompt_embeds is not None and prompt_embeds_mask is None: |
|
raise ValueError( |
|
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." |
|
) |
|
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: |
|
raise ValueError( |
|
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." |
|
) |
|
|
|
if max_sequence_length is not None and max_sequence_length > 1024: |
|
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") |
|
|
|
@staticmethod |
|
def _pack_latents(latents, batch_size, num_channels_latents, height, width): |
|
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) |
|
latents = latents.permute(0, 2, 4, 1, 3, 5) |
|
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) |
|
|
|
return latents |
|
|
|
@staticmethod |
|
def _unpack_latents(latents, height, width, vae_scale_factor): |
|
batch_size, num_patches, channels = latents.shape |
|
|
|
|
|
|
|
height = 2 * (int(height) // (vae_scale_factor * 2)) |
|
width = 2 * (int(width) // (vae_scale_factor * 2)) |
|
|
|
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) |
|
latents = latents.permute(0, 3, 1, 4, 2, 5) |
|
|
|
latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) |
|
|
|
return latents |
|
|
|
def enable_vae_slicing(self): |
|
r""" |
|
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to |
|
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. |
|
""" |
|
self.vae.enable_slicing() |
|
|
|
def disable_vae_slicing(self): |
|
r""" |
|
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to |
|
computing decoding in one step. |
|
""" |
|
self.vae.disable_slicing() |
|
|
|
def enable_vae_tiling(self): |
|
r""" |
|
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to |
|
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow |
|
processing larger images. |
|
""" |
|
self.vae.enable_tiling() |
|
|
|
def disable_vae_tiling(self): |
|
r""" |
|
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to |
|
computing decoding in one step. |
|
""" |
|
self.vae.disable_tiling() |
|
|
|
def prepare_latents( |
|
self, |
|
batch_size, |
|
num_channels_latents, |
|
height, |
|
width, |
|
dtype, |
|
device, |
|
generator, |
|
latents=None, |
|
): |
|
|
|
|
|
height = 2 * (int(height) // (self.vae_scale_factor * 2)) |
|
width = 2 * (int(width) // (self.vae_scale_factor * 2)) |
|
|
|
shape = (batch_size, 1, num_channels_latents, height, width) |
|
|
|
if latents is not None: |
|
return latents.to(device=device, dtype=dtype) |
|
|
|
if isinstance(generator, list) and len(generator) != batch_size: |
|
raise ValueError( |
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" |
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators." |
|
) |
|
|
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
|
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) |
|
|
|
return latents |
|
|
|
|
|
def prepare_image( |
|
self, |
|
image, |
|
width, |
|
height, |
|
batch_size, |
|
num_images_per_prompt, |
|
device, |
|
dtype, |
|
do_classifier_free_guidance=False, |
|
guess_mode=False, |
|
): |
|
if isinstance(image, torch.Tensor): |
|
pass |
|
else: |
|
image = self.image_processor.preprocess(image, height=height, width=width) |
|
|
|
image_batch_size = image.shape[0] |
|
|
|
if image_batch_size == 1: |
|
repeat_by = batch_size |
|
else: |
|
|
|
repeat_by = num_images_per_prompt |
|
|
|
image = image.repeat_interleave(repeat_by, dim=0) |
|
|
|
image = image.to(device=device, dtype=dtype) |
|
|
|
if do_classifier_free_guidance and not guess_mode: |
|
image = torch.cat([image] * 2) |
|
|
|
return image |
|
|
|
@property |
|
def guidance_scale(self): |
|
return self._guidance_scale |
|
|
|
@property |
|
def attention_kwargs(self): |
|
return self._attention_kwargs |
|
|
|
@property |
|
def num_timesteps(self): |
|
return self._num_timesteps |
|
|
|
@property |
|
def current_timestep(self): |
|
return self._current_timestep |
|
|
|
@property |
|
def interrupt(self): |
|
return self._interrupt |
|
|
|
@torch.no_grad() |
|
@replace_example_docstring(EXAMPLE_DOC_STRING) |
|
def __call__( |
|
self, |
|
prompt: Union[str, List[str]] = None, |
|
negative_prompt: Union[str, List[str]] = None, |
|
true_cfg_scale: float = 4.0, |
|
height: Optional[int] = None, |
|
width: Optional[int] = None, |
|
num_inference_steps: int = 50, |
|
sigmas: Optional[List[float]] = None, |
|
guidance_scale: float = 1.0, |
|
|
|
control_guidance_start: Union[float, List[float]] = 0.0, |
|
control_guidance_end: Union[float, List[float]] = 1.0, |
|
control_image: PipelineImageInput = None, |
|
controlnet_conditioning_scale: Union[float, List[float]] = 1.0, |
|
|
|
num_images_per_prompt: int = 1, |
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
|
latents: Optional[torch.Tensor] = None, |
|
prompt_embeds: Optional[torch.Tensor] = None, |
|
prompt_embeds_mask: Optional[torch.Tensor] = None, |
|
negative_prompt_embeds: Optional[torch.Tensor] = None, |
|
negative_prompt_embeds_mask: Optional[torch.Tensor] = None, |
|
output_type: Optional[str] = "pil", |
|
return_dict: bool = True, |
|
attention_kwargs: Optional[Dict[str, Any]] = None, |
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, |
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"], |
|
max_sequence_length: int = 512, |
|
): |
|
r""" |
|
Function invoked when calling the pipeline for generation. |
|
|
|
Args: |
|
prompt (`str` or `List[str]`, *optional*): |
|
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. |
|
instead. |
|
negative_prompt (`str` or `List[str]`, *optional*): |
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass |
|
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is |
|
not greater than `1`). |
|
true_cfg_scale (`float`, *optional*, defaults to 1.0): |
|
When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. |
|
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): |
|
The height in pixels of the generated image. This is set to 1024 by default for the best results. |
|
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): |
|
The width in pixels of the generated image. This is set to 1024 by default for the best results. |
|
num_inference_steps (`int`, *optional*, defaults to 50): |
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
|
expense of slower inference. |
|
sigmas (`List[float]`, *optional*): |
|
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in |
|
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed |
|
will be used. |
|
guidance_scale (`float`, *optional*, defaults to 3.5): |
|
Guidance scale as defined in [Classifier-Free Diffusion |
|
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. |
|
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting |
|
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to |
|
the text `prompt`, usually at the expense of lower image quality. |
|
num_images_per_prompt (`int`, *optional*, defaults to 1): |
|
The number of images to generate per prompt. |
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*): |
|
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) |
|
to make generation deterministic. |
|
latents (`torch.Tensor`, *optional*): |
|
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image |
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents |
|
tensor will be generated by sampling using the supplied random `generator`. |
|
prompt_embeds (`torch.Tensor`, *optional*): |
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not |
|
provided, text embeddings will be generated from `prompt` input argument. |
|
negative_prompt_embeds (`torch.Tensor`, *optional*): |
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt |
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input |
|
argument. |
|
output_type (`str`, *optional*, defaults to `"pil"`): |
|
The output format of the generate image. Choose between |
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. |
|
attention_kwargs (`dict`, *optional*): |
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under |
|
`self.processor` in |
|
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). |
|
callback_on_step_end (`Callable`, *optional*): |
|
A function that calls at the end of each denoising steps during the inference. The function is called |
|
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, |
|
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by |
|
`callback_on_step_end_tensor_inputs`. |
|
callback_on_step_end_tensor_inputs (`List`, *optional*): |
|
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list |
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the |
|
`._callback_tensor_inputs` attribute of your pipeline class. |
|
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. |
|
|
|
Examples: |
|
|
|
Returns: |
|
[`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`: |
|
[`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When |
|
returning a tuple, the first element is a list with the generated images. |
|
""" |
|
|
|
height = height or self.default_sample_size * self.vae_scale_factor |
|
width = width or self.default_sample_size * self.vae_scale_factor |
|
|
|
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): |
|
control_guidance_start = len(control_guidance_end) * [control_guidance_start] |
|
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): |
|
control_guidance_end = len(control_guidance_start) * [control_guidance_end] |
|
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): |
|
mult = len(self.controlnet.nets) if isinstance(self.controlnet, QwenImageMultiControlNetModel) else 1 |
|
control_guidance_start, control_guidance_end = ( |
|
mult * [control_guidance_start], |
|
mult * [control_guidance_end], |
|
) |
|
|
|
|
|
self.check_inputs( |
|
prompt, |
|
height, |
|
width, |
|
negative_prompt=negative_prompt, |
|
prompt_embeds=prompt_embeds, |
|
negative_prompt_embeds=negative_prompt_embeds, |
|
prompt_embeds_mask=prompt_embeds_mask, |
|
negative_prompt_embeds_mask=negative_prompt_embeds_mask, |
|
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, |
|
max_sequence_length=max_sequence_length, |
|
) |
|
|
|
self._guidance_scale = guidance_scale |
|
self._attention_kwargs = attention_kwargs |
|
self._current_timestep = None |
|
self._interrupt = False |
|
|
|
|
|
if prompt is not None and isinstance(prompt, str): |
|
batch_size = 1 |
|
elif prompt is not None and isinstance(prompt, list): |
|
batch_size = len(prompt) |
|
else: |
|
batch_size = prompt_embeds.shape[0] |
|
|
|
device = self._execution_device |
|
|
|
has_neg_prompt = negative_prompt is not None or ( |
|
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None |
|
) |
|
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt |
|
prompt_embeds, prompt_embeds_mask = self.encode_prompt( |
|
prompt=prompt, |
|
prompt_embeds=prompt_embeds, |
|
prompt_embeds_mask=prompt_embeds_mask, |
|
device=device, |
|
num_images_per_prompt=num_images_per_prompt, |
|
max_sequence_length=max_sequence_length, |
|
) |
|
if do_true_cfg: |
|
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( |
|
prompt=negative_prompt, |
|
prompt_embeds=negative_prompt_embeds, |
|
prompt_embeds_mask=negative_prompt_embeds_mask, |
|
device=device, |
|
num_images_per_prompt=num_images_per_prompt, |
|
max_sequence_length=max_sequence_length, |
|
) |
|
|
|
|
|
num_channels_latents = self.transformer.config.in_channels // 4 |
|
if isinstance(self.controlnet, QwenImageControlNetModel): |
|
control_image = self.prepare_image( |
|
image=control_image, |
|
width=width, |
|
height=height, |
|
batch_size=batch_size * num_images_per_prompt, |
|
num_images_per_prompt=num_images_per_prompt, |
|
device=device, |
|
dtype=self.vae.dtype, |
|
) |
|
height, width = control_image.shape[-2:] |
|
|
|
if control_image.ndim == 4: |
|
control_image = control_image.unsqueeze(2) |
|
|
|
|
|
self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) |
|
latents_mean = (torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1)).to(device) |
|
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(device) |
|
|
|
control_image = retrieve_latents(self.vae.encode(control_image), generator=generator) |
|
control_image = (control_image - latents_mean) * latents_std |
|
|
|
control_image = control_image.permute(0, 2, 1, 3, 4) |
|
|
|
|
|
control_image = self._pack_latents( |
|
control_image, |
|
batch_size=control_image.shape[0], |
|
num_channels_latents=num_channels_latents, |
|
height=control_image.shape[3], |
|
width=control_image.shape[4], |
|
) |
|
|
|
|
|
num_channels_latents = self.transformer.config.in_channels // 4 |
|
latents = self.prepare_latents( |
|
batch_size * num_images_per_prompt, |
|
num_channels_latents, |
|
height, |
|
width, |
|
prompt_embeds.dtype, |
|
device, |
|
generator, |
|
latents, |
|
) |
|
img_shapes = [(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)] * batch_size |
|
|
|
|
|
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas |
|
image_seq_len = latents.shape[1] |
|
mu = calculate_shift( |
|
image_seq_len, |
|
self.scheduler.config.get("base_image_seq_len", 256), |
|
self.scheduler.config.get("max_image_seq_len", 4096), |
|
self.scheduler.config.get("base_shift", 0.5), |
|
self.scheduler.config.get("max_shift", 1.15), |
|
) |
|
timesteps, num_inference_steps = retrieve_timesteps( |
|
self.scheduler, |
|
num_inference_steps, |
|
device, |
|
sigmas=sigmas, |
|
mu=mu, |
|
) |
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) |
|
self._num_timesteps = len(timesteps) |
|
|
|
controlnet_keep = [] |
|
for i in range(len(timesteps)): |
|
keeps = [ |
|
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) |
|
for s, e in zip(control_guidance_start, control_guidance_end) |
|
] |
|
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, QwenImageControlNetModel) else keeps) |
|
|
|
|
|
if self.transformer.config.guidance_embeds: |
|
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) |
|
guidance = guidance.expand(latents.shape[0]) |
|
else: |
|
guidance = None |
|
|
|
if self.attention_kwargs is None: |
|
self._attention_kwargs = {} |
|
|
|
|
|
self.scheduler.set_begin_index(0) |
|
with self.progress_bar(total=num_inference_steps) as progress_bar: |
|
for i, t in enumerate(timesteps): |
|
if self.interrupt: |
|
continue |
|
|
|
self._current_timestep = t |
|
|
|
timestep = t.expand(latents.shape[0]).to(latents.dtype) |
|
|
|
if isinstance(controlnet_keep[i], list): |
|
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] |
|
else: |
|
controlnet_cond_scale = controlnet_conditioning_scale |
|
if isinstance(controlnet_cond_scale, list): |
|
controlnet_cond_scale = controlnet_cond_scale[0] |
|
cond_scale = controlnet_cond_scale * controlnet_keep[i] |
|
|
|
|
|
controlnet_block_samples = self.controlnet( |
|
hidden_states=latents, |
|
controlnet_cond=control_image.to(dtype=latents.dtype, device=device), |
|
conditioning_scale=cond_scale, |
|
timestep=timestep / 1000, |
|
encoder_hidden_states=prompt_embeds, |
|
encoder_hidden_states_mask=prompt_embeds_mask, |
|
img_shapes=img_shapes, |
|
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), |
|
return_dict=False, |
|
) |
|
|
|
with self.transformer.cache_context("cond"): |
|
noise_pred = self.transformer( |
|
hidden_states=latents, |
|
timestep=timestep / 1000, |
|
encoder_hidden_states=prompt_embeds, |
|
encoder_hidden_states_mask=prompt_embeds_mask, |
|
img_shapes=img_shapes, |
|
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), |
|
controlnet_block_samples=controlnet_block_samples, |
|
attention_kwargs=self.attention_kwargs, |
|
return_dict=False, |
|
)[0] |
|
|
|
if do_true_cfg: |
|
with self.transformer.cache_context("uncond"): |
|
neg_noise_pred = self.transformer( |
|
hidden_states=latents, |
|
timestep=timestep / 1000, |
|
guidance=guidance, |
|
encoder_hidden_states_mask=negative_prompt_embeds_mask, |
|
encoder_hidden_states=negative_prompt_embeds, |
|
img_shapes=img_shapes, |
|
txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(), |
|
controlnet_block_samples=controlnet_block_samples, |
|
attention_kwargs=self.attention_kwargs, |
|
return_dict=False, |
|
)[0] |
|
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) |
|
|
|
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) |
|
noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) |
|
noise_pred = comb_pred * (cond_norm / noise_norm) |
|
|
|
|
|
latents_dtype = latents.dtype |
|
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] |
|
|
|
if latents.dtype != latents_dtype: |
|
if torch.backends.mps.is_available(): |
|
|
|
latents = latents.to(latents_dtype) |
|
|
|
if callback_on_step_end is not None: |
|
callback_kwargs = {} |
|
for k in callback_on_step_end_tensor_inputs: |
|
callback_kwargs[k] = locals()[k] |
|
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) |
|
|
|
latents = callback_outputs.pop("latents", latents) |
|
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) |
|
|
|
|
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
|
progress_bar.update() |
|
|
|
if XLA_AVAILABLE: |
|
xm.mark_step() |
|
|
|
self._current_timestep = None |
|
if output_type == "latent": |
|
image = latents |
|
else: |
|
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) |
|
latents = latents.to(self.vae.dtype) |
|
latents_mean = ( |
|
torch.tensor(self.vae.config.latents_mean) |
|
.view(1, self.vae.config.z_dim, 1, 1, 1) |
|
.to(latents.device, latents.dtype) |
|
) |
|
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( |
|
latents.device, latents.dtype |
|
) |
|
latents = latents / latents_std + latents_mean |
|
image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] |
|
image = self.image_processor.postprocess(image, output_type=output_type) |
|
|
|
|
|
self.maybe_free_model_hooks() |
|
|
|
if not return_dict: |
|
return (image,) |
|
|
|
return QwenImagePipelineOutput(images=image) |
|
|