Bton's picture
Update app.py
d5d5f7f verified
import gradio as gr
import spaces
import os
import torch
import random
import uuid
from datetime import datetime
from diffusers import FluxTransformer2DModel, FluxPipeline, GGUFQuantizationConfig
from transformers import T5EncoderModel
from PIL import Image
import numpy as np
# Constants
NUM_INFERENCE_STEPS = 8
MAX_SEED = np.iinfo(np.int32).max
SAVE_DIR = "saved_images"
os.makedirs(SAVE_DIR, exist_ok=True)
# Initialize device
device = "cuda" if torch.cuda.is_available() else "cpu"
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
# Load model
dtype = torch.bfloat16
gguf_file_url = "https://huggingface.co/gokaygokay/flux-game/resolve/main/hyperflux_00001_.q8_0.gguf"
base_model = "black-forest-labs/FLUX.1-dev"
text_encoder_2 = T5EncoderModel.from_pretrained(
base_model,
subfolder="text_encoder_2",
torch_dtype=dtype,
token=huggingface_token
)
transformer = FluxTransformer2DModel.from_single_file(
gguf_file_url,
quantization_config=GGUFQuantizationConfig(compute_dtype=dtype),
torch_dtype=dtype,
token=huggingface_token
)
flux_pipeline = FluxPipeline.from_pretrained(
base_model,
transformer=transformer,
text_encoder_2=text_encoder_2,
torch_dtype=dtype,
token=huggingface_token
).to(device)
@spaces.GPU
def generate_flux_image(
prompt: str,
seed: int,
randomize_seed: bool,
width: int,
height: int,
guidance_scale: float,
progress: gr.Progress = gr.Progress(track_tqdm=True),
) -> Image.Image:
"""Generate image using Flux pipeline"""
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=device).manual_seed(seed)
prompt = "wbgmsst, " + prompt + ", 3D isometric, white background"
image = flux_pipeline(
prompt=prompt,
guidance_scale=guidance_scale,
num_inference_steps=NUM_INFERENCE_STEPS,
width=width,
height=height,
generator=generator,
).images[0]
# Save the generated image
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
unique_id = str(uuid.uuid4())[:8]
filename = f"{timestamp}_{unique_id}.png"
filepath = os.path.join(SAVE_DIR, filename)
image.save(filepath)
return image
# Simple Gradio interface
demo = gr.Interface(
fn=generate_flux_image,
inputs=[
gr.Textbox(label="Prompt", placeholder="Enter your game asset description"),
gr.Slider(0, MAX_SEED, label="Seed", value=42, step=1),
gr.Checkbox(label="Randomize Seed", value=True),
gr.Slider(512, 1024, label="Width", value=1024, step=16),
gr.Slider(512, 1024, label="Height", value=1024, step=16),
gr.Slider(0.0, 10.0, label="Guidance Scale", value=3.5, step=0.1),
],
outputs=gr.Image(label="Generated Asset", type="pil"),
title="Game Asset Generator",
description="Generate game assets with FLUX"
)
if __name__ == "__main__":
demo.launch()