|
import os |
|
import streamlit as st |
|
import PyPDF2 |
|
import torch |
|
import weaviate |
|
from transformers import AutoTokenizer, AutoModel |
|
from weaviate.classes.init import Auth |
|
import cohere |
|
|
|
|
|
WEAVIATE_URL = "vgwhgmrlqrqqgnlb1avjaa.c0.us-west3.gcp.weaviate.cloud" |
|
WEAVIATE_API_KEY = "7VoeYTjkOS4aHINuhllGpH4JPgE2QquFmSMn" |
|
COHERE_API_KEY = "LEvCVeZkqZMW1aLYjxDqlstCzWi4Cvlt9PiysqT8" |
|
|
|
|
|
client = weaviate.connect_to_weaviate_cloud( |
|
cluster_url=WEAVIATE_URL, |
|
auth_credentials=Auth.api_key(WEAVIATE_API_KEY), |
|
headers={"X-Cohere-Api-Key": COHERE_API_KEY} |
|
) |
|
|
|
cohere_client = cohere.Client(COHERE_API_KEY) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') |
|
model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') |
|
|
|
def load_pdf(file): |
|
"""Extract text from PDF file.""" |
|
reader = PyPDF2.PdfReader(file) |
|
text = ''.join([page.extract_text() for page in reader.pages if page.extract_text()]) |
|
return text |
|
|
|
def get_embeddings(text): |
|
"""Generate mean pooled embedding for the input text.""" |
|
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) |
|
with torch.no_grad(): |
|
embeddings = model(**inputs).last_hidden_state.mean(dim=1).squeeze().cpu().numpy() |
|
return embeddings |
|
|
|
def upload_document_chunks(chunks): |
|
"""Insert document chunks into Weaviate collection with embeddings.""" |
|
doc_collection = client.collections.get("Document") |
|
for chunk in chunks: |
|
embedding = get_embeddings(chunk) |
|
doc_collection.data.insert( |
|
properties={"content": chunk}, |
|
vector=embedding.tolist() |
|
) |
|
|
|
def query_answer(query): |
|
"""Search for top relevant document chunks based on query embedding.""" |
|
query_embedding = get_embeddings(query) |
|
results = client.collections.get("Document").query.near_vector( |
|
near_vector=query_embedding.tolist(), |
|
limit=3 |
|
) |
|
return results.objects |
|
|
|
def generate_response(context, query): |
|
"""Generate answer using Cohere model based on context and query.""" |
|
response = cohere_client.generate( |
|
model='command', |
|
prompt=f"Context: {context}\n\nQuestion: {query}\nAnswer:", |
|
max_tokens=100 |
|
) |
|
return response.generations[0].text.strip() |
|
|
|
def qa_pipeline(pdf_file, query): |
|
"""Main pipeline for QA: parse PDF, embed chunks, query Weaviate, and generate answer.""" |
|
document_text = load_pdf(pdf_file) |
|
document_chunks = [document_text[i:i+500] for i in range(0, len(document_text), 500)] |
|
|
|
upload_document_chunks(document_chunks) |
|
top_docs = query_answer(query) |
|
|
|
context = ' '.join([doc.properties['content'] for doc in top_docs]) |
|
answer = generate_response(context, query) |
|
|
|
return context, answer |
|
|
|
|
|
st.set_page_config(page_title="Interactive QA Bot", layout="wide") |
|
|
|
st.markdown( |
|
""" |
|
<div style="text-align: center; font-size: 28px; font-weight: bold; margin-bottom: 20px; color: #2D3748;"> |
|
π Interactive QA Bot π |
|
</div> |
|
<p style="text-align: center; font-size: 16px; color: #4A5568;"> |
|
Upload a PDF document, ask questions, and receive answers based on the document content. |
|
</p> |
|
<hr style="border: 1px solid #CBD5E0; margin: 20px 0;"> |
|
""", unsafe_allow_html=True |
|
) |
|
|
|
col1, col2 = st.columns([1, 2]) |
|
|
|
with col1: |
|
pdf_file = st.file_uploader("π Upload PDF", type=["pdf"]) |
|
query = st.text_input("β Ask a Question", placeholder="Enter your question here...") |
|
submit = st.button("π Submit") |
|
|
|
with col2: |
|
doc_segments_output = st.empty() |
|
answer_output = st.empty() |
|
|
|
if submit: |
|
if not pdf_file: |
|
st.warning("Please upload a PDF file first.") |
|
elif not query.strip(): |
|
st.warning("Please enter a question.") |
|
else: |
|
with st.spinner("Processing..."): |
|
context, answer = qa_pipeline(pdf_file, query) |
|
doc_segments_output.text_area("π Retrieved Document Segments", context, height=200) |
|
answer_output.text_area("π¬ Answer", answer, height=80) |
|
|
|
|
|
st.markdown( |
|
""" |
|
<style> |
|
body { |
|
background-color: #EDF2F7; |
|
} |
|
.stFileUploader > div > div > input { |
|
background-color: #3182CE !important; |
|
color: white !important; |
|
padding: 8px !important; |
|
border-radius: 5px !important; |
|
} |
|
button { |
|
background-color: #3182CE !important; |
|
color: white !important; |
|
padding: 10px !important; |
|
font-size: 16px !important; |
|
border-radius: 5px !important; |
|
cursor: pointer; |
|
border: none !important; |
|
} |
|
button:hover { |
|
background-color: #2B6CB0 !important; |
|
} |
|
textarea { |
|
border: 2px solid #CBD5E0 !important; |
|
border-radius: 8px !important; |
|
padding: 10px !important; |
|
background-color: #FAFAFA !important; |
|
} |
|
</style> |
|
""", unsafe_allow_html=True |
|
) |