File size: 11,127 Bytes
437d8b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 |
import os
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
from ragatouille import RAGPretrainedModel
from groq import Groq
import json
#from PyPDF2 import PdfReader # New import for raw PDF text extraction
import re
# --- Configuration (can be overridden by the calling app) ---
CHUNK_SIZE = 1000
CHUNK_OVERLAP = 200
TOP_K_CHUNKS = 7
GROQ_MODEL_NAME = "llama3-8b-8192"
# --- Helper Functions ---
def extract_raw_text_from_pdf(pdf_path: str) -> str:
"""
Extracts raw text from a PDF file using PyPDF2.
This is a simpler text extraction compared to LLMSherpa, suitable for manual sectioning.
"""
try:
reader = PdfReader(pdf_path)
full_text = ""
for page in reader.pages:
full_text += page.extract_text() + "\n" # Add newline between pages
print(f"Extracted raw text from PDF: {len(full_text)} characters.")
return full_text
except Exception as e:
print(f"Error extracting raw text from PDF: {e}")
return ""
def process_markdown_with_manual_sections(
md_file_path: str,
headings_json: dict,
chunk_size: int,
chunk_overlap: int
):
"""
Processes a markdown document from a file path by segmenting it based on
provided section headings, and then recursively chunking each segment.
Each chunk receives the corresponding section heading as metadata.
Args:
md_file_path (str): The path to the input markdown (.md) file.
headings_json (dict): A JSON object with schema: {"headings": ["Your Heading 1", "Your Heading 2"]}
This contains the major section headings to split by.
chunk_size (int): The maximum size of each text chunk.
chunk_overlap (int): The number of characters to overlap between consecutive chunks.
Returns:
tuple[list[Document], list[dict]]: A tuple containing:
- list[Document]: A list of LangChain Document objects, each containing
a text chunk and its associated metadata.
- list[dict]: A list of dictionaries, each with {"section_heading", "section_text"}
representing the segmented sections for evaluation.
"""
all_chunks_with_metadata = []
full_text = ""
# Check if the file exists and read its content
if not os.path.exists(md_file_path):
print(f"Error: File not found at '{md_file_path}'")
return [], []
if not os.path.isfile(md_file_path):
print(f"Error: Path '{md_file_path}' is not a file.")
return [], []
if not md_file_path.lower().endswith(".md"):
print(f"Warning: File '{md_file_path}' does not have a .md extension.")
try:
with open(md_file_path, 'r', encoding='utf-8') as f:
full_text = f.read()
except Exception as e:
print(f"Error reading file '{md_file_path}': {e}")
return [], []
if not full_text:
print("Input markdown file is empty.")
return [], []
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len,
is_separator_regex=False,
)
# Extract heading texts from the 'headings' key
heading_texts = headings_json.get("headings", [])
print(f"Identified headings for segmentation: {heading_texts}")
# Find start indices of all headings in the full text using regex
heading_positions = []
for heading in heading_texts:
# Create a regex pattern to match the heading, ignoring extra whitespace and making it case-insensitive
# re.escape() escapes special characters in the heading string
# \s* matches zero or more whitespace characters
# re.IGNORECASE makes the match case-insensitive
pattern = re.compile(r'\s*'.join(re.escape(word) for word in heading.split()))
match = pattern.search(full_text)
if match:
heading_positions.append({"heading_text": heading, "start_index": match.start()})
else:
print(f"Warning: Heading '{heading}' not found in the markdown text using regex. This section might be missed.")
# Sort heading positions by their start index
heading_positions.sort(key=lambda x: x["start_index"])
# Segment the text based on heading positions
segments_with_headings = []
# Handle preface (text before the first heading)
if heading_positions and heading_positions[0]["start_index"] > 0:
preface_text = full_text[:heading_positions[0]["start_index"]].strip()
if preface_text:
segments_with_headings.append({
"section_heading": "Document Start/Preface",
"section_text": preface_text
})
# Iterate through heading positions to define sections
for i, current_heading_info in enumerate(heading_positions):
start_index = current_heading_info["start_index"]
heading_text = current_heading_info["heading_text"]
# Determine the end index for the current section
end_index = len(full_text)
if i + 1 < len(heading_positions):
end_index = heading_positions[i+1]["start_index"]
# Extract section content (from current heading's start to next heading's start)
# We include the heading text itself in the section_text
section_content = full_text[start_index:end_index].strip()
if section_content:
segments_with_headings.append({
"section_heading": heading_text,
"section_text": section_content
})
print(f"Created {len(segments_with_headings)} segments based on provided headings.")
# Chunk each segment and attach metadata
for segment in segments_with_headings:
section_heading = segment["section_heading"]
section_text = segment["section_text"]
if section_text:
chunks = text_splitter.split_text(section_text)
for chunk in chunks:
metadata = {
"document_part": "Section", # All these are now considered 'Section'
"section_heading": section_heading,
}
all_chunks_with_metadata.append(Document(page_content=chunk, metadata=metadata))
print(f"Created {len(all_chunks_with_metadata)} chunks with metadata from segmented sections.")
with open("output.json", 'w', encoding='utf-8') as f:
json.dump(segments_with_headings, f, indent=4, ensure_ascii=False)
return all_chunks_with_metadata
def perform_vector_search(documents: list[Document], query: str, top_k: int, rag_model_instance=None) -> list[dict]:
"""
Performs vector search using Ragatouille's ColBERT implementation
to retrieve the top k relevant chunks, preserving metadata.
Args:
documents (list[Document]): The list of LangChain Document objects to index and search.
query (str): The search query.
top_k (int): The number of top relevant chunks to retrieve.
rag_model_instance: An optional pre-loaded Ragatouille model instance.
If None, a new one will be loaded.
Returns:
list[dict]: A list of dictionaries, each containing 'content' and 'document_metadata'
from the Ragatouille search results.
"""
if rag_model_instance is None:
print("Initializing Ragatouille ColBERT model...")
rag = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")
else:
rag = rag_model_instance
# Separate content and metadata for indexing
collection_texts = [doc.page_content for doc in documents]
collection_metadatas = [doc.metadata for doc in documents]
index_name = "custom_chunks_index" # Changed index name
print("Indexing chunks with Ragatouille (this may take a while for large datasets)...")
rag.index(
collection=collection_texts,
document_metadatas=collection_metadatas,
index_name=index_name,
overwrite_index=True
)
print("Indexing complete.")
print(f"Performing vector search for query: '{query}' (top_k={top_k})...")
results = rag.search(query=query, k=top_k)
print(f"Retrieved {len(results)} top chunks.")
return results
def generate_answer_with_groq(query: str, retrieved_results: list[dict], groq_api_key: str) -> str:
"""
Generates an answer using the Groq API based on the query and retrieved chunks' content.
Includes metadata in the prompt for better context.
Args:
query (str): The original user query.
retrieved_results (list[dict]): A list of dictionaries from Ragatouille search,
each with 'content' and 'document_metadata'.
groq_api_key (str): The Groq API key.
Returns:
str: The generated answer.
"""
if not groq_api_key:
return "Error: Groq API key is not set. Cannot generate answer."
print("Generating answer with Groq API...")
client = Groq(api_key=groq_api_key)
context_parts = []
for i, res in enumerate(retrieved_results):
content = res.get("content", "")
metadata = res.get("document_metadata", {})
section_heading = metadata.get("section_heading", "N/A")
document_part = metadata.get("document_part", "N/A") # New metadata field
context_parts.append(
f"--- Context Chunk {i+1} ---\n"
f"Document Part: {document_part}\n"
f"Section Heading: {section_heading}\n"
f"Content: {content}\n"
f"-------------------------"
)
context = "\n\n".join(context_parts)
prompt = (
f"You are a specialized document analyzer assistant. Your task is to answer the user's question "
f"solely based on the provided context. Pay close attention to the section heading and document part "
f"for each context chunk. Ensure your answer incorporates all relevant details, including any legal nuances "
f"and conditions found in the context, and is concise, limited to one or two sentences. "
f"Do not explicitly mention the retrieved chunks. If the answer cannot be found in the provided context, "
f"clearly state that you do not have enough information.\n\n"
f"Context:\n{context}\n\n"
f"Question: {query}\n\n"
f"Answer:"
)
try:
chat_completion = client.chat.completions.create(
messages=[
{
"role": "user",
"content": prompt,
}
],
model=GROQ_MODEL_NAME,
temperature=0.7,
max_tokens=500,
)
answer = chat_completion.choices[0].message.content
print("Answer generated successfully.")
return answer
except Exception as e:
print(f"An error occurred during Groq API call: {e}")
return "Could not generate an answer due to an API error." |