Datasets for automix systems#

In this notebook, we will first discuss the datasets used to train the automix systems. Thereafter, we will see how to pre-process the data and set up the dataloaders for training the deep learning models for these systems.

Training automix models requires paired multitrack stems and their corresponding mixdowns. Below listed are the desired properties for these datasets:

  1. Time alligned stems and mixes : We require time-alligned stems and mixes to allow the models to learn timewise transformation relationships.

  2. Diverse instrument categories : The more diverse the number of instruments in the dataset, the more likely is the trained system to perform well with real-world songs.

  3. Diverse genres of songs : The mixing practices vary slightly from one genre to another. Hence, if the dataset has multitrack mixes from different genres, the trained system will be exposed to more diverse distribution of data.

  4. Dry multitrack stems : Mixing involves processing the recorded dry stems for corrective and aesthetic reasons before summing them to form a cohesive mixture. For a model to learn the correct way to process the stems to generate mixes, we need to train it on dry unprocessed stems and mix pairs. However, more recently approaches to use processed stems from datasets like MUSEDB to train automix systems have been explored. These approaches use a pre-processing effect normalisation method to deal with pre-processed wet stems. For the scope of this tutorial, we do not discuss these methods. However, we recommend having a look at this paper being presented at ISMIR 2022.

Here we list the datasets available for training automix systems.

Dataset

Size(Hrs)

no. of Songs

no. of Instrument Category

no. of tracks

Type

Usage Permissions

Other info

Remarks

MedleyDB

7.2

122

82

1-26

Multitrack, Wav

Open

44.1KHz, 16 bit, stereo

-

ENST Drums

1.25

-

1

8

Drums, Wav/AVI

Limited

44.1KHz, 16 bit, stereo

Drums only dataset

Cambridge Multitrack

>3

>50

>5

5-70

Multitrack, Wav

open

44.1KHz, 16/24 bit, Stereo

Not time alligned, recordings for all the songs are not uniform

MUSEDB

~10

150

4

4

Multitrack, Wav

open

44.1KHz, Stereo

used mainly for source separation, wet stems

Slakh

145

2100

34

4-48

Synthesised, Flac,

open

44.1KHz, 16 bit, stereo

used mainly for source separation; sometimes wet stems

Shaking Through

4.5

68

>30

>40

Multitrack, Wav

User only

44.1/88.2KHz, 16/24 bit, stereo

-

BitMIDI

-

>1M

>5

>5

Multitrack MIDI

open

MIDI data

MIDI data submitted by users across world

For this tutorial, we will use ENST-drums for training Wave-U-Net and ENST-drums, DSD100, and MedleyDB for training Differentiable Mixing Console(DMC).

In the following section, we will discuss the recommended pre-processing methods for these datasets and the methods to set up dataloaders for training the models. This notebook assumes that you have already installed the automix package.

We define dataset classes for DSD100, MedleyDB, and ENSTdrums, and then use getitem() function to load the audio data into the dataloader for training and testing.

Listed below are few of the advised variables that you should define in the dataset class definition:#

  1. Root directory of the folder containing the dataset.

  2. Length of the audio you wish to load for training/testing.

  3. Sample rate at which you wish to load the audio data.

