Spaces:
Running
on
Zero
Running
on
Zero
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
""" | |
IC-Custom Gradio Application | |
This module defines the UI and glue logic to run the IC-Custom pipeline | |
via Gradio. The code aims to keep UI text user-friendly while keeping the | |
implementation readable and maintainable. | |
""" | |
import os | |
import sys | |
import numpy as np | |
import torch | |
import gradio as gr | |
import spaces | |
from PIL import Image | |
import time | |
# Add current directory to path for imports | |
sys.path.append(os.getcwd() + '/app') | |
# Import modular components | |
from config import parse_args, load_config, setup_environment | |
from ui_components import ( | |
create_theme, create_css, create_header_section, create_customization_section, | |
create_image_input_section, create_prompt_section, create_advanced_options_section, | |
create_mask_operation_section, create_output_section, create_examples_section, | |
create_citation_section | |
) | |
from event_handlers import setup_event_handlers | |
from business_logic import ( | |
init_image_target_1, init_image_target_2, init_image_reference, | |
undo_seg_points, segmentation, get_point, get_brush, | |
dilate_mask, erode_mask, bounding_box, | |
change_input_mask_mode, change_custmization_mode, change_seg_ref_mode, | |
vlm_auto_generate, vlm_auto_polish, save_results, set_mobile_predictor, | |
set_ben2_model, set_vlm_processor, set_vlm_model, | |
) | |
# Import other dependencies | |
from utils import ( | |
get_sam_predictor, get_vlm, get_ben2_model, | |
prepare_input_images, get_mask_type_ids | |
) | |
from examples import GRADIO_EXAMPLES, MASK_TGT, IMG_GEN | |
from ic_custom.pipelines.ic_custom_pipeline import ICCustomPipeline | |
# Global variables for pipeline and assets cache directory | |
PIPELINE = None | |
ASSETS_CACHE_DIR = None | |
# Force Hugging Face to re-download models and clear cache | |
os.environ["HF_HUB_FORCE_DOWNLOAD"] = "1" | |
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1" | |
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache" # Use temp directory for Spaces | |
os.environ["HF_HOME"] = "/tmp/hf_home" # Use temp directory for Spaces | |
os.environ["GRADIO_TEMP_DIR"] = os.path.abspath(os.path.join(os.path.dirname(__file__), "gradio_cache")) | |
def set_pipeline(pipeline): | |
"""Inject pipeline into this module without changing function signatures.""" | |
global PIPELINE | |
PIPELINE = pipeline | |
def set_assets_cache_dir(assets_cache_dir): | |
"""Inject assets cache dir into this module without changing function signatures.""" | |
global ASSETS_CACHE_DIR | |
ASSETS_CACHE_DIR = assets_cache_dir | |
def initialize_models(args, cfg, device, weight_dtype): | |
"""Initialize all required models.""" | |
# Load IC-Custom pipeline | |
pipeline = ICCustomPipeline( | |
clip_path=cfg.checkpoint_config.clip_path if os.path.exists(cfg.checkpoint_config.clip_path) else "clip-vit-large-patch14", | |
t5_path=cfg.checkpoint_config.t5_path if os.path.exists(cfg.checkpoint_config.t5_path) else "t5-v1_1-xxl", | |
siglip_path=cfg.checkpoint_config.siglip_path if os.path.exists(cfg.checkpoint_config.siglip_path) else "siglip-so400m-patch14-384", | |
ae_path=cfg.checkpoint_config.ae_path if os.path.exists(cfg.checkpoint_config.ae_path) else "flux-fill-dev-ae", | |
dit_path=cfg.checkpoint_config.dit_path if os.path.exists(cfg.checkpoint_config.dit_path) else "flux-fill-dev-dit", | |
redux_path=cfg.checkpoint_config.redux_path if os.path.exists(cfg.checkpoint_config.redux_path) else "flux1-redux-dev", | |
lora_path=cfg.checkpoint_config.lora_path if os.path.exists(cfg.checkpoint_config.lora_path) else "dit_lora_0x1561", | |
img_txt_in_path=cfg.checkpoint_config.img_txt_in_path if os.path.exists(cfg.checkpoint_config.img_txt_in_path) else "dit_txt_img_in_0x1561", | |
boundary_embeddings_path=cfg.checkpoint_config.boundary_embeddings_path if os.path.exists(cfg.checkpoint_config.boundary_embeddings_path) else "dit_boundary_embeddings_0x1561", | |
task_register_embeddings_path=cfg.checkpoint_config.task_register_embeddings_path if os.path.exists(cfg.checkpoint_config.task_register_embeddings_path) else "dit_task_register_embeddings_0x1561", | |
network_alpha=cfg.model_config.network_alpha, | |
double_blocks_idx=cfg.model_config.double_blocks, | |
single_blocks_idx=cfg.model_config.single_blocks, | |
device=device, | |
weight_dtype=weight_dtype, | |
offload=True, | |
) | |
pipeline.set_pipeline_offload(True) | |
# pipeline.set_show_progress(True) | |
# Load SAM predictor | |
mobile_predictor = get_sam_predictor(cfg.checkpoint_config.sam_path, device) | |
# Load VLM if enabled | |
vlm_processor, vlm_model = None, None | |
if args.enable_vlm_for_prompt: | |
vlm_processor, vlm_model = get_vlm( | |
cfg.checkpoint_config.vlm_path, | |
device=device, | |
torch_dtype=weight_dtype, | |
) | |
# Load BEN2 model if enabled | |
ben2_model = None | |
if args.enable_ben2_for_mask_ref: | |
ben2_model = get_ben2_model(cfg.checkpoint_config.ben2_path, device) | |
return pipeline, mobile_predictor, vlm_processor, vlm_model, ben2_model | |
def run_model( | |
image_target_state, mask_target_state, image_reference_ori_state, | |
image_reference_rmbg_state, prompt, seed, guidance, true_gs, num_steps, | |
num_images_per_prompt, use_background_preservation, background_blend_threshold, | |
aspect_ratio, custmization_mode, seg_ref_mode, input_mask_mode, | |
progress=gr.Progress() | |
): | |
"""Run IC-Custom pipeline with current UI state and return images.""" | |
start_ts = time.time() | |
progress(0, desc="Starting generation...") | |
# Select reference image and check inputs | |
if seg_ref_mode == "Masked Ref": | |
image_reference_state = image_reference_rmbg_state | |
else: | |
image_reference_state = image_reference_ori_state | |
if image_reference_state is None: | |
gr.Warning('Please upload the reference image') | |
return None, seed, gr.update(placeholder="Last Input: " + prompt, value="") | |
if image_target_state is None and custmization_mode != "Position-free": | |
gr.Warning('Please upload the target image and mask it') | |
return None, seed, gr.update(placeholder="Last Input: " + prompt, value="") | |
if custmization_mode == "Position-aware" and mask_target_state is None: | |
gr.Warning('Please select/draw the target mask') | |
return None, seed, gr.update(placeholder=prompt, value="") | |
mask_type_ids = get_mask_type_ids(custmization_mode, input_mask_mode) | |
from constants import ASPECT_RATIO_TEMPLATE | |
output_w, output_h = ASPECT_RATIO_TEMPLATE[aspect_ratio] | |
image_reference, image_target, mask_target = prepare_input_images( | |
image_reference_state, custmization_mode, image_target_state, mask_target_state, | |
width=output_w, height=output_h, | |
force_resize_long_edge="long edge" in aspect_ratio, | |
return_type="pil" | |
) | |
gr.Info(f"Output WH resolution: {image_target.size[0]}px x {image_target.size[1]}px") | |
# Run the model | |
if seed == -1: | |
seed = torch.randint(0, 2147483647, (1,)).item() | |
width, height = image_target.size[0] + image_reference.size[0], image_target.size[1] | |
with torch.no_grad(): | |
output_img = PIPELINE( | |
prompt=prompt, width=width, height=height, guidance=guidance, | |
num_steps=num_steps, seed=seed, img_ref=image_reference, | |
img_target=image_target, mask_target=mask_target, img_ip=image_reference, | |
cond_w_regions=[image_reference.size[0]], mask_type_ids=mask_type_ids, | |
use_background_preservation=use_background_preservation, | |
background_blend_threshold=background_blend_threshold, true_gs=true_gs, | |
neg_prompt="worst quality, normal quality, low quality, low res, blurry,", | |
num_images_per_prompt=num_images_per_prompt, | |
gradio_progress=progress, | |
) | |
elapsed = time.time() - start_ts | |
progress(1.0, desc=f"Completed in {elapsed:.2f}s!") | |
gr.Info(f"Finished in {elapsed:.2f}s") | |
return output_img, -1, gr.update(placeholder=f"Last Input ({elapsed:.2f}s): " + prompt, value="") | |
def example_pipeline( | |
image_reference, image_target_1, image_target_2, custmization_mode, | |
input_mask_mode, seg_ref_mode, prompt, seed, true_gs, eg_idx, | |
num_steps, guidance | |
): | |
"""Handle example loading in the UI.""" | |
if seg_ref_mode == "Full Ref": | |
image_reference_ori_state = np.array(image_reference.convert("RGB")) | |
image_reference_rmbg_state = None | |
image_reference_state = image_reference_ori_state | |
else: | |
image_reference_rmbg_state = np.array(image_reference.convert("RGB")) | |
image_reference_ori_state = None | |
image_reference_state = image_reference_rmbg_state | |
if custmization_mode == "Position-aware": | |
if input_mask_mode == "Precise mask": | |
image_target_state = np.array(image_target_1.convert("RGB")) | |
else: | |
image_target_state = np.array(image_target_2['composite'].convert("RGB")) | |
mask_target_state = np.array(Image.open(MASK_TGT[int(eg_idx)])) | |
else: # Position-free mode | |
# For Position-free, use the target image from IMG_TGT1 and corresponding mask | |
image_target_state = np.array(image_target_1.convert("RGB")) | |
mask_target_state = np.array(Image.open(MASK_TGT[int(eg_idx)])) | |
mask_target_binary = mask_target_state / 255 | |
masked_img = image_target_state * mask_target_binary | |
masked_img_pil = Image.fromarray(masked_img.astype("uint8")) | |
output_mask_pil = Image.fromarray(mask_target_state.astype("uint8")) | |
if custmization_mode == "Position-aware": | |
mask_gallery = [masked_img_pil, output_mask_pil] | |
else: | |
mask_gallery = gr.skip() | |
result_gallery = [Image.open(IMG_GEN[int(eg_idx)]).convert("RGB")] | |
if custmization_mode == "Position-free": | |
return (image_reference_ori_state, image_reference_rmbg_state, image_target_state, | |
mask_target_state, mask_gallery, result_gallery, | |
gr.update(visible=False), gr.update(visible=False)) | |
if input_mask_mode == "Precise mask": | |
return (image_reference_ori_state, image_reference_rmbg_state, image_target_state, | |
mask_target_state, mask_gallery, result_gallery, | |
gr.update(visible=True), gr.update(visible=False)) | |
else: | |
# Ensure ImageEditor has a proper background so brush + undo work | |
try: | |
bg_img = image_target_2.get('background') or image_target_2.get('composite') | |
except Exception: | |
bg_img = image_target_2 | |
return ( | |
image_reference_ori_state, image_reference_rmbg_state, image_target_state, | |
mask_target_state, mask_gallery, result_gallery, | |
gr.update(visible=False), | |
gr.update(visible=True, value={"background": bg_img, "layers": [], "composite": bg_img}), | |
) | |
def create_application(): | |
"""Create the main Gradio application.""" | |
# Create theme and CSS | |
theme = create_theme() | |
css = create_css() | |
with gr.Blocks(theme=theme, css=css) as demo: | |
with gr.Column(elem_id="global_glass_container"): | |
# Create UI sections | |
create_header_section() | |
# Hidden components | |
eg_idx = gr.Textbox(label="eg_idx", visible=False, value="-1") | |
# State variables | |
image_target_state = gr.State(value=None) | |
mask_target_state = gr.State(value=None) | |
image_reference_ori_state = gr.State(value=None) | |
image_reference_rmbg_state = gr.State(value=None) | |
selected_points = gr.State(value=[]) | |
# Main UI content with optimized left-right layout | |
with gr.Column(elem_id="glass_card"): | |
# Top section - Mode selection (full width) | |
custmization_mode, md_custmization_mode = create_customization_section() | |
# Main layout: Left for inputs, Right for outputs | |
with gr.Row(equal_height=False): | |
# LEFT COLUMN - ALL INPUTS | |
with gr.Column(scale=3, min_width=400): | |
# Image input section | |
(image_reference, input_mask_mode, image_target_1, image_target_2, | |
undo_target_seg_button, md_image_reference, md_input_mask_mode, | |
md_target_image) = create_image_input_section() | |
# Text prompt section | |
prompt, vlm_generate_btn, vlm_polish_btn, md_prompt = create_prompt_section() | |
# Advanced options (collapsible) | |
(aspect_ratio, seg_ref_mode, move_to_center, use_background_preservation, | |
background_blend_threshold, seed, num_images_per_prompt, guidance, | |
num_steps, true_gs) = create_advanced_options_section() | |
# RIGHT COLUMN - ALL OUTPUTS | |
with gr.Column(scale=2, min_width=350): | |
# Mask preview and operations | |
(mask_gallery, dilate_button, erode_button, bounding_box_button, | |
md_mask_operation) = create_mask_operation_section() | |
# Generation controls and results | |
result_gallery, submit_button, clear_btn, md_submit = create_output_section() | |
with gr.Row(elem_id="glass_card"): | |
# Examples section | |
examples = create_examples_section( | |
GRADIO_EXAMPLES, | |
inputs=[ | |
image_reference, | |
image_target_1, | |
image_target_2, | |
custmization_mode, | |
input_mask_mode, | |
seg_ref_mode, | |
prompt, | |
seed, | |
true_gs, | |
eg_idx, | |
num_steps, | |
guidance | |
], | |
outputs=[ | |
image_reference_ori_state, | |
image_reference_rmbg_state, | |
image_target_state, | |
mask_target_state, | |
mask_gallery, | |
result_gallery, | |
image_target_1, | |
image_target_2, | |
], | |
fn=example_pipeline, | |
) | |
with gr.Row(elem_id="glass_card"): | |
# Citation section | |
create_citation_section() | |
# Setup event handlers | |
setup_event_handlers( | |
## UI components | |
input_mask_mode, image_target_1, image_target_2, undo_target_seg_button, | |
custmization_mode, dilate_button, erode_button, bounding_box_button, | |
mask_gallery, md_input_mask_mode, md_target_image, md_mask_operation, | |
md_prompt, md_submit, result_gallery, image_target_state, mask_target_state, | |
seg_ref_mode, image_reference_ori_state, move_to_center, | |
image_reference, image_reference_rmbg_state, | |
## Functions | |
change_input_mask_mode, change_custmization_mode, | |
change_seg_ref_mode, | |
init_image_target_1, init_image_target_2, init_image_reference, | |
get_point, undo_seg_points, | |
get_brush, | |
# VLM buttons | |
vlm_generate_btn, vlm_polish_btn, | |
# VLM functions | |
vlm_auto_generate, | |
vlm_auto_polish, | |
dilate_mask, erode_mask, bounding_box, | |
run_model, | |
## Other components | |
selected_points, prompt, | |
use_background_preservation, background_blend_threshold, seed, | |
num_images_per_prompt, guidance, true_gs, num_steps, aspect_ratio, | |
submit_button, | |
eg_idx, | |
) | |
# Setup clear button | |
clear_btn.add( | |
[image_reference, image_target_1,image_target_2, mask_gallery, result_gallery, | |
selected_points, image_target_state, mask_target_state, prompt, | |
image_reference_ori_state, image_reference_rmbg_state] | |
) | |
return demo | |
def main(): | |
"""Main entry point for the application.""" | |
# Parse arguments and load config | |
args = parse_args() | |
cfg = load_config(args.config) | |
setup_environment(args) | |
# Initialize device and models | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
weight_dtype = torch.bfloat16 | |
pipeline, mobile_predictor, vlm_processor, vlm_model, ben2_model = initialize_models( | |
args, cfg, device, weight_dtype | |
) | |
set_pipeline(pipeline) | |
set_assets_cache_dir(args.assets_cache_dir) | |
# Inject mobile predictor into business logic module so get_point can access it without lambdas | |
set_mobile_predictor(mobile_predictor) | |
set_ben2_model(ben2_model) | |
set_vlm_processor(vlm_processor) | |
set_vlm_model(vlm_model) | |
# Create and launch the application | |
demo = create_application() | |
# Launch the demo | |
demo.launch(server_port=7860, server_name="0.0.0.0", | |
allowed_paths=[os.path.abspath(os.path.join(os.path.dirname(__file__), "gradio_cache")), | |
os.path.abspath(os.path.join(os.path.dirname(__file__), "results"))]) | |
if __name__ == "__main__": | |
main() |