from fastapi import FastAPI from pydantic import BaseModel import joblib import pandas as pd import re from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.svm import LinearSVC from transformers import pipeline # Initialize FastAPI app app = FastAPI( title="Email Classification API", version="1.0.0", description="Classifies support emails into categories and masks personal information.", docs_url="/docs", redoc_url="/redoc" ) # Load model and vectorizer model = joblib.load("model.joblib") vectorizer = joblib.load("vectorizer.joblib") # Initialize NER pipeline ner = pipeline('ner', model='Davlan/xlm-roberta-base-ner-hrl', grouped_entities=True) # Input schemas class EmailInput(BaseModel): input_email_body: str class TrainingExample(BaseModel): email_body: str label: str # Map NER labels to types NER_TO_TOKEN = { 'PER': 'full_name', 'EMAIL': 'email', 'DATE': 'dob' } # Regex patterns for PII EMAIL_REGEX = r'\b[\w\.-]+@[\w\.-]+\.\w{2,}\b' AADHAAR_REGEX = r'\b\d{4}\s?\d{4}\s?\d{4}\b' CARD_REGEX = r'\b(?:\d[ -]*?){13,19}\b' CVV_REGEX = r'(?i)\b(?:cvv[:\s\-]*)?(\d{3,4})\b' EXPIRY_REGEX = r'\b(0[1-9]|1[0-2])[\/\-]\d{2,4}\b' PHONE_REGEX = r'\+?\d[\d\s\-]{7,14}\d' DOB_REGEX = r'\b\d{1,2}[\/\-\.\s]\d{1,2}[\/\-\.\s]\d{2,4}\b' # Masking function def mask_and_store_all_pii(text): text = str(text) mapping = {} counter = { 'full_name': 0, 'email': 0, 'phone_number': 0, 'dob': 0, 'aadhar_num': 0, 'credit_debit_no': 0, 'cvv_no': 0, 'expiry_no': 0 } entity_list = [] # NER-based masking entities = ner(text) for ent in entities: label = ent['entity_group'] if label in NER_TO_TOKEN: token_name = NER_TO_TOKEN[label] original = ent['word'].replace('##', '') token = f"[{token_name}_{counter[token_name]:03d}]" if original in text: start = text.index(original) text = text.replace(original, token, 1) mapping[token] = original counter[token_name] += 1 entity_list.append({ "position": [start, start + len(token)], "classification": token_name, "entity": original }) # Regex-based masking regex_map = [ (CARD_REGEX, 'credit_debit_no'), (AADHAAR_REGEX, 'aadhar_num'), (PHONE_REGEX, 'phone_number'), (CVV_REGEX, 'cvv_no'), (EXPIRY_REGEX, 'expiry_no'), (EMAIL_REGEX, 'email'), (DOB_REGEX, 'dob') ] for regex, token_name in regex_map: for match in re.finditer(regex, text): original = match.group(0) token = f"[{token_name}_{counter[token_name]:03d}]" if original in text: start = text.index(original) text = text.replace(original, token, 1) mapping[token] = original counter[token_name] += 1 entity_list.append({ "position": [start, start + len(token)], "classification": token_name, "entity": original }) return text, mapping, entity_list # Restore PII (optional use) def restore_pii(masked_text, pii_map): for placeholder, original in pii_map.items(): masked_text = masked_text.replace(placeholder, original) return masked_text # Prediction endpoint @app.post("/classify") def classify_email(data: EmailInput): raw_text = data.input_email_body masked_text, pii_map, entity_list = mask_and_store_all_pii(raw_text) features = vectorizer.transform([masked_text]) predicted_category = model.predict(features)[0] return { "input_email_body": raw_text, "list_of_masked_entities": entity_list, "masked_email": masked_text, "category_of_the_email": predicted_category } # Retraining endpoint @app.post("/train") def train_model(new_example: TrainingExample): df = pd.DataFrame([{"email_body": new_example.email_body, "label": new_example.label}]) try: df.to_csv("training_data.csv", mode='a', header=not pd.io.common.file_exists("training_data.csv"), index=False) except Exception as e: return {"error": f"Failed to append to dataset: {str(e)}"} # Load dataset full_df = pd.read_csv("training_data.csv") full_df['masked_text'] = full_df['email_body'].apply(lambda x: mask_and_store_all_pii(x)[0]) # Vectorize and train new_vectorizer = TfidfVectorizer() X = new_vectorizer.fit_transform(full_df['masked_text']) y = full_df['label'] new_model = LinearSVC() new_model.fit(X, y) # Save updated model and vectorizer joblib.dump(new_model, "model.joblib") joblib.dump(new_vectorizer, "vectorizer.joblib") return {"message": "Model retrained successfully with new example."} # Health check @app.get("/") def root(): return {"message": "Email Classification API is running."}