PartCrafter / src /models /attention_processor.py
alexnasa's picture
Upload 85 files
bef5729 verified
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from diffusers.models.attention_processor import Attention
from diffusers.utils import logging
from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
from diffusers.utils.torch_utils import is_torch_version, maybe_allow_in_graph
from einops import rearrange
from torch import nn
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class FlashTripo2AttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the Tripo2DiT model. It applies a s normalization layer and rotary embedding on query and key vector.
"""
def __init__(self, topk=True):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
self.topk = topk
def qkv(self, attn, q, k, v, attn_mask, dropout_p, is_causal):
if k.shape[-2] == 3072:
topk = 1024
elif k.shape[-2] == 512:
topk = 256
else:
topk = k.shape[-2] // 3
if self.topk is True:
q1 = q[:, :, ::100, :]
sim = q1 @ k.transpose(-1, -2)
sim = torch.mean(sim, -2)
topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1)
topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1])
v0 = torch.gather(v, dim=-2, index=topk_ind)
k0 = torch.gather(k, dim=-2, index=topk_ind)
out = F.scaled_dot_product_attention(q, k0, v0)
elif self.topk is False:
out = F.scaled_dot_product_attention(q, k, v)
else:
idx, counts = self.topk
start = 0
outs = []
for grid_coord, count in zip(idx, counts):
end = start + count
q_chunk = q[:, :, start:end, :]
q1 = q_chunk[:, :, ::50, :]
sim = q1 @ k.transpose(-1, -2)
sim = torch.mean(sim, -2)
topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1)
topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1])
v0 = torch.gather(v, dim=-2, index=topk_ind)
k0 = torch.gather(k, dim=-2, index=topk_ind)
out = F.scaled_dot_product_attention(q_chunk, k0, v0)
outs.append(out)
start += count
out = torch.cat(outs, dim=-2)
self.topk = False
return out
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from diffusers.models.embeddings import apply_rotary_emb
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width
).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size
)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(
batch_size, attn.heads, -1, attention_mask.shape[-1]
)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
1, 2
)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(
encoder_hidden_states
)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
# NOTE that tripo2 split heads first then split qkv or kv, like .view(..., attn.heads, 3, dim)
# instead of .view(..., 3, attn.heads, dim). So we need to re-split here.
if not attn.is_cross_attention:
qkv = torch.cat((query, key, value), dim=-1)
split_size = qkv.shape[-1] // attn.heads // 3
qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3)
query, key, value = torch.split(qkv, split_size, dim=-1)
else:
kv = torch.cat((key, value), dim=-1)
split_size = kv.shape[-1] // attn.heads // 2
kv = kv.view(batch_size, -1, attn.heads, split_size * 2)
key, value = torch.split(kv, split_size, dim=-1)
head_dim = key.shape[-1]
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
if not attn.is_cross_attention:
key = apply_rotary_emb(key, image_rotary_emb)
# flashvdm topk
hidden_states = self.qkv(attn, query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class TripoSGAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the TripoSG model. It applies a s normalization layer and rotary embedding on query and key vector.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from diffusers.models.embeddings import apply_rotary_emb
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width
).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size
)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(
batch_size, attn.heads, -1, attention_mask.shape[-1]
)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
1, 2
)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(
encoder_hidden_states
)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
# NOTE that pre-trained models split heads first then split qkv or kv, like .view(..., attn.heads, 3, dim)
# instead of .view(..., 3, attn.heads, dim). So we need to re-split here.
if not attn.is_cross_attention:
qkv = torch.cat((query, key, value), dim=-1)
split_size = qkv.shape[-1] // attn.heads // 3
qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3)
query, key, value = torch.split(qkv, split_size, dim=-1)
else:
kv = torch.cat((key, value), dim=-1)
split_size = kv.shape[-1] // attn.heads // 2
kv = kv.view(batch_size, -1, attn.heads, split_size * 2)
key, value = torch.split(kv, split_size, dim=-1)
head_dim = key.shape[-1]
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
if not attn.is_cross_attention:
key = apply_rotary_emb(key, image_rotary_emb)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class FusedTripoSGAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0) with fused
projection layers. This is used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on
query and key vector.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"FusedTripoSGAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from diffusers.models.embeddings import apply_rotary_emb
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width
).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size
)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(
batch_size, attn.heads, -1, attention_mask.shape[-1]
)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
1, 2
)
# NOTE that pre-trained split heads first, then split qkv
if encoder_hidden_states is None:
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // attn.heads // 3
qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3)
query, key, value = torch.split(qkv, split_size, dim=-1)
else:
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(
encoder_hidden_states
)
query = attn.to_q(hidden_states)
kv = attn.to_kv(encoder_hidden_states)
split_size = kv.shape[-1] // attn.heads // 2
kv = kv.view(batch_size, -1, attn.heads, split_size * 2)
key, value = torch.split(kv, split_size, dim=-1)
head_dim = key.shape[-1]
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
if not attn.is_cross_attention:
key = apply_rotary_emb(key, image_rotary_emb)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
# Modified from https://github.com/VAST-AI-Research/MIDI-3D/blob/main/midi/models/attention_processor.py#L264
class PartCrafterAttnProcessor:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the PartCrafter model. It applies a normalization layer and rotary embedding on query and key vector.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
num_parts: Optional[Union[int, torch.Tensor]] = None,
) -> torch.Tensor:
from diffusers.models.embeddings import apply_rotary_emb
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width
).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size
)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(
batch_size, attn.heads, -1, attention_mask.shape[-1]
)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
1, 2
)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(
encoder_hidden_states
)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
# NOTE that pre-trained models split heads first then split qkv or kv, like .view(..., attn.heads, 3, dim)
# instead of .view(..., 3, attn.heads, dim). So we need to re-split here.
if not attn.is_cross_attention:
qkv = torch.cat((query, key, value), dim=-1)
split_size = qkv.shape[-1] // attn.heads // 3
qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3)
query, key, value = torch.split(qkv, split_size, dim=-1)
else:
kv = torch.cat((key, value), dim=-1)
split_size = kv.shape[-1] // attn.heads // 2
kv = kv.view(batch_size, -1, attn.heads, split_size * 2)
key, value = torch.split(kv, split_size, dim=-1)
head_dim = key.shape[-1]
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
if not attn.is_cross_attention:
key = apply_rotary_emb(key, image_rotary_emb)
if isinstance(num_parts, torch.Tensor):
# Assume list in training, do not consider classifier-free guidance
idx = 0
hidden_states_list = []
for n_p in num_parts:
k = key[idx : idx + n_p]
v = value[idx : idx + n_p]
q = query[idx : idx + n_p]
idx += n_p
if k.shape[2] == q.shape[2]:
# Assuming self-attention
# Here 'b' is always 1
k = rearrange(
k, "(b ni) h nt c -> b h (ni nt) c", ni=n_p
) # [b, h, ni*nt, c]
v = rearrange(
v, "(b ni) h nt c -> b h (ni nt) c", ni=n_p
) # [b, h, ni*nt, c]
else:
# Assuming cross-attention
# Here 'b' is always 1
k = k[::n_p] # [b, h, nt, c]
v = v[::n_p] # [b, h, nt, c]
# Here 'b' is always 1
q = rearrange(
q, "(b ni) h nt c -> b h (ni nt) c", ni=n_p
) # [b, h, ni*nt, c]
# the output of sdp = (batch, num_heads, seq_len, head_dim)
h_s = F.scaled_dot_product_attention(
q, k, v,
dropout_p=0.0,
is_causal=False,
)
h_s = h_s.transpose(1, 2).reshape(
n_p, -1, attn.heads * head_dim
)
h_s = h_s.to(query.dtype)
hidden_states_list.append(h_s)
hidden_states = torch.cat(hidden_states_list, dim=0)
elif isinstance(num_parts, int):
# Assume single instance
if key.shape[2] == query.shape[2]:
# Assuming self-attention
# Here we need 'b' when using classifier-free guidance
key = rearrange(
key, "(b ni) h nt c -> b h (ni nt) c", ni=num_parts
) # [b, h, ni*nt, c]
value = rearrange(
value, "(b ni) h nt c -> b h (ni nt) c", ni=num_parts
) # [b, h, ni*nt, c]
else:
# Assuming cross-attention
# Here we need 'b' when using classifier-free guidance
# Control signal is repeated ni times within each (b, ni)
# We select only the first instance per group
key = key[::num_parts] # [b, h, nt, c]
value = value[::num_parts] # [b, h, nt, c]
query = rearrange(
query, "(b ni) h nt c -> b h (ni nt) c", ni=num_parts
) # [b, h, ni*nt, c]
# the output of sdp = (batch, num_heads, seq_len, head_dim)
hidden_states = F.scaled_dot_product_attention(
query,
key,
value,
dropout_p=0.0,
is_causal=False,
)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
hidden_states = hidden_states.to(query.dtype)
else:
raise ValueError(
"num_parts must be a torch.Tensor or int, but got {}".format(type(num_parts))
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states