Training#

In this notebook we will go through the basic process of training a an automatic mixing model. This will involve combining a dataset with a model and an appropriate training loop. For this demonstration we will PyTorch Lightning to faciliate the training.

Dataset#

For this demonstration we will use the subset of the DSD100 dataset. This is a music source separation data, but we will use it to demonstrate how you can train a model. This is a very small subset of the dataset so it can easily be downloaded and we should not expect that our model will perform very well after training.

This notebook can be used as a starting point for example by swapping out the dataset for a different dataset such as ENST-drums or MedleyDB after they have been downloaded. Since they are quite large, we will focus only on this small dataset for demonstration purposes.

GPU#

This notebook supports training with the GPU. You can achieve this by setting the Runtime to GPU in Colab using the menu bar at the top.

Learn More#

If you want to train these models on your own server and have much more control beyond this demo we encourage you to take a look at the training recipes we provide in the automix-toolkit repository.

But, let’s get started by installing the automix-toolkit.

!pip install git+https://github.com/csteinmetz1/automix-toolkit
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/csteinmetz1/automix-toolkit
  Cloning https://github.com/csteinmetz1/automix-toolkit to /tmp/pip-req-build-xunuziao
  Running command git clone -q https://github.com/csteinmetz1/automix-toolkit /tmp/pip-req-build-xunuziao
Requirement already satisfied: torch in /usr/local/lib/python3.8/dist-packages (from automix-toolkit==0.0.1) (1.12.1+cu113)
Requirement already satisfied: torchvision in /usr/local/lib/python3.8/dist-packages (from automix-toolkit==0.0.1) (0.13.1+cu113)
Requirement already satisfied: torchaudio in /usr/local/lib/python3.8/dist-packages (from automix-toolkit==0.0.1) (0.12.1+cu113)
Collecting pytorch_lightning
  Downloading pytorch_lightning-1.8.3.post1-py3-none-any.whl (798 kB)
     |████████████████████████████████| 798 kB 6.2 MB/s 
?25hRequirement already satisfied: tqdm in /usr/local/lib/python3.8/dist-packages (from automix-toolkit==0.0.1) (4.64.1)
Requirement already satisfied: numpy in /usr/local/lib/python3.8/dist-packages (from automix-toolkit==0.0.1) (1.21.6)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.8/dist-packages (from automix-toolkit==0.0.1) (3.2.2)
Collecting pedalboard
  Downloading pedalboard-0.6.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.2 MB)
     |████████████████████████████████| 3.2 MB 21.9 MB/s 
?25hRequirement already satisfied: scipy in /usr/local/lib/python3.8/dist-packages (from automix-toolkit==0.0.1) (1.7.3)
Collecting auraloss
  Downloading auraloss-0.2.2-py3-none-any.whl (15 kB)
Collecting wget
  Downloading wget-3.2.zip (10 kB)
Collecting pyloudnorm
  Downloading pyloudnorm-0.1.0-py3-none-any.whl (9.3 kB)
Collecting sklearn
  Downloading sklearn-0.0.post1.tar.gz (3.6 kB)
