File size: 3,782 Bytes
a3e9331 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import logging
import os
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Optional
import mteb
from sqlitedict import SqliteDict
from pylate import indexes, models, retrieve
# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
class IndexType(Enum):
"""Supported index types."""
PREBUILT = "prebuilt"
LOCAL = "local"
@dataclass
class IndexConfig:
"""Configuration for a search index."""
name: str
type: IndexType
path: str
description: Optional[str] = None
class MCPyLate:
"""Main server class that manages PyLate indexes and search operations."""
def __init__(self, override: bool = False):
self.logger = logging.getLogger(__name__)
dataset_name = "leetcode"
model_name = "lightonai/Reason-ModernColBERT"
override = override or not os.path.exists(
f"indexes/{dataset_name}_{model_name.split('/')[-1]}"
)
self.model = models.ColBERT(
model_name_or_path=model_name,
)
self.index = indexes.PLAID(
override=override,
index_name=f"{dataset_name}_{model_name.split('/')[-1]}",
)
self.id_to_doc = SqliteDict(
f"./indexes/{dataset_name}_{model_name.split('/')[-1]}/id_to_doc.sqlite",
outer_stack=False,
)
self.retriever = retrieve.ColBERT(index=self.index)
if override:
tasks = mteb.get_tasks(tasks=["BrightRetrieval"])
tasks[0].load_data()
for doc, doc_id in zip(
list(tasks[0].corpus[dataset_name]["standard"].values()),
list(tasks[0].corpus[dataset_name]["standard"].keys()),
):
self.id_to_doc[doc_id] = doc
self.id_to_doc.commit() # Don't forget to commit to save changes!
documents_embeddings = self.model.encode(
sentences=list(tasks[0].corpus[dataset_name]["standard"].values()),
batch_size=100,
is_query=False,
show_progress_bar=True,
)
self.index.add_documents(
documents_ids=list(tasks[0].corpus[dataset_name]["standard"].keys()),
documents_embeddings=documents_embeddings,
)
self.logger.info("Created PyLate MCP Server")
def get_document(
self,
docid: str,
) -> Optional[Dict[str, Any]]:
"""Retrieve full document by document ID."""
return {"docid": docid, "text": self.id_to_doc[docid]}
def search(self, query: str, k: int = 10) -> List[Dict[str, Any]]:
"""Perform multi-vector search on specified index."""
try:
query_embeddings = self.model.encode(
sentences=[query],
is_query=True,
show_progress_bar=True,
batch_size=32,
)
scores = self.retriever.retrieve(queries_embeddings=query_embeddings, k=20)
results = []
for score in scores[0]:
results.append(
{
"docid": score["id"],
"score": round(score["score"], 5),
"text": self.id_to_doc[score["id"]],
# "text": self.id_to_doc[score["id"]][:200] + "…"
# if len(self.id_to_doc[score["id"]]) > 200
# else self.id_to_doc[score["id"]],
}
)
return results
except Exception as e:
self.logger.error(f"Search failed: {e}")
raise RuntimeError(f"Search operation failed: {e}")
|