#include #include "registration.h" #include "torch_binding.h" TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("adam_atan2_cuda_impl_(" "Tensor(a!)[] params, " "Tensor(b!)[] grads, " "Tensor(c!)[] exp_avgs, " "Tensor(d!)[] exp_avg_sqs, " "Tensor(e!)[] state_steps, " "float lr, " "float beta1, " "float beta2, " "float weight_decay) -> ()"); ops.impl("adam_atan2_cuda_impl_", torch::kCUDA, &adam_atan2::adam_atan2_cuda_impl_); } REGISTER_EXTENSION(TORCH_EXTENSION_NAME)