TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { | |
ops.def("adam_atan2_cuda_impl_(" | |
"Tensor(a!)[] params, " | |
"Tensor(b!)[] grads, " | |
"Tensor(c!)[] exp_avgs, " | |
"Tensor(d!)[] exp_avg_sqs, " | |
"Tensor(e!)[] state_steps, " | |
"float lr, " | |
"float beta1, " | |
"float beta2, " | |
"float weight_decay) -> ()"); | |
ops.impl("adam_atan2_cuda_impl_", torch::kCUDA, &adam_atan2::adam_atan2_cuda_impl_); | |
} | |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME) | |