Spaces:
Running
on
Zero
Running
on
Zero
# coding=utf-8 | |
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. | |
# SPDX-License-Identifier: BSD-3-Clause | |
# | |
# Redistribution and use in source and binary forms, with or without | |
# modification, are permitted provided that the following conditions are met: | |
# | |
# 1. Redistributions of source code must retain the above copyright notice, this | |
# list of conditions and the following disclaimer. | |
# | |
# 2. Redistributions in binary form must reproduce the above copyright notice, | |
# this list of conditions and the following disclaimer in the documentation | |
# and/or other materials provided with the distribution. | |
# | |
# 3. Neither the name of the copyright holder nor the names of its | |
# contributors may be used to endorse or promote products derived from | |
# this software without specific prior written permission. | |
# | |
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | |
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
# | |
import abc | |
import math | |
from functools import partial | |
from typing import List, Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from .quadrature import _precompute_grid, _precompute_latitudes | |
if torch.cuda.is_available(): | |
from ._disco_convolution import ( | |
_disco_s2_contraction_triton, | |
_disco_s2_transpose_contraction_triton, | |
) | |
def _compute_support_vals_isotropic( | |
r: torch.Tensor, phi: torch.Tensor, nr: int, r_cutoff: float, norm: str = "s2" | |
): | |
""" | |
Computes the index set that falls into the isotropic kernel's support and returns both indices and values. | |
""" | |
# compute the support | |
dr = (r_cutoff - 0.0) / nr | |
ikernel = torch.arange(nr).reshape(-1, 1, 1) | |
ir = ikernel * dr | |
if norm == "none": | |
norm_factor = 1.0 | |
elif norm == "2d": | |
norm_factor = ( | |
math.pi * (r_cutoff * nr / (nr + 1)) ** 2 | |
+ math.pi * r_cutoff**2 * (2 * nr / (nr + 1) + 1) / (nr + 1) / 3 | |
) | |
elif norm == "s2": | |
norm_factor = ( | |
2 | |
* math.pi | |
* ( | |
1 | |
- math.cos(r_cutoff - dr) | |
+ math.cos(r_cutoff - dr) | |
+ (math.sin(r_cutoff - dr) - math.sin(r_cutoff)) / dr | |
) | |
) | |
else: | |
raise ValueError(f"Unknown normalization mode {norm}.") | |
# find the indices where the rotated position falls into the support of the kernel | |
iidx = torch.argwhere(((r - ir).abs() <= dr) & (r <= r_cutoff)) | |
vals = ( | |
1 - (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() / dr | |
) / norm_factor | |
return iidx, vals | |
def _compute_support_vals_anisotropic( | |
r: torch.Tensor, | |
phi: torch.Tensor, | |
nr: int, | |
nphi: int, | |
r_cutoff: float, | |
norm: str = "s2", | |
): | |
""" | |
Computes the index set that falls into the anisotropic kernel's support and returns both indices and values. | |
""" | |
# compute the support | |
dr = (r_cutoff - 0.0) / nr | |
dphi = 2.0 * math.pi / nphi | |
kernel_size = (nr - 1) * nphi + 1 | |
ikernel = torch.arange(kernel_size).reshape(-1, 1, 1) | |
ir = ((ikernel - 1) // nphi + 1) * dr | |
iphi = ((ikernel - 1) % nphi) * dphi | |
if norm == "none": | |
norm_factor = 1.0 | |
elif norm == "2d": | |
norm_factor = ( | |
math.pi * (r_cutoff * nr / (nr + 1)) ** 2 | |
+ math.pi * r_cutoff**2 * (2 * nr / (nr + 1) + 1) / (nr + 1) / 3 | |
) | |
elif norm == "s2": | |
norm_factor = ( | |
2 | |
* math.pi | |
* ( | |
1 | |
- math.cos(r_cutoff - dr) | |
+ math.cos(r_cutoff - dr) | |
+ (math.sin(r_cutoff - dr) - math.sin(r_cutoff)) / dr | |
) | |
) | |
else: | |
raise ValueError(f"Unknown normalization mode {norm}.") | |
# find the indices where the rotated position falls into the support of the kernel | |
cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff) | |
cond_phi = ( | |
(ikernel == 0) | |
| ((phi - iphi).abs() <= dphi) | |
| ((2 * math.pi - (phi - iphi).abs()) <= dphi) | |
) | |
iidx = torch.argwhere(cond_r & cond_phi) | |
vals = ( | |
1 - (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() / dr | |
) / norm_factor | |
vals *= torch.where( | |
iidx[:, 0] > 0, | |
( | |
1 | |
- torch.minimum( | |
(phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs(), | |
( | |
2 * math.pi | |
- (phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs() | |
), | |
) | |
/ dphi | |
), | |
1.0, | |
) | |
return iidx, vals | |
def _precompute_convolution_tensor_s2( | |
in_shape, | |
out_shape, | |
kernel_shape, | |
grid_in="equiangular", | |
grid_out="equiangular", | |
theta_cutoff=0.01 * math.pi, | |
): | |
""" | |
Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$. | |
Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al. | |
The output tensor has shape kernel_shape x nlat_out x (nlat_in * nlon_in). | |
The rotation of the Euler angles uses the YZY convention, which applied to the northpole $(0,0,1)^T$ yields | |
$$ | |
Y(\alpha) Z(\beta) Y(\gamma) n = | |
{\begin{bmatrix} | |
\cos(\gamma)\sin(\alpha) + \cos(\alpha)\cos(\beta)\sin(\gamma) \\ | |
\sin(\beta)\sin(\gamma) \\ | |
\cos(\alpha)\cos(\gamma)-\cos(\beta)\sin(\alpha)\sin(\gamma) | |
\end{bmatrix}} | |
$$ | |
""" | |
assert len(in_shape) == 2 | |
assert len(out_shape) == 2 | |
if len(kernel_shape) == 1: | |
kernel_handle = partial( | |
_compute_support_vals_isotropic, | |
nr=kernel_shape[0], | |
r_cutoff=theta_cutoff, | |
norm="s2", | |
) | |
elif len(kernel_shape) == 2: | |
kernel_handle = partial( | |
_compute_support_vals_anisotropic, | |
nr=kernel_shape[0], | |
nphi=kernel_shape[1], | |
r_cutoff=theta_cutoff, | |
norm="s2", | |
) | |
else: | |
raise ValueError("kernel_shape should be either one- or two-dimensional.") | |
nlat_in, nlon_in = in_shape | |
nlat_out, nlon_out = out_shape | |
lats_in, _ = _precompute_latitudes(nlat_in, grid=grid_in) | |
lats_in = torch.from_numpy(lats_in).float() | |
lats_out, _ = _precompute_latitudes(nlat_out, grid=grid_out) | |
lats_out = torch.from_numpy(lats_out).float() | |
# array for accumulating non-zero indices | |
out_idx = torch.empty([3, 0], dtype=torch.long) | |
out_vals = torch.empty([0], dtype=torch.long) | |
# compute the phi differences | |
# It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0 | |
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1] | |
for t in range(nlat_out): | |
# the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis | |
alpha = -lats_out[t] | |
beta = lons_in | |
gamma = lats_in.reshape(-1, 1) | |
# compute cartesian coordinates of the rotated position | |
# This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation, | |
# and therefore applied with a negative sign | |
z = -torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos( | |
alpha | |
) * torch.cos(gamma) | |
x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos( | |
gamma | |
) * torch.sin(alpha) | |
y = torch.sin(beta) * torch.sin(gamma) | |
# normalization is emportant to avoid NaNs when arccos and atan are applied | |
# this can otherwise lead to spurious artifacts in the solution | |
norm = torch.sqrt(x * x + y * y + z * z) | |
x = x / norm | |
y = y / norm | |
z = z / norm | |
# compute spherical coordinates, where phi needs to fall into the [0, 2pi) range | |
theta = torch.arccos(z) | |
phi = torch.arctan2(y, x) + torch.pi | |
# find the indices where the rotated position falls into the support of the kernel | |
iidx, vals = kernel_handle(theta, phi) | |
# add the output latitude and reshape such that psi has dimensions kernel_shape x nlat_out x (nlat_in*nlon_in) | |
idx = torch.stack( | |
[ | |
iidx[:, 0], | |
t * torch.ones_like(iidx[:, 0]), | |
iidx[:, 1] * nlon_in + iidx[:, 2], | |
], | |
dim=0, | |
) | |
# append indices and values to the COO datastructure | |
out_idx = torch.cat([out_idx, idx], dim=-1) | |
out_vals = torch.cat([out_vals, vals], dim=-1) | |
return out_idx, out_vals | |
def _precompute_convolution_tensor_2d( | |
grid_in, grid_out, kernel_shape, radius_cutoff=0.01, periodic=False | |
): | |
""" | |
Precomputes the translated filters at positions $T^{-1}_j \omega_i = T^{-1}_j T_i \nu$. Similar to the S2 routine, | |
only that it assumes a non-periodic subset of the euclidean plane | |
""" | |
# check that input arrays are valid point clouds in 2D | |
assert len(grid_in) == 2 | |
assert len(grid_out) == 2 | |
assert grid_in.shape[0] == 2 | |
assert grid_out.shape[0] == 2 | |
n_in = grid_in.shape[-1] | |
n_out = grid_out.shape[-1] | |
if len(kernel_shape) == 1: | |
kernel_handle = partial( | |
_compute_support_vals_isotropic, | |
nr=kernel_shape[0], | |
r_cutoff=radius_cutoff, | |
norm="2d", | |
) | |
elif len(kernel_shape) == 2: | |
kernel_handle = partial( | |
_compute_support_vals_anisotropic, | |
nr=kernel_shape[0], | |
nphi=kernel_shape[1], | |
r_cutoff=radius_cutoff, | |
norm="2d", | |
) | |
else: | |
raise ValueError("kernel_shape should be either one- or two-dimensional.") | |
grid_in = grid_in.reshape(2, 1, n_in) | |
grid_out = grid_out.reshape(2, n_out, 1) | |
diffs = grid_in - grid_out | |
if periodic: | |
periodic_diffs = torch.where(diffs > 0.0, diffs - 1, diffs + 1) | |
diffs = torch.where(diffs.abs() < periodic_diffs.abs(), diffs, periodic_diffs) | |
r = torch.sqrt(diffs[0] ** 2 + diffs[1] ** 2) | |
phi = torch.arctan2(diffs[1], diffs[0]) + torch.pi | |
idx, vals = kernel_handle(r, phi) | |
idx = idx.permute(1, 0) | |
return idx, vals | |
class DiscreteContinuousConv(nn.Module, abc.ABC): | |
""" | |
Abstract base class for DISCO convolutions | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
kernel_shape: Union[int, List[int]], | |
groups: Optional[int] = 1, | |
bias: Optional[bool] = True, | |
): | |
super().__init__() | |
if isinstance(kernel_shape, int): | |
self.kernel_shape = [kernel_shape] | |
else: | |
self.kernel_shape = kernel_shape | |
if len(self.kernel_shape) == 1: | |
self.kernel_size = self.kernel_shape[0] | |
elif len(self.kernel_shape) == 2: | |
self.kernel_size = (self.kernel_shape[0] - 1) * self.kernel_shape[1] + 1 | |
else: | |
raise ValueError("kernel_shape should be either one- or two-dimensional.") | |
# groups | |
self.groups = groups | |
# weight tensor | |
if in_channels % self.groups != 0: | |
raise ValueError( | |
"Error, the number of input channels has to be an integer multiple of the group size" | |
) | |
if out_channels % self.groups != 0: | |
raise ValueError( | |
"Error, the number of output channels has to be an integer multiple of the group size" | |
) | |
self.groupsize = in_channels // self.groups | |
scale = math.sqrt(1.0 / self.groupsize) | |
self.weight = nn.Parameter( | |
scale * torch.randn(out_channels, self.groupsize, self.kernel_size) | |
) | |
if bias: | |
self.bias = nn.Parameter(torch.zeros(out_channels)) | |
else: | |
self.bias = None | |
def forward(self, x: torch.Tensor): | |
raise NotImplementedError | |
def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: int): | |
""" | |
Reference implementation of the custom contraction as described in [1]. This requires repeated | |
shifting of the input tensor, which can potentially be costly. For an efficient implementation | |
on GPU, make sure to use the custom kernel written in Triton. | |
""" | |
assert len(psi.shape) == 3 | |
assert len(x.shape) == 4 | |
psi = psi.to(x.device) | |
batch_size, n_chans, nlat_in, nlon_in = x.shape | |
kernel_size, nlat_out, _ = psi.shape | |
assert psi.shape[-1] == nlat_in * nlon_in | |
assert nlon_in % nlon_out == 0 | |
assert nlon_in >= nlat_out | |
pscale = nlon_in // nlon_out | |
# add a dummy dimension for nkernel and move the batch and channel dims to the end | |
x = x.reshape(1, batch_size * n_chans, nlat_in, nlon_in).permute(0, 2, 3, 1) | |
x = x.expand(kernel_size, -1, -1, -1) | |
y = torch.zeros( | |
nlon_out, | |
kernel_size, | |
nlat_out, | |
batch_size * n_chans, | |
device=x.device, | |
dtype=x.dtype, | |
) | |
for pout in range(nlon_out): | |
# sparse contraction with psi | |
y[pout] = torch.bmm(psi, x.reshape(kernel_size, nlat_in * nlon_in, -1)) | |
# we need to repeatedly roll the input tensor to faciliate the shifted multiplication | |
x = torch.roll(x, -pscale, dims=2) | |
# reshape y back to expose the correct dimensions | |
y = y.permute(3, 1, 2, 0).reshape( | |
batch_size, n_chans, kernel_size, nlat_out, nlon_out | |
) | |
return y | |
def _disco_s2_transpose_contraction_torch( | |
x: torch.Tensor, psi: torch.Tensor, nlon_out: int | |
): | |
""" | |
Reference implementation of the custom contraction as described in [1]. This requires repeated | |
shifting of the input tensor, which can potentially be costly. For an efficient implementation | |
on GPU, make sure to use the custom kernel written in Triton. | |
""" | |
assert len(psi.shape) == 3 | |
assert len(x.shape) == 5 | |
psi = psi.to(x.device) | |
batch_size, n_chans, kernel_size, nlat_in, nlon_in = x.shape | |
kernel_size, _, n_out = psi.shape | |
assert psi.shape[-2] == nlat_in | |
assert n_out % nlon_out == 0 | |
nlat_out = n_out // nlon_out | |
assert nlon_out >= nlat_in | |
pscale = nlon_out // nlon_in | |
# we do a semi-transposition to faciliate the computation | |
inz = psi.indices() | |
tout = inz[2] // nlon_out | |
pout = inz[2] % nlon_out | |
# flip the axis of longitudes | |
pout = nlon_out - 1 - pout | |
tin = inz[1] | |
inz = torch.stack([inz[0], tout, tin * nlon_out + pout], dim=0) | |
psi_mod = torch.sparse_coo_tensor( | |
inz, psi.values(), size=(kernel_size, nlat_out, nlat_in * nlon_out) | |
) | |
# interleave zeros along the longitude dimension to allow for fractional offsets to be considered | |
x_ext = torch.zeros( | |
kernel_size, | |
nlat_in, | |
nlon_out, | |
batch_size * n_chans, | |
device=x.device, | |
dtype=x.dtype, | |
) | |
x_ext[:, :, ::pscale, :] = x.reshape( | |
batch_size * n_chans, kernel_size, nlat_in, nlon_in | |
).permute(1, 2, 3, 0) | |
# we need to go backwards through the vector, so we flip the axis | |
x_ext = x_ext.contiguous() | |
y = torch.zeros( | |
kernel_size, | |
nlon_out, | |
nlat_out, | |
batch_size * n_chans, | |
device=x.device, | |
dtype=x.dtype, | |
) | |
for pout in range(nlon_out): | |
# we need to repeatedly roll the input tensor to faciliate the shifted multiplication | |
# TODO: double-check why this has to happen first | |
x_ext = torch.roll(x_ext, -1, dims=2) | |
# sparse contraction with the modified psi | |
y[:, pout, :, :] = torch.bmm( | |
psi_mod, x_ext.reshape(kernel_size, nlat_in * nlon_out, -1) | |
) | |
# sum over the kernel dimension and reshape to the correct output size | |
y = y.sum(dim=0).permute(2, 1, 0).reshape(batch_size, n_chans, nlat_out, nlon_out) | |
return y | |
class DiscreteContinuousConvS2(DiscreteContinuousConv): | |
""" | |
Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1]. | |
[1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603 | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
in_shape: Tuple[int], | |
out_shape: Tuple[int], | |
kernel_shape: Union[int, List[int]], | |
groups: Optional[int] = 1, | |
grid_in: Optional[str] = "equiangular", | |
grid_out: Optional[str] = "equiangular", | |
bias: Optional[bool] = True, | |
theta_cutoff: Optional[float] = None, | |
): | |
super().__init__(in_channels, out_channels, kernel_shape, groups, bias) | |
self.nlat_in, self.nlon_in = in_shape | |
self.nlat_out, self.nlon_out = out_shape | |
# compute theta cutoff based on the bandlimit of the input field | |
if theta_cutoff is None: | |
theta_cutoff = ( | |
(self.kernel_shape[0] + 1) * torch.pi / float(self.nlat_in - 1) | |
) | |
if theta_cutoff <= 0.0: | |
raise ValueError("Error, theta_cutoff has to be positive.") | |
# integration weights | |
_, wgl = _precompute_latitudes(self.nlat_in, grid=grid_in) | |
quad_weights = ( | |
2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / self.nlon_in | |
) | |
self.register_buffer("quad_weights", quad_weights, persistent=False) | |
idx, vals = _precompute_convolution_tensor_s2( | |
in_shape, | |
out_shape, | |
self.kernel_shape, | |
grid_in=grid_in, | |
grid_out=grid_out, | |
theta_cutoff=theta_cutoff, | |
) | |
self.register_buffer("psi_idx", idx, persistent=False) | |
self.register_buffer("psi_vals", vals, persistent=False) | |
def get_psi(self): | |
psi = torch.sparse_coo_tensor( | |
self.psi_idx, | |
self.psi_vals, | |
size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in), | |
).coalesce() | |
return psi | |
def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tensor: | |
# pre-multiply x with the quadrature weights | |
x = self.quad_weights * x | |
psi = self.get_psi() | |
if x.is_cuda and use_triton_kernel: | |
x = _disco_s2_contraction_triton(x, psi, self.nlon_out) | |
else: | |
x = _disco_s2_contraction_torch(x, psi, self.nlon_out) | |
# extract shape | |
B, C, K, H, W = x.shape | |
x = x.reshape(B, self.groups, self.groupsize, K, H, W) | |
# do weight multiplication | |
out = torch.einsum( | |
"bgckxy,gock->bgoxy", | |
x, | |
self.weight.reshape( | |
self.groups, -1, self.weight.shape[1], self.weight.shape[2] | |
), | |
) | |
out = out.reshape(out.shape[0], -1, out.shape[-2], out.shape[-1]) | |
if self.bias is not None: | |
out = out + self.bias.reshape(1, -1, 1, 1) | |
return out | |
class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv): | |
""" | |
Discrete-continuous transpose convolutions (DISCO) on the 2-Sphere as described in [1]. | |
[1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603 | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
in_shape: Tuple[int], | |
out_shape: Tuple[int], | |
kernel_shape: Union[int, List[int]], | |
groups: Optional[int] = 1, | |
grid_in: Optional[str] = "equiangular", | |
grid_out: Optional[str] = "equiangular", | |
bias: Optional[bool] = True, | |
theta_cutoff: Optional[float] = None, | |
): | |
super().__init__(in_channels, out_channels, kernel_shape, groups, bias) | |
self.nlat_in, self.nlon_in = in_shape | |
self.nlat_out, self.nlon_out = out_shape | |
# bandlimit | |
if theta_cutoff is None: | |
theta_cutoff = ( | |
(self.kernel_shape[0] + 1) * torch.pi / float(self.nlat_in - 1) | |
) | |
if theta_cutoff <= 0.0: | |
raise ValueError("Error, theta_cutoff has to be positive.") | |
# integration weights | |
_, wgl = _precompute_latitudes(self.nlat_in, grid=grid_in) | |
quad_weights = ( | |
2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / self.nlon_in | |
) | |
self.register_buffer("quad_weights", quad_weights, persistent=False) | |
# switch in_shape and out_shape since we want transpose conv | |
idx, vals = _precompute_convolution_tensor_s2( | |
out_shape, | |
in_shape, | |
self.kernel_shape, | |
grid_in=grid_out, | |
grid_out=grid_in, | |
theta_cutoff=theta_cutoff, | |
) | |
self.register_buffer("psi_idx", idx, persistent=False) | |
self.register_buffer("psi_vals", vals, persistent=False) | |
def get_psi(self): | |
psi = torch.sparse_coo_tensor( | |
self.psi_idx, | |
self.psi_vals, | |
size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out), | |
).coalesce() | |
return psi | |
def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tensor: | |
# extract shape | |
B, C, H, W = x.shape | |
x = x.reshape(B, self.groups, self.groupsize, H, W) | |
# do weight multiplication | |
x = torch.einsum( | |
"bgcxy,gock->bgokxy", | |
x, | |
self.weight.reshape( | |
self.groups, -1, self.weight.shape[1], self.weight.shape[2] | |
), | |
) | |
x = x.reshape(x.shape[0], -1, x.shape[-3], x.shape[-2], x.shape[-1]) | |
# pre-multiply x with the quadrature weights | |
x = self.quad_weights * x | |
psi = self.get_psi() | |
if x.is_cuda and use_triton_kernel: | |
out = _disco_s2_transpose_contraction_triton(x, psi, self.nlon_out) | |
else: | |
out = _disco_s2_transpose_contraction_torch(x, psi, self.nlon_out) | |
if self.bias is not None: | |
out = out + self.bias.reshape(1, -1, 1, 1) | |
return out | |
class DiscreteContinuousConv2d(DiscreteContinuousConv): | |
""" | |
Discrete-continuous convolutions (DISCO) on arbitrary 2d grids. | |
[1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603 | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
grid_in: torch.Tensor, | |
grid_out: torch.Tensor, | |
kernel_shape: Union[int, List[int]], | |
n_in: Optional[Tuple[int]] = None, | |
n_out: Optional[Tuple[int]] = None, | |
quad_weights: Optional[torch.Tensor] = None, | |
periodic: Optional[bool] = False, | |
groups: Optional[int] = 1, | |
bias: Optional[bool] = True, | |
radius_cutoff: Optional[float] = None, | |
): | |
super().__init__(in_channels, out_channels, kernel_shape, groups, bias) | |
# the instantiator supports convenience constructors for the input and output grids | |
if isinstance(grid_in, torch.Tensor): | |
assert isinstance(quad_weights, torch.Tensor) | |
assert not periodic | |
elif isinstance(grid_in, str): | |
assert n_in is not None | |
assert len(n_in) == 2 | |
x, wx = _precompute_grid(n_in[0], grid=grid_in, periodic=periodic) | |
y, wy = _precompute_grid(n_in[1], grid=grid_in, periodic=periodic) | |
x, y = torch.meshgrid(torch.from_numpy(x), torch.from_numpy(y)) | |
wx, wy = torch.meshgrid(torch.from_numpy(wx), torch.from_numpy(wy)) | |
grid_in = torch.stack([x.reshape(-1), y.reshape(-1)]) | |
quad_weights = (wx * wy).reshape(-1) | |
else: | |
raise ValueError(f"Unknown grid input type of type {type(grid_in)}") | |
if isinstance(grid_out, torch.Tensor): | |
pass | |
elif isinstance(grid_out, str): | |
assert n_out is not None | |
assert len(n_out) == 2 | |
x, wx = _precompute_grid(n_out[0], grid=grid_out, periodic=periodic) | |
y, wy = _precompute_grid(n_out[1], grid=grid_out, periodic=periodic) | |
x, y = torch.meshgrid(torch.from_numpy(x), torch.from_numpy(y)) | |
grid_out = torch.stack([x.reshape(-1), y.reshape(-1)]) | |
else: | |
raise ValueError(f"Unknown grid output type of type {type(grid_out)}") | |
# check that input arrays are valid point clouds in 2D | |
assert len(grid_in.shape) == 2 | |
assert len(grid_out.shape) == 2 | |
assert len(quad_weights.shape) == 1 | |
assert grid_in.shape[0] == 2 | |
assert grid_out.shape[0] == 2 | |
self.n_in = grid_in.shape[-1] | |
self.n_out = grid_out.shape[-1] | |
# compute the cutoff radius based on the bandlimit of the input field | |
# TODO: this heuristic is ad-hoc! Verify that we do the right one | |
if radius_cutoff is None: | |
radius_cutoff = ( | |
2 * (self.kernel_shape[0] + 1) / float(math.sqrt(self.n_in) - 1) | |
) | |
if radius_cutoff <= 0.0: | |
raise ValueError("Error, radius_cutoff has to be positive.") | |
# integration weights | |
self.register_buffer("quad_weights", quad_weights, persistent=False) | |
idx, vals = _precompute_convolution_tensor_2d( | |
grid_in, | |
grid_out, | |
self.kernel_shape, | |
radius_cutoff=radius_cutoff, | |
periodic=periodic, | |
) | |
# to improve performance, we make psi a matrix by merging the first two dimensions | |
# This has to be accounted for in the forward pass | |
idx = torch.stack([idx[0] * self.n_out + idx[1], idx[2]], dim=0) | |
self.register_buffer("psi_idx", idx.contiguous(), persistent=False) | |
self.register_buffer("psi_vals", vals.contiguous(), persistent=False) | |
def get_psi(self): | |
psi = torch.sparse_coo_tensor( | |
self.psi_idx, self.psi_vals, size=(self.kernel_size * self.n_out, self.n_in) | |
) | |
return psi | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
# pre-multiply x with the quadrature weights | |
x = self.quad_weights * x | |
psi = self.get_psi() | |
# extract shape | |
B, C, _ = x.shape | |
# bring into the right shape for the bmm and perform it | |
x = x.reshape(B * C, self.n_in).permute(1, 0).contiguous() | |
x = torch.mm(psi, x) | |
x = x.permute(1, 0).reshape(B, C, self.kernel_size, self.n_out) | |
x = x.reshape(B, self.groups, self.groupsize, self.kernel_size, self.n_out) | |
# do weight multiplication | |
out = torch.einsum( | |
"bgckx,gock->bgox", | |
x, | |
self.weight.reshape( | |
self.groups, -1, self.weight.shape[1], self.weight.shape[2] | |
), | |
) | |
out = out.reshape(out.shape[0], -1, out.shape[-1]) | |
if self.bias is not None: | |
out = out + self.bias.reshape(1, -1, 1) | |
return out | |
class DiscreteContinuousConvTranspose2d(DiscreteContinuousConv): | |
""" | |
Discrete-continuous convolutions (DISCO) on arbitrary 2d grids. | |
[1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603 | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
grid_in: torch.Tensor, | |
grid_out: torch.Tensor, | |
kernel_shape: Union[int, List[int]], | |
n_in: Optional[Tuple[int]] = None, | |
n_out: Optional[Tuple[int]] = None, | |
quad_weights: Optional[torch.Tensor] = None, | |
periodic: Optional[bool] = False, | |
groups: Optional[int] = 1, | |
bias: Optional[bool] = True, | |
radius_cutoff: Optional[float] = None, | |
): | |
super().__init__(in_channels, out_channels, kernel_shape, groups, bias) | |
# the instantiator supports convenience constructors for the input and output grids | |
if isinstance(grid_in, torch.Tensor): | |
assert isinstance(quad_weights, torch.Tensor) | |
assert not periodic | |
elif isinstance(grid_in, str): | |
assert n_in is not None | |
assert len(n_in) == 2 | |
x, wx = _precompute_grid(n_in[0], grid=grid_in, periodic=periodic) | |
y, wy = _precompute_grid(n_in[1], grid=grid_in, periodic=periodic) | |
x, y = torch.meshgrid(torch.from_numpy(x), torch.from_numpy(y)) | |
wx, wy = torch.meshgrid(torch.from_numpy(wx), torch.from_numpy(wy)) | |
grid_in = torch.stack([x.reshape(-1), y.reshape(-1)]) | |
quad_weights = (wx * wy).reshape(-1) | |
else: | |
raise ValueError(f"Unknown grid input type of type {type(grid_in)}") | |
if isinstance(grid_out, torch.Tensor): | |
pass | |
elif isinstance(grid_out, str): | |
assert n_out is not None | |
assert len(n_out) == 2 | |
x, wx = _precompute_grid(n_out[0], grid=grid_out, periodic=periodic) | |
y, wy = _precompute_grid(n_out[1], grid=grid_out, periodic=periodic) | |
x, y = torch.meshgrid(torch.from_numpy(x), torch.from_numpy(y)) | |
grid_out = torch.stack([x.reshape(-1), y.reshape(-1)]) | |
else: | |
raise ValueError(f"Unknown grid output type of type {type(grid_out)}") | |
# check that input arrays are valid point clouds in 2D | |
assert len(grid_in.shape) == 2 | |
assert len(grid_out.shape) == 2 | |
assert len(quad_weights.shape) == 1 | |
assert grid_in.shape[0] == 2 | |
assert grid_out.shape[0] == 2 | |
self.n_in = grid_in.shape[-1] | |
self.n_out = grid_out.shape[-1] | |
# compute the cutoff radius based on the bandlimit of the input field | |
# TODO: this heuristic is ad-hoc! Verify that we do the right one | |
if radius_cutoff is None: | |
radius_cutoff = ( | |
2 * (self.kernel_shape[0] + 1) / float(math.sqrt(self.n_in) - 1) | |
) | |
if radius_cutoff <= 0.0: | |
raise ValueError("Error, radius_cutoff has to be positive.") | |
# integration weights | |
self.register_buffer("quad_weights", quad_weights, persistent=False) | |
# precompute the transposed tensor | |
idx, vals = _precompute_convolution_tensor_2d( | |
grid_out, | |
grid_in, | |
self.kernel_shape, | |
radius_cutoff=radius_cutoff, | |
periodic=periodic, | |
) | |
# to improve performance, we make psi a matrix by merging the first two dimensions | |
# This has to be accounted for in the forward pass | |
idx = torch.stack([idx[0] * self.n_out + idx[2], idx[1]], dim=0) | |
self.register_buffer("psi_idx", idx.contiguous(), persistent=False) | |
self.register_buffer("psi_vals", vals.contiguous(), persistent=False) | |
def get_psi(self): | |
psi = torch.sparse_coo_tensor( | |
self.psi_idx, self.psi_vals, size=(self.kernel_size * self.n_out, self.n_in) | |
) | |
return psi | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
# pre-multiply x with the quadrature weights | |
x = self.quad_weights * x | |
psi = self.get_psi() | |
# extract shape | |
B, C, _ = x.shape | |
# bring into the right shape for the bmm and perform it | |
x = x.reshape(B * C, self.n_in).permute(1, 0).contiguous() | |
x = torch.mm(psi, x) | |
x = x.permute(1, 0).reshape(B, C, self.kernel_size, self.n_out) | |
x = x.reshape(B, self.groups, self.groupsize, self.kernel_size, self.n_out) | |
# do weight multiplication | |
out = torch.einsum( | |
"bgckx,gock->bgox", | |
x, | |
self.weight.reshape( | |
self.groups, -1, self.weight.shape[1], self.weight.shape[2] | |
), | |
) | |
out = out.reshape(out.shape[0], -1, out.shape[-1]) | |
if self.bias is not None: | |
out = out + self.bias.reshape(1, -1, 1) | |
return out | |
class EquidistantDiscreteContinuousConv2d(DiscreteContinuousConv): | |
""" | |
Discrete-continuous convolutions (DISCO) on arbitrary 2d grids. | |
[1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603 | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
kernel_shape: Union[int, List[int]], | |
in_shape: Tuple[int], | |
groups: Optional[int] = 1, | |
bias: Optional[bool] = True, | |
radius_cutoff: Optional[float] = None, | |
padding_mode: str = "circular", | |
use_min_dim: bool = True, | |
**kwargs, | |
): | |
""" | |
use_min_dim (bool, optional): Use the minimum dimension of the input | |
shape to compute the cutoff radius. Otherwise use the maximum | |
dimension. Defaults to True. | |
""" | |
super().__init__(in_channels, out_channels, kernel_shape, groups, bias) | |
self.padding_mode = padding_mode | |
# compute the cutoff radius based on the assumption that the grid is [-1, 1]^2 | |
# this still assumes a quadratic domain | |
f = min if use_min_dim else max | |
if radius_cutoff is None: | |
radius_cutoff = 2 * (self.kernel_shape[0]) / float(f(*in_shape)) | |
# 2 * 0.02 * 7 / 2 + 1 = 1.14 | |
self.psi_local_size = math.floor(2 * radius_cutoff * f(*in_shape) / 2) + 1 | |
# psi_local is essentially the support of the hat functions evaluated locally | |
x = torch.linspace(-radius_cutoff, radius_cutoff, self.psi_local_size) | |
x, y = torch.meshgrid(x, x) | |
grid_in = torch.stack([x.reshape(-1), y.reshape(-1)]) | |
grid_out = torch.Tensor([[0.0], [0.0]]) | |
idx, vals = _precompute_convolution_tensor_2d( | |
grid_in, | |
grid_out, | |
self.kernel_shape, | |
radius_cutoff=radius_cutoff, | |
periodic=False, | |
) | |
psi_loc = torch.zeros( | |
self.kernel_size, self.psi_local_size * self.psi_local_size | |
) | |
for ie in range(len(vals)): | |
f = idx[0, ie] | |
j = idx[2, ie] | |
v = vals[ie] | |
psi_loc[f, j] = v | |
# compute local version of the filter matrix | |
psi_loc = psi_loc.reshape( | |
self.kernel_size, self.psi_local_size, self.psi_local_size | |
) | |
# normalization by the quadrature weights | |
psi_loc = 4.0 * psi_loc / float(in_shape[0] * in_shape[1]) | |
self.register_buffer("psi_loc", psi_loc, persistent=False) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
kernel = torch.einsum("kxy,ogk->ogxy", self.psi_loc, self.weight) | |
left_pad = self.psi_local_size // 2 | |
right_pad = (self.psi_local_size + 1) // 2 - 1 | |
x = F.pad(x, (left_pad, right_pad, left_pad, right_pad), mode=self.padding_mode) | |
out = F.conv2d( | |
x, kernel, self.bias, stride=1, dilation=1, padding=0, groups=self.groups | |
) | |
return out | |