├── AudioDec ├── About │ └── README.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── bin │ ├── stream.py │ ├── test.py │ ├── train.py │ └── utils.py ├── config │ ├── autoencoder │ │ ├── symAAD_vctk_48000_hop300.yaml │ │ ├── symAD_c16_vctk_48000_hop320.yaml │ │ ├── symAD_libritts_24000_hop300.yaml │ │ ├── symAD_vctk_48000_hop300.yaml │ │ └── symADuniv_vctk_48000_hop300.yaml │ ├── denoise │ │ └── symAD_vctk_48000_hop300.yaml │ ├── statistic │ │ ├── symAD_libritts_24000_hop300_clean.yaml │ │ ├── symAD_vctk_48000_hop300_clean.yaml │ │ └── symADuniv_vctk_48000_hop300_clean.yaml │ └── vocoder │ │ ├── AudioDec_v0_symAD_vctk_48000_hop300_clean.yaml │ │ ├── AudioDec_v1_symAD_libritts_24000_hop300_clean.yaml │ │ ├── AudioDec_v1_symAD_vctk_48000_hop300_clean.yaml │ │ ├── AudioDec_v2_symAD_vctk_48000_hop300_clean.yaml │ │ └── AudioDec_v3_symADuniv_vctk_48000_hop300_clean.yaml ├── dataloader │ ├── __init__.py │ ├── collater.py │ ├── dataset.py │ └── utils.py ├── figs │ ├── architecture-1.png │ ├── latency.jpg │ └── mos.jpg ├── layers │ ├── activation_function.py │ ├── conv_layer.py │ └── vq_module.py ├── losses │ ├── __init__.py │ ├── adversarial_loss.py │ ├── feat_match_loss.py │ ├── mel_loss.py │ ├── stft_loss.py │ └── waveform_loss.py ├── models │ ├── autoencoder │ │ ├── AudioDec.py │ │ └── modules │ │ │ ├── decoder.py │ │ │ ├── encoder.py │ │ │ ├── projector.py │ │ │ ├── quantizer.py │ │ │ └── residual_unit.py │ ├── utils.py │ └── vocoder │ │ ├── HiFiGAN.py │ │ └── modules │ │ ├── discriminator.py │ │ ├── multi_fusion.py │ │ └── residual_block.py ├── parse_options.sh ├── requirements.txt ├── slurmlogs │ └── README.md ├── stats │ ├── symAD_libritts_24000_hop300_clean.npy │ ├── symAD_vctk_48000_hop300_clean.npy │ └── symADuniv_vctk_48000_hop300_clean.npy ├── trainer │ ├── autoencoder.py │ ├── denoise.py │ ├── trainerGAN.py │ └── vocoder.py └── utils │ └── audiodec.py ├── LICENSE ├── README.md ├── config └── hparams.yaml ├── infer_tts.py ├── models ├── ar.py └── nar.py ├── requirements.txt ├── test └── prompt_wavs │ └── test_1.wav ├── text ├── chinese.py ├── chinese_dict ├── opencpop-strict.txt ├── symbols.py └── tone_sandhi.py └── utils ├── hparams.py └── utils.py /AudioDec/About/README.md: -------------------------------------------------------------------------------- 1 | # AudioDec: An Open-source Streaming High-fidelity Neural Audio Codec 2 | 3 | 4 | 5 | ## Audio quality (Subjective MOS) 6 |

7 | 8 |

9 | 10 | 11 | ## Latency 12 |

13 | 14 |

