Spaces:
Runtime error
Runtime error
tech-envision
commited on
Commit
·
49e2d72
1
Parent(s):
8114c3f
Return tool calls when present
Browse files- src/chat.py +35 -20
src/chat.py
CHANGED
|
@@ -146,12 +146,27 @@ class ChatSession:
|
|
| 146 |
messages.append({"role": "tool", "content": msg.content})
|
| 147 |
return messages
|
| 148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
@staticmethod
|
| 150 |
def _store_assistant_message(conversation: Conversation, message: Message) -> None:
|
| 151 |
"""Persist assistant messages, storing tool calls when present."""
|
| 152 |
|
| 153 |
if message.tool_calls:
|
| 154 |
-
content =
|
| 155 |
else:
|
| 156 |
content = message.content or ""
|
| 157 |
|
|
@@ -304,7 +319,7 @@ class ChatSession:
|
|
| 304 |
final_resp = await self._handle_tool_calls(
|
| 305 |
self._messages, response, self._conversation
|
| 306 |
)
|
| 307 |
-
return final_resp.message
|
| 308 |
|
| 309 |
async def chat_stream(self, prompt: str) -> AsyncIterator[str]:
|
| 310 |
async with self._lock:
|
|
@@ -329,10 +344,9 @@ class ChatSession:
|
|
| 329 |
async for resp in self._handle_tool_calls_stream(
|
| 330 |
self._messages, response, self._conversation
|
| 331 |
):
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
yield resp.message.content
|
| 336 |
|
| 337 |
async def _chat_during_tool(self, prompt: str) -> AsyncIterator[str]:
|
| 338 |
DBMessage.create(conversation=self._conversation, role="user", content=prompt)
|
|
@@ -365,23 +379,24 @@ class ChatSession:
|
|
| 365 |
nxt = await self.ask(self._messages, think=True)
|
| 366 |
self._store_assistant_message(self._conversation, nxt.message)
|
| 367 |
self._messages.append(nxt.message.model_dump())
|
| 368 |
-
|
| 369 |
-
|
|
|
|
| 370 |
async for part in self._handle_tool_calls_stream(
|
| 371 |
self._messages, nxt, self._conversation
|
| 372 |
):
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
yield part.message.content
|
| 377 |
else:
|
| 378 |
resp = await user_task
|
| 379 |
self._store_assistant_message(self._conversation, resp.message)
|
| 380 |
self._messages.append(resp.message.model_dump())
|
| 381 |
async with self._lock:
|
| 382 |
self._state = "awaiting_tool"
|
| 383 |
-
|
| 384 |
-
|
|
|
|
| 385 |
result = await exec_task
|
| 386 |
self._tool_task = None
|
| 387 |
self._messages.append(
|
|
@@ -395,12 +410,12 @@ class ChatSession:
|
|
| 395 |
nxt = await self.ask(self._messages, think=True)
|
| 396 |
self._store_assistant_message(self._conversation, nxt.message)
|
| 397 |
self._messages.append(nxt.message.model_dump())
|
| 398 |
-
|
| 399 |
-
|
|
|
|
| 400 |
async for part in self._handle_tool_calls_stream(
|
| 401 |
self._messages, nxt, self._conversation
|
| 402 |
):
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
yield part.message.content
|
|
|
|
| 146 |
messages.append({"role": "tool", "content": msg.content})
|
| 147 |
return messages
|
| 148 |
|
| 149 |
+
# ------------------------------------------------------------------
|
| 150 |
+
@staticmethod
|
| 151 |
+
def _serialize_tool_calls(calls: List[Message.ToolCall]) -> str:
|
| 152 |
+
"""Convert tool calls to a JSON string for storage or output."""
|
| 153 |
+
|
| 154 |
+
return json.dumps([c.model_dump() for c in calls])
|
| 155 |
+
|
| 156 |
+
@staticmethod
|
| 157 |
+
def _format_output(message: Message) -> str:
|
| 158 |
+
"""Return tool calls as JSON or message content if present."""
|
| 159 |
+
|
| 160 |
+
if message.tool_calls:
|
| 161 |
+
return ChatSession._serialize_tool_calls(message.tool_calls)
|
| 162 |
+
return message.content or ""
|
| 163 |
+
|
| 164 |
@staticmethod
|
| 165 |
def _store_assistant_message(conversation: Conversation, message: Message) -> None:
|
| 166 |
"""Persist assistant messages, storing tool calls when present."""
|
| 167 |
|
| 168 |
if message.tool_calls:
|
| 169 |
+
content = ChatSession._serialize_tool_calls(message.tool_calls)
|
| 170 |
else:
|
| 171 |
content = message.content or ""
|
| 172 |
|
|
|
|
| 319 |
final_resp = await self._handle_tool_calls(
|
| 320 |
self._messages, response, self._conversation
|
| 321 |
)
|
| 322 |
+
return self._format_output(final_resp.message)
|
| 323 |
|
| 324 |
async def chat_stream(self, prompt: str) -> AsyncIterator[str]:
|
| 325 |
async with self._lock:
|
|
|
|
| 344 |
async for resp in self._handle_tool_calls_stream(
|
| 345 |
self._messages, response, self._conversation
|
| 346 |
):
|
| 347 |
+
text = self._format_output(resp.message)
|
| 348 |
+
if text:
|
| 349 |
+
yield text
|
|
|
|
| 350 |
|
| 351 |
async def _chat_during_tool(self, prompt: str) -> AsyncIterator[str]:
|
| 352 |
DBMessage.create(conversation=self._conversation, role="user", content=prompt)
|
|
|
|
| 379 |
nxt = await self.ask(self._messages, think=True)
|
| 380 |
self._store_assistant_message(self._conversation, nxt.message)
|
| 381 |
self._messages.append(nxt.message.model_dump())
|
| 382 |
+
text = self._format_output(nxt.message)
|
| 383 |
+
if text:
|
| 384 |
+
yield text
|
| 385 |
async for part in self._handle_tool_calls_stream(
|
| 386 |
self._messages, nxt, self._conversation
|
| 387 |
):
|
| 388 |
+
text = self._format_output(part.message)
|
| 389 |
+
if text:
|
| 390 |
+
yield text
|
|
|
|
| 391 |
else:
|
| 392 |
resp = await user_task
|
| 393 |
self._store_assistant_message(self._conversation, resp.message)
|
| 394 |
self._messages.append(resp.message.model_dump())
|
| 395 |
async with self._lock:
|
| 396 |
self._state = "awaiting_tool"
|
| 397 |
+
text = self._format_output(resp.message)
|
| 398 |
+
if text:
|
| 399 |
+
yield text
|
| 400 |
result = await exec_task
|
| 401 |
self._tool_task = None
|
| 402 |
self._messages.append(
|
|
|
|
| 410 |
nxt = await self.ask(self._messages, think=True)
|
| 411 |
self._store_assistant_message(self._conversation, nxt.message)
|
| 412 |
self._messages.append(nxt.message.model_dump())
|
| 413 |
+
text = self._format_output(nxt.message)
|
| 414 |
+
if text:
|
| 415 |
+
yield text
|
| 416 |
async for part in self._handle_tool_calls_stream(
|
| 417 |
self._messages, nxt, self._conversation
|
| 418 |
):
|
| 419 |
+
text = self._format_output(part.message)
|
| 420 |
+
if text:
|
| 421 |
+
yield text
|
|
|