TeEFusion / pipelines /sd3_teefusion_pipeline.py
怀墨
add pipelines code
5db2d1c
raw
history blame
9.46 kB
# Copyright (C) 2025 AIDC-AI
# This project is licensed under the Attribution-NonCommercial 4.0 International
# License (SPDX-License-Identifier: CC-BY-NC-4.0).
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
import torch.nn as nn
from typing import Union, List, Any, Optional
from diffusers.configuration_utils import ConfigMixin, register_to_config
from PIL import Image
from diffusers import DiffusionPipeline, AutoencoderKL
from transformers import CLIPTextModelWithProjection, T5EncoderModel, CLIPTokenizer, T5Tokenizer
def get_noise(
num_samples: int,
channel: int,
height: int,
width: int,
device: torch.device,
dtype: torch.dtype,
seed: int,
):
return torch.randn(
num_samples,
channel,
height // 8,
width // 8,
device=device,
dtype=dtype,
generator=torch.Generator(device=device).manual_seed(seed),
)
def get_clip_prompt_embeds(
clip_tokenizers,
clip_text_encoders,
prompt: Union[str, List[str]],
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
clip_skip: Optional[int] = None,
clip_model_index: int = 0,
):
tokenizer_max_length = 77
tokenizer = clip_tokenizers[clip_model_index]
text_encoder = clip_text_encoders[clip_model_index]
batch_size = len(prompt)
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1])
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds, pooled_prompt_embeds
def get_t5_prompt_embeds(
tokenizer_3,
text_encoder_3,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 256,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
tokenizer_max_length = 77
batch_size = len(prompt)
text_inputs = tokenizer_3(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = tokenizer_3.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1])
prompt_embeds = text_encoder_3(text_input_ids.to(device))[0]
dtype = text_encoder_3.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds
@torch.no_grad()
def encode_text(clip_tokenizers, clip_text_encoders, tokenizer_3, text_encoder_3, prompt, device, max_sequence_length=256):
prompt_embed, pooled_prompt_embed = get_clip_prompt_embeds(clip_tokenizers, clip_text_encoders, prompt=prompt, device=device, clip_model_index=0)
prompt_2_embed, pooled_prompt_2_embed = get_clip_prompt_embeds(clip_tokenizers, clip_text_encoders, prompt=prompt, device=device, clip_model_index=1)
clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
t5_prompt_embed = get_t5_prompt_embeds(tokenizer_3, text_encoder_3, prompt=prompt, max_sequence_length=max_sequence_length, device=device)
clip_prompt_embeds = torch.nn.functional.pad(clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]))
prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
return prompt_embeds, pooled_prompt_embeds
class TeEFusionSD3Pipeline(DiffusionPipeline, ConfigMixin):
@register_to_config
def __init__(
self,
transformer: nn.Module,
text_encoder: CLIPTextModelWithProjection,
text_encoder_2: CLIPTextModelWithProjection,
text_encoder_3: T5EncoderModel,
tokenizer: CLIPTokenizer,
tokenizer_2: CLIPTokenizer,
tokenizer_3: T5Tokenizer,
vae: AutoencoderKL,
scheduler: Any
):
super().__init__()
self.register_modules(
transformer=transformer,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
text_encoder_3=text_encoder_3,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
tokenizer_3=tokenizer_3,
vae=vae,
scheduler=scheduler
)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
**kwargs,
) -> "TeEFusionSD3Pipeline":
return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
super().save_pretrained(save_directory)
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
latents: torch.FloatTensor = None,
height: int = 1024,
width: int = 1024,
seed: int = 0,
):
if isinstance(prompt, str):
prompt = [prompt]
device = self.transformer.device
clip_tokenizers = [self.tokenizer, self.tokenizer_2]
clip_text_encoders = [self.text_encoder, self.text_encoder_2]
prompt_embeds, pooled_prompt_embeds = encode_text(clip_tokenizers, clip_text_encoders, self.tokenizer_3, self.text_encoder_3, prompt, device)
_, negative_pooled_prompt_embeds = encode_text(clip_tokenizers, clip_text_encoders, self.tokenizer_3, self.text_encoder_3, [''], device)
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
bs = len(prompt)
channels = self.transformer.config.in_channels
height = 16 * (height // 16)
width = 16 * (width // 16)
# prepare input
if latents is None:
latents = get_noise(
bs,
channels,
height,
width,
device=device,
dtype=self.transformer.dtype,
seed=seed,
)
for i, t in enumerate(timesteps):
noise_pred = self.transformer(
hidden_states=latents,
timestep=t.reshape(1),
encoder_hidden_states=prompt_embeds,
pooled_projections=pooled_prompt_embeds,
return_dict=False,
txt_align_guidance=torch.tensor(data=(guidance_scale,), dtype=self.transformer.dtype, device=self.transformer.device) * 1000.,
txt_align_vec=pooled_prompt_embeds - negative_pooled_prompt_embeds
)[0]
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
x = latents.float()
with torch.no_grad():
with torch.autocast(device_type=device.type, dtype=torch.float32):
if hasattr(self.vae.config, 'scaling_factor') and self.vae.config.scaling_factor is not None:
x = x / self.vae.config.scaling_factor
if hasattr(self.vae.config, 'shift_factor') and self.vae.config.shift_factor is not None:
x = x + self.vae.config.shift_factor
x = self.vae.decode(x, return_dict=False)[0]
# bring into PIL format and save
x = (x / 2 + 0.5).clamp(0, 1)
x = x.cpu().permute(0, 2, 3, 1).float().numpy()
images = (x * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return pil_images