SpikingFlow.softbp.accelerating 源代码

import torch
import torch.nn as nn
import torch.nn.functional as F


[文档]class multiply_spike(torch.autograd.Function):
[文档] @staticmethod def forward(ctx, x: torch.Tensor, spike: torch.Tensor): # y = x * spike # x乘spike,等价于将x中spike == 0的位置全部填充为0 assert x.shape == spike.shape, print('x.shape != spike.shape') # 禁用广播机制 mask = torch.logical_not(spike.bool()) if x.requires_grad and spike.requires_grad: ctx.save_for_backward(mask, x) elif x.requires_grad and not spike.requires_grad: ctx.save_for_backward(mask) elif not x.requires_grad and spike.requires_grad: ctx.save_for_backward(x) return x.masked_fill(mask, 0)
[文档] @staticmethod def backward(ctx, grad_output: torch.Tensor): grad_x = None grad_spike = None # grad_x = grad_output * spike # grad_spike = grad_output * x if ctx.needs_input_grad[0] and ctx.needs_input_grad[1]: grad_x = grad_output.masked_fill(ctx.saved_tensors[0], 0) grad_spike = grad_output * ctx.saved_tensors[1] elif ctx.needs_input_grad[0] and not ctx.needs_input_grad[1]: grad_x = grad_output.masked_fill(ctx.saved_tensors[0], 0) elif not ctx.needs_input_grad[0] and ctx.needs_input_grad[1]: grad_spike = grad_output * ctx.saved_tensors[0] return grad_x, grad_spike
[文档]class add_spike(torch.autograd.Function):
[文档] @staticmethod def forward(ctx, x: torch.Tensor, spike: torch.Tensor): # y = x + spike # x乘spike,等价于将x中spike == 1的位置增加1 assert x.shape == spike.shape, print('x.shape != spike.shape') # 禁用广播机制 mask = spike.bool() y = x.clone() y[mask] += 1 return y
[文档] @staticmethod def backward(ctx, grad_output: torch.Tensor): grad_x = None grad_spike = None if ctx.needs_input_grad[0]: grad_x = grad_output if ctx.needs_input_grad[1]: grad_spike = grad_output return grad_x, grad_spike
[文档]class subtract_spike(torch.autograd.Function):
[文档] @staticmethod def forward(ctx, x: torch.Tensor, spike: torch.Tensor): # y = x - spike # x乘spike,等价于将x中spike == 1的位置减去1 assert x.shape == spike.shape, print('x.shape != spike.shape') # 禁用广播机制 mask = spike.bool() y = x.clone() y[mask] -= 1 return y
[文档] @staticmethod def backward(ctx, grad_output: torch.Tensor): grad_x = None grad_spike = None if ctx.needs_input_grad[0]: grad_x = grad_output if ctx.needs_input_grad[1]: grad_spike = - grad_output return grad_x, grad_spike
[文档]def add(x: torch.Tensor, spike: torch.Tensor): ''' :param x: 任意tensor :param spike: 脉冲tensor。要求spike中的元素只能为0或1,且spike.shape必须与x.shape相同 :return: x + spike 针对与脉冲这一特殊的数据类型,进行前反向传播加速并保持数值稳定的加法运算。 ''' return add_spike.apply(x, spike)
[文档]def sub(x: torch.Tensor, spike: torch.Tensor): ''' :param x: 任意tensor :param spike: 脉冲tensor。要求spike中的元素只能为0或1,且spike.shape必须与x.shape相同 :return: x - spike 针对与脉冲这一特殊的数据类型,进行前反向传播加速并保持数值稳定的减法运算。 ''' return subtract_spike.apply(x, spike)
[文档]def mul(x: torch.Tensor, spike: torch.Tensor): ''' :param x: 任意tensor :param spike: 脉冲tensor。要求spike中的元素只能为0或1,且spike.shape必须与x.shape相同 :return: x * spike 针对与脉冲这一特殊的数据类型,进行前反向传播加速并保持数值稳定的乘法运算。 ''' return multiply_spike.apply(x, spike)
[文档]class soft_vlotage_transform_function(torch.autograd.Function):
[文档] @staticmethod def forward(ctx, v: torch.Tensor, spike: torch.Tensor, v_threshold: float): # v = v - spike * v_threshold mask = spike.bool() # 表示释放脉冲的位置 if spike.requires_grad: ctx.v_threshold = v_threshold ret = v.clone() ret[mask] -= v_threshold return ret # 释放脉冲的位置,电压设置为v_reset,out-of-place操作
[文档] @staticmethod def backward(ctx, grad_output: torch.Tensor): grad_v = None grad_spike = None if ctx.needs_input_grad[0]: grad_v = grad_output # 因为输出对v的梯度是全1 if ctx.needs_input_grad[1]: grad_spike = - ctx.v_threshold * grad_output return grad_v, grad_spike, None
[文档]def soft_vlotage_transform(v: torch.Tensor, spike: torch.Tensor, v_threshold: float): ''' :param v: 重置前电压 :param spike: 释放的脉冲 :param v_threshold: 阈值电压 :return: 重置后的电压 根据释放的脉冲,以soft方式重置电压,即释放脉冲后,电压会减去阈值::math:`v = v - s \\cdot v_{threshold}`。 该函数针对脉冲数据进行了前反向传播的加速,并能节省内存,且保持数值稳定。 ''' return soft_vlotage_transform_function.apply(v, spike, v_threshold)
[文档]class hard_voltage_transform_function(torch.autograd.Function):
[文档] @staticmethod def forward(ctx, v: torch.Tensor, spike: torch.Tensor, v_reset: float): # v = v * (1 - spikes) + v_reset * spikes mask = spike.bool() # 表示释放脉冲的位置 if v.requires_grad and spike.requires_grad: ctx.save_for_backward(mask, v_reset - v) elif v.requires_grad and not spike.requires_grad: ctx.save_for_backward(mask) elif not v.requires_grad and spike.requires_grad: ctx.save_for_backward(v_reset - v) return v.masked_fill(mask, v_reset) # 释放脉冲的位置,电压设置为v_reset,out-of-place操作
[文档] @staticmethod def backward(ctx, grad_output: torch.Tensor): grad_v = None grad_spike = None if ctx.needs_input_grad[0] and ctx.needs_input_grad[1]: grad_v = grad_output.masked_fill(ctx.saved_tensors[0], 0) grad_spike = grad_output * ctx.saved_tensors[1] elif ctx.needs_input_grad[0] and not ctx.needs_input_grad[1]: grad_v = grad_output.masked_fill(ctx.saved_tensors[0], 0) elif not ctx.needs_input_grad[0] and ctx.needs_input_grad[1]: grad_spike = grad_output * ctx.saved_tensors[0] return grad_v, grad_spike, None
[文档]def hard_voltage_transform(v: torch.Tensor, spike: torch.Tensor, v_reset: float): ''' :param v: 重置前电压 :param spike: 释放的脉冲 :param v_reset: 重置电压 :return: 重置后的电压 根据释放的脉冲,以hard方式重置电压,即释放脉冲后,电压会直接置为重置电压::math:`v = v \\cdot (1-s) + v_{reset} \\cdot s`。 该函数针对脉冲数据进行了前反向传播的加速,并能节省内存,且保持数值稳定。 ''' return hard_voltage_transform_function.apply(v, spike, v_reset)