Spaces:
Sleeping
Sleeping
| from typing import Optional | |
| import torch | |
| from torch import nn | |
| from torch.distributions.transforms import TanhTransform | |
| class NonegativeParameter(nn.Module): | |
| """ | |
| Overview: | |
| This module will output a non-negative parameter during the forward process. | |
| Interfaces: | |
| ``__init__``, ``forward``, ``set_data``. | |
| """ | |
| def __init__(self, data: Optional[torch.Tensor] = None, requires_grad: bool = True, delta: float = 1e-8): | |
| """ | |
| Overview: | |
| Initialize the NonegativeParameter object using the given arguments. | |
| Arguments: | |
| - data (:obj:`Optional[torch.Tensor]`): The initial value of generated parameter. If set to ``None``, the \ | |
| default value is 0. | |
| - requires_grad (:obj:`bool`): Whether this parameter requires grad. | |
| - delta (:obj:`Any`): The delta of log function. | |
| """ | |
| super().__init__() | |
| if data is None: | |
| data = torch.zeros(1) | |
| self.log_data = nn.Parameter(torch.log(data + delta), requires_grad=requires_grad) | |
| def forward(self) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Output the non-negative parameter during the forward process. | |
| Returns: | |
| parameter (:obj:`torch.Tensor`): The generated parameter. | |
| """ | |
| return torch.exp(self.log_data) | |
| def set_data(self, data: torch.Tensor) -> None: | |
| """ | |
| Overview: | |
| Set the value of the non-negative parameter. | |
| Arguments: | |
| data (:obj:`torch.Tensor`): The new value of the non-negative parameter. | |
| """ | |
| self.log_data = nn.Parameter(torch.log(data + 1e-8), requires_grad=self.log_data.requires_grad) | |
| class TanhParameter(nn.Module): | |
| """ | |
| Overview: | |
| This module will output a tanh parameter during the forward process. | |
| Interfaces: | |
| ``__init__``, ``forward``, ``set_data``. | |
| """ | |
| def __init__(self, data: Optional[torch.Tensor] = None, requires_grad: bool = True): | |
| """ | |
| Overview: | |
| Initialize the TanhParameter object using the given arguments. | |
| Arguments: | |
| - data (:obj:`Optional[torch.Tensor]`): The initial value of generated parameter. If set to ``None``, the \ | |
| default value is 1. | |
| - requires_grad (:obj:`bool`): Whether this parameter requires grad. | |
| """ | |
| super().__init__() | |
| if data is None: | |
| data = torch.zeros(1) | |
| self.transform = TanhTransform(cache_size=1) | |
| self.data_inv = nn.Parameter(self.transform.inv(data), requires_grad=requires_grad) | |
| def forward(self) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Output the tanh parameter during the forward process. | |
| Returns: | |
| parameter (:obj:`torch.Tensor`): The generated parameter. | |
| """ | |
| return self.transform(self.data_inv) | |
| def set_data(self, data: torch.Tensor) -> None: | |
| """ | |
| Overview: | |
| Set the value of the tanh parameter. | |
| Arguments: | |
| data (:obj:`torch.Tensor`): The new value of the tanh parameter. | |
| """ | |
| self.data_inv = nn.Parameter(self.transform.inv(data), requires_grad=self.data_inv.requires_grad) | |