senseicashpls1_pr3 / src /pipeline.py
Manoj Bhat
Adding init
dd6a86f
import torch
from PIL.Image import Image
from diffusers import StableDiffusionXLPipeline
from pipelines.models import TextToImageRequest
from diffusers import DDIMScheduler
from torch import Generator
from loss import SchedulerWrapper
from onediffx import compile_pipe, save_pipe, load_pipe
def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs):
if step_index == int(pipe.num_timesteps * 0.77):
callback_kwargs['prompt_embeds'] = callback_kwargs['prompt_embeds'].chunk(2)[-1]
callback_kwargs['add_text_embeds'] = callback_kwargs['add_text_embeds'].chunk(2)[-1]
callback_kwargs['add_time_ids'] = callback_kwargs['add_time_ids'].chunk(2)[-1]
pipe._guidance_scale = 0.1
return callback_kwargs
def load_pipeline(pipeline=None) -> StableDiffusionXLPipeline:
if not pipeline:
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stablediffusionapi/newdream-sdxl-20",
torch_dtype=torch.float16,
).to("cuda")
# Prune the individual models
for name, module in pipeline.text_encoder.named_modules():
if isinstance(module, torch.nn.Linear):
prune.l1_unstructured(module, 'weight', amount=0.2)
for name, module in pipeline.unet.named_modules():
if isinstance(module, torch.nn.Linear):
prune.l1_unstructured(module, 'weight', amount=0.2)
for name, module in pipeline.vae.named_modules():
if isinstance(module, torch.nn.Linear):
prune.l1_unstructured(module, 'weight', amount=0.2)
pipeline.scheduler = SchedulerWrapper(DDIMScheduler.from_config(pipeline.scheduler.config))
pipeline = compile_pipe(pipeline)
load_pipe(pipeline, dir="/home/sandbox/.cache/huggingface/hub/models--RobertML--cached-pipe-02/snapshots/58d70deae87034cce351b780b48841f9746d4ad7")
for _ in range(2):
deepcache_output = pipeline(prompt="telestereography, unstrengthen, preadministrator, copatroness, hyperpersonal, paramountness, paranoid, guaniferous", output_type="pil", num_inference_steps=20)
pipeline.scheduler.prepare_loss()
for _ in range(4):
pipeline(prompt="telestereography, unstrengthen, preadministrator, copatroness, hyperpersonal, paramountness, paranoid, guaniferous", output_type="pil", num_inference_steps=20)
return pipeline
def infer(request: TextToImageRequest, pipeline: StableDiffusionXLPipeline) -> Image:
if request.seed is None:
generator = None
else:
generator = Generator(pipeline.device).manual_seed(request.seed)
return pipeline(
prompt=request.prompt,
negative_prompt=request.negative_prompt,
width=request.width,
height=request.height,
generator=generator,
num_inference_steps=13,
cache_interval=1,
cache_layer_id=1,
cache_block_id=0,
eta=1.0,
guidance_scale = 3.0,
guidance_rescale = 0.0,
callback_on_step_end=callback_dynamic_cfg,
callback_on_step_end_tensor_inputs=['prompt_embeds', 'add_text_embeds', 'add_time_ids'],
).images[0]