rahul7star's picture
Migrated from GitHub
05fcd0f verified
import numpy as np
import torch
import os
from diffusers_helper.models.mag_cache_ratios import MAG_RATIOS_DB
class MagCache:
"""
Implements the MagCache algorithm for skipping transformer steps during video generation.
MagCache: Fast Video Generation with Magnitude-Aware Cache
Zehong Ma, Longhui Wei, Feng Wang, Shiliang Zhang, Qi Tian
https://arxiv.org/abs/2506.09045
https://github.com/Zehong-Ma/MagCache
PR Demo defaults were threshold=0.1, max_consectutive_skips=3, retention_ratio=0.2
Changing defauults to threshold=0.1, max_consectutive_skips=2, retention_ratio=0.25 for quality vs speed tradeoff.
"""
def __init__(self, model_family, height, width, num_steps, is_enabled=True, is_calibrating = False, threshold=0.1, max_consectutive_skips=2, retention_ratio=0.25):
self.model_family = model_family
self.height = height
self.width = width
self.num_steps = num_steps
self.is_enabled = is_enabled
self.is_calibrating = is_calibrating
self.threshold = threshold
self.max_consectutive_skips = max_consectutive_skips
self.retention_ratio = retention_ratio
# total cache statistics for all sections in the entire generation
self.total_cache_requests = 0
self.total_cache_hits = 0
self.mag_ratios = self._determine_mag_ratios()
self._init_for_every_section()
def _init_for_every_section(self):
self.step_index = 0
self.steps_skipped_list = []
#Error accumulation state
self.accumulated_ratio = 1.0
self.accumulated_steps = 0
self.accumulated_err = 0
# Statistics for calibration
self.norm_ratio, self.norm_std, self.cos_dis = [], [], []
self.hidden_states = None
self.previous_residual = None
if self.is_calibrating and self.total_cache_requests > 0:
print('WARNING: Resetting MagCache calibration stats for new section. Typically you only want one section per calibration job. Discarding calibration from previsou section.')
def should_skip(self, hidden_states):
"""
Expected to be called once per step during the forward pass, for the numer of initialized steps.
Determines if the current step should be skipped based on estimated accumulated error.
If the step is skipped, the hidden_states should be replaced with the output of estimate_predicted_hidden_states().
Args:
hidden_states: The current hidden states tensor from the transformer model.
Returns:
True if the step should be skipped, False otherwise
"""
if self.step_index == 0 or self.step_index >= self.num_steps:
self._init_for_every_section()
self.total_cache_requests += 1
self.hidden_states = hidden_states.clone() # Is clone needed?
if self.is_calibrating:
print('######################### Calibrating MagCache #########################')
return False
should_skip_forward = False
if self.step_index>=int(self.retention_ratio*self.num_steps) and self.step_index>=1: # keep first retention_ratio steps
cur_mag_ratio = self.mag_ratios[self.step_index]
self.accumulated_ratio = self.accumulated_ratio*cur_mag_ratio
cur_skip_err = np.abs(1-self.accumulated_ratio)
self.accumulated_err += cur_skip_err
self.accumulated_steps += 1
# RT_BORG: Per my conversation with Zehong Ma, this 0.06 could potentially be exposed as another tunable param.
if self.accumulated_err<=self.threshold and self.accumulated_steps<=self.max_consectutive_skips and np.abs(1-cur_mag_ratio)<=0.06:
should_skip_forward = True
else:
self.accumulated_ratio = 1.0
self.accumulated_steps = 0
self.accumulated_err = 0
if should_skip_forward:
self.total_cache_hits += 1
self.steps_skipped_list.append(self.step_index)
# Increment for next step
self.step_index += 1
if self.step_index == self.num_steps:
self.step_index = 0
return should_skip_forward
def estimate_predicted_hidden_states(self):
"""
Should be called if and only if should_skip() returned True for the current step.
Estimates the hidden states for the current step based on the previous hidden states and residual.
Returns:
The estimated hidden states tensor.
"""
return self.hidden_states + self.previous_residual
def update_hidden_states(self, model_prediction_hidden_states):
"""
If and only if should_skip() returned False for the current step, the denoising layers should have been run,
and this function should be called to compute and store the residual for future steps.
Args:
model_prediction_hidden_states: The hidden states tensor output from running the denoising layers.
"""
current_residual = model_prediction_hidden_states - self.hidden_states
if self.is_calibrating:
self._update_calibration_stats(current_residual)
self.previous_residual = current_residual
def _update_calibration_stats(self, current_residual):
if self.step_index >= 1:
norm_ratio = ((current_residual.norm(dim=-1)/self.previous_residual.norm(dim=-1)).mean()).item()
norm_std = (current_residual.norm(dim=-1)/self.previous_residual.norm(dim=-1)).std().item()
cos_dis = (1-torch.nn.functional.cosine_similarity(current_residual, self.previous_residual, dim=-1, eps=1e-8)).mean().item()
self.norm_ratio.append(round(norm_ratio, 5))
self.norm_std.append(round(norm_std, 5))
self.cos_dis.append(round(cos_dis, 5))
# print(f"time: {self.step_index}, norm_ratio: {norm_ratio}, norm_std: {norm_std}, cos_dis: {cos_dis}")
self.step_index += 1
if self.step_index == self.num_steps:
print("norm ratio")
print(self.norm_ratio)
print("norm std")
print(self.norm_std)
print("cos_dis")
print(self.cos_dis)
self.step_index = 0
def _determine_mag_ratios(self):
"""
Determines the magnitude ratios by finding the closest resolution and step count
in the pre-calibrated database.
Returns:
A numpy array of magnitude ratios for the specified configuration, or None if not found.
"""
if self.is_calibrating:
return None
try:
# Find the closest available resolution group for the given model family
resolution_groups = MAG_RATIOS_DB[self.model_family]
available_resolutions = list(resolution_groups.keys())
if not available_resolutions:
raise ValueError("No resolutions defined for this model family.")
avg_resolution = (self.height + self.width) / 2.0
closest_resolution_key = min(available_resolutions, key=lambda r: abs(r - avg_resolution))
# Find the closest available step count for the given model/resolution
steps_group = resolution_groups[closest_resolution_key]
available_steps = list(steps_group.keys())
if not available_steps:
raise ValueError(f"No step counts defined for resolution {closest_resolution_key}.")
closest_steps = min(available_steps, key=lambda x: abs(x - self.num_steps))
base_ratios = steps_group[closest_steps]
if closest_steps == self.num_steps:
print(f"MagCache: Found ratios for {self.model_family}, resolution group {closest_resolution_key} ({self.width}x{self.height}), {self.num_steps} steps.")
return base_ratios
print(f"MagCache: Using ratios from {self.model_family}, resolution group {closest_resolution_key} ({self.width}x{self.height}), {closest_steps} steps and interpolating to {self.num_steps} steps.")
return self._nearest_step_interpolation(base_ratios, self.num_steps)
except KeyError:
# This will catch if model_family is not in MAG_RATIOS_DB
print(f"Warning: MagCache not calibrated for model family '{self.model_family}'. MagCache will not be used.")
self.is_enabled = False
except (ValueError, TypeError) as e:
# This will catch errors if resolution keys or step keys are not numbers, or if groups are empty.
print(f"Warning: Error processing MagCache DB for model family '{self.model_family}': {e}. MagCache will not be used.")
self.is_enabled = False
return None
# Nearest interpolation function for MagCache mag_ratios
@staticmethod
def _nearest_step_interpolation(src_array, target_length):
src_length = len(src_array)
if target_length == 1:
return np.array([src_array[-1]])
scale = (src_length - 1) / (target_length - 1)
mapped_indices = np.round(np.arange(target_length) * scale).astype(int)
return src_array[mapped_indices]
def append_calibration_to_file(self, output_file):
"""
Appends tab delimited calibration data (model_family,width,height,norm_ratio) to output_file.
"""
if not self.is_calibrating or not self.norm_ratio:
print("Calibration data can only be appended after calibration.")
return False
try:
with open(output_file, "a") as f:
# Format the data as a string
calibration_set = f"{self.model_family}\t{self.width}\t{self.height}\t{self.num_steps}"
# data_string = f"{calibration_set}\t{self.norm_ratio}"
entry_string = f"{calibration_set}\t{self.num_steps}: np.array([1.0] + {self.norm_ratio}),"
# Append the data to the file
f.write(entry_string + "\n")
print(f"Calibration data appended to {output_file}")
return True
except Exception as e:
print(f"Error appending calibration data: {e}")
return False