from distutils.util import strtobool from typing import Optional import os import argparse import gc import os import random import re import time from distutils.util import strtobool import spaces import pandas as pd import gc import matplotlib.pyplot as plt import numpy as np import torch import yaml from diffusers import FlowMatchEulerDiscreteScheduler from diffusers.utils.torch_utils import randn_tensor from PIL import Image from src.attn_utils.attn_utils import AttentionAdapter, AttnCollector from src.attn_utils.flux_attn_processor import NewFluxAttnProcessor2_0 from src.attn_utils.seq_aligner import get_refinement_mapper from src.callback.callback_fn import CallbackAll from src.inversion.inverse import get_inversed_latent_list from src.inversion.scheduling_flow_inverse import \ FlowMatchEulerDiscreteForwardScheduler from src.pipeline.flux_pipeline import NewFluxPipeline from src.transformer_utils.transformer_utils import (FeatureCollector, FeatureReplace) from src.utils import (find_token_id_differences, find_word_token_indices, get_flux_pipeline, mask_decode, mask_interpolate) from typing import Any, Callable, Dict, List, Optional, Union pipe = get_flux_pipeline(pipeline_class=NewFluxPipeline) pipe = pipe.to("cuda") def fix_seed(random_seed): """ fix seed to control any randomness from a code (enable stability of the experiments' results.) """ torch.manual_seed(random_seed) torch.cuda.manual_seed(random_seed) torch.cuda.manual_seed_all(random_seed) # if use multi-GPU torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False np.random.seed(random_seed) random.seed(random_seed) @spaces.GPU def infer( input_image: Union[str, Image.Image], # ⬅️ Main UI (uploaded image) target_prompt: Union[str, List[str]] = '', # ⬅️ Main UI (text prompt) source_prompt: Union[str, List[str]] = '', # ⬅️ Advanced accordion seed: int = 0, # ⬅️ Advanced accordion ca_steps: int = 10, # ⬅️ Advanced accordion sa_steps: int = 7, # ⬅️ Advanced accordion feature_steps: int = 5, # ⬅️ Advanced accordion attn_topk: int = 20, # ⬅️ Advanced accordion mask_image: Optional[Image.Image] = None, # ⬅️ Advanced (optional upload) # Everything below is backend-related or defaults, not exposed in UI blend_word: str = '', results_dir: str = 'results', model: str = 'flux', ca_attn_layer_from: int = 13, ca_attn_layer_to: int = 45, sa_attn_layer_from: int = 20, sa_attn_layer_to: int = 45, feature_layer_from: int = 13, feature_layer_to: int = 20, flow_steps: int = 7, step_start: int = 0, num_inference_steps: int = 28, guidance_scale: float = 3.5, text_scale: float = 4.0, mid_step_index: int = 14, use_mask: bool = True, use_ca_mask: bool = True, mask_steps: int = 18, mask_dilation: int = 3, mask_nbins: int = 128 ): if isinstance(mask_image, Image.Image): # Ensure mask is single channel if mask_image.mode != "L": mask_image = mask_image.convert("L") fix_seed(seed) device = torch.device('cuda') attn_proc = NewFluxAttnProcessor2_0 layer_order = range(57) ca_layer_list = layer_order[ca_attn_layer_from:ca_attn_layer_to] sa_layer_list = layer_order[feature_layer_to:sa_attn_layer_to] feature_layer_list = layer_order[feature_layer_from:feature_layer_to] source_img = input_image.resize((1024, 1024)).convert("RGB") #img_base_name = os.path.splitext(img_path)[0].split('/')[-1] result_img_dir = f"{results_dir}/seed_{seed}/{target_prompt}" source_prompt = source_prompt target_prompt = target_prompt prompts = [source_prompt, target_prompt] mask_path=mask_image print(prompts) mask = None if use_mask: use_mask = True if mask_path is not None: mask = mask_path mask = torch.tensor(np.array(mask)).bool() mask = mask.to(device) # Increase the latent blending steps if the ground truth mask is used. mask_steps = int(num_inference_steps * 0.9) source_ca_index = None target_ca_index = None use_ca_mask = False elif use_ca_mask and source_prompt: mask = None if blend_word and blend_word in source_prompt: editing_source_token_index = find_word_token_indices(source_prompt, blend_word, pipe.tokenizer_2) editing_target_token_index = None else: editing_tokens_info = find_token_id_differences(*prompts, pipe.tokenizer_2) editing_source_token_index = editing_tokens_info['prompt_1']['index'] editing_target_token_index = editing_tokens_info['prompt_2']['index'] use_ca_mask = True if editing_source_token_index: source_ca_index = editing_source_token_index target_ca_index = None elif editing_target_token_index: source_ca_index = None target_ca_index = editing_target_token_index else: source_ca_index = None target_ca_index = None use_ca_mask = False else: source_ca_index = None target_ca_index = None use_ca_mask = False else: use_mask = False use_ca_mask = False source_ca_index = None target_ca_index = None if source_prompt: # Use I2T-CA injection mappers, alphas = get_refinement_mapper(prompts, pipe.tokenizer_2, max_len=512) mappers = mappers.to(device=device) alphas = alphas.to(device=device, dtype=pipe.dtype) alphas = alphas[:, None, None, :] attn_adj_from = 1 else: # Not use I2T-CA injection mappers = None alphas = None ca_steps = 0 attn_adj_from=3 feature_steps = feature_steps attn_controller = AttentionAdapter( ca_layer_list=ca_layer_list, sa_layer_list=sa_layer_list, ca_steps=ca_steps, sa_steps=sa_steps, method='replace_topk', topk=attn_topk, text_scale=text_scale, mappers=mappers, alphas=alphas, attn_adj_from=attn_adj_from, save_source_ca=source_ca_index is not None, save_target_ca=target_ca_index is not None, ) attn_collector = AttnCollector( transformer=pipe.transformer, controller=attn_controller, attn_processor_class=NewFluxAttnProcessor2_0, ) feature_controller = FeatureReplace( layer_list=feature_layer_list, feature_steps=feature_steps, ) feature_collector = FeatureCollector( transformer=pipe.transformer, controller=feature_controller, ) num_prompts=len(prompts) shape = (1, 16, 128, 128) generator = torch.Generator(device=device).manual_seed(seed) latents = randn_tensor(shape, device=device, generator=generator) latents = pipe._pack_latents(latents, *latents.shape) attn_collector.restore_orig_attention() feature_collector.restore_orig_transformer() t0 = time.perf_counter() inv_latents = get_inversed_latent_list( pipe, source_img, random_noise=latents, num_inference_steps=num_inference_steps, backward_method="ode", use_prompt_for_inversion=False, guidance_scale_for_inversion=0, prompt_for_inversion='', flow_steps=flow_steps, ) source_latents = inv_latents[::-1] target_latents = inv_latents[::-1] attn_collector.register_attention_control() feature_collector.register_transformer_control() callback_fn = CallbackAll( latents=source_latents, attn_collector=attn_collector, feature_collector=feature_collector, feature_inject_steps=feature_steps, mid_step_index=mid_step_index, step_start=step_start, use_mask=use_mask, use_ca_mask=use_ca_mask, source_ca_index=source_ca_index, target_ca_index=target_ca_index, mask_kwargs={'dilation': mask_dilation}, mask_steps=mask_steps, mask=mask, ) init_latent = target_latents[step_start] init_latent = init_latent.repeat(num_prompts, 1, 1) init_latent[0] = source_latents[mid_step_index] os.makedirs(result_img_dir, exist_ok=True) pipe.scheduler = FlowMatchEulerDiscreteForwardScheduler.from_config( pipe.scheduler.config, step_start=step_start, margin_index_from_image=0 ) attn_controller.reset() feature_controller.reset() attn_controller.text_scale = text_scale attn_controller.cur_step = step_start feature_controller.cur_step = step_start with torch.no_grad(): images = pipe( prompts, latents=init_latent, num_images_per_prompt=1, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=generator, callback_on_step_end=callback_fn, mid_step_index=mid_step_index, step_start=step_start, callback_on_step_end_tensor_inputs=['latents'], ).images t1 = time.perf_counter() print(f"Done in {t1 - t0:.1f}s.") source_img_path = os.path.join(result_img_dir, f"source.png") source_img.save(source_img_path) final_image=input_image for i, img in enumerate(images[1:]): target_img_path = os.path.join(result_img_dir, f"target_{i}.png") img.save(target_img_path) final_image=img target_text_path = os.path.join(result_img_dir, f"target_prompts.txt") with open(target_text_path, 'w') as file: file.write(target_prompt + '\n') source_text_path = os.path.join(result_img_dir, f"source_prompt.txt") with open(source_text_path, 'w') as file: file.write(source_prompt + '\n') images = [source_img] + images fs=3 n = len(images) fig, ax = plt.subplots(1, n, figsize=(n*fs, 1*fs)) for i, img in enumerate(images): ax[i].imshow(img) ax[0].set_title('source') ax[1].set_title(source_prompt, fontsize=7) ax[2].set_title(target_prompt, fontsize=7) overall_img_path = os.path.join(result_img_dir, f"overall.png") plt.savefig(overall_img_path, bbox_inches='tight') plt.close() mask_save_dir = os.path.join(result_img_dir, f"mask") os.makedirs(mask_save_dir, exist_ok=True) if use_ca_mask: ca_mask_path = os.path.join(mask_save_dir, f"mask_ca.png") mask_img = Image.fromarray((callback_fn.mask.cpu().float().numpy() * 255).astype(np.uint8)).convert('L') mask_img.save(ca_mask_path) del inv_latents del init_latent gc.collect() torch.cuda.empty_cache() import shutil shutil.rmtree(result_img_dir) shutil.rmtree(results_dir) return final_image, seed, gr.Button(visible=True) import gradio as gr from PIL import Image import numpy as np MAX_SEED = np.iinfo(np.int32).max @spaces.GPU def infer_example(input_image, target_prompt, source_prompt, seed, ca_steps, sa_steps, feature_steps, attn_topk, mask_image=None): img, seed, _ = infer( input_image=input_image, target_prompt=target_prompt, source_prompt=source_prompt, seed=seed, ca_steps=ca_steps, sa_steps=sa_steps, feature_steps=feature_steps, attn_topk=attn_topk, mask_image=mask_image ) return img, seed with gr.Blocks() as demo: with gr.Column(elem_id="col-container"): gr.Markdown("""# ReFlex Text-Guided Editing of Real Images in Rectified Flow via Mid-Step Feature Extraction and Attention Adaptation [[blog]](https://wlaud1001.github.io/ReFlex/) | [[Github]](https://github.com/wlaud1001/ReFlex) """) with gr.Row(): with gr.Column(): input_image = gr.Image(label="Upload the image for editing", type="pil") mask_image = gr.Image(label="Upload optional mask", type="pil") with gr.Row(): target_prompt = gr.Text( label="Target Prompt", show_label=False, max_lines=1, placeholder="Describe the Edited Image", container=False, ) with gr.Column(): source_prompt = gr.Text( label="Source Prompt", show_label=False, max_lines=1, placeholder="Enter source prompt (optional) : Describe the Input Image", container=False, ) run_button = gr.Button("Run", scale=10) with gr.Accordion("Advanced Settings", open=False): seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, ) ca_steps = gr.Slider( label="Cross-Attn (CA) Steps", minimum=0, maximum=20, step=1, value=10 ) sa_steps = gr.Slider( label="Self-Attn (SA) Steps", minimum=0, maximum=20, step=1, value=7 ) feature_steps = gr.Slider( label="Feature Injection Steps", minimum=0, maximum=20, step=1, value=5 ) attn_topk = gr.Slider( label="Attention Top-K", minimum=1, maximum=64, step=1, value=20 ) with gr.Column(): result = gr.Image(label="Result", show_label=False, interactive=False) reuse_button = gr.Button("Reuse this image", visible=False) examples = gr.Examples( examples=[ # 2. Without mask [ "data/images/bear.jpeg", "an image of Paddington the bear", "", 0, 0, 12, 7, 20, None ], # 3. Without mask [ "data/images/bird_painting.jpg", "a photo of an eagle in the sky", "", 0, 0, 12, 7, 20, None ], [ "data/images/dancing.jpeg", "a couple of silver robots dancing in the garden", "", 0, 0, 12, 7, 20, None ], [ "data/images/real_karate.jpeg", "a silver robot in the snow", "", 0, 0, 12, 7, 20, None ], [ "data/images/woman_book.jpg", "a woman sitting in the grass with a laptop", "a woman sitting in the grass with a book", 0, 10, 7, 5, 20, None ], [ "data/images/statue.jpg", "photo of a statue in side view", "photo of a statue in front view", 0, 10, 7, 5, 60, None ], [ "data/images/tennis.jpg", "a iron woman robot in a black tank top and pink shorts is about to hit a tennis ball", "a woman in a black tank top and pink shorts is about to hit a tennis ball", 0, 10, 7, 5, 20, None ], [ "data/images/owl_heart.jpg", "a cartoon painting of a cute owl with a circle on its body", "a cartoon painting of a cute owl with a heart on its body", 0, 10, 7, 5, 20, None ], [ "data/images/girl_mountain.jpg", "a woman with her arms outstretched in front of the NewYork", "a woman with her arms outstretched on top of a mountain", 0, 10, 7, 5, 20, "data/masks/girl_mountain.jpg" ], [ "data/images/santa.jpg", "the christmas illustration of a santa's angry face", "the christmas illustration of a santa's laughing face", 0, 10, 7, 5, 20, "data/masks/santa.jpg" ], [ "data/images/cat_mirror.jpg", "a tiger sitting next to a mirror", "a cat sitting next to a mirror", 0, 10, 7, 5, 20, "data/masks/cat_mirror.jpg" ], ], inputs=[ input_image, target_prompt, source_prompt, seed, ca_steps, sa_steps, feature_steps, attn_topk, mask_image ], outputs=[result, seed], fn=infer_example, cache_examples="lazy" ) gr.on( triggers=[run_button.click, target_prompt.submit], fn=infer, inputs=[ input_image, target_prompt, source_prompt, seed, ca_steps, sa_steps, feature_steps, attn_topk, mask_image ], outputs=[result, seed, reuse_button] ) reuse_button.click( fn=lambda image: image, inputs=[result], outputs=[input_image] ) demo.launch(share=True, debug=True)