marcsun13's picture
marcsun13 HF Staff
Upload folder using huggingface_hub
f9a8cd3 verified
raw
history blame
7.72 kB
import enum
import functools
import os
import subprocess
import sys
import torch
from .numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5
def assert_equal(ref, tri):
if isinstance(ref, torch.Tensor):
assert torch.all(ref == tri)
else:
assert ref == tri
def assert_close(ref, tri, maxtol=None, rmstol=None, description="--", verbose=True):
if tri.dtype.itemsize == 1:
ref_as_type = ref.to(tri.dtype)
if ref.dtype == tri.dtype:
assert torch.all(ref_as_type == tri)
return
ref = ref_as_type
if maxtol is None:
maxtol = 2e-2
if rmstol is None:
rmstol = 4e-3
"""
Compare reference values against obtained values.
"""
# cast to float32:
ref = ref.to(torch.float32).detach()
tri = tri.to(torch.float32).detach()
assert ref.shape == tri.shape, f"Tensors must have same size {ref.shape=} {tri.shape=}"
# deal with infinite elements:
inf_mask_ref = torch.isinf(ref)
inf_mask_tri = torch.isinf(tri)
assert torch.equal(inf_mask_ref, inf_mask_tri), "Tensor must have same infinite elements"
refn = torch.where(inf_mask_ref, 0, ref)
trin = torch.where(inf_mask_tri, 0, tri)
# normalise so that RMS calculation doesn't overflow:
eps = 1.0e-30
multiplier = 1.0 / (torch.max(torch.abs(refn)) + eps)
refn *= multiplier
trin *= multiplier
ref_rms = torch.sqrt(torch.square(refn).mean()) + eps
rel_err = torch.abs(refn - trin) / torch.maximum(ref_rms, torch.abs(refn))
max_err = torch.max(rel_err).item()
rms_err = torch.sqrt(torch.square(rel_err).mean()).item()
if verbose:
print("%s maximum relative error = %s (threshold = %s)" % (description, max_err, maxtol))
print("%s RMS relative error = %s (threshold = %s)" % (description, rms_err, rmstol))
if max_err > maxtol:
bad_idxs = torch.nonzero(rel_err > maxtol)
num_nonzero = bad_idxs.size(0)
bad_idxs = bad_idxs[:1000]
print("%d / %d mismatched elements (shape = %s) at coords %s" %
(num_nonzero, rel_err.numel(), tuple(rel_err.shape), bad_idxs.tolist()))
bad_idxs = bad_idxs.unbind(-1)
print("ref values: ", ref[tuple(bad_idxs)].cpu())
print("tri values: ", tri[tuple(bad_idxs)].cpu())
assert max_err <= maxtol
assert rms_err <= rmstol
class ComputeSanitizerTool(enum.Enum):
MEMCHECK = "memcheck"
RACECHECK = "racecheck"
SYNCCHECK = "synccheck"
INITCHECK = "initcheck"
def compute_sanitizer(**target_kwargs):
"""
Decorator to run a test with compute sanitizer enabled and pytorch caching allocator disabled,
to expose potential memory access errors.
This decorator requires the `request` fixture to be present.
If `run_sanitizer` argument is present and set to False, the sanitizer is not run.
Running tests under compute sanitizer requires launching subprocess and is slow,
so use sparingly
"""
def decorator(test_fn):
@functools.wraps(test_fn)
def wrapper(*args, **kwargs):
if os.environ.get("SKIP_COMPUTE_SANITIZER") == "1":
test_fn(*args, **kwargs)
return
import psutil
if target_kwargs.pop("clear_torch_cache", False):
# If we don't pop clear_torch_cache, it won't pass
# target_kwargs.items() <= kwargs.items() condition below.
torch.cuda.empty_cache()
tools_to_check = target_kwargs.pop("tools_to_check", [ComputeSanitizerTool.MEMCHECK])
assert isinstance(tools_to_check, list), f"{tools_to_check=}"
assert all(tool in ComputeSanitizerTool for tool in tools_to_check), (
f"{(tool for tool in tools_to_check if tool not in ComputeSanitizerTool)=}")
ppid_name = psutil.Process(os.getppid()).exe()
run_compute_sanitizer = target_kwargs.items() <= kwargs.items()
if "run_sanitizer" in kwargs:
run_compute_sanitizer &= kwargs["run_sanitizer"]
if run_compute_sanitizer and "compute-sanitizer" not in ppid_name:
for tool in tools_to_check:
path = os.path.realpath(test_fn.__globals__["__file__"])
# get path of current file
env = {
"PATH": os.environ["PATH"],
"PYTORCH_NO_CUDA_MEMORY_CACHING": "1",
"TORCH_SHOW_CPP_STACKTRACES": "1",
"CUDA_LAUNCH_BLOCKING": "1",
}
if "CUDA_VISIBLE_DEVICES" in os.environ:
env["CUDA_VISIBLE_DEVICES"] = os.environ["CUDA_VISIBLE_DEVICES"]
assert "request_fixture" in kwargs, (
"memcheck'ed test must have a (possibly unused) `request` fixture")
test_id = kwargs["request_fixture"].node.callspec.id
cmd = f"{path}::{test_fn.__name__}[{test_id}]"
cmd = [
"compute-sanitizer",
"--target-processes=application-only",
"--destroy-on-device-error=context",
f"--tool={tool.value}",
sys.executable,
"-m",
"pytest",
"-vsx",
cmd,
]
for opt in ["--update_checksum", "--ignore_checksum_error"]:
if opt in sys.argv:
cmd.append(opt)
out = subprocess.run(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
env=env,
)
sanitizer_ok = "ERROR SUMMARY: 0 errors" in str(
out.stdout) or "RACECHECK SUMMARY: 0 hazards displayed" in str(out.stdout)
test_output = out.stdout
if type(test_output) is bytes:
test_output = test_output.decode()
fail = False
if not sanitizer_ok:
print("compute-sanitizer returned an error")
fail = True
elif out.returncode != 0:
print(
"The test failed due to some other reason: consider running without compute-sanitizer to verify."
)
print(f"{out.returncode=}")
fail = True
if fail:
print("*****************************************************")
print("******************** TEST OUTPUT ********************")
print("*****************************************************")
print(test_output)
print("*****************************************************")
print("****************** TEST OUTPUT END ******************")
print("*****************************************************")
assert None
else:
test_fn(*args, **kwargs)
return wrapper
return decorator
def compute_actual_scale(x, dtype):
max_finite = {
torch.float8_e5m2: MAX_FINITE_FLOAT8E5,
torch.float8_e4m3fn: MAX_FINITE_FLOAT8E4NV,
torch.float8_e4m3fnuz: MAX_FINITE_FLOAT8E4B8,
}[dtype]
return x.abs().max() / max_finite