import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint, checkpoint_sequential from collections.abc import Iterable from itertools import repeat def _ntuple(n): def parse(x): if isinstance(x, Iterable) and not isinstance(x, str): return x return tuple(repeat(x, n)) return parse to_1tuple = _ntuple(1) to_2tuple = _ntuple(2) def set_grad_checkpoint(model, use_fp32_attention=False, gc_step=1): assert isinstance(model, nn.Module) def set_attr(module): module.grad_checkpointing = True module.fp32_attention = use_fp32_attention module.grad_checkpointing_step = gc_step model.apply(set_attr) def auto_grad_checkpoint(module, *args, **kwargs): if getattr(module, 'grad_checkpointing', False): if isinstance(module, Iterable): gc_step = module[0].grad_checkpointing_step return checkpoint_sequential(module, gc_step, *args, **kwargs) else: return checkpoint(module, *args, **kwargs) return module(*args, **kwargs) def checkpoint_sequential(functions, step, input, *args, **kwargs): # Hack for keyword-only parameter in a python 2.7-compliant way preserve = kwargs.pop('preserve_rng_state', True) if kwargs: raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)) def run_function(start, end, functions): def forward(input): for j in range(start, end + 1): input = functions[j](input, *args) return input return forward if isinstance(functions, torch.nn.Sequential): functions = list(functions.children()) # the last chunk has to be non-volatile end = -1 segment = len(functions) // step for start in range(0, step * (segment - 1), step): end = start + step - 1 input = checkpoint(run_function(start, end, functions), input, preserve_rng_state=preserve) return run_function(end + 1, len(functions) - 1, functions)(input) def get_rel_pos(q_size, k_size, rel_pos): """ Get relative positional embeddings according to the relative positions of query and key sizes. Args: q_size (int): size of query q. k_size (int): size of key k. rel_pos (Tensor): relative position embeddings (L, C). Returns: Extracted positional embeddings according to relative positions. """ max_rel_dist = int(2 * max(q_size, k_size) - 1) # Interpolate rel pos if needed. if rel_pos.shape[0] != max_rel_dist: # Interpolate rel pos. rel_pos_resized = F.interpolate( rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), size=max_rel_dist, mode="linear", ) rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) else: rel_pos_resized = rel_pos # Scale the coords with short length if shapes for q and k are different. q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) return rel_pos_resized[relative_coords.long()] def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size): """ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 Args: attn (Tensor): attention map. q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. q_size (Tuple): spatial sequence size of query q with (q_h, q_w). k_size (Tuple): spatial sequence size of key k with (k_h, k_w). Returns: attn (Tensor): attention map with added relative positional embeddings. """ q_h, q_w = q_size k_h, k_w = k_size Rh = get_rel_pos(q_h, k_h, rel_pos_h) Rw = get_rel_pos(q_w, k_w, rel_pos_w) B, _, dim = q.shape r_q = q.reshape(B, q_h, q_w, dim) rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) attn = ( attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] ).view(B, q_h * q_w, k_h * k_w) return attn