Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,101 Bytes
a249588 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
# Copyright (c) OpenMMLab. All rights reserved.
import functools
class OutputHook:
def __init__(self, module, outputs=None, as_tensor=False):
self.outputs = outputs
self.as_tensor = as_tensor
self.layer_outputs = {}
self.register(module)
def register(self, module):
def hook_wrapper(name):
def hook(model, input, output):
if self.as_tensor:
self.layer_outputs[name] = output
else:
if isinstance(output, list):
self.layer_outputs[name] = [
out.detach().cpu().numpy() for out in output
]
else:
self.layer_outputs[name] = output.detach().cpu().numpy(
)
return hook
self.handles = []
if isinstance(self.outputs, (list, tuple)):
for name in self.outputs:
try:
layer = rgetattr(module, name)
h = layer.register_forward_hook(hook_wrapper(name))
except ModuleNotFoundError as module_not_found:
raise ModuleNotFoundError(
f'Module {name} not found') from module_not_found
self.handles.append(h)
def remove(self):
for h in self.handles:
h.remove()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.remove()
# using wonder's beautiful simplification:
# https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-objects
def rsetattr(obj, attr, val):
"""Set the value of a nested attribute of an object.
This function splits the attribute path and sets the value of the
nested attribute. If the attribute path is nested (e.g., 'x.y.z'), it
traverses through each attribute until it reaches the last one and sets
its value.
Args:
obj (object): The object whose attribute needs to be set.
attr (str): The attribute path in dot notation (e.g., 'x.y.z').
val (any): The value to set at the specified attribute path.
"""
pre, _, post = attr.rpartition('.')
return setattr(rgetattr(obj, pre) if pre else obj, post, val)
def rgetattr(obj, attr, *args):
"""Recursively get a nested attribute of an object.
This function splits the attribute path and retrieves the value of the
nested attribute. If the attribute path is nested (e.g., 'x.y.z'), it
traverses through each attribute. If an attribute in the path does not
exist, it returns the value specified as the third argument.
Args:
obj (object): The object whose attribute needs to be retrieved.
attr (str): The attribute path in dot notation (e.g., 'x.y.z').
*args (any): Optional default value to return if the attribute
does not exist.
"""
def _getattr(obj, attr):
return getattr(obj, attr, *args)
return functools.reduce(_getattr, [obj] + attr.split('.'))
|