formpilot-demo / rag /ocr_mistral.py
afulara's picture
Auto‑deploy from GitHub
0c0a4f7 verified
# rag/ocr_mistral.py
USE_MISTRAL = os.getenv("USE_MISTRAL", "false").lower() == "true"
if not USE_MISTRAL:
# dev chose Azure, so expose a shim that imports azure parser
from .ocr_azure import parse_passport_azure as parse_passport # type: ignore
# everything else in this file is skipped
raise SystemExit # stop executing rest of the file
import os, base64, re, logging
from typing import Dict, Any, Optional
from mistralai import Mistral
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# load your .env or rely on HF_SPACE secrets / os.environ
MISTRAL_KEY = os.getenv("MISTRAL_OCR_KEY")
if not MISTRAL_KEY:
raise RuntimeError("Please set MISTRAL_OCR_KEY in your environment")
# initialize once
_client = Mistral(api_key=MISTRAL_KEY)
def parse_passport(file_bytes: bytes) -> Dict[str, Any]:
"""
Send PDF/JPEG/PNG bytes to Mistral OCR and extract passport fields.
Returns a dictionary with extracted fields and raw text.
"""
try:
# 1) Base64‑encode the bytes
b64 = base64.b64encode(file_bytes).decode("utf-8")
# 2) Build the 'document' payload per Mistral spec
header = file_bytes[:8]
if header.startswith(b"%PDF"):
doc = {
"type": "document_url",
"document_url": f"data:application/pdf;base64,{b64}"
}
elif header.startswith(b"\x89PNG\r\n\x1a\n"):
doc = {
"type": "image_url",
"image_url": f"data:image/png;base64,{b64}"
}
else:
# treat everything else as JPEG
doc = {
"type": "image_url",
"image_url": f"data:image/jpeg;base64,{b64}"
}
# 3) Call Mistral OCR
resp = _client.ocr.process(
model="mistral-ocr-latest",
document=doc,
include_image_base64=False
)
# 4) Aggregate all text
pages = resp.get("pages", [])
full_text = "\n".join(p.get("text","") for p in pages)
# 5) Enhanced regex extraction with multiple patterns
profile = {}
# First Name patterns
first_name_patterns = [
r"Given\s+Name[:\s]+([A-Za-z\-]+)",
r"First\s+Name[:\s]+([A-Za-z\-]+)",
r"Given\s+Names?[:\s]+([A-Za-z\-]+)",
r"First\s+Names?[:\s]+([A-Za-z\-]+)"
]
# Last Name patterns
last_name_patterns = [
r"Family\s+Name[:\s]+([A-Za-z\-]+)",
r"Last\s+Name[:\s]+([A-Za-z\-]+)",
r"Surname[:\s]+([A-Za-z\-]+)",
r"Family\s+Names?[:\s]+([A-Za-z\-]+)"
]
# Date of Birth patterns
dob_patterns = [
r"Date\s+of\s+Birth[:\s]+(\d{2}/\d{2}/\d{4})",
r"DOB[:\s]+(\d{2}/\d{2}/\d{4})",
r"Birth\s+Date[:\s]+(\d{2}/\d{2}/\d{4})"
]
# A-Number patterns
a_number_patterns = [
r"A-Number[:\s]*(A\d{8,9})",
r"A\s*Number[:\s]*(A\d{8,9})",
r"Alien\s+Number[:\s]*(A\d{8,9})"
]
# Try each pattern until we find a match
for pattern in first_name_patterns:
if m := re.search(pattern, full_text, re.IGNORECASE):
profile["FirstName"] = m.group(1)
break
for pattern in last_name_patterns:
if m := re.search(pattern, full_text, re.IGNORECASE):
profile["LastName"] = m.group(1)
break
for pattern in dob_patterns:
if m := re.search(pattern, full_text, re.IGNORECASE):
profile["DateOfBirth"] = m.group(1)
break
for pattern in a_number_patterns:
if m := re.search(pattern, full_text, re.IGNORECASE):
profile["ANumber"] = m.group(1)
break
# include raw text for debugging
profile["_raw_text"] = full_text
# Log extraction results
logger.info(f"Extracted profile: {profile}")
return profile
except Exception as e:
logger.error(f"Error processing passport: {str(e)}")
return {
"error": str(e),
"_raw_text": full_text if 'full_text' in locals() else ""
}