Skip to content

Training process

Lightning checkpoints

More details are here: Transfer Learning

The new checkpoints, generated by lightning, are very easy to save and load:

trainer = Trainer(
    # Save checkpoints to the `default_root_dir` directory
    default_root_dir="checkpoints/acoustic",
    limit_train_batches=2,
    max_epochs=1,
    accelerator="cuda",
)

# Training process...
result = trainer.fit(model=module, train_dataloaders=train_dataloader)

# ...

# Restore from the checkpoint
acoustic_module = AcousticModule.load_from_checkpoint(
    "./checkpoints/am_pitche_stats.ckpt",
)

vocoder_module = VocoderModule.load_from_checkpoint(
    "./checkpoints/vocoder.ckpt",
)