Yongkang ZOU commited on
Commit
d9c23d1
·
1 Parent(s): 81917a3

update agent

Browse files
Files changed (4) hide show
  1. agent.py +213 -0
  2. metadata.jsonl +0 -0
  3. requirements.txt +189 -2
  4. system_prompt.txt +5 -0
agent.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LangGraph Agent"""
2
+ import os
3
+ from dotenv import load_dotenv
4
+ from langgraph.graph import START, StateGraph, MessagesState
5
+ from langgraph.prebuilt import tools_condition
6
+ from langgraph.prebuilt import ToolNode
7
+ from langchain_google_genai import ChatGoogleGenerativeAI
8
+ from langchain_groq import ChatGroq
9
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
+ from langchain_community.tools.tavily_search import TavilySearchResults
11
+ from langchain_community.document_loaders import WikipediaLoader
12
+ from langchain_community.document_loaders import ArxivLoader
13
+ from langchain_community.vectorstores import SupabaseVectorStore
14
+ from langchain_core.messages import SystemMessage, HumanMessage
15
+ from langchain_core.tools import tool
16
+ from langchain.tools.retriever import create_retriever_tool
17
+ # from supabase.client import Client, create_client
18
+
19
+ load_dotenv()
20
+
21
+ @tool
22
+ def multiply(a: int, b: int) -> int:
23
+ """Multiply two numbers.
24
+ Args:
25
+ a: first int
26
+ b: second int
27
+ """
28
+ return a * b
29
+
30
+ @tool
31
+ def add(a: int, b: int) -> int:
32
+ """Add two numbers.
33
+
34
+ Args:
35
+ a: first int
36
+ b: second int
37
+ """
38
+ return a + b
39
+
40
+ @tool
41
+ def subtract(a: int, b: int) -> int:
42
+ """Subtract two numbers.
43
+
44
+ Args:
45
+ a: first int
46
+ b: second int
47
+ """
48
+ return a - b
49
+
50
+ @tool
51
+ def divide(a: int, b: int) -> int:
52
+ """Divide two numbers.
53
+
54
+ Args:
55
+ a: first int
56
+ b: second int
57
+ """
58
+ if b == 0:
59
+ raise ValueError("Cannot divide by zero.")
60
+ return a / b
61
+
62
+ @tool
63
+ def modulus(a: int, b: int) -> int:
64
+ """Get the modulus of two numbers.
65
+
66
+ Args:
67
+ a: first int
68
+ b: second int
69
+ """
70
+ return a % b
71
+
72
+ @tool
73
+ def wiki_search(query: str) -> str:
74
+ """Search Wikipedia for a query and return maximum 2 results.
75
+
76
+ Args:
77
+ query: The search query."""
78
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
79
+ formatted_search_docs = "\n\n---\n\n".join(
80
+ [
81
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
82
+ for doc in search_docs
83
+ ])
84
+ return {"wiki_results": formatted_search_docs}
85
+
86
+ @tool
87
+ def web_search(query: str) -> str:
88
+ """Search Tavily for a query and return maximum 3 results.
89
+
90
+ Args:
91
+ query: The search query."""
92
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
93
+ formatted_search_docs = "\n\n---\n\n".join(
94
+ [
95
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
96
+ for doc in search_docs
97
+ ])
98
+ return {"web_results": formatted_search_docs}
99
+
100
+ @tool
101
+ def arvix_search(query: str) -> str:
102
+ """Search Arxiv for a query and return maximum 3 result.
103
+
104
+ Args:
105
+ query: The search query."""
106
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
107
+ formatted_search_docs = "\n\n---\n\n".join(
108
+ [
109
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
110
+ for doc in search_docs
111
+ ])
112
+ return {"arvix_results": formatted_search_docs}
113
+
114
+
115
+
116
+ # load the system prompt from the file
117
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
118
+ system_prompt = f.read()
119
+
120
+ # System message
121
+ sys_msg = SystemMessage(content=system_prompt)
122
+
123
+ # # build a retriever
124
+ # embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
125
+ # supabase: Client = create_client(
126
+ # os.environ.get("SUPABASE_URL"),
127
+ # os.environ.get("SUPABASE_SERVICE_KEY"))
128
+ # vector_store = SupabaseVectorStore(
129
+ # client=supabase,
130
+ # embedding= embeddings,
131
+ # table_name="documents",
132
+ # query_name="match_documents_langchain",
133
+ # )
134
+ # create_retriever_tool = create_retriever_tool(
135
+ # retriever=vector_store.as_retriever(),
136
+ # name="Question Search",
137
+ # description="A tool to retrieve similar questions from a vector store.",
138
+ # )
139
+
140
+
141
+
142
+ tools = [
143
+ multiply,
144
+ add,
145
+ subtract,
146
+ divide,
147
+ modulus,
148
+ wiki_search,
149
+ web_search,
150
+ arvix_search,
151
+ ]
152
+
153
+ # Build graph function
154
+ def build_graph(provider: str = "groq"):
155
+ """Build the graph"""
156
+ # Load environment variables from .env file
157
+ if provider == "google":
158
+ # Google Gemini
159
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
160
+ elif provider == "groq":
161
+ # Groq https://console.groq.com/docs/models
162
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
163
+ elif provider == "huggingface":
164
+ # TODO: Add huggingface endpoint
165
+ llm = ChatHuggingFace(
166
+ llm=HuggingFaceEndpoint(
167
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
168
+ temperature=0,
169
+ ),
170
+ )
171
+ else:
172
+ raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
173
+ # Bind tools to LLM
174
+ llm_with_tools = llm.bind_tools(tools)
175
+
176
+ # Node
177
+ def assistant(state: MessagesState):
178
+ """Assistant node"""
179
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
180
+
181
+ # def retriever(state: MessagesState):
182
+ # """Retriever node"""
183
+ # similar_question = vector_store.similarity_search(state["messages"][0].content)
184
+ # example_msg = HumanMessage(
185
+ # content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
186
+ # )
187
+ # return {"messages": [sys_msg] + state["messages"] + [example_msg]}
188
+
189
+ builder = StateGraph(MessagesState)
190
+ # builder.add_node("retriever", retriever)
191
+ builder.add_node("assistant", assistant)
192
+ builder.add_node("tools", ToolNode(tools))
193
+ builder.add_edge(START, "retriever")
194
+ builder.add_edge("retriever", "assistant")
195
+ builder.add_conditional_edges(
196
+ "assistant",
197
+ tools_condition,
198
+ )
199
+ builder.add_edge("tools", "assistant")
200
+
201
+ # Compile graph
202
+ return builder.compile()
203
+
204
+ # test
205
+ if __name__ == "__main__":
206
+ question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
207
+ # Build the graph
208
+ graph = build_graph(provider="groq")
209
+ # Run the graph
210
+ messages = [HumanMessage(content=question)]
211
+ messages = graph.invoke({"messages": messages})
212
+ for m in messages["messages"]:
213
+ m.pretty_print()
metadata.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt CHANGED
@@ -1,2 +1,189 @@
1
- gradio
2
- requests
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==24.1.0
2
+ aiohappyeyeballs==2.6.1
3
+ aiohttp==3.12.4
4
+ aiosignal==1.3.2
5
+ annotated-types==0.7.0
6
+ anyio==4.9.0
7
+ appnope==0.1.4
8
+ asgiref==3.8.1
9
+ asttokens==3.0.0
10
+ attrs==25.3.0
11
+ audioop-lts==0.2.1
12
+ backoff==1.11.1
13
+ bcrypt==4.3.0
14
+ build==1.2.2.post1
15
+ cachetools==5.5.2
16
+ certifi==2025.4.26
17
+ charset-normalizer==3.4.2
18
+ chromadb==1.0.11
19
+ click==8.2.1
20
+ coloredlogs==15.0.1
21
+ comm==0.2.2
22
+ dataclasses-json==0.6.7
23
+ debugpy==1.8.14
24
+ decorator==5.2.1
25
+ Deprecated==1.2.18
26
+ distro==1.9.0
27
+ durationpy==0.10
28
+ executing==2.2.0
29
+ fastapi==0.115.9
30
+ ffmpy==0.5.0
31
+ filelock==3.18.0
32
+ filetype==1.2.0
33
+ flatbuffers==25.2.10
34
+ frozenlist==1.6.0
35
+ fsspec==2025.5.1
36
+ google-ai-generativelanguage==0.6.18
37
+ google-api-core==2.24.2
38
+ google-auth==2.40.2
39
+ googleapis-common-protos==1.70.0
40
+ gradio==5.31.0
41
+ gradio_client==1.10.1
42
+ groovy==0.1.2
43
+ groq==0.26.0
44
+ grpcio==1.73.0rc1
45
+ grpcio-status==1.73.0rc1
46
+ h11==0.16.0
47
+ hf-xet==1.1.2
48
+ httpcore==1.0.9
49
+ httptools==0.6.4
50
+ httpx==0.28.1
51
+ httpx-sse==0.4.0
52
+ huggingface-hub==0.32.3
53
+ humanfriendly==10.0
54
+ idna==3.10
55
+ importlib_metadata==8.6.1
56
+ importlib_resources==6.5.2
57
+ ipykernel==6.29.5
58
+ ipython==9.2.0
59
+ ipython_pygments_lexers==1.1.1
60
+ jedi==0.19.2
61
+ Jinja2==3.1.6
62
+ joblib==1.5.1
63
+ jsonpatch==1.33
64
+ jsonpointer==3.0.0
65
+ jsonschema==4.24.0
66
+ jsonschema-specifications==2025.4.1
67
+ jupyter_client==8.6.3
68
+ jupyter_core==5.8.1
69
+ kubernetes==32.0.1
70
+ langchain==0.3.25
71
+ langchain-chroma==0.2.4
72
+ langchain-community==0.3.24
73
+ langchain-core==0.3.63
74
+ langchain-google-genai==2.1.5
75
+ langchain-groq==0.3.2
76
+ langchain-huggingface==0.2.0
77
+ langchain-tavily==0.2.0
78
+ langchain-text-splitters==0.3.8
79
+ langgraph==0.4.7
80
+ langgraph-checkpoint==2.0.26
81
+ langgraph-prebuilt==0.2.2
82
+ langgraph-sdk==0.1.70
83
+ langsmith==0.3.43
84
+ markdown-it-py==3.0.0
85
+ MarkupSafe==3.0.2
86
+ marshmallow==3.26.1
87
+ matplotlib-inline==0.1.7
88
+ mdurl==0.1.2
89
+ mmh3==5.1.0
90
+ mpmath==1.3.0
91
+ multidict==6.4.4
92
+ mypy==1.16.0
93
+ mypy_extensions==1.1.0
94
+ nest-asyncio==1.6.0
95
+ networkx==3.5
96
+ numpy==2.2.6
97
+ oauthlib==3.2.2
98
+ onnxruntime==1.22.0
99
+ opentelemetry-api==1.33.1
100
+ opentelemetry-exporter-otlp-proto-grpc==1.11.1
101
+ opentelemetry-instrumentation==0.54b1
102
+ opentelemetry-instrumentation-asgi==0.54b1
103
+ opentelemetry-instrumentation-fastapi==0.54b1
104
+ opentelemetry-proto==1.11.1
105
+ opentelemetry-sdk==1.33.1
106
+ opentelemetry-semantic-conventions==0.54b1
107
+ opentelemetry-util-http==0.54b1
108
+ orjson==3.10.18
109
+ ormsgpack==1.10.0
110
+ overrides==7.7.0
111
+ packaging==24.2
112
+ pandas==2.2.3
113
+ parso==0.8.4
114
+ pathspec==0.12.1
115
+ pexpect==4.9.0
116
+ pillow==11.2.1
117
+ platformdirs==4.3.8
118
+ posthog==4.2.0
119
+ prompt_toolkit==3.0.51
120
+ propcache==0.3.1
121
+ proto-plus==1.26.1
122
+ protobuf==6.31.1
123
+ psutil==7.0.0
124
+ ptyprocess==0.7.0
125
+ pure_eval==0.2.3
126
+ pyasn1==0.6.1
127
+ pyasn1_modules==0.4.2
128
+ pydantic==2.11.5
129
+ pydantic-settings==2.9.1
130
+ pydantic_core==2.33.2
131
+ pydub==0.25.1
132
+ Pygments==2.19.1
133
+ PyPika==0.48.9
134
+ pyproject_hooks==1.2.0
135
+ python-dateutil==2.9.0.post0
136
+ python-dotenv==1.1.0
137
+ python-multipart==0.0.20
138
+ pytz==2025.2
139
+ PyYAML==6.0.2
140
+ pyzmq==26.4.0
141
+ referencing==0.36.2
142
+ regex==2024.11.6
143
+ requests==2.32.3
144
+ requests-oauthlib==2.0.0
145
+ requests-toolbelt==1.0.0
146
+ rich==14.0.0
147
+ rpds-py==0.25.1
148
+ rsa==4.9.1
149
+ ruff==0.11.12
150
+ safehttpx==0.1.6
151
+ safetensors==0.5.3
152
+ scikit-learn==1.6.1
153
+ scipy==1.15.3
154
+ semantic-version==2.10.0
155
+ sentence-transformers==4.1.0
156
+ setuptools==80.9.0
157
+ shellingham==1.5.4
158
+ six==1.17.0
159
+ sniffio==1.3.1
160
+ SQLAlchemy==2.0.41
161
+ stack-data==0.6.3
162
+ starlette==0.45.3
163
+ sympy==1.14.0
164
+ tenacity==9.1.2
165
+ threadpoolctl==3.6.0
166
+ tokenizers==0.21.1
167
+ tomlkit==0.13.2
168
+ torch==2.7.0
169
+ tornado==6.5.1
170
+ tqdm==4.67.1
171
+ traitlets==5.14.3
172
+ transformers==4.52.4
173
+ typer==0.16.0
174
+ typing-inspect==0.9.0
175
+ typing-inspection==0.4.1
176
+ typing_extensions==4.13.2
177
+ tzdata==2025.2
178
+ urllib3==2.4.0
179
+ uvicorn==0.34.2
180
+ uvloop==0.21.0
181
+ watchfiles==1.0.5
182
+ wcwidth==0.2.13
183
+ websocket-client==1.8.0
184
+ websockets==15.0.1
185
+ wrapt==1.17.2
186
+ xxhash==3.5.0
187
+ yarl==1.20.0
188
+ zipp==3.22.0
189
+ zstandard==0.23.0
system_prompt.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ You are a helpful assistant tasked with answering questions using a set of tools.
2
+ Now, I will ask you a question. Report your thoughts, and finish your answer with the following template:
3
+ FINAL ANSWER: [YOUR FINAL ANSWER].
4
+ YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
5
+ Your answer should only start with "FINAL ANSWER: ", then follows with the answer.