import os import glob import pickle from pathlib import Path import gradio as gr import spaces import numpy as np from pypdf import PdfReader from sentence_transformers import SentenceTransformer model_name = os.environ.get("MODEL", "Snowflake/snowflake-arctic-embed-m") chunk_size = int(os.environ.get("CHUNK_SIZE", 1000)) default_k = int(os.environ.get("DEFAULT_K", 5)) model = SentenceTransformer(model_name) docs = {} def extract_text_from_pdf(reader: PdfReader) -> str: """Extract text from PDF pages Parameters ---------- reader : PdfReader PDF reader Returns ------- str Raw text """ content = [page.extract_text().strip() for page in reader.pages] return "\n\n".join(content).strip() def convert(filename: str) -> str: """Convert file content to raw text Parameters ---------- filename : str The filename or path Returns ------- str The raw text Raises ------ ValueError If the file type is not supported. """ plain_text_filetypes = [ ".txt", ".csv", ".tsv", ".md", ".yaml", ".toml", ".json", ".json5", ".jsonc", ] # Already a plain text file that wouldn't benefit from pandoc so return the content if any(filename.endswith(ft) for ft in plain_text_filetypes): with open(filename, "r", encoding="utf-8") as f: return f.read() if filename.endswith(".pdf"): return extract_text_from_pdf(PdfReader(filename)) raise ValueError(f"Unsupported file type: {filename}") def generate_chunks(text: str, max_length: int) -> list[str]: """Generate chunks from a file's raw text. Chunks are calculated based on the `max_lenght` parameter and the split character (.) Parameters ---------- text : str The raw text max_length : int Maximum number of characters a chunk can have. Note that chunks may not have this exact lenght, as another component is also involved in the splitting process Returns ------- list[str] A list of chunks/nodes """ segments = text.split(".") chunks = [] chunk = "" for current_segment in segments: if len(chunk) < max_length: chunk += current_segment else: chunks.append(chunk) chunk = current_segment if chunk: chunks.append(chunk) return chunks @spaces.GPU def predict(query: str, k: int = 5) -> str: """Find k most relevant chunks based on the given query Parameters ---------- query : str The input query k : int, optional Number of relevant chunks to return, by default 5 Returns ------- str The k chunks concatenated together as a single string. Example ------- If k=2, the returned string might look like: "CONTEXT:\n\nchunk-1\n\nchunk-2" """ # Embed the query query_embedding = model.encode(query, prompt_name="query") # Initialize a list to store all chunks and their similarities across all documents all_chunks = [] # Iterate through all documents for doc in docs.values(): # Calculate dot product between query and document embeddings similarities = np.dot(doc["embeddings"], query_embedding) / ( np.linalg.norm(doc["embeddings"]) * np.linalg.norm(query_embedding) ) # Add chunks and similarities to the all_chunks list all_chunks.extend(list(zip(doc["chunks"], similarities))) # Sort all chunks by similarity all_chunks.sort(key=lambda x: x[1], reverse=True) return "CONTEXT:\n\n" + "\n\n".join(chunk for chunk, _ in all_chunks[:k]) def init(): """Init function It will load or calculate the embeddings """ global docs # pylint: disable=W0603 embeddings_file = Path("embeddings.pickle") if embeddings_file.exists(): with open(embeddings_file, "rb") as embeddings_pickle: docs = pickle.load(embeddings_pickle) else: for filename in glob.glob("sources/*"): converted_doc = convert(filename) chunks = generate_chunks(converted_doc, chunk_size) embeddings = model.encode(chunks) docs[filename] = { "chunks": chunks, "embeddings": embeddings, } with open(embeddings_file, "wb") as pickle_file: pickle.dump(docs, pickle_file) init() gr.Interface( predict, inputs=[ gr.Textbox(label="Query asked about the documents"), gr.Number(label="Number of relevant sources returned (k)", value=default_k), ], outputs=[gr.Text(label="Relevant chunks")], title="ContextQA tool - El Salvador", description="Forked and customized RAG tool working with law documents from El Salvador", ).launch()