rahul7star's picture
Update app.py
569bc6c verified
import torch
from diffusers import DiffusionPipeline
import spaces
from spaces.zero.torch.aoti import aoti_capture, aoti_compile, aoti_apply
from time import perf_counter
CKPT_ID = "black-forest-labs/Flux.1-Dev"
# -----------------------------
# Pipeline arguments
# -----------------------------
PIPE_KWARGS = {
"prompt": "A cat holding a sign that says hello world",
"height": 256, # very small to reduce memory
"width": 256,
"guidance_scale": 3.5,
"num_inference_steps": 25, # fewer steps
"generator": torch.manual_seed(0)
}
# -----------------------------
# Load pipeline
# -----------------------------
def load_pipe():
pipe = DiffusionPipeline.from_pretrained(
CKPT_ID,
torch_dtype=torch.float32,
device_map="cpu"
)
pipe.set_progress_bar_config(disable=True)
return pipe
# -----------------------------
# Compile transformer using aoti (lightweight)
# -----------------------------
@torch.no_grad()
def compile_pipe(pipe):
with torch._inductor.utils.fresh_inductor_cache():
# Capture + compile transformer once
with aoti_capture(pipe.transformer) as call:
pipe(prompt="dummy")
exported = torch.export.export(pipe.transformer, args=call.args, kwargs=call.kwargs)
compiled = aoti_compile(exported)
aoti_apply(compiled, pipe.transformer)
del exported, compiled, call
return pipe
# -----------------------------
# Measure runtime
# -----------------------------
@torch.no_grad()
def run_pipe(pipe):
start = perf_counter()
image = pipe(**PIPE_KWARGS).images[0]
end = perf_counter()
return end-start, image
# -----------------------------
# Main
# -----------------------------
if __name__ == "__main__":
pipe = load_pipe()
pipe = compile_pipe(pipe) # light aoti compile
latency, image = run_pipe(pipe)
print(f"Lightweight CPU + aoti latency: {latency:.2f}s")
image.save("cpu_lightweight.png")