|
from transformers import PreTrainedModel, DPRQuestionEncoder, DPRContextEncoder |
|
import torch |
|
import torch.nn as nn |
|
from .configuration_dpr import CustomDPRConfig |
|
from typing import Union, List, Dict |
|
|
|
|
|
class OBSSDPRModel(PreTrainedModel): |
|
config_class = CustomDPRConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.config = config |
|
self.model = DPRModel() |
|
|
|
def forward(self, input): |
|
return self.model(input) |
|
|
|
|
|
class DPRModel(nn.Module): |
|
def __init__(self, |
|
question_model_name='facebook/dpr-question_encoder-single-nq-base', |
|
context_model_name='facebook/dpr-ctx_encoder-single-nq-base', |
|
freeze_params=12.0): |
|
super(DPRModel, self).__init__() |
|
self.freeze_params = freeze_params |
|
self.question_model = DPRQuestionEncoder.from_pretrained(question_model_name) |
|
self.context_model = DPRContextEncoder.from_pretrained(context_model_name) |
|
|
|
|
|
def freeze_layers(self, freeze_params): |
|
num_layers_context = sum(1 for _ in self.context_model.parameters()) |
|
num_layers_question = sum(1 for _ in self.question_model.parameters()) |
|
|
|
for parameters in list(self.context_model.parameters())[:int(freeze_params * num_layers_context)]: |
|
parameters.requires_grad = False |
|
|
|
for parameters in list(self.context_model.parameters())[int(freeze_params * num_layers_context):]: |
|
parameters.requires_grad = True |
|
|
|
for parameters in list(self.question_model.parameters())[:int(freeze_params * num_layers_question)]: |
|
parameters.requires_grad = False |
|
|
|
for parameters in list(self.question_model.parameters())[int(freeze_params * num_layers_question):]: |
|
parameters.requires_grad = True |
|
|
|
def batch_dot_product(self, context_output, question_output): |
|
mat1 = torch.unsqueeze(question_output, dim=1) |
|
mat2 = torch.unsqueeze(context_output, dim=2) |
|
result = torch.bmm(mat1, mat2) |
|
result = torch.squeeze(result, dim=1) |
|
result = torch.squeeze(result, dim=1) |
|
return result |
|
|
|
def forward(self, batch: Union[List[Dict], Dict]): |
|
context_tensor = batch['context_tensor'] |
|
question_tensor = batch['question_tensor'] |
|
context_model_output = self.context_model(input_ids=context_tensor['input_ids'], |
|
attention_mask=context_tensor['attention_mask']) |
|
question_model_output = self.question_model(input_ids = question_tensor['input_ids'], |
|
attention_mask=question_tensor['attention_mask']) |
|
embeddings_context = context_model_output['pooler_output'] |
|
embeddings_question = question_model_output['pooler_output'] |
|
|
|
scores = self.batch_dot_product(embeddings_context, embeddings_question) |
|
return scores |
|
|