Spaces:
Running
on
Zero
Running
on
Zero
import math | |
import sys | |
from typing import Callable, List, Optional, Union | |
import torch | |
from einops import rearrange, repeat | |
from torch import Tensor | |
from ..models.model import Flux | |
from ..modules.conditioner import HFEmbedder | |
from ..modules.image_embedders import ReduxImageEncoder | |
# ------------------------------------------------------------------------- | |
# Progress bar | |
# ------------------------------------------------------------------------- | |
import time | |
TGT_PREFIX = "[TARGET-SCENE]" | |
def print_progress_bar(iteration, total, prefix='', suffix='', length=30, fill='█'): | |
""" | |
Simple progress bar for console output, with elapsed and estimated remaining time. | |
Args: | |
iteration: Current iteration (Int) | |
total: Total iterations (Int) | |
prefix: Prefix string (Str) | |
suffix: Suffix string (Str) | |
length: Bar length (Int) | |
fill: Bar fill character (Str) | |
""" | |
# Static variable to store start time | |
if not hasattr(print_progress_bar, "_start_time") or iteration == 0: | |
print_progress_bar._start_time = time.time() | |
percent = f"{100 * (iteration / float(total)):.1f}%" | |
filled_length = int(length * iteration // total) | |
bar = fill * filled_length + '-' * (length - filled_length) | |
elapsed = time.time() - print_progress_bar._start_time | |
elapsed_str = time.strftime("%H:%M:%S", time.gmtime(elapsed)) | |
if iteration > 0: | |
avg_time_per_iter = elapsed / iteration | |
remaining = avg_time_per_iter * (total - iteration) | |
else: | |
remaining = 0 | |
remaining_str = time.strftime("%H:%M:%S", time.gmtime(remaining)) | |
time_info = f"Elapsed: {elapsed_str} | ETA: {remaining_str}" | |
sys.stdout.write(f'\r{prefix} |{bar}| {percent} {suffix} {time_info}') | |
sys.stdout.flush() | |
if iteration == total: | |
sys.stdout.write('\n') | |
sys.stdout.flush() | |
# ------------------------------------------------------------------------- | |
# 1) sampling func | |
# ------------------------------------------------------------------------- | |
def unpack(x: Tensor, height: int, width: int) -> Tensor: | |
return rearrange( | |
x, | |
"b (h w) (c ph pw) -> b c (h ph) (w pw)", | |
h=math.ceil(height / 16), | |
w=math.ceil(width / 16), | |
ph=2, | |
pw=2, | |
) | |
def time_shift(mu: float, sigma: float, t: Tensor): | |
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) | |
def get_lin_function( | |
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 | |
): | |
m = (y2 - y1) / (x2 - x1) | |
b = y1 - m * x1 | |
return lambda x: m * x + b | |
def get_schedule( | |
num_steps: int, | |
image_seq_len: int, | |
base_shift: float = 0.5, | |
max_shift: float = 1.15, | |
shift: bool = True, | |
): | |
# extra step for zero | |
timesteps = torch.linspace(1, 0, num_steps + 1) | |
# shifting the schedule to favor high timesteps for higher signal images | |
if shift: | |
# eastimate mu based on linear estimation between two points | |
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) | |
timesteps = time_shift(mu, 1.0, timesteps) | |
return timesteps.tolist() | |
def get_noise( | |
num_samples: int, | |
height: int, | |
width: int, | |
device: torch.device, | |
dtype: torch.dtype, | |
seed: int, | |
): | |
noise = torch.cat( | |
[torch.randn( | |
1, | |
16, | |
# allow for packing | |
2 * math.ceil(height / 16), | |
2 * math.ceil(width / 16), | |
device=device, | |
dtype=dtype, | |
generator=torch.Generator(device=device).manual_seed(seed+i), | |
) | |
for i in range(num_samples) | |
], | |
dim=0 | |
) | |
return noise | |
# ------------------------------------------------------------------------- | |
# prepare input func | |
# ------------------------------------------------------------------------- | |
def _get_batch_size_and_prompt(prompt, img_shape): | |
""" | |
Helper to determine batch size and prompt list. | |
""" | |
bs, c, h, w = img_shape | |
is_prompt_none = prompt is None | |
return bs, prompt, is_prompt_none, h, w | |
def _make_img_ids(bs, h, w, device=None, dtype=None): | |
""" | |
Helper to create image ids tensor. | |
""" | |
img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype) | |
img_ids[..., 1] = torch.arange(h // 2)[:, None] | |
img_ids[..., 2] = torch.arange(w // 2)[None, :] | |
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) | |
return img_ids | |
def prepare( | |
t5: HFEmbedder, | |
clip: HFEmbedder, | |
img: Tensor, | |
prompt: Union[str, List[str], None], | |
num_images_per_prompt: int = 1, | |
): | |
""" | |
Prepare the regular input for the Diffusion Transformer. | |
""" | |
img_bs, prompt, is_prompt_none, h, w = _get_batch_size_and_prompt(prompt, img.shape) | |
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) | |
img_ids = _make_img_ids(img_bs, h, w, device=img.device, dtype=img.dtype) | |
if isinstance(prompt, str): | |
prompt = [prompt] | |
txt_bs = len(prompt) | |
if not is_prompt_none: | |
prompt = [TGT_PREFIX + p for p in prompt] | |
txt = t5(prompt) | |
txt_ids = torch.zeros(txt_bs, txt.shape[1], 3, device=img.device, dtype=img.dtype) | |
txt_vec = clip(prompt) | |
else: | |
txt = torch.zeros(txt_bs, 512, 4096, device=img.device, dtype=img.dtype) | |
txt_ids = torch.zeros(txt_bs, 512, 3, device=img.device, dtype=img.dtype) | |
txt_vec = torch.zeros(txt_bs, 768, device=img.device, dtype=img.dtype) | |
if num_images_per_prompt > 1: | |
txt = txt.repeat_interleave(num_images_per_prompt, dim=0) | |
txt_ids = txt_ids.repeat_interleave(num_images_per_prompt, dim=0) | |
txt_vec = txt_vec.repeat_interleave(num_images_per_prompt, dim=0) | |
return { | |
"img": img.to(img.device), | |
"img_ids": img_ids.to(img.device), | |
"txt": txt.to(img.device), | |
"txt_ids": txt_ids.to(img.device), | |
"txt_vec": txt_vec.to(img.device), | |
} | |
def prepare_with_redux( | |
t5: HFEmbedder, | |
clip: HFEmbedder, | |
image_encoder: ReduxImageEncoder, | |
img: Tensor, | |
img_ip: Tensor, | |
prompt: Union[str, List[str], None], | |
num_images_per_prompt: int = 1, | |
): | |
img_bs, prompt, is_prompt_none, h, w = _get_batch_size_and_prompt(prompt, img.shape) | |
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) | |
img_ids = _make_img_ids(img_bs, h, w, device=img.device, dtype=img.dtype) | |
if isinstance(prompt, str): | |
prompt = [prompt] | |
txt_bs = len(prompt) | |
if not is_prompt_none: | |
prompt = [TGT_PREFIX + p for p in prompt] | |
txt = torch.cat((t5(prompt), image_encoder(img_ip)), dim=1) | |
txt_ids = torch.zeros(txt_bs, txt.shape[1], 3, device=img.device, dtype=img.dtype) | |
txt_vec = clip(prompt) | |
else: | |
txt = torch.zeros(txt_bs, 512, 4096, device=img.device, dtype=img.dtype) | |
txt_ids = torch.zeros(txt_bs, 512, 3, device=img.device, dtype=img.dtype) | |
txt_vec = torch.zeros(txt_bs, 768, device=img.device, dtype=img.dtype) | |
if num_images_per_prompt > 1: | |
txt = txt.repeat_interleave(num_images_per_prompt, dim=0) | |
txt_ids = txt_ids.repeat_interleave(num_images_per_prompt, dim=0) | |
txt_vec = txt_vec.repeat_interleave(num_images_per_prompt, dim=0) | |
return { | |
"img": img.to(img.device), | |
"img_ids": img_ids.to(img.device), | |
"txt": txt.to(img.device), | |
"txt_ids": txt_ids.to(img.device), | |
"txt_vec": txt_vec.to(img.device), | |
} | |
def prepare_image_cond( | |
ae, | |
img_ref, | |
img_target, | |
mask_target, | |
dtype, | |
device, | |
num_images_per_prompt: int = 1, | |
): | |
batch_size, _, _, _ = img_target.shape | |
# Apply mask to target image | |
mask_targeted_img = img_target * mask_target | |
if mask_target.shape[1] == 3: | |
mask_target = mask_target[:, 0 : 1, :, :] | |
with torch.no_grad(): | |
autoencoder_dtype = next(ae.parameters()).dtype | |
# Encode masked target image to latent space | |
mask_targeted_latent = ae.encode(mask_targeted_img.to(autoencoder_dtype)).to(dtype) | |
# Encode reference image to latent space | |
reference_latent = ae.encode(img_ref.to(autoencoder_dtype)).to(dtype) | |
# Repeat reference latent if batch size > 1 | |
if reference_latent.shape[0] == 1 and batch_size > 1: | |
reference_latent = repeat(reference_latent, "1 ... -> bs ...", bs=batch_size) | |
# Concatenate reference and target latents | |
latent_concat = torch.cat((reference_latent, mask_targeted_latent), dim=-1) | |
# Pack latents into 2x2 patches | |
latent_packed = rearrange(latent_concat, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) | |
# Create reference mask (all ones) | |
reference_mask = torch.ones_like(img_ref) | |
if reference_mask.shape[1] == 3: | |
reference_mask = reference_mask[:, 0 : 1, :, :] | |
# Concatenate reference and target masks | |
mask_concat = torch.cat((reference_mask, mask_target), dim=-1) | |
# Pack masks into 16x16 patches for image conditioning | |
mask_16x16 = rearrange(mask_concat, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=16, pw=16) | |
# Interpolate masks to latent space dimensions | |
mask_latent = torch.nn.functional.interpolate(mask_concat, size=(latent_concat.shape[2] // 2, latent_concat.shape[3] // 2), mode='nearest') | |
# Pack interpolated masks into 1x1 patches for mask conditioning | |
mask_cond = rearrange(mask_latent, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=1, pw=1) | |
# Combine packed latents and masks for image conditioning | |
img_cond = torch.cat((latent_packed, mask_16x16), dim=-1) | |
if num_images_per_prompt > 1: | |
img_cond = img_cond.repeat_interleave(num_images_per_prompt, dim=0) | |
mask_cond = mask_cond.repeat_interleave(num_images_per_prompt, dim=0) | |
latent_packed = latent_packed.repeat_interleave(num_images_per_prompt, dim=0) | |
return { | |
"img_cond": img_cond.to(device).to(dtype), | |
"mask_cond": mask_cond.to(device).to(dtype), | |
"img_latent": latent_packed.to(device).to(dtype), | |
} | |
# ------------------------------------------------------------------------- | |
# 2) denoise func | |
# ------------------------------------------------------------------------- | |
def is_even_step(step: int) -> bool: | |
"""Check if the current step is odd.""" | |
return (step % 2 == 0) | |
def denoise( | |
model, | |
img, | |
img_ids, | |
txt, | |
txt_ids, | |
txt_vec, | |
timesteps, | |
guidance: float = 4.0, | |
img_cond: Tensor = None, | |
mask_cond: Tensor = None, | |
img_latent: Tensor = None, | |
cond_w_regions: Optional[Union[List[int], int]] = None, | |
mask_type_ids: Optional[Union[Tensor, int]] = None, | |
height: int = 1024, | |
width: int = 1024, | |
use_background_preservation: bool = False, | |
use_progressive_background_preservation: bool = True, | |
background_blend_threshold: float = 0.8, | |
true_gs: float = 1, | |
timestep_to_start_cfg: int = 0, | |
neg_txt: Tensor = None, | |
neg_txt_ids: Tensor = None, | |
neg_txt_vec: Tensor = None, | |
show_progress: bool = False, | |
use_flash_attention: bool = False, | |
gradio_progress=None, | |
): | |
do_true_cfg = true_gs > 1 and neg_txt is not None | |
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) | |
v_gt = img - img_latent | |
num_steps = len(timesteps[:-1]) | |
for step, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])): | |
if show_progress: | |
print_progress_bar(step, num_steps, prefix='Denoising:', suffix=f'Step {step+1}/{num_steps}') | |
# Update Gradio progress if available | |
if gradio_progress is not None: | |
# Map denoise progress to 0.2-0.8 range (since 0.0-0.2 is preprocessing, 0.8-1.0 is postprocessing) | |
progress_value = (step / num_steps) | |
gradio_progress(progress_value, desc=f"Denoising step {step+1}/{num_steps}") | |
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) | |
model_dtype = list(model.parameters())[0].dtype | |
pred = model( | |
img=torch.cat((img.to(model_dtype), img_cond.to(model_dtype)), dim=-1) if img_cond is not None else img.to(model_dtype), | |
img_ids=img_ids.to(model_dtype), | |
txt=txt.to(model_dtype), | |
txt_ids=txt_ids.to(model_dtype), | |
txt_vec=txt_vec.to(model_dtype), | |
timesteps=t_vec.to(model_dtype), | |
guidance=guidance_vec.to(model_dtype), | |
cond_w_regions=cond_w_regions, | |
mask_type_ids=mask_type_ids, | |
height=height, | |
width=width, | |
use_flash_attention=use_flash_attention, | |
) | |
if do_true_cfg and step >= timestep_to_start_cfg: | |
neg_perd = model( | |
img=torch.cat((img.to(model_dtype), img_cond.to(model_dtype)), dim=-1) if img_cond is not None else img.to(model_dtype), | |
img_ids=img_ids.to(model_dtype), | |
txt=neg_txt.to(model_dtype), | |
txt_ids=neg_txt_ids.to(model_dtype), | |
txt_vec=neg_txt_vec.to(model_dtype), | |
timesteps=t_vec.to(model_dtype), | |
guidance=guidance_vec.to(model_dtype), | |
cond_w_regions=cond_w_regions, | |
mask_type_ids=mask_type_ids, | |
height=height, | |
width=width, | |
use_flash_attention=use_flash_attention, | |
) | |
pred = neg_perd + true_gs * (pred - neg_perd) | |
if use_background_preservation: | |
is_early_phase = step <= num_steps * background_blend_threshold | |
if is_early_phase: | |
if use_progressive_background_preservation: | |
if is_even_step(step): | |
# Apply mask blending on odd steps in early phase | |
masked_latent = pred * (1 - mask_cond) + v_gt * mask_cond | |
else: | |
# Use prediction directly for even steps or late phase | |
masked_latent = pred | |
else: | |
masked_latent = pred * (1 - mask_cond) + v_gt * mask_cond | |
else: | |
# Use prediction directly for even steps or late phase | |
masked_latent = pred | |
img = img + (t_prev - t_curr) * masked_latent | |
else: | |
img = img + (t_prev - t_curr) * pred | |
if show_progress: | |
print_progress_bar(num_steps, num_steps, prefix='Denoising:', suffix='Complete') | |
return img |