Spaces:
Sleeping
Sleeping
import nltk | |
nltk.download('punkt') | |
nltk.download('punkt_tab') | |
from io import BytesIO | |
from PyPDF2 import PdfReader, utils | |
import fitz | |
from typing import List | |
import google.generativeai as genai | |
import gradio as gr | |
from nltk.tokenize import sent_tokenize | |
from fastembed import TextEmbedding | |
import numpy as np | |
from qdrant_client import QdrantClient | |
from qdrant_client.http.models import Distance, VectorParams | |
from qdrant_client.models import PointStruct | |
import os | |
from dotenv import load_dotenv, find_dotenv | |
load_dotenv(find_dotenv()) | |
QDRANT_API_KEY = os.getenv('QDRANT_API_KEY') | |
GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY') | |
input_path = './repaired-www-foxweather-com.pdf' | |
with open(input_path, 'rb') as input_file: | |
input_buffer = BytesIO(input_file.read()) | |
# Try reading the PDF directly | |
try: | |
input_pdf = PdfReader(input_buffer) | |
print("PDF read successfully.") | |
except utils.PdfReadError: | |
# If direct reading fails, it might be a compression issue. | |
print("Could not read PDF directly. Proceeding with original file.") | |
# Reset buffer position for potential later use | |
input_buffer.seek(0) | |
def extract_text_from_pdf(pdf_path): | |
doc = fitz.open(pdf_path) | |
text = "" | |
for page in doc: | |
text+=page.get_text() | |
return text | |
def extract_text_from_pdfs_in_directory(directory): | |
for filename in os.listdir(directory): | |
if filename.endswith(".pdf"): | |
pdf_path = os.path.join(directory, filename) | |
extracted_text = extract_text_from_pdf(pdf_path) | |
txt_filename = os.path.splitext(filename)[0] + ".txt" | |
txt_filepath = os.path.join(directory, txt_filename) | |
with open(txt_filepath, "w", encoding="utf-8") as txt_file: | |
txt_file.write(extracted_text) | |
# Specify the directory containing PDF files | |
directory_path = "./" | |
# Extract text from PDFs in the directory and save as text files | |
extract_text_from_pdfs_in_directory(directory_path) | |
# List all .txt files in the directory | |
txt_files = [file for file in os.listdir(directory_path) if file.endswith('.txt')] | |
# List to store sentences from all files | |
all_sentences = [] | |
# Read each text file, split into sentences, and store | |
for txt_file in txt_files: | |
file_path = os.path.join(directory_path, txt_file) | |
with open(file_path, "r", encoding="utf-8") as file: | |
text = file.read() | |
sentences = sent_tokenize(text) | |
all_sentences.extend(sentences) | |
# Print the first few sentences as an example | |
print(all_sentences[:10]) # Print first 10 sentences | |
# Initialize the TextEmbedding model | |
embedding_model = TextEmbedding(model_name="BAAI/bge-base-en", cache_dir="./embeddings") | |
def embed_documents(documents): | |
embeddings = [] | |
for document in documents: | |
# Embed document using FastEmbed | |
embedding = np.array(list((embedding_model.embed([document])))) | |
# Append the embedding to the list of embeddings | |
embeddings.append(embedding) | |
return embeddings | |
# Define the documents | |
documents = all_sentences | |
# Perform embedding generation | |
embeddings = embed_documents(documents) | |
embeddings = [sublist[0] for sublist in embeddings] | |
client = QdrantClient( | |
url="https://ec069eb8-1679-4f53-971c-8fef6fe7d057.us-west-2-0.aws.cloud.qdrant.io", | |
api_key=QDRANT_API_KEY, | |
https=True, | |
) | |
collection_name = 'fastembed_collection' | |
client.recreate_collection( | |
collection_name=collection_name, | |
vectors_config=VectorParams(size=768, distance=Distance.COSINE), | |
) | |
client.upload_points( | |
collection_name=collection_name, | |
points=[ | |
PointStruct( | |
id=idx, | |
vector=vector.tolist(), | |
payload={"text": text} | |
) | |
for idx, (vector, text) in enumerate(zip(embeddings, documents)) | |
] | |
) | |
genai.configure(api_key=GOOGLE_API_KEY) | |
model = genai.GenerativeModel('gemini-2.5-pro') | |
# Function to generate completion from prompt | |
def generate_completion(prompt): | |
response = model.generate_content(prompt) | |
return response.text | |
# Function to embed Queries | |
def embed_query(Question): | |
return np.array(list(embedding_model.embed([Question]))) | |
def generate_response(Question): | |
query_embeddings = embed_query(Question) | |
collection_name = 'fastembed_collection' | |
all_text = "" | |
# Retrieve all hits and concatenate texts into a single prompt | |
for query_embedding in query_embeddings: | |
query_vector: List[np.ndarray] = list(query_embedding) | |
hits = client.search( | |
collection_name=collection_name, | |
query_vector=query_vector, | |
limit=50 | |
) | |
for hit in hits: | |
text = hit.payload["text"] | |
all_text += text + "\n\n" | |
# Generate completion using all texts as a single prompt | |
prompt = f"You are a helpful chatbot. Use only the following pieces of context to answer the question. Don't make up any new information:\n\n{all_text}\n\nQuestion:{Question}\n\nAnswer:" | |
completion = generate_completion(prompt) | |
return completion | |
# Set up the Gradio interface | |
iface = gr.Interface( | |
fn=generate_response, | |
inputs=[gr.Textbox(label="Question")], # Pass input as a list | |
outputs=[gr.Textbox(label="Generated Response")], # Pass output as a list | |
title="RAG with Qdrant, FastEmbed and Gemini", | |
description="Enter a question and get a generated response based on the retrieved text.", | |
) | |
iface.launch(share=True,debug=True) |