brainsqueeze's picture
v2 of public chat
ef088c2 verified
raw
history blame
2.26 kB
from typing import TypedDict, Literal, Any
from collections.abc import Iterator
from dataclasses import asdict
import logging
import json
from langchain_core.messages.tool import ToolMessage
from gradio import ChatMessage
logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s")
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class ToolInput(TypedDict):
name: str
args: dict[str, Any]
id: str
type: Literal["tool_call"]
class CalledTool(TypedDict):
id: str
name: Literal["tools"]
input: list[ToolInput]
triggers: tuple[str, ...]
class ToolResult(TypedDict):
id: str
name: Literal["tools"]
error: bool | None
result: list[tuple[str, list[ToolMessage]]]
interrupts: list
def convert_history_for_graph_agent(history: list[dict | ChatMessage]) -> list[dict]:
_hist = []
for h in history:
if isinstance(h, ChatMessage):
h = asdict(h)
if h.get("content"):
# if h.get("metadata"):
# # skip if it's a tool-call
# continue
_hist.append(h)
return _hist
def format_tool_call(input_chunk: CalledTool) -> Iterator[ChatMessage]:
for graph_input in input_chunk["input"]:
yield ChatMessage(
role="assistant",
content=json.dumps(graph_input["args"]),
metadata={
"title": f"Using tool `{graph_input.get('name')}`",
"status": "done",
"id": input_chunk["id"],
"parent_id": input_chunk["id"]
}
)
def format_tool_response(result_chunk: ToolResult) -> Iterator[ChatMessage]:
for _, outputs in result_chunk["result"]:
for tool in outputs:
logger.info("Called tool `%s`", tool.name)
yield ChatMessage(
role="assistant",
content=tool.content,
metadata={
"title": f"Results from tool `{tool.name}`",
"tool_name": tool.name,
"documents": tool.artifact,
"status": "done",
"parent_id": result_chunk["id"]
} # pyright: ignore[reportArgumentType]
)