from typing import List | |
import torch | |
from ._ops import ops | |
def w8_a16_gemm( | |
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor | |
) -> torch.Tensor: | |
return ops.w8_a16_gemm(input, weight, scale) | |
def w8_a16_gemm_( | |
input: torch.Tensor, | |
weight: torch.Tensor, | |
scale: torch.Tensor, | |
output: torch.Tensor, | |
m: int, | |
n: int, | |
k: int, | |
) -> torch.Tensor: | |
return ops.w8_a16_gemm_(input, weight, scale, output, m, n, k) | |
def preprocess_weights(origin_weight: torch.Tensor, is_int4: bool) -> torch.Tensor: | |
return ops.preprocess_weights(origin_weight, is_int4) | |
def quant_weights( | |
origin_weight: torch.Tensor, | |
quant_type: torch.dtype, | |
return_unprocessed_quantized_tensor: bool, | |
) -> List[torch.Tensor]: | |
return ops.quant_weights( | |
origin_weight, quant_type, return_unprocessed_quantized_tensor | |
) | |