Spaces:
Sleeping
Sleeping
from fastapi import FastAPI | |
from pydantic import BaseModel | |
import joblib | |
import re | |
# Initialize FastAPI app | |
app = FastAPI( | |
title="Email Classification API", | |
version="1.0.0", | |
description="Classifies support emails into categories and masks personal information.", | |
docs_url="/docs", | |
redoc_url="/redoc" | |
) | |
# Load pre-trained model | |
model = joblib.load("model.joblib") | |
# Input schema | |
class EmailInput(BaseModel): | |
input_email_body: str | |
# PII Masking Function | |
def mask_and_store_all_pii(text): | |
text = str(text) | |
pii_map = {} | |
entity_list = [] | |
patterns = { | |
"email": r"\b[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+\b", | |
"phone_number": r"(?<!\d)(\+?\d[\d\s\-]{7,14}\d)(?!\d)", | |
"dob": r"\b\d{1,2}[\/\-\.\s]\d{1,2}[\/\-\.\s]\d{2,4}\b", | |
"aadhar_num": r"\b\d{4}[ -]?\d{4}[ -]?\d{4}\b(?![\d])", | |
"credit_debit_no": r"\b(?:\d[ -]*?){13,19}\b", | |
"cvv_no": r"(?i)\b(?:CVV[:\s]*)?(\d{3,4})\b", | |
"expiry_no": r"\b(0[1-9]|1[0-2])[\/\-]\d{2,4}\b", | |
} | |
# Track masked spans to prevent overlapping matches | |
masked_spans = [] | |
def is_overlapping(start, end): | |
return any(s <= start < e or s < end <= e for s, e in masked_spans) | |
for label, pattern in patterns.items(): | |
for match in re.finditer(pattern, text): | |
original = match.group() | |
start, end = match.start(), match.end() | |
if is_overlapping(start, end): | |
continue | |
placeholder = f"[{label}_{len(pii_map):03d}]" | |
pii_map[placeholder] = original | |
entity_list.append({ | |
"position": [start, end], | |
"classification": label, | |
"entity": original | |
}) | |
text = text[:start] + placeholder + text[end:] | |
masked_spans.append((start, start + len(placeholder))) | |
return text, pii_map, entity_list | |
# Restore PII | |
def restore_pii(masked_text, pii_map): | |
restored = masked_text | |
for placeholder, original in pii_map.items(): | |
restored = restored.replace(placeholder, original) | |
return restored | |
# Classification Endpoint | |
def classify_email(data: EmailInput): | |
raw_text = data.input_email_body | |
# Masking | |
masked_text, pii_map, entity_list = mask_and_store_all_pii(raw_text) | |
# Prediction | |
predicted_category = model.predict([masked_text])[0] | |
# Response format | |
return { | |
"input_email_body": raw_text, | |
"list_of_masked_entities": entity_list, | |
"masked_email": masked_text, | |
"category_of_the_email": predicted_category | |
} | |
# Health check endpoint | |
def root(): | |
return {"message": "Email Classification API is running."} | |