Spaces:
Running
on
Zero
Running
on
Zero
import gc | |
import numpy as np | |
import torch | |
from diffusers.pipelines.flux.pipeline_flux import calculate_shift | |
from diffusers.utils.torch_utils import randn_tensor | |
from PIL import Image | |
from torchvision import transforms | |
from ..callback.callback_fn import CallbackLatentStore | |
from .scheduling_flow_inverse import (FlowMatchEulerDiscreteBackwardScheduler, | |
FlowMatchEulerDiscreteForwardScheduler) | |
def img_to_latent(img, vae): | |
normalize = transforms.Normalize(mean=[0.5],std=[0.5]) | |
trans = transforms.Compose([ | |
transforms.ToTensor(), | |
normalize, | |
]) | |
tensor_img = trans(img)[None, ...] | |
tensor_img = tensor_img.to(dtype=vae.dtype, device=vae.device) | |
posterior = vae.encode(tensor_img).latent_dist | |
latents = (posterior.mean - vae.config.shift_factor) * vae.config.scaling_factor | |
# latents = posterior.mean | |
return latents | |
def get_inversed_latent_list( | |
pipe, | |
image: Image, | |
random_noise=None, | |
num_inference_steps: int = 28, | |
backward_method: str = 'ode', | |
model_name: str = 'flux', | |
res=(1024, 1024), | |
use_prompt_for_inversion=False, | |
guidance_scale_for_inversion=0, | |
prompt_for_inversion=None, | |
seed=0, | |
flow_steps=1, | |
ode_steps=1, | |
intermediate_steps=None | |
): | |
img = image.resize(res) | |
img_latent = img_to_latent(image, pipe.vae) | |
device = img_latent.device | |
generator = torch.Generator(device=device).manual_seed(seed) | |
if random_noise is None: | |
random_noise = randn_tensor(img_latent.shape, device=device, generator=generator) | |
if model_name == 'flux': | |
random_noise = pipe._pack_latents(random_noise, *random_noise.shape) | |
if model_name == 'flux': | |
img_latent = pipe._pack_latents(img_latent, *img_latent.shape) | |
pipe.scheduler = FlowMatchEulerDiscreteBackwardScheduler.from_config( | |
pipe.scheduler.config, | |
margin_index_from_noise=0, | |
margin_index_from_image=0, | |
intermediate_steps=intermediate_steps | |
) | |
if model_name == 'flux': | |
image_seq_len = img_latent.shape[1] | |
mu = calculate_shift( | |
image_seq_len, | |
pipe.scheduler.config.base_image_seq_len, | |
pipe.scheduler.config.max_image_seq_len, | |
pipe.scheduler.config.base_shift, | |
pipe.scheduler.config.max_shift, | |
) | |
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) | |
else: | |
mu = None | |
sigmas = None | |
pipe.scheduler.set_timesteps(num_inference_steps=num_inference_steps, mu=mu, sigmas=sigmas) | |
sigmas = pipe.scheduler.sigmas | |
timesteps = pipe.scheduler.timesteps | |
if backward_method == 'flow': | |
inv_latents = [img_latent] | |
for sigma in sigmas: | |
inv_latent = (1 - sigma) * img_latent + sigma * random_noise | |
inv_latents.append(inv_latent) | |
elif backward_method == 'ode': | |
inv_latents = [img_latent] | |
img_latent_new = img_latent.to(pipe.dtype) | |
random_noise = random_noise.to(pipe.dtype) | |
callback_fn = CallbackLatentStore() | |
inv_latent = pipe.inversion( | |
latents=img_latent_new, | |
rand_latents=random_noise, | |
flow_steps=flow_steps, | |
prompt=prompt_for_inversion if use_prompt_for_inversion else '', | |
num_images_per_prompt=1, | |
output_type='latent', | |
width=res[0], height=res[1], | |
guidance_scale=guidance_scale_for_inversion, | |
num_inference_steps=num_inference_steps, | |
callback_on_step_end=callback_fn | |
).images | |
inv_latents = inv_latents + callback_fn.latents | |
del img_latent | |
gc.collect() | |
torch.cuda.empty_cache() | |
return inv_latents |