File size: 10,366 Bytes
0d5ac71
dc9d63b
2dd4294
dc9d63b
 
 
0d5ac71
 
 
dc9d63b
2dd4294
a467728
 
 
 
 
 
 
 
dc9d63b
2dd4294
dc9d63b
 
2dd4294
dc9d63b
 
 
0733fd6
2dd4294
9fa019c
2dd4294
9fa019c
 
dc9d63b
 
2dd4294
 
 
 
0733fd6
 
2dd4294
 
 
dc9d63b
 
 
 
2dd4294
 
0733fd6
 
2dd4294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc9d63b
9fa019c
 
0733fd6
dc9d63b
c8dd6f4
0733fd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc9d63b
 
2dd4294
 
a467728
dc9d63b
0733fd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc9d63b
0733fd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc9d63b
 
2dd4294
a467728
 
0733fd6
 
 
dc9d63b
0733fd6
 
 
dc9d63b
0733fd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc9d63b
2dd4294
a467728
 
0733fd6
 
 
 
 
 
 
 
 
 
 
 
 
a467728
 
 
0733fd6
 
 
 
 
 
 
 
 
a467728
0733fd6
a467728
 
0733fd6
a467728
0733fd6
 
 
a467728
2dd4294
 
dc9d63b
2dd4294
0733fd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc9d63b
 
0d5ac71
0733fd6
 
 
 
 
 
0d5ac71
 
0733fd6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
import asyncio
import os
from typing import Any, List, Dict

import mcp.types as types
from mcp import CreateMessageResult
from mcp.server import Server
from mcp.server.stdio import stdio_server

from ourllm import genratequestionnaire, gradeanswers
from database_module import init_db
from database_module import (
    get_all_models_handler,
    search_models_handler,
    save_diagnostic_data,
    get_baseline_diagnostics,
    save_drift_score,
    register_model_with_capabilities
)

# Initialize data directory and database
DATA_DIR = "data"
os.makedirs(DATA_DIR, exist_ok=True)
init_db()

app = Server("mcp-drift-server")


# === Tool Manifest ===
@app.list_tools()
async def list_tools() -> List[types.Tool]:
    return [
        types.Tool(
            name="run_initial_diagnostics",
            description="Generate and store baseline diagnostics for a connected LLM.",
            inputSchema={
                "type": "object",
                "properties": {
                    "model": {"type": "string", "description": "The name of the model to run diagnostics on"},
                    "model_capabilities": {"type": "string",
                                           "description": "Full description of the model's capabilities, along with the system prompt."}
                },
                "required": ["model", "model_capabilities"]
            },
        ),
        types.Tool(
            name="check_drift",
            description="Re-run diagnostics and compare to baseline for drift scoring.",
            inputSchema={
                "type": "object",
                "properties": {
                    "model": {"type": "string", "description": "The name of the model to run diagnostics on"}},
                "required": ["model"]
            },
        ),
        types.Tool(
            name="get_all_models",
            description="Retrieve all registered models from the database.",
            inputSchema={"type": "object", "properties": {}, "required": []}
        ),
        types.Tool(
            name="search_models",
            description="Search registered models by name.",
            inputSchema={
                "type": "object",
                "properties": {"query": {"type": "string", "description": "Substring to match model names against"}},
                "required": ["query"]
            }
        ),
    ]


# === Sampling Wrapper ===
async def sample(messages: list[types.SamplingMessage], max_tokens=600) -> CreateMessageResult:
    try:
        return await app.request_context.session.create_message(
            messages=messages,
            max_tokens=max_tokens,
            temperature=0.7
        )
    except Exception as e:
        print(f"Error in sampling: {e}")
        # Return a fallback response
        return CreateMessageResult(
            content=types.TextContent(type="text", text="Error generating response"),
            model="unknown",
            role="assistant"
        )


# === Core Logic ===
async def run_initial_diagnostics(arguments: Dict[str, Any]) -> List[types.TextContent]:
    model = arguments["model"]
    capabilities = arguments["model_capabilities"]

    try:
        # 1. Generate questionnaire using ourllm (returns list of strings)
        questions = genratequestionnaire(model, capabilities)

        # 2. Convert questions to sampling messages and get answers
        answers = []
        for question_text in questions:
            try:
                sampling_msg = types.SamplingMessage(
                    role="user",
                    content=types.TextContent(type="text", text=question_text)
                )
                answer_result = await sample([sampling_msg])

                # Extract text content from the answer
                if hasattr(answer_result, 'content'):
                    if hasattr(answer_result.content, 'text'):
                        answers.append(answer_result.content.text)
                    else:
                        answers.append(str(answer_result.content))
                else:
                    answers.append("No response generated")

            except Exception as e:
                print(f"Error getting answer for question '{question_text}': {e}")
                answers.append(f"Error: {str(e)}")

        # 3. Save the model capabilities and questions/answers to database
        try:
            register_model_with_capabilities(model, capabilities)
            save_diagnostic_data(
                model_name=model,
                questions=questions,
                answers=answers,
                is_baseline=True
            )
        except Exception as e:
            print(f"Error saving diagnostic data: {e}")
            return [types.TextContent(type="text", text=f"❌ Error saving baseline for {model}: {str(e)}")]

        return [
            types.TextContent(type="text", text=f"βœ… Baseline stored for model: {model} ({len(questions)} questions)")]

    except Exception as e:
        print(f"Error in run_initial_diagnostics: {e}")
        return [types.TextContent(type="text", text=f"❌ Error running diagnostics for {model}: {str(e)}")]


