spikingjelly.activation_based.examples.memopt.data_module 源代码

import PIL
import lightning as L
import numpy as np
import torch
from torch.utils.data.dataloader import DataLoader
import torchvision.transforms as transforms
from spikingjelly.datasets import CIFAR10DVSTEBNSplit


[文档] class Cutout: """Randomly mask out one or more patches from an image. Args: n_holes (int): Number of patches to cut out of each image. length (int): The length (in pixels) of each square patch. max_length (int): If not None, randomly sample the length of the square patch. If None, use the argument `length` instead. """ def __init__(self, n_holes, length=None, max_length=None): self.n_holes = n_holes self.length = length self.max_length = max_length def __call__(self, img): """ Args: img (Tensor): Tensor image of size (C, H, W). Returns: Tensor: Image with n_holes of dimension length x length cut out of it. """ h = img.size(-2) w = img.size(-1) mask = np.ones((h, w), np.float32) for _ in range(self.n_holes): y = np.random.randint(h) x = np.random.randint(w) length = self.length if self.max_length is not None: length = np.random.randint(1, self.max_length) y1 = np.clip(y - length // 2, 0, h) y2 = np.clip(y + length // 2, 0, h) x1 = np.clip(x - length // 2, 0, w) x2 = np.clip(x + length // 2, 0, w) mask[y1:y2, x1:x2] = 0.0 mask = torch.from_numpy(mask) mask = mask.expand_as(img) img = img * mask return img
[文档] class CIFAR10DVSNDA: def __init__(self, M=1, N=2): self.M = M self.N = N def __call__(self, data): c = 15 * self.N rotate_tf = transforms.RandomRotation(degrees=c) e = 8 * self.N cutout_tf = Cutout(n_holes=1, length=e) def roll(data, N=1): a = N * 2 + 1 off1 = np.random.randint(-a, a + 1) off2 = np.random.randint(-a, a + 1) return torch.roll(data, shifts=(off1, off2), dims=(2, 3)) def rotate(data, N): return rotate_tf(data) def cutout(data, N): return cutout_tf(data) transforms_list = [roll, rotate, cutout] sampled_ops = np.random.choice(transforms_list, self.M) for op in sampled_ops: data = op(data, self.N) return data
[文档] class CIFAR10DVSDataModule(L.LightningDataModule): def __init__( self, data_dir: str, T: int, batch_size: int = 128, num_workers: int = 4 ): super().__init__() self.data_dir = data_dir self.T = T self.batch_size = batch_size self.num_workers = num_workers
[文档] def prepare_data(self): CIFAR10DVSTEBNSplit( self.data_dir, train=True, data_type="frame", frames_number=self.T, split_by="number", ) CIFAR10DVSTEBNSplit( self.data_dir, train=False, data_type="frame", frames_number=self.T, split_by="number", )
[文档] def setup(self, stage: str): self.train_set = CIFAR10DVSTEBNSplit( self.data_dir, train=True, data_type="frame", frames_number=self.T, split_by="number", transform=transforms.Compose( [ transforms.Lambda(lambda x: torch.tensor(x, dtype=torch.float32)), transforms.RandomResizedCrop( 128, scale=(0.7, 1.0), interpolation=PIL.Image.NEAREST ), transforms.Resize(size=(48, 48)), transforms.RandomHorizontalFlip(p=0.5), CIFAR10DVSNDA(M=1, N=2), ] ), ) self.test_set = CIFAR10DVSTEBNSplit( self.data_dir, train=False, data_type="frame", frames_number=self.T, split_by="number", transform=transforms.Compose( [ transforms.Lambda(lambda x: torch.tensor(x, dtype=torch.float32)), transforms.Resize(size=(48, 48)), ] ), )
[文档] def train_dataloader(self): return DataLoader( self.train_set, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, pin_memory=True, drop_last=True, )
[文档] def val_dataloader(self): return DataLoader( self.test_set, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=True, drop_last=False, )
[文档] def test_dataloader(self): return self.val_dataloader()
[文档] def predict_dataloader(self): return self.val_dataloader()