File size: 6,921 Bytes
e3f3842 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
# Copyright (c) 2022 Seyong Kim
from typing import Any, Optional, Tuple, Union
import torch
from torch import Tensor, nn, sigmoid, tanh
class ConvGate(nn.Module):
def __init__(
self,
in_channels: int,
hidden_channels: int,
kernel_size: Union[Tuple[int, int], int],
padding: Union[Tuple[int, int], int],
stride: Union[Tuple[int, int], int],
bias: bool,
):
super(ConvGate, self).__init__()
self.conv_x = nn.Conv2d(
in_channels=in_channels,
out_channels=hidden_channels * 4,
kernel_size=kernel_size,
padding=padding,
stride=stride,
bias=bias,
)
self.conv_h = nn.Conv2d(
in_channels=hidden_channels,
out_channels=hidden_channels * 4,
kernel_size=kernel_size,
padding=padding,
stride=stride,
bias=bias,
)
self.bn2d = nn.BatchNorm2d(hidden_channels * 4)
def forward(self, x, hidden_state):
gated = self.conv_x(x) + self.conv_h(hidden_state)
return self.bn2d(gated)
class ConvLSTMCell(nn.Module):
def __init__(
self, in_channels, hidden_channels, kernel_size, padding, stride, bias
):
super().__init__()
# To check the model structure with tools such as torchinfo, need to wrap
# the custom module with nn.ModuleList
self.gates = nn.ModuleList(
[ConvGate(in_channels, hidden_channels, kernel_size, padding, stride, bias)]
)
def forward(
self, x: Tensor, hidden_state: Tensor, cell_state: Tensor
) -> Tuple[Tensor, Tensor]:
gated = self.gates[0](x, hidden_state)
i_gated, f_gated, c_gated, o_gated = gated.chunk(4, dim=1)
i_gated = sigmoid(i_gated)
f_gated = sigmoid(f_gated)
o_gated = sigmoid(o_gated)
cell_state = f_gated.mul(cell_state) + i_gated.mul(tanh(c_gated))
hidden_state = o_gated.mul(tanh(cell_state))
return hidden_state, cell_state
class ConvLSTM(nn.Module):
"""ConvLSTM module"""
def __init__(
self,
in_channels,
hidden_channels,
kernel_size,
padding,
stride,
bias,
batch_first,
bidirectional,
):
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.bidirectional = bidirectional
self.batch_first = batch_first
# To check the model structure with tools such as torchinfo, need to wrap
# the custom module with nn.ModuleList
self.conv_lstm_cells = nn.ModuleList(
[
ConvLSTMCell(
in_channels, hidden_channels, kernel_size, padding, stride, bias
)
]
)
if self.bidirectional:
self.conv_lstm_cells.append(
ConvLSTMCell(
in_channels, hidden_channels, kernel_size, padding, stride, bias
)
)
self.batch_size = None
self.seq_len = None
self.height = None
self.width = None
def forward(
self, x: Tensor, state: Optional[Tuple[Tensor, Tensor]] = None
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
# size of x: B, T, C, H, W or T, B, C, H, W
x = self._check_shape(x)
hidden_state, cell_state, backward_hidden_state, backward_cell_state = (
self.init_state(x, state)
)
output, hidden_state, cell_state = self._forward(
self.conv_lstm_cells[0], x, hidden_state, cell_state
)
if self.bidirectional:
x = torch.flip(x, [1])
backward_output, backward_hidden_state, backward_cell_state = self._forward(
self.conv_lstm_cells[1], x, backward_hidden_state, backward_cell_state
)
output = torch.cat([output, backward_output], dim=-3)
hidden_state = torch.cat([hidden_state, backward_hidden_state], dim=-1)
cell_state = torch.cat([cell_state, backward_cell_state], dim=-1)
return output, (hidden_state, cell_state)
def _forward(self, lstm_cell, x, hidden_state, cell_state):
outputs = []
for time_step in range(self.seq_len):
x_t = x[:, time_step, :, :, :]
hidden_state, cell_state = lstm_cell(x_t, hidden_state, cell_state)
outputs.append(hidden_state.detach())
output = torch.stack(outputs, dim=1)
return output, hidden_state, cell_state
def _check_shape(self, x: Tensor) -> Tensor:
if self.batch_first:
batch_size, self.seq_len = x.shape[0], x.shape[1]
else:
batch_size, self.seq_len = x.shape[1], x.shape[0]
x = x.permute(1, 0, 2, 3)
x = torch.swapaxes(x, 0, 1)
self.height = x.shape[-2]
self.width = x.shape[-1]
dim = len(x.shape)
if dim == 4:
x = x.unsqueeze(dim=1) # increase dimension
x = x.view(batch_size, self.seq_len, -1, self.height, self.width)
x = x.contiguous() # Reassign memory location
elif dim <= 3:
raise ValueError(
f"Got {len(x.shape)} dimensional tensor. Input shape unmatched"
)
return x
def init_state(
self, x: Tensor, state: Optional[Tuple[Tensor, Tensor]]
) -> Tuple[Union[Tensor, Any], Union[Tensor, Any], Optional[Any], Optional[Any]]:
# If state doesn't enter as input, initialize state to zeros
backward_hidden_state, backward_cell_state = None, None
if state is None:
self.batch_size = x.shape[0]
hidden_state, cell_state = self._init_state(x.dtype, x.device)
if self.bidirectional:
backward_hidden_state, backward_cell_state = self._init_state(
x.dtype, x.device
)
else:
if self.bidirectional:
hidden_state, hidden_state_back = state[0].chunk(2, dim=-1)
cell_state, cell_state_back = state[1].chunk(2, dim=-1)
else:
hidden_state, cell_state = state
return hidden_state, cell_state, backward_hidden_state, backward_cell_state
def _init_state(self, dtype, device):
self.register_buffer(
"hidden_state",
torch.zeros(
(1, self.hidden_channels, self.height, self.width),
dtype=dtype,
device=device,
),
)
self.register_buffer(
"cell_state",
torch.zeros(
(1, self.hidden_channels, self.height, self.width),
dtype=dtype,
device=device,
),
)
return self.hidden_state, self.cell_state
|