|
import logging |
|
import os |
|
from dataclasses import dataclass |
|
from enum import Enum |
|
from typing import Any, Dict, List, Optional |
|
|
|
import mteb |
|
from sqlitedict import SqliteDict |
|
|
|
from pylate import indexes, models, retrieve |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" |
|
) |
|
|
|
|
|
class IndexType(Enum): |
|
"""Supported index types.""" |
|
|
|
PREBUILT = "prebuilt" |
|
LOCAL = "local" |
|
|
|
|
|
@dataclass |
|
class IndexConfig: |
|
"""Configuration for a search index.""" |
|
|
|
name: str |
|
type: IndexType |
|
path: str |
|
description: Optional[str] = None |
|
|
|
|
|
class MCPyLate: |
|
"""Main server class that manages PyLate indexes and search operations.""" |
|
|
|
def __init__(self, override: bool = False): |
|
self.logger = logging.getLogger(__name__) |
|
dataset_name = "leetcode" |
|
|
|
model_name = "lightonai/Reason-ModernColBERT" |
|
override = override or not os.path.exists( |
|
f"indexes/{dataset_name}_{model_name.split('/')[-1]}" |
|
) |
|
|
|
self.model = models.ColBERT( |
|
model_name_or_path=model_name, |
|
) |
|
self.index = indexes.PLAID( |
|
override=override, |
|
index_name=f"{dataset_name}_{model_name.split('/')[-1]}", |
|
) |
|
self.id_to_doc = SqliteDict( |
|
f"./indexes/{dataset_name}_{model_name.split('/')[-1]}/id_to_doc.sqlite", |
|
outer_stack=False, |
|
) |
|
|
|
self.retriever = retrieve.ColBERT(index=self.index) |
|
if override: |
|
tasks = mteb.get_tasks(tasks=["BrightRetrieval"]) |
|
tasks[0].load_data() |
|
for doc, doc_id in zip( |
|
list(tasks[0].corpus[dataset_name]["standard"].values()), |
|
list(tasks[0].corpus[dataset_name]["standard"].keys()), |
|
): |
|
self.id_to_doc[doc_id] = doc |
|
self.id_to_doc.commit() |
|
documents_embeddings = self.model.encode( |
|
sentences=list(tasks[0].corpus[dataset_name]["standard"].values()), |
|
batch_size=100, |
|
is_query=False, |
|
show_progress_bar=True, |
|
) |
|
|
|
self.index.add_documents( |
|
documents_ids=list(tasks[0].corpus[dataset_name]["standard"].keys()), |
|
documents_embeddings=documents_embeddings, |
|
) |
|
self.logger.info("Created PyLate MCP Server") |
|
|
|
def get_document( |
|
self, |
|
docid: str, |
|
) -> Optional[Dict[str, Any]]: |
|
"""Retrieve full document by document ID.""" |
|
|
|
return {"docid": docid, "text": self.id_to_doc[docid]} |
|
|
|
def search(self, query: str, k: int = 10) -> List[Dict[str, Any]]: |
|
"""Perform multi-vector search on specified index.""" |
|
try: |
|
query_embeddings = self.model.encode( |
|
sentences=[query], |
|
is_query=True, |
|
show_progress_bar=True, |
|
batch_size=32, |
|
) |
|
scores = self.retriever.retrieve(queries_embeddings=query_embeddings, k=20) |
|
results = [] |
|
for score in scores[0]: |
|
results.append( |
|
{ |
|
"docid": score["id"], |
|
"score": round(score["score"], 5), |
|
"text": self.id_to_doc[score["id"]], |
|
|
|
|
|
|
|
} |
|
) |
|
return results |
|
except Exception as e: |
|
self.logger.error(f"Search failed: {e}") |
|
raise RuntimeError(f"Search operation failed: {e}") |
|
|