|
import numpy as np
|
|
import torch
|
|
from torch.nn import functional as F
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def mean(signals, win_length=9):
|
|
"""Averave filtering for signals containing nan values
|
|
|
|
Arguments
|
|
signals (torch.tensor (shape=(batch, time)))
|
|
The signals to filter
|
|
win_length
|
|
The size of the analysis window
|
|
|
|
Returns
|
|
filtered (torch.tensor (shape=(batch, time)))
|
|
"""
|
|
|
|
assert signals.dim() == 2, "Input tensor must have 2 dimensions (batch_size, width)"
|
|
signals = signals.unsqueeze(1)
|
|
|
|
|
|
mask = ~torch.isnan(signals)
|
|
masked_x = torch.where(mask, signals, torch.zeros_like(signals))
|
|
|
|
|
|
ones_kernel = torch.ones(signals.size(1), 1, win_length, device=signals.device)
|
|
|
|
|
|
sum_pooled = F.conv1d(
|
|
masked_x,
|
|
ones_kernel,
|
|
stride=1,
|
|
padding=win_length // 2,
|
|
)
|
|
|
|
|
|
valid_count = F.conv1d(
|
|
mask.float(),
|
|
ones_kernel,
|
|
stride=1,
|
|
padding=win_length // 2,
|
|
)
|
|
valid_count = valid_count.clamp(min=1)
|
|
|
|
|
|
avg_pooled = sum_pooled / valid_count
|
|
|
|
|
|
avg_pooled[avg_pooled == 0] = float("nan")
|
|
|
|
return avg_pooled.squeeze(1)
|
|
|
|
|
|
def median(signals, win_length):
|
|
"""Median filtering for signals containing nan values
|
|
|
|
Arguments
|
|
signals (torch.tensor (shape=(batch, time)))
|
|
The signals to filter
|
|
win_length
|
|
The size of the analysis window
|
|
|
|
Returns
|
|
filtered (torch.tensor (shape=(batch, time)))
|
|
"""
|
|
|
|
assert signals.dim() == 2, "Input tensor must have 2 dimensions (batch_size, width)"
|
|
signals = signals.unsqueeze(1)
|
|
|
|
mask = ~torch.isnan(signals)
|
|
masked_x = torch.where(mask, signals, torch.zeros_like(signals))
|
|
padding = win_length // 2
|
|
|
|
x = F.pad(masked_x, (padding, padding), mode="reflect")
|
|
mask = F.pad(mask.float(), (padding, padding), mode="constant", value=0)
|
|
|
|
x = x.unfold(2, win_length, 1)
|
|
mask = mask.unfold(2, win_length, 1)
|
|
|
|
x = x.contiguous().view(x.size()[:3] + (-1,))
|
|
mask = mask.contiguous().view(mask.size()[:3] + (-1,))
|
|
|
|
|
|
x_masked = torch.where(mask.bool(), x.float(), float("inf")).to(x)
|
|
|
|
|
|
x_sorted, _ = torch.sort(x_masked, dim=-1)
|
|
|
|
|
|
valid_count = mask.sum(dim=-1)
|
|
|
|
|
|
median_idx = ((valid_count - 1) // 2).clamp(min=0)
|
|
|
|
|
|
median_pooled = x_sorted.gather(-1, median_idx.unsqueeze(-1).long()).squeeze(-1)
|
|
|
|
|
|
median_pooled[torch.isinf(median_pooled)] = float("nan")
|
|
|
|
return median_pooled.squeeze(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def nanfilter(signals, win_length, filter_fn):
|
|
"""Filters a sequence, ignoring nan values
|
|
|
|
Arguments
|
|
signals (torch.tensor (shape=(batch, time)))
|
|
The signals to filter
|
|
win_length
|
|
The size of the analysis window
|
|
filter_fn (function)
|
|
The function to use for filtering
|
|
|
|
Returns
|
|
filtered (torch.tensor (shape=(batch, time)))
|
|
"""
|
|
|
|
filtered = torch.empty_like(signals)
|
|
|
|
|
|
for i in range(signals.size(1)):
|
|
|
|
|
|
start = max(0, i - win_length // 2)
|
|
end = min(signals.size(1), i + win_length // 2 + 1)
|
|
|
|
|
|
filtered[:, i] = filter_fn(signals[:, start:end])
|
|
|
|
return filtered
|
|
|
|
|
|
def nanmean(signals):
|
|
"""Computes the mean, ignoring nans
|
|
|
|
Arguments
|
|
signals (torch.tensor [shape=(batch, time)])
|
|
The signals to filter
|
|
|
|
Returns
|
|
filtered (torch.tensor [shape=(batch, time)])
|
|
"""
|
|
signals = signals.clone()
|
|
|
|
|
|
nans = torch.isnan(signals)
|
|
|
|
|
|
signals[nans] = 0.
|
|
|
|
|
|
return signals.sum(dim=1) / (~nans).float().sum(dim=1)
|
|
|
|
|
|
def nanmedian(signals):
|
|
"""Computes the median, ignoring nans
|
|
|
|
Arguments
|
|
signals (torch.tensor [shape=(batch, time)])
|
|
The signals to filter
|
|
|
|
Returns
|
|
filtered (torch.tensor [shape=(batch, time)])
|
|
"""
|
|
|
|
nans = torch.isnan(signals)
|
|
|
|
|
|
medians = [nanmedian1d(signal[~nan]) for signal, nan in zip(signals, nans)]
|
|
|
|
|
|
return torch.tensor(medians, dtype=signals.dtype, device=signals.device)
|
|
|
|
|
|
def nanmedian1d(signal):
|
|
"""Computes the median. If signal is empty, returns torch.nan
|
|
|
|
Arguments
|
|
signal (torch.tensor [shape=(time,)])
|
|
|
|
Returns
|
|
median (torch.tensor [shape=(1,)])
|
|
"""
|
|
return torch.median(signal) if signal.numel() else np.nan
|
|
|