# Copyright (C) 2025 AIDC-AI # This project is licensed under the Attribution-NonCommercial 4.0 International # License (SPDX-License-Identifier: CC-BY-NC-4.0). # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import torch import torch.nn as nn from typing import Union, List, Any, Optional from diffusers.configuration_utils import ConfigMixin, register_to_config from PIL import Image from diffusers import DiffusionPipeline, AutoencoderKL from transformers import CLIPTextModelWithProjection, T5EncoderModel, CLIPTokenizer, T5Tokenizer def get_noise( num_samples: int, channel: int, height: int, width: int, device: torch.device, dtype: torch.dtype, seed: int, ): return torch.randn( num_samples, channel, height // 8, width // 8, device=device, dtype=dtype, generator=torch.Generator(device=device).manual_seed(seed), ) def get_clip_prompt_embeds( clip_tokenizers, clip_text_encoders, prompt: Union[str, List[str]], num_images_per_prompt: int = 1, device: Optional[torch.device] = None, clip_skip: Optional[int] = None, clip_model_index: int = 0, ): tokenizer_max_length = 77 tokenizer = clip_tokenizers[clip_model_index] text_encoder = clip_text_encoders[clip_model_index] batch_size = len(prompt) text_inputs = tokenizer( prompt, padding="max_length", max_length=tokenizer_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1]) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) pooled_prompt_embeds = prompt_embeds[0] if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds, pooled_prompt_embeds def get_t5_prompt_embeds( tokenizer_3, text_encoder_3, prompt: Union[str, List[str]] = None, num_images_per_prompt: int = 1, max_sequence_length: int = 256, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): tokenizer_max_length = 77 batch_size = len(prompt) text_inputs = tokenizer_3( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = tokenizer_3.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1]) prompt_embeds = text_encoder_3(text_input_ids.to(device))[0] dtype = text_encoder_3.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) return prompt_embeds @torch.no_grad() def encode_text(clip_tokenizers, clip_text_encoders, tokenizer_3, text_encoder_3, prompt, device, max_sequence_length=256): prompt_embed, pooled_prompt_embed = get_clip_prompt_embeds(clip_tokenizers, clip_text_encoders, prompt=prompt, device=device, clip_model_index=0) prompt_2_embed, pooled_prompt_2_embed = get_clip_prompt_embeds(clip_tokenizers, clip_text_encoders, prompt=prompt, device=device, clip_model_index=1) clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) t5_prompt_embed = get_t5_prompt_embeds(tokenizer_3, text_encoder_3, prompt=prompt, max_sequence_length=max_sequence_length, device=device) clip_prompt_embeds = torch.nn.functional.pad(clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])) prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) return prompt_embeds, pooled_prompt_embeds class TeEFusionSD3Pipeline(DiffusionPipeline, ConfigMixin): @register_to_config def __init__( self, transformer: nn.Module, text_encoder: CLIPTextModelWithProjection, text_encoder_2: CLIPTextModelWithProjection, text_encoder_3: T5EncoderModel, tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, tokenizer_3: T5Tokenizer, vae: AutoencoderKL, scheduler: Any ): super().__init__() self.register_modules( transformer=transformer, text_encoder=text_encoder, text_encoder_2=text_encoder_2, text_encoder_3=text_encoder_3, tokenizer=tokenizer, tokenizer_2=tokenizer_2, tokenizer_3=tokenizer_3, vae=vae, scheduler=scheduler ) @classmethod def from_pretrained( cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs, ) -> "TeEFusionSD3Pipeline": return super().from_pretrained(pretrained_model_name_or_path, **kwargs) def save_pretrained(self, save_directory: Union[str, os.PathLike]): super().save_pretrained(save_directory) @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], num_inference_steps: int = 50, guidance_scale: float = 7.5, latents: torch.FloatTensor = None, height: int = 1024, width: int = 1024, seed: int = 0, ): if isinstance(prompt, str): prompt = [prompt] device = self.transformer.device clip_tokenizers = [self.tokenizer, self.tokenizer_2] clip_text_encoders = [self.text_encoder, self.text_encoder_2] prompt_embeds, pooled_prompt_embeds = encode_text(clip_tokenizers, clip_text_encoders, self.tokenizer_3, self.text_encoder_3, prompt, device) _, negative_pooled_prompt_embeds = encode_text(clip_tokenizers, clip_text_encoders, self.tokenizer_3, self.text_encoder_3, [''], device) self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps bs = len(prompt) channels = self.transformer.config.in_channels height = 16 * (height // 16) width = 16 * (width // 16) # prepare input if latents is None: latents = get_noise( bs, channels, height, width, device=device, dtype=self.transformer.dtype, seed=seed, ) for i, t in enumerate(timesteps): noise_pred = self.transformer( hidden_states=latents, timestep=t.reshape(1), encoder_hidden_states=prompt_embeds, pooled_projections=pooled_prompt_embeds, return_dict=False, txt_align_guidance=torch.tensor(data=(guidance_scale,), dtype=self.transformer.dtype, device=self.transformer.device) * 1000., txt_align_vec=pooled_prompt_embeds - negative_pooled_prompt_embeds )[0] latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] x = latents.float() with torch.no_grad(): with torch.autocast(device_type=device.type, dtype=torch.float32): if hasattr(self.vae.config, 'scaling_factor') and self.vae.config.scaling_factor is not None: x = x / self.vae.config.scaling_factor if hasattr(self.vae.config, 'shift_factor') and self.vae.config.shift_factor is not None: x = x + self.vae.config.shift_factor x = self.vae.decode(x, return_dict=False)[0] # bring into PIL format and save x = (x / 2 + 0.5).clamp(0, 1) x = x.cpu().permute(0, 2, 3, 1).float().numpy() images = (x * 255).round().astype("uint8") pil_images = [Image.fromarray(image) for image in images] return pil_images