Spaces:
Running
Running
import os | |
from typing import Union, Optional | |
import torch | |
import torch.nn as nn | |
from transformers.image_processing_utils import BaseImageProcessor | |
class SAFEReducerBlock(nn.Module): | |
""" | |
This is the block that reduces the size of an vactor w and h be half. It is designed to be iterative | |
So it is run multiple times to reduce an image to a desired dimension while carrying a shrinking residual | |
along for the ride. This is done to preserve information. | |
""" | |
def __init__(self, channels=512): | |
super(SAFEReducerBlock, self).__init__() | |
self.channels = channels | |
activation = nn.GELU | |
self.reducer = nn.Sequential( | |
nn.Conv2d(channels, channels, kernel_size=3, padding=1), | |
activation(), | |
nn.BatchNorm2d(channels), | |
nn.Conv2d(channels, channels, kernel_size=3, padding=1), | |
activation(), | |
nn.BatchNorm2d(channels), | |
nn.AvgPool2d(kernel_size=2, stride=2), | |
) | |
self.residual_shrink = nn.AvgPool2d(kernel_size=2, stride=2) | |
def forward(self, x): | |
res = self.residual_shrink(x) | |
reduced = self.reducer(x) | |
return reduced + res | |
class SizeAgnosticFeatureEncoder(nn.Module): | |
def __init__( | |
self, | |
in_channels=3, | |
num_tokens=8, | |
num_vectors=768, | |
reducer_channels=512, | |
channels=2048, | |
downscale_factor: int = 8, | |
): | |
super(SizeAgnosticFeatureEncoder, self).__init__() | |
self.num_tokens = num_tokens | |
self.num_vectors = num_vectors | |
self.channels = channels | |
self.reducer_channels = reducer_channels | |
self.gradient_checkpointing = False | |
# input is minimum of (bs, 3, 256, 256) | |
subpixel_channels = in_channels * downscale_factor ** 2 | |
# PixelUnshuffle(8 = # (bs, 3, 32, 32) -> (bs, 192, 32, 32) | |
# PixelUnshuffle(16 = # (bs, 3, 16, 16) -> (bs, 48, 16, 16) | |
self.unshuffle = nn.PixelUnshuffle(downscale_factor) # (bs, 3, 256, 256) -> (bs, 192, 32, 32) | |
self.conv_in = nn.Conv2d(subpixel_channels, reducer_channels, kernel_size=3, padding=1) # (bs, 192, 32, 32) -> (bs, 512, 32, 32) | |
# run as many times as needed to get to min feature of 8 on the smallest dimension | |
self.reducer = SAFEReducerBlock(reducer_channels) # (bs, 512, 32, 32) -> (bs, 512, 8, 8) | |
self.reduced_out = nn.Conv2d( | |
reducer_channels, self.channels, kernel_size=3, padding=1 | |
) # (bs, 512, 8, 8) -> (bs, 2048, 8, 8) | |
# (bs, 2048, 8, 8) | |
self.block1 = SAFEReducerBlock(self.channels) # (bs, 2048, 8, 8) -> (bs, 2048, 4, 4) | |
self.block2 = SAFEReducerBlock(self.channels) # (bs, 2048, 8, 8) -> (bs, 2048, 2, 2) | |
# reduce mean of dims 2 and 3 | |
self.adaptive_pool = nn.Sequential( | |
nn.AdaptiveAvgPool2d((1, 1)), | |
nn.Flatten(), | |
) | |
# (bs, 2048) | |
# linear layer to (bs, self.num_vectors * self.num_tokens) | |
self.fc1 = nn.Linear(self.channels, self.num_vectors * self.num_tokens) | |
# (bs, self.num_vectors * self.num_tokens) = (bs, 8 * 768) = (bs, 6144) | |
def forward(self, x): | |
x = self.unshuffle(x) | |
x = self.conv_in(x) | |
while True: | |
# reduce until we get as close to 8x8 as possible without going under | |
x = self.reducer(x) | |
if x.shape[2] // 2 < 8 or x.shape[3] // 2 < 8: | |
break | |
x = self.reduced_out(x) | |
x = self.block1(x) | |
x = self.block2(x) | |
x = self.adaptive_pool(x) | |
x = self.fc1(x) | |
# reshape | |
x = x.view(-1, self.num_tokens, self.num_vectors) | |
return x | |
class SAFEIPReturn: | |
def __init__(self, pixel_values): | |
self.pixel_values = pixel_values | |
class SAFEImageProcessor(BaseImageProcessor): | |
def __init__( | |
self, | |
max_size=1024, | |
min_size=256, | |
**kwargs | |
): | |
super().__init__(**kwargs) | |
self.max_size = max_size | |
self.min_size = min_size | |
def from_pretrained( | |
cls, | |
pretrained_model_name_or_path: Union[str, os.PathLike], | |
cache_dir: Optional[Union[str, os.PathLike]] = None, | |
force_download: bool = False, | |
local_files_only: bool = False, | |
token: Optional[Union[str, bool]] = None, | |
revision: str = "main", | |
**kwargs, | |
): | |
# not needed | |
return cls(**kwargs) | |
def __call__( | |
self, | |
images, | |
**kwargs | |
): | |
# TODO allow for random resizing | |
# comes in 0 - 1 range | |
# if any size is smaller than 256, resize to 256 | |
# if any size is larger than max_size, resize to max_size | |
if images.min() < -0.3 or images.max() > 1.3: | |
raise ValueError( | |
"images fed into SAFEImageProcessor values must be between 0 and 1. Got min: {}, max: {}".format( | |
images.min(), images.max() | |
)) | |
# make sure we have (bs, 3, h, w) | |
while len(images.shape) < 4: | |
images = images.unsqueeze(0) | |
# expand to 3 channels if we only have 1 channel | |
if images.shape[1] == 1: | |
images = torch.cat([images, images, images], dim=1) | |
width = images.shape[3] | |
height = images.shape[2] | |
if width < self.min_size or height < self.min_size: | |
# scale up so that the smallest size is 256 | |
if width < height: | |
new_width = self.min_size | |
new_height = int(height * (self.min_size / width)) | |
else: | |
new_height = self.min_size | |
new_width = int(width * (self.min_size / height)) | |
images = nn.functional.interpolate(images, size=(new_height, new_width), mode='bilinear', | |
align_corners=False) | |
elif width > self.max_size or height > self.max_size: | |
# scale down so that the largest size is max_size but do not shrink the other size below 256 | |
if width > height: | |
new_width = self.max_size | |
new_height = int(height * (self.max_size / width)) | |
else: | |
new_height = self.max_size | |
new_width = int(width * (self.max_size / height)) | |
if new_width < self.min_size: | |
new_width = self.min_size | |
new_height = int(height * (self.min_size / width)) | |
if new_height < self.min_size: | |
new_height = self.min_size | |
new_width = int(width * (self.min_size / height)) | |
images = nn.functional.interpolate(images, size=(new_height, new_width), mode='bilinear', | |
align_corners=False) | |
# if wither side is not divisible by 16, mirror pad to make it so | |
if images.shape[2] % 16 != 0: | |
pad = 16 - (images.shape[2] % 16) | |
pad1 = pad // 2 | |
pad2 = pad - pad1 | |
images = nn.functional.pad(images, (0, 0, pad1, pad2), mode='reflect') | |
if images.shape[3] % 16 != 0: | |
pad = 16 - (images.shape[3] % 16) | |
pad1 = pad // 2 | |
pad2 = pad - pad1 | |
images = nn.functional.pad(images, (pad1, pad2, 0, 0), mode='reflect') | |
return SAFEIPReturn(images) | |
class SAFEVMConfig: | |
def __init__( | |
self, | |
in_channels=3, | |
num_tokens=8, | |
num_vectors=768, | |
reducer_channels=512, | |
channels=2048, | |
downscale_factor: int = 8, | |
**kwargs | |
): | |
self.in_channels = in_channels | |
self.num_tokens = num_tokens | |
self.num_vectors = num_vectors | |
self.reducer_channels = reducer_channels | |
self.channels = channels | |
self.downscale_factor = downscale_factor | |
self.image_size = 224 | |
self.hidden_size = num_vectors | |
self.projection_dim = num_vectors | |
class SAFEVMReturn: | |
def __init__(self, output): | |
self.output = output | |
# todo actually do hidden states. This is just for code compatability for now | |
self.hidden_states = [output for _ in range(13)] | |
class SAFEVisionModel(SizeAgnosticFeatureEncoder): | |
def __init__(self, **kwargs): | |
self.config = SAFEVMConfig(**kwargs) | |
self.image_size = None | |
# super().__init__(**kwargs) | |
super(SAFEVisionModel, self).__init__(**kwargs) | |
def from_pretrained(cls, *args, **kwargs): | |
# not needed | |
return SAFEVisionModel(**kwargs) | |
def forward(self, x, **kwargs): | |
return SAFEVMReturn(super().forward(x)) | |