philincloud commited on
Commit
182160d
Β·
verified Β·
1 Parent(s): e29152c

Rename agent.py to langgraph_agent.py

Browse files
Files changed (2) hide show
  1. agent.py +0 -0
  2. langgraph_agent.py +115 -0
agent.py DELETED
File without changes
langgraph_agent.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from langgraph.graph import START, StateGraph, MessagesState
3
+ from langgraph.prebuilt import tools_condition, ToolNode
4
+ from langchain_openai import ChatOpenAI
5
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
6
+ from langchain_community.tools.tavily_search import TavilySearchResults
7
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
8
+ from langchain_core.messages import SystemMessage, HumanMessage
9
+ from langchain_core.tools import tool
10
+
11
+ #tools
12
+ @tool
13
+ def multiply(a: int, b: int) -> int:
14
+ return a * b
15
+
16
+ @tool
17
+ def add(a: int, b: int) -> int:
18
+ return a + b
19
+
20
+ @tool
21
+ def subtract(a: int, b: int) -> int:
22
+ return a - b
23
+
24
+ @tool
25
+ def divide(a: int, b: int) -> float:
26
+ if b == 0:
27
+ raise ValueError("Cannot divide by zero.")
28
+ return a / b
29
+
30
+ @tool
31
+ def modulus(a: int, b: int) -> int:
32
+ return a % b
33
+
34
+ @tool
35
+ def wiki_search(query: str) -> dict:
36
+ docs = WikipediaLoader(query=query, load_max_docs=2).load()
37
+ formatted = "\n\n---\n\n".join(
38
+ f'<Document source="{d.metadata["source"]}"/>\n{d.page_content}'
39
+ for d in docs
40
+ )
41
+ return {"wiki_results": formatted}
42
+
43
+ @tool
44
+ def web_search(query: str) -> dict:
45
+ docs = TavilySearchResults(max_results=3).invoke(query=query)
46
+ formatted = "\n\n---\n\n".join(
47
+ f'<Document source="{d.metadata["source"]}"/>\n{d.page_content}'
48
+ for d in docs
49
+ )
50
+ return {"web_results": formatted}
51
+
52
+ @tool
53
+ def arvix_search(query: str) -> dict:
54
+ docs = ArxivLoader(query=query, load_max_docs=3).load()
55
+ formatted = "\n\n---\n\n".join(
56
+ f'<Document source="{d.metadata["source"]}"/>\n{d.page_content[:1000]}'
57
+ for d in docs
58
+ )
59
+ return {"arvix_results": formatted}
60
+
61
+
62
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
63
+ HF_SPACE_TOKEN = os.getenv("HF_SPACE_TOKEN")
64
+
65
+ # ───────────────────────────────────────────────────────────────────────────────
66
+ # 4) Assemble tool list
67
+ tools = [
68
+ multiply, add, subtract, divide, modulus,
69
+ wiki_search, web_search, arvix_search,
70
+ ]
71
+
72
+ # ───────────────────────────────────────────────────────────────────────────────
73
+ # 5) Load your system prompt
74
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
75
+ system_prompt = f.read()
76
+ sys_msg = SystemMessage(content=system_prompt)
77
+
78
+ # ───────────────────────────────────────────────────────────────────────────────
79
+ def build_graph(provider: str = "openai"):
80
+ """Build the LangGraph agent with chosen LLM (default: OpenAI)."""
81
+ if provider == "openai":
82
+ llm = ChatOpenAI(
83
+ model_name="o4-mini-2025-04-16",
84
+ openai_api_key=OPENAI_API_KEY,
85
+ # no temperature override here
86
+ )
87
+ elif provider == "huggingface":
88
+ llm = ChatHuggingFace(
89
+ llm=HuggingFaceEndpoint(
90
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
91
+ ),
92
+ temperature=0,
93
+ )
94
+ else:
95
+ raise ValueError("Invalid provider. Choose 'openai' or 'huggingface'.")
96
+
97
+ llm_with_tools = llm.bind_tools(tools)
98
+
99
+ def assistant(state: MessagesState):
100
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
101
+
102
+ builder = StateGraph(MessagesState)
103
+ builder.add_node("assistant", assistant)
104
+ builder.add_node("tools", ToolNode(tools))
105
+ builder.add_edge(START, "assistant")
106
+ builder.add_conditional_edges("assistant", tools_condition)
107
+ builder.add_edge("tools", "assistant")
108
+
109
+ return builder.compile()
110
+
111
+ if __name__ == "__main__":
112
+ graph = build_graph()
113
+ msgs = graph.invoke({"messages":[ HumanMessage(content="What’s the capital of France?") ]})
114
+ for m in msgs["messages"]:
115
+ m.pretty_print()