sanabanu31's picture
Update app.py
fbb3cc8 verified
raw
history blame
2.74 kB
from fastapi import FastAPI
from pydantic import BaseModel
import joblib
import re
# Initialize FastAPI app
app = FastAPI(
title="Email Classification API",
version="1.0.0",
description="Classifies support emails into categories and masks personal information.",
docs_url="/docs",
redoc_url="/redoc"
)
# Load pre-trained model
model = joblib.load("model.joblib")
# Input schema
class EmailInput(BaseModel):
input_email_body: str
# PII Masking Function
def mask_and_store_all_pii(text):
text = str(text)
pii_map = {}
entity_list = []
patterns = {
"email": r"\b[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+\b",
"phone_number": r"(?<!\d)(\+?\d[\d\s\-]{7,14}\d)(?!\d)",
"dob": r"\b\d{1,2}[\/\-\.\s]\d{1,2}[\/\-\.\s]\d{2,4}\b",
"aadhar_num": r"\b\d{4}[ -]?\d{4}[ -]?\d{4}\b(?![\d])",
"credit_debit_no": r"\b(?:\d[ -]*?){13,19}\b",
"cvv_no": r"(?i)\b(?:CVV[:\s]*)?(\d{3,4})\b",
"expiry_no": r"\b(0[1-9]|1[0-2])[\/\-]\d{2,4}\b",
}
# Track masked spans to prevent overlapping matches
masked_spans = []
def is_overlapping(start, end):
return any(s <= start < e or s < end <= e for s, e in masked_spans)
for label, pattern in patterns.items():
for match in re.finditer(pattern, text):
original = match.group()
start, end = match.start(), match.end()
if is_overlapping(start, end):
continue
placeholder = f"[{label}_{len(pii_map):03d}]"
pii_map[placeholder] = original
entity_list.append({
"position": [start, end],
"classification": label,
"entity": original
})
text = text[:start] + placeholder + text[end:]
masked_spans.append((start, start + len(placeholder)))
return text, pii_map, entity_list
# Restore PII
def restore_pii(masked_text, pii_map):
restored = masked_text
for placeholder, original in pii_map.items():
restored = restored.replace(placeholder, original)
return restored
# Classification Endpoint
@app.post("/classify")
def classify_email(data: EmailInput):
raw_text = data.input_email_body
# Masking
masked_text, pii_map, entity_list = mask_and_store_all_pii(raw_text)
# Prediction
predicted_category = model.predict([masked_text])[0]
# Response format
return {
"input_email_body": raw_text,
"list_of_masked_entities": entity_list,
"masked_email": masked_text,
"category_of_the_email": predicted_category
}
# Health check endpoint
@app.get("/")
def root():
return {"message": "Email Classification API is running."}