File size: 9,837 Bytes
d941be5
eda06e0
 
 
 
 
 
f93879e
63e2e46
 
8dc5c8f
63e2e46
 
 
8dc5c8f
63e2e46
 
 
 
 
 
fc83ecf
63e2e46
 
 
 
8dc5c8f
 
63e2e46
f93879e
fc83ecf
 
 
 
3e32b0f
b88f075
3e32b0f
831dfea
b88f075
63e2e46
 
fc83ecf
 
 
 
 
 
63e2e46
fc83ecf
 
 
63e2e46
 
 
 
 
 
 
 
 
 
3e32b0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8dc5c8f
63e2e46
8dc5c8f
 
 
 
 
 
63e2e46
 
 
 
 
 
 
 
 
 
 
 
 
 
8dc5c8f
 
 
f27624c
 
63e2e46
8dc5c8f
 
f8a9e21
 
 
fc83ecf
 
 
 
 
 
8dc5c8f
f27624c
 
 
 
 
 
63e2e46
 
 
f27624c
eda06e0
 
 
 
 
fc83ecf
eda06e0
fc83ecf
aa52fc9
7a5d2ae
eda06e0
fc83ecf
eda06e0
63e2e46
eda06e0
f27624c
63e2e46
eda06e0
8dc5c8f
eda06e0
 
 
63e2e46
b88f075
eda06e0
b88f075
 
 
 
 
 
 
3e32b0f
b88f075
 
 
 
eda06e0
b88f075
eda06e0
 
 
 
63e2e46
 
8dc5c8f
63e2e46
8dc5c8f
f8a9e21
 
 
 
8dc5c8f
63e2e46
 
 
 
 
fc83ecf
 
8dc5c8f
 
f27624c
fc83ecf
f8a9e21
fc83ecf
f8a9e21
 
63e2e46
 
 
 
f8a9e21
 
 
 
 
 
 
 
 
 
63e2e46
 
 
 
f8a9e21
63e2e46
 
 
f8a9e21
63e2e46
 
8dc5c8f
63e2e46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f27624c
63e2e46
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
import nltk
try:
    nltk.download('averaged_perceptron_tagger_eng', quiet=True)
    nltk.download("punkt", quiet=True)
    nltk.download('punkt_tab', quiet=True)
except Exception as e:
    print(f"Warning: NLTK download failed: {e}")

import gradio as gr
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.document_loaders import UnstructuredFileLoader, PyPDFLoader
from langchain.vectorstores.faiss import FAISS
from langchain.vectorstores.utils import DistanceStrategy
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.schema import Document

from langchain.chains import RetrievalQA
from langchain.prompts.prompt import PromptTemplate
from langchain.vectorstores.base import VectorStoreRetriever

import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline

from transformers import TextIteratorStreamer
from threading import Thread
import os
import tempfile


# Prompt template optimized for Flan-T5
template = """Answer the question based on the context below.

Context: {context}

Question: {question}

Answer:"""
QA_PROMPT = PromptTemplate(template=template, input_variables=["question", "context"])


# Load Flan-T5 model from hugging face hub - excellent for CPU and Q&A tasks
# Alternative popular CPU-friendly models you can try:
# - "google/flan-t5-small" (faster, smaller)
# - "google/flan-t5-large" (better quality, slower)
# - "microsoft/DialoGPT-medium" (conversational)
model_id = "google/flan-t5-base"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(
    model_id, torch_dtype=torch.float32
)

# sentence transformers to be used in vector store
embeddings = HuggingFaceEmbeddings(
    model_name="sentence-transformers/msmarco-distilbert-base-v4",
    model_kwargs={"device": "cpu"},
    encode_kwargs={"normalize_embeddings": False},
)


def clean_response(text):
    """Clean up the generated response"""
    # Remove excessive whitespace and newlines
    text = ' '.join(text.split())
    
    # Remove repetitive patterns
    words = text.split()
    cleaned_words = []
    
    for word in words:
        # Skip if the same word appears too many times consecutively
        if len(cleaned_words) >= 3 and all(w == word for w in cleaned_words[-3:]):
            continue
        cleaned_words.append(word)
    
    cleaned_text = ' '.join(cleaned_words)
    
    # Truncate at natural stopping points
    sentences = cleaned_text.split('.')
    if len(sentences) > 1:
        # Keep complete sentences
        good_sentences = []
        for sentence in sentences[:-1]:  # Exclude last potentially incomplete sentence
            if len(sentence.strip()) > 5:  # Avoid very short fragments
                good_sentences.append(sentence.strip())
        
        if good_sentences:
            return '. '.join(good_sentences) + '.'
    
    return cleaned_text[:500]  # Fallback: truncate to reasonable length


# Returns a faiss vector store retriever given a txt or pdf file
def prepare_vector_store_retriever(filename):
    # Load data based on file extension
    if filename.lower().endswith('.pdf'):
        loader = PyPDFLoader(filename)
    else:
        loader = UnstructuredFileLoader(filename)
    
    raw_documents = loader.load()

    # Split the text
    text_splitter = CharacterTextSplitter(
        separator="\n\n", chunk_size=800, chunk_overlap=0, length_function=len
    )

    documents = text_splitter.split_documents(raw_documents)

    # Creating a vectorstore
    vectorstore = FAISS.from_documents(
        documents, embeddings, distance_strategy=DistanceStrategy.DOT_PRODUCT
    )

    return VectorStoreRetriever(vectorstore=vectorstore, search_kwargs={"k": 2}), vectorstore


