梯度替代
本教程作者: fangwei123456
在 神经元 中我们已经提到过,描述神经元放电过程的 \(S[t] = \Theta(H[t] - V_{threshold})\),使用了一个Heaviside阶跃函数:
按照定义,其导数为冲激函数:
直接使用冲激函数进行梯度下降,显然会使得网络的训练及其不稳定。为了解决这一问题,各种梯度替代法(the surrogate gradient method)被相继提出,参见此综述 Surrogate Gradient Learning in Spiking Neural Networks。
替代函数在神经元中被用于生成脉冲,查看 BaseNode.neuronal_fire
的源代码可以发现:
# spikingjelly.activation_based.neuron
class BaseNode(base.MemoryModule):
def __init__(..., surrogate_function: Callable = surrogate.Sigmoid(), ...)
# ...
self.surrogate_function = surrogate_function
# ...
def neuronal_fire(self):
return self.surrogate_function(self.v - self.v_threshold)
梯度替代法的原理是,在前向传播时使用 \(y = \Theta(x)\),而在反向传播时则使用 \(\frac{\mathrm{d}y}{\mathrm{d}x} = \sigma'(x)\),而非\(\frac{\mathrm{d}y}{\mathrm{d}x} = \Theta'(x)\),其中 \(\sigma(x)\) 即为替代函数。\(\sigma(x)\) 通常是一个形状与 \(\Theta(x)\) 类似,但光滑连续的函数。
在 spikingjelly.activation_based.surrogate
中提供了一些常用的替代函数,其中Sigmoid函数 \(\sigma(x, \alpha) = \frac{1}{1 + \exp(-\alpha x)}\) 为 spikingjelly.activation_based.surrogate.Sigmoid
,下图展示了原始的Heaviside阶跃函数 Heaviside
、 alpha=5
时的Sigmoid原函数 Primitive
以及其梯度 Gradient
:
替代函数的使用比较简单,使用替代函数就像是使用函数一样:
import torch
from spikingjelly.activation_based import surrogate
sg = surrogate.Sigmoid(alpha=4.)
x = torch.rand([8]) - 0.5
x.requires_grad = True
y = sg(x)
y.sum().backward()
print(f'x={x}')
print(f'y={y}')
print(f'x.grad={x.grad}')
输出为:
x=tensor([-0.1303, 0.4976, 0.3364, 0.4296, 0.2779, 0.4580, 0.4447, 0.2466],
requires_grad=True)
y=tensor([0., 1., 1., 1., 1., 1., 1., 1.], grad_fn=<sigmoidBackward>)
x.grad=tensor([0.9351, 0.4231, 0.6557, 0.5158, 0.7451, 0.4759, 0.4943, 0.7913])
每个替代函数,除了有形如 spikingjelly.activation_based.surrogate.Sigmoid
的模块风格API,也提供了形如 spikingjelly.activation_based.surrogate.sigmoid
函数风格的API。模块风格的API使用驼峰命名法,而函数风格的API使用下划线命名法,关系类似于 torch.nn
和 torch.nn.functional
,下面是几个示例:
模块 |
函数 |
---|---|
|
|
|
|
|
|
下面是函数风格API的用法示例:
import torch
from spikingjelly.activation_based import surrogate
alpha = 4.
x = torch.rand([8]) - 0.5
x.requires_grad = True
y = surrogate.sigmoid.apply(x, alpha)
y.sum().backward()
print(f'x={x}')
print(f'y={y}')
print(f'x.grad={x.grad}')
替代函数通常会有1个或多个控制形状的超参数,例如 spikingjelly.activation_based.surrogate.Sigmoid
中的 alpha
。SpikingJelly中替代函数的形状参数,默认情况下是使得替代函数梯度最大值为1,这在一定程度上可以避免梯度累乘导致的梯度爆炸问题。