Spaces:
Configuration error
Configuration error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.utils.checkpoint import checkpoint, checkpoint_sequential | |
from collections.abc import Iterable | |
from itertools import repeat | |
def _ntuple(n): | |
def parse(x): | |
if isinstance(x, Iterable) and not isinstance(x, str): | |
return x | |
return tuple(repeat(x, n)) | |
return parse | |
to_1tuple = _ntuple(1) | |
to_2tuple = _ntuple(2) | |
def set_grad_checkpoint(model, use_fp32_attention=False, gc_step=1): | |
assert isinstance(model, nn.Module) | |
def set_attr(module): | |
module.grad_checkpointing = True | |
module.fp32_attention = use_fp32_attention | |
module.grad_checkpointing_step = gc_step | |
model.apply(set_attr) | |
def auto_grad_checkpoint(module, *args, **kwargs): | |
if getattr(module, 'grad_checkpointing', False): | |
if isinstance(module, Iterable): | |
gc_step = module[0].grad_checkpointing_step | |
return checkpoint_sequential(module, gc_step, *args, **kwargs) | |
else: | |
return checkpoint(module, *args, **kwargs) | |
return module(*args, **kwargs) | |
def checkpoint_sequential(functions, step, input, *args, **kwargs): | |
# Hack for keyword-only parameter in a python 2.7-compliant way | |
preserve = kwargs.pop('preserve_rng_state', True) | |
if kwargs: | |
raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)) | |
def run_function(start, end, functions): | |
def forward(input): | |
for j in range(start, end + 1): | |
input = functions[j](input, *args) | |
return input | |
return forward | |
if isinstance(functions, torch.nn.Sequential): | |
functions = list(functions.children()) | |
# the last chunk has to be non-volatile | |
end = -1 | |
segment = len(functions) // step | |
for start in range(0, step * (segment - 1), step): | |
end = start + step - 1 | |
input = checkpoint(run_function(start, end, functions), input, preserve_rng_state=preserve) | |
return run_function(end + 1, len(functions) - 1, functions)(input) | |
def get_rel_pos(q_size, k_size, rel_pos): | |
""" | |
Get relative positional embeddings according to the relative positions of | |
query and key sizes. | |
Args: | |
q_size (int): size of query q. | |
k_size (int): size of key k. | |
rel_pos (Tensor): relative position embeddings (L, C). | |
Returns: | |
Extracted positional embeddings according to relative positions. | |
""" | |
max_rel_dist = int(2 * max(q_size, k_size) - 1) | |
# Interpolate rel pos if needed. | |
if rel_pos.shape[0] != max_rel_dist: | |
# Interpolate rel pos. | |
rel_pos_resized = F.interpolate( | |
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), | |
size=max_rel_dist, | |
mode="linear", | |
) | |
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) | |
else: | |
rel_pos_resized = rel_pos | |
# Scale the coords with short length if shapes for q and k are different. | |
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) | |
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) | |
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) | |
return rel_pos_resized[relative_coords.long()] | |
def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size): | |
""" | |
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. | |
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 | |
Args: | |
attn (Tensor): attention map. | |
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). | |
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. | |
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. | |
q_size (Tuple): spatial sequence size of query q with (q_h, q_w). | |
k_size (Tuple): spatial sequence size of key k with (k_h, k_w). | |
Returns: | |
attn (Tensor): attention map with added relative positional embeddings. | |
""" | |
q_h, q_w = q_size | |
k_h, k_w = k_size | |
Rh = get_rel_pos(q_h, k_h, rel_pos_h) | |
Rw = get_rel_pos(q_w, k_w, rel_pos_w) | |
B, _, dim = q.shape | |
r_q = q.reshape(B, q_h, q_w, dim) | |
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) | |
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) | |
attn = ( | |
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] | |
).view(B, q_h * q_w, k_h * k_w) | |
return attn | |