spikingjelly.activation_based.examples.memopt.lightning_modules 源代码
from lightning import LightningModule
from torchmetrics import MeanMetric
from torchmetrics.classification import Accuracy
import torch.nn as nn
[文档]
class ClassificationLightningModule(LightningModule):
def __init__(
self,
net: nn.Module,
criterion: nn.Module,
num_classes: int,
y_with_T: bool = False, # for computing accuracy
**kwargs,
):
super().__init__()
self.y_with_T = y_with_T
self.num_classes = num_classes
self.train_acc = Accuracy(task="multiclass", num_classes=num_classes)
self.val_acc = Accuracy(task="multiclass", num_classes=num_classes)
self.train_loss = MeanMetric()
self.val_loss = MeanMetric()
self.net = net
self.criterion = criterion
[文档]
def forward(self, x):
return self.net(x)
[文档]
def training_step(self, batch, batch_idx):
x, label = batch[0].float(), batch[1]
y = self(x)
batch_loss = self.criterion(y, label) # must properly handle the sizes!
if self.y_with_T:
y = y.mean(dim=0)
if label.ndim > 1:
label = label.argmax(dim=1)
self.train_acc.update(y, label)
self.train_loss.update(batch_loss.data)
self.log("train_loss", self.train_loss.compute(), prog_bar=True)
self.log("train_acc", self.train_acc.compute() * 100, prog_bar=True)
return batch_loss
[文档]
def on_train_epoch_end(self):
train_acc = self.train_acc.compute()
train_loss = self.train_loss.compute()
self.log("train_loss", train_loss, on_epoch=True, sync_dist=True)
self.log("train_acc", train_acc * 100, on_epoch=True, sync_dist=True)
self.train_acc.reset()
self.train_loss.reset()
if self.global_rank == 0:
print(
f"Epoch {self.current_epoch}/{self.trainer.max_epochs}: "
f"train_loss={train_loss:.2f}, "
f"train_acc={train_acc * 100:.2f}%"
)
[文档]
def validation_step(self, batch, batch_idx):
x, label = batch[0].float(), batch[1]
y = self(x)
batch_loss = self.criterion(y, label) # must properly handle the sizes!
if self.y_with_T:
y = y.mean(dim=0)
if label.ndim > 1:
label = label.argmax(dim=1)
self.val_acc.update(y, label)
self.val_loss.update(batch_loss.data)
return batch_loss
[文档]
def on_validation_epoch_end(self):
val_acc = self.val_acc.compute()
val_loss = self.val_loss.compute()
self.log("val_acc", val_acc * 100, on_epoch=True, sync_dist=True)
self.log("val_loss", val_loss, on_epoch=True, sync_dist=True)
self.val_acc.reset()
self.val_loss.reset()
if self.global_rank == 0:
print(
f"Epoch {self.current_epoch}/{self.trainer.max_epochs}: "
f"val_loss={val_loss:.2f}, val_acc={val_acc * 100:.2f}%"
)