Skip to content

Libri TTS dataset vocoder

LibriTTSDatasetVocoder

Bases: Dataset

Loading preprocessed univnet model data.

Source code in training/datasets/libritts_dataset_vocoder.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
class LibriTTSDatasetVocoder(Dataset):
    r"""Loading preprocessed univnet model data."""

    def __init__(
        self,
        root: str,
        batch_size: int,
        download: bool = True,
        lang: str = "en",
    ):
        r"""A PyTorch dataset for loading preprocessed univnet data.

        Args:
            root (str): Path to the directory where the dataset is found or downloaded.
            batch_size (int): Batch size for the dataset.
            download (bool, optional): Whether to download the dataset if it is not found. Defaults to True.
        """
        self.dataset = datasets.LIBRITTS(root=root, download=download)
        self.batch_size = batch_size

        lang_map = get_lang_map(lang)
        self.preprocess_libtts = PreprocessLibriTTS(
            PreprocessingConfigUnivNet(lang_map.processing_lang_type),
        )

    def __len__(self) -> int:
        r"""Returns the number of samples in the dataset.

        Returns
            int: Number of samples in the dataset.
        """
        return len(self.dataset)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        r"""Returns a sample from the dataset at the given index.

        Args:
            idx (int): Index of the sample to return.

        Returns:
            Dict[str, Any]: A dictionary containing the sample data.
        """
        # Retrive the dataset row
        data = self.dataset[idx]

        data = self.preprocess_libtts.univnet(data)

        if data is None:
            # print("Skipping due to preprocessing error")
            rand_idx = np.random.randint(0, self.__len__())
            return self.__getitem__(rand_idx)

        mel, audio, speaker_id = data

        return {
            "mel": mel,
            "audio": audio,
            "speaker_id": speaker_id,
        }

    def collate_fn(self, data: List) -> List:
        r"""Collates a batch of data samples.

        Args:
            data (List): A list of data samples.

        Returns:
            List: A list of reprocessed data batches.
        """
        data_size = len(data)

        idxs = list(range(data_size))

        # Initialize empty lists to store extracted values
        empty_lists: List[List] = [[] for _ in range(4)]
        (
            mels,
            mel_lens,
            audios,
            speaker_ids,
        ) = empty_lists

        # Extract fields from data dictionary and populate the lists
        for idx in idxs:
            data_entry = data[idx]

            mels.append(data_entry["mel"])
            mel_lens.append(data_entry["mel"].shape[1])
            audios.append(data_entry["audio"])
            speaker_ids.append(data_entry["speaker_id"])

        mels = torch.tensor(pad_2D(mels), dtype=torch.float32)
        mel_lens = torch.tensor(mel_lens, dtype=torch.int64)
        audios = torch.tensor(pad_1D(audios), dtype=torch.float32)
        speaker_ids = torch.tensor(speaker_ids, dtype=torch.int64)

        return [
            mels,
            mel_lens,
            audios,
            speaker_ids,
        ]

__getitem__(idx)

Returns a sample from the dataset at the given index.

Parameters:

Name Type Description Default
idx int

Index of the sample to return.

required

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: A dictionary containing the sample data.

Source code in training/datasets/libritts_dataset_vocoder.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def __getitem__(self, idx: int) -> Dict[str, Any]:
    r"""Returns a sample from the dataset at the given index.

    Args:
        idx (int): Index of the sample to return.

    Returns:
        Dict[str, Any]: A dictionary containing the sample data.
    """
    # Retrive the dataset row
    data = self.dataset[idx]

    data = self.preprocess_libtts.univnet(data)

    if data is None:
        # print("Skipping due to preprocessing error")
        rand_idx = np.random.randint(0, self.__len__())
        return self.__getitem__(rand_idx)

    mel, audio, speaker_id = data

    return {
        "mel": mel,
        "audio": audio,
        "speaker_id": speaker_id,
    }

__init__(root, batch_size, download=True, lang='en')

A PyTorch dataset for loading preprocessed univnet data.

Parameters:

Name Type Description Default
root str

Path to the directory where the dataset is found or downloaded.

required
batch_size int

Batch size for the dataset.

required
download bool

Whether to download the dataset if it is not found. Defaults to True.

True
Source code in training/datasets/libritts_dataset_vocoder.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def __init__(
    self,
    root: str,
    batch_size: int,
    download: bool = True,
    lang: str = "en",
):
    r"""A PyTorch dataset for loading preprocessed univnet data.

    Args:
        root (str): Path to the directory where the dataset is found or downloaded.
        batch_size (int): Batch size for the dataset.
        download (bool, optional): Whether to download the dataset if it is not found. Defaults to True.
    """
    self.dataset = datasets.LIBRITTS(root=root, download=download)
    self.batch_size = batch_size

    lang_map = get_lang_map(lang)
    self.preprocess_libtts = PreprocessLibriTTS(
        PreprocessingConfigUnivNet(lang_map.processing_lang_type),
    )

__len__()

Returns the number of samples in the dataset.

Returns int: Number of samples in the dataset.

Source code in training/datasets/libritts_dataset_vocoder.py
38
39
40
41
42
43
44
def __len__(self) -> int:
    r"""Returns the number of samples in the dataset.

    Returns
        int: Number of samples in the dataset.
    """
    return len(self.dataset)

collate_fn(data)

Collates a batch of data samples.

Parameters:

Name Type Description Default
data List

A list of data samples.

required

Returns:

Name Type Description
List List

A list of reprocessed data batches.

Source code in training/datasets/libritts_dataset_vocoder.py
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def collate_fn(self, data: List) -> List:
    r"""Collates a batch of data samples.

    Args:
        data (List): A list of data samples.

    Returns:
        List: A list of reprocessed data batches.
    """
    data_size = len(data)

    idxs = list(range(data_size))

    # Initialize empty lists to store extracted values
    empty_lists: List[List] = [[] for _ in range(4)]
    (
        mels,
        mel_lens,
        audios,
        speaker_ids,
    ) = empty_lists

    # Extract fields from data dictionary and populate the lists
    for idx in idxs:
        data_entry = data[idx]

        mels.append(data_entry["mel"])
        mel_lens.append(data_entry["mel"].shape[1])
        audios.append(data_entry["audio"])
        speaker_ids.append(data_entry["speaker_id"])

    mels = torch.tensor(pad_2D(mels), dtype=torch.float32)
    mel_lens = torch.tensor(mel_lens, dtype=torch.int64)
    audios = torch.tensor(pad_1D(audios), dtype=torch.float32)
    speaker_ids = torch.tensor(speaker_ids, dtype=torch.int64)

    return [
        mels,
        mel_lens,
        audios,
        speaker_ids,
    ]