from langchain.llms.huggingface_pipeline import HuggingFacePipeline from langchain.retrievers.multi_query import MultiQueryRetriever # Set logging for the queries import logging logging.basicConfig() class MultiQueryDocumentRetriever: def __init__(self, vector_store): self.vector_store = vector_store self.retriever = None self.llm = None # self.token = "LL-1kuyxK1z5NQYOiOsf5UdozHJuLhV6udoDGxL8NfM7brWCUbF0uqlii15sso8GNrd" def initialize(self): # self.llama = LlamaAPI(self.token) self.llm = HuggingFacePipeline.from_model_id( # model_id="bigscience/bloom-1b7", model_id="bigscience/bloomz-1b7", task="text-generation", # device=1, # model_kwargs={"do_sample": True, "temperature": 0.7, "num_beams": 4, "top_p": 0.95, "repetition_penalty": 1.25, "length_penalty": 1.2}, model_kwargs={"do_sample": True, "temperature": 0.7, "num_beams": 2}, # pipeline_kwargs={"max_new_tokens": 256, "min_new_tokens": 30}, pipeline_kwargs={"max_new_tokens": 256, "min_new_tokens": 30}, ) logging.getLogger("langchain.retrievers.multi_query").setLevel(logging.INFO) self.retriever = MultiQueryRetriever.from_llm( retriever=self.vector_store.db.as_retriever(search_kwargs={"k": 4, "fetch_k": 40}), llm=self.llm ) def retrieve(self, query: str, k: int = 4): pass