Training#

In this notebook we will go through the basic process of training a an automatic mixing model. This will involve combining a dataset with a model and an appropriate training loop. For this demonstration we will PyTorch Lightning to faciliate the training.

Dataset#

For this demonstration we will use the subset of the DSD100 dataset. This is a music source separation data, but we will use it to demonstrate how you can train a model. This is a very small subset of the dataset so it can easily be downloaded and we should not expect that our model will perform very well after training.

This notebook can be used as a starting point for example by swapping out the dataset for a different dataset such as ENST-drums or MedleyDB after they have been downloaded. Since they are quite large, we will focus only on this small dataset for demonstration purposes.

GPU#

This notebook supports training with the GPU. You can achieve this by setting the Runtime to GPU in Colab using the menu bar at the top.

Learn More#

If you want to train these models on your own server and have much more control beyond this demo we encourage you to take a look at the training recipes we provide in the automix-toolkit repository.

But, let’s get started by installing the automix-toolkit.

!pip install git+https://github.com/csteinmetz1/automix-toolkit
import os
import torch
import pytorch_lightning as pl
import IPython
import numpy as np
import IPython.display as ipd
import matplotlib.pyplot as plt
import librosa.display

from argparse import Namespace

%matplotlib inline
%load_ext autoreload
%autoreload 2

from automix.data import DSD100Dataset
from automix.system import System
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload

First we will download the dataset subset and unzip the archive as well as the pretrained encoder checkpoint.

os.makedirs("checkpoints/", exist_ok=True)
!wget https://huggingface.co/csteinmetz1/automix-toolkit/resolve/main/encoder.ckpt
!mv encoder.ckpt checkpoints/encoder.ckpt
    
