Gradient Checkpointing Tools#
用于实现带输入压缩的梯度检查点 (GC) 的工具。
Tools for implementing gradient checkpointing (GC) with input compression.
- spikingjelly.activation_based.memopt.checkpointing.in_gc_1st_forward() bool[源代码]#
-
中文
判断当前是否处于梯度检查点的第一次前向传播过程中。
- 返回类型:
English
Determine whether the current execution is inside the first forward pass of gradient checkpointing.
- 返回类型:
- spikingjelly.activation_based.memopt.checkpointing.query_autocast() Tuple[str, dtype, bool][源代码]#
-
中文
查询当前自动混合精度设置。
- 返回:
一个包含
(设备类型, 数据类型, 是否启用)的元组。如果is_enabled == False,device_type和dtype将分别设置为"cpu"和torch.get_autocast_dtype("cpu")- 返回类型:
Tuple[str, torch.dtype, bool]
English
Query the current autocast settings.
- 返回:
a tuple of
(device_type, dtype, is_enabled). Ifis_enabled == False,device_typeanddtypewill be set as"cpu"andtorch.get_autocast_dtype("cpu"), respectively.- 返回类型:
Tuple[str, torch.dtype, bool]
- spikingjelly.activation_based.memopt.checkpointing.input_compressed_gc(f_forward, x_compressor: BaseSpikeCompressor, x_seq, *args)[源代码]#
-
中文
带有输入压缩的梯度检查点。
- 参数:
f_forward (Callable) -- 要进行检查点的前向函数
x_compressor (BaseSpikeCompressor) -- 施加于
x_seq的压缩器x_seq (torch.Tensor) -- 主要输入参数,通常是脉冲序列。该张量将先被压缩,后暂存
args (tuple) -- 其他输入参数。这些张量不会被压缩,直接被暂存
- 返回:
张量或元组
- 返回类型:
English
Gradient checkpointing with input compression.
- 参数:
f_forward (Callable) -- the forward function whose arguments will be checkpointed
x_compressor (BaseSpikeCompressor) -- the compressor for
x_seqx_seq (torch.Tensor) -- the input argument to be compressed and then checkpointed. Typically,
x_seqis a spike trainargs (tuple) -- other arguments that will be checkpointed without compression
- 返回:
a Tensor or a tuple
- 返回类型:
代码示例 | Example
import torch import torch.nn as nn from spikingjelly.activation_based.memopt import input_compressed_gc from spikingjelly.activation_based.memopt import NullSpikeCompressor def simple_forward(x, weight): return torch.matmul(x, weight.t()) x = torch.randn(5, 3, requires_grad=True) weight = torch.randn(4, 3, requires_grad=True) result = input_compressed_gc(simple_forward, NullSpikeCompressor(), x, weight) loss = result.sum() loss.backward()
- spikingjelly.activation_based.memopt.checkpointing.to_gc_function(x_compressor: BaseSpikeCompressor, f_forward: Callable | None = None)[源代码]#
-
中文
将函数转换为被
input_compressed_gc包装后的函数。本接口可作为装饰器或转换函数。- 参数:
x_compressor (BaseSpikeCompressor) -- 压缩器
f_forward (Optional[Callable]) -- 前向函数,如果为
None则使用装饰器模式;否则使用转换函数模式。 默认为None
- 返回:
检查点包装后的函数
- 返回类型:
Callable
English
Convert a forward function to a GC-wrapped forward function. This API can be used as a decorator or a conversion function.
- 参数:
x_compressor (BaseSpikeCompressor) -- compressor
f_forward (Optional[Callable]) -- forward function. If
None, use the decorator mode; otherwise, use the conversion function mode. Defaults toNone.
- 返回:
the GC-wrapped forward function
- 返回类型:
Callable
代码示例 | Example
import torch from spikingjelly.activation_based.memopt import to_gc_function from spikingjelly.activation_based.memopt import NullSpikeCompressor x = torch.randn(5, 3, requires_grad=True) weight = torch.randn(4, 3, requires_grad=True) compressor = NullSpikeCompressor() # Usage 1: as decorator @to_gc_function(compressor) def decorated_forward(x, weight): return torch.matmul(x, weight.t()) result1 = decorated_forward(x, weight) # Usage 2: as conversion function def simple_forward(x, weight): return torch.matmul(x, weight.t()) converted_forward = to_gc_function(compressor, simple_forward) result2 = converted_forward(x, weight)
- class spikingjelly.activation_based.memopt.checkpointing.GCContainer(x_compressor: BaseSpikeCompressor | None, *args)[源代码]#
基类:
Sequential
中文
中文
以
nn.Sequential风格构造梯度检查点片段(GC segment)。- 参数:
x_compressor (Optional[BaseSpikeCompressor]) -- 脉冲压缩器。如果为
None则使用NullSpikeCompressor
English
English
Construct a GC block module in nn.Sequential style.
- 参数:
x_compressor (Optional[BaseSpikeCompressor]) -- spike compressor. If None, use
NullSpikeCompressor
代码示例 | Example
import torch import torch.nn as nn from spikingjelly.activation_based.memopt import GCContainer from spikingjelly.activation_based.memopt import NullSpikeCompressor container = GCContainer( NullSpikeCompressor(), nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5) ) x = torch.randn(3, 10, requires_grad=True) result = container(x)
- 返回:
None
- 返回类型:
None
- class spikingjelly.activation_based.memopt.checkpointing.TCGCContainer(x_compressor: BaseSpikeCompressor | None, *args, n_chunk: int = 1, n_seq_inputs: int = 1, n_outputs: int = 1)[源代码]#
基类:
GCContainer
中文
中文
时间分块的
GCContainer。- 参数:
x_compressor (Optional[BaseSpikeCompressor]) -- 脉冲压缩器。如果为
None则使用NullSpikeCompressor*args --
传递给
nn.Sequential的若干模块。必须以位置参数形式传入n_chunk (int) -- 分块数量。默认为1。必须以关键字参数形式传入
n_seq_inputs (int) -- 需要分块处理的序列输入数量。默认为1。必须以关键字参数形式传入
n_outputs (int) -- 输出数量。本模块假设输出都是
torch.Tensor。默认为1。必须以关键字参数形式传入
English
English
Temporally Chunked
GCContainer.- 参数:
x_compressor (Optional[BaseSpikeCompressor]) -- spike compressor. If None, use
NullSpikeCompressor*args --
modules as arguments of
nn.Sequential. Must act as positional argumentsn_chunk (int) -- number of chunks. Default to 1. Must act as keyword arguments
n_seq_inputs (int) -- number of sequence inputs. Default to 1. Must act as keyword arguments
n_outputs (int) -- number of outputs. This container assumes that all outputs are
torch.Tensor. Default to 1. Must act as keyword arguments
代码示例 | Example
import torch import torch.nn as nn from spikingjelly.activation_based.memopt import TCGCContainer from spikingjelly.activation_based.memopt import NullSpikeCompressor # Basic usage tc_container = TCGCContainer( NullSpikeCompressor(), nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5), n_chunk=4, ) x_seq = torch.randn(8, 3, 10, requires_grad=True) # T=8 result = tc_container(x_seq) print(f"Input shape: {x_seq.shape}") print(f"Output shape: {result.shape}")
- 返回:
None
- 返回类型:
None