Requirement already satisfied: librosa in /usr/local/lib/python3.8/dist-packages (from auraloss->automix-toolkit==0.0.1) (0.8.1)
Requirement already satisfied: joblib>=0.14 in /usr/local/lib/python3.8/dist-packages (from librosa->auraloss->automix-toolkit==0.0.1) (1.2.0)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.8/dist-packages (from librosa->auraloss->automix-toolkit==0.0.1) (21.3)
Requirement already satisfied: soundfile>=0.10.2 in /usr/local/lib/python3.8/dist-packages (from librosa->auraloss->automix-toolkit==0.0.1) (0.11.0)
Requirement already satisfied: numba>=0.43.0 in /usr/local/lib/python3.8/dist-packages (from librosa->auraloss->automix-toolkit==0.0.1) (0.56.4)
Requirement already satisfied: decorator>=3.0.0 in /usr/local/lib/python3.8/dist-packages (from librosa->auraloss->automix-toolkit==0.0.1) (4.4.2)
Requirement already satisfied: scikit-learn!=0.19.0,>=0.14.0 in /usr/local/lib/python3.8/dist-packages (from librosa->auraloss->automix-toolkit==0.0.1) (1.0.2)
Requirement already satisfied: pooch>=1.0 in /usr/local/lib/python3.8/dist-packages (from librosa->auraloss->automix-toolkit==0.0.1) (1.6.0)
Requirement already satisfied: audioread>=2.0.0 in /usr/local/lib/python3.8/dist-packages (from librosa->auraloss->automix-toolkit==0.0.1) (3.0.0)
Requirement already satisfied: resampy>=0.2.2 in /usr/local/lib/python3.8/dist-packages (from librosa->auraloss->automix-toolkit==0.0.1) (0.4.2)
Requirement already satisfied: llvmlite<0.40,>=0.39.0dev0 in /usr/local/lib/python3.8/dist-packages (from numba>=0.43.0->librosa->auraloss->automix-toolkit==0.0.1) (0.39.1)
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.8/dist-packages (from numba>=0.43.0->librosa->auraloss->automix-toolkit==0.0.1) (4.13.0)
Requirement already satisfied: setuptools in /usr/local/lib/python3.8/dist-packages (from numba>=0.43.0->librosa->auraloss->automix-toolkit==0.0.1) (57.4.0)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.8/dist-packages (from packaging>=20.0->librosa->auraloss->automix-toolkit==0.0.1) (3.0.9)
Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.8/dist-packages (from pooch>=1.0->librosa->auraloss->automix-toolkit==0.0.1) (2.23.0)
Requirement already satisfied: appdirs>=1.3.0 in /usr/local/lib/python3.8/dist-packages (from pooch>=1.0->librosa->auraloss->automix-toolkit==0.0.1) (1.4.4)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->pooch>=1.0->librosa->auraloss->automix-toolkit==0.0.1) (2022.9.24)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->pooch>=1.0->librosa->auraloss->automix-toolkit==0.0.1) (1.24.3)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->pooch>=1.0->librosa->auraloss->automix-toolkit==0.0.1) (2.10)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->pooch>=1.0->librosa->auraloss->automix-toolkit==0.0.1) (3.0.4)
Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.8/dist-packages (from scikit-learn!=0.19.0,>=0.14.0->librosa->auraloss->automix-toolkit==0.0.1) (3.1.0)
Requirement already satisfied: cffi>=1.0 in /usr/local/lib/python3.8/dist-packages (from soundfile>=0.10.2->librosa->auraloss->automix-toolkit==0.0.1) (1.15.1)
Requirement already satisfied: pycparser in /usr/local/lib/python3.8/dist-packages (from cffi>=1.0->soundfile>=0.10.2->librosa->auraloss->automix-toolkit==0.0.1) (2.21)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.8/dist-packages (from importlib-metadata->numba>=0.43.0->librosa->auraloss->automix-toolkit==0.0.1) (3.10.0)
Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->automix-toolkit==0.0.1) (2.8.2)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->automix-toolkit==0.0.1) (1.4.4)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.8/dist-packages (from matplotlib->automix-toolkit==0.0.1) (0.11.0)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil>=2.1->matplotlib->automix-toolkit==0.0.1) (1.15.0)
Requirement already satisfied: future>=0.16.0 in /usr/local/lib/python3.8/dist-packages (from pyloudnorm->automix-toolkit==0.0.1) (0.16.0)
Collecting lightning-utilities==0.3.*
  Downloading lightning_utilities-0.3.0-py3-none-any.whl (15 kB)
Requirement already satisfied: PyYAML>=5.4 in /usr/local/lib/python3.8/dist-packages (from pytorch_lightning->automix-toolkit==0.0.1) (6.0)
Requirement already satisfied: typing-extensions>=4.0.0 in /usr/local/lib/python3.8/dist-packages (from pytorch_lightning->automix-toolkit==0.0.1) (4.1.1)
Requirement already satisfied: fsspec[http]>2021.06.0 in /usr/local/lib/python3.8/dist-packages (from pytorch_lightning->automix-toolkit==0.0.1) (2022.11.0)
Collecting torchmetrics>=0.7.0
  Downloading torchmetrics-0.11.0-py3-none-any.whl (512 kB)
     |████████████████████████████████| 512 kB 13.2 MB/s 
?25hCollecting tensorboardX>=2.2
  Downloading tensorboardX-2.5.1-py2.py3-none-any.whl (125 kB)
     |████████████████████████████████| 125 kB 16.5 MB/s 
?25hCollecting fire
  Downloading fire-0.4.0.tar.gz (87 kB)
     |████████████████████████████████| 87 kB 4.5 MB/s 
