from typing import TypedDict, Literal, Any from collections.abc import Iterator from dataclasses import asdict import logging import json from langchain_core.messages.tool import ToolMessage from gradio import ChatMessage logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s") logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) class ToolInput(TypedDict): name: str args: dict[str, Any] id: str type: Literal["tool_call"] class CalledTool(TypedDict): id: str name: Literal["tools"] input: list[ToolInput] triggers: tuple[str, ...] class ToolResult(TypedDict): id: str name: Literal["tools"] error: bool | None result: list[tuple[str, list[ToolMessage]]] interrupts: list def convert_history_for_graph_agent(history: list[dict | ChatMessage]) -> list[dict]: _hist = [] for h in history: if isinstance(h, ChatMessage): h = asdict(h) if h.get("content"): # if h.get("metadata"): # # skip if it's a tool-call # continue _hist.append(h) return _hist def format_tool_call(input_chunk: CalledTool) -> Iterator[ChatMessage]: for graph_input in input_chunk["input"]: yield ChatMessage( role="assistant", content=json.dumps(graph_input["args"]), metadata={ "title": f"Using tool `{graph_input.get('name')}`", "status": "done", "id": input_chunk["id"], "parent_id": input_chunk["id"] } ) def format_tool_response(result_chunk: ToolResult) -> Iterator[ChatMessage]: for _, outputs in result_chunk["result"]: for tool in outputs: logger.info("Called tool `%s`", tool.name) yield ChatMessage( role="assistant", content=tool.content, metadata={ "title": f"Results from tool `{tool.name}`", "tool_name": tool.name, "documents": tool.artifact, "status": "done", "parent_id": result_chunk["id"] } # pyright: ignore[reportArgumentType] )