Models#

In this notebook we will dig into how the two automatic mixing models we discussed can be implemented in PyTorch. As usual, we will assume you have already installed the automix package from automix-toolkit. If not you can do it with the following command:

!pip install git+https://github.com/csteinmetz1/automix-toolkit
import os
import torch
import numpy as np
from automix.utils import count_parameters

MixWaveUNet#

First, we will take a look at the Mix-Wave-U-Net. Recall that this model is based on Wave-U-Net a time domain audio source separation model that is itself based on the famous U-Net architecture.

The overall architecture for the network is comprised of two types of blocks: the Downsampling blocks (shown on the left) and the Upsampling blocks (shown on the right). In the network we apply a certain number of these blocks, downsampling and then upsampling the signal at different temporal resolutions. Unique to U-Net like architectuers is the characteratistic skip connections that carry information from the each level in the downsampling branch to the respective branch in the upsampling brach.

We can start by importing the MixWaveUNet class

from automix.models.mixwaveunet import MixWaveUNet

Then we can construct this model supplying the desired hyperparameters. Below we will create a version of the model that accepts 8 input channels and produces a stereo (2) mix. We will use the default downsampling and upsampling kernel size of 13 and use a kernel size of 5 for the final output convolution. As in the original MixWaveUNet we use 12 down and upsampling blocks and increase the number of convolutional channels by 24 at each block. We also have the option to use either additive “add” or concatative “concat” skip connections. In this case, we will follow the original model and use concatenation.

model = MixWaveUNet( 
    ninputs = 8,    # the number of input recordings we can mix (this is fixed at training)
    noutputs = 2,   # the number of channels for the mix. Normally this is 2 for stereo mix
    ds_kernel = 13, # kernel size for the convolutional layers in the Downsampling Blocks
    us_kernel = 13, # kernel size for the convolutional layers in the Upsampling Blocks
    out_kernel = 5, # kernel size for the convolutional layer in the final layer
    layers = 12,    # Number of blocks in the upsampling and downsampling paths
    ch_growth = 24, # Number convolutional channels to add at each layer
    skip = "concat" # We can use either "add" or "concat" skip connections. ("add" will save parameters and memory)
)

Then we can count the number of parameters in this model.

print(f"{count_parameters(model)/1e6:0.3f} M")
17.833 M

This model is currently untrained, but we can demonstrate how we can generate a mix from this model. The model expects as input a tensor of shape (batch_size, num_tracks, seq_len). In this case num_tracks is equal to the number of input recordings that we want to mix together and seq_len corresponds to the number of samples in each recording. Since we will stack this into a single tensor it requires that each recording is of the same length. Let’s consider the following example of how we could generate mix from this untrained model.

batch_size = 2
num_tracks = 8
seq_length = 262144

x = torch.randn(batch_size, num_tracks, seq_length)
y_hat, p = model(x)

print(x.shape, y_hat.shape)
torch.Size([2, 8, 262144]) torch.Size([2, 2, 262144])

You can see that after passing in 8 tracks we will get two stereo mixes, one for each of the batch items. Importantly, note that calling model(x) will return two values. The first, y_hat is the mixture. The second value representas the parameters that created the mix. However, the parameters will only be populated for models that use explicit parameters like the DMC. So in this case we can see that p is a zero tensor.

print(p)
tensor([0.])

Now that we understand the basic operation of the Mix-Wave-U-Net at a high level, let’s investigate the inner workings of the model. To do so, we will define and connect the inner components of the original MixWaveUNet class.

Downsampling Block#

Let’s start with the DownsamplingBlock. Here we have reproduced the implementation from the automix package. It is composed of a few basic submodules. It starts with a Conv1d, BatchNorm1d, then PReLU activation, and a final Conv1d that has stride=2. While the original Wave-U-Net used decimation to downsample we can also use a strided convolution to achieve a similar downsampling operation.

