File size: 14,084 Bytes
9256d25
 
 
 
 
2481aa9
9256d25
 
 
 
 
 
 
 
 
d4809b8
 
 
 
2481aa9
9256d25
 
 
 
 
 
 
 
2481aa9
 
 
9256d25
 
 
 
d4809b8
9256d25
2481aa9
 
 
d4809b8
2481aa9
 
9256d25
 
 
 
 
 
 
 
 
 
 
 
 
 
2481aa9
9256d25
2481aa9
 
9256d25
2481aa9
 
9256d25
2481aa9
9256d25
 
 
 
 
 
 
 
 
 
 
 
 
 
d4809b8
 
 
9256d25
 
 
 
 
 
 
 
 
d4809b8
9256d25
 
 
 
 
 
 
 
 
 
 
2481aa9
 
 
 
 
 
 
 
 
 
 
 
 
 
9256d25
 
 
 
 
 
 
2481aa9
9256d25
2481aa9
9256d25
 
 
 
 
 
2481aa9
9256d25
 
 
 
 
2481aa9
9256d25
 
2481aa9
9256d25
 
 
 
 
 
 
 
 
2481aa9
 
 
 
 
 
 
 
 
 
 
 
 
9256d25
2481aa9
 
9256d25
 
 
 
 
2481aa9
9256d25
 
 
 
 
 
 
 
 
6175992
9256d25
 
 
 
2481aa9
 
 
 
 
 
 
 
 
 
 
 
 
d4809b8
 
 
 
 
 
 
 
 
 
 
 
 
9256d25
 
d4809b8
 
 
9256d25
 
 
 
 
 
 
 
 
d4809b8
 
9256d25
d4809b8
 
9256d25
 
 
 
 
d4809b8
9256d25
 
 
 
 
 
 
 
 
 
 
 
 
d4809b8
 
 
 
9256d25
d4809b8
9256d25
d4809b8
 
9256d25
d4809b8
 
 
 
 
 
 
 
 
 
 
 
 
 
9256d25
2481aa9
9256d25
 
d4809b8
2481aa9
d4809b8
 
 
9256d25
d4809b8
 
 
 
 
 
9256d25
 
 
 
d4809b8
9256d25
 
2481aa9
d4809b8
9256d25
 
 
 
 
 
 
 
 
 
 
d4809b8
9256d25
 
 
2481aa9
9256d25
 
d4809b8
 
 
 
9256d25
 
 
 
2481aa9
d4809b8
9256d25
 
 
 
2481aa9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
import os, re, time, json, urllib.parse
import gradio as gr
import torch
import torch.nn.functional as F

# Optional robust domain parsing; code falls back if missing.
try:
    import tldextract
except Exception:
    tldextract = None

os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")

URL_MODEL_ID = "CrabInHoney/urlbert-tiny-v4-malicious-url-classifier"

# Force readable labels regardless of model config
ID2LABEL = {0: "benign", 1: "defacement", 2: "malware", 3: "phishing"}

URL_RE = re.compile(r"""(?xi)\b(?:https?://|www\.)[^\s<>"'()]+""")

KEYWORDS = {
    "phish","login","verify","account","secure","update","bank","wallet",
    "password","invoice","pay","reset","support","unlock","confirm"
}
SUSPICIOUS_TLDS = {
    "zip","mov","lol","xyz","top","country","link","click","cam","help",
    "gq","cf","tk","work","rest","monster","quest","live","io","ly"
}
URL_SHORTENERS = {
    "bit.ly","tinyurl.com","t.co","goo.gl","is.gd","buff.ly","ow.ly","rebrand.ly","cutt.ly"
}

_tok = None
_mdl = None

# ---------- utils ----------
def _extract_urls(text: str):
    raw = [m.group(0).strip() for m in URL_RE.finditer(text or "")]
    cleaned = []
    for u in raw:
        u = u.rstrip(").,;:!?•]}>\"'")
        cleaned.append(u)
    return sorted(set(cleaned))

def _load_model():
    global _tok, _mdl
    if _tok is not None and _mdl is not None:
        return _tok, _mdl
    from transformers import AutoTokenizer, AutoModelForSequenceClassification
    _tok = AutoTokenizer.from_pretrained(URL_MODEL_ID)
    _mdl = AutoModelForSequenceClassification.from_pretrained(URL_MODEL_ID)
    _mdl.eval()
    return _tok, _mdl

def _softmax(logits: torch.Tensor):
    return F.softmax(logits, dim=-1).tolist()

