from math import log, pi, prod from typing import Any, Dict, List, Optional, Tuple import torch FNS = { "sqrto": lambda x: torch.sqrt(x + 1), "sqrt": lambda x: torch.sqrt(x + 1e-4), "log": lambda x: torch.log(x + 1e-4), "log1": lambda x: torch.log(x + 1), # transition from log(1/x) to 1/x at x=100 # if x -> 0 : log(1/x), if x -> inf : log(1+1/x) -> 1/x + hot "log1i": lambda x: torch.log(1 + 50 / (1e-4 + x)), "log10": lambda x: torch.log10(1e-4 + x), "log2": lambda x: torch.log2(1e-4 + x), "linear": lambda x: x, "square": torch.square, "disp": lambda x: 1 / (x + 1e-4), "disp1": lambda x: 1 / (1 + x), } FNS_INV = { "sqrt": torch.square, "log": torch.exp, "log1": lambda x: torch.exp(x) - 1, "linear": lambda x: x, "square": torch.sqrt, "disp": lambda x: 1 / x, } def masked_mean_var( data: torch.Tensor, mask: torch.Tensor, dim: List[int], keepdim: bool = True ): if mask is None: return data.mean(dim=dim, keepdim=keepdim), data.var(dim=dim, keepdim=keepdim) # if data[mask].isnan().any(): # print("Warning: NaN in masked_mean_var, valid_pixels before and after", mask.sum(dim=dim).squeeze(), (mask & ~data.isnan()).sum(dim=dim).squeeze()) mask = (mask & ~data.isnan().any(dim=1, keepdim=True)).float() data = torch.nan_to_num(data, nan=0.0) mask_sum = torch.sum(mask, dim=dim, keepdim=True) mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp( mask_sum, min=1.0 ) mask_var = torch.sum( mask * (data - mask_mean) ** 2, dim=dim, keepdim=True ) / torch.clamp(mask_sum, min=1.0) if not keepdim: mask_mean, mask_var = mask_mean.squeeze(dim), mask_var.squeeze(dim) return mask_mean, mask_var def masked_mean(data: torch.Tensor, mask: torch.Tensor | None, dim: List[int]): if mask is None: return data.mean(dim=dim, keepdim=True) mask = mask.float() mask_sum = torch.sum(mask, dim=dim, keepdim=True) mask_mean = torch.sum( torch.nan_to_num(data, nan=0.0) * mask, dim=dim, keepdim=True ) / mask_sum.clamp(min=1.0) return mask_mean def masked_quantile( data: torch.Tensor, mask: torch.Tensor | None, dims: List[int], q: float ): """ Compute the quantile of the data only where the mask is 1 along specified dimensions. Args: data (torch.Tensor): The input data tensor. mask (torch.Tensor): The mask tensor with the same shape as data, containing 1s where data should be considered. dims (list of int): The dimensions to compute the quantile over. q (float): The quantile to compute, must be between 0 and 1. Returns: torch.Tensor: The quantile computed over the specified dimensions, ignoring masked values. """ masked_data = data * mask if mask is not None else data # Get a list of all dimensions all_dims = list(range(masked_data.dim())) # Revert negative dimensions dims = [d % masked_data.dim() for d in dims] # Find the dimensions to keep (not included in the `dims` list) keep_dims = [d for d in all_dims if d not in dims] # Permute dimensions to bring `dims` to the front permute_order = dims + keep_dims permuted_data = masked_data.permute(permute_order) # Reshape into 2D: (-1, remaining_dims) collapsed_shape = ( -1, prod([permuted_data.size(d) for d in range(len(dims), permuted_data.dim())]), ) reshaped_data = permuted_data.reshape(collapsed_shape) if mask is None: return torch.quantile(reshaped_data, q, dim=0) permuted_mask = mask.permute(permute_order) reshaped_mask = permuted_mask.reshape(collapsed_shape) # Calculate quantile along the first dimension where mask is true quantiles = [] for i in range(reshaped_data.shape[1]): valid_data = reshaped_data[:, i][reshaped_mask[:, i]] if valid_data.numel() == 0: # print("Warning: No valid data found for quantile calculation.") quantiles.append(reshaped_data[:, i].min() * 0.99) else: quantiles.append(torch.quantile(valid_data, q, dim=0)) # Stack back into a tensor with reduced dimensions quantiles = torch.stack(quantiles) quantiles = quantiles.reshape( [permuted_data.size(d) for d in range(len(dims), permuted_data.dim())] ) return quantiles def masked_median(data: torch.Tensor, mask: torch.Tensor, dim: List[int]): ndim = data.ndim data = data.flatten(ndim - len(dim)) mask = mask.flatten(ndim - len(dim)) mask_median = torch.median(data[..., mask], dim=-1).values return mask_median def masked_median_mad(data: torch.Tensor, mask: torch.Tensor, dim: List[int]): ndim = data.ndim data = data.flatten(ndim - len(dim)) mask = mask.flatten(ndim - len(dim)) mask_median = torch.median(data[mask], dim=-1, keepdim=True).values mask_mad = masked_mean((data - mask_median).abs(), mask, dim=(-1,)) return mask_median, mask_mad def masked_weighted_mean_var( data: torch.Tensor, mask: torch.Tensor, weights: torch.Tensor, dim: Tuple[int, ...] ): if mask is None: return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True) mask = mask.float() mask_mean = torch.sum(data * mask * weights, dim=dim, keepdim=True) / torch.sum( mask * weights, dim=dim, keepdim=True ).clamp(min=1.0) # V1**2 - V2, V1: sum w_i, V2: sum w_i**2 denom = torch.sum(weights * mask, dim=dim, keepdim=True).square() - torch.sum( (mask * weights).square(), dim=dim, keepdim=True ) # correction is V1 / (V1**2 - V2), if w_i=1 => N/(N**2 - N) => 1/(N-1) (unbiased estimator of variance, cvd) correction_factor = torch.sum(mask * weights, dim=dim, keepdim=True) / denom.clamp( min=1.0 ) mask_var = correction_factor * torch.sum( weights * mask * (data - mask_mean) ** 2, dim=dim, keepdim=True ) return mask_mean, mask_var def stable_masked_mean_var( input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor, dim: list[int] ): # recalculate mask with points in 95% confidence interval input_detach = input.detach() input_mean, input_var = masked_mean_var(input_detach, mask=mask, dim=dim) target_mean, target_var = masked_mean_var(target, mask=mask, dim=dim) input_std = (input_var).clip(min=1e-6).sqrt() target_std = (target_var).clip(min=1e-6).sqrt() stable_points_input = torch.logical_and( input_detach > input_mean - 1.96 * input_std, input_detach < input_mean + 1.96 * input_std, ) stable_points_target = torch.logical_and( target > target_mean - 1.96 * target_std, target < target_mean + 1.96 * target_std, ) stable_mask = stable_points_target & stable_points_input & mask input_mean, input_var = masked_mean_var(input, mask=stable_mask, dim=dim) target_mean, target_var = masked_mean_var(target, mask=stable_mask, dim=dim) return input_mean, input_var, target_mean, target_var, stable_mask def ssi( input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor, dim: list[int], *args, **kwargs, ) -> torch.Tensor: # recalculate mask with points in 95% confidence interval input_mean, input_var, target_mean, target_var, stable_mask = ( stable_masked_mean_var(input, target, mask, dim) ) # if target_var.min() < 1e-6: # print( # "Warning: target low", # list(zip(target_var.squeeze().cpu().numpy(), # target_mean.squeeze().cpu().numpy(), # mask.reshape(target_var.shape[0], -1).sum(dim=-1).squeeze().cpu().numpy(), # stable_mask.reshape(target_var.shape[0], -1).sum(dim=-1).squeeze().cpu().numpy())) # ) # if input_var.min() < 1e-6: # print("Warning: input variance is too low", input_var.squeeze(), input_mean.squeeze()) if input_var.isnan().any(): print("Warning: input variance is nan") if input_var.isinf().any(): print("Warning: input variance is isinf") if input_mean.isnan().any(): print("Warning: input m is nan") if input_mean.isinf().any(): print("Warning: input m is isinf") target_normalized = (target - target_mean) / FNS["sqrt"](target_var) input_normalized = (input - input_mean) / FNS["sqrt"](input_var) return input_normalized, target_normalized, stable_mask def ssi_nd( input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor, dim: list[int], input_info: torch.Tensor, target_info: torch.Tensor, ) -> torch.Tensor: input_mean, input_var, target_mean, target_var, stable_mask = ( stable_masked_mean_var(input_info, target_info, mask, dim) ) if input_var.isnan().any(): print("Warning: input variance is nan") if input_var.isinf().any(): print("Warning: input variance is isinf") if input_mean.isnan().any(): print("Warning: input m is nan") if input_mean.isinf().any(): print("Warning: input m is isinf") target_normalized = (target - target_mean) / FNS["sqrt"](target_var) input_normalized = (input - input_mean) / FNS["sqrt"](input_var) return input_normalized, target_normalized, stable_mask def stable_ssi( input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor, dim: list[int] ) -> torch.Tensor: input_mean, input_var = masked_mean_var(input, mask=mask, dim=dim) target_mean, target_var = masked_mean_var(target, mask=mask, dim=dim) target_normalized = (target - target_mean) / torch.sqrt(target_var.clamp(min=1e-6)) input_normalized = (input - input_mean) / torch.sqrt(input_var.clamp(min=1e-6)) return input_normalized, target_normalized, mask def ind2sub(idx, cols): r = idx // cols c = idx % cols return r, c def sub2ind(r, c, cols): idx = r * cols + c return idx def l2(input_tensor: torch.Tensor, gamma: float = 1.0, *args, **kwargs) -> torch.Tensor: return (input_tensor / gamma) ** 2 def l1(input_tensor: torch.Tensor, gamma: float = 1.0, *args, **kwargs) -> torch.Tensor: return torch.abs(input_tensor) def charbonnier( input_tensor: torch.Tensor, gamma: float = 1.0, *args, **kwargs ) -> torch.Tensor: return gamma * torch.sqrt(torch.square(input_tensor / gamma) + 1) - 1 def cauchy( input_tensor: torch.Tensor, gamma: float = 1.0, *args, **kwargs ) -> torch.Tensor: return gamma * torch.log(torch.square(input_tensor / gamma) + 1) + log(gamma * pi) def geman_mcclure( input_tensor: torch.Tensor, gamma: float = 1.0, *args, **kwargs ) -> torch.Tensor: return gamma * torch.square(input_tensor) / (torch.square(input_tensor) + gamma) def robust_loss( input_tensor: torch.Tensor, alpha: float, gamma: float = 1.0, *args, **kwargs ) -> torch.Tensor: coeff = abs(alpha - 2) / alpha power = torch.square(input_tensor / gamma) / abs(alpha - 2) + 1 return ( gamma * coeff * (torch.pow(power, alpha / 2) - 1) ) # mult gamma to keep grad magnitude invariant wrt gamma REGRESSION_DICT = { "l2": l2, "l1": l1, "cauchy": cauchy, "charbonnier": charbonnier, "geman_mcclure": geman_mcclure, "robust_loss": robust_loss, }