spikingjelly.activation_based.memopt.compress 源代码

import abc
import torch
import logging


__all__ = [
    "BaseSpikeCompressor",
    "NullSpikeCompressor",
    "BooleanSpikeCompressor",
    "Uint8SpikeCompressor",
    "BitSpikeCompressor",
    "SparseSpikeCompressor",
]

try:
    import triton
except BaseException:
    triton = None

if triton is not None:
    logging.info(
        "spikingjelly.activation_based.triton_kernel.compress: "
        "Use Triton backend for bit spike compression"
    )
    from ..triton_kernel import bit_spike_compress, bit_spike_decompress
else:
    logging.info(
        "spikingjelly.activation_based.triton_kernel.compress: "
        "Use PyTorch backend for bit spike compression"
    )

    def bit_spike_compress(s_seq: torch.Tensor) -> torch.Tensor:
        s_seq = s_seq.to(dtype=torch.bool).reshape(-1)
        compressed_shape = (s_seq.numel() + 7) // 8
        s_seq_compressed = torch.zeros(
            compressed_shape, dtype=torch.uint8, device=s_seq.device
        )
        for i in range(8):
            sliced = s_seq[i::8].to(dtype=torch.uint8)
            sliced_len = sliced.numel()
            if sliced_len > 0:
                s_seq_compressed[:sliced_len] |= sliced << i
        return s_seq_compressed

    def bit_spike_decompress(s_seq_compressed: torch.Tensor, shape) -> torch.Tensor:
        decompressed_len = torch.Size(shape).numel()
        s_seq_decompressed = torch.zeros(
            decompressed_len, dtype=torch.bool, device=s_seq_compressed.device
        )
        for i in range(8):
            sliced_len = (decompressed_len - i + 7) // 8
            sliced = ((s_seq_compressed >> i) & 1)[:sliced_len]
            s_seq_decompressed[i::8] = sliced
        return s_seq_decompressed.reshape(shape)


