File size: 5,471 Bytes
72ae6c2
0ca60ea
72ae6c2
 
998b0a3
0ca60ea
72ae6c2
 
 
68d7874
 
 
f85ee30
d4fee03
e30d1ca
5485ad4
 
68d7874
 
 
 
 
 
0ca60ea
68d7874
 
 
72ae6c2
 
0ca60ea
72ae6c2
0ca60ea
68d7874
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ca60ea
72ae6c2
 
0ca60ea
 
 
72ae6c2
0ca60ea
 
 
72ae6c2
0ca60ea
 
 
72ae6c2
0ca60ea
 
72ae6c2
 
 
0ca60ea
 
 
 
 
 
 
 
 
 
 
 
72ae6c2
 
 
 
0ca60ea
 
 
998b0a3
72ae6c2
 
0ca60ea
72ae6c2
0ca60ea
 
 
998b0a3
72ae6c2
 
0ca60ea
72ae6c2
0ca60ea
72ae6c2
0ca60ea
 
 
 
 
 
 
 
 
998b0a3
68d7874
 
 
 
 
 
 
 
 
 
 
 
 
 
72ae6c2
0ca60ea
 
 
 
 
11eeb6f
 
0ca60ea
 
 
 
68d7874
0ca60ea
998b0a3
 
0ca60ea
1c817ec
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import gradio as gr
import torch
import numpy as np
import random
import spaces
from diffusers import FluxPipeline

MAX_SEED = np.iinfo(np.int32).max

# Available LoRAs
LORA_OPTIONS = {
    "None": None,
    "Add Details": "Shakker-Labs/FLUX.1-dev-LoRA-add-details",
    "Merlin Turbo Alpha": "its-magick/merlin-turbo-alpha",
    "Flux Realism": "its-magick/flux-realism",
    "Perfection Style v1": "https://huggingface.co/its-magick/merlin-test-loras/resolve/main/perfection%20style%20v1.safetensors",
    "Canopus Face Realism": "https://huggingface.co/its-magick/merlin-test-loras/resolve/main/Canopus-LoRA-Flux-FaceRealism.safetensors"
}

# Global variables to track current LoRA
current_lora = None
current_lora_strength = 0.8

@spaces.GPU(duration=60)
def generate_image(prompt, seed, randomize_seed, width, height, num_inference_steps, guidance_scale, lora_choice, lora_strength):
    global current_lora, current_lora_strength
    
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    
    generator = torch.Generator().manual_seed(seed)
    
    # Handle LoRA loading/unloading
    selected_lora = LORA_OPTIONS.get(lora_choice)
    
    if selected_lora != current_lora or lora_strength != current_lora_strength:
        # Unload current LoRA if any
        if current_lora is not None:
            pipe.unload_lora_weights()
        
        # Load new LoRA if selected
        if selected_lora is not None:
            pipe.load_lora_weights(selected_lora)
            current_lora = selected_lora
            current_lora_strength = lora_strength
        else:
            current_lora = None
            current_lora_strength = 0.8
    
    # Generate image
    if current_lora is not None:
        image = pipe(
            prompt=prompt,
            width=width,
            height=height,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            generator=generator,
            cross_attention_kwargs={"scale": lora_strength},
            return_dict=False
        )[0]
    else:
        image = pipe(
            prompt=prompt,
            width=width,
            height=height,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            generator=generator,
            return_dict=False
        )[0]
    
    return image, seed

# Load model
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipe.to("cuda")

# Gradio interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# FLUX.1 Schnell Image Generator")
    
    with gr.Row():
        with gr.Column(scale=1):
            prompt = gr.Textbox(
                label="Prompt",
                placeholder="Enter your image description...",
                lines=3
            )
            
            with gr.Row():
                generate_btn = gr.Button("Generate Image", variant="primary")
                
            with gr.Accordion("Advanced Settings", open=False):
                seed = gr.Slider(
                    label="Seed",
                    minimum=0,
                    maximum=MAX_SEED,
                    step=1,
                    value=42
                )
                
                randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
                
                width = gr.Slider(
                    label="Width",
                    minimum=256,
                    maximum=1024,
                    step=8,
                    value=1024
                )
                
                height = gr.Slider(
                    label="Height", 
                    minimum=256,
                    maximum=1024,
                    step=8,
                    value=1024
                )
                
                num_inference_steps = gr.Slider(
                    label="Inference Steps",
                    minimum=1,
                    maximum=4,
                    step=1,
                    value=4
                )
                
                guidance_scale = gr.Slider(
                    label="Guidance Scale",
                    minimum=0.0,
                    maximum=3.5,
                    step=0.1,
                    value=0.0
                )
                
                lora_choice = gr.Dropdown(
                    label="LoRA Model",
                    choices=list(LORA_OPTIONS.keys()),
                    value="None"
                )
                
                lora_strength = gr.Slider(
                    label="LoRA Strength",
                    minimum=0.0,
                    maximum=2.0,
                    step=0.1,
                    value=0.8
                )
        
        with gr.Column(scale=1):
            output_image = gr.Image(label="Generated Image")
            output_seed = gr.Number(label="Used Seed")
    
    # Examples
    with gr.Row():
        gr.Markdown("**Example prompts:** a tiny astronaut hatching from an egg on the moon • a cat holding a sign that says hello world • an anime illustration of a wiener schnitzel")
    
    # Connect the generate button
    generate_btn.click(
        fn=generate_image,
        inputs=[prompt, seed, randomize_seed, width, height, num_inference_steps, guidance_scale, lora_choice, lora_strength],
        outputs=[output_image, output_seed]
    )

if __name__ == "__main__":
    demo.launch(share=True)