spikingjelly.activation_based.triton_kernel.neuron_kernel package#
This package contains multi-step neuron kernels implemented with Triton.
torch.compile Compatibility#
Triton neuron backend is now compatible with torch.compile for IF/LIF/PLIF multi-step kernels.
Compatibility conditions:
torch>=2.6.0with Triton installed.CUDA device is required (Triton backend does not run on CPU).
Use neuron modules with
step_mode='m'andbackend='triton'.Supported surrogate types in Triton backend are
SigmoidandATan.
Current limits and notes:
Unsupported surrogate functions will raise
NotImplementedErrorin Triton path.torch.library.triton_opis used when available; runtime fallback tocustom_opis supported.Known problematic compile configurations from current validation:
torch.compile(..., mode="reduce-overhead")may trigger CUDAGraph output-overwrite runtime errors.fullgraph=Truecan trigger backend compiler exceptions (observed on PLIF).modeandoptionscannot be used together on some PyTorch versions.
It is recommended to use
backend="inductor"with explicitoptionsand tune CUDAGraph-related options if needed.
IF#
- spikingjelly.activation_based.triton_kernel.neuron_kernel.integrate_and_fire.multistep_if(x_seq: Tensor, v_init: Tensor, v_threshold: float, v_reset: float | None, detach_reset: bool, surrogate_function) tuple[Tensor, Tensor][源代码]#
Multi-step IF neuron forward pass via Triton kernel. API Language: 中文 | English
中文
多步IF神经元Triton kernel前向传播
- 参数:
x_seq (
torch.Tensor) -- Input sequence, shape[T, N, *]v_init (
torch.Tensor) -- Initial membrane potentialv_threshold (float) -- Threshold voltage
v_reset (Optional[float]) -- Reset voltage (
Nonefor soft reset)detach_reset (bool) -- Whether to detach the reset term in backward
surrogate_function (
surrogate.SurrogateFunctionBase) -- Surrogate gradient function
- 返回:
Tuple of (spike_seq, v_seq)
- 返回类型:
English
Multi-step IF neuron Triton kernel forward
- 参数:
x_seq (
torch.Tensor) -- Input sequence, shape[T, N, *]v_init (
torch.Tensor) -- Initial membrane potentialv_threshold (float) -- Threshold voltage
v_reset (Optional[float]) -- Reset voltage (
Nonefor soft reset)detach_reset (bool) -- Whether to detach the reset term in backward
surrogate_function (
surrogate.SurrogateFunctionBase) -- Surrogate gradient function
- 返回:
Tuple of (spike_seq, v_seq)
- 返回类型:
LIF#
- spikingjelly.activation_based.triton_kernel.neuron_kernel.lif.multistep_lif(x_seq: Tensor, v_init: Tensor, decay_input: bool, tau: float, v_threshold: float, v_reset: float | None, detach_reset: bool, surrogate_function) tuple[Tensor, Tensor][源代码]#
Multi-step LIF neuron forward pass via Triton kernel. API Language: 中文 | English
中文
多步LIF神经元Triton kernel前向传播
- 参数:
x_seq (
torch.Tensor) -- Input sequence, shape[T, N, *]v_init (
torch.Tensor) -- Initial membrane potentialdecay_input (bool) -- Whether input participates in decay
tau (float) -- Membrane time constant
v_threshold (float) -- Threshold voltage
v_reset (Optional[float]) -- Reset voltage (
Nonefor soft reset)detach_reset (bool) -- Whether to detach the reset term in backward
surrogate_function (
surrogate.SurrogateFunctionBase) -- Surrogate gradient function
- 返回:
Tuple of (spike_seq, v_seq)
- 返回类型:
English
Multi-step LIF neuron Triton kernel forward
- 参数:
x_seq (
torch.Tensor) -- Input sequence, shape[T, N, *]v_init (
torch.Tensor) -- Initial membrane potentialdecay_input (bool) -- Whether input participates in decay
tau (float) -- Membrane time constant
v_threshold (float) -- Threshold voltage
v_reset (Optional[float]) -- Reset voltage (
Nonefor soft reset)detach_reset (bool) -- Whether to detach the reset term in backward
surrogate_function (
surrogate.SurrogateFunctionBase) -- Surrogate gradient function
- 返回:
Tuple of (spike_seq, v_seq)
- 返回类型:
PLIF#
- spikingjelly.activation_based.triton_kernel.neuron_kernel.plif.multistep_plif(x_seq: Tensor, v_init: Tensor, r_tau: Tensor, decay_input: bool, v_threshold: float, v_reset: float | None, detach_reset: bool, surrogate_function) tuple[Tensor, Tensor][源代码]#
Multi-step Parametric LIF neuron forward pass via Triton kernel. API Language: 中文 | English
中文
多步PLIF神经元Triton kernel前向传播
- 参数:
x_seq (
torch.Tensor) -- Input sequence, shape[T, N, *]v_init (
torch.Tensor) -- Initial membrane potentialr_tau (
torch.Tensor) -- Reciprocal of the learnable membrane time constantdecay_input (bool) -- Whether input participates in decay
v_threshold (float) -- Threshold voltage
v_reset (Optional[float]) -- Reset voltage (
Nonefor soft reset)detach_reset (bool) -- Whether to detach the reset term in backward
surrogate_function (
surrogate.SurrogateFunctionBase) -- Surrogate gradient function
- 返回:
Tuple of (spike_seq, v_seq)
- 返回类型:
English
Multi-step PLIF neuron Triton kernel forward
- 参数:
x_seq (
torch.Tensor) -- Input sequence, shape[T, N, *]v_init (
torch.Tensor) -- Initial membrane potentialr_tau (
torch.Tensor) -- Reciprocal of the learnable membrane time constantdecay_input (bool) -- Whether input participates in decay
v_threshold (float) -- Threshold voltage
v_reset (Optional[float]) -- Reset voltage (
Nonefor soft reset)detach_reset (bool) -- Whether to detach the reset term in backward
surrogate_function (
surrogate.SurrogateFunctionBase) -- Surrogate gradient function
- 返回:
Tuple of (spike_seq, v_seq)
- 返回类型: