spikingjelly.activation_based.auto_cuda.neuron_kernel 源代码

import torch
import torch.nn.functional as F
import numpy as np
import logging

try:
    import cupy
except BaseException as e:
    logging.info(f'spikingjelly.activation_based.auto_cuda.neuronal_kernel: {e}')
    cupy = None
    

from .. import cuda_utils, surrogate
from ... import configure
from typing import Callable, Iterable
from . import base, cfunction
import math

[文档]def neuronal_hard_reset(v_next: str, h: str, spike: str, v_reset: str, dtype: str = 'float'): if dtype == 'float': return f'{v_next} = {h} * (1.0f - {spike}) + {v_reset} * {spike};' elif dtype == 'half2': return f'{v_next} = __hfma2({h}, __hsub2(__float2half2_rn(1.0f), {spike}), __hmul2(v_reset, {spike}));' else: raise NotImplementedError(dtype)
[文档]def neuronal_soft_reset(v_next: str, h: str, spike: str, v_th: str, dtype: str = 'float'): if dtype == 'float': return f'{v_next} = {h} - {v_th} * {spike};' elif dtype == 'half2': return f'{v_next} = __hsub2({h}, __hmul2({v_th}, {spike}));' else: raise NotImplementedError(dtype)
[文档]def neuronal_fire(spike: str, v: str, v_th: str, dtype: str = 'float'): if dtype == 'float': return cfunction.heaviside(y=spike, x=f'({v} - {v_th})', dtype=dtype) elif dtype == 'half2': return cfunction.heaviside(y=spike, x=f'__hsub2({v}, {v_th})', dtype=dtype) else: raise NotImplementedError(dtype)
[文档]class NeuronFPTTKernel(base.CKernel2D): def __init__(self, hard_reset: bool, dtype: str): super().__init__( kernel_name=f'{self.__class__.__name__}_{dtype}_{"hard_reset" if hard_reset else "soft_reset"}', reverse=False) self.hard_reset = hard_reset self.dtype = dtype self.add_param(ctype=f'const {dtype} *', cname='x_seq') self.add_param(ctype=f'{dtype} *', cname='v_v_seq') self.add_param(ctype=f'{dtype} *', cname='h_seq') self.add_param(ctype=f'{dtype} *', cname='spike_seq') self.add_param(ctype=f'{dtype} &', cname='v_th') if hard_reset: self.add_param(ctype=f'{dtype} &', cname='v_reset')
[文档] def neuronal_charge(self) -> str: """ :return: CUDA code :rtype: str Returns CUDA code for calculating :math:`H[t] = f(X[t], V[t-1], ...)`. This function should define how ``h_seq[t]`` is calculated by ``x_seq[t], v_v_seq[t]`` and other params if the neuron needs. For example, the IF neuron define this function as: .. code-block:: python def neuronal_charge(self) -> str: # note that v_v_seq[t] is v_seq[t - dt] return cfunction.add(z='h_seq[t]', x='x_seq[t]', y='v_v_seq[t]', dtype=self.dtype) """ return '// neuronal_charge should be defined here!'
@property def core(self): core_codes = base.CodeTyper(18) core_codes.append(self.neuronal_charge()) core_codes.append(neuronal_fire(spike='spike_seq[t]', v='h_seq[t]', v_th='v_th', dtype=self.dtype)) if self.hard_reset: core_codes.append( neuronal_hard_reset(v_next='v_v_seq[t + dt]', h='h_seq[t]', spike='spike_seq[t]', v_reset='v_reset', dtype=self.dtype)) else: core_codes.append( neuronal_soft_reset(v_next='v_v_seq[t + dt]', h='h_seq[t]', spike='spike_seq[t]', v_th='v_th', dtype=self.dtype)) self._core = core_codes.codes return self._core
[文档]class NeuronBPTTKernel(base.CKernel2D): def __init__(self, surrogate_function: Callable, hard_reset: bool, detach_reset: bool, dtype: str): super().__init__( kernel_name=f'{self.__class__.__name__}_{dtype}_{"hard_reset" if hard_reset else "soft_reset"}_{"detach_reset" if detach_reset else "nodetach_reset"}', reverse=True) self.surrogate_function = surrogate_function self.hard_reset = hard_reset self.detach_reset = detach_reset self.dtype = dtype self.add_param(ctype=f'const {dtype} *', cname='grad_spike_seq') self.add_param(ctype=f'const {dtype} *', cname='grad_v_seq') self.add_param(ctype=f'const {dtype} *', cname='h_seq') self.add_param(ctype=f'{dtype} *', cname='grad_x_seq') self.add_param(ctype=f'{dtype} *', cname='grad_v_init') self.add_param(ctype=f'{dtype} &', cname='v_th') if hard_reset: self.add_param(ctype=f'{dtype} &', cname='v_reset') @property def pre_core(self): codes = base.CodeTyper(16) if self.dtype == 'float': codes.append('float grad_h = 0.0f;') elif self.dtype == 'half2': codes.append(cfunction.float2half2(y='half2 grad_h', x='0.0f')) else: raise NotImplementedError(self.dtype) self._pre_core = codes.codes return self._pre_core @property def post_core(self): codes = base.CodeTyper(16) codes.append(self.grad_h_next_to_v()) codes.append(cfunction.mul(z='grad_v_init[index]', x='grad_h', y='grad_h_next_to_v', dtype=self.dtype)) self._post_core = codes.codes return self._post_core
[文档] def grad_h_next_to_v(self) -> str: """ :return: CUDA code :rtype: str Returns CUDA code for calculating :math:`\\frac{\\mathrm{d} H[t+1]}{\\mathrm{d} V[t]}`. This function should define how ``grad_h_next_to_v`` is calculated. Note that ``grad_h_next_to_v`` has not been declared. Thus, this function should also declare ``grad_h_next_to_v``. For example, the IF neuron define this function as: .. code-block:: python def grad_h_next_to_v(self) -> str: return cfunction.constant(y=f'const {self.dtype} grad_h_next_to_v', x=1., dtype=self.dtype) """ return '// grad_h_next_to_v should be defined here!'
[文档] def grad_h_to_x(self) -> str: """ :return: CUDA code :rtype: str Returns CUDA code for calculating :math:`\\frac{\\mathrm{d} H[t]}{\\mathrm{d} X[t]}`. This function should define how ``grad_h_to_x`` is calculated. Note that ``grad_h_to_x`` has not been declared. Thus, this function should also declare ``grad_h_to_x``. For example, the IF neuron define this function as: .. code-block:: python def grad_h_to_x(self) -> str: return cfunction.constant(y=f'const {self.dtype} grad_h_to_x', x=1., dtype=self.dtype) """ return '// grad_h_to_x should be defined here!'
@property def core(self): core_codes = base.CodeTyper(18) core_codes.append(cfunction.sub(z=f'const {self.dtype} over_th', x='h_seq[t]', y='v_th', dtype=self.dtype)) core_codes.append(cfunction.heaviside(y=f'const {self.dtype} spike_seq_t', x='over_th', dtype=self.dtype)) core_codes.append(self.surrogate_function(y=f'const {self.dtype} grad_s_to_h', x='over_th', dtype=self.dtype)) if self.hard_reset: core_codes.append( cfunction.sub(z=f'{self.dtype} grad_v_to_h', x=cfunction.constant(y=None, x=1., dtype=self.dtype), y='spike_seq_t', dtype=self.dtype)) if not self.detach_reset: with base.CodeBlock(core_codes): core_codes.append( cfunction.sub(z=f'{self.dtype} temp_var', x='v_reset', y='h_seq[t]', dtype=self.dtype)) core_codes.append(cfunction.mul(z=f'temp_var', x='temp_var', y='grad_s_to_h', dtype=self.dtype)) core_codes.append(cfunction.add(z=f'grad_v_to_h', x='temp_var', y='grad_v_to_h', dtype=self.dtype)) else: core_codes.append(f'{self.dtype} grad_v_to_h = {cfunction.constant(None, 1., dtype=self.dtype)}') if not self.detach_reset: with base.CodeBlock(core_codes): core_codes.append( cfunction.mul(z=f'{self.dtype} temp_var', x='v_th', y='grad_s_to_h', dtype=self.dtype)) core_codes.append(cfunction.sub(z=f'grad_v_to_h', x='grad_v_to_h', y='temp_var', dtype=self.dtype)) core_codes.append(self.grad_h_next_to_v()) core_codes.append(cfunction.mul(z=f'grad_h', x='grad_h', y='grad_h_next_to_v', dtype=self.dtype)) core_codes.append(cfunction.add(z='grad_h', x='grad_v_seq[t]', y='grad_h', dtype=self.dtype)) core_codes.append(cfunction.mul(z='grad_h', x='grad_h', y='grad_v_to_h', dtype=self.dtype)) with base.CodeBlock(core_codes): core_codes.append( cfunction.mul(z=f'{self.dtype} temp_var', x='grad_spike_seq[t]', y='grad_s_to_h', dtype=self.dtype)) core_codes.append(cfunction.add(z='grad_h', x='grad_h', y='temp_var', dtype=self.dtype)) core_codes.append(self.grad_h_to_x()) core_codes.append(cfunction.mul(z='grad_x_seq[t]', x='grad_h', y='grad_h_to_x', dtype=self.dtype)) self._core = core_codes.codes return self._core
[文档]class IFNodeFPTTKernel(NeuronFPTTKernel):
[文档] def neuronal_charge(self) -> str: return cfunction.add(z='h_seq[t]', x='x_seq[t]', y='v_v_seq[t]', dtype=self.dtype)
[文档]class IFNodeBPTTKernel(NeuronBPTTKernel):
[文档] def grad_h_next_to_v(self) -> str: return cfunction.constant(y=f'const {self.dtype} grad_h_next_to_v', x=1., dtype=self.dtype)
[文档] def grad_h_to_x(self) -> str: return cfunction.constant(y=f'const {self.dtype} grad_h_to_x', x=1., dtype=self.dtype)
[文档]def if_requires_grad(items: Iterable): requires_grad = False for item in items: if isinstance(item, torch.Tensor): if item.requires_grad: requires_grad = True break return requires_grad
[文档]def scalar_to_cupy(py_dict: dict, ref: str = 'x_seq'): device = py_dict[ref].get_device() dtype = py_dict[ref].dtype with cuda_utils.DeviceEnvironment(device): for key, value in py_dict.items(): if isinstance(value, float): if dtype == torch.float32: value = cupy.asarray(value, dtype=np.float32) elif dtype == torch.float16: value = cupy.asarray([value, value], dtype=np.float16) else: raise NotImplementedError(dtype) py_dict[key] = value elif isinstance(value, int): py_dict[key] = cupy.asarray(value)
[文档]def new_tensors(news: tuple, py_dict: dict, ref: str = 'x_seq'): ref = py_dict[ref] zero_shape = list(ref.shape) zero_shape[0] *= news.__len__() for i, item in enumerate(torch.split(torch.zeros(zero_shape, device=ref.device, dtype=ref.dtype),ref.shape[0])): py_dict[news[i]] = item
[文档]class NeuronATGFBase:
[文档] @staticmethod def pre_forward(py_dict: dict): """ :param py_dict: a dict built from the neuron's forward autograd function. It should at least contain ``x_seq, v_init, v_reset`` :type py_dict: dict :return: requires_grad, blocks, threads, py_dict requires_grad: bool if any tensor in ``py_dict`` requires grad, then ``requires_grad = True``;else ``requires_grad = False`` blocks: int CUDA param used in calling CUDA kernel threads: int CUDA param used in calling CUDA kernel. The default value is ``spikingjelly.configure.cuda_threads`` py_dict: dict Compared with the input ``py_dict``, the returned ``py_dict`` will: * convert all ``float/int`` scalars in ``py_dict`` to ``cupy.ndarray`` * add ``h_seq, spike_seq, v_v_seq`` to ``py_dict``. ``h_seq, spike_seq`` are zero tensors with the same shape with ``x_seq``. ``v_v_seq`` is concatenated from ``v_init`` and ``v_seq``, which is zero tensors with the same shape with ``x_seq`` * add ``N, numel`` to ``py_dict``. Note that ``x_seq.shape = [T, N]`` and ``numel = T * N``. A specific case is that ``x_seq.dtype == torch.half``, then ``N = math.ceil(N / 2)``, and ``numel = N * x_seq.shape[0]``. Note that ``N, numel`` in the returned ``py_dict`` are ``cupy.ndarray`` :rtype: tuple """ device = py_dict['x_seq'].get_device() requires_grad = if_requires_grad(py_dict.values()) scalar_to_cupy(py_dict) new_tensors(('h_seq', 'spike_seq', 'v_seq'), py_dict) py_dict['v_v_seq'] = torch.cat((py_dict.pop('v_init').unsqueeze(0), py_dict.pop('v_seq'))) numel = py_dict['x_seq'].numel() N = py_dict['x_seq'].shape[1] threads = configure.cuda_threads if py_dict['x_seq'].dtype == torch.float16: # we will take two neurons to calculate as one neuron in cuda half2 # pad will be implemented by the kernel.__call__ N = math.ceil(N / 2) numel = N * py_dict['x_seq'].shape[0] blocks = cuda_utils.cal_blocks(N) with cuda_utils.DeviceEnvironment(device): numel = cupy.asarray(numel) N = cupy.asarray(N) py_dict['numel'] = numel py_dict['N'] = N return requires_grad, blocks, threads, py_dict
[文档] @staticmethod def ctx_save(ctx, requires_grad: bool, *args, **kwargs): """ :param ctx: ``ctx`` in :class:`torch.autograd.Function` :param requires_grad: if any tensor in forward params requires grad :type requires_grad: bool :param args: tensors that need to be saved by ``ctx.save_for_backward`` :param kwargs: items that need to be saved by ``ctx.xx = xx`` Saves ``*args, **kwargs`` in ``ctx`` by ``ctx.save_for_backward(*args)`` and ``ctx.xx = xx`` for all ``xx`` in ``kwargs.items()``. """ if requires_grad: ctx.save_for_backward(*args) for key, value in kwargs.items(): ctx.__setattr__(key, value)
[文档] @staticmethod def pre_backward(ctx, grad_spike_seq: torch.Tensor, grad_v_seq: torch.Tensor): """ :param ctx: ``ctx`` in :class:`torch.autograd.Function` :param grad_spike_seq: gradients of ``spike_seq`` :type grad_spike_seq: torch.Tensor :param grad_v_seq: gradients of ``v_seq`` :type grad_v_seq: torch.Tensor :return: backward_kernel, blocks, threads, py_dict backward_kernel: NeuronBPTTKernel The CUDA kernel used for backward. It should be provided in ``ctx.backward_kernel`` blocks: int CUDA param used in calling CUDA kernel. It should be provided in ``ctx.blocks`` threads: int CUDA param used in calling CUDA kernel. It should be provided in ``ctx.threads`` :rtype: tuple """ backward_kernel = ctx.backward_kernel blocks = ctx.blocks threads = ctx.threads h_seq = ctx.saved_tensors[0] numel = ctx.numel N = ctx.N v_th = ctx.v_th v_reset = ctx.v_reset zero_shape = list(grad_spike_seq.shape) zero_shape[0] += 1 zero_data = torch.zeros(zero_shape, device=grad_spike_seq.device, dtype=grad_spike_seq.dtype) grad_x_seq = zero_data[0: -1] grad_v_init = zero_data[-1] py_dict = { 'numel': numel, 'N': N, 'grad_spike_seq': grad_spike_seq, 'grad_v_seq': grad_v_seq, 'h_seq': h_seq, 'grad_x_seq': grad_x_seq, 'grad_v_init': grad_v_init, 'v_th': v_th, 'v_reset': v_reset } return backward_kernel, blocks, threads, py_dict
[文档]class IFNodeATGF(torch.autograd.Function):
[文档] @staticmethod def forward(ctx, x_seq: torch.Tensor, v_init: torch.Tensor, v_th: float, v_reset: float or None, forward_kernel: IFNodeFPTTKernel, backward_kernel: IFNodeBPTTKernel): py_dict = { 'x_seq': x_seq, 'v_init': v_init, 'v_th': v_th, 'v_reset': v_reset } requires_grad, blocks, threads, py_dict = NeuronATGFBase.pre_forward(py_dict) if py_dict['v_reset'] is None: py_dict.pop('v_reset') forward_kernel((blocks,), (threads,), py_dict) if 'v_reset' not in py_dict: py_dict['v_reset'] = None NeuronATGFBase.ctx_save(ctx, requires_grad, py_dict['h_seq'], blocks=blocks, threads=threads, numel=py_dict['numel'], N=py_dict['N'], v_th=py_dict['v_th'], v_reset=py_dict['v_reset'], backward_kernel=backward_kernel) return py_dict['spike_seq'], py_dict['v_v_seq'][1:, ]
[文档] @staticmethod def backward(ctx, grad_spike_seq: torch.Tensor, grad_v_seq: torch.Tensor): backward_kernel, blocks, threads, py_dict = NeuronATGFBase.pre_backward(ctx, grad_spike_seq, grad_v_seq) if py_dict['v_reset'] is None: py_dict.pop('v_reset') backward_kernel((blocks,), (threads,), py_dict) if 'v_reset' not in py_dict: py_dict['v_reset'] = None return py_dict['grad_x_seq'], py_dict['grad_v_init'], None, None, None, None
[文档]class LIFNodeFPTTKernel(NeuronFPTTKernel): def __init__(self, decay_input: bool, hard_reset: bool, dtype: str): super().__init__(hard_reset, dtype) self.decay_input = decay_input self.add_param(ctype=f'const {dtype} &', cname='decay')
[文档] def neuronal_charge(self) -> str: if self.hard_reset: codes = cfunction.sub(z=f'{self.dtype} LIFNodeFPTTKernel_temp_var', x='v_v_seq[t]', y='v_reset', dtype=self.dtype) else: codes = f'{self.dtype} LIFNodeFPTTKernel_temp_var = v_v_seq[t];' if self.decay_input: codes += cfunction.sub(z='LIFNodeFPTTKernel_temp_var', x='x_seq[t]', y='LIFNodeFPTTKernel_temp_var', dtype=self.dtype) codes += cfunction.mul(z='LIFNodeFPTTKernel_temp_var', x='decay', y='LIFNodeFPTTKernel_temp_var', dtype=self.dtype) else: codes += cfunction.mul(z='LIFNodeFPTTKernel_temp_var', x='decay', y='LIFNodeFPTTKernel_temp_var', dtype=self.dtype) codes += cfunction.sub(z='LIFNodeFPTTKernel_temp_var', x='x_seq[t]', y='LIFNodeFPTTKernel_temp_var', dtype=self.dtype) codes += cfunction.add(z='h_seq[t]', x='LIFNodeFPTTKernel_temp_var', y='v_v_seq[t]', dtype=self.dtype) return codes
[文档]class LIFNodeBPTTKernel(NeuronBPTTKernel): def __init__(self, decay_input: bool, surrogate_function: Callable, hard_reset: bool, detach_reset: bool, dtype: str): super().__init__(surrogate_function, hard_reset, detach_reset, dtype) self.decay_input = decay_input self.add_param(ctype=f'const {dtype} &', cname='decay')
[文档] def grad_h_next_to_v(self) -> str: return cfunction.sub(z=f'const {self.dtype} grad_h_next_to_v', x=cfunction.constant(None, x=1., dtype=self.dtype), y='decay', dtype=self.dtype)
[文档] def grad_h_to_x(self) -> str: if not self.decay_input: return cfunction.constant(y=f'const {self.dtype} grad_h_to_x', x=1., dtype=self.dtype) else: return f'const {self.dtype} grad_h_to_x = decay;'
[文档]class LIFNodeATGF(torch.autograd.Function):
[文档] @staticmethod def forward(ctx, x_seq: torch.Tensor, v_init: torch.Tensor, v_th: float, v_reset: float or None, decay: float, forward_kernel: LIFNodeFPTTKernel, backward_kernel: LIFNodeBPTTKernel): py_dict = { 'x_seq': x_seq, 'v_init': v_init, 'v_th': v_th, 'v_reset': v_reset, 'decay': decay, } requires_grad, blocks, threads, py_dict = NeuronATGFBase.pre_forward(py_dict) if py_dict['v_reset'] is None: py_dict.pop('v_reset') forward_kernel((blocks,), (threads,), py_dict) if 'v_reset' not in py_dict: py_dict['v_reset'] = None NeuronATGFBase.ctx_save(ctx, requires_grad, py_dict['h_seq'], blocks=blocks, threads=threads, numel=py_dict['numel'], N=py_dict['N'], v_th=py_dict['v_th'], v_reset=py_dict['v_reset'], backward_kernel=backward_kernel, decay=py_dict['decay']) return py_dict['spike_seq'], py_dict['v_v_seq'][1:, ]
[文档] @staticmethod def backward(ctx, grad_spike_seq: torch.Tensor, grad_v_seq: torch.Tensor): backward_kernel, blocks, threads, py_dict = NeuronATGFBase.pre_backward(ctx, grad_spike_seq, grad_v_seq) py_dict['decay'] = ctx.decay if py_dict['v_reset'] is None: py_dict.pop('v_reset') backward_kernel((blocks,), (threads,), py_dict) if 'v_reset' not in py_dict: py_dict['v_reset'] = None return py_dict['grad_x_seq'], py_dict['grad_v_init'], None, None, None, None, None
[文档]class ParametricLIFNodeFPTTKernel(NeuronFPTTKernel): def __init__(self, decay_input: bool, hard_reset: bool, dtype: str): super().__init__(hard_reset, dtype) self.decay_input = decay_input self.add_param(ctype=f'const {dtype} *', cname='decay')
[文档] def neuronal_charge(self) -> str: if self.hard_reset: codes = cfunction.sub(z=f'{self.dtype} LIFNodeFPTTKernel_temp_var', x='v_v_seq[t]', y='v_reset', dtype=self.dtype) else: codes = f'{self.dtype} LIFNodeFPTTKernel_temp_var = v_v_seq[t];' if self.decay_input: codes += cfunction.sub(z='LIFNodeFPTTKernel_temp_var', x='x_seq[t]', y='LIFNodeFPTTKernel_temp_var', dtype=self.dtype) codes += cfunction.mul(z='LIFNodeFPTTKernel_temp_var', x='decay[0]', y='LIFNodeFPTTKernel_temp_var', dtype=self.dtype) else: codes += cfunction.mul(z='LIFNodeFPTTKernel_temp_var', x='decay[0]', y='LIFNodeFPTTKernel_temp_var', dtype=self.dtype) codes += cfunction.sub(z='LIFNodeFPTTKernel_temp_var', x='x_seq[t]', y='LIFNodeFPTTKernel_temp_var', dtype=self.dtype) codes += cfunction.add(z='h_seq[t]', x='LIFNodeFPTTKernel_temp_var', y='v_v_seq[t]', dtype=self.dtype) return codes
[文档]class ParametricLIFNodeBPTTKernel(NeuronBPTTKernel): def __init__(self, decay_input: bool, surrogate_function: Callable, hard_reset: bool, detach_reset: bool, dtype: str): super().__init__(surrogate_function, hard_reset, detach_reset, dtype) self.decay_input = decay_input self.add_param(ctype=f'const {dtype} *', cname='decay') self.add_param(ctype=f'float *', cname='grad_decay') # float to avoid overflow self.add_param(ctype=f'const {dtype} *', cname='v_v_seq')
[文档] def grad_h_next_to_v(self) -> str: return cfunction.sub(z=f'const {self.dtype} grad_h_next_to_v', x=cfunction.constant(None, x=1., dtype=self.dtype), y='decay[0]', dtype=self.dtype)
[文档] def grad_h_to_x(self) -> str: if not self.decay_input: return cfunction.constant(y=f'const {self.dtype} grad_h_to_x', x=1., dtype=self.dtype) else: return f'const {self.dtype} grad_h_to_x = decay[0];'
@property def head(self): # override codes = ''' { const int index = blockIdx.x * blockDim.x + threadIdx.x; ''' codes += fr''' __shared__ float sdata[{configure.cuda_threads}]; ''' codes += ''' if (index < N) { const int dt = N; ''' codes += self.pre_core if self.reverse: codes += ''' for(int t = numel - N + index; t >= 0; t -= dt) { ''' else: codes += ''' for(int t = index; t < numel; t += dt) { ''' return codes @property def pre_core(self): codes = base.CodeTyper(16) # use float to avoid overflow codes.append('sdata[threadIdx.x] = 0.0f;') return super().pre_core + '\n' + codes.codes @property def core(self): core_codes = base.CodeTyper(18) with base.CodeBlock(core_codes): if self.decay_input: core_codes.append(cfunction.sub(z=f'{self.dtype} temp_var', x='h_seq[t]', y='v_v_seq[t]', dtype=self.dtype)) core_codes.append(cfunction.mul(z='temp_var', x='temp_var', y='grad_h', dtype=self.dtype)) core_codes.append(cfunction.div(z='temp_var', x='temp_var', y='decay[0]', dtype=self.dtype)) else: if self.hard_reset: core_codes.append( cfunction.sub(z=f'{self.dtype} temp_var', x='v_reset', y='v_v_seq[t]', dtype=self.dtype)) core_codes.append(cfunction.mul(z='temp_var', x='temp_var', y='grad_h', dtype=self.dtype)) else: core_codes.append( cfunction.mul(z=f'{self.dtype} temp_var', x='grad_h', y='v_v_seq[t]', dtype=self.dtype)) core_codes.append(cfunction.neg(y='temp_var', x='temp_var', dtype=self.dtype)) if self.dtype == 'float': core_codes.append('sdata[threadIdx.x] += temp_var;') elif self.dtype == 'half2': core_codes.append('sdata[threadIdx.x] += __half2float(__hadd(__low2half(temp_var), __high2half(temp_var)));') else: raise NotImplementedError(self.dtype) return super().core + '\n' + core_codes.codes @property def tail(self): codes = ''' } ''' codes += self.post_core codes += ''' } else { sdata[threadIdx.x] = 0.0f; } int threadx = blockDim.x; #pragma unroll for (int stride = threadx >> 1; stride > 0; stride = stride >> 1) { // Synchronize all thread before next loop __syncthreads(); if (threadIdx.x < stride) { sdata[threadIdx.x] += sdata[threadIdx.x + stride]; } } __syncthreads(); if (threadIdx.x == 0) { atomicAdd(grad_decay, sdata[0]); } } ''' return codes
[文档]class ParametricLIFNodeATGF(torch.autograd.Function):
[文档] @staticmethod def forward(ctx, x_seq: torch.Tensor, v_init: torch.Tensor, v_th: float, v_reset: float or None, decay: torch.Tensor, forward_kernel: ParametricLIFNodeFPTTKernel, backward_kernel: ParametricLIFNodeBPTTKernel): if x_seq.dtype == torch.float16 and v_init.numel() % 2 != 0: raise ValueError('When using the the PLIF neuron with half2 cupy backend, the numer of neurons should be even to avoid the wrong gradient of tau caused by padding!') py_dict = { 'x_seq': x_seq, 'v_init': v_init, 'v_th': v_th, 'v_reset': v_reset, 'decay': decay, } requires_grad, blocks, threads, py_dict = NeuronATGFBase.pre_forward(py_dict) if py_dict['v_reset'] is None: py_dict.pop('v_reset') forward_kernel((blocks,), (threads,), py_dict) if 'v_reset' not in py_dict: py_dict['v_reset'] = None NeuronATGFBase.ctx_save(ctx, requires_grad, py_dict['h_seq'], py_dict['v_v_seq'], blocks=blocks, threads=threads, numel=py_dict['numel'], N=py_dict['N'], v_th=py_dict['v_th'], v_reset=py_dict['v_reset'], backward_kernel=backward_kernel, decay=py_dict['decay']) return py_dict['spike_seq'], py_dict['v_v_seq'][1:, ]
[文档] @staticmethod def backward(ctx, grad_spike_seq: torch.Tensor, grad_v_seq: torch.Tensor): backward_kernel, blocks, threads, py_dict = NeuronATGFBase.pre_backward(ctx, grad_spike_seq, grad_v_seq) py_dict['decay'] = ctx.decay py_dict['grad_decay'] = torch.zeros_like(ctx.decay, dtype=torch.float) py_dict['v_v_seq'] = ctx.saved_tensors[1] if py_dict['v_reset'] is None: py_dict.pop('v_reset') backward_kernel((blocks,), (threads,), py_dict) if 'v_reset' not in py_dict: py_dict['v_reset'] = None return py_dict['grad_x_seq'], py_dict['grad_v_init'], None, None, py_dict['grad_decay'], None, None