!wget https://huggingface.co/csteinmetz1/automix-toolkit/resolve/main/DSD100subset.zip
!unzip -o DSD100subset.zip 
--2024-08-29 16:42:07--  https://huggingface.co/csteinmetz1/automix-toolkit/resolve/main/encoder.ckpt
Resolving huggingface.co (huggingface.co)... 2600:9000:2751:7200:17:b174:6d00:93a1, 2600:9000:2751:9e00:17:b174:6d00:93a1, 2600:9000:2751:cc00:17:b174:6d00:93a1, ...
Connecting to huggingface.co (huggingface.co)|2600:9000:2751:7200:17:b174:6d00:93a1|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/ec/ee/ecee38df047e3f2db1bd8c31a742f3a08f557470cd67cb487402a9c3ed91b5ea/90c13ab981715e1fc1ae079f15fb6da36d61d6aad29ae5dddd4d3bfd4594546a?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27encoder.ckpt%3B+filename%3D%22encoder.ckpt%22%3B&Expires=1725176505&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyNTE3NjUwNX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy9lYy9lZS9lY2VlMzhkZjA0N2UzZjJkYjFiZDhjMzFhNzQyZjNhMDhmNTU3NDcwY2Q2N2NiNDg3NDAyYTljM2VkOTFiNWVhLzkwYzEzYWI5ODE3MTVlMWZjMWFlMDc5ZjE1ZmI2ZGEzNmQ2MWQ2YWFkMjlhZTVkZGRkNGQzYmZkNDU5NDU0NmE%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=QoeA-Ixt7pEFQ4Mt-c5r1hURxV38SQezM3Zp9XBuxvhFmiRvsOErkYwFKchfJ6ErRIO1SS66jxy6slIgn74vqImSYn3c-q0FVdIVwRoLATER5bjvOfi0L055fqYTyt1GGKtYS3G-Beteme2QctWGln8HnYa8FrEZexrN-k1fIE55FPPirhaO65MLASimPqFaEC7QEdRBAxspUkiNKpkH4%7EL2vgJ%7EKbwFZwVZ0GmzODYo5IWlMD6fJvHhpzTQunGpcHXldItD%7EKf9Hy6JuhWQmx-xA8zrAsn7j2PzR4vVgIhBjyzb2v5GIMJSNfibUa8OzWfxMUUUJlenH8CIxpvGAA__&Key-Pair-Id=K3ESJI6DHPFC7 [following]
--2024-08-29 16:42:07--  https://cdn-lfs.huggingface.co/repos/ec/ee/ecee38df047e3f2db1bd8c31a742f3a08f557470cd67cb487402a9c3ed91b5ea/90c13ab981715e1fc1ae079f15fb6da36d61d6aad29ae5dddd4d3bfd4594546a?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27encoder.ckpt%3B+filename%3D%22encoder.ckpt%22%3B&Expires=1725176505&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyNTE3NjUwNX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy9lYy9lZS9lY2VlMzhkZjA0N2UzZjJkYjFiZDhjMzFhNzQyZjNhMDhmNTU3NDcwY2Q2N2NiNDg3NDAyYTljM2VkOTFiNWVhLzkwYzEzYWI5ODE3MTVlMWZjMWFlMDc5ZjE1ZmI2ZGEzNmQ2MWQ2YWFkMjlhZTVkZGRkNGQzYmZkNDU5NDU0NmE%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=QoeA-Ixt7pEFQ4Mt-c5r1hURxV38SQezM3Zp9XBuxvhFmiRvsOErkYwFKchfJ6ErRIO1SS66jxy6slIgn74vqImSYn3c-q0FVdIVwRoLATER5bjvOfi0L055fqYTyt1GGKtYS3G-Beteme2QctWGln8HnYa8FrEZexrN-k1fIE55FPPirhaO65MLASimPqFaEC7QEdRBAxspUkiNKpkH4%7EL2vgJ%7EKbwFZwVZ0GmzODYo5IWlMD6fJvHhpzTQunGpcHXldItD%7EKf9Hy6JuhWQmx-xA8zrAsn7j2PzR4vVgIhBjyzb2v5GIMJSNfibUa8OzWfxMUUUJlenH8CIxpvGAA__&Key-Pair-Id=K3ESJI6DHPFC7
Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 2600:9000:20c4:c800:11:f807:5180:93a1, 2600:9000:20c4:6600:11:f807:5180:93a1, 2600:9000:20c4:3800:11:f807:5180:93a1, ...
Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|2600:9000:20c4:c800:11:f807:5180:93a1|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 48624134 (46M) [binary/octet-stream]
Saving to: ‘encoder.ckpt’

encoder.ckpt        100%[===================>]  46.37M  82.4MB/s    in 0.6s    

2024-08-29 16:42:08 (82.4 MB/s) - ‘encoder.ckpt’ saved [48624134/48624134]

