spikingjelly.cext.layer 源代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import spikingjelly.cext.functional
import warnings
import math
import numpy as np
import time
[文档]class SparseLinear(nn.Linear): ''' * :ref:`API in English <SparseLinear-en>` .. _SparseLinear-cn: :param in_features: 输入的特征数量 :type in_features: int :param out_features: 输出的特征数量 :type out_features: int :param bias: 若为 ``False``,则本层不含有可学习的偏置项。默认为 ``True`` :type bias: bool 适用于稀疏输入的全连接层。与 ``torch.nn.Linear`` 的行为几乎相同。 .. warning:: 代码内部的实现方式是,首先将 ``sparse`` 转换为稀疏矩阵格式,然后再调用相关库进行运算。如果 ``sparse`` 不够稀疏,则该函数的速度会比普通矩阵乘法 ``torch.mm`` 慢很多。 .. warning:: 稀疏矩阵的乘法存在一定的计算误差,但误差并不显著,或可忽略。 .. warning:: 本层不支持CPU。 * :ref:`中文API <SparseLinear-cn>` .. _SparseLinear-en: :param in_features: size of each input sample :type in_features: int :param out_features: size of each output sample :type out_features: int :param bias: If set to ``False``, the layer will not learn an additive bias. Default: ``True`` :type bias: bool The fully connected layer for sparse inputs. This module has a similar behavior as ``torch.nn.Linear``. .. admonition:: Warning :class: warning This function is implemented by converting ``sparse`` to a sparse format and doing a sparse matrix multiplication. If the sparsity of ``sparse`` is not high enough, the speed of this function will be slower than ``torch.mm``. .. admonition:: Warning :class: warning There are some numeral errors when doing the sparse matrix multiplication. But the errors are not significant. .. admonition:: Warning :class: warning This layer does not support to run on cpu. '''
[文档] def forward(self, sparse: torch.Tensor) -> torch.Tensor: if self.bias is None: return spikingjelly.cext.functional.sparse_mm_dense(sparse, self.weight.t()) else: return spikingjelly.cext.functional.sparse_mm_dense(sparse, self.weight.t()) + self.bias
[文档]class AutoSparseLinear(nn.Linear): def __init__(self, in_features: int, out_features: int, bias: bool = True, in_spikes: bool = False): ''' * :ref:`API in English <AutoSparseLinear-en>` .. _AutoSparseLinear-cn: :param in_features: 输入的特征数量 :type in_features: int :param out_features: 输出的特征数量 :type out_features: int :param bias: 若为 ``False``,则本层不含有可学习的偏置项。默认为 ``True`` :type bias: bool :param in_spikes: 输入是否为脉冲,即元素均为0或1 :type in_spikes: bool 智能稀疏全连接层。对于任意输入,若它的 ``batch_size`` 对应的临界稀疏度未知,本层会首先运行基准测试 :ref:`AutoSparseLinear.benchmark <AutoSparseLinear.benchmark-cn>` 来获取临界稀疏度。临界稀疏度定义为,当输入是这一稀疏度时,稀疏矩阵乘法和普通矩阵乘法的速度恰好相同。对于任意输入,若它的 ``batch_size`` 对应的临界稀疏度已知,本层都会根据当前输入的稀疏度来智能决定是使用稀疏矩阵乘法还是普通矩阵乘法。 .. warning:: 稀疏矩阵的乘法存在一定的计算误差,但误差并不显著,或可忽略。 .. warning:: 稀疏矩阵乘法不支持CPU。在CPU上运行,只会使用普通矩阵乘法。 * :ref:`中文API <AutoSparseLinear-cn>` .. _AutoSparseLinear-en: :param in_features: size of each input sample :type in_features: int :param out_features: size of each output sample :type out_features: int :param bias: If set to ``False``, the layer will not learn an additive bias. Default: ``True`` :type bias: bool :param in_spikes: Whether inputs are spikes, whose elements are 0 and 1 Default: ``False`` :type in_spikes: bool The auto sparse fully connected layer. For an input, if the corresponding critical sparsity of the input's batch size is unknown, this layer will firstly run the benchmark :ref:`AutoSparseLinear.benchmark <AutoSparseLinear.benchmark-en>` to get the critical sparsity. The critical sparsity is the sparsity where the sparse matrix multiplication and the dense matrix multiplication have the same speed. For an input, if the corresponding critical sparsity of the input's batch size is known, this layer can auto select whether using the sparse or dense matrix multiplication according to the current input's sparsity. .. admonition:: Warning :class: warning There are some numeral errors when doing the sparse matrix multiplication. But the errors are not significant. .. admonition:: Warning :class: warning This sparse matrix multiplication does not support to run on cpu. When this layer is on CPU, the dense matrix multiplication will be always used. ''' super().__init__(in_features, out_features, bias) self.critical_sparsity = {} # 键是输入数据的batch_size,值是临界稀疏度 # 当稀疏度高于临界稀疏度,前向传播使用稀疏矩阵乘法;否则使用普通矩阵乘法 self.in_spikes = in_spikes
[文档] def forward(self, x: torch.Tensor) -> torch.Tensor: if x.get_device() < 0: # 稀疏运算暂不支持CPU return F.linear(x, self.weight, self.bias) batch_size = x.shape[0] if batch_size not in self.critical_sparsity: # 运行benchmark,获取临界稀疏度 self.benchmark(batch_size, x.device) csp = self.critical_sparsity[batch_size] if csp is None: return F.linear(x, self.weight, self.bias) else: with torch.no_grad(): if self.in_spikes: sparsity = 1 - x.mean().item() else: sparsity = (x == 0).float().mean().item() if sparsity < csp: return F.linear(x, self.weight, self.bias) else: if self.bias is None: return spikingjelly.cext.functional.sparse_mm_dense(x, self.weight) else: return spikingjelly.cext.functional.sparse_mm_dense(x, self.weight) + self.bias
[文档] def extra_repr(self) -> str: return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}, critical_sparsity={self.critical_sparsity}'
[文档] @torch.enable_grad() def benchmark(self, batch_size: int, device=None, run_times=1024, precision=1e-4, verbose=True): ''' * :ref:`API in English <AutoSparseLinear.benchmark-en>` .. _AutoSparseLinear.benchmark-cn: :param batch_size: 输入的batch size :type batch_size: int :param device: 运行基准测试所在的设备。若为 ``None``,则会被设置成本层所在的设备。 :type device: str or None :param run_times: 运行稀疏/普通矩阵乘法的重复实验的次数。越大,则基准测试的结果越可靠 :type run_times: int :param precision: 二分搜索的最终临界稀疏值的精度 :type precision: float :param verbose: 是否打印出测试过程中的日志 :type verbose: bool 使用二分查找,在输入的batch size为 ``batch_size`` 时,在每个稀疏度上重复运行 ``run_times`` 次稀疏/普通矩阵乘法,比较 两者的速度,直到搜索到临界稀疏度。若搜索达到精度范围 ``precision`` 时,普通矩阵乘法仍然比稀疏矩阵乘法快,则会将临界稀疏度设 置成 ``None``。 * :ref:`中文API <AutoSparseLinear.benchmark-cn>` .. _AutoSparseLinear.benchmark-en: :param batch_size: batch size of the input :type batch_size: int :param device: where to running the benchmark. If ``None``, it will be set as same with this layer's device :type device: str :param run_times: the number of replicated running times for sparse/dense matrix multiplication. The benchmark result will be more reliable with a larger ``run_times`` :type run_times: int :param precision: the precision of binary searching critical sparsity :type precision: float :param verbose: If ``True``, this function will print logs during running :type verbose: bool Using the binary search to find the critical sparsity when the batch size of the input is ``batch_size``. This function will run ``run_times`` sparse/dense matrix multiplication on different sparsity and compare their speeds until it finds the cirtical sparsity. If the dense matrix multiplication is faster than the sparse matrix multiplication when searching exceeds ``precision``, then the critical sparsity will be set to ``None``. ''' if self.critical_sparsity.__len__() > 4: warnings.warn('AutoSparseLinear: The batch size of the input has changed more than 4 times. AutoSparseLinear may waste too much time on running benchmark.') if device is None: device = self.weight.device if verbose: print(f'{self} is running benchmark for batch_size={batch_size} at precision={precision} on device={device}') if self.bias is None: bias = False else: bias = True fc_sparse = SparseLinear(self.in_features, self.out_features, bias) fc_sparse.to(device) fc_dense = nn.Linear(self.in_features, self.out_features, bias) fc_dense.to(device) sparisity_r = 1.0 sparisity_l = 0.1 # 二分查找临界稀疏度 while True: sparisity = (sparisity_l + sparisity_r) / 2 x = torch.rand(size=[batch_size, self.in_features], device=device) sparse = (x > sparisity).to(x) sparisity_a = (sparse == 0).to(x).mean().item() # sparse的真实稀疏度 # 计算稀疏前反向所需时间 t_list = [] for _ in range(run_times * 2): fc_sparse.zero_grad() torch.cuda.synchronize() t_start = time.perf_counter() fc_sparse(sparse).sum().backward() torch.cuda.synchronize() t_list.append(time.perf_counter() - t_start) t_list = np.asarray(t_list) t_sparse = t_list[run_times:].sum() # 计算稠密前反向所需时间 t_list = [] for _ in range(run_times * 2): fc_dense.zero_grad() torch.cuda.synchronize() t_start = time.perf_counter() fc_dense(sparse).sum().backward() torch.cuda.synchronize() t_list.append(time.perf_counter() - t_start) t_list = np.asarray(t_list) t_dense = t_list[run_times:].sum() if verbose: print(f'sparisity_a={sparisity_a}, t_sparse={t_sparse}, t_dense={t_dense}') if t_sparse > t_dense: sparisity_l = sparisity_a elif t_sparse < t_dense: sparisity_r = sparisity_a else: break if sparisity_r - sparisity_l < precision: break if t_sparse < t_dense: self.critical_sparsity[batch_size] = sparisity_a else: # 如果搜索达到精度范围后,稀疏乘法仍然比普通乘法慢,则永远不调用稀疏乘法 self.critical_sparsity[batch_size] = None print(f'critical_sparsity[{batch_size}]={self.critical_sparsity[batch_size]}') del x, sparse, fc_sparse, fc_dense torch.cuda.empty_cache()