FlexSN#

Author: Yifan Huang (AllenYolk), wei.fang

中文版: FlexSN

This tutorial focuses on FlexSN. FlexSN can generate high-performance multi-step kernels from a user-defined single-step neuronal dynamics function core. If you have not read the Triton backend basics yet, it is recommended to read Triton Backend first to understand the usage and constraints of predefined Triton neuron kernels.

Using FlexSN to Customize Triton Neuron Kernels#

Describing Neuronal Dynamics with Functions#

The discrete-time dynamics of most spiking neuron models can be described as:

\[Y_1[t], Y_2[t], \dots, V_1[t], V_2[t], \dots = f_s\left( X_1[t], X_2[t], \dots, V_1[t-1], V_2[t-1], \dots \right)\]

where \(Y_i\) denotes outputs, \(V_i\) denotes state variables, and \(X_i\) denotes inputs. This equation can be described using a PyTorch function:

def single_step_inference(x1, x2, ..., v1, v2, ...):
    ...
    return y1, y2, ..., v1_updated, v2_updated, ...

Here, x1, x2, ... represent inputs, v1, v2, ... represent state variables, y1, y2, ... represent outputs, and v1_updated, v2_updated, ... represent the updated state variables (corresponding to v1, v2, ...). For example, a soft-reset LIF neuron with non-decaying input (tau=2, v_th=1.0, sigmoid surrogate function) can be described as:

from spikingjelly.activation_based import surrogate

tau = 2.0 # time constant
v_th = 1.0 # threshold
spike_fn = surrogate.Sigmoid()

def lif_single_step_inference(x, v):
    h = (1 - 1/tau) * v + x
    s = spike_fn(h - v_th)
    v = h - s * v_th
    return s, v

In this example, there is only one input, one output, and one state variable. For more complex neuron models, however, there may be multiple inputs, outputs, and state variables. In addition, the model hyperparameters tau, v_th, and spike_fn here are fixed global variables. To flexibly configure hyperparameters, function closures can be used:

from spikingjelly.activation_based import surrogate

def lif_single_step_inference_closure(tau=2., v_th=1., spike_fn=surrogate.Sigmoid()):
    def lif_single_step_inference(x, v):
        h = (1 - 1/tau) * v + x
        s = spike_fn(h - v_th)
        v = h - s * v_th
        return s, v
    return lif_single_step_inference

f = lif_single_step_inference_closure(tau=99., v_th=0.5)

FlexSN Workflow#

Take the following customized spiking neuron as an example:

import torch
from spikingjelly.activation_based import surrogate

def complicated_lif_core_generator(beta: float, gamma: float, spike_fn=surrogate.ATan()):
    def complicated_lif_core(
        x: torch.Tensor, y: torch.Tensor, v: torch.Tensor, rho: torch.Tensor
    ):
        h = beta*v + x
        s1 = spike_fn(h - (rho+1.)) # spike, with threshold adaptation
        s2 = spike_fn(h - 1.)       # spike, without threshold adaptation
        rho = gamma*rho + s1        # adaptation variable update
        v1 = h * (1.-s1)            # hard reset
        v2 = h - s2                 # soft reset
        yy = torch.sigmoid(y)       # modulation factor
        v = v1*yy + v2 * (1.-yy)    # modulated reset
        return s1, s2, v, rho
    return complicated_lif_core

This model has two inputs x, y, two outputs s1, s2, and two state variables v, rho. The dependency relationships among different variables are illustrated in the figure below:

../../_images/neuron.png

To generate a multi-step Triton kernel, use FlexSN:

from spikingjelly.activation_based import neuron

f = neuron.FlexSN(
    core=complicated_lif_core_generator(beta=0.5, gamma=0.9),
    num_inputs=2,
    num_states=2,
    num_outputs=2,
    example_inputs=(
        torch.zeros([1], device="cuda"), torch.zeros([1], device="cuda"),
        torch.zeros([1], device="cuda"), torch.zeros([1], device="cuda"),
    ),
    requires_grad=(True, True, True, True),
    step_mode="m",
    backend="inductor",
    store_state_seqs=True,
)

