from fastapi import FastAPI from pydantic import BaseModel from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer, AutoModelForMaskedLM from typing import List import torch from functools import lru_cache import logging from datetime import datetime from collections import defaultdict # 🔧 Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # 🚀 Initialize FastAPI app app = FastAPI() logger.info("Starting FastAPI application") # 🔌 Load SentenceTransformer models logger.info("Loading BGE small model...") bge_small_model = SentenceTransformer('BAAI/bge-small-en-v1.5', device="cpu") logger.info("Loaded BGE small model") logger.info("Loading All-MPNet model...") all_mp_net_model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2', device="cpu") logger.info("Loaded All-MPNet model") # 🔌 Load SPLADE model logger.info("Loading SPLADE model...") SPLADE_MODEL = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-ensembledistil", trust_remote_code=True) SPLADE_TOKENIZER = AutoTokenizer.from_pretrained("naver/splade-cocondenser-ensembledistil") SPLADE_MODEL.eval() logger.info("Loaded SPLADE model") # 📦 Request and response models class TextInput(BaseModel): text: List[str] model_name: str class SparseVector(BaseModel): indices: List[int] values: List[float] # 🧠 LRU cacheable versions @lru_cache(maxsize=1000) def encode_dense_cached(model_name: str, text: str): logger.info(f"Encoding dense text with model {model_name}: {text}") if model_name == "BM": embedding = all_mp_net_model.encode([text])[0].tolist() else: embedding = bge_small_model.encode([text])[0].tolist() logger.info(f"Finished encoding dense text") return embedding @lru_cache(maxsize=1000) def encode_splade_cached(text: str) -> SparseVector: logger.info(f"Encoding SPLADE sparse vector: {text}") inputs = SPLADE_TOKENIZER(text, return_tensors="pt", truncation=True) with torch.no_grad(): outputs = SPLADE_MODEL(**inputs) logits = outputs.logits[0] relu_log = torch.log1p(torch.relu(logits)) nonzero = relu_log.nonzero(as_tuple=False) if nonzero.shape[0] == 0: logger.info("No non-zero values found in SPLADE output") return SparseVector(indices=[], values=[]) vocab_indices = nonzero[:, 1] values = relu_log[nonzero[:, 0], nonzero[:, 1]] vocab_indices_list = vocab_indices.cpu().numpy().tolist() values_list = values.cpu().numpy().tolist() index_to_value = defaultdict(float) for idx, val in zip(vocab_indices_list, values_list): index_to_value[idx] += val deduped_indices = list(index_to_value.keys()) deduped_values = list(index_to_value.values()) logger.info(f"SPLADE encoding complete with {len(deduped_indices)} dimensions") return SparseVector( indices=deduped_indices, values=deduped_values ) # 🚀 Main endpoint @app.post("/get-embedding/") async def get_embedding(input: TextInput): logger.info(f"Received request with model: {input.model_name}, texts: {input.text}") model_key = input.model_name.upper() if model_key in {"BM", "BG"}: embeddings = [encode_dense_cached(model_key, t) for t in input.text] logger.info(f"Returning dense embeddings for {len(embeddings)} texts") return {"type": "dense", "embeddings": embeddings} elif model_key == "SPLADE": sparse_vecs = [encode_splade_cached(t).model_dump() for t in input.text] logger.info(f"Returning sparse embeddings for {len(sparse_vecs)} texts") return {"type": "sparse", "embeddings": sparse_vecs} else: embeddings = bge_small_model.encode(input.text) return {"embeddings": embeddings.tolist()} @app.get("/status") async def status(): logger.info(f"Status API: Server is up and running at {datetime.now()}") return {"status": "Server is up and running"}