Training process
Lightning checkpoints
More details are here: Transfer Learning
The new checkpoints, generated by lightning
, are very easy to save and load:
trainer = Trainer(
# Save checkpoints to the `default_root_dir` directory
default_root_dir="checkpoints/acoustic",
limit_train_batches=2,
max_epochs=1,
accelerator="cuda",
)
# Training process...
result = trainer.fit(model=module, train_dataloaders=train_dataloader)
# ...
# Restore from the checkpoint
acoustic_module = AcousticModule.load_from_checkpoint(
"./checkpoints/am_pitche_stats.ckpt",
)
vocoder_module = VocoderModule.load_from_checkpoint(
"./checkpoints/vocoder.ckpt",
)