import numpy as np
import torch
from .... import configure
from .. import cuda_utils, tensor_cache
from ..cuda_utils import resolve_python_object
from .common import (
_CapturedAutogradCtx,
_decode_v_reset,
_resolve_sg_cuda_code_fun,
_should_stash_capture_ctx,
_sg_obj_id,
_stash_capture_ctx,
_take_capture_ctx,
cupy,
)
__all__ = ["create_fptt_kernel", "create_bptt_kernel", "multistep_izhikevich_ptt"]
def create_fptt_kernel(hard_reset: bool, dtype: str):
r"""
**API Language:**
:ref:`中文 <create_fptt_kernel-cn>` | :ref:`English <create_fptt_kernel-en>`
----
.. _create_fptt_kernel-cn:
* **中文**
创建前向传播CUDA kernel
:param hard_reset: Whether to use hard reset mode
:type hard_reset: bool
:param dtype: Data type, ``\"fp32\"`` or ``\"fp16\"``
:type dtype: str
:return: CUDA kernel object with generated code
:rtype: CKernel1D
----
.. _create_fptt_kernel-en:
* **English**
Create forward-pass CUDA kernel
:param hard_reset: Whether to use hard reset mode
:param dtype: Data type, ``\"fp32\"`` or ``\"fp16\"``
:type hard_reset: bool
:type dtype: str
:return: CUDA kernel object with generated code
:rtype: CKernel1D
"""
kernel_name = f"IzhikevichNode_fptt_{'hard' if hard_reset else 'soft'}Reset_{dtype}"
if dtype == "fp32":
code = rf"""
extern "C" __global__
void {kernel_name}(const float* x_seq, float* v_v_seq, float* h_seq, float* w_w_seq, float* spike_seq,
const float & reciprocal_tau,
const float & a0,
const float & v_c,
const float & v_threshold,
const float & v_rest,
const float & reciprocal_tau_w,
const float & a,
const float & b, {"const float & v_reset," if hard_reset else ""}
const int & neuron_num, const int & numel)
"""
code += r"""
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < neuron_num)
{
const int dt = neuron_num;
for(int mem_offset = 0; mem_offset < numel; mem_offset += neuron_num)
{
const int t = index + mem_offset;
h_seq[t] = v_v_seq[t] + reciprocal_tau * (x_seq[t] + a0 * (v_v_seq[t] - v_rest) * (v_v_seq[t] - v_c) - w_w_seq[t]);
const float z = w_w_seq[t] + reciprocal_tau_w * (a * (h_seq[t] - v_rest) - w_w_seq[t]);
if (h_seq[t] >= v_threshold)
{
spike_seq[t] = 1.0f;
"""
if hard_reset:
code += r"""
v_v_seq[t + dt] = v_reset;
"""
else:
code += r"""
v_v_seq[t + dt] = h_seq[t] - v_threshold;
"""
code += r"""
}
else
{
spike_seq[t] = 0.0f;
v_v_seq[t + dt] = h_seq[t];
}
w_w_seq[t + dt] = z + b * spike_seq[t];
}
}
}
"""
else:
raise TypeError
return cupy.RawKernel(
code,
kernel_name,
options=configure.cuda_compiler_options,
backend=configure.cuda_compiler_backend,
)
def create_bptt_kernel(
sg_cuda_code_fun, hard_reset: bool, detach_reset: bool, dtype: str
):
r"""
**API Language:**
:ref:`中文 <create_bptt_kernel-cn>` | :ref:`English <create_bptt_kernel-en>`
----
.. _create_bptt_kernel-cn:
* **中文**
创建反向传播CUDA kernel
:param sg_cuda_code_fun: Callable that generates surrogate gradient CUDA code
:type sg_cuda_code_fun: ``Callable``
:param hard_reset: Whether to use hard reset mode
:type hard_reset: bool
:param detach_reset: Whether to detach the reset term in backward
:type detach_reset: bool
:param dtype: Data type, ``\"fp32\"`` or ``\"fp16\"``
:type dtype: str
:return: CUDA kernel object with generated code
:rtype: CKernel1D
----
.. _create_bptt_kernel-en:
* **English**
Create backward-pass CUDA kernel
:param sg_cuda_code_fun: Callable that generates surrogate gradient CUDA code
:param hard_reset: Whether to use hard reset mode
:param detach_reset: Whether to detach the reset term in backward
:param dtype: Data type, ``\"fp32\"`` or ``\"fp16\"``
:type sg_cuda_code_fun: ``Callable``
:type hard_reset: bool
:type detach_reset: bool
:type dtype: str
:return: CUDA kernel object with generated code
:rtype: CKernel1D
"""
kernel_name = f"IzhikevichNode_bptt_{'hard' if hard_reset else 'soft'}Reset_{'detachReset' if detach_reset else ''}_{dtype}"
code_grad_s_to_h = sg_cuda_code_fun(x="over_th", y="grad_s_to_h", dtype=dtype)
if dtype == "fp32":
code = rf"""
extern "C" __global__
void {kernel_name}(
const float* grad_spike_seq, const float* grad_v_seq,
const float* grad_w_seq, const float* h_seq,
const float* spike_seq, const float* v_v_seq,
float* grad_x_seq, float* grad_v_init, float* grad_w_init,
const float & reciprocal_tau, const float & one_sub_reciprocal_tau_w,
const float & a_over_tau_w, const float & a0_over_tau,
const float & b, const float & neg_sum_v_rest_v_c,
const float & v_threshold, {"const float & v_reset," if hard_reset else ""}
const int & neuron_num, const int & numel)
"""
code += r"""
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < neuron_num)
{
float grad_h = 0.0f; // grad_h will be used recursively
float grad_w = 0.0f; // grad_w will be used recursively
for(int mem_offset = numel - neuron_num; mem_offset >= 0; mem_offset -= neuron_num)
{
const int t = index + mem_offset;
const float over_th = h_seq[t] - v_threshold;
"""
code += code_grad_s_to_h
if detach_reset:
if hard_reset:
code_grad_v_to_h = r"""
const float grad_v_to_h = 1.0f - spike_seq[t];
"""
else:
code_grad_v_to_h = r"""
const float grad_v_to_h = 1.0f;
"""
else:
if hard_reset:
code_grad_v_to_h = r"""
const float grad_v_to_h = 1.0f - spike_seq[t] + (v_reset - h_seq[t]) * grad_s_to_h;
"""
else:
code_grad_v_to_h = r"""
const float grad_v_to_h = 1.0f - v_threshold * grad_s_to_h;
"""
code += code_grad_v_to_h
code += r"""
grad_w = -reciprocal_tau * grad_h + one_sub_reciprocal_tau_w * grad_w;
grad_h = grad_w * (a_over_tau_w + b * grad_s_to_h) + ((1 + a0_over_tau * (2.0f * v_v_seq[t + neuron_num] + neg_sum_v_rest_v_c)) * grad_h + grad_v_seq[t]) * grad_v_to_h + grad_spike_seq[t] * grad_s_to_h;
grad_x_seq[t] = grad_h * reciprocal_tau;
}
grad_v_init[index] = grad_x_seq[index] * (1.0f + a0_over_tau * (2.0f * v_v_seq[index] + neg_sum_v_rest_v_c));
grad_w_init[index] = -reciprocal_tau * grad_h + one_sub_reciprocal_tau_w * grad_w;
}
}
"""
else:
raise TypeError
return cupy.RawKernel(
code,
kernel_name,
options=configure.cuda_compiler_options,
backend=configure.cuda_compiler_backend,
)
def _iz_forward(
ctx,
x_seq: torch.Tensor,
v_init: torch.Tensor,
w_init: torch.Tensor,
tau: float,
v_threshold: float,
v_reset: float,
v_rest: float,
a: float,
b: float,
tau_w: float,
v_c: float,
a0: float,
detach_reset: bool,
sg_cuda_code_fun,
):
requires_grad = x_seq.requires_grad or v_init.requires_grad
device = x_seq.get_device()
if x_seq.dtype == torch.float32:
dtype = "fp32"
cp_dtype = np.float32
else:
raise NotImplementedError
zero_shape = list(x_seq.shape)
zero_shape[0] *= 4
v_seq, h_seq, w_seq, spike_seq = torch.split(
torch.zeros(zero_shape, device=x_seq.device, dtype=x_seq.dtype),
x_seq.shape[0],
)
v_v_seq = torch.cat((v_init.unsqueeze(0), v_seq))
w_w_seq = torch.cat((w_init.unsqueeze(0), w_seq))
with cuda_utils.DeviceEnvironment(device):
numel = x_seq.numel()
neuron_num = numel // x_seq.shape[0]
threads = configure.cuda_threads
blocks = cuda_utils.cal_blocks(neuron_num)
cp_numel = cupy.asarray(numel)
cp_neuron_num = cupy.asarray(neuron_num)
cp_v_threshold = cupy.asarray(v_threshold, dtype=cp_dtype)
cp_v_rest = cupy.asarray(v_rest, dtype=cp_dtype)
cp_v_c = cupy.asarray(v_c, dtype=cp_dtype)
cp_a0 = cupy.asarray(a0, dtype=cp_dtype)
cp_a = cupy.asarray(a, dtype=cp_dtype)
cp_b = cupy.asarray(b, dtype=cp_dtype)
cp_reciprocal_tau = cupy.asarray(1.0 / tau, dtype=cp_dtype)
cp_reciprocal_tau_w = cupy.asarray(1.0 / tau_w, dtype=cp_dtype)
cp_a0_over_tau = cupy.asarray(a0 / tau, dtype=cp_dtype)
cp_a_over_tau_w = cupy.asarray(a / tau_w, dtype=cp_dtype)
cp_one_sub_reciprocal_tau_w = cupy.asarray(1.0 - 1.0 / tau_w, dtype=cp_dtype)
cp_neg_sum_v_rest_v_c = cupy.asarray(-v_rest - v_c, dtype=cp_dtype)
if v_reset is None:
cp_v_reset = None
hard_reset = False
(
x_seq,
v_v_seq,
h_seq,
w_w_seq,
spike_seq,
cp_reciprocal_tau,
cp_a0,
cp_v_c,
cp_v_threshold,
cp_v_rest,
cp_reciprocal_tau_w,
cp_a,
cp_b,
cp_neuron_num,
cp_numel,
) = cuda_utils.get_contiguous(
x_seq,
v_v_seq,
h_seq,
w_w_seq,
spike_seq,
cp_reciprocal_tau,
cp_a0,
cp_v_c,
cp_v_threshold,
cp_v_rest,
cp_reciprocal_tau_w,
cp_a,
cp_b,
cp_neuron_num,
cp_numel,
)
kernel_args = [
x_seq,
v_v_seq,
h_seq,
w_w_seq,
spike_seq,
cp_reciprocal_tau,
cp_a0,
cp_v_c,
cp_v_threshold,
cp_v_rest,
cp_reciprocal_tau_w,
cp_a,
cp_b,
cp_neuron_num,
cp_numel,
]
else:
cp_v_reset = cupy.asarray(v_reset, dtype=cp_dtype)
hard_reset = True
(
x_seq,
v_v_seq,
h_seq,
w_w_seq,
spike_seq,
cp_reciprocal_tau,
cp_a0,
cp_v_c,
cp_v_threshold,
cp_v_rest,
cp_reciprocal_tau_w,
cp_a,
cp_b,
cp_v_reset,
cp_neuron_num,
cp_numel,
) = cuda_utils.get_contiguous(
x_seq,
v_v_seq,
h_seq,
w_w_seq,
spike_seq,
cp_reciprocal_tau,
cp_a0,
cp_v_c,
cp_v_threshold,
cp_v_rest,
cp_reciprocal_tau_w,
cp_a,
cp_b,
cp_v_reset,
cp_neuron_num,
cp_numel,
)
kernel_args = [
x_seq,
v_v_seq,
h_seq,
w_w_seq,
spike_seq,
cp_reciprocal_tau,
cp_a0,
cp_v_c,
cp_v_threshold,
cp_v_rest,
cp_reciprocal_tau_w,
cp_a,
cp_b,
cp_v_reset,
cp_neuron_num,
cp_numel,
]
kernel = create_fptt_kernel(hard_reset, dtype)
kernel(
(blocks,),
(threads,),
cuda_utils.wrap_args_to_raw_kernel(device, *kernel_args),
)
if requires_grad:
if configure.save_spike_as_bool_in_neuron_kernel:
ctx.s_shape = spike_seq.shape
ctx.s_tk = tensor_cache.BOOL_TENSOR_CACHE.store_bool(spike_seq)
ctx.save_for_backward(h_seq, v_v_seq)
else:
ctx.save_for_backward(h_seq, spike_seq, v_v_seq)
ctx.blocks = blocks
ctx.threads = threads
ctx.cp_numel = cp_numel
ctx.cp_neuron_num = cp_neuron_num
ctx.cp_reciprocal_tau = cp_reciprocal_tau
ctx.cp_one_sub_reciprocal_tau_w = cp_one_sub_reciprocal_tau_w
ctx.cp_a_over_tau_w = cp_a_over_tau_w
ctx.cp_a0_over_tau = cp_a0_over_tau
ctx.cp_b = cp_b
ctx.cp_neg_sum_v_rest_v_c = cp_neg_sum_v_rest_v_c
ctx.cp_v_threshold = cp_v_threshold
ctx.cp_v_reset = cp_v_reset
ctx.detach_reset = detach_reset
ctx.sg_cuda_code_fun = sg_cuda_code_fun
return spike_seq, v_v_seq[1:,], w_w_seq[1:,]
def _iz_backward(ctx, grad_spike_seq, grad_v_seq, grad_w_seq):
device = grad_spike_seq.get_device()
if configure.save_spike_as_bool_in_neuron_kernel:
spike_seq = tensor_cache.BOOL_TENSOR_CACHE.get_float(ctx.s_tk, ctx.s_shape)
h_seq, v_v_seq = ctx.saved_tensors
else:
h_seq, spike_seq, v_v_seq = ctx.saved_tensors
zero_shape = list(grad_spike_seq.shape)
zero_shape[0] += 2
zero_data = torch.zeros(
zero_shape, device=grad_spike_seq.device, dtype=grad_spike_seq.dtype
)
grad_x_seq = zero_data[0:-2]
grad_v_init = zero_data[-2]
grad_w_init = zero_data[-1]
if ctx.cp_v_reset is None:
hard_reset = False
else:
hard_reset = True
if grad_spike_seq.dtype == torch.float32:
dtype = "fp32"
else:
raise NotImplementedError
kernel = create_bptt_kernel(
ctx.sg_cuda_code_fun, hard_reset, ctx.detach_reset, dtype
)
with cuda_utils.DeviceEnvironment(device):
if hard_reset:
(
grad_spike_seq,
grad_v_seq,
grad_w_seq,
h_seq,
spike_seq,
v_v_seq,
grad_x_seq,
grad_v_init,
grad_w_init,
ctx.cp_reciprocal_tau,
ctx.cp_one_sub_reciprocal_tau_w,
ctx.cp_a_over_tau_w,
ctx.cp_a0_over_tau,
ctx.cp_b,
ctx.cp_neg_sum_v_rest_v_c,
ctx.cp_v_threshold,
ctx.cp_v_reset,
ctx.cp_neuron_num,
ctx.cp_numel,
) = cuda_utils.get_contiguous(
grad_spike_seq,
grad_v_seq,
grad_w_seq,
h_seq,
spike_seq,
v_v_seq,
grad_x_seq,
grad_v_init,
grad_w_init,
ctx.cp_reciprocal_tau,
ctx.cp_one_sub_reciprocal_tau_w,
ctx.cp_a_over_tau_w,
ctx.cp_a0_over_tau,
ctx.cp_b,
ctx.cp_neg_sum_v_rest_v_c,
ctx.cp_v_threshold,
ctx.cp_v_reset,
ctx.cp_neuron_num,
ctx.cp_numel,
)
kernel_args = [
grad_spike_seq,
grad_v_seq,
grad_w_seq,
h_seq,
spike_seq,
v_v_seq,
grad_x_seq,
grad_v_init,
grad_w_init,
ctx.cp_reciprocal_tau,
ctx.cp_one_sub_reciprocal_tau_w,
ctx.cp_a_over_tau_w,
ctx.cp_a0_over_tau,
ctx.cp_b,
ctx.cp_neg_sum_v_rest_v_c,
ctx.cp_v_threshold,
ctx.cp_v_reset,
ctx.cp_neuron_num,
ctx.cp_numel,
]
else:
(
grad_spike_seq,
grad_v_seq,
grad_w_seq,
h_seq,
spike_seq,
v_v_seq,
grad_x_seq,
grad_v_init,
grad_w_init,
ctx.cp_reciprocal_tau,
ctx.cp_one_sub_reciprocal_tau_w,
ctx.cp_a_over_tau_w,
ctx.cp_a0_over_tau,
ctx.cp_b,
ctx.cp_neg_sum_v_rest_v_c,
ctx.cp_v_threshold,
ctx.cp_neuron_num,
ctx.cp_numel,
) = cuda_utils.get_contiguous(
grad_spike_seq,
grad_v_seq,
grad_w_seq,
h_seq,
spike_seq,
v_v_seq,
grad_x_seq,
grad_v_init,
grad_w_init,
ctx.cp_reciprocal_tau,
ctx.cp_one_sub_reciprocal_tau_w,
ctx.cp_a_over_tau_w,
ctx.cp_a0_over_tau,
ctx.cp_b,
ctx.cp_neg_sum_v_rest_v_c,
ctx.cp_v_threshold,
ctx.cp_neuron_num,
ctx.cp_numel,
)
kernel_args = [
grad_spike_seq,
grad_v_seq,
grad_w_seq,
h_seq,
spike_seq,
v_v_seq,
grad_x_seq,
grad_v_init,
grad_w_init,
ctx.cp_reciprocal_tau,
ctx.cp_one_sub_reciprocal_tau_w,
ctx.cp_a_over_tau_w,
ctx.cp_a0_over_tau,
ctx.cp_b,
ctx.cp_neg_sum_v_rest_v_c,
ctx.cp_v_threshold,
ctx.cp_neuron_num,
ctx.cp_numel,
]
kernel(
(ctx.blocks,),
(ctx.threads,),
cuda_utils.wrap_args_to_raw_kernel(device, *kernel_args),
)
return (
grad_x_seq,
grad_v_init,
grad_w_init,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
@torch.library.custom_op("sj::cupy_multistep_izhikevich_forward", mutates_args=())
def cupy_multistep_izhikevich_forward(
x_seq: torch.Tensor,
v_init: torch.Tensor,
w_init: torch.Tensor,
tau: float,
v_threshold: float,
v_reset: float,
v_rest: float,
a: float,
b: float,
tau_w: float,
v_c: float,
a0: float,
detach_reset: bool,
sg_id: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
sg = resolve_python_object(sg_id)
captured_ctx = _CapturedAutogradCtx()
out = _iz_forward(
captured_ctx,
x_seq,
v_init,
w_init,
tau,
v_threshold,
_decode_v_reset(v_reset),
v_rest,
a,
b,
tau_w,
v_c,
a0,
detach_reset,
_resolve_sg_cuda_code_fun(sg),
)
capture_id = (
_stash_capture_ctx(captured_ctx)
if _should_stash_capture_ctx((x_seq, v_init, w_init))
else -1
)
capture_token = torch.tensor(capture_id, device=x_seq.device, dtype=torch.int64)
return (*out, capture_token)
@torch.library.register_fake("sj::cupy_multistep_izhikevich_forward")
def _cupy_multistep_izhikevich_forward_fake(*args):
x_seq = args[0]
return (
x_seq.new_empty(x_seq.shape),
x_seq.new_empty(x_seq.shape),
x_seq.new_empty(x_seq.shape),
x_seq.new_empty((), dtype=torch.int64),
)
def _setup_ctx(ctx, inputs, output):
capture_token = output[-1]
if capture_token.is_meta:
ctx.captured = None
return
capture_id = int(capture_token.item())
if capture_id < 0:
ctx.captured = None
return
ctx.captured = _take_capture_ctx(capture_id)
def _bw(ctx, *grad_outputs):
if ctx.captured is None:
raise RuntimeError("Missing captured context for backward.")
grads = _iz_backward(ctx.captured, *grad_outputs[:-1])
return (
grads[0],
grads[1],
grads[2],
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
torch.library.register_autograd(
"sj::cupy_multistep_izhikevich_forward", _bw, setup_context=_setup_ctx
)
[文档]
def multistep_izhikevich_ptt(
x_seq,
v_init,
w_init,
tau,
v_threshold,
v_reset,
v_rest,
a,
b,
tau_w,
v_c,
a0,
detach_reset,
surrogate_function,
):
"""Multi-step Izhikevich neuron forward pass via CuPy PTT custom op.
**API Language:**
:ref:`中文 <multistep_izhikevich_ptt-cn>` | :ref:`English <multistep_izhikevich_ptt-en>`
----
.. _multistep_izhikevich_ptt-cn:
* **中文**
多步Izhikevich神经元脉冲前向传播
:param x_seq: Input sequence, shape ``[T, N, *]``
:type x_seq: ``torch.Tensor``
:param v_init: Initial membrane potential
:type v_init: ``torch.Tensor``
:param w_init: Initial recovery variable
:type w_init: ``torch.Tensor``
:param tau: Membrane time constant
:type tau: float
:param v_threshold: Threshold voltage
:type v_threshold: float
:param v_reset: Reset voltage (``None`` for soft reset)
:type v_reset: Optional[float]
:param v_rest: Resting potential
:type v_rest: float
:param a: Time scale of the recovery variable
:type a: float
:param b: Sensitivity of the recovery variable
:type b: float
:param tau_w: Time constant of the recovery variable
:type tau_w: float
:param v_c: Cutoff voltage
:type v_c: float
:param a0: Reset value of the recovery variable
:type a0: float
:param detach_reset: Whether to detach the reset term in backward
:type detach_reset: bool
:param surrogate_function: Surrogate gradient function
:type surrogate_function: ``surrogate.SurrogateFunctionBase``
:return: Tuple of (spike_seq, v_seq, w_seq)
:rtype: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
----
.. _multistep_izhikevich_ptt-en:
* **English**
Multi-step Izhikevich neuron spike forward
:param x_seq: Input sequence, shape ``[T, N, *]``
:param v_init: Initial membrane potential
:param w_init: Initial recovery variable
:param tau: Membrane time constant
:param v_threshold: Threshold voltage
:param v_reset: Reset voltage (``None`` for soft reset)
:param v_rest: Resting potential
:param a: Time scale of the recovery variable
:param b: Sensitivity of the recovery variable
:param tau_w: Time constant of the recovery variable
:param v_c: Cutoff voltage
:param a0: Reset value of the recovery variable
:param detach_reset: Whether to detach the reset term in backward
:param surrogate_function: Surrogate gradient function
:type x_seq: ``torch.Tensor``
:type v_init: ``torch.Tensor``
:type w_init: ``torch.Tensor``
:type tau: float
:type v_threshold: float
:type v_reset: Optional[float]
:type v_rest: float
:type a: float
:type b: float
:type tau_w: float
:type v_c: float
:type a0: float
:type detach_reset: bool
:type surrogate_function: ``surrogate.SurrogateFunctionBase``
:return: Tuple of (spike_seq, v_seq, w_seq)
:rtype: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
"""
sg_id = _sg_obj_id(surrogate_function)
v_reset_value = float("nan") if v_reset is None else float(v_reset)
return cupy_multistep_izhikevich_forward(
x_seq,
v_init,
w_init,
tau,
v_threshold,
v_reset_value,
v_rest,
a,
b,
tau_w,
v_c,
a0,
detach_reset,
sg_id,
)[:-1]