Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import asyncio | |
| from typing import AsyncIterator, Optional | |
| from .chat import ChatSession | |
| from .config import OLLAMA_HOST, MODEL_NAME, SYSTEM_PROMPT, JUNIOR_PROMPT | |
| from .tools import execute_terminal | |
| from .db import Message as DBMessage | |
| __all__ = [ | |
| "TeamChatSession", | |
| "send_to_junior", | |
| "send_to_junior_async", | |
| "set_team", | |
| ] | |
| _TEAM: Optional["TeamChatSession"] = None | |
| def set_team(team: "TeamChatSession" | None) -> None: | |
| global _TEAM | |
| _TEAM = team | |
| async def send_to_junior(message: str) -> str: | |
| """Forward ``message`` to the junior agent and await the response.""" | |
| if _TEAM is None: | |
| return "No active team" | |
| return await _TEAM.queue_message_to_junior(message, enqueue=False) | |
| # Backwards compatibility --------------------------------------------------- | |
| send_to_junior_async = send_to_junior | |
| class TeamChatSession: | |
| def __init__( | |
| self, | |
| user: str = "default", | |
| session: str = "default", | |
| host: str = OLLAMA_HOST, | |
| model: str = MODEL_NAME, | |
| ) -> None: | |
| self._to_junior: asyncio.Queue[tuple[str, asyncio.Future[str], bool]] = asyncio.Queue() | |
| self._to_senior: asyncio.Queue[str] = asyncio.Queue() | |
| self._junior_task: asyncio.Task | None = None | |
| self.senior = ChatSession( | |
| user=user, | |
| session=session, | |
| host=host, | |
| model=model, | |
| system_prompt=SYSTEM_PROMPT, | |
| tools=[execute_terminal, send_to_junior], | |
| ) | |
| self.junior = ChatSession( | |
| user=user, | |
| session=f"{session}-junior", | |
| host=host, | |
| model=model, | |
| system_prompt=JUNIOR_PROMPT, | |
| tools=[execute_terminal], | |
| ) | |
| async def __aenter__(self) -> "TeamChatSession": | |
| await self.senior.__aenter__() | |
| await self.junior.__aenter__() | |
| set_team(self) | |
| return self | |
| async def __aexit__(self, exc_type, exc, tb) -> None: | |
| set_team(None) | |
| await self.senior.__aexit__(exc_type, exc, tb) | |
| await self.junior.__aexit__(exc_type, exc, tb) | |
| def upload_document(self, file_path: str) -> str: | |
| return self.senior.upload_document(file_path) | |
| async def queue_message_to_junior( | |
| self, message: str, *, enqueue: bool = True | |
| ) -> str: | |
| """Send ``message`` to the junior agent and wait for the reply.""" | |
| loop = asyncio.get_running_loop() | |
| fut: asyncio.Future[str] = loop.create_future() | |
| await self._to_junior.put((message, fut, enqueue)) | |
| if not self._junior_task or self._junior_task.done(): | |
| self._junior_task = asyncio.create_task(self._process_junior()) | |
| return await fut | |
| async def _process_junior(self) -> None: | |
| try: | |
| while not self._to_junior.empty(): | |
| msg, fut, enqueue = await self._to_junior.get() | |
| self.junior._messages.append({"role": "tool", "name": "senior", "content": msg}) | |
| DBMessage.create(conversation=self.junior._conversation, role="tool", content=msg) | |
| parts: list[str] = [] | |
| async for part in self.junior.continue_stream(): | |
| if part: | |
| parts.append(part) | |
| result = "\n".join(parts) | |
| if enqueue and result.strip(): | |
| await self._to_senior.put(result) | |
| if not fut.done(): | |
| fut.set_result(result) | |
| if self.senior._state == "idle": | |
| await self._deliver_junior_messages() | |
| finally: | |
| self._junior_task = None | |
| async def _deliver_junior_messages(self) -> None: | |
| while not self._to_senior.empty(): | |
| msg = await self._to_senior.get() | |
| self.senior._messages.append({"role": "tool", "name": "junior", "content": msg}) | |
| DBMessage.create(conversation=self.senior._conversation, role="tool", content=msg) | |
| async def chat_stream(self, prompt: str) -> AsyncIterator[str]: | |
| await self._deliver_junior_messages() | |
| async for part in self.senior.chat_stream(prompt): | |
| yield part | |
| await self._deliver_junior_messages() | |