File size: 1,061 Bytes
59f9023
 
 
 
 
 
 
8c67124
59f9023
8c67124
 
59f9023
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c67124
59f9023
8c67124
59f9023
 
68b9b1e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch.nn.functional as F
import torch
from pinecone_text.sparse import SpladeEncoder
import re
from fastapi import FastAPI, Depends
from fastapi_health import health
from fastapi import FastAPI, Query
from pydantic import BaseModel

class TextPayload(BaseModel):
    text: str

def get_session():
    return True

def is_database_online(session: bool = Depends(get_session)):
    return session

app = FastAPI()
app.add_api_route("/healthz", health([is_database_online]))

class Load_EmbeddingModels:
    def __init__(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.sparse_model = SpladeEncoder(device=self.device)
    
    def get_single_sparse_text_embedding(self, df_chunk):
        return self.sparse_model.encode_documents(df_chunk)


model = Load_EmbeddingModels()

@app.post("/embed-text-sparse/")
async def embed_text(payload: TextPayload):
    try:
        embeddings = model.get_single_sparse_text_embedding(payload.text)
        return embeddings
    except Exception as e:
        print(f'Error: {e}')