Spaces:
Running
Running
File size: 2,262 Bytes
ef088c2 a0e37e2 ef088c2 a0e37e2 ef088c2 a0e37e2 ef088c2 a0e37e2 ef088c2 a0e37e2 c751e97 ef088c2 c751e97 a0e37e2 ef088c2 a0e37e2 ef088c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
from typing import TypedDict, Literal, Any
from collections.abc import Iterator
from dataclasses import asdict
import logging
import json
from langchain_core.messages.tool import ToolMessage
from gradio import ChatMessage
logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s")
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class ToolInput(TypedDict):
name: str
args: dict[str, Any]
id: str
type: Literal["tool_call"]
class CalledTool(TypedDict):
id: str
name: Literal["tools"]
input: list[ToolInput]
triggers: tuple[str, ...]
class ToolResult(TypedDict):
id: str
name: Literal["tools"]
error: bool | None
result: list[tuple[str, list[ToolMessage]]]
interrupts: list
def convert_history_for_graph_agent(history: list[dict | ChatMessage]) -> list[dict]:
_hist = []
for h in history:
if isinstance(h, ChatMessage):
h = asdict(h)
if h.get("content"):
# if h.get("metadata"):
# # skip if it's a tool-call
# continue
_hist.append(h)
return _hist
def format_tool_call(input_chunk: CalledTool) -> Iterator[ChatMessage]:
for graph_input in input_chunk["input"]:
yield ChatMessage(
role="assistant",
content=json.dumps(graph_input["args"]),
metadata={
"title": f"Using tool `{graph_input.get('name')}`",
"status": "done",
"id": input_chunk["id"],
"parent_id": input_chunk["id"]
}
)
def format_tool_response(result_chunk: ToolResult) -> Iterator[ChatMessage]:
for _, outputs in result_chunk["result"]:
for tool in outputs:
logger.info("Called tool `%s`", tool.name)
yield ChatMessage(
role="assistant",
content=tool.content,
metadata={
"title": f"Results from tool `{tool.name}`",
"tool_name": tool.name,
"documents": tool.artifact,
"status": "done",
"parent_id": result_chunk["id"]
} # pyright: ignore[reportArgumentType]
)
|