├── AudioDec
├── About
│ └── README.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── bin
│ ├── stream.py
│ ├── test.py
│ ├── train.py
│ └── utils.py
├── config
│ ├── autoencoder
│ │ ├── symAAD_vctk_48000_hop300.yaml
│ │ ├── symAD_c16_vctk_48000_hop320.yaml
│ │ ├── symAD_libritts_24000_hop300.yaml
│ │ ├── symAD_vctk_48000_hop300.yaml
│ │ └── symADuniv_vctk_48000_hop300.yaml
│ ├── denoise
│ │ └── symAD_vctk_48000_hop300.yaml
│ ├── statistic
│ │ ├── symAD_libritts_24000_hop300_clean.yaml
│ │ ├── symAD_vctk_48000_hop300_clean.yaml
│ │ └── symADuniv_vctk_48000_hop300_clean.yaml
│ └── vocoder
│ │ ├── AudioDec_v0_symAD_vctk_48000_hop300_clean.yaml
│ │ ├── AudioDec_v1_symAD_libritts_24000_hop300_clean.yaml
│ │ ├── AudioDec_v1_symAD_vctk_48000_hop300_clean.yaml
│ │ ├── AudioDec_v2_symAD_vctk_48000_hop300_clean.yaml
│ │ └── AudioDec_v3_symADuniv_vctk_48000_hop300_clean.yaml
├── dataloader
│ ├── __init__.py
│ ├── collater.py
│ ├── dataset.py
│ └── utils.py
├── figs
│ ├── architecture-1.png
│ ├── latency.jpg
│ └── mos.jpg
├── layers
│ ├── activation_function.py
│ ├── conv_layer.py
│ └── vq_module.py
├── losses
│ ├── __init__.py
│ ├── adversarial_loss.py
│ ├── feat_match_loss.py
│ ├── mel_loss.py
│ ├── stft_loss.py
│ └── waveform_loss.py
├── models
│ ├── autoencoder
│ │ ├── AudioDec.py
│ │ └── modules
│ │ │ ├── decoder.py
│ │ │ ├── encoder.py
│ │ │ ├── projector.py
│ │ │ ├── quantizer.py
│ │ │ └── residual_unit.py
│ ├── utils.py
│ └── vocoder
│ │ ├── HiFiGAN.py
│ │ └── modules
│ │ ├── discriminator.py
│ │ ├── multi_fusion.py
│ │ └── residual_block.py
├── parse_options.sh
├── requirements.txt
├── slurmlogs
│ └── README.md
├── stats
│ ├── symAD_libritts_24000_hop300_clean.npy
│ ├── symAD_vctk_48000_hop300_clean.npy
│ └── symADuniv_vctk_48000_hop300_clean.npy
├── trainer
│ ├── autoencoder.py
│ ├── denoise.py
│ ├── trainerGAN.py
│ └── vocoder.py
└── utils
│ └── audiodec.py
├── LICENSE
├── README.md
├── config
└── hparams.yaml
├── infer_tts.py
├── models
├── ar.py
└── nar.py
├── requirements.txt
├── test
└── prompt_wavs
│ └── test_1.wav
├── text
├── chinese.py
├── chinese_dict
├── opencpop-strict.txt
├── symbols.py
└── tone_sandhi.py
└── utils
├── hparams.py
└── utils.py
/AudioDec/About/README.md:
--------------------------------------------------------------------------------
1 | # AudioDec: An Open-source Streaming High-fidelity Neural Audio Codec
2 |
3 |
4 |
5 | ## Audio quality (Subjective MOS)
6 |
7 |
8 |
9 |
10 |
11 | ## Latency
12 |
13 |
14 |
15 |
16 |
17 |
18 | ## Model size
19 | Only the generators.
20 | ```bash
21 | # AutoEncoder (symAD)
22 | # - Encoder
23 | Number of total parmeters:3,806,368
24 | Model size: 14.68MB
25 | # - Decoder
26 | Number of total parmeters:4,035,264
27 | Model size: 15.54MB
28 | # Vocoder (AD v0)
29 | Number of total parmeters:12,932,610
30 | Model size: 49.74MB
31 | # Vocoder (AD v1)
32 | Number of total parmeters:19,461,090
33 | Model size: 74.90MB
34 | # Vocoder (AD v2)
35 | Number of total parmeters:6,927,330
36 | Model size: 26.56MB
37 | ```
38 |
--------------------------------------------------------------------------------
/AudioDec/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to make participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies within all project spaces, and it also applies when
49 | an individual is representing the project or its community in public spaces.
50 | Examples of representing a project or community include using an official
51 | project e-mail address, posting via an official social media account, or acting
52 | as an appointed representative at an online or offline event. Representation of
53 | a project may be further defined and clarified by project maintainers.
54 |
55 | This Code of Conduct also applies outside the project spaces when there is a
56 | reasonable belief that an individual's behavior may have a negative impact on
57 | the project or its community.
58 |
59 | ## Enforcement
60 |
61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
62 | reported by contacting the project team at . All
63 | complaints will be reviewed and investigated and will result in a response that
64 | is deemed necessary and appropriate to the circumstances. The project team is
65 | obligated to maintain confidentiality with regard to the reporter of an incident.
66 | Further details of specific enforcement policies may be posted separately.
67 |
68 | Project maintainers who do not follow or enforce the Code of Conduct in good
69 | faith may face temporary or permanent repercussions as determined by other
70 | members of the project's leadership.
71 |
72 | ## Attribution
73 |
74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
76 |
77 | [homepage]: https://www.contributor-covenant.org
78 |
79 | For answers to common questions about this code of conduct, see
80 | https://www.contributor-covenant.org/faq
81 |
--------------------------------------------------------------------------------
/AudioDec/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to AudioDec
2 | We want to make contributing to this project as easy and transparent as
3 | possible.
4 |
5 | ## Pull Requests
6 | AudioDec is the implementation of a research paper.
7 | Therefore, we do not plan on accepting many pull requests for new features.
8 | We certainly welcome them for bug fixes.
9 |
10 | 1. Fork the repo and create your branch from `main`.
11 | 2. If you've added code that should be tested, add tests.
12 | 3. If you've changed APIs, update the documentation.
13 | 4. Ensure the test suite passes.
14 | 5. Make sure your code lints.
15 | 6. If you haven't already, complete the Contributor License Agreement ("CLA").
16 |
17 | ## Contributor License Agreement ("CLA")
18 | In order to accept your pull request, we need you to submit a CLA. You only need
19 | to do this once to work on any of Meta's open source projects.
20 |
21 | Complete your CLA here:
22 |
23 | ## Issues
24 | We use GitHub issues to track public bugs. Please ensure your description is
25 | clear and has sufficient instructions to be able to reproduce the issue.
26 |
27 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
28 | disclosure of security bugs. In those cases, please go through the process
29 | outlined on that page and do not file a public issue.
30 |
31 | ## License
32 | By contributing to AudioDec, you agree that your contributions will be licensed
33 | under the LICENSE file in the root directory of this source tree.
34 |
--------------------------------------------------------------------------------
/AudioDec/bin/test.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright (c) Meta Platforms, Inc. and affiliates.
5 | # All rights reserved.
6 | #
7 | # This source code is licensed under the license found in the
8 | # LICENSE file in the root directory of this source tree.
9 | #
10 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/)
11 |
12 | """Testing stage template."""
13 |
14 | import os
15 | import abc
16 | import sys
17 | import time
18 | import yaml
19 | import torch
20 | import logging
21 | import soundfile as sf
22 |
23 | from tqdm import tqdm
24 | from bin.utils import load_config
25 |
26 | class TestGEN(abc.ABC):
27 | def __init__(
28 | self,
29 | args,
30 | ):
31 | # set logger
32 | logging.basicConfig(
33 | level=logging.INFO,
34 | stream=sys.stdout,
35 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
36 | )
37 |
38 | # device
39 | if not torch.cuda.is_available():
40 | self.device = torch.device('cpu')
41 | logging.info(f"device: cpu")
42 | else:
43 | self.device = torch.device('cuda')
44 | logging.info(f"device: gpu")
45 |
46 | # initialize attribute
47 | if hasattr(args, 'encoder'):
48 | self.encoder_checkpoint = args.encoder
49 | self.encoder_config = self._load_config(args.encoder)
50 | if hasattr(args, 'decoder'):
51 | self.decoder_checkpoint = args.decoder
52 | self.decoder_config = self._load_config(args.decoder)
53 | self.encoder = None
54 | self.decoder = None
55 | self.dataset = None
56 | self.outdir = None
57 |
58 |
59 | @abc.abstractmethod
60 | def initial_folder(self, output_name):
61 | pass
62 |
63 |
64 | @abc.abstractmethod
65 | def load_dataset(self):
66 | pass
67 |
68 |
69 | @abc.abstractmethod
70 | def load_encoder(self):
71 | pass
72 |
73 |
74 | @abc.abstractmethod
75 | def load_decoder(self):
76 | pass
77 |
78 |
79 | @abc.abstractmethod
80 | def encode(self, x):
81 | pass
82 |
83 |
84 | @abc.abstractmethod
85 | def decode(self, z):
86 | pass
87 |
88 |
89 | def run(self):
90 | total_rtf = 0.0
91 | with torch.no_grad(), tqdm(self.dataset, desc="[test]") as pbar:
92 | for idx, (utt_id, x) in enumerate(pbar, 1):
93 | start = time.time()
94 | zq = self.encode(x)
95 | y = self.decode(zq)
96 | y = y.squeeze(1).transpose(1, 0).cpu().numpy() # T x C
97 | rtf = (time.time() - start) / (len(y) / self.decoder_config['sampling_rate'])
98 | pbar.set_postfix({"RTF": rtf})
99 | total_rtf += rtf
100 |
101 | # output wav file
102 | self._save_wav(os.path.join(self.outdir, f"{utt_id}_output.wav"), y)
103 |
104 | logging.info(
105 | "Finished generation of %d utterances (RTF = %.03f)." % (idx, (total_rtf / idx))
106 | )
107 |
108 |
109 | def _save_wav(self, file_name, audio):
110 | sf.write(
111 | file_name,
112 | audio,
113 | self.decoder_config['sampling_rate'],
114 | "PCM_16",
115 | )
116 |
117 |
118 | def _load_config(self, checkpoint, config_name='config.yml'):
119 | return load_config(checkpoint, config_name)
120 |
--------------------------------------------------------------------------------
/AudioDec/bin/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright (c) Meta Platforms, Inc. and affiliates.
5 | # All rights reserved.
6 | #
7 | # This source code is licensed under the license found in the
8 | # LICENSE file in the root directory of this source tree.
9 | #
10 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/)
11 |
12 | """Training stage template."""
13 |
14 | import os
15 | import abc
16 | import sys
17 | import yaml
18 | import random
19 | import logging
20 | import torch
21 | import numpy as np
22 |
23 | from bin.utils import load_config
24 |
25 |
26 | class TrainGAN(abc.ABC):
27 | def __init__(
28 | self,
29 | args,
30 | ):
31 | # set logger
32 | logging.basicConfig(
33 | level=logging.INFO,
34 | stream=sys.stdout,
35 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
36 | )
37 |
38 | # Fix seed and make backends deterministic
39 | random.seed(args.seed)
40 | np.random.seed(args.seed)
41 | torch.manual_seed(args.seed)
42 | if not torch.cuda.is_available():
43 | self.device = torch.device('cpu')
44 | logging.info(f"device: cpu")
45 | else:
46 | self.device = torch.device('cuda')
47 | logging.info(f"device: gpu")
48 | torch.cuda.manual_seed_all(args.seed)
49 | if args.disable_cudnn == "False":
50 | torch.backends.cudnn.benchmark = True
51 |
52 | # initialize config
53 | with open(args.config, 'r') as f:
54 | self.config = yaml.load(f, Loader=yaml.FullLoader)
55 | self.config.update(vars(args))
56 |
57 | # initialize model folder
58 | expdir = os.path.join(args.exp_root, args.tag)
59 | os.makedirs(expdir, exist_ok=True)
60 | self.config["outdir"] = expdir
61 |
62 | # save config
63 | with open(os.path.join(expdir, "config.yml"), "w") as f:
64 | yaml.dump(self.config, f, Dumper=yaml.Dumper)
65 | for key, value in self.config.items():
66 | logging.info(f"[TrainGAN] {key} = {value}")
67 |
68 | # initialize attribute
69 | self.resume = args.resume
70 | self.data_loader = None
71 | self.model = {}
72 | self.criterion = None
73 | self.optimizer = None
74 | self.scheduler = None
75 | self.trainer = None
76 |
77 | # initialize batch_length
78 | self.batch_length = self.config['batch_length']
79 |
80 |
81 | @abc.abstractmethod
82 | def initialize_data_loader(self):
83 | pass
84 |
85 |
86 | @abc.abstractmethod
87 | def define_model(self):
88 | pass
89 |
90 |
91 | @abc.abstractmethod
92 | def define_trainer(self):
93 | pass
94 |
95 |
96 | @abc.abstractmethod
97 | def initialize_model(self):
98 | pass
99 |
100 |
101 | @abc.abstractmethod
102 | def define_criterion(self):
103 | pass
104 |
105 |
106 | def run(self):
107 | try:
108 | logging.info(f"The current training step: {self.trainer.steps}")
109 | self.trainer.train_max_steps = self.config["train_max_steps"]
110 | if not self.trainer._check_train_finish():
111 | self.trainer.run()
112 | if self.config.get("adv_train_max_steps", False) and self.config.get("adv_batch_length", False):
113 | self.batch_length = self.config['adv_batch_length']
114 | logging.info(f"Reload dataloader for adversarial training.")
115 | self.initialize_data_loader()
116 | self.trainer.data_loader = self.data_loader
117 | self.trainer.train_max_steps = self.config["adv_train_max_steps"]
118 | self.trainer.run()
119 | finally:
120 | self.trainer.save_checkpoint(
121 | os.path.join(self.config["outdir"], f"checkpoint-{self.trainer.steps}steps.pkl")
122 | )
123 | logging.info(f"Successfully saved checkpoint @ {self.trainer.steps}steps.")
124 |
125 |
126 | def _show_setting(self):
127 | logging.info(self.model['generator'])
128 | logging.info(self.model['discriminator'])
129 | logging.info(self.optimizer['generator'])
130 | logging.info(self.optimizer['discriminator'])
131 | logging.info(self.scheduler['generator'])
132 | logging.info(self.scheduler['discriminator'])
133 | for criterion_ in self.criterion.values():
134 | logging.info(criterion_)
135 |
136 |
137 | def _load_config(self, checkpoint, config_name='config.yml'):
138 | return load_config(checkpoint, config_name)
139 |
140 |
--------------------------------------------------------------------------------
/AudioDec/bin/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright (c) Meta Platforms, Inc. and affiliates.
5 | # All rights reserved.
6 | #
7 | # This source code is licensed under the license found in the
8 | # LICENSE file in the root directory of this source tree.
9 |
10 |
11 | """Utility modules."""
12 |
13 | import os
14 | import yaml
15 |
16 |
17 | def load_config(checkpoint, config_name='config.yml'):
18 | dirname = os.path.dirname(checkpoint)
19 | config_path = os.path.join(dirname, config_name)
20 | with open(config_path) as f:
21 | config = yaml.load(f, Loader=yaml.Loader)
22 | return config
23 |
--------------------------------------------------------------------------------
/AudioDec/config/autoencoder/symAAD_vctk_48000_hop300.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/)
8 |
9 |
10 | ###########################################################
11 | # DATA SETTING #
12 | ###########################################################
13 | sampling_rate: &sampling_rate 48000
14 | data:
15 | path: /mnt/home/yichiaowu/datasets/list
16 | subset:
17 | train: vctk_48000_clean_trainset_84spk.txt
18 | valid: vctk_48000_clean_validset_84spk.txt
19 | test: vctk_48000_clean_testset.txt
20 | vhstest: VHS_48000_raw_Phase1_test.txt
21 |
22 | ###########################################################
23 | # MODEL SETTING #
24 | ###########################################################
25 | model_type: symAudioDec
26 | train_mode: autoencoder
27 | paradigm: efficient
28 |
29 | generator_params:
30 | input_channels: 1
31 | output_channels: 1
32 | encode_channels: 32
33 | decode_channels: 32
34 | code_dim: 64
35 | codebook_num: 8
36 | codebook_size: 1024
37 | bias: true
38 | enc_ratios: [2, 4, 8, 16]
39 | dec_ratios: [16, 8, 4, 2]
40 | enc_strides: [3, 4, 5, 5]
41 | dec_strides: [5, 5, 4, 3]
42 | mode: causal
43 | codec: activate_audiodec
44 | projector: conv1d
45 | quantier: residual_vq
46 | use_weight_norm: true
47 |
48 | discriminator_params:
49 | scales: 3 # Number of multi-scale discriminator.
50 | scale_downsample_pooling: AvgPool1d # Pooling operation for scale discriminator.
51 | scale_downsample_pooling_params:
52 | kernel_size: 4 # Pooling kernel size.
53 | stride: 2 # Pooling stride.
54 | padding: 2 # Padding size.
55 | scale_discriminator_params:
56 | in_channels: 1 # Number of input channels.
57 | out_channels: 1 # Number of output channels.
58 | kernel_sizes: [15, 41, 5, 3] # List of kernel sizes.
59 | channels: 128 # Initial number of channels.
60 | max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers.
61 | max_groups: 16 # Maximum number of groups in downsampling conv layers.
62 | bias: true
63 | downsample_scales: [4, 4, 4, 4, 1] # Downsampling scales.
64 | nonlinear_activation: LeakyReLU # Nonlinear activation.
65 | nonlinear_activation_params:
66 | negative_slope: 0.1
67 | follow_official_norm: true # Whether to follow the official norm setting.
68 | periods: [2, 3, 5, 7, 11] # List of period for multi-period discriminator.
69 | period_discriminator_params:
70 | in_channels: 1 # Number of input channels.
71 | out_channels: 1 # Number of output channels.
72 | kernel_sizes: [5, 3] # List of kernel sizes.
73 | channels: 32 # Initial number of channels.
74 | downsample_scales: [3, 3, 3, 3, 1] # Downsampling scales.
75 | max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers.
76 | bias: true # Whether to use bias parameter in conv layer.
77 | nonlinear_activation: LeakyReLU # Nonlinear activation.
78 | nonlinear_activation_params: # Nonlinear activation paramters.
79 | negative_slope: 0.1
80 | use_weight_norm: true # Whether to apply weight normalization.
81 | use_spectral_norm: false # Whether to apply spectral normalization.
82 |
83 | ###########################################################
84 | # METRIC LOSS SETTING #
85 | ###########################################################
86 | use_mel_loss: true # Whether to use Mel-spectrogram loss.
87 | mel_loss_params:
88 | fs: *sampling_rate
89 | fft_sizes: [2048]
90 | hop_sizes: [300]
91 | win_lengths: [2048]
92 | window: hann_window
93 | num_mels: 80
94 | fmin: 0
95 | fmax: 24000
96 | log_base: null
97 |
98 | use_stft_loss: false # Whether to use multi-resolution STFT loss.
99 | stft_loss_params:
100 | fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss.
101 | hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss
102 | win_lengths: [600, 1200, 240] # List of window length for STFT-based loss.
103 | window: hann_window # Window function for STFT-based loss
104 |
105 | use_shape_loss: false # Whether to use waveform shape loss.
106 | shape_loss_params:
107 | winlen: [300]
108 |
109 | ###########################################################
110 | # ADV LOSS SETTING #
111 | ###########################################################
112 | generator_adv_loss_params:
113 | average_by_discriminators: false # Whether to average loss by #discriminators.
114 |
115 | discriminator_adv_loss_params:
116 | average_by_discriminators: false # Whether to average loss by #discriminators.
117 |
118 | use_feat_match_loss: true
119 | feat_match_loss_params:
120 | average_by_discriminators: false # Whether to average loss by #discriminators.
121 | average_by_layers: false # Whether to average loss by #layers in each discriminator.
122 | include_final_outputs: false # Whether to include final outputs in feat match loss calculation.
123 |
124 | ###########################################################
125 | # LOSS WEIGHT SETTING #
126 | ###########################################################
127 | lambda_adv: 1.0 # Loss weight of adversarial loss.
128 | lambda_feat_match: 2.0 # Loss weight of feat match loss.
129 | lambda_vq_loss: 1.0 # Loss weight of vector quantize loss.
130 | lambda_mel_loss: 45.0 # Loss weight of mel-spectrogram spectloss.
131 | lambda_stft_loss: 45.0 # Loss weight of multi-resolution stft loss.
132 | lambda_shape_loss: 45.0 # Loss weight of multi-window shape loss.
133 |
134 | ###########################################################
135 | # DATA LOADER SETTING #
136 | ###########################################################
137 | batch_size: 16 # Batch size.
138 | batch_length: 9600 # Length of each audio in batch (training w/o adv). Make sure dividable by hop_size.
139 | adv_batch_length: 9600 # Length of each audio in batch (training w/ adv). Make sure dividable by hop_size.
140 | pin_memory: true # Whether to pin memory in Pytorch DataLoader.
141 | num_workers: 2 # Number of workers in Pytorch DataLoader.
142 |
143 | ###########################################################
144 | # OPTIMIZER & SCHEDULER SETTING #
145 | ###########################################################
146 | generator_optimizer_type: Adam
147 | generator_optimizer_params:
148 | lr: 1.0e-4
149 | betas: [0.5, 0.9]
150 | weight_decay: 0.0
151 | generator_scheduler_type: StepLR
152 | generator_scheduler_params:
153 | step_size: 200000 # Generator scheduler step size.
154 | gamma: 1.0
155 | generator_grad_norm: -1
156 | discriminator_optimizer_type: Adam
157 | discriminator_optimizer_params:
158 | lr: 2.0e-4
159 | betas: [0.5, 0.9]
160 | weight_decay: 0.0
161 | discriminator_scheduler_type: MultiStepLR
162 | discriminator_scheduler_params:
163 | gamma: 0.5
164 | milestones:
165 | - 200000
166 | - 400000
167 | - 600000
168 | - 800000
169 | discriminator_grad_norm: -1
170 |
171 | ###########################################################
172 | # INTERVAL SETTING #
173 | ###########################################################
174 | start_steps: # Number of steps to start training
175 | generator: 0
176 | discriminator: 200000
177 | train_max_steps: 200000 # Number of training steps. (w/o adv)
178 | adv_train_max_steps: 700000 # Number of training steps. (w/ adv)
179 | save_interval_steps: 100000 # Interval steps to save checkpoint.
180 | eval_interval_steps: 1000 # Interval steps to evaluate the network.
181 | log_interval_steps: 100 # Interval steps to record the training log.
182 |
--------------------------------------------------------------------------------
/AudioDec/config/autoencoder/symAD_c16_vctk_48000_hop320.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/)
8 |
9 |
10 | ###########################################################
11 | # DATA SETTING #
12 | ###########################################################
13 | sampling_rate: &sampling_rate 48000
14 | data:
15 | path: /mnt/home/yichiaowu/datasets/list
16 | subset:
17 | train: vctk_48000_clean_trainset_84spk.txt
18 | valid: vctk_48000_clean_validset_84spk.txt
19 | test: vctk_48000_clean_testset.txt
20 |
21 | ###########################################################
22 | # MODEL SETTING #
23 | ###########################################################
24 | model_type: symAudioDec
25 | train_mode: autoencoder
26 | paradigm: efficient
27 |
28 | generator_params:
29 | input_channels: 1
30 | output_channels: 1
31 | encode_channels: 32
32 | decode_channels: 32
33 | code_dim: 64
34 | codebook_num: 16
35 | codebook_size: 1024
36 | bias: true
37 | enc_ratios: [2, 4, 8, 16]
38 | dec_ratios: [16, 8, 4, 2]
39 | enc_strides: [2, 4, 5, 8]
40 | dec_strides: [8, 5, 4, 2]
41 | mode: causal
42 | codec: audiodec
43 | projector: conv1d
44 | quantier: residual_vq
45 |
46 | discriminator_params:
47 | scales: 3 # Number of multi-scale discriminator.
48 | scale_downsample_pooling: AvgPool1d # Pooling operation for scale discriminator.
49 | scale_downsample_pooling_params:
50 | kernel_size: 4 # Pooling kernel size.
51 | stride: 2 # Pooling stride.
52 | padding: 2 # Padding size.
53 | scale_discriminator_params:
54 | in_channels: 1 # Number of input channels.
55 | out_channels: 1 # Number of output channels.
56 | kernel_sizes: [15, 41, 5, 3] # List of kernel sizes.
57 | channels: 128 # Initial number of channels.
58 | max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers.
59 | max_groups: 16 # Maximum number of groups in downsampling conv layers.
60 | bias: true
61 | downsample_scales: [4, 4, 4, 4, 1] # Downsampling scales.
62 | nonlinear_activation: LeakyReLU # Nonlinear activation.
63 | nonlinear_activation_params:
64 | negative_slope: 0.1
65 | follow_official_norm: true # Whether to follow the official norm setting.
66 | periods: [2, 3, 5, 7, 11] # List of period for multi-period discriminator.
67 | period_discriminator_params:
68 | in_channels: 1 # Number of input channels.
69 | out_channels: 1 # Number of output channels.
70 | kernel_sizes: [5, 3] # List of kernel sizes.
71 | channels: 32 # Initial number of channels.
72 | downsample_scales: [3, 3, 3, 3, 1] # Downsampling scales.
73 | max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers.
74 | bias: true # Whether to use bias parameter in conv layer.
75 | nonlinear_activation: LeakyReLU # Nonlinear activation.
76 | nonlinear_activation_params: # Nonlinear activation paramters.
77 | negative_slope: 0.1
78 | use_weight_norm: true # Whether to apply weight normalization.
79 | use_spectral_norm: false # Whether to apply spectral normalization.
80 |
81 | ###########################################################
82 | # METRIC LOSS SETTING #
83 | ###########################################################
84 | use_mel_loss: true # Whether to use Mel-spectrogram loss.
85 | mel_loss_params:
86 | fs: *sampling_rate
87 | fft_sizes: [2048]
88 | hop_sizes: [300]
89 | win_lengths: [2048]
90 | window: hann_window
91 | num_mels: 80
92 | fmin: 0
93 | fmax: 24000
94 | log_base: null
95 |
96 | use_stft_loss: false # Whether to use multi-resolution STFT loss.
97 | stft_loss_params:
98 | fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss.
99 | hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss
100 | win_lengths: [600, 1200, 240] # List of window length for STFT-based loss.
101 | window: hann_window # Window function for STFT-based loss
102 |
103 | use_shape_loss: false # Whether to use waveform shape loss.
104 | shape_loss_params:
105 | winlen: [320]
106 |
107 | ###########################################################
108 | # ADV LOSS SETTING #
109 | ###########################################################
110 | generator_adv_loss_params:
111 | average_by_discriminators: false # Whether to average loss by #discriminators.
112 |
113 | discriminator_adv_loss_params:
114 | average_by_discriminators: false # Whether to average loss by #discriminators.
115 |
116 | use_feat_match_loss: true
117 | feat_match_loss_params:
118 | average_by_discriminators: false # Whether to average loss by #discriminators.
119 | average_by_layers: false # Whether to average loss by #layers in each discriminator.
120 | include_final_outputs: false # Whether to include final outputs in feat match loss calculation.
121 |
122 | ###########################################################
123 | # LOSS WEIGHT SETTING #
124 | ###########################################################
125 | lambda_adv: 1.0 # Loss weight of adversarial loss.
126 | lambda_feat_match: 2.0 # Loss weight of feat match loss.
127 | lambda_vq_loss: 1.0 # Loss weight of vector quantize loss.
128 | lambda_mel_loss: 45.0 # Loss weight of mel-spectrogram spectloss.
129 | lambda_stft_loss: 45.0 # Loss weight of multi-resolution stft loss.
130 | lambda_shape_loss: 45.0 # Loss weight of multi-window shape loss.
131 |
132 | ###########################################################
133 | # DATA LOADER SETTING #
134 | ###########################################################
135 | batch_size: 16 # Batch size.
136 | batch_length: 96000 # Length of each audio in batch (training w/o adv). Make sure dividable by hop_size.
137 | adv_batch_length: 9600 # Length of each audio in batch (training w/ adv). Make sure dividable by hop_size.
138 | pin_memory: true # Whether to pin memory in Pytorch DataLoader.
139 | num_workers: 2 # Number of workers in Pytorch DataLoader.
140 |
141 | ###########################################################
142 | # OPTIMIZER & SCHEDULER SETTING #
143 | ###########################################################
144 | generator_optimizer_type: Adam
145 | generator_optimizer_params:
146 | lr: 1.0e-4
147 | betas: [0.5, 0.9]
148 | weight_decay: 0.0
149 | generator_scheduler_type: StepLR
150 | generator_scheduler_params:
151 | step_size: 200000 # Generator scheduler step size.
152 | gamma: 1.0
153 | generator_grad_norm: -1
154 | discriminator_optimizer_type: Adam
155 | discriminator_optimizer_params:
156 | lr: 2.0e-4
157 | betas: [0.5, 0.9]
158 | weight_decay: 0.0
159 | discriminator_scheduler_type: MultiStepLR
160 | discriminator_scheduler_params:
161 | gamma: 0.5
162 | milestones:
163 | - 200000
164 | - 400000
165 | - 600000
166 | - 800000
167 | discriminator_grad_norm: -1
168 |
169 | ###########################################################
170 | # INTERVAL SETTING #
171 | ###########################################################
172 | start_steps: # Number of steps to start training
173 | generator: 0
174 | discriminator: 500000
175 | train_max_steps: 500000 # Number of training steps. (w/o adv)
176 | adv_train_max_steps: 1000000 # Number of training steps. (w/ adv)
177 | save_interval_steps: 100000 # Interval steps to save checkpoint.
178 | eval_interval_steps: 1000 # Interval steps to evaluate the network.
179 | log_interval_steps: 100 # Interval steps to record the training log.
180 |
--------------------------------------------------------------------------------
/AudioDec/config/autoencoder/symAD_libritts_24000_hop300.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/)
8 |
9 |
10 | ###########################################################
11 | # DATA SETTING #
12 | ###########################################################
13 | sampling_rate: &sampling_rate 24000
14 | data:
15 | path: /mnt/home/yichiaowu/datasets/LibriTTS/LibriTTS/24000
16 | subset:
17 | train: train-clean-450
18 | valid: dev-clean-1utt
19 | test: test-clean-1utt
20 |
21 | ###########################################################
22 | # MODEL SETTING #
23 | ###########################################################
24 | model_type: symAudioDec
25 | train_mode: autoencoder
26 | paradigm: efficient
27 |
28 | generator_params:
29 | input_channels: 1
30 | output_channels: 1
31 | encode_channels: 32
32 | decode_channels: 32
33 | code_dim: 64
34 | codebook_num: 8
35 | codebook_size: 1024
36 | bias: true
37 | enc_ratios: [2, 4, 8, 16]
38 | dec_ratios: [16, 8, 4, 2]
39 | enc_strides: [3, 4, 5, 5]
40 | dec_strides: [5, 5, 4, 3]
41 | mode: causal
42 | codec: audiodec
43 | projector: conv1d
44 | quantier: residual_vq
45 |
46 | discriminator_params:
47 | scales: 3 # Number of multi-scale discriminator.
48 | scale_downsample_pooling: AvgPool1d # Pooling operation for scale discriminator.
49 | scale_downsample_pooling_params:
50 | kernel_size: 4 # Pooling kernel size.
51 | stride: 2 # Pooling stride.
52 | padding: 2 # Padding size.
53 | scale_discriminator_params:
54 | in_channels: 1 # Number of input channels.
55 | out_channels: 1 # Number of output channels.
56 | kernel_sizes: [15, 41, 5, 3] # List of kernel sizes.
57 | channels: 128 # Initial number of channels.
58 | max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers.
59 | max_groups: 16 # Maximum number of groups in downsampling conv layers.
60 | bias: true
61 | downsample_scales: [4, 4, 4, 4, 1] # Downsampling scales.
62 | nonlinear_activation: LeakyReLU # Nonlinear activation.
63 | nonlinear_activation_params:
64 | negative_slope: 0.1
65 | follow_official_norm: true # Whether to follow the official norm setting.
66 | periods: [2, 3, 5, 7, 11] # List of period for multi-period discriminator.
67 | period_discriminator_params:
68 | in_channels: 1 # Number of input channels.
69 | out_channels: 1 # Number of output channels.
70 | kernel_sizes: [5, 3] # List of kernel sizes.
71 | channels: 32 # Initial number of channels.
72 | downsample_scales: [3, 3, 3, 3, 1] # Downsampling scales.
73 | max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers.
74 | bias: true # Whether to use bias parameter in conv layer.
75 | nonlinear_activation: LeakyReLU # Nonlinear activation.
76 | nonlinear_activation_params: # Nonlinear activation paramters.
77 | negative_slope: 0.1
78 | use_weight_norm: true # Whether to apply weight normalization.
79 | use_spectral_norm: false # Whether to apply spectral normalization.
80 |
81 | ###########################################################
82 | # METRIC LOSS SETTING #
83 | ###########################################################
84 | use_mel_loss: true # Whether to use Mel-spectrogram loss.
85 | mel_loss_params:
86 | fs: *sampling_rate
87 | fft_sizes: [2048]
88 | hop_sizes: [300]
89 | win_lengths: [2048]
90 | window: hann_window
91 | num_mels: 80
92 | fmin: 0
93 | fmax: 12000
94 | log_base: null
95 |
96 | use_stft_loss: false # Whether to use multi-resolution STFT loss.
97 | stft_loss_params:
98 | fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss.
99 | hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss
100 | win_lengths: [600, 1200, 240] # List of window length for STFT-based loss.
101 | window: hann_window # Window function for STFT-based loss
102 |
103 | use_shape_loss: false # Whether to use waveform shape loss.
104 | shape_loss_params:
105 | winlen: [300]
106 |
107 | ###########################################################
108 | # ADV LOSS SETTING #
109 | ###########################################################
110 | generator_adv_loss_params:
111 | average_by_discriminators: false # Whether to average loss by #discriminators.
112 |
113 | discriminator_adv_loss_params:
114 | average_by_discriminators: false # Whether to average loss by #discriminators.
115 |
116 | use_feat_match_loss: true
117 | feat_match_loss_params:
118 | average_by_discriminators: false # Whether to average loss by #discriminators.
119 | average_by_layers: false # Whether to average loss by #layers in each discriminator.
120 | include_final_outputs: false # Whether to include final outputs in feat match loss calculation.
121 |
122 | ###########################################################
123 | # LOSS WEIGHT SETTING #
124 | ###########################################################
125 | lambda_adv: 1.0 # Loss weight of adversarial loss.
126 | lambda_feat_match: 2.0 # Loss weight of feat match loss.
127 | lambda_vq_loss: 1.0 # Loss weight of vector quantize loss.
128 | lambda_mel_loss: 45.0 # Loss weight of mel-spectrogram spectloss.
129 | lambda_stft_loss: 45.0 # Loss weight of multi-resolution stft loss.
130 | lambda_shape_loss: 45.0 # Loss weight of multi-window shape loss.
131 |
132 | ###########################################################
133 | # DATA LOADER SETTING #
134 | ###########################################################
135 | batch_size: 16 # Batch size.
136 | batch_length: 9600 # Length of each audio in batch (training w/o adv). Make sure dividable by hop_size.
137 | adv_batch_length: 9600 # Length of each audio in batch (training w/ adv). Make sure dividable by hop_size.
138 | pin_memory: true # Whether to pin memory in Pytorch DataLoader.
139 | num_workers: 2 # Number of workers in Pytorch DataLoader.
140 |
141 | ###########################################################
142 | # OPTIMIZER & SCHEDULER SETTING #
143 | ###########################################################
144 | generator_optimizer_type: Adam
145 | generator_optimizer_params:
146 | lr: 1.0e-4
147 | betas: [0.5, 0.9]
148 | weight_decay: 0.0
149 | generator_scheduler_type: StepLR
150 | generator_scheduler_params:
151 | step_size: 200000 # Generator scheduler step size.
152 | gamma: 1.0
153 | generator_grad_norm: -1
154 | discriminator_optimizer_type: Adam
155 | discriminator_optimizer_params:
156 | lr: 2.0e-4
157 | betas: [0.5, 0.9]
158 | weight_decay: 0.0
159 | discriminator_scheduler_type: MultiStepLR
160 | discriminator_scheduler_params:
161 | gamma: 0.5
162 | milestones:
163 | - 200000
164 | - 400000
165 | - 600000
166 | - 800000
167 | discriminator_grad_norm: -1
168 |
169 | ###########################################################
170 | # INTERVAL SETTING #
171 | ###########################################################
172 | start_steps: # Number of steps to start training
173 | generator: 0
174 | discriminator: 500000
175 | train_max_steps: 500000 # Number of training steps. (w/o adv)
176 | adv_train_max_steps: 1000000 # Number of training steps. (w/ adv)
177 | save_interval_steps: 100000 # Interval steps to save checkpoint.
178 | eval_interval_steps: 1000 # Interval steps to evaluate the network.
179 | log_interval_steps: 100 # Interval steps to record the training log.
180 |
--------------------------------------------------------------------------------
/AudioDec/config/autoencoder/symAD_vctk_48000_hop300.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/)
8 |
9 |
10 | ###########################################################
11 | # DATA SETTING #
12 | ###########################################################
13 | sampling_rate: &sampling_rate 48000
14 | data:
15 | path: /mnt/home/yichiaowu/datasets/vctk_noisy/48000
16 | subset:
17 | train: clean_trainset_84spk_wav
18 | valid: clean_validset_84spk_wav
19 | test: clean_testset_wav
20 |
21 | ###########################################################
22 | # MODEL SETTING #
23 | ###########################################################
24 | model_type: symAudioDec
25 | train_mode: autoencoder
26 | paradigm: efficient
27 |
28 | generator_params:
29 | input_channels: 1
30 | output_channels: 1
31 | encode_channels: 32
32 | decode_channels: 32
33 | code_dim: 64
34 | codebook_num: 8
35 | codebook_size: 1024
36 | bias: true
37 | enc_ratios: [2, 4, 8, 16]
38 | dec_ratios: [16, 8, 4, 2]
39 | enc_strides: [3, 4, 5, 5]
40 | dec_strides: [5, 5, 4, 3]
41 | mode: causal
42 | codec: audiodec
43 | projector: conv1d
44 | quantier: residual_vq
45 |
46 | discriminator_params:
47 | scales: 3 # Number of multi-scale discriminator.
48 | scale_downsample_pooling: AvgPool1d # Pooling operation for scale discriminator.
49 | scale_downsample_pooling_params:
50 | kernel_size: 4 # Pooling kernel size.
51 | stride: 2 # Pooling stride.
52 | padding: 2 # Padding size.
53 | scale_discriminator_params:
54 | in_channels: 1 # Number of input channels.
55 | out_channels: 1 # Number of output channels.
56 | kernel_sizes: [15, 41, 5, 3] # List of kernel sizes.
57 | channels: 128 # Initial number of channels.
58 | max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers.
59 | max_groups: 16 # Maximum number of groups in downsampling conv layers.
60 | bias: true
61 | downsample_scales: [4, 4, 4, 4, 1] # Downsampling scales.
62 | nonlinear_activation: LeakyReLU # Nonlinear activation.
63 | nonlinear_activation_params:
64 | negative_slope: 0.1
65 | follow_official_norm: true # Whether to follow the official norm setting.
66 | periods: [2, 3, 5, 7, 11] # List of period for multi-period discriminator.
67 | period_discriminator_params:
68 | in_channels: 1 # Number of input channels.
69 | out_channels: 1 # Number of output channels.
70 | kernel_sizes: [5, 3] # List of kernel sizes.
71 | channels: 32 # Initial number of channels.
72 | downsample_scales: [3, 3, 3, 3, 1] # Downsampling scales.
73 | max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers.
74 | bias: true # Whether to use bias parameter in conv layer.
75 | nonlinear_activation: LeakyReLU # Nonlinear activation.
76 | nonlinear_activation_params: # Nonlinear activation paramters.
77 | negative_slope: 0.1
78 | use_weight_norm: true # Whether to apply weight normalization.
79 | use_spectral_norm: false # Whether to apply spectral normalization.
80 |
81 | ###########################################################
82 | # METRIC LOSS SETTING #
83 | ###########################################################
84 | use_mel_loss: true # Whether to use Mel-spectrogram loss.
85 | mel_loss_params:
86 | fs: *sampling_rate
87 | fft_sizes: [2048]
88 | hop_sizes: [300]
89 | win_lengths: [2048]
90 | window: hann_window
91 | num_mels: 80
92 | fmin: 0
93 | fmax: 24000
94 | log_base: null
95 |
96 | use_stft_loss: false # Whether to use multi-resolution STFT loss.
97 | stft_loss_params:
98 | fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss.
99 | hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss
100 | win_lengths: [600, 1200, 240] # List of window length for STFT-based loss.
101 | window: hann_window # Window function for STFT-based loss
102 |
103 | use_shape_loss: false # Whether to use waveform shape loss.
104 | shape_loss_params:
105 | winlen: [300]
106 |
107 | ###########################################################
108 | # ADV LOSS SETTING #
109 | ###########################################################
110 | generator_adv_loss_params:
111 | average_by_discriminators: false # Whether to average loss by #discriminators.
112 |
113 | discriminator_adv_loss_params:
114 | average_by_discriminators: false # Whether to average loss by #discriminators.
115 |
116 | use_feat_match_loss: true
117 | feat_match_loss_params:
118 | average_by_discriminators: false # Whether to average loss by #discriminators.
119 | average_by_layers: false # Whether to average loss by #layers in each discriminator.
120 | include_final_outputs: false # Whether to include final outputs in feat match loss calculation.
121 |
122 | ###########################################################
123 | # LOSS WEIGHT SETTING #
124 | ###########################################################
125 | lambda_adv: 1.0 # Loss weight of adversarial loss.
126 | lambda_feat_match: 2.0 # Loss weight of feat match loss.
127 | lambda_vq_loss: 1.0 # Loss weight of vector quantize loss.
128 | lambda_mel_loss: 45.0 # Loss weight of mel-spectrogram spectloss.
129 | lambda_stft_loss: 45.0 # Loss weight of multi-resolution stft loss.
130 | lambda_shape_loss: 45.0 # Loss weight of multi-window shape loss.
131 |
132 | ###########################################################
133 | # DATA LOADER SETTING #
134 | ###########################################################
135 | batch_size: 16 # Batch size.
136 | batch_length: 9600 # Length of each audio in batch (training w/o adv). Make sure dividable by hop_size.
137 | adv_batch_length: 9600 # Length of each audio in batch (training w/ adv). Make sure dividable by hop_size.
138 | pin_memory: true # Whether to pin memory in Pytorch DataLoader.
139 | num_workers: 2 # Number of workers in Pytorch DataLoader.
140 |
141 | ###########################################################
142 | # OPTIMIZER & SCHEDULER SETTING #
143 | ###########################################################
144 | generator_optimizer_type: Adam
145 | generator_optimizer_params:
146 | lr: 1.0e-4
147 | betas: [0.5, 0.9]
148 | weight_decay: 0.0
149 | generator_scheduler_type: StepLR
150 | generator_scheduler_params:
151 | step_size: 200000 # Generator scheduler step size.
152 | gamma: 1.0
153 | generator_grad_norm: -1
154 | discriminator_optimizer_type: Adam
155 | discriminator_optimizer_params:
156 | lr: 2.0e-4
157 | betas: [0.5, 0.9]
158 | weight_decay: 0.0
159 | discriminator_scheduler_type: MultiStepLR
160 | discriminator_scheduler_params:
161 | gamma: 0.5
162 | milestones:
163 | - 200000
164 | - 400000
165 | - 600000
166 | - 800000
167 | discriminator_grad_norm: -1
168 |
169 | ###########################################################
170 | # INTERVAL SETTING #
171 | ###########################################################
172 | start_steps: # Number of steps to start training
173 | generator: 0
174 | discriminator: 200000
175 | train_max_steps: 200000 # Number of training steps. (w/o adv)
176 | adv_train_max_steps: 700000 # Number of training steps. (w/ adv)
177 | save_interval_steps: 100000 # Interval steps to save checkpoint.
178 | eval_interval_steps: 1000 # Interval steps to evaluate the network.
179 | log_interval_steps: 100 # Interval steps to record the training log.
180 |
--------------------------------------------------------------------------------
/AudioDec/config/autoencoder/symADuniv_vctk_48000_hop300.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/)
8 | # Reference (https://github.com/chomeyama/SiFiGAN)
9 |
10 |
11 | ###########################################################
12 | # DATA SETTING #
13 | ###########################################################
14 | sampling_rate: &sampling_rate 48000
15 | data:
16 | path: /mnt/home/yichiaowu/datasets/vctk_noisy/48000
17 | subset:
18 | train: clean_trainset_84spk_wav
19 | valid: clean_validset_84spk_wav
20 | test: clean_testset_wav
21 |
22 | ###########################################################
23 | # MODEL SETTING #
24 | ###########################################################
25 | model_type: symAudioDecUniv
26 | train_mode: autoencoder
27 | paradigm: efficient
28 |
29 | generator_params:
30 | input_channels: 1
31 | output_channels: 1
32 | encode_channels: 32
33 | decode_channels: 32
34 | code_dim: 64
35 | codebook_num: 8
36 | codebook_size: 1024
37 | bias: true
38 | enc_ratios: [2, 4, 8, 16]
39 | dec_ratios: [16, 8, 4, 2]
40 | enc_strides: [3, 4, 5, 5]
41 | dec_strides: [5, 5, 4, 3]
42 | mode: causal
43 | codec: audiodec
44 | projector: conv1d
45 | quantier: residual_vq
46 |
47 | discriminator_params:
48 | fft_sizes: [1024, 2048, 512] # FFT sizes for each spectral discriminator.
49 | hop_sizes: [120, 240, 50] # Hop sizes for each spectral discriminator.
50 | win_lengths: [600, 1200, 240] # Window lengths for each spectral discriminator.
51 | window: "hann_window" # Name of window function.
52 | spectral_discriminator_params: # Params for UnivNet spectral discriminator.
53 | channels: 32 # Number of channels for conv layer.
54 | kernel_sizes: # List of stride sizes in down-sampling CNNs.
55 | - [3, 9]
56 | - [3, 9]
57 | - [3, 9]
58 | - [3, 9]
59 | - [3, 3]
60 | - [3, 3]
61 | strides: # List of kernel sizes in down-sampling CNNs.
62 | - [1, 1]
63 | - [1, 2]
64 | - [1, 2]
65 | - [1, 2]
66 | - [1, 1]
67 | - [1, 1]
68 | bias: true # Whether to add bias parameter in convolution layers.
69 | nonlinear_activation: "LeakyReLU" # Nonlinear activation.
70 | nonlinear_activation_params: # Nonlinear activation paramters.
71 | negative_slope: 0.2
72 | periods: [2, 3, 5, 7, 11] # List of period for multi-period discriminator.
73 | period_discriminator_params:
74 | in_channels: 1 # Number of input channels.
75 | out_channels: 1 # Number of output channels.
76 | kernel_sizes: [5, 3] # List of kernel sizes.
77 | channels: 32 # Initial number of channels.
78 | downsample_scales: [3, 3, 3, 3, 1] # Downsampling scales.
79 | max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers.
80 | bias: true # Whether to use bias parameter in conv layer."
81 | nonlinear_activation: "LeakyReLU" # Nonlinear activation.
82 | nonlinear_activation_params: # Nonlinear activation paramters.
83 | negative_slope: 0.1
84 | use_weight_norm: true # Whether to apply weight normalization.
85 | use_spectral_norm: false # Whether to apply spectral normalization.
86 |
87 | ###########################################################
88 | # METRIC LOSS SETTING #
89 | ###########################################################
90 | use_mel_loss: true # Whether to use Mel-spectrogram loss.
91 | mel_loss_params:
92 | fs: *sampling_rate
93 | fft_sizes: [2048]
94 | hop_sizes: [300]
95 | win_lengths: [2048]
96 | window: "hann_window"
97 | num_mels: 80
98 | fmin: 0
99 | fmax: 24000
100 | log_base: null
101 |
102 | use_stft_loss: false # Whether to use multi-resolution STFT loss.
103 | stft_loss_params:
104 | fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss.
105 | hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss
106 | win_lengths: [600, 1200, 240] # List of window length for STFT-based loss.
107 | window: "hann_window" # Window function for STFT-based loss
108 |
109 | use_shape_loss: false # Whether to use waveform shape loss.
110 | shape_loss_params:
111 | winlen: [300]
112 |
113 | ###########################################################
114 | # ADV LOSS SETTING #
115 | ###########################################################
116 | generator_adv_loss_params:
117 | average_by_discriminators: false # Whether to average loss by #discriminators.
118 |
119 | discriminator_adv_loss_params:
120 | average_by_discriminators: false # Whether to average loss by #discriminators.
121 |
122 | use_feat_match_loss: true
123 | feat_match_loss_params:
124 | average_by_discriminators: false # Whether to average loss by #discriminators.
125 | average_by_layers: false # Whether to average loss by #layers in each discriminator.
126 | include_final_outputs: false # Whether to include final outputs in feat match loss calculation.
127 |
128 | ###########################################################
129 | # LOSS WEIGHT SETTING #
130 | ###########################################################
131 | lambda_adv: 1.0 # Loss weight of adversarial loss.
132 | lambda_feat_match: 2.0 # Loss weight of feat match loss.
133 | lambda_vq_loss: 1.0 # Loss weight of vector quantize loss.
134 | lambda_mel_loss: 45.0 # Loss weight of mel-spectrogram spectloss.
135 | lambda_stft_loss: 45.0 # Loss weight of multi-resolution stft loss.
136 | lambda_shape_loss: 45.0 # Loss weight of multi-window shape loss.
137 |
138 | ###########################################################
139 | # DATA LOADER SETTING #
140 | ###########################################################
141 | batch_size: 16 # Batch size.
142 | batch_length: 9600 # Length of each audio in batch (training w/o adv). Make sure dividable by hop_size.
143 | adv_batch_length: 9600 # Length of each audio in batch (training w/ adv). Make sure dividable by hop_size.
144 | pin_memory: true # Whether to pin memory in Pytorch DataLoader.
145 | num_workers: 2 # Number of workers in Pytorch DataLoader.
146 |
147 | ###########################################################
148 | # OPTIMIZER & SCHEDULER SETTING #
149 | ###########################################################
150 | generator_optimizer_type: Adam
151 | generator_optimizer_params:
152 | lr: 1.0e-4
153 | betas: [0.5, 0.9]
154 | weight_decay: 0.0
155 | generator_scheduler_type: StepLR
156 | generator_scheduler_params:
157 | step_size: 200000 # Generator scheduler step size.
158 | gamma: 1.0
159 | generator_grad_norm: -1
160 | discriminator_optimizer_type: Adam
161 | discriminator_optimizer_params:
162 | lr: 2.0e-4
163 | betas: [0.5, 0.9]
164 | weight_decay: 0.0
165 | discriminator_scheduler_type: MultiStepLR
166 | discriminator_scheduler_params:
167 | gamma: 0.5
168 | milestones:
169 | - 200000
170 | - 400000
171 | - 600000
172 | - 800000
173 | discriminator_grad_norm: -1
174 |
175 | ###########################################################
176 | # INTERVAL SETTING #
177 | ###########################################################
178 | start_steps: # Number of steps to start training
179 | generator: 0
180 | discriminator: 500000
181 | train_max_steps: 500000 # Number of training steps. (w/o adv)
182 | adv_train_max_steps: 1000000 # Number of training steps. (w/ adv)
183 | save_interval_steps: 100000 # Interval steps to save checkpoint.
184 | eval_interval_steps: 1000 # Interval steps to evaluate the network.
185 | log_interval_steps: 100 # Interval steps to record the training log.
186 |
--------------------------------------------------------------------------------
/AudioDec/config/denoise/symAD_vctk_48000_hop300.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/)
8 |
9 |
10 | ###########################################################
11 | # DATA SETTING #
12 | ###########################################################
13 | sampling_rate: &sampling_rate 48000
14 | data:
15 | path: /mnt/home/yichiaowu/datasets/vctk_noisy/48000
16 | subset:
17 | clean_train: clean_trainset_84spk_wav
18 | clean_valid: clean_validset_84spk_wav
19 | clean_test: clean_testset_wav
20 | noisy_train: noisy_trainset_84spk_wav
21 | noisy_valid: noisy_validset_84spk_wav
22 | noisy_test: noisy_testset_wav
23 |
24 | ###########################################################
25 | # MODEL SETTING #
26 | ###########################################################
27 | model_type: symAudioDec
28 | train_mode: denoise
29 | initial: exp/autoencoder/symAD_vctk_48000_hop300/checkpoint-200000steps.pkl # for model initialization
30 |
31 | generator_params:
32 | input_channels: 1
33 | output_channels: 1
34 | encode_channels: 32
35 | decode_channels: 32
36 | code_dim: 64
37 | codebook_num: 8
38 | codebook_size: 1024
39 | bias: true
40 | enc_ratios: [2, 4, 8, 16]
41 | dec_ratios: [16, 8, 4, 2]
42 | enc_strides: [3, 4, 5, 5]
43 | dec_strides: [5, 5, 4, 3]
44 | mode: causal
45 | codec: audiodec
46 | projector: conv1d
47 | quantier: residual_vq
48 |
49 | discriminator_params:
50 | scales: 3 # Number of multi-scale discriminator.
51 | scale_downsample_pooling: AvgPool1d # Pooling operation for scale discriminator.
52 | scale_downsample_pooling_params:
53 | kernel_size: 4 # Pooling kernel size.
54 | stride: 2 # Pooling stride.
55 | padding: 2 # Padding size.
56 | scale_discriminator_params:
57 | in_channels: 1 # Number of input channels.
58 | out_channels: 1 # Number of output channels.
59 | kernel_sizes: [15, 41, 5, 3] # List of kernel sizes.
60 | channels: 128 # Initial number of channels.
61 | max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers.
62 | max_groups: 16 # Maximum number of groups in downsampling conv layers.
63 | bias: true
64 | downsample_scales: [4, 4, 4, 4, 1] # Downsampling scales.
65 | nonlinear_activation: LeakyReLU # Nonlinear activation.
66 | nonlinear_activation_params:
67 | negative_slope: 0.1
68 | follow_official_norm: true # Whether to follow the official norm setting.
69 | periods: [2, 3, 5, 7, 11] # List of period for multi-period discriminator.
70 | period_discriminator_params:
71 | in_channels: 1 # Number of input channels.
72 | out_channels: 1 # Number of output channels.
73 | kernel_sizes: [5, 3] # List of kernel sizes.
74 | channels: 32 # Initial number of channels.
75 | downsample_scales: [3, 3, 3, 3, 1] # Downsampling scales.
76 | max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers.
77 | bias: true # Whether to use bias parameter in conv layer.
78 | nonlinear_activation: LeakyReLU # Nonlinear activation.
79 | nonlinear_activation_params: # Nonlinear activation paramters.
80 | negative_slope: 0.1
81 | use_weight_norm: true # Whether to apply weight normalization.
82 | use_spectral_norm: false # Whether to apply spectral normalization.
83 |
84 | ###########################################################
85 | # METRIC LOSS SETTING #
86 | ###########################################################
87 | use_mel_loss: true # Whether to use Mel-spectrogram loss.
88 | mel_loss_params:
89 | fs: *sampling_rate
90 | fft_sizes: [2048]
91 | hop_sizes: [300]
92 | win_lengths: [null]
93 | window: hann_window
94 | num_mels: 80
95 | fmin: 0
96 | fmax: 24000
97 | log_base: null
98 |
99 | use_stft_loss: false # Whether to use multi-resolution STFT loss.
100 | stft_loss_params:
101 | fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss.
102 | hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss
103 | win_lengths: [600, 1200, 240] # List of window length for STFT-based loss.
104 | window: hann_window # Window function for STFT-based loss
105 |
106 | use_shape_loss: false # Whether to use waveform shape loss.
107 | shape_loss_params:
108 | winlen: [300]
109 |
110 | ###########################################################
111 | # ADV LOSS SETTING #
112 | ###########################################################
113 | generator_adv_loss_params:
114 | average_by_discriminators: false # Whether to average loss by #discriminators.
115 |
116 | discriminator_adv_loss_params:
117 | average_by_discriminators: false # Whether to average loss by #discriminators.
118 |
119 | use_feat_match_loss: true
120 | feat_match_loss_params:
121 | average_by_discriminators: false # Whether to average loss by #discriminators.
122 | average_by_layers: false # Whether to average loss by #layers in each discriminator.
123 | include_final_outputs: false # Whether to include final outputs in feat match loss calculation.
124 |
125 | ###########################################################
126 | # LOSS WEIGHT SETTING #
127 | ###########################################################
128 | lambda_adv: 1.0 # Loss weight of adversarial loss.
129 | lambda_feat_match: 2.0 # Loss weight of feat match loss.
130 | lambda_vq_loss: 1.0 # Loss weight of vector quantize loss.
131 | lambda_mel_loss: 45.0 # Loss weight of mel-spectrogram spectloss.
132 | lambda_stft_loss: 45.0 # Loss weight of multi-resolution stft loss.
133 | lambda_shape_loss: 45.0 # Loss weight of multi-window shape loss.
134 |
135 | ###########################################################
136 | # DATA LOADER SETTING #
137 | ###########################################################
138 | batch_size: 16 # Batch size.
139 | batch_length: 96000 # Length of each audio in batch. Make sure dividable by hop_size.
140 | pin_memory: true # Whether to pin memory in Pytorch DataLoader.
141 | num_workers: 2 # Number of workers in Pytorch DataLoader.
142 |
143 | ###########################################################
144 | # OPTIMIZER & SCHEDULER SETTING #
145 | ###########################################################
146 | generator_optimizer_type: Adam
147 | generator_optimizer_params:
148 | lr: 1.0e-4
149 | betas: [0.5, 0.9]
150 | weight_decay: 0.0
151 | generator_scheduler_type: StepLR
152 | generator_scheduler_params:
153 | step_size: 200000 # Generator scheduler step size.
154 | gamma: 1.0
155 | generator_grad_norm: -1
156 | discriminator_optimizer_type: Adam
157 | discriminator_optimizer_params:
158 | lr: 2.0e-4
159 | betas: [0.5, 0.9]
160 | weight_decay: 0.0
161 | discriminator_scheduler_type: MultiStepLR
162 | discriminator_scheduler_params:
163 | gamma: 0.5
164 | milestones:
165 | - 200000
166 | - 400000
167 | - 600000
168 | - 800000
169 | discriminator_grad_norm: -1
170 |
171 | ###########################################################
172 | # INTERVAL SETTING #
173 | ###########################################################
174 | start_steps: # Number of steps to start training
175 | generator: 0
176 | discriminator: 200000
177 | train_max_steps: 200000 # Number of training steps.
178 | save_interval_steps: 100000 # Interval steps to save checkpoint.
179 | eval_interval_steps: 1000 # Interval steps to evaluate the network.
180 | log_interval_steps: 100 # Interval steps to record the training log.
181 |
--------------------------------------------------------------------------------
/AudioDec/config/statistic/symAD_libritts_24000_hop300_clean.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/)
8 |
9 |
10 | ###########################################################
11 | # DATA SETTING #
12 | ###########################################################
13 | sampling_rate: 24000
14 | data:
15 | path: /mnt/home/yichiaowu/datasets/LibriTTS/LibriTTS/24000
16 | subset:
17 | train: train-clean-450
18 | valid: dev-clean-1utt
19 | test: test-clean-1utt
20 |
21 | ###########################################################
22 | # STATISTIC SETTING #
23 | ###########################################################
24 | analyzer: exp/autoencoder/symAD_libritts_24000_hop300/checkpoint-500000steps.pkl
25 | stats: stats/symAD_libritts_24000_hop300_clean.npy
--------------------------------------------------------------------------------
/AudioDec/config/statistic/symAD_vctk_48000_hop300_clean.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/)
8 |
9 |
10 | ###########################################################
11 | # DATA SETTING #
12 | ###########################################################
13 | sampling_rate: 48000
14 | data:
15 | path: /mnt/home/yichiaowu/datasets/vctk_noisy/48000
16 | subset:
17 | train: clean_trainset_84spk_wav
18 | valid: clean_validset_84spk_wav
19 | test: clean_testset_wav
20 |
21 | ###########################################################
22 | # STATISTIC SETTING #
23 | ###########################################################
24 | analyzer: exp/autoencoder/symAD_vctk_48000_hop300/checkpoint-200000steps.pkl
25 | stats: stats/symAD_vctk_48000_hop300_clean.npy
26 |
--------------------------------------------------------------------------------
/AudioDec/config/statistic/symADuniv_vctk_48000_hop300_clean.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/)
8 |
9 |
10 | ###########################################################
11 | # DATA SETTING #
12 | ###########################################################
13 | sampling_rate: 48000
14 | data:
15 | path: /mnt/home/yichiaowu/datasets/vctk_noisy/48000
16 | subset:
17 | train: clean_trainset_84spk_wav
18 | valid: clean_validset_84spk_wav
19 | test: clean_testset_wav
20 |
21 | ###########################################################
22 | # STATISTIC SETTING #
23 | ###########################################################
24 | analyzer: exp/autoencoder/symADuniv_vctk_48000_hop300/checkpoint-500000steps.pkl
25 | stats: stats/symADuniv_vctk_48000_hop300_clean.npy
26 |
--------------------------------------------------------------------------------
/AudioDec/dataloader/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataset import * # NOQA
2 | from .collater import * # NOQA
3 | from .utils import * # NOQA
--------------------------------------------------------------------------------
/AudioDec/dataloader/collater.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright (c) Meta Platforms, Inc. and affiliates.
5 | # All rights reserved.
6 | #
7 | # This source code is licensed under the license found in the
8 | # LICENSE file in the root directory of this source tree.
9 | #
10 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/)
11 |
12 | """Customized collater modules for Pytorch DataLoader."""
13 |
14 | import torch
15 | import numpy as np
16 |
17 |
18 | class CollaterAudio(object):
19 | """Customized collater for loading single audio."""
20 |
21 | def __init__(
22 | self,
23 | batch_length=9600,
24 | ):
25 | """
26 | Args:
27 | batch_length (int): The length of audio signal batch.
28 |
29 | """
30 | self.batch_length = batch_length
31 |
32 |
33 | def __call__(self, batch):
34 | # filter short batch
35 | xs = [b for b in batch if len(b) > self.batch_length]
36 |
37 | # random cut
38 | starts, ends = self._random_segment(xs)
39 | x_batch = self._cut(xs, starts, ends)
40 |
41 | return x_batch
42 |
43 |
44 | def _random_segment(self, xs):
45 | x_lengths = [len(x) for x in xs]
46 | start_offsets = np.array(
47 | [
48 | np.random.randint(0, xl - self.batch_length)
49 | for xl in x_lengths
50 | ]
51 | )
52 | starts = start_offsets
53 | ends = starts + self.batch_length
54 | return starts, ends
55 |
56 |
57 | def _cut(self, xs, starts, ends):
58 | x_batch = np.array([x[start:end] for x, start, end in zip(xs, starts, ends)])
59 | x_batch = torch.tensor(x_batch, dtype=torch.float).transpose(2, 1) # (B, C, T)
60 | return x_batch
61 |
62 |
63 | class CollaterAudioPair(CollaterAudio):
64 | """Customized collater for loading audio pair."""
65 |
66 | def __init__(
67 | self,
68 | batch_length=9600,
69 | ):
70 | super().__init__(
71 | batch_length=batch_length
72 | )
73 |
74 |
75 | def __call__(self, batch):
76 | batch = [
77 | b for b in batch if (len(b[0]) > self.batch_length) and (len(b[0]) == len(b[1]))
78 | ]
79 | assert len(batch) > 0, f"No qualified audio pairs.!"
80 | xs, ns = [b[0] for b in batch], [b[1] for b in batch]
81 |
82 | # random cut
83 | starts, ends = self._random_segment(xs)
84 | x_batch = self._cut(xs, starts, ends)
85 | n_batch = self._cut(ns, starts, ends)
86 |
87 | return n_batch, x_batch # (input, output)
88 |
89 |
90 |
--------------------------------------------------------------------------------
/AudioDec/dataloader/dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright (c) Meta Platforms, Inc. and affiliates.
5 | # All rights reserved.
6 | #
7 | # This source code is licensed under the license found in the
8 | # LICENSE file in the root directory of this source tree.
9 | #
10 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/)
11 |
12 | """PyTorch compatible dataset modules."""
13 |
14 | import os
15 | import soundfile as sf
16 | from torch.utils.data import Dataset
17 | from dataloader.utils import find_files
18 |
19 |
20 | class SingleDataset(Dataset):
21 | def __init__(
22 | self,
23 | files,
24 | query="*.wav",
25 | load_fn=sf.read,
26 | return_utt_id=False,
27 | subset_num=-1,
28 | ):
29 | self.return_utt_id = return_utt_id
30 | self.load_fn = load_fn
31 | self.subset_num = subset_num
32 |
33 | self.filenames = self._load_list(files, query)
34 | self.utt_ids = self._load_ids(self.filenames)
35 |
36 |
37 | def __getitem__(self, idx):
38 | utt_id = self.utt_ids[idx]
39 | data = self._data(idx)
40 |
41 | if self.return_utt_id:
42 | items = utt_id, data
43 | else:
44 | items = data
45 |
46 | return items
47 |
48 |
49 | def __len__(self):
50 | return len(self.filenames)
51 |
52 |
53 | def _read_list(self, listfile):
54 | filenames = []
55 | with open(listfile) as f:
56 | for line in f:
57 | line = line.strip()
58 | if len(line):
59 | filenames.append(line)
60 | return filenames
61 |
62 |
63 | def _load_list(self, files, query):
64 | if isinstance(files, list):
65 | filenames = files
66 | else:
67 | if os.path.isdir(files):
68 | filenames = sorted(find_files(files, query))
69 | elif os.path.isfile(files):
70 | filenames = sorted(self._read_list(files))
71 | else:
72 | raise ValueError(f"{files} is not a list / existing folder or file!")
73 |
74 | if self.subset_num > 0:
75 | filenames = filenames[:self.subset_num]
76 | assert len(filenames) != 0, f"File list in empty!"
77 | return filenames
78 |
79 |
80 | def _load_ids(self, filenames):
81 | utt_ids = [
82 | os.path.splitext(os.path.basename(f))[0] for f in filenames
83 | ]
84 | return utt_ids
85 |
86 |
87 | def _data(self, idx):
88 | return self._load_data(self.filenames[idx], self.load_fn)
89 |
90 |
91 | def _load_data(self, filename, load_fn):
92 | if load_fn == sf.read:
93 | data, _ = load_fn(filename, always_2d=True) # (T, C)
94 | else:
95 | data = load_fn(filename)
96 | return data
97 |
98 |
99 | class MultiDataset(SingleDataset):
100 | def __init__(
101 | self,
102 | multi_files,
103 | queries,
104 | load_fns,
105 | return_utt_id=False,
106 | subset_num=-1,
107 | ):
108 | errmsg = f"multi_files({len(multi_files)}), queries({len(queries)}), and load_fns({len(load_fns)}) are length mismatched!"
109 | assert len(multi_files) == len(queries) == len(load_fns), errmsg
110 | super(MultiDataset, self).__init__(
111 | files=multi_files,
112 | query=queries,
113 | load_fn=load_fns,
114 | return_utt_id=return_utt_id,
115 | subset_num=subset_num,
116 | )
117 | self._check_length(self.filenames)
118 |
119 |
120 | def _load_list(self, multi_files, queries):
121 | multi_filenames = []
122 | if isinstance(multi_files, list):
123 | for files, query in zip(multi_files, queries):
124 | multi_filenames.append(super()._load_list(files, query))
125 | else:
126 | raise ValueError(f"{multi_files} should be a list!")
127 |
128 | return multi_filenames
129 |
130 |
131 | def _load_ids(self, multi_filenames):
132 | return super()._load_ids(multi_filenames[0])
133 |
134 |
135 | def _data(self, idx):
136 | filenames = [
137 | f[idx] for f in self.filenames
138 | ]
139 | data = []
140 | for filename, load_fn in zip(filenames, self.load_fn):
141 | data.append(self._load_data(filename, load_fn))
142 | return data
143 |
144 |
145 | def _check_length(self, multi_filenames):
146 | errmsg = f"Not all lists have the same number of files!"
147 | self.file_num = len(multi_filenames[0])
148 | assert all(len(x)==self.file_num for x in multi_filenames), errmsg
149 |
150 |
151 | def __len__(self):
152 | return self.file_num
153 |
154 |
--------------------------------------------------------------------------------
/AudioDec/dataloader/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright (c) Meta Platforms, Inc. and affiliates.
5 | # All rights reserved.
6 | #
7 | # This source code is licensed under the license found in the
8 | # LICENSE file in the root directory of this source tree.
9 | #
10 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/)
11 |
12 | import os
13 | import fnmatch
14 | import logging
15 | import numpy as np
16 |
17 |
18 | def find_files(root_dir, query="*.wav", include_root_dir=True):
19 | """Find files recursively.
20 | Args:
21 | root_dir (str): Root root_dir to find.
22 | query (str): Query to find.
23 | include_root_dir (bool): If False, root_dir name is not included.
24 | Returns:
25 | list: List of found filenames.
26 | """
27 | files = []
28 | for root, dirnames, filenames in os.walk(root_dir, followlinks=True):
29 | for filename in fnmatch.filter(filenames, query):
30 | files.append(os.path.join(root, filename))
31 | if not include_root_dir:
32 | files = [file_.replace(root_dir + "/", "") for file_ in files]
33 |
34 | return files
35 |
36 |
37 | def load_files(data_path, query="*.wav", num_core=40):
38 | # sort all files
39 | file_list = sorted(find_files(data_path, query))
40 | logging.info(f"The number of {os.path.basename(data_path)} files = {len(file_list)}.")
41 | # divide
42 | if num_core < len(file_list):
43 | file_lists = np.array_split(file_list, num_core)
44 | file_lists = [f_list.tolist() for f_list in file_lists]
45 | else:
46 | file_lists = [file_list]
47 | return file_lists
--------------------------------------------------------------------------------
/AudioDec/figs/architecture-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dukGuo/valle-audiodec/cbd51784ea9fe0bee05b4a65c95f39523b57be89/AudioDec/figs/architecture-1.png
--------------------------------------------------------------------------------
/AudioDec/figs/latency.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dukGuo/valle-audiodec/cbd51784ea9fe0bee05b4a65c95f39523b57be89/AudioDec/figs/latency.jpg
--------------------------------------------------------------------------------
/AudioDec/figs/mos.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dukGuo/valle-audiodec/cbd51784ea9fe0bee05b4a65c95f39523b57be89/AudioDec/figs/mos.jpg
--------------------------------------------------------------------------------
/AudioDec/layers/activation_function.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright (c) Meta Platforms, Inc. and affiliates.
5 | # All rights reserved.
6 | #
7 | # This source code is licensed under the license found in the
8 | # LICENSE file in the root directory of this source tree.
9 |
10 | """Activation functions."""
11 |
12 | import numpy as np
13 | import torch
14 | import torch.nn as nn
15 | import torch.nn.functional as F
16 |
17 |
18 | def get_activation(nonlinear_activation, nonlinear_activation_params={}):
19 | if hasattr(nn, nonlinear_activation):
20 | return getattr(nn, nonlinear_activation)(**nonlinear_activation_params)
21 | else:
22 | raise NotImplementedError(f"Activation {nonlinear_activation} is not supported!")
--------------------------------------------------------------------------------
/AudioDec/layers/conv_layer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright (c) Meta Platforms, Inc. and affiliates.
5 | # All rights reserved.
6 | #
7 | # This source code is licensed under the license found in the
8 | # LICENSE file in the root directory of this source tree.
9 | #
10 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/)
11 |
12 | """Convolution layers."""
13 |
14 | import math
15 | import torch
16 | import torch.nn as nn
17 | import torch.nn.functional as F
18 |
19 |
20 | def int2tuple(variable, length):
21 | if isinstance(variable, int):
22 | return (variable,)*length
23 | else:
24 | assert len(variable) == length, f"The length of {variable} is not {length}!"
25 | return variable
26 |
27 |
28 | class Conv1d1x1(nn.Conv1d):
29 | """1x1 Conv1d."""
30 |
31 | def __init__(self, in_channels, out_channels, bias=True):
32 | super(Conv1d1x1, self).__init__(in_channels, out_channels, kernel_size=1, bias=bias)
33 |
34 |
35 | class NonCausalConv1d(nn.Module):
36 | """1D noncausal convloution w/ 2-sides padding."""
37 |
38 | def __init__(
39 | self,
40 | in_channels,
41 | out_channels,
42 | kernel_size,
43 | stride=1,
44 | padding=-1,
45 | dilation=1,
46 | groups=1,
47 | bias=True):
48 | super().__init__()
49 | self.in_channels = in_channels
50 | self.out_channels = out_channels
51 | self.kernel_size = kernel_size
52 | if padding < 0:
53 | padding = (kernel_size - 1) // 2 * dilation
54 | self.dilation = dilation
55 | self.conv = nn.Conv1d(
56 | in_channels=in_channels,
57 | out_channels=out_channels,
58 | kernel_size=kernel_size,
59 | stride=stride,
60 | padding=padding,
61 | dilation=dilation,
62 | groups=groups,
63 | bias=bias,
64 | )
65 |
66 | def forward(self, x):
67 | """
68 | Args:
69 | x (Tensor): Float tensor variable with the shape (B, C, T).
70 | Returns:
71 | Tensor: Float tensor variable with the shape (B, C, T).
72 | """
73 | x = self.conv(x)
74 | return x
75 |
76 |
77 | class NonCausalConvTranspose1d(nn.Module):
78 | """1D noncausal transpose convloution."""
79 |
80 | def __init__(
81 | self,
82 | in_channels,
83 | out_channels,
84 | kernel_size,
85 | stride,
86 | padding=-1,
87 | output_padding=-1,
88 | groups=1,
89 | bias=True,
90 | ):
91 | super().__init__()
92 | if padding < 0:
93 | padding = (stride+1) // 2
94 | if output_padding < 0:
95 | output_padding = 1 if stride % 2 else 0
96 | self.deconv = nn.ConvTranspose1d(
97 | in_channels=in_channels,
98 | out_channels=out_channels,
99 | kernel_size=kernel_size,
100 | stride=stride,
101 | padding=padding,
102 | output_padding=output_padding,
103 | groups=groups,
104 | bias=bias,
105 | )
106 |
107 | def forward(self, x):
108 | """
109 | Args:
110 | x (Tensor): Float tensor variable with the shape (B, C, T).
111 | Returns:
112 | Tensor: Float tensor variable with the shape (B, C', T').
113 | """
114 | x = self.deconv(x)
115 | return x
116 |
117 |
118 | class CausalConv1d(NonCausalConv1d):
119 | """1D causal convloution w/ 1-side padding."""
120 |
121 | def __init__(
122 | self,
123 | in_channels,
124 | out_channels,
125 | kernel_size,
126 | stride=1,
127 | dilation=1,
128 | groups=1,
129 | bias=True,
130 | pad_buffer=None,
131 | ):
132 | super(CausalConv1d, self).__init__(
133 | in_channels=in_channels,
134 | out_channels=out_channels,
135 | kernel_size=kernel_size,
136 | stride=stride,
137 | padding=0,
138 | dilation=dilation,
139 | groups=groups,
140 | bias=bias,
141 | )
142 | self.stride = stride
143 | self.pad_length = (kernel_size - 1) * dilation
144 | if pad_buffer is None:
145 | pad_buffer = torch.zeros(1, in_channels, self.pad_length)
146 | self.register_buffer("pad_buffer", pad_buffer)
147 |
148 | def forward(self, x):
149 | pad = nn.ConstantPad1d((self.pad_length, 0), 0.0)
150 | x = pad(x)
151 | return self.conv(x)
152 |
153 | def inference(self, x):
154 | x = torch.cat((self.pad_buffer, x), -1)
155 | self.pad_buffer = x[:, :, -self.pad_length:]
156 | return self.conv(x)
157 |
158 | def reset_buffer(self):
159 | self.pad_buffer.zero_()
160 |
161 |
162 | class CausalConvTranspose1d(NonCausalConvTranspose1d):
163 | """1D causal transpose convloution."""
164 |
165 | def __init__(
166 | self,
167 | in_channels,
168 | out_channels,
169 | kernel_size,
170 | stride,
171 | bias=True,
172 | pad_buffer=None,
173 | ):
174 | super(CausalConvTranspose1d, self).__init__(
175 | in_channels=in_channels,
176 | out_channels=out_channels,
177 | kernel_size=kernel_size,
178 | stride=stride,
179 | padding=0,
180 | output_padding=0,
181 | bias=bias,
182 | )
183 | self.stride = stride
184 | self.pad_length = (math.ceil(kernel_size/stride) - 1)
185 | if pad_buffer is None:
186 | pad_buffer = torch.zeros(1, in_channels, self.pad_length)
187 | self.register_buffer("pad_buffer", pad_buffer)
188 |
189 | def forward(self, x):
190 | pad = nn.ReplicationPad1d((self.pad_length, 0))
191 | x = pad(x)
192 | return self.deconv(x)[:, :, self.stride : -self.stride]
193 |
194 | def inference(self, x):
195 | x = torch.cat((self.pad_buffer, x), -1)
196 | self.pad_buffer = x[:, :, -self.pad_length:]
197 | return self.deconv(x)[:, :, self.stride : -self.stride]
198 |
199 | def reset_buffer(self):
200 | self.pad_buffer.zero_()
201 |
202 |
203 | class NonCausalConv2d(nn.Module):
204 | """2D noncausal convloution w/ 4-sides padding."""
205 |
206 | def __init__(
207 | self,
208 | in_channels,
209 | out_channels,
210 | kernel_size,
211 | stride=1,
212 | padding=-1,
213 | dilation=1,
214 | groups=1,
215 | bias=True):
216 | super().__init__()
217 | self.in_channels = in_channels
218 | self.out_channels = out_channels
219 | self.kernel_size = int2tuple(kernel_size, 2)
220 | self.dilation = int2tuple(dilation, 2)
221 | if isinstance(padding, int) and padding < 0:
222 | padding_0 = (self.kernel_size[0] - 1) // 2 * self.dilation[0]
223 | padding_1 = (self.kernel_size[1] - 1) // 2 * self.dilation[1]
224 | padding = (padding_0, padding_1)
225 |
226 | self.conv = nn.Conv2d(
227 | in_channels=in_channels,
228 | out_channels=out_channels,
229 | kernel_size=kernel_size,
230 | stride=stride,
231 | padding=padding,
232 | dilation=dilation,
233 | groups=groups,
234 | bias=bias,
235 | )
236 |
237 | def forward(self, x):
238 | """
239 | Args:
240 | x (Tensor): Float tensor variable with the shape (B, C, T).
241 | Returns:
242 | Tensor: Float tensor variable with the shape (B, C, T).
243 | """
244 | x = self.conv(x)
245 | return x
--------------------------------------------------------------------------------
/AudioDec/layers/vq_module.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright (c) Meta Platforms, Inc. and affiliates.
5 | # All rights reserved.
6 | #
7 | # This source code is licensed under the license found in the
8 | # LICENSE file in the root directory of this source tree.
9 | #
10 | # Reference (https://github.com/lucidrains/vector-quantize-pytorch/)
11 |
12 | """Vector quantizer."""
13 |
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 |
18 |
19 | class VectorQuantize(nn.Module):
20 | """Vector quantization w/ exponential moving averages (EMA)"""
21 |
22 | def __init__(
23 | self,
24 | dim,
25 | codebook_size,
26 | decay = 0.8,
27 | commitment = 1.,
28 | eps = 1e-5,
29 | n_embed = None,
30 | ):
31 | super().__init__()
32 | n_embed = self.default(n_embed, codebook_size)
33 |
34 | self.dim = dim
35 | self.n_embed = n_embed
36 | self.decay = decay
37 | self.eps = eps
38 | self.commitment = commitment
39 |
40 | embed = torch.randn(dim, n_embed)
41 | self.register_buffer('embed', embed)
42 | self.register_buffer('cluster_size', torch.zeros(n_embed))
43 | self.register_buffer('embed_avg', embed.clone())
44 |
45 | @property
46 | def codebook(self):
47 | return self.embed.transpose(0, 1)
48 |
49 | def exists(self,val):
50 | return val is not None
51 |
52 | def default(self, val, d):
53 | return val if self.exists(val) else d
54 |
55 | def ema_inplace(self, moving_avg, new, decay):
56 | moving_avg.data.mul_(decay).add_(new, alpha = (1 - decay))
57 |
58 | def laplace_smoothing(self, x, n_categories, eps=1e-5):
59 | return (x + eps) / (x.sum() + n_categories * eps)
60 |
61 | def forward(self, input):
62 | dtype = input.dtype
63 | flatten = input.reshape(-1, self.dim)
64 | dist = (
65 | flatten.pow(2).sum(1, keepdim=True)
66 | - 2 * flatten @ self.embed
67 | + self.embed.pow(2).sum(0, keepdim=True)
68 | )
69 | _, embed_ind = (-dist).max(1)
70 | embed_onehot = F.one_hot(embed_ind, self.n_embed).type(dtype)
71 | embed_ind = embed_ind.view(*input.shape[:-1])
72 | quantize = F.embedding(embed_ind, self.embed.transpose(0, 1))
73 |
74 | if self.training:
75 | self.ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
76 | embed_sum = flatten.transpose(0, 1) @ embed_onehot
77 | self.ema_inplace(self.embed_avg, embed_sum, self.decay)
78 | cluster_size = self.laplace_smoothing(self.cluster_size, self.n_embed, self.eps) * self.cluster_size.sum()
79 | embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
80 | self.embed.data.copy_(embed_normalized)
81 |
82 | loss = F.mse_loss(quantize.detach(), input) * self.commitment
83 | quantize = input + (quantize - input).detach()
84 |
85 | avg_probs = torch.mean(embed_onehot, dim=0)
86 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
87 |
88 | return quantize, loss, perplexity
89 |
90 | def forward_index(self, input):
91 | dtype = input.dtype
92 | flatten = input.reshape(-1, self.dim)
93 | dist = (
94 | flatten.pow(2).sum(1, keepdim=True)
95 | - 2 * flatten @ self.embed
96 | + self.embed.pow(2).sum(0, keepdim=True)
97 | )
98 | _, embed_ind = (-dist).max(1)
99 | embed_onehot = F.one_hot(embed_ind, self.n_embed).type(dtype)
100 | embed_ind = embed_ind.view(*input.shape[:-1])
101 | quantize = F.embedding(embed_ind, self.embed.transpose(0, 1))
102 | quantize = input + (quantize - input).detach()
103 |
104 | return quantize, embed_ind
105 |
106 |
107 | class ResidualVQ(nn.Module):
108 | """ Residual VQ following algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """
109 |
110 | def __init__(
111 | self,
112 | *,
113 | num_quantizers,
114 | **kwargs
115 | ):
116 | super().__init__()
117 | self.layers = nn.ModuleList([VectorQuantize(**kwargs) for _ in range(num_quantizers)])
118 |
119 | def forward(self, x):
120 | quantized_out = 0.
121 | residual = x
122 | all_losses = []
123 | all_perplexities = []
124 | for layer in self.layers:
125 | quantized, loss, perplexity = layer(residual)
126 | # Issue: https://github.com/lucidrains/vector-quantize-pytorch/issues/33
127 | # We found considering only the 1st layer VQ's graident results in better performance
128 | #residual = residual - quantized.detach() # considering all layers' graidents
129 | residual = residual - quantized # considering only the first layer's graident
130 | quantized_out = quantized_out + quantized
131 | all_losses.append(loss)
132 | all_perplexities.append(perplexity)
133 | all_losses, all_perplexities = map(torch.stack, (all_losses, all_perplexities))
134 | return quantized_out, all_losses, all_perplexities
135 |
136 | def forward_index(self, x, flatten_idx=False):
137 | quantized_out = 0.
138 | residual = x
139 | all_indices = []
140 | for i, layer in enumerate(self.layers):
141 | quantized, indices = layer.forward_index(residual)
142 | #residual = residual - quantized.detach()
143 | residual = residual - quantized
144 | quantized_out = quantized_out + quantized
145 | if flatten_idx:
146 | indices += (self.codebook_size * i)
147 | all_indices.append(indices)
148 | all_indices= torch.stack(all_indices)
149 | return quantized_out, all_indices.squeeze(1)
150 |
151 | def initial(self):
152 | self.codebook = []
153 | for layer in self.layers:
154 | self.codebook.append(layer.codebook)
155 | self.codebook_size = self.codebook[0].size(0)
156 | self.codebook = torch.stack(self.codebook)
157 | self.codebook = self.codebook.reshape(-1, self.codebook.size(-1))
158 |
159 | def lookup(self, indices):
160 | quantized_out = F.embedding(indices, self.codebook) # Num x T x C
161 | return torch.sum(quantized_out, dim=0,keepdim=True)
162 |
--------------------------------------------------------------------------------
/AudioDec/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from .adversarial_loss import * # NOQA
2 | from .feat_match_loss import * # NOQA
3 | from .mel_loss import * # NOQA
4 | from .stft_loss import * # NOQA
5 | from .waveform_loss import * # NOQA
--------------------------------------------------------------------------------
/AudioDec/losses/adversarial_loss.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright 2021 Tomoki Hayashi
5 | # MIT License (https://opensource.org/licenses/MIT)
6 |
7 | """Adversarial loss modules."""
8 |
9 | import torch
10 | import torch.nn.functional as F
11 |
12 |
13 | class GeneratorAdversarialLoss(torch.nn.Module):
14 | """Generator adversarial loss module."""
15 |
16 | def __init__(
17 | self,
18 | average_by_discriminators=True,
19 | loss_type="mse",
20 | ):
21 | """Initialize GeneratorAversarialLoss module."""
22 | super().__init__()
23 | self.average_by_discriminators = average_by_discriminators
24 | assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
25 | if loss_type == "mse":
26 | self.criterion = self._mse_loss
27 | else:
28 | self.criterion = self._hinge_loss
29 |
30 | def forward(self, outputs):
31 | """Calcualate generator adversarial loss.
32 |
33 | Args:
34 | outputs (Tensor or list): Discriminator outputs or list of
35 | discriminator outputs.
36 |
37 | Returns:
38 | Tensor: Generator adversarial loss value.
39 |
40 | """
41 | if isinstance(outputs, (tuple, list)):
42 | adv_loss = 0.0
43 | for i, outputs_ in enumerate(outputs):
44 | if isinstance(outputs_, (tuple, list)):
45 | # NOTE(kan-bayashi): case including feature maps
46 | outputs_ = outputs_[-1]
47 | adv_loss += self.criterion(outputs_)
48 | if self.average_by_discriminators:
49 | adv_loss /= i + 1
50 | else:
51 | adv_loss = self.criterion(outputs)
52 |
53 | return adv_loss
54 |
55 | def _mse_loss(self, x):
56 | return F.mse_loss(x, x.new_ones(x.size()))
57 |
58 | def _hinge_loss(self, x):
59 | return -x.mean()
60 |
61 |
62 | class DiscriminatorAdversarialLoss(torch.nn.Module):
63 | """Discriminator adversarial loss module."""
64 |
65 | def __init__(
66 | self,
67 | average_by_discriminators=True,
68 | loss_type="mse",
69 | ):
70 | """Initialize DiscriminatorAversarialLoss module."""
71 | super().__init__()
72 | self.average_by_discriminators = average_by_discriminators
73 | assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
74 | if loss_type == "mse":
75 | self.fake_criterion = self._mse_fake_loss
76 | self.real_criterion = self._mse_real_loss
77 | else:
78 | self.fake_criterion = self._hinge_fake_loss
79 | self.real_criterion = self._hinge_real_loss
80 |
81 | def forward(self, outputs_hat, outputs):
82 | """Calcualate discriminator adversarial loss.
83 |
84 | Args:
85 | outputs_hat (Tensor or list): Discriminator outputs or list of
86 | discriminator outputs calculated from generator outputs.
87 | outputs (Tensor or list): Discriminator outputs or list of
88 | discriminator outputs calculated from groundtruth.
89 |
90 | Returns:
91 | Tensor: Discriminator real loss value.
92 | Tensor: Discriminator fake loss value.
93 |
94 | """
95 | if isinstance(outputs, (tuple, list)):
96 | real_loss = 0.0
97 | fake_loss = 0.0
98 | for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)):
99 | if isinstance(outputs_hat_, (tuple, list)):
100 | # NOTE(kan-bayashi): case including feature maps
101 | outputs_hat_ = outputs_hat_[-1]
102 | outputs_ = outputs_[-1]
103 | real_loss += self.real_criterion(outputs_)
104 | fake_loss += self.fake_criterion(outputs_hat_)
105 | if self.average_by_discriminators:
106 | fake_loss /= i + 1
107 | real_loss /= i + 1
108 | else:
109 | real_loss = self.real_criterion(outputs)
110 | fake_loss = self.fake_criterion(outputs_hat)
111 |
112 | return real_loss, fake_loss
113 |
114 | def _mse_real_loss(self, x):
115 | return F.mse_loss(x, x.new_ones(x.size()))
116 |
117 | def _mse_fake_loss(self, x):
118 | return F.mse_loss(x, x.new_zeros(x.size()))
119 |
120 | def _hinge_real_loss(self, x):
121 | return -torch.mean(torch.min(x - 1, x.new_zeros(x.size())))
122 |
123 | def _hinge_fake_loss(self, x):
124 | return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size())))
125 |
--------------------------------------------------------------------------------
/AudioDec/losses/feat_match_loss.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright 2021 Tomoki Hayashi
5 | # MIT License (https://opensource.org/licenses/MIT)
6 |
7 | """Feature matching loss modules."""
8 |
9 | import torch
10 | import torch.nn.functional as F
11 |
12 |
13 | class FeatureMatchLoss(torch.nn.Module):
14 | """Feature matching loss module."""
15 |
16 | def __init__(
17 | self,
18 | average_by_layers=True,
19 | average_by_discriminators=True,
20 | include_final_outputs=False,
21 | ):
22 | """Initialize FeatureMatchLoss module."""
23 | super().__init__()
24 | self.average_by_layers = average_by_layers
25 | self.average_by_discriminators = average_by_discriminators
26 | self.include_final_outputs = include_final_outputs
27 |
28 | def forward(self, feats_hat, feats):
29 | """Calcualate feature matching loss.
30 |
31 | Args:
32 | feats_hat (list): List of list of discriminator outputs
33 | calcuated from generater outputs.
34 | feats (list): List of list of discriminator outputs
35 | calcuated from groundtruth.
36 |
37 | Returns:
38 | Tensor: Feature matching loss value.
39 |
40 | """
41 | feat_match_loss = 0.0
42 | for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)):
43 | feat_match_loss_ = 0.0
44 | if not self.include_final_outputs:
45 | feats_hat_ = feats_hat_[:-1]
46 | feats_ = feats_[:-1]
47 | for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
48 | feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach())
49 | if self.average_by_layers:
50 | feat_match_loss_ /= j + 1
51 | feat_match_loss += feat_match_loss_
52 | if self.average_by_discriminators:
53 | feat_match_loss /= i + 1
54 |
55 | return feat_match_loss
56 |
--------------------------------------------------------------------------------
/AudioDec/losses/mel_loss.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright (c) Meta Platforms, Inc. and affiliates.
5 | # All rights reserved.
6 | #
7 | # This source code is licensed under the license found in the
8 | # LICENSE file in the root directory of this source tree.
9 | #
10 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/)
11 |
12 | """Mel-spectrogram loss modules."""
13 |
14 | import librosa
15 | import torch
16 | import torch.nn.functional as F
17 |
18 |
19 | class MelSpectrogram(torch.nn.Module):
20 | """Calculate Mel-spectrogram."""
21 |
22 | def __init__(
23 | self,
24 | fs=22050,
25 | fft_size=1024,
26 | hop_size=256,
27 | win_length=None,
28 | window="hann_window",
29 | num_mels=80,
30 | fmin=80,
31 | fmax=7600,
32 | center=True,
33 | normalized=False,
34 | onesided=True,
35 | eps=1e-10,
36 | log_base=10.0,
37 | ):
38 | """Initialize MelSpectrogram module."""
39 | super().__init__()
40 | self.fft_size = fft_size
41 | self.hop_size = hop_size
42 | if win_length is not None:
43 | self.win_length = win_length
44 | else:
45 | self.win_length = fft_size
46 | self.center = center
47 | self.normalized = normalized
48 | self.onesided = onesided
49 | self.register_buffer("window", getattr(torch, window)(self.win_length))
50 | self.eps = eps
51 |
52 | fmin = 0 if fmin is None else fmin
53 | fmax = fs / 2 if fmax is None else fmax
54 | melmat = librosa.filters.mel(
55 | sr=fs,
56 | n_fft=fft_size,
57 | n_mels=num_mels,
58 | fmin=fmin,
59 | fmax=fmax,
60 | )
61 | self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
62 |
63 | self.log_base = log_base
64 | if self.log_base is None:
65 | self.log = torch.log
66 | elif self.log_base == 2.0:
67 | self.log = torch.log2
68 | elif self.log_base == 10.0:
69 | self.log = torch.log10
70 | else:
71 | raise ValueError(f"log_base: {log_base} is not supported.")
72 |
73 |
74 | def forward(self, x):
75 | """Calculate Mel-spectrogram.
76 |
77 | Args:
78 | x (Tensor): Input waveform tensor (B, T) or (B, C, T).
79 |
80 | Returns:
81 | Tensor: Mel-spectrogram (B, #mels, #frames).
82 |
83 | """
84 | if x.dim() == 3:
85 | # (B, C, T) -> (B*C, T)
86 | x = x.reshape(-1, x.size(2))
87 |
88 | x_stft = torch.stft(x, self.fft_size, self.hop_size, self.win_length, self.window, return_complex=True)
89 | x_power = x_stft.real ** 2 + x_stft.imag ** 2
90 | x_amp = torch.sqrt(torch.clamp(x_power, min=self.eps)).transpose(2, 1) # (B, D, T') -> (B, T', D)
91 | x_mel = torch.matmul(x_amp, self.melmat)
92 | x_mel = torch.clamp(x_mel, min=self.eps)
93 |
94 | return self.log(x_mel).transpose(1, 2) # (B, D, T')
95 |
96 |
97 | class MultiMelSpectrogramLoss(torch.nn.Module):
98 | """Multi resolution Mel-spectrogram loss."""
99 |
100 | def __init__(
101 | self,
102 | fs=22050,
103 | fft_sizes=[1024, 2048, 512],
104 | hop_sizes=[120, 240, 50],
105 | win_lengths=[600, 1200, 240],
106 | window="hann_window",
107 | num_mels=80,
108 | fmin=80,
109 | fmax=7600,
110 | center=True,
111 | normalized=False,
112 | onesided=True,
113 | eps=1e-10,
114 | log_base=10.0,
115 | ):
116 | """Initialize Mel-spectrogram loss."""
117 | super().__init__()
118 | assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
119 | self.mel_transfers = torch.nn.ModuleList()
120 | for fft_size, hop_size, win_length in zip(fft_sizes, hop_sizes, win_lengths):
121 | self.mel_transfers += [
122 | MelSpectrogram(
123 | fs=fs,
124 | fft_size=fft_size,
125 | hop_size=hop_size,
126 | win_length=win_length,
127 | window=window,
128 | num_mels=num_mels,
129 | fmin=fmin,
130 | fmax=fmax,
131 | center=center,
132 | normalized=normalized,
133 | onesided=onesided,
134 | eps=eps,
135 | log_base=log_base,
136 | )
137 | ]
138 |
139 |
140 | def forward(self, y_hat, y):
141 | """Calculate Mel-spectrogram loss.
142 |
143 | Args:
144 | y_hat (Tensor): Generated single tensor (B, C, T).
145 | y (Tensor): Groundtruth single tensor (B, C, T).
146 |
147 | Returns:
148 | Tensor: Mel-spectrogram loss value.
149 |
150 | """
151 | mel_loss = 0.0
152 | for f in self.mel_transfers:
153 | mel_loss += F.l1_loss(f(y_hat), f(y))
154 | mel_loss /= len(self.mel_transfers)
155 |
156 | return mel_loss
--------------------------------------------------------------------------------
/AudioDec/losses/stft_loss.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright (c) Meta Platforms, Inc. and affiliates.
5 | # All rights reserved.
6 | #
7 | # This source code is licensed under the license found in the
8 | # LICENSE file in the root directory of this source tree.
9 | #
10 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/)
11 |
12 | """STFT-based loss modules."""
13 |
14 | import torch
15 | import torch.nn.functional as F
16 |
17 |
18 |
19 | def stft(x, fft_size, hop_size, win_length, window, eps=1e-7):
20 | """Perform STFT and convert to magnitude spectrogram.
21 |
22 | Args:
23 | x (Tensor): Input signal tensor (B, T).
24 | fft_size (int): FFT size.
25 | hop_size (int): Hop size.
26 | win_length (int): Window length.
27 | window (str): Window function type.
28 |
29 | Returns:
30 | Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
31 |
32 | """
33 | x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True)
34 | x_power = x_stft.real ** 2 + x_stft.imag ** 2
35 | return torch.sqrt(torch.clamp(x_power, min=eps)).transpose(2, 1)
36 |
37 |
38 | class SpectralConvergenceLoss(torch.nn.Module):
39 | """Spectral convergence loss module."""
40 |
41 | def __init__(self):
42 | """Initilize spectral convergence loss module."""
43 | super(SpectralConvergenceLoss, self).__init__()
44 |
45 | def forward(self, x_mag, y_mag):
46 | """Calculate forward propagation.
47 |
48 | Args:
49 | x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
50 | y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
51 |
52 | Returns:
53 | Tensor: Spectral convergence loss value.
54 |
55 | """
56 | return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
57 |
58 |
59 | class LogSTFTMagnitudeLoss(torch.nn.Module):
60 | """Log STFT magnitude loss module."""
61 |
62 | def __init__(self):
63 | """Initilize los STFT magnitude loss module."""
64 | super(LogSTFTMagnitudeLoss, self).__init__()
65 |
66 | def forward(self, x_mag, y_mag):
67 | """Calculate forward propagation.
68 |
69 | Args:
70 | x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
71 | y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
72 |
73 | Returns:
74 | Tensor: Log STFT magnitude loss value.
75 |
76 | """
77 | return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
78 |
79 |
80 | class STFTLoss(torch.nn.Module):
81 | """STFT loss module."""
82 |
83 | def __init__(
84 | self,
85 | fft_size=1024,
86 | hop_size=120,
87 | win_length=600,
88 | window="hann_window",
89 | ):
90 | """Initialize STFT loss module."""
91 | super(STFTLoss, self).__init__()
92 | self.fft_size = fft_size
93 | self.hop_size = hop_size
94 | self.win_length = win_length
95 | self.spectral_convergence_loss = SpectralConvergenceLoss()
96 | self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
97 | self.register_buffer("window", getattr(torch, window)(win_length))
98 |
99 |
100 | def forward(self, x, y):
101 | """Calculate forward propagation.
102 |
103 | Args:
104 | x (Tensor): Predicted signal (B, T).
105 | y (Tensor): Groundtruth signal (B, T).
106 |
107 | Returns:
108 | Tensor: Spectral convergence loss value.
109 | Tensor: Log STFT magnitude loss value.
110 |
111 | """
112 | x_mag = stft(x, self.fft_size, self.hop_size, self.win_length, self.window)
113 | y_mag = stft(y, self.fft_size, self.hop_size, self.win_length, self.window)
114 | sc_loss = self.spectral_convergence_loss(x_mag, y_mag)
115 | mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
116 |
117 | return sc_loss, mag_loss
118 |
119 |
120 | class MultiResolutionSTFTLoss(torch.nn.Module):
121 | """Multi resolution STFT loss module."""
122 |
123 | def __init__(
124 | self,
125 | fft_sizes=[1024, 2048, 512],
126 | hop_sizes=[120, 240, 50],
127 | win_lengths=[600, 1200, 240],
128 | window="hann_window",
129 | ):
130 | """Initialize Multi resolution STFT loss module.
131 |
132 | Args:
133 | fft_sizes (list): List of FFT sizes.
134 | hop_sizes (list): List of hop sizes.
135 | win_lengths (list): List of window lengths.
136 | window (str): Window function type.
137 |
138 | """
139 | super(MultiResolutionSTFTLoss, self).__init__()
140 | assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
141 | self.stft_losses = torch.nn.ModuleList()
142 | for fft_size, hop_size, win_length in zip(fft_sizes, hop_sizes, win_lengths):
143 | self.stft_losses += [STFTLoss(fft_size, hop_size, win_length, window)]
144 |
145 |
146 | def forward(self, x, y):
147 | """Calculate forward propagation.
148 |
149 | Args:
150 | x (Tensor): Predicted signal (B, T) or (B, #subband, T).
151 | y (Tensor): Groundtruth signal (B, T) or (B, #subband, T).
152 |
153 | Returns:
154 | Tensor: Multi resolution spectral convergence loss value.
155 | Tensor: Multi resolution log STFT magnitude loss value.
156 |
157 | """
158 | if len(x.shape) == 3:
159 | x = x.view(-1, x.size(2)) # (B, C, T) -> (B x C, T)
160 | y = y.view(-1, y.size(2)) # (B, C, T) -> (B x C, T)
161 | sc_loss = 0.0
162 | mag_loss = 0.0
163 | for f in self.stft_losses:
164 | sc_l, mag_l = f(x, y)
165 | sc_loss += sc_l
166 | mag_loss += mag_l
167 | sc_loss /= len(self.stft_losses)
168 | mag_loss /= len(self.stft_losses)
169 |
170 | return sc_loss, mag_loss
171 |
--------------------------------------------------------------------------------
/AudioDec/losses/waveform_loss.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright (c) Meta Platforms, Inc. and affiliates.
5 | # All rights reserved.
6 | #
7 | # This source code is licensed under the license found in the
8 | # LICENSE file in the root directory of this source tree.
9 |
10 | """Waveform-based loss modules."""
11 |
12 | import torch
13 |
14 |
15 | class WaveformShapeLoss(torch.nn.Module):
16 | """Waveform shape loss."""
17 |
18 | def __init__(self, winlen):
19 | super().__init__()
20 | self.loss = torch.nn.L1Loss()
21 | self.winlen = winlen
22 | self.maxpool = torch.nn.MaxPool1d(self.winlen)
23 |
24 | def forward(self, y_hat, y):
25 | """Calculate L1 loss.
26 |
27 | Args:
28 | y_hat (Tensor): Generated single tensor (B, 1, T).
29 | y (Tensor): Groundtruth single tensor (B, 1, T).
30 |
31 | Returns:
32 | Tensor: L1 loss value.
33 |
34 | """
35 | ys = self.maxpool(torch.abs(y))
36 | ys_hat = self.maxpool(torch.abs(y_hat))
37 | loss = self.loss(ys_hat, ys)
38 | return loss
39 |
40 |
41 | class MultiWindowShapeLoss(torch.nn.Module):
42 | """Multi-window-lengthe waveform shape loss."""
43 |
44 | def __init__(
45 | self,
46 | winlen=[300, 200, 100],
47 | ):
48 | """Initialize Multi window shape loss module.
49 |
50 | Args:
51 | winlen (list): List of window lengths.
52 |
53 | """
54 | super(MultiWindowShapeLoss, self).__init__()
55 | self.shape_losses = torch.nn.ModuleList()
56 | for wl in winlen:
57 | self.shape_losses += [WaveformShapeLoss(wl)]
58 |
59 | def forward(self, y_hat, y):
60 | """Calculate L1 loss.
61 |
62 | Args:
63 | y_hat (Tensor): Generated single tensor (B, 1, T).
64 | y (Tensor): Groundtruth single tensor (B, 1, T).
65 |
66 | Returns:
67 | Tensor: L2 loss value.
68 |
69 | """
70 | loss = 0.0
71 | for f in self.shape_losses:
72 | loss += f(y_hat, y)
73 | loss /= len(self.shape_losses)
74 |
75 | return loss
--------------------------------------------------------------------------------
/AudioDec/models/autoencoder/AudioDec.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright (c) Meta Platforms, Inc. and affiliates.
5 | # All rights reserved.
6 | #
7 | # This source code is licensed under the license found in the
8 | # LICENSE file in the root directory of this source tree.
9 | #
10 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/)
11 | # Reference (https://github.com/jik876/hifi-gan/)
12 |
13 | """AudioDec model."""
14 |
15 | import torch
16 | import logging
17 |
18 | from AudioDec.layers.conv_layer import CausalConv1d, CausalConvTranspose1d
19 | from AudioDec.models.autoencoder.modules.encoder import Encoder, ActivateEncoder
20 | from AudioDec.models.autoencoder.modules.decoder import Decoder, ActivateDecoder
21 | from AudioDec.models.autoencoder.modules.projector import Projector
22 | from AudioDec.models.autoencoder.modules.quantizer import Quantizer
23 | from AudioDec.models.utils import check_mode
24 |
25 |
26 | ### GENERATOR ###
27 | class Generator(torch.nn.Module):
28 | """AudioDec generator."""
29 |
30 | def __init__(
31 | self,
32 | input_channels=1,
33 | output_channels=1,
34 | encode_channels=32,
35 | decode_channels=32,
36 | code_dim=64,
37 | codebook_num=8,
38 | codebook_size=1024,
39 | bias=True,
40 | enc_ratios=(2, 4, 8, 16),
41 | dec_ratios=(16, 8, 4, 2),
42 | enc_strides=(3, 4, 5, 5),
43 | dec_strides=(5, 5, 4, 3),
44 | mode='causal',
45 | codec='AudioDec',
46 | projector='conv1d',
47 | quantier='residual_vq',
48 | nonlinear_activation="ELU",
49 | nonlinear_activation_params={},
50 | use_weight_norm=False,
51 | ):
52 | super().__init__()
53 | if codec == 'audiodec':
54 | encoder = Encoder
55 | decoder = Decoder
56 | elif codec == 'activate_audiodec':
57 | encoder = ActivateEncoder
58 | decoder = ActivateDecoder
59 | else:
60 | raise NotImplementedError(f"Codec ({codec}) is not supported!")
61 | self.mode = mode
62 | self.input_channels = input_channels
63 |
64 | self.encoder = encoder(
65 | input_channels=input_channels,
66 | encode_channels=encode_channels,
67 | channel_ratios=enc_ratios,
68 | strides=enc_strides,
69 | kernel_size=7,
70 | bias=bias,
71 | mode=self.mode,
72 | nonlinear_activation=nonlinear_activation,
73 | nonlinear_activation_params=nonlinear_activation_params,
74 | )
75 |
76 | self.decoder = decoder(
77 | code_dim=code_dim,
78 | output_channels=output_channels,
79 | decode_channels=decode_channels,
80 | channel_ratios=dec_ratios,
81 | strides=dec_strides,
82 | kernel_size=7,
83 | bias=bias,
84 | mode=self.mode,
85 | nonlinear_activation=nonlinear_activation,
86 | nonlinear_activation_params=nonlinear_activation_params,
87 | )
88 |
89 | self.projector = Projector(
90 | input_channels=self.encoder.out_channels,
91 | code_dim=code_dim,
92 | kernel_size=3,
93 | stride=1,
94 | bias=False,
95 | mode=self.mode,
96 | model=projector,
97 | )
98 |
99 | self.quantizer = Quantizer(
100 | code_dim=code_dim,
101 | codebook_num=codebook_num,
102 | codebook_size=codebook_size,
103 | model=quantier,
104 | )
105 |
106 | # apply weight norm & reset parameters
107 | if use_weight_norm:
108 | self.apply_weight_norm()
109 | self.reset_parameters()
110 |
111 |
112 | def forward(self, x):
113 | (batch, channel, length) = x.size()
114 | if channel != self.input_channels:
115 | x = x.reshape(-1, self.input_channels, length) # (B, C, T) -> (B', C', T)
116 | x = self.encoder(x)
117 | z = self.projector(x)
118 | zq, vqloss, perplexity = self.quantizer(z)
119 | y = self.decoder(zq)
120 | return y, zq, z, vqloss, perplexity
121 |
122 |
123 | def reset_parameters(self):
124 | """Reset parameters.
125 |
126 | This initialization follows the official implementation manner.
127 | https://github.com/jik876/hifi-gan/blob/master/models.py
128 |
129 | """
130 |
131 | def _reset_parameters(m):
132 | if isinstance(m, (torch.nn.Conv1d, torch.nn.ConvTranspose1d)):
133 | m.weight.data.normal_(0.0, 0.01)
134 | logging.debug(f"Reset parameters in {m}.")
135 |
136 | self.apply(_reset_parameters)
137 |
138 |
139 | def remove_weight_norm(self):
140 | """Remove weight normalization module from all of the layers."""
141 |
142 | def _remove_weight_norm(m):
143 | try:
144 | logging.debug(f"Weight norm is removed from {m}.")
145 | torch.nn.utils.remove_weight_norm(m)
146 | except ValueError: # this module didn't have weight norm
147 | return
148 |
149 | self.apply(_remove_weight_norm)
150 |
151 |
152 | def apply_weight_norm(self):
153 | """Apply weight normalization module from all of the layers."""
154 |
155 | def _apply_weight_norm(m):
156 | if isinstance(m, torch.nn.Conv1d) or isinstance(
157 | m, torch.nn.ConvTranspose1d
158 | ):
159 | torch.nn.utils.weight_norm(m)
160 | logging.debug(f"Weight norm is applied to {m}.")
161 |
162 | self.apply(_apply_weight_norm)
163 |
164 |
165 | # STREAMING
166 | class StreamGenerator(Generator):
167 | """AudioDec streaming generator."""
168 |
169 | def __init__(
170 | self,
171 | input_channels=1,
172 | output_channels=1,
173 | encode_channels=32,
174 | decode_channels=32,
175 | code_dim=64,
176 | codebook_num=8,
177 | codebook_size=1024,
178 | bias=True,
179 | enc_ratios=(2, 4, 8, 16),
180 | dec_ratios=(16, 8, 4, 2),
181 | enc_strides=(3, 4, 5, 5),
182 | dec_strides=(5, 5, 4, 3),
183 | mode='causal',
184 | codec='audiodec',
185 | projector='conv1d',
186 | quantier='residual_vq',
187 | nonlinear_activation="ELU",
188 | nonlinear_activation_params={},
189 | use_weight_norm=False,
190 | ):
191 | super(StreamGenerator, self).__init__(
192 | input_channels=input_channels,
193 | output_channels=output_channels,
194 | encode_channels=encode_channels,
195 | decode_channels=decode_channels,
196 | code_dim=code_dim,
197 | codebook_num=codebook_num,
198 | codebook_size=codebook_size,
199 | bias=bias,
200 | enc_ratios=enc_ratios,
201 | dec_ratios=dec_ratios,
202 | enc_strides=enc_strides,
203 | dec_strides=dec_strides,
204 | mode=mode,
205 | codec=codec,
206 | projector=projector,
207 | quantier=quantier,
208 | nonlinear_activation=nonlinear_activation,
209 | nonlinear_activation_params=nonlinear_activation_params,
210 | use_weight_norm=use_weight_norm,
211 | )
212 | check_mode(mode, "AudioDec Streamer")
213 | self.reset_buffer()
214 |
215 |
216 | def initial_encoder(self, receptive_length, device):
217 | self.quantizer.initial()
218 | z = self.encode(torch.zeros(1, self.input_channels, receptive_length).to(device))
219 | idx = self.quantize(z)
220 | zq = self.lookup(idx)
221 | return zq
222 |
223 |
224 | def initial_decoder(self, zq):
225 | self.decode(zq)
226 |
227 |
228 | def encode(self, x):
229 | (batch, channel, length) = x.size()
230 | if channel != self.input_channels:
231 | x = x.reshape(-1, self.input_channels, length) # (B, C, T) -> (B', C', T)
232 | x = self.encoder.encode(x)
233 | z = self.projector.encode(x)
234 | return z
235 |
236 |
237 | def quantize(self, z):
238 | zq, idx = self.quantizer.encode(z)
239 | return idx
240 |
241 |
242 | def lookup(self, idx):
243 | return self.quantizer.decode(idx)
244 |
245 |
246 | def decode(self, zq):
247 | return self.decoder.decode(zq.transpose(2, 1))
248 |
249 |
250 | def reset_buffer(self):
251 | """Apply weight normalization module from all layers."""
252 |
253 | def _reset_buffer(m):
254 | if isinstance(m, CausalConv1d) or isinstance(m, CausalConvTranspose1d):
255 | m.reset_buffer()
256 | self.apply(_reset_buffer)
257 |
--------------------------------------------------------------------------------
/AudioDec/models/autoencoder/modules/decoder.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright (c) Meta Platforms, Inc. and affiliates.
5 | # All rights reserved.
6 | #
7 | # This source code is licensed under the license found in the
8 | # LICENSE file in the root directory of this source tree.
9 | #
10 | # Reference (https://ieeexplore.ieee.org/document/9625818)
11 |
12 | """Decoder modules."""
13 |
14 | import torch
15 | import inspect
16 |
17 | from AudioDec.layers.conv_layer import NonCausalConv1d, NonCausalConvTranspose1d
18 | from AudioDec.layers.conv_layer import CausalConv1d, CausalConvTranspose1d
19 | from AudioDec.layers.activation_function import get_activation
20 | from AudioDec.models.autoencoder.modules.residual_unit import NonCausalResidualUnit
21 | from AudioDec.models.autoencoder.modules.residual_unit import CausalResidualUnit
22 | from AudioDec.models.utils import check_mode
23 |
24 |
25 | class DecoderBlock(torch.nn.Module):
26 | """ Decoder block (upsampling) """
27 |
28 | def __init__(
29 | self,
30 | in_channels,
31 | out_channels,
32 | stride,
33 | dilations=(1, 3, 9),
34 | bias=True,
35 | mode='causal',
36 | nonlinear_activation="ELU",
37 | nonlinear_activation_params={},
38 | ):
39 | super().__init__()
40 | self.mode = mode
41 | if self.mode == 'noncausal':
42 | ResidualUnit = NonCausalResidualUnit
43 | ConvTranspose1d = NonCausalConvTranspose1d
44 | elif self.mode == 'causal':
45 | ResidualUnit = CausalResidualUnit
46 | ConvTranspose1d = CausalConvTranspose1d
47 | else:
48 | raise NotImplementedError(f"Mode ({self.mode}) is not supported!")
49 |
50 | self.conv = ConvTranspose1d(
51 | in_channels=in_channels,
52 | out_channels=out_channels,
53 | kernel_size=(2 * stride),
54 | stride=stride,
55 | bias=bias,
56 | )
57 |
58 | self.res_units = torch.nn.ModuleList()
59 | for idx, dilation in enumerate(dilations):
60 | self.res_units += [
61 | ResidualUnit(
62 | out_channels,
63 | out_channels,
64 | dilation=dilation,
65 | nonlinear_activation=nonlinear_activation,
66 | nonlinear_activation_params=nonlinear_activation_params,
67 | )]
68 | self.num_res = len(self.res_units)
69 |
70 | def forward(self, x):
71 | x = self.conv(x)
72 | for idx in range(self.num_res):
73 | x = self.res_units[idx](x)
74 | return x
75 |
76 | def inference(self, x):
77 | check_mode(self.mode, inspect.stack()[0][3])
78 | x = self.conv.inference(x)
79 | for idx in range(self.num_res):
80 | x = self.res_units[idx].inference(x)
81 | return x
82 |
83 |
84 | class Decoder(torch.nn.Module):
85 | def __init__(self,
86 | code_dim,
87 | output_channels,
88 | decode_channels,
89 | channel_ratios=(16, 8, 4, 2),
90 | strides=(5, 5, 4, 3),
91 | kernel_size=7,
92 | bias=True,
93 | mode='causal',
94 | nonlinear_activation="ELU",
95 | nonlinear_activation_params={},
96 | ):
97 | super().__init__()
98 | assert len(channel_ratios) == len(strides)
99 | self.mode = mode
100 | if self.mode == 'noncausal':
101 | Conv1d = NonCausalConv1d
102 | elif self.mode == 'causal':
103 | Conv1d = CausalConv1d
104 | else:
105 | raise NotImplementedError(f"Mode ({self.mode}) is not supported!")
106 |
107 | self.conv1 = Conv1d(
108 | in_channels=code_dim,
109 | out_channels=(decode_channels * channel_ratios[0]),
110 | kernel_size=kernel_size,
111 | stride=1,
112 | bias=False)
113 |
114 | self.conv_blocks = torch.nn.ModuleList()
115 | for idx, stride in enumerate(strides):
116 | in_channels = decode_channels * channel_ratios[idx]
117 | if idx < (len(channel_ratios)-1):
118 | out_channels = decode_channels * channel_ratios[idx+1]
119 | else:
120 | out_channels = decode_channels
121 | self.conv_blocks += [
122 | DecoderBlock(
123 | in_channels,
124 | out_channels,
125 | stride,
126 | bias=bias,
127 | mode=self.mode,
128 | nonlinear_activation=nonlinear_activation,
129 | nonlinear_activation_params=nonlinear_activation_params,
130 | )]
131 | self.num_blocks = len(self.conv_blocks)
132 |
133 | self.conv2 = Conv1d(out_channels, output_channels, kernel_size, 1, bias=False)
134 |
135 | def forward(self, z):
136 | x = self.conv1(z)
137 | for i in range(self.num_blocks):
138 | x = self.conv_blocks[i](x)
139 | x = self.conv2(x)
140 | return x
141 |
142 | def decode(self, z):
143 | check_mode(self.mode, inspect.stack()[0][3])
144 | x = self.conv1.inference(z)
145 | for i in range(self.num_blocks):
146 | x = self.conv_blocks[i].inference(x)
147 | x = self.conv2.inference(x)
148 | return x
149 |
150 |
151 | class ActivateDecoder(Decoder):
152 | def __init__(self,
153 | code_dim,
154 | output_channels,
155 | decode_channels,
156 | channel_ratios=(16, 8, 4, 2),
157 | strides=(5, 5, 4, 3),
158 | kernel_size=7,
159 | bias=True,
160 | mode='causal',
161 | nonlinear_activation="ELU",
162 | nonlinear_activation_params={},
163 | ):
164 | super().__init__(
165 | code_dim=code_dim,
166 | output_channels=output_channels,
167 | decode_channels=decode_channels,
168 | channel_ratios=channel_ratios,
169 | strides=strides,
170 | kernel_size=kernel_size,
171 | bias=bias,
172 | mode=mode,
173 | nonlinear_activation=nonlinear_activation,
174 | nonlinear_activation_params=nonlinear_activation_params,
175 | )
176 | self.conv_blocks = torch.nn.ModuleList()
177 | for idx, stride in enumerate(strides):
178 | in_channels = decode_channels * channel_ratios[idx]
179 | if idx < (len(channel_ratios)-1):
180 | out_channels = decode_channels * channel_ratios[idx+1]
181 | else:
182 | out_channels = decode_channels
183 | # upsamping + residual
184 | self.conv_blocks += [
185 | torch.nn.Sequential(
186 | get_activation(nonlinear_activation, nonlinear_activation_params),
187 | DecoderBlock(
188 | in_channels,
189 | out_channels,
190 | stride,
191 | bias=bias,
192 | mode=self.mode,
193 | nonlinear_activation=nonlinear_activation,
194 | nonlinear_activation_params=nonlinear_activation_params,
195 | ),
196 | )
197 | ]
198 | self.conv_blocks += [get_activation(nonlinear_activation, nonlinear_activation_params)] # for conv2
199 | self.num_blocks = len(self.conv_blocks)
200 | # output activation
201 | self.activation_output = torch.nn.Tanh()
202 |
203 | def forward(self, z):
204 | return self.activation_output(super().forward(z))
205 |
206 | def decode(self, z):
207 | check_mode(self.mode, inspect.stack()[0][3])
208 | x = self.conv1.inference(z)
209 | for i in range(self.num_blocks-1):
210 | x = self.conv_blocks[i][0](x) # activation
211 | x = self.conv_blocks[i][1].inference(x) # DecoderBlock
212 | x = self.conv_blocks[-1](x) # activation
213 | x = self.conv2.inference(x)
214 | return self.activation_output(x)
--------------------------------------------------------------------------------
/AudioDec/models/autoencoder/modules/encoder.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright (c) Meta Platforms, Inc. and affiliates.
5 | # All rights reserved.
6 | #
7 | # This source code is licensed under the license found in the
8 | # LICENSE file in the root directory of this source tree.
9 | #
10 | # Reference (https://ieeexplore.ieee.org/document/9625818)
11 |
12 | """Encoder modules."""
13 |
14 | import torch
15 | import inspect
16 |
17 | from AudioDec.layers.conv_layer import NonCausalConv1d
18 | from AudioDec.layers.conv_layer import CausalConv1d
19 | from AudioDec.layers.activation_function import get_activation
20 | from AudioDec.models.autoencoder.modules.residual_unit import NonCausalResidualUnit
21 | from AudioDec.models.autoencoder.modules.residual_unit import CausalResidualUnit
22 | from AudioDec.models.utils import check_mode
23 |
24 |
25 | class EncoderBlock(torch.nn.Module):
26 | """ Encoder block (downsampling) """
27 |
28 | def __init__(
29 | self,
30 | in_channels,
31 | out_channels,
32 | stride,
33 | dilations=(1, 3, 9),
34 | bias=True,
35 | mode='causal',
36 | nonlinear_activation="ELU",
37 | nonlinear_activation_params={},
38 | ):
39 | super().__init__()
40 | self.mode = mode
41 | if self.mode == 'noncausal':
42 | ResidualUnit = NonCausalResidualUnit
43 | Conv1d = NonCausalConv1d
44 | elif self.mode == 'causal':
45 | ResidualUnit = CausalResidualUnit
46 | Conv1d = CausalConv1d
47 | else:
48 | raise NotImplementedError(f"Mode ({self.mode}) is not supported!")
49 |
50 | self.res_units = torch.nn.ModuleList()
51 | for dilation in dilations:
52 | self.res_units += [
53 | ResidualUnit(
54 | in_channels,
55 | in_channels,
56 | dilation=dilation,
57 | nonlinear_activation=nonlinear_activation,
58 | nonlinear_activation_params=nonlinear_activation_params,
59 | )]
60 | self.num_res = len(self.res_units)
61 |
62 | self.conv = Conv1d(
63 | in_channels=in_channels,
64 | out_channels=out_channels,
65 | kernel_size=(2 * stride),
66 | stride=stride,
67 | bias=bias,
68 | )
69 |
70 | def forward(self, x):
71 | for idx in range(self.num_res):
72 | x = self.res_units[idx](x)
73 | x = self.conv(x)
74 | return x
75 |
76 | def inference(self, x):
77 | check_mode(self.mode, inspect.stack()[0][3])
78 | for idx in range(self.num_res):
79 | x = self.res_units[idx].inference(x)
80 | x = self.conv.inference(x)
81 | return x
82 |
83 |
84 | class Encoder(torch.nn.Module):
85 | def __init__(self,
86 | input_channels,
87 | encode_channels,
88 | channel_ratios=(2, 4, 8, 16),
89 | strides=(3, 4, 5, 5),
90 | kernel_size=7,
91 | bias=True,
92 | mode='causal',
93 | nonlinear_activation="ELU",
94 | nonlinear_activation_params={},
95 | ):
96 | super().__init__()
97 | assert len(channel_ratios) == len(strides)
98 | self.mode = mode
99 | if self.mode == 'noncausal':
100 | Conv1d = NonCausalConv1d
101 | elif self.mode == 'causal':
102 | Conv1d = CausalConv1d
103 | else:
104 | raise NotImplementedError(f"Mode ({self.mode}) is not supported!")
105 |
106 | self.conv = Conv1d(
107 | in_channels=input_channels,
108 | out_channels=encode_channels,
109 | kernel_size=kernel_size,
110 | stride=1,
111 | bias=False)
112 |
113 | self.conv_blocks = torch.nn.ModuleList()
114 | in_channels = encode_channels
115 | for idx, stride in enumerate(strides):
116 | out_channels = encode_channels * channel_ratios[idx]
117 | self.conv_blocks += [
118 | EncoderBlock(
119 | in_channels,
120 | out_channels,
121 | stride,
122 | bias=bias,
123 | mode=self.mode,
124 | nonlinear_activation=nonlinear_activation,
125 | nonlinear_activation_params=nonlinear_activation_params,
126 | )]
127 | in_channels = out_channels
128 | self.num_blocks = len(self.conv_blocks)
129 | self.out_channels = out_channels
130 |
131 | def forward(self, x):
132 | x = self.conv(x)
133 | for i in range(self.num_blocks):
134 | x = self.conv_blocks[i](x)
135 | return x
136 |
137 | def encode(self, x):
138 | check_mode(self.mode, inspect.stack()[0][3])
139 | x = self.conv.inference(x)
140 | for i in range(self.num_blocks):
141 | x = self.conv_blocks[i].inference(x)
142 | return x
143 |
144 |
145 | class ActivateEncoder(Encoder):
146 | def __init__(self,
147 | input_channels,
148 | encode_channels,
149 | channel_ratios=(2, 4, 8, 16),
150 | strides=(3, 4, 5, 5),
151 | kernel_size=7,
152 | bias=True,
153 | mode='causal',
154 | nonlinear_activation="ELU",
155 | nonlinear_activation_params={},
156 | ):
157 | super().__init__(
158 | input_channels=input_channels,
159 | encode_channels=encode_channels,
160 | channel_ratios=channel_ratios,
161 | strides=strides,
162 | kernel_size=kernel_size,
163 | bias=bias,
164 | mode=mode,
165 | nonlinear_activation=nonlinear_activation,
166 | nonlinear_activation_params=nonlinear_activation_params,
167 | )
168 | self.activation = get_activation(nonlinear_activation, nonlinear_activation_params)
169 |
170 |
171 | def forward(self, x):
172 | return self.activation(super().forward(x))
173 |
174 | def encode(self, x):
175 | return self.activation(super().encode(x))
--------------------------------------------------------------------------------
/AudioDec/models/autoencoder/modules/projector.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright (c) Meta Platforms, Inc. and affiliates.
5 | # All rights reserved.
6 | #
7 | # This source code is licensed under the license found in the
8 | # LICENSE file in the root directory of this source tree.
9 |
10 | """Projector modules."""
11 |
12 | import torch
13 | import inspect
14 |
15 | from AudioDec.layers.conv_layer import NonCausalConv1d
16 | from AudioDec.layers.conv_layer import CausalConv1d
17 | from AudioDec.models.utils import check_mode
18 |
19 |
20 | class Projector(torch.nn.Module):
21 | def __init__(self,
22 | input_channels,
23 | code_dim,
24 | kernel_size=3,
25 | stride=1,
26 | bias=False,
27 | mode='causal',
28 | model='conv1d',
29 | ):
30 | super().__init__()
31 | self.mode = mode
32 | if self.mode == 'noncausal':
33 | Conv1d = NonCausalConv1d
34 | elif self.mode == 'causal':
35 | Conv1d = CausalConv1d
36 | else:
37 | raise NotImplementedError(f"Mode ({mode}) is not supported!")
38 |
39 | if model == 'conv1d':
40 | self.project = Conv1d(input_channels, code_dim, kernel_size=kernel_size, stride=stride, bias=bias)
41 | elif model == 'conv1d_bn':
42 | self.project = torch.nn.Sequential(
43 | Conv1d(input_channels, code_dim, kernel_size=kernel_size, stride=stride, bias=bias),
44 | torch.nn.BatchNorm1d(code_dim)
45 | )
46 | else:
47 | raise NotImplementedError(f"Model ({model}) is not supported!")
48 |
49 | def forward(self, x):
50 | return self.project(x)
51 |
52 | def encode(self, x):
53 | check_mode(self.mode, inspect.stack()[0][3])
54 | return self.project.inference(x)
55 |
56 |
57 |
--------------------------------------------------------------------------------
/AudioDec/models/autoencoder/modules/quantizer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright (c) Meta Platforms, Inc. and affiliates.
5 | # All rights reserved.
6 | #
7 | # This source code is licensed under the license found in the
8 | # LICENSE file in the root directory of this source tree.
9 |
10 | import torch
11 |
12 | from AudioDec.layers.vq_module import ResidualVQ
13 |
14 |
15 | class Quantizer(torch.nn.Module):
16 | def __init__(self,
17 | code_dim,
18 | codebook_num,
19 | codebook_size,
20 | model='residual_vq',
21 | ):
22 | super().__init__()
23 | # speech
24 | if model == 'residual_vq':
25 | self.codebook = ResidualVQ(dim=code_dim, num_quantizers=codebook_num, codebook_size=codebook_size)
26 | else:
27 | raise NotImplementedError(f"Model ({model}) is not supported!")
28 |
29 | def initial(self):
30 | self.codebook.initial()
31 |
32 | def forward(self, z):
33 | zq, vqloss, perplexity = self.codebook(z.transpose(2, 1))
34 | zq = zq.transpose(2, 1)
35 | return zq, vqloss, perplexity
36 |
37 | def inference(self, z):
38 | zq, indices = self.codebook.forward_index(z.transpose(2, 1))
39 | zq = zq.transpose(2, 1)
40 | return zq, indices
41 |
42 | def encode(self, z):
43 | zq, indices = self.codebook.forward_index(z.transpose(2, 1), flatten_idx=True)
44 | return zq, indices
45 |
46 | def decode(self, indices):
47 | z = self.codebook.lookup(indices)
48 | return z
49 |
--------------------------------------------------------------------------------
/AudioDec/models/autoencoder/modules/residual_unit.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright (c) Meta Platforms, Inc. and affiliates.
5 | # All rights reserved.
6 | #
7 | # This source code is licensed under the license found in the
8 | # LICENSE file in the root directory of this source tree.
9 | #
10 | # Reference (https://ieeexplore.ieee.org/document/9625818)
11 |
12 | """Residual Units."""
13 |
14 | import torch
15 | import torch.nn as nn
16 |
17 | from AudioDec.layers.conv_layer import Conv1d1x1, NonCausalConv1d, CausalConv1d
18 | from AudioDec.layers.activation_function import get_activation
19 |
20 | class NonCausalResidualUnit(nn.Module):
21 | def __init__(
22 | self,
23 | in_channels,
24 | out_channels,
25 | kernel_size=7,
26 | dilation=1,
27 | bias=False,
28 | nonlinear_activation="ELU",
29 | nonlinear_activation_params={},
30 | ):
31 | super().__init__()
32 | self.activation = get_activation(nonlinear_activation, nonlinear_activation_params)
33 | self.conv1 = NonCausalConv1d(
34 | in_channels=in_channels,
35 | out_channels=out_channels,
36 | kernel_size=kernel_size,
37 | stride=1,
38 | dilation=dilation,
39 | bias=bias,
40 | )
41 | self.conv2 = Conv1d1x1(out_channels, out_channels, bias)
42 |
43 | def forward(self, x):
44 | y = self.conv1(self.activation(x))
45 | y = self.conv2(self.activation(y))
46 | return x + y
47 |
48 |
49 | class CausalResidualUnit(NonCausalResidualUnit):
50 | def __init__(
51 | self,
52 | in_channels,
53 | out_channels,
54 | kernel_size=7,
55 | dilation=1,
56 | bias=False,
57 | nonlinear_activation="ELU",
58 | nonlinear_activation_params={},
59 | ):
60 | super(CausalResidualUnit, self).__init__(
61 | in_channels=in_channels,
62 | out_channels=out_channels,
63 | kernel_size=kernel_size,
64 | dilation=dilation,
65 | bias=bias,
66 | nonlinear_activation=nonlinear_activation,
67 | nonlinear_activation_params=nonlinear_activation_params,
68 | )
69 | self.conv1 = CausalConv1d(
70 | in_channels=in_channels,
71 | out_channels=out_channels,
72 | kernel_size=kernel_size,
73 | stride=1,
74 | dilation=dilation,
75 | bias=bias,
76 | )
77 |
78 | def inference(self, x):
79 | y = self.conv1.inference(self.activation(x))
80 | y = self.conv2(self.activation(y))
81 | return x + y
--------------------------------------------------------------------------------
/AudioDec/models/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright (c) Meta Platforms, Inc. and affiliates.
5 | # All rights reserved.
6 | #
7 | # This source code is licensed under the license found in the
8 | # LICENSE file in the root directory of this source tree.
9 |
10 |
11 | """Utility modules."""
12 |
13 | def check_mode(mode, method):
14 | stream_modes = ['causal']
15 | assert mode in stream_modes, f"Mode {mode} does not support {method}!"
16 |
--------------------------------------------------------------------------------
/AudioDec/models/vocoder/modules/multi_fusion.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright (c) Meta Platforms, Inc. and affiliates.
5 | # All rights reserved.
6 | #
7 | # This source code is licensed under the license found in the
8 | # LICENSE file in the root directory of this source tree.
9 | #
10 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/)
11 | # Reference (https://github.com/r9y9/wavenet_vocoder)
12 | # Reference (https://github.com/jik876/hifi-gan)
13 |
14 | """Multi-fusion modules."""
15 |
16 | import math
17 | import torch
18 | import torch.nn as nn
19 | from AudioDec.layers.conv_layer import Conv1d1x1
20 | from AudioDec.models.vocoder.modules.residual_block import HiFiGANResidualBlock
21 |
22 |
23 | class MultiReceptiveField(nn.Module):
24 | """Multi-receptive field module in HiFiGAN."""
25 |
26 | def __init__(
27 | self,
28 | channels=512,
29 | resblock_kernel_sizes=(3, 7, 11),
30 | resblock_dilations=[(1, 3, 5), (1, 3, 5), (1, 3, 5)],
31 | groups=1,
32 | bias=True,
33 | use_additional_convs=True,
34 | nonlinear_activation="LeakyReLU",
35 | nonlinear_activation_params={"negative_slope": 0.1},
36 | ):
37 | assert len(resblock_kernel_sizes) == len(resblock_dilations)
38 | super().__init__()
39 | self.num_blocks = len(resblock_kernel_sizes)
40 |
41 | self.blocks = nn.ModuleList()
42 | for i in range(self.num_blocks):
43 | self.blocks += [
44 | HiFiGANResidualBlock(
45 | kernel_size=resblock_kernel_sizes[i],
46 | channels=channels,
47 | dilations=resblock_dilations[i],
48 | groups=groups,
49 | bias=bias,
50 | use_additional_convs=use_additional_convs,
51 | nonlinear_activation=nonlinear_activation,
52 | nonlinear_activation_params=nonlinear_activation_params,
53 | )
54 | ]
55 |
56 | def forward(self, c):
57 | """Calculate forward propagation.
58 |
59 | Args:
60 | c (Tensor): Input tensor (B, channels, T).
61 |
62 | Returns:
63 | Tensor: Output tensor (B, channels, T).
64 |
65 | """
66 | cs = 0.0 # initialize
67 | for i in range(self.num_blocks):
68 | cs += self.blocks[i](c)
69 | c = cs / self.num_blocks
70 |
71 | return c
72 |
73 | def inference(self, c):
74 | cs = 0.0 # initialize
75 | for i in range(self.num_blocks):
76 | cs += self.blocks[i].inference(c)
77 | c = cs / self.num_blocks
78 |
79 | return c
80 |
81 |
82 | class MultiGroupConv1d(HiFiGANResidualBlock):
83 | """Multi-group convolution module."""
84 |
85 | def __init__(
86 | self,
87 | channels=512,
88 | resblock_kernel_sizes=(3),
89 | resblock_dilations=[(1, 3, 5)],
90 | groups=3,
91 | bias=True,
92 | use_additional_convs=True,
93 | nonlinear_activation="LeakyReLU",
94 | nonlinear_activation_params={"negative_slope": 0.1},
95 | ):
96 | assert len(resblock_kernel_sizes) == len(resblock_dilations) == 1
97 | super(MultiGroupConv1d, self).__init__(
98 | kernel_size=resblock_kernel_sizes[0],
99 | channels=channels*groups,
100 | dilations=resblock_dilations[0],
101 | groups=groups,
102 | bias=bias,
103 | use_additional_convs=use_additional_convs,
104 | nonlinear_activation=nonlinear_activation,
105 | nonlinear_activation_params=nonlinear_activation_params,
106 | )
107 | self.groups = groups
108 | self.conv_out = Conv1d1x1(
109 | in_channels=channels*groups,
110 | out_channels=channels,
111 | bias=False,
112 | )
113 |
114 | def forward(self, x):
115 | """Calculate forward propagation.
116 |
117 | Args:
118 | x (Tensor): Input tensor (B, channels, T).
119 |
120 | Returns:
121 | Tensor: Output tensor (B, channels, T).
122 |
123 | """
124 | x = x.repeat(1, self.groups, 1) # (B, n*C, T)
125 | for idx in range(self.num_layer):
126 | xt = self.convs1[idx](self.activation(x))
127 | if self.use_additional_convs:
128 | xt = self.convs2[idx](self.activation(xt))
129 | x = xt + x
130 | x = self.conv_out(x) # (B, C, T)
131 | return x
132 |
133 | def inference(self, x):
134 | x = x.repeat(1, self.groups, 1) # (B, n*C, T)
135 | for idx in range(self.num_layer):
136 | xt = self.convs1[idx].inference(self.activation(x))
137 | if self.use_additional_convs:
138 | xt = self.convs2[idx].inference(self.activation(xt))
139 | x = xt + x
140 | x = self.conv_out(x) # (B, C, T)
141 | return x
--------------------------------------------------------------------------------
/AudioDec/models/vocoder/modules/residual_block.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright (c) Meta Platforms, Inc. and affiliates.
5 | # All rights reserved.
6 | #
7 | # This source code is licensed under the license found in the
8 | # LICENSE file in the root directory of this source tree.
9 | #
10 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/)
11 | # Reference (https://github.com/r9y9/wavenet_vocoder)
12 | # Reference (https://github.com/jik876/hifi-gan)
13 |
14 | """Residual block modules."""
15 |
16 | import math
17 |
18 | import torch
19 | import torch.nn as nn
20 | from AudioDec.layers.conv_layer import CausalConv1d, Conv1d1x1
21 |
22 |
23 | class HiFiGANResidualBlock(nn.Module):
24 | """Causal Residual block module in HiFiGAN."""
25 |
26 | def __init__(
27 | self,
28 | kernel_size=3,
29 | channels=512,
30 | dilations=(1, 3, 5),
31 | groups=1,
32 | bias=True,
33 | use_additional_convs=True,
34 | nonlinear_activation="LeakyReLU",
35 | nonlinear_activation_params={"negative_slope": 0.1},
36 | ):
37 | """Initialize CausalResidualBlock module.
38 |
39 | Args:
40 | kernel_size (int): Kernel size of dilation convolution layer.
41 | channels (int): Number of channels for convolution layer.
42 | dilations (List[int]): List of dilation factors.
43 | use_additional_convs (bool): Whether to use additional convolution layers.
44 | groups (int): The group number of conv1d (default: 1)
45 | bias (bool): Whether to add bias parameter in convolution layers.
46 | nonlinear_activation (str): Activation function module name.
47 | nonlinear_activation_params (dict): Hyperparameters for activation function.
48 |
49 | """
50 | super().__init__()
51 | self.use_additional_convs = use_additional_convs
52 | self.convs1 = nn.ModuleList()
53 | if use_additional_convs:
54 | self.convs2 = nn.ModuleList()
55 | assert kernel_size % 2 == 1, "Kernel size must be odd number."
56 | self.activation = getattr(nn, nonlinear_activation)(**nonlinear_activation_params)
57 | for dilation in dilations:
58 | self.convs1 += [
59 | CausalConv1d(
60 | in_channels=channels,
61 | out_channels=channels,
62 | kernel_size=kernel_size,
63 | stride=1,
64 | dilation=dilation,
65 | groups=groups,
66 | bias=bias,
67 | )
68 | ]
69 | if use_additional_convs:
70 | self.convs2 += [
71 | CausalConv1d(
72 | in_channels=channels,
73 | out_channels=channels,
74 | kernel_size=kernel_size,
75 | stride=1,
76 | dilation=1,
77 | groups=groups,
78 | bias=bias,
79 | )
80 | ]
81 | self.num_layer = len(self.convs1)
82 |
83 | def forward(self, x):
84 | """Calculate forward propagation.
85 |
86 | Args:
87 | x (Tensor): Input tensor (B, channels, T).
88 |
89 | Returns:
90 | Tensor: Output tensor (B, channels, T).
91 |
92 | """
93 | for idx in range(self.num_layer):
94 | xt = self.convs1[idx](self.activation(x))
95 | if self.use_additional_convs:
96 | xt = self.convs2[idx](self.activation(xt))
97 | x = xt + x
98 | return x
99 |
100 | def inference(self, x):
101 | for idx in range(self.num_layer):
102 | xt = self.convs1[idx].inference(self.activation(x))
103 | if self.use_additional_convs:
104 | xt = self.convs2[idx].inference(self.activation(xt))
105 | x = xt + x
106 | return x
--------------------------------------------------------------------------------
/AudioDec/parse_options.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
4 | # Arnab Ghoshal, Karel Vesely
5 |
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
14 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
15 | # MERCHANTABLITY OR NON-INFRINGEMENT.
16 | # See the Apache 2 License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 |
20 | # Parse command-line options.
21 | # To be sourced by another script (as in ". parse_options.sh").
22 | # Option format is: --option-name arg
23 | # and shell variable "option_name" gets set to value "arg."
24 | # The exception is --help, which takes no arguments, but prints the
25 | # $help_message variable (if defined).
26 |
27 |
28 | ###
29 | ### The --config file options have lower priority to command line
30 | ### options, so we need to import them first...
31 | ###
32 |
33 | # Now import all the configs specified by command-line, in left-to-right order
34 | for ((argpos=1; argpos<$#; argpos++)); do
35 | if [ "${!argpos}" == "--config" ]; then
36 | argpos_plus1=$((argpos+1))
37 | config=${!argpos_plus1}
38 | [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
39 | . $config # source the config file.
40 | fi
41 | done
42 |
43 |
44 | ###
45 | ### Now we process the command line options
46 | ###
47 | while true; do
48 | [ -z "${1:-}" ] && break; # break if there are no arguments
49 | case "$1" in
50 | # If the enclosing script is called with --help option, print the help
51 | # message and exit. Scripts should put help messages in $help_message
52 | --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
53 | else printf "$help_message\n" 1>&2 ; fi;
54 | exit 0 ;;
55 | --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
56 | exit 1 ;;
57 | # If the first command-line argument begins with "--" (e.g. --foo-bar),
58 | # then work out the variable name as $name, which will equal "foo_bar".
59 | --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
60 | # Next we test whether the variable in question is undefned-- if so it's
61 | # an invalid option and we die. Note: $0 evaluates to the name of the
62 | # enclosing script.
63 | # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
64 | # is undefined. We then have to wrap this test inside "eval" because
65 | # foo_bar is itself inside a variable ($name).
66 | eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
67 |
68 | oldval="`eval echo \\$$name`";
69 | # Work out whether we seem to be expecting a Boolean argument.
70 | if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
71 | was_bool=true;
72 | else
73 | was_bool=false;
74 | fi
75 |
76 | # Set the variable to the right value-- the escaped quotes make it work if
77 | # the option had spaces, like --cmd "queue.pl -sync y"
78 | eval $name=\"$2\";
79 |
80 | # Check that Boolean-valued arguments are really Boolean.
81 | if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
82 | echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
83 | exit 1;
84 | fi
85 | shift 2;
86 | ;;
87 | *) break;
88 | esac
89 | done
90 |
91 |
92 | # Check for an empty argument to the --cmd option, which can easily occur as a
93 | # result of scripting errors.
94 | [ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
95 |
96 |
97 | true; # so this script returns exit code 0.
--------------------------------------------------------------------------------
/AudioDec/requirements.txt:
--------------------------------------------------------------------------------
1 | soundfile
2 | sounddevice
3 | numpy
4 | torch
5 | torchaudio
6 | scikit-learn
7 | librosa
8 | argparse
9 | pyyaml
10 | tqdm
11 | tensorboardX
12 |
13 |
--------------------------------------------------------------------------------
/AudioDec/slurmlogs/README.md:
--------------------------------------------------------------------------------
1 | # Folder for saving slurm logs
--------------------------------------------------------------------------------
/AudioDec/stats/symAD_libritts_24000_hop300_clean.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dukGuo/valle-audiodec/cbd51784ea9fe0bee05b4a65c95f39523b57be89/AudioDec/stats/symAD_libritts_24000_hop300_clean.npy
--------------------------------------------------------------------------------
/AudioDec/stats/symAD_vctk_48000_hop300_clean.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dukGuo/valle-audiodec/cbd51784ea9fe0bee05b4a65c95f39523b57be89/AudioDec/stats/symAD_vctk_48000_hop300_clean.npy
--------------------------------------------------------------------------------
/AudioDec/stats/symADuniv_vctk_48000_hop300_clean.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dukGuo/valle-audiodec/cbd51784ea9fe0bee05b4a65c95f39523b57be89/AudioDec/stats/symADuniv_vctk_48000_hop300_clean.npy
--------------------------------------------------------------------------------
/AudioDec/trainer/autoencoder.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright (c) Meta Platforms, Inc. and affiliates.
5 | # All rights reserved.
6 | #
7 | # This source code is licensed under the license found in the
8 | # LICENSE file in the root directory of this source tree.
9 | #
10 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/)
11 |
12 | """Training flow of symmetric codec."""
13 |
14 | import logging
15 | import torch
16 | from trainer.trainerGAN import TrainerVQGAN
17 |
18 |
19 | class Trainer(TrainerVQGAN):
20 | def __init__(
21 | self,
22 | steps,
23 | epochs,
24 | data_loader,
25 | model,
26 | criterion,
27 | optimizer,
28 | scheduler,
29 | config,
30 | device=torch.device("cpu"),
31 | ):
32 | super(Trainer, self).__init__(
33 | steps=steps,
34 | epochs=epochs,
35 | data_loader=data_loader,
36 | model=model,
37 | criterion=criterion,
38 | optimizer=optimizer,
39 | scheduler=scheduler,
40 | config=config,
41 | device=device,
42 | )
43 | self.fix_encoder = False
44 | self.paradigm = config.get('paradigm', 'efficient')
45 | self.generator_start = config.get('start_steps', {}).get('generator', 0)
46 | self.discriminator_start = config.get('start_steps', {}).get('discriminator', 200000)
47 |
48 |
49 | def _train_step(self, batch):
50 | """Single step of training."""
51 | mode = 'train'
52 | x = batch
53 | x = x.to(self.device)
54 |
55 | # check generator step
56 | if self.steps < self.generator_start:
57 | self.generator_train = False
58 | else:
59 | self.generator_train = True
60 |
61 | # check discriminator step
62 | if self.steps < self.discriminator_start:
63 | self.discriminator_train = False
64 | else:
65 | self.discriminator_train = True
66 | if (not self.fix_encoder) and (self.paradigm == 'efficient'):
67 | # fix encoder, quantizer, and codebook
68 | for parameter in self.model["generator"].encoder.parameters():
69 | parameter.requires_grad = False
70 | for parameter in self.model["generator"].projector.parameters():
71 | parameter.requires_grad = False
72 | for parameter in self.model["generator"].quantizer.parameters():
73 | parameter.requires_grad = False
74 | self.fix_encoder = True
75 | logging.info("Encoder, projector, quantizer, and codebook are fixed")
76 |
77 | # check codebook updating
78 | if self.fix_encoder:
79 | self.model["generator"].quantizer.codebook.eval()
80 |
81 | #######################
82 | # Generator #
83 | #######################
84 | if self.generator_train:
85 | # initialize generator loss
86 | gen_loss = 0.0
87 |
88 | # main genertor operation
89 | y_, zq, z, vqloss, perplexity = self.model["generator"](x)
90 |
91 | # perplexity info
92 | self._perplexity(perplexity, mode=mode)
93 |
94 | # vq loss
95 | gen_loss += self._vq_loss(vqloss, mode=mode)
96 |
97 | # metric loss
98 | gen_loss += self._metric_loss(y_, x, mode=mode)
99 |
100 | # adversarial loss
101 | if self.discriminator_train:
102 | p_ = self.model["discriminator"](y_)
103 | if self.config["use_feat_match_loss"]:
104 | with torch.no_grad():
105 | p = self.model["discriminator"](x)
106 | else:
107 | p = None
108 | gen_loss += self._adv_loss(p_, p, mode=mode)
109 |
110 | # update generator
111 | self._record_loss('generator_loss', gen_loss, mode=mode)
112 | self._update_generator(gen_loss)
113 |
114 | #######################
115 | # Discriminator #
116 | #######################
117 | if self.discriminator_train:
118 | # re-compute y_ which leads better quality
119 | with torch.no_grad():
120 | y_, _, _, _, _ = self.model["generator"](x)
121 |
122 | p = self.model["discriminator"](x)
123 | p_ = self.model["discriminator"](y_.detach())
124 |
125 | # discriminator loss & update discriminator
126 | self._update_discriminator(self._dis_loss(p_, p, mode=mode))
127 |
128 | # update counts
129 | self.steps += 1
130 | self.tqdm.update(1)
131 | self._check_train_finish()
132 |
133 |
134 | @torch.no_grad()
135 | def _eval_step(self, batch):
136 | """Single step of evaluation."""
137 | mode = 'eval'
138 | x = batch
139 | x = x.to(self.device)
140 |
141 | # initialize generator loss
142 | gen_loss = 0.0
143 |
144 | # main genertor operation
145 | y_, zq, z, vqloss, perplexity = self.model["generator"](x)
146 |
147 | # perplexity info
148 | self._perplexity(perplexity, mode=mode)
149 |
150 | # vq_loss
151 | gen_loss += self._vq_loss(vqloss, mode=mode)
152 |
153 | # metric loss
154 | gen_loss += self._metric_loss(y_, x, mode=mode)
155 |
156 | if self.discriminator_train:
157 | # adversarial loss
158 | p_ = self.model["discriminator"](y_)
159 | p = self.model["discriminator"](x)
160 | gen_loss += self._adv_loss(p_, p, mode=mode)
161 |
162 | # discriminator loss
163 | self._dis_loss(p_, p, mode=mode)
164 |
165 | # generator loss
166 | self._record_loss('generator_loss', gen_loss, mode=mode)
167 |
168 |
169 |
170 |
171 |
172 |
--------------------------------------------------------------------------------
/AudioDec/trainer/denoise.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright (c) Meta Platforms, Inc. and affiliates.
5 | # All rights reserved.
6 | #
7 | # This source code is licensed under the license found in the
8 | # LICENSE file in the root directory of this source tree.
9 | #
10 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/)
11 |
12 | """Training flow of symmetric codec."""
13 |
14 | import logging
15 | import torch
16 | from trainer.trainerGAN import TrainerVQGAN
17 |
18 |
19 | class Trainer(TrainerVQGAN):
20 | def __init__(
21 | self,
22 | steps,
23 | epochs,
24 | data_loader,
25 | model,
26 | criterion,
27 | optimizer,
28 | scheduler,
29 | config,
30 | device=torch.device("cpu"),
31 | ):
32 | super(Trainer, self).__init__(
33 | steps=steps,
34 | epochs=epochs,
35 | data_loader=data_loader,
36 | model=model,
37 | criterion=criterion,
38 | optimizer=optimizer,
39 | scheduler=scheduler,
40 | config=config,
41 | device=device,
42 | )
43 | # fix quantizer
44 | for parameter in self.model["generator"].quantizer.parameters():
45 | parameter.requires_grad = False
46 | # fix decoder
47 | for parameter in self.model["generator"].decoder.parameters():
48 | parameter.requires_grad = False
49 | logging.info("Quantizer, codebook, and decoder are fixed")
50 |
51 |
52 | def _train_step(self, batch):
53 | """Single step of training."""
54 | mode = 'train'
55 | x_n, x_c = batch
56 | x_n = x_n.to(self.device)
57 | x_c = x_c.to(self.device)
58 |
59 | # fix codebook
60 | self.model["generator"].quantizer.codebook.eval()
61 |
62 | # initialize generator loss
63 | gen_loss = 0.0
64 |
65 | # main genertor operation
66 | y_nc, zq, z, vqloss, perplexity = self.model["generator"](x_n)
67 |
68 | # perplexity info
69 | self._perplexity(perplexity, mode=mode)
70 |
71 | # vq loss
72 | gen_loss += self._vq_loss(vqloss, mode=mode)
73 |
74 | # metric loss
75 | gen_loss += self._metric_loss(y_nc, x_c, mode=mode)
76 |
77 | # update generator
78 | self._record_loss('generator_loss', gen_loss, mode=mode)
79 | self._update_generator(gen_loss)
80 |
81 | # update counts
82 | self.steps += 1
83 | self.tqdm.update(1)
84 | self._check_train_finish()
85 |
86 |
87 | @torch.no_grad()
88 | def _eval_step(self, batch):
89 | """Single step of evaluation."""
90 | mode = 'eval'
91 | x_n, x_c = batch
92 | x_n = x_n.to(self.device)
93 | x_c = x_c.to(self.device)
94 |
95 | # initialize generator loss
96 | gen_loss = 0.0
97 |
98 | # main genertor operation
99 | y_nc, zq, z, vqloss, perplexity = self.model["generator"](x_n)
100 |
101 | # perplexity info
102 | self._perplexity(perplexity, mode=mode)
103 |
104 | # vq_loss
105 | gen_loss += self._vq_loss(vqloss, mode=mode)
106 |
107 | # metric loss
108 | gen_loss += self._metric_loss(y_nc, x_c, mode=mode)
109 |
110 | # generator loss
111 | self._record_loss('generator_loss', gen_loss, mode=mode)
112 |
113 |
114 |
115 |
116 |
117 |
--------------------------------------------------------------------------------
/AudioDec/trainer/vocoder.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright (c) Meta Platforms, Inc. and affiliates.
5 | # All rights reserved.
6 | #
7 | # This source code is licensed under the license found in the
8 | # LICENSE file in the root directory of this source tree.
9 | #
10 | # Reference (https://github.com/kan-bayashi/ParallelWaveGAN/)
11 |
12 | """Training flow of GAN-based vocoder."""
13 |
14 | import logging
15 | import torch
16 | from trainer.trainerGAN import TrainerGAN
17 |
18 |
19 | class Trainer(TrainerGAN):
20 | def __init__(
21 | self,
22 | steps,
23 | epochs,
24 | data_loader,
25 | model,
26 | criterion,
27 | optimizer,
28 | scheduler,
29 | config,
30 | device=torch.device("cpu"),
31 | ):
32 | super(Trainer, self).__init__(
33 | steps=steps,
34 | epochs=epochs,
35 | data_loader=data_loader,
36 | model=model,
37 | criterion=criterion,
38 | optimizer=optimizer,
39 | scheduler=scheduler,
40 | config=config,
41 | device=device,
42 | )
43 | self.fix_analyzer = False
44 | self.generator_start = config.get("generator_train_start_steps", 0)
45 | self.discriminator_start = config.get("discriminator_train_start_steps", 0)
46 |
47 |
48 | def _train_step(self, batch):
49 | """Train model one step."""
50 | mode = 'train'
51 | x = batch
52 | x = x.to(self.device)
53 |
54 | # fix analyzer
55 | if not self.fix_analyzer:
56 | for parameter in self.model["analyzer"].parameters():
57 | parameter.requires_grad = False
58 | self.fix_analyzer = True
59 | logging.info("Analyzer is fixed!")
60 | self.model["analyzer"].eval()
61 |
62 | #######################
63 | # Generator #
64 | #######################
65 | if self.steps > self.generator_start:
66 | # initialize generator loss
67 | gen_loss = 0.0
68 |
69 | # main genertor operation
70 | e = self.model["analyzer"].encoder(x)
71 | z = self.model["analyzer"].projector(e)
72 | zq, _, _ = self.model["analyzer"].quantizer(z)
73 | y_ = self.model["generator"](zq)
74 |
75 | # metric loss
76 | gen_loss += self._metric_loss(y_, x, mode=mode)
77 |
78 | # adversarial loss
79 | if self.steps > self.discriminator_start:
80 | p_ = self.model["discriminator"](y_)
81 | if self.config["use_feat_match_loss"]:
82 | with torch.no_grad():
83 | p = self.model["discriminator"](x)
84 | else:
85 | p = None
86 | gen_loss += self._adv_loss(p_, p, mode=mode)
87 |
88 | # update generator
89 | self._record_loss('generator_loss', gen_loss, mode=mode)
90 | self._update_generator(gen_loss)
91 |
92 | #######################
93 | # Discriminator #
94 | #######################
95 | if self.steps > self.discriminator_start:
96 | # re-compute y_ which leads better quality
97 | with torch.no_grad():
98 | e = self.model["analyzer"].encoder(x)
99 | z = self.model["analyzer"].projector(e)
100 | zq, _, _ = self.model["analyzer"].quantizer(z)
101 | y_ = self.model["generator"](zq)
102 | p = self.model["discriminator"](x)
103 | p_ = self.model["discriminator"](y_.detach())
104 |
105 | # discriminator loss & update discriminator
106 | self._update_discriminator(self._dis_loss(p_, p, mode=mode))
107 |
108 | # update counts
109 | self.steps += 1
110 | self.tqdm.update(1)
111 | self._check_train_finish()
112 |
113 |
114 | @torch.no_grad()
115 | def _eval_step(self, batch):
116 | """Single step of evaluation."""
117 | mode = 'eval'
118 | x = batch
119 | x = x.to(self.device)
120 |
121 | # initialize generator loss
122 | gen_loss = 0.0
123 |
124 | # main genertor operation
125 | e = self.model["analyzer"].encoder(x)
126 | z = self.model["analyzer"].projector(e)
127 | zq, _, _ = self.model["analyzer"].quantizer(z)
128 | y_ = self.model["generator"](zq)
129 |
130 | # metric loss
131 | gen_loss += self._metric_loss(y_, x, mode=mode)
132 |
133 | # adversarial loss & feature matching loss
134 | if self.steps > self.discriminator_start:
135 | p_ = self.model["discriminator"](y_)
136 | if self.config["use_feat_match_loss"]:
137 | p = self.model["discriminator"](x)
138 | else:
139 | p = None
140 | gen_loss += self._adv_loss(p_, p, mode=mode)
141 |
142 | # discriminator loss
143 | self._dis_loss(p_, p, mode=mode)
144 |
145 | # generator loss
146 | self._record_loss('generator_loss', gen_loss, mode=mode)
147 |
148 |
--------------------------------------------------------------------------------
/AudioDec/utils/audiodec.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright (c) Meta Platforms, Inc. and affiliates.
5 | # All rights reserved.
6 | #
7 | # This source code is licensed under the license found in the
8 | # LICENSE file in the root directory of this source tree.
9 |
10 | import os
11 | import torch
12 | import math
13 | from typing import Union
14 | from AudioDec.models.autoencoder.AudioDec import StreamGenerator as generator_audiodec
15 | from AudioDec.models.vocoder.HiFiGAN import StreamGenerator as generator_hifigan
16 | from AudioDec.bin.stream import AudioCodec, AudioCodecStreamer
17 |
18 |
19 | class AudioDec(AudioCodec):
20 | def __init__(
21 | self,
22 | tx_device: str = "cpu",
23 | rx_device: str = "cpu",
24 | receptive_length: int = 8192, # actual number is 7209 for symAD_vctk_48000_hop300
25 | ):
26 | super(AudioDec, self).__init__(
27 | tx_device = tx_device,
28 | rx_device = rx_device,
29 | receptive_length = receptive_length,
30 | )
31 |
32 | def _load_encoder(self, checkpoint):
33 | # load config
34 | config = self._load_config(checkpoint)
35 | # load model
36 | if config['model_type'] in ['symAudioDec', 'symAudioDecUniv']:
37 | encoder = generator_audiodec
38 | else:
39 | raise NotImplementedError(f"Encoder type {config['model_type']} is not supported!")
40 | encoder = encoder(**config['generator_params'])
41 | encoder.load_state_dict(torch.load(checkpoint, map_location='cpu')['model']['generator'])
42 | return encoder
43 |
44 | def _load_decoder(self, checkpoint):
45 | # load config
46 | config = self._load_config(checkpoint)
47 | # load model
48 | if config['model_type'] in ['symAudioDec', 'symAudioDecUniv']:
49 | decoder = generator_audiodec
50 | elif config['model_type'] in ['HiFiGAN', 'UnivNet']:
51 | decoder = generator_hifigan
52 | else:
53 | raise NotImplementedError(f"Decoder {config['model_type']} is not supported!")
54 | decoder = decoder(**config['generator_params'])
55 | decoder.load_state_dict(torch.load(checkpoint, map_location='cpu')['model']['generator'])
56 | return decoder
57 |
58 | def get_hop_length(self, checkpoint):
59 | # load receiver model(s)
60 | assert os.path.exists(checkpoint), f'{checkpoint} does not exist!'
61 | config = self._load_config(checkpoint)
62 | return math.prod(config['generator_params']['enc_strides'])
63 |
64 |
65 | class AudioDecStreamer(AudioCodecStreamer):
66 | def __init__(
67 | self,
68 | input_device: Union[str, int],
69 | output_device: Union[str, int],
70 | input_channels: int = 1,
71 | output_channels: int = 1,
72 | frame_size: int = 512,
73 | sample_rate: int = 48000,
74 | gain: int = 1.0,
75 | max_latency: float = 0.1,
76 | # encoder params
77 | tx_encoder = None,
78 | tx_device: str = "cpu",
79 | # decoder params
80 | rx_encoder = None,
81 | decoder = None,
82 | rx_device: str = "cpu",
83 | ):
84 | super(AudioDecStreamer, self).__init__(
85 | input_device = input_device,
86 | output_device = output_device,
87 | input_channels = input_channels,
88 | output_channels = output_channels,
89 | frame_size = frame_size,
90 | sample_rate = sample_rate,
91 | gain = gain,
92 | max_latency = max_latency,
93 | tx_encoder = tx_encoder,
94 | tx_device = tx_device,
95 | rx_encoder = rx_encoder,
96 | decoder = decoder,
97 | rx_device = rx_device,
98 | )
99 |
100 | def _encode(self, x):
101 | x = self.tx_encoder.encode(x)
102 | return self.tx_encoder.quantize(x)
103 |
104 | def _decode(self, x):
105 | x = self.rx_encoder.lookup(x)
106 | return self.decoder.decode(x)
107 |
108 |
109 | def assign_model(model):
110 | if model == 'libritts_v1':
111 | sample_rate = 24000
112 | tx_steps = 500000
113 | rx_steps = 500000
114 | encoder_checkpoint = os.path.join('AudioDec/exp', 'autoencoder', 'symAD_libritts_24000_hop300', f"checkpoint-{tx_steps}steps.pkl")
115 | decoder_checkpoint = os.path.join('AudioDec/exp', 'vocoder', 'AudioDec_v1_symAD_libritts_24000_hop300_clean', f"checkpoint-{rx_steps}steps.pkl")
116 | elif model == 'libritts_sym':
117 | sample_rate = 24000
118 | tx_steps = 500000
119 | rx_steps = 1000000
120 | encoder_checkpoint = os.path.join('exp', 'autoencoder', 'symAD_libritts_24000_hop300', f"checkpoint-{tx_steps}steps.pkl")
121 | decoder_checkpoint = os.path.join('exp', 'autoencoder', 'symAD_libritts_24000_hop300', f"checkpoint-{rx_steps}steps.pkl")
122 | elif model == 'vctk_v1':
123 | sample_rate = 48000
124 | tx_steps = 200000
125 | rx_steps = 500000
126 | encoder_checkpoint = os.path.join('exp', 'autoencoder', 'symAD_vctk_48000_hop300', f"checkpoint-{tx_steps}steps.pkl")
127 | decoder_checkpoint = os.path.join('exp', 'vocoder', 'AudioDec_v1_symAD_vctk_48000_hop300_clean', f"checkpoint-{rx_steps}steps.pkl")
128 | elif model == 'vctk_sym':
129 | sample_rate = 48000
130 | tx_steps = 200000
131 | rx_steps = 700000
132 | encoder_checkpoint = os.path.join('exp', 'autoencoder', 'symAD_vctk_48000_hop300', f"checkpoint-{tx_steps}steps.pkl")
133 | decoder_checkpoint = os.path.join('exp', 'autoencoder', 'symAD_vctk_48000_hop300', f"checkpoint-{rx_steps}steps.pkl")
134 | elif model == 'vctk_v0':
135 | sample_rate = 48000
136 | tx_steps = 200000
137 | rx_steps = 500000
138 | encoder_checkpoint = os.path.join('exp', 'autoencoder', 'symAD_vctk_48000_hop300', f"checkpoint-{tx_steps}steps.pkl")
139 | decoder_checkpoint = os.path.join('exp', 'vocoder', 'AudioDec_v0_symAD_vctk_48000_hop300_clean', f"checkpoint-{rx_steps}steps.pkl")
140 | elif model == 'vctk_v2':
141 | sample_rate = 48000
142 | tx_steps = 200000
143 | rx_steps = 500000
144 | encoder_checkpoint = os.path.join('exp', 'autoencoder', 'symAD_vctk_48000_hop300', f"checkpoint-{tx_steps}steps.pkl")
145 | decoder_checkpoint = os.path.join('exp', 'vocoder', 'AudioDec_v2_symAD_vctk_48000_hop300_clean', f"checkpoint-{rx_steps}steps.pkl")
146 | elif model == 'vctk_denoise':
147 | sample_rate = 48000
148 | tx_steps = 200000
149 | rx_steps = 500000
150 | encoder_checkpoint = os.path.join('exp', 'denoise', 'symAD_vctk_48000_hop300', f"checkpoint-{tx_steps}steps.pkl")
151 | decoder_checkpoint = os.path.join('exp', 'vocoder', 'AudioDec_v1_symAD_vctk_48000_hop300_clean', f"checkpoint-{rx_steps}steps.pkl")
152 | elif model == 'vctk_univ':
153 | sample_rate = 48000
154 | tx_steps = 500000
155 | rx_steps = 500000
156 | encoder_checkpoint = os.path.join('exp', 'autoencoder', 'symADuniv_vctk_48000_hop300', f"checkpoint-{tx_steps}steps.pkl")
157 | decoder_checkpoint = os.path.join('exp', 'vocoder', 'AudioDec_v3_symADuniv_vctk_48000_hop300_clean', f"checkpoint-{rx_steps}steps.pkl")
158 | elif model == 'vctk_univ_sym':
159 | sample_rate = 48000
160 | tx_steps = 500000
161 | rx_steps = 1000000
162 | encoder_checkpoint = os.path.join('exp', 'autoencoder', 'symADuniv_vctk_48000_hop300', f"checkpoint-{tx_steps}steps.pkl")
163 | decoder_checkpoint = os.path.join('exp', 'autoencoder', 'symADuniv_vctk_48000_hop300', f"checkpoint-{rx_steps}steps.pkl")
164 | elif model == 'vctk_activate_sym':
165 | sample_rate = 48000
166 | tx_steps = 200000
167 | rx_steps = 700000
168 | encoder_checkpoint = os.path.join('exp', 'autoencoder', 'symAAD_vctk_48000_hop300', f"checkpoint-{tx_steps}steps.pkl")
169 | decoder_checkpoint = os.path.join('exp', 'autoencoder', 'symAAD_vctk_48000_hop300', f"checkpoint-{rx_steps}steps.pkl")
170 | elif model == 'vctk_c16h320_sym':
171 | sample_rate = 48000
172 | tx_steps = 500000
173 | rx_steps = 1000000
174 | encoder_checkpoint = os.path.join('exp', 'autoencoder', 'symAD_c16_vctk_48000_hop320', f"checkpoint-{tx_steps}steps.pkl")
175 | decoder_checkpoint = os.path.join('exp', 'autoencoder', 'symAD_c16_vctk_48000_hop320', f"checkpoint-{rx_steps}steps.pkl")
176 | else:
177 | raise NotImplementedError(f'Model {model} is not supported!')
178 |
179 | return sample_rate, encoder_checkpoint, decoder_checkpoint
180 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VALL-E
2 |
3 | Inference code for [Wenetspeech4TTS/Audiodec-Valle-Wenetspeech4TTS](https://huggingface.co/Wenetspeech4TTS/Audiodec-Valle-Wenetspeech4TTS)
4 |
5 | ## Installation
6 |
7 | ``` bash
8 | git clone https://github.com/dukGuo/valle-audiodec.git
9 | cd valle-audiodec
10 | pip install -r requirements.txt
11 | ```
12 |
13 | ## Download pre-train model
14 | ### AudioDec
15 | We use [AudioDec](https://github.com/facebookresearch/AudioDec/) as our speech tokenizer instead of [encodec](https://github.com/facebookresearch/encodec) to further improve audio quality.
16 |
17 | Please download the whole [exp](https://github.com/facebookresearch/AudioDec/releases/download/pretrain_models_v02/exp.zip) folder, unzip and put it in the `AudioDec/exp` directory.
18 |
19 | ```bash
20 | cd valle-audiodec
21 | wget https://github.com/facebookresearch/AudioDec/releases/download/pretrain_models_v02/exp.zip
22 | unzip exp.zip
23 | mv exp AudioDec/exp
24 | ```
25 |
26 | ### VALL-E
27 | Checkpiont available on [Wenetspeech4TTS/Audiodec-Valle-Wenetspeech4TTS](https://huggingface.co/Wenetspeech4TTS/Audiodec-Valle-Wenetspeech4TTS)
28 |
29 | - VALL-E *Basic* :VALL-E trained with the WenetSpeech4TTS Basic subset
30 | - VALL-E *Standard*: VALL-E *Basic* fine-tuning with the WenetSpeech4TTS Standard subset
31 | - VALL-E *Premium*: VALL-E *Standard* fine-tuning with the WenetSpeech4TTS Premium subset
32 | ## Speech Sample
33 |
34 | https://wenetspeech4tts.github.io/wenetspeech4tts
35 |
36 | https://rxy-j.github.io/HPMD-TTS
37 |
38 | ## Inference
39 |
40 | ``` bash
41 | cd valle-audiodec
42 | python infer_tts.py \
43 | --config config/hparams.yaml \
44 | --ar_ckpt ckpt/basic/ar.pt \
45 | --nar_ckpt ckpt/basic/nar.pt \
46 | --prompt_wav test/prompt_wavs/test_1.wav \
47 | --prompt_text 在夏日阴凉的树荫下,鸭妈妈孵着鸭宝宝。 \
48 | --text 负责指挥的将军在一旁交代着注意事项,每个人在上面最多只能待九十秒。
49 | ```
50 |
51 | > To improve audio quality and ensure consistent volume levels across different inputs, it is advisable to normalize the loudness of the prompt waveform before conducting inference. This preprocessing step helps achieve uniformity in the audio input, which can lead to more reliable inference outcomes.
52 | > ```
53 | > sox $in_wave -r $sample_rate -b 16 --norm=-6 $out_wave
54 | > ```
55 |
56 | ## References
57 | This repository is developed based on the following repositories.
58 |
59 | - [facebookresearch/AudioDec](https://github.com/facebookresearch/AudioDec)
60 | - [lifeiteng/vall-e](https://github.com/lifeiteng/vall-e)
61 | - [fishaudio/Bert-VITS2](https://github.com/fishaudio/Bert-VITS2)
62 |
--------------------------------------------------------------------------------
/config/hparams.yaml:
--------------------------------------------------------------------------------
1 | hparams:
2 |
3 | # input_settings:
4 | num_semantic: 216 # pinyin
5 | num_acoustic: 1024 #codebook num of codec
6 | acoustic_num_layer: 8 # codec layer num
7 |
8 | # Training_settings:
9 | batch_size: 8
10 | num_workers: 16
11 | save_checkpoint_step: 5000
12 | learning_rate: 0.0002
13 | max_training_steps: 4000000
14 | temperature: 1.0
15 | grad_accu: 40
16 | dist_backend: "nccl" # distributed training setting
17 | dist_url: "tcp://localhost:12345"
18 |
19 | # AR setting:
20 | GPT2_vocab_size: 1245 # padding+semantic+2eos acoustic
21 | GPT2_n_positions: 2048 # 支持的最长长度
22 | GPT2_n_ctx: 2048 # 等同于n_positions
23 | GPT2_n_embd: 1024 # 隐藏层dim
24 | GPT2_n_layer: 12 # 多少层
25 | GPT2_n_head: 16 # 多少个头
26 |
27 | # NAR setting:
28 | nar_vocab_size: 1245 # padding+semantic+acoustic+2eos
29 | nar_n_emb: 1024
30 | nar_n_layer: 12
31 | nar_n_head: 16
--------------------------------------------------------------------------------
/infer_tts.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from tqdm import tqdm
5 | import time
6 | import torchaudio
7 | import soundfile as sf
8 | import shutil
9 |
10 | from models.ar import AR
11 | from models.nar import NAR
12 | from utils.utils import *
13 | from text.chinese import generate_token
14 | from AudioDec.utils.audiodec import AudioDec, assign_model
15 |
16 | def tokenize_wav(wav_path,audiodec,device,sample_rate=24000):
17 |
18 | wav, sr = torchaudio.load(wav_path)
19 | if sr != sample_rate:
20 | wav = torchaudio.functional.resample(wav, sr, sample_rate)
21 | with torch.no_grad():
22 | wav = wav.unsqueeze(1) #C T-> 1 C T
23 | wav = wav.float().to(device)
24 | z = audiodec.tx_encoder.encode(wav)
25 | idx = audiodec.tx_encoder.quantize(z)
26 |
27 | inc = torch.arange(8)*1024
28 | idx = idx.cpu() - inc.reshape(-1,1)
29 | return idx.numpy().T
30 |
31 |
32 | def do_tts(hp,args):
33 |
34 |
35 | # init codec
36 | model_name = "libritts_v1"
37 | device = args.device
38 | sample_rate, encoder_checkpoint, decoder_checkpoint = assign_model(model_name)
39 | audiodec = AudioDec(tx_device=device , rx_device=device )
40 | audiodec.load_transmitter(encoder_checkpoint)
41 | audiodec.load_receiver(encoder_checkpoint, decoder_checkpoint)
42 |
43 | # init valle
44 | ar = AR(hp,device)
45 | ar_ckpt = torch.load(args.ar_ckpt, map_location=device)
46 | ar.load_state_dict(ar_ckpt['model'])
47 | ar.eval()
48 | nar = NAR(hp, device).to(device)
49 | nar_ckpt = torch.load(args.nar_ckpt, map_location=device)
50 | nar_weight = nar_ckpt["model"]
51 | if list(nar_weight.keys())[0].startswith('_orig'):
52 | nar_weight = {}
53 | for k, v in nar_ckpt["model"].items():
54 | nar_weight[k.split("_orig_mod.")[1]] = v
55 | nar.load_state_dict(nar_weight)
56 | nar.eval()
57 |
58 | # prepare data
59 | prompt_text = np.array(generate_token(args.prompt_text))
60 | prompt_token = tokenize_wav(args.prompt_wav,audiodec,device,sample_rate)
61 | text = np.array(generate_token(args.text))
62 |
63 | semantic_token = torch.from_numpy(np.concatenate((prompt_text,text)))
64 | t_size, dim_size = prompt_token.shape
65 | semantic_token = semantic_token + 1 # padding
66 | acoustic_token = prompt_token + 1 + hp.num_semantic # padding + len(text)
67 |
68 | acoustic_len = prompt_token.shape[0]
69 | semantic_len = semantic_token.shape[0]
70 | max_len = 2040 - semantic_len-acoustic_len
71 | eos_id = hp.num_semantic + hp.num_acoustic + 1
72 |
73 | first_idx = np.asarray(list(semantic_token) + [eos_id] + list(acoustic_token[:, 0]))
74 | full_semantic = np.stack([semantic_token] * dim_size, axis=1) # t,8
75 | eos_full = np.stack([np.asarray([eos_id, ])] * dim_size, axis=1) # 1, 8
76 | full_idx = np.concatenate([full_semantic, eos_full, acoustic_token], axis=0)
77 |
78 |
79 | pos_ids= np.asarray(list(range(semantic_len + 1)) + list(range(acoustic_len)))
80 | tags = np.asarray([1] * (semantic_len + 1) + [2] * (acoustic_len))
81 |
82 | data = {}
83 | data["first_idx"] = torch.from_numpy(np.asarray(first_idx)).unsqueeze(0).to(device)
84 | data["seq_lens"] = torch.from_numpy(np.asarray(first_idx.shape[0])).unsqueeze(0).to(device)
85 | data["full_idx"] = torch.from_numpy(np.asarray(full_idx)).unsqueeze(0).to(device)
86 | data["pos_ids"] = torch.from_numpy(np.asarray(pos_ids)).unsqueeze(0).to(device)
87 | data["tags"] = torch.from_numpy(np.asarray(tags)).unsqueeze(0).to(device)
88 |
89 |
90 | # infer
91 | with torch.no_grad():
92 | first_idx = ar.inference(data,max_len,20)
93 | full_idx = nar.inference(data)
94 | full_idx = (full_idx - (hp.num_semantic+1))[:, :, :-1]
95 |
96 | prompt_idx = full_idx[:,:,:data['prompt_len']]
97 | full_idx = full_idx[:,:,data['prompt_len']:]
98 | full_idx = full_idx.squeeze(0)
99 |
100 | inc = (torch.arange(8)*1024).to(device)
101 | full_idx = full_idx + inc.reshape(-1,1)
102 | zq = audiodec.rx_encoder.lookup(full_idx)
103 |
104 | try:
105 | res_wav = audiodec.decoder.decode(zq)
106 | res_wav = res_wav.squeeze(1).transpose(1, 0).cpu().numpy()
107 | except:
108 | print('error in reconstruction')
109 |
110 | sf.write(
111 | './test/syn.wav',
112 | res_wav,
113 | sample_rate,
114 | "PCM_16",)
115 |
116 |
117 | if __name__ == '__main__':
118 |
119 | import argparse
120 |
121 | parser = argparse.ArgumentParser()
122 |
123 | parser.add_argument('--config',
124 | type=str,
125 | default='config/hparams.yaml',
126 | help='config path')
127 |
128 | parser.add_argument('--ar_ckpt',
129 | type=str,
130 | default='ckpt/basic/ar.pt',
131 | help='AR ckpt path')
132 |
133 | parser.add_argument('--nar_ckpt',
134 | type=str,
135 | default='ckpt/basic/nar.pt',
136 | help='NAR ckpt path')
137 | parser.add_argument('--prompt_wav',
138 | type=str,
139 | default='test/prompt_wavs/test_1.wav',
140 | help='out_dir')
141 | parser.add_argument('--prompt_text',
142 | type=str,
143 | default='在夏日阴凉的树荫下,鸭妈妈孵着鸭宝宝。',
144 | help='out_dir')
145 | parser.add_argument('--text',
146 | type=str,
147 | default='这一段戏同样也表达了亚瑟说的,我原本以为我的人生是一出悲剧,但其实它是一出戏剧。',
148 | help='out_dir')
149 | parser.add_argument('--device',
150 | type=str,
151 | default='cpu',
152 | help='cpu or cuda')
153 |
154 | args = parser.parse_args()
155 | hp = get_config_from_file(args.config).hparams
156 |
157 | do_tts(hp,args)
--------------------------------------------------------------------------------
/models/ar.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn as nn
4 | from transformers import GPT2Config, GPT2LMHeadModel
5 | import torch.nn.functional as F
6 | from tqdm import tqdm
7 |
8 |
9 | class AR(nn.Module):
10 | def __init__(self, hparams, device):
11 | super().__init__()
12 | self.hparams = hparams
13 | self.device = device
14 | self.lm_model = GPT2LMHeadModel(config=self._model_config()).to(device)
15 |
16 | def _model_config(self):
17 | return GPT2Config(
18 | vocab_size=self.hparams.GPT2_vocab_size,
19 | n_positions=self.hparams.GPT2_n_positions,
20 | n_ctx=self.hparams.GPT2_n_ctx,
21 | n_embd=self.hparams.GPT2_n_embd,
22 | n_layer=self.hparams.GPT2_n_layer,
23 | n_head=self.hparams.GPT2_n_head,
24 | )
25 |
26 |
27 | def inference(self, data,max_len=None,topk=None):
28 | idx = data['first_idx']
29 | pos_ids = data['pos_ids']
30 | tags = data['tags']
31 | semantic_len = (tags == 1).sum(dim=1)
32 | prompt_acoustic_len = (tags == 2).sum(dim=1)
33 | data['prompt_len'] = prompt_acoustic_len
34 |
35 | if max_len == None:
36 | max_len = int((semantic_len - 1) * 320 / 16000 * 24000 / 300 - prompt_acoustic_len) + 3
37 |
38 | for j in tqdm(range(max_len)):
39 | lm_outputs = self.lm_model(
40 | input_ids=idx,
41 | attention_mask=None,
42 | position_ids=pos_ids,
43 | )
44 | logits = lm_outputs['logits']
45 | logits = logits * self.hparams.temperature
46 | logits[:, :, 0:(self.hparams.num_semantic+1)] = -float('Inf') # semantic token
47 | logits[:, :, self.hparams.num_semantic + 1 + self.hparams.num_acoustic] =-float('Inf') # eos
48 |
49 | logits=logits[:, -1, :]
50 | if topk is not None:
51 | v, _ = torch.topk(logits, min(topk, logits.size(-1)))
52 | logits[logits < v[:, [-1]]] = -float('Inf')
53 |
54 |
55 | probs = logits.softmax(dim=-1) # [b, d]
56 | dist = torch.distributions.categorical.Categorical(probs=probs)
57 | samples = dist.sample().unsqueeze(0)
58 |
59 | idx = torch.cat([idx, samples], dim=1) # [b, t]
60 | pos_ids = torch.cat([pos_ids, pos_ids[:, -1:] + 1], dim=1) # [b, t]
61 | tags = torch.cat([tags, torch.zeros_like(tags[:, -1:]) + 2], dim=1)
62 | if max_len is not None:
63 | if samples.item() == self.hparams.num_semantic + 1 + self.hparams.num_acoustic +1: # eos
64 | break
65 | if j == max_len-1:
66 | # too long break
67 | samples[:,:] = self.hparams.num_semantic + 1 + self.hparams.num_acoustic +1
68 | idx = torch.cat([idx, samples], dim=1)
69 | pos_ids = torch.cat([pos_ids, pos_ids[:, -1:] + 1], dim=1) # [b, t]
70 | tags = torch.cat([tags, torch.zeros_like(tags[:, -1:]) + 2], dim=1)
71 | break
72 | data['tags'] = tags
73 | data['pos_ids'] = pos_ids
74 | data['first_idx'] = idx
75 | return idx
76 |
--------------------------------------------------------------------------------
/models/nar.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn as nn
4 | from torch.nn import TransformerEncoderLayer
5 | import random
6 | import torch.nn.functional as F
7 | import math
8 |
9 | class new_gelu(nn.Module):
10 | def forward(self, x):
11 | return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
12 |
13 | class NAR(nn.Module):
14 | def __init__(self, hparams, device):
15 | super().__init__()
16 | self.hparams = hparams
17 | self.device = device
18 |
19 | self.token_embeddings = nn.ModuleList([nn.Embedding(self.hparams.nar_vocab_size, embedding_dim=self.hparams.nar_n_emb, padding_idx=0)] * self.hparams.acoustic_num_layer)
20 | self.wpe = nn.Embedding(2048, self.hparams.nar_n_emb)
21 | self.dense_layer = nn.Linear(self.hparams.nar_n_emb * self.hparams.acoustic_num_layer, self.hparams.nar_n_emb, bias=False)
22 |
23 | self.res_embedding = nn.Embedding(self.hparams.acoustic_num_layer - 1, self.hparams.nar_n_emb)
24 |
25 | activation = new_gelu()
26 | self.transformer_layers = nn.ModuleList()
27 | for i in range(self.hparams.nar_n_layer):
28 | self.transformer_layers.append(
29 | TransformerEncoderLayer(
30 | d_model=self.hparams.nar_n_emb,
31 | nhead=self.hparams.nar_n_head ,
32 | dim_feedforward=self.hparams.nar_n_emb * 4,
33 | dropout=0.1,
34 | activation=activation,
35 | layer_norm_eps=1e-05,
36 | batch_first=True,
37 | )
38 | )
39 | self.mlp_layer = nn.Linear(self.hparams.nar_n_emb, self.hparams.num_acoustic + 2, bias=False) # quant_token_num + eos
40 |
41 |
42 | def inference(self, data):
43 | # unpack data
44 | init_full_idx = data['full_idx']
45 | pos_ids = data['pos_ids']
46 | tags = data['tags']
47 | prompt_len = data['prompt_len']
48 |
49 | # Calculate lengths
50 | semantic_len = (tags == 1).sum(dim=1)
51 | acoustic_len = (tags == 2).sum(dim=1)
52 |
53 | # Initialize the sequence tensor with indices for each acoustic layer
54 | full_idx = torch.stack([data['first_idx']] * self.hparams.acoustic_num_layer, dim=1)
55 | full_idx[:, :, semantic_len[0]:semantic_len[0] + prompt_len[0]] = init_full_idx[:,semantic_len[0]:, :].transpose(1, 2)
56 |
57 | batch_size, layer_size, t_size = full_idx.size()
58 |
59 | # Create masks
60 | layer_index = torch.ones(size=[batch_size,], device=self.device)
61 | layer_mask = layer_index.unsqueeze(1) > torch.arange(self.hparams.acoustic_num_layer, device=self.device).unsqueeze(0)
62 | prompt_mask = (semantic_len + prompt_len).unsqueeze(1) > torch.arange(t_size, device=self.device).unsqueeze(0)
63 |
64 | # Combine masks
65 | mask = layer_mask.unsqueeze(2) + prompt_mask.unsqueeze(1)
66 | full_idx = torch.where(mask, full_idx, torch.zeros_like(full_idx))
67 |
68 | for layer_index in range(1, self.hparams.acoustic_num_layer):
69 | layer_index_tensor = torch.LongTensor([layer_index, ]).repeat(batch_size, ).to(self.device)
70 | layer_mask = layer_index_tensor.unsqueeze(1) > torch.arange(self.hparams.acoustic_num_layer, device=self.device).unsqueeze(0)
71 |
72 | # Apply masks
73 | mask = layer_mask.unsqueeze(2) + prompt_mask.unsqueeze(1)
74 | mask_full_idx = torch.where(mask, full_idx, torch.zeros_like(full_idx))
75 |
76 | # Generate embeddings
77 | embeddings = [self.token_embeddings[i](mask_full_idx[:, i, :]) for i in range(layer_size)]
78 | embeddings = torch.cat(embeddings, dim=-1)
79 | embeddings = self.dense_layer(embeddings)
80 |
81 | # Add layer and position embeddings
82 | res_embeddings = self.res_embedding(layer_index_tensor - 1).unsqueeze(1)
83 | outputs = embeddings + self.wpe(pos_ids) + res_embeddings
84 |
85 | # Apply transformer layers
86 | for layer in self.transformer_layers:
87 | outputs = layer(outputs)
88 |
89 | logits = self.mlp_layer(outputs)
90 | logits[:, :, -2:] = -float('Inf') # unused token
91 |
92 | # NAR greedy search
93 | samples = torch.argmax(logits, dim=-1) + self.hparams.num_semantic + 1
94 | samples = torch.where(prompt_mask, full_idx[:, layer_index, :], samples)
95 | full_idx[:, layer_index, :] = samples
96 |
97 | return full_idx[:, :, semantic_len[0]:]
98 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers==4.25.1
2 | numba==0.55.1
3 | tqdm
4 | julius
5 | tabulate
6 | einops
7 | pydub
8 | pypinyin
9 | cn2an
10 | jieba
11 | soundfile
12 | sounddevice
13 | numpy
14 | torch
15 | torchaudio
16 | librosa
17 | argparse
18 | pyyaml
19 |
--------------------------------------------------------------------------------
/test/prompt_wavs/test_1.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dukGuo/valle-audiodec/cbd51784ea9fe0bee05b4a65c95f39523b57be89/test/prompt_wavs/test_1.wav
--------------------------------------------------------------------------------
/text/chinese.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 |
4 | import cn2an
5 | from pypinyin import lazy_pinyin, Style
6 |
7 | from text.symbols import punctuation
8 | from text.tone_sandhi import ToneSandhi
9 |
10 | current_file_path = os.path.dirname(__file__)
11 | pinyin_to_symbol_map = {
12 | line.split("\t")[0]: line.strip().split("\t")[1]
13 | for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
14 | }
15 |
16 | import jieba.posseg as psg
17 |
18 | finals= ['a','ai','an','ang','ao','e','ei','en','eng','er','i','i0','ia','ian','iang','iao','ie','in','ing','iong','ir','iu','o','ong','ou','u','ua','uai','uan','uang','ui','un','uo','v','van','ve','vn']
19 |
20 | rep_map = {
21 | ":": ",",
22 | ";": ",",
23 | ",": ",",
24 | "。": ".",
25 | "!": "!",
26 | "?": "?",
27 | "\n": ".",
28 | "·": ",",
29 | "、": ",",
30 | "...": "…",
31 | "$": ".",
32 | "“": "'",
33 | "”": "'",
34 | '"': "'",
35 | "‘": "'",
36 | "’": "'",
37 | "(": "'",
38 | ")": "'",
39 | "(": "'",
40 | ")": "'",
41 | "《": "'",
42 | "》": "'",
43 | "【": "'",
44 | "】": "'",
45 | "[": "'",
46 | "]": "'",
47 | "—": "-",
48 | "~": "-",
49 | "~": "-",
50 | "「": "'",
51 | "」": "'",
52 | }
53 |
54 | tone_modifier = ToneSandhi()
55 |
56 |
57 | def replace_punctuation(text):
58 | text = text.replace("嗯", "恩").replace("呣", "母")
59 | pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
60 |
61 | replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
62 |
63 | replaced_text = re.sub(
64 | r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
65 | )
66 |
67 | return replaced_text
68 |
69 |
70 | def g2p(text):
71 | pattern = r"(?<=[{0}])\s*".format("".join(punctuation))
72 | sentences = [i for i in re.split(pattern, text) if i.strip() != ""]
73 | phones, tones, word2ph = _g2p(sentences)
74 | assert sum(word2ph) == len(phones)
75 | assert len(word2ph) == len(text) # Sometimes it will crash,you can add a try-catch.
76 | phones = ["_"] + phones + ["_"]
77 | tones = [0] + tones + [0]
78 | word2ph = [1] + word2ph + [1]
79 | return phones, tones, word2ph
80 |
81 |
82 | def _get_initials_finals(word):
83 | initials = []
84 | finals = []
85 | orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
86 | orig_finals = lazy_pinyin(
87 | word, neutral_tone_with_five=True, style=Style.FINALS_TONE3
88 | )
89 | for c, v in zip(orig_initials, orig_finals):
90 | initials.append(c)
91 | finals.append(v)
92 | return initials, finals
93 |
94 |
95 | def _g2p(segments):
96 | phones_list = []
97 | tones_list = []
98 | word2ph = []
99 | for seg in segments:
100 | # Replace all English words in the sentence
101 | seg = re.sub("[a-zA-Z]+", "", seg)
102 | seg_cut = psg.lcut(seg)
103 | initials = []
104 | finals = []
105 | seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
106 | for word, pos in seg_cut:
107 | if pos == "eng":
108 | continue
109 | sub_initials, sub_finals = _get_initials_finals(word)
110 | sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
111 | initials.append(sub_initials)
112 | finals.append(sub_finals)
113 |
114 | # assert len(sub_initials) == len(sub_finals) == len(word)
115 | initials = sum(initials, [])
116 | finals = sum(finals, [])
117 | #
118 | for c, v in zip(initials, finals):
119 | raw_pinyin = c + v
120 | # NOTE: post process for pypinyin outputs
121 | # we discriminate i, ii and iii
122 | if c == v:
123 | assert c in punctuation
124 | phone = [c]
125 | tone = "0"
126 | word2ph.append(1)
127 | else:
128 | v_without_tone = v[:-1]
129 | tone = v[-1]
130 |
131 | pinyin = c + v_without_tone
132 | assert tone in "12345"
133 |
134 | if c:
135 | # 多音节
136 | v_rep_map = {
137 | "uei": "ui",
138 | "iou": "iu",
139 | "uen": "un",
140 | }
141 | if v_without_tone in v_rep_map.keys():
142 | pinyin = c + v_rep_map[v_without_tone]
143 | else:
144 | # 单音节
145 | pinyin_rep_map = {
146 | "ing": "ying",
147 | "i": "yi",
148 | "in": "yin",
149 | "u": "wu",
150 | }
151 | if pinyin in pinyin_rep_map.keys():
152 | pinyin = pinyin_rep_map[pinyin]
153 | else:
154 | single_rep_map = {
155 | "v": "yu",
156 | "e": "e",
157 | "i": "y",
158 | "u": "w",
159 | }
160 | if pinyin[0] in single_rep_map.keys():
161 | pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
162 |
163 | assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin)
164 | phone = pinyin_to_symbol_map[pinyin].split(" ")
165 | word2ph.append(len(phone))
166 |
167 | phones_list += phone
168 | tones_list += [int(tone)] * len(phone)
169 | return phones_list, tones_list, word2ph
170 |
171 |
172 | def text_normalize(text):
173 | numbers = re.findall(r"\d+(?:\.?\d+)?", text)
174 | for number in numbers:
175 | text = text.replace(number, cn2an.an2cn(number), 1)
176 | text = replace_punctuation(text)
177 | return text
178 |
179 | def generate_token(text):
180 |
181 | phone_set = "text/chinese_dict"
182 | with open(phone_set) as ps:
183 | phone_map = { pho.strip():idx+1 for idx,pho in enumerate(ps.readlines())}
184 |
185 | norm_text = text_normalize(text)
186 | phones, tones, word2ph = g2p(norm_text)
187 |
188 | final_phone_lst = []
189 | for phone,tone in zip(phones,tones):
190 | if '_' == phone:
191 | continue
192 | if phone in finals:
193 | phone = phone+str(tone)
194 | if phone == '!' or phone == '…':
195 | phone = '.'
196 | if phone == '-':
197 | phone = ','
198 | final_phone_lst.append(phone.strip())
199 | if 'iong4' in final_phone_lst: # 默认音素集内不存在iong4
200 | assert 0
201 | token_lst = [phone_map[phone] for phone in final_phone_lst]
202 |
203 | return token_lst
204 |
--------------------------------------------------------------------------------
/text/chinese_dict:
--------------------------------------------------------------------------------
1 | '
2 | ,
3 | .
4 | ?
5 | AA
6 | E
7 | EE
8 | En
9 | OO
10 | a1
11 | a2
12 | a3
13 | a4
14 | a5
15 | ai1
16 | ai2
17 | ai3
18 | ai4
19 | ai5
20 | an1
21 | an2
22 | an3
23 | an4
24 | an5
25 | ang1
26 | ang2
27 | ang3
28 | ang4
29 | ang5
30 | ao1
31 | ao2
32 | ao3
33 | ao4
34 | ao5
35 | b
36 | c
37 | ch
38 | d
39 | e1
40 | e2
41 | e3
42 | e4
43 | e5
44 | ei1
45 | ei2
46 | ei3
47 | ei4
48 | ei5
49 | en1
50 | en2
51 | en3
52 | en4
53 | en5
54 | eng1
55 | eng2
56 | eng3
57 | eng4
58 | eng5
59 | er2
60 | er3
61 | er4
62 | er5
63 | f
64 | g
65 | h
66 | i01
67 | i02
68 | i03
69 | i04
70 | i05
71 | i1
72 | i2
73 | i3
74 | i4
75 | i5
76 | ia1
77 | ia2
78 | ia3
79 | ia4
80 | ia5
81 | ian1
82 | ian2
83 | ian3
84 | ian4
85 | ian5
86 | iang1
87 | iang2
88 | iang3
89 | iang4
90 | iang5
91 | iao1
92 | iao2
93 | iao3
94 | iao4
95 | iao5
96 | ie1
97 | ie2
98 | ie3
99 | ie4
100 | ie5
101 | in1
102 | in2
103 | in3
104 | in4
105 | in5
106 | ing1
107 | ing2
108 | ing3
109 | ing4
110 | ing5
111 | iong1
112 | iong2
113 | iong3
114 | iong5
115 | ir1
116 | ir2
117 | ir3
118 | ir4
119 | ir5
120 | iu1
121 | iu2
122 | iu3
123 | iu4
124 | iu5
125 | j
126 | k
127 | l
128 | m
129 | n
130 | o1
131 | o2
132 | o3
133 | o4
134 | o5
135 | ong1
136 | ong2
137 | ong3
138 | ong4
139 | ong5
140 | ou1
141 | ou2
142 | ou3
143 | ou4
144 | ou5
145 | p
146 | q
147 | r
148 | s
149 | sh
150 | t
151 | u1
152 | u2
153 | u3
154 | u4
155 | u5
156 | ua1
157 | ua2
158 | ua3
159 | ua4
160 | ua5
161 | uai1
162 | uai2
163 | uai3
164 | uai4
165 | uai5
166 | uan1
167 | uan2
168 | uan3
169 | uan4
170 | uan5
171 | uang1
172 | uang2
173 | uang3
174 | uang4
175 | uang5
176 | ui1
177 | ui2
178 | ui3
179 | ui4
180 | ui5
181 | un1
182 | un2
183 | un3
184 | un4
185 | un5
186 | uo1
187 | uo2
188 | uo3
189 | uo4
190 | uo5
191 | v1
192 | v2
193 | v3
194 | v4
195 | v5
196 | van1
197 | van2
198 | van3
199 | van4
200 | van5
201 | ve1
202 | ve2
203 | ve3
204 | ve4
205 | ve5
206 | vn1
207 | vn2
208 | vn3
209 | vn4
210 | vn5
211 | w
212 | x
213 | y
214 | z
215 | zh
--------------------------------------------------------------------------------
/text/opencpop-strict.txt:
--------------------------------------------------------------------------------
1 | a AA a
2 | ai AA ai
3 | an AA an
4 | ang AA ang
5 | ao AA ao
6 | ba b a
7 | bai b ai
8 | ban b an
9 | bang b ang
10 | bao b ao
11 | bei b ei
12 | ben b en
13 | beng b eng
14 | bi b i
15 | bian b ian
16 | biao b iao
17 | bie b ie
18 | bin b in
19 | bing b ing
20 | bo b o
21 | bu b u
22 | ca c a
23 | cai c ai
24 | can c an
25 | cang c ang
26 | cao c ao
27 | ce c e
28 | cei c ei
29 | cen c en
30 | ceng c eng
31 | cha ch a
32 | chai ch ai
33 | chan ch an
34 | chang ch ang
35 | chao ch ao
36 | che ch e
37 | chen ch en
38 | cheng ch eng
39 | chi ch ir
40 | chong ch ong
41 | chou ch ou
42 | chu ch u
43 | chua ch ua
44 | chuai ch uai
45 | chuan ch uan
46 | chuang ch uang
47 | chui ch ui
48 | chun ch un
49 | chuo ch uo
50 | ci c i0
51 | cong c ong
52 | cou c ou
53 | cu c u
54 | cuan c uan
55 | cui c ui
56 | cun c un
57 | cuo c uo
58 | da d a
59 | dai d ai
60 | dan d an
61 | dang d ang
62 | dao d ao
63 | de d e
64 | dei d ei
65 | den d en
66 | deng d eng
67 | di d i
68 | dia d ia
69 | dian d ian
70 | diao d iao
71 | die d ie
72 | ding d ing
73 | diu d iu
74 | dong d ong
75 | dou d ou
76 | du d u
77 | duan d uan
78 | dui d ui
79 | dun d un
80 | duo d uo
81 | e EE e
82 | ei EE ei
83 | en EE en
84 | eng EE eng
85 | er EE er
86 | fa f a
87 | fan f an
88 | fang f ang
89 | fei f ei
90 | fen f en
91 | feng f eng
92 | fo f o
93 | fou f ou
94 | fu f u
95 | ga g a
96 | gai g ai
97 | gan g an
98 | gang g ang
99 | gao g ao
100 | ge g e
101 | gei g ei
102 | gen g en
103 | geng g eng
104 | gong g ong
105 | gou g ou
106 | gu g u
107 | gua g ua
108 | guai g uai
109 | guan g uan
110 | guang g uang
111 | gui g ui
112 | gun g un
113 | guo g uo
114 | ha h a
115 | hai h ai
116 | han h an
117 | hang h ang
118 | hao h ao
119 | he h e
120 | hei h ei
121 | hen h en
122 | heng h eng
123 | hong h ong
124 | hou h ou
125 | hu h u
126 | hua h ua
127 | huai h uai
128 | huan h uan
129 | huang h uang
130 | hui h ui
131 | hun h un
132 | huo h uo
133 | ji j i
134 | jia j ia
135 | jian j ian
136 | jiang j iang
137 | jiao j iao
138 | jie j ie
139 | jin j in
140 | jing j ing
141 | jiong j iong
142 | jiu j iu
143 | ju j v
144 | jv j v
145 | juan j van
146 | jvan j van
147 | jue j ve
148 | jve j ve
149 | jun j vn
150 | jvn j vn
151 | ka k a
152 | kai k ai
153 | kan k an
154 | kang k ang
155 | kao k ao
156 | ke k e
157 | kei k ei
158 | ken k en
159 | keng k eng
160 | kong k ong
161 | kou k ou
162 | ku k u
163 | kua k ua
164 | kuai k uai
165 | kuan k uan
166 | kuang k uang
167 | kui k ui
168 | kun k un
169 | kuo k uo
170 | la l a
171 | lai l ai
172 | lan l an
173 | lang l ang
174 | lao l ao
175 | le l e
176 | lei l ei
177 | leng l eng
178 | li l i
179 | lia l ia
180 | lian l ian
181 | liang l iang
182 | liao l iao
183 | lie l ie
184 | lin l in
185 | ling l ing
186 | liu l iu
187 | lo l o
188 | long l ong
189 | lou l ou
190 | lu l u
191 | luan l uan
192 | lun l un
193 | luo l uo
194 | lv l v
195 | lve l ve
196 | ma m a
197 | mai m ai
198 | man m an
199 | mang m ang
200 | mao m ao
201 | me m e
202 | mei m ei
203 | men m en
204 | meng m eng
205 | mi m i
206 | mian m ian
207 | miao m iao
208 | mie m ie
209 | min m in
210 | ming m ing
211 | miu m iu
212 | mo m o
213 | mou m ou
214 | mu m u
215 | na n a
216 | nai n ai
217 | nan n an
218 | nang n ang
219 | nao n ao
220 | ne n e
221 | nei n ei
222 | nen n en
223 | neng n eng
224 | ni n i
225 | nian n ian
226 | niang n iang
227 | niao n iao
228 | nie n ie
229 | nin n in
230 | ning n ing
231 | niu n iu
232 | nong n ong
233 | nou n ou
234 | nu n u
235 | nuan n uan
236 | nun n un
237 | nuo n uo
238 | nv n v
239 | nve n ve
240 | o OO o
241 | ou OO ou
242 | pa p a
243 | pai p ai
244 | pan p an
245 | pang p ang
246 | pao p ao
247 | pei p ei
248 | pen p en
249 | peng p eng
250 | pi p i
251 | pian p ian
252 | piao p iao
253 | pie p ie
254 | pin p in
255 | ping p ing
256 | po p o
257 | pou p ou
258 | pu p u
259 | qi q i
260 | qia q ia
261 | qian q ian
262 | qiang q iang
263 | qiao q iao
264 | qie q ie
265 | qin q in
266 | qing q ing
267 | qiong q iong
268 | qiu q iu
269 | qu q v
270 | qv q v
271 | quan q van
272 | qvan q van
273 | que q ve
274 | qve q ve
275 | qun q vn
276 | qvn q vn
277 | ran r an
278 | rang r ang
279 | rao r ao
280 | re r e
281 | ren r en
282 | reng r eng
283 | ri r ir
284 | rong r ong
285 | rou r ou
286 | ru r u
287 | rua r ua
288 | ruan r uan
289 | rui r ui
290 | run r un
291 | ruo r uo
292 | sa s a
293 | sai s ai
294 | san s an
295 | sang s ang
296 | sao s ao
297 | se s e
298 | sen s en
299 | seng s eng
300 | sha sh a
301 | shai sh ai
302 | shan sh an
303 | shang sh ang
304 | shao sh ao
305 | she sh e
306 | shei sh ei
307 | shen sh en
308 | sheng sh eng
309 | shi sh ir
310 | shou sh ou
311 | shu sh u
312 | shua sh ua
313 | shuai sh uai
314 | shuan sh uan
315 | shuang sh uang
316 | shui sh ui
317 | shun sh un
318 | shuo sh uo
319 | si s i0
320 | song s ong
321 | sou s ou
322 | su s u
323 | suan s uan
324 | sui s ui
325 | sun s un
326 | suo s uo
327 | ta t a
328 | tai t ai
329 | tan t an
330 | tang t ang
331 | tao t ao
332 | te t e
333 | tei t ei
334 | teng t eng
335 | ti t i
336 | tian t ian
337 | tiao t iao
338 | tie t ie
339 | ting t ing
340 | tong t ong
341 | tou t ou
342 | tu t u
343 | tuan t uan
344 | tui t ui
345 | tun t un
346 | tuo t uo
347 | wa w a
348 | wai w ai
349 | wan w an
350 | wang w ang
351 | wei w ei
352 | wen w en
353 | weng w eng
354 | wo w o
355 | wu w u
356 | xi x i
357 | xia x ia
358 | xian x ian
359 | xiang x iang
360 | xiao x iao
361 | xie x ie
362 | xin x in
363 | xing x ing
364 | xiong x iong
365 | xiu x iu
366 | xu x v
367 | xv x v
368 | xuan x van
369 | xvan x van
370 | xue x ve
371 | xve x ve
372 | xun x vn
373 | xvn x vn
374 | ya y a
375 | yan y En
376 | yang y ang
377 | yao y ao
378 | ye y E
379 | yi y i
380 | yin y in
381 | ying y ing
382 | yo y o
383 | yong y ong
384 | you y ou
385 | yu y v
386 | yv y v
387 | yuan y van
388 | yvan y van
389 | yue y ve
390 | yve y ve
391 | yun y vn
392 | yvn y vn
393 | za z a
394 | zai z ai
395 | zan z an
396 | zang z ang
397 | zao z ao
398 | ze z e
399 | zei z ei
400 | zen z en
401 | zeng z eng
402 | zha zh a
403 | zhai zh ai
404 | zhan zh an
405 | zhang zh ang
406 | zhao zh ao
407 | zhe zh e
408 | zhei zh ei
409 | zhen zh en
410 | zheng zh eng
411 | zhi zh ir
412 | zhong zh ong
413 | zhou zh ou
414 | zhu zh u
415 | zhua zh ua
416 | zhuai zh uai
417 | zhuan zh uan
418 | zhuang zh uang
419 | zhui zh ui
420 | zhun zh un
421 | zhuo zh uo
422 | zi z i0
423 | zong z ong
424 | zou z ou
425 | zu z u
426 | zuan z uan
427 | zui z ui
428 | zun z un
429 | zuo z uo
430 |
--------------------------------------------------------------------------------
/text/symbols.py:
--------------------------------------------------------------------------------
1 | punctuation = ["!", "?", "…", ",", ".", "'", "-"]
2 | pu_symbols = punctuation + ["SP", "UNK"]
3 | pad = "_"
4 |
5 | # chinese
6 | zh_symbols = [
7 | "E",
8 | "En",
9 | "a",
10 | "ai",
11 | "an",
12 | "ang",
13 | "ao",
14 | "b",
15 | "c",
16 | "ch",
17 | "d",
18 | "e",
19 | "ei",
20 | "en",
21 | "eng",
22 | "er",
23 | "f",
24 | "g",
25 | "h",
26 | "i",
27 | "i0",
28 | "ia",
29 | "ian",
30 | "iang",
31 | "iao",
32 | "ie",
33 | "in",
34 | "ing",
35 | "iong",
36 | "ir",
37 | "iu",
38 | "j",
39 | "k",
40 | "l",
41 | "m",
42 | "n",
43 | "o",
44 | "ong",
45 | "ou",
46 | "p",
47 | "q",
48 | "r",
49 | "s",
50 | "sh",
51 | "t",
52 | "u",
53 | "ua",
54 | "uai",
55 | "uan",
56 | "uang",
57 | "ui",
58 | "un",
59 | "uo",
60 | "v",
61 | "van",
62 | "ve",
63 | "vn",
64 | "w",
65 | "x",
66 | "y",
67 | "z",
68 | "zh",
69 | "AA",
70 | "EE",
71 | "OO",
72 | ]
73 | num_zh_tones = 6
74 |
75 | # japanese
76 | ja_symbols = [
77 | "N",
78 | "a",
79 | "a:",
80 | "b",
81 | "by",
82 | "ch",
83 | "d",
84 | "dy",
85 | "e",
86 | "e:",
87 | "f",
88 | "g",
89 | "gy",
90 | "h",
91 | "hy",
92 | "i",
93 | "i:",
94 | "j",
95 | "k",
96 | "ky",
97 | "m",
98 | "my",
99 | "n",
100 | "ny",
101 | "o",
102 | "o:",
103 | "p",
104 | "py",
105 | "q",
106 | "r",
107 | "ry",
108 | "s",
109 | "sh",
110 | "t",
111 | "ts",
112 | "ty",
113 | "u",
114 | "u:",
115 | "w",
116 | "y",
117 | "z",
118 | "zy",
119 | ]
120 | num_ja_tones = 2
121 |
122 | # English
123 | en_symbols = [
124 | "aa",
125 | "ae",
126 | "ah",
127 | "ao",
128 | "aw",
129 | "ay",
130 | "b",
131 | "ch",
132 | "d",
133 | "dh",
134 | "eh",
135 | "er",
136 | "ey",
137 | "f",
138 | "g",
139 | "hh",
140 | "ih",
141 | "iy",
142 | "jh",
143 | "k",
144 | "l",
145 | "m",
146 | "n",
147 | "ng",
148 | "ow",
149 | "oy",
150 | "p",
151 | "r",
152 | "s",
153 | "sh",
154 | "t",
155 | "th",
156 | "uh",
157 | "uw",
158 | "V",
159 | "w",
160 | "y",
161 | "z",
162 | "zh",
163 | ]
164 | num_en_tones = 4
165 |
166 | # combine all symbols
167 | normal_symbols = sorted(set(zh_symbols + ja_symbols + en_symbols))
168 | symbols = [pad] + normal_symbols + pu_symbols
169 | sil_phonemes_ids = [symbols.index(i) for i in pu_symbols]
170 |
171 | # combine all tones
172 | num_tones = num_zh_tones + num_ja_tones + num_en_tones
173 |
174 | # language maps
175 | language_id_map = {"ZH": 0, "JP": 1, "EN": 2}
176 | num_languages = len(language_id_map.keys())
177 |
178 | language_tone_start_map = {
179 | "ZH": 0,
180 | "JP": num_zh_tones,
181 | "EN": num_zh_tones + num_ja_tones,
182 | }
183 |
184 | if __name__ == "__main__":
185 | a = set(zh_symbols)
186 | b = set(en_symbols)
187 | print(sorted(a & b))
188 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import torch
3 | import collections
4 | import os
5 | import numpy as np
6 | import random
7 |
8 | class ValueWindow():
9 | def __init__(self, window_size=100):
10 | self._window_size = window_size
11 | self._values = []
12 |
13 | def append(self, x):
14 | self._values = self._values[-(self._window_size - 1):] + [x]
15 |
16 | @property
17 | def sum(self):
18 | return sum(self._values)
19 |
20 | @property
21 | def count(self):
22 | return len(self._values)
23 |
24 | @property
25 | def average(self):
26 | return self.sum / max(1, self.count)
27 |
28 | def reset(self):
29 | self._values = []
30 |
31 |
32 | def str2bool(v):
33 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
34 | return True
35 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
36 | return False
37 | else:
38 | raise ValueError('Unsupported value encountered.')
39 |
40 |
41 | # def hparams_string(hparams, args):
42 | # print('Output path: {}'.format(args.logdir_path))
43 | # values = hparams.values()
44 | # hp = [' %s: %s' % (name, values[name])
45 | # for name in sorted(values) if name != 'sentences']
46 | # # print('Hyperparameters:\n' + '\n'.join(hp))
47 | # return
48 |
49 |
50 | class HParams():
51 | def __init__(self, **kwargs):
52 | for k, v in kwargs.items():
53 | if type(v) == dict:
54 | v = HParams(**v)
55 | self[k] = v
56 |
57 | def keys(self):
58 | return self.__dict__.keys()
59 |
60 | def items(self):
61 | return self.__dict__.items()
62 |
63 | def values(self):
64 | return self.__dict__.values()
65 |
66 | def __len__(self):
67 | return len(self.__dict__)
68 |
69 | def __getitem__(self, key):
70 | return getattr(self, key)
71 |
72 | def __setitem__(self, key, value):
73 | return setattr(self, key, value)
74 |
75 | def __contains__(self, key):
76 | return key in self.__dict__
77 |
78 | def __repr__(self):
79 | return self.__dict__.__repr__()
80 |
81 |
82 | def get_config_from_file(file):
83 | with open(file, 'r') as f:
84 | hp = yaml.load(f,Loader=yaml.FullLoader)
85 | hp = HParams(**hp)
86 | return hp
87 |
88 | def update_lr(opt, lr):
89 | for g in opt.param_groups:
90 | g['lr'] = lr
91 | return
92 |
93 | def to_device(tensors, device):
94 | tensors_to_device = []
95 | for tensor in tensors:
96 | if isinstance(tensor, torch.Tensor):
97 | tensors_to_device.append(tensor.to(device))
98 | else:
99 | tensors_to_device.append(tensor)
100 | return tensors_to_device
101 |
102 | def calculate_model_params(model):
103 | total = sum([param.numel() for param in model.parameters()])
104 | para_name = [para[0] for para in list(model.named_parameters())]
105 | print("==================================================")
106 | print("model struct is {}\n".format(str(model)))
107 | print("model params : {}".format(para_name))
108 | # log("FLOPs: {}".format(flops.total_float_ops))
109 | print("Number of parameter: %.2fM" % (total / 1e6))
110 | print("==================================================")
111 | return
112 |
113 | def get_metadata(path):
114 | with open(path, 'r') as f:
115 | metas = [l.strip() for l in f]
116 | random.shuffle(metas)
117 | return metas
118 |
119 | class TemperatureSampler():
120 | """
121 | ## Sampler with Temperature
122 | """
123 | def __init__(self, temperature: float = 1.0):
124 | """
125 | :param temperature: is the temperature to sample with
126 | """
127 | self.temperature = temperature
128 |
129 | def __call__(self, logits: torch.Tensor):
130 | """
131 | Sample from logits
132 | """
133 |
134 | # Create a categorical distribution with temperature adjusted logits
135 | dist = torch.distributions.categorical.Categorical(logits=logits / self.temperature)
136 |
137 | # Sample
138 | return dist.sample()
139 |
140 |
141 |
142 | class TopKSampler():
143 | """
144 | ## Top-k Sampler
145 | """
146 | def __init__(self, k: int, sampler=TemperatureSampler()):
147 | """
148 | :param k: is the number of tokens to pick
149 | :param sampler: is the sampler to use for the top-k tokens
150 |
151 | `sampler` can be any sampler that takes a logits tensor as input and returns a token tensor;
152 | e.g. [`TemperatureSampler'](temperature.html).
153 | """
154 | self.k = k
155 | self.sampler = sampler
156 |
157 | def __call__(self, logits: torch.Tensor):
158 | """
159 | Sample from logits
160 | """
161 | # New logits filled with $-\infty$; i.e. zero probability
162 | zeros = logits.new_ones(logits.shape) * float('-inf')
163 | # Pick the largest $k$ logits and their indices
164 | values, indices = torch.topk(logits, self.k, dim=-1)
165 | # Set the values of the top-k selected indices to actual logits.
166 | # Logits of other tokens remain $-\infty$
167 | zeros.scatter_(-1, indices, values)
168 |
169 | # Sample from the top-k logits with the specified sampler.
170 | return self.sampler(zeros)
--------------------------------------------------------------------------------