梯度替代

本教程作者: fangwei123456

神经元 中我们已经提到过,描述神经元放电过程的 \(S[t] = \Theta(H[t] - V_{threshold})\),使用了一个Heaviside阶跃函数:

\[\begin{split}\Theta(x) = \begin{cases} 1, & x \geq 0 \\ 0, & x < 0 \end{cases}\end{split}\]

按照定义,其导数为冲激函数:

\[\begin{split}\delta(x) = \begin{cases} +\infty, & x = 0 \\ 0, & x \neq 0 \end{cases}\end{split}\]

直接使用冲激函数进行梯度下降,显然会使得网络的训练及其不稳定。为了解决这一问题,各种梯度替代法(the surrogate gradient method)被相继提出,参见此综述 Surrogate Gradient Learning in Spiking Neural Networks

替代函数在神经元中被用于生成脉冲,查看 BaseNode.neuronal_fire 的源代码可以发现:

# spikingjelly.activation_based.neuron
class BaseNode(base.MemoryModule):
    def __init__(..., surrogate_function: Callable = surrogate.Sigmoid(), ...)
    # ...
    self.surrogate_function = surrogate_function
    # ...


    def neuronal_fire(self):
        return self.surrogate_function(self.v - self.v_threshold)

梯度替代法的原理是,在前向传播时使用 \(y = \Theta(x)\),而在反向传播时则使用 \(\frac{\mathrm{d}y}{\mathrm{d}x} = \sigma'(x)\),而非\(\frac{\mathrm{d}y}{\mathrm{d}x} = \Theta'(x)\),其中 \(\sigma(x)\) 即为替代函数。\(\sigma(x)\) 通常是一个形状与 \(\Theta(x)\) 类似,但光滑连续的函数。

spikingjelly.activation_based.surrogate 中提供了一些常用的替代函数,其中Sigmoid函数 \(\sigma(x, \alpha) = \frac{1}{1 + \exp(-\alpha x)}\)spikingjelly.activation_based.surrogate.Sigmoid,下图展示了原始的Heaviside阶跃函数 Heavisidealpha=5 时的Sigmoid原函数 Primitive 以及其梯度 Gradient

../_images/Sigmoid.svg

替代函数的使用比较简单,使用替代函数就像是使用函数一样:

import torch
from spikingjelly.activation_based import surrogate

sg = surrogate.Sigmoid(alpha=4.)

x = torch.rand([8]) - 0.5
x.requires_grad = True
y = sg(x)
y.sum().backward()
print(f'x={x}')
print(f'y={y}')
print(f'x.grad={x.grad}')

输出为:

x=tensor([-0.1303,  0.4976,  0.3364,  0.4296,  0.2779,  0.4580,  0.4447,  0.2466],
   requires_grad=True)
y=tensor([0., 1., 1., 1., 1., 1., 1., 1.], grad_fn=<sigmoidBackward>)
x.grad=tensor([0.9351, 0.4231, 0.6557, 0.5158, 0.7451, 0.4759, 0.4943, 0.7913])

每个替代函数,除了有形如 spikingjelly.activation_based.surrogate.Sigmoid 的模块风格API,也提供了形如 spikingjelly.activation_based.surrogate.sigmoid 函数风格的API。模块风格的API使用驼峰命名法,而函数风格的API使用下划线命名法,关系类似于 torch.nntorch.nn.functional,下面是几个示例:

模块

函数

Sigmoid

sigmoid

SoftSign

soft_sign

LeakyKReLU

leaky_k_relu

下面是函数风格API的用法示例:

import torch
from spikingjelly.activation_based import surrogate

alpha = 4.
x = torch.rand([8]) - 0.5
x.requires_grad = True
y = surrogate.sigmoid.apply(x, alpha)
y.sum().backward()
print(f'x={x}')
print(f'y={y}')
print(f'x.grad={x.grad}')

替代函数通常会有1个或多个控制形状的超参数,例如 spikingjelly.activation_based.surrogate.Sigmoid 中的 alpha。SpikingJelly中替代函数的形状参数,默认情况下是使得替代函数梯度最大值为1,这在一定程度上可以避免梯度累乘导致的梯度爆炸问题。