spikingjelly.activation_based.tensor_cache package

Module contents

class spikingjelly.activation_based.tensor_cache.DataTypeConvertCUDACode[源代码]

基类:object

float2bool = '\n    extern "C" __global__\n            void float2bool(const float* fs, unsigned char* bs, const int &N)\n            {\n                // assert N == numel / 8 and numel % 8 == 0\n                const int index = blockIdx.x * blockDim.x + threadIdx.x;\n                if (index < N)\n                {\n                    bs[index] = 0;\n                    const int mem_offset = (index << 3);\n                    #pragma unroll\n                    for(int i = 0; i < 8; i++)\n                    {\n                        bs[index] += ( ((unsigned char) fs[mem_offset + i]) << i);\n                    }\n                }\n            }\n    '
half2bool = '\n    #include <cuda_fp16.h>\n    extern "C" __global__\n            void half2bool(const half* fs, unsigned char* bs, const int &N)\n            {\n                // assert N == numel / 8 and numel % 8 == 0\n                const int index = blockIdx.x * blockDim.x + threadIdx.x;\n                if (index < N)\n                {\n                    bs[index] = 0;\n                    const int mem_offset = (index << 3);\n                    #pragma unroll\n                    for(int i = 0; i < 8; i++)\n                    {\n                        bs[index] += ( ((unsigned char) __half2float(fs[mem_offset + i])) << i);\n                    }\n                }\n            }\n    '
bool2float = '\n    extern "C" __global__\n            void bool2float(const unsigned char* bs, float* fs, const int &N)\n            {\n                const int index = blockIdx.x * blockDim.x + threadIdx.x;\n                if (index < N)\n                {\n                    const int mem_offset = (index << 3);\n                    unsigned char compressed_v = bs[index];\n                    #pragma unroll\n                    for(int i = 0; i < 8; i++)\n                    {\n                        fs[mem_offset + i] = (float) (compressed_v % 2);\n                        compressed_v = (compressed_v >> 1);\n                    }\n                }\n            }\n    '
bool2half = '\n    #include <cuda_fp16.h>\n    extern "C" __global__\n            void bool2half(const unsigned char* bs, half* fs, const int &N)\n            {\n                const int index = blockIdx.x * blockDim.x + threadIdx.x;\n                if (index < N)\n                {\n                    const int mem_offset = (index << 3);\n                    unsigned char compressed_v = bs[index];\n                    #pragma unroll\n                    for(int i = 0; i < 8; i++)\n                    {\n                        fs[mem_offset + i] = __float2half((float) (compressed_v % 2));\n                        compressed_v = (compressed_v >> 1);\n                    }\n                }\n            }\n    '
spikingjelly.activation_based.tensor_cache.float_spike_to_bool(spike: Tensor)[源代码]
参数:

spike (torch.Tensor) – a spike tensor whose dtype=torch.float or dtype=torch.half and all elements are 0 or 1

返回:

(spike_b, s_dtype, s_shape, s_padding) spike_b: a compressed spike tensor with dtype=torch.uint8 and each element stores 8 spikes s_dtype: the dtype of the original spike s_shape: the shape of the original spike s_padding: the number of padding elements

返回类型:

tuple

Compress a float/half spike tensor spike to an uint8 tensor spike_b. Each element in spike_b represents 8 elements of spike.

spikingjelly.activation_based.tensor_cache.bool_spike_to_float(spike_b: Tensor, s_dtype: dtype, s_shape: Size, s_padding: int = 0)[源代码]
参数:
  • spike_b (torch.Tensor) – a compressed spike tensor with dtype=torch.uint8 and each element stores 8 spikes

  • s_dtype (torch.dtype) – the dtype of the original spike

  • s_shape (torch.Size) – the shape of the original spike

  • s_padding (int) – the number of padding elements

返回:

the original tensor

返回类型:

torch.Tensor

spikingjelly.activation_based.tensor_cache.tensor_key(x: Tensor)[源代码]
class spikingjelly.activation_based.tensor_cache.BoolTensorCache[源代码]

基类:object

store_bool(spike: FloatTensor)[源代码]
get_float(tk, spike_shape: Size)[源代码]