Training
Contents
Training#
In this notebook we will go through the basic process of training a an automatic mixing model. This will involve combining a dataset with a model and an appropriate training loop. For this demonstration we will PyTorch Lightning to faciliate the training.
Dataset#
For this demonstration we will use the subset of the DSD100 dataset. This is a music source separation data, but we will use it to demonstrate how you can train a model. This is a very small subset of the dataset so it can easily be downloaded and we should not expect that our model will perform very well after training.
This notebook can be used as a starting point for example by swapping out the dataset for a different dataset such as ENST-drums or MedleyDB after they have been downloaded. Since they are quite large, we will focus only on this small dataset for demonstration purposes.
GPU#
This notebook supports training with the GPU. You can achieve this by setting the Runtime
to GPU
in Colab using the menu bar at the top.
Learn More#
If you want to train these models on your own server and have much more control beyond this demo we encourage you to take a look at the training recipes we provide in the automix-toolkit repository.
But, let’s get started by installing the automix-toolkit.
!pip install git+https://github.com/csteinmetz1/automix-toolkit
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/csteinmetz1/automix-toolkit
Cloning https://github.com/csteinmetz1/automix-toolkit to /tmp/pip-req-build-xunuziao
Running command git clone -q https://github.com/csteinmetz1/automix-toolkit /tmp/pip-req-build-xunuziao
Requirement already satisfied: torch in /usr/local/lib/python3.8/dist-packages (from automix-toolkit==0.0.1) (1.12.1+cu113)
Requirement already satisfied: torchvision in /usr/local/lib/python3.8/dist-packages (from automix-toolkit==0.0.1) (0.13.1+cu113)
Requirement already satisfied: torchaudio in /usr/local/lib/python3.8/dist-packages (from automix-toolkit==0.0.1) (0.12.1+cu113)
Collecting pytorch_lightning
Downloading pytorch_lightning-1.8.3.post1-py3-none-any.whl (798 kB)
|████████████████████████████████| 798 kB 6.2 MB/s
?25hRequirement already satisfied: tqdm in /usr/local/lib/python3.8/dist-packages (from automix-toolkit==0.0.1) (4.64.1)
Requirement already satisfied: numpy in /usr/local/lib/python3.8/dist-packages (from automix-toolkit==0.0.1) (1.21.6)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.8/dist-packages (from automix-toolkit==0.0.1) (3.2.2)
Collecting pedalboard
Downloading pedalboard-0.6.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.2 MB)
|████████████████████████████████| 3.2 MB 21.9 MB/s
?25hRequirement already satisfied: scipy in /usr/local/lib/python3.8/dist-packages (from automix-toolkit==0.0.1) (1.7.3)
Collecting auraloss
Downloading auraloss-0.2.2-py3-none-any.whl (15 kB)
Collecting wget
Downloading wget-3.2.zip (10 kB)
Collecting pyloudnorm
Downloading pyloudnorm-0.1.0-py3-none-any.whl (9.3 kB)
Collecting sklearn
Downloading sklearn-0.0.post1.tar.gz (3.6 kB)
Requirement already satisfied: librosa in /usr/local/lib/python3.8/dist-packages (from auraloss->automix-toolkit==0.0.1) (0.8.1)
Requirement already satisfied: joblib>=0.14 in /usr/local/lib/python3.8/dist-packages (from librosa->auraloss->automix-toolkit==0.0.1) (1.2.0)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.8/dist-packages (from librosa->auraloss->automix-toolkit==0.0.1) (21.3)
Requirement already satisfied: soundfile>=0.10.2 in /usr/local/lib/python3.8/dist-packages (from librosa->auraloss->automix-toolkit==0.0.1) (0.11.0)
Requirement already satisfied: numba>=0.43.0 in /usr/local/lib/python3.8/dist-packages (from librosa->auraloss->automix-toolkit==0.0.1) (0.56.4)
Requirement already satisfied: decorator>=3.0.0 in /usr/local/lib/python3.8/dist-packages (from librosa->auraloss->automix-toolkit==0.0.1) (4.4.2)
Requirement already satisfied: scikit-learn!=0.19.0,>=0.14.0 in /usr/local/lib/python3.8/dist-packages (from librosa->auraloss->automix-toolkit==0.0.1) (1.0.2)
Requirement already satisfied: pooch>=1.0 in /usr/local/lib/python3.8/dist-packages (from librosa->auraloss->automix-toolkit==0.0.1) (1.6.0)
Requirement already satisfied: audioread>=2.0.0 in /usr/local/lib/python3.8/dist-packages (from librosa->auraloss->automix-toolkit==0.0.1) (3.0.0)
Requirement already satisfied: resampy>=0.2.2 in /usr/local/lib/python3.8/dist-packages (from librosa->auraloss->automix-toolkit==0.0.1) (0.4.2)
Requirement already satisfied: llvmlite<0.40,>=0.39.0dev0 in /usr/local/lib/python3.8/dist-packages (from numba>=0.43.0->librosa->auraloss->automix-toolkit==0.0.1) (0.39.1)
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.8/dist-packages (from numba>=0.43.0->librosa->auraloss->automix-toolkit==0.0.1) (4.13.0)
Requirement already satisfied: setuptools in /usr/local/lib/python3.8/dist-packages (from numba>=0.43.0->librosa->auraloss->automix-toolkit==0.0.1) (57.4.0)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.8/dist-packages (from packaging>=20.0->librosa->auraloss->automix-toolkit==0.0.1) (3.0.9)
Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.8/dist-packages (from pooch>=1.0->librosa->auraloss->automix-toolkit==0.0.1) (2.23.0)
Requirement already satisfied: appdirs>=1.3.0 in /usr/local/lib/python3.8/dist-packages (from pooch>=1.0->librosa->auraloss->automix-toolkit==0.0.1) (1.4.4)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->pooch>=1.0->librosa->auraloss->automix-toolkit==0.0.1) (2022.9.24)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->pooch>=1.0->librosa->auraloss->automix-toolkit==0.0.1) (1.24.3)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->pooch>=1.0->librosa->auraloss->automix-toolkit==0.0.1) (2.10)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->pooch>=1.0->librosa->auraloss->automix-toolkit==0.0.1) (3.0.4)
Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.8/dist-packages (from scikit-learn!=0.19.0,>=0.14.0->librosa->auraloss->automix-toolkit==0.0.1) (3.1.0)
Requirement already satisfied: cffi>=1.0 in /usr/local/lib/python3.8/dist-packages (from soundfile>=0.10.2->librosa->auraloss->automix-toolkit==0.0.1) (1.15.1)
Requirement already satisfied: pycparser in /usr/local/lib/python3.8/dist-packages (from cffi>=1.0->soundfile>=0.10.2->librosa->auraloss->automix-toolkit==0.0.1) (2.21)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.8/dist-packages (from importlib-metadata->numba>=0.43.0->librosa->auraloss->automix-toolkit==0.0.1) (3.10.0)
Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->automix-toolkit==0.0.1) (2.8.2)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->automix-toolkit==0.0.1) (1.4.4)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.8/dist-packages (from matplotlib->automix-toolkit==0.0.1) (0.11.0)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil>=2.1->matplotlib->automix-toolkit==0.0.1) (1.15.0)
Requirement already satisfied: future>=0.16.0 in /usr/local/lib/python3.8/dist-packages (from pyloudnorm->automix-toolkit==0.0.1) (0.16.0)
Collecting lightning-utilities==0.3.*
Downloading lightning_utilities-0.3.0-py3-none-any.whl (15 kB)
Requirement already satisfied: PyYAML>=5.4 in /usr/local/lib/python3.8/dist-packages (from pytorch_lightning->automix-toolkit==0.0.1) (6.0)
Requirement already satisfied: typing-extensions>=4.0.0 in /usr/local/lib/python3.8/dist-packages (from pytorch_lightning->automix-toolkit==0.0.1) (4.1.1)
Requirement already satisfied: fsspec[http]>2021.06.0 in /usr/local/lib/python3.8/dist-packages (from pytorch_lightning->automix-toolkit==0.0.1) (2022.11.0)
Collecting torchmetrics>=0.7.0
Downloading torchmetrics-0.11.0-py3-none-any.whl (512 kB)
|████████████████████████████████| 512 kB 13.2 MB/s
?25hCollecting tensorboardX>=2.2
Downloading tensorboardX-2.5.1-py2.py3-none-any.whl (125 kB)
|████████████████████████████████| 125 kB 16.5 MB/s
?25hCollecting fire
Downloading fire-0.4.0.tar.gz (87 kB)
|████████████████████████████████| 87 kB 4.5 MB/s
?25hRequirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.8/dist-packages (from fsspec[http]>2021.06.0->pytorch_lightning->automix-toolkit==0.0.1) (3.8.3)
Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.8/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>2021.06.0->pytorch_lightning->automix-toolkit==0.0.1) (1.3.3)
Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>2021.06.0->pytorch_lightning->automix-toolkit==0.0.1) (1.8.1)
Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>2021.06.0->pytorch_lightning->automix-toolkit==0.0.1) (22.1.0)
Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.8/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>2021.06.0->pytorch_lightning->automix-toolkit==0.0.1) (1.3.1)
Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.8/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>2021.06.0->pytorch_lightning->automix-toolkit==0.0.1) (6.0.2)
Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>2021.06.0->pytorch_lightning->automix-toolkit==0.0.1) (2.1.1)
Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.8/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>2021.06.0->pytorch_lightning->automix-toolkit==0.0.1) (4.0.2)
Requirement already satisfied: protobuf<=3.20.1,>=3.8.0 in /usr/local/lib/python3.8/dist-packages (from tensorboardX>=2.2->pytorch_lightning->automix-toolkit==0.0.1) (3.19.6)
Requirement already satisfied: termcolor in /usr/local/lib/python3.8/dist-packages (from fire->lightning-utilities==0.3.*->pytorch_lightning->automix-toolkit==0.0.1) (2.1.1)
Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.8/dist-packages (from torchvision->automix-toolkit==0.0.1) (7.1.2)
Building wheels for collected packages: automix-toolkit, fire, sklearn, wget
Building wheel for automix-toolkit (setup.py) ... ?25l?25hdone
Created wheel for automix-toolkit: filename=automix_toolkit-0.0.1-py3-none-any.whl size=35727 sha256=41ec86fd33641e15f1b88c98c5f4adaef21655e067705a52bb7e3bbfcc4a912b
Stored in directory: /tmp/pip-ephem-wheel-cache-g99h00r8/wheels/66/2a/85/4c0a92c4a2d0108f71a9a138ac530a0346a7d57496aaab973a
Building wheel for fire (setup.py) ... ?25l?25hdone
Created wheel for fire: filename=fire-0.4.0-py2.py3-none-any.whl size=115943 sha256=c5af00138987f1be8a3a1becc72566337f2dce59fae222d17be81476dac1f266
Stored in directory: /root/.cache/pip/wheels/1f/10/06/2a990ee4d73a8479fe2922445e8a876d38cfbfed052284c6a1
Building wheel for sklearn (setup.py) ... ?25l?25hdone
Created wheel for sklearn: filename=sklearn-0.0.post1-py3-none-any.whl size=2344 sha256=0c745271874af668ae419fc82a86022e0853f2bac3348ab8c59677db812e8acf
Stored in directory: /root/.cache/pip/wheels/14/25/f7/1cc0956978ae479e75140219088deb7a36f60459df242b1a72
Building wheel for wget (setup.py) ... ?25l?25hdone
Created wheel for wget: filename=wget-3.2-py3-none-any.whl size=9674 sha256=63316b14cc93e741ac0fedd98e6bbd8017e1c24dca18f9d40a7d74bab6ea382c
Stored in directory: /root/.cache/pip/wheels/bd/a8/c3/3cf2c14a1837a4e04bd98631724e81f33f462d86a1d895fae0
Successfully built automix-toolkit fire sklearn wget
Installing collected packages: fire, torchmetrics, tensorboardX, lightning-utilities, wget, sklearn, pytorch-lightning, pyloudnorm, pedalboard, auraloss, automix-toolkit
Successfully installed auraloss-0.2.2 automix-toolkit-0.0.1 fire-0.4.0 lightning-utilities-0.3.0 pedalboard-0.6.6 pyloudnorm-0.1.0 pytorch-lightning-1.8.3.post1 sklearn-0.0.post1 tensorboardX-2.5.1 torchmetrics-0.11.0 wget-3.2
import os
import torch
import pytorch_lightning as pl
import IPython
import numpy as np
import IPython.display as ipd
import matplotlib.pyplot as plt
import librosa.display
from argparse import Namespace
%matplotlib inline
%load_ext autoreload
%autoreload 2
from automix.data import DSD100Dataset
from automix.system import System
First we will download the dataset subset and unzip the archive as well as the pretrained encoder checkpoint.
os.makedirs("checkpoints/", exist_ok=True)
!wget https://huggingface.co/csteinmetz1/automix-toolkit/resolve/main/encoder.ckpt
!mv encoder.ckpt checkpoints/encoder.ckpt
!wget https://huggingface.co/csteinmetz1/automix-toolkit/resolve/main/DSD100subset.zip
!unzip -o DSD100subset.zip
--2022-12-01 17:48:19-- https://huggingface.co/csteinmetz1/automix-toolkit/resolve/main/encoder.ckpt
Resolving huggingface.co (huggingface.co)... 54.147.99.175, 34.227.196.80, 2600:1f18:147f:e800:3df1:c2fc:20aa:9b45, ...
Connecting to huggingface.co (huggingface.co)|54.147.99.175|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/ec/ee/ecee38df047e3f2db1bd8c31a742f3a08f557470cd67cb487402a9c3ed91b5ea/90c13ab981715e1fc1ae079f15fb6da36d61d6aad29ae5dddd4d3bfd4594546a?response-content-disposition=attachment%3B%20filename%3D%22encoder.ckpt%22&Expires=1670165709&Policy=eyJTdGF0ZW1lbnQiOlt7IlJlc291cmNlIjoiaHR0cHM6Ly9jZG4tbGZzLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2VjL2VlL2VjZWUzOGRmMDQ3ZTNmMmRiMWJkOGMzMWE3NDJmM2EwOGY1NTc0NzBjZDY3Y2I0ODc0MDJhOWMzZWQ5MWI1ZWEvOTBjMTNhYjk4MTcxNWUxZmMxYWUwNzlmMTVmYjZkYTM2ZDYxZDZhYWQyOWFlNWRkZGQ0ZDNiZmQ0NTk0NTQ2YT9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPWF0dGFjaG1lbnQlM0IlMjBmaWxlbmFtZSUzRCUyMmVuY29kZXIuY2twdCUyMiIsIkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTY3MDE2NTcwOX19fV19&Signature=tppl1MHYKCBPl1qmPJq2DoBMHFJMssgCO-6lTTUY79LaUJuFqTV0XnSrlRxbAVKJk5gnAg-OenbPpMAnRkxqTnKEUZFmd3rZritCLbIR9dGyBj93E6wrcOjOuM69kz--z2YGtaWqcCyrniuJHtDFO24PciKp6LteNdhzdSqfUguzBPRJ3r-iQJ8aCPMjaUeAB6KXQvo579ZuqO7cEbuQn-gMvtUlGyHiHwNjNDTf6GhDPGjs4FaTV6g~KOzKDE4imUl6nhQOM3pktQ1D8Mi6Ry0dkMLijAXA8Na0JVXsPyxYQ8Y6mmKq7D-3HS7OG9WOK7i7HdQApZQWEDpfh34cdA__&Key-Pair-Id=KVTP0A1DKRTAX [following]
--2022-12-01 17:48:19-- https://cdn-lfs.huggingface.co/repos/ec/ee/ecee38df047e3f2db1bd8c31a742f3a08f557470cd67cb487402a9c3ed91b5ea/90c13ab981715e1fc1ae079f15fb6da36d61d6aad29ae5dddd4d3bfd4594546a?response-content-disposition=attachment%3B%20filename%3D%22encoder.ckpt%22&Expires=1670165709&Policy=eyJTdGF0ZW1lbnQiOlt7IlJlc291cmNlIjoiaHR0cHM6Ly9jZG4tbGZzLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2VjL2VlL2VjZWUzOGRmMDQ3ZTNmMmRiMWJkOGMzMWE3NDJmM2EwOGY1NTc0NzBjZDY3Y2I0ODc0MDJhOWMzZWQ5MWI1ZWEvOTBjMTNhYjk4MTcxNWUxZmMxYWUwNzlmMTVmYjZkYTM2ZDYxZDZhYWQyOWFlNWRkZGQ0ZDNiZmQ0NTk0NTQ2YT9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPWF0dGFjaG1lbnQlM0IlMjBmaWxlbmFtZSUzRCUyMmVuY29kZXIuY2twdCUyMiIsIkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTY3MDE2NTcwOX19fV19&Signature=tppl1MHYKCBPl1qmPJq2DoBMHFJMssgCO-6lTTUY79LaUJuFqTV0XnSrlRxbAVKJk5gnAg-OenbPpMAnRkxqTnKEUZFmd3rZritCLbIR9dGyBj93E6wrcOjOuM69kz--z2YGtaWqcCyrniuJHtDFO24PciKp6LteNdhzdSqfUguzBPRJ3r-iQJ8aCPMjaUeAB6KXQvo579ZuqO7cEbuQn-gMvtUlGyHiHwNjNDTf6GhDPGjs4FaTV6g~KOzKDE4imUl6nhQOM3pktQ1D8Mi6Ry0dkMLijAXA8Na0JVXsPyxYQ8Y6mmKq7D-3HS7OG9WOK7i7HdQApZQWEDpfh34cdA__&Key-Pair-Id=KVTP0A1DKRTAX
Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 108.138.94.122, 108.138.94.25, 108.138.94.14, ...
Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|108.138.94.122|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 48624134 (46M) [binary/octet-stream]
Saving to: ‘encoder.ckpt’
encoder.ckpt 100%[===================>] 46.37M 121MB/s in 0.4s
2022-12-01 17:48:19 (121 MB/s) - ‘encoder.ckpt’ saved [48624134/48624134]
--2022-12-01 17:48:20-- https://huggingface.co/csteinmetz1/automix-toolkit/resolve/main/DSD100subset.zip
Resolving huggingface.co (huggingface.co)... 54.147.99.175, 34.227.196.80, 2600:1f18:147f:e800:3df1:c2fc:20aa:9b45, ...
Connecting to huggingface.co (huggingface.co)|54.147.99.175|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/ec/ee/ecee38df047e3f2db1bd8c31a742f3a08f557470cd67cb487402a9c3ed91b5ea/3544bf18ffbea78aee3273ba8267a6cb15aa04b52bc430e2f39755d40d212208?response-content-disposition=attachment%3B%20filename%3D%22DSD100subset.zip%22&Expires=1670176100&Policy=eyJTdGF0ZW1lbnQiOlt7IlJlc291cmNlIjoiaHR0cHM6Ly9jZG4tbGZzLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2VjL2VlL2VjZWUzOGRmMDQ3ZTNmMmRiMWJkOGMzMWE3NDJmM2EwOGY1NTc0NzBjZDY3Y2I0ODc0MDJhOWMzZWQ5MWI1ZWEvMzU0NGJmMThmZmJlYTc4YWVlMzI3M2JhODI2N2E2Y2IxNWFhMDRiNTJiYzQzMGUyZjM5NzU1ZDQwZDIxMjIwOD9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPWF0dGFjaG1lbnQlM0IlMjBmaWxlbmFtZSUzRCUyMkRTRDEwMHN1YnNldC56aXAlMjIiLCJDb25kaXRpb24iOnsiRGF0ZUxlc3NUaGFuIjp7IkFXUzpFcG9jaFRpbWUiOjE2NzAxNzYxMDB9fX1dfQ__&Signature=0ybLBK0AJD~mWLyg~Lp1o2gD96bVJ0mRTYkScvfMf5l97oAZnqpmtI4FLHw49OvJugCzoyL5Dc46A5Vy2N28x0~64D1YDSdgoTa6o7tA~DVI857Qjw0~7Ljm8fPdzao6nZnMA9wmbPmUgqj9FnW1J81oTguXLn6jVtrsGQZYbXiN6JOhs~XWfne5m9RsScrzwjxypdBRD158vtRLtOgQCGcKChw83fh3KKMQG5VVSWAQpOMDnrEVhRx8SPCEBmOJx0sTF2SQQ1X94IelnEisF5dOeVgeZSralsEAbKSOA~3FX2wR4KmtFVJY2mvCUgTaz2BNLr57E3r9RrJDJdRZ0w__&Key-Pair-Id=KVTP0A1DKRTAX [following]
--2022-12-01 17:48:20-- https://cdn-lfs.huggingface.co/repos/ec/ee/ecee38df047e3f2db1bd8c31a742f3a08f557470cd67cb487402a9c3ed91b5ea/3544bf18ffbea78aee3273ba8267a6cb15aa04b52bc430e2f39755d40d212208?response-content-disposition=attachment%3B%20filename%3D%22DSD100subset.zip%22&Expires=1670176100&Policy=eyJTdGF0ZW1lbnQiOlt7IlJlc291cmNlIjoiaHR0cHM6Ly9jZG4tbGZzLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2VjL2VlL2VjZWUzOGRmMDQ3ZTNmMmRiMWJkOGMzMWE3NDJmM2EwOGY1NTc0NzBjZDY3Y2I0ODc0MDJhOWMzZWQ5MWI1ZWEvMzU0NGJmMThmZmJlYTc4YWVlMzI3M2JhODI2N2E2Y2IxNWFhMDRiNTJiYzQzMGUyZjM5NzU1ZDQwZDIxMjIwOD9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPWF0dGFjaG1lbnQlM0IlMjBmaWxlbmFtZSUzRCUyMkRTRDEwMHN1YnNldC56aXAlMjIiLCJDb25kaXRpb24iOnsiRGF0ZUxlc3NUaGFuIjp7IkFXUzpFcG9jaFRpbWUiOjE2NzAxNzYxMDB9fX1dfQ__&Signature=0ybLBK0AJD~mWLyg~Lp1o2gD96bVJ0mRTYkScvfMf5l97oAZnqpmtI4FLHw49OvJugCzoyL5Dc46A5Vy2N28x0~64D1YDSdgoTa6o7tA~DVI857Qjw0~7Ljm8fPdzao6nZnMA9wmbPmUgqj9FnW1J81oTguXLn6jVtrsGQZYbXiN6JOhs~XWfne5m9RsScrzwjxypdBRD158vtRLtOgQCGcKChw83fh3KKMQG5VVSWAQpOMDnrEVhRx8SPCEBmOJx0sTF2SQQ1X94IelnEisF5dOeVgeZSralsEAbKSOA~3FX2wR4KmtFVJY2mvCUgTaz2BNLr57E3r9RrJDJdRZ0w__&Key-Pair-Id=KVTP0A1DKRTAX
Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 108.138.94.25, 108.138.94.14, 108.138.94.23, ...
Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|108.138.94.25|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 126074934 (120M) [application/zip]
Saving to: ‘DSD100subset.zip’
DSD100subset.zip 100%[===================>] 120.23M 37.3MB/s in 3.2s
2022-12-01 17:48:23 (37.3 MB/s) - ‘DSD100subset.zip’ saved [126074934/126074934]
Archive: DSD100subset.zip
creating: DSD100subset/
inflating: DSD100subset/dsd100.xlsx
creating: DSD100subset/Sources/
creating: DSD100subset/Sources/Dev/
creating: DSD100subset/Sources/Dev/081 - Patrick Talbot - Set Me Free/
inflating: DSD100subset/Sources/Dev/081 - Patrick Talbot - Set Me Free/drums.wav
inflating: DSD100subset/Sources/Dev/081 - Patrick Talbot - Set Me Free/other.wav
inflating: DSD100subset/Sources/Dev/081 - Patrick Talbot - Set Me Free/bass.wav
inflating: DSD100subset/Sources/Dev/081 - Patrick Talbot - Set Me Free/vocals.wav
creating: DSD100subset/Sources/Dev/055 - Angels In Amplifiers - I'm Alright/
inflating: DSD100subset/Sources/Dev/055 - Angels In Amplifiers - I'm Alright/vocals.wav
inflating: DSD100subset/Sources/Dev/055 - Angels In Amplifiers - I'm Alright/bass.wav
inflating: DSD100subset/Sources/Dev/055 - Angels In Amplifiers - I'm Alright/drums.wav
inflating: DSD100subset/Sources/Dev/055 - Angels In Amplifiers - I'm Alright/other.wav
creating: DSD100subset/Sources/Test/
creating: DSD100subset/Sources/Test/049 - Young Griffo - Facade/
inflating: DSD100subset/Sources/Test/049 - Young Griffo - Facade/bass.wav
inflating: DSD100subset/Sources/Test/049 - Young Griffo - Facade/vocals.wav
inflating: DSD100subset/Sources/Test/049 - Young Griffo - Facade/other.wav
inflating: DSD100subset/Sources/Test/049 - Young Griffo - Facade/drums.wav
creating: DSD100subset/Sources/Test/005 - Angela Thomas Wade - Milk Cow Blues/
inflating: DSD100subset/Sources/Test/005 - Angela Thomas Wade - Milk Cow Blues/vocals.wav
inflating: DSD100subset/Sources/Test/005 - Angela Thomas Wade - Milk Cow Blues/drums.wav
inflating: DSD100subset/Sources/Test/005 - Angela Thomas Wade - Milk Cow Blues/other.wav
inflating: DSD100subset/Sources/Test/005 - Angela Thomas Wade - Milk Cow Blues/bass.wav
creating: DSD100subset/Mixtures/
creating: DSD100subset/Mixtures/Test/
creating: DSD100subset/Mixtures/Test/005 - Angela Thomas Wade - Milk Cow Blues/
inflating: DSD100subset/Mixtures/Test/005 - Angela Thomas Wade - Milk Cow Blues/mixture.wav
creating: DSD100subset/Mixtures/Test/049 - Young Griffo - Facade/
inflating: DSD100subset/Mixtures/Test/049 - Young Griffo - Facade/mixture.wav
creating: DSD100subset/Mixtures/Dev/
creating: DSD100subset/Mixtures/Dev/055 - Angels In Amplifiers - I'm Alright/
inflating: DSD100subset/Mixtures/Dev/055 - Angels In Amplifiers - I'm Alright/mixture.wav
creating: DSD100subset/Mixtures/Dev/081 - Patrick Talbot - Set Me Free/
inflating: DSD100subset/Mixtures/Dev/081 - Patrick Talbot - Set Me Free/mixture.wav
inflating: DSD100subset/dsd100subset.txt
Configuration#
Here we select where we want to train on CPU or GPU and what model we will use.
!nvidia-smi # check for GPU
Thu Dec 1 17:48:26 2022
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03 Driver Version: 460.32.03 CUDA Version: 11.2 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |
| N/A 56C P0 29W / 70W | 0MiB / 15109MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+
args = {
"dataset_dir" : "./DSD100subset",
"dataset_name" : "DSD100",
"automix_model" : "dmc",
"pretrained_encoder" : True,
"train_length" : 65536,
"val_length" : 65536,
"accelerator" : "gpu", # you can select "cpu" or "gpu"
"devices" : 1,
"batch_size" : 4,
"lr" : 3e-4,
"max_epochs" : 25,
"schedule" : "none",
"recon_losses" : ["sd"],
"recon_loss_weights" : [1.0],
"sample_rate" : 44100,
"num_workers" : 2,
}
args = Namespace(**args)
pl.seed_everything(42, workers=True)
INFO:lightning_lite.utilities.seed:Global seed set to 42
42
# setup callbacks
callbacks = [
#LogAudioCallback(),
pl.callbacks.LearningRateMonitor(logging_interval="step"),
pl.callbacks.ModelCheckpoint(
filename=f"{args.dataset_name}-{args.automix_model}"
+ "_epoch-{epoch}-step-{step}",
monitor="val/loss_epoch",
mode="min",
save_last=True,
auto_insert_metric_name=False,
),
]
# we not will use weights and biases
#wandb_logger = WandbLogger(save_dir=log_dir, project="automix-notebook")
# create PyTorch Lightning trainer
trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks)
# create the System
system = System(**vars(args))
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.8/dist-packages/torchaudio/functional/functional.py:539: UserWarning: At least one mel filterbank has all zero values. The value for `n_mels` (128) may be set too high. Or, the value for `n_freqs` (257) may be set too low.
warnings.warn(
Loaded weights from ./checkpoints/encoder.ckpt
Dataset#
Now we will create datasets for train/val/test but we will use the same four songs across all sets here for demonstration purposes.
train_dataset = DSD100Dataset(
args.dataset_dir,
args.train_length,
44100,
indices=[0, 4],
num_examples_per_epoch=100,
)
val_dataset = DSD100Dataset(
args.dataset_dir,
args.val_length,
44100,
indices=[0, 4],
num_examples_per_epoch=100,
)
test_dataset = DSD100Dataset(
args.dataset_dir,
args.train_length,
44100,
indices=[0, 4],
num_examples_per_epoch=100,
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
persistent_workers=True,
)
val_dataloader = torch.utils.data.DataLoader(
val_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
persistent_workers=False,
)
100%|██████████| 4/4 [00:00<00:00, 894.93it/s]
Found 4 mixes. Using 4 in this subset.
100%|██████████| 4/4 [00:00<00:00, 2512.69it/s]
Found 4 mixes. Using 4 in this subset.
100%|██████████| 4/4 [00:00<00:00, 4591.47it/s]
Found 4 mixes. Using 4 in this subset.
Logging#
We can launch an instance of TensorBoard within our notebook to monitor the training process. Be patient, it can take ~60 seconds for the window to show.
%load_ext tensorboard
%tensorboard --logdir="lightning_logs"
Train!#
Now we are ready to launch the training process.
trainer.fit(system, train_dataloader, val_dataloader)
WARNING:pytorch_lightning.loggers.tensorboard:Missing logger folder: /content/lightning_logs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
| Name | Type | Params
-------------------------------------------------------------
0 | model | DifferentiableMixingConsole | 12.4 M
1 | recon_losses | ModuleDict | 0
2 | sisdr | SISDRLoss | 0
3 | mrstft | MultiResolutionSTFTLoss | 0
-------------------------------------------------------------
12.4 M Trainable params
0 Non-trainable params
12.4 M Total params
49.732 Total estimated model params size (MB)
/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py:1558: PossibleUserWarning: The number of training batches (25) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
rank_zero_warn(
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=25` reached.
Test#
After training for a few epochs we will test the system by creating a mix from one of the songs that was in the training set.
import glob
import torchaudio
start_sample = 262144 * 2
end_sample = 262144 * 3
# load the input tracks
track_dir = "DSD100subset/Sources/Dev/081 - Patrick Talbot - Set Me Free/"
track_ext = "wav"
track_filepaths = glob.glob(os.path.join(track_dir, f"*.{track_ext}"))
track_filepaths = sorted(track_filepaths)
track_names = []
tracks = []
for idx, track_filepath in enumerate(track_filepaths):
x, sr = torchaudio.load(track_filepath)
x = x[:, start_sample: end_sample]
for n in range(x.shape[0]):
x_sub = x[n:n+1, :]
gain_dB = np.random.rand() * 12
gain_dB *= np.random.choice([1.0, -1.0])
gain_ln = 10 ** (gain_dB/20.0)
x_sub *= gain_ln
tracks.append(x_sub)
track_names.append(os.path.basename(track_filepath))
IPython.display.display(ipd.Audio(x[n, :].view(1,-1).numpy(), rate=sr, normalize=True))
print(idx+1, os.path.basename(track_filepath))
# add dummy tracks of silence if needed
if system.hparams.automix_model == "mixwaveunet" and len(tracks) < 8:
tracks.append(torch.zeros(x.shape))
# stack tracks into a tensor
tracks = torch.stack(tracks, dim=0)
tracks = tracks.permute(1, 0, 2)
# tracks have shape (1, num_tracks, seq_len)
print(tracks.shape)
# listen to the input (mono) before mixing
input_mix = tracks.sum(dim=1, keepdim=True)
input_mix /= input_mix.abs().max()
print(input_mix.shape)
plt.figure(figsize=(10, 2))
librosa.display.waveshow(input_mix.view(2,-1).numpy(), sr=sr, zorder=3)
plt.ylim([-1,1])
plt.grid(c="lightgray")
plt.show()
IPython.display.display(ipd.Audio(input_mix.view(1,-1).numpy(), rate=sr, normalize=False))
1 bass.wav
1 bass.wav
2 drums.wav
2 drums.wav
3 other.wav
3 other.wav
4 vocals.wav
4 vocals.wav
torch.Size([1, 8, 262144])
torch.Size([1, 1, 262144])
/usr/local/lib/python3.8/dist-packages/librosa/util/utils.py:198: UserWarning: librosa.util.frame called with axis=-1 on a non-contiguous input. This will result in a copy.
warnings.warn(
Above we can hear the tracks with a simple mono mix. Now we will create a mix with the model we just trained.
tracks = tracks.view(1,8,-1)
with torch.no_grad():
y_hat, p = system(tracks)
# view the mix
print(y_hat.shape)
y_hat /= y_hat.abs().max()
plt.figure(figsize=(10, 2))
librosa.display.waveshow(y_hat.view(2,-1).cpu().numpy(), sr=sr, zorder=3)
plt.ylim([-1,1])
plt.grid(c="lightgray")
plt.show()
IPython.display.display(ipd.Audio(y_hat.view(2,-1).cpu().numpy(), rate=sr, normalize=True))
# print the parameters
if system.hparams.automix_model == "dmc":
for track_fp, param in zip(track_names, p.squeeze()):
print(os.path.basename(track_fp), param)
torch.Size([1, 2, 262144])
/usr/local/lib/python3.8/dist-packages/librosa/util/utils.py:198: UserWarning: librosa.util.frame called with axis=-1 on a non-contiguous input. This will result in a copy.
warnings.warn(
bass.wav tensor([-7.9330, 0.6107])
bass.wav tensor([0.7972, 0.6918])
drums.wav tensor([-0.9448, 0.6856])
drums.wav tensor([-2.4009, 0.6927])
other.wav tensor([5.5242, 0.7276])
other.wav tensor([10.0456, 0.7959])
vocals.wav tensor([8.6141, 0.7512])
vocals.wav tensor([-0.8755, 0.6691])
You should be able to hear that the levels have been adjusted and the sources panned to sound more like the original mix indicating that our system learned to overfit the songs in our very small training set.