Saving Loading checkpoints
Lightning checkpoints
More details are here: Transfer Learning
The new checkpoints, generated by lightning
, are very easy to save and load:
trainer = Trainer(
# Save checkpoints to the `default_root_dir` directory
default_root_dir="checkpoints/acoustic",
limit_train_batches=2,
max_epochs=1,
accelerator="cuda",
)
# Training process...
result = trainer.fit(model=module, train_dataloaders=train_dataloader)
# ...
# Restore from the checkpoint
acoustic_module = AcousticModule.load_from_checkpoint(
"./checkpoints/am_pitche_stats.ckpt",
)
vocoder_module = VocoderModule.load_from_checkpoint(
"./checkpoints/vocoder.ckpt",
)
Initialize with other parameters
# if you train and save the model like this it will use these values when loading
# the weights. But you can overwrite this
LitModel(in_dim=32, out_dim=10)
# uses in_dim=32, out_dim=10
model = LitModel.load_from_checkpoint(PATH)
# uses in_dim=128, out_dim=10
model = LitModel.load_from_checkpoint(PATH, in_dim=128, out_dim=10)
Resume training state
If you don't just want to load weights, but instead restore the full training, do the following:
model = LitModel()
trainer = Trainer()
# automatically restores model, epoch, step, LR schedulers, etc...
trainer.fit(model, ckpt_path="some/path/to/my_checkpoint.ckpt")