sanabanu31's picture
Update app.py
d0995a7 verified
raw
history blame
5.05 kB
from fastapi import FastAPI
from pydantic import BaseModel
import joblib
import pandas as pd
import re
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.svm import LinearSVC
from transformers import pipeline
# Initialize FastAPI app
app = FastAPI(
title="Email Classification API",
version="1.0.0",
description="Classifies support emails into categories and masks personal information.",
docs_url="/docs",
redoc_url="/redoc"
)
# Load model and vectorizer
model = joblib.load("model.joblib")
vectorizer = joblib.load("vectorizer.joblib")
# Initialize NER pipeline
ner = pipeline('ner', model='Davlan/xlm-roberta-base-ner-hrl', grouped_entities=True)
# Input schemas
class EmailInput(BaseModel):
input_email_body: str
class TrainingExample(BaseModel):
email_body: str
label: str
# Map NER labels to types
NER_TO_TOKEN = {
'PER': 'full_name',
'EMAIL': 'email',
'DATE': 'dob'
}
# Regex patterns for PII
EMAIL_REGEX = r'\b[\w\.-]+@[\w\.-]+\.\w{2,}\b'
AADHAAR_REGEX = r'\b\d{4}\s?\d{4}\s?\d{4}\b'
CARD_REGEX = r'\b(?:\d[ -]*?){13,19}\b'
CVV_REGEX = r'(?i)\b(?:cvv[:\s\-]*)?(\d{3,4})\b'
EXPIRY_REGEX = r'\b(0[1-9]|1[0-2])[\/\-]\d{2,4}\b'
PHONE_REGEX = r'\+?\d[\d\s\-]{7,14}\d'
DOB_REGEX = r'\b\d{1,2}[\/\-\.\s]\d{1,2}[\/\-\.\s]\d{2,4}\b'
# Masking function
def mask_and_store_all_pii(text):
text = str(text)
mapping = {}
counter = {
'full_name': 0, 'email': 0, 'phone_number': 0, 'dob': 0,
'aadhar_num': 0, 'credit_debit_no': 0, 'cvv_no': 0, 'expiry_no': 0
}
entity_list = []
# NER-based masking
entities = ner(text)
for ent in entities:
label = ent['entity_group']
if label in NER_TO_TOKEN:
token_name = NER_TO_TOKEN[label]
original = ent['word'].replace('##', '')
token = f"[{token_name}_{counter[token_name]:03d}]"
if original in text:
start = text.index(original)
text = text.replace(original, token, 1)
mapping[token] = original
counter[token_name] += 1
entity_list.append({
"position": [start, start + len(token)],
"classification": token_name,
"entity": original
})
# Regex-based masking
regex_map = [
(CARD_REGEX, 'credit_debit_no'),
(AADHAAR_REGEX, 'aadhar_num'),
(PHONE_REGEX, 'phone_number'),
(CVV_REGEX, 'cvv_no'),
(EXPIRY_REGEX, 'expiry_no'),
(EMAIL_REGEX, 'email'),
(DOB_REGEX, 'dob')
]
for regex, token_name in regex_map:
for match in re.finditer(regex, text):
original = match.group(0)
token = f"[{token_name}_{counter[token_name]:03d}]"
if original in text:
start = text.index(original)
text = text.replace(original, token, 1)
mapping[token] = original
counter[token_name] += 1
entity_list.append({
"position": [start, start + len(token)],
"classification": token_name,
"entity": original
})
return text, mapping, entity_list
# Restore PII (optional use)
def restore_pii(masked_text, pii_map):
for placeholder, original in pii_map.items():
masked_text = masked_text.replace(placeholder, original)
return masked_text
# Prediction endpoint
@app.post("/classify")
def classify_email(data: EmailInput):
raw_text = data.input_email_body
masked_text, pii_map, entity_list = mask_and_store_all_pii(raw_text)
features = vectorizer.transform([masked_text])
predicted_category = model.predict(features)[0]
return {
"input_email_body": raw_text,
"list_of_masked_entities": entity_list,
"masked_email": masked_text,
"category_of_the_email": predicted_category
}
# Retraining endpoint
@app.post("/train")
def train_model(new_example: TrainingExample):
df = pd.DataFrame([{"email_body": new_example.email_body, "label": new_example.label}])
try:
df.to_csv("training_data.csv", mode='a', header=not pd.io.common.file_exists("training_data.csv"), index=False)
except Exception as e:
return {"error": f"Failed to append to dataset: {str(e)}"}
# Load dataset
full_df = pd.read_csv("training_data.csv")
full_df['masked_text'] = full_df['email_body'].apply(lambda x: mask_and_store_all_pii(x)[0])
# Vectorize and train
new_vectorizer = TfidfVectorizer()
X = new_vectorizer.fit_transform(full_df['masked_text'])
y = full_df['label']
new_model = LinearSVC()
new_model.fit(X, y)
# Save updated model and vectorizer
joblib.dump(new_model, "model.joblib")
joblib.dump(new_vectorizer, "vectorizer.joblib")
return {"message": "Model retrained successfully with new example."}
# Health check
@app.get("/")
def root():
return {"message": "Email Classification API is running."}