15 | 16 | 17 | 18 | ## Model size 19 | Only the generators. 20 | ```bash 21 | # AutoEncoder (symAD) 22 | # - Encoder 23 | Number of total parmeters:3,806,368 24 | Model size: 14.68MB 25 | # - Decoder 26 | Number of total parmeters:4,035,264 27 | Model size: 15.54MB 28 | # Vocoder (AD v0) 29 | Number of total parmeters:12,932,610 30 | Model size: 49.74MB 31 | # Vocoder (AD v1) 32 | Number of total parmeters:19,461,090 33 | Model size: 74.90MB 34 | # Vocoder (AD v2) 35 | Number of total parmeters:6,927,330 36 | Model size: 26.56MB 37 | ``` 38 | -------------------------------------------------------------------------------- /AudioDec/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /AudioDec/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to AudioDec 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | AudioDec is the implementation of a research paper. 7 | Therefore, we do not plan on accepting many pull requests for new features. 8 | We certainly welcome them for bug fixes. 9 | 10 | 1. Fork the repo and create your branch from `main`. 11 | 2. If you've added code that should be tested, add tests. 12 | 3. If you've changed APIs, update the documentation. 13 | 4. Ensure the test suite passes. 14 | 5. Make sure your code lints. 15 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 16 | 17 | ## Contributor License Agreement ("CLA") 18 | In order to accept your pull request, we need you to submit a CLA. You only need 19 | to do this once to work on any of Meta's open source projects. 20 | 21 | Complete your CLA here: 22 | 23 | ## Issues 24 | We use GitHub issues to track public bugs. Please ensure your description is 25 | clear and has sufficient instructions to be able to reproduce the issue. 26 | 27 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 28 | disclosure of security bugs. In those cases, please go through the process 29 | outlined on that page and do not file a public issue. 30 | 31 | ## License 32 | By contributing to AudioDec, you agree that your contributions will be licensed 33 | under the LICENSE file in the root directory of this source tree. 34 | -------------------------------------------------------------------------------- /AudioDec/bin/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright (c) Meta Platforms, Inc. and affiliates. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | # 10 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/) 11 | 12 | """Testing stage template.""" 13 | 14 | import os 15 | import abc 16 | import sys 17 | import time 18 | import yaml 19 | import torch 20 | import logging 21 | import soundfile as sf 22 | 23 | from tqdm import tqdm 24 | from bin.utils import load_config 25 | 26 | class TestGEN(abc.ABC): 27 | def __init__( 28 | self, 29 | args, 30 | ): 31 | # set logger 32 | logging.basicConfig( 33 | level=logging.INFO, 34 | stream=sys.stdout, 35 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", 36 | ) 37 | 38 | # device 39 | if not torch.cuda.is_available(): 40 | self.device = torch.device('cpu') 41 | logging.info(f"device: cpu") 42 | else: 43 | self.device = torch.device('cuda') 44 | logging.info(f"device: gpu") 45 | 46 | # initialize attribute 47 | if hasattr(args, 'encoder'): 48 | self.encoder_checkpoint = args.encoder 49 | self.encoder_config = self._load_config(args.encoder) 50 | if hasattr(args, 'decoder'): 51 | self.decoder_checkpoint = args.decoder 52 | self.decoder_config = self._load_config(args.decoder) 53 | self.encoder = None 54 | self.decoder = None 55 | self.dataset = None 56 | self.outdir = None 57 | 58 | 59 | @abc.abstractmethod 60 | def initial_folder(self, output_name): 61 | pass 62 | 63 | 64 | @abc.abstractmethod 65 | def load_dataset(self): 66 | pass 67 | 68 | 69 | @abc.abstractmethod 70 | def load_encoder(self): 71 | pass 72 | 73 | 74 | @abc.abstractmethod 75 | def load_decoder(self): 76 | pass 77 | 78 | 79 | @abc.abstractmethod 80 | def encode(self, x): 81 | pass 82 | 83 | 84 | @abc.abstractmethod 85 | def decode(self, z): 86 | pass 87 | 88 | 89 | def run(self): 90 | total_rtf = 0.0 91 | with torch.no_grad(), tqdm(self.dataset, desc="[test]") as pbar: 92 | for idx, (utt_id, x) in enumerate(pbar, 1): 93 | start = time.time() 94 | zq = self.encode(x) 95 | y = self.decode(zq) 96 | y = y.squeeze(1).transpose(1, 0).cpu().numpy() # T x C 97 | rtf = (time.time() - start) / (len(y) / self.decoder_config['sampling_rate']) 98 | pbar.set_postfix({"RTF": rtf}) 99 | total_rtf += rtf 100 | 101 | # output wav file 102 | self._save_wav(os.path.join(self.outdir, f"{utt_id}_output.wav"), y) 103 | 104 | logging.info( 105 | "Finished generation of %d utterances (RTF = %.03f)." % (idx, (total_rtf / idx)) 106 | ) 107 | 108 | 109 | def _save_wav(self, file_name, audio): 110 | sf.write( 111 | file_name, 112 | audio, 113 | self.decoder_config['sampling_rate'], 114 | "PCM_16", 115 | ) 116 | 117 | 118 | def _load_config(self, checkpoint, config_name='config.yml'): 119 | return load_config(checkpoint, config_name) 120 | -------------------------------------------------------------------------------- /AudioDec/bin/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright (c) Meta Platforms, Inc. and affiliates. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | # 10 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/) 11 | 12 | """Training stage template.""" 13 | 14 | import os 15 | import abc 16 | import sys 17 | import yaml 18 | import random 19 | import logging 20 | import torch 21 | import numpy as np 22 | 23 | from bin.utils import load_config 24 | 25 | 26 | class TrainGAN(abc.ABC): 27 | def __init__( 28 | self, 29 | args, 30 | ): 31 | # set logger 32 | logging.basicConfig( 33 | level=logging.INFO, 34 | stream=sys.stdout, 35 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", 36 | ) 37 | 38 | # Fix seed and make backends deterministic 39 | random.seed(args.seed) 40 | np.random.seed(args.seed) 41 | torch.manual_seed(args.seed) 42 | if not torch.cuda.is_available(): 43 | self.device = torch.device('cpu') 44 | logging.info(f"device: cpu") 45 | else: 46 | self.device = torch.device('cuda') 47 | logging.info(f"device: gpu") 48 | torch.cuda.manual_seed_all(args.seed) 49 | if args.disable_cudnn == "False": 50 | torch.backends.cudnn.benchmark = True 51 | 52 | # initialize config 53 | with open(args.config, 'r') as f: 54 | self.config = yaml.load(f, Loader=yaml.FullLoader) 55 | self.config.update(vars(args)) 56 | 57 | # initialize model folder 58 | expdir = os.path.join(args.exp_root, args.tag) 59 | os.makedirs(expdir, exist_ok=True) 60 | self.config["outdir"] = expdir 61 | 62 | # save config 63 | with open(os.path.join(expdir, "config.yml"), "w") as f: 64 | yaml.dump(self.config, f, Dumper=yaml.Dumper) 65 | for key, value in self.config.items(): 66 | logging.info(f"[TrainGAN] {key} = {value}") 67 | 68 | # initialize attribute 69 | self.resume = args.resume 70 | self.data_loader = None 71 | self.model = {} 72 | self.criterion = None 73 | self.optimizer = None 74 | self.scheduler = None 75 | self.trainer = None 76 | 77 | # initialize batch_length 78 | self.batch_length = self.config['batch_length'] 79 | 80 | 81 | @abc.abstractmethod 82 | def initialize_data_loader(self): 83 | pass 84 | 85 | 86 | @abc.abstractmethod 87 | def define_model(self): 88 | pass 89 | 90 | 91 | @abc.abstractmethod 92 | def define_trainer(self): 93 | pass 94 | 95 | 96 | @abc.abstractmethod 97 | def initialize_model(self): 98 | pass 99 | 100 | 101 | @abc.abstractmethod 102 | def define_criterion(self): 103 | pass 104 | 105 | 106 | def run(self): 107 | try: 108 | logging.info(f"The current training step: {self.trainer.steps}") 109 | self.trainer.train_max_steps = self.config["train_max_steps"] 110 | if not self.trainer._check_train_finish(): 111 | self.trainer.run() 112 | if self.config.get("adv_train_max_steps", False) and self.config.get("adv_batch_length", False): 113 | self.batch_length = self.config['adv_batch_length'] 114 | logging.info(f"Reload dataloader for adversarial training.") 115 | self.initialize_data_loader() 116 | self.trainer.data_loader = self.data_loader 117 | self.trainer.train_max_steps = self.config["adv_train_max_steps"] 118 | self.trainer.run() 119 | finally: 120 | self.trainer.save_checkpoint( 121 | os.path.join(self.config["outdir"], f"checkpoint-{self.trainer.steps}steps.pkl") 122 | ) 123 | logging.info(f"Successfully saved checkpoint @ {self.trainer.steps}steps.") 124 | 125 | 126 | def _show_setting(self): 127 | logging.info(self.model['generator']) 128 | logging.info(self.model['discriminator']) 129 | logging.info(self.optimizer['generator']) 130 | logging.info(self.optimizer['discriminator']) 131 | logging.info(self.scheduler['generator']) 132 | logging.info(self.scheduler['discriminator']) 133 | for criterion_ in self.criterion.values(): 134 | logging.info(criterion_) 135 | 136 | 137 | def _load_config(self, checkpoint, config_name='config.yml'): 138 | return load_config(checkpoint, config_name) 139 | 140 | -------------------------------------------------------------------------------- /AudioDec/bin/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright (c) Meta Platforms, Inc. and affiliates. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | 10 | 11 | """Utility modules.""" 12 | 13 | import os 14 | import yaml 15 | 16 | 17 | def load_config(checkpoint, config_name='config.yml'): 18 | dirname = os.path.dirname(checkpoint) 19 | config_path = os.path.join(dirname, config_name) 20 | with open(config_path) as f: 21 | config = yaml.load(f, Loader=yaml.Loader) 22 | return config 23 | -------------------------------------------------------------------------------- /AudioDec/config/autoencoder/symAAD_vctk_48000_hop300.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/) 8 | 9 | 10 | ########################################################### 11 | # DATA SETTING # 12 | ########################################################### 13 | sampling_rate: &sampling_rate 48000 14 | data: 15 | path: /mnt/home/yichiaowu/datasets/list 16 | subset: 17 | train: vctk_48000_clean_trainset_84spk.txt 18 | valid: vctk_48000_clean_validset_84spk.txt 19 | test: vctk_48000_clean_testset.txt 20 | vhstest: VHS_48000_raw_Phase1_test.txt 21 | 22 | ########################################################### 23 | # MODEL SETTING # 24 | ########################################################### 25 | model_type: symAudioDec 26 | train_mode: autoencoder 27 | paradigm: efficient 28 | 29 | generator_params: 30 | input_channels: 1 31 | output_channels: 1 32 | encode_channels: 32 33 | decode_channels: 32 34 | code_dim: 64 35 | codebook_num: 8 36 | codebook_size: 1024 37 | bias: true 38 | enc_ratios: [2, 4, 8, 16] 39 | dec_ratios: [16, 8, 4, 2] 40 | enc_strides: [3, 4, 5, 5] 41 | dec_strides: [5, 5, 4, 3] 42 | mode: causal 43 | codec: activate_audiodec 44 | projector: conv1d 45 | quantier: residual_vq 46 | use_weight_norm: true 47 | 48 | discriminator_params: 49 | scales: 3 # Number of multi-scale discriminator. 50 | scale_downsample_pooling: AvgPool1d # Pooling operation for scale discriminator. 51 | scale_downsample_pooling_params: 52 | kernel_size: 4 # Pooling kernel size. 53 | stride: 2 # Pooling stride. 54 | padding: 2 # Padding size. 55 | scale_discriminator_params: 56 | in_channels: 1 # Number of input channels. 57 | out_channels: 1 # Number of output channels. 58 | kernel_sizes: [15, 41, 5, 3] # List of kernel sizes. 59 | channels: 128 # Initial number of channels. 60 | max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers. 61 | max_groups: 16 # Maximum number of groups in downsampling conv layers. 62 | bias: true 63 | downsample_scales: [4, 4, 4, 4, 1] # Downsampling scales. 64 | nonlinear_activation: LeakyReLU # Nonlinear activation. 65 | nonlinear_activation_params: 66 | negative_slope: 0.1 67 | follow_official_norm: true # Whether to follow the official norm setting. 68 | periods: [2, 3, 5, 7, 11] # List of period for multi-period discriminator. 69 | period_discriminator_params: 70 | in_channels: 1 # Number of input channels. 71 | out_channels: 1 # Number of output channels. 72 | kernel_sizes: [5, 3] # List of kernel sizes. 73 | channels: 32 # Initial number of channels. 74 | downsample_scales: [3, 3, 3, 3, 1] # Downsampling scales. 75 | max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers. 76 | bias: true # Whether to use bias parameter in conv layer. 77 | nonlinear_activation: LeakyReLU # Nonlinear activation. 78 | nonlinear_activation_params: # Nonlinear activation paramters. 79 | negative_slope: 0.1 80 | use_weight_norm: true # Whether to apply weight normalization. 81 | use_spectral_norm: false # Whether to apply spectral normalization. 82 | 83 | ########################################################### 84 | # METRIC LOSS SETTING # 85 | ########################################################### 86 | use_mel_loss: true # Whether to use Mel-spectrogram loss. 87 | mel_loss_params: 88 | fs: *sampling_rate 89 | fft_sizes: [2048] 90 | hop_sizes: [300] 91 | win_lengths: [2048] 92 | window: hann_window 93 | num_mels: 80 94 | fmin: 0 95 | fmax: 24000 96 | log_base: null 97 | 98 | use_stft_loss: false # Whether to use multi-resolution STFT loss. 99 | stft_loss_params: 100 | fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss. 101 | hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss 102 | win_lengths: [600, 1200, 240] # List of window length for STFT-based loss. 103 | window: hann_window # Window function for STFT-based loss 104 | 105 | use_shape_loss: false # Whether to use waveform shape loss. 106 | shape_loss_params: 107 | winlen: [300] 108 | 109 | ########################################################### 110 | # ADV LOSS SETTING # 111 | ########################################################### 112 | generator_adv_loss_params: 113 | average_by_discriminators: false # Whether to average loss by #discriminators. 114 | 115 | discriminator_adv_loss_params: 116 | average_by_discriminators: false # Whether to average loss by #discriminators. 117 | 118 | use_feat_match_loss: true 119 | feat_match_loss_params: 120 | average_by_discriminators: false # Whether to average loss by #discriminators. 121 | average_by_layers: false # Whether to average loss by #layers in each discriminator. 122 | include_final_outputs: false # Whether to include final outputs in feat match loss calculation. 123 | 124 | ########################################################### 125 | # LOSS WEIGHT SETTING # 126 | ########################################################### 127 | lambda_adv: 1.0 # Loss weight of adversarial loss. 128 | lambda_feat_match: 2.0 # Loss weight of feat match loss. 129 | lambda_vq_loss: 1.0 # Loss weight of vector quantize loss. 130 | lambda_mel_loss: 45.0 # Loss weight of mel-spectrogram spectloss. 131 | lambda_stft_loss: 45.0 # Loss weight of multi-resolution stft loss. 132 | lambda_shape_loss: 45.0 # Loss weight of multi-window shape loss. 133 | 134 | ########################################################### 135 | # DATA LOADER SETTING # 136 | ########################################################### 137 | batch_size: 16 # Batch size. 138 | batch_length: 9600 # Length of each audio in batch (training w/o adv). Make sure dividable by hop_size. 139 | adv_batch_length: 9600 # Length of each audio in batch (training w/ adv). Make sure dividable by hop_size. 140 | pin_memory: true # Whether to pin memory in Pytorch DataLoader. 141 | num_workers: 2 # Number of workers in Pytorch DataLoader. 142 | 143 | ########################################################### 144 | # OPTIMIZER & SCHEDULER SETTING # 145 | ########################################################### 146 | generator_optimizer_type: Adam 147 | generator_optimizer_params: 148 | lr: 1.0e-4 149 | betas: [0.5, 0.9] 150 | weight_decay: 0.0 151 | generator_scheduler_type: StepLR 152 | generator_scheduler_params: 153 | step_size: 200000 # Generator scheduler step size. 154 | gamma: 1.0 155 | generator_grad_norm: -1 156 | discriminator_optimizer_type: Adam 157 | discriminator_optimizer_params: 158 | lr: 2.0e-4 159 | betas: [0.5, 0.9] 160 | weight_decay: 0.0 161 | discriminator_scheduler_type: MultiStepLR 162 | discriminator_scheduler_params: 163 | gamma: 0.5 164 | milestones: 165 | - 200000 166 | - 400000 167 | - 600000 168 | - 800000 169 | discriminator_grad_norm: -1 170 | 171 | ########################################################### 172 | # INTERVAL SETTING # 173 | ########################################################### 174 | start_steps: # Number of steps to start training 175 | generator: 0 176 | discriminator: 200000 177 | train_max_steps: 200000 # Number of training steps. (w/o adv) 178 | adv_train_max_steps: 700000 # Number of training steps. (w/ adv) 179 | save_interval_steps: 100000 # Interval steps to save checkpoint. 180 | eval_interval_steps: 1000 # Interval steps to evaluate the network. 181 | log_interval_steps: 100 # Interval steps to record the training log. 182 | -------------------------------------------------------------------------------- /AudioDec/config/autoencoder/symAD_c16_vctk_48000_hop320.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/) 8 | 9 | 10 | ########################################################### 11 | # DATA SETTING # 12 | ########################################################### 13 | sampling_rate: &sampling_rate 48000 14 | data: 15 | path: /mnt/home/yichiaowu/datasets/list 16 | subset: 17 | train: vctk_48000_clean_trainset_84spk.txt 18 | valid: vctk_48000_clean_validset_84spk.txt 19 | test: vctk_48000_clean_testset.txt 20 | 21 | ########################################################### 22 | # MODEL SETTING # 23 | ########################################################### 24 | model_type: symAudioDec 25 | train_mode: autoencoder 26 | paradigm: efficient 27 | 28 | generator_params: 29 | input_channels: 1 30 | output_channels: 1 31 | encode_channels: 32 32 | decode_channels: 32 33 | code_dim: 64 34 | codebook_num: 16 35 | codebook_size: 1024 36 | bias: true 37 | enc_ratios: [2, 4, 8, 16] 38 | dec_ratios: [16, 8, 4, 2] 39 | enc_strides: [2, 4, 5, 8] 40 | dec_strides: [8, 5, 4, 2] 41 | mode: causal 42 | codec: audiodec 43 | projector: conv1d 44 | quantier: residual_vq 45 | 46 | discriminator_params: 47 | scales: 3 # Number of multi-scale discriminator. 48 | scale_downsample_pooling: AvgPool1d # Pooling operation for scale discriminator. 49 | scale_downsample_pooling_params: 50 | kernel_size: 4 # Pooling kernel size. 51 | stride: 2 # Pooling stride. 52 | padding: 2 # Padding size. 53 | scale_discriminator_params: 54 | in_channels: 1 # Number of input channels. 55 | out_channels: 1 # Number of output channels. 56 | kernel_sizes: [15, 41, 5, 3] # List of kernel sizes. 57 | channels: 128 # Initial number of channels. 58 | max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers. 59 | max_groups: 16 # Maximum number of groups in downsampling conv layers. 60 | bias: true 61 | downsample_scales: [4, 4, 4, 4, 1] # Downsampling scales. 62 | nonlinear_activation: LeakyReLU # Nonlinear activation. 63 | nonlinear_activation_params: 64 | negative_slope: 0.1 65 | follow_official_norm: true # Whether to follow the official norm setting. 66 | periods: [2, 3, 5, 7, 11] # List of period for multi-period discriminator. 67 | period_discriminator_params: 68 | in_channels: 1 # Number of input channels. 69 | out_channels: 1 # Number of output channels. 70 | kernel_sizes: [5, 3] # List of kernel sizes. 71 | channels: 32 # Initial number of channels. 72 | downsample_scales: [3, 3, 3, 3, 1] # Downsampling scales. 73 | max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers. 74 | bias: true # Whether to use bias parameter in conv layer. 75 | nonlinear_activation: LeakyReLU # Nonlinear activation. 76 | nonlinear_activation_params: # Nonlinear activation paramters. 77 | negative_slope: 0.1 78 | use_weight_norm: true # Whether to apply weight normalization. 79 | use_spectral_norm: false # Whether to apply spectral normalization. 80 | 81 | ########################################################### 82 | # METRIC LOSS SETTING # 83 | ########################################################### 84 | use_mel_loss: true # Whether to use Mel-spectrogram loss. 85 | mel_loss_params: 86 | fs: *sampling_rate 87 | fft_sizes: [2048] 88 | hop_sizes: [300] 89 | win_lengths: [2048] 90 | window: hann_window 91 | num_mels: 80 92 | fmin: 0 93 | fmax: 24000 94 | log_base: null 95 | 96 | use_stft_loss: false # Whether to use multi-resolution STFT loss. 97 | stft_loss_params: 98 | fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss. 99 | hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss 100 | win_lengths: [600, 1200, 240] # List of window length for STFT-based loss. 101 | window: hann_window # Window function for STFT-based loss 102 | 103 | use_shape_loss: false # Whether to use waveform shape loss. 104 | shape_loss_params: 105 | winlen: [320] 106 | 107 | ########################################################### 108 | # ADV LOSS SETTING # 109 | ########################################################### 110 | generator_adv_loss_params: 111 | average_by_discriminators: false # Whether to average loss by #discriminators. 112 | 113 | discriminator_adv_loss_params: 114 | average_by_discriminators: false # Whether to average loss by #discriminators. 115 | 116 | use_feat_match_loss: true 117 | feat_match_loss_params: 118 | average_by_discriminators: false # Whether to average loss by #discriminators. 119 | average_by_layers: false # Whether to average loss by #layers in each discriminator. 120 | include_final_outputs: false # Whether to include final outputs in feat match loss calculation. 121 | 122 | ########################################################### 123 | # LOSS WEIGHT SETTING # 124 | ########################################################### 125 | lambda_adv: 1.0 # Loss weight of adversarial loss. 126 | lambda_feat_match: 2.0 # Loss weight of feat match loss. 127 | lambda_vq_loss: 1.0 # Loss weight of vector quantize loss. 128 | lambda_mel_loss: 45.0 # Loss weight of mel-spectrogram spectloss. 129 | lambda_stft_loss: 45.0 # Loss weight of multi-resolution stft loss. 130 | lambda_shape_loss: 45.0 # Loss weight of multi-window shape loss. 131 | 132 | ########################################################### 133 | # DATA LOADER SETTING # 134 | ########################################################### 135 | batch_size: 16 # Batch size. 136 | batch_length: 96000 # Length of each audio in batch (training w/o adv). Make sure dividable by hop_size. 137 | adv_batch_length: 9600 # Length of each audio in batch (training w/ adv). Make sure dividable by hop_size. 138 | pin_memory: true # Whether to pin memory in Pytorch DataLoader. 139 | num_workers: 2 # Number of workers in Pytorch DataLoader. 140 | 141 | ########################################################### 142 | # OPTIMIZER & SCHEDULER SETTING # 143 | ########################################################### 144 | generator_optimizer_type: Adam 145 | generator_optimizer_params: 146 | lr: 1.0e-4 147 | betas: [0.5, 0.9] 148 | weight_decay: 0.0 149 | generator_scheduler_type: StepLR 150 | generator_scheduler_params: 151 | step_size: 200000 # Generator scheduler step size. 152 | gamma: 1.0 153 | generator_grad_norm: -1 154 | discriminator_optimizer_type: Adam 155 | discriminator_optimizer_params: 156 | lr: 2.0e-4 157 | betas: [0.5, 0.9] 158 | weight_decay: 0.0 159 | discriminator_scheduler_type: MultiStepLR 160 | discriminator_scheduler_params: 161 | gamma: 0.5 162 | milestones: 163 | - 200000 164 | - 400000 165 | - 600000 166 | - 800000 167 | discriminator_grad_norm: -1 168 | 169 | ########################################################### 170 | # INTERVAL SETTING # 171 | ########################################################### 172 | start_steps: # Number of steps to start training 173 | generator: 0 174 | discriminator: 500000 175 | train_max_steps: 500000 # Number of training steps. (w/o adv) 176 | adv_train_max_steps: 1000000 # Number of training steps. (w/ adv) 177 | save_interval_steps: 100000 # Interval steps to save checkpoint. 178 | eval_interval_steps: 1000 # Interval steps to evaluate the network. 179 | log_interval_steps: 100 # Interval steps to record the training log. 180 | -------------------------------------------------------------------------------- /AudioDec/config/autoencoder/symAD_libritts_24000_hop300.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/) 8 | 9 | 10 | ########################################################### 11 | # DATA SETTING # 12 | ########################################################### 13 | sampling_rate: &sampling_rate 24000 14 | data: 15 | path: /mnt/home/yichiaowu/datasets/LibriTTS/LibriTTS/24000 16 | subset: 17 | train: train-clean-450 18 | valid: dev-clean-1utt 19 | test: test-clean-1utt 20 | 21 | ########################################################### 22 | # MODEL SETTING # 23 | ########################################################### 24 | model_type: symAudioDec 25 | train_mode: autoencoder 26 | paradigm: efficient 27 | 28 | generator_params: 29 | input_channels: 1 30 | output_channels: 1 31 | encode_channels: 32 32 | decode_channels: 32 33 | code_dim: 64 34 | codebook_num: 8 35 | codebook_size: 1024 36 | bias: true 37 | enc_ratios: [2, 4, 8, 16] 38 | dec_ratios: [16, 8, 4, 2] 39 | enc_strides: [3, 4, 5, 5] 40 | dec_strides: [5, 5, 4, 3] 41 | mode: causal 42 | codec: audiodec 43 | projector: conv1d 44 | quantier: residual_vq 45 | 46 | discriminator_params: 47 | scales: 3 # Number of multi-scale discriminator. 48 | scale_downsample_pooling: AvgPool1d # Pooling operation for scale discriminator. 49 | scale_downsample_pooling_params: 50 | kernel_size: 4 # Pooling kernel size. 51 | stride: 2 # Pooling stride. 52 | padding: 2 # Padding size. 53 | scale_discriminator_params: 54 | in_channels: 1 # Number of input channels. 55 | out_channels: 1 # Number of output channels. 56 | kernel_sizes: [15, 41, 5, 3] # List of kernel sizes. 57 | channels: 128 # Initial number of channels. 58 | max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers. 59 | max_groups: 16 # Maximum number of groups in downsampling conv layers. 60 | bias: true 61 | downsample_scales: [4, 4, 4, 4, 1] # Downsampling scales. 62 | nonlinear_activation: LeakyReLU # Nonlinear activation. 63 | nonlinear_activation_params: 64 | negative_slope: 0.1 65 | follow_official_norm: true # Whether to follow the official norm setting. 66 | periods: [2, 3, 5, 7, 11] # List of period for multi-period discriminator. 67 | period_discriminator_params: 68 | in_channels: 1 # Number of input channels. 69 | out_channels: 1 # Number of output channels. 70 | kernel_sizes: [5, 3] # List of kernel sizes. 71 | channels: 32 # Initial number of channels. 72 | downsample_scales: [3, 3, 3, 3, 1] # Downsampling scales. 73 | max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers. 74 | bias: true # Whether to use bias parameter in conv layer. 75 | nonlinear_activation: LeakyReLU # Nonlinear activation. 76 | nonlinear_activation_params: # Nonlinear activation paramters. 77 | negative_slope: 0.1 78 | use_weight_norm: true # Whether to apply weight normalization. 79 | use_spectral_norm: false # Whether to apply spectral normalization. 80 | 81 | ########################################################### 82 | # METRIC LOSS SETTING # 83 | ########################################################### 84 | use_mel_loss: true # Whether to use Mel-spectrogram loss. 85 | mel_loss_params: 86 | fs: *sampling_rate 87 | fft_sizes: [2048] 88 | hop_sizes: [300] 89 | win_lengths: [2048] 90 | window: hann_window 91 | num_mels: 80 92 | fmin: 0 93 | fmax: 12000 94 | log_base: null 95 | 96 | use_stft_loss: false # Whether to use multi-resolution STFT loss. 97 | stft_loss_params: 98 | fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss. 99 | hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss 100 | win_lengths: [600, 1200, 240] # List of window length for STFT-based loss. 101 | window: hann_window # Window function for STFT-based loss 102 | 103 | use_shape_loss: false # Whether to use waveform shape loss. 104 | shape_loss_params: 105 | winlen: [300] 106 | 107 | ########################################################### 108 | # ADV LOSS SETTING # 109 | ########################################################### 110 | generator_adv_loss_params: 111 | average_by_discriminators: false # Whether to average loss by #discriminators. 112 | 113 | discriminator_adv_loss_params: 114 | average_by_discriminators: false # Whether to average loss by #discriminators. 115 | 116 | use_feat_match_loss: true 117 | feat_match_loss_params: 118 | average_by_discriminators: false # Whether to average loss by #discriminators. 119 | average_by_layers: false # Whether to average loss by #layers in each discriminator. 120 | include_final_outputs: false # Whether to include final outputs in feat match loss calculation. 121 | 122 | ########################################################### 123 | # LOSS WEIGHT SETTING # 124 | ########################################################### 125 | lambda_adv: 1.0 # Loss weight of adversarial loss. 126 | lambda_feat_match: 2.0 # Loss weight of feat match loss. 127 | lambda_vq_loss: 1.0 # Loss weight of vector quantize loss. 128 | lambda_mel_loss: 45.0 # Loss weight of mel-spectrogram spectloss. 129 | lambda_stft_loss: 45.0 # Loss weight of multi-resolution stft loss. 130 | lambda_shape_loss: 45.0 # Loss weight of multi-window shape loss. 131 | 132 | ########################################################### 133 | # DATA LOADER SETTING # 134 | ########################################################### 135 | batch_size: 16 # Batch size. 136 | batch_length: 9600 # Length of each audio in batch (training w/o adv). Make sure dividable by hop_size. 137 | adv_batch_length: 9600 # Length of each audio in batch (training w/ adv). Make sure dividable by hop_size. 138 | pin_memory: true # Whether to pin memory in Pytorch DataLoader. 139 | num_workers: 2 # Number of workers in Pytorch DataLoader. 140 | 141 | ########################################################### 142 | # OPTIMIZER & SCHEDULER SETTING # 143 | ########################################################### 144 | generator_optimizer_type: Adam 145 | generator_optimizer_params: 146 | lr: 1.0e-4 147 | betas: [0.5, 0.9] 148 | weight_decay: 0.0 149 | generator_scheduler_type: StepLR 150 | generator_scheduler_params: 151 | step_size: 200000 # Generator scheduler step size. 152 | gamma: 1.0 153 | generator_grad_norm: -1 154 | discriminator_optimizer_type: Adam 155 | discriminator_optimizer_params: 156 | lr: 2.0e-4 157 | betas: [0.5, 0.9] 158 | weight_decay: 0.0 159 | discriminator_scheduler_type: MultiStepLR 160 | discriminator_scheduler_params: 161 | gamma: 0.5 162 | milestones: 163 | - 200000 164 | - 400000 165 | - 600000 166 | - 800000 167 | discriminator_grad_norm: -1 168 | 169 | ########################################################### 170 | # INTERVAL SETTING # 171 | ########################################################### 172 | start_steps: # Number of steps to start training 173 | generator: 0 174 | discriminator: 500000 175 | train_max_steps: 500000 # Number of training steps. (w/o adv) 176 | adv_train_max_steps: 1000000 # Number of training steps. (w/ adv) 177 | save_interval_steps: 100000 # Interval steps to save checkpoint. 178 | eval_interval_steps: 1000 # Interval steps to evaluate the network. 179 | log_interval_steps: 100 # Interval steps to record the training log. 180 | -------------------------------------------------------------------------------- /AudioDec/config/autoencoder/symAD_vctk_48000_hop300.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/) 8 | 9 | 10 | ########################################################### 11 | # DATA SETTING # 12 | ########################################################### 13 | sampling_rate: &sampling_rate 48000 14 | data: 15 | path: /mnt/home/yichiaowu/datasets/vctk_noisy/48000 16 | subset: 17 | train: clean_trainset_84spk_wav 18 | valid: clean_validset_84spk_wav 19 | test: clean_testset_wav 20 | 21 | ########################################################### 22 | # MODEL SETTING # 23 | ########################################################### 24 | model_type: symAudioDec 25 | train_mode: autoencoder 26 | paradigm: efficient 27 | 28 | generator_params: 29 | input_channels: 1 30 | output_channels: 1 31 | encode_channels: 32 32 | decode_channels: 32 33 | code_dim: 64 34 | codebook_num: 8 35 | codebook_size: 1024 36 | bias: true 37 | enc_ratios: [2, 4, 8, 16] 38 | dec_ratios: [16, 8, 4, 2] 39 | enc_strides: [3, 4, 5, 5] 40 | dec_strides: [5, 5, 4, 3] 41 | mode: causal 42 | codec: audiodec 43 | projector: conv1d 44 | quantier: residual_vq 45 | 46 | discriminator_params: 47 | scales: 3 # Number of multi-scale discriminator. 48 | scale_downsample_pooling: AvgPool1d # Pooling operation for scale discriminator. 49 | scale_downsample_pooling_params: 50 | kernel_size: 4 # Pooling kernel size. 51 | stride: 2 # Pooling stride. 52 | padding: 2 # Padding size. 53 | scale_discriminator_params: 54 | in_channels: 1 # Number of input channels. 55 | out_channels: 1 # Number of output channels. 56 | kernel_sizes: [15, 41, 5, 3] # List of kernel sizes. 57 | channels: 128 # Initial number of channels. 58 | max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers. 59 | max_groups: 16 # Maximum number of groups in downsampling conv layers. 60 | bias: true 61 | downsample_scales: [4, 4, 4, 4, 1] # Downsampling scales. 62 | nonlinear_activation: LeakyReLU # Nonlinear activation. 63 | nonlinear_activation_params: 64 | negative_slope: 0.1 65 | follow_official_norm: true # Whether to follow the official norm setting. 66 | periods: [2, 3, 5, 7, 11] # List of period for multi-period discriminator. 67 | period_discriminator_params: 68 | in_channels: 1 # Number of input channels. 69 | out_channels: 1 # Number of output channels. 70 | kernel_sizes: [5, 3] # List of kernel sizes. 71 | channels: 32 # Initial number of channels. 72 | downsample_scales: [3, 3, 3, 3, 1] # Downsampling scales. 73 | max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers. 74 | bias: true # Whether to use bias parameter in conv layer. 75 | nonlinear_activation: LeakyReLU # Nonlinear activation. 76 | nonlinear_activation_params: # Nonlinear activation paramters. 77 | negative_slope: 0.1 78 | use_weight_norm: true # Whether to apply weight normalization. 79 | use_spectral_norm: false # Whether to apply spectral normalization. 80 | 81 | ########################################################### 82 | # METRIC LOSS SETTING # 83 | ########################################################### 84 | use_mel_loss: true # Whether to use Mel-spectrogram loss. 85 | mel_loss_params: 86 | fs: *sampling_rate 87 | fft_sizes: [2048] 88 | hop_sizes: [300] 89 | win_lengths: [2048] 90 | window: hann_window 91 | num_mels: 80 92 | fmin: 0 93 | fmax: 24000 94 | log_base: null 95 | 96 | use_stft_loss: false # Whether to use multi-resolution STFT loss. 97 | stft_loss_params: 98 | fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss. 99 | hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss 100 | win_lengths: [600, 1200, 240] # List of window length for STFT-based loss. 101 | window: hann_window # Window function for STFT-based loss 102 | 103 | use_shape_loss: false # Whether to use waveform shape loss. 104 | shape_loss_params: 105 | winlen: [300] 106 | 107 | ########################################################### 108 | # ADV LOSS SETTING # 109 | ########################################################### 110 | generator_adv_loss_params: 111 | average_by_discriminators: false # Whether to average loss by #discriminators. 112 | 113 | discriminator_adv_loss_params: 114 | average_by_discriminators: false # Whether to average loss by #discriminators. 115 | 116 | use_feat_match_loss: true 117 | feat_match_loss_params: 118 | average_by_discriminators: false # Whether to average loss by #discriminators. 119 | average_by_layers: false # Whether to average loss by #layers in each discriminator. 120 | include_final_outputs: false # Whether to include final outputs in feat match loss calculation. 121 | 122 | ########################################################### 123 | # LOSS WEIGHT SETTING # 124 | ########################################################### 125 | lambda_adv: 1.0 # Loss weight of adversarial loss. 126 | lambda_feat_match: 2.0 # Loss weight of feat match loss. 127 | lambda_vq_loss: 1.0 # Loss weight of vector quantize loss. 128 | lambda_mel_loss: 45.0 # Loss weight of mel-spectrogram spectloss. 129 | lambda_stft_loss: 45.0 # Loss weight of multi-resolution stft loss. 130 | lambda_shape_loss: 45.0 # Loss weight of multi-window shape loss. 131 | 132 | ########################################################### 133 | # DATA LOADER SETTING # 134 | ########################################################### 135 | batch_size: 16 # Batch size. 136 | batch_length: 9600 # Length of each audio in batch (training w/o adv). Make sure dividable by hop_size. 137 | adv_batch_length: 9600 # Length of each audio in batch (training w/ adv). Make sure dividable by hop_size. 138 | pin_memory: true # Whether to pin memory in Pytorch DataLoader. 139 | num_workers: 2 # Number of workers in Pytorch DataLoader. 140 | 141 | ########################################################### 142 | # OPTIMIZER & SCHEDULER SETTING # 143 | ########################################################### 144 | generator_optimizer_type: Adam 145 | generator_optimizer_params: 146 | lr: 1.0e-4 147 | betas: [0.5, 0.9] 148 | weight_decay: 0.0 149 | generator_scheduler_type: StepLR 150 | generator_scheduler_params: 151 | step_size: 200000 # Generator scheduler step size. 152 | gamma: 1.0 153 | generator_grad_norm: -1 154 | discriminator_optimizer_type: Adam 155 | discriminator_optimizer_params: 156 | lr: 2.0e-4 157 | betas: [0.5, 0.9] 158 | weight_decay: 0.0 159 | discriminator_scheduler_type: MultiStepLR 160 | discriminator_scheduler_params: 161 | gamma: 0.5 162 | milestones: 163 | - 200000 164 | - 400000 165 | - 600000 166 | - 800000 167 | discriminator_grad_norm: -1 168 | 169 | ########################################################### 170 | # INTERVAL SETTING # 171 | ########################################################### 172 | start_steps: # Number of steps to start training 173 | generator: 0 174 | discriminator: 200000 175 | train_max_steps: 200000 # Number of training steps. (w/o adv) 176 | adv_train_max_steps: 700000 # Number of training steps. (w/ adv) 177 | save_interval_steps: 100000 # Interval steps to save checkpoint. 178 | eval_interval_steps: 1000 # Interval steps to evaluate the network. 179 | log_interval_steps: 100 # Interval steps to record the training log. 180 | -------------------------------------------------------------------------------- /AudioDec/config/autoencoder/symADuniv_vctk_48000_hop300.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/) 8 | # Reference (https://github.com/chomeyama/SiFiGAN) 9 | 10 | 11 | ########################################################### 12 | # DATA SETTING # 13 | ########################################################### 14 | sampling_rate: &sampling_rate 48000 15 | data: 16 | path: /mnt/home/yichiaowu/datasets/vctk_noisy/48000 17 | subset: 18 | train: clean_trainset_84spk_wav 19 | valid: clean_validset_84spk_wav 20 | test: clean_testset_wav 21 | 22 | ########################################################### 23 | # MODEL SETTING # 24 | ########################################################### 25 | model_type: symAudioDecUniv 26 | train_mode: autoencoder 27 | paradigm: efficient 28 | 29 | generator_params: 30 | input_channels: 1 31 | output_channels: 1 32 | encode_channels: 32 33 | decode_channels: 32 34 | code_dim: 64 35 | codebook_num: 8 36 | codebook_size: 1024 37 | bias: true 38 | enc_ratios: [2, 4, 8, 16] 39 | dec_ratios: [16, 8, 4, 2] 40 | enc_strides: [3, 4, 5, 5] 41 | dec_strides: [5, 5, 4, 3] 42 | mode: causal 43 | codec: audiodec 44 | projector: conv1d 45 | quantier: residual_vq 46 | 47 | discriminator_params: 48 | fft_sizes: [1024, 2048, 512] # FFT sizes for each spectral discriminator. 49 | hop_sizes: [120, 240, 50] # Hop sizes for each spectral discriminator. 50 | win_lengths: [600, 1200, 240] # Window lengths for each spectral discriminator. 51 | window: "hann_window" # Name of window function. 52 | spectral_discriminator_params: # Params for UnivNet spectral discriminator. 53 | channels: 32 # Number of channels for conv layer. 54 | kernel_sizes: # List of stride sizes in down-sampling CNNs. 55 | - [3, 9] 56 | - [3, 9] 57 | - [3, 9] 58 | - [3, 9] 59 | - [3, 3] 60 | - [3, 3] 61 | strides: # List of kernel sizes in down-sampling CNNs. 62 | - [1, 1] 63 | - [1, 2] 64 | - [1, 2] 65 | - [1, 2] 66 | - [1, 1] 67 | - [1, 1] 68 | bias: true # Whether to add bias parameter in convolution layers. 69 | nonlinear_activation: "LeakyReLU" # Nonlinear activation. 70 | nonlinear_activation_params: # Nonlinear activation paramters. 71 | negative_slope: 0.2 72 | periods: [2, 3, 5, 7, 11] # List of period for multi-period discriminator. 73 | period_discriminator_params: 74 | in_channels: 1 # Number of input channels. 75 | out_channels: 1 # Number of output channels. 76 | kernel_sizes: [5, 3] # List of kernel sizes. 77 | channels: 32 # Initial number of channels. 78 | downsample_scales: [3, 3, 3, 3, 1] # Downsampling scales. 79 | max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers. 80 | bias: true # Whether to use bias parameter in conv layer." 81 | nonlinear_activation: "LeakyReLU" # Nonlinear activation. 82 | nonlinear_activation_params: # Nonlinear activation paramters. 83 | negative_slope: 0.1 84 | use_weight_norm: true # Whether to apply weight normalization. 85 | use_spectral_norm: false # Whether to apply spectral normalization. 86 | 87 | ########################################################### 88 | # METRIC LOSS SETTING # 89 | ########################################################### 90 | use_mel_loss: true # Whether to use Mel-spectrogram loss. 91 | mel_loss_params: 92 | fs: *sampling_rate 93 | fft_sizes: [2048] 94 | hop_sizes: [300] 95 | win_lengths: [2048] 96 | window: "hann_window" 97 | num_mels: 80 98 | fmin: 0 99 | fmax: 24000 100 | log_base: null 101 | 102 | use_stft_loss: false # Whether to use multi-resolution STFT loss. 103 | stft_loss_params: 104 | fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss. 105 | hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss 106 | win_lengths: [600, 1200, 240] # List of window length for STFT-based loss. 107 | window: "hann_window" # Window function for STFT-based loss 108 | 109 | use_shape_loss: false # Whether to use waveform shape loss. 110 | shape_loss_params: 111 | winlen: [300] 112 | 113 | ########################################################### 114 | # ADV LOSS SETTING # 115 | ########################################################### 116 | generator_adv_loss_params: 117 | average_by_discriminators: false # Whether to average loss by #discriminators. 118 | 119 | discriminator_adv_loss_params: 120 | average_by_discriminators: false # Whether to average loss by #discriminators. 121 | 122 | use_feat_match_loss: true 123 | feat_match_loss_params: 124 | average_by_discriminators: false # Whether to average loss by #discriminators. 125 | average_by_layers: false # Whether to average loss by #layers in each discriminator. 126 | include_final_outputs: false # Whether to include final outputs in feat match loss calculation. 127 | 128 | ########################################################### 129 | # LOSS WEIGHT SETTING # 130 | ########################################################### 131 | lambda_adv: 1.0 # Loss weight of adversarial loss. 132 | lambda_feat_match: 2.0 # Loss weight of feat match loss. 133 | lambda_vq_loss: 1.0 # Loss weight of vector quantize loss. 134 | lambda_mel_loss: 45.0 # Loss weight of mel-spectrogram spectloss. 135 | lambda_stft_loss: 45.0 # Loss weight of multi-resolution stft loss. 136 | lambda_shape_loss: 45.0 # Loss weight of multi-window shape loss. 137 | 138 | ########################################################### 139 | # DATA LOADER SETTING # 140 | ########################################################### 141 | batch_size: 16 # Batch size. 142 | batch_length: 9600 # Length of each audio in batch (training w/o adv). Make sure dividable by hop_size. 143 | adv_batch_length: 9600 # Length of each audio in batch (training w/ adv). Make sure dividable by hop_size. 144 | pin_memory: true # Whether to pin memory in Pytorch DataLoader. 145 | num_workers: 2 # Number of workers in Pytorch DataLoader. 146 | 147 | ########################################################### 148 | # OPTIMIZER & SCHEDULER SETTING # 149 | ########################################################### 150 | generator_optimizer_type: Adam 151 | generator_optimizer_params: 152 | lr: 1.0e-4 153 | betas: [0.5, 0.9] 154 | weight_decay: 0.0 155 | generator_scheduler_type: StepLR 156 | generator_scheduler_params: 157 | step_size: 200000 # Generator scheduler step size. 158 | gamma: 1.0 159 | generator_grad_norm: -1 160 | discriminator_optimizer_type: Adam 161 | discriminator_optimizer_params: 162 | lr: 2.0e-4 163 | betas: [0.5, 0.9] 164 | weight_decay: 0.0 165 | discriminator_scheduler_type: MultiStepLR 166 | discriminator_scheduler_params: 167 | gamma: 0.5 168 | milestones: 169 | - 200000 170 | - 400000 171 | - 600000 172 | - 800000 173 | discriminator_grad_norm: -1 174 | 175 | ########################################################### 176 | # INTERVAL SETTING # 177 | ########################################################### 178 | start_steps: # Number of steps to start training 179 | generator: 0 180 | discriminator: 500000 181 | train_max_steps: 500000 # Number of training steps. (w/o adv) 182 | adv_train_max_steps: 1000000 # Number of training steps. (w/ adv) 183 | save_interval_steps: 100000 # Interval steps to save checkpoint. 184 | eval_interval_steps: 1000 # Interval steps to evaluate the network. 185 | log_interval_steps: 100 # Interval steps to record the training log. 186 | -------------------------------------------------------------------------------- /AudioDec/config/denoise/symAD_vctk_48000_hop300.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/) 8 | 9 | 10 | ########################################################### 11 | # DATA SETTING # 12 | ########################################################### 13 | sampling_rate: &sampling_rate 48000 14 | data: 15 | path: /mnt/home/yichiaowu/datasets/vctk_noisy/48000 16 | subset: 17 | clean_train: clean_trainset_84spk_wav 18 | clean_valid: clean_validset_84spk_wav 19 | clean_test: clean_testset_wav 20 | noisy_train: noisy_trainset_84spk_wav 21 | noisy_valid: noisy_validset_84spk_wav 22 | noisy_test: noisy_testset_wav 23 | 24 | ########################################################### 25 | # MODEL SETTING # 26 | ########################################################### 27 | model_type: symAudioDec 28 | train_mode: denoise 29 | initial: exp/autoencoder/symAD_vctk_48000_hop300/checkpoint-200000steps.pkl # for model initialization 30 | 31 | generator_params: 32 | input_channels: 1 33 | output_channels: 1 34 | encode_channels: 32 35 | decode_channels: 32 36 | code_dim: 64 37 | codebook_num: 8 38 | codebook_size: 1024 39 | bias: true 40 | enc_ratios: [2, 4, 8, 16] 41 | dec_ratios: [16, 8, 4, 2] 42 | enc_strides: [3, 4, 5, 5] 43 | dec_strides: [5, 5, 4, 3] 44 | mode: causal 45 | codec: audiodec 46 | projector: conv1d 47 | quantier: residual_vq 48 | 49 | discriminator_params: 50 | scales: 3 # Number of multi-scale discriminator. 51 | scale_downsample_pooling: AvgPool1d # Pooling operation for scale discriminator. 52 | scale_downsample_pooling_params: 53 | kernel_size: 4 # Pooling kernel size. 54 | stride: 2 # Pooling stride. 55 | padding: 2 # Padding size. 56 | scale_discriminator_params: 57 | in_channels: 1 # Number of input channels. 58 | out_channels: 1 # Number of output channels. 59 | kernel_sizes: [15, 41, 5, 3] # List of kernel sizes. 60 | channels: 128 # Initial number of channels. 61 | max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers. 62 | max_groups: 16 # Maximum number of groups in downsampling conv layers. 63 | bias: true 64 | downsample_scales: [4, 4, 4, 4, 1] # Downsampling scales. 65 | nonlinear_activation: LeakyReLU # Nonlinear activation. 66 | nonlinear_activation_params: 67 | negative_slope: 0.1 68 | follow_official_norm: true # Whether to follow the official norm setting. 69 | periods: [2, 3, 5, 7, 11] # List of period for multi-period discriminator. 70 | period_discriminator_params: 71 | in_channels: 1 # Number of input channels. 72 | out_channels: 1 # Number of output channels. 73 | kernel_sizes: [5, 3] # List of kernel sizes. 74 | channels: 32 # Initial number of channels. 75 | downsample_scales: [3, 3, 3, 3, 1] # Downsampling scales. 76 | max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers. 77 | bias: true # Whether to use bias parameter in conv layer. 78 | nonlinear_activation: LeakyReLU # Nonlinear activation. 79 | nonlinear_activation_params: # Nonlinear activation paramters. 80 | negative_slope: 0.1 81 | use_weight_norm: true # Whether to apply weight normalization. 82 | use_spectral_norm: false # Whether to apply spectral normalization. 83 | 84 | ########################################################### 85 | # METRIC LOSS SETTING # 86 | ########################################################### 87 | use_mel_loss: true # Whether to use Mel-spectrogram loss. 88 | mel_loss_params: 89 | fs: *sampling_rate 90 | fft_sizes: [2048] 91 | hop_sizes: [300] 92 | win_lengths: [null] 93 | window: hann_window 94 | num_mels: 80 95 | fmin: 0 96 | fmax: 24000 97 | log_base: null 98 | 99 | use_stft_loss: false # Whether to use multi-resolution STFT loss. 100 | stft_loss_params: 101 | fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss. 102 | hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss 103 | win_lengths: [600, 1200, 240] # List of window length for STFT-based loss. 104 | window: hann_window # Window function for STFT-based loss 105 | 106 | use_shape_loss: false # Whether to use waveform shape loss. 107 | shape_loss_params: 108 | winlen: [300] 109 | 110 | ########################################################### 111 | # ADV LOSS SETTING # 112 | ########################################################### 113 | generator_adv_loss_params: 114 | average_by_discriminators: false # Whether to average loss by #discriminators. 115 | 116 | discriminator_adv_loss_params: 117 | average_by_discriminators: false # Whether to average loss by #discriminators. 118 | 119 | use_feat_match_loss: true 120 | feat_match_loss_params: 121 | average_by_discriminators: false # Whether to average loss by #discriminators. 122 | average_by_layers: false # Whether to average loss by #layers in each discriminator. 123 | include_final_outputs: false # Whether to include final outputs in feat match loss calculation. 124 | 125 | ########################################################### 126 | # LOSS WEIGHT SETTING # 127 | ########################################################### 128 | lambda_adv: 1.0 # Loss weight of adversarial loss. 129 | lambda_feat_match: 2.0 # Loss weight of feat match loss. 130 | lambda_vq_loss: 1.0 # Loss weight of vector quantize loss. 131 | lambda_mel_loss: 45.0 # Loss weight of mel-spectrogram spectloss. 132 | lambda_stft_loss: 45.0 # Loss weight of multi-resolution stft loss. 133 | lambda_shape_loss: 45.0 # Loss weight of multi-window shape loss. 134 | 135 | ########################################################### 136 | # DATA LOADER SETTING # 137 | ########################################################### 138 | batch_size: 16 # Batch size. 139 | batch_length: 96000 # Length of each audio in batch. Make sure dividable by hop_size. 140 | pin_memory: true # Whether to pin memory in Pytorch DataLoader. 141 | num_workers: 2 # Number of workers in Pytorch DataLoader. 142 | 143 | ########################################################### 144 | # OPTIMIZER & SCHEDULER SETTING # 145 | ########################################################### 146 | generator_optimizer_type: Adam 147 | generator_optimizer_params: 148 | lr: 1.0e-4 149 | betas: [0.5, 0.9] 150 | weight_decay: 0.0 151 | generator_scheduler_type: StepLR 152 | generator_scheduler_params: 153 | step_size: 200000 # Generator scheduler step size. 154 | gamma: 1.0 155 | generator_grad_norm: -1 156 | discriminator_optimizer_type: Adam 157 | discriminator_optimizer_params: 158 | lr: 2.0e-4 159 | betas: [0.5, 0.9] 160 | weight_decay: 0.0 161 | discriminator_scheduler_type: MultiStepLR 162 | discriminator_scheduler_params: 163 | gamma: 0.5 164 | milestones: 165 | - 200000 166 | - 400000 167 | - 600000 168 | - 800000 169 | discriminator_grad_norm: -1 170 | 171 | ########################################################### 172 | # INTERVAL SETTING # 173 | ########################################################### 174 | start_steps: # Number of steps to start training 175 | generator: 0 176 | discriminator: 200000 177 | train_max_steps: 200000 # Number of training steps. 178 | save_interval_steps: 100000 # Interval steps to save checkpoint. 179 | eval_interval_steps: 1000 # Interval steps to evaluate the network. 180 | log_interval_steps: 100 # Interval steps to record the training log. 181 | -------------------------------------------------------------------------------- /AudioDec/config/statistic/symAD_libritts_24000_hop300_clean.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/) 8 | 9 | 10 | ########################################################### 11 | # DATA SETTING # 12 | ########################################################### 13 | sampling_rate: 24000 14 | data: 15 | path: /mnt/home/yichiaowu/datasets/LibriTTS/LibriTTS/24000 16 | subset: 17 | train: train-clean-450 18 | valid: dev-clean-1utt 19 | test: test-clean-1utt 20 | 21 | ########################################################### 22 | # STATISTIC SETTING # 23 | ########################################################### 24 | analyzer: exp/autoencoder/symAD_libritts_24000_hop300/checkpoint-500000steps.pkl 25 | stats: stats/symAD_libritts_24000_hop300_clean.npy -------------------------------------------------------------------------------- /AudioDec/config/statistic/symAD_vctk_48000_hop300_clean.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/) 8 | 9 | 10 | ########################################################### 11 | # DATA SETTING # 12 | ########################################################### 13 | sampling_rate: 48000 14 | data: 15 | path: /mnt/home/yichiaowu/datasets/vctk_noisy/48000 16 | subset: 17 | train: clean_trainset_84spk_wav 18 | valid: clean_validset_84spk_wav 19 | test: clean_testset_wav 20 | 21 | ########################################################### 22 | # STATISTIC SETTING # 23 | ########################################################### 24 | analyzer: exp/autoencoder/symAD_vctk_48000_hop300/checkpoint-200000steps.pkl 25 | stats: stats/symAD_vctk_48000_hop300_clean.npy 26 | -------------------------------------------------------------------------------- /AudioDec/config/statistic/symADuniv_vctk_48000_hop300_clean.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/) 8 | 9 | 10 | ########################################################### 11 | # DATA SETTING # 12 | ########################################################### 13 | sampling_rate: 48000 14 | data: 15 | path: /mnt/home/yichiaowu/datasets/vctk_noisy/48000 16 | subset: 17 | train: clean_trainset_84spk_wav 18 | valid: clean_validset_84spk_wav 19 | test: clean_testset_wav 20 | 21 | ########################################################### 22 | # STATISTIC SETTING # 23 | ########################################################### 24 | analyzer: exp/autoencoder/symADuniv_vctk_48000_hop300/checkpoint-500000steps.pkl 25 | stats: stats/symADuniv_vctk_48000_hop300_clean.npy 26 | -------------------------------------------------------------------------------- /AudioDec/dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import * # NOQA 2 | from .collater import * # NOQA 3 | from .utils import * # NOQA -------------------------------------------------------------------------------- /AudioDec/dataloader/collater.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright (c) Meta Platforms, Inc. and affiliates. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | # 10 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/) 11 | 12 | """Customized collater modules for Pytorch DataLoader.""" 13 | 14 | import torch 15 | import numpy as np 16 | 17 | 18 | class CollaterAudio(object): 19 | """Customized collater for loading single audio.""" 20 | 21 | def __init__( 22 | self, 23 | batch_length=9600, 24 | ): 25 | """ 26 | Args: 27 | batch_length (int): The length of audio signal batch. 28 | 29 | """ 30 | self.batch_length = batch_length 31 | 32 | 33 | def __call__(self, batch): 34 | # filter short batch 35 | xs = [b for b in batch if len(b) > self.batch_length] 36 | 37 | # random cut 38 | starts, ends = self._random_segment(xs) 39 | x_batch = self._cut(xs, starts, ends) 40 | 41 | return x_batch 42 | 43 | 44 | def _random_segment(self, xs): 45 | x_lengths = [len(x) for x in xs] 46 | start_offsets = np.array( 47 | [ 48 | np.random.randint(0, xl - self.batch_length) 49 | for xl in x_lengths 50 | ] 51 | ) 52 | starts = start_offsets 53 | ends = starts + self.batch_length 54 | return starts, ends 55 | 56 | 57 | def _cut(self, xs, starts, ends): 58 | x_batch = np.array([x[start:end] for x, start, end in zip(xs, starts, ends)]) 59 | x_batch = torch.tensor(x_batch, dtype=torch.float).transpose(2, 1) # (B, C, T) 60 | return x_batch 61 | 62 | 63 | class CollaterAudioPair(CollaterAudio): 64 | """Customized collater for loading audio pair.""" 65 | 66 | def __init__( 67 | self, 68 | batch_length=9600, 69 | ): 70 | super().__init__( 71 | batch_length=batch_length 72 | ) 73 | 74 | 75 | def __call__(self, batch): 76 | batch = [ 77 | b for b in batch if (len(b[0]) > self.batch_length) and (len(b[0]) == len(b[1])) 78 | ] 79 | assert len(batch) > 0, f"No qualified audio pairs.!" 80 | xs, ns = [b[0] for b in batch], [b[1] for b in batch] 81 | 82 | # random cut 83 | starts, ends = self._random_segment(xs) 84 | x_batch = self._cut(xs, starts, ends) 85 | n_batch = self._cut(ns, starts, ends) 86 | 87 | return n_batch, x_batch # (input, output) 88 | 89 | 90 | -------------------------------------------------------------------------------- /AudioDec/dataloader/dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright (c) Meta Platforms, Inc. and affiliates. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | # 10 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/) 11 | 12 | """PyTorch compatible dataset modules.""" 13 | 14 | import os 15 | import soundfile as sf 16 | from torch.utils.data import Dataset 17 | from dataloader.utils import find_files 18 | 19 | 20 | class SingleDataset(Dataset): 21 | def __init__( 22 | self, 23 | files, 24 | query="*.wav", 25 | load_fn=sf.read, 26 | return_utt_id=False, 27 | subset_num=-1, 28 | ): 29 | self.return_utt_id = return_utt_id 30 | self.load_fn = load_fn 31 | self.subset_num = subset_num 32 | 33 | self.filenames = self._load_list(files, query) 34 | self.utt_ids = self._load_ids(self.filenames) 35 | 36 | 37 | def __getitem__(self, idx): 38 | utt_id = self.utt_ids[idx] 39 | data = self._data(idx) 40 | 41 | if self.return_utt_id: 42 | items = utt_id, data 43 | else: 44 | items = data 45 | 46 | return items 47 | 48 | 49 | def __len__(self): 50 | return len(self.filenames) 51 | 52 | 53 | def _read_list(self, listfile): 54 | filenames = [] 55 | with open(listfile) as f: 56 | for line in f: 57 | line = line.strip() 58 | if len(line): 59 | filenames.append(line) 60 | return filenames 61 | 62 | 63 | def _load_list(self, files, query): 64 | if isinstance(files, list): 65 | filenames = files 66 | else: 67 | if os.path.isdir(files): 68 | filenames = sorted(find_files(files, query)) 69 | elif os.path.isfile(files): 70 | filenames = sorted(self._read_list(files)) 71 | else: 72 | raise ValueError(f"{files} is not a list / existing folder or file!") 73 | 74 | if self.subset_num > 0: 75 | filenames = filenames[:self.subset_num] 76 | assert len(filenames) != 0, f"File list in empty!" 77 | return filenames 78 | 79 | 80 | def _load_ids(self, filenames): 81 | utt_ids = [ 82 | os.path.splitext(os.path.basename(f))[0] for f in filenames 83 | ] 84 | return utt_ids 85 | 86 | 87 | def _data(self, idx): 88 | return self._load_data(self.filenames[idx], self.load_fn) 89 | 90 | 91 | def _load_data(self, filename, load_fn): 92 | if load_fn == sf.read: 93 | data, _ = load_fn(filename, always_2d=True) # (T, C) 94 | else: 95 | data = load_fn(filename) 96 | return data 97 | 98 | 99 | class MultiDataset(SingleDataset): 100 | def __init__( 101 | self, 102 | multi_files, 103 | queries, 104 | load_fns, 105 | return_utt_id=False, 106 | subset_num=-1, 107 | ): 108 | errmsg = f"multi_files({len(multi_files)}), queries({len(queries)}), and load_fns({len(load_fns)}) are length mismatched!" 109 | assert len(multi_files) == len(queries) == len(load_fns), errmsg 110 | super(MultiDataset, self).__init__( 111 | files=multi_files, 112 | query=queries, 113 | load_fn=load_fns, 114 | return_utt_id=return_utt_id, 115 | subset_num=subset_num, 116 | ) 117 | self._check_length(self.filenames) 118 | 119 | 120 | def _load_list(self, multi_files, queries): 121 | multi_filenames = [] 122 | if isinstance(multi_files, list): 123 | for files, query in zip(multi_files, queries): 124 | multi_filenames.append(super()._load_list(files, query)) 125 | else: 126 | raise ValueError(f"{multi_files} should be a list!") 127 | 128 | return multi_filenames 129 | 130 | 131 | def _load_ids(self, multi_filenames): 132 | return super()._load_ids(multi_filenames[0]) 133 | 134 | 135 | def _data(self, idx): 136 | filenames = [ 137 | f[idx] for f in self.filenames 138 | ] 139 | data = [] 140 | for filename, load_fn in zip(filenames, self.load_fn): 141 | data.append(self._load_data(filename, load_fn)) 142 | return data 143 | 144 | 145 | def _check_length(self, multi_filenames): 146 | errmsg = f"Not all lists have the same number of files!" 147 | self.file_num = len(multi_filenames[0]) 148 | assert all(len(x)==self.file_num for x in multi_filenames), errmsg 149 | 150 | 151 | def __len__(self): 152 | return self.file_num 153 | 154 | -------------------------------------------------------------------------------- /AudioDec/dataloader/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright (c) Meta Platforms, Inc. and affiliates. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | # 10 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/) 11 | 12 | import os 13 | import fnmatch 14 | import logging 15 | import numpy as np 16 | 17 | 18 | def find_files(root_dir, query="*.wav", include_root_dir=True): 19 | """Find files recursively. 20 | Args: 21 | root_dir (str): Root root_dir to find. 22 | query (str): Query to find. 23 | include_root_dir (bool): If False, root_dir name is not included. 24 | Returns: 25 | list: List of found filenames. 26 | """ 27 | files = [] 28 | for root, dirnames, filenames in os.walk(root_dir, followlinks=True): 29 | for filename in fnmatch.filter(filenames, query): 30 | files.append(os.path.join(root, filename)) 31 | if not include_root_dir: 32 | files = [file_.replace(root_dir + "/", "") for file_ in files] 33 | 34 | return files 35 | 36 | 37 | def load_files(data_path, query="*.wav", num_core=40): 38 | # sort all files 39 | file_list = sorted(find_files(data_path, query)) 40 | logging.info(f"The number of {os.path.basename(data_path)} files = {len(file_list)}.") 41 | # divide 42 | if num_core < len(file_list): 43 | file_lists = np.array_split(file_list, num_core) 44 | file_lists = [f_list.tolist() for f_list in file_lists] 45 | else: 46 | file_lists = [file_list] 47 | return file_lists -------------------------------------------------------------------------------- /AudioDec/figs/architecture-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dukGuo/valle-audiodec/cbd51784ea9fe0bee05b4a65c95f39523b57be89/AudioDec/figs/architecture-1.png -------------------------------------------------------------------------------- /AudioDec/figs/latency.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dukGuo/valle-audiodec/cbd51784ea9fe0bee05b4a65c95f39523b57be89/AudioDec/figs/latency.jpg -------------------------------------------------------------------------------- /AudioDec/figs/mos.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dukGuo/valle-audiodec/cbd51784ea9fe0bee05b4a65c95f39523b57be89/AudioDec/figs/mos.jpg -------------------------------------------------------------------------------- /AudioDec/layers/activation_function.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright (c) Meta Platforms, Inc. and affiliates. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | 10 | """Activation functions.""" 11 | 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | 17 | 18 | def get_activation(nonlinear_activation, nonlinear_activation_params={}): 19 | if hasattr(nn, nonlinear_activation): 20 | return getattr(nn, nonlinear_activation)(**nonlinear_activation_params) 21 | else: 22 | raise NotImplementedError(f"Activation {nonlinear_activation} is not supported!") -------------------------------------------------------------------------------- /AudioDec/layers/conv_layer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright (c) Meta Platforms, Inc. and affiliates. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | # 10 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/) 11 | 12 | """Convolution layers.""" 13 | 14 | import math 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | 19 | 20 | def int2tuple(variable, length): 21 | if isinstance(variable, int): 22 | return (variable,)*length 23 | else: 24 | assert len(variable) == length, f"The length of {variable} is not {length}!" 25 | return variable 26 | 27 | 28 | class Conv1d1x1(nn.Conv1d): 29 | """1x1 Conv1d.""" 30 | 31 | def __init__(self, in_channels, out_channels, bias=True): 32 | super(Conv1d1x1, self).__init__(in_channels, out_channels, kernel_size=1, bias=bias) 33 | 34 | 35 | class NonCausalConv1d(nn.Module): 36 | """1D noncausal convloution w/ 2-sides padding.""" 37 | 38 | def __init__( 39 | self, 40 | in_channels, 41 | out_channels, 42 | kernel_size, 43 | stride=1, 44 | padding=-1, 45 | dilation=1, 46 | groups=1, 47 | bias=True): 48 | super().__init__() 49 | self.in_channels = in_channels 50 | self.out_channels = out_channels 51 | self.kernel_size = kernel_size 52 | if padding < 0: 53 | padding = (kernel_size - 1) // 2 * dilation 54 | self.dilation = dilation 55 | self.conv = nn.Conv1d( 56 | in_channels=in_channels, 57 | out_channels=out_channels, 58 | kernel_size=kernel_size, 59 | stride=stride, 60 | padding=padding, 61 | dilation=dilation, 62 | groups=groups, 63 | bias=bias, 64 | ) 65 | 66 | def forward(self, x): 67 | """ 68 | Args: 69 | x (Tensor): Float tensor variable with the shape (B, C, T). 70 | Returns: 71 | Tensor: Float tensor variable with the shape (B, C, T). 72 | """ 73 | x = self.conv(x) 74 | return x 75 | 76 | 77 | class NonCausalConvTranspose1d(nn.Module): 78 | """1D noncausal transpose convloution.""" 79 | 80 | def __init__( 81 | self, 82 | in_channels, 83 | out_channels, 84 | kernel_size, 85 | stride, 86 | padding=-1, 87 | output_padding=-1, 88 | groups=1, 89 | bias=True, 90 | ): 91 | super().__init__() 92 | if padding < 0: 93 | padding = (stride+1) // 2 94 | if output_padding < 0: 95 | output_padding = 1 if stride % 2 else 0 96 | self.deconv = nn.ConvTranspose1d( 97 | in_channels=in_channels, 98 | out_channels=out_channels, 99 | kernel_size=kernel_size, 100 | stride=stride, 101 | padding=padding, 102 | output_padding=output_padding, 103 | groups=groups, 104 | bias=bias, 105 | ) 106 | 107 | def forward(self, x): 108 | """ 109 | Args: 110 | x (Tensor): Float tensor variable with the shape (B, C, T). 111 | Returns: 112 | Tensor: Float tensor variable with the shape (B, C', T'). 113 | """ 114 | x = self.deconv(x) 115 | return x 116 | 117 | 118 | class CausalConv1d(NonCausalConv1d): 119 | """1D causal convloution w/ 1-side padding.""" 120 | 121 | def __init__( 122 | self, 123 | in_channels, 124 | out_channels, 125 | kernel_size, 126 | stride=1, 127 | dilation=1, 128 | groups=1, 129 | bias=True, 130 | pad_buffer=None, 131 | ): 132 | super(CausalConv1d, self).__init__( 133 | in_channels=in_channels, 134 | out_channels=out_channels, 135 | kernel_size=kernel_size, 136 | stride=stride, 137 | padding=0, 138 | dilation=dilation, 139 | groups=groups, 140 | bias=bias, 141 | ) 142 | self.stride = stride 143 | self.pad_length = (kernel_size - 1) * dilation 144 | if pad_buffer is None: 145 | pad_buffer = torch.zeros(1, in_channels, self.pad_length) 146 | self.register_buffer("pad_buffer", pad_buffer) 147 | 148 | def forward(self, x): 149 | pad = nn.ConstantPad1d((self.pad_length, 0), 0.0) 150 | x = pad(x) 151 | return self.conv(x) 152 | 153 | def inference(self, x): 154 | x = torch.cat((self.pad_buffer, x), -1) 155 | self.pad_buffer = x[:, :, -self.pad_length:] 156 | return self.conv(x) 157 | 158 | def reset_buffer(self): 159 | self.pad_buffer.zero_() 160 | 161 | 162 | class CausalConvTranspose1d(NonCausalConvTranspose1d): 163 | """1D causal transpose convloution.""" 164 | 165 | def __init__( 166 | self, 167 | in_channels, 168 | out_channels, 169 | kernel_size, 170 | stride, 171 | bias=True, 172 | pad_buffer=None, 173 | ): 174 | super(CausalConvTranspose1d, self).__init__( 175 | in_channels=in_channels, 176 | out_channels=out_channels, 177 | kernel_size=kernel_size, 178 | stride=stride, 179 | padding=0, 180 | output_padding=0, 181 | bias=bias, 182 | ) 183 | self.stride = stride 184 | self.pad_length = (math.ceil(kernel_size/stride) - 1) 185 | if pad_buffer is None: 186 | pad_buffer = torch.zeros(1, in_channels, self.pad_length) 187 | self.register_buffer("pad_buffer", pad_buffer) 188 | 189 | def forward(self, x): 190 | pad = nn.ReplicationPad1d((self.pad_length, 0)) 191 | x = pad(x) 192 | return self.deconv(x)[:, :, self.stride : -self.stride] 193 | 194 | def inference(self, x): 195 | x = torch.cat((self.pad_buffer, x), -1) 196 | self.pad_buffer = x[:, :, -self.pad_length:] 197 | return self.deconv(x)[:, :, self.stride : -self.stride] 198 | 199 | def reset_buffer(self): 200 | self.pad_buffer.zero_() 201 | 202 | 203 | class NonCausalConv2d(nn.Module): 204 | """2D noncausal convloution w/ 4-sides padding.""" 205 | 206 | def __init__( 207 | self, 208 | in_channels, 209 | out_channels, 210 | kernel_size, 211 | stride=1, 212 | padding=-1, 213 | dilation=1, 214 | groups=1, 215 | bias=True): 216 | super().__init__() 217 | self.in_channels = in_channels 218 | self.out_channels = out_channels 219 | self.kernel_size = int2tuple(kernel_size, 2) 220 | self.dilation = int2tuple(dilation, 2) 221 | if isinstance(padding, int) and padding < 0: 222 | padding_0 = (self.kernel_size[0] - 1) // 2 * self.dilation[0] 223 | padding_1 = (self.kernel_size[1] - 1) // 2 * self.dilation[1] 224 | padding = (padding_0, padding_1) 225 | 226 | self.conv = nn.Conv2d( 227 | in_channels=in_channels, 228 | out_channels=out_channels, 229 | kernel_size=kernel_size, 230 | stride=stride, 231 | padding=padding, 232 | dilation=dilation, 233 | groups=groups, 234 | bias=bias, 235 | ) 236 | 237 | def forward(self, x): 238 | """ 239 | Args: 240 | x (Tensor): Float tensor variable with the shape (B, C, T). 241 | Returns: 242 | Tensor: Float tensor variable with the shape (B, C, T). 243 | """ 244 | x = self.conv(x) 245 | return x -------------------------------------------------------------------------------- /AudioDec/layers/vq_module.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright (c) Meta Platforms, Inc. and affiliates. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | # 10 | # Reference (https://github.com/lucidrains/vector-quantize-pytorch/) 11 | 12 | """Vector quantizer.""" 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | 18 | 19 | class VectorQuantize(nn.Module): 20 | """Vector quantization w/ exponential moving averages (EMA)""" 21 | 22 | def __init__( 23 | self, 24 | dim, 25 | codebook_size, 26 | decay = 0.8, 27 | commitment = 1., 28 | eps = 1e-5, 29 | n_embed = None, 30 | ): 31 | super().__init__() 32 | n_embed = self.default(n_embed, codebook_size) 33 | 34 | self.dim = dim 35 | self.n_embed = n_embed 36 | self.decay = decay 37 | self.eps = eps 38 | self.commitment = commitment 39 | 40 | embed = torch.randn(dim, n_embed) 41 | self.register_buffer('embed', embed) 42 | self.register_buffer('cluster_size', torch.zeros(n_embed)) 43 | self.register_buffer('embed_avg', embed.clone()) 44 | 45 | @property 46 | def codebook(self): 47 | return self.embed.transpose(0, 1) 48 | 49 | def exists(self,val): 50 | return val is not None 51 | 52 | def default(self, val, d): 53 | return val if self.exists(val) else d 54 | 55 | def ema_inplace(self, moving_avg, new, decay): 56 | moving_avg.data.mul_(decay).add_(new, alpha = (1 - decay)) 57 | 58 | def laplace_smoothing(self, x, n_categories, eps=1e-5): 59 | return (x + eps) / (x.sum() + n_categories * eps) 60 | 61 | def forward(self, input): 62 | dtype = input.dtype 63 | flatten = input.reshape(-1, self.dim) 64 | dist = ( 65 | flatten.pow(2).sum(1, keepdim=True) 66 | - 2 * flatten @ self.embed 67 | + self.embed.pow(2).sum(0, keepdim=True) 68 | ) 69 | _, embed_ind = (-dist).max(1) 70 | embed_onehot = F.one_hot(embed_ind, self.n_embed).type(dtype) 71 | embed_ind = embed_ind.view(*input.shape[:-1]) 72 | quantize = F.embedding(embed_ind, self.embed.transpose(0, 1)) 73 | 74 | if self.training: 75 | self.ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) 76 | embed_sum = flatten.transpose(0, 1) @ embed_onehot 77 | self.ema_inplace(self.embed_avg, embed_sum, self.decay) 78 | cluster_size = self.laplace_smoothing(self.cluster_size, self.n_embed, self.eps) * self.cluster_size.sum() 79 | embed_normalized = self.embed_avg / cluster_size.unsqueeze(0) 80 | self.embed.data.copy_(embed_normalized) 81 | 82 | loss = F.mse_loss(quantize.detach(), input) * self.commitment 83 | quantize = input + (quantize - input).detach() 84 | 85 | avg_probs = torch.mean(embed_onehot, dim=0) 86 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) 87 | 88 | return quantize, loss, perplexity 89 | 90 | def forward_index(self, input): 91 | dtype = input.dtype 92 | flatten = input.reshape(-1, self.dim) 93 | dist = ( 94 | flatten.pow(2).sum(1, keepdim=True) 95 | - 2 * flatten @ self.embed 96 | + self.embed.pow(2).sum(0, keepdim=True) 97 | ) 98 | _, embed_ind = (-dist).max(1) 99 | embed_onehot = F.one_hot(embed_ind, self.n_embed).type(dtype) 100 | embed_ind = embed_ind.view(*input.shape[:-1]) 101 | quantize = F.embedding(embed_ind, self.embed.transpose(0, 1)) 102 | quantize = input + (quantize - input).detach() 103 | 104 | return quantize, embed_ind 105 | 106 | 107 | class ResidualVQ(nn.Module): 108 | """ Residual VQ following algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """ 109 | 110 | def __init__( 111 | self, 112 | *, 113 | num_quantizers, 114 | **kwargs 115 | ): 116 | super().__init__() 117 | self.layers = nn.ModuleList([VectorQuantize(**kwargs) for _ in range(num_quantizers)]) 118 | 119 | def forward(self, x): 120 | quantized_out = 0. 121 | residual = x 122 | all_losses = [] 123 | all_perplexities = [] 124 | for layer in self.layers: 125 | quantized, loss, perplexity = layer(residual) 126 | # Issue: https://github.com/lucidrains/vector-quantize-pytorch/issues/33 127 | # We found considering only the 1st layer VQ's graident results in better performance 128 | #residual = residual - quantized.detach() # considering all layers' graidents 129 | residual = residual - quantized # considering only the first layer's graident 130 | quantized_out = quantized_out + quantized 131 | all_losses.append(loss) 132 | all_perplexities.append(perplexity) 133 | all_losses, all_perplexities = map(torch.stack, (all_losses, all_perplexities)) 134 | return quantized_out, all_losses, all_perplexities 135 | 136 | def forward_index(self, x, flatten_idx=False): 137 | quantized_out = 0. 138 | residual = x 139 | all_indices = [] 140 | for i, layer in enumerate(self.layers): 141 | quantized, indices = layer.forward_index(residual) 142 | #residual = residual - quantized.detach() 143 | residual = residual - quantized 144 | quantized_out = quantized_out + quantized 145 | if flatten_idx: 146 | indices += (self.codebook_size * i) 147 | all_indices.append(indices) 148 | all_indices= torch.stack(all_indices) 149 | return quantized_out, all_indices.squeeze(1) 150 | 151 | def initial(self): 152 | self.codebook = [] 153 | for layer in self.layers: 154 | self.codebook.append(layer.codebook) 155 | self.codebook_size = self.codebook[0].size(0) 156 | self.codebook = torch.stack(self.codebook) 157 | self.codebook = self.codebook.reshape(-1, self.codebook.size(-1)) 158 | 159 | def lookup(self, indices): 160 | quantized_out = F.embedding(indices, self.codebook) # Num x T x C 161 | return torch.sum(quantized_out, dim=0,keepdim=True) 162 | -------------------------------------------------------------------------------- /AudioDec/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .adversarial_loss import * # NOQA 2 | from .feat_match_loss import * # NOQA 3 | from .mel_loss import * # NOQA 4 | from .stft_loss import * # NOQA 5 | from .waveform_loss import * # NOQA -------------------------------------------------------------------------------- /AudioDec/losses/adversarial_loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2021 Tomoki Hayashi 5 | # MIT License (https://opensource.org/licenses/MIT) 6 | 7 | """Adversarial loss modules.""" 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | 13 | class GeneratorAdversarialLoss(torch.nn.Module): 14 | """Generator adversarial loss module.""" 15 | 16 | def __init__( 17 | self, 18 | average_by_discriminators=True, 19 | loss_type="mse", 20 | ): 21 | """Initialize GeneratorAversarialLoss module.""" 22 | super().__init__() 23 | self.average_by_discriminators = average_by_discriminators 24 | assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." 25 | if loss_type == "mse": 26 | self.criterion = self._mse_loss 27 | else: 28 | self.criterion = self._hinge_loss 29 | 30 | def forward(self, outputs): 31 | """Calcualate generator adversarial loss. 32 | 33 | Args: 34 | outputs (Tensor or list): Discriminator outputs or list of 35 | discriminator outputs. 36 | 37 | Returns: 38 | Tensor: Generator adversarial loss value. 39 | 40 | """ 41 | if isinstance(outputs, (tuple, list)): 42 | adv_loss = 0.0 43 | for i, outputs_ in enumerate(outputs): 44 | if isinstance(outputs_, (tuple, list)): 45 | # NOTE(kan-bayashi): case including feature maps 46 | outputs_ = outputs_[-1] 47 | adv_loss += self.criterion(outputs_) 48 | if self.average_by_discriminators: 49 | adv_loss /= i + 1 50 | else: 51 | adv_loss = self.criterion(outputs) 52 | 53 | return adv_loss 54 | 55 | def _mse_loss(self, x): 56 | return F.mse_loss(x, x.new_ones(x.size())) 57 | 58 | def _hinge_loss(self, x): 59 | return -x.mean() 60 | 61 | 62 | class DiscriminatorAdversarialLoss(torch.nn.Module): 63 | """Discriminator adversarial loss module.""" 64 | 65 | def __init__( 66 | self, 67 | average_by_discriminators=True, 68 | loss_type="mse", 69 | ): 70 | """Initialize DiscriminatorAversarialLoss module.""" 71 | super().__init__() 72 | self.average_by_discriminators = average_by_discriminators 73 | assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." 74 | if loss_type == "mse": 75 | self.fake_criterion = self._mse_fake_loss 76 | self.real_criterion = self._mse_real_loss 77 | else: 78 | self.fake_criterion = self._hinge_fake_loss 79 | self.real_criterion = self._hinge_real_loss 80 | 81 | def forward(self, outputs_hat, outputs): 82 | """Calcualate discriminator adversarial loss. 83 | 84 | Args: 85 | outputs_hat (Tensor or list): Discriminator outputs or list of 86 | discriminator outputs calculated from generator outputs. 87 | outputs (Tensor or list): Discriminator outputs or list of 88 | discriminator outputs calculated from groundtruth. 89 | 90 | Returns: 91 | Tensor: Discriminator real loss value. 92 | Tensor: Discriminator fake loss value. 93 | 94 | """ 95 | if isinstance(outputs, (tuple, list)): 96 | real_loss = 0.0 97 | fake_loss = 0.0 98 | for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)): 99 | if isinstance(outputs_hat_, (tuple, list)): 100 | # NOTE(kan-bayashi): case including feature maps 101 | outputs_hat_ = outputs_hat_[-1] 102 | outputs_ = outputs_[-1] 103 | real_loss += self.real_criterion(outputs_) 104 | fake_loss += self.fake_criterion(outputs_hat_) 105 | if self.average_by_discriminators: 106 | fake_loss /= i + 1 107 | real_loss /= i + 1 108 | else: 109 | real_loss = self.real_criterion(outputs) 110 | fake_loss = self.fake_criterion(outputs_hat) 111 | 112 | return real_loss, fake_loss 113 | 114 | def _mse_real_loss(self, x): 115 | return F.mse_loss(x, x.new_ones(x.size())) 116 | 117 | def _mse_fake_loss(self, x): 118 | return F.mse_loss(x, x.new_zeros(x.size())) 119 | 120 | def _hinge_real_loss(self, x): 121 | return -torch.mean(torch.min(x - 1, x.new_zeros(x.size()))) 122 | 123 | def _hinge_fake_loss(self, x): 124 | return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size()))) 125 | -------------------------------------------------------------------------------- /AudioDec/losses/feat_match_loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2021 Tomoki Hayashi 5 | # MIT License (https://opensource.org/licenses/MIT) 6 | 7 | """Feature matching loss modules.""" 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | 13 | class FeatureMatchLoss(torch.nn.Module): 14 | """Feature matching loss module.""" 15 | 16 | def __init__( 17 | self, 18 | average_by_layers=True, 19 | average_by_discriminators=True, 20 | include_final_outputs=False, 21 | ): 22 | """Initialize FeatureMatchLoss module.""" 23 | super().__init__() 24 | self.average_by_layers = average_by_layers 25 | self.average_by_discriminators = average_by_discriminators 26 | self.include_final_outputs = include_final_outputs 27 | 28 | def forward(self, feats_hat, feats): 29 | """Calcualate feature matching loss. 30 | 31 | Args: 32 | feats_hat (list): List of list of discriminator outputs 33 | calcuated from generater outputs. 34 | feats (list): List of list of discriminator outputs 35 | calcuated from groundtruth. 36 | 37 | Returns: 38 | Tensor: Feature matching loss value. 39 | 40 | """ 41 | feat_match_loss = 0.0 42 | for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)): 43 | feat_match_loss_ = 0.0 44 | if not self.include_final_outputs: 45 | feats_hat_ = feats_hat_[:-1] 46 | feats_ = feats_[:-1] 47 | for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)): 48 | feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach()) 49 | if self.average_by_layers: 50 | feat_match_loss_ /= j + 1 51 | feat_match_loss += feat_match_loss_ 52 | if self.average_by_discriminators: 53 | feat_match_loss /= i + 1 54 | 55 | return feat_match_loss 56 | -------------------------------------------------------------------------------- /AudioDec/losses/mel_loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright (c) Meta Platforms, Inc. and affiliates. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | # 10 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/) 11 | 12 | """Mel-spectrogram loss modules.""" 13 | 14 | import librosa 15 | import torch 16 | import torch.nn.functional as F 17 | 18 | 19 | class MelSpectrogram(torch.nn.Module): 20 | """Calculate Mel-spectrogram.""" 21 | 22 | def __init__( 23 | self, 24 | fs=22050, 25 | fft_size=1024, 26 | hop_size=256, 27 | win_length=None, 28 | window="hann_window", 29 | num_mels=80, 30 | fmin=80, 31 | fmax=7600, 32 | center=True, 33 | normalized=False, 34 | onesided=True, 35 | eps=1e-10, 36 | log_base=10.0, 37 | ): 38 | """Initialize MelSpectrogram module.""" 39 | super().__init__() 40 | self.fft_size = fft_size 41 | self.hop_size = hop_size 42 | if win_length is not None: 43 | self.win_length = win_length 44 | else: 45 | self.win_length = fft_size 46 | self.center = center 47 | self.normalized = normalized 48 | self.onesided = onesided 49 | self.register_buffer("window", getattr(torch, window)(self.win_length)) 50 | self.eps = eps 51 | 52 | fmin = 0 if fmin is None else fmin 53 | fmax = fs / 2 if fmax is None else fmax 54 | melmat = librosa.filters.mel( 55 | sr=fs, 56 | n_fft=fft_size, 57 | n_mels=num_mels, 58 | fmin=fmin, 59 | fmax=fmax, 60 | ) 61 | self.register_buffer("melmat", torch.from_numpy(melmat.T).float()) 62 | 63 | self.log_base = log_base 64 | if self.log_base is None: 65 | self.log = torch.log 66 | elif self.log_base == 2.0: 67 | self.log = torch.log2 68 | elif self.log_base == 10.0: 69 | self.log = torch.log10 70 | else: 71 | raise ValueError(f"log_base: {log_base} is not supported.") 72 | 73 | 74 | def forward(self, x): 75 | """Calculate Mel-spectrogram. 76 | 77 | Args: 78 | x (Tensor): Input waveform tensor (B, T) or (B, C, T). 79 | 80 | Returns: 81 | Tensor: Mel-spectrogram (B, #mels, #frames). 82 | 83 | """ 84 | if x.dim() == 3: 85 | # (B, C, T) -> (B*C, T) 86 | x = x.reshape(-1, x.size(2)) 87 | 88 | x_stft = torch.stft(x, self.fft_size, self.hop_size, self.win_length, self.window, return_complex=True) 89 | x_power = x_stft.real ** 2 + x_stft.imag ** 2 90 | x_amp = torch.sqrt(torch.clamp(x_power, min=self.eps)).transpose(2, 1) # (B, D, T') -> (B, T', D) 91 | x_mel = torch.matmul(x_amp, self.melmat) 92 | x_mel = torch.clamp(x_mel, min=self.eps) 93 | 94 | return self.log(x_mel).transpose(1, 2) # (B, D, T') 95 | 96 | 97 | class MultiMelSpectrogramLoss(torch.nn.Module): 98 | """Multi resolution Mel-spectrogram loss.""" 99 | 100 | def __init__( 101 | self, 102 | fs=22050, 103 | fft_sizes=[1024, 2048, 512], 104 | hop_sizes=[120, 240, 50], 105 | win_lengths=[600, 1200, 240], 106 | window="hann_window", 107 | num_mels=80, 108 | fmin=80, 109 | fmax=7600, 110 | center=True, 111 | normalized=False, 112 | onesided=True, 113 | eps=1e-10, 114 | log_base=10.0, 115 | ): 116 | """Initialize Mel-spectrogram loss.""" 117 | super().__init__() 118 | assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) 119 | self.mel_transfers = torch.nn.ModuleList() 120 | for fft_size, hop_size, win_length in zip(fft_sizes, hop_sizes, win_lengths): 121 | self.mel_transfers += [ 122 | MelSpectrogram( 123 | fs=fs, 124 | fft_size=fft_size, 125 | hop_size=hop_size, 126 | win_length=win_length, 127 | window=window, 128 | num_mels=num_mels, 129 | fmin=fmin, 130 | fmax=fmax, 131 | center=center, 132 | normalized=normalized, 133 | onesided=onesided, 134 | eps=eps, 135 | log_base=log_base, 136 | ) 137 | ] 138 | 139 | 140 | def forward(self, y_hat, y): 141 | """Calculate Mel-spectrogram loss. 142 | 143 | Args: 144 | y_hat (Tensor): Generated single tensor (B, C, T). 145 | y (Tensor): Groundtruth single tensor (B, C, T). 146 | 147 | Returns: 148 | Tensor: Mel-spectrogram loss value. 149 | 150 | """ 151 | mel_loss = 0.0 152 | for f in self.mel_transfers: 153 | mel_loss += F.l1_loss(f(y_hat), f(y)) 154 | mel_loss /= len(self.mel_transfers) 155 | 156 | return mel_loss -------------------------------------------------------------------------------- /AudioDec/losses/stft_loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright (c) Meta Platforms, Inc. and affiliates. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | # 10 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/) 11 | 12 | """STFT-based loss modules.""" 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | 17 | 18 | 19 | def stft(x, fft_size, hop_size, win_length, window, eps=1e-7): 20 | """Perform STFT and convert to magnitude spectrogram. 21 | 22 | Args: 23 | x (Tensor): Input signal tensor (B, T). 24 | fft_size (int): FFT size. 25 | hop_size (int): Hop size. 26 | win_length (int): Window length. 27 | window (str): Window function type. 28 | 29 | Returns: 30 | Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1). 31 | 32 | """ 33 | x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True) 34 | x_power = x_stft.real ** 2 + x_stft.imag ** 2 35 | return torch.sqrt(torch.clamp(x_power, min=eps)).transpose(2, 1) 36 | 37 | 38 | class SpectralConvergenceLoss(torch.nn.Module): 39 | """Spectral convergence loss module.""" 40 | 41 | def __init__(self): 42 | """Initilize spectral convergence loss module.""" 43 | super(SpectralConvergenceLoss, self).__init__() 44 | 45 | def forward(self, x_mag, y_mag): 46 | """Calculate forward propagation. 47 | 48 | Args: 49 | x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). 50 | y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). 51 | 52 | Returns: 53 | Tensor: Spectral convergence loss value. 54 | 55 | """ 56 | return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro") 57 | 58 | 59 | class LogSTFTMagnitudeLoss(torch.nn.Module): 60 | """Log STFT magnitude loss module.""" 61 | 62 | def __init__(self): 63 | """Initilize los STFT magnitude loss module.""" 64 | super(LogSTFTMagnitudeLoss, self).__init__() 65 | 66 | def forward(self, x_mag, y_mag): 67 | """Calculate forward propagation. 68 | 69 | Args: 70 | x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). 71 | y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). 72 | 73 | Returns: 74 | Tensor: Log STFT magnitude loss value. 75 | 76 | """ 77 | return F.l1_loss(torch.log(y_mag), torch.log(x_mag)) 78 | 79 | 80 | class STFTLoss(torch.nn.Module): 81 | """STFT loss module.""" 82 | 83 | def __init__( 84 | self, 85 | fft_size=1024, 86 | hop_size=120, 87 | win_length=600, 88 | window="hann_window", 89 | ): 90 | """Initialize STFT loss module.""" 91 | super(STFTLoss, self).__init__() 92 | self.fft_size = fft_size 93 | self.hop_size = hop_size 94 | self.win_length = win_length 95 | self.spectral_convergence_loss = SpectralConvergenceLoss() 96 | self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss() 97 | self.register_buffer("window", getattr(torch, window)(win_length)) 98 | 99 | 100 | def forward(self, x, y): 101 | """Calculate forward propagation. 102 | 103 | Args: 104 | x (Tensor): Predicted signal (B, T). 105 | y (Tensor): Groundtruth signal (B, T). 106 | 107 | Returns: 108 | Tensor: Spectral convergence loss value. 109 | Tensor: Log STFT magnitude loss value. 110 | 111 | """ 112 | x_mag = stft(x, self.fft_size, self.hop_size, self.win_length, self.window) 113 | y_mag = stft(y, self.fft_size, self.hop_size, self.win_length, self.window) 114 | sc_loss = self.spectral_convergence_loss(x_mag, y_mag) 115 | mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag) 116 | 117 | return sc_loss, mag_loss 118 | 119 | 120 | class MultiResolutionSTFTLoss(torch.nn.Module): 121 | """Multi resolution STFT loss module.""" 122 | 123 | def __init__( 124 | self, 125 | fft_sizes=[1024, 2048, 512], 126 | hop_sizes=[120, 240, 50], 127 | win_lengths=[600, 1200, 240], 128 | window="hann_window", 129 | ): 130 | """Initialize Multi resolution STFT loss module. 131 | 132 | Args: 133 | fft_sizes (list): List of FFT sizes. 134 | hop_sizes (list): List of hop sizes. 135 | win_lengths (list): List of window lengths. 136 | window (str): Window function type. 137 | 138 | """ 139 | super(MultiResolutionSTFTLoss, self).__init__() 140 | assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) 141 | self.stft_losses = torch.nn.ModuleList() 142 | for fft_size, hop_size, win_length in zip(fft_sizes, hop_sizes, win_lengths): 143 | self.stft_losses += [STFTLoss(fft_size, hop_size, win_length, window)] 144 | 145 | 146 | def forward(self, x, y): 147 | """Calculate forward propagation. 148 | 149 | Args: 150 | x (Tensor): Predicted signal (B, T) or (B, #subband, T). 151 | y (Tensor): Groundtruth signal (B, T) or (B, #subband, T). 152 | 153 | Returns: 154 | Tensor: Multi resolution spectral convergence loss value. 155 | Tensor: Multi resolution log STFT magnitude loss value. 156 | 157 | """ 158 | if len(x.shape) == 3: 159 | x = x.view(-1, x.size(2)) # (B, C, T) -> (B x C, T) 160 | y = y.view(-1, y.size(2)) # (B, C, T) -> (B x C, T) 161 | sc_loss = 0.0 162 | mag_loss = 0.0 163 | for f in self.stft_losses: 164 | sc_l, mag_l = f(x, y) 165 | sc_loss += sc_l 166 | mag_loss += mag_l 167 | sc_loss /= len(self.stft_losses) 168 | mag_loss /= len(self.stft_losses) 169 | 170 | return sc_loss, mag_loss 171 | -------------------------------------------------------------------------------- /AudioDec/losses/waveform_loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright (c) Meta Platforms, Inc. and affiliates. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | 10 | """Waveform-based loss modules.""" 11 | 12 | import torch 13 | 14 | 15 | class WaveformShapeLoss(torch.nn.Module): 16 | """Waveform shape loss.""" 17 | 18 | def __init__(self, winlen): 19 | super().__init__() 20 | self.loss = torch.nn.L1Loss() 21 | self.winlen = winlen 22 | self.maxpool = torch.nn.MaxPool1d(self.winlen) 23 | 24 | def forward(self, y_hat, y): 25 | """Calculate L1 loss. 26 | 27 | Args: 28 | y_hat (Tensor): Generated single tensor (B, 1, T). 29 | y (Tensor): Groundtruth single tensor (B, 1, T). 30 | 31 | Returns: 32 | Tensor: L1 loss value. 33 | 34 | """ 35 | ys = self.maxpool(torch.abs(y)) 36 | ys_hat = self.maxpool(torch.abs(y_hat)) 37 | loss = self.loss(ys_hat, ys) 38 | return loss 39 | 40 | 41 | class MultiWindowShapeLoss(torch.nn.Module): 42 | """Multi-window-lengthe waveform shape loss.""" 43 | 44 | def __init__( 45 | self, 46 | winlen=[300, 200, 100], 47 | ): 48 | """Initialize Multi window shape loss module. 49 | 50 | Args: 51 | winlen (list): List of window lengths. 52 | 53 | """ 54 | super(MultiWindowShapeLoss, self).__init__() 55 | self.shape_losses = torch.nn.ModuleList() 56 | for wl in winlen: 57 | self.shape_losses += [WaveformShapeLoss(wl)] 58 | 59 | def forward(self, y_hat, y): 60 | """Calculate L1 loss. 61 | 62 | Args: 63 | y_hat (Tensor): Generated single tensor (B, 1, T). 64 | y (Tensor): Groundtruth single tensor (B, 1, T). 65 | 66 | Returns: 67 | Tensor: L2 loss value. 68 | 69 | """ 70 | loss = 0.0 71 | for f in self.shape_losses: 72 | loss += f(y_hat, y) 73 | loss /= len(self.shape_losses) 74 | 75 | return loss -------------------------------------------------------------------------------- /AudioDec/models/autoencoder/AudioDec.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright (c) Meta Platforms, Inc. and affiliates. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | # 10 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/) 11 | # Reference (https://github.com/jik876/hifi-gan/) 12 | 13 | """AudioDec model.""" 14 | 15 | import torch 16 | import logging 17 | 18 | from AudioDec.layers.conv_layer import CausalConv1d, CausalConvTranspose1d 19 | from AudioDec.models.autoencoder.modules.encoder import Encoder, ActivateEncoder 20 | from AudioDec.models.autoencoder.modules.decoder import Decoder, ActivateDecoder 21 | from AudioDec.models.autoencoder.modules.projector import Projector 22 | from AudioDec.models.autoencoder.modules.quantizer import Quantizer 23 | from AudioDec.models.utils import check_mode 24 | 25 | 26 | ### GENERATOR ### 27 | class Generator(torch.nn.Module): 28 | """AudioDec generator.""" 29 | 30 | def __init__( 31 | self, 32 | input_channels=1, 33 | output_channels=1, 34 | encode_channels=32, 35 | decode_channels=32, 36 | code_dim=64, 37 | codebook_num=8, 38 | codebook_size=1024, 39 | bias=True, 40 | enc_ratios=(2, 4, 8, 16), 41 | dec_ratios=(16, 8, 4, 2), 42 | enc_strides=(3, 4, 5, 5), 43 | dec_strides=(5, 5, 4, 3), 44 | mode='causal', 45 | codec='AudioDec', 46 | projector='conv1d', 47 | quantier='residual_vq', 48 | nonlinear_activation="ELU", 49 | nonlinear_activation_params={}, 50 | use_weight_norm=False, 51 | ): 52 | super().__init__() 53 | if codec == 'audiodec': 54 | encoder = Encoder 55 | decoder = Decoder 56 | elif codec == 'activate_audiodec': 57 | encoder = ActivateEncoder 58 | decoder = ActivateDecoder 59 | else: 60 | raise NotImplementedError(f"Codec ({codec}) is not supported!") 61 | self.mode = mode 62 | self.input_channels = input_channels 63 | 64 | self.encoder = encoder( 65 | input_channels=input_channels, 66 | encode_channels=encode_channels, 67 | channel_ratios=enc_ratios, 68 | strides=enc_strides, 69 | kernel_size=7, 70 | bias=bias, 71 | mode=self.mode, 72 | nonlinear_activation=nonlinear_activation, 73 | nonlinear_activation_params=nonlinear_activation_params, 74 | ) 75 | 76 | self.decoder = decoder( 77 | code_dim=code_dim, 78 | output_channels=output_channels, 79 | decode_channels=decode_channels, 80 | channel_ratios=dec_ratios, 81 | strides=dec_strides, 82 | kernel_size=7, 83 | bias=bias, 84 | mode=self.mode, 85 | nonlinear_activation=nonlinear_activation, 86 | nonlinear_activation_params=nonlinear_activation_params, 87 | ) 88 | 89 | self.projector = Projector( 90 | input_channels=self.encoder.out_channels, 91 | code_dim=code_dim, 92 | kernel_size=3, 93 | stride=1, 94 | bias=False, 95 | mode=self.mode, 96 | model=projector, 97 | ) 98 | 99 | self.quantizer = Quantizer( 100 | code_dim=code_dim, 101 | codebook_num=codebook_num, 102 | codebook_size=codebook_size, 103 | model=quantier, 104 | ) 105 | 106 | # apply weight norm & reset parameters 107 | if use_weight_norm: 108 | self.apply_weight_norm() 109 | self.reset_parameters() 110 | 111 | 112 | def forward(self, x): 113 | (batch, channel, length) = x.size() 114 | if channel != self.input_channels: 115 | x = x.reshape(-1, self.input_channels, length) # (B, C, T) -> (B', C', T) 116 | x = self.encoder(x) 117 | z = self.projector(x) 118 | zq, vqloss, perplexity = self.quantizer(z) 119 | y = self.decoder(zq) 120 | return y, zq, z, vqloss, perplexity 121 | 122 | 123 | def reset_parameters(self): 124 | """Reset parameters. 125 | 126 | This initialization follows the official implementation manner. 127 | https://github.com/jik876/hifi-gan/blob/master/models.py 128 | 129 | """ 130 | 131 | def _reset_parameters(m): 132 | if isinstance(m, (torch.nn.Conv1d, torch.nn.ConvTranspose1d)): 133 | m.weight.data.normal_(0.0, 0.01) 134 | logging.debug(f"Reset parameters in {m}.") 135 | 136 | self.apply(_reset_parameters) 137 | 138 | 139 | def remove_weight_norm(self): 140 | """Remove weight normalization module from all of the layers.""" 141 | 142 | def _remove_weight_norm(m): 143 | try: 144 | logging.debug(f"Weight norm is removed from {m}.") 145 | torch.nn.utils.remove_weight_norm(m) 146 | except ValueError: # this module didn't have weight norm 147 | return 148 | 149 | self.apply(_remove_weight_norm) 150 | 151 | 152 | def apply_weight_norm(self): 153 | """Apply weight normalization module from all of the layers.""" 154 | 155 | def _apply_weight_norm(m): 156 | if isinstance(m, torch.nn.Conv1d) or isinstance( 157 | m, torch.nn.ConvTranspose1d 158 | ): 159 | torch.nn.utils.weight_norm(m) 160 | logging.debug(f"Weight norm is applied to {m}.") 161 | 162 | self.apply(_apply_weight_norm) 163 | 164 | 165 | # STREAMING 166 | class StreamGenerator(Generator): 167 | """AudioDec streaming generator.""" 168 | 169 | def __init__( 170 | self, 171 | input_channels=1, 172 | output_channels=1, 173 | encode_channels=32, 174 | decode_channels=32, 175 | code_dim=64, 176 | codebook_num=8, 177 | codebook_size=1024, 178 | bias=True, 179 | enc_ratios=(2, 4, 8, 16), 180 | dec_ratios=(16, 8, 4, 2), 181 | enc_strides=(3, 4, 5, 5), 182 | dec_strides=(5, 5, 4, 3), 183 | mode='causal', 184 | codec='audiodec', 185 | projector='conv1d', 186 | quantier='residual_vq', 187 | nonlinear_activation="ELU", 188 | nonlinear_activation_params={}, 189 | use_weight_norm=False, 190 | ): 191 | super(StreamGenerator, self).__init__( 192 | input_channels=input_channels, 193 | output_channels=output_channels, 194 | encode_channels=encode_channels, 195 | decode_channels=decode_channels, 196 | code_dim=code_dim, 197 | codebook_num=codebook_num, 198 | codebook_size=codebook_size, 199 | bias=bias, 200 | enc_ratios=enc_ratios, 201 | dec_ratios=dec_ratios, 202 | enc_strides=enc_strides, 203 | dec_strides=dec_strides, 204 | mode=mode, 205 | codec=codec, 206 | projector=projector, 207 | quantier=quantier, 208 | nonlinear_activation=nonlinear_activation, 209 | nonlinear_activation_params=nonlinear_activation_params, 210 | use_weight_norm=use_weight_norm, 211 | ) 212 | check_mode(mode, "AudioDec Streamer") 213 | self.reset_buffer() 214 | 215 | 216 | def initial_encoder(self, receptive_length, device): 217 | self.quantizer.initial() 218 | z = self.encode(torch.zeros(1, self.input_channels, receptive_length).to(device)) 219 | idx = self.quantize(z) 220 | zq = self.lookup(idx) 221 | return zq 222 | 223 | 224 | def initial_decoder(self, zq): 225 | self.decode(zq) 226 | 227 | 228 | def encode(self, x): 229 | (batch, channel, length) = x.size() 230 | if channel != self.input_channels: 231 | x = x.reshape(-1, self.input_channels, length) # (B, C, T) -> (B', C', T) 232 | x = self.encoder.encode(x) 233 | z = self.projector.encode(x) 234 | return z 235 | 236 | 237 | def quantize(self, z): 238 | zq, idx = self.quantizer.encode(z) 239 | return idx 240 | 241 | 242 | def lookup(self, idx): 243 | return self.quantizer.decode(idx) 244 | 245 | 246 | def decode(self, zq): 247 | return self.decoder.decode(zq.transpose(2, 1)) 248 | 249 | 250 | def reset_buffer(self): 251 | """Apply weight normalization module from all layers.""" 252 | 253 | def _reset_buffer(m): 254 | if isinstance(m, CausalConv1d) or isinstance(m, CausalConvTranspose1d): 255 | m.reset_buffer() 256 | self.apply(_reset_buffer) 257 | -------------------------------------------------------------------------------- /AudioDec/models/autoencoder/modules/decoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright (c) Meta Platforms, Inc. and affiliates. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | # 10 | # Reference (https://ieeexplore.ieee.org/document/9625818) 11 | 12 | """Decoder modules.""" 13 | 14 | import torch 15 | import inspect 16 | 17 | from AudioDec.layers.conv_layer import NonCausalConv1d, NonCausalConvTranspose1d 18 | from AudioDec.layers.conv_layer import CausalConv1d, CausalConvTranspose1d 19 | from AudioDec.layers.activation_function import get_activation 20 | from AudioDec.models.autoencoder.modules.residual_unit import NonCausalResidualUnit 21 | from AudioDec.models.autoencoder.modules.residual_unit import CausalResidualUnit 22 | from AudioDec.models.utils import check_mode 23 | 24 | 25 | class DecoderBlock(torch.nn.Module): 26 | """ Decoder block (upsampling) """ 27 | 28 | def __init__( 29 | self, 30 | in_channels, 31 | out_channels, 32 | stride, 33 | dilations=(1, 3, 9), 34 | bias=True, 35 | mode='causal', 36 | nonlinear_activation="ELU", 37 | nonlinear_activation_params={}, 38 | ): 39 | super().__init__() 40 | self.mode = mode 41 | if self.mode == 'noncausal': 42 | ResidualUnit = NonCausalResidualUnit 43 | ConvTranspose1d = NonCausalConvTranspose1d 44 | elif self.mode == 'causal': 45 | ResidualUnit = CausalResidualUnit 46 | ConvTranspose1d = CausalConvTranspose1d 47 | else: 48 | raise NotImplementedError(f"Mode ({self.mode}) is not supported!") 49 | 50 | self.conv = ConvTranspose1d( 51 | in_channels=in_channels, 52 | out_channels=out_channels, 53 | kernel_size=(2 * stride), 54 | stride=stride, 55 | bias=bias, 56 | ) 57 | 58 | self.res_units = torch.nn.ModuleList() 59 | for idx, dilation in enumerate(dilations): 60 | self.res_units += [ 61 | ResidualUnit( 62 | out_channels, 63 | out_channels, 64 | dilation=dilation, 65 | nonlinear_activation=nonlinear_activation, 66 | nonlinear_activation_params=nonlinear_activation_params, 67 | )] 68 | self.num_res = len(self.res_units) 69 | 70 | def forward(self, x): 71 | x = self.conv(x) 72 | for idx in range(self.num_res): 73 | x = self.res_units[idx](x) 74 | return x 75 | 76 | def inference(self, x): 77 | check_mode(self.mode, inspect.stack()[0][3]) 78 | x = self.conv.inference(x) 79 | for idx in range(self.num_res): 80 | x = self.res_units[idx].inference(x) 81 | return x 82 | 83 | 84 | class Decoder(torch.nn.Module): 85 | def __init__(self, 86 | code_dim, 87 | output_channels, 88 | decode_channels, 89 | channel_ratios=(16, 8, 4, 2), 90 | strides=(5, 5, 4, 3), 91 | kernel_size=7, 92 | bias=True, 93 | mode='causal', 94 | nonlinear_activation="ELU", 95 | nonlinear_activation_params={}, 96 | ): 97 | super().__init__() 98 | assert len(channel_ratios) == len(strides) 99 | self.mode = mode 100 | if self.mode == 'noncausal': 101 | Conv1d = NonCausalConv1d 102 | elif self.mode == 'causal': 103 | Conv1d = CausalConv1d 104 | else: 105 | raise NotImplementedError(f"Mode ({self.mode}) is not supported!") 106 | 107 | self.conv1 = Conv1d( 108 | in_channels=code_dim, 109 | out_channels=(decode_channels * channel_ratios[0]), 110 | kernel_size=kernel_size, 111 | stride=1, 112 | bias=False) 113 | 114 | self.conv_blocks = torch.nn.ModuleList() 115 | for idx, stride in enumerate(strides): 116 | in_channels = decode_channels * channel_ratios[idx] 117 | if idx < (len(channel_ratios)-1): 118 | out_channels = decode_channels * channel_ratios[idx+1] 119 | else: 120 | out_channels = decode_channels 121 | self.conv_blocks += [ 122 | DecoderBlock( 123 | in_channels, 124 | out_channels, 125 | stride, 126 | bias=bias, 127 | mode=self.mode, 128 | nonlinear_activation=nonlinear_activation, 129 | nonlinear_activation_params=nonlinear_activation_params, 130 | )] 131 | self.num_blocks = len(self.conv_blocks) 132 | 133 | self.conv2 = Conv1d(out_channels, output_channels, kernel_size, 1, bias=False) 134 | 135 | def forward(self, z): 136 | x = self.conv1(z) 137 | for i in range(self.num_blocks): 138 | x = self.conv_blocks[i](x) 139 | x = self.conv2(x) 140 | return x 141 | 142 | def decode(self, z): 143 | check_mode(self.mode, inspect.stack()[0][3]) 144 | x = self.conv1.inference(z) 145 | for i in range(self.num_blocks): 146 | x = self.conv_blocks[i].inference(x) 147 | x = self.conv2.inference(x) 148 | return x 149 | 150 | 151 | class ActivateDecoder(Decoder): 152 | def __init__(self, 153 | code_dim, 154 | output_channels, 155 | decode_channels, 156 | channel_ratios=(16, 8, 4, 2), 157 | strides=(5, 5, 4, 3), 158 | kernel_size=7, 159 | bias=True, 160 | mode='causal', 161 | nonlinear_activation="ELU", 162 | nonlinear_activation_params={}, 163 | ): 164 | super().__init__( 165 | code_dim=code_dim, 166 | output_channels=output_channels, 167 | decode_channels=decode_channels, 168 | channel_ratios=channel_ratios, 169 | strides=strides, 170 | kernel_size=kernel_size, 171 | bias=bias, 172 | mode=mode, 173 | nonlinear_activation=nonlinear_activation, 174 | nonlinear_activation_params=nonlinear_activation_params, 175 | ) 176 | self.conv_blocks = torch.nn.ModuleList() 177 | for idx, stride in enumerate(strides): 178 | in_channels = decode_channels * channel_ratios[idx] 179 | if idx < (len(channel_ratios)-1): 180 | out_channels = decode_channels * channel_ratios[idx+1] 181 | else: 182 | out_channels = decode_channels 183 | # upsamping + residual 184 | self.conv_blocks += [ 185 | torch.nn.Sequential( 186 | get_activation(nonlinear_activation, nonlinear_activation_params), 187 | DecoderBlock( 188 | in_channels, 189 | out_channels, 190 | stride, 191 | bias=bias, 192 | mode=self.mode, 193 | nonlinear_activation=nonlinear_activation, 194 | nonlinear_activation_params=nonlinear_activation_params, 195 | ), 196 | ) 197 | ] 198 | self.conv_blocks += [get_activation(nonlinear_activation, nonlinear_activation_params)] # for conv2 199 | self.num_blocks = len(self.conv_blocks) 200 | # output activation 201 | self.activation_output = torch.nn.Tanh() 202 | 203 | def forward(self, z): 204 | return self.activation_output(super().forward(z)) 205 | 206 | def decode(self, z): 207 | check_mode(self.mode, inspect.stack()[0][3]) 208 | x = self.conv1.inference(z) 209 | for i in range(self.num_blocks-1): 210 | x = self.conv_blocks[i][0](x) # activation 211 | x = self.conv_blocks[i][1].inference(x) # DecoderBlock 212 | x = self.conv_blocks[-1](x) # activation 213 | x = self.conv2.inference(x) 214 | return self.activation_output(x) -------------------------------------------------------------------------------- /AudioDec/models/autoencoder/modules/encoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright (c) Meta Platforms, Inc. and affiliates. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | # 10 | # Reference (https://ieeexplore.ieee.org/document/9625818) 11 | 12 | """Encoder modules.""" 13 | 14 | import torch 15 | import inspect 16 | 17 | from AudioDec.layers.conv_layer import NonCausalConv1d 18 | from AudioDec.layers.conv_layer import CausalConv1d 19 | from AudioDec.layers.activation_function import get_activation 20 | from AudioDec.models.autoencoder.modules.residual_unit import NonCausalResidualUnit 21 | from AudioDec.models.autoencoder.modules.residual_unit import CausalResidualUnit 22 | from AudioDec.models.utils import check_mode 23 | 24 | 25 | class EncoderBlock(torch.nn.Module): 26 | """ Encoder block (downsampling) """ 27 | 28 | def __init__( 29 | self, 30 | in_channels, 31 | out_channels, 32 | stride, 33 | dilations=(1, 3, 9), 34 | bias=True, 35 | mode='causal', 36 | nonlinear_activation="ELU", 37 | nonlinear_activation_params={}, 38 | ): 39 | super().__init__() 40 | self.mode = mode 41 | if self.mode == 'noncausal': 42 | ResidualUnit = NonCausalResidualUnit 43 | Conv1d = NonCausalConv1d 44 | elif self.mode == 'causal': 45 | ResidualUnit = CausalResidualUnit 46 | Conv1d = CausalConv1d 47 | else: 48 | raise NotImplementedError(f"Mode ({self.mode}) is not supported!") 49 | 50 | self.res_units = torch.nn.ModuleList() 51 | for dilation in dilations: 52 | self.res_units += [ 53 | ResidualUnit( 54 | in_channels, 55 | in_channels, 56 | dilation=dilation, 57 | nonlinear_activation=nonlinear_activation, 58 | nonlinear_activation_params=nonlinear_activation_params, 59 | )] 60 | self.num_res = len(self.res_units) 61 | 62 | self.conv = Conv1d( 63 | in_channels=in_channels, 64 | out_channels=out_channels, 65 | kernel_size=(2 * stride), 66 | stride=stride, 67 | bias=bias, 68 | ) 69 | 70 | def forward(self, x): 71 | for idx in range(self.num_res): 72 | x = self.res_units[idx](x) 73 | x = self.conv(x) 74 | return x 75 | 76 | def inference(self, x): 77 | check_mode(self.mode, inspect.stack()[0][3]) 78 | for idx in range(self.num_res): 79 | x = self.res_units[idx].inference(x) 80 | x = self.conv.inference(x) 81 | return x 82 | 83 | 84 | class Encoder(torch.nn.Module): 85 | def __init__(self, 86 | input_channels, 87 | encode_channels, 88 | channel_ratios=(2, 4, 8, 16), 89 | strides=(3, 4, 5, 5), 90 | kernel_size=7, 91 | bias=True, 92 | mode='causal', 93 | nonlinear_activation="ELU", 94 | nonlinear_activation_params={}, 95 | ): 96 | super().__init__() 97 | assert len(channel_ratios) == len(strides) 98 | self.mode = mode 99 | if self.mode == 'noncausal': 100 | Conv1d = NonCausalConv1d 101 | elif self.mode == 'causal': 102 | Conv1d = CausalConv1d 103 | else: 104 | raise NotImplementedError(f"Mode ({self.mode}) is not supported!") 105 | 106 | self.conv = Conv1d( 107 | in_channels=input_channels, 108 | out_channels=encode_channels, 109 | kernel_size=kernel_size, 110 | stride=1, 111 | bias=False) 112 | 113 | self.conv_blocks = torch.nn.ModuleList() 114 | in_channels = encode_channels 115 | for idx, stride in enumerate(strides): 116 | out_channels = encode_channels * channel_ratios[idx] 117 | self.conv_blocks += [ 118 | EncoderBlock( 119 | in_channels, 120 | out_channels, 121 | stride, 122 | bias=bias, 123 | mode=self.mode, 124 | nonlinear_activation=nonlinear_activation, 125 | nonlinear_activation_params=nonlinear_activation_params, 126 | )] 127 | in_channels = out_channels 128 | self.num_blocks = len(self.conv_blocks) 129 | self.out_channels = out_channels 130 | 131 | def forward(self, x): 132 | x = self.conv(x) 133 | for i in range(self.num_blocks): 134 | x = self.conv_blocks[i](x) 135 | return x 136 | 137 | def encode(self, x): 138 | check_mode(self.mode, inspect.stack()[0][3]) 139 | x = self.conv.inference(x) 140 | for i in range(self.num_blocks): 141 | x = self.conv_blocks[i].inference(x) 142 | return x 143 | 144 | 145 | class ActivateEncoder(Encoder): 146 | def __init__(self, 147 | input_channels, 148 | encode_channels, 149 | channel_ratios=(2, 4, 8, 16), 150 | strides=(3, 4, 5, 5), 151 | kernel_size=7, 152 | bias=True, 153 | mode='causal', 154 | nonlinear_activation="ELU", 155 | nonlinear_activation_params={}, 156 | ): 157 | super().__init__( 158 | input_channels=input_channels, 159 | encode_channels=encode_channels, 160 | channel_ratios=channel_ratios, 161 | strides=strides, 162 | kernel_size=kernel_size, 163 | bias=bias, 164 | mode=mode, 165 | nonlinear_activation=nonlinear_activation, 166 | nonlinear_activation_params=nonlinear_activation_params, 167 | ) 168 | self.activation = get_activation(nonlinear_activation, nonlinear_activation_params) 169 | 170 | 171 | def forward(self, x): 172 | return self.activation(super().forward(x)) 173 | 174 | def encode(self, x): 175 | return self.activation(super().encode(x)) -------------------------------------------------------------------------------- /AudioDec/models/autoencoder/modules/projector.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright (c) Meta Platforms, Inc. and affiliates. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | 10 | """Projector modules.""" 11 | 12 | import torch 13 | import inspect 14 | 15 | from AudioDec.layers.conv_layer import NonCausalConv1d 16 | from AudioDec.layers.conv_layer import CausalConv1d 17 | from AudioDec.models.utils import check_mode 18 | 19 | 20 | class Projector(torch.nn.Module): 21 | def __init__(self, 22 | input_channels, 23 | code_dim, 24 | kernel_size=3, 25 | stride=1, 26 | bias=False, 27 | mode='causal', 28 | model='conv1d', 29 | ): 30 | super().__init__() 31 | self.mode = mode 32 | if self.mode == 'noncausal': 33 | Conv1d = NonCausalConv1d 34 | elif self.mode == 'causal': 35 | Conv1d = CausalConv1d 36 | else: 37 | raise NotImplementedError(f"Mode ({mode}) is not supported!") 38 | 39 | if model == 'conv1d': 40 | self.project = Conv1d(input_channels, code_dim, kernel_size=kernel_size, stride=stride, bias=bias) 41 | elif model == 'conv1d_bn': 42 | self.project = torch.nn.Sequential( 43 | Conv1d(input_channels, code_dim, kernel_size=kernel_size, stride=stride, bias=bias), 44 | torch.nn.BatchNorm1d(code_dim) 45 | ) 46 | else: 47 | raise NotImplementedError(f"Model ({model}) is not supported!") 48 | 49 | def forward(self, x): 50 | return self.project(x) 51 | 52 | def encode(self, x): 53 | check_mode(self.mode, inspect.stack()[0][3]) 54 | return self.project.inference(x) 55 | 56 | 57 | -------------------------------------------------------------------------------- /AudioDec/models/autoencoder/modules/quantizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright (c) Meta Platforms, Inc. and affiliates. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | 10 | import torch 11 | 12 | from AudioDec.layers.vq_module import ResidualVQ 13 | 14 | 15 | class Quantizer(torch.nn.Module): 16 | def __init__(self, 17 | code_dim, 18 | codebook_num, 19 | codebook_size, 20 | model='residual_vq', 21 | ): 22 | super().__init__() 23 | # speech 24 | if model == 'residual_vq': 25 | self.codebook = ResidualVQ(dim=code_dim, num_quantizers=codebook_num, codebook_size=codebook_size) 26 | else: 27 | raise NotImplementedError(f"Model ({model}) is not supported!") 28 | 29 | def initial(self): 30 | self.codebook.initial() 31 | 32 | def forward(self, z): 33 | zq, vqloss, perplexity = self.codebook(z.transpose(2, 1)) 34 | zq = zq.transpose(2, 1) 35 | return zq, vqloss, perplexity 36 | 37 | def inference(self, z): 38 | zq, indices = self.codebook.forward_index(z.transpose(2, 1)) 39 | zq = zq.transpose(2, 1) 40 | return zq, indices 41 | 42 | def encode(self, z): 43 | zq, indices = self.codebook.forward_index(z.transpose(2, 1), flatten_idx=True) 44 | return zq, indices 45 | 46 | def decode(self, indices): 47 | z = self.codebook.lookup(indices) 48 | return z 49 | -------------------------------------------------------------------------------- /AudioDec/models/autoencoder/modules/residual_unit.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright (c) Meta Platforms, Inc. and affiliates. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | # 10 | # Reference (https://ieeexplore.ieee.org/document/9625818) 11 | 12 | """Residual Units.""" 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | from AudioDec.layers.conv_layer import Conv1d1x1, NonCausalConv1d, CausalConv1d 18 | from AudioDec.layers.activation_function import get_activation 19 | 20 | class NonCausalResidualUnit(nn.Module): 21 | def __init__( 22 | self, 23 | in_channels, 24 | out_channels, 25 | kernel_size=7, 26 | dilation=1, 27 | bias=False, 28 | nonlinear_activation="ELU", 29 | nonlinear_activation_params={}, 30 | ): 31 | super().__init__() 32 | self.activation = get_activation(nonlinear_activation, nonlinear_activation_params) 33 | self.conv1 = NonCausalConv1d( 34 | in_channels=in_channels, 35 | out_channels=out_channels, 36 | kernel_size=kernel_size, 37 | stride=1, 38 | dilation=dilation, 39 | bias=bias, 40 | ) 41 | self.conv2 = Conv1d1x1(out_channels, out_channels, bias) 42 | 43 | def forward(self, x): 44 | y = self.conv1(self.activation(x)) 45 | y = self.conv2(self.activation(y)) 46 | return x + y 47 | 48 | 49 | class CausalResidualUnit(NonCausalResidualUnit): 50 | def __init__( 51 | self, 52 | in_channels, 53 | out_channels, 54 | kernel_size=7, 55 | dilation=1, 56 | bias=False, 57 | nonlinear_activation="ELU", 58 | nonlinear_activation_params={}, 59 | ): 60 | super(CausalResidualUnit, self).__init__( 61 | in_channels=in_channels, 62 | out_channels=out_channels, 63 | kernel_size=kernel_size, 64 | dilation=dilation, 65 | bias=bias, 66 | nonlinear_activation=nonlinear_activation, 67 | nonlinear_activation_params=nonlinear_activation_params, 68 | ) 69 | self.conv1 = CausalConv1d( 70 | in_channels=in_channels, 71 | out_channels=out_channels, 72 | kernel_size=kernel_size, 73 | stride=1, 74 | dilation=dilation, 75 | bias=bias, 76 | ) 77 | 78 | def inference(self, x): 79 | y = self.conv1.inference(self.activation(x)) 80 | y = self.conv2(self.activation(y)) 81 | return x + y -------------------------------------------------------------------------------- /AudioDec/models/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright (c) Meta Platforms, Inc. and affiliates. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | 10 | 11 | """Utility modules.""" 12 | 13 | def check_mode(mode, method): 14 | stream_modes = ['causal'] 15 | assert mode in stream_modes, f"Mode {mode} does not support {method}!" 16 | -------------------------------------------------------------------------------- /AudioDec/models/vocoder/modules/multi_fusion.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright (c) Meta Platforms, Inc. and affiliates. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | # 10 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/) 11 | # Reference (https://github.com/r9y9/wavenet_vocoder) 12 | # Reference (https://github.com/jik876/hifi-gan) 13 | 14 | """Multi-fusion modules.""" 15 | 16 | import math 17 | import torch 18 | import torch.nn as nn 19 | from AudioDec.layers.conv_layer import Conv1d1x1 20 | from AudioDec.models.vocoder.modules.residual_block import HiFiGANResidualBlock 21 | 22 | 23 | class MultiReceptiveField(nn.Module): 24 | """Multi-receptive field module in HiFiGAN.""" 25 | 26 | def __init__( 27 | self, 28 | channels=512, 29 | resblock_kernel_sizes=(3, 7, 11), 30 | resblock_dilations=[(1, 3, 5), (1, 3, 5), (1, 3, 5)], 31 | groups=1, 32 | bias=True, 33 | use_additional_convs=True, 34 | nonlinear_activation="LeakyReLU", 35 | nonlinear_activation_params={"negative_slope": 0.1}, 36 | ): 37 | assert len(resblock_kernel_sizes) == len(resblock_dilations) 38 | super().__init__() 39 | self.num_blocks = len(resblock_kernel_sizes) 40 | 41 | self.blocks = nn.ModuleList() 42 | for i in range(self.num_blocks): 43 | self.blocks += [ 44 | HiFiGANResidualBlock( 45 | kernel_size=resblock_kernel_sizes[i], 46 | channels=channels, 47 | dilations=resblock_dilations[i], 48 | groups=groups, 49 | bias=bias, 50 | use_additional_convs=use_additional_convs, 51 | nonlinear_activation=nonlinear_activation, 52 | nonlinear_activation_params=nonlinear_activation_params, 53 | ) 54 | ] 55 | 56 | def forward(self, c): 57 | """Calculate forward propagation. 58 | 59 | Args: 60 | c (Tensor): Input tensor (B, channels, T). 61 | 62 | Returns: 63 | Tensor: Output tensor (B, channels, T). 64 | 65 | """ 66 | cs = 0.0 # initialize 67 | for i in range(self.num_blocks): 68 | cs += self.blocks[i](c) 69 | c = cs / self.num_blocks 70 | 71 | return c 72 | 73 | def inference(self, c): 74 | cs = 0.0 # initialize 75 | for i in range(self.num_blocks): 76 | cs += self.blocks[i].inference(c) 77 | c = cs / self.num_blocks 78 | 79 | return c 80 | 81 | 82 | class MultiGroupConv1d(HiFiGANResidualBlock): 83 | """Multi-group convolution module.""" 84 | 85 | def __init__( 86 | self, 87 | channels=512, 88 | resblock_kernel_sizes=(3), 89 | resblock_dilations=[(1, 3, 5)], 90 | groups=3, 91 | bias=True, 92 | use_additional_convs=True, 93 | nonlinear_activation="LeakyReLU", 94 | nonlinear_activation_params={"negative_slope": 0.1}, 95 | ): 96 | assert len(resblock_kernel_sizes) == len(resblock_dilations) == 1 97 | super(MultiGroupConv1d, self).__init__( 98 | kernel_size=resblock_kernel_sizes[0], 99 | channels=channels*groups, 100 | dilations=resblock_dilations[0], 101 | groups=groups, 102 | bias=bias, 103 | use_additional_convs=use_additional_convs, 104 | nonlinear_activation=nonlinear_activation, 105 | nonlinear_activation_params=nonlinear_activation_params, 106 | ) 107 | self.groups = groups 108 | self.conv_out = Conv1d1x1( 109 | in_channels=channels*groups, 110 | out_channels=channels, 111 | bias=False, 112 | ) 113 | 114 | def forward(self, x): 115 | """Calculate forward propagation. 116 | 117 | Args: 118 | x (Tensor): Input tensor (B, channels, T). 119 | 120 | Returns: 121 | Tensor: Output tensor (B, channels, T). 122 | 123 | """ 124 | x = x.repeat(1, self.groups, 1) # (B, n*C, T) 125 | for idx in range(self.num_layer): 126 | xt = self.convs1[idx](self.activation(x)) 127 | if self.use_additional_convs: 128 | xt = self.convs2[idx](self.activation(xt)) 129 | x = xt + x 130 | x = self.conv_out(x) # (B, C, T) 131 | return x 132 | 133 | def inference(self, x): 134 | x = x.repeat(1, self.groups, 1) # (B, n*C, T) 135 | for idx in range(self.num_layer): 136 | xt = self.convs1[idx].inference(self.activation(x)) 137 | if self.use_additional_convs: 138 | xt = self.convs2[idx].inference(self.activation(xt)) 139 | x = xt + x 140 | x = self.conv_out(x) # (B, C, T) 141 | return x -------------------------------------------------------------------------------- /AudioDec/models/vocoder/modules/residual_block.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright (c) Meta Platforms, Inc. and affiliates. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | # 10 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/) 11 | # Reference (https://github.com/r9y9/wavenet_vocoder) 12 | # Reference (https://github.com/jik876/hifi-gan) 13 | 14 | """Residual block modules.""" 15 | 16 | import math 17 | 18 | import torch 19 | import torch.nn as nn 20 | from AudioDec.layers.conv_layer import CausalConv1d, Conv1d1x1 21 | 22 | 23 | class HiFiGANResidualBlock(nn.Module): 24 | """Causal Residual block module in HiFiGAN.""" 25 | 26 | def __init__( 27 | self, 28 | kernel_size=3, 29 | channels=512, 30 | dilations=(1, 3, 5), 31 | groups=1, 32 | bias=True, 33 | use_additional_convs=True, 34 | nonlinear_activation="LeakyReLU", 35 | nonlinear_activation_params={"negative_slope": 0.1}, 36 | ): 37 | """Initialize CausalResidualBlock module. 38 | 39 | Args: 40 | kernel_size (int): Kernel size of dilation convolution layer. 41 | channels (int): Number of channels for convolution layer. 42 | dilations (List[int]): List of dilation factors. 43 | use_additional_convs (bool): Whether to use additional convolution layers. 44 | groups (int): The group number of conv1d (default: 1) 45 | bias (bool): Whether to add bias parameter in convolution layers. 46 | nonlinear_activation (str): Activation function module name. 47 | nonlinear_activation_params (dict): Hyperparameters for activation function. 48 | 49 | """ 50 | super().__init__() 51 | self.use_additional_convs = use_additional_convs 52 | self.convs1 = nn.ModuleList() 53 | if use_additional_convs: 54 | self.convs2 = nn.ModuleList() 55 | assert kernel_size % 2 == 1, "Kernel size must be odd number." 56 | self.activation = getattr(nn, nonlinear_activation)(**nonlinear_activation_params) 57 | for dilation in dilations: 58 | self.convs1 += [ 59 | CausalConv1d( 60 | in_channels=channels, 61 | out_channels=channels, 62 | kernel_size=kernel_size, 63 | stride=1, 64 | dilation=dilation, 65 | groups=groups, 66 | bias=bias, 67 | ) 68 | ] 69 | if use_additional_convs: 70 | self.convs2 += [ 71 | CausalConv1d( 72 | in_channels=channels, 73 | out_channels=channels, 74 | kernel_size=kernel_size, 75 | stride=1, 76 | dilation=1, 77 | groups=groups, 78 | bias=bias, 79 | ) 80 | ] 81 | self.num_layer = len(self.convs1) 82 | 83 | def forward(self, x): 84 | """Calculate forward propagation. 85 | 86 | Args: 87 | x (Tensor): Input tensor (B, channels, T). 88 | 89 | Returns: 90 | Tensor: Output tensor (B, channels, T). 91 | 92 | """ 93 | for idx in range(self.num_layer): 94 | xt = self.convs1[idx](self.activation(x)) 95 | if self.use_additional_convs: 96 | xt = self.convs2[idx](self.activation(xt)) 97 | x = xt + x 98 | return x 99 | 100 | def inference(self, x): 101 | for idx in range(self.num_layer): 102 | xt = self.convs1[idx].inference(self.activation(x)) 103 | if self.use_additional_convs: 104 | xt = self.convs2[idx].inference(self.activation(xt)) 105 | x = xt + x 106 | return x -------------------------------------------------------------------------------- /AudioDec/parse_options.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2012 Johns Hopkins University (Author: Daniel Povey); 4 | # Arnab Ghoshal, Karel Vesely 5 | 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 14 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 15 | # MERCHANTABLITY OR NON-INFRINGEMENT. 16 | # See the Apache 2 License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | 20 | # Parse command-line options. 21 | # To be sourced by another script (as in ". parse_options.sh"). 22 | # Option format is: --option-name arg 23 | # and shell variable "option_name" gets set to value "arg." 24 | # The exception is --help, which takes no arguments, but prints the 25 | # $help_message variable (if defined). 26 | 27 | 28 | ### 29 | ### The --config file options have lower priority to command line 30 | ### options, so we need to import them first... 31 | ### 32 | 33 | # Now import all the configs specified by command-line, in left-to-right order 34 | for ((argpos=1; argpos<$#; argpos++)); do 35 | if [ "${!argpos}" == "--config" ]; then 36 | argpos_plus1=$((argpos+1)) 37 | config=${!argpos_plus1} 38 | [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 39 | . $config # source the config file. 40 | fi 41 | done 42 | 43 | 44 | ### 45 | ### Now we process the command line options 46 | ### 47 | while true; do 48 | [ -z "${1:-}" ] && break; # break if there are no arguments 49 | case "$1" in 50 | # If the enclosing script is called with --help option, print the help 51 | # message and exit. Scripts should put help messages in $help_message 52 | --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; 53 | else printf "$help_message\n" 1>&2 ; fi; 54 | exit 0 ;; 55 | --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" 56 | exit 1 ;; 57 | # If the first command-line argument begins with "--" (e.g. --foo-bar), 58 | # then work out the variable name as $name, which will equal "foo_bar". 59 | --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; 60 | # Next we test whether the variable in question is undefned-- if so it's 61 | # an invalid option and we die. Note: $0 evaluates to the name of the 62 | # enclosing script. 63 | # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar 64 | # is undefined. We then have to wrap this test inside "eval" because 65 | # foo_bar is itself inside a variable ($name). 66 | eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; 67 | 68 | oldval="`eval echo \\$$name`"; 69 | # Work out whether we seem to be expecting a Boolean argument. 70 | if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then 71 | was_bool=true; 72 | else 73 | was_bool=false; 74 | fi 75 | 76 | # Set the variable to the right value-- the escaped quotes make it work if 77 | # the option had spaces, like --cmd "queue.pl -sync y" 78 | eval $name=\"$2\"; 79 | 80 | # Check that Boolean-valued arguments are really Boolean. 81 | if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then 82 | echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 83 | exit 1; 84 | fi 85 | shift 2; 86 | ;; 87 | *) break; 88 | esac 89 | done 90 | 91 | 92 | # Check for an empty argument to the --cmd option, which can easily occur as a 93 | # result of scripting errors. 94 | [ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; 95 | 96 | 97 | true; # so this script returns exit code 0. -------------------------------------------------------------------------------- /AudioDec/requirements.txt: -------------------------------------------------------------------------------- 1 | soundfile 2 | sounddevice 3 | numpy 4 | torch 5 | torchaudio 6 | scikit-learn 7 | librosa 8 | argparse 9 | pyyaml 10 | tqdm 11 | tensorboardX 12 | 13 | -------------------------------------------------------------------------------- /AudioDec/slurmlogs/README.md: -------------------------------------------------------------------------------- 1 | # Folder for saving slurm logs -------------------------------------------------------------------------------- /AudioDec/stats/symAD_libritts_24000_hop300_clean.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dukGuo/valle-audiodec/cbd51784ea9fe0bee05b4a65c95f39523b57be89/AudioDec/stats/symAD_libritts_24000_hop300_clean.npy -------------------------------------------------------------------------------- /AudioDec/stats/symAD_vctk_48000_hop300_clean.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dukGuo/valle-audiodec/cbd51784ea9fe0bee05b4a65c95f39523b57be89/AudioDec/stats/symAD_vctk_48000_hop300_clean.npy -------------------------------------------------------------------------------- /AudioDec/stats/symADuniv_vctk_48000_hop300_clean.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dukGuo/valle-audiodec/cbd51784ea9fe0bee05b4a65c95f39523b57be89/AudioDec/stats/symADuniv_vctk_48000_hop300_clean.npy -------------------------------------------------------------------------------- /AudioDec/trainer/autoencoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright (c) Meta Platforms, Inc. and affiliates. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | # 10 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/) 11 | 12 | """Training flow of symmetric codec.""" 13 | 14 | import logging 15 | import torch 16 | from trainer.trainerGAN import TrainerVQGAN 17 | 18 | 19 | class Trainer(TrainerVQGAN): 20 | def __init__( 21 | self, 22 | steps, 23 | epochs, 24 | data_loader, 25 | model, 26 | criterion, 27 | optimizer, 28 | scheduler, 29 | config, 30 | device=torch.device("cpu"), 31 | ): 32 | super(Trainer, self).__init__( 33 | steps=steps, 34 | epochs=epochs, 35 | data_loader=data_loader, 36 | model=model, 37 | criterion=criterion, 38 | optimizer=optimizer, 39 | scheduler=scheduler, 40 | config=config, 41 | device=device, 42 | ) 43 | self.fix_encoder = False 44 | self.paradigm = config.get('paradigm', 'efficient') 45 | self.generator_start = config.get('start_steps', {}).get('generator', 0) 46 | self.discriminator_start = config.get('start_steps', {}).get('discriminator', 200000) 47 | 48 | 49 | def _train_step(self, batch): 50 | """Single step of training.""" 51 | mode = 'train' 52 | x = batch 53 | x = x.to(self.device) 54 | 55 | # check generator step 56 | if self.steps < self.generator_start: 57 | self.generator_train = False 58 | else: 59 | self.generator_train = True 60 | 61 | # check discriminator step 62 | if self.steps < self.discriminator_start: 63 | self.discriminator_train = False 64 | else: 65 | self.discriminator_train = True 66 | if (not self.fix_encoder) and (self.paradigm == 'efficient'): 67 | # fix encoder, quantizer, and codebook 68 | for parameter in self.model["generator"].encoder.parameters(): 69 | parameter.requires_grad = False 70 | for parameter in self.model["generator"].projector.parameters(): 71 | parameter.requires_grad = False 72 | for parameter in self.model["generator"].quantizer.parameters(): 73 | parameter.requires_grad = False 74 | self.fix_encoder = True 75 | logging.info("Encoder, projector, quantizer, and codebook are fixed") 76 | 77 | # check codebook updating 78 | if self.fix_encoder: 79 | self.model["generator"].quantizer.codebook.eval() 80 | 81 | ####################### 82 | # Generator # 83 | ####################### 84 | if self.generator_train: 85 | # initialize generator loss 86 | gen_loss = 0.0 87 | 88 | # main genertor operation 89 | y_, zq, z, vqloss, perplexity = self.model["generator"](x) 90 | 91 | # perplexity info 92 | self._perplexity(perplexity, mode=mode) 93 | 94 | # vq loss 95 | gen_loss += self._vq_loss(vqloss, mode=mode) 96 | 97 | # metric loss 98 | gen_loss += self._metric_loss(y_, x, mode=mode) 99 | 100 | # adversarial loss 101 | if self.discriminator_train: 102 | p_ = self.model["discriminator"](y_) 103 | if self.config["use_feat_match_loss"]: 104 | with torch.no_grad(): 105 | p = self.model["discriminator"](x) 106 | else: 107 | p = None 108 | gen_loss += self._adv_loss(p_, p, mode=mode) 109 | 110 | # update generator 111 | self._record_loss('generator_loss', gen_loss, mode=mode) 112 | self._update_generator(gen_loss) 113 | 114 | ####################### 115 | # Discriminator # 116 | ####################### 117 | if self.discriminator_train: 118 | # re-compute y_ which leads better quality 119 | with torch.no_grad(): 120 | y_, _, _, _, _ = self.model["generator"](x) 121 | 122 | p = self.model["discriminator"](x) 123 | p_ = self.model["discriminator"](y_.detach()) 124 | 125 | # discriminator loss & update discriminator 126 | self._update_discriminator(self._dis_loss(p_, p, mode=mode)) 127 | 128 | # update counts 129 | self.steps += 1 130 | self.tqdm.update(1) 131 | self._check_train_finish() 132 | 133 | 134 | @torch.no_grad() 135 | def _eval_step(self, batch): 136 | """Single step of evaluation.""" 137 | mode = 'eval' 138 | x = batch 139 | x = x.to(self.device) 140 | 141 | # initialize generator loss 142 | gen_loss = 0.0 143 | 144 | # main genertor operation 145 | y_, zq, z, vqloss, perplexity = self.model["generator"](x) 146 | 147 | # perplexity info 148 | self._perplexity(perplexity, mode=mode) 149 | 150 | # vq_loss 151 | gen_loss += self._vq_loss(vqloss, mode=mode) 152 | 153 | # metric loss 154 | gen_loss += self._metric_loss(y_, x, mode=mode) 155 | 156 | if self.discriminator_train: 157 | # adversarial loss 158 | p_ = self.model["discriminator"](y_) 159 | p = self.model["discriminator"](x) 160 | gen_loss += self._adv_loss(p_, p, mode=mode) 161 | 162 | # discriminator loss 163 | self._dis_loss(p_, p, mode=mode) 164 | 165 | # generator loss 166 | self._record_loss('generator_loss', gen_loss, mode=mode) 167 | 168 | 169 | 170 | 171 | 172 | -------------------------------------------------------------------------------- /AudioDec/trainer/denoise.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright (c) Meta Platforms, Inc. and affiliates. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | # 10 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/) 11 | 12 | """Training flow of symmetric codec.""" 13 | 14 | import logging 15 | import torch 16 | from trainer.trainerGAN import TrainerVQGAN 17 | 18 | 19 | class Trainer(TrainerVQGAN): 20 | def __init__( 21 | self, 22 | steps, 23 | epochs, 24 | data_loader, 25 | model, 26 | criterion, 27 | optimizer, 28 | scheduler, 29 | config, 30 | device=torch.device("cpu"), 31 | ): 32 | super(Trainer, self).__init__( 33 | steps=steps, 34 | epochs=epochs, 35 | data_loader=data_loader, 36 | model=model, 37 | criterion=criterion, 38 | optimizer=optimizer, 39 | scheduler=scheduler, 40 | config=config, 41 | device=device, 42 | ) 43 | # fix quantizer 44 | for parameter in self.model["generator"].quantizer.parameters(): 45 | parameter.requires_grad = False 46 | # fix decoder 47 | for parameter in self.model["generator"].decoder.parameters(): 48 | parameter.requires_grad = False 49 | logging.info("Quantizer, codebook, and decoder are fixed") 50 | 51 | 52 | def _train_step(self, batch): 53 | """Single step of training.""" 54 | mode = 'train' 55 | x_n, x_c = batch 56 | x_n = x_n.to(self.device) 57 | x_c = x_c.to(self.device) 58 | 59 | # fix codebook 60 | self.model["generator"].quantizer.codebook.eval() 61 | 62 | # initialize generator loss 63 | gen_loss = 0.0 64 | 65 | # main genertor operation 66 | y_nc, zq, z, vqloss, perplexity = self.model["generator"](x_n) 67 | 68 | # perplexity info 69 | self._perplexity(perplexity, mode=mode) 70 | 71 | # vq loss 72 | gen_loss += self._vq_loss(vqloss, mode=mode) 73 | 74 | # metric loss 75 | gen_loss += self._metric_loss(y_nc, x_c, mode=mode) 76 | 77 | # update generator 78 | self._record_loss('generator_loss', gen_loss, mode=mode) 79 | self._update_generator(gen_loss) 80 | 81 | # update counts 82 | self.steps += 1 83 | self.tqdm.update(1) 84 | self._check_train_finish() 85 | 86 | 87 | @torch.no_grad() 88 | def _eval_step(self, batch): 89 | """Single step of evaluation.""" 90 | mode = 'eval' 91 | x_n, x_c = batch 92 | x_n = x_n.to(self.device) 93 | x_c = x_c.to(self.device) 94 | 95 | # initialize generator loss 96 | gen_loss = 0.0 97 | 98 | # main genertor operation 99 | y_nc, zq, z, vqloss, perplexity = self.model["generator"](x_n) 100 | 101 | # perplexity info 102 | self._perplexity(perplexity, mode=mode) 103 | 104 | # vq_loss 105 | gen_loss += self._vq_loss(vqloss, mode=mode) 106 | 107 | # metric loss 108 | gen_loss += self._metric_loss(y_nc, x_c, mode=mode) 109 | 110 | # generator loss 111 | self._record_loss('generator_loss', gen_loss, mode=mode) 112 | 113 | 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /AudioDec/trainer/vocoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright (c) Meta Platforms, Inc. and affiliates. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | # 10 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/) 11 | 12 | """Training flow of GAN-based vocoder.""" 13 | 14 | import logging 15 | import torch 16 | from trainer.trainerGAN import TrainerGAN 17 | 18 | 19 | class Trainer(TrainerGAN): 20 | def __init__( 21 | self, 22 | steps, 23 | epochs, 24 | data_loader, 25 | model, 26 | criterion, 27 | optimizer, 28 | scheduler, 29 | config, 30 | device=torch.device("cpu"), 31 | ): 32 | super(Trainer, self).__init__( 33 | steps=steps, 34 | epochs=epochs, 35 | data_loader=data_loader, 36 | model=model, 37 | criterion=criterion, 38 | optimizer=optimizer, 39 | scheduler=scheduler, 40 | config=config, 41 | device=device, 42 | ) 43 | self.fix_analyzer = False 44 | self.generator_start = config.get("generator_train_start_steps", 0) 45 | self.discriminator_start = config.get("discriminator_train_start_steps", 0) 46 | 47 | 48 | def _train_step(self, batch): 49 | """Train model one step.""" 50 | mode = 'train' 51 | x = batch 52 | x = x.to(self.device) 53 | 54 | # fix analyzer 55 | if not self.fix_analyzer: 56 | for parameter in self.model["analyzer"].parameters(): 57 | parameter.requires_grad = False 58 | self.fix_analyzer = True 59 | logging.info("Analyzer is fixed!") 60 | self.model["analyzer"].eval() 61 | 62 | ####################### 63 | # Generator # 64 | ####################### 65 | if self.steps > self.generator_start: 66 | # initialize generator loss 67 | gen_loss = 0.0 68 | 69 | # main genertor operation 70 | e = self.model["analyzer"].encoder(x) 71 | z = self.model["analyzer"].projector(e) 72 | zq, _, _ = self.model["analyzer"].quantizer(z) 73 | y_ = self.model["generator"](zq) 74 | 75 | # metric loss 76 | gen_loss += self._metric_loss(y_, x, mode=mode) 77 | 78 | # adversarial loss 79 | if self.steps > self.discriminator_start: 80 | p_ = self.model["discriminator"](y_) 81 | if self.config["use_feat_match_loss"]: 82 | with torch.no_grad(): 83 | p = self.model["discriminator"](x) 84 | else: 85 | p = None 86 | gen_loss += self._adv_loss(p_, p, mode=mode) 87 | 88 | # update generator 89 | self._record_loss('generator_loss', gen_loss, mode=mode) 90 | self._update_generator(gen_loss) 91 | 92 | ####################### 93 | # Discriminator # 94 | ####################### 95 | if self.steps > self.discriminator_start: 96 | # re-compute y_ which leads better quality 97 | with torch.no_grad(): 98 | e = self.model["analyzer"].encoder(x) 99 | z = self.model["analyzer"].projector(e) 100 | zq, _, _ = self.model["analyzer"].quantizer(z) 101 | y_ = self.model["generator"](zq) 102 | p = self.model["discriminator"](x) 103 | p_ = self.model["discriminator"](y_.detach()) 104 | 105 | # discriminator loss & update discriminator 106 | self._update_discriminator(self._dis_loss(p_, p, mode=mode)) 107 | 108 | # update counts 109 | self.steps += 1 110 | self.tqdm.update(1) 111 | self._check_train_finish() 112 | 113 | 114 | @torch.no_grad() 115 | def _eval_step(self, batch): 116 | """Single step of evaluation.""" 117 | mode = 'eval' 118 | x = batch 119 | x = x.to(self.device) 120 | 121 | # initialize generator loss 122 | gen_loss = 0.0 123 | 124 | # main genertor operation 125 | e = self.model["analyzer"].encoder(x) 126 | z = self.model["analyzer"].projector(e) 127 | zq, _, _ = self.model["analyzer"].quantizer(z) 128 | y_ = self.model["generator"](zq) 129 | 130 | # metric loss 131 | gen_loss += self._metric_loss(y_, x, mode=mode) 132 | 133 | # adversarial loss & feature matching loss 134 | if self.steps > self.discriminator_start: 135 | p_ = self.model["discriminator"](y_) 136 | if self.config["use_feat_match_loss"]: 137 | p = self.model["discriminator"](x) 138 | else: 139 | p = None 140 | gen_loss += self._adv_loss(p_, p, mode=mode) 141 | 142 | # discriminator loss 143 | self._dis_loss(p_, p, mode=mode) 144 | 145 | # generator loss 146 | self._record_loss('generator_loss', gen_loss, mode=mode) 147 | 148 | -------------------------------------------------------------------------------- /AudioDec/utils/audiodec.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright (c) Meta Platforms, Inc. and affiliates. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | 10 | import os 11 | import torch 12 | import math 13 | from typing import Union 14 | from AudioDec.models.autoencoder.AudioDec import StreamGenerator as generator_audiodec 15 | from AudioDec.models.vocoder.HiFiGAN import StreamGenerator as generator_hifigan 16 | from AudioDec.bin.stream import AudioCodec, AudioCodecStreamer 17 | 18 | 19 | class AudioDec(AudioCodec): 20 | def __init__( 21 | self, 22 | tx_device: str = "cpu", 23 | rx_device: str = "cpu", 24 | receptive_length: int = 8192, # actual number is 7209 for symAD_vctk_48000_hop300 25 | ): 26 | super(AudioDec, self).__init__( 27 | tx_device = tx_device, 28 | rx_device = rx_device, 29 | receptive_length = receptive_length, 30 | ) 31 | 32 | def _load_encoder(self, checkpoint): 33 | # load config 34 | config = self._load_config(checkpoint) 35 | # load model 36 | if config['model_type'] in ['symAudioDec', 'symAudioDecUniv']: 37 | encoder = generator_audiodec 38 | else: 39 | raise NotImplementedError(f"Encoder type {config['model_type']} is not supported!") 40 | encoder = encoder(**config['generator_params']) 41 | encoder.load_state_dict(torch.load(checkpoint, map_location='cpu')['model']['generator']) 42 | return encoder 43 | 44 | def _load_decoder(self, checkpoint): 45 | # load config 46 | config = self._load_config(checkpoint) 47 | # load model 48 | if config['model_type'] in ['symAudioDec', 'symAudioDecUniv']: 49 | decoder = generator_audiodec 50 | elif config['model_type'] in ['HiFiGAN', 'UnivNet']: 51 | decoder = generator_hifigan 52 | else: 53 | raise NotImplementedError(f"Decoder {config['model_type']} is not supported!") 54 | decoder = decoder(**config['generator_params']) 55 | decoder.load_state_dict(torch.load(checkpoint, map_location='cpu')['model']['generator']) 56 | return decoder 57 | 58 | def get_hop_length(self, checkpoint): 59 | # load receiver model(s) 60 | assert os.path.exists(checkpoint), f'{checkpoint} does not exist!' 61 | config = self._load_config(checkpoint) 62 | return math.prod(config['generator_params']['enc_strides']) 63 | 64 | 65 | class AudioDecStreamer(AudioCodecStreamer): 66 | def __init__( 67 | self, 68 | input_device: Union[str, int], 69 | output_device: Union[str, int], 70 | input_channels: int = 1, 71 | output_channels: int = 1, 72 | frame_size: int = 512, 73 | sample_rate: int = 48000, 74 | gain: int = 1.0, 75 | max_latency: float = 0.1, 76 | # encoder params 77 | tx_encoder = None, 78 | tx_device: str = "cpu", 79 | # decoder params 80 | rx_encoder = None, 81 | decoder = None, 82 | rx_device: str = "cpu", 83 | ): 84 | super(AudioDecStreamer, self).__init__( 85 | input_device = input_device, 86 | output_device = output_device, 87 | input_channels = input_channels, 88 | output_channels = output_channels, 89 | frame_size = frame_size, 90 | sample_rate = sample_rate, 91 | gain = gain, 92 | max_latency = max_latency, 93 | tx_encoder = tx_encoder, 94 | tx_device = tx_device, 95 | rx_encoder = rx_encoder, 96 | decoder = decoder, 97 | rx_device = rx_device, 98 | ) 99 | 100 | def _encode(self, x): 101 | x = self.tx_encoder.encode(x) 102 | return self.tx_encoder.quantize(x) 103 | 104 | def _decode(self, x): 105 | x = self.rx_encoder.lookup(x) 106 | return self.decoder.decode(x) 107 | 108 | 109 | def assign_model(model): 110 | if model == 'libritts_v1': 111 | sample_rate = 24000 112 | tx_steps = 500000 113 | rx_steps = 500000 114 | encoder_checkpoint = os.path.join('AudioDec/exp', 'autoencoder', 'symAD_libritts_24000_hop300', f"checkpoint-{tx_steps}steps.pkl") 115 | decoder_checkpoint = os.path.join('AudioDec/exp', 'vocoder', 'AudioDec_v1_symAD_libritts_24000_hop300_clean', f"checkpoint-{rx_steps}steps.pkl") 116 | elif model == 'libritts_sym': 117 | sample_rate = 24000 118 | tx_steps = 500000 119 | rx_steps = 1000000 120 | encoder_checkpoint = os.path.join('exp', 'autoencoder', 'symAD_libritts_24000_hop300', f"checkpoint-{tx_steps}steps.pkl") 121 | decoder_checkpoint = os.path.join('exp', 'autoencoder', 'symAD_libritts_24000_hop300', f"checkpoint-{rx_steps}steps.pkl") 122 | elif model == 'vctk_v1': 123 | sample_rate = 48000 124 | tx_steps = 200000 125 | rx_steps = 500000 126 | encoder_checkpoint = os.path.join('exp', 'autoencoder', 'symAD_vctk_48000_hop300', f"checkpoint-{tx_steps}steps.pkl") 127 | decoder_checkpoint = os.path.join('exp', 'vocoder', 'AudioDec_v1_symAD_vctk_48000_hop300_clean', f"checkpoint-{rx_steps}steps.pkl") 128 | elif model == 'vctk_sym': 129 | sample_rate = 48000 130 | tx_steps = 200000 131 | rx_steps = 700000 132 | encoder_checkpoint = os.path.join('exp', 'autoencoder', 'symAD_vctk_48000_hop300', f"checkpoint-{tx_steps}steps.pkl") 133 | decoder_checkpoint = os.path.join('exp', 'autoencoder', 'symAD_vctk_48000_hop300', f"checkpoint-{rx_steps}steps.pkl") 134 | elif model == 'vctk_v0': 135 | sample_rate = 48000 136 | tx_steps = 200000 137 | rx_steps = 500000 138 | encoder_checkpoint = os.path.join('exp', 'autoencoder', 'symAD_vctk_48000_hop300', f"checkpoint-{tx_steps}steps.pkl") 139 | decoder_checkpoint = os.path.join('exp', 'vocoder', 'AudioDec_v0_symAD_vctk_48000_hop300_clean', f"checkpoint-{rx_steps}steps.pkl") 140 | elif model == 'vctk_v2': 141 | sample_rate = 48000 142 | tx_steps = 200000 143 | rx_steps = 500000 144 | encoder_checkpoint = os.path.join('exp', 'autoencoder', 'symAD_vctk_48000_hop300', f"checkpoint-{tx_steps}steps.pkl") 145 | decoder_checkpoint = os.path.join('exp', 'vocoder', 'AudioDec_v2_symAD_vctk_48000_hop300_clean', f"checkpoint-{rx_steps}steps.pkl") 146 | elif model == 'vctk_denoise': 147 | sample_rate = 48000 148 | tx_steps = 200000 149 | rx_steps = 500000 150 | encoder_checkpoint = os.path.join('exp', 'denoise', 'symAD_vctk_48000_hop300', f"checkpoint-{tx_steps}steps.pkl") 151 | decoder_checkpoint = os.path.join('exp', 'vocoder', 'AudioDec_v1_symAD_vctk_48000_hop300_clean', f"checkpoint-{rx_steps}steps.pkl") 152 | elif model == 'vctk_univ': 153 | sample_rate = 48000 154 | tx_steps = 500000 155 | rx_steps = 500000 156 | encoder_checkpoint = os.path.join('exp', 'autoencoder', 'symADuniv_vctk_48000_hop300', f"checkpoint-{tx_steps}steps.pkl") 157 | decoder_checkpoint = os.path.join('exp', 'vocoder', 'AudioDec_v3_symADuniv_vctk_48000_hop300_clean', f"checkpoint-{rx_steps}steps.pkl") 158 | elif model == 'vctk_univ_sym': 159 | sample_rate = 48000 160 | tx_steps = 500000 161 | rx_steps = 1000000 162 | encoder_checkpoint = os.path.join('exp', 'autoencoder', 'symADuniv_vctk_48000_hop300', f"checkpoint-{tx_steps}steps.pkl") 163 | decoder_checkpoint = os.path.join('exp', 'autoencoder', 'symADuniv_vctk_48000_hop300', f"checkpoint-{rx_steps}steps.pkl") 164 | elif model == 'vctk_activate_sym': 165 | sample_rate = 48000 166 | tx_steps = 200000 167 | rx_steps = 700000 168 | encoder_checkpoint = os.path.join('exp', 'autoencoder', 'symAAD_vctk_48000_hop300', f"checkpoint-{tx_steps}steps.pkl") 169 | decoder_checkpoint = os.path.join('exp', 'autoencoder', 'symAAD_vctk_48000_hop300', f"checkpoint-{rx_steps}steps.pkl") 170 | elif model == 'vctk_c16h320_sym': 171 | sample_rate = 48000 172 | tx_steps = 500000 173 | rx_steps = 1000000 174 | encoder_checkpoint = os.path.join('exp', 'autoencoder', 'symAD_c16_vctk_48000_hop320', f"checkpoint-{tx_steps}steps.pkl") 175 | decoder_checkpoint = os.path.join('exp', 'autoencoder', 'symAD_c16_vctk_48000_hop320', f"checkpoint-{rx_steps}steps.pkl") 176 | else: 177 | raise NotImplementedError(f'Model {model} is not supported!') 178 | 179 | return sample_rate, encoder_checkpoint, decoder_checkpoint 180 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VALL-E 2 | 3 | Inference code for [Wenetspeech4TTS/Audiodec-Valle-Wenetspeech4TTS](https://huggingface.co/Wenetspeech4TTS/Audiodec-Valle-Wenetspeech4TTS) 4 | 5 | ## Installation 6 | 7 | ``` bash 8 | git clone https://github.com/dukGuo/valle-audiodec.git 9 | cd valle-audiodec 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | ## Download pre-train model 14 | ### AudioDec 15 | We use [AudioDec](https://github.com/facebookresearch/AudioDec/) as our speech tokenizer instead of [encodec](https://github.com/facebookresearch/encodec) to further improve audio quality. 16 | 17 | Please download the whole [exp](https://github.com/facebookresearch/AudioDec/releases/download/pretrain_models_v02/exp.zip) folder, unzip and put it in the `AudioDec/exp` directory. 18 | 19 | ```bash 20 | cd valle-audiodec 21 | wget https://github.com/facebookresearch/AudioDec/releases/download/pretrain_models_v02/exp.zip 22 | unzip exp.zip 23 | mv exp AudioDec/exp 24 | ``` 25 | 26 | ### VALL-E 27 | Checkpiont available on [Wenetspeech4TTS/Audiodec-Valle-Wenetspeech4TTS](https://huggingface.co/Wenetspeech4TTS/Audiodec-Valle-Wenetspeech4TTS) 28 | 29 | - VALL-E *Basic* :VALL-E trained with the WenetSpeech4TTS Basic subset 30 | - VALL-E *Standard*: VALL-E *Basic* fine-tuning with the WenetSpeech4TTS Standard subset 31 | - VALL-E *Premium*: VALL-E *Standard* fine-tuning with the WenetSpeech4TTS Premium subset 32 | ## Speech Sample 33 | 34 | https://wenetspeech4tts.github.io/wenetspeech4tts 35 | 36 | https://rxy-j.github.io/HPMD-TTS 37 | 38 | ## Inference 39 | 40 | ``` bash 41 | cd valle-audiodec 42 | python infer_tts.py \ 43 | --config config/hparams.yaml \ 44 | --ar_ckpt ckpt/basic/ar.pt \ 45 | --nar_ckpt ckpt/basic/nar.pt \ 46 | --prompt_wav test/prompt_wavs/test_1.wav \ 47 | --prompt_text 在夏日阴凉的树荫下,鸭妈妈孵着鸭宝宝。 \ 48 | --text 负责指挥的将军在一旁交代着注意事项,每个人在上面最多只能待九十秒。 49 | ``` 50 | 51 | > To improve audio quality and ensure consistent volume levels across different inputs, it is advisable to normalize the loudness of the prompt waveform before conducting inference. This preprocessing step helps achieve uniformity in the audio input, which can lead to more reliable inference outcomes. 52 | > ``` 53 | > sox $in_wave -r $sample_rate -b 16 --norm=-6 $out_wave 54 | > ``` 55 | 56 | ## References 57 | This repository is developed based on the following repositories. 58 | 59 | - [facebookresearch/AudioDec](https://github.com/facebookresearch/AudioDec) 60 | - [lifeiteng/vall-e](https://github.com/lifeiteng/vall-e) 61 | - [fishaudio/Bert-VITS2](https://github.com/fishaudio/Bert-VITS2) 62 | -------------------------------------------------------------------------------- /config/hparams.yaml: -------------------------------------------------------------------------------- 1 | hparams: 2 | 3 | # input_settings: 4 | num_semantic: 216 # pinyin 5 | num_acoustic: 1024 #codebook num of codec 6 | acoustic_num_layer: 8 # codec layer num 7 | 8 | # Training_settings: 9 | batch_size: 8 10 | num_workers: 16 11 | save_checkpoint_step: 5000 12 | learning_rate: 0.0002 13 | max_training_steps: 4000000 14 | temperature: 1.0 15 | grad_accu: 40 16 | dist_backend: "nccl" # distributed training setting 17 | dist_url: "tcp://localhost:12345" 18 | 19 | # AR setting: 20 | GPT2_vocab_size: 1245 # padding+semantic+2eos acoustic 21 | GPT2_n_positions: 2048 # 支持的最长长度 22 | GPT2_n_ctx: 2048 # 等同于n_positions 23 | GPT2_n_embd: 1024 # 隐藏层dim 24 | GPT2_n_layer: 12 # 多少层 25 | GPT2_n_head: 16 # 多少个头 26 | 27 | # NAR setting: 28 | nar_vocab_size: 1245 # padding+semantic+acoustic+2eos 29 | nar_n_emb: 1024 30 | nar_n_layer: 12 31 | nar_n_head: 16 -------------------------------------------------------------------------------- /infer_tts.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from tqdm import tqdm 5 | import time 6 | import torchaudio 7 | import soundfile as sf 8 | import shutil 9 | 10 | from models.ar import AR 11 | from models.nar import NAR 12 | from utils.utils import * 13 | from text.chinese import generate_token 14 | from AudioDec.utils.audiodec import AudioDec, assign_model 15 | 16 | def tokenize_wav(wav_path,audiodec,device,sample_rate=24000): 17 | 18 | wav, sr = torchaudio.load(wav_path) 19 | if sr != sample_rate: 20 | wav = torchaudio.functional.resample(wav, sr, sample_rate) 21 | with torch.no_grad(): 22 | wav = wav.unsqueeze(1) #C T-> 1 C T 23 | wav = wav.float().to(device) 24 | z = audiodec.tx_encoder.encode(wav) 25 | idx = audiodec.tx_encoder.quantize(z) 26 | 27 | inc = torch.arange(8)*1024 28 | idx = idx.cpu() - inc.reshape(-1,1) 29 | return idx.numpy().T 30 | 31 | 32 | def do_tts(hp,args): 33 | 34 | 35 | # init codec 36 | model_name = "libritts_v1" 37 | device = args.device 38 | sample_rate, encoder_checkpoint, decoder_checkpoint = assign_model(model_name) 39 | audiodec = AudioDec(tx_device=device , rx_device=device ) 40 | audiodec.load_transmitter(encoder_checkpoint) 41 | audiodec.load_receiver(encoder_checkpoint, decoder_checkpoint) 42 | 43 | # init valle 44 | ar = AR(hp,device) 45 | ar_ckpt = torch.load(args.ar_ckpt, map_location=device) 46 | ar.load_state_dict(ar_ckpt['model']) 47 | ar.eval() 48 | nar = NAR(hp, device).to(device) 49 | nar_ckpt = torch.load(args.nar_ckpt, map_location=device) 50 | nar_weight = nar_ckpt["model"] 51 | if list(nar_weight.keys())[0].startswith('_orig'): 52 | nar_weight = {} 53 | for k, v in nar_ckpt["model"].items(): 54 | nar_weight[k.split("_orig_mod.")[1]] = v 55 | nar.load_state_dict(nar_weight) 56 | nar.eval() 57 | 58 | # prepare data 59 | prompt_text = np.array(generate_token(args.prompt_text)) 60 | prompt_token = tokenize_wav(args.prompt_wav,audiodec,device,sample_rate) 61 | text = np.array(generate_token(args.text)) 62 | 63 | semantic_token = torch.from_numpy(np.concatenate((prompt_text,text))) 64 | t_size, dim_size = prompt_token.shape 65 | semantic_token = semantic_token + 1 # padding 66 | acoustic_token = prompt_token + 1 + hp.num_semantic # padding + len(text) 67 | 68 | acoustic_len = prompt_token.shape[0] 69 | semantic_len = semantic_token.shape[0] 70 | max_len = 2040 - semantic_len-acoustic_len 71 | eos_id = hp.num_semantic + hp.num_acoustic + 1 72 | 73 | first_idx = np.asarray(list(semantic_token) + [eos_id] + list(acoustic_token[:, 0])) 74 | full_semantic = np.stack([semantic_token] * dim_size, axis=1) # t,8 75 | eos_full = np.stack([np.asarray([eos_id, ])] * dim_size, axis=1) # 1, 8 76 | full_idx = np.concatenate([full_semantic, eos_full, acoustic_token], axis=0) 77 | 78 | 79 | pos_ids= np.asarray(list(range(semantic_len + 1)) + list(range(acoustic_len))) 80 | tags = np.asarray([1] * (semantic_len + 1) + [2] * (acoustic_len)) 81 | 82 | data = {} 83 | data["first_idx"] = torch.from_numpy(np.asarray(first_idx)).unsqueeze(0).to(device) 84 | data["seq_lens"] = torch.from_numpy(np.asarray(first_idx.shape[0])).unsqueeze(0).to(device) 85 | data["full_idx"] = torch.from_numpy(np.asarray(full_idx)).unsqueeze(0).to(device) 86 | data["pos_ids"] = torch.from_numpy(np.asarray(pos_ids)).unsqueeze(0).to(device) 87 | data["tags"] = torch.from_numpy(np.asarray(tags)).unsqueeze(0).to(device) 88 | 89 | 90 | # infer 91 | with torch.no_grad(): 92 | first_idx = ar.inference(data,max_len,20) 93 | full_idx = nar.inference(data) 94 | full_idx = (full_idx - (hp.num_semantic+1))[:, :, :-1] 95 | 96 | prompt_idx = full_idx[:,:,:data['prompt_len']] 97 | full_idx = full_idx[:,:,data['prompt_len']:] 98 | full_idx = full_idx.squeeze(0) 99 | 100 | inc = (torch.arange(8)*1024).to(device) 101 | full_idx = full_idx + inc.reshape(-1,1) 102 | zq = audiodec.rx_encoder.lookup(full_idx) 103 | 104 | try: 105 | res_wav = audiodec.decoder.decode(zq) 106 | res_wav = res_wav.squeeze(1).transpose(1, 0).cpu().numpy() 107 | except: 108 | print('error in reconstruction') 109 | 110 | sf.write( 111 | './test/syn.wav', 112 | res_wav, 113 | sample_rate, 114 | "PCM_16",) 115 | 116 | 117 | if __name__ == '__main__': 118 | 119 | import argparse 120 | 121 | parser = argparse.ArgumentParser() 122 | 123 | parser.add_argument('--config', 124 | type=str, 125 | default='config/hparams.yaml', 126 | help='config path') 127 | 128 | parser.add_argument('--ar_ckpt', 129 | type=str, 130 | default='ckpt/basic/ar.pt', 131 | help='AR ckpt path') 132 | 133 | parser.add_argument('--nar_ckpt', 134 | type=str, 135 | default='ckpt/basic/nar.pt', 136 | help='NAR ckpt path') 137 | parser.add_argument('--prompt_wav', 138 | type=str, 139 | default='test/prompt_wavs/test_1.wav', 140 | help='out_dir') 141 | parser.add_argument('--prompt_text', 142 | type=str, 143 | default='在夏日阴凉的树荫下,鸭妈妈孵着鸭宝宝。', 144 | help='out_dir') 145 | parser.add_argument('--text', 146 | type=str, 147 | default='这一段戏同样也表达了亚瑟说的,我原本以为我的人生是一出悲剧,但其实它是一出戏剧。', 148 | help='out_dir') 149 | parser.add_argument('--device', 150 | type=str, 151 | default='cpu', 152 | help='cpu or cuda') 153 | 154 | args = parser.parse_args() 155 | hp = get_config_from_file(args.config).hparams 156 | 157 | do_tts(hp,args) -------------------------------------------------------------------------------- /models/ar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | from transformers import GPT2Config, GPT2LMHeadModel 5 | import torch.nn.functional as F 6 | from tqdm import tqdm 7 | 8 | 9 | class AR(nn.Module): 10 | def __init__(self, hparams, device): 11 | super().__init__() 12 | self.hparams = hparams 13 | self.device = device 14 | self.lm_model = GPT2LMHeadModel(config=self._model_config()).to(device) 15 | 16 | def _model_config(self): 17 | return GPT2Config( 18 | vocab_size=self.hparams.GPT2_vocab_size, 19 | n_positions=self.hparams.GPT2_n_positions, 20 | n_ctx=self.hparams.GPT2_n_ctx, 21 | n_embd=self.hparams.GPT2_n_embd, 22 | n_layer=self.hparams.GPT2_n_layer, 23 | n_head=self.hparams.GPT2_n_head, 24 | ) 25 | 26 | 27 | def inference(self, data,max_len=None,topk=None): 28 | idx = data['first_idx'] 29 | pos_ids = data['pos_ids'] 30 | tags = data['tags'] 31 | semantic_len = (tags == 1).sum(dim=1) 32 | prompt_acoustic_len = (tags == 2).sum(dim=1) 33 | data['prompt_len'] = prompt_acoustic_len 34 | 35 | if max_len == None: 36 | max_len = int((semantic_len - 1) * 320 / 16000 * 24000 / 300 - prompt_acoustic_len) + 3 37 | 38 | for j in tqdm(range(max_len)): 39 | lm_outputs = self.lm_model( 40 | input_ids=idx, 41 | attention_mask=None, 42 | position_ids=pos_ids, 43 | ) 44 | logits = lm_outputs['logits'] 45 | logits = logits * self.hparams.temperature 46 | logits[:, :, 0:(self.hparams.num_semantic+1)] = -float('Inf') # semantic token 47 | logits[:, :, self.hparams.num_semantic + 1 + self.hparams.num_acoustic] =-float('Inf') # eos 48 | 49 | logits=logits[:, -1, :] 50 | if topk is not None: 51 | v, _ = torch.topk(logits, min(topk, logits.size(-1))) 52 | logits[logits < v[:, [-1]]] = -float('Inf') 53 | 54 | 55 | probs = logits.softmax(dim=-1) # [b, d] 56 | dist = torch.distributions.categorical.Categorical(probs=probs) 57 | samples = dist.sample().unsqueeze(0) 58 | 59 | idx = torch.cat([idx, samples], dim=1) # [b, t] 60 | pos_ids = torch.cat([pos_ids, pos_ids[:, -1:] + 1], dim=1) # [b, t] 61 | tags = torch.cat([tags, torch.zeros_like(tags[:, -1:]) + 2], dim=1) 62 | if max_len is not None: 63 | if samples.item() == self.hparams.num_semantic + 1 + self.hparams.num_acoustic +1: # eos 64 | break 65 | if j == max_len-1: 66 | # too long break 67 | samples[:,:] = self.hparams.num_semantic + 1 + self.hparams.num_acoustic +1 68 | idx = torch.cat([idx, samples], dim=1) 69 | pos_ids = torch.cat([pos_ids, pos_ids[:, -1:] + 1], dim=1) # [b, t] 70 | tags = torch.cat([tags, torch.zeros_like(tags[:, -1:]) + 2], dim=1) 71 | break 72 | data['tags'] = tags 73 | data['pos_ids'] = pos_ids 74 | data['first_idx'] = idx 75 | return idx 76 | -------------------------------------------------------------------------------- /models/nar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | from torch.nn import TransformerEncoderLayer 5 | import random 6 | import torch.nn.functional as F 7 | import math 8 | 9 | class new_gelu(nn.Module): 10 | def forward(self, x): 11 | return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) 12 | 13 | class NAR(nn.Module): 14 | def __init__(self, hparams, device): 15 | super().__init__() 16 | self.hparams = hparams 17 | self.device = device 18 | 19 | self.token_embeddings = nn.ModuleList([nn.Embedding(self.hparams.nar_vocab_size, embedding_dim=self.hparams.nar_n_emb, padding_idx=0)] * self.hparams.acoustic_num_layer) 20 | self.wpe = nn.Embedding(2048, self.hparams.nar_n_emb) 21 | self.dense_layer = nn.Linear(self.hparams.nar_n_emb * self.hparams.acoustic_num_layer, self.hparams.nar_n_emb, bias=False) 22 | 23 | self.res_embedding = nn.Embedding(self.hparams.acoustic_num_layer - 1, self.hparams.nar_n_emb) 24 | 25 | activation = new_gelu() 26 | self.transformer_layers = nn.ModuleList() 27 | for i in range(self.hparams.nar_n_layer): 28 | self.transformer_layers.append( 29 | TransformerEncoderLayer( 30 | d_model=self.hparams.nar_n_emb, 31 | nhead=self.hparams.nar_n_head , 32 | dim_feedforward=self.hparams.nar_n_emb * 4, 33 | dropout=0.1, 34 | activation=activation, 35 | layer_norm_eps=1e-05, 36 | batch_first=True, 37 | ) 38 | ) 39 | self.mlp_layer = nn.Linear(self.hparams.nar_n_emb, self.hparams.num_acoustic + 2, bias=False) # quant_token_num + eos 40 | 41 | 42 | def inference(self, data): 43 | # unpack data 44 | init_full_idx = data['full_idx'] 45 | pos_ids = data['pos_ids'] 46 | tags = data['tags'] 47 | prompt_len = data['prompt_len'] 48 | 49 | # Calculate lengths 50 | semantic_len = (tags == 1).sum(dim=1) 51 | acoustic_len = (tags == 2).sum(dim=1) 52 | 53 | # Initialize the sequence tensor with indices for each acoustic layer 54 | full_idx = torch.stack([data['first_idx']] * self.hparams.acoustic_num_layer, dim=1) 55 | full_idx[:, :, semantic_len[0]:semantic_len[0] + prompt_len[0]] = init_full_idx[:,semantic_len[0]:, :].transpose(1, 2) 56 | 57 | batch_size, layer_size, t_size = full_idx.size() 58 | 59 | # Create masks 60 | layer_index = torch.ones(size=[batch_size,], device=self.device) 61 | layer_mask = layer_index.unsqueeze(1) > torch.arange(self.hparams.acoustic_num_layer, device=self.device).unsqueeze(0) 62 | prompt_mask = (semantic_len + prompt_len).unsqueeze(1) > torch.arange(t_size, device=self.device).unsqueeze(0) 63 | 64 | # Combine masks 65 | mask = layer_mask.unsqueeze(2) + prompt_mask.unsqueeze(1) 66 | full_idx = torch.where(mask, full_idx, torch.zeros_like(full_idx)) 67 | 68 | for layer_index in range(1, self.hparams.acoustic_num_layer): 69 | layer_index_tensor = torch.LongTensor([layer_index, ]).repeat(batch_size, ).to(self.device) 70 | layer_mask = layer_index_tensor.unsqueeze(1) > torch.arange(self.hparams.acoustic_num_layer, device=self.device).unsqueeze(0) 71 | 72 | # Apply masks 73 | mask = layer_mask.unsqueeze(2) + prompt_mask.unsqueeze(1) 74 | mask_full_idx = torch.where(mask, full_idx, torch.zeros_like(full_idx)) 75 | 76 | # Generate embeddings 77 | embeddings = [self.token_embeddings[i](mask_full_idx[:, i, :]) for i in range(layer_size)] 78 | embeddings = torch.cat(embeddings, dim=-1) 79 | embeddings = self.dense_layer(embeddings) 80 | 81 | # Add layer and position embeddings 82 | res_embeddings = self.res_embedding(layer_index_tensor - 1).unsqueeze(1) 83 | outputs = embeddings + self.wpe(pos_ids) + res_embeddings 84 | 85 | # Apply transformer layers 86 | for layer in self.transformer_layers: 87 | outputs = layer(outputs) 88 | 89 | logits = self.mlp_layer(outputs) 90 | logits[:, :, -2:] = -float('Inf') # unused token 91 | 92 | # NAR greedy search 93 | samples = torch.argmax(logits, dim=-1) + self.hparams.num_semantic + 1 94 | samples = torch.where(prompt_mask, full_idx[:, layer_index, :], samples) 95 | full_idx[:, layer_index, :] = samples 96 | 97 | return full_idx[:, :, semantic_len[0]:] 98 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.25.1 2 | numba==0.55.1 3 | tqdm 4 | julius 5 | tabulate 6 | einops 7 | pydub 8 | pypinyin 9 | cn2an 10 | jieba 11 | soundfile 12 | sounddevice 13 | numpy 14 | torch 15 | torchaudio 16 | librosa 17 | argparse 18 | pyyaml 19 | -------------------------------------------------------------------------------- /test/prompt_wavs/test_1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dukGuo/valle-audiodec/cbd51784ea9fe0bee05b4a65c95f39523b57be89/test/prompt_wavs/test_1.wav -------------------------------------------------------------------------------- /text/chinese.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | import cn2an 5 | from pypinyin import lazy_pinyin, Style 6 | 7 | from text.symbols import punctuation 8 | from text.tone_sandhi import ToneSandhi 9 | 10 | current_file_path = os.path.dirname(__file__) 11 | pinyin_to_symbol_map = { 12 | line.split("\t")[0]: line.strip().split("\t")[1] 13 | for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines() 14 | } 15 | 16 | import jieba.posseg as psg 17 | 18 | finals= ['a','ai','an','ang','ao','e','ei','en','eng','er','i','i0','ia','ian','iang','iao','ie','in','ing','iong','ir','iu','o','ong','ou','u','ua','uai','uan','uang','ui','un','uo','v','van','ve','vn'] 19 | 20 | rep_map = { 21 | ":": ",", 22 | ";": ",", 23 | ",": ",", 24 | "。": ".", 25 | "!": "!", 26 | "?": "?", 27 | "\n": ".", 28 | "·": ",", 29 | "、": ",", 30 | "...": "…", 31 | "$": ".", 32 | "“": "'", 33 | "”": "'", 34 | '"': "'", 35 | "‘": "'", 36 | "’": "'", 37 | "(": "'", 38 | ")": "'", 39 | "(": "'", 40 | ")": "'", 41 | "《": "'", 42 | "》": "'", 43 | "【": "'", 44 | "】": "'", 45 | "[": "'", 46 | "]": "'", 47 | "—": "-", 48 | "~": "-", 49 | "~": "-", 50 | "「": "'", 51 | "」": "'", 52 | } 53 | 54 | tone_modifier = ToneSandhi() 55 | 56 | 57 | def replace_punctuation(text): 58 | text = text.replace("嗯", "恩").replace("呣", "母") 59 | pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys())) 60 | 61 | replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) 62 | 63 | replaced_text = re.sub( 64 | r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text 65 | ) 66 | 67 | return replaced_text 68 | 69 | 70 | def g2p(text): 71 | pattern = r"(?<=[{0}])\s*".format("".join(punctuation)) 72 | sentences = [i for i in re.split(pattern, text) if i.strip() != ""] 73 | phones, tones, word2ph = _g2p(sentences) 74 | assert sum(word2ph) == len(phones) 75 | assert len(word2ph) == len(text) # Sometimes it will crash,you can add a try-catch. 76 | phones = ["_"] + phones + ["_"] 77 | tones = [0] + tones + [0] 78 | word2ph = [1] + word2ph + [1] 79 | return phones, tones, word2ph 80 | 81 | 82 | def _get_initials_finals(word): 83 | initials = [] 84 | finals = [] 85 | orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS) 86 | orig_finals = lazy_pinyin( 87 | word, neutral_tone_with_five=True, style=Style.FINALS_TONE3 88 | ) 89 | for c, v in zip(orig_initials, orig_finals): 90 | initials.append(c) 91 | finals.append(v) 92 | return initials, finals 93 | 94 | 95 | def _g2p(segments): 96 | phones_list = [] 97 | tones_list = [] 98 | word2ph = [] 99 | for seg in segments: 100 | # Replace all English words in the sentence 101 | seg = re.sub("[a-zA-Z]+", "", seg) 102 | seg_cut = psg.lcut(seg) 103 | initials = [] 104 | finals = [] 105 | seg_cut = tone_modifier.pre_merge_for_modify(seg_cut) 106 | for word, pos in seg_cut: 107 | if pos == "eng": 108 | continue 109 | sub_initials, sub_finals = _get_initials_finals(word) 110 | sub_finals = tone_modifier.modified_tone(word, pos, sub_finals) 111 | initials.append(sub_initials) 112 | finals.append(sub_finals) 113 | 114 | # assert len(sub_initials) == len(sub_finals) == len(word) 115 | initials = sum(initials, []) 116 | finals = sum(finals, []) 117 | # 118 | for c, v in zip(initials, finals): 119 | raw_pinyin = c + v 120 | # NOTE: post process for pypinyin outputs 121 | # we discriminate i, ii and iii 122 | if c == v: 123 | assert c in punctuation 124 | phone = [c] 125 | tone = "0" 126 | word2ph.append(1) 127 | else: 128 | v_without_tone = v[:-1] 129 | tone = v[-1] 130 | 131 | pinyin = c + v_without_tone 132 | assert tone in "12345" 133 | 134 | if c: 135 | # 多音节 136 | v_rep_map = { 137 | "uei": "ui", 138 | "iou": "iu", 139 | "uen": "un", 140 | } 141 | if v_without_tone in v_rep_map.keys(): 142 | pinyin = c + v_rep_map[v_without_tone] 143 | else: 144 | # 单音节 145 | pinyin_rep_map = { 146 | "ing": "ying", 147 | "i": "yi", 148 | "in": "yin", 149 | "u": "wu", 150 | } 151 | if pinyin in pinyin_rep_map.keys(): 152 | pinyin = pinyin_rep_map[pinyin] 153 | else: 154 | single_rep_map = { 155 | "v": "yu", 156 | "e": "e", 157 | "i": "y", 158 | "u": "w", 159 | } 160 | if pinyin[0] in single_rep_map.keys(): 161 | pinyin = single_rep_map[pinyin[0]] + pinyin[1:] 162 | 163 | assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin) 164 | phone = pinyin_to_symbol_map[pinyin].split(" ") 165 | word2ph.append(len(phone)) 166 | 167 | phones_list += phone 168 | tones_list += [int(tone)] * len(phone) 169 | return phones_list, tones_list, word2ph 170 | 171 | 172 | def text_normalize(text): 173 | numbers = re.findall(r"\d+(?:\.?\d+)?", text) 174 | for number in numbers: 175 | text = text.replace(number, cn2an.an2cn(number), 1) 176 | text = replace_punctuation(text) 177 | return text 178 | 179 | def generate_token(text): 180 | 181 | phone_set = "text/chinese_dict" 182 | with open(phone_set) as ps: 183 | phone_map = { pho.strip():idx+1 for idx,pho in enumerate(ps.readlines())} 184 | 185 | norm_text = text_normalize(text) 186 | phones, tones, word2ph = g2p(norm_text) 187 | 188 | final_phone_lst = [] 189 | for phone,tone in zip(phones,tones): 190 | if '_' == phone: 191 | continue 192 | if phone in finals: 193 | phone = phone+str(tone) 194 | if phone == '!' or phone == '…': 195 | phone = '.' 196 | if phone == '-': 197 | phone = ',' 198 | final_phone_lst.append(phone.strip()) 199 | if 'iong4' in final_phone_lst: # 默认音素集内不存在iong4 200 | assert 0 201 | token_lst = [phone_map[phone] for phone in final_phone_lst] 202 | 203 | return token_lst 204 | -------------------------------------------------------------------------------- /text/chinese_dict: -------------------------------------------------------------------------------- 1 | ' 2 | , 3 | . 4 | ? 5 | AA 6 | E 7 | EE 8 | En 9 | OO 10 | a1 11 | a2 12 | a3 13 | a4 14 | a5 15 | ai1 16 | ai2 17 | ai3 18 | ai4 19 | ai5 20 | an1 21 | an2 22 | an3 23 | an4 24 | an5 25 | ang1 26 | ang2 27 | ang3 28 | ang4 29 | ang5 30 | ao1 31 | ao2 32 | ao3 33 | ao4 34 | ao5 35 | b 36 | c 37 | ch 38 | d 39 | e1 40 | e2 41 | e3 42 | e4 43 | e5 44 | ei1 45 | ei2 46 | ei3 47 | ei4 48 | ei5 49 | en1 50 | en2 51 | en3 52 | en4 53 | en5 54 | eng1 55 | eng2 56 | eng3 57 | eng4 58 | eng5 59 | er2 60 | er3 61 | er4 62 | er5 63 | f 64 | g 65 | h 66 | i01 67 | i02 68 | i03 69 | i04 70 | i05 71 | i1 72 | i2 73 | i3 74 | i4 75 | i5 76 | ia1 77 | ia2 78 | ia3 79 | ia4 80 | ia5 81 | ian1 82 | ian2 83 | ian3 84 | ian4 85 | ian5 86 | iang1 87 | iang2 88 | iang3 89 | iang4 90 | iang5 91 | iao1 92 | iao2 93 | iao3 94 | iao4 95 | iao5 96 | ie1 97 | ie2 98 | ie3 99 | ie4 100 | ie5 101 | in1 102 | in2 103 | in3 104 | in4 105 | in5 106 | ing1 107 | ing2 108 | ing3 109 | ing4 110 | ing5 111 | iong1 112 | iong2 113 | iong3 114 | iong5 115 | ir1 116 | ir2 117 | ir3 118 | ir4 119 | ir5 120 | iu1 121 | iu2 122 | iu3 123 | iu4 124 | iu5 125 | j 126 | k 127 | l 128 | m 129 | n 130 | o1 131 | o2 132 | o3 133 | o4 134 | o5 135 | ong1 136 | ong2 137 | ong3 138 | ong4 139 | ong5 140 | ou1 141 | ou2 142 | ou3 143 | ou4 144 | ou5 145 | p 146 | q 147 | r 148 | s 149 | sh 150 | t 151 | u1 152 | u2 153 | u3 154 | u4 155 | u5 156 | ua1 157 | ua2 158 | ua3 159 | ua4 160 | ua5 161 | uai1 162 | uai2 163 | uai3 164 | uai4 165 | uai5 166 | uan1 167 | uan2 168 | uan3 169 | uan4 170 | uan5 171 | uang1 172 | uang2 173 | uang3 174 | uang4 175 | uang5 176 | ui1 177 | ui2 178 | ui3 179 | ui4 180 | ui5 181 | un1 182 | un2 183 | un3 184 | un4 185 | un5 186 | uo1 187 | uo2 188 | uo3 189 | uo4 190 | uo5 191 | v1 192 | v2 193 | v3 194 | v4 195 | v5 196 | van1 197 | van2 198 | van3 199 | van4 200 | van5 201 | ve1 202 | ve2 203 | ve3 204 | ve4 205 | ve5 206 | vn1 207 | vn2 208 | vn3 209 | vn4 210 | vn5 211 | w 212 | x 213 | y 214 | z 215 | zh -------------------------------------------------------------------------------- /text/opencpop-strict.txt: -------------------------------------------------------------------------------- 1 | a AA a 2 | ai AA ai 3 | an AA an 4 | ang AA ang 5 | ao AA ao 6 | ba b a 7 | bai b ai 8 | ban b an 9 | bang b ang 10 | bao b ao 11 | bei b ei 12 | ben b en 13 | beng b eng 14 | bi b i 15 | bian b ian 16 | biao b iao 17 | bie b ie 18 | bin b in 19 | bing b ing 20 | bo b o 21 | bu b u 22 | ca c a 23 | cai c ai 24 | can c an 25 | cang c ang 26 | cao c ao 27 | ce c e 28 | cei c ei 29 | cen c en 30 | ceng c eng 31 | cha ch a 32 | chai ch ai 33 | chan ch an 34 | chang ch ang 35 | chao ch ao 36 | che ch e 37 | chen ch en 38 | cheng ch eng 39 | chi ch ir 40 | chong ch ong 41 | chou ch ou 42 | chu ch u 43 | chua ch ua 44 | chuai ch uai 45 | chuan ch uan 46 | chuang ch uang 47 | chui ch ui 48 | chun ch un 49 | chuo ch uo 50 | ci c i0 51 | cong c ong 52 | cou c ou 53 | cu c u 54 | cuan c uan 55 | cui c ui 56 | cun c un 57 | cuo c uo 58 | da d a 59 | dai d ai 60 | dan d an 61 | dang d ang 62 | dao d ao 63 | de d e 64 | dei d ei 65 | den d en 66 | deng d eng 67 | di d i 68 | dia d ia 69 | dian d ian 70 | diao d iao 71 | die d ie 72 | ding d ing 73 | diu d iu 74 | dong d ong 75 | dou d ou 76 | du d u 77 | duan d uan 78 | dui d ui 79 | dun d un 80 | duo d uo 81 | e EE e 82 | ei EE ei 83 | en EE en 84 | eng EE eng 85 | er EE er 86 | fa f a 87 | fan f an 88 | fang f ang 89 | fei f ei 90 | fen f en 91 | feng f eng 92 | fo f o 93 | fou f ou 94 | fu f u 95 | ga g a 96 | gai g ai 97 | gan g an 98 | gang g ang 99 | gao g ao 100 | ge g e 101 | gei g ei 102 | gen g en 103 | geng g eng 104 | gong g ong 105 | gou g ou 106 | gu g u 107 | gua g ua 108 | guai g uai 109 | guan g uan 110 | guang g uang 111 | gui g ui 112 | gun g un 113 | guo g uo 114 | ha h a 115 | hai h ai 116 | han h an 117 | hang h ang 118 | hao h ao 119 | he h e 120 | hei h ei 121 | hen h en 122 | heng h eng 123 | hong h ong 124 | hou h ou 125 | hu h u 126 | hua h ua 127 | huai h uai 128 | huan h uan 129 | huang h uang 130 | hui h ui 131 | hun h un 132 | huo h uo 133 | ji j i 134 | jia j ia 135 | jian j ian 136 | jiang j iang 137 | jiao j iao 138 | jie j ie 139 | jin j in 140 | jing j ing 141 | jiong j iong 142 | jiu j iu 143 | ju j v 144 | jv j v 145 | juan j van 146 | jvan j van 147 | jue j ve 148 | jve j ve 149 | jun j vn 150 | jvn j vn 151 | ka k a 152 | kai k ai 153 | kan k an 154 | kang k ang 155 | kao k ao 156 | ke k e 157 | kei k ei 158 | ken k en 159 | keng k eng 160 | kong k ong 161 | kou k ou 162 | ku k u 163 | kua k ua 164 | kuai k uai 165 | kuan k uan 166 | kuang k uang 167 | kui k ui 168 | kun k un 169 | kuo k uo 170 | la l a 171 | lai l ai 172 | lan l an 173 | lang l ang 174 | lao l ao 175 | le l e 176 | lei l ei 177 | leng l eng 178 | li l i 179 | lia l ia 180 | lian l ian 181 | liang l iang 182 | liao l iao 183 | lie l ie 184 | lin l in 185 | ling l ing 186 | liu l iu 187 | lo l o 188 | long l ong 189 | lou l ou 190 | lu l u 191 | luan l uan 192 | lun l un 193 | luo l uo 194 | lv l v 195 | lve l ve 196 | ma m a 197 | mai m ai 198 | man m an 199 | mang m ang 200 | mao m ao 201 | me m e 202 | mei m ei 203 | men m en 204 | meng m eng 205 | mi m i 206 | mian m ian 207 | miao m iao 208 | mie m ie 209 | min m in 210 | ming m ing 211 | miu m iu 212 | mo m o 213 | mou m ou 214 | mu m u 215 | na n a 216 | nai n ai 217 | nan n an 218 | nang n ang 219 | nao n ao 220 | ne n e 221 | nei n ei 222 | nen n en 223 | neng n eng 224 | ni n i 225 | nian n ian 226 | niang n iang 227 | niao n iao 228 | nie n ie 229 | nin n in 230 | ning n ing 231 | niu n iu 232 | nong n ong 233 | nou n ou 234 | nu n u 235 | nuan n uan 236 | nun n un 237 | nuo n uo 238 | nv n v 239 | nve n ve 240 | o OO o 241 | ou OO ou 242 | pa p a 243 | pai p ai 244 | pan p an 245 | pang p ang 246 | pao p ao 247 | pei p ei 248 | pen p en 249 | peng p eng 250 | pi p i 251 | pian p ian 252 | piao p iao 253 | pie p ie 254 | pin p in 255 | ping p ing 256 | po p o 257 | pou p ou 258 | pu p u 259 | qi q i 260 | qia q ia 261 | qian q ian 262 | qiang q iang 263 | qiao q iao 264 | qie q ie 265 | qin q in 266 | qing q ing 267 | qiong q iong 268 | qiu q iu 269 | qu q v 270 | qv q v 271 | quan q van 272 | qvan q van 273 | que q ve 274 | qve q ve 275 | qun q vn 276 | qvn q vn 277 | ran r an 278 | rang r ang 279 | rao r ao 280 | re r e 281 | ren r en 282 | reng r eng 283 | ri r ir 284 | rong r ong 285 | rou r ou 286 | ru r u 287 | rua r ua 288 | ruan r uan 289 | rui r ui 290 | run r un 291 | ruo r uo 292 | sa s a 293 | sai s ai 294 | san s an 295 | sang s ang 296 | sao s ao 297 | se s e 298 | sen s en 299 | seng s eng 300 | sha sh a 301 | shai sh ai 302 | shan sh an 303 | shang sh ang 304 | shao sh ao 305 | she sh e 306 | shei sh ei 307 | shen sh en 308 | sheng sh eng 309 | shi sh ir 310 | shou sh ou 311 | shu sh u 312 | shua sh ua 313 | shuai sh uai 314 | shuan sh uan 315 | shuang sh uang 316 | shui sh ui 317 | shun sh un 318 | shuo sh uo 319 | si s i0 320 | song s ong 321 | sou s ou 322 | su s u 323 | suan s uan 324 | sui s ui 325 | sun s un 326 | suo s uo 327 | ta t a 328 | tai t ai 329 | tan t an 330 | tang t ang 331 | tao t ao 332 | te t e 333 | tei t ei 334 | teng t eng 335 | ti t i 336 | tian t ian 337 | tiao t iao 338 | tie t ie 339 | ting t ing 340 | tong t ong 341 | tou t ou 342 | tu t u 343 | tuan t uan 344 | tui t ui 345 | tun t un 346 | tuo t uo 347 | wa w a 348 | wai w ai 349 | wan w an 350 | wang w ang 351 | wei w ei 352 | wen w en 353 | weng w eng 354 | wo w o 355 | wu w u 356 | xi x i 357 | xia x ia 358 | xian x ian 359 | xiang x iang 360 | xiao x iao 361 | xie x ie 362 | xin x in 363 | xing x ing 364 | xiong x iong 365 | xiu x iu 366 | xu x v 367 | xv x v 368 | xuan x van 369 | xvan x van 370 | xue x ve 371 | xve x ve 372 | xun x vn 373 | xvn x vn 374 | ya y a 375 | yan y En 376 | yang y ang 377 | yao y ao 378 | ye y E 379 | yi y i 380 | yin y in 381 | ying y ing 382 | yo y o 383 | yong y ong 384 | you y ou 385 | yu y v 386 | yv y v 387 | yuan y van 388 | yvan y van 389 | yue y ve 390 | yve y ve 391 | yun y vn 392 | yvn y vn 393 | za z a 394 | zai z ai 395 | zan z an 396 | zang z ang 397 | zao z ao 398 | ze z e 399 | zei z ei 400 | zen z en 401 | zeng z eng 402 | zha zh a 403 | zhai zh ai 404 | zhan zh an 405 | zhang zh ang 406 | zhao zh ao 407 | zhe zh e 408 | zhei zh ei 409 | zhen zh en 410 | zheng zh eng 411 | zhi zh ir 412 | zhong zh ong 413 | zhou zh ou 414 | zhu zh u 415 | zhua zh ua 416 | zhuai zh uai 417 | zhuan zh uan 418 | zhuang zh uang 419 | zhui zh ui 420 | zhun zh un 421 | zhuo zh uo 422 | zi z i0 423 | zong z ong 424 | zou z ou 425 | zu z u 426 | zuan z uan 427 | zui z ui 428 | zun z un 429 | zuo z uo 430 | -------------------------------------------------------------------------------- /text/symbols.py: -------------------------------------------------------------------------------- 1 | punctuation = ["!", "?", "…", ",", ".", "'", "-"] 2 | pu_symbols = punctuation + ["SP", "UNK"] 3 | pad = "_" 4 | 5 | # chinese 6 | zh_symbols = [ 7 | "E", 8 | "En", 9 | "a", 10 | "ai", 11 | "an", 12 | "ang", 13 | "ao", 14 | "b", 15 | "c", 16 | "ch", 17 | "d", 18 | "e", 19 | "ei", 20 | "en", 21 | "eng", 22 | "er", 23 | "f", 24 | "g", 25 | "h", 26 | "i", 27 | "i0", 28 | "ia", 29 | "ian", 30 | "iang", 31 | "iao", 32 | "ie", 33 | "in", 34 | "ing", 35 | "iong", 36 | "ir", 37 | "iu", 38 | "j", 39 | "k", 40 | "l", 41 | "m", 42 | "n", 43 | "o", 44 | "ong", 45 | "ou", 46 | "p", 47 | "q", 48 | "r", 49 | "s", 50 | "sh", 51 | "t", 52 | "u", 53 | "ua", 54 | "uai", 55 | "uan", 56 | "uang", 57 | "ui", 58 | "un", 59 | "uo", 60 | "v", 61 | "van", 62 | "ve", 63 | "vn", 64 | "w", 65 | "x", 66 | "y", 67 | "z", 68 | "zh", 69 | "AA", 70 | "EE", 71 | "OO", 72 | ] 73 | num_zh_tones = 6 74 | 75 | # japanese 76 | ja_symbols = [ 77 | "N", 78 | "a", 79 | "a:", 80 | "b", 81 | "by", 82 | "ch", 83 | "d", 84 | "dy", 85 | "e", 86 | "e:", 87 | "f", 88 | "g", 89 | "gy", 90 | "h", 91 | "hy", 92 | "i", 93 | "i:", 94 | "j", 95 | "k", 96 | "ky", 97 | "m", 98 | "my", 99 | "n", 100 | "ny", 101 | "o", 102 | "o:", 103 | "p", 104 | "py", 105 | "q", 106 | "r", 107 | "ry", 108 | "s", 109 | "sh", 110 | "t", 111 | "ts", 112 | "ty", 113 | "u", 114 | "u:", 115 | "w", 116 | "y", 117 | "z", 118 | "zy", 119 | ] 120 | num_ja_tones = 2 121 | 122 | # English 123 | en_symbols = [ 124 | "aa", 125 | "ae", 126 | "ah", 127 | "ao", 128 | "aw", 129 | "ay", 130 | "b", 131 | "ch", 132 | "d", 133 | "dh", 134 | "eh", 135 | "er", 136 | "ey", 137 | "f", 138 | "g", 139 | "hh", 140 | "ih", 141 | "iy", 142 | "jh", 143 | "k", 144 | "l", 145 | "m", 146 | "n", 147 | "ng", 148 | "ow", 149 | "oy", 150 | "p", 151 | "r", 152 | "s", 153 | "sh", 154 | "t", 155 | "th", 156 | "uh", 157 | "uw", 158 | "V", 159 | "w", 160 | "y", 161 | "z", 162 | "zh", 163 | ] 164 | num_en_tones = 4 165 | 166 | # combine all symbols 167 | normal_symbols = sorted(set(zh_symbols + ja_symbols + en_symbols)) 168 | symbols = [pad] + normal_symbols + pu_symbols 169 | sil_phonemes_ids = [symbols.index(i) for i in pu_symbols] 170 | 171 | # combine all tones 172 | num_tones = num_zh_tones + num_ja_tones + num_en_tones 173 | 174 | # language maps 175 | language_id_map = {"ZH": 0, "JP": 1, "EN": 2} 176 | num_languages = len(language_id_map.keys()) 177 | 178 | language_tone_start_map = { 179 | "ZH": 0, 180 | "JP": num_zh_tones, 181 | "EN": num_zh_tones + num_ja_tones, 182 | } 183 | 184 | if __name__ == "__main__": 185 | a = set(zh_symbols) 186 | b = set(en_symbols) 187 | print(sorted(a & b)) 188 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import torch 3 | import collections 4 | import os 5 | import numpy as np 6 | import random 7 | 8 | class ValueWindow(): 9 | def __init__(self, window_size=100): 10 | self._window_size = window_size 11 | self._values = [] 12 | 13 | def append(self, x): 14 | self._values = self._values[-(self._window_size - 1):] + [x] 15 | 16 | @property 17 | def sum(self): 18 | return sum(self._values) 19 | 20 | @property 21 | def count(self): 22 | return len(self._values) 23 | 24 | @property 25 | def average(self): 26 | return self.sum / max(1, self.count) 27 | 28 | def reset(self): 29 | self._values = [] 30 | 31 | 32 | def str2bool(v): 33 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 34 | return True 35 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 36 | return False 37 | else: 38 | raise ValueError('Unsupported value encountered.') 39 | 40 | 41 | # def hparams_string(hparams, args): 42 | # print('Output path: {}'.format(args.logdir_path)) 43 | # values = hparams.values() 44 | # hp = [' %s: %s' % (name, values[name]) 45 | # for name in sorted(values) if name != 'sentences'] 46 | # # print('Hyperparameters:\n' + '\n'.join(hp)) 47 | # return 48 | 49 | 50 | class HParams(): 51 | def __init__(self, **kwargs): 52 | for k, v in kwargs.items(): 53 | if type(v) == dict: 54 | v = HParams(**v) 55 | self[k] = v 56 | 57 | def keys(self): 58 | return self.__dict__.keys() 59 | 60 | def items(self): 61 | return self.__dict__.items() 62 | 63 | def values(self): 64 | return self.__dict__.values() 65 | 66 | def __len__(self): 67 | return len(self.__dict__) 68 | 69 | def __getitem__(self, key): 70 | return getattr(self, key) 71 | 72 | def __setitem__(self, key, value): 73 | return setattr(self, key, value) 74 | 75 | def __contains__(self, key): 76 | return key in self.__dict__ 77 | 78 | def __repr__(self): 79 | return self.__dict__.__repr__() 80 | 81 | 82 | def get_config_from_file(file): 83 | with open(file, 'r') as f: 84 | hp = yaml.load(f,Loader=yaml.FullLoader) 85 | hp = HParams(**hp) 86 | return hp 87 | 88 | def update_lr(opt, lr): 89 | for g in opt.param_groups: 90 | g['lr'] = lr 91 | return 92 | 93 | def to_device(tensors, device): 94 | tensors_to_device = [] 95 | for tensor in tensors: 96 | if isinstance(tensor, torch.Tensor): 97 | tensors_to_device.append(tensor.to(device)) 98 | else: 99 | tensors_to_device.append(tensor) 100 | return tensors_to_device 101 | 102 | def calculate_model_params(model): 103 | total = sum([param.numel() for param in model.parameters()]) 104 | para_name = [para[0] for para in list(model.named_parameters())] 105 | print("==================================================") 106 | print("model struct is {}\n".format(str(model))) 107 | print("model params : {}".format(para_name)) 108 | # log("FLOPs: {}".format(flops.total_float_ops)) 109 | print("Number of parameter: %.2fM" % (total / 1e6)) 110 | print("==================================================") 111 | return 112 | 113 | def get_metadata(path): 114 | with open(path, 'r') as f: 115 | metas = [l.strip() for l in f] 116 | random.shuffle(metas) 117 | return metas 118 | 119 | class TemperatureSampler(): 120 | """ 121 | ## Sampler with Temperature 122 | """ 123 | def __init__(self, temperature: float = 1.0): 124 | """ 125 | :param temperature: is the temperature to sample with 126 | """ 127 | self.temperature = temperature 128 | 129 | def __call__(self, logits: torch.Tensor): 130 | """ 131 | Sample from logits 132 | """ 133 | 134 | # Create a categorical distribution with temperature adjusted logits 135 | dist = torch.distributions.categorical.Categorical(logits=logits / self.temperature) 136 | 137 | # Sample 138 | return dist.sample() 139 | 140 | 141 | 142 | class TopKSampler(): 143 | """ 144 | ## Top-k Sampler 145 | """ 146 | def __init__(self, k: int, sampler=TemperatureSampler()): 147 | """ 148 | :param k: is the number of tokens to pick 149 | :param sampler: is the sampler to use for the top-k tokens 150 | 151 | `sampler` can be any sampler that takes a logits tensor as input and returns a token tensor; 152 | e.g. [`TemperatureSampler'](temperature.html). 153 | """ 154 | self.k = k 155 | self.sampler = sampler 156 | 157 | def __call__(self, logits: torch.Tensor): 158 | """ 159 | Sample from logits 160 | """ 161 | # New logits filled with $-\infty$; i.e. zero probability 162 | zeros = logits.new_ones(logits.shape) * float('-inf') 163 | # Pick the largest $k$ logits and their indices 164 | values, indices = torch.topk(logits, self.k, dim=-1) 165 | # Set the values of the top-k selected indices to actual logits. 166 | # Logits of other tokens remain $-\infty$ 167 | zeros.scatter_(-1, indices, values) 168 | 169 | # Sample from the top-k logits with the specified sampler. 170 | return self.sampler(zeros) --------------------------------------------------------------------------------