Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import json | |
from typing import Any, Dict, List, Optional | |
from langchain_core.callbacks import CallbackManagerForChainRun | |
from langchain_core.language_models import BaseLanguageModel | |
from langchain_core.prompts import BasePromptTemplate | |
from langchain_core.pydantic_v1 import Field | |
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter | |
from langchain.chains.base import Chain | |
from langchain.chains.llm import LLMChain | |
from langchain.chains.qa_generation.prompt import PROMPT_SELECTOR | |
class QAGenerationChain(Chain): | |
"""Base class for question-answer generation chains.""" | |
llm_chain: LLMChain | |
"""LLM Chain that generates responses from user input and context.""" | |
text_splitter: TextSplitter = Field( | |
default=RecursiveCharacterTextSplitter(chunk_overlap=500) | |
) | |
"""Text splitter that splits the input into chunks.""" | |
input_key: str = "text" | |
"""Key of the input to the chain.""" | |
output_key: str = "questions" | |
"""Key of the output of the chain.""" | |
k: Optional[int] = None | |
"""Number of questions to generate.""" | |
def from_llm( | |
cls, | |
llm: BaseLanguageModel, | |
prompt: Optional[BasePromptTemplate] = None, | |
**kwargs: Any, | |
) -> QAGenerationChain: | |
""" | |
Create a QAGenerationChain from a language model. | |
Args: | |
llm: a language model | |
prompt: a prompt template | |
**kwargs: additional arguments | |
Returns: | |
a QAGenerationChain class | |
""" | |
_prompt = prompt or PROMPT_SELECTOR.get_prompt(llm) | |
chain = LLMChain(llm=llm, prompt=_prompt) | |
return cls(llm_chain=chain, **kwargs) | |
def _chain_type(self) -> str: | |
raise NotImplementedError | |
def input_keys(self) -> List[str]: | |
return [self.input_key] | |
def output_keys(self) -> List[str]: | |
return [self.output_key] | |
def _call( | |
self, | |
inputs: Dict[str, Any], | |
run_manager: Optional[CallbackManagerForChainRun] = None, | |
) -> Dict[str, List]: | |
docs = self.text_splitter.create_documents([inputs[self.input_key]]) | |
results = self.llm_chain.generate( | |
[{"text": d.page_content} for d in docs], run_manager=run_manager | |
) | |
qa = [json.loads(res[0].text) for res in results.generations] | |
return {self.output_key: qa} | |