| import torch |
| import torch.nn as nn |
| from . import SparseTensor |
|
|
| __all__ = [ |
| 'SparseReLU', |
| 'SparseSiLU', |
| 'SparseGELU', |
| 'SparseActivation' |
| ] |
|
|
|
|
| class SparseReLU(nn.ReLU): |
| def forward(self, input: SparseTensor) -> SparseTensor: |
| return input.replace(super().forward(input.feats)) |
| |
|
|
| class SparseSiLU(nn.SiLU): |
| def forward(self, input: SparseTensor) -> SparseTensor: |
| return input.replace(super().forward(input.feats)) |
|
|
|
|
| class SparseGELU(nn.GELU): |
| def forward(self, input: SparseTensor) -> SparseTensor: |
| return input.replace(super().forward(input.feats)) |
|
|
|
|
| class SparseActivation(nn.Module): |
| def __init__(self, activation: nn.Module): |
| super().__init__() |
| self.activation = activation |
|
|
| def forward(self, input: SparseTensor) -> SparseTensor: |
| return input.replace(self.activation(input.feats)) |
| |
|
|