File size: 14,535 Bytes
9acd97f 075a559 9acd97f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 |
# Modifications Copyright (c) 2025 Advanced Micro Devices, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from diffusers.configuration_utils import register_to_config
from diffusers.models.attention import JointTransformerBlock
from diffusers.models.embeddings import PatchEmbed, TimestepEmbedding, Timesteps
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.normalization import AdaLayerNormContinuous
from diffusers.models.transformers import SD3Transformer2DModel
from diffusers.utils import is_torch_version
def random_masking(x, mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
ids_keep = torch.argsort(noise, dim=1)[:, :len_keep]
ids_keep, _ = torch.sort(ids_keep, dim=1)
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
return x_masked, ids_keep, len_keep
def build_projector(hidden_size, projector_dim, z_dim):
return nn.Sequential(
nn.Linear(hidden_size, projector_dim),
nn.SiLU(),
nn.Linear(projector_dim, projector_dim),
nn.SiLU(),
nn.Linear(projector_dim, z_dim),
)
# Source: https://github.com/NVlabs/Sana/blob/70459f414474c10c509e8b58f3f9442738f85577/diffusion/model/norms.py#L183
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, scale_factor=1.0, eps: float = 1e-6):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.ones(dim) * scale_factor)
def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
return (self.weight * self._norm(x.float())).type_as(x)
class TimestepEmbeddings(nn.Module):
def __init__(self, embedding_dim):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(self, timestep, dtype):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=dtype)) # (N, D)
return timesteps_emb
class NitroMMDiTModel(SD3Transformer2DModel):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
sample_size: int = 128,
patch_size: int = 2,
in_channels: int = 16,
num_layers: int = 24,
attention_head_dim: int = 64,
num_attention_heads: int = 18,
caption_channels: int = 4096,
caption_projection_dim: int = 1152,
out_channels: int = 16,
interpolation_scale: int = 1,
pos_embed_max_size: int = 96,
dual_attention_layers: Tuple[
int, ...
] = (), # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5
qk_norm: Optional[str] = None,
repa_depth=-1,
projector_dim=2048,
z_dims=[768],
):
super().__init__(
sample_size=sample_size,
patch_size=patch_size,
in_channels=in_channels,
num_layers=num_layers,
attention_head_dim=attention_head_dim,
num_attention_heads=num_attention_heads,
caption_projection_dim=caption_projection_dim,
out_channels=out_channels,
pos_embed_max_size=pos_embed_max_size,
dual_attention_layers=dual_attention_layers,
qk_norm=qk_norm,
)
self.patch_mixer_depth = None # initially no masking applied
self.mask_ratio = 0
default_out_channels = in_channels
self.out_channels = out_channels if out_channels is not None else default_out_channels
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
if repa_depth != -1:
self.projectors = nn.ModuleList([build_projector(self.inner_dim, projector_dim, z_dim) for z_dim in z_dims])
assert repa_depth >= 0 and repa_depth < num_layers
self.repa_depth = repa_depth
self.pos_embed = PatchEmbed(
height=self.config.sample_size,
width=self.config.sample_size,
patch_size=self.config.patch_size,
in_channels=self.config.in_channels,
embed_dim=self.inner_dim,
interpolation_scale=self.config.interpolation_scale,
)
self.time_text_embed = TimestepEmbeddings(embedding_dim=self.inner_dim)
self.context_embedder = nn.Linear(self.config.caption_channels, self.config.caption_projection_dim)
self.text_embedding_norm = RMSNorm(self.inner_dim, scale_factor=0.01, eps=1e-5)
# `attention_head_dim` is doubled to account for the mixing.
# It needs to crafted when we get the actual checkpoints.
self.transformer_blocks = nn.ModuleList(
[
JointTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
context_pre_only=i == num_layers - 1,
qk_norm=qk_norm,
use_dual_attention=True if i in dual_attention_layers else False,
)
for i in range(self.config.num_layers)
]
)
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
timestep: torch.LongTensor = None,
block_controlnet_hidden_states: List = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
**kwargs,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
timestep (`torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
height, width = hidden_states.shape[-2:]
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
temb = self.time_text_embed(timestep, dtype=encoder_hidden_states.dtype)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
encoder_hidden_states = self.text_embedding_norm(encoder_hidden_states)
ids_keep = None
len_keep = hidden_states.shape[1]
zs = None
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing and block.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
joint_attention_kwargs,
**ckpt_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
if block_controlnet_hidden_states is not None and block.context_pre_only is False:
interval_control = len(self.transformer_blocks) / len(block_controlnet_hidden_states)
hidden_states = hidden_states + block_controlnet_hidden_states[int(index_block / interval_control)]
# patch masking
if self.training and (self.patch_mixer_depth != -1) and (self.patch_mixer_depth == index_block):
hidden_states, ids_keep, len_keep = random_masking(hidden_states, self.mask_ratio)
# REPA
if self.training and (self.repa_depth != -1) and (self.repa_depth == index_block):
N, T, D = hidden_states.shape
zs = [projector(hidden_states.reshape(-1, D)).reshape(N, len_keep, -1) for projector in self.projectors]
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
# if inference, return the unpatchified output as usual
# if training, return the patch sequence
if not self.training:
patch_size = self.config.patch_size
height = height // patch_size
width = width // patch_size
hidden_states = hidden_states.reshape(
shape=(
hidden_states.shape[0],
height,
width,
patch_size,
patch_size,
self.out_channels,
)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(
hidden_states.shape[0],
self.out_channels,
height * patch_size,
width * patch_size,
)
)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
else:
return hidden_states, ids_keep, zs
def enable_masking(self, depth, mask_ratio):
# depth: apply masking after block_[depth]. should be [0, nblks-1]
assert depth >= 0 and depth < len(self.transformer_blocks)
self.patch_mixer_depth = depth
assert mask_ratio >= 0 and mask_ratio <= 1
self.mask_ratio = mask_ratio
def disable_masking(self):
self.patch_mixer_depth = None
def enable_gradient_checkpointing(self, nblocks_to_apply_grad_checkpointing):
N = len(self.transformer_blocks)
if nblocks_to_apply_grad_checkpointing == -1:
nblocks_to_apply_grad_checkpointing = N
nblocks_to_apply_grad_checkpointing = min(N, nblocks_to_apply_grad_checkpointing)
# Apply to blocks evenly spaced out
step = N / nblocks_to_apply_grad_checkpointing if nblocks_to_apply_grad_checkpointing > 0 else 0
indices = [int((i + 0.5) * step) for i in range(nblocks_to_apply_grad_checkpointing)]
self.gradient_checkpointing = True
for blk_ind, block in enumerate(self.transformer_blocks):
block.gradient_checkpointing = blk_ind in indices
print(f"Block {blk_ind} grad checkpointing set to {block.gradient_checkpointing}")
|