Spaces:
Configuration error
Configuration error
""" 'Fast' Normalization Functions | |
For GroupNorm and LayerNorm these functions bypass typical AMP upcast to float32. | |
Additionally, for LayerNorm, the APEX fused LN is used if available (which also does not upcast) | |
Hacked together by / Copyright 2022 Ross Wightman | |
""" | |
from typing import List, Optional | |
import torch | |
from torch.nn import functional as F | |
try: | |
from apex.normalization.fused_layer_norm import fused_layer_norm_affine | |
has_apex = True | |
except ImportError: | |
has_apex = False | |
# fast (ie lower precision LN) can be disabled with this flag if issues crop up | |
_USE_FAST_NORM = False # defaulting to False for now | |
def is_fast_norm(): | |
return _USE_FAST_NORM | |
def set_fast_norm(enable=True): | |
global _USE_FAST_NORM | |
_USE_FAST_NORM = enable | |
def fast_group_norm( | |
x: torch.Tensor, | |
num_groups: int, | |
weight: Optional[torch.Tensor] = None, | |
bias: Optional[torch.Tensor] = None, | |
eps: float = 1e-5 | |
) -> torch.Tensor: | |
if torch.jit.is_scripting(): | |
# currently cannot use is_autocast_enabled within torchscript | |
return F.group_norm(x, num_groups, weight, bias, eps) | |
if torch.is_autocast_enabled(): | |
# normally native AMP casts GN inputs to float32 | |
# here we use the low precision autocast dtype | |
# FIXME what to do re CPU autocast? | |
dt = torch.get_autocast_gpu_dtype() | |
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) | |
with torch.cuda.amp.autocast(enabled=False): | |
return F.group_norm(x, num_groups, weight, bias, eps) | |
def fast_layer_norm( | |
x: torch.Tensor, | |
normalized_shape: List[int], | |
weight: Optional[torch.Tensor] = None, | |
bias: Optional[torch.Tensor] = None, | |
eps: float = 1e-5 | |
) -> torch.Tensor: | |
if torch.jit.is_scripting(): | |
# currently cannot use is_autocast_enabled within torchscript | |
return F.layer_norm(x, normalized_shape, weight, bias, eps) | |
if has_apex: | |
return fused_layer_norm_affine(x, weight, bias, normalized_shape, eps) | |
if torch.is_autocast_enabled(): | |
# normally native AMP casts LN inputs to float32 | |
# apex LN does not, this is behaving like Apex | |
dt = torch.get_autocast_gpu_dtype() | |
# FIXME what to do re CPU autocast? | |
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) | |
with torch.cuda.amp.autocast(enabled=False): | |
return F.layer_norm(x, normalized_shape, weight, bias, eps) | |