spikingjelly.timing_based package#

Encoding#

class spikingjelly.timing_based.encoding.GaussianTuning(n: int, m: int, x_min: Tensor, x_max: Tensor)[源代码]#

基类:object

Gaussian Receptive Field Population Coding

All neurons are arranged in a grid where each row corresponds to an input dimension and each column corresponds to a neuron in that dimension. When an input value is given, each neuron computes its response based on its Gaussian tuning curve. The response is then mapped to a spike time, where a higher response results in an earlier spike time, and a lower response results in a later spike time. If the spike time exceeds a predefined maximum spike time, the neuron becomes inactive (no spike, represented by -1).

Reference: | Sander M. Bohte, Joost N. Kok, Han La Poutré, | Error-backpropagation in temporally encoded networks of spiking neurons, | Neurocomputing, | Volume 48, Issues 1–4, | 2002, | Pages 17-37, | ISSN 0925-2312, | https://doi.org/10.1016/S0925-2312(01)00658-0. | (https://www.sciencedirect.com/science/article/pii/S0925231201006580)

Neuron Spatial Receptive Field (Lecture): | https://youtu.be/fCqt07IXUPI?si=jVT-QlmEgrbQZkB2

参数:
  • n (int) -- The number of channels/dims in the input data.

  • m (int) -- The number of neurons per input channel.

  • x_min (torch.Tensor) -- 1D tensor of shape (n,) representing the minimum value of the input data for each input dimension.

  • x_max (torch.Tensor) -- 1D tensor of shape (n,) representing the maximum value of the input data for each input dimension.

抛出:

ValueError -- If x_min and x_max do not have shape (n,) or if any element in x_min is not less than the corresponding element in x_max.

示例

>>> import torch
>>> from spikingjelly.timing_based import encoding
>>> x_min = torch.tensor([0.0])
>>> x_max = torch.tensor([1.0])
>>> encoder = encoding.GaussianTuning(n=1, m=4, x_min=x_min, x_max=x_max)
>>> x = torch.tensor([[[0.1, 0.5, 0.9]]])
>>> spikes = encoder.encode(x, max_spike_time=100)
>>> print(spikes)
tensor([[[[42., 10., 85., -1.],
          [92., 25., 25., 92.],
          [-1., 85., 10., 42.]]]])
>>> print(spikes.shape)
torch.Size([1, 1, 3, 4]) # (batch_size, channels_count, samples_count, neurons_count)

An array/grid of Neuronal Receptive Fields:

Each one is a Gaussian curve defined by its center (mu) and variance (sigma^2).
The grid is made up of m neurons for each of the n input dimensions.

┌───┐ ┌───┐ ┌───┐ ┌───┐     ┌───┐
│ 1 │ │ 2 │ │ 3 │ │ 4 │ ... │ m │ <- m neurons (m=5) for Dimension 0
└───┘ └───┘ └───┘ └───┘     └───┘
┌───┐ ┌───┐ ┌───┐ ┌───┐     ┌───┐
│ 1 │ │ 2 │ │ 3 │ │ 4 │ ... │ m │ <- m neurons (m=5) for Dimension 1
└───┘ └───┘ └───┘ └───┘     └───┘
 ...   ...   ...   ...       ...
┌───┐ ┌───┐ ┌───┐ ┌───┐     ┌───┐
│ 1 │ │ 2 │ │ 3 │ │ 4 │ ... │ m │ <- m neurons (m=5) for Dimension n
└───┘ └───┘ └───┘ └───┘     └───┘

Each neuron computes the response to the input based on its Gaussian tuning curve:

     0.1      0.1      0.1    0.1
     0.5      0.5      0.5    0.5
     0.9      0.9      0.9    0.9
      │        │        │      │
      ▼        ▼        ▼      ▼
    ┌───┐    ┌───┐    ┌───┐   ┌───┐
    │ 1 │    │ 2 │    │ 3 │   │ 4 │   <-- 4 neurons
    └───┘    └───┘    └───┘   └───┘
      │        │        │       │
      ▼        ▼        ▼       ▼
     0.5762   0.9037   0.1494  0.0026
     0.0796   0.7548   0.7548  0.0796 <-- responses (probability of firing)
     0.0026   0.1494   0.9037  0.5762
      │        │        │       │
      ▼        ▼        ▼       ▼
      42       10       85     -1
      92       25       25      92    <-- spike times
     -1        85       10      42

lower response  -> later spike time
higher response -> earlier spike time

if the spike time >= max_spike_time, neuron becomes inactive (no spike, -1)

示例

