Spaces:
Running
Running
import nltk | |
try: | |
nltk.download('averaged_perceptron_tagger_eng', quiet=True) | |
nltk.download("punkt", quiet=True) | |
nltk.download('punkt_tab', quiet=True) | |
except Exception as e: | |
print(f"Warning: NLTK download failed: {e}") | |
import gradio as gr | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain_community.document_loaders import UnstructuredFileLoader, PyPDFLoader | |
from langchain.vectorstores.faiss import FAISS | |
from langchain.vectorstores.utils import DistanceStrategy | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain.schema import Document | |
from langchain.chains import RetrievalQA | |
from langchain.prompts.prompt import PromptTemplate | |
from langchain.vectorstores.base import VectorStoreRetriever | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline | |
from transformers import TextIteratorStreamer | |
from threading import Thread | |
import os | |
import tempfile | |
# Prompt template optimized for Flan-T5 | |
template = """Answer the question based on the context below. | |
Context: {context} | |
Question: {question} | |
Answer:""" | |
QA_PROMPT = PromptTemplate(template=template, input_variables=["question", "context"]) | |
# Load Flan-T5 model from hugging face hub - excellent for CPU and Q&A tasks | |
# Alternative popular CPU-friendly models you can try: | |
# - "google/flan-t5-small" (faster, smaller) | |
# - "google/flan-t5-large" (better quality, slower) | |
# - "microsoft/DialoGPT-medium" (conversational) | |
model_id = "google/flan-t5-base" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
model_id, torch_dtype=torch.float32 | |
) | |
# sentence transformers to be used in vector store | |
embeddings = HuggingFaceEmbeddings( | |
model_name="sentence-transformers/msmarco-distilbert-base-v4", | |
model_kwargs={"device": "cpu"}, | |
encode_kwargs={"normalize_embeddings": False}, | |
) | |
def clean_response(text): | |
"""Clean up the generated response""" | |
# Remove excessive whitespace and newlines | |
text = ' '.join(text.split()) | |
# Remove repetitive patterns | |
words = text.split() | |
cleaned_words = [] | |
for word in words: | |
# Skip if the same word appears too many times consecutively | |
if len(cleaned_words) >= 3 and all(w == word for w in cleaned_words[-3:]): | |
continue | |
cleaned_words.append(word) | |
cleaned_text = ' '.join(cleaned_words) | |
# Truncate at natural stopping points | |
sentences = cleaned_text.split('.') | |
if len(sentences) > 1: | |
# Keep complete sentences | |
good_sentences = [] | |
for sentence in sentences[:-1]: # Exclude last potentially incomplete sentence | |
if len(sentence.strip()) > 5: # Avoid very short fragments | |
good_sentences.append(sentence.strip()) | |
if good_sentences: | |
return '. '.join(good_sentences) + '.' | |
return cleaned_text[:500] # Fallback: truncate to reasonable length | |
# Returns a faiss vector store retriever given a txt or pdf file | |
def prepare_vector_store_retriever(filename): | |
# Load data based on file extension | |
if filename.lower().endswith('.pdf'): | |
loader = PyPDFLoader(filename) | |
else: | |
loader = UnstructuredFileLoader(filename) | |
raw_documents = loader.load() | |
# Split the text | |
text_splitter = CharacterTextSplitter( | |
separator="\n\n", chunk_size=800, chunk_overlap=0, length_function=len | |
) | |
documents = text_splitter.split_documents(raw_documents) | |
# Creating a vectorstore | |
vectorstore = FAISS.from_documents( | |
documents, embeddings, distance_strategy=DistanceStrategy.DOT_PRODUCT | |
) | |
return VectorStoreRetriever(vectorstore=vectorstore, search_kwargs={"k": 2}), vectorstore | |
# Retrieval QA chain | |
def get_retrieval_qa_chain(text_file, hf_model): | |
retriever = default_retriever | |
vectorstore = default_vectorstore | |
if text_file != default_text_file or default_text_file is None: | |
if text_file is not None and os.path.exists(text_file): | |
retriever, vectorstore = prepare_vector_store_retriever(text_file) | |
else: | |
# Create a dummy retriever if no file is available | |
dummy_doc = Document(page_content="No document loaded. Please upload a file to get started.") | |
dummy_vectorstore = FAISS.from_documents([dummy_doc], embeddings) | |
retriever = VectorStoreRetriever(vectorstore=dummy_vectorstore, search_kwargs={"k": 1}) | |
vectorstore = dummy_vectorstore | |
chain = RetrievalQA.from_chain_type( | |
llm=hf_model, | |
retriever=retriever, | |
chain_type_kwargs={"prompt": QA_PROMPT}, | |
) | |
return chain, vectorstore | |
# Generates response using the question answering chain defined earlier | |
def generate(question, answer, text_file, max_new_tokens): | |
if not question.strip(): | |
yield "Please enter a question." | |
return | |
try: | |
# Create pipeline for text2text generation (Flan-T5) | |
phi2_pipeline = pipeline( | |
"text2text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_new_tokens=max_new_tokens, | |
do_sample=False, | |
) | |
hf_model = HuggingFacePipeline(pipeline=phi2_pipeline) | |
qa_chain, vectorstore = get_retrieval_qa_chain(text_file, hf_model) | |
query = f"{question}" | |
if len(tokenizer.tokenize(query)) >= 512: | |
yield "Your question is too long! Please shorten it." | |
return | |
# Get the response directly without streaming first | |
try: | |
result = qa_chain.invoke({"query": query}) | |
# Extract the answer from the result | |
if isinstance(result, dict): | |
response = result.get('result', str(result)) | |
else: | |
response = str(result) | |
# Clean the response | |
cleaned_response = clean_response(response) | |
yield cleaned_response | |
except Exception as e: | |
yield f"Error during generation: {str(e)}" | |
return | |
except Exception as e: | |
yield f"Error: {str(e)}" | |
# replaces the retriever in the question answering chain whenever a new file is uploaded | |
def upload_file(file): | |
if file is not None: | |
# In Gradio, file is already a path to the uploaded file | |
file_path = file.name if hasattr(file, 'name') else file | |
filename = os.path.basename(file_path) | |
return filename, file_path | |
return None, None | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# Retrieval Augmented Generation with Flan-T5: Question Answering demo | |
### This demo uses Google's Flan-T5 language model and Retrieval Augmented Generation (RAG). It allows you to upload a txt or PDF file and ask the model questions related to the content of that file. | |
### Features: | |
- Support for both PDF and text files | |
- Retrieval-based question answering using document context | |
- Optimized for CPU performance using Flan-T5-Base model | |
### To get started, upload a text (.txt) or PDF (.pdf) file using the upload button below. | |
The Flan-T5 model is efficient and works well on CPU, making it perfect for document Q&A tasks. | |
Retrieval Augmented Generation (RAG) enables us to retrieve just the few small chunks of the document that are relevant to your query and inject it into our prompt. | |
The model is then able to answer questions by incorporating knowledge from the newly provided document. | |
""" | |
) | |
default_text_file = "Oppenheimer-movie-wiki.txt" | |
# Check if default file exists, if not, set to None | |
if not os.path.exists(default_text_file): | |
default_text_file = None | |
default_retriever = None | |
default_vectorstore = None | |
initial_file_display = "No default file found - please upload a file" | |
else: | |
default_retriever, default_vectorstore = prepare_vector_store_retriever(default_text_file) | |
initial_file_display = default_text_file | |
text_file = gr.State(default_text_file) | |
gr.Markdown( | |
"## Upload a txt or PDF file to get started" | |
) | |
file_name = gr.Textbox( | |
label="Loaded file", value=initial_file_display, lines=1, interactive=False | |
) | |
upload_button = gr.UploadButton( | |
label="Click to upload a text or PDF file", file_types=[".txt", ".pdf"], file_count="single" | |
) | |
upload_button.upload(upload_file, upload_button, [file_name, text_file]) | |
gr.Markdown("## Enter your question") | |
tokens_slider = gr.Slider( | |
8, | |
256, | |
value=64, | |
label="Maximum new tokens", | |
info="A larger `max_new_tokens` parameter value gives you longer text responses but at the cost of a slower response time.", | |
) | |
with gr.Row(): | |
with gr.Column(): | |
ques = gr.Textbox(label="Question", placeholder="Enter text here", lines=3) | |
with gr.Column(): | |
ans = gr.Textbox(label="Answer", lines=4, interactive=False) | |
with gr.Row(): | |
with gr.Column(): | |
btn = gr.Button("Submit") | |
with gr.Column(): | |
clear = gr.ClearButton([ques, ans]) | |
btn.click(fn=generate, inputs=[ques, ans, text_file, tokens_slider], outputs=[ans]) | |
examples = gr.Examples( | |
examples=[ | |
"Who portrayed J. Robert Oppenheimer in the new Oppenheimer movie?", | |
"In the plot of the movie, why did Lewis Strauss resent Robert Oppenheimer?", | |
"How much money did the Oppenheimer movie make at the US and global box office?", | |
"What score did the Oppenheimer movie get on Rotten Tomatoes and Metacritic?", | |
], | |
inputs=[ques], | |
) | |
demo.queue().launch() | |