Pre-processing advice for loading multitrack data:#

  1. Discard the examples from the dataset that have length shorter than the prescribed length.

     ```
     #code from automix/data/drums.py
     #remove any mixes that are shorter than the requested length
     self.mix_filepaths = [
         fp
         for fp in self.mix_filepaths
         
         # use torchaudio.info to get information about the audio. This is much faster than loading the whole audio.
         if torchaudio.info(fp).num_frames > self.length
     ]
     ```
    
  2. Loudness normalise the stems and the mixes after loading.

     ```
     #code from automix/data/drums.py
     y /= y.abs().max().clamp(1e-8) 
     ```
    
  3. Look out for silence in the loaded audio: Common practice is to generate a random starting index for the frame from which the audio is loaded. However, it is likely that some of the multitrack stem or the mix as a whole could have just silence in this chunk of loaded audio. This results in generation of NaN in the audio tensor when it is normalised. In the below shown code block, we show how to check for silence. We keep generating a new starting index(offset)) for loading the audio until the audio has some content and is not just silence(silent is False).

             ```
             #code from automix/data/drums.py
             # load the chunk of the mix
                     silent = True
                     while silent:
                     # get random offset
                     offset = np.random.randint(0, md.num_frames - self.length - 1)
    
                     y, sr = torchaudio.load(
                             mix_filepath,
                             frame_offset=offset,
                             num_frames=self.length,
                     )
                     energy = (y**2).mean()
                     if energy > 1e-8:
                             silent = False
    
                     # only normalise the audio that are not silent
                     y /= y.abs().max().clamp(1e-8)  # peak normalize
             ```
    

ENST Drums#

Below described is the folder structure of the ENST Drums dataset:

  • ENST-Drums

    • drummer_1

      • annotation

      • audio

        • accompaniment

        • dry mix

        • hi-hat

        • kick

        • overhead L

        • overhead R

        • snare

        • tom 1

        • tom 2

        • wet mix

    • drummer_2

      • (same structure as drummer_1)

    • drummer_3

      • (same structure as drummer_1)

We are going to use audios from the wet mix folder for this tutorial.

In the automix/data/drums, we define an ENSTDrumsdataset class and use the getitem() to load data for the dataloader in our training loop.

class ENSTDrumsDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        root_dir: str,
        length: int,
        sample_rate: float,
        drummers: List[int] = [1, 2],
        track_names: List[str] = [
            "kick",
            "snare",
            "hi-hat",
            "overhead_L",
            "overhead_R",
            "tom_1",
            "tom_2",
            "tom_3",
        ],
        indices: Tuple[int, int] = [0, 1],
        wet_mix: bool = False,
        hits: bool = False,
        num_examples_per_epoch: int = 1000,
        seed: int = 42,
    ) -> None:
  • We use indices to define the train-test split.

  • In the getitem() of the dataset class, we first generate a mix_idx which is a random number in the range of 0 and the number of songs in the directory(len of mix_filepaths). This allows to randomly pick a mix/song from the mix_filepath.

    def __getitem__(self, _):
            # select a mix at random
            mix_idx = np.random.randint(0, len(self.mix_filepaths))
            mix_filepath = self.mix_filepaths[mix_idx]
            example_id = os.path.basename(mix_filepath)
            drummer_id = os.path.normpath(mix_filepath).split(os.path.sep)[-4]
    
            md = torchaudio.info(mix_filepath)  # check length
    
  • Next, we load the mix(y) from the filepath. Make sure to check for silence as discussed above. Once the mix is loaded, peak normalise it.

            # load the chunk of the mix
            silent = True
            while silent:
                # get random offset
                offset = np.random.randint(0, md.num_frames - self.length - 1)
    
                y, sr = torchaudio.load(
                    mix_filepath,
                    frame_offset=offset,
                    num_frames=self.length,
                )
                energy = (y**2).mean()
                if energy > 1e-8:
                    silent = False
    
            y /= y.abs().max().clamp(1e-8)  # peak normalize
    
  • Last step is to load the stems. max_num_tracks is the maximum number of tracks you want to load. Some songs might have less or more stems than this number. We keep a track of empty stems using pad which is True whenever the stem is empty. The getitem() returns stems tensor (x), mix (y), and pad information.

      # -------------------- load the tracks from disk --------------------
      x = torch.zeros((self.max_num_tracks, self.length))
      pad = [True] * self.max_num_tracks  # note which tracks are empty
    
      for tidx, track_name in enumerate(self.track_names):
          track_filepath = os.path.join(
              self.root_dir,
              drummer_id,
              "audio",
              track_name,
              example_id,
          )
          if os.path.isfile(track_filepath):
              x_s, sr = torchaudio.load(
                  track_filepath,
                  frame_offset=offset,
                  num_frames=self.length,
              )
              x_s /= x_s.abs().max().clamp(1e-6)
              x_s *= 10 ** (-12 / 20.0)
              x[tidx, :] = x_s
              pad[tidx] = False
    
      return x, y, pad
    

