nomri / models /udno.py
samaonline
init
1b34a12
"""
U-shaped DISCO Neural Operator
"""
from typing import List, Tuple
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch_harmonics_local.convolution import (
EquidistantDiscreteContinuousConv2d as DISCO2d,
)
class UDNO(nn.Module):
"""
U-shaped DISCO Neural Operator in PyTorch
"""
def __init__(
self,
in_chans: int,
out_chans: int,
radius_cutoff: float,
chans: int = 32,
num_pool_layers: int = 4,
drop_prob: float = 0.0,
in_shape: Tuple[int, int] = (320, 320),
kernel_shape: Tuple[int, int] = (3, 4),
):
"""
Parameters
----------
in_chans : int
Number of channels in the input to the U-Net model.
out_chans : int
Number of channels in the output to the U-Net model.
radius_cutoff : float
Control the effective radius of the DISCO kernel. Values are
between 0.0 and 1.0. The radius_cutoff is represented as a proportion
of the normalized input space, to ensure that kernels are resolution
invaraint.
chans : int, optional
Number of output channels of the first DISCO layer. Default is 32.
num_pool_layers : int, optional
Number of down-sampling and up-sampling layers. Default is 4.
drop_prob : float, optional
Dropout probability. Default is 0.0.
in_shape : Tuple[int, int]
Shape of the input to the UDNO. This is required to dynamically
compile DISCO kernels for resolution invariance.
kernel_shape : Tuple[int, int], optional
Shape of the DISCO kernel. Default is (3, 4). This corresponds to 3
rings and 4 anisotropic basis functions. Under the hood, each DISCO
kernel has (3 - 1) * 4 + 1 = 9 parameters, equivalent to a standard
3x3 convolution kernel.
Note: This is NOT kernel_size, as under the DISCO framework,
kernels are dynamically compiled to support resolution invariance.
"""
super().__init__()
assert len(in_shape) == 2, "Input shape must be 2D"
self.in_chans = in_chans
self.out_chans = out_chans
self.chans = chans
self.num_pool_layers = num_pool_layers
self.drop_prob = drop_prob
self.in_shape = in_shape
self.kernel_shape = kernel_shape
self.down_sample_layers = nn.ModuleList(
[
DISCOBlock(
in_chans,
chans,
radius_cutoff,
drop_prob,
in_shape,
kernel_shape,
)
]
)
ch = chans
shape = (in_shape[0] // 2, in_shape[1] // 2)
radius_cutoff = radius_cutoff * 2
for _ in range(num_pool_layers - 1):
self.down_sample_layers.append(
DISCOBlock(
ch,
ch * 2,
radius_cutoff,
drop_prob,
in_shape=shape,
kernel_shape=kernel_shape,
)
)
ch *= 2
shape = (shape[0] // 2, shape[1] // 2)
radius_cutoff *= 2
# test commit
self.bottleneck = DISCOBlock(
ch,
ch * 2,
radius_cutoff,
drop_prob,
in_shape=shape,
kernel_shape=kernel_shape,
)
self.up = nn.ModuleList()
self.up_transpose = nn.ModuleList()
for _ in range(num_pool_layers - 1):
self.up_transpose.append(
TransposeDISCOBlock(
ch * 2,
ch,
radius_cutoff,
in_shape=shape,
kernel_shape=kernel_shape,
)
)
shape = (shape[0] * 2, shape[1] * 2)
radius_cutoff /= 2
self.up.append(
DISCOBlock(
ch * 2,
ch,
radius_cutoff,
drop_prob,
in_shape=shape,
kernel_shape=kernel_shape,
)
)
ch //= 2
self.up_transpose.append(
TransposeDISCOBlock(
ch * 2,
ch,
radius_cutoff,
in_shape=shape,
kernel_shape=kernel_shape,
)
)
shape = (shape[0] * 2, shape[1] * 2)
radius_cutoff /= 2
self.up.append(
nn.Sequential(
DISCOBlock(
ch * 2,
ch,
radius_cutoff,
drop_prob,
in_shape=shape,
kernel_shape=kernel_shape,
),
nn.Conv2d(
ch, self.out_chans, kernel_size=1, stride=1
), # 1x1 conv is always res-invariant (pixel wise channel transformation)
)
)
def forward(self, image: torch.Tensor) -> torch.Tensor:
"""
Parameters
----------
image : torch.Tensor
Input 4D tensor of shape `(N, in_chans, H, W)`.
Returns
-------
torch.Tensor
Output tensor of shape `(N, out_chans, H, W)`.
"""
stack = []
output = image
# apply down-sampling layers
for layer in self.down_sample_layers:
output = layer(output)
stack.append(output)
output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0)
output = self.bottleneck(output)
# apply up-sampling layers
for transpose, disco in zip(self.up_transpose, self.up):
downsample_layer = stack.pop()
output = transpose(output)
# reflect pad on the right/botton if needed to handle odd input dimensions
padding = [0, 0, 0, 0]
if output.shape[-1] != downsample_layer.shape[-1]:
padding[1] = 1 # padding right
if output.shape[-2] != downsample_layer.shape[-2]:
padding[3] = 1 # padding bottom
if torch.sum(torch.tensor(padding)) != 0:
output = F.pad(output, padding, "reflect")
output = torch.cat([output, downsample_layer], dim=1)
output = disco(output)
return output
class DISCOBlock(nn.Module):
"""
A DISCO Block that consists of two DISCO layers each followed by
instance normalization, LeakyReLU activation and dropout.
"""
def __init__(
self,
in_chans: int,
out_chans: int,
radius_cutoff: float,
drop_prob: float,
in_shape: Tuple[int, int],
kernel_shape: Tuple[int, int] = (3, 4),
):
"""
Parameters
----------
in_chans : int
Number of channels in the input.
out_chans : int
Number of channels in the output.
radius_cutoff : float
Control the effective radius of the DISCO kernel. Values are
between 0.0 and 1.0. The radius_cutoff is represented as a proportion
of the normalized input space, to ensure that kernels are resolution
invaraint.
in_shape : Tuple[int]
Unbatched spatial 2D shape of the input to this block.
Rrequired to dynamically compile DISCO kernels for resolution invariance.
kernel_shape : Tuple[int, int], optional
Shape of the DISCO kernel. Default is (3, 4). This corresponds to 3
rings and 4 anisotropic basis functions. Under the hood, each DISCO
kernel has (3 - 1) * 4 + 1 = 9 parameters, equivalent to a standard
3x3 convolution kernel.
Note: This is NOT kernel_size, as under the DISCO framework,
kernels are dynamically compiled to support resolution invariance.
drop_prob : float
Dropout probability.
"""
super().__init__()
self.in_chans = in_chans
self.out_chans = out_chans
self.drop_prob = drop_prob
self.layers = nn.Sequential(
DISCO2d(
in_chans,
out_chans,
kernel_shape=kernel_shape,
in_shape=in_shape,
bias=False,
radius_cutoff=radius_cutoff,
padding_mode="constant",
),
nn.InstanceNorm2d(out_chans),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Dropout2d(drop_prob),
DISCO2d(
out_chans,
out_chans,
kernel_shape=kernel_shape,
in_shape=in_shape,
bias=False,
radius_cutoff=radius_cutoff,
padding_mode="constant",
),
nn.InstanceNorm2d(out_chans),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Dropout2d(drop_prob),
)
def forward(self, image: torch.Tensor) -> torch.Tensor:
"""
Parameters
----------
image : ndarray
Input 4D tensor of shape `(N, in_chans, H, W)`.
Returns
-------
ndarray
Output tensor of shape `(N, out_chans, H, W)`.
"""
return self.layers(image)
class TransposeDISCOBlock(nn.Module):
"""
A transpose DISCO Block that consists of an up-sampling layer followed by a
DISCO layer, instance normalization, and LeakyReLU activation.
"""
def __init__(
self,
in_chans: int,
out_chans: int,
radius_cutoff: float,
in_shape: Tuple[int, int],
kernel_shape: Tuple[int, int] = (3, 4),
):
"""
Parameters
----------
in_chans : int
Number of channels in the input.
out_chans : int
Number of channels in the output.
radius_cutoff : float
Control the effective radius of the DISCO kernel. Values are
between 0.0 and 1.0. The radius_cutoff is represented as a proportion
of the normalized input space, to ensure that kernels are resolution
invaraint.
in_shape : Tuple[int]
Unbatched spatial 2D shape of the input to this block.
Rrequired to dynamically compile DISCO kernels for resolution invariance.
kernel_shape : Tuple[int, int], optional
Shape of the DISCO kernel. Default is (3, 4). This corresponds to 3
rings and 4 anisotropic basis functions. Under the hood, each DISCO
kernel has (3 - 1) * 4 + 1 = 9 parameters, equivalent to a standard
3x3 convolution kernel.
Note: This is NOT kernel_size, as under the DISCO framework,
kernels are dynamically compiled to support resolution invariance
"""
super().__init__()
self.in_chans = in_chans
self.out_chans = out_chans
self.layers = nn.Sequential(
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
DISCO2d(
in_chans,
out_chans,
kernel_shape=kernel_shape,
in_shape=(2 * in_shape[0], 2 * in_shape[1]),
bias=False,
radius_cutoff=(radius_cutoff / 2),
padding_mode="constant",
),
nn.InstanceNorm2d(out_chans),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
)
def forward(self, image: torch.Tensor) -> torch.Tensor:
"""
Parameters
----------
image : torch.Tensor
Input 4D tensor of shape `(N, in_chans, H, W)`.
Returns
-------
torch.Tensor
Output tensor of shape `(N, out_chans, H*2, W*2)`.
"""
return self.layers(image)