import torch import torch.nn as nn from transformers import BertModel from torchcrf import CRF from utils import bert_model class Bert_BiLSTM_CRF(nn.Module): def __init__(self, tag_to_ix, embedding_dim=768, hidden_dim=256): super(Bert_BiLSTM_CRF, self).__init__() self.tag_to_ix = tag_to_ix self.tagset_size = len(tag_to_ix) self.hidden_dim = hidden_dim self.embedding_dim = embedding_dim self.bert = BertModel.from_pretrained(bert_model) self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim//2, num_layers=2, bidirectional=True, batch_first=True) self.dropout = nn.Dropout(p=0.1) self.linear = nn.Linear(hidden_dim, self.tagset_size) self.crf = CRF(self.tagset_size, batch_first=True) def _get_features(self, sentence): with torch.no_grad(): embeds, _ = self.bert(sentence, return_dict=False) enc, _ = self.lstm(embeds) enc = self.dropout(enc) feats = self.linear(enc) return feats def forward(self, sentence, tags=None, mask=None, is_test=False): emissions = self._get_features(sentence) if not is_test: # Training: return loss loss = -self.crf.forward(emissions, tags, mask, reduction='mean') return loss else: # Testing: return decoding decode = self.crf.decode(emissions, mask) return decode