import torch
try:
import triton
import triton.language as tl
except BaseException as e:
import logging
from . import dummy
logging.info(f"spikingjelly.activation_based.triton_kernel.compress: {e}")
triton = dummy.DummyImport()
tl = dummy.DummyImport()
from .triton_utils import contiguous_and_device_guard
__all__ = ["bit_spike_compress", "bit_spike_decompress"]
def _bit_spike_compress_pytorch(s_seq: torch.Tensor) -> torch.Tensor:
s_seq = s_seq.to(dtype=torch.bool).reshape(-1)
n_compressed_elements = (s_seq.numel() + 7) // 8
s_seq_compressed = torch.zeros(
n_compressed_elements, dtype=torch.uint8, device=s_seq.device
)
for i in range(8):
sliced = s_seq[i::8].to(dtype=torch.uint8)
sliced_len = sliced.numel()
if sliced_len > 0:
s_seq_compressed[:sliced_len] |= sliced << i
return s_seq_compressed
def _bit_spike_decompress_pytorch(
s_seq_compressed: torch.Tensor, shape
) -> torch.Tensor:
n_decompressed_elements = torch.Size(shape).numel()
s_seq_decompressed = torch.zeros(
n_decompressed_elements, dtype=torch.uint8, device=s_seq_compressed.device
)
for i in range(8):
sliced_len = (n_decompressed_elements - i + 7) // 8
sliced = ((s_seq_compressed >> i) & 1)[:sliced_len]
s_seq_decompressed[i::8] = sliced
return s_seq_decompressed.reshape(shape)
@triton.autotune(
configs=[triton.Config({"BLOCK_SIZE": b}) for b in [64, 128, 256]],
key=[],
restore_value=["s_seq_compressed_ptr"],
)
@triton.jit
def _bit_spike_compress_triton(
s_seq_ptr, # fp32, 0 or 1
s_seq_compressed_ptr,
n_elements,
n_compressed_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
store_offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
store_mask = store_offsets < n_compressed_elements
s_seq_compressed = tl.zeros(
[
BLOCK_SIZE,
],
dtype=tl.uint8,
)
for i in tl.static_range(8):
load_offsets = i + store_offsets * 8
load_mask = load_offsets < n_elements
s_seq = tl.load(s_seq_ptr + load_offsets, mask=load_mask, other=0.0)
s_seq = s_seq.to(tl.uint8)
s_seq_compressed = s_seq_compressed | (s_seq << i)
tl.store(s_seq_compressed_ptr + store_offsets, s_seq_compressed, mask=store_mask)
@triton.autotune(
configs=[triton.Config({"BLOCK_SIZE": b}) for b in [64, 128, 256]],
key=[],
restore_value=["s_seq_decompressed_ptr"],
)
@triton.jit
def _bit_spike_decompress_triton(
s_seq_compressed_ptr,
s_seq_decompressed_ptr,
n_compressed_elements,
n_decompressed_elements,
BLOCK_SIZE: tl.constexpr, # must be dividable by 8
):
pid = tl.program_id(0)
load_offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
load_mask = load_offsets < n_compressed_elements
s_seq_compressed = tl.load(
s_seq_compressed_ptr + load_offsets,
mask=load_mask,
other=0,
)
for i in tl.static_range(8):
store_offsets = i + load_offsets * 8
store_mask = store_offsets < n_decompressed_elements
tl.store(
s_seq_decompressed_ptr + store_offsets,
(s_seq_compressed >> i) & 1,
mask=store_mask,
)
[文档]
@contiguous_and_device_guard
def bit_spike_compress(s_seq):
"""Compress a float32 spike tensor into a compact uint8 representation using bit-packing.
**API Language:**
:ref:`中文 <bit_spike_compress-cn>` | :ref:`English <bit_spike_compress-en>`
----
.. _bit_spike_compress-cn:
* **中文**
对脉冲张量进行位压缩
:param s_seq: Spike sequence tensor of ``float32``
:type s_seq: ``torch.Tensor``
:return: Compressed uint8 tensor (8x smaller)
:rtype: torch.Tensor
Each element is rounded to 0 or 1 (by threshold 0.5) and packed as a single bit.
Works on both CPU and GPU (via Triton kernel on CUDA).
----
.. _bit_spike_compress-en:
* **English**
Bit-compress a spike tensor
:param s_seq: Spike sequence tensor of ``float32``
:type s_seq: ``torch.Tensor``
:return: Compressed uint8 tensor (8x smaller)
:rtype: torch.Tensor
"""
s_seq = s_seq.reshape(-1)
if s_seq.device.type != "cuda":
return _bit_spike_compress_pytorch(s_seq)
n_elements = s_seq.numel()
n_compressed_elements = (n_elements + 7) // 8
s_seq_compressed = torch.zeros(
n_compressed_elements, dtype=torch.uint8, device=s_seq.device
)
grid = lambda meta: (triton.cdiv(n_compressed_elements, meta["BLOCK_SIZE"]),)
with torch.cuda.device(s_seq.device):
_bit_spike_compress_triton[grid](
s_seq,
s_seq_compressed,
n_elements,
n_compressed_elements,
)
return s_seq_compressed
[文档]
@contiguous_and_device_guard
def bit_spike_decompress(s_seq_compressed, shape):
"""Decompress a uint8 bit-packed tensor back to a float32 spike tensor.
**API Language:**
:ref:`中文 <bit_spike_decompress-cn>` | :ref:`English <bit_spike_decompress-en>`
----
.. _bit_spike_decompress-cn:
* **中文**
解压位压缩的脉冲张量
:param s_seq_compressed: Compressed uint8 tensor from :func:`bit_spike_compress`
:type s_seq_compressed: ``torch.Tensor``
:param shape: Original shape of the uncompressed tensor
:type shape: tuple
:return: Decompressed float32 spike tensor (values are 0.0 or 1.0)
:rtype: torch.Tensor
----
.. _bit_spike_decompress-en:
* **English**
Decompress a bit-compressed spike tensor
:param s_seq_compressed: Compressed uint8 tensor from :func:`bit_spike_compress`
:param shape: Original shape of the uncompressed tensor
:type s_seq_compressed: ``torch.Tensor``
:type shape: tuple
:return: Decompressed float32 spike tensor (values are 0.0 or 1.0)
:rtype: torch.Tensor
"""
if s_seq_compressed.device.type != "cuda":
return _bit_spike_decompress_pytorch(s_seq_compressed, shape)
n_compressed_elements = s_seq_compressed.numel()
n_decompressed_elements = torch.Size(shape).numel()
s_seq_decompressed = torch.zeros(
n_decompressed_elements, dtype=torch.uint8, device=s_seq_compressed.device
)
grid = lambda meta: (triton.cdiv(n_compressed_elements, meta["BLOCK_SIZE"]),)
with torch.cuda.device(s_seq_compressed.device):
_bit_spike_decompress_triton[grid](
s_seq_compressed,
s_seq_decompressed,
n_compressed_elements,
n_decompressed_elements,
)
return s_seq_decompressed.reshape(shape)