File size: 2,148 Bytes
d7a1c7a 4352f4f d7a1c7a 4352f4f 9845438 d7a1c7a a57b03a d7a1c7a ed2924c 4352f4f d7a1c7a a57b03a 4352f4f c2a96d8 4352f4f d7a1c7a ed2924c d7a1c7a 19256fd d7a1c7a 4352f4f ed2924c 4352f4f 19256fd 311cc81 60d5f94 7ce8553 4352f4f d7a1c7a 19256fd d7a1c7a 4352f4f d7a1c7a 4352f4f b842cf8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
from diffusers import DiffusionPipeline, AutoencoderKL
from transformers import T5EncoderModel
import torch
import gc
from PIL import Image
from pipelines.models import TextToImageRequest # Assuming this defines your request object
import os
from torch import Generator
Pipeline = None
# Consistent environment variable setting
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" # More robust memory management
ckpt_id = "black-forest-labs/FLUX.1-schnell"
def load_pipeline() -> Pipeline:
gc.collect()
# torch.cuda.empty_cache()
dtype = torch.bfloat16
text_encoder_2 = T5EncoderModel.from_pretrained(
"city96/t5-v1_1-xxl-encoder-bf16", torch_dtype=dtype
).to(memory_format=torch.channels_last)
vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=dtype).to(memory_format=torch.channels_last)
pipeline = DiffusionPipeline.from_pretrained(
ckpt_id,
vae=vae,
text_encoder_2=text_encoder_2,
torch_dtype=dtype,
)#.to("cuda")
pipeline.transformer.to(memory_format=torch.channels_last)
# pipeline.text_encoder.to(memory_format=torch.channels_last)
# Optimize after moving to GPU
pipeline.vae = torch.compile(pipeline.vae) # compile after moving to device
# It's unclear if offloading helped in the originals. Test with and without!
pipeline._exclude_from_cpu_offload = ["vae"]
pipeline.enable_sequential_cpu_offload()
# Warmup on GPU
for _ in range(2):
pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
return pipeline
@torch.inference_mode()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
torch.cuda.reset_peak_memory_stats()
generator = Generator("cuda").manual_seed(request.seed)
image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]
return(image) |