File size: 11,127 Bytes
437d8b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
import os
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
from ragatouille import RAGPretrainedModel
from groq import Groq
import json
#from PyPDF2 import PdfReader # New import for raw PDF text extraction
import re

# --- Configuration (can be overridden by the calling app) ---
CHUNK_SIZE = 1000
CHUNK_OVERLAP = 200
TOP_K_CHUNKS = 7
GROQ_MODEL_NAME = "llama3-8b-8192"

# --- Helper Functions ---

def extract_raw_text_from_pdf(pdf_path: str) -> str:
    """
    Extracts raw text from a PDF file using PyPDF2.
    This is a simpler text extraction compared to LLMSherpa, suitable for manual sectioning.
    """
    try:
        reader = PdfReader(pdf_path)
        full_text = ""
        for page in reader.pages:
            full_text += page.extract_text() + "\n" # Add newline between pages
        print(f"Extracted raw text from PDF: {len(full_text)} characters.")
        return full_text
    except Exception as e:
        print(f"Error extracting raw text from PDF: {e}")
        return ""

def process_markdown_with_manual_sections(
    md_file_path: str,
    headings_json: dict,
    chunk_size: int,
    chunk_overlap: int
):
    """
    Processes a markdown document from a file path by segmenting it based on
    provided section headings, and then recursively chunking each segment.
    Each chunk receives the corresponding section heading as metadata.

    Args:
        md_file_path (str): The path to the input markdown (.md) file.
        headings_json (dict): A JSON object with schema: {"headings": ["Your Heading 1", "Your Heading 2"]}
                              This contains the major section headings to split by.
        chunk_size (int): The maximum size of each text chunk.
        chunk_overlap (int): The number of characters to overlap between consecutive chunks.

    Returns:
        tuple[list[Document], list[dict]]: A tuple containing:
            - list[Document]: A list of LangChain Document objects, each containing
                              a text chunk and its associated metadata.
            - list[dict]: A list of dictionaries, each with {"section_heading", "section_text"}
                          representing the segmented sections for evaluation.
    """
    all_chunks_with_metadata = []
    full_text = ""

    # Check if the file exists and read its content
    if not os.path.exists(md_file_path):
        print(f"Error: File not found at '{md_file_path}'")
        return [], []
    if not os.path.isfile(md_file_path):
        print(f"Error: Path '{md_file_path}' is not a file.")
        return [], []
    if not md_file_path.lower().endswith(".md"):
        print(f"Warning: File '{md_file_path}' does not have a .md extension.")

    try:
        with open(md_file_path, 'r', encoding='utf-8') as f:
            full_text = f.read()
    except Exception as e:
        print(f"Error reading file '{md_file_path}': {e}")
        return [], []

    if not full_text:
        print("Input markdown file is empty.")
        return [], []

    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap,
        length_function=len,
        is_separator_regex=False,
    )

    # Extract heading texts from the 'headings' key
    heading_texts = headings_json.get("headings", [])
    print(f"Identified headings for segmentation: {heading_texts}")

    # Find start indices of all headings in the full text using regex
    heading_positions = []
    for heading in heading_texts:
        # Create a regex pattern to match the heading, ignoring extra whitespace and making it case-insensitive
        # re.escape() escapes special characters in the heading string
        # \s* matches zero or more whitespace characters
        # re.IGNORECASE makes the match case-insensitive
        pattern = re.compile(r'\s*'.join(re.escape(word) for word in heading.split()))
        
        match = pattern.search(full_text)
        if match:
            heading_positions.append({"heading_text": heading, "start_index": match.start()})
        else:
            print(f"Warning: Heading '{heading}' not found in the markdown text using regex. This section might be missed.")

    # Sort heading positions by their start index
    heading_positions.sort(key=lambda x: x["start_index"])

    # Segment the text based on heading positions
    segments_with_headings = []
    
    # Handle preface (text before the first heading)
    if heading_positions and heading_positions[0]["start_index"] > 0:
        preface_text = full_text[:heading_positions[0]["start_index"]].strip()
        if preface_text:
            segments_with_headings.append({
                "section_heading": "Document Start/Preface",
                "section_text": preface_text
            })

    # Iterate through heading positions to define sections
    for i, current_heading_info in enumerate(heading_positions):
        start_index = current_heading_info["start_index"]
        heading_text = current_heading_info["heading_text"]
        
        # Determine the end index for the current section
        end_index = len(full_text)
        if i + 1 < len(heading_positions):
            end_index = heading_positions[i+1]["start_index"]

        # Extract section content (from current heading's start to next heading's start)
        # We include the heading text itself in the section_text
        section_content = full_text[start_index:end_index].strip()
        
        if section_content:
            segments_with_headings.append({
                "section_heading": heading_text,
                "section_text": section_content
            })

    print(f"Created {len(segments_with_headings)} segments based on provided headings.")

    # Chunk each segment and attach metadata
    for segment in segments_with_headings:
        section_heading = segment["section_heading"]
        section_text = segment["section_text"]

        if section_text:
            chunks = text_splitter.split_text(section_text)
            for chunk in chunks:
                metadata = {
                    "document_part": "Section", # All these are now considered 'Section'
                    "section_heading": section_heading,
                }
                all_chunks_with_metadata.append(Document(page_content=chunk, metadata=metadata))
    
    print(f"Created {len(all_chunks_with_metadata)} chunks with metadata from segmented sections.")
    
    with open("output.json", 'w', encoding='utf-8') as f:
            json.dump(segments_with_headings, f, indent=4, ensure_ascii=False)
    return all_chunks_with_metadata