async def check_drift(arguments: Dict[str, Any]) -> List[types.TextContent]:
    model = arguments["model"]

    try:
        # Get baseline from database
        baseline = get_baseline_diagnostics(model)

        # Ensure baseline exists
        if not baseline:
            return [types.TextContent(type="text", text=f"❌ No baseline for model: {model}")]

        # Get old answers from baseline
        old_answers = baseline["answers"]
        questions = baseline["questions"]

        # Ask the model the same questions again
        new_answers = []
        for question_text in questions:
            try:
                sampling_msg = types.SamplingMessage(
                    role="user",
                    content=types.TextContent(type="text", text=question_text)
                )
                answer_result = await sample([sampling_msg])

                # Extract text content from the answer
                if hasattr(answer_result, 'content'):
                    if hasattr(answer_result.content, 'text'):
                        new_answers.append(answer_result.content.text)
                    else:
                        new_answers.append(str(answer_result.content))
                else:
                    new_answers.append("No response generated")

            except Exception as e:
                print(f"Error getting new answer for question '{question_text}': {e}")
                new_answers.append(f"Error: {str(e)}")

        # Grade the answers and get a drift score (returns string)
        drift_score_str = gradeanswers(old_answers, new_answers)

        # Save the latest responses and drift score to database
        try:
            save_diagnostic_data(
                model_name=model,
                questions=questions,
                answers=new_answers,
                is_baseline=False
            )
            save_drift_score(model, drift_score_str)
        except Exception as e:
            print(f"Error saving drift data: {e}")

        # Alert threshold
        try:
            score_val = float(drift_score_str)
            alert = "🚨 Significant drift!" if score_val > 50 else "βœ… Drift OK"
        except ValueError:
            alert = "⚠️ Drift score not numeric"

        return [
            types.TextContent(type="text", text=f"Drift score for {model}: {drift_score_str}%"),
            types.TextContent(type="text", text=alert)
        ]

    except Exception as e:
        print(f"Error in check_drift: {e}")
        return [types.TextContent(type="text", text=f"❌ Error checking drift for {model}: {str(e)}")]


# Database tool handlers
async def get_all_models_handler_async(_: Dict[str, Any]) -> List[types.TextContent]:
    try:
        models = get_all_models_handler({})
        if not models:
            return [types.TextContent(type="text", text="No models registered.")]

        model_list = "\n".join([f"β€’ {m['name']} - {m.get('description', 'No description')}" for m in models])
        return [types.TextContent(
            type="text",
            text=f"Registered models:\n{model_list}"
        )]
    except Exception as e:
        print(f"Error getting all models: {e}")
        return [types.TextContent(type="text", text=f"❌ Error retrieving models: {str(e)}")]


async def search_models_handler_async(arguments: Dict[str, Any]) -> List[types.TextContent]:
    try:
        query = arguments.get("query", "")
        models = search_models_handler({"search_term": query})

        if not models:
            return [types.TextContent(
                type="text",
                text=f"No models found matching '{query}'."
            )]

        model_list = "\n".join([f"β€’ {m['name']} - {m.get('description', 'No description')}" for m in models])
        return [types.TextContent(
            type="text",
            text=f"Models matching '{query}':\n{model_list}"
        )]
    except Exception as e:
        print(f"Error searching models: {e}")
        return [types.TextContent(type="text", text=f"❌ Error searching models: {str(e)}")]


# === Dispatcher ===
@app.call_tool()
async def dispatch_tool(name: str, arguments: Dict[str, Any] | None = None):
    try:
        if name == "run_initial_diagnostics":
            return await run_initial_diagnostics(arguments)
        elif name == "check_drift":
            return await check_drift(arguments)
        elif name == "get_all_models":
            return await get_all_models_handler_async(arguments or {})
        elif name == "search_models":
            return await search_models_handler_async(arguments or {})
        else:
            return [types.TextContent(type="text", text=f"❌ Unknown tool: {name}")]
    except Exception as e:
        print(f"Error in dispatch_tool for {name}: {e}")
        return [types.TextContent(type="text", text=f"❌ Error executing {name}: {str(e)}")]


# === Entrypoint ===
async def main():
    try:
        async with stdio_server() as (reader, writer):
            await app.run(reader, writer, app.create_initialization_options())
    except Exception as e:
        print(f"Error running MCP server: {e}")


if __name__ == "__main__":
    asyncio.run(main())