Skip to content

Log STFT Magnitude Loss

LogSTFTMagnitudeLoss

Bases: Module

Log STFT magnitude loss module. Log STFT magnitude loss is a loss function that is commonly used in speech and audio signal processing tasks, such as speech enhancement and source separation. It is a modification of the spectral convergence loss, which measures the similarity between two magnitude spectrograms.

The log STFT magnitude loss is calculated as the mean squared error between the logarithm of the predicted and groundtruth magnitude spectrograms. The logarithm is applied to the magnitude spectrograms to convert them to a decibel scale, which is more perceptually meaningful than the linear scale. The mean squared error is used to penalize large errors between the predicted and groundtruth spectrograms.

Source code in training/loss/log_stft_magnitude_loss.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class LogSTFTMagnitudeLoss(Module):
    r"""Log STFT magnitude loss module.
    Log STFT magnitude loss is a loss function that is commonly used in speech and audio signal processing tasks, such as speech enhancement and source separation. It is a modification of the spectral convergence loss, which measures the similarity between two magnitude spectrograms.

    The log STFT magnitude loss is calculated as the mean squared error between the logarithm of the predicted and groundtruth magnitude spectrograms. The logarithm is applied to the magnitude spectrograms to convert them to a decibel scale, which is more perceptually meaningful than the linear scale. The mean squared error is used to penalize large errors between the predicted and groundtruth spectrograms.
    """

    def __init__(self):
        r"""Initilize los STFT magnitude loss module."""
        super().__init__()

    def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor) -> torch.Tensor:
        r"""Calculate forward propagation.

        Args:
            x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
            y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).

        Returns:
            Tensor: Log STFT magnitude loss value.
        """
        # Ensure that x_mag and y_mag have the same size along dimension 1
        min_len = min(x_mag.shape[1], y_mag.shape[1])
        x_mag = x_mag[:, :min_len]
        y_mag = y_mag[:, :min_len]

        return F.l1_loss(torch.log(y_mag), torch.log(x_mag))

__init__()

Initilize los STFT magnitude loss module.

Source code in training/loss/log_stft_magnitude_loss.py
13
14
15
def __init__(self):
    r"""Initilize los STFT magnitude loss module."""
    super().__init__()

forward(x_mag, y_mag)

Calculate forward propagation.

Parameters:

Name Type Description Default
x_mag Tensor

Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).

required
y_mag Tensor

Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).

required

Returns:

Name Type Description
Tensor Tensor

Log STFT magnitude loss value.

Source code in training/loss/log_stft_magnitude_loss.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor) -> torch.Tensor:
    r"""Calculate forward propagation.

    Args:
        x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
        y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).

    Returns:
        Tensor: Log STFT magnitude loss value.
    """
    # Ensure that x_mag and y_mag have the same size along dimension 1
    min_len = min(x_mag.shape[1], y_mag.shape[1])
    x_mag = x_mag[:, :min_len]
    y_mag = y_mag[:, :min_len]

    return F.l1_loss(torch.log(y_mag), torch.log(x_mag))