Spaces:
Running
Running
| import gradio as gr | |
| from langchain_text_splitters import MarkdownHeaderTextSplitter, RecursiveCharacterTextSplitter | |
| from langchain.schema import Document | |
| from typing import List, Dict, Any, Tuple | |
| import logging | |
| import re | |
| import base64 | |
| import mimetypes | |
| from datasets import Dataset | |
| from huggingface_hub import HfApi, get_token | |
| import huggingface_hub | |
| import os | |
| from mistralai import Mistral | |
| import gradio_client.utils as client_utils | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # --- Patch Gradio's get_type function to handle boolean schemas --- | |
| def patched_get_type(schema: Any) -> str: | |
| """Patched version of get_type to handle boolean schemas.""" | |
| if isinstance(schema, bool): | |
| return "bool" | |
| if "const" in schema: | |
| return f"Literal[{repr(schema['const'])}]" | |
| if "enum" in schema: | |
| return f"Literal[{', '.join(repr(v) for v in schema['enum'])}]" | |
| if "type" not in schema: | |
| return "Any" | |
| type_ = schema["type"] | |
| if isinstance(type_, list): | |
| return f"Union[{', '.join(t for t in type_ if t != 'null')}]" | |
| if type_ == "array": | |
| items = schema.get("items", {}) | |
| return f"List[{patched_json_schema_to_python_type(items, schema.get('$defs'))}]" | |
| if type_ == "object": | |
| return "Dict[str, Any]" | |
| if type_ == "null": | |
| return "None" | |
| if type_ == "integer": | |
| return "int" | |
| if type_ == "number": | |
| return "float" | |
| if type_ == "boolean": | |
| return "bool" | |
| return type_ | |
| def patched_json_schema_to_python_type(schema: Any, defs: Dict[str, Any] = None) -> str: | |
| """Patched version of json_schema_to_python_type to use patched_get_type.""" | |
| defs = defs or {} | |
| if not schema: | |
| return "Any" | |
| if "$ref" in schema: | |
| ref = schema["$ref"].split("/")[-1] | |
| return patched_json_schema_to_python_type(defs.get(ref, {}), defs) | |
| if "anyOf" in schema: | |
| types = [ | |
| patched_json_schema_to_python_type(s, defs) for s in schema["anyOf"] | |
| ] | |
| return f"Union[{', '.join(t for t in types if t != 'None')}]" | |
| if "type" in schema and schema["type"] == "array": | |
| items = schema.get("items", {}) | |
| elements = patched_json_schema_to_python_type(items, defs) | |
| return f"List[{elements}]" | |
| if "type" in schema and schema["type"] == "object": | |
| if "properties" in schema: | |
| des = [ | |
| f"{n}: {patched_json_schema_to_python_type(v, defs)}{client_utils.get_desc(v)}" | |
| for n, v in schema["properties"].items() | |
| ] | |
| return f"Dict[str, Union[{', '.join(des)}]]" | |
| if "additionalProperties" in schema: | |
| return f"Dict[str, {patched_json_schema_to_python_type(schema['additionalProperties'], defs)}]" | |
| return "Dict[str, Any]" | |
| return patched_get_type(schema) | |
| # Override Gradio's json_schema_to_python_type | |
| client_utils.json_schema_to_python_type = patched_json_schema_to_python_type | |
| # --- Mistral OCR Setup --- | |
| api_key = os.environ.get("MISTRAL_API_KEY") | |
| hf_token_global = None | |
| client = None | |
| if not api_key: | |
| logger.warning("MISTRAL_API_KEY not set. Attempting to use Hugging Face token.") | |
| api_key = get_token() | |
| if api_key: | |
| logger.info("Using Hugging Face token as MISTRAL_API_KEY.") | |
| else: | |
| logger.warning("No API key found.") | |
| if api_key: | |
| try: | |
| client = Mistral(api_key=api_key) | |
| logger.info("Mistral client initialized successfully.") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize Mistral client: {e}", exc_info=True) | |
| raise RuntimeError(f"Failed to initialize Mistral client: {e}") | |
| else: | |
| logger.error("Mistral API key not available. OCR will fail.") | |
| # --- Helper Functions --- | |
| def encode_image_bytes(image_bytes: bytes) -> str: | |
| """Encodes image bytes to a base64 string.""" | |
| return base64.b64encode(image_bytes).decode('utf-8') | |
| def extract_images_from_markdown(markdown_text: str) -> Dict[str, str]: | |
| """ | |
| Extracts base64 image data URIs from markdown and maps them to reference IDs. | |
| Returns a dictionary mapping reference IDs to base64 data URIs. | |
| """ | |
| image_map = {} | |
| img_refs = re.findall(r"!\[.*?\]\((data:image/[a-zA-Z+]+;base64,[A-Za-z0-9+/=]+)\)", markdown_text) | |
| for idx, img_uri in enumerate(img_refs): | |
| ref_id = f"img_ref_{idx+1}" | |
| image_map[ref_id] = img_uri | |
| return image_map | |
| def replace_image_references(markdown_text: str, image_map: Dict[str, str]) -> str: | |
| """ | |
| Replaces base64 image data URIs in markdown with reference IDs (e.g., img_ref_1). | |
| """ | |
| updated_markdown = markdown_text | |
| for ref_id, img_uri in image_map.items(): | |
| escaped_uri = re.escape(img_uri) | |
| pattern = r"(!\[.*?\]\()" + escaped_uri + r"(\))" | |
| updated_markdown = re.sub(pattern, f"\\1{ref_id}\\2", updated_markdown) | |
| return updated_markdown | |
| def get_combined_markdown(ocr_response: Any) -> Tuple[str, str, Dict[str, str]]: | |
| """Combines markdown from OCR pages, replacing image IDs with base64 data URIs.""" | |
| processed_markdowns = [] | |
| raw_markdowns = [] | |
| image_data_map = {} | |
| if not hasattr(ocr_response, 'pages') or not ocr_response.pages: | |
| logger.warning("OCR response has no 'pages' attribute or pages list is empty.") | |
| return "", "", {} | |
| try: | |
| for page_idx, page in enumerate(ocr_response.pages): | |
| if hasattr(page, 'images') and page.images: | |
| logger.info(f"Page {page_idx}: Found {len(page.images)} images.") | |
| for img in page.images: | |
| if hasattr(img, 'id') and hasattr(img, 'image_base64') and img.image_base64: | |
| image_data_map[img.id] = img.image_base64 | |
| logger.debug(f"Page {page_idx}: Image ID {img.id} added to image_data_map.") | |
| else: | |
| logger.warning(f"Page {page_idx}: Image object lacks 'id' or valid 'image_base64'. Image: {img}") | |
| else: | |
| logger.info(f"Page {page_idx}: No images found.") | |
| if not hasattr(page, 'markdown'): | |
| logger.warning(f"Page {page_idx} lacks 'markdown' attribute. Skipping.") | |
| continue | |
| current_raw_markdown = page.markdown if page.markdown else "" | |
| raw_markdowns.append(current_raw_markdown) | |
| current_processed_markdown = current_raw_markdown | |
| img_refs = re.findall(r"!\[.*?\]\((.*?)\)", current_processed_markdown) | |
| logger.debug(f"Page {page_idx}: Found {len(img_refs)} image references in markdown.") | |
| for img_id in img_refs: | |
| if img_id in image_data_map: | |
| base64_data_uri = image_data_map[img_id] | |
| escaped_img_id = re.escape(img_id) | |
| pattern = r"(!\[.*?\]\()" + escaped_img_id + r"(\))" | |
| if re.search(pattern, current_processed_markdown): | |
| current_processed_markdown = re.sub( | |
| pattern, | |
| r"\1" + base64_data_uri + r"\2", | |
| current_processed_markdown | |
| ) | |
| logger.debug(f"Page {page_idx}: Replaced image ID {img_id} with base64 data URI.") | |
| elif not img_id.startswith(('http:', 'https:', 'data:')): | |
| logger.warning(f"Page {page_idx}: Image ID '{img_id}' not in image data.") | |
| processed_markdowns.append(current_processed_markdown) | |
| logger.info(f"Processed {len(processed_markdowns)} pages with {len(image_data_map)} images.") | |
| return "\n\n".join(processed_markdowns), "\n\n".join(raw_markdowns), image_data_map | |
| except Exception as e: | |
| logger.error(f"Error processing OCR response markdown: {e}", exc_info=True) | |
| raise | |
| def perform_ocr_file(file_obj: Any) -> Tuple[str, str, Dict[str, str]]: | |
| """Performs OCR on an uploaded file using Mistral API.""" | |
| if not client: | |
| return "Error: Mistral client not initialized.", "", {} | |
| if not file_obj: | |
| return "Error: No file provided.", "", {} | |
| try: | |
| file_path = file_obj.name | |
| file_name = getattr(file_obj, 'orig_name', os.path.basename(file_path)) | |
| logger.info(f"Performing OCR on file: {file_name}") | |
| file_ext = os.path.splitext(file_name)[1].lower() | |
| ocr_response = None | |
| uploaded_file_id = None | |
| if file_ext == '.pdf': | |
| try: | |
| with open(file_path, "rb") as f: | |
| file_content = f.read() | |
| logger.info(f"Uploading PDF {file_name} to Mistral...") | |
| uploaded_pdf = client.files.upload( | |
| file={ | |
| "file_name": file_name, | |
| "content": file_content, | |
| }, | |
| purpose="ocr" | |
| ) | |
| uploaded_file_id = uploaded_pdf.id | |
| logger.info(f"PDF uploaded successfully. File ID: {uploaded_file_id}") | |
| signed_url_response = client.files.get_signed_url(file_id=uploaded_file_id) | |
| ocr_response = client.ocr.process( | |
| model="mistral-ocr-latest", | |
| document={"type": "document_url", "document_url": signed_url_response.url}, | |
| include_image_base64=True | |
| ) | |
| logger.info(f"OCR response received: {ocr_response}") | |
| finally: | |
| if uploaded_file_id: | |
| try: | |
| client.files.delete(file_id=uploaded_file_id) | |
| except Exception as delete_err: | |
| logger.warning(f"Failed to delete temporary file {uploaded_file_id}: {delete_err}") | |
| elif file_ext in ['.png', '.jpg', '.jpeg', '.webp', '.bmp']: | |
| with open(file_path, "rb") as f: | |
| image_bytes = f.read() | |
| if not image_bytes: | |
| return f"Error: Uploaded image file '{file_name}' is empty.", "", {} | |
| base64_encoded_image = encode_image_bytes(image_bytes) | |
| mime_type, _ = mimetypes.guess_type(file_path) | |
| mime_type = mime_type or 'image/jpeg' | |
| data_uri = f"data:{mime_type};base64,{base64_encoded_image}" | |
| ocr_response = client.ocr.process( | |
| model="mistral-ocr-latest", | |
| document={"type": "image_url", "image_url": data_uri}, | |
| include_image_base64=True | |
| ) | |
| logger.info(f"OCR response received: {ocr_response}") | |
| else: | |
| return f"Unsupported file type: '{file_name}'.", "", {} | |
| if ocr_response: | |
| processed_md, raw_md, img_map = get_combined_markdown(ocr_response) | |
| logger.info(f"Processed markdown length: {len(processed_md)}") | |
| return processed_md, raw_md, img_map | |
| return f"Error: OCR failed for '{file_name}'.", "", {} | |
| except Exception as e: | |
| logger.error(f"Error during OCR: {e}", exc_info=True) | |
| return f"Error during OCR: {str(e)}", "", {} | |
| def chunk_markdown( | |
| markdown_text_with_images: str, | |
| chunk_size: int = 1000, | |
| chunk_overlap: int = 200, | |
| strip_headers: bool = True | |
| ) -> List[Document]: | |
| """Chunks markdown text, preserving headers in metadata and extracting images.""" | |
| if not markdown_text_with_images or not markdown_text_with_images.strip(): | |
| logger.warning("chunk_markdown received empty input.") | |
| return [] | |
| # Extract images and replace with reference IDs | |
| image_map = extract_images_from_markdown(markdown_text_with_images) | |
| updated_markdown = replace_image_references(markdown_text_with_images, image_map) | |
| logger.info(f"Extracted {len(image_map)} images from markdown.") | |
| headers_to_split_on = [ | |
| ("#", "Header 1"), ("##", "Header 2"), ("###", "Header 3"), | |
| ("####", "Header 4"), ("#####", "Header 5"), ("######", "Header 6"), | |
| ] | |
| markdown_splitter = MarkdownHeaderTextSplitter( | |
| headers_to_split_on=headers_to_split_on, strip_headers=strip_headers | |
| ) | |
| header_chunks = markdown_splitter.split_text(updated_markdown) | |
| if not header_chunks: | |
| logger.warning("No header chunks created. Treating entire text as one chunk.") | |
| return [Document(page_content=updated_markdown, metadata={"images_base64": list(image_map.values())})] | |
| final_chunks = [] | |
| if chunk_size > 0: | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function=len, | |
| separators=["\n\n", "\n", "(?<=\. )", "(?<=\? )", "(?<=! )", ", ", "; ", " ", ""], | |
| add_start_index=True | |
| ) | |
| for i, header_chunk in enumerate(header_chunks): | |
| if header_chunk.page_content: | |
| sub_chunks = text_splitter.split_documents([header_chunk]) | |
| final_chunks.extend(sub_chunks) | |
| logger.debug(f"Header chunk {i}: Split into {len(sub_chunks)} sub-chunks.") | |
| else: | |
| logger.debug(f"Header chunk {i}: Empty, skipping.") | |
| else: | |
| final_chunks = [chunk for chunk in header_chunks if chunk.page_content] | |
| # Add image references to metadata for each chunk | |
| for chunk in final_chunks: | |
| if not hasattr(chunk, 'metadata'): | |
| chunk.metadata = {} | |
| # Find image references in this chunk | |
| chunk_img_refs = re.findall(r"!\[.*?\]\((img_ref_\d+)\)", chunk.page_content) | |
| chunk_images = [image_map[ref_id] for ref_id in chunk_img_refs if ref_id in image_map] | |
| chunk.metadata["images_base64"] = chunk_images | |
| chunk.metadata["image_references"] = chunk_img_refs | |
| logger.debug(f"Chunk {chunk.metadata.get('start_index', 'unknown')}: Found {len(chunk_images)} images.") | |
| logger.info(f"Created {len(final_chunks)} final chunks.") | |
| return final_chunks | |
| def get_hf_token(explicit_token: str = None) -> str: | |
| """Retrieve Hugging Face token with fallback mechanisms.""" | |
| global hf_token_global | |
| if explicit_token and explicit_token.strip() and explicit_token.startswith('hf_'): | |
| return explicit_token.strip() | |
| if hf_token_global: | |
| return hf_token_global | |
| env_token = os.environ.get("HF_TOKEN") | |
| if env_token and env_token.startswith('hf_'): | |
| hf_token_global = env_token | |
| return env_token | |
| try: | |
| stored_token = huggingface_hub.get_token() | |
| if stored_token: | |
| hf_token_global = stored_token | |
| return stored_token | |
| except Exception as e: | |
| logger.warning(f"Could not retrieve token from Hugging Face config: {e}") | |
| return None | |
| def process_file_and_save( | |
| file_objs: Any, chunk_size: int, chunk_overlap: int, | |
| strip_headers: bool, hf_token: str, repo_name: str | |
| ) -> str: | |
| """Orchestrates OCR, chunking, and saving to Hugging Face for multiple files.""" | |
| # Handle case where file_objs is a single file or None | |
| if not file_objs: | |
| return "Error: No files uploaded." | |
| if not isinstance(file_objs, list): | |
| file_objs = [file_objs] | |
| if not repo_name or '/' not in repo_name: | |
| return "Error: Invalid repository name (use 'username/dataset-name')." | |
| if chunk_size < 0: | |
| chunk_size = 0 | |
| if chunk_overlap < 0: | |
| chunk_overlap = 0 | |
| if chunk_size > 0 and chunk_overlap >= chunk_size: | |
| chunk_overlap = min(200, chunk_size // 2) | |
| effective_hf_token = get_hf_token(hf_token) | |
| if not effective_hf_token: | |
| return """Error: No valid Hugging Face token found. | |
| Please either: | |
| 1. Provide a token in the input field (starts with 'hf_') | |
| 2. Set HF_TOKEN environment variable | |
| 3. Run `huggingface-cli login` in your terminal""" | |
| try: | |
| all_data = { | |
| "chunk_id": [], | |
| "text": [], | |
| "metadata": [], | |
| "source_filename": [] | |
| } | |
| total_chunks = 0 | |
| files_processed = 0 | |
| error_messages = [] | |
| for file_idx, file_obj in enumerate(file_objs, 1): | |
| source_filename = getattr(file_obj, 'orig_name', os.path.basename(file_obj.name)) | |
| logger.info(f"--- Processing file {file_idx}/{len(file_objs)}: {source_filename} ---") | |
| processed_markdown, raw_markdown, img_map = perform_ocr_file(file_obj) | |
| if processed_markdown.startswith("Error:"): | |
| error_messages.append(f"File '{source_filename}': {processed_markdown}") | |
| logger.error(f"Failed to process file {source_filename}: {processed_markdown}") | |
| continue | |
| chunks = chunk_markdown(processed_markdown, chunk_size, chunk_overlap, strip_headers) | |
| if not chunks: | |
| error_messages.append(f"File '{source_filename}': Failed to chunk the document.") | |
| logger.error(f"Failed to chunk file {source_filename}") | |
| continue | |
| all_data["chunk_id"].extend([f"{source_filename}_chunk_{i}" for i in range(len(chunks))]) | |
| all_data["text"].extend([chunk.page_content or "" for chunk in chunks]) | |
| all_data["metadata"].extend([chunk.metadata for chunk in chunks]) | |
| all_data["source_filename"].extend([source_filename] * len(chunks)) | |
| total_chunks += len(chunks) | |
| files_processed += 1 | |
| logger.info(f"File {source_filename}: Added {len(chunks)} chunks. Total chunks: {total_chunks}") | |
| if not all_data["chunk_id"]: | |
| return "Error: No valid data processed from any files.\n" + "\n".join(error_messages) | |
| dataset = Dataset.from_dict(all_data) | |
| api = HfApi(token=effective_hf_token) | |
| try: | |
| user_info = api.whoami() | |
| logger.info(f"Authenticated as: {user_info['name']}") | |
| except Exception as auth_err: | |
| return f"Error: Invalid HF token - authentication failed: {auth_err}" | |
| try: | |
| api.repo_info(repo_id=repo_name, repo_type="dataset") | |
| logger.info(f"Repository '{repo_name}' exists.") | |
| except huggingface_hub.utils.RepositoryNotFoundError: | |
| api.create_repo(repo_id=repo_name, repo_type="dataset", private=False) | |
| logger.info(f"Created repository '{repo_name}'.") | |
| dataset.push_to_hub(repo_name, token=effective_hf_token, | |
| commit_message=f"Add OCR data from {files_processed} files") | |
| repo_url = f"https://huggingface.co/datasets/{repo_name}" | |
| result = f"Success! Dataset with {total_chunks} chunks from {files_processed}/{len(file_objs)} files saved to: {repo_url}" | |
| if error_messages: | |
| result += "\n\nErrors encountered:\n" + "\n".join(error_messages) | |
| return result | |
| except huggingface_hub.utils.HfHubHTTPError as hf_http_err: | |
| status = getattr(hf_http_err.response, 'status_code', 'Unknown') | |
| if status == 401: | |
| return "Error: Invalid or unauthorized Hugging Face token." | |
| elif status == 403: | |
| return "Error: Token lacks write permission." | |
| return f"Error: Hugging Face Hub Error (Status {status}): {hf_http_err}" | |
| except Exception as e: | |
| logger.error(f"Unexpected error: {e}", exc_info=True) | |
| return f"Unexpected error: {str(e)}\n" + "\n".join(error_messages) | |
| # --- Gradio Interface --- | |
| with gr.Blocks(title="Mistral OCR & Dataset Creator", | |
| theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan")) as demo: | |
| gr.Markdown("# Mistral OCR, Markdown Chunking, and Hugging Face Dataset Creator") | |
| gr.Markdown( | |
| """ | |
| Upload one or more PDF or image files. The application will: | |
| 1. Extract text and images using Mistral OCR for each file | |
| 2. Embed images as base64 data URIs in markdown | |
| 3. Chunk markdown by headers and optionally character count | |
| 4. Store embedded images in chunk metadata | |
| 5. Create/update a Hugging Face Dataset with all processed data | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| file_input = gr.File( | |
| label="Upload PDF or Image Files", | |
| file_types=['.pdf', '.png', '.jpg', '.jpeg', '.webp', '.bmp'], | |
| type="filepath", | |
| file_count="multiple" # Allow multiple file uploads | |
| ) | |
| gr.Markdown("## Chunking Options") | |
| chunk_size = gr.Slider(minimum=0, maximum=8000, value=1000, step=100, | |
| label="Max Chunk Size (Characters)") | |
| chunk_overlap = gr.Slider(minimum=0, maximum=1000, value=200, step=50, | |
| label="Chunk Overlap (Characters)") | |
| strip_headers = gr.Checkbox(label="Strip Headers from Content", value=True) | |
| gr.Markdown("## Hugging Face Output Options") | |
| repo_name = gr.Textbox(label="HF Dataset Repository", | |
| placeholder="your-username/your-dataset-name") | |
| hf_token = gr.Textbox(label="Hugging Face Token", type="password", | |
| placeholder="hf_...") | |
| submit_btn = gr.Button("Process and Save", variant="primary") | |
| with gr.Column(scale=1): | |
| output = gr.Textbox(label="Result Status", lines=20, interactive=False) | |
| submit_btn.click( | |
| fn=process_file_and_save, | |
| inputs=[file_input, chunk_size, chunk_overlap, strip_headers, hf_token, repo_name], | |
| outputs=output | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| [None, 1000, 200, True, "", "hf-username/my-first-ocr-dataset"], | |
| [None, 2000, 400, True, "", "hf-username/large-chunk-ocr-data"], | |
| [None, 0, 0, False, "", "hf-username/header-only-ocr-data"], | |
| ], | |
| inputs=[file_input, chunk_size, chunk_overlap, strip_headers, hf_token, repo_name], | |
| outputs=output, | |
| fn=process_file_and_save, | |
| cache_examples=False | |
| ) | |
| gr.Markdown("*Requires MISTRAL_API_KEY or HF token*") | |
| if __name__ == "__main__": | |
| import gradio | |
| logger.info(f"Using Gradio version: {gradio.__version__}") | |
| if not gradio.__version__.startswith("4."): | |
| logger.warning("Gradio version is not 4.x. Updating to the latest version is recommended.") | |
| print("Consider running: pip install --upgrade gradio") | |
| initial_token = get_hf_token() | |
| if not initial_token and not client: | |
| print("\nWARNING: Neither Mistral API key nor HF token found.") | |
| print("Set MISTRAL_API_KEY and/or HF_TOKEN, or use `huggingface-cli login`") | |
| demo.launch( | |
| share=os.getenv('GRADIO_SHARE', 'False').lower() == 'true', | |
| debug=True, | |
| auth_message="Provide a valid Hugging Face token if prompted" | |
| ) | |