--2024-08-29 16:42:08--  https://huggingface.co/csteinmetz1/automix-toolkit/resolve/main/DSD100subset.zip
Resolving huggingface.co (huggingface.co)... 2600:9000:2751:cc00:17:b174:6d00:93a1, 2600:9000:2751:9600:17:b174:6d00:93a1, 2600:9000:2751:9e00:17:b174:6d00:93a1, ...
Connecting to huggingface.co (huggingface.co)|2600:9000:2751:cc00:17:b174:6d00:93a1|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/ec/ee/ecee38df047e3f2db1bd8c31a742f3a08f557470cd67cb487402a9c3ed91b5ea/3544bf18ffbea78aee3273ba8267a6cb15aa04b52bc430e2f39755d40d212208?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27DSD100subset.zip%3B+filename%3D%22DSD100subset.zip%22%3B&response-content-type=application%2Fzip&Expires=1725176507&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyNTE3NjUwN319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy9lYy9lZS9lY2VlMzhkZjA0N2UzZjJkYjFiZDhjMzFhNzQyZjNhMDhmNTU3NDcwY2Q2N2NiNDg3NDAyYTljM2VkOTFiNWVhLzM1NDRiZjE4ZmZiZWE3OGFlZTMyNzNiYTgyNjdhNmNiMTVhYTA0YjUyYmM0MzBlMmYzOTc1NWQ0MGQyMTIyMDg%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qJnJlc3BvbnNlLWNvbnRlbnQtdHlwZT0qIn1dfQ__&Signature=l%7Ehny95FpLsm7MF-K00Mk%7EVwRPNNiECmczrvl1hZNRYU38jBED9VrhlMVr9Yz0rgl-X8JUx-yWzdO%7EKEodjPwzxy0hVs26AQwqWu5oCj41fnPUDyM-ZAin-gjftQkF2AjmgwypQ9jFT-dt%7EwYl84dtsNRtK6p5LJkB8NRb9%7EmB-i7VozVCuG%7E8%7Eq2f6wQkFqdJp007Npew3vHQcKEZevX19zgYli6OJ9H7sDk9Q4kCCB7H2u4RsN6fNayyWs%7E6XhdVHwcEaVZbM2Qazmsa9LZ7sI5-Y-MBT6lb2JweL103t77tHR3IcNn1DFANLdMEaCFkiwYcB4vepvIO2pt3-3PQ__&Key-Pair-Id=K3ESJI6DHPFC7 [following]
--2024-08-29 16:42:09--  https://cdn-lfs.huggingface.co/repos/ec/ee/ecee38df047e3f2db1bd8c31a742f3a08f557470cd67cb487402a9c3ed91b5ea/3544bf18ffbea78aee3273ba8267a6cb15aa04b52bc430e2f39755d40d212208?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27DSD100subset.zip%3B+filename%3D%22DSD100subset.zip%22%3B&response-content-type=application%2Fzip&Expires=1725176507&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyNTE3NjUwN319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy9lYy9lZS9lY2VlMzhkZjA0N2UzZjJkYjFiZDhjMzFhNzQyZjNhMDhmNTU3NDcwY2Q2N2NiNDg3NDAyYTljM2VkOTFiNWVhLzM1NDRiZjE4ZmZiZWE3OGFlZTMyNzNiYTgyNjdhNmNiMTVhYTA0YjUyYmM0MzBlMmYzOTc1NWQ0MGQyMTIyMDg%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qJnJlc3BvbnNlLWNvbnRlbnQtdHlwZT0qIn1dfQ__&Signature=l%7Ehny95FpLsm7MF-K00Mk%7EVwRPNNiECmczrvl1hZNRYU38jBED9VrhlMVr9Yz0rgl-X8JUx-yWzdO%7EKEodjPwzxy0hVs26AQwqWu5oCj41fnPUDyM-ZAin-gjftQkF2AjmgwypQ9jFT-dt%7EwYl84dtsNRtK6p5LJkB8NRb9%7EmB-i7VozVCuG%7E8%7Eq2f6wQkFqdJp007Npew3vHQcKEZevX19zgYli6OJ9H7sDk9Q4kCCB7H2u4RsN6fNayyWs%7E6XhdVHwcEaVZbM2Qazmsa9LZ7sI5-Y-MBT6lb2JweL103t77tHR3IcNn1DFANLdMEaCFkiwYcB4vepvIO2pt3-3PQ__&Key-Pair-Id=K3ESJI6DHPFC7
Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 2600:9000:20c4:a400:11:f807:5180:93a1, 2600:9000:20c4:5200:11:f807:5180:93a1, 2600:9000:20c4:6600:11:f807:5180:93a1, ...
Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|2600:9000:20c4:a400:11:f807:5180:93a1|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 126074934 (120M) [application/zip]
Saving to: ‘DSD100subset.zip.8’

DSD100subset.zip.8  100%[===================>] 120.23M   103MB/s    in 1.2s    

2024-08-29 16:42:10 (103 MB/s) - ‘DSD100subset.zip.8’ saved [126074934/126074934]

