spikingjelly.activation_based.triton_kernel.neuron_kernel package#

This package contains multi-step neuron kernels implemented with Triton.

torch.compile Compatibility#

Triton neuron backend is now compatible with torch.compile for IF/LIF/PLIF multi-step kernels.

Compatibility conditions:

  • torch>=2.6.0 with Triton installed.

  • CUDA device is required (Triton backend does not run on CPU).

  • Use neuron modules with step_mode='m' and backend='triton'.

  • Supported surrogate types in Triton backend are Sigmoid and ATan.

Current limits and notes:

  • Unsupported surrogate functions will raise NotImplementedError in Triton path.

  • torch.library.triton_op is used when available; runtime fallback to custom_op is supported.

  • Known problematic compile configurations from current validation:

    • torch.compile(..., mode="reduce-overhead") may trigger CUDAGraph output-overwrite runtime errors.

    • fullgraph=True can trigger backend compiler exceptions (observed on PLIF).

    • mode and options cannot be used together on some PyTorch versions.

  • It is recommended to use backend="inductor" with explicit options and tune CUDAGraph-related options if needed.

IF#

spikingjelly.activation_based.triton_kernel.neuron_kernel.integrate_and_fire.multistep_if(x_seq: Tensor, v_init: Tensor, v_threshold: float, v_reset: float | None, detach_reset: bool, surrogate_function) tuple[Tensor, Tensor][源代码]#

Multi-step IF neuron forward pass via Triton kernel. API Language: 中文 | English


  • 中文

多步IF神经元Triton kernel前向传播

参数:
  • x_seq (torch.Tensor) -- Input sequence, shape [T, N, *]

  • v_init (torch.Tensor) -- Initial membrane potential

  • v_threshold (float) -- Threshold voltage

  • v_reset (Optional[float]) -- Reset voltage (None for soft reset)

  • detach_reset (bool) -- Whether to detach the reset term in backward

  • surrogate_function (surrogate.SurrogateFunctionBase) -- Surrogate gradient function

返回:

Tuple of (spike_seq, v_seq)

返回类型:

tuple[torch.Tensor, torch.Tensor]


  • English

Multi-step IF neuron Triton kernel forward

参数:
  • x_seq (torch.Tensor) -- Input sequence, shape [T, N, *]

  • v_init (torch.Tensor) -- Initial membrane potential

  • v_threshold (float) -- Threshold voltage

  • v_reset (Optional[float]) -- Reset voltage (None for soft reset)

  • detach_reset (bool) -- Whether to detach the reset term in backward

  • surrogate_function (surrogate.SurrogateFunctionBase) -- Surrogate gradient function

返回:

Tuple of (spike_seq, v_seq)

返回类型:

tuple[torch.Tensor, torch.Tensor]

LIF#

spikingjelly.activation_based.triton_kernel.neuron_kernel.lif.multistep_lif(x_seq: Tensor, v_init: Tensor, decay_input: bool, tau: float, v_threshold: float, v_reset: float | None, detach_reset: bool, surrogate_function) tuple[Tensor, Tensor][源代码]#

Multi-step LIF neuron forward pass via Triton kernel. API Language: 中文 | English


  • 中文

多步LIF神经元Triton kernel前向传播

参数:
  • x_seq (torch.Tensor) -- Input sequence, shape [T, N, *]

  • v_init (torch.Tensor) -- Initial membrane potential

  • decay_input (bool) -- Whether input participates in decay

  • tau (float) -- Membrane time constant

  • v_threshold (float) -- Threshold voltage

  • v_reset (Optional[float]) -- Reset voltage (None for soft reset)

  • detach_reset (bool) -- Whether to detach the reset term in backward

  • surrogate_function (surrogate.SurrogateFunctionBase) -- Surrogate gradient function

返回:

Tuple of (spike_seq, v_seq)

返回类型:

tuple[torch.Tensor, torch.Tensor]


  • English

Multi-step LIF neuron Triton kernel forward

参数:
  • x_seq (torch.Tensor) -- Input sequence, shape [T, N, *]

  • v_init (torch.Tensor) -- Initial membrane potential

  • decay_input (bool) -- Whether input participates in decay

  • tau (float) -- Membrane time constant

  • v_threshold (float) -- Threshold voltage

  • v_reset (Optional[float]) -- Reset voltage (None for soft reset)

  • detach_reset (bool) -- Whether to detach the reset term in backward

  • surrogate_function (surrogate.SurrogateFunctionBase) -- Surrogate gradient function

返回:

Tuple of (spike_seq, v_seq)

返回类型:

tuple[torch.Tensor, torch.Tensor]

PLIF#

spikingjelly.activation_based.triton_kernel.neuron_kernel.plif.multistep_plif(x_seq: Tensor, v_init: Tensor, r_tau: Tensor, decay_input: bool, v_threshold: float, v_reset: float | None, detach_reset: bool, surrogate_function) tuple[Tensor, Tensor][源代码]#

Multi-step Parametric LIF neuron forward pass via Triton kernel. API Language: 中文 | English


  • 中文

多步PLIF神经元Triton kernel前向传播

参数:
  • x_seq (torch.Tensor) -- Input sequence, shape [T, N, *]

  • v_init (torch.Tensor) -- Initial membrane potential

  • r_tau (torch.Tensor) -- Reciprocal of the learnable membrane time constant

  • decay_input (bool) -- Whether input participates in decay

  • v_threshold (float) -- Threshold voltage

  • v_reset (Optional[float]) -- Reset voltage (None for soft reset)

  • detach_reset (bool) -- Whether to detach the reset term in backward

  • surrogate_function (surrogate.SurrogateFunctionBase) -- Surrogate gradient function

返回:

Tuple of (spike_seq, v_seq)

返回类型:

tuple[torch.Tensor, torch.Tensor]


  • English

Multi-step PLIF neuron Triton kernel forward

参数:
  • x_seq (torch.Tensor) -- Input sequence, shape [T, N, *]

  • v_init (torch.Tensor) -- Initial membrane potential

  • r_tau (torch.Tensor) -- Reciprocal of the learnable membrane time constant

  • decay_input (bool) -- Whether input participates in decay

  • v_threshold (float) -- Threshold voltage

  • v_reset (Optional[float]) -- Reset voltage (None for soft reset)

  • detach_reset (bool) -- Whether to detach the reset term in backward

  • surrogate_function (surrogate.SurrogateFunctionBase) -- Surrogate gradient function

返回:

Tuple of (spike_seq, v_seq)

返回类型:

tuple[torch.Tensor, torch.Tensor]