Luigi Piccinelli
init demo
1ea89dd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from unik3d.utils.constants import VERBOSE
from unik3d.utils.geometric import downsample, erode
from unik3d.utils.misc import profile_method
from .utils import (FNS, REGRESSION_DICT, ind2sub, masked_mean,
masked_quantile, ssi, ssi_nd)
def sample_strong_edges(edges_img, quantile=0.95, reshape=8):
# flat
edges_img = F.interpolate(
edges_img, scale_factor=1 / reshape, mode="bilinear", align_corners=False
)
edges_img_flat = edges_img.flatten(1)
# Find strong edges
edges_mask = edges_img_flat > torch.quantile(
edges_img_flat, quantile, dim=-1, keepdim=True
)
num_samples = edges_mask.sum(dim=-1)
if (num_samples < 10).any():
# sample random edges where num_samples < 2
random = torch.rand_like(edges_img_flat[num_samples < 10, :]) > quantile
edges_mask[num_samples < 10, :] = torch.logical_or(
edges_mask[num_samples < 10, :], random
)
num_samples = edges_mask.sum(dim=-1)
min_samples = num_samples.min()
# Compute the coordinates of the strong edges as B, N, 2
edges_coords = torch.stack(
[torch.nonzero(x, as_tuple=False)[:min_samples].squeeze() for x in edges_mask]
)
edges_coords = (
torch.stack(ind2sub(edges_coords, edges_img.shape[-1]), dim=-1) * reshape
)
return edges_coords
@torch.jit.script
def extract_patches(tensor, sample_coords, patch_size: tuple[int, int] = (32, 32)):
"""
Extracts patches around specified edge coordinates, with zero padding.
Parameters:
- tensor: tenosr to be gatherd based on sampling (B, 1, H, W).
- sample_coords: Batch of edge coordinates as a PyTorch tensor of shape (B, num_coords, 2).
- patch_size: Tuple (width, height) representing the size of the patches.
Returns:
- Patches as a PyTorch tensor of shape (B, num_coords, patch_height, patch_width).
"""
N, _, H, W = tensor.shape
device = tensor.device
dtype = tensor.dtype
patch_width, patch_height = patch_size
pad_width = patch_width // 2
pad_height = patch_height // 2
# Pad the RGB images for both sheep
tensor_padded = F.pad(
tensor,
(pad_width, pad_width, pad_height, pad_height),
mode="constant",
value=0.0,
)
# Adjust edge coordinates to account for padding
sample_coords_padded = sample_coords + torch.tensor(
[pad_height, pad_width], dtype=dtype, device=device
).reshape(1, 1, 2)
# Calculate the indices for gather operation
x_centers = sample_coords_padded[:, :, 1].int()
y_centers = sample_coords_padded[:, :, 0].int()
all_patches = []
for tensor_i, x_centers_i, y_centers_i in zip(tensor_padded, x_centers, y_centers):
patches = []
for x_center, y_center in zip(x_centers_i, y_centers_i):
y_start, y_end = y_center - pad_height, y_center + pad_height + 1
x_start, x_end = x_center - pad_width, x_center + pad_width + 1
patches.append(tensor_i[..., y_start:y_end, x_start:x_end])
all_patches.append(torch.stack(patches, dim=0))
return torch.stack(all_patches, dim=0).reshape(N, -1, patch_height * patch_width)
class LocalSSI(nn.Module):
def __init__(
self,
weight: float,
output_fn: str = "sqrt",
patch_size: tuple[int, int] = (32, 32),
min_samples: int = 4,
num_levels: int = 4,
fn: str = "l1",
rescale_fn: str = "ssi",
input_fn: str = "linear",
quantile: float = 0.1,
gamma: float = 1.0,
alpha: float = 1.0,
relative: bool = False,
eps: float = 1e-5,
):
super(LocalSSI, self).__init__()
self.name: str = self.__class__.__name__
self.weight = weight
self.output_fn = FNS[output_fn]
self.input_fn = FNS[input_fn]
self.fn = REGRESSION_DICT[fn]
self.min_samples = min_samples
self.eps = eps
patch_logrange = np.linspace(
start=np.log2(min(patch_size)),
stop=np.log2(max(patch_size)),
endpoint=True,
num=num_levels + 1,
)
self.patch_logrange = [
(x, y) for x, y in zip(patch_logrange[:-1], patch_logrange[1:])
]
self.rescale_fn = eval(rescale_fn)
self.quantile = quantile
self.gamma = gamma
self.alpha = alpha
self.relative = relative
@profile_method(verbose=VERBOSE)
@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
def forward(
self,
input: torch.Tensor,
target: torch.Tensor,
mask: torch.Tensor,
quality: torch.Tensor = None,
down_ratio: int = 1,
*args,
**kwargs,
) -> torch.Tensor:
mask = mask.bool()
if down_ratio > 1:
input = downsample(input, down_ratio)
target = downsample(target, down_ratio)
# downsample will ignore 0s in the patch "min", if there is a 1 -> set mask to 1 there
mask = downsample(mask.float(), down_ratio).bool()
input = self.input_fn(input.float())
target = self.input_fn(target.float())
B, C, H, W = input.shape
total_errors = []
# save = random() < - 0.001 and is_main_process()
for ii, patch_logrange in enumerate(self.patch_logrange):
log_kernel = (
np.random.uniform(*patch_logrange)
if self.training
else np.mean(patch_logrange)
)
kernel_size = int(
(2**log_kernel) * min(input.shape[-2:])
) # always smaller than min_shape
kernel_size = (kernel_size, kernel_size)
stride = (int(kernel_size[0] * 0.9), int(kernel_size[1] * 0.9))
# stride = kernel_size
# unfold is always exceeding right/bottom, roll image only negative
# to have them back in the unfolding window
max_roll = (
(W - kernel_size[1]) % stride[1],
(H - kernel_size[0]) % stride[0],
)
roll_x, roll_y = np.random.randint(-max_roll[0], 1), np.random.randint(
-max_roll[1], 1
)
input_fold = torch.roll(input, shifts=(roll_y, roll_x), dims=(2, 3))
target_fold = torch.roll(target, shifts=(roll_y, roll_x), dims=(2, 3))
mask_fold = torch.roll(mask.float(), shifts=(roll_y, roll_x), dims=(2, 3))
# unfold in patches
input_fold = F.unfold(
input_fold, kernel_size=kernel_size, stride=stride
).permute(
0, 2, 1
) # B N C*H_p*W_p
target_fold = F.unfold(
target_fold, kernel_size=kernel_size, stride=stride
).permute(0, 2, 1)
mask_fold = (
F.unfold(mask_fold, kernel_size=kernel_size, stride=stride)
.bool()
.permute(0, 2, 1)
)
# calculate error patchwise, then mean over patch, then over image based if sample size is significant
input_fold, target_fold, _ = self.rescale_fn(
input_fold, target_fold, mask_fold, dim=(-1,)
)
error = self.fn(
input_fold - target_fold, alpha=self.alpha, gamma=self.gamma
)
# calculate elements more then 95 percentile and lower than 5percentile of error
if quality is not None:
N_patches = mask_fold.shape[1]
for quality_level in [1, 2]:
current_quality = quality == quality_level
if current_quality.sum() > 0:
error_qtl = error[current_quality].detach()
mask_qtl = error_qtl < masked_quantile(
error_qtl,
mask_fold[current_quality],
dims=[2],
q=1 - self.quantile * quality_level,
).view(-1, N_patches, 1)
mask_fold[current_quality] = (
mask_fold[current_quality] & mask_qtl
)
else:
error_qtl = error.detach()
mask_fold = mask_fold & (
error_qtl
< masked_quantile(
error_qtl, mask_fold, dims=[2], q=1 - self.quantile
).view(B, -1, 1)
)
valid_patches = mask_fold.sum(dim=-1) >= self.min_samples
error_mean_patch = masked_mean(error, mask_fold, dim=(-1,)).squeeze(-1)
error_mean_image = self.output_fn(error_mean_patch.clamp(min=self.eps))
error_mean_image = masked_mean(
error_mean_image, mask=valid_patches, dim=(-1,)
)
total_errors.append(error_mean_image.squeeze(-1))
# global
input_rescale = input.reshape(B, C, -1).clone()
target_rescale = target.reshape(B, C, -1)
mask = mask.reshape(B, 1, -1).clone()
input, target, _ = self.rescale_fn(
input_rescale,
target_rescale,
mask,
dim=(-1,),
target_info=target_rescale.norm(dim=1, keepdim=True),
input_info=input_rescale.norm(dim=1, keepdim=True),
)
error = input - target
error = error.norm(dim=1) if C > 1 else error.squeeze(1)
if self.relative:
error = error * torch.log(
1.0 + 10.0 / target_rescale.norm(dim=1).clip(min=0.01)
)
error = self.fn(error, alpha=self.alpha, gamma=self.gamma).squeeze(1)
mask = mask.squeeze(1)
valid_patches = mask.sum(dim=-1) >= 3 * self.min_samples # 30 samples per image
if quality is not None:
for quality_level in [1, 2]:
current_quality = quality == quality_level
if current_quality.sum() > 0:
error_qtl = error[current_quality].detach()
mask_qtl = error_qtl < masked_quantile(
error_qtl,
mask[current_quality],
dims=[1],
q=1 - self.quantile * quality_level,
).view(-1, 1)
mask[current_quality] = mask[current_quality] & mask_qtl
else:
error_qtl = error.detach()
mask = mask & (
error_qtl
< masked_quantile(error_qtl, mask, dims=[1], q=1 - self.quantile).view(
-1, 1
)
)
error_mean_image = masked_mean(error, mask, dim=(-1,)).squeeze(-1)
error_mean_image = (
self.output_fn(error_mean_image.clamp(min=self.eps)) * valid_patches.float()
)
total_errors.append(error_mean_image)
errors = torch.stack(total_errors).mean(dim=0)
return errors
@classmethod
def build(cls, config):
obj = cls(
weight=config["weight"],
patch_size=config["patch_size"],
output_fn=config["output_fn"],
min_samples=config["min_samples"],
num_levels=config["num_levels"],
input_fn=config["input_fn"],
quantile=config["quantile"],
gamma=config["gamma"],
alpha=config["alpha"],
rescale_fn=config["rescale_fn"],
fn=config["fn"],
relative=config["relative"],
)
return obj