sanabanu31's picture
Update app.py
9b81b0a verified
raw
history blame
3.9 kB
from fastapi import FastAPI
from pydantic import BaseModel
import joblib
import pandas as pd
import re
from transformers import pipeline
# Initialize FastAPI app
app = FastAPI(
title="Email Classification API",
version="1.0.0",
description="Classifies support emails into categories and masks personal information.",
docs_url="/docs",
redoc_url="/redoc"
)
# Load the combined model pipeline (includes vectorizer)
model = joblib.load("model.joblib")
# Initialize NER pipeline
ner = pipeline('ner', model='Davlan/xlm-roberta-base-ner-hrl', grouped_entities=True)
# Input schemas
class EmailInput(BaseModel):
input_email_body: str
class TrainingExample(BaseModel):
email_body: str
label: str
# Map NER labels to types
NER_TO_TOKEN = {
'PER': 'full_name',
'EMAIL': 'email',
'DATE': 'dob'
}
# Regex patterns for PII
EMAIL_REGEX = r'\b[\w\.-]+@[\w\.-]+\.\w{2,}\b'
AADHAAR_REGEX = r'\b\d{4}\s?\d{4}\s?\d{4}\b'
CARD_REGEX = r'\b(?:\d[ -]*?){13,19}\b'
CVV_REGEX = r'(?i)\b(?:cvv[:\s\-]*)?(\d{3,4})\b'
EXPIRY_REGEX = r'\b(0[1-9]|1[0-2])[\/\-]\d{2,4}\b'
PHONE_REGEX = r'\+?\d[\d\s\-]{7,14}\d'
DOB_REGEX = r'\b\d{1,2}[\/\-\.\s]\d{1,2}[\/\-\.\s]\d{2,4}\b'
# Masking function
def mask_and_store_all_pii(text):
text = str(text)
mapping = {}
counter = {
'full_name': 0, 'email': 0, 'phone_number': 0, 'dob': 0,
'aadhar_num': 0, 'credit_debit_no': 0, 'cvv_no': 0, 'expiry_no': 0
}
entity_list = []
# NER-based masking
entities = ner(text)
for ent in entities:
label = ent['entity_group']
if label in NER_TO_TOKEN:
token_name = NER_TO_TOKEN[label]
original = ent['word'].replace('##', '')
token = f"[{token_name}_{counter[token_name]:03d}]"
if original in text:
start = text.index(original)
text = text.replace(original, token, 1)
mapping[token] = original
counter[token_name] += 1
entity_list.append({
"position": [start, start + len(token)],
"classification": token_name,
"entity": original
})
# Regex-based masking
regex_map = [
(CARD_REGEX, 'credit_debit_no'),
(AADHAAR_REGEX, 'aadhar_num'),
(PHONE_REGEX, 'phone_number'),
(CVV_REGEX, 'cvv_no'),
(EXPIRY_REGEX, 'expiry_no'),
(EMAIL_REGEX, 'email'),
(DOB_REGEX, 'dob')
]
for regex, token_name in regex_map:
for match in re.finditer(regex, text):
original = match.group(0)
token = f"[{token_name}_{counter[token_name]:03d}]"
if original in text:
start = text.index(original)
text = text.replace(original, token, 1)
mapping[token] = original
counter[token_name] += 1
entity_list.append({
"position": [start, start + len(token)],
"classification": token_name,
"entity": original
})
return text, mapping, entity_list
# Restore PII (optional use)
def restore_pii(masked_text, pii_map):
for placeholder, original in pii_map.items():
masked_text = masked_text.replace(placeholder, original)
return masked_text
# Prediction endpoint
@app.post("/classify")
def classify_email(data: EmailInput):
raw_text = data.input_email_body
masked_text, pii_map, entity_list = mask_and_store_all_pii(raw_text)
predicted_category = model.predict([masked_text])[0]
return {
"input_email_body": raw_text,
"list_of_masked_entities": entity_list,
"masked_email": masked_text,
"category_of_the_email": predicted_category
}
# Health check
@app.get("/")
def root():
return {"message": "Email Classification API is running."}