Skip to content

Saving Loading checkpoints

Lightning checkpoints

More details are here: Transfer Learning

The new checkpoints, generated by lightning, are very easy to save and load:

trainer = Trainer(
    # Save checkpoints to the `default_root_dir` directory
    default_root_dir="checkpoints/acoustic",
    limit_train_batches=2,
    max_epochs=1,
    accelerator="cuda",
)

# Training process...
result = trainer.fit(model=module, train_dataloaders=train_dataloader)

# ...

# Restore from the checkpoint
acoustic_module = AcousticModule.load_from_checkpoint(
    "./checkpoints/am_pitche_stats.ckpt",
)

vocoder_module = VocoderModule.load_from_checkpoint(
    "./checkpoints/vocoder.ckpt",
)

Initialize with other parameters

# if you train and save the model like this it will use these values when loading
# the weights. But you can overwrite this
LitModel(in_dim=32, out_dim=10)

# uses in_dim=32, out_dim=10
model = LitModel.load_from_checkpoint(PATH)

# uses in_dim=128, out_dim=10
model = LitModel.load_from_checkpoint(PATH, in_dim=128, out_dim=10)

Resume training state

If you don't just want to load weights, but instead restore the full training, do the following:

model = LitModel()
trainer = Trainer()

# automatically restores model, epoch, step, LR schedulers, etc...
trainer.fit(model, ckpt_path="some/path/to/my_checkpoint.ckpt")