embd / app.py
sushil3125's picture
duplicate issue
60cacbb
from fastapi import FastAPI
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForMaskedLM
from typing import List
import torch
from functools import lru_cache
import logging
from datetime import datetime
from collections import defaultdict
# πŸ”§ Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# πŸš€ Initialize FastAPI app
app = FastAPI()
logger.info("Starting FastAPI application")
# πŸ”Œ Load SentenceTransformer models
logger.info("Loading BGE small model...")
bge_small_model = SentenceTransformer('BAAI/bge-small-en-v1.5', device="cpu")
logger.info("Loaded BGE small model")
logger.info("Loading All-MPNet model...")
all_mp_net_model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2', device="cpu")
logger.info("Loaded All-MPNet model")
# πŸ”Œ Load SPLADE model
logger.info("Loading SPLADE model...")
SPLADE_MODEL = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-ensembledistil", trust_remote_code=True)
SPLADE_TOKENIZER = AutoTokenizer.from_pretrained("naver/splade-cocondenser-ensembledistil")
SPLADE_MODEL.eval()
logger.info("Loaded SPLADE model")
# πŸ“¦ Request and response models
class TextInput(BaseModel):
text: List[str]
model_name: str
class SparseVector(BaseModel):
indices: List[int]
values: List[float]
# 🧠 LRU cacheable versions
@lru_cache(maxsize=1000)
def encode_dense_cached(model_name: str, text: str):
logger.info(f"Encoding dense text with model {model_name}: {text}")
if model_name == "BM":
embedding = all_mp_net_model.encode([text])[0].tolist()
else:
embedding = bge_small_model.encode([text])[0].tolist()
logger.info(f"Finished encoding dense text")
return embedding
@lru_cache(maxsize=1000)
def encode_splade_cached(text: str) -> SparseVector:
logger.info(f"Encoding SPLADE sparse vector: {text}")
inputs = SPLADE_TOKENIZER(text, return_tensors="pt", truncation=True)
with torch.no_grad():
outputs = SPLADE_MODEL(**inputs)
logits = outputs.logits[0]
relu_log = torch.log1p(torch.relu(logits))
nonzero = relu_log.nonzero(as_tuple=False)
if nonzero.shape[0] == 0:
logger.info("No non-zero values found in SPLADE output")
return SparseVector(indices=[], values=[])
vocab_indices = nonzero[:, 1]
values = relu_log[nonzero[:, 0], nonzero[:, 1]]
vocab_indices_list = vocab_indices.cpu().numpy().tolist()
values_list = values.cpu().numpy().tolist()
index_to_value = defaultdict(float)
for idx, val in zip(vocab_indices_list, values_list):
index_to_value[idx] += val
deduped_indices = list(index_to_value.keys())
deduped_values = list(index_to_value.values())
logger.info(f"SPLADE encoding complete with {len(deduped_indices)} dimensions")
return SparseVector(
indices=deduped_indices,
values=deduped_values
)
# πŸš€ Main endpoint
@app.post("/get-embedding/")
async def get_embedding(input: TextInput):
logger.info(f"Received request with model: {input.model_name}, texts: {input.text}")
model_key = input.model_name.upper()
if model_key in {"BM", "BG"}:
embeddings = [encode_dense_cached(model_key, t) for t in input.text]
logger.info(f"Returning dense embeddings for {len(embeddings)} texts")
return {"type": "dense", "embeddings": embeddings}
elif model_key == "SPLADE":
sparse_vecs = [encode_splade_cached(t).model_dump() for t in input.text]
logger.info(f"Returning sparse embeddings for {len(sparse_vecs)} texts")
return {"type": "sparse", "embeddings": sparse_vecs}
else:
embeddings = bge_small_model.encode(input.text)
return {"embeddings": embeddings.tolist()}
@app.get("/status")
async def status():
logger.info(f"Status API: Server is up and running at {datetime.now()}")
return {"status": "Server is up and running"}