""" app.py An agent with access to a hybrid search tool and a large language model. The search tool has access to a collection of documents from the OECD related to international tax crimes. Agentic framework: - smolagents Retrieval model: - LanceDB: support for hybrid search search with reranking of results. - Full text search (lexical): BM25 - Vector search (semantic dense vectors): BAAI/bge-m3 Rerankers: - ColBERT, cross encoder, reciprocal rank fusion, AnswerDotAI Generation: - Mistral :author: Didier Guillevic :date: 2025-01-05 """ import gradio as gr import lancedb import smolagents import os import logging logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) # # LanceDB with the indexed documents # # Connect to the database lance_db = lancedb.connect("lance.db") lance_tbl = lance_db.open_table("documents") # Document schema class Document(lancedb.pydantic.LanceModel): text: str vector: lancedb.pydantic.Vector(1024) file_name: str num_pages: int creation_date: str modification_date: str # # Retrieval: query types and reranker types # query_types = { 'lexical': 'fts', 'semantic': 'vector', 'hybrid': 'hybrid', } # Define a few rerankers colbert_reranker = lancedb.rerankers.ColbertReranker(column='text') answerai_reranker = lancedb.rerankers.AnswerdotaiRerankers(column='text') crossencoder_reranker = lancedb.rerankers.CrossEncoderReranker(column='text') reciprocal_rank_fusion_reranker = lancedb.rerankers.RRFReranker() # hybrid search only reranker_types = { 'ColBERT': colbert_reranker, 'cross encoder': crossencoder_reranker, 'AnswerAI': answerai_reranker, 'Reciprocal Rank Fusion': reciprocal_rank_fusion_reranker } def search_table( table: lancedb.table, query: str, query_type: str='hybrid', reranker_name: str='cross encoder', filter_year: int=2000, top_k: int=5, overfetch_factor: int=2 ): # Get the instance of reranker reranker = reranker_types.get(reranker_name) if reranker is None: logger.error(f"Invalid reranker name: {reranker_name}") raise ValueError(f"Invalid reranker selected: {reranker_name}") if query_type in ["vector", "fts"]: if reranker == reciprocal_rank_fusion_reranker: # reciprocal is for 'hybrid' search type only reranker = crossencoder_reranker results = ( table.search(query, query_type=query_type) .where(f"creation_date >= '{filter_year}'", prefilter=True) .rerank(reranker=reranker) .limit(top_k * overfetch_factor) .to_pydantic(Document) ) elif query_type == "hybrid": results = ( table.search(query, query_type=query_type) .where(f"creation_date >= '{filter_year}'", prefilter=True) .rerank(reranker=reranker) .limit(top_k) .to_pydantic(Document) ) return results[:top_k] # # Define a retriever tool # class RetrieverTool(smolagents.Tool): name = "retriever" description = "Uses hybrid search to retrieve snippets from OECD documents that could be most relevant to answer your query." inputs = { "query": { "type": "string", "description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.", } } output_type = "string" def __init__(self, **kwargs): super().__init__(**kwargs) def forward(self, query: str) -> str: assert isinstance(query, str), "Your search query must be a string" results = search_table(table=lance_tbl, query=query) return "\nRetrieved documents:\n" + "".join( [ f"\n\n===== Document {str(i)} =====\n" + result.text for i, result in enumerate(results) ] ) retriever_tool = RetrieverTool() # # Define a language model # mistral_api_key = os.environ["MISTRAL_API_KEY"] mistral_model_id = "mistral/mistral-large-latest" # 128k context window #mistral_model_id = "mistral/codestral-latest" mistral_model = smolagents.LiteLLMModel( model_id=mistral_model_id, api_key=mistral_api_key) # # Define an agent with access to tool(s) and language model. # agent = smolagents.CodeAgent( tools=[retriever_tool], model=mistral_model, max_iterations=4, verbose=True ) # # app # def generate_response(query: str) -> str: """Generate a response given query, search type and reranker. Args: Returns: - the response from the agent having access to a retriever tool over a collection of documents and a large language model. """ agent_output = agent.run(query) return agent_output # # User interface # with gr.Blocks() as demo: gr.Markdown(""" # Agentic Hybrid search Document collection: OECD documents on international tax crimes. """) # Inputs: question question = gr.Textbox( label="Question to answer", placeholder="" ) # Response / references / snippets response = gr.Textbox( label="Response", placeholder="" ) # Button with gr.Row(): response_button = gr.Button("Submit", variant='primary') clear_button = gr.Button("Clear", variant='secondary') # Example questions given default provided PDF file with gr.Accordion("Sample questions", open=False): gr.Examples( [ ["What is the OECD's role in combating offshore tax evasion?",], ["What are the key tools used in fighting offshore tax evasion?",], ['What are "High Net Worth Individuals" (HNWIs) and how do they relate to tax compliance efforts?',], ["What is the significance of international financial centers (IFCs) in the context of tax evasion?",], ["What is being done to address the role of professional enablers in facilitating tax evasion?",], ["How does the OECD measure the effectiveness of international efforts to fight offshore tax evasion?",], ['What are the "Ten Global Principles" for fighting tax crime?',], ["What are some recent developments in the fight against offshore tax evasion?",], ], inputs=[question,], outputs=[response,], fn=generate_response, cache_examples=False, label="Sample questions" ) # Documentation with gr.Accordion("Documentation", open=False): gr.Markdown(""" - Agentic framework - Hugging Face's smolagents - Retrieval model - LanceDB: support for hybrid search search with reranking of results. - Full text search (lexical): BM25 - Vector search (semantic dense vectors): BAAI/bge-m3 - Rerankers - ColBERT, cross encoder, reciprocal rank fusion, AnswerDotAI - Generation - Mistral - Examples - Generated using Google NotebookLM """) # Click actions response_button.click( fn=generate_response, inputs=[question,], outputs=[response,] ) clear_button.click( fn=lambda: ('', ''), inputs=[], outputs=[question, response] ) demo.launch(show_api=False)