Spaces:
Runtime error
Runtime error
File size: 10,385 Bytes
82a7a28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 |
from tinytroupe.utils import JsonSerializableRegistry
import tinytroupe.utils as utils
from tinytroupe.agent import logger
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
#######################################################################################################################
# Grounding connectors
#######################################################################################################################
class GroundingConnector(JsonSerializableRegistry):
"""
An abstract class representing a grounding connector. A grounding connector is a component that allows an agent to ground
its knowledge in external sources, such as files, web pages, databases, etc.
"""
serializable_attributes = ["name"]
def __init__(self, name:str) -> None:
self.name = name
def retrieve_relevant(self, relevance_target:str, source:str, top_k=20) -> list:
raise NotImplementedError("Subclasses must implement this method.")
def retrieve_by_name(self, name:str) -> str:
raise NotImplementedError("Subclasses must implement this method.")
def list_sources(self) -> list:
raise NotImplementedError("Subclasses must implement this method.")
@utils.post_init
class BaseSemanticGroundingConnector(GroundingConnector):
"""
A base class for semantic grounding connectors. A semantic grounding connector is a component that indexes and retrieves
documents based on so-called "semantic search" (i.e, embeddings-based search). This specific implementation
is based on the VectorStoreIndex class from the LLaMa-Index library. Here, "documents" refer to the llama-index's
data structure that stores a unit of content, not necessarily a file.
"""
serializable_attributes = ["documents"]
def __init__(self, name:str="Semantic Grounding") -> None:
super().__init__(name)
self.documents = None
self.name_to_document = None
# @post_init ensures that _post_init is called after the __init__ method
def _post_init(self):
"""
This will run after __init__, since the class has the @post_init decorator.
It is convenient to separate some of the initialization processes to make deserialize easier.
"""
self.index = None
if not hasattr(self, 'documents') or self.documents is None:
self.documents = []
if not hasattr(self, 'name_to_document') or self.name_to_document is None:
self.name_to_document = {}
self.add_documents(self.documents)
def retrieve_relevant(self, relevance_target:str, top_k=20) -> list:
"""
Retrieves all values from memory that are relevant to a given target.
"""
if self.index is not None:
retriever = self.index.as_retriever(similarity_top_k=top_k)
nodes = retriever.retrieve(relevance_target)
else:
nodes = []
retrieved = []
for node in nodes:
content = "SOURCE: " + node.metadata.get('file_name', '(unknown)')
content += "\n" + "SIMILARITY SCORE:" + str(node.score)
content += "\n" + "RELEVANT CONTENT:" + node.text
retrieved.append(content)
logger.debug(f"Content retrieved: {content[:200]}")
return retrieved
def retrieve_by_name(self, name:str) -> list:
"""
Retrieves a content source by its name.
"""
# TODO also optionally provide a relevance target?
results = []
if self.name_to_document is not None and name in self.name_to_document:
docs = self.name_to_document[name]
for i, doc in enumerate(docs):
if doc is not None:
content = f"SOURCE: {name}\n"
content += f"PAGE: {i}\n"
content += "CONTENT: \n" + doc.text[:10000] # TODO a more intelligent way to limit the content
results.append(content)
return results
def list_sources(self) -> list:
"""
Lists the names of the available content sources.
"""
if self.name_to_document is not None:
return list(self.name_to_document.keys())
else:
return []
def add_document(self, document, doc_to_name_func=None) -> None:
"""
Indexes a document for semantic retrieval.
"""
self.add_documents([document], doc_to_name_func)
def add_documents(self, new_documents, doc_to_name_func=None) -> list:
"""
Indexes documents for semantic retrieval.
"""
# index documents by name
if len(new_documents) > 0:
# add the new documents to the list of documents
self.documents += new_documents
# process documents individually too
for document in new_documents:
# out of an abundance of caution, we sanitize the text
document.text = utils.sanitize_raw_string(document.text)
if doc_to_name_func is not None:
name = doc_to_name_func(document)
# self.name_to_document[name] contains a list, since each source file could be split into multiple pages
if name in self.name_to_document:
self.name_to_document[name].append(document)
else:
self.name_to_document[name] = [document]
# index documents for semantic retrieval
if self.index is None:
self.index = VectorStoreIndex.from_documents(self.documents)
else:
self.index.refresh(self.documents)
@utils.post_init
class LocalFilesGroundingConnector(BaseSemanticGroundingConnector):
serializable_attributes = ["folders_paths"]
def __init__(self, name:str="Local Files", folders_paths: list=None) -> None:
super().__init__(name)
self.folders_paths = folders_paths
# @post_init ensures that _post_init is called after the __init__ method
def _post_init(self):
"""
This will run after __init__, since the class has the @post_init decorator.
It is convenient to separate some of the initialization processes to make deserialize easier.
"""
self.loaded_folders_paths = []
if not hasattr(self, 'folders_paths') or self.folders_paths is None:
self.folders_paths = []
self.add_folders(self.folders_paths)
def add_folders(self, folders_paths:list) -> None:
"""
Adds a path to a folder with files used for grounding.
"""
if folders_paths is not None:
for folder_path in folders_paths:
try:
logger.debug(f"Adding the following folder to grounding index: {folder_path}")
self.add_folder(folder_path)
except (FileNotFoundError, ValueError) as e:
print(f"Error: {e}")
print(f"Current working directory: {os.getcwd()}")
print(f"Provided path: {folder_path}")
print("Please check if the path exists and is accessible.")
def add_folder(self, folder_path:str) -> None:
"""
Adds a path to a folder with files used for grounding.
"""
if folder_path not in self.loaded_folders_paths:
self._mark_folder_as_loaded(folder_path)
# for PDF files, please note that the document will be split into pages: https://github.com/run-llama/llama_index/issues/15903
new_files = SimpleDirectoryReader(folder_path).load_data()
self.add_documents(new_files, lambda doc: doc.metadata["file_name"])
def add_file_path(self, file_path:str) -> None:
"""
Adds a path to a file used for grounding.
"""
# a trick to make SimpleDirectoryReader work with a single file
new_files = SimpleDirectoryReader(input_files=[file_path]).load_data()
logger.debug(f"Adding the following file to grounding index: {new_files}")
self.add_documents(new_files, lambda doc: doc.metadata["file_name"])
def _mark_folder_as_loaded(self, folder_path:str) -> None:
if folder_path not in self.loaded_folders_paths:
self.loaded_folders_paths.append(folder_path)
if folder_path not in self.folders_paths:
self.folders_paths.append(folder_path)
@utils.post_init
class WebPagesGroundingConnector(BaseSemanticGroundingConnector):
serializable_attributes = ["web_urls"]
def __init__(self, name:str="Web Pages", web_urls: list=None) -> None:
super().__init__(name)
self.web_urls = web_urls
# @post_init ensures that _post_init is called after the __init__ method
def _post_init(self):
self.loaded_web_urls = []
if not hasattr(self, 'web_urls') or self.web_urls is None:
self.web_urls = []
# load web urls
self.add_web_urls(self.web_urls)
def add_web_urls(self, web_urls:list) -> None:
"""
Adds the data retrieved from the specified URLs to grounding.
"""
filtered_web_urls = [url for url in web_urls if url not in self.loaded_web_urls]
for url in filtered_web_urls:
self._mark_web_url_as_loaded(url)
if len(filtered_web_urls) > 0:
new_documents = SimpleWebPageReader(html_to_text=True).load_data(filtered_web_urls)
self.add_documents(new_documents, lambda doc: doc.id_)
def add_web_url(self, web_url:str) -> None:
"""
Adds the data retrieved from the specified URL to grounding.
"""
# we do it like this because the add_web_urls could run scrapes in parallel, so it is better
# to implement this one in terms of the other
self.add_web_urls([web_url])
def _mark_web_url_as_loaded(self, web_url:str) -> None:
if web_url not in self.loaded_web_urls:
self.loaded_web_urls.append(web_url)
if web_url not in self.web_urls:
self.web_urls.append(web_url)
|