File size: 2,403 Bytes
38e5a7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#!/usr/bin/env python
"""
list_modules.py
────────────────────────────────────────────────────────────
Print (and optionally save) the dotted names of **all** sub‑modules
inside a PyTorch model.  Handy for locating the correct layer name
for Grad‑CAM, feature hooks, etc.

USAGE
-----
edit MODEL_SOURCE and MODEL_TYPE below, then:

    python list_modules.py

Outputs:
    β€’ console  – first `LIMIT` names (to keep logs short)
    β€’ file     – full list written to `modules_<model>.txt`
"""

from __future__ import annotations
import torch, argparse, pathlib, sys
from transformers import AutoModel

# ──────────────  CONFIG  ───────────────────────────────────────────────
MODEL_SOURCE = "haywoodsloan/ai-image-detector-deploy"
MODEL_TYPE   = "huggingface"

LIMIT        = 2000                  # how many lines to print to stdout (None = all)
# ───────────────────────────────────────────────────────────────────────

def load_model(src: str, src_type: str):
    if src_type == "huggingface":
        model = AutoModel.from_pretrained(src)
    elif src_type == "torchscript":
        model = torch.jit.load(src)
    else:
        raise ValueError("MODEL_TYPE must be 'huggingface' or 'torchscript'")
    model.eval()
    return model

def dump_module_names(model: torch.nn.Module,
                      out_file: pathlib.Path,
                      limit: int | None = None):
    names = [n for n, _ in model.named_modules()]  # includes root '' at idx 0
    total = len(names)

    print(f"\nβ–Ά total {total} sub‑modules found\n")
    for idx, name in enumerate(names):
        if limit is None or idx < limit:
            print(f"{idx:4d}: {name}")

    out_file.write_text("\n".join(names), encoding="utf‑8")
    print(f"\nβ–Ά wrote full list to {out_file}")

def main():
    model = load_model(MODEL_SOURCE, MODEL_TYPE)
    txt_path = pathlib.Path(f"modules_{MODEL_SOURCE.split('/')[-1].replace('.','_')}.txt")
    dump_module_names(model, txt_path, LIMIT)

if __name__ == "__main__":
    main()