Spaces:
Sleeping
Sleeping
File size: 14,084 Bytes
9256d25 2481aa9 9256d25 d4809b8 2481aa9 9256d25 2481aa9 9256d25 d4809b8 9256d25 2481aa9 d4809b8 2481aa9 9256d25 2481aa9 9256d25 2481aa9 9256d25 2481aa9 9256d25 2481aa9 9256d25 d4809b8 9256d25 d4809b8 9256d25 2481aa9 9256d25 2481aa9 9256d25 2481aa9 9256d25 2481aa9 9256d25 2481aa9 9256d25 2481aa9 9256d25 2481aa9 9256d25 2481aa9 9256d25 2481aa9 9256d25 6175992 9256d25 2481aa9 d4809b8 9256d25 d4809b8 9256d25 d4809b8 9256d25 d4809b8 9256d25 d4809b8 9256d25 d4809b8 9256d25 d4809b8 9256d25 d4809b8 9256d25 d4809b8 9256d25 2481aa9 9256d25 d4809b8 2481aa9 d4809b8 9256d25 d4809b8 9256d25 d4809b8 9256d25 2481aa9 d4809b8 9256d25 d4809b8 9256d25 2481aa9 9256d25 d4809b8 9256d25 2481aa9 d4809b8 9256d25 2481aa9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 |
import os, re, time, json, urllib.parse
import gradio as gr
import torch
import torch.nn.functional as F
# Optional robust domain parsing; code falls back if missing.
try:
import tldextract
except Exception:
tldextract = None
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
URL_MODEL_ID = "CrabInHoney/urlbert-tiny-v4-malicious-url-classifier"
# Force readable labels regardless of model config
ID2LABEL = {0: "benign", 1: "defacement", 2: "malware", 3: "phishing"}
URL_RE = re.compile(r"""(?xi)\b(?:https?://|www\.)[^\s<>"'()]+""")
KEYWORDS = {
"phish","login","verify","account","secure","update","bank","wallet",
"password","invoice","pay","reset","support","unlock","confirm"
}
SUSPICIOUS_TLDS = {
"zip","mov","lol","xyz","top","country","link","click","cam","help",
"gq","cf","tk","work","rest","monster","quest","live","io","ly"
}
URL_SHORTENERS = {
"bit.ly","tinyurl.com","t.co","goo.gl","is.gd","buff.ly","ow.ly","rebrand.ly","cutt.ly"
}
_tok = None
_mdl = None
# ---------- utils ----------
def _extract_urls(text: str):
raw = [m.group(0).strip() for m in URL_RE.finditer(text or "")]
cleaned = []
for u in raw:
u = u.rstrip(").,;:!?•]}>\"'")
cleaned.append(u)
return sorted(set(cleaned))
def _load_model():
global _tok, _mdl
if _tok is not None and _mdl is not None:
return _tok, _mdl
from transformers import AutoTokenizer, AutoModelForSequenceClassification
_tok = AutoTokenizer.from_pretrained(URL_MODEL_ID)
_mdl = AutoModelForSequenceClassification.from_pretrained(URL_MODEL_ID)
_mdl.eval()
return _tok, _mdl
def _softmax(logits: torch.Tensor):
return F.softmax(logits, dim=-1).tolist()
def _results_table(rows):
lines = [
"| URL | Model | Model Prob (%) | Heuristic | Fused Risk | Decision | Reasons |",
"|---|---|---:|---:|---:|:--:|---|",
]
for r in rows:
u, lbl, pct, h, fused, decision, reasons = r
lines.append(
f"| `{u}` | **{lbl}** | {pct:.2f} | {h:.2f} | {fused:.2f} | {decision} | {reasons} |"
)
return "\n".join(lines)
def _forensic_block(url, token_ids, tokens, scores_sorted, cls_vec, elapsed_s, truncated):
toks_prev = ", ".join(tokens[:64]) + (" …" if len(tokens) > 64 else "")
ids_prev = ", ".join(map(str, token_ids[:64])) + (" …" if len(token_ids) > 64 else "")
cls_dim = len(cls_vec)
cls_prev = ", ".join(f"{v:.4f}" for v in cls_vec[:16]) + (" …" if cls_dim > 16 else "")
l2 = (sum(v*v for v in cls_vec)) ** 0.5
md = []
md.append(f"### 🔍 Forensics for `{url}`\n")
md.append(f"- tokens: **{len(tokens)}** • truncated: **{'yes' if truncated else 'no'}**")
md.append(f"- inference time: **{elapsed_s:.2f}s**\n")
md.append("**Top-k scores**")
md.append("| Class | Prob (%) | Logit |\n|---|---:|---:|")
for s in scores_sorted:
md.append(f"| **{s['label']}** | {s['prob']*100:.2f} | {s['logit']:.3f} |")
md.append("\n**Token IDs (preview)**")
md.append("```txt\n" + ids_prev + "\n```")
md.append("**Tokens (preview)**")
md.append("```txt\n" + toks_prev + "\n```")
md.append("**[CLS] embedding (preview)**")
md.append(f"`dim={cls_dim}`, `L2={l2:.4f}`")
md.append("```txt\n" + cls_prev + "\n```")
return "\n".join(md)
# ---------- heuristics ----------
def _safe_parse(url: str):
if not re.match(r"^https?://", url, re.I):
url = "http://" + url
return urllib.parse.urlparse(url)
def _split_reg_domain(host: str):
parts = host.split(".")
if len(parts) >= 2:
return parts[-2] + "." + parts[-1]
return host
def _domain_parts(host: str):
if tldextract:
ext = tldextract.extract(host) # subdomain, domain, suffix
regdom = f"{ext.domain}.{ext.suffix}" if ext.domain and ext.suffix else host
sub = ext.subdomain or ""
tld = ext.suffix or ""
core = ext.domain or ""
else:
regdom = _split_reg_domain(host)
tld = regdom.split(".")[-1] if "." in regdom else ""
sub = host[:-len(regdom)].rstrip(".") if host.endswith(regdom) else ""
core = regdom.split(".")[0] if "." in regdom else regdom
return regdom, sub, core, tld
def heuristic_features(u: str):
feats = {}
try:
p = _safe_parse(u)
feats["host"] = p.hostname or ""
feats["path"] = p.path or "/"
feats["query"] = p.query or ""
regdom, sub, core, tld = _domain_parts(feats["host"])
feats["registered_domain"] = regdom
feats["subdomain"] = sub
feats["tld"] = tld
feats["labels"] = feats["host"].count(".") + (1 if feats["host"] else 0)
feats["has_at"] = "@" in u
feats["has_port"] = bool(p.netloc and ":" in p.netloc.split("@")[-1])
feats["has_punycode"] = "xn--" in feats["host"]
feats["len_url"] = len(u)
feats["hyphen_in_regdom"] = "-" in (core or "")
low_host = feats["host"].lower()
low_path = feats["path"].lower()
feats["kw_in_path"] = int(any(k in low_path for k in KEYWORDS))
feats["kw_in_host"] = int(any(k in low_host for k in KEYWORDS))
feats["kw_in_subdomain_only"] = int(
feats["kw_in_host"] and (core and not any(k in (core.lower()) for k in KEYWORDS))
)
feats["suspicious_tld"] = int((feats["tld"].split(".")[-1] or "") in SUSPICIOUS_TLDS)
feats["is_shortener"] = int(regdom.lower() in URL_SHORTENERS)
alnum = sum(c.isalnum() for c in feats["query"])
feats["query_ratio_alnum"] = (alnum / max(1, len(feats["query"]))) if feats["query"] else 0.0
feats["parse_error"] = False
except Exception:
feats = {"parse_error": True}
return feats
def heuristic_score(feats: dict) -> float:
if feats.get("parse_error"):
return 0.80
s = 0.0
s += 0.28 * feats["kw_in_path"]
s += 0.24 * feats["kw_in_subdomain_only"]
s += 0.10 * feats["kw_in_host"]
s += 0.12 * feats["hyphen_in_regdom"]
s += 0.10 * (feats["labels"] >= 4)
s += 0.10 * feats["has_punycode"]
s += 0.12 * feats["suspicious_tld"]
s += 0.10 * feats["is_shortener"]
s += 0.05 * feats["has_at"]
s += 0.05 * feats["has_port"]
s += 0.10 * (feats["len_url"] >= 100)
if feats.get("query") and len(feats.get("query", "")) >= 40 and feats.get("query_ratio_alnum", 0) > 0.9:
s += 0.10
return max(0.0, min(1.0, s))
def heuristic_reasons(feats: dict) -> str:
if feats.get("parse_error"):
return "parse error"
rs = []
if feats.get("is_shortener"): rs.append("URL shortener")
if feats.get("kw_in_path"): rs.append("keyword in path")
if feats.get("kw_in_subdomain_only"): rs.append("keyword in subdomain")
if feats.get("kw_in_host") and not feats.get("kw_in_subdomain_only"): rs.append("keyword in host")
if feats.get("hyphen_in_regdom"): rs.append("hyphen in registered domain")
if feats.get("labels", 0) >= 4: rs.append("deep subdomain nesting")
if feats.get("has_punycode"): rs.append("punycode host")
if feats.get("suspicious_tld"): rs.append(f"suspicious TLD: {feats.get('tld')}")
if feats.get("has_at"): rs.append("@ in URL")
if feats.get("has_port"): rs.append("explicit port")
if feats.get("len_url", 0) >= 100: rs.append("very long URL") # ✅ fixed
if feats.get("query") and len(feats.get("query", "")) >= 40 and feats.get("query_ratio_alnum", 0) > 0.9:
rs.append("long query blob")
return ", ".join(rs) if rs else "no heuristic triggers"
def heuristic_hard_flag(feats: dict) -> (bool, str):
if feats.get("parse_error"):
return True, "unparsable URL"
if feats.get("kw_in_subdomain_only") and feats.get("kw_in_path"):
return True, "keyword in subdomain + keyword in path"
if feats.get("is_shortener") and (feats.get("kw_in_host") or feats.get("kw_in_path")):
return True, "URL shortener + keyword"
if feats.get("suspicious_tld") and (feats.get("kw_in_host") or feats.get("kw_in_path")):
return True, "suspicious TLD + keyword"
if feats.get("labels", 0) >= 4 and (feats.get("kw_in_host") or feats.get("kw_in_path")):
return True, "deep subdomain nesting + keyword"
return False, ""
# ---------- core ----------
def _parse_allowlist(s: str):
items = re.split(r"[,\s]+", (s or "").strip())
return {x.strip().lower() for x in items if x.strip()}
def analyze(
text: str,
forensic: bool,
show_json: bool,
threshold: float,
allowlist_txt: str,
allowlist_override: bool
):
"""
One Markdown output:
- verdict + table (model, heuristic, fused + decision + reasons)
- optional forensic blocks
- optional raw JSON
"""
text = (text or "").strip()
if not text:
return "Paste an email body or a URL."
urls = [text] if (text.lower().startswith(("http://","https://","www.")) and " " not in text) else _extract_urls(text)
if not urls:
return "No URLs detected in the text."
allowset = _parse_allowlist(allowlist_txt)
tok, mdl = _load_model()
rows = []
forensic_blocks = []
export_data = {"model_id": URL_MODEL_ID, "items": []}
any_unsafe = False
for u in urls:
# model forward
max_len = min(512, getattr(mdl.config, "max_position_embeddings", 512) or 512)
enc = tok(u, truncation=True, max_length=max_len, return_tensors="pt", return_attention_mask=True)
token_ids = enc["input_ids"][0].tolist()
tokens = tok.convert_ids_to_tokens(enc["input_ids"][0])
truncated = enc["input_ids"].shape[1] >= max_len and len(tokens) >= max_len
t0 = time.time()
with torch.no_grad():
out = mdl(**enc, output_hidden_states=True)
elapsed = time.time() - t0
logits = out.logits.squeeze(0)
probs = _softmax(logits)
scores = [{"label": ID2LABEL[i], "prob": float(probs[i]), "logit": float(logits[i])}
for i in range(len(probs))]
scores_sorted = sorted(scores, key=lambda x: x["prob"], reverse=True)
top = scores_sorted[0]
# heuristics
feats = heuristic_features(u)
regdom = feats.get("registered_domain", "").lower()
h_flag, h_reason = heuristic_hard_flag(feats)
h_score = heuristic_score(feats)
mdl_phish_like = sum(s["prob"] for s in scores_sorted if s["label"] in {"phishing","malware","defacement"})
fused = 0.50 * mdl_phish_like + 0.50 * h_score
# allowlist override (domain-based)
allow_hit = regdom in allowset if regdom else False
decision = "🛑 UNSAFE"
reasons = (h_reason + (", " if h_reason else "") + heuristic_reasons(feats)).strip(", ")
if allow_hit and allowlist_override:
decision = "✅ SAFE"
reasons = f"allowlisted domain ({regdom})"
fused = min(fused, 0.01) # clamp down the risk for display
else:
decision = "🛑 UNSAFE" if (h_flag or fused >= float(threshold)) else "✅ SAFE"
if decision.startswith("🛑"):
any_unsafe = True
rows.append([u, top["label"], top["prob"]*100.0, h_score, fused, decision, reasons])
# export + forensics
hidden_states = out.hidden_states
cls_vec = hidden_states[-1][0, 0, :].cpu().tolist()
export_data["items"].append({
"url": u, "token_ids": token_ids, "tokens": tokens, "truncated": truncated,
"logits": [float(x) for x in logits.cpu().tolist()], "probs": [float(p) for p in probs],
"scores_sorted": scores_sorted, "cls_vector": cls_vec, "cls_dim": len(cls_vec),
"elapsed_sec": elapsed, "heuristic": feats, "heuristic_score": h_score,
"fused_risk": fused, "hard_flag": h_flag, "hard_reason": h_reason,
"allowlisted": allow_hit
})
if forensic:
forensic_blocks.append(
_forensic_block(u, token_ids, tokens, scores_sorted, cls_vec, elapsed, truncated)
)
verdict = "🔴 **UNSAFE (at least one link flagged)**" if any_unsafe else "🟢 **SAFE (no link over threshold)**"
body = verdict + "\n\n" + _results_table(rows)
if forensic and forensic_blocks:
body += "\n\n---\n\n" + "\n\n---\n\n".join(forensic_blocks)
if show_json:
pretty = json.dumps(export_data, ensure_ascii=False, indent=2)
body += "\n\n---\n\n**Raw forensics JSON (copy & save):**\n"
body += "```json\n" + pretty + "\n```"
return body
# ---------- UI ----------
demo = gr.Interface(
fn=analyze,
inputs=[
gr.Textbox(lines=10, label="Email or URL", placeholder="Paste a URL or a full email…"),
gr.Checkbox(label="Forensic mode (tokens, logits, [CLS])", value=True),
gr.Checkbox(label="Show raw JSON at the end (copy/paste)", value=False),
gr.Slider(0.0, 1.0, value=0.40, step=0.01, label="Decision threshold (fused risk ≥ threshold → UNSAFE)"),
gr.Textbox(lines=2, label="Allowlist (domains, comma/space/newline separated)",
placeholder="example.com, github.com microsoft.com"),
gr.Checkbox(label="Allowlist overrides (force SAFE if registered domain matches)", value=True),
],
outputs=gr.Markdown(label="Results"),
title="🛡️ PhishingMail — Model + Heuristics (HF Free CPU)",
description=(
"Extract links, score with a tiny HF URL model and transparent heuristics. "
"Short-circuits for classic phishing patterns. Adjust the threshold, and allowlist trusted domains."
),
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
|