File size: 2,593 Bytes
512b774
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
065f6d9
 
 
512b774
 
065f6d9
512b774
647c025
512b774
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ab4409
 
512b774
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
import torch
import torch._dynamo
import gc
from PIL import Image as img
from PIL.Image import Image
from pipelines.models import TextToImageRequest
from torch import Generator
import time
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from torchao.quantization import quantize_, int8_weight_only
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"

Pipeline = None

ckpt_id = "black-forest-labs/FLUX.1-schnell"
def empty_cache():
    start = time.time()
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.reset_peak_memory_stats()

def load_pipeline() -> Pipeline:    
    empty_cache()

    dtype, device = torch.bfloat16, "cuda"

    vae = AutoencoderTiny.from_pretrained("RobertML/FLUX.1-schnell-vae_e3m2", torch_dtype=dtype)
    quantize_(vae, int8_weight_only())

    ############ Text Encoder ############
    text_encoder = CLIPTextModel.from_pretrained(
        ckpt_id, subfolder="text_encoder", torch_dtype=torch.bfloat16
    )
    text_encoder_2 = T5EncoderModel.from_pretrained(
        "HighCWu/FLUX.1-dev-4bit",
        subfolder="text_encoder_2",
        torch_dtype=torch.bfloat16,
    )


    model = FluxTransformer2DModel.from_pretrained(
        "/root/.cache/huggingface/hub/models--RobertML--FLUX.1-schnell-int8wo/snapshots/307e0777d92df966a3c0f99f31a6ee8957a9857a", torch_dtype=dtype, use_safetensors=False
    )
    pipeline = DiffusionPipeline.from_pretrained(
        ckpt_id, 
        transformer=model,
        text_encoder=text_encoder,
        text_encoder_2=text_encoder_2,
        torch_dtype=dtype,
        vae=vae
        ).to(device)

    for _ in range(2):
        pipeline(prompt="onomancy, aftergo, spirantic, Platyhelmia, modificator, drupaceous, jobbernowl, hereness", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
    
    empty_cache()
    return pipeline


@torch.inference_mode()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
    generator = Generator(pipeline.device).manual_seed(request.seed)
    image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]
    return(image)