spikingjelly.timing_based.encoding 源代码

from typing import Annotated

import einops
import torch
from pydantic import BaseModel, ConfigDict, Field, model_validator


class _GaussianTuningConfig(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)

    neurons_count: Annotated[
        int,
        Field(
            description="The number of neurons per input channel.",
            gt=2,
        ),
    ] = 5

    max_spike_time: Annotated[
        int,
        Field(
            description="""
            The maximum spike time for the neurons
            beyond which neurons become inactive.
            """,
            gt=0,
        ),
    ] = 100

    beta: Annotated[
        float,
        Field(
            description="The sharpness of the tuning curves.",
            gt=0.0,
        ),
    ] = 1.5

    channels_count: Annotated[
        int,
        Field(
            description="The number of channels/dims in the input data.",
            gt=0,
        ),
    ] = 1

    input_min: Annotated[
        torch.Tensor,
        Field(
            description="The minimum value of the input data.",
        ),
    ] = torch.tensor(0.0)

    input_max: Annotated[
        torch.Tensor,
        Field(
            description="The maximum value of the input data.",
        ),
    ] = torch.tensor(1.0)

    @model_validator(mode="after")
    def sanity_check(self):
        if not self.input_min.shape == self.input_max.shape == (self.channels_count,):
            raise ValueError(
                f"input_min and input_max should have shape ({self.channels_count},),"
                f" but got shapes {self.input_min.shape} and {self.input_max.shape}."
            )
        if not torch.all(self.input_min < self.input_max):
            raise ValueError("All elements of input_min must be less than input_max.")

        return self


class _GaussianTuning:
    def __init__(self, cfg: dict, device: torch.device) -> None:
        self.cfg = _GaussianTuningConfig(**cfg) if cfg else _GaussianTuningConfig()
        self.device = device
        self._build_receptive_fields()

    def encode(self, x: torch.Tensor, max_spike_time: int = 50) -> torch.Tensor:
        # ---- Input validation -------------------------------------------------------
        assert x.dim() == 3 and x.shape[1] == self.cfg.channels_count, (
            "Input tensor x must be 3-dimensional"
            "(batch_size, channels_count, samples_count)."
            f"Got shape: {x.shape}."
        )
        # Due to reshaping, we need to know the original shape for output formatting
        x_shape_orig = x.shape
        # ---- Input reshaping --------------------------------------------------------
        x = einops.rearrange(x, "b c s -> (b s) c")
        x = einops.repeat(x, "bs c -> bs c n", n=self.cfg.neurons_count)
        # ---- Neuron Responses -------------------------------------------------------
        neuron_responses = self._calculate_neurons_response(
            x,
            self.neuron_centers,
            self.neuron_variances,
        )
        # ---- Spike time mapping -----------------------------------------------------
        # higher response -> earlier spike time | lower response -> later spike time
        max_spike_time = max_spike_time or self.cfg.max_spike_time
        spike_times = (max_spike_time * (1 - neuron_responses)).round()
        spike_times[spike_times >= max_spike_time] = -1  # inactive neurons
        # ---- Output formatting ------------------------------------------------------
        # Reshape back to (batch_size, samples_count, channels_count, neurons_count)
        batch_size = x_shape_orig[0]
        samples_count = x_shape_orig[2]
        out_spikes = einops.rearrange(
            spike_times,
            "(b s) c n -> b c s n",
            b=batch_size,
            s=samples_count,
        )

        return out_spikes

    def _build_receptive_fields(self) -> None:
        """
        Build the Gaussian receptive fields for each neuron in each input dimension.
        Each neuron is defined by its center (mu) and variance (sigma^2).
        """
        # A row of neuron indices [1, 2, ..., m] for each input dimension
        neuron_indices = (
            torch.arange(start=1, end=self.cfg.neurons_count + 1)  # A row of neurons
            .unsqueeze(0)  # add extra dimension to handle other input dimensions
            .repeat(self.cfg.channels_count, 1)  # repeat for each dimension
            .float()
            .to(self.device)
        )
        # Repeat input_min and input_max for each neurons row
        _input_min = self.cfg.input_min.unsqueeze(-1).repeat(1, self.cfg.neurons_count)
        _input_max = self.cfg.input_max.unsqueeze(-1).repeat(1, self.cfg.neurons_count)
        _input_range = _input_max - _input_min
        # Calculate the centers and variances for each neuron in each dimension
        self.neuron_centers = self._calculate_neurons_centers(
            _input_min,
            _input_range,
            neuron_indices,
            self.cfg.neurons_count,
        )
        self.neuron_variances = self._calculate_neurons_variances(
            _input_range,
            self.cfg.neurons_count,
            self.cfg.beta,
        )

    @staticmethod
    def _calculate_neurons_centers(
        in_min: torch.Tensor,
        in_range: torch.Tensor,
        indices: torch.Tensor,
        neurons_count: int,
    ) -> torch.Tensor:
        """
        mu = in_min + ((2 * indices - 3) / (2 * (neurons_count - 2))) * in_range
        where:
            indices: neuron index (1 to m)
            neurons_count: number of neurons
            in_min: minimum value of the input data
            in_range: range of the input data (in_max - in_min)
        """
        nominator = 2 * indices - 3
        denominator = 2 * (neurons_count - 2)
        centers = in_min + (nominator / denominator) * in_range

        return centers

    @staticmethod
    def _calculate_neurons_variances(
        in_range: torch.Tensor,
        neurons_count: int,
        beta: float,
    ) -> torch.Tensor:
        """
        sigma^2 = ( (in_range / (beta * (neurons_count - 2)) ) )^2
        where:
            neurons_count: number of neurons
            in_range: range of the input data (in_max - in_min)
            beta: sharpness of the tuning curves
        """
        nominator = in_range
        denominator = beta * (neurons_count - 2)
        variances = (nominator / denominator).square()

        return variances

    @staticmethod
    def _calculate_neurons_response(
        x: torch.Tensor,
        centers: torch.Tensor,
        variances: torch.Tensor,
    ) -> torch.Tensor:
        """
        neuron_response = exp( -( (x - centers)^2 / (2 * variances) ) )
        where:
            x: input data
            centers: neuron centers (mu)
            variances: neuron variances (sigma^2)
        """
        nominator = (x - centers).square()
        denominator = 2 * variances
        neuron_responses = torch.exp(-(nominator / denominator))

        return neuron_responses


