"""
FlexSN Triton Kernel Templates.
Insert a single-step kernel into a multi-step kernel template.
"""
try:
import triton
except Exception as e:
import logging
from .. import dummy
logging.info(f"spikingjelly.activation_based.triton_kernel.flexsn.template: {e}")
triton = dummy.DummyImport()
from ..torch2triton import compile_triton_code_str
from .info import FlexSNInfo
__all__ = [
"get_flexsn_inference_kernel",
"get_flexsn_inference_final_state_kernel",
"get_flexsn_forward_kernel",
"get_flexsn_backward_kernel",
]
INDENTATION = " " * 4
def _signature(names):
return f",\n{INDENTATION}".join(names)
init_state_load_template = """
{name}_init_ptrs = tl.make_block_ptr(
{name}_init_ptr,
shape=(1, NCL),
strides=(NCL, 1),
offsets=(0, ncl_offset),
block_shape=(1, BLOCK_NCL),
order=(1, 0)
)
{name} = tl.load(
{name}_init_ptrs, boundary_check=(1,), padding_option="zero"
)
"""
grad_init_state_store_template = """
{name}_init_ptrs = tl.make_block_ptr(
{name}_init_ptr,
shape=(T, NCL),
strides=(NCL, 1),
offsets=(t, ncl_offset),
block_shape=(1, BLOCK_NCL),
order=(1, 0)
)
convert_and_store({name}_init_ptrs, {name}_accumulate, boundary_check=(1,))
# tl.store({name}_ptrs, {name}, boundary_check=(1,))
"""
store_template = """
{name}_ptrs = tl.make_block_ptr(
{name}_seq_ptr,
shape=(T, NCL),
strides=(NCL, 1),
offsets=(t, ncl_offset),
block_shape=(1, BLOCK_NCL),
order=(1, 0)
)
convert_and_store({name}_ptrs, {name}, boundary_check=(1,))
# tl.store({name}_ptrs, {name}, boundary_check=(1,))
"""
final_state_store_template = """
{name}_final_ptrs = tl.make_block_ptr(
{name}_final_ptr,
shape=(1, NCL),
strides=(NCL, 1),
offsets=(0, ncl_offset),
block_shape=(1, BLOCK_NCL),
order=(1, 0)
)
convert_and_store({name}_final_ptrs, {name}, boundary_check=(1,))
"""
load_template = """
{name}_ptrs = tl.make_block_ptr(
{name}_seq_ptr,
shape=(T, NCL),
strides=(NCL, 1),
offsets=(t, ncl_offset),
block_shape=(1, BLOCK_NCL),
order=(1, 0)
)
{name} = tl.load(
{name}_ptrs, boundary_check=(1,), padding_option="zero"
)
"""
kernel_template = """import triton
import triton.language as tl
@triton.jit
def convert_and_store(pointer, value, boundary_check):
# For block pointers created by tl.make_block_pointer(),
# implicit type casting is not supported when calling tl.store().
# This function manually converts dtype and then stores the data.
value = value.to(pointer.dtype.element_ty.element_ty)
tl.store(pointer, value, boundary_check=boundary_check)
{core_str}
@triton.autotune(
configs=[
triton.Config({{"BLOCK_NCL": f * w * 32}}, num_warps=w)
for f in [1, 2]
for w in [2, 4]
],
key=["T", "dtype"],
restore_value=[{autotune_restore}],
)
@triton.jit
def flexsn_{kernel_type}_kernel_{hash}(
{kernel_input_signature}, # inputs (including init states)
{kernel_output_signature}, # outputs
T: tl.constexpr,
NCL: tl.constexpr,
BLOCK_NCL: tl.constexpr,
dtype: tl.constexpr,
):
pid_ncl = tl.program_id(0)
ncl_offset = pid_ncl * BLOCK_NCL
{init_state_loads}
for t in tl.static_range({loop_range}):
{loads}
{computes}
{stores}
{tail}
"""
[文档]
def get_flexsn_inference_kernel(
core_str: str, core_name: str, info: FlexSNInfo, verbose: bool = False
):
"""Compile a Triton kernel for FlexSN inference (no backward).
**API Language:**
:ref:`中文 <get_flexsn_inference_kernel-cn>` | :ref:`English <get_flexsn_inference_kernel-en>`
----
.. _get_flexsn_inference_kernel-cn:
* **中文**
get flexsn inference kernel 函数
:param core_str: Core kernel source code as a string
:type core_str: str
:param core_name: Unique name for the compiled kernel
:type core_name: str
:param info: FlexSN kernel metadata
:type info: ``FlexSNInfo``
:param verbose: If ``True``, print compilation info
:type verbose: bool
:return: Compiled Triton kernel executable
:rtype: triton.runtime.JITFunction
----
.. _get_flexsn_inference_kernel-en:
* **English**
Get Flexsn Inference Kernel function
:param core_str: Core kernel source code as a string
:param core_name: Unique name for the compiled kernel
:param info: FlexSN kernel metadata
:param verbose: If ``True``, print compilation info
:type core_str: str
:type core_name: str
:type info: ``FlexSNInfo``
:type verbose: bool
:return: Compiled Triton kernel executable
:rtype: triton.runtime.JITFunction
"""
kernel_hash = core_name[-8:]
num_inputs = info.num_inputs
num_states = info.num_states
num_outputs = info.num_outputs
kernel_input_signature = _signature(
[f"x{i}_seq_ptr" for i in range(num_inputs)]
+ [f"v{i}_init_ptr" for i in range(num_states)]
)
kernel_output_signature = _signature(
[f"s{i}_seq_ptr" for i in range(num_outputs)]
+ [f"v{i}_seq_ptr" for i in range(num_states)]
)
restore_names = [f'"s{i}_seq_ptr"' for i in range(num_outputs)]
restore_names += [f'"v{i}_seq_ptr"' for i in range(num_states)]
autotune_restore = ", ".join(restore_names)
init_state_loads = "".join(
[
init_state_load_template.format(
name=f"v{i}",
)
for i in range(num_states)
]
)
loads = "".join([load_template.format(name=f"x{i}") for i in range(num_inputs)])
stores = "".join([store_template.format(name=f"s{i}") for i in range(num_outputs)])
stores += "".join([store_template.format(name=f"v{i}") for i in range(num_states)])
lhs_list = [f"s{i}" for i in range(num_outputs)] + [
f"v{i}" for i in range(num_states)
]
lhs = ", ".join(lhs_list)
core_args = ", ".join([f"x{i}" for i in range(num_inputs)])
core_args += ", "
core_args += ", ".join([f"v{i}" for i in range(num_states)])
kernel_str = kernel_template.format(
core_str=core_str,
autotune_restore=autotune_restore,
kernel_type="inference",
hash=kernel_hash,
kernel_input_signature=kernel_input_signature,
kernel_output_signature=kernel_output_signature,
init_state_loads=init_state_loads,
loop_range="0, T, 1",
loads=loads,
computes=f"{lhs} = {core_name}({core_args})",
stores=stores,
tail="",
).strip()
kernel_name = f"flexsn_inference_kernel_{kernel_hash}"
if verbose:
print("=" * 40, core_name, "=" * 40)
print("Generated flexsn inference kernel:")
print("```")
print(kernel_str)
print("```\n")
print(info)
print("=" * 40, "=" * len(core_name), "=" * 40)
kernel_exe = compile_triton_code_str(kernel_str, kernel_name, verbose)
return kernel_exe
[文档]
def get_flexsn_inference_final_state_kernel(
core_str: str, core_name: str, info: FlexSNInfo, verbose: bool = False
):
"""Compile a Triton kernel for FlexSN inference returning the final state.
**API Language:**
:ref:`中文 <get_flexsn_inference_final_state_kernel-cn>` | :ref:`English <get_flexsn_inference_final_state_kernel-en>`
----
.. _get_flexsn_inference_final_state_kernel-cn:
* **中文**
get flexsn inference final state kernel 函数
:param core_str: Core kernel source code as a string
:type core_str: str
:param core_name: Unique name for the compiled kernel
:type core_name: str
:param info: FlexSN kernel metadata
:type info: ``FlexSNInfo``
:param verbose: If ``True``, print compilation info
:type verbose: bool
:return: Compiled Triton kernel executable
:rtype: triton.runtime.JITFunction
----
.. _get_flexsn_inference_final_state_kernel-en:
* **English**
Get Flexsn Inference Final State Kernel function
:param core_str: Core kernel source code as a string
:param core_name: Unique name for the compiled kernel
:param info: FlexSN kernel metadata
:param verbose: If ``True``, print compilation info
:type core_str: str
:type core_name: str
:type info: ``FlexSNInfo``
:type verbose: bool
:return: Compiled Triton kernel executable
:rtype: triton.runtime.JITFunction
"""
kernel_hash = core_name[-8:]
num_inputs = info.num_inputs
num_states = info.num_states
num_outputs = info.num_outputs
kernel_input_signature = _signature(
[f"x{i}_seq_ptr" for i in range(num_inputs)]
+ [f"v{i}_init_ptr" for i in range(num_states)]
)
kernel_output_signature = _signature(
[f"s{i}_seq_ptr" for i in range(num_outputs)]
+ [f"v{i}_final_ptr" for i in range(num_states)]
)
restore_names = [f'"s{i}_seq_ptr"' for i in range(num_outputs)]
restore_names += [f'"v{i}_final_ptr"' for i in range(num_states)]
autotune_restore = ", ".join(restore_names)
init_state_loads = "".join(
[init_state_load_template.format(name=f"v{i}") for i in range(num_states)]
)
loads = "".join([load_template.format(name=f"x{i}") for i in range(num_inputs)])
stores = "".join([store_template.format(name=f"s{i}") for i in range(num_outputs)])
tail = "".join(
[final_state_store_template.format(name=f"v{i}") for i in range(num_states)]
)
lhs_list = [f"s{i}" for i in range(num_outputs)] + [
f"v{i}" for i in range(num_states)
]
lhs = ", ".join(lhs_list)
core_args = ", ".join([f"x{i}" for i in range(num_inputs)])
core_args += ", "
core_args += ", ".join([f"v{i}" for i in range(num_states)])
kernel_str = kernel_template.format(
core_str=core_str,
autotune_restore=autotune_restore,
kernel_type="inference_final_state",
hash=kernel_hash,
kernel_input_signature=kernel_input_signature,
kernel_output_signature=kernel_output_signature,
init_state_loads=init_state_loads,
loop_range="0, T, 1",
loads=loads,
computes=f"{lhs} = {core_name}({core_args})",
stores=stores,
tail=tail,
).strip()
kernel_name = f"flexsn_inference_final_state_kernel_{kernel_hash}"
if verbose:
print("=" * 40, core_name, "=" * 40)
print("Generated flexsn inference-final-state kernel:")
print("```")
print(kernel_str)
print("```\n")
print(info)
print("=" * 40, "=" * len(core_name), "=" * 40)
kernel_exe = compile_triton_code_str(kernel_str, kernel_name, verbose)
return kernel_exe
[文档]
def get_flexsn_forward_kernel(
core_str: str,
core_name: str,
info: FlexSNInfo,
verbose: bool = False,
):
"""Compile a Triton kernel for FlexSN forward pass (with state saving).
**API Language:**
:ref:`中文 <get_flexsn_forward_kernel-cn>` | :ref:`English <get_flexsn_forward_kernel-en>`
----
.. _get_flexsn_forward_kernel-cn:
* **中文**
get flexsn forward kernel 函数
:param core_str: Core kernel source code as a string
:type core_str: str
:param core_name: Unique name for the compiled kernel
:type core_name: str
:param info: FlexSN kernel metadata
:type info: ``FlexSNInfo``
:param verbose: If ``True``, print compilation info
:type verbose: bool
:return: Compiled Triton kernel executable
:rtype: triton.runtime.JITFunction
----
.. _get_flexsn_forward_kernel-en:
* **English**
Get Flexsn Forward Kernel function
:param core_str: Core kernel source code as a string
:param core_name: Unique name for the compiled kernel
:param info: FlexSN kernel metadata
:param verbose: If ``True``, print compilation info
:type core_str: str
:type core_name: str
:type info: ``FlexSNInfo``
:type verbose: bool
:return: Compiled Triton kernel executable
:rtype: triton.runtime.JITFunction
"""
kernel_hash = core_name[-8:]
num_inputs = info.num_inputs
num_states = info.num_states
fwd_kernel_returns = info.fwd_kernel_returns # unique
fwd_core_recipients = info.fwd_core_recipients # `_` for duplicates
kernel_input_signature = _signature(
[f"x{i}_seq_ptr" for i in range(num_inputs)]
+ [f"v{i}_init_ptr" for i in range(num_states)]
)
kernel_output_signature = _signature([f"{r}_seq_ptr" for r in fwd_kernel_returns])
autotune_restore = ", ".join([f'"{r}_seq_ptr"' for r in fwd_kernel_returns])
init_state_loads = "".join(
[
init_state_load_template.format(
name=f"v{i}",
)
for i in range(num_states)
]
)
loads = "".join([load_template.format(name=f"x{i}") for i in range(num_inputs)])
stores = "".join([store_template.format(name=r) for r in fwd_kernel_returns])
lhs = ", ".join([r for r in fwd_core_recipients])
core_args = ", ".join([f"x{i}" for i in range(num_inputs)])
core_args += ", "
core_args += ", ".join([f"v{i}" for i in range(num_states)])
kernel_str = kernel_template.format(
core_str=core_str,
autotune_restore=autotune_restore,
kernel_type="forward",
hash=kernel_hash,
kernel_input_signature=kernel_input_signature,
kernel_output_signature=kernel_output_signature,
init_state_loads=init_state_loads,
loop_range="0, T, 1",
loads=loads,
computes=f"{lhs} = {core_name}({core_args})",
stores=stores,
tail="",
).strip()
kernel_name = f"flexsn_forward_kernel_{kernel_hash}"
if verbose:
print("=" * 40, core_name, "=" * 40)
print("Generating flexsn forward kernel:")
print("```")
print(kernel_str)
print("```")
print(info)
print("=" * 40, "=" * len(core_name), "=" * 40)
kernel_exe = compile_triton_code_str(kernel_str, kernel_name, verbose)
return kernel_exe
[文档]
def get_flexsn_backward_kernel(
core_str: str,
core_name: str,
info: FlexSNInfo,
verbose: bool = False,
):
"""Compile a Triton kernel for FlexSN backward pass.
**API Language:**
:ref:`中文 <get_flexsn_backward_kernel-cn>` | :ref:`English <get_flexsn_backward_kernel-en>`
----
.. _get_flexsn_backward_kernel-cn:
* **中文**
get flexsn backward kernel 函数
:param core_str: Core kernel source code as a string
:type core_str: str
:param core_name: Unique name for the compiled kernel
:type core_name: str
:param info: FlexSN kernel metadata
:type info: ``FlexSNInfo``
:param verbose: If ``True``, print compilation info
:type verbose: bool
:return: Compiled Triton kernel executable
:rtype: triton.runtime.JITFunction
----
.. _get_flexsn_backward_kernel-en:
* **English**
Get Flexsn Backward Kernel function
:param core_str: Core kernel source code as a string
:param core_name: Unique name for the compiled kernel
:param info: FlexSN kernel metadata
:param verbose: If ``True``, print compilation info
:type core_str: str
:type core_name: str
:type info: ``FlexSNInfo``
:type verbose: bool
:return: Compiled Triton kernel executable
:rtype: triton.runtime.JITFunction
"""
kernel_hash = core_name[-8:]
num_outputs = info.num_outputs
num_inputs = info.num_inputs
num_states = info.num_states
n = len(info.c2k_return_mapping) # number of intermediate results
assert n + num_outputs + num_states == len(info.fwd_core_returns)
kernel_input_signature = _signature(
[f"grad_s{i}_seq_ptr" for i in range(num_outputs)]
+ [f"grad_v{i}_seq_ptr" for i in range(num_states)]
+ [f"res{i}_b_seq_ptr" for i in range(n)]
)
# res{i}_b slightly different from res{i}_f in the forward kernel
# as res{i}_b might be from s{i} or v{i}
kernel_output_signature = _signature(
[f"grad_x{i}_seq_ptr" for i in range(num_inputs)]
+ [f"grad_v{i}_init_ptr" for i in range(num_states)]
)
autotune_restore = ", ".join([f'"grad_x{i}_seq_ptr"' for i in range(num_inputs)])
autotune_restore += ", "
autotune_restore += ", ".join([f'"grad_v{i}_init_ptr"' for i in range(num_states)])
init_state_loads = f"\n{INDENTATION}".join(
[
f"grad_v{i}_accumulate = tl.zeros([1, BLOCK_NCL], dtype=dtype)"
for i in range(num_states)
]
)
loads = "".join(
[load_template.format(name=f"grad_s{i}") for i in range(num_outputs)]
)
loads += "".join(
[load_template.format(name=f"grad_v{i}") for i in range(num_states)]
)
loads += "".join([load_template.format(name=f"res{i}_b") for i in range(n)])
stores = "".join(
[store_template.format(name=f"grad_x{i}") for i in range(num_inputs)]
)
computes = f"\n{INDENTATION}{INDENTATION}".join(
[
f"grad_v{i}_accumulate = grad_v{i}_accumulate + grad_v{i}"
for i in range(num_states)
]
) # accumulate gradients of states
lhs = ", ".join([f"grad_x{i}" for i in range(num_inputs)])
lhs += ", "
lhs += ", ".join([f"grad_v{i}_accumulate" for i in range(num_states)])
_core_args_parts = []
if n > 0:
_core_args_parts += [f"res{i}_b" for i in range(n)]
_core_args_parts += [f"grad_s{i}" for i in range(num_outputs)]
_core_args_parts += [f"grad_v{i}_accumulate" for i in range(num_states)]
core_args = ", ".join(_core_args_parts)
computes += f"\n{INDENTATION}{INDENTATION}{lhs} = {core_name}({core_args})"
tail = f"\n{INDENTATION}".join(
[
grad_init_state_store_template.format(name=f"grad_v{i}")
for i in range(num_states)
]
)
kernel_str = kernel_template.format(
core_str=core_str,
autotune_restore=autotune_restore,
kernel_type="backward",
hash=kernel_hash,
kernel_input_signature=kernel_input_signature,
kernel_output_signature=kernel_output_signature,
init_state_loads=init_state_loads,
loop_range="T-1, -1, -1",
loads=loads,
computes=computes,
stores=stores,
tail=tail,
).strip()
kernel_name = f"flexsn_backward_kernel_{kernel_hash}"
if verbose:
print("=" * 40, core_name, "=" * 40)
print("Generated flexsn backward kernel:")
print("```")
print(kernel_str)
print("```\n")
print(info)
print("=" * 40, "=" * len(core_name), "=" * 40)
kernel_exe = compile_triton_code_str(kernel_str, kernel_name, verbose)
return kernel_exe