Spaces:
Sleeping
Sleeping
import os | |
from dotenv import load_dotenv | |
from langchain import hub | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.runnables import RunnablePassthrough, RunnableLambda | |
from langchain_core.messages.base import BaseMessage | |
from basic_chain import basic_chain, get_model | |
from remote_loader import load_wiki_articles # updated import | |
from splitter import split_documents | |
from vector_store import create_vector_db | |
def find_similar(vs, query): | |
docs = vs.similarity_search(query) | |
return docs | |
def format_docs(docs): | |
return "\n\n".join(doc.page_content for doc in docs) | |
def get_question(input): | |
if not input: | |
return None | |
elif isinstance(input,str): | |
return input | |
elif isinstance(input,dict) and 'question' in input: | |
return input['question'] | |
elif isinstance(input,BaseMessage): | |
return input.content | |
else: | |
raise Exception("string or dict with 'question' key expected as RAG chain input.") | |
def make_rag_chain(model, retriever, rag_prompt = None): | |
# We will use a prompt template from langchain hub. | |
if not rag_prompt: | |
rag_prompt = hub.pull("rlm/rag-prompt") | |
# And we will use the LangChain RunnablePassthrough to add some custom processing into our chain. | |
rag_chain = ( | |
{ | |
"context": RunnableLambda(get_question) | retriever | format_docs, | |
"question": RunnablePassthrough() | |
} | |
| rag_prompt | |
| model | |
) | |
return rag_chain | |
def main(): | |
load_dotenv() | |
model = get_model("ChatGPT") | |
docs = load_wiki_articles(query="Bertrand Russell", load_max_docs=5) # Updated | |
texts = split_documents(docs) | |
vs = create_vector_db(texts) | |
prompt = ChatPromptTemplate.from_messages([ | |
("system", "You are a professor who teaches philosophical concepts to beginners."), | |
("user", "{input}") | |
]) | |
# Besides similarly search, you can also use maximal marginal relevance (MMR) for selecting results. | |
# retriever = vs.as_retriever(search_type="mmr") | |
retriever = vs.as_retriever() | |
output_parser = StrOutputParser() | |
chain = basic_chain(model, prompt) | |
base_chain = chain | output_parser | |
rag_chain = make_rag_chain(model, retriever) | output_parser | |
questions = [ | |
"What were the most important contributions of Bertrand Russell to philosophy?", | |
"What was the first book Bertrand Russell published?", | |
"What was most notable about \"An Essay on the Foundations of Geometry\"?", | |
] | |
for q in questions: | |
print("\n--- QUESTION: ", q) | |
print("* BASE:\n", base_chain.invoke({"input": q})) | |
print("* RAG:\n", rag_chain.invoke(q)) | |
if __name__ == '__main__': | |
# this is to quite parallel tokenizers warning. | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
main() | |