?25hRequirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.8/dist-packages (from fsspec[http]>2021.06.0->pytorch_lightning->automix-toolkit==0.0.1) (3.8.3)
Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.8/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>2021.06.0->pytorch_lightning->automix-toolkit==0.0.1) (1.3.3)
Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>2021.06.0->pytorch_lightning->automix-toolkit==0.0.1) (1.8.1)
Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>2021.06.0->pytorch_lightning->automix-toolkit==0.0.1) (22.1.0)
Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.8/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>2021.06.0->pytorch_lightning->automix-toolkit==0.0.1) (1.3.1)
Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.8/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>2021.06.0->pytorch_lightning->automix-toolkit==0.0.1) (6.0.2)
Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>2021.06.0->pytorch_lightning->automix-toolkit==0.0.1) (2.1.1)
Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.8/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>2021.06.0->pytorch_lightning->automix-toolkit==0.0.1) (4.0.2)
Requirement already satisfied: protobuf<=3.20.1,>=3.8.0 in /usr/local/lib/python3.8/dist-packages (from tensorboardX>=2.2->pytorch_lightning->automix-toolkit==0.0.1) (3.19.6)
Requirement already satisfied: termcolor in /usr/local/lib/python3.8/dist-packages (from fire->lightning-utilities==0.3.*->pytorch_lightning->automix-toolkit==0.0.1) (2.1.1)
Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.8/dist-packages (from torchvision->automix-toolkit==0.0.1) (7.1.2)
Building wheels for collected packages: automix-toolkit, fire, sklearn, wget
  Building wheel for automix-toolkit (setup.py) ... ?25l?25hdone
  Created wheel for automix-toolkit: filename=automix_toolkit-0.0.1-py3-none-any.whl size=35727 sha256=41ec86fd33641e15f1b88c98c5f4adaef21655e067705a52bb7e3bbfcc4a912b
  Stored in directory: /tmp/pip-ephem-wheel-cache-g99h00r8/wheels/66/2a/85/4c0a92c4a2d0108f71a9a138ac530a0346a7d57496aaab973a
  Building wheel for fire (setup.py) ... ?25l?25hdone
  Created wheel for fire: filename=fire-0.4.0-py2.py3-none-any.whl size=115943 sha256=c5af00138987f1be8a3a1becc72566337f2dce59fae222d17be81476dac1f266
  Stored in directory: /root/.cache/pip/wheels/1f/10/06/2a990ee4d73a8479fe2922445e8a876d38cfbfed052284c6a1
  Building wheel for sklearn (setup.py) ... ?25l?25hdone
  Created wheel for sklearn: filename=sklearn-0.0.post1-py3-none-any.whl size=2344 sha256=0c745271874af668ae419fc82a86022e0853f2bac3348ab8c59677db812e8acf
  Stored in directory: /root/.cache/pip/wheels/14/25/f7/1cc0956978ae479e75140219088deb7a36f60459df242b1a72
  Building wheel for wget (setup.py) ... ?25l?25hdone
  Created wheel for wget: filename=wget-3.2-py3-none-any.whl size=9674 sha256=63316b14cc93e741ac0fedd98e6bbd8017e1c24dca18f9d40a7d74bab6ea382c
  Stored in directory: /root/.cache/pip/wheels/bd/a8/c3/3cf2c14a1837a4e04bd98631724e81f33f462d86a1d895fae0
Successfully built automix-toolkit fire sklearn wget
Installing collected packages: fire, torchmetrics, tensorboardX, lightning-utilities, wget, sklearn, pytorch-lightning, pyloudnorm, pedalboard, auraloss, automix-toolkit
Successfully installed auraloss-0.2.2 automix-toolkit-0.0.1 fire-0.4.0 lightning-utilities-0.3.0 pedalboard-0.6.6 pyloudnorm-0.1.0 pytorch-lightning-1.8.3.post1 sklearn-0.0.post1 tensorboardX-2.5.1 torchmetrics-0.11.0 wget-3.2
import os
import torch
import pytorch_lightning as pl
import IPython
import numpy as np
import IPython.display as ipd
import matplotlib.pyplot as plt
import librosa.display

from argparse import Namespace

%matplotlib inline
%load_ext autoreload
%autoreload 2

from automix.data import DSD100Dataset
from automix.system import System

First we will download the dataset subset and unzip the archive as well as the pretrained encoder checkpoint.

os.makedirs("checkpoints/", exist_ok=True)
!wget https://huggingface.co/csteinmetz1/automix-toolkit/resolve/main/encoder.ckpt
!mv encoder.ckpt checkpoints/encoder.ckpt
    
