Spaces:
Paused
Paused
import torch | |
from diffusers import DiffusionPipeline | |
import spaces | |
from spaces.zero.torch.aoti import aoti_capture, aoti_compile, aoti_apply | |
from time import perf_counter | |
CKPT_ID = "black-forest-labs/Flux.1-Dev" | |
# ----------------------------- | |
# Pipeline arguments | |
# ----------------------------- | |
PIPE_KWARGS = { | |
"prompt": "A cat holding a sign that says hello world", | |
"height": 256, # very small to reduce memory | |
"width": 256, | |
"guidance_scale": 3.5, | |
"num_inference_steps": 25, # fewer steps | |
"generator": torch.manual_seed(0) | |
} | |
# ----------------------------- | |
# Load pipeline | |
# ----------------------------- | |
def load_pipe(): | |
pipe = DiffusionPipeline.from_pretrained( | |
CKPT_ID, | |
torch_dtype=torch.float32, | |
device_map="cpu" | |
) | |
pipe.set_progress_bar_config(disable=True) | |
return pipe | |
# ----------------------------- | |
# Compile transformer using aoti (lightweight) | |
# ----------------------------- | |
def compile_pipe(pipe): | |
with torch._inductor.utils.fresh_inductor_cache(): | |
# Capture + compile transformer once | |
with aoti_capture(pipe.transformer) as call: | |
pipe(prompt="dummy") | |
exported = torch.export.export(pipe.transformer, args=call.args, kwargs=call.kwargs) | |
compiled = aoti_compile(exported) | |
aoti_apply(compiled, pipe.transformer) | |
del exported, compiled, call | |
return pipe | |
# ----------------------------- | |
# Measure runtime | |
# ----------------------------- | |
def run_pipe(pipe): | |
start = perf_counter() | |
image = pipe(**PIPE_KWARGS).images[0] | |
end = perf_counter() | |
return end-start, image | |
# ----------------------------- | |
# Main | |
# ----------------------------- | |
if __name__ == "__main__": | |
pipe = load_pipe() | |
pipe = compile_pipe(pipe) # light aoti compile | |
latency, image = run_pipe(pipe) | |
print(f"Lightweight CPU + aoti latency: {latency:.2f}s") | |
image.save("cpu_lightweight.png") | |