DSD100 dataset#

Below described is the folder structure of the DSD100 dataset:

  • ENST Drums

    • Train

      • Songdir(songname)

        • vocals.wav

        • bass.wav

        • drums.wav

        • other.wav

        • accompaniment.wav

        • mixture.wav

    • Test

      • Songdir(songname)

        • vocals.wav

        • bass.wav

        • drums.wav

        • other.wav

        • accompaniment.wav

        • mixture.wav

Note: Accompaniment is the sum of bass, drums, and other.

For the purpose of training our models, we use:

Input: vocals, bass, drums, and other

Output: Mixture

We will first define a dataset class and use the getitem() function to load items into the dataloader.

#Code from automix/data/dsd100.py

class DSD100Dataset(torch.utils.data.Dataset):
    def __init__(
        self,
        root_dir: str,
        length: int,
        sample_rate: float,
        indices: Tuple[int, int],
        track_names: List[str] = ["bass", "drums", "other", "vocals"],
        num_examples_per_epoch: int = 1000,
    ) -> None:

Hereafter, we follow similar structure in getitem() as in the case of ENSTDrums.

  • We first pick a mix_filepath on random and then look for non-silent part to load the mix(y).

  • Then, we load stems(x) starting with the same start_idx of the prescribed length.

  • We peak normalise all the loaded stems and mix and save the empty stem inofrmation in the pad variable.

  • We then return x, y, and pad.

MedleyDB Dataset#

Described below is the folder structure for MedleyDB:

  • MedleyDB

    • songnames

      • songname_MIX.wav

      • songname_STEMS

        • songname_STEMS_{stem_number}.wav

      • songname_RAW

        • songname_STEMS_{stem_number}_{track_number}.wav

  • STEMS folder have some of the RAW audio tracks combined into a single audio file.

  • RAW folder contains all of the audio tracks individually.

We define the corresponding dataset class like before.

class MedleyDBDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        root_dirs: List[str],
        length: int,
        sample_rate: float,
        indices: List[int] = [0, 100],
        max_num_tracks: int = 16,
        num_examples_per_epoch: int = 1000,
        buffer_size_gb: float = 3.0,
        buffer_reload_rate: int = 200,
        normalization: str = "peak",
    ) -> None:
  • indices define the train-test split.

  • buffer_size_gb specifies the amount of data loaded onto RAM

  • buffer_reload_rate specifies the rate of loading new data onto the RAM.

In case of large datasets like MedleyDB which have large number of stems in the songs, it could be very time-consuming to always load audio tracks from the disk. However, we could load a small subset of the dataset randomly onto the RAM every few iterations to speed up the process.

We load nbytes_loaded amount of data onto the RAM everytime the items_since_load > buffer_reload_rate

#code from automix/data/medleydb.py

def reload_buffer(self):

        self.examples = []  # clear buffer
        self.items_since_load = 0  # reset iteration counter
        nbytes_loaded = 0  # counter for data in RAM

        # different subset in each
        random.shuffle(self.mix_dirs)

        # load files into RAM
        for mix_dir in self.mix_dirs:
            mix_id = os.path.basename(mix_dir)
            mix_filepath = glob.glob(os.path.join(mix_dir, "*.wav"))[0]

            # now check the length of the mix
            try:
                y, sr = torchaudio.load(mix_filepath)
            except:
                print(f"Skipping {mix_filepath}")
                continue

            mix_num_frames = y.shape[-1]
            nbytes = y.element_size() * y.nelement()
            nbytes_loaded += nbytes

            # now find all the track filepaths
            track_filepaths = glob.glob(os.path.join(mix_dir, f"{mix_id}_RAW", "*.wav"))

            if len(track_filepaths) > self.max_num_tracks:
                continue

            # check length of each track
            tracks = []
            for tidx, track_filepath in enumerate(track_filepaths):
                x, sr = torchaudio.load(track_filepath)
                tracks.append(x)

                nbytes = x.element_size() * x.nelement()
                nbytes_loaded += nbytes

                track_num_frames = x.shape[-1]
                if track_num_frames < mix_num_frames:
                    mix_num_frames = track_num_frames

            # store this example
            example = {
                "mix_id": os.path.dirname(mix_filepath).split(os.sep)[-1],
                "mix_filepath": mix_filepath,
                "mix_audio": y,
                "num_frames": mix_num_frames,
                "track_filepaths": track_filepaths,
                "track_audio": tracks,
            }

            self.examples.append(example)

            # check the size of loaded data
            if nbytes_loaded > self.buffer_size_gb * 1e9:
                break
