File size: 2,262 Bytes
ef088c2
 
 
 
 
a0e37e2
ef088c2
 
a0e37e2
ef088c2
 
 
a0e37e2
 
ef088c2
 
 
 
 
a0e37e2
 
ef088c2
 
 
 
 
a0e37e2
c751e97
ef088c2
 
 
 
 
 
c751e97
a0e37e2
ef088c2
 
 
 
 
a0e37e2
ef088c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from typing import TypedDict, Literal, Any
from collections.abc import Iterator
from dataclasses import asdict
import logging
import json

from langchain_core.messages.tool import ToolMessage
from gradio import ChatMessage

logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s")
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class ToolInput(TypedDict):
    name: str
    args: dict[str, Any]
    id: str
    type: Literal["tool_call"]


class CalledTool(TypedDict):
    id: str
    name: Literal["tools"]
    input: list[ToolInput]
    triggers: tuple[str, ...]


class ToolResult(TypedDict):
    id: str
    name: Literal["tools"]
    error: bool | None
    result: list[tuple[str, list[ToolMessage]]]
    interrupts: list


def convert_history_for_graph_agent(history: list[dict | ChatMessage]) -> list[dict]:
    _hist = []
    for h in history:
        if isinstance(h, ChatMessage):
            h = asdict(h)

        if h.get("content"):
            # if h.get("metadata"):
            #     # skip if it's a tool-call
            #     continue
            _hist.append(h)
    return _hist


def format_tool_call(input_chunk: CalledTool) -> Iterator[ChatMessage]:
    for graph_input in input_chunk["input"]:
        yield ChatMessage(
            role="assistant",
            content=json.dumps(graph_input["args"]),
            metadata={
                "title": f"Using tool `{graph_input.get('name')}`",
                "status": "done",
                "id": input_chunk["id"],
                "parent_id": input_chunk["id"]
            }
        )


def format_tool_response(result_chunk: ToolResult) -> Iterator[ChatMessage]:
    for _, outputs in result_chunk["result"]:
        for tool in outputs:
            logger.info("Called tool `%s`", tool.name)
            yield ChatMessage(
                role="assistant",
                content=tool.content,
                metadata={
                    "title": f"Results from tool `{tool.name}`",
                    "tool_name": tool.name,
                    "documents": tool.artifact,
                    "status": "done",
                    "parent_id": result_chunk["id"]
                } # pyright: ignore[reportArgumentType]
            )