import torch from PIL.Image import Image from diffusers import StableDiffusionXLPipeline from pipelines.models import TextToImageRequest from diffusers import DDIMScheduler from torch import Generator from loss import SchedulerWrapper, LoadSDXLQuantization from onediffx import compile_pipe, save_pipe, load_pipe def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs): if step_index == int(pipe.num_timesteps * 0.78): callback_kwargs['prompt_embeds'] = callback_kwargs['prompt_embeds'].chunk(2)[-1] callback_kwargs['add_text_embeds'] = callback_kwargs['add_text_embeds'].chunk(2)[-1] callback_kwargs['add_time_ids'] = callback_kwargs['add_time_ids'].chunk(2)[-1] pipe._guidance_scale = 0.1 return callback_kwargs def load_pipeline(pipeline=None) -> StableDiffusionXLPipeline: if not pipeline: pipeline = StableDiffusionXLPipeline.from_pretrained( "stablediffusionapi/newdream-sdxl-20", torch_dtype=torch.float16, ) pipeline.scheduler = SchedulerWrapper(DDIMScheduler.from_config(pipeline.scheduler.config)) quantizer = LoadSDXLQuantization(pipeline.unet) quantizer.load_model() pipeline.to("cuda") pipeline = compile_pipe(pipeline) load_pipe(pipeline, dir="/home/sandbox/.cache/huggingface/hub/models--RobertML--cached-pipe-02/snapshots/58d70deae87034cce351b780b48841f9746d4ad7") for _ in range(1): deepcache_output = pipeline(prompt="polypterid, fattenable, geoparallelotropic, Galeus, galipine, peritoneum, malappropriate, Sekar", output_type="pil", num_inference_steps=20) pipeline.scheduler.prepare_loss() for _ in range(1): pipeline(prompt="polypterid, fattenable, geoparallelotropic, Galeus, galipine, peritoneum, malappropriate, Sekar", output_type="pil", num_inference_steps=20) return pipeline def infer(request: TextToImageRequest, pipeline: StableDiffusionXLPipeline) -> Image: if request.seed is None: generator = None else: generator = Generator(pipeline.device).manual_seed(request.seed) return pipeline( prompt=request.prompt, negative_prompt=request.negative_prompt, width=request.width, height=request.height, generator=generator, num_inference_steps=13, cache_interval=1, cache_layer_id=1, cache_block_id=0, eta=1.0, guidance_scale = 5.0, guidance_rescale = 0.0, callback_on_step_end=callback_dynamic_cfg, callback_on_step_end_tensor_inputs=['prompt_embeds', 'add_text_embeds', 'add_time_ids'], ).images[0]