spikingjelly.activation_based.triton_kernel.compress 源代码

import torch

try:
    import triton
    import triton.language as tl
except BaseException as e:
    import logging
    from . import dummy

    logging.info(f"spikingjelly.activation_based.triton_kernel.compress: {e}")
    triton = dummy.DummyImport()
    tl = dummy.DummyImport()

from .triton_utils import contiguous_and_device_guard

__all__ = ["bit_spike_compress", "bit_spike_decompress"]


def _bit_spike_compress_pytorch(s_seq: torch.Tensor) -> torch.Tensor:
    s_seq = s_seq.to(dtype=torch.bool).reshape(-1)
    n_compressed_elements = (s_seq.numel() + 7) // 8
    s_seq_compressed = torch.zeros(
        n_compressed_elements, dtype=torch.uint8, device=s_seq.device
    )
    for i in range(8):
        sliced = s_seq[i::8].to(dtype=torch.uint8)
        sliced_len = sliced.numel()
        if sliced_len > 0:
            s_seq_compressed[:sliced_len] |= sliced << i
    return s_seq_compressed


def _bit_spike_decompress_pytorch(
    s_seq_compressed: torch.Tensor, shape
) -> torch.Tensor:
    n_decompressed_elements = torch.Size(shape).numel()
    s_seq_decompressed = torch.zeros(
        n_decompressed_elements, dtype=torch.uint8, device=s_seq_compressed.device
    )
    for i in range(8):
        sliced_len = (n_decompressed_elements - i + 7) // 8
        sliced = ((s_seq_compressed >> i) & 1)[:sliced_len]
        s_seq_decompressed[i::8] = sliced
    return s_seq_decompressed.reshape(shape)


@triton.autotune(
    configs=[triton.Config({"BLOCK_SIZE": b}) for b in [64, 128, 256]],
    key=[],
    restore_value=["s_seq_compressed_ptr"],
)
@triton.jit
def _bit_spike_compress_triton(
    s_seq_ptr,  # fp32, 0 or 1
    s_seq_compressed_ptr,
    n_elements,
    n_compressed_elements,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    store_offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    store_mask = store_offsets < n_compressed_elements

    s_seq_compressed = tl.zeros(
        [
            BLOCK_SIZE,
        ],
        dtype=tl.uint8,
    )

    for i in tl.static_range(8):
        load_offsets = i + store_offsets * 8
        load_mask = load_offsets < n_elements
        s_seq = tl.load(s_seq_ptr + load_offsets, mask=load_mask, other=0.0)
        s_seq = s_seq.to(tl.uint8)
        s_seq_compressed = s_seq_compressed | (s_seq << i)

    tl.store(s_seq_compressed_ptr + store_offsets, s_seq_compressed, mask=store_mask)


@triton.autotune(
    configs=[triton.Config({"BLOCK_SIZE": b}) for b in [64, 128, 256]],
    key=[],
    restore_value=["s_seq_decompressed_ptr"],
)
@triton.jit
def _bit_spike_decompress_triton(
    s_seq_compressed_ptr,
    s_seq_decompressed_ptr,
    n_compressed_elements,
    n_decompressed_elements,
    BLOCK_SIZE: tl.constexpr,  # must be dividable by 8
):
    pid = tl.program_id(0)
    load_offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    load_mask = load_offsets < n_compressed_elements

    s_seq_compressed = tl.load(
        s_seq_compressed_ptr + load_offsets,
        mask=load_mask,
        other=0,
    )

    for i in tl.static_range(8):
        store_offsets = i + load_offsets * 8
        store_mask = store_offsets < n_decompressed_elements
        tl.store(
            s_seq_decompressed_ptr + store_offsets,
            (s_seq_compressed >> i) & 1,
            mask=store_mask,
        )


[文档] @contiguous_and_device_guard def bit_spike_compress(s_seq): """Compress a float32 spike tensor into a compact uint8 representation using bit-packing. **API Language:** :ref:`中文 <bit_spike_compress-cn>` | :ref:`English <bit_spike_compress-en>` ---- .. _bit_spike_compress-cn: * **中文** 对脉冲张量进行位压缩 :param s_seq: Spike sequence tensor of ``float32`` :type s_seq: ``torch.Tensor`` :return: Compressed uint8 tensor (8x smaller) :rtype: torch.Tensor Each element is rounded to 0 or 1 (by threshold 0.5) and packed as a single bit. Works on both CPU and GPU (via Triton kernel on CUDA). ---- .. _bit_spike_compress-en: * **English** Bit-compress a spike tensor :param s_seq: Spike sequence tensor of ``float32`` :type s_seq: ``torch.Tensor`` :return: Compressed uint8 tensor (8x smaller) :rtype: torch.Tensor """ s_seq = s_seq.reshape(-1) if s_seq.device.type != "cuda": return _bit_spike_compress_pytorch(s_seq) n_elements = s_seq.numel() n_compressed_elements = (n_elements + 7) // 8 s_seq_compressed = torch.zeros( n_compressed_elements, dtype=torch.uint8, device=s_seq.device ) grid = lambda meta: (triton.cdiv(n_compressed_elements, meta["BLOCK_SIZE"]),) with torch.cuda.device(s_seq.device): _bit_spike_compress_triton[grid]( s_seq, s_seq_compressed, n_elements, n_compressed_elements, ) return s_seq_compressed
[文档] @contiguous_and_device_guard def bit_spike_decompress(s_seq_compressed, shape): """Decompress a uint8 bit-packed tensor back to a float32 spike tensor. **API Language:** :ref:`中文 <bit_spike_decompress-cn>` | :ref:`English <bit_spike_decompress-en>` ---- .. _bit_spike_decompress-cn: * **中文** 解压位压缩的脉冲张量 :param s_seq_compressed: Compressed uint8 tensor from :func:`bit_spike_compress` :type s_seq_compressed: ``torch.Tensor`` :param shape: Original shape of the uncompressed tensor :type shape: tuple :return: Decompressed float32 spike tensor (values are 0.0 or 1.0) :rtype: torch.Tensor ---- .. _bit_spike_decompress-en: * **English** Decompress a bit-compressed spike tensor :param s_seq_compressed: Compressed uint8 tensor from :func:`bit_spike_compress` :param shape: Original shape of the uncompressed tensor :type s_seq_compressed: ``torch.Tensor`` :type shape: tuple :return: Decompressed float32 spike tensor (values are 0.0 or 1.0) :rtype: torch.Tensor """ if s_seq_compressed.device.type != "cuda": return _bit_spike_decompress_pytorch(s_seq_compressed, shape) n_compressed_elements = s_seq_compressed.numel() n_decompressed_elements = torch.Size(shape).numel() s_seq_decompressed = torch.zeros( n_decompressed_elements, dtype=torch.uint8, device=s_seq_compressed.device ) grid = lambda meta: (triton.cdiv(n_compressed_elements, meta["BLOCK_SIZE"]),) with torch.cuda.device(s_seq_compressed.device): _bit_spike_decompress_triton[grid]( s_seq_compressed, s_seq_decompressed, n_compressed_elements, n_decompressed_elements, ) return s_seq_decompressed.reshape(shape)