from collections import defaultdict from typing import List, Union import dspy try: import marqo except ImportError: raise ImportError( "The 'marqo' extra is required to use MarqoRM. Install it with `pip install dspy-ai[marqo]`" ) class MarqoRM(dspy.Retrieve): """ A retrieval module that uses Marqo to return the top passages for a given query. Assumes that a Marqo index has been created and populated with the following payload: - document: The text of the passage Args: marqo_index_name (str): The name of the marqo index. marqo_client (marqo.client.Client): A marqo client instance. k (int, optional): The number of top passages to retrieve. Defaults to 3. Examples: Below is a code snippet that shows how to use Marqo as the default retriver: ```python import marqo marqo_client = marqo.Client(url="http://0.0.0.0:8882") llm = dspy.OpenAI(model="gpt-3.5-turbo") retriever_model = MarqoRM("my_index_name", marqo_client=marqo_client) dspy.settings.configure(lm=llm, rm=retriever_model) ``` Below is a code snippet that shows how to use Marqo in the forward() function of a module ```python self.retrieve = MarqoRM("my_index_name", marqo_client=marqo_client, k=num_passages) ``` """ def __init__( self, marqo_index_name: str, marqo_client: marqo.client.Client, k: int = 3, ): self._marqo_index_name = marqo_index_name self._marqo_client = marqo_client super().__init__(k=k) def forward(self, query_or_queries: Union[str, List[str]]) -> dspy.Prediction: """Search with Marqo for self.k top passages for query Args: query_or_queries (Union[str, List[str]]): The query or queries to search for. Returns: dspy.Prediction: An object containing the retrieved passages. """ queries = ( [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries ) queries = [q for q in queries if q] all_query_results = [] for query in queries: _result = self._marqo_client.index(self._marqo_index_name).search( q=query, limit=self.k ) all_query_results.append(_result) passages = defaultdict(float) for result_dict in all_query_results: for result in result_dict['hits']: passages[result['document']] += result['_score'] sorted_passages = sorted( passages.items(), key=lambda x: x[1], reverse=True)[:self.k] return dspy.Prediction(passages=[passage for passage, _ in sorted_passages])