File size: 6,523 Bytes
cc0fe43
19be4eb
91d99a8
 
2685d15
 
 
91d99a8
2685d15
 
51468b8
2685d15
 
 
91d99a8
 
 
 
 
 
 
 
 
 
 
 
cc0fe43
2685d15
91d99a8
2685d15
91d99a8
 
 
 
 
51468b8
2685d15
91d99a8
2685d15
91d99a8
 
 
 
92a7021
2685d15
91d99a8
2685d15
91d99a8
 
 
 
 
2685d15
91d99a8
 
2685d15
91d99a8
fd4cd12
91d99a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import random
import re
import requests
import torch
import numpy as np
import gradio as gr
import spaces
from diffusers import FluxPipeline
from translatepy import Translator

# -----------------------------------------------------------------------------
# CONFIGURATION
# -----------------------------------------------------------------------------
config = {
    "model_id": "black-forest-labs/FLUX.1-dev",
    "default_lora": "nftnik/BR_ohwx_V1",
    "default_weight_name": "BR_ohwx.safetensors",
    "max_seed": int(np.iinfo(np.int32).max),
    "css": "footer { visibility: hidden; }",
    "default_width": 896,
    "default_height": 1152,
    "default_guidance_scale": 3.5,
    "default_steps": 35,
    "default_loRa_scale": 1.0,
}

# -----------------------------------------------------------------------------
# Environment and device setup
# -----------------------------------------------------------------------------
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
translator = Translator()
HF_TOKEN = os.environ.get("HF_TOKEN", None)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device.upper()}")

# -----------------------------------------------------------------------------
# Initialize the Flux pipeline and load default LoRA
# -----------------------------------------------------------------------------
pipe = FluxPipeline.from_pretrained(
    config["model_id"], torch_dtype=torch.bfloat16
).to(device)
pipe.load_lora_weights(config["default_lora"], weight_name=config["default_weight_name"])

# -----------------------------------------------------------------------------
# Function to load a new LoRA model
# -----------------------------------------------------------------------------
def enable_lora(lora_add: str):
    pipe.unload_lora_weights()
    if not lora_add:
        return gr.update(value="")
    url = f"https://huggingface.co/{lora_add}/tree/main"
    try:
        pipe.load_lora_weights(lora_add)
        return gr.update(label="LoRA Loaded Now")
    except Exception as e:
        raise gr.Error(f"Failed to load {lora_add}: {e}")

# -----------------------------------------------------------------------------
# Function to generate an image from a prompt
# -----------------------------------------------------------------------------
@spaces.GPU()
def generate_image(
    prompt: str, lora_word: str, lora_scale: float = config["default_loRa_scale"],
    width: int = config["default_width"], height: int = config["default_height"],
    guidance_scale: float = config["default_guidance_scale"], steps: int = config["default_steps"],
    seed: int = -1, nums: int = 1
):
    pipe.to(device)
    seed = random.randint(0, config["max_seed"]) if seed == -1 else int(seed)
    prompt_english = str(translator.translate(prompt, "English"))
    full_prompt = f"{prompt_english} {lora_word}"
    generator = torch.Generator().manual_seed(seed)

    result = pipe(
        prompt=full_prompt, height=height, width=width, guidance_scale=guidance_scale,
        output_type="pil", num_inference_steps=steps, num_images_per_prompt=nums,
        generator=generator, joint_attention_kwargs={"scale": lora_scale},
    )
    return result.images, seed

# -----------------------------------------------------------------------------
# Gradio UI
# -----------------------------------------------------------------------------
example_prompts = [
    ["Medium-shot portrait, ohwx blue alien, wearing black techwear with a high collar, standing inside a futuristic VR showroom.", "ohwx", 0.9],
    ["ohwx blue alien, wearing black techwear with a high collar, immersed in a digital cybernetic landscape.", "ohwx", 0.9],
    ["full-body shot, ohwx blue alien, wearing black techwear with a high collar, black cyber sneakers, running through a neon-lit cyberpunk alley at night.", "ohwx", 0.9],
    ["ohwx blue alien, wearing black techwear with a high collar, sitting inside a sleek, high-tech VR capsule, immersed in an augmented reality experience.", "ohwx", 0.9]
]

with gr.Blocks(css=config["css"]) as demo:
    gr.HTML("<h1><center>BR METAVERSO - Avatar Generator</center></h1>")
    
    processing_status = gr.Markdown("**🟒 Ready**", visible=True)  # Status indicator
    
    with gr.Row():
        with gr.Column(scale=4):
            gallery = gr.Gallery(label="Flux Generated Image", columns=1, preview=True, height=600)
            prompt_input = gr.Textbox(label="Enter Your Prompt", lines=2, placeholder="Enter prompt...")
            generate_btn = gr.Button(variant="primary")
        with gr.Accordion("Advanced Options", open=True):
            width_slider = gr.Slider(label="Width", minimum=512, maximum=1920, step=8, value=config["default_width"])
            height_slider = gr.Slider(label="Height", minimum=512, maximum=1920, step=8, value=config["default_height"])
            guidance_slider = gr.Slider(label="Guidance Scale", minimum=3.5, maximum=7, step=0.1, value=config["default_guidance_scale"])
            steps_slider = gr.Slider(label="Steps", minimum=1, maximum=100, step=1, value=config["default_steps"])
            seed_slider = gr.Slider(label="Seed", minimum=-1, maximum=config["max_seed"], step=1, value=-1)
            nums_slider = gr.Slider(label="Image Count", minimum=1, maximum=2, step=1, value=1)
            lora_scale_slider = gr.Slider(label="LoRA Scale", minimum=0.1, maximum=2.0, step=0.1, value=config["default_loRa_scale"])
            lora_add_text = gr.Textbox(label="Flux LoRA", lines=1, value=config["default_lora"])
            lora_word_text = gr.Textbox(label="Flux LoRA Trigger Word", lines=1, value="ohwx")
            load_lora_btn = gr.Button(value="Load LoRA", variant="secondary")
    
    gr.Examples(examples=example_prompts, inputs=[prompt_input, lora_word_text, lora_scale_slider], cache_examples=False, examples_per_page=4)
    
    # Ensuring processing status updates correctly
    def update_status():
        return "**⏳ Processing...**"
    
    generate_btn.click(fn=update_status, inputs=[], outputs=[processing_status]).then(
        fn=generate_image,
        inputs=[prompt_input, lora_word_text, lora_scale_slider, width_slider, height_slider, guidance_slider, steps_slider, seed_slider, nums_slider],
        outputs=[gallery, seed_slider]
    ).then(
        fn=lambda: "**βœ… Done!**",
        inputs=[],
        outputs=[processing_status]
    )
    
    load_lora_btn.click(fn=enable_lora, inputs=[lora_add_text], outputs=lora_add_text)
    
demo.queue().launch()