def perform_vector_search(documents: list[Document], query: str, top_k: int, rag_model_instance=None) -> list[dict]:
    """
    Performs vector search using Ragatouille's ColBERT implementation
    to retrieve the top k relevant chunks, preserving metadata.

    Args:
        documents (list[Document]): The list of LangChain Document objects to index and search.
        query (str): The search query.
        top_k (int): The number of top relevant chunks to retrieve.
        rag_model_instance: An optional pre-loaded Ragatouille model instance.
                            If None, a new one will be loaded.

    Returns:
        list[dict]: A list of dictionaries, each containing 'content' and 'document_metadata'
                    from the Ragatouille search results.
    """
    if rag_model_instance is None:
        print("Initializing Ragatouille ColBERT model...")
        rag = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")
    else:
        rag = rag_model_instance

    # Separate content and metadata for indexing
    collection_texts = [doc.page_content for doc in documents]
    collection_metadatas = [doc.metadata for doc in documents]

    index_name = "custom_chunks_index" # Changed index name
    print("Indexing chunks with Ragatouille (this may take a while for large datasets)...")
    rag.index(
        collection=collection_texts,
        document_metadatas=collection_metadatas,
        index_name=index_name,
        overwrite_index=True
    )
    print("Indexing complete.")

    print(f"Performing vector search for query: '{query}' (top_k={top_k})...")
    results = rag.search(query=query, k=top_k)

    print(f"Retrieved {len(results)} top chunks.")
    return results

def generate_answer_with_groq(query: str, retrieved_results: list[dict], groq_api_key: str) -> str:
    """
    Generates an answer using the Groq API based on the query and retrieved chunks' content.
    Includes metadata in the prompt for better context.

    Args:
        query (str): The original user query.
        retrieved_results (list[dict]): A list of dictionaries from Ragatouille search,
                                        each with 'content' and 'document_metadata'.
        groq_api_key (str): The Groq API key.

    Returns:
        str: The generated answer.
    """
    if not groq_api_key:
        return "Error: Groq API key is not set. Cannot generate answer."

    print("Generating answer with Groq API...")
    client = Groq(api_key=groq_api_key)

    context_parts = []
    for i, res in enumerate(retrieved_results):
        content = res.get("content", "")
        metadata = res.get("document_metadata", {})
        section_heading = metadata.get("section_heading", "N/A")
        document_part = metadata.get("document_part", "N/A") # New metadata field

        context_parts.append(
            f"--- Context Chunk {i+1} ---\n"
            f"Document Part: {document_part}\n"
            f"Section Heading: {section_heading}\n"
            f"Content: {content}\n"
            f"-------------------------"
        )
    context = "\n\n".join(context_parts)

    prompt = (
        f"You are a specialized document analyzer assistant. Your task is to answer the user's question "
        f"solely based on the provided context. Pay close attention to the section heading and document part "
        f"for each context chunk. Ensure your answer incorporates all relevant details, including any legal nuances "
        f"and conditions found in the context, and is concise, limited to one or two sentences. "
        f"Do not explicitly mention the retrieved chunks. If the answer cannot be found in the provided context, "
        f"clearly state that you do not have enough information.\n\n"
        f"Context:\n{context}\n\n"
        f"Question: {query}\n\n"
        f"Answer:"
    )

    try:
        chat_completion = client.chat.completions.create(
            messages=[
                {
                    "role": "user",
                    "content": prompt,
                }
            ],
            model=GROQ_MODEL_NAME,
            temperature=0.7,
            max_tokens=500,
        )
        answer = chat_completion.choices[0].message.content
        print("Answer generated successfully.")
        return answer
    except Exception as e:
        print(f"An error occurred during Groq API call: {e}")
        return "Could not generate an answer due to an API error."