adam-atan2 / torch-ext /torch_binding.h
drbh
fix: refactors and adjust name
3683745
#pragma once
#include <torch/torch.h>
namespace adam_atan2 {
void adam_atan2_cuda_impl_(
std::vector<at::Tensor> params,
std::vector<at::Tensor> grads,
std::vector<at::Tensor> exp_avgs,
std::vector<at::Tensor> exp_avg_sqs,
std::vector<at::Tensor> state_steps,
const double lr,
const double beta1,
const double beta2,
const double weight_decay);
}