File size: 1,857 Bytes
c1ce505
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch


def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner


def linear(a, b, x, min_x, max_x):
    """
    b             ___________
                /|
               / |
    a  _______/  |
              |  |
           min_x max_x
    """
    return a + min(max((x - min_x) / (max_x - min_x), 0), 1) * (b - a)


def batchify(data, device):
    return (d.unsqueeze(0).to(device) for d in data)


def _make_seq_first(*args):
    # N, G, S, ... -> S, G, N, ...
    if len(args) == 1:
        arg, = args
        return arg.permute(2, 1, 0, *range(3, arg.dim())) if arg is not None else None
    return (*(arg.permute(2, 1, 0, *range(3, arg.dim())) if arg is not None else None for arg in args),)


def _make_batch_first(*args):
    # S, G, N, ... -> N, G, S, ...
    if len(args) == 1:
        arg, = args
        return arg.permute(2, 1, 0, *range(3, arg.dim())) if arg is not None else None
    return (*(arg.permute(2, 1, 0, *range(3, arg.dim())) if arg is not None else None for arg in args),)


def _pack_group_batch(*args):
    # S, G, N, ... -> S, G * N, ...
    if len(args) == 1:
        arg, = args
        return arg.reshape(arg.size(0), arg.size(1) * arg.size(2), *arg.shape[3:]) if arg is not None else None
    return (*(arg.reshape(arg.size(0), arg.size(1) * arg.size(2), *arg.shape[3:]) if arg is not None else None for arg in args),)


def _unpack_group_batch(N, *args):
    # S, G * N, ... -> S, G, N, ...
    if len(args) == 1:
        arg, = args
        return arg.reshape(arg.size(0), -1, N, *arg.shape[2:]) if arg is not None else None
    return (*(arg.reshape(arg.size(0), -1, N, *arg.shape[2:]) if arg is not None else None for arg in args),)