import logging
try:
import cupy
except BaseException as e:
logging.info(f'spikingjelly.activation_based.neuron_kernel: {e}')
cupy = None
import torch
import torch.nn.functional as F
from . import cuda_utils, surrogate, tensor_cache
from .. import configure
import numpy as np
[文档]class MultiStepIFNodePTT(torch.autograd.Function):
[文档] @staticmethod
def create_fptt_kernel(hard_reset: bool, dtype: str):
kernel_name = f'IFNode_fptt_{"hard" if hard_reset else "soft"}Reset_{dtype}'
if dtype == 'fp32':
code = rf'''
extern "C" __global__
void {kernel_name}(const float* x_seq, float* v_v_seq, float* h_seq, float* spike_seq,
const float & v_threshold, {'const float & v_reset,' if hard_reset else ''}
const int & neuron_num, const int & numel)
'''
code += r'''
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < neuron_num)
{
const int dt = neuron_num;
for(int mem_offset = 0; mem_offset < numel; mem_offset += neuron_num)
{
const int t = index + mem_offset;
h_seq[t] = v_v_seq[t] + x_seq[t];
if (h_seq[t] >= v_threshold)
'''
if hard_reset:
code += r'''
{
spike_seq[t] = 1.0f;
v_v_seq[t + dt] = v_reset;
}
'''
else:
code += r'''
{
spike_seq[t] = 1.0f;
v_v_seq[t + dt] = h_seq[t] - v_threshold;
}
'''
code += r'''
else
{
spike_seq[t] = 0.0f;
v_v_seq[t + dt] = h_seq[t];
}
}
}
}
'''
elif dtype == 'fp16':
code = rf'''
#include <cuda_fp16.h>
extern "C" __global__
void {kernel_name}(const half2* x_seq, half2* v_v_seq, half2* h_seq, half2* spike_seq,
const half & v_threshold, {'const half & v_reset,' if hard_reset else ''}
const int & neuron_num, const int & numel)
'''
code += r'''
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
const int stride = neuron_num >> 1;
if (index < stride)
{
const int numel_2 = numel >> 1;
const half2 v_threshold_half2 = __half2half2(v_threshold);
'''
if hard_reset:
code += r'''
const half2 v_reset_half2 = __half2half2(v_reset);
'''
code += r'''
for(int mem_offset = 0; mem_offset < numel_2; mem_offset += stride)
{
const int t = index + mem_offset;
h_seq[t] = __hadd2(v_v_seq[t], x_seq[t]);
spike_seq[t] = __hgeu2(h_seq[t], v_threshold_half2);
'''
if hard_reset:
code += r'''
v_v_seq[t + stride] = __hadd2(__hmul2(spike_seq[t], v_reset_half2), __hmul2(__hsub2(__float2half2_rn(1.0f), spike_seq[t]), h_seq[t]));
'''
else:
code += r'''
v_v_seq[t + stride] = __hadd2(__hmul2(spike_seq[t], __hsub2(h_seq[t], v_threshold_half2)), __hmul2(__hsub2(__float2half2_rn(1.0f), spike_seq[t]), h_seq[t]));
'''
code += r'''
}
}
}
'''
else:
raise TypeError
return cupy.RawKernel(code, kernel_name, options=configure.cuda_compiler_options, backend=configure.cuda_compiler_backend)
[文档] @staticmethod
def create_bptt_kernel(sg_cuda_code_fun, hard_reset: bool, detach_reset: bool, dtype: str):
kernel_name = f'IFNode_bptt_{"hard" if hard_reset else "soft"}Reset_{"detachReset" if detach_reset else ""}_{dtype}'
code_grad_s_to_h = sg_cuda_code_fun(x='over_th', y='grad_s_to_h', dtype=dtype)
if dtype == 'fp32':
code = fr'''
extern "C" __global__
void {kernel_name}(
const float* grad_spike_seq, const float* grad_v_seq, const float* h_seq, const float* spike_seq,
float* grad_x_seq, float* grad_v_init,
const float & v_threshold, {'const float & v_reset,' if hard_reset else ''}
const int & neuron_num, const int & numel)
'''
code += r'''
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < neuron_num)
{
float grad_h = 0.0f; // grad_h will be used recursively
for(int mem_offset = numel - neuron_num; mem_offset >= 0; mem_offset -= neuron_num)
{
const int t = index + mem_offset;
const float over_th = h_seq[t] - v_threshold;
'''
code += code_grad_s_to_h
if detach_reset:
if hard_reset:
code_grad_v_to_h = r'''
const float grad_v_to_h = 1.0f - spike_seq[t];
'''
else:
code_grad_v_to_h = r'''
const float grad_v_to_h = 1.0f;
'''
else:
if hard_reset:
code_grad_v_to_h = r'''
const float grad_v_to_h = 1.0f - spike_seq[t] + (v_reset - h_seq[t]) * grad_s_to_h;
// const float grad_v_to_h = fmaf(grad_s_to_h, v_reset - h_seq[t], 1.0f - spike_seq[t]);
'''
else:
code_grad_v_to_h = r'''
const float grad_v_to_h = 1.0f - v_threshold * grad_s_to_h;
// const float grad_v_to_h = fmaf(-grad_s_to_h, v_threshold, 1.0f);
'''
code += code_grad_v_to_h
code += r'''
grad_h = grad_spike_seq[t] * grad_s_to_h + (grad_v_seq[t] + grad_h) * grad_v_to_h;
// grad_h = fmaf(grad_spike_seq[t], grad_s_to_h, (grad_v_seq[t] + grad_h) * grad_v_to_h);
grad_x_seq[t] = grad_h;
}
grad_v_init[index] = grad_h;
}
}
'''
elif dtype == 'fp16':
code = fr'''
#include <cuda_fp16.h>
extern "C" __global__
void {kernel_name}(
const half2* grad_spike_seq, const half2* grad_v_seq, const half2* h_seq, const half2* spike_seq,
half2* grad_x_seq, half2* grad_v_init,
const half & v_threshold, {'const half & v_reset,' if hard_reset else ''}
const int & neuron_num, const int & numel)
'''
code += r'''
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
const int stride = neuron_num >> 1;
if (index < stride)
{
const half2 v_threshold_half2 = __half2half2(v_threshold);
'''
if hard_reset:
code += r'''
const half2 v_reset_half2 = __half2half2(v_reset);
'''
code += r'''
half2 grad_h = __float2half2_rn(0.0f); // grad_h will be used recursively
for(int mem_offset = (numel >> 1) - stride; mem_offset >= 0; mem_offset -= stride)
{
const int t = index + mem_offset;
const half2 over_th = __hsub2(h_seq[t], v_threshold_half2);
'''
code += code_grad_s_to_h
if detach_reset:
if hard_reset:
code_grad_v_to_h = r'''
const half2 grad_v_to_h = __hsub2(__float2half2_rn(1.0f), spike_seq[t]);
'''
else:
code_grad_v_to_h = r'''
const half2 grad_v_to_h = __float2half2_rn(1.0f);
'''
else:
if hard_reset:
code_grad_v_to_h = r'''
const half2 grad_v_to_h = __hfma2(__hsub2(v_reset_half2, h_seq[t]), grad_s_to_h, __hsub2(__float2half2_rn(1.0f), spike_seq[t]));
'''
else:
code_grad_v_to_h = r'''
const half2 grad_v_to_h = __hsub2(__float2half2_rn(1.0f), __hmul2(v_threshold_half2, grad_s_to_h));
'''
code += code_grad_v_to_h
code += r'''
grad_h = __hfma2(__hadd2(grad_v_seq[t], grad_h), grad_v_to_h, __hmul2(grad_spike_seq[t], grad_s_to_h));
grad_x_seq[t] = grad_h;
}
grad_v_init[index] = grad_h;
}
}
'''
else:
raise TypeError
return cupy.RawKernel(code, kernel_name, options=configure.cuda_compiler_options, backend=configure.cuda_compiler_backend)
[文档] @staticmethod
def forward(ctx, x_seq: torch.Tensor, v_init: torch.Tensor, v_threshold: float, v_reset: float,
detach_reset: bool, sg_cuda_code_fun):
requires_grad = x_seq.requires_grad or v_init.requires_grad
device = x_seq.get_device()
if x_seq.dtype == torch.float32:
dtype = 'fp32'
cp_dtype = np.float32
elif x_seq.dtype == torch.float16:
dtype = 'fp16'
cp_dtype = np.half
else:
raise NotImplementedError
use_pad = False
if dtype == 'fp16' and v_init.numel() % 2 != 0:
# only fp16 needs even numel because we use half2 to accelerate
# when numel is odd, we will pad x_seq
use_pad = True
x_seq = F.pad(x_seq, (0, 1)) # [T, N] -> [T, N + 1]
v_init = F.pad(v_init, (0, 1)) # [N] -> [N + 1]
zero_shape = list(x_seq.shape)
zero_shape[0] *= 3
v_seq, h_seq, spike_seq = torch.split(torch.zeros(zero_shape, device=x_seq.device, dtype=x_seq.dtype), x_seq.shape[0])
v_v_seq = torch.cat((v_init.unsqueeze(0), v_seq))
with cuda_utils.DeviceEnvironment(device):
numel = x_seq.numel()
neuron_num = numel // x_seq.shape[0]
threads = configure.cuda_threads
if dtype == 'fp16':
assert neuron_num % 2 == 0
blocks = cuda_utils.cal_blocks(neuron_num >> 1)
# we will take two neurons to calculate as one neuron in cuda half2
else:
blocks = cuda_utils.cal_blocks(neuron_num)
cp_numel = cupy.asarray(numel)
cp_neuron_num = cupy.asarray(neuron_num)
cp_v_threshold = cupy.asarray(v_threshold, dtype=cp_dtype)
if v_reset is None:
cp_v_reset = None
hard_reset = False
x_seq, v_v_seq, h_seq, spike_seq, cp_v_threshold, cp_neuron_num, cp_numel = cuda_utils.get_contiguous(
x_seq, v_v_seq, h_seq, spike_seq, cp_v_threshold, cp_neuron_num, cp_numel)
kernel_args = [x_seq, v_v_seq, h_seq, spike_seq, cp_v_threshold, cp_neuron_num, cp_numel]
else:
cp_v_reset = cupy.asarray(v_reset, dtype=cp_dtype)
hard_reset = True
x_seq, v_v_seq, h_seq, spike_seq, cp_v_threshold, cp_v_reset, cp_neuron_num, cp_numel = cuda_utils.get_contiguous(
x_seq, v_v_seq, h_seq, spike_seq, cp_v_threshold, cp_v_reset, cp_neuron_num, cp_numel)
kernel_args = [x_seq, v_v_seq, h_seq, spike_seq, cp_v_threshold, cp_v_reset, cp_neuron_num,
cp_numel]
kernel = MultiStepIFNodePTT.create_fptt_kernel(hard_reset, dtype)
kernel(
(blocks,), (threads,),
cuda_utils.wrap_args_to_raw_kernel(
device,
*kernel_args
)
)
if requires_grad:
ctx.use_pad = use_pad
if configure.save_spike_as_bool_in_neuron_kernel:
ctx.s_shape = spike_seq.shape
ctx.s_tk = tensor_cache.BOOL_TENSOR_CACHE.store_bool(spike_seq)
ctx.save_for_backward(h_seq)
else:
ctx.save_for_backward(h_seq, spike_seq)
ctx.blocks = blocks
ctx.threads = threads
ctx.cp_numel = cp_numel
ctx.cp_neuron_num = cp_neuron_num
ctx.cp_v_threshold = cp_v_threshold
ctx.cp_v_reset = cp_v_reset
ctx.detach_reset = detach_reset
ctx.sg_cuda_code_fun = sg_cuda_code_fun
if use_pad:
return spike_seq[..., :-1], v_v_seq[1:, ..., :-1]
else:
return spike_seq, v_v_seq[1:, ]
[文档] @staticmethod
def backward(ctx, grad_spike_seq, grad_v_seq):
if ctx.use_pad:
# grad_spike_seq.shape = [T, N]
# grad_v_seq.shape = [T, N]
# h_seq.shape = [T, N + 1]
# spike_seq.shape = [T, N + 1]
grad_spike_seq = F.pad(grad_spike_seq, (0, 1))
grad_v_seq = F.pad(grad_v_seq, (0, 1))
device = grad_spike_seq.get_device()
if configure.save_spike_as_bool_in_neuron_kernel:
h_seq = ctx.saved_tensors[0]
spike_seq = tensor_cache.BOOL_TENSOR_CACHE.get_float(ctx.s_tk, ctx.s_shape)
else:
h_seq, spike_seq = ctx.saved_tensors
zero_shape = list(grad_spike_seq.shape)
zero_shape[0] += 1
zero_data = torch.zeros(zero_shape, device=grad_spike_seq.device, dtype=grad_spike_seq.dtype)
grad_x_seq = zero_data[0: -1]
grad_v_init = zero_data[-1]
if ctx.cp_v_reset is None:
hard_reset = False
else:
hard_reset = True
if grad_spike_seq.dtype == torch.float32:
dtype = 'fp32'
elif grad_spike_seq.dtype == torch.float16:
dtype = 'fp16'
else:
raise NotImplementedError
kernel = MultiStepIFNodePTT.create_bptt_kernel(ctx.sg_cuda_code_fun, hard_reset, ctx.detach_reset, dtype)
with cuda_utils.DeviceEnvironment(device):
if hard_reset:
grad_spike_seq, grad_v_seq, h_seq, spike_seq, grad_x_seq, grad_v_init, ctx.cp_v_threshold, ctx.cp_v_reset, ctx.cp_neuron_num, ctx.cp_numel = cuda_utils.get_contiguous(
grad_spike_seq, grad_v_seq, h_seq, spike_seq, grad_x_seq, grad_v_init, ctx.cp_v_threshold,
ctx.cp_v_reset, ctx.cp_neuron_num, ctx.cp_numel)
kernel_args = [grad_spike_seq, grad_v_seq, h_seq, spike_seq, grad_x_seq, grad_v_init,
ctx.cp_v_threshold, ctx.cp_v_reset, ctx.cp_neuron_num, ctx.cp_numel]
else:
grad_spike_seq, grad_v_seq, h_seq, spike_seq, grad_x_seq, grad_v_init, ctx.cp_v_threshold, ctx.cp_neuron_num, ctx.cp_numel = cuda_utils.get_contiguous(
grad_spike_seq, grad_v_seq, h_seq, spike_seq, grad_x_seq, grad_v_init, ctx.cp_v_threshold,
ctx.cp_neuron_num, ctx.cp_numel)
kernel_args = [grad_spike_seq, grad_v_seq, h_seq, spike_seq, grad_x_seq, grad_v_init,
ctx.cp_v_threshold, ctx.cp_neuron_num, ctx.cp_numel]
kernel(
(ctx.blocks,), (ctx.threads,),
cuda_utils.wrap_args_to_raw_kernel(
device,
*kernel_args
)
)
if ctx.use_pad:
return grad_x_seq[..., :-1], grad_v_init[..., :-1], None, None, None, None
else:
return grad_x_seq, grad_v_init, None, None, None, None
[文档]class MultiStepLIFNodePTT(torch.autograd.Function):
[文档] @staticmethod
def create_fptt_kernel(decay_input: bool, hard_reset: bool, dtype: str, kernel_name_prefix: str = 'LIFNode'):
kernel_name = f'{kernel_name_prefix}_fptt_decayInput{decay_input}_{"hard" if hard_reset else "soft"}Reset_{dtype}'
if dtype == 'fp32':
code = rf'''
extern "C" __global__
void {kernel_name}(const float* x_seq, float* v_v_seq, float* h_seq, float* spike_seq,
const float & reciprocal_tau,
const float & v_threshold, {'const float & v_reset,' if hard_reset else ''}
const int & neuron_num, const int & numel)
'''
code += r'''
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < neuron_num)
{
const int dt = neuron_num;
for(int mem_offset = 0; mem_offset < numel; mem_offset += neuron_num)
{
const int t = index + mem_offset;
'''
if hard_reset:
if decay_input:
code += r'''
h_seq[t] = v_v_seq[t] + reciprocal_tau * (x_seq[t] - v_v_seq[t] + v_reset);
'''
else:
code += r'''
h_seq[t] = v_v_seq[t] - reciprocal_tau * (v_v_seq[t] - v_reset) + x_seq[t];
'''
code += r'''
if (h_seq[t] >= v_threshold)
{
spike_seq[t] = 1.0f;
v_v_seq[t + dt] = v_reset;
}
'''
else:
if decay_input:
code += r'''
h_seq[t] = v_v_seq[t] + reciprocal_tau * (x_seq[t] - v_v_seq[t]);
'''
else:
code += r'''
h_seq[t] = v_v_seq[t] * (1.0f - reciprocal_tau) + x_seq[t];
'''
code += r'''
if (h_seq[t] >= v_threshold)
{
spike_seq[t] = 1.0f;
v_v_seq[t + dt] = h_seq[t] - v_threshold;
}
'''
code += r'''
else
{
spike_seq[t] = 0.0f;
v_v_seq[t + dt] = h_seq[t];
}
}
}
}
'''
elif dtype == 'fp16':
code = rf'''
#include <cuda_fp16.h>
extern "C" __global__
void {kernel_name}(const half2* x_seq, half2* v_v_seq, half2* h_seq, half2* spike_seq,
const half & reciprocal_tau,
const half & v_threshold, {'const half & v_reset,' if hard_reset else ''}
const int & neuron_num, const int & numel)
'''
code += r'''
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
const int stride = neuron_num >> 1;
if (index < stride)
{
const int numel_2 = numel >> 1;
const half2 reciprocal_tau_half2 = __half2half2(reciprocal_tau);
const half2 v_threshold_half2 = __half2half2(v_threshold);
'''
if hard_reset:
code += r'''
const half2 v_reset_half2 = __half2half2(v_reset);
'''
code += r'''
for(int mem_offset = 0; mem_offset < numel_2; mem_offset += stride)
{
const int t = index + mem_offset;
'''
if hard_reset:
if decay_input:
code += r'''
h_seq[t] = __hfma2(__hadd2(__hsub2(x_seq[t], v_v_seq[t]), v_reset_half2), reciprocal_tau_half2, v_v_seq[t]);
'''
else:
code += r'''
// h_seq[t] = v_v_seq[t] - reciprocal_tau * (v_v_seq[t] - v_reset) + x_seq[t];
// = reciprocal_tau * (v_reset - v_v_seq[t]) + v_v_seq[t] + x_seq[t];
h_seq[t] = __hadd2(__hfma2(__hsub2(v_reset_half2, v_v_seq[t]), reciprocal_tau_half2, v_v_seq[t]), x_seq[t]);
'''
code += r'''
spike_seq[t] = __hgeu2(h_seq[t], v_threshold_half2);
v_v_seq[t + stride] = __hadd2(__hmul2(spike_seq[t], v_reset_half2), __hmul2(__hsub2(__float2half2_rn(1.0f), spike_seq[t]), h_seq[t]));
'''
else:
if decay_input:
code += r'''
h_seq[t] = __hfma2(__hsub2(x_seq[t], v_v_seq[t]), reciprocal_tau_half2, v_v_seq[t]);
'''
else:
code += r'''
// h_seq[t] = v_v_seq[t] * (1.0f - reciprocal_tau) + x_seq[t];
h_seq[t] = __hfma2(__hsub2(__float2half2_rn(1.0f), reciprocal_tau_half2), v_v_seq[t], x_seq[t]);
'''
code += r'''
spike_seq[t] = __hgeu2(h_seq[t], v_threshold_half2);
v_v_seq[t + stride] = __hadd2(__hmul2(spike_seq[t], __hsub2(h_seq[t], v_threshold_half2)), __hmul2(__hsub2(__float2half2_rn(1.0f), spike_seq[t]), h_seq[t]));
'''
code += r'''
}
}
}
'''
else:
raise TypeError
return cupy.RawKernel(code, kernel_name, options=configure.cuda_compiler_options, backend=configure.cuda_compiler_backend)
[文档] @staticmethod
def create_bptt_kernel(sg_cuda_code_fun, decay_input: bool, hard_reset: bool, detach_reset: bool, dtype: str):
kernel_name = f'LIFNode_bptt_decayInput{decay_input}_{"hard" if hard_reset else "soft"}Reset_{"detachReset" if detach_reset else ""}_{dtype}'
code_grad_s_to_h = sg_cuda_code_fun(x='over_th', y='grad_s_to_h', dtype=dtype)
if dtype == 'fp32':
code = fr'''
extern "C" __global__
void {kernel_name}(
const float* grad_spike_seq, const float* grad_v_seq, const float* h_seq, const float* spike_seq,
float* grad_x_seq, float* grad_v_init,
const float & reciprocal_tau, const float & one_sub_reciprocal_tau,
const float & v_threshold, {'const float & v_reset,' if hard_reset else ''}
const int & neuron_num, const int & numel)
'''
code += r'''
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < neuron_num)
{
float grad_h = 0.0f; // grad_h will be used recursively
for(int mem_offset = numel - neuron_num; mem_offset >= 0; mem_offset -= neuron_num)
{
const int t = index + mem_offset;
const float over_th = h_seq[t] - v_threshold;
'''
code += code_grad_s_to_h
if detach_reset:
if hard_reset:
code_grad_v_to_h = r'''
const float grad_v_to_h = 1.0f - spike_seq[t];
'''
else:
code_grad_v_to_h = r'''
const float grad_v_to_h = 1.0f;
'''
else:
if hard_reset:
code_grad_v_to_h = r'''
const float grad_v_to_h = 1.0f - spike_seq[t] + (v_reset - h_seq[t]) * grad_s_to_h;
// const float grad_v_to_h = fmaf(v_reset - h_seq[t], grad_s_to_h, 1.0f - spike_seq[t]);
'''
else:
code_grad_v_to_h = r'''
const float grad_v_to_h = 1.0f - v_threshold * grad_s_to_h;
// const float grad_v_to_h = fmaf(-grad_s_to_h, v_threshold, 1.0f);
'''
code += code_grad_v_to_h
code += r'''
grad_h = grad_spike_seq[t] * grad_s_to_h + (grad_v_seq[t] + grad_h * one_sub_reciprocal_tau) * grad_v_to_h;
// grad_h = fmaf(grad_spike_seq[t], grad_s_to_h, fmaf(grad_h, one_sub_reciprocal_tau, grad_v_seq[t]) * grad_v_to_h);
'''
if decay_input:
code += r'''
grad_x_seq[t] = grad_h * reciprocal_tau;
'''
else:
code += r'''
grad_x_seq[t] = grad_h;
'''
code += r'''
}
grad_v_init[index] = grad_h * one_sub_reciprocal_tau;
}
}
'''
elif dtype == 'fp16':
code = fr'''
#include <cuda_fp16.h>
extern "C" __global__
void {kernel_name}(
const half2* grad_spike_seq, const half2* grad_v_seq, const half2* h_seq, const half2* spike_seq,
half2* grad_x_seq, half2* grad_v_init,
const half & reciprocal_tau, const half & one_sub_reciprocal_tau,
const half & v_threshold, {'const half & v_reset,' if hard_reset else ''}
const int & neuron_num, const int & numel)
'''
code += r'''
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
const int stride = neuron_num >> 1;
if (index < stride)
{
const half2 reciprocal_tau_half2 = __half2half2(reciprocal_tau);
const half2 one_sub_reciprocal_tau_half2 = __half2half2(one_sub_reciprocal_tau);
const half2 v_threshold_half2 = __half2half2(v_threshold);
'''
if hard_reset:
code += r'''
const half2 v_reset_half2 = __half2half2(v_reset);
'''
code += r'''
half2 grad_h = __float2half2_rn(0.0f); // grad_h will be used recursively
for(int mem_offset = (numel >> 1) - stride; mem_offset >= 0; mem_offset -= stride)
{
const int t = index + mem_offset;
const half2 over_th = __hsub2(h_seq[t], v_threshold_half2);
'''
code += code_grad_s_to_h
if detach_reset:
if hard_reset:
code_grad_v_to_h = r'''
const half2 grad_v_to_h = __hsub2(__float2half2_rn(1.0f), spike_seq[t]);
'''
else:
code_grad_v_to_h = r'''
const half2 grad_v_to_h = __float2half2_rn(1.0f);
'''
else:
if hard_reset:
code_grad_v_to_h = r'''
const half2 grad_v_to_h = __hfma2(__hsub2(v_reset_half2, h_seq[t]), grad_s_to_h, __hsub2(__float2half2_rn(1.0f), spike_seq[t]));
'''
else:
code_grad_v_to_h = r'''
const half2 grad_v_to_h = __hsub2(__float2half2_rn(1.0f), __hmul2(v_threshold_half2, grad_s_to_h));
'''
code += code_grad_v_to_h
code += r'''
grad_h = __hfma2(__hfma2(grad_h, one_sub_reciprocal_tau_half2, grad_v_seq[t]), grad_v_to_h, __hmul2(grad_spike_seq[t], grad_s_to_h));
'''
if decay_input:
code += r'''
grad_x_seq[t] = __hmul2(grad_h, reciprocal_tau_half2);
'''
else:
code += r'''
grad_x_seq[t] = grad_h;
'''
code += r'''
}
grad_v_init[index] = __hmul2(grad_h, one_sub_reciprocal_tau_half2);
}
}
'''
else:
raise TypeError
return cupy.RawKernel(code, kernel_name, options=configure.cuda_compiler_options, backend=configure.cuda_compiler_backend)
[文档] @staticmethod
def forward(ctx, x_seq: torch.Tensor, v_init: torch.Tensor, decay_input: bool, tau: float, v_threshold: float, v_reset: float,
detach_reset: bool, sg_cuda_code_fun):
requires_grad = x_seq.requires_grad or v_init.requires_grad
device = x_seq.get_device()
if x_seq.dtype == torch.float32:
dtype = 'fp32'
cp_dtype = np.float32
elif x_seq.dtype == torch.float16:
dtype = 'fp16'
cp_dtype = np.half
else:
raise NotImplementedError
use_pad = False
if dtype == 'fp16' and v_init.numel() % 2 != 0:
# only fp16 needs even numel because we use half2 to accelerate
# when numel is odd, we will pad x_seq
use_pad = True
x_seq = F.pad(x_seq, (0, 1)) # [T, N] -> [T, N + 1]
v_init = F.pad(v_init, (0, 1)) # [N] -> [N + 1]
zero_shape = list(x_seq.shape)
zero_shape[0] *= 3
v_seq, h_seq, spike_seq = torch.split(torch.zeros(zero_shape, device=x_seq.device, dtype=x_seq.dtype), x_seq.shape[0])
v_v_seq = torch.cat((v_init.unsqueeze(0), v_seq))
with cuda_utils.DeviceEnvironment(device):
numel = x_seq.numel()
neuron_num = numel // x_seq.shape[0]
threads = configure.cuda_threads
if dtype == 'fp16':
assert neuron_num % 2 == 0
blocks = cuda_utils.cal_blocks(neuron_num >> 1)
# we will take two neurons to calculate as one neuron in cuda half2
else:
blocks = cuda_utils.cal_blocks(neuron_num)
cp_numel = cupy.asarray(numel)
cp_neuron_num = cupy.asarray(neuron_num)
cp_v_threshold = cupy.asarray(v_threshold, dtype=cp_dtype)
cp_reciprocal_tau = cupy.asarray(1. / tau, dtype=cp_dtype)
cp_one_sub_reciprocal_tau = cupy.asarray(1. - 1. / tau, dtype=cp_dtype)
if v_reset is None:
cp_v_reset = None
hard_reset = False
x_seq, v_v_seq, h_seq, spike_seq, cp_reciprocal_tau, cp_v_threshold, cp_neuron_num, cp_numel = cuda_utils.get_contiguous(
x_seq, v_v_seq, h_seq, spike_seq, cp_reciprocal_tau, cp_v_threshold, cp_neuron_num, cp_numel)
kernel_args = [x_seq, v_v_seq, h_seq, spike_seq, cp_reciprocal_tau, cp_v_threshold, cp_neuron_num,
cp_numel]
else:
cp_v_reset = cupy.asarray(v_reset, dtype=cp_dtype)
hard_reset = True
x_seq, v_v_seq, h_seq, spike_seq, cp_reciprocal_tau, cp_v_threshold, cp_v_reset, cp_neuron_num, cp_numel = cuda_utils.get_contiguous(
x_seq, v_v_seq, h_seq, spike_seq, cp_reciprocal_tau, cp_v_threshold, cp_v_reset, cp_neuron_num,
cp_numel)
kernel_args = [x_seq, v_v_seq, h_seq, spike_seq, cp_reciprocal_tau, cp_v_threshold, cp_v_reset,
cp_neuron_num, cp_numel]
kernel = MultiStepLIFNodePTT.create_fptt_kernel(decay_input, hard_reset, dtype)
kernel(
(blocks,), (threads,),
cuda_utils.wrap_args_to_raw_kernel(
device,
*kernel_args
)
)
if requires_grad:
ctx.decay_input = decay_input
ctx.use_pad = use_pad
if configure.save_spike_as_bool_in_neuron_kernel:
ctx.s_shape = spike_seq.shape
ctx.s_tk = tensor_cache.BOOL_TENSOR_CACHE.store_bool(spike_seq)
ctx.save_for_backward(h_seq)
else:
ctx.save_for_backward(h_seq, spike_seq)
ctx.blocks = blocks
ctx.threads = threads
ctx.cp_numel = cp_numel
ctx.cp_neuron_num = cp_neuron_num
ctx.cp_reciprocal_tau = cp_reciprocal_tau
ctx.cp_one_sub_reciprocal_tau = cp_one_sub_reciprocal_tau
ctx.cp_v_threshold = cp_v_threshold
ctx.cp_v_reset = cp_v_reset
ctx.detach_reset = detach_reset
ctx.sg_cuda_code_fun = sg_cuda_code_fun
if use_pad:
return spike_seq[..., :-1], v_v_seq[1:, ..., :-1]
else:
return spike_seq, v_v_seq[1:, ]
[文档] @staticmethod
def backward(ctx, grad_spike_seq, grad_v_seq):
if ctx.use_pad:
# grad_spike_seq.shape = [T, N]
# grad_v_seq.shape = [T, N]
# h_seq.shape = [T, N + 1]
# spike_seq.shape = [T, N + 1]
grad_spike_seq = F.pad(grad_spike_seq, (0, 1))
grad_v_seq = F.pad(grad_v_seq, (0, 1))
device = grad_spike_seq.get_device()
if configure.save_spike_as_bool_in_neuron_kernel:
h_seq = ctx.saved_tensors[0]
spike_seq = tensor_cache.BOOL_TENSOR_CACHE.get_float(ctx.s_tk, ctx.s_shape)
else:
h_seq, spike_seq = ctx.saved_tensors
zero_shape = list(grad_spike_seq.shape)
zero_shape[0] += 1
zero_data = torch.zeros(zero_shape, device=grad_spike_seq.device, dtype=grad_spike_seq.dtype)
grad_x_seq = zero_data[0: -1]
grad_v_init = zero_data[-1]
if ctx.cp_v_reset is None:
hard_reset = False
else:
hard_reset = True
if grad_spike_seq.dtype == torch.float32:
dtype = 'fp32'
elif grad_spike_seq.dtype == torch.float16:
dtype = 'fp16'
else:
raise NotImplementedError
kernel = MultiStepLIFNodePTT.create_bptt_kernel(ctx.sg_cuda_code_fun, ctx.decay_input, hard_reset, ctx.detach_reset, dtype)
with cuda_utils.DeviceEnvironment(device):
if hard_reset:
grad_spike_seq, grad_v_seq, h_seq, spike_seq, grad_x_seq, grad_v_init, ctx.cp_reciprocal_tau, ctx.cp_one_sub_reciprocal_tau, ctx.cp_v_threshold, ctx.cp_v_reset, ctx.cp_neuron_num, ctx.cp_numel = cuda_utils.get_contiguous(
grad_spike_seq, grad_v_seq, h_seq, spike_seq, grad_x_seq, grad_v_init, ctx.cp_reciprocal_tau,
ctx.cp_one_sub_reciprocal_tau, ctx.cp_v_threshold, ctx.cp_v_reset, ctx.cp_neuron_num,
ctx.cp_numel)
kernel_args = [grad_spike_seq, grad_v_seq, h_seq, spike_seq, grad_x_seq, grad_v_init,
ctx.cp_reciprocal_tau, ctx.cp_one_sub_reciprocal_tau, ctx.cp_v_threshold,
ctx.cp_v_reset, ctx.cp_neuron_num, ctx.cp_numel]
else:
grad_spike_seq, grad_v_seq, h_seq, spike_seq, grad_x_seq, grad_v_init, ctx.cp_reciprocal_tau, ctx.cp_one_sub_reciprocal_tau, ctx.cp_v_threshold, ctx.cp_neuron_num, ctx.cp_numel = cuda_utils.get_contiguous(
grad_spike_seq, grad_v_seq, h_seq, spike_seq, grad_x_seq, grad_v_init, ctx.cp_reciprocal_tau,
ctx.cp_one_sub_reciprocal_tau, ctx.cp_v_threshold, ctx.cp_neuron_num, ctx.cp_numel)
kernel_args = [grad_spike_seq, grad_v_seq, h_seq, spike_seq, grad_x_seq, grad_v_init,
ctx.cp_reciprocal_tau, ctx.cp_one_sub_reciprocal_tau, ctx.cp_v_threshold,
ctx.cp_neuron_num, ctx.cp_numel]
kernel(
(ctx.blocks,), (ctx.threads,),
cuda_utils.wrap_args_to_raw_kernel(
device,
*kernel_args
)
)
if ctx.use_pad:
return grad_x_seq[..., :-1], grad_v_init[..., :-1], None, None, None, None, None, None
else:
return grad_x_seq, grad_v_init, None, None, None, None, None, None
[文档]class MultiStepParametricLIFNodePTT(torch.autograd.Function):
[文档] @staticmethod
def create_fptt_kernel(decay_input: bool, hard_reset: bool, dtype: str):
return MultiStepLIFNodePTT.create_fptt_kernel(decay_input, hard_reset, dtype, kernel_name_prefix='ParametricLIFNode')
[文档] @staticmethod
def create_bptt_kernel(sg_cuda_code_fun, decay_input: bool, hard_reset: bool, detach_reset: bool, dtype: str):
kernel_name = f'ParametricLIFNode_bptt_decayInput{decay_input}_{"hard" if hard_reset else "soft"}Reset_{"detachReset" if detach_reset else ""}_{dtype}'
code_grad_s_to_h = sg_cuda_code_fun(x='over_th', y='grad_s_to_h', dtype=dtype)
if dtype == 'fp32':
code = fr'''
extern "C" __global__
void {kernel_name}(
const float* grad_spike_seq, const float* grad_v_seq, const float* h_seq, const float* spike_seq, const float* v_v_seq,
float* grad_x_seq, float* grad_v_init, float* grad_reciprocal_tau,
const float & reciprocal_tau, const float & one_sub_reciprocal_tau,
const float & v_threshold, {'const float & v_reset,' if hard_reset else ''}
const int & neuron_num, const int & numel)
'''
code += r'''
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
'''
code += f'__shared__ float sdata[{configure.cuda_threads}];'
code += r'''
if (index < neuron_num)
{
float grad_h = 0.0f; // grad_h will be used recursively
sdata[threadIdx.x] = 0.0f;
for(int mem_offset = numel - neuron_num; mem_offset >= 0; mem_offset -= neuron_num)
{
const int t = index + mem_offset;
const float over_th = h_seq[t] - v_threshold;
'''
code += code_grad_s_to_h
if detach_reset:
if hard_reset:
code_grad_v_to_h = r'''
const float grad_v_to_h = 1.0f - spike_seq[t];
'''
else:
code_grad_v_to_h = r'''
const float grad_v_to_h = 1.0f;
'''
else:
if hard_reset:
code_grad_v_to_h = r'''
const float grad_v_to_h = 1.0f - spike_seq[t] + (v_reset - h_seq[t]) * grad_s_to_h;
// const float grad_v_to_h = fmaf(v_reset - h_seq[t], grad_s_to_h, 1.0f - spike_seq[t]);
'''
else:
code_grad_v_to_h = r'''
const float grad_v_to_h = 1.0f - v_threshold * grad_s_to_h;
// const float grad_v_to_h = fmaf(-v_threshold, grad_s_to_h, 1.0f);
'''
code += code_grad_v_to_h
code += r'''
grad_h = grad_spike_seq[t] * grad_s_to_h + (grad_v_seq[t] + grad_h * one_sub_reciprocal_tau) * grad_v_to_h;
// grad_h = fmaf(grad_spike_seq[t], grad_s_to_h, fmaf(grad_h, one_sub_reciprocal_tau, grad_v_seq[t]) * grad_v_to_h);
'''
if decay_input:
code += r'''
grad_x_seq[t] = grad_h * reciprocal_tau;
sdata[threadIdx.x] += grad_h * (h_seq[t] - v_v_seq[t]) / reciprocal_tau;
'''
else:
if hard_reset:
code += r'''
grad_x_seq[t] = grad_h;
sdata[threadIdx.x] += grad_h * (v_reset - v_v_seq[t]);
'''
else:
code += r'''
grad_x_seq[t] = grad_h;
sdata[threadIdx.x] -= grad_h * v_v_seq[t];
'''
code += r'''
}
grad_v_init[index] = grad_h * one_sub_reciprocal_tau;
}
else
{
sdata[threadIdx.x] = 0.0f;
}
int threadx = blockDim.x;
#pragma unroll
for (int stride = threadx >> 1; stride > 0; stride = stride >> 1)
{
// Synchronize all thread before next loop
__syncthreads();
if (threadIdx.x < stride)
{
sdata[threadIdx.x] += sdata[threadIdx.x + stride];
}
}
__syncthreads();
if (threadIdx.x == 0)
{
atomicAdd(grad_reciprocal_tau, sdata[0]);
}
}
'''
elif dtype == 'fp16':
code = fr'''
#include <cuda_fp16.h>
extern "C" __global__
void {kernel_name}(
const half2* grad_spike_seq, const half2* grad_v_seq, const half2* h_seq, const half2* spike_seq, const half2* v_v_seq,
half2* grad_x_seq, half2* grad_v_init, float* grad_reciprocal_tau,
const half & reciprocal_tau, const half & one_sub_reciprocal_tau,
const half & v_threshold, {'const half & v_reset,' if hard_reset else ''}
const int & neuron_num, const int & numel)\
// note that grad_reciprocal_tau is float to avoid overflow
'''
code += r'''
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
const int stride = neuron_num >> 1;
'''
code += f'__shared__ float sdata[{configure.cuda_threads}];'
code += r'''
if (index < stride)
{
const half2 reciprocal_tau_half2 = __half2half2(reciprocal_tau);
const half2 one_sub_reciprocal_tau_half2 = __half2half2(one_sub_reciprocal_tau);
const half2 v_threshold_half2 = __half2half2(v_threshold);
'''
if hard_reset:
code += r'''
const half2 v_reset_half2 = __half2half2(v_reset);
'''
code += r'''
half2 grad_h = __float2half2_rn(0.0f); // grad_h will be used recursively
sdata[threadIdx.x] = 0.0f;
for(int mem_offset = (numel >> 1) - stride; mem_offset >= 0; mem_offset -= stride)
{
const int t = index + mem_offset;
const half2 over_th = __hsub2(h_seq[t], v_threshold_half2);
'''
code += code_grad_s_to_h
if detach_reset:
if hard_reset:
code_grad_v_to_h = r'''
const half2 grad_v_to_h = __hsub2(__float2half2_rn(1.0f), spike_seq[t]);
'''
else:
code_grad_v_to_h = r'''
const half2 grad_v_to_h = __float2half2_rn(1.0f);
'''
else:
if hard_reset:
code_grad_v_to_h = r'''
const half2 grad_v_to_h = __hfma2(__hsub2(v_reset_half2, h_seq[t]), grad_s_to_h, __hsub2(__float2half2_rn(1.0f), spike_seq[t]));
'''
else:
code_grad_v_to_h = r'''
const half2 grad_v_to_h = __hsub2(__float2half2_rn(1.0f), __hmul2(v_threshold_half2, grad_s_to_h));
'''
code += code_grad_v_to_h
code += r'''
grad_h = __hfma2(__hfma2(grad_h, one_sub_reciprocal_tau_half2, grad_v_seq[t]), grad_v_to_h, __hmul2(grad_spike_seq[t], grad_s_to_h));
'''
if decay_input:
code += r'''
grad_x_seq[t] = __hmul2(grad_h, reciprocal_tau_half2);
half2 temp_sum = __h2div(__hmul2(grad_h, __hsub2(h_seq[t], v_v_seq[t])), reciprocal_tau_half2);
sdata[threadIdx.x] += __half2float(__hadd(__low2half(temp_sum), __high2half(temp_sum)));
'''
else:
if hard_reset:
code += r'''
grad_x_seq[t] = grad_h;
half2 temp_sum = __hmul2(grad_h, __hsub2(v_reset_half2, v_v_seq[t]));
sdata[threadIdx.x] += __half2float(__hadd(__low2half(temp_sum), __high2half(temp_sum)));
'''
else:
code += r'''
grad_x_seq[t] = grad_h;
half2 temp_sum = __hmul2(grad_h, __hneg2(v_v_seq[t]));
sdata[threadIdx.x] += __half2float(__hadd(__low2half(temp_sum), __high2half(temp_sum)));
'''
code += r'''
}
grad_v_init[index] = __hmul2(grad_h, one_sub_reciprocal_tau_half2);
}
else
{
sdata[threadIdx.x] = 0.0f;
}
int threadx = blockDim.x;
#pragma unroll
for (int i = threadx >> 1; i > 0; i = i >> 1)
{
// Synchronize all thread before next loop
__syncthreads();
if (threadIdx.x < i)
{
sdata[threadIdx.x] += sdata[threadIdx.x + i];
}
}
__syncthreads();
if (threadIdx.x == 0)
{
/*
The 32-bit floating-point version of atomicAdd() is only supported by devices of compute capability 2.x and higher.
The 64-bit floating-point version of atomicAdd() is only supported by devices of compute capability 6.x and higher.
The 32-bit __half2 floating-point version of atomicAdd() is only supported by devices of compute capability 6.x and higher. The atomicity of the __half2 or __nv_bfloat162 add operation is guaranteed separately for each of the two __half or __nv_bfloat16 elements; the entire __half2 or __nv_bfloat162 is not guaranteed to be atomic as a single 32-bit access.
The 16-bit __half floating-point version of atomicAdd() is only supported by devices of compute capability 7.x and higher.
The 16-bit __nv_bfloat16 floating-point version of atomicAdd() is only supported by devices of compute capability 8.x and higher.
*/
atomicAdd(grad_reciprocal_tau, sdata[0]);
}
}
'''
else:
raise TypeError
return cupy.RawKernel(code, kernel_name, options=configure.cuda_compiler_options, backend=configure.cuda_compiler_backend)
[文档] @staticmethod
def forward(ctx, x_seq: torch.Tensor, v_init: torch.Tensor, reciprocal_tau: torch.Tensor, decay_input: bool, v_threshold: float,
v_reset: float, detach_reset: bool, sg_cuda_code_fun):
# reciprocal_tau.dtype is float32 even when using amp
requires_grad = x_seq.requires_grad or v_init.requires_grad
device = x_seq.get_device()
if x_seq.dtype == torch.float32:
dtype = 'fp32'
cp_dtype = np.float32
elif x_seq.dtype == torch.float16:
dtype = 'fp16'
cp_dtype = np.half
# assert torch.cuda.get_device_capability(device)[0] >= 7, "MultiStepParametricLIFNodePTT can not run in the current device with float16 because the 16-bit __half floating-point version of atomicAdd() is only supported by devices of compute capability 7.x and higher."
else:
raise NotImplementedError
use_pad = False
if dtype == 'fp16' and v_init.numel() % 2 != 0:
# only fp16 needs even numel because we use half2 to accelerate
# when numel is odd, we will pad x_seq
use_pad = True
x_seq = F.pad(x_seq, (0, 1)) # [T, N] -> [T, N + 1]
v_init = F.pad(v_init, (0, 1)) # [N] -> [N + 1]
zero_shape = list(x_seq.shape)
zero_shape[0] *= 3
v_seq, h_seq, spike_seq = torch.split(torch.zeros(zero_shape, device=x_seq.device, dtype=x_seq.dtype), x_seq.shape[0])
v_v_seq = torch.cat((v_init.unsqueeze(0), v_seq))
tau = 1. / reciprocal_tau.item()
with cuda_utils.DeviceEnvironment(device):
numel = x_seq.numel()
neuron_num = numel // x_seq.shape[0]
threads = configure.cuda_threads
if dtype == 'fp16':
assert neuron_num % 2 == 0
blocks = cuda_utils.cal_blocks(neuron_num >> 1)
# we will take two neurons to calculate as one neuron in cuda half2
else:
blocks = cuda_utils.cal_blocks(neuron_num)
cp_numel = cupy.asarray(numel)
cp_neuron_num = cupy.asarray(neuron_num)
cp_v_threshold = cupy.asarray(v_threshold, dtype=cp_dtype)
cp_reciprocal_tau = cupy.asarray(1. / tau, dtype=cp_dtype)
cp_one_sub_reciprocal_tau = cupy.asarray(1. - 1. / tau, dtype=cp_dtype)
if v_reset is None:
cp_v_reset = None
hard_reset = False
x_seq, v_v_seq, h_seq, spike_seq, cp_reciprocal_tau, cp_v_threshold, cp_neuron_num, cp_numel = cuda_utils.get_contiguous(
x_seq, v_v_seq, h_seq, spike_seq, cp_reciprocal_tau, cp_v_threshold, cp_neuron_num, cp_numel)
kernel_args = [x_seq, v_v_seq, h_seq, spike_seq, cp_reciprocal_tau, cp_v_threshold, cp_neuron_num,
cp_numel]
else:
cp_v_reset = cupy.asarray(v_reset, dtype=cp_dtype)
hard_reset = True
x_seq, v_v_seq, h_seq, spike_seq, cp_reciprocal_tau, cp_v_threshold, cp_v_reset, cp_neuron_num, cp_numel = cuda_utils.get_contiguous(
x_seq, v_v_seq, h_seq, spike_seq, cp_reciprocal_tau, cp_v_threshold, cp_v_reset, cp_neuron_num,
cp_numel)
kernel_args = [x_seq, v_v_seq, h_seq, spike_seq, cp_reciprocal_tau, cp_v_threshold, cp_v_reset,
cp_neuron_num, cp_numel]
kernel = MultiStepParametricLIFNodePTT.create_fptt_kernel(decay_input, hard_reset, dtype)
kernel(
(blocks,), (threads,),
cuda_utils.wrap_args_to_raw_kernel(
device,
*kernel_args
)
)
if requires_grad:
ctx.decay_input = decay_input
ctx.use_pad = use_pad
if configure.save_spike_as_bool_in_neuron_kernel:
ctx.s_shape = spike_seq.shape
ctx.s_tk = tensor_cache.BOOL_TENSOR_CACHE.store_bool(spike_seq)
ctx.save_for_backward(h_seq, v_v_seq)
else:
ctx.save_for_backward(h_seq, spike_seq, v_v_seq)
ctx.blocks = blocks
ctx.threads = threads
ctx.cp_numel = cp_numel
ctx.cp_neuron_num = cp_neuron_num
ctx.cp_reciprocal_tau = cp_reciprocal_tau
ctx.cp_one_sub_reciprocal_tau = cp_one_sub_reciprocal_tau
ctx.cp_v_threshold = cp_v_threshold
ctx.cp_v_reset = cp_v_reset
ctx.detach_reset = detach_reset
ctx.sg_cuda_code_fun = sg_cuda_code_fun
if use_pad:
return spike_seq[..., :-1], v_v_seq[1:, ..., :-1]
else:
return spike_seq, v_v_seq[1:, ]
[文档] @staticmethod
def backward(ctx, grad_spike_seq, grad_v_seq):
if ctx.use_pad:
# grad_spike_seq.shape = [T, N]
# grad_v_seq.shape = [T, N]
# h_seq.shape = [T, N + 1]
# spike_seq.shape = [T, N + 1]
grad_spike_seq = F.pad(grad_spike_seq, (0, 1))
grad_v_seq = F.pad(grad_v_seq, (0, 1))
device = grad_spike_seq.get_device()
if configure.save_spike_as_bool_in_neuron_kernel:
spike_seq = tensor_cache.BOOL_TENSOR_CACHE.get_float(ctx.s_tk, ctx.s_shape)
h_seq, v_v_seq = ctx.saved_tensors
else:
h_seq, spike_seq, v_v_seq = ctx.saved_tensors
zero_shape = list(grad_spike_seq.shape)
zero_shape[0] += 1
zero_data = torch.zeros(zero_shape, device=grad_spike_seq.device, dtype=grad_spike_seq.dtype)
grad_x_seq = zero_data[0: -1]
grad_v_init = zero_data[-1]
grad_reciprocal_tau = torch.as_tensor(0., device=grad_spike_seq.device, dtype=torch.float32)
if ctx.cp_v_reset is None:
hard_reset = False
else:
hard_reset = True
if grad_spike_seq.dtype == torch.float32:
dtype = 'fp32'
elif grad_spike_seq.dtype == torch.float16:
dtype = 'fp16'
else:
raise NotImplementedError
kernel = MultiStepParametricLIFNodePTT.create_bptt_kernel(ctx.sg_cuda_code_fun, ctx.decay_input, hard_reset,
ctx.detach_reset, dtype)
with cuda_utils.DeviceEnvironment(device):
if hard_reset:
grad_spike_seq, grad_v_seq, h_seq, spike_seq, v_v_seq, grad_x_seq, grad_v_init, grad_reciprocal_tau, ctx.cp_reciprocal_tau, ctx.cp_one_sub_reciprocal_tau, ctx.cp_v_threshold, ctx.cp_v_reset, ctx.cp_neuron_num, ctx.cp_numel = cuda_utils.get_contiguous(
grad_spike_seq, grad_v_seq, h_seq, spike_seq, v_v_seq, grad_x_seq, grad_v_init,
grad_reciprocal_tau, ctx.cp_reciprocal_tau, ctx.cp_one_sub_reciprocal_tau, ctx.cp_v_threshold,
ctx.cp_v_reset, ctx.cp_neuron_num, ctx.cp_numel)
kernel_args = [grad_spike_seq, grad_v_seq, h_seq, spike_seq, v_v_seq, grad_x_seq, grad_v_init,
grad_reciprocal_tau, ctx.cp_reciprocal_tau, ctx.cp_one_sub_reciprocal_tau,
ctx.cp_v_threshold, ctx.cp_v_reset, ctx.cp_neuron_num, ctx.cp_numel]
else:
grad_spike_seq, grad_v_seq, h_seq, spike_seq, v_v_seq, grad_x_seq, grad_v_init, grad_reciprocal_tau, ctx.cp_reciprocal_tau, ctx.cp_one_sub_reciprocal_tau, ctx.cp_v_threshold, ctx.cp_neuron_num, ctx.cp_numel = cuda_utils.get_contiguous(
grad_spike_seq, grad_v_seq, h_seq, spike_seq, v_v_seq, grad_x_seq, grad_v_init,
grad_reciprocal_tau, ctx.cp_reciprocal_tau, ctx.cp_one_sub_reciprocal_tau, ctx.cp_v_threshold,
ctx.cp_neuron_num, ctx.cp_numel)
kernel_args = [grad_spike_seq, grad_v_seq, h_seq, spike_seq, v_v_seq, grad_x_seq, grad_v_init,
grad_reciprocal_tau, ctx.cp_reciprocal_tau, ctx.cp_one_sub_reciprocal_tau,
ctx.cp_v_threshold, ctx.cp_neuron_num, ctx.cp_numel]
kernel(
(ctx.blocks,), (ctx.threads,),
cuda_utils.wrap_args_to_raw_kernel(
device,
*kernel_args
)
)
if ctx.use_pad:
return grad_x_seq[..., :-1], grad_v_init[..., :-1], grad_reciprocal_tau, None, None, None, None, None
else:
return grad_x_seq, grad_v_init, grad_reciprocal_tau, None, None, None, None, None
[文档]def check_multi_step_neuron_output_and_grad(device, multi_step_neuron, shape = [65, 15, 511], *neu_args, **neu_kwargs):
@torch.no_grad()
def max_error(x, y):
return (x - y).abs().max().item()
def fbptt(m, x: torch.Tensor):
x = x.detach()
x.requires_grad_(True)
spike_seq = m(x)
(spike_seq * m.v_seq ** 2).sum().backward()
ret = {
'spike_seq': spike_seq.detach().clone(),
'v_seq': m.v_seq.detach().clone(),
'x.grad': x.grad.clone()
}
for i, param in enumerate(m.parameters()):
ret[f'param_{i}.grad'] = param.grad.detach().clone()
param.grad.zero_()
x.grad.zero_()
m.reset()
return ret
for hard_reset in [True, False]:
for detach_reset in [False, True]:
for dtype in ['fp32', 'fp16']:
x = (torch.rand(shape, device=device) - 0.5) * 3.
if dtype == 'fp16':
x = x.half()
print(f'hard_reset={hard_reset}, detach_reset={detach_reset}, dtype={dtype}')
model = multi_step_neuron(v_reset=0. if hard_reset else None, detach_reset=detach_reset, *neu_args,
**neu_kwargs)
# print(model)
model.to(device)
if dtype == 'fp16':
model = model.half()
model.backend = 'torch'
y_torch = fbptt(model, x)
model.backend = 'cupy'
y_cupy = fbptt(model, x)
for key in y_torch.keys():
me = max_error(y_torch[key], y_cupy[key])
print(key, 'max error', me)
if me > 0.5:
print(f'y_torch[{key}]={y_torch[key]}, y_cupy[{key}]={y_cupy[key]}')
print('\n')
[文档]def check_single_step_neuron_output_and_grad(device, single_step_neuron, shape = [65, 15, 511], *neu_args, **neu_kwargs):
@torch.no_grad()
def max_error(x, y):
return (x - y).abs().max().item()
def fbp(m, x: torch.Tensor):
x = x.detach()
x.requires_grad_(True)
T = x.size(0)
for i in range(T):
spike = m(x[i])
(spike * m.v ** 2).sum().backward()
ret = {
'spike': spike.detach().clone(),
'v': m.v.detach().clone(),
'x.grad': x.grad.clone()
}
for i, param in enumerate(m.parameters()):
ret[f'param_{i}.grad'] = param.grad.detach().clone()
param.grad.zero_()
x.grad.zero_()
m.reset()
return ret
for hard_reset in [True, False]:
for detach_reset in [False, True]:
for dtype in ['fp32', 'fp16']:
x = (torch.rand(shape, device=device) - 0.5) * 3.
if dtype == 'fp16':
x = x.half()
print(f'hard_reset={hard_reset}, detach_reset={detach_reset}, dtype={dtype}')
model = single_step_neuron(v_reset=0. if hard_reset else None, detach_reset=detach_reset, step_mode='s', *neu_args,
**neu_kwargs)
# print(model)
model.to(device)
if dtype == 'fp16':
model = model.half()
model.backend = 'torch'
y_torch = fbp(model, x)
model.backend = 'cupy'
y_cupy = fbp(model, x)
for key in y_torch.keys():
me = max_error(y_torch[key], y_cupy[key])
print(key, 'max error', me)
if me > 0.5:
print(f'y_torch[{key}]={y_torch[key]}, y_cupy[{key}]={y_cupy[key]}')
print('\n')
[文档]class MultiStepQIFNodePTT(torch.autograd.Function):
[文档] @staticmethod
def create_fptt_kernel(hard_reset: bool, dtype: str):
kernel_name = f'QIFNode_fptt_{"hard" if hard_reset else "soft"}Reset_{dtype}'
if dtype == 'fp32':
code = rf'''
extern "C" __global__
void {kernel_name}(const float* x_seq, float* v_v_seq, float* h_seq, float* spike_seq,
const float & reciprocal_tau,
const float & v_c,
const float & a0,
const float & v_threshold,
const float & v_rest, {'const float & v_reset,' if hard_reset else ''}
const int & neuron_num, const int & numel)
'''
code += r'''
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < neuron_num)
{
const int dt = neuron_num;
for(int mem_offset = 0; mem_offset < numel; mem_offset += neuron_num)
{
const int t = index + mem_offset;
h_seq[t] = v_v_seq[t] + reciprocal_tau * (x_seq[t] + a0 * (v_v_seq[t] - v_rest) * (v_v_seq[t] - v_c));
if (h_seq[t] >= v_threshold)
{
spike_seq[t] = 1.0f;
'''
if hard_reset:
code += r'''
v_v_seq[t + dt] = v_reset;
'''
else:
code += r'''
v_v_seq[t + dt] = h_seq[t] - v_threshold;
'''
code += r'''
}
else
{
spike_seq[t] = 0.0f;
v_v_seq[t + dt] = h_seq[t];
}
}
}
}
'''
elif dtype == 'fp16':
code = rf'''
#include <cuda_fp16.h>
extern "C" __global__
void {kernel_name}(const half2* x_seq, half2* v_v_seq, half2* h_seq, half2* spike_seq,
const half & reciprocal_tau,
const half & v_c,
const half & a0,
const half & v_threshold,
const half & v_rest, {'const half & v_reset,' if hard_reset else ''}
const int & neuron_num, const int & numel)
'''
code += r'''
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
const int stride = neuron_num >> 1;
if (index < stride)
{
const int numel_2 = numel >> 1;
const half2 reciprocal_tau_half2 = __half2half2(reciprocal_tau);
const half2 v_c_half2 = __half2half2(v_c);
const half2 a0_half2 = __half2half2(a0);
const half2 v_threshold_half2 = __half2half2(v_threshold);
const half2 v_rest_half2 = __half2half2(v_rest);
'''
if hard_reset:
code += r'''
const half2 v_reset_half2 = __half2half2(v_reset);
'''
code += r'''
for(int mem_offset = 0; mem_offset < numel_2; mem_offset += stride)
{
const int t = index + mem_offset;
h_seq[t] = __hfma2(__hfma2(__hmul2(__hsub2(v_v_seq[t], v_rest_half2), __hsub2(v_v_seq[t], v_c_half2)), a0_half2, x_seq[t]), reciprocal_tau_half2, v_v_seq[t]);
spike_seq[t] = __hgeu2(h_seq[t], v_threshold_half2);
'''
if hard_reset:
code += r'''
v_v_seq[t + stride] = __hadd2(__hmul2(spike_seq[t], v_reset_half2), __hmul2(__hsub2(__float2half2_rn(1.0f), spike_seq[t]), h_seq[t]));
'''
else:
code += r'''
v_v_seq[t + stride] = __hadd2(__hmul2(spike_seq[t], __hsub2(h_seq[t], v_threshold_half2)), __hmul2(__hsub2(__float2half2_rn(1.0f), spike_seq[t]), h_seq[t]));
'''
code += r'''
}
}
}
'''
else:
raise TypeError
return cupy.RawKernel(code, kernel_name, options=configure.cuda_compiler_options, backend=configure.cuda_compiler_backend)
[文档] @staticmethod
def create_bptt_kernel(sg_cuda_code_fun, hard_reset: bool, detach_reset: bool, dtype: str):
kernel_name = f'QIFNode_bptt_{"hard" if hard_reset else "soft"}Reset_{"detachReset" if detach_reset else ""}_{dtype}'
code_grad_s_to_h = sg_cuda_code_fun(x='over_th', y='grad_s_to_h', dtype=dtype)
if dtype == 'fp32':
code = fr'''
extern "C" __global__
void {kernel_name}(
const float* grad_spike_seq, const float* grad_v_seq, const float* h_seq, const float* spike_seq, const float* v_v_seq,
float* grad_x_seq, float* grad_v_init,
const float & a0_over_tau, const float & neg_sum_v_rest_v_c, const float & reciprocal_tau,
const float & v_threshold, {'const float & v_reset,' if hard_reset else ''}
const int & neuron_num, const int & numel)
'''
code += r'''
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < neuron_num)
{
float grad_h = 0.0f; // grad_h will be used recursively
for(int mem_offset = numel - neuron_num; mem_offset >= 0; mem_offset -= neuron_num)
{
const int t = index + mem_offset;
const float over_th = h_seq[t] - v_threshold;
'''
code += code_grad_s_to_h
if detach_reset:
if hard_reset:
code_grad_v_to_h = r'''
const float grad_v_to_h = 1.0f - spike_seq[t];
'''
else:
code_grad_v_to_h = r'''
const float grad_v_to_h = 1.0f;
'''
else:
if hard_reset:
code_grad_v_to_h = r'''
const float grad_v_to_h = 1.0f - spike_seq[t] + (v_reset - h_seq[t]) * grad_s_to_h;
'''
else:
code_grad_v_to_h = r'''
const float grad_v_to_h = 1.0f - v_threshold * grad_s_to_h;
'''
code += code_grad_v_to_h
code += r'''
grad_h = grad_spike_seq[t] * grad_s_to_h + (grad_v_seq[t] + grad_h * (1.0f + a0_over_tau * (2.0f * v_v_seq[t + neuron_num] + neg_sum_v_rest_v_c))) * grad_v_to_h;
grad_x_seq[t] = grad_h * reciprocal_tau;
}
grad_v_init[index] = grad_x_seq[index] * (1.0f + a0_over_tau * (2.0f * v_v_seq[index] + neg_sum_v_rest_v_c));
}
}
'''
elif dtype == 'fp16':
code = fr'''
#include <cuda_fp16.h>
extern "C" __global__
void {kernel_name}(
const half2* grad_spike_seq, const half2* grad_v_seq, const half2* h_seq, const half2* spike_seq, const half2* v_v_seq,
half2* grad_x_seq, half2* grad_v_init,
const half & a0_over_tau, const half & neg_sum_v_rest_v_c,
const half & reciprocal_tau,
const half & v_threshold, {'const half & v_reset,' if hard_reset else ''}
const int & neuron_num, const int & numel)
'''
code += r'''
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
const int stride = neuron_num >> 1;
if (index < stride)
{
const half2 a0_over_tau_half2 = __half2half2(a0_over_tau);
const half2 neg_sum_v_rest_v_c_half2 = __half2half2(neg_sum_v_rest_v_c);
const half2 v_threshold_half2 = __half2half2(v_threshold);
const half2 reciprocal_tau_half2 = __half2half2(reciprocal_tau);
'''
if hard_reset:
code += r'''
const half2 v_reset_half2 = __half2half2(v_reset);
'''
code += r'''
half2 grad_h = __float2half2_rn(0.0f); // grad_h will be used recursively
for(int mem_offset = (numel >> 1) - stride; mem_offset >= 0; mem_offset -= stride)
{
const int t = index + mem_offset;
const half2 over_th = __hsub2(h_seq[t], v_threshold_half2);
'''
code += code_grad_s_to_h
if detach_reset:
if hard_reset:
code_grad_v_to_h = r'''
const half2 grad_v_to_h = __hsub2(__float2half2_rn(1.0f), spike_seq[t]);
'''
else:
code_grad_v_to_h = r'''
const half2 grad_v_to_h = __float2half2_rn(1.0f);
'''
else:
if hard_reset:
code_grad_v_to_h = r'''
const half2 grad_v_to_h = __hfma2(__hsub2(v_reset_half2, h_seq[t]), grad_s_to_h, __hsub2(__float2half2_rn(1.0f), spike_seq[t]));
'''
else:
code_grad_v_to_h = r'''
const half2 grad_v_to_h = __hsub2(__float2half2_rn(1.0f), __hmul2(v_threshold_half2, grad_s_to_h));
'''
code += code_grad_v_to_h
code += r'''
grad_h = __hfma2(__hfma2(__hfma2(__hfma2(__float2half2_rn(2.0f), v_v_seq[t + stride], neg_sum_v_rest_v_c_half2), a0_over_tau_half2, __float2half2_rn(1.0f)), grad_h, grad_v_seq[t]), grad_v_to_h, __hmul2(grad_spike_seq[t], grad_s_to_h));
grad_x_seq[t] = __hmul2(grad_h, reciprocal_tau_half2);
}
grad_v_init[index] = __hmul2(__hfma2(__hfma2(__float2half2_rn(2.0f), v_v_seq[index], neg_sum_v_rest_v_c_half2), a0_over_tau_half2, __float2half2_rn(1.0f)), grad_x_seq[index]);
}
}
'''
else:
raise TypeError
return cupy.RawKernel(code, kernel_name, options=configure.cuda_compiler_options, backend=configure.cuda_compiler_backend)
[文档] @staticmethod
def forward(ctx, x_seq: torch.Tensor, v_init: torch.Tensor, tau: float, v_threshold: float, v_reset: float, v_rest: float, v_c: float, a0: float, detach_reset: bool, sg_cuda_code_fun):
requires_grad = x_seq.requires_grad or v_init.requires_grad
device = x_seq.get_device()
if x_seq.dtype == torch.float32:
dtype = 'fp32'
cp_dtype = np.float32
elif x_seq.dtype == torch.float16:
dtype = 'fp16'
cp_dtype = np.half
else:
raise NotImplementedError
use_pad = False
if dtype == 'fp16' and v_init.numel() % 2 != 0:
# only fp16 needs even numel because we use half2 to accelerate
# when numel is odd, we will pad x_seq
use_pad = True
x_seq = F.pad(x_seq, (0, 1)) # [T, N] -> [T, N + 1]
v_init = F.pad(v_init, (0, 1)) # [N] -> [N + 1]
zero_shape = list(x_seq.shape)
zero_shape[0] *= 3
v_seq, h_seq, spike_seq = torch.split(torch.zeros(zero_shape, device=x_seq.device, dtype=x_seq.dtype), x_seq.shape[0])
v_v_seq = torch.cat((v_init.unsqueeze(0), v_seq))
with cuda_utils.DeviceEnvironment(device):
numel = x_seq.numel()
neuron_num = numel // x_seq.shape[0]
threads = configure.cuda_threads
if dtype == 'fp16':
assert neuron_num % 2 == 0
blocks = cuda_utils.cal_blocks(neuron_num >> 1)
# we will take two neurons to calculate as one neuron in cuda half2
else:
blocks = cuda_utils.cal_blocks(neuron_num)
cp_numel = cupy.asarray(numel)
cp_neuron_num = cupy.asarray(neuron_num)
cp_v_threshold = cupy.asarray(v_threshold, dtype=cp_dtype)
cp_v_rest = cupy.asarray(v_rest, dtype=cp_dtype)
cp_v_c = cupy.asarray(v_c, dtype=cp_dtype)
cp_a0 = cupy.asarray(a0, dtype=cp_dtype)
cp_reciprocal_tau = cupy.asarray(1.0 / tau, dtype=cp_dtype)
cp_a0_over_tau = cupy.asarray(a0 / tau, dtype=cp_dtype)
cp_neg_sum_v_rest_v_c = cupy.asarray(-v_rest - v_c, dtype=cp_dtype)
if v_reset is None:
cp_v_reset = None
hard_reset = False
x_seq, v_v_seq, h_seq, spike_seq, cp_reciprocal_tau, cp_v_c, cp_a0, cp_v_threshold, cp_v_rest, cp_neuron_num, cp_numel = cuda_utils.get_contiguous(x_seq, v_v_seq, h_seq, spike_seq, cp_reciprocal_tau, cp_v_c, cp_a0, cp_v_threshold, cp_v_rest, cp_neuron_num, cp_numel)
kernel_args = [x_seq, v_v_seq, h_seq, spike_seq, cp_reciprocal_tau, cp_v_c, cp_a0, cp_v_threshold, cp_v_rest, cp_neuron_num, cp_numel]
else:
cp_v_reset = cupy.asarray(v_reset, dtype=cp_dtype)
hard_reset = True
x_seq, v_v_seq, h_seq, spike_seq, cp_reciprocal_tau, cp_v_c, cp_a0, cp_v_threshold, cp_v_rest, cp_v_reset, cp_neuron_num, cp_numel = cuda_utils.get_contiguous(x_seq, v_v_seq, h_seq, spike_seq, cp_reciprocal_tau, cp_v_c, cp_a0, cp_v_threshold, cp_v_rest, cp_v_reset, cp_neuron_num, cp_numel)
kernel_args = [x_seq, v_v_seq, h_seq, spike_seq, cp_reciprocal_tau, cp_v_c, cp_a0, cp_v_threshold, cp_v_rest, cp_v_reset, cp_neuron_num, cp_numel]
kernel = MultiStepQIFNodePTT.create_fptt_kernel(hard_reset, dtype)
kernel(
(blocks,), (threads,),
cuda_utils.wrap_args_to_raw_kernel(
device,
*kernel_args
)
)
if requires_grad:
ctx.use_pad = use_pad
if configure.save_spike_as_bool_in_neuron_kernel:
ctx.s_shape = spike_seq.shape
ctx.s_tk = tensor_cache.BOOL_TENSOR_CACHE.store_bool(spike_seq)
ctx.save_for_backward(h_seq, v_v_seq)
else:
ctx.save_for_backward(h_seq, spike_seq, v_v_seq)
ctx.blocks = blocks
ctx.threads = threads
ctx.cp_numel = cp_numel
ctx.cp_neuron_num = cp_neuron_num
ctx.cp_a0_over_tau = cp_a0_over_tau
ctx.cp_neg_sum_v_rest_v_c = cp_neg_sum_v_rest_v_c
ctx.cp_reciprocal_tau = cp_reciprocal_tau
ctx.cp_v_threshold = cp_v_threshold
ctx.cp_v_reset = cp_v_reset
ctx.detach_reset = detach_reset
ctx.sg_cuda_code_fun = sg_cuda_code_fun
if use_pad:
return spike_seq[..., :-1], v_v_seq[1:, ..., :-1]
else:
return spike_seq, v_v_seq[1:, ]
[文档] @staticmethod
def backward(ctx, grad_spike_seq, grad_v_seq):
if ctx.use_pad:
# grad_spike_seq.shape = [T, N]
# grad_v_seq.shape = [T, N]
# h_seq.shape = [T, N + 1]
# spike_seq.shape = [T, N + 1]
grad_spike_seq = F.pad(grad_spike_seq, (0, 1))
grad_v_seq = F.pad(grad_v_seq, (0, 1))
device = grad_spike_seq.get_device()
if configure.save_spike_as_bool_in_neuron_kernel:
spike_seq = tensor_cache.BOOL_TENSOR_CACHE.get_float(ctx.s_tk, ctx.s_shape)
h_seq, v_v_seq = ctx.saved_tensors
else:
h_seq, spike_seq, v_v_seq = ctx.saved_tensors
zero_shape = list(grad_spike_seq.shape)
zero_shape[0] += 1
zero_data = torch.zeros(zero_shape, device=grad_spike_seq.device, dtype=grad_spike_seq.dtype)
grad_x_seq = zero_data[0: -1]
grad_v_init = zero_data[-1]
if ctx.cp_v_reset is None:
hard_reset = False
else:
hard_reset = True
if grad_spike_seq.dtype == torch.float32:
dtype = 'fp32'
elif grad_spike_seq.dtype == torch.float16:
dtype = 'fp16'
else:
raise NotImplementedError
kernel = MultiStepQIFNodePTT.create_bptt_kernel(ctx.sg_cuda_code_fun, hard_reset, ctx.detach_reset, dtype)
with cuda_utils.DeviceEnvironment(device):
if hard_reset:
grad_spike_seq, grad_v_seq, h_seq, spike_seq, v_v_seq, grad_x_seq, grad_v_init, ctx.cp_a0_over_tau, ctx.cp_neg_sum_v_rest_v_c, ctx.cp_reciprocal_tau, ctx.cp_v_threshold, ctx.cp_v_reset, ctx.cp_neuron_num, ctx.cp_numel = cuda_utils.get_contiguous(grad_spike_seq, grad_v_seq, h_seq, spike_seq, v_v_seq, grad_x_seq, grad_v_init, ctx.cp_a0_over_tau, ctx.cp_neg_sum_v_rest_v_c, ctx.cp_reciprocal_tau, ctx.cp_v_threshold, ctx.cp_v_reset, ctx.cp_neuron_num, ctx.cp_numel)
kernel_args = [grad_spike_seq, grad_v_seq, h_seq, spike_seq, v_v_seq, grad_x_seq, grad_v_init, ctx.cp_a0_over_tau, ctx.cp_neg_sum_v_rest_v_c, ctx.cp_reciprocal_tau, ctx.cp_v_threshold, ctx.cp_v_reset, ctx.cp_neuron_num, ctx.cp_numel]
else:
grad_spike_seq, grad_v_seq, h_seq, spike_seq, v_v_seq, grad_x_seq, grad_v_init, ctx.cp_a0_over_tau, ctx.cp_neg_sum_v_rest_v_c, ctx.cp_reciprocal_tau, ctx.cp_v_threshold, ctx.cp_neuron_num, ctx.cp_numel = cuda_utils.get_contiguous(grad_spike_seq, grad_v_seq, h_seq, spike_seq, v_v_seq, grad_x_seq, grad_v_init, ctx.cp_a0_over_tau, ctx.cp_neg_sum_v_rest_v_c, ctx.cp_reciprocal_tau, ctx.cp_v_threshold, ctx.cp_neuron_num, ctx.cp_numel)
kernel_args = [grad_spike_seq, grad_v_seq, h_seq, spike_seq, v_v_seq, grad_x_seq, grad_v_init, ctx.cp_a0_over_tau, ctx.cp_neg_sum_v_rest_v_c, ctx.cp_reciprocal_tau, ctx.cp_v_threshold, ctx.cp_neuron_num, ctx.cp_numel]
kernel(
(ctx.blocks,), (ctx.threads,),
cuda_utils.wrap_args_to_raw_kernel(
device,
*kernel_args
)
)
if ctx.use_pad:
return grad_x_seq[..., :-1], grad_v_init[..., :-1], None, None, None, None, None, None, None, None
else:
return grad_x_seq, grad_v_init, None, None, None, None, None, None, None, None
[文档]class MultiStepIzhikevichNodePTT(torch.autograd.Function):
[文档] @staticmethod
def create_fptt_kernel(hard_reset: bool, dtype: str):
kernel_name = f'IzhikevichNode_fptt_{"hard" if hard_reset else "soft"}Reset_{dtype}'
if dtype == 'fp32':
code = rf'''
extern "C" __global__
void {kernel_name}(const float* x_seq, float* v_v_seq, float* h_seq, float* w_w_seq, float* spike_seq,
const float & reciprocal_tau,
const float & a0,
const float & v_c,
const float & v_threshold,
const float & v_rest,
const float & reciprocal_tau_w,
const float & a,
const float & b, {'const float & v_reset,' if hard_reset else ''}
const int & neuron_num, const int & numel)
'''
code += r'''
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < neuron_num)
{
const int dt = neuron_num;
for(int mem_offset = 0; mem_offset < numel; mem_offset += neuron_num)
{
const int t = index + mem_offset;
h_seq[t] = v_v_seq[t] + reciprocal_tau * (x_seq[t] + a0 * (v_v_seq[t] - v_rest) * (v_v_seq[t] - v_c) - w_w_seq[t]);
const float z = w_w_seq[t] + reciprocal_tau_w * (a * (h_seq[t] - v_rest) - w_w_seq[t]);
if (h_seq[t] >= v_threshold)
{
spike_seq[t] = 1.0f;
'''
if hard_reset:
code += r'''
v_v_seq[t + dt] = v_reset;
'''
else:
code += r'''
v_v_seq[t + dt] = h_seq[t] - v_threshold;
'''
code += r'''
}
else
{
spike_seq[t] = 0.0f;
v_v_seq[t + dt] = h_seq[t];
}
w_w_seq[t + dt] = z + b * spike_seq[t];
}
}
}
'''
else:
raise TypeError
return cupy.RawKernel(code, kernel_name, options=configure.cuda_compiler_options, backend=configure.cuda_compiler_backend)
[文档] @staticmethod
def create_bptt_kernel(sg_cuda_code_fun, hard_reset: bool, detach_reset: bool, dtype: str):
kernel_name = f'IzhikevichNode_bptt_{"hard" if hard_reset else "soft"}Reset_{"detachReset" if detach_reset else ""}_{dtype}'
code_grad_s_to_h = sg_cuda_code_fun(x='over_th', y='grad_s_to_h', dtype=dtype)
if dtype == 'fp32':
code = fr'''
extern "C" __global__
void {kernel_name}(
const float* grad_spike_seq, const float* grad_v_seq,
const float* grad_w_seq, const float* h_seq,
const float* spike_seq, const float* v_v_seq,
float* grad_x_seq, float* grad_v_init, float* grad_w_init,
const float & reciprocal_tau, const float & one_sub_reciprocal_tau_w,
const float & a_over_tau_w, const float & a0_over_tau,
const float & b, const float & neg_sum_v_rest_v_c,
const float & v_threshold, {'const float & v_reset,' if hard_reset else ''}
const int & neuron_num, const int & numel)
'''
code += r'''
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < neuron_num)
{
float grad_h = 0.0f; // grad_h will be used recursively
float grad_w = 0.0f; // grad_w will be used recursively
for(int mem_offset = numel - neuron_num; mem_offset >= 0; mem_offset -= neuron_num)
{
const int t = index + mem_offset;
const float over_th = h_seq[t] - v_threshold;
'''
code += code_grad_s_to_h
if detach_reset:
if hard_reset:
code_grad_v_to_h = r'''
const float grad_v_to_h = 1.0f - spike_seq[t];
'''
else:
code_grad_v_to_h = r'''
const float grad_v_to_h = 1.0f;
'''
else:
if hard_reset:
code_grad_v_to_h = r'''
const float grad_v_to_h = 1.0f - spike_seq[t] + (v_reset - h_seq[t]) * grad_s_to_h;
'''
else:
code_grad_v_to_h = r'''
const float grad_v_to_h = 1.0f - v_threshold * grad_s_to_h;
'''
code += code_grad_v_to_h
code += r'''
grad_w = -reciprocal_tau * grad_h + one_sub_reciprocal_tau_w * grad_w;
grad_h = grad_w * (a_over_tau_w + b * grad_s_to_h) + ((1 + a0_over_tau * (2.0f * v_v_seq[t + neuron_num] + neg_sum_v_rest_v_c)) * grad_h + grad_v_seq[t]) * grad_v_to_h + grad_spike_seq[t] * grad_s_to_h;
grad_x_seq[t] = grad_h * reciprocal_tau;
}
grad_v_init[index] = grad_x_seq[index] * (1.0f + a0_over_tau * (2.0f * v_v_seq[index] + neg_sum_v_rest_v_c));
grad_w_init[index] = -reciprocal_tau * grad_h + one_sub_reciprocal_tau_w * grad_w;
}
}
'''
else:
raise TypeError
return cupy.RawKernel(code, kernel_name, options=configure.cuda_compiler_options, backend=configure.cuda_compiler_backend)
[文档] @staticmethod
def forward(ctx, x_seq: torch.Tensor, v_init: torch.Tensor, w_init: torch.Tensor, tau: float, v_threshold: float, v_reset: float, v_rest: float, a: float, b: float, tau_w: float, v_c: float, a0: float, detach_reset: bool, sg_cuda_code_fun):
requires_grad = x_seq.requires_grad or v_init.requires_grad
device = x_seq.get_device()
if x_seq.dtype == torch.float32:
dtype = 'fp32'
cp_dtype = np.float32
else:
raise NotImplementedError
zero_shape = list(x_seq.shape)
zero_shape[0] *= 4
v_seq, h_seq, w_seq, spike_seq = torch.split(torch.zeros(zero_shape, device=x_seq.device, dtype=x_seq.dtype), x_seq.shape[0])
v_v_seq = torch.cat((v_init.unsqueeze(0), v_seq))
w_w_seq = torch.cat((w_init.unsqueeze(0), w_seq))
with cuda_utils.DeviceEnvironment(device):
numel = x_seq.numel()
neuron_num = numel // x_seq.shape[0]
threads = configure.cuda_threads
blocks = cuda_utils.cal_blocks(neuron_num)
cp_numel = cupy.asarray(numel)
cp_neuron_num = cupy.asarray(neuron_num)
cp_v_threshold = cupy.asarray(v_threshold, dtype=cp_dtype)
cp_v_rest = cupy.asarray(v_rest, dtype=cp_dtype)
cp_v_c = cupy.asarray(v_c, dtype=cp_dtype)
cp_a0 = cupy.asarray(a0, dtype=cp_dtype)
cp_a = cupy.asarray(a, dtype=cp_dtype)
cp_b = cupy.asarray(b, dtype=cp_dtype)
cp_reciprocal_tau = cupy.asarray(1. / tau, dtype=cp_dtype)
cp_reciprocal_tau_w = cupy.asarray(1./ tau_w, dtype=cp_dtype)
cp_a0_over_tau = cupy.asarray(a0 / tau, dtype=cp_dtype)
cp_a_over_tau_w = cupy.asarray(a / tau_w, dtype=cp_dtype)
cp_one_sub_reciprocal_tau_w = cupy.asarray(1. - 1./tau_w, dtype=cp_dtype)
cp_neg_sum_v_rest_v_c = cupy.asarray(-v_rest - v_c, dtype=cp_dtype)
if v_reset is None:
cp_v_reset = None
hard_reset = False
x_seq, v_v_seq, h_seq, w_w_seq, spike_seq, cp_reciprocal_tau, cp_a0, cp_v_c, cp_v_threshold, cp_v_rest, cp_reciprocal_tau_w, cp_a, cp_b, cp_neuron_num, cp_numel = cuda_utils.get_contiguous(x_seq, v_v_seq, h_seq, w_w_seq, spike_seq, cp_reciprocal_tau, cp_a0, cp_v_c, cp_v_threshold, cp_v_rest, cp_reciprocal_tau_w, cp_a, cp_b, cp_neuron_num, cp_numel)
kernel_args = [x_seq, v_v_seq, h_seq, w_w_seq, spike_seq, cp_reciprocal_tau, cp_a0, cp_v_c, cp_v_threshold, cp_v_rest, cp_reciprocal_tau_w, cp_a, cp_b, cp_neuron_num, cp_numel]
else:
cp_v_reset = cupy.asarray(v_reset, dtype=cp_dtype)
hard_reset = True
x_seq, v_v_seq, h_seq, w_w_seq, spike_seq, cp_reciprocal_tau, cp_a0, cp_v_c, cp_v_threshold, cp_v_rest, cp_reciprocal_tau_w, cp_a, cp_b, cp_v_reset, cp_neuron_num, cp_numel = cuda_utils.get_contiguous(x_seq, v_v_seq, h_seq, w_w_seq, spike_seq, cp_reciprocal_tau, cp_a0, cp_v_c, cp_v_threshold, cp_v_rest, cp_reciprocal_tau_w, cp_a, cp_b, cp_v_reset, cp_neuron_num, cp_numel)
kernel_args = [x_seq, v_v_seq, h_seq, w_w_seq, spike_seq, cp_reciprocal_tau, cp_a0, cp_v_c, cp_v_threshold, cp_v_rest, cp_reciprocal_tau_w, cp_a, cp_b, cp_v_reset, cp_neuron_num, cp_numel]
kernel = MultiStepIzhikevichNodePTT.create_fptt_kernel(hard_reset, dtype)
kernel(
(blocks,), (threads,),
cuda_utils.wrap_args_to_raw_kernel(
device,
*kernel_args
)
)
if requires_grad:
if configure.save_spike_as_bool_in_neuron_kernel:
ctx.s_shape = spike_seq.shape
ctx.s_tk = tensor_cache.BOOL_TENSOR_CACHE.store_bool(spike_seq)
ctx.save_for_backward(h_seq, v_v_seq)
else:
ctx.save_for_backward(h_seq, spike_seq, v_v_seq)
ctx.blocks = blocks
ctx.threads = threads
ctx.cp_numel = cp_numel
ctx.cp_neuron_num = cp_neuron_num
ctx.cp_reciprocal_tau = cp_reciprocal_tau
ctx.cp_one_sub_reciprocal_tau_w = cp_one_sub_reciprocal_tau_w
ctx.cp_a_over_tau_w = cp_a_over_tau_w
ctx.cp_a0_over_tau = cp_a0_over_tau
ctx.cp_b = cp_b
ctx.cp_neg_sum_v_rest_v_c = cp_neg_sum_v_rest_v_c
ctx.cp_v_threshold = cp_v_threshold
ctx.cp_v_reset = cp_v_reset
ctx.detach_reset = detach_reset
ctx.sg_cuda_code_fun = sg_cuda_code_fun
return spike_seq, v_v_seq[1:, ], w_w_seq[1:, ]
[文档] @staticmethod
def backward(ctx, grad_spike_seq, grad_v_seq, grad_w_seq):
device = grad_spike_seq.get_device()
if configure.save_spike_as_bool_in_neuron_kernel:
spike_seq = tensor_cache.BOOL_TENSOR_CACHE.get_float(ctx.s_tk, ctx.s_shape)
h_seq, v_v_seq = ctx.saved_tensors
else:
h_seq, spike_seq, v_v_seq = ctx.saved_tensors
zero_shape = list(grad_spike_seq.shape)
zero_shape[0] += 2
zero_data = torch.zeros(zero_shape, device=grad_spike_seq.device, dtype=grad_spike_seq.dtype)
grad_x_seq = zero_data[0: -2]
grad_v_init = zero_data[-2]
grad_w_init = zero_data[-1]
if ctx.cp_v_reset is None:
hard_reset = False
else:
hard_reset = True
if grad_spike_seq.dtype == torch.float32:
dtype = 'fp32'
else:
raise NotImplementedError
kernel = MultiStepIzhikevichNodePTT.create_bptt_kernel(ctx.sg_cuda_code_fun, hard_reset, ctx.detach_reset, dtype)
with cuda_utils.DeviceEnvironment(device):
if hard_reset:
grad_spike_seq, grad_v_seq, grad_w_seq, h_seq, spike_seq, v_v_seq, grad_x_seq, grad_v_init, grad_w_init, ctx.cp_reciprocal_tau, ctx.cp_one_sub_reciprocal_tau_w, ctx.cp_a_over_tau_w, ctx.cp_a0_over_tau, ctx.cp_b, ctx.cp_neg_sum_v_rest_v_c, ctx.cp_v_threshold, ctx.cp_v_reset, ctx.cp_neuron_num, ctx.cp_numel = cuda_utils.get_contiguous(grad_spike_seq, grad_v_seq, grad_w_seq, h_seq, spike_seq, v_v_seq, grad_x_seq, grad_v_init, grad_w_init, ctx.cp_reciprocal_tau, ctx.cp_one_sub_reciprocal_tau_w, ctx.cp_a_over_tau_w, ctx.cp_a0_over_tau, ctx.cp_b, ctx.cp_neg_sum_v_rest_v_c, ctx.cp_v_threshold, ctx.cp_v_reset, ctx.cp_neuron_num, ctx.cp_numel)
kernel_args = [grad_spike_seq, grad_v_seq, grad_w_seq, h_seq, spike_seq, v_v_seq, grad_x_seq, grad_v_init, grad_w_init, ctx.cp_reciprocal_tau, ctx.cp_one_sub_reciprocal_tau_w, ctx.cp_a_over_tau_w, ctx.cp_a0_over_tau, ctx.cp_b, ctx.cp_neg_sum_v_rest_v_c, ctx.cp_v_threshold, ctx.cp_v_reset, ctx.cp_neuron_num, ctx.cp_numel]
else:
grad_spike_seq, grad_v_seq, grad_w_seq, h_seq, spike_seq, v_v_seq, grad_x_seq, grad_v_init, grad_w_init, ctx.cp_reciprocal_tau, ctx.cp_one_sub_reciprocal_tau_w, ctx.cp_a_over_tau_w, ctx.cp_a0_over_tau, ctx.cp_b, ctx.cp_neg_sum_v_rest_v_c, ctx.cp_v_threshold, ctx.cp_neuron_num, ctx.cp_numel = cuda_utils.get_contiguous(grad_spike_seq, grad_v_seq, grad_w_seq, h_seq, spike_seq, v_v_seq, grad_x_seq, grad_v_init, grad_w_init, ctx.cp_reciprocal_tau, ctx.cp_one_sub_reciprocal_tau_w, ctx.cp_a_over_tau_w, ctx.cp_a0_over_tau, ctx.cp_b, ctx.cp_neg_sum_v_rest_v_c, ctx.cp_v_threshold, ctx.cp_neuron_num, ctx.cp_numel)
kernel_args = [grad_spike_seq, grad_v_seq, grad_w_seq, h_seq, spike_seq, v_v_seq, grad_x_seq, grad_v_init, grad_w_init, ctx.cp_reciprocal_tau, ctx.cp_one_sub_reciprocal_tau_w, ctx.cp_a_over_tau_w, ctx.cp_a0_over_tau, ctx.cp_b, ctx.cp_neg_sum_v_rest_v_c, ctx.cp_v_threshold, ctx.cp_neuron_num, ctx.cp_numel]
kernel(
(ctx.blocks,), (ctx.threads,),
cuda_utils.wrap_args_to_raw_kernel(
device,
*kernel_args
)
)
return grad_x_seq, grad_v_init, grad_w_init, None, None, None, None, None, None, None, None, None, None, None
[文档]class MultiStepEIFNodePTT(torch.autograd.Function):
[文档] @staticmethod
def create_fptt_kernel(hard_reset: bool, dtype: str):
kernel_name = f'EIFNode_fptt_{"hard" if hard_reset else "soft"}Reset_{dtype}'
if dtype == 'fp32':
code = rf'''
extern "C" __global__
void {kernel_name}(const float* x_seq, float* v_v_seq, float* h_seq, float* spike_seq,
const float & reciprocal_tau,
const float & delta_T,
const float & theta_rh,
const float & v_threshold,
const float & v_rest, {'const float & v_reset,' if hard_reset else ''}
const int & neuron_num, const int & numel)
'''
code += r'''
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < neuron_num)
{
const int dt = neuron_num;
for(int mem_offset = 0; mem_offset < numel; mem_offset += neuron_num)
{
const int t = index + mem_offset;
h_seq[t] = v_v_seq[t] + reciprocal_tau * (x_seq[t] - v_v_seq[t] + v_rest + delta_T * expf((v_v_seq[t] - theta_rh) / delta_T));
if (h_seq[t] >= v_threshold)
{
spike_seq[t] = 1.0f;
'''
if hard_reset:
code += r'''
v_v_seq[t + dt] = v_reset;
'''
else:
code += r'''
v_v_seq[t + dt] = h_seq[t] - v_threshold;
'''
code += r'''
}
else
{
spike_seq[t] = 0.0f;
v_v_seq[t + dt] = h_seq[t];
}
}
}
}
'''
elif dtype == 'fp16':
code = rf'''
#include <cuda_fp16.h>
extern "C" __global__
void {kernel_name}(const half2* x_seq, half2* v_v_seq, half2* h_seq, half2* spike_seq,
const half & reciprocal_tau,
const half & delta_T,
const half & theta_rh,
const half & v_threshold,
const half & v_rest, {'const half & v_reset,' if hard_reset else ''}
const int & neuron_num, const int & numel)
'''
code += r'''
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
const int stride = neuron_num >> 1;
if (index < stride)
{
const int numel_2 = numel >> 1;
const half2 reciprocal_tau_half2 = __half2half2(reciprocal_tau);
const half2 delta_T_half2 = __half2half2(delta_T);
const half2 theta_rh_half2 = __half2half2(theta_rh);
const half2 v_threshold_half2 = __half2half2(v_threshold);
const half2 v_rest_half2 = __half2half2(v_rest);
'''
if hard_reset:
code += r'''
const half2 v_reset_half2 = __half2half2(v_reset);
'''
code += r'''
for(int mem_offset = 0; mem_offset < numel_2; mem_offset += stride)
{
const int t = index + mem_offset;
h_seq[t] = __hfma2(__hfma2(h2exp(__h2div(__hsub2(v_v_seq[t], theta_rh_half2), delta_T_half2)), delta_T_half2, __hadd2(__hsub2(x_seq[t], v_v_seq[t]), v_rest_half2)), reciprocal_tau_half2, v_v_seq[t]);
spike_seq[t] = __hgeu2(h_seq[t], v_threshold_half2);
'''
if hard_reset:
code += r'''
v_v_seq[t + stride] = __hadd2(__hmul2(spike_seq[t], v_reset_half2), __hmul2(__hsub2(__float2half2_rn(1.0f), spike_seq[t]), h_seq[t]));
'''
else:
code += r'''
v_v_seq[t + stride] = __hadd2(__hmul2(spike_seq[t], __hsub2(h_seq[t], v_threshold_half2)), __hmul2(__hsub2(__float2half2_rn(1.0f), spike_seq[t]), h_seq[t]));
'''
code += r'''
}
}
}
'''
else:
raise TypeError
return cupy.RawKernel(code, kernel_name, options=configure.cuda_compiler_options, backend=configure.cuda_compiler_backend)
[文档] @staticmethod
def create_bptt_kernel(sg_cuda_code_fun, hard_reset: bool, detach_reset: bool, dtype: str):
kernel_name = f'EIFNode_bptt_{"hard" if hard_reset else "soft"}Reset_{"detachReset" if detach_reset else ""}_{dtype}'
code_grad_s_to_h = sg_cuda_code_fun(x='over_th', y='grad_s_to_h', dtype=dtype)
if dtype == 'fp32':
code = fr'''
extern "C" __global__
void {kernel_name}(
const float* grad_spike_seq, const float* grad_v_seq, const float* h_seq, const float* spike_seq, const float* v_v_seq,
float* grad_x_seq, float* grad_v_init,
const float & theta_rh, const float & reciprocal_delta_T,
const float & reciprocal_tau, const float & one_sub_reciprocal_tau,
const float & v_threshold, {'const float & v_reset,' if hard_reset else ''}
const int & neuron_num, const int & numel)
'''
code += r'''
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < neuron_num)
{
float grad_h = 0.0f; // grad_h will be used recursively
for(int mem_offset = numel - neuron_num; mem_offset >= 0; mem_offset -= neuron_num)
{
const int t = index + mem_offset;
const float over_th = h_seq[t] - v_threshold;
'''
code += code_grad_s_to_h
if detach_reset:
if hard_reset:
code_grad_v_to_h = r'''
const float grad_v_to_h = 1.0f - spike_seq[t];
'''
else:
code_grad_v_to_h = r'''
const float grad_v_to_h = 1.0f;
'''
else:
if hard_reset:
code_grad_v_to_h = r'''
const float grad_v_to_h = 1.0f - spike_seq[t] + (v_reset - h_seq[t]) * grad_s_to_h;
'''
else:
code_grad_v_to_h = r'''
const float grad_v_to_h = 1.0f - v_threshold * grad_s_to_h;
'''
code += code_grad_v_to_h
code += r'''
grad_h = grad_spike_seq[t] * grad_s_to_h + (grad_v_seq[t] + grad_h * (one_sub_reciprocal_tau + reciprocal_tau * expf((v_v_seq[t + neuron_num] - theta_rh) * reciprocal_delta_T))) * grad_v_to_h;
grad_x_seq[t] = grad_h * reciprocal_tau;
}
grad_v_init[index] = grad_x_seq[index] * (one_sub_reciprocal_tau + reciprocal_tau * expf((v_v_seq[index] - theta_rh) * reciprocal_delta_T));
}
}
'''
elif dtype == 'fp16':
code = fr'''
#include <cuda_fp16.h>
extern "C" __global__
void {kernel_name}(
const half2* grad_spike_seq, const half2* grad_v_seq, const half2* h_seq, const half2* spike_seq, const half2* v_v_seq,
half2* grad_x_seq, half2* grad_v_init,
const half & theta_rh, const half & reciprocal_delta_T,
const half & reciprocal_tau, const half & one_sub_reciprocal_tau,
const half & v_threshold, {'const half & v_reset,' if hard_reset else ''}
const int & neuron_num, const int & numel)
'''
code += r'''
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
const int stride = neuron_num >> 1;
if (index < stride)
{
const half2 reciprocal_tau_half2 = __half2half2(reciprocal_tau);
const half2 one_sub_reciprocal_tau_half2 = __half2half2(one_sub_reciprocal_tau);
const half2 reciprocal_delta_T_half2 = __half2half2(reciprocal_delta_T);
const half2 theta_rh_half2 = __half2half2(theta_rh);
const half2 v_threshold_half2 = __half2half2(v_threshold);
'''
if hard_reset:
code += r'''
const half2 v_reset_half2 = __half2half2(v_reset);
'''
code += r'''
half2 grad_h = __float2half2_rn(0.0f); // grad_h will be used recursively
for(int mem_offset = (numel >> 1) - stride; mem_offset >= 0; mem_offset -= stride)
{
const int t = index + mem_offset;
const half2 over_th = __hsub2(h_seq[t], v_threshold_half2);
'''
code += code_grad_s_to_h
if detach_reset:
if hard_reset:
code_grad_v_to_h = r'''
const half2 grad_v_to_h = __hsub2(__float2half2_rn(1.0f), spike_seq[t]);
'''
else:
code_grad_v_to_h = r'''
const half2 grad_v_to_h = __float2half2_rn(1.0f);
'''
else:
if hard_reset:
code_grad_v_to_h = r'''
const half2 grad_v_to_h = __hfma2(__hsub2(v_reset_half2, h_seq[t]), grad_s_to_h, __hsub2(__float2half2_rn(1.0f), spike_seq[t]));
'''
else:
code_grad_v_to_h = r'''
const half2 grad_v_to_h = __hsub2(__float2half2_rn(1.0f), __hmul2(v_threshold_half2, grad_s_to_h));
'''
code += code_grad_v_to_h
code += r'''
grad_h = __hfma2(__hfma2(__hfma2(h2exp(__hmul2(__hsub2(v_v_seq[t + stride], theta_rh_half2), reciprocal_delta_T_half2)), reciprocal_tau_half2, one_sub_reciprocal_tau_half2), grad_h, grad_v_seq[t]), grad_v_to_h, __hmul2(grad_spike_seq[t], grad_s_to_h));
grad_x_seq[t] = __hmul2(grad_h, reciprocal_tau_half2);
}
grad_v_init[index] = __hmul2(__hfma2(h2exp(__hmul2(__hsub2(v_v_seq[index], theta_rh_half2), reciprocal_delta_T_half2)), reciprocal_tau_half2, one_sub_reciprocal_tau_half2), grad_x_seq[index]);
}
}
'''
else:
raise TypeError
return cupy.RawKernel(code, kernel_name, options=configure.cuda_compiler_options, backend=configure.cuda_compiler_backend)
[文档] @staticmethod
def forward(ctx, x_seq: torch.Tensor, v_init: torch.Tensor, tau: float, v_threshold: float, v_reset: float, v_rest: float, theta_rh: float, delta_T: float, detach_reset: bool, sg_cuda_code_fun):
requires_grad = x_seq.requires_grad or v_init.requires_grad
device = x_seq.get_device()
if x_seq.dtype == torch.float32:
dtype = 'fp32'
cp_dtype = np.float32
elif x_seq.dtype == torch.float16:
dtype = 'fp16'
cp_dtype = np.half
else:
raise NotImplementedError
use_pad = False
if dtype == 'fp16' and v_init.numel() % 2 != 0:
# only fp16 needs even numel because we use half2 to accelerate
# when numel is odd, we will pad x_seq
use_pad = True
x_seq = F.pad(x_seq, (0, 1)) # [T, N] -> [T, N + 1]
v_init = F.pad(v_init, (0, 1)) # [N] -> [N + 1]
zero_shape = list(x_seq.shape)
zero_shape[0] *= 3
v_seq, h_seq, spike_seq = torch.split(torch.zeros(zero_shape, device=x_seq.device, dtype=x_seq.dtype), x_seq.shape[0])
v_v_seq = torch.cat((v_init.unsqueeze(0), v_seq))
with cuda_utils.DeviceEnvironment(device):
numel = x_seq.numel()
neuron_num = numel // x_seq.shape[0]
threads = configure.cuda_threads
if dtype == 'fp16':
assert neuron_num % 2 == 0
blocks = cuda_utils.cal_blocks(neuron_num >> 1)
# we will take two neurons to calculate as one neuron in cuda half2
else:
blocks = cuda_utils.cal_blocks(neuron_num)
cp_numel = cupy.asarray(numel)
cp_neuron_num = cupy.asarray(neuron_num)
cp_v_threshold = cupy.asarray(v_threshold, dtype=cp_dtype)
cp_v_rest = cupy.asarray(v_rest, dtype=cp_dtype)
cp_theta_rh = cupy.asarray(theta_rh, dtype=cp_dtype)
cp_delta_T = cupy.asarray(delta_T, dtype=cp_dtype)
cp_reciprocal_delta_T = cupy.asarray(1. / delta_T, dtype=cp_dtype)
cp_reciprocal_tau = cupy.asarray(1./tau, dtype=cp_dtype)
cp_one_sub_reciprocal_tau = cupy.asarray(1. - 1./tau, dtype=cp_dtype)
if v_reset is None:
cp_v_reset = None
hard_reset = False
x_seq, v_v_seq, h_seq, spike_seq, cp_reciprocal_tau, cp_delta_T, cp_theta_rh, cp_v_threshold, cp_v_rest, cp_neuron_num, cp_numel = cuda_utils.get_contiguous(x_seq, v_v_seq, h_seq, spike_seq, cp_reciprocal_tau, cp_delta_T, cp_theta_rh, cp_v_threshold, cp_v_rest, cp_neuron_num, cp_numel)
kernel_args = [x_seq, v_v_seq, h_seq, spike_seq, cp_reciprocal_tau, cp_delta_T, cp_theta_rh, cp_v_threshold, cp_v_rest, cp_neuron_num, cp_numel]
else:
cp_v_reset = cupy.asarray(v_reset, dtype=cp_dtype)
hard_reset = True
x_seq, v_v_seq, h_seq, spike_seq, cp_reciprocal_tau, cp_delta_T, cp_theta_rh, cp_v_threshold, cp_v_rest, cp_v_reset, cp_neuron_num, cp_numel = cuda_utils.get_contiguous(x_seq, v_v_seq, h_seq, spike_seq, cp_reciprocal_tau, cp_delta_T, cp_theta_rh, cp_v_threshold, cp_v_rest, cp_v_reset, cp_neuron_num, cp_numel)
kernel_args = [x_seq, v_v_seq, h_seq, spike_seq, cp_reciprocal_tau, cp_delta_T, cp_theta_rh, cp_v_threshold, cp_v_rest, cp_v_reset, cp_neuron_num, cp_numel]
kernel = MultiStepEIFNodePTT.create_fptt_kernel(hard_reset, dtype)
kernel(
(blocks,), (threads,),
cuda_utils.wrap_args_to_raw_kernel(
device,
*kernel_args
)
)
if requires_grad:
ctx.use_pad = use_pad
if configure.save_spike_as_bool_in_neuron_kernel:
ctx.s_shape = spike_seq.shape
ctx.s_tk = tensor_cache.BOOL_TENSOR_CACHE.store_bool(spike_seq)
ctx.save_for_backward(h_seq, v_v_seq)
else:
ctx.save_for_backward(h_seq, spike_seq, v_v_seq)
ctx.blocks = blocks
ctx.threads = threads
ctx.cp_numel = cp_numel
ctx.cp_neuron_num = cp_neuron_num
ctx.cp_reciprocal_tau = cp_reciprocal_tau
ctx.cp_one_sub_reciprocal_tau = cp_one_sub_reciprocal_tau
ctx.cp_theta_rh = cp_theta_rh
ctx.cp_reciprocal_delta_T = cp_reciprocal_delta_T
ctx.cp_v_threshold = cp_v_threshold
ctx.cp_v_reset = cp_v_reset
ctx.detach_reset = detach_reset
ctx.sg_cuda_code_fun = sg_cuda_code_fun
if use_pad:
return spike_seq[..., :-1], v_v_seq[1:, ..., :-1]
else:
return spike_seq, v_v_seq[1:, ]
[文档] @staticmethod
def backward(ctx, grad_spike_seq, grad_v_seq):
if ctx.use_pad:
# grad_spike_seq.shape = [T, N]
# grad_v_seq.shape = [T, N]
# h_seq.shape = [T, N + 1]
# spike_seq.shape = [T, N + 1]
grad_spike_seq = F.pad(grad_spike_seq, (0, 1))
grad_v_seq = F.pad(grad_v_seq, (0, 1))
device = grad_spike_seq.get_device()
if configure.save_spike_as_bool_in_neuron_kernel:
spike_seq = tensor_cache.BOOL_TENSOR_CACHE.get_float(ctx.s_tk, ctx.s_shape)
h_seq, v_v_seq = ctx.saved_tensors
else:
h_seq, spike_seq, v_v_seq = ctx.saved_tensors
zero_shape = list(grad_spike_seq.shape)
zero_shape[0] += 1
zero_data = torch.zeros(zero_shape, device=grad_spike_seq.device, dtype=grad_spike_seq.dtype)
grad_x_seq = zero_data[0: -1]
grad_v_init = zero_data[-1]
if ctx.cp_v_reset is None:
hard_reset = False
else:
hard_reset = True
if grad_spike_seq.dtype == torch.float32:
dtype = 'fp32'
elif grad_spike_seq.dtype == torch.float16:
dtype = 'fp16'
else:
raise NotImplementedError
kernel = MultiStepEIFNodePTT.create_bptt_kernel(ctx.sg_cuda_code_fun, hard_reset, ctx.detach_reset, dtype)
with cuda_utils.DeviceEnvironment(device):
if hard_reset:
grad_spike_seq, grad_v_seq, h_seq, spike_seq, v_v_seq, grad_x_seq, grad_v_init, ctx.cp_theta_rh, ctx.cp_reciprocal_delta_T, ctx.cp_reciprocal_tau, ctx.cp_one_sub_reciprocal_tau, ctx.cp_v_threshold, ctx.cp_v_reset, ctx.cp_neuron_num, ctx.cp_numel = cuda_utils.get_contiguous(grad_spike_seq, grad_v_seq, h_seq, spike_seq, v_v_seq, grad_x_seq, grad_v_init, ctx.cp_theta_rh, ctx.cp_reciprocal_delta_T, ctx.cp_reciprocal_tau, ctx.cp_one_sub_reciprocal_tau, ctx.cp_v_threshold, ctx.cp_v_reset, ctx.cp_neuron_num, ctx.cp_numel)
kernel_args = [grad_spike_seq, grad_v_seq, h_seq, spike_seq, v_v_seq, grad_x_seq, grad_v_init, ctx.cp_theta_rh, ctx.cp_reciprocal_delta_T, ctx.cp_reciprocal_tau, ctx.cp_one_sub_reciprocal_tau, ctx.cp_v_threshold, ctx.cp_v_reset, ctx.cp_neuron_num, ctx.cp_numel]
else:
grad_spike_seq, grad_v_seq, h_seq, spike_seq, v_v_seq, grad_x_seq, grad_v_init, ctx.cp_theta_rh, ctx.cp_reciprocal_delta_T, ctx.cp_reciprocal_tau, ctx.cp_one_sub_reciprocal_tau, ctx.cp_v_threshold, ctx.cp_neuron_num, ctx.cp_numel = cuda_utils.get_contiguous(grad_spike_seq, grad_v_seq, h_seq, spike_seq, v_v_seq, grad_x_seq, grad_v_init, ctx.cp_theta_rh, ctx.cp_reciprocal_delta_T, ctx.cp_reciprocal_tau, ctx.cp_one_sub_reciprocal_tau, ctx.cp_v_threshold, ctx.cp_neuron_num, ctx.cp_numel)
kernel_args = [grad_spike_seq, grad_v_seq, h_seq, spike_seq, v_v_seq, grad_x_seq, grad_v_init, ctx.cp_theta_rh, ctx.cp_reciprocal_delta_T, ctx.cp_reciprocal_tau, ctx.cp_one_sub_reciprocal_tau, ctx.cp_v_threshold, ctx.cp_neuron_num, ctx.cp_numel]
kernel(
(ctx.blocks,), (ctx.threads,),
cuda_utils.wrap_args_to_raw_kernel(
device,
*kernel_args
)
)
if ctx.use_pad:
return grad_x_seq[..., :-1], grad_v_init[..., :-1], None, None, None, None, None, None, None, None
else:
return grad_x_seq, grad_v_init, None, None, None, None, None, None, None, None
[文档]def save_cuda_codes(cu_file_path: str = './spikingjelly/activation_based/neuron_kernel_sample.cu'):
# save all cuda codes to files
with open(cu_file_path, 'w+') as cu_file:
cu_file.write('// This file is created by spikingjelly.activation_based.neuron_kernel.save_cuda_codes.\n')
cu_file.write('// Note that codes in this file will not be executed This file is just created for reading.\n')
for ms_neu in [MultiStepIFNodePTT]:
cu_file.write('\n// ' + ms_neu.__name__ + '\n')
for sg in surrogate._has_cuda_:
for hard_reset in [True, False]:
for dtype in ['fp32', 'fp16']:
cu_file.write(
f'\n// {ms_neu.__name__} fptt {sg.__name__}, hard_reset={hard_reset}, dtype={dtype}\n')
fp_codes = ms_neu.create_fptt_kernel(hard_reset, dtype).code
cu_file.write(fp_codes)
for detach_reset in [True, False]:
cu_file.write(
f'\n// {ms_neu.__name__} bptt {sg.__name__}, hard_reset={hard_reset}, dtype={dtype}, detach_reset={detach_reset}\n')
bp_codes = ms_neu.create_bptt_kernel(sg().cuda_code, hard_reset, detach_reset,
dtype).code
cu_file.write(bp_codes)