fffiloni's picture
Migrated from GitHub
406f22d verified
import torch
from torch import nn
def linear():
return nn.Identity()
def relu():
return nn.ReLU()
def prelu():
return nn.PReLU()
def leaky_relu():
return nn.LeakyReLU()
def sigmoid():
return nn.Sigmoid()
def softmax(dim=None):
return nn.Softmax(dim=dim)
def tanh():
return nn.Tanh()
def gelu():
return nn.GELU()
def register_activation(custom_act):
if (
custom_act.__name__ in globals().keys()
or custom_act.__name__.lower() in globals().keys()
):
raise ValueError(
f"Activation {custom_act.__name__} already exists. Choose another name."
)
globals().update({custom_act.__name__: custom_act})
def get(identifier):
if identifier is None:
return None
elif callable(identifier):
return identifier
elif isinstance(identifier, str):
cls = globals().get(identifier)
if cls is None:
raise ValueError(
"Could not interpret activation identifier: " + str(identifier)
)
return cls
else:
raise ValueError(
"Could not interpret activation identifier: " + str(identifier)
)
if __name__ == "__main__":
print(globals().keys())
print(globals().get("tanh"))