import spaces import time import os import gradio as gr import torch from einops import rearrange from PIL import Image from transformers import pipeline from flux.cli import SamplingOptions from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack from flux.util import load_ae, load_clip, load_flow_model, load_t5 from pulid.pipeline_flux import PuLIDPipeline from pulid.utils import resize_numpy_image_long NSFW_THRESHOLD = 0.85 def get_models(name: str, device: torch.device, offload: bool): t5 = load_t5(device, max_length=128) clip = load_clip(device) model = load_flow_model(name, device="cpu" if offload else device) model.eval() ae = load_ae(name, device="cpu" if offload else device) nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device) return model, ae, t5, clip, nsfw_classifier class FluxGenerator: def __init__(self): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.offload = True # Enable offloading for free tier self.model_name = "flux-schnell" # Use flux-schnell self.model, self.ae, self.t5, self.clip, self.nsfw_classifier = get_models( self.model_name, device=self.device, offload=self.offload, ) self.pulid_model = PuLIDPipeline(self.model, "cuda", weight_dtype=torch.bfloat16) self.pulid_model.load_pretrain() flux_generator = FluxGenerator() @spaces.GPU @torch.inference_mode() def generate_image( prompt, id_image, seed, width=512, # Reduced for free tier height=512, # Reduced for free tier num_steps=4, # Optimized for schnell id_weight=1.0, ): flux_generator.t5.max_length = 128 seed = int(seed) if seed != -1 else torch.Generator(device="cpu").seed() opts = SamplingOptions( prompt=prompt, width=width, height=height, num_steps=num_steps, guidance=0.0, # No guidance for schnell seed=seed, ) print(f"Generating '{opts.prompt}' with seed {opts.seed}") t0 = time.perf_counter() # Process ID image if provided if id_image is not None: id_image = resize_numpy_image_long(id_image, 512) # Smaller size for memory id_embeddings, uncond_id_embeddings = flux_generator.pulid_model.get_id_embedding(id_image, cal_uncond=False) else: id_embeddings = None uncond_id_embeddings = None # Prepare noise and schedule x = get_noise( 1, opts.height, opts.width, device=flux_generator.device, dtype=torch.bfloat16, seed=opts.seed, ) timesteps = get_schedule( opts.num_steps, x.shape[-1] * x.shape[-2] // 4, shift=True, ) if flux_generator.offload: flux_generator.t5, flux_generator.clip = flux_generator.t5.to(flux_generator.device), flux_generator.clip.to(flux_generator.device) inp = prepare(t5=flux_generator.t5, clip=flux_generator.clip, img=x, prompt=opts.prompt) if flux_generator.offload: flux_generator.t5, flux_generator.clip = flux_generator.t5.cpu(), flux_generator.clip.cpu() torch.cuda.empty_cache() flux_generator.model = flux_generator.model.to(flux_generator.device) # Denoise x = denoise( flux_generator.model, **inp, timesteps=timesteps, guidance=opts.guidance, id=id_embeddings, id_weight=id_weight, start_step=0, uncond_id=uncond_id_embeddings, true_cfg=1.0, # No true CFG for schnell ) if flux_generator.offload: flux_generator.model.cpu() torch.cuda.empty_cache() flux_generator.ae.decoder.to(x.device) # Decode x = unpack(x.float(), opts.height, opts.width) with torch.autocast(device_type=flux_generator.device.type, dtype=torch.bfloat16): x = flux_generator.ae.decode(x) if flux_generator.offload: flux_generator.ae.decoder.cpu() torch.cuda.empty_cache() t1 = time.perf_counter() print(f"Done in {t1 - t0:.1f}s.") # Convert to PIL x = x.clamp(-1, 1) x = rearrange(x[0], "c h w -> h w c") img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) # NSFW check nsfw_score = [x["score"] for x in flux_generator.nsfw_classifier(img) if x["label"] == "nsfw"][0] if nsfw_score < NSFW_THRESHOLD: return img, str(opts.seed) else: return None, f"Image may contain NSFW content (score: {nsfw_score})" def create_demo(): with gr.Blocks() as demo: gr.Markdown("# PuLID with FLUX.1 Schnell Demo") with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="Prompt", value="A person in a futuristic city") id_image = gr.Image(label="Reference Image (ID)") seed = gr.Textbox(label="Seed (-1 for random)", value="-1") width = gr.Slider(256, 1024, 512, step=16, label="Width") height = gr.Slider(256, 1024, 512, step=16, label="Height") num_steps = gr.Slider(1, 4, 4, step=1, label="Number of Steps") id_weight = gr.Slider(0.0, 2.0, 1.0, step=0.05, label="ID Weight") generate_btn = gr.Button("Generate") with gr.Column(): output_image = gr.Image(label="Generated Image") seed_output = gr.Textbox(label="Used Seed") generate_btn.click( fn=generate_image, inputs=[prompt, id_image, seed, width, height, num_steps, id_weight], outputs=[output_image, seed_output] ) return demo if __name__ == "__main__": import huggingface_hub huggingface_hub.login(os.getenv("HF_TOKEN")) demo = create_demo() demo.launch()