|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import random |
|
|
|
import torch |
|
from torch import Tensor, nn |
|
|
|
|
|
def register_inf_check_hooks(model: nn.Module) -> None: |
|
"""Registering forward hook on each module, to check |
|
whether its output tensors is not finite. |
|
|
|
Args: |
|
model: |
|
the model to be analyzed. |
|
""" |
|
|
|
for name, module in model.named_modules(): |
|
if name == "": |
|
name = "<top-level>" |
|
|
|
|
|
|
|
def forward_hook(_module, _input, _output, _name=name): |
|
if isinstance(_output, Tensor): |
|
try: |
|
if not torch.isfinite(_output.to(torch.float32).sum()): |
|
logging.warning(f"The sum of {_name}.output is not finite") |
|
except RuntimeError: |
|
pass |
|
elif isinstance(_output, tuple): |
|
for i, o in enumerate(_output): |
|
if isinstance(o, tuple): |
|
o = o[0] |
|
if not isinstance(o, Tensor): |
|
continue |
|
try: |
|
if not torch.isfinite(o.to(torch.float32).sum()): |
|
logging.warning( |
|
f"The sum of {_name}.output[{i}] is not finite" |
|
) |
|
except RuntimeError: |
|
pass |
|
|
|
|
|
|
|
def backward_hook(_module, _input, _output, _name=name): |
|
if isinstance(_output, Tensor): |
|
try: |
|
if not torch.isfinite(_output.to(torch.float32).sum()): |
|
logging.warning(f"The sum of {_name}.grad is not finite") |
|
except RuntimeError: |
|
pass |
|
|
|
elif isinstance(_output, tuple): |
|
for i, o in enumerate(_output): |
|
if isinstance(o, tuple): |
|
o = o[0] |
|
if not isinstance(o, Tensor): |
|
continue |
|
if not torch.isfinite(o.to(torch.float32).sum()): |
|
logging.warning(f"The sum of {_name}.grad[{i}] is not finite") |
|
|
|
module.register_forward_hook(forward_hook) |
|
module.register_backward_hook(backward_hook) |
|
|
|
for name, parameter in model.named_parameters(): |
|
|
|
def param_backward_hook(grad, _name=name): |
|
if not torch.isfinite(grad.to(torch.float32).sum()): |
|
logging.warning(f"The sum of {_name}.param_grad is not finite") |
|
|
|
try: |
|
parameter.register_hook(param_backward_hook) |
|
except Exception as e: |
|
logging.warning( |
|
f"Warning: could not register backward hook for parameter {name}" |
|
f" with error {e}, it might not be differentiable." |
|
) |
|
|
|
|
|
def _test_inf_check_hooks(): |
|
model = nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 80)) |
|
|
|
register_inf_check_hooks(model) |
|
for _ in range(10): |
|
T = random.randint(200, 300) |
|
x = torch.randn(T, 100) + float("inf") * (T % 2) |
|
y = model(x) |
|
y.sum().backward() |
|
|
|
|
|
if __name__ == "__main__": |
|
_test_inf_check_hooks() |
|
|