Spaces:
Sleeping
Sleeping
| import glob | |
| import os | |
| import pickle | |
| import re | |
| from pathlib import Path | |
| import gradio as gr | |
| import spaces | |
| import numpy as np | |
| from pypdf import PdfReader | |
| from transformers import AutoModel | |
| chunk_size = int(os.environ.get("CHUNK_SIZE", 1000)) | |
| default_k = int(os.environ.get("DEFAULT_K", 5)) | |
| model = AutoModel.from_pretrained("jinaai/jina-embeddings-v2-base-es", trust_remote_code=True) | |
| docs = {} | |
| def extract_text_from_pdf(reader: PdfReader) -> str: | |
| """Extract text from PDF pages | |
| Parameters | |
| ---------- | |
| reader : PdfReader | |
| PDF reader | |
| Returns | |
| ------- | |
| str | |
| Raw text | |
| """ | |
| content = [page.extract_text().strip() for page in reader.pages] | |
| return "\n\n".join(content).strip() | |
| def convert(filename: str) -> str: | |
| """Convert file content to raw text | |
| Parameters | |
| ---------- | |
| filename : str | |
| The filename or path | |
| Returns | |
| ------- | |
| str | |
| The raw text | |
| Raises | |
| ------ | |
| ValueError | |
| If the file type is not supported. | |
| """ | |
| plain_text_filetypes = [ | |
| ".txt", | |
| ".csv", | |
| ".tsv", | |
| ".md", | |
| ".yaml", | |
| ".toml", | |
| ".json", | |
| ".json5", | |
| ".jsonc", | |
| ] | |
| # Already a plain text file that wouldn't benefit from pandoc so return the content | |
| if any(filename.endswith(ft) for ft in plain_text_filetypes): | |
| with open(filename, "r", encoding="utf-8") as f: | |
| return f.read() | |
| if filename.endswith(".pdf"): | |
| return extract_text_from_pdf(PdfReader(filename)) | |
| raise ValueError(f"Unsupported file type: {filename}") | |
| def generate_chunks(text: str, max_length: int) -> list[str]: | |
| """Generate chunks from a file's raw text. Chunks are calculated based | |
| on the `max_lenght` parameter and the split character (.) | |
| Parameters | |
| ---------- | |
| text : str | |
| The raw text | |
| max_length : int | |
| Maximum number of characters a chunk can have. Note that chunks | |
| may not have this exact lenght, as another component is also | |
| involved in the splitting process | |
| Returns | |
| ------- | |
| list[str] | |
| A list of chunks/nodes | |
| """ | |
| segments = text.split(".") | |
| chunks = [] | |
| chunk = "" | |
| for current_segment in segments: | |
| # try to normalize the current chunk | |
| current_segment = re.sub(r"\s+", " ", current_segment).strip() | |
| if len(chunk) < max_length: | |
| chunk += f". {current_segment}" | |
| else: | |
| chunks.append(chunk) | |
| chunk = current_segment | |
| if chunk: | |
| chunks.append(chunk) | |
| return chunks | |
| def predict(query: str, k: int = 5) -> str: | |
| """Find k most relevant chunks based on the given query | |
| Parameters | |
| ---------- | |
| query : str | |
| The input query | |
| k : int, optional | |
| Number of relevant chunks to return, by default 5 | |
| Returns | |
| ------- | |
| str | |
| The k chunks concatenated together as a single string. | |
| Example | |
| ------- | |
| If k=2, the returned string might look like: | |
| "CONTEXT:\n\nchunk-1\n\nchunk-2" | |
| """ | |
| # Embed the query | |
| query_embedding = model.encode(query) | |
| # Initialize a list to store all chunks and their similarities across all documents | |
| all_chunks = [] | |
| # Iterate through all documents | |
| for filename, doc in docs.items(): | |
| # Calculate cosine similarity between the query and the document embeddings | |
| similarities = np.dot(doc["embeddings"], query_embedding) / ( | |
| np.linalg.norm(doc["embeddings"]) * np.linalg.norm(query_embedding) | |
| ) | |
| # Add chunks and similarities to the all_chunks list | |
| all_chunks.extend([(filename, chunk, sim) for chunk, sim in zip(doc["chunks"], similarities)]) | |
| # Sort all chunks by similarity | |
| all_chunks.sort(key=lambda x: x[2], reverse=True) | |
| return "CONTEXT:\n\n" + "\n\n".join(f"{filename}: {chunk}" for filename, chunk, _ in all_chunks[:k]) | |
| def init(): | |
| """Init function | |
| It will load or calculate the embeddings | |
| """ | |
| global docs # pylint: disable=W0603 | |
| embeddings_file = Path("embeddings-es.pickle") | |
| if embeddings_file.exists(): | |
| with open(embeddings_file, "rb") as embeddings_pickle: | |
| docs = pickle.load(embeddings_pickle) | |
| else: | |
| for filename in glob.glob("sources/*"): | |
| converted_doc = convert(filename) | |
| chunks = generate_chunks(converted_doc, chunk_size) | |
| embeddings = model.encode(chunks) | |
| # get the filename and slugify it | |
| docs[filename.rsplit("/", 1)[-1].lower().replace(" ", "-")] = { | |
| "chunks": chunks, | |
| "embeddings": embeddings, | |
| } | |
| with open(embeddings_file, "wb") as pickle_file: | |
| pickle.dump(docs, pickle_file) | |
| init() | |
| gr.Interface( | |
| predict, | |
| inputs=[ | |
| gr.Textbox(label="Query asked about the documents"), | |
| gr.Number(label="Number of relevant sources returned (k)", value=default_k), | |
| ], | |
| outputs=[gr.Text(label="Relevant chunks")], | |
| title="ContextQA tool - El Salvador", | |
| description="Forked and customized RAG tool working with law documents from El Salvador", | |
| ).launch() | |