Kontextremix / app.py
wanghaofan's picture
Update app.py
d57e3a0 verified
import os
import gradio as gr
import json
import logging
import torch
from PIL import Image
import spaces
from diffusers import FluxKontextPipeline
from huggingface_hub import HfFileSystem, ModelCard
import copy
import random
import time
import subprocess
subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
from huggingface_hub import login
hf_token = os.environ.get("HF_TOKEN_GATED")
login(token=hf_token)
# Load LoRAs from JSON file
with open('loras.json', 'r') as f:
loras = json.load(f)
# Initialize the base model
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
base_model = "black-forest-labs/FLUX.1-Kontext-dev"
pipe = FluxKontextPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device)
MAX_SEED = 2**32-1
class calculateDuration:
def __init__(self, activity_name=""):
self.activity_name = activity_name
def __enter__(self):
self.start_time = time.time()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.end_time = time.time()
self.elapsed_time = self.end_time - self.start_time
if self.activity_name:
print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
else:
print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
def update_selection(evt: gr.SelectData, default_scale, lora_scale):
selected_lora = loras[evt.index]
new_placeholder = f"Type a prompt for {selected_lora['title']}"
prompt = selected_lora["prompt"]
lora_repo = selected_lora["repo"]
updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨"
if default_scale:
lora_scale = selected_lora["lora_scale"]
return (
prompt,
updated_text,
evt.index,
lora_scale,
)
@spaces.GPU
def generate_image(input_image, prompt_mash, steps, seed, cfg_scale, lora_scale, progress):
pipe.to("cuda")
generator = torch.Generator(device="cuda").manual_seed(seed)
with calculateDuration("Generating image"):
# Generate image
for img in pipe(
image=input_image,
prompt=prompt_mash,
num_inference_steps=steps,
guidance_scale=cfg_scale,
generator=generator,
joint_attention_kwargs={"scale": lora_scale},
output_type="pil",
):
yield img
@spaces.GPU
def run_lora(input_image, prompt, cfg_scale, steps, selected_index, randomize_seed, seed, lora_scale, progress=gr.Progress(track_tqdm=True)):
if selected_index is None:
raise gr.Error("You must select a LoRA before proceeding.")
selected_lora = loras[selected_index]
lora_path = selected_lora["repo"]
trigger_word = selected_lora["trigger_word"]
if(trigger_word):
if "trigger_position" in selected_lora:
if selected_lora["trigger_position"] == "prepend" and trigger_word != prompt:
prompt_mash = f"{trigger_word} {prompt}"
else:
if trigger_word != prompt:
prompt_mash = f"{prompt} {trigger_word}"
else:
prompt_mash = prompt
else:
prompt_mash = f"{trigger_word} {prompt}"
else:
prompt_mash = prompt
with calculateDuration("Unloading LoRA"):
pipe.unload_lora_weights()
# Load LoRA weights
with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
if "weights" in selected_lora:
pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
else:
pipe.load_lora_weights(lora_path)
# Set random seed for reproducibility
with calculateDuration("Randomizing seed"):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
# image_generator = generate_image(input_image, prompt_mash, steps, seed, cfg_scale, lora_scale, progress)
generator = torch.Generator(device="cuda").manual_seed(seed)
final_image = pipe(
image=input_image,
prompt=prompt_mash,
num_inference_steps=steps,
guidance_scale=cfg_scale,
generator=generator,
joint_attention_kwargs={"scale": lora_scale},
).images[0]
# # Consume the generator to get the final image
# final_image = None
# step_counter = 0
# for image in image_generator:
# step_counter+=1
# final_image = image
# progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
# yield image, seed, gr.update(value=progress_bar, visible=True)
yield final_image, seed, gr.update(visible=True)
css = '''
#gen_btn{height: 100%}
#title{text-align: center}
#title h1{font-size: 3em; display:inline-flex; align-items:center}
#title img{width: 100px; margin-right: 0.5em}
#gallery .grid-wrap{height: 10vh}
#lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%}
.card_internal{display: flex;height: 100px;margin-top: .5em}
.card_internal img{margin-right: 1em}
.styler{--form-gap-width: 0px !important}
#progress{height:30px}
#progress .generating{display:none}
.progress-container {width: 100%;height: 30px;background-color: #f0f0f0;border-radius: 15px;overflow: hidden;margin-bottom: 20px}
.progress-bar {height: 100%;background-color: #4f46e5;width: calc(var(--current) / var(--total) * 100%);transition: width 0.5s ease-in-out}
'''
with gr.Blocks(theme=gr.themes.Soft(), css=css) as app:
title = gr.HTML(
"""<h1><img src="https://huggingface.co/Shakker-Labs/FLUX.1-dev-LoRA-collections/resolve/main/logo.png" alt="LoRA"> FLUX Kontext LoRA Gallery from Shakker AI</h1>""",
elem_id="title",
)
selected_index = gr.State(None)
with gr.Row():
with gr.Column(scale=3):
prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Please select a LoRA by clicking")
with gr.Column(scale=1, elem_id="gen_column"):
generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
with gr.Row():
with gr.Column():
selected_info = gr.Markdown("")
gallery = gr.Gallery(
[(item["image"], item["title"]) for item in loras],
label="LoRA Gallery",
allow_preview=False,
columns=3,
elem_id="gallery"
)
with gr.Row():
with gr.Column():
image_in = gr.Image(label="Upload the image for editing", type="pil")
with gr.Column():
progress_bar = gr.Markdown(elem_id="progress",visible=False)
result = gr.Image(label="Generated Image",show_label=False,interactive=False)
with gr.Row():
with gr.Accordion("Advanced Settings", open=False):
with gr.Column():
with gr.Row():
cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=2.5)
steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
with gr.Row():
randomize_seed = gr.Checkbox(True, label="Randomize seed")
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
default_scale = gr.Checkbox(True, label="Use default LoRA scale")
lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=3, step=0.01, value=1.0)
gallery.select(
update_selection,
inputs=[default_scale, lora_scale],
outputs=[prompt, selected_info, selected_index, lora_scale]
)
gr.on(
triggers=[generate_button.click, prompt.submit],
fn=run_lora,
inputs=[image_in, prompt, cfg_scale, steps, selected_index, randomize_seed, seed, lora_scale],
outputs=[result, seed, progress_bar]
)
app.queue()
app.launch()