!wget https://huggingface.co/csteinmetz1/automix-toolkit/resolve/main/DSD100subset.zip
!unzip -o DSD100subset.zip 
--2022-12-01 17:48:19--  https://huggingface.co/csteinmetz1/automix-toolkit/resolve/main/encoder.ckpt
Resolving huggingface.co (huggingface.co)... 54.147.99.175, 34.227.196.80, 2600:1f18:147f:e800:3df1:c2fc:20aa:9b45, ...
Connecting to huggingface.co (huggingface.co)|54.147.99.175|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/ec/ee/ecee38df047e3f2db1bd8c31a742f3a08f557470cd67cb487402a9c3ed91b5ea/90c13ab981715e1fc1ae079f15fb6da36d61d6aad29ae5dddd4d3bfd4594546a?response-content-disposition=attachment%3B%20filename%3D%22encoder.ckpt%22&Expires=1670165709&Policy=eyJTdGF0ZW1lbnQiOlt7IlJlc291cmNlIjoiaHR0cHM6Ly9jZG4tbGZzLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2VjL2VlL2VjZWUzOGRmMDQ3ZTNmMmRiMWJkOGMzMWE3NDJmM2EwOGY1NTc0NzBjZDY3Y2I0ODc0MDJhOWMzZWQ5MWI1ZWEvOTBjMTNhYjk4MTcxNWUxZmMxYWUwNzlmMTVmYjZkYTM2ZDYxZDZhYWQyOWFlNWRkZGQ0ZDNiZmQ0NTk0NTQ2YT9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPWF0dGFjaG1lbnQlM0IlMjBmaWxlbmFtZSUzRCUyMmVuY29kZXIuY2twdCUyMiIsIkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTY3MDE2NTcwOX19fV19&Signature=tppl1MHYKCBPl1qmPJq2DoBMHFJMssgCO-6lTTUY79LaUJuFqTV0XnSrlRxbAVKJk5gnAg-OenbPpMAnRkxqTnKEUZFmd3rZritCLbIR9dGyBj93E6wrcOjOuM69kz--z2YGtaWqcCyrniuJHtDFO24PciKp6LteNdhzdSqfUguzBPRJ3r-iQJ8aCPMjaUeAB6KXQvo579ZuqO7cEbuQn-gMvtUlGyHiHwNjNDTf6GhDPGjs4FaTV6g~KOzKDE4imUl6nhQOM3pktQ1D8Mi6Ry0dkMLijAXA8Na0JVXsPyxYQ8Y6mmKq7D-3HS7OG9WOK7i7HdQApZQWEDpfh34cdA__&Key-Pair-Id=KVTP0A1DKRTAX [following]
--2022-12-01 17:48:19--  https://cdn-lfs.huggingface.co/repos/ec/ee/ecee38df047e3f2db1bd8c31a742f3a08f557470cd67cb487402a9c3ed91b5ea/90c13ab981715e1fc1ae079f15fb6da36d61d6aad29ae5dddd4d3bfd4594546a?response-content-disposition=attachment%3B%20filename%3D%22encoder.ckpt%22&Expires=1670165709&Policy=eyJTdGF0ZW1lbnQiOlt7IlJlc291cmNlIjoiaHR0cHM6Ly9jZG4tbGZzLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2VjL2VlL2VjZWUzOGRmMDQ3ZTNmMmRiMWJkOGMzMWE3NDJmM2EwOGY1NTc0NzBjZDY3Y2I0ODc0MDJhOWMzZWQ5MWI1ZWEvOTBjMTNhYjk4MTcxNWUxZmMxYWUwNzlmMTVmYjZkYTM2ZDYxZDZhYWQyOWFlNWRkZGQ0ZDNiZmQ0NTk0NTQ2YT9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPWF0dGFjaG1lbnQlM0IlMjBmaWxlbmFtZSUzRCUyMmVuY29kZXIuY2twdCUyMiIsIkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTY3MDE2NTcwOX19fV19&Signature=tppl1MHYKCBPl1qmPJq2DoBMHFJMssgCO-6lTTUY79LaUJuFqTV0XnSrlRxbAVKJk5gnAg-OenbPpMAnRkxqTnKEUZFmd3rZritCLbIR9dGyBj93E6wrcOjOuM69kz--z2YGtaWqcCyrniuJHtDFO24PciKp6LteNdhzdSqfUguzBPRJ3r-iQJ8aCPMjaUeAB6KXQvo579ZuqO7cEbuQn-gMvtUlGyHiHwNjNDTf6GhDPGjs4FaTV6g~KOzKDE4imUl6nhQOM3pktQ1D8Mi6Ry0dkMLijAXA8Na0JVXsPyxYQ8Y6mmKq7D-3HS7OG9WOK7i7HdQApZQWEDpfh34cdA__&Key-Pair-Id=KVTP0A1DKRTAX
Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 108.138.94.122, 108.138.94.25, 108.138.94.14, ...
Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|108.138.94.122|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 48624134 (46M) [binary/octet-stream]
Saving to: ‘encoder.ckpt’

