#include #include #include #include #include namespace adam_atan2 { using at::native::kILP; constexpr int kArgsDepth = 4; constexpr uint8_t kParamIdx = 0; constexpr uint8_t kGradIdx = 1; constexpr uint8_t kExpAvgIdx = 2; constexpr uint8_t kExpAvgSqIdx = 3; template __device__ __forceinline__ T lerp(const T v0, const T v1, const T t) { // NOTE(one): Identical to PyTorch when t < 0.5 // https://github.com/pytorch/pytorch/blob/b7f25226929e70187a9f36c393665abad0b25190/aten/src/ATen/native/Lerp.h#L21 return fma(t, v1, fma(-t, v0, v0)); } template __device__ __forceinline__ void adam_math( scalar_type r_args[kArgsDepth][kILP], const opmath_t &step_size, const opmath_t &wd_alpha, const opmath_t &mbeta1, const opmath_t &mbeta2, const opmath_t &bias_correction2_sqrt) { #pragma unroll for (int ii = 0; ii < kILP; ii++) { // Load values. opmath_t param = static_cast(r_args[kParamIdx][ii]); const opmath_t grad = static_cast(r_args[kGradIdx][ii]); opmath_t exp_avg = static_cast(r_args[kExpAvgIdx][ii]); opmath_t exp_avg_sq = static_cast(r_args[kExpAvgSqIdx][ii]); param *= wd_alpha; exp_avg = lerp(exp_avg, grad, mbeta1); exp_avg_sq = lerp(exp_avg_sq, grad * grad, mbeta2); const opmath_t denom = std::sqrt(exp_avg_sq) / bias_correction2_sqrt; param -= step_size * std::atan2(exp_avg, denom); // Store results. r_args[kParamIdx][ii] = param; r_args[kExpAvgIdx][ii] = exp_avg; r_args[kExpAvgSqIdx][ii] = exp_avg_sq; } } template struct FusedAdamMathFunctor { using opmath_t = at::opmath_type; __device__ __forceinline__ void operator()( int chunk_size, at::native::FusedOptimizerTensorListMetadata& tl, const double& lr, const double& beta1, const double& beta2, const double& weight_decay) { const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; const auto [step_size, wd_alpha, bias_correction2_sqrt, mbeta1, mbeta2] = [&]() -> std::tuple { auto* step_count = reinterpret_cast(tl.state_steps_addresses[tensor_loc]); const auto bias_correction1 = 1 - at::native::pow_(beta1, *step_count); const auto bias_correction2 = 1 - at::native::pow_(beta2, *step_count); const auto bias_correction2_sqrt = std::sqrt(bias_correction2); return { static_cast(lr / bias_correction1), static_cast(1 - lr * weight_decay), static_cast(bias_correction2_sqrt), static_cast(1 - beta1), static_cast(1 - beta2) }; }(); scalar_type* args[kArgsDepth]; scalar_type r_args[kArgsDepth][kILP]; const auto n = tl.numel_for_tensor[tensor_loc] - chunk_idx * chunk_size; const bool all_aligned{ at::native::init_args(args, tl, chunk_idx, chunk_size, tensor_loc)}; if ((n % kILP == 0) && (chunk_size % kILP == 0) && all_aligned) { for (int64_t i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) { #pragma unroll for (int i = 0; i < kArgsDepth; i++) { at::native::load_store(r_args[i], args[i], 0, i_start); } adam_math( r_args, step_size, wd_alpha, mbeta1, mbeta2, bias_correction2_sqrt); #pragma unroll for (int i = 0; i < kArgsDepth; i++) { if (i != kGradIdx) { at::native::load_store(args[i], r_args[i], i_start, 0); } } } } else { for (int64_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) { at::native::load_args(r_args, args, i_start, chunk_size, n); adam_math( r_args, step_size, wd_alpha, mbeta1, mbeta2, bias_correction2_sqrt); #pragma unroll for (int i = 0; i < kArgsDepth; i++) { if (i != kGradIdx) { at::native::store_args(args[i], r_args[i], i_start, chunk_size, n); } } } } } }; void adam_atan2_cuda_impl_( std::vector params, std::vector grads, std::vector exp_avgs, std::vector exp_avg_sqs, std::vector state_steps, const double lr, const double beta1, const double beta2, const double weight_decay) { std::vector> tensor_lists{params, grads, exp_avgs, exp_avg_sqs}; AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, params[0].scalar_type(), "adam_atan2_kernel_cuda", [&]() { at::native::multi_tensor_apply_for_fused_optimizer( tensor_lists, state_steps, FusedAdamMathFunctor(), lr, beta1, beta2, weight_decay); }); } } // namespace adam_atan2