Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from unik3d.utils.constants import VERBOSE | |
from unik3d.utils.geometric import downsample, erode | |
from unik3d.utils.misc import profile_method | |
from .utils import (FNS, REGRESSION_DICT, ind2sub, masked_mean, | |
masked_quantile, ssi, ssi_nd) | |
def sample_strong_edges(edges_img, quantile=0.95, reshape=8): | |
# flat | |
edges_img = F.interpolate( | |
edges_img, scale_factor=1 / reshape, mode="bilinear", align_corners=False | |
) | |
edges_img_flat = edges_img.flatten(1) | |
# Find strong edges | |
edges_mask = edges_img_flat > torch.quantile( | |
edges_img_flat, quantile, dim=-1, keepdim=True | |
) | |
num_samples = edges_mask.sum(dim=-1) | |
if (num_samples < 10).any(): | |
# sample random edges where num_samples < 2 | |
random = torch.rand_like(edges_img_flat[num_samples < 10, :]) > quantile | |
edges_mask[num_samples < 10, :] = torch.logical_or( | |
edges_mask[num_samples < 10, :], random | |
) | |
num_samples = edges_mask.sum(dim=-1) | |
min_samples = num_samples.min() | |
# Compute the coordinates of the strong edges as B, N, 2 | |
edges_coords = torch.stack( | |
[torch.nonzero(x, as_tuple=False)[:min_samples].squeeze() for x in edges_mask] | |
) | |
edges_coords = ( | |
torch.stack(ind2sub(edges_coords, edges_img.shape[-1]), dim=-1) * reshape | |
) | |
return edges_coords | |
def extract_patches(tensor, sample_coords, patch_size: tuple[int, int] = (32, 32)): | |
""" | |
Extracts patches around specified edge coordinates, with zero padding. | |
Parameters: | |
- tensor: tenosr to be gatherd based on sampling (B, 1, H, W). | |
- sample_coords: Batch of edge coordinates as a PyTorch tensor of shape (B, num_coords, 2). | |
- patch_size: Tuple (width, height) representing the size of the patches. | |
Returns: | |
- Patches as a PyTorch tensor of shape (B, num_coords, patch_height, patch_width). | |
""" | |
N, _, H, W = tensor.shape | |
device = tensor.device | |
dtype = tensor.dtype | |
patch_width, patch_height = patch_size | |
pad_width = patch_width // 2 | |
pad_height = patch_height // 2 | |
# Pad the RGB images for both sheep | |
tensor_padded = F.pad( | |
tensor, | |
(pad_width, pad_width, pad_height, pad_height), | |
mode="constant", | |
value=0.0, | |
) | |
# Adjust edge coordinates to account for padding | |
sample_coords_padded = sample_coords + torch.tensor( | |
[pad_height, pad_width], dtype=dtype, device=device | |
).reshape(1, 1, 2) | |
# Calculate the indices for gather operation | |
x_centers = sample_coords_padded[:, :, 1].int() | |
y_centers = sample_coords_padded[:, :, 0].int() | |
all_patches = [] | |
for tensor_i, x_centers_i, y_centers_i in zip(tensor_padded, x_centers, y_centers): | |
patches = [] | |
for x_center, y_center in zip(x_centers_i, y_centers_i): | |
y_start, y_end = y_center - pad_height, y_center + pad_height + 1 | |
x_start, x_end = x_center - pad_width, x_center + pad_width + 1 | |
patches.append(tensor_i[..., y_start:y_end, x_start:x_end]) | |
all_patches.append(torch.stack(patches, dim=0)) | |
return torch.stack(all_patches, dim=0).reshape(N, -1, patch_height * patch_width) | |
class LocalSSI(nn.Module): | |
def __init__( | |
self, | |
weight: float, | |
output_fn: str = "sqrt", | |
patch_size: tuple[int, int] = (32, 32), | |
min_samples: int = 4, | |
num_levels: int = 4, | |
fn: str = "l1", | |
rescale_fn: str = "ssi", | |
input_fn: str = "linear", | |
quantile: float = 0.1, | |
gamma: float = 1.0, | |
alpha: float = 1.0, | |
relative: bool = False, | |
eps: float = 1e-5, | |
): | |
super(LocalSSI, self).__init__() | |
self.name: str = self.__class__.__name__ | |
self.weight = weight | |
self.output_fn = FNS[output_fn] | |
self.input_fn = FNS[input_fn] | |
self.fn = REGRESSION_DICT[fn] | |
self.min_samples = min_samples | |
self.eps = eps | |
patch_logrange = np.linspace( | |
start=np.log2(min(patch_size)), | |
stop=np.log2(max(patch_size)), | |
endpoint=True, | |
num=num_levels + 1, | |
) | |
self.patch_logrange = [ | |
(x, y) for x, y in zip(patch_logrange[:-1], patch_logrange[1:]) | |
] | |
self.rescale_fn = eval(rescale_fn) | |
self.quantile = quantile | |
self.gamma = gamma | |
self.alpha = alpha | |
self.relative = relative | |
def forward( | |
self, | |
input: torch.Tensor, | |
target: torch.Tensor, | |
mask: torch.Tensor, | |
quality: torch.Tensor = None, | |
down_ratio: int = 1, | |
*args, | |
**kwargs, | |
) -> torch.Tensor: | |
mask = mask.bool() | |
if down_ratio > 1: | |
input = downsample(input, down_ratio) | |
target = downsample(target, down_ratio) | |
# downsample will ignore 0s in the patch "min", if there is a 1 -> set mask to 1 there | |
mask = downsample(mask.float(), down_ratio).bool() | |
input = self.input_fn(input.float()) | |
target = self.input_fn(target.float()) | |
B, C, H, W = input.shape | |
total_errors = [] | |
# save = random() < - 0.001 and is_main_process() | |
for ii, patch_logrange in enumerate(self.patch_logrange): | |
log_kernel = ( | |
np.random.uniform(*patch_logrange) | |
if self.training | |
else np.mean(patch_logrange) | |
) | |
kernel_size = int( | |
(2**log_kernel) * min(input.shape[-2:]) | |
) # always smaller than min_shape | |
kernel_size = (kernel_size, kernel_size) | |
stride = (int(kernel_size[0] * 0.9), int(kernel_size[1] * 0.9)) | |
# stride = kernel_size | |
# unfold is always exceeding right/bottom, roll image only negative | |
# to have them back in the unfolding window | |
max_roll = ( | |
(W - kernel_size[1]) % stride[1], | |
(H - kernel_size[0]) % stride[0], | |
) | |
roll_x, roll_y = np.random.randint(-max_roll[0], 1), np.random.randint( | |
-max_roll[1], 1 | |
) | |
input_fold = torch.roll(input, shifts=(roll_y, roll_x), dims=(2, 3)) | |
target_fold = torch.roll(target, shifts=(roll_y, roll_x), dims=(2, 3)) | |
mask_fold = torch.roll(mask.float(), shifts=(roll_y, roll_x), dims=(2, 3)) | |
# unfold in patches | |
input_fold = F.unfold( | |
input_fold, kernel_size=kernel_size, stride=stride | |
).permute( | |
0, 2, 1 | |
) # B N C*H_p*W_p | |
target_fold = F.unfold( | |
target_fold, kernel_size=kernel_size, stride=stride | |
).permute(0, 2, 1) | |
mask_fold = ( | |
F.unfold(mask_fold, kernel_size=kernel_size, stride=stride) | |
.bool() | |
.permute(0, 2, 1) | |
) | |
# calculate error patchwise, then mean over patch, then over image based if sample size is significant | |
input_fold, target_fold, _ = self.rescale_fn( | |
input_fold, target_fold, mask_fold, dim=(-1,) | |
) | |
error = self.fn( | |
input_fold - target_fold, alpha=self.alpha, gamma=self.gamma | |
) | |
# calculate elements more then 95 percentile and lower than 5percentile of error | |
if quality is not None: | |
N_patches = mask_fold.shape[1] | |
for quality_level in [1, 2]: | |
current_quality = quality == quality_level | |
if current_quality.sum() > 0: | |
error_qtl = error[current_quality].detach() | |
mask_qtl = error_qtl < masked_quantile( | |
error_qtl, | |
mask_fold[current_quality], | |
dims=[2], | |
q=1 - self.quantile * quality_level, | |
).view(-1, N_patches, 1) | |
mask_fold[current_quality] = ( | |
mask_fold[current_quality] & mask_qtl | |
) | |
else: | |
error_qtl = error.detach() | |
mask_fold = mask_fold & ( | |
error_qtl | |
< masked_quantile( | |
error_qtl, mask_fold, dims=[2], q=1 - self.quantile | |
).view(B, -1, 1) | |
) | |
valid_patches = mask_fold.sum(dim=-1) >= self.min_samples | |
error_mean_patch = masked_mean(error, mask_fold, dim=(-1,)).squeeze(-1) | |
error_mean_image = self.output_fn(error_mean_patch.clamp(min=self.eps)) | |
error_mean_image = masked_mean( | |
error_mean_image, mask=valid_patches, dim=(-1,) | |
) | |
total_errors.append(error_mean_image.squeeze(-1)) | |
# global | |
input_rescale = input.reshape(B, C, -1).clone() | |
target_rescale = target.reshape(B, C, -1) | |
mask = mask.reshape(B, 1, -1).clone() | |
input, target, _ = self.rescale_fn( | |
input_rescale, | |
target_rescale, | |
mask, | |
dim=(-1,), | |
target_info=target_rescale.norm(dim=1, keepdim=True), | |
input_info=input_rescale.norm(dim=1, keepdim=True), | |
) | |
error = input - target | |
error = error.norm(dim=1) if C > 1 else error.squeeze(1) | |
if self.relative: | |
error = error * torch.log( | |
1.0 + 10.0 / target_rescale.norm(dim=1).clip(min=0.01) | |
) | |
error = self.fn(error, alpha=self.alpha, gamma=self.gamma).squeeze(1) | |
mask = mask.squeeze(1) | |
valid_patches = mask.sum(dim=-1) >= 3 * self.min_samples # 30 samples per image | |
if quality is not None: | |
for quality_level in [1, 2]: | |
current_quality = quality == quality_level | |
if current_quality.sum() > 0: | |
error_qtl = error[current_quality].detach() | |
mask_qtl = error_qtl < masked_quantile( | |
error_qtl, | |
mask[current_quality], | |
dims=[1], | |
q=1 - self.quantile * quality_level, | |
).view(-1, 1) | |
mask[current_quality] = mask[current_quality] & mask_qtl | |
else: | |
error_qtl = error.detach() | |
mask = mask & ( | |
error_qtl | |
< masked_quantile(error_qtl, mask, dims=[1], q=1 - self.quantile).view( | |
-1, 1 | |
) | |
) | |
error_mean_image = masked_mean(error, mask, dim=(-1,)).squeeze(-1) | |
error_mean_image = ( | |
self.output_fn(error_mean_image.clamp(min=self.eps)) * valid_patches.float() | |
) | |
total_errors.append(error_mean_image) | |
errors = torch.stack(total_errors).mean(dim=0) | |
return errors | |
def build(cls, config): | |
obj = cls( | |
weight=config["weight"], | |
patch_size=config["patch_size"], | |
output_fn=config["output_fn"], | |
min_samples=config["min_samples"], | |
num_levels=config["num_levels"], | |
input_fn=config["input_fn"], | |
quantile=config["quantile"], | |
gamma=config["gamma"], | |
alpha=config["alpha"], | |
rescale_fn=config["rescale_fn"], | |
fn=config["fn"], | |
relative=config["relative"], | |
) | |
return obj | |