!pip install git+https://github.com/csteinmetz1/automix-toolkit
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/csteinmetz1/automix-toolkit
  Cloning https://github.com/csteinmetz1/automix-toolkit to /tmp/pip-req-build-ynp21f13
  Running command git clone -q https://github.com/csteinmetz1/automix-toolkit /tmp/pip-req-build-ynp21f13
Requirement already satisfied: torch in /usr/local/lib/python3.8/dist-packages (from automix-toolkit==0.0.1) (1.12.1+cu113)
Requirement already satisfied: torchvision in /usr/local/lib/python3.8/dist-packages (from automix-toolkit==0.0.1) (0.13.1+cu113)
Requirement already satisfied: torchaudio in /usr/local/lib/python3.8/dist-packages (from automix-toolkit==0.0.1) (0.12.1+cu113)
Requirement already satisfied: pytorch_lightning in /usr/local/lib/python3.8/dist-packages (from automix-toolkit==0.0.1) (1.8.3.post1)
Requirement already satisfied: tqdm in /usr/local/lib/python3.8/dist-packages (from automix-toolkit==0.0.1) (4.64.1)
Requirement already satisfied: numpy in /usr/local/lib/python3.8/dist-packages (from automix-toolkit==0.0.1) (1.21.6)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.8/dist-packages (from automix-toolkit==0.0.1) (3.2.2)
Requirement already satisfied: pedalboard in /usr/local/lib/python3.8/dist-packages (from automix-toolkit==0.0.1) (0.6.6)
Requirement already satisfied: scipy in /usr/local/lib/python3.8/dist-packages (from automix-toolkit==0.0.1) (1.7.3)
Requirement already satisfied: auraloss in /usr/local/lib/python3.8/dist-packages (from automix-toolkit==0.0.1) (0.2.2)
Requirement already satisfied: wget in /usr/local/lib/python3.8/dist-packages (from automix-toolkit==0.0.1) (3.2)
Requirement already satisfied: pyloudnorm in /usr/local/lib/python3.8/dist-packages (from automix-toolkit==0.0.1) (0.1.0)
Requirement already satisfied: sklearn in /usr/local/lib/python3.8/dist-packages (from automix-toolkit==0.0.1) (0.0.post1)
Requirement already satisfied: librosa in /usr/local/lib/python3.8/dist-packages (from auraloss->automix-toolkit==0.0.1) (0.8.1)
Requirement already satisfied: resampy>=0.2.2 in /usr/local/lib/python3.8/dist-packages (from librosa->auraloss->automix-toolkit==0.0.1) (0.4.2)
Requirement already satisfied: decorator>=3.0.0 in /usr/local/lib/python3.8/dist-packages (from librosa->auraloss->automix-toolkit==0.0.1) (4.4.2)
Requirement already satisfied: scikit-learn!=0.19.0,>=0.14.0 in /usr/local/lib/python3.8/dist-packages (from librosa->auraloss->automix-toolkit==0.0.1) (1.0.2)
Requirement already satisfied: pooch>=1.0 in /usr/local/lib/python3.8/dist-packages (from librosa->auraloss->automix-toolkit==0.0.1) (1.6.0)
Requirement already satisfied: numba>=0.43.0 in /usr/local/lib/python3.8/dist-packages (from librosa->auraloss->automix-toolkit==0.0.1) (0.56.4)
Requirement already satisfied: soundfile>=0.10.2 in /usr/local/lib/python3.8/dist-packages (from librosa->auraloss->automix-toolkit==0.0.1) (0.11.0)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.8/dist-packages (from librosa->auraloss->automix-toolkit==0.0.1) (21.3)
Requirement already satisfied: audioread>=2.0.0 in /usr/local/lib/python3.8/dist-packages (from librosa->auraloss->automix-toolkit==0.0.1) (3.0.0)
Requirement already satisfied: joblib>=0.14 in /usr/local/lib/python3.8/dist-packages (from librosa->auraloss->automix-toolkit==0.0.1) (1.2.0)
Requirement already satisfied: llvmlite<0.40,>=0.39.0dev0 in /usr/local/lib/python3.8/dist-packages (from numba>=0.43.0->librosa->auraloss->automix-toolkit==0.0.1) (0.39.1)
Requirement already satisfied: setuptools in /usr/local/lib/python3.8/dist-packages (from numba>=0.43.0->librosa->auraloss->automix-toolkit==0.0.1) (57.4.0)
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.8/dist-packages (from numba>=0.43.0->librosa->auraloss->automix-toolkit==0.0.1) (4.13.0)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.8/dist-packages (from packaging>=20.0->librosa->auraloss->automix-toolkit==0.0.1) (3.0.9)
Requirement already satisfied: appdirs>=1.3.0 in /usr/local/lib/python3.8/dist-packages (from pooch>=1.0->librosa->auraloss->automix-toolkit==0.0.1) (1.4.4)
Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.8/dist-packages (from pooch>=1.0->librosa->auraloss->automix-toolkit==0.0.1) (2.23.0)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->pooch>=1.0->librosa->auraloss->automix-toolkit==0.0.1) (3.0.4)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->pooch>=1.0->librosa->auraloss->automix-toolkit==0.0.1) (2.10)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->pooch>=1.0->librosa->auraloss->automix-toolkit==0.0.1) (2022.9.24)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->pooch>=1.0->librosa->auraloss->automix-toolkit==0.0.1) (1.24.3)
Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.8/dist-packages (from scikit-learn!=0.19.0,>=0.14.0->librosa->auraloss->automix-toolkit==0.0.1) (3.1.0)
Requirement already satisfied: cffi>=1.0 in /usr/local/lib/python3.8/dist-packages (from soundfile>=0.10.2->librosa->auraloss->automix-toolkit==0.0.1) (1.15.1)
Requirement already satisfied: pycparser in /usr/local/lib/python3.8/dist-packages (from cffi>=1.0->soundfile>=0.10.2->librosa->auraloss->automix-toolkit==0.0.1) (2.21)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.8/dist-packages (from importlib-metadata->numba>=0.43.0->librosa->auraloss->automix-toolkit==0.0.1) (3.10.0)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.8/dist-packages (from matplotlib->automix-toolkit==0.0.1) (0.11.0)
Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->automix-toolkit==0.0.1) (2.8.2)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->automix-toolkit==0.0.1) (1.4.4)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil>=2.1->matplotlib->automix-toolkit==0.0.1) (1.15.0)
Requirement already satisfied: future>=0.16.0 in /usr/local/lib/python3.8/dist-packages (from pyloudnorm->automix-toolkit==0.0.1) (0.16.0)
Requirement already satisfied: torchmetrics>=0.7.0 in /usr/local/lib/python3.8/dist-packages (from pytorch_lightning->automix-toolkit==0.0.1) (0.11.0)
Requirement already satisfied: tensorboardX>=2.2 in /usr/local/lib/python3.8/dist-packages (from pytorch_lightning->automix-toolkit==0.0.1) (2.5.1)
Requirement already satisfied: PyYAML>=5.4 in /usr/local/lib/python3.8/dist-packages (from pytorch_lightning->automix-toolkit==0.0.1) (6.0)
Requirement already satisfied: typing-extensions>=4.0.0 in /usr/local/lib/python3.8/dist-packages (from pytorch_lightning->automix-toolkit==0.0.1) (4.1.1)
Requirement already satisfied: lightning-utilities==0.3.* in /usr/local/lib/python3.8/dist-packages (from pytorch_lightning->automix-toolkit==0.0.1) (0.3.0)
Requirement already satisfied: fsspec[http]>2021.06.0 in /usr/local/lib/python3.8/dist-packages (from pytorch_lightning->automix-toolkit==0.0.1) (2022.11.0)
Requirement already satisfied: fire in /usr/local/lib/python3.8/dist-packages (from lightning-utilities==0.3.*->pytorch_lightning->automix-toolkit==0.0.1) (0.4.0)
Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.8/dist-packages (from fsspec[http]>2021.06.0->pytorch_lightning->automix-toolkit==0.0.1) (3.8.3)
Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.8/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>2021.06.0->pytorch_lightning->automix-toolkit==0.0.1) (4.0.2)
Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>2021.06.0->pytorch_lightning->automix-toolkit==0.0.1) (2.1.1)
Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>2021.06.0->pytorch_lightning->automix-toolkit==0.0.1) (22.1.0)
Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.8/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>2021.06.0->pytorch_lightning->automix-toolkit==0.0.1) (1.3.3)
Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.8/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>2021.06.0->pytorch_lightning->automix-toolkit==0.0.1) (1.3.1)
Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>2021.06.0->pytorch_lightning->automix-toolkit==0.0.1) (1.8.1)
Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.8/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>2021.06.0->pytorch_lightning->automix-toolkit==0.0.1) (6.0.2)
Requirement already satisfied: protobuf<=3.20.1,>=3.8.0 in /usr/local/lib/python3.8/dist-packages (from tensorboardX>=2.2->pytorch_lightning->automix-toolkit==0.0.1) (3.19.6)
Requirement already satisfied: termcolor in /usr/local/lib/python3.8/dist-packages (from fire->lightning-utilities==0.3.*->pytorch_lightning->automix-toolkit==0.0.1) (2.1.1)
Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.8/dist-packages (from torchvision->automix-toolkit==0.0.1) (7.1.2)
from automix.data import DSD100Dataset
import torch
import torchaudio
import matplotlib.pyplot as plt
import librosa
import librosa.display
import IPython
import numpy as np
import os

