├── conftest.py ├── inr ├── py.typed ├── __init__.py └── systems │ ├── __init__.py │ └── main.py ├── tests ├── __init__.py ├── hypersound │ ├── __init__.py │ └── models │ │ ├── __init__.py │ │ ├── meta │ │ ├── __init__.py │ │ ├── test_hyper.py │ │ └── test_inr.py │ │ ├── test_siren.py │ │ └── test_nerf.py └── test_end_to_end.py ├── hypersound ├── __init__.py ├── py.typed ├── utils │ ├── __init__.py │ ├── wandb.py │ ├── metrics.py │ └── eval.py ├── datasets │ ├── __init__.py │ ├── wrappers │ │ ├── __init__.py │ │ ├── gtzan.py │ │ ├── ljs.py │ │ ├── vctk.py │ │ ├── libritts.py │ │ └── librispeech.py │ ├── transforms.py │ ├── utils.py │ ├── audio.py │ └── base.py ├── models │ ├── __init__.py │ ├── meta │ │ ├── __init__.py │ │ ├── hyper.py │ │ └── inr.py │ ├── nerf.py │ ├── siren.py │ └── encoder.py ├── systems │ ├── __init__.py │ └── loss.py ├── datamodules │ ├── __init__.py │ └── audio.py └── cfg │ ├── default.yaml │ ├── nerf.yaml │ ├── siren.yaml │ ├── inr-nerf.yaml │ ├── inr-siren.yaml │ └── __init__.py ├── rave ├── rave │ ├── __init__.py │ ├── resample.py │ ├── core.py │ └── pqmf.py ├── prior │ ├── __init__.py │ ├── residual_block.py │ ├── core.py │ └── model.py ├── docs │ ├── rave.png │ ├── log_gan.png │ ├── log_distance.png │ ├── log_fidelity.png │ ├── darbouka_prior.mp3 │ ├── maxmsp_screenshot.png │ ├── tensorboard_guide.md │ └── training_setup.md ├── .gitignore ├── debug.py ├── parallel_training.sh ├── generation.py ├── combine_models.py ├── reconstruct.py ├── export_prior.py ├── README.md ├── train_prior.py ├── cli_helper.py ├── train_rave.py └── export_rave.py ├── .env.inr ├── pytest.ini ├── mypy.ini ├── .env.example ├── .gitignore ├── environment.yml ├── setup.cfg ├── pyproject.toml ├── Makefile ├── README.md ├── requirements.txt ├── .github └── workflows │ └── main.yml ├── train.py ├── train_inr.py └── evaluate.py /conftest.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /inr/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /inr/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hypersound/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hypersound/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /inr/systems/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hypersound/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/hypersound/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hypersound/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hypersound/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hypersound/systems/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hypersound/datamodules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hypersound/models/meta/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/hypersound/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hypersound/datasets/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/hypersound/models/meta/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rave/rave/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import RAVE -------------------------------------------------------------------------------- /.env.inr: -------------------------------------------------------------------------------- 1 | WANDB_ENTITY=hypersound 2 | WANDB_PROJECT=inr -------------------------------------------------------------------------------- /rave/prior/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Model as Prior -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | testpaths = tests 3 | pythonpath = . 4 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | python_version = 3.9 3 | ignore_missing_imports = True 4 | -------------------------------------------------------------------------------- /rave/docs/rave.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WUT-AI/hypersound/HEAD/rave/docs/rave.png -------------------------------------------------------------------------------- /rave/docs/log_gan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WUT-AI/hypersound/HEAD/rave/docs/log_gan.png -------------------------------------------------------------------------------- /rave/docs/log_distance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WUT-AI/hypersound/HEAD/rave/docs/log_distance.png -------------------------------------------------------------------------------- /rave/docs/log_fidelity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WUT-AI/hypersound/HEAD/rave/docs/log_fidelity.png -------------------------------------------------------------------------------- /rave/docs/darbouka_prior.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WUT-AI/hypersound/HEAD/rave/docs/darbouka_prior.mp3 -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | DATA_DIR=~/datasets 2 | RESULTS_DIR=~/results 3 | WANDB_ENTITY=hypersound 4 | WANDB_PROJECT=hypersound -------------------------------------------------------------------------------- /hypersound/cfg/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - settings_schema 3 | 4 | _tags_: 5 | - VCTK 6 | 7 | dataset: VCTK 8 | -------------------------------------------------------------------------------- /rave/docs/maxmsp_screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WUT-AI/hypersound/HEAD/rave/docs/maxmsp_screenshot.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .env 3 | .ipynb_checkpoints 4 | .virtual_documents 5 | .vscode/typings 6 | tf_siren 7 | runs/* -------------------------------------------------------------------------------- /rave/.gitignore: -------------------------------------------------------------------------------- 1 | *pycache* 2 | *DS_Store 3 | lightning_logs/ 4 | *.ckpt 5 | *.ts 6 | *libtorch* 7 | *.wav 8 | *.txt 9 | runs -------------------------------------------------------------------------------- /hypersound/cfg/nerf.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - settings_schema 3 | 4 | _tags_: 5 | - NERF 6 | 7 | model: 8 | type: NERF 9 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: hypersound 2 | dependencies: 3 | - python=3.9 4 | - pip 5 | - cudatoolkit=11.3 6 | - pip: 7 | - -r requirements.txt -------------------------------------------------------------------------------- /rave/debug.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | from glob import glob 3 | from pathlib import Path 4 | from rave.core import search_for_run 5 | 6 | print(search_for_run("runs/engine/prior")) -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | exclude = .venv,.git,__pycache__,venv,audiosuperres,.vscode 3 | ignore = E203,W503,SIM106,D,CC,TAE,SIM903 4 | per-file-ignores = __init__.py:F401 5 | max-line-length = 119 6 | max-cognitive-complexity = 10 7 | -------------------------------------------------------------------------------- /hypersound/cfg/siren.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - settings_schema 3 | 4 | _tags_: 5 | - SIREN 6 | 7 | model: 8 | type: SIREN 9 | target_network_omega_0: 2000.0 10 | target_network_omega_i: 30.0 11 | target_network_siren_gradient_fix: true 12 | 13 | learning_rate: 1e-6 -------------------------------------------------------------------------------- /hypersound/cfg/inr-nerf.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - settings_schema 3 | 4 | _tags_: 5 | - NERF 6 | - INR 7 | 8 | model: 9 | type: NERF 10 | perceptual_loss_lambda: 0.0 11 | 12 | data: 13 | samples_per_epoch: 1000 14 | file_limit_validation: 30 15 | 16 | pl: 17 | max_epochs: 100 -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 119 3 | target-version = ["py39"] 4 | 5 | [tool.isort] 6 | profile = "black" 7 | skip_glob = ['.vscode/*'] 8 | src_paths = ["hypersound"] 9 | known_first_party=["hypersound", "inr"] 10 | 11 | [tool.pyright] 12 | exclude = ['.vscode/*', 'rave/*'] -------------------------------------------------------------------------------- /hypersound/cfg/inr-siren.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - settings_schema 3 | 4 | _tags_: 5 | - SIREN 6 | - INR 7 | 8 | model: 9 | type: SIREN 10 | perceptual_loss_lambda: 0.0 11 | target_network_omega_0: 2000.0 12 | target_network_omega_i: 30.0 13 | target_network_siren_gradient_fix: true 14 | 15 | data: 16 | samples_per_epoch: 1000 17 | file_limit_validation: 30 18 | 19 | pl: 20 | max_epochs: 100 -------------------------------------------------------------------------------- /rave/parallel_training.sh: -------------------------------------------------------------------------------- 1 | function cleanup(){ 2 | rm job_*.sh > /dev/null 2>&1 3 | rm jobs.input > /dev/null 2>&1 4 | } 5 | 6 | 7 | cleanup 8 | 9 | INDEX=0 10 | for file in $(ls instruction_*.txt) 11 | do 12 | cat $file | grep python > job_$INDEX.sh 13 | echo job_$INDEX.sh >> jobs.input 14 | ((INDEX++)) 15 | done 16 | 17 | cat jobs.input | parallel -j$1 'CUDA_VISIBLE_DEVICES=$(({%}-1)) source {} > job_$(({#}-1)).log 2>&1' 18 | 19 | cleanup -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .DEFAULT_GOAL := check 2 | 3 | # Code linting 4 | .PHONY: lint 5 | lint: 6 | @echo "\n>>> Default linting checks" 7 | @flake8 8 | @isort --check --diff --color . 9 | @black --check --diff --color . 10 | 11 | # mypy checks 12 | .PHONY: mypy 13 | mypy: 14 | @echo "\n>>> mypy checks" 15 | @mypy *.py hypersound 16 | 17 | # Extended linting 18 | .PHONY: qa 19 | qa: 20 | @echo "\n>>> Extended linting checks" 21 | @pylint --disable=all --enable=duplicate-code hypersound --exit-zero 22 | @flake8 --ignore= --select=D --exit-zero 23 | @flake8 --ignore= --select=CC --exit-zero 24 | @flake8 --ignore= --select=TAE --exit-zero 25 | 26 | # Default code check 27 | .PHONY: check 28 | check: lint mypy qa; 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /hypersound/datasets/wrappers/gtzan.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any 3 | 4 | from torchaudio.datasets.gtzan import GTZAN 5 | 6 | from hypersound.datasets.base import BaseSamples 7 | 8 | ORIGINAL_AUDIO_EXT = ".wav" 9 | VALIDATION_SPLIT = 0.05 10 | 11 | 12 | class GTZAN_Samples(BaseSamples): 13 | def __init__( 14 | self, 15 | root: str, 16 | download: bool, 17 | sample_rate: int, 18 | fold: str, 19 | **kwargs: Any, 20 | ): 21 | 22 | super().__init__(sample_rate, fold, **kwargs) 23 | 24 | GTZAN(root, download=download) 25 | src_dir = Path(root) / "genres" 26 | dst_dir = Path(root) / "genres" 27 | 28 | self.process_recordings_dir(src_dir, dst_dir, ORIGINAL_AUDIO_EXT, validation_split=VALIDATION_SPLIT) 29 | -------------------------------------------------------------------------------- /hypersound/datasets/wrappers/ljs.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any 3 | 4 | from torchaudio.datasets.ljspeech import LJSPEECH 5 | 6 | from hypersound.datasets.base import BaseSamples 7 | 8 | ORIGINAL_AUDIO_EXT = ".wav" 9 | VALIDATION_SPLIT = 0.05 10 | 11 | 12 | class LJS_Samples(BaseSamples): 13 | def __init__( 14 | self, 15 | root: str, 16 | download: bool, 17 | sample_rate: int, 18 | fold: str, 19 | **kwargs: Any, 20 | ): 21 | 22 | super().__init__(sample_rate, fold, **kwargs) 23 | 24 | LJSPEECH(root, download=download) 25 | src_dir = Path(root) / "LJSpeech-1.1" / "wavs" 26 | dst_dir = Path(root) / "LJSpeech-1.1" 27 | 28 | self.process_recordings_dir(src_dir, dst_dir, ORIGINAL_AUDIO_EXT, validation_split=VALIDATION_SPLIT) 29 | -------------------------------------------------------------------------------- /hypersound/datasets/wrappers/vctk.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any 3 | 4 | from torchaudio.datasets.vctk import VCTK_092 5 | 6 | from hypersound.datasets.base import BaseSamples 7 | 8 | ORIGINAL_AUDIO_EXT = ".flac" 9 | VAL_SPEAKER_SPLIT = 10 10 | 11 | 12 | class VCTK_Samples(BaseSamples): 13 | def __init__( 14 | self, 15 | root: str, 16 | download: bool, 17 | sample_rate: int, 18 | fold: str, 19 | **kwargs: Any, 20 | ): 21 | 22 | super().__init__(sample_rate, fold, **kwargs) 23 | 24 | VCTK_092(root, download=download) 25 | src_dir = Path(root) / "VCTK-Corpus-0.92/wav48_silence_trimmed" 26 | dst_dir = Path(root) / "VCTK-Corpus-0.92" 27 | 28 | self.process_recordings_dir(src_dir, dst_dir, ORIGINAL_AUDIO_EXT, val_speaker_split=VAL_SPEAKER_SPLIT) 29 | -------------------------------------------------------------------------------- /hypersound/datasets/wrappers/libritts.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any 3 | 4 | from torchaudio.datasets.libritts import LIBRITTS 5 | 6 | from hypersound.datasets.base import BaseSamples 7 | 8 | ORIGINAL_AUDIO_EXT = ".wav" 9 | VAL_SPEAKER_SPLIT = 10 10 | 11 | 12 | class LibriTTS_Samples(BaseSamples): 13 | def __init__( 14 | self, 15 | root: str, 16 | download: bool, 17 | sample_rate: int, 18 | fold: str, 19 | **kwargs: Any, 20 | ): 21 | 22 | super().__init__(sample_rate, fold, **kwargs) 23 | 24 | LIBRITTS(root, download=download, url="train-clean-360") 25 | src_dir = Path(root) / "LibriTTS/train-clean-360" 26 | dst_dir = Path(root) / "LibriTTS-train-clean-360" 27 | 28 | self.process_recordings_dir(src_dir, dst_dir, ORIGINAL_AUDIO_EXT, val_speaker_split=VAL_SPEAKER_SPLIT) 29 | -------------------------------------------------------------------------------- /hypersound/utils/wandb.py: -------------------------------------------------------------------------------- 1 | import io 2 | 3 | import wandb 4 | from matplotlib.figure import Figure 5 | from PIL import Image 6 | 7 | # isort: split 8 | 9 | 10 | # ---------------------------------------------------------------------------------------------- 11 | # Visualization 12 | # ---------------------------------------------------------------------------------------------- 13 | def fig_to_wandb(fig: Figure) -> wandb.Image: 14 | """Convert a matplotlib.Figure to a wandb.Image. 15 | 16 | Parameters 17 | ---------- 18 | fig : Figure 19 | Matplotlib figure. 20 | 21 | Returns 22 | ------- 23 | wandb.Image 24 | Image version of the figure. 25 | 26 | """ 27 | buffer = io.BytesIO() 28 | fig.savefig(buffer, bbox_inches="tight") # type: ignore 29 | 30 | with Image.open(buffer) as img: 31 | img = wandb.Image(img) 32 | 33 | return img 34 | -------------------------------------------------------------------------------- /hypersound/datasets/wrappers/librispeech.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any 3 | 4 | from torchaudio.datasets.librispeech import LIBRISPEECH 5 | 6 | from hypersound.datasets.base import BaseSamples 7 | 8 | ORIGINAL_AUDIO_EXT = ".flac" 9 | VAL_SPEAKER_SPLIT = 10 10 | 11 | 12 | class LibriSpeech_Samples(BaseSamples): 13 | def __init__( 14 | self, 15 | root: str, 16 | download: bool, 17 | sample_rate: int, 18 | fold: str, 19 | **kwargs: Any, 20 | ): 21 | 22 | super().__init__(sample_rate, fold, **kwargs) 23 | 24 | LIBRISPEECH(root, download=download, url="train-clean-360") 25 | src_dir = Path(root) / "LibriSpeech/train-clean-360" 26 | dst_dir = Path(root) / "LibriSpeech-train-clean-360" 27 | 28 | self.process_recordings_dir(src_dir, dst_dir, ORIGINAL_AUDIO_EXT, val_speaker_split=VAL_SPEAKER_SPLIT) 29 | -------------------------------------------------------------------------------- /rave/prior/residual_block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import cached_conv as cc 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | 8 | def __init__(self, res_size, skp_size, kernel_size, dilation): 9 | super().__init__() 10 | fks = (kernel_size - 1) * dilation + 1 11 | 12 | self.dconv = cc.Conv1d( 13 | res_size, 14 | 2 * res_size, 15 | kernel_size, 16 | padding=(fks - 1, 0), 17 | dilation=dilation, 18 | ) 19 | 20 | self.rconv = nn.Conv1d(res_size, res_size, 1) 21 | self.sconv = nn.Conv1d(res_size, skp_size, 1) 22 | 23 | def forward(self, x, skp): 24 | res = x.clone() 25 | 26 | x = self.dconv(x) 27 | xa, xb = torch.split(x, x.shape[1] // 2, 1) 28 | 29 | x = torch.sigmoid(xa) * torch.tanh(xb) 30 | res = res + self.rconv(x) 31 | skp = skp + self.sconv(x) 32 | return res, skp 33 | -------------------------------------------------------------------------------- /rave/generation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | torch.set_grad_enabled(False) 5 | from effortless_config import Config 6 | from rave.core import search_for_run 7 | from prior import Prior 8 | import math 9 | 10 | import soundfile as sf 11 | 12 | 13 | class args(Config): 14 | PRIOR_CKPT = None # path to prior .ckpt file 15 | EXPORTED_RAVE = None # path to .ts file 16 | LENGTH = 10 # in second 17 | OUT_PATH = "unconditional.wav" 18 | 19 | 20 | args.parse_args() 21 | 22 | args.PRIOR_CKPT = search_for_run(args.PRIOR_CKPT) 23 | 24 | prior = Prior.load_from_checkpoint(args.PRIOR_CKPT).eval() 25 | rave = torch.jit.load(args.EXPORTED_RAVE).eval() 26 | 27 | sr = int(rave.sampling_rate.item()) 28 | n_samples = math.ceil(sr * args.LENGTH) 29 | 30 | x = torch.zeros(1, 1, n_samples) 31 | z = rave.encode(x).zero_() 32 | z = prior.quantized_normal.encode(z) 33 | z = prior.generate(z) 34 | z = prior.diagonal_shift.inverse(prior.quantized_normal.decode(z)) 35 | 36 | y = rave.decode(z).reshape(-1).numpy() 37 | sf.write(args.OUT_PATH, y, sr) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HyperSound 2 | Source code for paper "Hypernetworks build Implicit Neural Representations of 3 | Sounds". [(arxiv)](https://arxiv.org/abs/2302.04959) 4 | 5 | ## Setup 6 | 7 | Setup conda environment: 8 | 9 | ```console 10 | conda env create -f environment.yml 11 | ``` 12 | 13 | Populate `.env` file with settings from `.env.example`, e.g.: 14 | 15 | ```txt 16 | DATA_DIR=~/datasets 17 | RESULTS_DIR=~/results 18 | WANDB_ENTITY=hypersound 19 | WANDB_PROJECT=hypersound 20 | ``` 21 | 22 | Make sure that `pytorch-yard` is using the appropriate version (defined in `train.py`). If not, then correct package version with something like: 23 | 24 | ```console 25 | pip install --force-reinstall pytorch-yard==2022.9.1 26 | ``` 27 | 28 | ## Experiments 29 | 30 | Default experiment: 31 | 32 | ```console 33 | python train.py 34 | ``` 35 | 36 | Custom settings: 37 | 38 | ```console 39 | python train.py cfg.learning_rate=0.01 cfg.pl.max_epochs=100 40 | ``` 41 | 42 | Isolated training of a target network on a single recording: 43 | 44 | ```console 45 | python train_inr.py 46 | ``` 47 | -------------------------------------------------------------------------------- /rave/docs/tensorboard_guide.md: -------------------------------------------------------------------------------- 1 | # Tensorboard guide 2 | 3 | ## Latent space size estimation 4 | 5 | During training, RAVE regularly estimates the **size** of the latent space given a specific dataset for a given *fidelity*. The fidelity parameter is a percentage that defines how well the model should be able to reconstruct an input audio sample. 6 | 7 | Usually values around 80% yield correct yet not accurate reconstructions. Values around 95% are most of the time sufficient to have both a compact latent space and correct reconstructions. 8 | 9 | We log the estimated size of the latent space for several values of fidelity in tensorboard (80, 90, 95 and 99%). 10 | 11 | ![log_fidelity](log_fidelity.png) 12 | 13 | ## Reconstrution error 14 | 15 | The values you should look at for tracking the reconstruction error of the model are the *distance* and *validation* logs 16 | 17 | ![log_distance.png](log_distance.png) 18 | 19 | When the 2 phase kicks in, those values increase - **that's usually normal** 20 | 21 | ## Adversarial losses 22 | 23 | The `loss_dis, loss_gen, pred_true, pred_fake` losses only appear during the second phase. They are usually harder to read, as most of GAN losses are, bu we include here an example of what *normal* logs should look like 24 | 25 | ![log_gan.png](log_gan.png) -------------------------------------------------------------------------------- /rave/combine_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from effortless_config import Config 4 | 5 | 6 | class args(Config): 7 | PRIOR = None 8 | RAVE = None 9 | NAME = "combined" 10 | 11 | 12 | args.parse_args() 13 | 14 | 15 | class Combined(nn.Module): 16 | def __init__(self, prior, rave): 17 | super().__init__() 18 | self._prior = torch.jit.load(prior) 19 | self._rave = torch.jit.load(rave) 20 | 21 | self.register_buffer("encode_params", self._rave.encode_params) 22 | self.register_buffer("decode_params", self._rave.decode_params) 23 | self.register_buffer("forward_params", self._rave.forward_params) 24 | self.register_buffer("prior_params", self._prior.forward_params) 25 | 26 | @torch.jit.export 27 | def encode(self, x): 28 | return self._rave.encode(x) 29 | 30 | @torch.jit.export 31 | def encode_amortized(self, x): 32 | return self._rave.encode_amortized(x) 33 | 34 | @torch.jit.export 35 | def decode(self, x): 36 | return self._rave.decode(x) 37 | 38 | @torch.jit.export 39 | def prior(self, x): 40 | return self._prior(x) 41 | 42 | @torch.jit.export 43 | def forward(self, x): 44 | return self._rave(x) 45 | 46 | 47 | model = torch.jit.script(Combined(args.PRIOR, args.RAVE)) 48 | torch.jit.save(model, f"{args.NAME}.ts") 49 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Generic 2 | jupyterlab==3.2.* 3 | jupyterlab-lsp==3.9.* 4 | librosa==0.9.2 5 | matplotlib 6 | more-itertools 7 | numpy==1.22.* 8 | packaging 9 | pandas==1.3.* 10 | pathos==0.2.* 11 | pyparsing==2.4.7 12 | rich 13 | scikit-learn 14 | scipy==1.8.* 15 | seaborn 16 | soundfile 17 | typer==0.4.* 18 | wandb==0.12.* 19 | 20 | # Code quality 21 | black 22 | black[jupyter] 23 | flake8 24 | flake8-annotations-complexity 25 | flake8-cognitive-complexity 26 | flake8-docstrings 27 | flake8-simplify 28 | isort[colors] 29 | mypy 30 | pydocstyle 31 | pylint 32 | pytest 33 | 34 | # PyTorch & Hydra 35 | --find-links https://download.pytorch.org/whl/cu113/torch_stable.html 36 | torch==1.11.* 37 | torchaudio==0.11.* 38 | torchmetrics==0.9.3 39 | torchvision==0.12.* 40 | hydra-core==1.2.* 41 | git+https://github.com/pytorch/hydra-torch/#subdirectory=hydra-configs-torch 42 | git+https://github.com/pytorch/hydra-torch/#subdirectory=hydra-configs-torchvision 43 | pytorch-lightning==1.5.* 44 | lightning-bolts==0.4.* 45 | pytorch-yard==2022.9.1 46 | torch_tb_profiler==0.3.* 47 | torchinfo==1.5.* 48 | hypnettorch==0.0.4 49 | auraloss==0.2.* 50 | cdpam==0.0.6 51 | pystoi==0.3.3 52 | pesq==0.0.4 53 | 54 | # RAVE 55 | effortless-config==0.7.0 56 | einops==0.4.0 57 | GPUtil==1.4.0 58 | tqdm==4.62.3 59 | git+https://github.com/caillonantoine/cached_conv.git@v2.3.6#egg=cached_conv 60 | git+https://github.com/caillonantoine/UDLS.git@v1.0.0#egg=udls -------------------------------------------------------------------------------- /rave/rave/resample.py: -------------------------------------------------------------------------------- 1 | from scipy.signal import kaiserord, firwin 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from .pqmf import kaiser_filter 8 | 9 | import cached_conv as cc 10 | 11 | 12 | class Resampling(nn.Module): 13 | def __init__(self, target_sr, source_sr): 14 | super().__init__() 15 | self.source_sr = source_sr 16 | self.taget_sr = target_sr 17 | 18 | ratio = target_sr // source_sr 19 | assert int(ratio) == ratio 20 | self.identity = target_sr == source_sr 21 | 22 | if self.identity: 23 | self.upsample = nn.Identity() 24 | self.downsample = nn.Identity() 25 | return 26 | 27 | wc = np.pi / ratio 28 | filt = kaiser_filter(wc, 140) 29 | filt = torch.from_numpy(filt).float() 30 | 31 | self.downsample = cc.Conv1d( 32 | 1, 33 | 1, 34 | len(filt), 35 | stride=ratio, 36 | padding=cc.get_padding(len(filt), ratio), 37 | ) 38 | 39 | self.downsample.weight.data.copy_(filt.reshape(1, 1, -1)) 40 | self.downsample.bias.data.zero_() 41 | 42 | pad = len(filt) % ratio 43 | 44 | filt = nn.functional.pad(filt, (pad, 0)) 45 | filt = filt.reshape(-1, ratio).permute(1, 0) # ratio x T 46 | 47 | pad = (filt.shape[-1] + 1) % 2 48 | filt = nn.functional.pad(filt, (pad, 0)).unsqueeze(1) 49 | 50 | self.upsample = cc.Conv1d( 51 | 1, 52 | 2, 53 | filt.shape[-1], 54 | stride=1, 55 | padding=cc.get_padding(filt.shape[-1]), 56 | ) 57 | 58 | self.upsample.weight.data.copy_(filt) 59 | self.upsample.bias.data.zero_() 60 | 61 | self.ratio = ratio 62 | 63 | def from_target_sampling_rate(self, x): 64 | return self.downsample(x) 65 | 66 | def to_target_sampling_rate(self, x): 67 | x = self.upsample(x) # B x 2 x T 68 | x = x.permute(0, 2, 1).reshape(x.shape[0], -1).unsqueeze(1) 69 | return x -------------------------------------------------------------------------------- /rave/reconstruct.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | torch.set_grad_enabled(False) 4 | 5 | from os import environ, makedirs, path 6 | from pathlib import Path 7 | 8 | import GPUtil as gpu 9 | import librosa as li 10 | import soundfile as sf 11 | from effortless_config import Config 12 | from rave import RAVE 13 | from rave.core import search_for_run 14 | from tqdm import tqdm 15 | 16 | 17 | class args(Config): 18 | CKPT = None # PATH TO YOUR PRETRAINED CHECKPOINT 19 | WAV_FOLDER = None # PATH TO YOUR WAV FOLDER 20 | OUT = "./reconstruction/" 21 | 22 | 23 | args.parse_args() 24 | 25 | # GPU DISCOVERY 26 | CUDA = gpu.getAvailable(maxMemory=0.05) 27 | if len(CUDA): 28 | environ["CUDA_VISIBLE_DEVICES"] = str(CUDA[0]) 29 | use_gpu = 1 30 | elif torch.cuda.is_available(): 31 | print("Cuda is available but no fully free GPU found.") 32 | print("Reconstruction may be slower due to concurrent processes.") 33 | use_gpu = 1 34 | else: 35 | print("No GPU found.") 36 | use_gpu = 0 37 | 38 | device = torch.device("cuda:0" if use_gpu else "cpu") 39 | 40 | # LOAD RAVE 41 | rave = ( 42 | RAVE.load_from_checkpoint( 43 | search_for_run(args.CKPT), 44 | strict=False, 45 | ) 46 | .eval() 47 | .to(device) 48 | ) 49 | 50 | # COMPUTE LATENT COMPRESSION RATIO 51 | x = torch.randn(1, 1, 2**14).to(device) 52 | z = rave.encode(x) 53 | ratio = x.shape[-1] // z.shape[-1] 54 | 55 | # SEARCH FOR WAV FILES 56 | audios = tqdm(list(Path(args.WAV_FOLDER).rglob("*.wav"))) 57 | 58 | # RECONSTRUCTION 59 | makedirs(args.OUT, exist_ok=True) 60 | for audio in audios: 61 | audio_name = path.splitext(path.basename(audio))[0] 62 | audios.set_description(audio_name) 63 | 64 | # LOAD AUDIO TO TENSOR 65 | x, sr = li.load(audio, sr=rave.sr) 66 | x = torch.from_numpy(x).reshape(1, 1, -1).float().to(device) 67 | 68 | # PAD AUDIO 69 | n_sample = x.shape[-1] 70 | pad = (ratio - (n_sample % ratio)) % ratio 71 | x = torch.nn.functional.pad(x, (0, pad)) 72 | 73 | # ENCODE / DECODE 74 | y = rave.decode(rave.encode(x)) 75 | y = y.reshape(-1).cpu().numpy()[:n_sample] 76 | 77 | # WRITE AUDIO 78 | sf.write(path.join(args.OUT, f"reconstruction_{audio_name}.wav"), y, sr) 79 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | pull_request: 5 | branches: [ main ] 6 | 7 | # Allows you to run this workflow manually from the Actions tab 8 | workflow_dispatch: 9 | 10 | jobs: 11 | code-quality: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v1 15 | name: Checkout code 16 | - uses: actions/setup-python@v2 17 | name: Setup Python 3.9 18 | with: 19 | python-version: "3.9" 20 | - uses: syphar/restore-virtualenv@c23446cde59e2779e6bf8f4dc27aa5c0d6b9ec97 21 | name: Restore virtualenv 22 | id: cache-virtualenv 23 | 24 | - uses: syphar/restore-pip-download-cache@8755b20190cfbeb76a425368e201406b09c4b922 25 | name: Restore pip download cache 26 | if: steps.cache-virtualenv.outputs.cache-hit != 'true' 27 | 28 | - run: pip install -r requirements.txt 29 | name: Install requirements 30 | if: steps.cache-virtualenv.outputs.cache-hit != 'true' 31 | 32 | - uses: actions/cache@v2 33 | name: Restore .mypy_cache 34 | with: 35 | path: .mypy_cache 36 | key: ${{ runner.os }}-${{ hashFiles('**/*.py') }} 37 | restore-keys: | 38 | ${{ runner.os }}- 39 | 40 | # Linters 41 | - run: flake8 hypersound *.py 42 | name: "🕵️ flake8" 43 | if: ${{ always() }} 44 | 45 | - run: isort --check --diff --color hypersound *.py 46 | name: "🕵️ isort" 47 | if: ${{ always() }} 48 | 49 | - run: black --check --diff --color hypersound *.py 50 | name: "🕵️ black" 51 | if: ${{ always() }} 52 | 53 | - run: mypy hypersound *.py 54 | name: "🕵️ mypy" 55 | if: ${{ always() }} 56 | 57 | # Extra linters 58 | - run: pylint --disable=all --enable=duplicate-code hypersound *.py 59 | name: "✨ pylint (duplicate code)" 60 | if: ${{ always() }} 61 | continue-on-error: true 62 | 63 | - run: flake8 --ignore= --select=D 64 | name: "✨ flake8 (docstrings)" 65 | if: ${{ always() }} 66 | continue-on-error: true 67 | 68 | - run: flake8 --ignore= --select=CC 69 | name: "✨ flake8 (complexity)" 70 | if: ${{ always() }} 71 | continue-on-error: true 72 | 73 | - run: flake8 --ignore= --select=TAE 74 | name: "✨ flake8 (type annotation complexity)" 75 | if: ${{ always() }} 76 | continue-on-error: true 77 | -------------------------------------------------------------------------------- /rave/prior/core.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class QuantizedNormal(nn.Module): 7 | def __init__(self, resolution, dither=True): 8 | super().__init__() 9 | self.resolution = resolution 10 | self.dither = dither 11 | self.clamp = 4 12 | 13 | def from_normal(self, x): 14 | return .5 * (1 + torch.erf(x / math.sqrt(2))) 15 | 16 | def to_normal(self, x): 17 | x = torch.erfinv(2 * x - 1) * math.sqrt(2) 18 | return torch.clamp(x, -self.clamp, self.clamp) 19 | 20 | def encode(self, x): 21 | x = self.from_normal(x) 22 | x = torch.floor(x * self.resolution) 23 | x = torch.clamp(x, 0, self.resolution - 1) 24 | return self.to_stack_one_hot(x.long()) 25 | 26 | def to_stack_one_hot(self, x): 27 | x = nn.functional.one_hot(x, self.resolution) 28 | x = x.permute(0, 2, 1, 3) 29 | x = x.reshape(x.shape[0], x.shape[1], -1) 30 | x = x.permute(0, 2, 1).float() 31 | return x 32 | 33 | def decode(self, x): 34 | x = x.permute(0, 2, 1) 35 | x = x.reshape(x.shape[0], x.shape[1], -1, self.resolution) 36 | x = torch.argmax(x, -1) / self.resolution 37 | if self.dither: 38 | x = x + torch.rand_like(x) / self.resolution 39 | x = self.to_normal(x) 40 | x = x.permute(0, 2, 1) 41 | return x 42 | 43 | 44 | class DiagonalShift(nn.Module): 45 | def __init__(self, groups=1): 46 | super().__init__() 47 | assert isinstance(groups, int) 48 | assert groups > 0 49 | self.groups = groups 50 | 51 | def shift(self, x: torch.Tensor, i: int, n_dim: int): 52 | i = i // self.groups 53 | n_dim = n_dim // self.groups 54 | start = i 55 | end = -n_dim + i + 1 56 | end = end if end else None 57 | return x[..., start:end] 58 | 59 | def forward(self, x): 60 | n_dim = x.shape[1] 61 | x = torch.split(x, 1, 1) 62 | x = [ 63 | self.shift(_x, i, n_dim) for _x, i in zip( 64 | x, 65 | torch.arange(n_dim).flip(0), 66 | ) 67 | ] 68 | x = torch.cat(list(x), 1) 69 | return x 70 | 71 | def inverse(self, x): 72 | x = x.flip(1) 73 | x = self.forward(x) 74 | x = x.flip(1) 75 | return x 76 | -------------------------------------------------------------------------------- /hypersound/datasets/transforms.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import torch 4 | from scipy.signal import lfilter # type: ignore 5 | from torch import Tensor 6 | 7 | 8 | class Transform: 9 | def __call__(self, x: Tensor) -> Tensor: 10 | raise NotImplementedError 11 | 12 | 13 | class RandomApply(Transform): 14 | def __init__(self, transform: Callable[[Tensor], Tensor], p: float = 0.5): 15 | assert 0.0 <= p <= 1.0 16 | self.transform = transform 17 | self.p = p 18 | 19 | def __call__(self, x: Tensor) -> Tensor: 20 | if torch.rand(1).item() < self.p: 21 | x = self.transform(x) 22 | return x 23 | 24 | 25 | class RandomCrop(Transform): 26 | def __init__(self, total_samples: int, random: bool = True): 27 | self.total_samples = total_samples 28 | self.random = random 29 | 30 | def __call__(self, x: Tensor) -> Tensor: 31 | if not self.random: 32 | start = 0 33 | else: 34 | start = int(torch.randint(x.shape[-1] - self.total_samples, size=(1,)).item()) 35 | 36 | x = x[..., start : start + self.total_samples] 37 | return x 38 | 39 | 40 | class Dequantize(Transform): 41 | def __init__(self, bit_depth: int): 42 | self.bit_depth = bit_depth 43 | 44 | def __call__(self, x: Tensor) -> Tensor: 45 | x = x + (torch.rand(x.shape[-1]) - 0.5) / 2**self.bit_depth 46 | return x 47 | 48 | 49 | class RandomPhaseMangle(Transform): 50 | def __init__(self, min_f: int, max_f: int, amplitude: float, sample_rate: int): 51 | self.min_f = torch.tensor(min_f) 52 | self.max_f = torch.tensor(max_f) 53 | self.amplitude = torch.tensor(amplitude) 54 | self.sample_rate = sample_rate 55 | 56 | def random_angle(self) -> Tensor: 57 | min_f = torch.log(self.min_f) 58 | max_f = torch.log(self.max_f) 59 | 60 | rand = torch.exp(torch.rand(1).item() * (max_f - min_f) + min_f) 61 | rand = 2 * torch.pi * rand / self.sample_rate 62 | 63 | return rand 64 | 65 | def pole_to_z_filter(self, angle: Tensor) -> tuple[list[float], list[float]]: 66 | z0 = self.amplitude * torch.exp(1j * angle) 67 | a = [1.0, float(-2 * torch.real(z0)), float(torch.abs(z0) ** 2)] 68 | b = [float(torch.abs(z0) ** 2), float(-2 * torch.real(z0)), 1.0] 69 | return b, a 70 | 71 | def __call__(self, x: Tensor) -> Tensor: 72 | angle = self.random_angle() 73 | b, a = self.pole_to_z_filter(angle) 74 | 75 | return torch.tensor(lfilter(b, a, x.numpy())).float() 76 | -------------------------------------------------------------------------------- /rave/docs/training_setup.md: -------------------------------------------------------------------------------- 1 | ![logo](rave.png) 2 | 3 | # Training setup 4 | 5 | 1. You should train on a _CUDA-enabled_ machine (i.e with an nvidia-card) 6 | - You can use either **Linux** or **Windows** 7 | - However we advise to use **Linux** if available 8 | - Training RAVE without a hardware accelerator (GPU, TPU) will take ages, and is not recommended 9 | 2. Make sure that you have CUDA enabled 10 | - Go to a terminal an enter `nvidia-smi` 11 | - If a message appears with the name of your graphic card and the available memory, it's all good ! 12 | - Otherwise, you have to install **cuda** on your computer (we don't provide support for that, lots of guides are available online) 13 | 3. Let's install python ! 14 | 15 | # Python installation 16 | 17 | Python is often pre-installed on most computers, but we won't use this version. Instead, we will install a **conda** distribution on the machine. This keeps different versions of python separate for different projects, and allows regular users to install new packages without sudo access. 18 | 19 | You can follow the [instructions here](https://conda.io/projects/conda/en/latest/user-guide/install/index.html) to install a miniconda environment on your computer. 20 | 21 | Once installed, you know that you are inside your miniconda environment if there's a "`(base)`" at the beginning of your terminal. 22 | 23 | # RAVE installation 24 | 25 | We will create a new virtual environment for RAVE. 26 | 27 | ```bash 28 | conda create -n rave python=3.9 29 | ``` 30 | 31 | Each time we want to use RAVE, we can (and **should**) activate this environment using 32 | 33 | ```bash 34 | conda activate rave 35 | ``` 36 | 37 | Let's clone RAVE and install the requirements ! 38 | 39 | ```bash 40 | git clone https://github.com/acids-ircam/RAVE 41 | cd RAVE 42 | pip install -r requirements.txt 43 | ``` 44 | 45 | You can now use `python cli_helper.py` to start a new training ! 46 | 47 | # About the dataset 48 | 49 | A good rule of thumb is **more is better**. You might want to have _at least_ 3h of homogeneous recordings to train RAVE, more if your dataset is complex (e.g mixtures of instruments, lots of variations...) 50 | 51 | If you have a folder filled with various audio files (any extension, any sampling rate), you can use the `resample` utility in this folder 52 | 53 | ```bash 54 | conda activate rave 55 | resample --sr TARGET_SAMPLING_RATE --augment 56 | ``` 57 | 58 | It will convert, resample, crop and augment all audio files present in the directory to an output directory called `out_TARGET_SAMPLING_RATE/` (which is the one you should give to `cli_helper.py` when asked for the path of the .wav files). 59 | -------------------------------------------------------------------------------- /hypersound/models/nerf.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, cast 2 | 3 | import torch 4 | import torch.nn 5 | from torch import Tensor, nn 6 | from torch.nn.parameter import Parameter 7 | 8 | from hypersound.cfg import TargetNetworkMode 9 | from hypersound.models.meta.inr import INR 10 | 11 | 12 | class NERF(INR): 13 | def __init__( 14 | self, 15 | input_size: int, 16 | output_size: int, 17 | hidden_sizes: list[int], 18 | bias: bool, 19 | mode: TargetNetworkMode, 20 | encoding_length: int, 21 | learnable_encoding: bool, 22 | ): 23 | super().__init__( 24 | input_size=2 * encoding_length * input_size, 25 | output_size=output_size, 26 | hidden_sizes=hidden_sizes, 27 | bias=bias, 28 | activation_fn=nn.ReLU(), 29 | mode=mode, 30 | ) 31 | self.input_size = input_size # override super().input_size 32 | self.encoding_length = encoding_length 33 | 34 | freq = torch.ones((encoding_length,), dtype=torch.float32) 35 | for i in range(len(freq)): 36 | freq[i] = 2**i 37 | freq = freq * torch.pi 38 | freq = Parameter(freq, requires_grad=learnable_encoding) 39 | 40 | self.params["freq"] = freq 41 | self.register_parameter("freq", freq) 42 | 43 | def forward( # type: ignore 44 | self, 45 | x: Tensor, 46 | weights: Optional[dict[str, Tensor]] = None, 47 | return_activations: bool = False, 48 | ) -> Union[Tensor, tuple[Tensor, list[tuple[Tensor, Tensor]]]]: 49 | if weights is None: 50 | # Single mode, x --> (num_samples, input_size) 51 | freq = cast(Tensor, self.params["freq"]) # (encoding_length,) 52 | else: 53 | # Batch mode, x --> (batch_size, num_samples, input_size) 54 | freq = weights.get("freq", cast(Tensor, self.params["freq"]).tile(x.shape[0], 1)) 55 | freq = freq.unsqueeze(1).unsqueeze(1) # (batch_size, _, _, encoding_length) 56 | 57 | x = x.unsqueeze(-1) * freq 58 | x = torch.cat((torch.sin(x), torch.cos(x)), dim=-1).flatten(-2, -1) 59 | # x --> (_batch_size_, num_samples, input_size * encoding_length * 2) 60 | 61 | activations: list[tuple[Tensor, Tensor]] = [] 62 | 63 | for i in range(self.n_layers): 64 | x = self._forward(x, layer_idx=i, weights=weights) 65 | 66 | h = x 67 | 68 | if i != self.n_layers - 1: 69 | x = self._activation_fn(x) 70 | 71 | if return_activations: 72 | activations.append((x, h)) 73 | 74 | if return_activations: 75 | return x, activations 76 | else: 77 | return x 78 | -------------------------------------------------------------------------------- /hypersound/datasets/utils.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Any, Callable, Optional, cast 3 | 4 | import torch.utils.data 5 | from omegaconf import OmegaConf 6 | from pytorch_yard import RootConfig 7 | from torch import Tensor 8 | 9 | from hypersound.cfg import Dataset, Settings 10 | from hypersound.datamodules.audio import AudioDataModule 11 | from hypersound.datasets.wrappers.gtzan import GTZAN_Samples 12 | from hypersound.datasets.wrappers.librispeech import LibriSpeech_Samples 13 | from hypersound.datasets.wrappers.libritts import LibriTTS_Samples 14 | from hypersound.datasets.wrappers.ljs import LJS_Samples 15 | from hypersound.datasets.wrappers.vctk import VCTK_Samples 16 | 17 | 18 | def init_datamodule( 19 | root_cfg: RootConfig, resample_rate: Optional[int] = None 20 | ) -> tuple[AudioDataModule, Optional[AudioDataModule]]: 21 | cfg = cast(Settings, root_cfg.cfg) 22 | kwargs = dict( 23 | root=root_cfg.data_dir, 24 | download=True, 25 | duration=cfg.data.duration, 26 | padding=cfg.transforms.padding, 27 | dequantize=cfg.transforms.dequantize, 28 | phase_mangle=cfg.transforms.phase_mangle, 29 | random_crop=cfg.transforms.random_crop, 30 | start_offset=cfg.transforms.start_offset, 31 | ) 32 | 33 | _dataset: Callable[[str], torch.utils.data.Dataset[tuple[Tensor, int]]] 34 | if cfg.dataset is Dataset.VCTK: 35 | _dataset = partial(VCTK_Samples, **kwargs) 36 | elif cfg.dataset is Dataset.LJS: 37 | _dataset = partial(LJS_Samples, **kwargs) 38 | elif cfg.dataset is Dataset.GTZAN: 39 | _dataset = partial(GTZAN_Samples, **kwargs) 40 | elif cfg.dataset is Dataset.LIBRITTS: 41 | _dataset = partial(LibriTTS_Samples, **kwargs) 42 | elif cfg.dataset is Dataset.LIBRISPEECH: 43 | _dataset = partial(LibriSpeech_Samples, **kwargs) 44 | else: 45 | raise ValueError(f"Unknown dataset: {cfg.dataset}.") 46 | 47 | train_dataset = _dataset(fold="train", sample_rate=cfg.data.sample_rate) # type: ignore 48 | validation_dataset = _dataset(fold="validation", sample_rate=cfg.data.sample_rate) # type: ignore 49 | 50 | main_dm = AudioDataModule( 51 | train_dataset=train_dataset, 52 | validation_dataset=validation_dataset, 53 | batch_size=cfg.batch_size, 54 | **cast(dict[str, Any], OmegaConf.to_container(cfg.data, resolve=True)), 55 | ) 56 | 57 | if resample_rate: 58 | # Prepare datamodule for resampling evaluation 59 | _cfg = cast(dict[str, Any], OmegaConf.to_container(cfg.data, resolve=True)) 60 | _cfg.update({"sample_rate": resample_rate}) 61 | interpolation_dm = AudioDataModule( 62 | train_dataset=train_dataset, 63 | validation_dataset=_dataset(fold="validation", sample_rate=resample_rate), # type: ignore 64 | batch_size=cfg.batch_size, 65 | **_cfg, 66 | ) 67 | interpolation_dm.prepare_data() 68 | interpolation_dm.setup() 69 | else: 70 | interpolation_dm = None 71 | 72 | return main_dm, interpolation_dm 73 | -------------------------------------------------------------------------------- /hypersound/datamodules/audio.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional, cast 2 | 3 | import torch 4 | from pytorch_lightning import LightningDataModule 5 | from torch.utils.data import DataLoader, Dataset, RandomSampler 6 | from torch.utils.data.dataset import Subset 7 | 8 | from ..datasets.audio import AudioDataset, IndicesAudioAndSpectrogram 9 | 10 | 11 | class AudioDataModule(LightningDataModule): 12 | """A data module for generic audio datasets. 13 | 14 | Uses the `hypersound.datasets.audio.AudioDataset` wrapper to return 15 | pairs of (waveform, melspectrogram) tensors. 16 | 17 | """ 18 | 19 | def __init__( 20 | self, 21 | train_dataset: Dataset[Any], 22 | validation_dataset: Dataset[Any], 23 | batch_size: int, 24 | num_workers: int, 25 | samples_per_epoch: int, 26 | file_limit: Optional[int] = None, 27 | file_limit_validation: Optional[int] = None, 28 | **kwargs: Any, 29 | ): 30 | super().__init__() # type: ignore 31 | 32 | self.train = AudioDataset(train_dataset, **kwargs) 33 | self.validation = AudioDataset(validation_dataset, **kwargs) 34 | self.batch_size = batch_size 35 | self.num_workers = num_workers 36 | 37 | self.file_limit = file_limit 38 | self.file_limit_validation = file_limit_validation 39 | self.samples_per_epoch = samples_per_epoch 40 | 41 | def prepare_data(self, *args: None, **kwargs: None): 42 | pass 43 | 44 | def setup(self, stage: Optional[str] = None): 45 | if self.file_limit: 46 | _indices: list[int] = torch.randperm(len(self.train), generator=torch.Generator().manual_seed(0)).tolist() # type: ignore # noqa 47 | self.train = cast(AudioDataset, Subset(self.train, _indices[: self.file_limit])) 48 | 49 | if self.file_limit_validation: 50 | _indices: list[int] = torch.randperm(len(self.validation), generator=torch.Generator().manual_seed(0)).tolist() # type: ignore # noqa 51 | self.validation = cast(AudioDataset, Subset(self.validation, _indices[: self.file_limit_validation])) 52 | 53 | @property 54 | def shape(self): 55 | return self.train.shape 56 | 57 | def train_dataloader(self, *args: None, **kwargs: None) -> DataLoader[IndicesAudioAndSpectrogram]: 58 | assert self.train 59 | return DataLoader( 60 | self.train, 61 | batch_size=self.batch_size, 62 | sampler=RandomSampler(self.train, num_samples=self.samples_per_epoch), 63 | num_workers=self.num_workers, 64 | drop_last=True, 65 | ) 66 | 67 | def val_dataloader(self, *args: None, **kwargs: None) -> DataLoader[IndicesAudioAndSpectrogram]: 68 | assert self.validation 69 | return DataLoader( 70 | self.validation, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True 71 | ) 72 | 73 | def __repr__(self) -> str: 74 | return ( 75 | f" {self.train}, {self.validation}>" 79 | ) 80 | -------------------------------------------------------------------------------- /rave/export_prior.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | torch.set_grad_enabled(False) 4 | import torch.nn as nn 5 | from effortless_config import Config 6 | import logging 7 | from termcolor import colored 8 | 9 | import math 10 | 11 | from cached_conv import use_cached_conv 12 | 13 | logging.basicConfig(level=logging.INFO, 14 | format=colored("[%(relativeCreated).2f] ", "green") + 15 | "%(message)s") 16 | 17 | logging.info("exporting model") 18 | 19 | 20 | class args(Config): 21 | RUN = None 22 | NAME = "latent" 23 | 24 | 25 | args.parse_args() 26 | use_cached_conv(True) 27 | 28 | import cached_conv as cc 29 | from prior.model import Model 30 | from rave.core import search_for_run 31 | 32 | 33 | class TraceModel(nn.Module): 34 | def __init__(self, pretrained: Model): 35 | super().__init__() 36 | data_size = pretrained.data_size 37 | 38 | self.data_size = data_size 39 | self.pretrained = pretrained 40 | 41 | x = torch.zeros(1, 1, 2**14) 42 | z = self.pretrained.encode(x) 43 | ratio = x.shape[-1] // z.shape[-1] 44 | 45 | self.register_buffer( 46 | "forward_params", 47 | torch.tensor([1, ratio, data_size, ratio]), 48 | ) 49 | 50 | self.pretrained.synth = None 51 | 52 | self.register_buffer( 53 | "previous_step", 54 | self.pretrained.quantized_normal.encode( 55 | torch.zeros(1, data_size, 1)), 56 | ) 57 | 58 | self.pre_diag_cache = cc.CachedPadding1d(data_size - 1) 59 | self.pre_diag_cache(z) 60 | self.pre_diag_cache = torch.jit.script(self.pre_diag_cache) 61 | 62 | def step_forward(self, temp): 63 | # PREDICT NEXT STEP 64 | x = self.pretrained.forward(self.previous_step) 65 | x = x / temp 66 | x = self.pretrained.post_process_prediction(x, argmax=False) 67 | self.previous_step.copy_(x.clone()) 68 | 69 | # DECODE AND SHIFT PREDICTION 70 | x = self.pretrained.quantized_normal.decode(x) 71 | x = self.pre_diag_cache(x) 72 | x = self.pretrained.diagonal_shift.inverse(x) 73 | 74 | return x 75 | 76 | def forward(self, temp: torch.Tensor): 77 | x = torch.zeros( 78 | temp.shape[0], 79 | self.data_size, 80 | temp.shape[-1], 81 | ).to(temp) 82 | 83 | temp = temp.mean(-1, keepdim=True) 84 | temp = nn.functional.softplus(temp) / math.log(2) 85 | 86 | for i in range(x.shape[-1]): 87 | x[..., i:i + 1] = self.step_forward(temp) 88 | 89 | return x 90 | 91 | 92 | logging.info("loading model from checkpoint") 93 | 94 | RUN = search_for_run(args.RUN) 95 | logging.info(f"using {RUN}") 96 | 97 | model = Model.load_from_checkpoint(RUN, strict=False).eval() 98 | 99 | logging.info("warmup forward pass") 100 | 101 | x = torch.zeros(1, 1, 2**17) 102 | x = model.encode(x) 103 | x = torch.zeros_like(x) 104 | x = model.quantized_normal.encode(model.diagonal_shift(x)) 105 | x = x[..., -1:] 106 | model(x) 107 | 108 | logging.info("script model") 109 | model = TraceModel(model) 110 | model = torch.jit.script(model) 111 | 112 | logging.info("save model") 113 | model.save(f"prior_{args.NAME}.ts") -------------------------------------------------------------------------------- /tests/hypersound/models/meta/test_hyper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from hypersound.cfg import TargetNetworkMode 4 | from hypersound.models.meta.hyper import MLPHyperNetwork 5 | from hypersound.models.nerf import NERF 6 | from hypersound.models.siren import SIREN 7 | 8 | 9 | def test_siren_hypernetwork() -> None: 10 | target_network = SIREN( 11 | input_size=1, 12 | output_size=1, 13 | hidden_sizes=[32, 16], 14 | mode=TargetNetworkMode.TARGET_NETWORK, 15 | bias=True, 16 | omega_0=30, 17 | omega_i=30, 18 | learnable_omega=False, 19 | gradient_fix=True, 20 | ) 21 | hypernetwork = MLPHyperNetwork( 22 | target_network=target_network, 23 | shared_params=["o0", "o1", "o2"], 24 | input_size=32, 25 | layer_sizes=[64], 26 | ) 27 | 28 | z = torch.rand((2, 32)) 29 | x = torch.rand((2, 32, 1)) * 2 - 1 30 | 31 | weights = hypernetwork(z) 32 | y = target_network(x, weights=weights) 33 | 34 | assert y.shape == (2, 32, 1) 35 | 36 | 37 | def test_nerf_hypernetwork() -> None: 38 | target_network = NERF( 39 | input_size=1, 40 | output_size=1, 41 | hidden_sizes=[32, 16], 42 | mode=TargetNetworkMode.TARGET_NETWORK, 43 | bias=True, 44 | encoding_length=6, 45 | learnable_encoding=False, 46 | ) 47 | hypernetwork = MLPHyperNetwork( 48 | target_network=target_network, 49 | shared_params=["freq"], 50 | input_size=32, 51 | layer_sizes=[64], 52 | ) 53 | 54 | z = torch.rand((2, 32)) 55 | x = torch.rand((2, 32, 1)) 56 | 57 | weights = hypernetwork(z) 58 | y = target_network(x, weights=weights) 59 | 60 | assert y.shape == (2, 32, 1) 61 | 62 | 63 | def test_siren_hypernetwork_with_learnable_omega() -> None: 64 | target_network = SIREN( 65 | input_size=1, 66 | output_size=1, 67 | hidden_sizes=[32, 16], 68 | bias=True, 69 | mode=TargetNetworkMode.TARGET_NETWORK, 70 | omega_0=30, 71 | omega_i=30, 72 | learnable_omega=True, 73 | gradient_fix=True, 74 | ) 75 | hypernetwork = MLPHyperNetwork( 76 | target_network=target_network, 77 | shared_params=[], 78 | input_size=32, 79 | layer_sizes=[64], 80 | ) 81 | 82 | z = torch.rand((2, 32)) 83 | x = torch.rand((2, 32, 1)) * 2 - 1 84 | 85 | weights = hypernetwork(z) 86 | y = target_network(x, weights=weights) 87 | 88 | assert y.shape == (2, 32, 1) 89 | 90 | 91 | def test_nerf_hypernetwork_with_learnable_encoding() -> None: 92 | target_network = NERF( 93 | input_size=1, 94 | output_size=1, 95 | hidden_sizes=[32, 16], 96 | mode=TargetNetworkMode.TARGET_NETWORK, 97 | bias=True, 98 | encoding_length=6, 99 | learnable_encoding=True, 100 | ) 101 | hypernetwork = MLPHyperNetwork( 102 | target_network=target_network, 103 | shared_params=[], 104 | input_size=32, 105 | layer_sizes=[64], 106 | ) 107 | 108 | z = torch.rand((2, 32)) 109 | x = torch.rand((2, 32, 1)) * 2 - 1 110 | 111 | weights = hypernetwork(z) 112 | y = target_network(x, weights=weights) 113 | 114 | assert y.shape == (2, 32, 1) 115 | -------------------------------------------------------------------------------- /hypersound/models/siren.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, cast 2 | 3 | import torch 4 | import torch.nn 5 | from torch import Tensor, nn 6 | from torch.nn.parameter import Parameter 7 | 8 | from hypersound.cfg import TargetNetworkMode 9 | from hypersound.models.meta.inr import INR 10 | 11 | 12 | class Sine(nn.Module): 13 | def forward(self, x: Tensor) -> Tensor: # type: ignore 14 | return torch.sin(x) 15 | 16 | 17 | class SIREN(INR): 18 | def __init__( 19 | self, 20 | input_size: int, 21 | output_size: int, 22 | hidden_sizes: list[int], 23 | bias: bool, 24 | mode: TargetNetworkMode, 25 | omega_0: float, 26 | omega_i: float, 27 | learnable_omega: bool, 28 | gradient_fix: bool, 29 | ): 30 | super().__init__( 31 | input_size=input_size, 32 | output_size=output_size, 33 | hidden_sizes=hidden_sizes, 34 | bias=bias, 35 | activation_fn=Sine(), 36 | mode=mode, 37 | ) 38 | 39 | for i in range(self.n_layers): 40 | omega_val = torch.ones((1,), dtype=torch.float32) 41 | if i == 0: 42 | omega_val *= omega_0 43 | else: 44 | omega_val *= omega_i 45 | omega = Parameter(omega_val, requires_grad=learnable_omega) 46 | 47 | self.params[f"o{i}"] = omega 48 | self.register_parameter(f"o{i}", omega) 49 | 50 | self.gradient_fix = gradient_fix 51 | self.init_siren() 52 | 53 | def init_siren(self) -> None: 54 | for i in range(self.n_layers): 55 | w_i = self.params[f"w{i}"] 56 | _, n_features_in = w_i.shape 57 | 58 | if i == 0: 59 | std = 1 / n_features_in 60 | else: 61 | # NOTE: Version used in https://colab.research.google.com/github/vsitzmann/siren/blob/master/explore_siren.ipynb # noqa 62 | std = (6.0 / n_features_in) ** 0.5 63 | if self.gradient_fix: 64 | std /= float(self.params["o0"]) 65 | 66 | with torch.no_grad(): 67 | w_i.uniform_(-std, std) 68 | 69 | def forward( # type: ignore 70 | self, 71 | x: Tensor, 72 | weights: Optional[dict[str, Tensor]] = None, 73 | return_activations: bool = False, 74 | ) -> Union[Tensor, tuple[Tensor, list[tuple[Tensor, Tensor]]]]: 75 | 76 | activations: list[tuple[Tensor, Tensor]] = [] 77 | 78 | for i in range(self.n_layers): 79 | if weights is None: 80 | omega = cast(Tensor, self.params[f"o{i}"]) 81 | x = x * omega 82 | else: 83 | omega = weights.get(f"o{i}", torch.stack(x.shape[0] * [cast(Tensor, self.params[f"o{i}"])], dim=0)) 84 | x = x * omega.unsqueeze(-1) 85 | 86 | x = self._forward(x, layer_idx=i, weights=weights) 87 | 88 | h = x 89 | 90 | if i != self.n_layers - 1: 91 | x = self._activation_fn(x) 92 | 93 | if return_activations: 94 | activations.append((x, h)) 95 | 96 | if return_activations: 97 | return x, activations 98 | else: 99 | return x 100 | -------------------------------------------------------------------------------- /tests/test_end_to_end.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from hypersound.cfg import TargetNetworkMode 4 | from hypersound.models.meta.hyper import MLPHyperNetwork 5 | from hypersound.models.nerf import NERF 6 | from hypersound.models.siren import SIREN 7 | 8 | 9 | def test_siren_hypernetwork_with_shared_layers() -> None: 10 | target_network = SIREN( 11 | input_size=1, 12 | output_size=1, 13 | hidden_sizes=[32, 16], 14 | bias=True, 15 | mode=TargetNetworkMode.TARGET_NETWORK, 16 | omega_0=1, 17 | omega_i=1, 18 | learnable_omega=False, 19 | gradient_fix=True, 20 | ) 21 | hypernetwork = MLPHyperNetwork( 22 | target_network=target_network, 23 | shared_params=["w0", "b0", "o0", "o1", "o2"], 24 | input_size=32, 25 | layer_sizes=[64], 26 | ) 27 | 28 | z = torch.rand((2, 32)) 29 | x = torch.rand((2, 32, 1)) * 2 - 1 30 | 31 | weights = hypernetwork(z) 32 | y = target_network(x, weights=weights) 33 | 34 | assert y.shape == (2, 32, 1) 35 | 36 | 37 | def test_nerf_hypernetwork_with_shared_layers() -> None: 38 | target_network = NERF( 39 | input_size=1, 40 | output_size=1, 41 | hidden_sizes=[32, 16], 42 | bias=True, 43 | mode=TargetNetworkMode.TARGET_NETWORK, 44 | encoding_length=4, 45 | learnable_encoding=False, 46 | ) 47 | hypernetwork = MLPHyperNetwork( 48 | target_network=target_network, 49 | shared_params=["freq", "w2", "b2"], 50 | input_size=32, 51 | layer_sizes=[64], 52 | ) 53 | 54 | z = torch.rand((2, 32)) 55 | x = torch.rand((2, 32, 1)) * 2 - 1 56 | 57 | weights = hypernetwork(z) 58 | y = target_network(x, weights=weights) 59 | 60 | assert y.shape == (2, 32, 1) 61 | 62 | 63 | def test_siren_hypernetwork_with_shared_layers_with_learnable_omega() -> None: 64 | target_network = SIREN( 65 | input_size=1, 66 | output_size=1, 67 | hidden_sizes=[32, 32, 16], 68 | bias=True, 69 | mode=TargetNetworkMode.TARGET_NETWORK, 70 | omega_0=30, 71 | omega_i=1, 72 | learnable_omega=True, 73 | gradient_fix=True, 74 | ) 75 | hypernetwork = MLPHyperNetwork( 76 | target_network=target_network, 77 | shared_params=["w3", "b3"], 78 | input_size=32, 79 | layer_sizes=[64], 80 | ) 81 | 82 | z = torch.rand((2, 32)) 83 | x = torch.rand((2, 32, 1)) * 2 - 1 84 | 85 | weights = hypernetwork(z) 86 | y = target_network(x, weights=weights) 87 | 88 | assert y.shape == (2, 32, 1) 89 | 90 | 91 | def test_nerf_hypernetwork_with_shared_layers_with_learnable_encoding() -> None: 92 | target_network = NERF( 93 | input_size=1, 94 | output_size=1, 95 | hidden_sizes=[32, 16, 8], 96 | bias=True, 97 | mode=TargetNetworkMode.TARGET_NETWORK, 98 | encoding_length=6, 99 | learnable_encoding=True, 100 | ) 101 | hypernetwork = MLPHyperNetwork( 102 | target_network=target_network, 103 | shared_params=["w2", "b2"], 104 | input_size=32, 105 | layer_sizes=[64], 106 | ) 107 | 108 | z = torch.rand((2, 32)) 109 | x = torch.rand((2, 32, 1)) * 2 - 1 110 | 111 | weights = hypernetwork(z) 112 | y = target_network(x, weights=weights) 113 | 114 | assert y.shape == (2, 32, 1) 115 | -------------------------------------------------------------------------------- /hypersound/datasets/audio.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch import Tensor 6 | from torch.utils.data.dataset import Dataset 7 | from torchaudio.transforms import AmplitudeToDB, MelSpectrogram 8 | 9 | IndicesAudioAndSpectrogram = tuple[Tensor, Tensor, Tensor] 10 | 11 | 12 | class AudioDataset(Dataset[IndicesAudioAndSpectrogram]): 13 | def __init__( 14 | self, 15 | dataset: Dataset[tuple[Tensor, int]], 16 | duration: float, 17 | sample_rate: int, 18 | n_fft: int, 19 | hop_length: int, 20 | n_mels: int, 21 | data_norm: bool, 22 | index_scaling: Optional[tuple[float, float]], 23 | proportional_index_scaling: bool, 24 | ): 25 | self._dataset = dataset 26 | self.duration = duration 27 | self.sample_rate = sample_rate 28 | self.n_fft = n_fft 29 | self.hop_length = hop_length 30 | self.n_mels = n_mels 31 | self.data_norm = data_norm 32 | 33 | self.total_samples = int(duration * sample_rate) 34 | self.indices = torch.arange(0, self.total_samples, dtype=torch.float).unsqueeze(-1) 35 | 36 | if index_scaling: 37 | min_val, max_val = index_scaling 38 | 39 | if proportional_index_scaling: 40 | # Normalize config setting to 1 second 41 | min_val *= self.duration 42 | max_val *= self.duration 43 | 44 | assert min_val < max_val 45 | self.indices = min_val + (max_val - min_val) * self.indices / (self.total_samples - 1) 46 | 47 | self.spec_transform = nn.ModuleList( 48 | [ 49 | MelSpectrogram(sample_rate=self.sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels), 50 | AmplitudeToDB(), 51 | ] 52 | ) 53 | 54 | @property 55 | def shape(self): 56 | n_channels = 1 57 | freq_size = self.n_mels 58 | 59 | return ( 60 | (self.total_samples, 1), 61 | (self.total_samples,), 62 | (n_channels, freq_size, 1 + self.total_samples // self.hop_length), 63 | ) 64 | 65 | def to_spectrogram(self, y: Tensor) -> Tensor: 66 | for transform in self.spec_transform: 67 | y = transform(y) 68 | y = y.unsqueeze(-3) 69 | return y 70 | 71 | def __getitem__(self, idx: int) -> IndicesAudioAndSpectrogram: 72 | # Load raw recording from the underlying dataset 73 | y = self._dataset[idx][0] 74 | samplerate = self._dataset[idx][1] 75 | 76 | y = y[0] # type: ignore 77 | assert samplerate == self.sample_rate, f"Got {samplerate}, expected {self.sample_rate}" 78 | assert y.dim() == 1 # type: ignore 79 | assert len(y) == self.total_samples, "Audio signal has an invalid length" 80 | 81 | # Normalize audio 82 | if self.data_norm and y.abs().max() > 0: 83 | y = y / y.abs().max() # type: ignore 84 | 85 | # Generate spectrogram 86 | spec = self.to_spectrogram(y) # type: ignore 87 | 88 | # Generate time indices 89 | indices = self.indices 90 | 91 | return indices, y, spec # type: ignore 92 | 93 | def __len__(self): 94 | return len(self._dataset) # type: ignore 95 | 96 | def __str__(self) -> str: 97 | return ( 98 | f" {self.shape}>" 101 | ) 102 | -------------------------------------------------------------------------------- /rave/README.md: -------------------------------------------------------------------------------- 1 | ![rave_logo](docs/rave.png) 2 | 3 | # RAVE: Realtime Audio Variational autoEncoder 4 | 5 | Official implementation of _RAVE: A variational autoencoder for fast and high-quality neural audio synthesis_ ([article link](https://arxiv.org/abs/2111.05011)) by Antoine Caillon and Philippe Esling. 6 | 7 | If you use RAVE as a part of a music performance or installation, be sure to cite either this repository or the article ! 8 | 9 | ## Colab 10 | 11 | We propose a Google Colab handling the training of a RAVE model on a custom dataset ! 12 | 13 | [![colab_badge](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1aK8K186QegnWVMAhfnFRofk_Jf7BBUxl?usp=sharing) 14 | 15 | ## Installation 16 | 17 | RAVE needs `python 3.9`. Install the dependencies using 18 | 19 | ```bash 20 | pip install -r requirements.txt 21 | ``` 22 | 23 | Detailed instructions to setup a training station for this project are available [here](docs/training_setup.md). 24 | 25 | ## Preprocessing 26 | 27 | RAVE comes with two command line utilities, `resample` and `duration`. `resample` allows to pre-process (silence removal, loudness normalization) and augment (compression) an entire directory of audio files (.mp3, .aiff, .opus, .wav, .aac). `duration` prints out the total duration of a .wav folder. 28 | 29 | ## Training 30 | 31 | Both RAVE and the prior model are available in this repo. For most users we recommand to use the `cli_helper.py` script, since it will generate a set of instructions allowing the training and export of both RAVE and the prior model on a specific dataset. 32 | 33 | ```bash 34 | python cli_helper.py 35 | ``` 36 | 37 | However, if you want to customize even more your training, you can use the provided `train_{rave, prior}.py` and `export_{rave, prior}.py` scripts manually. 38 | 39 | ## Reconstructing audio 40 | 41 | Once trained, you can reconstruct an entire folder containing wav files using 42 | 43 | ```bash 44 | python reconstruct.py --ckpt /path/to/checkpoint --wav-folder /path/to/wav/folder 45 | ``` 46 | 47 | You can also export RAVE to a `torchscript` file using `export_rave.py` and use the `encode` and `decode` methods on tensors. 48 | 49 | ## Realtime usage 50 | 51 | **UPDATE** 52 | 53 | If you want to use the realtime mode, you should update your dependencies ! 54 | 55 | ```bash 56 | pip install -r requirements.txt 57 | ``` 58 | 59 | RAVE and the prior model can be used in realtime on live audio streams, allowing creative interactions with both models. 60 | 61 | ### [nn~](https://github.com/acids-ircam/nn_tilde) 62 | 63 | RAVE is compatible with the **nn~** max/msp and PureData external. 64 | 65 | ![max_msp_screenshot](docs/maxmsp_screenshot.png) 66 | 67 | An audio example of the prior sampling patch is available in the `docs/` folder. 68 | 69 | ### [RAVE vst](https://github.com/acids-ircam/rave_vst) 70 | 71 | You can also use RAVE as a VST audio plugin using the RAVE vst ! 72 | 73 | ![plugin_screenshot](https://github.com/acids-ircam/rave_vst/blob/main/assets/rave_screenshot_audio_panel.png?raw=true) 74 | 75 | ## Discussion 76 | 77 | If you have questions, want to share your experience with RAVE or share musical pieces done with the model, you can use the [Discussion tab](https://github.com/acids-ircam/RAVE/discussions) ! 78 | 79 | ## Demonstation 80 | 81 | ### RAVE x nn~ 82 | 83 | Demonstration of what you can do with RAVE and the nn~ external for maxmsp ! 84 | 85 | [![RAVE x nn~](http://img.youtube.com/vi/dMZs04TzxUI/mqdefault.jpg)](https://www.youtube.com/watch?v=dMZs04TzxUI) 86 | 87 | ### embedded RAVE 88 | 89 | Using nn~ for puredata, RAVE can be used in realtime on embedded platforms ! 90 | 91 | [![RAVE x nn~](http://img.youtube.com/vi/jAIRf4nGgYI/mqdefault.jpg)](https://www.youtube.com/watch?v=jAIRf4nGgYI) 92 | -------------------------------------------------------------------------------- /rave/train_prior.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import random_split, DataLoader 3 | import pytorch_lightning as pl 4 | 5 | from rave.core import search_for_run 6 | 7 | from prior.model import Model 8 | from effortless_config import Config 9 | from os import environ, path 10 | import os 11 | 12 | from udls import SimpleDataset, simple_audio_preprocess 13 | import numpy as np 14 | 15 | import math 16 | 17 | import GPUtil as gpu 18 | 19 | 20 | class args(Config): 21 | RESOLUTION = 32 22 | 23 | RES_SIZE = 512 24 | SKP_SIZE = 256 25 | KERNEL_SIZE = 3 26 | CYCLE_SIZE = 4 27 | N_LAYERS = 10 28 | PRETRAINED_VAE = None 29 | 30 | PREPROCESSED = None 31 | WAV = None 32 | N_SIGNAL = 65536 33 | 34 | BATCH = 8 35 | CKPT = None 36 | MAX_STEPS = 10000000 37 | VAL_EVERY = 10000 38 | 39 | NAME = None 40 | 41 | 42 | args.parse_args() 43 | assert args.NAME is not None 44 | 45 | 46 | def get_n_signal(a, m): 47 | k = a.KERNEL_SIZE 48 | cs = a.CYCLE_SIZE 49 | l = a.N_LAYERS 50 | 51 | rf = (k - 1) * sum(2**(np.arange(l) % cs)) + 1 52 | ratio = m.encode_params[-1].item() 53 | 54 | return 2**math.ceil(math.log2(rf * ratio)) 55 | 56 | 57 | model = Model( 58 | resolution=args.RESOLUTION, 59 | res_size=args.RES_SIZE, 60 | skp_size=args.SKP_SIZE, 61 | kernel_size=args.KERNEL_SIZE, 62 | cycle_size=args.CYCLE_SIZE, 63 | n_layers=args.N_LAYERS, 64 | pretrained_vae=args.PRETRAINED_VAE, 65 | ) 66 | 67 | args.N_SIGNAL = max(args.N_SIGNAL, get_n_signal(args, model.synth)) 68 | 69 | dataset = SimpleDataset( 70 | args.PREPROCESSED, 71 | args.WAV, 72 | preprocess_function=simple_audio_preprocess(model.sr, args.N_SIGNAL), 73 | split_set="full", 74 | transforms=lambda x: x.reshape(1, -1).astype(np.float32), 75 | ) 76 | 77 | val = max((2 * len(dataset)) // 100, 1) 78 | train = len(dataset) - val 79 | train, val = random_split(dataset, [train, val]) 80 | 81 | num_workers = 0 if os.name == "nt" else 8 82 | train = DataLoader(train, args.BATCH, True, drop_last=True, num_workers=num_workers) 83 | val = DataLoader(val, args.BATCH, False, num_workers=num_workers) 84 | 85 | # CHECKPOINT CALLBACKS 86 | validation_checkpoint = pl.callbacks.ModelCheckpoint( 87 | monitor="validation", 88 | filename="best", 89 | ) 90 | last_checkpoint = pl.callbacks.ModelCheckpoint(filename="last") 91 | 92 | CUDA = gpu.getAvailable(maxMemory=.05) 93 | VISIBLE_DEVICES = environ.get("CUDA_VISIBLE_DEVICES", "") 94 | 95 | if VISIBLE_DEVICES: 96 | use_gpu = int(int(VISIBLE_DEVICES) >= 0) 97 | elif len(CUDA): 98 | environ["CUDA_VISIBLE_DEVICES"] = str(CUDA[0]) 99 | use_gpu = 1 100 | elif torch.cuda.is_available(): 101 | print("Cuda is available but no fully free GPU found.") 102 | print("Training may be slower due to concurrent processes.") 103 | use_gpu = 1 104 | else: 105 | print("No GPU found.") 106 | use_gpu = 0 107 | 108 | val_check = {} 109 | if len(train) >= args.VAL_EVERY: 110 | val_check["val_check_interval"] = args.VAL_EVERY 111 | else: 112 | nepoch = args.VAL_EVERY // len(train) 113 | val_check["check_val_every_n_epoch"] = nepoch 114 | 115 | trainer = pl.Trainer( 116 | logger=pl.loggers.TensorBoardLogger(path.join("runs", args.NAME), 117 | name="prior"), 118 | gpus=use_gpu, 119 | callbacks=[validation_checkpoint, last_checkpoint], 120 | max_epochs=100000, 121 | max_steps=args.MAX_STEPS, 122 | **val_check, 123 | ) 124 | 125 | run = search_for_run(args.CKPT) 126 | if run is not None: 127 | step = torch.load(run, map_location='cpu')["global_step"] 128 | trainer.fit_loop.epoch_loop._batches_that_stepped = step 129 | 130 | trainer.fit(model, train, val, ckpt_path=run) 131 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from pathlib import Path 3 | from typing import Optional, Type, cast 4 | 5 | import hydra 6 | import pytorch_lightning as pl 7 | import pytorch_yard 8 | from omegaconf import OmegaConf 9 | from omegaconf.dictconfig import DictConfig 10 | from pytorch_yard.configs import get_tags 11 | from pytorch_yard.experiments.lightning import LightningExperiment 12 | from pytorch_yard.utils.logging import info, info_bold 13 | 14 | from hypersound.cfg import Settings 15 | from hypersound.datasets.utils import init_datamodule 16 | from hypersound.systems.main import HyperNetworkAE 17 | 18 | 19 | class HyperSound(LightningExperiment): 20 | def __init__(self, config_path: str, settings_cls: Type[Settings], settings_group: Optional[str] = None) -> None: 21 | super().__init__(config_path, settings_cls, settings_group=settings_group) 22 | 23 | self.cfg: Settings 24 | """ Experiment config. """ 25 | 26 | def entry(self, root_cfg: pytorch_yard.RootConfig): 27 | super().entry(root_cfg) 28 | 29 | # Do not use pytorch-yard template specializations as we use a monolithic `main` here. 30 | def setup_system(self): 31 | pass 32 | 33 | def setup_datamodule(self): 34 | pass 35 | 36 | # ------------------------------------------------------------------------ 37 | # Experiment specific code 38 | # ------------------------------------------------------------------------ 39 | def main(self): 40 | # -------------------------------------------------------------------- 41 | # W&B init 42 | # -------------------------------------------------------------------- 43 | tags: list[str] = get_tags(cast(DictConfig, self.root_cfg)) 44 | self.run.tags = tags 45 | self.run.notes = str(self.root_cfg.notes) 46 | self.wandb_logger.log_hyperparams(OmegaConf.to_container(self.root_cfg.cfg, resolve=True)) # type: ignore 47 | 48 | # -------------------------------------------------------------------- 49 | # Data module setup 50 | # -------------------------------------------------------------------- 51 | Path(self.root_cfg.data_dir).mkdir(parents=True, exist_ok=True) 52 | 53 | self.datamodule, _ = init_datamodule(self.root_cfg) 54 | self.datamodule.prepare_data() 55 | 56 | # -------------------------------------------------------------------- 57 | # System setup 58 | # -------------------------------------------------------------------- 59 | self.system = HyperNetworkAE( 60 | cfg=self.cfg, 61 | input_length=self.datamodule.train.shape[1][0], 62 | spec_transform=copy.deepcopy(self.datamodule.train.spec_transform), 63 | ) 64 | 65 | info_bold("System architecture:") 66 | info(str(self.system)) 67 | # info_bold(f"Size of target network: {cast(Any, self.system.target_network).num_params:,d}") 68 | 69 | info_bold(f"Input shape: {self.datamodule.shape}") 70 | 71 | # -------------------------------------------------------------------- 72 | # Trainer setup 73 | # -------------------------------------------------------------------- 74 | self.setup_callbacks() 75 | 76 | num_sanity_val_steps = -1 if self.cfg.validate_before_training else 0 77 | 78 | self.trainer: pl.Trainer = hydra.utils.instantiate( # type: ignore 79 | self.cfg.pl, 80 | logger=self.wandb_logger, 81 | callbacks=self.callbacks, 82 | enable_checkpointing=self.cfg.save_checkpoints, 83 | num_sanity_val_steps=num_sanity_val_steps, 84 | ) 85 | 86 | self.trainer.fit( # type: ignore 87 | self.system, 88 | datamodule=self.datamodule, 89 | ckpt_path=self.cfg.resume_path, 90 | ) 91 | 92 | 93 | if __name__ == "__main__": 94 | HyperSound("hypersound", Settings) 95 | -------------------------------------------------------------------------------- /tests/hypersound/models/meta/test_inr.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch.nn 4 | import torch.nn.functional as F 5 | from torch import Tensor 6 | 7 | from hypersound.cfg import TargetNetworkMode 8 | from hypersound.models.meta.inr import INR 9 | 10 | EPS = 1e-3 11 | 12 | 13 | @pytest.mark.parametrize( 14 | "num_samples", 15 | [ 16 | 1, 17 | 50, 18 | 32768, 19 | ], 20 | ) 21 | def test_inr_standard_inference(num_samples: int) -> None: 22 | model = INR( 23 | input_size=1, 24 | output_size=1, 25 | hidden_sizes=[16], 26 | activation_fn=torch.nn.ReLU(), 27 | bias=True, 28 | mode=TargetNetworkMode.INR, 29 | ) 30 | 31 | # x --> (num_samples, 1) 32 | x: Tensor = torch.rand((num_samples, 1)) 33 | 34 | assert model(x, weights=None).shape == (num_samples, 1) 35 | 36 | 37 | @pytest.mark.parametrize( 38 | "num_samples", 39 | [ 40 | 1, 41 | 50, 42 | 32768, 43 | ], 44 | ) 45 | def test_inr_train_loop(num_samples: int) -> None: 46 | model = INR( 47 | input_size=1, 48 | output_size=1, 49 | hidden_sizes=[16], 50 | activation_fn=torch.nn.ReLU(), 51 | bias=True, 52 | mode=TargetNetworkMode.INR, 53 | ) 54 | optimizer = torch.optim.Adam(model.parameters()) 55 | loss_fn = torch.nn.L1Loss() 56 | x = torch.rand((num_samples, 1)) 57 | y = torch.ones((num_samples, 1)) 58 | y0 = model(x) 59 | 60 | for _ in range(25): 61 | optimizer.zero_grad() 62 | loss = loss_fn(model(x), y) 63 | loss.backward() 64 | optimizer.step() 65 | 66 | assert F.mse_loss(y.squeeze(), y0.squeeze()) > EPS 67 | 68 | 69 | @pytest.mark.parametrize( 70 | "num_samples", 71 | [ 72 | 1, 73 | 50, 74 | 32768, 75 | ], 76 | ) 77 | def test_inr_tn_equivalence(num_samples: int) -> None: 78 | x = torch.rand((num_samples, 1)) 79 | 80 | inr = INR( 81 | input_size=1, 82 | output_size=1, 83 | hidden_sizes=[16], 84 | activation_fn=torch.nn.ReLU(), 85 | bias=True, 86 | mode=TargetNetworkMode.INR, 87 | ) 88 | y_inr = inr(x) 89 | 90 | tn = INR( 91 | input_size=1, 92 | output_size=1, 93 | hidden_sizes=[16], 94 | activation_fn=torch.nn.ReLU(), 95 | bias=True, 96 | mode=TargetNetworkMode.TARGET_NETWORK, 97 | ) 98 | weights = {name: param.unsqueeze(0) for name, param in inr.params.items()} 99 | y_tn = tn(x.unsqueeze(0), weights=weights) 100 | 101 | assert F.mse_loss(y_inr.squeeze(), y_tn.squeeze()) < EPS 102 | 103 | 104 | @pytest.mark.parametrize( 105 | "n_models, num_samples, input_size, output_size, hidden_sizes", 106 | [ 107 | (5, 32768, 1, 1, [16]), 108 | (4, 32768, 3, 2, [32, 4]), 109 | (2, 32768, 1, 1, [10, 20, 30]), 110 | (1, 32768, 2, 1, [128, 12]), 111 | (3, 32768, 1, 2, [1, 23, 12]), 112 | (7, 32768, 3, 1, [4, 6, 7, 8, 5, 7]), 113 | (6, 32768, 3, 2, [3, 3, 2, 1, 4, 2, 1, 4, 5, 3, 1, 2, 3, 4, 1, 2]), 114 | ], 115 | ) 116 | def test_inr_tn_batch_equivalence( 117 | n_models: int, 118 | num_samples: int, 119 | input_size: int, 120 | output_size: int, 121 | hidden_sizes: list[int], 122 | ) -> None: 123 | x_inr = [torch.rand((num_samples, input_size)) for _ in range(n_models)] 124 | x_tn = torch.stack(x_inr, 0) 125 | 126 | inrs = [ 127 | INR( 128 | input_size=input_size, 129 | output_size=output_size, 130 | hidden_sizes=hidden_sizes, 131 | activation_fn=torch.nn.ReLU(), 132 | bias=True, 133 | mode=TargetNetworkMode.INR, 134 | ) 135 | for _ in range(n_models) 136 | ] 137 | y_inr = torch.stack([inr(x) for inr, x in zip(inrs, x_inr)], dim=0) 138 | 139 | weights = {} 140 | for param_key in inrs[0].params.keys(): 141 | weights[param_key] = torch.stack([model.params[param_key] for model in inrs], dim=0) 142 | 143 | tn = INR( 144 | input_size=input_size, 145 | output_size=output_size, 146 | hidden_sizes=hidden_sizes, 147 | activation_fn=torch.nn.ReLU(), 148 | bias=True, 149 | mode=TargetNetworkMode.TARGET_NETWORK, 150 | ) 151 | y_tn = tn(x_tn, weights=weights) 152 | 153 | assert F.mse_loss(y_inr.squeeze(), y_tn.squeeze()) < EPS 154 | -------------------------------------------------------------------------------- /tests/hypersound/models/test_siren.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch.nn 4 | import torch.nn.functional as F 5 | 6 | from hypersound.cfg import TargetNetworkMode 7 | from hypersound.models.siren import SIREN 8 | 9 | EPS = 1e-3 10 | 11 | 12 | @pytest.mark.parametrize( 13 | "num_samples", 14 | [ 15 | 1, 16 | 50, 17 | 32768, 18 | ], 19 | ) 20 | def test_siren_standard_inference(num_samples: int) -> None: 21 | model = SIREN( 22 | input_size=1, 23 | output_size=1, 24 | hidden_sizes=[16], 25 | bias=True, 26 | mode=TargetNetworkMode.INR, 27 | omega_0=1.0, 28 | omega_i=1.0, 29 | learnable_omega=False, 30 | gradient_fix=True, 31 | ) 32 | x = torch.rand((num_samples, 1)) 33 | 34 | assert model(x, weights=None).shape == (num_samples, 1) 35 | 36 | 37 | @pytest.mark.parametrize( 38 | "num_samples", 39 | [ 40 | 1, 41 | 50, 42 | 32768, 43 | ], 44 | ) 45 | def test_siren_train_loop(num_samples: int) -> None: 46 | model = SIREN( 47 | input_size=1, 48 | output_size=1, 49 | hidden_sizes=[16], 50 | mode=TargetNetworkMode.INR, 51 | bias=True, 52 | omega_0=30.0, 53 | omega_i=1.0, 54 | learnable_omega=False, 55 | gradient_fix=True, 56 | ) 57 | optimizer = torch.optim.Adam(model.parameters()) 58 | loss_fn = torch.nn.L1Loss() 59 | x = torch.ones((num_samples, 1)) 60 | y = torch.rand((num_samples, 1)) 61 | y0 = model(x) 62 | 63 | for _ in range(25): 64 | optimizer.zero_grad() 65 | loss = loss_fn(model(x), y) 66 | loss.backward() 67 | optimizer.step() 68 | 69 | assert F.mse_loss(y.squeeze(), y0.squeeze()) > EPS 70 | 71 | 72 | @pytest.mark.parametrize( 73 | "num_samples", 74 | [ 75 | 1, 76 | 50, 77 | 32768, 78 | ], 79 | ) 80 | def test_siren_inr_tn_equivalence(num_samples: int) -> None: 81 | x = torch.rand((num_samples, 1)) 82 | 83 | inr = SIREN( 84 | input_size=1, 85 | output_size=1, 86 | hidden_sizes=[16], 87 | bias=True, 88 | mode=TargetNetworkMode.INR, 89 | omega_0=30.0, 90 | omega_i=1, 91 | learnable_omega=False, 92 | gradient_fix=True, 93 | ) 94 | y_inr = inr(x) 95 | 96 | tn = SIREN( 97 | input_size=1, 98 | output_size=1, 99 | hidden_sizes=[16], 100 | bias=True, 101 | mode=TargetNetworkMode.TARGET_NETWORK, 102 | omega_0=30.0, 103 | omega_i=1, 104 | learnable_omega=False, 105 | gradient_fix=True, 106 | ) 107 | weights = {name: param.unsqueeze(0) for name, param in inr.params.items()} 108 | y_tn = tn(x.unsqueeze(0), weights=weights) 109 | 110 | assert F.mse_loss(y_inr.squeeze(), y_tn.squeeze()) < EPS 111 | 112 | 113 | @pytest.mark.parametrize( 114 | "n_models, num_samples, input_size, output_size, hidden_sizes", 115 | [ 116 | (5, 32768, 1, 1, [16]), 117 | (4, 32768, 3, 2, [32, 4]), 118 | (2, 32768, 1, 1, [10, 20, 30]), 119 | (1, 32768, 2, 1, [128, 12]), 120 | (3, 32768, 1, 2, [1, 23, 12]), 121 | (7, 32768, 3, 1, [4, 6, 7, 8, 5, 7]), 122 | (6, 32768, 3, 2, [3, 3, 2, 1, 4, 2, 1, 4, 5, 3, 1, 2, 3, 4, 1, 2]), 123 | ], 124 | ) 125 | def test_siren_inr_tn_batch_equivalence( 126 | n_models: int, 127 | num_samples: int, 128 | input_size: int, 129 | output_size: int, 130 | hidden_sizes: list[int], 131 | ) -> None: 132 | x_inr = [torch.rand((num_samples, input_size)) for _ in range(n_models)] 133 | x_tn = torch.stack(x_inr, 0) 134 | 135 | inrs = [ 136 | SIREN( 137 | input_size=input_size, 138 | output_size=output_size, 139 | hidden_sizes=hidden_sizes, 140 | bias=True, 141 | mode=TargetNetworkMode.INR, 142 | omega_0=30.0, 143 | omega_i=30.0, 144 | learnable_omega=False, 145 | gradient_fix=True, 146 | ) 147 | for _ in range(n_models) 148 | ] 149 | y_inr = torch.stack([model(x) for model, x in zip(inrs, x_inr)], dim=0) 150 | 151 | weights = {} 152 | for param_key in inrs[0].params.keys(): 153 | if inrs[0].params[param_key].requires_grad: 154 | weights[param_key] = torch.stack([model.params[param_key] for model in inrs], dim=0) 155 | 156 | tn = SIREN( 157 | input_size=input_size, 158 | output_size=output_size, 159 | hidden_sizes=hidden_sizes, 160 | bias=True, 161 | mode=TargetNetworkMode.TARGET_NETWORK, 162 | omega_0=30.0, 163 | omega_i=30.0, 164 | learnable_omega=False, 165 | gradient_fix=True, 166 | ) 167 | y_tn = tn(x_tn, weights=weights) 168 | 169 | assert F.mse_loss(y_inr.squeeze(), y_tn.squeeze()) < EPS 170 | -------------------------------------------------------------------------------- /hypersound/systems/loss.py: -------------------------------------------------------------------------------- 1 | import librosa.filters 2 | import numpy as np 3 | import numpy.typing as npt 4 | import torch 5 | from torch import Tensor 6 | 7 | 8 | class MultiSTFTLoss(torch.nn.Module): 9 | def __init__( 10 | self, 11 | fft_sizes: list[int], 12 | hop_sizes: list[int], 13 | win_lengths: list[int], 14 | sample_rate: int, 15 | n_bins: int, 16 | freq_weights_p: float = 0.0, 17 | freq_weights_warmup_epochs: int = 500, 18 | ): 19 | super().__init__() 20 | assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) # must define all 21 | self.fft_sizes = fft_sizes 22 | self.hop_sizes = hop_sizes 23 | self.win_lengths = win_lengths 24 | 25 | self.stft_losses = torch.nn.ModuleList() 26 | for fft_size, hop_size, win_length in zip(fft_sizes, hop_sizes, win_lengths): 27 | self.stft_losses += [ 28 | STFTLoss( 29 | fft_size, 30 | hop_size, 31 | win_length, 32 | sample_rate, 33 | n_bins, 34 | freq_weights_p=freq_weights_p, 35 | freq_weights_warmup_epochs=freq_weights_warmup_epochs, 36 | ) 37 | ] 38 | 39 | def update(self, current_epoch) -> None: 40 | for loss in self.stft_losses: 41 | loss.update(current_epoch) 42 | 43 | def forward(self, x: Tensor, x_hat: Tensor) -> Tensor: # type: ignore 44 | return torch.stack([loss(x, x_hat) for loss in self.stft_losses]).mean() 45 | 46 | 47 | class STFTLoss(torch.nn.Module): 48 | def __init__( 49 | self, 50 | fft_size: int, 51 | hop_size: int, 52 | win_length: int, 53 | sample_rate: int, 54 | n_bins: int, 55 | freq_weights_p: float, 56 | freq_weights_warmup_epochs: int, 57 | ): 58 | super().__init__() 59 | self.fft_size = fft_size 60 | self.hop_size = hop_size 61 | self.win_length = win_length 62 | self.window = torch.hann_window(win_length) 63 | self.sample_rate = sample_rate 64 | self.n_bins = n_bins 65 | 66 | self.eps = 1e-8 67 | 68 | assert sample_rate is not None # Must set sample rate to use mel scale 69 | # assert n_bins <= fft_size # Must be more FFT bins than Mel bins 70 | fb: npt.NDArray[np.float32] = librosa.filters.mel( 71 | sr=sample_rate, n_fft=fft_size, n_mels=n_bins 72 | ) # type: ignore # noqa 73 | 74 | self.fb: Tensor 75 | self.register_buffer( 76 | "fb", 77 | torch.tensor(fb).unsqueeze(0), 78 | ) 79 | 80 | self.freq_weights_p = freq_weights_p 81 | self.freq_weights_warmup_epochs = freq_weights_warmup_epochs 82 | self.mask = self.compute_mask(0) 83 | 84 | def update(self, current_epoch) -> None: 85 | self.mask = self.compute_mask(current_epoch) 86 | 87 | def compute_mask(self, current_epoch: int) -> Tensor: 88 | if self.freq_weights_p == 0: 89 | return torch.ones((self.n_bins, 1)) 90 | else: 91 | mask = (torch.arange(128) + 1) ** self.freq_weights_p 92 | 93 | if current_epoch > self.freq_weights_warmup_epochs: 94 | return self.mask 95 | 96 | else: 97 | mask = mask / mask.sum() * self.n_bins # Make mask sum to the same value as mask of ones 98 | alpha = 1 - current_epoch / self.freq_weights_warmup_epochs 99 | mask = alpha * torch.ones(mask.shape) + (1 - alpha) * mask 100 | return mask.unsqueeze(1) 101 | 102 | def stft(self, x: Tensor) -> tuple[Tensor, Tensor]: 103 | x_stft = torch.stft( 104 | x, 105 | self.fft_size, 106 | self.hop_size, 107 | self.win_length, 108 | self.window, 109 | return_complex=True, 110 | ) 111 | x_mag = torch.sqrt(torch.clamp((x_stft.real**2) + (x_stft.imag**2), min=self.eps)) 112 | x_phs = torch.angle(x_stft) 113 | 114 | return x_mag, x_phs 115 | 116 | def forward(self, x: Tensor, x_hat: Tensor) -> Tensor: # type: ignore 117 | self.window = self.window.to(x.device) 118 | x_mag, _ = self.stft(x.view(-1, x.size(-1))) 119 | x_hat_mag, _ = self.stft(x_hat.view(-1, x_hat.size(-1))) 120 | 121 | x_mag = torch.matmul(self.fb, x_mag) 122 | x_hat_mag = torch.matmul(self.fb, x_hat_mag) 123 | 124 | # Standardize? 125 | # x_mag = (x_mag - x_mag.mean([1, 2], keepdim=True)) / x_mag.std([1, 2], keepdim=True) 126 | # x_hat_mag = (x_hat_mag - x_hat_mag.mean([1, 2], keepdim=True)) / x_hat_mag.std([1, 2], keepdim=True) 127 | 128 | # compute loss terms 129 | l1_corrected = (x_mag - x_hat_mag).abs() * self.mask.to(x.device) 130 | l1 = l1_corrected.mean() 131 | l2_corrected = (x_mag - x_hat_mag).pow(2) * self.mask.to(x.device) 132 | l2 = l2_corrected.mean().sqrt() 133 | 134 | return l1 + l2 135 | -------------------------------------------------------------------------------- /tests/hypersound/models/test_nerf.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch.nn 4 | import torch.nn.functional as F 5 | from torch import Tensor 6 | 7 | from hypersound.cfg import TargetNetworkMode 8 | from hypersound.models.nerf import NERF 9 | 10 | EPS = 1e-3 11 | 12 | 13 | @pytest.mark.parametrize( 14 | "num_samples, encoding_length", 15 | [ 16 | (1, 1), 17 | (1, 6), 18 | (50, 5), 19 | (32768, 1), 20 | (32768, 10), 21 | (32768, 16), 22 | ], 23 | ) 24 | def test_nerf_standard_inference(num_samples: int, encoding_length: int) -> None: 25 | model = NERF( 26 | input_size=1, 27 | output_size=1, 28 | hidden_sizes=[16], 29 | bias=True, 30 | mode=TargetNetworkMode.INR, 31 | encoding_length=encoding_length, 32 | learnable_encoding=False, 33 | ) 34 | 35 | x: Tensor = torch.rand((num_samples, 1)) 36 | 37 | assert model(x, weights=None).shape == (num_samples, 1) 38 | 39 | 40 | @pytest.mark.parametrize( 41 | "num_samples, encoding_length", 42 | [ 43 | (1, 1), 44 | (1, 6), 45 | (50, 5), 46 | (32768, 1), 47 | (32768, 10), 48 | (32768, 16), 49 | ], 50 | ) 51 | def test_nerf_train_loop(num_samples: int, encoding_length: int) -> None: 52 | model = NERF( 53 | input_size=1, 54 | output_size=1, 55 | hidden_sizes=[16], 56 | bias=True, 57 | mode=TargetNetworkMode.INR, 58 | encoding_length=encoding_length, 59 | learnable_encoding=False, 60 | ) 61 | optimizer = torch.optim.Adam(model.parameters()) 62 | loss_fn = torch.nn.L1Loss() 63 | x = torch.rand((num_samples, 1)) 64 | y = torch.rand((num_samples, 1)) 65 | y0 = model(x) 66 | 67 | for _ in range(25): 68 | optimizer.zero_grad() 69 | loss = loss_fn(model(x), y) 70 | loss.backward() 71 | optimizer.step() 72 | 73 | assert F.mse_loss(y.squeeze(), y0.squeeze()) > EPS 74 | 75 | 76 | @pytest.mark.parametrize( 77 | "num_samples, encoding_length", 78 | [ 79 | (1, 1), 80 | (1, 6), 81 | (50, 5), 82 | (32768, 1), 83 | (32768, 10), 84 | (32768, 16), 85 | ], 86 | ) 87 | def test_nerf_inr_tn_equivalence(num_samples: int, encoding_length: int) -> None: 88 | x = torch.rand((num_samples, 1)) 89 | 90 | inr = NERF( 91 | input_size=1, 92 | output_size=1, 93 | hidden_sizes=[16], 94 | bias=True, 95 | mode=TargetNetworkMode.INR, 96 | encoding_length=encoding_length, 97 | learnable_encoding=False, 98 | ) 99 | y_inr = inr(x) 100 | 101 | tn = NERF( 102 | input_size=1, 103 | output_size=1, 104 | hidden_sizes=[16], 105 | bias=True, 106 | mode=TargetNetworkMode.TARGET_NETWORK, 107 | encoding_length=encoding_length, 108 | learnable_encoding=False, 109 | ) 110 | weights = {name: param.unsqueeze(0) for name, param in inr.params.items() if param.requires_grad} 111 | y_tn = tn(x.unsqueeze(0), weights=weights) 112 | 113 | assert F.mse_loss(y_inr.squeeze(), y_tn.squeeze()) < EPS 114 | 115 | 116 | @pytest.mark.parametrize( 117 | "n_models, num_samples, input_size, output_size, hidden_sizes, encoding_length", 118 | [ 119 | (5, 32768, 1, 1, [16], 10), 120 | (4, 32768, 3, 2, [32, 4], 2), 121 | (2, 32768, 1, 1, [10, 20, 30], 16), 122 | (1, 32768, 2, 1, [128, 12], 1), 123 | (3, 32768, 1, 2, [1, 23, 12], 10), 124 | (7, 32768, 3, 1, [4, 6, 7, 8, 5, 7], 16), 125 | (6, 32768, 3, 2, [3, 3, 2, 1, 4, 2, 1, 4, 5, 3, 1, 2, 3, 4, 1, 2], 5), 126 | ], 127 | ) 128 | def test_nerf_inr_tn_batch_equivalence( 129 | n_models: int, 130 | num_samples: int, 131 | input_size: int, 132 | output_size: int, 133 | hidden_sizes: list[int], 134 | encoding_length: int, 135 | ) -> None: 136 | x_inr = [torch.rand((num_samples, input_size)) for _ in range(n_models)] 137 | x_tn = torch.stack(x_inr, dim=0) 138 | 139 | inrs = [ 140 | NERF( 141 | input_size=input_size, 142 | output_size=output_size, 143 | hidden_sizes=hidden_sizes, 144 | bias=True, 145 | mode=TargetNetworkMode.INR, 146 | encoding_length=encoding_length, 147 | learnable_encoding=False, 148 | ) 149 | for _ in range(n_models) 150 | ] 151 | y_inr = torch.stack([model(x) for model, x in zip(inrs, x_inr)], dim=0) 152 | 153 | weights = {} 154 | for param_key in inrs[0].params.keys(): 155 | if inrs[0].params[param_key].requires_grad: 156 | weights[param_key] = torch.stack([model.params[param_key] for model in inrs], dim=0) 157 | 158 | tn = NERF( 159 | input_size=input_size, 160 | output_size=output_size, 161 | hidden_sizes=hidden_sizes, 162 | bias=True, 163 | mode=TargetNetworkMode.TARGET_NETWORK, 164 | encoding_length=encoding_length, 165 | learnable_encoding=False, 166 | ) 167 | y_tn = tn(x_tn, weights=weights) 168 | 169 | assert F.mse_loss(y_inr.squeeze(), y_tn.squeeze()) < EPS 170 | -------------------------------------------------------------------------------- /train_inr.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from pathlib import Path 3 | from typing import Optional, Type, cast 4 | 5 | import hydra 6 | import pytorch_lightning as pl 7 | import pytorch_yard 8 | import torch 9 | import torch.utils.data 10 | from dotenv import load_dotenv # type: ignore 11 | from omegaconf import OmegaConf 12 | from omegaconf.dictconfig import DictConfig 13 | from pytorch_yard.configs import get_tags 14 | from pytorch_yard.experiments.lightning import LightningExperiment 15 | from torch import Tensor 16 | from torch.utils.data import RandomSampler, TensorDataset 17 | 18 | from hypersound.cfg import Settings 19 | from hypersound.datasets.utils import init_datamodule 20 | from hypersound.utils.metrics import reduce_metric 21 | from inr.systems.main import INRSystem 22 | 23 | 24 | class SingleINRExperiment(LightningExperiment): 25 | def __init__(self, config_path: str, settings_cls: Type[Settings], settings_group: Optional[str] = None) -> None: 26 | super().__init__(config_path, settings_cls, settings_group=settings_group) 27 | 28 | self.cfg: Settings 29 | """ Experiment config. """ 30 | 31 | def entry(self, root_cfg: pytorch_yard.RootConfig): 32 | super().entry(root_cfg) 33 | 34 | # Do not use pytorch-yard template specializations as we use a monolithic `main` here. 35 | def setup_system(self): 36 | pass 37 | 38 | def setup_datamodule(self): 39 | pass 40 | 41 | # ------------------------------------------------------------------------ 42 | # Experiment specific code 43 | # ------------------------------------------------------------------------ 44 | def main(self): 45 | # -------------------------------------------------------------------- 46 | # W&B init 47 | # -------------------------------------------------------------------- 48 | tags: list[str] = get_tags(cast(DictConfig, self.root_cfg)) 49 | self.run.tags = tags 50 | self.run.notes = str(self.root_cfg.notes) 51 | self.wandb_logger.log_hyperparams(OmegaConf.to_container(self.root_cfg.cfg, resolve=True)) # type: ignore 52 | 53 | # -------------------------------------------------------------------- 54 | # Data module setup 55 | # -------------------------------------------------------------------- 56 | Path(self.root_cfg.data_dir).mkdir(parents=True, exist_ok=True) 57 | 58 | self.root_cfg.cfg = cast(Settings, self.root_cfg.cfg) 59 | self.root_cfg.cfg.batch_size = 1 60 | self.root_cfg.cfg.save_checkpoints = False 61 | 62 | self.datamodule, _ = init_datamodule(self.root_cfg) 63 | self.datamodule.prepare_data() 64 | self.datamodule.setup() 65 | 66 | # -------------------------------------------------------------------- 67 | # Trainer setup 68 | # -------------------------------------------------------------------- 69 | self.setup_callbacks() 70 | 71 | steps_to_log = range(self.cfg.log.examples) 72 | combined_metrics: list[dict[str, Tensor]] = [] 73 | 74 | for i, (indices, audio, spectrograms) in enumerate(self.datamodule.val_dataloader()): 75 | callbacks = copy.deepcopy(self.callbacks) 76 | 77 | self.trainer: pl.Trainer = hydra.utils.instantiate( # type: ignore 78 | self.cfg.pl, 79 | logger=self.wandb_logger, 80 | callbacks=callbacks, 81 | enable_checkpointing=False, 82 | num_sanity_val_steps=0, 83 | ) 84 | 85 | indices = torch.cat([indices]) 86 | audio = torch.cat([audio]) 87 | spectrograms = torch.cat([spectrograms]) 88 | dataset = TensorDataset(indices, audio, spectrograms) 89 | 90 | dataloader = torch.utils.data.DataLoader( 91 | dataset, 92 | batch_size=1, 93 | sampler=RandomSampler(dataset, replacement=True, num_samples=self.cfg.data.samples_per_epoch), 94 | num_workers=self.cfg.data.num_workers, 95 | ) 96 | 97 | log_reconstruction = i in steps_to_log 98 | 99 | self.system = INRSystem( 100 | cfg=self.cfg, 101 | spec_transform=copy.deepcopy(self.datamodule.train.spec_transform), 102 | idx=i, 103 | extended_logging=log_reconstruction, 104 | ) 105 | self.trainer.fit( # type: ignore 106 | self.system, 107 | train_dataloaders=dataloader, 108 | ckpt_path=None, 109 | ) 110 | combined_metrics.append(self.system.metrics) 111 | 112 | assert isinstance(self.system, INRSystem) 113 | self.wandb_logger.experiment.summary["combined_metrics/compression_ratio"] = self.system.compression_ratio() # type: ignore # noqa 114 | self.wandb_logger.experiment.summary["combined_metrics/inr_idx"] = i + 1 # type: ignore 115 | for key in combined_metrics[0]: 116 | self.wandb_logger.experiment.summary[f"combined_metrics/{key}"] = reduce_metric(combined_metrics, key) # type: ignore # noqa 117 | 118 | 119 | if __name__ == "__main__": 120 | load_dotenv(".env.inr", verbose=True, override=True) 121 | SingleINRExperiment("hypersound", Settings) 122 | -------------------------------------------------------------------------------- /rave/cli_helper.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | import shutil 3 | 4 | 5 | class Print: 6 | def __init__(self): 7 | self.msg = "" 8 | 9 | def __call__(self, msg, end="\n"): 10 | print(msg, end=end) 11 | self.msg += msg + end 12 | 13 | 14 | p = Print() 15 | 16 | 17 | def header(title: str, out=p): 18 | term_size = shutil.get_terminal_size().columns 19 | term_size = min(120, term_size) 20 | pad_l = (term_size - len(title)) // 2 21 | 22 | out(term_size * "=") 23 | out(pad_l * " " + title.upper()) 24 | out(term_size * "=") 25 | out("") 26 | 27 | 28 | def subsection(title: str): 29 | p("") 30 | size = len(title) 31 | p(title.upper()) 32 | p(size * "-") 33 | p("") 34 | 35 | 36 | if __name__ == "__main__": 37 | header("rave command line helper", out=print) 38 | 39 | name = "" 40 | while not name: 41 | name = input("choose a name for the training: ") 42 | name = name.lower().replace(" ", "_") 43 | 44 | data = "" 45 | while not data: 46 | data = input("path to the .wav files: ") 47 | 48 | preprocessed = "" 49 | while not preprocessed: 50 | preprocessed = input("temporary folder (fast drive): ") 51 | 52 | sampling_rate = input("sampling rate (defaults to 48000): ") 53 | multiband_number = input("multiband number (defaults to 16): ") 54 | n_signal = input("training example duration (defaults to 65536 samples): ") 55 | 56 | configurations = ["default", "small", "large"] 57 | configuration = "" 58 | while not configuration: 59 | conf = input("configuration (default, small, large): ") 60 | if conf in configurations: 61 | configuration = conf 62 | else: 63 | print(f"configuration {conf} not understood.") 64 | 65 | prior_resolution = input("prior resolution (defaults to 32): ") 66 | fidelity = input("reconstruction fidelity (defaults to 0.95): ") 67 | no_latency = input("latency compensation (defaults to false): ") 68 | latent_size = input("latent size (learned if left blank): ") 69 | pressure = input("regularization strength (defaults to 0.1): ") 70 | 71 | header(f"{name}: training instructions") 72 | subsection("train rave") 73 | 74 | p("Train rave (both training stages are included)") 75 | p("") 76 | 77 | cmd = "python train_rave.py " 78 | cmd += f"-c {configuration} " 79 | cmd += f"--name {name} " 80 | cmd += f"--wav {data} " 81 | prep_rave = path.join(preprocessed, name, "rave") 82 | cmd += f"--preprocessed {prep_rave} " 83 | 84 | if sampling_rate: 85 | cmd += f"--sr {sampling_rate} " 86 | if multiband_number: 87 | cmd += f"--data-size {multiband_number} " 88 | if n_signal: 89 | cmd += f"--n-signal {n_signal} " 90 | if no_latency: 91 | cmd += f"--no-latency {no_latency.lower()} " 92 | if latent_size: 93 | cmd += f"--cropped-latent-size {int(latent_size)} " 94 | if pressure: 95 | cmd += f"--max-kl {pressure} " 96 | 97 | p(cmd) 98 | p("") 99 | 100 | p("You can follow the training using tensorboard") 101 | p("") 102 | p("tensorboard --logdir . --bind_all") 103 | p("") 104 | p("Once the training has reached a satisfactory state, kill it (ctrl + C)") 105 | 106 | subsection("train prior") 107 | 108 | p( 109 | f"Export the latent space trained on {name}.", 110 | end="\n\n", 111 | ) 112 | 113 | cmd = "python export_rave.py " 114 | 115 | run = path.join("runs", name, "rave") 116 | cmd += f"--run {run} " 117 | cmd += f"--cached false " 118 | if fidelity: 119 | cmd += f"--fidelity {fidelity} " 120 | cmd += f"--name {name}" 121 | 122 | p(cmd) 123 | p("") 124 | 125 | p( 126 | f"Train the prior model.", 127 | end="\n\n", 128 | ) 129 | 130 | cmd = "python train_prior.py " 131 | 132 | if prior_resolution: 133 | cmd += f"--resolution {prior_resolution} " 134 | 135 | cmd += f"--pretrained-vae rave_{name}.ts " 136 | prep_prior = path.join(preprocessed, name, "prior") 137 | cmd += f"--preprocessed {prep_prior} " 138 | cmd += f"--wav {data} " 139 | 140 | if n_signal: 141 | cmd += f"--n-signal {n_signal} " 142 | 143 | cmd += f"--name {name}" 144 | 145 | p(cmd) 146 | p("") 147 | p("Once the training has reached a satisfactory state, kill it (ctrl + C)") 148 | 149 | p("") 150 | header("export to max msp (coming soon)") 151 | 152 | p("In order to use both **rave** and the **prior** model inside max/msp, we have to export them using **cached convolutions**." 153 | ) 154 | p("") 155 | 156 | cmd = "python export_rave.py " 157 | 158 | run = path.join("runs", name, "rave") 159 | cmd += f"--run {run} " 160 | cmd += f"--cached true " 161 | if fidelity: 162 | cmd += f"--fidelity {fidelity} " 163 | cmd += f"--name {name}_rt" 164 | p(cmd) 165 | cmd = "python export_prior.py " 166 | 167 | run = path.join("runs", name, "prior") 168 | cmd += f"--run {run} " 169 | cmd += f"--name {name}_rt" 170 | 171 | p(cmd) 172 | 173 | cmd = "python combine_models.py " 174 | cmd += f"--prior prior_{name}_rt.ts " 175 | cmd += f"--rave rave_{name}_rt.ts " 176 | cmd += f"--name {name}" 177 | 178 | p(cmd) 179 | 180 | with open(f"instruction_{name}.txt", "w") as out: 181 | out.write(p.msg) -------------------------------------------------------------------------------- /rave/prior/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import pytorch_lightning as pl 4 | from tqdm import tqdm 5 | 6 | from .residual_block import ResidualBlock 7 | from .core import DiagonalShift, QuantizedNormal 8 | 9 | import cached_conv as cc 10 | 11 | 12 | class Model(pl.LightningModule): 13 | 14 | def __init__(self, resolution, res_size, skp_size, kernel_size, cycle_size, 15 | n_layers, pretrained_vae): 16 | super().__init__() 17 | self.save_hyperparameters() 18 | 19 | self.diagonal_shift = DiagonalShift() 20 | self.quantized_normal = QuantizedNormal(resolution) 21 | 22 | self.synth = torch.jit.load(pretrained_vae) 23 | self.sr = self.synth.sampling_rate.item() 24 | data_size = self.synth.cropped_latent_size 25 | 26 | self.pre_net = nn.Sequential( 27 | cc.Conv1d( 28 | resolution * data_size, 29 | res_size, 30 | kernel_size, 31 | padding=cc.get_padding(kernel_size, mode="causal"), 32 | groups=data_size, 33 | ), 34 | nn.LeakyReLU(.2), 35 | ) 36 | 37 | self.residuals = nn.ModuleList([ 38 | ResidualBlock( 39 | res_size, 40 | skp_size, 41 | kernel_size, 42 | 2**(i % cycle_size), 43 | ) for i in range(n_layers) 44 | ]) 45 | 46 | self.post_net = nn.Sequential( 47 | cc.Conv1d(skp_size, skp_size, 1), 48 | nn.LeakyReLU(.2), 49 | cc.Conv1d( 50 | skp_size, 51 | resolution * data_size, 52 | 1, 53 | groups=data_size, 54 | ), 55 | ) 56 | 57 | self.data_size = data_size 58 | 59 | self.val_idx = 0 60 | 61 | def configure_optimizers(self): 62 | p = [] 63 | p.extend(list(self.pre_net.parameters())) 64 | p.extend(list(self.residuals.parameters())) 65 | p.extend(list(self.post_net.parameters())) 66 | return torch.optim.Adam(p, lr=1e-4) 67 | 68 | @torch.no_grad() 69 | def encode(self, x): 70 | self.synth.eval() 71 | return self.synth.encode(x) 72 | 73 | @torch.no_grad() 74 | def decode(self, z): 75 | self.synth.eval() 76 | return self.synth.decode(z) 77 | 78 | def forward(self, x): 79 | res = self.pre_net(x) 80 | skp = torch.tensor(0.).to(x) 81 | for layer in self.residuals: 82 | res, skp = layer(res, skp) 83 | x = self.post_net(skp) 84 | return x 85 | 86 | @torch.no_grad() 87 | def generate(self, x, argmax: bool = False): 88 | for i in tqdm(range(x.shape[-1] - 1)): 89 | if cc.USE_BUFFER_CONV: 90 | start = i 91 | else: 92 | start = None 93 | 94 | pred = self.forward(x[..., start:i + 1]) 95 | 96 | if not cc.USE_BUFFER_CONV: 97 | pred = pred[..., -1:] 98 | 99 | pred = self.post_process_prediction(pred, argmax=argmax) 100 | 101 | x[..., i + 1:i + 2] = pred 102 | return x 103 | 104 | def split_classes(self, x): 105 | # B x D*C x T 106 | x = x.permute(0, 2, 1) 107 | x = x.reshape(x.shape[0], x.shape[1], self.data_size, -1) 108 | x = x.permute(0, 2, 1, 3) # B x D x T x C 109 | return x 110 | 111 | def post_process_prediction(self, x, argmax: bool = False): 112 | x = self.split_classes(x) 113 | shape = x.shape[:-1] 114 | x = x.reshape(-1, x.shape[-1]) 115 | 116 | if argmax: 117 | x = torch.argmax(x, -1) 118 | else: 119 | x = torch.softmax(x - torch.logsumexp(x, -1, keepdim=True), -1) 120 | x = torch.multinomial(x, 1, True).squeeze(-1) 121 | 122 | x = x.reshape(shape[0], shape[1], shape[2]) 123 | x = self.quantized_normal.to_stack_one_hot(x) 124 | return x 125 | 126 | def training_step(self, batch, batch_idx): 127 | x = self.encode(batch) 128 | x = self.quantized_normal.encode(self.diagonal_shift(x)) 129 | pred = self.forward(x) 130 | 131 | x = torch.argmax(self.split_classes(x[..., 1:]), -1) 132 | pred = self.split_classes(pred[..., :-1]) 133 | 134 | loss = nn.functional.cross_entropy( 135 | pred.reshape(-1, self.quantized_normal.resolution), 136 | x.reshape(-1), 137 | ) 138 | 139 | self.log("latent_prediction", loss) 140 | return loss 141 | 142 | def validation_step(self, batch, batch_idx): 143 | x = self.encode(batch) 144 | x = self.quantized_normal.encode(self.diagonal_shift(x)) 145 | pred = self.forward(x) 146 | 147 | x = torch.argmax(self.split_classes(x[..., 1:]), -1) 148 | pred = self.split_classes(pred[..., :-1]) 149 | 150 | loss = nn.functional.cross_entropy( 151 | pred.reshape(-1, self.quantized_normal.resolution), 152 | x.reshape(-1), 153 | ) 154 | 155 | self.log("validation", loss) 156 | return batch 157 | 158 | def validation_epoch_end(self, out): 159 | x = torch.randn_like(self.encode(out[0])) 160 | x = self.quantized_normal.encode(self.diagonal_shift(x)) 161 | z = self.generate(x) 162 | z = self.diagonal_shift.inverse(self.quantized_normal.decode(z)) 163 | 164 | y = self.decode(z) 165 | self.logger.experiment.add_audio( 166 | "generation", 167 | y.reshape(-1), 168 | self.val_idx, 169 | self.synth.sampling_rate.item(), 170 | ) 171 | self.val_idx += 1 172 | -------------------------------------------------------------------------------- /hypersound/cfg/__init__.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from enum import Enum, auto 3 | from typing import Optional 4 | 5 | from pytorch_yard.configs.cfg.lightning import LightningConf, LightningSettings 6 | 7 | 8 | class Dataset(Enum): 9 | VCTK = auto() 10 | LJS = auto() 11 | GTZAN = auto() 12 | LIBRITTS = auto() 13 | LIBRISPEECH = auto() 14 | 15 | 16 | class ModelType(Enum): 17 | NERF = auto() 18 | SIREN = auto() 19 | 20 | 21 | class TargetNetworkMode(Enum): 22 | INR = auto() 23 | TARGET_NETWORK = auto() 24 | MODULATOR = auto() 25 | RESIDUAL = auto() 26 | 27 | 28 | class LRScheduler(Enum): 29 | ONE_CYCLE = auto() 30 | CONSTANT = auto() 31 | 32 | 33 | @dataclass 34 | class DataSettings: 35 | num_workers: int = 8 36 | 37 | duration: float = 1.4861 # 1.4861 @ 22050 Hz --> 2^15 (32768) samples 38 | 39 | sample_rate: int = 22050 40 | n_fft: int = 1024 41 | hop_length: int = 256 42 | n_mels: int = 128 43 | data_norm: bool = True 44 | 45 | samples_per_epoch: int = 10_000 46 | 47 | file_limit: Optional[int] = None # Limit on files used in training 48 | file_limit_validation: Optional[int] = 512 49 | 50 | # Boundaries for target network input indices. If provided, will rescale indices from [0-n_samples] to this value. 51 | # Setting value is normalized for a clip duration of 1.0 second if `proportional_index_scaling` is enabled. 52 | # A range of [-m, m] will use basal frequencies of up to m/π Hz, i.e. +/-300 ~ 100 Hz. 53 | index_scaling: Optional[tuple[float, float]] = (-1, 1) 54 | proportional_index_scaling: bool = False 55 | 56 | 57 | @dataclass 58 | class TransformSettings: 59 | start_offset: float = 0.5 60 | padding: bool = False 61 | dequantize: bool = True 62 | phase_mangle: bool = True 63 | random_crop: bool = True 64 | 65 | 66 | @dataclass 67 | class ModelSettings: 68 | type: ModelType = ModelType.NERF 69 | target_network_mode: TargetNetworkMode = TargetNetworkMode.TARGET_NETWORK 70 | 71 | layer_norm: bool = True 72 | encoder_channels: int = 16 73 | embedding_size: int = 32 74 | 75 | reconstruction_loss_lambda: float = 1.0 76 | 77 | activations_decay_loss_lambda: float = 0.0 78 | 79 | perceptual_loss_lambda: float = 1.0 80 | perceptual_loss_decay_epochs: Optional[int] = None 81 | perceptual_loss_final_lambda: float = 1.0 82 | perceptual_loss_freq_weights_p: float = 0.0 83 | perceptual_loss_freq_weights_warmup_epochs: int = 500 84 | 85 | fft_sizes: list[int] = field(default_factory=lambda: [2048, 1024, 512, 256, 128]) 86 | hop_sizes: list[int] = field(default_factory=lambda: [512, 256, 128, 64, 32]) 87 | win_lengths: list[int] = field(default_factory=lambda: [2048, 1024, 512, 256, 128]) 88 | 89 | hypernetwork_layer_sizes: list[int] = field(default_factory=lambda: [400, 768, 768, 768, 768, 400]) 90 | 91 | target_network_shared_layers: Optional[list[int]] = None 92 | target_network_layer_sizes: list[int] = field(default_factory=lambda: 5 * [90]) 93 | 94 | target_network_omega_0: float = 1.0 95 | target_network_omega_i: float = 1.0 96 | target_network_share_omega: bool = True 97 | target_network_learnable_omega: bool = False 98 | target_network_siren_gradient_fix: bool = True 99 | 100 | target_network_share_encoding: bool = True 101 | target_network_encoding_length: int = 10 102 | target_network_learnable_encoding: bool = False 103 | 104 | 105 | @dataclass 106 | class LogSettings: 107 | examples: int = 5 # Number of examples to plot to wandb during training/validation 108 | warmup_epochs: int = 100 109 | warmup_every_n_epoch: int = 1 # Controls how often logging is done in the initial training phase 110 | normal_every_n_epoch: int = 50 # Controls how often logging is done after warmup epochs have passed 111 | debug: bool = False 112 | 113 | 114 | # Experiment settings validation schema & default values 115 | @dataclass 116 | class Settings(LightningSettings): 117 | # ---------------------------------------------------------------------------------------------- 118 | # General experiment settings 119 | # ---------------------------------------------------------------------------------------------- 120 | batch_size: int = 16 121 | scheduler: LRScheduler = LRScheduler.ONE_CYCLE 122 | learning_rate: float = 0.0001 123 | learning_rate_div_factor: float = 25.0 # or 25.0 124 | 125 | # ---------------------------------------------------------------------------------------------- 126 | # Experiment logging settings 127 | # ---------------------------------------------------------------------------------------------- 128 | log: LogSettings = field(default_factory=lambda: LogSettings()) 129 | 130 | # ---------------------------------------------------------------------------------------------- 131 | # Data settings 132 | # ---------------------------------------------------------------------------------------------- 133 | dataset: Dataset = Dataset.VCTK 134 | data: DataSettings = field(default_factory=lambda: DataSettings()) 135 | transforms: TransformSettings = field(default_factory=lambda: TransformSettings()) 136 | 137 | # ---------------------------------------------------------------------------------------------- 138 | # Model settings 139 | # ---------------------------------------------------------------------------------------------- 140 | model: ModelSettings = field(default_factory=lambda: ModelSettings()) 141 | 142 | # ---------------------------------------------------------------------------------------------- 143 | # PyTorch Lightning overrides 144 | # ---------------------------------------------------------------------------------------------- 145 | pl: LightningConf = LightningConf( 146 | max_epochs=2500, 147 | check_val_every_n_epoch=100, 148 | deterministic=False, 149 | ) 150 | 151 | validate_before_training: bool = False 152 | -------------------------------------------------------------------------------- /rave/train_rave.py: -------------------------------------------------------------------------------- 1 | # fmt: off 2 | import os 3 | from os import environ, path 4 | 5 | import GPUtil as gpu 6 | import numpy as np 7 | import pytorch_lightning as pl 8 | import torch 9 | from effortless_config import Config, setting 10 | from rave.core import EMAModelCheckPoint, random_phase_mangle, search_for_run 11 | from rave.model import RAVE 12 | from torch.utils.data import DataLoader, random_split 13 | from udls import SimpleDataset, simple_audio_preprocess 14 | from udls.transforms import Compose, Dequantize, RandomApply, RandomCrop 15 | 16 | if __name__ == "__main__": 17 | 18 | class args(Config): 19 | groups = ["small", "large"] 20 | 21 | DATA_SIZE = 16 22 | CAPACITY = setting(default=64, small=32, large=64) 23 | LATENT_SIZE = 128 24 | BIAS = True 25 | NO_LATENCY = False 26 | RATIOS = setting( 27 | default=[4, 4, 4, 2], 28 | small=[4, 4, 4, 2], 29 | large=[4, 4, 2, 2, 2], 30 | ) 31 | 32 | MIN_KL = 1e-1 33 | MAX_KL = 1e-1 34 | CROPPED_LATENT_SIZE = 0 35 | FEATURE_MATCH = True 36 | 37 | LOUD_STRIDE = 1 38 | 39 | USE_NOISE = True 40 | NOISE_RATIOS = [4, 4, 4] 41 | NOISE_BANDS = 5 42 | 43 | D_CAPACITY = 16 44 | D_MULTIPLIER = 4 45 | D_N_LAYERS = 4 46 | 47 | WARMUP = setting(default=1000000, small=1000000, large=3000000) 48 | MODE = "hinge" 49 | CKPT = None 50 | 51 | PREPROCESSED = None 52 | WAV = None 53 | WAV_VAL = None 54 | SR = 48000 55 | N_SIGNAL = 65536 56 | MAX_STEPS = setting(default=3000000, small=3000000, large=6000000) 57 | VAL_EVERY = 10000 58 | 59 | BATCH = 8 60 | 61 | NAME = None 62 | 63 | args.parse_args() 64 | 65 | assert args.NAME is not None 66 | model = RAVE( 67 | data_size=args.DATA_SIZE, 68 | capacity=args.CAPACITY, 69 | latent_size=args.LATENT_SIZE, 70 | ratios=args.RATIOS, 71 | bias=args.BIAS, 72 | loud_stride=args.LOUD_STRIDE, 73 | use_noise=args.USE_NOISE, 74 | noise_ratios=args.NOISE_RATIOS, 75 | noise_bands=args.NOISE_BANDS, 76 | d_capacity=args.D_CAPACITY, 77 | d_multiplier=args.D_MULTIPLIER, 78 | d_n_layers=args.D_N_LAYERS, 79 | warmup=args.WARMUP, 80 | mode=args.MODE, 81 | no_latency=args.NO_LATENCY, 82 | sr=args.SR, 83 | min_kl=args.MIN_KL, 84 | max_kl=args.MAX_KL, 85 | cropped_latent_size=args.CROPPED_LATENT_SIZE, 86 | feature_match=args.FEATURE_MATCH, 87 | ) 88 | 89 | x = torch.zeros(args.BATCH, 2**14) 90 | model.validation_step(x, 0) 91 | 92 | preprocess = lambda name: simple_audio_preprocess(args.SR, 2 * args.N_SIGNAL,)( 93 | name 94 | ).astype(np.float16) 95 | 96 | def get_dataset(out_folder, folder): 97 | return SimpleDataset( 98 | out_folder, 99 | folder, 100 | preprocess_function=preprocess, 101 | split_set="full", 102 | transforms=Compose([ 103 | lambda x: x.astype(np.float32), 104 | RandomCrop(args.N_SIGNAL), 105 | RandomApply( 106 | lambda x: random_phase_mangle(x, 20, 2000, .99, args.SR), 107 | p=.8, 108 | ), 109 | Dequantize(16), 110 | lambda x: x.astype(np.float32), 111 | ]), 112 | ) 113 | 114 | train = get_dataset(path.join(args.PREPROCESSED, "train"), args.WAV) 115 | val = get_dataset(path.join(args.PREPROCESSED + "val"), args.WAV_VAL) 116 | 117 | # val = max((2 * len(dataset)) // 100, 1) 118 | # train = len(dataset) - val 119 | # train, val = random_split( 120 | # dataset, 121 | # [train, val], 122 | # generator=torch.Generator().manual_seed(42), 123 | # ) 124 | 125 | num_workers = 0 if os.name == "nt" else 8 126 | train = DataLoader(train, args.BATCH, True, drop_last=True, num_workers=num_workers) 127 | val = DataLoader(val, args.BATCH, False, num_workers=num_workers) 128 | 129 | # CHECKPOINT CALLBACKS 130 | validation_checkpoint = pl.callbacks.ModelCheckpoint( 131 | monitor="validation", 132 | filename="best", 133 | ) 134 | last_checkpoint = pl.callbacks.ModelCheckpoint(filename="last") 135 | 136 | CUDA = gpu.getAvailable(maxMemory=0.05) 137 | VISIBLE_DEVICES = environ.get("CUDA_VISIBLE_DEVICES", "") 138 | 139 | if VISIBLE_DEVICES: 140 | use_gpu = int(int(VISIBLE_DEVICES) >= 0) 141 | elif len(CUDA): 142 | environ["CUDA_VISIBLE_DEVICES"] = str(CUDA[0]) 143 | use_gpu = 1 144 | elif torch.cuda.is_available(): 145 | print("Cuda is available but no fully free GPU found.") 146 | print("Training may be slower due to concurrent processes.") 147 | use_gpu = 1 148 | else: 149 | print("No GPU found.") 150 | use_gpu = 0 151 | 152 | val_check = {} 153 | if len(train) >= args.VAL_EVERY: 154 | val_check["val_check_interval"] = args.VAL_EVERY 155 | else: 156 | nepoch = args.VAL_EVERY // len(train) 157 | val_check["check_val_every_n_epoch"] = nepoch 158 | 159 | trainer = pl.Trainer( 160 | logger=pl.loggers.TensorBoardLogger(path.join("runs", args.NAME), 161 | name="rave"), 162 | gpus=use_gpu, 163 | callbacks=[validation_checkpoint, last_checkpoint], 164 | max_epochs=100000, 165 | max_steps=args.MAX_STEPS, 166 | **val_check, 167 | ) 168 | 169 | run = search_for_run(args.CKPT, mode="last") 170 | if run is None: run = search_for_run(args.CKPT, mode="best") 171 | if run is not None: 172 | step = torch.load(run, map_location='cpu')["global_step"] 173 | trainer.fit_loop.epoch_loop._batches_that_stepped = step 174 | 175 | trainer.fit(model, train, val, ckpt_path=run) 176 | -------------------------------------------------------------------------------- /rave/rave/core.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.fft as fft 4 | from einops import rearrange 5 | import numpy as np 6 | from random import random 7 | from scipy.signal import lfilter 8 | from pytorch_lightning.callbacks import ModelCheckpoint 9 | import librosa as li 10 | from pathlib import Path 11 | 12 | 13 | def mod_sigmoid(x): 14 | return 2 * torch.sigmoid(x)**2.3 + 1e-7 15 | 16 | 17 | def multiscale_stft(signal, scales, overlap): 18 | """ 19 | Compute a stft on several scales, with a constant overlap value. 20 | Parameters 21 | ---------- 22 | signal: torch.Tensor 23 | input signal to process ( B X C X T ) 24 | 25 | scales: list 26 | scales to use 27 | overlap: float 28 | overlap between windows ( 0 - 1 ) 29 | """ 30 | signal = rearrange(signal, "b c t -> (b c) t") 31 | stfts = [] 32 | for s in scales: 33 | S = torch.stft( 34 | signal, 35 | s, 36 | int(s * (1 - overlap)), 37 | s, 38 | torch.hann_window(s).to(signal), 39 | True, 40 | normalized=True, 41 | return_complex=True, 42 | ).abs() 43 | stfts.append(S) 44 | return stfts 45 | 46 | 47 | def random_angle(min_f=20, max_f=8000, sr=24000): 48 | min_f = np.log(min_f) 49 | max_f = np.log(max_f) 50 | rand = np.exp(random() * (max_f - min_f) + min_f) 51 | rand = 2 * np.pi * rand / sr 52 | return rand 53 | 54 | 55 | def pole_to_z_filter(omega, amplitude=.9): 56 | z0 = amplitude * np.exp(1j * omega) 57 | a = [1, -2 * np.real(z0), abs(z0)**2] 58 | b = [abs(z0)**2, -2 * np.real(z0), 1] 59 | return b, a 60 | 61 | 62 | def random_phase_mangle(x, min_f, max_f, amp, sr): 63 | angle = random_angle(min_f, max_f, sr) 64 | b, a = pole_to_z_filter(angle, amp) 65 | return lfilter(b, a, x) 66 | 67 | 68 | class EMAModelCheckPoint(ModelCheckpoint): 69 | 70 | def __init__(self, model: torch.nn.Module, alpha=.999, *args, **kwargs): 71 | super().__init__(*args, **kwargs) 72 | 73 | self.shadow = {} 74 | for n, p in model.named_parameters(): 75 | if p.requires_grad: 76 | self.shadow[n] = p.data.clone() 77 | self.model = model 78 | self.alpha = alpha 79 | 80 | def on_train_batch_end(self, *args, **kwargs): 81 | with torch.no_grad(): 82 | for n, p in self.model.named_parameters(): 83 | if n in self.shadow: 84 | self.shadow[n] *= self.alpha 85 | self.shadow[n] += (1 - self.alpha) * p.data 86 | 87 | def on_validation_epoch_start(self, *args, **kwargs): 88 | self.swap() 89 | 90 | def on_validation_epoch_end(self, *args, **kwargs): 91 | self.swap() 92 | 93 | def swap(self): 94 | for n, p in self.model.named_parameters(): 95 | if n in self.shadow: 96 | tmp = p.data.clone() 97 | p.data.copy_(self.shadow[n]) 98 | self.shadow[n] = tmp 99 | 100 | def save_checkpoint(self, *args, **kwargs): 101 | self.swap() 102 | super().save_checkpoint(*args, **kwargs) 103 | self.swap() 104 | 105 | 106 | class Loudness(nn.Module): 107 | 108 | def __init__(self, sr, block_size, n_fft=2048): 109 | super().__init__() 110 | self.sr = sr 111 | self.block_size = block_size 112 | self.n_fft = n_fft 113 | 114 | f = np.linspace(0, sr / 2, n_fft // 2 + 1) + 1e-7 115 | a_weight = li.A_weighting(f).reshape(-1, 1) 116 | 117 | self.register_buffer("a_weight", torch.from_numpy(a_weight).float()) 118 | self.register_buffer("window", torch.hann_window(self.n_fft)) 119 | 120 | def forward(self, x): 121 | x = torch.stft( 122 | x.squeeze(1), 123 | self.n_fft, 124 | self.block_size, 125 | self.n_fft, 126 | center=True, 127 | window=self.window, 128 | return_complex=True, 129 | ).abs() 130 | x = torch.log(x + 1e-7) + self.a_weight 131 | return torch.mean(x, 1, keepdim=True) 132 | 133 | 134 | def amp_to_impulse_response(amp, target_size): 135 | """ 136 | transforms frequecny amps to ir on the last dimension 137 | """ 138 | amp = torch.stack([amp, torch.zeros_like(amp)], -1) 139 | amp = torch.view_as_complex(amp) 140 | amp = fft.irfft(amp) 141 | 142 | filter_size = amp.shape[-1] 143 | 144 | amp = torch.roll(amp, filter_size // 2, -1) 145 | win = torch.hann_window(filter_size, dtype=amp.dtype, device=amp.device) 146 | 147 | amp = amp * win 148 | 149 | amp = nn.functional.pad( 150 | amp, 151 | (0, int(target_size) - int(filter_size)), 152 | ) 153 | amp = torch.roll(amp, -filter_size // 2, -1) 154 | 155 | return amp 156 | 157 | 158 | def fft_convolve(signal, kernel): 159 | """ 160 | convolves signal by kernel on the last dimension 161 | """ 162 | signal = nn.functional.pad(signal, (0, signal.shape[-1])) 163 | kernel = nn.functional.pad(kernel, (kernel.shape[-1], 0)) 164 | 165 | output = fft.irfft(fft.rfft(signal) * fft.rfft(kernel)) 166 | output = output[..., output.shape[-1] // 2:] 167 | 168 | return output 169 | 170 | 171 | def search_for_run(run_path, mode="last"): 172 | if run_path is None: return None 173 | if ".ckpt" in run_path: return run_path 174 | ckpts = map(str, Path(run_path).rglob("*.ckpt")) 175 | ckpts = filter(lambda e: mode in e, ckpts) 176 | ckpts = sorted(ckpts) 177 | if len(ckpts): return ckpts[-1] 178 | else: return None 179 | 180 | 181 | def get_beta_kl(step, warmup, min_beta, max_beta): 182 | if step > warmup: return max_beta 183 | t = step / warmup 184 | min_beta_log = np.log(min_beta) 185 | max_beta_log = np.log(max_beta) 186 | beta_log = t * (max_beta_log - min_beta_log) + min_beta_log 187 | return np.exp(beta_log) 188 | 189 | 190 | def get_beta_kl_cyclic(step, cycle_size, min_beta, max_beta): 191 | return get_beta_kl(step % cycle_size, cycle_size // 2, min_beta, max_beta) 192 | 193 | 194 | def get_beta_kl_cyclic_annealed(step, cycle_size, warmup, min_beta, max_beta): 195 | min_beta = get_beta_kl(step, warmup, min_beta, max_beta) 196 | return get_beta_kl_cyclic(step, cycle_size, min_beta, max_beta) 197 | -------------------------------------------------------------------------------- /hypersound/utils/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Type 2 | 3 | import librosa 4 | import torch 5 | from cdpam import CDPAM 6 | from pesq import NoUtterancesError # type: ignore 7 | from torch import Tensor 8 | from torchmetrics import Metric 9 | from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality 10 | from torchmetrics.audio.snr import ScaleInvariantSignalNoiseRatio, SignalNoiseRatio 11 | from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility 12 | from torchmetrics.image.psnr import PeakSignalNoiseRatio 13 | from torchmetrics.regression.mae import MeanAbsoluteError 14 | from torchmetrics.regression.mse import MeanSquaredError 15 | 16 | METRICS = [ 17 | "MAE", 18 | "MSE", 19 | "SNR", 20 | "SI-SNR", 21 | "PSNR", 22 | "LSD", 23 | "PESQ", 24 | "STOI", 25 | "CDPAM", 26 | ] 27 | 28 | 29 | def reduce_metric(outputs: list[dict[str, Tensor]], key: str) -> float: 30 | try: 31 | values = [out[key] for out in outputs if not out[key].isnan()] 32 | except KeyError: 33 | return float("nan") 34 | if not values: 35 | return float("nan") 36 | return float(torch.stack(values).mean().detach()) 37 | 38 | 39 | def _get_metric(is_active: bool, metric: Type[Metric], preds: Tensor, target: Tensor) -> Tensor: 40 | if not is_active: 41 | return torch.tensor(float("nan")) 42 | 43 | _metric = metric(full_state_update=False).to(preds.device) # type: ignore 44 | _metric.update(preds=preds, target=target) # type: ignore 45 | return _metric.compute() # type: ignore 46 | 47 | 48 | def _resample(signal: Tensor, orig_sr: int, target_sr: int): 49 | device = signal.device 50 | 51 | signal = torch.tensor( 52 | librosa.resample( # type: ignore 53 | signal.detach().cpu().numpy(), 54 | orig_sr=orig_sr, 55 | target_sr=target_sr, 56 | ) 57 | ) 58 | 59 | return signal.to(device) 60 | 61 | 62 | def compute_metrics( 63 | preds: Tensor, 64 | target: Tensor, 65 | sample_rate: int, 66 | *, 67 | mae: bool = True, 68 | mse: bool = True, 69 | snr: bool = True, 70 | si_snr: bool = True, 71 | psnr: bool = True, 72 | lsd: bool = True, 73 | pesq: bool = False, 74 | stoi: bool = False, 75 | cdpam: bool = False, 76 | ) -> dict[str, Tensor]: 77 | 78 | results: dict[str, Tensor] = {} 79 | 80 | ScaleInvariantSignalNoiseRatio.full_state_update = False 81 | 82 | results["MSE"] = _get_metric(mse, MeanSquaredError, preds, target).detach() 83 | results["MAE"] = _get_metric(mae, MeanAbsoluteError, preds, target).detach() 84 | results["LSD"] = log_spectral_distance(lsd, preds, target, sample_rate=sample_rate).detach() 85 | results["SNR"] = _get_metric(snr, SignalNoiseRatio, preds, target).detach() 86 | results["PSNR"] = _get_metric(psnr, PeakSignalNoiseRatio, preds, target).detach() 87 | results["SI-SNR"] = _get_metric(si_snr, ScaleInvariantSignalNoiseRatio, preds, target).detach() 88 | 89 | results["PESQ"] = torch.tensor(float("nan")) 90 | if pesq: 91 | preds_16k = torch.stack([_resample(y, sample_rate, 16000) for y in preds]) 92 | target_16k = torch.stack([_resample(y, sample_rate, 16000) for y in target]) 93 | 94 | try: 95 | _pesq = PerceptualEvaluationSpeechQuality(fs=16000, mode="wb") 96 | _pesq.update(preds=preds_16k, target=target_16k) 97 | results["PESQ"] = _pesq.compute().detach() 98 | except NoUtterancesError: 99 | print("PESQ computation failed, skipping batch - no utterances found.") 100 | results["PESQ"] = torch.tensor(float("nan")) 101 | 102 | results["STOI"] = torch.tensor(float("nan")) 103 | if stoi: 104 | _stoi = ShortTimeObjectiveIntelligibility(fs=sample_rate) 105 | _stoi.update(preds=preds, target=target) 106 | results["STOI"] = _stoi.compute().detach() 107 | 108 | results["CDPAM"] = torch.tensor(float("nan")) 109 | if cdpam: 110 | _cdpam = CDPAM(dev=preds.device) 111 | results["CDPAM"] = compute_cdpam(_cdpam, preds=preds, target=target, sample_rate=sample_rate).detach() 112 | 113 | results = {metric: value.to("cpu") for metric, value in results.items()} 114 | 115 | return results 116 | 117 | 118 | def log_spectral_distance( 119 | is_active: bool, preds: Tensor, target: Tensor, sample_rate: int, eps: float = 1e-12 120 | ) -> Tensor: 121 | """ 122 | Log spectral distance between two spectrograms as per https://arxiv.org/pdf/2203.14941v1.pdf. 123 | `hop_length` and `n_fft` are computed as in the paper. 124 | """ 125 | 126 | if not is_active: 127 | return torch.tensor(0.0) 128 | 129 | hop_length = int(sample_rate / 100) 130 | n_fft = int(2048 / (44100 / sample_rate)) 131 | target_stft = _to_spectrogram(target, hop_length=hop_length, n_fft=n_fft) 132 | pred_stft = _to_spectrogram(preds, hop_length=hop_length, n_fft=n_fft) 133 | 134 | lsd = torch.log10(target_stft**2 / ((pred_stft + eps) ** 2) + eps) ** 2 135 | lsd = torch.mean(lsd, dim=-1) ** 0.5 136 | lsd = torch.mean(lsd, dim=-1) 137 | return lsd.mean() 138 | 139 | 140 | def _to_spectrogram(audio: Tensor, hop_length: int, n_fft: int) -> Tensor: 141 | stft = torch.stft( 142 | audio, 143 | hop_length=hop_length, 144 | n_fft=n_fft, 145 | window=torch.hann_window(window_length=n_fft).to(audio.device), 146 | return_complex=True, 147 | pad_mode="constant", 148 | ) 149 | stft = torch.abs(stft) 150 | stft = torch.transpose(stft, -1, -2) 151 | return stft 152 | 153 | 154 | def compute_cdpam( 155 | model: CDPAM, 156 | preds: Tensor, 157 | target: Tensor, 158 | sample_rate: int, 159 | ) -> Tensor: 160 | """ 161 | Computes CDPAM metric introduced in https://arxiv.org/abs/2102.05109. 162 | Requires CDPAM model, which operates on audio sampled with 22050 Hz. 163 | In case of different sampling rate, resamples recordings. 164 | """ 165 | 166 | if sample_rate != 22050: 167 | preds = _resample(preds, sample_rate, 22050) 168 | target = _resample(target, sample_rate, 22050) 169 | 170 | preds = torch.round(preds.float() * 32768) 171 | target = torch.round(target.float() * 32768) 172 | 173 | with torch.no_grad(): 174 | return model.forward(target, preds).detach().cpu().mean() # type: ignore 175 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from pathlib import Path 3 | from typing import Optional 4 | 5 | import librosa 6 | import torch 7 | import typer 8 | from more_itertools import chunked 9 | from rich import print 10 | from rich.progress import track 11 | from torch import Tensor 12 | 13 | from hypersound.utils.eval import EvalModel, load_recordings 14 | from hypersound.utils.metrics import compute_metrics, reduce_metric 15 | 16 | 17 | def main( 18 | # fmt: off 19 | # Eval model 20 | work_dir: Optional[Path] = typer.Argument( 21 | None, help="Working directory (directory of the run to evaluate)." 22 | ), 23 | cfg_path: Path = typer.Option( 24 | Path("code.cfg"), "--cfg_path", "-c", help="Config file used for training." 25 | ), 26 | ckpt_path: Path = typer.Option( 27 | Path("checkpoints/"), "--ckpt_path", "-p", help="Checkpoint file or directory." # noqa 28 | ), 29 | reference_dir: Path = typer.Option( 30 | Path("eval/reference"), "--reference_dir", "-r", help="Directory containing reference recordings." # noqa 31 | ), 32 | generated_dir: Path = typer.Option( 33 | Path("eval/generated"), "--generated_dir", "-g", help="Directory containing recordings generated by the model." # noqa 34 | ), 35 | sample_rate: Optional[int] = typer.Option( 36 | None, "--sample_rate", "-s", help="Audio interpolation rate. If None, will use default training rate." 37 | ), 38 | audio_format: str = typer.Option( 39 | "wav", "--audio_format", "-f", help="Audio format." 40 | ), 41 | generate_training: bool = typer.Option( 42 | False, "--train", help="If set, will generate recordings for the training fold." 43 | ), 44 | is_cuda: bool = typer.Option( 45 | False, "--cuda", help="Enable GPU processing", 46 | ), 47 | evaluation_fold: str = typer.Option( 48 | "validation", help="Determines which fold to evaluate on (not applicable if using absolute recording paths)." 49 | ), 50 | clean: bool = typer.Option( 51 | False, "--clean", help="Remove generated recordings after evaluation." 52 | ), 53 | file_limit_override: Optional[int] = typer.Option( 54 | None, "--file_limit", help="Override config validation file limit." 55 | ), 56 | # Metrics 57 | batch_size: int = typer.Option( 58 | 16, help="Metric evaluation batch size." 59 | ), 60 | pesq: bool = typer.Option( 61 | False, help="If set, will compute PESQ metric on audio resampled to 16000 Hz." 62 | ), 63 | stoi: bool = typer.Option( 64 | False, help="If set, will compute STOI." 65 | ), 66 | cdpam: bool = typer.Option( 67 | False, help="If set, will compute CDPAM metric on audio resampled to 22050 Hz." 68 | ), 69 | # fmt: on 70 | ): 71 | """ 72 | Evaluate a trained model. 73 | """ 74 | 75 | if work_dir: 76 | eval_model = EvalModel( 77 | cfg_path=work_dir / cfg_path, 78 | ckpt_path=work_dir / ckpt_path, 79 | audio_format=audio_format, 80 | is_cuda=is_cuda, 81 | batch_size=batch_size, 82 | sample_rate=sample_rate, 83 | file_limit_override=file_limit_override, 84 | ) 85 | 86 | reference_dir = work_dir / reference_dir 87 | generated_dir = work_dir / generated_dir 88 | 89 | eval_model.generate_recordings(reference_dir, generated_dir, generate_training) 90 | 91 | reference_dir = reference_dir / evaluation_fold 92 | generated_dir = generated_dir / evaluation_fold 93 | else: 94 | if not reference_dir.is_absolute() or not generated_dir.is_absolute(): 95 | raise RuntimeError("When no work_dir is specified, dataset paths must be absolute.") 96 | 97 | reference, generated = load_recordings(reference_dir, generated_dir, audio_format) 98 | 99 | sample_rate = sample_rate or 22050 # Set default for metric evaluation 100 | 101 | print("[bold yellow]Evaluating recordings:") 102 | print(f" - reference ({len(reference)}): {reference_dir}") 103 | print(f" - generated ({len(generated)}): {generated_dir}") 104 | 105 | if cdpam and sample_rate != 22050: 106 | print( 107 | "[bold red]CDPAM can only be used with sampling rate 22050 Hz, " 108 | "recordings will be resampled for evaluation of this metric." 109 | ) 110 | 111 | if pesq and sample_rate != 16000: 112 | print( 113 | "[bold red]PESQ can only be used with sampling rate 16000 Hz, " 114 | "recordings will be resampled for evaluation of this metric." 115 | ) 116 | 117 | metrics: list[dict[str, Tensor]] = [] 118 | 119 | for ref_paths, gen_paths in track( 120 | zip(chunked(reference, batch_size), chunked(generated, batch_size)), 121 | description="Running evaluation...", 122 | total=len(reference) // batch_size, 123 | ): 124 | for (ref_path, gen_path) in zip(ref_paths, gen_paths): 125 | assert ref_path.name == gen_path.name.replace("_reconstruction", "").replace("reconstruction_", ""), ( 126 | f"Names of the recordings in target and generated directory should match, " 127 | f"but got: {str(ref_path)} and {str(gen_path)}." 128 | ) 129 | 130 | ref_recordings: Tensor = torch.stack( 131 | [torch.tensor(librosa.load(str(path), sr=sample_rate)[0]) for path in ref_paths] # type: ignore 132 | ) 133 | gen_recordings: Tensor = torch.stack( 134 | [torch.tensor(librosa.load(str(path), sr=sample_rate)[0]) for path in gen_paths] # type: ignore 135 | ) 136 | 137 | if is_cuda: 138 | ref_recordings = ref_recordings.to("cuda") 139 | gen_recordings = gen_recordings.to("cuda") 140 | 141 | metrics.append( 142 | compute_metrics( 143 | preds=gen_recordings, 144 | target=ref_recordings, 145 | sample_rate=sample_rate, 146 | pesq=pesq, 147 | stoi=stoi, 148 | cdpam=cdpam, 149 | ) 150 | ) 151 | 152 | print("Finished evaluation.") 153 | for key in metrics[0]: 154 | print(f"\t[bold yellow]{key}:[/] {reduce_metric(metrics, key):.4f}") 155 | 156 | if work_dir and clean: 157 | print("Removing generated recordings after evaluation...") 158 | shutil.rmtree(generated_dir) 159 | shutil.rmtree(reference_dir) 160 | 161 | 162 | if __name__ == "__main__": 163 | typer.run(main) 164 | -------------------------------------------------------------------------------- /hypersound/utils/eval.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from math import ceil 3 | from pathlib import Path 4 | from typing import Optional, cast 5 | 6 | import soundfile as sf 7 | import torch 8 | from more_itertools import chunked 9 | from omegaconf import OmegaConf 10 | from pathos.threading import ThreadPool as Pool 11 | from pytorch_yard import RootConfig 12 | from rich import print 13 | from rich.progress import track 14 | from torch import Tensor 15 | 16 | from hypersound.cfg import Settings 17 | from hypersound.datasets.audio import AudioDataset 18 | from hypersound.datasets.utils import init_datamodule 19 | from hypersound.systems.main import HyperNetworkAE 20 | 21 | 22 | def load_directory(recording_dir: Path, audio_extension: str) -> list[Path]: 23 | recordings = list(recording_dir.rglob(f"*.{audio_extension}")) 24 | recordings = sorted(recordings, key=lambda path: str(path.name)) 25 | return recordings 26 | 27 | 28 | def load_recordings(reference_dir: Path, generated_dir: Path, audio_extension: str) -> tuple[list[Path], list[Path]]: 29 | reference = load_directory(reference_dir, audio_extension) 30 | generated = load_directory(generated_dir, audio_extension) 31 | 32 | assert len(reference) == len(generated), "Number of reference and generated recordings differs." 33 | assert len(reference) > 0, "Data directory is empty." 34 | 35 | return reference, generated 36 | 37 | 38 | class EvalModel: 39 | def __init__( 40 | self, 41 | cfg_path: Path, 42 | ckpt_path: Path, 43 | audio_format: str, 44 | is_cuda: bool, 45 | batch_size: int, 46 | sample_rate: Optional[int] = None, 47 | file_limit_override: Optional[int] = None, 48 | ) -> None: 49 | self.cfg_path = cfg_path 50 | self.ckpt_path = ckpt_path 51 | self.sample_rate = sample_rate 52 | self.audio_format = audio_format 53 | self.is_cuda = is_cuda 54 | self.batch_size = batch_size 55 | self.file_limit_override = file_limit_override 56 | 57 | self.cfg: Settings 58 | 59 | self._load_config() 60 | self._setup() 61 | self._load_checkpoint() 62 | 63 | def _load_config(self): 64 | # Config 65 | try: 66 | self.root_cfg = cast(RootConfig, OmegaConf.load(self.cfg_path)) 67 | except Exception: 68 | raise RuntimeError(f"{self.cfg_path} is not a valid config file.") 69 | 70 | self.cfg = cast(Settings, OmegaConf.merge(OmegaConf.structured(Settings), self.root_cfg.cfg)) 71 | 72 | if self.file_limit_override is not None: 73 | self.cfg.data.file_limit_validation = self.file_limit_override or None 74 | 75 | print(f"Loaded config: {OmegaConf.to_yaml(self.root_cfg, resolve=True)}") 76 | 77 | def _setup(self): 78 | # Model setup 79 | self.datamodule, self.interpolation_dm = init_datamodule(self.root_cfg, self.sample_rate) 80 | self.datamodule.prepare_data() 81 | self.datamodule.setup() 82 | 83 | self.system = HyperNetworkAE( 84 | cfg=self.cfg, 85 | input_length=self.datamodule.train.shape[1][0], 86 | spec_transform=copy.deepcopy(self.datamodule.train.spec_transform), 87 | ) 88 | 89 | def _load_checkpoint(self): 90 | # Checkpoint 91 | checkpoint = self.ckpt_path 92 | 93 | if checkpoint.suffix != ".ckpt": 94 | print(f"Searching for a valid checkpoint file in: {self.ckpt_path}") 95 | if not self.ckpt_path.is_dir(): 96 | raise RuntimeError(f"{self.ckpt_path} is not a directory.") 97 | 98 | checkpoint = [name for name in self.ckpt_path.iterdir() if name.suffix == ".ckpt"][-1] 99 | 100 | if not checkpoint.is_file(): 101 | raise RuntimeError(f"{checkpoint} is not a valid file.") 102 | 103 | print(f"Loading checkpoint {checkpoint}...") 104 | self.system = HyperNetworkAE.load_from_checkpoint(str(checkpoint)) # type: ignore 105 | 106 | if self.is_cuda: 107 | self.system.to("cuda") 108 | 109 | print(self.system) 110 | 111 | def generate_recordings(self, reference_dir: Path, generated_dir: Path, generate_training: bool = False) -> None: 112 | print("Verifying recording files...") 113 | 114 | if generate_training: 115 | self._generate_recordings("train", reference_dir, generated_dir) 116 | self._generate_recordings("validation", reference_dir, generated_dir) 117 | 118 | def _generate_recordings(self, fold: str, reference_dir: Path, generated_dir: Path) -> None: 119 | dataset: AudioDataset = getattr(self.datamodule, fold) 120 | reference_dir = reference_dir / fold 121 | generated_dir = generated_dir / fold 122 | 123 | if reference_dir.is_dir(): 124 | assert len(list(reference_dir.iterdir())) == len(dataset), \ 125 | "The reference directory and source dataset have a different number of recordings." # fmt: skip 126 | else: 127 | self._process_dataset(dataset, fold, reference_dir, generate=False) 128 | 129 | if generated_dir.is_dir(): 130 | assert len(list(generated_dir.iterdir())) == len(dataset), \ 131 | "The generated directory has an invalid number of recordings." # fmt: skip 132 | else: 133 | self._process_dataset(dataset, fold, generated_dir, generate=True) 134 | 135 | def _save_recording(self, output_dir: Path, idx: int, signal: Tensor): 136 | sf.write( # type: ignore 137 | str(output_dir / f"{idx}.{self.audio_format}"), 138 | signal.detach().cpu().numpy(), 139 | self.sample_rate or self.cfg.data.sample_rate, 140 | format=self.audio_format, 141 | subtype="PCM_24", 142 | ) 143 | 144 | def _process_dataset(self, dataset: AudioDataset, fold: str, output_dir: Path, generate: bool): 145 | output_dir.mkdir(parents=True) 146 | 147 | description = "Generated" if generate else "Reference" 148 | 149 | p = Pool(self.batch_size) 150 | 151 | for idxs in track( 152 | chunked(range(len(dataset)), self.batch_size), 153 | description=f"Recreating: {description} / {fold}", 154 | total=ceil(len(dataset) / self.batch_size), 155 | ): 156 | 157 | indices = torch.stack([dataset[idx][0] for idx in idxs]) 158 | signal = torch.stack([dataset[idx][1] for idx in idxs]) 159 | 160 | assert self.interpolation_dm 161 | 162 | if generate: 163 | if self.sample_rate and fold == "validation": 164 | # evaluate on interpolated indices, base input signal remains the same 165 | indices = torch.stack([self.interpolation_dm.validation[idx][0] for idx in idxs]) 166 | 167 | if self.is_cuda: 168 | indices = indices.to("cuda") 169 | signal = signal.to("cuda") 170 | 171 | signal = self.system.reconstruct(indices, signal) 172 | elif self.sample_rate and fold == "validation": # ground truth for interpolation 173 | signal = torch.stack([self.interpolation_dm.validation[idx][1] for idx in idxs]) 174 | 175 | p.map( # type: ignore 176 | self._save_recording, 177 | [output_dir] * self.batch_size, 178 | idxs, 179 | signal.unbind(0), 180 | ) 181 | -------------------------------------------------------------------------------- /hypersound/models/meta/hyper.py: -------------------------------------------------------------------------------- 1 | import math 2 | from abc import ABC 3 | from dataclasses import dataclass 4 | from typing import cast 5 | 6 | import torch 7 | from torch import Tensor, nn 8 | 9 | from hypersound.models.siren import SIREN 10 | 11 | from .inr import INR 12 | 13 | 14 | @dataclass 15 | class ParamInfo: 16 | start_idx: int 17 | end_idx: int 18 | shape: tuple[int, ...] 19 | 20 | 21 | def calc_fan_in_and_out(shape: tuple[int, ...]) -> tuple[int, int]: 22 | """Code copied from hypnettorch library""" 23 | assert len(shape) > 1 24 | 25 | fan_in = shape[1] 26 | fan_out = shape[0] 27 | 28 | if len(shape) > 2: 29 | receptive_field_size = int(torch.prod(torch.tensor(shape[2:]))) 30 | else: 31 | receptive_field_size = 1 32 | 33 | fan_in *= receptive_field_size 34 | fan_out *= receptive_field_size 35 | 36 | return fan_in, fan_out 37 | 38 | 39 | class HyperNetwork(ABC, torch.nn.Module): 40 | pass 41 | 42 | 43 | class MLPHyperNetwork(HyperNetwork): 44 | def __init__( 45 | self, 46 | target_network: INR, 47 | shared_params: list[str], 48 | input_size: int, 49 | layer_sizes: list[int], 50 | ): 51 | super().__init__() 52 | self.input_size = input_size 53 | self._param_info = self._compute_param_info(target_network, shared_params) 54 | 55 | layers: list[nn.Linear] = [] 56 | 57 | in_sizes = [input_size] + layer_sizes 58 | out_sizes = layer_sizes + [target_network.num_params(shared_params=shared_params)] 59 | 60 | for n_in, n_out in zip(in_sizes, out_sizes): 61 | layer = nn.Linear(n_in, n_out) 62 | layers.append(layer) 63 | 64 | self.net = nn.ModuleList(layers) 65 | self.activation_fn = nn.ELU() 66 | 67 | self._init_weights(target_network) 68 | 69 | def __call__(self, x: Tensor) -> dict[str, Tensor]: # type: ignore 70 | return super().__call__(x) 71 | 72 | def forward(self, x: Tensor) -> dict[str, Tensor]: # type: ignore 73 | for i, layer in enumerate(self.net): 74 | x = layer(x) 75 | if i < len(self.net) - 1: 76 | x = self.activation_fn(x) 77 | 78 | return { 79 | param_name: x[:, info.start_idx : info.end_idx].reshape((-1, *info.shape)) 80 | for param_name, info in self._param_info.items() 81 | } 82 | 83 | def _init_weights(self, target_network: INR, init_var: float = 1.0): 84 | """All comments relate to `hypnettorch.hnets.mlp_hnet`""" 85 | 86 | # Compute input variance - #608-617 87 | # Since we only have single input this variance is simply equal to initial variance 88 | input_variance = init_var 89 | 90 | # Init hidden layers #636-651 91 | # Ignore batch norm layers since we don't use them 92 | for layer in cast(list[nn.Linear], self.net)[:-1]: 93 | nn.init.kaiming_uniform_(layer.weight, mode="fan_in", nonlinearity="relu") # type: ignore 94 | nn.init.zeros_(layer.bias) 95 | 96 | # Current variances and values from lines #652-697 default to Nones 97 | # Init biases of last layer to zero # 705-706 98 | nn.init.zeros_(cast(list[nn.Linear], self.net)[-1].bias) 99 | 100 | # Line #705 101 | c_relu = 2 102 | 103 | fan_in, _ = cast(tuple[int, int], nn.init._calculate_fan_in_and_fan_out(self.net[-1].weight)) # type: ignore 104 | 105 | """Generic initialization.""" 106 | for i in range(target_network.n_layers): 107 | if f"b{i}" in self._param_info: 108 | c_bias = 2 109 | 110 | b_info = self._param_info[f"b{i}"] 111 | m_fan_out = b_info.shape[0] 112 | m_fan_in = m_fan_out 113 | var_in = c_relu / (2.0 * fan_in * input_variance) 114 | self._init_last_layer_for_param(b_info, var_in) 115 | else: 116 | c_bias = 1 117 | 118 | if f"w{i}" in self._param_info: 119 | w_info = self._param_info[f"w{i}"] 120 | m_fan_in, m_fan_out = calc_fan_in_and_out(w_info.shape) 121 | 122 | var_in = c_relu / (c_bias * m_fan_in * fan_in * input_variance) 123 | self._init_last_layer_for_param(w_info, var_in) 124 | 125 | """Pseudo-SIREN initialization""" 126 | tn_heads: nn.Linear = cast(list[nn.Linear], self.net)[-1] 127 | tn_heads_w, tn_heads_b = tn_heads.weight, tn_heads.bias 128 | 129 | if isinstance(target_network, SIREN): 130 | for i in range(target_network.n_layers): 131 | if f"w{i}" in self._param_info: 132 | w_info = self._param_info[f"w{i}"] 133 | _, n_in = w_info.shape 134 | if n_in == 1: 135 | bound = 1 / n_in 136 | else: 137 | bound = math.sqrt(6.0 / n_in) 138 | if target_network.gradient_fix: 139 | bound = bound / float(target_network.params[f"o{i}"]) 140 | 141 | with torch.no_grad(): 142 | tn_heads_w[w_info.start_idx : w_info.end_idx, :].multiply_(torch.tensor(1e-3)) 143 | tn_heads_b[w_info.start_idx : w_info.end_idx].uniform_(-bound, bound) 144 | 145 | if f"b{i}" in self._param_info: 146 | b_info = self._param_info[f"b{i}"] 147 | with torch.no_grad(): 148 | tn_heads_w[b_info.start_idx : b_info.end_idx, :].multiply_(torch.tensor(1e-3)) 149 | tn_heads_b[b_info.start_idx : b_info.end_idx].uniform_(-bound, bound) 150 | 151 | """Initialization for non-layer learnable parameters""" 152 | with torch.no_grad(): 153 | for param_name, param_info in self._param_info.items(): 154 | if param_name == "freq": 155 | tn_heads_w[param_info.start_idx : param_info.end_idx, :].uniform_(-1e-4, 1e-4) 156 | for i, base_freq in enumerate(target_network.params[param_name].data): 157 | tn_heads_b[param_info.start_idx + i : param_info.start_idx + i + 1].uniform_( 158 | base_freq - 1e-2, base_freq + 1e-2 159 | ) 160 | if "o" in param_name: 161 | tn_heads_w[param_info.start_idx : param_info.end_idx, :].uniform_(-1e-4, 1e-4) 162 | omega = target_network.params[param_name].data 163 | tn_heads_b[param_info.start_idx : param_info.end_idx].uniform_(omega - 1e-2, omega + 1e-2) 164 | 165 | def _init_last_layer_for_param(self, param_info: ParamInfo, var_in: float): 166 | var = var_in 167 | 168 | std = math.sqrt(var) 169 | a = math.sqrt(3.0) * std 170 | 171 | nn.init._no_grad_uniform_( # type: ignore 172 | cast(list[nn.Linear], self.net)[-1].weight[param_info.start_idx : param_info.end_idx, :], -a, a 173 | ) 174 | 175 | def _compute_param_info(self, target_network: INR, shared_params: list[str]) -> dict[str, ParamInfo]: 176 | info: dict[str, ParamInfo] = {} 177 | current_idx = 0 178 | for param_name, param in target_network.params.items(): 179 | if param_name not in shared_params: 180 | info[param_name] = ParamInfo( 181 | start_idx=current_idx, end_idx=current_idx + param.numel(), shape=param.shape 182 | ) 183 | current_idx += param.numel() 184 | 185 | assert current_idx == target_network.num_params(shared_params=shared_params) 186 | 187 | return info 188 | 189 | @property 190 | def param_info(self) -> dict[str, ParamInfo]: 191 | return self._param_info 192 | -------------------------------------------------------------------------------- /hypersound/models/meta/inr.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Literal, Optional, Union, cast, overload 3 | 4 | import torch 5 | import torch.nn 6 | from torch import Tensor, nn 7 | from torch.nn.parameter import Parameter 8 | 9 | from hypersound.cfg import TargetNetworkMode 10 | 11 | 12 | class Backbone(nn.Module): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def __call__( # type: ignore 17 | self, 18 | x: Tensor, 19 | ) -> Tensor: 20 | return super().__call__(x) 21 | 22 | 23 | class INR(nn.Module): 24 | def __init__( 25 | self, 26 | input_size: int, 27 | output_size: int, 28 | hidden_sizes: list[int], 29 | activation_fn: nn.Module, 30 | bias: bool, 31 | mode: TargetNetworkMode, 32 | ): 33 | super().__init__() 34 | 35 | self.input_size = input_size 36 | self.output_size = output_size 37 | self.n_layers = 1 + len(hidden_sizes) 38 | self.mode = mode 39 | self.bias = bias 40 | self.params: dict[str, Parameter] = {} 41 | 42 | self._activation_fn = activation_fn 43 | 44 | for i, (n_in, n_out) in enumerate(zip([input_size] + hidden_sizes, hidden_sizes + [output_size])): 45 | w = Parameter(torch.empty((n_out, n_in)), requires_grad=True) 46 | nn.init.kaiming_uniform_(w, a=math.sqrt(5)) # type: ignore 47 | self.params[f"w{i}"] = w 48 | 49 | if bias is not None: 50 | b = Parameter(torch.empty((n_out,)), requires_grad=True) 51 | fan_in, _ = cast(tuple[int, int], nn.init._calculate_fan_in_and_fan_out(w)) # type: ignore 52 | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 53 | nn.init.uniform_(b, -bound, bound) 54 | self.params[f"b{i}"] = b 55 | 56 | for name, param in self.params.items(): 57 | self.register_parameter(name, param) 58 | 59 | def num_params(self, shared_params: Optional[list[str]] = None) -> int: 60 | num_params = 0 61 | if shared_params is None: 62 | shared_params = [] 63 | 64 | learnable_params = { 65 | param_name: param 66 | for param_name, param in self.params.items() 67 | if param.requires_grad and param_name not in shared_params 68 | } 69 | 70 | for param in learnable_params.values(): 71 | num_params += param.numel() 72 | 73 | return num_params 74 | 75 | def freeze_params(self, shared_params: list[str]) -> None: # TODO: Verify 76 | assert self.mode is TargetNetworkMode.TARGET_NETWORK 77 | for param_name, param in self.params.items(): 78 | if param_name not in shared_params: 79 | param.requires_grad = False 80 | 81 | @overload 82 | def __call__( 83 | self, 84 | x: Tensor, 85 | weights: Optional[dict[str, Tensor]] = None, 86 | *, 87 | return_activations: Literal[False] = ..., 88 | ) -> Tensor: 89 | ... 90 | 91 | @overload 92 | def __call__( 93 | self, 94 | x: Tensor, 95 | weights: Optional[dict[str, Tensor]] = None, 96 | *, 97 | return_activations: Literal[True], 98 | ) -> tuple[Tensor, list[tuple[Tensor, Tensor]]]: 99 | ... 100 | 101 | def __call__( # type: ignore 102 | self, 103 | x: Tensor, 104 | weights: Optional[dict[str, Tensor]] = None, 105 | *, 106 | return_activations: bool = False, 107 | ) -> Union[Tensor, tuple[Tensor, list[tuple[Tensor, Tensor]]]]: 108 | return super().__call__(x, weights, return_activations=return_activations) 109 | 110 | def forward( # type: ignore 111 | self, 112 | x: Tensor, 113 | weights: Optional[dict[str, Tensor]] = None, 114 | *, 115 | return_activations: bool = False, 116 | ) -> Union[Tensor, tuple[Tensor, list[tuple[Tensor, Tensor]]]]: 117 | """ 118 | Forward function for INR network. Works in two modes: if `weights` is None, behaves as a 119 | standard MLP, expects input size [S, I] and returns output [S, O], where S is sequence 120 | length and I and O are input and output sizes. 121 | If `weights` is provided will run batched inference with multiple models in parallel, 122 | with expected input shape of [N, S, I] and output shape [N, S, O]. Weights are expected to 123 | be dict of Tensors, where first dimension of each tensor is always N. 124 | """ 125 | activations: list[tuple[Tensor, Tensor]] = [] 126 | 127 | for i in range(self.n_layers): 128 | x = self._forward(x, layer_idx=i, weights=weights) 129 | 130 | h = x 131 | 132 | if i != self.n_layers - 1: 133 | x = self._activation_fn(x) 134 | 135 | if return_activations: 136 | activations.append((x, h)) 137 | 138 | if return_activations: 139 | return x, activations 140 | else: 141 | return x 142 | 143 | def _forward(self, x: Tensor, layer_idx: int, weights: Optional[dict[str, Tensor]] = None) -> Tensor: 144 | weight_name = f"w{layer_idx}" 145 | bias_name = f"b{layer_idx}" 146 | 147 | weight_matrix, bias = None, None 148 | 149 | if self.mode is TargetNetworkMode.INR: 150 | if weights is not None: 151 | raise ValueError("Can't provide weights in `inr` mode.") 152 | 153 | x = torch.matmul(x, self.params[weight_name].T) 154 | if self.bias: 155 | x = x + self.params[bias_name] 156 | 157 | elif self.mode in (TargetNetworkMode.TARGET_NETWORK, TargetNetworkMode.RESIDUAL, TargetNetworkMode.MODULATOR): 158 | if weights is None: 159 | raise ValueError("`weights` are required for inference in `target_network` mode.") 160 | 161 | if self.mode in (TargetNetworkMode.RESIDUAL, TargetNetworkMode.MODULATOR): 162 | weight_matrix = weights[weight_name] 163 | if not self.params[weight_name].requires_grad or not ( 164 | self.bias and self.params[bias_name].requires_grad # FIXME? 165 | ): 166 | raise ValueError() 167 | if self.bias: 168 | bias = weights[bias_name] 169 | 170 | if self.mode is TargetNetworkMode.RESIDUAL: 171 | weight_matrix = weight_matrix + torch.stack( 172 | x.shape[0] * [cast(Tensor, self.params[weight_name])], dim=0 173 | ) 174 | else: 175 | weight_matrix = weight_matrix * torch.stack( 176 | x.shape[0] * [cast(Tensor, self.params[weight_name])], dim=0 177 | ) 178 | 179 | else: 180 | weight_matrix = weights.get( 181 | weight_name, torch.stack(x.shape[0] * [cast(Tensor, self.params[weight_name])], dim=0) 182 | ) 183 | 184 | if self.bias: 185 | bias = weights.get( 186 | bias_name, torch.stack(x.shape[0] * [cast(Tensor, self.params[bias_name])], dim=0) 187 | ) 188 | 189 | assert weight_matrix is not None 190 | x = torch.bmm(x, weight_matrix.permute(0, 2, 1)) 191 | if self.bias: 192 | assert bias is not None 193 | if self.mode is TargetNetworkMode.RESIDUAL: 194 | x = x + bias.unsqueeze(1) 195 | elif self.mode is TargetNetworkMode.MODULATOR: 196 | x = x * bias.unsqueeze(1) 197 | else: 198 | x = x + bias.unsqueeze(1) 199 | else: 200 | raise ValueError(f"Unknown mode: `{self.mode}.") 201 | 202 | return x 203 | -------------------------------------------------------------------------------- /rave/export_rave.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from effortless_config import Config 4 | import logging 5 | from termcolor import colored 6 | import cached_conv as cc 7 | 8 | logging.basicConfig(level=logging.INFO, 9 | format=colored("[%(relativeCreated).2f] ", "green") + 10 | "%(message)s") 11 | 12 | logging.info("exporting model") 13 | 14 | 15 | class args(Config): 16 | RUN = None 17 | SR = None 18 | CACHED = False 19 | FIDELITY = .95 20 | NAME = "vae" 21 | STEREO = False 22 | DETERMINISTIC = False 23 | 24 | 25 | args.parse_args() 26 | cc.use_cached_conv(args.CACHED) 27 | 28 | from rave.model import RAVE 29 | from rave.resample import Resampling 30 | from rave.core import search_for_run 31 | 32 | import numpy as np 33 | import math 34 | 35 | 36 | class TraceModel(nn.Module): 37 | def __init__(self, pretrained: RAVE, resample: Resampling, 38 | fidelity: float): 39 | super().__init__() 40 | 41 | latent_size = pretrained.latent_size 42 | self.resample = resample 43 | 44 | self.pqmf = pretrained.pqmf 45 | self.encoder = pretrained.encoder 46 | self.decoder = pretrained.decoder 47 | 48 | self.register_buffer("latent_pca", pretrained.latent_pca) 49 | self.register_buffer("latent_mean", pretrained.latent_mean) 50 | self.register_buffer("latent_size", torch.tensor(latent_size)) 51 | self.register_buffer( 52 | "sampling_rate", 53 | torch.tensor(self.resample.taget_sr), 54 | ) 55 | try: 56 | self.register_buffer("max_batch_size", 57 | torch.tensor(cc.MAX_BATCH_SIZE)) 58 | except: 59 | print( 60 | "You should upgrade cached_conv if you want to use RAVE in batch mode !" 61 | ) 62 | self.register_buffer("max_batch_size", torch.tensor(1)) 63 | self.trained_cropped = bool(pretrained.cropped_latent_size) 64 | self.deterministic = args.DETERMINISTIC 65 | 66 | if self.trained_cropped: 67 | self.cropped_latent_size = pretrained.cropped_latent_size 68 | else: 69 | latent_size = np.argmax(pretrained.fidelity.numpy() > fidelity) 70 | latent_size = 2**math.ceil(math.log2(latent_size)) 71 | self.cropped_latent_size = latent_size 72 | 73 | x = torch.zeros(1, 1, 2**14) 74 | z = self.encode(x) 75 | ratio = x.shape[-1] // z.shape[-1] 76 | 77 | self.register_buffer( 78 | "encode_params", 79 | torch.tensor([ 80 | 1, 81 | 1, 82 | self.cropped_latent_size, 83 | ratio, 84 | ])) 85 | 86 | self.register_buffer( 87 | "decode_params", 88 | torch.tensor([ 89 | self.cropped_latent_size, 90 | ratio, 91 | 2 if args.STEREO else 1, 92 | 1, 93 | ])) 94 | 95 | self.register_buffer("forward_params", 96 | torch.tensor([1, 1, 2 if args.STEREO else 1, 1])) 97 | 98 | self.stereo = args.STEREO 99 | 100 | def post_process_distribution(self, mean, scale): 101 | std = nn.functional.softplus(scale) + 1e-4 102 | return mean, std 103 | 104 | def reparametrize(self, mean, std): 105 | var = std * std 106 | logvar = torch.log(var) 107 | 108 | z = torch.randn_like(mean) * std + mean 109 | kl = (mean * mean + var - logvar - 1).sum(1).mean() 110 | 111 | return z, kl 112 | 113 | @torch.jit.export 114 | def encode(self, x): 115 | x = self.resample.from_target_sampling_rate(x) 116 | 117 | if self.pqmf is not None: 118 | x = self.pqmf(x) 119 | 120 | mean, scale = self.encoder(x) 121 | mean, std = self.post_process_distribution(mean, scale) 122 | 123 | if self.deterministic: 124 | z = mean 125 | else: 126 | z = self.reparametrize(mean, std)[0] 127 | 128 | z = z - self.latent_mean.unsqueeze(-1) 129 | z = nn.functional.conv1d(z, self.latent_pca.unsqueeze(-1)) 130 | 131 | z = z[:, :self.cropped_latent_size] 132 | return z 133 | 134 | @torch.jit.export 135 | def encode_amortized(self, x): 136 | x = self.resample.from_target_sampling_rate(x) 137 | 138 | if self.pqmf is not None: 139 | x = self.pqmf(x) 140 | 141 | mean, scale = self.encoder(x) 142 | mean, std = self.post_process_distribution(mean, scale) 143 | var = std * std 144 | 145 | mean = mean - self.latent_mean.unsqueeze(-1) 146 | 147 | mean = nn.functional.conv1d(mean, self.latent_pca.unsqueeze(-1)) 148 | var = nn.functional.conv1d(var, self.latent_pca.unsqueeze(-1).pow(2)) 149 | 150 | mean = mean[:, :self.cropped_latent_size] 151 | var = var[:, :self.cropped_latent_size] 152 | std = var.sqrt() 153 | 154 | return mean, std 155 | 156 | @torch.jit.export 157 | def decode(self, z): 158 | if self.trained_cropped: # PERFORM PCA BEFORE PADDING 159 | z = nn.functional.conv1d(z, self.latent_pca.T.unsqueeze(-1)) 160 | z = z + self.latent_mean.unsqueeze(-1) 161 | 162 | if self.stereo and z.shape[0] == 1: # DUPLICATE LATENT PATH 163 | z = z.expand(2, z.shape[1], z.shape[2]) 164 | 165 | # CAT WITH SAMPLES FROM PRIOR DISTRIBUTION 166 | pad_size = self.latent_size.item() - z.shape[1] 167 | 168 | if self.deterministic: 169 | pad_latent = torch.zeros( 170 | z.shape[0], 171 | pad_size, 172 | z.shape[-1], 173 | device=z.device, 174 | ) 175 | else: 176 | pad_latent = torch.randn( 177 | z.shape[0], 178 | pad_size, 179 | z.shape[-1], 180 | device=z.device, 181 | ) 182 | 183 | z = torch.cat([z, pad_latent], 1) 184 | 185 | if not self.trained_cropped: # PERFORM PCA AFTER PADDING 186 | z = nn.functional.conv1d(z, self.latent_pca.T.unsqueeze(-1)) 187 | z = z + self.latent_mean.unsqueeze(-1) 188 | 189 | x = self.decoder(z, add_noise=not self.deterministic) 190 | 191 | if self.pqmf is not None: 192 | x = self.pqmf.inverse(x) 193 | 194 | x = self.resample.to_target_sampling_rate(x) 195 | 196 | if self.stereo: 197 | x = x.permute(1, 0, 2) 198 | return x 199 | 200 | def forward(self, x): 201 | return self.decode(self.encode(x)) 202 | 203 | 204 | logging.info("loading model from checkpoint") 205 | 206 | RUN = search_for_run(args.RUN) 207 | logging.info(f"using {RUN}") 208 | model = RAVE.load_from_checkpoint(RUN, strict=False).eval() 209 | 210 | logging.info("flattening weights") 211 | for m in model.modules(): 212 | if hasattr(m, "weight_g"): 213 | nn.utils.remove_weight_norm(m) 214 | 215 | logging.info("warmup forward pass") 216 | x = torch.zeros(1, 1, 2**14) 217 | if model.pqmf is not None: 218 | x = model.pqmf(x) 219 | 220 | z, _ = model.reparametrize(*model.encoder(x)) 221 | 222 | if args.STEREO: 223 | z = z.expand(2, *z.shape[1:]) 224 | 225 | y = model.decoder(z) 226 | 227 | if model.pqmf is not None: 228 | y = model.pqmf.inverse(y) 229 | 230 | model.discriminator = None 231 | 232 | sr = model.sr 233 | 234 | if args.SR is not None: 235 | target_sr = int(args.SR) 236 | else: 237 | target_sr = sr 238 | 239 | logging.info("build resampling model") 240 | resample = Resampling(target_sr, sr) 241 | x = torch.zeros(1, 1, 2**14) 242 | resample.to_target_sampling_rate(resample.from_target_sampling_rate(x)) 243 | 244 | logging.info("script model") 245 | model = TraceModel(model, resample, args.FIDELITY) 246 | model(x) 247 | 248 | model = torch.jit.script(model) 249 | logging.info(f"save rave_{args.NAME}.ts") 250 | model.save(f"rave_{args.NAME}.ts") 251 | -------------------------------------------------------------------------------- /rave/rave/pqmf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from scipy.signal import kaiser, kaiserord, kaiser_beta, firwin 4 | from scipy.optimize import fmin 5 | import math 6 | import numpy as np 7 | from einops import rearrange 8 | 9 | import cached_conv as cc 10 | 11 | 12 | def reverse_half(x): 13 | mask = torch.ones_like(x) 14 | mask[..., 1::2, ::2] = -1 15 | 16 | return x * mask 17 | 18 | 19 | def center_pad_next_pow_2(x): 20 | next_2 = 2**math.ceil(math.log2(x.shape[-1])) 21 | pad = next_2 - x.shape[-1] 22 | return nn.functional.pad(x, (pad // 2, pad // 2 + int(pad % 2))) 23 | 24 | 25 | def make_odd(x): 26 | if not x.shape[-1] % 2: 27 | x = nn.functional.pad(x, (0, 1)) 28 | return x 29 | 30 | 31 | def get_qmf_bank(h, n_band): 32 | """ 33 | Modulates an input protoype filter into a bank of 34 | cosine modulated filters 35 | Parameters 36 | ---------- 37 | h: torch.Tensor 38 | prototype filter 39 | n_band: int 40 | number of sub-bands 41 | """ 42 | k = torch.arange(n_band).reshape(-1, 1) 43 | N = h.shape[-1] 44 | t = torch.arange(-(N // 2), N // 2 + 1) 45 | 46 | p = (-1)**k * math.pi / 4 47 | 48 | mod = torch.cos((2 * k + 1) * math.pi / (2 * n_band) * t + p) 49 | hk = 2 * h * mod 50 | 51 | return hk 52 | 53 | 54 | def kaiser_filter(wc, atten, N=None): 55 | """ 56 | Computes a kaiser lowpass filter 57 | Parameters 58 | ---------- 59 | wc: float 60 | Angular frequency 61 | 62 | atten: float 63 | Attenuation (dB, positive) 64 | """ 65 | N_, beta = kaiserord(atten, wc / np.pi) 66 | N_ = 2 * (N_ // 2) + 1 67 | N = N if N is not None else N_ 68 | h = firwin(N, wc, window=('kaiser', beta), scale=False, nyq=np.pi) 69 | return h 70 | 71 | 72 | def loss_wc(wc, atten, M, N): 73 | """ 74 | Computes the objective described in https://ieeexplore.ieee.org/document/681427 75 | """ 76 | h = kaiser_filter(wc, atten, N) 77 | g = np.convolve(h, h[::-1], "full") 78 | g = abs(g[g.shape[-1] // 2::2 * M][1:]) 79 | return np.max(g) 80 | 81 | 82 | def get_prototype(atten, M, N=None): 83 | """ 84 | Given an attenuation objective and the number of bands 85 | returns the corresponding lowpass filter 86 | """ 87 | wc = fmin(lambda w: loss_wc(w, atten, M, N), 1 / M, disp=0)[0] 88 | return kaiser_filter(wc, atten, N) 89 | 90 | 91 | def polyphase_forward(x, hk, rearrange_filter=True): 92 | """ 93 | Polyphase implementation of the analysis process (fast) 94 | Parameters 95 | ---------- 96 | x: torch.Tensor 97 | signal to analyse ( B x 1 x T ) 98 | 99 | hk: torch.Tensor 100 | filter bank ( M x T ) 101 | """ 102 | x = rearrange(x, "b c (t m) -> b (c m) t", m=hk.shape[0]) 103 | if rearrange_filter: 104 | hk = rearrange(hk, "c (t m) -> c m t", m=hk.shape[0]) 105 | x = nn.functional.conv1d(x, hk, padding=hk.shape[-1] // 2)[..., :-1] 106 | return x 107 | 108 | 109 | def polyphase_inverse(x, hk, rearrange_filter=True): 110 | """ 111 | Polyphase implementation of the synthesis process (fast) 112 | Parameters 113 | ---------- 114 | x: torch.Tensor 115 | signal to synthesize from ( B x 1 x T ) 116 | 117 | hk: torch.Tensor 118 | filter bank ( M x T ) 119 | """ 120 | 121 | m = hk.shape[0] 122 | 123 | if rearrange_filter: 124 | hk = hk.flip(-1) 125 | hk = rearrange(hk, "c (t m) -> m c t", m=m) # polyphase 126 | 127 | pad = hk.shape[-1] // 2 + 1 128 | x = nn.functional.conv1d(x, hk, padding=int(pad))[..., :-1] * m 129 | 130 | x = x.flip(1) 131 | x = rearrange(x, "b (c m) t -> b c (t m)", m=m) 132 | x = x[..., 2 * hk.shape[1]:] 133 | return x 134 | 135 | 136 | def classic_forward(x, hk): 137 | """ 138 | Naive implementation of the analysis process (slow) 139 | Parameters 140 | ---------- 141 | x: torch.Tensor 142 | signal to analyse ( B x 1 x T ) 143 | 144 | hk: torch.Tensor 145 | filter bank ( M x T ) 146 | """ 147 | x = nn.functional.conv1d( 148 | x, 149 | hk.unsqueeze(1), 150 | stride=hk.shape[0], 151 | padding=hk.shape[-1] // 2, 152 | )[..., :-1] 153 | return x 154 | 155 | 156 | def classic_inverse(x, hk): 157 | """ 158 | Naive implementation of the synthesis process (slow) 159 | Parameters 160 | ---------- 161 | x: torch.Tensor 162 | signal to synthesize from ( B x 1 x T ) 163 | 164 | hk: torch.Tensor 165 | filter bank ( M x T ) 166 | """ 167 | hk = hk.flip(-1) 168 | y = torch.zeros(*x.shape[:2], hk.shape[0] * x.shape[-1]).to(x) 169 | y[..., ::hk.shape[0]] = x * hk.shape[0] 170 | y = nn.functional.conv1d( 171 | y, 172 | hk.unsqueeze(0), 173 | padding=hk.shape[-1] // 2, 174 | )[..., 1:] 175 | return y 176 | 177 | 178 | class PQMF(nn.Module): 179 | """ 180 | Pseudo Quadrature Mirror Filter multiband decomposition / reconstruction 181 | Parameters 182 | ---------- 183 | attenuation: int 184 | Attenuation of the rejected bands (dB, 80 - 120) 185 | n_band: int 186 | Number of bands, must be a power of 2 if the polyphase implementation 187 | is needed 188 | """ 189 | def __init__(self, attenuation, n_band, polyphase=True): 190 | super().__init__() 191 | h = get_prototype(attenuation, n_band) 192 | 193 | if polyphase: 194 | power = math.log2(n_band) 195 | assert power == math.floor( 196 | power 197 | ), "when using the polyphase algorithm, n_band must be a power of 2" 198 | 199 | h = torch.from_numpy(h).float() 200 | hk = get_qmf_bank(h, n_band) 201 | hk = center_pad_next_pow_2(hk) 202 | 203 | self.register_buffer("hk", hk) 204 | self.register_buffer("h", h) 205 | self.n_band = n_band 206 | self.polyphase = polyphase 207 | 208 | def forward(self, x): 209 | if self.n_band == 1: 210 | return x 211 | elif self.polyphase: 212 | x = polyphase_forward(x, self.hk) 213 | else: 214 | x = classic_forward(x, self.hk) 215 | 216 | x = reverse_half(x) 217 | 218 | return x 219 | 220 | def inverse(self, x): 221 | if self.n_band == 1: 222 | return x 223 | 224 | x = reverse_half(x) 225 | 226 | if self.polyphase: 227 | return polyphase_inverse(x, self.hk) 228 | else: 229 | return classic_inverse(x, self.hk) 230 | 231 | 232 | class CachedPQMF(PQMF): 233 | def __init__(self, *args, **kwargs): 234 | super().__init__(*args, **kwargs) 235 | 236 | hkf = make_odd(self.hk).unsqueeze(1) 237 | 238 | hki = self.hk.flip(-1) 239 | hki = rearrange(hki, "c (t m) -> m c t", m=self.hk.shape[0]) 240 | hki = make_odd(hki) 241 | 242 | self.forward_conv = cc.Conv1d( 243 | hkf.shape[1], 244 | hkf.shape[0], 245 | hkf.shape[2], 246 | padding=cc.get_padding(hkf.shape[-1]), 247 | stride=hkf.shape[0], 248 | bias=False, 249 | ) 250 | self.forward_conv.weight.data.copy_(hkf) 251 | 252 | self.inverse_conv = cc.Conv1d( 253 | hki.shape[1], 254 | hki.shape[0], 255 | hki.shape[-1], 256 | padding=cc.get_padding(hki.shape[-1]), 257 | bias=False, 258 | ) 259 | self.inverse_conv.weight.data.copy_(hki) 260 | 261 | def script_cache(self): 262 | self.forward_conv.script_cache() 263 | self.inverse_conv.script_cache() 264 | 265 | def forward(self, x): 266 | x = self.forward_conv(x) 267 | x = reverse_half(x) 268 | return x 269 | 270 | def inverse(self, x): 271 | x = reverse_half(x) 272 | m = self.hk.shape[0] 273 | x = self.inverse_conv(x) * m 274 | x = x.flip(1) 275 | x = x.permute(0, 2, 1) 276 | x = x.reshape(x.shape[0], x.shape[1], -1, m).permute(0, 2, 1, 3) 277 | x = x.reshape(x.shape[0], x.shape[1], -1) 278 | return x 279 | -------------------------------------------------------------------------------- /hypersound/datasets/base.py: -------------------------------------------------------------------------------- 1 | import random 2 | from abc import ABC 3 | from functools import partial 4 | from pathlib import Path 5 | from typing import Optional, cast 6 | 7 | import librosa 8 | import numpy as np 9 | import numpy.typing as npt 10 | import soundfile as sf 11 | import torchaudio 12 | from more_itertools import chunked 13 | from pathos.threading import ThreadPool as Pool 14 | from pytorch_yard.utils.logging import info_bold 15 | from rich.progress import Progress 16 | from torch import Tensor 17 | from torch.utils.data import Dataset 18 | from torchvision.transforms import transforms 19 | 20 | from hypersound.datasets.transforms import ( 21 | Dequantize, 22 | RandomApply, 23 | RandomCrop, 24 | RandomPhaseMangle, 25 | Transform, 26 | ) 27 | 28 | PROCESSING_BATCH_SIZE = 8 29 | 30 | 31 | class BaseSamples(Dataset[tuple[Tensor, int]], ABC): 32 | def __init__( 33 | self, 34 | sample_rate: int, 35 | fold: str, 36 | duration: float, 37 | start_offset: float, 38 | padding: bool, 39 | dequantize: bool, 40 | phase_mangle: bool, 41 | random_crop: bool, 42 | transforms: Optional[transforms.Compose] = None, 43 | ): 44 | 45 | self._recordings: list[Path] 46 | 47 | assert 8000 <= sample_rate <= 48000, f"Sample rate {sample_rate} is out of expected range of 8-48 kHz" 48 | 49 | self.sample_rate = sample_rate 50 | self.duration = duration 51 | self.total_samples = int(duration * sample_rate) 52 | self.start_offset = int(start_offset * sample_rate) 53 | self.padding = padding 54 | self.dequantize = dequantize 55 | self.phase_mangle = phase_mangle 56 | self.random_crop = random_crop 57 | 58 | info_bold(f"Sample rate: {sample_rate}, duration: {duration:.2f} --> N_SIGNAL = {self.total_samples}") 59 | 60 | assert fold in ["train", "validation"] 61 | self.fold = fold 62 | 63 | self.transforms = transforms or self.default_transforms() 64 | 65 | self.suffixes = f"-{sample_rate}-{self.total_samples}" 66 | self.suffixes = self.suffixes + ("p" if self.padding else "") 67 | self.suffixes = self.suffixes + (f"s{int(self.start_offset)}" if self.start_offset else "") 68 | 69 | @staticmethod 70 | def _save_recording( 71 | recording_src: Path, 72 | audio_dst_dir: Path, 73 | sample_rate: int, 74 | segment_len: int, 75 | start_offset: int, 76 | padding: bool, 77 | ): 78 | signal, _ = librosa.load(str(recording_src), sr=sample_rate) # type: ignore 79 | signal = cast(npt.NDArray[np.float32], signal) 80 | 81 | signal = signal[start_offset:] 82 | 83 | if padding: 84 | pad = segment_len - len(signal) % segment_len 85 | signal = cast(npt.NDArray[np.float32], np.pad(signal, (0, pad))) # type: ignore 86 | else: 87 | signal = signal[: len(signal) - len(signal) % segment_len] 88 | 89 | signal = signal.reshape(-1, segment_len) 90 | 91 | for i, y in enumerate(signal): 92 | dst_name = recording_src.with_stem(f"{recording_src.stem}-{i + 1}").with_suffix(".wav").name 93 | sf.write( # type: ignore 94 | str(audio_dst_dir / dst_name), 95 | y, 96 | sample_rate, 97 | format="wav", 98 | subtype="PCM_16", 99 | ) 100 | 101 | def __len__(self): 102 | return len(self._recordings) 103 | 104 | def __getitem__(self, idx: int) -> tuple[Tensor, int]: 105 | signal, sample_rate = cast(tuple[Tensor, int], torchaudio.load(str(self._recordings[idx]))) # type: ignore 106 | if self.transforms is not None: 107 | signal = cast(Tensor, self.transforms(signal)) 108 | return signal, sample_rate 109 | 110 | def process_recordings_dir( 111 | self, 112 | src_dir: Path, 113 | dst_dir: Path, 114 | original_audio_ext: str, 115 | validation_split: Optional[float] = None, 116 | val_speaker_split: Optional[int] = None, 117 | ): 118 | dst_dir = Path(str(dst_dir) + self.suffixes) 119 | 120 | if not dst_dir.is_dir(): 121 | info_bold("Preparing processed version of the dataset...") 122 | p = Pool(PROCESSING_BATCH_SIZE) 123 | 124 | dst_dir.mkdir() 125 | 126 | with Progress() as progress: 127 | 128 | def process_recordings(recordings: list[Path], fold: str): 129 | (dst_dir / fold).mkdir(exist_ok=True) 130 | 131 | progress.update( 132 | progress_recording, 133 | total=len(recordings), 134 | completed=0, 135 | description=f"[red]Processing {fold} recordings...", 136 | ) 137 | 138 | _save = partial( 139 | self._save_recording, 140 | audio_dst_dir=dst_dir / fold, 141 | sample_rate=self.sample_rate, 142 | segment_len=self.total_samples * 2, 143 | start_offset=self.start_offset, 144 | padding=self.padding, 145 | ) 146 | 147 | for recording_batch in chunked(recordings, PROCESSING_BATCH_SIZE): 148 | p.map(_save, recording_batch) # type: ignore 149 | progress.update(progress_recording, advance=len(recording_batch)) 150 | 151 | if validation_split: 152 | recordings = sorted([path for path in src_dir.glob(f"*{original_audio_ext}") if not path.is_dir()]) 153 | random.Random(1).shuffle(recordings) 154 | 155 | progress_recording = progress.add_task("[red]Processing recordings...", total=len(recordings)) 156 | 157 | val_size = int(validation_split * len(recordings)) 158 | recordings_train = recordings[:-val_size] 159 | recordings_val = recordings[-val_size:] 160 | 161 | process_recordings(recordings_train, "train") 162 | process_recordings(recordings_val, "validation") 163 | elif val_speaker_split: 164 | speakers = sorted([path for path in src_dir.iterdir() if path.is_dir()]) 165 | 166 | progress_speaker = progress.add_task("[red]Processing speakers...", total=len(speakers)) 167 | progress_recording = progress.add_task("[yellow]Resampling speaker recordings...", total=0) 168 | 169 | speakers_train = speakers[:-val_speaker_split] 170 | speakers_val = speakers[-val_speaker_split:] 171 | 172 | for speaker in speakers_train: 173 | process_recordings(list(speaker.rglob(f"*{original_audio_ext}")), "train") 174 | progress.update(progress_speaker, advance=1) 175 | for speaker in speakers_val: 176 | process_recordings(list(speaker.rglob(f"*{original_audio_ext}")), "validation") 177 | progress.update(progress_speaker, advance=1) 178 | else: 179 | raise RuntimeError("Both `validation_split` and `val_speaker_split` are unspecified.") 180 | 181 | p.close() # type: ignore 182 | 183 | files = list((dst_dir / self.fold).glob("*.wav")) 184 | files.sort() 185 | self._recordings = files 186 | 187 | def default_transforms(self): 188 | _transforms: list[Transform] = [] 189 | 190 | if self.fold == "validation": 191 | _transforms.append(RandomCrop(self.total_samples, random=False)) 192 | else: 193 | _transforms.append(RandomCrop(self.total_samples, random=self.random_crop)) 194 | if self.phase_mangle: 195 | _transforms.append( 196 | RandomApply( 197 | RandomPhaseMangle( 198 | min_f=20, 199 | max_f=2000, 200 | amplitude=0.99, 201 | sample_rate=self.sample_rate, 202 | ), 203 | p=0.8, 204 | ) 205 | ) 206 | 207 | if self.dequantize: 208 | _transforms.append(Dequantize(16)) 209 | 210 | return transforms.Compose(_transforms) 211 | -------------------------------------------------------------------------------- /hypersound/models/encoder.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | from typing import Any, Optional 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch import Tensor 8 | 9 | # https://github.com/wesbz/SoundStream/blob/main/net.py 10 | 11 | 12 | # Signal encoder 13 | # -------------- 14 | 15 | 16 | class CausalConv1d(nn.Conv1d): 17 | def __init__(self, *args: Any, **kwargs: Any): 18 | super().__init__(*args, **kwargs) # type: ignore 19 | self.causal_padding = self.dilation[0] * (self.kernel_size[0] - 1) 20 | 21 | def forward(self, x: Tensor): # type: ignore 22 | return self._conv_forward(F.pad(x, [self.causal_padding, 0]), self.weight, self.bias) 23 | 24 | 25 | class CausalConvTranspose1d(nn.ConvTranspose1d): 26 | def __init__(self, adjust_padding: int, *args: Any, **kwargs: Any): 27 | super().__init__(*args, **kwargs) # type: ignore 28 | self.adjust_padding = adjust_padding 29 | self.causal_padding = ( 30 | self.dilation[0] * (self.kernel_size[0] - 1) 31 | + self.output_padding[0] 32 | + 1 33 | - self.stride[0] 34 | + self.adjust_padding 35 | ) 36 | 37 | def forward(self, x: Tensor, output_size: Optional[list[int]] = None): # type: ignore 38 | if self.padding_mode != "zeros": 39 | raise ValueError("Only `zeros` padding mode is supported for ConvTranspose1d") 40 | 41 | assert isinstance(self.padding, tuple) 42 | output_padding = self._output_padding( 43 | x, output_size, self.stride, self.padding, self.kernel_size, self.dilation # type: ignore 44 | ) 45 | return F.conv_transpose1d( 46 | x, self.weight, self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation 47 | )[..., : -self.causal_padding] 48 | 49 | 50 | class ResidualUnit(nn.Module): 51 | def __init__(self, in_channels: int, out_channels: int, dilation: int, layer_norm_size: Optional[int] = None): 52 | super().__init__() 53 | 54 | self.dilation = dilation 55 | 56 | self.layers = nn.Sequential( 57 | # SEANet: first kernel_size=3 58 | CausalConv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=7, dilation=dilation), 59 | nn.LayerNorm([out_channels, layer_norm_size]) if layer_norm_size else nn.Identity(), 60 | nn.ELU(), 61 | nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1), 62 | nn.LayerNorm([out_channels, layer_norm_size]) if layer_norm_size else nn.Identity(), 63 | nn.ELU(), 64 | ) 65 | 66 | def forward(self, x: Tensor): # type: ignore 67 | return x + self.layers(x) 68 | 69 | 70 | class EncoderBlock(nn.Module): 71 | def __init__(self, channels: int, stride: int, layer_norm_size: Optional[int] = None): 72 | super().__init__() 73 | 74 | self.layers = nn.Sequential( 75 | ResidualUnit( 76 | in_channels=channels // 2, out_channels=channels // 2, dilation=1, layer_norm_size=layer_norm_size 77 | ), 78 | ResidualUnit( 79 | in_channels=channels // 2, out_channels=channels // 2, dilation=3, layer_norm_size=layer_norm_size 80 | ), 81 | ResidualUnit( 82 | in_channels=channels // 2, out_channels=channels // 2, dilation=9, layer_norm_size=layer_norm_size 83 | ), 84 | CausalConv1d(in_channels=channels // 2, out_channels=channels, kernel_size=2 * stride, stride=stride), 85 | nn.LayerNorm([channels, ceil(layer_norm_size / stride)]) if layer_norm_size else nn.Identity(), 86 | nn.ELU(), 87 | ) 88 | 89 | def forward(self, x: Tensor): # type: ignore 90 | return self.layers(x) 91 | 92 | 93 | class Encoder(nn.Module): 94 | """Audio encoder.""" 95 | 96 | def __init__( 97 | self, 98 | C: int, # SEANET: 32 99 | D: int, 100 | layer_norm: bool, 101 | **kwargs: Any, 102 | ): 103 | super().__init__() 104 | 105 | self.net = nn.Sequential( 106 | CausalConv1d(in_channels=1, out_channels=C, kernel_size=7), # 107 | nn.LayerNorm([16, 32768]) if layer_norm else nn.Identity(), # 16x32768 108 | nn.ELU(), 109 | EncoderBlock( 110 | channels=2 * C, stride=2, layer_norm_size=32768 if layer_norm else None 111 | ), # SEANet: Stride 2 # AudioGen 2 112 | EncoderBlock( 113 | channels=4 * C, stride=4, layer_norm_size=16384 if layer_norm else None 114 | ), # SEANet: Stride 2 # AudioGen 2 115 | EncoderBlock( 116 | channels=8 * C, stride=5, layer_norm_size=4096 if layer_norm else None 117 | ), # SEANet: Stride 8 # AudioGen 2 118 | EncoderBlock( 119 | channels=16 * C, stride=8, layer_norm_size=820 if layer_norm else None 120 | ), # SEANet: Stride 8 # AudioGen 4 121 | CausalConv1d( 122 | in_channels=16 * C, out_channels=D, kernel_size=3 123 | ), # SEANet: kernel_size=7, channels=128 # AudioGen: 7 124 | ) 125 | 126 | def __call__(self, x: Tensor) -> Tensor: # type: ignore 127 | return super().__call__(x) 128 | 129 | def forward(self, x: Tensor) -> Tensor: # type: ignore 130 | x = self.net(x.unsqueeze(1)) 131 | return x.flatten(1, 2) 132 | 133 | def output_width(self, input_length: Any) -> int: 134 | return ceil(input_length / (2 * 4 * 5 * 8)) 135 | 136 | 137 | # STFT discriminator 138 | # ---------------------- 139 | 140 | 141 | class ResidualUnit2d(nn.Module): 142 | def __init__(self, in_channels: int, N: int, m: int, s_t: int, s_f: int): 143 | super().__init__() 144 | 145 | self.s_t = s_t 146 | self.s_f = s_f 147 | 148 | self.layers = nn.Sequential( 149 | nn.Conv2d(in_channels=in_channels, out_channels=N, kernel_size=(3, 3), padding="same"), 150 | nn.ELU(), 151 | nn.Conv2d(in_channels=N, out_channels=m * N, kernel_size=(s_f + 2, s_t + 2), stride=(s_f, s_t)), 152 | ) 153 | 154 | self.skip_connection = nn.Conv2d( 155 | in_channels=in_channels, out_channels=m * N, kernel_size=(1, 1), stride=(s_f, s_t) 156 | ) 157 | 158 | def forward(self, x: Tensor) -> Tensor: # type: ignore 159 | return self.layers(F.pad(x, [self.s_t + 1, 0, self.s_f + 1, 0])) + self.skip_connection(x) 160 | 161 | 162 | class STFTDiscriminator(nn.Module): 163 | """STFT discriminator.""" 164 | 165 | def __init__( 166 | self, 167 | C: int, 168 | win_length: int, 169 | hop_size: int, 170 | **kwargs: Any, 171 | ): 172 | super().__init__() 173 | 174 | self.win_length = win_length 175 | self.hop_size = hop_size 176 | 177 | self.window: Tensor 178 | self.register_buffer("window", torch.hann_window(self.win_length)) 179 | 180 | frequency_bins = self.win_length // 2 181 | 182 | self.layers = nn.ModuleList( 183 | [ 184 | nn.Sequential(nn.Conv2d(in_channels=2, out_channels=32, kernel_size=(7, 7)), nn.ELU()), 185 | nn.Sequential(ResidualUnit2d(in_channels=32, N=C, m=2, s_t=1, s_f=2), nn.ELU()), 186 | nn.Sequential(ResidualUnit2d(in_channels=2 * C, N=2 * C, m=2, s_t=2, s_f=2), nn.ELU()), 187 | nn.Sequential(ResidualUnit2d(in_channels=4 * C, N=4 * C, m=1, s_t=1, s_f=2), nn.ELU()), 188 | nn.Sequential(ResidualUnit2d(in_channels=4 * C, N=4 * C, m=2, s_t=2, s_f=2), nn.ELU()), 189 | nn.Sequential(ResidualUnit2d(in_channels=8 * C, N=8 * C, m=1, s_t=1, s_f=2), nn.ELU()), 190 | nn.Sequential(ResidualUnit2d(in_channels=8 * C, N=8 * C, m=2, s_t=2, s_f=2), nn.ELU()), 191 | nn.Conv2d(in_channels=16 * C, out_channels=1, kernel_size=(frequency_bins // 2**6, 1)), 192 | ] 193 | ) 194 | 195 | def __call__(self, x: Tensor) -> Tensor: # type: ignore 196 | return super().__call__(x) 197 | 198 | # def features_lengths(self, lengths): 199 | # return [ 200 | # lengths-6, 201 | # lengths-6, 202 | # torch.div(lengths-5, 2, rounding_mode="floor"), 203 | # torch.div(lengths-5, 2, rounding_mode="floor"), 204 | # torch.div(lengths-3, 4, rounding_mode="floor"), 205 | # torch.div(lengths-3, 4, rounding_mode="floor"), 206 | # torch.div(lengths+1, 8, rounding_mode="floor"), 207 | # torch.div(lengths+1, 8, rounding_mode="floor") 208 | # ] 209 | 210 | def forward(self, x: Tensor) -> list[Tensor]: # type: ignore 211 | 212 | x = torch.stft( 213 | x, 214 | self.win_length, 215 | self.hop_size, 216 | self.win_length, 217 | self.window, 218 | return_complex=False, 219 | ).permute(0, 3, 1, 2) 220 | 221 | feature_map: list[Tensor] = [] 222 | for layer in self.layers: 223 | x = layer(x) 224 | feature_map.append(x) 225 | # print(x.shape) # DEBUG 226 | return feature_map 227 | -------------------------------------------------------------------------------- /inr/systems/main.py: -------------------------------------------------------------------------------- 1 | import gc 2 | from typing import Any, Iterable, Optional, Union, cast 3 | 4 | import matplotlib.pyplot as plt 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim 8 | import torch.optim.lr_scheduler 9 | import wandb 10 | from pytorch_yard.lightning import LightningModuleWithWandb 11 | from torch import Tensor 12 | from torch.nn import L1Loss 13 | 14 | # isort: split 15 | 16 | from hypersound.cfg import ModelType, Settings, TargetNetworkMode 17 | from hypersound.datasets.audio import IndicesAudioAndSpectrogram 18 | from hypersound.models.meta.inr import INR 19 | from hypersound.models.nerf import NERF 20 | from hypersound.models.siren import SIREN 21 | from hypersound.systems.loss import MultiSTFTLoss 22 | from hypersound.utils.metrics import METRICS, compute_metrics 23 | from hypersound.utils.wandb import fig_to_wandb 24 | 25 | 26 | class INRSystem(LightningModuleWithWandb): 27 | """Functional representation of audio data learned example by example.""" 28 | 29 | def __init__(self, cfg: Settings, spec_transform: nn.ModuleList, idx: int, extended_logging: bool): 30 | super().__init__() 31 | self.save_hyperparameters() # type: ignore 32 | 33 | self.cfg = cfg 34 | """Main experiment config.""" 35 | 36 | self.spec_transform = spec_transform 37 | """Raw signal to spectrogram transform provided as an `nn.ModuleList`.""" 38 | 39 | self.inr: Union[NERF, SIREN] 40 | """INR network""" 41 | 42 | self.idx = idx 43 | """Example index, used for logging.""" 44 | 45 | self.extended_logging = extended_logging 46 | """If True, will log reconstructions obtained with this system""" 47 | 48 | # ------------------------------------------------------------------------------------------ 49 | 50 | if self.cfg.model.type is ModelType.SIREN: 51 | self.inr = SIREN( 52 | input_size=1, 53 | output_size=1, 54 | hidden_sizes=self.cfg.model.target_network_layer_sizes, 55 | bias=True, 56 | mode=TargetNetworkMode.INR, 57 | omega_0=self.cfg.model.target_network_omega_0, 58 | omega_i=self.cfg.model.target_network_omega_i, 59 | learnable_omega=self.cfg.model.target_network_learnable_omega, 60 | gradient_fix=self.cfg.model.target_network_siren_gradient_fix, 61 | ) 62 | 63 | elif self.cfg.model.type is ModelType.NERF: 64 | self.inr = NERF( 65 | input_size=1, 66 | output_size=1, 67 | hidden_sizes=self.cfg.model.target_network_layer_sizes, 68 | bias=True, 69 | mode=TargetNetworkMode.INR, 70 | encoding_length=self.cfg.model.target_network_encoding_length, 71 | learnable_encoding=self.cfg.model.target_network_learnable_encoding, 72 | ) 73 | else: 74 | raise ValueError(f"Unknown type of INR: {self.cfg.model.type}") 75 | 76 | self.reconstruction_loss = L1Loss() 77 | 78 | self.perceptual_loss = MultiSTFTLoss( 79 | sample_rate=self.cfg.data.sample_rate, 80 | fft_sizes=self.cfg.model.fft_sizes, 81 | hop_sizes=self.cfg.model.hop_sizes, 82 | win_lengths=self.cfg.model.win_lengths, 83 | n_bins=self.cfg.data.n_mels, 84 | freq_weights_warmup_epochs=cfg.model.perceptual_loss_freq_weights_warmup_epochs, 85 | freq_weights_p=self.cfg.model.perceptual_loss_freq_weights_p 86 | ) 87 | 88 | self.data: Optional[IndicesAudioAndSpectrogram] = None 89 | self.metrics: dict[str, Tensor] = {} 90 | 91 | self._activations: list[tuple[Tensor, Tensor]] 92 | 93 | def compression_ratio(self) -> float: 94 | if self.data is not None: 95 | indices, _, _ = self.data 96 | else: 97 | raise ValueError("You must run at least one training step to get the compression rate.") 98 | num_params = cast(INR, self.inr).num_params() 99 | num_samples = indices.numel() 100 | return num_samples / num_params 101 | 102 | def on_fit_start(self) -> None: 103 | super().on_fit_start() 104 | 105 | self._log_inr(init_log=True) 106 | 107 | def forward( # type: ignore 108 | self, 109 | indices: Tensor, 110 | audio: Tensor, 111 | spectrogram: Tensor, 112 | log_reconstructions: bool = False, 113 | ) -> tuple[dict[str, Tensor], Tensor]: 114 | """Reconstruct a spectrogram with. 115 | 116 | Optionally, log examples to wandb. 117 | 118 | Parameters 119 | ---------- 120 | indices : Tensor 121 | Time indices. 122 | audio : Tensor 123 | Audio samples at given indices. 124 | spectrogram: Tensor 125 | Spectrogram data (full length). 126 | log_reconstructions: Whether to log reconstructions. 127 | Returns 128 | ------- 129 | dict[str, Tensor] 130 | Loss values. 131 | 132 | Tensor 133 | Audio reconstructions. 134 | 135 | """ 136 | # Audio shape should be (1, num_samples) 137 | assert len(audio.shape) == 2 138 | assert audio.shape[0] == 1 139 | assert audio.shape[1] > 1 140 | 141 | # Indices shape should be (num_samples, 1) 142 | assert len(indices.shape) == 3 143 | assert indices.shape[0] == 1 144 | indices = indices.squeeze(dim=0) 145 | assert indices.shape[0] == audio.shape[1] 146 | assert indices.shape[1] == 1 147 | 148 | # Spectrogram shape should be (1, num_mels, num_bins) 149 | assert len(spectrogram.shape) == 4 150 | assert spectrogram.shape[0] == 1 151 | spectrogram = spectrogram.squeeze(dim=0) 152 | assert spectrogram.shape[0] == 1 153 | assert spectrogram.shape[1] > 1 154 | assert spectrogram.shape[2] > 1 155 | 156 | if log_reconstructions: 157 | audio_reconstruction, self._activations = self.inr(indices, return_activations=True) 158 | else: 159 | audio_reconstruction = self.inr(indices) 160 | 161 | audio_reconstruction = audio_reconstruction.squeeze(dim=1).unsqueeze(dim=0) 162 | assert audio_reconstruction.shape == audio.shape 163 | 164 | spectrogram_reconstruction = self.to_spectrogram(audio_reconstruction) 165 | assert spectrogram_reconstruction.shape == spectrogram.shape 166 | 167 | reconstruction_loss = self.reconstruction_loss(audio, audio_reconstruction) 168 | perceptual_loss = self.perceptual_loss(audio, audio_reconstruction) 169 | 170 | if log_reconstructions: 171 | self.log_reconstructions( 172 | audio=audio, 173 | audio_reconstructions=audio_reconstruction, 174 | spectrograms=spectrogram, 175 | spectrogram_reconstructions=spectrogram_reconstruction, 176 | ) 177 | 178 | total_loss = ( 179 | self.cfg.model.reconstruction_loss_lambda * reconstruction_loss 180 | + self.cfg.model.perceptual_loss_lambda * perceptual_loss 181 | ) 182 | 183 | return ( 184 | dict( 185 | loss=total_loss, 186 | reconstruction=reconstruction_loss.detach(), 187 | perceptual=perceptual_loss.detach(), 188 | ), 189 | audio_reconstruction, 190 | ) 191 | 192 | def _on_epoch_end(self, outputs: list[dict[str, Tensor]]) -> None: # type: ignore 193 | metrics = { 194 | f"loss/total/r{self.idx}": outputs[-1]["loss"], 195 | f"loss/reconstruction/r{self.idx}": outputs[-1]["reconstruction"], 196 | f"loss/perceptual/r{self.idx}": outputs[-1]["perceptual"], 197 | "epoch": self.epoch, 198 | } 199 | 200 | for metric in METRICS: 201 | if metric in outputs[-1]: 202 | metrics[f"metric/{metric}/r{self.idx}"] = outputs[-1][metric] 203 | 204 | self.log_wandb( 205 | metrics, 206 | ) 207 | 208 | if self.extended_logging: 209 | with self.no_train(): 210 | self.forward(*cast(IndicesAudioAndSpectrogram, self.data), log_reconstructions=True) 211 | 212 | # Training 213 | # ---------------------------------------------------------------------------------------------- 214 | def training_step(self, batch: IndicesAudioAndSpectrogram, batch_idx: int) -> dict[str, Tensor]: # type: ignore 215 | if self.data is None: 216 | self.data = batch 217 | 218 | is_last = batch_idx == cast(Any, self.trainer).num_training_batches - 1 219 | 220 | loss, reconstructions = self.forward(*batch) 221 | 222 | metrics = compute_metrics( 223 | reconstructions, 224 | batch[1], 225 | sample_rate=self.cfg.data.sample_rate, 226 | pesq=is_last, 227 | stoi=is_last, 228 | cdpam=is_last, 229 | ) 230 | loss.update(metrics) 231 | 232 | self.metrics = loss 233 | 234 | return loss 235 | 236 | def training_epoch_end(self, outputs: list[dict[str, Tensor]]) -> None: # type: ignore 237 | self._on_epoch_end(outputs=outputs) 238 | 239 | # Optimizers 240 | # ---------------------------------------------------------------------------------------------- 241 | def configure_optimizers(self): # type: ignore 242 | # Main model optimization 243 | optimizer_main = torch.optim.Adam( 244 | self.inr.parameters(), 245 | lr=self.cfg.learning_rate, 246 | ) 247 | 248 | scheduler_ = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer_main, lr_lambda=lambda _: 1.0) 249 | 250 | optimizers = [optimizer_main] 251 | schedulers = [scheduler_] 252 | 253 | return optimizers, schedulers # type: ignore 254 | 255 | # Helpers 256 | # ---------------------------------------------------------------------------------------------- 257 | def to_spectrogram(self, y: Tensor) -> Tensor: 258 | """Convert raw audio signal to spectrogram using the pre-defined transform. 259 | 260 | Parameters 261 | ---------- 262 | y : Tensor 263 | Audio signal. 264 | 265 | Returns 266 | ------- 267 | Tensor 268 | Spectrogram. 269 | 270 | """ 271 | for transform in self.spec_transform: 272 | y = transform(y) 273 | return y 274 | 275 | # Visualizations 276 | # ---------------------------------------------------------------------------------------------- 277 | def log_reconstructions( 278 | self, 279 | audio: Tensor, 280 | audio_reconstructions: Tensor, 281 | spectrograms: Tensor, 282 | spectrogram_reconstructions: Tensor, 283 | ) -> None: 284 | """Log reconstruction samples to wandb. 285 | 286 | Parameters 287 | ---------- 288 | audio : Tensor 289 | [description] 290 | audio_reconstructions : Tensor 291 | [description] 292 | spectrograms : Tensor 293 | [description] 294 | spectrogram_reconstructions : Tensor 295 | [description] 296 | 297 | """ 298 | self._log_as_imgs( 299 | audio, 300 | audio_reconstructions, 301 | spectrograms, 302 | spectrogram_reconstructions, 303 | ) 304 | self._log_audio(audio, audio_reconstructions) 305 | self._log_inr() 306 | 307 | def _log_as_imgs( 308 | self, 309 | audio: Tensor, 310 | audio_reconstructions: Tensor, 311 | spectrograms: Tensor, 312 | spectrogram_reconstructions: Tensor, 313 | ): 314 | assert len(audio) == len(audio_reconstructions) == len(spectrograms) == len(spectrogram_reconstructions) 315 | 316 | with plt.style.context("bmh"): # type: ignore 317 | fig, axes = plt.subplots( # type: ignore 318 | nrows=4, ncols=len(audio), figsize=(3 * len(audio), 6), squeeze=False 319 | ) 320 | fig.subplots_adjust(hspace=0.01, wspace=0.01) 321 | axes = cast(Any, axes) 322 | 323 | for i, ax in enumerate(axes[0]): 324 | ax.imshow(spectrograms[i].squeeze().detach().cpu().numpy(), origin="lower", aspect="auto") 325 | ax.set_axis_off() 326 | for i, ax in enumerate(axes[1]): 327 | ax.imshow( 328 | spectrogram_reconstructions[i].squeeze().detach().cpu().numpy(), origin="lower", aspect="auto" 329 | ) 330 | ax.set_axis_off() 331 | for i, ax in enumerate(axes[2]): 332 | ax.plot(audio[i].detach().cpu().numpy()) 333 | ax.set_ylim(-1, 1) 334 | ax.xaxis.set_visible(False) 335 | ax.yaxis.set_visible(False) 336 | for i, ax in enumerate(axes[3]): 337 | ax.plot(audio_reconstructions[i].detach().cpu().numpy()) 338 | ax.set_ylim(-1, 1) 339 | ax.xaxis.set_visible(False) 340 | ax.yaxis.set_visible(False) 341 | 342 | self.log_wandb( 343 | { 344 | f"reconstruction/r{self.idx}": fig_to_wandb(fig), 345 | "epoch": self.epoch, 346 | } 347 | ) 348 | plt.close("all") # type: ignore 349 | gc.collect() 350 | 351 | def _log_audio( 352 | self, 353 | audio: Tensor, 354 | audio_reconstructions: Tensor, 355 | ): 356 | for i, (_audio, _reconstruction) in enumerate( 357 | zip( 358 | cast(Iterable[Tensor], audio), 359 | cast(Iterable[Tensor], audio_reconstructions), 360 | ) 361 | ): 362 | self.log_wandb( 363 | { 364 | f"audio/in/{i}/r{self.idx}": wandb.Audio( 365 | _audio.detach().cpu().numpy(), sample_rate=self.cfg.data.sample_rate 366 | ), 367 | "epoch": self.epoch, 368 | }, 369 | ) 370 | self.log_wandb( 371 | { 372 | f"audio/out/{i}/r{self.idx}": wandb.Audio( 373 | _reconstruction.detach().cpu().numpy(), sample_rate=self.cfg.data.sample_rate 374 | ), 375 | "epoch": self.epoch, 376 | }, 377 | ) 378 | 379 | def _log_inr(self, init_log: bool = False): 380 | for layer in self.inr.params: 381 | if layer[0] not in ["w", "b"]: 382 | continue 383 | self.log_wandb( 384 | { 385 | f"layer/{layer}/r{self.idx}": wandb.Histogram(self.inr.params[layer].detach().cpu().numpy()), 386 | "epoch": self.epoch if not init_log else 0, 387 | }, 388 | ) 389 | 390 | if not init_log: 391 | for i, (activations, preactivations) in enumerate(self._activations): 392 | self.log_wandb( 393 | { 394 | f"activations/{i}/r{self.idx}": wandb.Histogram(activations.detach().cpu().numpy()), 395 | f"preactivations/{i}/r{self.idx}": wandb.Histogram(preactivations.detach().cpu().numpy()), 396 | "epoch": self.epoch, 397 | }, 398 | ) 399 | --------------------------------------------------------------------------------