Gradient Checkpointing Tools#

用于实现带输入压缩的梯度检查点 (GC) 的工具。


Tools for implementing gradient checkpointing (GC) with input compression.

spikingjelly.activation_based.memopt.checkpointing.in_gc_1st_forward() bool[源代码]#

API Language: 中文 | English


  • 中文

判断当前是否处于梯度检查点的第一次前向传播过程中。

返回类型:

bool


  • English

Determine whether the current execution is inside the first forward pass of gradient checkpointing.

返回类型:

bool

spikingjelly.activation_based.memopt.checkpointing.query_autocast() Tuple[str, dtype, bool][源代码]#

API Language: 中文 | English


  • 中文

查询当前自动混合精度设置。

返回:

一个包含 (设备类型, 数据类型, 是否启用) 的元组。如果 is_enabled == Falsedevice_typedtype 将分别设置为 "cpu"torch.get_autocast_dtype("cpu")

返回类型:

Tuple[str, torch.dtype, bool]


  • English

Query the current autocast settings.

返回:

a tuple of (device_type, dtype, is_enabled) . If is_enabled == False, device_type and dtype will be set as "cpu" and torch.get_autocast_dtype("cpu"), respectively.

返回类型:

Tuple[str, torch.dtype, bool]

spikingjelly.activation_based.memopt.checkpointing.input_compressed_gc(f_forward, x_compressor: BaseSpikeCompressor, x_seq, *args)[源代码]#

API Language: 中文 | English


  • 中文

带有输入压缩的梯度检查点。

参数:
  • f_forward (Callable) -- 要进行检查点的前向函数

  • x_compressor (BaseSpikeCompressor) -- 施加于 x_seq 的压缩器

  • x_seq (torch.Tensor) -- 主要输入参数,通常是脉冲序列。该张量将先被压缩,后暂存

  • args (tuple) -- 其他输入参数。这些张量不会被压缩,直接被暂存

返回:

张量或元组

返回类型:

torch.Tensor or tuple


  • English

Gradient checkpointing with input compression.

参数:
  • f_forward (Callable) -- the forward function whose arguments will be checkpointed

  • x_compressor (BaseSpikeCompressor) -- the compressor for x_seq

  • x_seq (torch.Tensor) -- the input argument to be compressed and then checkpointed. Typically, x_seq is a spike train

  • args (tuple) -- other arguments that will be checkpointed without compression

返回:

a Tensor or a tuple

返回类型:

torch.Tensor or tuple


  • 代码示例 | Example

import torch
import torch.nn as nn
from spikingjelly.activation_based.memopt import input_compressed_gc
from spikingjelly.activation_based.memopt import NullSpikeCompressor


def simple_forward(x, weight):
    return torch.matmul(x, weight.t())


x = torch.randn(5, 3, requires_grad=True)
weight = torch.randn(4, 3, requires_grad=True)
result = input_compressed_gc(simple_forward, NullSpikeCompressor(), x, weight)
loss = result.sum()
loss.backward()
spikingjelly.activation_based.memopt.checkpointing.to_gc_function(x_compressor: BaseSpikeCompressor, f_forward: Callable | None = None)[源代码]#

API Language: 中文 | English


  • 中文

将函数转换为被 input_compressed_gc 包装后的函数。本接口可作为装饰器或转换函数。

参数:
  • x_compressor (BaseSpikeCompressor) -- 压缩器

  • f_forward (Optional[Callable]) -- 前向函数,如果为 None 则使用装饰器模式;否则使用转换函数模式。 默认为 None

返回:

检查点包装后的函数

返回类型:

Callable


  • English

Convert a forward function to a GC-wrapped forward function. This API can be used as a decorator or a conversion function.

参数:
  • x_compressor (BaseSpikeCompressor) -- compressor

  • f_forward (Optional[Callable]) -- forward function. If None, use the decorator mode; otherwise, use the conversion function mode. Defaults to None.

返回:

the GC-wrapped forward function

返回类型:

Callable


  • 代码示例 | Example

import torch
from spikingjelly.activation_based.memopt import to_gc_function
from spikingjelly.activation_based.memopt import NullSpikeCompressor