# Retrieval QA chain
def get_retrieval_qa_chain(text_file, hf_model):
    retriever = default_retriever
    vectorstore = default_vectorstore
    
    if text_file != default_text_file or default_text_file is None:
        if text_file is not None and os.path.exists(text_file):
            retriever, vectorstore = prepare_vector_store_retriever(text_file)
        else:
            # Create a dummy retriever if no file is available
            dummy_doc = Document(page_content="No document loaded. Please upload a file to get started.")
            dummy_vectorstore = FAISS.from_documents([dummy_doc], embeddings)
            retriever = VectorStoreRetriever(vectorstore=dummy_vectorstore, search_kwargs={"k": 1})
            vectorstore = dummy_vectorstore

    chain = RetrievalQA.from_chain_type(
        llm=hf_model,
        retriever=retriever,
        chain_type_kwargs={"prompt": QA_PROMPT},
    )
    return chain, vectorstore


# Generates response using the question answering chain defined earlier
def generate(question, answer, text_file, max_new_tokens):
    if not question.strip():
        yield "Please enter a question."
        return
    
    try:
        # Create pipeline for text2text generation (Flan-T5)
        phi2_pipeline = pipeline(
            "text2text-generation",
            model=model,
            tokenizer=tokenizer,
            max_new_tokens=max_new_tokens,
            do_sample=False,
        )

        hf_model = HuggingFacePipeline(pipeline=phi2_pipeline)
        qa_chain, vectorstore = get_retrieval_qa_chain(text_file, hf_model)

        query = f"{question}"

        if len(tokenizer.tokenize(query)) >= 512:
            yield "Your question is too long! Please shorten it."
            return

        # Get the response directly without streaming first
        try:
            result = qa_chain.invoke({"query": query})
            
            # Extract the answer from the result
            if isinstance(result, dict):
                response = result.get('result', str(result))
            else:
                response = str(result)
                
            # Clean the response
            cleaned_response = clean_response(response)
            yield cleaned_response
            
        except Exception as e:
            yield f"Error during generation: {str(e)}"
            return
                
    except Exception as e:
        yield f"Error: {str(e)}"


# replaces the retriever in the question answering chain whenever a new file is uploaded
def upload_file(file):
    if file is not None:
        # In Gradio, file is already a path to the uploaded file
        file_path = file.name if hasattr(file, 'name') else file
        filename = os.path.basename(file_path)
        return filename, file_path
    return None, None


with gr.Blocks() as demo:
    gr.Markdown(
        """
  # Retrieval Augmented Generation with Flan-T5: Question Answering demo
  ### This demo uses Google's Flan-T5 language model and Retrieval Augmented Generation (RAG). It allows you to upload a txt or PDF file and ask the model questions related to the content of that file.
  ### Features:
  - Support for both PDF and text files
  - Retrieval-based question answering using document context
  - Optimized for CPU performance using Flan-T5-Base model
  ### To get started, upload a text (.txt) or PDF (.pdf) file using the upload button below.
  The Flan-T5 model is efficient and works well on CPU, making it perfect for document Q&A tasks.
  Retrieval Augmented Generation (RAG) enables us to retrieve just the few small chunks of the document that are relevant to your query and inject it into our prompt.
  The model is then able to answer questions by incorporating knowledge from the newly provided document.
  """
    )

    default_text_file = "Oppenheimer-movie-wiki.txt"
    
    # Check if default file exists, if not, set to None
    if not os.path.exists(default_text_file):
        default_text_file = None
        default_retriever = None
        default_vectorstore = None
        initial_file_display = "No default file found - please upload a file"
    else:
        default_retriever, default_vectorstore = prepare_vector_store_retriever(default_text_file)
        initial_file_display = default_text_file

    text_file = gr.State(default_text_file)

    gr.Markdown(
        "## Upload a txt or PDF file to get started"
    )

    file_name = gr.Textbox(
        label="Loaded file", value=initial_file_display, lines=1, interactive=False
    )
    upload_button = gr.UploadButton(
        label="Click to upload a text or PDF file", file_types=[".txt", ".pdf"], file_count="single"
    )
    upload_button.upload(upload_file, upload_button, [file_name, text_file])

    gr.Markdown("## Enter your question")
    tokens_slider = gr.Slider(
        8,
        256,
        value=64,
        label="Maximum new tokens",
        info="A larger `max_new_tokens` parameter value gives you longer text responses but at the cost of a slower response time.",
    )

    with gr.Row():
        with gr.Column():
            ques = gr.Textbox(label="Question", placeholder="Enter text here", lines=3)
        with gr.Column():
            ans = gr.Textbox(label="Answer", lines=4, interactive=False)
    with gr.Row():
        with gr.Column():
            btn = gr.Button("Submit")
        with gr.Column():
            clear = gr.ClearButton([ques, ans])

    btn.click(fn=generate, inputs=[ques, ans, text_file, tokens_slider], outputs=[ans])
    examples = gr.Examples(
        examples=[
            "Who portrayed J. Robert Oppenheimer in the new Oppenheimer movie?",
            "In the plot of the movie, why did Lewis Strauss resent Robert Oppenheimer?",
            "How much money did the Oppenheimer movie make at the US and global box office?",
            "What score did the Oppenheimer movie get on Rotten Tomatoes and Metacritic?",
        ],
        inputs=[ques],
    )

demo.queue().launch()