File size: 7,323 Bytes
19aaa42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import argparse
import logging
import os
from pathlib import Path
import shutil
import torch
from groq import Groq
from nougat import NougatModel
from nougat.utils.device import move_to_device
from nougat.postprocessing import markdown_compatible
from pypdf import PdfReader
from tqdm import tqdm
from dotenv import load_dotenv
import pypdfium2 as pdfium
from torchvision.transforms.functional import to_tensor

# Configure basic logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    handlers=[logging.StreamHandler(), logging.FileHandler("pdf_processing.log")],
)

class NougatPDFProcessor:
    """
    Processes PDFs using the Nougat model to generate high-quality Markdown,
    and prepends an AMA citation generated by a Groq LLM.
    """

    def __init__(self, input_dir: str, output_dir: str):
        self.input_dir = Path(input_dir)
        self.output_dir = Path(output_dir)
        self.temp_dir = self.output_dir / "temp_nougat_output"
        
        self.output_dir.mkdir(exist_ok=True)
        self.temp_dir.mkdir(exist_ok=True)

        load_dotenv()
        groq_api_key = os.getenv("GROQ_API_KEY")
        if not groq_api_key:
            raise ValueError("GROQ_API_KEY not found in .env file")
        self.groq_client = Groq(api_key=groq_api_key)
        
        # Initialize Nougat model
        self.model = NougatModel.from_pretrained("facebook/nougat-small").to(torch.bfloat16)
        self.model = move_to_device(self.model)
        self.model.eval()

    def _get_first_page_text(self, pdf_path: Path) -> str:
        """Extracts text from the first page of a PDF."""
        try:
            reader = PdfReader(pdf_path)
            first_page = reader.pages[0]
            return first_page.extract_text() or ""
        except Exception as e:
            logging.error(f"Could not extract text from first page of '{pdf_path.name}': {e}")
            return ""

    def _generate_ama_citation(self, text: str) -> str:
        """Generates an AMA citation using the Groq API."""
        if not text:
            return "Citation could not be generated: No text found on the first page."
        
        prompt = (
            "Based on the following text from the first page of a medical document, "
            "please generate a concise AMA (American Medical Association) style citation. "
            "Include authors, title, journal/source, year, and volume/page numbers if available. "
            "If some information is missing, create the best citation possible with the available data. "
            "Output only the citation itself, with no additional text or labels.\n\n"
            f"--- DOCUMENT TEXT ---\n{text[:4000]}\n\n--- END DOCUMENT TEXT ---\n\nAMA Citation:"
        )
        try:
            chat_completion = self.groq_client.chat.completions.create(
                messages=[{"role": "user", "content": prompt}],
                model="llama3-8b-8192",
                temperature=0,
                max_tokens=200,
            )
            citation = chat_completion.choices[0].message.content.strip()
            return f"## Citation\n\n{citation}\n\n---\n\n"
        except Exception as e:
            logging.error(f"Groq API call failed for citation generation: {e}")
            return "## Citation\n\nCitation could not be generated due to an error.\n\n---\n\n"

    def process_single_pdf(self, pdf_path: Path):
        """Processes a single PDF with Nougat and adds a citation."""
        logging.info(f"Processing '{pdf_path.name}'...")
        final_md_path = self.output_dir / f"{pdf_path.stem}.md"

        # 1. Generate Citation
        logging.info(f"Generating AMA citation for '{pdf_path.name}'...")
        first_page_text = self._get_first_page_text(pdf_path)
        citation_md = self._generate_ama_citation(first_page_text)
        logging.info(f"Citation generated for '{pdf_path.name}'.")

        # 2. Process with Nougat
        logging.info(f"Processing PDF '{pdf_path.name}' with Nougat...")
        try:
            # Open the PDF with pypdfium2 and render the first page
            pdf = pdfium.PdfDocument(pdf_path)
            page = pdf[0] # Get the first page
            bitmap = page.render(scale=1) # Render at 72 DPI
            image = bitmap.to_pil() # Convert to a PIL Image

            # Resize the image to the required dimensions for the Nougat model
            image = image.resize((672, 896))

            # Convert PIL image to a bfloat16 tensor
            tensor = to_tensor(image).to(torch.bfloat16)

            # Pass the tensor to the model
            predictions = self.model.inference(image_tensors=tensor.unsqueeze(0))
            
            # The output for a single file is in predictions['predictions'][0]
            nougat_markdown = predictions['predictions'][0]
            
            # Post-processing to fix common markdown issues
            nougat_markdown = markdown_compatible(nougat_markdown)
            
            logging.info(f"Successfully processed '{pdf_path.name}' with Nougat.")

            # 3. Combine and Save
            final_content = citation_md + nougat_markdown
            final_md_path.write_text(final_content, encoding="utf-8")
            logging.info(f"Successfully saved final markdown to '{final_md_path}'.")

        except Exception as e:
            logging.error(f"Failed to process '{pdf_path.name}' with Nougat: {e}")
            # Create an error file to avoid reprocessing
            final_md_path.write_text(f"Failed to process this document with Nougat.\n\nError: {e}", encoding="utf-8")

    def process_all_pdfs(self):
        """Processes all PDF files in the input directory."""
        pdf_files = sorted(list(self.input_dir.glob("*.pdf")))
        if not pdf_files:
            logging.warning(f"No PDF files found in {self.input_dir}")
            return

        logging.info(f"Found {len(pdf_files)} PDF(s) to process.")

        for pdf_path in tqdm(pdf_files, desc="Processing PDFs with Nougat"):
            final_md_path = self.output_dir / f"{pdf_path.stem}.md"
            if final_md_path.exists():
                logging.info(f"Skipping '{pdf_path.name}' as it has already been processed.")
                continue
            
            self.process_single_pdf(pdf_path)

def main():
    """Main function to run the PDF processing script."""
    parser = argparse.ArgumentParser(description="PDF to Markdown Converter using Nougat with AMA Citations.")
    parser.add_argument("--input-dir", type=str, default="Obs", help="Directory containing source PDF files.")
    parser.add_argument("--output-dir", type=str, default="src/processed_markdown", help="Directory to save final Markdown files.")
    parser.add_argument("--file", type=str, help="Process a single PDF file by name (e.g., 'my_doc.pdf').")
    args = parser.parse_args()

    processor = NougatPDFProcessor(input_dir=args.input_dir, output_dir=args.output_dir)

    if args.file:
        pdf_to_process = Path(args.input_dir) / args.file
        if pdf_to_process.exists():
            processor.process_single_pdf(pdf_to_process)
        else:
            logging.error(f"Specified file not found: {pdf_to_process}")
    else:
        processor.process_all_pdfs()

if __name__ == "__main__":
    main()