spikingjelly.activation_based.quantize 源代码

import torch



[文档]class round_atgf(torch.autograd.Function):
[文档] @staticmethod def forward(ctx, x: torch.Tensor): return torch.round(x)
[文档] @staticmethod def backward(ctx, grad_output: torch.Tensor): return grad_output
[文档]@torch.jit.ignore def round(x: torch.Tensor): """ :param x: the input tensor :type x: torch.Tensor :return: the output tensor :rtype: torch.Tensor Apply ``y = torch.round(x)`` with re-defining gradient as :math:`\\frac{\\partial y}{\\partial x} = 1`. """ return round_atgf.apply(x)
[文档]class ceil_atgf(torch.autograd.Function):
[文档] @staticmethod def forward(ctx, x: torch.Tensor): return torch.ceil(x)
[文档] @staticmethod def backward(ctx, grad_output: torch.Tensor): return grad_output
[文档]@torch.jit.ignore def ceil(x: torch.Tensor): """ :param x: the input tensor :type x: torch.Tensor :return: the output tensor :rtype: torch.Tensor Apply ``y = torch.ceil(x)`` with re-defining gradient as :math:`\\frac{\\partial y}{\\partial x} = 1`. """ return ceil_atgf.apply(x)
[文档]class floor_atgf(torch.autograd.Function):
[文档] @staticmethod def forward(ctx, x: torch.Tensor): return torch.floor(x)
[文档] @staticmethod def backward(ctx, grad_output: torch.Tensor): return grad_output
[文档]@torch.jit.ignore def floor(x: torch.Tensor): """ :param x: the input tensor :type x: torch.Tensor :return: the output tensor :rtype: torch.Tensor Apply ``y = torch.floor(x)`` with re-defining gradient as :math:`\\frac{\\partial y}{\\partial x} = 1`. """ return floor_atgf.apply(x)
[文档]@torch.jit.script def clamp_backward(grad_output: torch.Tensor, x: torch.Tensor, min_value: float, max_value: float): mask = (x >= min_value).to(x) * (x <= max_value).to(x) return grad_output * mask
[文档]class clamp_atgf(torch.autograd.Function):
[文档] @staticmethod def forward(ctx, x: torch.Tensor, min_value: float, max_value: float): if x.requires_grad: ctx.save_for_backward(x) ctx.min_value = min_value ctx.max_value = max_value return torch.clamp(x, min_value, max_value)
[文档] @staticmethod def backward(ctx, grad_output: torch.Tensor): return clamp_backward(grad_output, ctx.saved_tensors[0], ctx.min_value, ctx.max_value), None, None
[文档]@torch.jit.ignore def clamp(x: torch.Tensor, min_value: float, max_value: float): """ :param x: the input tensor :type x: torch.Tensor :param min_value: lower-bound of the range to be clamped to :type min_value: float :param max_value: upper-bound of the range to be clamped to :type max_value: torch.Tensor :return: the output tensor :rtype: torch.Tensor Apply ``y = torch.clamp(x, min_value, max_value)`` with re-defining gradient as: .. math:: \\frac{\\partial y}{\\partial x} = \\begin{cases} 1, \\rm{min\\_value} \\leq x \\leq \\rm{max\\_value} \\\\ 0, \\rm{otherwise} \\end{cases} """ return clamp_atgf.apply(x, min_value, max_value)
[文档]@torch.jit.script def step_quantize_forward(x: torch.Tensor, step: float): return torch.round_(x / step) * step
[文档]class step_quantize_atgf(torch.autograd.Function):
[文档] @staticmethod def forward(ctx, x: torch.Tensor, step: float): return step_quantize_forward(x, step)
[文档] @staticmethod def backward(ctx, grad_output: torch.Tensor): return grad_output, None
[文档]@torch.jit.ignore def step_quantize(x: torch.Tensor, step: float): """ :param x: the input tensor :type x: torch.Tensor :param step: the quantize step :type step: float :return: the quantized tensor :rtype: torch.Tensor Quantize ``x`` to the nearest ``i * step``, where ``i`` is an integer. Note that the gradient is defined by :math:`\\frac{\\partial y}{\\partial x} = 1`. .. image:: ../_static/API/activation_based//quantize/step_quantize.* :width: 100% """ return step_quantize_atgf.apply(x, step)
""" import torch from spikingjelly.activation_based import quantize from matplotlib import pyplot as plt plt.style.use(['science', 'grid']) fig = plt.figure(dpi=200, figsize=(8, 4)) x = torch.arange(-4, 4, 0.01) colormap = plt.get_cmap('tab10') for i, step in zip(range(2), [1, 2]): plt.subplot(1, 2, i + 1) y = quantize.step_quantize(x, step) plt.plot(x, y, label=f'y = step_quantize(x, {step})', c=colormap(i)) plt.xlabel('Input') plt.ylabel('Output') plt.xticks(step / 2 * torch.as_tensor([-3, -1, 1, 3])) plt.grid(ls='--') plt.legend() # plt.show() plt.savefig('./docs/source/_static/API/activation_based/quantize/step_quantize.pdf') plt.savefig('./docs/source/_static/API/activation_based/quantize/step_quantize.svg') plt.savefig('./docs/source/_static/API/activation_based/quantize/step_quantize.png') """
[文档]@torch.jit.script def k_bit_quantize_forward(x: torch.Tensor, k: int): c = float(1 << k) - 1. x = x * c torch.round_(x) return x / c
[文档]class k_bit_quantize_atgf(torch.autograd.Function):
[文档] @staticmethod def forward(ctx, x: torch.Tensor, k: int): return k_bit_quantize_forward(x, k)
[文档] @staticmethod def backward(ctx, grad_output): return grad_output, None
[文档]@torch.jit.ignore def k_bit_quantize(x: torch.Tensor, k: int): """ :param x: a float tensor whose range is ``[0, 1]``. :type x: torch.Tensor :param k: the bit number of output :type k: int :return: ``y = round((2 ** k - 1) * x) / (2 ** k - 1)`` :rtype: torch.Tensor The k-bit quantizer defined in `DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients <https://arxiv.org/abs/1606.06160>`_. The input whose range is ``[0, 1]`` will be quantized to the nearest ``i / (2 ** k - 1)``, where ``i = 0, 1, ..., (2 ** k - 1)``. Note that the gradient is defined by :math:`\\frac{\\partial y}{\\partial x} = 1`. To clamp the input whose range is ``(-inf, inf)`` to range ``(0, 1)``, using :class:`torch.sigmoid`, :class:`torch.nn.Hardtanh` or ``clamp_*`` functions (e.g., :class:`spikingjelly.activation_based.quantize.clamp_by_linear`) in ``spikingjelly.activation_based.quantize``. .. image:: ../_static/API/activation_based//quantize/k_bit_quantize.* :width: 100% Codes example: .. code-block:: python x = torch.rand(8) y = k_bit_quantize(x, 2) print(f'x={x}') print(f'y={y}') # x=tensor([0.6965, 0.5697, 0.9883, 0.0438, 0.1332, 0.7613, 0.9704, 0.2384]) # y=tensor([0.6667, 0.6667, 1.0000, 0.0000, 0.0000, 0.6667, 1.0000, 0.3333]) """ return k_bit_quantize_atgf.apply(x, k)
[文档]def affine_k_bit_quantize(x: torch.Tensor, k: int, w: torch.Tensor, b: torch.Tensor): """ :param x: a float tensor whose range is ``[0, 1]``. :type x: torch.Tensor :param k: the bit number of output :type k: int :param w: the weight of the affine transform :type w: torch.Tensor :param b: the bias of the affine transform :type b: torch.Tensor :return: ``y = w * round((2 ** k - 1) * x) / (2 ** k - 1) + b`` :rtype: torch.Tensor Apply an affine quantization with ``y = w * round((2 ** k - 1) * x) / (2 ** k - 1) + b``. """ return w * k_bit_quantize(x, k) + b
""" import torch from spikingjelly.activation_based import quantize from matplotlib import pyplot as plt plt.style.use(['science', 'grid']) fig = plt.figure(dpi=200, figsize=(8, 4)) x = torch.arange(0, 1, 0.001) colormap = plt.get_cmap('tab10') for i, k in zip(range(2), [2, 3]): plt.subplot(1, 2, i + 1) y = quantize.k_bit_quantize(x, k=k) plt.plot(x, y, label=f'y = k_bit_quantize(x, {k})', c=colormap(i)) plt.xlabel('Input') plt.ylabel('Output') plt.grid(ls='--') plt.legend() # plt.show() plt.savefig('./docs/source/_static/API/activation_based/quantize/k_bit_quantize.pdf') plt.savefig('./docs/source/_static/API/activation_based/quantize/k_bit_quantize.svg') plt.savefig('./docs/source/_static/API/activation_based/quantize/k_bit_quantize.png') """
[文档]@torch.jit.script def clamp_by_linear(x: torch.Tensor, eps: float = 1e-5): """ :param x: the input tensor to be normed, whose range is ``(-inf, inf)`` :type x: torch.Tensor :param eps: a value added to the denominator for numerical stability. The default value is ``1e-5`` :type eps: float :type max_value: float :return: the normed tensor, whose range is ``[min_value, max_value]`` :rtype: torch.Tensor Using the linear transform to clamp the input range from ``(-inf, inf)`` to ``[0., 1.]``: .. math:: y = \\frac{x - \\rm{min}(x)}{\\rm{max}(x) - \\rm{min}(x) + eps} """ x_max = torch.max(x) + eps x_min = torch.min(x) return (x - x_min) / (x_max - x_min)