File size: 3,782 Bytes
a3e9331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import logging
import os
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Optional

import mteb
from sqlitedict import SqliteDict

from pylate import indexes, models, retrieve

# Configure logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)


class IndexType(Enum):
    """Supported index types."""

    PREBUILT = "prebuilt"
    LOCAL = "local"


@dataclass
class IndexConfig:
    """Configuration for a search index."""

    name: str
    type: IndexType
    path: str
    description: Optional[str] = None


class MCPyLate:
    """Main server class that manages PyLate indexes and search operations."""

    def __init__(self, override: bool = False):
        self.logger = logging.getLogger(__name__)
        dataset_name = "leetcode"

        model_name = "lightonai/Reason-ModernColBERT"
        override = override or not os.path.exists(
            f"indexes/{dataset_name}_{model_name.split('/')[-1]}"
        )

        self.model = models.ColBERT(
            model_name_or_path=model_name,
        )
        self.index = indexes.PLAID(
            override=override,
            index_name=f"{dataset_name}_{model_name.split('/')[-1]}",
        )
        self.id_to_doc = SqliteDict(
            f"./indexes/{dataset_name}_{model_name.split('/')[-1]}/id_to_doc.sqlite",
            outer_stack=False,
        )

        self.retriever = retrieve.ColBERT(index=self.index)
        if override:
            tasks = mteb.get_tasks(tasks=["BrightRetrieval"])
            tasks[0].load_data()
            for doc, doc_id in zip(
                list(tasks[0].corpus[dataset_name]["standard"].values()),
                list(tasks[0].corpus[dataset_name]["standard"].keys()),
            ):
                self.id_to_doc[doc_id] = doc
            self.id_to_doc.commit()  # Don't forget to commit to save changes!
            documents_embeddings = self.model.encode(
                sentences=list(tasks[0].corpus[dataset_name]["standard"].values()),
                batch_size=100,
                is_query=False,
                show_progress_bar=True,
            )

            self.index.add_documents(
                documents_ids=list(tasks[0].corpus[dataset_name]["standard"].keys()),
                documents_embeddings=documents_embeddings,
            )
        self.logger.info("Created PyLate MCP Server")

    def get_document(
        self,
        docid: str,
    ) -> Optional[Dict[str, Any]]:
        """Retrieve full document by document ID."""

        return {"docid": docid, "text": self.id_to_doc[docid]}

    def search(self, query: str, k: int = 10) -> List[Dict[str, Any]]:
        """Perform multi-vector search on specified index."""
        try:
            query_embeddings = self.model.encode(
                sentences=[query],
                is_query=True,
                show_progress_bar=True,
                batch_size=32,
            )
            scores = self.retriever.retrieve(queries_embeddings=query_embeddings, k=20)
            results = []
            for score in scores[0]:
                results.append(
                    {
                        "docid": score["id"],
                        "score": round(score["score"], 5),
                        "text": self.id_to_doc[score["id"]],
                        # "text": self.id_to_doc[score["id"]][:200] + "…"
                        # if len(self.id_to_doc[score["id"]]) > 200
                        # else self.id_to_doc[score["id"]],
                    }
                )
            return results
        except Exception as e:
            self.logger.error(f"Search failed: {e}")
            raise RuntimeError(f"Search operation failed: {e}")