Spaces:
Paused
Paused
File size: 6,989 Bytes
96257b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
def modulate(x, shift, scale):
# return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
return x * (1 + scale) + shift # TODO
def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.Tensor:
"""
Pool tokens in x using mask.
NOTE: We assume x does not require gradients.
Args:
x: (B, L, D) tensor of tokens.
mask: (B, L) boolean tensor indicating which tokens are not padding.
Returns:
pooled: (B, D) tensor of pooled tokens.
"""
assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens.
assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens.
mask = mask[:, :, None].to(dtype=x.dtype)
mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1)
pooled = (x * mask).sum(dim=1, keepdim=keepdim)
return pooled
class AttentionPool(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
output_dim: int = None,
device: Optional[torch.device] = None,
):
"""
Args:
spatial_dim (int): Number of tokens in sequence length.
embed_dim (int): Dimensionality of input tokens.
num_heads (int): Number of attention heads.
output_dim (int): Dimensionality of output tokens. Defaults to embed_dim.
"""
super().__init__()
self.num_heads = num_heads
self.to_kv = nn.Linear(embed_dim, 2 * embed_dim, device=device)
self.to_q = nn.Linear(embed_dim, embed_dim, device=device)
self.to_out = nn.Linear(embed_dim, output_dim or embed_dim, device=device)
def forward(self, x, mask):
"""
Args:
x (torch.Tensor): (B, L, D) tensor of input tokens.
mask (torch.Tensor): (B, L) boolean tensor indicating which tokens are not padding.
NOTE: We assume x does not require gradients.
Returns:
x (torch.Tensor): (B, D) tensor of pooled tokens.
"""
D = x.size(2)
# Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L).
attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L).
attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L).
# Average non-padding token features. These will be used as the query.
x_pool = pool_tokens(x, mask, keepdim=True) # (B, 1, D)
# Concat pooled features to input sequence.
x = torch.cat([x_pool, x], dim=1) # (B, L+1, D)
# Compute queries, keys, values. Only the mean token is used to create a query.
kv = self.to_kv(x) # (B, L+1, 2 * D)
q = self.to_q(x[:, 0]) # (B, D)
# Extract heads.
head_dim = D // self.num_heads
kv = kv.unflatten(2, (2, self.num_heads, head_dim)).contiguous() # (B, 1+L, 2, H, head_dim)
kv = kv.transpose(1, 3).contiguous() # (B, H, 2, 1+L, head_dim)
k, v = kv.unbind(2) # (B, H, 1+L, head_dim)
q = q.unflatten(1, (self.num_heads, head_dim)).contiguous() # (B, H, head_dim)
q = q.unsqueeze(2) # (B, H, 1, head_dim)
# Compute attention.
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) # (B, H, 1, head_dim)
# Concatenate heads and run output.
x = x.squeeze(2).flatten(1, 2).contiguous() # (B, D = H * head_dim)
x = self.to_out(x)
return x
class PadSplitXY(torch.nn.Module):
"""
Merge heads, pad and extract visual and text tokens,
and split along the sequence length.
"""
def __init__(self):
super(PadSplitXY, self).__init__()
def forward(
self,
xy: torch.Tensor,
indices: torch.Tensor,
B: int,
N: int,
L: int,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
xy: Packed tokens. Shape: (total <= B * (N + L), num_heads * head_dim).
indices: Valid token indices out of unpacked tensor. Shape: (total,)
Returns:
x: Visual tokens. Shape: (B, N, num_heads * head_dim).
y: Text tokens. Shape: (B, L, num_heads * head_dim).
"""
B, N, L = B, N, L
D = xy.size(1)
# Pad sequences to (B, N + L, dim).
assert indices.ndim == 1
output = torch.zeros(B * (N + L), D, device=xy.device, dtype=dtype)
indices = indices.unsqueeze(1).expand(-1, D) # (total,) -> (total, num_heads * head_dim)
output.scatter_(0, indices, xy)
xy = output.view(B, N + L, D).contiguous()
# Split visual and text tokens along the sequence length.
return torch.tensor_split(xy, (N,), dim=1)
# def pad_and_split_xy(xy, indices, B, N, L, dtype) -> Tuple[torch.Tensor, torch.Tensor]:
# return PadSplitXY.apply(xy, indices, B, N, L, dtype)
def pad_and_split_xy(xy, indices, B, N, L, dtype) -> Tuple[torch.Tensor, torch.Tensor]:
pad = PadSplitXY()
return pad.forward(xy, indices, B, N, L, dtype)
class UnifyStreams(torch.nn.Module):
"""Unify visual and text streams."""
def __init__(self):
super(UnifyStreams, self).__init__()
def forward(
self,
q_x: torch.Tensor,
k_x: torch.Tensor,
v_x: torch.Tensor,
q_y: torch.Tensor,
k_y: torch.Tensor,
v_y: torch.Tensor,
indices: torch.Tensor,
):
"""
Args:
q_x: (B, N, num_heads, head_dim)
k_x: (B, N, num_heads, head_dim)
v_x: (B, N, num_heads, head_dim)
q_y: (B, L, num_heads, head_dim)
k_y: (B, L, num_heads, head_dim)
v_y: (B, L, num_heads, head_dim)
indices: (total <= B * (N + L))
Returns:
qkv: (total <= B * (N + L), 3, num_heads, head_dim)
"""
B, N, num_heads, head_dim = q_x.size()
B, N, L = B, N, q_y.size(1)
D = num_heads * head_dim
q = torch.cat([q_x, q_y], dim=1)
k = torch.cat([k_x, k_y], dim=1)
v = torch.cat([v_x, v_y], dim=1)
qkv = torch.stack([q, k, v], dim=2).view(B * (N + L), 3, D)
indices = indices[:, None, None].expand(-1, 3, D)
qkv = torch.gather(qkv, 0, indices) # (total, 3, num_heads * head_dim)
return qkv.unflatten(2, (num_heads, head_dim)).contiguous()
# def unify_streams(q_x, k_x, v_x, q_y, k_y, v_y, indices) -> torch.Tensor:
# return UnifyStreams.apply(q_x, k_x, v_x, q_y, k_y, v_y, indices)
def unify_streams(q_x, k_x, v_x, q_y, k_y, v_y, indices) -> torch.Tensor:
stream = UnifyStreams()
return stream.forward(q_x, k_x, v_x, q_y, k_y, v_y, indices) |