sathish2352 commited on
Commit
df3b7e9
·
verified ·
1 Parent(s): 9572b93

Upload 4 files

Browse files
Files changed (4) hide show
  1. main.py +24 -0
  2. models.py +20 -0
  3. requirements.txt +7 -0
  4. utils.py +61 -0
main.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Entry point for FastAPI
2
+ from fastapi import FastAPI
3
+ from pydantic import BaseModel
4
+ from models import load_model, classify_email
5
+ from utils import mask_pii_multilingual
6
+
7
+ app = FastAPI()
8
+ tokenizer, model, device = load_model()
9
+
10
+ class EmailInput(BaseModel):
11
+ input_email_body: str
12
+
13
+ @app.post("/classify")
14
+ async def classify_route(request: EmailInput):
15
+ text = request.input_email_body
16
+ masked_text, entities = mask_pii_multilingual(text)
17
+ category = classify_email(masked_text, tokenizer, model, device)
18
+ return {
19
+ "input_email_body": text,
20
+ "list_of_masked_entities": entities,
21
+ "masked_email": masked_text,
22
+ "category_of_the_email": category
23
+ }
24
+
models.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
+ import torch
3
+
4
+ def load_model():
5
+ model_path = "model"
6
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
7
+ model = AutoModelForSequenceClassification.from_pretrained(model_path)
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+ model.to(device)
10
+ model.eval()
11
+ return tokenizer, model, device
12
+
13
+ def classify_email(text, tokenizer, model, device):
14
+ inputs = tokenizer(text, return_tensors="pt", max_length=256, padding="max_length", truncation=True)
15
+ inputs = {k: v.to(device) for k, v in inputs.items()}
16
+ with torch.no_grad():
17
+ logits = model(**inputs).logits
18
+ label_map = {0: "Incident", 1: "Request", 2: "Change", 3: "Problem"}
19
+ pred = torch.argmax(logits, dim=1).item()
20
+ return label_map[pred]
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ transformers
4
+ torch
5
+ pydantic
6
+ pandas
7
+ numpy
utils.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer
3
+ def mask_pii_multilingual(text: str):
4
+
5
+ # Load model only once globally if needed
6
+ model_name = "Davlan/xlm-roberta-base-ner-hrl"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForTokenClassification.from_pretrained(model_name)
9
+ ner_pipe = pipeline("token-classification", model=model, tokenizer=tokenizer, aggregation_strategy="simple")
10
+
11
+ regex_patterns = {
12
+ "email": r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b",
13
+ "phone_number": r"(?:\+?\d{1,3})?[-.\s]?\(?\d{1,4}\)?[-.\s]?\d{2,4}[-.\s]?\d{2,4}[-.\s]?\d{2,4}",
14
+ "dob": r"\b(0?[1-9]|[12][0-9]|3[01])[-/](0?[1-9]|1[012])[-/](19[5-9]\d|20[0-3]\d)\b",
15
+ "aadhar_num": r"\b\d{4}[\s-]?\d{4}[\s-]?\d{4}\b",
16
+ "credit_debit_no": r"\b(?:\d{4}[\s-]?){3}\d{4}\b",
17
+ "cvv_no": r"\b\d{3,4}\b",
18
+ "expiry_no": r"\b(0[1-9]|1[0-2])[/-]?(?:\d{2}|\d{4})\b"
19
+ }
20
+
21
+ entities = []
22
+ masked_text = text
23
+ offsets = []
24
+
25
+ # Step 1: Apply regex PII masking first
26
+ for entity_type, pattern in regex_patterns.items():
27
+ for match in re.finditer(pattern, text):
28
+ start, end = match.start(), match.end()
29
+ if any(start < e[1] and end > e[0] for e in offsets):
30
+ continue
31
+ token = f"[{entity_type}]"
32
+ entity_val = text[start:end]
33
+ masked_text = masked_text[:start] + token + masked_text[end:]
34
+ offsets.append((start, end))
35
+ entities.append({
36
+ "position": [start, end],
37
+ "classification": entity_type,
38
+ "entity": entity_val
39
+ })
40
+
41
+ # Step 2: Run NER on updated masked_text to avoid overlap
42
+ ner_results = ner_pipe(masked_text)
43
+ for ent in ner_results:
44
+ start, end = ent["start"], ent["end"]
45
+ if ent["entity_group"] != "PER":
46
+ continue
47
+ if any(start < e[1] and end > e[0] for e in offsets):
48
+ continue
49
+ token = "[full_name]"
50
+ entity_val = text[start:end]
51
+ masked_text = masked_text[:start] + token + masked_text[end:]
52
+ entities.append({
53
+ "position": [start, end],
54
+ "classification": "full_name",
55
+ "entity": entity_val
56
+ })
57
+ offsets.append((start, end))
58
+
59
+ # Sort final result
60
+ entities.sort(key=lambda x: x["position"][0])
61
+ return masked_text, entities