# NOTE: Facade class for backward compatibility
[文档] class GaussianTuning: def __init__( self, n: int, m: int, x_min: torch.Tensor, x_max: torch.Tensor, ) -> None: """ **Gaussian Receptive Field Population Coding** All neurons are arranged in a grid where each row corresponds to an input dimension and each column corresponds to a neuron in that dimension. When an input value is given, each neuron computes its response based on its Gaussian tuning curve. The response is then mapped to a spike time, where a higher response results in an earlier spike time, and a lower response results in a later spike time. If the spike time exceeds a predefined maximum spike time, the neuron becomes inactive (no spike, represented by -1). Reference: | Sander M. Bohte, Joost N. Kok, Han La Poutré, | Error-backpropagation in temporally encoded networks of spiking neurons, | Neurocomputing, | Volume 48, Issues 1–4, | 2002, | Pages 17-37, | ISSN 0925-2312, | https://doi.org/10.1016/S0925-2312(01)00658-0. | (https://www.sciencedirect.com/science/article/pii/S0925231201006580) Neuron Spatial Receptive Field (Lecture): | https://youtu.be/fCqt07IXUPI?si=jVT-QlmEgrbQZkB2 Parameters ---------- n : int The number of channels/dims in the input data. m : int The number of neurons per input channel. x_min : torch.Tensor 1D tensor of shape (n,) representing the minimum value of the input data for each input dimension. x_max : torch.Tensor 1D tensor of shape (n,) representing the maximum value of the input data for each input dimension. Raises ------ ValueError If x_min and x_max do not have shape (n,) or if any element in x_min is not less than the corresponding element in x_max. Example ------- .. code-block:: python >>> import torch >>> from spikingjelly.timing_based import encoding >>> x_min = torch.tensor([0.0]) >>> x_max = torch.tensor([1.0]) >>> encoder = encoding.GaussianTuning(n=1, m=4, x_min=x_min, x_max=x_max) >>> x = torch.tensor([[[0.1, 0.5, 0.9]]]) >>> spikes = encoder.encode(x, max_spike_time=100) >>> print(spikes) tensor([[[[42., 10., 85., -1.], [92., 25., 25., 92.], [-1., 85., 10., 42.]]]]) >>> print(spikes.shape) torch.Size([1, 1, 3, 4]) # (batch_size, channels_count, samples_count, neurons_count) An array/grid of Neuronal Receptive Fields:: Each one is a Gaussian curve defined by its center (mu) and variance (sigma^2). The grid is made up of m neurons for each of the n input dimensions. ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ │ 1 │ │ 2 │ │ 3 │ │ 4 │ ... │ m │ <- m neurons (m=5) for Dimension 0 └───┘ └───┘ └───┘ └───┘ └───┘ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ │ 1 │ │ 2 │ │ 3 │ │ 4 │ ... │ m │ <- m neurons (m=5) for Dimension 1 └───┘ └───┘ └───┘ └───┘ └───┘ ... ... ... ... ... ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ │ 1 │ │ 2 │ │ 3 │ │ 4 │ ... │ m │ <- m neurons (m=5) for Dimension n └───┘ └───┘ └───┘ └───┘ └───┘ Each neuron computes the response to the input based on its Gaussian tuning curve:: 0.1 0.1 0.1 0.1 0.5 0.5 0.5 0.5 0.9 0.9 0.9 0.9 │ │ │ │ ▼ ▼ ▼ ▼ ┌───┐ ┌───┐ ┌───┐ ┌───┐ │ 1 │ │ 2 │ │ 3 │ │ 4 │ <-- 4 neurons └───┘ └───┘ └───┘ └───┘ │ │ │ │ ▼ ▼ ▼ ▼ 0.5762 0.9037 0.1494 0.0026 0.0796 0.7548 0.7548 0.0796 <-- responses (probability of firing) 0.0026 0.1494 0.9037 0.5762 │ │ │ │ ▼ ▼ ▼ ▼ 42 10 85 -1 92 25 25 92 <-- spike times -1 85 10 42 lower response -> later spike time higher response -> earlier spike time if the spike time >= max_spike_time, neuron becomes inactive (no spike, -1) Example ------- .. code-block:: python >>> x_min = torch.tensor([0.0, 0.0, 0.0]) >>> x_max = torch.tensor([1.0, 1.0, 1.0]) >>> encoder = GaussianTuning(n=3, m=5, x_min=x_min, x_max=x_max) >>> x = torch.tensor([[[0.1, 0.5, 0.9], [0.2, 0.6, 0.8], [0.3, 0.7, 0.4]]]) >>> spikes = encoder.encode(x, max_spike_time=100) >>> print(spikes) tensor([[[[51., 4., 80., -1., -1.], [99., 68., 0., 68., 99.], [-1., -1., 80., 4., 51.]], [[74., 1., 60., 98., -1.], [-1., 85., 10., 42., 96.], [-1., 98., 60., 1., 74.]], [[89., 16., 33., 94., -1.], [-1., 94., 33., 16., 89.], [96., 42., 10., 85., -1.]]]]) >>> print(spikes.shape) torch.Size([1, 3, 3, 5]) # (batch_size, channels_count, samples_count, neurons_count) """ cfg = dict( neurons_count=m, channels_count=n, input_min=x_min, input_max=x_max, ) self._encoder = _GaussianTuning(cfg, device=x_min.device) self._set_attributes()
[文档] def encode( self, x: torch.Tensor, max_spike_time: int = 50, ) -> torch.Tensor: """ Parameters ---------- x : torch.Tensor Input tensor of shape (batch_size, channels_count, samples_count). max_spike_time : int The maximum spike time for the neurons beyond which neurons become inactive. Returns ------- torch.Tensor Encoded spike times of shape (batch_size, channels_count, samples_count, neurons_count). Raises ------ AssertionError If the input tensor x does not have shape (batch_size, channels_count, samples_count). """ encoded = self._encoder.encode(x, max_spike_time) return encoded
def _set_attributes(self) -> None: self.m = self._encoder.cfg.neurons_count self.n = self._encoder.cfg.channels_count self.mu = self._encoder.neuron_centers self.sigma2 = self._encoder.neuron_variances