>>> x_min = torch.tensor([0.0, 0.0, 0.0])
>>> x_max = torch.tensor([1.0, 1.0, 1.0])
>>> encoder = GaussianTuning(n=3, m=5, x_min=x_min, x_max=x_max)
>>> x = torch.tensor([[[0.1, 0.5, 0.9], [0.2, 0.6, 0.8], [0.3, 0.7, 0.4]]])
>>> spikes = encoder.encode(x, max_spike_time=100)
>>> print(spikes)
tensor([[[[51.,  4., 80., -1., -1.],
          [99., 68.,  0., 68., 99.],
          [-1., -1., 80.,  4., 51.]],

         [[74.,  1., 60., 98., -1.],
          [-1., 85., 10., 42., 96.],
          [-1., 98., 60.,  1., 74.]],

         [[89., 16., 33., 94., -1.],
          [-1., 94., 33., 16., 89.],
          [96., 42., 10., 85., -1.]]]])
>>> print(spikes.shape)
torch.Size([1, 3, 3, 5]) # (batch_size, channels_count, samples_count, neurons_count)
encode(x: Tensor, max_spike_time: int = 50) Tensor[源代码]#
参数:
  • x (torch.Tensor) -- Input tensor of shape (batch_size, channels_count, samples_count).

  • max_spike_time (int) -- The maximum spike time for the neurons beyond which neurons become inactive.

返回:

Encoded spike times of shape (batch_size, channels_count, samples_count, neurons_count).

返回类型:

torch.Tensor

抛出:

AssertionError -- If the input tensor x does not have shape (batch_size, channels_count, samples_count).

Neuron#

class spikingjelly.timing_based.neuron.Tempotron(in_features, out_features, T, tau=15.0, tau_s=3.75, v_threshold=1.0)[源代码]#

基类:Module

Tempotron is a Leaky Integrate-and-Fire (LIF) Neuron Model that accepts spikes from sensory neurons spikes and learns to classify spatiotemporal patterns of those spikes.

Reference:
Gütig R, Sompolinsky H.
The tempotron: a neuron that learns spike timing-based decisions.
Nat Neurosci.
2006 Mar;9(3):420-8.
DOI: 10.1038/nn1643.
Epub 2006 Feb 12.
PMID: 16474393.

Neuronal Simulation:

                 ┌─────────────────────┐
Time:            │0 1 2 3 4 5 6 7 8 T-1│10 11 12 13 14 15 16 17 18 19 20 ......
Sensory Neuron 1:│------------|--------│---------------------------------------
Sensory Neuron 2:│----|----------------│---------------------------------------
                 └─────────────────────┘
                       Time Window

Tempotron Neuron accepts timing of spikes from sensory neurons within a defined time window and learns to classify different spatiotemporal patterns of those spikes.

Tempotron doesn't consider the rate of incoming spikes, it consider the precise timing and spatial arrangement of incoming spikes.

Something like this:

-|--|------|--|----  vs  -|-----|-------|-|- (single neuron)

both have the same number of spikes, but different timing patterns.
Spike Patterns could be across multiple neurons too
参数:
  • in_features (int) -- Number of input neurons.

  • out_features (int) -- Number of output neurons.

  • T (int) -- Temporal window to consider for spike patterns.

  • tau (float, optional) -- Decay time constant, by default 15.0

  • tau_s (float, optional) -- Synaptic current time constant, by default 15.0 / 4

  • v_threshold (float, optional) -- Membrane threshold voltage, by default 1.0

forward(in_spikes, ret_type)[源代码]#
参数:
  • in_spikes_timings (torch.Tensor) -- Shape: (batch_size, in_neurons_count) The spike timings from sensory neurons.

  • out_voltage_type (str) -- The type of output voltage to return ['v', 'v_max', 'spikes'].

返回:

The output voltage based on the specified type.

返回类型:

torch.Tensor

抛出:

ValueError -- If an invalid out_voltage_type is provided. Should be one of 'v', 'v_max', or 'spikes'.

mse_loss(v_max, label)[源代码]#

Mean Squared Error Loss for Tempotron Neuron

wrong_mask: Identifies neurons that misclassified the input. A neuron is considered to have misclassified if: - It fired (v_max >= threshold) when it shouldn't have (not the correct class). - It didn't fire (v_max < threshold) when it should have (the correct class).

loss: Computes the mean squared error of the voltage difference for the misclassified neurons, averaged over the batch size.

参数:
  • v_max (torch.Tensor) -- Shape: (batch_size, out_neurons_count) The maximum voltage output from the Tempotron neuron.

  • label (torch.Tensor) -- Shape: (batch_size,) The true class labels for the input data.

返回类型:

torch.Tensor

Examples#