|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import torch |
|
import torch.nn as nn |
|
|
|
from typing import Union, List, Any, Optional |
|
from diffusers.configuration_utils import ConfigMixin, register_to_config |
|
from PIL import Image |
|
from diffusers import DiffusionPipeline, AutoencoderKL |
|
from transformers import CLIPTextModelWithProjection, T5EncoderModel, CLIPTokenizer, T5Tokenizer |
|
|
|
def get_noise( |
|
num_samples: int, |
|
channel: int, |
|
height: int, |
|
width: int, |
|
device: torch.device, |
|
dtype: torch.dtype, |
|
seed: int, |
|
): |
|
return torch.randn( |
|
num_samples, |
|
channel, |
|
height // 8, |
|
width // 8, |
|
device=device, |
|
dtype=dtype, |
|
generator=torch.Generator(device=device).manual_seed(seed), |
|
) |
|
|
|
def get_clip_prompt_embeds( |
|
clip_tokenizers, |
|
clip_text_encoders, |
|
prompt: Union[str, List[str]], |
|
num_images_per_prompt: int = 1, |
|
device: Optional[torch.device] = None, |
|
clip_skip: Optional[int] = None, |
|
clip_model_index: int = 0, |
|
): |
|
|
|
tokenizer_max_length = 77 |
|
tokenizer = clip_tokenizers[clip_model_index] |
|
text_encoder = clip_text_encoders[clip_model_index] |
|
|
|
batch_size = len(prompt) |
|
|
|
text_inputs = tokenizer( |
|
prompt, |
|
padding="max_length", |
|
max_length=tokenizer_max_length, |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
|
|
text_input_ids = text_inputs.input_ids |
|
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids |
|
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): |
|
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1]) |
|
|
|
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) |
|
pooled_prompt_embeds = prompt_embeds[0] |
|
|
|
if clip_skip is None: |
|
prompt_embeds = prompt_embeds.hidden_states[-2] |
|
else: |
|
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] |
|
|
|
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) |
|
|
|
_, seq_len, _ = prompt_embeds.shape |
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) |
|
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) |
|
|
|
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) |
|
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) |
|
|
|
return prompt_embeds, pooled_prompt_embeds |
|
|
|
def get_t5_prompt_embeds( |
|
tokenizer_3, |
|
text_encoder_3, |
|
prompt: Union[str, List[str]] = None, |
|
num_images_per_prompt: int = 1, |
|
max_sequence_length: int = 256, |
|
device: Optional[torch.device] = None, |
|
dtype: Optional[torch.dtype] = None, |
|
): |
|
|
|
tokenizer_max_length = 77 |
|
batch_size = len(prompt) |
|
|
|
text_inputs = tokenizer_3( |
|
prompt, |
|
padding="max_length", |
|
max_length=max_sequence_length, |
|
truncation=True, |
|
add_special_tokens=True, |
|
return_tensors="pt", |
|
) |
|
text_input_ids = text_inputs.input_ids |
|
untruncated_ids = tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids |
|
|
|
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): |
|
removed_text = tokenizer_3.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1]) |
|
|
|
prompt_embeds = text_encoder_3(text_input_ids.to(device))[0] |
|
|
|
dtype = text_encoder_3.dtype |
|
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) |
|
|
|
_, seq_len, _ = prompt_embeds.shape |
|
|
|
|
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) |
|
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) |
|
|
|
return prompt_embeds |
|
|
|
|
|
@torch.no_grad() |
|
def encode_text(clip_tokenizers, clip_text_encoders, tokenizer_3, text_encoder_3, prompt, device, max_sequence_length=256): |
|
|
|
prompt_embed, pooled_prompt_embed = get_clip_prompt_embeds(clip_tokenizers, clip_text_encoders, prompt=prompt, device=device, clip_model_index=0) |
|
prompt_2_embed, pooled_prompt_2_embed = get_clip_prompt_embeds(clip_tokenizers, clip_text_encoders, prompt=prompt, device=device, clip_model_index=1) |
|
clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) |
|
|
|
t5_prompt_embed = get_t5_prompt_embeds(tokenizer_3, text_encoder_3, prompt=prompt, max_sequence_length=max_sequence_length, device=device) |
|
|
|
clip_prompt_embeds = torch.nn.functional.pad(clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])) |
|
|
|
prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) |
|
pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) |
|
|
|
return prompt_embeds, pooled_prompt_embeds |
|
|
|
|
|
class TeEFusionSD3Pipeline(DiffusionPipeline, ConfigMixin): |
|
|
|
@register_to_config |
|
def __init__( |
|
self, |
|
transformer: nn.Module, |
|
text_encoder: CLIPTextModelWithProjection, |
|
text_encoder_2: CLIPTextModelWithProjection, |
|
text_encoder_3: T5EncoderModel, |
|
tokenizer: CLIPTokenizer, |
|
tokenizer_2: CLIPTokenizer, |
|
tokenizer_3: T5Tokenizer, |
|
vae: AutoencoderKL, |
|
scheduler: Any |
|
): |
|
super().__init__() |
|
|
|
self.register_modules( |
|
transformer=transformer, |
|
text_encoder=text_encoder, |
|
text_encoder_2=text_encoder_2, |
|
text_encoder_3=text_encoder_3, |
|
tokenizer=tokenizer, |
|
tokenizer_2=tokenizer_2, |
|
tokenizer_3=tokenizer_3, |
|
vae=vae, |
|
scheduler=scheduler |
|
) |
|
|
|
|
|
@classmethod |
|
def from_pretrained( |
|
cls, |
|
pretrained_model_name_or_path: Union[str, os.PathLike], |
|
**kwargs, |
|
) -> "TeEFusionSD3Pipeline": |
|
|
|
return super().from_pretrained(pretrained_model_name_or_path, **kwargs) |
|
|
|
def save_pretrained(self, save_directory: Union[str, os.PathLike]): |
|
super().save_pretrained(save_directory) |
|
|
|
@torch.no_grad() |
|
def __call__( |
|
self, |
|
prompt: Union[str, List[str]], |
|
num_inference_steps: int = 50, |
|
guidance_scale: float = 7.5, |
|
latents: torch.FloatTensor = None, |
|
height: int = 1024, |
|
width: int = 1024, |
|
seed: int = 0, |
|
): |
|
if isinstance(prompt, str): |
|
prompt = [prompt] |
|
|
|
device = self.transformer.device |
|
|
|
clip_tokenizers = [self.tokenizer, self.tokenizer_2] |
|
clip_text_encoders = [self.text_encoder, self.text_encoder_2] |
|
|
|
prompt_embeds, pooled_prompt_embeds = encode_text(clip_tokenizers, clip_text_encoders, self.tokenizer_3, self.text_encoder_3, prompt, device) |
|
|
|
_, negative_pooled_prompt_embeds = encode_text(clip_tokenizers, clip_text_encoders, self.tokenizer_3, self.text_encoder_3, [''], device) |
|
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps, device=device) |
|
timesteps = self.scheduler.timesteps |
|
|
|
bs = len(prompt) |
|
channels = self.transformer.config.in_channels |
|
height = 16 * (height // 16) |
|
width = 16 * (width // 16) |
|
|
|
|
|
if latents is None: |
|
latents = get_noise( |
|
bs, |
|
channels, |
|
height, |
|
width, |
|
device=device, |
|
dtype=self.transformer.dtype, |
|
seed=seed, |
|
) |
|
|
|
for i, t in enumerate(timesteps): |
|
noise_pred = self.transformer( |
|
hidden_states=latents, |
|
timestep=t.reshape(1), |
|
encoder_hidden_states=prompt_embeds, |
|
pooled_projections=pooled_prompt_embeds, |
|
return_dict=False, |
|
txt_align_guidance=torch.tensor(data=(guidance_scale,), dtype=self.transformer.dtype, device=self.transformer.device) * 1000., |
|
txt_align_vec=pooled_prompt_embeds - negative_pooled_prompt_embeds |
|
)[0] |
|
|
|
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] |
|
|
|
x = latents.float() |
|
|
|
with torch.no_grad(): |
|
with torch.autocast(device_type=device.type, dtype=torch.float32): |
|
if hasattr(self.vae.config, 'scaling_factor') and self.vae.config.scaling_factor is not None: |
|
x = x / self.vae.config.scaling_factor |
|
if hasattr(self.vae.config, 'shift_factor') and self.vae.config.shift_factor is not None: |
|
x = x + self.vae.config.shift_factor |
|
x = self.vae.decode(x, return_dict=False)[0] |
|
|
|
|
|
x = (x / 2 + 0.5).clamp(0, 1) |
|
x = x.cpu().permute(0, 2, 3, 1).float().numpy() |
|
images = (x * 255).round().astype("uint8") |
|
pil_images = [Image.fromarray(image) for image in images] |
|
|
|
return pil_images |
|
|
|
|
|
|
|
|