sohamc10's picture
gradio app
9b0d6c2
import torch.nn as nn
class BidirectionalGRU(nn.Module):
def __init__(self, n_in, n_hidden, dropout=0, num_layers=1):
super(BidirectionalGRU, self).__init__()
self.rnn = nn.GRU(
n_in,
n_hidden,
bidirectional=True,
dropout=dropout,
batch_first=True,
num_layers=num_layers,
)
def forward(self, input_feat):
recurrent, _ = self.rnn(input_feat)
return recurrent
class BidirectionalLSTM(nn.Module):
def __init__(self, nIn, nHidden, nOut, dropout=0, num_layers=1):
super(BidirectionalLSTM, self).__init__()
self.rnn = nn.LSTM(
nIn,
nHidden,
bidirectional=True,
batch_first=True,
dropout=dropout,
num_layers=num_layers,
)
self.embedding = nn.Linear(nHidden * 2, nOut)
def forward(self, input_feat):
recurrent, _ = self.rnn(input_feat)
b, T, h = recurrent.size()
t_rec = recurrent.contiguous().view(b * T, h)
output = self.embedding(t_rec)
output = output.view(b, T, -1)
return output