File size: 2,236 Bytes
de071e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
from transformers import GPTNeoXForCausalLM, AutoTokenizer

import argparse
import pickle
import timeit
import subprocess

from tracing.utils.evaluate import prepare_hf_dataset, prepare_hf_dataloader, evaluate
from tracing.utils.utils import output_hook, get_submodule

parser = argparse.ArgumentParser(description="Experiment Settings")

parser.add_argument("--model_id", default="EleutherAI/pythia-1.4b-deduped", type=str)
parser.add_argument("--step", default=0, type=int)
parser.add_argument("--layer", default=10, type=int)

parser.add_argument("--dataset_id", default="dlwh/wikitext_103_detokenized", type=str)
parser.add_argument("--block_size", default=512, type=int)
parser.add_argument("--batch_size", default=6, type=int)

parser.add_argument("--save", default="results.p", type=str)
parser.add_argument("--seed", default=0, type=int)
parser.add_argument("--token", default="", type=str)

args = parser.parse_args()

from huggingface_hub import login

login(token=args.token)

start = timeit.default_timer()

results = {}
results["args"] = args
results["commit"] = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip()

torch.manual_seed(args.seed)

model = GPTNeoXForCausalLM.from_pretrained(
    args.model_id,
    revision=f"step{args.step}",
)
tokenizer = AutoTokenizer.from_pretrained(
    args.model_id,
    revision=f"step{args.step}",
)

print("model loaded")

dataset = prepare_hf_dataset(args.dataset_id, args.block_size, tokenizer)
dataloader = prepare_hf_dataloader(dataset, args.batch_size)

print("dataset loaded")

block = get_submodule(model, f"gpt_neox.layers.{args.layer}")

feats, hooks = {}, {}
for layer in [
    "input_layernorm",
    "post_attention_layernorm",
    "mlp.dense_h_to_4h",
    "mlp.dense_4h_to_h",
]:
    hooks[layer] = lambda m, inp, op, layer=layer, feats=feats: output_hook(
        m, inp, op, layer, feats
    )
    get_submodule(block, layer).register_forward_hook(hooks[layer])

print("hooks created")

evaluate(model, dataloader)

print("models evaluated")

end = timeit.default_timer()
results["time"] = end - start

results["weights"] = block.state_dict()
results["feats"] = feats

print(results)
pickle.dump(results, open(args.save, "wb"))