Monitor

Author: fangwei123456

spikingjelly.activation_based.monitor has defined some commonly used monitors, with which the users can record the data that they are interested in. Now let us try these monitors.

Usage

All monitors have similar usage. Let us take spikingjelly.activation_based.monitor.OutputMonitor as the example.

Firstly, let us build a simple single-step network. To avoid no spikes, we set all weights to be positive:

spike_seq_monitor = monitor.OutputMonitor(net, neuron.IFNode)
T = 4
N = 1
x_seq = torch.rand([T, N, 8])

with torch.no_grad():
    net(x_seq)

The recorded data will be stored in .records whose type is list. The data are recorded by the order in how they are created:

print(f'spike_seq_monitor.records=\n{spike_seq_monitor.records}')

The outputs are:

spike_seq_monitor.records=
[tensor([[[0., 0., 0., 0.]],

        [[1., 1., 1., 1.]],

        [[0., 0., 0., 0.]],

        [[1., 1., 1., 1.]]]), tensor([[[0., 0.]],

        [[1., 0.]],

        [[0., 1.]],

        [[1., 0.]]])]

We can also use the index to get the i-th data:

print(f'spike_seq_monitor[0]={spike_seq_monitor[0]}')

The outputs are:

spike_seq_monitor[0]=tensor([[[0., 0., 0., 0.]],

        [[1., 1., 1., 1.]],

        [[0., 0., 0., 0.]],

        [[1., 1., 1., 1.]]])

The names of monitored layers are stored in .monitored_layers:

print(f'net={net}')
print(f'spike_seq_monitor.monitored_layers={spike_seq_monitor.monitored_layers}')

The outputs are:

