Mojo3 commited on
Commit
de3b7ed
·
verified ·
1 Parent(s): 8c75d77

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +222 -1
app.py CHANGED
@@ -1,6 +1,227 @@
 
 
 
 
 
 
 
 
 
 
1
  from langchain_community.embeddings import HuggingFaceEmbeddings
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  embedding_model = HuggingFaceEmbeddings(
4
  model_name="Omartificial-Intelligence-Space/Arabic-Triplet-Matryoshka-V2"
5
  )
6
- print("Finished fucking")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from docx import Document
3
+ import os
4
+ from langchain_core.prompts import PromptTemplate
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ import torch
7
+ import time
8
+ from sentence_transformers import SentenceTransformer
9
+ from langchain.vectorstores import Chroma
10
+ from langchain.docstore.document import Document as Document2
11
  from langchain_community.embeddings import HuggingFaceEmbeddings
12
 
13
+ import cohere
14
+ from langchain_core.prompts import PromptTemplate
15
+
16
+ # Load token from environment variable
17
+ token = os.getenv("HF_TOKEN")
18
+
19
+ print("my token is ", token)
20
+ # Save the token to Hugging Face's system directory
21
+
22
+ docs_folder = "./converted_docs"
23
+
24
+
25
+ # Function to load .docx files from Google Drive folder
26
+ def load_docx_files_from_drive(drive_folder):
27
+ docx_files = [f for f in os.listdir(drive_folder) if f.endswith(".docx")]
28
+ documents = []
29
+
30
+ for file_name in docx_files:
31
+ file_path = os.path.join(drive_folder, file_name)
32
+ doc = Document(file_path)
33
+ content = "\n".join([p.text for p in doc.paragraphs if p.text.strip()])
34
+ documents.append(content)
35
+
36
+ return documents
37
+
38
+
39
+ # Load .docx files from Google Drive folder
40
+ documents = load_docx_files_from_drive(docs_folder)
41
+
42
+
43
+ def split_extracted_text_into_chunks(documents):
44
+ print("Splitting text into chunks")
45
+ # List to hold all chunks
46
+ chunks = []
47
+
48
+ for doc_text in documents:
49
+ # Split the document text into lines
50
+ lines = doc_text.splitlines()
51
+
52
+ # Initialize variables for splitting
53
+ current_chunk = []
54
+ for line in lines:
55
+ # Check if the line starts with "File Name:"
56
+ if line.startswith("File Name:"):
57
+ # If there's a current chunk, save it before starting a new one
58
+ if current_chunk:
59
+ chunks.append("\n".join(current_chunk))
60
+ current_chunk = [] # Reset the current chunk
61
+
62
+ # Add the line to the current chunk
63
+ current_chunk.append(line)
64
+
65
+ # Add the last chunk for the current document
66
+ if current_chunk:
67
+ chunks.append("\n".join(current_chunk))
68
+
69
+ return chunks
70
+
71
+
72
+ # Split the extracted documents into chunks
73
+ chunks = split_extracted_text_into_chunks(documents)
74
+
75
+
76
+ def save_chunks_to_file(chunks, output_file_path):
77
+ print("Saving chunks to file")
78
+ # Open the file in write mode
79
+ with open(output_file_path, "w", encoding="utf-8") as file:
80
+ for i, chunk in enumerate(chunks, start=1):
81
+ # Write each chunk with a header for easy identification
82
+ file.write(f"Chunk {i}:\n")
83
+ file.write(chunk)
84
+ file.write("\n" + "=" * 50 + "\n")
85
+
86
+
87
+ # Path to save the chunks file
88
+ output_file_path = "./chunks_output.txt"
89
+
90
+ # Split the extracted documents into chunks
91
+ chunks = split_extracted_text_into_chunks(documents)
92
+
93
+ # Save the chunks to the file
94
+ save_chunks_to_file(chunks, output_file_path)
95
+
96
+
97
+ # Step 1: Load the model through LangChain's wrapper
98
  embedding_model = HuggingFaceEmbeddings(
99
  model_name="Omartificial-Intelligence-Space/Arabic-Triplet-Matryoshka-V2"
100
  )