x = torch.randn(5, 3, requires_grad=True)
weight = torch.randn(4, 3, requires_grad=True)
compressor = NullSpikeCompressor()


# Usage 1: as decorator
@to_gc_function(compressor)
def decorated_forward(x, weight):
    return torch.matmul(x, weight.t())


result1 = decorated_forward(x, weight)


# Usage 2: as conversion function
def simple_forward(x, weight):
    return torch.matmul(x, weight.t())


converted_forward = to_gc_function(compressor, simple_forward)
result2 = converted_forward(x, weight)
class spikingjelly.activation_based.memopt.checkpointing.GCContainer(x_compressor: BaseSpikeCompressor | None, *args)[源代码]#

基类:Sequential

API Language: 中文 | English


  • 中文

  • 中文

nn.Sequential 风格构造梯度检查点片段(GC segment)。

参数:

x_compressor (Optional[BaseSpikeCompressor]) -- 脉冲压缩器。如果为 None 则使用 NullSpikeCompressor


  • English

  • English

Construct a GC block module in nn.Sequential style.

参数:

x_compressor (Optional[BaseSpikeCompressor]) -- spike compressor. If None, use NullSpikeCompressor


  • 代码示例 | Example

import torch
import torch.nn as nn
from spikingjelly.activation_based.memopt import GCContainer
from spikingjelly.activation_based.memopt import NullSpikeCompressor

container = GCContainer(
    NullSpikeCompressor(), nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5)
)

x = torch.randn(3, 10, requires_grad=True)
result = container(x)
返回:

None

返回类型:

None

super_forward(input)[源代码]#

The same as nn.Sequential.forward .

We have to explicitly specify and use this function in __init__ instead of using super().forward in order to avoid infinite recursion in multiprocess scenarios!!

stateless_forward(x, *args)[源代码]#
stateful_forward(x, *args)[源代码]#
forward(x, *args)[源代码]#
class spikingjelly.activation_based.memopt.checkpointing.TCGCContainer(x_compressor: BaseSpikeCompressor | None, *args, n_chunk: int = 1, n_seq_inputs: int = 1, n_outputs: int = 1)[源代码]#

基类:GCContainer

API Language: 中文 | English


  • 中文

  • 中文

时间分块的 GCContainer

参数:
  • x_compressor (Optional[BaseSpikeCompressor]) -- 脉冲压缩器。如果为 None 则使用 NullSpikeCompressor

  • *args --

    传递给 nn.Sequential 的若干模块。必须以位置参数形式传入

  • n_chunk (int) -- 分块数量。默认为1。必须以关键字参数形式传入

  • n_seq_inputs (int) -- 需要分块处理的序列输入数量。默认为1。必须以关键字参数形式传入

  • n_outputs (int) -- 输出数量。本模块假设输出都是 torch.Tensor 。默认为1。必须以关键字参数形式传入


  • English

  • English

Temporally Chunked GCContainer .

参数:
  • x_compressor (Optional[BaseSpikeCompressor]) -- spike compressor. If None, use NullSpikeCompressor

  • *args --

    modules as arguments of nn.Sequential. Must act as positional arguments

  • n_chunk (int) -- number of chunks. Default to 1. Must act as keyword arguments

  • n_seq_inputs (int) -- number of sequence inputs. Default to 1. Must act as keyword arguments

  • n_outputs (int) -- number of outputs. This container assumes that all outputs are torch.Tensor. Default to 1. Must act as keyword arguments


  • 代码示例 | Example

import torch
import torch.nn as nn
from spikingjelly.activation_based.memopt import TCGCContainer
from spikingjelly.activation_based.memopt import NullSpikeCompressor

# Basic usage
tc_container = TCGCContainer(
    NullSpikeCompressor(),
    nn.Linear(10, 20),
    nn.ReLU(),
    nn.Linear(20, 5),
    n_chunk=4,
)
x_seq = torch.randn(8, 3, 10, requires_grad=True)  # T=8
result = tc_container(x_seq)
print(f"Input shape: {x_seq.shape}")
print(f"Output shape: {result.shape}")
返回:

None

返回类型:

None

forward(x_seq: Tensor, *args)[源代码]#