Archive:  DSD100subset.zip
  inflating: DSD100subset/dsd100.xlsx  
  inflating: DSD100subset/Sources/Dev/081 - Patrick Talbot - Set Me Free/drums.wav  
  inflating: DSD100subset/Sources/Dev/081 - Patrick Talbot - Set Me Free/other.wav  
  inflating: DSD100subset/Sources/Dev/081 - Patrick Talbot - Set Me Free/bass.wav  
  inflating: DSD100subset/Sources/Dev/081 - Patrick Talbot - Set Me Free/vocals.wav  
  inflating: DSD100subset/Sources/Dev/055 - Angels In Amplifiers - I'm Alright/vocals.wav  
  inflating: DSD100subset/Sources/Dev/055 - Angels In Amplifiers - I'm Alright/bass.wav  
  inflating: DSD100subset/Sources/Dev/055 - Angels In Amplifiers - I'm Alright/drums.wav  
  inflating: DSD100subset/Sources/Dev/055 - Angels In Amplifiers - I'm Alright/other.wav  
  inflating: DSD100subset/Sources/Test/049 - Young Griffo - Facade/bass.wav  
  inflating: DSD100subset/Sources/Test/049 - Young Griffo - Facade/vocals.wav  
  inflating: DSD100subset/Sources/Test/049 - Young Griffo - Facade/other.wav  
  inflating: DSD100subset/Sources/Test/049 - Young Griffo - Facade/drums.wav  
  inflating: DSD100subset/Sources/Test/005 - Angela Thomas Wade - Milk Cow Blues/vocals.wav  
  inflating: DSD100subset/Sources/Test/005 - Angela Thomas Wade - Milk Cow Blues/drums.wav  
  inflating: DSD100subset/Sources/Test/005 - Angela Thomas Wade - Milk Cow Blues/other.wav  
  inflating: DSD100subset/Sources/Test/005 - Angela Thomas Wade - Milk Cow Blues/bass.wav  
  inflating: DSD100subset/Mixtures/Test/005 - Angela Thomas Wade - Milk Cow Blues/mixture.wav  
  inflating: DSD100subset/Mixtures/Test/049 - Young Griffo - Facade/mixture.wav  
  inflating: DSD100subset/Mixtures/Dev/055 - Angels In Amplifiers - I'm Alright/mixture.wav  
  inflating: DSD100subset/Mixtures/Dev/081 - Patrick Talbot - Set Me Free/mixture.wav  
  inflating: DSD100subset/dsd100subset.txt  

Configuration#

Here we select where we want to train on CPU or GPU and what model we will use.

!nvidia-smi # check for GPU
Thu Aug 29 16:42:11 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.01             Driver Version: 535.183.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA RTX A6000               Off | 00000000:18:00.0 Off |                  Off |
| 30%   31C    P8              16W / 250W |    359MiB / 49140MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA RTX A6000               Off | 00000000:3B:00.0 Off |                  Off |
| 30%   33C    P8              23W / 250W |     12MiB / 49140MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA RTX A6000               Off | 00000000:86:00.0 Off |                  Off |
| 30%   32C    P8              15W / 250W |     12MiB / 49140MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   3  NVIDIA RTX A6000               Off | 00000000:AF:00.0 Off |                  Off |
| 30%   30C    P8              26W / 250W |     12MiB / 49140MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A      2096      G   /usr/lib/xorg/Xorg                            4MiB |
|    0   N/A  N/A   2427920      C   .../anaconda3/envs/dafxtest/bin/python      344MiB |
|    1   N/A  N/A      2096      G   /usr/lib/xorg/Xorg                            4MiB |
|    2   N/A  N/A      2096      G   /usr/lib/xorg/Xorg                            4MiB |
|    3   N/A  N/A      2096      G   /usr/lib/xorg/Xorg                            4MiB |
+---------------------------------------------------------------------------------------+
args = {
    "dataset_dir" :  "./DSD100subset",
    "dataset_name" : "DSD100",
    "automix_model" : "dmc",
    "pretrained_encoder" : True,
    "train_length" : 65536,
    "val_length" : 65536,
    "accelerator" : "gpu", # you can select "cpu" or "gpu"
    "devices" : 1, 
    "batch_size" : 4,
    "lr" : 3e-4,
    "max_epochs" : 10, 
    "schedule" : "none",
    "recon_losses" : ["sd"],
    "recon_loss_weights" : [1.0],
    "sample_rate" : 44100,
    "num_workers" : 2,
}
args = Namespace(**args)
    
