Spaces:
Running
on
Zero
Running
on
Zero
from typing import Dict, List | |
import torch | |
import torch.nn as nn | |
import os | |
from functools import cache | |
from transformers import AutoTokenizer, AutoModel | |
from multi_token.modalities.base_modality import Modality | |
from multi_token.modalities.projectors import build_mlp_vector_projector | |
GTE_EMBEDDING_SIZE = 1024 | |
GTE_CONTEXT_WINDOW = 512 | |
GTE_DEFAULT_MODEL = "thenlper/gte-large" | |
DOCUMENT_GTE_FORCE_CPU = "DOCUMENT_GTE_FORCE_CPU" | |
def average_pool( | |
last_hidden_states: torch.Tensor, attention_mask: torch.Tensor | |
) -> torch.Tensor: | |
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) | |
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] | |
def _get_tokenizer(model_name_or_path: str = GTE_DEFAULT_MODEL): | |
return AutoTokenizer.from_pretrained(model_name_or_path) | |
def split_text_into_documents(text: str) -> List[str]: | |
from nltk.tokenize import sent_tokenize | |
tokenizer = _get_tokenizer(GTE_DEFAULT_MODEL) | |
sentences = sent_tokenize(text) | |
documents = [[]] | |
for sentence in sentences: | |
sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False) | |
if len(documents[-1]) + len(sentence_tokens) > GTE_CONTEXT_WINDOW: | |
documents.append([]) | |
documents[-1].extend(sentence_tokens) | |
return [tokenizer.decode(doc) for doc in documents] | |
class DocumentGTEModule(nn.Module): | |
def __init__(self, model_name_or_path: str): | |
super().__init__() | |
self.feature_layer = -2 | |
self.model_name_or_path = model_name_or_path | |
self.model = AutoModel.from_pretrained("thenlper/gte-large") | |
self.model.requires_grad_(False) | |
def forward(self, batch_dict) -> torch.Tensor: | |
outputs = self.model(**batch_dict) | |
embeddings = average_pool( | |
outputs.last_hidden_state, batch_dict["attention_mask"] | |
) | |
return embeddings | |
def embedding_size(self): | |
return GTE_EMBEDDING_SIZE | |
class DocumentGTEModality(Modality): | |
def __init__( | |
self, | |
model_name_or_path: str = GTE_DEFAULT_MODEL, | |
num_projector_layers: int = 2, | |
num_tokens_output: int = 4, | |
): | |
self.model_name_or_path = model_name_or_path | |
self.module = DocumentGTEModule(model_name_or_path=self.model_name_or_path) | |
self.tokenizer = _get_tokenizer(model_name_or_path) | |
self.num_projector_layers = num_projector_layers | |
self.num_tokens_output = num_tokens_output | |
self.dtype = torch.float32 | |
self.device = "cpu" | |
self.document_gte_device = "cpu" | |
def build_projector(self, lm_hidden_size: int) -> nn.Module: | |
return build_mlp_vector_projector( | |
input_hidden_size=self.module.embedding_size, | |
lm_hidden_size=lm_hidden_size, | |
num_layers=self.num_projector_layers, | |
num_tokens=self.num_tokens_output, | |
) | |
def name(self) -> str: | |
return "document_gte" | |
def token(self) -> str: | |
return "<document>" | |
def data_key(self) -> str: | |
return "documents" | |
def token_width(self) -> int: | |
return self.num_tokens_output | |
def to(self, dtype: torch.dtype, device: torch.device) -> "DocumentGTEModality": | |
self.dtype = dtype | |
self.device = device | |
if DOCUMENT_GTE_FORCE_CPU not in os.environ: | |
# running out of VRAM on 24GB GPU | |
self.document_gte_device = device | |
self.module.to(device=self.document_gte_device) | |
return self | |
def preprocess_rows(self, rows: List[Dict]) -> List[Dict]: | |
row_values = [] | |
for row in rows: | |
documents = [] | |
for doc in row[self.data_key]: | |
documents.append(doc) | |
documents_tokenized = self.tokenizer( | |
documents, | |
max_length=GTE_CONTEXT_WINDOW, | |
padding=True, | |
truncation=True, | |
return_tensors="pt", | |
) | |
row_values.append(documents_tokenized) | |
return row_values | |
def forward(self, encoded_values: List[Dict]) -> List[torch.Tensor]: | |
outputs = [] | |
for val in encoded_values: | |
outputs.append( | |
self.module.forward(val.to(device=self.document_gte_device)) | |
.to(device=self.device, dtype=self.dtype) | |
.view(-1, 1, self.module.embedding_size) | |
) | |
# batch_size x num_items x 1 x embedding_size | |
return outputs | |