Now we will download a subset of DSD100 and load it using the dataloader.

#First lets download a subset of DSD100
!wget https://huggingface.co/csteinmetz1/automix-toolkit/resolve/main/DSD100subset.zip
!unzip -o DSD100subset.zip 
--2022-12-04 03:54:36--  https://huggingface.co/csteinmetz1/automix-toolkit/resolve/main/DSD100subset.zip
Resolving huggingface.co (huggingface.co)... 3.234.187.147, 54.147.99.175, 2600:1f18:147f:e850:d78f:7d9d:6ec3:2aee, ...
Connecting to huggingface.co (huggingface.co)|3.234.187.147|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/ec/ee/ecee38df047e3f2db1bd8c31a742f3a08f557470cd67cb487402a9c3ed91b5ea/3544bf18ffbea78aee3273ba8267a6cb15aa04b52bc430e2f39755d40d212208?response-content-disposition=attachment%3B%20filename%3D%22DSD100subset.zip%22&Expires=1670385277&Policy=eyJTdGF0ZW1lbnQiOlt7IlJlc291cmNlIjoiaHR0cHM6Ly9jZG4tbGZzLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2VjL2VlL2VjZWUzOGRmMDQ3ZTNmMmRiMWJkOGMzMWE3NDJmM2EwOGY1NTc0NzBjZDY3Y2I0ODc0MDJhOWMzZWQ5MWI1ZWEvMzU0NGJmMThmZmJlYTc4YWVlMzI3M2JhODI2N2E2Y2IxNWFhMDRiNTJiYzQzMGUyZjM5NzU1ZDQwZDIxMjIwOD9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPWF0dGFjaG1lbnQlM0IlMjBmaWxlbmFtZSUzRCUyMkRTRDEwMHN1YnNldC56aXAlMjIiLCJDb25kaXRpb24iOnsiRGF0ZUxlc3NUaGFuIjp7IkFXUzpFcG9jaFRpbWUiOjE2NzAzODUyNzd9fX1dfQ__&Signature=STySP8RVdnHfKtN3H2G4gCTGtufLX64ng8Hxbw6oqRnbvzEjvIIValONjHq4WUh0b1u7VqZXGCzRJGQBIp9ZZ6KdpajUD3DYEpcuseKBJA01ZkBZvUkO4WbfvSlfutZzYu30-FCCP0sF9aSms~Z6WpTEcooLiT53YyLyQktUY5ggM3ghDFupG8qlHjjR3D5FnMg3dDKQo-5blOtlF622NllOFYjPnuOY8KB3o5T0cIUUUDBW6lzi~MkhGjnZdib4wB~h8uv4ZfJPsPS6lE0LpphVm8zTDAX24t5yWLIBEcXZnhiSnd7C7WTdRV-sllQhoL4C96Dcr2RlQjDG72mS7g__&Key-Pair-Id=KVTP0A1DKRTAX [following]
--2022-12-04 03:54:37--  https://cdn-lfs.huggingface.co/repos/ec/ee/ecee38df047e3f2db1bd8c31a742f3a08f557470cd67cb487402a9c3ed91b5ea/3544bf18ffbea78aee3273ba8267a6cb15aa04b52bc430e2f39755d40d212208?response-content-disposition=attachment%3B%20filename%3D%22DSD100subset.zip%22&Expires=1670385277&Policy=eyJTdGF0ZW1lbnQiOlt7IlJlc291cmNlIjoiaHR0cHM6Ly9jZG4tbGZzLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2VjL2VlL2VjZWUzOGRmMDQ3ZTNmMmRiMWJkOGMzMWE3NDJmM2EwOGY1NTc0NzBjZDY3Y2I0ODc0MDJhOWMzZWQ5MWI1ZWEvMzU0NGJmMThmZmJlYTc4YWVlMzI3M2JhODI2N2E2Y2IxNWFhMDRiNTJiYzQzMGUyZjM5NzU1ZDQwZDIxMjIwOD9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPWF0dGFjaG1lbnQlM0IlMjBmaWxlbmFtZSUzRCUyMkRTRDEwMHN1YnNldC56aXAlMjIiLCJDb25kaXRpb24iOnsiRGF0ZUxlc3NUaGFuIjp7IkFXUzpFcG9jaFRpbWUiOjE2NzAzODUyNzd9fX1dfQ__&Signature=STySP8RVdnHfKtN3H2G4gCTGtufLX64ng8Hxbw6oqRnbvzEjvIIValONjHq4WUh0b1u7VqZXGCzRJGQBIp9ZZ6KdpajUD3DYEpcuseKBJA01ZkBZvUkO4WbfvSlfutZzYu30-FCCP0sF9aSms~Z6WpTEcooLiT53YyLyQktUY5ggM3ghDFupG8qlHjjR3D5FnMg3dDKQo-5blOtlF622NllOFYjPnuOY8KB3o5T0cIUUUDBW6lzi~MkhGjnZdib4wB~h8uv4ZfJPsPS6lE0LpphVm8zTDAX24t5yWLIBEcXZnhiSnd7C7WTdRV-sllQhoL4C96Dcr2RlQjDG72mS7g__&Key-Pair-Id=KVTP0A1DKRTAX
Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 108.157.162.58, 108.157.162.95, 108.157.162.27, ...
Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|108.157.162.58|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 126074934 (120M) [application/zip]
Saving to: ‘DSD100subset.zip.1’

