Skip to content

TacotronSTFT

TacotronSTFT

Bases: Module

Source code in training/preprocess/tacotron_stft.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
class TacotronSTFT(Module):
    def __init__(
        self,
        filter_length: int,
        hop_length: int,
        win_length: int,
        n_mel_channels: int,
        sampling_rate: int,
        center: bool,
        mel_fmax: Optional[int],
        mel_fmin: float = 0.0,
    ):
        r"""TacotronSTFT module that computes mel-spectrograms from a batch of waves.

        Args:
            filter_length (int): Length of the filter window.
            hop_length (int): Number of samples between successive frames.
            win_length (int): Size of the STFT window.
            n_mel_channels (int): Number of mel bins.
            sampling_rate (int): Sampling rate of the input waveforms.
            mel_fmin (int or None): Minimum frequency for the mel filter bank.
            mel_fmax (int or None): Maximum frequency for the mel filter bank.
            center (bool): Whether to pad the input signal on both sides.
        """
        super().__init__()

        self.n_mel_channels = n_mel_channels
        self.sampling_rate = sampling_rate
        self.n_fft = filter_length
        self.hop_size = hop_length
        self.win_size = win_length
        self.fmin = mel_fmin
        self.fmax = mel_fmax
        self.center = center

        # Define the mel filterbank
        mel = librosa.filters.mel(
            sr=sampling_rate,
            n_fft=filter_length,
            n_mels=n_mel_channels,
            fmin=mel_fmin,
            fmax=mel_fmax,
        )

        mel_basis = torch.from_numpy(mel).float()

        # Define the Hann window
        hann_window = torch.hann_window(win_length)

        self.register_buffer("mel_basis", mel_basis)
        self.register_buffer("hann_window", hann_window)

    def _spectrogram(self, y: torch.Tensor) -> torch.Tensor:
        r"""Computes the linear spectrogram of a batch of waves.

        Args:
            y (torch.Tensor): Input waveforms.

        Returns:
            torch.Tensor: Linear spectrogram.
        """
        assert torch.min(y.data) >= -1
        assert torch.max(y.data) <= 1

        y = torch.nn.functional.pad(
            y.unsqueeze(1),
            (
                int((self.n_fft - self.hop_size) / 2),
                int((self.n_fft - self.hop_size) / 2),
            ),
            mode="reflect",
        )
        y = y.squeeze(1)
        spec = torch.stft(
            y,
            self.n_fft,
            hop_length=self.hop_size,
            win_length=self.win_size,
            window=self.hann_window,  # type: ignore
            center=self.center,
            pad_mode="reflect",
            normalized=False,
            onesided=True,
            return_complex=True,
        )
        return torch.view_as_real(spec)

    def linear_spectrogram(self, y: torch.Tensor) -> torch.Tensor:
        r"""Computes the linear spectrogram of a batch of waves.

        Args:
            y (torch.Tensor): Input waveforms.

        Returns:
            torch.Tensor: Linear spectrogram.
        """
        spec = self._spectrogram(y)
        return torch.norm(spec, p=2, dim=-1)

    def forward(self, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        r"""Computes mel-spectrograms from a batch of waves.

        Args:
            y (torch.FloatTensor): Input waveforms with shape (B, T) in range [-1, 1]

        Returns:
            torch.FloatTensor: Spectrogram of shape (B, n_spech_channels, T)
            torch.FloatTensor: Mel-spectrogram of shape (B, n_mel_channels, T)
        """
        spec = self._spectrogram(y)

        spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))

        mel = torch.matmul(self.mel_basis, spec)  # type: ignore
        mel = self.spectral_normalize_torch(mel)

        return spec, mel

    def spectral_normalize_torch(self, magnitudes: torch.Tensor) -> torch.Tensor:
        r"""Applies dynamic range compression to magnitudes.

        Args:
            magnitudes (torch.Tensor): Input magnitudes.

        Returns:
            torch.Tensor: Output magnitudes.
        """
        return self.dynamic_range_compression_torch(magnitudes)

    def dynamic_range_compression_torch(
        self,
        x: torch.Tensor,
        C: int = 1,
        clip_val: float = 1e-5,
    ) -> torch.Tensor:
        r"""Applies dynamic range compression to x.

        Args:
            x (torch.Tensor): Input tensor.
            C (float): Compression factor.
            clip_val (float): Clipping value.

        Returns:
            torch.Tensor: Output tensor.
        """
        return torch.log(torch.clamp(x, min=clip_val) * C)

    # NOTE: audio np.ndarray changed to torch.FloatTensor!
    def get_mel_from_wav(self, audio: torch.Tensor) -> torch.Tensor:
        audio_tensor = audio.unsqueeze(0)
        with torch.no_grad():
            _, melspec = self.forward(audio_tensor)
        return melspec.squeeze(0)

__init__(filter_length, hop_length, win_length, n_mel_channels, sampling_rate, center, mel_fmax, mel_fmin=0.0)

TacotronSTFT module that computes mel-spectrograms from a batch of waves.

Parameters:

Name Type Description Default
filter_length int

Length of the filter window.

required
hop_length int

Number of samples between successive frames.

required
win_length int

