File size: 1,283 Bytes
406f22d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
from torch import nn


def linear():
    return nn.Identity()


def relu():
    return nn.ReLU()


def prelu():
    return nn.PReLU()


def leaky_relu():
    return nn.LeakyReLU()


def sigmoid():
    return nn.Sigmoid()


def softmax(dim=None):
    return nn.Softmax(dim=dim)


def tanh():
    return nn.Tanh()


def gelu():
    return nn.GELU()


def register_activation(custom_act):
    if (
        custom_act.__name__ in globals().keys()
        or custom_act.__name__.lower() in globals().keys()
    ):
        raise ValueError(
            f"Activation {custom_act.__name__} already exists. Choose another name."
        )
    globals().update({custom_act.__name__: custom_act})


def get(identifier):
    if identifier is None:
        return None
    elif callable(identifier):
        return identifier
    elif isinstance(identifier, str):
        cls = globals().get(identifier)
        if cls is None:
            raise ValueError(
                "Could not interpret activation identifier: " + str(identifier)
            )
        return cls
    else:
        raise ValueError(
            "Could not interpret activation identifier: " + str(identifier)
        )


if __name__ == "__main__":
    print(globals().keys())
    print(globals().get("tanh"))