OBS-Diff Structured Pruning for Stable Diffusion-xl-base-1.0

OBS-Diff

โœ‚๏ธ OBS-Diff: Accurate Pruning for Diffusion Models in One-Shot

Junhan Zhu, Hesong Wang, Mingluo Su, Zefang Wang, Huan Wang*

The first training-free, one-shot pruning framework for Diffusion Models, supporting diverse architectures and pruning granularities. Uses Optimal Brain Surgeon (OBS) to achieve SOTA compression with high generative quality.

OBS-Diff-SDXL provides a collection of structured-pruned checkpoints for the Stable Diffusion XL (SDXL) base model, compressed using the OBS-Diff framework. By leveraging an efficient one-shot pruning algorithm, this model significantly reduces the parameter count of the UNet while maintaining high-fidelity image generation capabilities. The provided variants cover a sparsity range from 10% to 30%, offering a trade-off between model size and performance.

Pruned UNet Variants

Sparsity (%) 0 (Dense) 10 15 20 25 30
Params (B) 2.57 2.35 2.24 2.13 2.02 1.91

How to use the pruned model

  1. Download the base model (SDXL) from huggingface or ModelScope.

  2. Download the pruned weights (.pth files) and use torch.load to replace the original UNet in the pipeline.

  3. Run inference using the code below.

import os
import torch
from diffusers import DiffusionPipeline
from PIL import Image


# 1. Load the base SDXL model
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)

# 2. Swap the original UNet with the pruned UNet checkpoint
# Note: Ensure the path points to your downloaded .pth file
pruned_unet_path = "/path/to/sparsity_30/unet_pruned.pth"
pipe.unet = torch.load(pruned_unet_path, weights_only=False)
pipe = pipe.to("cuda")

total_params = sum(p.numel() for p in pipe.unet.parameters())
print(f"Total UNet parameters: {total_params / 1e6:.2f} M")

image = pipe(
    prompt="A ship sailing through a sea of clouds, golden hour, impasto oil painting, brush strokes visible, dreamlike atmosphere.",
    negative_prompt=None,
    height=1024,
    width=1024,
    num_inference_steps=30,
    guidance_scale=7.0,
    generator=torch.Generator("cuda").manual_seed(42)
).images[0]

image.save("output_pruned.png")

Citation

If you find this work useful, please consider citing:

@article{zhu2025obs,
  title={OBS-Diff: Accurate Pruning For Diffusion Models in One-Shot},
  author={Zhu, Junhan and Wang, Hesong and Su, Mingluo and Wang, Zefang and Wang, Huan},
  journal={arXiv preprint arXiv:2510.06751},
  year={2025}
}
Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support