File size: 2,435 Bytes
de071e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
import matplotlib.pyplot as plt
import numpy as np
import random
import os
from scipy.stats import chi2


def manual_seed(seed, fix_cudnn=True):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    if fix_cudnn:
        torch.backends.cudnn.deterministic = True  # noqa
        torch.backends.cudnn.benchmark = False  # noqa


def spcor(x, y):
    n = len(x)
    with torch.no_grad():
        r = 1 - torch.sum(6 * torch.square(x - y)) / (n * (n**2 - 1))

    return r


def pdists(x, y):
    x = x.to("cuda")
    y = y.to("cuda")

    with torch.no_grad():
        xsum = torch.sum(torch.square(x), axis=-1)
        ysum = torch.sum(torch.square(y), axis=-1)

        dists = xsum.view(-1, 1) + ysum.view(1, -1) - 2 * x @ y.T

    return dists.cpu()


def cossim(x, y):
    x = x.to("cuda")
    y = y.to("cuda")

    with torch.no_grad():
        similarities = (
            x
            @ y.T
            / (
                torch.linalg.norm(x, axis=-1).view(-1, 1)
                * torch.linalg.norm(y, axis=-1).view(1, -1)
            )
        )

    return similarities.cpu()


def fisher(p):
    count = 0
    chi_2 = 0
    for pvalue in p:
        if not np.isnan(pvalue):
            chi_2 -= 2 * np.log(pvalue)
            count += 1

    return chi2.sf(chi_2, df=2 * count)


def normalize_mc_midpoint(mid, base, ft):
    slope = ft - base
    mid -= slope * 0.5
    mid -= base
    return mid


def normalize_trace(trace, alphas):
    slope = trace[-1] - trace[0]
    start = trace[0]
    for i in range(len(trace)):
        trace[i] -= slope * alphas[i]
        trace[i] -= start
    return trace


def output_hook(m, inp, op, name, feats):
    feats[name] = op.detach()


def get_submodule(module, submodule_string):
    attributes = submodule_string.split(".")
    for attr in attributes:
        module = getattr(module, attr)
    return module


def plot_trace(losses, alphas, normalize, model_a_name, model_b_name, plot_path):

    plt.figure(figsize=(8, 6))
    if normalize:
        losses = normalize_trace(losses, alphas)
    plt.plot(alphas, losses, "o-")

    plt.xlabel("Alpha")
    plt.ylabel("Loss")
    plt.title(f"{model_a_name} (Left) vs {model_b_name} (Right)")
    plot_filename = f"{plot_path}.png"

    plt.savefig(plot_filename, dpi=300, bbox_inches="tight")
    plt.close()