ReFlex / img_edit.py
SahilCarterr's picture
Upload 77 files
f056744 verified
import argparse
import gc
import os
import random
import re
import time
from distutils.util import strtobool
import pandas as pd
parser = argparse.ArgumentParser()
parser.add_argument(
"--img_path",
type=str,
)
parser.add_argument(
"--target_prompt",
type=str,
)
parser.add_argument(
"--source_prompt",
type=str,
default=''
)
parser.add_argument(
"--blend_word",
type=str,
default=''
)
parser.add_argument(
"--mask_path",
type=str,
default=None
)
parser.add_argument(
"--gpu",
type=str,
default="0",
)
parser.add_argument(
"--seed",
type=int,
default=0
)
parser.add_argument(
"--results_dir",
type=str,
default='results'
)
parser.add_argument(
"--model",
type=str,
default='flux',
choices=['flux']
)
parser.add_argument(
"--ca_steps",
type=int,
default=10,
help="Number of steps to apply I2T-CA adaptation and injection.",
)
parser.add_argument(
"--sa_steps",
type=int,
default=7
help="Number of steps to apply I2I-SA adaptation and injection.",
)
parser.add_argument(
"--feature_steps",
type=int,
default=5
help="Number of steps to inject residual features.",
)
parser.add_argument(
"--ca_attn_layer_from",
type=int,
default=13,
help="Layers to apply I2T-CA adaptation and injection.",
)
parser.add_argument(
"--ca_attn_layer_to",
type=int,
default=45,
help="Layers to apply I2T-CA adaptation and injection.",
)
parser.add_argument(
"--sa_attn_layer_from",
type=int,
default=20,
help="Layers to apply I2I-SA adaptation and injection.",
)
parser.add_argument(
"--sa_attn_layer_to",
type=int,
default=45,
help="Layers to apply I2I-SA adaptation and injection.",
)
parser.add_argument(
"--feature_layer_from",
type=int,
default=13,
help="Layers to inject residual features.",
)
parser.add_argument(
"--feature_layer_to",
type=int,
default=20,
help="Layers to inject residual features.",
)
parser.add_argument(
"--flow_steps",
type=int,
default=7,
help="Steps to apply forward step before inversion",
)
parser.add_argument(
"--step_start",
type=int,
default=0
)
parser.add_argument(
"--num_inference_steps",
type=int,
default=28
)
parser.add_argument(
"--guidance_scale",
type=float,
default=3.5,
)
parser.add_argument(
"--attn_topk",
type=int,
default=20,
help="Hyperparameter for I2I-SA adaptaion."
)
parser.add_argument(
"--text_scale",
type=float,
default=4,
help="Hyperparameter for I2T-CA adaptaion."
)
parser.add_argument(
"--mid_step_index",
type=int,
default=14,
help="Hyperparameter for mid-step feature extraction."
)
parser.add_argument(
"--use_mask",
type=strtobool,
default=True
)
parser.add_argument(
"--use_ca_mask",
type=strtobool,
default=True
)
parser.add_argument(
"--mask_steps",
type=int,
default=18,
help="Steps to apply latent blending"
)
parser.add_argument(
"--mask_dilation",
type=int,
default=3
)
parser.add_argument(
"--mask_nbins",
type=int,
default=128
)
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = f"{args.gpu}"
import gc
import matplotlib.pyplot as plt
import numpy as np
import torch
import yaml
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.utils.torch_utils import randn_tensor
from PIL import Image
from src.attn_utils.attn_utils import AttentionAdapter, AttnCollector
from src.attn_utils.flux_attn_processor import NewFluxAttnProcessor2_0
from src.attn_utils.seq_aligner import get_refinement_mapper
from src.callback.callback_fn import CallbackAll
from src.inversion.inverse import get_inversed_latent_list
from src.inversion.scheduling_flow_inverse import \
FlowMatchEulerDiscreteForwardScheduler
from src.pipeline.flux_pipeline import NewFluxPipeline
from src.transformer_utils.transformer_utils import (FeatureCollector,
FeatureReplace)
from src.utils import (find_token_id_differences, find_word_token_indices,
get_flux_pipeline, mask_decode, mask_interpolate)
def fix_seed(random_seed):
"""
fix seed to control any randomness from a code
(enable stability of the experiments' results.)
"""
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed) # if use multi-GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(random_seed)
random.seed(random_seed)
def main(args):
fix_seed(args.seed)
device = torch.device('cuda')
pipe = get_flux_pipeline(pipeline_class=NewFluxPipeline)
attn_proc = NewFluxAttnProcessor2_0
pipe = pipe.to(device)
layer_order = range(57)
ca_layer_list = layer_order[args.ca_attn_layer_from:args.ca_attn_layer_to]
sa_layer_list = layer_order[args.feature_layer_to:args.sa_attn_layer_to]
feature_layer_list = layer_order[args.feature_layer_from:args.feature_layer_to]
img_path = args.img_path
source_img = Image.open(img_path).resize((1024, 1024)).convert("RGB")
img_base_name = os.path.splitext(img_path)[0].split('/')[-1]
result_img_dir = f"{args.results_dir}/seed_{args.seed}/{args.target_prompt}"
source_prompt = args.source_prompt
target_prompt = args.target_prompt
prompts = [source_prompt, target_prompt]
print(prompts)
mask = None
if args.use_mask:
use_mask = True
if args.mask_path is not None:
mask = Image.open(args.mask_path)
mask = torch.tensor(np.array(mask)).bool()
mask = mask.to(device)
# Increase the latent blending steps if the ground truth mask is used.
args.mask_steps = int(args.num_inference_steps * 0.9)
source_ca_index = None
target_ca_index = None
use_ca_mask = False
elif args.use_ca_mask and source_prompt:
mask = None
if args.blend_word and args.blend_word in source_prompt:
editing_source_token_index = find_word_token_indices(source_prompt, args.blend_word, pipe.tokenizer_2)
editing_target_token_index = None
else:
editing_tokens_info = find_token_id_differences(*prompts, pipe.tokenizer_2)
editing_source_token_index = editing_tokens_info['prompt_1']['index']
editing_target_token_index = editing_tokens_info['prompt_2']['index']
use_ca_mask = True
if editing_source_token_index:
source_ca_index = editing_source_token_index
target_ca_index = None
elif editing_target_token_index:
source_ca_index = None
target_ca_index = editing_target_token_index
else:
source_ca_index = None
target_ca_index = None
use_ca_mask = False
else:
source_ca_index = None
target_ca_index = None
use_ca_mask = False
else:
use_mask = False
use_ca_mask = False
source_ca_index = None
target_ca_index = None
if source_prompt:
# Use I2T-CA injection
mappers, alphas = get_refinement_mapper(prompts, pipe.tokenizer_2, max_len=512)
mappers = mappers.to(device=device)
alphas = alphas.to(device=device, dtype=pipe.dtype)
alphas = alphas[:, None, None, :]
ca_steps = args.ca_steps
attn_adj_from = 1
else:
# Not use I2T-CA injection
mappers = None
alphas = None
ca_steps = 0
attn_adj_from=3
sa_steps = args.sa_steps
feature_steps = args.feature_steps
attn_controller = AttentionAdapter(
ca_layer_list=ca_layer_list,
sa_layer_list=sa_layer_list,
ca_steps=ca_steps,
sa_steps=sa_steps,
method='replace_topk',
topk=args.attn_topk,
text_scale=args.text_scale,
mappers=mappers,
alphas=alphas,
attn_adj_from=attn_adj_from,
save_source_ca=source_ca_index is not None,
save_target_ca=target_ca_index is not None,
)
attn_collector = AttnCollector(
transformer=pipe.transformer,
controller=attn_controller,
attn_processor_class=NewFluxAttnProcessor2_0,
)
feature_controller = FeatureReplace(
layer_list=feature_layer_list,
feature_steps=feature_steps,
)
feature_collector = FeatureCollector(
transformer=pipe.transformer,
controller=feature_controller,
)
num_prompts=len(prompts)
shape = (1, 16, 128, 128)
generator = torch.Generator(device=device).manual_seed(args.seed)
latents = randn_tensor(shape, device=device, generator=generator)
latents = pipe._pack_latents(latents, *latents.shape)
attn_collector.restore_orig_attention()
feature_collector.restore_orig_transformer()
t0 = time.perf_counter()
inv_latents = get_inversed_latent_list(
pipe,
source_img,
random_noise=latents,
num_inference_steps=args.num_inference_steps,
backward_method="ode",
use_prompt_for_inversion=False,
guidance_scale_for_inversion=0,
prompt_for_inversion='',
flow_steps=args.flow_steps,
)
source_latents = inv_latents[::-1]
target_latents = inv_latents[::-1]
attn_collector.register_attention_control()
feature_collector.register_transformer_control()
callback_fn = CallbackAll(
latents=source_latents,
attn_collector=attn_collector,
feature_collector=feature_collector,
feature_inject_steps=feature_steps,
mid_step_index=args.mid_step_index,
step_start=args.step_start,
use_mask=use_mask,
use_ca_mask=use_ca_mask,
source_ca_index=source_ca_index,
target_ca_index=target_ca_index,
mask_kwargs={'dilation': args.mask_dilation},
mask_steps=args.mask_steps,
mask=mask,
)
init_latent = target_latents[args.step_start]
init_latent = init_latent.repeat(num_prompts, 1, 1)
init_latent[0] = source_latents[args.mid_step_index]
os.makedirs(result_img_dir, exist_ok=True)
pipe.scheduler = FlowMatchEulerDiscreteForwardScheduler.from_config(
pipe.scheduler.config,
step_start=args.step_start,
margin_index_from_image=0
)
attn_controller.reset()
feature_controller.reset()
attn_controller.text_scale = args.text_scale
attn_controller.cur_step = args.step_start
feature_controller.cur_step = args.step_start
with torch.no_grad():
images = pipe(
prompts,
latents=init_latent,
num_images_per_prompt=1,
guidance_scale=args.guidance_scale,
num_inference_steps=args.num_inference_steps,
generator=generator,
callback_on_step_end=callback_fn,
mid_step_index=args.mid_step_index,
step_start=args.step_start,
callback_on_step_end_tensor_inputs=['latents'],
).images
t1 = time.perf_counter()
print(f"Done in {t1 - t0:.1f}s.")
source_img_path = os.path.join(result_img_dir, f"source.png")
source_img.save(source_img_path)
for i, img in enumerate(images[1:]):
target_img_path = os.path.join(result_img_dir, f"target_{i}.png")
img.save(target_img_path)
target_text_path = os.path.join(result_img_dir, f"target_prompts.txt")
with open(target_text_path, 'w') as file:
file.write(target_prompt + '\n')
source_text_path = os.path.join(result_img_dir, f"source_prompt.txt")
with open(source_text_path, 'w') as file:
file.write(source_prompt + '\n')
images = [source_img] + images
fs=3
n = len(images)
fig, ax = plt.subplots(1, n, figsize=(n*fs, 1*fs))
for i, img in enumerate(images):
ax[i].imshow(img)
ax[0].set_title('source')
ax[1].set_title(source_prompt, fontsize=7)
ax[2].set_title(target_prompt, fontsize=7)
overall_img_path = os.path.join(result_img_dir, f"overall.png")
plt.savefig(overall_img_path, bbox_inches='tight')
plt.close()
mask_save_dir = os.path.join(result_img_dir, f"mask")
os.makedirs(mask_save_dir, exist_ok=True)
if use_ca_mask:
ca_mask_path = os.path.join(mask_save_dir, f"mask_ca.png")
mask_img = Image.fromarray((callback_fn.mask.cpu().float().numpy() * 255).astype(np.uint8)).convert('L')
mask_img.save(ca_mask_path)
del inv_latents
del init_latent
gc.collect()
torch.cuda.empty_cache()
if __name__ == '__main__':
main(args)