def _results_table(rows):
    lines = [
        "| URL | Model | Model Prob (%) | Heuristic | Fused Risk | Decision | Reasons |",
        "|---|---|---:|---:|---:|:--:|---|",
    ]
    for r in rows:
        u, lbl, pct, h, fused, decision, reasons = r
        lines.append(
            f"| `{u}` | **{lbl}** | {pct:.2f} | {h:.2f} | {fused:.2f} | {decision} | {reasons} |"
        )
    return "\n".join(lines)

def _forensic_block(url, token_ids, tokens, scores_sorted, cls_vec, elapsed_s, truncated):
    toks_prev = ", ".join(tokens[:64]) + (" …" if len(tokens) > 64 else "")
    ids_prev  = ", ".join(map(str, token_ids[:64])) + (" …" if len(token_ids) > 64 else "")
    cls_dim = len(cls_vec)
    cls_prev = ", ".join(f"{v:.4f}" for v in cls_vec[:16]) + (" …" if cls_dim > 16 else "")
    l2 = (sum(v*v for v in cls_vec)) ** 0.5
    md = []
    md.append(f"### 🔍 Forensics for `{url}`\n")
    md.append(f"- tokens: **{len(tokens)}** • truncated: **{'yes' if truncated else 'no'}**")
    md.append(f"- inference time: **{elapsed_s:.2f}s**\n")
    md.append("**Top-k scores**")
    md.append("| Class | Prob (%) | Logit |\n|---|---:|---:|")
    for s in scores_sorted:
        md.append(f"| **{s['label']}** | {s['prob']*100:.2f} | {s['logit']:.3f} |")
    md.append("\n**Token IDs (preview)**")
    md.append("```txt\n" + ids_prev + "\n```")
    md.append("**Tokens (preview)**")
    md.append("```txt\n" + toks_prev + "\n```")
    md.append("**[CLS] embedding (preview)**")
    md.append(f"`dim={cls_dim}`, `L2={l2:.4f}`")
    md.append("```txt\n" + cls_prev + "\n```")
    return "\n".join(md)

# ---------- heuristics ----------
def _safe_parse(url: str):
    if not re.match(r"^https?://", url, re.I):
        url = "http://" + url
    return urllib.parse.urlparse(url)

def _split_reg_domain(host: str):
    parts = host.split(".")
    if len(parts) >= 2:
        return parts[-2] + "." + parts[-1]
    return host

def _domain_parts(host: str):
    if tldextract:
        ext = tldextract.extract(host)  # subdomain, domain, suffix
        regdom = f"{ext.domain}.{ext.suffix}" if ext.domain and ext.suffix else host
        sub = ext.subdomain or ""
        tld = ext.suffix or ""
        core = ext.domain or ""
    else:
        regdom = _split_reg_domain(host)
        tld = regdom.split(".")[-1] if "." in regdom else ""
        sub = host[:-len(regdom)].rstrip(".") if host.endswith(regdom) else ""
        core = regdom.split(".")[0] if "." in regdom else regdom
    return regdom, sub, core, tld

def heuristic_features(u: str):
    feats = {}
    try:
        p = _safe_parse(u)
        feats["host"] = p.hostname or ""
        feats["path"] = p.path or "/"
        feats["query"] = p.query or ""
        regdom, sub, core, tld = _domain_parts(feats["host"])
        feats["registered_domain"] = regdom
        feats["subdomain"] = sub
        feats["tld"] = tld
        feats["labels"] = feats["host"].count(".") + (1 if feats["host"] else 0)
        feats["has_at"] = "@" in u
        feats["has_port"] = bool(p.netloc and ":" in p.netloc.split("@")[-1])
        feats["has_punycode"] = "xn--" in feats["host"]
        feats["len_url"] = len(u)
        feats["hyphen_in_regdom"] = "-" in (core or "")
        low_host = feats["host"].lower()
        low_path = feats["path"].lower()
        feats["kw_in_path"] = int(any(k in low_path for k in KEYWORDS))
        feats["kw_in_host"] = int(any(k in low_host for k in KEYWORDS))
        feats["kw_in_subdomain_only"] = int(
            feats["kw_in_host"] and (core and not any(k in (core.lower()) for k in KEYWORDS))
        )
        feats["suspicious_tld"] = int((feats["tld"].split(".")[-1] or "") in SUSPICIOUS_TLDS)
        feats["is_shortener"] = int(regdom.lower() in URL_SHORTENERS)
        alnum = sum(c.isalnum() for c in feats["query"])
        feats["query_ratio_alnum"] = (alnum / max(1, len(feats["query"]))) if feats["query"] else 0.0
        feats["parse_error"] = False
    except Exception:
        feats = {"parse_error": True}
    return feats

