File size: 328 Bytes
23d26f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
import triton
import triton.language as tl
from packaging import version
TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0")
if TRITON3:
@triton.jit
def softplus(dt):
return tl.math.log(tl.math.exp(dt) + 1)
else:
@triton.jit
def softplus(dt):
return tl.math.log1p(tl.exp(dt)) |