formpilot-demo / rag /ocr_mistral_old.py
afulara's picture
Auto‑deploy from GitHub
0c0a4f7 verified
"""
Mistral OCR Tool + Agent
Call `parse_passport(bytes) -> dict`
"""
import os, json, requests, logging
from typing import Dict, Any, List
from langchain.tools import StructuredTool
from langchain.agents import initialize_agent, AgentType
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv
load_dotenv()
# ------------------ raw tool ------------------
MISTRAL_URL = os.getenv("MISTRAL_OCR_URL", "https://api.mistral.ai/ocr/v1/documents")
MISTRAL_KEY = os.getenv("MISTRAL_OCR_KEY") # set this in .env / HF secret
_HEADERS = {"Authorization": f"Bearer {MISTRAL_KEY}"} if MISTRAL_KEY else {}
def _mistral_ocr_call(file_bytes: bytes) -> Dict[str, Any]:
"""Send PDF, JPEG, or PNG bytes to Mistral OCR endpoint."""
if not MISTRAL_KEY:
logging.warning("MISTRAL_OCR_KEY not set")
return {}
# --- detect MIME type & extension ---
header = file_bytes[:8]
if header.startswith(b"\x89PNG\r\n\x1a\n"):
content_type = "image/png"
ext = "png"
elif header[:2] == b"\xff\xd8":
content_type = "image/jpeg"
ext = "jpg"
elif header[:4] == b"%PDF":
content_type = "application/pdf"
ext = "pdf"
else:
content_type = "application/octet-stream"
ext = "bin"
# new JSON payload with base64 content
import base64
b64 = base64.b64encode(file_bytes).decode("utf-8")
payload = {"content": b64}
resp = requests.post(
MISTRAL_URL,
headers={**_HEADERS, "Content-Type": "application/json"},
json=payload,
timeout=60,
)
if resp.status_code != 200:
logging.error("Mistral OCR error %s: %s", resp.status_code, resp.text[:200])
return {}
return resp.json()
# ------------------ LangChain Tool ------------------
def _extract_fields(raw_json: Dict[str, Any]) -> Dict[str, str]:
mapping = {
"FirstName": ["first_name", "given_name"],
"LastName": ["last_name", "surname"],
"DateOfBirth": ["date_of_birth", "dob"],
"PassportNumber": ["passport_number", "document_number"],
}
out = {}
for k, aliases in mapping.items():
for a in aliases:
if a in raw_json:
out[k] = raw_json[a]
break
return out
def mistral_ocr_tool(file_b64: str) -> Dict[str, str]:
import base64, json
raw = _mistral_ocr_call(base64.b64decode(file_b64))
return _extract_fields(raw)
OCR_TOOL = StructuredTool.from_function(
mistral_ocr_tool,
name="mistral_ocr",
description=(
"Extracts key passport fields from a base64‑encoded PDF, JPEG, or PNG "
"by calling the Mistral OCR service."
),
)
# ------------------ single‑turn agent ------------------
_llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
_AGENT = initialize_agent(
tools=[OCR_TOOL],
llm=_llm,
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=False,
)
def parse_passport(pdf_bytes: bytes) -> Dict[str, str]:
import base64, json
if not MISTRAL_KEY:
return {}
b64 = base64.b64encode(pdf_bytes).decode()
try:
result = _AGENT.run(
"Extract passport fields",
input={"mistral_ocr": b64}
)
return json.loads(result) if isinstance(result, str) else result
except Exception as e:
logging.error("Agent failed: %s", e)
return {}