File size: 4,604 Bytes
7c34c28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from typing import Dict, List

import torch
import torch.nn as nn
import os
from functools import cache
from transformers import AutoTokenizer, AutoModel

from multi_token.modalities.base_modality import Modality
from multi_token.modalities.projectors import build_mlp_vector_projector

GTE_EMBEDDING_SIZE = 1024
GTE_CONTEXT_WINDOW = 512
GTE_DEFAULT_MODEL = "thenlper/gte-large"
DOCUMENT_GTE_FORCE_CPU = "DOCUMENT_GTE_FORCE_CPU"


def average_pool(
    last_hidden_states: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]


@cache
def _get_tokenizer(model_name_or_path: str = GTE_DEFAULT_MODEL):
    return AutoTokenizer.from_pretrained(model_name_or_path)


def split_text_into_documents(text: str) -> List[str]:
    from nltk.tokenize import sent_tokenize

    tokenizer = _get_tokenizer(GTE_DEFAULT_MODEL)

    sentences = sent_tokenize(text)
    documents = [[]]

    for sentence in sentences:
        sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
        if len(documents[-1]) + len(sentence_tokens) > GTE_CONTEXT_WINDOW:
            documents.append([])
        documents[-1].extend(sentence_tokens)

    return [tokenizer.decode(doc) for doc in documents]


class DocumentGTEModule(nn.Module):
    def __init__(self, model_name_or_path: str):
        super().__init__()
        self.feature_layer = -2
        self.model_name_or_path = model_name_or_path

        self.model = AutoModel.from_pretrained("thenlper/gte-large")
        self.model.requires_grad_(False)

    @torch.no_grad()
    def forward(self, batch_dict) -> torch.Tensor:
        outputs = self.model(**batch_dict)
        embeddings = average_pool(
            outputs.last_hidden_state, batch_dict["attention_mask"]
        )
        return embeddings

    @property
    def embedding_size(self):
        return GTE_EMBEDDING_SIZE


class DocumentGTEModality(Modality):
    def __init__(
        self,
        model_name_or_path: str = GTE_DEFAULT_MODEL,
        num_projector_layers: int = 2,
        num_tokens_output: int = 4,
    ):
        self.model_name_or_path = model_name_or_path
        self.module = DocumentGTEModule(model_name_or_path=self.model_name_or_path)
        self.tokenizer = _get_tokenizer(model_name_or_path)
        self.num_projector_layers = num_projector_layers
        self.num_tokens_output = num_tokens_output
        self.dtype = torch.float32
        self.device = "cpu"
        self.document_gte_device = "cpu"

    def build_projector(self, lm_hidden_size: int) -> nn.Module:
        return build_mlp_vector_projector(
            input_hidden_size=self.module.embedding_size,
            lm_hidden_size=lm_hidden_size,
            num_layers=self.num_projector_layers,
            num_tokens=self.num_tokens_output,
        )

    @property
    def name(self) -> str:
        return "document_gte"

    @property
    def token(self) -> str:
        return "<document>"

    @property
    def data_key(self) -> str:
        return "documents"

    @property
    def token_width(self) -> int:
        return self.num_tokens_output

    def to(self, dtype: torch.dtype, device: torch.device) -> "DocumentGTEModality":
        self.dtype = dtype
        self.device = device
        if DOCUMENT_GTE_FORCE_CPU not in os.environ:
            # running out of VRAM on 24GB GPU
            self.document_gte_device = device
        self.module.to(device=self.document_gte_device)
        return self

    def preprocess_rows(self, rows: List[Dict]) -> List[Dict]:
        row_values = []
        for row in rows:
            documents = []
            for doc in row[self.data_key]:
                documents.append(doc)
            documents_tokenized = self.tokenizer(
                documents,
                max_length=GTE_CONTEXT_WINDOW,
                padding=True,
                truncation=True,
                return_tensors="pt",
            )
            row_values.append(documents_tokenized)
        return row_values

    @torch.no_grad()
    def forward(self, encoded_values: List[Dict]) -> List[torch.Tensor]:
        outputs = []
        for val in encoded_values:
            outputs.append(
                self.module.forward(val.to(device=self.document_gte_device))
                .to(device=self.device, dtype=self.dtype)
                .view(-1, 1, self.module.embedding_size)
            )
        # batch_size x num_items x 1 x embedding_size
        return outputs