Spaces:
Running
Running
UI callbacks and style changes
Browse files- app.py +2 -0
- ask_candid/base/config/data.py +5 -14
- ask_candid/chat.py +6 -1
- ask_candid/graph.py +53 -28
- ask_candid/retrieval/elastic.py +60 -117
- ask_candid/retrieval/sources/candid_blog.py +7 -0
- ask_candid/retrieval/sources/candid_help.py +7 -0
- ask_candid/retrieval/sources/candid_learning.py +7 -0
- ask_candid/retrieval/sources/candid_news.py +7 -0
- ask_candid/retrieval/sources/issuelab.py +7 -0
- ask_candid/retrieval/sources/schema.py +9 -0
- ask_candid/retrieval/sources/youtube.py +8 -0
- ask_candid/tools/elastic/list_indices_tool.py +2 -1
- ask_candid/tools/org_seach.py +25 -12
- ask_candid/tools/search.py +86 -2
- ask_candid/utils.py +1 -1
app.py
CHANGED
@@ -147,6 +147,8 @@ def build_rag_chat() -> Tuple[LoggedComponents, gr.Blocks]:
|
|
147 |
show_copy_button=True,
|
148 |
show_share_button=None,
|
149 |
show_copy_all_button=False,
|
|
|
|
|
150 |
)
|
151 |
msg = gr.MultimodalTextbox(label="Your message", interactive=True)
|
152 |
thread_id = gr.Text(visible=False, value="", label="thread_id")
|
|
|
147 |
show_copy_button=True,
|
148 |
show_share_button=None,
|
149 |
show_copy_all_button=False,
|
150 |
+
autoscroll=True,
|
151 |
+
layout="panel",
|
152 |
)
|
153 |
msg = gr.MultimodalTextbox(label="Your message", interactive=True)
|
154 |
thread_id = gr.Text(visible=False, value="", label="thread_id")
|
ask_candid/base/config/data.py
CHANGED
@@ -1,21 +1,12 @@
|
|
1 |
-
|
2 |
-
"Mapping from plain name to Elasticsearch index name"
|
3 |
|
4 |
-
|
5 |
-
ISSUELAB_INDEX_ELSER = "search-semantic-issuelab-elser_ve2"
|
6 |
-
YOUTUBE_INDEX = "search-semantic-youtube_v1"
|
7 |
-
YOUTUBE_INDEX_ELSER = "search-semantic-youtube-elser_ve1"
|
8 |
-
CANDID_BLOG_INDEX = "search-semantic-candid-blog_v1"
|
9 |
-
CANDID_BLOG_INDEX_ELSER = "search-semantic-candid-blog"
|
10 |
-
CANDID_LEARNING_INDEX_ELSER = "search-semantic-candid-learning_ve1"
|
11 |
-
CANDID_HELP_INDEX_ELSER = "search-semantic-candid-help-elser_ve1"
|
12 |
-
|
13 |
-
|
14 |
-
ALL_INDICES = (
|
15 |
"issuelab",
|
16 |
"youtube",
|
17 |
"candid_blog",
|
18 |
"candid_learning",
|
19 |
"candid_help",
|
20 |
"news"
|
21 |
-
|
|
|
|
|
|
1 |
+
from typing import Literal, get_args
|
|
|
2 |
|
3 |
+
DataIndices = Literal[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
"issuelab",
|
5 |
"youtube",
|
6 |
"candid_blog",
|
7 |
"candid_learning",
|
8 |
"candid_help",
|
9 |
"news"
|
10 |
+
]
|
11 |
+
|
12 |
+
ALL_INDICES = get_args(DataIndices)
|
ask_candid/chat.py
CHANGED
@@ -29,7 +29,12 @@ def run_chat(
|
|
29 |
config = {"configurable": {"thread_id": thread_id}}
|
30 |
|
31 |
enable_recommendations = "Recommendation" in premium_features
|
32 |
-
workflow = build_compute_graph(
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
memory = MemorySaver() # TODO: don't use for Prod
|
35 |
graph = workflow.compile(checkpointer=memory)
|
|
|
29 |
config = {"configurable": {"thread_id": thread_id}}
|
30 |
|
31 |
enable_recommendations = "Recommendation" in premium_features
|
32 |
+
workflow = build_compute_graph(
|
33 |
+
llm=llm,
|
34 |
+
indices=indices,
|
35 |
+
user_callback=gr.Info,
|
36 |
+
enable_recommendations=enable_recommendations
|
37 |
+
)
|
38 |
|
39 |
memory = MemorySaver() # TODO: don't use for Prod
|
40 |
graph = workflow.compile(checkpointer=memory)
|
ask_candid/graph.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from typing import List
|
2 |
from functools import partial
|
3 |
import logging
|
4 |
|
@@ -11,7 +11,6 @@ from langgraph.prebuilt import tools_condition, ToolNode
|
|
11 |
from langgraph.graph.state import StateGraph
|
12 |
from langgraph.constants import START, END
|
13 |
|
14 |
-
from ask_candid.retrieval.elastic import retriever_tool
|
15 |
from ask_candid.tools.recommendation import (
|
16 |
detect_intent_with_llm,
|
17 |
determine_context,
|
@@ -19,8 +18,9 @@ from ask_candid.tools.recommendation import (
|
|
19 |
)
|
20 |
from ask_candid.tools.question_reformulation import reformulate_question_using_history
|
21 |
from ask_candid.tools.org_seach import has_org_name, insert_org_link
|
22 |
-
from ask_candid.tools.search import search_agent
|
23 |
from ask_candid.agents.schema import AgentState
|
|
|
24 |
|
25 |
from ask_candid.utils import html_format_docs_chat
|
26 |
|
@@ -29,7 +29,11 @@ logger = logging.getLogger(__name__)
|
|
29 |
logger.setLevel(logging.INFO)
|
30 |
|
31 |
|
32 |
-
def generate_with_context(
|
|
|
|
|
|
|
|
|
33 |
"""Generate answer.
|
34 |
|
35 |
Parameters
|
@@ -37,6 +41,8 @@ def generate_with_context(state: AgentState, llm: LLM) -> AgentState:
|
|
37 |
state : AgentState
|
38 |
The current state
|
39 |
llm : LLM
|
|
|
|
|
40 |
|
41 |
Returns
|
42 |
-------
|
@@ -45,14 +51,20 @@ def generate_with_context(state: AgentState, llm: LLM) -> AgentState:
|
|
45 |
"""
|
46 |
|
47 |
logger.info("---GENERATE ANSWER---")
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
messages = state["messages"]
|
49 |
question = state["user_input"]
|
50 |
last_message = messages[-1]
|
51 |
|
52 |
sources_str = last_message.content
|
53 |
-
sources_list = last_message.artifact
|
54 |
-
# converting to html string
|
55 |
sources_html = html_format_docs_chat(sources_list)
|
|
|
56 |
if sources_list:
|
57 |
logger.info("---ADD SOURCES---")
|
58 |
state["messages"].append(BaseMessage(content=sources_html, type="HTML"))
|
@@ -97,13 +109,13 @@ def add_recommendations_pipeline_(
|
|
97 |
"""
|
98 |
|
99 |
# Nodes for recommendation functionalities
|
100 |
-
G.add_node("detect_intent_with_llm", partial(detect_intent_with_llm, llm=llm))
|
101 |
-
G.add_node("determine_context", determine_context)
|
102 |
-
G.add_node("make_recommendation", make_recommendation)
|
103 |
|
104 |
# Check for recommendation query first
|
105 |
# Execute until reaching END if user asks for recommendation
|
106 |
-
G.add_edge(reformulation_node_name, "detect_intent_with_llm")
|
107 |
G.add_conditional_edges(
|
108 |
source="detect_intent_with_llm",
|
109 |
path=lambda state: "determine_context" if state["intent"] in ["rfp", "funder"] else search_node_name,
|
@@ -112,24 +124,27 @@ def add_recommendations_pipeline_(
|
|
112 |
search_node_name: search_node_name
|
113 |
},
|
114 |
)
|
115 |
-
G.add_edge("determine_context", "make_recommendation")
|
116 |
-
G.add_edge("make_recommendation", END)
|
117 |
|
118 |
|
119 |
def build_compute_graph(
|
120 |
llm: LLM,
|
121 |
-
indices: List[
|
122 |
-
enable_recommendations: bool = False
|
|
|
123 |
) -> StateGraph:
|
124 |
"""Execution graph builder, the output is the execution flow for an interaction with the assistant.
|
125 |
|
126 |
Parameters
|
127 |
----------
|
128 |
llm : LLM
|
129 |
-
indices : List[
|
130 |
Semantic index names to search over
|
131 |
enable_recommendations : bool, optional
|
132 |
Set to `True` to allow the flow to generate recommendations based on context, by default False
|
|
|
|
|
133 |
|
134 |
Returns
|
135 |
-------
|
@@ -137,25 +152,35 @@ def build_compute_graph(
|
|
137 |
Execution graph
|
138 |
"""
|
139 |
|
140 |
-
candid_retriever_tool = retriever_tool(indices=indices)
|
141 |
retrieve = ToolNode([candid_retriever_tool])
|
142 |
tools = [candid_retriever_tool]
|
143 |
|
144 |
G = StateGraph(AgentState)
|
145 |
|
146 |
-
G.add_node(
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
G.add_node("
|
151 |
-
G.add_node("
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
|
153 |
if enable_recommendations:
|
154 |
-
add_recommendations_pipeline_(
|
|
|
|
|
|
|
|
|
155 |
else:
|
156 |
-
G.add_edge("reformulate", "search_agent")
|
157 |
|
158 |
-
G.add_edge(START, "reformulate")
|
159 |
G.add_conditional_edges(
|
160 |
source="search_agent",
|
161 |
path=tools_condition,
|
@@ -164,8 +189,8 @@ def build_compute_graph(
|
|
164 |
END: "has_org_name",
|
165 |
},
|
166 |
)
|
167 |
-
G.add_edge("retrieve", "generate_with_context")
|
168 |
-
G.add_edge("generate_with_context", "has_org_name")
|
169 |
G.add_conditional_edges(
|
170 |
source="has_org_name",
|
171 |
path=lambda x: x["next"], # Now we're accessing the 'next' key from the dict
|
@@ -174,5 +199,5 @@ def build_compute_graph(
|
|
174 |
END: END
|
175 |
},
|
176 |
)
|
177 |
-
G.add_edge("insert_org_link", END)
|
178 |
return G
|
|
|
1 |
+
from typing import List, Optional, Callable, Any
|
2 |
from functools import partial
|
3 |
import logging
|
4 |
|
|
|
11 |
from langgraph.graph.state import StateGraph
|
12 |
from langgraph.constants import START, END
|
13 |
|
|
|
14 |
from ask_candid.tools.recommendation import (
|
15 |
detect_intent_with_llm,
|
16 |
determine_context,
|
|
|
18 |
)
|
19 |
from ask_candid.tools.question_reformulation import reformulate_question_using_history
|
20 |
from ask_candid.tools.org_seach import has_org_name, insert_org_link
|
21 |
+
from ask_candid.tools.search import search_agent, retriever_tool
|
22 |
from ask_candid.agents.schema import AgentState
|
23 |
+
from ask_candid.base.config.data import DataIndices
|
24 |
|
25 |
from ask_candid.utils import html_format_docs_chat
|
26 |
|
|
|
29 |
logger.setLevel(logging.INFO)
|
30 |
|
31 |
|
32 |
+
def generate_with_context(
|
33 |
+
state: AgentState,
|
34 |
+
llm: LLM,
|
35 |
+
user_callback: Optional[Callable[[str], Any]] = None
|
36 |
+
) -> AgentState:
|
37 |
"""Generate answer.
|
38 |
|
39 |
Parameters
|
|
|
41 |
state : AgentState
|
42 |
The current state
|
43 |
llm : LLM
|
44 |
+
user_callback : Optional[Callable[[str], Any]], optional
|
45 |
+
Optional UI callback to inform the user of apps states, by default None
|
46 |
|
47 |
Returns
|
48 |
-------
|
|
|
51 |
"""
|
52 |
|
53 |
logger.info("---GENERATE ANSWER---")
|
54 |
+
if user_callback is not None:
|
55 |
+
try:
|
56 |
+
user_callback("Writing a response...")
|
57 |
+
except Exception as ex:
|
58 |
+
logger.warning("User callback was passed in but failed: %s", ex)
|
59 |
+
|
60 |
messages = state["messages"]
|
61 |
question = state["user_input"]
|
62 |
last_message = messages[-1]
|
63 |
|
64 |
sources_str = last_message.content
|
65 |
+
sources_list = last_message.artifact
|
|
|
66 |
sources_html = html_format_docs_chat(sources_list)
|
67 |
+
|
68 |
if sources_list:
|
69 |
logger.info("---ADD SOURCES---")
|
70 |
state["messages"].append(BaseMessage(content=sources_html, type="HTML"))
|
|
|
109 |
"""
|
110 |
|
111 |
# Nodes for recommendation functionalities
|
112 |
+
G.add_node(node="detect_intent_with_llm", action=partial(detect_intent_with_llm, llm=llm))
|
113 |
+
G.add_node(node="determine_context", action=determine_context)
|
114 |
+
G.add_node(node="make_recommendation", action=make_recommendation)
|
115 |
|
116 |
# Check for recommendation query first
|
117 |
# Execute until reaching END if user asks for recommendation
|
118 |
+
G.add_edge(start_key=reformulation_node_name, end_key="detect_intent_with_llm")
|
119 |
G.add_conditional_edges(
|
120 |
source="detect_intent_with_llm",
|
121 |
path=lambda state: "determine_context" if state["intent"] in ["rfp", "funder"] else search_node_name,
|
|
|
124 |
search_node_name: search_node_name
|
125 |
},
|
126 |
)
|
127 |
+
G.add_edge(start_key="determine_context", end_key="make_recommendation")
|
128 |
+
G.add_edge(start_key="make_recommendation", end_key=END)
|
129 |
|
130 |
|
131 |
def build_compute_graph(
|
132 |
llm: LLM,
|
133 |
+
indices: List[DataIndices],
|
134 |
+
enable_recommendations: bool = False,
|
135 |
+
user_callback: Optional[Callable[[str], Any]] = None
|
136 |
) -> StateGraph:
|
137 |
"""Execution graph builder, the output is the execution flow for an interaction with the assistant.
|
138 |
|
139 |
Parameters
|
140 |
----------
|
141 |
llm : LLM
|
142 |
+
indices : List[DataIndices]
|
143 |
Semantic index names to search over
|
144 |
enable_recommendations : bool, optional
|
145 |
Set to `True` to allow the flow to generate recommendations based on context, by default False
|
146 |
+
user_callback : Optional[Callable[[str], Any]], optional
|
147 |
+
Optional UI callback to inform the user of apps states, by default None
|
148 |
|
149 |
Returns
|
150 |
-------
|
|
|
152 |
Execution graph
|
153 |
"""
|
154 |
|
155 |
+
candid_retriever_tool = retriever_tool(indices=indices, user_callback=user_callback)
|
156 |
retrieve = ToolNode([candid_retriever_tool])
|
157 |
tools = [candid_retriever_tool]
|
158 |
|
159 |
G = StateGraph(AgentState)
|
160 |
|
161 |
+
G.add_node(
|
162 |
+
node="reformulate",
|
163 |
+
action=partial(reformulate_question_using_history, llm=llm, focus_on_recommendations=enable_recommendations)
|
164 |
+
)
|
165 |
+
G.add_node(node="search_agent", action=partial(search_agent, llm=llm, tools=tools))
|
166 |
+
G.add_node(node="retrieve", action=retrieve)
|
167 |
+
G.add_node(
|
168 |
+
node="generate_with_context",
|
169 |
+
action=partial(generate_with_context, llm=llm, user_callback=user_callback)
|
170 |
+
)
|
171 |
+
G.add_node(node="has_org_name", action=partial(has_org_name, llm=llm, user_callback=user_callback))
|
172 |
+
G.add_node(node="insert_org_link", action=insert_org_link)
|
173 |
|
174 |
if enable_recommendations:
|
175 |
+
add_recommendations_pipeline_(
|
176 |
+
G, llm=llm,
|
177 |
+
reformulation_node_name="reformulate",
|
178 |
+
search_node_name="search_agent"
|
179 |
+
)
|
180 |
else:
|
181 |
+
G.add_edge(start_key="reformulate", end_key="search_agent")
|
182 |
|
183 |
+
G.add_edge(start_key=START, end_key="reformulate")
|
184 |
G.add_conditional_edges(
|
185 |
source="search_agent",
|
186 |
path=tools_condition,
|
|
|
189 |
END: "has_org_name",
|
190 |
},
|
191 |
)
|
192 |
+
G.add_edge(start_key="retrieve", end_key="generate_with_context")
|
193 |
+
G.add_edge(start_key="generate_with_context", end_key="has_org_name")
|
194 |
G.add_conditional_edges(
|
195 |
source="has_org_name",
|
196 |
path=lambda x: x["next"], # Now we're accessing the 'next' key from the dict
|
|
|
199 |
END: END
|
200 |
},
|
201 |
)
|
202 |
+
G.add_edge(start_key="insert_org_link", end_key=END)
|
203 |
return G
|
ask_candid/retrieval/elastic.py
CHANGED
@@ -1,20 +1,24 @@
|
|
1 |
from typing import List, Tuple, Dict, Iterable, Iterator, Optional, Union, Any
|
2 |
from dataclasses import dataclass
|
3 |
-
from functools import partial
|
4 |
from itertools import groupby
|
5 |
|
6 |
from torch.nn import functional as F
|
7 |
|
8 |
from pydantic import BaseModel, Field
|
9 |
from langchain_core.documents import Document
|
10 |
-
from langchain_core.tools import Tool
|
11 |
|
12 |
from elasticsearch import Elasticsearch
|
13 |
|
14 |
from ask_candid.retrieval.sparse_lexical import SpladeEncoder
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
from ask_candid.services.small_lm import CandidSLM
|
16 |
from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA, NEWS_ELASTIC
|
17 |
-
from ask_candid.base.config.data import
|
18 |
|
19 |
encoder = SpladeEncoder()
|
20 |
|
@@ -82,6 +86,18 @@ def build_sparse_vector_query(
|
|
82 |
|
83 |
|
84 |
def news_query_builder(query: str) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
tokens = encoder.token_expand(query)
|
86 |
|
87 |
query = {
|
@@ -103,81 +119,70 @@ def news_query_builder(query: str) -> Dict[str, Any]:
|
|
103 |
query["query"]["bool"]["should"].append({
|
104 |
"multi_match": {
|
105 |
"query": token,
|
106 |
-
"fields":
|
107 |
"boost": score
|
108 |
}
|
109 |
})
|
110 |
return query
|
111 |
|
112 |
|
113 |
-
def query_builder(query: str, indices: List[
|
114 |
"""Builds Elasticsearch multi-search query payload
|
115 |
|
116 |
Parameters
|
117 |
----------
|
118 |
query : str
|
119 |
Search context string
|
120 |
-
indices : List[
|
121 |
Semantic index names to search over
|
122 |
|
123 |
Returns
|
124 |
-------
|
125 |
-
List[Dict[str, Any]]
|
|
|
126 |
"""
|
127 |
|
128 |
-
queries = []
|
129 |
if indices is None:
|
130 |
indices = list(ALL_INDICES)
|
131 |
|
132 |
for index in indices:
|
133 |
if index == "issuelab":
|
134 |
-
q = build_sparse_vector_query(
|
135 |
-
query=query,
|
136 |
-
fields=("description", "content", "combined_issuelab_findings", "combined_item_description")
|
137 |
-
)
|
138 |
q["_source"] = {"excludes": ["embeddings"]}
|
139 |
q["size"] = 1
|
140 |
-
queries.extend([{"index":
|
141 |
elif index == "youtube":
|
142 |
-
q = build_sparse_vector_query(
|
143 |
-
|
144 |
-
fields=("captions_cleaned", "description_cleaned", "title")
|
145 |
-
)
|
146 |
-
# text_cleaned duplicates captions_cleaned
|
147 |
-
q["_source"] = {"excludes": ["embeddings", "captions", "description", "text_cleaned"]}
|
148 |
q["size"] = 2
|
149 |
-
queries.extend([{"index":
|
150 |
elif index == "candid_blog":
|
151 |
-
q = build_sparse_vector_query(
|
152 |
-
query=query,
|
153 |
-
fields=("content", "authors_text", "title_summary_tags")
|
154 |
-
)
|
155 |
q["_source"] = {"excludes": ["embeddings"]}
|
156 |
q["size"] = 2
|
157 |
-
queries.extend([{"index":
|
158 |
elif index == "candid_learning":
|
159 |
-
q = build_sparse_vector_query(
|
160 |
-
query=query,
|
161 |
-
fields=("content", "title", "training_topics", "staff_recommendations")
|
162 |
-
)
|
163 |
q["_source"] = {"excludes": ["embeddings"]}
|
164 |
q["size"] = 2
|
165 |
-
queries.extend([{"index":
|
166 |
elif index == "candid_help":
|
167 |
-
q = build_sparse_vector_query(
|
168 |
-
query=query,
|
169 |
-
fields=("content", "combined_article_description")
|
170 |
-
)
|
171 |
q["_source"] = {"excludes": ["embeddings"]}
|
172 |
q["size"] = 2
|
173 |
-
queries.extend([{"index":
|
|
|
|
|
|
|
|
|
174 |
|
175 |
-
return queries
|
176 |
|
177 |
|
178 |
def multi_search(
|
179 |
queries: List[Dict[str, Any]],
|
180 |
-
|
181 |
) -> List[ElasticHitsResult]:
|
182 |
"""Runs multi-search query
|
183 |
|
@@ -191,6 +196,17 @@ def multi_search(
|
|
191 |
List[ElasticHitsResult]
|
192 |
"""
|
193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
results = []
|
195 |
|
196 |
if len(queries) > 0:
|
@@ -200,31 +216,16 @@ def multi_search(
|
|
200 |
verify_certs=False,
|
201 |
request_timeout=60 * 3
|
202 |
) as es:
|
203 |
-
for
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
id=hit["_id"],
|
208 |
-
score=hit["_score"],
|
209 |
-
source=hit["_source"],
|
210 |
-
inner_hits=hit.get("inner_hits", {})
|
211 |
-
)
|
212 |
-
results.append(hit)
|
213 |
-
|
214 |
-
if news_query is not None:
|
215 |
with Elasticsearch(
|
216 |
NEWS_ELASTIC.url,
|
217 |
http_auth=(NEWS_ELASTIC.username, NEWS_ELASTIC.password),
|
218 |
timeout=60
|
219 |
) as es:
|
220 |
-
for hit in es.
|
221 |
-
hit = ElasticHitsResult(
|
222 |
-
index=hit["_index"],
|
223 |
-
id=hit["_id"],
|
224 |
-
score=hit["_score"],
|
225 |
-
source=hit["_source"],
|
226 |
-
inner_hits=hit.get("inner_hits", {})
|
227 |
-
)
|
228 |
results.append(hit)
|
229 |
return results
|
230 |
|
@@ -244,9 +245,8 @@ def get_query_results(search_text: str, indices: Optional[List[str]] = None) ->
|
|
244 |
List[ElasticHitsResult]
|
245 |
"""
|
246 |
|
247 |
-
queries = query_builder(query=search_text, indices=indices)
|
248 |
-
|
249 |
-
return multi_search(queries, news_query=news_q)
|
250 |
|
251 |
|
252 |
def retrieved_text(hits: Dict[str, Any]) -> str:
|
@@ -335,36 +335,6 @@ def reranker(
|
|
335 |
yield from sorted(results, key=lambda x: x.score, reverse=True)
|
336 |
|
337 |
|
338 |
-
def get_results(user_input: str, indices: List[str]) -> Tuple[str, List[Document]]:
|
339 |
-
"""End-to-end search and re-rank function.
|
340 |
-
|
341 |
-
Parameters
|
342 |
-
----------
|
343 |
-
user_input : str
|
344 |
-
Search context string
|
345 |
-
indices : List[str]
|
346 |
-
Semantic index names to search over
|
347 |
-
|
348 |
-
Returns
|
349 |
-
-------
|
350 |
-
Tuple[str, List[Document]]
|
351 |
-
(concatenated text from search results, documents list)
|
352 |
-
"""
|
353 |
-
|
354 |
-
output = ["Search didn't return any Candid sources"]
|
355 |
-
page_content = []
|
356 |
-
content = "Search didn't return any Candid sources"
|
357 |
-
results = get_query_results(search_text=user_input, indices=indices)
|
358 |
-
if results:
|
359 |
-
output = get_reranked_results(results, search_text=user_input)
|
360 |
-
for doc in output:
|
361 |
-
page_content.append(doc.page_content)
|
362 |
-
content = "\n\n".join(page_content)
|
363 |
-
|
364 |
-
# for the tool we need to return a tuple for content_and_artifact type
|
365 |
-
return content, output
|
366 |
-
|
367 |
-
|
368 |
def get_context(field_name: str, hit: ElasticHitsResult, context_length: int = 1024, add_context: bool = True) -> str:
|
369 |
"""Pads the relevant chunk of text with context before and after
|
370 |
|
@@ -537,30 +507,3 @@ def get_reranked_results(results: List[ElasticHitsResult], search_text: Optional
|
|
537 |
if hit is not None:
|
538 |
output.append(hit)
|
539 |
return output
|
540 |
-
|
541 |
-
|
542 |
-
def retriever_tool(indices: List[str]) -> Tool:
|
543 |
-
"""Tool component for use in conditional edge building for RAG execution graph.
|
544 |
-
Cannot use `create_retriever_tool` because it only provides content losing all metadata on the way
|
545 |
-
https://python.langchain.com/docs/how_to/custom_tools/#returning-artifacts-of-tool-execution
|
546 |
-
|
547 |
-
Parameters
|
548 |
-
----------
|
549 |
-
indices : List[str]
|
550 |
-
Semantic index names to search over
|
551 |
-
|
552 |
-
Returns
|
553 |
-
-------
|
554 |
-
Tool
|
555 |
-
"""
|
556 |
-
|
557 |
-
return Tool(
|
558 |
-
name="retrieve_social_sector_information",
|
559 |
-
func=partial(get_results, indices=indices),
|
560 |
-
description=(
|
561 |
-
"Return additional information about social and philanthropic sector, "
|
562 |
-
"including nonprofits (NGO), grants, foundations, funding, RFP, LOI, Candid."
|
563 |
-
),
|
564 |
-
args_schema=RetrieverInput,
|
565 |
-
response_format="content_and_artifact"
|
566 |
-
)
|
|
|
1 |
from typing import List, Tuple, Dict, Iterable, Iterator, Optional, Union, Any
|
2 |
from dataclasses import dataclass
|
|
|
3 |
from itertools import groupby
|
4 |
|
5 |
from torch.nn import functional as F
|
6 |
|
7 |
from pydantic import BaseModel, Field
|
8 |
from langchain_core.documents import Document
|
|
|
9 |
|
10 |
from elasticsearch import Elasticsearch
|
11 |
|
12 |
from ask_candid.retrieval.sparse_lexical import SpladeEncoder
|
13 |
+
from ask_candid.retrieval.sources.issuelab import IssueLabConfig
|
14 |
+
from ask_candid.retrieval.sources.youtube import YoutubeConfig
|
15 |
+
from ask_candid.retrieval.sources.candid_blog import CandidBlogConfig
|
16 |
+
from ask_candid.retrieval.sources.candid_learning import CandidLearningConfig
|
17 |
+
from ask_candid.retrieval.sources.candid_help import CandidHelpConfig
|
18 |
+
from ask_candid.retrieval.sources.candid_news import CandidNewsConfig
|
19 |
from ask_candid.services.small_lm import CandidSLM
|
20 |
from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA, NEWS_ELASTIC
|
21 |
+
from ask_candid.base.config.data import DataIndices, ALL_INDICES
|
22 |
|
23 |
encoder = SpladeEncoder()
|
24 |
|
|
|
86 |
|
87 |
|
88 |
def news_query_builder(query: str) -> Dict[str, Any]:
|
89 |
+
"""Builds a valid Elasticsearch query against Candid news, simulating a token expansion.
|
90 |
+
|
91 |
+
Parameters
|
92 |
+
----------
|
93 |
+
query : str
|
94 |
+
Search context string
|
95 |
+
|
96 |
+
Returns
|
97 |
+
-------
|
98 |
+
Dict[str, Any]
|
99 |
+
"""
|
100 |
+
|
101 |
tokens = encoder.token_expand(query)
|
102 |
|
103 |
query = {
|
|
|
119 |
query["query"]["bool"]["should"].append({
|
120 |
"multi_match": {
|
121 |
"query": token,
|
122 |
+
"fields": CandidNewsConfig.text_fields,
|
123 |
"boost": score
|
124 |
}
|
125 |
})
|
126 |
return query
|
127 |
|
128 |
|
129 |
+
def query_builder(query: str, indices: List[DataIndices]) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
130 |
"""Builds Elasticsearch multi-search query payload
|
131 |
|
132 |
Parameters
|
133 |
----------
|
134 |
query : str
|
135 |
Search context string
|
136 |
+
indices : List[DataIndices]
|
137 |
Semantic index names to search over
|
138 |
|
139 |
Returns
|
140 |
-------
|
141 |
+
Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]
|
142 |
+
(semantic index queries, news queries)
|
143 |
"""
|
144 |
|
145 |
+
queries, news_queries = [], []
|
146 |
if indices is None:
|
147 |
indices = list(ALL_INDICES)
|
148 |
|
149 |
for index in indices:
|
150 |
if index == "issuelab":
|
151 |
+
q = build_sparse_vector_query(query=query, fields=IssueLabConfig.text_fields)
|
|
|
|
|
|
|
152 |
q["_source"] = {"excludes": ["embeddings"]}
|
153 |
q["size"] = 1
|
154 |
+
queries.extend([{"index": IssueLabConfig.index_name}, q])
|
155 |
elif index == "youtube":
|
156 |
+
q = build_sparse_vector_query(query=query, fields=YoutubeConfig.text_fields)
|
157 |
+
q["_source"] = {"excludes": ["embeddings", *YoutubeConfig.excluded_fields]}
|
|
|
|
|
|
|
|
|
158 |
q["size"] = 2
|
159 |
+
queries.extend([{"index": YoutubeConfig.index_name}, q])
|
160 |
elif index == "candid_blog":
|
161 |
+
q = build_sparse_vector_query(query=query, fields=CandidBlogConfig.text_fields)
|
|
|
|
|
|
|
162 |
q["_source"] = {"excludes": ["embeddings"]}
|
163 |
q["size"] = 2
|
164 |
+
queries.extend([{"index": CandidBlogConfig.index_name}, q])
|
165 |
elif index == "candid_learning":
|
166 |
+
q = build_sparse_vector_query(query=query, fields=CandidLearningConfig.text_fields)
|
|
|
|
|
|
|
167 |
q["_source"] = {"excludes": ["embeddings"]}
|
168 |
q["size"] = 2
|
169 |
+
queries.extend([{"index": CandidLearningConfig.index_name}, q])
|
170 |
elif index == "candid_help":
|
171 |
+
q = build_sparse_vector_query(query=query, fields=CandidHelpConfig.text_fields)
|
|
|
|
|
|
|
172 |
q["_source"] = {"excludes": ["embeddings"]}
|
173 |
q["size"] = 2
|
174 |
+
queries.extend([{"index": CandidHelpConfig.index_name}, q])
|
175 |
+
elif index == "news":
|
176 |
+
q = news_query_builder(query=query)
|
177 |
+
q["size"] = 5
|
178 |
+
news_queries.extend([{"index": CandidNewsConfig.index_name}, q])
|
179 |
|
180 |
+
return queries, news_queries
|
181 |
|
182 |
|
183 |
def multi_search(
|
184 |
queries: List[Dict[str, Any]],
|
185 |
+
news_queries: Optional[List[Dict[str, Any]]] = None
|
186 |
) -> List[ElasticHitsResult]:
|
187 |
"""Runs multi-search query
|
188 |
|
|
|
196 |
List[ElasticHitsResult]
|
197 |
"""
|
198 |
|
199 |
+
def _msearch_response_generator(responses: List[Dict[str, Any]]) -> Iterator[ElasticHitsResult]:
|
200 |
+
for query_group in responses:
|
201 |
+
for h in query_group.get("hits", {}).get("hits", []):
|
202 |
+
yield ElasticHitsResult(
|
203 |
+
index=h["_index"],
|
204 |
+
id=h["_id"],
|
205 |
+
score=h["_score"],
|
206 |
+
source=h["_source"],
|
207 |
+
inner_hits=h.get("inner_hits", {})
|
208 |
+
)
|
209 |
+
|
210 |
results = []
|
211 |
|
212 |
if len(queries) > 0:
|
|
|
216 |
verify_certs=False,
|
217 |
request_timeout=60 * 3
|
218 |
) as es:
|
219 |
+
for hit in _msearch_response_generator(es.msearch(body=queries).get("responses", [])):
|
220 |
+
results.append(hit)
|
221 |
+
|
222 |
+
if news_queries is not None and len(news_queries):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
with Elasticsearch(
|
224 |
NEWS_ELASTIC.url,
|
225 |
http_auth=(NEWS_ELASTIC.username, NEWS_ELASTIC.password),
|
226 |
timeout=60
|
227 |
) as es:
|
228 |
+
for hit in _msearch_response_generator(es.msearch(body=news_queries).get("responses", [])):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
results.append(hit)
|
230 |
return results
|
231 |
|
|
|
245 |
List[ElasticHitsResult]
|
246 |
"""
|
247 |
|
248 |
+
queries, news_q = query_builder(query=search_text, indices=indices)
|
249 |
+
return multi_search(queries, news_queries=news_q)
|
|
|
250 |
|
251 |
|
252 |
def retrieved_text(hits: Dict[str, Any]) -> str:
|
|
|
335 |
yield from sorted(results, key=lambda x: x.score, reverse=True)
|
336 |
|
337 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
338 |
def get_context(field_name: str, hit: ElasticHitsResult, context_length: int = 1024, add_context: bool = True) -> str:
|
339 |
"""Pads the relevant chunk of text with context before and after
|
340 |
|
|
|
507 |
if hit is not None:
|
508 |
output.append(hit)
|
509 |
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ask_candid/retrieval/sources/candid_blog.py
CHANGED
@@ -1,4 +1,11 @@
|
|
1 |
from typing import Dict, Any
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
|
4 |
def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
|
|
1 |
from typing import Dict, Any
|
2 |
+
from ask_candid.retrieval.sources.schema import ElasticSourceConfig
|
3 |
+
|
4 |
+
|
5 |
+
CandidBlogConfig = ElasticSourceConfig(
|
6 |
+
index_name="search-semantic-candid-blog",
|
7 |
+
text_fields=("content", "authors_text", "title_summary_tags")
|
8 |
+
)
|
9 |
|
10 |
|
11 |
def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
ask_candid/retrieval/sources/candid_help.py
CHANGED
@@ -1,4 +1,11 @@
|
|
1 |
from typing import Dict, Any
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
|
4 |
def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
|
|
1 |
from typing import Dict, Any
|
2 |
+
from ask_candid.retrieval.sources.schema import ElasticSourceConfig
|
3 |
+
|
4 |
+
|
5 |
+
CandidHelpConfig = ElasticSourceConfig(
|
6 |
+
index_name="search-semantic-candid-help-elser_ve1",
|
7 |
+
text_fields=("content", "combined_article_description")
|
8 |
+
)
|
9 |
|
10 |
|
11 |
def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
ask_candid/retrieval/sources/candid_learning.py
CHANGED
@@ -1,4 +1,11 @@
|
|
1 |
from typing import Dict, Any
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
|
4 |
def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
|
|
1 |
from typing import Dict, Any
|
2 |
+
from ask_candid.retrieval.sources.schema import ElasticSourceConfig
|
3 |
+
|
4 |
+
|
5 |
+
CandidLearningConfig = ElasticSourceConfig(
|
6 |
+
index_name="search-semantic-candid-learning_ve1",
|
7 |
+
text_fields=("content", "title", "training_topics", "staff_recommendations")
|
8 |
+
)
|
9 |
|
10 |
|
11 |
def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
ask_candid/retrieval/sources/candid_news.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ask_candid.retrieval.sources.schema import ElasticSourceConfig
|
2 |
+
|
3 |
+
|
4 |
+
CandidNewsConfig = ElasticSourceConfig(
|
5 |
+
index_name="news_1",
|
6 |
+
text_fields=("title", "content")
|
7 |
+
)
|
ask_candid/retrieval/sources/issuelab.py
CHANGED
@@ -1,4 +1,11 @@
|
|
1 |
from typing import Dict, Any
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
|
4 |
def issuelab_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
|
|
1 |
from typing import Dict, Any
|
2 |
+
from ask_candid.retrieval.sources.schema import ElasticSourceConfig
|
3 |
+
|
4 |
+
|
5 |
+
IssueLabConfig = ElasticSourceConfig(
|
6 |
+
index_name="search-semantic-issuelab-elser_ve2",
|
7 |
+
text_fields=("description", "content", "combined_issuelab_findings", "combined_item_description")
|
8 |
+
)
|
9 |
|
10 |
|
11 |
def issuelab_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
ask_candid/retrieval/sources/schema.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple, Optional
|
2 |
+
from dataclasses import dataclass, field
|
3 |
+
|
4 |
+
|
5 |
+
@dataclass
|
6 |
+
class ElasticSourceConfig:
|
7 |
+
index_name: str
|
8 |
+
text_fields: Tuple[str]
|
9 |
+
excluded_fields: Optional[Tuple[str]] = field(default_factory=tuple)
|
ask_candid/retrieval/sources/youtube.py
CHANGED
@@ -1,4 +1,12 @@
|
|
1 |
from typing import Dict, Any
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
|
4 |
def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
|
|
1 |
from typing import Dict, Any
|
2 |
+
from ask_candid.retrieval.sources.schema import ElasticSourceConfig
|
3 |
+
|
4 |
+
|
5 |
+
YoutubeConfig = ElasticSourceConfig(
|
6 |
+
index_name="search-semantic-youtube-elser_ve1",
|
7 |
+
text_fields=("captions_cleaned", "description_cleaned", "title"),
|
8 |
+
excluded_fields=("captions", "description", "text_cleaned")
|
9 |
+
)
|
10 |
|
11 |
|
12 |
def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
ask_candid/tools/elastic/list_indices_tool.py
CHANGED
@@ -31,7 +31,8 @@ class ListIndicesTool(BaseTool):
|
|
31 |
|
32 |
name: str = "elastic_list_indices" # Added type annotation
|
33 |
description: str = (
|
34 |
-
"Input is a delimiter like comma or new line. Output is a separated list of indices in the database.
|
|
|
35 |
)
|
36 |
args_schema: Optional[Type[BaseModel]] = (
|
37 |
ListIndicesInput # Define this before methods
|
|
|
31 |
|
32 |
name: str = "elastic_list_indices" # Added type annotation
|
33 |
description: str = (
|
34 |
+
"Input is a delimiter like comma or new line. Output is a separated list of indices in the database. "
|
35 |
+
"Always use this tool to get to know the indices in the ElasticSearch cluster."
|
36 |
)
|
37 |
args_schema: Optional[Type[BaseModel]] = (
|
38 |
ListIndicesInput # Define this before methods
|
ask_candid/tools/org_seach.py
CHANGED
@@ -1,11 +1,10 @@
|
|
1 |
-
from typing import List
|
2 |
import logging
|
3 |
import re
|
4 |
|
5 |
from thefuzz import fuzz
|
6 |
|
7 |
from langchain.output_parsers.openai_tools import JsonOutputToolsParser
|
8 |
-
# from langchain_openai.chat_models import ChatOpenAI
|
9 |
from langchain_core.runnables import RunnableSequence
|
10 |
from langchain_core.prompts import ChatPromptTemplate
|
11 |
from langchain_core.language_models.llms import LLM
|
@@ -15,7 +14,6 @@ from pydantic import BaseModel, Field
|
|
15 |
|
16 |
from ask_candid.agents.schema import AgentState
|
17 |
from ask_candid.services.org_search import OrgSearch
|
18 |
-
# from ask_candid.base.config.rest import OPENAI
|
19 |
|
20 |
search = OrgSearch()
|
21 |
logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s")
|
@@ -59,7 +57,6 @@ def extract_org_links_from_chatbot(chatbot_output: str, llm: LLM):
|
|
59 |
|
60 |
try:
|
61 |
parser = JsonOutputToolsParser()
|
62 |
-
# llm = ChatOpenAI(model="gpt-4o", api_key=OPENAI["key"]).bind_tools([OrganizationNames])
|
63 |
model = llm.bind_tools([OrganizationNames])
|
64 |
prompt = ChatPromptTemplate.from_template(prompt)
|
65 |
chain = RunnableSequence(prompt, model, parser)
|
@@ -203,17 +200,33 @@ def embed_org_links_in_text(input_text: str, org_link_dict: dict):
|
|
203 |
return input_text
|
204 |
|
205 |
|
206 |
-
def has_org_name(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
"""
|
208 |
-
Processes the latest message to extract organization links and determine the next step.
|
209 |
|
210 |
-
Args:
|
211 |
-
state (AgentState): The current state of the agent, including a list of messages.
|
212 |
-
|
213 |
-
Returns:
|
214 |
-
dict: A dictionary with the next agent action and, if available, a dictionary of organization links.
|
215 |
-
"""
|
216 |
logger.info("---HAS ORG NAMES?---")
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
messages = state["messages"]
|
218 |
last_message = messages[-1].content
|
219 |
output_list = extract_org_links_from_chatbot(last_message, llm=llm)
|
|
|
1 |
+
from typing import List, Optional, Callable, Any
|
2 |
import logging
|
3 |
import re
|
4 |
|
5 |
from thefuzz import fuzz
|
6 |
|
7 |
from langchain.output_parsers.openai_tools import JsonOutputToolsParser
|
|
|
8 |
from langchain_core.runnables import RunnableSequence
|
9 |
from langchain_core.prompts import ChatPromptTemplate
|
10 |
from langchain_core.language_models.llms import LLM
|
|
|
14 |
|
15 |
from ask_candid.agents.schema import AgentState
|
16 |
from ask_candid.services.org_search import OrgSearch
|
|
|
17 |
|
18 |
search = OrgSearch()
|
19 |
logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s")
|
|
|
57 |
|
58 |
try:
|
59 |
parser = JsonOutputToolsParser()
|
|
|
60 |
model = llm.bind_tools([OrganizationNames])
|
61 |
prompt = ChatPromptTemplate.from_template(prompt)
|
62 |
chain = RunnableSequence(prompt, model, parser)
|
|
|
200 |
return input_text
|
201 |
|
202 |
|
203 |
+
def has_org_name(
|
204 |
+
state: AgentState,
|
205 |
+
llm: LLM,
|
206 |
+
user_callback: Optional[Callable[[str], Any]] = None
|
207 |
+
) -> AgentState:
|
208 |
+
"""Processes the latest message to extract organization links and determine the next step.
|
209 |
+
|
210 |
+
Parameters
|
211 |
+
----------
|
212 |
+
state : AgentState
|
213 |
+
The current state of the agent, including a list of messages.
|
214 |
+
llm : LLM
|
215 |
+
user_callback : Optional[Callable[[str], Any]], optional
|
216 |
+
Optional UI callback to inform the user of apps states, by default None
|
217 |
+
|
218 |
+
Returns
|
219 |
+
-------
|
220 |
+
AgentState
|
221 |
"""
|
|
|
222 |
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
logger.info("---HAS ORG NAMES?---")
|
224 |
+
if user_callback is not None:
|
225 |
+
try:
|
226 |
+
user_callback("Checking for relevant organizations")
|
227 |
+
except Exception as ex:
|
228 |
+
logger.warning("User callback was passed in but failed: %s", ex)
|
229 |
+
|
230 |
messages = state["messages"]
|
231 |
last_message = messages[-1].content
|
232 |
output_list = extract_org_links_from_chatbot(last_message, llm=llm)
|
ask_candid/tools/search.py
CHANGED
@@ -1,9 +1,14 @@
|
|
1 |
-
from typing import List
|
|
|
2 |
import logging
|
3 |
|
|
|
4 |
from langchain_core.language_models.llms import LLM
|
|
|
5 |
from langchain_core.tools import Tool
|
6 |
|
|
|
|
|
7 |
from ask_candid.agents.schema import AgentState
|
8 |
|
9 |
logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s")
|
@@ -11,7 +16,86 @@ logger = logging.getLogger(__name__)
|
|
11 |
logger.setLevel(logging.INFO)
|
12 |
|
13 |
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
"""Invokes the agent model to generate a response based on the current state. Given
|
16 |
the question, it will decide to retrieve using the retriever tool, or simply end.
|
17 |
|
|
|
1 |
+
from typing import List, Tuple, Callable, Optional, Any
|
2 |
+
from functools import partial
|
3 |
import logging
|
4 |
|
5 |
+
from pydantic import BaseModel, Field
|
6 |
from langchain_core.language_models.llms import LLM
|
7 |
+
from langchain_core.documents import Document
|
8 |
from langchain_core.tools import Tool
|
9 |
|
10 |
+
from ask_candid.retrieval.elastic import get_query_results, get_reranked_results
|
11 |
+
from ask_candid.base.config.data import DataIndices
|
12 |
from ask_candid.agents.schema import AgentState
|
13 |
|
14 |
logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s")
|
|
|
16 |
logger.setLevel(logging.INFO)
|
17 |
|
18 |
|
19 |
+
class RetrieverInput(BaseModel):
|
20 |
+
"""Input to the Elasticsearch retriever."""
|
21 |
+
user_input: str = Field(description="query to look up in retriever")
|
22 |
+
|
23 |
+
|
24 |
+
def get_search_results(
|
25 |
+
user_input: str,
|
26 |
+
indices: List[DataIndices],
|
27 |
+
user_callback: Optional[Callable[[str], Any]] = None
|
28 |
+
) -> Tuple[str, List[Document]]:
|
29 |
+
"""End-to-end search and re-rank function.
|
30 |
+
|
31 |
+
Parameters
|
32 |
+
----------
|
33 |
+
user_input : str
|
34 |
+
Search context string
|
35 |
+
indices : List[DataIndices]
|
36 |
+
Semantic index names to search over
|
37 |
+
user_callback : Optional[Callable[[str], Any]], optional
|
38 |
+
Optional UI callback to inform the user of apps states, by default None
|
39 |
+
|
40 |
+
Returns
|
41 |
+
-------
|
42 |
+
Tuple[str, List[Document]]
|
43 |
+
(concatenated text from search results, documents list)
|
44 |
+
"""
|
45 |
+
|
46 |
+
if user_callback is not None:
|
47 |
+
try:
|
48 |
+
user_callback("Searching for relevant information")
|
49 |
+
except Exception as ex:
|
50 |
+
logger.warning("User callback was passed in but failed: %s", ex)
|
51 |
+
|
52 |
+
output = ["Search didn't return any Candid sources"]
|
53 |
+
page_content = []
|
54 |
+
content = "Search didn't return any Candid sources"
|
55 |
+
results = get_query_results(search_text=user_input, indices=indices)
|
56 |
+
if results:
|
57 |
+
output = get_reranked_results(results, search_text=user_input)
|
58 |
+
for doc in output:
|
59 |
+
page_content.append(doc.page_content)
|
60 |
+
content = "\n\n".join(page_content)
|
61 |
+
|
62 |
+
# for the tool we need to return a tuple for content_and_artifact type
|
63 |
+
return content, output
|
64 |
+
|
65 |
+
|
66 |
+
def retriever_tool(
|
67 |
+
indices: List[DataIndices],
|
68 |
+
user_callback: Optional[Callable[[str], Any]] = None
|
69 |
+
) -> Tool:
|
70 |
+
"""Tool component for use in conditional edge building for RAG execution graph.
|
71 |
+
Cannot use `create_retriever_tool` because it only provides content losing all metadata on the way
|
72 |
+
https://python.langchain.com/docs/how_to/custom_tools/#returning-artifacts-of-tool-execution
|
73 |
+
|
74 |
+
Parameters
|
75 |
+
----------
|
76 |
+
indices : List[DataIndices]
|
77 |
+
Semantic index names to search over
|
78 |
+
user_callback : Optional[Callable[[str], Any]], optional
|
79 |
+
Optional UI callback to inform the user of apps states, by default None
|
80 |
+
|
81 |
+
Returns
|
82 |
+
-------
|
83 |
+
Tool
|
84 |
+
"""
|
85 |
+
|
86 |
+
return Tool(
|
87 |
+
name="retrieve_social_sector_information",
|
88 |
+
func=partial(get_search_results, indices=indices, user_callback=user_callback),
|
89 |
+
description=(
|
90 |
+
"Return additional information about social and philanthropic sector, "
|
91 |
+
"including nonprofits (NGO), grants, foundations, funding, RFP, LOI, Candid."
|
92 |
+
),
|
93 |
+
args_schema=RetrieverInput,
|
94 |
+
response_format="content_and_artifact"
|
95 |
+
)
|
96 |
+
|
97 |
+
|
98 |
+
def search_agent(state: AgentState, llm: LLM, tools: List[Tool]) -> AgentState:
|
99 |
"""Invokes the agent model to generate a response based on the current state. Given
|
100 |
the question, it will decide to retrieve using the retriever tool, or simply end.
|
101 |
|
ask_candid/utils.py
CHANGED
@@ -77,7 +77,7 @@ def format_chat_ag_response(chatbot: List[Any]) -> List[Any]:
|
|
77 |
"""
|
78 |
sources = ""
|
79 |
if chatbot:
|
80 |
-
title = chatbot[-1]
|
81 |
if title == "Sources HTML":
|
82 |
sources = chatbot[-1]["content"]
|
83 |
chatbot.pop(-1)
|
|
|
77 |
"""
|
78 |
sources = ""
|
79 |
if chatbot:
|
80 |
+
title = (chatbot[-1].get("metadata") or {}).get("title", None)
|
81 |
if title == "Sources HTML":
|
82 |
sources = chatbot[-1]["content"]
|
83 |
chatbot.pop(-1)
|