File size: 4,615 Bytes
eef5eb1
 
 
cae7a7e
4502b8a
7473c7e
4502b8a
 
cae7a7e
 
 
 
 
 
eef5eb1
7473c7e
 
 
 
 
 
0be8e8b
7473c7e
 
4502b8a
7473c7e
cae7a7e
 
 
 
 
7473c7e
4502b8a
7473c7e
 
4502b8a
7473c7e
 
 
4502b8a
7473c7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e1e0e1
7473c7e
7ad8dad
7473c7e
 
 
 
 
 
 
4502b8a
 
7473c7e
 
d3062b2
 
 
7ad8dad
 
d3062b2
7473c7e
d3062b2
7473c7e
4502b8a
7473c7e
 
 
 
 
 
 
 
 
 
 
 
d3062b2
7473c7e
 
 
 
 
 
 
 
4502b8a
7473c7e
 
 
 
 
 
 
 
0b556fe
4502b8a
d3062b2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import random
import numpy as np
import spaces
import torch
from diffusers import FluxPipeline
import gradio as gr

# Access the Hugging Face token from environment variables
hf_token = os.getenv("HF_TOKEN")

if hf_token is None:
    raise ValueError("Hugging Face token is not set. Please set the HF_TOKEN environment variable.")

# Check if GPU is available
if torch.cuda.is_available():
    device = "cuda"
    print("Using GPU")
else:
    device = "cpu"
    print("Using CPU")

MAX_SEED = np.iinfo(np.int32).max
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "0") == "1"

# Initialize the pipeline and download the model
pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    torch_dtype=torch.bfloat16,
    use_auth_token=hf_token  # Use the token from the environment variable
)
pipe.to(device)

# Enable memory optimizations
pipe.enable_attention_slicing()

# Define the image generation function
@spaces.GPU(duration=180)
def generate_image(prompt, num_inference_steps, height, width, guidance_scale, seed, num_images_per_prompt, progress=gr.Progress(track_tqdm=True)):
    if seed == 0:
        seed = random.randint(1, MAX_SEED)

    generator = torch.Generator().manual_seed(seed)
    
    with torch.inference_mode():
        output = pipe(
            prompt=prompt,
            num_inference_steps=num_inference_steps,
            height=height,
            width=width,
            guidance_scale=guidance_scale,
            generator=generator,
            num_images_per_prompt=num_images_per_prompt
        ).images
    
    return output

# Create the Gradio interface

examples = [
    ["Black forest cake spelling out the words 'I love you', tasty, food photography, dynamic shot"],
]

css = '''
.gradio-container{max-width: 1000px !important}
h1{text-align:center}
'''
with gr.Blocks(css=css) as demo:
    with gr.Row():
        with gr.Column():
            gr.HTML(
                """
                <h1 style='text-align: center'>
                FLUX.1-dev Image Generator
                </h1>
                <p style='text-align: center; font-size: 18px; color: #333;'>
                Welcome to the FLUX.1-dev Image Generator! This tool transforms your creative ideas into stunning visual artwork using state-of-the-art AI technology. Simply describe the image you imagine, adjust the settings to your preference, and let our model bring your vision to life. Explore endless possibilities and let your creativity soar!
                </p>
                """
            )
    with gr.Group():
        with gr.Column():
            prompt = gr.Textbox(label="Prompt", info="Describe the image you want", placeholder="A cat...")
            run_button = gr.Button("Run")
        result = gr.Gallery(label="Generated AI Images", elem_id="gallery")
    with gr.Accordion("Advanced options", open=False):
        with gr.Row():
            num_inference_steps = gr.Slider(label="Number of Inference Steps", info="The number of denoising steps of the image. More denoising steps usually lead to a higher quality image at the cost of slower inference", minimum=1, maximum=50, value=25, step=1)
            guidance_scale = gr.Slider(label="Guidance Scale", info="Controls how much the image generation process follows the text prompt. Higher values make the image stick more closely to the input text.", minimum=0.0, maximum=7.0, value=3.5, step=0.1)
        with gr.Row():
            width = gr.Slider(label="Width", info="Width of the Image", minimum=256, maximum=1024, step=32, value=1024)
            height = gr.Slider(label="Height", info="Height of the Image", minimum=256, maximum=1024, step=32, value=1024)
        with gr.Row():
            seed = gr.Slider(value=42, minimum=0, maximum=MAX_SEED, step=1, label="Seed", info="A starting point to initiate the generation process, put 0 for a random one")
            num_images_per_prompt = gr.Slider(label="Images Per Prompt", info="Number of Images to generate with the settings", minimum=1, maximum=4, step=1, value=2)

    gr.Examples(
        examples=examples,
        fn=generate_image,
        inputs=[prompt, num_inference_steps, height, width, guidance_scale, seed, num_images_per_prompt],
        outputs=[result],
        cache_examples=CACHE_EXAMPLES
    )

    gr.on(
        triggers=[
            prompt.submit,
            run_button.click,
        ],
        fn=generate_image,
        inputs=[prompt, num_inference_steps, height, width, guidance_scale, seed, num_images_per_prompt],
        outputs=[result],
    )

demo.queue().launch(share=False)