encoder.ckpt        100%[===================>]  46.37M   121MB/s    in 0.4s    

2022-12-01 17:48:19 (121 MB/s) - ‘encoder.ckpt’ saved [48624134/48624134]

--2022-12-01 17:48:20--  https://huggingface.co/csteinmetz1/automix-toolkit/resolve/main/DSD100subset.zip
Resolving huggingface.co (huggingface.co)... 54.147.99.175, 34.227.196.80, 2600:1f18:147f:e800:3df1:c2fc:20aa:9b45, ...
Connecting to huggingface.co (huggingface.co)|54.147.99.175|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/ec/ee/ecee38df047e3f2db1bd8c31a742f3a08f557470cd67cb487402a9c3ed91b5ea/3544bf18ffbea78aee3273ba8267a6cb15aa04b52bc430e2f39755d40d212208?response-content-disposition=attachment%3B%20filename%3D%22DSD100subset.zip%22&Expires=1670176100&Policy=eyJTdGF0ZW1lbnQiOlt7IlJlc291cmNlIjoiaHR0cHM6Ly9jZG4tbGZzLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2VjL2VlL2VjZWUzOGRmMDQ3ZTNmMmRiMWJkOGMzMWE3NDJmM2EwOGY1NTc0NzBjZDY3Y2I0ODc0MDJhOWMzZWQ5MWI1ZWEvMzU0NGJmMThmZmJlYTc4YWVlMzI3M2JhODI2N2E2Y2IxNWFhMDRiNTJiYzQzMGUyZjM5NzU1ZDQwZDIxMjIwOD9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPWF0dGFjaG1lbnQlM0IlMjBmaWxlbmFtZSUzRCUyMkRTRDEwMHN1YnNldC56aXAlMjIiLCJDb25kaXRpb24iOnsiRGF0ZUxlc3NUaGFuIjp7IkFXUzpFcG9jaFRpbWUiOjE2NzAxNzYxMDB9fX1dfQ__&Signature=0ybLBK0AJD~mWLyg~Lp1o2gD96bVJ0mRTYkScvfMf5l97oAZnqpmtI4FLHw49OvJugCzoyL5Dc46A5Vy2N28x0~64D1YDSdgoTa6o7tA~DVI857Qjw0~7Ljm8fPdzao6nZnMA9wmbPmUgqj9FnW1J81oTguXLn6jVtrsGQZYbXiN6JOhs~XWfne5m9RsScrzwjxypdBRD158vtRLtOgQCGcKChw83fh3KKMQG5VVSWAQpOMDnrEVhRx8SPCEBmOJx0sTF2SQQ1X94IelnEisF5dOeVgeZSralsEAbKSOA~3FX2wR4KmtFVJY2mvCUgTaz2BNLr57E3r9RrJDJdRZ0w__&Key-Pair-Id=KVTP0A1DKRTAX [following]
--2022-12-01 17:48:20--  https://cdn-lfs.huggingface.co/repos/ec/ee/ecee38df047e3f2db1bd8c31a742f3a08f557470cd67cb487402a9c3ed91b5ea/3544bf18ffbea78aee3273ba8267a6cb15aa04b52bc430e2f39755d40d212208?response-content-disposition=attachment%3B%20filename%3D%22DSD100subset.zip%22&Expires=1670176100&Policy=eyJTdGF0ZW1lbnQiOlt7IlJlc291cmNlIjoiaHR0cHM6Ly9jZG4tbGZzLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2VjL2VlL2VjZWUzOGRmMDQ3ZTNmMmRiMWJkOGMzMWE3NDJmM2EwOGY1NTc0NzBjZDY3Y2I0ODc0MDJhOWMzZWQ5MWI1ZWEvMzU0NGJmMThmZmJlYTc4YWVlMzI3M2JhODI2N2E2Y2IxNWFhMDRiNTJiYzQzMGUyZjM5NzU1ZDQwZDIxMjIwOD9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPWF0dGFjaG1lbnQlM0IlMjBmaWxlbmFtZSUzRCUyMkRTRDEwMHN1YnNldC56aXAlMjIiLCJDb25kaXRpb24iOnsiRGF0ZUxlc3NUaGFuIjp7IkFXUzpFcG9jaFRpbWUiOjE2NzAxNzYxMDB9fX1dfQ__&Signature=0ybLBK0AJD~mWLyg~Lp1o2gD96bVJ0mRTYkScvfMf5l97oAZnqpmtI4FLHw49OvJugCzoyL5Dc46A5Vy2N28x0~64D1YDSdgoTa6o7tA~DVI857Qjw0~7Ljm8fPdzao6nZnMA9wmbPmUgqj9FnW1J81oTguXLn6jVtrsGQZYbXiN6JOhs~XWfne5m9RsScrzwjxypdBRD158vtRLtOgQCGcKChw83fh3KKMQG5VVSWAQpOMDnrEVhRx8SPCEBmOJx0sTF2SQQ1X94IelnEisF5dOeVgeZSralsEAbKSOA~3FX2wR4KmtFVJY2mvCUgTaz2BNLr57E3r9RrJDJdRZ0w__&Key-Pair-Id=KVTP0A1DKRTAX
Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 108.138.94.25, 108.138.94.14, 108.138.94.23, ...
Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|108.138.94.25|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 126074934 (120M) [application/zip]
Saving to: ‘DSD100subset.zip’