DSD100subset.zip.1  100%[===================>] 120.23M  64.7MB/s    in 1.9s    

2022-12-04 03:54:39 (64.7 MB/s) - ‘DSD100subset.zip.1’ saved [126074934/126074934]

Archive:  DSD100subset.zip
  inflating: DSD100subset/dsd100.xlsx  
  inflating: DSD100subset/Sources/Dev/081 - Patrick Talbot - Set Me Free/drums.wav  
  inflating: DSD100subset/Sources/Dev/081 - Patrick Talbot - Set Me Free/other.wav  
  inflating: DSD100subset/Sources/Dev/081 - Patrick Talbot - Set Me Free/bass.wav  
  inflating: DSD100subset/Sources/Dev/081 - Patrick Talbot - Set Me Free/vocals.wav  
  inflating: DSD100subset/Sources/Dev/055 - Angels In Amplifiers - I'm Alright/vocals.wav  
  inflating: DSD100subset/Sources/Dev/055 - Angels In Amplifiers - I'm Alright/bass.wav  
  inflating: DSD100subset/Sources/Dev/055 - Angels In Amplifiers - I'm Alright/drums.wav  
  inflating: DSD100subset/Sources/Dev/055 - Angels In Amplifiers - I'm Alright/other.wav  
  inflating: DSD100subset/Sources/Test/049 - Young Griffo - Facade/bass.wav  
  inflating: DSD100subset/Sources/Test/049 - Young Griffo - Facade/vocals.wav  
  inflating: DSD100subset/Sources/Test/049 - Young Griffo - Facade/other.wav  
  inflating: DSD100subset/Sources/Test/049 - Young Griffo - Facade/drums.wav  
  inflating: DSD100subset/Sources/Test/005 - Angela Thomas Wade - Milk Cow Blues/vocals.wav  
  inflating: DSD100subset/Sources/Test/005 - Angela Thomas Wade - Milk Cow Blues/drums.wav  
  inflating: DSD100subset/Sources/Test/005 - Angela Thomas Wade - Milk Cow Blues/other.wav  
  inflating: DSD100subset/Sources/Test/005 - Angela Thomas Wade - Milk Cow Blues/bass.wav  
  inflating: DSD100subset/Mixtures/Test/005 - Angela Thomas Wade - Milk Cow Blues/mixture.wav  
  inflating: DSD100subset/Mixtures/Test/049 - Young Griffo - Facade/mixture.wav  
  inflating: DSD100subset/Mixtures/Dev/055 - Angels In Amplifiers - I'm Alright/mixture.wav  
  inflating: DSD100subset/Mixtures/Dev/081 - Patrick Talbot - Set Me Free/mixture.wav  
  inflating: DSD100subset/dsd100subset.txt  