[文档] class BaseSpikeCompressor(abc.ABC): requires_strictly_binary = False def __init__(self): r""" **API Language:** :ref:`中文 <BaseSpikeCompressor.__init__-cn>` | :ref:`English <BaseSpikeCompressor.__init__-en>` ---- .. _BaseSpikeCompressor.__init__-cn: * **中文** * **中文** 脉冲压缩器的抽象基类。欲实现脉冲压缩器,需继承该抽象基类并实现 ``_compress`` 和 ``_decompress`` 方法。 ---- .. _BaseSpikeCompressor.__init__-en: * **English** * **English** Abstract base class for spike compressors. To implement a spike compressor, you need to inherit this abstract base class and implement the ``_compress`` and ``_decompress`` methods. :return: None :rtype: None """ pass @abc.abstractmethod def _compress(self, s_seq: torch.Tensor) -> torch.Tensor: pass @abc.abstractmethod def _decompress(self, s_seq: torch.Tensor, shape) -> torch.Tensor: pass
[文档] def compress(self, s_seq: torch.Tensor) -> torch.Tensor: r""" **API Language:** :ref:`中文 <BaseSpikeCompressor.compress-cn>` | :ref:`English <BaseSpikeCompressor.compress-en>` ---- .. _BaseSpikeCompressor.compress-cn: * **中文** 压缩缩脉冲序列。 :param s_seq: 输入脉冲序列 :type s_seq: torch.Tensor :return: 压缩后的脉冲序列 :rtype: torch.Tensor ---- .. _BaseSpikeCompressor.compress-en: * **English** Compress spike sequence. :param s_seq: input spike sequence :type s_seq: torch.Tensor :return: compressed spike sequence :rtype: torch.Tensor """ with torch.no_grad(): return self._compress(s_seq)
[文档] def decompress(self, s_seq: torch.Tensor, shape) -> torch.Tensor: r""" **API Language:** :ref:`中文 <BaseSpikeCompressor.decompress-cn>` | :ref:`English <BaseSpikeCompressor.decompress-en>` ---- .. _BaseSpikeCompressor.decompress-cn: * **中文** 解压缩脉冲序列。 :param s_seq: 压缩的脉冲序列 :type s_seq: torch.Tensor :param shape: 原始形状 :type shape: tuple or torch.Size :return: 解压缩后的脉冲序列 :rtype: torch.Tensor ---- .. _BaseSpikeCompressor.decompress-en: * **English** Decompress spike sequence. :param s_seq: compressed spike sequence :type s_seq: torch.Tensor :param shape: original shape :type shape: tuple or torch.Size :return: decompressed spike sequence :rtype: torch.Tensor """ with torch.no_grad(): return self._decompress(s_seq, shape)
[文档] class NullSpikeCompressor(BaseSpikeCompressor): requires_strictly_binary = False def __init__(self): r""" **API Language:** :ref:`中文 <NullSpikeCompressor.__init__-cn>` | :ref:`English <NullSpikeCompressor.__init__-en>` ---- .. _NullSpikeCompressor.__init__-cn: * **中文** * **中文** 空脉冲压缩器。压缩和解压缩过程都是恒等映射。 ``NullSpikeCompressor`` 是唯一能够无损处理非二进制张量的"脉冲压缩器"模块。例如,SNN的输入层 应该始终使用 ``NullSpikeCompressor`` ,因为其输入是浮点张量而不是二值张量。 ---- .. _NullSpikeCompressor.__init__-en: * **English** * **English** Null spike compressor. The compression and decompression process are identity mapping. ``NullSpikeCompressor`` is the only compressor module that can deal with non-binary tensors losslessly. For instance, the input layer should always use ``NullSpikeCompressor``, as its input is a float tensor rather than a binary tensor. ---- * **代码示例 | Example** .. code-block:: python import torch from spikingjelly.activation_based.memopt.compress import ( NullSpikeCompressor, ) compressor = NullSpikeCompressor() x = torch.randn(32, 10) compressed = compressor.compress(x) decompressed = compressor.decompress(compressed, x.shape) :return: None :rtype: None """ super().__init__() def _compress(self, s_seq: torch.Tensor) -> torch.Tensor: return s_seq def _decompress(self, s_seq: torch.Tensor, shape) -> torch.Tensor: return s_seq
[文档] class BooleanSpikeCompressor(BaseSpikeCompressor): requires_strictly_binary = True def __init__(self): r""" **API Language:** :ref:`中文 <BooleanSpikeCompressor.__init__-cn>` | :ref:`English <BooleanSpikeCompressor.__init__-en>` ---- .. _BooleanSpikeCompressor.__init__-cn: * **中文** * **中文** 布尔脉冲压缩器。 将脉冲序列转换为布尔类型以节省内存。要求输入必须是严格的二进制脉冲。 ---- .. _BooleanSpikeCompressor.__init__-en: * **English** * **English** Boolean spike compressor. Convert spike sequences to boolean type to save memory. Requires input to be strictly binary spikes. ---- * **代码示例 | Example** .. code-block:: python import torch from spikingjelly.activation_based.memopt.compress import ( BooleanSpikeCompressor, ) compressor = BooleanSpikeCompressor() spikes = torch.randint(0, 2, (32, 100)).float() compressed = compressor.compress(spikes) decompressed = compressor.decompress(compressed, spikes.shape) :return: None :rtype: None """ super().__init__() self.s_seq_dtype = torch.float32 def _compress(self, s_seq: torch.Tensor) -> torch.Tensor: self.s_seq_dtype = s_seq.dtype return s_seq.to(dtype=torch.bool) def _decompress(self, s_seq: torch.Tensor, shape) -> torch.Tensor: return s_seq.to(dtype=self.s_seq_dtype).reshape(shape)
[文档] class Uint8SpikeCompressor(BaseSpikeCompressor): requires_strictly_binary = False def __init__(self): r""" **API Language:** :ref:`中文 <Uint8SpikeCompressor.__init__-cn>` | :ref:`English <Uint8SpikeCompressor.__init__-en>` ---- .. _Uint8SpikeCompressor.__init__-cn: * **中文** * **中文** Uint8脉冲压缩器。 将脉冲序列转换为uint8类型以节省内存。可以处理非二进制整数数值。 ---- .. _Uint8SpikeCompressor.__init__-en: * **English** * **English** Uint8 spike compressor. Convert spike sequences to uint8 type to save memory. Can handle non-binary integer values. ---- * **代码示例 | Example** .. code-block:: python import torch from spikingjelly.activation_based.memopt.compress import ( Uint8SpikeCompressor, ) compressor = Uint8SpikeCompressor() x = torch.randn(32, 10) compressed = compressor.compress(x) decompressed = compressor.decompress(compressed, x.shape) :return: None :rtype: None """ super().__init__() self.s_seq_dtype = torch.float32 def _compress(self, s_seq: torch.Tensor) -> torch.Tensor: self.s_seq_dtype = s_seq.dtype return s_seq.to(dtype=torch.uint8) def _decompress(self, s_seq: torch.Tensor, shape) -> torch.Tensor: return s_seq.to(dtype=self.s_seq_dtype).reshape(shape)
[文档] class BitSpikeCompressor(BaseSpikeCompressor): requires_strictly_binary = True def __init__(self): r""" **API Language:** :ref:`中文 <BitSpikeCompressor.__init__-cn>` | :ref:`English <BitSpikeCompressor.__init__-en>` ---- .. _BitSpikeCompressor.__init__-cn: * **中文** * **中文** 比特脉冲压缩器。 使用位压缩技术将8个二进制脉冲压缩到一个字节中,实现极高的内存压缩比。 要求输入必须是严格的二进制脉冲(0或1)。 ---- .. _BitSpikeCompressor.__init__-en: * **English** * **English** Bit-level spike compressor. Use bit compression technique to compress 8 binary spikes into one byte, achieving high memory compression ratio. Requires input to be strictly binary spikes (0 or 1). ---- * **代码示例 | Example** .. code-block:: python import torch from spikingjelly.activation_based.memopt.compress import BitSpikeCompressor compressor = BitSpikeCompressor() spikes = torch.randint(0, 2, (32, 1000)).float() compressed = compressor.compress(spikes) decompressed = compressor.decompress(compressed, spikes.shape) :return: None :rtype: None """ super().__init__() self.s_seq_dtype = torch.float32 def _compress(self, s_seq: torch.Tensor) -> torch.Tensor: # s_seq: float32 return bit_spike_compress(s_seq) def _decompress(self, s_seq: torch.Tensor, shape) -> torch.Tensor: s_seq_decompressed = bit_spike_decompress(s_seq, shape) return s_seq_decompressed.to(dtype=self.s_seq_dtype)
[文档] class SparseSpikeCompressor(BaseSpikeCompressor): requires_strictly_binary = True def __init__(self, dtype=torch.int64): r""" **API Language:** :ref:`中文 <SparseSpikeCompressor.__init__-cn>` | :ref:`English <SparseSpikeCompressor.__init__-en>` ---- .. _SparseSpikeCompressor.__init__-cn: * **中文** * **中文** 稀疏脉冲压缩器。 只存储非零脉冲的位置索引,适用于稀疏脉冲序列。 要求输入必须是严格的二进制脉冲(0或1)。 :param dtype: 索引数据类型,默认为 ``torch.int64`` :type dtype: torch.dtype ---- .. _SparseSpikeCompressor.__init__-en: * **English** * **English** Sparse spike compressor. Only store the position indices of non-zero spikes, suitable for sparse spike sequences. Requires input to be strictly binary spikes (0 or 1). :param dtype: index data type. Default to ``torch.int64`` :type dtype: torch.dtype ---- * **代码示例 | Example** .. code-block:: python import torch from spikingjelly.activation_based.memopt.compress import ( SparseSpikeCompressor, ) compressor = SparseSpikeCompressor() spikes = (torch.rand(32, 1000) < 0.04).float() compressed = compressor.compress(spikes) decompressed = compressor.decompress(compressed, spikes.shape) :return: None :rtype: None """ super().__init__() self.dtype = dtype self.s_seq_dtype = torch.float32 def _compress(self, s_seq: torch.Tensor) -> torch.Tensor: indices = torch.nonzero(s_seq.reshape(-1)) self.s_seq_dtype = s_seq.dtype return indices.to(dtype=self.dtype) def _decompress(self, s_seq: torch.Tensor, shape) -> torch.Tensor: s_seq_decompressed = torch.zeros( torch.Size(shape).numel(), dtype=self.s_seq_dtype, device=s_seq.device ) s_seq_decompressed = s_seq_decompressed.scatter_( dim=0, index=s_seq.to(dtype=torch.int64).reshape(-1), value=1, ) return s_seq_decompressed.reshape(shape)