Spaces:
Sleeping
Sleeping
from fastapi import FastAPI | |
from pydantic import BaseModel | |
import joblib | |
import pandas as pd | |
import re | |
from transformers import pipeline | |
# Initialize FastAPI app | |
app = FastAPI( | |
title="Email Classification API", | |
version="1.0.0", | |
description="Classifies support emails into categories and masks personal information.", | |
docs_url="/docs", | |
redoc_url="/redoc" | |
) | |
# Load the combined model pipeline (includes vectorizer) | |
model = joblib.load("model.joblib") | |
# Initialize NER pipeline | |
ner = pipeline('ner', model='Davlan/xlm-roberta-base-ner-hrl', grouped_entities=True) | |
# Input schemas | |
class EmailInput(BaseModel): | |
input_email_body: str | |
class TrainingExample(BaseModel): | |
email_body: str | |
label: str | |
# Map NER labels to types | |
NER_TO_TOKEN = { | |
'PER': 'full_name', | |
'EMAIL': 'email', | |
'DATE': 'dob' | |
} | |
# Regex patterns for PII | |
EMAIL_REGEX = r'\b[\w\.-]+@[\w\.-]+\.\w{2,}\b' | |
AADHAAR_REGEX = r'\b\d{4}\s?\d{4}\s?\d{4}\b' | |
CARD_REGEX = r'\b(?:\d[ -]*?){13,19}\b' | |
CVV_REGEX = r'(?i)\b(?:cvv[:\s\-]*)?(\d{3,4})\b' | |
EXPIRY_REGEX = r'\b(0[1-9]|1[0-2])[\/\-]\d{2,4}\b' | |
PHONE_REGEX = r'\+?\d[\d\s\-]{7,14}\d' | |
DOB_REGEX = r'\b\d{1,2}[\/\-\.\s]\d{1,2}[\/\-\.\s]\d{2,4}\b' | |
# Masking function | |
def mask_and_store_all_pii(text): | |
text = str(text) | |
mapping = {} | |
counter = { | |
'full_name': 0, 'email': 0, 'phone_number': 0, 'dob': 0, | |
'aadhar_num': 0, 'credit_debit_no': 0, 'cvv_no': 0, 'expiry_no': 0 | |
} | |
entity_list = [] | |
# NER-based masking | |
entities = ner(text) | |
for ent in entities: | |
label = ent['entity_group'] | |
if label in NER_TO_TOKEN: | |
token_name = NER_TO_TOKEN[label] | |
original = ent['word'].replace('##', '') | |
token = f"[{token_name}_{counter[token_name]:03d}]" | |
if original in text: | |
start = text.index(original) | |
text = text.replace(original, token, 1) | |
mapping[token] = original | |
counter[token_name] += 1 | |
entity_list.append({ | |
"position": [start, start + len(token)], | |
"classification": token_name, | |
"entity": original | |
}) | |
# Regex-based masking | |
regex_map = [ | |
(CARD_REGEX, 'credit_debit_no'), | |
(AADHAAR_REGEX, 'aadhar_num'), | |
(PHONE_REGEX, 'phone_number'), | |
(CVV_REGEX, 'cvv_no'), | |
(EXPIRY_REGEX, 'expiry_no'), | |
(EMAIL_REGEX, 'email'), | |
(DOB_REGEX, 'dob') | |
] | |
for regex, token_name in regex_map: | |
for match in re.finditer(regex, text): | |
original = match.group(0) | |
token = f"[{token_name}_{counter[token_name]:03d}]" | |
if original in text: | |
start = text.index(original) | |
text = text.replace(original, token, 1) | |
mapping[token] = original | |
counter[token_name] += 1 | |
entity_list.append({ | |
"position": [start, start + len(token)], | |
"classification": token_name, | |
"entity": original | |
}) | |
return text, mapping, entity_list | |
# Restore PII (optional use) | |
def restore_pii(masked_text, pii_map): | |
for placeholder, original in pii_map.items(): | |
masked_text = masked_text.replace(placeholder, original) | |
return masked_text | |
# Prediction endpoint | |
def classify_email(data: EmailInput): | |
raw_text = data.input_email_body | |
masked_text, pii_map, entity_list = mask_and_store_all_pii(raw_text) | |
predicted_category = model.predict([masked_text])[0] | |
return { | |
"input_email_body": raw_text, | |
"list_of_masked_entities": entity_list, | |
"masked_email": masked_text, | |
"category_of_the_email": predicted_category | |
} | |
# Health check | |
def root(): | |
return {"message": "Email Classification API is running."} | |