net=Sequential(
(0): Linear(in_features=8, out_features=4, bias=True)
(1): IFNode(
    v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=torch
    (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
)
(2): Linear(in_features=4, out_features=2, bias=True)
(3): IFNode(
    v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=torch
    (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
)
)
spike_seq_monitor.monitored_layers=['1', '3']

We can also use the name as the index to get the recorded data of the layer, which are stored in a list:

print(f"spike_seq_monitor['1']={spike_seq_monitor['1']}")

The outputs are:

spike_seq_monitor['1']=[tensor([[[0., 0., 0., 0.]],

    [[1., 1., 1., 1.]],

    [[0., 0., 0., 0.]],

    [[1., 1., 1., 1.]]])]

We can call .clear_recorded_data() to clear the recorded data:

spike_seq_monitor.clear_recorded_data()
print(f'spike_seq_monitor.records={spike_seq_monitor.records}')
print(f"spike_seq_monitor['1']={spike_seq_monitor['1']}")

The outputs are:

spike_seq_monitor.records=[]
spike_seq_monitor['1']=[]

All monitor will remove hooks when they are deleted. However, python will not guarantee to call the __del__() function of the monitor even if we call del a_monitor manually:

del spike_seq_monitor
# hooks may still work

Instead, we should call remove_hooks to remove all hooks:

spike_seq_monitor.remove_hooks()

OutputMonitor can also process the data when recording, which is implemented by function_on_output. The default value of function_on_output is lambda x: x, which means record the origin data. If we want to record the firing rates, we can define the function of calculating the firing rates:

def cal_firing_rate(s_seq: torch.Tensor):
    # s_seq.shape = [T, N, *]
    return s_seq.flatten(1).mean(1)

Then, we can set this function as function_on_output to get a firing rates monitor:

fr_monitor = monitor.OutputMonitor(net, neuron.IFNode, cal_firing_rate)

.disable() can pause monitor, and .enable() can restart monitor:

with torch.no_grad():
    fr_monitor.disable()
    net(x_seq)
    functional.reset_net(net)
    print(f'after call fr_monitor.disable(), fr_monitor.records=\n{fr_monitor.records}')

    fr_monitor.enable()
    net(x_seq)
    print(f'after call fr_monitor.enable(), fr_monitor.records=\n{fr_monitor.records}')
    functional.reset_net(net)
    del fr_monitor

The outputs are:

after call fr_monitor.disable(), fr_monitor.records=
[]
after call fr_monitor.enable(), fr_monitor.records=
[tensor([0.0000, 1.0000, 0.5000, 1.0000]), tensor([0., 1., 0., 1.])]

Record Attributes

To record the attributes of some modules, e.g., the membrane potential, we can use spikingjelly.activation_based.monitor.AttributeMonitor.

store_v_seq: bool = False is the default arg in __init__ of spiking neurons, which means only v at the last time-step will be stored, and v_seq at each time-step will not be sotred. To record all \(V[t]\), we set store_v_seq = True:

for m in net.modules():
    if isinstance(m, neuron.IFNode):
        m.store_v_seq = True

Then, we use spikingjelly.activation_based.monitor.AttributeMonitor to record:

v_seq_monitor = monitor.AttributeMonitor('v_seq', pre_forward=False, net=net, instance=neuron.IFNode)
with torch.no_grad():
    net(x_seq)
    print(f'v_seq_monitor.records=\n{v_seq_monitor.records}')
    functional.reset_net(net)
    del v_seq_monitor

The outputs are:

v_seq_monitor.records=
[tensor([[[0.8102, 0.8677, 0.8153, 0.9200]],

        [[0.0000, 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.8129, 0.0000, 0.9263]],

        [[0.0000, 0.0000, 0.0000, 0.0000]]]), tensor([[[0.2480, 0.4848]],

        [[0.0000, 0.0000]],

        [[0.8546, 0.6674]],

        [[0.0000, 0.0000]]])]

Record Inputs

To record inputs, we can use spikingjelly.activation_based.monitor.InputMonitor, which is similar to spikingjelly.activation_based.monitor.OutputMonitor:

input_monitor = monitor.InputMonitor(net, neuron.IFNode)
with torch.no_grad():
    net(x_seq)
    print(f'input_monitor.records=\n{input_monitor.records}')
    functional.reset_net(net)
    del input_monitor

The outputs are:

input_monitor.records=
[tensor([[[1.1710, 0.7936, 0.9325, 0.8227]],

        [[1.4373, 0.7645, 1.2167, 1.3342]],

        [[1.6011, 0.9850, 1.2648, 1.2650]],

        [[0.9322, 0.6143, 0.7481, 0.9770]]]), tensor([[[0.8072, 0.7733]],

        [[1.1186, 1.2176]],

        [[1.0576, 1.0153]],

        [[0.4966, 0.6030]]])]

Record the Input Gradients \(\frac{\partial L}{\partial Y}\)

We can use spikingjelly.activation_based.monitor.GradOutputMonitor to record the input gradients \(\frac{\partial L}{\partial S}\) of each module:

spike_seq_grad_monitor = monitor.GradOutputMonitor(net, neuron.IFNode)
net(x_seq).sum().backward()
print(f'spike_seq_grad_monitor.records=\n{spike_seq_grad_monitor.records}')
functional.reset_net(net)
del spike_seq_grad_monitor

The outputs are:

spike_seq_grad_monitor.records=
[tensor([[[1., 1.]],

        [[1., 1.]],

        [[1., 1.]],

        [[1., 1.]]]), tensor([[[ 0.0803,  0.0383,  0.1035,  0.1177]],

        [[-0.1013, -0.1346, -0.0561, -0.0085]],

        [[ 0.5364,  0.6285,  0.3696,  0.1818]],

        [[ 0.3704,  0.4747,  0.2201,  0.0596]]])]

Note that the input gradients of the last layer’s output spikes are all 1 because we use .sum().backward().

Record the Output Gradients \(\frac{\partial L}{\partial X}\)

We can use spikingjelly.activation_based.monitor.GradInputMonitor to record the output gradients \(\frac{\partial L}{\partial X}\) of each module.

Let us build a deep SNN, tune alpha for surrogate functions, and compare the effect:

import torch
import torch.nn as nn
from spikingjelly.activation_based import monitor, neuron, functional, layer, surrogate

net = []
for i in range(10):
    net.append(layer.Linear(8, 8))
    net.append(neuron.IFNode())

net = nn.Sequential(*net)

functional.set_step_mode(net, 'm')

T = 4
N = 1
x_seq = torch.rand([T, N, 8])

input_grad_monitor = monitor.GradInputMonitor(net, neuron.IFNode, function_on_grad_input=torch.norm)

for alpha in [0.1, 0.5, 2, 4, 8]:
    for m in net.modules():
        if isinstance(m, surrogate.Sigmoid):
            m.alpha = alpha
    net(x_seq).sum().backward()
    print(f'alpha={alpha}, input_grad_monitor.records=\n{input_grad_monitor.records}\n')
    functional.reset_net(net)
    # zero grad
    for param in net.parameters():
        param.grad.zero_()

    input_grad_monitor.records.clear()

The outputs are:

alpha=0.1, input_grad_monitor.records=
[tensor(0.3868), tensor(0.0138), tensor(0.0003), tensor(9.1888e-06), tensor(1.0164e-07), tensor(1.9384e-09), tensor(4.0199e-11), tensor(8.6942e-13), tensor(1.3389e-14), tensor(2.7714e-16)]

alpha=0.5, input_grad_monitor.records=
[tensor(1.7575), tensor(0.2979), tensor(0.0344), tensor(0.0045), tensor(0.0002), tensor(1.5708e-05), tensor(1.6167e-06), tensor(1.6107e-07), tensor(1.1618e-08), tensor(1.1097e-09)]

alpha=2, input_grad_monitor.records=
[tensor(3.3033), tensor(1.2917), tensor(0.4673), tensor(0.1134), tensor(0.0238), tensor(0.0040), tensor(0.0008), tensor(0.0001), tensor(2.5466e-05), tensor(3.9537e-06)]

alpha=4, input_grad_monitor.records=
[tensor(3.5353), tensor(1.6377), tensor(0.7076), tensor(0.2143), tensor(0.0369), tensor(0.0069), tensor(0.0026), tensor(0.0006), tensor(0.0003), tensor(8.5736e-05)]

alpha=8, input_grad_monitor.records=
[tensor(4.3944), tensor(2.4396), tensor(0.8996), tensor(0.4376), tensor(0.0640), tensor(0.0122), tensor(0.0053), tensor(0.0016), tensor(0.0013), tensor(0.0005)]