pixe-3.5 / main_gr.py
phxdev's picture
Fix TypeError: Replace gr.Examples with Markdown to avoid schema issues
11eeb6f
import gradio as gr
import torch
import numpy as np
import random
import spaces
from diffusers import FluxPipeline
MAX_SEED = np.iinfo(np.int32).max
# Available LoRAs
LORA_OPTIONS = {
"None": None,
"Add Details": "Shakker-Labs/FLUX.1-dev-LoRA-add-details",
"Merlin Turbo Alpha": "its-magick/merlin-turbo-alpha",
"Flux Realism": "its-magick/flux-realism",
"Perfection Style v1": "https://huggingface.co/its-magick/merlin-test-loras/resolve/main/perfection%20style%20v1.safetensors",
"Canopus Face Realism": "https://huggingface.co/its-magick/merlin-test-loras/resolve/main/Canopus-LoRA-Flux-FaceRealism.safetensors"
}
# Global variables to track current LoRA
current_lora = None
current_lora_strength = 0.8
@spaces.GPU(duration=60)
def generate_image(prompt, seed, randomize_seed, width, height, num_inference_steps, guidance_scale, lora_choice, lora_strength):
global current_lora, current_lora_strength
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
# Handle LoRA loading/unloading
selected_lora = LORA_OPTIONS.get(lora_choice)
if selected_lora != current_lora or lora_strength != current_lora_strength:
# Unload current LoRA if any
if current_lora is not None:
pipe.unload_lora_weights()
# Load new LoRA if selected
if selected_lora is not None:
pipe.load_lora_weights(selected_lora)
current_lora = selected_lora
current_lora_strength = lora_strength
else:
current_lora = None
current_lora_strength = 0.8
# Generate image
if current_lora is not None:
image = pipe(
prompt=prompt,
width=width,
height=height,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
cross_attention_kwargs={"scale": lora_strength},
return_dict=False
)[0]
else:
image = pipe(
prompt=prompt,
width=width,
height=height,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
return_dict=False
)[0]
return image, seed
# Load model
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Gradio interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# FLUX.1 Schnell Image Generator")
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(
label="Prompt",
placeholder="Enter your image description...",
lines=3
)
with gr.Row():
generate_btn = gr.Button("Generate Image", variant="primary")
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42
)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
width = gr.Slider(
label="Width",
minimum=256,
maximum=1024,
step=8,
value=1024
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=1024,
step=8,
value=1024
)
num_inference_steps = gr.Slider(
label="Inference Steps",
minimum=1,
maximum=4,
step=1,
value=4
)
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=0.0,
maximum=3.5,
step=0.1,
value=0.0
)
lora_choice = gr.Dropdown(
label="LoRA Model",
choices=list(LORA_OPTIONS.keys()),
value="None"
)
lora_strength = gr.Slider(
label="LoRA Strength",
minimum=0.0,
maximum=2.0,
step=0.1,
value=0.8
)
with gr.Column(scale=1):
output_image = gr.Image(label="Generated Image")
output_seed = gr.Number(label="Used Seed")
# Examples
with gr.Row():
gr.Markdown("**Example prompts:** a tiny astronaut hatching from an egg on the moon • a cat holding a sign that says hello world • an anime illustration of a wiener schnitzel")
# Connect the generate button
generate_btn.click(
fn=generate_image,
inputs=[prompt, seed, randomize_seed, width, height, num_inference_steps, guidance_scale, lora_choice, lora_strength],
outputs=[output_image, output_seed]
)
if __name__ == "__main__":
demo.launch(share=True)