x = torch.randn([16, 3, 32, 32], device="cuda")
y = torch.randn([16, 3, 32, 32], device="cuda")
s1, s2 = f(x, y)
v, rho = f.state_seqs

print(s1.mean()) # tensor(0.0821, device='cuda:0', grad_fn=<MeanBackward0>)
print(s2.mean()) # tensor(0.1494, device='cuda:0', grad_fn=<MeanBackward0>)
print(v.mean()) # tensor(-0.2750, device='cuda:0', grad_fn=<MeanBackward0>)
print(rho.mean()) # tensor(0.4842, device='cuda:0', grad_fn=<MeanBackward0>)

The construction of FlexSN requires the following arguments:

  • core : a function that describes the single-step neuron dynamics, with the signature [*inputs, *states] -> [*outputs, *states].

  • num_inputs, num_states, num_outputs : the numbers of inputs, state variables, and outputs, which should be consistent with the signature of core.

  • example_inputs : example arguments for core. FlexSN will call core with these example inputs in order to capture the computation graph.

  • example_outputs : optional, per-step output templates for core. They are mainly used to determine output shapes and dtypes when the input sequence is empty (T == 0). On the "triton" / "inductor" path, if this argument is provided, each template tensor should match the first example_inputs tensor's per-step shape and dtype.

  • requires_grad : whether the arguments of core require gradients. The default value is None, which means that all arguments require gradients (i.e., equivalent to all True).

  • step_mode, backend : similar to other neuron modules, these two arguments determine the step mode and the backend. The "torch" backend is always available. The "triton", "inductor", and "hop" backends are only valid when step_mode="m". In FlexSN, "triton" and "inductor" are equivalent labels for the same Triton path, while "hop" uses the HOP/eager-scan path.

  • store_state_seqs : similar to store_v_seq in other neuron modules, this argument determines whether state sequences are stored. If True, the state sequences from the last run can be accessed via the state_seqs attribute. This attribute is a list, where each element corresponds to the sequence of a specific state variable.

FlexSN also supports backward propagation, as shown in the following code block:

n_inductor = neuron.FlexSN(
    core=complicated_lif_core_generator(beta=0.5, gamma=0.9),
    num_inputs=2,
    num_states=2,
    num_outputs=2,
    example_inputs=(
        torch.zeros([1], device="cuda"), torch.zeros([1], device="cuda"),
        torch.zeros([1], device="cuda"), torch.zeros([1], device="cuda"),
    ),
    requires_grad=(True, True, True, True),
    step_mode="m",
    backend="inductor",
    store_state_seqs=True,
)

n_torch = neuron.FlexSN(
    core=complicated_lif_core_generator(beta=0.5, gamma=0.9),
    num_inputs=2,
    num_states=2,
    num_outputs=2,
    example_inputs=(
        torch.zeros([1], device="cuda"), torch.zeros([1], device="cuda"),
        torch.zeros([1], device="cuda"), torch.zeros([1], device="cuda"),
    ),
    requires_grad=(True, True, True, True),
    step_mode="m",
    backend="torch",
    store_state_seqs=True,
)

x = torch.randn([16, 3, 32, 32], device="cuda")
y = torch.randn([16, 3, 32, 32], device="cuda")
x_inductor = x.clone().requires_grad_(True)
y_inductor = y.clone().requires_grad_(True)
x_torch = x.clone().requires_grad_(True)
y_torch = y.clone().requires_grad_(True)

s1_inductor, s2_inductor = n_inductor(x_inductor, y_inductor)
s1_torch, s2_torch = n_torch(x_torch, y_torch)
grad = torch.randn_like(s1_inductor)
s1_inductor.backward(grad)
s1_torch.backward(grad)

v_inductor, rho_inductor = n_inductor.state_seqs
v_torch, rho_torch = n_torch.state_seqs

