Lightning Data Module
Lightning Data Module
More details are here: LightningDataModule
A datamodule is a shareable, reusable class that encapsulates all the steps needed to process data
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str = "path/to/dir", batch_size: int = 32):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
def setup(self, stage: str):
self.mnist_test = MNIST(self.data_dir, train=False)
self.mnist_predict = MNIST(self.data_dir, train=False)
mnist_full = MNIST(self.data_dir, train=True)
self.mnist_train, self.mnist_val = random_split(
mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)
def predict_dataloader(self):
return DataLoader(self.mnist_predict, batch_size=self.batch_size)
def teardown(self, stage: str):
# Used to clean-up when the run is finished
...
Not sure about this module, I use the from torch.utils.data import Dataset, because it's much easy to transfer from dunky11 code.
Maybe for the next datasets it makes sense to try the new way to build the dataset with LightningDataModule