Spectral Basis Adapter (SBA): Dynamic Efficient Fine-Tuning for Diffusion Models

Author: YSNRFD
Project Page: GitHub | HuggingFace | Civitai


Abstract

Fine-tuning large diffusion models like Stable Diffusion XL (SDXL) typically requires substantial computational resources. While Low-Rank Adaptation (LoRA) has become the standard for efficient fine-tuning, it relies on static weight updates. The Spectral Basis Adapter (SBA) introduces a novel approach: a dynamic, LoRA-inspired mechanism that replaces static adaptations with a learnable mixture of orthogonal basis matrices.

This article details the architecture, implementation, and practical application of SBA. We explore how it enables conditional adaptation using timestep and context embeddings while maintaining a low parameter footprint and minimal VRAM usage (under 11GB for SDXL training).


1. Introduction

The rapid evolution of generative AI has led to increasingly large U-Net architectures. Training these models from scratch is prohibitive for most researchers, leading to the popularity of Parameter-Efficient Fine-Tuning (PEFT) methods like LoRA. However, standard LoRA applies a fixed delta $W$ regardless of the input conditions.

SBA addresses this by making the adaptation dynamic. Instead of a single static low-rank update, SBA utilizes a "Basis Bank" of orthogonal matrices. A lightweight gating mechanism combines these bases dynamically based on the current timestep ($t$) and text conditioning ($c$). This allows the model to adapt its behavior specifically to the semantic and temporal context of the generation step.

Key Features

  • Dynamic Adaptation: Weights change based on $t$ and $c$ embeddings via a learned gating mechanism.
  • Orthogonal Basis Bank: Uses a mixture of identity and random orthogonal matrices to preserve manifold geometry.
  • VRAM Efficient: Optimizations in the gate architecture reduce optimizer state overhead, enabling training on consumer GPUs (e.g., 16GB VRAM).
  • Seamless Integration: Injects directly into Hugging Face Diffusers UNet models without altering the base model weights.

2. Theoretical Foundation

The core operation of SBA can be defined by the following equation:

y=Bβ‹…SiLU(M(t,c)β‹…Aβ‹…x)y = B \cdot \text{SiLU}( M(t, c) \cdot A \cdot x )

Where:

  • $x$: Input tensor to the linear layer.
  • $A$: Down-projection matrix (reduces dimension to rank).
  • $B$: Up-projection matrix (projects back to output dimension).
  • $M(t, c)$: The Mixing Matrix. This is the heart of SBA. It is computed dynamically for every forward pass.
  • $y$: The residual output added to the original layer's output.

The Mixing Matrix $M(t, c)$

The mixing matrix is not a static parameter but a function of the timestep embedding ($t_{emb}$) and the context embedding ($c_{emb}$).

M(t,c)=βˆ‘i=0Nβˆ’1Ξ±i(t,c)β‹…BasisiM(t, c) = \sum_{i=0}^{N-1} \alpha_i(t, c) \cdot \text{Basis}_i

  • $\text{Basis}_i$: A set of $N$ orthogonal matrices of size $\text{rank} \times \text{rank}$.
  • $\alpha_i(t, c)$: Scalar coefficients computed by a gating network (Softmax output) that determines how much each basis contributes to the final transformation.

This formulation allows SBA to switch between different behaviors encoded in the bases, depending on whether the model is denoising high-frequency noise (early timesteps) or refining details (late timesteps).


3. Architecture & Implementation

The implementation is divided into three main components: the SpectralBasisAdapter, the LinearWithSBA wrapper, and the SBA Injector.

3.1 SpectralBasisAdapter (sba.py)

This module defines the core logic.

1. Parameter Initialization The adapter consists of three learnable parameter groups:

  • lora_A: Rank-reduction projection.
  • lora_B: Rank-restoration projection.
  • basis_bank: A 3D tensor of shape (num_bases, rank, rank).

2. The Orthogonal Basis Bank To ensure training stability and preserve the information manifold, the basis matrices are initialized to be orthogonal.

  • Basis 0: Initialized as an Identity matrix ($I$). This ensures that at $t=0$, the adapter acts roughly like a standard LoRA.
  • Basis 1..N: Initialized via QR decomposition of random Gaussian matrices.

3. The Memory-Efficient Gate A critical optimization found in the code is the simplification of the gating network.

  • Previous approach: A multi-layer perceptron (MLP).
  • Current approach: A single nn.Linear layer.
  • Reasoning: In AdamW optimization, optimizer states (momentum and variance) can take up to 8 bytes per parameter. By reducing the gate parameters from ~400k to ~10k per layer, the VRAM usage for optimizer states drops from hundreds of megabytes to single digits, significantly reducing the overall memory footprint (Total optimizer states reduced from ~240M to ~6M params).

