Spaces:
Running
Running
import torch | |
from torch import nn | |
def linear(): | |
return nn.Identity() | |
def relu(): | |
return nn.ReLU() | |
def prelu(): | |
return nn.PReLU() | |
def leaky_relu(): | |
return nn.LeakyReLU() | |
def sigmoid(): | |
return nn.Sigmoid() | |
def softmax(dim=None): | |
return nn.Softmax(dim=dim) | |
def tanh(): | |
return nn.Tanh() | |
def gelu(): | |
return nn.GELU() | |
def register_activation(custom_act): | |
if ( | |
custom_act.__name__ in globals().keys() | |
or custom_act.__name__.lower() in globals().keys() | |
): | |
raise ValueError( | |
f"Activation {custom_act.__name__} already exists. Choose another name." | |
) | |
globals().update({custom_act.__name__: custom_act}) | |
def get(identifier): | |
if identifier is None: | |
return None | |
elif callable(identifier): | |
return identifier | |
elif isinstance(identifier, str): | |
cls = globals().get(identifier) | |
if cls is None: | |
raise ValueError( | |
"Could not interpret activation identifier: " + str(identifier) | |
) | |
return cls | |
else: | |
raise ValueError( | |
"Could not interpret activation identifier: " + str(identifier) | |
) | |
if __name__ == "__main__": | |
print(globals().keys()) | |
print(globals().get("tanh")) | |