├── scripts ├── __init__.py ├── compress.py ├── utils.py ├── test.py ├── metrics.py ├── trainer_no_adv.py └── trainer_adv.py ├── esc ├── modules │ ├── loss │ │ ├── __init__.py │ │ ├── gan_loss.py │ │ └── generator_loss.py │ ├── vq │ │ ├── __init__.py │ │ ├── initialize.py │ │ └── codebook.py │ ├── transformer │ │ ├── __init__.py │ │ ├── scale.py │ │ └── attention.py │ ├── __init__.py │ └── convolution │ │ └── layers.py ├── __init__.py └── models │ ├── __init__.py │ ├── utils.py │ ├── discriminator.py │ ├── csrvq.py │ ├── codecs.py │ └── base.py ├── baselines └── descript │ ├── dac │ ├── compare │ │ ├── __init__.py │ │ └── encodec.py │ ├── nn │ │ ├── __init__.py │ │ ├── layers.py │ │ ├── quantize.py │ │ └── loss.py │ ├── model │ │ ├── __init__.py │ │ ├── discriminator.py │ │ ├── base.py │ │ └── dac.py │ ├── __init__.py │ ├── __main__.py │ └── utils │ │ ├── decode.py │ │ ├── encode.py │ │ └── __init__.py │ ├── README.md │ ├── conf │ ├── descript_6k_final.yml │ ├── 16khz_dns_9k.yml │ └── 16khz_dns_9k_tiny.yml │ └── scripts │ └── train_customize_no_adv.py ├── assets ├── results.png └── architecture.png ├── requirements.txt ├── configs ├── 9kbps_esc_base.yaml ├── ablations │ ├── 9kbps_csvq_conv.yaml │ ├── 9kbps_rvq_conv.yaml │ ├── 9kbps_csvq_swinT.yaml │ └── 9kbps_rvq_swinT.yaml ├── 9kbps_esc_large.yaml └── 9kbps_esc_base_adv.yaml ├── LICENSE ├── main.py ├── .gitignore ├── scripts_all.sh └── README.md /scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /esc/modules/loss/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /esc/modules/vq/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /esc/modules/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /baselines/descript/dac/compare/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /esc/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import ESC, RVQCodecs -------------------------------------------------------------------------------- /assets/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzGuu830/efficient-speech-codec/HEAD/assets/results.png -------------------------------------------------------------------------------- /assets/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzGuu830/efficient-speech-codec/HEAD/assets/architecture.png -------------------------------------------------------------------------------- /baselines/descript/dac/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from . import layers 2 | from . import loss 3 | from . import quantize 4 | -------------------------------------------------------------------------------- /esc/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .codecs import ESC, RVQCodecs, make_model 2 | from .discriminator import Discriminator -------------------------------------------------------------------------------- /baselines/descript/dac/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import CodecMixin 2 | from .base import DACFile 3 | from .dac import DAC 4 | from .discriminator import Discriminator 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.0 2 | torchaudio==2.0.0 3 | transformers==4.36.2 4 | accelerate==0.26.1 5 | timm==0.9.12 6 | einops==0.7.0 7 | pesq==0.0.4 8 | wandb>=0.16.2 9 | git+https://github.com/descriptinc/audiotools -------------------------------------------------------------------------------- /baselines/descript/dac/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.0" 2 | 3 | # preserved here for legacy reasons 4 | __model_version__ = "latest" 5 | 6 | import audiotools 7 | 8 | audiotools.ml.BaseModel.INTERN += ["dac.**"] 9 | audiotools.ml.BaseModel.EXTERN += ["einops"] 10 | 11 | 12 | from . import nn 13 | from . import model 14 | from . import utils 15 | from .model import DAC 16 | from .model import DACFile 17 | -------------------------------------------------------------------------------- /esc/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer.attention import TransformerLayer 2 | from .transformer.scale import PatchEmbed, PatchDeEmbed 3 | 4 | from .vq.quantization import ProductVectorQuantize, ResidualVectorQuantize, ProductResidualVectorQuantize 5 | from .loss.generator_loss import MelSpectrogramLoss, ComplexSTFTLoss 6 | from .loss.gan_loss import GANLoss 7 | from .convolution.layers import ConvolutionLayer, Convolution2D -------------------------------------------------------------------------------- /esc/models/utils.py: -------------------------------------------------------------------------------- 1 | from ..modules import TransformerLayer, ConvolutionLayer, Convolution2D 2 | 3 | def blk_func(blk, feat, feat_shape): 4 | Wh, Ww = feat_shape 5 | if isinstance(blk, TransformerLayer): 6 | feat_next, Wh, Ww = blk(feat, Wh, Ww) 7 | elif isinstance(blk, ConvolutionLayer): 8 | feat_next = blk(feat) 9 | Wh, Ww = Wh//2, Ww 10 | elif isinstance(blk, Convolution2D): 11 | feat_next = blk(feat) 12 | 13 | return feat_next, (Wh, Ww) -------------------------------------------------------------------------------- /baselines/descript/README.md: -------------------------------------------------------------------------------- 1 | ## Descript's Audio Codec (DAC) Experimental Reproduction 2 | 3 | This folder is mostly borrowed from [Descript's Github Repository](https://github.com/descriptinc/descript-audio-codec). 4 | 5 | We adapt a few features for customized reproduction. For developmental setups, refer to the original repository. 6 | 7 | 8 | ## Reproduce DAC Baselines 9 | 10 | ```ruby 11 | torchrun --nproc_per_node gpu train_customize.py --config 16kHz_dns_9k.yml 12 | ``` 13 | This reproduces 16kHz (0.5kbps ~ 9.0kbps) DAC with adversarial setups. 14 | 15 | ```ruby 16 | torchrun --nproc_per_node gpu train_customize_no_adv.py --config 16kHz_dns_9k_tiny.yml 17 | ``` 18 | This reproduces 16kHz (0.5kbps ~ 9.0kbps) DAC in non-adversarial setups. -------------------------------------------------------------------------------- /configs/9kbps_esc_base.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_data_path: ../data/train 3 | val_data_path: ../data/val 4 | num_workers: 36 5 | train_bs_per_device: 9 6 | val_bs_per_device: 4 7 | 8 | model_name: csvq+swinT 9 | model: 10 | backbone: transformer 11 | in_dim: 2 12 | in_freq: 192 13 | h_dims: [45,72,96,144,192,384] 14 | max_streams: 6 15 | win_len: 20 16 | hop_len: 5 17 | sr: 16000 18 | patch_size: [3,2] 19 | swin_heads: [3,6,12,24,24] 20 | swin_depth: 2 21 | window_size: 4 22 | mlp_ratio: 4. 23 | overlap: 2 24 | group_size: 3 25 | codebook_size: 1024 26 | codebook_dims: [32,32,16,12,8,6] 27 | l2norm: True 28 | 29 | loss: 30 | stft_weight: 1.0 31 | cm_weight: .25 32 | cb_weight: 1.0 33 | mel_weight: .25 -------------------------------------------------------------------------------- /configs/ablations/9kbps_csvq_conv.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_data_path: ../data/dnscustom/processed_wav/train 3 | val_data_path: ../data/dnscustom/processed_wav/test 4 | num_workers: 36 5 | train_bs_per_device: 9 6 | val_bs_per_device: 8 7 | 8 | model_name: csvq+conv 9 | model: 10 | backbone: convolution 11 | in_dim: 2 12 | in_freq: 192 13 | h_dims: [45,72,96,144,192,384] 14 | max_streams: 6 15 | kernel_size: [5,2] 16 | patch_size: [3,2] 17 | conv_depth: 1 18 | overlap: 2 19 | group_size: 3 20 | codebook_size: 1024 21 | codebook_dim: [8,8,8,8,8,8] 22 | l2norm: True 23 | win_len: 20 24 | hop_len: 5 25 | sr: 16000 26 | 27 | loss: 28 | stft_weight: 1.0 29 | cm_weight: .25 30 | cb_weight: 1.0 31 | mel_weight: .25 -------------------------------------------------------------------------------- /configs/ablations/9kbps_rvq_conv.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_data_path: ../data/dnscustom/processed_wav/train 3 | val_data_path: ../data/dnscustom/processed_wav/test 4 | num_workers: 36 5 | train_bs_per_device: 9 6 | val_bs_per_device: 8 7 | 8 | model_name: rvq+conv 9 | model: 10 | backbone: convolution 11 | in_dim: 2 12 | in_freq: 192 13 | h_dims: [45,72,96,144,192,384] 14 | max_streams: 6 15 | kernel_size: [5,2] 16 | patch_size: [3,2] 17 | conv_depth: 1 18 | overlap: 2 19 | num_rvqs: 6 20 | group_size: 3 21 | codebook_size: 1024 22 | codebook_dim: [8,8,8,8,8,8] 23 | l2norm: True 24 | win_len: 20 25 | hop_len: 5 26 | sr: 16000 27 | 28 | loss: 29 | stft_weight: 1.0 30 | cm_weight: .25 31 | cb_weight: 1.0 32 | mel_weight: .25 -------------------------------------------------------------------------------- /configs/9kbps_esc_large.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_data_path: ../data/dnscustom/processed_wav/train 3 | val_data_path: ../data/dnscustom/processed_wav/test 4 | num_workers: 36 5 | train_bs_per_device: 9 6 | val_bs_per_device: 6 7 | 8 | model_name: csvq+swinT 9 | model: 10 | backbone: transformer 11 | in_dim: 2 12 | in_freq: 192 13 | h_dims: [45,72,96,144,192,384] 14 | max_streams: 6 15 | patch_size: [3,2] 16 | swin_heads: [3,6,12,24,24] 17 | swin_depth: 4 18 | window_size: 4 19 | mlp_ratio: 4. 20 | overlap: 2 21 | group_size: 3 22 | codebook_size: 1024 23 | codebook_dims: [8,8,8,8,8,8] 24 | l2norm: True 25 | win_len: 20 26 | hop_len: 5 27 | sr: 16000 28 | 29 | loss: 30 | stft_weight: 1.0 31 | cm_weight: .25 32 | cb_weight: 1.0 33 | mel_weight: .25 -------------------------------------------------------------------------------- /baselines/descript/dac/__main__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import argbind 4 | 5 | from dac.utils import download 6 | from dac.utils.decode import decode 7 | from dac.utils.encode import encode 8 | 9 | STAGES = ["encode", "decode", "download"] 10 | 11 | 12 | def run(stage: str): 13 | """Run stages. 14 | 15 | Parameters 16 | ---------- 17 | stage : str 18 | Stage to run 19 | """ 20 | if stage not in STAGES: 21 | raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}") 22 | stage_fn = globals()[stage] 23 | 24 | if stage == "download": 25 | stage_fn() 26 | return 27 | 28 | stage_fn() 29 | 30 | 31 | if __name__ == "__main__": 32 | group = sys.argv.pop(1) 33 | args = argbind.parse_args(group=group) 34 | 35 | with argbind.scope(args): 36 | run(group) 37 | -------------------------------------------------------------------------------- /configs/ablations/9kbps_csvq_swinT.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_data_path: ../data/dnscustom/processed_wav/train 3 | val_data_path: ../data/dnscustom/processed_wav/test 4 | num_workers: 36 5 | train_bs_per_device: 9 6 | val_bs_per_device: 6 7 | 8 | model_name: csvq+swinT 9 | model: 10 | backbone: transformer 11 | in_dim: 2 12 | in_freq: 192 13 | h_dims: [45,72,96,144,192,384] 14 | max_streams: 6 15 | patch_size: [3,2] 16 | swin_heads: [3,6,12,24,24] 17 | swin_depth: 2 18 | window_size: 4 19 | mlp_ratio: 4. 20 | overlap: 2 21 | group_size: 3 22 | codebook_size: 1024 23 | codebook_dims: [8,8,8,8,8,8] 24 | l2norm: True 25 | win_len: 20 26 | hop_len: 5 27 | sr: 16000 28 | 29 | loss: 30 | stft_weight: 1.0 31 | cm_weight: .25 32 | cb_weight: 1.0 33 | mel_weight: .25 -------------------------------------------------------------------------------- /configs/ablations/9kbps_rvq_swinT.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_data_path: ../data/dnscustom/processed_wav/train 3 | val_data_path: ../data/dnscustom/processed_wav/test 4 | num_workers: 36 5 | train_bs_per_device: 18 6 | val_bs_per_device: 6 7 | 8 | model_name: rvq+swinT 9 | model: 10 | backbone: transformer 11 | in_dim: 2 12 | in_freq: 192 13 | h_dims: [45,72,96,144,192,384] 14 | max_streams: 6 15 | patch_size: [3,2] 16 | swin_heads: [3,6,12,24,24] 17 | swin_depth: 2 18 | window_size: 4 19 | mlp_ratio: 4. 20 | overlap: 2 21 | num_rvqs: 6 22 | group_size: 3 23 | codebook_size: 1024 24 | codebook_dim: 8 25 | l2norm: True 26 | win_len: 20 27 | hop_len: 5 28 | sr: 16000 29 | 30 | loss: 31 | stft_weight: 1.0 32 | cm_weight: .25 33 | cb_weight: 1.0 34 | mel_weight: .25 -------------------------------------------------------------------------------- /baselines/descript/dac/nn/layers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | from torch.nn.utils import weight_norm 7 | 8 | 9 | def WNConv1d(*args, **kwargs): 10 | return weight_norm(nn.Conv1d(*args, **kwargs)) 11 | 12 | 13 | def WNConvTranspose1d(*args, **kwargs): 14 | return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) 15 | 16 | 17 | # Scripting this brings model speed up 1.4x 18 | @torch.jit.script 19 | def snake(x, alpha): 20 | shape = x.shape 21 | x = x.reshape(shape[0], shape[1], -1) 22 | x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) 23 | x = x.reshape(shape) 24 | return x 25 | 26 | 27 | class Snake1d(nn.Module): 28 | def __init__(self, channels): 29 | super().__init__() 30 | self.alpha = nn.Parameter(torch.ones(1, channels, 1)) 31 | 32 | def forward(self, x): 33 | return snake(x, self.alpha) 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [2024] [Yuzhe Gu] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /configs/9kbps_esc_base_adv.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_data_path: ../data/train 3 | val_data_path: ../data/val 4 | num_workers: 36 5 | train_bs_per_device: 9 6 | val_bs_per_device: 4 7 | 8 | model_name: csvq+swinT 9 | model: 10 | backbone: transformer 11 | in_dim: 2 12 | in_freq: 192 13 | h_dims: [45,72,96,144,192,384] 14 | max_streams: 6 15 | win_len: 20 16 | hop_len: 5 17 | sr: 16000 18 | patch_size: [3,2] 19 | swin_heads: [3,6,12,24,24] 20 | swin_depth: 2 21 | window_size: 4 22 | mlp_ratio: 4. 23 | overlap: 2 24 | group_size: 3 25 | codebook_size: 1024 26 | codebook_dims: [8,8,8,8,8,8] 27 | l2norm: True 28 | 29 | discriminator: 30 | sample_rate: 16000 31 | rates: [] 32 | periods: [2, 3, 5, 7, 11] 33 | fft_sizes: [2048, 1024, 512] 34 | bands: 35 | - [0.0, 0.1] 36 | - [0.1, 0.25] 37 | - [0.25, 0.5] 38 | - [0.5, 0.75] 39 | - [0.75, 1.0] 40 | 41 | loss: 42 | stft_weight: 0.0 43 | cm_weight: .25 44 | cb_weight: 1.0 45 | mel_weight: 15.0 46 | gen_weight: 1.0 47 | feat_weight: 2.0 -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from scripts.trainer_no_adv import main as train_no_adv 4 | from scripts.trainer_adv import main as train_adv 5 | from scripts.utils import read_yaml, dict2namespace 6 | 7 | def parse_args_config(): 8 | parser = argparse.ArgumentParser() 9 | 10 | # Experimental Setups 11 | parser.add_argument("--exp_name", default="esc9kbps", type=str) 12 | parser.add_argument("--wandb_project", default=None, type=str) 13 | parser.add_argument("--lr", default=1.e-4, type=float) 14 | parser.add_argument("--num_epochs", default=80, type=int) 15 | parser.add_argument("--num_pretraining_epochs", default=10, type=int) 16 | parser.add_argument("--num_devices", default=4, type=int) 17 | parser.add_argument("--num_warmup_steps", default=0, type=int) 18 | parser.add_argument("--val_metric", default="PESQ", type=str) 19 | parser.add_argument("--scheduler_type", default="constant", type=str) 20 | parser.add_argument("--dropout_rate", type=float, default=1.0) 21 | parser.add_argument("--adv_training", default=False, action="store_true") 22 | parser.add_argument("--pretrain_ckp", type=str, default=None) 23 | 24 | parser.add_argument("--log_steps", default=5, type=int) 25 | parser.add_argument("--save_path", default="./output", type=str) 26 | parser.add_argument("--config_path", default="./configs/9kbps_esc_base.yaml") 27 | parser.add_argument("--seed", default=1234, type=int) 28 | 29 | args = parser.parse_args() 30 | config = dict2namespace(read_yaml(args.config_path)) 31 | return args, config 32 | 33 | 34 | if __name__ == "__main__": 35 | args, config = parse_args_config() 36 | if args.adv_training: 37 | train_adv(args, config) 38 | else: 39 | train_no_adv(args, config) 40 | -------------------------------------------------------------------------------- /baselines/descript/dac/compare/encodec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from audiotools import AudioSignal 3 | from audiotools.ml import BaseModel 4 | from encodec import EncodecModel 5 | 6 | 7 | class Encodec(BaseModel): 8 | def __init__(self, sample_rate: int = 24000, bandwidth: float = 24.0): 9 | super().__init__() 10 | 11 | if sample_rate == 24000: 12 | self.model = EncodecModel.encodec_model_24khz() 13 | else: 14 | self.model = EncodecModel.encodec_model_48khz() 15 | self.model.set_target_bandwidth(bandwidth) 16 | self.sample_rate = 44100 17 | 18 | def forward( 19 | self, 20 | audio_data: torch.Tensor, 21 | sample_rate: int = 44100, 22 | n_quantizers: int = None, 23 | ): 24 | signal = AudioSignal(audio_data, sample_rate) 25 | signal.resample(self.model.sample_rate) 26 | recons = self.model(signal.audio_data) 27 | recons = AudioSignal(recons, self.model.sample_rate) 28 | recons.resample(sample_rate) 29 | return {"audio": recons.audio_data} 30 | 31 | 32 | if __name__ == "__main__": 33 | import numpy as np 34 | from functools import partial 35 | 36 | model = Encodec() 37 | 38 | for n, m in model.named_modules(): 39 | o = m.extra_repr() 40 | p = sum([np.prod(p.size()) for p in m.parameters()]) 41 | fn = lambda o, p: o + f" {p/1e6:<.3f}M params." 42 | setattr(m, "extra_repr", partial(fn, o=o, p=p)) 43 | print(model) 44 | print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()])) 45 | 46 | length = 88200 * 2 47 | x = torch.randn(1, 1, length).to(model.device) 48 | x.requires_grad_(True) 49 | x.retain_grad() 50 | 51 | # Make a forward pass 52 | out = model(x)["audio"] 53 | 54 | print(x.shape, out.shape) 55 | -------------------------------------------------------------------------------- /baselines/descript/conf/descript_6k_final.yml: -------------------------------------------------------------------------------- 1 | # Model setup 2 | DAC: 3 | sample_rate: 16000 4 | encoder_dim: 64 5 | encoder_rates: [2, 4, 5, 8] 6 | decoder_dim: 1536 7 | decoder_rates: [8, 5, 4, 2] 8 | 9 | # Quantization 10 | n_codebooks: 12 11 | codebook_size: 1024 12 | codebook_dim: 8 13 | quantizer_dropout: 0.5 14 | 15 | # Discriminator 16 | Discriminator: 17 | sample_rate: 16000 18 | rates: [] 19 | periods: [2, 3, 5, 7, 11] 20 | fft_sizes: [2048, 1024, 512] 21 | bands: 22 | - [0.0, 0.1] 23 | - [0.1, 0.25] 24 | - [0.25, 0.5] 25 | - [0.5, 0.75] 26 | - [0.75, 1.0] 27 | 28 | # Optimization 29 | AdamW: 30 | betas: [0.8, 0.99] 31 | lr: 0.0001 32 | ExponentialLR: 33 | gamma: 0.999996 34 | 35 | amp: false 36 | val_batch_size: 16 37 | batch_size: 12 38 | device: cuda 39 | num_iters: 400000 40 | save_iters: [10000, 50000, 100000, 200000] 41 | valid_freq: 4000 42 | sample_freq: 10000 43 | num_workers: 8 44 | log_every: 5 45 | val_idx: [0, 1, 2, 3, 4, 5, 6, 7] 46 | seed: 53 47 | lambdas: 48 | mel/loss: 15.0 49 | adv/feat_loss: 2.0 50 | adv/gen_loss: 1.0 51 | vq/commitment_loss: 0.25 52 | vq/codebook_loss: 1.0 53 | 54 | # Transforms 55 | build_transform: 56 | preprocess: 57 | - Identity 58 | augment_prob: 0.0 59 | augment: 60 | - Identity 61 | postprocess: 62 | - VolumeNorm 63 | - RescaleAudio 64 | - ShiftPhase 65 | # - Identity 66 | 67 | # Loss setup 68 | MultiScaleSTFTLoss: 69 | window_lengths: [2048, 512] 70 | MelSpectrogramLoss: 71 | n_mels: [5, 10, 20, 40, 80, 160, 320] 72 | window_lengths: [32, 64, 128, 256, 512, 1024, 2048] 73 | mel_fmin: [0, 0, 0, 0, 0, 0, 0] 74 | mel_fmax: [null, null, null, null, null, null, null] 75 | pow: 1.0 76 | clamp_eps: 1.0e-5 77 | mag_weight: 0.0 78 | 79 | save_path: /scratch/eys9/descript-audio-codec/runs/compare_study_dns/ 80 | wb_project_name: Neural_Speech_Coding 81 | wb_exp_name: DAC16k-Original -------------------------------------------------------------------------------- /baselines/descript/conf/16khz_dns_9k.yml: -------------------------------------------------------------------------------- 1 | # Model setup 2 | DAC: 3 | sample_rate: 16000 4 | encoder_dim: 64 5 | encoder_rates: [2, 4, 5, 8] 6 | decoder_dim: 1536 7 | decoder_rates: [8, 5, 4, 2] 8 | 9 | # Quantization 10 | n_codebooks: 18 11 | codebook_size: 1024 12 | codebook_dim: 8 13 | quantizer_dropout: 0.5 14 | 15 | # Discriminator 16 | Discriminator: 17 | sample_rate: 16000 18 | rates: [] 19 | periods: [2, 3, 5, 7, 11] 20 | fft_sizes: [2048, 1024, 512] 21 | bands: 22 | - [0.0, 0.1] 23 | - [0.1, 0.25] 24 | - [0.25, 0.5] 25 | - [0.5, 0.75] 26 | - [0.75, 1.0] 27 | 28 | # Optimization 29 | AdamW: 30 | betas: [0.8, 0.99] 31 | lr: 0.0001 32 | ExponentialLR: 33 | gamma: 0.999996 34 | 35 | amp: false 36 | val_batch_size: 24 37 | batch_size: 16 38 | device: cuda 39 | num_iters: 400000 40 | save_iters: [10000, 50000, 100000, 200000] 41 | valid_freq: 4000 42 | sample_freq: 10000 43 | num_workers: 32 44 | log_every: 5 45 | val_idx: [0, 1, 2, 3, 4, 5, 6, 7] 46 | seed: 53 47 | lambdas: 48 | mel/loss: 15.0 49 | adv/feat_loss: 2.0 50 | adv/gen_loss: 1.0 51 | vq/commitment_loss: 0.25 52 | vq/codebook_loss: 1.0 53 | 54 | # Transforms 55 | build_transform: 56 | preprocess: 57 | - Identity 58 | augment_prob: 0.0 59 | augment: 60 | - Identity 61 | postprocess: 62 | - VolumeNorm 63 | - RescaleAudio 64 | - ShiftPhase 65 | 66 | # Loss setup 67 | MultiScaleSTFTLoss: 68 | window_lengths: [2048, 512] 69 | MelSpectrogramLoss: 70 | n_mels: [5, 10, 20, 40, 80, 160, 320] 71 | window_lengths: [32, 64, 128, 256, 512, 1024, 2048] 72 | mel_fmin: [0, 0, 0, 0, 0, 0, 0] 73 | mel_fmax: [null, null, null, null, null, null, null] 74 | pow: 1.0 75 | clamp_eps: 1.0e-5 76 | mag_weight: 0.0 77 | 78 | 79 | data_path: ../DNS_CHALLENGE/processed_wav 80 | save_path: ../dac_output/DAC16kHz_9kbps_base/ 81 | wb_project_name: Neural_Speech_Coding 82 | wb_exp_name: DAC16kHz_9kbps_base -------------------------------------------------------------------------------- /baselines/descript/conf/16khz_dns_9k_tiny.yml: -------------------------------------------------------------------------------- 1 | # Model setup 2 | DAC: 3 | sample_rate: 16000 4 | encoder_dim: 32 5 | encoder_rates: [2, 4, 5, 8] 6 | decoder_dim: 288 7 | decoder_rates: [8, 5, 4, 2] 8 | 9 | # Quantization 10 | n_codebooks: 18 11 | codebook_size: 1024 12 | codebook_dim: 8 13 | quantizer_dropout: 0.5 14 | 15 | # Discriminator 16 | Discriminator: 17 | sample_rate: 16000 18 | rates: [] 19 | periods: [2, 3, 5, 7, 11] 20 | fft_sizes: [2048, 1024, 512] 21 | bands: 22 | - [0.0, 0.1] 23 | - [0.1, 0.25] 24 | - [0.25, 0.5] 25 | - [0.5, 0.75] 26 | - [0.75, 1.0] 27 | 28 | # Optimization 29 | AdamW: 30 | betas: [0.8, 0.99] 31 | lr: 0.0001 32 | ExponentialLR: 33 | gamma: 0.999996 34 | 35 | amp: false 36 | val_batch_size: 32 37 | batch_size: 16 38 | device: cuda 39 | num_iters: 400000 40 | save_iters: [10000, 50000, 100000, 200000] 41 | valid_freq: 4000 42 | sample_freq: 10000 43 | num_workers: 32 44 | log_every: 5 45 | val_idx: [0, 1, 2, 3, 4, 5, 6, 7] 46 | seed: 53 47 | lambdas: 48 | mel/loss: 15.0 49 | adv/feat_loss: 2.0 50 | adv/gen_loss: 1.0 51 | vq/commitment_loss: 0.25 52 | vq/codebook_loss: 1.0 53 | 54 | # Transforms 55 | build_transform: 56 | preprocess: 57 | - Identity 58 | augment_prob: 0.0 59 | augment: 60 | - Identity 61 | postprocess: 62 | - VolumeNorm 63 | - RescaleAudio 64 | - ShiftPhase 65 | 66 | # Loss setup 67 | MultiScaleSTFTLoss: 68 | window_lengths: [2048, 512] 69 | MelSpectrogramLoss: 70 | n_mels: [5, 10, 20, 40, 80, 160, 320] 71 | window_lengths: [32, 64, 128, 256, 512, 1024, 2048] 72 | mel_fmin: [0, 0, 0, 0, 0, 0, 0] 73 | mel_fmax: [null, null, null, null, null, null, null] 74 | pow: 1.0 75 | clamp_eps: 1.0e-5 76 | mag_weight: 0.0 77 | 78 | 79 | data_path: ../DNS_CHALLENGE/processed_wav 80 | save_path: ../dac_output/DAC16kHz_9kbps_tiny/ 81 | wb_project_name: Neural_Speech_Coding 82 | wb_exp_name: DAC16kHz_9kbps_tiny -------------------------------------------------------------------------------- /esc/modules/loss/gan_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class GANLoss(nn.Module): 6 | """ 7 | Computes a discriminator loss, given a discriminator on 8 | generated waveforms/spectrograms compared to ground truth 9 | waveforms/spectrograms. Computes the loss for both the 10 | discriminator and the generator in separate functions. 11 | 12 | Adapted from DAC https://github.com/descriptinc/descript-audio-codec/blob/main/ 13 | """ 14 | 15 | def __init__(self, discriminator): 16 | super().__init__() 17 | self.discriminator = discriminator 18 | 19 | def forward(self, fake, real): 20 | """ 21 | fake/real: audio tensor of shape [batchsize, channel, len] 22 | """ 23 | if fake.dim() == 2: fake = fake.unsqueeze(1) 24 | if real.dim() == 2: real = real.unsqueeze(1) 25 | 26 | d_fake = self.discriminator(**dict(x=fake)) 27 | d_real = self.discriminator(**dict(x=real)) 28 | return d_fake, d_real 29 | 30 | def discriminator_loss(self, fake, real): 31 | d_fake, d_real = self.forward(fake.clone().detach(), real) 32 | 33 | loss_d = 0 34 | for x_fake, x_real in zip(d_fake, d_real): 35 | loss_d += torch.mean(x_fake[-1] ** 2, dim=[1,2,3]) 36 | loss_d += torch.mean((1 - x_real[-1]) ** 2, dim=[1,2,3]) 37 | return loss_d 38 | 39 | def generator_loss(self, fake, real): 40 | d_fake, d_real = self.forward(fake, real) 41 | 42 | loss_g = 0 43 | for x_fake in d_fake: 44 | loss_g += torch.mean((1 - x_fake[-1]) ** 2, dim=[1,2,3]) 45 | 46 | loss_feature = 0 47 | 48 | for i in range(len(d_fake)): 49 | for j in range(len(d_fake[i]) - 1): 50 | loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach(), reduction="none").mean([1,2,3]) 51 | return loss_g, loss_feature -------------------------------------------------------------------------------- /scripts/compress.py: -------------------------------------------------------------------------------- 1 | from esc.models import make_model 2 | from .utils import read_yaml 3 | import torch, os, torchaudio, argparse, warnings 4 | warnings.filterwarnings("ignore") 5 | 6 | def parse_args(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--input", type=str, required=True, help="input 16kHz mono audio file to encode") 9 | parser.add_argument("--save_path", type=str, default="./output", help="folder to save codes and reconstructed audio") 10 | 11 | parser.add_argument("--model_path", type=str, required=True, help="folder contains model configuration and checkpoint") 12 | parser.add_argument("--num_streams", type=int, default=6, help="number of transmitted streams in encoding") 13 | 14 | parser.add_argument("--device", type=str, default="cpu") 15 | return parser.parse_args() 16 | 17 | def main(args): 18 | 19 | x, sr = torchaudio.load(f"{args.input}") 20 | x = x.to(args.device) 21 | 22 | model = make_model(read_yaml(f"{args.model_path}/config.yaml")['model']) 23 | model.load_state_dict( 24 | torch.load(f"{args.model_path}/model.pth", map_location="cpu")["model_state_dict"], 25 | ) 26 | model = model.to(args.device) 27 | 28 | codes, size = model.encode(x, num_streams=args.num_streams) 29 | recon_x = model.decode(codes, size) 30 | 31 | fname = args.input.split("/")[-1] 32 | if not os.path.exists(args.save_path): 33 | os.makedirs(args.save_path) 34 | torchaudio.save(f"{args.save_path}/decoded_{args.num_streams*1.5}kbps_{fname}", recon_x, sr) 35 | torch.save(codes, f"{args.save_path}/encoded_{args.num_streams*1.5}kbps_{fname.split('.')[0]}.pth") 36 | print(f"compression outputs saved into {args.save_path}") 37 | 38 | if __name__ == "__main__": 39 | args = parse_args() 40 | main(args) 41 | 42 | """ 43 | python -m scripts.compress \ 44 | --input ./audio.wav \ 45 | --save_path ./output \ 46 | --model_path ./esc9kbps \ 47 | --num_streams 6 \ 48 | --device cpu 49 | 50 | """ -------------------------------------------------------------------------------- /esc/modules/convolution/layers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class Convolution2D(nn.Module): 4 | """2D Convolution Layer""" 5 | def __init__(self, 6 | in_channels, 7 | out_channels, 8 | kernel_size=(5,2), 9 | scale=True, 10 | transpose=False): 11 | super().__init__() 12 | 13 | stride = (2,1) if scale else (1,1) 14 | conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=(2,1)) if not transpose \ 15 | else nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding=(1,0)) 16 | self.conv = conv 17 | self.transpose, self.scale = transpose, scale 18 | 19 | def forward(self, x): 20 | F, T = x.size(-2), x.size(-1) 21 | y = self.conv(x) 22 | 23 | if self.scale: 24 | y = y[..., :F*2, :T] if self.transpose else y[..., :F//2, :T] 25 | else: 26 | y = y[..., :F, :T] 27 | 28 | return y 29 | 30 | class ResidualUnit(nn.Module): 31 | def __init__(self, dim: int) -> None: 32 | super().__init__() 33 | 34 | self.block = nn.Sequential(*[ 35 | Convolution2D(dim, dim, kernel_size=(5,2), scale=False), 36 | nn.BatchNorm2d(dim), 37 | nn.PReLU(), 38 | Convolution2D(dim, dim, kernel_size=(5,2), scale=False), 39 | nn.BatchNorm2d(dim), 40 | nn.PReLU(), 41 | ]) 42 | 43 | def forward(self, x): 44 | y = self.block(x) 45 | 46 | return x + y 47 | 48 | 49 | class ConvolutionLayer(nn.Module): 50 | def __init__(self, in_dim, out_dim, depth=1, 51 | kernel_size=(5,2), transpose=False) -> None: 52 | super().__init__() 53 | 54 | blocks = [ResidualUnit(in_dim) for _ in range(depth)] 55 | blocks += [Convolution2D(in_dim, out_dim, kernel_size, scale=True, transpose=transpose), 56 | nn.BatchNorm2d(out_dim), 57 | nn.PReLU(),] 58 | 59 | self.blocks = nn.Sequential(*blocks) 60 | 61 | def forward(self, x): 62 | 63 | y = self.blocks(x) 64 | return y -------------------------------------------------------------------------------- /esc/modules/vq/initialize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import warnings 4 | 5 | 6 | @torch.no_grad() 7 | def codebook_init_forward_hook_pvq(self, input, output): 8 | """ initializes codebook from data """ 9 | 10 | if (not self.training) or (self.codebook_initialized.item() == 1): 11 | return # no initialization during inference 12 | 13 | if self.verbose_init is True: 14 | if self.kmeans_init is None: 15 | print("Initializing Product VQs with KaimingNormal") 16 | elif self.kmeans_init is True: 17 | print('Initializing Product VQs with k-means++') 18 | elif self.kmeans_init is False: 19 | print('Initializing Product VQs by randomly choosing from z_e') 20 | 21 | outputs, _ = output 22 | _, z_e_downs, _ = outputs 23 | # z_e_downs [B, group_size, T, codebook_dim] 24 | for i in range(self.num_vqs): 25 | if self.kmeans_init is not None: 26 | z_e_i = z_e_downs[:,i] # [B, T, codebook_dim] 27 | init_codebook = sample_centroids(z_e_i, self.codebook_size, self.kmeans_init) 28 | self.vqs[i].embedding.weight.data = init_codebook 29 | else: 30 | nn.init.kaiming_normal_(self.vqs[i].embedding.weight) 31 | 32 | self.codebook_initialized.fill_(1) # set boolean flag 33 | return 34 | 35 | @torch.no_grad() 36 | def sample_centroids(z_e, codebook_size, use_kmeans=False): 37 | """ create an initialize codebook one-time from z_e 38 | Args: 39 | z_e: encoded embedding Tensor of size [bs,T,d] 40 | codebook_size: number of codewords 41 | 42 | returns: 43 | new_codebook: Tensor of size [codebook_size, d] 44 | """ 45 | 46 | z_e = z_e.reshape(-1, z_e.size(-1)) # bs*T, d 47 | if codebook_size >= z_e.size(0): 48 | e_msg = f'\ncodebook size > warmup samples: {codebook_size} vs {z_e.size(0)}. ' + \ 49 | 'recommended to decrease the codebook size or increase batch size.' 50 | warnings.warn(e_msg) 51 | # repeat until it fits and add noise 52 | repeat = 1 + codebook_size // z_e.shape[0] 53 | new_codes = z_e.data.tile([repeat, 1])[:codebook_size] 54 | new_codes += 1e-3 * torch.randn_like(new_codes.data) 55 | else: 56 | # you have more warmup samples than codebook. subsample data 57 | if use_kmeans: 58 | from torchpq.clustering import KMeans 59 | kmeans = KMeans(n_clusters=codebook_size, distance='euclidean', init_mode="kmeans++") 60 | kmeans.fit(z_e.data.T.contiguous()) 61 | new_codes = kmeans.centroids.T 62 | else: 63 | indices = torch.randint(low=0, high=codebook_size, size=(codebook_size,)) 64 | indices = indices.to(z_e.device) 65 | new_codes = torch.index_select(z_e, 0, indices).to(z_e.device).data 66 | 67 | return new_codes -------------------------------------------------------------------------------- /esc/modules/loss/generator_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | import torchaudio.transforms as T 5 | import torch 6 | 7 | MEL_WINDOWS = [32,64,128,256,512,1024,2048] 8 | MEL_BINS = [5,10,20,40,80,160,320] 9 | SR = 16000 10 | POWER = 0.3 11 | 12 | class ComplexSTFTLoss(nn.Module): 13 | """L2 Loss on Complex STFTs (Power Law Compressed https://arxiv.org/pdf/1811.07030)""" 14 | def __init__(self, weight=1.0, power_law=True): 15 | super().__init__() 16 | self.power_law = power_law 17 | self.weight = weight 18 | 19 | def forward(self, raw_feat, recon_feat): 20 | """ 21 | Args: 22 | raw_feat/recon_feat: (B,2,F,T) 23 | returns: (B,) 24 | """ 25 | if self.power_law: 26 | raw_feat = power_law(raw_feat, power=POWER) 27 | recon_feat = power_law(recon_feat, power=POWER) 28 | 29 | return self.weight * F.mse_loss(raw_feat,recon_feat,reduction="none").mean([1,2,3]) 30 | 31 | def power_law(stft, power=POWER, eps=1e-10): 32 | mask = torch.sign(stft) 33 | power_law_compressed = (torch.abs(stft) + eps) ** power 34 | power_law_compressed = power_law_compressed * mask 35 | return power_law_compressed 36 | 37 | class MelSpectrogramLoss(nn.Module): 38 | """ 39 | L1 MelSpectrogram Loss 40 | Implementation adapted from https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/loss.py 41 | """ 42 | def __init__(self, weight=1.0, 43 | win_lengths=MEL_WINDOWS, n_mels=MEL_BINS, clamp_eps=1e-5,): 44 | super().__init__() 45 | 46 | self.n_mels = n_mels 47 | self.mel_transf = nn.ModuleList( [ 48 | T.MelSpectrogram( 49 | sample_rate=SR, n_fft=w, win_length=w, 50 | hop_length=w//4, n_mels=n_mels[i], power=1) 51 | for i, w in enumerate(win_lengths) 52 | ] ) 53 | self.clamp_eps = clamp_eps 54 | self.weight = weight 55 | 56 | def forward(self, raw_audio, recon_audio): 57 | """ 58 | Args: 59 | raw_audio/recon_audio: (B,L) 60 | returns: (B,) 61 | """ 62 | mel_loss = 0.0 63 | for mel_trans in self.mel_transf: 64 | x_mels, y_mels = mel_trans(raw_audio), mel_trans(recon_audio) 65 | 66 | # magnitude loss 67 | mel_loss += F.l1_loss(x_mels, y_mels, reduction="none").mean([1,2]) 68 | # log magnitude loss 69 | mel_loss += F.l1_loss( 70 | x_mels.clamp(self.clamp_eps).pow(2).log10(), 71 | y_mels.clamp(self.clamp_eps).pow(2).log10(), 72 | reduction="none" 73 | ).mean([1,2]) 74 | 75 | return self.weight * mel_loss -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # celery beat schedule file 90 | celerybeat-schedule 91 | celerybeat.pid 92 | 93 | # SageMath parsed files 94 | *.sage.py 95 | 96 | # Environments 97 | .env 98 | .venv 99 | env/ 100 | venv/ 101 | ENV/ 102 | env.bak/ 103 | venv.bak/ 104 | 105 | # Spyder project settings 106 | .spyderproject 107 | .spyproject 108 | 109 | # Rope project settings 110 | .ropeproject 111 | 112 | # mkdocs documentation 113 | /site 114 | 115 | # mypy 116 | .mypy_cache/ 117 | .dmypy.json 118 | dmypy.json 119 | 120 | # Pyre type checker 121 | .pyre/ 122 | 123 | # pytype static type analyzer 124 | .pytype/ 125 | 126 | # profiling data 127 | .prof 128 | 129 | # Editors and IDEs 130 | # See https://help.github.com/articles/ignoring-files for more about ignoring files. 131 | # Visual Studio Code 132 | .vscode/ 133 | # Intellij 134 | .idea/ 135 | # Sublime Text 136 | *.sublime-workspace 137 | 138 | # Windows image file caches 139 | Thumbs.db 140 | ehthumbs.db 141 | 142 | # Folder config file 143 | Desktop.ini 144 | 145 | # Recycle Bin used on file shares 146 | $RECYCLE.BIN/ 147 | 148 | # macOS files 149 | .DS_Store 150 | .AppleDouble 151 | .LSOverride 152 | 153 | # Icon must end with two \r 154 | Icon 155 | 156 | # Thumbnails 157 | ._* 158 | 159 | # Files that might appear in the root of a volume 160 | .DocumentRevisions-V100 161 | .fseventsd 162 | .Spotlight-V100 163 | .TemporaryItems 164 | .Trashes 165 | .VolumeIcon.icns 166 | .com.apple.timemachine.dontbackup 167 | .PKInstallSandboxManager 168 | .PKInstallSandboxManager-SystemSoftware 169 | 170 | 171 | # old dev repo 172 | dev-deep-audio-signal-coding/ -------------------------------------------------------------------------------- /baselines/descript/dac/utils/decode.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from pathlib import Path 3 | 4 | import argbind 5 | import numpy as np 6 | import torch 7 | from audiotools import AudioSignal 8 | from tqdm import tqdm 9 | 10 | from dac import DACFile 11 | from dac.utils import load_model 12 | 13 | warnings.filterwarnings("ignore", category=UserWarning) 14 | 15 | 16 | @argbind.bind(group="decode", positional=True, without_prefix=True) 17 | @torch.inference_mode() 18 | @torch.no_grad() 19 | def decode( 20 | input: str, 21 | output: str = "", 22 | weights_path: str = "", 23 | model_tag: str = "latest", 24 | model_bitrate: str = "8kbps", 25 | device: str = "cuda", 26 | model_type: str = "44khz", 27 | verbose: bool = False, 28 | ): 29 | """Decode audio from codes. 30 | 31 | Parameters 32 | ---------- 33 | input : str 34 | Path to input directory or file 35 | output : str, optional 36 | Path to output directory, by default "". 37 | If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`. 38 | weights_path : str, optional 39 | Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the 40 | model_tag and model_type. 41 | model_tag : str, optional 42 | Tag of the model to use, by default "latest". Ignored if `weights_path` is specified. 43 | model_bitrate: str 44 | Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". 45 | device : str, optional 46 | Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU. 47 | model_type : str, optional 48 | The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified. 49 | """ 50 | generator = load_model( 51 | model_type=model_type, 52 | model_bitrate=model_bitrate, 53 | tag=model_tag, 54 | load_path=weights_path, 55 | ) 56 | generator.to(device) 57 | generator.eval() 58 | 59 | # Find all .dac files in input directory 60 | _input = Path(input) 61 | input_files = list(_input.glob("**/*.dac")) 62 | 63 | # If input is a .dac file, add it to the list 64 | if _input.suffix == ".dac": 65 | input_files.append(_input) 66 | 67 | # Create output directory 68 | output = Path(output) 69 | output.mkdir(parents=True, exist_ok=True) 70 | 71 | for i in tqdm(range(len(input_files)), desc=f"Decoding files"): 72 | # Load file 73 | artifact = DACFile.load(input_files[i]) 74 | 75 | # Reconstruct audio from codes 76 | recons = generator.decompress(artifact, verbose=verbose) 77 | 78 | # Compute output path 79 | relative_path = input_files[i].relative_to(input) 80 | output_dir = output / relative_path.parent 81 | if not relative_path.name: 82 | output_dir = output 83 | relative_path = input_files[i] 84 | output_name = relative_path.with_suffix(".wav").name 85 | output_path = output_dir / output_name 86 | output_path.parent.mkdir(parents=True, exist_ok=True) 87 | 88 | # Write to file 89 | recons.write(output_path) 90 | 91 | 92 | if __name__ == "__main__": 93 | args = argbind.parse_args() 94 | with argbind.scope(args): 95 | decode() 96 | -------------------------------------------------------------------------------- /baselines/descript/dac/utils/encode.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | from pathlib import Path 4 | 5 | import argbind 6 | import numpy as np 7 | import torch 8 | from audiotools import AudioSignal 9 | from audiotools.core import util 10 | from tqdm import tqdm 11 | 12 | from dac.utils import load_model 13 | 14 | warnings.filterwarnings("ignore", category=UserWarning) 15 | 16 | 17 | @argbind.bind(group="encode", positional=True, without_prefix=True) 18 | @torch.inference_mode() 19 | @torch.no_grad() 20 | def encode( 21 | input: str, 22 | output: str = "", 23 | weights_path: str = "", 24 | model_tag: str = "latest", 25 | model_bitrate: str = "8kbps", 26 | n_quantizers: int = None, 27 | device: str = "cuda", 28 | model_type: str = "44khz", 29 | win_duration: float = 5.0, 30 | verbose: bool = False, 31 | ): 32 | """Encode audio files in input path to .dac format. 33 | 34 | Parameters 35 | ---------- 36 | input : str 37 | Path to input audio file or directory 38 | output : str, optional 39 | Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`. 40 | weights_path : str, optional 41 | Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the 42 | model_tag and model_type. 43 | model_tag : str, optional 44 | Tag of the model to use, by default "latest". Ignored if `weights_path` is specified. 45 | model_bitrate: str 46 | Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". 47 | n_quantizers : int, optional 48 | Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate. 49 | device : str, optional 50 | Device to use, by default "cuda" 51 | model_type : str, optional 52 | The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified. 53 | """ 54 | generator = load_model( 55 | model_type=model_type, 56 | model_bitrate=model_bitrate, 57 | tag=model_tag, 58 | load_path=weights_path, 59 | ) 60 | generator.to(device) 61 | generator.eval() 62 | kwargs = {"n_quantizers": n_quantizers} 63 | 64 | # Find all audio files in input path 65 | input = Path(input) 66 | audio_files = util.find_audio(input) 67 | 68 | output = Path(output) 69 | output.mkdir(parents=True, exist_ok=True) 70 | 71 | for i in tqdm(range(len(audio_files)), desc="Encoding files"): 72 | # Load file 73 | signal = AudioSignal(audio_files[i]) 74 | 75 | # Encode audio to .dac format 76 | artifact = generator.compress(signal, win_duration, verbose=verbose, **kwargs) 77 | 78 | # Compute output path 79 | relative_path = audio_files[i].relative_to(input) 80 | output_dir = output / relative_path.parent 81 | if not relative_path.name: 82 | output_dir = output 83 | relative_path = audio_files[i] 84 | output_name = relative_path.with_suffix(".dac").name 85 | output_path = output_dir / output_name 86 | output_path.parent.mkdir(parents=True, exist_ok=True) 87 | 88 | artifact.save(output_path) 89 | 90 | 91 | if __name__ == "__main__": 92 | args = argbind.parse_args() 93 | with argbind.scope(args): 94 | encode() 95 | -------------------------------------------------------------------------------- /scripts_all.sh: -------------------------------------------------------------------------------- 1 | ## Training Final Models 2 | accelerate launch main.py \ 3 | --exp_name esc-base-non-adv \ 4 | --config_path ./configs/9kbps_esc_base.yaml \ 5 | --wandb_project efficient-speech-codec \ 6 | --lr 1.0e-4 \ 7 | --num_epochs 80 \ 8 | --num_pretraining_epochs 15 \ 9 | --num_devices 4 \ 10 | --dropout_rate 0.75 \ 11 | --save_path ../output \ 12 | --seed 53 13 | 14 | accelerate launch main.py \ 15 | --exp_name esc-base-adv \ 16 | --adv_training \ 17 | --config_path ./configs/9kbps_esc_base_adv.yaml \ 18 | --wandb_project efficient-speech-codec \ 19 | --lr 1.0e-4 \ 20 | --num_epochs 80 \ 21 | --num_pretraining_epochs 15 \ 22 | --num_devices 4 \ 23 | --dropout_rate 0.75 \ 24 | --save_path ../output \ 25 | --seed 53 26 | 27 | # accelerate launch main.py \ 28 | # --exp_name esc-base-post-adv \ 29 | # --adv_training \ 30 | # --pretrain_ckp ../esc9kbps_base_non_adversarial/model.pth \ 31 | # --config_path ./configs/9kbps_esc_base_adv.yaml \ 32 | # --wandb_project efficient-speech-codec \ 33 | # --lr 1.0e-4 \ 34 | # --num_epochs 20 \ 35 | # --num_pretraining_epochs 0 \ 36 | # --num_devices 4 \ 37 | # --dropout_rate 0.75 \ 38 | # --save_path ../output \ 39 | # --seed 53 40 | 41 | accelerate launch main.py \ 42 | --exp_name esc-large-non-adv \ 43 | --config_path ./configs/9kbps_esc_large.yaml \ 44 | --wandb_project efficient-speech-codec \ 45 | --lr 1.0e-4 \ 46 | --num_epochs 80 \ 47 | --num_pretraining_epochs 15 \ 48 | --num_devices 4 \ 49 | --dropout_rate 0.75 \ 50 | --save_path ../output \ 51 | --seed 53 52 | 53 | 54 | ## Method Ablations 55 | accelerate launch main.py \ 56 | --exp_name csvq+swinT \ 57 | --config_path ./configs/ablations/9kbps_csvq_swinT.yaml \ 58 | --wandb_project efficient-speech-codec \ 59 | --lr 1.0e-4 \ 60 | --num_epochs 50 \ 61 | --num_pretraining_epochs 5 \ 62 | --num_devices 4 \ 63 | --dropout_rate 0.75 \ 64 | --save_path ../output \ 65 | --seed 53 66 | 67 | accelerate launch main.py \ 68 | --exp_name csvq+conv_9kbps \ 69 | --config_path ./configs/ablations/9kbps_csvq_conv.yaml \ 70 | --wandb_project efficient-speech-codec \ 71 | --lr 1.0e-4 \ 72 | --num_epochs 50 \ 73 | --num_pretraining_epochs 5 \ 74 | --num_devices 4 \ 75 | --dropout_rate 0.75 \ 76 | --save_path ../output \ 77 | --seed 53 78 | 79 | accelerate launch main.py \ 80 | --exp_name rvq+swinT \ 81 | --config_path ./configs/ablations/9kbps_rvq_swinT.yaml \ 82 | --wandb_project efficient-speech-codec \ 83 | --lr 1.0e-4 \ 84 | --num_epochs 50 \ 85 | --num_pretraining_epochs 5 \ 86 | --num_devices 2 \ 87 | --dropout_rate 0.75 \ 88 | --save_path ../output \ 89 | --seed 53 90 | 91 | accelerate launch main.py \ 92 | --exp_name rvq+conv \ 93 | --config_path ./configs/ablations/9kbps_rvq_conv.yaml \ 94 | --wandb_project efficient-speech-codec \ 95 | --lr 1.0e-4 \ 96 | --num_epochs 50 \ 97 | --num_pretraining_epochs 5 \ 98 | --num_devices 4 \ 99 | --dropout_rate 0.75 \ 100 | --save_path ../output \ 101 | --seed 53 102 | 103 | accelerate launch main.py \ 104 | --exp_name csvq+swinT_w/o_pretraining \ 105 | --config_path ./configs/ablations/9kbps_csvq_swinT.yaml \ 106 | --wandb_project efficient-speech-codec \ 107 | --lr 1.0e-4 \ 108 | --num_epochs 50 \ 109 | --num_pretraining_epochs 0 \ 110 | --num_devices 2 \ 111 | --dropout_rate 0.75 \ 112 | --save_path ../output \ 113 | --seed 53 -------------------------------------------------------------------------------- /esc/modules/vq/codebook.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from einops import rearrange 4 | 5 | class Codebook(nn.Module): 6 | def __init__(self, 7 | embedding_dim: int=256, 8 | num_embeddings: int=1024, 9 | l2norm: bool=False, 10 | ): 11 | super().__init__() 12 | 13 | self.embedding = nn.Embedding(num_embeddings, embedding_dim) 14 | nn.init.kaiming_normal_(self.embedding.weight) 15 | 16 | self.embedding_dim = embedding_dim 17 | self.num_embeddings = num_embeddings 18 | self.l2norm = l2norm 19 | 20 | def quantize_to_code(self, z_e): 21 | """ Quantize input vector to codebook indices. 22 | Args: 23 | z_e (Tensor): input vector with shape (bs, *, embedding_dim) 24 | Returns: 25 | Tensor of indices with shape (bs, *) 26 | """ 27 | 28 | codebook = self.embedding.weight # [num_embeddings, embedding_dim] 29 | z_flat = rearrange(z_e, "b t d -> (b t) d") # [*, embedding_dim] 30 | 31 | if self.l2norm: 32 | codebook = F.normalize(codebook, dim=-1) 33 | z_flat = F.normalize(z_flat, dim=-1) 34 | 35 | dist = ( 36 | z_flat.pow(2).sum(1, keepdim=True) 37 | - 2 * z_flat @ codebook.t() 38 | + codebook.pow(2).sum(1, keepdim=True).t() 39 | ) 40 | indices = dist.min(1).indices 41 | indices = rearrange(indices, "(b t) -> b t", b=z_e.size(0)) 42 | 43 | return indices 44 | 45 | def dequantize_code(self, code): 46 | """ De-quantize code indices to vectors 47 | Args: 48 | code (Tensor): code with shape (bs, *) 49 | Returns: 50 | Tensor of quantized vector with shape (bs, *, embedding_dim) 51 | """ 52 | codebook = self.embedding.weight 53 | z_q = F.embedding(code, codebook) 54 | 55 | return z_q 56 | 57 | def forward(self, z_e): 58 | """ Vector Quantization Forward Function. 59 | Args: 60 | z_e (Tensor): input vector with shape (bs, T, embedding_dim) 61 | z_q (Tensor): quantized vector with shape (bs, T, embedding_dim) 62 | """ 63 | 64 | code = self.quantize_to_code(z_e) 65 | z_q = self.dequantize_code(code) 66 | 67 | if self.training: # Straight-Through Estimator 68 | commitment_loss = F.mse_loss(z_q.detach(), z_e, reduction="none").mean([1,2]) 69 | codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1,2]) 70 | z_q = z_e + (z_q - z_e).detach() 71 | else: 72 | commitment_loss = F.mse_loss(z_q, z_e, reduction="none").mean([1,2]) 73 | codebook_loss = commitment_loss 74 | 75 | return z_q, code, codebook_loss, commitment_loss 76 | 77 | def encode(self, z_e): 78 | code = self.quantize_to_code(z_e) 79 | return code 80 | 81 | def decode(self, code): 82 | z_q = self.dequantize_code(code) 83 | return z_q 84 | 85 | def count_posterior(code, codebook_size): 86 | """ Compute the posterior codebook distribution P(q|e) on a total batch of encoded features 87 | Args: 88 | code: quantized discrete code of size [B, T] 89 | codebook_size: total number of entries 90 | returns: posterior distribution with size [B, codebook_size] 91 | """ 92 | one_hot = F.one_hot(code, num_classes=codebook_size) # B T codebook_size 93 | counts = one_hot.sum(dim=1) # B codebook_size 94 | posterior = counts / code.size(1) 95 | 96 | return posterior 97 | -------------------------------------------------------------------------------- /baselines/descript/dac/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import argbind 4 | from audiotools import ml 5 | 6 | import dac 7 | 8 | DAC = dac.model.DAC 9 | Accelerator = ml.Accelerator 10 | 11 | __MODEL_LATEST_TAGS__ = { 12 | ("44khz", "8kbps"): "0.0.1", 13 | ("24khz", "8kbps"): "0.0.4", 14 | ("16khz", "8kbps"): "0.0.5", 15 | ("44khz", "16kbps"): "1.0.0", 16 | } 17 | 18 | __MODEL_URLS__ = { 19 | ( 20 | "44khz", 21 | "0.0.1", 22 | "8kbps", 23 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth", 24 | ( 25 | "24khz", 26 | "0.0.4", 27 | "8kbps", 28 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth", 29 | ( 30 | "16khz", 31 | "0.0.5", 32 | "8kbps", 33 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth", 34 | ( 35 | "44khz", 36 | "1.0.0", 37 | "16kbps", 38 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth", 39 | } 40 | 41 | 42 | @argbind.bind(group="download", positional=True, without_prefix=True) 43 | def download( 44 | model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest" 45 | ): 46 | """ 47 | Function that downloads the weights file from URL if a local cache is not found. 48 | 49 | Parameters 50 | ---------- 51 | model_type : str 52 | The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". 53 | model_bitrate: str 54 | Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". 55 | Only 44khz model supports 16kbps. 56 | tag : str 57 | The tag of the model to download. Defaults to "latest". 58 | 59 | Returns 60 | ------- 61 | Path 62 | Directory path required to load model via audiotools. 63 | """ 64 | model_type = model_type.lower() 65 | tag = tag.lower() 66 | 67 | assert model_type in [ 68 | "44khz", 69 | "24khz", 70 | "16khz", 71 | ], "model_type must be one of '44khz', '24khz', or '16khz'" 72 | 73 | assert model_bitrate in [ 74 | "8kbps", 75 | "16kbps", 76 | ], "model_bitrate must be one of '8kbps', or '16kbps'" 77 | 78 | if tag == "latest": 79 | tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)] 80 | 81 | download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None) 82 | 83 | if download_link is None: 84 | raise ValueError( 85 | f"Could not find model with tag {tag} and model type {model_type}" 86 | ) 87 | 88 | local_path = ( 89 | Path.home() 90 | / ".cache" 91 | / "descript" 92 | / "dac" 93 | / f"weights_{model_type}_{model_bitrate}_{tag}.pth" 94 | ) 95 | if not local_path.exists(): 96 | local_path.parent.mkdir(parents=True, exist_ok=True) 97 | 98 | # Download the model 99 | import requests 100 | 101 | response = requests.get(download_link) 102 | 103 | if response.status_code != 200: 104 | raise ValueError( 105 | f"Could not download model. Received response code {response.status_code}" 106 | ) 107 | local_path.write_bytes(response.content) 108 | 109 | return local_path 110 | 111 | 112 | def load_model( 113 | model_type: str = "44khz", 114 | model_bitrate: str = "8kbps", 115 | tag: str = "latest", 116 | load_path: str = None, 117 | ): 118 | if not load_path: 119 | load_path = download( 120 | model_type=model_type, model_bitrate=model_bitrate, tag=tag 121 | ) 122 | generator = DAC.load(load_path) 123 | return generator 124 | -------------------------------------------------------------------------------- /scripts/utils.py: -------------------------------------------------------------------------------- 1 | import torch, torchaudio 2 | import transformers 3 | import numpy as np 4 | import argparse, yaml, glob 5 | from huggingface_hub import hf_hub_download 6 | from torch.utils.data import Dataset, DataLoader, default_collate 7 | 8 | from esc.modules import ComplexSTFTLoss, MelSpectrogramLoss 9 | 10 | 11 | def quantization_dropout(dropout_rate: float, max_streams: int): 12 | """ 13 | Args: 14 | dropout_rate: probability that applies quantization dropout 15 | max_streams: maximum number of streams codec can take 16 | returns: sampled number of streams for current batch 17 | """ 18 | assert dropout_rate >=0 and dropout_rate <=1, "dropout_rate must be within [0, 1]" 19 | # Do Random Sample N w prob dropout_rate 20 | do_sample = np.random.choice([0, 1], p=[1-dropout_rate, dropout_rate]) 21 | if do_sample: 22 | streams = np.random.randint(1, max_streams+1) 23 | else: 24 | streams = max_streams 25 | return streams 26 | 27 | class EvalSet(Dataset): 28 | def __init__(self, eval_folder_path) -> None: 29 | super().__init__() 30 | self.testset_files = glob.glob(f"{eval_folder_path}/*.wav") 31 | if not self.testset_files: 32 | self.testset_files = glob.glob(f"{eval_folder_path}/*/*.wav") 33 | self.testset_files = self.testset_files[:180000] 34 | 35 | def __len__(self): 36 | return len(self.testset_files) 37 | 38 | def __getitem__(self, i): 39 | x, _ = torchaudio.load(self.testset_files[i]) 40 | return x[0, :-80] 41 | 42 | def make_dataloader(data_path, batch_size, shuffle, num_workers=0): 43 | ds = EvalSet(data_path) 44 | dl = DataLoader(ds, batch_size=batch_size, shuffle=shuffle, 45 | collate_fn=default_collate, num_workers=num_workers) 46 | return dl 47 | 48 | def make_optimizer(params, lr): 49 | return torch.optim.AdamW(params, lr) 50 | 51 | GAMMAR = 0.999996 52 | def make_scheduler(optimizer, scheduler_type, total_steps=250000, warmup_steps=0): 53 | if scheduler_type == "constant": 54 | scheduler = transformers.get_constant_schedule(optimizer) 55 | elif scheduler_type == "constant_warmup": 56 | scheduler = transformers.get_constant_schedule_with_warmup( 57 | optimizer, num_warmup_steps=warmup_steps) 58 | elif scheduler_type == "cosine_warmup": 59 | scheduler = transformers.get_cosine_schedule_with_warmup( 60 | optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps) 61 | elif scheduler_type == "exponential_decay": 62 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=GAMMAR) 63 | else: 64 | raise ValueError("\{scheduler_type\} must be in ('constant', 'constant_warmup', 'cosine_warmup', 'exponential_decay')") 65 | return scheduler 66 | 67 | def make_losses(name="mel_loss"): 68 | if name == "mel_loss": 69 | return MelSpectrogramLoss() 70 | elif name == "stft_loss": 71 | return ComplexSTFTLoss(power_law=True) 72 | else: 73 | raise ValueError("Supported losses are (mel_loss, stft_loss)") 74 | 75 | def dict2namespace(config): 76 | namespace = argparse.Namespace() 77 | for key, value in config.items(): 78 | if isinstance(value, dict): 79 | new_value = dict2namespace(value) 80 | else: 81 | new_value = value 82 | setattr(namespace, key, new_value) 83 | return namespace 84 | 85 | def namespace2dict(config): 86 | return vars(config) 87 | 88 | def read_yaml(pth): 89 | with open(pth, 'r') as f: 90 | config = yaml.safe_load(f) 91 | return config 92 | 93 | def download_data_hf(repo_id="../dnscustom", 94 | filename="testset.tar.gz", 95 | local_dir="./data"): 96 | 97 | file_path = hf_hub_download(repo_id=repo_id, 98 | filename=filename, 99 | repo_type="dataset", 100 | local_dir=local_dir) 101 | print(f"File has been downloaded and is located at {file_path}") 102 | return file_path 103 | -------------------------------------------------------------------------------- /scripts/test.py: -------------------------------------------------------------------------------- 1 | from .metrics import EntropyCounter, PESQ, MelSpectrogramDistance, SISDR 2 | from .utils import read_yaml, EvalSet 3 | from esc.models import make_model 4 | 5 | from torch.utils.data import DataLoader, default_collate 6 | from tqdm import tqdm 7 | import numpy as np 8 | 9 | import argparse, torch, json 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--eval_folder_path", type=str, required=True) 14 | parser.add_argument("--batch_size", type=int, default=1) 15 | 16 | parser.add_argument("--model_path", type=str, required=True, help="folder contains model configuration and checkpoint") 17 | parser.add_argument("--save_path", type=str, default=None, help="folder to save test statistics") 18 | 19 | parser.add_argument("--device", type=str, default="cpu") 20 | return parser.parse_args() 21 | 22 | @torch.no_grad() 23 | def eval_epoch(model, eval_loader:DataLoader, 24 | metric_funcs:dict, e_counter:EntropyCounter, device: str, bps_per_stream: float, 25 | num_streams=None, verbose: bool=True): 26 | model.eval() 27 | 28 | all_perf = {k:[] for k in metric_funcs.keys()} 29 | all_perf["utilization"] = [] 30 | eval_range = range(num_streams,num_streams+1) if num_streams is not None \ 31 | else range(1, model.max_streams+1) # 1.5kbps -> 9kbps 32 | for s in eval_range: 33 | perf = {k:[] for k in metric_funcs.keys()} 34 | e_counter.reset_stats(num_streams=s) 35 | for _, x in tqdm(enumerate(eval_loader), total=len(eval_loader), desc=f"Evaluating Codec at {s*bps_per_stream:.2f}kbps"): 36 | x = x.to(device) 37 | outputs = model(**dict(x=x, x_feat=None, num_streams=s)) 38 | recon_x, codes = outputs["recon_audio"], outputs["codes"] 39 | 40 | for k, func in metric_funcs.items(): 41 | perf[k].extend(func(x, recon_x).tolist()) 42 | e_counter.update(codes) 43 | 44 | for k, v in perf.items(): 45 | all_perf[k].append(round(np.mean(v),4)) 46 | rate, _ = e_counter.compute_utilization() 47 | perf["utilization"] = [rate] 48 | all_perf["utilization"].append(rate) 49 | 50 | if verbose: 51 | print(f"Test Metrics at {s*1.5:.2f}kbps: ", end="") 52 | print(" | ".join(f"{k}: {np.mean(v):.4f}" for k, v in perf.items())) 53 | 54 | model.train() 55 | return all_perf 56 | 57 | def run(args): 58 | # Data 59 | eval_set = EvalSet(args.eval_folder_path) 60 | eval_loader = DataLoader(eval_set, batch_size=args.batch_size, shuffle=False, collate_fn=default_collate) 61 | 62 | # Metrics 63 | metric_funcs = {"PESQ": PESQ(), "MelDistance": MelSpectrogramDistance().to(args.device), "SISDR": SISDR().to(args.device)} 64 | 65 | # Model 66 | cfg = read_yaml(f"{args.model_path}/config.yaml") 67 | model = make_model(cfg['model'], cfg['model_name']) 68 | model.load_state_dict( 69 | torch.load(f"{args.model_path}/model.pth", map_location="cpu")["model_state_dict"], 70 | ) 71 | model = model.to(args.device) 72 | e_counter = EntropyCounter(cfg['model']['codebook_size'], num_streams=cfg['model']['max_streams'], 73 | num_groups=cfg['model']['group_size'], device=args.device) 74 | 75 | performances = eval_epoch( 76 | model, eval_loader, metric_funcs, e_counter, args.device, 77 | num_streams=None, verbose=True, bps_per_stream=1.5, # evaluate across all bitrates 78 | ) 79 | 80 | save_path = args.model_path if args.save_path is None else args.save_path 81 | json.dump(performances, open(f"{save_path}/perf_stats.json", "w"), indent=2) 82 | print(f"Test statistics saved into {save_path}/perf_stats.json") 83 | 84 | 85 | if __name__ == "__main__": 86 | args = parse_args() 87 | run(args) 88 | 89 | 90 | """ 91 | python -m scripts.test \ 92 | --eval_folder_path ../evaluation_set/test \ 93 | --batch_size 12 \ 94 | --model_path ./esc9kbps \ 95 | --device cuda 96 | 97 | 98 | python -m scripts.test \ 99 | --eval_folder_path ../data/ESC_evaluation/test \ 100 | --batch_size 6 \ 101 | --model_path ../output/csvq_conv_9kbps \ 102 | --device cuda 103 | 104 | export CUDA_VISIBLE_DEVICES=1 105 | python -m scripts.test \ 106 | --eval_folder_path ../data/ESC_evaluation/test \ 107 | --batch_size 6 \ 108 | --model_path ../output/rvq_conv_9kbps \ 109 | --device cuda 110 | 111 | export CUDA_VISIBLE_DEVICES=2 112 | python -m scripts.test \ 113 | --eval_folder_path ../data/ESC_evaluation/test \ 114 | --batch_size 6 \ 115 | --model_path ../output/rvq_swinT_9kbps \ 116 | --device cuda 117 | 118 | """ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Efficient Speech Coding with Cross-Scale Residual Vector Quantized Transformers 2 | 3 | This is the code repository for the neural speech codec presented in the EMNLP 2024 paper **ESC: Efficient Speech Coding with Cross-Scale Residual Vector Quantized Transformers** [[paper](https://arxiv.org/abs/2404.19441)] 4 | - Our neural speech codec ESC, within only 30MB, efficiently compresses 16kHz speech at bitrates of 1.5, 3, 4.5, 6, 7.5, and 9kbps, while maintaining comparative reconstruction quality to Descript's audio codec. 5 | - We provide pretrained model checkpoints [[download](#model-checkpoints)] for different ESC variants and DAC models, as well as a demo webpage [[link](https://efficient-speech-codec.notion.site/)] including multilingual speech samples. 6 | 7 | ![An illustration of ESC Architecture](assets/architecture.png) 8 | ## Usage 9 | 10 | ### Environment Setup 11 | ```bash 12 | conda create -n esc python=3.8 13 | conda activate esc 14 | 15 | pip install -r requirements.txt 16 | ``` 17 | 18 | ### Compress and de-compress audio 19 | ```ruby 20 | python -m scripts.compress --input /path/to/input.wav --save_path /path/to/output --model_path /path/to/model --num_streams 6 --device cpu 21 | ``` 22 | This will create `.pth`(code) and `.wav`(reconstructed audio) files under the specified `save_path`. Our codec supports `num_streams` from 1 to 6, corresponding to bitrates 1.5 ~ 9.0 kbps. For programmatic usage, you can compress audio tensors using `torchaudio` as follows: 23 | 24 | ```python 25 | import torchaudio, torch 26 | from esc import ESC 27 | model = ESC(**config) 28 | model.load_state_dict(torch.load("model.pth", map_location="cpu"),) 29 | x, _ = torchaudio.load("input.wav") 30 | # Enc. (@ num_streams*1.5 kbps) 31 | codes, f_shape = model.encode(x, num_streams=6) 32 | # Dec. 33 | recon_x = model.decode(codes, f_shape) 34 | ``` 35 | For more details, see the `example.ipynb` notebook. 36 | 37 | ### Training 38 | 39 | We provide developmental training and evaluation datasets available on [Hugging Face](https://huggingface.co/datasets/Tracygu/dnscustom/tree/main). For custom training, set the `train_data_path` in `exp.yaml` to the parent directory containing `.wav` audio segments. Run the following to start training: 40 | 41 | ```ruby 42 | WANDB_API_KEY=your_API_key 43 | accelerate launch main.py --exp_name esc9kbps --config_path ./configs/9kbps_esc_base.yaml --wandb_project efficient-speech-codec --lr 1.0e-4 --num_epochs 80 --num_pretraining_epochs 15 --num_devices 4 --dropout_rate 0.75 --save_path /path/to/output --seed 53 44 | ``` 45 | 46 | We use `accelerate` library to handle distributed training and `wandb` library for monitoring. To enable adversarial training with the same discriminator in DAC, include the `--adv_training` flag. 47 | 48 | Training a base ESC model on 4 RTX4090 GPUs takes ~16 hours for 250k steps on 3-second speech clips with a batch size of 36. Detailed experiment configurations can be found in the `configs/` folder. For complete experiments presented in the paper, refer to `scripts_all.sh`. 49 | 50 | ### Evaluation 51 | 52 | ```ruby 53 | CUDA_VISIBLE_DEVICES=0 54 | python -m scripts.test --eval_folder_path path/to/data --batch_size 12 --model_path /path/to/model --device cuda 55 | ``` 56 | This will run codec evaluation across all available bandwidth on the specified test set folder. We provide four metrics for reporting: `PESQ`, `Mel-Distance`, `SI-SDR` and `Bitrate-Utilization-Rate`. Evaluation statistics will be saved under `model_path` by default. 57 | 58 | ### Model Checkpoints 59 | You can download the pre-trained model checkpoints below: 60 | 61 | | Codec | Checkpoint | #Param. | 62 | |--------|-------------------------------------------------|----------| 63 | | ESC-Base | [Download](https://drive.google.com/file/d/1OF1ab3az6nKOY8owSUhUH0ksYHFmR1bc/view?usp=sharing) | 8.39M | 64 | | ESC-Base(adv) | [Download](https://drive.google.com/file/d/1_g1dFYhY7qXKWkcq8_Q6I-kv8tQW_SF7/view?usp=sharing) | 8.39M | 65 | | ESC-Large | [Download](https://drive.google.com/file/d/180Q4zctqeNnDmRvoMsVQ-3iCB5FriJbN/view?usp=sharing) | 15.58M | 66 | | DAC-Tiny(adv) | [Download](https://drive.google.com/file/d/1ED-B_S7ftsb8CqoFGTNkWUIrMKrk-iiu/view?usp=sharing) | 8.17M | 67 | | DAC-Tiny | [Download](https://drive.google.com/file/d/1jk8zPYBYmxgsiSzrgoQynF6hnzoiIuX8/view?usp=sharing) | 8.17M | 68 | | DAC-Base(adv) | [Download](https://drive.google.com/file/d/1moy0FX-aPlx54MajBRuE-zjYeNlJUjI6/view?usp=sharing) | 74.31M | 69 | 70 | ## Results 71 | 72 | ![Performance Evaluation](assets/results.png) 73 | We provide a comprehensive performance comparison of ESC with Descript's audio codec (DAC) at different scales of model sizes (w/ and w/o adversarial trainings). 74 | 75 | ## Reference 76 | If you find our work useful or relevant to your research, please kindly cite our paper: 77 | ```bibtex 78 | @article{gu2024esc, 79 | title={ESC: Efficient Speech Coding with Cross-Scale Residual Vector Quantized Transformers}, 80 | author={Gu, Yuzhe and Diao, Enmao}, 81 | journal={arXiv preprint arXiv:2404.19441}, 82 | year={2024} 83 | } 84 | ``` -------------------------------------------------------------------------------- /esc/modules/transformer/scale.py: -------------------------------------------------------------------------------- 1 | from einops import rearrange 2 | from typing import Literal 3 | 4 | import torch.nn as nn 5 | 6 | 7 | def pixel_unshuffle(input, downscale_factor:tuple=(2,1)): 8 | s1, s2 = downscale_factor 9 | B, H, W, C = input.size() 10 | C_, H_, W_ = C*(s1*s2), H//s1, W//s2 11 | 12 | unshuffle_out = input.reshape(B, H_, s1, W_, s2, C).\ 13 | permute(0,1,3,2,4,5).reshape(B, H_, W_, C_) 14 | return unshuffle_out 15 | 16 | def pixel_shuffle(input, upscale_factor:tuple=(2,1)): 17 | s1, s2 = upscale_factor 18 | B, H, W, C = input.size() 19 | C_, H_, W_ = C//(s1*s2), H*s1, W*s2 20 | 21 | shuffle_out = input.reshape(B, H, W, s1, s2, C_).\ 22 | permute(0,1,3,2,4,5).reshape(B, H_, W_, C_) 23 | return shuffle_out 24 | 25 | 26 | class PatchEmbed(nn.Module): 27 | """ 2D Linear Patchify """ 28 | def __init__(self, 29 | freq: int=192, 30 | in_chans: int=2, 31 | patch_size: tuple=(3,2), 32 | embed_dim: int=48, 33 | norm_layer=nn.LayerNorm, 34 | backbone: Literal['transformer', 'convolution']='transformer',): 35 | super().__init__() 36 | 37 | self.H = freq // patch_size[0] 38 | self.proj = nn.Conv2d(in_chans, embed_dim, patch_size, patch_size) 39 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 40 | self.backbone = backbone 41 | 42 | def forward(self, x): 43 | 44 | x = self.proj(x) # B2FT -> BCHW 45 | if self.backbone == "convolution": 46 | return x # for convolution backbones, no normalization 47 | 48 | x = rearrange(x, "b c h w -> b (h w) c") # BCHW -> BCL -> BLC 49 | x = self.norm(x) 50 | return x 51 | 52 | class PatchDeEmbed(nn.Module): 53 | """ 2D Linear De-Patchify """ 54 | def __init__(self, 55 | freq: int=192, 56 | in_chans: int=2, 57 | patch_size: tuple=(3,2), 58 | embed_dim: int=48, 59 | backbone: Literal['transformer', 'convolution']='transformer',): 60 | super().__init__() 61 | 62 | self.patch_size = patch_size 63 | self.H = freq // patch_size[0] 64 | self.backbone = backbone 65 | 66 | self.de_proj1 = nn.Conv2d(embed_dim, 67 | embed_dim*patch_size[0]*patch_size[1], 68 | kernel_size=5, stride=1, padding=2) 69 | self.de_proj2 = nn.Conv2d(embed_dim, 70 | in_chans, 71 | kernel_size=3, stride=1, padding=1) 72 | 73 | def forward(self, x): 74 | if self.backbone == "transformer": 75 | x = rearrange(x, "b (h w) c -> b c h w", h=self.H) 76 | 77 | x = self.de_proj1(x) # B C*scale H W 78 | x = pixel_shuffle(x.permute(0,2,3,1), self.patch_size) # B F T C 79 | x = self.de_proj2(x.permute(0,3,1,2)) # BCFT -> B2FT 80 | 81 | return x 82 | 83 | class PatchMerge(nn.Module): 84 | """Patch Merging Layer: Perform Pixel Unshuffle and Downscale""" 85 | def __init__(self, 86 | in_dim: int, 87 | out_dim: int, 88 | scale_factor: tuple=(2,1), 89 | norm_layer=nn.LayerNorm): 90 | super().__init__() 91 | s1, s2 = scale_factor 92 | 93 | self.norm = norm_layer(s1*s2*in_dim) 94 | self.down = nn.Linear(s1*s2*in_dim, out_dim, bias=False) 95 | self.scale_factor = scale_factor 96 | 97 | def forward(self, x, H): 98 | """ Forward function. 99 | Args: 100 | x: Input feature, tensor size (B, H*W, in_dim) 101 | H: num_patches along Freq Domain 102 | returns: downscaled feature x, tensor size (B, H*W//2, out_dim) 103 | """ 104 | 105 | x = rearrange(x, "b (h w) c -> b h w c", h=H) 106 | pad_input = (H%2 == 1) 107 | if pad_input: 108 | x = nn.functional.pad(x, (0,0,0,0,0,H%2)) 109 | 110 | x = pixel_unshuffle(x, self.scale_factor) 111 | x = rearrange(x, "b h w c -> b (h w) c") 112 | x = self.norm(x) 113 | x = self.down(x) 114 | 115 | return x 116 | 117 | class PatchSplit(nn.Module): 118 | """Patch Splitting Layer: Perform Pixel Shuffle and Upscale""" 119 | def __init__(self, 120 | in_dim: int, 121 | out_dim: int, 122 | scale_factor: tuple=(2,1), 123 | norm_layer=nn.LayerNorm): 124 | super().__init__() 125 | s1, s2 = scale_factor 126 | 127 | self.norm = norm_layer(in_dim) 128 | self.up = nn.Linear(in_dim, out_dim*s1*s2, bias=False) 129 | self.scale_factor = scale_factor 130 | 131 | def forward(self, x, H): 132 | """ Forward function. 133 | Args: 134 | x: Input feature, tensor size (B, H*W, in_dim) 135 | H: num_patches along Freq Domain 136 | returns: upscaled feature x, tensor size (B, H*W*2, out_dim) 137 | """ 138 | 139 | x = self.norm(x) 140 | x = self.up(x) 141 | 142 | x = rearrange(x, "b (h w) c -> b h w c", h=H) 143 | x = pixel_shuffle(x, self.scale_factor) 144 | x = rearrange(x, "b h w c -> b (h w) c") 145 | return x 146 | -------------------------------------------------------------------------------- /scripts/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchaudio.transforms as T 5 | import numpy as np 6 | from pesq import pesq 7 | 8 | MEL_WINDOWS = [32,64,128,256,512,1024,2048] 9 | MEL_BINS = [5,10,20,40,80,160,320] 10 | SR = 16000 11 | 12 | class EntropyCounter: 13 | """Counter maintaining codebook utilization rate on a held-out validation set""" 14 | def __init__(self, codebook_size=1024, 15 | num_streams=6, num_groups=3, 16 | device="cuda"): 17 | 18 | self.num_groups = num_groups 19 | self.codebook_size = codebook_size 20 | self.device = device 21 | 22 | self.reset_stats(num_streams) 23 | 24 | def reset_stats(self, num_streams): 25 | self.codebook_counts = { 26 | f"stream_{S}_group_{G+1}": torch.zeros(self.codebook_size, device=self.device) \ 27 | for S in range(num_streams) for G in range(self.num_groups) 28 | } # counts codeword stats for each codebook 29 | self.total_counts = 0 30 | self.dist = None # posterior distribution for each codebook 31 | self.entropy = None # entropy stats for each codebook 32 | 33 | self.max_entropy_per_book = np.log2(self.codebook_size) 34 | self.max_total_entropy = num_streams * self.num_groups * self.max_entropy_per_book 35 | self.num_streams = num_streams 36 | 37 | def update(self, codes): 38 | """ Update codebook counts and total counts from a batch of codes 39 | Args: 40 | codes: (B, num_streams, group_size, *) 41 | """ 42 | assert codes.size(1) == self.num_streams and codes.size(2) == self.num_groups, "code indices size not match" 43 | num_codes = codes.size(0) * codes.size(-1) 44 | self.total_counts += num_codes 45 | 46 | for s in range(self.num_streams): 47 | stream_s_code = codes[:, s] # (B, group_size, *) 48 | for g in range(self.num_groups): 49 | stream_s_group_g_code = stream_s_code[:,g] # (B, *) 50 | one_hot = F.one_hot(stream_s_group_g_code, num_classes=self.codebook_size) # (B, *, codebook_size) 51 | self.codebook_counts[f"stream_{s}_group_{g+1}"] += one_hot.view(-1, self.codebook_size).sum(0) # (*, codebook_size) 52 | 53 | def _form_distribution(self): 54 | """After iterating over a held-out set, compute posterior distribution for each codebook""" 55 | assert self.total_counts > 0, "No data collected, please update on a specific dataset" 56 | self.dist = {} 57 | for k, _counts in self.codebook_counts.items(): 58 | self.dist[k] = _counts / torch.tensor(self.total_counts, device=_counts.device) 59 | 60 | def _form_entropy(self): 61 | """After forming codebook posterior distributions, compute entropy for each distribution""" 62 | assert self.dist is not None, "Please compute posterior distribution first using self._form_distribution()" 63 | 64 | self.entropy = {} 65 | for k, dist in self.dist.items(): 66 | self.entropy[k] = (-torch.sum(dist * torch.log2(dist+1e-10))).item() 67 | 68 | def compute_utilization(self): 69 | """After forming entropy statistics for each codebook, compute utilization ratio (bitrate efficiency)""" 70 | if self.dist is None: self._form_distribution() 71 | if self.entropy is None: self._form_entropy() 72 | 73 | utilization = {} 74 | for k, e in self.entropy.items(): 75 | utilization[k] = round(e/self.max_entropy_per_book, 4) 76 | 77 | return round(sum(self.entropy.values())/self.max_total_entropy, 4), utilization 78 | 79 | class PESQ: 80 | """Batch-wise computing of PESQ scores""" 81 | def __call__(self, x, y): 82 | """ 83 | Args: 84 | x: source audio Tensor (B, L) 85 | y: recon audio Tensor (B, L) 86 | returns: (B,) 87 | """ 88 | batch_pesq = [] 89 | for b in range(x.size(0)): 90 | ref = x[b].cpu().numpy() 91 | deg = y[b].cpu().numpy() 92 | batch_pesq.append(pesq(SR, ref, deg, 'wb')) 93 | 94 | return torch.tensor(batch_pesq) 95 | 96 | class MelSpectrogramDistance(nn.Module): 97 | """ 98 | L1 Log MelSpectrogram Distance 99 | Implementation adapted from https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/loss.py 100 | """ 101 | def __init__(self, win_lengths=MEL_WINDOWS, 102 | n_mels=MEL_BINS, clamp_eps=1e-5,): 103 | super().__init__() 104 | self.mel_transf = nn.ModuleList([ 105 | T.MelSpectrogram(sample_rate=SR, 106 | n_fft=w, win_length=w, hop_length=w//4, 107 | n_mels=n_mels[i], power=1) 108 | for i, w in enumerate(win_lengths) 109 | ]) 110 | self.clamp_eps = clamp_eps 111 | 112 | def forward(self, raw_audio, recon_audio): 113 | mel_loss = 0.0 114 | for mel_trans in self.mel_transf: 115 | x_mels, y_mels = mel_trans(raw_audio), mel_trans(recon_audio) 116 | mel_loss += F.l1_loss( # log mel loss 117 | x_mels.clamp(self.clamp_eps).pow(2).log10(), 118 | y_mels.clamp(self.clamp_eps).pow(2).log10(), 119 | reduction="none" 120 | ).mean(dim=[1,2]) 121 | return mel_loss 122 | 123 | class SISDR(nn.Module): 124 | """ 125 | Scale-Invariant Source-to-Distortion Ratio 126 | Implementation adapted from https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/loss.py 127 | """ 128 | def __init__(self, scaling: int = True, 129 | reduction: str = "none", zero_mean: int = True): 130 | self.scaling = scaling 131 | self.reduction = reduction 132 | self.zero_mean = zero_mean 133 | super().__init__() 134 | 135 | def forward(self, x, y): 136 | eps = 1e-8 137 | 138 | references = x.unsqueeze(1) if x.dim() == 2 else x # add channel dim 139 | estimates = y.unsqueeze(1) if y.dim() == 2 else y # add channel dim 140 | 141 | nb = references.shape[0] 142 | references = references.reshape(nb, 1, -1).permute(0, 2, 1) 143 | estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1) 144 | 145 | # samples now on axis 1 146 | if self.zero_mean: 147 | mean_reference = references.mean(dim=1, keepdim=True) 148 | mean_estimate = estimates.mean(dim=1, keepdim=True) 149 | else: 150 | mean_reference = 0 151 | mean_estimate = 0 152 | 153 | _references = references - mean_reference 154 | _estimates = estimates - mean_estimate 155 | 156 | references_projection = (_references**2).sum(dim=-2) + eps 157 | references_on_estimates = (_estimates * _references).sum(dim=-2) + eps 158 | 159 | scale = ( 160 | (references_on_estimates / references_projection).unsqueeze(1) 161 | if self.scaling 162 | else 1 163 | ) 164 | e_true = scale * _references 165 | e_res = _estimates - e_true 166 | 167 | signal = (e_true**2).sum(dim=1) 168 | noise = (e_res**2).sum(dim=1) 169 | sdr = 10 * torch.log10(signal/noise + eps) 170 | 171 | return sdr.squeeze(1) # (B,) 172 | -------------------------------------------------------------------------------- /baselines/descript/dac/model/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from audiotools import AudioSignal 5 | from audiotools import ml 6 | from audiotools import STFTParams 7 | from einops import rearrange 8 | from torch.nn.utils import weight_norm 9 | 10 | 11 | def WNConv1d(*args, **kwargs): 12 | act = kwargs.pop("act", True) 13 | conv = weight_norm(nn.Conv1d(*args, **kwargs)) 14 | if not act: 15 | return conv 16 | return nn.Sequential(conv, nn.LeakyReLU(0.1)) 17 | 18 | 19 | def WNConv2d(*args, **kwargs): 20 | act = kwargs.pop("act", True) 21 | conv = weight_norm(nn.Conv2d(*args, **kwargs)) 22 | if not act: 23 | return conv 24 | return nn.Sequential(conv, nn.LeakyReLU(0.1)) 25 | 26 | 27 | class MPD(nn.Module): 28 | def __init__(self, period): 29 | super().__init__() 30 | self.period = period 31 | self.convs = nn.ModuleList( 32 | [ 33 | WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)), 34 | WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)), 35 | WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)), 36 | WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)), 37 | WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)), 38 | ] 39 | ) 40 | self.conv_post = WNConv2d( 41 | 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False 42 | ) 43 | 44 | def pad_to_period(self, x): 45 | t = x.shape[-1] 46 | x = F.pad(x, (0, self.period - t % self.period), mode="reflect") 47 | return x 48 | 49 | def forward(self, x): 50 | fmap = [] 51 | 52 | x = self.pad_to_period(x) 53 | x = rearrange(x, "b c (l p) -> b c l p", p=self.period) 54 | 55 | for layer in self.convs: 56 | x = layer(x) 57 | fmap.append(x) 58 | 59 | x = self.conv_post(x) 60 | fmap.append(x) 61 | 62 | return fmap 63 | 64 | 65 | class MSD(nn.Module): 66 | def __init__(self, rate: int = 1, sample_rate: int = 44100): 67 | super().__init__() 68 | self.convs = nn.ModuleList( 69 | [ 70 | WNConv1d(1, 16, 15, 1, padding=7), 71 | WNConv1d(16, 64, 41, 4, groups=4, padding=20), 72 | WNConv1d(64, 256, 41, 4, groups=16, padding=20), 73 | WNConv1d(256, 1024, 41, 4, groups=64, padding=20), 74 | WNConv1d(1024, 1024, 41, 4, groups=256, padding=20), 75 | WNConv1d(1024, 1024, 5, 1, padding=2), 76 | ] 77 | ) 78 | self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False) 79 | self.sample_rate = sample_rate 80 | self.rate = rate 81 | 82 | def forward(self, x): 83 | x = AudioSignal(x, self.sample_rate) 84 | x.resample(self.sample_rate // self.rate) 85 | x = x.audio_data 86 | 87 | fmap = [] 88 | 89 | for l in self.convs: 90 | x = l(x) 91 | fmap.append(x) 92 | x = self.conv_post(x) 93 | fmap.append(x) 94 | 95 | return fmap 96 | 97 | 98 | BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)] 99 | 100 | 101 | class MRD(nn.Module): 102 | def __init__( 103 | self, 104 | window_length: int, 105 | hop_factor: float = 0.25, 106 | sample_rate: int = 44100, 107 | bands: list = BANDS, 108 | ): 109 | """Complex multi-band spectrogram discriminator. 110 | Parameters 111 | ---------- 112 | window_length : int 113 | Window length of STFT. 114 | hop_factor : float, optional 115 | Hop factor of the STFT, defaults to ``0.25 * window_length``. 116 | sample_rate : int, optional 117 | Sampling rate of audio in Hz, by default 44100 118 | bands : list, optional 119 | Bands to run discriminator over. 120 | """ 121 | super().__init__() 122 | 123 | self.window_length = window_length 124 | self.hop_factor = hop_factor 125 | self.sample_rate = sample_rate 126 | self.stft_params = STFTParams( 127 | window_length=window_length, 128 | hop_length=int(window_length * hop_factor), 129 | match_stride=True, 130 | ) 131 | 132 | n_fft = window_length // 2 + 1 133 | bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] 134 | self.bands = bands 135 | 136 | ch = 32 137 | convs = lambda: nn.ModuleList( 138 | [ 139 | WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)), 140 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), 141 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), 142 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), 143 | WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)), 144 | ] 145 | ) 146 | self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) 147 | self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False) 148 | 149 | def spectrogram(self, x): 150 | x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params) 151 | x = torch.view_as_real(x.stft()) 152 | x = rearrange(x, "b 1 f t c -> (b 1) c t f") 153 | # Split into bands 154 | x_bands = [x[..., b[0] : b[1]] for b in self.bands] 155 | return x_bands 156 | 157 | def forward(self, x): 158 | x_bands = self.spectrogram(x) 159 | fmap = [] 160 | 161 | x = [] 162 | for band, stack in zip(x_bands, self.band_convs): 163 | for layer in stack: 164 | band = layer(band) 165 | fmap.append(band) 166 | x.append(band) 167 | 168 | x = torch.cat(x, dim=-1) 169 | x = self.conv_post(x) 170 | fmap.append(x) 171 | 172 | return fmap 173 | 174 | 175 | class Discriminator(ml.BaseModel): 176 | def __init__( 177 | self, 178 | rates: list = [], 179 | periods: list = [2, 3, 5, 7, 11], 180 | fft_sizes: list = [2048, 1024, 512], 181 | sample_rate: int = 44100, 182 | bands: list = BANDS, 183 | ): 184 | """Discriminator that combines multiple discriminators. 185 | 186 | Parameters 187 | ---------- 188 | rates : list, optional 189 | sampling rates (in Hz) to run MSD at, by default [] 190 | If empty, MSD is not used. 191 | periods : list, optional 192 | periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11] 193 | fft_sizes : list, optional 194 | Window sizes of the FFT to run MRD at, by default [2048, 1024, 512] 195 | sample_rate : int, optional 196 | Sampling rate of audio in Hz, by default 44100 197 | bands : list, optional 198 | Bands to run MRD at, by default `BANDS` 199 | """ 200 | super().__init__() 201 | discs = [] 202 | discs += [MPD(p) for p in periods] 203 | discs += [MSD(r, sample_rate=sample_rate) for r in rates] 204 | discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes] 205 | self.discriminators = nn.ModuleList(discs) 206 | 207 | def preprocess(self, y): 208 | # Remove DC offset 209 | y = y - y.mean(dim=-1, keepdims=True) 210 | # Peak normalize the volume of input audio 211 | y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9) 212 | return y 213 | 214 | def forward(self, x): 215 | x = self.preprocess(x) 216 | fmaps = [d(x) for d in self.discriminators] 217 | return fmaps 218 | 219 | 220 | if __name__ == "__main__": 221 | disc = Discriminator() 222 | x = torch.zeros(1, 1, 44100) 223 | results = disc(x) 224 | for i, result in enumerate(results): 225 | print(f"disc{i}") 226 | for i, r in enumerate(result): 227 | print(r.shape, r.mean(), r.min(), r.max()) 228 | print() 229 | -------------------------------------------------------------------------------- /esc/models/discriminator.py: -------------------------------------------------------------------------------- 1 | """Same discriminator from https://github.com/descriptinc/descript-audio-codec/blob/main/dac/model/discriminator.py 2 | Requires audiotools package to be installed. 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from audiotools import AudioSignal 9 | from audiotools import ml 10 | from audiotools import STFTParams 11 | from einops import rearrange 12 | from torch.nn.utils import weight_norm 13 | 14 | 15 | def WNConv1d(*args, **kwargs): 16 | act = kwargs.pop("act", True) 17 | conv = weight_norm(nn.Conv1d(*args, **kwargs)) 18 | if not act: 19 | return conv 20 | return nn.Sequential(conv, nn.LeakyReLU(0.1)) 21 | 22 | 23 | def WNConv2d(*args, **kwargs): 24 | act = kwargs.pop("act", True) 25 | conv = weight_norm(nn.Conv2d(*args, **kwargs)) 26 | if not act: 27 | return conv 28 | return nn.Sequential(conv, nn.LeakyReLU(0.1)) 29 | 30 | 31 | class MPD(nn.Module): 32 | def __init__(self, period): 33 | super().__init__() 34 | self.period = period 35 | self.convs = nn.ModuleList( 36 | [ 37 | WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)), 38 | WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)), 39 | WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)), 40 | WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)), 41 | WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)), 42 | ] 43 | ) 44 | self.conv_post = WNConv2d( 45 | 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False 46 | ) 47 | 48 | def pad_to_period(self, x): 49 | t = x.shape[-1] 50 | x = F.pad(x, (0, self.period - t % self.period), mode="reflect") 51 | return x 52 | 53 | def forward(self, x): 54 | fmap = [] 55 | 56 | x = self.pad_to_period(x) 57 | x = rearrange(x, "b c (l p) -> b c l p", p=self.period) 58 | 59 | for layer in self.convs: 60 | x = layer(x) 61 | fmap.append(x) 62 | 63 | x = self.conv_post(x) 64 | fmap.append(x) 65 | 66 | return fmap 67 | 68 | 69 | class MSD(nn.Module): 70 | def __init__(self, rate: int = 1, sample_rate: int = 44100): 71 | super().__init__() 72 | self.convs = nn.ModuleList( 73 | [ 74 | WNConv1d(1, 16, 15, 1, padding=7), 75 | WNConv1d(16, 64, 41, 4, groups=4, padding=20), 76 | WNConv1d(64, 256, 41, 4, groups=16, padding=20), 77 | WNConv1d(256, 1024, 41, 4, groups=64, padding=20), 78 | WNConv1d(1024, 1024, 41, 4, groups=256, padding=20), 79 | WNConv1d(1024, 1024, 5, 1, padding=2), 80 | ] 81 | ) 82 | self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False) 83 | self.sample_rate = sample_rate 84 | self.rate = rate 85 | 86 | def forward(self, x): 87 | x = AudioSignal(x, self.sample_rate) 88 | x.resample(self.sample_rate // self.rate) 89 | x = x.audio_data 90 | 91 | fmap = [] 92 | 93 | for l in self.convs: 94 | x = l(x) 95 | fmap.append(x) 96 | x = self.conv_post(x) 97 | fmap.append(x) 98 | 99 | return fmap 100 | 101 | 102 | BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)] 103 | 104 | 105 | class MRD(nn.Module): 106 | def __init__( 107 | self, 108 | window_length: int, 109 | hop_factor: float = 0.25, 110 | sample_rate: int = 44100, 111 | bands: list = BANDS, 112 | ): 113 | """Complex multi-band spectrogram discriminator. 114 | Parameters 115 | ---------- 116 | window_length : int 117 | Window length of STFT. 118 | hop_factor : float, optional 119 | Hop factor of the STFT, defaults to ``0.25 * window_length``. 120 | sample_rate : int, optional 121 | Sampling rate of audio in Hz, by default 44100 122 | bands : list, optional 123 | Bands to run discriminator over. 124 | """ 125 | super().__init__() 126 | 127 | self.window_length = window_length 128 | self.hop_factor = hop_factor 129 | self.sample_rate = sample_rate 130 | self.stft_params = STFTParams( 131 | window_length=window_length, 132 | hop_length=int(window_length * hop_factor), 133 | match_stride=True, 134 | ) 135 | 136 | n_fft = window_length // 2 + 1 137 | bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] 138 | self.bands = bands 139 | 140 | ch = 32 141 | convs = lambda: nn.ModuleList( 142 | [ 143 | WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)), 144 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), 145 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), 146 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), 147 | WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)), 148 | ] 149 | ) 150 | self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) 151 | self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False) 152 | 153 | def spectrogram(self, x): 154 | x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params) 155 | x = torch.view_as_real(x.stft()) 156 | x = rearrange(x, "b 1 f t c -> (b 1) c t f") 157 | # Split into bands 158 | x_bands = [x[..., b[0] : b[1]] for b in self.bands] 159 | return x_bands 160 | 161 | def forward(self, x): 162 | x_bands = self.spectrogram(x) 163 | fmap = [] 164 | 165 | x = [] 166 | for band, stack in zip(x_bands, self.band_convs): 167 | for layer in stack: 168 | band = layer(band) 169 | fmap.append(band) 170 | x.append(band) 171 | 172 | x = torch.cat(x, dim=-1) 173 | x = self.conv_post(x) 174 | fmap.append(x) 175 | 176 | return fmap 177 | 178 | 179 | class Discriminator(ml.BaseModel): 180 | def __init__( 181 | self, 182 | rates: list = [], 183 | periods: list = [2, 3, 5, 7, 11], 184 | fft_sizes: list = [2048, 1024, 512], 185 | sample_rate: int = 44100, 186 | bands: list = BANDS, 187 | ): 188 | """Discriminator that combines multiple discriminators. 189 | 190 | Parameters 191 | ---------- 192 | rates : list, optional 193 | sampling rates (in Hz) to run MSD at, by default [] 194 | If empty, MSD is not used. 195 | periods : list, optional 196 | periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11] 197 | fft_sizes : list, optional 198 | Window sizes of the FFT to run MRD at, by default [2048, 1024, 512] 199 | sample_rate : int, optional 200 | Sampling rate of audio in Hz, by default 44100 201 | bands : list, optional 202 | Bands to run MRD at, by default `BANDS` 203 | """ 204 | super().__init__() 205 | discs = [] 206 | discs += [MPD(p) for p in periods] 207 | discs += [MSD(r, sample_rate=sample_rate) for r in rates] 208 | discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes] 209 | self.discriminators = nn.ModuleList(discs) 210 | 211 | def preprocess(self, y): 212 | # Remove DC offset 213 | y = y - y.mean(dim=-1, keepdims=True) 214 | # Peak normalize the volume of input audio 215 | y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9) 216 | return y 217 | 218 | def forward(self, x): 219 | x = self.preprocess(x) 220 | fmaps = [d(x) for d in self.discriminators] 221 | return fmaps 222 | 223 | 224 | if __name__ == "__main__": 225 | disc = Discriminator() 226 | x = torch.zeros(1, 1, 44100) 227 | results = disc(x) 228 | for i, result in enumerate(results): 229 | print(f"disc{i}") 230 | for i, r in enumerate(result): 231 | print(r.shape, r.mean(), r.min(), r.max()) 232 | print() -------------------------------------------------------------------------------- /esc/models/csrvq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Literal 4 | 5 | from ..modules import ProductVectorQuantize, TransformerLayer, PatchDeEmbed, ConvolutionLayer, Convolution2D 6 | from .utils import blk_func 7 | 8 | class CrossScaleRVQ(nn.Module): 9 | """Cross-Scale Residual Vector Quantization Framework""" 10 | def __init__(self, backbone: Literal['transformer', 'convolution']="transformer") -> None: 11 | super().__init__() 12 | if backbone == "transformer": self.dims = 3 13 | elif backbone == "convolution": self.dims = 4 14 | 15 | def pre_fuse(self, enc, dec): 16 | """Compute residuals to quantize""" 17 | return enc - dec 18 | 19 | def post_fuse(self, residual_q, dec): 20 | """Add back quantized residuals""" 21 | return residual_q + dec 22 | 23 | def csrvq(self, enc: torch.tensor, dec: torch.tensor, vq: ProductVectorQuantize, 24 | transmit: bool=True, freeze_vq: bool=False): 25 | """ Forward Function combining encoding and decoding at a single bitstream/resolution scale 26 | Args: 27 | enc (Tensor): Tensor of encoded feature with shape (B, H*W, C) / (B, C, H, W) 28 | dec (Tensor): Tensor of decoded feature with shape (B, H*W, C) / (B, C, H, W) 29 | vq (ProductVectorQuantize): product quantizer at this stream level 30 | transmit (Boolean): whether this stream is transmitted (perform quantization or not) 31 | freeze_vq (Boolean): whether freeze the codebook (in a pre-training stage) 32 | Returns: 33 | Tensor of dec_refine (decoded feature conditioned on quantized encodings) 34 | """ 35 | if not self.training and not transmit: 36 | return dec, 0., 0., None 37 | 38 | residual = self.pre_fuse(enc, dec) 39 | outputs = vq(residual, freeze_vq) 40 | residual_q, code = outputs["z_q"], outputs["codes"] 41 | cm_loss, cb_loss = outputs["cm_loss"], outputs["cb_loss"] 42 | 43 | if not transmit: # masking non-transmitted streams 44 | cm_loss, cb_loss = cm_loss * 0., cb_loss * 0. 45 | residual_q *= 0. 46 | 47 | dec_refine = self.post_fuse(residual_q, dec) 48 | return dec_refine, cm_loss, cb_loss, code 49 | 50 | def csrvq_encode(self, enc, dec, vq): 51 | 52 | residual = self.pre_fuse(enc, dec) 53 | code = vq.encode(residual) 54 | return code 55 | 56 | def csrvq_decode(self, codes, dec, vq): 57 | 58 | residual_q = vq.decode(codes, self.dims) 59 | dec_refine = self.post_fuse(residual_q, dec) 60 | return dec_refine 61 | 62 | 63 | class CrossScaleRVQDecoder(CrossScaleRVQ): 64 | def __init__(self, 65 | backbone: Literal['transformer', 'convolution'], 66 | in_freq: int, 67 | in_dim: int, 68 | h_dims: list, 69 | patch_size: tuple, 70 | kernel_size: list=[], 71 | conv_depth: int=1, 72 | swin_heads: list=[], 73 | swin_depth: int=2, 74 | window_size: int=4, 75 | mlp_ratio: float=4.,) -> None: 76 | super().__init__(backbone) 77 | 78 | in_dims, out_dims = h_dims[:-1], h_dims[1:] 79 | 80 | blocks = nn.ModuleList() 81 | for i in range(len(in_dims)): 82 | layer = ConvolutionLayer(in_dims[i], out_dims[i], conv_depth, kernel_size, transpose=True) if backbone == "convolution" \ 83 | else TransformerLayer( 84 | in_dims[i], out_dims[i], swin_heads[i], swin_depth, window_size, mlp_ratio, 85 | activation=nn.GELU, norm_layer=nn.LayerNorm, scale="up", scale_factor=(2,1) 86 | ) 87 | blocks.append(layer) 88 | 89 | self.patch_deembed = PatchDeEmbed(in_freq, in_dim, patch_size, h_dims[-1], backbone) 90 | self.post_nn = Convolution2D(h_dims[-1], h_dims[-1], kernel_size, scale=False) if backbone == "convolution" \ 91 | else TransformerLayer( 92 | h_dims[-1], h_dims[-1], swin_heads[-1], swin_depth, window_size, mlp_ratio, 93 | activation=nn.GELU, norm_layer=nn.LayerNorm, scale=None 94 | ) 95 | self.blocks = blocks 96 | 97 | def forward(self, enc_hs: list, num_streams: int, quantizers: nn.ModuleList, feat_shape: tuple, freeze_vq: bool=False): 98 | """Forward Function: step-wise cross-scale decoding 99 | Args: 100 | enc_hs (List[Tensor, ...]): a list of encoded features at all scales 101 | num_streams (int): number of bitstreams to use (max_streams when freeze_vq is True) 102 | quantizers (ModuleList): a modulelist of multi-scale quantizers 103 | feat_shape (Tuple): (Wh, Ww) feature shape at bottom level 104 | freeze_vq (Boolean): freeze vq layers during pre-training 105 | Returns: 106 | recon_feat: reconstructed complex spectrum (Bs, 2, F, T) 107 | codes: discrete indices (Bs, num_streams, group_size, T//overlap) 108 | num_streams is always max_stream in training mode 109 | cm_loss, cb_loss: VQ losses (Bs, ) 110 | """ 111 | z0, cm_loss, cb_loss, code = self.csrvq(enc=enc_hs[-1], dec=0.0, vq=quantizers[0], 112 | transmit=True, freeze_vq=freeze_vq) 113 | codes, dec_hs = [code], [z0] 114 | for i, blk in enumerate(self.blocks): 115 | dec_i_refine, cm_loss_i, cb_loss_i, code_i = self.csrvq( 116 | enc=enc_hs[-1-i], dec=dec_hs[i], vq=quantizers[i+1], transmit=(i