SubhaL's picture
Updated app.py
5286d26 verified
import gradio as gr
from transformers import pipeline, AutoTokenizer
# Load model and tokenizer
model_name = "ealvaradob/bert-finetuned-phishing"
classifier = pipeline("text-classification", model=model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
MAX_TOKENS = 512
def count_tokens(text):
return len(tokenizer.encode(text, truncation=False))
def chunk_text(text, max_tokens=MAX_TOKENS):
words = text.split()
chunks = []
current_chunk = []
current_length = 0
for word in words:
word_length = len(tokenizer.encode(word, add_special_tokens=False))
if current_length + word_length > max_tokens:
chunks.append(" ".join(current_chunk))
current_chunk = [word]
current_length = word_length
else:
current_chunk.append(word)
current_length += word_length
if current_chunk:
chunks.append(" ".join(current_chunk))
return chunks
def process_chunks(chunks):
phishing_count = 0
legitimate_count = 0
total_score = 0
for chunk in chunks:
result = classifier(chunk)[0]
label = result['label'].lower()
score = result['score']
total_score += score
if label == "phishing":
phishing_count += 1
else:
legitimate_count += 1
final_label = "Phishing" if phishing_count > legitimate_count else "Legitimate"
average_confidence = total_score / len(chunks)
return f"Prediction: {final_label}\nAverage Confidence: {average_confidence:.2%}"
def detect_phishing(input_text):
token_count = count_tokens(input_text)
if token_count <= MAX_TOKENS:
result = classifier(input_text)[0]
label = "Phishing" if result['label'].lower() == "phishing" else "Legitimate"
return f"Prediction: {label}\nConfidence: {result['score']:.2%}"
else:
chunks = chunk_text(input_text)
return process_chunks(chunks)
# Gradio interface
demo = gr.Interface(
fn=detect_phishing,
inputs=gr.Textbox(lines=8, placeholder="Paste email content here..."),
outputs="text",
title="Phishing Email Detector",
description="Uses a fine-tuned BERT model to classify whether the email is phishing or legitimate. Handles long emails by chunking."
)
demo.launch()