|
import os |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_core.documents import Document |
|
from ragatouille import RAGPretrainedModel |
|
from groq import Groq |
|
import json |
|
|
|
import re |
|
|
|
|
|
CHUNK_SIZE = 1000 |
|
CHUNK_OVERLAP = 200 |
|
TOP_K_CHUNKS = 7 |
|
GROQ_MODEL_NAME = "llama3-8b-8192" |
|
|
|
|
|
|
|
def extract_raw_text_from_pdf(pdf_path: str) -> str: |
|
""" |
|
Extracts raw text from a PDF file using PyPDF2. |
|
This is a simpler text extraction compared to LLMSherpa, suitable for manual sectioning. |
|
""" |
|
try: |
|
reader = PdfReader(pdf_path) |
|
full_text = "" |
|
for page in reader.pages: |
|
full_text += page.extract_text() + "\n" |
|
print(f"Extracted raw text from PDF: {len(full_text)} characters.") |
|
return full_text |
|
except Exception as e: |
|
print(f"Error extracting raw text from PDF: {e}") |
|
return "" |
|
|
|
def process_markdown_with_manual_sections( |
|
md_file_path: str, |
|
headings_json: dict, |
|
chunk_size: int, |
|
chunk_overlap: int |
|
): |
|
""" |
|
Processes a markdown document from a file path by segmenting it based on |
|
provided section headings, and then recursively chunking each segment. |
|
Each chunk receives the corresponding section heading as metadata. |
|
|
|
Args: |
|
md_file_path (str): The path to the input markdown (.md) file. |
|
headings_json (dict): A JSON object with schema: {"headings": ["Your Heading 1", "Your Heading 2"]} |
|
This contains the major section headings to split by. |
|
chunk_size (int): The maximum size of each text chunk. |
|
chunk_overlap (int): The number of characters to overlap between consecutive chunks. |
|
|
|
Returns: |
|
tuple[list[Document], list[dict]]: A tuple containing: |
|
- list[Document]: A list of LangChain Document objects, each containing |
|
a text chunk and its associated metadata. |
|
- list[dict]: A list of dictionaries, each with {"section_heading", "section_text"} |
|
representing the segmented sections for evaluation. |
|
""" |
|
all_chunks_with_metadata = [] |
|
full_text = "" |
|
|
|
|
|
if not os.path.exists(md_file_path): |
|
print(f"Error: File not found at '{md_file_path}'") |
|
return [], [] |
|
if not os.path.isfile(md_file_path): |
|
print(f"Error: Path '{md_file_path}' is not a file.") |
|
return [], [] |
|
if not md_file_path.lower().endswith(".md"): |
|
print(f"Warning: File '{md_file_path}' does not have a .md extension.") |
|
|
|
try: |
|
with open(md_file_path, 'r', encoding='utf-8') as f: |
|
full_text = f.read() |
|
except Exception as e: |
|
print(f"Error reading file '{md_file_path}': {e}") |
|
return [], [] |
|
|
|
if not full_text: |
|
print("Input markdown file is empty.") |
|
return [], [] |
|
|
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=chunk_size, |
|
chunk_overlap=chunk_overlap, |
|
length_function=len, |
|
is_separator_regex=False, |
|
) |
|
|
|
|
|
heading_texts = headings_json.get("headings", []) |
|
print(f"Identified headings for segmentation: {heading_texts}") |
|
|
|
|
|
heading_positions = [] |
|
for heading in heading_texts: |
|
|
|
|
|
|
|
|
|
pattern = re.compile(r'\s*'.join(re.escape(word) for word in heading.split())) |
|
|
|
match = pattern.search(full_text) |
|
if match: |
|
heading_positions.append({"heading_text": heading, "start_index": match.start()}) |
|
else: |
|
print(f"Warning: Heading '{heading}' not found in the markdown text using regex. This section might be missed.") |
|
|
|
|
|
heading_positions.sort(key=lambda x: x["start_index"]) |
|
|
|
|
|
segments_with_headings = [] |
|
|
|
|
|
if heading_positions and heading_positions[0]["start_index"] > 0: |
|
preface_text = full_text[:heading_positions[0]["start_index"]].strip() |
|
if preface_text: |
|
segments_with_headings.append({ |
|
"section_heading": "Document Start/Preface", |
|
"section_text": preface_text |
|
}) |
|
|
|
|
|
for i, current_heading_info in enumerate(heading_positions): |
|
start_index = current_heading_info["start_index"] |
|
heading_text = current_heading_info["heading_text"] |
|
|
|
|
|
end_index = len(full_text) |
|
if i + 1 < len(heading_positions): |
|
end_index = heading_positions[i+1]["start_index"] |
|
|
|
|
|
|
|
section_content = full_text[start_index:end_index].strip() |
|
|
|
if section_content: |
|
segments_with_headings.append({ |
|
"section_heading": heading_text, |
|
"section_text": section_content |
|
}) |
|
|
|
print(f"Created {len(segments_with_headings)} segments based on provided headings.") |
|
|
|
|
|
for segment in segments_with_headings: |
|
section_heading = segment["section_heading"] |
|
section_text = segment["section_text"] |
|
|
|
if section_text: |
|
chunks = text_splitter.split_text(section_text) |
|
for chunk in chunks: |
|
metadata = { |
|
"document_part": "Section", |
|
"section_heading": section_heading, |
|
} |
|
all_chunks_with_metadata.append(Document(page_content=chunk, metadata=metadata)) |
|
|
|
print(f"Created {len(all_chunks_with_metadata)} chunks with metadata from segmented sections.") |
|
|
|
with open("output.json", 'w', encoding='utf-8') as f: |
|
json.dump(segments_with_headings, f, indent=4, ensure_ascii=False) |
|
return all_chunks_with_metadata |
|
|
|
def perform_vector_search(documents: list[Document], query: str, top_k: int, rag_model_instance=None) -> list[dict]: |
|
""" |
|
Performs vector search using Ragatouille's ColBERT implementation |
|
to retrieve the top k relevant chunks, preserving metadata. |
|
|
|
Args: |
|
documents (list[Document]): The list of LangChain Document objects to index and search. |
|
query (str): The search query. |
|
top_k (int): The number of top relevant chunks to retrieve. |
|
rag_model_instance: An optional pre-loaded Ragatouille model instance. |
|
If None, a new one will be loaded. |
|
|
|
Returns: |
|
list[dict]: A list of dictionaries, each containing 'content' and 'document_metadata' |
|
from the Ragatouille search results. |
|
""" |
|
if rag_model_instance is None: |
|
print("Initializing Ragatouille ColBERT model...") |
|
rag = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0") |
|
else: |
|
rag = rag_model_instance |
|
|
|
|
|
collection_texts = [doc.page_content for doc in documents] |
|
collection_metadatas = [doc.metadata for doc in documents] |
|
|
|
index_name = "custom_chunks_index" |
|
print("Indexing chunks with Ragatouille (this may take a while for large datasets)...") |
|
rag.index( |
|
collection=collection_texts, |
|
document_metadatas=collection_metadatas, |
|
index_name=index_name, |
|
overwrite_index=True |
|
) |
|
print("Indexing complete.") |
|
|
|
print(f"Performing vector search for query: '{query}' (top_k={top_k})...") |
|
results = rag.search(query=query, k=top_k) |
|
|
|
print(f"Retrieved {len(results)} top chunks.") |
|
return results |
|
|
|
def generate_answer_with_groq(query: str, retrieved_results: list[dict], groq_api_key: str) -> str: |
|
""" |
|
Generates an answer using the Groq API based on the query and retrieved chunks' content. |
|
Includes metadata in the prompt for better context. |
|
|
|
Args: |
|
query (str): The original user query. |
|
retrieved_results (list[dict]): A list of dictionaries from Ragatouille search, |
|
each with 'content' and 'document_metadata'. |
|
groq_api_key (str): The Groq API key. |
|
|
|
Returns: |
|
str: The generated answer. |
|
""" |
|
if not groq_api_key: |
|
return "Error: Groq API key is not set. Cannot generate answer." |
|
|
|
print("Generating answer with Groq API...") |
|
client = Groq(api_key=groq_api_key) |
|
|
|
context_parts = [] |
|
for i, res in enumerate(retrieved_results): |
|
content = res.get("content", "") |
|
metadata = res.get("document_metadata", {}) |
|
section_heading = metadata.get("section_heading", "N/A") |
|
document_part = metadata.get("document_part", "N/A") |
|
|
|
context_parts.append( |
|
f"--- Context Chunk {i+1} ---\n" |
|
f"Document Part: {document_part}\n" |
|
f"Section Heading: {section_heading}\n" |
|
f"Content: {content}\n" |
|
f"-------------------------" |
|
) |
|
context = "\n\n".join(context_parts) |
|
|
|
prompt = ( |
|
f"You are a specialized document analyzer assistant. Your task is to answer the user's question " |
|
f"solely based on the provided context. Pay close attention to the section heading and document part " |
|
f"for each context chunk. Ensure your answer incorporates all relevant details, including any legal nuances " |
|
f"and conditions found in the context, and is concise, limited to one or two sentences. " |
|
f"Do not explicitly mention the retrieved chunks. If the answer cannot be found in the provided context, " |
|
f"clearly state that you do not have enough information.\n\n" |
|
f"Context:\n{context}\n\n" |
|
f"Question: {query}\n\n" |
|
f"Answer:" |
|
) |
|
|
|
try: |
|
chat_completion = client.chat.completions.create( |
|
messages=[ |
|
{ |
|
"role": "user", |
|
"content": prompt, |
|
} |
|
], |
|
model=GROQ_MODEL_NAME, |
|
temperature=0.7, |
|
max_tokens=500, |
|
) |
|
answer = chat_completion.choices[0].message.content |
|
print("Answer generated successfully.") |
|
return answer |
|
except Exception as e: |
|
print(f"An error occurred during Groq API call: {e}") |
|
return "Could not generate an answer due to an API error." |