|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import random |
|
from dataclasses import dataclass |
|
from typing import Optional, Tuple |
|
|
|
import torch |
|
from torch import Tensor, nn |
|
|
|
|
|
class TensorDiagnosticOptions(object): |
|
"""Options object for tensor diagnostics: |
|
|
|
Args: |
|
max_eig_dim: |
|
The maximum dimension for which we print out eigenvalues |
|
(limited for speed reasons). |
|
""" |
|
|
|
def __init__(self, max_eig_dim: int = 512): |
|
self.max_eig_dim = max_eig_dim |
|
|
|
def dim_is_summarized(self, size: int): |
|
return size > 10 and size != 31 |
|
|
|
|
|
def get_tensor_stats( |
|
x: Tensor, |
|
dim: int, |
|
stats_type: str, |
|
) -> Tuple[Tensor, int]: |
|
""" |
|
Returns the specified transformation of the Tensor (either x or x.abs() |
|
or (x > 0), summed over all but the index `dim`. |
|
|
|
Args: |
|
x: |
|
Tensor, tensor to be analyzed |
|
dim: |
|
Dimension with 0 <= dim < x.ndim |
|
stats_type: |
|
The stats_type includes several types: |
|
"abs" -> take abs() before summing |
|
"positive" -> take (x > 0) before summing |
|
"rms" -> square before summing, we'll take sqrt later |
|
"value" -> just sum x itself |
|
"max", "min" -> take the maximum or minimum [over all other dims but dim] |
|
instead of summing |
|
"rms-sort" -> this is a bit different than the others, it's based on computing |
|
the rms over the specified dim and returning percentiles of the result |
|
(11 of them). |
|
Returns: |
|
stats: a Tensor of shape (x.shape[dim],). |
|
count: an integer saying how many items were counted in each element |
|
of stats. |
|
""" |
|
|
|
if stats_type == "rms-sort": |
|
rms = (x**2).mean(dim=dim).sqrt() |
|
rms = rms.flatten() |
|
rms = rms.sort()[0] |
|
rms = rms[(torch.arange(11) * rms.numel() // 10).clamp(max=rms.numel() - 1)] |
|
count = 1.0 |
|
return rms, count |
|
|
|
count = x.numel() // x.shape[dim] |
|
|
|
if stats_type == "eigs": |
|
x = x.transpose(dim, -1) |
|
x = x.reshape(-1, x.shape[-1]) |
|
|
|
|
|
return torch.matmul(x.transpose(0, 1), x), count |
|
elif stats_type == "abs": |
|
x = x.abs() |
|
elif stats_type == "rms": |
|
x = x**2 |
|
elif stats_type == "positive": |
|
x = (x > 0).to(dtype=torch.float) |
|
else: |
|
assert stats_type in ["value", "max", "min"] |
|
|
|
sum_dims = [d for d in range(x.ndim) if d != dim] |
|
if len(sum_dims) > 0: |
|
if stats_type == "max": |
|
for dim in reversed(sum_dims): |
|
x = torch.max(x, dim=dim)[0] |
|
elif stats_type == "min": |
|
for dim in reversed(sum_dims): |
|
x = torch.min(x, dim=dim)[0] |
|
else: |
|
x = torch.sum(x, dim=sum_dims) |
|
x = x.flatten().clone() |
|
return x, count |
|
|
|
|
|
@dataclass |
|
class TensorAndCount: |
|
tensor: Tensor |
|
count: int |
|
|
|
|
|
class TensorDiagnostic(object): |
|
"""This class is not directly used by the user, it is responsible for |
|
collecting diagnostics for a module or parameter tensor of a torch.nn.Module. |
|
|
|
Args: |
|
opts: |
|
Options object. |
|
name: |
|
The name associated with this diagnostics object, will probably be |
|
{module_name}.X where X is "output" or "grad", or {parameter_name}. |
|
Y where Y is param_value or param_grad. |
|
""" |
|
|
|
def __init__(self, opts: TensorDiagnosticOptions, name: str): |
|
self.opts = opts |
|
self.name = name |
|
self.class_name = None |
|
|
|
self.stats = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.scalar_stats = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def accumulate(self, x, class_name: Optional[str] = None): |
|
""" |
|
Accumulate tensors. |
|
""" |
|
if class_name is not None: |
|
self.class_name = class_name |
|
if isinstance(x, Tuple): |
|
x = x[0] |
|
if not isinstance(x, Tensor): |
|
return |
|
if x.numel() == 0: |
|
return |
|
x = x.detach().clone() |
|
if x.ndim == 0: |
|
x = x.unsqueeze(0) |
|
ndim = x.ndim |
|
if self.stats is None: |
|
self.stats = [dict() for _ in range(ndim)] |
|
|
|
for dim in range(ndim): |
|
this_dim_stats = self.stats[dim] |
|
if ndim > 1: |
|
|
|
|
|
stats_types = [ |
|
"abs", |
|
"max", |
|
"min", |
|
"positive", |
|
"value", |
|
"rms", |
|
"rms-sort", |
|
] |
|
if x.shape[dim] <= self.opts.max_eig_dim: |
|
stats_types.append("eigs") |
|
else: |
|
stats_types = ["value", "abs", "max", "min"] |
|
|
|
for stats_type in stats_types: |
|
stats, count = get_tensor_stats(x, dim, stats_type) |
|
if stats_type not in this_dim_stats: |
|
this_dim_stats[stats_type] = [] |
|
|
|
done = False |
|
if this_dim_stats[stats_type] is None: |
|
|
|
|
|
|
|
continue |
|
for s in this_dim_stats[stats_type]: |
|
if s.tensor.shape == stats.shape: |
|
if stats_type == "max": |
|
s.tensor = torch.maximum(s.tensor, stats) |
|
|
|
elif stats_type == "min": |
|
s.tensor = torch.minimum(s.tensor, stats) |
|
else: |
|
assert stats_type != "max" |
|
s.tensor += stats |
|
s.count += count |
|
done = True |
|
break |
|
if not done: |
|
if this_dim_stats[stats_type] != [] and stats_type == "eigs": |
|
|
|
|
|
|
|
this_dim_stats[stats_type] = None |
|
else: |
|
this_dim_stats[stats_type].append(TensorAndCount(stats, count)) |
|
|
|
def print_diagnostics(self): |
|
"""Print diagnostics for each dimension of the tensor.""" |
|
if self.stats is None: |
|
print(f"Warning: the stats of {self.name} is None.") |
|
return |
|
for dim, this_dim_stats in enumerate(self.stats): |
|
if "rms" in this_dim_stats and "value" in this_dim_stats: |
|
|
|
rms_stats_list = this_dim_stats["rms"] |
|
value_stats_list = this_dim_stats["value"] |
|
if len(rms_stats_list) == len(value_stats_list): |
|
stddev_stats_list = [] |
|
for r, v in zip(rms_stats_list, value_stats_list): |
|
stddev_stats_list.append( |
|
|
|
|
|
TensorAndCount( |
|
r.tensor - v.tensor * v.tensor / (v.count + 1.0e-20), |
|
r.count, |
|
) |
|
) |
|
this_dim_stats["stddev"] = stddev_stats_list |
|
|
|
for stats_type, stats_list in this_dim_stats.items(): |
|
|
|
|
|
|
|
if stats_list is None: |
|
assert stats_type == "eigs" |
|
continue |
|
|
|
def get_count(count): |
|
return 1 if stats_type in ["max", "min"] else count |
|
|
|
if len(stats_list) == 1: |
|
stats = stats_list[0].tensor / get_count(stats_list[0].count) |
|
else: |
|
|
|
|
|
stats = torch.cat( |
|
[x.tensor / get_count(x.count) for x in stats_list], dim=0 |
|
) |
|
|
|
if stats_type == "eigs": |
|
try: |
|
if hasattr(torch, "linalg") and hasattr(torch.linalg, "eigh"): |
|
eigs, _ = torch.linalg.eigh(stats) |
|
else: |
|
eigs, _ = torch.symeig(stats) |
|
stats = eigs.abs().sqrt() |
|
except: |
|
print("Error getting eigenvalues, trying another method.") |
|
if hasattr(torch, "linalg") and hasattr(torch.linalg, "eig"): |
|
eigs, _ = torch.linalg.eig(stats) |
|
eigs = eigs.abs() |
|
else: |
|
eigs, _ = torch.eig(stats) |
|
eigs = eigs.norm(dim=1) |
|
stats = eigs.sqrt() |
|
|
|
|
|
if stats_type in ["rms", "stddev"]: |
|
|
|
stats = stats.sqrt() |
|
|
|
|
|
|
|
summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized( |
|
stats.numel() |
|
) |
|
if summarize: |
|
|
|
stats = stats.sort()[0] |
|
num_percentiles = 10 |
|
size = stats.numel() |
|
percentiles = [] |
|
for i in range(num_percentiles + 1): |
|
index = (i * (size - 1)) // num_percentiles |
|
percentiles.append(stats[index].item()) |
|
percentiles = ["%.2g" % x for x in percentiles] |
|
percentiles = " ".join(percentiles) |
|
ans = f"percentiles: [{percentiles}]" |
|
else: |
|
ans = stats.tolist() |
|
ans = ["%.2g" % x for x in ans] |
|
ans = "[" + " ".join(ans) + "]" |
|
if stats_type in ["value", "rms", "stddev", "eigs"]: |
|
|
|
|
|
|
|
|
|
norm = (stats**2).sum().sqrt().item() |
|
ans += f", norm={norm:.2g}" |
|
mean = stats.mean().item() |
|
rms = (stats**2).mean().sqrt().item() |
|
ans += f", mean={mean:.3g}, rms={rms:.3g}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
sizes = [x.tensor.shape[0] for x in stats_list] |
|
size_str = ( |
|
f"{sizes[0]}" if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}" |
|
) |
|
maybe_class_name = ( |
|
f" type={self.class_name}," if self.class_name is not None else "" |
|
) |
|
print( |
|
f"module={self.name},{maybe_class_name} dim={dim}, size={size_str}, " |
|
f"{stats_type} {ans}" |
|
) |
|
|
|
|
|
class ScalarDiagnostic(object): |
|
"""This class is not directly used by the user, it is responsible for |
|
collecting diagnostics for a single module (subclass of torch.nn.Module) that |
|
represents some kind of nonlinearity, e.g. ReLU, sigmoid, etc. |
|
""" |
|
|
|
def __init__(self, opts: TensorDiagnosticOptions, name: str): |
|
self.opts = opts |
|
self.name = name |
|
self.class_name = None |
|
self.is_forward_pass = True |
|
|
|
self.tick_scale = None |
|
|
|
self.saved_inputs = [] |
|
self.is_ok = True |
|
|
|
self.counts = None |
|
self.sum_grad = None |
|
self.sum_gradsq = None |
|
self.sum_abs_grad = None |
|
|
|
def accumulate_input(self, x: Tensor, class_name: Optional[str] = None): |
|
""" |
|
Called in forward pass. |
|
""" |
|
if not self.is_forward_pass: |
|
|
|
self.saved_inputs = [] |
|
self.is_forward_pass = True |
|
|
|
if class_name is not None: |
|
self.class_name = class_name |
|
if not self.is_ok: |
|
return |
|
|
|
limit = 10 |
|
if len(self.saved_inputs) > limit: |
|
print( |
|
f"ERROR: forward pass called for this module over {limit} times " |
|
f"with no backward pass. Will not accumulate scalar stats." |
|
) |
|
self.is_ok = False |
|
return |
|
self.saved_inputs.append(x) |
|
|
|
def accumulate_output_grad(self, grad: Tensor): |
|
if not self.is_ok: |
|
return |
|
if self.is_forward_pass: |
|
self.is_forward_pass = False |
|
|
|
last_shape = ( |
|
"n/a" if len(self.saved_inputs) == 0 else self.saved_inputs[-1].shape |
|
) |
|
if len(self.saved_inputs) == 0 or grad.shape != last_shape: |
|
print( |
|
f"ERROR: shape mismatch or no forward activation present when backward " |
|
f"pass called: grad shape ={tuple(grad.shape)}" |
|
f", num-saved-inputs={len(self.saved_inputs)}" |
|
f", shape-of-last-saved-input={last_shape}" |
|
) |
|
self.is_ok = False |
|
return |
|
|
|
x = self.saved_inputs.pop() |
|
self.process_input_and_grad(x, grad) |
|
|
|
def process_input_and_grad(self, x: Tensor, grad: Tensor): |
|
assert x.shape == grad.shape |
|
x = x.flatten() |
|
grad = grad.flatten() |
|
|
|
num_ticks_per_side = 256 |
|
|
|
if self.tick_scale is None: |
|
x_abs_sorted = x.abs().sort()[0] |
|
|
|
index = int(x.numel() * 0.98) |
|
self.tick_scale = float(x_abs_sorted[index] / num_ticks_per_side) |
|
|
|
|
|
self.counts = torch.zeros( |
|
2 * num_ticks_per_side, dtype=torch.long, device=x.device |
|
) |
|
self.sum_grad = torch.zeros( |
|
2 * num_ticks_per_side, dtype=torch.double, device=x.device |
|
) |
|
|
|
self.sum_gradsq = torch.zeros( |
|
2 * num_ticks_per_side, dtype=torch.double, device=x.device |
|
) |
|
self.sum_abs_grad = torch.zeros( |
|
2 * num_ticks_per_side, dtype=torch.double, device=x.device |
|
) |
|
|
|
|
|
x = (x / self.tick_scale).to(torch.long) |
|
x = x.clamp_(min=-num_ticks_per_side, max=num_ticks_per_side - 1) |
|
x = x + num_ticks_per_side |
|
|
|
self.counts.index_add_(dim=0, index=x, source=torch.ones_like(x)) |
|
self.sum_grad.index_add_(dim=0, index=x, source=grad.to(torch.double)) |
|
self.sum_gradsq.index_add_( |
|
dim=0, index=x, source=(grad * grad).to(torch.double) |
|
) |
|
self.sum_abs_grad.index_add_(dim=0, index=x, source=grad.abs().to(torch.double)) |
|
|
|
def print_diagnostics(self): |
|
"""Print diagnostics.""" |
|
if self.is_ok is False or self.counts is None: |
|
print(f"Warning: no stats accumulated for {self.name}, is_ok={self.is_ok}") |
|
return |
|
|
|
counts = self.counts.to("cpu") |
|
sum_grad = self.sum_grad.to(device="cpu", dtype=torch.float32) |
|
sum_gradsq = self.sum_gradsq.to(device="cpu", dtype=torch.float32) |
|
sum_abs_grad = self.sum_abs_grad.to(device="cpu", dtype=torch.float32) |
|
|
|
counts_cumsum = counts.cumsum(dim=0) |
|
counts_tot = counts_cumsum[-1] |
|
|
|
|
|
|
|
|
|
num_bins = 20 |
|
|
|
|
|
counts_per_bin = (counts_tot // num_bins) + 1 |
|
bin_indexes = counts_cumsum // counts_per_bin |
|
bin_indexes = bin_indexes.clamp(min=0, max=num_bins).to(torch.long) |
|
|
|
bin_counts = torch.zeros(num_bins, dtype=torch.long) |
|
bin_counts.index_add_(dim=0, index=bin_indexes, source=counts) |
|
bin_grad = torch.zeros(num_bins) |
|
bin_grad.index_add_(dim=0, index=bin_indexes, source=sum_grad) |
|
bin_gradsq = torch.zeros(num_bins) |
|
bin_gradsq.index_add_(dim=0, index=bin_indexes, source=sum_gradsq) |
|
bin_abs_grad = torch.zeros(num_bins) |
|
bin_abs_grad.index_add_(dim=0, index=bin_indexes, source=sum_abs_grad) |
|
|
|
bin_boundary_counts = ( |
|
torch.arange(num_bins + 1, dtype=torch.long) * counts_per_bin |
|
) |
|
bin_tick_indexes = torch.searchsorted(counts_cumsum, bin_boundary_counts) |
|
|
|
|
|
num_ticks_per_side = counts.numel() // 2 |
|
bin_boundaries = (bin_tick_indexes - num_ticks_per_side) * self.tick_scale |
|
|
|
bin_grad = bin_grad / (bin_counts + 1) |
|
bin_conf_interval = bin_gradsq.sqrt() / ( |
|
bin_counts + 1 |
|
) |
|
|
|
|
|
bin_abs_grad = bin_abs_grad / (bin_counts + 1) |
|
|
|
bin_rel_grad = bin_grad / (bin_abs_grad + 1.0e-20) |
|
bin_conf = bin_grad / (bin_conf_interval + 1.0e-20) |
|
|
|
def tensor_to_str(x: Tensor): |
|
x = ["%.2g" % f for f in x] |
|
x = "[" + " ".join(x) + "]" |
|
return x |
|
|
|
maybe_class_name = ( |
|
f" type={self.class_name}," if self.class_name is not None else "" |
|
) |
|
|
|
print( |
|
f"module={self.name},{maybe_class_name} " |
|
f"bin-boundaries={tensor_to_str(bin_boundaries)}, " |
|
f"rel_grad={tensor_to_str(bin_rel_grad)}, " |
|
f"grad_conf={tensor_to_str(bin_conf)}" |
|
) |
|
|
|
|
|
class ModelDiagnostic(object): |
|
"""This class stores diagnostics for all tensors in the torch.nn.Module. |
|
|
|
Args: |
|
opts: |
|
Options object. |
|
""" |
|
|
|
def __init__(self, opts: Optional[TensorDiagnosticOptions] = None): |
|
|
|
|
|
if opts is None: |
|
self.opts = TensorDiagnosticOptions() |
|
else: |
|
self.opts = opts |
|
self.diagnostics = dict() |
|
|
|
def __getitem__(self, name: str): |
|
T = ScalarDiagnostic if name[-7:] == ".scalar" else TensorDiagnostic |
|
if name not in self.diagnostics: |
|
self.diagnostics[name] = T(self.opts, name) |
|
return self.diagnostics[name] |
|
|
|
def print_diagnostics(self): |
|
"""Print diagnostics for each tensor.""" |
|
for k in sorted(self.diagnostics.keys()): |
|
self.diagnostics[k].print_diagnostics() |
|
|
|
|
|
def get_class_name(module: nn.Module): |
|
ans = type(module).__name__ |
|
|
|
|
|
if ans == "Balancer" or ans == "ActivationBalancer": |
|
try: |
|
ans += f"[{float(module.min_positive)},{float(module.max_positive)}," |
|
f"{float(module.min_abs)},{float(module.max_abs)}]" |
|
except: |
|
pass |
|
elif ans == "AbsValuePenalizer": |
|
try: |
|
ans += f"[{module.limit}]" |
|
except: |
|
pass |
|
return ans |
|
|
|
|
|
def attach_diagnostics( |
|
model: nn.Module, opts: Optional[TensorDiagnosticOptions] = None |
|
) -> ModelDiagnostic: |
|
"""Attach a ModelDiagnostic object to the model by |
|
1) registering forward hook and backward hook on each module, to accumulate |
|
its output tensors and gradient tensors, respectively; |
|
2) registering backward hook on each module parameter, to accumulate its |
|
values and gradients. |
|
|
|
Args: |
|
model: |
|
the model to be analyzed. |
|
opts: |
|
Options object. |
|
|
|
Returns: |
|
The ModelDiagnostic object attached to the model. |
|
""" |
|
|
|
ans = ModelDiagnostic(opts) |
|
for name, module in model.named_modules(): |
|
if name == "": |
|
name = "<top-level>" |
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name): |
|
if isinstance(_output, tuple) and len(_output) == 1: |
|
_output = _output[0] |
|
|
|
if isinstance(_output, Tensor) and _output.dtype in ( |
|
torch.float32, |
|
torch.float16, |
|
torch.float64, |
|
): |
|
_model_diagnostic[f"{_name}.output"].accumulate( |
|
_output, class_name=get_class_name(_module) |
|
) |
|
elif isinstance(_output, tuple): |
|
for i, o in enumerate(_output): |
|
if isinstance(o, Tensor) and o.dtype in ( |
|
torch.float32, |
|
torch.float16, |
|
torch.float64, |
|
): |
|
_model_diagnostic[f"{_name}.output[{i}]"].accumulate( |
|
o, class_name=get_class_name(_module) |
|
) |
|
|
|
def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name): |
|
if isinstance(_output, tuple) and len(_output) == 1: |
|
_output = _output[0] |
|
if isinstance(_output, Tensor) and _output.dtype in ( |
|
torch.float32, |
|
torch.float16, |
|
torch.float64, |
|
): |
|
_model_diagnostic[f"{_name}.grad"].accumulate( |
|
_output, class_name=get_class_name(_module) |
|
) |
|
elif isinstance(_output, tuple): |
|
for i, o in enumerate(_output): |
|
if isinstance(o, Tensor) and o.dtype in ( |
|
torch.float32, |
|
torch.float16, |
|
torch.float64, |
|
): |
|
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate( |
|
o, class_name=get_class_name(_module) |
|
) |
|
|
|
module.register_forward_hook(forward_hook) |
|
module.register_backward_hook(backward_hook) |
|
|
|
if type(module).__name__ in [ |
|
"Sigmoid", |
|
"Tanh", |
|
"ReLU", |
|
"TanSwish", |
|
"Swish", |
|
"DoubleSwish", |
|
"Swoosh", |
|
]: |
|
|
|
|
|
|
|
|
|
def scalar_forward_hook( |
|
_module, _input, _output, _model_diagnostic=ans, _name=name |
|
): |
|
if isinstance(_input, tuple): |
|
(_input,) = _input |
|
assert isinstance(_input, Tensor) |
|
_model_diagnostic[f"{_name}.scalar"].accumulate_input( |
|
_input, class_name=get_class_name(_module) |
|
) |
|
|
|
def scalar_backward_hook( |
|
_module, _input, _output, _model_diagnostic=ans, _name=name |
|
): |
|
if isinstance(_output, tuple): |
|
(_output,) = _output |
|
assert isinstance(_output, Tensor) |
|
_model_diagnostic[f"{_name}.scalar"].accumulate_output_grad(_output) |
|
|
|
module.register_forward_hook(scalar_forward_hook) |
|
module.register_backward_hook(scalar_backward_hook) |
|
|
|
for name, parameter in model.named_parameters(): |
|
|
|
def param_backward_hook( |
|
grad, _parameter=parameter, _model_diagnostic=ans, _name=name |
|
): |
|
_model_diagnostic[f"{_name}.param_value"].accumulate(_parameter) |
|
_model_diagnostic[f"{_name}.param_grad"].accumulate(grad) |
|
|
|
try: |
|
parameter.register_hook(param_backward_hook) |
|
except: |
|
logging.warning( |
|
f"Warning: could not register backward hook for parameter {name}, " |
|
f"it might not be differentiable." |
|
) |
|
|
|
return ans |
|
|
|
|
|
def _test_tensor_diagnostic(): |
|
opts = TensorDiagnosticOptions(512) |
|
|
|
diagnostic = TensorDiagnostic(opts, "foo") |
|
|
|
for _ in range(10): |
|
diagnostic.accumulate(torch.randn(50, 100) * 10.0) |
|
|
|
diagnostic.print_diagnostics() |
|
|
|
model = nn.Sequential(nn.Linear(100, 50), nn.ReLU(), nn.Linear(50, 80)) |
|
|
|
diagnostic = attach_diagnostics(model, opts) |
|
for _ in range(10): |
|
T = random.randint(200, 300) |
|
x = torch.randn(T, 100) |
|
y = model(x) |
|
y.sum().backward() |
|
|
|
diagnostic.print_diagnostics() |
|
|
|
|
|
if __name__ == "__main__": |
|
_test_tensor_diagnostic() |
|
|