Spaces:
Sleeping
Sleeping
import os, re, time, json, urllib.parse | |
import gradio as gr | |
import torch | |
import torch.nn.functional as F | |
# Optional robust domain parsing; code falls back if missing. | |
try: | |
import tldextract | |
except Exception: | |
tldextract = None | |
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") | |
URL_MODEL_ID = "CrabInHoney/urlbert-tiny-v4-malicious-url-classifier" | |
# Force readable labels regardless of model config | |
ID2LABEL = {0: "benign", 1: "defacement", 2: "malware", 3: "phishing"} | |
URL_RE = re.compile(r"""(?xi)\b(?:https?://|www\.)[^\s<>"'()]+""") | |
KEYWORDS = { | |
"phish","login","verify","account","secure","update","bank","wallet", | |
"password","invoice","pay","reset","support","unlock","confirm" | |
} | |
SUSPICIOUS_TLDS = { | |
"zip","mov","lol","xyz","top","country","link","click","cam","help", | |
"gq","cf","tk","work","rest","monster","quest","live","io","ly" | |
} | |
URL_SHORTENERS = { | |
"bit.ly","tinyurl.com","t.co","goo.gl","is.gd","buff.ly","ow.ly","rebrand.ly","cutt.ly" | |
} | |
_tok = None | |
_mdl = None | |
# ---------- utils ---------- | |
def _extract_urls(text: str): | |
raw = [m.group(0).strip() for m in URL_RE.finditer(text or "")] | |
cleaned = [] | |
for u in raw: | |
u = u.rstrip(").,;:!?•]}>\"'") | |
cleaned.append(u) | |
return sorted(set(cleaned)) | |
def _load_model(): | |
global _tok, _mdl | |
if _tok is not None and _mdl is not None: | |
return _tok, _mdl | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
_tok = AutoTokenizer.from_pretrained(URL_MODEL_ID) | |
_mdl = AutoModelForSequenceClassification.from_pretrained(URL_MODEL_ID) | |
_mdl.eval() | |
return _tok, _mdl | |
def _softmax(logits: torch.Tensor): | |
return F.softmax(logits, dim=-1).tolist() | |
def _results_table(rows): | |
lines = [ | |
"| URL | Model | Model Prob (%) | Heuristic | Fused Risk | Decision | Reasons |", | |
"|---|---|---:|---:|---:|:--:|---|", | |
] | |
for r in rows: | |
u, lbl, pct, h, fused, decision, reasons = r | |
lines.append( | |
f"| `{u}` | **{lbl}** | {pct:.2f} | {h:.2f} | {fused:.2f} | {decision} | {reasons} |" | |
) | |
return "\n".join(lines) | |
def _forensic_block(url, token_ids, tokens, scores_sorted, cls_vec, elapsed_s, truncated): | |
toks_prev = ", ".join(tokens[:64]) + (" …" if len(tokens) > 64 else "") | |
ids_prev = ", ".join(map(str, token_ids[:64])) + (" …" if len(token_ids) > 64 else "") | |
cls_dim = len(cls_vec) | |
cls_prev = ", ".join(f"{v:.4f}" for v in cls_vec[:16]) + (" …" if cls_dim > 16 else "") | |
l2 = (sum(v*v for v in cls_vec)) ** 0.5 | |
md = [] | |
md.append(f"### 🔍 Forensics for `{url}`\n") | |
md.append(f"- tokens: **{len(tokens)}** • truncated: **{'yes' if truncated else 'no'}**") | |
md.append(f"- inference time: **{elapsed_s:.2f}s**\n") | |
md.append("**Top-k scores**") | |
md.append("| Class | Prob (%) | Logit |\n|---|---:|---:|") | |
for s in scores_sorted: | |
md.append(f"| **{s['label']}** | {s['prob']*100:.2f} | {s['logit']:.3f} |") | |
md.append("\n**Token IDs (preview)**") | |
md.append("```txt\n" + ids_prev + "\n```") | |
md.append("**Tokens (preview)**") | |
md.append("```txt\n" + toks_prev + "\n```") | |
md.append("**[CLS] embedding (preview)**") | |
md.append(f"`dim={cls_dim}`, `L2={l2:.4f}`") | |
md.append("```txt\n" + cls_prev + "\n```") | |
return "\n".join(md) | |
# ---------- heuristics ---------- | |
def _safe_parse(url: str): | |
if not re.match(r"^https?://", url, re.I): | |
url = "http://" + url | |
return urllib.parse.urlparse(url) | |
def _split_reg_domain(host: str): | |
parts = host.split(".") | |
if len(parts) >= 2: | |
return parts[-2] + "." + parts[-1] | |
return host | |
def _domain_parts(host: str): | |
if tldextract: | |
ext = tldextract.extract(host) # subdomain, domain, suffix | |
regdom = f"{ext.domain}.{ext.suffix}" if ext.domain and ext.suffix else host | |
sub = ext.subdomain or "" | |
tld = ext.suffix or "" | |
core = ext.domain or "" | |
else: | |
regdom = _split_reg_domain(host) | |
tld = regdom.split(".")[-1] if "." in regdom else "" | |
sub = host[:-len(regdom)].rstrip(".") if host.endswith(regdom) else "" | |
core = regdom.split(".")[0] if "." in regdom else regdom | |
return regdom, sub, core, tld | |
def heuristic_features(u: str): | |
feats = {} | |
try: | |
p = _safe_parse(u) | |
feats["host"] = p.hostname or "" | |
feats["path"] = p.path or "/" | |
feats["query"] = p.query or "" | |
regdom, sub, core, tld = _domain_parts(feats["host"]) | |
feats["registered_domain"] = regdom | |
feats["subdomain"] = sub | |
feats["tld"] = tld | |
feats["labels"] = feats["host"].count(".") + (1 if feats["host"] else 0) | |
feats["has_at"] = "@" in u | |
feats["has_port"] = bool(p.netloc and ":" in p.netloc.split("@")[-1]) | |
feats["has_punycode"] = "xn--" in feats["host"] | |
feats["len_url"] = len(u) | |
feats["hyphen_in_regdom"] = "-" in (core or "") | |
low_host = feats["host"].lower() | |
low_path = feats["path"].lower() | |
feats["kw_in_path"] = int(any(k in low_path for k in KEYWORDS)) | |
feats["kw_in_host"] = int(any(k in low_host for k in KEYWORDS)) | |
feats["kw_in_subdomain_only"] = int( | |
feats["kw_in_host"] and (core and not any(k in (core.lower()) for k in KEYWORDS)) | |
) | |
feats["suspicious_tld"] = int((feats["tld"].split(".")[-1] or "") in SUSPICIOUS_TLDS) | |
feats["is_shortener"] = int(regdom.lower() in URL_SHORTENERS) | |
alnum = sum(c.isalnum() for c in feats["query"]) | |
feats["query_ratio_alnum"] = (alnum / max(1, len(feats["query"]))) if feats["query"] else 0.0 | |
feats["parse_error"] = False | |
except Exception: | |
feats = {"parse_error": True} | |
return feats | |
def heuristic_score(feats: dict) -> float: | |
if feats.get("parse_error"): | |
return 0.80 | |
s = 0.0 | |
s += 0.28 * feats["kw_in_path"] | |
s += 0.24 * feats["kw_in_subdomain_only"] | |
s += 0.10 * feats["kw_in_host"] | |
s += 0.12 * feats["hyphen_in_regdom"] | |
s += 0.10 * (feats["labels"] >= 4) | |
s += 0.10 * feats["has_punycode"] | |
s += 0.12 * feats["suspicious_tld"] | |
s += 0.10 * feats["is_shortener"] | |
s += 0.05 * feats["has_at"] | |
s += 0.05 * feats["has_port"] | |
s += 0.10 * (feats["len_url"] >= 100) | |
if feats.get("query") and len(feats.get("query", "")) >= 40 and feats.get("query_ratio_alnum", 0) > 0.9: | |
s += 0.10 | |
return max(0.0, min(1.0, s)) | |
def heuristic_reasons(feats: dict) -> str: | |
if feats.get("parse_error"): | |
return "parse error" | |
rs = [] | |
if feats.get("is_shortener"): rs.append("URL shortener") | |
if feats.get("kw_in_path"): rs.append("keyword in path") | |
if feats.get("kw_in_subdomain_only"): rs.append("keyword in subdomain") | |
if feats.get("kw_in_host") and not feats.get("kw_in_subdomain_only"): rs.append("keyword in host") | |
if feats.get("hyphen_in_regdom"): rs.append("hyphen in registered domain") | |
if feats.get("labels", 0) >= 4: rs.append("deep subdomain nesting") | |
if feats.get("has_punycode"): rs.append("punycode host") | |
if feats.get("suspicious_tld"): rs.append(f"suspicious TLD: {feats.get('tld')}") | |
if feats.get("has_at"): rs.append("@ in URL") | |
if feats.get("has_port"): rs.append("explicit port") | |
if feats.get("len_url", 0) >= 100: rs.append("very long URL") # ✅ fixed | |
if feats.get("query") and len(feats.get("query", "")) >= 40 and feats.get("query_ratio_alnum", 0) > 0.9: | |
rs.append("long query blob") | |
return ", ".join(rs) if rs else "no heuristic triggers" | |
def heuristic_hard_flag(feats: dict) -> (bool, str): | |
if feats.get("parse_error"): | |
return True, "unparsable URL" | |
if feats.get("kw_in_subdomain_only") and feats.get("kw_in_path"): | |
return True, "keyword in subdomain + keyword in path" | |
if feats.get("is_shortener") and (feats.get("kw_in_host") or feats.get("kw_in_path")): | |
return True, "URL shortener + keyword" | |
if feats.get("suspicious_tld") and (feats.get("kw_in_host") or feats.get("kw_in_path")): | |
return True, "suspicious TLD + keyword" | |
if feats.get("labels", 0) >= 4 and (feats.get("kw_in_host") or feats.get("kw_in_path")): | |
return True, "deep subdomain nesting + keyword" | |
return False, "" | |
# ---------- core ---------- | |
def _parse_allowlist(s: str): | |
items = re.split(r"[,\s]+", (s or "").strip()) | |
return {x.strip().lower() for x in items if x.strip()} | |
def analyze( | |
text: str, | |
forensic: bool, | |
show_json: bool, | |
threshold: float, | |
allowlist_txt: str, | |
allowlist_override: bool | |
): | |
""" | |
One Markdown output: | |
- verdict + table (model, heuristic, fused + decision + reasons) | |
- optional forensic blocks | |
- optional raw JSON | |
""" | |
text = (text or "").strip() | |
if not text: | |
return "Paste an email body or a URL." | |
urls = [text] if (text.lower().startswith(("http://","https://","www.")) and " " not in text) else _extract_urls(text) | |
if not urls: | |
return "No URLs detected in the text." | |
allowset = _parse_allowlist(allowlist_txt) | |
tok, mdl = _load_model() | |
rows = [] | |
forensic_blocks = [] | |
export_data = {"model_id": URL_MODEL_ID, "items": []} | |
any_unsafe = False | |
for u in urls: | |
# model forward | |
max_len = min(512, getattr(mdl.config, "max_position_embeddings", 512) or 512) | |
enc = tok(u, truncation=True, max_length=max_len, return_tensors="pt", return_attention_mask=True) | |
token_ids = enc["input_ids"][0].tolist() | |
tokens = tok.convert_ids_to_tokens(enc["input_ids"][0]) | |
truncated = enc["input_ids"].shape[1] >= max_len and len(tokens) >= max_len | |
t0 = time.time() | |
with torch.no_grad(): | |
out = mdl(**enc, output_hidden_states=True) | |
elapsed = time.time() - t0 | |
logits = out.logits.squeeze(0) | |
probs = _softmax(logits) | |
scores = [{"label": ID2LABEL[i], "prob": float(probs[i]), "logit": float(logits[i])} | |
for i in range(len(probs))] | |
scores_sorted = sorted(scores, key=lambda x: x["prob"], reverse=True) | |
top = scores_sorted[0] | |
# heuristics | |
feats = heuristic_features(u) | |
regdom = feats.get("registered_domain", "").lower() | |
h_flag, h_reason = heuristic_hard_flag(feats) | |
h_score = heuristic_score(feats) | |
mdl_phish_like = sum(s["prob"] for s in scores_sorted if s["label"] in {"phishing","malware","defacement"}) | |
fused = 0.50 * mdl_phish_like + 0.50 * h_score | |
# allowlist override (domain-based) | |
allow_hit = regdom in allowset if regdom else False | |
decision = "🛑 UNSAFE" | |
reasons = (h_reason + (", " if h_reason else "") + heuristic_reasons(feats)).strip(", ") | |
if allow_hit and allowlist_override: | |
decision = "✅ SAFE" | |
reasons = f"allowlisted domain ({regdom})" | |
fused = min(fused, 0.01) # clamp down the risk for display | |
else: | |
decision = "🛑 UNSAFE" if (h_flag or fused >= float(threshold)) else "✅ SAFE" | |
if decision.startswith("🛑"): | |
any_unsafe = True | |
rows.append([u, top["label"], top["prob"]*100.0, h_score, fused, decision, reasons]) | |
# export + forensics | |
hidden_states = out.hidden_states | |
cls_vec = hidden_states[-1][0, 0, :].cpu().tolist() | |
export_data["items"].append({ | |
"url": u, "token_ids": token_ids, "tokens": tokens, "truncated": truncated, | |
"logits": [float(x) for x in logits.cpu().tolist()], "probs": [float(p) for p in probs], | |
"scores_sorted": scores_sorted, "cls_vector": cls_vec, "cls_dim": len(cls_vec), | |
"elapsed_sec": elapsed, "heuristic": feats, "heuristic_score": h_score, | |
"fused_risk": fused, "hard_flag": h_flag, "hard_reason": h_reason, | |
"allowlisted": allow_hit | |
}) | |
if forensic: | |
forensic_blocks.append( | |
_forensic_block(u, token_ids, tokens, scores_sorted, cls_vec, elapsed, truncated) | |
) | |
verdict = "🔴 **UNSAFE (at least one link flagged)**" if any_unsafe else "🟢 **SAFE (no link over threshold)**" | |
body = verdict + "\n\n" + _results_table(rows) | |
if forensic and forensic_blocks: | |
body += "\n\n---\n\n" + "\n\n---\n\n".join(forensic_blocks) | |
if show_json: | |
pretty = json.dumps(export_data, ensure_ascii=False, indent=2) | |
body += "\n\n---\n\n**Raw forensics JSON (copy & save):**\n" | |
body += "```json\n" + pretty + "\n```" | |
return body | |
# ---------- UI ---------- | |
demo = gr.Interface( | |
fn=analyze, | |
inputs=[ | |
gr.Textbox(lines=10, label="Email or URL", placeholder="Paste a URL or a full email…"), | |
gr.Checkbox(label="Forensic mode (tokens, logits, [CLS])", value=True), | |
gr.Checkbox(label="Show raw JSON at the end (copy/paste)", value=False), | |
gr.Slider(0.0, 1.0, value=0.40, step=0.01, label="Decision threshold (fused risk ≥ threshold → UNSAFE)"), | |
gr.Textbox(lines=2, label="Allowlist (domains, comma/space/newline separated)", | |
placeholder="example.com, github.com microsoft.com"), | |
gr.Checkbox(label="Allowlist overrides (force SAFE if registered domain matches)", value=True), | |
], | |
outputs=gr.Markdown(label="Results"), | |
title="🛡️ PhishingMail — Model + Heuristics (HF Free CPU)", | |
description=( | |
"Extract links, score with a tiny HF URL model and transparent heuristics. " | |
"Short-circuits for classic phishing patterns. Adjust the threshold, and allowlist trusted domains." | |
), | |
) | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True) | |