hackrxsubmission / rag_utils.py
shreyanshknayak's picture
Upload 4 files
437d8b7 verified
import os
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
from ragatouille import RAGPretrainedModel
from groq import Groq
import json
#from PyPDF2 import PdfReader # New import for raw PDF text extraction
import re
# --- Configuration (can be overridden by the calling app) ---
CHUNK_SIZE = 1000
CHUNK_OVERLAP = 200
TOP_K_CHUNKS = 7
GROQ_MODEL_NAME = "llama3-8b-8192"
# --- Helper Functions ---
def extract_raw_text_from_pdf(pdf_path: str) -> str:
"""
Extracts raw text from a PDF file using PyPDF2.
This is a simpler text extraction compared to LLMSherpa, suitable for manual sectioning.
"""
try:
reader = PdfReader(pdf_path)
full_text = ""
for page in reader.pages:
full_text += page.extract_text() + "\n" # Add newline between pages
print(f"Extracted raw text from PDF: {len(full_text)} characters.")
return full_text
except Exception as e:
print(f"Error extracting raw text from PDF: {e}")
return ""
def process_markdown_with_manual_sections(
md_file_path: str,
headings_json: dict,
chunk_size: int,
chunk_overlap: int
):
"""
Processes a markdown document from a file path by segmenting it based on
provided section headings, and then recursively chunking each segment.
Each chunk receives the corresponding section heading as metadata.
Args:
md_file_path (str): The path to the input markdown (.md) file.
headings_json (dict): A JSON object with schema: {"headings": ["Your Heading 1", "Your Heading 2"]}
This contains the major section headings to split by.
chunk_size (int): The maximum size of each text chunk.
chunk_overlap (int): The number of characters to overlap between consecutive chunks.
Returns:
tuple[list[Document], list[dict]]: A tuple containing:
- list[Document]: A list of LangChain Document objects, each containing
a text chunk and its associated metadata.
- list[dict]: A list of dictionaries, each with {"section_heading", "section_text"}
representing the segmented sections for evaluation.
"""
all_chunks_with_metadata = []
full_text = ""
# Check if the file exists and read its content
if not os.path.exists(md_file_path):
print(f"Error: File not found at '{md_file_path}'")
return [], []
if not os.path.isfile(md_file_path):
print(f"Error: Path '{md_file_path}' is not a file.")
return [], []
if not md_file_path.lower().endswith(".md"):
print(f"Warning: File '{md_file_path}' does not have a .md extension.")
try:
with open(md_file_path, 'r', encoding='utf-8') as f:
full_text = f.read()
except Exception as e:
print(f"Error reading file '{md_file_path}': {e}")
return [], []
if not full_text:
print("Input markdown file is empty.")
return [], []
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len,
is_separator_regex=False,
)
# Extract heading texts from the 'headings' key
heading_texts = headings_json.get("headings", [])
print(f"Identified headings for segmentation: {heading_texts}")
# Find start indices of all headings in the full text using regex
heading_positions = []
for heading in heading_texts:
# Create a regex pattern to match the heading, ignoring extra whitespace and making it case-insensitive
# re.escape() escapes special characters in the heading string
# \s* matches zero or more whitespace characters
# re.IGNORECASE makes the match case-insensitive
pattern = re.compile(r'\s*'.join(re.escape(word) for word in heading.split()))
match = pattern.search(full_text)
if match:
heading_positions.append({"heading_text": heading, "start_index": match.start()})
else:
print(f"Warning: Heading '{heading}' not found in the markdown text using regex. This section might be missed.")
# Sort heading positions by their start index
heading_positions.sort(key=lambda x: x["start_index"])
# Segment the text based on heading positions
segments_with_headings = []
# Handle preface (text before the first heading)
if heading_positions and heading_positions[0]["start_index"] > 0:
preface_text = full_text[:heading_positions[0]["start_index"]].strip()
if preface_text:
segments_with_headings.append({
"section_heading": "Document Start/Preface",
"section_text": preface_text
})
# Iterate through heading positions to define sections
for i, current_heading_info in enumerate(heading_positions):
start_index = current_heading_info["start_index"]
heading_text = current_heading_info["heading_text"]
# Determine the end index for the current section
end_index = len(full_text)
if i + 1 < len(heading_positions):
end_index = heading_positions[i+1]["start_index"]
# Extract section content (from current heading's start to next heading's start)
# We include the heading text itself in the section_text
section_content = full_text[start_index:end_index].strip()
if section_content:
segments_with_headings.append({
"section_heading": heading_text,
"section_text": section_content
})
print(f"Created {len(segments_with_headings)} segments based on provided headings.")
# Chunk each segment and attach metadata
for segment in segments_with_headings:
section_heading = segment["section_heading"]
section_text = segment["section_text"]
if section_text:
chunks = text_splitter.split_text(section_text)
for chunk in chunks:
metadata = {
"document_part": "Section", # All these are now considered 'Section'
"section_heading": section_heading,
}
all_chunks_with_metadata.append(Document(page_content=chunk, metadata=metadata))
print(f"Created {len(all_chunks_with_metadata)} chunks with metadata from segmented sections.")
with open("output.json", 'w', encoding='utf-8') as f:
json.dump(segments_with_headings, f, indent=4, ensure_ascii=False)
return all_chunks_with_metadata
def perform_vector_search(documents: list[Document], query: str, top_k: int, rag_model_instance=None) -> list[dict]:
"""
Performs vector search using Ragatouille's ColBERT implementation
to retrieve the top k relevant chunks, preserving metadata.
Args:
documents (list[Document]): The list of LangChain Document objects to index and search.
query (str): The search query.
top_k (int): The number of top relevant chunks to retrieve.
rag_model_instance: An optional pre-loaded Ragatouille model instance.
If None, a new one will be loaded.
Returns:
list[dict]: A list of dictionaries, each containing 'content' and 'document_metadata'
from the Ragatouille search results.
"""
if rag_model_instance is None:
print("Initializing Ragatouille ColBERT model...")
rag = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")
else:
rag = rag_model_instance
# Separate content and metadata for indexing
collection_texts = [doc.page_content for doc in documents]
collection_metadatas = [doc.metadata for doc in documents]
index_name = "custom_chunks_index" # Changed index name
print("Indexing chunks with Ragatouille (this may take a while for large datasets)...")
rag.index(
collection=collection_texts,
document_metadatas=collection_metadatas,
index_name=index_name,
overwrite_index=True
)
print("Indexing complete.")
print(f"Performing vector search for query: '{query}' (top_k={top_k})...")
results = rag.search(query=query, k=top_k)
print(f"Retrieved {len(results)} top chunks.")
return results
def generate_answer_with_groq(query: str, retrieved_results: list[dict], groq_api_key: str) -> str:
"""
Generates an answer using the Groq API based on the query and retrieved chunks' content.
Includes metadata in the prompt for better context.
Args:
query (str): The original user query.
retrieved_results (list[dict]): A list of dictionaries from Ragatouille search,
each with 'content' and 'document_metadata'.
groq_api_key (str): The Groq API key.
Returns:
str: The generated answer.
"""
if not groq_api_key:
return "Error: Groq API key is not set. Cannot generate answer."
print("Generating answer with Groq API...")
client = Groq(api_key=groq_api_key)
context_parts = []
for i, res in enumerate(retrieved_results):
content = res.get("content", "")
metadata = res.get("document_metadata", {})
section_heading = metadata.get("section_heading", "N/A")
document_part = metadata.get("document_part", "N/A") # New metadata field
context_parts.append(
f"--- Context Chunk {i+1} ---\n"
f"Document Part: {document_part}\n"
f"Section Heading: {section_heading}\n"
f"Content: {content}\n"
f"-------------------------"
)
context = "\n\n".join(context_parts)
prompt = (
f"You are a specialized document analyzer assistant. Your task is to answer the user's question "
f"solely based on the provided context. Pay close attention to the section heading and document part "
f"for each context chunk. Ensure your answer incorporates all relevant details, including any legal nuances "
f"and conditions found in the context, and is concise, limited to one or two sentences. "
f"Do not explicitly mention the retrieved chunks. If the answer cannot be found in the provided context, "
f"clearly state that you do not have enough information.\n\n"
f"Context:\n{context}\n\n"
f"Question: {query}\n\n"
f"Answer:"
)
try:
chat_completion = client.chat.completions.create(
messages=[
{
"role": "user",
"content": prompt,
}
],
model=GROQ_MODEL_NAME,
temperature=0.7,
max_tokens=500,
)
answer = chat_completion.choices[0].message.content
print("Answer generated successfully.")
return answer
except Exception as e:
print(f"An error occurred during Groq API call: {e}")
return "Could not generate an answer due to an API error."