from collections import defaultdict from typing import List, Union import dspy from dsp.utils import dotdict from typing import Optional try: import weaviate except ImportError: raise ImportError( "The 'weaviate' extra is required to use WeaviateRM. Install it with `pip install dspy-ai[weaviate]`" ) class WeaviateRM(dspy.Retrieve): """ A retrieval module that uses Weaviate to return the top passages for a given query. Assumes that a Weaviate collection has been created and populated with the following payload: - content: The text of the passage Args: weaviate_collection_name (str): The name of the Weaviate collection. weaviate_client (WeaviateClient): An instance of the Weaviate client. k (int, optional): The default number of top passages to retrieve. Defaults to 3. Examples: Below is a code snippet that shows how to use Weaviate as the default retriver: ```python import weaviate llm = dspy.OpenAI(model="gpt-3.5-turbo") weaviate_client = weaviate.Client("your-path-here") retriever_model = WeaviateRM(weaviate_collection_name="my_collection_name", weaviate_collection_text_key="content", weaviate_client=weaviate_client) dspy.settings.configure(lm=llm, rm=retriever_model) ``` Below is a code snippet that shows how to use Weaviate in the forward() function of a module ```python self.retrieve = WeaviateRM("my_collection_name", weaviate_client=weaviate_client, k=num_passages) ``` """ def __init__(self, weaviate_collection_name: str, weaviate_client: weaviate.Client, k: int = 3, weaviate_collection_text_key: Optional[str] = "content" ): self._weaviate_collection_name = weaviate_collection_name self._weaviate_client = weaviate_client self._weaviate_collection_text_key = weaviate_collection_text_key super().__init__(k=k) def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int]) -> dspy.Prediction: """Search with Weaviate for self.k top passages for query Args: query_or_queries (Union[str, List[str]]): The query or queries to search for. k (Optional[int]): The number of top passages to retrieve. Defaults to self.k. Returns: dspy.Prediction: An object containing the retrieved passages. """ k = k if k is not None else self.k queries = ( [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries ) queries = [q for q in queries if q] passages = [] for query in queries: results = self._weaviate_client.query\ .get(self._weaviate_collection_name, [self._weaviate_collection_text_key])\ .with_hybrid(query=query)\ .with_limit(k)\ .do() results = results["data"]["Get"][self._weaviate_collection_name] parsed_results = [result[self._weaviate_collection_text_key] for result in results] passages.extend(dotdict({"long_text": d}) for d in parsed_results) return passages