| | """Kernel test utils""" |
| |
|
| | import itertools |
| | import random |
| | import unittest |
| | from numbers import Number |
| | from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union |
| |
|
| | import pytest |
| | import torch |
| | from torch._prims_common import TensorLikeType |
| |
|
| | |
| | |
| | DEFAULT_OPCHECK_TEST_UTILS: Tuple[str, ...] = ( |
| | "test_schema", |
| | "test_autograd_registration", |
| | "test_faketensor", |
| | ) |
| |
|
| | ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = ( |
| | "test_schema", |
| | "test_autograd_registration", |
| | "test_faketensor", |
| | "test_aot_dispatch_dynamic", |
| | ) |
| |
|
| | def to_fp8(tensor: torch.Tensor): |
| | finfo = torch.finfo(torch.float8_e4m3fn) |
| | return torch.round(tensor.clamp( |
| | min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) |
| |
|
| | def to_int8(tensor: torch.Tensor): |
| | return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) |
| |
|
| |
|
| | def rand_int8(shape: tuple, device: str = "cuda"): |
| | return to_int8(torch.rand(shape, device=device) * 255 - 128) |
| |
|
| |
|
| |
|
| |
|
| | |
| | def fp8_allclose( |
| | a: TensorLikeType, |
| | b: TensorLikeType, |
| | rtol: float = 1e-05, |
| | atol: float = 1e-08, |
| | equal_nan: bool = False, |
| | ) -> bool: |
| | """ |
| | Reference implementation of torch.allclose |
| | """ |
| | torch._refs._check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol) |
| |
|
| | return bool( |
| | torch.all( |
| | torch.isclose( |
| | a.double(), b.double(), rtol=rtol, atol=atol, equal_nan=equal_nan |
| | ) |
| | ).item() |
| | ) |
| |
|
| |
|
| | |
| | |
| | def opcheck( |
| | op: Union[ |
| | torch._ops.OpOverload, |
| | torch._ops.OpOverloadPacket, |
| | torch._library.custom_ops.CustomOpDef, |
| | ], |
| | args: Tuple[Any, ...], |
| | kwargs: Optional[Dict[str, Any]] = None, |
| | *, |
| | test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS, |
| | raise_exception: bool = True, |
| | cond: bool = True |
| | ) -> Dict[str, str]: |
| | with unittest.mock.patch("torch.allclose", new=fp8_allclose): |
| | return ( |
| | torch.library.opcheck( |
| | op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception |
| | ) |
| | if cond |
| | else {} |
| | ) |
| |
|
| | def baseline_scaled_mm(a: torch.Tensor, |
| | b: torch.Tensor, |
| | scale_a: torch.Tensor, |
| | scale_b: torch.Tensor, |
| | out_dtype: type[torch.dtype], |
| | bias: Optional[torch.Tensor] = None) -> torch.Tensor: |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | def group_broadcast(t, shape): |
| | for i, s in enumerate(shape): |
| | if t.shape[i] != s and t.shape[i] != 1: |
| | assert s % t.shape[i] == 0 |
| | t = t.unsqueeze(i + 1)\ |
| | .expand(*t.shape[:i+1], s // t.shape[i], *t.shape[i+1:])\ |
| | .flatten(i, i + 1) |
| | return t |
| |
|
| | scale_a = group_broadcast(scale_a, a.shape) |
| | scale_b = group_broadcast(scale_b, b.shape) |
| |
|
| | output = torch.mm((scale_a * a.to(dtype=torch.float32)), |
| | (scale_b * b.to(dtype=torch.float32))).to(out_dtype) |
| |
|
| | if bias is not None: |
| | output = output + bias |
| |
|
| | return output |
| |
|