Spaces:
Running
Running
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from sentence_transformers import SentenceTransformer | |
from transformers import AutoTokenizer, AutoModelForMaskedLM | |
from typing import List | |
import torch | |
from functools import lru_cache | |
import logging | |
from datetime import datetime | |
from collections import defaultdict | |
# π§ Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# π Initialize FastAPI app | |
app = FastAPI() | |
logger.info("Starting FastAPI application") | |
# π Load SentenceTransformer models | |
logger.info("Loading BGE small model...") | |
bge_small_model = SentenceTransformer('BAAI/bge-small-en-v1.5', device="cpu") | |
logger.info("Loaded BGE small model") | |
logger.info("Loading All-MPNet model...") | |
all_mp_net_model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2', device="cpu") | |
logger.info("Loaded All-MPNet model") | |
# π Load SPLADE model | |
logger.info("Loading SPLADE model...") | |
SPLADE_MODEL = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-ensembledistil", trust_remote_code=True) | |
SPLADE_TOKENIZER = AutoTokenizer.from_pretrained("naver/splade-cocondenser-ensembledistil") | |
SPLADE_MODEL.eval() | |
logger.info("Loaded SPLADE model") | |
# π¦ Request and response models | |
class TextInput(BaseModel): | |
text: List[str] | |
model_name: str | |
class SparseVector(BaseModel): | |
indices: List[int] | |
values: List[float] | |
# π§ LRU cacheable versions | |
def encode_dense_cached(model_name: str, text: str): | |
logger.info(f"Encoding dense text with model {model_name}: {text}") | |
if model_name == "BM": | |
embedding = all_mp_net_model.encode([text])[0].tolist() | |
else: | |
embedding = bge_small_model.encode([text])[0].tolist() | |
logger.info(f"Finished encoding dense text") | |
return embedding | |
def encode_splade_cached(text: str) -> SparseVector: | |
logger.info(f"Encoding SPLADE sparse vector: {text}") | |
inputs = SPLADE_TOKENIZER(text, return_tensors="pt", truncation=True) | |
with torch.no_grad(): | |
outputs = SPLADE_MODEL(**inputs) | |
logits = outputs.logits[0] | |
relu_log = torch.log1p(torch.relu(logits)) | |
nonzero = relu_log.nonzero(as_tuple=False) | |
if nonzero.shape[0] == 0: | |
logger.info("No non-zero values found in SPLADE output") | |
return SparseVector(indices=[], values=[]) | |
vocab_indices = nonzero[:, 1] | |
values = relu_log[nonzero[:, 0], nonzero[:, 1]] | |
vocab_indices_list = vocab_indices.cpu().numpy().tolist() | |
values_list = values.cpu().numpy().tolist() | |
index_to_value = defaultdict(float) | |
for idx, val in zip(vocab_indices_list, values_list): | |
index_to_value[idx] += val | |
deduped_indices = list(index_to_value.keys()) | |
deduped_values = list(index_to_value.values()) | |
logger.info(f"SPLADE encoding complete with {len(deduped_indices)} dimensions") | |
return SparseVector( | |
indices=deduped_indices, | |
values=deduped_values | |
) | |
# π Main endpoint | |
async def get_embedding(input: TextInput): | |
logger.info(f"Received request with model: {input.model_name}, texts: {input.text}") | |
model_key = input.model_name.upper() | |
if model_key in {"BM", "BG"}: | |
embeddings = [encode_dense_cached(model_key, t) for t in input.text] | |
logger.info(f"Returning dense embeddings for {len(embeddings)} texts") | |
return {"type": "dense", "embeddings": embeddings} | |
elif model_key == "SPLADE": | |
sparse_vecs = [encode_splade_cached(t).model_dump() for t in input.text] | |
logger.info(f"Returning sparse embeddings for {len(sparse_vecs)} texts") | |
return {"type": "sparse", "embeddings": sparse_vecs} | |
else: | |
embeddings = bge_small_model.encode(input.text) | |
return {"embeddings": embeddings.tolist()} | |
async def status(): | |
logger.info(f"Status API: Server is up and running at {datetime.now()}") | |
return {"status": "Server is up and running"} | |