Spaces:
Running
Running
v2 of public chat
#1
by
brainsqueeze
- opened
- README.md +2 -2
- ask_candid/base/api_base.py +3 -3
- ask_candid/base/api_base_async.py +3 -3
- ask_candid/base/config/base.py +10 -0
- ask_candid/base/config/connections.py +6 -14
- ask_candid/base/config/models.py +1 -0
- ask_candid/base/config/rest.py +49 -10
- ask_candid/base/lambda_base.py +3 -3
- ask_candid/base/retrieval/__init__.py +0 -0
- ask_candid/base/retrieval/elastic.py +205 -0
- ask_candid/base/retrieval/knowledge_base.py +362 -0
- ask_candid/base/retrieval/schemas.py +23 -0
- ask_candid/base/retrieval/sources.py +40 -0
- ask_candid/base/retrieval/sparse_lexical.py +98 -0
- ask_candid/base/utils.py +52 -0
- ask_candid/chat.py +68 -55
- ask_candid/services/small_lm.py +32 -6
- ask_candid/tools/general.py +17 -0
- ask_candid/tools/org_search.py +182 -0
- ask_candid/tools/search.py +56 -111
- ask_candid/tools/utils.py +14 -0
- chat_v2.py +265 -0
- requirements.txt +5 -5
README.md
CHANGED
@@ -6,8 +6,8 @@ colorFrom: blue
|
|
6 |
colorTo: purple
|
7 |
python_version: 3.12
|
8 |
sdk: gradio
|
9 |
-
sdk_version: 5.
|
10 |
-
app_file:
|
11 |
pinned: true
|
12 |
license: mit
|
13 |
---
|
|
|
6 |
colorTo: purple
|
7 |
python_version: 3.12
|
8 |
sdk: gradio
|
9 |
+
sdk_version: 5.42.0
|
10 |
+
app_file: chat_v2.py
|
11 |
pinned: true
|
12 |
license: mit
|
13 |
---
|
ask_candid/base/api_base.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from typing import
|
2 |
|
3 |
from urllib3.util.retry import Retry
|
4 |
from requests.adapters import HTTPAdapter
|
@@ -10,7 +10,7 @@ class BaseAPI:
|
|
10 |
def __init__(
|
11 |
self,
|
12 |
url: str,
|
13 |
-
headers:
|
14 |
total_retries: int = 3,
|
15 |
backoff_factor: int = 2
|
16 |
) -> None:
|
@@ -36,7 +36,7 @@ class BaseAPI:
|
|
36 |
r.raise_for_status()
|
37 |
return r.json()
|
38 |
|
39 |
-
def post(self, payload:
|
40 |
r = self.session.post(url=self.__url, headers=self.__headers, json=payload, timeout=30)
|
41 |
r.raise_for_status()
|
42 |
return r.json()
|
|
|
1 |
+
from typing import Any
|
2 |
|
3 |
from urllib3.util.retry import Retry
|
4 |
from requests.adapters import HTTPAdapter
|
|
|
10 |
def __init__(
|
11 |
self,
|
12 |
url: str,
|
13 |
+
headers: dict[str, Any] | None = None,
|
14 |
total_retries: int = 3,
|
15 |
backoff_factor: int = 2
|
16 |
) -> None:
|
|
|
36 |
r.raise_for_status()
|
37 |
return r.json()
|
38 |
|
39 |
+
def post(self, payload: dict[str, Any]):
|
40 |
r = self.session.post(url=self.__url, headers=self.__headers, json=payload, timeout=30)
|
41 |
r.raise_for_status()
|
42 |
return r.json()
|
ask_candid/base/api_base_async.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from typing import
|
2 |
import json
|
3 |
|
4 |
import aiohttp
|
@@ -6,7 +6,7 @@ import aiohttp
|
|
6 |
|
7 |
class BaseAsyncAPI:
|
8 |
|
9 |
-
def __init__(self, url: str, headers:
|
10 |
self.__url = url
|
11 |
self.__headers = headers
|
12 |
self.__retries = max(retries, 5)
|
@@ -29,7 +29,7 @@ class BaseAsyncAPI:
|
|
29 |
break
|
30 |
return output
|
31 |
|
32 |
-
async def post(self, payload:
|
33 |
session_timeout = aiohttp.ClientTimeout(total=30)
|
34 |
async with aiohttp.ClientSession(headers=self.__headers, timeout=session_timeout) as session:
|
35 |
output = {}
|
|
|
1 |
+
from typing import Any
|
2 |
import json
|
3 |
|
4 |
import aiohttp
|
|
|
6 |
|
7 |
class BaseAsyncAPI:
|
8 |
|
9 |
+
def __init__(self, url: str, headers: dict[str, Any] | None = None, retries: int = 3) -> None:
|
10 |
self.__url = url
|
11 |
self.__headers = headers
|
12 |
self.__retries = max(retries, 5)
|
|
|
29 |
break
|
30 |
return output
|
31 |
|
32 |
+
async def post(self, payload: dict[str, Any]):
|
33 |
session_timeout = aiohttp.ClientTimeout(total=30)
|
34 |
async with aiohttp.ClientSession(headers=self.__headers, timeout=session_timeout) as session:
|
35 |
output = {}
|
ask_candid/base/config/base.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from dotenv import dotenv_values, find_dotenv
|
4 |
+
|
5 |
+
__env_values__ = dotenv_values(
|
6 |
+
dotenv_path=find_dotenv(".env", raise_error_if_not_found=False)
|
7 |
+
)
|
8 |
+
|
9 |
+
def _load_value(key: str):
|
10 |
+
return __env_values__.get(key) or os.getenv(key)
|
ask_candid/base/config/connections.py
CHANGED
@@ -1,33 +1,25 @@
|
|
1 |
from dataclasses import dataclass, field
|
2 |
-
import os
|
3 |
|
4 |
-
from
|
5 |
|
6 |
|
7 |
@dataclass
|
8 |
class BaseElasticSearchConnection:
|
9 |
"""Elasticsearch connection dataclass
|
10 |
"""
|
11 |
-
url: str = field(default_factory=str)
|
12 |
-
username: str = field(default_factory=str)
|
13 |
-
password: str = field(default_factory=str)
|
14 |
|
15 |
|
16 |
@dataclass
|
17 |
class BaseElasticAPIKeyCredential:
|
18 |
"""Cloud ID/API key data class
|
19 |
"""
|
20 |
-
cloud_id: str = field(default_factory=str)
|
21 |
-
api_key: str = field(default_factory=str)
|
22 |
|
23 |
|
24 |
-
__env_values__ = dotenv_values(
|
25 |
-
dotenv_path=find_dotenv(".env", raise_error_if_not_found=False)
|
26 |
-
)
|
27 |
-
|
28 |
-
def _load_value(key: str):
|
29 |
-
return __env_values__.get(key) or os.getenv(key)
|
30 |
-
|
31 |
SEMANTIC_ELASTIC_QA = BaseElasticAPIKeyCredential(
|
32 |
cloud_id=_load_value("SEMANTIC_ELASTIC_CLOUD_ID"),
|
33 |
api_key=_load_value("SEMANTIC_ELASTIC_API_KEY"),
|
|
|
1 |
from dataclasses import dataclass, field
|
|
|
2 |
|
3 |
+
from ask_candid.base.config.base import _load_value
|
4 |
|
5 |
|
6 |
@dataclass
|
7 |
class BaseElasticSearchConnection:
|
8 |
"""Elasticsearch connection dataclass
|
9 |
"""
|
10 |
+
url: str | None = field(default_factory=str)
|
11 |
+
username: str | None = field(default_factory=str)
|
12 |
+
password: str | None = field(default_factory=str)
|
13 |
|
14 |
|
15 |
@dataclass
|
16 |
class BaseElasticAPIKeyCredential:
|
17 |
"""Cloud ID/API key data class
|
18 |
"""
|
19 |
+
cloud_id: str | None = field(default_factory=str)
|
20 |
+
api_key: str | None = field(default_factory=str)
|
21 |
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
SEMANTIC_ELASTIC_QA = BaseElasticAPIKeyCredential(
|
24 |
cloud_id=_load_value("SEMANTIC_ELASTIC_CLOUD_ID"),
|
25 |
api_key=_load_value("SEMANTIC_ELASTIC_API_KEY"),
|
ask_candid/base/config/models.py
CHANGED
@@ -3,6 +3,7 @@ from types import MappingProxyType
|
|
3 |
Name2Endpoint = MappingProxyType({
|
4 |
"gpt-4o": "gpt-4o",
|
5 |
"claude-3.5-haiku": "us.anthropic.claude-3-5-haiku-20241022-v1:0",
|
|
|
6 |
# "llama-3.1-70b-instruct": "us.meta.llama3-1-70b-instruct-v1:0",
|
7 |
# "mistral-large": "mistral.mistral-large-2402-v1:0",
|
8 |
# "mixtral-8x7B": "mistral.mixtral-8x7b-instruct-v0:1",
|
|
|
3 |
Name2Endpoint = MappingProxyType({
|
4 |
"gpt-4o": "gpt-4o",
|
5 |
"claude-3.5-haiku": "us.anthropic.claude-3-5-haiku-20241022-v1:0",
|
6 |
+
"claude-4-sonnet": "us.anthropic.claude-sonnet-4-20250514-v1:0",
|
7 |
# "llama-3.1-70b-instruct": "us.meta.llama3-1-70b-instruct-v1:0",
|
8 |
# "mistral-large": "mistral.mistral-large-2402-v1:0",
|
9 |
# "mixtral-8x7B": "mistral.mixtral-8x7b-instruct-v0:1",
|
ask_candid/base/config/rest.py
CHANGED
@@ -1,25 +1,64 @@
|
|
1 |
-
from typing import TypedDict
|
2 |
-
import os
|
3 |
|
4 |
-
from
|
5 |
|
6 |
|
7 |
class Api(TypedDict):
|
8 |
"""REST API configuration template
|
9 |
"""
|
10 |
-
url: str
|
11 |
-
key: str
|
12 |
|
13 |
-
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
-
def _load_value(key: str):
|
18 |
-
return __env_values__.get(key) or os.getenv(key)
|
19 |
|
20 |
CDS_API = Api(
|
21 |
url=_load_value("CDS_API_URL"),
|
22 |
key=_load_value("CDS_API_KEY")
|
23 |
)
|
24 |
|
|
|
|
|
|
|
|
|
|
|
25 |
OPENAI = Api(url=None, key=_load_value("OPENAI_API_KEY"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import TypedDict, NamedTuple
|
|
|
2 |
|
3 |
+
from ask_candid.base.config.base import _load_value
|
4 |
|
5 |
|
6 |
class Api(TypedDict):
|
7 |
"""REST API configuration template
|
8 |
"""
|
9 |
+
url: str | None
|
10 |
+
key: str | None
|
11 |
|
12 |
+
class ApiConfig(NamedTuple):
|
13 |
+
url: str | None
|
14 |
+
key: str | None
|
15 |
+
|
16 |
+
@property
|
17 |
+
def header(self) -> dict[str, str | None]:
|
18 |
+
return {"x-api-key": self.key}
|
19 |
+
|
20 |
+
def endpoint(self, route: str):
|
21 |
+
return f"{self.url}/{route}"
|
22 |
|
|
|
|
|
23 |
|
24 |
CDS_API = Api(
|
25 |
url=_load_value("CDS_API_URL"),
|
26 |
key=_load_value("CDS_API_KEY")
|
27 |
)
|
28 |
|
29 |
+
CANDID_SEARCH_API = Api(
|
30 |
+
url=_load_value("CANDID_SEARCH_API_URL"),
|
31 |
+
key=_load_value("CANDID_SEARCH_API_KEY")
|
32 |
+
)
|
33 |
+
|
34 |
OPENAI = Api(url=None, key=_load_value("OPENAI_API_KEY"))
|
35 |
+
|
36 |
+
SEARCH = ApiConfig(
|
37 |
+
url="https://ajr9jccwf0.execute-api.us-east-1.amazonaws.com/Prod",
|
38 |
+
key=_load_value("SEARCH_API_KEY")
|
39 |
+
)
|
40 |
+
|
41 |
+
AUTOCODING = ApiConfig(
|
42 |
+
url="https://auto-coding-api.candid.org",
|
43 |
+
key=_load_value("AUTOCODING_API_KEY")
|
44 |
+
)
|
45 |
+
|
46 |
+
DOCUMENT = ApiConfig(
|
47 |
+
url="https://dtntz2p635.execute-api.us-east-1.amazonaws.com/Prod",
|
48 |
+
key=_load_value("GEOCODING_API_KEY")
|
49 |
+
)
|
50 |
+
|
51 |
+
FUNDER_RECOMMENDATION = ApiConfig(
|
52 |
+
url="https://r6g59fxbie.execute-api.us-east-1.amazonaws.com/Prod",
|
53 |
+
key=_load_value("FUNDER_RECS_API_KEY")
|
54 |
+
)
|
55 |
+
|
56 |
+
LOI_WRITER = ApiConfig(
|
57 |
+
url="https://tc2ir1o7ne.execute-api.us-east-1.amazonaws.com/Prod",
|
58 |
+
key=_load_value("LOI_WRITER_API_KEY")
|
59 |
+
)
|
60 |
+
|
61 |
+
GOLDEN_ORG = ApiConfig(
|
62 |
+
url="https://qfdur742ih.execute-api.us-east-1.amazonaws.com/Prod",
|
63 |
+
key=_load_value("GOLDEN_RECORD_API_KEY")
|
64 |
+
)
|
ask_candid/base/lambda_base.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from typing import
|
2 |
from time import sleep
|
3 |
import json
|
4 |
|
@@ -25,7 +25,7 @@ class LambdaInvokeBase:
|
|
25 |
|
26 |
def __init__(
|
27 |
self, function_name: str,
|
28 |
-
access_key:
|
29 |
) -> None:
|
30 |
if access_key is not None and secret_key is not None:
|
31 |
self._client = boto3.client(
|
@@ -39,7 +39,7 @@ class LambdaInvokeBase:
|
|
39 |
|
40 |
self.function_name = function_name
|
41 |
|
42 |
-
def _submit_request(self, payload:
|
43 |
response = self._client.invoke(
|
44 |
FunctionName=self.function_name,
|
45 |
InvocationType="RequestResponse",
|
|
|
1 |
+
from typing import Any
|
2 |
from time import sleep
|
3 |
import json
|
4 |
|
|
|
25 |
|
26 |
def __init__(
|
27 |
self, function_name: str,
|
28 |
+
access_key: str | None = None, secret_key: str | None = None,
|
29 |
) -> None:
|
30 |
if access_key is not None and secret_key is not None:
|
31 |
self._client = boto3.client(
|
|
|
39 |
|
40 |
self.function_name = function_name
|
41 |
|
42 |
+
def _submit_request(self, payload: dict[str, Any]) -> dict[str, Any] | list[Any]:
|
43 |
response = self._client.invoke(
|
44 |
FunctionName=self.function_name,
|
45 |
InvocationType="RequestResponse",
|
ask_candid/base/retrieval/__init__.py
ADDED
File without changes
|
ask_candid/base/retrieval/elastic.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any
|
2 |
+
from collections.abc import Iterator
|
3 |
+
|
4 |
+
from elasticsearch import Elasticsearch
|
5 |
+
|
6 |
+
from ask_candid.base.retrieval.sparse_lexical import SpladeEncoder
|
7 |
+
from ask_candid.base.config.connections import BaseElasticAPIKeyCredential, BaseElasticSearchConnection
|
8 |
+
|
9 |
+
NEWS_TRUST_SCORE_THRESHOLD = 0.8
|
10 |
+
SPARSE_ENCODING_SCORE_THRESHOLD = 0.4
|
11 |
+
|
12 |
+
|
13 |
+
def build_sparse_vector_query(
|
14 |
+
query: str,
|
15 |
+
fields: tuple[str, ...],
|
16 |
+
inference_id: str = ".elser-2-elasticsearch"
|
17 |
+
) -> dict[str, Any]:
|
18 |
+
"""Builds a valid Elasticsearch text expansion query payload
|
19 |
+
|
20 |
+
Parameters
|
21 |
+
----------
|
22 |
+
query : str
|
23 |
+
Search context string
|
24 |
+
fields : Tuple[str, ...]
|
25 |
+
Semantic text field names
|
26 |
+
inference_id : str, optional
|
27 |
+
ID of model deployed in Elasticsearch, by default ".elser-2-elasticsearch"
|
28 |
+
|
29 |
+
Returns
|
30 |
+
-------
|
31 |
+
Dict[str, Any]
|
32 |
+
"""
|
33 |
+
|
34 |
+
output = []
|
35 |
+
|
36 |
+
for f in fields:
|
37 |
+
output.append({
|
38 |
+
"nested": {
|
39 |
+
"path": f"embeddings.{f}.chunks",
|
40 |
+
"query": {
|
41 |
+
"sparse_vector": {
|
42 |
+
"field": f"embeddings.{f}.chunks.vector",
|
43 |
+
"inference_id": inference_id,
|
44 |
+
"prune": True,
|
45 |
+
"query": query,
|
46 |
+
# "boost": 1 / len(fields)
|
47 |
+
}
|
48 |
+
},
|
49 |
+
"inner_hits": {
|
50 |
+
"_source": False,
|
51 |
+
"size": 2,
|
52 |
+
"fields": [f"embeddings.{f}.chunks.chunk"]
|
53 |
+
}
|
54 |
+
}
|
55 |
+
})
|
56 |
+
return {"query": {"bool": {"should": output}}}
|
57 |
+
|
58 |
+
|
59 |
+
def build_sparse_vector_and_text_query(
|
60 |
+
query: str,
|
61 |
+
semantic_fields: tuple[str, ...],
|
62 |
+
text_fields: tuple[str, ...] | None,
|
63 |
+
highlight_fields: tuple[str, ...] | None,
|
64 |
+
excluded_fields: tuple[str, ...] | None,
|
65 |
+
inference_id: str = ".elser-2-elasticsearch"
|
66 |
+
) -> dict[str, Any]:
|
67 |
+
"""Builds Elasticsearch sparse vector and text query payload
|
68 |
+
|
69 |
+
Parameters
|
70 |
+
----------
|
71 |
+
query : str
|
72 |
+
Search context string
|
73 |
+
semantic_fields : Tuple[str]
|
74 |
+
Semantic text field names
|
75 |
+
highlight_fields: Tuple[str]
|
76 |
+
Fields which relevant chunks will be helpful for the agent to read
|
77 |
+
text_fields : Tuple[str]
|
78 |
+
Regular text fields
|
79 |
+
excluded_fields : Tuple[str]
|
80 |
+
Fields to exclude from the source
|
81 |
+
inference_id : str, optional
|
82 |
+
ID of model deployed in Elasticsearch, by default ".elser-2-elasticsearch"
|
83 |
+
|
84 |
+
Returns
|
85 |
+
-------
|
86 |
+
Dict[str, Any]
|
87 |
+
"""
|
88 |
+
|
89 |
+
output = []
|
90 |
+
final_query = {}
|
91 |
+
|
92 |
+
for f in semantic_fields:
|
93 |
+
output.append({
|
94 |
+
"sparse_vector": {
|
95 |
+
"field": f"{f}",
|
96 |
+
"inference_id": inference_id,
|
97 |
+
"query": query,
|
98 |
+
"boost": 1,
|
99 |
+
"prune": True # doesn't seem it changes anything if we use text queries additionally
|
100 |
+
}
|
101 |
+
})
|
102 |
+
|
103 |
+
if text_fields:
|
104 |
+
output.append({
|
105 |
+
"multi_match": {
|
106 |
+
"fields": text_fields,
|
107 |
+
"query": query,
|
108 |
+
"boost": 3
|
109 |
+
}
|
110 |
+
})
|
111 |
+
|
112 |
+
|
113 |
+
final_query = {
|
114 |
+
"track_total_hits": False,
|
115 |
+
"query": {
|
116 |
+
"bool": {"should": output}
|
117 |
+
}
|
118 |
+
}
|
119 |
+
|
120 |
+
if highlight_fields:
|
121 |
+
final_query["highlight"] = {
|
122 |
+
"fields": {
|
123 |
+
f"{f}": {
|
124 |
+
"type": "semantic", # ensures that highlighting is applied exclusively to semantic_text fields.
|
125 |
+
"number_of_fragments": 2, # number of chunks
|
126 |
+
"order": "none" # can be "score", but we have only two and hope for context
|
127 |
+
}
|
128 |
+
for f in highlight_fields
|
129 |
+
}
|
130 |
+
}
|
131 |
+
|
132 |
+
if excluded_fields:
|
133 |
+
final_query["_source"] = {"excludes": list(excluded_fields)}
|
134 |
+
return final_query
|
135 |
+
|
136 |
+
|
137 |
+
def news_query_builder(
|
138 |
+
query: str,
|
139 |
+
fields: tuple[str, ...],
|
140 |
+
encoder: SpladeEncoder,
|
141 |
+
days_ago: int = 60,
|
142 |
+
) -> dict[str, Any]:
|
143 |
+
"""Builds a valid Elasticsearch query against Candid news, simulating a token expansion.
|
144 |
+
|
145 |
+
Parameters
|
146 |
+
----------
|
147 |
+
query : str
|
148 |
+
Search context string
|
149 |
+
|
150 |
+
Returns
|
151 |
+
-------
|
152 |
+
Dict[str, Any]
|
153 |
+
"""
|
154 |
+
|
155 |
+
tokens = encoder.token_expand(query)
|
156 |
+
|
157 |
+
elastic_query = {
|
158 |
+
"_source": ["id", "link", "title", "content", "site_name"],
|
159 |
+
"query": {
|
160 |
+
"bool": {
|
161 |
+
"filter": [
|
162 |
+
{"range": {"event_date": {"gte": f"now-{days_ago}d/d"}}},
|
163 |
+
{"range": {"insert_date": {"gte": f"now-{days_ago}d/d"}}},
|
164 |
+
{"range": {"article_trust_worthiness": {"gt": NEWS_TRUST_SCORE_THRESHOLD}}}
|
165 |
+
],
|
166 |
+
"should": []
|
167 |
+
}
|
168 |
+
}
|
169 |
+
}
|
170 |
+
|
171 |
+
for token, score in tokens.items():
|
172 |
+
if score > SPARSE_ENCODING_SCORE_THRESHOLD:
|
173 |
+
elastic_query["query"]["bool"]["should"].append({
|
174 |
+
"multi_match": {
|
175 |
+
"query": token,
|
176 |
+
"fields": fields,
|
177 |
+
"boost": score
|
178 |
+
}
|
179 |
+
})
|
180 |
+
return elastic_query
|
181 |
+
|
182 |
+
|
183 |
+
def multi_search_base(
|
184 |
+
queries: list[dict[str, Any]],
|
185 |
+
credentials: BaseElasticSearchConnection | BaseElasticAPIKeyCredential,
|
186 |
+
timeout: int = 180
|
187 |
+
) -> Iterator[dict[str, Any]]:
|
188 |
+
if isinstance(credentials, BaseElasticAPIKeyCredential):
|
189 |
+
es = Elasticsearch(
|
190 |
+
cloud_id=credentials.cloud_id,
|
191 |
+
api_key=credentials.api_key,
|
192 |
+
verify_certs=False,
|
193 |
+
request_timeout=timeout
|
194 |
+
)
|
195 |
+
elif isinstance(credentials, BaseElasticSearchConnection):
|
196 |
+
es = Elasticsearch(
|
197 |
+
credentials.url,
|
198 |
+
http_auth=(credentials.username, credentials.password),
|
199 |
+
timeout=timeout
|
200 |
+
)
|
201 |
+
else:
|
202 |
+
raise TypeError(f"Invalid credentials of type `{type(credentials)}")
|
203 |
+
|
204 |
+
yield from es.msearch(body=queries).get("responses", [])
|
205 |
+
es.close()
|
ask_candid/base/retrieval/knowledge_base.py
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal, Any
|
2 |
+
from collections.abc import Iterator, Iterable
|
3 |
+
from itertools import groupby
|
4 |
+
import logging
|
5 |
+
|
6 |
+
from langchain_core.documents import Document
|
7 |
+
|
8 |
+
from ask_candid.base.retrieval.elastic import (
|
9 |
+
build_sparse_vector_query,
|
10 |
+
build_sparse_vector_and_text_query,
|
11 |
+
news_query_builder,
|
12 |
+
multi_search_base
|
13 |
+
)
|
14 |
+
from ask_candid.base.retrieval.sparse_lexical import SpladeEncoder
|
15 |
+
from ask_candid.base.retrieval.schemas import ElasticHitsResult
|
16 |
+
import ask_candid.base.retrieval.sources as S
|
17 |
+
from ask_candid.services.small_lm import CandidSLM
|
18 |
+
|
19 |
+
from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA, NEWS_ELASTIC
|
20 |
+
|
21 |
+
SourceNames = Literal[
|
22 |
+
"Candid Blog",
|
23 |
+
"Candid Help",
|
24 |
+
"Candid Learning",
|
25 |
+
"Candid News",
|
26 |
+
"IssueLab Research Reports",
|
27 |
+
"YouTube Training"
|
28 |
+
]
|
29 |
+
sparse_encoder = SpladeEncoder()
|
30 |
+
logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s")
|
31 |
+
logger = logging.getLogger(__name__)
|
32 |
+
logger.setLevel(logging.INFO)
|
33 |
+
|
34 |
+
|
35 |
+
# TODO remove
|
36 |
+
def get_context(field_name: str, hit: ElasticHitsResult, context_length: int = 1024, add_context: bool = True) -> str:
|
37 |
+
"""Pads the relevant chunk of text with context before and after
|
38 |
+
|
39 |
+
Parameters
|
40 |
+
----------
|
41 |
+
field_name : str
|
42 |
+
a field with the long text that was chunked into pieces
|
43 |
+
hit : ElasticHitsResult
|
44 |
+
context_length : int, optional
|
45 |
+
length of text to add before and after the chunk, by default 1024
|
46 |
+
add_context : bool, optional
|
47 |
+
Set to `False` to expand the text context by searching for the Elastic inner hit inside the larger document
|
48 |
+
, by default True
|
49 |
+
|
50 |
+
Returns
|
51 |
+
-------
|
52 |
+
str
|
53 |
+
longer chunks stuffed together
|
54 |
+
"""
|
55 |
+
|
56 |
+
chunks = []
|
57 |
+
# NOTE chunks have tokens, long text is a string, but may contain html which affects tokenization
|
58 |
+
long_text = hit.source.get(field_name) or ""
|
59 |
+
long_text = long_text.lower()
|
60 |
+
|
61 |
+
inner_hits_field = f"embeddings.{field_name}.chunks"
|
62 |
+
found_chunks = hit.inner_hits.get(inner_hits_field, {}) if hit.inner_hits else None
|
63 |
+
if found_chunks:
|
64 |
+
for h in found_chunks.get("hits", {}).get("hits") or []:
|
65 |
+
chunk = h.get("fields", {})[inner_hits_field][0]["chunk"][0]
|
66 |
+
|
67 |
+
# cutting the middle because we may have tokenizing artifacts there
|
68 |
+
chunk = chunk[3: -3]
|
69 |
+
|
70 |
+
if add_context:
|
71 |
+
# Find the start and end indices of the chunk in the large text
|
72 |
+
start_index = long_text.find(chunk[:20])
|
73 |
+
|
74 |
+
# Chunk is found
|
75 |
+
if start_index != -1:
|
76 |
+
end_index = start_index + len(chunk)
|
77 |
+
pre_start_index = max(0, start_index - context_length)
|
78 |
+
post_end_index = min(len(long_text), end_index + context_length)
|
79 |
+
chunks.append(long_text[pre_start_index:post_end_index])
|
80 |
+
else:
|
81 |
+
chunks.append(chunk)
|
82 |
+
return '\n\n'.join(chunks)
|
83 |
+
|
84 |
+
|
85 |
+
def generate_queries(
|
86 |
+
query: str,
|
87 |
+
sources: list[SourceNames],
|
88 |
+
news_days_ago: int = 60
|
89 |
+
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
90 |
+
"""Builds Elastic queries against indices which do or do not support sparse vector queries.
|
91 |
+
|
92 |
+
Parameters
|
93 |
+
----------
|
94 |
+
query : str
|
95 |
+
Text describing a user's question or a description of investigative work which requires support from Candid's
|
96 |
+
knowledge base
|
97 |
+
sources : list[SourceNames]
|
98 |
+
One or more sources of knowledge from different areas at Candid.
|
99 |
+
* Candid Blog: Blog posts from Candid staff and trusted partners intended to help those in the sector or
|
100 |
+
illuminate ongoing work
|
101 |
+
* Candid Help: Candid FAQs to help user's get started with Candid's product platform and learning resources
|
102 |
+
* Candid Learning: Training documents from Candid's subject matter experts
|
103 |
+
* Candid News: News articles and press releases about real-time activity in the philanthropic sector
|
104 |
+
* IssueLab Research Reports: Academic research reports about the social/philanthropic sector
|
105 |
+
* YouTube Training: Transcripts from video-based training seminars from Candid's subject matter experts
|
106 |
+
news_days_ago : int, optional
|
107 |
+
How many days in the past to search for news articles, if a user is asking for recent trends then this value
|
108 |
+
should be set lower >~ 10, by default 60
|
109 |
+
|
110 |
+
Returns
|
111 |
+
-------
|
112 |
+
tuple[list[dict[str, Any]], list[dict[str, Any]]]
|
113 |
+
(sparse vector queries, queries for indices which do not support sparse vectors)
|
114 |
+
"""
|
115 |
+
|
116 |
+
vector_queries = []
|
117 |
+
quasi_vector_queries = []
|
118 |
+
|
119 |
+
for source_name in sources:
|
120 |
+
if source_name == "Candid Blog":
|
121 |
+
q = build_sparse_vector_query(query=query, fields=S.CandidBlogConfig.semantic_fields)
|
122 |
+
q["_source"] = {"excludes": ["embeddings"]}
|
123 |
+
q["size"] = 5
|
124 |
+
vector_queries.extend([{"index": S.CandidBlogConfig.index_name}, q])
|
125 |
+
elif source_name == "Candid Help":
|
126 |
+
q = build_sparse_vector_query(query=query, fields=S.CandidHelpConfig.semantic_fields)
|
127 |
+
q["_source"] = {"excludes": ["embeddings"]}
|
128 |
+
q["size"] = 5
|
129 |
+
vector_queries.extend([{"index": S.CandidHelpConfig.index_name}, q])
|
130 |
+
elif source_name == "Candid Learning":
|
131 |
+
q = build_sparse_vector_query(query=query, fields=S.CandidLearningConfig.semantic_fields)
|
132 |
+
q["_source"] = {"excludes": ["embeddings"]}
|
133 |
+
q["size"] = 5
|
134 |
+
vector_queries.extend([{"index": S.CandidLearningConfig.index_name}, q])
|
135 |
+
elif source_name == "Candid News":
|
136 |
+
q = news_query_builder(
|
137 |
+
query=query,
|
138 |
+
fields=S.CandidNewsConfig.semantic_fields,
|
139 |
+
encoder=sparse_encoder,
|
140 |
+
days_ago=news_days_ago
|
141 |
+
)
|
142 |
+
q["size"] = 5
|
143 |
+
quasi_vector_queries.extend([{"index": S.CandidNewsConfig.index_name}, q])
|
144 |
+
elif source_name == "IssueLab Research Reports":
|
145 |
+
q = build_sparse_vector_query(query=query, fields=S.IssueLabConfig.semantic_fields)
|
146 |
+
q["_source"] = {"excludes": ["embeddings"]}
|
147 |
+
q["size"] = 1
|
148 |
+
vector_queries.extend([{"index": S.IssueLabConfig.index_name}, q])
|
149 |
+
elif source_name == "YouTube Training":
|
150 |
+
q = build_sparse_vector_and_text_query(
|
151 |
+
query=query,
|
152 |
+
semantic_fields=S.YoutubeConfig.semantic_fields,
|
153 |
+
text_fields=S.YoutubeConfig.text_fields,
|
154 |
+
highlight_fields=S.YoutubeConfig.highlight_fields,
|
155 |
+
excluded_fields=S.YoutubeConfig.excluded_fields
|
156 |
+
)
|
157 |
+
q["size"] = 5
|
158 |
+
vector_queries.extend([{"index": S.YoutubeConfig.index_name}, q])
|
159 |
+
|
160 |
+
return vector_queries, quasi_vector_queries
|
161 |
+
|
162 |
+
|
163 |
+
def run_search(
|
164 |
+
vector_searches: list[dict[str, Any]] | None = None,
|
165 |
+
non_vector_searches: list[dict[str, Any]] | None = None,
|
166 |
+
) -> list[ElasticHitsResult]:
|
167 |
+
def _msearch_response_generator(responses: Iterable[dict[str, Any]]) -> Iterator[ElasticHitsResult]:
|
168 |
+
for query_group in responses:
|
169 |
+
for h in query_group.get("hits", {}).get("hits", []):
|
170 |
+
inner_hits = h.get("inner_hits", {})
|
171 |
+
|
172 |
+
if not inner_hits and "news" in h.get("_index"):
|
173 |
+
inner_hits = {"text": h.get("_source", {}).get("content")}
|
174 |
+
|
175 |
+
yield ElasticHitsResult(
|
176 |
+
index=h["_index"],
|
177 |
+
id=h["_id"],
|
178 |
+
score=h["_score"],
|
179 |
+
source=h["_source"],
|
180 |
+
inner_hits=inner_hits,
|
181 |
+
highlight=h.get("highlight", {})
|
182 |
+
)
|
183 |
+
|
184 |
+
results = []
|
185 |
+
if vector_searches is not None and len(vector_searches) > 0:
|
186 |
+
hits = multi_search_base(queries=vector_searches, credentials=SEMANTIC_ELASTIC_QA)
|
187 |
+
for hit in _msearch_response_generator(responses=hits):
|
188 |
+
results.append(hit)
|
189 |
+
if non_vector_searches is not None and len(non_vector_searches) > 0:
|
190 |
+
hits = multi_search_base(queries=non_vector_searches, credentials=NEWS_ELASTIC)
|
191 |
+
for hit in _msearch_response_generator(responses=hits):
|
192 |
+
results.append(hit)
|
193 |
+
return results
|
194 |
+
|
195 |
+
|
196 |
+
def retrieved_text(hits: dict[str, Any]) -> str:
|
197 |
+
"""Extracts retrieved sub-texts from documents which are strong hits from semantic queries for the purpose of
|
198 |
+
re-scoring by a secondary language model.
|
199 |
+
|
200 |
+
Parameters
|
201 |
+
----------
|
202 |
+
hits : Dict[str, Any]
|
203 |
+
|
204 |
+
Returns
|
205 |
+
-------
|
206 |
+
str
|
207 |
+
"""
|
208 |
+
|
209 |
+
nlp = CandidSLM()
|
210 |
+
|
211 |
+
text = []
|
212 |
+
for _, v in hits.items():
|
213 |
+
if _ == "text":
|
214 |
+
s = nlp.summarize(v, top_k=3)
|
215 |
+
text.append(s.summary)
|
216 |
+
# text.append(v)
|
217 |
+
continue
|
218 |
+
|
219 |
+
for h in (v.get("hits", {}).get("hits") or []):
|
220 |
+
for _, field in h.get("fields", {}).items():
|
221 |
+
for chunk in field:
|
222 |
+
if chunk.get("chunk"):
|
223 |
+
text.extend(chunk["chunk"])
|
224 |
+
return '\n'.join(text)
|
225 |
+
|
226 |
+
|
227 |
+
def reranker(
|
228 |
+
query_results: Iterable[ElasticHitsResult],
|
229 |
+
search_text: str | None = None,
|
230 |
+
max_num_results: int = 5
|
231 |
+
) -> Iterator[ElasticHitsResult]:
|
232 |
+
"""Reranks Elasticsearch hits coming from multiple indices/queries which may have scores on different scales.
|
233 |
+
This will shuffle results
|
234 |
+
|
235 |
+
Parameters
|
236 |
+
----------
|
237 |
+
query_results : Iterable[ElasticHitsResult]
|
238 |
+
|
239 |
+
Yields
|
240 |
+
------
|
241 |
+
Iterator[ElasticHitsResult]
|
242 |
+
"""
|
243 |
+
|
244 |
+
results: list[ElasticHitsResult] = []
|
245 |
+
texts: list[str] = []
|
246 |
+
for _, data in groupby(query_results, key=lambda x: x.index):
|
247 |
+
data = list(data) # noqa: PLW2901
|
248 |
+
max_score = max(data, key=lambda x: x.score).score
|
249 |
+
min_score = min(data, key=lambda x: x.score).score
|
250 |
+
|
251 |
+
for d in data:
|
252 |
+
d.score = (d.score - min_score) / (max_score - min_score + 1e-9)
|
253 |
+
results.append(d)
|
254 |
+
|
255 |
+
if search_text:
|
256 |
+
if d.inner_hits:
|
257 |
+
text = retrieved_text(d.inner_hits)
|
258 |
+
if d.highlight:
|
259 |
+
highlight_texts = []
|
260 |
+
for k,v in d.highlight.items():
|
261 |
+
v_text = '\n'.join(v)
|
262 |
+
highlight_texts.append(v_text)
|
263 |
+
text = '\n'.join(highlight_texts)
|
264 |
+
texts.append(text)
|
265 |
+
|
266 |
+
if search_text and len(texts) == len(results) and len(texts) > 1:
|
267 |
+
logger.info("Re-ranking %d retrieval results", len(results))
|
268 |
+
scores = sparse_encoder.query_reranking(query=search_text, documents=texts)
|
269 |
+
for r, s in zip(results, scores):
|
270 |
+
r.score = s
|
271 |
+
|
272 |
+
yield from sorted(results, key=lambda x: x.score, reverse=True)[:max_num_results]
|
273 |
+
|
274 |
+
|
275 |
+
def process_hit(hit: ElasticHitsResult) -> Document:
|
276 |
+
if "issuelab-elser" in hit.index:
|
277 |
+
doc = Document(
|
278 |
+
page_content='\n\n'.join([
|
279 |
+
hit.source.get("combined_item_description", ""),
|
280 |
+
hit.source.get("description", ""),
|
281 |
+
hit.source.get("combined_issuelab_findings", ""),
|
282 |
+
get_context("content", hit, context_length=12)
|
283 |
+
]),
|
284 |
+
metadata={
|
285 |
+
"title": hit.source["title"],
|
286 |
+
"source": "IssueLab",
|
287 |
+
"source_id": hit.source["resource_id"],
|
288 |
+
"url": hit.source.get("permalink", "")
|
289 |
+
}
|
290 |
+
)
|
291 |
+
elif "youtube" in hit.index:
|
292 |
+
highlight = hit.highlight or {}
|
293 |
+
doc = Document(
|
294 |
+
page_content='\n\n'.join([
|
295 |
+
hit.source.get("title", ""),
|
296 |
+
hit.source.get("semantic_description", ""),
|
297 |
+
' '.join(highlight.get("semantic_cc_text", []))
|
298 |
+
]),
|
299 |
+
metadata={
|
300 |
+
"title": hit.source.get("title", ""),
|
301 |
+
"source": "Candid YouTube",
|
302 |
+
"source_id": hit.source['video_id'],
|
303 |
+
"url": f"https://www.youtube.com/watch?v={hit.source['video_id']}"
|
304 |
+
}
|
305 |
+
)
|
306 |
+
elif "candid-blog" in hit.index:
|
307 |
+
doc = Document(
|
308 |
+
page_content='\n\n'.join([
|
309 |
+
hit.source.get("title", ""),
|
310 |
+
hit.source.get("excerpt", ""),
|
311 |
+
get_context("content", hit, context_length=12, add_context=False),
|
312 |
+
get_context("authors_text", hit, context_length=12, add_context=False),
|
313 |
+
hit.source.get("title_summary_tags", "")
|
314 |
+
]),
|
315 |
+
metadata={
|
316 |
+
"title": hit.source.get("title", ""),
|
317 |
+
"source": "Candid Blog",
|
318 |
+
"source_id": hit.source["id"],
|
319 |
+
"url": hit.source["link"]
|
320 |
+
}
|
321 |
+
)
|
322 |
+
elif "candid-learning" in hit.index:
|
323 |
+
doc = Document(
|
324 |
+
page_content='\n\n'.join([
|
325 |
+
hit.source.get("title", ""),
|
326 |
+
hit.source.get("staff_recommendations", ""),
|
327 |
+
hit.source.get("training_topics", ""),
|
328 |
+
get_context("content", hit, context_length=12)
|
329 |
+
]),
|
330 |
+
metadata={
|
331 |
+
"title": hit.source["title"],
|
332 |
+
"source": "Candid Learning",
|
333 |
+
"source_id": hit.source["post_id"],
|
334 |
+
"url": hit.source.get("url", "")
|
335 |
+
}
|
336 |
+
)
|
337 |
+
elif "candid-help" in hit.index:
|
338 |
+
doc = Document(
|
339 |
+
page_content='\n\n'.join([
|
340 |
+
hit.source.get("combined_article_description", ""),
|
341 |
+
get_context("content", hit, context_length=12)
|
342 |
+
]),
|
343 |
+
metadata={
|
344 |
+
"title": hit.source.get("title", ""),
|
345 |
+
"source": "Candid Help",
|
346 |
+
"source_id": hit.source["id"],
|
347 |
+
"url": hit.source.get("link", "")
|
348 |
+
}
|
349 |
+
)
|
350 |
+
elif "news" in hit.index:
|
351 |
+
doc = Document(
|
352 |
+
page_content='\n\n'.join([hit.source.get("title", ""), hit.source.get("content", "")]),
|
353 |
+
metadata={
|
354 |
+
"title": hit.source.get("title", ""),
|
355 |
+
"source": hit.source.get("site_name") or "Candid News",
|
356 |
+
"source_id": hit.source["id"],
|
357 |
+
"url": hit.source.get("link", "")
|
358 |
+
}
|
359 |
+
)
|
360 |
+
else:
|
361 |
+
raise ValueError(f"Unknown source result from index {hit.index}")
|
362 |
+
return doc
|
ask_candid/base/retrieval/schemas.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any
|
2 |
+
from dataclasses import dataclass, field
|
3 |
+
|
4 |
+
|
5 |
+
@dataclass
|
6 |
+
class ElasticSourceConfig:
|
7 |
+
index_name: str
|
8 |
+
semantic_fields: tuple[str,...] = field(default_factory=tuple)
|
9 |
+
text_fields: tuple[str,...] | None = field(default_factory=tuple)
|
10 |
+
highlight_fields: tuple[str,...] | None = field(default_factory=tuple)
|
11 |
+
excluded_fields: tuple[str,...] | None = field(default_factory=tuple)
|
12 |
+
|
13 |
+
|
14 |
+
@dataclass
|
15 |
+
class ElasticHitsResult:
|
16 |
+
"""Dataclass for Elasticsearch hits results
|
17 |
+
"""
|
18 |
+
index: str
|
19 |
+
id: Any
|
20 |
+
score: float
|
21 |
+
source: dict[str, Any]
|
22 |
+
inner_hits: dict[str, Any] | None
|
23 |
+
highlight: dict[str, list[str]] | None
|
ask_candid/base/retrieval/sources.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ask_candid.base.retrieval.schemas import ElasticSourceConfig
|
2 |
+
|
3 |
+
|
4 |
+
CandidBlogConfig = ElasticSourceConfig(
|
5 |
+
index_name="search-semantic-candid-blog",
|
6 |
+
semantic_fields=("content", "authors_text", "title_summary_tags")
|
7 |
+
)
|
8 |
+
|
9 |
+
|
10 |
+
CandidHelpConfig = ElasticSourceConfig(
|
11 |
+
index_name="search-semantic-candid-help-elser_ve1",
|
12 |
+
semantic_fields=("content", "combined_article_description")
|
13 |
+
)
|
14 |
+
|
15 |
+
|
16 |
+
CandidLearningConfig = ElasticSourceConfig(
|
17 |
+
index_name="search-semantic-candid-learning_ve1",
|
18 |
+
semantic_fields=("content", "title", "training_topics", "staff_recommendations")
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
CandidNewsConfig = ElasticSourceConfig(
|
23 |
+
index_name="news_1",
|
24 |
+
semantic_fields=("title", "content")
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
IssueLabConfig = ElasticSourceConfig(
|
29 |
+
index_name="search-semantic-issuelab-elser_ve2",
|
30 |
+
semantic_fields=("description", "content", "combined_issuelab_findings", "combined_item_description")
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
YoutubeConfig = ElasticSourceConfig(
|
35 |
+
index_name="search-semantic-youtube",
|
36 |
+
semantic_fields=("semantic_title", "semantic_description","semantic_cc_text"),
|
37 |
+
text_fields=("title", "description", "cc_text"),
|
38 |
+
highlight_fields=("semantic_cc_text",),
|
39 |
+
excluded_fields=("cc_text", "semantic_cc_text", "semantic_title")
|
40 |
+
)
|
ask_candid/base/retrieval/sparse_lexical.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm.auto import tqdm
|
2 |
+
|
3 |
+
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
4 |
+
from transformers.tokenization_utils_base import BatchEncoding
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import Tensor
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
class SpladeEncoder:
|
12 |
+
batch_size = 8
|
13 |
+
model_id = "naver/splade-v3"
|
14 |
+
|
15 |
+
def __init__(self):
|
16 |
+
|
17 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
18 |
+
self.model = AutoModelForMaskedLM.from_pretrained(self.model_id)
|
19 |
+
self.idx2token = {idx: token for token, idx in self.tokenizer.get_vocab().items()}
|
20 |
+
|
21 |
+
if torch.cuda.is_available():
|
22 |
+
self.device = torch.device("cuda")
|
23 |
+
elif torch.mps.is_available():
|
24 |
+
self.device = torch.device("mps")
|
25 |
+
else:
|
26 |
+
self.device = torch.device("cpu")
|
27 |
+
self.model.to(self.device)
|
28 |
+
|
29 |
+
@torch.no_grad()
|
30 |
+
def forward(self, inputs: BatchEncoding) -> Tensor:
|
31 |
+
output = self.model(**inputs.to(self.device))
|
32 |
+
|
33 |
+
logits: Tensor = output.logits
|
34 |
+
mask: Tensor = inputs.attention_mask
|
35 |
+
|
36 |
+
vec = (logits.relu() + 1).log() * mask.unsqueeze(dim=-1)
|
37 |
+
return vec.max(dim=1)[0].squeeze()
|
38 |
+
|
39 |
+
def encode(self, texts: list[str]) -> Tensor:
|
40 |
+
"""Forward pass to get dense vectors
|
41 |
+
|
42 |
+
Parameters
|
43 |
+
----------
|
44 |
+
texts : list[str]
|
45 |
+
|
46 |
+
Returns
|
47 |
+
-------
|
48 |
+
torch.Tensor
|
49 |
+
Dense vectors
|
50 |
+
"""
|
51 |
+
|
52 |
+
vectors = []
|
53 |
+
for batch in tqdm(DataLoader(dataset=texts, shuffle=False, batch_size=self.batch_size), desc="Encoding"): # type: ignore
|
54 |
+
tokens = self.tokenizer(batch, return_tensors='pt', truncation=True, padding=True)
|
55 |
+
vec = self.forward(inputs=tokens)
|
56 |
+
vectors.append(vec)
|
57 |
+
return torch.vstack(vectors)
|
58 |
+
|
59 |
+
def query_reranking(self, query: str, documents: list[str]) -> list[float]:
|
60 |
+
"""Cosine similarity re-ranking.
|
61 |
+
|
62 |
+
Parameters
|
63 |
+
----------
|
64 |
+
query : str
|
65 |
+
Retrieval query
|
66 |
+
documents : list[str]
|
67 |
+
Retrieved documents
|
68 |
+
|
69 |
+
Returns
|
70 |
+
-------
|
71 |
+
list[float]
|
72 |
+
Cosine values
|
73 |
+
"""
|
74 |
+
|
75 |
+
vec = self.encode([query, *documents])
|
76 |
+
xQ = F.normalize(vec[:1], dim=-1, p=2.)
|
77 |
+
xD = F.normalize(vec[1:], dim=-1, p=2.)
|
78 |
+
return (xQ * xD).sum(dim=-1).cpu().tolist()
|
79 |
+
|
80 |
+
def token_expand(self, query: str) -> dict[str, float]:
|
81 |
+
"""Sparse lexical token expansion.
|
82 |
+
|
83 |
+
Parameters
|
84 |
+
----------
|
85 |
+
query : str
|
86 |
+
Retrieval query
|
87 |
+
|
88 |
+
Returns
|
89 |
+
-------
|
90 |
+
dict[str, float]
|
91 |
+
"""
|
92 |
+
|
93 |
+
vec = self.encode([query]).squeeze()
|
94 |
+
cols = vec.nonzero().squeeze().cpu().tolist()
|
95 |
+
weights = vec[cols].cpu().tolist()
|
96 |
+
|
97 |
+
sparse_dict_tokens = {self.idx2token[idx]: round(weight, 3) for idx, weight in zip(cols, weights) if weight > 0}
|
98 |
+
return dict(sorted(sparse_dict_tokens.items(), key=lambda item: item[1], reverse=True))
|
ask_candid/base/utils.py
CHANGED
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
1 |
import asyncio
|
2 |
|
3 |
|
@@ -12,3 +15,52 @@ def async_tasks(*tasks):
|
|
12 |
loop.stop()
|
13 |
loop.close()
|
14 |
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections.abc import Callable
|
2 |
+
from functools import wraps
|
3 |
+
from time import sleep
|
4 |
import asyncio
|
5 |
|
6 |
|
|
|
15 |
loop.stop()
|
16 |
loop.close()
|
17 |
return results
|
18 |
+
|
19 |
+
|
20 |
+
def retry_on_status(
|
21 |
+
num_retries: int = 3,
|
22 |
+
backoff_factor: float = 0.5,
|
23 |
+
max_backoff: float | None = None,
|
24 |
+
retry_statuses: tuple[int, ...] = (501, 503)
|
25 |
+
):
|
26 |
+
"""
|
27 |
+
Retry decorator for functions making httpx requests.
|
28 |
+
Retries on specific HTTP status codes with exponential backoff.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
num_retries (int): Max number of retries.
|
32 |
+
backoff_factor (float): Multiplier for delay (e.g., 0.5, 1, etc.).
|
33 |
+
max_backoff (float, optional): Cap on the backoff delay in seconds.
|
34 |
+
retry_statuses (tuple): HTTP status codes to retry on.
|
35 |
+
"""
|
36 |
+
|
37 |
+
def decorator(func: Callable):
|
38 |
+
|
39 |
+
if asyncio.iscoroutinefunction(func):
|
40 |
+
# Async version
|
41 |
+
@wraps(func)
|
42 |
+
async def async_wrapper(*args, **kwargs):
|
43 |
+
for attempt in range(num_retries + 1):
|
44 |
+
response = await func(*args, **kwargs)
|
45 |
+
if response.status_code not in retry_statuses:
|
46 |
+
return response
|
47 |
+
if attempt < num_retries:
|
48 |
+
delay = min(backoff_factor * (2 ** attempt), max_backoff or float('inf'))
|
49 |
+
await asyncio.sleep(delay)
|
50 |
+
return response
|
51 |
+
return async_wrapper
|
52 |
+
|
53 |
+
# Sync version
|
54 |
+
@wraps(func)
|
55 |
+
def sync_wrapper(*args, **kwargs):
|
56 |
+
for attempt in range(num_retries + 1):
|
57 |
+
response = func(*args, **kwargs)
|
58 |
+
if response.status_code not in retry_statuses:
|
59 |
+
return response
|
60 |
+
if attempt < num_retries:
|
61 |
+
delay = min(backoff_factor * (2 ** attempt), max_backoff or float('inf'))
|
62 |
+
sleep(delay)
|
63 |
+
return response
|
64 |
+
return sync_wrapper
|
65 |
+
|
66 |
+
return decorator
|
ask_candid/chat.py
CHANGED
@@ -1,66 +1,79 @@
|
|
1 |
-
from typing import
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
from
|
5 |
-
from langgraph.checkpoint.memory import MemorySaver
|
6 |
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
|
11 |
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
indices: Optional[List[str]] = None,
|
18 |
-
premium_features: Optional[List[str]] = None,
|
19 |
-
) -> Tuple[gr.MultimodalTextbox, List[Dict[str, Any]], str]:
|
20 |
-
if premium_features is None:
|
21 |
-
premium_features = []
|
22 |
-
if len(history) == 0:
|
23 |
-
history.append({"role": "system", "content": START_SYSTEM_PROMPT})
|
24 |
|
25 |
-
history.append({"role": "user", "content": user_input["text"]})
|
26 |
-
inputs = {"messages": history}
|
27 |
-
# thread_id can be an email https://github.com/yurisasc/memory-enhanced-ai-assistant/blob/main/assistant.py
|
28 |
-
thread_id = get_session_id(thread_id)
|
29 |
-
config = {"configurable": {"thread_id": thread_id}}
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
enable_recommendations=enable_recommendations
|
37 |
-
)
|
38 |
|
39 |
-
memory = MemorySaver() # TODO: don't use for Prod
|
40 |
-
graph = workflow.compile(checkpointer=memory)
|
41 |
-
response = graph.invoke(inputs, config=config)
|
42 |
-
messages = response["messages"]
|
43 |
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
last_message = messages[-1]
|
51 |
-
ai_answer = last_message.content
|
52 |
|
53 |
-
sources_html = ""
|
54 |
-
for message in messages[-2:]:
|
55 |
-
if message.type == "HTML":
|
56 |
-
sources_html = message.content
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
"metadata": {"title": "Sources HTML"},
|
64 |
-
})
|
65 |
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import TypedDict, Literal, Any
|
2 |
+
from collections.abc import Iterator
|
3 |
+
from dataclasses import asdict
|
4 |
+
import logging
|
5 |
+
import json
|
6 |
|
7 |
+
from langchain_core.messages.tool import ToolMessage
|
8 |
+
from gradio import ChatMessage
|
|
|
9 |
|
10 |
+
logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s")
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
logger.setLevel(logging.INFO)
|
13 |
|
14 |
|
15 |
+
class ToolInput(TypedDict):
|
16 |
+
name: str
|
17 |
+
args: dict[str, Any]
|
18 |
+
id: str
|
19 |
+
type: Literal["tool_call"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
+
class CalledTool(TypedDict):
|
23 |
+
id: str
|
24 |
+
name: Literal["tools"]
|
25 |
+
input: list[ToolInput]
|
26 |
+
triggers: tuple[str, ...]
|
|
|
|
|
27 |
|
|
|
|
|
|
|
|
|
28 |
|
29 |
+
class ToolResult(TypedDict):
|
30 |
+
id: str
|
31 |
+
name: Literal["tools"]
|
32 |
+
error: bool | None
|
33 |
+
result: list[tuple[str, list[ToolMessage]]]
|
34 |
+
interrupts: list
|
|
|
|
|
35 |
|
|
|
|
|
|
|
|
|
36 |
|
37 |
+
def convert_history_for_graph_agent(history: list[dict | ChatMessage]) -> list[dict]:
|
38 |
+
_hist = []
|
39 |
+
for h in history:
|
40 |
+
if isinstance(h, ChatMessage):
|
41 |
+
h = asdict(h)
|
|
|
|
|
42 |
|
43 |
+
if h.get("content"):
|
44 |
+
# if h.get("metadata"):
|
45 |
+
# # skip if it's a tool-call
|
46 |
+
# continue
|
47 |
+
_hist.append(h)
|
48 |
+
return _hist
|
49 |
+
|
50 |
+
|
51 |
+
def format_tool_call(input_chunk: CalledTool) -> Iterator[ChatMessage]:
|
52 |
+
for graph_input in input_chunk["input"]:
|
53 |
+
yield ChatMessage(
|
54 |
+
role="assistant",
|
55 |
+
content=json.dumps(graph_input["args"]),
|
56 |
+
metadata={
|
57 |
+
"title": f"Using tool `{graph_input.get('name')}`",
|
58 |
+
"status": "done",
|
59 |
+
"id": input_chunk["id"],
|
60 |
+
"parent_id": input_chunk["id"]
|
61 |
+
}
|
62 |
+
)
|
63 |
+
|
64 |
+
|
65 |
+
def format_tool_response(result_chunk: ToolResult) -> Iterator[ChatMessage]:
|
66 |
+
for _, outputs in result_chunk["result"]:
|
67 |
+
for tool in outputs:
|
68 |
+
logger.info("Called tool `%s`", tool.name)
|
69 |
+
yield ChatMessage(
|
70 |
+
role="assistant",
|
71 |
+
content=tool.content,
|
72 |
+
metadata={
|
73 |
+
"title": f"Results from tool `{tool.name}`",
|
74 |
+
"tool_name": tool.name,
|
75 |
+
"documents": tool.artifact,
|
76 |
+
"status": "done",
|
77 |
+
"parent_id": result_chunk["id"]
|
78 |
+
} # pyright: ignore[reportArgumentType]
|
79 |
+
)
|
ask_candid/services/small_lm.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
from typing import List, Optional
|
2 |
from dataclasses import dataclass
|
3 |
from enum import Enum
|
4 |
|
@@ -9,10 +8,26 @@ from ask_candid.base.lambda_base import LambdaInvokeBase
|
|
9 |
|
10 |
@dataclass(slots=True)
|
11 |
class Encoding:
|
12 |
-
inputs:
|
13 |
vectors: torch.Tensor
|
14 |
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
class CandidSLM(LambdaInvokeBase):
|
17 |
"""Wrapper around Candid's custom small language model.
|
18 |
For more details see https://dev.azure.com/guidestar/DataScience/_git/graph-ai?path=/releases/language.
|
@@ -35,7 +50,7 @@ class CandidSLM(LambdaInvokeBase):
|
|
35 |
DOCUMENT_NER_SALIENCE = "/document/entitySalience"
|
36 |
|
37 |
def __init__(
|
38 |
-
self, access_key:
|
39 |
) -> None:
|
40 |
super().__init__(
|
41 |
function_name="small-lm",
|
@@ -43,11 +58,22 @@ class CandidSLM(LambdaInvokeBase):
|
|
43 |
secret_key=secret_key
|
44 |
)
|
45 |
|
46 |
-
def encode(self, text:
|
47 |
response = self._submit_request({"text": text, "path": self.Tasks.ENCODE.value})
|
|
|
48 |
|
49 |
-
|
50 |
inputs=(response.get("inputs") or []),
|
51 |
vectors=torch.tensor((response.get("vectors") or []), dtype=torch.float32)
|
52 |
)
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from dataclasses import dataclass
|
2 |
from enum import Enum
|
3 |
|
|
|
8 |
|
9 |
@dataclass(slots=True)
|
10 |
class Encoding:
|
11 |
+
inputs: list[str]
|
12 |
vectors: torch.Tensor
|
13 |
|
14 |
|
15 |
+
@dataclass(slots=True)
|
16 |
+
class SummaryItem:
|
17 |
+
rank: int
|
18 |
+
score: float
|
19 |
+
text: str
|
20 |
+
|
21 |
+
|
22 |
+
@dataclass(slots=True)
|
23 |
+
class TextSummary:
|
24 |
+
snippets: list[SummaryItem]
|
25 |
+
|
26 |
+
@property
|
27 |
+
def summary(self) -> str:
|
28 |
+
return ' '.join([_.text for _ in self.snippets])
|
29 |
+
|
30 |
+
|
31 |
class CandidSLM(LambdaInvokeBase):
|
32 |
"""Wrapper around Candid's custom small language model.
|
33 |
For more details see https://dev.azure.com/guidestar/DataScience/_git/graph-ai?path=/releases/language.
|
|
|
50 |
DOCUMENT_NER_SALIENCE = "/document/entitySalience"
|
51 |
|
52 |
def __init__(
|
53 |
+
self, access_key: str | None = None, secret_key: str | None = None
|
54 |
) -> None:
|
55 |
super().__init__(
|
56 |
function_name="small-lm",
|
|
|
58 |
secret_key=secret_key
|
59 |
)
|
60 |
|
61 |
+
def encode(self, text: list[str]) -> Encoding:
|
62 |
response = self._submit_request({"text": text, "path": self.Tasks.ENCODE.value})
|
63 |
+
assert isinstance(response, dict)
|
64 |
|
65 |
+
return Encoding(
|
66 |
inputs=(response.get("inputs") or []),
|
67 |
vectors=torch.tensor((response.get("vectors") or []), dtype=torch.float32)
|
68 |
)
|
69 |
+
|
70 |
+
def summarize(self, text: list[str], top_k: int) -> TextSummary:
|
71 |
+
response = self._submit_request({"text": text, "path": self.Tasks.DOCUMENT_SUMMARIZE.value})
|
72 |
+
assert isinstance(response, dict)
|
73 |
+
|
74 |
+
return TextSummary(
|
75 |
+
snippets=[
|
76 |
+
SummaryItem(rank=item["rank"], score=item["score"], text=item["value"])
|
77 |
+
for item in (response.get("summary") or [])[:top_k]
|
78 |
+
]
|
79 |
+
)
|
ask_candid/tools/general.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import date
|
2 |
+
|
3 |
+
from langchain_core.tools import tool
|
4 |
+
|
5 |
+
|
6 |
+
@tool
|
7 |
+
def get_current_day() -> date:
|
8 |
+
"""Get the current day to reference for any time-sensitive data requests. This might be useful for information
|
9 |
+
searches through news data, where more current articles may be more relevant.
|
10 |
+
|
11 |
+
Returns
|
12 |
+
-------
|
13 |
+
date
|
14 |
+
Today's date
|
15 |
+
"""
|
16 |
+
|
17 |
+
return date.today()
|
ask_candid/tools/org_search.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any
|
2 |
+
|
3 |
+
from pydantic import BaseModel, Field
|
4 |
+
from langchain_core.output_parsers.pydantic import PydanticOutputParser
|
5 |
+
from langchain_core.language_models.chat_models import BaseChatModel
|
6 |
+
from langchain_core.runnables import RunnableSequence
|
7 |
+
from langchain_core.prompts import PromptTemplate
|
8 |
+
from langchain_core.tools import tool, BaseTool
|
9 |
+
|
10 |
+
from thefuzz import fuzz
|
11 |
+
|
12 |
+
from ask_candid.tools.utils import format_candid_profile_link
|
13 |
+
from ask_candid.base.api_base import BaseAPI
|
14 |
+
from ask_candid.base.config.rest import CANDID_SEARCH_API
|
15 |
+
|
16 |
+
|
17 |
+
class OrganizationNames(BaseModel):
|
18 |
+
"""List of names of social-sector organizations, such as nonprofits and foundations."""
|
19 |
+
orgnames: list[str] = Field(..., description="List of organization names.")
|
20 |
+
|
21 |
+
|
22 |
+
class OrganizationIdentifierArgs(BaseModel):
|
23 |
+
text: str = Field(..., description="Chat model response text which contains named organizations.")
|
24 |
+
|
25 |
+
|
26 |
+
class OrganizationIdentifier(BaseTool):
|
27 |
+
llm: BaseChatModel
|
28 |
+
parser: type[PydanticOutputParser] = PydanticOutputParser(pydantic_object=OrganizationNames)
|
29 |
+
template: str = """Extract only the names of officially recognized organizations, foundations, and government
|
30 |
+
entities from the text below. Do not include any entries that contain descriptions, regional identifiers, or
|
31 |
+
explanations within parentheses or following the name. Strictly exclude databases, resources, crowdfunding
|
32 |
+
platforms, and general terms. Provide the output only in the specified JSON format.
|
33 |
+
|
34 |
+
input text: ```{chatbot_output}```
|
35 |
+
output format: ```{format_instructions}```
|
36 |
+
"""
|
37 |
+
|
38 |
+
name: str = "organization-identifier"
|
39 |
+
description: str = """
|
40 |
+
Identify the names of nonprofits and foundations from chat model responses. If it is likely that a response contains
|
41 |
+
proper names then it should be processed through this tool.
|
42 |
+
|
43 |
+
Examples
|
44 |
+
--------
|
45 |
+
>>> `organization_identifier('My Favorite Foundation awarded a grant to My Favorite Nonprofit.')`
|
46 |
+
>>> `organization_identifier('The LoremIpsum Nonprofit will be running a community event this Thursday')`
|
47 |
+
"""
|
48 |
+
args_schema: type[OrganizationIdentifierArgs] = OrganizationIdentifierArgs
|
49 |
+
|
50 |
+
def _build_pipeline(self):
|
51 |
+
prompt = PromptTemplate(
|
52 |
+
template=self.template,
|
53 |
+
input_variables=["chatbot_output"],
|
54 |
+
partial_variables={"format_instructions": self.parser.get_format_instructions()}
|
55 |
+
)
|
56 |
+
return RunnableSequence(prompt, self.llm, self.parser)
|
57 |
+
|
58 |
+
def _run(self, text: str) -> str:
|
59 |
+
chain = self._build_pipeline()
|
60 |
+
result: OrganizationNames = chain.invoke({"chatbot_output": text})
|
61 |
+
return result.orgnames
|
62 |
+
|
63 |
+
async def _arun(self, text: str) -> str:
|
64 |
+
chain = self._build_pipeline()
|
65 |
+
result: OrganizationNames = await chain.ainvoke({"chatbot_output": text})
|
66 |
+
return result.orgnames
|
67 |
+
|
68 |
+
|
69 |
+
def name_search(name: str) -> list[dict[str, Any]]:
|
70 |
+
candid_org_search = BaseAPI(
|
71 |
+
url=f'{CANDID_SEARCH_API["url"]}/v1/search',
|
72 |
+
headers={"x-api-key": CANDID_SEARCH_API["key"]}
|
73 |
+
)
|
74 |
+
results = candid_org_search.get(
|
75 |
+
query=f"'{name}'",
|
76 |
+
searchMode="organization_only",
|
77 |
+
rowCount=5
|
78 |
+
)
|
79 |
+
return results.get("returnedOrgs") or []
|
80 |
+
|
81 |
+
|
82 |
+
def find_similar(name: str, potential_matches: list[dict[str, Any]], threshold: int = 80):
|
83 |
+
for org in potential_matches:
|
84 |
+
similarity = max(
|
85 |
+
fuzz.ratio(name.lower(), (org["orgName"] or "").lower()),
|
86 |
+
fuzz.ratio(name.lower(), (org["akaName"] or "").lower()),
|
87 |
+
fuzz.ratio(name.lower(), (org["dbaName"] or "").lower()),
|
88 |
+
)
|
89 |
+
if similarity >= threshold:
|
90 |
+
yield org, similarity
|
91 |
+
|
92 |
+
|
93 |
+
@tool(response_format="content_and_artifact")
|
94 |
+
def find_mentioned_organizations(organizations: list[str]) -> tuple[str, dict[str, str]]:
|
95 |
+
"""Match organization names found in a chat response to official organizations tracked by Candid. This involves
|
96 |
+
using the Candid Search API in a lookup mode, and then finding the best result(s) using a heuristic string
|
97 |
+
similarity search.
|
98 |
+
|
99 |
+
This tool is focused on getting links to the organization's Candid profile for the user to click and explore in
|
100 |
+
more detail.
|
101 |
+
|
102 |
+
Use the URLs here to replace organization names in the chat response with links to the organization's profile. Links
|
103 |
+
to Candid profiles **MUST** be used to do the following:
|
104 |
+
1. Generate direct links to Candid organization profiles
|
105 |
+
2. Provide a mechanism for users to easily access detailed organizational information
|
106 |
+
3. Enhance responses with authoritative source links
|
107 |
+
|
108 |
+
Key Usage Requirements:
|
109 |
+
- Always incorporate returned profile URLs directly into the response text
|
110 |
+
- Replace organization name mentions with hyperlinked Candid profile URLs
|
111 |
+
- Prioritize creating a seamless user experience by making URLs contextually relevant
|
112 |
+
|
113 |
+
Example Desired Output:
|
114 |
+
Instead of: 'The Gates Foundation does impressive work.'
|
115 |
+
Use: 'The [Gates Foundation](https://app.candid.org/profile/XXXXX) does impressive work.'
|
116 |
+
|
117 |
+
The function returns a tuple with:
|
118 |
+
- A link information text (optional)
|
119 |
+
- A dictionary mapping input names to their best Candid Search profile URL
|
120 |
+
|
121 |
+
Failure to integrate the URLs into the response is considered an incomplete implementation.",
|
122 |
+
|
123 |
+
Examples
|
124 |
+
--------
|
125 |
+
>>> find_mentioned_organizations(organizations=['Gates Foundation', 'Candid'])
|
126 |
+
|
127 |
+
Parameters
|
128 |
+
----------
|
129 |
+
organizations : list[str]
|
130 |
+
A list of organization name strings found in a chat response message which need to be matches
|
131 |
+
|
132 |
+
Returns
|
133 |
+
-------
|
134 |
+
tuple[str, dict[str, str]]
|
135 |
+
(Link information text, mapping input name --> Candid Search profile URL of the best potential match)
|
136 |
+
"""
|
137 |
+
|
138 |
+
output = {}
|
139 |
+
for name in organizations:
|
140 |
+
search_results = name_search(name)
|
141 |
+
try:
|
142 |
+
best_result, _ = max(find_similar(name=name, potential_matches=search_results), key=lambda x: x[-1])
|
143 |
+
except ValueError:
|
144 |
+
# no similar organizations could be found for this one, keep going
|
145 |
+
continue
|
146 |
+
output[name] = format_candid_profile_link(best_result["candidEntityID"])
|
147 |
+
|
148 |
+
response = [f"The Candid profile link for {name} is {url}" for name, url in output.items()]
|
149 |
+
return '. '.join(response), output
|
150 |
+
|
151 |
+
|
152 |
+
@tool
|
153 |
+
def find_mentioned_organizations_detailed(organizations: list[str]) -> dict[str, dict[str, Any]]:
|
154 |
+
"""Match organization names found in a chat response to official organizations tracked by Candid. This involves
|
155 |
+
using the Candid Search API in a lookup mode, and then finding the best result(s) using a heuristic string
|
156 |
+
similarity search.
|
157 |
+
|
158 |
+
Examples
|
159 |
+
--------
|
160 |
+
>>> find_mentioned_organizations(organizations=['Gates Foundation', 'Candid'])
|
161 |
+
|
162 |
+
Parameters
|
163 |
+
----------
|
164 |
+
organizations : list[str]
|
165 |
+
A list of organization name strings found in a chat response message which need to be matches
|
166 |
+
|
167 |
+
Returns
|
168 |
+
-------
|
169 |
+
dict[str, dict[str, Any]]
|
170 |
+
Mapping from the input name(s) to the best potential match.
|
171 |
+
"""
|
172 |
+
|
173 |
+
output = {}
|
174 |
+
for name in organizations:
|
175 |
+
search_results = name_search(name)
|
176 |
+
try:
|
177 |
+
best_result, _ = max(find_similar(name=name, potential_matches=search_results), key=lambda x: x[-1])
|
178 |
+
except ValueError:
|
179 |
+
# no similar organizations could be found for this one, keep going
|
180 |
+
continue
|
181 |
+
output[name] = best_result
|
182 |
+
return output
|
ask_candid/tools/search.py
CHANGED
@@ -1,122 +1,67 @@
|
|
1 |
-
from typing import List, Tuple, Callable, Optional, Any
|
2 |
-
from functools import partial
|
3 |
-
import logging
|
4 |
-
|
5 |
-
from pydantic import BaseModel, Field
|
6 |
-
from langchain_core.language_models.llms import LLM
|
7 |
from langchain_core.documents import Document
|
8 |
-
from langchain_core.tools import
|
9 |
-
|
10 |
-
from ask_candid.retrieval.
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
Parameters
|
32 |
----------
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
Returns
|
41 |
-------
|
42 |
-
|
43 |
-
|
44 |
"""
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
logger.warning("User callback was passed in but failed: %s", ex)
|
51 |
-
|
52 |
-
output = ["Search didn't return any Candid sources"]
|
53 |
-
page_content = []
|
54 |
-
content = "Search didn't return any Candid sources"
|
55 |
-
results = get_query_results(search_text=user_input, indices=indices)
|
56 |
-
if results:
|
57 |
-
output = get_reranked_results(results, search_text=user_input)
|
58 |
-
for doc in output:
|
59 |
-
page_content.append(doc.page_content)
|
60 |
-
content = "\n\n".join(page_content)
|
61 |
-
|
62 |
-
# for the tool we need to return a tuple for content_and_artifact type
|
63 |
-
return content, output
|
64 |
-
|
65 |
-
|
66 |
-
def retriever_tool(
|
67 |
-
indices: List[DataIndices],
|
68 |
-
user_callback: Optional[Callable[[str], Any]] = None
|
69 |
-
) -> Tool:
|
70 |
-
"""Tool component for use in conditional edge building for RAG execution graph.
|
71 |
-
Cannot use `create_retriever_tool` because it only provides content losing all metadata on the way
|
72 |
-
https://python.langchain.com/docs/how_to/custom_tools/#returning-artifacts-of-tool-execution
|
73 |
-
|
74 |
-
Parameters
|
75 |
-
----------
|
76 |
-
indices : List[DataIndices]
|
77 |
-
Semantic index names to search over
|
78 |
-
user_callback : Optional[Callable[[str], Any]], optional
|
79 |
-
Optional UI callback to inform the user of apps states, by default None
|
80 |
-
|
81 |
-
Returns
|
82 |
-
-------
|
83 |
-
Tool
|
84 |
-
"""
|
85 |
-
|
86 |
-
return Tool(
|
87 |
-
name="retrieve_social_sector_information",
|
88 |
-
func=partial(get_search_results, indices=indices, user_callback=user_callback),
|
89 |
-
description=(
|
90 |
-
"Return additional information about social and philanthropic sector, "
|
91 |
-
"including nonprofits (NGO), grants, foundations, funding, RFP, LOI, Candid."
|
92 |
-
),
|
93 |
-
args_schema=RetrieverInput,
|
94 |
-
response_format="content_and_artifact"
|
95 |
)
|
96 |
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
state : _type_
|
105 |
-
The current state
|
106 |
-
llm : LLM
|
107 |
-
tools : List[Tool]
|
108 |
-
|
109 |
-
Returns
|
110 |
-
-------
|
111 |
-
AgentState
|
112 |
-
The updated state with the agent response appended to messages
|
113 |
-
"""
|
114 |
-
|
115 |
-
logger.info("---SEARCH AGENT---")
|
116 |
-
messages = state["messages"]
|
117 |
-
question = messages[-1].content
|
118 |
-
|
119 |
-
model = llm.bind_tools(tools)
|
120 |
-
response = model.invoke(messages)
|
121 |
-
# return a list, because this will get added to the existing list
|
122 |
-
return {"messages": [response], "user_input": question}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from langchain_core.documents import Document
|
2 |
+
from langchain_core.tools import tool
|
3 |
+
|
4 |
+
from ask_candid.base.retrieval.knowledge_base import (
|
5 |
+
SourceNames,
|
6 |
+
generate_queries,
|
7 |
+
run_search,
|
8 |
+
reranker,
|
9 |
+
process_hit
|
10 |
+
)
|
11 |
+
|
12 |
+
|
13 |
+
@tool(response_format="content_and_artifact")
|
14 |
+
def search_candid_knowledge_base(
|
15 |
+
query: str,
|
16 |
+
sources: list[SourceNames],
|
17 |
+
news_days_ago: int = 60
|
18 |
+
) -> tuple[str, list[Document]]:
|
19 |
+
"""Search Candid's subject matter expert knowledge base to find answers about the social and philanthropic sector.
|
20 |
+
This knowledge includes help articles and video training sessions from Candid's subject matter experts, blog posts
|
21 |
+
about the sector from Candid staff and trusted partner authors, research documents about the sector and news
|
22 |
+
articles curated about activity happening in the sector around the world.
|
23 |
+
|
24 |
+
Searches are performed through a combination of vector and keyword searching. Results are then re-ranked against
|
25 |
+
the original query to get the best results.
|
26 |
+
|
27 |
+
Search results often come back with specific organizations named, especially if referencing the news. In these cases
|
28 |
+
the organizations should be identified in Candid's data and links to their profiles **MUST** be included in final
|
29 |
+
chat response to the user.
|
30 |
|
31 |
Parameters
|
32 |
----------
|
33 |
+
query : str
|
34 |
+
Text describing a user's question or a description of investigative work which requires support from Candid's
|
35 |
+
knowledge base
|
36 |
+
sources : list[SourceNames]
|
37 |
+
One or more sources of knowledge from different areas at Candid.
|
38 |
+
* Candid Blog: Blog posts from Candid staff and trusted partners intended to help those in the sector or
|
39 |
+
illuminate ongoing work
|
40 |
+
* Candid Help: Candid FAQs to help user's get started with Candid's product platform and learning resources
|
41 |
+
* Candid Learning: Training documents from Candid's subject matter experts
|
42 |
+
* Candid News: News articles and press releases about real-time activity in the philanthropic sector
|
43 |
+
* IssueLab Research Reports: Academic research reports about the social/philanthropic sector
|
44 |
+
* YouTube Training: Transcripts from video-based training seminars from Candid's subject matter experts
|
45 |
+
news_days_ago : int, optional
|
46 |
+
How many days in the past to search for news articles, if a user is asking for recent trends then this value
|
47 |
+
should be set lower >~ 10, by default 60
|
48 |
|
49 |
Returns
|
50 |
-------
|
51 |
+
str
|
52 |
+
Re-ranked document text
|
53 |
"""
|
54 |
|
55 |
+
vector_queries, quasi_vector_queries = generate_queries(
|
56 |
+
query=query,
|
57 |
+
sources=sources,
|
58 |
+
news_days_ago=news_days_ago
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
)
|
60 |
|
61 |
+
results = run_search(vector_searches=vector_queries, non_vector_searches=quasi_vector_queries)
|
62 |
+
text_response = []
|
63 |
+
response_sources = []
|
64 |
+
for hit in map(process_hit, reranker(results, search_text=query)):
|
65 |
+
text_response.append(hit.page_content)
|
66 |
+
response_sources.append(hit)
|
67 |
+
return '\n\n'.join(text_response), response_sources
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ask_candid/tools/utils.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def format_candid_profile_link(candid_entity_id: int | str) -> str:
|
2 |
+
"""Format the Candid Search organization profile link.
|
3 |
+
|
4 |
+
Parameters
|
5 |
+
----------
|
6 |
+
candid_entity_id : int | str
|
7 |
+
|
8 |
+
Returns
|
9 |
+
-------
|
10 |
+
str
|
11 |
+
URL
|
12 |
+
"""
|
13 |
+
|
14 |
+
return f"https://app.candid.org/profile/{candid_entity_id}"
|
chat_v2.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import TypedDict, Any
|
2 |
+
from collections.abc import Iterator, AsyncIterator
|
3 |
+
import os
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
|
7 |
+
from langgraph.graph.state import CompiledStateGraph
|
8 |
+
from langgraph.prebuilt import create_react_agent
|
9 |
+
from langchain_aws import ChatBedrock
|
10 |
+
import boto3
|
11 |
+
|
12 |
+
from ask_candid.tools.org_search import OrganizationIdentifier, find_mentioned_organizations
|
13 |
+
from ask_candid.tools.search import search_candid_knowledge_base
|
14 |
+
from ask_candid.tools.general import get_current_day
|
15 |
+
from ask_candid.utils import html_format_docs_chat
|
16 |
+
from ask_candid.base.config.constants import START_SYSTEM_PROMPT
|
17 |
+
from ask_candid.base.config.models import Name2Endpoint
|
18 |
+
from ask_candid.chat import convert_history_for_graph_agent, format_tool_call, format_tool_response
|
19 |
+
|
20 |
+
try:
|
21 |
+
from feedback import FeedbackApi
|
22 |
+
ROOT = "."
|
23 |
+
except ImportError:
|
24 |
+
from demos.feedback import FeedbackApi
|
25 |
+
ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..")
|
26 |
+
|
27 |
+
BOT_LOGO = os.path.join(ROOT, "static", "candid_logo_yellow.png")
|
28 |
+
if not os.path.isfile(BOT_LOGO):
|
29 |
+
BOT_LOGO = os.path.join(ROOT, "..", "..", "static", "candid_logo_yellow.png")
|
30 |
+
|
31 |
+
|
32 |
+
class LoggedComponents(TypedDict):
|
33 |
+
context: list[gr.Component]
|
34 |
+
found_helpful: gr.Component
|
35 |
+
will_recommend: gr.Component
|
36 |
+
comments: gr.Component
|
37 |
+
email: gr.Component
|
38 |
+
|
39 |
+
|
40 |
+
def build_execution_graph() -> CompiledStateGraph:
|
41 |
+
llm = ChatBedrock(
|
42 |
+
client=boto3.client("bedrock-runtime", region_name="us-east-1"),
|
43 |
+
model=Name2Endpoint["claude-3.5-haiku"]
|
44 |
+
)
|
45 |
+
org_name_recognition = OrganizationIdentifier(llm=llm) # bind the main chat model to the tool
|
46 |
+
return create_react_agent(
|
47 |
+
model=llm,
|
48 |
+
tools=[
|
49 |
+
get_current_day,
|
50 |
+
org_name_recognition,
|
51 |
+
find_mentioned_organizations,
|
52 |
+
search_candid_knowledge_base
|
53 |
+
],
|
54 |
+
)
|
55 |
+
|
56 |
+
|
57 |
+
def generate_postscript_messages(history: list[gr.ChatMessage]) -> Iterator[gr.ChatMessage]:
|
58 |
+
for record in history:
|
59 |
+
title = record.metadata.get("tool_name")
|
60 |
+
if title == search_candid_knowledge_base.name:
|
61 |
+
yield gr.ChatMessage(
|
62 |
+
role="assistant",
|
63 |
+
content=html_format_docs_chat(record.metadata.get("documents")),
|
64 |
+
metadata={
|
65 |
+
"title": "Source citations",
|
66 |
+
}
|
67 |
+
)
|
68 |
+
elif title == find_mentioned_organizations.name:
|
69 |
+
pass
|
70 |
+
|
71 |
+
|
72 |
+
async def execute(
|
73 |
+
user_input: dict[str, Any],
|
74 |
+
history: list[gr.ChatMessage]
|
75 |
+
) -> AsyncIterator[tuple[gr.Component, list[gr.ChatMessage]]]:
|
76 |
+
if len(history) == 0:
|
77 |
+
history.append(gr.ChatMessage(role="system", content=START_SYSTEM_PROMPT))
|
78 |
+
|
79 |
+
history.append(gr.ChatMessage(role="user", content=user_input["text"]))
|
80 |
+
for fname in user_input.get("files") or []:
|
81 |
+
fname: str
|
82 |
+
if fname.endswith('.txt'):
|
83 |
+
with open(fname, 'r', encoding='utf8') as f:
|
84 |
+
history.append(gr.ChatMessage(role="user", content=f.read()))
|
85 |
+
yield gr.MultimodalTextbox(value=None, interactive=True), history
|
86 |
+
|
87 |
+
horizon = len(history)
|
88 |
+
inputs = {"messages": convert_history_for_graph_agent(history)}
|
89 |
+
|
90 |
+
graph = build_execution_graph()
|
91 |
+
|
92 |
+
history.append(gr.ChatMessage(role="assistant", content=""))
|
93 |
+
async for stream_mode, chunk in graph.astream(inputs, stream_mode=["messages", "tasks"]):
|
94 |
+
if stream_mode == "messages" and chunk[0].content:
|
95 |
+
for msg in chunk[0].content:
|
96 |
+
if 'text' in msg:
|
97 |
+
history[-1].content += msg["text"]
|
98 |
+
yield gr.MultimodalTextbox(value=None, interactive=True), history
|
99 |
+
|
100 |
+
elif stream_mode == "tasks" and chunk.get("name") == "tools" and chunk.get("error") is None:
|
101 |
+
if "input" in chunk:
|
102 |
+
for msg in format_tool_call(chunk):
|
103 |
+
history.append(msg)
|
104 |
+
yield gr.MultimodalTextbox(value=None, interactive=True), history
|
105 |
+
elif "result" in chunk:
|
106 |
+
for msg in format_tool_response(chunk):
|
107 |
+
history.append(msg)
|
108 |
+
yield gr.MultimodalTextbox(value=None, interactive=True), history
|
109 |
+
history.append(gr.ChatMessage(role="assistant", content=""))
|
110 |
+
|
111 |
+
for post_msg in generate_postscript_messages(history=history[horizon:]):
|
112 |
+
history.append(post_msg)
|
113 |
+
yield gr.MultimodalTextbox(value=None, interactive=True), history
|
114 |
+
|
115 |
+
|
116 |
+
def send_feedback(
|
117 |
+
chat_context,
|
118 |
+
found_helpful,
|
119 |
+
will_recommend,
|
120 |
+
comments,
|
121 |
+
email
|
122 |
+
):
|
123 |
+
api = FeedbackApi()
|
124 |
+
total_submissions = 0
|
125 |
+
|
126 |
+
try:
|
127 |
+
response = api(
|
128 |
+
context=chat_context,
|
129 |
+
found_helpful=found_helpful,
|
130 |
+
will_recommend=will_recommend,
|
131 |
+
comments=comments,
|
132 |
+
email=email
|
133 |
+
)
|
134 |
+
total_submissions = response.get("response", 0)
|
135 |
+
gr.Info("Thank you for submitting feedback")
|
136 |
+
except Exception as ex:
|
137 |
+
raise gr.Error(f"Error submitting feedback: {ex}")
|
138 |
+
return total_submissions
|
139 |
+
|
140 |
+
|
141 |
+
def build_chat_app():
|
142 |
+
with gr.Blocks(theme=gr.themes.Soft(), title="Chat") as demo:
|
143 |
+
|
144 |
+
gr.Markdown(
|
145 |
+
"""
|
146 |
+
<h1>Candid's AI assistant</h1>
|
147 |
+
|
148 |
+
<p>
|
149 |
+
Please read the <a
|
150 |
+
href='https://info.candid.org/chatbot-reference-guide'
|
151 |
+
target="_blank"
|
152 |
+
rel="noopener noreferrer"
|
153 |
+
>guide</a> to get started.
|
154 |
+
</p>
|
155 |
+
<hr>
|
156 |
+
"""
|
157 |
+
)
|
158 |
+
|
159 |
+
with gr.Column():
|
160 |
+
chatbot = gr.Chatbot(
|
161 |
+
label="AskCandid",
|
162 |
+
elem_id="chatbot",
|
163 |
+
editable="user",
|
164 |
+
avatar_images=(
|
165 |
+
None, # user
|
166 |
+
BOT_LOGO, # bot
|
167 |
+
),
|
168 |
+
height="60vh",
|
169 |
+
type="messages",
|
170 |
+
show_label=False,
|
171 |
+
show_copy_button=True,
|
172 |
+
autoscroll=True,
|
173 |
+
layout="panel",
|
174 |
+
)
|
175 |
+
msg = gr.MultimodalTextbox(label="Your message", interactive=True)
|
176 |
+
gr.ClearButton(components=[msg, chatbot], size="sm")
|
177 |
+
|
178 |
+
# pylint: disable=no-member
|
179 |
+
# chatbot.like(fn=like_callback, inputs=chatbot, outputs=None)
|
180 |
+
msg.submit(
|
181 |
+
fn=execute,
|
182 |
+
inputs=[msg, chatbot],
|
183 |
+
outputs=[msg, chatbot],
|
184 |
+
show_api=False
|
185 |
+
)
|
186 |
+
logged = LoggedComponents(context=chatbot)
|
187 |
+
|
188 |
+
return demo, logged
|
189 |
+
|
190 |
+
|
191 |
+
def build_feedback(components: LoggedComponents) -> gr.Blocks:
|
192 |
+
with gr.Blocks(theme=gr.themes.Soft(), title="Candid AI demo") as demo:
|
193 |
+
gr.Markdown("<h1>Help us improve this tool with your valuable feedback</h1>")
|
194 |
+
|
195 |
+
with gr.Row():
|
196 |
+
with gr.Column():
|
197 |
+
found_helpful = gr.Radio(
|
198 |
+
[True, False], label="Did you find what you were looking for?"
|
199 |
+
)
|
200 |
+
will_recommend = gr.Radio(
|
201 |
+
[True, False],
|
202 |
+
label="Will you recommend this Chatbot to others?",
|
203 |
+
)
|
204 |
+
comment = gr.Textbox(label="Additional comments (optional)", lines=4)
|
205 |
+
email = gr.Textbox(label="Your email (optional)", lines=1)
|
206 |
+
submit = gr.Button("Submit Feedback")
|
207 |
+
|
208 |
+
components["found_helpful"] = found_helpful
|
209 |
+
components["will_recommend"] = will_recommend
|
210 |
+
components["comments"] = comment
|
211 |
+
components["email"] = email
|
212 |
+
|
213 |
+
# pylint: disable=no-member
|
214 |
+
submit.click(
|
215 |
+
fn=send_feedback,
|
216 |
+
inputs=[
|
217 |
+
components["context"],
|
218 |
+
components["found_helpful"],
|
219 |
+
components["will_recommend"],
|
220 |
+
components["comments"],
|
221 |
+
components["email"]
|
222 |
+
],
|
223 |
+
outputs=None,
|
224 |
+
show_api=False,
|
225 |
+
api_name=False,
|
226 |
+
preprocess=False,
|
227 |
+
)
|
228 |
+
|
229 |
+
return demo
|
230 |
+
|
231 |
+
|
232 |
+
def build_app():
|
233 |
+
candid_chat, logger = build_chat_app()
|
234 |
+
feedback = build_feedback(logger)
|
235 |
+
|
236 |
+
with open(os.path.join(ROOT, "static", "chatStyle.css"), "r", encoding="utf8") as f:
|
237 |
+
css_chat = f.read()
|
238 |
+
|
239 |
+
demo = gr.TabbedInterface(
|
240 |
+
interface_list=[
|
241 |
+
candid_chat,
|
242 |
+
feedback
|
243 |
+
],
|
244 |
+
tab_names=[
|
245 |
+
"Candid's AI assistant",
|
246 |
+
"Feedback"
|
247 |
+
],
|
248 |
+
title="Candid's AI assistant",
|
249 |
+
theme=gr.themes.Soft(),
|
250 |
+
css=css_chat,
|
251 |
+
)
|
252 |
+
return demo
|
253 |
+
|
254 |
+
|
255 |
+
if __name__ == "__main__":
|
256 |
+
app = build_app()
|
257 |
+
app.queue(max_size=5).launch(
|
258 |
+
show_api=False,
|
259 |
+
mcp_server=False,
|
260 |
+
auth=[
|
261 |
+
(os.getenv("APP_USERNAME"), os.getenv("APP_PASSWORD")),
|
262 |
+
(os.getenv("APP_PUBLIC_USERNAME"), os.getenv("APP_PUBLIC_PASSWORD")),
|
263 |
+
],
|
264 |
+
auth_message="Login to Candid's AI assistant",
|
265 |
+
)
|
requirements.txt
CHANGED
@@ -1,11 +1,11 @@
|
|
1 |
boto3
|
2 |
elasticsearch==7.17.6
|
3 |
thefuzz
|
4 |
-
gradio==5.
|
5 |
-
langchain
|
6 |
-
langchain-aws
|
7 |
-
|
8 |
-
langgraph
|
9 |
pydantic==2.10.6
|
10 |
pyopenssl>22.0.0
|
11 |
python-dotenv
|
|
|
1 |
boto3
|
2 |
elasticsearch==7.17.6
|
3 |
thefuzz
|
4 |
+
gradio==5.42.0
|
5 |
+
langchain==0.3.27
|
6 |
+
langchain-aws==0.2.30
|
7 |
+
langgraph==0.6.5
|
8 |
+
langgraph-prebuilt==0.6.4
|
9 |
pydantic==2.10.6
|
10 |
pyopenssl>22.0.0
|
11 |
python-dotenv
|