from collections import defaultdict from typing import List, Union import dspy from typing import Optional import json import os import requests from dsp.utils import dotdict START_SNIPPET = "<%START%>" END_SNIPPET = "<%END%>" def remove_snippet(s: str) -> str: return s.replace(START_SNIPPET, "").replace(END_SNIPPET, "") class VectaraRM(dspy.Retrieve): """ A retrieval module that uses Vectara to return the top passages for a given query. Assumes that a Vectara corpus has been created and populated with the following payload: - document: The text of the passage Args: vectara_customer_id (str): Vectara Customer ID. defaults to VECTARA_CUSTOMER_ID environment variable vectara_corpus_id (str): Vectara Corpus ID. defaults to VECTARA_CORPUS_ID environment variable vectara_api_key (str): Vectara API Key. defaults to VECTARA_API_KEY environment variable k (int, optional): The default number of top passages to retrieve. Defaults to 3. Examples: Below is a code snippet that shows how to use Vectara as the default retriver: ```python from vectara_client import vectaraClient llm = dspy.OpenAI(model="gpt-3.5-turbo") retriever_model = vectaraRM("", "", "") dspy.settings.configure(lm=llm, rm=retriever_model) ``` Below is a code snippet that shows how to use Vectara in the forward() function of a module ```python self.retrieve = vectaraRM("", "", "", k=num_passages) ``` """ def __init__( self, vectara_customer_id: Optional[str] = None, vectara_corpus_id: Optional[str] = None, vectara_api_key: Optional[str] = None, k: int = 5, ): if vectara_customer_id is None: vectara_customer_id = os.environ.get("VECTARA_CUSTOMER_ID", "") if vectara_corpus_id is None: vectara_corpus_id = os.environ.get("VECTARA_CORPUS_ID", "") if vectara_api_key is None: vectara_api_key = os.environ.get("VECTARA_API_KEY", "") self._vectara_customer_id = vectara_customer_id self._vectara_corpus_id = vectara_corpus_id self._vectara_api_key = vectara_api_key self._n_sentences_before = self._n_sentences_after = 2 self._vectara_timeout = 120 super().__init__(k=k) def _vectara_query( self, query: str, limit: int = 3, ) -> List[str]: """Query Vectara index to get for top k matching passages. Args: query: query string """ corpus_key = { "customerId": self._vectara_customer_id, "corpusId": self._vectara_corpus_id, "lexicalInterpolationConfig": {"lambda": 0.025 } } data = { "query": [ { "query": query, "start": 0, "numResults": limit, "contextConfig": { "sentencesBefore": self._n_sentences_before, "sentencesAfter": self._n_sentences_after, "startTag": START_SNIPPET, "endTag": END_SNIPPET, }, "corpusKey": [corpus_key], } ] } headers = { "x-api-key": self._vectara_api_key, "customer-id": self._vectara_customer_id, "Content-Type": "application/json", "X-Source": "dspy", } response = requests.post( headers=headers, url="https://api.vectara.io/v1/query", data=json.dumps(data), timeout=self._vectara_timeout, ) if response.status_code != 200: print( "Query failed %s", f"(code {response.status_code}, reason {response.reason}, details " f"{response.text})", ) return [] result = response.json() responses = result["responseSet"][0]["response"] res = [ { "text": remove_snippet(x["text"]), "score": x["score"] } for x in responses ] return res def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int]) -> dspy.Prediction: """Search with Vectara for self.k top passages for query Args: query_or_queries (Union[str, List[str]]): The query or queries to search for. k (Optional[int]): The number of top passages to retrieve. Defaults to self.k. Returns: dspy.Prediction: An object containing the retrieved passages. """ queries = ( [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries ) queries = [q for q in queries if q] # Filter empty queries k = k if k is not None else self.k all_res = [] limit = 3*k if len(queries) > 1 else k for query in queries: _res = self._vectara_query(query, limit=limit) all_res.append(_res) passages = defaultdict(float) for res_list in all_res: for res in res_list: passages[res["text"]] += res["score"] sorted_passages = sorted( passages.items(), key=lambda x: x[1], reverse=True)[:k] return [dotdict({"long_text": passage}) for passage, _ in sorted_passages]