import gradio as gr import torch import spaces from diffusers import FluxPipeline from safetensors.torch import load_file # Load the model pipe = FluxPipeline.from_pretrained( 'black-forest-labs/FLUX.1-dev', torch_dtype=torch.bfloat16, use_safetensors=True ).to('cuda') # Load SRPO weights from huggingface_hub import hf_hub_download srpo_path = hf_hub_download( repo_id="tencent/SRPO", filename="diffusion_pytorch_model.safetensors" ) state_dict = load_file(srpo_path) pipe.transformer.load_state_dict(state_dict) @spaces.GPU(duration=120) def generate_image( prompt, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=50, seed=-1 ): if seed == -1: seed = torch.randint(0, 2**32, (1,)).item() generator = torch.Generator(device='cuda').manual_seed(seed) image = pipe( prompt=prompt, guidance_scale=guidance_scale, height=height, width=width, num_inference_steps=num_inference_steps, max_sequence_length=512, generator=generator ).images[0] return image, seed with gr.Blocks(title="FLUX SRPO Text-to-Image", theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray", neutral_hue="slate")) as demo: gr.Markdown("# Flux SRPO") gr.Markdown("Generate images using FLUX model enhanced with Tencent's SRPO technique") gr.Markdown("Built with [AnyCoder](https://huggingface.co/spaces/akhaliq/anycoder)") output_image = gr.Image(label="Generated Image", type="pil") prompt = gr.Textbox( label="Prompt", placeholder="Describe the image you want to generate...", lines=3 ) generate_btn = gr.Button("Generate Image", variant="primary", size="lg") with gr.Accordion("Advanced Settings", open=False): with gr.Row(): width = gr.Slider( minimum=256, maximum=2048, value=1024, step=64, label="Width" ) height = gr.Slider( minimum=256, maximum=2048, value=1024, step=64, label="Height" ) with gr.Row(): guidance_scale = gr.Slider( minimum=1.0, maximum=20.0, value=3.5, step=0.5, label="Guidance Scale" ) num_inference_steps = gr.Slider( minimum=10, maximum=100, value=50, step=5, label="Inference Steps" ) seed = gr.Number( label="Seed (-1 for random)", value=-1, precision=0 ) used_seed = gr.Number(label="Seed Used", precision=0) gr.Examples( examples=[ ["The Death of Ophelia by John Everett Millais, Pre-Raphaelite painting, Ophelia floating in a river surrounded by flowers, detailed natural elements, melancholic and tragic atmosphere"], ["A serene Japanese garden with cherry blossoms, koi pond, traditional wooden bridge, soft morning light, photorealistic"], ["Cyberpunk cityscape at night, neon lights, flying cars, rain-slicked streets, blade runner aesthetic, highly detailed"], ["Portrait of a majestic lion in golden hour light, detailed fur texture, intense gaze, African savanna background"], ["Abstract colorful explosion of paint in water, high speed photography, vibrant colors mixing, dramatic lighting"], ], inputs=prompt, label="Example Prompts" ) generate_btn.click( fn=generate_image, inputs=[prompt, width, height, guidance_scale, num_inference_steps, seed], outputs=[output_image, used_seed] ) demo.launch()