nomri / torch_harmonics_local /convolution.py
samaonline
init
1b34a12
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import abc
import math
from functools import partial
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from .quadrature import _precompute_grid, _precompute_latitudes
if torch.cuda.is_available():
from ._disco_convolution import (
_disco_s2_contraction_triton,
_disco_s2_transpose_contraction_triton,
)
def _compute_support_vals_isotropic(
r: torch.Tensor, phi: torch.Tensor, nr: int, r_cutoff: float, norm: str = "s2"
):
"""
Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
"""
# compute the support
dr = (r_cutoff - 0.0) / nr
ikernel = torch.arange(nr).reshape(-1, 1, 1)
ir = ikernel * dr
if norm == "none":
norm_factor = 1.0
elif norm == "2d":
norm_factor = (
math.pi * (r_cutoff * nr / (nr + 1)) ** 2
+ math.pi * r_cutoff**2 * (2 * nr / (nr + 1) + 1) / (nr + 1) / 3
)
elif norm == "s2":
norm_factor = (
2
* math.pi
* (
1
- math.cos(r_cutoff - dr)
+ math.cos(r_cutoff - dr)
+ (math.sin(r_cutoff - dr) - math.sin(r_cutoff)) / dr
)
)
else:
raise ValueError(f"Unknown normalization mode {norm}.")
# find the indices where the rotated position falls into the support of the kernel
iidx = torch.argwhere(((r - ir).abs() <= dr) & (r <= r_cutoff))
vals = (
1 - (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() / dr
) / norm_factor
return iidx, vals
def _compute_support_vals_anisotropic(
r: torch.Tensor,
phi: torch.Tensor,
nr: int,
nphi: int,
r_cutoff: float,
norm: str = "s2",
):
"""
Computes the index set that falls into the anisotropic kernel's support and returns both indices and values.
"""
# compute the support
dr = (r_cutoff - 0.0) / nr
dphi = 2.0 * math.pi / nphi
kernel_size = (nr - 1) * nphi + 1
ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
ir = ((ikernel - 1) // nphi + 1) * dr
iphi = ((ikernel - 1) % nphi) * dphi
if norm == "none":
norm_factor = 1.0
elif norm == "2d":
norm_factor = (
math.pi * (r_cutoff * nr / (nr + 1)) ** 2
+ math.pi * r_cutoff**2 * (2 * nr / (nr + 1) + 1) / (nr + 1) / 3
)
elif norm == "s2":
norm_factor = (
2
* math.pi
* (
1
- math.cos(r_cutoff - dr)
+ math.cos(r_cutoff - dr)
+ (math.sin(r_cutoff - dr) - math.sin(r_cutoff)) / dr
)
)
else:
raise ValueError(f"Unknown normalization mode {norm}.")
# find the indices where the rotated position falls into the support of the kernel
cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
cond_phi = (
(ikernel == 0)
| ((phi - iphi).abs() <= dphi)
| ((2 * math.pi - (phi - iphi).abs()) <= dphi)
)
iidx = torch.argwhere(cond_r & cond_phi)
vals = (
1 - (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() / dr
) / norm_factor
vals *= torch.where(
iidx[:, 0] > 0,
(
1
- torch.minimum(
(phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs(),
(
2 * math.pi
- (phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs()
),
)
/ dphi
),
1.0,
)
return iidx, vals
def _precompute_convolution_tensor_s2(
in_shape,
out_shape,
kernel_shape,
grid_in="equiangular",
grid_out="equiangular",
theta_cutoff=0.01 * math.pi,
):
"""
Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al.
The output tensor has shape kernel_shape x nlat_out x (nlat_in * nlon_in).
The rotation of the Euler angles uses the YZY convention, which applied to the northpole $(0,0,1)^T$ yields
$$
Y(\alpha) Z(\beta) Y(\gamma) n =
{\begin{bmatrix}
\cos(\gamma)\sin(\alpha) + \cos(\alpha)\cos(\beta)\sin(\gamma) \\
\sin(\beta)\sin(\gamma) \\
\cos(\alpha)\cos(\gamma)-\cos(\beta)\sin(\alpha)\sin(\gamma)
\end{bmatrix}}
$$
"""
assert len(in_shape) == 2
assert len(out_shape) == 2
if len(kernel_shape) == 1:
kernel_handle = partial(
_compute_support_vals_isotropic,
nr=kernel_shape[0],
r_cutoff=theta_cutoff,
norm="s2",
)
elif len(kernel_shape) == 2:
kernel_handle = partial(
_compute_support_vals_anisotropic,
nr=kernel_shape[0],
nphi=kernel_shape[1],
r_cutoff=theta_cutoff,
norm="s2",
)
else:
raise ValueError("kernel_shape should be either one- or two-dimensional.")
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
lats_in, _ = _precompute_latitudes(nlat_in, grid=grid_in)
lats_in = torch.from_numpy(lats_in).float()
lats_out, _ = _precompute_latitudes(nlat_out, grid=grid_out)
lats_out = torch.from_numpy(lats_out).float()
# array for accumulating non-zero indices
out_idx = torch.empty([3, 0], dtype=torch.long)
out_vals = torch.empty([0], dtype=torch.long)
# compute the phi differences
# It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1]
for t in range(nlat_out):
# the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis
alpha = -lats_out[t]
beta = lons_in
gamma = lats_in.reshape(-1, 1)
# compute cartesian coordinates of the rotated position
# This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation,
# and therefore applied with a negative sign
z = -torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(
alpha
) * torch.cos(gamma)
x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(
gamma
) * torch.sin(alpha)
y = torch.sin(beta) * torch.sin(gamma)
# normalization is emportant to avoid NaNs when arccos and atan are applied
# this can otherwise lead to spurious artifacts in the solution
norm = torch.sqrt(x * x + y * y + z * z)
x = x / norm
y = y / norm
z = z / norm
# compute spherical coordinates, where phi needs to fall into the [0, 2pi) range
theta = torch.arccos(z)
phi = torch.arctan2(y, x) + torch.pi
# find the indices where the rotated position falls into the support of the kernel
iidx, vals = kernel_handle(theta, phi)
# add the output latitude and reshape such that psi has dimensions kernel_shape x nlat_out x (nlat_in*nlon_in)
idx = torch.stack(
[
iidx[:, 0],
t * torch.ones_like(iidx[:, 0]),
iidx[:, 1] * nlon_in + iidx[:, 2],
],
dim=0,
)
# append indices and values to the COO datastructure
out_idx = torch.cat([out_idx, idx], dim=-1)
out_vals = torch.cat([out_vals, vals], dim=-1)
return out_idx, out_vals
def _precompute_convolution_tensor_2d(
grid_in, grid_out, kernel_shape, radius_cutoff=0.01, periodic=False
):
"""
Precomputes the translated filters at positions $T^{-1}_j \omega_i = T^{-1}_j T_i \nu$. Similar to the S2 routine,
only that it assumes a non-periodic subset of the euclidean plane
"""
# check that input arrays are valid point clouds in 2D
assert len(grid_in) == 2
assert len(grid_out) == 2
assert grid_in.shape[0] == 2
assert grid_out.shape[0] == 2
n_in = grid_in.shape[-1]
n_out = grid_out.shape[-1]
if len(kernel_shape) == 1:
kernel_handle = partial(
_compute_support_vals_isotropic,
nr=kernel_shape[0],
r_cutoff=radius_cutoff,
norm="2d",
)
elif len(kernel_shape) == 2:
kernel_handle = partial(
_compute_support_vals_anisotropic,
nr=kernel_shape[0],
nphi=kernel_shape[1],
r_cutoff=radius_cutoff,
norm="2d",
)
else:
raise ValueError("kernel_shape should be either one- or two-dimensional.")
grid_in = grid_in.reshape(2, 1, n_in)
grid_out = grid_out.reshape(2, n_out, 1)
diffs = grid_in - grid_out
if periodic:
periodic_diffs = torch.where(diffs > 0.0, diffs - 1, diffs + 1)
diffs = torch.where(diffs.abs() < periodic_diffs.abs(), diffs, periodic_diffs)
r = torch.sqrt(diffs[0] ** 2 + diffs[1] ** 2)
phi = torch.arctan2(diffs[1], diffs[0]) + torch.pi
idx, vals = kernel_handle(r, phi)
idx = idx.permute(1, 0)
return idx, vals
class DiscreteContinuousConv(nn.Module, abc.ABC):
"""
Abstract base class for DISCO convolutions
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_shape: Union[int, List[int]],
groups: Optional[int] = 1,
bias: Optional[bool] = True,
):
super().__init__()
if isinstance(kernel_shape, int):
self.kernel_shape = [kernel_shape]
else:
self.kernel_shape = kernel_shape
if len(self.kernel_shape) == 1:
self.kernel_size = self.kernel_shape[0]
elif len(self.kernel_shape) == 2:
self.kernel_size = (self.kernel_shape[0] - 1) * self.kernel_shape[1] + 1
else:
raise ValueError("kernel_shape should be either one- or two-dimensional.")
# groups
self.groups = groups
# weight tensor
if in_channels % self.groups != 0:
raise ValueError(
"Error, the number of input channels has to be an integer multiple of the group size"
)
if out_channels % self.groups != 0:
raise ValueError(
"Error, the number of output channels has to be an integer multiple of the group size"
)
self.groupsize = in_channels // self.groups
scale = math.sqrt(1.0 / self.groupsize)
self.weight = nn.Parameter(
scale * torch.randn(out_channels, self.groupsize, self.kernel_size)
)
if bias:
self.bias = nn.Parameter(torch.zeros(out_channels))
else:
self.bias = None
@abc.abstractmethod
def forward(self, x: torch.Tensor):
raise NotImplementedError
def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: int):
"""
Reference implementation of the custom contraction as described in [1]. This requires repeated
shifting of the input tensor, which can potentially be costly. For an efficient implementation
on GPU, make sure to use the custom kernel written in Triton.
"""
assert len(psi.shape) == 3
assert len(x.shape) == 4
psi = psi.to(x.device)
batch_size, n_chans, nlat_in, nlon_in = x.shape
kernel_size, nlat_out, _ = psi.shape
assert psi.shape[-1] == nlat_in * nlon_in
assert nlon_in % nlon_out == 0
assert nlon_in >= nlat_out
pscale = nlon_in // nlon_out
# add a dummy dimension for nkernel and move the batch and channel dims to the end
x = x.reshape(1, batch_size * n_chans, nlat_in, nlon_in).permute(0, 2, 3, 1)
x = x.expand(kernel_size, -1, -1, -1)
y = torch.zeros(
nlon_out,
kernel_size,
nlat_out,
batch_size * n_chans,
device=x.device,
dtype=x.dtype,
)
for pout in range(nlon_out):
# sparse contraction with psi
y[pout] = torch.bmm(psi, x.reshape(kernel_size, nlat_in * nlon_in, -1))
# we need to repeatedly roll the input tensor to faciliate the shifted multiplication
x = torch.roll(x, -pscale, dims=2)
# reshape y back to expose the correct dimensions
y = y.permute(3, 1, 2, 0).reshape(
batch_size, n_chans, kernel_size, nlat_out, nlon_out
)
return y
def _disco_s2_transpose_contraction_torch(
x: torch.Tensor, psi: torch.Tensor, nlon_out: int
):
"""
Reference implementation of the custom contraction as described in [1]. This requires repeated
shifting of the input tensor, which can potentially be costly. For an efficient implementation
on GPU, make sure to use the custom kernel written in Triton.
"""
assert len(psi.shape) == 3
assert len(x.shape) == 5
psi = psi.to(x.device)
batch_size, n_chans, kernel_size, nlat_in, nlon_in = x.shape
kernel_size, _, n_out = psi.shape
assert psi.shape[-2] == nlat_in
assert n_out % nlon_out == 0
nlat_out = n_out // nlon_out
assert nlon_out >= nlat_in
pscale = nlon_out // nlon_in
# we do a semi-transposition to faciliate the computation
inz = psi.indices()
tout = inz[2] // nlon_out
pout = inz[2] % nlon_out
# flip the axis of longitudes
pout = nlon_out - 1 - pout
tin = inz[1]
inz = torch.stack([inz[0], tout, tin * nlon_out + pout], dim=0)
psi_mod = torch.sparse_coo_tensor(
inz, psi.values(), size=(kernel_size, nlat_out, nlat_in * nlon_out)
)
# interleave zeros along the longitude dimension to allow for fractional offsets to be considered
x_ext = torch.zeros(
kernel_size,
nlat_in,
nlon_out,
batch_size * n_chans,
device=x.device,
dtype=x.dtype,
)
x_ext[:, :, ::pscale, :] = x.reshape(
batch_size * n_chans, kernel_size, nlat_in, nlon_in
).permute(1, 2, 3, 0)
# we need to go backwards through the vector, so we flip the axis
x_ext = x_ext.contiguous()
y = torch.zeros(
kernel_size,
nlon_out,
nlat_out,
batch_size * n_chans,
device=x.device,
dtype=x.dtype,
)
for pout in range(nlon_out):
# we need to repeatedly roll the input tensor to faciliate the shifted multiplication
# TODO: double-check why this has to happen first
x_ext = torch.roll(x_ext, -1, dims=2)
# sparse contraction with the modified psi
y[:, pout, :, :] = torch.bmm(
psi_mod, x_ext.reshape(kernel_size, nlat_in * nlon_out, -1)
)
# sum over the kernel dimension and reshape to the correct output size
y = y.sum(dim=0).permute(2, 1, 0).reshape(batch_size, n_chans, nlat_out, nlon_out)
return y
class DiscreteContinuousConvS2(DiscreteContinuousConv):
"""
Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1].
[1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
"""
def __init__(
self,
in_channels: int,
out_channels: int,
in_shape: Tuple[int],
out_shape: Tuple[int],
kernel_shape: Union[int, List[int]],
groups: Optional[int] = 1,
grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular",
bias: Optional[bool] = True,
theta_cutoff: Optional[float] = None,
):
super().__init__(in_channels, out_channels, kernel_shape, groups, bias)
self.nlat_in, self.nlon_in = in_shape
self.nlat_out, self.nlon_out = out_shape
# compute theta cutoff based on the bandlimit of the input field
if theta_cutoff is None:
theta_cutoff = (
(self.kernel_shape[0] + 1) * torch.pi / float(self.nlat_in - 1)
)
if theta_cutoff <= 0.0:
raise ValueError("Error, theta_cutoff has to be positive.")
# integration weights
_, wgl = _precompute_latitudes(self.nlat_in, grid=grid_in)
quad_weights = (
2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / self.nlon_in
)
self.register_buffer("quad_weights", quad_weights, persistent=False)
idx, vals = _precompute_convolution_tensor_s2(
in_shape,
out_shape,
self.kernel_shape,
grid_in=grid_in,
grid_out=grid_out,
theta_cutoff=theta_cutoff,
)
self.register_buffer("psi_idx", idx, persistent=False)
self.register_buffer("psi_vals", vals, persistent=False)
def get_psi(self):
psi = torch.sparse_coo_tensor(
self.psi_idx,
self.psi_vals,
size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in),
).coalesce()
return psi
def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tensor:
# pre-multiply x with the quadrature weights
x = self.quad_weights * x
psi = self.get_psi()
if x.is_cuda and use_triton_kernel:
x = _disco_s2_contraction_triton(x, psi, self.nlon_out)
else:
x = _disco_s2_contraction_torch(x, psi, self.nlon_out)
# extract shape
B, C, K, H, W = x.shape
x = x.reshape(B, self.groups, self.groupsize, K, H, W)
# do weight multiplication
out = torch.einsum(
"bgckxy,gock->bgoxy",
x,
self.weight.reshape(
self.groups, -1, self.weight.shape[1], self.weight.shape[2]
),
)
out = out.reshape(out.shape[0], -1, out.shape[-2], out.shape[-1])
if self.bias is not None:
out = out + self.bias.reshape(1, -1, 1, 1)
return out
class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
"""
Discrete-continuous transpose convolutions (DISCO) on the 2-Sphere as described in [1].
[1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
"""
def __init__(
self,
in_channels: int,
out_channels: int,
in_shape: Tuple[int],
out_shape: Tuple[int],
kernel_shape: Union[int, List[int]],
groups: Optional[int] = 1,
grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular",
bias: Optional[bool] = True,
theta_cutoff: Optional[float] = None,
):
super().__init__(in_channels, out_channels, kernel_shape, groups, bias)
self.nlat_in, self.nlon_in = in_shape
self.nlat_out, self.nlon_out = out_shape
# bandlimit
if theta_cutoff is None:
theta_cutoff = (
(self.kernel_shape[0] + 1) * torch.pi / float(self.nlat_in - 1)
)
if theta_cutoff <= 0.0:
raise ValueError("Error, theta_cutoff has to be positive.")
# integration weights
_, wgl = _precompute_latitudes(self.nlat_in, grid=grid_in)
quad_weights = (
2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / self.nlon_in
)
self.register_buffer("quad_weights", quad_weights, persistent=False)
# switch in_shape and out_shape since we want transpose conv
idx, vals = _precompute_convolution_tensor_s2(
out_shape,
in_shape,
self.kernel_shape,
grid_in=grid_out,
grid_out=grid_in,
theta_cutoff=theta_cutoff,
)
self.register_buffer("psi_idx", idx, persistent=False)
self.register_buffer("psi_vals", vals, persistent=False)
def get_psi(self):
psi = torch.sparse_coo_tensor(
self.psi_idx,
self.psi_vals,
size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out),
).coalesce()
return psi
def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tensor:
# extract shape
B, C, H, W = x.shape
x = x.reshape(B, self.groups, self.groupsize, H, W)
# do weight multiplication
x = torch.einsum(
"bgcxy,gock->bgokxy",
x,
self.weight.reshape(
self.groups, -1, self.weight.shape[1], self.weight.shape[2]
),
)
x = x.reshape(x.shape[0], -1, x.shape[-3], x.shape[-2], x.shape[-1])
# pre-multiply x with the quadrature weights
x = self.quad_weights * x
psi = self.get_psi()
if x.is_cuda and use_triton_kernel:
out = _disco_s2_transpose_contraction_triton(x, psi, self.nlon_out)
else:
out = _disco_s2_transpose_contraction_torch(x, psi, self.nlon_out)
if self.bias is not None:
out = out + self.bias.reshape(1, -1, 1, 1)
return out
class DiscreteContinuousConv2d(DiscreteContinuousConv):
"""
Discrete-continuous convolutions (DISCO) on arbitrary 2d grids.
[1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
"""
def __init__(
self,
in_channels: int,
out_channels: int,
grid_in: torch.Tensor,
grid_out: torch.Tensor,
kernel_shape: Union[int, List[int]],
n_in: Optional[Tuple[int]] = None,
n_out: Optional[Tuple[int]] = None,
quad_weights: Optional[torch.Tensor] = None,
periodic: Optional[bool] = False,
groups: Optional[int] = 1,
bias: Optional[bool] = True,
radius_cutoff: Optional[float] = None,
):
super().__init__(in_channels, out_channels, kernel_shape, groups, bias)
# the instantiator supports convenience constructors for the input and output grids
if isinstance(grid_in, torch.Tensor):
assert isinstance(quad_weights, torch.Tensor)
assert not periodic
elif isinstance(grid_in, str):
assert n_in is not None
assert len(n_in) == 2
x, wx = _precompute_grid(n_in[0], grid=grid_in, periodic=periodic)
y, wy = _precompute_grid(n_in[1], grid=grid_in, periodic=periodic)
x, y = torch.meshgrid(torch.from_numpy(x), torch.from_numpy(y))
wx, wy = torch.meshgrid(torch.from_numpy(wx), torch.from_numpy(wy))
grid_in = torch.stack([x.reshape(-1), y.reshape(-1)])
quad_weights = (wx * wy).reshape(-1)
else:
raise ValueError(f"Unknown grid input type of type {type(grid_in)}")
if isinstance(grid_out, torch.Tensor):
pass
elif isinstance(grid_out, str):
assert n_out is not None
assert len(n_out) == 2
x, wx = _precompute_grid(n_out[0], grid=grid_out, periodic=periodic)
y, wy = _precompute_grid(n_out[1], grid=grid_out, periodic=periodic)
x, y = torch.meshgrid(torch.from_numpy(x), torch.from_numpy(y))
grid_out = torch.stack([x.reshape(-1), y.reshape(-1)])
else:
raise ValueError(f"Unknown grid output type of type {type(grid_out)}")
# check that input arrays are valid point clouds in 2D
assert len(grid_in.shape) == 2
assert len(grid_out.shape) == 2
assert len(quad_weights.shape) == 1
assert grid_in.shape[0] == 2
assert grid_out.shape[0] == 2
self.n_in = grid_in.shape[-1]
self.n_out = grid_out.shape[-1]
# compute the cutoff radius based on the bandlimit of the input field
# TODO: this heuristic is ad-hoc! Verify that we do the right one
if radius_cutoff is None:
radius_cutoff = (
2 * (self.kernel_shape[0] + 1) / float(math.sqrt(self.n_in) - 1)
)
if radius_cutoff <= 0.0:
raise ValueError("Error, radius_cutoff has to be positive.")
# integration weights
self.register_buffer("quad_weights", quad_weights, persistent=False)
idx, vals = _precompute_convolution_tensor_2d(
grid_in,
grid_out,
self.kernel_shape,
radius_cutoff=radius_cutoff,
periodic=periodic,
)
# to improve performance, we make psi a matrix by merging the first two dimensions
# This has to be accounted for in the forward pass
idx = torch.stack([idx[0] * self.n_out + idx[1], idx[2]], dim=0)
self.register_buffer("psi_idx", idx.contiguous(), persistent=False)
self.register_buffer("psi_vals", vals.contiguous(), persistent=False)
def get_psi(self):
psi = torch.sparse_coo_tensor(
self.psi_idx, self.psi_vals, size=(self.kernel_size * self.n_out, self.n_in)
)
return psi
def forward(self, x: torch.Tensor) -> torch.Tensor:
# pre-multiply x with the quadrature weights
x = self.quad_weights * x
psi = self.get_psi()
# extract shape
B, C, _ = x.shape
# bring into the right shape for the bmm and perform it
x = x.reshape(B * C, self.n_in).permute(1, 0).contiguous()
x = torch.mm(psi, x)
x = x.permute(1, 0).reshape(B, C, self.kernel_size, self.n_out)
x = x.reshape(B, self.groups, self.groupsize, self.kernel_size, self.n_out)
# do weight multiplication
out = torch.einsum(
"bgckx,gock->bgox",
x,
self.weight.reshape(
self.groups, -1, self.weight.shape[1], self.weight.shape[2]
),
)
out = out.reshape(out.shape[0], -1, out.shape[-1])
if self.bias is not None:
out = out + self.bias.reshape(1, -1, 1)
return out
class DiscreteContinuousConvTranspose2d(DiscreteContinuousConv):
"""
Discrete-continuous convolutions (DISCO) on arbitrary 2d grids.
[1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
"""
def __init__(
self,
in_channels: int,
out_channels: int,
grid_in: torch.Tensor,
grid_out: torch.Tensor,
kernel_shape: Union[int, List[int]],
n_in: Optional[Tuple[int]] = None,
n_out: Optional[Tuple[int]] = None,
quad_weights: Optional[torch.Tensor] = None,
periodic: Optional[bool] = False,
groups: Optional[int] = 1,
bias: Optional[bool] = True,
radius_cutoff: Optional[float] = None,
):
super().__init__(in_channels, out_channels, kernel_shape, groups, bias)
# the instantiator supports convenience constructors for the input and output grids
if isinstance(grid_in, torch.Tensor):
assert isinstance(quad_weights, torch.Tensor)
assert not periodic
elif isinstance(grid_in, str):
assert n_in is not None
assert len(n_in) == 2
x, wx = _precompute_grid(n_in[0], grid=grid_in, periodic=periodic)
y, wy = _precompute_grid(n_in[1], grid=grid_in, periodic=periodic)
x, y = torch.meshgrid(torch.from_numpy(x), torch.from_numpy(y))
wx, wy = torch.meshgrid(torch.from_numpy(wx), torch.from_numpy(wy))
grid_in = torch.stack([x.reshape(-1), y.reshape(-1)])
quad_weights = (wx * wy).reshape(-1)
else:
raise ValueError(f"Unknown grid input type of type {type(grid_in)}")
if isinstance(grid_out, torch.Tensor):
pass
elif isinstance(grid_out, str):
assert n_out is not None
assert len(n_out) == 2
x, wx = _precompute_grid(n_out[0], grid=grid_out, periodic=periodic)
y, wy = _precompute_grid(n_out[1], grid=grid_out, periodic=periodic)
x, y = torch.meshgrid(torch.from_numpy(x), torch.from_numpy(y))
grid_out = torch.stack([x.reshape(-1), y.reshape(-1)])
else:
raise ValueError(f"Unknown grid output type of type {type(grid_out)}")
# check that input arrays are valid point clouds in 2D
assert len(grid_in.shape) == 2
assert len(grid_out.shape) == 2
assert len(quad_weights.shape) == 1
assert grid_in.shape[0] == 2
assert grid_out.shape[0] == 2
self.n_in = grid_in.shape[-1]
self.n_out = grid_out.shape[-1]
# compute the cutoff radius based on the bandlimit of the input field
# TODO: this heuristic is ad-hoc! Verify that we do the right one
if radius_cutoff is None:
radius_cutoff = (
2 * (self.kernel_shape[0] + 1) / float(math.sqrt(self.n_in) - 1)
)
if radius_cutoff <= 0.0:
raise ValueError("Error, radius_cutoff has to be positive.")
# integration weights
self.register_buffer("quad_weights", quad_weights, persistent=False)
# precompute the transposed tensor
idx, vals = _precompute_convolution_tensor_2d(
grid_out,
grid_in,
self.kernel_shape,
radius_cutoff=radius_cutoff,
periodic=periodic,
)
# to improve performance, we make psi a matrix by merging the first two dimensions
# This has to be accounted for in the forward pass
idx = torch.stack([idx[0] * self.n_out + idx[2], idx[1]], dim=0)
self.register_buffer("psi_idx", idx.contiguous(), persistent=False)
self.register_buffer("psi_vals", vals.contiguous(), persistent=False)
def get_psi(self):
psi = torch.sparse_coo_tensor(
self.psi_idx, self.psi_vals, size=(self.kernel_size * self.n_out, self.n_in)
)
return psi
def forward(self, x: torch.Tensor) -> torch.Tensor:
# pre-multiply x with the quadrature weights
x = self.quad_weights * x
psi = self.get_psi()
# extract shape
B, C, _ = x.shape
# bring into the right shape for the bmm and perform it
x = x.reshape(B * C, self.n_in).permute(1, 0).contiguous()
x = torch.mm(psi, x)
x = x.permute(1, 0).reshape(B, C, self.kernel_size, self.n_out)
x = x.reshape(B, self.groups, self.groupsize, self.kernel_size, self.n_out)
# do weight multiplication
out = torch.einsum(
"bgckx,gock->bgox",
x,
self.weight.reshape(
self.groups, -1, self.weight.shape[1], self.weight.shape[2]
),
)
out = out.reshape(out.shape[0], -1, out.shape[-1])
if self.bias is not None:
out = out + self.bias.reshape(1, -1, 1)
return out
class EquidistantDiscreteContinuousConv2d(DiscreteContinuousConv):
"""
Discrete-continuous convolutions (DISCO) on arbitrary 2d grids.
[1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_shape: Union[int, List[int]],
in_shape: Tuple[int],
groups: Optional[int] = 1,
bias: Optional[bool] = True,
radius_cutoff: Optional[float] = None,
padding_mode: str = "circular",
use_min_dim: bool = True,
**kwargs,
):
"""
use_min_dim (bool, optional): Use the minimum dimension of the input
shape to compute the cutoff radius. Otherwise use the maximum
dimension. Defaults to True.
"""
super().__init__(in_channels, out_channels, kernel_shape, groups, bias)
self.padding_mode = padding_mode
# compute the cutoff radius based on the assumption that the grid is [-1, 1]^2
# this still assumes a quadratic domain
f = min if use_min_dim else max
if radius_cutoff is None:
radius_cutoff = 2 * (self.kernel_shape[0]) / float(f(*in_shape))
# 2 * 0.02 * 7 / 2 + 1 = 1.14
self.psi_local_size = math.floor(2 * radius_cutoff * f(*in_shape) / 2) + 1
# psi_local is essentially the support of the hat functions evaluated locally
x = torch.linspace(-radius_cutoff, radius_cutoff, self.psi_local_size)
x, y = torch.meshgrid(x, x)
grid_in = torch.stack([x.reshape(-1), y.reshape(-1)])
grid_out = torch.Tensor([[0.0], [0.0]])
idx, vals = _precompute_convolution_tensor_2d(
grid_in,
grid_out,
self.kernel_shape,
radius_cutoff=radius_cutoff,
periodic=False,
)
psi_loc = torch.zeros(
self.kernel_size, self.psi_local_size * self.psi_local_size
)
for ie in range(len(vals)):
f = idx[0, ie]
j = idx[2, ie]
v = vals[ie]
psi_loc[f, j] = v
# compute local version of the filter matrix
psi_loc = psi_loc.reshape(
self.kernel_size, self.psi_local_size, self.psi_local_size
)
# normalization by the quadrature weights
psi_loc = 4.0 * psi_loc / float(in_shape[0] * in_shape[1])
self.register_buffer("psi_loc", psi_loc, persistent=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
kernel = torch.einsum("kxy,ogk->ogxy", self.psi_loc, self.weight)
left_pad = self.psi_local_size // 2
right_pad = (self.psi_local_size + 1) // 2 - 1
x = F.pad(x, (left_pad, right_pad, left_pad, right_pad), mode=self.padding_mode)
out = F.conv2d(
x, kernel, self.bias, stride=1, dilation=1, padding=0, groups=self.groups
)
return out