Multi Resolution STFT Loss
MultiResolutionSTFTLoss
Bases: Module
Multi resolution STFT loss module.
The Multi resolution STFT loss module is a PyTorch module that computes the spectral convergence and log STFT magnitude losses for a predicted signal and a groundtruth signal at multiple resolutions. The module is designed for speech and audio signal processing tasks, such as speech enhancement and source separation.
The module takes as input a list of tuples, where each tuple contains the FFT size, hop size, and window length for a particular resolution. For each resolution, the module computes the spectral convergence and log STFT magnitude losses using the STFTLoss module, which is a PyTorch module that computes the STFT of a signal and the corresponding magnitude spectrogram.
The spectral convergence loss measures the similarity between two magnitude spectrograms, while the log STFT magnitude loss measures the similarity between two logarithmically-scaled magnitude spectrograms. The logarithm is applied to the magnitude spectrograms to convert them to a decibel scale, which is more perceptually meaningful than the linear scale.
The Multi resolution STFT loss module returns the average spectral convergence and log STFT magnitude losses across all resolutions. This allows the module to capture both fine-grained and coarse-grained spectral information in the predicted and groundtruth signals.
Source code in training/loss/multi_resolution_stft_loss.py
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
|
__init__(resolutions)
Initialize Multi resolution STFT loss module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
resolutions |
list
|
List of (FFT size, shift size, window length). |
required |
Source code in training/loss/multi_resolution_stft_loss.py
19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
|
forward(x, y)
Calculate forward propagation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Tensor
|
Predicted signal (B, T). |
required |
y |
Tensor
|
Groundtruth signal (B, T). |
required |
Returns:
Name | Type | Description |
---|---|---|
Tensor |
Tensor
|
Multi resolution spectral convergence loss value. |
Tensor |
Tensor
|
Multi resolution log STFT magnitude loss value. |
Source code in training/loss/multi_resolution_stft_loss.py
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
|