ReFlex / src /inversion /inverse.py
SahilCarterr's picture
Upload 77 files
f056744 verified
import gc
import numpy as np
import torch
from diffusers.pipelines.flux.pipeline_flux import calculate_shift
from diffusers.utils.torch_utils import randn_tensor
from PIL import Image
from torchvision import transforms
from ..callback.callback_fn import CallbackLatentStore
from .scheduling_flow_inverse import (FlowMatchEulerDiscreteBackwardScheduler,
FlowMatchEulerDiscreteForwardScheduler)
@torch.no_grad()
def img_to_latent(img, vae):
normalize = transforms.Normalize(mean=[0.5],std=[0.5])
trans = transforms.Compose([
transforms.ToTensor(),
normalize,
])
tensor_img = trans(img)[None, ...]
tensor_img = tensor_img.to(dtype=vae.dtype, device=vae.device)
posterior = vae.encode(tensor_img).latent_dist
latents = (posterior.mean - vae.config.shift_factor) * vae.config.scaling_factor
# latents = posterior.mean
return latents
@torch.no_grad()
def get_inversed_latent_list(
pipe,
image: Image,
random_noise=None,
num_inference_steps: int = 28,
backward_method: str = 'ode',
model_name: str = 'flux',
res=(1024, 1024),
use_prompt_for_inversion=False,
guidance_scale_for_inversion=0,
prompt_for_inversion=None,
seed=0,
flow_steps=1,
ode_steps=1,
intermediate_steps=None
):
img = image.resize(res)
img_latent = img_to_latent(image, pipe.vae)
device = img_latent.device
generator = torch.Generator(device=device).manual_seed(seed)
if random_noise is None:
random_noise = randn_tensor(img_latent.shape, device=device, generator=generator)
if model_name == 'flux':
random_noise = pipe._pack_latents(random_noise, *random_noise.shape)
if model_name == 'flux':
img_latent = pipe._pack_latents(img_latent, *img_latent.shape)
pipe.scheduler = FlowMatchEulerDiscreteBackwardScheduler.from_config(
pipe.scheduler.config,
margin_index_from_noise=0,
margin_index_from_image=0,
intermediate_steps=intermediate_steps
)
if model_name == 'flux':
image_seq_len = img_latent.shape[1]
mu = calculate_shift(
image_seq_len,
pipe.scheduler.config.base_image_seq_len,
pipe.scheduler.config.max_image_seq_len,
pipe.scheduler.config.base_shift,
pipe.scheduler.config.max_shift,
)
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
else:
mu = None
sigmas = None
pipe.scheduler.set_timesteps(num_inference_steps=num_inference_steps, mu=mu, sigmas=sigmas)
sigmas = pipe.scheduler.sigmas
timesteps = pipe.scheduler.timesteps
if backward_method == 'flow':
inv_latents = [img_latent]
for sigma in sigmas:
inv_latent = (1 - sigma) * img_latent + sigma * random_noise
inv_latents.append(inv_latent)
elif backward_method == 'ode':
inv_latents = [img_latent]
img_latent_new = img_latent.to(pipe.dtype)
random_noise = random_noise.to(pipe.dtype)
callback_fn = CallbackLatentStore()
inv_latent = pipe.inversion(
latents=img_latent_new,
rand_latents=random_noise,
flow_steps=flow_steps,
prompt=prompt_for_inversion if use_prompt_for_inversion else '',
num_images_per_prompt=1,
output_type='latent',
width=res[0], height=res[1],
guidance_scale=guidance_scale_for_inversion,
num_inference_steps=num_inference_steps,
callback_on_step_end=callback_fn
).images
inv_latents = inv_latents + callback_fn.latents
del img_latent
gc.collect()
torch.cuda.empty_cache()
return inv_latents