Spaces:
Running
Running
Upload 36 files
Browse filesProject re-structuring
- app.py +78 -45
- ask_candid/__init__.py +0 -0
- ask_candid/agents/__init__.py +0 -0
- ask_candid/agents/elastic.py +293 -0
- ask_candid/base/__init__.py +0 -0
- ask_candid/base/api_base.py +42 -0
- ask_candid/base/api_base_async.py +48 -0
- ask_candid/base/config/__init__.py +0 -0
- ask_candid/base/config/connections.py +36 -0
- ask_candid/base/config/data.py +20 -0
- ask_candid/base/config/models.py +9 -0
- ask_candid/base/config/rest.py +21 -0
- ask_candid/base/lambda_base.py +58 -0
- ask_candid/base/utils.py +14 -0
- ask_candid/chat.py +251 -0
- ask_candid/indexing/__init__.py +0 -0
- ask_candid/retrieval/__init__.py +0 -0
- ask_candid/retrieval/elastic.py +323 -0
- ask_candid/retrieval/sources/__init__.py +0 -0
- ask_candid/retrieval/sources/candid_blog.py +43 -0
- ask_candid/retrieval/sources/candid_help.py +41 -0
- ask_candid/retrieval/sources/candid_learning.py +41 -0
- ask_candid/retrieval/sources/issuelab.py +50 -0
- ask_candid/retrieval/sources/youtube.py +54 -0
- ask_candid/services/__init__.py +0 -0
- ask_candid/services/org_search.py +50 -0
- ask_candid/services/small_lm.py +53 -0
- ask_candid/tools/__init__.py +0 -0
- ask_candid/tools/elastic/__init__.py +0 -0
- ask_candid/tools/elastic/index_data_tool.py +59 -0
- ask_candid/tools/elastic/index_details_tool.py +73 -0
- ask_candid/tools/elastic/index_search_tool.py +102 -0
- ask_candid/tools/elastic/list_indices_tool.py +58 -0
- ask_candid/tools/org_seach.py +194 -0
- ask_candid/tools/question_reformulation.py +44 -0
- ask_candid/utils.py +103 -0
app.py
CHANGED
@@ -3,21 +3,21 @@ import os
|
|
3 |
|
4 |
import gradio as gr
|
5 |
|
|
|
|
|
6 |
from langchain_openai.chat_models import ChatOpenAI
|
|
|
|
|
7 |
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
except ImportError:
|
14 |
-
from .utils import format_chat_ag_response
|
15 |
-
from .retrieval.config import ALL_INDICES
|
16 |
-
from .static.css import css_chat
|
17 |
-
from .chat import run_chat
|
18 |
|
19 |
ROOT = os.path.dirname(os.path.abspath(__file__))
|
20 |
-
|
|
|
21 |
|
22 |
class LoggedComponents(TypedDict):
|
23 |
context: List[gr.components.Component]
|
@@ -27,32 +27,46 @@ class LoggedComponents(TypedDict):
|
|
27 |
email: gr.components.Component
|
28 |
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
def execute(
|
31 |
thread_id: str,
|
32 |
user_input: Dict[str, Any],
|
33 |
-
|
|
|
34 |
max_new_tokens: int,
|
35 |
indices: Optional[List[str]] = None,
|
36 |
):
|
37 |
-
llm = ChatOpenAI(
|
38 |
-
model_name="gpt-4o",
|
39 |
-
max_tokens=max_new_tokens,
|
40 |
-
api_key=os.getenv("OPENAI_API_KEY"),
|
41 |
-
temperature=0.0,
|
42 |
-
streaming=True
|
43 |
-
)
|
44 |
-
|
45 |
return run_chat(
|
46 |
thread_id=thread_id,
|
47 |
user_input=user_input,
|
48 |
-
|
49 |
-
llm=
|
50 |
indices=indices
|
51 |
)
|
52 |
|
53 |
|
54 |
-
def
|
55 |
-
with gr.Blocks(theme=gr.themes.Soft(), title="
|
56 |
|
57 |
gr.Markdown(
|
58 |
"""
|
@@ -74,28 +88,38 @@ def build_chat() -> Tuple[LoggedComponents, gr.Blocks]:
|
|
74 |
choices=list(ALL_INDICES),
|
75 |
value=list(ALL_INDICES),
|
76 |
label="Sources to include",
|
77 |
-
interactive=True
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
)
|
79 |
max_new_tokens = gr.Slider(
|
80 |
-
value=256 * 3,
|
81 |
-
|
|
|
|
|
|
|
|
|
82 |
)
|
83 |
|
84 |
with gr.Column():
|
85 |
chatbot = gr.Chatbot(
|
86 |
-
label="
|
87 |
elem_id="chatbot",
|
88 |
-
bubble_full_width=
|
89 |
avatar_images=(
|
90 |
None,
|
91 |
-
os.path.join(ROOT, "static", "candid_logo_yellow.png")
|
92 |
),
|
93 |
height="45vh",
|
94 |
type="messages",
|
95 |
show_label=False,
|
96 |
show_copy_button=True,
|
97 |
show_share_button=True,
|
98 |
-
show_copy_all_button=True
|
99 |
)
|
100 |
msg = gr.MultimodalTextbox(label="Your message", interactive=True)
|
101 |
thread_id = gr.Text(visible=False, value="", label="thread_id")
|
@@ -104,24 +128,33 @@ def build_chat() -> Tuple[LoggedComponents, gr.Blocks]:
|
|
104 |
# pylint: disable=no-member
|
105 |
chat_msg = msg.submit(
|
106 |
fn=execute,
|
107 |
-
inputs=[thread_id, msg, chatbot, max_new_tokens, es_indices],
|
108 |
-
outputs=[msg, chatbot, thread_id]
|
109 |
)
|
110 |
chat_msg.then(format_chat_ag_response, chatbot, chatbot, api_name="bot_response")
|
111 |
-
logged = LoggedComponents(
|
112 |
-
context=[thread_id, chatbot]
|
113 |
-
)
|
114 |
return logged, demo
|
115 |
|
116 |
|
117 |
-
|
118 |
-
_,
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
|
|
|
|
124 |
],
|
125 |
-
|
126 |
-
|
|
|
|
|
|
|
127 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
import gradio as gr
|
5 |
|
6 |
+
from langchain_core.language_models.llms import LLM
|
7 |
+
|
8 |
from langchain_openai.chat_models import ChatOpenAI
|
9 |
+
from langchain_aws import ChatBedrock
|
10 |
+
import boto3
|
11 |
|
12 |
+
from ask_candid.base.config.rest import OPENAI
|
13 |
+
from ask_candid.base.config.models import Name2Endpoint
|
14 |
+
from ask_candid.base.config.data import ALL_INDICES
|
15 |
+
from ask_candid.utils import format_chat_ag_response
|
16 |
+
from ask_candid.chat import run_chat
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
ROOT = os.path.dirname(os.path.abspath(__file__))
|
19 |
+
BUCKET = "candid-data-science-reporting"
|
20 |
+
PREFIX = "Assistant"
|
21 |
|
22 |
class LoggedComponents(TypedDict):
|
23 |
context: List[gr.components.Component]
|
|
|
27 |
email: gr.components.Component
|
28 |
|
29 |
|
30 |
+
def select_foundation_model(model_name: str, max_new_tokens: int) -> LLM:
|
31 |
+
if model_name == "gpt-4o":
|
32 |
+
llm = ChatOpenAI(
|
33 |
+
model_name=Name2Endpoint[model_name],
|
34 |
+
max_tokens=max_new_tokens,
|
35 |
+
api_key=OPENAI["key"],
|
36 |
+
temperature=0.0,
|
37 |
+
streaming=True,
|
38 |
+
)
|
39 |
+
elif model_name in {"claude-3.5-haiku", "llama-3.1-70b-instruct", "mistral-large", "mixtral-8x7B"}:
|
40 |
+
llm = ChatBedrock(
|
41 |
+
client=boto3.client("bedrock-runtime"),
|
42 |
+
model=Name2Endpoint[model_name],
|
43 |
+
max_tokens=max_new_tokens,
|
44 |
+
temperature=0.0
|
45 |
+
)
|
46 |
+
else:
|
47 |
+
raise gr.Error(f"Base model `{model_name}` is not supported")
|
48 |
+
return llm
|
49 |
+
|
50 |
+
|
51 |
def execute(
|
52 |
thread_id: str,
|
53 |
user_input: Dict[str, Any],
|
54 |
+
history: List[Dict],
|
55 |
+
model_name: str,
|
56 |
max_new_tokens: int,
|
57 |
indices: Optional[List[str]] = None,
|
58 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
return run_chat(
|
60 |
thread_id=thread_id,
|
61 |
user_input=user_input,
|
62 |
+
history=history,
|
63 |
+
llm=select_foundation_model(model_name=model_name, max_new_tokens=max_new_tokens),
|
64 |
indices=indices
|
65 |
)
|
66 |
|
67 |
|
68 |
+
def build_rag_chat() -> Tuple[LoggedComponents, gr.Blocks]:
|
69 |
+
with gr.Blocks(theme=gr.themes.Soft(), title="Chat") as demo:
|
70 |
|
71 |
gr.Markdown(
|
72 |
"""
|
|
|
88 |
choices=list(ALL_INDICES),
|
89 |
value=list(ALL_INDICES),
|
90 |
label="Sources to include",
|
91 |
+
interactive=True,
|
92 |
+
)
|
93 |
+
llmname = gr.Radio(
|
94 |
+
label="Language model",
|
95 |
+
value="gpt-4o",
|
96 |
+
choices=list(Name2Endpoint.keys()),
|
97 |
+
interactive=True,
|
98 |
)
|
99 |
max_new_tokens = gr.Slider(
|
100 |
+
value=256 * 3,
|
101 |
+
minimum=128,
|
102 |
+
maximum=2048,
|
103 |
+
step=128,
|
104 |
+
label="Max new tokens",
|
105 |
+
interactive=True,
|
106 |
)
|
107 |
|
108 |
with gr.Column():
|
109 |
chatbot = gr.Chatbot(
|
110 |
+
label="AskCandid",
|
111 |
elem_id="chatbot",
|
112 |
+
bubble_full_width=True,
|
113 |
avatar_images=(
|
114 |
None,
|
115 |
+
os.path.join(ROOT, "static", "candid_logo_yellow.png"),
|
116 |
),
|
117 |
height="45vh",
|
118 |
type="messages",
|
119 |
show_label=False,
|
120 |
show_copy_button=True,
|
121 |
show_share_button=True,
|
122 |
+
show_copy_all_button=True,
|
123 |
)
|
124 |
msg = gr.MultimodalTextbox(label="Your message", interactive=True)
|
125 |
thread_id = gr.Text(visible=False, value="", label="thread_id")
|
|
|
128 |
# pylint: disable=no-member
|
129 |
chat_msg = msg.submit(
|
130 |
fn=execute,
|
131 |
+
inputs=[thread_id, msg, chatbot, llmname, max_new_tokens, es_indices],
|
132 |
+
outputs=[msg, chatbot, thread_id],
|
133 |
)
|
134 |
chat_msg.then(format_chat_ag_response, chatbot, chatbot, api_name="bot_response")
|
135 |
+
logged = LoggedComponents(context=[thread_id, chatbot])
|
|
|
|
|
136 |
return logged, demo
|
137 |
|
138 |
|
139 |
+
def build_app():
|
140 |
+
_, candid_chat = build_rag_chat()
|
141 |
+
|
142 |
+
with open(os.path.join(ROOT, "static", "chatStyle.css"), "r", encoding="utf8") as f:
|
143 |
+
css_chat = f.read()
|
144 |
+
|
145 |
+
demo = gr.TabbedInterface(
|
146 |
+
interface_list=[
|
147 |
+
candid_chat,
|
148 |
],
|
149 |
+
tab_names=[
|
150 |
+
"AskCandid",
|
151 |
+
],
|
152 |
+
theme=gr.themes.Soft(),
|
153 |
+
css=css_chat,
|
154 |
)
|
155 |
+
return demo
|
156 |
+
|
157 |
+
|
158 |
+
if __name__ == "__main__":
|
159 |
+
app = build_app()
|
160 |
+
app.queue(max_size=5).launch(show_api=False)
|
ask_candid/__init__.py
ADDED
File without changes
|
ask_candid/agents/__init__.py
ADDED
File without changes
|
ask_candid/agents/elastic.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import TypedDict
|
2 |
+
from functools import partial
|
3 |
+
import json
|
4 |
+
import ast
|
5 |
+
|
6 |
+
from pydantic import BaseModel, Field
|
7 |
+
|
8 |
+
from langchain_openai import ChatOpenAI
|
9 |
+
|
10 |
+
from langchain_core.runnables import RunnableSequence
|
11 |
+
from langchain_core.language_models.llms import LLM
|
12 |
+
|
13 |
+
from langchain.agents.openai_functions_agent.base import create_openai_functions_agent
|
14 |
+
from langchain.agents.agent import AgentExecutor
|
15 |
+
from langchain.agents.agent_types import AgentType
|
16 |
+
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
17 |
+
from langchain.output_parsers import PydanticOutputParser
|
18 |
+
from langchain.schema import BaseMessage
|
19 |
+
|
20 |
+
from langgraph.graph import StateGraph, END
|
21 |
+
|
22 |
+
from ask_candid.tools.elastic.list_indices_tool import ListIndicesTool
|
23 |
+
from ask_candid.tools.elastic.index_data_tool import IndexShowDataTool
|
24 |
+
from ask_candid.tools.elastic.index_details_tool import IndexDetailsTool
|
25 |
+
from ask_candid.tools.elastic.index_search_tool import create_search_tool
|
26 |
+
from ask_candid.base.config.rest import OPENAI
|
27 |
+
|
28 |
+
tools = [
|
29 |
+
ListIndicesTool(),
|
30 |
+
IndexShowDataTool(),
|
31 |
+
IndexDetailsTool(),
|
32 |
+
create_search_tool(),
|
33 |
+
]
|
34 |
+
|
35 |
+
|
36 |
+
class GraphState(TypedDict):
|
37 |
+
query: str = Field(
|
38 |
+
..., description="The user's query to be processed by the system."
|
39 |
+
)
|
40 |
+
agent_out: str = Field(
|
41 |
+
...,
|
42 |
+
description="The output generated by the AI agent after processing the query.",
|
43 |
+
)
|
44 |
+
next_step: str = Field(
|
45 |
+
..., description="The next step in the workflow, determined by query analysis."
|
46 |
+
)
|
47 |
+
es_query: dict = Field(
|
48 |
+
..., description="The Elasticsearch query generated or used by the agent."
|
49 |
+
)
|
50 |
+
|
51 |
+
es_result: dict = Field(
|
52 |
+
...,
|
53 |
+
description="The Elasticsearch query result generated or used by the agent.",
|
54 |
+
)
|
55 |
+
|
56 |
+
|
57 |
+
class AnalysisResult(BaseModel):
|
58 |
+
category: str = Field(..., description="Either 'general' or 'Database'")
|
59 |
+
|
60 |
+
|
61 |
+
def agent_factory() -> AgentExecutor:
|
62 |
+
"""
|
63 |
+
Creates and configures an AgentExecutor instance for interacting with Elasticsearch.
|
64 |
+
|
65 |
+
This function initializes an OpenAI GPT-4-based LLM with specific parameters,
|
66 |
+
constructs a prompt tailored for Elasticsearch assistance, and integrates the
|
67 |
+
agent with a set of tools to handle user queries. The agent is designed to work
|
68 |
+
with OpenAI functions for enhanced capabilities.
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
AgentExecutor: Configured agent ready to execute tasks with specified tools,
|
72 |
+
providing detailed intermediate steps for transparency.
|
73 |
+
"""
|
74 |
+
|
75 |
+
llm = ChatOpenAI(
|
76 |
+
model="gpt-4o", temperature=0, api_key=OPENAI["key"], streaming=False
|
77 |
+
)
|
78 |
+
|
79 |
+
tags_ = []
|
80 |
+
agent = AgentType.OPENAI_FUNCTIONS
|
81 |
+
tags_.append(agent.value if isinstance(agent, AgentType) else agent)
|
82 |
+
# Create the prompt
|
83 |
+
prompt = ChatPromptTemplate.from_messages(
|
84 |
+
[
|
85 |
+
("system", "You are a helpful elasticsearch assistant"),
|
86 |
+
MessagesPlaceholder(variable_name="chat_history", optional=True),
|
87 |
+
("human", "{input}"),
|
88 |
+
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
89 |
+
]
|
90 |
+
)
|
91 |
+
|
92 |
+
# Create the agent
|
93 |
+
agent_obj = create_openai_functions_agent(llm, tools, prompt)
|
94 |
+
|
95 |
+
return AgentExecutor.from_agent_and_tools(
|
96 |
+
agent=agent_obj,
|
97 |
+
tools=tools,
|
98 |
+
tags=tags_,
|
99 |
+
verbose=True,
|
100 |
+
return_intermediate_steps=True,
|
101 |
+
)
|
102 |
+
|
103 |
+
|
104 |
+
# define graph node functions
|
105 |
+
def general_query(state: GraphState, llm: LLM) -> GraphState:
|
106 |
+
"""
|
107 |
+
Processes a user query using an LLM and updates the graph state with the response.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
state (GraphState): Current graph state containing the user's query.
|
111 |
+
llm (LLM): Language model to process the query.
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
GraphState: Updated state with the LLM's response in "agent_out".
|
115 |
+
"""
|
116 |
+
print("> General query")
|
117 |
+
prompt = ChatPromptTemplate.from_template(
|
118 |
+
"Answer based on the user's query: {query}"
|
119 |
+
)
|
120 |
+
chain = prompt | llm
|
121 |
+
response = chain.invoke({"query": state["query"]})
|
122 |
+
if isinstance(response, BaseMessage):
|
123 |
+
state["agent_out"] = response.content
|
124 |
+
else:
|
125 |
+
state["agent_out"] = str(response)
|
126 |
+
return state
|
127 |
+
|
128 |
+
|
129 |
+
def database_agent(state: GraphState) -> GraphState:
|
130 |
+
"""
|
131 |
+
Executes a database query using an Elasticsearch agent and updates the graph state.
|
132 |
+
|
133 |
+
The agent queries indices and field names in the Elasticsearch database,
|
134 |
+
selects the appropriate index (`organization_dev_2`), and answers the user's question.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
state (GraphState): Current graph state containing the user's query.
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
GraphState: Updated state with the agent's output in "agent_out" and
|
141 |
+
the Elasticsearch query in "es_query".
|
142 |
+
"""
|
143 |
+
|
144 |
+
print("> database agent")
|
145 |
+
input_data = {
|
146 |
+
"input": f"""
|
147 |
+
Make sure that you query first the indices in the ElasticSearch database.
|
148 |
+
Make sure that after querying the indices you query the field names.
|
149 |
+
To answer the question choose ```organization_dev_2``` index
|
150 |
+
|
151 |
+
Then answer this question:
|
152 |
+
{state["query"]}
|
153 |
+
"""
|
154 |
+
}
|
155 |
+
agent_exec = agent_factory()
|
156 |
+
res = agent_exec.invoke(input_data)
|
157 |
+
state["agent_out"] = res["output"]
|
158 |
+
|
159 |
+
es_queries, es_results = {}, {}
|
160 |
+
for i, action in enumerate(res.get("intermediate_steps", []), start=1):
|
161 |
+
if action[0].tool == "elastic_index_search_tool":
|
162 |
+
es_queries[f"query_{i}"] = json.loads(action[0].tool_input.get("query") or "{}")
|
163 |
+
es_results[f"query_{i}"] = ast.literal_eval(action[-1] or "{}")
|
164 |
+
|
165 |
+
# if len(res["intermediate_steps"]) > 1:
|
166 |
+
# es_queries = {
|
167 |
+
# f"query_{i}": action[0].tool_input.get("query", "")
|
168 |
+
# for i, action in enumerate(res.get("intermediate_steps", []), start=1)
|
169 |
+
# if action[0].tool == "elastic_index_search_tool"
|
170 |
+
# }
|
171 |
+
|
172 |
+
# es_results = {
|
173 |
+
# f"result_{i}": action[-1]
|
174 |
+
# for i, action in enumerate(res.get("intermediate_steps", []), start=1)
|
175 |
+
# if action[0].tool == "elastic_index_search_tool"
|
176 |
+
# }
|
177 |
+
|
178 |
+
# state["es_query"] = es_queries
|
179 |
+
# state["es_result"] = es_results
|
180 |
+
# else:
|
181 |
+
# state["es_query"] = res["intermediate_steps"][-1][0].tool_input["query"]
|
182 |
+
# state["es_result"] = {"result": res["intermediate_steps"][-2][-1]}
|
183 |
+
|
184 |
+
state["es_query"] = es_queries
|
185 |
+
state["es_result"] = es_results
|
186 |
+
return state
|
187 |
+
|
188 |
+
|
189 |
+
def analyse_query(state: GraphState, llm: LLM) -> GraphState:
|
190 |
+
"""
|
191 |
+
Analyzes the user's query to classify it as either general or database-specific
|
192 |
+
and determines the next processing step.
|
193 |
+
|
194 |
+
Args:
|
195 |
+
state (GraphState): Current graph state containing the user's query.
|
196 |
+
llm (LLM): Language model used for query analysis.
|
197 |
+
|
198 |
+
Returns:
|
199 |
+
GraphState: Updated state with the classification result and the
|
200 |
+
next processing step in "next_step".
|
201 |
+
"""
|
202 |
+
|
203 |
+
print("> analyse query")
|
204 |
+
prompt_template = """Your task is to analyze the query ```{query}``` and classify it in:
|
205 |
+
general: it's a basic general enquiry
|
206 |
+
Database: query which is complicated and would require to go into the database and extract specific information
|
207 |
+
Output format:
|
208 |
+
{{"category": "<your_classification>"}}
|
209 |
+
"""
|
210 |
+
|
211 |
+
# Create the prompt
|
212 |
+
prompt = ChatPromptTemplate.from_template(prompt_template)
|
213 |
+
|
214 |
+
# Define the parser
|
215 |
+
parser = PydanticOutputParser(pydantic_object=AnalysisResult)
|
216 |
+
|
217 |
+
# Create the chain
|
218 |
+
chain = RunnableSequence(prompt, llm)
|
219 |
+
# Invoke the chain with the query
|
220 |
+
response = chain.invoke({"query": state["query"]})
|
221 |
+
if "Database" in response.content:
|
222 |
+
state["next_step"] = "es_database_agent"
|
223 |
+
else:
|
224 |
+
state["next_step"] = "general_query"
|
225 |
+
return state
|
226 |
+
|
227 |
+
|
228 |
+
def final_answer(state: GraphState, llm: LLM) -> GraphState:
|
229 |
+
"""
|
230 |
+
Generates and presents the final response based on the user's query and the AI's output.
|
231 |
+
|
232 |
+
Args:
|
233 |
+
state (GraphState): Current graph state containing the query and AI output.
|
234 |
+
llm (LLM): Language model used to format the final response.
|
235 |
+
|
236 |
+
Returns:
|
237 |
+
GraphState: Updated state with the formatted final answer in "agent_out".
|
238 |
+
"""
|
239 |
+
|
240 |
+
print("> Final Answer")
|
241 |
+
prompt_template = """
|
242 |
+
Your task is to present the result based on the user's query:
|
243 |
+
|
244 |
+
Query: ```{query}```
|
245 |
+
|
246 |
+
AI Output:
|
247 |
+
```{output}```
|
248 |
+
"""
|
249 |
+
prompt = ChatPromptTemplate.from_template(prompt_template)
|
250 |
+
chain = RunnableSequence(prompt, llm)
|
251 |
+
response = chain.invoke({"query": state["query"], "output": state["agent_out"]})
|
252 |
+
|
253 |
+
return {"agent_out": response.content}
|
254 |
+
|
255 |
+
|
256 |
+
def build_compute_graph(llm: LLM) -> StateGraph:
|
257 |
+
"""
|
258 |
+
Constructs a compute graph for processing user queries using a defined workflow.
|
259 |
+
|
260 |
+
The workflow includes nodes for query analysis, handling general or database-specific queries,
|
261 |
+
and generating the final response. Conditional logic determines the path based on query type.
|
262 |
+
|
263 |
+
Args:
|
264 |
+
llm (LLM): Language model to be used in various nodes for processing queries.
|
265 |
+
|
266 |
+
Returns:
|
267 |
+
StateGraph: Configured compute graph ready for execution.
|
268 |
+
"""
|
269 |
+
# Create the workflow
|
270 |
+
workflow = StateGraph(GraphState)
|
271 |
+
|
272 |
+
# Add nodes
|
273 |
+
workflow.add_node("analyse", partial(analyse_query, llm=llm))
|
274 |
+
workflow.add_node("general_query", partial(general_query, llm=llm))
|
275 |
+
workflow.add_node("es_database_agent", database_agent)
|
276 |
+
workflow.add_node("final_answer", partial(final_answer, llm=llm))
|
277 |
+
|
278 |
+
# Set entry point
|
279 |
+
workflow.set_entry_point("analyse")
|
280 |
+
|
281 |
+
# Add conditional edges
|
282 |
+
workflow.add_conditional_edges(
|
283 |
+
"analyse",
|
284 |
+
lambda x: x["next_step"], # Use the return value of analyse_query directly
|
285 |
+
{"es_database_agent": "es_database_agent", "general_query": "general_query"},
|
286 |
+
)
|
287 |
+
|
288 |
+
# Add edges to end the workflow
|
289 |
+
workflow.add_edge("es_database_agent", "final_answer")
|
290 |
+
workflow.add_edge("general_query", "final_answer")
|
291 |
+
workflow.add_edge("final_answer", END)
|
292 |
+
|
293 |
+
return workflow
|
ask_candid/base/__init__.py
ADDED
File without changes
|
ask_candid/base/api_base.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Optional, Any
|
2 |
+
|
3 |
+
from urllib3.util.retry import Retry
|
4 |
+
from requests.adapters import HTTPAdapter
|
5 |
+
import requests
|
6 |
+
|
7 |
+
|
8 |
+
class BaseAPI:
|
9 |
+
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
url: str,
|
13 |
+
headers: Optional[Dict[str, Any]] = None,
|
14 |
+
total_retries: int = 3,
|
15 |
+
backoff_factor: int = 2
|
16 |
+
) -> None:
|
17 |
+
total_retries = max(total_retries, 10)
|
18 |
+
|
19 |
+
adapter = HTTPAdapter(
|
20 |
+
max_retries=Retry(
|
21 |
+
total=total_retries,
|
22 |
+
status_forcelist=[429, 500, 502, 503, 504],
|
23 |
+
allowed_methods=frozenset({"HEAD", "GET", "POST", "OPTIONS"}),
|
24 |
+
backoff_factor=backoff_factor,
|
25 |
+
)
|
26 |
+
)
|
27 |
+
self.session = requests.Session()
|
28 |
+
self.session.mount("https://", adapter)
|
29 |
+
self.session.mount("http://", adapter)
|
30 |
+
|
31 |
+
self.__url = url
|
32 |
+
self.__headers = headers
|
33 |
+
|
34 |
+
def get(self, **request_kwargs):
|
35 |
+
r = self.session.get(url=self.__url, headers=self.__headers, params=request_kwargs, timeout=30)
|
36 |
+
r.raise_for_status()
|
37 |
+
return r.json()
|
38 |
+
|
39 |
+
def post(self, payload: Dict[str, Any]):
|
40 |
+
r = self.session.post(url=self.__url, headers=self.__headers, json=payload, timeout=30)
|
41 |
+
r.raise_for_status()
|
42 |
+
return r.json()
|
ask_candid/base/api_base_async.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Optional, Any
|
2 |
+
import json
|
3 |
+
|
4 |
+
import aiohttp
|
5 |
+
|
6 |
+
|
7 |
+
class BaseAsyncAPI:
|
8 |
+
|
9 |
+
def __init__(self, url: str, headers: Optional[Dict[str, Any]] = None, retries: int = 3) -> None:
|
10 |
+
self.__url = url
|
11 |
+
self.__headers = headers
|
12 |
+
self.__retries = max(retries, 5)
|
13 |
+
|
14 |
+
async def get(self, **request_kwargs):
|
15 |
+
session_timeout = aiohttp.ClientTimeout(total=30)
|
16 |
+
async with aiohttp.ClientSession(headers=self.__headers, timeout=session_timeout) as session:
|
17 |
+
output = {}
|
18 |
+
count = 1
|
19 |
+
while True:
|
20 |
+
if count >= self.__retries:
|
21 |
+
break
|
22 |
+
async with session.get(url=self.__url, params=request_kwargs) as r:
|
23 |
+
if r.status in {429, 500, 502, 503, 504}:
|
24 |
+
count += 1
|
25 |
+
elif r.status == 200:
|
26 |
+
output = await r.json()
|
27 |
+
break
|
28 |
+
else:
|
29 |
+
break
|
30 |
+
return output
|
31 |
+
|
32 |
+
async def post(self, payload: Dict[str, Any]):
|
33 |
+
session_timeout = aiohttp.ClientTimeout(total=30)
|
34 |
+
async with aiohttp.ClientSession(headers=self.__headers, timeout=session_timeout) as session:
|
35 |
+
output = {}
|
36 |
+
count = 1
|
37 |
+
while True:
|
38 |
+
if count >= self.__retries:
|
39 |
+
break
|
40 |
+
async with session.post(url=self.__url, data=json.dumps(payload).encode('utf8')) as r:
|
41 |
+
if r.status in {429, 500, 502, 503, 504}:
|
42 |
+
count += 1
|
43 |
+
elif r.status == 200:
|
44 |
+
output = await r.json()
|
45 |
+
break
|
46 |
+
else:
|
47 |
+
break
|
48 |
+
return output
|
ask_candid/base/config/__init__.py
ADDED
File without changes
|
ask_candid/base/config/connections.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
|
3 |
+
from dotenv import dotenv_values, find_dotenv
|
4 |
+
|
5 |
+
|
6 |
+
@dataclass
|
7 |
+
class BaseElasticSearchConnection:
|
8 |
+
"""Elasticsearch connection dataclass
|
9 |
+
"""
|
10 |
+
url: str = field(default_factory=str)
|
11 |
+
username: str = field(default_factory=str)
|
12 |
+
password: str = field(default_factory=str)
|
13 |
+
|
14 |
+
|
15 |
+
@dataclass
|
16 |
+
class BaseElasticAPIKeyCredential:
|
17 |
+
"""Cloud ID/API key data class
|
18 |
+
"""
|
19 |
+
cloud_id: str = field(default_factory=str)
|
20 |
+
api_key: str = field(default_factory=str)
|
21 |
+
|
22 |
+
|
23 |
+
__env_values__ = dotenv_values(
|
24 |
+
dotenv_path=find_dotenv(".env", raise_error_if_not_found=True)
|
25 |
+
)
|
26 |
+
|
27 |
+
SEMANTIC_ELASTIC_QA = BaseElasticAPIKeyCredential(
|
28 |
+
cloud_id=__env_values__.get("SEMANTIC_ELASTIC_CLOUD_ID"),
|
29 |
+
api_key=__env_values__.get("SEMANTIC_ELASTIC_API_KEY"),
|
30 |
+
)
|
31 |
+
|
32 |
+
CDS_ELASTIC = BaseElasticSearchConnection(
|
33 |
+
url="https://cdses.candid.org:9200",
|
34 |
+
username=__env_values__.get("CDS_UID"),
|
35 |
+
password=__env_values__.get("CDS_PWD")
|
36 |
+
)
|
ask_candid/base/config/data.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class ElasticIndexMapping:
|
2 |
+
"Mapping from plain name to Elasticsearch index name"
|
3 |
+
|
4 |
+
ISSUELAB_INDEX = "search-semantic-issuelab_v1"
|
5 |
+
ISSUELAB_INDEX_ELSER = "search-semantic-issuelab-elser_ve2"
|
6 |
+
YOUTUBE_INDEX = "search-semantic-youtube_v1"
|
7 |
+
YOUTUBE_INDEX_ELSER = "search-semantic-youtube-elser_ve1"
|
8 |
+
CANDID_BLOG_INDEX = "search-semantic-candid-blog_v1"
|
9 |
+
CANDID_BLOG_INDEX_ELSER = "search-semantic-candid-blog-elser_ve2"
|
10 |
+
CANDID_LEARNING_INDEX_ELSER = "search-semantic-candid-learning_ve1"
|
11 |
+
CANDID_HELP_INDEX_ELSER = "search-semantic-candid-help-elser_ve1"
|
12 |
+
|
13 |
+
|
14 |
+
ALL_INDICES = (
|
15 |
+
"issuelab",
|
16 |
+
"youtube",
|
17 |
+
"candid_blog",
|
18 |
+
"candid_learning",
|
19 |
+
"candid_help"
|
20 |
+
)
|
ask_candid/base/config/models.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from types import MappingProxyType
|
2 |
+
|
3 |
+
Name2Endpoint = MappingProxyType({
|
4 |
+
"gpt-4o": "gpt-4o",
|
5 |
+
"claude-3.5-haiku": "us.anthropic.claude-3-5-haiku-20241022-v1:0",
|
6 |
+
# "llama-3.1-70b-instruct": "us.meta.llama3-1-70b-instruct-v1:0",
|
7 |
+
# "mistral-large": "mistral.mistral-large-2402-v1:0",
|
8 |
+
# "mixtral-8x7B": "mistral.mixtral-8x7b-instruct-v0:1",
|
9 |
+
})
|
ask_candid/base/config/rest.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import TypedDict
|
2 |
+
|
3 |
+
from dotenv import dotenv_values, find_dotenv
|
4 |
+
|
5 |
+
|
6 |
+
class Api(TypedDict):
|
7 |
+
"""REST API configuration template
|
8 |
+
"""
|
9 |
+
url: str
|
10 |
+
key: str
|
11 |
+
|
12 |
+
__env_values__ = dotenv_values(
|
13 |
+
dotenv_path=find_dotenv(".env", raise_error_if_not_found=True)
|
14 |
+
)
|
15 |
+
|
16 |
+
CDS_API = Api(
|
17 |
+
url=__env_values__.get("CDS_API_URL"),
|
18 |
+
key=__env_values__.get("CDS_API_KEY")
|
19 |
+
)
|
20 |
+
|
21 |
+
OPENAI = Api(url=None, key=__env_values__.get("OPENAI_API_KEY"))
|
ask_candid/base/lambda_base.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict, Union, Optional, Any
|
2 |
+
from time import sleep
|
3 |
+
import json
|
4 |
+
|
5 |
+
import boto3
|
6 |
+
|
7 |
+
|
8 |
+
class LambdaInvokeBase:
|
9 |
+
"""Base class for AWS Lambda direct-invocation based classes. Each class which inherits from this only serves a
|
10 |
+
single function.
|
11 |
+
|
12 |
+
Parameters
|
13 |
+
----------
|
14 |
+
function_name : str
|
15 |
+
Name of the Lambda function to invoke
|
16 |
+
access_key : Optional[str], optional
|
17 |
+
AWS access key, by default None
|
18 |
+
secret_key : Optional[str], optional
|
19 |
+
AWS secret key, by default None
|
20 |
+
"""
|
21 |
+
|
22 |
+
errors = frozenset([
|
23 |
+
"Unhandled"
|
24 |
+
])
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self, function_name: str,
|
28 |
+
access_key: Optional[str] = None, secret_key: Optional[str] = None,
|
29 |
+
) -> None:
|
30 |
+
if access_key is not None and secret_key is not None:
|
31 |
+
self._client = boto3.client(
|
32 |
+
"lambda",
|
33 |
+
aws_access_key_id=access_key,
|
34 |
+
aws_secret_access_key=secret_key,
|
35 |
+
region_name="us-east-1",
|
36 |
+
)
|
37 |
+
else:
|
38 |
+
self._client = boto3.client("lambda", region_name='us-east-1')
|
39 |
+
|
40 |
+
self.function_name = function_name
|
41 |
+
|
42 |
+
def _submit_request(self, payload: Dict[str, Any]) -> Union[Dict[str, Any], List[Any]]:
|
43 |
+
response = self._client.invoke(
|
44 |
+
FunctionName=self.function_name,
|
45 |
+
InvocationType="RequestResponse",
|
46 |
+
Payload=json.dumps(payload),
|
47 |
+
)
|
48 |
+
|
49 |
+
if response.get("FunctionError") in self.errors:
|
50 |
+
# could use recursion, but we need to keep track of number of function calls
|
51 |
+
sleep(1)
|
52 |
+
response = self._client.invoke(
|
53 |
+
FunctionName=self.function_name,
|
54 |
+
InvocationType="RequestResponse",
|
55 |
+
Payload=json.dumps(payload),
|
56 |
+
)
|
57 |
+
|
58 |
+
return json.loads(response["Payload"].read())
|
ask_candid/base/utils.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
|
3 |
+
|
4 |
+
def async_tasks(*tasks):
|
5 |
+
|
6 |
+
async def gather(*t):
|
7 |
+
t = [await _ for _ in t]
|
8 |
+
return await asyncio.gather(*t)
|
9 |
+
|
10 |
+
loop = asyncio.new_event_loop()
|
11 |
+
results = loop.run_until_complete(gather(*tasks))
|
12 |
+
loop.stop()
|
13 |
+
loop.close()
|
14 |
+
return results
|
ask_candid/chat.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Dict, Any, TypedDict, Annotated, Sequence
|
2 |
+
from functools import partial
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
|
8 |
+
from langchain_core.messages import AIMessage, BaseMessage
|
9 |
+
from langchain_core.output_parsers import StrOutputParser
|
10 |
+
from langchain_core.prompts import ChatPromptTemplate
|
11 |
+
from langchain_core.language_models.llms import LLM
|
12 |
+
|
13 |
+
from langgraph.prebuilt import tools_condition, ToolNode
|
14 |
+
from langgraph.checkpoint.memory import MemorySaver
|
15 |
+
from langgraph.graph.state import StateGraph
|
16 |
+
from langgraph.graph.message import add_messages
|
17 |
+
from langgraph.constants import START, END
|
18 |
+
|
19 |
+
from ask_candid.tools.org_seach import extract_org_links_from_chatbot, embed_org_links_in_text, generate_org_link_dict
|
20 |
+
from ask_candid.tools.question_reformulation import reformulate_question_using_history
|
21 |
+
from ask_candid.utils import html_format_docs_chat, get_session_id
|
22 |
+
from ask_candid.retrieval.elastic import retriever_tool
|
23 |
+
|
24 |
+
ROOT = os.path.dirname(os.path.abspath(__file__))
|
25 |
+
logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s")
|
26 |
+
logger = logging.getLogger(__name__)
|
27 |
+
logger.setLevel(logging.INFO)
|
28 |
+
|
29 |
+
# TODO https://www.metadocs.co/2024/08/29/simple-domain-specific-corrective-rag-with-langchain-and-langgraph/
|
30 |
+
|
31 |
+
|
32 |
+
class AgentState(TypedDict):
|
33 |
+
# The add_messages function defines how an update should be processed
|
34 |
+
# Default is to replace. add_messages says "append"
|
35 |
+
messages: Annotated[Sequence[BaseMessage], add_messages]
|
36 |
+
user_input: str
|
37 |
+
org_dict: Dict
|
38 |
+
|
39 |
+
|
40 |
+
def search_agent(state, llm: LLM, tools) -> AgentState:
|
41 |
+
"""Invokes the agent model to generate a response based on the current state. Given
|
42 |
+
the question, it will decide to retrieve using the retriever tool, or simply end.
|
43 |
+
|
44 |
+
Parameters
|
45 |
+
----------
|
46 |
+
state : _type_
|
47 |
+
The current state
|
48 |
+
llm : LLM
|
49 |
+
tools : _type_
|
50 |
+
_description_
|
51 |
+
|
52 |
+
Returns
|
53 |
+
-------
|
54 |
+
AgentState
|
55 |
+
The updated state with the agent response appended to messages
|
56 |
+
"""
|
57 |
+
|
58 |
+
logger.info("---SEARCH AGENT---")
|
59 |
+
messages = state["messages"]
|
60 |
+
question = messages[-1].content
|
61 |
+
|
62 |
+
model = llm.bind_tools(tools)
|
63 |
+
response = model.invoke(messages)
|
64 |
+
# return a list, because this will get added to the existing list
|
65 |
+
return {"messages": [response], "user_input": question}
|
66 |
+
|
67 |
+
|
68 |
+
def generate_with_context(state, llm: LLM) -> AgentState:
|
69 |
+
"""Generate answer.
|
70 |
+
|
71 |
+
Parameters
|
72 |
+
----------
|
73 |
+
state : _type_
|
74 |
+
The current state
|
75 |
+
llm : LLM
|
76 |
+
tools : _type_
|
77 |
+
_description_
|
78 |
+
|
79 |
+
Returns
|
80 |
+
-------
|
81 |
+
AgentState
|
82 |
+
The updated state with the agent response appended to messages
|
83 |
+
"""
|
84 |
+
|
85 |
+
logger.info("---GENERATE ANSWER---")
|
86 |
+
messages = state["messages"]
|
87 |
+
question = state["user_input"]
|
88 |
+
last_message = messages[-1]
|
89 |
+
|
90 |
+
sources_str = last_message.content
|
91 |
+
sources_list = last_message.artifact # cannot use directly as list of Documents
|
92 |
+
# converting to html string
|
93 |
+
sources_html = html_format_docs_chat(sources_list)
|
94 |
+
if sources_list:
|
95 |
+
logger.info("---ADD SOURCES---")
|
96 |
+
state["messages"].append(BaseMessage(content=sources_html, type="HTML"))
|
97 |
+
|
98 |
+
# Prompt
|
99 |
+
qa_system_prompt = """
|
100 |
+
You are an assistant for question-answering tasks in the social and philanthropic sector. \n
|
101 |
+
Use the following pieces of retrieved context to answer the question at the end. \n
|
102 |
+
If you don't know the answer, just say that you don't know. \n
|
103 |
+
Keep the response professional, friendly, and as concise as possible. \n
|
104 |
+
Question: {question}
|
105 |
+
Context: {context}
|
106 |
+
Answer:
|
107 |
+
"""
|
108 |
+
|
109 |
+
qa_prompt = ChatPromptTemplate(
|
110 |
+
[
|
111 |
+
("system", qa_system_prompt),
|
112 |
+
("human", question),
|
113 |
+
]
|
114 |
+
)
|
115 |
+
|
116 |
+
rag_chain = qa_prompt | llm | StrOutputParser()
|
117 |
+
response = rag_chain.invoke({"context": sources_str, "question": question})
|
118 |
+
# couldn't figure out why returning usual "response" was seen as HumanMessage
|
119 |
+
return {"messages": [AIMessage(content=response)], "user_input": question}
|
120 |
+
|
121 |
+
|
122 |
+
def has_org_name(state: AgentState) -> AgentState:
|
123 |
+
"""
|
124 |
+
Processes the latest message to extract organization links and determine the next step.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
state (AgentState): The current state of the agent, including a list of messages.
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
dict: A dictionary with the next agent action and, if available, a dictionary of organization links.
|
131 |
+
"""
|
132 |
+
logger.info("---HAS ORG NAMES?---")
|
133 |
+
messages = state["messages"]
|
134 |
+
last_message = messages[-1].content
|
135 |
+
output_list = extract_org_links_from_chatbot(last_message)
|
136 |
+
link_dict = generate_org_link_dict(output_list) if output_list else {}
|
137 |
+
if link_dict:
|
138 |
+
logger.info("---FOUND ORG NAMES---")
|
139 |
+
return {"next": "insert_org_link", "org_dict": link_dict}
|
140 |
+
logger.info("---NO ORG NAMES FOUND---")
|
141 |
+
return {"next": END, "messages": messages}
|
142 |
+
|
143 |
+
|
144 |
+
def insert_org_link(state: AgentState) -> AgentState:
|
145 |
+
"""
|
146 |
+
Embeds organization links in the latest message content and returns it as an AI message.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
state (dict): The current state, including the organization links and latest message.
|
150 |
+
|
151 |
+
Returns:
|
152 |
+
dict: A dictionary with the updated message content as an AIMessage.
|
153 |
+
"""
|
154 |
+
logger.info("---INSERT ORG LINKS---")
|
155 |
+
messages = state["messages"]
|
156 |
+
last_message = messages[-1].content
|
157 |
+
messages.pop(-1) # Deleting the original message because we will append the same one but with links
|
158 |
+
link_dict = state["org_dict"]
|
159 |
+
last_message = embed_org_links_in_text(last_message, link_dict)
|
160 |
+
return {"messages": [AIMessage(content=last_message)]}
|
161 |
+
|
162 |
+
|
163 |
+
def build_compute_graph(llm: LLM, indices: List[str]) -> StateGraph:
|
164 |
+
candid_retriever_tool = retriever_tool(indices=indices)
|
165 |
+
retrieve = ToolNode([candid_retriever_tool])
|
166 |
+
tools = [candid_retriever_tool]
|
167 |
+
|
168 |
+
G = StateGraph(AgentState)
|
169 |
+
# Add nodes
|
170 |
+
G.add_node("reformulate", partial(reformulate_question_using_history, llm=llm))
|
171 |
+
G.add_node("search_agent", partial(search_agent, llm=llm, tools=tools))
|
172 |
+
G.add_node("retrieve", retrieve)
|
173 |
+
G.add_node("generate_with_context", partial(generate_with_context, llm=llm))
|
174 |
+
G.add_node("has_org_name", has_org_name)
|
175 |
+
G.add_node("insert_org_link", insert_org_link)
|
176 |
+
|
177 |
+
# Add edges
|
178 |
+
G.add_edge(START, "reformulate")
|
179 |
+
G.add_edge("reformulate", "search_agent")
|
180 |
+
# Conditional edges from search_agent
|
181 |
+
G.add_conditional_edges(
|
182 |
+
source="search_agent",
|
183 |
+
path=tools_condition,
|
184 |
+
path_map={
|
185 |
+
"tools": "retrieve",
|
186 |
+
END: "has_org_name",
|
187 |
+
},
|
188 |
+
)
|
189 |
+
G.add_edge("retrieve", "generate_with_context")
|
190 |
+
|
191 |
+
# Add edges
|
192 |
+
G.add_edge("generate_with_context", "has_org_name")
|
193 |
+
# Use add_conditional_edges for has_org_name
|
194 |
+
G.add_conditional_edges(
|
195 |
+
"has_org_name",
|
196 |
+
lambda x: x["next"], # Now we're accessing the 'next' key from the dict
|
197 |
+
{"insert_org_link": "insert_org_link", END: END},
|
198 |
+
)
|
199 |
+
G.add_edge("insert_org_link", END)
|
200 |
+
|
201 |
+
return G
|
202 |
+
|
203 |
+
|
204 |
+
def run_chat(
|
205 |
+
thread_id: str,
|
206 |
+
user_input: Dict[str, Any],
|
207 |
+
history: List[Dict],
|
208 |
+
llm: LLM,
|
209 |
+
indices: Optional[List[str]] = None,
|
210 |
+
):
|
211 |
+
# https://langchain-ai.github.io/langgraph/tutorials/rag/langgraph_agentic_rag/#graph
|
212 |
+
|
213 |
+
if len(history) == 0:
|
214 |
+
history.append({
|
215 |
+
"role": "system",
|
216 |
+
"content": (
|
217 |
+
"You are a Candid subject matter expert on the social sector and philanthropy. "
|
218 |
+
"You should address the user's queries and stay on topic."
|
219 |
+
)
|
220 |
+
})
|
221 |
+
|
222 |
+
history.append({"role": "user", "content": user_input["text"]})
|
223 |
+
inputs = {"messages": history}
|
224 |
+
# thread_id can be an email https://github.com/yurisasc/memory-enhanced-ai-assistant/blob/main/assistant.py
|
225 |
+
thread_id = get_session_id(thread_id)
|
226 |
+
config = {"configurable": {"thread_id": thread_id}}
|
227 |
+
|
228 |
+
workflow = build_compute_graph(llm=llm, indices=indices)
|
229 |
+
|
230 |
+
memory = MemorySaver() # TODO: don't use for Prod
|
231 |
+
graph = workflow.compile(checkpointer=memory)
|
232 |
+
response = graph.invoke(inputs, config=config)
|
233 |
+
messages = response["messages"]
|
234 |
+
last_message = messages[-1]
|
235 |
+
ai_answer = last_message.content
|
236 |
+
sources_html = ""
|
237 |
+
for message in messages[-2:]:
|
238 |
+
if message.type == "HTML":
|
239 |
+
sources_html = message.content
|
240 |
+
|
241 |
+
history.append({"role": "assistant", "content": ai_answer})
|
242 |
+
if sources_html:
|
243 |
+
history.append(
|
244 |
+
{
|
245 |
+
"role": "assistant",
|
246 |
+
"content": sources_html,
|
247 |
+
"metadata": {"title": "Sources HTML"},
|
248 |
+
}
|
249 |
+
)
|
250 |
+
|
251 |
+
return gr.MultimodalTextbox(value=None, interactive=True), history, thread_id
|
ask_candid/indexing/__init__.py
ADDED
File without changes
|
ask_candid/retrieval/__init__.py
ADDED
File without changes
|
ask_candid/retrieval/elastic.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple, Dict, Iterable, Iterator, Optional, Any
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from functools import partial
|
4 |
+
from itertools import groupby
|
5 |
+
|
6 |
+
from pydantic import BaseModel, Field
|
7 |
+
from langchain_core.documents import Document
|
8 |
+
from langchain_core.tools import Tool
|
9 |
+
|
10 |
+
from elasticsearch import Elasticsearch
|
11 |
+
|
12 |
+
from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA
|
13 |
+
from ask_candid.base.config.data import ElasticIndexMapping, ALL_INDICES
|
14 |
+
|
15 |
+
|
16 |
+
@dataclass
|
17 |
+
class ElasticHitsResult:
|
18 |
+
"""Dataclass for Elasticsearch hits results
|
19 |
+
"""
|
20 |
+
index: str
|
21 |
+
id: Any
|
22 |
+
score: float
|
23 |
+
source: Dict[str, Any]
|
24 |
+
inner_hits: Dict[str, Any]
|
25 |
+
|
26 |
+
|
27 |
+
class RetrieverInput(BaseModel):
|
28 |
+
"""Input to the Elasticsearch retriever."""
|
29 |
+
user_input: str = Field(description="query to look up in retriever")
|
30 |
+
|
31 |
+
|
32 |
+
def build_text_expansion_query(
|
33 |
+
query: str,
|
34 |
+
fields: Tuple[str],
|
35 |
+
model_id: str = ".elser_model_2_linux-x86_64"
|
36 |
+
) -> Dict[str, Any]:
|
37 |
+
|
38 |
+
output = []
|
39 |
+
|
40 |
+
for f in fields:
|
41 |
+
output.append({
|
42 |
+
"nested": {
|
43 |
+
"path": f"embeddings.{f}.chunks",
|
44 |
+
"query": {
|
45 |
+
"text_expansion": {
|
46 |
+
f"embeddings.{f}.chunks.vector": {
|
47 |
+
"model_id": model_id,
|
48 |
+
"model_text": query,
|
49 |
+
"boost": 1 / len(fields)
|
50 |
+
}
|
51 |
+
}
|
52 |
+
},
|
53 |
+
"inner_hits": {
|
54 |
+
"_source": False,
|
55 |
+
"size": 2,
|
56 |
+
"fields": [f"embeddings.{f}.chunks.chunk"]
|
57 |
+
}
|
58 |
+
}
|
59 |
+
})
|
60 |
+
return {"query": {"bool": {"should": output}}}
|
61 |
+
|
62 |
+
|
63 |
+
def query_builder(query: str, indices: List[str]):
|
64 |
+
queries = []
|
65 |
+
if indices is None:
|
66 |
+
indices = list(ALL_INDICES)
|
67 |
+
|
68 |
+
for index in indices:
|
69 |
+
if index == "issuelab":
|
70 |
+
q = build_text_expansion_query(
|
71 |
+
query=query,
|
72 |
+
fields=("description", "content", "combined_issuelab_findings", "combined_item_description")
|
73 |
+
)
|
74 |
+
q["_source"] = {"excludes": ["embeddings"]}
|
75 |
+
q["size"] = 1
|
76 |
+
queries.extend([{"index": ElasticIndexMapping.ISSUELAB_INDEX_ELSER}, q])
|
77 |
+
elif index == "youtube":
|
78 |
+
q = build_text_expansion_query(
|
79 |
+
query=query,
|
80 |
+
fields=("captions_cleaned", "description_cleaned", "title")
|
81 |
+
)
|
82 |
+
# text_cleaned duplicates captions_cleaned
|
83 |
+
q["_source"] = {"excludes": ["embeddings", "captions", "description", "text_cleaned"]}
|
84 |
+
q["size"] = 2
|
85 |
+
queries.extend([{"index": ElasticIndexMapping.YOUTUBE_INDEX_ELSER}, q])
|
86 |
+
elif index == "candid_blog":
|
87 |
+
q = build_text_expansion_query(
|
88 |
+
query=query,
|
89 |
+
fields=("content", "title")
|
90 |
+
)
|
91 |
+
q["_source"] = {"excludes": ["embeddings"]}
|
92 |
+
q["size"] = 2
|
93 |
+
queries.extend([{"index": ElasticIndexMapping.CANDID_BLOG_INDEX_ELSER}, q])
|
94 |
+
elif index == "candid_learning":
|
95 |
+
q = build_text_expansion_query(
|
96 |
+
query=query,
|
97 |
+
fields=("content", "title", "training_topics", "staff_recommendations")
|
98 |
+
)
|
99 |
+
q["_source"] = {"excludes": ["embeddings"]}
|
100 |
+
q["size"] = 2
|
101 |
+
queries.extend([{"index": ElasticIndexMapping.CANDID_LEARNING_INDEX_ELSER}, q])
|
102 |
+
elif index == "candid_help":
|
103 |
+
q = build_text_expansion_query(
|
104 |
+
query=query,
|
105 |
+
fields=("content", "combined_article_description")
|
106 |
+
)
|
107 |
+
q["_source"] = {"excludes": ["embeddings"]}
|
108 |
+
q["size"] = 2
|
109 |
+
queries.extend([{"index": ElasticIndexMapping.CANDID_HELP_INDEX_ELSER}, q])
|
110 |
+
|
111 |
+
return queries
|
112 |
+
|
113 |
+
|
114 |
+
def multi_search(queries: List[ElasticHitsResult]):
|
115 |
+
results = []
|
116 |
+
with Elasticsearch(
|
117 |
+
cloud_id=SEMANTIC_ELASTIC_QA.cloud_id,
|
118 |
+
api_key=SEMANTIC_ELASTIC_QA.api_key,
|
119 |
+
verify_certs=False,
|
120 |
+
request_timeout=60 * 3
|
121 |
+
) as es:
|
122 |
+
for query_group in es.msearch(body=queries).get("responses", []):
|
123 |
+
for hit in query_group.get("hits", {}).get("hits", []):
|
124 |
+
hit = ElasticHitsResult(
|
125 |
+
index=hit["_index"],
|
126 |
+
id=hit["_id"],
|
127 |
+
score=hit["_score"],
|
128 |
+
source=hit["_source"],
|
129 |
+
inner_hits=hit.get("inner_hits", {})
|
130 |
+
)
|
131 |
+
results.append(hit)
|
132 |
+
return results
|
133 |
+
|
134 |
+
|
135 |
+
def get_query_results(search_text: str, indices: Optional[List[str]] = None):
|
136 |
+
queries = query_builder(query=search_text, indices=indices)
|
137 |
+
return multi_search(queries)
|
138 |
+
|
139 |
+
|
140 |
+
def reranker(query_results: Iterable[ElasticHitsResult]) -> Iterator[ElasticHitsResult]:
|
141 |
+
"""Reranks Elasticsearch hits coming from multiple indicies/queries which may have scores on different scales.
|
142 |
+
This will shuffle results
|
143 |
+
|
144 |
+
Parameters
|
145 |
+
----------
|
146 |
+
query_results : Iterable[ElasticHitsResult]
|
147 |
+
|
148 |
+
Yields
|
149 |
+
------
|
150 |
+
Iterator[ElasticHitsResult]
|
151 |
+
"""
|
152 |
+
|
153 |
+
results: List[ElasticHitsResult] = []
|
154 |
+
for _, data in groupby(query_results, key=lambda x: x.index):
|
155 |
+
data = list(data)
|
156 |
+
max_score = max(data, key=lambda x: x.score).score
|
157 |
+
min_score = min(data, key=lambda x: x.score).score
|
158 |
+
|
159 |
+
for d in data:
|
160 |
+
d.score = (d.score - min_score) / (max_score - min_score + 1e-9)
|
161 |
+
results.append(d)
|
162 |
+
|
163 |
+
yield from sorted(results, key=lambda x: x.score, reverse=True)
|
164 |
+
|
165 |
+
|
166 |
+
def get_results(user_input: str, indices: List[str]) -> List[ElasticHitsResult]:
|
167 |
+
output = ["Search didn't return any Candid sources"]
|
168 |
+
page_content=[]
|
169 |
+
content = "Search didn't return any Candid sources"
|
170 |
+
results = get_query_results(search_text=user_input, indices=indices)
|
171 |
+
if results:
|
172 |
+
output = get_reranked_results(results)
|
173 |
+
for doc in output:
|
174 |
+
page_content.append(doc.page_content)
|
175 |
+
content = "/n/n".join(page_content)
|
176 |
+
# for the tool we need to return a tuple for content_and_artifact type
|
177 |
+
return content, output
|
178 |
+
|
179 |
+
|
180 |
+
def get_context(field_name: str, hit: ElasticHitsResult, context_length: int = 1024) -> str:
|
181 |
+
"""Pads the relevant chunk of text with context before and after
|
182 |
+
|
183 |
+
Parameters
|
184 |
+
----------
|
185 |
+
field_name : str
|
186 |
+
a field with the long text that was chunked into pieces
|
187 |
+
hit : ElasticHitsResult
|
188 |
+
context_length : int, optional
|
189 |
+
length of text to add before and after the chunk, by default 1024
|
190 |
+
|
191 |
+
Returns
|
192 |
+
-------
|
193 |
+
str
|
194 |
+
longer chunks stuffed together
|
195 |
+
"""
|
196 |
+
|
197 |
+
chunks_with_context = []
|
198 |
+
long_text = hit.source.get(f"{field_name}", "")
|
199 |
+
inner_hits_field = f"embeddings.{field_name}.chunks"
|
200 |
+
inner_hits = hit.inner_hits
|
201 |
+
found_chunks = inner_hits.get(inner_hits_field, {})
|
202 |
+
if found_chunks:
|
203 |
+
hits = found_chunks.get("hits", {}).get("hits", [])
|
204 |
+
for h in hits:
|
205 |
+
chunk = h.get("fields", {})[inner_hits_field][0]["chunk"][0]
|
206 |
+
chunk = chunk[3:-3] # cutting the middle because we may have tokenizing artefacts there
|
207 |
+
# Find the start and end indices of the chunk in the large text
|
208 |
+
start_index = long_text.find(chunk)
|
209 |
+
if start_index != -1: # Chunk is found
|
210 |
+
end_index = start_index + len(chunk)
|
211 |
+
pre_start_index = max(0, start_index - context_length)
|
212 |
+
post_end_index = min(len(long_text), end_index + context_length)
|
213 |
+
context = long_text[pre_start_index:post_end_index]
|
214 |
+
chunks_with_context.append(context)
|
215 |
+
chunks_with_context_txt = '\n\n'.join(chunks_with_context)
|
216 |
+
|
217 |
+
return chunks_with_context_txt
|
218 |
+
|
219 |
+
|
220 |
+
def process_hit(hit: ElasticHitsResult) -> Document | None:
|
221 |
+
if "issuelab-elser" in hit.index:
|
222 |
+
combined_item_description = hit.source.get("combined_item_description", "") # title inside
|
223 |
+
description = hit.source.get("description", "")
|
224 |
+
combined_issuelab_findings = hit.source.get("combined_issuelab_findings", "")
|
225 |
+
# we only need to process long texts
|
226 |
+
chunks_with_context_txt = get_context("content", hit, context_length=12)
|
227 |
+
doc = Document(
|
228 |
+
page_content='\n\n'.join([
|
229 |
+
combined_item_description,
|
230 |
+
combined_issuelab_findings,
|
231 |
+
description,
|
232 |
+
chunks_with_context_txt
|
233 |
+
]),
|
234 |
+
metadata={
|
235 |
+
"title": hit.source["title"],
|
236 |
+
"source": "IssueLab",
|
237 |
+
"source_id": hit.source["resource_id"],
|
238 |
+
"url": hit.source.get("permalink", "")
|
239 |
+
}
|
240 |
+
)
|
241 |
+
elif "youtube" in hit.index:
|
242 |
+
title = hit.source.get("title", "")
|
243 |
+
# we only need to process long texts
|
244 |
+
description_cleaned_with_context_txt = get_context("description_cleaned", hit, context_length=12)
|
245 |
+
captions_cleaned_with_context_txt = get_context("captions_cleaned", hit, context_length=12)
|
246 |
+
doc = Document(
|
247 |
+
page_content='\n\n'.join([title, description_cleaned_with_context_txt, captions_cleaned_with_context_txt]),
|
248 |
+
metadata={
|
249 |
+
"title": title,
|
250 |
+
"source": "Candid YouTube",
|
251 |
+
"source_id": hit.source['video_id'],
|
252 |
+
"url": f"https://www.youtube.com/watch?v={hit.source['video_id']}"
|
253 |
+
}
|
254 |
+
)
|
255 |
+
elif "candid-blog" in hit.index:
|
256 |
+
excerpt = hit.source.get("excerpt", "")
|
257 |
+
title = hit.source.get("title", "")
|
258 |
+
# we only need to process long texts
|
259 |
+
content_with_context_txt = get_context("content", hit, context_length=12)
|
260 |
+
doc = Document(
|
261 |
+
page_content='\n\n'.join([title, excerpt, content_with_context_txt]),
|
262 |
+
metadata={
|
263 |
+
"title": title,
|
264 |
+
"source": "Candid Blog",
|
265 |
+
"source_id": hit.source["id"],
|
266 |
+
"url": hit.source["link"]
|
267 |
+
}
|
268 |
+
)
|
269 |
+
elif "candid-learning" in hit.index:
|
270 |
+
title = hit.source.get("title", "")
|
271 |
+
content_with_context_txt = get_context("content", hit, context_length=12)
|
272 |
+
training_topics = hit.source.get("training_topics", "")
|
273 |
+
staff_recommendations = hit.source.get("staff_recommendations", "")
|
274 |
+
|
275 |
+
doc = Document(
|
276 |
+
page_content='\n\n'.join([title, staff_recommendations, training_topics, content_with_context_txt]),
|
277 |
+
metadata={
|
278 |
+
"title": hit.source["title"],
|
279 |
+
"source": "Candid Learning",
|
280 |
+
"source_id": hit.source["post_id"],
|
281 |
+
"url": hit.source.get("url", "")
|
282 |
+
}
|
283 |
+
)
|
284 |
+
elif "candid-help" in hit.index:
|
285 |
+
title = hit.source.get("title", "")
|
286 |
+
content_with_context_txt = get_context("content", hit, context_length=12)
|
287 |
+
combined_article_description = hit.source.get("combined_article_description", "")
|
288 |
+
|
289 |
+
doc = Document(
|
290 |
+
page_content='\n\n'.join([combined_article_description, content_with_context_txt]),
|
291 |
+
metadata={
|
292 |
+
"title": title,
|
293 |
+
"source": "Candid Help",
|
294 |
+
"source_id": hit.source["id"],
|
295 |
+
"url": hit.source.get("link", "")
|
296 |
+
}
|
297 |
+
)
|
298 |
+
else:
|
299 |
+
doc = None
|
300 |
+
return doc
|
301 |
+
|
302 |
+
|
303 |
+
def get_reranked_results(results: List[ElasticHitsResult]) -> List[Document]:
|
304 |
+
output = []
|
305 |
+
for r in reranker(results):
|
306 |
+
hit = process_hit(r)
|
307 |
+
output.append(hit)
|
308 |
+
return output
|
309 |
+
|
310 |
+
|
311 |
+
def retriever_tool(indices: List[str]) -> Tool:
|
312 |
+
# cannot use create_retriever_tool because it only provides content losing all metadata on the way
|
313 |
+
# https://python.langchain.com/docs/how_to/custom_tools/#returning-artifacts-of-tool-execution
|
314 |
+
return Tool(
|
315 |
+
name="retrieve_social_sector_information",
|
316 |
+
func=partial(get_results, indices=indices),
|
317 |
+
description=(
|
318 |
+
"Return additional information about social and philanthropic sector, "
|
319 |
+
"including nonprofits (NGO), grants, foundations, funding, RFP, LOI, Candid."
|
320 |
+
),
|
321 |
+
args_schema=RetrieverInput,
|
322 |
+
response_format="content_and_artifact"
|
323 |
+
)
|
ask_candid/retrieval/sources/__init__.py
ADDED
File without changes
|
ask_candid/retrieval/sources/candid_blog.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Any
|
2 |
+
|
3 |
+
|
4 |
+
def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
5 |
+
url = f"{doc['link']}"
|
6 |
+
fields = ["title", "excerpt"]
|
7 |
+
|
8 |
+
fields_dict = {}
|
9 |
+
fields_len = 0
|
10 |
+
for field in fields:
|
11 |
+
if doc.get(field, None) is not None:
|
12 |
+
fields_dict[field] = doc[field]
|
13 |
+
fields_dict[field + "_txt"] = f"<div>{doc[field]}</div>"
|
14 |
+
|
15 |
+
if (fields_len + len(doc[field])) > 999:
|
16 |
+
rest_text_len = 999 - fields_len
|
17 |
+
if rest_text_len > 0:
|
18 |
+
fields_dict[field + "_txt"] = f"<div>{doc[field][:rest_text_len] + '[...]'}</div>"
|
19 |
+
else: fields_dict[field + "_txt"] = f"<span>{'[...]'}</span>"
|
20 |
+
fields_len = fields_len + len(doc[field])
|
21 |
+
else:
|
22 |
+
fields_dict[field] = ""
|
23 |
+
fields_dict[field + "_txt"] = ""
|
24 |
+
html = f"""
|
25 |
+
<div style='height: {height_px}px; padding: 5px;'>
|
26 |
+
<div style='height: {height_px}px; border: 1px solid #febe10;'>
|
27 |
+
<span style='padding-left: 10px; display: inline-block; width: 100%;'>
|
28 |
+
<div>
|
29 |
+
<span>
|
30 |
+
<b>Candid blog post:</b>
|
31 |
+
<a href='{url}' target='_blank' style='text-decoration: none;'>
|
32 |
+
{doc['title']}
|
33 |
+
</a>
|
34 |
+
</span>
|
35 |
+
<br>
|
36 |
+
<br>
|
37 |
+
{fields_dict["excerpt_txt"]}
|
38 |
+
</div>
|
39 |
+
</span>
|
40 |
+
</div>
|
41 |
+
</div>
|
42 |
+
"""
|
43 |
+
return html
|
ask_candid/retrieval/sources/candid_help.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Any
|
2 |
+
|
3 |
+
|
4 |
+
def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
5 |
+
url = f"{doc['link']}"
|
6 |
+
fields = ["title", "summary"]
|
7 |
+
|
8 |
+
fields_dict = {}
|
9 |
+
fields_len = 0
|
10 |
+
for field in fields:
|
11 |
+
if doc.get(field, None) is not None:
|
12 |
+
fields_dict[field] = doc[field]
|
13 |
+
fields_dict[field + "_txt"] = f"<div>{doc[field]}</div>"
|
14 |
+
|
15 |
+
if (fields_len + len(doc[field])) > 999:
|
16 |
+
rest_text_len = 999 - fields_len
|
17 |
+
if rest_text_len > 0:
|
18 |
+
fields_dict[field + "_txt"] = f"<div>{doc[field][:rest_text_len] + '[...]'}</div>"
|
19 |
+
else: fields_dict[field + "_txt"] = f"<span>{'[...]'}</span>"
|
20 |
+
fields_len = fields_len + len(doc[field])
|
21 |
+
else:
|
22 |
+
fields_dict[field] = ""
|
23 |
+
fields_dict[field + "_txt"] = ""
|
24 |
+
html = f"""
|
25 |
+
<div style='height: {height_px}px; padding: 5px;'>
|
26 |
+
<div style='height: {height_px}px; border: 1px solid #febe10;'>
|
27 |
+
<span style='padding-left: 10px; display: inline-block; width: 100%;'>
|
28 |
+
<div>
|
29 |
+
<span>
|
30 |
+
<b>Candid help article:</b>
|
31 |
+
<a href='{url}' target='_blank' style='text-decoration: none;'>
|
32 |
+
{doc['title']}
|
33 |
+
</a>
|
34 |
+
</span>
|
35 |
+
<br>
|
36 |
+
</div>
|
37 |
+
</span>
|
38 |
+
</div>
|
39 |
+
</div>
|
40 |
+
"""
|
41 |
+
return html
|
ask_candid/retrieval/sources/candid_learning.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Any
|
2 |
+
|
3 |
+
|
4 |
+
def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
5 |
+
url = f"{doc['url']}"
|
6 |
+
fields = ["title", "excerpt"]
|
7 |
+
|
8 |
+
fields_dict = {}
|
9 |
+
fields_len = 0
|
10 |
+
for field in fields:
|
11 |
+
if doc.get(field, None) is not None:
|
12 |
+
fields_dict[field] = doc[field]
|
13 |
+
fields_dict[field + "_txt"] = f"<div>{doc[field]}</div>"
|
14 |
+
|
15 |
+
if (fields_len + len(doc[field])) > 999:
|
16 |
+
rest_text_len = 999 - fields_len
|
17 |
+
if rest_text_len > 0:
|
18 |
+
fields_dict[field + "_txt"] = f"<div>{doc[field][:rest_text_len] + '[...]'}</div>"
|
19 |
+
else: fields_dict[field + "_txt"] = f"<span>{'[...]'}</span>"
|
20 |
+
fields_len = fields_len + len(doc[field])
|
21 |
+
else:
|
22 |
+
fields_dict[field] = ""
|
23 |
+
fields_dict[field + "_txt"] = ""
|
24 |
+
html = f"""
|
25 |
+
<div style='height: {height_px}px; padding: 5px;'>
|
26 |
+
<div style='height: {height_px}px; border: 1px solid #febe10;'>
|
27 |
+
<span style='padding-left: 10px; display: inline-block; width: 100%;'>
|
28 |
+
<div>
|
29 |
+
<span>
|
30 |
+
<b>Candid Learning resource:</b>
|
31 |
+
<a href='{url}' target='_blank' style='text-decoration: none;'>
|
32 |
+
{doc['title']}
|
33 |
+
</a>
|
34 |
+
</span>
|
35 |
+
<br>
|
36 |
+
</div>
|
37 |
+
</span>
|
38 |
+
</div>
|
39 |
+
</div>
|
40 |
+
"""
|
41 |
+
return html
|
ask_candid/retrieval/sources/issuelab.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Any
|
2 |
+
|
3 |
+
|
4 |
+
def issuelab_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
5 |
+
chunks_html = ""
|
6 |
+
if show_chunks:
|
7 |
+
cleaned_text = []
|
8 |
+
for k, v in doc["inner_hits"].items():
|
9 |
+
hits = v["hits"]["hits"]
|
10 |
+
for h in hits:
|
11 |
+
for k1, v1 in h["fields"].items():
|
12 |
+
# we don't want other chunks
|
13 |
+
if "content" in k1:
|
14 |
+
cleaned_text.append(f"<div><p>{v1[0]['chunk'][0]}</p></div>")
|
15 |
+
|
16 |
+
chunks_html ="<span><b>Relevant parts of the content:</b></span>" + "<br>".join(cleaned_text)
|
17 |
+
|
18 |
+
html = f"""
|
19 |
+
<div style='height: auto; padding: 5px;'>
|
20 |
+
<div style='border: 1px solid #febe10;'>
|
21 |
+
<span style='display: inline-block; height: {height_px - 10}px; padding: 5px; vertical-align: top;'>
|
22 |
+
<img
|
23 |
+
src='{doc['cover_graphic_small']}'
|
24 |
+
style='max-height: 100%; overflow: hidden; border-radius: 3%;'
|
25 |
+
>
|
26 |
+
</span>
|
27 |
+
|
28 |
+
<span style='padding: 10px; display: inline-block; width: 70%;'>
|
29 |
+
<div>
|
30 |
+
<span><b>Issuelab ID:</b> {doc['resource_id']}</span>
|
31 |
+
<br>
|
32 |
+
<span>
|
33 |
+
<a href='{doc['issuelab_url']}' target='_blank' style='text-decoration: none;'>
|
34 |
+
{doc['title']}
|
35 |
+
</a>
|
36 |
+
</span>
|
37 |
+
<br>
|
38 |
+
|
39 |
+
<span><b>Description:</b> {doc['description']}</span>
|
40 |
+
<br>
|
41 |
+
<div>{doc['combined_item_description']}</div>
|
42 |
+
<br>
|
43 |
+
<div>{chunks_html}</div>
|
44 |
+
|
45 |
+
</div>
|
46 |
+
</span>
|
47 |
+
</div>
|
48 |
+
</div>
|
49 |
+
"""
|
50 |
+
return html
|
ask_candid/retrieval/sources/youtube.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Any
|
2 |
+
|
3 |
+
|
4 |
+
def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
5 |
+
url = f"https://www.youtube.com/watch?v={doc['video_id']}"
|
6 |
+
fields = ["title", "description_cleaned"]
|
7 |
+
|
8 |
+
fields_dict = {}
|
9 |
+
fields_len = 0
|
10 |
+
for field in fields:
|
11 |
+
if doc.get(field, None) is not None:
|
12 |
+
fields_dict[field] = doc[field]
|
13 |
+
fields_dict[field + "_txt"] = f"<div>{doc[field]}</div>"
|
14 |
+
|
15 |
+
if (fields_len + len(doc[field])) > 999:
|
16 |
+
rest_text_len = 999 - fields_len
|
17 |
+
if rest_text_len > 0:
|
18 |
+
fields_dict[field + "_txt"] = f"<div>{doc[field][:rest_text_len] + '[...]'}</div>"
|
19 |
+
else: fields_dict[field + "_txt"] = f"<span>{'[...]'}</span>"
|
20 |
+
fields_len = fields_len + len(doc[field])
|
21 |
+
else:
|
22 |
+
fields_dict[field] = ""
|
23 |
+
fields_dict[field + "_txt"] = ""
|
24 |
+
html = f"""
|
25 |
+
<div style='height: {height_px}px; padding: 5px;'>
|
26 |
+
<div style='height: {height_px}px; border: 1px solid #febe10;'>
|
27 |
+
<span style='padding-left: 10px; display: inline-block; width: 100%;'>
|
28 |
+
<div>
|
29 |
+
<span>
|
30 |
+
<b>Candid Youtube video:</b>
|
31 |
+
<a href='{url}' target='_blank' style='text-decoration: none;'>
|
32 |
+
{doc['title']}
|
33 |
+
</a>
|
34 |
+
</span>
|
35 |
+
<iframe
|
36 |
+
width="426"
|
37 |
+
height="240"
|
38 |
+
src="https://www.youtube.com/embed/{doc['video_id']}?si=0-y6eRrOzXTUSBDY"
|
39 |
+
title="YouTube video player"
|
40 |
+
frameborder="0"
|
41 |
+
allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share"
|
42 |
+
referrerpolicy="strict-origin-when-cross-origin"
|
43 |
+
allowfullscreen
|
44 |
+
style="display: inline-block; float: left;padding-right: 10px;padding-top: 5px;">
|
45 |
+
</iframe>
|
46 |
+
<br>
|
47 |
+
<br>
|
48 |
+
{fields_dict["description_cleaned_txt"]}
|
49 |
+
</div>
|
50 |
+
</span>
|
51 |
+
</div>
|
52 |
+
</div>
|
53 |
+
"""
|
54 |
+
return html
|
ask_candid/services/__init__.py
ADDED
File without changes
|
ask_candid/services/org_search.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ask_candid.base.api_base import BaseAPI
|
2 |
+
from ask_candid.base.config.rest import CDS_API
|
3 |
+
|
4 |
+
|
5 |
+
class OrgSearch(BaseAPI):
|
6 |
+
|
7 |
+
def __init__(self):
|
8 |
+
super().__init__(
|
9 |
+
url=f"{CDS_API['url']}/v1/organization/search",
|
10 |
+
headers={"x-api-key": CDS_API["key"]}
|
11 |
+
)
|
12 |
+
|
13 |
+
def __call__(self, name: str, name_only: bool = False, **kwargs):
|
14 |
+
is_valid = False
|
15 |
+
|
16 |
+
payload = {
|
17 |
+
"names": [{
|
18 |
+
"value": name,
|
19 |
+
"type": "main"
|
20 |
+
}],
|
21 |
+
"status": "authorized"
|
22 |
+
}
|
23 |
+
|
24 |
+
if name_only:
|
25 |
+
is_valid = True
|
26 |
+
else:
|
27 |
+
if kwargs.get("ein"):
|
28 |
+
ein = kwargs.get("ein")
|
29 |
+
if "-" not in ein:
|
30 |
+
ein = f"{ein[:2]}-{ein[2:]}"
|
31 |
+
payload["ids"] = [{
|
32 |
+
"value": ein,
|
33 |
+
"type": "ein"
|
34 |
+
}]
|
35 |
+
is_valid = True
|
36 |
+
|
37 |
+
if kwargs.get("street") or kwargs.get("city") or kwargs.get("state") or kwargs.get("postal_code"):
|
38 |
+
payload["addresses"] = [{
|
39 |
+
"street1": kwargs.get("street") or "",
|
40 |
+
"city": kwargs.get("city") or "",
|
41 |
+
"state": kwargs.get("state") or "",
|
42 |
+
"postal_code": kwargs.get("postal_code") or ""
|
43 |
+
}]
|
44 |
+
is_valid = True
|
45 |
+
|
46 |
+
if not is_valid:
|
47 |
+
return None
|
48 |
+
|
49 |
+
result = self.post(payload=payload)
|
50 |
+
return result.get("payload", [])
|
ask_candid/services/small_lm.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from enum import Enum
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from ask_candid.base.lambda_base import LambdaInvokeBase
|
8 |
+
|
9 |
+
|
10 |
+
@dataclass(slots=True)
|
11 |
+
class Encoding:
|
12 |
+
inputs: List[str]
|
13 |
+
vectors: torch.Tensor
|
14 |
+
|
15 |
+
|
16 |
+
class CandidSLM(LambdaInvokeBase):
|
17 |
+
"""Wrapper around Candid's custom small language model.
|
18 |
+
For more details see https://dev.azure.com/guidestar/DataScience/_git/graph-ai?path=/releases/language.
|
19 |
+
This services includes:
|
20 |
+
* text encoding
|
21 |
+
* document summarization
|
22 |
+
* entity salience estimation
|
23 |
+
|
24 |
+
Parameters
|
25 |
+
----------
|
26 |
+
access_key : Optional[str], optional
|
27 |
+
AWS access key, by default None
|
28 |
+
secret_key : Optional[str], optional
|
29 |
+
AWS secret key, by default None
|
30 |
+
"""
|
31 |
+
|
32 |
+
class Tasks(Enum): # pylint: disable=missing-class-docstring
|
33 |
+
ENCODE = "/encode"
|
34 |
+
DOCUMENT_SUMMARIZE = "/document/summarize"
|
35 |
+
DOCUMENT_NER_SALIENCE = "/document/entitySalience"
|
36 |
+
|
37 |
+
def __init__(
|
38 |
+
self, access_key: Optional[str] = None, secret_key: Optional[str] = None
|
39 |
+
) -> None:
|
40 |
+
super().__init__(
|
41 |
+
function_name="small-lm",
|
42 |
+
access_key=access_key,
|
43 |
+
secret_key=secret_key
|
44 |
+
)
|
45 |
+
|
46 |
+
def encode(self, text: List[str]) -> Encoding:
|
47 |
+
response = self._submit_request({"text": text})
|
48 |
+
|
49 |
+
output = Encoding(
|
50 |
+
inputs=(response.get("inputs") or []),
|
51 |
+
vectors=torch.tensor((response.get("vectors") or []), dtype=torch.float32)
|
52 |
+
)
|
53 |
+
return output
|
ask_candid/tools/__init__.py
ADDED
File without changes
|
ask_candid/tools/elastic/__init__.py
ADDED
File without changes
|
ask_candid/tools/elastic/index_data_tool.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Type, Optional
|
2 |
+
import logging
|
3 |
+
|
4 |
+
from pydantic import BaseModel, Field
|
5 |
+
|
6 |
+
from elasticsearch import Elasticsearch
|
7 |
+
|
8 |
+
from langchain.callbacks.manager import CallbackManagerForToolRun
|
9 |
+
from langchain.tools.base import BaseTool
|
10 |
+
from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA
|
11 |
+
|
12 |
+
logging.basicConfig(level="INFO")
|
13 |
+
logger = logging.getLogger("elasticsearch_playground")
|
14 |
+
es = Elasticsearch(
|
15 |
+
cloud_id=SEMANTIC_ELASTIC_QA.cloud_id,
|
16 |
+
api_key=SEMANTIC_ELASTIC_QA.api_key,
|
17 |
+
verify_certs=True,
|
18 |
+
request_timeout=60 * 3
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
class IndexShowDataInput(BaseModel):
|
23 |
+
"""Input for the index show data tool."""
|
24 |
+
|
25 |
+
index_name: str = Field(
|
26 |
+
..., description="The name of the index for which the data is to be retrieved"
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
class IndexShowDataTool(BaseTool):
|
31 |
+
"""Tool for getting a list of entries from an ElasticSearch index, helpful to figure out what data is available."""
|
32 |
+
|
33 |
+
name: str = "elastic_index_show_data" # Added type annotation
|
34 |
+
description: str = (
|
35 |
+
"Input is an index name, output is a JSON based string with an extract of the data of the index"
|
36 |
+
)
|
37 |
+
args_schema: Optional[Type[BaseModel]] = (
|
38 |
+
IndexShowDataInput # This should be placed before methods
|
39 |
+
)
|
40 |
+
|
41 |
+
def _run(
|
42 |
+
self,
|
43 |
+
index_name: str,
|
44 |
+
run_manager: Optional[CallbackManagerForToolRun] = None,
|
45 |
+
) -> str:
|
46 |
+
"""Get all indices in the Elasticsearch server, usually separated by a line break."""
|
47 |
+
try:
|
48 |
+
# Ensure `es` is properly initialized before this method is called
|
49 |
+
res = es.search(
|
50 |
+
index=index_name,
|
51 |
+
from_=0,
|
52 |
+
size=20,
|
53 |
+
query={"match_all": {}},
|
54 |
+
)
|
55 |
+
return str(res["hits"])
|
56 |
+
except Exception as e:
|
57 |
+
print(e)
|
58 |
+
logger.exception("Could not fetch index data for %s", index_name)
|
59 |
+
return ""
|
ask_candid/tools/elastic/index_details_tool.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Type, Optional
|
2 |
+
import logging
|
3 |
+
|
4 |
+
from pydantic import BaseModel, Field
|
5 |
+
|
6 |
+
from elasticsearch import Elasticsearch
|
7 |
+
|
8 |
+
from langchain.callbacks.manager import (
|
9 |
+
AsyncCallbackManagerForToolRun,
|
10 |
+
CallbackManagerForToolRun,
|
11 |
+
)
|
12 |
+
from langchain.tools.base import BaseTool
|
13 |
+
from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA
|
14 |
+
|
15 |
+
|
16 |
+
logging.basicConfig(level="INFO")
|
17 |
+
logger = logging.getLogger("elasticsearch_playground")
|
18 |
+
es = Elasticsearch(
|
19 |
+
cloud_id=SEMANTIC_ELASTIC_QA.cloud_id,
|
20 |
+
api_key=SEMANTIC_ELASTIC_QA.api_key,
|
21 |
+
verify_certs=True,
|
22 |
+
request_timeout=60 * 3
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
class IndexDetailsInput(BaseModel):
|
27 |
+
"""Input for the list index details tool."""
|
28 |
+
|
29 |
+
index_name: str = Field(
|
30 |
+
...,
|
31 |
+
description="The name of the index for which the details are to be retrieved",
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
class IndexDetailsTool(BaseTool):
|
36 |
+
"""Tool for getting information about a single ElasticSearch index."""
|
37 |
+
|
38 |
+
name: str = "elastic_index_show_details" # Added type annotation
|
39 |
+
description: str = (
|
40 |
+
"Input is an index name, output is a JSON-based string with the aliases, mappings containing the field names, and settings of an index."
|
41 |
+
)
|
42 |
+
args_schema: Optional[Type[BaseModel]] = (
|
43 |
+
IndexDetailsInput # Ensure this is above the methods
|
44 |
+
)
|
45 |
+
|
46 |
+
def _run(
|
47 |
+
self,
|
48 |
+
index_name: str,
|
49 |
+
run_manager: Optional[CallbackManagerForToolRun] = None,
|
50 |
+
) -> str:
|
51 |
+
"""Get information about a single Elasticsearch index."""
|
52 |
+
try:
|
53 |
+
# Ensure that `es` is correctly initialized before calling this method
|
54 |
+
alias = es.indices.get_alias(index=index_name)
|
55 |
+
field_mappings = es.indices.get_field_mapping(index=index_name, fields="*")
|
56 |
+
field_settings = es.indices.get_settings(index=index_name)
|
57 |
+
return str(
|
58 |
+
{
|
59 |
+
"alias": alias[index_name],
|
60 |
+
"field_mappings": field_mappings[index_name],
|
61 |
+
"settings": field_settings[index_name],
|
62 |
+
}
|
63 |
+
)
|
64 |
+
except Exception as e:
|
65 |
+
logger.exception("Could not fetch index information for %s: %s", index_name, e)
|
66 |
+
return ""
|
67 |
+
|
68 |
+
async def _arun(
|
69 |
+
self,
|
70 |
+
index_name: str = "",
|
71 |
+
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
72 |
+
) -> str:
|
73 |
+
raise NotImplementedError("IndexDetailsTool does not support async operations")
|
ask_candid/tools/elastic/index_search_tool.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import json
|
3 |
+
|
4 |
+
import tiktoken
|
5 |
+
from elasticsearch import Elasticsearch
|
6 |
+
# from pydantic.v1 import BaseModel, Field # <-- Uses v1 namespace
|
7 |
+
from pydantic import BaseModel, Field
|
8 |
+
from langchain.tools import StructuredTool
|
9 |
+
|
10 |
+
from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA
|
11 |
+
|
12 |
+
logging.basicConfig(level="INFO")
|
13 |
+
logger = logging.getLogger("elasticsearch_playground")
|
14 |
+
es = Elasticsearch(
|
15 |
+
cloud_id=SEMANTIC_ELASTIC_QA.cloud_id,
|
16 |
+
api_key=SEMANTIC_ELASTIC_QA.api_key,
|
17 |
+
verify_certs=True,
|
18 |
+
request_timeout=60 * 3
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
class SearchToolInput(BaseModel):
|
23 |
+
"""Input for the index show data tool."""
|
24 |
+
|
25 |
+
index_name: str = Field(
|
26 |
+
..., description="The name of the index for which the data is to be retrieved"
|
27 |
+
)
|
28 |
+
query: str = Field(
|
29 |
+
...,
|
30 |
+
description="The ElasticSearch JSON query used to filter all hits. Should use the _source field if possible to specify required fields.",
|
31 |
+
)
|
32 |
+
from_: int = Field(
|
33 |
+
..., description="The record index from which the query will start"
|
34 |
+
)
|
35 |
+
size: int = Field(
|
36 |
+
...,
|
37 |
+
description="How many records will be retrieved from the ElasticSearch query",
|
38 |
+
)
|
39 |
+
|
40 |
+
|
41 |
+
def elastic_search(
|
42 |
+
index_name: str,
|
43 |
+
query: str,
|
44 |
+
from_: int = 0,
|
45 |
+
size: int = 20,
|
46 |
+
):
|
47 |
+
"""Executes a specific query on an ElasticSearch index and returns all hits or aggregation results"""
|
48 |
+
size = min(50, size)
|
49 |
+
encoding = tiktoken.encoding_for_model("gpt-4")
|
50 |
+
try:
|
51 |
+
full_dict: dict = json.loads(query)
|
52 |
+
query_dict = None
|
53 |
+
aggs_dict = None
|
54 |
+
sort_dict = None
|
55 |
+
if "query" in full_dict:
|
56 |
+
query_dict = full_dict["query"]
|
57 |
+
if "aggs" in full_dict:
|
58 |
+
aggs_dict = full_dict["aggs"]
|
59 |
+
if "sort" in full_dict:
|
60 |
+
sort_dict = full_dict["sort"]
|
61 |
+
if query_dict is None and aggs_dict is None and sort_dict is None:
|
62 |
+
# Assume that there is a query but that the query part was ommitted.
|
63 |
+
query_dict = full_dict
|
64 |
+
if query_dict is None and aggs_dict is not None:
|
65 |
+
# This is an aggregation query, therefore we suppress the hits here
|
66 |
+
size = 200
|
67 |
+
logger.info(query)
|
68 |
+
# Print the query
|
69 |
+
# print(f"Executing Elasticsearch Query: {query}")
|
70 |
+
final_res = ""
|
71 |
+
retries = 0
|
72 |
+
while retries < 100:
|
73 |
+
res = es.search(
|
74 |
+
index=index_name,
|
75 |
+
from_=from_,
|
76 |
+
size=size,
|
77 |
+
query=query_dict,
|
78 |
+
aggs=aggs_dict,
|
79 |
+
sort=sort_dict,
|
80 |
+
)
|
81 |
+
if query_dict is None and aggs_dict is not None:
|
82 |
+
# When a result has aggregations, just return that and ignore the rest
|
83 |
+
final_res = str(res["aggregations"])
|
84 |
+
else:
|
85 |
+
final_res = str(res["hits"])
|
86 |
+
tokens = encoding.encode(final_res)
|
87 |
+
retries += 1
|
88 |
+
if len(tokens) > 6000:
|
89 |
+
size -= 1
|
90 |
+
else:
|
91 |
+
return final_res
|
92 |
+
|
93 |
+
except Exception as e:
|
94 |
+
logger.exception("Could not execute query %s", query)
|
95 |
+
msg = str(e)
|
96 |
+
return msg
|
97 |
+
|
98 |
+
|
99 |
+
def create_search_tool():
|
100 |
+
return StructuredTool.from_function(
|
101 |
+
elastic_search, name="elastic_index_search_tool", args_schema=SearchToolInput
|
102 |
+
)
|
ask_candid/tools/elastic/list_indices_tool.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Type, Optional, List
|
2 |
+
import logging
|
3 |
+
|
4 |
+
from pydantic import BaseModel, Field
|
5 |
+
|
6 |
+
from elasticsearch import Elasticsearch
|
7 |
+
|
8 |
+
from langchain.callbacks.manager import AsyncCallbackManagerForToolRun
|
9 |
+
from langchain.tools.base import BaseTool
|
10 |
+
|
11 |
+
from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA
|
12 |
+
|
13 |
+
logging.basicConfig(level="INFO")
|
14 |
+
logger = logging.getLogger("elasticsearch_playground")
|
15 |
+
es = Elasticsearch(
|
16 |
+
cloud_id=SEMANTIC_ELASTIC_QA.cloud_id,
|
17 |
+
api_key=SEMANTIC_ELASTIC_QA.api_key,
|
18 |
+
verify_certs=True,
|
19 |
+
request_timeout=60 * 3
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
class ListIndicesInput(BaseModel):
|
24 |
+
"""Input for the list indices tool."""
|
25 |
+
|
26 |
+
separator: str = Field(..., description="Separator for the list of indices")
|
27 |
+
|
28 |
+
|
29 |
+
class ListIndicesTool(BaseTool):
|
30 |
+
"""Tool for getting all ElasticSearch indices."""
|
31 |
+
|
32 |
+
name: str = "elastic_list_indices" # Added type annotation
|
33 |
+
description: str = (
|
34 |
+
"Input is a delimiter like comma or new line. Output is a separated list of indices in the database. Always use this tool to get to know the indices in the ElasticSearch cluster."
|
35 |
+
)
|
36 |
+
args_schema: Optional[Type[BaseModel]] = (
|
37 |
+
ListIndicesInput # Define this before methods
|
38 |
+
)
|
39 |
+
|
40 |
+
def _run(self, separator: str) -> str:
|
41 |
+
"""Get all indices in the Elasticsearch server, usually separated by a line break."""
|
42 |
+
try:
|
43 |
+
# Ensure that `es` is correctly initialized before calling this method
|
44 |
+
indices: List[str] = es.cat.indices(h="index", s="index").split()
|
45 |
+
# Filter out hidden indices starting with a dot
|
46 |
+
return separator.join(
|
47 |
+
[index for index in indices if not index.startswith(".")]
|
48 |
+
)
|
49 |
+
except Exception as e:
|
50 |
+
logger.exception("Could not list indices: %s", e)
|
51 |
+
return ""
|
52 |
+
|
53 |
+
async def _arun(
|
54 |
+
self,
|
55 |
+
separator: str = "",
|
56 |
+
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
57 |
+
) -> str:
|
58 |
+
raise NotImplementedError("ListIndicesTool does not support async operations")
|
ask_candid/tools/org_seach.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import re
|
3 |
+
|
4 |
+
from fuzzywuzzy import fuzz
|
5 |
+
|
6 |
+
from langchain.output_parsers.openai_tools import JsonOutputToolsParser
|
7 |
+
from langchain_openai.chat_models import ChatOpenAI
|
8 |
+
from langchain_core.runnables import RunnableSequence
|
9 |
+
from langchain_core.prompts import ChatPromptTemplate
|
10 |
+
from pydantic import BaseModel, Field
|
11 |
+
|
12 |
+
from ask_candid.services.org_search import OrgSearch
|
13 |
+
from ask_candid.base.config.rest import OPENAI
|
14 |
+
|
15 |
+
search = OrgSearch()
|
16 |
+
|
17 |
+
|
18 |
+
class OrganizationNames(BaseModel):
|
19 |
+
"""List of names of social-sector organizations, such as nonprofits and foundations."""
|
20 |
+
orgnames: List[str] = Field(description="List of organization names")
|
21 |
+
|
22 |
+
|
23 |
+
def extract_org_links_from_chatbot(chatbot_output: str):
|
24 |
+
"""
|
25 |
+
Extracts a list of organization names from the provided text.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
chatbot_output (str):The chatbot output containing organization names and other content.
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
list: A list of organization names extracted from the text.
|
32 |
+
|
33 |
+
Raises:
|
34 |
+
ValueError: If parsing fails or if an unexpected output format is received.
|
35 |
+
"""
|
36 |
+
prompt = """Extract only the names of officially recognized organizations, foundations, and government entities
|
37 |
+
from the text below. Do not include any entries that contain descriptions, regional identifiers, or explanations
|
38 |
+
within parentheses or following the name. Strictly exclude databases, resources, crowdfunding platforms, and general
|
39 |
+
terms. Provide the output only in the specified JSON format.
|
40 |
+
|
41 |
+
input text below:
|
42 |
+
|
43 |
+
```{chatbot_output}``
|
44 |
+
|
45 |
+
output format:
|
46 |
+
{{
|
47 |
+
'orgnames' : [list of organization names without any additional descriptions or identifiers]
|
48 |
+
}}
|
49 |
+
|
50 |
+
"""
|
51 |
+
|
52 |
+
try:
|
53 |
+
parser = JsonOutputToolsParser()
|
54 |
+
llm = ChatOpenAI(model="gpt-4o", api_key=OPENAI["key"]).bind_tools([OrganizationNames])
|
55 |
+
prompt = ChatPromptTemplate.from_template(prompt)
|
56 |
+
chain = RunnableSequence(prompt, llm, parser)
|
57 |
+
|
58 |
+
# Run the chain with the input data
|
59 |
+
result = chain.invoke({"chatbot_output": chatbot_output})
|
60 |
+
|
61 |
+
# Extract the organization names from the output
|
62 |
+
output_list = result[0]["args"].get("orgnames", [])
|
63 |
+
|
64 |
+
# Validate output format
|
65 |
+
if not isinstance(output_list, list):
|
66 |
+
raise ValueError("Unexpected output format: 'orgnames' should be a list")
|
67 |
+
|
68 |
+
return output_list
|
69 |
+
|
70 |
+
except Exception as e:
|
71 |
+
# Log or print the error as needed for debugging
|
72 |
+
print(f"text does not have any organization: {e}")
|
73 |
+
return []
|
74 |
+
|
75 |
+
|
76 |
+
def is_similar(name: str, list_of_dict: list, threshold: int = 80):
|
77 |
+
"""
|
78 |
+
Returns True if `name` is similar to any names in `list_of_dict` based on a similarity threshold.
|
79 |
+
"""
|
80 |
+
try:
|
81 |
+
for item in list_of_dict:
|
82 |
+
try:
|
83 |
+
# Attempt to calculate similarity score
|
84 |
+
similarity = fuzz.ratio(name.lower(), item["name"].lower())
|
85 |
+
if similarity >= threshold:
|
86 |
+
return True
|
87 |
+
except KeyError:
|
88 |
+
# Handle cases where 'name' key might be missing in dictionary
|
89 |
+
print(f"KeyError: Missing 'name' key in dictionary item {item}")
|
90 |
+
continue
|
91 |
+
except AttributeError:
|
92 |
+
# Handle non-string name values in dictionary items
|
93 |
+
print(f"AttributeError: Non-string 'name' in dictionary item {item}")
|
94 |
+
continue
|
95 |
+
except TypeError as e:
|
96 |
+
# Handle cases where input types are incorrect
|
97 |
+
print(f"TypeError: {e}")
|
98 |
+
return False
|
99 |
+
|
100 |
+
return False
|
101 |
+
|
102 |
+
|
103 |
+
def generate_org_link_dict(org_names_list: list):
|
104 |
+
"""
|
105 |
+
Maps organization names to their Candid profile URLs if available.
|
106 |
+
|
107 |
+
For each organization in `output_list`, this function attempts to retrieve a matching profile
|
108 |
+
using `search_org`. If a similar name is found and a Candid entity ID is available, it constructs
|
109 |
+
a profile URL. If no ID or similar match is found, or if an error occurs, it assigns an empty string.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
output_list (list): List of organization names (str) to retrieve Candid profile links for.
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
dict: Dictionary with organization names as keys and Candid profile URLs or empty strings as values.
|
116 |
+
|
117 |
+
Example:
|
118 |
+
get_org_link(['New York-Presbyterian Hospital'])
|
119 |
+
# {'New York-Presbyterian Hospital': 'https://app.candid.org/profile/6915255'}
|
120 |
+
"""
|
121 |
+
link_dict = {}
|
122 |
+
|
123 |
+
for org in org_names_list:
|
124 |
+
try:
|
125 |
+
# Attempt to retrieve organization data
|
126 |
+
response = search(org, name_only=True)
|
127 |
+
|
128 |
+
# Check if there is a valid response and if names are similar
|
129 |
+
if response and is_similar(org, response[0].get("names", "")):
|
130 |
+
# Try to get the Candid entity ID and construct the URL
|
131 |
+
candid_entity_id = response[0].get("candid_entity_id")
|
132 |
+
if candid_entity_id:
|
133 |
+
link_dict[org] = (
|
134 |
+
f"https://app.candid.org/profile/{candid_entity_id}"
|
135 |
+
)
|
136 |
+
else:
|
137 |
+
link_dict[org] = "" # No ID found, set empty string
|
138 |
+
else:
|
139 |
+
link_dict[org] = "" # No similar match found
|
140 |
+
|
141 |
+
except KeyError as e:
|
142 |
+
# Handle missing keys in the response dictionary
|
143 |
+
print(f"KeyError encountered for organization '{org}': {e}")
|
144 |
+
link_dict[org] = ""
|
145 |
+
|
146 |
+
except Exception as e:
|
147 |
+
# Catch any other unexpected errors
|
148 |
+
|
149 |
+
print(f"An error occurred for organization '{org}': {e}")
|
150 |
+
link_dict[org] = ""
|
151 |
+
|
152 |
+
return link_dict
|
153 |
+
|
154 |
+
|
155 |
+
def embed_org_links_in_text(input_text: str, org_link_dict: dict):
|
156 |
+
"""
|
157 |
+
Replaces organization names in `text` with links from `link_dict` and appends a Candid info message.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
text (str): The text containing organization names.
|
161 |
+
link_dict (dict): Mapping of organization names to URLs.
|
162 |
+
|
163 |
+
Returns:
|
164 |
+
str: Updated text with linked organization names and an appended Candid message.
|
165 |
+
"""
|
166 |
+
try:
|
167 |
+
for org_name, url in org_link_dict.items():
|
168 |
+
if url: # Only proceed if the URL is not empty
|
169 |
+
regex_pattern = re.compile(re.escape(org_name))
|
170 |
+
input_text = regex_pattern.sub(
|
171 |
+
repl=f"<a href={url} target='_blank' rel='noreferrer' class='candid-org-link'>{org_name}</a>",
|
172 |
+
string=input_text
|
173 |
+
)
|
174 |
+
|
175 |
+
# Append Candid information message at the end
|
176 |
+
input_text += (
|
177 |
+
"<p class='candid-app-link'> "
|
178 |
+
"Visit <a href=https://app.candid.org/ target='_blank' rel='noreferrer' class='candid-org-link'>Candid</a> "
|
179 |
+
"to get nonprofit information you need.</p>"
|
180 |
+
)
|
181 |
+
|
182 |
+
except TypeError as e:
|
183 |
+
print(f"TypeError encountered: {e}")
|
184 |
+
return input_text
|
185 |
+
|
186 |
+
except re.error as e:
|
187 |
+
print(f"Regex error encountered for '{org_name}': {e}")
|
188 |
+
return input_text
|
189 |
+
|
190 |
+
except Exception as e:
|
191 |
+
print(f"Unexpected error: {e}")
|
192 |
+
return input_text
|
193 |
+
|
194 |
+
return input_text
|
ask_candid/tools/question_reformulation.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.prompts import ChatPromptTemplate
|
2 |
+
from langchain_core.output_parsers import StrOutputParser
|
3 |
+
|
4 |
+
|
5 |
+
def reformulate_question_using_history(state, llm):
|
6 |
+
"""
|
7 |
+
Transform the query to produce a better query with details from previous messages.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
state (messages): The current state
|
11 |
+
llm: LLM to use
|
12 |
+
Returns:
|
13 |
+
dict: The updated state with re-phrased question and original user_input for UI
|
14 |
+
"""
|
15 |
+
print("---REFORMULATE THE USER INPUT---")
|
16 |
+
messages = state["messages"]
|
17 |
+
question = messages[-1].content
|
18 |
+
|
19 |
+
if len(messages) > 1:
|
20 |
+
contextualize_q_system_prompt = """Given a chat history and the latest user input \
|
21 |
+
which might reference context in the chat history, formulate a standalone input \
|
22 |
+
which can be understood without the chat history.
|
23 |
+
Chat history:
|
24 |
+
\n ------- \n
|
25 |
+
{chat_history}
|
26 |
+
\n ------- \n
|
27 |
+
User input:
|
28 |
+
\n ------- \n
|
29 |
+
{question}
|
30 |
+
\n ------- \n
|
31 |
+
Do NOT answer the question, \
|
32 |
+
just reformulate it if needed and otherwise return it as is.
|
33 |
+
"""
|
34 |
+
|
35 |
+
contextualize_q_prompt = ChatPromptTemplate([
|
36 |
+
("system", contextualize_q_system_prompt),
|
37 |
+
("human", question),
|
38 |
+
])
|
39 |
+
|
40 |
+
rag_chain = contextualize_q_prompt | llm | StrOutputParser()
|
41 |
+
new_question = rag_chain.invoke({"chat_history": messages, "question": question})
|
42 |
+
print(f"user asked: '{question}', agent reformulated the question basing on the chat history: {new_question}")
|
43 |
+
return {"messages": [new_question], "user_input" : question}
|
44 |
+
return {"messages": [question], "user_input" : question}
|
ask_candid/utils.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict, Union, Any
|
2 |
+
from uuid import uuid4
|
3 |
+
|
4 |
+
from ask_candid.retrieval.sources import (
|
5 |
+
candid_blog,
|
6 |
+
candid_help,
|
7 |
+
candid_learning,
|
8 |
+
issuelab,
|
9 |
+
youtube
|
10 |
+
)
|
11 |
+
|
12 |
+
|
13 |
+
def filter_messages(messages, k=10):
|
14 |
+
# TODO summarize messages instead
|
15 |
+
return messages[-k:]
|
16 |
+
|
17 |
+
|
18 |
+
def html_format_doc(doc: Dict[str, Any], source: str, show_chunks=False) -> str:
|
19 |
+
height_px = 200
|
20 |
+
html = ""
|
21 |
+
|
22 |
+
if source == "news":
|
23 |
+
# html = news.article_card_html(doc, height_px, show_chunks)
|
24 |
+
pass
|
25 |
+
elif source == "transactions":
|
26 |
+
# html = cds.transaction_card_html(doc, height_px, show_chunks)
|
27 |
+
pass
|
28 |
+
elif source == "organizations":
|
29 |
+
# html = up_orgs.organization_card_html(doc, 400, show_chunks)
|
30 |
+
pass
|
31 |
+
elif source == "issuelab":
|
32 |
+
html = issuelab.issuelab_card_html(doc, height_px, show_chunks)
|
33 |
+
elif source == "youtube":
|
34 |
+
html = youtube.build_card_html(doc, 400, show_chunks)
|
35 |
+
elif source == "candid_blog":
|
36 |
+
html = candid_blog.build_card_html(doc, height_px, show_chunks)
|
37 |
+
elif source == "candid_learning":
|
38 |
+
html = candid_learning.build_card_html(doc, height_px, show_chunks)
|
39 |
+
elif source == "candid_help":
|
40 |
+
html = candid_help.build_card_html(doc, height_px, show_chunks)
|
41 |
+
return html
|
42 |
+
|
43 |
+
|
44 |
+
def html_format_docs_chat(docs):
|
45 |
+
"""
|
46 |
+
Formats Candid sources into a line of buttons
|
47 |
+
"""
|
48 |
+
html = ""
|
49 |
+
if docs:
|
50 |
+
docs_html = []
|
51 |
+
for doc in docs:
|
52 |
+
s_name = doc.metadata.get("source", "Source")
|
53 |
+
s_url = doc.metadata.get("url", "URL")
|
54 |
+
|
55 |
+
s_html = (
|
56 |
+
"<span class='source-item'>"
|
57 |
+
f"<a href={s_url} target='_blank' rel='noreferrer' class='ssearch-source'>"
|
58 |
+
f"{doc.metadata['title']} ({s_name})</a></span>"
|
59 |
+
)
|
60 |
+
|
61 |
+
docs_html.append(s_html)
|
62 |
+
|
63 |
+
html = f"<h2>Candid Resources</h2><div id='ssearch-sources'>{'<br>'.join(docs_html)}</div>"
|
64 |
+
return html
|
65 |
+
|
66 |
+
|
67 |
+
def format_chat_response(chatbot: List[Any]) -> List[Any]:
|
68 |
+
"""We have sources appended as one more tuple. Here we concatinate HTML of sources
|
69 |
+
with the AI response
|
70 |
+
Returns:
|
71 |
+
_type_: updated chatbot message as HTML
|
72 |
+
"""
|
73 |
+
if chatbot:
|
74 |
+
sources = chatbot[-1][1]
|
75 |
+
chatbot.pop(-1)
|
76 |
+
chatbot[-1][1] = chatbot[-1][1] + sources
|
77 |
+
return chatbot
|
78 |
+
|
79 |
+
|
80 |
+
def format_chat_ag_response(chatbot: List[Any]) -> List[Any]:
|
81 |
+
"""If we called retriever, we appended sources as as one more message. Here we concatinate HTML of sources
|
82 |
+
with the AI response
|
83 |
+
Returns:
|
84 |
+
_type_: updated chatbot message as HTML
|
85 |
+
"""
|
86 |
+
sources = ""
|
87 |
+
if chatbot:
|
88 |
+
title = chatbot[-1]["metadata"].get("title", None)
|
89 |
+
if title == "Sources HTML":
|
90 |
+
sources = chatbot[-1]["content"]
|
91 |
+
chatbot.pop(-1)
|
92 |
+
chatbot[-1]["content"] = chatbot[-1]["content"] + sources
|
93 |
+
return chatbot
|
94 |
+
|
95 |
+
|
96 |
+
def valid_inputs(*args) -> bool:
|
97 |
+
return any(a is not None or (isinstance(a, str) and a.strip() != '') for a in args)
|
98 |
+
|
99 |
+
|
100 |
+
def get_session_id(thread_id: Union[str, None]) -> str:
|
101 |
+
if not thread_id:
|
102 |
+
thread_id = uuid4().hex
|
103 |
+
return thread_id
|