File size: 10,362 Bytes
05fcd0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
import numpy as np
import torch
import os

from diffusers_helper.models.mag_cache_ratios import MAG_RATIOS_DB


class MagCache:
    """
    Implements the MagCache algorithm for skipping transformer steps during video generation.
    MagCache: Fast Video Generation with Magnitude-Aware Cache
    Zehong Ma, Longhui Wei, Feng Wang, Shiliang Zhang, Qi Tian
    https://arxiv.org/abs/2506.09045
    https://github.com/Zehong-Ma/MagCache
    PR Demo defaults were threshold=0.1, max_consectutive_skips=3, retention_ratio=0.2
    Changing defauults to threshold=0.1, max_consectutive_skips=2, retention_ratio=0.25 for quality vs speed tradeoff.
    """

    def __init__(self, model_family, height, width, num_steps, is_enabled=True, is_calibrating = False, threshold=0.1, max_consectutive_skips=2, retention_ratio=0.25):
        self.model_family = model_family
        self.height = height
        self.width = width
        self.num_steps = num_steps

        self.is_enabled = is_enabled
        self.is_calibrating = is_calibrating

        self.threshold = threshold
        self.max_consectutive_skips = max_consectutive_skips
        self.retention_ratio = retention_ratio

        # total cache statistics for all sections in the entire generation
        self.total_cache_requests = 0
        self.total_cache_hits = 0

        self.mag_ratios = self._determine_mag_ratios()

        self._init_for_every_section()


    def _init_for_every_section(self):
        self.step_index = 0
        self.steps_skipped_list = []
        #Error accumulation state
        self.accumulated_ratio = 1.0
        self.accumulated_steps = 0
        self.accumulated_err = 0
        # Statistics for calibration
        self.norm_ratio, self.norm_std, self.cos_dis = [], [], []

        self.hidden_states = None
        self.previous_residual = None

        if self.is_calibrating and self.total_cache_requests > 0:
                print('WARNING: Resetting MagCache calibration stats for new section. Typically you only want one section per calibration job. Discarding calibration from previsou section.')

    def should_skip(self, hidden_states):
        """
        Expected to be called once per step during the forward pass, for the numer of initialized steps.
        Determines if the current step should be skipped based on estimated accumulated error.
        If the step is skipped, the hidden_states should be replaced with the output of estimate_predicted_hidden_states().
        
        Args:
            hidden_states: The current hidden states tensor from the transformer model.
        Returns:
            True if the step should be skipped, False otherwise
        """
        if self.step_index == 0 or self.step_index >= self.num_steps:
            self._init_for_every_section()
        self.total_cache_requests += 1
        self.hidden_states = hidden_states.clone() # Is clone needed?

        if self.is_calibrating:
            print('######################### Calibrating MagCache #########################')
            return False
        
        should_skip_forward = False
        if self.step_index>=int(self.retention_ratio*self.num_steps) and self.step_index>=1: # keep first retention_ratio steps
            cur_mag_ratio = self.mag_ratios[self.step_index]
            self.accumulated_ratio = self.accumulated_ratio*cur_mag_ratio
            cur_skip_err = np.abs(1-self.accumulated_ratio)
            self.accumulated_err += cur_skip_err
            self.accumulated_steps += 1
            # RT_BORG: Per my conversation with Zehong Ma, this 0.06 could potentially be exposed as another tunable param.
            if self.accumulated_err<=self.threshold and self.accumulated_steps<=self.max_consectutive_skips and np.abs(1-cur_mag_ratio)<=0.06:
                should_skip_forward = True
            else:
                self.accumulated_ratio = 1.0
                self.accumulated_steps = 0
                self.accumulated_err = 0

        if should_skip_forward:
            self.total_cache_hits += 1
            self.steps_skipped_list.append(self.step_index)
        # Increment for next step
        self.step_index += 1
        if self.step_index == self.num_steps:
            self.step_index = 0

        return should_skip_forward
    
    def estimate_predicted_hidden_states(self):
        """
        Should be called if and only if should_skip() returned True for the current step.
        Estimates the hidden states for the current step based on the previous hidden states and residual.
        
        Returns:
            The estimated hidden states tensor.
        """
        return self.hidden_states + self.previous_residual
    
    def update_hidden_states(self, model_prediction_hidden_states):
        """
        If and only if should_skip() returned False for the current step, the denoising layers should have been run,
        and this function should be called to compute and store the residual for future steps.

        Args:
            model_prediction_hidden_states: The hidden states tensor output from running the denoising layers.
        """

        current_residual = model_prediction_hidden_states - self.hidden_states
        if self.is_calibrating:
            self._update_calibration_stats(current_residual)

        self.previous_residual = current_residual

    def _update_calibration_stats(self, current_residual):        
        if self.step_index >= 1:
            norm_ratio = ((current_residual.norm(dim=-1)/self.previous_residual.norm(dim=-1)).mean()).item()
            norm_std = (current_residual.norm(dim=-1)/self.previous_residual.norm(dim=-1)).std().item()
            cos_dis = (1-torch.nn.functional.cosine_similarity(current_residual, self.previous_residual, dim=-1, eps=1e-8)).mean().item()
            self.norm_ratio.append(round(norm_ratio, 5))
            self.norm_std.append(round(norm_std, 5))
            self.cos_dis.append(round(cos_dis, 5))
            # print(f"time: {self.step_index}, norm_ratio: {norm_ratio}, norm_std: {norm_std}, cos_dis: {cos_dis}")
        
        self.step_index += 1
        if self.step_index == self.num_steps:
            print("norm ratio")
            print(self.norm_ratio)
            print("norm std")
            print(self.norm_std)
            print("cos_dis")
            print(self.cos_dis)
            self.step_index = 0

    def _determine_mag_ratios(self):
        """
        Determines the magnitude ratios by finding the closest resolution and step count
        in the pre-calibrated database.
        
        Returns:
            A numpy array of magnitude ratios for the specified configuration, or None if not found.
        """
        if self.is_calibrating:
            return None
        try:
            # Find the closest available resolution group for the given model family
            resolution_groups = MAG_RATIOS_DB[self.model_family]
            available_resolutions = list(resolution_groups.keys())
            if not available_resolutions:
                raise ValueError("No resolutions defined for this model family.")

            avg_resolution = (self.height + self.width) / 2.0
            closest_resolution_key = min(available_resolutions, key=lambda r: abs(r - avg_resolution))

            # Find the closest available step count for the given model/resolution
            steps_group = resolution_groups[closest_resolution_key]
            available_steps = list(steps_group.keys())
            if not available_steps:
                raise ValueError(f"No step counts defined for resolution {closest_resolution_key}.")
            closest_steps = min(available_steps, key=lambda x: abs(x - self.num_steps))
            base_ratios = steps_group[closest_steps]
            if closest_steps == self.num_steps:
                print(f"MagCache: Found ratios for {self.model_family}, resolution group {closest_resolution_key} ({self.width}x{self.height}), {self.num_steps} steps.")
                return base_ratios
            print(f"MagCache: Using ratios from {self.model_family}, resolution group {closest_resolution_key} ({self.width}x{self.height}), {closest_steps} steps and interpolating to {self.num_steps} steps.")
            return self._nearest_step_interpolation(base_ratios, self.num_steps)
        except KeyError:
            # This will catch if model_family is not in MAG_RATIOS_DB
            print(f"Warning: MagCache not calibrated for model family '{self.model_family}'. MagCache will not be used.")
            self.is_enabled = False
        except (ValueError, TypeError) as e:
            # This will catch errors if resolution keys or step keys are not numbers, or if groups are empty.
            print(f"Warning: Error processing MagCache DB for model family '{self.model_family}': {e}. MagCache will not be used.")
            self.is_enabled = False
        return None

    # Nearest interpolation function for MagCache mag_ratios
    @staticmethod
    def _nearest_step_interpolation(src_array, target_length):
        src_length = len(src_array)
        if target_length == 1:
            return np.array([src_array[-1]])

        scale = (src_length - 1) / (target_length - 1)
        mapped_indices = np.round(np.arange(target_length) * scale).astype(int)
        return src_array[mapped_indices]

    def append_calibration_to_file(self, output_file):
        """
        Appends tab delimited calibration data (model_family,width,height,norm_ratio) to output_file.
        """
        if not self.is_calibrating or not self.norm_ratio:
            print("Calibration data can only be appended after calibration.")
            return False
        try:
            with open(output_file, "a") as f:
                # Format the data as a string
                calibration_set = f"{self.model_family}\t{self.width}\t{self.height}\t{self.num_steps}"
                # data_string = f"{calibration_set}\t{self.norm_ratio}"
                entry_string = f"{calibration_set}\t{self.num_steps}: np.array([1.0] + {self.norm_ratio}),"
                # Append the data to the file
                f.write(entry_string + "\n")
            print(f"Calibration data appended to {output_file}")
            return True
        except Exception as e:
            print(f"Error appending calibration data: {e}")
            return False