assert torch.allclose(s1_inductor, s1_torch)
assert torch.allclose(s2_inductor, s2_torch)
assert torch.allclose(x_inductor.grad, x_torch.grad, atol=1e-6, rtol=1e-6)
assert torch.allclose(y_inductor.grad, y_torch.grad, atol=1e-6, rtol=1e-6)
assert torch.allclose(v_inductor, v_torch, atol=1e-6, rtol=1e-6)
assert torch.allclose(rho_inductor, rho_torch)
print(s1_inductor.mean())
print(s2_inductor.mean())
print(x_inductor.grad.mean())
print(y_inductor.grad.mean())
print(v_inductor.mean())
print(rho_inductor.mean())

All assert statements pass, and the outputs are shown below. This demonstrates that the Triton scan kernels used by FlexSN are equivalent to the original PyTorch function in both forward and backward propagation.

tensor(0.0821, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.1494, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0007, device='cuda:0')
tensor(6.2995e-05, device='cuda:0')
tensor(-0.2750, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.4842, device='cuda:0', grad_fn=<MeanBackward0>)

With the workflow described above, users can obtain Triton-accelerated neuron models with very little code. Compared with the former auto_cuda module (see Implement CUPY Neuron), FlexSN is more flexible and general.

Note

In the example above, the state variables v and rho use the default zero initialization. Users can override the init_states() method to change the state initialization rule. The original definition of this method is shown below, where *args represents the arguments of forward():

class FlexSN(base.MemoryModule):

    ...

    @staticmethod
    def init_states(num_states: int, step_mode: str, *args) -> List[torch.Tensor]:
        if step_mode == "s":
            return [torch.zeros_like(args[0]) for _ in range(num_states)]
        elif step_mode == "m":
            return [torch.zeros_like(args[0][0]) for _ in range(num_states)]
        else:
            raise ValueError(f"Unsupported step mode: {step_mode}")

See FlexSN.init_states for details.

Note

FlexSN implements most features of SpikingJelly's neuron modules. To call the generated kernels directly in a lightweight, transparent, and functional style, use FlexSNKernel.

Warning

When using FlexSN, please note the following:

  • The "torch" backend can run on CPU or GPU. The "triton" and "inductor" backends require a GPU, and the "triton", "inductor", and "hop" backends only support multi-step mode step_mode="m".

  • The PyTorch backend is implemented by repeatedly calling core.

  • In the design of FlexSN, compromises are made in efficiency in order to pursue generality. At present, IFNode, LIFNode, and PLIFNode are equipped with highly optimized predefined Triton kernels. Please use these predefined kernels whenever possible to obtain higher performance.

  • After completing a simulation with FlexSN, reset() must be called to reset the neuron states.

Compatibility with torch.compile#

FlexSN exposes two equivalent backend labels for the same Triton path: backend="triton" and backend="inductor". The latter is the custom-op-wrapped entry to the same maintained Triton execution path. In practice, choose whichever label is clearer in your codebase; behavior and kernel generation are aligned.

Key properties:

  • Integrates with torch.compile, enabling cross-layer fusion with surrounding modules.

  • Ships dedicated Triton kernels for both inference and training.

  • Provides specialized final-state fast paths and HOP/eager fallbacks when dedicated kernels are unavailable.

Important

  • This section is about the Triton path. The "triton" and "inductor" backends require CUDA.

  • Ops inside core must be in the FX_TO_TRITON table. Unsupported ops fall back to eager_scan with a WARNING log. See Op Coverage below for the full list.

  • For training, core should use a surrogate gradient (e.g. Sigmoid) instead of a hard threshold; hard thresholds yield zero gradients by design.

Minimal Examples#

Inference:

import torch
import torch.nn as nn
from spikingjelly.activation_based.neuron.flexsn import FlexSN

def lif_core(x: torch.Tensor, v: torch.Tensor):
    tau, v_th = 2.0, 1.0
    h = v + (x - v) / tau
    s = (h >= v_th).to(h.dtype)
    return s, h * (1.0 - s)

