import logging import os from dataclasses import dataclass from enum import Enum from typing import Any, Dict, List, Optional import mteb from sqlitedict import SqliteDict from pylate import indexes, models, retrieve # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) class IndexType(Enum): """Supported index types.""" PREBUILT = "prebuilt" LOCAL = "local" @dataclass class IndexConfig: """Configuration for a search index.""" name: str type: IndexType path: str description: Optional[str] = None class MCPyLate: """Main server class that manages PyLate indexes and search operations.""" def __init__(self, override: bool = False): self.logger = logging.getLogger(__name__) dataset_name = "leetcode" model_name = "lightonai/Reason-ModernColBERT" override = override or not os.path.exists( f"indexes/{dataset_name}_{model_name.split('/')[-1]}" ) self.model = models.ColBERT( model_name_or_path=model_name, ) self.index = indexes.PLAID( override=override, index_name=f"{dataset_name}_{model_name.split('/')[-1]}", ) self.id_to_doc = SqliteDict( f"./indexes/{dataset_name}_{model_name.split('/')[-1]}/id_to_doc.sqlite", outer_stack=False, ) self.retriever = retrieve.ColBERT(index=self.index) if override: tasks = mteb.get_tasks(tasks=["BrightRetrieval"]) tasks[0].load_data() for doc, doc_id in zip( list(tasks[0].corpus[dataset_name]["standard"].values()), list(tasks[0].corpus[dataset_name]["standard"].keys()), ): self.id_to_doc[doc_id] = doc self.id_to_doc.commit() # Don't forget to commit to save changes! documents_embeddings = self.model.encode( sentences=list(tasks[0].corpus[dataset_name]["standard"].values()), batch_size=100, is_query=False, show_progress_bar=True, ) self.index.add_documents( documents_ids=list(tasks[0].corpus[dataset_name]["standard"].keys()), documents_embeddings=documents_embeddings, ) self.logger.info("Created PyLate MCP Server") def get_document( self, docid: str, ) -> Optional[Dict[str, Any]]: """Retrieve full document by document ID.""" return {"docid": docid, "text": self.id_to_doc[docid]} def search(self, query: str, k: int = 10) -> List[Dict[str, Any]]: """Perform multi-vector search on specified index.""" try: query_embeddings = self.model.encode( sentences=[query], is_query=True, show_progress_bar=True, batch_size=32, ) scores = self.retriever.retrieve(queries_embeddings=query_embeddings, k=20) results = [] for score in scores[0]: results.append( { "docid": score["id"], "score": round(score["score"], 5), "text": self.id_to_doc[score["id"]], # "text": self.id_to_doc[score["id"]][:200] + "…" # if len(self.id_to_doc[score["id"]]) > 200 # else self.id_to_doc[score["id"]], } ) return results except Exception as e: self.logger.error(f"Search failed: {e}") raise RuntimeError(f"Search operation failed: {e}")