|
#pragma once |
|
#include <ATen/cuda/CUDAContext.h> |
|
#include <torch/extension.h> |
|
|
|
#define CHECK_CUDA(x) \ |
|
do { \ |
|
AT_ASSERT(x.is_cuda(), #x " must be a CUDA tensor"); \ |
|
} while (0) |
|
|
|
#define CHECK_CONTIGUOUS(x) \ |
|
do { \ |
|
AT_ASSERT(x.is_contiguous(), #x " must be a contiguous tensor"); \ |
|
} while (0) |
|
|
|
#define CHECK_IS_INT(x) \ |
|
do { \ |
|
AT_ASSERT(x.scalar_type() == at::ScalarType::Int, \ |
|
#x " must be an int tensor"); \ |
|
} while (0) |
|
|
|
#define CHECK_IS_FLOAT(x) \ |
|
do { \ |
|
AT_ASSERT(x.scalar_type() == at::ScalarType::Float, \ |
|
#x " must be a float tensor"); \ |
|
} while (0) |
|
|