File size: 2,608 Bytes
9ba0b20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca58d4e
 
 
 
 
 
 
 
328e24f
 
 
9ba0b20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e00bb1
9ba0b20
 
 
 
 
 
 
 
 
 
 
328e24f
9ba0b20
328e24f
9ba0b20
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
from typing import Any, Dict, List


# copied from the model card
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

class EndpointHandler():
    def __init__(self, path="./"):

        # load the optimized model
        
        self.model = torch.jit.trace(
            AutoModel.from_pretrained(
                path,
                torchscript=True,
            ), 
            [torch.randint(0,100,(2,128)), torch.randint(0,100,(2,128))],
            )
        self.model.eval()


        self.tokenizer = AutoTokenizer.from_pretrained(path)
        # create inference pipeline

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.model.to(self.device)
    
    def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
        """
        Args:
            data (:obj:):
                includes the input data and the parameters for the inference.
        Return:
            A :obj:`list`:. The object returned should be a list of one list like [[{"label": 0.9939950108528137}]] containing :
                - "label": A string representing what the label/class is. There can be multiple labels.
                - "score": A score between 0 and 1 describing how confident the model is for this label/class.
        """
        inputs = data.pop("inputs", data)
        parameters = data.pop("parameters", None)

        with torch.inference_mode():
            
            if parameters is None:
                max_length = 512
            else:
                max_length = parameters.pop("max_length", 512)

            inputs = self.tokenizer(
                inputs, 
                padding=True, 
                truncation=True, 
                return_tensors='pt', 
                max_length=max_length,
                ).to(self.device)
            
            model_output = self.model(inputs.input_ids, inputs.attention_mask)

            sentence_embeddings = mean_pooling(model_output, inputs.attention_mask)

            sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
            

        # postprocess the prediction
        return {
            "embeddings": sentence_embeddings.cpu().tolist()
        }