DSD100subset.zip    100%[===================>] 120.23M  37.3MB/s    in 3.2s    

2022-12-01 17:48:23 (37.3 MB/s) - ‘DSD100subset.zip’ saved [126074934/126074934]

Archive:  DSD100subset.zip
   creating: DSD100subset/
  inflating: DSD100subset/dsd100.xlsx  
   creating: DSD100subset/Sources/
   creating: DSD100subset/Sources/Dev/
   creating: DSD100subset/Sources/Dev/081 - Patrick Talbot - Set Me Free/
  inflating: DSD100subset/Sources/Dev/081 - Patrick Talbot - Set Me Free/drums.wav  
  inflating: DSD100subset/Sources/Dev/081 - Patrick Talbot - Set Me Free/other.wav  
  inflating: DSD100subset/Sources/Dev/081 - Patrick Talbot - Set Me Free/bass.wav  
  inflating: DSD100subset/Sources/Dev/081 - Patrick Talbot - Set Me Free/vocals.wav  
   creating: DSD100subset/Sources/Dev/055 - Angels In Amplifiers - I'm Alright/
  inflating: DSD100subset/Sources/Dev/055 - Angels In Amplifiers - I'm Alright/vocals.wav  
  inflating: DSD100subset/Sources/Dev/055 - Angels In Amplifiers - I'm Alright/bass.wav  
  inflating: DSD100subset/Sources/Dev/055 - Angels In Amplifiers - I'm Alright/drums.wav  
  inflating: DSD100subset/Sources/Dev/055 - Angels In Amplifiers - I'm Alright/other.wav  
   creating: DSD100subset/Sources/Test/
   creating: DSD100subset/Sources/Test/049 - Young Griffo - Facade/
  inflating: DSD100subset/Sources/Test/049 - Young Griffo - Facade/bass.wav  
  inflating: DSD100subset/Sources/Test/049 - Young Griffo - Facade/vocals.wav  
  inflating: DSD100subset/Sources/Test/049 - Young Griffo - Facade/other.wav  
  inflating: DSD100subset/Sources/Test/049 - Young Griffo - Facade/drums.wav  
   creating: DSD100subset/Sources/Test/005 - Angela Thomas Wade - Milk Cow Blues/
  inflating: DSD100subset/Sources/Test/005 - Angela Thomas Wade - Milk Cow Blues/vocals.wav  
  inflating: DSD100subset/Sources/Test/005 - Angela Thomas Wade - Milk Cow Blues/drums.wav  
  inflating: DSD100subset/Sources/Test/005 - Angela Thomas Wade - Milk Cow Blues/other.wav  
  inflating: DSD100subset/Sources/Test/005 - Angela Thomas Wade - Milk Cow Blues/bass.wav  
   creating: DSD100subset/Mixtures/
   creating: DSD100subset/Mixtures/Test/
   creating: DSD100subset/Mixtures/Test/005 - Angela Thomas Wade - Milk Cow Blues/
  inflating: DSD100subset/Mixtures/Test/005 - Angela Thomas Wade - Milk Cow Blues/mixture.wav  
   creating: DSD100subset/Mixtures/Test/049 - Young Griffo - Facade/
  inflating: DSD100subset/Mixtures/Test/049 - Young Griffo - Facade/mixture.wav  
   creating: DSD100subset/Mixtures/Dev/
   creating: DSD100subset/Mixtures/Dev/055 - Angels In Amplifiers - I'm Alright/
  inflating: DSD100subset/Mixtures/Dev/055 - Angels In Amplifiers - I'm Alright/mixture.wav  
   creating: DSD100subset/Mixtures/Dev/081 - Patrick Talbot - Set Me Free/
  inflating: DSD100subset/Mixtures/Dev/081 - Patrick Talbot - Set Me Free/mixture.wav  
  inflating: DSD100subset/dsd100subset.txt  

