|
import datasets |
|
from langchain.docstore.document import Document |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_community.retrievers import BM25Retriever |
|
|
|
|
|
knowledge_base = datasets.load_dataset("wikimedia/wikipedia", "20231101.en") |
|
|
|
|
|
source_docs = [ |
|
Document(page_content=doc["text"], metadata={"title": doc["title"]}) |
|
for doc in knowledge_base |
|
] |
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=500, |
|
chunk_overlap=50, |
|
add_start_index=True, |
|
strip_whitespace=True, |
|
separators=["\n\n", "\n", ".", " ", ""], |
|
) |
|
docs_processed = text_splitter.split_documents(source_docs) |
|
|
|
print(f"Knowledge base prepared with {len(docs_processed)} document chunks") |
|
|
|
from smolagents import Tool |
|
|
|
class RetrieverTool(Tool): |
|
name = "retriever" |
|
description = "Faster than WikipediaSearchTool, it uses semantic search to retrieve wikipedia article that could be most relevant to answer your query." |
|
inputs = { |
|
"query": { |
|
"type": "string", |
|
"description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.", |
|
} |
|
} |
|
output_type = "string" |
|
|
|
def __init__(self, docs, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
self.retriever = BM25Retriever.from_documents( |
|
docs, k=10 |
|
) |
|
|
|
def forward(self, query: str) -> str: |
|
"""Execute the retrieval based on the provided query.""" |
|
assert isinstance(query, str), "Your search query must be a string" |
|
|
|
|
|
docs = self.retriever.invoke(query) |
|
|
|
|
|
return "\nRetrieved documents:\n" + "".join( |
|
[ |
|
f"\n\n===== Document {str(i)} =====\n" + doc.page_content |
|
for i, doc in enumerate(docs) |
|
] |
|
) |
|
|
|
|
|
retriever_tool = RetrieverTool(docs_processed) |