test_q1 / src /pipeline.py
manbeast3b's picture
Update src/pipeline.py
e63447a verified
#6
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
import torch
import torch._dynamo
import gc
import os
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
from PIL.Image import Image
from pipelines.models import TextToImageRequest
from torch import Generator
from diffusers import FluxTransformer2DModel, DiffusionPipeline
from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only
os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "True"
torch._dynamo.config.suppress_errors = True
text = "manbeast3b/flux-text-encoder"
Pipeline = None
ids = "slobers/Flux.1.Schnella"
Revision = "e34d670e44cecbbc90e4962e7aada2ac5ce8b55b"
def load_traced_clip_text_model(model_path, config_path, tokenizer_path, device="cpu"):
"""
Loads a traced CLIPTextModel.
Args:
model_path: Path to the traced model file (pytorch_model.bin).
config_path: Path to the directory containing the config.json file.
tokenizer_path: Path to the directory containing the tokenizer files.
device: Device to load the model onto (e.g., "cpu" or "cuda").
Returns:
The loaded traced model and tokenizer.
"""
# Load the traced model
model = torch.jit.load(os.path.join(model_path, "pytorch_model.bin"), map_location=device)
model.eval()
# Load the config
config = CLIPTextConfig.from_pretrained(config_path)
# Create a dummy CLIPTextModel (we only need it for the config)
dummy_model = CLIPTextModel(config)#.to(device)
# Load the tokenizer
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
return model, dummy_model.config, tokenizer
def load_pipeline() -> Pipeline:
device = "cuda"
path = os.path.join(HF_HUB_CACHE, "models--slobers--Flux.1.Schnella/snapshots/e34d670e44cecbbc90e4962e7aada2ac5ce8b55b/transformer")
transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16, use_safetensors=False)
# text = os.path.join(HF_HUB_CACHE, "models--manbeast3b--flux-text-encoder/snapshots/8f16aae56e82e1b2530e931a2c9932a7099c8b3f/")
# model, config, tokenizer = load_traced_clip_text_model(text, text, text, device)
# text = os.path.join(HF_HUB_CACHE, "models--manbeast3b--flux_te1/snapshots/13c9a50bd895859518720b4fdd021c747ecf7dbc/")
# # Load the model (you need to specify the original model architecture)
# model = CLIPTextModel.from_pretrained("slobers/Flux.1.Schnella", revision="e34d670e44cecbbc90e4962e7aada2ac5ce8b55b", subfolder="text_encoder")
# # Load the tokenizer
# tokenizer = CLIPTokenizer.from_pretrained(text) # Load the tokenizer from your repo
# # Load the quantized state_dict
# state_dict = torch.load(f"{text}/pytorch_model.bin") # Replace with the path to your local file or from your repo
# # Load the state_dict into the model
# model.load_state_dict(state_dict)
# model = CLIPTextModel.from_pretrained(text, torch_dtype=torch.bfloat16)
pipeline = FluxPipeline.from_pretrained(ids, revision=Revision, transformer=transformer, local_files_only=True, torch_dtype=torch.bfloat16,) # text_encoder= model,
pipeline.to("cuda")
pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune")
for _ in range(3):
pipeline(prompt="insensible, timbale, pothery, electrovital, actinogram, taxis, intracerebellar, centrodesmus", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
return pipeline
@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
generator = Generator(pipeline.device).manual_seed(request.seed)
return pipeline(
request.prompt,
generator=generator,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
height=request.height,
width=request.width,
).images[0]