RCaz commited on
Commit
3bc88d4
·
verified ·
1 Parent(s): ff70d70

Create retreiver.py

Browse files
Files changed (1) hide show
  1. retreiver.py +63 -0
retreiver.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datasets
2
+ from langchain.docstore.document import Document
3
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
4
+ from langchain_community.retrievers import BM25Retriever
5
+
6
+ # Load the wikipedia dataset
7
+ knowledge_base = datasets.load_dataset("wikimedia/wikipedia", "20231101.en")
8
+
9
+ # Convert dataset entries to Document objects with metadata
10
+ source_docs = [
11
+ Document(page_content=doc["text"], metadata={"title": doc["title"]})
12
+ for doc in knowledge_base
13
+ ]
14
+
15
+ # Split documents into smaller chunks for better retrieval
16
+ text_splitter = RecursiveCharacterTextSplitter(
17
+ chunk_size=500, # Characters per chunk
18
+ chunk_overlap=50, # Overlap between chunks to maintain context
19
+ add_start_index=True,
20
+ strip_whitespace=True,
21
+ separators=["\n\n", "\n", ".", " ", ""], # Priority order for splitting
22
+ )
23
+ docs_processed = text_splitter.split_documents(source_docs)
24
+
25
+ print(f"Knowledge base prepared with {len(docs_processed)} document chunks")
26
+
27
+ from smolagents import Tool
28
+
29
+ class RetrieverTool(Tool):
30
+ name = "retriever"
31
+ description = "Uses semantic search to retrieve wikipedia article that could be most relevant to answer your query."
32
+ inputs = {
33
+ "query": {
34
+ "type": "string",
35
+ "description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
36
+ }
37
+ }
38
+ output_type = "string"
39
+
40
+ def __init__(self, docs, **kwargs):
41
+ super().__init__(**kwargs)
42
+ # Initialize the retriever with our processed documents
43
+ self.retriever = BM25Retriever.from_documents(
44
+ docs, k=10 # Return top 10 most relevant documents
45
+ )
46
+
47
+ def forward(self, query: str) -> str:
48
+ """Execute the retrieval based on the provided query."""
49
+ assert isinstance(query, str), "Your search query must be a string"
50
+
51
+ # Retrieve relevant documents
52
+ docs = self.retriever.invoke(query)
53
+
54
+ # Format the retrieved documents for readability
55
+ return "\nRetrieved documents:\n" + "".join(
56
+ [
57
+ f"\n\n===== Document {str(i)} =====\n" + doc.page_content
58
+ for i, doc in enumerate(docs)
59
+ ]
60
+ )
61
+
62
+ # Initialize our retriever tool with the processed documents
63
+ retriever_tool = RetrieverTool(docs_processed)