Configuration#

Here we select where we want to train on CPU or GPU and what model we will use.

!nvidia-smi # check for GPU
Thu Dec  1 17:48:26 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   56C    P0    29W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
args = {
    "dataset_dir" :  "./DSD100subset",
    "dataset_name" : "DSD100",
    "automix_model" : "dmc",
    "pretrained_encoder" : True,
    "train_length" : 65536,
    "val_length" : 65536,
    "accelerator" : "gpu", # you can select "cpu" or "gpu"
    "devices" : 1, 
    "batch_size" : 4,
    "lr" : 3e-4,
    "max_epochs" : 25, 
    "schedule" : "none",
    "recon_losses" : ["sd"],
    "recon_loss_weights" : [1.0],
    "sample_rate" : 44100,
    "num_workers" : 2,
}
args = Namespace(**args)
    
pl.seed_everything(42, workers=True)
INFO:lightning_lite.utilities.seed:Global seed set to 42
42
# setup callbacks
callbacks = [
    #LogAudioCallback(),
    pl.callbacks.LearningRateMonitor(logging_interval="step"),
    pl.callbacks.ModelCheckpoint(
        filename=f"{args.dataset_name}-{args.automix_model}"
        + "_epoch-{epoch}-step-{step}",
        monitor="val/loss_epoch",
        mode="min",
        save_last=True,
        auto_insert_metric_name=False,
    ),
]

# we not will use weights and biases
#wandb_logger = WandbLogger(save_dir=log_dir, project="automix-notebook")

# create PyTorch Lightning trainer
trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks)

# create the System
system = System(**vars(args))
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.8/dist-packages/torchaudio/functional/functional.py:539: UserWarning: At least one mel filterbank has all zero values. The value for `n_mels` (128) may be set too high. Or, the value for `n_freqs` (257) may be set too low.
  warnings.warn(
Loaded weights from ./checkpoints/encoder.ckpt

Dataset#

Now we will create datasets for train/val/test but we will use the same four songs across all sets here for demonstration purposes.

train_dataset = DSD100Dataset(
    args.dataset_dir,
    args.train_length,
    44100,
    indices=[0, 4],
    num_examples_per_epoch=100,
)
val_dataset = DSD100Dataset(
    args.dataset_dir,
    args.val_length,
    44100,
    indices=[0, 4],
    num_examples_per_epoch=100,
)
test_dataset = DSD100Dataset(
    args.dataset_dir,
    args.train_length,
    44100,
    indices=[0, 4],
    num_examples_per_epoch=100,
)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    shuffle=False,
    num_workers=args.num_workers,
    persistent_workers=True,
)
val_dataloader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=args.batch_size,
    shuffle=False,
    num_workers=args.num_workers,
    persistent_workers=False,
)
100%|██████████| 4/4 [00:00<00:00, 894.93it/s]
Found 4 mixes. Using 4 in this subset.
100%|██████████| 4/4 [00:00<00:00, 2512.69it/s]
Found 4 mixes. Using 4 in this subset.
100%|██████████| 4/4 [00:00<00:00, 4591.47it/s]
Found 4 mixes. Using 4 in this subset.

Logging#

We can launch an instance of TensorBoard within our notebook to monitor the training process. Be patient, it can take ~60 seconds for the window to show.

%load_ext tensorboard
%tensorboard  --logdir="lightning_logs"

Train!#

Now we are ready to launch the training process.

