Spaces:
Sleeping
Sleeping
| from typing import Union, List | |
| import torch | |
| def is_differentiable( | |
| loss: torch.Tensor, model: Union[torch.nn.Module, List[torch.nn.Module]], print_instead: bool = False | |
| ) -> None: | |
| """ | |
| Overview: | |
| Judge whether the model/models are differentiable. First check whether module's grad is None, | |
| then do loss's back propagation, finally check whether module's grad are torch.Tensor. | |
| Arguments: | |
| - loss (:obj:`torch.Tensor`): loss tensor of the model | |
| - model (:obj:`Union[torch.nn.Module, List[torch.nn.Module]]`): model or models to be checked | |
| - print_instead (:obj:`bool`): Whether to print module's final grad result, \ | |
| instead of asserting. Default set to ``False``. | |
| """ | |
| assert isinstance(loss, torch.Tensor) | |
| if isinstance(model, list): | |
| for m in model: | |
| assert isinstance(m, torch.nn.Module) | |
| for k, p in m.named_parameters(): | |
| assert p.grad is None, k | |
| elif isinstance(model, torch.nn.Module): | |
| for k, p in model.named_parameters(): | |
| assert p.grad is None, k | |
| else: | |
| raise TypeError('model must be list or nn.Module') | |
| loss.backward() | |
| if isinstance(model, list): | |
| for m in model: | |
| for k, p in m.named_parameters(): | |
| if print_instead: | |
| if not isinstance(p.grad, torch.Tensor): | |
| print(k, "grad is:", p.grad) | |
| else: | |
| assert isinstance(p.grad, torch.Tensor), k | |
| elif isinstance(model, torch.nn.Module): | |
| for k, p in model.named_parameters(): | |
| if print_instead: | |
| if not isinstance(p.grad, torch.Tensor): | |
| print(k, "grad is:", p.grad) | |
| else: | |
| assert isinstance(p.grad, torch.Tensor), k | |