File size: 1,383 Bytes
8293a2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import logging
import json
from typing import Dict
from fastapi import APIRouter, Depends
from fastapi.responses import StreamingResponse
from aworld.cmd import AgentModel, ChatCompletionRequest
from aworld.cmd.utils import agent_loader, agent_executor
from aworld.cmd.web.web_server import get_user_id_from_jwt
import aworld.trace as trace

logger = logging.getLogger(__name__)

router = APIRouter()

prefix = "/api/agent"


@router.get("/list")
@router.get("/models")
async def list_agents() -> Dict[str, AgentModel]:
    return agent_loader.list_agents()


@router.post("/chat/completions")
async def chat_completion(
    form_data: ChatCompletionRequest, user_id: str = Depends(get_user_id_from_jwt)
) -> StreamingResponse:
    # Set user_id from JWT to form_data
    form_data.user_id = user_id

    async def generate_stream():
        async with trace.span(
            "/chat/chat_completion", attributes={"model": form_data.model}
        ) as span:
            form_data.trace_id = span.get_trace_id()
            async for chunk in agent_executor.stream_run(form_data):
                yield f"data: {json.dumps(chunk.model_dump(), ensure_ascii=False)}\n\n"

    return StreamingResponse(
        generate_stream(),
        media_type="text/event-stream",
        headers={
            "Cache-Control": "no-cache",
            "Connection": "keep-alive",
        },
    )