trainer.fit(system, train_dataloader, val_dataloader)
WARNING:pytorch_lightning.loggers.tensorboard:Missing logger folder: /content/lightning_logs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name         | Type                        | Params
-------------------------------------------------------------
0 | model        | DifferentiableMixingConsole | 12.4 M
1 | recon_losses | ModuleDict                  | 0     
2 | sisdr        | SISDRLoss                   | 0     
3 | mrstft       | MultiResolutionSTFTLoss     | 0     
-------------------------------------------------------------
12.4 M    Trainable params
0         Non-trainable params
12.4 M    Total params
49.732    Total estimated model params size (MB)
/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py:1558: PossibleUserWarning: The number of training batches (25) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  rank_zero_warn(
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=25` reached.

Test#

After training for a few epochs we will test the system by creating a mix from one of the songs that was in the training set.

import glob
import torchaudio
start_sample = 262144 * 2
end_sample = 262144 * 3

# load the input tracks
track_dir = "DSD100subset/Sources/Dev/081 - Patrick Talbot - Set Me Free/"
track_ext = "wav"

track_filepaths = glob.glob(os.path.join(track_dir, f"*.{track_ext}"))
track_filepaths = sorted(track_filepaths)
track_names = []
tracks = []
for idx, track_filepath in enumerate(track_filepaths):
    x, sr = torchaudio.load(track_filepath)
    x = x[:, start_sample: end_sample]



    for n in range(x.shape[0]):
      x_sub = x[n:n+1, :]

      gain_dB = np.random.rand() * 12
      gain_dB *= np.random.choice([1.0, -1.0])
      gain_ln = 10 ** (gain_dB/20.0)
      x_sub *= gain_ln 

      tracks.append(x_sub)
      track_names.append(os.path.basename(track_filepath))
      IPython.display.display(ipd.Audio(x[n, :].view(1,-1).numpy(), rate=sr, normalize=True))    
      print(idx+1, os.path.basename(track_filepath))

# add dummy tracks of silence if needed
if system.hparams.automix_model == "mixwaveunet" and len(tracks) < 8:
    tracks.append(torch.zeros(x.shape))

# stack tracks into a tensor
tracks = torch.stack(tracks, dim=0)
tracks = tracks.permute(1, 0, 2)
# tracks have shape (1, num_tracks, seq_len)
print(tracks.shape)

# listen to the input (mono) before mixing
input_mix = tracks.sum(dim=1, keepdim=True)
input_mix /= input_mix.abs().max()
print(input_mix.shape)
plt.figure(figsize=(10, 2))
librosa.display.waveshow(input_mix.view(2,-1).numpy(), sr=sr, zorder=3)
plt.ylim([-1,1])
plt.grid(c="lightgray")
plt.show()
IPython.display.display(ipd.Audio(input_mix.view(1,-1).numpy(), rate=sr, normalize=False))
1 bass.wav
1 bass.wav
2 drums.wav
2 drums.wav
3 other.wav
3 other.wav
4 vocals.wav
4 vocals.wav
torch.Size([1, 8, 262144])
torch.Size([1, 1, 262144])
/usr/local/lib/python3.8/dist-packages/librosa/util/utils.py:198: UserWarning: librosa.util.frame called with axis=-1 on a non-contiguous input. This will result in a copy.
  warnings.warn(
../_images/04_training_16_17.png

Above we can hear the tracks with a simple mono mix. Now we will create a mix with the model we just trained.

tracks = tracks.view(1,8,-1)

with torch.no_grad():
  y_hat, p = system(tracks)
  
# view the mix
print(y_hat.shape)
y_hat /= y_hat.abs().max()
plt.figure(figsize=(10, 2))
librosa.display.waveshow(y_hat.view(2,-1).cpu().numpy(), sr=sr, zorder=3)
plt.ylim([-1,1])
plt.grid(c="lightgray")
plt.show()
IPython.display.display(ipd.Audio(y_hat.view(2,-1).cpu().numpy(), rate=sr, normalize=True))

# print the parameters
if system.hparams.automix_model == "dmc":
    for track_fp, param in zip(track_names, p.squeeze()):
        print(os.path.basename(track_fp), param)
torch.Size([1, 2, 262144])
/usr/local/lib/python3.8/dist-packages/librosa/util/utils.py:198: UserWarning: librosa.util.frame called with axis=-1 on a non-contiguous input. This will result in a copy.
  warnings.warn(
../_images/04_training_18_2.png
bass.wav tensor([-7.9330,  0.6107])
bass.wav tensor([0.7972, 0.6918])
drums.wav tensor([-0.9448,  0.6856])
drums.wav tensor([-2.4009,  0.6927])
other.wav tensor([5.5242, 0.7276])
other.wav tensor([10.0456,  0.7959])
vocals.wav tensor([8.6141, 0.7512])
vocals.wav tensor([-0.8755,  0.6691])

You should be able to hear that the levels have been adjusted and the sources panned to sound more like the original mix indicating that our system learned to overfit the songs in our very small training set.