Skip to content

STFT

stft(x, fft_size, hop_size, win_length, window)

Perform STFT and convert to magnitude spectrogram. STFT stands for Short-Time Fourier Transform. It is a signal processing technique that is used to analyze the frequency content of a signal over time. The STFT is computed by dividing a long signal into shorter segments, and then computing the Fourier transform of each segment. This results in a time-frequency representation of the signal, where the frequency content of the signal is shown as a function of time.

Parameters:

Name Type Description Default
x Tensor

Input signal tensor (B, T).

required
fft_size int

FFT size.

required
hop_size int

Hop size.

required
win_length Tensor

Window length.

required
window str

Window function type.

required

Returns:

Name Type Description
Tensor Tensor

Magnitude spectrogram (B, #frames, fft_size // 2 + 1).

Source code in training/loss/stft.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def stft(
    x: torch.Tensor,
    fft_size: int,
    hop_size: int,
    win_length: int,
    window: torch.Tensor,
) -> torch.Tensor:
    r"""Perform STFT and convert to magnitude spectrogram.
    STFT stands for Short-Time Fourier Transform. It is a signal processing technique that is used to analyze the frequency content of a signal over time. The STFT is computed by dividing a long signal into shorter segments, and then computing the Fourier transform of each segment. This results in a time-frequency representation of the signal, where the frequency content of the signal is shown as a function of time.

    Args:
        x (Tensor): Input signal tensor (B, T).
        fft_size (int): FFT size.
        hop_size (int): Hop size.
        win_length (torch.Tensor): Window length.
        window (str): Window function type.

    Returns:
        Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
    """
    x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True)
    x_stft = torch.view_as_real(x_stft)

    real = x_stft[..., 0]
    imag = x_stft[..., 1]

    # NOTE (kan-bayashi): clamp is needed to avoid nan or inf
    return torch.sqrt(torch.clamp(real**2 + imag**2, min=1e-7)).transpose(2, 1)