Yongkang ZOU commited on
Commit
3fba19d
·
1 Parent(s): 547bea9

add retriever

Browse files
Files changed (2) hide show
  1. agent.py +38 -4
  2. requirements.txt +1 -0
agent.py CHANGED
@@ -87,7 +87,8 @@ sys_msg = SystemMessage(content=system_prompt)
87
  from langchain_openai import ChatOpenAI # ✅ 新增导入
88
 
89
  def build_graph(provider: str = "groq"):
90
- """Build the LangGraph with tool-use."""
 
91
  if provider == "google":
92
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
93
  elif provider == "groq":
@@ -106,21 +107,54 @@ def build_graph(provider: str = "groq"):
106
  openai_key = os.getenv("OPENAI_API_KEY")
107
  if not openai_key:
108
  raise ValueError("OPENAI_API_KEY is not set.")
109
- llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0, api_key=openai_key) # ✅ OpenAI GPT
110
  else:
111
  raise ValueError("Invalid provider")
112
 
 
113
  llm_with_tools = llm.bind_tools(tools)
114
 
115
  def assistant(state: MessagesState):
116
  return {"messages": [sys_msg] + [llm_with_tools.invoke(state["messages"])]}
117
 
118
- # Build the graph with assistant and tools
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  builder = StateGraph(MessagesState)
 
120
  builder.add_node("assistant", assistant)
121
  builder.add_node("tools", ToolNode(tools))
122
 
123
- builder.add_edge(START, "assistant")
 
 
 
 
124
  builder.add_conditional_edges("assistant", tools_condition)
125
  builder.add_edge("tools", "assistant")
126
 
 
87
  from langchain_openai import ChatOpenAI # ✅ 新增导入
88
 
89
  def build_graph(provider: str = "groq"):
90
+ """Build LangGraph agent with QA retriever and tool-use fallback."""
91
+ # 初始化 LLM
92
  if provider == "google":
93
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
94
  elif provider == "groq":
 
107
  openai_key = os.getenv("OPENAI_API_KEY")
108
  if not openai_key:
109
  raise ValueError("OPENAI_API_KEY is not set.")
110
+ llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0, api_key=openai_key)
111
  else:
112
  raise ValueError("Invalid provider")
113
 
114
+ # 工具绑定
115
  llm_with_tools = llm.bind_tools(tools)
116
 
117
  def assistant(state: MessagesState):
118
  return {"messages": [sys_msg] + [llm_with_tools.invoke(state["messages"])]}
119
 
120
+ # 初始化 Supabase Retriever
121
+ SUPABASE_URL = os.getenv("SUPABASE_URL")
122
+ SUPABASE_KEY = os.getenv("SUPABASE_SERVICE_KEY")
123
+ supabase = create_client(SUPABASE_URL, SUPABASE_KEY)
124
+
125
+ embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
126
+ vectorstore = SupabaseVectorStore(
127
+ client=supabase,
128
+ embedding=embedding_model,
129
+ table_name="QA_db"
130
+ )
131
+ retriever = vectorstore.as_retriever(search_kwargs={"k": 1})
132
+
133
+ # ✅ Retriever 节点
134
+ def qa_retriever_node(state: MessagesState):
135
+ user_question = state["messages"][-1].content
136
+ docs = retriever.invoke(user_question)
137
+ if docs:
138
+ return {
139
+ "messages": state["messages"] + [AIMessage(content=docs[0].page_content)],
140
+ "__condition__": "complete"
141
+ }
142
+ return {
143
+ "messages": state["messages"],
144
+ "__condition__": "default"
145
+ }
146
+
147
+ # 构建图结构
148
  builder = StateGraph(MessagesState)
149
+ builder.add_node("retriever", qa_retriever_node)
150
  builder.add_node("assistant", assistant)
151
  builder.add_node("tools", ToolNode(tools))
152
 
153
+ builder.add_edge(START, "retriever")
154
+ builder.add_conditional_edges("retriever", {
155
+ "default": "assistant",
156
+ "complete": None
157
+ })
158
  builder.add_conditional_edges("assistant", tools_condition)
159
  builder.add_edge("tools", "assistant")
160
 
requirements.txt CHANGED
@@ -17,3 +17,4 @@ pymupdf
17
  wikipedia
18
  pgvector
19
  python-dotenv
 
 
17
  wikipedia
18
  pgvector
19
  python-dotenv
20
+ supabase