File size: 3,999 Bytes
693e699
 
 
32c0b8d
693e699
32c0b8d
 
 
b3e42d1
60cacbb
 
693e699
32c0b8d
 
 
 
 
 
 
 
 
 
693e699
32c0b8d
 
 
693e699
32c0b8d
693e699
32c0b8d
 
 
 
 
 
693e699
32c0b8d
693e699
32c0b8d
693e699
 
32c0b8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60cacbb
 
 
 
 
 
 
 
 
 
 
32c0b8d
60cacbb
 
32c0b8d
 
60cacbb
32c0b8d
693e699
 
32c0b8d
 
 
 
 
 
 
 
 
 
 
693e699
 
 
91fb3b7
 
 
 
b3e42d1
f4ba8ea
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from fastapi import FastAPI
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForMaskedLM
from typing import List
import torch
from functools import lru_cache
import logging
from datetime import datetime
from collections import defaultdict


# πŸ”§ Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# πŸš€ Initialize FastAPI app
app = FastAPI()
logger.info("Starting FastAPI application")

# πŸ”Œ Load SentenceTransformer models
logger.info("Loading BGE small model...")
bge_small_model = SentenceTransformer('BAAI/bge-small-en-v1.5', device="cpu")
logger.info("Loaded BGE small model")

logger.info("Loading All-MPNet model...")
all_mp_net_model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2', device="cpu")
logger.info("Loaded All-MPNet model")

# πŸ”Œ Load SPLADE model
logger.info("Loading SPLADE model...")
SPLADE_MODEL = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-ensembledistil", trust_remote_code=True)
SPLADE_TOKENIZER = AutoTokenizer.from_pretrained("naver/splade-cocondenser-ensembledistil")
SPLADE_MODEL.eval()
logger.info("Loaded SPLADE model")

# πŸ“¦ Request and response models
class TextInput(BaseModel):
    text: List[str]
    model_name: str

class SparseVector(BaseModel):
    indices: List[int]
    values: List[float]

# 🧠 LRU cacheable versions
@lru_cache(maxsize=1000)
def encode_dense_cached(model_name: str, text: str):
    logger.info(f"Encoding dense text with model {model_name}: {text}")
    if model_name == "BM":
        embedding = all_mp_net_model.encode([text])[0].tolist()
    else:
        embedding = bge_small_model.encode([text])[0].tolist()
    logger.info(f"Finished encoding dense text")
    return embedding

@lru_cache(maxsize=1000)
def encode_splade_cached(text: str) -> SparseVector:
    logger.info(f"Encoding SPLADE sparse vector: {text}")
    inputs = SPLADE_TOKENIZER(text, return_tensors="pt", truncation=True)
    with torch.no_grad():
        outputs = SPLADE_MODEL(**inputs)

    logits = outputs.logits[0]
    relu_log = torch.log1p(torch.relu(logits))
    nonzero = relu_log.nonzero(as_tuple=False)

    if nonzero.shape[0] == 0:
        logger.info("No non-zero values found in SPLADE output")
        return SparseVector(indices=[], values=[])

    vocab_indices = nonzero[:, 1]
    values = relu_log[nonzero[:, 0], nonzero[:, 1]]

    vocab_indices_list = vocab_indices.cpu().numpy().tolist()
    values_list = values.cpu().numpy().tolist()

    index_to_value = defaultdict(float)
    for idx, val in zip(vocab_indices_list, values_list):
        index_to_value[idx] += val

    deduped_indices = list(index_to_value.keys())
    deduped_values = list(index_to_value.values())

    logger.info(f"SPLADE encoding complete with {len(deduped_indices)} dimensions")
    return SparseVector(
        indices=deduped_indices,
        values=deduped_values
    )


# πŸš€ Main endpoint
@app.post("/get-embedding/")
async def get_embedding(input: TextInput):
    logger.info(f"Received request with model: {input.model_name}, texts: {input.text}")

    model_key = input.model_name.upper()
    if model_key in {"BM", "BG"}:
        embeddings = [encode_dense_cached(model_key, t) for t in input.text]
        logger.info(f"Returning dense embeddings for {len(embeddings)} texts")
        return {"type": "dense", "embeddings": embeddings}
    elif model_key == "SPLADE":
        sparse_vecs = [encode_splade_cached(t).model_dump() for t in input.text]
        logger.info(f"Returning sparse embeddings for {len(sparse_vecs)} texts")
        return {"type": "sparse", "embeddings": sparse_vecs}
    else:
        embeddings = bge_small_model.encode(input.text)
    return {"embeddings": embeddings.tolist()}


@app.get("/status")
async def status():
    logger.info(f"Status API: Server is up and running at {datetime.now()}")
    return {"status": "Server is up and running"}