jeekzhang's picture
Upload models.py
869a613
import torch
import torch.nn as nn
from transformers import BertModel
from torchcrf import CRF
from utils import bert_model
class Bert_BiLSTM_CRF(nn.Module):
def __init__(self, tag_to_ix, embedding_dim=768, hidden_dim=256):
super(Bert_BiLSTM_CRF, self).__init__()
self.tag_to_ix = tag_to_ix
self.tagset_size = len(tag_to_ix)
self.hidden_dim = hidden_dim
self.embedding_dim = embedding_dim
self.bert = BertModel.from_pretrained(bert_model)
self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim//2,
num_layers=2, bidirectional=True, batch_first=True)
self.dropout = nn.Dropout(p=0.1)
self.linear = nn.Linear(hidden_dim, self.tagset_size)
self.crf = CRF(self.tagset_size, batch_first=True)
def _get_features(self, sentence):
with torch.no_grad():
embeds, _ = self.bert(sentence, return_dict=False)
enc, _ = self.lstm(embeds)
enc = self.dropout(enc)
feats = self.linear(enc)
return feats
def forward(self, sentence, tags=None, mask=None, is_test=False):
emissions = self._get_features(sentence)
if not is_test: # Training: return loss
loss = -self.crf.forward(emissions, tags, mask, reduction='mean')
return loss
else: # Testing: return decoding
decode = self.crf.decode(emissions, mask)
return decode