Spaces:
Running
on
Zero
Running
on
Zero
from distutils.util import strtobool | |
from typing import Optional | |
import os | |
import argparse | |
import gc | |
import os | |
import random | |
import re | |
import time | |
from distutils.util import strtobool | |
import spaces | |
import pandas as pd | |
import gc | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
import yaml | |
from diffusers import FlowMatchEulerDiscreteScheduler | |
from diffusers.utils.torch_utils import randn_tensor | |
from PIL import Image | |
from src.attn_utils.attn_utils import AttentionAdapter, AttnCollector | |
from src.attn_utils.flux_attn_processor import NewFluxAttnProcessor2_0 | |
from src.attn_utils.seq_aligner import get_refinement_mapper | |
from src.callback.callback_fn import CallbackAll | |
from src.inversion.inverse import get_inversed_latent_list | |
from src.inversion.scheduling_flow_inverse import \ | |
FlowMatchEulerDiscreteForwardScheduler | |
from src.pipeline.flux_pipeline import NewFluxPipeline | |
from src.transformer_utils.transformer_utils import (FeatureCollector, | |
FeatureReplace) | |
from src.utils import (find_token_id_differences, find_word_token_indices, | |
get_flux_pipeline, mask_decode, mask_interpolate) | |
from typing import Any, Callable, Dict, List, Optional, Union | |
pipe = get_flux_pipeline(pipeline_class=NewFluxPipeline) | |
pipe = pipe.to("cuda") | |
def fix_seed(random_seed): | |
""" | |
fix seed to control any randomness from a code | |
(enable stability of the experiments' results.) | |
""" | |
torch.manual_seed(random_seed) | |
torch.cuda.manual_seed(random_seed) | |
torch.cuda.manual_seed_all(random_seed) # if use multi-GPU | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
np.random.seed(random_seed) | |
random.seed(random_seed) | |
def infer( | |
input_image: Union[str, Image.Image], # ⬅️ Main UI (uploaded image) | |
target_prompt: Union[str, List[str]] = '', # ⬅️ Main UI (text prompt) | |
source_prompt: Union[str, List[str]] = '', # ⬅️ Advanced accordion | |
seed: int = 0, # ⬅️ Advanced accordion | |
ca_steps: int = 10, # ⬅️ Advanced accordion | |
sa_steps: int = 7, # ⬅️ Advanced accordion | |
feature_steps: int = 5, # ⬅️ Advanced accordion | |
attn_topk: int = 20, # ⬅️ Advanced accordion | |
mask_image: Optional[Image.Image] = None, # ⬅️ Advanced (optional upload) | |
# Everything below is backend-related or defaults, not exposed in UI | |
blend_word: str = '', | |
results_dir: str = 'results', | |
model: str = 'flux', | |
ca_attn_layer_from: int = 13, | |
ca_attn_layer_to: int = 45, | |
sa_attn_layer_from: int = 20, | |
sa_attn_layer_to: int = 45, | |
feature_layer_from: int = 13, | |
feature_layer_to: int = 20, | |
flow_steps: int = 7, | |
step_start: int = 0, | |
num_inference_steps: int = 28, | |
guidance_scale: float = 3.5, | |
text_scale: float = 4.0, | |
mid_step_index: int = 14, | |
use_mask: bool = True, | |
use_ca_mask: bool = True, | |
mask_steps: int = 18, | |
mask_dilation: int = 3, | |
mask_nbins: int = 128 | |
): | |
if isinstance(mask_image, Image.Image): | |
# Ensure mask is single channel | |
if mask_image.mode != "L": | |
mask_image = mask_image.convert("L") | |
fix_seed(seed) | |
device = torch.device('cuda') | |
attn_proc = NewFluxAttnProcessor2_0 | |
layer_order = range(57) | |
ca_layer_list = layer_order[ca_attn_layer_from:ca_attn_layer_to] | |
sa_layer_list = layer_order[feature_layer_to:sa_attn_layer_to] | |
feature_layer_list = layer_order[feature_layer_from:feature_layer_to] | |
source_img = input_image.resize((1024, 1024)).convert("RGB") | |
#img_base_name = os.path.splitext(img_path)[0].split('/')[-1] | |
result_img_dir = f"{results_dir}/seed_{seed}/{target_prompt}" | |
source_prompt = source_prompt | |
target_prompt = target_prompt | |
prompts = [source_prompt, target_prompt] | |
mask_path=mask_image | |
print(prompts) | |
mask = None | |
if use_mask: | |
use_mask = True | |
if mask_path is not None: | |
mask = mask_path | |
mask = torch.tensor(np.array(mask)).bool() | |
mask = mask.to(device) | |
# Increase the latent blending steps if the ground truth mask is used. | |
mask_steps = int(num_inference_steps * 0.9) | |
source_ca_index = None | |
target_ca_index = None | |
use_ca_mask = False | |
elif use_ca_mask and source_prompt: | |
mask = None | |
if blend_word and blend_word in source_prompt: | |
editing_source_token_index = find_word_token_indices(source_prompt, blend_word, pipe.tokenizer_2) | |
editing_target_token_index = None | |
else: | |
editing_tokens_info = find_token_id_differences(*prompts, pipe.tokenizer_2) | |
editing_source_token_index = editing_tokens_info['prompt_1']['index'] | |
editing_target_token_index = editing_tokens_info['prompt_2']['index'] | |
use_ca_mask = True | |
if editing_source_token_index: | |
source_ca_index = editing_source_token_index | |
target_ca_index = None | |
elif editing_target_token_index: | |
source_ca_index = None | |
target_ca_index = editing_target_token_index | |
else: | |
source_ca_index = None | |
target_ca_index = None | |
use_ca_mask = False | |
else: | |
source_ca_index = None | |
target_ca_index = None | |
use_ca_mask = False | |
else: | |
use_mask = False | |
use_ca_mask = False | |
source_ca_index = None | |
target_ca_index = None | |
if source_prompt: | |
# Use I2T-CA injection | |
mappers, alphas = get_refinement_mapper(prompts, pipe.tokenizer_2, max_len=512) | |
mappers = mappers.to(device=device) | |
alphas = alphas.to(device=device, dtype=pipe.dtype) | |
alphas = alphas[:, None, None, :] | |
attn_adj_from = 1 | |
else: | |
# Not use I2T-CA injection | |
mappers = None | |
alphas = None | |
ca_steps = 0 | |
attn_adj_from=3 | |
feature_steps = feature_steps | |
attn_controller = AttentionAdapter( | |
ca_layer_list=ca_layer_list, | |
sa_layer_list=sa_layer_list, | |
ca_steps=ca_steps, | |
sa_steps=sa_steps, | |
method='replace_topk', | |
topk=attn_topk, | |
text_scale=text_scale, | |
mappers=mappers, | |
alphas=alphas, | |
attn_adj_from=attn_adj_from, | |
save_source_ca=source_ca_index is not None, | |
save_target_ca=target_ca_index is not None, | |
) | |
attn_collector = AttnCollector( | |
transformer=pipe.transformer, | |
controller=attn_controller, | |
attn_processor_class=NewFluxAttnProcessor2_0, | |
) | |
feature_controller = FeatureReplace( | |
layer_list=feature_layer_list, | |
feature_steps=feature_steps, | |
) | |
feature_collector = FeatureCollector( | |
transformer=pipe.transformer, | |
controller=feature_controller, | |
) | |
num_prompts=len(prompts) | |
shape = (1, 16, 128, 128) | |
generator = torch.Generator(device=device).manual_seed(seed) | |
latents = randn_tensor(shape, device=device, generator=generator) | |
latents = pipe._pack_latents(latents, *latents.shape) | |
attn_collector.restore_orig_attention() | |
feature_collector.restore_orig_transformer() | |
t0 = time.perf_counter() | |
inv_latents = get_inversed_latent_list( | |
pipe, | |
source_img, | |
random_noise=latents, | |
num_inference_steps=num_inference_steps, | |
backward_method="ode", | |
use_prompt_for_inversion=False, | |
guidance_scale_for_inversion=0, | |
prompt_for_inversion='', | |
flow_steps=flow_steps, | |
) | |
source_latents = inv_latents[::-1] | |
target_latents = inv_latents[::-1] | |
attn_collector.register_attention_control() | |
feature_collector.register_transformer_control() | |
callback_fn = CallbackAll( | |
latents=source_latents, | |
attn_collector=attn_collector, | |
feature_collector=feature_collector, | |
feature_inject_steps=feature_steps, | |
mid_step_index=mid_step_index, | |
step_start=step_start, | |
use_mask=use_mask, | |
use_ca_mask=use_ca_mask, | |
source_ca_index=source_ca_index, | |
target_ca_index=target_ca_index, | |
mask_kwargs={'dilation': mask_dilation}, | |
mask_steps=mask_steps, | |
mask=mask, | |
) | |
init_latent = target_latents[step_start] | |
init_latent = init_latent.repeat(num_prompts, 1, 1) | |
init_latent[0] = source_latents[mid_step_index] | |
os.makedirs(result_img_dir, exist_ok=True) | |
pipe.scheduler = FlowMatchEulerDiscreteForwardScheduler.from_config( | |
pipe.scheduler.config, | |
step_start=step_start, | |
margin_index_from_image=0 | |
) | |
attn_controller.reset() | |
feature_controller.reset() | |
attn_controller.text_scale = text_scale | |
attn_controller.cur_step = step_start | |
feature_controller.cur_step = step_start | |
with torch.no_grad(): | |
images = pipe( | |
prompts, | |
latents=init_latent, | |
num_images_per_prompt=1, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
generator=generator, | |
callback_on_step_end=callback_fn, | |
mid_step_index=mid_step_index, | |
step_start=step_start, | |
callback_on_step_end_tensor_inputs=['latents'], | |
).images | |
t1 = time.perf_counter() | |
print(f"Done in {t1 - t0:.1f}s.") | |
source_img_path = os.path.join(result_img_dir, f"source.png") | |
source_img.save(source_img_path) | |
final_image=input_image | |
for i, img in enumerate(images[1:]): | |
target_img_path = os.path.join(result_img_dir, f"target_{i}.png") | |
img.save(target_img_path) | |
final_image=img | |
target_text_path = os.path.join(result_img_dir, f"target_prompts.txt") | |
with open(target_text_path, 'w') as file: | |
file.write(target_prompt + '\n') | |
source_text_path = os.path.join(result_img_dir, f"source_prompt.txt") | |
with open(source_text_path, 'w') as file: | |
file.write(source_prompt + '\n') | |
images = [source_img] + images | |
fs=3 | |
n = len(images) | |
fig, ax = plt.subplots(1, n, figsize=(n*fs, 1*fs)) | |
for i, img in enumerate(images): | |
ax[i].imshow(img) | |
ax[0].set_title('source') | |
ax[1].set_title(source_prompt, fontsize=7) | |
ax[2].set_title(target_prompt, fontsize=7) | |
overall_img_path = os.path.join(result_img_dir, f"overall.png") | |
plt.savefig(overall_img_path, bbox_inches='tight') | |
plt.close() | |
mask_save_dir = os.path.join(result_img_dir, f"mask") | |
os.makedirs(mask_save_dir, exist_ok=True) | |
if use_ca_mask: | |
ca_mask_path = os.path.join(mask_save_dir, f"mask_ca.png") | |
mask_img = Image.fromarray((callback_fn.mask.cpu().float().numpy() * 255).astype(np.uint8)).convert('L') | |
mask_img.save(ca_mask_path) | |
del inv_latents | |
del init_latent | |
gc.collect() | |
torch.cuda.empty_cache() | |
import shutil | |
shutil.rmtree(result_img_dir) | |
shutil.rmtree(results_dir) | |
return final_image, seed, gr.Button(visible=True) | |
import gradio as gr | |
from PIL import Image | |
import numpy as np | |
MAX_SEED = np.iinfo(np.int32).max | |
def infer_example(input_image, target_prompt, source_prompt, seed, ca_steps, sa_steps, feature_steps, attn_topk, mask_image=None): | |
img, seed, _ = infer( | |
input_image=input_image, | |
target_prompt=target_prompt, | |
source_prompt=source_prompt, | |
seed=seed, | |
ca_steps=ca_steps, | |
sa_steps=sa_steps, | |
feature_steps=feature_steps, | |
attn_topk=attn_topk, | |
mask_image=mask_image | |
) | |
return img, seed | |
with gr.Blocks() as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown("""# ReFlex | |
Text-Guided Editing of Real Images in Rectified Flow via Mid-Step Feature Extraction and Attention Adaptation | |
[[blog]](https://wlaud1001.github.io/ReFlex/) | [[Github]](https://github.com/wlaud1001/ReFlex) | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(label="Upload the image for editing", type="pil") | |
mask_image = gr.Image(label="Upload optional mask", type="pil") | |
with gr.Row(): | |
target_prompt = gr.Text( | |
label="Target Prompt", | |
show_label=False, | |
max_lines=1, | |
placeholder="Describe the Edited Image", | |
container=False, | |
) | |
with gr.Column(): | |
source_prompt = gr.Text( | |
label="Source Prompt", | |
show_label=False, | |
max_lines=1, | |
placeholder="Enter source prompt (optional) : Describe the Input Image", | |
container=False, | |
) | |
run_button = gr.Button("Run", scale=10) | |
with gr.Accordion("Advanced Settings", open=False): | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=0, | |
) | |
ca_steps = gr.Slider( | |
label="Cross-Attn (CA) Steps", | |
minimum=0, | |
maximum=20, | |
step=1, | |
value=10 | |
) | |
sa_steps = gr.Slider( | |
label="Self-Attn (SA) Steps", | |
minimum=0, | |
maximum=20, | |
step=1, | |
value=7 | |
) | |
feature_steps = gr.Slider( | |
label="Feature Injection Steps", | |
minimum=0, | |
maximum=20, | |
step=1, | |
value=5 | |
) | |
attn_topk = gr.Slider( | |
label="Attention Top-K", | |
minimum=1, | |
maximum=64, | |
step=1, | |
value=20 | |
) | |
with gr.Column(): | |
result = gr.Image(label="Result", show_label=False, interactive=False) | |
reuse_button = gr.Button("Reuse this image", visible=False) | |
examples = gr.Examples( | |
examples=[ | |
# 2. Without mask | |
[ | |
"data/images/bear.jpeg", | |
"an image of Paddington the bear", | |
"", | |
0, 0, 12, 7, 20, | |
None | |
], | |
# 3. Without mask | |
[ | |
"data/images/bird_painting.jpg", | |
"a photo of an eagle in the sky", | |
"", | |
0, 0, 12, 7, 20, | |
None | |
], | |
[ | |
"data/images/dancing.jpeg", | |
"a couple of silver robots dancing in the garden", | |
"", | |
0, 0, 12, 7, 20, | |
None | |
], | |
[ | |
"data/images/real_karate.jpeg", | |
"a silver robot in the snow", | |
"", | |
0, 0, 12, 7, 20, | |
None | |
], | |
[ | |
"data/images/woman_book.jpg", | |
"a woman sitting in the grass with a laptop", | |
"a woman sitting in the grass with a book", | |
0, 10, 7, 5, 20, | |
None | |
], | |
[ | |
"data/images/statue.jpg", | |
"photo of a statue in side view", | |
"photo of a statue in front view", | |
0, 10, 7, 5, 60, | |
None | |
], | |
[ | |
"data/images/tennis.jpg", | |
"a iron woman robot in a black tank top and pink shorts is about to hit a tennis ball", | |
"a woman in a black tank top and pink shorts is about to hit a tennis ball", | |
0, 10, 7, 5, 20, | |
None | |
], | |
[ | |
"data/images/owl_heart.jpg", | |
"a cartoon painting of a cute owl with a circle on its body", | |
"a cartoon painting of a cute owl with a heart on its body", | |
0, 10, 7, 5, 20, | |
None | |
], | |
[ | |
"data/images/girl_mountain.jpg", | |
"a woman with her arms outstretched in front of the NewYork", | |
"a woman with her arms outstretched on top of a mountain", | |
0, 10, 7, 5, 20, | |
"data/masks/girl_mountain.jpg" | |
], | |
[ | |
"data/images/santa.jpg", | |
"the christmas illustration of a santa's angry face", | |
"the christmas illustration of a santa's laughing face", | |
0, 10, 7, 5, 20, | |
"data/masks/santa.jpg" | |
], | |
[ | |
"data/images/cat_mirror.jpg", | |
"a tiger sitting next to a mirror", | |
"a cat sitting next to a mirror", | |
0, 10, 7, 5, 20, | |
"data/masks/cat_mirror.jpg" | |
], | |
], | |
inputs=[ | |
input_image, | |
target_prompt, | |
source_prompt, | |
seed, | |
ca_steps, | |
sa_steps, | |
feature_steps, | |
attn_topk, | |
mask_image | |
], | |
outputs=[result, seed], | |
fn=infer_example, | |
cache_examples="lazy" | |
) | |
gr.on( | |
triggers=[run_button.click, target_prompt.submit], | |
fn=infer, | |
inputs=[ | |
input_image, | |
target_prompt, | |
source_prompt, | |
seed, | |
ca_steps, | |
sa_steps, | |
feature_steps, | |
attn_topk, | |
mask_image | |
], | |
outputs=[result, seed, reuse_button] | |
) | |
reuse_button.click( | |
fn=lambda image: image, | |
inputs=[result], | |
outputs=[input_image] | |
) | |
demo.launch(share=True, debug=True) |