OBS-Diff Structured Pruning for Stable Diffusion-xl-base-1.0
โ๏ธ OBS-Diff: Accurate Pruning for Diffusion Models in One-Shot
Junhan Zhu, Hesong Wang, Mingluo Su, Zefang Wang, Huan Wang*
The first training-free, one-shot pruning framework for Diffusion Models, supporting diverse architectures and pruning granularities. Uses Optimal Brain Surgeon (OBS) to achieve SOTA compression with high generative quality.
OBS-Diff-SDXL provides a collection of structured-pruned checkpoints for the Stable Diffusion XL (SDXL) base model, compressed using the OBS-Diff framework. By leveraging an efficient one-shot pruning algorithm, this model significantly reduces the parameter count of the UNet while maintaining high-fidelity image generation capabilities. The provided variants cover a sparsity range from 10% to 30%, offering a trade-off between model size and performance.

Pruned UNet Variants
| Sparsity (%) | 0 (Dense) | 10 | 15 | 20 | 25 | 30 |
|---|---|---|---|---|---|---|
| Params (B) | 2.57 | 2.35 | 2.24 | 2.13 | 2.02 | 1.91 |
How to use the pruned model
Download the base model (SDXL) from huggingface or ModelScope.
Download the pruned weights (.pth files) and use
torch.loadto replace the original UNet in the pipeline.Run inference using the code below.
import os
import torch
from diffusers import DiffusionPipeline
from PIL import Image
# 1. Load the base SDXL model
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
# 2. Swap the original UNet with the pruned UNet checkpoint
# Note: Ensure the path points to your downloaded .pth file
pruned_unet_path = "/path/to/sparsity_30/unet_pruned.pth"
pipe.unet = torch.load(pruned_unet_path, weights_only=False)
pipe = pipe.to("cuda")
total_params = sum(p.numel() for p in pipe.unet.parameters())
print(f"Total UNet parameters: {total_params / 1e6:.2f} M")
image = pipe(
prompt="A ship sailing through a sea of clouds, golden hour, impasto oil painting, brush strokes visible, dreamlike atmosphere.",
negative_prompt=None,
height=1024,
width=1024,
num_inference_steps=30,
guidance_scale=7.0,
generator=torch.Generator("cuda").manual_seed(42)
).images[0]
image.save("output_pruned.png")
Citation
If you find this work useful, please consider citing:
@article{zhu2025obs,
title={OBS-Diff: Accurate Pruning For Diffusion Models in One-Shot},
author={Zhu, Junhan and Wang, Hesong and Su, Mingluo and Wang, Zefang and Wang, Huan},
journal={arXiv preprint arXiv:2510.06751},
year={2025}
}
- Downloads last month
- -