spikingjelly.activation_based.cuda_kernel.auto_cuda.cfunction 源代码

from typing import Optional


[文档] def wrap_return_codes(y: Optional[str], codes: str): if y is None: return f"({codes})" else: return f"{y} = {codes};"
[文档] def float2half2(y: Optional[str], x: str): codes = f"__float2half2_rn({x})" return wrap_return_codes(y, codes)
[文档] def constant(y: Optional[str], x: float, dtype: str): if dtype == "float": codes = f"{x}f" elif dtype == "half2": codes = f"__float2half2_rn({x}f)" else: raise NotImplementedError(dtype) return wrap_return_codes(y, codes)
[文档] def abs(y: Optional[str], x: str, dtype: str): if dtype == "float": codes = f"fabsf({x})" elif dtype == "half2": codes = f"__habs2({x})" else: raise NotImplementedError(dtype) return wrap_return_codes(y, codes)
[文档] def power(z: Optional[str], x: str, y: str, dtype: str): if dtype == "float": codes = f"__powf({x, y})" elif dtype == "half2": # CUDA FP16 does not provide powf function. We use z = 2 ** (log2(x) * y) codes = f"h2exp(__hmul2(h2log2({x}), {y}))" else: raise NotImplementedError(dtype) return wrap_return_codes(z, codes)
[文档] def if_else(z: Optional[str], x: str, y: str, mask: str, dtype: str): # z = x * mask + y * (1. - mask) if dtype == "float": codes = f"{x} * {mask} + {y} * (1.0f - {mask})" elif dtype == "half2": codes = f"__hfma2({x}, {mask}, __hmul2({y}, __hsub2(__float2half2_rn(1.0f), {mask})))" else: raise NotImplementedError(dtype) return wrap_return_codes(z, codes)
[文档] def if_else_else( w: Optional[str], x: str, y: str, z: str, mask_x: str, mask_y: str, dtype: str ): # w = mask_x * x + mask_y * y + (1. - mask_x * mask_y) * z if dtype == "float": codes = f"{mask_x} * {x} + {mask_y} * {y} + (1. - {mask_x} * {mask_y}) * {z}" else: codes = f"__hadd2(__hadd2(__hmul2({mask_x}, {x}), __hmul2({mask_y}, {y})), __hmul2({z}, __hsub2(__float2half_rn(1.0f), __hmul2({mask_x}, {mask_y}))))" return wrap_return_codes(w, codes)
[文档] def greater_equal(z: Optional[str], x: str, y: str, dtype: str): if dtype == "float": codes = f"(float) ({x} >= {y})" elif dtype == "half2": codes = f"__hgeu2({x}, {y})" else: raise NotImplementedError(dtype) return wrap_return_codes(z, codes)
[文档] def greater_than(z: Optional[str], x: str, y: str, dtype: str): if dtype == "float": codes = f"(float) ({x} > {y})" elif dtype == "half2": codes = f"__hgtu2({x}, {y})" else: raise NotImplementedError(dtype) return wrap_return_codes(z, codes)
[文档] def minimal(z: Optional[str], x: str, y: str, dtype: str): if dtype == "float": codes = f"min({x}, {y})" elif dtype == "half2": codes = f"__hmin2({x}, {y})" else: raise NotImplementedError(dtype) return wrap_return_codes(z, codes)
[文档] def maximum(z: Optional[str], x: str, y: str, dtype: str): if dtype == "float": codes = f"max({x}, {y})" elif dtype == "half2": codes = f"__hmax2({x}, {y})" else: raise NotImplementedError(dtype) return wrap_return_codes(z, codes)
[文档] def add(z: Optional[str], x: str, y: str, dtype: str): if dtype == "float": if x == "0.0f": codes = f"{y}" elif y == "0.0f": codes = f"{x}" else: codes = f"{x} + {y}" elif dtype == "half2": if x == "__float2half2_rn(0.0f)": codes = f"{y}" elif y == "__float2half2_rn(0.0f)": codes = f"{x}" else: codes = f"__hadd2({x}, {y})" else: raise NotImplementedError(dtype) return wrap_return_codes(z, codes)
[文档] def sub(z: Optional[str], x: str, y: str, dtype: str): if dtype == "float": if y == "0.0f": codes = f"{x}" else: codes = f"{x} - {y}" elif dtype == "half2": if y == "__float2half2_rn(0.0f)": codes = f"{x}" else: codes = f"__hsub2({x}, {y})" else: raise NotImplementedError(dtype) return wrap_return_codes(z, codes)
[文档] def mul(z: Optional[str], x: str, y: str, dtype: str): if dtype == "float": if x == "1.0f": codes = f"{y}" elif y == "1.0f": codes = f"{x}" else: codes = f"{x} * {y}" elif dtype == "half2": if x == "__float2half2_rn(1.0f)": codes = f"{y}" elif y == "__float2half2_rn(1.0f)": codes = f"{x}" else: codes = f"__hmul2({x}, {y})" else: raise NotImplementedError(dtype) return wrap_return_codes(z, codes)
[文档] def div(z: Optional[str], x: str, y: str, dtype: str): if dtype == "float": if y == "1.0f": codes = f"{x}" else: codes = f"{x} / {y}" elif dtype == "half2": if y == "__float2half2_rn(1.0f)": codes = f"{x}" else: codes = f"__h2div({x}, {y})" else: raise NotImplementedError(dtype) return wrap_return_codes(z, codes)
[文档] def neg(y: Optional[str], x: str, dtype: str): if dtype == "float": codes = f"- {x}" elif dtype == "half2": codes = f"__hneg2({x})" else: raise NotImplementedError(dtype) return wrap_return_codes(y, codes)
[文档] def heaviside(y: Optional[str], x: str, dtype: str): if dtype == "float": codes = f"{x} >= 0.0f ? 1.0f: 0.0f" elif dtype == "half2": codes = f"__hgeu2({x}, __float2half2_rn(0.0f))" else: raise NotImplementedError(dtype) return wrap_return_codes(y, codes)
[文档] def exp(y: Optional[str], x: str, dtype: str): if dtype == "float": codes = f"expf({x})" elif dtype == "half2": codes = f"h2exp({x})" else: raise NotImplementedError(dtype) return wrap_return_codes(y, codes)
[文档] def sigmoid(y: Optional[str], x: str, alpha: float, dtype: str): alpha = constant(None, alpha, dtype) if dtype == "float": codes = f"1.0f / (1.0f + expf(- {alpha} * {x}))" elif dtype == "half2": codes = f"__h2div(__float2half2_rn(1.0f), __hadd2(__float2half2_rn(1.0f), h2exp(__hneg2(__hmul2({alpha}, {x})))))" else: raise NotImplementedError(dtype) return wrap_return_codes(y, codes)
[文档] def sigmoid_backward(y: str, x: str, alpha: float, dtype: str): assert y is not None codes = ( sigmoid( y=f"const {dtype} sigmoid_backward__sigmoid_ax", x=x, alpha=alpha, dtype=dtype, ) + "\n" ) alpha = constant(None, alpha, dtype) if dtype == "float": codes += f"{y} = (1.0f - sigmoid_backward__sigmoid_ax) * sigmoid_backward__sigmoid_ax * {alpha};" return codes elif dtype == "half2": codes += f"{y} = __hmul2(__hmul2(__hsub2(__float2half2_rn(1.0f), sigmoid_backward__sigmoid_ax), sigmoid_backward__sigmoid_ax), {alpha});" return codes else: raise NotImplementedError(dtype)
[文档] def atan_backward(y: str, x: str, alpha: float, dtype: str): assert y is not None alpha = constant(None, alpha, dtype) if dtype == "float": codes = f"const float atan_backward__alpha_x = ((float) 1.57079632679489661923) * {alpha} * {x};" codes += f"{y} = {alpha} / 2.0f / (1.0f + atan_backward__alpha_x * atan_backward__alpha_x);" return codes elif dtype == "half2": codes = f"const half2 atan_backward__alpha_x = __hmul2(__hmul2(__float2half2_rn((float) 1.57079632679489661923), {alpha}), {x});" codes += f"{y} = __h2div({alpha}, __hmul2(__float2half2_rn(2.0f), __hfma2(atan_backward__alpha_x, atan_backward__alpha_x, __float2half2_rn(1.0f))));" return codes else: raise NotImplementedError(dtype)
[文档] def piecewise_leaky_relu_backward(y: str, x: str, w: float, c: float, dtype: str): assert y is not None w_inv = constant(None, 1.0 / w, dtype) w = constant(None, w, dtype) c = constant(None, c, dtype) codes = greater_equal( z=f"const {dtype} piecewise_leaky_relu_backward__mask", x=w, y=abs(y=None, x=x, dtype=dtype), dtype=dtype, ) codes += if_else( z=y, x=w_inv, y=c, mask="piecewise_leaky_relu_backward__mask", dtype=dtype ) return codes
[文档] def s2nn_backward(y: str, x: str, alpha: float, beta: float, dtype: str): assert y is not None codes = sigmoid_backward( y=f"const {dtype} s2nn_backward__sgax", x=x, alpha=alpha, dtype=dtype ) codes += greater_than( z=f"const {dtype} s2nn_backward__mask", x=constant(None, 0.0, dtype), y=x, dtype=dtype, ) codes += if_else( z=y, x="s2nn_backward__sgax", y=div( z=None, x=constant(None, beta, dtype), y=add(z=None, x=x, y=constant(None, 1.0, dtype), dtype=dtype), dtype=dtype, ), mask="s2nn_backward__mask", dtype=dtype, ) return codes
[文档] def q_pseudo_spike_backward(y: str, x: str, alpha: float, dtype: str): assert y is not None alpha = constant(None, alpha, dtype) if dtype == "float": return f"{y} = __powf(2.0f * fabsf({x}) / ({alpha} - 1.0f) + 1.0f, - {alpha});" elif dtype == "half2": return power( z=y, x=f"__hadd2(__h2div(__hmul2(__float2half2_rn(2.0f), __habs2({x})), __hsub2({alpha}, __float2half2_rn(1.0f))), __float2half2_rn(1.0f))", y=f"__hneg2({alpha})", dtype=dtype, )
[文档] def leaky_k_relu_backward(y: str, x: str, leak: float, k: float, dtype: str): assert y is not None leak = constant(None, leak, dtype) k = constant(None, k, dtype) codes = greater_equal( z=f"const {dtype} leaky_k_relu_backward__mask", x=x, y=constant(None, 0.0, dtype), dtype=dtype, ) codes += if_else(z=y, x=k, y=leak, mask="leaky_k_relu_backward__mask", dtype=dtype) return codes
[文档] def fake_numerical_gradient_backward(y: str, x: str, alpha: float, dtype: str): assert y is not None alpha = constant(None, alpha, dtype) codes = greater_equal( z=f"{dtype} fake_numerical_gradient_backward__mask", x=x, y=constant(None, 0.0, dtype), dtype=dtype, ) codes += mul( z="fake_numerical_gradient_backward__mask", x="fake_numerical_gradient_backward__mask", y=constant(None, 2.0, dtype), dtype=dtype, ) codes += sub( z="fake_numerical_gradient_backward__mask", x="fake_numerical_gradient_backward__mask", y=constant(None, 1.0, dtype), dtype=dtype, ) codes += div( z="fake_numerical_gradient_backward__mask", x="fake_numerical_gradient_backward__mask", y=x, dtype=dtype, ) codes += minimal( z=y, x="fake_numerical_gradient_backward__mask", y=alpha, dtype=dtype ) return codes
[文档] def log_tailed_relu_backward(y: str, x: str, alpha: float, dtype: str): alpha = constant(None, alpha, dtype) codes = greater_equal( z=f"const {dtype} log_tailed_relu_backward__mask_le0", x=constant(None, 0.0, dtype), y=x, dtype=dtype, ) codes += greater_than( z=f"const {dtype} log_tailed_relu_backward__mask_gt1", x=x, y=constant(None, 1, dtype), dtype=dtype, ) codes += if_else_else( w=y, x=alpha, y=div(z=None, x=constant(None, 1.0, dtype), y=x, dtype=dtype), z=constant(None, 1.0, dtype), mask_x=f"const {dtype} log_tailed_relu_backward__mask_le0", mask_y=f"const {dtype} log_tailed_relu_backward__mask_gt1", dtype=dtype, ) return codes