import torch from diffusers import DiffusionPipeline import spaces from spaces.zero.torch.aoti import aoti_capture, aoti_compile, aoti_apply from time import perf_counter CKPT_ID = "black-forest-labs/Flux.1-Dev" # ----------------------------- # Pipeline arguments # ----------------------------- PIPE_KWARGS = { "prompt": "A cat holding a sign that says hello world", "height": 256, # very small to reduce memory "width": 256, "guidance_scale": 3.5, "num_inference_steps": 25, # fewer steps "generator": torch.manual_seed(0) } # ----------------------------- # Load pipeline # ----------------------------- def load_pipe(): pipe = DiffusionPipeline.from_pretrained( CKPT_ID, torch_dtype=torch.float32, device_map="cpu" ) pipe.set_progress_bar_config(disable=True) return pipe # ----------------------------- # Compile transformer using aoti (lightweight) # ----------------------------- @torch.no_grad() def compile_pipe(pipe): with torch._inductor.utils.fresh_inductor_cache(): # Capture + compile transformer once with aoti_capture(pipe.transformer) as call: pipe(prompt="dummy") exported = torch.export.export(pipe.transformer, args=call.args, kwargs=call.kwargs) compiled = aoti_compile(exported) aoti_apply(compiled, pipe.transformer) del exported, compiled, call return pipe # ----------------------------- # Measure runtime # ----------------------------- @torch.no_grad() def run_pipe(pipe): start = perf_counter() image = pipe(**PIPE_KWARGS).images[0] end = perf_counter() return end-start, image # ----------------------------- # Main # ----------------------------- if __name__ == "__main__": pipe = load_pipe() pipe = compile_pipe(pipe) # light aoti compile latency, image = run_pipe(pipe) print(f"Lightweight CPU + aoti latency: {latency:.2f}s") image.save("cpu_lightweight.png")