open-webui-rag-system / rag_system.py
hugging2021's picture
Update rag_system.py
e3fafe9 verified
import os
import re
import glob
import time
from collections import defaultdict
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
# PyMuPDF library
try:
import fitz # PyMuPDF
PYMUPDF_AVAILABLE = True
print("PyMuPDF library available")
except ImportError:
PYMUPDF_AVAILABLE = False
print("PyMuPDF library is not installed. Install with: pip install PyMuPDF")
# PDF processing utilities
import pytesseract
from PIL import Image
from pdf2image import convert_from_path
import pdfplumber
from pymupdf4llm import LlamaMarkdownReader
# --------------------------------
# Log Output
# --------------------------------
def log(msg):
print(f"[{time.strftime('%H:%M:%S')}] {msg}")
# --------------------------------
# Text Cleaning Function
# --------------------------------
def clean_text(text):
return re.sub(r"[^\uAC00-\uD7A3\u1100-\u11FF\u3130-\u318F\w\s.,!?\"'()$:\-]", "", text)
def apply_corrections(text):
corrections = {
'º©': 'info', 'Ì': 'of', '½': 'operation', 'Ã': '', '©': '',
'’': "'", '“': '"', 'â€': '"'
}
for k, v in corrections.items():
text = text.replace(k, v)
return text
# --------------------------------
# HWPX Processing (Section-wise Processing Only)
# --------------------------------
def load_hwpx(file_path):
"""Loading HWPX file (using XML parsing method only)"""
import zipfile
import xml.etree.ElementTree as ET
import chardet
log(f"Starting HWPX section-wise processing: {file_path}")
start = time.time()
documents = []
try:
with zipfile.ZipFile(file_path, 'r') as zip_ref:
file_list = zip_ref.namelist()
section_files = [f for f in file_list
if f.startswith('Contents/section') and f.endswith('.xml')]
section_files.sort() # Sort by section0.xml, section1.xml order
log(f"Found section files: {len(section_files)} files")
for section_idx, section_file in enumerate(section_files):
with zip_ref.open(section_file) as xml_file:
raw = xml_file.read()
encoding = chardet.detect(raw)['encoding'] or 'utf-8'
try:
text = raw.decode(encoding)
except UnicodeDecodeError:
text = raw.decode("cp949", errors="replace")
tree = ET.ElementTree(ET.fromstring(text))
root = tree.getroot()
# Find text without namespace
t_elements = [elem for elem in root.iter() if elem.tag.endswith('}t') or elem.tag == 't']
body_text = ""
for elem in t_elements:
if elem.text:
body_text += clean_text(elem.text) + " "
# Set page metadata to empty
page_value = ""
if body_text.strip():
documents.append(Document(
page_content=apply_corrections(body_text),
metadata={
"source": file_path,
"filename": os.path.basename(file_path),
"type": "hwpx_body",
"page": page_value,
"total_sections": len(section_files)
}
))
log(f"Section text extraction complete (chars: {len(body_text)})")
# Find tables
table_elements = [elem for elem in root.iter() if elem.tag.endswith('}table') or elem.tag == 'table']
if table_elements:
table_text = ""
for table_idx, table in enumerate(table_elements):
table_text += f"[Table {table_idx + 1}]\n"
rows = [elem for elem in table.iter() if elem.tag.endswith('}tr') or elem.tag == 'tr']
for row in rows:
row_text = []
cells = [elem for elem in row.iter() if elem.tag.endswith('}tc') or elem.tag == 'tc']
for cell in cells:
cell_texts = []
for t_elem in cell.iter():
if (t_elem.tag.endswith('}t') or t_elem.tag == 't') and t_elem.text:
cell_texts.append(clean_text(t_elem.text))
row_text.append(" ".join(cell_texts))
if row_text:
table_text += "\t".join(row_text) + "\n"
if table_text.strip():
documents.append(Document(
page_content=apply_corrections(table_text),
metadata={
"source": file_path,
"filename": os.path.basename(file_path),
"type": "hwpx_table",
"page": page_value,
"total_sections": len(section_files)
}
))
log(f"Table extraction complete")
# Find images
if [elem for elem in root.iter() if elem.tag.endswith('}picture') or elem.tag == 'picture']:
documents.append(Document(
page_content="[Image included]",
metadata={
"source": file_path,
"filename": os.path.basename(file_path),
"type": "hwpx_image",
"page": page_value,
"total_sections": len(section_files)
}
))
log(f"Image found")
except Exception as e:
log(f"HWPX processing error: {e}")
duration = time.time() - start
# Print summary of document information
if documents:
log(f"Number of extracted documents: {len(documents)}")
log(f"HWPX processing complete: {file_path} ⏱️ {duration:.2f}s, total {len(documents)} documents")
return documents
# --------------------------------
# PDF Processing Functions (same as before)
# --------------------------------
def run_ocr_on_image(image: Image.Image, lang='kor+eng'):
return pytesseract.image_to_string(image, lang=lang)
def extract_images_with_ocr(pdf_path, lang='kor+eng'):
try:
images = convert_from_path(pdf_path)
page_ocr_data = {}
for idx, img in enumerate(images):
page_num = idx + 1
text = run_ocr_on_image(img, lang=lang)
if text.strip():
page_ocr_data[page_num] = text.strip()
return page_ocr_data
except Exception as e:
print(f"Image OCR failed: {e}")
return {}
def extract_tables_with_pdfplumber(pdf_path):
page_table_data = {}
try:
with pdfplumber.open(pdf_path) as pdf:
for i, page in enumerate(pdf.pages):
page_num = i + 1
tables = page.extract_tables()
table_text = ""
for t_index, table in enumerate(tables):
if table:
table_text += f"[Table {t_index+1}]\n"
for row in table:
row_text = "\t".join(cell if cell else "" for cell in row)
table_text += row_text + "\n"
if table_text.strip():
page_table_data[page_num] = table_text.strip()
return page_table_data
except Exception as e:
print(f"Table extraction failed: {e}")
return {}
def extract_body_text_with_pages(pdf_path):
page_body_data = {}
try:
pdf_processor = LlamaMarkdownReader()
docs = pdf_processor.load_data(file_path=pdf_path)
combined_text = ""
for d in docs:
if isinstance(d, dict) and "text" in d:
combined_text += d["text"]
elif hasattr(d, "text"):
combined_text += d.text
if combined_text.strip():
chars_per_page = 2000
start = 0
page_num = 1
while start < len(combined_text):
end = start + chars_per_page
if end > len(combined_text):
end = len(combined_text)
page_text = combined_text[start:end]
if page_text.strip():
page_body_data[page_num] = page_text.strip()
page_num += 1
if end == len(combined_text):
break
start = end - 100
except Exception as e:
print(f"Body extraction failed: {e}")
return page_body_data
def load_pdf_with_metadata(pdf_path):
"""Extracts page-specific information from a PDF file"""
log(f"Starting PDF page-wise processing: {pdf_path}")
start = time.time()
# First, check the actual number of pages using PyPDFLoader
try:
from langchain_community.document_loaders import PyPDFLoader
loader = PyPDFLoader(pdf_path)
pdf_pages = loader.load()
actual_total_pages = len(pdf_pages)
log(f"Actual page count as verified by PyPDFLoader: {actual_total_pages}")
except Exception as e:
log(f"PyPDFLoader page count verification failed: {e}")
actual_total_pages = 1
try:
page_tables = extract_tables_with_pdfplumber(pdf_path)
except Exception as e:
page_tables = {}
print(f"Table extraction failed: {e}")
try:
page_ocr = extract_images_with_ocr(pdf_path)
except Exception as e:
page_ocr = {}
print(f"Image OCR failed: {e}")
try:
page_body = extract_body_text_with_pages(pdf_path)
except Exception as e:
page_body = {}
print(f"Body extraction failed: {e}")
duration = time.time() - start
log(f"PDF page-wise processing complete: {pdf_path} ⏱️ {duration:.2f}s")
# Set the total number of pages based on the actual number of pages
all_pages = set(page_tables.keys()) | set(page_ocr.keys()) | set(page_body.keys())
if all_pages:
max_extracted_page = max(all_pages)
# Use the greater of the actual and extracted page numbers
total_pages = max(actual_total_pages, max_extracted_page)
else:
total_pages = actual_total_pages
log(f"Final total page count set to: {total_pages}")
docs = []
for page_num in sorted(all_pages):
if page_num in page_tables and page_tables[page_num].strip():
docs.append(Document(
page_content=clean_text(apply_corrections(page_tables[page_num])),
metadata={
"source": pdf_path,
"filename": os.path.basename(pdf_path),
"type": "table",
"page": page_num,
"total_pages": total_pages
}
))
log(f"Page {page_num}: Table extraction complete")
if page_num in page_body and page_body[page_num].strip():
docs.append(Document(
page_content=clean_text(apply_corrections(page_body[page_num])),
metadata={
"source": pdf_path,
"filename": os.path.basename(pdf_path),
"type": "body",
"page": page_num,
"total_pages": total_pages
}
))
log(f"Page {page_num}: Body extraction complete")
if page_num in page_ocr and page_ocr[page_num].strip():
docs.append(Document(
page_content=clean_text(apply_corrections(page_ocr[page_num])),
metadata={
"source": pdf_path,
"filename": os.path.basename(pdf_path),
"type": "ocr",
"page": page_num,
"total_pages": total_pages
}
))
log(f"Page {page_num}: OCR extraction complete")
if not docs:
docs.append(Document(
page_content="[Content extraction failed]",
metadata={
"source": pdf_path,
"filename": os.path.basename(pdf_path),
"type": "error",
"page": 1,
"total_pages": total_pages
}
))
# Print summary of page information
if docs:
page_numbers = [doc.metadata.get('page', 0) for doc in docs if doc.metadata.get('page')]
if page_numbers:
log(f"Extracted page range: {min(page_numbers)} ~ {max(page_numbers)}")
log(f"PDF documents with extracted pages: {len(docs)} documents (total {total_pages} pages)")
return docs
# --------------------------------
# Document Loading and Splitting
# --------------------------------
def load_documents(folder_path):
documents = []
for file in glob.glob(os.path.join(folder_path, "*.hwpx")):
log(f"HWPX file found: {file}")
docs = load_hwpx(file)
documents.extend(docs)
for file in glob.glob(os.path.join(folder_path, "*.pdf")):
log(f"PDF file found: {file}")
documents.extend(load_pdf_with_metadata(file))
log(f"Document loading complete! Total documents: {len(documents)}")
return documents
def split_documents(documents, chunk_size=800, chunk_overlap=100):
log("Starting chunk splitting")
splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len
)
chunks = []
for doc in documents:
split = splitter.split_text(doc.page_content)
for i, chunk in enumerate(split):
enriched_chunk = f"passage: {chunk}"
chunks.append(Document(
page_content=enriched_chunk,
metadata={**doc.metadata, "chunk_index": i}
))
log(f"Chunk splitting complete: Created {len(chunks)} chunks")
return chunks
# --------------------------------
# Main Execution
# --------------------------------
def build_rag_chain(llm, vectorstore, language="en", k=7):
"""Build RAG Chain"""
question_prompt, refine_prompt = create_refine_prompts_with_pages(language)
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="refine",
retriever=vectorstore.as_retriever(search_kwargs={"k": k}),
chain_type_kwargs={
"question_prompt": question_prompt,
"refine_prompt": refine_prompt
},
return_source_documents=True
)
return qa_chain
def ask_question_with_pages(qa_chain, question):
"""Process questions"""
result = qa_chain({"query": question})
# Extract only the text after A: from the result
answer = result['result']
final_answer = answer.split("A:")[-1].strip() if "A:" in answer else answer.strip()
print(f"\nQuestion: {question}")
print(f"\nFinal Answer: {final_answer}")
# Metadata debugging info (disabled)
# debug_metadata_info(result["source_documents"])
# Organize reference documents by page
print("\nReference Document Summary:")
source_info = {}
for doc in result["source_documents"]:
source = doc.metadata.get('source', 'unknown')
page = doc.metadata.get('page', 'unknown')
doc_type = doc.metadata.get('type', 'unknown')
section = doc.metadata.get('section', None)
total_pages = doc.metadata.get('total_pages', None)
filename = doc.metadata.get('filename', 'unknown')
if filename == 'unknown':
filename = os.path.basename(source) if source != 'unknown' else 'unknown'
if filename not in source_info:
source_info[filename] = {
'pages': set(),
'sections': set(),
'types': set(),
'total_pages': total_pages
}
if page != 'unknown':
if isinstance(page, str) and page.startswith('section'):
source_info[filename]['sections'].add(page)
else:
source_info[filename]['pages'].add(page)
if section is not None:
source_info[filename]['sections'].add(f"section {section}")
source_info[filename]['types'].add(doc_type)
# Result output
total_chunks = len(result["source_documents"])
print(f"Total chunks used: {total_chunks}")
for filename, info in source_info.items():
print(f"\n- {filename}")
# Total page count information
if info['total_pages']:
print(f" Total page count: {info['total_pages']}")
# Page information output
if info['pages']:
pages_list = list(info['pages'])
print(f" Pages: {', '.join(map(str, pages_list))}")
# Section information output
if info['sections']:
sections_list = sorted(list(info['sections']))
print(f" Sections: {', '.join(sections_list)}")
# If no pages or sections are present
if not info['pages'] and not info['sections']:
print(f" Pages: No information")
# Output document type
types_str = ', '.join(sorted(info['types']))
print(f" Type: {types_str}")
return result
# Existing ask_question function is replaced with ask_question_with_pages
def ask_question(qa_chain, question):
"""Wrapper function for compatibility"""
return ask_question_with_pages(qa_chain, question)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="RAG refine system (supports page numbers)")
parser.add_argument("--vector_store", type=str, default="vector_db", help="Vector store path")
parser.add_argument("--model", type=str, default="LGAI-EXAONE/EXAONE-3.5-7.8B-Instruct", help="LLM model ID")
parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"], help="Device to use")
parser.add_argument("--k", type=int, default=7, help="Number of documents to retrieve")
parser.add_argument("--language", type=str, default="en", choices=["ko", "en"], help="Language to use")
parser.add_argument("--query", type=str, help="Question (runs interactive mode if not provided)")
args = parser.parse_args()
embeddings = get_embeddings(device=args.device)
vectorstore = load_vector_store(embeddings, load_path=args.vector_store)
llm = load_llama_model()
from rag_system import build_rag_chain, ask_question_with_pages #Hinzugefügt, um den neuen ask_question_with_pages code in der Konsole nutzbar zu machen.
qa_chain = build_rag_chain(llm, vectorstore, language=args.language, k=args.k)
print("RAG system with page number support ready!")
if args.query:
ask_question_with_pages(qa_chain, args.query)
else:
print("Starting interactive mode (enter 'exit', 'quit' to finish)")
while True:
try:
query = input("Question: ").strip()
if query.lower() in ["exit", "quit"]:
break
if query: # Prevent empty input
ask_question_with_pages(qa_chain, query)
except KeyboardInterrupt:
print("\n\nExiting program.")
break
except Exception as e:
print(f"Error occurred: {e}\nPlease try again.")