File size: 3,399 Bytes
0c0a4f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
Mistral OCR Tool + Agent
Call `parse_passport(bytes) -> dict`
"""

import os, json, requests, logging
from typing import Dict, Any, List
from langchain.tools import StructuredTool
from langchain.agents import initialize_agent, AgentType
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv
load_dotenv() 
# ------------------ raw tool ------------------
MISTRAL_URL = os.getenv("MISTRAL_OCR_URL", "https://api.mistral.ai/ocr/v1/documents")
MISTRAL_KEY = os.getenv("MISTRAL_OCR_KEY")     # set this in .env / HF secret
_HEADERS = {"Authorization": f"Bearer {MISTRAL_KEY}"} if MISTRAL_KEY else {}

def _mistral_ocr_call(file_bytes: bytes) -> Dict[str, Any]:
    """Send PDF, JPEG, or PNG bytes to Mistral OCR endpoint."""
    if not MISTRAL_KEY:
        logging.warning("MISTRAL_OCR_KEY not set")
        return {}

    # --- detect MIME type & extension ---
    header = file_bytes[:8]
    if header.startswith(b"\x89PNG\r\n\x1a\n"):
        content_type = "image/png"
        ext = "png"
    elif header[:2] == b"\xff\xd8":
        content_type = "image/jpeg"
        ext = "jpg"
    elif header[:4] == b"%PDF":
        content_type = "application/pdf"
        ext = "pdf"
    else:
        content_type = "application/octet-stream"
        ext = "bin"

    # new JSON payload with base64 content
    import base64
    b64 = base64.b64encode(file_bytes).decode("utf-8")
    payload = {"content": b64}
    resp = requests.post(
        MISTRAL_URL,
        headers={**_HEADERS, "Content-Type": "application/json"},
        json=payload,
        timeout=60,
    )
    if resp.status_code != 200:
        logging.error("Mistral OCR error %s: %s", resp.status_code, resp.text[:200])
        return {}
    return resp.json()

# ------------------ LangChain Tool ------------------
def _extract_fields(raw_json: Dict[str, Any]) -> Dict[str, str]:
    mapping = {
        "FirstName": ["first_name", "given_name"],
        "LastName": ["last_name", "surname"],
        "DateOfBirth": ["date_of_birth", "dob"],
        "PassportNumber": ["passport_number", "document_number"],
    }
    out = {}
    for k, aliases in mapping.items():
        for a in aliases:
            if a in raw_json:
                out[k] = raw_json[a]
                break
    return out

def mistral_ocr_tool(file_b64: str) -> Dict[str, str]:
    import base64, json
    raw = _mistral_ocr_call(base64.b64decode(file_b64))
    return _extract_fields(raw)

OCR_TOOL = StructuredTool.from_function(
    mistral_ocr_tool,
    name="mistral_ocr",
    description=(
        "Extracts key passport fields from a base64‑encoded PDF, JPEG, or PNG "
        "by calling the Mistral OCR service."
    ),
)

# ------------------ single‑turn agent ------------------
_llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
_AGENT = initialize_agent(
    tools=[OCR_TOOL],
    llm=_llm,
    agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    verbose=False,
)

def parse_passport(pdf_bytes: bytes) -> Dict[str, str]:
    import base64, json
    if not MISTRAL_KEY:
        return {}
    b64 = base64.b64encode(pdf_bytes).decode()
    try:
        result = _AGENT.run(
            "Extract passport fields",
            input={"mistral_ocr": b64}
        )
        return json.loads(result) if isinstance(result, str) else result
    except Exception as e:
        logging.error("Agent failed: %s", e)
        return {}