Prathamesh1420 commited on
Commit
7d48d44
·
verified ·
1 Parent(s): 3618f3d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +210 -36
app.py CHANGED
@@ -1,43 +1,217 @@
1
  import streamlit as st
2
- import asyncio
3
- import websockets
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- st.markdown('<h1 style="color: darkblue;">AI Voice Assistant</h1>', unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- # JavaScript for real-time voice streaming
8
- audio_recorder_js = """
9
- <script>
10
- let mediaRecorder;
11
- let ws;
12
-
13
- function startRecording() {
14
- navigator.mediaDevices.getUserMedia({ audio: true }).then(stream => {
15
- ws = new WebSocket("ws://localhost:8765"); // Replace with your server's WebSocket URL
16
- mediaRecorder = new MediaRecorder(stream);
17
- mediaRecorder.start();
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- mediaRecorder.ondataavailable = event => {
20
- ws.send(event.data);
21
- };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- ws.onmessage = function(event) {
24
- document.getElementById("response").innerHTML += "<br><b>AI:</b> " + event.data;
25
- };
26
- });
27
- }
 
 
 
 
 
28
 
29
- function stopRecording() {
30
- mediaRecorder.stop();
31
- ws.close();
32
- }
33
- </script>
34
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- # Display buttons for real-time recording
37
- st.components.v1.html(
38
- audio_recorder_js + """
39
- <button onclick="startRecording()">🎤 Start Talking</button>
40
- <button onclick="stopRecording()">🛑 Stop</button>
41
- <div id="response" style="margin-top: 10px; padding: 10px; border: 1px solid #ccc;"></div>
42
- """, height=200
43
- )
 
1
  import streamlit as st
2
+ from langchain.chains import RetrievalQA
3
+ from langchain.vectorstores import Milvus
4
+ from langchain.embeddings import HuggingFaceEmbeddings
5
+ from transformers import AutoTokenizer
6
+ from langchain_groq import ChatGroq
7
+ import os
8
+ from docling.document_converter import DocumentConverter, PdfFormatOption
9
+ from docling.datamodel.base_models import InputFormat
10
+ from docling.datamodel.pipeline_options import PdfPipelineOptions
11
+ from docling_core.transforms.chunker.hybrid_chunker import HybridChunker
12
+ from docling_core.types.doc.document import TableItem
13
+ from langchain_core.documents import Document
14
+ import itertools
15
+ from docling_core.types.doc.labels import DocItemLabel
16
+ import google.generativeai as genai
17
+ from PIL import Image
18
+ import base64
19
+ import io
20
 
21
+ # Initialize components (similar to your notebook)
22
+ @st.cache_resource
23
+ def initialize_components():
24
+ # Initialize embeddings
25
+ embeddings_model_path = "ibm-granite/granite-embedding-30m-english"
26
+ embeddings_model = HuggingFaceEmbeddings(model_name=embeddings_model_path)
27
+ embeddings_tokenizer = AutoTokenizer.from_pretrained(embeddings_model_path)
28
+
29
+ # Initialize language model
30
+ GROQ_API_KEY = "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"
31
+ model = ChatGroq(model_name="llama3-70b-8192", api_key=GROQ_API_KEY)
32
+
33
+ # Initialize vision model
34
+ GOOGLE_API_KEY = "AIzaSyBTt66oOvxpLeYn41sR-KkjSYPK2vOAqkU"
35
+ genai.configure(api_key=GOOGLE_API_KEY)
36
+ vision_model = genai.GenerativeModel(model_name="gemini-1.5-flash")
37
+
38
+ return embeddings_model, embeddings_tokenizer, model, vision_model
39
 
40
+ def process_pdf(file_path, embeddings_tokenizer, vision_model):
41
+ # PDF processing (similar to your notebook)
42
+ pdf_pipeline_options = PdfPipelineOptions(
43
+ do_ocr=True,
44
+ generate_picture_images=True
45
+ )
46
+
47
+ format_options = {
48
+ InputFormat.PDF: PdfFormatOption(pipeline_options=pdf_pipeline_options),
49
+ }
50
+
51
+ converter = DocumentConverter(format_options=format_options)
52
+ sources = [file_path]
53
+ conversions = {
54
+ source: converter.convert(source=source).document for source in sources
55
+ }
56
+
57
+ # Process text chunks
58
+ doc_id = 0
59
+ texts = []
60
+
61
+ for source, docling_document in conversions.items():
62
+ chunker = HybridChunker(tokenizer=embeddings_tokenizer)
63
 
64
+ for chunk in chunker.chunk(docling_document):
65
+ items = chunk.meta.doc_items
66
+
67
+ if len(items) == 1 and isinstance(items[0], TableItem):
68
+ continue
69
+
70
+ refs = "".join(item.get_ref().cref for item in items)
71
+ text = chunk.text
72
+
73
+ document = Document(
74
+ page_content=text,
75
+ metadata={
76
+ "doc_id": (doc_id := doc_id + 1),
77
+ "source": source,
78
+ "ref": refs,
79
+ }
80
+ )
81
+ texts.append(document)
82
+
83
+ # Process tables (if any)
84
+ tables = []
85
+ for source, docling_document in conversions.items():
86
+ for table in docling_document.tables:
87
+ if table.label == DocItemLabel.TABLE:
88
+ ref = table.get_ref().cref
89
+ text = table.export_to_markdown()
90
+
91
+ document = Document(
92
+ page_content=text,
93
+ metadata={
94
+ "doc_id": (doc_id := doc_id + 1),
95
+ "source": source,
96
+ "ref": ref,
97
+ },
98
+ )
99
+ tables.append(document)
100
+
101
+ # Process images (if any)
102
+ pictures = []
103
+ start_doc_id = len(texts) + len(tables) + 1
104
+
105
+ for source, docling_document in conversions.items():
106
+ if hasattr(docling_document, 'pictures') and docling_document.pictures:
107
+ for picture in docling_document.pictures:
108
+ try:
109
+ ref = picture.get_ref().cref
110
+ image = picture.get_image(docling_document)
111
+
112
+ if image:
113
+ response = vision_model.generate_content([
114
+ "Extract all text and describe key visual elements in this image. "
115
+ "Include any numbers, labels, or important details.",
116
+ image
117
+ ])
118
+
119
+ document = Document(
120
+ page_content=response.text,
121
+ metadata={
122
+ "doc_id": doc_id,
123
+ "source": source,
124
+ "ref": ref,
125
+ }
126
+ )
127
+ pictures.append(document)
128
+ doc_id += 1
129
+ except Exception as e:
130
+ print(f"Error processing image: {str(e)}")
131
+
132
+ return texts + tables + pictures
133
 
134
+ def create_vector_store(docs, embeddings_model):
135
+ # Create vector store (using Milvus as in your notebook)
136
+ # Note: You'll need to have Milvus running
137
+ vector_store = Milvus.from_documents(
138
+ docs,
139
+ embeddings_model,
140
+ connection_args={"host": "127.0.0.1", "port": "19530"},
141
+ collection_name="pdf_manual"
142
+ )
143
+ return vector_store
144
 
145
+ def main():
146
+ st.title("PDF Manual Chatbot")
147
+
148
+ # Initialize components
149
+ embeddings_model, embeddings_tokenizer, model, vision_model = initialize_components()
150
+
151
+ # File upload
152
+ uploaded_file = st.file_uploader("Upload a PDF manual", type="pdf")
153
+
154
+ if uploaded_file is not None:
155
+ # Save the uploaded file
156
+ file_path = os.path.join("temp", uploaded_file.name)
157
+ os.makedirs("temp", exist_ok=True)
158
+ with open(file_path, "wb") as f:
159
+ f.write(uploaded_file.getbuffer())
160
+
161
+ # Process the PDF
162
+ with st.spinner("Processing PDF..."):
163
+ docs = process_pdf(file_path, embeddings_tokenizer, vision_model)
164
+ vector_store = create_vector_store(docs, embeddings_model)
165
+
166
+ st.success("PDF processed successfully!")
167
+
168
+ # Initialize chat history
169
+ if "messages" not in st.session_state:
170
+ st.session_state.messages = []
171
+
172
+ # Display chat messages from history on app rerun
173
+ for message in st.session_state.messages:
174
+ with st.chat_message(message["role"]):
175
+ st.markdown(message["content"])
176
+
177
+ # Accept user input
178
+ if prompt := st.chat_input("Ask a question about the manual"):
179
+ # Add user message to chat history
180
+ st.session_state.messages.append({"role": "user", "content": prompt})
181
+
182
+ # Display user message in chat message container
183
+ with st.chat_message("user"):
184
+ st.markdown(prompt)
185
+
186
+ # Create QA chain
187
+ qa_chain = RetrievalQA.from_chain_type(
188
+ llm=model,
189
+ chain_type="stuff",
190
+ retriever=vector_store.as_retriever(),
191
+ return_source_documents=True
192
+ )
193
+
194
+ # Get response
195
+ with st.spinner("Thinking..."):
196
+ result = qa_chain({"query": prompt})
197
+ response = result["result"]
198
+ source_docs = result["source_documents"]
199
+
200
+ # Display assistant response in chat message container
201
+ with st.chat_message("assistant"):
202
+ st.markdown(response)
203
+
204
+ # Show sources if available
205
+ if source_docs:
206
+ with st.expander("Source Documents"):
207
+ for i, doc in enumerate(source_docs):
208
+ st.write(f"Source {i+1}:")
209
+ st.write(doc.page_content)
210
+ st.write(f"Metadata: {doc.metadata}")
211
+ st.write("---")
212
+
213
+ # Add assistant response to chat history
214
+ st.session_state.messages.append({"role": "assistant", "content": response})
215
 
216
+ if __name__ == "__main__":
217
+ main()