def heuristic_score(feats: dict) -> float:
    if feats.get("parse_error"):
        return 0.80
    s = 0.0
    s += 0.28 * feats["kw_in_path"]
    s += 0.24 * feats["kw_in_subdomain_only"]
    s += 0.10 * feats["kw_in_host"]
    s += 0.12 * feats["hyphen_in_regdom"]
    s += 0.10 * (feats["labels"] >= 4)
    s += 0.10 * feats["has_punycode"]
    s += 0.12 * feats["suspicious_tld"]
    s += 0.10 * feats["is_shortener"]
    s += 0.05 * feats["has_at"]
    s += 0.05 * feats["has_port"]
    s += 0.10 * (feats["len_url"] >= 100)
    if feats.get("query") and len(feats.get("query", "")) >= 40 and feats.get("query_ratio_alnum", 0) > 0.9:
        s += 0.10
    return max(0.0, min(1.0, s))

def heuristic_reasons(feats: dict) -> str:
    if feats.get("parse_error"):
        return "parse error"
    rs = []
    if feats.get("is_shortener"): rs.append("URL shortener")
    if feats.get("kw_in_path"): rs.append("keyword in path")
    if feats.get("kw_in_subdomain_only"): rs.append("keyword in subdomain")
    if feats.get("kw_in_host") and not feats.get("kw_in_subdomain_only"): rs.append("keyword in host")
    if feats.get("hyphen_in_regdom"): rs.append("hyphen in registered domain")
    if feats.get("labels", 0) >= 4: rs.append("deep subdomain nesting")
    if feats.get("has_punycode"): rs.append("punycode host")
    if feats.get("suspicious_tld"): rs.append(f"suspicious TLD: {feats.get('tld')}")
    if feats.get("has_at"): rs.append("@ in URL")
    if feats.get("has_port"): rs.append("explicit port")
    if feats.get("len_url", 0) >= 100: rs.append("very long URL")   # ✅ fixed
    if feats.get("query") and len(feats.get("query", "")) >= 40 and feats.get("query_ratio_alnum", 0) > 0.9:
        rs.append("long query blob")
    return ", ".join(rs) if rs else "no heuristic triggers"

def heuristic_hard_flag(feats: dict) -> (bool, str):
    if feats.get("parse_error"):
        return True, "unparsable URL"
    if feats.get("kw_in_subdomain_only") and feats.get("kw_in_path"):
        return True, "keyword in subdomain + keyword in path"
    if feats.get("is_shortener") and (feats.get("kw_in_host") or feats.get("kw_in_path")):
        return True, "URL shortener + keyword"
    if feats.get("suspicious_tld") and (feats.get("kw_in_host") or feats.get("kw_in_path")):
        return True, "suspicious TLD + keyword"
    if feats.get("labels", 0) >= 4 and (feats.get("kw_in_host") or feats.get("kw_in_path")):
        return True, "deep subdomain nesting + keyword"
    return False, ""

# ---------- core ----------
def _parse_allowlist(s: str):
    items = re.split(r"[,\s]+", (s or "").strip())
    return {x.strip().lower() for x in items if x.strip()}

