spikingjelly.activation_based.tensor_cache package
Module contents
- class spikingjelly.activation_based.tensor_cache.DataTypeConvertCUDACode[源代码]
基类:
object
- float2bool = '\n extern "C" __global__\n void float2bool(const float* fs, unsigned char* bs, const int &N)\n {\n // assert N == numel / 8 and numel % 8 == 0\n const int index = blockIdx.x * blockDim.x + threadIdx.x;\n if (index < N)\n {\n bs[index] = 0;\n const int mem_offset = (index << 3);\n #pragma unroll\n for(int i = 0; i < 8; i++)\n {\n bs[index] += ( ((unsigned char) fs[mem_offset + i]) << i);\n }\n }\n }\n '
- half2bool = '\n #include <cuda_fp16.h>\n extern "C" __global__\n void half2bool(const half* fs, unsigned char* bs, const int &N)\n {\n // assert N == numel / 8 and numel % 8 == 0\n const int index = blockIdx.x * blockDim.x + threadIdx.x;\n if (index < N)\n {\n bs[index] = 0;\n const int mem_offset = (index << 3);\n #pragma unroll\n for(int i = 0; i < 8; i++)\n {\n bs[index] += ( ((unsigned char) __half2float(fs[mem_offset + i])) << i);\n }\n }\n }\n '
- bool2float = '\n extern "C" __global__\n void bool2float(const unsigned char* bs, float* fs, const int &N)\n {\n const int index = blockIdx.x * blockDim.x + threadIdx.x;\n if (index < N)\n {\n const int mem_offset = (index << 3);\n unsigned char compressed_v = bs[index];\n #pragma unroll\n for(int i = 0; i < 8; i++)\n {\n fs[mem_offset + i] = (float) (compressed_v % 2);\n compressed_v = (compressed_v >> 1);\n }\n }\n }\n '
- bool2half = '\n #include <cuda_fp16.h>\n extern "C" __global__\n void bool2half(const unsigned char* bs, half* fs, const int &N)\n {\n const int index = blockIdx.x * blockDim.x + threadIdx.x;\n if (index < N)\n {\n const int mem_offset = (index << 3);\n unsigned char compressed_v = bs[index];\n #pragma unroll\n for(int i = 0; i < 8; i++)\n {\n fs[mem_offset + i] = __float2half((float) (compressed_v % 2));\n compressed_v = (compressed_v >> 1);\n }\n }\n }\n '
- spikingjelly.activation_based.tensor_cache.float_spike_to_bool(spike: Tensor)[源代码]
- 参数:
spike (torch.Tensor) – a spike tensor whose
dtype=torch.float
ordtype=torch.half
and all elements are 0 or 1- 返回:
(spike_b, s_dtype, s_shape, s_padding) spike_b: a compressed spike tensor with
dtype=torch.uint8
and each element stores 8 spikes s_dtype: the dtype of the original spike s_shape: the shape of the original spike s_padding: the number of padding elements- 返回类型:
Compress a float/half spike tensor
spike
to an uint8 tensorspike_b
. Each element inspike_b
represents 8 elements ofspike
.
- spikingjelly.activation_based.tensor_cache.bool_spike_to_float(spike_b: Tensor, s_dtype: dtype, s_shape: Size, s_padding: int = 0)[源代码]
- 参数:
spike_b (torch.Tensor) – a compressed spike tensor with
dtype=torch.uint8
and each element stores 8 spikess_dtype (torch.dtype) – the dtype of the original spike
s_shape (torch.Size) – the shape of the original spike
s_padding (int) – the number of padding elements
- 返回:
the original tensor
- 返回类型: