pradeepsengarr commited on
Commit
2074ed8
Β·
verified Β·
1 Parent(s): 759650f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -30
app.py CHANGED
@@ -9,31 +9,31 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
9
  from langchain.text_splitter import RecursiveCharacterTextSplitter
10
  from huggingface_hub import login
11
 
12
- # 1. Authenticate HuggingFace
13
  hf_token = os.environ.get("HUGGINGFACE_TOKEN")
14
  if not hf_token:
15
  raise ValueError("⚠️ Please set the HUGGINGFACE_TOKEN environment variable.")
16
  login(token=hf_token)
17
 
18
- # 2. Load embedding model
19
  embed_model = SentenceTransformer("BAAI/bge-base-en-v1.5")
20
 
21
- # 3. Load LLM (Mistral 7B Instruct with 4-bit quantization)
22
  model_id = "mistralai/Mistral-7B-Instruct-v0.1"
23
- tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=hf_token)
24
  model = AutoModelForCausalLM.from_pretrained(
25
  model_id,
26
- device_map="auto",
27
- load_in_4bit=True,
28
- use_auth_token=hf_token
29
  )
30
  llm = pipeline("text-generation", model=model, tokenizer=tokenizer)
31
 
32
- # 4. Globals
33
  index = None
34
  doc_texts = []
35
 
36
- # 5. Extract text from uploaded file
37
  def extract_text(file):
38
  text = ""
39
  file_bytes = file.read()
@@ -45,69 +45,60 @@ def extract_text(file):
45
  elif file.name.endswith(".txt"):
46
  text = file_bytes.decode("utf-8")
47
  else:
48
- return "❌ Unsupported file type. Only PDF and TXT are allowed."
49
  return text
50
 
51
- # 6. Process the file: split text, create embeddings, build FAISS index
52
  def process_file(file):
53
  global index, doc_texts
54
  text = extract_text(file)
55
  if text.startswith("❌"):
56
  return text
57
 
58
- # Split text
59
  splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=50)
60
  doc_texts = splitter.split_text(text)
61
 
62
- # Create embeddings
63
  embeddings = embed_model.encode(doc_texts, convert_to_numpy=True)
64
-
65
- # Build FAISS index
66
  dim = embeddings.shape[1]
67
  index = faiss.IndexFlatL2(dim)
68
  index.add(embeddings)
69
 
70
- return "βœ… File processed successfully. You can now ask questions!"
71
 
72
- # 7. Generate answer based on question + retrieved context
73
  def generate_answer(question):
74
  global index, doc_texts
75
  if index is None or not doc_texts:
76
  return "⚠️ Please upload and process a file first."
77
 
78
- # Embed the question
79
  question_emb = embed_model.encode([question], convert_to_numpy=True)
80
  _, I = index.search(question_emb, k=3)
81
-
82
- # Build context
83
  context = "\n".join([doc_texts[i] for i in I[0]])
84
 
85
- # Prompt
86
- prompt = f"""[System: You are a helpful assistant. Answer strictly based on the context. Do not hallucinate.]
87
  Context:
88
  {context}
89
 
90
  Question: {question}
91
  Answer:"""
92
 
93
- # Generate response
94
  response = llm(prompt, max_new_tokens=300, do_sample=True, temperature=0.7)
95
  return response[0]["generated_text"].split("Answer:")[-1].strip()
96
 
97
- # 8. Gradio UI
98
- with gr.Blocks(title="🧠 RAG Chatbot") as demo:
99
- gr.Markdown("## πŸ“š Retrieval-Augmented Generation Chatbot\nUpload a `.pdf` or `.txt` and ask questions from the content.")
100
 
101
  with gr.Row():
102
- file_input = gr.File(label="πŸ“ Upload PDF/TXT", file_types=[".pdf", ".txt"])
103
- upload_status = gr.Textbox(label="πŸ“₯ Upload Status", interactive=False)
104
 
105
  with gr.Row():
106
- question_box = gr.Textbox(label="❓ Ask a Question", placeholder="Type your question here...")
107
  answer_box = gr.Textbox(label="πŸ’¬ Answer", interactive=False)
