vedaMD / src /nougat_pdf_processor.py
sniro23's picture
Initial commit without binary files
19aaa42
import argparse
import logging
import os
from pathlib import Path
import shutil
import torch
from groq import Groq
from nougat import NougatModel
from nougat.utils.device import move_to_device
from nougat.postprocessing import markdown_compatible
from pypdf import PdfReader
from tqdm import tqdm
from dotenv import load_dotenv
import pypdfium2 as pdfium
from torchvision.transforms.functional import to_tensor
# Configure basic logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler(), logging.FileHandler("pdf_processing.log")],
)
class NougatPDFProcessor:
"""
Processes PDFs using the Nougat model to generate high-quality Markdown,
and prepends an AMA citation generated by a Groq LLM.
"""
def __init__(self, input_dir: str, output_dir: str):
self.input_dir = Path(input_dir)
self.output_dir = Path(output_dir)
self.temp_dir = self.output_dir / "temp_nougat_output"
self.output_dir.mkdir(exist_ok=True)
self.temp_dir.mkdir(exist_ok=True)
load_dotenv()
groq_api_key = os.getenv("GROQ_API_KEY")
if not groq_api_key:
raise ValueError("GROQ_API_KEY not found in .env file")
self.groq_client = Groq(api_key=groq_api_key)
# Initialize Nougat model
self.model = NougatModel.from_pretrained("facebook/nougat-small").to(torch.bfloat16)
self.model = move_to_device(self.model)
self.model.eval()
def _get_first_page_text(self, pdf_path: Path) -> str:
"""Extracts text from the first page of a PDF."""
try:
reader = PdfReader(pdf_path)
first_page = reader.pages[0]
return first_page.extract_text() or ""
except Exception as e:
logging.error(f"Could not extract text from first page of '{pdf_path.name}': {e}")
return ""
def _generate_ama_citation(self, text: str) -> str:
"""Generates an AMA citation using the Groq API."""
if not text:
return "Citation could not be generated: No text found on the first page."
prompt = (
"Based on the following text from the first page of a medical document, "
"please generate a concise AMA (American Medical Association) style citation. "
"Include authors, title, journal/source, year, and volume/page numbers if available. "
"If some information is missing, create the best citation possible with the available data. "
"Output only the citation itself, with no additional text or labels.\n\n"
f"--- DOCUMENT TEXT ---\n{text[:4000]}\n\n--- END DOCUMENT TEXT ---\n\nAMA Citation:"
)
try:
chat_completion = self.groq_client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model="llama3-8b-8192",
temperature=0,
max_tokens=200,
)
citation = chat_completion.choices[0].message.content.strip()
return f"## Citation\n\n{citation}\n\n---\n\n"
except Exception as e:
logging.error(f"Groq API call failed for citation generation: {e}")
return "## Citation\n\nCitation could not be generated due to an error.\n\n---\n\n"
def process_single_pdf(self, pdf_path: Path):
"""Processes a single PDF with Nougat and adds a citation."""
logging.info(f"Processing '{pdf_path.name}'...")
final_md_path = self.output_dir / f"{pdf_path.stem}.md"
# 1. Generate Citation
logging.info(f"Generating AMA citation for '{pdf_path.name}'...")
first_page_text = self._get_first_page_text(pdf_path)
citation_md = self._generate_ama_citation(first_page_text)
logging.info(f"Citation generated for '{pdf_path.name}'.")
# 2. Process with Nougat
logging.info(f"Processing PDF '{pdf_path.name}' with Nougat...")
try:
# Open the PDF with pypdfium2 and render the first page
pdf = pdfium.PdfDocument(pdf_path)
page = pdf[0] # Get the first page
bitmap = page.render(scale=1) # Render at 72 DPI
image = bitmap.to_pil() # Convert to a PIL Image
# Resize the image to the required dimensions for the Nougat model
image = image.resize((672, 896))
# Convert PIL image to a bfloat16 tensor
tensor = to_tensor(image).to(torch.bfloat16)
# Pass the tensor to the model
predictions = self.model.inference(image_tensors=tensor.unsqueeze(0))
# The output for a single file is in predictions['predictions'][0]
nougat_markdown = predictions['predictions'][0]
# Post-processing to fix common markdown issues
nougat_markdown = markdown_compatible(nougat_markdown)
logging.info(f"Successfully processed '{pdf_path.name}' with Nougat.")
# 3. Combine and Save
final_content = citation_md + nougat_markdown
final_md_path.write_text(final_content, encoding="utf-8")
logging.info(f"Successfully saved final markdown to '{final_md_path}'.")
except Exception as e:
logging.error(f"Failed to process '{pdf_path.name}' with Nougat: {e}")
# Create an error file to avoid reprocessing
final_md_path.write_text(f"Failed to process this document with Nougat.\n\nError: {e}", encoding="utf-8")
def process_all_pdfs(self):
"""Processes all PDF files in the input directory."""
pdf_files = sorted(list(self.input_dir.glob("*.pdf")))
if not pdf_files:
logging.warning(f"No PDF files found in {self.input_dir}")
return
logging.info(f"Found {len(pdf_files)} PDF(s) to process.")
for pdf_path in tqdm(pdf_files, desc="Processing PDFs with Nougat"):
final_md_path = self.output_dir / f"{pdf_path.stem}.md"
if final_md_path.exists():
logging.info(f"Skipping '{pdf_path.name}' as it has already been processed.")
continue
self.process_single_pdf(pdf_path)
def main():
"""Main function to run the PDF processing script."""
parser = argparse.ArgumentParser(description="PDF to Markdown Converter using Nougat with AMA Citations.")
parser.add_argument("--input-dir", type=str, default="Obs", help="Directory containing source PDF files.")
parser.add_argument("--output-dir", type=str, default="src/processed_markdown", help="Directory to save final Markdown files.")
parser.add_argument("--file", type=str, help="Process a single PDF file by name (e.g., 'my_doc.pdf').")
args = parser.parse_args()
processor = NougatPDFProcessor(input_dir=args.input_dir, output_dir=args.output_dir)
if args.file:
pdf_to_process = Path(args.input_dir) / args.file
if pdf_to_process.exists():
processor.process_single_pdf(pdf_to_process)
else:
logging.error(f"Specified file not found: {pdf_to_process}")
else:
processor.process_all_pdfs()
if __name__ == "__main__":
main()