import os import re from typing import List, Optional, Union import PIL from PIL import Image from einops import rearrange from torch import Tensor import numpy as np import torch from safetensors.torch import load_file as load_sft from diffusers.image_processor import VaeImageProcessor from ..modules.layers import ( SingleStreamBlockLoraProcessor, DoubleStreamBlockLoraProcessor, ) from ..pipelines.sampling import denoise, prepare_image_cond, get_noise, get_schedule, prepare, prepare_with_redux, unpack from ..utils.model_utils import ( load_ae, load_clip, load_ic_custom, load_t5, load_redux, resolve_model_path ) PipelineImageInput = Union[ PIL.Image.Image, np.ndarray, torch.Tensor, List[PIL.Image.Image], List[np.ndarray], List[torch.Tensor], ] class ICCustomPipeline: def __init__( self, clip_path: str = "clip-vit-large-patch14", t5_path: str = "t5-v1_1-xxl", siglip_path: str = "siglip-so400m-patch14-384", ae_path: str = "flux-fill-dev-ae", dit_path: str = "flux-fill-dev-dit", redux_path: str = "flux1-redux-dev", lora_path: str = "dit_lora_0x1561", img_txt_in_path: str = "dit_txt_img_in_0x1561", boundary_embeddings_path: str = "dit_boundary_embeddings_0x1561", task_register_embeddings_path: str = "dit_task_register_embeddings_0x1561", network_alpha: int = None, double_blocks_idx: str = None, single_blocks_idx: str = None, device: torch.device = torch.device("cuda"), offload: bool = False, weight_dtype: torch.dtype = torch.bfloat16, show_progress: bool = False, use_flash_attention: bool = False, ): self.device = device self.offload = offload self.weight_dtype = weight_dtype self.clip = load_clip(clip_path, self.device if not offload else "cpu", dtype=self.weight_dtype).eval() self.t5 = load_t5(t5_path, self.device if not offload else "cpu", max_length=512, dtype=self.weight_dtype).eval() self.ae = load_ae(ae_path, device="cpu" if offload else self.device).eval() self.model = load_ic_custom(dit_path, device="cpu" if offload else self.device, dtype=self.weight_dtype).eval() self.image_encoder = load_redux(redux_path, siglip_path, device="cpu" if offload else self.device, dtype=self.weight_dtype).eval() self.vae_scale_factor = 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) self.mask_processor = VaeImageProcessor(resample="nearest", do_normalize=False) self.set_lora(lora_path, network_alpha, double_blocks_idx, single_blocks_idx) self.set_img_txt_in(img_txt_in_path) self.set_boundary_embeddings(boundary_embeddings_path) self.set_task_register_embeddings(task_register_embeddings_path) self.show_progress = show_progress self.use_flash_attention = use_flash_attention def set_show_progress(self, show_progress: bool): self.show_progress = show_progress def set_use_flash_attention(self, use_flash_attention: bool): self.use_flash_attention = use_flash_attention def set_pipeline_offload(self, offload: bool): self.ae = self.ae.to("cpu" if offload else self.device) self.model = self.model.to("cpu" if offload else self.device) self.image_encoder = self.image_encoder.to("cpu" if offload else self.device) self.clip = self.clip.to("cpu" if offload else self.device) self.t5 = self.t5.to("cpu" if offload else self.device) self.offload = offload def set_pipeline_gradient_checkpointing(self, enable: bool): def _recursive_set_gradient_checkpointing(module): self.model._set_gradient_checkpointing(module, enable) for child in module.children(): _recursive_set_gradient_checkpointing(child) _recursive_set_gradient_checkpointing(self.model) def get_lora_rank(self, weights): for k in weights.keys(): if k.endswith(".down.weight"): return weights[k].shape[0] def load_model_weights(self, weights: dict, strict: bool = False): model_state_dict = self.model.state_dict() update_dict = {k: v for k, v in weights.items() if k in model_state_dict} missing_keys = [k for k in weights if k not in model_state_dict] assert len(missing_keys) == 0, f"Some keys in the file are not found in the model: {missing_keys}" self.model.load_state_dict(update_dict, strict=strict) def set_lora( self, lora_path: str = None, network_alpha: int = None, double_blocks_idx: str = None, single_blocks_idx: str = None, ): if not os.path.exists(lora_path): lora_path = "dit_lora_0x1561" lora_path = resolve_model_path( name=lora_path, repo_id_field="repo_id", filename_field="filename", ckpt_path_field="ckpt_path", hf_download=True, ) weights = load_sft(lora_path) self.update_model_with_lora(weights, network_alpha, double_blocks_idx, single_blocks_idx) def update_model_with_lora( self, weights, network_alpha, double_blocks_idx, single_blocks_idx, ): rank = self.get_lora_rank(weights) network_alpha = network_alpha if network_alpha is not None else rank lora_attn_procs = {} if double_blocks_idx is None: double_blocks_idx = [] else: double_blocks_idx = [int(idx) for idx in double_blocks_idx.split(",")] if single_blocks_idx is None: single_blocks_idx = list(range(38)) else: single_blocks_idx = [int(idx) for idx in single_blocks_idx.split(",")] for name, attn_processor in self.model.attn_processors.items(): match = re.search(r'\.(\d+)\.', name) if match: layer_index = int(match.group(1)) if name.startswith("double_blocks") and layer_index in double_blocks_idx: lora_attn_procs[name] = DoubleStreamBlockLoraProcessor( dim=3072, rank=rank, network_alpha=network_alpha ) elif name.startswith("single_blocks") and layer_index in single_blocks_idx: lora_attn_procs[name] = SingleStreamBlockLoraProcessor( dim=3072, rank=rank, network_alpha=network_alpha ) else: lora_attn_procs[name] = attn_processor self.model.set_attn_processor(lora_attn_procs) self.load_model_weights(weights, strict=False) def set_img_txt_in(self, img_txt_in_path: str): if not os.path.exists(img_txt_in_path): img_txt_in_path = "dit_txt_img_in_0x1561" img_txt_in_path = resolve_model_path( name=img_txt_in_path, repo_id_field="repo_id", filename_field="filename", ckpt_path_field="ckpt_path", hf_download=True, ) weights = load_sft(img_txt_in_path) self.load_model_weights(weights, strict=False) def set_boundary_embeddings(self, boundary_embeddings_path: str): if not os.path.exists(boundary_embeddings_path): boundary_embeddings_path = "dit_boundary_embeddings_0x1561" boundary_embeddings_path = resolve_model_path( name=boundary_embeddings_path, repo_id_field="repo_id", filename_field="filename", ckpt_path_field="ckpt_path", hf_download=True, ) weights = load_sft(boundary_embeddings_path) self.load_model_weights(weights, strict=False) def set_task_register_embeddings(self, task_register_embeddings_path: str): if not os.path.exists(task_register_embeddings_path): task_register_embeddings_path = "dit_task_register_embeddings_0x1561" task_register_embeddings_path = resolve_model_path( name=task_register_embeddings_path, repo_id_field="repo_id", filename_field="filename", ckpt_path_field="ckpt_path", hf_download=True, ) weights = load_sft(task_register_embeddings_path) self.load_model_weights(weights, strict=False) def offload_model_to_cpu(self, *models): for model in models: if model is not None: model.to("cpu") def prepare_image( self, image, device, dtype, width=None, height=None, ): if isinstance(image, torch.Tensor): pass else: image = self.image_processor.preprocess(image, height=height, width=width) image = image.to(device=device, dtype=dtype) return image def prepare_mask( self, mask, device, dtype, width: int = None, height: int = None, ): if isinstance(mask, torch.Tensor): pass else: mask = self.mask_processor.preprocess(mask, height=height, width=width) mask = mask.to(device=device, dtype=dtype) return mask def __call__( self, prompt: Union[str, List[str], None], width: int = 512, height: int = 512, guidance: float = 4, num_steps: int = 50, seed: int = 123456789, true_gs: float = 1, neg_prompt: Optional[Union[str, List[str], None]] = None, timestep_to_start_cfg: int = 0, img_ref: Optional[PipelineImageInput] = None, img_target: Optional[PipelineImageInput] = None, mask_target: Optional[PipelineImageInput] = None, img_ip: Optional[PipelineImageInput] = None, cond_w_regions: Optional[Union[List[int], int]] = None, mask_type_ids: Optional[Union[Tensor, int]] = None, use_background_preservation: bool = False, use_progressive_background_preservation: bool = True, background_blend_threshold: float = 0.8, num_images_per_prompt: int = 1, gradio_progress=None, ): width = 16 * (width // 16) height = 16 * (height // 16) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = 1 img_ref = self.prepare_image( img_ref, self.device, self.weight_dtype, ) img_target = self.prepare_image( img_target, self.device, self.weight_dtype, ) mask_target = self.prepare_mask( mask_target, self.device, self.weight_dtype, ) if num_images_per_prompt > 1: mask_type_ids = mask_type_ids.repeat_interleave(num_images_per_prompt, dim=0) return self.forward( batch_size, num_images_per_prompt, prompt, width, height, guidance, num_steps, seed, timestep_to_start_cfg=timestep_to_start_cfg, true_gs=true_gs, neg_prompt=neg_prompt, img_ref=img_ref, img_target=img_target, mask_target=mask_target, img_ip=img_ip, cond_w_regions=cond_w_regions, mask_type_ids=mask_type_ids, use_background_preservation=use_background_preservation, use_progressive_background_preservation=use_progressive_background_preservation, background_blend_threshold=background_blend_threshold, gradio_progress=gradio_progress, ) def forward( self, batch_size, num_images_per_prompt, prompt, width, height, guidance, num_steps, seed, timestep_to_start_cfg, true_gs, neg_prompt, img_ref, img_target, mask_target, img_ip, cond_w_regions, mask_type_ids, use_background_preservation, use_progressive_background_preservation, background_blend_threshold, gradio_progress=None, ): has_neg_prompt = neg_prompt is not None do_true_cfg = true_gs > 1 and has_neg_prompt x = get_noise( batch_size * num_images_per_prompt, height, width, device=self.device, dtype=self.weight_dtype, seed=seed ) image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) timesteps = get_schedule( num_steps, image_seq_len, shift=True, ) with torch.no_grad(): self.t5, self.clip, self.image_encoder = self.t5.to(self.device), self.clip.to(self.device), self.image_encoder.to(self.device) if self.image_encoder is not None: inp_cond = prepare_with_redux(t5=self.t5, clip=self.clip, image_encoder=self.image_encoder, img=x, img_ip=img_ip, prompt=prompt, num_images_per_prompt=num_images_per_prompt) else: inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=prompt, num_images_per_prompt=num_images_per_prompt) neg_inp_cond = None if do_true_cfg: if self.image_encoder is not None: neg_inp_cond = prepare_with_redux(t5=self.t5, clip=self.clip, image_encoder=self.image_encoder, img=x, img_ip=img_ip, prompt=neg_prompt, num_images_per_prompt=num_images_per_prompt) else: neg_inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=neg_prompt, num_images_per_prompt=num_images_per_prompt) if self.offload: self.offload_model_to_cpu(self.t5, self.clip, self.image_encoder) self.model = self.model.to(self.device) self.ae.encoder = self.ae.encoder.to(self.device) inp_img_cond = prepare_image_cond( ae=self.ae, img_ref=img_ref, img_target=img_target, mask_target=mask_target, dtype=self.weight_dtype, device=self.device, num_images_per_prompt=num_images_per_prompt, ) x = denoise( self.model, img=inp_cond['img'], img_ids=inp_cond['img_ids'], txt=inp_cond['txt'], txt_ids=inp_cond['txt_ids'], txt_vec=inp_cond['txt_vec'], timesteps=timesteps, guidance=guidance, img_cond=inp_img_cond['img_cond'], mask_cond=inp_img_cond['mask_cond'], img_latent=inp_img_cond['img_latent'], cond_w_regions=cond_w_regions, mask_type_ids=mask_type_ids, height=height, width=width, use_background_preservation=use_background_preservation, use_progressive_background_preservation=use_progressive_background_preservation, background_blend_threshold=background_blend_threshold, true_gs=true_gs, timestep_to_start_cfg=timestep_to_start_cfg, neg_txt=neg_inp_cond['txt'] if neg_inp_cond is not None else None, neg_txt_ids=neg_inp_cond['txt_ids'] if neg_inp_cond is not None else None, neg_txt_vec=neg_inp_cond['txt_vec'] if neg_inp_cond is not None else None, show_progress=self.show_progress, use_flash_attention=self.use_flash_attention, gradio_progress=gradio_progress, ) if self.offload: self.offload_model_to_cpu(self.model, self.ae.encoder) x = unpack(x.float(), height, width) self.ae.decoder = self.ae.decoder.to(x.device) x = self.ae.decode(x) if self.offload: self.offload_model_to_cpu(self.ae.decoder) x1 = x.clamp(-1, 1) x1 = rearrange(x1, "b c h w -> b h w c") output_imgs_target = [] for i in range(x1.shape[0]): output_img = Image.fromarray((127.5 * (x1[i] + 1.0)).cpu().byte().numpy()) img_target_height, img_target_width = img_target.shape[2], img_target.shape[3] output_img_target = output_img.crop(( output_img.width - img_target_width, output_img.height - img_target_height, output_img.width, output_img.height )) output_imgs_target.append(output_img_target) return output_imgs_target