import nltk try: nltk.download('averaged_perceptron_tagger_eng', quiet=True) nltk.download("punkt", quiet=True) nltk.download('punkt_tab', quiet=True) except Exception as e: print(f"Warning: NLTK download failed: {e}") import gradio as gr from langchain.text_splitter import CharacterTextSplitter from langchain_community.document_loaders import UnstructuredFileLoader, PyPDFLoader from langchain.vectorstores.faiss import FAISS from langchain.vectorstores.utils import DistanceStrategy from langchain_community.embeddings import HuggingFaceEmbeddings from langchain.schema import Document from langchain.chains import RetrievalQA from langchain.prompts.prompt import PromptTemplate from langchain.vectorstores.base import VectorStoreRetriever import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline from transformers import TextIteratorStreamer from threading import Thread import os import tempfile # Prompt template optimized for Flan-T5 template = """Answer the question based on the context below. Context: {context} Question: {question} Answer:""" QA_PROMPT = PromptTemplate(template=template, input_variables=["question", "context"]) # Load Flan-T5 model from hugging face hub - excellent for CPU and Q&A tasks # Alternative popular CPU-friendly models you can try: # - "google/flan-t5-small" (faster, smaller) # - "google/flan-t5-large" (better quality, slower) # - "microsoft/DialoGPT-medium" (conversational) model_id = "google/flan-t5-base" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForSeq2SeqLM.from_pretrained( model_id, torch_dtype=torch.float32 ) # sentence transformers to be used in vector store embeddings = HuggingFaceEmbeddings( model_name="sentence-transformers/msmarco-distilbert-base-v4", model_kwargs={"device": "cpu"}, encode_kwargs={"normalize_embeddings": False}, ) def clean_response(text): """Clean up the generated response""" # Remove excessive whitespace and newlines text = ' '.join(text.split()) # Remove repetitive patterns words = text.split() cleaned_words = [] for word in words: # Skip if the same word appears too many times consecutively if len(cleaned_words) >= 3 and all(w == word for w in cleaned_words[-3:]): continue cleaned_words.append(word) cleaned_text = ' '.join(cleaned_words) # Truncate at natural stopping points sentences = cleaned_text.split('.') if len(sentences) > 1: # Keep complete sentences good_sentences = [] for sentence in sentences[:-1]: # Exclude last potentially incomplete sentence if len(sentence.strip()) > 5: # Avoid very short fragments good_sentences.append(sentence.strip()) if good_sentences: return '. '.join(good_sentences) + '.' return cleaned_text[:500] # Fallback: truncate to reasonable length # Returns a faiss vector store retriever given a txt or pdf file def prepare_vector_store_retriever(filename): # Load data based on file extension if filename.lower().endswith('.pdf'): loader = PyPDFLoader(filename) else: loader = UnstructuredFileLoader(filename) raw_documents = loader.load() # Split the text text_splitter = CharacterTextSplitter( separator="\n\n", chunk_size=800, chunk_overlap=0, length_function=len ) documents = text_splitter.split_documents(raw_documents) # Creating a vectorstore vectorstore = FAISS.from_documents( documents, embeddings, distance_strategy=DistanceStrategy.DOT_PRODUCT ) return VectorStoreRetriever(vectorstore=vectorstore, search_kwargs={"k": 2}), vectorstore # Retrieval QA chain def get_retrieval_qa_chain(text_file, hf_model): retriever = default_retriever vectorstore = default_vectorstore if text_file != default_text_file or default_text_file is None: if text_file is not None and os.path.exists(text_file): retriever, vectorstore = prepare_vector_store_retriever(text_file) else: # Create a dummy retriever if no file is available dummy_doc = Document(page_content="No document loaded. Please upload a file to get started.") dummy_vectorstore = FAISS.from_documents([dummy_doc], embeddings) retriever = VectorStoreRetriever(vectorstore=dummy_vectorstore, search_kwargs={"k": 1}) vectorstore = dummy_vectorstore chain = RetrievalQA.from_chain_type( llm=hf_model, retriever=retriever, chain_type_kwargs={"prompt": QA_PROMPT}, ) return chain, vectorstore # Generates response using the question answering chain defined earlier def generate(question, answer, text_file, max_new_tokens): if not question.strip(): yield "Please enter a question." return try: # Create pipeline for text2text generation (Flan-T5) phi2_pipeline = pipeline( "text2text-generation", model=model, tokenizer=tokenizer, max_new_tokens=max_new_tokens, do_sample=False, ) hf_model = HuggingFacePipeline(pipeline=phi2_pipeline) qa_chain, vectorstore = get_retrieval_qa_chain(text_file, hf_model) query = f"{question}" if len(tokenizer.tokenize(query)) >= 512: yield "Your question is too long! Please shorten it." return # Get the response directly without streaming first try: result = qa_chain.invoke({"query": query}) # Extract the answer from the result if isinstance(result, dict): response = result.get('result', str(result)) else: response = str(result) # Clean the response cleaned_response = clean_response(response) yield cleaned_response except Exception as e: yield f"Error during generation: {str(e)}" return except Exception as e: yield f"Error: {str(e)}" # replaces the retriever in the question answering chain whenever a new file is uploaded def upload_file(file): if file is not None: # In Gradio, file is already a path to the uploaded file file_path = file.name if hasattr(file, 'name') else file filename = os.path.basename(file_path) return filename, file_path return None, None with gr.Blocks() as demo: gr.Markdown( """ # Retrieval Augmented Generation with Flan-T5: Question Answering demo ### This demo uses Google's Flan-T5 language model and Retrieval Augmented Generation (RAG). It allows you to upload a txt or PDF file and ask the model questions related to the content of that file. ### Features: - Support for both PDF and text files - Retrieval-based question answering using document context - Optimized for CPU performance using Flan-T5-Base model ### To get started, upload a text (.txt) or PDF (.pdf) file using the upload button below. The Flan-T5 model is efficient and works well on CPU, making it perfect for document Q&A tasks. Retrieval Augmented Generation (RAG) enables us to retrieve just the few small chunks of the document that are relevant to your query and inject it into our prompt. The model is then able to answer questions by incorporating knowledge from the newly provided document. """ ) default_text_file = "Oppenheimer-movie-wiki.txt" # Check if default file exists, if not, set to None if not os.path.exists(default_text_file): default_text_file = None default_retriever = None default_vectorstore = None initial_file_display = "No default file found - please upload a file" else: default_retriever, default_vectorstore = prepare_vector_store_retriever(default_text_file) initial_file_display = default_text_file text_file = gr.State(default_text_file) gr.Markdown( "## Upload a txt or PDF file to get started" ) file_name = gr.Textbox( label="Loaded file", value=initial_file_display, lines=1, interactive=False ) upload_button = gr.UploadButton( label="Click to upload a text or PDF file", file_types=[".txt", ".pdf"], file_count="single" ) upload_button.upload(upload_file, upload_button, [file_name, text_file]) gr.Markdown("## Enter your question") tokens_slider = gr.Slider( 8, 256, value=64, label="Maximum new tokens", info="A larger `max_new_tokens` parameter value gives you longer text responses but at the cost of a slower response time.", ) with gr.Row(): with gr.Column(): ques = gr.Textbox(label="Question", placeholder="Enter text here", lines=3) with gr.Column(): ans = gr.Textbox(label="Answer", lines=4, interactive=False) with gr.Row(): with gr.Column(): btn = gr.Button("Submit") with gr.Column(): clear = gr.ClearButton([ques, ans]) btn.click(fn=generate, inputs=[ques, ans, text_file, tokens_slider], outputs=[ans]) examples = gr.Examples( examples=[ "Who portrayed J. Robert Oppenheimer in the new Oppenheimer movie?", "In the plot of the movie, why did Lewis Strauss resent Robert Oppenheimer?", "How much money did the Oppenheimer movie make at the US and global box office?", "What score did the Oppenheimer movie get on Rotten Tomatoes and Metacritic?", ], inputs=[ques], ) demo.queue().launch()