RAG-Chatbot / app.py
mohamedachraf's picture
modify the pipeline
7bfe361
import nltk
try:
nltk.download('averaged_perceptron_tagger_eng', quiet=True)
nltk.download("punkt", quiet=True)
nltk.download('punkt_tab', quiet=True)
except Exception as e:
print(f"Warning: NLTK download failed: {e}")
import gradio as gr
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.document_loaders import UnstructuredFileLoader, PyPDFLoader
from langchain.vectorstores.faiss import FAISS
from langchain.vectorstores.utils import DistanceStrategy
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.schema import Document
from langchain.chains import RetrievalQA
from langchain.prompts.prompt import PromptTemplate
from langchain.vectorstores.base import VectorStoreRetriever
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from transformers import TextIteratorStreamer
from threading import Thread
import os
import tempfile
# Prompt template optimized for Flan-T5
template = """Answer the question based on the context below.
Context: {context}
Question: {question}
Answer:"""
QA_PROMPT = PromptTemplate(template=template, input_variables=["question", "context"])
# Load Flan-T5 model from hugging face hub - excellent for CPU and Q&A tasks
# Alternative popular CPU-friendly models you can try:
# - "google/flan-t5-small" (faster, smaller)
# - "google/flan-t5-large" (better quality, slower)
# - "microsoft/DialoGPT-medium" (conversational)
model_id = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(
model_id, torch_dtype=torch.float32
)
# sentence transformers to be used in vector store
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/msmarco-distilbert-base-v4",
model_kwargs={"device": "cpu"},
encode_kwargs={"normalize_embeddings": False},
)
def clean_response(text):
"""Clean up the generated response"""
# Remove excessive whitespace and newlines
text = ' '.join(text.split())
# Remove repetitive patterns
words = text.split()
cleaned_words = []
for word in words:
# Skip if the same word appears too many times consecutively
if len(cleaned_words) >= 3 and all(w == word for w in cleaned_words[-3:]):
continue
cleaned_words.append(word)
cleaned_text = ' '.join(cleaned_words)
# Truncate at natural stopping points
sentences = cleaned_text.split('.')
if len(sentences) > 1:
# Keep complete sentences
good_sentences = []
for sentence in sentences[:-1]: # Exclude last potentially incomplete sentence
if len(sentence.strip()) > 5: # Avoid very short fragments
good_sentences.append(sentence.strip())
if good_sentences:
return '. '.join(good_sentences) + '.'
return cleaned_text[:500] # Fallback: truncate to reasonable length
# Returns a faiss vector store retriever given a txt or pdf file
def prepare_vector_store_retriever(filename):
# Load data based on file extension
if filename.lower().endswith('.pdf'):
loader = PyPDFLoader(filename)
else:
loader = UnstructuredFileLoader(filename)
raw_documents = loader.load()
# Split the text
text_splitter = CharacterTextSplitter(
separator="\n\n", chunk_size=800, chunk_overlap=0, length_function=len
)
documents = text_splitter.split_documents(raw_documents)
# Creating a vectorstore
vectorstore = FAISS.from_documents(
documents, embeddings, distance_strategy=DistanceStrategy.DOT_PRODUCT
)
return VectorStoreRetriever(vectorstore=vectorstore, search_kwargs={"k": 2}), vectorstore
# Retrieval QA chain
def get_retrieval_qa_chain(text_file, hf_model):
retriever = default_retriever
vectorstore = default_vectorstore
if text_file != default_text_file or default_text_file is None:
if text_file is not None and os.path.exists(text_file):
retriever, vectorstore = prepare_vector_store_retriever(text_file)
else:
# Create a dummy retriever if no file is available
dummy_doc = Document(page_content="No document loaded. Please upload a file to get started.")
dummy_vectorstore = FAISS.from_documents([dummy_doc], embeddings)
retriever = VectorStoreRetriever(vectorstore=dummy_vectorstore, search_kwargs={"k": 1})
vectorstore = dummy_vectorstore
chain = RetrievalQA.from_chain_type(
llm=hf_model,
retriever=retriever,
chain_type_kwargs={"prompt": QA_PROMPT},
)
return chain, vectorstore
# Generates response using the question answering chain defined earlier
def generate(question, answer, text_file, max_new_tokens):
if not question.strip():
yield "Please enter a question."
return
try:
# Create pipeline for text2text generation (Flan-T5)
phi2_pipeline = pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=max_new_tokens,
do_sample=False,
)
hf_model = HuggingFacePipeline(pipeline=phi2_pipeline)
qa_chain, vectorstore = get_retrieval_qa_chain(text_file, hf_model)
query = f"{question}"
if len(tokenizer.tokenize(query)) >= 512:
yield "Your question is too long! Please shorten it."
return
# Get the response directly without streaming first
try:
result = qa_chain.invoke({"query": query})
# Extract the answer from the result
if isinstance(result, dict):
response = result.get('result', str(result))
else:
response = str(result)
# Clean the response
cleaned_response = clean_response(response)
yield cleaned_response
except Exception as e:
yield f"Error during generation: {str(e)}"
return
except Exception as e:
yield f"Error: {str(e)}"
# replaces the retriever in the question answering chain whenever a new file is uploaded
def upload_file(file):
if file is not None:
# In Gradio, file is already a path to the uploaded file
file_path = file.name if hasattr(file, 'name') else file
filename = os.path.basename(file_path)
return filename, file_path
return None, None
with gr.Blocks() as demo:
gr.Markdown(
"""
# Retrieval Augmented Generation with Flan-T5: Question Answering demo
### This demo uses Google's Flan-T5 language model and Retrieval Augmented Generation (RAG). It allows you to upload a txt or PDF file and ask the model questions related to the content of that file.
### Features:
- Support for both PDF and text files
- Retrieval-based question answering using document context
- Optimized for CPU performance using Flan-T5-Base model
### To get started, upload a text (.txt) or PDF (.pdf) file using the upload button below.
The Flan-T5 model is efficient and works well on CPU, making it perfect for document Q&A tasks.
Retrieval Augmented Generation (RAG) enables us to retrieve just the few small chunks of the document that are relevant to your query and inject it into our prompt.
The model is then able to answer questions by incorporating knowledge from the newly provided document.
"""
)
default_text_file = "Oppenheimer-movie-wiki.txt"
# Check if default file exists, if not, set to None
if not os.path.exists(default_text_file):
default_text_file = None
default_retriever = None
default_vectorstore = None
initial_file_display = "No default file found - please upload a file"
else:
default_retriever, default_vectorstore = prepare_vector_store_retriever(default_text_file)
initial_file_display = default_text_file
text_file = gr.State(default_text_file)
gr.Markdown(
"## Upload a txt or PDF file to get started"
)
file_name = gr.Textbox(
label="Loaded file", value=initial_file_display, lines=1, interactive=False
)
upload_button = gr.UploadButton(
label="Click to upload a text or PDF file", file_types=[".txt", ".pdf"], file_count="single"
)
upload_button.upload(upload_file, upload_button, [file_name, text_file])
gr.Markdown("## Enter your question")
tokens_slider = gr.Slider(
8,
256,
value=64,
label="Maximum new tokens",
info="A larger `max_new_tokens` parameter value gives you longer text responses but at the cost of a slower response time.",
)
with gr.Row():
with gr.Column():
ques = gr.Textbox(label="Question", placeholder="Enter text here", lines=3)
with gr.Column():
ans = gr.Textbox(label="Answer", lines=4, interactive=False)
with gr.Row():
with gr.Column():
btn = gr.Button("Submit")
with gr.Column():
clear = gr.ClearButton([ques, ans])
btn.click(fn=generate, inputs=[ques, ans, text_file, tokens_slider], outputs=[ans])
examples = gr.Examples(
examples=[
"Who portrayed J. Robert Oppenheimer in the new Oppenheimer movie?",
"In the plot of the movie, why did Lewis Strauss resent Robert Oppenheimer?",
"How much money did the Oppenheimer movie make at the US and global box office?",
"What score did the Oppenheimer movie get on Rotten Tomatoes and Metacritic?",
],
inputs=[ques],
)
demo.queue().launch()