Spaces:
Sleeping
Sleeping
import logging | |
import json | |
from typing import Dict | |
from fastapi import APIRouter, Depends | |
from fastapi.responses import StreamingResponse | |
from aworld.cmd import AgentModel, ChatCompletionRequest | |
from aworld.cmd.utils import agent_loader, agent_executor | |
from aworld.cmd.web.web_server import get_user_id_from_jwt | |
import aworld.trace as trace | |
logger = logging.getLogger(__name__) | |
router = APIRouter() | |
prefix = "/api/agent" | |
async def list_agents() -> Dict[str, AgentModel]: | |
return agent_loader.list_agents() | |
async def chat_completion( | |
form_data: ChatCompletionRequest, user_id: str = Depends(get_user_id_from_jwt) | |
) -> StreamingResponse: | |
# Set user_id from JWT to form_data | |
form_data.user_id = user_id | |
async def generate_stream(): | |
async with trace.span( | |
"/chat/chat_completion", attributes={"model": form_data.model} | |
) as span: | |
form_data.trace_id = span.get_trace_id() | |
async for chunk in agent_executor.stream_run(form_data): | |
yield f"data: {json.dumps(chunk.model_dump(), ensure_ascii=False)}\n\n" | |
return StreamingResponse( | |
generate_stream(), | |
media_type="text/event-stream", | |
headers={ | |
"Cache-Control": "no-cache", | |
"Connection": "keep-alive", | |
}, | |
) | |