Yongkang ZOU commited on
Commit
b5faafa
·
1 Parent(s): 3fba19d

update agent

Browse files
Files changed (1) hide show
  1. agent.py +49 -31
agent.py CHANGED
@@ -1,22 +1,25 @@
1
  import os
2
  from dotenv import load_dotenv
3
- from langgraph.graph import START, StateGraph, MessagesState
4
  from langgraph.prebuilt import tools_condition, ToolNode
5
  from langchain_google_genai import ChatGoogleGenerativeAI
6
  from langchain_groq import ChatGroq
7
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
8
  from langchain_community.tools.tavily_search import TavilySearchResults
9
  from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
10
- from langchain_core.messages import SystemMessage, HumanMessage
11
  from langchain_core.tools import tool
12
  from langchain_groq import ChatGroq
 
 
 
 
 
 
13
 
14
  load_dotenv()
15
 
16
-
17
-
18
  # ------------------- TOOL DEFINITIONS -------------------
19
-
20
  @tool
21
  def multiply(a: int, b: int) -> int:
22
  """Multiply two numbers."""
@@ -29,19 +32,19 @@ def add(a: int, b: int) -> int:
29
 
30
  @tool
31
  def subtract(a: int, b: int) -> int:
32
- """Subtract two numbers."""
33
  return a - b
34
 
35
  @tool
36
  def divide(a: int, b: int) -> float:
37
- """Divide two numbers."""
38
  if b == 0:
39
  raise ValueError("Cannot divide by zero.")
40
  return a / b
41
 
42
  @tool
43
  def modulus(a: int, b: int) -> int:
44
- """Get the modulus of two numbers."""
45
  return a % b
46
 
47
  @tool
@@ -54,23 +57,19 @@ def wiki_search(query: str) -> str:
54
  def web_search(query: str) -> str:
55
  """Search the web using Tavily (max 3 results)."""
56
  results = TavilySearchResults(max_results=3).invoke(query)
57
- texts = []
58
- for doc in results:
59
- if isinstance(doc, dict):
60
- texts.append(doc.get("content", "") or doc.get("text", ""))
61
  return "\n\n".join(texts)
62
 
63
-
64
  @tool
65
  def arvix_search(query: str) -> str:
66
- """Search Arxiv for academic papers (max 3)."""
67
  docs = ArxivLoader(query=query, load_max_docs=3).load()
68
  return "\n\n".join([doc.page_content[:1000] for doc in docs])
69
 
 
70
  tools = [multiply, add, subtract, divide, modulus, wiki_search, web_search, arvix_search]
71
 
72
  # ------------------- SYSTEM PROMPT -------------------
73
-
74
  system_prompt_path = "system_prompt.txt"
75
  if os.path.exists(system_prompt_path):
76
  with open(system_prompt_path, "r", encoding="utf-8") as f:
@@ -83,12 +82,7 @@ else:
83
  sys_msg = SystemMessage(content=system_prompt)
84
 
85
  # ------------------- GRAPH CONSTRUCTION -------------------
86
-
87
- from langchain_openai import ChatOpenAI # ✅ 新增导入
88
-
89
  def build_graph(provider: str = "groq"):
90
- """Build LangGraph agent with QA retriever and tool-use fallback."""
91
- # 初始化 LLM
92
  if provider == "google":
93
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
94
  elif provider == "groq":
@@ -111,13 +105,11 @@ def build_graph(provider: str = "groq"):
111
  else:
112
  raise ValueError("Invalid provider")
113
 
114
- # 工具绑定
115
  llm_with_tools = llm.bind_tools(tools)
116
 
117
  def assistant(state: MessagesState):
118
  return {"messages": [sys_msg] + [llm_with_tools.invoke(state["messages"])]}
119
 
120
- # ✅ 初始化 Supabase Retriever
121
  SUPABASE_URL = os.getenv("SUPABASE_URL")
122
  SUPABASE_KEY = os.getenv("SUPABASE_SERVICE_KEY")
123
  supabase = create_client(SUPABASE_URL, SUPABASE_KEY)
@@ -130,7 +122,38 @@ def build_graph(provider: str = "groq"):
130
  )
131
  retriever = vectorstore.as_retriever(search_kwargs={"k": 1})
132
 
133
- # ✅ Retriever 节点
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  def qa_retriever_node(state: MessagesState):
135
  user_question = state["messages"][-1].content
136
  docs = retriever.invoke(user_question)
@@ -139,12 +162,8 @@ def build_graph(provider: str = "groq"):
139
  "messages": state["messages"] + [AIMessage(content=docs[0].page_content)],
140
  "__condition__": "complete"
141
  }
142
- return {
143
- "messages": state["messages"],
144
- "__condition__": "default"
145
- }
146
 
147
- # 构建图结构
148
  builder = StateGraph(MessagesState)
149
  builder.add_node("retriever", qa_retriever_node)
150
  builder.add_node("assistant", assistant)
@@ -152,8 +171,8 @@ def build_graph(provider: str = "groq"):
152
 
153
  builder.add_edge(START, "retriever")
154
  builder.add_conditional_edges("retriever", {
155
- "default": "assistant",
156
- "complete": None
157
  })
158
  builder.add_conditional_edges("assistant", tools_condition)
159
  builder.add_edge("tools", "assistant")
@@ -161,7 +180,6 @@ def build_graph(provider: str = "groq"):
161
  return builder.compile()
162
 
163
  # ------------------- LOCAL TEST -------------------
164
-
165
  if __name__ == "__main__":
166
  question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
167
  graph = build_graph(provider="openai")
 
1
  import os
2
  from dotenv import load_dotenv
3
+ from langgraph.graph import START, StateGraph, MessagesState, END
4
  from langgraph.prebuilt import tools_condition, ToolNode
