File size: 4,039 Bytes
4ea3271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec97d21
4ea3271
 
 
 
b013be0
 
 
 
 
 
 
 
 
 
 
 
 
 
20243e5
b013be0
 
 
 
 
 
 
 
 
 
 
 
 
4ea3271
b013be0
4ea3271
 
087a885
 
8e69c0a
e63447a
 
 
 
 
 
 
 
 
8e69c0a
e63447a
 
4ea3271
e63447a
4ea3271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#6
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
import torch
import torch._dynamo
import gc
import os
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
from PIL.Image import Image
from pipelines.models import TextToImageRequest
from torch import Generator
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only

os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
torch._dynamo.config.suppress_errors = True
text = "manbeast3b/flux-text-encoder"
Pipeline = None
ids = "slobers/Flux.1.Schnella"
Revision = "e34d670e44cecbbc90e4962e7aada2ac5ce8b55b"

def load_traced_clip_text_model(model_path, config_path, tokenizer_path, device="cpu"):
    """
    Loads a traced CLIPTextModel.

    Args:
        model_path: Path to the traced model file (pytorch_model.bin).
        config_path: Path to the directory containing the config.json file.
        tokenizer_path: Path to the directory containing the tokenizer files.
        device: Device to load the model onto (e.g., "cpu" or "cuda").

    Returns:
        The loaded traced model and tokenizer.
    """
    # Load the traced model
    model = torch.jit.load(os.path.join(model_path, "pytorch_model.bin"), map_location=device)
    model.eval()

    # Load the config
    config = CLIPTextConfig.from_pretrained(config_path)

    # Create a dummy CLIPTextModel (we only need it for the config)
    dummy_model = CLIPTextModel(config)#.to(device)

    # Load the tokenizer
    tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)

    return model, dummy_model.config, tokenizer

def load_pipeline() -> Pipeline:
    device = "cuda"
    path = os.path.join(HF_HUB_CACHE, "models--slobers--Flux.1.Schnella/snapshots/e34d670e44cecbbc90e4962e7aada2ac5ce8b55b/transformer")
    transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16, use_safetensors=False)
    # text = os.path.join(HF_HUB_CACHE, "models--manbeast3b--flux-text-encoder/snapshots/8f16aae56e82e1b2530e931a2c9932a7099c8b3f/")
    # model, config, tokenizer = load_traced_clip_text_model(text, text, text, device)

    # text = os.path.join(HF_HUB_CACHE, "models--manbeast3b--flux_te1/snapshots/13c9a50bd895859518720b4fdd021c747ecf7dbc/")
    # # Load the model (you need to specify the original model architecture)
    # model = CLIPTextModel.from_pretrained("slobers/Flux.1.Schnella", revision="e34d670e44cecbbc90e4962e7aada2ac5ce8b55b", subfolder="text_encoder")
    # # Load the tokenizer
    # tokenizer = CLIPTokenizer.from_pretrained(text) # Load the tokenizer from your repo
    # # Load the quantized state_dict
    # state_dict = torch.load(f"{text}/pytorch_model.bin") # Replace with the path to your local file or from your repo
    # # Load the state_dict into the model
    # model.load_state_dict(state_dict)
    
    # model = CLIPTextModel.from_pretrained(text, torch_dtype=torch.bfloat16)
    pipeline = FluxPipeline.from_pretrained(ids, revision=Revision,  transformer=transformer, local_files_only=True, torch_dtype=torch.bfloat16,) # text_encoder= model,
    pipeline.to("cuda")
    pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune")
    for _ in range(3):
        pipeline(prompt="insensible, timbale, pothery, electrovital, actinogram, taxis, intracerebellar, centrodesmus", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
    return pipeline

@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
    generator = Generator(pipeline.device).manual_seed(request.seed)

    return pipeline(
        request.prompt,
        generator=generator,
        guidance_scale=0.0,
        num_inference_steps=4,
        max_sequence_length=256,
        height=request.height,
        width=request.width,
    ).images[0]