├── LICENSE ├── README.md ├── gomin ├── __init__.py ├── config.py ├── models │ ├── __init__.py │ ├── common.py │ ├── diffusion.py │ ├── diffusion_wrapper.py │ └── gan.py ├── run.py └── utils.py ├── requirements.txt └── setup.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Gaudio Lab, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GOMIN: Gaudio Open Melspectrogram Inversion Network 2 | 3 | tl;dr - GOMIN is a general-purpose, general-source model for melspectrogram -> waveform. 4 | 5 | 6 | ## About 7 | 8 | GOMIN is a general-purpose, general-source model for converting melspectrograms to waveforms. 9 | 10 | We open source this code and provide pretrained models to further research in music / general audio generation. 11 | The models have been trained on a diverse range of audio datasets, including speech signals, music stems, animal sound recordings, and foley sound stems. 12 | To cover those various filed of audio and make it more universal and robust, some improvements are applied to its neural vocoder baseline. 13 | This makes GOMIN suitable for various applications and research endeavors. 14 | 15 | ### Supported models 16 | 17 | The available models are based on two state-of-the-art neural vocoder models; BigVGAN \[[Lee et al. 2022](https://arxiv.org/abs/2206.04658)\] and DiffWave \[[Kong et al. 2020](https://arxiv.org/abs/2009.09761)\]. 18 | These models have been slightly modified to improve their performance in generating general audio signals beyond just speech signals. 19 | 20 | The GAN-based models have been enhanced with Feature-wise Linear Modulation (FiLM) \[[Perez et al. 2017](https://arxiv.org/abs/1709.07871)\] after every upsampling block. 21 | The modulation parameters, i.e. shift and scale parameters, are calculated from the raw melspectrogram and each upsampling layer has distinct parameters, meaning that the parameters are not shared. 22 | This modification improves tonal consistency and leads to better sound reconstruction for general audio signals. 23 | Note that code for this model is largely brought from [HiFi-GAN](https://github.com/jik876/hifi-gan) \[[Kong et al. 2020](https://arxiv.org/abs/2010.05646)\], not directly from [BigVGAN](https://github.com/NVIDIA/BigVGAN) repository. 24 | 25 | For Diffusion-based models, in addition to the FiLM, we have also fine-tuned its noise schedule to better accommodate universal audio generation. 26 | This was achieved by interpolating two popular schedules, the linear and cosine schedule. 27 | The interpolated schedule roughly follows the linear schedule near $t=0$ and the cosine schedule near $t=T$. 28 | This noise schedule injects more noise in earlier steps, helping the model to handle more diverse and complex data distributions. 29 | Research has shown that the use of noise schedules is crucial in high-resolution image generation \[[Chen 2023](https://arxiv.org/abs/2301.10972), [Hoogeboom et al. 2023](https://arxiv.org/abs/2301.11093)\]. 30 | We believe that further research in this area will lead to even better audio generation in the future. 31 | 32 | 33 | ## Requirements 34 | 35 | - python3 ≥ 3.10 36 | - torch (tested on 1.13.1+cu116) 37 | - torchaudio (tested on 0.13.1+cu116) 38 | - diffusers (tested on 0.17.0) 39 | - librosa ≥ 0.9.2 40 | - numpy == 1.24.3 41 | - pyyaml 42 | - scipy ≥ 1.10.0 43 | - soundfile ≥ 0.11.0 44 | - tqdm 45 | 46 | 47 | ## Install 48 | You can install this package using following command. 49 | ```Shell 50 | $ git clone https://github.com/ryeoat3/gomin.git 51 | $ cd gomin 52 | $ pip install -e . 53 | ``` 54 | 55 | 56 | ## Pretrained checkpoint 57 | 58 | You can download pretrained checkpoint from google drive ([gan](https://drive.google.com/file/d/1TyNCS7fdeeCJK66x_n9TeR_SurPaft4L), [diffusion](https://drive.google.com/file/d/1vkrTICKruShu_0ofM3vTc3No2Fj5rMxD)). 59 | 60 | 61 | ## Inference 62 | 63 | ### Python 64 | ```Python 65 | >>> from gomin.models import GomiGAN, DiffusionWrapper 66 | >>> from config import GANConfig, DiffusionConfig 67 | 68 | # Model loading 69 | >>> model = GomiGAN.from_pretrained( 70 | pretrained_model_path="CHECKPOINT PATH", **GANConfig().__dict__ 71 | ) # for GAN model 72 | >>> model = DiffusionWrapper.from_pretrained( 73 | pretrained_model_path="CHECKPOINT_PATH", **DiffusionConfig().__dict__ 74 | ) # for diffusion model 75 | 76 | # To convert your waveform into mel-spectrogram, run: 77 | >>> assert waveform.ndim == 2 78 | >>> melspec = model.prepare_melspectrogram(waveform) 79 | 80 | # To reconstruct wavefrom from mel-spectrogram, run: 81 | >>> assert melspec.ndim == 3 82 | >>> waveform = model(melspec) 83 | ``` 84 | 85 | ### CLI 86 | ```Shell 87 | $ python -m gomin.run -m {MODEL: 'gan' | 'diffusion'} -p {MODEL_PATH} -i {INPUT_FILES} -o {OUTPUT_FILES} 88 | ``` 89 | 90 | **Arguments** 91 | 92 | - `-m` (`--model`): model type to run process. either 'gan' or 'diffusion' is valid. 93 | - `-p` (`--model_path`): directory path for model checkpoint. default value is `'checkpoints/'` 94 | - `-i` (`--input_files`): list of file paths or directory path for input files. 95 | - `-o` (`--output_files`): list of file paths or directory path for output files. 96 | 97 | **Notes** 98 | 99 | For CLI run, `input_files` and `output_files` option supports 100 | 101 | - list of files (e.g. `-i a.wav b.wav c.wav -o a_out.wav b_out.wav c_out.wav`) 102 | - directory path (e.g. `-i inputs/ -o outputs/`) 103 | - and path with wildcard `*` (e.g. `-i inputs/**/*.wav -o outputs/**/*.pt`) 104 | 105 | Extensions of `input_files` and `output_files` will determine which process to run: 106 | 107 | - Analysis and synthesis 108 | - If both are audio files (with extension of `.wav` or `.mp3`), this program will first convert to mel-spectrogram and reconstruct it back to waveform. 109 | - Analysis 110 | - If `input_files` extinsions are one of [`.wav` and `.mp3`] and `output_files` are `.pt` file, thie program will convert waveforms to mel-spectrogram and save them. Saved files will include mel-spectrogram tensor and configurations (`config.py::MelConfig`). 111 | - Synthesis 112 | - If `input_files` are `.pt` files and `output_files` are audio files, this program will reconstruct mel-spectrogram into waveform. 113 | 114 | 115 | ## References 116 | - [Lee et al. 2022](https://arxiv.org/abs/2206.04658): Sang-gil Lee, Wei Ping, Boris Ginsburg, Bryan Catanzaro, Sungroh Yoon, "BigVGAN: A Universal Neural Vocoder with Large-Scale Training" 117 | - [Kong et al. 2020](https://arxiv.org/abs/2009.09761): Zhifeng Kong, Wei Ping, Jiaji Huang, Kexin Zhao, Bryan Catanzaro, "DiffWave: A Versatile Diffusion Model for Audio Synthesis" 118 | - [Perez et al. 2017](https://arxiv.org/abs/1709.07871): Ethan Perez, Florian Strub, Harm de Vries, Vincent Dumoulin, Aaron Courville. "FiLM: Visual Reasoning with a General Conditioning Layer" 119 | - [Chen 2023](https://arxiv.org/abs/2301.10972): Ting Chen. "On the Importance of Noise Scheduling for Diffusion Models" 120 | - [Hoogeboom et al. 2023](https://arxiv.org/abs/2301.11093): Emiel Hoogeboom, Jonathan Heek, Tim Salimans. "simple diffusion: End-to-end diffusion for high resolution images" 121 | 122 | ## Related repos 123 | - HiFi-GAN: https://github.com/jik876/hifi-gan 124 | - DiffWave: https://github.com/lmnt-com/diffwave 125 | 126 | 127 | ## How to cite 128 | If you find this work useful, please refer this: 129 | ```bibtex 130 | # Workshop paper for DCASE Workshop 2023 131 | @inproceedings{Kang2023, 132 | author = "Kang, Minsung and Oh, Sangshin and Moon, Hyeongi and Lee, Kyungyun and Chon, Ben Sangbae", 133 | title = "FALL-E: A Foley Sound Synthesis Model and Strategies", 134 | booktitle = "Proceedings of the 8th Detection and Classification of Acoustic Scenes and Events 2023 Workshop (DCASE2023)", 135 | address = "Tampere, Finland", 136 | year = "2023", 137 | } 138 | ``` 139 | 140 | ## License 141 | [MIT License](LICENSE) 142 | -------------------------------------------------------------------------------- /gomin/__init__.py: -------------------------------------------------------------------------------- 1 | # empty 2 | -------------------------------------------------------------------------------- /gomin/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import List 3 | 4 | 5 | @dataclass 6 | class MelConfig: 7 | n_mels: int = 128 8 | sample_rate: int = 24000 9 | win_length: int = 1024 10 | hop_length: int = 256 11 | 12 | 13 | @dataclass 14 | class DiffusionConfig: 15 | in_channels: int = 128 16 | residual_layers: int = 30 17 | residual_channels: int = 128 18 | dilation_cycle_length: int = 10 19 | num_diffusion_steps: int = 50 20 | 21 | # mel config 22 | sample_rate: int = 24000 23 | win_length: int = 1024 24 | hop_length: int = 256 25 | 26 | 27 | @dataclass 28 | class GANConfig: 29 | in_channels: int = 128 30 | upsample_in_channels: int = 1536 31 | upsample_strides: List[int] = field(default_factory=lambda: [4, 4, 2, 2, 2, 2]) 32 | resblock_kernel_sizes: List[int] = field(default_factory=lambda: [3, 7, 11]) 33 | resblock_dilations: List[List[int]] = field( 34 | default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] 35 | ) 36 | sample_rate: int = 24000 37 | 38 | # mel config 39 | win_length: int = 1024 40 | hop_length: int = 256 41 | -------------------------------------------------------------------------------- /gomin/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .gan import GomiGAN 2 | from .diffusion_wrapper import DiffusionWrapper 3 | from .common import MelspecInversion 4 | -------------------------------------------------------------------------------- /gomin/models/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | from torchaudio.transforms import MelSpectrogram 7 | 8 | 9 | def get_least_power2_above(x): 10 | return np.power(2, math.ceil(np.log2(x))) 11 | 12 | 13 | class MelspecInversion(nn.Module): 14 | def __init__( 15 | self, 16 | n_mels: int = 128, 17 | sample_rate: int = 24000, 18 | win_length: int = 1024, 19 | hop_length: int = 256, 20 | f_min: float = 0.0, 21 | f_max: float | None = None, 22 | ): 23 | super().__init__() 24 | self.n_mels = n_mels 25 | self.sample_rate = sample_rate 26 | self.win_length = win_length 27 | self.hop_length = hop_length 28 | self.f_min = f_min 29 | self.f_max = f_max or (self.sample_rate / 2.0) 30 | self.melspec_layer = None 31 | 32 | @classmethod 33 | def from_pretrained(cls, pretrained_model_path, **config): 34 | model = cls(**config) 35 | model.load_state_dict(torch.load(pretrained_model_path, map_location="cpu")) 36 | return model 37 | 38 | def prepare_melspectrogram(self, audio): 39 | if self.melspec_layer is None: 40 | self.melspec_layer = MelSpectrogram( 41 | n_mels=self.n_mels, 42 | sample_rate=self.sample_rate, 43 | n_fft=get_least_power2_above(self.win_length), 44 | win_length=self.win_length, 45 | hop_length=self.hop_length, 46 | f_min=self.f_min, 47 | f_max=self.f_max, 48 | center=True, 49 | power=2.0, 50 | mel_scale="slaney", 51 | norm="slaney", 52 | normalized=True, 53 | pad_mode="constant", 54 | ) 55 | self.melspec_layer = self.melspec_layer.to(audio.device) 56 | 57 | melspec = self.melspec_layer(audio) 58 | melspec = 10 * torch.log10(melspec + 1e-10) 59 | melspec = torch.clamp((melspec + 100) / 100, min=0.0) 60 | return melspec 61 | -------------------------------------------------------------------------------- /gomin/models/diffusion.py: -------------------------------------------------------------------------------- 1 | # This file is modified from https://github.com/lmnt-com/diffwave 2 | 3 | # Original terms for the license are as follows: 4 | 5 | # Apache License 6 | # Version 2.0, January 2004 7 | # http://www.apache.org/licenses/ 8 | 9 | # TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 10 | 11 | # 1. Definitions. 12 | 13 | # "License" shall mean the terms and conditions for use, reproduction, 14 | # and distribution as defined by Sections 1 through 9 of this document. 15 | 16 | # "Licensor" shall mean the copyright owner or entity authorized by 17 | # the copyright owner that is granting the License. 18 | 19 | # "Legal Entity" shall mean the union of the acting entity and all 20 | # other entities that control, are controlled by, or are under common 21 | # control with that entity. For the purposes of this definition, 22 | # "control" means (i) the power, direct or indirect, to cause the 23 | # direction or management of such entity, whether by contract or 24 | # otherwise, or (ii) ownership of fifty percent (50%) or more of the 25 | # outstanding shares, or (iii) beneficial ownership of such entity. 26 | 27 | # "You" (or "Your") shall mean an individual or Legal Entity 28 | # exercising permissions granted by this License. 29 | 30 | # "Source" form shall mean the preferred form for making modifications, 31 | # including but not limited to software source code, documentation 32 | # source, and configuration files. 33 | 34 | # "Object" form shall mean any form resulting from mechanical 35 | # transformation or translation of a Source form, including but 36 | # not limited to compiled object code, generated documentation, 37 | # and conversions to other media types. 38 | 39 | # "Work" shall mean the work of authorship, whether in Source or 40 | # Object form, made available under the License, as indicated by a 41 | # copyright notice that is included in or attached to the work 42 | # (an example is provided in the Appendix below). 43 | 44 | # "Derivative Works" shall mean any work, whether in Source or Object 45 | # form, that is based on (or derived from) the Work and for which the 46 | # editorial revisions, annotations, elaborations, or other modifications 47 | # represent, as a whole, an original work of authorship. For the purposes 48 | # of this License, Derivative Works shall not include works that remain 49 | # separable from, or merely link (or bind by name) to the interfaces of, 50 | # the Work and Derivative Works thereof. 51 | 52 | # "Contribution" shall mean any work of authorship, including 53 | # the original version of the Work and any modifications or additions 54 | # to that Work or Derivative Works thereof, that is intentionally 55 | # submitted to Licensor for inclusion in the Work by the copyright owner 56 | # or by an individual or Legal Entity authorized to submit on behalf of 57 | # the copyright owner. For the purposes of this definition, "submitted" 58 | # means any form of electronic, verbal, or written communication sent 59 | # to the Licensor or its representatives, including but not limited to 60 | # communication on electronic mailing lists, source code control systems, 61 | # and issue tracking systems that are managed by, or on behalf of, the 62 | # Licensor for the purpose of discussing and improving the Work, but 63 | # excluding communication that is conspicuously marked or otherwise 64 | # designated in writing by the copyright owner as "Not a Contribution." 65 | 66 | # "Contributor" shall mean Licensor and any individual or Legal Entity 67 | # on behalf of whom a Contribution has been received by Licensor and 68 | # subsequently incorporated within the Work. 69 | 70 | # 2. Grant of Copyright License. Subject to the terms and conditions of 71 | # this License, each Contributor hereby grants to You a perpetual, 72 | # worldwide, non-exclusive, no-charge, royalty-free, irrevocable 73 | # copyright license to reproduce, prepare Derivative Works of, 74 | # publicly display, publicly perform, sublicense, and distribute the 75 | # Work and such Derivative Works in Source or Object form. 76 | 77 | # 3. Grant of Patent License. Subject to the terms and conditions of 78 | # this License, each Contributor hereby grants to You a perpetual, 79 | # worldwide, non-exclusive, no-charge, royalty-free, irrevocable 80 | # (except as stated in this section) patent license to make, have made, 81 | # use, offer to sell, sell, import, and otherwise transfer the Work, 82 | # where such license applies only to those patent claims licensable 83 | # by such Contributor that are necessarily infringed by their 84 | # Contribution(s) alone or by combination of their Contribution(s) 85 | # with the Work to which such Contribution(s) was submitted. If You 86 | # institute patent litigation against any entity (including a 87 | # cross-claim or counterclaim in a lawsuit) alleging that the Work 88 | # or a Contribution incorporated within the Work constitutes direct 89 | # or contributory patent infringement, then any patent licenses 90 | # granted to You under this License for that Work shall terminate 91 | # as of the date such litigation is filed. 92 | 93 | # 4. Redistribution. You may reproduce and distribute copies of the 94 | # Work or Derivative Works thereof in any medium, with or without 95 | # modifications, and in Source or Object form, provided that You 96 | # meet the following conditions: 97 | 98 | # (a) You must give any other recipients of the Work or 99 | # Derivative Works a copy of this License; and 100 | 101 | # (b) You must cause any modified files to carry prominent notices 102 | # stating that You changed the files; and 103 | 104 | # (c) You must retain, in the Source form of any Derivative Works 105 | # that You distribute, all copyright, patent, trademark, and 106 | # attribution notices from the Source form of the Work, 107 | # excluding those notices that do not pertain to any part of 108 | # the Derivative Works; and 109 | 110 | # (d) If the Work includes a "NOTICE" text file as part of its 111 | # distribution, then any Derivative Works that You distribute must 112 | # include a readable copy of the attribution notices contained 113 | # within such NOTICE file, excluding those notices that do not 114 | # pertain to any part of the Derivative Works, in at least one 115 | # of the following places: within a NOTICE text file distributed 116 | # as part of the Derivative Works; within the Source form or 117 | # documentation, if provided along with the Derivative Works; or, 118 | # within a display generated by the Derivative Works, if and 119 | # wherever such third-party notices normally appear. The contents 120 | # of the NOTICE file are for informational purposes only and 121 | # do not modify the License. You may add Your own attribution 122 | # notices within Derivative Works that You distribute, alongside 123 | # or as an addendum to the NOTICE text from the Work, provided 124 | # that such additional attribution notices cannot be construed 125 | # as modifying the License. 126 | 127 | # You may add Your own copyright statement to Your modifications and 128 | # may provide additional or different license terms and conditions 129 | # for use, reproduction, or distribution of Your modifications, or 130 | # for any such Derivative Works as a whole, provided Your use, 131 | # reproduction, and distribution of the Work otherwise complies with 132 | # the conditions stated in this License. 133 | 134 | # 5. Submission of Contributions. Unless You explicitly state otherwise, 135 | # any Contribution intentionally submitted for inclusion in the Work 136 | # by You to the Licensor shall be under the terms and conditions of 137 | # this License, without any additional terms or conditions. 138 | # Notwithstanding the above, nothing herein shall supersede or modify 139 | # the terms of any separate license agreement you may have executed 140 | # with Licensor regarding such Contributions. 141 | 142 | # 6. Trademarks. This License does not grant permission to use the trade 143 | # names, trademarks, service marks, or product names of the Licensor, 144 | # except as required for reasonable and customary use in describing the 145 | # origin of the Work and reproducing the content of the NOTICE file. 146 | 147 | # 7. Disclaimer of Warranty. Unless required by applicable law or 148 | # agreed to in writing, Licensor provides the Work (and each 149 | # Contributor provides its Contributions) on an "AS IS" BASIS, 150 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 151 | # implied, including, without limitation, any warranties or conditions 152 | # of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 153 | # PARTICULAR PURPOSE. You are solely responsible for determining the 154 | # appropriateness of using or redistributing the Work and assume any 155 | # risks associated with Your exercise of permissions under this License. 156 | 157 | # 8. Limitation of Liability. In no event and under no legal theory, 158 | # whether in tort (including negligence), contract, or otherwise, 159 | # unless required by applicable law (such as deliberate and grossly 160 | # negligent acts) or agreed to in writing, shall any Contributor be 161 | # liable to You for damages, including any direct, indirect, special, 162 | # incidental, or consequential damages of any character arising as a 163 | # result of this License or out of the use or inability to use the 164 | # Work (including but not limited to damages for loss of goodwill, 165 | # work stoppage, computer failure or malfunction, or any and all 166 | # other commercial damages or losses), even if such Contributor 167 | # has been advised of the possibility of such damages. 168 | 169 | # 9. Accepting Warranty or Additional Liability. While redistributing 170 | # the Work or Derivative Works thereof, You may choose to offer, 171 | # and charge a fee for, acceptance of support, warranty, indemnity, 172 | # or other liability obligations and/or rights consistent with this 173 | # License. However, in accepting such obligations, You may act only 174 | # on Your own behalf and on Your sole responsibility, not on behalf 175 | # of any other Contributor, and only if You agree to indemnify, 176 | # defend, and hold each Contributor harmless for any liability 177 | # incurred by, or claims asserted against, such Contributor by reason 178 | # of your accepting any such warranty or additional liability. 179 | 180 | # END OF TERMS AND CONDITIONS 181 | 182 | # APPENDIX: How to apply the Apache License to your work. 183 | 184 | # To apply the Apache License to your work, attach the following 185 | # boilerplate notice, with the fields enclosed by brackets "[]" 186 | # replaced with your own identifying information. (Don't include 187 | # the brackets!) The text should be enclosed in the appropriate 188 | # comment syntax for the file format. We also recommend that a 189 | # file or class name and description of purpose be included on the 190 | # same "printed page" as the copyright notice for easier 191 | # identification within third-party archives. 192 | 193 | # Copyright [yyyy] [name of copyright owner] 194 | 195 | # Licensed under the Apache License, Version 2.0 (the "License"); 196 | # you may not use this file except in compliance with the License. 197 | # You may obtain a copy of the License at 198 | 199 | # http://www.apache.org/licenses/LICENSE-2.0 200 | 201 | # Unless required by applicable law or agreed to in writing, software 202 | # distributed under the License is distributed on an "AS IS" BASIS, 203 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 204 | # See the License for the specific language governing permissions and 205 | # limitations under the License. 206 | 207 | 208 | from math import sqrt 209 | import torch 210 | import torch.nn as nn 211 | import torch.nn.functional as F 212 | 213 | 214 | @torch.jit.script 215 | def silu(x): 216 | return x * torch.sigmoid(x) 217 | 218 | 219 | def Conv1d(*args, **kwargs): 220 | layer = nn.Conv1d(*args, **kwargs) 221 | nn.init.kaiming_normal_(layer.weight) 222 | return layer 223 | 224 | 225 | class DiffusionEmbedding(nn.Module): 226 | """Diffusion Timestep Embedding""" 227 | 228 | def __init__(self, max_steps): 229 | super().__init__() 230 | self.register_buffer( 231 | "embedding", self._build_embedding(max_steps), persistent=False 232 | ) 233 | self.projection1 = nn.Linear(128, 512) 234 | self.projection2 = nn.Linear(512, 512) 235 | 236 | def _build_embedding(self, max_steps): 237 | steps = torch.arange(max_steps).unsqueeze(1) 238 | dims = torch.arange(64).unsqueeze(0) 239 | table = steps * 10.0 ** (dims * 4.0 / 63.0) 240 | table = torch.cat([torch.sin(table), torch.cos(table)], dim=1) 241 | return table 242 | 243 | def _lerp_embedding(self, t): 244 | low_idx = torch.floor(t).long() 245 | high_idx = torch.ceil(t).long() 246 | low = self.embedding[low_idx] 247 | high = self.embedding[high_idx] 248 | return low + (high - low) * (t - low_idx) 249 | 250 | def forward(self, timestep): 251 | if timestep.dtype in [torch.int32, torch.int64]: 252 | x = self.embedding[timestep] 253 | else: 254 | x = self._lerp_embedding(timestep) 255 | x = self.projection1(x) 256 | x = silu(x) 257 | x = self.projection2(x) 258 | x = silu(x) 259 | return x 260 | 261 | 262 | class SpectrogramUpsampler(nn.Module): 263 | """Convolution-based Upsampler (16x)""" 264 | 265 | def __init__(self): 266 | super().__init__() 267 | self.conv1 = nn.ConvTranspose2d(1, 1, [3, 32], stride=[1, 16], padding=[1, 8]) 268 | self.conv2 = nn.ConvTranspose2d(1, 1, [3, 32], stride=[1, 16], padding=[1, 8]) 269 | 270 | def forward(self, x): 271 | x = self.conv1(x) 272 | x = F.leaky_relu(x, 0.4) 273 | x = self.conv2(x) 274 | x = F.leaky_relu(x, 0.4) 275 | x = torch.squeeze(x, 1) 276 | return x 277 | 278 | 279 | class ResidualBlock(nn.Module): 280 | """Building Block for Diffusion Model""" 281 | 282 | def __init__(self, n_mels, residual_channels, dilation): 283 | super().__init__() 284 | self.dilated_conv = Conv1d( 285 | residual_channels, 286 | 2 * residual_channels, 287 | 3, 288 | padding=dilation, 289 | dilation=dilation, 290 | ) 291 | self.diffusion_projection = nn.Linear(512, residual_channels) 292 | self.conditioner_projection = Conv1d(n_mels, 2 * residual_channels, 1) 293 | self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1) 294 | 295 | def forward(self, x, timestep, conditioner): 296 | timestep = self.diffusion_projection(timestep).unsqueeze(-1) 297 | y = x + timestep 298 | conditioner = self.conditioner_projection(conditioner) 299 | y = self.dilated_conv(y) + conditioner 300 | y_gate, y_filter = torch.chunk(y, 2, dim=1) 301 | y = torch.sigmoid(y_gate) * torch.tanh(y_filter) 302 | 303 | y = self.output_projection(y) 304 | residual, skip = torch.chunk(y, 2, dim=1) 305 | return (x + residual) / sqrt(2.0), skip 306 | 307 | 308 | class GomiDiff(nn.Module): 309 | """GomiDiff: Gaudio open mel-spectrogram inversion with diffusion models. 310 | 311 | Based on Diffwave (Kong et al. 2020). 312 | 313 | Args: 314 | n_mels (int): number of frequency bins 315 | residual_layers (int): number of residual layers 316 | residual_channels (int): dimension of channels in aresidual layer 317 | dilation_cycle_length (int): number of dilation cycles 318 | num_diffusion_steps (int): number of diffusion steps 319 | """ 320 | 321 | def __init__( 322 | self, 323 | in_channels: int, 324 | residual_layers: int, 325 | residual_channels: int, 326 | dilation_cycle_length: int, 327 | num_diffusion_steps: int, 328 | ): 329 | super().__init__() 330 | self.dilation_cycle_length = dilation_cycle_length 331 | self.num_diffusion_steps = num_diffusion_steps 332 | 333 | self.film_layers = nn.ModuleList( 334 | [ 335 | nn.Conv1d(residual_channels, 2 * residual_channels, 1) 336 | for _ in range(residual_layers // dilation_cycle_length) 337 | ] 338 | ) 339 | 340 | self.input_projection = Conv1d(1, residual_channels, 1) 341 | self.diffusion_embedding = DiffusionEmbedding(num_diffusion_steps) 342 | self.spectrogram_upsampler = SpectrogramUpsampler() 343 | self.residual_layers = nn.ModuleList( 344 | [ 345 | ResidualBlock( 346 | in_channels, residual_channels, 2 ** (i % dilation_cycle_length) 347 | ) 348 | for i in range(residual_layers) 349 | ] 350 | ) 351 | self.skip_projection = Conv1d(residual_channels, residual_channels, 1) 352 | self.output_projection = Conv1d(residual_channels, 1, 1) 353 | nn.init.zeros_(self.output_projection.weight) 354 | 355 | def forward(self, signal, timestep, spectrogram): 356 | x = self.input_projection(signal) 357 | x = F.relu(x) 358 | 359 | timestep = self.diffusion_embedding(timestep) 360 | spectrogram = self.spectrogram_upsampler(spectrogram) 361 | 362 | skip = None 363 | for i, layer in enumerate(self.residual_layers): 364 | if i % self.dilation_cycle_length == 0: 365 | film = self.film_layers[i // self.dilation_cycle_length](spectrogram) 366 | scale, shift = torch.chunk(film, 2, dim=1) 367 | x = (1 + 0.01 * scale) * x + shift 368 | 369 | x, skip_connection = layer(x, timestep, spectrogram) 370 | skip = skip_connection if skip is None else skip_connection + skip 371 | 372 | x = skip / sqrt(len(self.residual_layers)) 373 | x = self.skip_projection(x) 374 | x = F.relu(x) 375 | x = self.output_projection(x) 376 | return x 377 | -------------------------------------------------------------------------------- /gomin/models/diffusion_wrapper.py: -------------------------------------------------------------------------------- 1 | import diffusers 2 | import torch 3 | from tqdm import tqdm 4 | 5 | from .common import MelspecInversion 6 | from .diffusion import GomiDiff 7 | 8 | 9 | class DiffusionWrapper(MelspecInversion): 10 | def __init__( 11 | self, 12 | in_channels: int, 13 | residual_layers: int, 14 | residual_channels: int, 15 | dilation_cycle_length: int, 16 | num_diffusion_steps: int, 17 | **mel_config, 18 | ): 19 | super().__init__(n_mels=in_channels, **mel_config) 20 | self.model = GomiDiff( 21 | in_channels=in_channels, 22 | residual_layers=residual_layers, 23 | residual_channels=residual_channels, 24 | dilation_cycle_length=dilation_cycle_length, 25 | num_diffusion_steps=num_diffusion_steps, 26 | ) 27 | self.scheduler = diffusers.DDPMScheduler( 28 | beta_start=0.0001, 29 | beta_end=0.05, 30 | num_train_timesteps=self.model.num_diffusion_steps, 31 | ) 32 | self.scheduler.set_timesteps(num_inference_steps=self.model.num_diffusion_steps) 33 | 34 | @torch.no_grad() 35 | def forward(self, spectrogram, return_whole_sequence=False): 36 | shape = (spectrogram.size(0), 1, self.hop_length * spectrogram.size(-1)) 37 | 38 | x = torch.randn(*shape, device=spectrogram.device) 39 | 40 | if return_whole_sequence: 41 | output_sequence = [x.clone()] 42 | 43 | for t in tqdm(self.scheduler.timesteps, total=len(self.scheduler.timesteps)): 44 | timestep = torch.tensor([t], device=spectrogram.device).long() 45 | predicted_noise = self.model(x, timestep, spectrogram) 46 | 47 | scheduler_output = self.scheduler.step( 48 | predicted_noise, timestep=t, sample=x 49 | ) 50 | x = scheduler_output["prev_sample"] 51 | 52 | if return_whole_sequence: 53 | output_sequence.insert(0, x.clone()) 54 | 55 | if return_whole_sequence: 56 | return output_sequence 57 | return x 58 | -------------------------------------------------------------------------------- /gomin/models/gan.py: -------------------------------------------------------------------------------- 1 | # This file is modified from https://github.com/jik876/hifi-gan 2 | 3 | # Original terms for the license are as follows: 4 | 5 | # MIT License 6 | 7 | # Copyright (c) 2020 Jungil Kong 8 | 9 | # Permission is hereby granted, free of charge, to any person obtaining a copy 10 | # of this software and associated documentation files (the "Software"), to deal 11 | # in the Software without restriction, including without limitation the rights 12 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13 | # copies of the Software, and to permit persons to whom the Software is 14 | # furnished to do so, subject to the following conditions: 15 | 16 | # The above copyright notice and this permission notice shall be included in all 17 | # copies or substantial portions of the Software. 18 | 19 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 25 | # SOFTWARE. 26 | 27 | import functools 28 | from typing import List 29 | 30 | import numpy as np 31 | import scipy.signal 32 | 33 | import torch 34 | import torch.nn as nn 35 | import torch.nn.functional as F 36 | 37 | from .common import MelspecInversion 38 | 39 | 40 | def get_padding(kernel_size, dilation=1): 41 | return int((kernel_size * dilation - dilation) / 2) 42 | 43 | 44 | class Snake(nn.Module): 45 | """Periodic activation function with learned parameter""" 46 | 47 | def __init__(self, kernel_size): 48 | super().__init__() 49 | self.alpha = nn.Parameter(data=torch.randn(kernel_size), requires_grad=True) 50 | 51 | def forward(self, x): 52 | return x + (torch.sin(self.alpha * x) ** 2) / self.alpha 53 | 54 | 55 | class AmpBlock(nn.Module): 56 | """Anti-aliased Multi-Periodicity Block with 2 Convolution per Step""" 57 | 58 | def __init__(self, channels, kernel_size, dilation, sample_rate): 59 | super().__init__() 60 | 61 | conv1d = functools.partial(nn.Conv1d, channels, channels, kernel_size, 1) 62 | 63 | # Low pass filter 64 | filter_ = scipy.signal.firwin( 65 | numtaps=12, cutoff=sample_rate / 4, width=0.6, fs=sample_rate 66 | ) 67 | filter_ = np.sqrt(2) * torch.from_numpy(filter_).reshape(1, 1, -1) 68 | self.register_buffer("lpf", filter_.tile(channels, 1, 1).float()) 69 | 70 | # Layer: snake activations 71 | self.snakes1 = nn.ModuleList([Snake([channels, 1]) for _ in dilation]) 72 | self.snakes2 = nn.ModuleList([Snake([channels, 1]) for _ in dilation]) 73 | 74 | self.convs1 = nn.ModuleList( 75 | [conv1d(dilation=d, padding=get_padding(kernel_size, d)) for d in dilation] 76 | ) 77 | self.convs2 = nn.ModuleList( 78 | [conv1d(dilation=1, padding=get_padding(kernel_size, 1)) for _ in dilation] 79 | ) 80 | 81 | def upsample(self, x): 82 | b, c, _ = x.shape 83 | x = torch.cat([x.unsqueeze(-1), torch.zeros_like(x.unsqueeze(-1))], dim=-1) 84 | x = x.reshape(b, c, -1) 85 | x = F.conv1d(x, self.lpf, padding=(self.lpf.shape[-1] // 2), groups=c) 86 | return x, c 87 | 88 | def downsample(self, x, groups): 89 | x = F.conv1d(x, self.lpf, padding=(self.lpf.shape[-1] // 2), groups=groups) 90 | x = x[..., 1:-1:2] 91 | return x 92 | 93 | def forward(self, x): 94 | for i in range(len(self.convs1)): 95 | xt, c = self.upsample(x) 96 | xt = self.snakes1[i](xt) 97 | xt = self.downsample(xt, c) 98 | xt = self.convs1[i](xt) 99 | xt = self.snakes2[i](xt) 100 | xt = self.convs2[i](xt) 101 | x = xt + x 102 | return x 103 | 104 | 105 | class FreqTemporalFiLM(nn.Module): 106 | def __init__(self, in_channels, upsample_in_channels, upsample_strides): 107 | super().__init__() 108 | in_ch = in_channels # initial value 109 | out_ch = upsample_in_channels # for scale and shift each 110 | self.convs = nn.ModuleList() 111 | for s in upsample_strides: 112 | self.convs.append( 113 | nn.ConvTranspose1d(in_ch, out_ch, (2 * s), s, padding=(s // 2)) 114 | ) 115 | in_ch = out_ch 116 | out_ch = out_ch // 2 117 | 118 | def forward(self, x): 119 | scales, shifts = [], [] 120 | for conv in self.convs: 121 | x = conv(x) 122 | y1, y2 = torch.tensor_split(x, 2, dim=1) 123 | scales.append(y1) 124 | shifts.append(y2) 125 | return scales, shifts 126 | 127 | 128 | class GomiGAN(MelspecInversion): 129 | """GomiGAN: Gaudio open mel-spectrogram inversion GAN. 130 | 131 | Based on HiFi-GAN (Kong et al. 2020) and BigVGAN (Lee et al. 2022). 132 | 133 | Args: 134 | in_channels (int): Number of input channels (i.e. number of mel bands) 135 | upsample_in_channels (int): Number of channels in the first upsampling layer 136 | upsample_strides (List[int]): Upsampling strides 137 | resblock_kernel_sizes (List[int]): Kernel sizes for the resblocks 138 | resblock_dilations (List[List[int]]): Dilations for each resblock 139 | sample_rate (int): Sample rate of the audio signal 140 | """ 141 | 142 | def __init__( 143 | self, 144 | in_channels: int, 145 | upsample_in_channels: int, 146 | upsample_strides: List[int], 147 | resblock_kernel_sizes: List[int], 148 | resblock_dilations: List[List[int]], 149 | sample_rate: int = 24000, 150 | **mel_config, 151 | ): 152 | super().__init__(n_mels=in_channels, sample_rate=sample_rate, **mel_config) 153 | self.num_kernels = len(resblock_kernel_sizes) 154 | self.num_upsamples = len(upsample_strides) 155 | self.upsample_in_channels = upsample_in_channels 156 | self.sample_rate = sample_rate 157 | 158 | # Initial sample rate for anti-aliased snake function 159 | sr = (2 * sample_rate) / np.prod(upsample_strides) 160 | 161 | # Layers 162 | self.film_generator = FreqTemporalFiLM( 163 | in_channels, upsample_in_channels, upsample_strides 164 | ) 165 | self.conv_pre = nn.Conv1d(in_channels, upsample_in_channels, 7, 1, padding=3) 166 | 167 | self.snake_ups = nn.ModuleList() 168 | self.ups = nn.ModuleList() 169 | self.resblocks = nn.ModuleList() 170 | for i, us in enumerate(upsample_strides): 171 | ch = upsample_in_channels // (2 ** (i + 1)) 172 | sr *= us 173 | 174 | self.snake_ups.append(Snake([2 * ch, 1])) 175 | self.ups.append(nn.ConvTranspose1d(2 * ch, ch, 2 * us, us, us // 2)) 176 | 177 | for rk, rd in zip(resblock_kernel_sizes, resblock_dilations): 178 | self.resblocks.append(AmpBlock(ch, rk, rd, sr)) 179 | 180 | self.snake_post = Snake([ch, 1]) 181 | self.conv_post = nn.Conv1d(ch, 1, 7, 1, padding=3) 182 | 183 | def forward(self, x): 184 | # Get FiLM weights 185 | scales, shifts = self.film_generator(x) 186 | 187 | x = self.conv_pre(x) 188 | for i in range(self.num_upsamples): 189 | x = self.snake_ups[i](x) 190 | x = self.ups[i](x) 191 | 192 | # Apply FiLM weights 193 | x = (1 + 0.01 * scales[i]) * x + shifts[i] 194 | 195 | xs = 0.0 196 | for j in range(self.num_kernels): 197 | xs += self.resblocks[i * self.num_kernels + j](x) 198 | x = xs / self.num_kernels 199 | x = self.snake_post(x) 200 | x = self.conv_post(x) 201 | x = torch.tanh(x) 202 | return x 203 | -------------------------------------------------------------------------------- /gomin/run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from typing import Union 4 | 5 | import librosa 6 | import numpy as np 7 | import soundfile as sf 8 | import torch 9 | 10 | from . import config 11 | from . import models 12 | from .utils import AUDIO_EXTS, MEL_EXTS, device_type, preprocess_inout_files 13 | 14 | 15 | def load_audio( 16 | audio_path: Union[str, bytes, os.PathLike], 17 | sample_rate: int, 18 | mono=True, 19 | fast_resample=False, 20 | ): 21 | """Load and resample audio file""" 22 | 23 | if fast_resample: 24 | audio, orig_sr = librosa.load(audio_path, sr=None, mono=mono) 25 | audio = librosa.resample( 26 | audio, orig_sr=orig_sr, target_sr=sample_rate, res_type="polyphase" 27 | ) 28 | else: 29 | audio, _ = librosa.load(audio_path, sr=sample_rate, mono=mono) 30 | 31 | audio = np.atleast_2d(audio) 32 | audio = torch.from_numpy(audio).float() 33 | return audio 34 | 35 | 36 | @torch.no_grad() 37 | def _analysis( 38 | model: torch.nn.Module, audio: Union[np.ndarray, torch.Tensor], device: torch.device 39 | ): 40 | """Convert waveform into melspectrogram.""" 41 | if isinstance(audio, np.ndarray): 42 | audio = torch.from_numpy(audio).float() 43 | 44 | if audio.ndim < 2: 45 | audio = torch.atleast_2d(audio) 46 | 47 | if audio.device != device: 48 | audio = audio.to(device) 49 | 50 | melspec = model.prepare_melspectrogram(audio) 51 | return melspec 52 | 53 | 54 | @torch.no_grad() 55 | def _synthesis(model: torch.nn.Module, melspec: torch.Tensor, device: torch.device): 56 | """Convert melspectrogram into waveform.""" 57 | if melspec.ndim < 3: 58 | melspec = torch.atleast_3d(melspec) 59 | 60 | if melspec.device != device: 61 | melspec = melspec.to(device) 62 | 63 | recon_audio = model(melspec) 64 | recon_audio = recon_audio.squeeze().cpu() 65 | return recon_audio 66 | 67 | 68 | def analysis_synthesis( 69 | model: torch.nn.Module, 70 | in_file: Union[str, bytes, os.PathLike], 71 | out_file: Union[str, bytes, os.PathLike], 72 | device: torch.device, 73 | fast_resample=False, 74 | ): 75 | """Process and save files.""" 76 | _, input_ext = os.path.splitext(in_file) 77 | _, output_ext = os.path.splitext(out_file) 78 | 79 | if input_ext in AUDIO_EXTS: 80 | audio = load_audio(in_file, model.sample_rate, fast_resample=fast_resample) 81 | melspec = _analysis(model, audio, device=device) 82 | elif input_ext in MEL_EXTS: 83 | if input_ext == ".npy": 84 | melspec = np.load(in_file) 85 | melspec = torch.from_numpy(melspec).float() 86 | else: 87 | melspec_dict = torch.load(in_file) 88 | melspec = melspec_dict["melspec"] 89 | assert melspec_dict["n_mels"] == model.n_mels, ( 90 | f"Wrong `n_mels`. expected [{model.n_mels}], got" 91 | f" [{melspec_dict['n_mels']}]." 92 | ) 93 | assert melspec_dict["sample_rate"] == model.sample_rate, ( 94 | f"Wrong `sample_rate`. expected [{model.sample_rate}], got" 95 | f" [{melspec_dict['sample_rate']}]." 96 | ) 97 | assert melspec_dict["win_length"] == model.win_length, ( 98 | f"Wrong `win_length`. expected [{model.win_length}], got" 99 | f" [{melspec_dict['win_length']}]." 100 | ) 101 | assert melspec_dict["hop_length"] == model.hop_length, ( 102 | f"Wrong `hop_length`. expected [{model.hop_length}], got" 103 | f" [{melspec_dict['hop_length']}]." 104 | ) 105 | else: 106 | print(f"Unsupported input file extension: {input_ext} for {in_file}.") 107 | 108 | if output_ext in MEL_EXTS: 109 | assert output_ext == ".pt", ( 110 | "Only '.pt' file is supported for melspectrogram extraction. got" 111 | f" '{output_ext}'." 112 | ) 113 | torch.save( 114 | { 115 | "melspec": melspec.cpu(), 116 | "n_mels": model.n_mels, 117 | "sample_rate": model.sample_rate, 118 | "win_length": model.win_length, 119 | "hop_length": model.hop_length, 120 | }, 121 | out_file, 122 | ) 123 | elif output_ext in AUDIO_EXTS: 124 | assert ( 125 | output_ext == ".wav" 126 | ), "Only '.wav' file is supported for melspectrogram inversion. got" 127 | f" '{output_ext}'." 128 | recon_audio = _synthesis(model, melspec, device=device) 129 | sf.write(out_file, recon_audio, model.sample_rate, subtype="PCM_16") 130 | else: 131 | print(f"Unsupported output file extension: {output_ext} for {out_file}.") 132 | 133 | 134 | def parse_args(): 135 | parser = argparse.ArgumentParser() 136 | parser.add_argument( 137 | "-m", 138 | "--model", 139 | type=str, 140 | default=None, 141 | choices=["diffusion", "gan"], 142 | help="Model type to run.", 143 | ) 144 | parser.add_argument( 145 | "-p", 146 | "--model_path", 147 | type=str, 148 | default="checkpoints/", 149 | help="Directory path to model checkpoint.", 150 | ) 151 | parser.add_argument( 152 | "-i", 153 | "--input_files", 154 | nargs="+", 155 | type=str, 156 | help=f"Path to input files. Audio files with {AUDIO_EXTS} extension and" 157 | f" melspectrogram files with {MEL_EXTS} extension are supported.", 158 | ) 159 | parser.add_argument( 160 | "-o", 161 | "--output_files", 162 | nargs="+", 163 | type=str, 164 | default=["outputs/"], 165 | help="(Optional) Path to output files. Audio files with '.wav' extension and" 166 | " melspectrogram files with '.pt' extension are supported. If both input and" 167 | " output files are melspectrogram, error will be raised. (default: `outputs/`)", 168 | ) 169 | parser.add_argument( 170 | "-d", 171 | "--device", 172 | nargs="+", 173 | type=device_type, 174 | default=["cpu"], 175 | help="Device to use. Currently multi-GPU inference is not supported. (default: " 176 | "'cpu')", 177 | ) 178 | parser.add_argument( 179 | "--n_mels", 180 | type=int, 181 | default=None, 182 | help="Number of bins in melspectrogram. If `args.model` is provided, this value" 183 | " will be ignored.", 184 | ) 185 | parser.add_argument( 186 | "--sample_rate", 187 | type=int, 188 | default=None, 189 | help="Sample rate for audio sample and melspetrogram. If `args.model` is" 190 | " provided, this value will be ignored.", 191 | ) 192 | parser.add_argument( 193 | "--win_length", 194 | type=int, 195 | default=None, 196 | help="Windoew length for melspetrogram. If `args.model` is provided, this value" 197 | " will be ignored", 198 | ) 199 | parser.add_argument( 200 | "--hop_length", 201 | type=int, 202 | default=None, 203 | help="Hop length for melspectrogram. If `args.model` is provided, this value" 204 | " will be ignored.", 205 | ) 206 | 207 | args = parser.parse_args() 208 | return args 209 | 210 | 211 | def process(model, device, input_files, output_files): 212 | model.to(device) 213 | 214 | for in_file, out_file in zip(input_files, output_files): 215 | analysis_synthesis( 216 | model=model, 217 | in_file=in_file, 218 | out_file=out_file, 219 | device=device, 220 | fast_resample=len(input_files) > 1, 221 | ) 222 | 223 | 224 | def main(): 225 | args = parse_args() 226 | args.input_files, args.output_files = preprocess_inout_files( 227 | args.input_files, args.output_files 228 | ) 229 | 230 | if args.model == "gan": 231 | model = models.GomiGAN.from_pretrained( 232 | pretrained_model_path="checkpoints/gan_state_dict.pt", 233 | **config.GANConfig().__dict__, 234 | ) 235 | elif args.model == "diffusion": 236 | model = models.DiffusionWrapper.from_pretrained( 237 | pretrained_model_path="checkpoints/diffusion_state_dict.pt", 238 | **config.DiffusionConfig().__dict__, 239 | ) 240 | elif args.model is None: 241 | model = models.MelspecInversion( 242 | n_mels=args.n_mels, 243 | sample_rate=args.sample_rate, 244 | win_length=args.win_length, 245 | hop_length=args.hop_length, 246 | ) 247 | 248 | process(model, args.device[0], args.input_files, args.output_files) 249 | 250 | 251 | if __name__ == "__main__": 252 | main() 253 | -------------------------------------------------------------------------------- /gomin/utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from typing import Final, List, Union 4 | import warnings 5 | 6 | 7 | AUDIO_EXTS: Final = [".wav", ".mp3"] 8 | MEL_EXTS: Final = [".npy", ".pt"] 9 | OUTPUT_EXTS: Final = [".wav", ".pt"] 10 | INPUT_EXTS: Final = AUDIO_EXTS + MEL_EXTS 11 | 12 | 13 | def device_type(device: Union[str, int]): 14 | if device == "cpu": 15 | return device 16 | 17 | assert device.isnumeric() 18 | return f"cuda:{device}" 19 | 20 | 21 | def check_paths(input_filelist: List[str], output_filelist: List[str]): 22 | """Check errors and return filelist if no error detected. 23 | 24 | This methode checks 3 types of errors: 25 | (1) Extension 26 | Check if filelists have proper extensions. The allowed extensions 27 | are defined as constants. 28 | (2) Lengths 29 | Check input and output filelist have same length 30 | (3) Undefined behaviours 31 | If there is 'mel -> mel' match in input and output filelists, it 32 | raises error. 33 | """ 34 | 35 | def _check_extension(filelist: List[str], key: str): 36 | """Check if filelist""" 37 | allowed_exts = INPUT_EXTS if key == "input" else OUTPUT_EXTS 38 | 39 | if any([os.path.splitext(x)[1] not in allowed_exts for x in filelist]): 40 | raise ValueError( 41 | f"Unsupported format for `{key}_files`. Only {allowed_exts} are" 42 | " supported." 43 | ) 44 | 45 | def _check_lengths(): 46 | if len(input_filelist) != len(output_filelist): 47 | raise ValueError( 48 | f"Mistmatched input / output length. Input length:" 49 | f" {len(input_filelist)}, Output length: {len(output_filelist)}" 50 | ) 51 | 52 | def _check_mel_to_mel(): 53 | for in_file, out_file in zip(input_filelist, output_filelist): 54 | _, in_ext = os.path.splitext(in_file) 55 | _, out_ext = os.path.splitext(out_file) 56 | if in_ext == out_ext in MEL_EXTS: 57 | raise ValueError( 58 | "Unsupported behaviour. Behaviour for melspectrogram to" 59 | f" melspectrogram is not defined. got '-i {in_file} -o " 60 | f"{out_file}'." 61 | ) 62 | 63 | _check_extension(input_filelist, key="input") 64 | _check_extension(output_filelist, key="output") 65 | _check_lengths() 66 | _check_mel_to_mel() 67 | return input_filelist, output_filelist 68 | 69 | 70 | def preprocess_inout_files(input_files: List[str], output_files: List[str]): 71 | """Process input and output filelists. 72 | 73 | For various types of input / output path arguments, it preprocess paths. For 74 | input paths, it reads text files containing audio file paths or search 75 | directories using `glob.glob`. For output paths, it automatically generates 76 | output file paths corresponding to input file paths. 77 | 78 | For more informations or examples, please refer `tests/testcases_path.yaml`. 79 | 80 | Params: 81 | input_files: file paths for inputs. 82 | output_files: file paths for outputs. 83 | 84 | Notes: 85 | `input_files` and `output_files` support various formats. However only 86 | naive list of paths is supported for multiple item, i.e, multiple items 87 | for directory path, text file or paths including wild card will result 88 | in unexpected outcome. 89 | 90 | Supported formats for arguments include: 91 | single or multiple audio / mel file paths: 92 | - ["foo/bar/input2.wav"] 93 | - ["foo/bar/input1.wav", "bar/input2.pt", "foo/input3.mp3"] 94 | (single) text file path 95 | - ["foo/bar/input.txt"] 96 | (single) directory 97 | - ["foo/"] 98 | - ["bar/"] 99 | (single) file or directory path with wild card 100 | - ["foo/*.wav"] 101 | - ["foo/**/*] 102 | - ["bar/**/input*.wav"] 103 | 104 | """ 105 | 106 | def _common_process(filelist, key): 107 | # Only "input" and "output" are allowed for key 108 | assert key in ["input", "output"], f"Unknown key: {key}" 109 | 110 | file_base, file_ext = os.path.splitext(filelist[0]) 111 | 112 | if file_ext == ".txt": 113 | return [line.strip() for line in open(filelist[0])] 114 | elif "*" in file_base: 115 | return filelist[0] 116 | elif file_ext == "": 117 | # Directory path case 118 | return os.path.join(file_base, "**/*") 119 | 120 | return filelist 121 | 122 | def _add_file(input_filelist: List[str], output_filelist: List[str], new_item): 123 | """Check file name confliction and add file path to the list. 124 | 125 | If `new_item` exists in `input_filelist` or `output_filelist`, rename it 126 | by add extension (e.g., ".wav") of input file name right before its 127 | real extension. 128 | """ 129 | while new_item in input_filelist + output_filelist: 130 | new_item_base, new_item_ext = os.path.splitext(new_item) 131 | _, input_ext = os.path.splitext(input_filelist[len(output_filelist)]) 132 | new_item = new_item_base + input_ext + new_item_ext 133 | output_filelist.append(new_item) 134 | 135 | # Read text files or add wild card to the path prpoperly 136 | input_files = _common_process(input_files, "input") 137 | output_files = _common_process(output_files, "output") 138 | 139 | if len(input_files) == 0: 140 | # If `input_files` is empty, make a warning and exit immediately 141 | warnings.warn("No inputs.", UserWarning) 142 | return [], [] 143 | elif "*" not in input_files and "*" not in output_files: 144 | # If there are no wild card in both lists, return them 145 | return check_paths(input_files, output_files) 146 | 147 | # If there is wild card in `input_files`, `output_files` should have it too 148 | assert "*" in output_files 149 | 150 | # Search `input_files` using `glob.glob` 151 | if "*" in input_files: 152 | in_base, in_ext = os.path.splitext(input_files) 153 | 154 | input_files = [] 155 | for ext in INPUT_EXTS if in_ext in ["", ".*"] else [in_ext]: 156 | input_files.extend(glob.glob(f"{in_base}{ext}", recursive=True)) 157 | 158 | # Aliases for further processing 159 | input_root = os.path.commonpath([os.path.split(f)[0] for f in input_files]) 160 | input_prefix = os.path.commonprefix([os.path.split(f)[1] for f in input_files]) 161 | 162 | keep_subdirs = "**" in output_files 163 | output_root, output_tail = (x.split("*", 1)[0] for x in os.path.split(output_files)) 164 | 165 | _, output_ext = os.path.splitext(output_files) 166 | if not output_ext: 167 | output_ext = ".wav" 168 | 169 | output_files = [] 170 | for in_file in input_files: 171 | subdir, filename = os.path.split(in_file) 172 | subdir = subdir.removeprefix(input_root).strip("/") 173 | 174 | # Drop sub-directory path if needed 175 | if not keep_subdirs: 176 | subdir = "" 177 | 178 | # If output files have certain naming convention 179 | # e.g.) "foo/bar/output*.wav" 180 | if output_tail: 181 | filename = output_tail + filename.removeprefix(input_prefix) 182 | 183 | file_base, _ = os.path.splitext(filename) 184 | filepath = os.path.join(output_root, subdir, f"{file_base}{output_ext}") 185 | 186 | # Check if `filepath` conflicts and add it to `output_files` 187 | _add_file(input_files, output_files, filepath) 188 | 189 | return check_paths(input_files, output_files) 190 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # PYTHON 3.10.9 2 | torch==1.13.1+cu116 3 | torchaudio==0.13.1+cu116 4 | diffusers==0.17.0\ 5 | librosa>=0.9.2 6 | scipy>=1.10.0 7 | soundfile>=0.11.0 8 | numpy==1.24.3 9 | pyyaml 10 | tqdm -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="gaudio_open_vocoder", 5 | version="0.1.0", 6 | packages=find_packages(include=["gomin", "gomin.*"]), 7 | install_requires=[ 8 | "torch>=1.13.1", 9 | "torchaudio>=0.13.1", 10 | "diffusers>=0.17.0", 11 | "librosa>=0.9.2", 12 | "scipy>=1.10.0", 13 | "soundfile>=0.11.0", 14 | "numpy>=1.24.3", 15 | "pyyaml", 16 | "tqdm", 17 | ], 18 | ) 19 | --------------------------------------------------------------------------------