├── model ├── lstm.py ├── conv.py └── seanet.py ├── utils.py ├── LICENSE ├── requirements.txt ├── synthesize.ipynb ├── datasets.py ├── README.md ├── inference.ipynb ├── generate_data.py └── train.py /model/lstm.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class SLSTM(nn.Module): 5 | """ 6 | LSTM without worrying about the hidden state, nor the layout of the data. 7 | Expects input as convolutional layout. 8 | """ 9 | def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True): 10 | super().__init__() 11 | self.skip = skip 12 | self.lstm = nn.LSTM(dimension, dimension, num_layers) 13 | 14 | def forward(self, x): 15 | x = x.permute(2, 0, 1) 16 | y, _ = self.lstm(x) 17 | if self.skip: 18 | y = y + x 19 | y = y.permute(1, 2, 0) 20 | return y 21 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | 5 | def seed_worker(worker_id): 6 | """ 7 | Used in generating seed for the worker of torch.utils.data.Dataloader 8 | """ 9 | worker_seed = torch.initial_seed() % 2**32 10 | np.random.seed(worker_seed) 11 | random.seed(worker_seed) 12 | 13 | 14 | def set_seed(seed): 15 | """ 16 | set initial seed for reproduction 17 | """ 18 | random.seed(seed) 19 | np.random.seed(seed) 20 | torch.manual_seed(seed) 21 | if torch.cuda.is_available(): 22 | torch.cuda.manual_seed_all(seed) 23 | torch.backends.cudnn.deterministic = True 24 | torch.backends.cudnn.benchmark = False -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Yongyi Zang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | audioread==3.0.1 3 | auraloss==0.4.0 4 | certifi==2024.2.2 5 | cffi==1.16.0 6 | charset-normalizer==3.3.2 7 | decorator==5.1.1 8 | einops==0.8.0 9 | filelock==3.14.0 10 | fsspec==2024.5.0 11 | grpcio==1.64.0 12 | idna==3.7 13 | Jinja2==3.1.4 14 | joblib==1.4.2 15 | lazy_loader==0.4 16 | librosa==0.10.2.post1 17 | llvmlite==0.42.0 18 | Markdown==3.6 19 | MarkupSafe==2.1.5 20 | mpmath==1.3.0 21 | msgpack==1.0.8 22 | networkx==3.3 23 | numba==0.59.1 24 | numpy==1.26.4 25 | nvidia-cublas-cu12==12.1.3.1 26 | nvidia-cuda-cupti-cu12==12.1.105 27 | nvidia-cuda-nvrtc-cu12==12.1.105 28 | nvidia-cuda-runtime-cu12==12.1.105 29 | nvidia-cudnn-cu12==8.9.2.26 30 | nvidia-cufft-cu12==11.0.2.54 31 | nvidia-curand-cu12==10.3.2.106 32 | nvidia-cusolver-cu12==11.4.5.107 33 | nvidia-cusparse-cu12==12.1.0.106 34 | nvidia-nccl-cu12==2.20.5 35 | nvidia-nvjitlink-cu12==12.5.40 36 | nvidia-nvtx-cu12==12.1.105 37 | packaging==24.0 38 | platformdirs==4.2.2 39 | pooch==1.8.1 40 | protobuf==5.27.0 41 | pycparser==2.22 42 | requests==2.32.2 43 | scikit-learn==1.5.0 44 | scipy==1.13.1 45 | six==1.16.0 46 | soundfile==0.12.1 47 | soxr==0.3.7 48 | sympy==1.12 49 | tensorboard==2.16.2 50 | tensorboard-data-server==0.7.2 51 | threadpoolctl==3.5.0 52 | torch==2.3.0 53 | tqdm==4.66.4 54 | triton==2.3.0 55 | typing_extensions==4.12.0 56 | urllib3==2.2.1 57 | Werkzeug==3.0.3 58 | -------------------------------------------------------------------------------- /synthesize.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import librosa\n", 11 | "import soundfile as sf\n", 12 | "\n", 13 | "azimuth_list = [-60, 120] # Left and right speaker azimuths\n", 14 | "azimuth_list = [np.deg2rad(azimuth) for azimuth in azimuth_list]\n", 15 | "\n", 16 | "w_signal = \"W.wav\"\n", 17 | "x_signal = \"X.wav\"\n", 18 | "y_signal = \"Y.wav\"\n", 19 | "\n", 20 | "w, _ = librosa.load(w_signal, sr=44100)\n", 21 | "x, _ = librosa.load(x_signal, sr=44100)\n", 22 | "y, _ = librosa.load(y_signal, sr=44100)\n", 23 | "\n", 24 | "x = x[:len(w)]\n", 25 | "y = y[:len(w)]\n", 26 | "\n", 27 | "left = 0.3 * w + x * np.cos(azimuth_list[0]) + y * np.sin(azimuth_list[0])\n", 28 | "right = 0.3 * w + x * np.cos(azimuth_list[1]) + y * np.sin(azimuth_list[1])\n", 29 | "signal = np.array([left, right]).T\n", 30 | "\n", 31 | "sf.write(\"test.wav\", signal, 44100)" 32 | ] 33 | } 34 | ], 35 | "metadata": { 36 | "kernelspec": { 37 | "display_name": "ambisonizer", 38 | "language": "python", 39 | "name": "python3" 40 | }, 41 | "language_info": { 42 | "codemirror_mode": { 43 | "name": "ipython", 44 | "version": 3 45 | }, 46 | "file_extension": ".py", 47 | "mimetype": "text/x-python", 48 | "name": "python", 49 | "nbconvert_exporter": "python", 50 | "pygments_lexer": "ipython3", 51 | "version": "3.11.4" 52 | } 53 | }, 54 | "nbformat": 4, 55 | "nbformat_minor": 2 56 | } 57 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from torch.utils.data import Dataset, DataLoader 4 | import librosa 5 | 6 | class Ambisonizer(Dataset): 7 | """ 8 | Dataset class for the Ambisonizer dataset. 9 | """ 10 | def __init__(self, base_dir, partition="train", max_len=120000): 11 | assert partition in ["train", "val", "test"], "Invalid partition. Must be one of ['train', 'val', 'test']" 12 | self.base_dir = base_dir 13 | self.partition = partition 14 | self.base_dir = os.path.join(base_dir, partition) 15 | self.max_len = max_len 16 | self.file_list = os.listdir(self.base_dir) 17 | 18 | def __len__(self): 19 | return len(self.file_list) 20 | 21 | def __getitem__(self, index): 22 | base_path = os.path.join(self.base_dir, self.file_list[index]) 23 | w, _ = librosa.load(os.path.join(base_path, "W.wav"), sr=44100, mono=True) 24 | x, _ = librosa.load(os.path.join(base_path, "X.wav"), sr=44100, mono=True) 25 | y, _ = librosa.load(os.path.join(base_path, "Y.wav"), sr=44100, mono=True) 26 | 27 | # random crop indx 28 | crop_idx = np.random.randint(0, len(w)-self.max_len) 29 | w = w[crop_idx:crop_idx+self.max_len] 30 | x = x[crop_idx:crop_idx+self.max_len] 31 | y = y[crop_idx:crop_idx+self.max_len] 32 | 33 | azimuth_list = [22.5, 202.5] 34 | azimuth_list = [np.deg2rad(azimuth) for azimuth in azimuth_list] 35 | left = w + x * np.cos(azimuth_list[0]) + y * np.sin(azimuth_list[0]) 36 | right = w + x * np.cos(azimuth_list[1]) + y * np.sin(azimuth_list[1]) 37 | sig = np.array([left, right]) 38 | target = np.array([x, y]) 39 | random_gain = np.random.uniform(0.5, 1.0) 40 | sig *= random_gain 41 | target *= random_gain 42 | return sig, target 43 | 44 | def test(source_dir, partition) 45 | dataset = Ambisonizer(base_dir=source_dir, partition=partition) 46 | print(len(dataset)) 47 | print(dataset[0][0].shape) 48 | print(dataset[0][1].shape) 49 | print(dataset[0][0].dtype) 50 | print(dataset[0][1].dtype) 51 | dataloader = DataLoader(dataset, batch_size=4, shuffle=True) 52 | for i, (x, y) in enumerate(dataloader): 53 | print(x.shape) 54 | print(y.shape) 55 | break -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Ambisonizer: Neural Upmixing as Spherical Harmonics Generation 2 | Yongyi Zang*, Yifan Wang* (* Equal Contribution), Minglun Lee 3 | 4 | [arXiv link](https://arxiv.org/abs/2405.13428) 5 | 6 | We directly generate the Ambisonic B-format from mono channel to achieve mono-to-any audio upmixing, and use stereo signal as condition to achieve stereo-to-any upmixing. 7 | 8 | The model implementation (defined in `/model`) largely references the [EnCodec Implementation](https://github.com/facebookresearch/encodec) of SEANet. 9 | 10 | ## Updates 11 | - May 2024: We release the model weights, pre-processed ambisonic impulse responses, and relevant scripts for training and inference. 12 | 13 | ## Getting Started 14 | ### Prepare Environment 15 | ```bash 16 | conda create --name ambisonizer python=3.10 17 | conda activate ambisonizer 18 | pip install -r requirements.txt 19 | ``` 20 | 21 | ### Run Inference 22 | We provide the example embedding [here](https://drive.google.com/file/d/1S9VPkvPs0LI3oZzZRwoeLmdbm6BJpUph/view?usp=sharing). Once this file is downloaded, you can use the `inference.ipynb` to run inference using the model. 23 | 24 | Note that the inference result will be W, X and Y channels of the first-order Ambisonic B-format. We provide a simple script in `synthesize.ipynb` to help you convert the ambisonic signals to stereo signals given two azimuth angles. Feel free to adjust it to your needs. 25 | 26 | ### Train 27 | To train the model, start by downloading [MUSDB18-HQ](https://sigsep.github.io/datasets/musdb.html#musdb18-hq-uncompressed-wav) dataset, then use the `generate-data.py` script to generate the training data as needed. For reproducibility, the pre-processed ambisonic impulse responses are available for download [here](https://drive.google.com/file/d/1aGC9pqxZMZPnDjctRqp3wWs-b6HzJYNC/view?usp=sharing). 28 | 29 | Once the data is ready, you can train the model using the `train.py` script. You may need to edit the available partitions defined in the `dataset.py` script to your specific needs. This is an example training command: 30 | ```bash 31 | python train.py --base_dir [path to the dataset] --epochs 100 --batch_size 16 --lr 1e-4 --num_workers 8 --embed_dim 64 --log_dir [path to the log directory] 32 | ``` 33 | 34 | ## Citation 35 | If you find any part of our work useful, please consider citing us: 36 | ```bibtex 37 | @article{zang2024ambisonizer, 38 | title={Ambisonizer: Neural Upmixing as Spherical Harmonics Generation}, 39 | author={Zang, Yongyi and Wang, Yifan and Lee, Minglun}, 40 | journal={arXiv preprint arXiv:2405.13428}, 41 | year={2024} 42 | } 43 | ``` 44 | -------------------------------------------------------------------------------- /inference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import argparse\n", 10 | "import os, sys\n", 11 | "import torch\n", 12 | "import numpy as np\n", 13 | "from tqdm import tqdm\n", 14 | "import datetime, random\n", 15 | "import torch.nn.functional as F\n", 16 | "from torch import nn\n", 17 | "from torch import optim\n", 18 | "from torch.utils.data import DataLoader\n", 19 | "from torch.utils.tensorboard import SummaryWriter\n", 20 | "from datasets import Ambisonizer\n", 21 | "from model.seanet import SEANet\n", 22 | "from utils import seed_worker, set_seed\n", 23 | "import auraloss\n", 24 | "\n", 25 | "device = 'cpu'\n", 26 | "\n", 27 | "checkpoint_path = None # Fill this in with the path to the checkpoint you want to load\n", 28 | "\n", 29 | "model = SEANet(480000, 64).to(device)\n", 30 | "model.load_state_dict(torch.load(checkpoint_path, map_location=device))\n", 31 | "print(model)" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "source_audio = None # Fill this in with the path to the audio file you want to use\n", 41 | "import librosa\n", 42 | "y, sr = librosa.load(source_audio, sr=44100, mono=False)\n", 43 | "start_idx = np.random.randint(0, y.shape[1] - 480000)\n", 44 | "y = y[:, start_idx:start_idx+480000]\n", 45 | "print(y.shape)" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "y = torch.tensor(y).to(device).float().unsqueeze(0)\n", 55 | "print(y.shape)\n", 56 | "y_pred, _, _ = model(y)\n", 57 | "print(y_pred.shape)" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "# save y\n", 67 | "y = y.squeeze().detach().cpu().numpy()\n", 68 | "y = np.transpose(y, (1, 0))\n", 69 | "y = np.ascontiguousarray(y)\n", 70 | "import soundfile as sf\n", 71 | "sf.write('y.wav', y, 44100)\n", 72 | "y_mono = np.mean(y, axis=1)\n", 73 | "sf.write('y_mono.wav', y_mono, 44100)" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "y_pred = y_pred.squeeze().detach().cpu().numpy()\n", 83 | "x = y_pred[0]\n", 84 | "y = y_pred[1]\n", 85 | "w = np.mean(y, axis=0)\n", 86 | "\n", 87 | "sf.write('X.wav', x, 44100)\n", 88 | "sf.write('Y.wav', y, 44100)\n", 89 | "sf.write('W.wav', w, 44100)" 90 | ] 91 | } 92 | ], 93 | "metadata": { 94 | "kernelspec": { 95 | "display_name": "ambisonizer", 96 | "language": "python", 97 | "name": "python3" 98 | }, 99 | "language_info": { 100 | "codemirror_mode": { 101 | "name": "ipython", 102 | "version": 3 103 | }, 104 | "file_extension": ".py", 105 | "mimetype": "text/x-python", 106 | "name": "python", 107 | "nbconvert_exporter": "python", 108 | "pygments_lexer": "ipython3", 109 | "version": "3.10.13" 110 | } 111 | }, 112 | "nbformat": 4, 113 | "nbformat_minor": 2 114 | } 115 | -------------------------------------------------------------------------------- /generate_data.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | import librosa 4 | import soundfile as sf 5 | import audiomentations as augs 6 | from scipy import signal 7 | from tqdm import tqdm 8 | 9 | np.random.seed(42) 10 | audio_source_dir = None # Fill this with the path to the audio source directory 11 | ir_source_dir = None # Fill this with the path to the IR source directory 12 | data_output_dir = None # Fill this with the path to the output directory 13 | n_samples = None # Fill this with the number of samples you want to generate 14 | 15 | assert audio_source_dir is not None, "Please fill in the audio source directory" 16 | assert ir_source_dir is not None, "Please fill in the IR source directory" 17 | assert data_output_dir is not None, "Please fill in the output directory" 18 | assert n_samples is not None, "Please fill in the number of samples to generate" 19 | 20 | sample_length = 10 # seconds 21 | source_max_width = np.pi # in radians 22 | 23 | irs = os.listdir(ir_source_dir) 24 | audio_files = os.listdir(audio_source_dir) 25 | print("Found", len(irs), "IRs and", len(audio_files), "audio files") 26 | 27 | sample_configs = [] 28 | for i in tqdm(range(n_samples), desc="Generating config for each sample"): 29 | ir = np.random.choice(irs) 30 | audio = np.random.choice(audio_files) 31 | sample_configs.append((ir, audio)) 32 | print("Generated", len(sample_configs), "sample configs") 33 | 34 | def convolve(ir, audio): 35 | return signal.fftconvolve(audio, ir, mode="full")[:len(audio)] 36 | 37 | augment = augs.Compose([ 38 | augs.Gain(min_gain_db=-10, max_gain_db=10, p=0.7), 39 | augs.AirAbsorption(min_distance=0.1, max_distance=10, p=0.7), 40 | augs.SevenBandParametricEQ(p=0.7), 41 | augs.GainTransition(p=0.7), 42 | ]) 43 | 44 | def get_coefficients(azimuth, elevation=0): 45 | # assuming that azimuth and elevation are in radians. 46 | w = 1 47 | x = np.cos(elevation) * np.cos(azimuth) * np.sqrt(3) 48 | y = np.cos(elevation) * np.sin(azimuth) * np.sqrt(3) 49 | z = np.sin(elevation) * np.sqrt(3) 50 | return w, x, y, z 51 | 52 | def process_sample_config(sample_config): 53 | ir, audio = sample_config 54 | audio_name = audio.split(".")[0] 55 | ir_path = os.path.join(ir_source_dir, ir) 56 | w_ir, _ = librosa.load(os.path.join(ir_path, "W.wav"), sr=44100, mono=True) 57 | x_ir, _ = librosa.load(os.path.join(ir_path, "X.wav"), sr=44100, mono=True) 58 | y_ir, _ = librosa.load(os.path.join(ir_path, "Y.wav"), sr=44100, mono=True) 59 | # z_ir, _ = librosa.load(os.path.join(ir_path, "Z.wav"), sr=44100, mono=True) 60 | 61 | # randomly chop ir 62 | ir_length = np.random.uniform(0.3, 1.0) 63 | ir_length = int(ir_length*44100) 64 | w_ir = w_ir[:ir_length] 65 | x_ir = x_ir[:ir_length] 66 | y_ir = y_ir[:ir_length] 67 | 68 | ir_fadeout_length = np.random.uniform(0.05, 0.3) 69 | ir_fadeout_length = int(ir_fadeout_length*44100) 70 | ir_fadeout = np.linspace(1, 0, ir_fadeout_length) 71 | w_ir[-ir_fadeout_length:] *= ir_fadeout 72 | x_ir[-ir_fadeout_length:] *= ir_fadeout 73 | y_ir[-ir_fadeout_length:] *= ir_fadeout 74 | 75 | audio_path = os.path.join(audio_source_dir, audio) 76 | 77 | w_channel = np.zeros((sample_length*44100,)) 78 | x_channel = np.zeros((sample_length*44100,)) 79 | y_channel = np.zeros((sample_length*44100,)) 80 | # z_channel = np.zeros((sample_length*44100,)) 81 | 82 | audios = os.listdir(audio_path) 83 | audio_length = librosa.get_duration(path=os.path.join(audio_path, audios[0])) 84 | start_index = int(np.random.uniform(0, int((audio_length - sample_length) * 44100))) 85 | 86 | azimuths = np.random.uniform(-np.pi, np.pi, len(audios)) 87 | 88 | for i, audio in enumerate(audios): 89 | curr_azi = azimuths[i] 90 | y, _ = librosa.load(os.path.join(audio_path, audio), sr=44100, mono=False) 91 | if len(y.shape) == 1: 92 | y = y[start_index:start_index+int(sample_length*44100)] 93 | y = augment(samples=y, sample_rate=44100) 94 | w_c, x_c, y_c, z_c = get_coefficients(curr_azi) 95 | w_channel += convolve(w_ir, y) * w_c 96 | x_channel += convolve(x_ir, y) * x_c 97 | y_channel += convolve(y_ir, y) * y_c 98 | # z_channel += convolve(z_ir, y) * z_c 99 | elif len(y.shape) == 2: 100 | y = y[:, start_index:start_index+int(sample_length*44100)] 101 | y = augment(samples=y, sample_rate=44100) 102 | left_sig = y[0] * 0.5 103 | right_sig = y[1] * 0.5 104 | source_width = np.random.uniform(0, source_max_width) 105 | left_azimuth = curr_azi - source_width/2 106 | right_azimuth = curr_azi + source_width/2 107 | left_w_c, left_x_c, left_y_c, left_z_c = get_coefficients(left_azimuth) 108 | right_w_c, right_x_c, right_y_c, right_z_c = get_coefficients(right_azimuth) 109 | w_channel += convolve(w_ir, left_sig) * left_w_c + convolve(w_ir, right_sig) * right_w_c 110 | x_channel += convolve(x_ir, left_sig) * left_x_c + convolve(x_ir, right_sig) * right_x_c 111 | y_channel += convolve(y_ir, left_sig) * left_y_c + convolve(y_ir, right_sig) * right_y_c 112 | # z_channel += convolve(z_ir, left_sig) * left_z_c + convolve(z_ir, right_sig) * right_z_c 113 | else: 114 | raise ValueError("Audio file is not mono or stereo") 115 | data = np.array([w_channel, x_channel, y_channel]) 116 | if np.max(np.abs(data)) == 0: 117 | pass 118 | elif np.max(np.abs(data)) > 1: 119 | data = data / np.max(np.abs(data)) 120 | 121 | data = data[:, :int(sample_length*44100)] 122 | 123 | curr_output_dir = os.path.join(data_output_dir, ir.split(".")[0] + "|" + audio_name.split(".")[0].replace(" ", "_") + "|" + str(start_index/44100).format(".2f")) 124 | if not os.path.exists(curr_output_dir): 125 | os.makedirs(curr_output_dir) 126 | sf.write(os.path.join(curr_output_dir, "W.wav"), data[0], 44100) 127 | sf.write(os.path.join(curr_output_dir, "X.wav"), data[1], 44100) 128 | sf.write(os.path.join(curr_output_dir, "Y.wav"), data[2], 44100) 129 | # sf.write(os.path.join(curr_output_dir, "Z.wav"), data[3], 44100) 130 | 131 | # use multiprocessing and tqdm 132 | import multiprocessing as mp 133 | pool = mp.Pool(mp.cpu_count()) 134 | with tqdm(total=len(sample_configs), desc="Processing samples") as pbar: 135 | for _ in pool.imap_unordered(process_sample_config, sample_configs): 136 | pbar.update() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | import datetime, random 7 | import torch.nn.functional as F 8 | from torch import nn 9 | from torch import optim 10 | from torch.utils.data import DataLoader 11 | from torch.utils.tensorboard import SummaryWriter 12 | from datasets import Ambisonizer 13 | from model.seanet import SEANet 14 | from utils import seed_worker, set_seed 15 | import auraloss 16 | 17 | def kl_divergence(mu, log_var): 18 | """ 19 | Computes the KL divergence between the learned latent distribution and a standard Gaussian distribution. 20 | Args: 21 | mu (torch.Tensor): Mean of the learned latent distribution. 22 | log_var (torch.Tensor): Log variance of the learned latent distribution. 23 | Returns: 24 | kl_div (torch.Tensor): KL divergence term. 25 | """ 26 | log_var = torch.clamp(log_var, min=-10, max=10) 27 | kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - torch.exp(log_var), dim=1) 28 | return kl_div.mean() 29 | 30 | def main(args): 31 | # Set the seed for reproducibility 32 | set_seed(42) 33 | 34 | device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu") 35 | 36 | mrstft = auraloss.freq.MultiResolutionSTFTLoss().to(device) 37 | mse = nn.MSELoss().to(device) 38 | 39 | # Create the dataset 40 | path = args.base_dir 41 | train_dataset = Ambisonizer(path, partition="train") 42 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, worker_init_fn=seed_worker) 43 | 44 | val_dataset = Ambisonizer(path, partition="val") 45 | val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, worker_init_fn=seed_worker) 46 | 47 | # Create the model 48 | model = SEANet(120000, args.embed_dim).to(device) 49 | 50 | # Create the optimizer 51 | optimizer = optim.AdamW(model.parameters(), lr=5e-5, betas=(0.9, 0.999), eps=1e-9) 52 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-7) 53 | 54 | # Create the directory for the logs 55 | log_dir = os.path.join(args.log_dir, str(args.embed_dim)) 56 | os.makedirs(log_dir, exist_ok=True) 57 | 58 | # get current time 59 | current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 60 | log_dir = os.path.join(log_dir, current_time) 61 | os.makedirs(log_dir, exist_ok=True) 62 | 63 | # Create the summary writer 64 | writer = SummaryWriter(log_dir=log_dir) 65 | 66 | # Create the directory for the checkpoints 67 | checkpoint_dir = os.path.join(log_dir, "checkpoints") 68 | os.makedirs(checkpoint_dir, exist_ok=True) 69 | 70 | # Save config for reproducibility 71 | with open(os.path.join(log_dir, "config.json"), "w") as f: 72 | f.write(str(vars(args))) 73 | 74 | best_val_loss = float("inf") 75 | 76 | # Train the model 77 | for epoch in range(args.epochs): 78 | model.train() 79 | for i, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch + 1}/{args.epochs}")): 80 | if args.debug and i > 20: 81 | break 82 | source, target = batch 83 | source = source.to(device) 84 | target = target.to(device) 85 | pred, mu, log_var = model(source) 86 | kl_loss = kl_divergence(mu, log_var) 87 | mrstft_loss = mrstft(pred, target) 88 | mse_loss = mse(pred, target) 89 | loss = mrstft_loss + mse_loss * 10 + kl_loss 90 | loss.backward() 91 | nn.utils.clip_grad_norm_(model.parameters(), 2.0) 92 | optimizer.step() 93 | optimizer.zero_grad(set_to_none=True) 94 | writer.add_scalar("Train/Loss", loss.item(), epoch * len(train_loader) + i) 95 | writer.add_scalar("Train/MRSTFT", mrstft_loss.item(), epoch * len(train_loader) + i) 96 | writer.add_scalar("Train/MSE", mse_loss.item(), epoch * len(train_loader) + i) 97 | writer.add_scalar("Train/KL", kl_loss.item(), epoch * len(train_loader) + i) 98 | scheduler.step() 99 | writer.add_scalar("LR/train", scheduler.get_last_lr()[0], epoch * len(train_loader) + i) 100 | 101 | model.eval() 102 | val_loss_dict = { 103 | 'MRSTFT': 0, 104 | 'MSE': 0, 105 | 'KL': 0, 106 | } 107 | with torch.no_grad(): 108 | for i, batch in enumerate(tqdm(val_loader, desc=f"Validation")): 109 | if args.debug and i > 20: 110 | break 111 | source, target = batch 112 | source = source.to(device) 113 | target = target.to(device) 114 | pred, mu, log_var = model(source) 115 | kl_loss = kl_divergence(mu, log_var) 116 | mrstft_loss = mrstft(pred, target) 117 | mse_loss = mse(pred, target) 118 | val_loss_dict['MRSTFT'] += mrstft_loss.item() 119 | val_loss_dict['MSE'] += mse_loss.item() 120 | val_loss_dict['KL'] += kl_loss.item() 121 | val_loss_dict['MRSTFT'] /= len(val_loader) 122 | val_loss_dict['MSE'] /= len(val_loader) 123 | val_loss_dict['KL'] /= len(val_loader) 124 | val_loss = val_loss_dict['MRSTFT'] + val_loss_dict['MSE'] * 10 + val_loss_dict['KL'] 125 | writer.add_scalar("Val/Loss", val_loss, epoch) 126 | writer.add_scalar("Val/MRSTFT", val_loss_dict['MRSTFT'], epoch) 127 | writer.add_scalar("Val/MSE", val_loss_dict['MSE'], epoch) 128 | writer.add_scalar("Val/KL", val_loss_dict['KL'], epoch) 129 | if val_loss < best_val_loss: 130 | best_val_loss = val_loss 131 | torch.save(model.state_dict(), os.path.join(checkpoint_dir, "best_model.pth")) 132 | # save checkpoint every 10 epochs 133 | if (epoch + 1) % 10 == 0: 134 | torch.save(model.state_dict(), os.path.join(checkpoint_dir, f"epoch_{epoch + 1}.pth")) 135 | 136 | if __name__ == "__main__": 137 | parser = argparse.ArgumentParser() 138 | parser.add_argument("--base_dir", type=str, required=True, help="The base directory of the dataset.") 139 | parser.add_argument("--epochs", type=int, default=5000, help="The number of epochs to train.") 140 | parser.add_argument("--debug", action="store_true", help="Run in debug mode; will only run for 20 steps for each epoch.") 141 | parser.add_argument("--gpu", type=int, default=0, help="The GPU to use.") 142 | parser.add_argument("--batch_size", type=int, default=32, help="The batch size for training.") 143 | parser.add_argument("--num_workers", type=int, default=8, help="The number of workers for the data loader.") 144 | parser.add_argument("--embed_dim", type=int, default=64, help="The number of workers for the data loader.") 145 | parser.add_argument("--log_dir", type=str, default="logs", help="The directory for the logs.") 146 | 147 | args = parser.parse_args() 148 | main(args) -------------------------------------------------------------------------------- /model/conv.py: -------------------------------------------------------------------------------- 1 | """Convolutional layers wrappers and utilities.""" 2 | 3 | import math 4 | import typing as tp 5 | import warnings 6 | import einops 7 | 8 | import torch 9 | from torch import nn 10 | from torch.nn import functional as F 11 | from torch.nn.utils import spectral_norm, weight_norm 12 | 13 | 14 | class ConvLayerNorm(nn.LayerNorm): 15 | """ 16 | Convolution-friendly LayerNorm that moves channels to last dimensions 17 | before running the normalization and moves them back to original position right after. 18 | """ 19 | def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs): 20 | super().__init__(normalized_shape, **kwargs) 21 | 22 | def forward(self, x): 23 | x = einops.rearrange(x, 'b ... t -> b t ...') 24 | x = super().forward(x) 25 | x = einops.rearrange(x, 'b t ... -> b ... t') 26 | return 27 | 28 | 29 | CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm', 30 | 'time_layer_norm', 'layer_norm', 'time_group_norm']) 31 | 32 | 33 | def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module: 34 | assert norm in CONV_NORMALIZATIONS 35 | if norm == 'weight_norm': 36 | return weight_norm(module) 37 | elif norm == 'spectral_norm': 38 | return spectral_norm(module) 39 | else: 40 | # We already check was in CONV_NORMALIZATION, so any other choice 41 | # doesn't need reparametrization. 42 | return module 43 | 44 | 45 | def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module: 46 | """Return the proper normalization module. If causal is True, this will ensure the returned 47 | module is causal, or return an error if the normalization doesn't support causal evaluation. 48 | """ 49 | assert norm in CONV_NORMALIZATIONS 50 | if norm == 'layer_norm': 51 | assert isinstance(module, nn.modules.conv._ConvNd) 52 | return ConvLayerNorm(module.out_channels, **norm_kwargs) 53 | elif norm == 'time_group_norm': 54 | if causal: 55 | raise ValueError("GroupNorm doesn't support causal evaluation.") 56 | assert isinstance(module, nn.modules.conv._ConvNd) 57 | return nn.GroupNorm(1, module.out_channels, **norm_kwargs) 58 | else: 59 | return nn.Identity() 60 | 61 | 62 | def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, 63 | padding_total: int = 0) -> int: 64 | """See `pad_for_conv1d`. 65 | """ 66 | length = x.shape[-1] 67 | n_frames = (length - kernel_size + padding_total) / stride + 1 68 | ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) 69 | return ideal_length - length 70 | 71 | 72 | def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0): 73 | """Pad for a convolution to make sure that the last window is full. 74 | Extra padding is added at the end. This is required to ensure that we can rebuild 75 | an output of the same length, as otherwise, even with padding, some time steps 76 | might get removed. 77 | For instance, with total padding = 4, kernel size = 4, stride = 2: 78 | 0 0 1 2 3 4 5 0 0 # (0s are padding) 79 | 1 2 3 # (output frames of a convolution, last 0 is never used) 80 | 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding) 81 | 1 2 3 4 # once you removed padding, we are missing one time step ! 82 | """ 83 | extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) 84 | return F.pad(x, (0, extra_padding)) 85 | 86 | 87 | def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.): 88 | """Tiny wrapper around F.pad, just to allow for reflect padding on small input. 89 | If this is the case, we insert extra 0 padding to the right before the reflection happen. 90 | """ 91 | length = x.shape[-1] 92 | padding_left, padding_right = paddings 93 | assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) 94 | if mode == 'reflect': 95 | max_pad = max(padding_left, padding_right) 96 | extra_pad = 0 97 | if length <= max_pad: 98 | extra_pad = max_pad - length + 1 99 | x = F.pad(x, (0, extra_pad)) 100 | padded = F.pad(x, paddings, mode, value) 101 | end = padded.shape[-1] - extra_pad 102 | return padded[..., :end] 103 | else: 104 | return F.pad(x, paddings, mode, value) 105 | 106 | 107 | def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): 108 | """Remove padding from x, handling properly zero padding. Only for 1d!""" 109 | padding_left, padding_right = paddings 110 | assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) 111 | assert (padding_left + padding_right) <= x.shape[-1] 112 | end = x.shape[-1] - padding_right 113 | return x[..., padding_left: end] 114 | 115 | 116 | class NormConv1d(nn.Module): 117 | """Wrapper around Conv1d and normalization applied to this conv 118 | to provide a uniform interface across normalization approaches. 119 | """ 120 | def __init__(self, *args, causal: bool = False, norm: str = 'none', 121 | norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): 122 | super().__init__() 123 | self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm) 124 | self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs) 125 | self.norm_type = norm 126 | 127 | def forward(self, x): 128 | x = self.conv(x) 129 | x = self.norm(x) 130 | return x 131 | 132 | 133 | class NormConv2d(nn.Module): 134 | """Wrapper around Conv2d and normalization applied to this conv 135 | to provide a uniform interface across normalization approaches. 136 | """ 137 | def __init__(self, *args, norm: str = 'none', 138 | norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): 139 | super().__init__() 140 | self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm) 141 | self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs) 142 | self.norm_type = norm 143 | 144 | def forward(self, x): 145 | x = self.conv(x) 146 | x = self.norm(x) 147 | return x 148 | 149 | 150 | class NormConvTranspose1d(nn.Module): 151 | """Wrapper around ConvTranspose1d and normalization applied to this conv 152 | to provide a uniform interface across normalization approaches. 153 | """ 154 | def __init__(self, *args, causal: bool = False, norm: str = 'none', 155 | norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): 156 | super().__init__() 157 | self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm) 158 | self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs) 159 | self.norm_type = norm 160 | 161 | def forward(self, x): 162 | x = self.convtr(x) 163 | x = self.norm(x) 164 | return x 165 | 166 | 167 | class NormConvTranspose2d(nn.Module): 168 | """Wrapper around ConvTranspose2d and normalization applied to this conv 169 | to provide a uniform interface across normalization approaches. 170 | """ 171 | def __init__(self, *args, norm: str = 'none', 172 | norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): 173 | super().__init__() 174 | self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm) 175 | self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs) 176 | 177 | def forward(self, x): 178 | x = self.convtr(x) 179 | x = self.norm(x) 180 | return x 181 | 182 | 183 | class SConv1d(nn.Module): 184 | """Conv1d with some builtin handling of asymmetric or causal padding 185 | and normalization. 186 | """ 187 | def __init__(self, in_channels: int, out_channels: int, 188 | kernel_size: int, stride: int = 1, dilation: int = 1, 189 | groups: int = 1, bias: bool = True, causal: bool = False, 190 | norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, 191 | pad_mode: str = 'reflect'): 192 | super().__init__() 193 | # warn user on unusual setup between dilation and stride 194 | if stride > 1 and dilation > 1: 195 | warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1' 196 | f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).') 197 | self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride, 198 | dilation=dilation, groups=groups, bias=bias, causal=causal, 199 | norm=norm, norm_kwargs=norm_kwargs) 200 | self.causal = causal 201 | self.pad_mode = pad_mode 202 | 203 | def forward(self, x): 204 | B, C, T = x.shape 205 | kernel_size = self.conv.conv.kernel_size[0] 206 | stride = self.conv.conv.stride[0] 207 | dilation = self.conv.conv.dilation[0] 208 | kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations 209 | padding_total = kernel_size - stride 210 | extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) 211 | if self.causal: 212 | # Left padding for causal 213 | x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode) 214 | else: 215 | # Asymmetric padding required for odd strides 216 | padding_right = padding_total // 2 217 | padding_left = padding_total - padding_right 218 | x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode) 219 | return self.conv(x) 220 | 221 | 222 | class SConvTranspose1d(nn.Module): 223 | """ConvTranspose1d with some builtin handling of asymmetric or causal padding 224 | and normalization. 225 | """ 226 | def __init__(self, in_channels: int, out_channels: int, 227 | kernel_size: int, stride: int = 1, causal: bool = False, 228 | norm: str = 'none', trim_right_ratio: float = 1., 229 | norm_kwargs: tp.Dict[str, tp.Any] = {}): 230 | super().__init__() 231 | self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride, 232 | causal=causal, norm=norm, norm_kwargs=norm_kwargs) 233 | self.causal = causal 234 | self.trim_right_ratio = trim_right_ratio 235 | assert self.causal or self.trim_right_ratio == 1., \ 236 | "`trim_right_ratio` != 1.0 only makes sense for causal convolutions" 237 | assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1. 238 | 239 | def forward(self, x): 240 | kernel_size = self.convtr.convtr.kernel_size[0] 241 | stride = self.convtr.convtr.stride[0] 242 | padding_total = kernel_size - stride 243 | 244 | y = self.convtr(x) 245 | 246 | # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be 247 | # removed at the very end, when keeping only the right length for the output, 248 | # as removing it here would require also passing the length at the matching layer 249 | # in the encoder. 250 | if self.causal: 251 | # Trim the padding on the right according to the specified ratio 252 | # if trim_right_ratio = 1.0, trim everything from right 253 | padding_right = math.ceil(padding_total * self.trim_right_ratio) 254 | padding_left = padding_total - padding_right 255 | y = unpad1d(y, (padding_left, padding_right)) 256 | else: 257 | # Asymmetric padding required for odd strides 258 | padding_right = padding_total // 2 259 | padding_left = padding_total - padding_right 260 | y = unpad1d(y, (padding_left, padding_right)) 261 | return y 262 | -------------------------------------------------------------------------------- /model/seanet.py: -------------------------------------------------------------------------------- 1 | """Encodec SEANet-based encoder and decoder implementation.""" 2 | 3 | import typing as tp 4 | 5 | import numpy as np 6 | import torch.nn as nn 7 | import torch 8 | 9 | from model.conv import ( 10 | SConv1d, 11 | SConvTranspose1d, 12 | ) 13 | 14 | from model.lstm import SLSTM 15 | 16 | class ChannelAttention(nn.Module): 17 | def __init__(self, channels, reduction_ratio=16): 18 | super(ChannelAttention, self).__init__() 19 | self.avg_pool = nn.AdaptiveAvgPool1d(1) 20 | self.fc = nn.Sequential( 21 | nn.Linear(channels, channels // reduction_ratio), 22 | nn.ReLU(inplace=True), 23 | nn.Linear(channels // reduction_ratio, channels), 24 | nn.Sigmoid() 25 | ) 26 | 27 | def forward(self, x): 28 | b, c, _ = x.size() 29 | y = self.avg_pool(x).view(b, c) 30 | y = self.fc(y).view(b, c, 1) 31 | return x * y 32 | 33 | class SEANetResnetBlock(nn.Module): 34 | """Residual block from SEANet model. 35 | Args: 36 | dim (int): Dimension of the input/output 37 | kernel_sizes (list): List of kernel sizes for the convolutions. 38 | dilations (list): List of dilations for the convolutions. 39 | activation (str): Activation function. 40 | activation_params (dict): Parameters to provide to the activation function 41 | norm (str): Normalization method. 42 | norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. 43 | causal (bool): Whether to use fully causal convolution. 44 | pad_mode (str): Padding mode for the convolutions. 45 | compress (int): Reduced dimensionality in residual branches (from Demucs v3) 46 | true_skip (bool): Whether to use true skip connection or a simple convolution as the skip connection. 47 | """ 48 | def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1], 49 | activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, 50 | norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False, 51 | pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True): 52 | super().__init__() 53 | assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations' 54 | act = getattr(nn, activation) 55 | hidden = dim // compress 56 | block = [] 57 | for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)): 58 | in_chs = dim if i == 0 else hidden 59 | out_chs = dim if i == len(kernel_sizes) - 1 else hidden 60 | block += [ 61 | act(**activation_params), 62 | SConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation, 63 | norm=norm, norm_kwargs=norm_params, 64 | causal=causal, pad_mode=pad_mode), 65 | ] 66 | self.block = nn.Sequential(*block) 67 | self.shortcut: nn.Module 68 | if true_skip: 69 | self.shortcut = nn.Identity() 70 | else: 71 | self.shortcut = SConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params, 72 | causal=causal, pad_mode=pad_mode) 73 | 74 | def forward(self, x): 75 | return self.shortcut(x) + self.block(x) 76 | 77 | 78 | class SEANetEncoder(nn.Module): 79 | """SEANet encoder. 80 | Args: 81 | channels (int): Audio channels. 82 | dimension (int): Intermediate representation dimension. 83 | n_filters (int): Base width for the model. 84 | n_residual_layers (int): nb of residual layers. 85 | ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of 86 | upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here 87 | that must match the decoder order 88 | activation (str): Activation function. 89 | activation_params (dict): Parameters to provide to the activation function 90 | norm (str): Normalization method. 91 | norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. 92 | kernel_size (int): Kernel size for the initial convolution. 93 | last_kernel_size (int): Kernel size for the initial convolution. 94 | residual_kernel_size (int): Kernel size for the residual layers. 95 | dilation_base (int): How much to increase the dilation with each layer. 96 | causal (bool): Whether to use fully causal convolution. 97 | pad_mode (str): Padding mode for the convolutions. 98 | true_skip (bool): Whether to use true skip connection or a simple 99 | (streamable) convolution as the skip connection in the residual network blocks. 100 | compress (int): Reduced dimensionality in residual branches (from Demucs v3). 101 | lstm (int): Number of LSTM layers at the end of the encoder. 102 | """ 103 | def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 1, 104 | ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, 105 | norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7, 106 | last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False, 107 | pad_mode: str = 'reflect', true_skip: bool = False, compress: int = 2, lstm: int = 2): 108 | super().__init__() 109 | self.channels = channels 110 | self.dimension = dimension 111 | self.n_filters = n_filters 112 | self.ratios = list(reversed(ratios)) 113 | del ratios 114 | self.n_residual_layers = n_residual_layers 115 | self.hop_length = np.prod(self.ratios) 116 | 117 | act = getattr(nn, activation) 118 | mult = 1 119 | model: tp.List[nn.Module] = [ 120 | SConv1d(channels, mult * n_filters, kernel_size, norm=norm, norm_kwargs=norm_params, 121 | causal=causal, pad_mode=pad_mode) 122 | ] 123 | # Downsample to raw audio scale 124 | for i, ratio in enumerate(self.ratios): 125 | # Add residual layers 126 | for j in range(n_residual_layers): 127 | model += [ 128 | SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1], 129 | dilations=[dilation_base ** j, 1], 130 | norm=norm, norm_params=norm_params, 131 | activation=activation, activation_params=activation_params, 132 | causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)] 133 | 134 | # Add downsampling layers 135 | model += [ 136 | act(**activation_params), 137 | SConv1d(mult * n_filters, mult * n_filters * 2, 138 | kernel_size=ratio * 2, stride=ratio, 139 | norm=norm, norm_kwargs=norm_params, 140 | causal=causal, pad_mode=pad_mode), 141 | ] 142 | mult *= 2 143 | 144 | if lstm: 145 | model += [SLSTM(mult * n_filters, num_layers=lstm)] 146 | 147 | model += [ 148 | act(**activation_params), 149 | SConv1d(mult * n_filters, dimension, last_kernel_size, norm=norm, norm_kwargs=norm_params, 150 | causal=causal, pad_mode=pad_mode) 151 | ] 152 | 153 | self.model = nn.ModuleList(model) 154 | 155 | def forward(self, x): 156 | outputs = [] 157 | for layer in self.model: 158 | x = layer(x) 159 | if layer.__class__.__name__ == "SConv1d": 160 | outputs.append(x) 161 | return outputs 162 | 163 | 164 | class SEANetDecoder(nn.Module): 165 | """SEANet decoder. 166 | Args: 167 | channels (int): Audio channels. 168 | dimension (int): Intermediate representation dimension. 169 | n_filters (int): Base width for the model. 170 | n_residual_layers (int): nb of residual layers. 171 | ratios (Sequence[int]): kernel size and stride ratios 172 | activation (str): Activation function. 173 | activation_params (dict): Parameters to provide to the activation function 174 | final_activation (str): Final activation function after all convolutions. 175 | final_activation_params (dict): Parameters to provide to the activation function 176 | norm (str): Normalization method. 177 | norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. 178 | kernel_size (int): Kernel size for the initial convolution. 179 | last_kernel_size (int): Kernel size for the initial convolution. 180 | residual_kernel_size (int): Kernel size for the residual layers. 181 | dilation_base (int): How much to increase the dilation with each layer. 182 | causal (bool): Whether to use fully causal convolution. 183 | pad_mode (str): Padding mode for the convolutions. 184 | true_skip (bool): Whether to use true skip connection or a simple 185 | (streamable) convolution as the skip connection in the residual network blocks. 186 | compress (int): Reduced dimensionality in residual branches (from Demucs v3). 187 | lstm (int): Number of LSTM layers at the end of the encoder. 188 | trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup. 189 | If equal to 1.0, it means that all the trimming is done at the right. 190 | """ 191 | def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 1, 192 | ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, 193 | final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None, 194 | norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7, 195 | last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False, 196 | pad_mode: str = 'reflect', true_skip: bool = False, compress: int = 2, lstm: int = 2, 197 | trim_right_ratio: float = 1.0): 198 | super().__init__() 199 | self.dimension = dimension 200 | self.channels = channels 201 | self.n_filters = n_filters 202 | self.ratios = ratios 203 | del ratios 204 | self.n_residual_layers = n_residual_layers 205 | self.hop_length = np.prod(self.ratios) 206 | 207 | act = getattr(nn, activation) 208 | mult = int(2 ** len(self.ratios)) 209 | model: tp.List[nn.Module] = [ 210 | SConv1d(dimension, mult * n_filters, kernel_size, norm=norm, norm_kwargs=norm_params, 211 | causal=causal, pad_mode=pad_mode) 212 | ] 213 | 214 | if lstm: 215 | model += [SLSTM(mult * n_filters, num_layers=lstm)] 216 | 217 | # Upsample to raw audio scale 218 | for i, ratio in enumerate(self.ratios): 219 | # Add upsampling layers 220 | model += [ 221 | act(**activation_params), 222 | SConvTranspose1d(mult * n_filters, mult * n_filters // 2, 223 | kernel_size=ratio * 2, stride=ratio, 224 | norm=norm, norm_kwargs=norm_params, 225 | causal=causal, trim_right_ratio=trim_right_ratio), 226 | ] 227 | # Add residual layers 228 | for j in range(n_residual_layers): 229 | model += [ 230 | SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1], 231 | dilations=[dilation_base ** j, 1], 232 | activation=activation, activation_params=activation_params, 233 | norm=norm, norm_params=norm_params, causal=causal, 234 | pad_mode=pad_mode, compress=compress, true_skip=true_skip), 235 | ChannelAttention(mult * n_filters // 2) 236 | ] 237 | mult //= 2 238 | 239 | # Add final layers 240 | model += [ 241 | act(**activation_params), 242 | SConv1d(n_filters, channels, last_kernel_size, norm=norm, norm_kwargs=norm_params, 243 | causal=causal, pad_mode=pad_mode) 244 | ] 245 | # Add optional final activation to decoder (eg. tanh) 246 | if final_activation is not None: 247 | final_act = getattr(nn, final_activation) 248 | final_activation_params = final_activation_params or {} 249 | model += [ 250 | final_act(**final_activation_params) 251 | ] 252 | self.model = nn.ModuleList(model) 253 | 254 | def forward(self, z, encoder_outputs): 255 | # for i in encoder_outputs: 256 | # print(i.shape) 257 | skip_connection_idx = 0 258 | for i, layer in enumerate(self.model): 259 | z = layer(z) 260 | if i < len(self.model) - 1 and layer.__class__.__name__ == "SConvTranspose1d": 261 | curr_encoder_output = encoder_outputs[-skip_connection_idx - 2] 262 | # crop or pad with 0 the encoder output to match the size of the current decoder output 263 | if curr_encoder_output.shape[-1] > z.shape[-1]: 264 | curr_encoder_output = curr_encoder_output[..., :z.shape[-1]] 265 | elif curr_encoder_output.shape[-1] < z.shape[-1]: 266 | curr_encoder_output = torch.nn.functional.pad(curr_encoder_output, (0, z.shape[-1] - curr_encoder_output.shape[-1])) 267 | skip_connection_idx += 1 268 | z = z + curr_encoder_output 269 | return z 270 | 271 | 272 | class SEANet(nn.Module): 273 | def __init__(self, length, hidden_dim): 274 | super(SEANet, self).__init__() 275 | self.encoder = SEANetEncoder(channels=1, dimension=hidden_dim) 276 | self.decoder = SEANetDecoder(channels=2, dimension=hidden_dim * 2) 277 | self.stereo_conditioner = SEANetEncoder(channels=2, dimension=hidden_dim) 278 | self.length = length 279 | self.hidden_dim = hidden_dim 280 | encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim * 2, nhead=8, batch_first=True) 281 | self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=8) 282 | self.mu_layer = nn.Linear(hidden_dim, hidden_dim) 283 | self.log_var_layer = nn.Linear(hidden_dim, hidden_dim) 284 | # set weights in the log_var_layer to be all zeros to begin with 285 | # this helps stabilize training 286 | self.log_var_layer.weight.data.zero_() 287 | self.log_var_layer.bias.data.zero_() 288 | 289 | def forward(self, x): 290 | assert x.shape[1] == 2 and x.shape[2] == self.length, f"Expected input of size [batch size, 2, {self.length}] but got {x.shape}" 291 | mono_x = x.mean(dim=1, keepdim=True) 292 | encoder_outputs = self.encoder(mono_x) 293 | z = encoder_outputs[-1] 294 | z = z.permute(0, 2, 1) 295 | conditioner_outputs = self.stereo_conditioner(x) 296 | c = conditioner_outputs[-1] 297 | c = c.permute(0, 2, 1) 298 | mu = self.mu_layer(c) 299 | log_var = self.log_var_layer(c) 300 | std = torch.exp(0.5 * log_var) 301 | eps = torch.randn_like(std) 302 | c = eps * std + mu 303 | z = torch.cat([z, c], dim=2) 304 | z = self.transformer_encoder(z) 305 | z = z.permute(0, 2, 1) 306 | y = self.decoder(z, encoder_outputs[:-1]) 307 | return y, mu, log_var 308 | 309 | def test(): 310 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 311 | model = SEANet(480000, 256).to(device) 312 | x = torch.rand(2, 2, 480000).to(device) 313 | y = model(x) 314 | assert y.shape == (2, 3, 480000), f"Got {y.shape}" 315 | print("SEANet test passed") 316 | 317 | # get number of parameters in the model 318 | pytorch_total_params = sum(p.numel() for p in model.parameters()) 319 | print(f"Number of parameters in the model: {pytorch_total_params/1e6} M") 320 | 321 | 322 | if __name__ == '__main__': 323 | test() 324 | --------------------------------------------------------------------------------