|
|
|
|
|
from typing import Any, Optional, Tuple, Union |
|
|
|
import torch |
|
from torch import Tensor, nn, sigmoid, tanh |
|
|
|
|
|
class ConvGate(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
hidden_channels: int, |
|
kernel_size: Union[Tuple[int, int], int], |
|
padding: Union[Tuple[int, int], int], |
|
stride: Union[Tuple[int, int], int], |
|
bias: bool, |
|
): |
|
super(ConvGate, self).__init__() |
|
self.conv_x = nn.Conv2d( |
|
in_channels=in_channels, |
|
out_channels=hidden_channels * 4, |
|
kernel_size=kernel_size, |
|
padding=padding, |
|
stride=stride, |
|
bias=bias, |
|
) |
|
self.conv_h = nn.Conv2d( |
|
in_channels=hidden_channels, |
|
out_channels=hidden_channels * 4, |
|
kernel_size=kernel_size, |
|
padding=padding, |
|
stride=stride, |
|
bias=bias, |
|
) |
|
self.bn2d = nn.BatchNorm2d(hidden_channels * 4) |
|
|
|
def forward(self, x, hidden_state): |
|
gated = self.conv_x(x) + self.conv_h(hidden_state) |
|
return self.bn2d(gated) |
|
|
|
|
|
class ConvLSTMCell(nn.Module): |
|
def __init__( |
|
self, in_channels, hidden_channels, kernel_size, padding, stride, bias |
|
): |
|
super().__init__() |
|
|
|
|
|
self.gates = nn.ModuleList( |
|
[ConvGate(in_channels, hidden_channels, kernel_size, padding, stride, bias)] |
|
) |
|
|
|
def forward( |
|
self, x: Tensor, hidden_state: Tensor, cell_state: Tensor |
|
) -> Tuple[Tensor, Tensor]: |
|
gated = self.gates[0](x, hidden_state) |
|
i_gated, f_gated, c_gated, o_gated = gated.chunk(4, dim=1) |
|
|
|
i_gated = sigmoid(i_gated) |
|
f_gated = sigmoid(f_gated) |
|
o_gated = sigmoid(o_gated) |
|
|
|
cell_state = f_gated.mul(cell_state) + i_gated.mul(tanh(c_gated)) |
|
hidden_state = o_gated.mul(tanh(cell_state)) |
|
|
|
return hidden_state, cell_state |
|
|
|
|
|
class ConvLSTM(nn.Module): |
|
"""ConvLSTM module""" |
|
|
|
def __init__( |
|
self, |
|
in_channels, |
|
hidden_channels, |
|
kernel_size, |
|
padding, |
|
stride, |
|
bias, |
|
batch_first, |
|
bidirectional, |
|
): |
|
super().__init__() |
|
self.in_channels = in_channels |
|
self.hidden_channels = hidden_channels |
|
self.bidirectional = bidirectional |
|
self.batch_first = batch_first |
|
|
|
|
|
|
|
self.conv_lstm_cells = nn.ModuleList( |
|
[ |
|
ConvLSTMCell( |
|
in_channels, hidden_channels, kernel_size, padding, stride, bias |
|
) |
|
] |
|
) |
|
|
|
if self.bidirectional: |
|
self.conv_lstm_cells.append( |
|
ConvLSTMCell( |
|
in_channels, hidden_channels, kernel_size, padding, stride, bias |
|
) |
|
) |
|
|
|
self.batch_size = None |
|
self.seq_len = None |
|
self.height = None |
|
self.width = None |
|
|
|
def forward( |
|
self, x: Tensor, state: Optional[Tuple[Tensor, Tensor]] = None |
|
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: |
|
|
|
x = self._check_shape(x) |
|
hidden_state, cell_state, backward_hidden_state, backward_cell_state = ( |
|
self.init_state(x, state) |
|
) |
|
|
|
output, hidden_state, cell_state = self._forward( |
|
self.conv_lstm_cells[0], x, hidden_state, cell_state |
|
) |
|
|
|
if self.bidirectional: |
|
x = torch.flip(x, [1]) |
|
backward_output, backward_hidden_state, backward_cell_state = self._forward( |
|
self.conv_lstm_cells[1], x, backward_hidden_state, backward_cell_state |
|
) |
|
|
|
output = torch.cat([output, backward_output], dim=-3) |
|
hidden_state = torch.cat([hidden_state, backward_hidden_state], dim=-1) |
|
cell_state = torch.cat([cell_state, backward_cell_state], dim=-1) |
|
return output, (hidden_state, cell_state) |
|
|
|
def _forward(self, lstm_cell, x, hidden_state, cell_state): |
|
outputs = [] |
|
for time_step in range(self.seq_len): |
|
x_t = x[:, time_step, :, :, :] |
|
hidden_state, cell_state = lstm_cell(x_t, hidden_state, cell_state) |
|
outputs.append(hidden_state.detach()) |
|
output = torch.stack(outputs, dim=1) |
|
return output, hidden_state, cell_state |
|
|
|
def _check_shape(self, x: Tensor) -> Tensor: |
|
if self.batch_first: |
|
batch_size, self.seq_len = x.shape[0], x.shape[1] |
|
else: |
|
batch_size, self.seq_len = x.shape[1], x.shape[0] |
|
x = x.permute(1, 0, 2, 3) |
|
x = torch.swapaxes(x, 0, 1) |
|
|
|
self.height = x.shape[-2] |
|
self.width = x.shape[-1] |
|
|
|
dim = len(x.shape) |
|
|
|
if dim == 4: |
|
x = x.unsqueeze(dim=1) |
|
x = x.view(batch_size, self.seq_len, -1, self.height, self.width) |
|
x = x.contiguous() |
|
elif dim <= 3: |
|
raise ValueError( |
|
f"Got {len(x.shape)} dimensional tensor. Input shape unmatched" |
|
) |
|
|
|
return x |
|
|
|
def init_state( |
|
self, x: Tensor, state: Optional[Tuple[Tensor, Tensor]] |
|
) -> Tuple[Union[Tensor, Any], Union[Tensor, Any], Optional[Any], Optional[Any]]: |
|
|
|
backward_hidden_state, backward_cell_state = None, None |
|
|
|
if state is None: |
|
self.batch_size = x.shape[0] |
|
hidden_state, cell_state = self._init_state(x.dtype, x.device) |
|
|
|
if self.bidirectional: |
|
backward_hidden_state, backward_cell_state = self._init_state( |
|
x.dtype, x.device |
|
) |
|
else: |
|
if self.bidirectional: |
|
hidden_state, hidden_state_back = state[0].chunk(2, dim=-1) |
|
cell_state, cell_state_back = state[1].chunk(2, dim=-1) |
|
else: |
|
hidden_state, cell_state = state |
|
|
|
return hidden_state, cell_state, backward_hidden_state, backward_cell_state |
|
|
|
def _init_state(self, dtype, device): |
|
self.register_buffer( |
|
"hidden_state", |
|
torch.zeros( |
|
(1, self.hidden_channels, self.height, self.width), |
|
dtype=dtype, |
|
device=device, |
|
), |
|
) |
|
self.register_buffer( |
|
"cell_state", |
|
torch.zeros( |
|
(1, self.hidden_channels, self.height, self.width), |
|
dtype=dtype, |
|
device=device, |
|
), |
|
) |
|
return self.hidden_state, self.cell_state |
|
|