|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
import torch |
|
import torch.nn.functional as F |
|
from einops import rearrange |
|
from .utils import split_chessboard, merge_chessboard, batched_forward |
|
|
|
def forward(model, input, scales=None, img_sizes=None, max_split_size=None, resize_output_to_idx=0, num_prefix_token=0, |
|
output_shape='bnc', split_forward=False): |
|
|
|
|
|
|
|
assert input.dim() == 4, "Input image must be in the shape of BxCxHxW." |
|
assert input.shape[2] == input.shape[3], "Currently only square images are supported." |
|
assert output_shape in ['bnc', 'bchw'], "Output shape should be either BxNxC (e.g., ViT) or BxCxHxW (e.g., ConvNet)." |
|
assert output_shape == 'bnc' or num_prefix_token == 0, "For ConvNet there shouldn't be any prefix token." |
|
|
|
b, c, input_size, _ = input.shape |
|
|
|
|
|
assert scales is not None or img_sizes is not None, "Please assign either scales or img_sizes." |
|
img_sizes = img_sizes or [int(input_size * scale) for scale in scales] |
|
|
|
|
|
max_split_size = max_split_size or input_size |
|
num_splits = [math.ceil(size / max_split_size) for size in img_sizes] |
|
input_multiscale = [] |
|
for size, num_split in zip(img_sizes, num_splits): |
|
x = F.interpolate(input.to(torch.float32), size=size, mode='bicubic').to(input.dtype) |
|
x = split_chessboard(x, num_split=num_split) |
|
input_multiscale.append(x) |
|
|
|
|
|
outs_multiscale = [batched_forward(model, x, b) if split_forward else model(x) for x in input_multiscale] |
|
if num_prefix_token > 0: |
|
outs_prefix_multiscale = [out[:, :num_prefix_token] for out in outs_multiscale] |
|
outs_multiscale = [out[:, num_prefix_token:] for out in outs_multiscale] |
|
if output_shape == 'bnc': |
|
outs_multiscale = [rearrange(out, 'b (h w) c -> b c h w', h=int(out.shape[1] ** 0.5), w=int(out.shape[1] ** 0.5)) |
|
for out in outs_multiscale] |
|
|
|
|
|
outs_multiscale = [merge_chessboard(out, num_split=num_split) for num_split, out in zip(num_splits, outs_multiscale)] |
|
|
|
|
|
output_size = outs_multiscale[resize_output_to_idx].shape[-2] |
|
out = torch.cat([F.interpolate(outs_multiscale[i].to(torch.float32), size=output_size, |
|
mode='area').to(outs_multiscale[i].dtype) |
|
for i in range(len(outs_multiscale))], dim=1) |
|
if output_shape == 'bnc': |
|
out = rearrange(out, 'b c h w -> b (h w) c') |
|
if num_prefix_token > 0: |
|
|
|
outs_prefix_multiscale = [torch.stack(out.split(b, dim=0), dim=0).mean(dim=0) for out in outs_prefix_multiscale] |
|
out_prefix_multiscale = torch.cat(outs_prefix_multiscale, dim=-1) |
|
out = torch.cat([out_prefix_multiscale, out], dim=1) |
|
|
|
return out |
|
|