pl.seed_everything(42, workers=True)
Seed set to 42
42
# setup callbacks
callbacks = [
    #LogAudioCallback(),
    pl.callbacks.LearningRateMonitor(logging_interval="step"),
    pl.callbacks.ModelCheckpoint(
        filename=f"{args.dataset_name}-{args.automix_model}"
        + "_epoch-{epoch}-step-{step}",
        monitor="val/loss_epoch",
        mode="min",
        save_last=True,
        auto_insert_metric_name=False,
    ),
]

# we not will use weights and biases
#wandb_logger = WandbLogger(save_dir=log_dir, project="automix-notebook")

# create PyTorch Lightning trainer
# trainer = pl.Trainer(args, callbacks=callbacks)

trainer = pl.Trainer(
    max_epochs=args.max_epochs,
    accelerator=args.accelerator,
    devices=args.devices,
    callbacks=callbacks,
    # Add other trainer arguments here if needed
)

# create the System
system = System(**vars(args))
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/martinez/Documents/anaconda3/envs/dafx24/lib/python3.9/site-packages/torchaudio/functional/functional.py:584: UserWarning: At least one mel filterbank has all zero values. The value for `n_mels` (128) may be set too high. Or, the value for `n_freqs` (257) may be set too low.
  warnings.warn(
Loaded weights from ./checkpoints/encoder.ckpt

Dataset#

Now we will create datasets for train/val/test but we will use the same four songs across all sets here for demonstration purposes.

train_dataset = DSD100Dataset(
    args.dataset_dir,
    args.train_length,
    44100,
    indices=[0, 4],
    num_examples_per_epoch=100,
)
val_dataset = DSD100Dataset(
    args.dataset_dir,
    args.val_length,
    44100,
    indices=[0, 4],
    num_examples_per_epoch=100,
)
test_dataset = DSD100Dataset(
    args.dataset_dir,
    args.train_length,
    44100,
    indices=[0, 4],
    num_examples_per_epoch=100,
)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    shuffle=False,
    num_workers=args.num_workers,
    persistent_workers=True,
)
val_dataloader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=args.batch_size,
    shuffle=False,
    num_workers=args.num_workers,
    persistent_workers=False,
)
100%|██████████| 4/4 [00:00<00:00, 3587.94it/s]
Found 4 mixes. Using 4 in this subset.
100%|██████████| 4/4 [00:00<00:00, 3676.80it/s]
Found 4 mixes. Using 4 in this subset.
100%|██████████| 4/4 [00:00<00:00, 3710.95it/s]
Found 4 mixes. Using 4 in this subset.

Logging#

We can launch an instance of TensorBoard within our notebook to monitor the training process. Be patient, it can take ~60 seconds for the window to show.

%load_ext tensorboard
%tensorboard  --logdir="lightning_logs"
Reusing TensorBoard on port 6006 (pid 2428432), started 23 days, 22:22:30 ago. (Use '!kill 2428432' to kill it.)

Train!#

Now we are ready to launch the training process.

trainer.fit(system, train_dataloader, val_dataloader)
You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name         | Type                        | Params | Mode 
---------------------------------------------------------------------
0 | model        | DifferentiableMixingConsole | 12.4 M | train
1 | recon_losses | ModuleDict                  | 0      | train
2 | sisdr        | SISDRLoss                   | 0      | train
3 | mrstft       | MultiResolutionSTFTLoss     | 0      | train
---------------------------------------------------------------------
12.4 M    Trainable params
0         Non-trainable params
12.4 M    Total params
49.732    Total estimated model params size (MB)
/home/martinez/Documents/anaconda3/envs/dafx24/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (25) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
`Trainer.fit` stopped: `max_epochs=10` reached.

Test#

After training for a few epochs we will test the system by creating a mix from one of the songs that was in the training set.