Size of the STFT window.

required
n_mel_channels int

Number of mel bins.

required
sampling_rate int

Sampling rate of the input waveforms.

required
mel_fmin int or None

Minimum frequency for the mel filter bank.

0.0
mel_fmax int or None

Maximum frequency for the mel filter bank.

required
center bool

Whether to pad the input signal on both sides.

required
Source code in training/preprocess/tacotron_stft.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def __init__(
    self,
    filter_length: int,
    hop_length: int,
    win_length: int,
    n_mel_channels: int,
    sampling_rate: int,
    center: bool,
    mel_fmax: Optional[int],
    mel_fmin: float = 0.0,
):
    r"""TacotronSTFT module that computes mel-spectrograms from a batch of waves.

    Args:
        filter_length (int): Length of the filter window.
        hop_length (int): Number of samples between successive frames.
        win_length (int): Size of the STFT window.
        n_mel_channels (int): Number of mel bins.
        sampling_rate (int): Sampling rate of the input waveforms.
        mel_fmin (int or None): Minimum frequency for the mel filter bank.
        mel_fmax (int or None): Maximum frequency for the mel filter bank.
        center (bool): Whether to pad the input signal on both sides.
    """
    super().__init__()

    self.n_mel_channels = n_mel_channels
    self.sampling_rate = sampling_rate
    self.n_fft = filter_length
    self.hop_size = hop_length
    self.win_size = win_length
    self.fmin = mel_fmin
    self.fmax = mel_fmax
    self.center = center

    # Define the mel filterbank
    mel = librosa.filters.mel(
        sr=sampling_rate,
        n_fft=filter_length,
        n_mels=n_mel_channels,
        fmin=mel_fmin,
        fmax=mel_fmax,
    )

    mel_basis = torch.from_numpy(mel).float()

    # Define the Hann window
    hann_window = torch.hann_window(win_length)

    self.register_buffer("mel_basis", mel_basis)
    self.register_buffer("hann_window", hann_window)

dynamic_range_compression_torch(x, C=1, clip_val=1e-05)

Applies dynamic range compression to x.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required
C float

Compression factor.

1
clip_val float

Clipping value.

1e-05

Returns:

Type Description
Tensor

torch.Tensor: Output tensor.

Source code in training/preprocess/tacotron_stft.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def dynamic_range_compression_torch(
    self,
    x: torch.Tensor,
    C: int = 1,
    clip_val: float = 1e-5,
) -> torch.Tensor:
    r"""Applies dynamic range compression to x.

    Args:
        x (torch.Tensor): Input tensor.
        C (float): Compression factor.
        clip_val (float): Clipping value.

    Returns:
        torch.Tensor: Output tensor.
    """
    return torch.log(torch.clamp(x, min=clip_val) * C)

forward(y)

Computes mel-spectrograms from a batch of waves.

Parameters:

Name Type Description Default
y FloatTensor

Input waveforms with shape (B, T) in range [-1, 1]

required

Returns:

Type Description
Tensor

torch.FloatTensor: Spectrogram of shape (B, n_spech_channels, T)

Tensor

torch.FloatTensor: Mel-spectrogram of shape (B, n_mel_channels, T)

Source code in training/preprocess/tacotron_stft.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def forward(self, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    r"""Computes mel-spectrograms from a batch of waves.

    Args:
        y (torch.FloatTensor): Input waveforms with shape (B, T) in range [-1, 1]

    Returns:
        torch.FloatTensor: Spectrogram of shape (B, n_spech_channels, T)
        torch.FloatTensor: Mel-spectrogram of shape (B, n_mel_channels, T)
    """
    spec = self._spectrogram(y)

    spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))

    mel = torch.matmul(self.mel_basis, spec)  # type: ignore
    mel = self.spectral_normalize_torch(mel)

    return spec, mel

linear_spectrogram(y)

Computes the linear spectrogram of a batch of waves.

Parameters:

Name Type Description Default
y Tensor

Input waveforms.

required

Returns:

Type Description
Tensor

torch.Tensor: Linear spectrogram.

Source code in training/preprocess/tacotron_stft.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
def linear_spectrogram(self, y: torch.Tensor) -> torch.Tensor:
    r"""Computes the linear spectrogram of a batch of waves.

    Args:
        y (torch.Tensor): Input waveforms.

    Returns:
        torch.Tensor: Linear spectrogram.
    """
    spec = self._spectrogram(y)
    return torch.norm(spec, p=2, dim=-1)

spectral_normalize_torch(magnitudes)

Applies dynamic range compression to magnitudes.

Parameters:

Name Type Description Default
magnitudes Tensor

Input magnitudes.

required

Returns:

Type Description
Tensor

torch.Tensor: Output magnitudes.

Source code in training/preprocess/tacotron_stft.py
126
127
128
129
130
131
132
133
134
135
def spectral_normalize_torch(self, magnitudes: torch.Tensor) -> torch.Tensor:
    r"""Applies dynamic range compression to magnitudes.

    Args:
        magnitudes (torch.Tensor): Input magnitudes.

    Returns:
        torch.Tensor: Output magnitudes.
    """
    return self.dynamic_range_compression_torch(magnitudes)