neuron = FlexSN(core=lif_core, num_inputs=1, num_states=1,
                num_outputs=1, step_mode="m", backend="inductor").cuda()

x = torch.randn(8, 64, 512, device="cuda")
with torch.no_grad():
    out = neuron(x)   # no torch.compile required

# Optional: wrap with torch.compile for cross-layer fusion
model = nn.Sequential(nn.Linear(512, 512), neuron, nn.Linear(512, 512)).cuda()
model = torch.compile(model, fullgraph=True)
out = model(x)

Training:

import torch
from spikingjelly.activation_based import surrogate
from spikingjelly.activation_based.neuron.flexsn import FlexSN

sg = surrogate.Sigmoid(alpha=4.0)

def lif_core_sg(x: torch.Tensor, v: torch.Tensor):
    tau, v_th = 2.0, 1.0
    h = v + (x - v) / tau
    s = sg(h - v_th)        # Sigmoid surrogate gradient
    return s, h * (1.0 - s)

neuron = FlexSN(core=lif_core_sg, num_inputs=1, num_states=1,
                num_outputs=1, step_mode="m", backend="inductor").cuda()

x = torch.randn(8, 64, 512, device="cuda", requires_grad=True)
out = neuron(x)
out.sum().backward()        # BPTT via dedicated Triton fwd+bwd kernels
print(x.grad.shape)         # [8, 64, 512]

torch.compile is optional in both examples. Without it, the "triton" and "inductor" backends still use the same dedicated Triton scan kernels. With it, FlexSN stays in the compiled graph through the custom-op path, which is the mode to use when benchmarking cross-layer fusion with surrounding Linear / Conv modules.

Supported Backends#

Backend

Device

Execution path

Typical use / notes

"torch"

CPU / CUDA

Pure PyTorch multi-step loop

Reference implementation, debugging, CPU prototyping

"triton" / "inductor"

CUDA

Same Triton path

Primary high-performance path. inductor is custom-op-wrapped entry to same Triton kernels

"hop"

CPU / CUDA

HOP / eager scan path

Experimentation, scan tracing, fallback path

Runtime Behavior#

Under backend="triton" or backend="inductor":

  • Inference traces core with make_fx and emits one Triton scan kernel with tl.static_range(T). Each inference call launches exactly one kernel, independent of T.

  • Training traces both forward and backward and emits dedicated Triton forward/backward scan kernels. Without torch.compile, this is already the full Triton path.

  • With torch.compile, FlexSN still dispatches the same Triton kernels through opaque custom ops, while surrounding layers can be jointly compiled.

  • If core uses unsupported ops, or if Triton kernels cannot be built, FlexSN falls back to the HOP/eager-scan path.

Practical recommendation:

  • Use the "torch" backend for CPU work, debugging, or when you want simplest semantics.

  • Use the "triton" or "inductor" backend for actual high-performance GPU execution.

  • Add torch.compile only when you want cross-layer fusion with surrounding modules.

Op Coverage#

The FX_TO_TRITON table currently covers the following ATen ops (supported for both inference and training):

Category

Ops

Arithmetic

add, sub, mul, div, reciprocal, neg, rsub

Transcendentals

exp, log, log2, sqrt, rsqrt, tanh, sin, cos, erf

Rounding

floor, ceil, round

Activation / threshold

relu, sigmoid, sign / sgn, abs

Comparisons

eq, ne, ge, le, gt, lt

Logic / bitwise

logical_and / or / not, bitwise_and / or / not

Binary math

minimum, maximum, pow, fmod

Clamp

clamp, clamp_min, clamp_max

Type / construction

_to_copy (type cast), scalar_tensor, zeros_like, ones_like

Selection

where, masked_fill

Backward-only

sigmoid_backward, tanh_backward, threshold_backward

Misc

clone, detach, spike_fn

Ops not in this table (e.g. matrix ops, complex control flow) trigger eager_scan fallback with a WARNING log.