import glob
import torchaudio
start_sample = 262144 * 2
end_sample = 262144 * 3

# load the input tracks
track_dir = "DSD100subset/Sources/Dev/081 - Patrick Talbot - Set Me Free/"
track_ext = "wav"

track_filepaths = glob.glob(os.path.join(track_dir, f"*.{track_ext}"))
track_filepaths = sorted(track_filepaths)
track_names = []
tracks = []
for idx, track_filepath in enumerate(track_filepaths):
    x, sr = torchaudio.load(track_filepath)
    x = x[:, start_sample: end_sample]



    for n in range(x.shape[0]):
      x_sub = x[n:n+1, :]

      gain_dB = np.random.rand() * 12
      gain_dB *= np.random.choice([1.0, -1.0])
      gain_ln = 10 ** (gain_dB/20.0)
      x_sub *= gain_ln 

      tracks.append(x_sub)
      track_names.append(os.path.basename(track_filepath))
      IPython.display.display(ipd.Audio(x[n, :].view(1,-1).numpy(), rate=sr, normalize=True))    
      print(idx+1, os.path.basename(track_filepath))

# add dummy tracks of silence if needed
if system.hparams.automix_model == "mixwaveunet" and len(tracks) < 8:
    tracks.append(torch.zeros(x.shape))

# stack tracks into a tensor
tracks = torch.stack(tracks, dim=0)
tracks = tracks.permute(1, 0, 2)
# tracks have shape (1, num_tracks, seq_len)
print(tracks.shape)

# listen to the input (mono) before mixing
input_mix = tracks.sum(dim=1, keepdim=True)
input_mix /= input_mix.abs().max()
print(input_mix.shape)
plt.figure(figsize=(10, 2))
librosa.display.waveshow(input_mix.view(2,-1).numpy(), sr=sr, zorder=3)
plt.ylim([-1,1])
plt.grid(c="lightgray")
plt.show()
IPython.display.display(ipd.Audio(input_mix.view(1,-1).numpy(), rate=sr, normalize=False))
1 bass.wav
1 bass.wav
2 drums.wav
2 drums.wav
3 other.wav
3 other.wav
4 vocals.wav
4 vocals.wav
torch.Size([1, 8, 262144])
torch.Size([1, 1, 262144])
../_images/419c82249af720fd96e1ecc42a3b536c8b776ee5cbf7fe10697c3592b756275b.png

Above we can hear the tracks with a simple mono mix. Now we will create a mix with the model we just trained.

tracks = tracks.view(1,8,-1)

with torch.no_grad():
  y_hat, p = system(tracks)
  
# view the mix
print(y_hat.shape)
y_hat /= y_hat.abs().max()
plt.figure(figsize=(10, 2))
librosa.display.waveshow(y_hat.view(2,-1).cpu().numpy(), sr=sr, zorder=3)
plt.ylim([-1,1])
plt.grid(c="lightgray")
plt.show()
IPython.display.display(ipd.Audio(y_hat.view(2,-1).cpu().numpy(), rate=sr, normalize=True))

# print the parameters
if system.hparams.automix_model == "dmc":
    for track_fp, param in zip(track_names, p.squeeze()):
        print(os.path.basename(track_fp), param)
torch.Size([1, 2, 262144])
../_images/7e95fa3a9cace0eed83691073bc0d5fbcd69102dbbf5dba2be72cf77ed799cbe.png
bass.wav tensor([-8.9969,  0.5640])
bass.wav tensor([-4.5226,  0.6432])
drums.wav tensor([-6.7178,  0.5868])
drums.wav tensor([-5.4589,  0.6139])
other.wav tensor([-0.3160,  0.6668])
other.wav tensor([6.6966, 0.7568])
vocals.wav tensor([6.4484, 0.7839])
vocals.wav tensor([-2.8909,  0.6442])

You should be able to hear that the levels have been adjusted and the sources panned to sound more like the original mix indicating that our system learned to overfit the songs in our very small training set.