ReFlex / app.py
SahilCarterr's picture
Create app.py
a49b34f verified
from distutils.util import strtobool
from typing import Optional
import os
import argparse
import gc
import os
import random
import re
import time
from distutils.util import strtobool
import spaces
import pandas as pd
import gc
import matplotlib.pyplot as plt
import numpy as np
import torch
import yaml
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.utils.torch_utils import randn_tensor
from PIL import Image
from src.attn_utils.attn_utils import AttentionAdapter, AttnCollector
from src.attn_utils.flux_attn_processor import NewFluxAttnProcessor2_0
from src.attn_utils.seq_aligner import get_refinement_mapper
from src.callback.callback_fn import CallbackAll
from src.inversion.inverse import get_inversed_latent_list
from src.inversion.scheduling_flow_inverse import \
FlowMatchEulerDiscreteForwardScheduler
from src.pipeline.flux_pipeline import NewFluxPipeline
from src.transformer_utils.transformer_utils import (FeatureCollector,
FeatureReplace)
from src.utils import (find_token_id_differences, find_word_token_indices,
get_flux_pipeline, mask_decode, mask_interpolate)
from typing import Any, Callable, Dict, List, Optional, Union
pipe = get_flux_pipeline(pipeline_class=NewFluxPipeline)
pipe = pipe.to("cuda")
def fix_seed(random_seed):
"""
fix seed to control any randomness from a code
(enable stability of the experiments' results.)
"""
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed) # if use multi-GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(random_seed)
random.seed(random_seed)
@spaces.GPU
def infer(
input_image: Union[str, Image.Image], # ⬅️ Main UI (uploaded image)
target_prompt: Union[str, List[str]] = '', # ⬅️ Main UI (text prompt)
source_prompt: Union[str, List[str]] = '', # ⬅️ Advanced accordion
seed: int = 0, # ⬅️ Advanced accordion
ca_steps: int = 10, # ⬅️ Advanced accordion
sa_steps: int = 7, # ⬅️ Advanced accordion
feature_steps: int = 5, # ⬅️ Advanced accordion
attn_topk: int = 20, # ⬅️ Advanced accordion
mask_image: Optional[Image.Image] = None, # ⬅️ Advanced (optional upload)
# Everything below is backend-related or defaults, not exposed in UI
blend_word: str = '',
results_dir: str = 'results',
model: str = 'flux',
ca_attn_layer_from: int = 13,
ca_attn_layer_to: int = 45,
sa_attn_layer_from: int = 20,
sa_attn_layer_to: int = 45,
feature_layer_from: int = 13,
feature_layer_to: int = 20,
flow_steps: int = 7,
step_start: int = 0,
num_inference_steps: int = 28,
guidance_scale: float = 3.5,
text_scale: float = 4.0,
mid_step_index: int = 14,
use_mask: bool = True,
use_ca_mask: bool = True,
mask_steps: int = 18,
mask_dilation: int = 3,
mask_nbins: int = 128
):
if isinstance(mask_image, Image.Image):
# Ensure mask is single channel
if mask_image.mode != "L":
mask_image = mask_image.convert("L")
fix_seed(seed)
device = torch.device('cuda')
attn_proc = NewFluxAttnProcessor2_0
layer_order = range(57)
ca_layer_list = layer_order[ca_attn_layer_from:ca_attn_layer_to]
sa_layer_list = layer_order[feature_layer_to:sa_attn_layer_to]
feature_layer_list = layer_order[feature_layer_from:feature_layer_to]
source_img = input_image.resize((1024, 1024)).convert("RGB")
#img_base_name = os.path.splitext(img_path)[0].split('/')[-1]
result_img_dir = f"{results_dir}/seed_{seed}/{target_prompt}"
source_prompt = source_prompt
target_prompt = target_prompt
prompts = [source_prompt, target_prompt]
mask_path=mask_image
print(prompts)
mask = None
if use_mask:
use_mask = True
if mask_path is not None:
mask = mask_path
mask = torch.tensor(np.array(mask)).bool()
mask = mask.to(device)
# Increase the latent blending steps if the ground truth mask is used.
mask_steps = int(num_inference_steps * 0.9)
source_ca_index = None
target_ca_index = None
use_ca_mask = False
elif use_ca_mask and source_prompt:
mask = None
if blend_word and blend_word in source_prompt:
editing_source_token_index = find_word_token_indices(source_prompt, blend_word, pipe.tokenizer_2)
editing_target_token_index = None
else:
editing_tokens_info = find_token_id_differences(*prompts, pipe.tokenizer_2)
editing_source_token_index = editing_tokens_info['prompt_1']['index']
editing_target_token_index = editing_tokens_info['prompt_2']['index']
use_ca_mask = True
if editing_source_token_index:
source_ca_index = editing_source_token_index
target_ca_index = None
elif editing_target_token_index:
source_ca_index = None
target_ca_index = editing_target_token_index
else:
source_ca_index = None
target_ca_index = None
use_ca_mask = False
else:
source_ca_index = None
target_ca_index = None
use_ca_mask = False
else:
use_mask = False
use_ca_mask = False
source_ca_index = None
target_ca_index = None
if source_prompt:
# Use I2T-CA injection
mappers, alphas = get_refinement_mapper(prompts, pipe.tokenizer_2, max_len=512)
mappers = mappers.to(device=device)
alphas = alphas.to(device=device, dtype=pipe.dtype)
alphas = alphas[:, None, None, :]
attn_adj_from = 1
else:
# Not use I2T-CA injection
mappers = None
alphas = None
ca_steps = 0
attn_adj_from=3
feature_steps = feature_steps
attn_controller = AttentionAdapter(
ca_layer_list=ca_layer_list,
sa_layer_list=sa_layer_list,
ca_steps=ca_steps,
sa_steps=sa_steps,
method='replace_topk',
topk=attn_topk,
text_scale=text_scale,
mappers=mappers,
alphas=alphas,
attn_adj_from=attn_adj_from,
save_source_ca=source_ca_index is not None,
save_target_ca=target_ca_index is not None,
)
attn_collector = AttnCollector(
transformer=pipe.transformer,
controller=attn_controller,
attn_processor_class=NewFluxAttnProcessor2_0,
)
feature_controller = FeatureReplace(
layer_list=feature_layer_list,
feature_steps=feature_steps,
)
feature_collector = FeatureCollector(
transformer=pipe.transformer,
controller=feature_controller,
)
num_prompts=len(prompts)
shape = (1, 16, 128, 128)
generator = torch.Generator(device=device).manual_seed(seed)
latents = randn_tensor(shape, device=device, generator=generator)
latents = pipe._pack_latents(latents, *latents.shape)
attn_collector.restore_orig_attention()
feature_collector.restore_orig_transformer()
t0 = time.perf_counter()
inv_latents = get_inversed_latent_list(
pipe,
source_img,
random_noise=latents,
num_inference_steps=num_inference_steps,
backward_method="ode",
use_prompt_for_inversion=False,
guidance_scale_for_inversion=0,
prompt_for_inversion='',
flow_steps=flow_steps,
)
source_latents = inv_latents[::-1]
target_latents = inv_latents[::-1]
attn_collector.register_attention_control()
feature_collector.register_transformer_control()
callback_fn = CallbackAll(
latents=source_latents,
attn_collector=attn_collector,
feature_collector=feature_collector,
feature_inject_steps=feature_steps,
mid_step_index=mid_step_index,
step_start=step_start,
use_mask=use_mask,
use_ca_mask=use_ca_mask,
source_ca_index=source_ca_index,
target_ca_index=target_ca_index,
mask_kwargs={'dilation': mask_dilation},
mask_steps=mask_steps,
mask=mask,
)
init_latent = target_latents[step_start]
init_latent = init_latent.repeat(num_prompts, 1, 1)
init_latent[0] = source_latents[mid_step_index]
os.makedirs(result_img_dir, exist_ok=True)
pipe.scheduler = FlowMatchEulerDiscreteForwardScheduler.from_config(
pipe.scheduler.config,
step_start=step_start,
margin_index_from_image=0
)
attn_controller.reset()
feature_controller.reset()
attn_controller.text_scale = text_scale
attn_controller.cur_step = step_start
feature_controller.cur_step = step_start
with torch.no_grad():
images = pipe(
prompts,
latents=init_latent,
num_images_per_prompt=1,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
callback_on_step_end=callback_fn,
mid_step_index=mid_step_index,
step_start=step_start,
callback_on_step_end_tensor_inputs=['latents'],
).images
t1 = time.perf_counter()
print(f"Done in {t1 - t0:.1f}s.")
source_img_path = os.path.join(result_img_dir, f"source.png")
source_img.save(source_img_path)
final_image=input_image
for i, img in enumerate(images[1:]):
target_img_path = os.path.join(result_img_dir, f"target_{i}.png")
img.save(target_img_path)
final_image=img
target_text_path = os.path.join(result_img_dir, f"target_prompts.txt")
with open(target_text_path, 'w') as file:
file.write(target_prompt + '\n')
source_text_path = os.path.join(result_img_dir, f"source_prompt.txt")
with open(source_text_path, 'w') as file:
file.write(source_prompt + '\n')
images = [source_img] + images
fs=3
n = len(images)
fig, ax = plt.subplots(1, n, figsize=(n*fs, 1*fs))
for i, img in enumerate(images):
ax[i].imshow(img)
ax[0].set_title('source')
ax[1].set_title(source_prompt, fontsize=7)
ax[2].set_title(target_prompt, fontsize=7)
overall_img_path = os.path.join(result_img_dir, f"overall.png")
plt.savefig(overall_img_path, bbox_inches='tight')
plt.close()
mask_save_dir = os.path.join(result_img_dir, f"mask")
os.makedirs(mask_save_dir, exist_ok=True)
if use_ca_mask:
ca_mask_path = os.path.join(mask_save_dir, f"mask_ca.png")
mask_img = Image.fromarray((callback_fn.mask.cpu().float().numpy() * 255).astype(np.uint8)).convert('L')
mask_img.save(ca_mask_path)
del inv_latents
del init_latent
gc.collect()
torch.cuda.empty_cache()
import shutil
shutil.rmtree(result_img_dir)
shutil.rmtree(results_dir)
return final_image, seed, gr.Button(visible=True)
import gradio as gr
from PIL import Image
import numpy as np
MAX_SEED = np.iinfo(np.int32).max
@spaces.GPU
def infer_example(input_image, target_prompt, source_prompt, seed, ca_steps, sa_steps, feature_steps, attn_topk, mask_image=None):
img, seed, _ = infer(
input_image=input_image,
target_prompt=target_prompt,
source_prompt=source_prompt,
seed=seed,
ca_steps=ca_steps,
sa_steps=sa_steps,
feature_steps=feature_steps,
attn_topk=attn_topk,
mask_image=mask_image
)
return img, seed
with gr.Blocks() as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("""# ReFlex
Text-Guided Editing of Real Images in Rectified Flow via Mid-Step Feature Extraction and Attention Adaptation
[[blog]](https://wlaud1001.github.io/ReFlex/) | [[Github]](https://github.com/wlaud1001/ReFlex)
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Upload the image for editing", type="pil")
mask_image = gr.Image(label="Upload optional mask", type="pil")
with gr.Row():
target_prompt = gr.Text(
label="Target Prompt",
show_label=False,
max_lines=1,
placeholder="Describe the Edited Image",
container=False,
)
with gr.Column():
source_prompt = gr.Text(
label="Source Prompt",
show_label=False,
max_lines=1,
placeholder="Enter source prompt (optional) : Describe the Input Image",
container=False,
)
run_button = gr.Button("Run", scale=10)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
ca_steps = gr.Slider(
label="Cross-Attn (CA) Steps",
minimum=0,
maximum=20,
step=1,
value=10
)
sa_steps = gr.Slider(
label="Self-Attn (SA) Steps",
minimum=0,
maximum=20,
step=1,
value=7
)
feature_steps = gr.Slider(
label="Feature Injection Steps",
minimum=0,
maximum=20,
step=1,
value=5
)
attn_topk = gr.Slider(
label="Attention Top-K",
minimum=1,
maximum=64,
step=1,
value=20
)
with gr.Column():
result = gr.Image(label="Result", show_label=False, interactive=False)
reuse_button = gr.Button("Reuse this image", visible=False)
examples = gr.Examples(
examples=[
# 2. Without mask
[
"data/images/bear.jpeg",
"an image of Paddington the bear",
"",
0, 0, 12, 7, 20,
None
],
# 3. Without mask
[
"data/images/bird_painting.jpg",
"a photo of an eagle in the sky",
"",
0, 0, 12, 7, 20,
None
],
[
"data/images/dancing.jpeg",
"a couple of silver robots dancing in the garden",
"",
0, 0, 12, 7, 20,
None
],
[
"data/images/real_karate.jpeg",
"a silver robot in the snow",
"",
0, 0, 12, 7, 20,
None
],
[
"data/images/woman_book.jpg",
"a woman sitting in the grass with a laptop",
"a woman sitting in the grass with a book",
0, 10, 7, 5, 20,
None
],
[
"data/images/statue.jpg",
"photo of a statue in side view",
"photo of a statue in front view",
0, 10, 7, 5, 60,
None
],
[
"data/images/tennis.jpg",
"a iron woman robot in a black tank top and pink shorts is about to hit a tennis ball",
"a woman in a black tank top and pink shorts is about to hit a tennis ball",
0, 10, 7, 5, 20,
None
],
[
"data/images/owl_heart.jpg",
"a cartoon painting of a cute owl with a circle on its body",
"a cartoon painting of a cute owl with a heart on its body",
0, 10, 7, 5, 20,
None
],
[
"data/images/girl_mountain.jpg",
"a woman with her arms outstretched in front of the NewYork",
"a woman with her arms outstretched on top of a mountain",
0, 10, 7, 5, 20,
"data/masks/girl_mountain.jpg"
],
[
"data/images/santa.jpg",
"the christmas illustration of a santa's angry face",
"the christmas illustration of a santa's laughing face",
0, 10, 7, 5, 20,
"data/masks/santa.jpg"
],
[
"data/images/cat_mirror.jpg",
"a tiger sitting next to a mirror",
"a cat sitting next to a mirror",
0, 10, 7, 5, 20,
"data/masks/cat_mirror.jpg"
],
],
inputs=[
input_image,
target_prompt,
source_prompt,
seed,
ca_steps,
sa_steps,
feature_steps,
attn_topk,
mask_image
],
outputs=[result, seed],
fn=infer_example,
cache_examples="lazy"
)
gr.on(
triggers=[run_button.click, target_prompt.submit],
fn=infer,
inputs=[
input_image,
target_prompt,
source_prompt,
seed,
ca_steps,
sa_steps,
feature_steps,
attn_topk,
mask_image
],
outputs=[result, seed, reuse_button]
)
reuse_button.click(
fn=lambda image: image,
inputs=[result],
outputs=[input_image]
)
demo.launch(share=True, debug=True)