from diffusers import FluxPipeline, AutoencoderKL from diffusers.image_processor import VaeImageProcessor from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel import torch import gc from PIL.Image import Image from pipelines.models import TextToImageRequest from torch import Generator Pipeline = None CHECKPOINT = "black-forest-labs/FLUX.1-schnell" def empty_cache(): gc.collect() torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() torch.cuda.reset_peak_memory_stats() def load_pipeline() -> Pipeline: infer(TextToImageRequest(prompt=""), Pipeline) return Pipeline def encode_prompt(prompt: str): text_encoder = CLIPTextModel.from_pretrained( CHECKPOINT, subfolder="text_encoder", torch_dtype=torch.bfloat16, ) text_encoder_2 = T5EncoderModel.from_pretrained( CHECKPOINT, subfolder="text_encoder_2", torch_dtype=torch.bfloat16, ) tokenizer = CLIPTokenizer.from_pretrained(CHECKPOINT, subfolder="tokenizer") tokenizer_2 = T5TokenizerFast.from_pretrained(CHECKPOINT, subfolder="tokenizer_2") pipeline = FluxPipeline.from_pretrained( CHECKPOINT, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, transformer=None, vae=None, ).to("cuda") with torch.no_grad(): return pipeline.encode_prompt( prompt=prompt, prompt_2=None, max_sequence_length=256, ) def infer_latents(prompt_embeds, pooled_prompt_embeds, width: int | None, height: int | None, seed: int | None): pipeline = FluxPipeline.from_pretrained( CHECKPOINT, text_encoder=None, text_encoder_2=None, tokenizer=None, tokenizer_2=None, vae=None, torch_dtype=torch.bfloat16, ).to("cuda") if seed is None: generator = None else: generator = Generator(pipeline.device).manual_seed(seed) return pipeline( prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, num_inference_steps=4, guidance_scale=0.0, width=width, height=height, generator=generator, output_type="latent", ).images def infer(request: TextToImageRequest, _pipeline: Pipeline) -> Image: empty_cache() prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(request.prompt) empty_cache() latents = infer_latents(prompt_embeds, pooled_prompt_embeds, request.width, request.height, request.seed) empty_cache() vae = AutoencoderKL.from_pretrained( CHECKPOINT, subfolder="vae", torch_dtype=torch.bfloat16, ).to("cuda") vae_scale_factor = 2 ** (len(vae.config.block_out_channels)) image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) height = request.height or 64 * vae_scale_factor width = request.width or 64 * vae_scale_factor with torch.no_grad(): latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor) latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor image = vae.decode(latents, return_dict=False)[0] return image_processor.postprocess(image, output_type="pil")[0]