File size: 6,067 Bytes
10eeb8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dcd4f6c
10eeb8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# shopping_agent.py

from agent_src.tools import search_tool, query_url_tool, run_code_tool
from langchain_nvidia_ai_endpoints import ChatNVIDIA
from langgraph.graph import StateGraph, END, START
from langchain_core.messages import ToolMessage, SystemMessage, HumanMessage, BaseMessage, AIMessage
from typing import Literal, List, Dict, Any, TypedDict, Annotated, Generator
from langchain_core.tools import tool
import operator
import yaml

# --- State Definition ---
class ShoppingState(TypedDict):
    messages: Annotated[List[BaseMessage], operator.add]
    shopping_list: List[Dict[str, Any]]

# --- Custom Tools ---
@tool
def add_to_list_tool(name: str, url: str, details: str) -> str:
    """
    Adds a specified item to the user's shopping list.

    Args:
        name: The name of the product.
        url: The URL of the product page.
        details: Any important details about the product (e.g., price, color, size).
    """
    # This tool's primary purpose is to signal the intent to add to the list.
    # The _tool_node will handle the actual state modification.
    return f"✅ Item '{name}' is ready to be added to your shopping list."

# --- ShoppingAgent Class ---
class ShoppingAgent:
    def __init__(self, model: str = "qwen/qwen3-235b-a22b"):
        self.llm = ChatNVIDIA(model=model, max_tokens=8192)
        
        self.environment_tools = [search_tool, query_url_tool, run_code_tool]
        self.state_tools = [add_to_list_tool]
        self.tools = self.environment_tools + self.state_tools
        self.tools_by_name = {t.name: t for t in self.tools}
        
        self.llm_with_tools = self.llm.bind_tools(self.tools, parallel_tool_calls=True)
        self.agent = self._build_agent()
        self.state: ShoppingState = {"messages": [], "shopping_list": []}

    def _build_agent(self) -> StateGraph:
        builder = StateGraph(ShoppingState)
        builder.add_node("llm_call", self._llm_call)
        builder.add_node("tool_node", self._tool_node)
        builder.add_edge(START, "llm_call")
        builder.add_conditional_edges(
            "llm_call", self._should_continue,
            {"Action": "tool_node", END: END}
        )
        builder.add_edge("tool_node", "llm_call")
        return builder.compile()

    def _load_system_prompt(self) -> SystemMessage:
        try:
            from src.shopping_agent import SHOPPING_AGENT_CONFIG
            prompt = SHOPPING_AGENT_CONFIG.get("system_prompt", "You are a helpful assistant.") 
            print(prompt)
        except Exception as e:
            print(e)
            prompt = "You are a helpful shopping assistant. Be friendly and efficient."
        return SystemMessage(content=prompt)

    def _llm_call(self, state: ShoppingState) -> Dict[str, List[BaseMessage]]:
        system_msg = self._load_system_prompt()
        print("This is the system : ",system_msg)
        messages = [system_msg] + state["messages"]
        response = self.llm_with_tools.invoke(messages)
        return {"messages": [response]}

    def _should_continue(self, state: ShoppingState) -> Literal["Action", END]:
        last_msg = state["messages"][-1]
        return "Action" if hasattr(last_msg, "tool_calls") and last_msg.tool_calls else END

    def _tool_node(self, state: ShoppingState) -> Dict[str, Any]:
        """
        This node executes tools and returns updates to the state.
        It is now a "pure function" that doesn't modify self.state directly.
        """
        tool_results = []
        # The shopping list from the current state in the graph
        updated_shopping_list = state.get("shopping_list", []).copy()

        last_msg = state["messages"][-1]
        if not isinstance(last_msg, AIMessage) or not last_msg.tool_calls:
            return {}

        for call in last_msg.tool_calls:
            tool_func = self.tools_by_name.get(call["name"])
            if not tool_func:
                obs = f"⚠️ Error: Tool '{call['name']}' not found."
            else:
                try:
                    # Special handling for our state-modifying tool
                    if call["name"] == "add_to_list_tool":
                        updated_shopping_list.append(call["args"])
                        obs = f"✅ Added '{call['args'].get('name')}' to your shopping list."
                    else:
                        obs = tool_func.invoke(call["args"])
                except Exception as e:
                    obs = f"⚠️ Error running {call['name']}: {e}"
            
            tool_results.append(ToolMessage(content=str(obs), tool_call_id=call["id"]))

        return {
            "messages": tool_results,
            "shopping_list": updated_shopping_list, # Return the new list
        }

    def run_agent(self, user_input: str) -> Generator[str, None, None]:
        """
        Runs the agent graph and yields 'thought' updates for the Streamlit UI.
        Manages state synchronization after the run is complete.
        """
        self.state["messages"].append(HumanMessage(content=user_input))
        
        graph_stream = self.agent.stream(self.state, stream_mode="values")
        final_state = None

        for step_output in graph_stream:
            final_state = step_output # Keep track of the latest state
            
            if "llm_call" in step_output:
                yield "🧠 **Planning:** Deciding next steps..."
                
            elif "tool_node" in step_output:
                for tool_call in step_output["llm_call"]["messages"][-1].tool_calls:
                    yield f"🛠️ **Tool Call:** `{tool_call['name']}`\n" \
                          f"   - **Arguments:** `{tool_call['args']}`"
                
                # Yield the results from the tool messages
                for tool_message in step_output['tool_node']['messages']:
                     yield f"✔️ **Tool Result:**\n   - `{tool_message.content}`"

        # After the stream is complete, update the agent's main state
        if final_state:
            self.state = final_state