import torch from functools import partial from typing import Callable def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool): def decorator(*args, **kwargs): if cuda_amp_deprecated: kwargs["device_type"] = "cuda" return dec(*args, **kwargs) return decorator if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined] deprecated = True from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined] else: deprecated = False from torch.cuda.amp import custom_fwd, custom_bwd custom_fwd = custom_amp_decorator(custom_fwd, deprecated) custom_bwd = custom_amp_decorator(custom_bwd, deprecated)