Spaces:
Sleeping
Sleeping
from typing import Any, Optional, Tuple, Union | |
import torch | |
import transformers | |
class DistilBertTransferLearningModel(torch.nn.Module): | |
def __init__( | |
self, | |
pretrained_model: str = "distilbert-base-uncased", | |
layers: list[Tuple[str, Optional[list[Any]]]] = [ | |
('linear', ['in', 'out']), | |
('softmax'), | |
], | |
dim_out: int = 2, | |
use_local_file: bool = False, | |
device: str = 'cpu', | |
state_dict: Optional[Union[str, dict]] = None, | |
): | |
super(DistilBertTransferLearningModel, self).__init__() | |
self.tokenizer = transformers.AutoTokenizer.from_pretrained( | |
pretrained_model, local_files_only=use_local_file | |
) | |
self.base_model = transformers.AutoModel.from_pretrained( | |
pretrained_model, local_files_only=use_local_file | |
) | |
clf_layers = [] | |
for layer in layers: | |
layer_type = layer[0] if isinstance(layer, tuple) else layer | |
if layer_type == 'linear': | |
layer_in, layer_out = [ | |
( | |
self.base_model.config.hidden_size | |
if x == 'in' | |
else dim_out if x == 'out' else x | |
) | |
for x in layer[1] | |
] | |
clf_layers.append(torch.nn.Linear(layer_in, layer_out)) | |
elif layer_type == 'softmax': | |
clf_layers.append(torch.nn.Softmax(dim=-1)) | |
self.clf = torch.nn.Sequential(*clf_layers) | |
if state_dict is not None: | |
if isinstance(state_dict, str) and state_dict.endswith('.pt'): | |
if device == 'cpu': | |
state_dict = torch.load(state_dict, map_location='cpu') | |
else: | |
state_dict = torch.load(state_dict) | |
self.load_state_dict(state_dict) | |
def forward(self, ids: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: | |
y = self.base_model(ids, attention_mask=mask, return_dict=False)[0][:, 0] | |
y = self.clf(y) | |
return y | |
def predict(self, text: str, device: str) -> torch.Tensor: | |
encoded = self.tokenizer.encode_plus( | |
text, | |
add_special_tokens=True, | |
return_token_type_ids=False, | |
return_attention_mask=True, | |
max_length=512, | |
padding='max_length', | |
truncation=True, | |
return_tensors='pt', | |
) | |
with torch.no_grad(): | |
ids = encoded['input_ids'].to(device) | |
mask = encoded['attention_mask'].to(device) | |
output = self.forward(ids, mask) | |
return output.to(device) | |