Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,149 Bytes
1b34a12 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
"""
Copyright (c) Facebook, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
import numpy as np
import torch
def complex_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Complex multiplication.
Multiplies two complex tensors assuming that they are both stored as
real arrays with the last dimension being the complex dimension.
Parameters
----------
x : torch.Tensor
A PyTorch tensor with the last dimension of size 2.
y : torch.Tensor
A PyTorch tensor with the last dimension of size 2.
Returns
-------
torch.Tensor
A PyTorch tensor with the last dimension of size 2, representing
the result of the complex multiplication.
"""
if not x.shape[-1] == y.shape[-1] == 2:
raise ValueError("Tensors do not have separate complex dim.")
re = x[..., 0] * y[..., 0] - x[..., 1] * y[..., 1]
im = x[..., 0] * y[..., 1] + x[..., 1] * y[..., 0]
return torch.stack((re, im), dim=-1)
def complex_conj(x: torch.Tensor) -> torch.Tensor:
"""
Complex conjugate.
Applies the complex conjugate assuming that the input array has the
last dimension as the complex dimension.
Parameters
----------
x : torch.Tensor
A PyTorch tensor with the last dimension of size 2.
Returns
-------
torch.Tensor
A PyTorch tensor with the last dimension of size 2, representing
the complex conjugate of the input tensor.
"""
if not x.shape[-1] == 2:
raise ValueError("Tensor does not have separate complex dim.")
return torch.stack((x[..., 0], -x[..., 1]), dim=-1)
def complex_abs(data: torch.Tensor) -> torch.Tensor:
"""
Compute the absolute value of a complex-valued input tensor.
Parameters
----------
data : torch.Tensor
A complex-valued tensor, where the size of the final dimension
should be 2.
Returns
-------
torch.Tensor
Absolute value of the input tensor.
"""
if not data.shape[-1] == 2:
raise ValueError("Tensor does not have separate complex dim.")
return (data**2).sum(dim=-1).sqrt()
def complex_abs_sq(data: torch.Tensor) -> torch.Tensor:
"""
Compute the squared absolute value of a complex tensor.
Parameters
----------
data : torch.Tensor
A complex-valued tensor, where the size of the final dimension
should be 2.
Returns
-------
torch.Tensor
Squared absolute value of the input tensor.
"""
if not data.shape[-1] == 2:
raise ValueError("Tensor does not have separate complex dim.")
return (data**2).sum(dim=-1)
def tensor_to_complex_np(data: torch.Tensor) -> np.ndarray:
"""
Convert a complex PyTorch tensor to a NumPy array.
Parameters
----------
data : torch.Tensor
Input data to be converted to a NumPy array.
Returns
-------
np.ndarray
A complex NumPy array version of the input tensor.
"""
return torch.view_as_complex(data).numpy()
|