108
 
109
  file_input.change(fn=process_file, inputs=file_input, outputs=upload_status)
110
  question_box.submit(fn=generate_answer, inputs=question_box, outputs=answer_box)
111
 
112
- # 9. Launch the app
113
  demo.launch()
 
9
  from langchain.text_splitter import RecursiveCharacterTextSplitter
10
  from huggingface_hub import login
11
 
12
+ # Authenticate with Hugging Face
13
  hf_token = os.environ.get("HUGGINGFACE_TOKEN")
14
  if not hf_token:
15
  raise ValueError("⚠️ Please set the HUGGINGFACE_TOKEN environment variable.")
16
  login(token=hf_token)
17
 
18
+ # Load embedding model
19
  embed_model = SentenceTransformer("BAAI/bge-base-en-v1.5")
20
 
21
+ # Load Mistral without 4bit quantization (CPU-friendly)
22
  model_id = "mistralai/Mistral-7B-Instruct-v0.1"
23
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
24
  model = AutoModelForCausalLM.from_pretrained(
25
  model_id,
26
+ device_map={"": "cpu"}, # force CPU
27
+ torch_dtype="auto", # safe for CPU
28
+ token=hf_token
29
  )
30
  llm = pipeline("text-generation", model=model, tokenizer=tokenizer)
31
 
32
+ # Globals
33
  index = None
34
  doc_texts = []
35
 
36
+ # Extract text from PDF or TXT
37
  def extract_text(file):
38
  text = ""
39
  file_bytes = file.read()
 
45
  elif file.name.endswith(".txt"):
46
  text = file_bytes.decode("utf-8")
47
  else:
48
+ return "❌ Unsupported file type."
49
  return text
50
 
51
+ # Process the file, build FAISS index
52
  def process_file(file):
53
  global index, doc_texts
54
  text = extract_text(file)
55
  if text.startswith("❌"):
56
  return text
57
 
 
58
  splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=50)
59
  doc_texts = splitter.split_text(text)
60
 
 
61
  embeddings = embed_model.encode(doc_texts, convert_to_numpy=True)
 
 
62
  dim = embeddings.shape[1]
63
  index = faiss.IndexFlatL2(dim)
64
  index.add(embeddings)
65
 
66
+ return "βœ… File processed! You can now ask questions."
67
 
68
+ # Generate answer using context + LLM
69
  def generate_answer(question):
70
  global index, doc_texts
71
  if index is None or not doc_texts:
72
  return "⚠️ Please upload and process a file first."
73
 
 
74
  question_emb = embed_model.encode([question], convert_to_numpy=True)
75
  _, I = index.search(question_emb, k=3)
 
 
76
  context = "\n".join([doc_texts[i] for i in I[0]])
77
 
78
+ prompt = f"""[System: You are a helpful assistant. Answer based on the context.]
79
+
80
  Context:
81
  {context}
82
 
83
  Question: {question}
84
  Answer:"""
85
 
 
86
  response = llm(prompt, max_new_tokens=300, do_sample=True, temperature=0.7)
87
  return response[0]["generated_text"].split("Answer:")[-1].strip()
88
 
89
+ # Gradio UI
90
+ with gr.Blocks(title="RAG Chatbot (CPU Compatible)") as demo:
91
+ gr.Markdown("## πŸ“š Upload PDF/TXT and Ask Questions using Mistral-7B")
92
 
93
  with gr.Row():
94
+ file_input = gr.File(label="πŸ“ Upload File (.pdf or .txt)", file_types=[".pdf", ".txt"])
95
+ upload_status = gr.Textbox(label="Upload Status", interactive=False)
96
 
97
  with gr.Row():
98
+ question_box = gr.Textbox(label="❓ Ask a Question", placeholder="What would you like to know?")
99
  answer_box = gr.Textbox(label="πŸ’¬ Answer", interactive=False)
100
 
101
  file_input.change(fn=process_file, inputs=file_input, outputs=upload_status)
102
  question_box.submit(fn=generate_answer, inputs=question_box, outputs=answer_box)
103
 
 
104
  demo.launch()