IC-Custom / ic_custom /pipelines /ic_custom_pipeline.py
Yaowei222's picture
fix md and pipeline
0da2326
import os
import re
from typing import List, Optional, Union
import PIL
from PIL import Image
from einops import rearrange
from torch import Tensor
import numpy as np
import torch
from safetensors.torch import load_file as load_sft
from diffusers.image_processor import VaeImageProcessor
from ..modules.layers import (
SingleStreamBlockLoraProcessor,
DoubleStreamBlockLoraProcessor,
)
from ..pipelines.sampling import denoise, prepare_image_cond, get_noise, get_schedule, prepare, prepare_with_redux, unpack
from ..utils.model_utils import (
load_ae,
load_clip,
load_ic_custom,
load_t5,
load_redux,
resolve_model_path
)
PipelineImageInput = Union[
PIL.Image.Image,
np.ndarray,
torch.Tensor,
List[PIL.Image.Image],
List[np.ndarray],
List[torch.Tensor],
]
class ICCustomPipeline:
def __init__(
self,
clip_path: str = "clip-vit-large-patch14",
t5_path: str = "t5-v1_1-xxl",
siglip_path: str = "siglip-so400m-patch14-384",
ae_path: str = "flux-fill-dev-ae",
dit_path: str = "flux-fill-dev-dit",
redux_path: str = "flux1-redux-dev",
lora_path: str = "dit_lora_0x1561",
img_txt_in_path: str = "dit_txt_img_in_0x1561",
boundary_embeddings_path: str = "dit_boundary_embeddings_0x1561",
task_register_embeddings_path: str = "dit_task_register_embeddings_0x1561",
network_alpha: int = None,
double_blocks_idx: str = None,
single_blocks_idx: str = None,
device: torch.device = torch.device("cuda"),
offload: bool = False,
weight_dtype: torch.dtype = torch.bfloat16,
show_progress: bool = False,
use_flash_attention: bool = False,
):
self.device = device
self.offload = offload
self.weight_dtype = weight_dtype
self.clip = load_clip(clip_path, self.device if not offload else "cpu", dtype=self.weight_dtype).eval()
self.t5 = load_t5(t5_path, self.device if not offload else "cpu", max_length=512, dtype=self.weight_dtype).eval()
self.ae = load_ae(ae_path, device="cpu" if offload else self.device).eval()
self.model = load_ic_custom(dit_path, device="cpu" if offload else self.device, dtype=self.weight_dtype).eval()
self.image_encoder = load_redux(redux_path, siglip_path, device="cpu" if offload else self.device, dtype=self.weight_dtype).eval()
self.vae_scale_factor = 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.mask_processor = VaeImageProcessor(resample="nearest", do_normalize=False)
self.set_lora(lora_path, network_alpha, double_blocks_idx, single_blocks_idx)
self.set_img_txt_in(img_txt_in_path)
self.set_boundary_embeddings(boundary_embeddings_path)
self.set_task_register_embeddings(task_register_embeddings_path)
self.show_progress = show_progress
self.use_flash_attention = use_flash_attention
def set_show_progress(self, show_progress: bool):
self.show_progress = show_progress
def set_use_flash_attention(self, use_flash_attention: bool):
self.use_flash_attention = use_flash_attention
def set_pipeline_offload(self, offload: bool):
self.ae = self.ae.to("cpu" if offload else self.device)
self.model = self.model.to("cpu" if offload else self.device)
self.image_encoder = self.image_encoder.to("cpu" if offload else self.device)
self.clip = self.clip.to("cpu" if offload else self.device)
self.t5 = self.t5.to("cpu" if offload else self.device)
self.offload = offload
def set_pipeline_gradient_checkpointing(self, enable: bool):
def _recursive_set_gradient_checkpointing(module):
self.model._set_gradient_checkpointing(module, enable)
for child in module.children():
_recursive_set_gradient_checkpointing(child)
_recursive_set_gradient_checkpointing(self.model)
def get_lora_rank(self, weights):
for k in weights.keys():
if k.endswith(".down.weight"):
return weights[k].shape[0]
def load_model_weights(self, weights: dict, strict: bool = False):
model_state_dict = self.model.state_dict()
update_dict = {k: v for k, v in weights.items() if k in model_state_dict}
missing_keys = [k for k in weights if k not in model_state_dict]
assert len(missing_keys) == 0, f"Some keys in the file are not found in the model: {missing_keys}"
self.model.load_state_dict(update_dict, strict=strict)
def set_lora(
self,
lora_path: str = None,
network_alpha: int = None,
double_blocks_idx: str = None,
single_blocks_idx: str = None,
):
if not os.path.exists(lora_path):
lora_path = "dit_lora_0x1561"
lora_path = resolve_model_path(
name=lora_path,
repo_id_field="repo_id",
filename_field="filename",
ckpt_path_field="ckpt_path",
hf_download=True,
)
weights = load_sft(lora_path)
self.update_model_with_lora(weights, network_alpha, double_blocks_idx, single_blocks_idx)
def update_model_with_lora(
self,
weights,
network_alpha,
double_blocks_idx,
single_blocks_idx,
):
rank = self.get_lora_rank(weights)
network_alpha = network_alpha if network_alpha is not None else rank
lora_attn_procs = {}
if double_blocks_idx is None:
double_blocks_idx = []
else:
double_blocks_idx = [int(idx) for idx in double_blocks_idx.split(",")]
if single_blocks_idx is None:
single_blocks_idx = list(range(38))
else:
single_blocks_idx = [int(idx) for idx in single_blocks_idx.split(",")]
for name, attn_processor in self.model.attn_processors.items():
match = re.search(r'\.(\d+)\.', name)
if match:
layer_index = int(match.group(1))
if name.startswith("double_blocks") and layer_index in double_blocks_idx:
lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(
dim=3072, rank=rank, network_alpha=network_alpha
)
elif name.startswith("single_blocks") and layer_index in single_blocks_idx:
lora_attn_procs[name] = SingleStreamBlockLoraProcessor(
dim=3072, rank=rank, network_alpha=network_alpha
)
else:
lora_attn_procs[name] = attn_processor
self.model.set_attn_processor(lora_attn_procs)
self.load_model_weights(weights, strict=False)
def set_img_txt_in(self, img_txt_in_path: str):
if not os.path.exists(img_txt_in_path):
img_txt_in_path = "dit_txt_img_in_0x1561"
img_txt_in_path = resolve_model_path(
name=img_txt_in_path,
repo_id_field="repo_id",
filename_field="filename",
ckpt_path_field="ckpt_path",
hf_download=True,
)
weights = load_sft(img_txt_in_path)
self.load_model_weights(weights, strict=False)
def set_boundary_embeddings(self, boundary_embeddings_path: str):
if not os.path.exists(boundary_embeddings_path):
boundary_embeddings_path = "dit_boundary_embeddings_0x1561"
boundary_embeddings_path = resolve_model_path(
name=boundary_embeddings_path,
repo_id_field="repo_id",
filename_field="filename",
ckpt_path_field="ckpt_path",
hf_download=True,
)
weights = load_sft(boundary_embeddings_path)
self.load_model_weights(weights, strict=False)
def set_task_register_embeddings(self, task_register_embeddings_path: str):
if not os.path.exists(task_register_embeddings_path):
task_register_embeddings_path = "dit_task_register_embeddings_0x1561"
task_register_embeddings_path = resolve_model_path(
name=task_register_embeddings_path,
repo_id_field="repo_id",
filename_field="filename",
ckpt_path_field="ckpt_path",
hf_download=True,
)
weights = load_sft(task_register_embeddings_path)
self.load_model_weights(weights, strict=False)
def offload_model_to_cpu(self, *models):
for model in models:
if model is not None:
model.to("cpu")
def prepare_image(
self,
image,
device,
dtype,
width=None,
height=None,
):
if isinstance(image, torch.Tensor):
pass
else:
image = self.image_processor.preprocess(image, height=height, width=width)
image = image.to(device=device, dtype=dtype)
return image
def prepare_mask(
self,
mask,
device,
dtype,
width: int = None,
height: int = None,
):
if isinstance(mask, torch.Tensor):
pass
else:
mask = self.mask_processor.preprocess(mask, height=height, width=width)
mask = mask.to(device=device, dtype=dtype)
return mask
def __call__(
self,
prompt: Union[str, List[str], None],
width: int = 512,
height: int = 512,
guidance: float = 4,
num_steps: int = 50,
seed: int = 123456789,
true_gs: float = 1,
neg_prompt: Optional[Union[str, List[str], None]] = None,
timestep_to_start_cfg: int = 0,
img_ref: Optional[PipelineImageInput] = None,
img_target: Optional[PipelineImageInput] = None,
mask_target: Optional[PipelineImageInput] = None,
img_ip: Optional[PipelineImageInput] = None,
cond_w_regions: Optional[Union[List[int], int]] = None,
mask_type_ids: Optional[Union[Tensor, int]] = None,
use_background_preservation: bool = False,
use_progressive_background_preservation: bool = True,
background_blend_threshold: float = 0.8,
num_images_per_prompt: int = 1,
gradio_progress=None,
):
width = 16 * (width // 16)
height = 16 * (height // 16)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = 1
img_ref = self.prepare_image(
img_ref,
self.device,
self.weight_dtype,
)
img_target = self.prepare_image(
img_target,
self.device,
self.weight_dtype,
)
mask_target = self.prepare_mask(
mask_target,
self.device,
self.weight_dtype,
)
if num_images_per_prompt > 1:
mask_type_ids = mask_type_ids.repeat_interleave(num_images_per_prompt, dim=0)
return self.forward(
batch_size,
num_images_per_prompt,
prompt,
width,
height,
guidance,
num_steps,
seed,
timestep_to_start_cfg=timestep_to_start_cfg,
true_gs=true_gs,
neg_prompt=neg_prompt,
img_ref=img_ref,
img_target=img_target,
mask_target=mask_target,
img_ip=img_ip,
cond_w_regions=cond_w_regions,
mask_type_ids=mask_type_ids,
use_background_preservation=use_background_preservation,
use_progressive_background_preservation=use_progressive_background_preservation,
background_blend_threshold=background_blend_threshold,
gradio_progress=gradio_progress,
)
def forward(
self,
batch_size,
num_images_per_prompt,
prompt,
width,
height,
guidance,
num_steps,
seed,
timestep_to_start_cfg,
true_gs,
neg_prompt,
img_ref,
img_target,
mask_target,
img_ip,
cond_w_regions,
mask_type_ids,
use_background_preservation,
use_progressive_background_preservation,
background_blend_threshold,
gradio_progress=None,
):
has_neg_prompt = neg_prompt is not None
do_true_cfg = true_gs > 1 and has_neg_prompt
x = get_noise(
batch_size * num_images_per_prompt, height, width, device=self.device,
dtype=self.weight_dtype, seed=seed
)
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
timesteps = get_schedule(
num_steps,
image_seq_len,
shift=True,
)
with torch.no_grad():
self.t5, self.clip, self.image_encoder = self.t5.to(self.device), self.clip.to(self.device), self.image_encoder.to(self.device)
if self.image_encoder is not None:
inp_cond = prepare_with_redux(t5=self.t5, clip=self.clip, image_encoder=self.image_encoder, img=x, img_ip=img_ip, prompt=prompt, num_images_per_prompt=num_images_per_prompt)
else:
inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=prompt, num_images_per_prompt=num_images_per_prompt)
neg_inp_cond = None
if do_true_cfg:
if self.image_encoder is not None:
neg_inp_cond = prepare_with_redux(t5=self.t5, clip=self.clip, image_encoder=self.image_encoder, img=x, img_ip=img_ip, prompt=neg_prompt, num_images_per_prompt=num_images_per_prompt)
else:
neg_inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=neg_prompt, num_images_per_prompt=num_images_per_prompt)
if self.offload:
self.offload_model_to_cpu(self.t5, self.clip, self.image_encoder)
self.model = self.model.to(self.device)
self.ae.encoder = self.ae.encoder.to(self.device)
inp_img_cond = prepare_image_cond(
ae=self.ae,
img_ref=img_ref,
img_target=img_target,
mask_target=mask_target,
dtype=self.weight_dtype,
device=self.device,
num_images_per_prompt=num_images_per_prompt,
)
x = denoise(
self.model,
img=inp_cond['img'],
img_ids=inp_cond['img_ids'],
txt=inp_cond['txt'],
txt_ids=inp_cond['txt_ids'],
txt_vec=inp_cond['txt_vec'],
timesteps=timesteps,
guidance=guidance,
img_cond=inp_img_cond['img_cond'],
mask_cond=inp_img_cond['mask_cond'],
img_latent=inp_img_cond['img_latent'],
cond_w_regions=cond_w_regions,
mask_type_ids=mask_type_ids,
height=height,
width=width,
use_background_preservation=use_background_preservation,
use_progressive_background_preservation=use_progressive_background_preservation,
background_blend_threshold=background_blend_threshold,
true_gs=true_gs,
timestep_to_start_cfg=timestep_to_start_cfg,
neg_txt=neg_inp_cond['txt'] if neg_inp_cond is not None else None,
neg_txt_ids=neg_inp_cond['txt_ids'] if neg_inp_cond is not None else None,
neg_txt_vec=neg_inp_cond['txt_vec'] if neg_inp_cond is not None else None,
show_progress=self.show_progress,
use_flash_attention=self.use_flash_attention,
gradio_progress=gradio_progress,
)
if self.offload:
self.offload_model_to_cpu(self.model, self.ae.encoder)
x = unpack(x.float(), height, width)
self.ae.decoder = self.ae.decoder.to(x.device)
x = self.ae.decode(x)
if self.offload:
self.offload_model_to_cpu(self.ae.decoder)
x1 = x.clamp(-1, 1)
x1 = rearrange(x1, "b c h w -> b h w c")
output_imgs_target = []
for i in range(x1.shape[0]):
output_img = Image.fromarray((127.5 * (x1[i] + 1.0)).cpu().byte().numpy())
img_target_height, img_target_width = img_target.shape[2], img_target.shape[3]
output_img_target = output_img.crop((
output_img.width - img_target_width,
output_img.height - img_target_height,
output_img.width,
output_img.height
))
output_imgs_target.append(output_img_target)
return output_imgs_target