import torch from PIL import Image, ImageDraw from transformers import AutoProcessor, PaliGemmaForConditionalGeneration from peft import PeftModel, PeftConfig import numpy as np from detector import TextDetector import tempfile import os # List of available models with their IDs and prompts MODELS = { "Medium-14k, Single Line": { "id": "alakxender/paligemma2-qlora-dhivehi-ocr-224-sl-14k", "prompt": "What text is written in this image?" }, "Medium-16k, Single Line": { "id": "alakxender/paligemma2-qlora-dhivehi-ocr-224-sl-md-16k", "prompt": "What text is written in this image?" }, "Small, Single Line": { "id": "alakxender/paligemma2-qlora-vrd-dhivehi-ocr-224-sm", "prompt": "What text is written in this image?" } } class PaliGemma2Handler: def __init__(self): self.model = None self.processor = None self.current_model_name = None self.detector = TextDetector() def load_model(self, model_name): """Load the model and processor""" model_id = MODELS[model_name]['id'] # Load the PEFT configuration to get the base model path peft_config = PeftConfig.from_pretrained(model_id) # Load the base model base_model = PaliGemmaForConditionalGeneration.from_pretrained( peft_config.base_model_name_or_path, device_map="auto", torch_dtype=torch.bfloat16 ) # Load the adapter on top of the base model self.model = PeftModel.from_pretrained(base_model, model_id) self.processor = AutoProcessor.from_pretrained(peft_config.base_model_name_or_path) self.current_model_name = model_name def process_image(self, model_name, image, progress=None): """Process a single image""" if image is None: return "", [] # Load model if different model selected if model_name != self.current_model_name: try: if progress is not None: progress(0, desc="Loading model...") except: pass self.load_model(model_name) if isinstance(image, np.ndarray): image = Image.fromarray(image) width, height = image.size print(f"Image dimensions: {width}x{height}") # Check if image proportions are similar to a single line # Typical single line has width significantly larger than height # and aspect ratio (width/height) greater than 3 aspect_ratio = width / height if height <= 50 or aspect_ratio > 3: try: if progress is not None: progress(0.5, desc="Processing single line...") except: pass result = self.process_single_line(image, model_name) try: if progress is not None: progress(1.0, desc="Done!") except: pass return result, [image] else: return self.process_multi_line(image, model_name, progress) def process_single_line(self, image, model_name): """Process a single line of text""" prompt = MODELS[model_name]["prompt"] # Add image token to prompt prompt = f"{prompt}" # First prepare inputs without moving to CUDA model_inputs = self.processor(text=prompt, images=image, return_tensors="pt") # Then move to CUDA and convert only image tensors to bfloat16 for k, v in model_inputs.items(): if k == "pixel_values": model_inputs[k] = v.to(torch.bfloat16).to("cuda") else: model_inputs[k] = v.to("cuda") outputs = self.model.generate( **model_inputs, max_new_tokens=500, do_sample=False ) generated_text = self.processor.batch_decode(outputs, skip_special_tokens=True)[0] # Remove the prompt and any leading/trailing whitespace cleaned_text = generated_text.replace(prompt, "").strip() # Remove any remaining question marks or other artifacts cleaned_text = cleaned_text.lstrip("?").strip() # Remove the prompt text if it somehow appears in the output cleaned_text = cleaned_text.replace("What text is written in this image?", "").strip() return cleaned_text def process_multi_line(self, image, model_name, progress=None): """Process a multi-line image by detecting text regions and OCRing each region""" # Create temporary directory with tempfile.TemporaryDirectory() as temp_dir: # Save input image input_path = os.path.join(temp_dir, "input.png") image.save(input_path) # Initialize detector with temp directory detector = TextDetector(output_dir=temp_dir) # Run text detection try: if progress is not None: progress(0.1, desc="Detecting text regions...") except: pass results = detector.process_input(input_path, save_images=True) # Get text regions for the image regions = detector.get_text_regions(results, "input") if not regions: return "No text regions detected", [] # Process each text region page_regions = regions[0] # First page text_lines = page_regions.get('bboxes', []) if not text_lines: return "No text lines detected", [] # Sort text lines by y-coordinate (top to bottom) text_lines.sort(key=lambda x: x['bbox'][1]) # Draw bounding boxes on the image bbox_image = image.copy() bbox_image = self.draw_bboxes(bbox_image, text_lines) # Process each text line all_text = [] total_lines = len(text_lines) for i, line in enumerate(text_lines): try: if progress is not None: progress((i + 1) / total_lines, desc=f"Processing line {i+1}/{total_lines}") except: pass # Extract text region using bbox x1, y1, x2, y2 = line['bbox'] line_image = image.crop((x1, y1, x2, y2)) # Process the line line_text = self.process_single_line(line_image, model_name) all_text.append(line_text) try: if progress is not None: progress(1.0, desc="Done!") except: pass return "\n".join(all_text), [bbox_image] # Return as list for gallery def process_pdf(self, pdf_path, model_name, progress=None): """Process a PDF file""" if pdf_path is None: return "", [] # Load model if different model selected if model_name != self.current_model_name: try: if progress is not None: progress(0, desc="Loading model...") except: pass self.load_model(model_name) # Create temporary directory with tempfile.TemporaryDirectory() as temp_dir: # Initialize detector with temp directory self.detector.output_dir = temp_dir # Run text detection on PDF (process first 2 pages) try: if progress is not None: progress(0.1, desc="Detecting text regions in PDF...") except: pass results = self.detector.process_input(pdf_path, save_images=True, page_range="0") # Get text regions for the PDF regions = self.detector.get_text_regions(results, os.path.splitext(os.path.basename(pdf_path))[0]) if not regions: return "No text regions detected", [] # Process each page all_text = [] bbox_images = [] # Get the base name of the PDF without extension pdf_name = os.path.splitext(os.path.basename(pdf_path))[0] for page_num, page_regions in enumerate(regions): try: if progress is not None: progress(0.2 + (page_num/len(regions))*0.3, desc=f"Processing page {page_num+1}/{len(regions)}...") except: pass # Try different possible paths for the page image possible_paths = [ os.path.join(temp_dir, pdf_name, f"{pdf_name}_{page_num}_bbox.png"), # Detector's actual path os.path.join(temp_dir, pdf_name, f"page_{page_num}.png"), # Original path os.path.join(temp_dir, f"page_{page_num}.png"), # Direct in output dir os.path.join(temp_dir, f"{pdf_name}_page_{page_num}.png") # Alternative naming ] page_image = None for page_image_path in possible_paths: if os.path.exists(page_image_path): page_image = Image.open(page_image_path) break if page_image is None: all_text.append(f"\nPage {page_num+1}: Page image not found. Tried paths:\n" + "\n".join(f"- {path}" for path in possible_paths)) continue text_lines = page_regions.get('bboxes', []) if not text_lines: all_text.append(f"\nPage {page_num+1}: No text lines detected") continue # Sort text lines by y-coordinate (top to bottom) text_lines.sort(key=lambda x: x['bbox'][1]) # Draw bounding boxes on the image bbox_image = page_image.copy() bbox_image = self.draw_bboxes(bbox_image, text_lines) bbox_images.append(bbox_image) # Process each text line page_text = [] total_lines = len(text_lines) for i, line in enumerate(text_lines): try: if progress is not None: progress(0.5 + (page_num/len(regions))*0.2 + (i/total_lines)*0.3, desc=f"Processing line {i+1}/{total_lines} on page {page_num+1}/{len(regions)}...") except: pass # Extract text region using bbox x1, y1, x2, y2 = line['bbox'] line_image = page_image.crop((x1, y1, x2, y2)) # Process the line line_text = self.process_single_line(line_image, model_name) page_text.append(line_text) # Add page text without page number all_text.extend(page_text) try: if progress is not None: progress(1.0, desc="Done!") except: pass return "\n".join(all_text), bbox_images # Return list of bbox images @staticmethod def draw_bboxes(image, text_lines): """Draw bounding boxes on the image""" draw = ImageDraw.Draw(image) for line in text_lines: # Draw polygon - flatten nested coordinates polygon = line['polygon'] flat_polygon = [coord for point in polygon for coord in point] draw.polygon(flat_polygon, outline="red", width=2) # Draw bbox x1, y1, x2, y2 = line['bbox'] draw.rectangle([x1, y1, x2, y2], outline="blue", width=1) # Draw confidence score draw.text((x1, y1 - 10), f"{line['confidence']:.2f}", fill="red") return image