Spaces:
Running
Running
File size: 3,999 Bytes
693e699 32c0b8d 693e699 32c0b8d b3e42d1 60cacbb 693e699 32c0b8d 693e699 32c0b8d 693e699 32c0b8d 693e699 32c0b8d 693e699 32c0b8d 693e699 32c0b8d 693e699 32c0b8d 60cacbb 32c0b8d 60cacbb 32c0b8d 60cacbb 32c0b8d 693e699 32c0b8d 693e699 91fb3b7 b3e42d1 f4ba8ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
from fastapi import FastAPI
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForMaskedLM
from typing import List
import torch
from functools import lru_cache
import logging
from datetime import datetime
from collections import defaultdict
# π§ Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# π Initialize FastAPI app
app = FastAPI()
logger.info("Starting FastAPI application")
# π Load SentenceTransformer models
logger.info("Loading BGE small model...")
bge_small_model = SentenceTransformer('BAAI/bge-small-en-v1.5', device="cpu")
logger.info("Loaded BGE small model")
logger.info("Loading All-MPNet model...")
all_mp_net_model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2', device="cpu")
logger.info("Loaded All-MPNet model")
# π Load SPLADE model
logger.info("Loading SPLADE model...")
SPLADE_MODEL = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-ensembledistil", trust_remote_code=True)
SPLADE_TOKENIZER = AutoTokenizer.from_pretrained("naver/splade-cocondenser-ensembledistil")
SPLADE_MODEL.eval()
logger.info("Loaded SPLADE model")
# π¦ Request and response models
class TextInput(BaseModel):
text: List[str]
model_name: str
class SparseVector(BaseModel):
indices: List[int]
values: List[float]
# π§ LRU cacheable versions
@lru_cache(maxsize=1000)
def encode_dense_cached(model_name: str, text: str):
logger.info(f"Encoding dense text with model {model_name}: {text}")
if model_name == "BM":
embedding = all_mp_net_model.encode([text])[0].tolist()
else:
embedding = bge_small_model.encode([text])[0].tolist()
logger.info(f"Finished encoding dense text")
return embedding
@lru_cache(maxsize=1000)
def encode_splade_cached(text: str) -> SparseVector:
logger.info(f"Encoding SPLADE sparse vector: {text}")
inputs = SPLADE_TOKENIZER(text, return_tensors="pt", truncation=True)
with torch.no_grad():
outputs = SPLADE_MODEL(**inputs)
logits = outputs.logits[0]
relu_log = torch.log1p(torch.relu(logits))
nonzero = relu_log.nonzero(as_tuple=False)
if nonzero.shape[0] == 0:
logger.info("No non-zero values found in SPLADE output")
return SparseVector(indices=[], values=[])
vocab_indices = nonzero[:, 1]
values = relu_log[nonzero[:, 0], nonzero[:, 1]]
vocab_indices_list = vocab_indices.cpu().numpy().tolist()
values_list = values.cpu().numpy().tolist()
index_to_value = defaultdict(float)
for idx, val in zip(vocab_indices_list, values_list):
index_to_value[idx] += val
deduped_indices = list(index_to_value.keys())
deduped_values = list(index_to_value.values())
logger.info(f"SPLADE encoding complete with {len(deduped_indices)} dimensions")
return SparseVector(
indices=deduped_indices,
values=deduped_values
)
# π Main endpoint
@app.post("/get-embedding/")
async def get_embedding(input: TextInput):
logger.info(f"Received request with model: {input.model_name}, texts: {input.text}")
model_key = input.model_name.upper()
if model_key in {"BM", "BG"}:
embeddings = [encode_dense_cached(model_key, t) for t in input.text]
logger.info(f"Returning dense embeddings for {len(embeddings)} texts")
return {"type": "dense", "embeddings": embeddings}
elif model_key == "SPLADE":
sparse_vecs = [encode_splade_cached(t).model_dump() for t in input.text]
logger.info(f"Returning sparse embeddings for {len(sparse_vecs)} texts")
return {"type": "sparse", "embeddings": sparse_vecs}
else:
embeddings = bge_small_model.encode(input.text)
return {"embeddings": embeddings.tolist()}
@app.get("/status")
async def status():
logger.info(f"Status API: Server is up and running at {datetime.now()}")
return {"status": "Server is up and running"}
|