Spaces:
Running
Running
File size: 1,283 Bytes
406f22d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import torch
from torch import nn
def linear():
return nn.Identity()
def relu():
return nn.ReLU()
def prelu():
return nn.PReLU()
def leaky_relu():
return nn.LeakyReLU()
def sigmoid():
return nn.Sigmoid()
def softmax(dim=None):
return nn.Softmax(dim=dim)
def tanh():
return nn.Tanh()
def gelu():
return nn.GELU()
def register_activation(custom_act):
if (
custom_act.__name__ in globals().keys()
or custom_act.__name__.lower() in globals().keys()
):
raise ValueError(
f"Activation {custom_act.__name__} already exists. Choose another name."
)
globals().update({custom_act.__name__: custom_act})
def get(identifier):
if identifier is None:
return None
elif callable(identifier):
return identifier
elif isinstance(identifier, str):
cls = globals().get(identifier)
if cls is None:
raise ValueError(
"Could not interpret activation identifier: " + str(identifier)
)
return cls
else:
raise ValueError(
"Could not interpret activation identifier: " + str(identifier)
)
if __name__ == "__main__":
print(globals().keys())
print(globals().get("tanh"))
|