tejastake commited on
Commit
59f9023
·
verified ·
1 Parent(s): 68b9b1e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -35
app.py CHANGED
@@ -1,36 +1,36 @@
1
- import torch.nn.functional as F
2
- import torch
3
- from pinecone_text.sparse import SpladeEncoder
4
- import re
5
- from fastapi import FastAPI, Depends
6
- from fastapi_health import health
7
- from fastapi import FastAPI, Query
8
-
9
-
10
- def get_session():
11
- return True
12
-
13
- def is_database_online(session: bool = Depends(get_session)):
14
- return session
15
-
16
- app = FastAPI()
17
- app.add_api_route("/healthz", health([is_database_online]))
18
-
19
- class Load_EmbeddingModels:
20
- def __init__(self,model):
21
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
22
- self.sparse_model = SpladeEncoder(device=self.device)
23
-
24
- def get_single_sparse_text_embedding(self, df_chunk):
25
- return self.sparse_model.encode_documents(df_chunk)
26
-
27
-
28
- model = Load_EmbeddingModels()
29
-
30
- @app.post("/embed-text-sparse/")
31
- async def embed_text(text: str = Query(...)):
32
- try:
33
- embeddings = model.get_single_sparse_text_embedding(text)
34
- return embeddings
35
- except Exception as e:
36
  print(f'Error: {e}')
 
1
+ import torch.nn.functional as F
2
+ import torch
3
+ from pinecone_text.sparse import SpladeEncoder
4
+ import re
5
+ from fastapi import FastAPI, Depends
6
+ from fastapi_health import health
7
+ from fastapi import FastAPI, Query
8
+
9
+
10
+ def get_session():
11
+ return True
12
+
13
+ def is_database_online(session: bool = Depends(get_session)):
14
+ return session
15
+
16
+ app = FastAPI()
17
+ app.add_api_route("/healthz", health([is_database_online]))
18
+
19
+ class Load_EmbeddingModels:
20
+ def __init__(self):
21
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
22
+ self.sparse_model = SpladeEncoder(device=self.device)
23
+
24
+ def get_single_sparse_text_embedding(self, df_chunk):
25
+ return self.sparse_model.encode_documents(df_chunk)
26
+
27
+
28
+ model = Load_EmbeddingModels()
29
+
30
+ @app.post("/embed-text-sparse/")
31
+ async def embed_text(text: str = Query(...)):
32
+ try:
33
+ embeddings = model.get_single_sparse_text_embedding(text)
34
+ return embeddings
35
+ except Exception as e:
36
  print(f'Error: {e}')