Load the dataset.#

We will use the DSD100Dataset class from the automix.data module. We load data at 44.1KHz sample rate. Let’s have the train length = 65536 frames We will split the dataset to have the first four examples as train and the rest as test; this is indicated using indices.

num_frames = 65536
sample_rate = 44100

train_dataset = DSD100Dataset(
    "./DSD100subset",
    num_frames,
    sample_rate,
    indices=[0, 4],
    num_examples_per_epoch=100,)

#Define the dataloader
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=1,
    persistent_workers=True,
)

print(train_dataloader)
100%|██████████| 4/4 [00:00<00:00, 4095.00it/s]
Found 4 mixes. Using 4 in this subset.
<torch.utils.data.dataloader.DataLoader object at 0x7fe9669f5460>

Lop over the dataloader to load examples for batch size of 1. We will see the shape of the loaded data.

for i,( stems, mix, pad) in enumerate(train_dataloader):
    print("Stems shape: ", stems.shape)
    print("Mix shape: ", mix.shape)
    print("Pad shape: ", len(pad))
    print("Pad: ", pad)
    break
Stems shape:  torch.Size([1, 8, 65536])
Mix shape:  torch.Size([1, 2, 65536])
Pad shape:  1
Pad:  tensor([[False, False, False, False, False, False, False, False]])