Spaces:
Running
Running
Adding optional news data source
Browse filesOptional Candid news data source with sparse encoding for quasi-semantic searching
- app.py +1 -1
- ask_candid/agents/elastic.py +380 -23
- ask_candid/base/config/connections.py +11 -0
- ask_candid/base/config/constants.py +1 -1
- ask_candid/base/config/data.py +3 -2
- ask_candid/chat.py +4 -1
- ask_candid/graph.py +4 -3
- ask_candid/retrieval/elastic.py +566 -500
- ask_candid/retrieval/sparse_lexical.py +29 -0
- ask_candid/tools/elastic/index_details_tool.py +4 -2
- ask_candid/tools/elastic/index_search_tool.py +24 -2
- ask_candid/tools/org_seach.py +1 -1
- ask_candid/tools/question_reformulation.py +39 -19
- ask_candid/tools/recommendation.py +189 -140
- ask_candid/utils.py +17 -25
- requirements.txt +3 -1
app.py
CHANGED
@@ -113,7 +113,7 @@ def build_rag_chat() -> Tuple[LoggedComponents, gr.Blocks]:
|
|
113 |
with gr.Accordion(label="Advanced settings", open=False):
|
114 |
es_indices = gr.CheckboxGroup(
|
115 |
choices=list(ALL_INDICES),
|
116 |
-
value=
|
117 |
label="Sources to include",
|
118 |
interactive=True,
|
119 |
)
|
|
|
113 |
with gr.Accordion(label="Advanced settings", open=False):
|
114 |
es_indices = gr.CheckboxGroup(
|
115 |
choices=list(ALL_INDICES),
|
116 |
+
value=[idx for idx in ALL_INDICES if "news" not in idx],
|
117 |
label="Sources to include",
|
118 |
interactive=True,
|
119 |
)
|
ask_candid/agents/elastic.py
CHANGED
@@ -1,32 +1,27 @@
|
|
1 |
-
from typing import TypedDict
|
2 |
from functools import partial
|
3 |
import json
|
4 |
import ast
|
5 |
-
|
6 |
from pydantic import BaseModel, Field
|
7 |
|
8 |
-
from langchain_openai import ChatOpenAI
|
9 |
-
|
10 |
from langchain_core.runnables import RunnableSequence
|
11 |
from langchain_core.language_models.llms import LLM
|
12 |
-
|
13 |
from langchain.agents.openai_functions_agent.base import create_openai_functions_agent
|
14 |
from langchain.agents.agent import AgentExecutor
|
15 |
from langchain.agents.agent_types import AgentType
|
16 |
-
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
17 |
from langchain.output_parsers import PydanticOutputParser
|
18 |
from langchain.schema import BaseMessage
|
|
|
|
|
19 |
|
20 |
from langgraph.graph import StateGraph, END
|
21 |
|
22 |
-
from ask_candid.tools.elastic.list_indices_tool import ListIndicesTool
|
23 |
from ask_candid.tools.elastic.index_data_tool import IndexShowDataTool
|
24 |
from ask_candid.tools.elastic.index_details_tool import IndexDetailsTool
|
25 |
from ask_candid.tools.elastic.index_search_tool import create_search_tool
|
26 |
-
from ask_candid.base.config.rest import OPENAI
|
27 |
|
28 |
tools = [
|
29 |
-
ListIndicesTool(),
|
30 |
IndexShowDataTool(),
|
31 |
IndexDetailsTool(),
|
32 |
create_search_tool(),
|
@@ -58,7 +53,7 @@ class AnalysisResult(BaseModel):
|
|
58 |
category: str = Field(..., description="Either 'general' or 'Database'")
|
59 |
|
60 |
|
61 |
-
def agent_factory() -> AgentExecutor:
|
62 |
"""
|
63 |
Creates and configures an AgentExecutor instance for interacting with Elasticsearch.
|
64 |
|
@@ -72,9 +67,9 @@ def agent_factory() -> AgentExecutor:
|
|
72 |
providing detailed intermediate steps for transparency.
|
73 |
"""
|
74 |
|
75 |
-
llm = ChatOpenAI(
|
76 |
-
|
77 |
-
)
|
78 |
|
79 |
tags_ = []
|
80 |
agent = AgentType.OPENAI_FUNCTIONS
|
@@ -101,6 +96,45 @@ def agent_factory() -> AgentExecutor:
|
|
101 |
)
|
102 |
|
103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
# define graph node functions
|
105 |
def general_query(state: GraphState, llm: LLM) -> GraphState:
|
106 |
"""
|
@@ -126,7 +160,7 @@ def general_query(state: GraphState, llm: LLM) -> GraphState:
|
|
126 |
return state
|
127 |
|
128 |
|
129 |
-
def database_agent(state: GraphState) -> GraphState:
|
130 |
"""
|
131 |
Executes a database query using an Elasticsearch agent and updates the graph state.
|
132 |
|
@@ -144,22 +178,28 @@ def database_agent(state: GraphState) -> GraphState:
|
|
144 |
print("> database agent")
|
145 |
input_data = {
|
146 |
"input": f"""
|
147 |
-
|
148 |
-
Make sure that after querying the indices you query the field names.
|
149 |
-
To answer the question choose ```organization_dev_2``` index
|
150 |
|
151 |
-
|
152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
"""
|
154 |
}
|
155 |
-
agent_exec =
|
156 |
res = agent_exec.invoke(input_data)
|
157 |
state["agent_out"] = res["output"]
|
158 |
|
159 |
es_queries, es_results = {}, {}
|
160 |
for i, action in enumerate(res.get("intermediate_steps", []), start=1):
|
161 |
if action[0].tool == "elastic_index_search_tool":
|
162 |
-
es_queries[f"query_{i}"] = json.loads(
|
|
|
|
|
163 |
es_results[f"query_{i}"] = ast.literal_eval(action[-1] or "{}")
|
164 |
|
165 |
# if len(res["intermediate_steps"]) > 1:
|
@@ -239,7 +279,7 @@ def final_answer(state: GraphState, llm: LLM) -> GraphState:
|
|
239 |
|
240 |
print("> Final Answer")
|
241 |
prompt_template = """
|
242 |
-
|
243 |
|
244 |
Query: ```{query}```
|
245 |
|
@@ -272,7 +312,7 @@ def build_compute_graph(llm: LLM) -> StateGraph:
|
|
272 |
# Add nodes
|
273 |
workflow.add_node("analyse", partial(analyse_query, llm=llm))
|
274 |
workflow.add_node("general_query", partial(general_query, llm=llm))
|
275 |
-
workflow.add_node("es_database_agent", database_agent)
|
276 |
workflow.add_node("final_answer", partial(final_answer, llm=llm))
|
277 |
|
278 |
# Set entry point
|
@@ -291,3 +331,320 @@ def build_compute_graph(llm: LLM) -> StateGraph:
|
|
291 |
workflow.add_edge("final_answer", END)
|
292 |
|
293 |
return workflow
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import TypedDict, List
|
2 |
from functools import partial
|
3 |
import json
|
4 |
import ast
|
|
|
5 |
from pydantic import BaseModel, Field
|
6 |
|
|
|
|
|
7 |
from langchain_core.runnables import RunnableSequence
|
8 |
from langchain_core.language_models.llms import LLM
|
|
|
9 |
from langchain.agents.openai_functions_agent.base import create_openai_functions_agent
|
10 |
from langchain.agents.agent import AgentExecutor
|
11 |
from langchain.agents.agent_types import AgentType
|
12 |
+
from langchain.prompts import ChatPromptTemplate, PromptTemplate, MessagesPlaceholder
|
13 |
from langchain.output_parsers import PydanticOutputParser
|
14 |
from langchain.schema import BaseMessage
|
15 |
+
from langchain.agents import create_tool_calling_agent, AgentExecutor
|
16 |
+
from langchain_core.tools import Tool
|
17 |
|
18 |
from langgraph.graph import StateGraph, END
|
19 |
|
|
|
20 |
from ask_candid.tools.elastic.index_data_tool import IndexShowDataTool
|
21 |
from ask_candid.tools.elastic.index_details_tool import IndexDetailsTool
|
22 |
from ask_candid.tools.elastic.index_search_tool import create_search_tool
|
|
|
23 |
|
24 |
tools = [
|
|
|
25 |
IndexShowDataTool(),
|
26 |
IndexDetailsTool(),
|
27 |
create_search_tool(),
|
|
|
53 |
category: str = Field(..., description="Either 'general' or 'Database'")
|
54 |
|
55 |
|
56 |
+
def agent_factory(llm: LLM) -> AgentExecutor:
|
57 |
"""
|
58 |
Creates and configures an AgentExecutor instance for interacting with Elasticsearch.
|
59 |
|
|
|
67 |
providing detailed intermediate steps for transparency.
|
68 |
"""
|
69 |
|
70 |
+
# llm = ChatOpenAI(
|
71 |
+
# model="gpt-4o", temperature=0, api_key=OPENAI["key"], streaming=False
|
72 |
+
# )
|
73 |
|
74 |
tags_ = []
|
75 |
agent = AgentType.OPENAI_FUNCTIONS
|
|
|
96 |
)
|
97 |
|
98 |
|
99 |
+
def agent_factory_claude(llm: LLM) -> AgentExecutor:
|
100 |
+
"""
|
101 |
+
Creates and configures an AgentExecutor instance for interacting with Elasticsearch.
|
102 |
+
|
103 |
+
This function initializes an OpenAI GPT-4-based LLM with specific parameters,
|
104 |
+
constructs a prompt tailored for Elasticsearch assistance, and integrates the
|
105 |
+
agent with a set of tools to handle user queries. The agent is designed to work
|
106 |
+
with OpenAI functions for enhanced capabilities.
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
AgentExecutor: Configured agent ready to execute tasks with specified tools,
|
110 |
+
providing detailed intermediate steps for transparency.
|
111 |
+
"""
|
112 |
+
|
113 |
+
# llm = ChatOpenAI(
|
114 |
+
# model="gpt-4o", temperature=0, api_key=OPENAI["key"], streaming=False
|
115 |
+
# )
|
116 |
+
|
117 |
+
# tags_ = []
|
118 |
+
# agent = AgentType.OPENAI_FUNCTIONS
|
119 |
+
# tags_.append(agent.value if isinstance(agent, AgentType) else agent)
|
120 |
+
# Create the prompt
|
121 |
+
prompt = ChatPromptTemplate.from_messages(
|
122 |
+
[
|
123 |
+
("system", "You are a helpful elasticsearch assistant"),
|
124 |
+
MessagesPlaceholder(variable_name="chat_history", optional=True),
|
125 |
+
("human", "{input}"),
|
126 |
+
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
127 |
+
]
|
128 |
+
)
|
129 |
+
|
130 |
+
agent = create_tool_calling_agent(llm, tools, prompt)
|
131 |
+
agent_executor = AgentExecutor.from_agent_and_tools(
|
132 |
+
agent=agent, tools=tools, verbose=True, return_intermediate_steps=True
|
133 |
+
)
|
134 |
+
# Create the agent
|
135 |
+
return agent_executor
|
136 |
+
|
137 |
+
|
138 |
# define graph node functions
|
139 |
def general_query(state: GraphState, llm: LLM) -> GraphState:
|
140 |
"""
|
|
|
160 |
return state
|
161 |
|
162 |
|
163 |
+
def database_agent(state: GraphState, llm: LLM) -> GraphState:
|
164 |
"""
|
165 |
Executes a database query using an Elasticsearch agent and updates the graph state.
|
166 |
|
|
|
178 |
print("> database agent")
|
179 |
input_data = {
|
180 |
"input": f"""
|
181 |
+
You are an Elasticsearch database agent designed to accurately understand and respond to user queries. Follow these steps:
|
|
|
|
|
182 |
|
183 |
+
1. Understand the user query to determine the required information.
|
184 |
+
2. Query the indices in the Elasticsearch database.
|
185 |
+
3. Retrieve the mappings and field names relevant to the query.
|
186 |
+
4. Use the organization_dev_2 index to extract the necessary data.
|
187 |
+
5. Present the response in a clear and natural language format, addressing the user's question directly.
|
188 |
+
|
189 |
+
User's quer:
|
190 |
+
```{state["query"]}```
|
191 |
"""
|
192 |
}
|
193 |
+
agent_exec = agent_factory_claude(llm)
|
194 |
res = agent_exec.invoke(input_data)
|
195 |
state["agent_out"] = res["output"]
|
196 |
|
197 |
es_queries, es_results = {}, {}
|
198 |
for i, action in enumerate(res.get("intermediate_steps", []), start=1):
|
199 |
if action[0].tool == "elastic_index_search_tool":
|
200 |
+
es_queries[f"query_{i}"] = json.loads(
|
201 |
+
action[0].tool_input.get("query") or "{}"
|
202 |
+
)
|
203 |
es_results[f"query_{i}"] = ast.literal_eval(action[-1] or "{}")
|
204 |
|
205 |
# if len(res["intermediate_steps"]) > 1:
|
|
|
279 |
|
280 |
print("> Final Answer")
|
281 |
prompt_template = """
|
282 |
+
You are a chat agent that takes outputs generated by Elasticsearch and presents them in a conversational, natural language format, as if responding to a user's query.
|
283 |
|
284 |
Query: ```{query}```
|
285 |
|
|
|
312 |
# Add nodes
|
313 |
workflow.add_node("analyse", partial(analyse_query, llm=llm))
|
314 |
workflow.add_node("general_query", partial(general_query, llm=llm))
|
315 |
+
workflow.add_node("es_database_agent", partial(database_agent, llm=llm))
|
316 |
workflow.add_node("final_answer", partial(final_answer, llm=llm))
|
317 |
|
318 |
# Set entry point
|
|
|
331 |
workflow.add_edge("final_answer", END)
|
332 |
|
333 |
return workflow
|
334 |
+
|
335 |
+
|
336 |
+
class ElasticGraph(StateGraph):
|
337 |
+
"""Elastic Seach Agent State Graph"""
|
338 |
+
|
339 |
+
llm: LLM
|
340 |
+
tools: List[Tool]
|
341 |
+
|
342 |
+
def __init__(self, llm: LLM, tools: List[Tool]):
|
343 |
+
super().__init__(GraphState)
|
344 |
+
self.llm = llm
|
345 |
+
self.tools = tools
|
346 |
+
self.construct_graph()
|
347 |
+
|
348 |
+
def agent_factory(self) -> AgentExecutor:
|
349 |
+
"""
|
350 |
+
Creates and configures an AgentExecutor instance for interacting with Elasticsearch.
|
351 |
+
|
352 |
+
This function initializes an OpenAI GPT-4-based LLM with specific parameters,
|
353 |
+
constructs a prompt tailored for Elasticsearch assistance, and integrates the
|
354 |
+
agent with a set of tools to handle user queries. The agent is designed to work
|
355 |
+
with OpenAI functions for enhanced capabilities.
|
356 |
+
|
357 |
+
Returns:
|
358 |
+
AgentExecutor: Configured agent ready to execute tasks with specified tools,
|
359 |
+
providing detailed intermediate steps for transparency.
|
360 |
+
"""
|
361 |
+
|
362 |
+
# llm = ChatOpenAI(
|
363 |
+
# model="gpt-4o", temperature=0, api_key=OPENAI["key"], streaming=False
|
364 |
+
# )
|
365 |
+
|
366 |
+
tags_ = []
|
367 |
+
agent = AgentType.OPENAI_FUNCTIONS
|
368 |
+
tags_.append(agent.value if isinstance(agent, AgentType) else agent)
|
369 |
+
# Create the prompt
|
370 |
+
prompt = ChatPromptTemplate.from_messages(
|
371 |
+
[
|
372 |
+
("system", "You are a helpful elasticsearch assistant"),
|
373 |
+
MessagesPlaceholder(variable_name="chat_history", optional=True),
|
374 |
+
("human", "{input}"),
|
375 |
+
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
376 |
+
]
|
377 |
+
)
|
378 |
+
|
379 |
+
# Create the agent
|
380 |
+
agent_obj = create_openai_functions_agent(self.llm, tools, prompt)
|
381 |
+
|
382 |
+
return AgentExecutor.from_agent_and_tools(
|
383 |
+
agent=agent_obj,
|
384 |
+
tools=tools,
|
385 |
+
tags=tags_,
|
386 |
+
verbose=True,
|
387 |
+
return_intermediate_steps=True,
|
388 |
+
)
|
389 |
+
|
390 |
+
def agent_factory_claude(self) -> AgentExecutor:
|
391 |
+
"""
|
392 |
+
Creates and configures an AgentExecutor instance for interacting with Elasticsearch.
|
393 |
+
|
394 |
+
This function initializes an OpenAI GPT-4-based LLM with specific parameters,
|
395 |
+
constructs a prompt tailored for Elasticsearch assistance, and integrates the
|
396 |
+
agent with a set of tools to handle user queries. The agent is designed to work
|
397 |
+
with OpenAI functions for enhanced capabilities.
|
398 |
+
|
399 |
+
Returns:
|
400 |
+
AgentExecutor: Configured agent ready to execute tasks with specified tools,
|
401 |
+
providing detailed intermediate steps for transparency.
|
402 |
+
"""
|
403 |
+
prefix = """
|
404 |
+
You are an intelligent agent tasked with generating accurate Elasticsearch DSL queries.
|
405 |
+
Analyze the intent behind the query and determine the appropriate Elasticsearch operations required.
|
406 |
+
Guidelines for generating right elastic seach query:
|
407 |
+
1. Automatically determine whether to return document hits or aggregation results based on the query structure.
|
408 |
+
2. Use keyword fields instead of text fields for aggregations and sorting to avoid fielddata errors
|
409 |
+
3. Avoid using field.keyword if a keyword field is already present to prevent redundant queries.
|
410 |
+
4. Ensure efficient query execution by selecting appropriate query types for filtering, searching, and aggregating.
|
411 |
+
"""
|
412 |
+
prompt = ChatPromptTemplate.from_messages(
|
413 |
+
[
|
414 |
+
("system", f"You are a helpful elasticsearch assistant. {prefix}"),
|
415 |
+
MessagesPlaceholder(variable_name="chat_history", optional=True),
|
416 |
+
("human", "{input}"),
|
417 |
+
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
418 |
+
]
|
419 |
+
)
|
420 |
+
|
421 |
+
agent = create_tool_calling_agent(self.llm, self.tools, prompt)
|
422 |
+
agent_executor = AgentExecutor.from_agent_and_tools(
|
423 |
+
agent=agent, tools=self.tools, verbose=True, return_intermediate_steps=True
|
424 |
+
)
|
425 |
+
# Create the agent
|
426 |
+
return agent_executor
|
427 |
+
|
428 |
+
def analyse_query(self, state: GraphState) -> GraphState:
|
429 |
+
"""
|
430 |
+
Analyzes the user's query to classify it as either general or database-specific
|
431 |
+
and determines the next processing step.
|
432 |
+
|
433 |
+
Args:
|
434 |
+
state (GraphState): Current graph state containing the user's query.
|
435 |
+
llm (LLM): Language model used for query analysis.
|
436 |
+
|
437 |
+
Returns:
|
438 |
+
GraphState: Updated state with the classification result and the
|
439 |
+
next processing step in "next_step".
|
440 |
+
"""
|
441 |
+
|
442 |
+
print("> analyse query")
|
443 |
+
prompt_template = """Your task is to analyze the query ```{query}``` and classify it in:
|
444 |
+
grant: Grant Index - A query where users seek information about grants, funding opportunities, and grantmakers. This includes inquiries about the purpose of funding, eligibility criteria, application processes, grant recipients, funding amounts, deadlines, and how grants can be used for specific projects or initiatives. Users may also request grants tailored to their unique needs, industries, or social impact goals
|
445 |
+
|
446 |
+
org: Org Index - Query which asks speicific details about the organizations, their mission statement, where they are located
|
447 |
+
Output format:
|
448 |
+
{{"category": "<your_classification>"}}
|
449 |
+
"""
|
450 |
+
parser = PydanticOutputParser(pydantic_object=AnalysisResult)
|
451 |
+
|
452 |
+
# Create the prompt
|
453 |
+
prompt = PromptTemplate(
|
454 |
+
template=prompt_template,
|
455 |
+
input_variables=["query"],
|
456 |
+
partial_variables={"format_instructions": parser.get_format_instructions()},
|
457 |
+
)
|
458 |
+
# Create the chain
|
459 |
+
chain = RunnableSequence(prompt, self.llm, parser)
|
460 |
+
# Invoke the chain with the query
|
461 |
+
response = chain.invoke({"query": state["query"]})
|
462 |
+
if response.category == "grant":
|
463 |
+
state["next_step"] = "grant-index"
|
464 |
+
else:
|
465 |
+
state["next_step"] = "org-index"
|
466 |
+
return state
|
467 |
+
|
468 |
+
def grant_index_agent(self, state: GraphState) -> GraphState:
|
469 |
+
print("> Grant Index Agent")
|
470 |
+
input_data = {
|
471 |
+
"input": f"""
|
472 |
+
You are an Elasticsearch database agent designed to accurately understand and respond to user queries. Follow these steps:
|
473 |
+
|
474 |
+
1. Understand the user query to determine the required information.
|
475 |
+
2. Query the indices in the Elasticsearch database.
|
476 |
+
3. Retrieve the mappings and field names relevant to the query.
|
477 |
+
4. Use the ``grants_qa_1`` index to extract the necessary data.
|
478 |
+
5. Ensure that you correctly identify the grantmaker (funder) or recipient (funded entity) if mentioned in the query.
|
479 |
+
Users may not always provide the exact name, so the Elasticsearch query should accommodate partial or incomplete names
|
480 |
+
by searching for relevant keywords.
|
481 |
+
6. Present the response in a clear and natural language format, addressing the user's question directly.
|
482 |
+
|
483 |
+
|
484 |
+
Description of some of the fields in the index but rest of the fields which are not here should be easy to understand:
|
485 |
+
fiscal_year: Year when grantmaker allocates budget for funding and grants. format YYYY
|
486 |
+
text: Objectives,mission, program and funding related information
|
487 |
+
Program_area: program area where organization is working on
|
488 |
+
Title: the title of the funding
|
489 |
+
pcs_v3: PCS is taxonomy, describing the work of grantmakers, recipient organizations and the philanthropic transactions between those entities.
|
490 |
+
The facets of the PCS illuminate the work and answer the following questions about philanthropy:
|
491 |
+
Who? = Population Served
|
492 |
+
What? = Subject and Organization Type
|
493 |
+
How? = Support Strategy and Transaction Type
|
494 |
+
the Facets:
|
495 |
+
Subjects: Describes WHAT is being supported. Example: Elementary education or Clean water supply.
|
496 |
+
Populations: Describes WHO is being supported. Example: Girls or People with disabilities.
|
497 |
+
Organization Type: Describes WHAT type of organization is providing or receiving support.
|
498 |
+
Transaction Type: Describes HOW support is being provided.
|
499 |
+
Support Strategies: Describes HOW activities are being implemented.
|
500 |
+
|
501 |
+
pcs_v3 itself is in a json format:
|
502 |
+
key - subject
|
503 |
+
value: it is a list of dictionary so might need to loop around to find the particular aspect
|
504 |
+
hierarchy: (it is a list having subject name)
|
505 |
+
[
|
506 |
+
{{
|
507 |
+
'name':
|
508 |
+
}},
|
509 |
+
{{
|
510 |
+
'name':
|
511 |
+
}}
|
512 |
+
]
|
513 |
+
Before Writing elastic search query think through which field to use
|
514 |
+
|
515 |
+
Note: first you should focus on query `text` then look into pcs_v3. Make sure you pick the right size for the query
|
516 |
+
|
517 |
+
User's query:
|
518 |
+
```{state["query"]}```
|
519 |
+
"""
|
520 |
+
}
|
521 |
+
agent_exec = self.agent_factory_claude()
|
522 |
+
res = agent_exec.invoke(input_data)
|
523 |
+
state["agent_out"] = res["output"]
|
524 |
+
|
525 |
+
es_queries, es_results = {}, {}
|
526 |
+
for i, action in enumerate(res.get("intermediate_steps", []), start=1):
|
527 |
+
if action[0].tool == "elastic_index_search_tool":
|
528 |
+
es_queries[f"query_{i}"] = json.loads(
|
529 |
+
action[0].tool_input.get("query") or "{}"
|
530 |
+
)
|
531 |
+
es_results[f"query_{i}"] = ast.literal_eval(action[-1] or "{}")
|
532 |
+
|
533 |
+
state["es_query"] = es_queries
|
534 |
+
state["es_result"] = es_results
|
535 |
+
return state
|
536 |
+
|
537 |
+
def org_index_agent(self, state: GraphState) -> GraphState:
|
538 |
+
"""
|
539 |
+
Executes a database query using an Elasticsearch agent and updates the graph state.
|
540 |
+
|
541 |
+
The agent queries indices and field names in the Elasticsearch database,
|
542 |
+
selects the appropriate index (`organization_dev_2`), and answers the user's question.
|
543 |
+
|
544 |
+
Args:
|
545 |
+
state (GraphState): Current graph state containing the user's query.
|
546 |
+
|
547 |
+
Returns:
|
548 |
+
GraphState: Updated state with the agent's output in "agent_out" and
|
549 |
+
the Elasticsearch query in "es_query".
|
550 |
+
"""
|
551 |
+
|
552 |
+
print("> Org Index Agent")
|
553 |
+
input_data = {
|
554 |
+
"input": f"""
|
555 |
+
You are an Elasticsearch database agent designed to accurately understand and respond to user queries. Follow these steps:
|
556 |
+
|
557 |
+
1. Understand the user query to determine the required information.
|
558 |
+
2. Query the indices in the Elasticsearch database.
|
559 |
+
3. Retrieve the mappings and field names relevant to the query.
|
560 |
+
4. Use the `organization_qa_2` index to extract the necessary data.
|
561 |
+
5. Present the response in a clear and natural language format, addressing the user's question directly.
|
562 |
+
|
563 |
+
User's quer:
|
564 |
+
```{state["query"]}```
|
565 |
+
"""
|
566 |
+
}
|
567 |
+
agent_exec = self.agent_factory_claude()
|
568 |
+
res = agent_exec.invoke(input_data)
|
569 |
+
state["agent_out"] = res["output"]
|
570 |
+
|
571 |
+
es_queries, es_results = {}, {}
|
572 |
+
for i, action in enumerate(res.get("intermediate_steps", []), start=1):
|
573 |
+
if action[0].tool == "elastic_index_search_tool":
|
574 |
+
es_queries[f"query_{i}"] = json.loads(
|
575 |
+
action[0].tool_input.get("query") or "{}"
|
576 |
+
)
|
577 |
+
es_results[f"query_{i}"] = ast.literal_eval(action[-1] or "{}")
|
578 |
+
|
579 |
+
state["es_query"] = es_queries
|
580 |
+
state["es_result"] = es_results
|
581 |
+
return state
|
582 |
+
|
583 |
+
def final_answer(self, state: GraphState) -> GraphState:
|
584 |
+
"""
|
585 |
+
Generates and presents the final response based on the user's query and the AI's output.
|
586 |
+
|
587 |
+
Args:
|
588 |
+
state (GraphState): Current graph state containing the query and AI output.
|
589 |
+
llm (LLM): Language model used to format the final response.
|
590 |
+
|
591 |
+
Returns:
|
592 |
+
GraphState: Updated state with the formatted final answer in "agent_out".
|
593 |
+
"""
|
594 |
+
|
595 |
+
print("> Final Answer")
|
596 |
+
prompt_template = """
|
597 |
+
You are a chat agent that takes outputs generated by Elasticsearch and presents them in a conversational, natural language format, as if responding to a user's query.
|
598 |
+
|
599 |
+
Query: ```{query}```
|
600 |
+
|
601 |
+
AI Output:
|
602 |
+
```{output}```
|
603 |
+
"""
|
604 |
+
prompt = ChatPromptTemplate.from_template(prompt_template)
|
605 |
+
chain = RunnableSequence(prompt, self.llm)
|
606 |
+
response = chain.invoke({"query": state["query"], "output": state["agent_out"]})
|
607 |
+
|
608 |
+
return {"agent_out": response.content}
|
609 |
+
|
610 |
+
def construct_graph(self) -> StateGraph:
|
611 |
+
"""
|
612 |
+
Constructs a compute graph for processing user queries using a defined workflow.
|
613 |
+
|
614 |
+
The workflow includes nodes for query analysis, handling general or database-specific queries,
|
615 |
+
and generating the final response. Conditional logic determines the path based on query type.
|
616 |
+
|
617 |
+
Args:
|
618 |
+
llm (LLM): Language model to be used in various nodes for processing queries.
|
619 |
+
|
620 |
+
Returns:
|
621 |
+
StateGraph: Configured compute graph ready for execution.
|
622 |
+
"""
|
623 |
+
|
624 |
+
# Add nodes
|
625 |
+
self.add_node("analyse", self.analyse_query)
|
626 |
+
self.add_node("grant-index", self.grant_index_agent)
|
627 |
+
self.add_node("org-index", self.org_index_agent)
|
628 |
+
self.add_node("final_answer", self.final_answer)
|
629 |
+
|
630 |
+
# Set entry point
|
631 |
+
self.set_entry_point("analyse")
|
632 |
+
|
633 |
+
# Add conditional edges
|
634 |
+
self.add_conditional_edges(
|
635 |
+
"analyse",
|
636 |
+
lambda x: x["next_step"], # Use the return value of analyse_query directly
|
637 |
+
{"org-index": "org-index", "grant-index": "grant-index"},
|
638 |
+
)
|
639 |
+
|
640 |
+
# Add edges to end the workflow
|
641 |
+
self.add_edge("org-index", "final_answer")
|
642 |
+
self.add_edge("grant-index", "final_answer")
|
643 |
+
self.add_edge("final_answer", END)
|
644 |
+
|
645 |
+
|
646 |
+
def build_elastic_graph(llm: LLM, tools: List[Tool]):
|
647 |
+
"""Compile Elastic Agent Graph"""
|
648 |
+
elastic_graph = ElasticGraph(llm=llm, tools=tools)
|
649 |
+
graph = elastic_graph.compile()
|
650 |
+
return graph
|
ask_candid/base/config/connections.py
CHANGED
@@ -32,3 +32,14 @@ SEMANTIC_ELASTIC_QA = BaseElasticAPIKeyCredential(
|
|
32 |
cloud_id=_load_value("SEMANTIC_ELASTIC_CLOUD_ID"),
|
33 |
api_key=_load_value("SEMANTIC_ELASTIC_API_KEY"),
|
34 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
cloud_id=_load_value("SEMANTIC_ELASTIC_CLOUD_ID"),
|
33 |
api_key=_load_value("SEMANTIC_ELASTIC_API_KEY"),
|
34 |
)
|
35 |
+
|
36 |
+
SEMANTIC_ELASTIC_QA_WRITER = BaseElasticAPIKeyCredential(
|
37 |
+
cloud_id=_load_value("SEMANTIC_ELASTIC_WRITER_CLOUD_ID"),
|
38 |
+
api_key=_load_value("SEMANTIC_ELASTIC_WRITER_API_KEY"),
|
39 |
+
)
|
40 |
+
|
41 |
+
NEWS_ELASTIC = BaseElasticSearchConnection(
|
42 |
+
url=_load_value("NEWS_URL"),
|
43 |
+
username=_load_value("NEWS_UID"),
|
44 |
+
password=_load_value("NEWS_PWD")
|
45 |
+
)
|
ask_candid/base/config/constants.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
START_SYSTEM_PROMPT = (
|
2 |
"You are a Candid subject matter expert on the social sector and philanthropy. "
|
3 |
"You should address the user's queries and stay on topic."
|
4 |
-
)
|
|
|
1 |
START_SYSTEM_PROMPT = (
|
2 |
"You are a Candid subject matter expert on the social sector and philanthropy. "
|
3 |
"You should address the user's queries and stay on topic."
|
4 |
+
)
|
ask_candid/base/config/data.py
CHANGED
@@ -16,5 +16,6 @@ ALL_INDICES = (
|
|
16 |
"youtube",
|
17 |
"candid_blog",
|
18 |
"candid_learning",
|
19 |
-
"candid_help"
|
20 |
-
|
|
|
|
16 |
"youtube",
|
17 |
"candid_blog",
|
18 |
"candid_learning",
|
19 |
+
"candid_help",
|
20 |
+
"news"
|
21 |
+
)
|
ask_candid/chat.py
CHANGED
@@ -15,8 +15,10 @@ def run_chat(
|
|
15 |
history: List[Dict],
|
16 |
llm: LLM,
|
17 |
indices: Optional[List[str]] = None,
|
18 |
-
|
19 |
) -> Tuple[gr.MultimodalTextbox, List[Dict[str, Any]], str]:
|
|
|
|
|
20 |
if len(history) == 0:
|
21 |
history.append({"role": "system", "content": START_SYSTEM_PROMPT})
|
22 |
|
@@ -26,6 +28,7 @@ def run_chat(
|
|
26 |
thread_id = get_session_id(thread_id)
|
27 |
config = {"configurable": {"thread_id": thread_id}}
|
28 |
|
|
|
29 |
workflow = build_compute_graph(llm=llm, indices=indices, enable_recommendations=enable_recommendations)
|
30 |
|
31 |
memory = MemorySaver() # TODO: don't use for Prod
|
|
|
15 |
history: List[Dict],
|
16 |
llm: LLM,
|
17 |
indices: Optional[List[str]] = None,
|
18 |
+
premium_features: Optional[List[str]] = None,
|
19 |
) -> Tuple[gr.MultimodalTextbox, List[Dict[str, Any]], str]:
|
20 |
+
if premium_features is None:
|
21 |
+
premium_features = []
|
22 |
if len(history) == 0:
|
23 |
history.append({"role": "system", "content": START_SYSTEM_PROMPT})
|
24 |
|
|
|
28 |
thread_id = get_session_id(thread_id)
|
29 |
config = {"configurable": {"thread_id": thread_id}}
|
30 |
|
31 |
+
enable_recommendations = "Recommendation" in premium_features
|
32 |
workflow = build_compute_graph(llm=llm, indices=indices, enable_recommendations=enable_recommendations)
|
33 |
|
34 |
memory = MemorySaver() # TODO: don't use for Prod
|
ask_candid/graph.py
CHANGED
@@ -80,6 +80,7 @@ def generate_with_context(state: AgentState, llm: LLM) -> AgentState:
|
|
80 |
|
81 |
def add_recommendations_pipeline_(
|
82 |
G: StateGraph,
|
|
|
83 |
reformulation_node_name: str = "reformulate",
|
84 |
search_node_name: str = "search_agent"
|
85 |
) -> None:
|
@@ -96,7 +97,7 @@ def add_recommendations_pipeline_(
|
|
96 |
"""
|
97 |
|
98 |
# Nodes for recommendation functionalities
|
99 |
-
G.add_node("detect_intent_with_llm", detect_intent_with_llm)
|
100 |
G.add_node("determine_context", determine_context)
|
101 |
G.add_node("make_recommendation", make_recommendation)
|
102 |
|
@@ -142,7 +143,7 @@ def build_compute_graph(
|
|
142 |
|
143 |
G = StateGraph(AgentState)
|
144 |
|
145 |
-
G.add_node("reformulate", partial(reformulate_question_using_history, llm=llm))
|
146 |
G.add_node("search_agent", partial(search_agent, llm=llm, tools=tools))
|
147 |
G.add_node("retrieve", retrieve)
|
148 |
G.add_node("generate_with_context", partial(generate_with_context, llm=llm))
|
@@ -150,7 +151,7 @@ def build_compute_graph(
|
|
150 |
G.add_node("insert_org_link", insert_org_link)
|
151 |
|
152 |
if enable_recommendations:
|
153 |
-
add_recommendations_pipeline_(G, reformulation_node_name="reformulate", search_node_name="search_agent")
|
154 |
else:
|
155 |
G.add_edge("reformulate", "search_agent")
|
156 |
|
|
|
80 |
|
81 |
def add_recommendations_pipeline_(
|
82 |
G: StateGraph,
|
83 |
+
llm: LLM,
|
84 |
reformulation_node_name: str = "reformulate",
|
85 |
search_node_name: str = "search_agent"
|
86 |
) -> None:
|
|
|
97 |
"""
|
98 |
|
99 |
# Nodes for recommendation functionalities
|
100 |
+
G.add_node("detect_intent_with_llm", partial(detect_intent_with_llm, llm=llm))
|
101 |
G.add_node("determine_context", determine_context)
|
102 |
G.add_node("make_recommendation", make_recommendation)
|
103 |
|
|
|
143 |
|
144 |
G = StateGraph(AgentState)
|
145 |
|
146 |
+
G.add_node("reformulate", partial(reformulate_question_using_history, llm=llm, focus_on_recommendations=enable_recommendations))
|
147 |
G.add_node("search_agent", partial(search_agent, llm=llm, tools=tools))
|
148 |
G.add_node("retrieve", retrieve)
|
149 |
G.add_node("generate_with_context", partial(generate_with_context, llm=llm))
|
|
|
151 |
G.add_node("insert_org_link", insert_org_link)
|
152 |
|
153 |
if enable_recommendations:
|
154 |
+
add_recommendations_pipeline_(G, llm=llm, reformulation_node_name="reformulate", search_node_name="search_agent")
|
155 |
else:
|
156 |
G.add_edge("reformulate", "search_agent")
|
157 |
|
ask_candid/retrieval/elastic.py
CHANGED
@@ -1,500 +1,566 @@
|
|
1 |
-
from typing import List, Tuple, Dict, Iterable, Iterator, Optional, Union, Any
|
2 |
-
from dataclasses import dataclass
|
3 |
-
from functools import partial
|
4 |
-
from itertools import groupby
|
5 |
-
|
6 |
-
from torch.nn import functional as F
|
7 |
-
|
8 |
-
from pydantic import BaseModel, Field
|
9 |
-
from langchain_core.documents import Document
|
10 |
-
from langchain_core.tools import Tool
|
11 |
-
|
12 |
-
from elasticsearch import Elasticsearch
|
13 |
-
|
14 |
-
from ask_candid.
|
15 |
-
from ask_candid.
|
16 |
-
from ask_candid.base.config.
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
"
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
"""
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
q = build_sparse_vector_query(
|
135 |
-
query=query,
|
136 |
-
fields=("content", "
|
137 |
-
)
|
138 |
-
q["_source"] = {"excludes": ["embeddings"]}
|
139 |
-
q["size"] =
|
140 |
-
queries.extend([{"index": ElasticIndexMapping.
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
return
|
176 |
-
|
177 |
-
|
178 |
-
def
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
for
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
#
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
"
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple, Dict, Iterable, Iterator, Optional, Union, Any
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from functools import partial
|
4 |
+
from itertools import groupby
|
5 |
+
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
from pydantic import BaseModel, Field
|
9 |
+
from langchain_core.documents import Document
|
10 |
+
from langchain_core.tools import Tool
|
11 |
+
|
12 |
+
from elasticsearch import Elasticsearch
|
13 |
+
|
14 |
+
from ask_candid.retrieval.sparse_lexical import SpladeEncoder
|
15 |
+
from ask_candid.services.small_lm import CandidSLM
|
16 |
+
from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA, NEWS_ELASTIC
|
17 |
+
from ask_candid.base.config.data import ElasticIndexMapping, ALL_INDICES
|
18 |
+
|
19 |
+
encoder = SpladeEncoder()
|
20 |
+
|
21 |
+
|
22 |
+
@dataclass
|
23 |
+
class ElasticHitsResult:
|
24 |
+
"""Dataclass for Elasticsearch hits results
|
25 |
+
"""
|
26 |
+
index: str
|
27 |
+
id: Any
|
28 |
+
score: float
|
29 |
+
source: Dict[str, Any]
|
30 |
+
inner_hits: Dict[str, Any]
|
31 |
+
|
32 |
+
|
33 |
+
class RetrieverInput(BaseModel):
|
34 |
+
"""Input to the Elasticsearch retriever."""
|
35 |
+
user_input: str = Field(description="query to look up in retriever")
|
36 |
+
|
37 |
+
|
38 |
+
def build_sparse_vector_query(
|
39 |
+
query: str,
|
40 |
+
fields: Tuple[str],
|
41 |
+
inference_id: str = ".elser-2-elasticsearch"
|
42 |
+
) -> Dict[str, Any]:
|
43 |
+
"""Builds a valid Elasticsearch text expansion query payload
|
44 |
+
|
45 |
+
Parameters
|
46 |
+
----------
|
47 |
+
query : str
|
48 |
+
Search context string
|
49 |
+
fields : Tuple[str]
|
50 |
+
Semantic text field names
|
51 |
+
inference_id : str, optional
|
52 |
+
ID of model deployed in Elasticsearch, by default ".elser-2-elasticsearch"
|
53 |
+
|
54 |
+
Returns
|
55 |
+
-------
|
56 |
+
Dict[str, Any]
|
57 |
+
"""
|
58 |
+
|
59 |
+
output = []
|
60 |
+
|
61 |
+
for f in fields:
|
62 |
+
output.append({
|
63 |
+
"nested": {
|
64 |
+
"path": f"embeddings.{f}.chunks",
|
65 |
+
"query": {
|
66 |
+
"sparse_vector": {
|
67 |
+
"field": f"embeddings.{f}.chunks.vector",
|
68 |
+
"inference_id": inference_id,
|
69 |
+
"prune": True,
|
70 |
+
"query": query,
|
71 |
+
"boost": 1 / len(fields)
|
72 |
+
}
|
73 |
+
},
|
74 |
+
"inner_hits": {
|
75 |
+
"_source": False,
|
76 |
+
"size": 2,
|
77 |
+
"fields": [f"embeddings.{f}.chunks.chunk"]
|
78 |
+
}
|
79 |
+
}
|
80 |
+
})
|
81 |
+
return {"query": {"bool": {"should": output}}}
|
82 |
+
|
83 |
+
|
84 |
+
def news_query_builder(query: str) -> Dict[str, Any]:
|
85 |
+
tokens = encoder.token_expand(query)
|
86 |
+
|
87 |
+
query = {
|
88 |
+
"_source": ["id", "link", "title", "content"],
|
89 |
+
"query": {
|
90 |
+
"bool": {
|
91 |
+
"filter": [
|
92 |
+
{"range": {"event_date": {"gte": "now-60d/d"}}},
|
93 |
+
{"range": {"insert_date": {"gte": "now-60d/d"}}},
|
94 |
+
{"range": {"article_trust_worthiness": {"gt": 0.8}}}
|
95 |
+
],
|
96 |
+
"should": []
|
97 |
+
}
|
98 |
+
}
|
99 |
+
}
|
100 |
+
|
101 |
+
for token, score in tokens.items():
|
102 |
+
if score > 0.4:
|
103 |
+
query["query"]["bool"]["should"].append({
|
104 |
+
"multi_match": {
|
105 |
+
"query": token,
|
106 |
+
"fields": ["title", "content"],
|
107 |
+
"boost": score
|
108 |
+
}
|
109 |
+
})
|
110 |
+
return query
|
111 |
+
|
112 |
+
|
113 |
+
def query_builder(query: str, indices: List[str]) -> List[Dict[str, Any]]:
|
114 |
+
"""Builds Elasticsearch multi-search query payload
|
115 |
+
|
116 |
+
Parameters
|
117 |
+
----------
|
118 |
+
query : str
|
119 |
+
Search context string
|
120 |
+
indices : List[str]
|
121 |
+
Semantic index names to search over
|
122 |
+
|
123 |
+
Returns
|
124 |
+
-------
|
125 |
+
List[Dict[str, Any]]
|
126 |
+
"""
|
127 |
+
|
128 |
+
queries = []
|
129 |
+
if indices is None:
|
130 |
+
indices = list(ALL_INDICES)
|
131 |
+
|
132 |
+
for index in indices:
|
133 |
+
if index == "issuelab":
|
134 |
+
q = build_sparse_vector_query(
|
135 |
+
query=query,
|
136 |
+
fields=("description", "content", "combined_issuelab_findings", "combined_item_description")
|
137 |
+
)
|
138 |
+
q["_source"] = {"excludes": ["embeddings"]}
|
139 |
+
q["size"] = 1
|
140 |
+
queries.extend([{"index": ElasticIndexMapping.ISSUELAB_INDEX_ELSER}, q])
|
141 |
+
elif index == "youtube":
|
142 |
+
q = build_sparse_vector_query(
|
143 |
+
query=query,
|
144 |
+
fields=("captions_cleaned", "description_cleaned", "title")
|
145 |
+
)
|
146 |
+
# text_cleaned duplicates captions_cleaned
|
147 |
+
q["_source"] = {"excludes": ["embeddings", "captions", "description", "text_cleaned"]}
|
148 |
+
q["size"] = 2
|
149 |
+
queries.extend([{"index": ElasticIndexMapping.YOUTUBE_INDEX_ELSER}, q])
|
150 |
+
elif index == "candid_blog":
|
151 |
+
q = build_sparse_vector_query(
|
152 |
+
query=query,
|
153 |
+
fields=("content", "authors_text", "title_summary_tags")
|
154 |
+
)
|
155 |
+
q["_source"] = {"excludes": ["embeddings"]}
|
156 |
+
q["size"] = 2
|
157 |
+
queries.extend([{"index": ElasticIndexMapping.CANDID_BLOG_INDEX_ELSER}, q])
|
158 |
+
elif index == "candid_learning":
|
159 |
+
q = build_sparse_vector_query(
|
160 |
+
query=query,
|
161 |
+
fields=("content", "title", "training_topics", "staff_recommendations")
|
162 |
+
)
|
163 |
+
q["_source"] = {"excludes": ["embeddings"]}
|
164 |
+
q["size"] = 2
|
165 |
+
queries.extend([{"index": ElasticIndexMapping.CANDID_LEARNING_INDEX_ELSER}, q])
|
166 |
+
elif index == "candid_help":
|
167 |
+
q = build_sparse_vector_query(
|
168 |
+
query=query,
|
169 |
+
fields=("content", "combined_article_description")
|
170 |
+
)
|
171 |
+
q["_source"] = {"excludes": ["embeddings"]}
|
172 |
+
q["size"] = 2
|
173 |
+
queries.extend([{"index": ElasticIndexMapping.CANDID_HELP_INDEX_ELSER}, q])
|
174 |
+
|
175 |
+
return queries
|
176 |
+
|
177 |
+
|
178 |
+
def multi_search(
|
179 |
+
queries: List[Dict[str, Any]],
|
180 |
+
news_query: Optional[Dict[str, Any]] = None
|
181 |
+
) -> List[ElasticHitsResult]:
|
182 |
+
"""Runs multi-search query
|
183 |
+
|
184 |
+
Parameters
|
185 |
+
----------
|
186 |
+
queries : List[Dict[str, Any]]
|
187 |
+
Pre-built multi-search query payload
|
188 |
+
|
189 |
+
Returns
|
190 |
+
-------
|
191 |
+
List[ElasticHitsResult]
|
192 |
+
"""
|
193 |
+
|
194 |
+
results = []
|
195 |
+
|
196 |
+
if len(queries) > 0:
|
197 |
+
with Elasticsearch(
|
198 |
+
cloud_id=SEMANTIC_ELASTIC_QA.cloud_id,
|
199 |
+
api_key=SEMANTIC_ELASTIC_QA.api_key,
|
200 |
+
verify_certs=False,
|
201 |
+
request_timeout=60 * 3
|
202 |
+
) as es:
|
203 |
+
for query_group in es.msearch(body=queries).get("responses", []):
|
204 |
+
for hit in query_group.get("hits", {}).get("hits", []):
|
205 |
+
hit = ElasticHitsResult(
|
206 |
+
index=hit["_index"],
|
207 |
+
id=hit["_id"],
|
208 |
+
score=hit["_score"],
|
209 |
+
source=hit["_source"],
|
210 |
+
inner_hits=hit.get("inner_hits", {})
|
211 |
+
)
|
212 |
+
results.append(hit)
|
213 |
+
|
214 |
+
if news_query is not None:
|
215 |
+
with Elasticsearch(
|
216 |
+
NEWS_ELASTIC.url,
|
217 |
+
http_auth=(NEWS_ELASTIC.username, NEWS_ELASTIC.password),
|
218 |
+
timeout=60
|
219 |
+
) as es:
|
220 |
+
for hit in es.search(body=news_query, index="news_1").get("hits", {}).get("hits") or []:
|
221 |
+
hit = ElasticHitsResult(
|
222 |
+
index=hit["_index"],
|
223 |
+
id=hit["_id"],
|
224 |
+
score=hit["_score"],
|
225 |
+
source=hit["_source"],
|
226 |
+
inner_hits=hit.get("inner_hits", {})
|
227 |
+
)
|
228 |
+
results.append(hit)
|
229 |
+
return results
|
230 |
+
|
231 |
+
|
232 |
+
def get_query_results(search_text: str, indices: Optional[List[str]] = None) -> List[ElasticHitsResult]:
|
233 |
+
"""Builds and executes Elasticsearch data queries from a search string.
|
234 |
+
|
235 |
+
Parameters
|
236 |
+
----------
|
237 |
+
search_text : str
|
238 |
+
Search context string
|
239 |
+
indices : Optional[List[str]], optional
|
240 |
+
Semantic index names to search over, by default None
|
241 |
+
|
242 |
+
Returns
|
243 |
+
-------
|
244 |
+
List[ElasticHitsResult]
|
245 |
+
"""
|
246 |
+
|
247 |
+
queries = query_builder(query=search_text, indices=indices)
|
248 |
+
news_q = news_query_builder(query=search_text)
|
249 |
+
return multi_search(queries, news_query=news_q)
|
250 |
+
|
251 |
+
|
252 |
+
def retrieved_text(hits: Dict[str, Any]) -> str:
|
253 |
+
"""Extracts retrieved sub-texts from documents which are strong hits from semantic queries for the purpose of
|
254 |
+
re-scoring by a secondary language model.
|
255 |
+
|
256 |
+
Parameters
|
257 |
+
----------
|
258 |
+
hits : Dict[str, Any]
|
259 |
+
|
260 |
+
Returns
|
261 |
+
-------
|
262 |
+
str
|
263 |
+
"""
|
264 |
+
|
265 |
+
text = []
|
266 |
+
for _, v in hits.items():
|
267 |
+
for h in (v.get("hits", {}).get("hits") or []):
|
268 |
+
for _, field in h.get("fields", {}).items():
|
269 |
+
for chunk in field:
|
270 |
+
if chunk.get("chunk"):
|
271 |
+
text.extend(chunk["chunk"])
|
272 |
+
return '\n'.join(text)
|
273 |
+
|
274 |
+
|
275 |
+
def cosine_rescore(query: str, contexts: List[str]) -> List[float]:
|
276 |
+
"""Computes cosine scores between retrieved contexts and the original query to re-score results based on overall
|
277 |
+
relevance to the original query.
|
278 |
+
|
279 |
+
Parameters
|
280 |
+
----------
|
281 |
+
query : str
|
282 |
+
Search context string
|
283 |
+
contexts : List[str]
|
284 |
+
Semantic field sub-texts, order is by document retrieved from the original multi-search query.
|
285 |
+
|
286 |
+
Returns
|
287 |
+
-------
|
288 |
+
List[float]
|
289 |
+
Scores in the same order as the input document contexts
|
290 |
+
"""
|
291 |
+
|
292 |
+
nlp = CandidSLM()
|
293 |
+
X = nlp.encode([query, *contexts]).vectors
|
294 |
+
X = F.normalize(X, dim=-1, p=2.)
|
295 |
+
cosine = X[1:] @ X[:1].T
|
296 |
+
return cosine.flatten().cpu().numpy().tolist()
|
297 |
+
|
298 |
+
|
299 |
+
def reranker(
|
300 |
+
query_results: Iterable[ElasticHitsResult],
|
301 |
+
search_text: Optional[str] = None
|
302 |
+
) -> Iterator[ElasticHitsResult]:
|
303 |
+
"""Reranks Elasticsearch hits coming from multiple indices/queries which may have scores on different scales.
|
304 |
+
This will shuffle results
|
305 |
+
|
306 |
+
Parameters
|
307 |
+
----------
|
308 |
+
query_results : Iterable[ElasticHitsResult]
|
309 |
+
|
310 |
+
Yields
|
311 |
+
------
|
312 |
+
Iterator[ElasticHitsResult]
|
313 |
+
"""
|
314 |
+
|
315 |
+
results: List[ElasticHitsResult] = []
|
316 |
+
texts: List[str] = []
|
317 |
+
for _, data in groupby(query_results, key=lambda x: x.index):
|
318 |
+
data = list(data)
|
319 |
+
max_score = max(data, key=lambda x: x.score).score
|
320 |
+
min_score = min(data, key=lambda x: x.score).score
|
321 |
+
|
322 |
+
for d in data:
|
323 |
+
d.score = (d.score - min_score) / (max_score - min_score + 1e-9)
|
324 |
+
results.append(d)
|
325 |
+
|
326 |
+
if search_text:
|
327 |
+
text = retrieved_text(d.inner_hits)
|
328 |
+
texts.append(text)
|
329 |
+
|
330 |
+
# if search_text and len(texts) == len(results):
|
331 |
+
# scores = cosine_rescore(search_text, texts)
|
332 |
+
# for r, s in zip(results, scores):
|
333 |
+
# r.score = s
|
334 |
+
|
335 |
+
yield from sorted(results, key=lambda x: x.score, reverse=True)
|
336 |
+
|
337 |
+
|
338 |
+
def get_results(user_input: str, indices: List[str]) -> Tuple[str, List[Document]]:
|
339 |
+
"""End-to-end search and re-rank function.
|
340 |
+
|
341 |
+
Parameters
|
342 |
+
----------
|
343 |
+
user_input : str
|
344 |
+
Search context string
|
345 |
+
indices : List[str]
|
346 |
+
Semantic index names to search over
|
347 |
+
|
348 |
+
Returns
|
349 |
+
-------
|
350 |
+
Tuple[str, List[Document]]
|
351 |
+
(concatenated text from search results, documents list)
|
352 |
+
"""
|
353 |
+
|
354 |
+
output = ["Search didn't return any Candid sources"]
|
355 |
+
page_content = []
|
356 |
+
content = "Search didn't return any Candid sources"
|
357 |
+
results = get_query_results(search_text=user_input, indices=indices)
|
358 |
+
if results:
|
359 |
+
output = get_reranked_results(results, search_text=user_input)
|
360 |
+
for doc in output:
|
361 |
+
page_content.append(doc.page_content)
|
362 |
+
content = "\n\n".join(page_content)
|
363 |
+
|
364 |
+
# for the tool we need to return a tuple for content_and_artifact type
|
365 |
+
return content, output
|
366 |
+
|
367 |
+
|
368 |
+
def get_context(field_name: str, hit: ElasticHitsResult, context_length: int = 1024, add_context: bool = True) -> str:
|
369 |
+
"""Pads the relevant chunk of text with context before and after
|
370 |
+
|
371 |
+
Parameters
|
372 |
+
----------
|
373 |
+
field_name : str
|
374 |
+
a field with the long text that was chunked into pieces
|
375 |
+
hit : ElasticHitsResult
|
376 |
+
context_length : int, optional
|
377 |
+
length of text to add before and after the chunk, by default 1024
|
378 |
+
|
379 |
+
Returns
|
380 |
+
-------
|
381 |
+
str
|
382 |
+
longer chunks stuffed together
|
383 |
+
"""
|
384 |
+
|
385 |
+
chunks = []
|
386 |
+
# NOTE chunks have tokens, long text is a normal text, but may contain html that also gets weird after tokenization
|
387 |
+
long_text = hit.source.get(f"{field_name}", "")
|
388 |
+
long_text = long_text.lower()
|
389 |
+
inner_hits_field = f"embeddings.{field_name}.chunks"
|
390 |
+
found_chunks = hit.inner_hits.get(inner_hits_field, {})
|
391 |
+
if found_chunks:
|
392 |
+
hits = found_chunks.get("hits", {}).get("hits", [])
|
393 |
+
for h in hits:
|
394 |
+
chunk = h.get("fields", {})[inner_hits_field][0]["chunk"][0]
|
395 |
+
|
396 |
+
# cutting the middle because we may have tokenizing artifacts there
|
397 |
+
chunk = chunk[3: -3]
|
398 |
+
|
399 |
+
if add_context:
|
400 |
+
# Find the start and end indices of the chunk in the large text
|
401 |
+
start_index = long_text.find(chunk[:20])
|
402 |
+
|
403 |
+
# Chunk is found
|
404 |
+
if start_index != -1:
|
405 |
+
end_index = start_index + len(chunk)
|
406 |
+
pre_start_index = max(0, start_index - context_length)
|
407 |
+
post_end_index = min(len(long_text), end_index + context_length)
|
408 |
+
chunks.append(long_text[pre_start_index:post_end_index])
|
409 |
+
else:
|
410 |
+
chunks.append(chunk)
|
411 |
+
return '\n\n'.join(chunks)
|
412 |
+
|
413 |
+
|
414 |
+
def process_hit(hit: ElasticHitsResult) -> Union[Document, None]:
|
415 |
+
"""Parse Elasticsearch hit results into data structures handled by the RAG pipeline.
|
416 |
+
|
417 |
+
Parameters
|
418 |
+
----------
|
419 |
+
hit : ElasticHitsResult
|
420 |
+
|
421 |
+
Returns
|
422 |
+
-------
|
423 |
+
Union[Document, None]
|
424 |
+
"""
|
425 |
+
|
426 |
+
if "issuelab-elser" in hit.index:
|
427 |
+
combined_item_description = hit.source.get("combined_item_description", "") # title inside
|
428 |
+
description = hit.source.get("description", "")
|
429 |
+
combined_issuelab_findings = hit.source.get("combined_issuelab_findings", "")
|
430 |
+
# we only need to process long texts
|
431 |
+
chunks_with_context_txt = get_context("content", hit, context_length=12)
|
432 |
+
doc = Document(
|
433 |
+
page_content='\n\n'.join([
|
434 |
+
combined_item_description,
|
435 |
+
combined_issuelab_findings,
|
436 |
+
description,
|
437 |
+
chunks_with_context_txt
|
438 |
+
]),
|
439 |
+
metadata={
|
440 |
+
"title": hit.source["title"],
|
441 |
+
"source": "IssueLab",
|
442 |
+
"source_id": hit.source["resource_id"],
|
443 |
+
"url": hit.source.get("permalink", "")
|
444 |
+
}
|
445 |
+
)
|
446 |
+
elif "youtube" in hit.index:
|
447 |
+
title = hit.source.get("title", "")
|
448 |
+
# we only need to process long texts
|
449 |
+
description_cleaned_with_context_txt = get_context("description_cleaned", hit, context_length=12)
|
450 |
+
captions_cleaned_with_context_txt = get_context("captions_cleaned", hit, context_length=12)
|
451 |
+
doc = Document(
|
452 |
+
page_content='\n\n'.join([title, description_cleaned_with_context_txt, captions_cleaned_with_context_txt]),
|
453 |
+
metadata={
|
454 |
+
"title": title,
|
455 |
+
"source": "Candid YouTube",
|
456 |
+
"source_id": hit.source['video_id'],
|
457 |
+
"url": f"https://www.youtube.com/watch?v={hit.source['video_id']}"
|
458 |
+
}
|
459 |
+
)
|
460 |
+
elif "candid-blog" in hit.index:
|
461 |
+
excerpt = hit.source.get("excerpt", "")
|
462 |
+
title = hit.source.get("title", "")
|
463 |
+
# we only need to process long text
|
464 |
+
content_with_context_txt = get_context("content", hit, context_length=12, add_context=False)
|
465 |
+
authors = get_context("authors_text", hit, context_length=12, add_context=False)
|
466 |
+
tags = hit.source.get("title_summary_tags", "")
|
467 |
+
doc = Document(
|
468 |
+
page_content='\n\n'.join([title, excerpt, content_with_context_txt, authors, tags]),
|
469 |
+
metadata={
|
470 |
+
"title": title,
|
471 |
+
"source": "Candid Blog",
|
472 |
+
"source_id": hit.source["id"],
|
473 |
+
"url": hit.source["link"]
|
474 |
+
}
|
475 |
+
)
|
476 |
+
elif "candid-learning" in hit.index:
|
477 |
+
title = hit.source.get("title", "")
|
478 |
+
content_with_context_txt = get_context("content", hit, context_length=12)
|
479 |
+
training_topics = hit.source.get("training_topics", "")
|
480 |
+
staff_recommendations = hit.source.get("staff_recommendations", "")
|
481 |
+
|
482 |
+
doc = Document(
|
483 |
+
page_content='\n\n'.join([title, staff_recommendations, training_topics, content_with_context_txt]),
|
484 |
+
metadata={
|
485 |
+
"title": hit.source["title"],
|
486 |
+
"source": "Candid Learning",
|
487 |
+
"source_id": hit.source["post_id"],
|
488 |
+
"url": hit.source.get("url", "")
|
489 |
+
}
|
490 |
+
)
|
491 |
+
elif "candid-help" in hit.index:
|
492 |
+
title = hit.source.get("title", "")
|
493 |
+
content_with_context_txt = get_context("content", hit, context_length=12)
|
494 |
+
combined_article_description = hit.source.get("combined_article_description", "")
|
495 |
+
|
496 |
+
doc = Document(
|
497 |
+
page_content='\n\n'.join([combined_article_description, content_with_context_txt]),
|
498 |
+
metadata={
|
499 |
+
"title": title,
|
500 |
+
"source": "Candid Help",
|
501 |
+
"source_id": hit.source["id"],
|
502 |
+
"url": hit.source.get("link", "")
|
503 |
+
}
|
504 |
+
)
|
505 |
+
elif "news" in hit.index:
|
506 |
+
doc = Document(
|
507 |
+
page_content='\n\n'.join([hit.source.get("title", ""), hit.source.get("content", "")]),
|
508 |
+
metadata={
|
509 |
+
"title": hit.source.get("title", ""),
|
510 |
+
"source": "Candid News",
|
511 |
+
"source_id": hit.source["id"],
|
512 |
+
"url": hit.source.get("link", "")
|
513 |
+
}
|
514 |
+
)
|
515 |
+
else:
|
516 |
+
doc = None
|
517 |
+
return doc
|
518 |
+
|
519 |
+
|
520 |
+
def get_reranked_results(results: List[ElasticHitsResult], search_text: Optional[str] = None) -> List[Document]:
|
521 |
+
"""Run data re-ranking and document building for tool usage.
|
522 |
+
|
523 |
+
Parameters
|
524 |
+
----------
|
525 |
+
results : List[ElasticHitsResult]
|
526 |
+
search_text : Optional[str], optional
|
527 |
+
Search context string, by default None
|
528 |
+
|
529 |
+
Returns
|
530 |
+
-------
|
531 |
+
List[Document]
|
532 |
+
"""
|
533 |
+
|
534 |
+
output = []
|
535 |
+
for r in reranker(results, search_text=search_text):
|
536 |
+
hit = process_hit(r)
|
537 |
+
if hit is not None:
|
538 |
+
output.append(hit)
|
539 |
+
return output
|
540 |
+
|
541 |
+
|
542 |
+
def retriever_tool(indices: List[str]) -> Tool:
|
543 |
+
"""Tool component for use in conditional edge building for RAG execution graph.
|
544 |
+
Cannot use `create_retriever_tool` because it only provides content losing all metadata on the way
|
545 |
+
https://python.langchain.com/docs/how_to/custom_tools/#returning-artifacts-of-tool-execution
|
546 |
+
|
547 |
+
Parameters
|
548 |
+
----------
|
549 |
+
indices : List[str]
|
550 |
+
Semantic index names to search over
|
551 |
+
|
552 |
+
Returns
|
553 |
+
-------
|
554 |
+
Tool
|
555 |
+
"""
|
556 |
+
|
557 |
+
return Tool(
|
558 |
+
name="retrieve_social_sector_information",
|
559 |
+
func=partial(get_results, indices=indices),
|
560 |
+
description=(
|
561 |
+
"Return additional information about social and philanthropic sector, "
|
562 |
+
"including nonprofits (NGO), grants, foundations, funding, RFP, LOI, Candid."
|
563 |
+
),
|
564 |
+
args_schema=RetrieverInput,
|
565 |
+
response_format="content_and_artifact"
|
566 |
+
)
|
ask_candid/retrieval/sparse_lexical.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
|
3 |
+
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
class SpladeEncoder:
|
8 |
+
|
9 |
+
def __init__(self):
|
10 |
+
model_id = "naver/splade-v3"
|
11 |
+
|
12 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
13 |
+
self.model = AutoModelForMaskedLM.from_pretrained(model_id)
|
14 |
+
self.idx2token = {idx: token for token, idx in self.tokenizer.get_vocab().items()}
|
15 |
+
|
16 |
+
@torch.no_grad()
|
17 |
+
def token_expand(self, query: str) -> Dict[str, float]:
|
18 |
+
tokens = self.tokenizer(query, return_tensors='pt')
|
19 |
+
output = self.model(**tokens)
|
20 |
+
|
21 |
+
vec = torch.max(
|
22 |
+
torch.log(1 + torch.relu(output.logits)) * tokens.attention_mask.unsqueeze(-1),
|
23 |
+
dim=1
|
24 |
+
)[0].squeeze()
|
25 |
+
cols = vec.nonzero().squeeze().cpu().tolist()
|
26 |
+
weights = vec[cols].cpu().tolist()
|
27 |
+
|
28 |
+
sparse_dict_tokens = {self.idx2token[idx]: round(weight, 3) for idx, weight in zip(cols, weights) if weight > 0}
|
29 |
+
return dict(sorted(sparse_dict_tokens.items(), key=lambda item: item[1], reverse=True))
|
ask_candid/tools/elastic/index_details_tool.py
CHANGED
@@ -19,7 +19,7 @@ es = Elasticsearch(
|
|
19 |
cloud_id=SEMANTIC_ELASTIC_QA.cloud_id,
|
20 |
api_key=SEMANTIC_ELASTIC_QA.api_key,
|
21 |
verify_certs=True,
|
22 |
-
request_timeout=60 * 3
|
23 |
)
|
24 |
|
25 |
|
@@ -62,7 +62,9 @@ class IndexDetailsTool(BaseTool):
|
|
62 |
}
|
63 |
)
|
64 |
except Exception as e:
|
65 |
-
logger.exception(
|
|
|
|
|
66 |
return ""
|
67 |
|
68 |
async def _arun(
|
|
|
19 |
cloud_id=SEMANTIC_ELASTIC_QA.cloud_id,
|
20 |
api_key=SEMANTIC_ELASTIC_QA.api_key,
|
21 |
verify_certs=True,
|
22 |
+
request_timeout=60 * 3,
|
23 |
)
|
24 |
|
25 |
|
|
|
62 |
}
|
63 |
)
|
64 |
except Exception as e:
|
65 |
+
logger.exception(
|
66 |
+
"Could not fetch index information for %s: %s", index_name, e
|
67 |
+
)
|
68 |
return ""
|
69 |
|
70 |
async def _arun(
|
ask_candid/tools/elastic/index_search_tool.py
CHANGED
@@ -3,6 +3,7 @@ import json
|
|
3 |
|
4 |
import tiktoken
|
5 |
from elasticsearch import Elasticsearch
|
|
|
6 |
# from pydantic.v1 import BaseModel, Field # <-- Uses v1 namespace
|
7 |
from pydantic import BaseModel, Field
|
8 |
from langchain.tools import StructuredTool
|
@@ -15,7 +16,7 @@ es = Elasticsearch(
|
|
15 |
cloud_id=SEMANTIC_ELASTIC_QA.cloud_id,
|
16 |
api_key=SEMANTIC_ELASTIC_QA.api_key,
|
17 |
verify_certs=True,
|
18 |
-
request_timeout=60 * 3
|
19 |
)
|
20 |
|
21 |
|
@@ -81,8 +82,18 @@ def elastic_search(
|
|
81 |
if query_dict is None and aggs_dict is not None:
|
82 |
# When a result has aggregations, just return that and ignore the rest
|
83 |
final_res = str(res["aggregations"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
else:
|
85 |
final_res = str(res["hits"])
|
|
|
86 |
tokens = encoding.encode(final_res)
|
87 |
retries += 1
|
88 |
if len(tokens) > 6000:
|
@@ -98,5 +109,16 @@ def elastic_search(
|
|
98 |
|
99 |
def create_search_tool():
|
100 |
return StructuredTool.from_function(
|
101 |
-
elastic_search,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
)
|
|
|
3 |
|
4 |
import tiktoken
|
5 |
from elasticsearch import Elasticsearch
|
6 |
+
|
7 |
# from pydantic.v1 import BaseModel, Field # <-- Uses v1 namespace
|
8 |
from pydantic import BaseModel, Field
|
9 |
from langchain.tools import StructuredTool
|
|
|
16 |
cloud_id=SEMANTIC_ELASTIC_QA.cloud_id,
|
17 |
api_key=SEMANTIC_ELASTIC_QA.api_key,
|
18 |
verify_certs=True,
|
19 |
+
request_timeout=60 * 3,
|
20 |
)
|
21 |
|
22 |
|
|
|
82 |
if query_dict is None and aggs_dict is not None:
|
83 |
# When a result has aggregations, just return that and ignore the rest
|
84 |
final_res = str(res["aggregations"])
|
85 |
+
elif query_dict is not None and aggs_dict is not None:
|
86 |
+
# Return both hits and aggregations
|
87 |
+
final_res = str(
|
88 |
+
{
|
89 |
+
"hits": res.get("hits", {}),
|
90 |
+
"aggregations": res.get("aggregations", {}),
|
91 |
+
}
|
92 |
+
)
|
93 |
+
|
94 |
else:
|
95 |
final_res = str(res["hits"])
|
96 |
+
|
97 |
tokens = encoding.encode(final_res)
|
98 |
retries += 1
|
99 |
if len(tokens) > 6000:
|
|
|
109 |
|
110 |
def create_search_tool():
|
111 |
return StructuredTool.from_function(
|
112 |
+
elastic_search,
|
113 |
+
name="elastic_index_search_tool",
|
114 |
+
description=(
|
115 |
+
"""This tool allows executing queries on an Elasticsearch index efficiently. Provide:
|
116 |
+
1. index_name (string): The target Elasticsearch index.
|
117 |
+
2. query (dictionary): Defines the query structure, supporting:
|
118 |
+
a. Filters: For precise data retrieval (e.g., match, term, range).
|
119 |
+
b. Aggregations: For statistical summaries and grouping (e.g., sum, average, histogram).
|
120 |
+
c. Full-text search: For analyzing and ranking text-based results (e.g., match, multi-match, query_string).
|
121 |
+
"""
|
122 |
+
),
|
123 |
+
args_schema=SearchToolInput,
|
124 |
)
|
ask_candid/tools/org_seach.py
CHANGED
@@ -2,7 +2,7 @@ from typing import List
|
|
2 |
import logging
|
3 |
import re
|
4 |
|
5 |
-
from
|
6 |
|
7 |
from langchain.output_parsers.openai_tools import JsonOutputToolsParser
|
8 |
# from langchain_openai.chat_models import ChatOpenAI
|
|
|
2 |
import logging
|
3 |
import re
|
4 |
|
5 |
+
from thefuzz import fuzz
|
6 |
|
7 |
from langchain.output_parsers.openai_tools import JsonOutputToolsParser
|
8 |
# from langchain_openai.chat_models import ChatOpenAI
|
ask_candid/tools/question_reformulation.py
CHANGED
@@ -2,13 +2,17 @@ from langchain_core.prompts import ChatPromptTemplate
|
|
2 |
from langchain_core.output_parsers import StrOutputParser
|
3 |
|
4 |
|
5 |
-
def reformulate_question_using_history(state, llm):
|
6 |
"""
|
7 |
-
Transform the query to produce a better query with details from previous messages
|
|
|
8 |
|
9 |
Args:
|
10 |
-
state (
|
11 |
-
llm: LLM to use
|
|
|
|
|
|
|
12 |
Returns:
|
13 |
dict: The updated state with re-phrased question and original user_input for UI
|
14 |
"""
|
@@ -17,23 +21,39 @@ def reformulate_question_using_history(state, llm):
|
|
17 |
question = messages[-1].content
|
18 |
|
19 |
if len(messages) > 1:
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
contextualize_q_prompt = ChatPromptTemplate([
|
36 |
-
("system",
|
37 |
("human", question),
|
38 |
])
|
39 |
|
|
|
2 |
from langchain_core.output_parsers import StrOutputParser
|
3 |
|
4 |
|
5 |
+
def reformulate_question_using_history(state, llm, focus_on_recommendations=False):
|
6 |
"""
|
7 |
+
Transform the query to produce a better query with details from previous messages and emphasize
|
8 |
+
aspects important for recommendations if needed.
|
9 |
|
10 |
Args:
|
11 |
+
state (dict): The current state containing messages.
|
12 |
+
llm: LLM to use for generating the reformulation.
|
13 |
+
focus_on_recommendations (bool): Flag to determine if the reformulation should emphasize
|
14 |
+
recommendation-relevant aspects such as geographies,
|
15 |
+
cause areas, etc.
|
16 |
Returns:
|
17 |
dict: The updated state with re-phrased question and original user_input for UI
|
18 |
"""
|
|
|
21 |
question = messages[-1].content
|
22 |
|
23 |
if len(messages) > 1:
|
24 |
+
if focus_on_recommendations:
|
25 |
+
prompt_text = """Given a chat history and the latest user input \
|
26 |
+
which might reference context in the chat history, \
|
27 |
+
especially geographic locations, cause areas and/or population groups, \
|
28 |
+
formulate a standalone input which can be understood without the chat history.
|
29 |
+
Chat history:
|
30 |
+
\n ------- \n
|
31 |
+
{chat_history}
|
32 |
+
\n ------- \n
|
33 |
+
User input:
|
34 |
+
\n ------- \n
|
35 |
+
{question}
|
36 |
+
\n ------- \n
|
37 |
+
Reformulate the question without adding implications or assumptions about the user's needs or intentions.
|
38 |
+
Focus solely on clarifying any contextual details present in the original input."""
|
39 |
+
else:
|
40 |
+
prompt_text = """Given a chat history and the latest user input \
|
41 |
+
which might reference context in the chat history, formulate a standalone input \
|
42 |
+
which can be understood without the chat history.
|
43 |
+
Chat history:
|
44 |
+
\n ------- \n
|
45 |
+
{chat_history}
|
46 |
+
\n ------- \n
|
47 |
+
User input:
|
48 |
+
\n ------- \n
|
49 |
+
{question}
|
50 |
+
\n ------- \n
|
51 |
+
Do NOT answer the question, \
|
52 |
+
just reformulate it if needed and otherwise return it as is.
|
53 |
+
"""
|
54 |
|
55 |
contextualize_q_prompt = ChatPromptTemplate([
|
56 |
+
("system", prompt_text),
|
57 |
("human", question),
|
58 |
])
|
59 |
|
ask_candid/tools/recommendation.py
CHANGED
@@ -1,87 +1,155 @@
|
|
1 |
-
import logging
|
2 |
import os
|
3 |
|
4 |
from openai import OpenAI
|
|
|
5 |
import requests
|
6 |
|
7 |
from ask_candid.agents.schema import AgentState, Context
|
8 |
-
from ask_candid.base.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
-
logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s")
|
11 |
-
logger = logging.getLogger(__name__)
|
12 |
-
logger.setLevel(logging.INFO)
|
13 |
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
query = state["messages"][-1].content
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
)
|
40 |
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
43 |
return state
|
44 |
|
45 |
|
46 |
def determine_context(state: AgentState) -> AgentState:
|
47 |
-
"
|
48 |
-
logger.info("---GETTING RECOMMENDATION CONTEXT---")
|
49 |
query = state["messages"][-1].content
|
50 |
|
|
|
|
|
|
|
51 |
subject_codes, population_codes, geo_ids = [], [], []
|
52 |
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
'Content-Type': 'application/json'
|
57 |
-
}
|
58 |
-
autocoding_params = {
|
59 |
-
'text': query,
|
60 |
-
'taxonomy': 'pcs-v3'
|
61 |
-
}
|
62 |
-
autocoding_response = requests.get(
|
63 |
-
os.getenv("AUTOCODING_API_URL"),
|
64 |
-
headers=autocoding_headers,
|
65 |
-
params=autocoding_params,
|
66 |
-
timeout=30
|
67 |
-
)
|
68 |
-
if autocoding_response.status_code == 200:
|
69 |
-
returned_pcs = autocoding_response.json()["data"]
|
70 |
population_codes = [item['full_code'] for item in returned_pcs.get("population", [])]
|
71 |
subject_codes = [item['full_code'] for item in returned_pcs.get("subject", [])]
|
|
|
|
|
72 |
|
73 |
-
|
74 |
-
|
75 |
-
'
|
76 |
-
'
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
}
|
81 |
-
geo_response = requests.post(os.getenv("GEO_API_URL"), headers=geo_headers, json=geo_data, timeout=30)
|
82 |
-
if geo_response.status_code == 200:
|
83 |
-
entities = geo_response.json()['data']['entities']
|
84 |
-
geo_ids = [entity['geo']['id'] for entity in entities if 'id' in entity['geo']]
|
85 |
|
86 |
state["context"] = Context(
|
87 |
subject=subject_codes,
|
@@ -91,99 +159,80 @@ def determine_context(state: AgentState) -> AgentState:
|
|
91 |
return state
|
92 |
|
93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
def make_recommendation(state: AgentState) -> AgentState:
|
95 |
-
"
|
96 |
-
|
97 |
-
logger.info("---RECOMMENDING---")
|
98 |
-
org_id = "6908122"
|
99 |
funder_or_rfp = state["intent"]
|
100 |
|
101 |
-
# Extract context
|
102 |
contexts = state["context"]
|
103 |
-
subject_codes = contexts.get("subject", [])
|
104 |
-
population_codes = contexts.get("population", [])
|
105 |
-
geo_ids = contexts.get("geography", [])
|
106 |
-
|
107 |
-
# Prepare parameters
|
108 |
-
params = {
|
109 |
-
"subjects": ",".join(subject_codes),
|
110 |
-
"geos": ",".join([str(geo) for geo in geo_ids]),
|
111 |
-
"populations": ",".join(population_codes)
|
112 |
-
}
|
113 |
-
headers = {"x-api-key": os.getenv("FUNDER_REC_API_KEY")}
|
114 |
-
base_url = os.getenv("FUNDER_REC_API_URL")
|
115 |
-
|
116 |
-
# Initialize response
|
117 |
-
response = None
|
118 |
|
119 |
recommendation_display_text = ""
|
120 |
|
121 |
try:
|
122 |
-
# Make the API call based on intent
|
123 |
if funder_or_rfp == "funder":
|
124 |
-
|
|
|
125 |
elif funder_or_rfp == "rfp":
|
126 |
-
|
127 |
-
|
128 |
else:
|
129 |
-
|
130 |
-
state["recommendation"] =
|
131 |
return state
|
132 |
|
133 |
-
|
134 |
-
|
135 |
-
recommendations = response.json().get("recommendations", [])
|
136 |
-
if recommendations:
|
137 |
-
if funder_or_rfp == "funder":
|
138 |
-
# Format recommendations
|
139 |
-
recommendation_display_text = "Here are the top 10 recommendations. Click their profiles to learn more:\n" + "\n".join([
|
140 |
-
f"{recommendation['funder_data']['main_sort_name']} - Profile: https://app.candid.org/profile/{recommendation['funder_id']}"
|
141 |
-
for recommendation in recommendations
|
142 |
-
])
|
143 |
-
elif funder_or_rfp == "rfp":
|
144 |
-
recommendation_display_text = "Here are the top recommendations:\n" + "\n".join([
|
145 |
-
f"Title: {rec['title']}\n"
|
146 |
-
f"Funder: {rec['funder_name']}\n"
|
147 |
-
f"Amount: {rec.get('amount', 'Not specified')}\n"
|
148 |
-
f"Description: {rec.get('description', 'No description available')}\n"
|
149 |
-
f"Deadline: {rec.get('deadline', 'No deadline provided')}\n"
|
150 |
-
f"Application URL: {rec.get('application_url', 'No URL available')}\n"
|
151 |
-
for rec in recommendations
|
152 |
-
])
|
153 |
-
else:
|
154 |
-
# No recommendations found
|
155 |
-
recommendation_display_text = "No recommendations were found for your query. Please try refining your search criteria."
|
156 |
-
elif response and response.status_code == 400:
|
157 |
-
# Handle bad request
|
158 |
-
error_details = response.json()
|
159 |
-
recommendation_display_text = (
|
160 |
-
"An error occurred while processing your request. "
|
161 |
-
f"Details: {error_details.get('message', 'Unknown error.')}"
|
162 |
-
)
|
163 |
-
elif response:
|
164 |
-
# Handle other unexpected status codes
|
165 |
-
recommendation_display_text = (
|
166 |
-
f"An unexpected error occurred (Status Code: {response.status_code}). "
|
167 |
-
"Please try again later or contact support if the problem persists."
|
168 |
-
)
|
169 |
else:
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
recommendation_display_text =
|
176 |
-
"A network error occurred while trying to connect to the recommendation service. "
|
177 |
-
f"Details: {str(e)}"
|
178 |
-
)
|
179 |
except Exception as e:
|
180 |
-
#
|
181 |
-
print(
|
182 |
-
recommendation_display_text =
|
183 |
-
"An unexpected error occurred while processing your request. "
|
184 |
-
f"Details: {str(e)}"
|
185 |
-
)
|
186 |
|
187 |
-
# Update state with recommendations or error messages
|
188 |
state["recommendation"] = recommendation_display_text
|
189 |
-
return state
|
|
|
|
|
1 |
import os
|
2 |
|
3 |
from openai import OpenAI
|
4 |
+
from langchain_core.prompts import ChatPromptTemplate
|
5 |
import requests
|
6 |
|
7 |
from ask_candid.agents.schema import AgentState, Context
|
8 |
+
from ask_candid.base.api_base import BaseAPI
|
9 |
+
|
10 |
+
class AutocodingAPI(BaseAPI):
|
11 |
+
def __init__(self):
|
12 |
+
super().__init__(
|
13 |
+
url=os.getenv("AUTOCODING_API_URL"),
|
14 |
+
headers={
|
15 |
+
'x-api-key': os.getenv("AUTOCODING_API_KEY"),
|
16 |
+
'Content-Type': 'application/json'
|
17 |
+
}
|
18 |
+
)
|
19 |
+
|
20 |
+
def __call__(self, text: str, taxonomy: str = 'pcs-v3'):
|
21 |
+
params = {
|
22 |
+
'text': text,
|
23 |
+
'taxonomy': taxonomy
|
24 |
+
}
|
25 |
+
return self.get(**params)
|
26 |
+
|
27 |
+
class GeoAPI(BaseAPI):
|
28 |
+
def __init__(self):
|
29 |
+
super().__init__(
|
30 |
+
url=os.getenv("GEO_API_URL"),
|
31 |
+
headers={
|
32 |
+
'x-api-key': os.getenv("GEO_API_KEY"),
|
33 |
+
'Content-Type': 'application/json'
|
34 |
+
}
|
35 |
+
)
|
36 |
+
|
37 |
+
def __call__(self, text: str):
|
38 |
+
payload = {
|
39 |
+
'text': text
|
40 |
+
}
|
41 |
+
return self.post(payload=payload)
|
42 |
+
|
43 |
+
class EntitiesAPI(BaseAPI):
|
44 |
+
def __init__(self):
|
45 |
+
super().__init__(
|
46 |
+
url=f'{os.getenv("DOCUMENT_API_URL")}/entities',
|
47 |
+
headers={
|
48 |
+
'x-api-key': os.getenv("DOCUMENT_API_KEY"),
|
49 |
+
'Content-Type': 'application/json'
|
50 |
+
}
|
51 |
+
)
|
52 |
+
|
53 |
+
def __call__(self, text: str):
|
54 |
+
payload = {
|
55 |
+
'text': text
|
56 |
+
}
|
57 |
+
return self.post(payload=payload)
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
class FunderRecommendationAPI(BaseAPI):
|
62 |
+
def __init__(self):
|
63 |
+
super().__init__(
|
64 |
+
url=os.getenv("FUNDER_REC_API_URL"),
|
65 |
+
headers={"x-api-key": os.getenv("FUNDER_REC_API_KEY")}
|
66 |
+
)
|
67 |
+
|
68 |
+
def __call__(self, subjects, populations, geos):
|
69 |
+
params = {
|
70 |
+
"subjects": subjects,
|
71 |
+
"populations": populations,
|
72 |
+
"geos": geos
|
73 |
+
}
|
74 |
+
return self.get(**params)
|
75 |
+
|
76 |
+
class RFPRecommendationAPI(BaseAPI):
|
77 |
+
def __init__(self):
|
78 |
+
super().__init__(
|
79 |
+
url= f'{os.getenv("FUNDER_REC_API_URL")}/rfp',
|
80 |
+
headers={"x-api-key": os.getenv("FUNDER_REC_API_KEY")}
|
81 |
+
)
|
82 |
+
|
83 |
+
def __call__(self, org_id, subjects, populations, geos):
|
84 |
+
params = {
|
85 |
+
"candid_entity_id": org_id,
|
86 |
+
"subjects": subjects,
|
87 |
+
"populations": populations,
|
88 |
+
"geos": geos
|
89 |
+
}
|
90 |
+
return self.get(**params)
|
91 |
|
|
|
|
|
|
|
92 |
|
93 |
|
94 |
+
|
95 |
+
def detect_intent_with_llm(state: AgentState, llm) -> AgentState:
|
96 |
+
"""Detect query intent (which type of recommendation) and update the state using the specified LLM."""
|
97 |
+
print("running detect intent")
|
98 |
+
|
99 |
query = state["messages"][-1].content
|
100 |
+
|
101 |
+
prompt_template = ChatPromptTemplate.from_messages(
|
102 |
+
[
|
103 |
+
("system", """
|
104 |
+
Please classify the following query by stating ONLY the category name: 'none', 'funder', or 'rfp'.
|
105 |
+
Please answer WITHOUT any reasoning.
|
106 |
+
- 'none': The query does not ask for any recommendations.
|
107 |
+
- 'funder': The query asks for recommendations about funders, such as foundations or donors.
|
108 |
+
- 'rfp': The query asks for recommendations about specific Requests for Proposals (RFPs).
|
109 |
+
|
110 |
+
Consider:
|
111 |
+
- If the query seeks broad, long-term funding sources or organizations, classify as 'funder'.
|
112 |
+
- If the query seeks specific, time-bound funding opportunities with a deadline, classify as 'rfp'.
|
113 |
+
- If the query does not seek any recommendations, classify as 'none'.
|
114 |
+
|
115 |
+
Query: """),
|
116 |
+
("human", f"{query}")
|
117 |
+
]
|
118 |
)
|
119 |
|
120 |
+
chain = prompt_template | llm
|
121 |
+
response = chain.invoke({"query": query})
|
122 |
+
|
123 |
+
intent = response.content.strip().lower()
|
124 |
+
state["intent"] = intent.strip("'").strip('"') # Remove extra quotes if necessary
|
125 |
+
print(state["intent"])
|
126 |
return state
|
127 |
|
128 |
|
129 |
def determine_context(state: AgentState) -> AgentState:
|
130 |
+
print("running context")
|
|
|
131 |
query = state["messages"][-1].content
|
132 |
|
133 |
+
autocoding_api = AutocodingAPI()
|
134 |
+
entities_api = EntitiesAPI()
|
135 |
+
|
136 |
subject_codes, population_codes, geo_ids = [], [], []
|
137 |
|
138 |
+
try:
|
139 |
+
autocoding_response = autocoding_api(text=query)
|
140 |
+
returned_pcs = autocoding_response.get("data", {})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
population_codes = [item['full_code'] for item in returned_pcs.get("population", [])]
|
142 |
subject_codes = [item['full_code'] for item in returned_pcs.get("subject", [])]
|
143 |
+
except Exception as e:
|
144 |
+
print(f"Failed to retrieve autocoding data: {e}")
|
145 |
|
146 |
+
try:
|
147 |
+
geo_response = entities_api(text=query)
|
148 |
+
entities = geo_response.get('entities', [])
|
149 |
+
geo_ids = [match['geonames_id'] for entity in entities if entity['type'] == 'geo' and 'match' in entity
|
150 |
+
for match in entity['match'] if 'geonames_id' in match]
|
151 |
+
except Exception as e:
|
152 |
+
print(f"Failed to retrieve geographic data: {e}")
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
state["context"] = Context(
|
155 |
subject=subject_codes,
|
|
|
159 |
return state
|
160 |
|
161 |
|
162 |
+
|
163 |
+
def format_recommendations(intent, data):
|
164 |
+
if 'recommendations' not in data:
|
165 |
+
return "No recommendations available."
|
166 |
+
|
167 |
+
recommendations = data['recommendations']
|
168 |
+
if not recommendations:
|
169 |
+
return "No recommendations found."
|
170 |
+
|
171 |
+
recommendation_texts = []
|
172 |
+
if intent == "funder":
|
173 |
+
for rec in recommendations:
|
174 |
+
main_sort_name = rec['funder_data']['main_sort_name']
|
175 |
+
profile_url = f"https://app.candid.org/profile/{rec['funder_id']}"
|
176 |
+
recommendation_texts.append(f"{main_sort_name} - Profile: {profile_url}")
|
177 |
+
elif intent == "rfp":
|
178 |
+
for rec in recommendations:
|
179 |
+
title = rec.get('title', 'N/A')
|
180 |
+
funder_name = rec.get('funder_name', 'N/A')
|
181 |
+
amount = rec.get('amount', 'Not specified')
|
182 |
+
description = rec.get('description', 'No description available')
|
183 |
+
deadline = rec.get('deadline', 'No deadline provided')
|
184 |
+
application_url = rec.get('application_url', 'No URL available')
|
185 |
+
text = (f"Title: {title}\n"
|
186 |
+
f"Funder: {funder_name}\n"
|
187 |
+
f"Amount: {amount}\n"
|
188 |
+
f"Description: {description}\n"
|
189 |
+
f"Deadline: {deadline}\n"
|
190 |
+
f"Application URL: {application_url}\n")
|
191 |
+
recommendation_texts.append(text)
|
192 |
+
else:
|
193 |
+
return "Only funder recommendation or RFP recommendation are supported."
|
194 |
+
|
195 |
+
return "\n".join(recommendation_texts)
|
196 |
+
|
197 |
+
|
198 |
+
|
199 |
def make_recommendation(state: AgentState) -> AgentState:
|
200 |
+
print("running recommendation")
|
201 |
+
org_id = "6908122" # Example organization ID (Candid)
|
|
|
|
|
202 |
funder_or_rfp = state["intent"]
|
203 |
|
|
|
204 |
contexts = state["context"]
|
205 |
+
subject_codes = ",".join(contexts.get("subject", []))
|
206 |
+
population_codes = ",".join(contexts.get("population", []))
|
207 |
+
geo_ids = ",".join([str(geo) for geo in contexts.get("geography", [])])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
|
209 |
recommendation_display_text = ""
|
210 |
|
211 |
try:
|
|
|
212 |
if funder_or_rfp == "funder":
|
213 |
+
funder_api = FunderRecommendationAPI()
|
214 |
+
recommendations = funder_api(subject_codes, population_codes, geo_ids)
|
215 |
elif funder_or_rfp == "rfp":
|
216 |
+
rfp_api = RFPRecommendationAPI()
|
217 |
+
recommendations = rfp_api(org_id, subject_codes, population_codes, geo_ids)
|
218 |
else:
|
219 |
+
recommendation_display_text = "Unknown intent. Intent 'funder' or 'rfp' expected."
|
220 |
+
state["recommendation"] = recommendation_display_text
|
221 |
return state
|
222 |
|
223 |
+
if recommendations:
|
224 |
+
recommendation_display_text = format_recommendations(funder_or_rfp, recommendations)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
else:
|
226 |
+
recommendation_display_text = "No recommendations were found for your query. Please try refining your search criteria."
|
227 |
+
|
228 |
+
except requests.exceptions.HTTPError as e:
|
229 |
+
# Handle HTTP errors raised by raise_for_status()
|
230 |
+
print(f"HTTP error occurred: {e.response.status_code} - {e.response.reason}")
|
231 |
+
recommendation_display_text = "HTTP error occurred, please report this to datascience@candid.org"
|
|
|
|
|
|
|
232 |
except Exception as e:
|
233 |
+
# Catch-all for any other exceptions that are not HTTP errors
|
234 |
+
print(f"An unexpected error occurred: {str(e)}")
|
235 |
+
recommendation_display_text = "Unexpected error occurred, please report this to datascience@candid.org"
|
|
|
|
|
|
|
236 |
|
|
|
237 |
state["recommendation"] = recommendation_display_text
|
238 |
+
return state
|
ask_candid/utils.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
from typing import List, Dict, Union, Any
|
2 |
from uuid import uuid4
|
3 |
|
|
|
|
|
4 |
from ask_candid.retrieval.sources import (
|
5 |
candid_blog,
|
6 |
candid_help,
|
@@ -10,11 +12,6 @@ from ask_candid.retrieval.sources import (
|
|
10 |
)
|
11 |
|
12 |
|
13 |
-
def filter_messages(messages, k=10):
|
14 |
-
# TODO summarize messages instead
|
15 |
-
return messages[-k:]
|
16 |
-
|
17 |
-
|
18 |
def html_format_doc(doc: Dict[str, Any], source: str, show_chunks=False) -> str:
|
19 |
height_px = 200
|
20 |
html = ""
|
@@ -23,10 +20,8 @@ def html_format_doc(doc: Dict[str, Any], source: str, show_chunks=False) -> str:
|
|
23 |
# html = news.article_card_html(doc, height_px, show_chunks)
|
24 |
pass
|
25 |
elif source == "transactions":
|
26 |
-
# html = cds.transaction_card_html(doc, height_px, show_chunks)
|
27 |
pass
|
28 |
elif source == "organizations":
|
29 |
-
# html = up_orgs.organization_card_html(doc, 400, show_chunks)
|
30 |
pass
|
31 |
elif source == "issuelab":
|
32 |
html = issuelab.issuelab_card_html(doc, height_px, show_chunks)
|
@@ -41,10 +36,20 @@ def html_format_doc(doc: Dict[str, Any], source: str, show_chunks=False) -> str:
|
|
41 |
return html
|
42 |
|
43 |
|
44 |
-
def html_format_docs_chat(docs):
|
45 |
-
"""
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
"""
|
|
|
48 |
html = ""
|
49 |
if docs:
|
50 |
docs_html = []
|
@@ -54,8 +59,8 @@ def html_format_docs_chat(docs):
|
|
54 |
|
55 |
s_html = (
|
56 |
"<span class='source-item'>"
|
57 |
-
f"<a href={s_url} target='_blank' rel='noreferrer' class='ssearch-source'>"
|
58 |
-
f"{doc.metadata['title']}
|
59 |
)
|
60 |
|
61 |
docs_html.append(s_html)
|
@@ -64,19 +69,6 @@ def html_format_docs_chat(docs):
|
|
64 |
return html
|
65 |
|
66 |
|
67 |
-
def format_chat_response(chatbot: List[Any]) -> List[Any]:
|
68 |
-
"""We have sources appended as one more tuple. Here we concatinate HTML of sources
|
69 |
-
with the AI response
|
70 |
-
Returns:
|
71 |
-
_type_: updated chatbot message as HTML
|
72 |
-
"""
|
73 |
-
if chatbot:
|
74 |
-
sources = chatbot[-1][1]
|
75 |
-
chatbot.pop(-1)
|
76 |
-
chatbot[-1][1] = chatbot[-1][1] + sources
|
77 |
-
return chatbot
|
78 |
-
|
79 |
-
|
80 |
def format_chat_ag_response(chatbot: List[Any]) -> List[Any]:
|
81 |
"""If we called retriever, we appended sources as as one more message. Here we concatinate HTML of sources
|
82 |
with the AI response
|
|
|
1 |
from typing import List, Dict, Union, Any
|
2 |
from uuid import uuid4
|
3 |
|
4 |
+
from langchain_core.documents import Document
|
5 |
+
|
6 |
from ask_candid.retrieval.sources import (
|
7 |
candid_blog,
|
8 |
candid_help,
|
|
|
12 |
)
|
13 |
|
14 |
|
|
|
|
|
|
|
|
|
|
|
15 |
def html_format_doc(doc: Dict[str, Any], source: str, show_chunks=False) -> str:
|
16 |
height_px = 200
|
17 |
html = ""
|
|
|
20 |
# html = news.article_card_html(doc, height_px, show_chunks)
|
21 |
pass
|
22 |
elif source == "transactions":
|
|
|
23 |
pass
|
24 |
elif source == "organizations":
|
|
|
25 |
pass
|
26 |
elif source == "issuelab":
|
27 |
html = issuelab.issuelab_card_html(doc, height_px, show_chunks)
|
|
|
36 |
return html
|
37 |
|
38 |
|
39 |
+
def html_format_docs_chat(docs: List[Document]) -> str:
|
40 |
+
"""Formats Candid sources
|
41 |
+
|
42 |
+
Parameters
|
43 |
+
----------
|
44 |
+
docs : List[Document]
|
45 |
+
Retrieved documents for context
|
46 |
+
|
47 |
+
Returns
|
48 |
+
-------
|
49 |
+
str
|
50 |
+
Formatted HTML
|
51 |
"""
|
52 |
+
|
53 |
html = ""
|
54 |
if docs:
|
55 |
docs_html = []
|
|
|
59 |
|
60 |
s_html = (
|
61 |
"<span class='source-item'>"
|
62 |
+
f"<a href='{s_url}' target='_blank' rel='noreferrer' class='ssearch-source'>"
|
63 |
+
f"{doc.metadata['title']} | {s_name}</a></span>"
|
64 |
)
|
65 |
|
66 |
docs_html.append(s_html)
|
|
|
69 |
return html
|
70 |
|
71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
def format_chat_ag_response(chatbot: List[Any]) -> List[Any]:
|
73 |
"""If we called retriever, we appended sources as as one more message. Here we concatinate HTML of sources
|
74 |
with the AI response
|
requirements.txt
CHANGED
@@ -1,13 +1,15 @@
|
|
1 |
boto3
|
2 |
elasticsearch==7.17.6
|
3 |
-
|
4 |
gradio
|
5 |
langchain
|
6 |
langchain-aws
|
7 |
langchain-openai
|
8 |
langgraph
|
9 |
pydantic
|
|
|
10 |
python-dotenv
|
|
|
11 |
|
12 |
--find-links https://download.pytorch.org/whl/cpu
|
13 |
torch
|
|
|
1 |
boto3
|
2 |
elasticsearch==7.17.6
|
3 |
+
thefuzz
|
4 |
gradio
|
5 |
langchain
|
6 |
langchain-aws
|
7 |
langchain-openai
|
8 |
langgraph
|
9 |
pydantic
|
10 |
+
pyopenssl>22.0.0
|
11 |
python-dotenv
|
12 |
+
transformers
|
13 |
|
14 |
--find-links https://download.pytorch.org/whl/cpu
|
15 |
torch
|