nomri / fastmri /math_utils.py
samaonline
init
1b34a12
"""
Copyright (c) Facebook, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
import numpy as np
import torch
def complex_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Complex multiplication.
Multiplies two complex tensors assuming that they are both stored as
real arrays with the last dimension being the complex dimension.
Parameters
----------
x : torch.Tensor
A PyTorch tensor with the last dimension of size 2.
y : torch.Tensor
A PyTorch tensor with the last dimension of size 2.
Returns
-------
torch.Tensor
A PyTorch tensor with the last dimension of size 2, representing
the result of the complex multiplication.
"""
if not x.shape[-1] == y.shape[-1] == 2:
raise ValueError("Tensors do not have separate complex dim.")
re = x[..., 0] * y[..., 0] - x[..., 1] * y[..., 1]
im = x[..., 0] * y[..., 1] + x[..., 1] * y[..., 0]
return torch.stack((re, im), dim=-1)
def complex_conj(x: torch.Tensor) -> torch.Tensor:
"""
Complex conjugate.
Applies the complex conjugate assuming that the input array has the
last dimension as the complex dimension.
Parameters
----------
x : torch.Tensor
A PyTorch tensor with the last dimension of size 2.
Returns
-------
torch.Tensor
A PyTorch tensor with the last dimension of size 2, representing
the complex conjugate of the input tensor.
"""
if not x.shape[-1] == 2:
raise ValueError("Tensor does not have separate complex dim.")
return torch.stack((x[..., 0], -x[..., 1]), dim=-1)
def complex_abs(data: torch.Tensor) -> torch.Tensor:
"""
Compute the absolute value of a complex-valued input tensor.
Parameters
----------
data : torch.Tensor
A complex-valued tensor, where the size of the final dimension
should be 2.
Returns
-------
torch.Tensor
Absolute value of the input tensor.
"""
if not data.shape[-1] == 2:
raise ValueError("Tensor does not have separate complex dim.")
return (data**2).sum(dim=-1).sqrt()
def complex_abs_sq(data: torch.Tensor) -> torch.Tensor:
"""
Compute the squared absolute value of a complex tensor.
Parameters
----------
data : torch.Tensor
A complex-valued tensor, where the size of the final dimension
should be 2.
Returns
-------
torch.Tensor
Squared absolute value of the input tensor.
"""
if not data.shape[-1] == 2:
raise ValueError("Tensor does not have separate complex dim.")
return (data**2).sum(dim=-1)
def tensor_to_complex_np(data: torch.Tensor) -> np.ndarray:
"""
Convert a complex PyTorch tensor to a NumPy array.
Parameters
----------
data : torch.Tensor
Input data to be converted to a NumPy array.
Returns
-------
np.ndarray
A complex NumPy array version of the input tensor.
"""
return torch.view_as_complex(data).numpy()