class DownsamplingBlock(torch.nn.Module):
    def __init__(
        self,
        ch_in: int,
        ch_out: int,
        kernel_size: int = 15,
    ):
        super().__init__()

        assert kernel_size % 2 != 0  # kernel must be odd length
        padding = kernel_size // 2  # calculate same padding

        self.conv1 = torch.nn.Conv1d(
            ch_in,
            ch_out,
            kernel_size=kernel_size,
            padding=padding,
        )
        self.bn = torch.nn.BatchNorm1d(ch_out)
        self.prelu = torch.nn.PReLU(ch_out)
        self.conv2 = torch.nn.Conv1d(
            ch_out,
            ch_out,
            kernel_size=kernel_size,
            stride=2,
            padding=padding,
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn(x)
        x = self.prelu(x)
        x_ds = self.conv2(x)
        return x_ds, x

Note that the forward() will return two tensors:

  • x_ds - the downsampled (by factor of 2) signal.

  • x - the processed signal before any downsampling.

This is so we can save the itermediate tensors before downsampling so they can be used in the upsampling process when we employ the skip connections.

Upsampling Block#

As we see below, the UpsamplingBlock follows a very similar pattern to the DownsamplingBlock with the exceptation that is upsamples the signals (of course) but also that we have to handle aggregating information from the skip connections. As before we have a similar series connection of convolution, batch normalization, activations, and in this case a linear upsampling block.

class UpsamplingBlock(torch.nn.Module):
    def __init__(
        self,
        ch_in: int,
        ch_out: int,
        kernel_size: int = 5,
        skip: str = "add",
    ):
        super().__init__()

        assert kernel_size % 2 != 0  # kernel must be odd length
        padding = kernel_size // 2  # calculate same padding

        self.skip = skip
        self.conv = torch.nn.Conv1d(
            ch_in,
            ch_out,
            kernel_size=kernel_size,
            padding=padding,
        )
        self.bn = torch.nn.BatchNorm1d(ch_out)
        self.prelu = torch.nn.PReLU(ch_out)
        self.us = torch.nn.Upsample(scale_factor=2)

    def forward(self, x: torch.Tensor, skip: torch.Tensor):
        x = self.us(x)  # upsample by x2

        # handle skip connections
        if self.skip == "add":
            x = x + skip
        elif self.skip == "concat":
            x = torch.cat((x, skip), dim=1)
        elif self.skip == "none":
            pass
        else:
            raise NotImplementedError()

        x = self.conv(x)
        x = self.bn(x)
        x = self.prelu(x)
        return x

We can see here a unique part of the forward() in the UpsamplingBlock is that takes as input two tensors:

  • x - the output of the previous upsampling layer.

  • skip - the output of the respective downsampling layer (same resolution) which creates a skip connection.

We can then see that we first upsample the input tensor x and then combine it with the skip connection. In our implementation we include a few different options. In the case of "add" we will simply sum the two tensors which is a pointwise sum between the signals in each channel. This is not as expressive but saves memory and lower the parameter count. In the "concat" case we will concatenate the two tensors along the channel dimension which will result in a new tensor that has twice the number of channels. However, this provides more flexibility since the convolutional layer that follows can decide how to mix these signals together. Finally, there is also the option to forgo the skip connections.

Encoder#

Now let’s use these building blocks to construct the Mix-Wave-U-Net. We will start with the Encoder, which is composed of series connection of DownsamplingBlocks. We will use a for loop to construct each layer and then store them in a ModuleList. At the first layer, we will ensure the convolution accepts that same number of channels as there are input recordings (ninputs). For the other blocks, we will increase the number of channels by ch_growth each iteration (ch_growth = 24).

ninputs = 8     # the number of input recordings we can mix (this is fixed at training)
noutputs = 2    # the number of channels for the mix. Normally this is 2 for stereo mix
ds_kernel = 13  # kernel size for the convolutional layers in the Downsampling Blocks
us_kernel = 13  # kernel size for the convolutional layers in the Upsampling Blocks
out_kernel = 5  # kernel size for the convolutional layer in the final layer
layers = 12     # Number of blocks in the upsampling and downsampling paths
ch_growth = 24  # Number convolutional channels to add at each layer
skip = "concat" 
encoder = torch.nn.ModuleList()

for n in np.arange(layers):
    if n == 0:
        ch_in = ninputs
        ch_out = ch_growth
    else:
        ch_in = ch_out
        ch_out = ch_in + ch_growth

    encoder.append(DownsamplingBlock(ch_in, ch_out, kernel_size=ds_kernel))

And now we can see the layers we created.

print(encoder)
ModuleList(
  (0): DownsamplingBlock(
    (conv1): Conv1d(8, 24, kernel_size=(13,), stride=(1,), padding=(6,))
    (bn): BatchNorm1d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (prelu): PReLU(num_parameters=24)
    (conv2): Conv1d(24, 24, kernel_size=(13,), stride=(2,), padding=(6,))
  )
  (1): DownsamplingBlock(
    (conv1): Conv1d(24, 48, kernel_size=(13,), stride=(1,), padding=(6,))
    (bn): BatchNorm1d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (prelu): PReLU(num_parameters=48)
    (conv2): Conv1d(48, 48, kernel_size=(13,), stride=(2,), padding=(6,))
  )
  (2): DownsamplingBlock(
    (conv1): Conv1d(48, 72, kernel_size=(13,), stride=(1,), padding=(6,))
    (bn): BatchNorm1d(72, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (prelu): PReLU(num_parameters=72)
    (conv2): Conv1d(72, 72, kernel_size=(13,), stride=(2,), padding=(6,))
  )
  (3): DownsamplingBlock(
    (conv1): Conv1d(72, 96, kernel_size=(13,), stride=(1,), padding=(6,))
    (bn): BatchNorm1d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (prelu): PReLU(num_parameters=96)
    (conv2): Conv1d(96, 96, kernel_size=(13,), stride=(2,), padding=(6,))
  )
  (4): DownsamplingBlock(
    (conv1): Conv1d(96, 120, kernel_size=(13,), stride=(1,), padding=(6,))
    (bn): BatchNorm1d(120, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (prelu): PReLU(num_parameters=120)
    (conv2): Conv1d(120, 120, kernel_size=(13,), stride=(2,), padding=(6,))
  )
  (5): DownsamplingBlock(
    (conv1): Conv1d(120, 144, kernel_size=(13,), stride=(1,), padding=(6,))
    (bn): BatchNorm1d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (prelu): PReLU(num_parameters=144)
    (conv2): Conv1d(144, 144, kernel_size=(13,), stride=(2,), padding=(6,))
  )
  (6): DownsamplingBlock(
    (conv1): Conv1d(144, 168, kernel_size=(13,), stride=(1,), padding=(6,))
    (bn): BatchNorm1d(168, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (prelu): PReLU(num_parameters=168)
    (conv2): Conv1d(168, 168, kernel_size=(13,), stride=(2,), padding=(6,))
  )
  (7): DownsamplingBlock(
    (conv1): Conv1d(168, 192, kernel_size=(13,), stride=(1,), padding=(6,))
    (bn): BatchNorm1d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (prelu): PReLU(num_parameters=192)
    (conv2): Conv1d(192, 192, kernel_size=(13,), stride=(2,), padding=(6,))
  )
  (8): DownsamplingBlock(
    (conv1): Conv1d(192, 216, kernel_size=(13,), stride=(1,), padding=(6,))
    (bn): BatchNorm1d(216, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (prelu): PReLU(num_parameters=216)
    (conv2): Conv1d(216, 216, kernel_size=(13,), stride=(2,), padding=(6,))
  )
  (9): DownsamplingBlock(
    (conv1): Conv1d(216, 240, kernel_size=(13,), stride=(1,), padding=(6,))
    (bn): BatchNorm1d(240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (prelu): PReLU(num_parameters=240)
    (conv2): Conv1d(240, 240, kernel_size=(13,), stride=(2,), padding=(6,))
  )
  (10): DownsamplingBlock(
    (conv1): Conv1d(240, 264, kernel_size=(13,), stride=(1,), padding=(6,))
    (bn): BatchNorm1d(264, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (prelu): PReLU(num_parameters=264)
    (conv2): Conv1d(264, 264, kernel_size=(13,), stride=(2,), padding=(6,))
  )
  (11): DownsamplingBlock(
    (conv1): Conv1d(264, 288, kernel_size=(13,), stride=(1,), padding=(6,))
    (bn): BatchNorm1d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (prelu): PReLU(num_parameters=288)
    (conv2): Conv1d(288, 288, kernel_size=(13,), stride=(2,), padding=(6,))
  )
)

Embedding/Latent#

In the middle of the network we will include a single convolutional layer to produce the latent embedding.

embedding = torch.nn.Conv1d(ch_out, ch_out, kernel_size=1)
print(embedding)
Conv1d(288, 288, kernel_size=(1,), stride=(1,))

Decoder#

In a similar manner to the encoder, we will construct the decoder by storing UpsamplingBlocks in a ModuleList. However, in this case we will count backwards (step=-1) as we create the layers, starting with the number of channels in the final layer of the encoder, decrementing this value by ch_growth at each iteration. Note also that when we use "concat" skip connections we will double the number of channels in each block to accomidate the additional channels from the skip connections.

decoder = torch.nn.ModuleList()
for n in np.arange(layers, stop=0, step=-1):

    ch_in = ch_out
    ch_out = ch_in - ch_growth

    if ch_out < ch_growth:
        ch_out = ch_growth

    if skip == "concat":
        ch_in *= 2

    decoder.append(
        UpsamplingBlock(
            ch_in,
            ch_out,
            kernel_size=us_kernel,
            skip=skip,
        )
    )
print(decoder)
ModuleList(
  (0): UpsamplingBlock(
    (conv): Conv1d(576, 264, kernel_size=(13,), stride=(1,), padding=(6,))
    (bn): BatchNorm1d(264, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (prelu): PReLU(num_parameters=264)
    (us): Upsample(scale_factor=2.0, mode='nearest')
  )
  (1): UpsamplingBlock(
    (conv): Conv1d(528, 240, kernel_size=(13,), stride=(1,), padding=(6,))
    (bn): BatchNorm1d(240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (prelu): PReLU(num_parameters=240)
    (us): Upsample(scale_factor=2.0, mode='nearest')
  )
  (2): UpsamplingBlock(
    (conv): Conv1d(480, 216, kernel_size=(13,), stride=(1,), padding=(6,))
    (bn): BatchNorm1d(216, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (prelu): PReLU(num_parameters=216)
    (us): Upsample(scale_factor=2.0, mode='nearest')
  )
  (3): UpsamplingBlock(
    (conv): Conv1d(432, 192, kernel_size=(13,), stride=(1,), padding=(6,))
    (bn): BatchNorm1d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (prelu): PReLU(num_parameters=192)
    (us): Upsample(scale_factor=2.0, mode='nearest')
  )
  (4): UpsamplingBlock(
    (conv): Conv1d(384, 168, kernel_size=(13,), stride=(1,), padding=(6,))
    (bn): BatchNorm1d(168, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (prelu): PReLU(num_parameters=168)
    (us): Upsample(scale_factor=2.0, mode='nearest')
  )
  (5): UpsamplingBlock(
    (conv): Conv1d(336, 144, kernel_size=(13,), stride=(1,), padding=(6,))
    (bn): BatchNorm1d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (prelu): PReLU(num_parameters=144)
    (us): Upsample(scale_factor=2.0, mode='nearest')
  )
  (6): UpsamplingBlock(
    (conv): Conv1d(288, 120, kernel_size=(13,), stride=(1,), padding=(6,))
    (bn): BatchNorm1d(120, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (prelu): PReLU(num_parameters=120)
    (us): Upsample(scale_factor=2.0, mode='nearest')
  )
  (7): UpsamplingBlock(
    (conv): Conv1d(240, 96, kernel_size=(13,), stride=(1,), padding=(6,))
    (bn): BatchNorm1d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (prelu): PReLU(num_parameters=96)
    (us): Upsample(scale_factor=2.0, mode='nearest')
  )
  (8): UpsamplingBlock(
    (conv): Conv1d(192, 72, kernel_size=(13,), stride=(1,), padding=(6,))
    (bn): BatchNorm1d(72, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (prelu): PReLU(num_parameters=72)
    (us): Upsample(scale_factor=2.0, mode='nearest')
  )
  (9): UpsamplingBlock(
    (conv): Conv1d(144, 48, kernel_size=(13,), stride=(1,), padding=(6,))
    (bn): BatchNorm1d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (prelu): PReLU(num_parameters=48)
    (us): Upsample(scale_factor=2.0, mode='nearest')
  )
  (10): UpsamplingBlock(
    (conv): Conv1d(96, 24, kernel_size=(13,), stride=(1,), padding=(6,))
    (bn): BatchNorm1d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (prelu): PReLU(num_parameters=24)
    (us): Upsample(scale_factor=2.0, mode='nearest')
  )
  (11): UpsamplingBlock(
    (conv): Conv1d(48, 24, kernel_size=(13,), stride=(1,), padding=(6,))
    (bn): BatchNorm1d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (prelu): PReLU(num_parameters=24)
    (us): Upsample(scale_factor=2.0, mode='nearest')
  )
)

Output#

Finally we have the output convolution which will collect the output channels from the final layer of the decoder and map them to the stereo mixture (noutputs=2)

output_conv = torch.nn.Conv1d(
    ch_out + ninputs,
    noutputs,
    kernel_size=out_kernel,
    padding=out_kernel // 2,
)

Forward#

The forward pass of the modell involves simply iterating over the blocks in the encoder and storing the outputs as well as the skip connections. We will just use a list to store these tensors. Then we pass the final output from the encoder to the embedding() layer and use another for loop to iterate over the blocks in the decoder. This time we process the signals each time passing the respective skip connection. We use skips.pop() to return the last skip connection from the encoder (LIFO). Finally we implement the last skip connection which uses the original input recordings x_in. Again, recall that we return a torch.zeros(1) as dummy tensor since this model does not give us interpretable parameters

def forward(x: torch.Tensor): 
    x_in = x
    skips = [] # storage 

    for enc in encoder:
        x, skip = enc(x)
        skips.append(skip)

    x = embedding(x)

    for dec in decoder:
        skip = skips.pop()
        x = dec(x, skip)

    x = torch.cat((x_in, x), dim=1)
    y = output_conv(x)

    return y, torch.zeros(1)  # return dummy parameters

We can then test the model just as we did before and see the results are the same.

batch_size = 2
num_tracks = 8
seq_length = 262144

x = torch.randn(batch_size, num_tracks, seq_length)
y_hat, p = forward(x)

print(x.shape, y_hat.shape)
torch.Size([2, 8, 262144]) torch.Size([2, 2, 262144])

Differentiable Mixing Console (DMC)#

Now that we have seen how the Mix-Wave-U-Net, a direct transformation approach, can be implemented, we will shift our focus to the Differentiable Mixing Console, which is a parameter estimation approach.

Similar to our explanation before we will first inspect the model from a high level and then go through the basic compotnents of the model so we can get an understanding of their operation.

from automix.models.dmc import DifferentiableMixingConsole, PostProcessor, Mixer, ShortChunkCNN_Res
from automix.utils import restore_from_0to1

First we will create the main modules of the system.

  • ShortChunkCNN_res - This is our encoder. We use an encoder that operates on melspectrograms and has been pretrained.

  • PostProcessor - This is a MLP that will project our embeddings to the parameters of the mixing console.

  • Mixer - The differnetiable mixer class. In this case our mixer supports gain and stereo panning operations.

We will also need to download the pretrained model checkpoint for the encoder.

os.makedirs("checkpoints", exist_ok=True)
# download the pretrained models for the encoder
!wget https://huggingface.co/csteinmetz1/automix-toolkit/resolve/main/encoder.ckpt
!mv encoder.ckpt checkpoints/encoder.ckpt
encoder_ckpt_path = "checkpoints/encoder.ckpt"
--2024-08-29 16:41:32--  https://huggingface.co/csteinmetz1/automix-toolkit/resolve/main/encoder.ckpt
Resolving huggingface.co (huggingface.co)... 2600:9000:2751:e200:17:b174:6d00:93a1, 2600:9000:2751:9e00:17:b174:6d00:93a1, 2600:9000:2751:cc00:17:b174:6d00:93a1, ...
Connecting to huggingface.co (huggingface.co)|2600:9000:2751:e200:17:b174:6d00:93a1|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/ec/ee/ecee38df047e3f2db1bd8c31a742f3a08f557470cd67cb487402a9c3ed91b5ea/90c13ab981715e1fc1ae079f15fb6da36d61d6aad29ae5dddd4d3bfd4594546a?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27encoder.ckpt%3B+filename%3D%22encoder.ckpt%22%3B&Expires=1725176471&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyNTE3NjQ3MX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy9lYy9lZS9lY2VlMzhkZjA0N2UzZjJkYjFiZDhjMzFhNzQyZjNhMDhmNTU3NDcwY2Q2N2NiNDg3NDAyYTljM2VkOTFiNWVhLzkwYzEzYWI5ODE3MTVlMWZjMWFlMDc5ZjE1ZmI2ZGEzNmQ2MWQ2YWFkMjlhZTVkZGRkNGQzYmZkNDU5NDU0NmE%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=DwbNKNJC-jgNqziEOYqUuU4t0i0T-3ZRH3Zx0fmO-TRwLD8HFSKZSg0PFpu%7ExmbakSF0d%7EfTRIsNq7Q5qi0saUswTUWQSEELZB%7EVOTBBe5IY10rYFOUjSaaOPsMj%7EFes52vicMWcRD1dWaGgoqxY98RyT1DtcIoGFp-FC9QAYYrHK-vHV8P7ZD9jDKptlAaIWv9MG9TLtJSSVIFz3iEkWrdCy5ylklwijgSrENluSiefNmStg4z9T8j5Or42F4-KsLE2zmn8jyvRoMOa8ClxlzlhJ5PCx2lxL5VxC%7En7vZDAQyIuM7BSj4rPueylWxaDOXtHXgf84oHsNdPTRzTyrA__&Key-Pair-Id=K3ESJI6DHPFC7 [following]
--2024-08-29 16:41:32--  https://cdn-lfs.huggingface.co/repos/ec/ee/ecee38df047e3f2db1bd8c31a742f3a08f557470cd67cb487402a9c3ed91b5ea/90c13ab981715e1fc1ae079f15fb6da36d61d6aad29ae5dddd4d3bfd4594546a?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27encoder.ckpt%3B+filename%3D%22encoder.ckpt%22%3B&Expires=1725176471&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyNTE3NjQ3MX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy9lYy9lZS9lY2VlMzhkZjA0N2UzZjJkYjFiZDhjMzFhNzQyZjNhMDhmNTU3NDcwY2Q2N2NiNDg3NDAyYTljM2VkOTFiNWVhLzkwYzEzYWI5ODE3MTVlMWZjMWFlMDc5ZjE1ZmI2ZGEzNmQ2MWQ2YWFkMjlhZTVkZGRkNGQzYmZkNDU5NDU0NmE%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=DwbNKNJC-jgNqziEOYqUuU4t0i0T-3ZRH3Zx0fmO-TRwLD8HFSKZSg0PFpu%7ExmbakSF0d%7EfTRIsNq7Q5qi0saUswTUWQSEELZB%7EVOTBBe5IY10rYFOUjSaaOPsMj%7EFes52vicMWcRD1dWaGgoqxY98RyT1DtcIoGFp-FC9QAYYrHK-vHV8P7ZD9jDKptlAaIWv9MG9TLtJSSVIFz3iEkWrdCy5ylklwijgSrENluSiefNmStg4z9T8j5Or42F4-KsLE2zmn8jyvRoMOa8ClxlzlhJ5PCx2lxL5VxC%7En7vZDAQyIuM7BSj4rPueylWxaDOXtHXgf84oHsNdPTRzTyrA__&Key-Pair-Id=K3ESJI6DHPFC7
Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 2600:9000:20c4:4800:11:f807:5180:93a1, 2600:9000:20c4:b000:11:f807:5180:93a1, 2600:9000:20c4:8400:11:f807:5180:93a1, ...
Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|2600:9000:20c4:4800:11:f807:5180:93a1|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 48624134 (46M) [binary/octet-stream]
Saving to: ‘encoder.ckpt’

encoder.ckpt        100%[===================>]  46.37M  15.3MB/s    in 3.0s    

2024-08-29 16:41:37 (15.3 MB/s) - ‘encoder.ckpt’ saved [48624134/48624134]

Encoder#

The role of the encoder is extract information from each input recording that will be used in order to create a mix. This implicitly involves determining the identity of the source (e.g. drums, guitar, vocal, etc.) as well as other factors such as the level. We adopt a very standard 2d convolutional network that operates on log melspectrograms. In the original paper that authors used the VGGish architecture pretrained on AudioSet.

To faciliate faster training and simpler code we opt to the Short Chunk CNN (with residual connections) which is very similar but faciliates easy computation of melspectrograms with torchaudio. In addition, we use a pretrained checkpoint after training the model on a music tagging task, which should aid in learning. We will not go into detail of how the encoder itself is implemented but you can see the details here.

sample_rate = 44100
encoder = ShortChunkCNN_Res(sample_rate, ckpt_path=encoder_ckpt_path)
print(encoder)
Loaded weights from checkpoints/encoder.ckpt
ShortChunkCNN_Res(
  (spec): MelSpectrogram(
    (spectrogram): Spectrogram()
    (mel_scale): MelScale()
  )
  (to_db): AmplitudeToDB()
  (spec_bn): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Res_2d(
    (conv_1): Conv2d(1, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (bn_1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn_2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_3): Conv2d(1, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (bn_3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (layer2): Res_2d(
    (conv_1): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (bn_1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn_2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_3): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (bn_3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (layer3): Res_2d(
    (conv_1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (bn_1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn_2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_3): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (bn_3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (layer4): Res_2d(
    (conv_1): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (bn_1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn_2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_3): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (bn_3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (layer5): Res_2d(
    (conv_1): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (bn_1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn_2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_3): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (bn_3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (layer6): Res_2d(
    (conv_1): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (bn_1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn_2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_3): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (bn_3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (layer7): Res_2d(
    (conv_1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (bn_1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn_2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_3): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (bn_3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (dense1): Linear(in_features=512, out_features=512, bias=True)
  (bn): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dense2): Linear(in_features=512, out_features=50, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
  (relu): ReLU()
  (resample): Resample()
)
/home/martinez/Documents/anaconda3/envs/dafx24/lib/python3.9/site-packages/torchaudio/functional/functional.py:584: UserWarning: At least one mel filterbank has all zero values. The value for `n_mels` (128) may be set too high. Or, the value for `n_freqs` (257) may be set too low.
  warnings.warn(

Post-Processor#

The role of the Post-Processor is to take the track embedding and context embedding (for each track and context pair) and compute a set of control parameters for the current track. We can implement this as a simple multi-layer perceptron (MLP) with three layers. This network will use a sigmoid activation function to map all outputs between 0 and 1. This is the format our Mixer expects. Inside the Mixer these parameters will be denormalized to the correct ranges. In the original paper the authors use a tanh activation so that parameters are scaled between -1 and 1, but this is just a design choice.

class PostProcessor(torch.nn.Module):
    def __init__(self, num_params: int, d_embed: int) -> None:
        super().__init__()
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(d_embed, 256),
            torch.nn.Dropout(0.2),
            torch.nn.PReLU(),
            torch.nn.Linear(256, 256),
            torch.nn.Dropout(0.2),
            torch.nn.PReLU(),
            torch.nn.Linear(256, num_params),
            torch.nn.Sigmoid(),
        )

    def forward(self, z: torch.Tensor):
        return self.mlp(z)
    
postprocessor = PostProcessor(2, 2 * encoder.d_embed)
print(postprocessor)
PostProcessor(
  (mlp): Sequential(
    (0): Linear(in_features=1024, out_features=256, bias=True)
    (1): Dropout(p=0.2, inplace=False)
    (2): PReLU(num_parameters=1)
    (3): Linear(in_features=256, out_features=256, bias=True)
    (4): Dropout(p=0.2, inplace=False)
    (5): PReLU(num_parameters=1)
    (6): Linear(in_features=256, out_features=2, bias=True)
    (7): Sigmoid()
  )
)

Mixer#

The role of the mixer is to process the individual channels in a mix given the control parameters and produce a stereo mix. In the original paper the mixer included the Transformation Network. This network was first pretrained to emulate common audio effects like an equalizer, compressor, and reveberation model. In our setup we will consider the simple case for the Transformation Network which uses only gain (level) and panning parameters. Since these operations are differentiable we do not need to worry about the proxy method or any other differentiable signal processing techniques.

As you can see below we will implement the differentiable mixer by simply applying the gain and panning operations, which enables the use of autodiff for the gradient computation during training. This is both fast and memory efficient. An extension of our implemntation could add in more effects like equalization or compression.

class Mixer(torch.nn.Module):
    def __init__(
        self,
        sample_rate: float,
        min_gain_dB: int = -48.0,
        max_gain_dB: int = 24.0,
    ) -> None:
        super().__init__()
        self.num_params = 2
        self.param_names = ["Gain dB", "Pan"]
        self.sample_rate = sample_rate
        self.min_gain_dB = min_gain_dB
        self.max_gain_dB = max_gain_dB

    def forward(self, x: torch.Tensor, p: torch.Tensor):
        """Generate a mix of stems given mixing parameters normalized to (0,1).

        Args:
            x (torch.Tensor): Batch of waveform stem tensors with shape (bs, num_tracks, seq_len).
            p (torch.Tensor): Batch of normalized mixing parameters (0,1) for each stem with shape (bs, num_tracks, num_params)

        Returns:
            y (torch.Tensor): Batch of stereo waveform mixes with shape (bs, 2, seq_len)
        """
        bs, num_tracks, seq_len = x.size()

        # ------------- apply gain -------------
        gain_dB = p[..., 0]  # get gain parameter
        gain_dB = restore_from_0to1(gain_dB, self.min_gain_dB, self.max_gain_dB)
        gain_lin = 10 ** (gain_dB / 20.0)  # convert gain from dB scale to linear
        gain_lin = gain_lin.view(bs, num_tracks, 1)  # reshape for multiplication
        x = x * gain_lin  # apply gain (bs, num_tracks, seq_len)

        # ------------- apply panning -------------
        # expand mono stems to stereo, then apply panning
        x = x.view(bs, num_tracks, 1, -1)  # (bs, num_tracks, 1, seq_len)
        x = x.repeat(1, 1, 2, 1)  # (bs, num_tracks, 2, seq_len)

        pan = p[..., 1]  # get pan parameter
        pan_theta = pan * torch.pi / 2
        left_gain = torch.cos(pan_theta)
        right_gain = torch.sin(pan_theta)
        pan_gains_lin = torch.stack([left_gain, right_gain], dim=-1)
        pan_gains_lin = pan_gains_lin.view(bs, num_tracks, 2, 1)  # reshape for multiply
        x = x * pan_gains_lin  # (bs, num_tracks, 2, seq_len)

        # ----------------- apply mix -------------
        # generate a mix for each batch item by summing stereo tracks
        y = torch.sum(x, dim=1)  # (bs, 2, seq_len)

        p = torch.cat(
            (
                gain_dB.view(bs, num_tracks, 1),
                pan.view(bs, num_tracks, 1),
            ),
            dim=-1,
        )

        return y, p
mixer = Mixer(sample_rate)
print(mixer)
Mixer()

Here we will set up some inputs that we can use for our example. You can adjust these values to see how the results change. In this case we will use a batch size of 2, mixes with 8 input recordings each approx 3 sec in length (at a sample rate of 44100). Then we will generate a tensor of noise to represent theses tracks.

batch_size = 2
num_tracks = 4
num_samples = 131072

x = torch.randn(batch_size, num_tracks, num_samples)
bs, num_tracks, seq_len = x.size()

Generating embeddings#

As first the step we will need to generate embeddings with our encoder for each of the input recordings in each batch item. Since each batch item will contain multiple tracks (in this case 4) one option would be to loop over each track and pass them to the encoder one-by-one. However, this will create an unnecessary bottleneck. Instead, we use a small trick to compute all of the embeddings in the batch at once.

We do this simply by moving all of the input recordings into the batch dimension. This will give us an effective batch size of eff_bs = bs * num_tracks. After moving all the recordings to the batch dimension, we can then pass them to the encoder, which expects a tensor of shape (bs, seq_len). After generating these embeddings e it is simply a matter of reshaping the tensor so we can restore each embedding from the respective mix to the original dimension, which gives us a tensor of shape (bs, num_tracks, d_embed).

# move tracks to the batch dimension to fully parallelize embedding computation
x = x.view(bs * num_tracks, -1)
print(f"We get {bs}x{num_tracks} items in first dim: {x.shape}")

# generate single embedding for each track
z = encoder(x)
z = z.view(bs, num_tracks, -1)  # (bs, num_tracks, d_embed)
print(f"We get {num_tracks} embeddings of size {encoder.d_embed}: {z.shape}")
We get 2x4 items in first dim: torch.Size([8, 131072])
We get 4 embeddings of size 512: torch.Size([2, 4, 512])

“Context” embedding#

Key to the DMC is the concept of the “context” emebdding which enables effective cross-channel communication between the recordings within a mixture when the post-processor will make a decision about the parameters for each channel. We compute the context embedding by simply taking the mean of all the track embeddings for each batch item. We can see this in the figure about represented as \(z_\mu\). After taking this mean we then copy (using torch.repeat) the mean embedding once for each track. This way we can then concatenate these copied embeddings with each of the track embeddings.

However, recall that during training we use a fixed number of tracks and therefore some songs may have less than num_tracks active tracks where the other tracks are simply silence. These empty tracks will corrupt our context embedding. One way to handle this is to use the track_mask which is also provided by the dataset. This will be a tensor of boolean values telling us which tracks are not active, and show be masked. For example, consider the case where we have four total tracks but only the first three are active in the first batch item and all are active in the second. We would use the following track_mask.

track_mask = torch.tensor(
    [[False, False, False, True], 
     [False, False, False, False]]).view(2,-1)
print(track_mask)
tensor([[False, False, False,  True],
        [False, False, False, False]])
# generate the "context" embedding
c = []
for bidx in range(bs): # loop over each batch for "dynamic" context computation
    c_n = z[bidx, ~track_mask[bidx, :], :].mean(
        dim=0, keepdim=True
    )  # (bs, 1, d_embed)
    c_n = c_n.repeat(num_tracks, 1)  # (bs, num_tracks, d_embed)
    c.append(c_n)
c = torch.stack(c, dim=0)
print(c.shape)
torch.Size([2, 4, 512])

Note: Another way to implement this could be to fill the embeddings for non-active tracks with zeros and then take the sum across each batch item. Then to get the mean we could divide each sum by the number of False values in each track_mask. This would enable us to avoid the for loop.

At the end of this process we will have num_tracks embeddings each of size d_embed*2 after the concatentation.

# fuse the track embs and context embs
z_final = torch.cat((z, c), dim=-1)  # (bs, num_tracks, d_embed*2)
print("final embedding", z_final.shape)
final embedding torch.Size([2, 4, 1024])

Estimate mixing parameters#

Now that we have the embeddings for ecah track we will use the Post-processor to estimate the mixing parameters (gain and panning) for each track. This will require running each of the num_tracks embeddings through the Post-Processor. However, the MLP class in PyTorch enables us to compute these in parallel automatically. So by passing our final embedding tensor of shape (bs, num_tracks, d_embed*2) into the Post-Processor we can generate all the mixing parameters. As we see below, we will get a parameter tensor containing 2 parameters (gain and pan) for each of the num_tracks. Almost there, now the final step is to use these parameters and the Mixer to create the mix.

# estimate mixing parameters for each track (in parallel)
p = postprocessor(z_final)  # (bs, num_tracks, num_params)
print(p.shape)
torch.Size([2, 4, 2])

Generate the mix#

We already discussed how the Mixer is implemented. Here we will call the mixer passing in the tracks as well as the parameters we just predicted. Inside the Mixer these parameters will be denormalized from 0 to 1 to their full range.

We can see that we get two return values from calling the Mixer. The first is the stereo mix, which is the same length as the inputs but has only two channels. We also get a new tensor for the parameters, which is the same shape. This tensor contains the parameter values, but in their denormalized state. This will enable us to inspect what parameters were estimated by the model in the human interpretable form.

# generate the stereo mix
x = x.view(bs, num_tracks, -1)  # move tracks back from batch dim
y, p = mixer(x, p)  # (bs, 2, seq_len) # and denormalized params
print(y.shape, p.shape)
torch.Size([2, 2, 131072]) torch.Size([2, 4, 2])

We can easily print out the parameters for each track as follows. In this example all of the parameters are very similar since we are using an untrained network and noise as input.

for tidx, track_params in enumerate(p[0,...]):
    print(f"{tidx} gain dB:{track_params[0]:0.3f}  pan:{track_params[1]:0.3f}")
0 gain dB:-11.733  pan:0.504
1 gain dB:-11.911  pan:0.509
2 gain dB:-11.903  pan:0.509
3 gain dB:-11.732  pan:0.505

That concludes the section on the models. Hopefully this provided some insight into the innerworkings of these two automatic mixing models. Both implementations are simple and could be built on to extend their features.