|
import torch |
|
import torch.nn as nn |
|
from transformers import BertModel |
|
from torchcrf import CRF |
|
from utils import bert_model |
|
|
|
class Bert_BiLSTM_CRF(nn.Module): |
|
|
|
def __init__(self, tag_to_ix, embedding_dim=768, hidden_dim=256): |
|
super(Bert_BiLSTM_CRF, self).__init__() |
|
self.tag_to_ix = tag_to_ix |
|
self.tagset_size = len(tag_to_ix) |
|
self.hidden_dim = hidden_dim |
|
self.embedding_dim = embedding_dim |
|
|
|
self.bert = BertModel.from_pretrained(bert_model) |
|
self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim//2, |
|
num_layers=2, bidirectional=True, batch_first=True) |
|
self.dropout = nn.Dropout(p=0.1) |
|
self.linear = nn.Linear(hidden_dim, self.tagset_size) |
|
self.crf = CRF(self.tagset_size, batch_first=True) |
|
|
|
def _get_features(self, sentence): |
|
with torch.no_grad(): |
|
embeds, _ = self.bert(sentence, return_dict=False) |
|
enc, _ = self.lstm(embeds) |
|
enc = self.dropout(enc) |
|
feats = self.linear(enc) |
|
return feats |
|
|
|
def forward(self, sentence, tags=None, mask=None, is_test=False): |
|
emissions = self._get_features(sentence) |
|
if not is_test: |
|
loss = -self.crf.forward(emissions, tags, mask, reduction='mean') |
|
return loss |
|
else: |
|
decode = self.crf.decode(emissions, mask) |
|
return decode |
|
|