Spaces:
Sleeping
Sleeping
import argparse | |
import logging | |
import os | |
from pathlib import Path | |
import shutil | |
import torch | |
from groq import Groq | |
from nougat import NougatModel | |
from nougat.utils.device import move_to_device | |
from nougat.postprocessing import markdown_compatible | |
from pypdf import PdfReader | |
from tqdm import tqdm | |
from dotenv import load_dotenv | |
import pypdfium2 as pdfium | |
from torchvision.transforms.functional import to_tensor | |
# Configure basic logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(levelname)s - %(message)s", | |
handlers=[logging.StreamHandler(), logging.FileHandler("pdf_processing.log")], | |
) | |
class NougatPDFProcessor: | |
""" | |
Processes PDFs using the Nougat model to generate high-quality Markdown, | |
and prepends an AMA citation generated by a Groq LLM. | |
""" | |
def __init__(self, input_dir: str, output_dir: str): | |
self.input_dir = Path(input_dir) | |
self.output_dir = Path(output_dir) | |
self.temp_dir = self.output_dir / "temp_nougat_output" | |
self.output_dir.mkdir(exist_ok=True) | |
self.temp_dir.mkdir(exist_ok=True) | |
load_dotenv() | |
groq_api_key = os.getenv("GROQ_API_KEY") | |
if not groq_api_key: | |
raise ValueError("GROQ_API_KEY not found in .env file") | |
self.groq_client = Groq(api_key=groq_api_key) | |
# Initialize Nougat model | |
self.model = NougatModel.from_pretrained("facebook/nougat-small").to(torch.bfloat16) | |
self.model = move_to_device(self.model) | |
self.model.eval() | |
def _get_first_page_text(self, pdf_path: Path) -> str: | |
"""Extracts text from the first page of a PDF.""" | |
try: | |
reader = PdfReader(pdf_path) | |
first_page = reader.pages[0] | |
return first_page.extract_text() or "" | |
except Exception as e: | |
logging.error(f"Could not extract text from first page of '{pdf_path.name}': {e}") | |
return "" | |
def _generate_ama_citation(self, text: str) -> str: | |
"""Generates an AMA citation using the Groq API.""" | |
if not text: | |
return "Citation could not be generated: No text found on the first page." | |
prompt = ( | |
"Based on the following text from the first page of a medical document, " | |
"please generate a concise AMA (American Medical Association) style citation. " | |
"Include authors, title, journal/source, year, and volume/page numbers if available. " | |
"If some information is missing, create the best citation possible with the available data. " | |
"Output only the citation itself, with no additional text or labels.\n\n" | |
f"--- DOCUMENT TEXT ---\n{text[:4000]}\n\n--- END DOCUMENT TEXT ---\n\nAMA Citation:" | |
) | |
try: | |
chat_completion = self.groq_client.chat.completions.create( | |
messages=[{"role": "user", "content": prompt}], | |
model="llama3-8b-8192", | |
temperature=0, | |
max_tokens=200, | |
) | |
citation = chat_completion.choices[0].message.content.strip() | |
return f"## Citation\n\n{citation}\n\n---\n\n" | |
except Exception as e: | |
logging.error(f"Groq API call failed for citation generation: {e}") | |
return "## Citation\n\nCitation could not be generated due to an error.\n\n---\n\n" | |
def process_single_pdf(self, pdf_path: Path): | |
"""Processes a single PDF with Nougat and adds a citation.""" | |
logging.info(f"Processing '{pdf_path.name}'...") | |
final_md_path = self.output_dir / f"{pdf_path.stem}.md" | |
# 1. Generate Citation | |
logging.info(f"Generating AMA citation for '{pdf_path.name}'...") | |
first_page_text = self._get_first_page_text(pdf_path) | |
citation_md = self._generate_ama_citation(first_page_text) | |
logging.info(f"Citation generated for '{pdf_path.name}'.") | |
# 2. Process with Nougat | |
logging.info(f"Processing PDF '{pdf_path.name}' with Nougat...") | |
try: | |
# Open the PDF with pypdfium2 and render the first page | |
pdf = pdfium.PdfDocument(pdf_path) | |
page = pdf[0] # Get the first page | |
bitmap = page.render(scale=1) # Render at 72 DPI | |
image = bitmap.to_pil() # Convert to a PIL Image | |
# Resize the image to the required dimensions for the Nougat model | |
image = image.resize((672, 896)) | |
# Convert PIL image to a bfloat16 tensor | |
tensor = to_tensor(image).to(torch.bfloat16) | |
# Pass the tensor to the model | |
predictions = self.model.inference(image_tensors=tensor.unsqueeze(0)) | |
# The output for a single file is in predictions['predictions'][0] | |
nougat_markdown = predictions['predictions'][0] | |
# Post-processing to fix common markdown issues | |
nougat_markdown = markdown_compatible(nougat_markdown) | |
logging.info(f"Successfully processed '{pdf_path.name}' with Nougat.") | |
# 3. Combine and Save | |
final_content = citation_md + nougat_markdown | |
final_md_path.write_text(final_content, encoding="utf-8") | |
logging.info(f"Successfully saved final markdown to '{final_md_path}'.") | |
except Exception as e: | |
logging.error(f"Failed to process '{pdf_path.name}' with Nougat: {e}") | |
# Create an error file to avoid reprocessing | |
final_md_path.write_text(f"Failed to process this document with Nougat.\n\nError: {e}", encoding="utf-8") | |
def process_all_pdfs(self): | |
"""Processes all PDF files in the input directory.""" | |
pdf_files = sorted(list(self.input_dir.glob("*.pdf"))) | |
if not pdf_files: | |
logging.warning(f"No PDF files found in {self.input_dir}") | |
return | |
logging.info(f"Found {len(pdf_files)} PDF(s) to process.") | |
for pdf_path in tqdm(pdf_files, desc="Processing PDFs with Nougat"): | |
final_md_path = self.output_dir / f"{pdf_path.stem}.md" | |
if final_md_path.exists(): | |
logging.info(f"Skipping '{pdf_path.name}' as it has already been processed.") | |
continue | |
self.process_single_pdf(pdf_path) | |
def main(): | |
"""Main function to run the PDF processing script.""" | |
parser = argparse.ArgumentParser(description="PDF to Markdown Converter using Nougat with AMA Citations.") | |
parser.add_argument("--input-dir", type=str, default="Obs", help="Directory containing source PDF files.") | |
parser.add_argument("--output-dir", type=str, default="src/processed_markdown", help="Directory to save final Markdown files.") | |
parser.add_argument("--file", type=str, help="Process a single PDF file by name (e.g., 'my_doc.pdf').") | |
args = parser.parse_args() | |
processor = NougatPDFProcessor(input_dir=args.input_dir, output_dir=args.output_dir) | |
if args.file: | |
pdf_to_process = Path(args.input_dir) / args.file | |
if pdf_to_process.exists(): | |
processor.process_single_pdf(pdf_to_process) | |
else: | |
logging.error(f"Specified file not found: {pdf_to_process}") | |
else: | |
processor.process_all_pdfs() | |
if __name__ == "__main__": | |
main() |