3.2 LinearWithSBA (sba.py)

This is a wrapper module that replaces standard nn.Linear layers in the UNet.

  • It freezes the original layer weights (requires_grad=False).
  • It initializes the SpectralBasisAdapter.
  • During the forward pass, it returns Original_Output + SBA_Output.

3.3 SBA Injector (sba_injector.py)

The injector is responsible for recursively patching the UNet architecture.

Mechanism:

  1. Global Context: It defines a global storage _SBA_CONTEXT to pass timestep and context embeddings ($t_{emb}$, $c_{emb}$) to all layers without altering the function signatures of every underlying method.
  2. Monkey Patching: It overrides UNet2DConditionModel.forward to intercept added_cond_kwargs. This allows users to pass SBA embeddings using standard Diffusers arguments.
  3. Recursive Traversal: It iterates through Down, Mid, and Up blocks of the UNet.
  4. Targeted Injection:
    • Transformers: Injects into QKV projections (to_q, to_k, to_v) and output projections (to_out).
    • FFN (Optional): Can inject into FeedForward networks (GEGLU layers).
    • ResNet (Optional): Can inject into ResNet time embeddings.

VRAM Controls: The injector includes flags to skip ResNet (inject_into_resnet=False) and FFN injection. This is crucial for fitting SDXL training into limited VRAM.


4. Training Workflow

The training script (train_sba.py) demonstrates a standard PyTorch loop integrated with Hugging Face Diffusers.

Step 1: Model Loading & Injection

model_id = "stabilityai/stable-diffusion-xl-base-1.0"
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")

inject_sba_into_diffusers_unet(
    unet, 
    rank=16,              # Rank for low-rank adaptation
    num_bases=4,          # Number of spectral bases
    inject_into_ffn=False, 
    inject_into_resnet=False # Disabled to save VRAM
)

Step 2: Optimizer Configuration

The optimizer is configured to treat gate parameters differently from basis/LoRA parameters. Typically, gates require a higher learning rate to converge faster.

optimizer = configure_sba_optimizer(unet, lr=1e-4, lr_gate=5e-4)

Step 3: The Forward Pass

The forward pass utilizes the UNet's added_cond_kwargs to smuggle the SBA embeddings into the global context.

# Prepare embeddings
added_cond_kwargs = {
    "text_embeds": pooled_text_embeds,
    "time_ids": added_time_ids,
    "sba_t_emb": time_emb,       # Specific to SBA
    "sba_c_emb": pooled_text_embeds # Specific to SBA
}

# Mixed precision forward pass
with torch.autocast(device_type="cuda", dtype=torch.float16):
    noise_pred = unet(dummy_latents, dummy_timesteps, 
                      encoder_hidden_states=encoder_hidden_states, 
                      added_cond_kwargs=added_cond_kwargs)

Step 4: Gradient Checkpointing

To further reduce memory, gradient checkpointing is enabled:

unet.enable_gradient_checkpointing()

This trades compute for memory by recalculating activations during the backward pass rather than storing them.


5. Performance & Diagnostics

Based on the execution logs provided, we can analyze the performance of the SBA integration on SDXL.

Configuration:

  • Model: Stable Diffusion XL Base 1.0
  • Rank: 16
  • Injection: Transformer blocks only (ResNet/FFN skipped)
  • Hardware: CUDA GPU (Mixed Precision FP16)

Results:

  • Trainable Parameters: ~30.5 Million.
    • Note: This includes only the SBA parameters. The base UNet (approx 2.6B params) remains frozen.
  • VRAM Allocation: ~10.05 GB.
    • This is highly efficient for a full-architecture modification of SDXL, fitting comfortably within consumer 12GB-16GB cards.
  • Graph Integrity: Successful backward pass (ConvolutionBackward0 object), confirming that the autograd graph flows correctly through the dynamic mixing matrix.

6. Conclusion

The Spectral Basis Adapter (SBA) represents a significant step forward in parameter-efficient fine-tuning for diffusion models. By moving from static weight updates to dynamic, condition-dependent mixtures of orthogonal bases, SBA offers a more expressive adaptation mechanism.

The implementation provided demonstrates that this expressivity does not come at the cost of usability or memory. Through clever optimizations like the simplified gate and optional injection targets, SBA makes advanced, dynamic fine-tuning of massive models like SDXL accessible on standard hardware.

Recommended Environment:

  • PyTorch >= 2.1
  • diffusers >= 0.26
  • CUDA-enabled GPU

For collaboration, questions, or access to the latest codebases, please refer to the author's profiles on GitHub or HuggingFace.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support