File size: 4,533 Bytes
3a235a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from typing import AsyncGenerator
import traceback
import logging
import os
import json
import sys
import uuid
import time

logger = logging.getLogger(__name__)


class GaiaAgentServer:

    def _get_model_config(self):
        try:
            llm_provider = os.getenv("LLM_PROVIDER_GAIA")
            llm_model_name = os.getenv("LLM_MODEL_NAME_GAIA")
            llm_api_key = os.getenv("LLM_API_KEY_GAIA")
            llm_base_url = os.getenv("LLM_BASE_URL_GAIA")
            llm_temperature = os.getenv("LLM_TEMPERATURE_GAIA", 0.0)
            return {
                "provider": llm_provider,
                "model": llm_model_name,
                "api_key": llm_api_key,
                "base_url": llm_base_url,
                "temperature": llm_temperature,
            }
        except Exception as e:
            logger.warning(
                f">>> Gaia Agent: GAIA_MODEL_CONFIG is not configured, using LLM"
            )
            raise e

    def models(self):
        model = self._get_model_config()

        return [
            {
                "id": f"{model['provider']}/{model['model']}",
                "name": f"gaia_agent@{model['provider']}/{model['model']}",
            }
        ]

    async def chat_completions(self, body: dict) -> AsyncGenerator[str, None]:
        def response_line(line: str, model: str):
            return {
                "object": "chat.completion.chunk",
                "id": str(uuid.uuid4()).replace("-", ""),
                "choices": [
                    {"index": 0, "delta": {"content": line, "role": "assistant"}}
                ],
                "created": int(time.time()),
                "model": model,
            }

        try:
            logger.info(f">>> Gaia Agent: body={body}")

            prompt = body["messages"][-1]["content"]
            model = body["model"].replace("gaia_agent.", "")

            logger.info(f">>> Gaia Agent: prompt={prompt}, model={model}")

            selected_model = self._get_model_config()

            logger.info(f">>> Gaia Agent: Using model configuration: {selected_model}")

            logger.info(f">>> Gaia Agent Python Path: sys.path={sys.path}")

            llm_provider = selected_model.get("provider")
            llm_model_name = selected_model.get("model")
            llm_api_key = selected_model.get("api_key")
            llm_base_url = selected_model.get("base_url")
            llm_temperature = selected_model.get("temperature", 0.0)

            from examples.gaia.gaia_agent_runner import GaiaAgentRunner

            mcp_path = os.path.join(
                os.path.dirname(os.path.abspath(__file__)), "mcp.json"
            )
            with open(mcp_path, "r") as f:
                mcp_config = json.load(f)

            runner = GaiaAgentRunner(
                llm_provider=llm_provider,
                llm_model_name=llm_model_name,
                llm_base_url=llm_base_url,
                llm_api_key=llm_api_key,
                llm_temperature=llm_temperature,
                mcp_config=mcp_config,
            )

            logger.info(f">>> Gaia Agent: prompt={prompt}, runner={runner}")

            async for i in runner.run(prompt):
                line = response_line(i, model)
                logger.info(f">>> Gaia Agent Line: {line}")
                yield line

        except Exception as e:
            emsg = traceback.format_exc()
            logger.error(f">>> Gaia Agent Error: exception {emsg}")
            yield response_line(f"Gaia Agent Error: {emsg}", model)

        finally:
            logger.info(f">>> Gaia Agent Done")


import fastapi
from fastapi.responses import StreamingResponse

app = fastapi.FastAPI()

from examples.gaia.gaia_agent_server import GaiaAgentServer

agent_server = GaiaAgentServer()


@app.get("/v1/models")
async def models():
    return agent_server.models()


@app.post("/v1/chat/completions")
async def chat_completions(request: fastapi.Request):
    form_data = await request.json()
    logger.info(f">>> Gaia Agent Server: form_data={form_data}")

    async def event_generator():
        async for chunk in agent_server.chat_completions(form_data):
            # Format as SSE: each line needs to start with "data: " and end with two newlines
            yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"

    return StreamingResponse(event_generator(), media_type="text/event-stream")


if __name__ == "__main__":
    import uvicorn

    uvicorn.run("gaia_agent_server:app", host="0.0.0.0", port=8888)