Spaces:
Runtime error
Runtime error
File size: 6,523 Bytes
cc0fe43 19be4eb 91d99a8 2685d15 91d99a8 2685d15 51468b8 2685d15 91d99a8 cc0fe43 2685d15 91d99a8 2685d15 91d99a8 51468b8 2685d15 91d99a8 2685d15 91d99a8 92a7021 2685d15 91d99a8 2685d15 91d99a8 2685d15 91d99a8 2685d15 91d99a8 fd4cd12 91d99a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import os
import random
import re
import requests
import torch
import numpy as np
import gradio as gr
import spaces
from diffusers import FluxPipeline
from translatepy import Translator
# -----------------------------------------------------------------------------
# CONFIGURATION
# -----------------------------------------------------------------------------
config = {
"model_id": "black-forest-labs/FLUX.1-dev",
"default_lora": "nftnik/BR_ohwx_V1",
"default_weight_name": "BR_ohwx.safetensors",
"max_seed": int(np.iinfo(np.int32).max),
"css": "footer { visibility: hidden; }",
"default_width": 896,
"default_height": 1152,
"default_guidance_scale": 3.5,
"default_steps": 35,
"default_loRa_scale": 1.0,
}
# -----------------------------------------------------------------------------
# Environment and device setup
# -----------------------------------------------------------------------------
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
translator = Translator()
HF_TOKEN = os.environ.get("HF_TOKEN", None)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device.upper()}")
# -----------------------------------------------------------------------------
# Initialize the Flux pipeline and load default LoRA
# -----------------------------------------------------------------------------
pipe = FluxPipeline.from_pretrained(
config["model_id"], torch_dtype=torch.bfloat16
).to(device)
pipe.load_lora_weights(config["default_lora"], weight_name=config["default_weight_name"])
# -----------------------------------------------------------------------------
# Function to load a new LoRA model
# -----------------------------------------------------------------------------
def enable_lora(lora_add: str):
pipe.unload_lora_weights()
if not lora_add:
return gr.update(value="")
url = f"https://huggingface.co/{lora_add}/tree/main"
try:
pipe.load_lora_weights(lora_add)
return gr.update(label="LoRA Loaded Now")
except Exception as e:
raise gr.Error(f"Failed to load {lora_add}: {e}")
# -----------------------------------------------------------------------------
# Function to generate an image from a prompt
# -----------------------------------------------------------------------------
@spaces.GPU()
def generate_image(
prompt: str, lora_word: str, lora_scale: float = config["default_loRa_scale"],
width: int = config["default_width"], height: int = config["default_height"],
guidance_scale: float = config["default_guidance_scale"], steps: int = config["default_steps"],
seed: int = -1, nums: int = 1
):
pipe.to(device)
seed = random.randint(0, config["max_seed"]) if seed == -1 else int(seed)
prompt_english = str(translator.translate(prompt, "English"))
full_prompt = f"{prompt_english} {lora_word}"
generator = torch.Generator().manual_seed(seed)
result = pipe(
prompt=full_prompt, height=height, width=width, guidance_scale=guidance_scale,
output_type="pil", num_inference_steps=steps, num_images_per_prompt=nums,
generator=generator, joint_attention_kwargs={"scale": lora_scale},
)
return result.images, seed
# -----------------------------------------------------------------------------
# Gradio UI
# -----------------------------------------------------------------------------
example_prompts = [
["Medium-shot portrait, ohwx blue alien, wearing black techwear with a high collar, standing inside a futuristic VR showroom.", "ohwx", 0.9],
["ohwx blue alien, wearing black techwear with a high collar, immersed in a digital cybernetic landscape.", "ohwx", 0.9],
["full-body shot, ohwx blue alien, wearing black techwear with a high collar, black cyber sneakers, running through a neon-lit cyberpunk alley at night.", "ohwx", 0.9],
["ohwx blue alien, wearing black techwear with a high collar, sitting inside a sleek, high-tech VR capsule, immersed in an augmented reality experience.", "ohwx", 0.9]
]
with gr.Blocks(css=config["css"]) as demo:
gr.HTML("<h1><center>BR METAVERSO - Avatar Generator</center></h1>")
processing_status = gr.Markdown("**π’ Ready**", visible=True) # Status indicator
with gr.Row():
with gr.Column(scale=4):
gallery = gr.Gallery(label="Flux Generated Image", columns=1, preview=True, height=600)
prompt_input = gr.Textbox(label="Enter Your Prompt", lines=2, placeholder="Enter prompt...")
generate_btn = gr.Button(variant="primary")
with gr.Accordion("Advanced Options", open=True):
width_slider = gr.Slider(label="Width", minimum=512, maximum=1920, step=8, value=config["default_width"])
height_slider = gr.Slider(label="Height", minimum=512, maximum=1920, step=8, value=config["default_height"])
guidance_slider = gr.Slider(label="Guidance Scale", minimum=3.5, maximum=7, step=0.1, value=config["default_guidance_scale"])
steps_slider = gr.Slider(label="Steps", minimum=1, maximum=100, step=1, value=config["default_steps"])
seed_slider = gr.Slider(label="Seed", minimum=-1, maximum=config["max_seed"], step=1, value=-1)
nums_slider = gr.Slider(label="Image Count", minimum=1, maximum=2, step=1, value=1)
lora_scale_slider = gr.Slider(label="LoRA Scale", minimum=0.1, maximum=2.0, step=0.1, value=config["default_loRa_scale"])
lora_add_text = gr.Textbox(label="Flux LoRA", lines=1, value=config["default_lora"])
lora_word_text = gr.Textbox(label="Flux LoRA Trigger Word", lines=1, value="ohwx")
load_lora_btn = gr.Button(value="Load LoRA", variant="secondary")
gr.Examples(examples=example_prompts, inputs=[prompt_input, lora_word_text, lora_scale_slider], cache_examples=False, examples_per_page=4)
# Ensuring processing status updates correctly
def update_status():
return "**β³ Processing...**"
generate_btn.click(fn=update_status, inputs=[], outputs=[processing_status]).then(
fn=generate_image,
inputs=[prompt_input, lora_word_text, lora_scale_slider, width_slider, height_slider, guidance_slider, steps_slider, seed_slider, nums_slider],
outputs=[gallery, seed_slider]
).then(
fn=lambda: "**β
Done!**",
inputs=[],
outputs=[processing_status]
)
load_lora_btn.click(fn=enable_lora, inputs=[lora_add_text], outputs=lora_add_text)
demo.queue().launch() |