spikingjelly.activation_based.distributed.topology 源代码
from __future__ import annotations
from dataclasses import dataclass
from numbers import Integral
from types import MappingProxyType
from typing import Mapping, Tuple
[文档]
@dataclass(frozen=True)
class SNNDistributedTopology:
world_size: int
dims: Mapping[str, int]
def __post_init__(self):
frozen_dims = MappingProxyType(dict(self.dims))
object.__setattr__(self, "dims", frozen_dims)
if not isinstance(self.world_size, Integral) or isinstance(self.world_size, bool):
raise TypeError(
f"world_size must be an integer, but got {type(self.world_size).__name__}."
)
if self.world_size <= 0:
raise ValueError(f"world_size must be positive, but got {self.world_size}.")
if not self.dims:
raise ValueError("dims must not be empty.")
volume = 1
for name, size in frozen_dims.items():
if not isinstance(name, str):
raise TypeError(
f"Topology dimension names must be strings, but got {type(name).__name__}."
)
if not name:
raise ValueError("Topology dimension names must be non-empty.")
if not isinstance(size, Integral) or isinstance(size, bool):
raise TypeError(
f"Topology dimension '{name}' must be an integer, but got {type(size).__name__}."
)
if size <= 0:
raise ValueError(
f"Topology dimension '{name}' must be positive, but got {size}."
)
volume *= size
if volume != self.world_size:
raise ValueError(
f"Topology dims {dict(frozen_dims)} multiply to {volume}, but world_size={self.world_size}."
)
@property
def ordered_dim_names(self) -> Tuple[str, ...]:
preferred = ("dp", "tp", "pp", "vpp")
names = list(self.dims.keys())
ordered = [name for name in preferred if name in self.dims]
ordered.extend(name for name in names if name not in ordered)
return tuple(ordered)
@property
def mesh_shape(self) -> Tuple[int, ...]:
return tuple(int(self.dims[name]) for name in self.ordered_dim_names)
[文档]
@classmethod
def from_mapping(
cls,
dims: Mapping[str, int],
*,
world_size: int | None = None,
) -> "SNNDistributedTopology":
normalized_dims = {}
for key, value in dims.items():
if isinstance(value, bool):
raise TypeError(
f"Topology dimension '{key}' must be an integer, but got bool."
)
if isinstance(value, Integral):
normalized_dims[key] = int(value)
continue
if isinstance(value, float) and value.is_integer():
normalized_dims[key] = int(value)
continue
if isinstance(value, float):
raise TypeError(
f"Topology dimension '{key}' must be an integer, but got float."
)
raise TypeError(
f"Topology dimension '{key}' must be an integer, but got {type(value).__name__}."
)
if world_size is None:
volume = 1
for size in normalized_dims.values():
volume *= size
world_size = volume
else:
if isinstance(world_size, bool):
raise TypeError("world_size must be an integer, but got bool.")
if isinstance(world_size, Integral):
world_size = int(world_size)
elif isinstance(world_size, float) and world_size.is_integer():
world_size = int(world_size)
elif isinstance(world_size, float):
raise TypeError("world_size must be an integer, but got float.")
else:
raise TypeError(
f"world_size must be an integer, but got {type(world_size).__name__}."
)
return cls(world_size=int(world_size), dims=normalized_dims)