def analyze(

    text: str,

    forensic: bool,

    show_json: bool,

    threshold: float,

    allowlist_txt: str,

    allowlist_override: bool

):
    """

    One Markdown output:

      - verdict + table (model, heuristic, fused + decision + reasons)

      - optional forensic blocks

      - optional raw JSON

    """
    text = (text or "").strip()
    if not text:
        return "Paste an email body or a URL."

    urls = [text] if (text.lower().startswith(("http://","https://","www.")) and " " not in text) else _extract_urls(text)
    if not urls:
        return "No URLs detected in the text."

    allowset = _parse_allowlist(allowlist_txt)

    tok, mdl = _load_model()

    rows = []
    forensic_blocks = []
    export_data = {"model_id": URL_MODEL_ID, "items": []}
    any_unsafe = False

    for u in urls:
        # model forward
        max_len = min(512, getattr(mdl.config, "max_position_embeddings", 512) or 512)
        enc = tok(u, truncation=True, max_length=max_len, return_tensors="pt", return_attention_mask=True)
        token_ids = enc["input_ids"][0].tolist()
        tokens = tok.convert_ids_to_tokens(enc["input_ids"][0])
        truncated = enc["input_ids"].shape[1] >= max_len and len(tokens) >= max_len

        t0 = time.time()
        with torch.no_grad():
            out = mdl(**enc, output_hidden_states=True)
        elapsed = time.time() - t0

        logits = out.logits.squeeze(0)
        probs  = _softmax(logits)
        scores = [{"label": ID2LABEL[i], "prob": float(probs[i]), "logit": float(logits[i])}
                  for i in range(len(probs))]
        scores_sorted = sorted(scores, key=lambda x: x["prob"], reverse=True)
        top = scores_sorted[0]

        # heuristics
        feats = heuristic_features(u)
        regdom = feats.get("registered_domain", "").lower()
        h_flag, h_reason = heuristic_hard_flag(feats)
        h_score = heuristic_score(feats)
        mdl_phish_like = sum(s["prob"] for s in scores_sorted if s["label"] in {"phishing","malware","defacement"})
        fused = 0.50 * mdl_phish_like + 0.50 * h_score

        # allowlist override (domain-based)
        allow_hit = regdom in allowset if regdom else False
        decision = "🛑 UNSAFE"
        reasons = (h_reason + (", " if h_reason else "") + heuristic_reasons(feats)).strip(", ")

        if allow_hit and allowlist_override:
            decision = "✅ SAFE"
            reasons = f"allowlisted domain ({regdom})"
            fused = min(fused, 0.01)  # clamp down the risk for display
        else:
            decision = "🛑 UNSAFE" if (h_flag or fused >= float(threshold)) else "✅ SAFE"

        if decision.startswith("🛑"):
            any_unsafe = True

        rows.append([u, top["label"], top["prob"]*100.0, h_score, fused, decision, reasons])

        # export + forensics
        hidden_states = out.hidden_states
        cls_vec = hidden_states[-1][0, 0, :].cpu().tolist()
        export_data["items"].append({
            "url": u, "token_ids": token_ids, "tokens": tokens, "truncated": truncated,
            "logits": [float(x) for x in logits.cpu().tolist()], "probs": [float(p) for p in probs],
            "scores_sorted": scores_sorted, "cls_vector": cls_vec, "cls_dim": len(cls_vec),
            "elapsed_sec": elapsed, "heuristic": feats, "heuristic_score": h_score,
            "fused_risk": fused, "hard_flag": h_flag, "hard_reason": h_reason,
            "allowlisted": allow_hit
        })

        if forensic:
            forensic_blocks.append(
                _forensic_block(u, token_ids, tokens, scores_sorted, cls_vec, elapsed, truncated)
            )

    verdict = "🔴 **UNSAFE (at least one link flagged)**" if any_unsafe else "🟢 **SAFE (no link over threshold)**"
    body = verdict + "\n\n" + _results_table(rows)

    if forensic and forensic_blocks:
        body += "\n\n---\n\n" + "\n\n---\n\n".join(forensic_blocks)

    if show_json:
        pretty = json.dumps(export_data, ensure_ascii=False, indent=2)
        body += "\n\n---\n\n**Raw forensics JSON (copy & save):**\n"
        body += "```json\n" + pretty + "\n```"

    return body

# ---------- UI ----------
demo = gr.Interface(
    fn=analyze,
    inputs=[
        gr.Textbox(lines=10, label="Email or URL", placeholder="Paste a URL or a full email…"),
        gr.Checkbox(label="Forensic mode (tokens, logits, [CLS])", value=True),
        gr.Checkbox(label="Show raw JSON at the end (copy/paste)", value=False),
        gr.Slider(0.0, 1.0, value=0.40, step=0.01, label="Decision threshold (fused risk ≥ threshold → UNSAFE)"),
        gr.Textbox(lines=2, label="Allowlist (domains, comma/space/newline separated)",
                   placeholder="example.com, github.com  microsoft.com"),
        gr.Checkbox(label="Allowlist overrides (force SAFE if registered domain matches)", value=True),
    ],
    outputs=gr.Markdown(label="Results"),
    title="🛡️ PhishingMail — Model + Heuristics (HF Free CPU)",
    description=(
        "Extract links, score with a tiny HF URL model and transparent heuristics. "
        "Short-circuits for classic phishing patterns. Adjust the threshold, and allowlist trusted domains."
    ),
)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)