File size: 2,148 Bytes
d7a1c7a
 
4352f4f
 
d7a1c7a
 
4352f4f
9845438
d7a1c7a
a57b03a
d7a1c7a
ed2924c
4352f4f
 
d7a1c7a
a57b03a
4352f4f
c2a96d8
4352f4f
d7a1c7a
ed2924c
d7a1c7a
19256fd
 
d7a1c7a
4352f4f
 
 
 
ed2924c
4352f4f
19256fd
311cc81
60d5f94
7ce8553
4352f4f
d7a1c7a
 
 
 
19256fd
 
d7a1c7a
 
4352f4f
 
 
 
 
 
d7a1c7a
4352f4f
b842cf8
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from diffusers import DiffusionPipeline, AutoencoderKL
from transformers import T5EncoderModel
import torch
import gc
from PIL import Image
from pipelines.models import TextToImageRequest  # Assuming this defines your request object
import os
from torch import Generator

Pipeline = None
# Consistent environment variable setting
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"  # More robust memory management

ckpt_id = "black-forest-labs/FLUX.1-schnell"

def load_pipeline() -> Pipeline:
    gc.collect()
    # torch.cuda.empty_cache()

    dtype = torch.bfloat16
    text_encoder_2 = T5EncoderModel.from_pretrained(
        "city96/t5-v1_1-xxl-encoder-bf16", torch_dtype=dtype
    ).to(memory_format=torch.channels_last)
    vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=dtype).to(memory_format=torch.channels_last)


    pipeline = DiffusionPipeline.from_pretrained(
        ckpt_id,
        vae=vae,
        text_encoder_2=text_encoder_2,
        torch_dtype=dtype,
    )#.to("cuda")
    pipeline.transformer.to(memory_format=torch.channels_last)
    # pipeline.text_encoder.to(memory_format=torch.channels_last)
    

    # Optimize after moving to GPU
    pipeline.vae = torch.compile(pipeline.vae) # compile after moving to device

    # It's unclear if offloading helped in the originals. Test with and without!
    pipeline._exclude_from_cpu_offload = ["vae"]
    pipeline.enable_sequential_cpu_offload()

    # Warmup on GPU
    for _ in range(2):
        pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
    
    return pipeline



@torch.inference_mode()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
    torch.cuda.reset_peak_memory_stats()
    generator = Generator("cuda").manual_seed(request.seed)
    image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]
    return(image)