101
+ print("#0")
102
+
103
+
104
+ # Step 2: Embed the chunks (now simplified)
105
+ def embed_chunks(chunks):
106
+ status_text = st.empty()
107
+ progress_bar = st.progress(0)
108
+ results = []
109
+
110
+ total_chunks = len(chunks)
111
+
112
+ for i, chunk in enumerate(chunks):
113
+ result = {
114
+ "chunk": chunk,
115
+ "embedding": embedding_model.embed_query(chunk)
116
+ }
117
+ results.append(result)
118
+
119
+ progress = (i + 1) / total_chunks
120
+ progress_bar.progress(progress)
121
+ status_text.text(f"Processed {i+1}/{total_chunks} chunks ({progress:.0%})")
122
+
123
+ progress_bar.progress(1.0)
124
+ status_text.text("Embedding complete!")
125
+ return results
126
+
127
+
128
+ embeddings = embed_chunks(chunks)
129
+ print("#1")
130
+
131
+
132
+ # Step 3: Prepare documents (unchanged)
133
+ def prepare_documents_for_chroma(embeddings):
134
+ print("Preparing documents for chroma")
135
+ return [
136
+ Document2(page_content=entry["chunk"], metadata={"chunk_index": i})
137
+ for i, entry in enumerate(embeddings, start=1)
138
+ ]
139
+
140
+
141
+ print("#2")
142
+ documents = prepare_documents_for_chroma(embeddings)
143
+ print("Creating the vectore store")
144
+ # Step 4: Create Chroma store (fixed)
145
+ vectorstore = Chroma.from_documents(
146
+ documents=documents,
147
+ embedding=embedding_model, # Proper embedding object
148
+ persist_directory="./chroma_db", # Optional persistence
149
+ )
150
+
151
+
152
+ class RAGPipeline:
153
+ def __init__(self, vectorstore, api_key, model_name="c4ai-aya-expanse-8b", k=3):
154
+ print("Initializing RAG Pipeline")
155
+ self.vectorstore = vectorstore
156
+ self.model_name = model_name
157
+ self.k = k
158
+ self.api_key = api_key
159
+ self.client = cohere.Client(api_key) # Initialize the Cohere client
160
+ self.retriever = self.vectorstore.as_retriever(
161
+ search_type="mmr", search_kwargs={"k": 3}
162
+ )
163
+ self.prompt_template = PromptTemplate.from_template(self._get_template())
164
+
165
+ def _get_template(self):
166
+ return """<s>[INST] <<SYS>>
167
+ أنت مساعد مفيد يقدم إجابات باللغة العربية بناءً على السياق المقدم.
168
+ - أجب فقط باللغة العربية
169
+ - إذا لم تجد إجابة في السياق، قل أنك لا تعرف
170
+ - كن دقيقاً وواضحاً في إجاباتك
171
+ -جاوب من السياق حصريا
172
+ <</SYS>>
173
+
174
+ السياق: {context}
175
+
176
+ السؤال: {question}
177
+ الإجابة: [/INST]\
178
+
179
+ """
180
+
181
+ def generate_response(self, question):
182
+ retrieved_docs = self._retrieve_documents(question)
183
+ prompt = self._create_prompt(retrieved_docs, question)
184
+ response = self._generate_response_cohere(prompt)
185
+ return response
186
+
187
+ def _retrieve_documents(self, question):
188
+ retrieved_docs = self.retriever.invoke(question)
189
+ # print("\n=== المستندات المسترجعة ===")
190
+ # for i, doc in enumerate(retrieved_docs):
191
+ # print(f"المستند {i+1}: {doc.page_content}")
192
+ # print("==========================\n")
193
+
194
+ # دمج النصوص المسترجعة في سياق واحد
195
+ return " ".join([doc.page_content for doc in retrieved_docs])
196
+
197
+ def _create_prompt(self, docs, question):
198
+ return self.prompt_template.format(context=docs, question=question)
199
+
200
+ def _generate_response_cohere(self, prompt):
201
+ # Call Cohere's generate API
202
+ response = self.client.generate(
203
+ model=self.model_name,
204
+ prompt=prompt,
205
+ max_tokens=2000, # Adjust token limit based on requirements
206
+ temperature=0.3, # Control creativity
207
+ stop_sequences=None,
208
+ )
209
+
210
+ if response.generations:
211
+ return response.generations[0].text.strip()
212
+ else:
213
+ raise Exception("No response generated by Cohere API.")
214
+
215
+
216
+ st.title("Simple Text Generator")
217
+ api_key = os.getenv("API_KEY")
218
+ s = api_key[:5]
219
+ print("KEY: ", s)
220
+ rag_pipeline = RAGPipeline(vectorstore=vectorstore, api_key=api_key)
221
+ print("Enter your question Here: ")
222
+ question = st.text_input("أدخل سؤالك هنا")
223
+ if st.button("Generate Answer"):
224
+ response = rag_pipeline.generate_response(question)
225
+ st.write(response)
226
+ print("Question: ", question)
227
+ print("Response: ", response)