spikingjelly.activation_based.learning 源代码

from typing import Callable, Union

import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from . import neuron, monitor, base


[文档]def stdp_linear_single_step( fc: nn.Linear, in_spike: torch.Tensor, out_spike: torch.Tensor, trace_pre: Union[float, torch.Tensor, None], trace_post: Union[float, torch.Tensor, None], tau_pre: float, tau_post: float, f_pre: Callable = lambda x: x, f_post: Callable = lambda x: x ): if trace_pre is None: trace_pre = 0. if trace_post is None: trace_post = 0. weight = fc.weight.data trace_pre = trace_pre - trace_pre / tau_pre + in_spike # shape = [batch_size, N_in] trace_post = trace_post - trace_post / tau_post + out_spike # shape = [batch_size, N_out] # [batch_size, N_out, N_in] -> [N_out, N_in] delta_w_pre = -f_pre(weight) * (trace_post.unsqueeze(2) * in_spike.unsqueeze(1)).sum(0) delta_w_post = f_post(weight) * (trace_pre.unsqueeze(1) * out_spike.unsqueeze(2)).sum(0) return trace_pre, trace_post, delta_w_pre + delta_w_post
[文档]def mstdp_linear_single_step( fc: nn.Linear, in_spike: torch.Tensor, out_spike: torch.Tensor, trace_pre: Union[float, torch.Tensor, None], trace_post: Union[float, torch.Tensor, None], tau_pre: float, tau_post: float, f_pre: Callable = lambda x: x, f_post: Callable = lambda x: x ): if trace_pre is None: trace_pre = 0. if trace_post is None: trace_post = 0. weight = fc.weight.data trace_pre = trace_pre * math.exp(-1 / tau_pre) + in_spike # shape = [batch_size, C_in] trace_post = trace_post * math.exp(-1 / tau_post) + out_spike # shape = [batch_size, C_out] # [batch_size, N_out, N_in] eligibility = f_post(weight) * (trace_pre.unsqueeze(1) * out_spike.unsqueeze(2)) -\ f_pre(weight) * (trace_post.unsqueeze(2) * in_spike.unsqueeze(1)) return trace_pre, trace_post, eligibility
[文档]def mstdpet_linear_single_step( fc: nn.Linear, in_spike: torch.Tensor, out_spike: torch.Tensor, trace_pre: Union[float, torch.Tensor, None], trace_post: Union[float, torch.Tensor, None], tau_pre: float, tau_post: float, tau_trace: float, f_pre: Callable = lambda x: x, f_post: Callable = lambda x: x ): if trace_pre is None: trace_pre = 0. if trace_post is None: trace_post = 0. weight = fc.weight.data trace_pre = trace_pre * math.exp(-1 / tau_pre) + in_spike trace_post = trace_post * math.exp(-1 / tau_post) + out_spike eligibility = f_post(weight) * torch.outer(out_spike, trace_pre) -\ f_pre(weight) * torch.outer(trace_post, in_spike) return trace_pre, trace_post, eligibility
[文档]def stdp_conv2d_single_step( conv: nn.Conv2d, in_spike: torch.Tensor, out_spike: torch.Tensor, trace_pre: Union[torch.Tensor, None], trace_post: Union[torch.Tensor, None], tau_pre: float, tau_post: float, f_pre: Callable = lambda x: x, f_post: Callable = lambda x: x ): if conv.dilation != (1, 1): raise NotImplementedError( 'STDP with dilation != 1 for Conv2d has not been implemented!' ) if conv.groups != 1: raise NotImplementedError( 'STDP with groups != 1 for Conv2d has not been implemented!' ) stride_h = conv.stride[0] stride_w = conv.stride[1] if conv.padding == (0, 0): pass else: pH = conv.padding[0] pW = conv.padding[1] if conv.padding_mode != 'zeros': in_spike = F.pad( in_spike, conv._reversed_padding_repeated_twice, mode=conv.padding_mode ) else: in_spike = F.pad(in_spike, pad=(pW, pW, pH, pH)) if trace_pre is None: trace_pre = torch.zeros_like( in_spike, device=in_spike.device, dtype=in_spike.dtype ) if trace_post is None: trace_post = torch.zeros_like( out_spike, device = in_spike.device, dtype=in_spike.dtype ) trace_pre = trace_pre - trace_pre / tau_pre + in_spike trace_post = trace_post - trace_post / tau_post + out_spike delta_w = torch.zeros_like(conv.weight.data) for h in range(conv.weight.shape[2]): for w in range(conv.weight.shape[3]): h_end = in_spike.shape[2] - conv.weight.shape[2] + 1 + h w_end = in_spike.shape[3] - conv.weight.shape[3] + 1 + w pre_spike = in_spike[:, :, h:h_end:stride_h, w:w_end:stride_w] # shape = [batch_size, C_in, h_out, w_out] post_spike = out_spike # shape = [batch_size, C_out, h_out, h_out] weight = conv.weight.data[:, :, h, w] # shape = [batch_size_out, C_in] tr_pre = trace_pre[:, :, h:h_end:stride_h, w:w_end:stride_w] # shape = [batch_size, C_in, h_out, w_out] tr_post = trace_post # shape = [batch_size, C_out, h_out, w_out] delta_w_pre = - (f_pre(weight) * (tr_post.unsqueeze(2) * pre_spike.unsqueeze(1)) .permute([1, 2, 0, 3, 4]).sum(dim = [2, 3, 4])) delta_w_post = f_post(weight) * \ (tr_pre.unsqueeze(1) * post_spike.unsqueeze(2))\ .permute([1, 2, 0, 3, 4]).sum(dim = [2, 3, 4]) delta_w[:, :, h, w] += delta_w_pre + delta_w_post return trace_pre, trace_post, delta_w
[文档]def stdp_conv1d_single_step( conv: nn.Conv1d, in_spike: torch.Tensor, out_spike: torch.Tensor, trace_pre: Union[torch.Tensor, None], trace_post: Union[torch.Tensor, None], tau_pre: float, tau_post: float, f_pre: Callable = lambda x: x, f_post: Callable = lambda x: x ): if conv.dilation != (1, ): raise NotImplementedError( 'STDP with dilation != 1 for Conv1d has not been implemented!' ) if conv.groups != 1: raise NotImplementedError( 'STDP with groups != 1 for Conv1d has not been implemented!' ) stride_l = conv.stride[0] if conv.padding == (0, ): pass else: pL = conv.padding[0] if conv.padding_mode != 'zeros': in_spike = F.pad( in_spike, conv._reversed_padding_repeated_twice, mode=conv.padding_mode ) else: in_spike = F.pad(in_spike, pad=(pL, pL)) if trace_pre is None: trace_pre = torch.zeros_like( in_spike, device=in_spike.device, dtype=in_spike.dtype ) if trace_post is None: trace_post = torch.zeros_like( out_spike, device=in_spike.device, dtype=in_spike.dtype ) trace_pre = trace_pre - trace_pre / tau_pre + in_spike trace_post = trace_post - trace_post / tau_post + out_spike delta_w = torch.zeros_like(conv.weight.data) for l in range(conv.weight.shape[2]): l_end = in_spike.shape[2] - conv.weight.shape[2] + 1 + l pre_spike = in_spike[:, :, l:l_end:stride_l] # shape = [batch_size, C_in, l_out] post_spike = out_spike # shape = [batch_size, C_out, l_out] weight = conv.weight.data[:, :, l] # shape = [batch_size_out, C_in] tr_pre = trace_pre[:, :, l:l_end:stride_l] # shape = [batch_size, C_in, l_out] tr_post = trace_post # shape = [batch_size, C_out, l_out] delta_w_pre = - (f_pre(weight) * (tr_post.unsqueeze(2) * pre_spike.unsqueeze(1)) .permute([1, 2, 0, 3]).sum(dim = [2, 3])) delta_w_post = f_post(weight) * \ (tr_pre.unsqueeze(1) * post_spike.unsqueeze(2))\ .permute([1, 2, 0, 3]).sum(dim = [2, 3]) delta_w[:, :, l] += delta_w_pre + delta_w_post return trace_pre, trace_post, delta_w
[文档]def stdp_multi_step( layer: Union[nn.Linear, nn.Conv1d, nn.Conv2d], in_spike: torch.Tensor, out_spike: torch.Tensor, trace_pre: Union[float, torch.Tensor, None], trace_post: Union[float, torch.Tensor, None], tau_pre: float, tau_post: float, f_pre: Callable = lambda x: x, f_post: Callable = lambda x: x ): weight = layer.weight.data delta_w = torch.zeros_like(weight) T = in_spike.shape[0] if isinstance(layer, nn.Linear): stdp_single_step = stdp_linear_single_step elif isinstance(layer, nn.Conv1d): stdp_single_step = stdp_conv1d_single_step elif isinstance(layer, nn.Conv2d): stdp_single_step = stdp_conv2d_single_step for t in range(T): trace_pre, trace_post, dw = stdp_single_step( layer, in_spike[t], out_spike[t], trace_pre, trace_post, tau_pre, tau_post, f_pre, f_post ) delta_w += dw return trace_pre, trace_post, delta_w
[文档]class STDPLearner(base.MemoryModule): def __init__( self, step_mode: str, synapse: Union[nn.Conv2d, nn.Linear], sn: neuron.BaseNode, tau_pre: float, tau_post: float, f_pre: Callable = lambda x: x, f_post: Callable = lambda x: x ): super().__init__() self.step_mode = step_mode self.tau_pre = tau_pre self.tau_post = tau_post self.f_pre = f_pre self.f_post = f_post self.synapse = synapse self.in_spike_monitor = monitor.InputMonitor(synapse) self.out_spike_monitor = monitor.OutputMonitor(sn) self.register_memory('trace_pre', None) self.register_memory('trace_post', None)
[文档] def reset(self): super(STDPLearner, self).reset() self.in_spike_monitor.clear_recorded_data() self.out_spike_monitor.clear_recorded_data()
[文档] def disable(self): self.in_spike_monitor.disable() self.out_spike_monitor.disable()
[文档] def enable(self): self.in_spike_monitor.enable() self.out_spike_monitor.enable()
[文档] def step(self, on_grad: bool = True, scale: float = 1.): length = self.in_spike_monitor.records.__len__() delta_w = None if self.step_mode == 's': if isinstance(self.synapse, nn.Linear): stdp_f = stdp_linear_single_step elif isinstance(self.synapse, nn.Conv2d): stdp_f = stdp_conv2d_single_step elif isinstance(self.synapse, nn.Conv1d): stdp_f = stdp_conv1d_single_step else: raise NotImplementedError(self.synapse) elif self.step_mode == 'm': if isinstance(self.synapse, (nn.Linear, nn.Conv1d, nn.Conv2d)): stdp_f = stdp_multi_step else: raise NotImplementedError(self.synapse) else: raise ValueError(self.step_mode) for _ in range(length): in_spike = self.in_spike_monitor.records.pop(0) # [batch_size, N_in] out_spike = self.out_spike_monitor.records.pop(0) # [batch_size, N_out] self.trace_pre, self.trace_post, dw = stdp_f( self.synapse, in_spike, out_spike, self.trace_pre, self.trace_post, self.tau_pre, self.tau_post, self.f_pre, self.f_post ) if scale != 1.: dw *= scale delta_w = dw if (delta_w is None) else (delta_w + dw) if on_grad: if self.synapse.weight.grad is None: self.synapse.weight.grad = -delta_w else: self.synapse.weight.grad = self.synapse.weight.grad - delta_w else: return delta_w
[文档]class MSTDPLearner(base.MemoryModule): def __init__( self, step_mode: str, batch_size: float, synapse: Union[nn.Conv2d, nn.Linear], sn: neuron.BaseNode, tau_pre: float, tau_post: float, f_pre: Callable = lambda x: x, f_post: Callable = lambda x: x ): super().__init__() self.step_mode = step_mode self.batch_size = batch_size self.tau_pre = tau_pre self.tau_post = tau_post self.f_pre = f_pre self.f_post = f_post self.synapse = synapse self.in_spike_monitor = monitor.InputMonitor(synapse) self.out_spike_monitor = monitor.OutputMonitor(sn) self.register_memory('trace_pre', None) self.register_memory('trace_post', None)
[文档] def reset(self): super(MSTDPLearner, self).reset() self.in_spike_monitor.clear_recorded_data() self.out_spike_monitor.clear_recorded_data()
[文档] def disable(self): self.in_spike_monitor.disable() self.out_spike_monitor.disable()
[文档] def enable(self): self.in_spike_monitor.enable() self.out_spike_monitor.enable()
[文档] def step(self, reward, on_grad: bool = True, scale: float = 1.): length = self.in_spike_monitor.records.__len__() delta_w = None if self.step_mode == 's': if isinstance(self.synapse, nn.Conv2d): # stdp_f = mstdp_conv2d_single_step raise NotImplementedError(self.synapse) elif isinstance(self.synapse, nn.Linear): stdp_f = mstdp_linear_single_step else: raise NotImplementedError(self.synapse) elif self.step_mode == 'm': if (isinstance(self.synapse, nn.Conv2d) or isinstance(self.synapse, nn.Linear)): # stdp_f = mstdp_multi_step raise NotImplementedError(self.synapse) else: raise NotImplementedError(self.synapse) else: raise ValueError(self.step_mode) for _ in range(length): if not hasattr(self, "eligibility"): self.eligibility = torch.zeros( self.batch_size, *self.synapse.weight.shape, device=self.synapse.weight.device ) dw = (reward.view(-1, 1, 1) * self.eligibility).sum(0) # [batch_size, N_out, N_in] -> [N_out, N_in] if scale != 1.: dw *= scale delta_w = dw if (delta_w is None) else (delta_w + dw) in_spike = self.in_spike_monitor.records.pop(0) # [batch_size, N_in] out_spike = self.out_spike_monitor.records.pop(0) # [batch_size, N_out] self.trace_pre, self.trace_post, self.eligibility = stdp_f( self.synapse, in_spike, out_spike, self.trace_pre, self.trace_post, self.tau_pre, self.tau_post, self.f_pre, self.f_post ) if on_grad: if self.synapse.weight.grad is None: self.synapse.weight.grad = -delta_w else: self.synapse.weight.grad = self.synapse.weight.grad - delta_w else: return delta_w
[文档]class MSTDPETLearner(base.MemoryModule): def __init__( self, step_mode: str, synapse: Union[nn.Conv2d, nn.Linear], sn: neuron.BaseNode, tau_pre: float, tau_post: float, tau_trace: float, f_pre: Callable = lambda x: x, f_post: Callable = lambda x: x ): super().__init__() self.step_mode = step_mode self.tau_pre = tau_pre self.tau_post = tau_post self.tau_trace = tau_trace self.f_pre = f_pre self.f_post = f_post self.synapse = synapse self.in_spike_monitor = monitor.InputMonitor(synapse) self.out_spike_monitor = monitor.OutputMonitor(sn) self.register_memory('trace_pre', None) self.register_memory('trace_post', None) self.register_memory('trace_e', None)
[文档] def reset(self): super(MSTDPETLearner, self).reset() self.in_spike_monitor.clear_recorded_data() self.out_spike_monitor.clear_recorded_data()
[文档] def disable(self): self.in_spike_monitor.disable() self.out_spike_monitor.disable()
[文档] def enable(self): self.in_spike_monitor.enable() self.out_spike_monitor.enable()
[文档] def step(self, reward, on_grad: bool = True, scale: float = 1.): length = self.in_spike_monitor.records.__len__() delta_w = None if self.step_mode == 's': if isinstance(self.synapse, nn.Conv2d): # stdp_f = mstdpet_conv2d_single_step raise NotImplementedError(self.synapse) elif isinstance(self.synapse, nn.Linear): stdp_f = mstdpet_linear_single_step else: raise NotImplementedError(self.synapse) elif self.step_mode == 'm': if (isinstance(self.synapse, nn.Conv2d) or isinstance(self.synapse, nn.Linear)): # stdp_f = mstdpet_multi_step raise NotImplementedError(self.synapse) else: raise NotImplementedError(self.synapse) else: raise ValueError(self.step_mode) for _ in range(length): if not hasattr(self, "eligibility"): self.eligibility = torch.zeros( *self.synapse.weight.shape, device=self.synapse.weight.device ) if self.trace_e is None: self.trace_e = 0. self.trace_e = self.trace_e * math.exp(-1 / self.tau_trace) + self.eligibility / self.tau_trace dw = reward * self.trace_e if scale != 1.: dw *= scale delta_w = dw if (delta_w is None) else (delta_w + dw) in_spike = self.in_spike_monitor.records.pop(0) out_spike = self.out_spike_monitor.records.pop(0) self.trace_pre, self.trace_post, self.eligibility = stdp_f( self.synapse, in_spike, out_spike, self.trace_pre, self.trace_post, self.tau_pre, self.tau_post, self.tau_trace, self.f_pre, self.f_post ) if on_grad: if self.synapse.weight.grad is None: self.synapse.weight.grad = -delta_w else: self.synapse.weight.grad = self.synapse.weight.grad - delta_w else: return delta_w