Spaces:
Sleeping
Sleeping
import gradio as gr | |
import joblib | |
import json | |
import numpy as np | |
import re | |
from urllib.parse import urlparse | |
import os | |
from huggingface_hub import hf_hub_download | |
# Define the model and username | |
MODEL_NAME = "XGBoost" | |
HF_USERNAME = "Devishetty100" | |
CUSTOM_MODEL_NAME = "NeoGuardianAI" | |
REPO_ID = f"{HF_USERNAME}/{CUSTOM_MODEL_NAME.lower()}" | |
# List of trusted domains that should always be considered safe | |
TRUSTED_DOMAINS = [ | |
'huggingface.co', | |
'github.com', | |
'google.com', | |
'microsoft.com', | |
'apple.com', | |
'amazon.com', | |
'facebook.com', | |
'twitter.com', | |
'linkedin.com', | |
'youtube.com', | |
'wikipedia.org' | |
] | |
# Load model files (either from local files or Hugging Face Hub) | |
def load_model_files(): | |
try: | |
print(f"Attempting to download model from Hugging Face Hub: {REPO_ID}") | |
# Try to list files in the repository to see what's available | |
try: | |
from huggingface_hub import list_repo_files | |
files = list_repo_files(repo_id=REPO_ID) | |
print(f"Files available in the repository: {files}") | |
except Exception as list_error: | |
print(f"Error listing repository files: {list_error}") | |
# Use lowercase 'xgboost' instead of MODEL_NAME.lower() to match the actual filename | |
model_path = hf_hub_download(repo_id=REPO_ID, filename="xgboost_model.joblib") | |
print(f"Downloaded model file to: {model_path}") | |
scaler_path = hf_hub_download(repo_id=REPO_ID, filename="scaler.joblib") | |
feature_names_path = hf_hub_download(repo_id=REPO_ID, filename="feature_names.json") | |
# Load the model and preprocessing components | |
model = joblib.load(model_path) | |
scaler = joblib.load(scaler_path) | |
# Load feature names | |
with open(feature_names_path, 'r') as f: | |
feature_names = json.load(f) | |
print("Successfully downloaded model from Hugging Face Hub.") | |
return model, scaler, feature_names | |
except Exception as hub_error: | |
print(f"Error downloading from Hugging Face Hub: {hub_error}") | |
# If downloading fails, try to load from local files | |
try: | |
print("Attempting to load model from local files...") | |
# Try with the correct lowercase name | |
model = joblib.load("xgboost_model.joblib") | |
print("Successfully loaded xgboost_model.joblib") | |
scaler = joblib.load("scaler.joblib") | |
with open("feature_names.json", 'r') as f: | |
feature_names = json.load(f) | |
print("Successfully loaded model from local files.") | |
return model, scaler, feature_names | |
except Exception as local_error: | |
print(f"Error loading from local files: {local_error}") | |
raise RuntimeError("Failed to load model from both Hugging Face Hub and local files.") | |
# Extract features from URL | |
def extract_features(url): | |
"""Extract features from a URL for model prediction.""" | |
features = {} | |
# Basic URL properties | |
features['length_url'] = len(url) | |
# Parse URL | |
parsed_url = urlparse(url) | |
hostname = parsed_url.netloc | |
path = parsed_url.path | |
# Hostname features | |
features['length_hostname'] = len(hostname) | |
features['ip'] = 1 if re.match(r'\d+\.\d+\.\d+\.\d+', hostname) else 0 | |
# Count special characters | |
features['nb_dots'] = url.count('.') | |
features['nb_hyphens'] = url.count('-') | |
features['nb_at'] = url.count('@') | |
features['nb_qm'] = url.count('?') | |
features['nb_and'] = url.count('&') | |
features['nb_or'] = url.count('|') | |
features['nb_eq'] = url.count('=') | |
features['nb_underscore'] = url.count('_') | |
features['nb_tilde'] = url.count('~') | |
features['nb_percent'] = url.count('%') | |
features['nb_slash'] = url.count('/') | |
features['nb_star'] = url.count('*') | |
features['nb_colon'] = url.count(':') | |
features['nb_comma'] = url.count(',') | |
features['nb_semicolumn'] = url.count(';') | |
features['nb_dollar'] = url.count('$') | |
features['nb_space'] = url.count(' ') | |
# Other URL features | |
features['nb_www'] = 1 if 'www' in hostname else 0 | |
features['nb_com'] = 1 if '.com' in hostname else 0 | |
features['nb_dslash'] = url.count('//') | |
features['http_in_path'] = 1 if 'http' in path else 0 | |
features['https_token'] = 1 if 'https' in url and 'http://' not in url else 0 | |
# Ratio features | |
digits_count = sum(c.isdigit() for c in url) | |
features['ratio_digits_url'] = digits_count / len(url) if len(url) > 0 else 0 | |
features['ratio_digits_host'] = sum(c.isdigit() for c in hostname) / len(hostname) if len(hostname) > 0 else 0 | |
# Punycode | |
features['punycode'] = 1 if 'xn--' in hostname else 0 | |
# Port | |
features['port'] = 1 if ':' in hostname and any(c.isdigit() for c in hostname.split(':')[1]) else 0 | |
# TLD features | |
tlds = ['.com', '.org', '.net', '.edu', '.gov', '.mil', '.int'] | |
features['tld_in_path'] = 1 if any(tld in path for tld in tlds) else 0 | |
features['tld_in_subdomain'] = 1 if hostname.count('.') > 1 and any(tld in hostname.split('.')[0] for tld in tlds) else 0 | |
# Subdomain features | |
features['abnormal_subdomain'] = 1 if hostname.count('.') > 2 else 0 | |
features['nb_subdomains'] = hostname.count('.') | |
# Other suspicious features | |
features['prefix_suffix'] = 1 if '-' in hostname else 0 | |
features['random_domain'] = 1 if len(hostname) > 12 and sum(c.isdigit() for c in hostname) > 4 else 0 | |
# Shortening service | |
shortening_services = ['bit.ly', 'goo.gl', 'tinyurl.com', 't.co', 'tr.im', 'is.gd', 'cli.gs', 'ow.ly', 'yfrog.com', 'migre.me', 'ff.im', 'tiny.cc', 'url4.eu', 'twit.ac', 'su.pr', 'twurl.nl', 'snipurl.com', 'short.to', 'budurl.com', 'ping.fm', 'post.ly', 'just.as', 'bkite.com', 'snipr.com', 'fic.kr', 'loopt.us', 'doiop.com', 'twitthis.com', 'htxt.it', 'ak.im', 'shar.es', 'kl.am', 'wp.me', 'rubyurl.com', 'om.ly', 'to.ly', 'bit.do', 't.co', 'lnkd.in', 'db.tt', 'qr.ae', 'adf.ly', 'goo.gl', 'bitly.com', 'cur.lv', 'tinyurl.com', 'ow.ly', 'bit.ly', 'ity.im', 'q.gs', 'is.gd', 'po.st', 'bc.vc', 'twitthis.com', 'u.to', 'j.mp', 'buzurl.com', 'cutt.us', 'u.bb', 'yourls.org', 'x.co', 'prettylinkpro.com', 'scrnch.me', 'filoops.info', 'vzturl.com', 'qr.net', '1url.com', 'tweez.me', 'v.gd', 'tr.im', 'link.zip.net'] | |
features['shortening_service'] = 1 if any(service in hostname for service in shortening_services) else 0 | |
# Path features | |
features['path_extension'] = 1 if '.' in path.split('/')[-1] else 0 | |
# Fill in remaining features with default values | |
# These would normally be computed with more complex analysis | |
for feature in ['nb_redirection', 'nb_external_redirection', 'length_words_raw', | |
'char_repeat', 'shortest_words_raw', 'shortest_word_host', | |
'shortest_word_path', 'longest_words_raw', 'longest_word_host', | |
'longest_word_path', 'avg_words_raw', 'avg_word_host', | |
'avg_word_path', 'phish_hints', 'domain_in_brand', | |
'brand_in_subdomain', 'brand_in_path', 'suspecious_tld', | |
'statistical_report', 'nb_hyperlinks', 'ratio_intHyperlinks', | |
'ratio_extHyperlinks', 'ratio_nullHyperlinks', 'nb_extCSS', | |
'ratio_intRedirection', 'ratio_extRedirection', 'ratio_intErrors', | |
'ratio_extErrors', 'login_form', 'external_favicon', | |
'links_in_tags', 'submit_email', 'ratio_intMedia', | |
'ratio_extMedia', 'sfh', 'iframe', 'popup_window', | |
'safe_anchor', 'onmouseover', 'right_clic', 'empty_title', | |
'domain_in_title', 'domain_with_copyright', 'whois_registered_domain', | |
'domain_registration_length', 'domain_age', 'web_traffic', | |
'dns_record', 'google_index', 'page_rank']: | |
if feature not in features: | |
features[feature] = 0 | |
return features | |
# Load model and components | |
try: | |
model, scaler, feature_names = load_model_files() | |
print("Model loaded successfully!") | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
# Create dummy model and components for demo purposes | |
print("Using dummy model for demonstration purposes.") | |
import numpy as np | |
from sklearn.ensemble import RandomForestClassifier | |
# Create a dummy model | |
model = RandomForestClassifier(n_estimators=10) | |
model.fit(np.array([[0, 0]]), np.array([0])) | |
model.predict_proba = lambda x: np.array([[0.5, 0.5]]) | |
# Create dummy scaler and feature names | |
scaler = lambda x: x | |
scaler.transform = lambda x: x | |
feature_names = ['length_url', 'length_hostname'] | |
def predict_url(url): | |
"""Predict if a URL is phishing or legitimate.""" | |
if not url or not url.strip(): | |
return "Please enter a URL", 0.0, "N/A" | |
try: | |
# Check if the URL belongs to a trusted domain | |
parsed_url = urlparse(url) | |
domain = parsed_url.netloc | |
# Remove 'www.' prefix if present | |
if domain.startswith('www.'): | |
domain = domain[4:] | |
# Check if the domain or any parent domain is in the trusted list | |
is_trusted = False | |
domain_parts = domain.split('.') | |
for i in range(len(domain_parts) - 1): | |
check_domain = '.'.join(domain_parts[i:]) | |
if check_domain in TRUSTED_DOMAINS: | |
is_trusted = True | |
break | |
if is_trusted: | |
return "Legitimate (Trusted Domain)", 1.0, "✅ SAFE" | |
# Extract features | |
url_features = extract_features(url) | |
# Ensure features are in the correct order | |
features_array = [] | |
for feature in feature_names: | |
if feature in url_features: | |
features_array.append(url_features[feature]) | |
else: | |
features_array.append(0) # Default value if feature is missing | |
# Scale features | |
scaled_features = scaler.transform([features_array]) | |
# Make prediction | |
prediction = model.predict(scaled_features)[0] | |
probability = model.predict_proba(scaled_features)[0][1] | |
# Prepare return values | |
prediction_text = "Phishing" if prediction == 1 else "Legitimate" | |
confidence = float(probability) if prediction == 1 else float(1 - probability) | |
status = "⚠️ UNSAFE" if prediction == 1 else "✅ SAFE" | |
# Return three separate values for the three output components | |
return prediction_text, confidence, status | |
except Exception as e: | |
error_msg = f"Error: {str(e)}" | |
return error_msg, 0.0, "Error" | |
# Create Gradio interface | |
def create_interface(): | |
with gr.Blocks(title="NeoGuardianAI - URL Phishing Detection", theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
# NeoGuardianAI - URL Phishing Detection | |
This app uses a machine learning model to detect if a URL is legitimate or phishing. | |
Enter a URL below to check if it's safe or potentially malicious. | |
""" | |
) | |
with gr.Row(): | |
url_input = gr.Textbox(label="Enter URL", placeholder="https://example.com") | |
submit_btn = gr.Button("Check URL", variant="primary") | |
with gr.Row(): | |
status_output = gr.Textbox(label="Status") | |
prediction_output = gr.Textbox(label="Prediction") | |
confidence_output = gr.Textbox(label="Confidence") | |
submit_btn.click( | |
fn=predict_url, | |
inputs=url_input, | |
outputs=[ | |
prediction_output, | |
confidence_output, | |
status_output | |
] | |
) | |
gr.Markdown( | |
""" | |
## How it works | |
This model was trained on the [pirocheto/phishing-url](https://huggingface.co/datasets/pirocheto/phishing-url) dataset from Hugging Face. | |
The model extracts various features from the URL and uses a machine learning algorithm to classify it as legitimate or phishing. | |
**Note**: While this model is highly accurate, it's not perfect. Always exercise caution when visiting unfamiliar websites. | |
## API Usage | |
You can also use this model via the Hugging Face Inference API: | |
```python | |
import requests | |
API_URL = "https://api-inference.huggingface.co/models/Devishetty100/neoguardianai" | |
headers = {"Authorization": "Bearer YOUR_API_TOKEN"} | |
def query(url): | |
payload = {"inputs": url} | |
response = requests.post(API_URL, headers=headers, json=payload) | |
return response.json() | |
# Example | |
result = query("https://example.com") | |
print(result) | |
``` | |
""" | |
) | |
return demo | |
# Launch the app | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch() | |