Spaces:
Running
Running
""" | |
Mistral OCR Tool + Agent | |
Call `parse_passport(bytes) -> dict` | |
""" | |
import os, json, requests, logging | |
from typing import Dict, Any, List | |
from langchain.tools import StructuredTool | |
from langchain.agents import initialize_agent, AgentType | |
from langchain_openai import ChatOpenAI | |
from dotenv import load_dotenv | |
load_dotenv() | |
# ------------------ raw tool ------------------ | |
MISTRAL_URL = os.getenv("MISTRAL_OCR_URL", "https://api.mistral.ai/ocr/v1/documents") | |
MISTRAL_KEY = os.getenv("MISTRAL_OCR_KEY") # set this in .env / HF secret | |
_HEADERS = {"Authorization": f"Bearer {MISTRAL_KEY}"} if MISTRAL_KEY else {} | |
def _mistral_ocr_call(file_bytes: bytes) -> Dict[str, Any]: | |
"""Send PDF, JPEG, or PNG bytes to Mistral OCR endpoint.""" | |
if not MISTRAL_KEY: | |
logging.warning("MISTRAL_OCR_KEY not set") | |
return {} | |
# --- detect MIME type & extension --- | |
header = file_bytes[:8] | |
if header.startswith(b"\x89PNG\r\n\x1a\n"): | |
content_type = "image/png" | |
ext = "png" | |
elif header[:2] == b"\xff\xd8": | |
content_type = "image/jpeg" | |
ext = "jpg" | |
elif header[:4] == b"%PDF": | |
content_type = "application/pdf" | |
ext = "pdf" | |
else: | |
content_type = "application/octet-stream" | |
ext = "bin" | |
# new JSON payload with base64 content | |
import base64 | |
b64 = base64.b64encode(file_bytes).decode("utf-8") | |
payload = {"content": b64} | |
resp = requests.post( | |
MISTRAL_URL, | |
headers={**_HEADERS, "Content-Type": "application/json"}, | |
json=payload, | |
timeout=60, | |
) | |
if resp.status_code != 200: | |
logging.error("Mistral OCR error %s: %s", resp.status_code, resp.text[:200]) | |
return {} | |
return resp.json() | |
# ------------------ LangChain Tool ------------------ | |
def _extract_fields(raw_json: Dict[str, Any]) -> Dict[str, str]: | |
mapping = { | |
"FirstName": ["first_name", "given_name"], | |
"LastName": ["last_name", "surname"], | |
"DateOfBirth": ["date_of_birth", "dob"], | |
"PassportNumber": ["passport_number", "document_number"], | |
} | |
out = {} | |
for k, aliases in mapping.items(): | |
for a in aliases: | |
if a in raw_json: | |
out[k] = raw_json[a] | |
break | |
return out | |
def mistral_ocr_tool(file_b64: str) -> Dict[str, str]: | |
import base64, json | |
raw = _mistral_ocr_call(base64.b64decode(file_b64)) | |
return _extract_fields(raw) | |
OCR_TOOL = StructuredTool.from_function( | |
mistral_ocr_tool, | |
name="mistral_ocr", | |
description=( | |
"Extracts key passport fields from a base64‑encoded PDF, JPEG, or PNG " | |
"by calling the Mistral OCR service." | |
), | |
) | |
# ------------------ single‑turn agent ------------------ | |
_llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0) | |
_AGENT = initialize_agent( | |
tools=[OCR_TOOL], | |
llm=_llm, | |
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, | |
verbose=False, | |
) | |
def parse_passport(pdf_bytes: bytes) -> Dict[str, str]: | |
import base64, json | |
if not MISTRAL_KEY: | |
return {} | |
b64 = base64.b64encode(pdf_bytes).decode() | |
try: | |
result = _AGENT.run( | |
"Extract passport fields", | |
input={"mistral_ocr": b64} | |
) | |
return json.loads(result) if isinstance(result, str) else result | |
except Exception as e: | |
logging.error("Agent failed: %s", e) | |
return {} | |