Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from transformers import AutoTokenizer, AutoModel | |
class TextEmbeddingModel(nn.Module): | |
def __init__(self, model_name,output_hidden_states=False): | |
super(TextEmbeddingModel, self).__init__() | |
self.model_name = model_name | |
if output_hidden_states: | |
self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True, output_hidden_states=True) | |
else: | |
self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True) | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
def pooling(self, model_output, attention_mask, use_pooling='average',hidden_states=False): | |
if hidden_states: | |
model_output.masked_fill(~attention_mask[None,..., None].bool(), 0.0) | |
if use_pooling == "average": | |
emb = model_output.sum(dim=2) / attention_mask.sum(dim=1)[..., None] | |
else: | |
emb = model_output[:,:, 0] | |
emb = emb.permute(1, 0, 2) | |
else: | |
model_output.masked_fill(~attention_mask[..., None].bool(), 0.0) | |
if use_pooling == "average": | |
emb = model_output.sum(dim=1) / attention_mask.sum(dim=1)[..., None] | |
elif use_pooling == "cls": | |
emb = model_output[:, 0] | |
return emb | |
def forward(self, encoded_batch, use_pooling='average',hidden_states=False): | |
if "t5" in self.model_name.lower(): | |
input_ids = encoded_batch['input_ids'] | |
decoder_input_ids = torch.zeros((input_ids.shape[0], 1), dtype=torch.long, device=input_ids.device) | |
model_output = self.model(**encoded_batch, | |
decoder_input_ids=decoder_input_ids) | |
else: | |
model_output = self.model(**encoded_batch) | |
if 'bge' in self.model_name.lower() or 'mxbai' in self.model_name.lower(): | |
use_pooling = 'cls' | |
if isinstance(model_output, tuple): | |
model_output = model_output[0] | |
if isinstance(model_output, dict): | |
if hidden_states: | |
model_output = model_output["hidden_states"] | |
model_output = torch.stack(model_output, dim=0) | |
else: | |
model_output = model_output["last_hidden_state"] | |
emb = self.pooling(model_output, encoded_batch['attention_mask'], use_pooling,hidden_states) | |
emb = torch.nn.functional.normalize(emb, dim=-1) | |
return emb | |