import os from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_core.documents import Document from ragatouille import RAGPretrainedModel from groq import Groq import json #from PyPDF2 import PdfReader # New import for raw PDF text extraction import re # --- Configuration (can be overridden by the calling app) --- CHUNK_SIZE = 1000 CHUNK_OVERLAP = 200 TOP_K_CHUNKS = 7 GROQ_MODEL_NAME = "llama3-8b-8192" # --- Helper Functions --- def extract_raw_text_from_pdf(pdf_path: str) -> str: """ Extracts raw text from a PDF file using PyPDF2. This is a simpler text extraction compared to LLMSherpa, suitable for manual sectioning. """ try: reader = PdfReader(pdf_path) full_text = "" for page in reader.pages: full_text += page.extract_text() + "\n" # Add newline between pages print(f"Extracted raw text from PDF: {len(full_text)} characters.") return full_text except Exception as e: print(f"Error extracting raw text from PDF: {e}") return "" def process_markdown_with_manual_sections( md_file_path: str, headings_json: dict, chunk_size: int, chunk_overlap: int ): """ Processes a markdown document from a file path by segmenting it based on provided section headings, and then recursively chunking each segment. Each chunk receives the corresponding section heading as metadata. Args: md_file_path (str): The path to the input markdown (.md) file. headings_json (dict): A JSON object with schema: {"headings": ["Your Heading 1", "Your Heading 2"]} This contains the major section headings to split by. chunk_size (int): The maximum size of each text chunk. chunk_overlap (int): The number of characters to overlap between consecutive chunks. Returns: tuple[list[Document], list[dict]]: A tuple containing: - list[Document]: A list of LangChain Document objects, each containing a text chunk and its associated metadata. - list[dict]: A list of dictionaries, each with {"section_heading", "section_text"} representing the segmented sections for evaluation. """ all_chunks_with_metadata = [] full_text = "" # Check if the file exists and read its content if not os.path.exists(md_file_path): print(f"Error: File not found at '{md_file_path}'") return [], [] if not os.path.isfile(md_file_path): print(f"Error: Path '{md_file_path}' is not a file.") return [], [] if not md_file_path.lower().endswith(".md"): print(f"Warning: File '{md_file_path}' does not have a .md extension.") try: with open(md_file_path, 'r', encoding='utf-8') as f: full_text = f.read() except Exception as e: print(f"Error reading file '{md_file_path}': {e}") return [], [] if not full_text: print("Input markdown file is empty.") return [], [] text_splitter = RecursiveCharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function=len, is_separator_regex=False, ) # Extract heading texts from the 'headings' key heading_texts = headings_json.get("headings", []) print(f"Identified headings for segmentation: {heading_texts}") # Find start indices of all headings in the full text using regex heading_positions = [] for heading in heading_texts: # Create a regex pattern to match the heading, ignoring extra whitespace and making it case-insensitive # re.escape() escapes special characters in the heading string # \s* matches zero or more whitespace characters # re.IGNORECASE makes the match case-insensitive pattern = re.compile(r'\s*'.join(re.escape(word) for word in heading.split())) match = pattern.search(full_text) if match: heading_positions.append({"heading_text": heading, "start_index": match.start()}) else: print(f"Warning: Heading '{heading}' not found in the markdown text using regex. This section might be missed.") # Sort heading positions by their start index heading_positions.sort(key=lambda x: x["start_index"]) # Segment the text based on heading positions segments_with_headings = [] # Handle preface (text before the first heading) if heading_positions and heading_positions[0]["start_index"] > 0: preface_text = full_text[:heading_positions[0]["start_index"]].strip() if preface_text: segments_with_headings.append({ "section_heading": "Document Start/Preface", "section_text": preface_text }) # Iterate through heading positions to define sections for i, current_heading_info in enumerate(heading_positions): start_index = current_heading_info["start_index"] heading_text = current_heading_info["heading_text"] # Determine the end index for the current section end_index = len(full_text) if i + 1 < len(heading_positions): end_index = heading_positions[i+1]["start_index"] # Extract section content (from current heading's start to next heading's start) # We include the heading text itself in the section_text section_content = full_text[start_index:end_index].strip() if section_content: segments_with_headings.append({ "section_heading": heading_text, "section_text": section_content }) print(f"Created {len(segments_with_headings)} segments based on provided headings.") # Chunk each segment and attach metadata for segment in segments_with_headings: section_heading = segment["section_heading"] section_text = segment["section_text"] if section_text: chunks = text_splitter.split_text(section_text) for chunk in chunks: metadata = { "document_part": "Section", # All these are now considered 'Section' "section_heading": section_heading, } all_chunks_with_metadata.append(Document(page_content=chunk, metadata=metadata)) print(f"Created {len(all_chunks_with_metadata)} chunks with metadata from segmented sections.") with open("output.json", 'w', encoding='utf-8') as f: json.dump(segments_with_headings, f, indent=4, ensure_ascii=False) return all_chunks_with_metadata def perform_vector_search(documents: list[Document], query: str, top_k: int, rag_model_instance=None) -> list[dict]: """ Performs vector search using Ragatouille's ColBERT implementation to retrieve the top k relevant chunks, preserving metadata. Args: documents (list[Document]): The list of LangChain Document objects to index and search. query (str): The search query. top_k (int): The number of top relevant chunks to retrieve. rag_model_instance: An optional pre-loaded Ragatouille model instance. If None, a new one will be loaded. Returns: list[dict]: A list of dictionaries, each containing 'content' and 'document_metadata' from the Ragatouille search results. """ if rag_model_instance is None: print("Initializing Ragatouille ColBERT model...") rag = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0") else: rag = rag_model_instance # Separate content and metadata for indexing collection_texts = [doc.page_content for doc in documents] collection_metadatas = [doc.metadata for doc in documents] index_name = "custom_chunks_index" # Changed index name print("Indexing chunks with Ragatouille (this may take a while for large datasets)...") rag.index( collection=collection_texts, document_metadatas=collection_metadatas, index_name=index_name, overwrite_index=True ) print("Indexing complete.") print(f"Performing vector search for query: '{query}' (top_k={top_k})...") results = rag.search(query=query, k=top_k) print(f"Retrieved {len(results)} top chunks.") return results def generate_answer_with_groq(query: str, retrieved_results: list[dict], groq_api_key: str) -> str: """ Generates an answer using the Groq API based on the query and retrieved chunks' content. Includes metadata in the prompt for better context. Args: query (str): The original user query. retrieved_results (list[dict]): A list of dictionaries from Ragatouille search, each with 'content' and 'document_metadata'. groq_api_key (str): The Groq API key. Returns: str: The generated answer. """ if not groq_api_key: return "Error: Groq API key is not set. Cannot generate answer." print("Generating answer with Groq API...") client = Groq(api_key=groq_api_key) context_parts = [] for i, res in enumerate(retrieved_results): content = res.get("content", "") metadata = res.get("document_metadata", {}) section_heading = metadata.get("section_heading", "N/A") document_part = metadata.get("document_part", "N/A") # New metadata field context_parts.append( f"--- Context Chunk {i+1} ---\n" f"Document Part: {document_part}\n" f"Section Heading: {section_heading}\n" f"Content: {content}\n" f"-------------------------" ) context = "\n\n".join(context_parts) prompt = ( f"You are a specialized document analyzer assistant. Your task is to answer the user's question " f"solely based on the provided context. Pay close attention to the section heading and document part " f"for each context chunk. Ensure your answer incorporates all relevant details, including any legal nuances " f"and conditions found in the context, and is concise, limited to one or two sentences. " f"Do not explicitly mention the retrieved chunks. If the answer cannot be found in the provided context, " f"clearly state that you do not have enough information.\n\n" f"Context:\n{context}\n\n" f"Question: {query}\n\n" f"Answer:" ) try: chat_completion = client.chat.completions.create( messages=[ { "role": "user", "content": prompt, } ], model=GROQ_MODEL_NAME, temperature=0.7, max_tokens=500, ) answer = chat_completion.choices[0].message.content print("Answer generated successfully.") return answer except Exception as e: print(f"An error occurred during Groq API call: {e}") return "Could not generate an answer due to an API error."