5
  from langchain_google_genai import ChatGoogleGenerativeAI
6
  from langchain_groq import ChatGroq
7
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
8
  from langchain_community.tools.tavily_search import TavilySearchResults
9
  from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
10
+ from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
11
  from langchain_core.tools import tool
12
  from langchain_groq import ChatGroq
13
+ from supabase import create_client
14
+ from langchain_huggingface import HuggingFaceEmbeddings
15
+ from langchain_community.vectorstores import SupabaseVectorStore
16
+ from langchain_openai import ChatOpenAI
17
+ from langchain_core.documents import Document
18
+ import json
19
 
20
  load_dotenv()
21
 
 
 
22
  # ------------------- TOOL DEFINITIONS -------------------
 
23
  @tool
24
  def multiply(a: int, b: int) -> int:
25
  """Multiply two numbers."""
 
32
 
33
  @tool
34
  def subtract(a: int, b: int) -> int:
35
+ """Subtract b from a."""
36
  return a - b
37
 
38
  @tool
39
  def divide(a: int, b: int) -> float:
40
+ """Divide a by b. Raise error if b is zero."""
41
  if b == 0:
42
  raise ValueError("Cannot divide by zero.")
43
  return a / b
44
 
45
  @tool
46
  def modulus(a: int, b: int) -> int:
47
+ """Get remainder of a divided by b."""
48
  return a % b
49
 
50
  @tool
 
57
  def web_search(query: str) -> str:
58
  """Search the web using Tavily (max 3 results)."""
59
  results = TavilySearchResults(max_results=3).invoke(query)
60
+ texts = [doc.get("content", "") or doc.get("text", "") for doc in results if isinstance(doc, dict)]
 
 
 
61
  return "\n\n".join(texts)
62
 
 
63
  @tool
64
  def arvix_search(query: str) -> str:
65
+ """Search Arxiv for academic papers (max 3 results, truncated to 1000 characters each)."""
66
  docs = ArxivLoader(query=query, load_max_docs=3).load()
67
  return "\n\n".join([doc.page_content[:1000] for doc in docs])
68
 
69
+
70
  tools = [multiply, add, subtract, divide, modulus, wiki_search, web_search, arvix_search]
71
 
72
  # ------------------- SYSTEM PROMPT -------------------
 
73
  system_prompt_path = "system_prompt.txt"
74
  if os.path.exists(system_prompt_path):
75
  with open(system_prompt_path, "r", encoding="utf-8") as f:
 
82
  sys_msg = SystemMessage(content=system_prompt)
83
 
84
  # ------------------- GRAPH CONSTRUCTION -------------------
 
 
 
85
  def build_graph(provider: str = "groq"):
 
 
86
  if provider == "google":
87
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
88
  elif provider == "groq":
 
105
  else:
106
  raise ValueError("Invalid provider")
107
 
 
108
  llm_with_tools = llm.bind_tools(tools)
109
 
110
  def assistant(state: MessagesState):
111
  return {"messages": [sys_msg] + [llm_with_tools.invoke(state["messages"])]}
112
 
 
113
  SUPABASE_URL = os.getenv("SUPABASE_URL")
114
  SUPABASE_KEY = os.getenv("SUPABASE_SERVICE_KEY")
115
  supabase = create_client(SUPABASE_URL, SUPABASE_KEY)
 
122
  )
123
  retriever = vectorstore.as_retriever(search_kwargs={"k": 1})
124
 
125
+
126
+ # ✅ 替换 similarity_search_by_vector_with_relevance_scores 方法,直接调用 supabase.rpc
127
+ original_fn = vectorstore.similarity_search_by_vector_with_relevance_scores
128
+
129
+ # ✅ 覆盖 vectorstore 的方法
130
+ def patched_fn(embedding, k=4, filter=None, **kwargs):
131
+ response = supabase.rpc(
132
+ "match_documents",
133
+ {
134
+ "query_embedding": embedding,
135
+ "match_count": k
136
+ }
137
+ ).execute()
138
+
139
+ documents = []
140
+ for r in response.data:
141
+ metadata = r["metadata"]
142
+ if isinstance(metadata, str):
143
+ try:
144
+ metadata = json.loads(metadata)
145
+ except Exception:
146
+ metadata = {}
147
+ doc = Document(
148
+ page_content=r["content"],
149
+ metadata=metadata
150
+ )
151
+ documents.append((doc, r["similarity"]))
152
+ return documents
153
+
154
+ # ✅ 覆盖 vectorstore 的方法
155
+ vectorstore.similarity_search_by_vector_with_relevance_scores = patched_fn
156
+
157
  def qa_retriever_node(state: MessagesState):
158
  user_question = state["messages"][-1].content
159
  docs = retriever.invoke(user_question)
 
162
  "messages": state["messages"] + [AIMessage(content=docs[0].page_content)],
163
  "__condition__": "complete"
164
  }
165
+ return {"messages": state["messages"], "__condition__": "default"}
 
 
 
166
 
 
167
  builder = StateGraph(MessagesState)
168
  builder.add_node("retriever", qa_retriever_node)
169
  builder.add_node("assistant", assistant)
 
171
 
172
  builder.add_edge(START, "retriever")
173
  builder.add_conditional_edges("retriever", {
174
+ "default": lambda x: "assistant",
175
+ "complete": lambda x: END,
176
  })
177
  builder.add_conditional_edges("assistant", tools_condition)
178
  builder.add_edge("tools", "assistant")
 
180
  return builder.compile()
181
 
182
  # ------------------- LOCAL TEST -------------------
 
183
  if __name__ == "__main__":
184
  question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
185
  graph = build_graph(provider="openai")