File size: 2,315 Bytes
27d4105
 
 
 
 
 
 
5286d26
27d4105
 
 
 
 
 
 
5286d26
 
 
 
27d4105
 
 
 
 
 
 
 
 
5286d26
27d4105
 
5286d26
27d4105
 
 
5286d26
 
 
 
27d4105
 
5286d26
 
27d4105
5286d26
27d4105
 
 
 
5286d26
27d4105
 
5286d26
 
27d4105
 
5286d26
 
 
27d4105
 
5286d26
27d4105
 
 
 
5286d26
27d4105
5286d26
 
27d4105
 
5286d26
27d4105
 
5286d26
 
ecf2409
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import gradio as gr
from transformers import pipeline, AutoTokenizer

# Load model and tokenizer
model_name = "ealvaradob/bert-finetuned-phishing"
classifier = pipeline("text-classification", model=model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

MAX_TOKENS = 512

def count_tokens(text):
    return len(tokenizer.encode(text, truncation=False))

def chunk_text(text, max_tokens=MAX_TOKENS):
    words = text.split()
    chunks = []
    current_chunk = []
    current_length = 0

    for word in words:
        word_length = len(tokenizer.encode(word, add_special_tokens=False))
        if current_length + word_length > max_tokens:
            chunks.append(" ".join(current_chunk))
            current_chunk = [word]
            current_length = word_length
        else:
            current_chunk.append(word)
            current_length += word_length

    if current_chunk:
        chunks.append(" ".join(current_chunk))

    return chunks

def process_chunks(chunks):
    phishing_count = 0
    legitimate_count = 0
    total_score = 0

    for chunk in chunks:
        result = classifier(chunk)[0]
        label = result['label'].lower()
        score = result['score']
        total_score += score

        if label == "phishing":
            phishing_count += 1
        else:
            legitimate_count += 1

    final_label = "Phishing" if phishing_count > legitimate_count else "Legitimate"
    average_confidence = total_score / len(chunks)

    return f"Prediction: {final_label}\nAverage Confidence: {average_confidence:.2%}"

def detect_phishing(input_text):
    token_count = count_tokens(input_text)

    if token_count <= MAX_TOKENS:
        result = classifier(input_text)[0]
        label = "Phishing" if result['label'].lower() == "phishing" else "Legitimate"
        return f"Prediction: {label}\nConfidence: {result['score']:.2%}"
    else:
        chunks = chunk_text(input_text)
        return process_chunks(chunks)

# Gradio interface
demo = gr.Interface(
    fn=detect_phishing,
    inputs=gr.Textbox(lines=8, placeholder="Paste email content here..."),
    outputs="text",
    title="Phishing Email Detector",
    description="Uses a fine-tuned BERT model to classify whether the email is phishing or legitimate. Handles long emails by chunking."
)

demo.launch()