├── LICENSE ├── README.md ├── config ├── .DS_Store ├── __init__.py ├── conf.py ├── datasets │ └── data-config.yaml ├── logging.ini └── pipelines │ ├── asr_to_tts.yaml │ └── yt_data.yaml ├── manager ├── __init__.py ├── downloader.py └── runner.py ├── modules ├── __init__.py ├── audio.py ├── audio_superres.py ├── audio_superres_utils.py ├── chunking.py ├── common.py ├── demucs_utils.py ├── denoise_audio.py ├── denoiser_utils.py ├── lang_list.py ├── music_separation.py └── transcribe.py ├── requirements.txt ├── utils ├── __init__.py ├── data.py ├── helpers.py ├── io.py ├── loggers.py └── validators.py └── workers ├── __init__.py └── pipeline.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Audio pipeline for TTS Datasets 2 | 3 | This project aims to provide high level APIs for various feature engineering techniques for processing audio files. The implementation follows a modularized and config based approach, so any dataset processing pipeline can be built and managed using the same. 4 | 5 | ### Creating new pipeline 6 | 7 | Creat an yaml file under `config/pipelines/` directory with the following structure 8 | 9 | ``` 10 | pipeline: 11 | loader: 12 | target: manager.Downloader 13 | args: 14 | configs: 15 | - config/datasets/data-config.yaml 16 | save_dir: raw_data/yt_data 17 | manager: 18 | target: manager.YoutubeRunner 19 | processors: 20 | - name: chunking 21 | target: modules.AudioChunking 22 | args: 23 | model_choice: pydub_chunking 24 | - name: denoise_audio 25 | target: modules.DenoiseAudio 26 | args: 27 | model_choice: meta_denoiser_dns48 28 | - name: audio_superres 29 | target: modules.SuperResAudio 30 | args: 31 | model_choice: voicefixer 32 | ... 33 | ``` 34 | 35 | **Pipeline Schema** 36 | 37 | - **Loader**: The entry point for fetching data from various sources like S3, local systems, or blob storage. 38 | - **Manager**: Specifies the manager class responsible for running the pipeline. 39 | - **Processors**: An ordered list of processors to apply for feature extraction or other manipulations. 40 | 41 | 42 | If new feature extractors or manager are required for your needs, check the `modules/` directory for understanding the structure and create or update the objects as needed. 43 | 44 | ### Run pipleines 45 | 46 | ``` 47 | python workers/pipeline.py --configs 48 | ``` 49 | ## Acknowledgements 50 | credit a few of the amazing folks in the community that have helped to this happen: 51 | - [bud-studio](https://bud.studio/) - For providing a initial framework 52 | -------------------------------------------------------------------------------- /config/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAION-AI/Text-to-speech/4d69b12975b3a74f37b11c93edf83d55e133b649/config/.DS_Store -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | from .conf import settings -------------------------------------------------------------------------------- /config/conf.py: -------------------------------------------------------------------------------- 1 | from os import path as osp 2 | from environs import Env 3 | 4 | from utils.helpers import NestedNamespace 5 | 6 | 7 | env = Env() 8 | env.read_env() 9 | 10 | DIR_PATH = osp.dirname(osp.realpath(__file__)) 11 | ROOT_PATH = osp.abspath(osp.join(osp.dirname(__file__), ".." + osp.sep)) 12 | 13 | 14 | settings = { 15 | "CONSTANTS": {}, 16 | "ROOT_PATH": ROOT_PATH, 17 | "DEBUG": env.bool("DEBUG", False), 18 | "LOG_LEVEL": env.str("LOG_LEVEL", "DEBUG"), 19 | "LOG_DIR": env.str("LOG_DIR", osp.join(ROOT_PATH, "logs")), 20 | "DATA_DIR": env.str("DATA_DIR", osp.join(ROOT_PATH, "data")), 21 | "CACHE_DIR": env.str("CACHE_DIR", osp.join(ROOT_PATH, "cache")), 22 | "huggingface": {"HF_TOKEN": env.str("HF_TOKEN")}, 23 | } 24 | 25 | 26 | settings = NestedNamespace(settings) 27 | -------------------------------------------------------------------------------- /config/datasets/data-config.yaml: -------------------------------------------------------------------------------- 1 | sources: 2 | - https://www.youtube.com/watch?v=4utBo-9hMSc 3 | - https://www.youtube.com/watch?v=ncAnC7pVr7w -------------------------------------------------------------------------------- /config/logging.ini: -------------------------------------------------------------------------------- 1 | [loggers] 2 | keys=root, server_log, module_log 3 | 4 | [handlers] 5 | keys=consoleHandler, server_hand, module_hand 6 | 7 | [formatters] 8 | keys=simpleFormatter, detailedFormatter, json 9 | 10 | [logger_root] 11 | handlers=consoleHandler 12 | level=NOTSET 13 | 14 | [logger_server_log] 15 | level=NOTSET 16 | handlers=server_hand 17 | qualname=server_log 18 | 19 | [logger_module_log] 20 | level=NOTSET 21 | handlers=module_hand 22 | qualname=module_log 23 | 24 | [handler_server_hand] 25 | class=logging.handlers.TimedRotatingFileHandler 26 | level=NOTSET 27 | formatter=json 28 | args=('%(logdir)s/server.log', 'W6', 1, 5, None, False, True) 29 | 30 | [handler_module_hand] 31 | class=logging.handlers.TimedRotatingFileHandler 32 | level=NOTSET 33 | formatter=json 34 | args=('%(logdir)s/module.log', 'W6', 1, 5, None, False, True) 35 | 36 | [handler_consoleHandler] 37 | class=StreamHandler 38 | level=NOTSET 39 | formatter=simpleFormatter 40 | args=(sys.stdout,) 41 | 42 | [formatter_simpleFormatter] 43 | format=%(asctime)s - [%(levelname)s] - %(lineno)d - %(message)s 44 | datefmt=%Y-%m-%d %H:%M:%S 45 | 46 | [formatter_detailedFormatter] 47 | class=utils.loggers.ColorFormatter 48 | datefmt=%Y-%m-%d %H:%M:%S 49 | 50 | [formatter_json] 51 | class=pythonjsonlogger.jsonlogger.JsonFormatter 52 | format=%(asctime)s - [%(threadName)-12.12s] [%(levelname)s] - %(name)s - (%(filename)s).%(funcName)s(%(lineno)d) - %(message)s 53 | datefmt=%Y-%m-%d %H:%M:%S -------------------------------------------------------------------------------- /config/pipelines/asr_to_tts.yaml: -------------------------------------------------------------------------------- 1 | pipeline: 2 | loader: 3 | target: manager.Downloader 4 | args: 5 | configs: 6 | - config/datasets/data-config.yaml 7 | manager: 8 | target: manager.ASR2TTSRunner 9 | processors: 10 | - name: downloader 11 | target: modules.Downloader 12 | - name: voice_activity_detection 13 | target: modules.VoiceActivityDetection 14 | args: 15 | model_choice: webrtc_voice_activity_detection 16 | - name: denoiser 17 | target: modules.DenoiseAudio 18 | args: 19 | model_choice: master64 20 | - name: superres 21 | target: modules.SuperResAudio 22 | args: 23 | model_choice: voicefixer -------------------------------------------------------------------------------- /config/pipelines/yt_data.yaml: -------------------------------------------------------------------------------- 1 | pipeline: 2 | loader: 3 | target: manager.Downloader 4 | args: 5 | configs: 6 | - config/datasets/data-config.yaml 7 | save_dir: raw_data/yt_data 8 | manager: 9 | target: manager.YoutubeRunner 10 | processors: 11 | - name: chunking 12 | target: modules.AudioChunking 13 | args: 14 | model_choice: pydub_chunking 15 | - name: denoise_audio 16 | target: modules.DenoiseAudio 17 | args: 18 | model_choice: meta_denoiser_dns48 19 | - name: audio_superres 20 | target: modules.SuperResAudio 21 | args: 22 | model_choice: voicefixer 23 | - name: transcription 24 | target: modules.TranscribeAudio 25 | args: 26 | model_choice: openai_whisper_base -------------------------------------------------------------------------------- /manager/__init__.py: -------------------------------------------------------------------------------- 1 | from utils.loggers import init_loggers 2 | 3 | init_loggers() 4 | 5 | 6 | from .downloader import Downloader 7 | from .runner import * 8 | -------------------------------------------------------------------------------- /manager/downloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import zipfile 4 | from uuid import uuid4 5 | from tqdm import tqdm 6 | from pytube import YouTube 7 | import wget 8 | from urllib.parse import urlparse 9 | 10 | from utils.io import load_configs 11 | from utils.helpers import exists 12 | from config import settings 13 | 14 | class Downloader: 15 | def __init__(self, configs=None, save_dir="data/") -> None: 16 | if exists(configs): 17 | self.configs = load_configs(configs[0]) 18 | else: 19 | self.configs = None 20 | self.save_dir = save_dir 21 | def download_from_youtube(self, url): 22 | yt = YouTube(url) 23 | stream = yt.streams.get_highest_resolution() 24 | save_path = os.path.join(self.save_dir, f"{str(uuid4())}") 25 | file_path = stream.download(output_path=save_path) 26 | 27 | directory, filename = os.path.split(file_path) 28 | file_root, file_extension = os.path.splitext(filename) 29 | sanitized_root = re.sub(r"[^a-zA-Z0-9 ]", "", file_root) 30 | sanitized_root = sanitized_root.replace(" ", "_") 31 | 32 | new_filename = f"{sanitized_root}{file_extension}" 33 | new_file_path = os.path.join(directory, new_filename) 34 | 35 | os.rename(file_path, new_file_path) 36 | metadata = {"video": new_file_path, "source": url} 37 | return (metadata, True) 38 | 39 | def download_from_url(self, url): 40 | save_path = os.path.join(self.save_dir, f"{str(uuid4())}.zip") 41 | wget.download(url, save_path) 42 | metadata = {"file": save_path, "source": url} 43 | return (metadata, True) 44 | 45 | def unzip_file(self, path, save_dir): 46 | with zipfile.ZipFile(path, 'r') as zip_ref: 47 | zip_ref.extractall(save_dir) 48 | 49 | def walk_files(self, save_dir=None): 50 | for path in tqdm(self.configs.sources): 51 | save_dir = self.save_dir or self.configs.get("save_dir", []) 52 | if "youtube.com" in path: 53 | metadata, _ = self.download_from_youtube(path) 54 | yield metadata 55 | else: 56 | metadata, _ = self.download_from_url(path, save_dir) 57 | yield metadata 58 | 59 | def download_file(self, metadata, save_dir): 60 | if "source" in metadata: 61 | parsed_url = urlparse(metadata["source"]) 62 | if "youtube.com" in parsed_url.netloc: 63 | return self.download_from_youtube(metadata["source"], save_dir) 64 | else: 65 | metadata, status = self.download_from_url(metadata["source"], save_dir) 66 | if status and metadata['file'].endswith('.zip'): 67 | self.unzip_file(metadata['file'], save_dir) 68 | return (metadata, status) 69 | 70 | def __call__(self, metadata, save_dir=None): 71 | if not exists(save_dir): 72 | save_dir = os.path.join(settings.CACHE_DIR, "tmp", "downloads") 73 | return self.download_file(metadata, save_dir) 74 | -------------------------------------------------------------------------------- /manager/runner.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | from os import path as osp 3 | from uuid import uuid4 4 | from tqdm import tqdm 5 | from pathlib import Path 6 | import inspect 7 | import shutil 8 | import json 9 | import time 10 | 11 | from modules.audio import convert2wav 12 | from config import settings 13 | from utils.io import load_configs, merge_configs 14 | from utils.helpers import exists, get_obj_from_str 15 | from utils.loggers import get_logger 16 | import torchaudio 17 | 18 | logger = get_logger("module_log") 19 | 20 | 21 | class Runner: 22 | ALLOWED_PROCESSORS = [ 23 | "downloader", 24 | "speaker_diarization", 25 | "chunking", 26 | "music_separation", 27 | "denoise_audio", 28 | "gender_classification", 29 | "emotion_classification", 30 | "transcription", 31 | "audio_superres", 32 | ] 33 | 34 | def __init__(self, configs, lazy_load=True) -> None: 35 | self.config = load_configs(configs) 36 | self.lazy_load = lazy_load 37 | 38 | self.processors = {} 39 | if not hasattr(self.config, "processors"): 40 | self.config = merge_configs(self.config, {"processors": []}) 41 | for proc in self.config.processors: 42 | assert ( 43 | proc.name in self.ALLOWED_PROCESSORS 44 | ), f"Processor {proc.name} is not supported" 45 | assert ( 46 | proc.name not in self.processors 47 | ), f"A processor already exists with the name {proc.name}" 48 | 49 | obj = get_obj_from_str(proc.target) 50 | if inspect.isclass(obj) and not self.lazy_load: 51 | self.processors[proc.name] = {"obj": obj(**proc.args), "loaded": True} 52 | else: 53 | self.processors[proc.name] = {"obj": obj, "loaded": False} 54 | 55 | def load_processors(self, names=None, reload=False): 56 | if not (self.lazy_load or reload): 57 | return 58 | if exists(names) and isinstance(names, str): 59 | names = [names] 60 | 61 | for proc in self.config.processors: 62 | if ( 63 | (not exists(names) or proc.name in names) 64 | and (not self.processors[proc.name]["loaded"] or reload) 65 | and inspect.isclass(self.processors[proc.name]["obj"]) 66 | ): 67 | if hasattr(proc, "args"): 68 | self.processors[proc.name]["obj"] = self.processors[proc.name][ 69 | "obj" 70 | ](**proc.args) 71 | else: 72 | self.processors[proc.name]["obj"] = self.processors[proc.name][ 73 | "obj" 74 | ]() 75 | self.processors[proc.name]["loaded"] = True 76 | 77 | def offload_processors(self, names=None): 78 | if exists(names) and isinstance(names, str): 79 | names = [names] 80 | 81 | for proc in self.config.processors: 82 | if not exists(names) or proc.name in names: 83 | del self.processors[proc.name] 84 | self.processors[proc.name] = { 85 | "obj": get_obj_from_str(proc.target), 86 | "loaded": False, 87 | } 88 | logger.info(f"Dag {proc.name} deleted!!!") 89 | 90 | def resolve_dag_processor(self, name): 91 | raw_name = name 92 | name = name.split(".")[0] 93 | assert name in self.processors, f"Processor for dag {raw_name} is not declared" 94 | self.load_processors(name) 95 | return self.processors[name]["obj"] 96 | 97 | @lru_cache() 98 | def refactor_dag_name_with_ordinal(self, name): 99 | name_splits = name.split(".") 100 | 101 | if name_splits[0] not in self.processed_dags: 102 | self.processed_dags[name_splits[0]] = 0 103 | else: 104 | self.processed_dags[name_splits[0]] += 1 105 | 106 | if len(name_splits) == 1: 107 | name = name + f".{len(self.processed_dags[name])}" 108 | else: 109 | name = name 110 | 111 | def run_dag(self, name, **kwargs): 112 | processor = self.resolve_dag_processor(name) 113 | return processor(**kwargs) 114 | 115 | def cleanup_dag(self, name): 116 | if self.lazy_load: 117 | logger.info("Cleaning up dag: %s", name) 118 | self.offload_processors(name) 119 | 120 | 121 | class YoutubeRunner(Runner): 122 | def __init__(self, configs, lazy_load=True) -> None: 123 | super().__init__(configs, lazy_load) 124 | 125 | def __call__(self, file_metadata, **kwargs): 126 | cache_dir = osp.join(settings.CACHE_DIR, "tmp", str(uuid4())) 127 | # print(file_metadata["video"].split("/")[-1]) 128 | wav_path = convert2wav(file_metadata["video"]) 129 | dag_name = "chunking" 130 | logger.info(f"Running pipeline -> {dag_name}") 131 | now = time.time() 132 | audio_chunks = self.run_dag( 133 | dag_name, 134 | audio_path=wav_path, 135 | save_to_file=True, 136 | save_dir=osp.join("data", file_metadata["video"].split("/")[-1][:-4],"chunked_audio"), 137 | ) 138 | file_metadata[dag_name] = audio_chunks 139 | file_metadata[f"{dag_name}_proc_time"] = time.time() - now 140 | dag_name = "denoise_audio" 141 | logger.info(f"Running pipeline -> {dag_name}") 142 | total_time = 0 143 | for v, va in tqdm( 144 | enumerate(file_metadata["chunking"]["audio_chunks"]), 145 | desc=dag_name, 146 | ): 147 | now = time.time() 148 | enhanced_audio = self.run_dag( 149 | dag_name, 150 | audio_path=va["filepath"], 151 | save_to_file=True, 152 | save_dir=osp.join("data", file_metadata["video"].split("/")[-1][:-4],"denoise_audio"), 153 | ) 154 | proc_time = time.time() - now 155 | file_metadata["chunking"]["audio_chunks"][v].update( 156 | {dag_name: enhanced_audio, "enhancement_proc_time": proc_time} 157 | ) 158 | total_time += proc_time 159 | file_metadata[f"{dag_name}_proc_time"] = total_time 160 | self.cleanup_dag(dag_name) 161 | 162 | dag_name = "audio_superres" 163 | logger.info(f"Running pipeline -> {dag_name}") 164 | total_time = 0 165 | 166 | for v, va in tqdm( 167 | enumerate(file_metadata["chunking"]["audio_chunks"]), 168 | desc=dag_name, 169 | ): 170 | now = time.time() 171 | enhanced_audio = self.run_dag( 172 | dag_name, 173 | audio_path=va["denoise_audio"], 174 | save_to_file=True, 175 | save_dir=osp.join("data", file_metadata["video"].split("/")[-1][:-4],"superres_audio"), 176 | ) 177 | proc_time = time.time() - now 178 | file_metadata["chunking"]["audio_chunks"][v].update( 179 | {dag_name: enhanced_audio, "enhancement_proc_time": proc_time} 180 | ) 181 | total_time += proc_time 182 | file_metadata[f"{dag_name}_proc_time"] = total_time 183 | self.cleanup_dag(dag_name) 184 | 185 | dag_name = "transcription" 186 | logger.info(f"Running pipeline -> {dag_name}") 187 | total_time = 0 188 | print(file_metadata) 189 | for v, va in tqdm( 190 | enumerate(file_metadata["chunking"]["audio_chunks"]), 191 | desc=dag_name, 192 | ): 193 | now = time.time() 194 | transcription = self.run_dag( 195 | dag_name, 196 | audio_path=va["audio_superres"] 197 | ) 198 | proc_time = time.time() - now 199 | total_time += proc_time 200 | transcript_path = osp.join("data", file_metadata["video"].split("/")[-1][:-4],"transcription") 201 | 202 | make_path = Path(transcript_path) 203 | make_path.mkdir(parents=True, exist_ok=True) 204 | 205 | file_path = transcript_path+"/"+va["filepath"].split("/")[-1][:-4]+".txt" 206 | with open(file_path, 'w', encoding='utf-8') as transcript_file: 207 | transcript_file.write(transcription) 208 | file_metadata["chunking"]["audio_chunks"][v].update( 209 | {dag_name: enhanced_audio, "enhancement_proc_time": proc_time} 210 | ) 211 | file_metadata[f"{dag_name}_proc_time"] = total_time 212 | self.cleanup_dag(dag_name) 213 | return file_metadata 214 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .transcribe import TranscribeAudio 2 | from .chunking import AudioChunking 3 | from .denoise_audio import DenoiseAudio 4 | from .audio_superres import SuperResAudio 5 | from .music_separation import PartitionAudio -------------------------------------------------------------------------------- /modules/audio.py: -------------------------------------------------------------------------------- 1 | import torchaudio 2 | import lameenc 3 | from pathlib import Path 4 | from os import system as os_system 5 | from collections import namedtuple 6 | 7 | FFMPEG_BIN = "ffmpeg" 8 | Info = namedtuple("Info", ["length", "sample_rate", "channels"]) 9 | def load_audio(audio_path): 10 | audio, sr = torchaudio.load(audio_path) 11 | return audio, sr 12 | 13 | def get_audio_info(audio_path): 14 | info = torchaudio.info(audio_path) 15 | if hasattr(info, "num_frames"): 16 | return Info(info.num_frames, info.sample_rate, info.num_channels) 17 | else: 18 | siginfo = info[0] 19 | return Info(siginfo.length // siginfo.channels, siginfo.rate, siginfo.channels) 20 | 21 | def normalize_audio(wav): 22 | return wav / max(wav.abs().max().item(), 1) 23 | 24 | def convert_channels(wav, channels): 25 | if wav.shape[0] != channels: 26 | wav = wav.mean(dim=0, keepdim=True).expand(channels, -1) 27 | return wav 28 | 29 | def convert_audio(wav, from_sr, to_sr, channels): 30 | if from_sr != to_sr: 31 | wav = torchaudio.transforms.Resample(from_sr, to_sr)(wav) 32 | return convert_channels(wav, channels) 33 | 34 | def save_audio(wav, path, sr, bitrate=320, bits_per_sample=16): 35 | path = Path(path) 36 | if path.suffix.lower() == ".mp3": 37 | encode_mp3(wav, path, sr, bitrate) 38 | else: 39 | torchaudio.save(str(path), wav, sample_rate=sr, bits_per_sample=bits_per_sample) 40 | 41 | def encode_mp3(wav, path, sr, bitrate): 42 | wav = (wav.clamp_(-1, 1) * (2 ** 15 - 1)).short().data.cpu().numpy().T 43 | mp3_data = lameenc.Encoder().set_bit_rate(bitrate).set_in_sample_rate(sr).set_channels(1).encode(wav.tobytes()) 44 | with open(path, "wb") as f: 45 | f.write(mp3_data) 46 | 47 | def convert_and_trim_ffmpeg(src_file, dst_file, sr, start_tm, end_tm): 48 | cmd = f'{FFMPEG_BIN} -i {src_file} -ar {sr} -ac 1 -ss {start_tm} -to {end_tm} {dst_file} -y -loglevel panic' 49 | os_system(cmd) 50 | 51 | def get_duration(wave_file, sr=16000): 52 | y, _ = torchaudio.load(wave_file) 53 | return len(y[0]) / sr 54 | 55 | def convert2wav(audio_path): 56 | ext = Path(audio_path).suffix.lower() 57 | supported_ext = [".sph", ".wav", ".mp3", ".flac", ".ogg", ".mp4"] 58 | 59 | if ext not in supported_ext: 60 | raise NotImplementedError(f"Audio format {ext} is not supported") 61 | 62 | if ext != ".wav": 63 | dst_path = audio_path.replace(ext, ".wav") 64 | cmd = f'{FFMPEG_BIN} -i {audio_path} {dst_path}' 65 | os_system(cmd) 66 | return dst_path 67 | else: 68 | return audio_path 69 | 70 | # Usage example 71 | if __name__ == '__main__': 72 | audio, sr = load_audio("audio_path") 73 | audio = normalize_audio(audio) 74 | audio = convert_audio(audio, sr, 44100, 1) 75 | save_audio(audio, "output_path", 44100) 76 | -------------------------------------------------------------------------------- /modules/audio_superres.py: -------------------------------------------------------------------------------- 1 | from os import path as osp 2 | import os 3 | from pathlib import Path 4 | from audiosr import super_resolution 5 | from functools import partial 6 | import argparse 7 | from .common import Base 8 | from modules.audio_superres_utils import load_audiosr 9 | from voicefixer import VoiceFixer 10 | from config import settings 11 | 12 | cache_dir = osp.join(settings.CACHE_DIR, "weights", "enhancement") 13 | 14 | 15 | class SuperResAudio(Base): 16 | MODEL_CHOICES = { 17 | "audiosr": { 18 | "model": partial( 19 | load_audiosr, 20 | args=argparse.Namespace( 21 | **{ 22 | "model_name": None, 23 | "device": "auto", 24 | } 25 | ) 26 | ), 27 | "target": "sr_with_voicefixer", 28 | }, 29 | "voicefixer": { 30 | "model": VoiceFixer, 31 | "target": "sr_with_voicefixer", 32 | }, 33 | } 34 | 35 | def sr_with_audiosr(self, audio_path): 36 | waveform = super_resolution( 37 | self.model, 38 | audio_path, 39 | guidance_scale=3.5, 40 | ddim_steps=50, 41 | latent_t_per_second=12.8 42 | ) 43 | return waveform 44 | def sr_with_voicefixer(self, audio_path, **kwargs): 45 | save_dir = kwargs.get("save_dir") 46 | if not osp.exists(save_dir): 47 | os.makedirs(save_dir) 48 | original_file_name = osp.basename(audio_path) 49 | self.model["model"].restore( 50 | input=audio_path, # low quality .wav/.flac file 51 | output=osp.join(save_dir,original_file_name), # save file path 52 | cuda=True, # GPU acceleration 53 | mode=0, 54 | ) 55 | return osp.join(save_dir,original_file_name) 56 | -------------------------------------------------------------------------------- /modules/audio_superres_utils.py: -------------------------------------------------------------------------------- 1 | from audiosr import build_model 2 | 3 | def load_audiosr(args): 4 | return build_model(args.model_name, device=args.device) -------------------------------------------------------------------------------- /modules/chunking.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from os import path as osp 3 | import os 4 | import torchaudio 5 | from .common import Base 6 | from . import audio 7 | from config import settings 8 | from pydub import AudioSegment 9 | from pydub.silence import split_on_silence 10 | from typing import Any 11 | 12 | class AudioChunking(Base): 13 | MODEL_CHOICES = { 14 | "pydub_chunking": { 15 | "target": "chunk_by_silence", 16 | } 17 | } 18 | def __init__(self, model_choice: str, **kwargs) -> None: 19 | super().__init__(model_choice, **kwargs) 20 | 21 | def chunk_by_silence(self, audio_path, silence_len=800, silence_thresh=-40, min_chunk_len=2.0, 22 | max_chunk_len=25, **kwargs) -> Any: 23 | audio_info = audio.get_audio_info(audio_path) 24 | audio_segment = AudioSegment.from_wav(audio_path) 25 | 26 | # Use pydub to split audio based on silence 27 | chunks = split_on_silence(audio_segment, 28 | min_silence_len=silence_len, 29 | silence_thresh=silence_thresh) 30 | 31 | chunk_list = [] 32 | total_chunk_duration = 0 33 | 34 | for i, chunk in enumerate(chunks): 35 | chunk_duration = chunk.duration_seconds 36 | if chunk_duration < min_chunk_len or chunk_duration > max_chunk_len: 37 | continue 38 | meta = { 39 | "duration": chunk_duration, 40 | "filepath": None, 41 | "sample_rate": chunk.frame_rate, 42 | } 43 | 44 | total_chunk_duration += meta["duration"] 45 | 46 | if "save_to_file" in kwargs and kwargs["save_to_file"]: 47 | meta["filepath"] = self.save_to_file(chunk, chunk.frame_rate, audio_path, i, save_dir=kwargs["save_dir"]) 48 | 49 | chunk_list.append(meta) 50 | 51 | return { 52 | "audio_chunks": chunk_list, 53 | "total_chunk_duration": total_chunk_duration, 54 | "total_audio_duration": audio_info.length, 55 | } 56 | 57 | def save_to_file(self, audio_chunk, sr, audio_path, chunk_idx, save_dir): 58 | if not os.path.exists(save_dir): 59 | os.makedirs(save_dir) 60 | 61 | original_file_name = osp.basename(audio_path) 62 | file_name_without_extension = osp.splitext(original_file_name)[0] 63 | chunk_file_name = f"{file_name_without_extension}_chunk_{chunk_idx}.wav" 64 | chunk_file_path = osp.join(save_dir, chunk_file_name) 65 | 66 | audio_chunk.export(chunk_file_path, format="wav") 67 | 68 | return chunk_file_path 69 | 70 | def __call__(self, audio_path: str = None, **kwargs) -> Any: 71 | assert audio_path is not None, "audio_path is required" 72 | 73 | audio_chunks_info = self.chunk_by_silence(audio_path, **kwargs) 74 | 75 | return audio_chunks_info 76 | -------------------------------------------------------------------------------- /modules/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from uuid import uuid4 3 | 4 | from typing import Union, Optional, Any 5 | from functools import partial 6 | import os 7 | from abc import abstractmethod 8 | 9 | from . import audio as audio_ops 10 | from utils.helpers import exists 11 | from utils.loggers import get_logger 12 | 13 | 14 | logger = get_logger("module_log") 15 | 16 | 17 | class Base: 18 | MODEL_CHOICES = {} 19 | 20 | def __init__( 21 | self, 22 | model_choice: str, 23 | sampling_rate: int = 16000, 24 | padding: Union[bool, str] = True, 25 | max_length: Optional[int] = None, 26 | pad_to_multiple_of: Optional[int] = None, 27 | max_audio_len: int = 5, 28 | **kwargs, 29 | ) -> None: 30 | self.model_choice = model_choice.lower() 31 | assert ( 32 | self.model_choice in self.MODEL_CHOICES 33 | ), f"Unrecognized model choice {self.model_choice}" 34 | model = self.MODEL_CHOICES[self.model_choice] 35 | if isinstance(model, dict): 36 | self.model = {} 37 | for key, value in model.items(): 38 | if key in ["target"]: 39 | continue 40 | self.model[key] = value(**kwargs) 41 | elif isinstance(model, partial): 42 | self.model = model(**kwargs) 43 | else: 44 | raise NotImplementedError("Not sure how to handle this model choice") 45 | 46 | self.sampling_rate = sampling_rate 47 | self.padding = padding 48 | self.max_length = max_length 49 | self.pad_to_multiple_of = pad_to_multiple_of 50 | self.max_audio_len = max_audio_len 51 | 52 | self.__post__init__() 53 | 54 | def __post__init__(self): 55 | for key, value in self.MODEL_CHOICES.items(): 56 | if ( 57 | isinstance(value, dict) 58 | and "target" in value 59 | and isinstance(value["target"], str) 60 | ): 61 | self.MODEL_CHOICES[key]["target"] = getattr(self, value["target"]) 62 | 63 | @abstractmethod 64 | def predict(self, **kwargs): 65 | self.model(**kwargs) 66 | 67 | def __call__( 68 | self, audio_path: str = None, audio: torch.Tensor = None, **kwargs 69 | ) -> Any: 70 | assert exists(audio_path) or exists( 71 | audio 72 | ), "Either audio_path or audio tensor is required" 73 | if isinstance(self.model, dict): 74 | prediction = self.MODEL_CHOICES[self.model_choice]["target"]( 75 | audio_path=audio_path, audio=audio, **kwargs 76 | ) 77 | else: 78 | prediction = self.predict(audio_path=audio_path, audio=audio, **kwargs) 79 | return prediction 80 | 81 | def save_to_file(self, audio, sr, save_dir, start_dur=None, stop_dur=None): 82 | # Handling audio with more than 2 dimensions 83 | if audio.ndim > 2: 84 | print(f"Warning: Audio has {audio.ndim} dimensions, averaging over channels for simplicity.") 85 | audio = torch.mean(audio, dim=-1) 86 | 87 | if exists(start_dur): 88 | start_sample = round(start_dur * sr) 89 | audio = audio[start_sample:] 90 | 91 | if exists(stop_dur): 92 | stop_sample = round(stop_dur * sr) 93 | audio = audio[:stop_sample] 94 | 95 | if not os.path.exists(save_dir): 96 | os.makedirs(save_dir) 97 | 98 | if audio.ndim == 1: 99 | audio = audio.unsqueeze(0) 100 | 101 | save_path = ( 102 | os.path.join(save_dir, f"{str(uuid4())}.wav") 103 | if not os.path.splitext(save_dir)[-1] 104 | else save_dir 105 | ) 106 | audio_ops.save_audio(wav=audio, path=save_path, sr=sr) 107 | return save_path 108 | -------------------------------------------------------------------------------- /modules/demucs_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import sys 9 | from pathlib import Path 10 | import subprocess 11 | 12 | from dora.log import fatal 13 | import torch as th 14 | import torchaudio as ta 15 | 16 | from demucs.apply import apply_model, BagOfModels 17 | from demucs.audio import AudioFile, convert_audio, save_audio 18 | from demucs.pretrained import get_model_from_args, add_model_flags, ModelLoadingError 19 | 20 | 21 | def load_track(track, audio_channels, samplerate): 22 | errors = {} 23 | wav = None 24 | 25 | try: 26 | wav = AudioFile(track).read( 27 | streams=0, samplerate=samplerate, channels=audio_channels 28 | ) 29 | except FileNotFoundError: 30 | errors["ffmpeg"] = "FFmpeg is not installed." 31 | except subprocess.CalledProcessError: 32 | errors["ffmpeg"] = "FFmpeg could not read the file." 33 | 34 | if wav is None: 35 | try: 36 | wav, sr = ta.load(str(track)) 37 | except RuntimeError as err: 38 | errors["torchaudio"] = err.args[0] 39 | else: 40 | wav = convert_audio(wav, sr, samplerate, audio_channels) 41 | 42 | if wav is None: 43 | print( 44 | f"Could not load file {track}. " "Maybe it is not a supported file format? " 45 | ) 46 | for backend, error in errors.items(): 47 | print( 48 | f"When trying to load using {backend}, got the following error: {error}" 49 | ) 50 | sys.exit(1) 51 | return wav 52 | 53 | 54 | def get_parser(): 55 | parser = argparse.ArgumentParser( 56 | "demucs.separate", description="Separate the sources for the given tracks" 57 | ) 58 | parser.add_argument( 59 | "tracks", nargs="+", type=Path, default=[], help="Path to tracks" 60 | ) 61 | add_model_flags(parser) 62 | parser.add_argument("-v", "--verbose", action="store_true") 63 | parser.add_argument( 64 | "-o", 65 | "--out", 66 | type=Path, 67 | default=Path("separated"), 68 | help="Folder where to put extracted tracks. A subfolder " 69 | "with the model name will be created.", 70 | ) 71 | parser.add_argument( 72 | "--filename", 73 | default="{track}/{stem}.{ext}", 74 | help="Set the name of output file. \n" 75 | 'Use "{track}", "{trackext}", "{stem}", "{ext}" to use ' 76 | "variables of track name without extension, track extension, " 77 | "stem name and default output file extension. \n" 78 | 'Default is "{track}/{stem}.{ext}".', 79 | ) 80 | parser.add_argument( 81 | "-d", 82 | "--device", 83 | default="cuda" if th.cuda.is_available() else "cpu", 84 | help="Device to use, default is cuda if available else cpu", 85 | ) 86 | parser.add_argument( 87 | "--shifts", 88 | default=1, 89 | type=int, 90 | help="Number of random shifts for equivariant stabilization." 91 | "Increase separation time but improves quality for Demucs. 10 was used " 92 | "in the original paper.", 93 | ) 94 | parser.add_argument( 95 | "--overlap", default=0.25, type=float, help="Overlap between the splits." 96 | ) 97 | split_group = parser.add_mutually_exclusive_group() 98 | split_group.add_argument( 99 | "--no-split", 100 | action="store_false", 101 | dest="split", 102 | default=True, 103 | help="Doesn't split audio in chunks. " "This can use large amounts of memory.", 104 | ) 105 | split_group.add_argument( 106 | "--segment", 107 | type=int, 108 | help="Set split size of each chunk. " 109 | "This can help save memory of graphic card. ", 110 | ) 111 | parser.add_argument( 112 | "--two-stems", 113 | dest="stem", 114 | metavar="STEM", 115 | help="Only separate audio into {STEM} and no_{STEM}. ", 116 | ) 117 | group = parser.add_mutually_exclusive_group() 118 | group.add_argument( 119 | "--int24", action="store_true", help="Save wav output as 24 bits wav." 120 | ) 121 | group.add_argument( 122 | "--float32", action="store_true", help="Save wav output as float32 (2x bigger)." 123 | ) 124 | parser.add_argument( 125 | "--clip-mode", 126 | default="rescale", 127 | choices=["rescale", "clamp"], 128 | help="Strategy for avoiding clipping: rescaling entire signal " 129 | "if necessary (rescale) or hard clipping (clamp).", 130 | ) 131 | format_group = parser.add_mutually_exclusive_group() 132 | format_group.add_argument( 133 | "--flac", action="store_true", help="Convert the output wavs to flac." 134 | ) 135 | format_group.add_argument( 136 | "--mp3", action="store_true", help="Convert the output wavs to mp3." 137 | ) 138 | parser.add_argument( 139 | "--mp3-bitrate", default=320, type=int, help="Bitrate of converted mp3." 140 | ) 141 | parser.add_argument( 142 | "--mp3-preset", 143 | choices=range(2, 8), 144 | type=int, 145 | default=2, 146 | help="Encoder preset of MP3, 2 for highest quality, 7 for " 147 | "fastest speed. Default is 2", 148 | ) 149 | parser.add_argument( 150 | "-j", 151 | "--jobs", 152 | default=0, 153 | type=int, 154 | help="Number of jobs. This can increase memory usage but will " 155 | "be much faster when multiple cores are available.", 156 | ) 157 | 158 | return parser 159 | 160 | 161 | def load_demucs_model(args): 162 | try: 163 | model = get_model_from_args(args) 164 | except ModelLoadingError as error: 165 | fatal(error.args[0]) 166 | 167 | if isinstance(model, BagOfModels): 168 | print( 169 | f"Selected model is a bag of {len(model.models)} models. " 170 | "You will see that many progress bars per track." 171 | ) 172 | 173 | model.cpu() 174 | model.eval() 175 | return model 176 | 177 | 178 | def main(opts=None): 179 | parser = get_parser() 180 | args = parser.parse_args(opts) 181 | 182 | model = load_demucs_model(args) 183 | 184 | if args.stem is not None and args.stem not in model.sources: 185 | fatal( 186 | 'error: stem "{stem}" is not in selected model. STEM must be one of {sources}.'.format( 187 | stem=args.stem, sources=", ".join(model.sources) 188 | ) 189 | ) 190 | out = args.out / args.name 191 | out.mkdir(parents=True, exist_ok=True) 192 | print(f"Separated tracks will be stored in {out.resolve()}") 193 | for track in args.tracks: 194 | if not track.exists(): 195 | print( 196 | f"File {track} does not exist. If the path contains spaces, " 197 | 'please try again after surrounding the entire path with quotes "".', 198 | file=sys.stderr, 199 | ) 200 | continue 201 | print(f"Separating track {track}") 202 | wav = load_track(track, model.audio_channels, model.samplerate) 203 | 204 | ref = wav.mean(0) 205 | wav -= ref.mean() 206 | wav /= ref.std() 207 | sources = apply_model( 208 | model, 209 | wav[None], 210 | device=args.device, 211 | shifts=args.shifts, 212 | split=args.split, 213 | overlap=args.overlap, 214 | progress=True, 215 | num_workers=args.jobs, 216 | segment=args.segment, 217 | )[0] 218 | sources *= ref.std() 219 | sources += ref.mean() 220 | 221 | if args.mp3: 222 | ext = "mp3" 223 | elif args.flac: 224 | ext = "flac" 225 | else: 226 | ext = "wav" 227 | kwargs = { 228 | "samplerate": model.samplerate, 229 | "bitrate": args.mp3_bitrate, 230 | "preset": args.mp3_preset, 231 | "clip": args.clip_mode, 232 | "as_float": args.float32, 233 | "bits_per_sample": 24 if args.int24 else 16, 234 | } 235 | if args.stem is None: 236 | for source, name in zip(sources, model.sources): 237 | stem = out / args.filename.format( 238 | track=track.name.rsplit(".", 1)[0], 239 | trackext=track.name.rsplit(".", 1)[-1], 240 | stem=name, 241 | ext=ext, 242 | ) 243 | stem.parent.mkdir(parents=True, exist_ok=True) 244 | save_audio(source, str(stem), **kwargs) 245 | else: 246 | sources = list(sources) 247 | stem = out / args.filename.format( 248 | track=track.name.rsplit(".", 1)[0], 249 | trackext=track.name.rsplit(".", 1)[-1], 250 | stem=args.stem, 251 | ext=ext, 252 | ) 253 | stem.parent.mkdir(parents=True, exist_ok=True) 254 | save_audio(sources.pop(model.sources.index(args.stem)), str(stem), **kwargs) 255 | # Warning : after poping the stem, selected stem is no longer in the list 'sources' 256 | other_stem = th.zeros_like(sources[0]) 257 | for i in sources: 258 | other_stem += i 259 | stem = out / args.filename.format( 260 | track=track.name.rsplit(".", 1)[0], 261 | trackext=track.name.rsplit(".", 1)[-1], 262 | stem="no_" + args.stem, 263 | ext=ext, 264 | ) 265 | stem.parent.mkdir(parents=True, exist_ok=True) 266 | save_audio(other_stem, str(stem), **kwargs) 267 | -------------------------------------------------------------------------------- /modules/denoise_audio.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Union, Optional 2 | import torch 3 | import argparse 4 | from os import path as osp 5 | import os 6 | from functools import partial 7 | from denoiser.audio import Audioset 8 | 9 | from .common import Base 10 | from . import audio as audio_ops 11 | from .denoiser_utils import get_model 12 | from utils.helpers import exists 13 | from config import settings 14 | 15 | cache_dir = osp.join(settings.CACHE_DIR, "weights", "enhancement") 16 | 17 | 18 | class DenoiseAudio(Base): 19 | MODEL_CHOICES = { 20 | "meta_denoiser_master64": { 21 | "model": partial( 22 | get_model, 23 | args=argparse.Namespace( 24 | **{ 25 | "model_path": None, 26 | "hub_dir": osp.join(cache_dir, "fair-denoiser-master64"), 27 | "master64": True, 28 | "dns64": False, 29 | "valentini": False, 30 | "dns48": False, 31 | } 32 | ), 33 | ), 34 | "target": "enhance_with_denoiser", 35 | }, 36 | "meta_denoiser_dns64": { 37 | "model": partial( 38 | get_model, 39 | args=argparse.Namespace( 40 | **{ 41 | "model_path": None, 42 | "hub_dir": osp.join(cache_dir, "meta-denoiser-dns64"), 43 | "master64": False, 44 | "dns64": True, 45 | "valentini": False, 46 | "dns48": False, 47 | } 48 | ), 49 | ), 50 | "target": "enhance_with_denoiser", 51 | }, 52 | "meta_denoiser_valentini": { 53 | "model": partial( 54 | get_model, 55 | args=argparse.Namespace( 56 | **{ 57 | "model_path": None, 58 | "hub_dir": osp.join(cache_dir, "meta-denoiser-valentini"), 59 | "master64": False, 60 | "dns64": False, 61 | "valentini": True, 62 | } 63 | ), 64 | ), 65 | "target": "enhance_with_denoiser", 66 | }, 67 | "meta_denoiser_dns48": { 68 | "model": partial( 69 | get_model, 70 | args=argparse.Namespace( 71 | **{ 72 | "model_path": None, 73 | "hub_dir": osp.join(cache_dir, "meta-denoiser-dns48"), 74 | "master64": False, 75 | "dns64": False, 76 | "valentini": False, 77 | "dns48": True, 78 | } 79 | ), 80 | ), 81 | "target": "enhance_with_denoiser", 82 | }, 83 | } 84 | 85 | def _init_( 86 | self, 87 | model_choice: str, 88 | sampling_rate: int = 16000, 89 | padding: Union[bool,str] = True, 90 | max_length: Union[int,None] = None, 91 | pad_to_multiple_of: Union[int,None] = None, 92 | max_audio_len: int = 5, 93 | dry=0, 94 | **kwargs, 95 | ) -> None: 96 | super()._init_( 97 | model_choice, 98 | sampling_rate, 99 | padding, 100 | max_length, 101 | pad_to_multiple_of, 102 | max_audio_len, 103 | **kwargs, 104 | ) 105 | self.dry = dry 106 | def save_to_file(self, audio, sr, save_dir, audio_path, start_dur=None, stop_dur=None): 107 | # Handling audio with more than 2 dimensions 108 | if audio.ndim > 2: 109 | print(f"Warning: Audio has {audio.ndim} dimensions, averaging over channels for simplicity.") 110 | audio = torch.mean(audio, dim=-1) 111 | 112 | if exists(start_dur): 113 | start_sample = round(start_dur * sr) 114 | audio = audio[start_sample:] 115 | 116 | if exists(stop_dur): 117 | stop_sample = round(stop_dur * sr) 118 | audio = audio[:stop_sample] 119 | 120 | if not os.path.exists(save_dir): 121 | os.makedirs(save_dir) 122 | 123 | if audio.ndim == 1: 124 | audio = audio.unsqueeze(0) 125 | original_file_name = osp.basename(audio_path) 126 | save_path = ( 127 | os.path.join(save_dir, original_file_name) 128 | if not os.path.splitext(save_dir)[-1] 129 | else save_dir 130 | ) 131 | audio_ops.save_audio(wav=audio, path=save_path, sr=sr) 132 | return save_path 133 | def enhance_with_denoiser(self, audio_path, save_to_file=False, **kwargs): 134 | metadata = [(audio_path, audio_ops.get_audio_info(audio_path))] 135 | dataset = Audioset( 136 | metadata, 137 | with_path=False, 138 | sample_rate=self.model["model"].sample_rate, 139 | channels=self.model["model"].chin, 140 | convert=True, 141 | ) 142 | signal = dataset[0] 143 | with torch.no_grad(): 144 | estimate = self.model["model"](signal.cuda()) 145 | # estimate = (1 - self.dry) * estimate + self.dry * signal 146 | if save_to_file: 147 | save_dir = kwargs.get("save_dir") 148 | enhanced_audio = estimate.detach().cpu().squeeze(0) 149 | denoised_path = self.save_to_file( 150 | enhanced_audio, sr=16000, save_dir=save_dir, audio_path=audio_path 151 | ) 152 | return denoised_path 153 | 154 | def predict(self, audio_path, **kwargs) -> torch.Tensor: 155 | if hasattr(self.model, "enhance_file"): 156 | enhanced_audio = self.model.enhance_file(audio_path) 157 | else: 158 | raise NotImplementedError( 159 | f"{self.model_choice} doesn't have any supported methods" 160 | ) 161 | if enhanced_audio.ndim == 3: 162 | enhanced_audio = enhanced_audio.squeeze(-1) 163 | return enhanced_audio 164 | 165 | def _call_( 166 | self, 167 | audio_path: str = None, 168 | audio: torch.Tensor = None, 169 | save_to_file=False, 170 | offload=False, 171 | **kwargs, 172 | ) -> Any: 173 | assert exists(audio_path) or exists( 174 | audio 175 | ), "Either audio_path or audio tensor is required" 176 | 177 | 178 | 179 | if isinstance(self.model, dict): 180 | enhanced_audio = self.MODEL_CHOICES[self.model_choice]["target"]( 181 | audio_path=audio_path, audio=audio, **kwargs 182 | ) 183 | else: 184 | enhanced_audio = self.predict(audio_path=audio_path, audio=audio, **kwargs) 185 | 186 | if save_to_file: 187 | enhanced_audio = enhanced_audio.detach().cpu() 188 | enhanced_audio = self.save_to_file( 189 | enhanced_audio, sr=audio_info.sample_rate, save_dir=save_dir 190 | ) 191 | elif offload: 192 | enhanced_audio = enhanced_audio.detach().cpu() 193 | return enhanced_audio -------------------------------------------------------------------------------- /modules/denoiser_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch.hub 4 | from denoiser.demucs import Demucs 5 | from denoiser.utils import deserialize_model 6 | 7 | from .common import logger 8 | 9 | 10 | ROOT = "https://dl.fbaipublicfiles.com/adiyoss/denoiser/" 11 | DNS_48_URL = ROOT + "dns48-11decc9d8e3f0998.th" 12 | DNS_64_URL = ROOT + "dns64-a7761ff99a7d5bb6.th" 13 | MASTER_64_URL = ROOT + "master64-8a5dfb4bb92753dd.th" 14 | VALENTINI_NC = ROOT + "valentini_nc-93fc4337.th" # Non causal Demucs on Valentini 15 | 16 | 17 | def _demucs(pretrained, url, model_dir=None, **kwargs): 18 | model = Demucs(**kwargs, sample_rate=16_000) 19 | if pretrained: 20 | state_dict = torch.hub.load_state_dict_from_url( 21 | url, model_dir=model_dir, map_location="cpu" 22 | ) 23 | model.load_state_dict(state_dict) 24 | return model 25 | 26 | 27 | def dns48(pretrained=True, model_dir=None): 28 | return _demucs(pretrained, DNS_48_URL, hidden=48, model_dir=model_dir) 29 | 30 | 31 | def dns64(pretrained=True, model_dir=None): 32 | return _demucs(pretrained, DNS_64_URL, hidden=64, model_dir=model_dir) 33 | 34 | 35 | def master64(pretrained=True, model_dir=None): 36 | return _demucs(pretrained, MASTER_64_URL, hidden=64, model_dir=model_dir) 37 | 38 | 39 | def valentini_nc(pretrained=True, model_dir=None): 40 | return _demucs( 41 | pretrained, 42 | VALENTINI_NC, 43 | hidden=64, 44 | causal=False, 45 | stride=2, 46 | resample=2, 47 | model_dir=model_dir, 48 | ) 49 | 50 | 51 | def add_model_flags(parser): 52 | group = parser.add_mutually_exclusive_group(required=False) 53 | group.add_argument("-m", "--model_path", help="Path to local trained model.") 54 | group.add_argument( 55 | "--hub_dir", default=None, help="Path to save torch/hf hub models." 56 | ) 57 | group.add_argument( 58 | "--dns48", 59 | action="store_true", 60 | help="Use pre-trained real time H=48 model trained on DNS.", 61 | ) 62 | group.add_argument( 63 | "--dns64", 64 | action="store_true", 65 | help="Use pre-trained real time H=64 model trained on DNS.", 66 | ) 67 | group.add_argument( 68 | "--master64", 69 | action="store_true", 70 | help="Use pre-trained real time H=64 model trained on DNS and Valentini.", 71 | ) 72 | group.add_argument( 73 | "--valentini_nc", 74 | action="store_true", 75 | help="Use pre-trained H=64 model trained on Valentini, non causal.", 76 | ) 77 | 78 | 79 | def get_model(args): 80 | """ 81 | Load local model package or torchhub pre-trained model. 82 | """ 83 | if args.model_path: 84 | logger.info("Loading model from %s", args.model_path) 85 | pkg = torch.load(args.model_path, "cpu") 86 | if "model" in pkg: 87 | if "best_state" in pkg: 88 | pkg["model"]["state"] = pkg["best_state"] 89 | model = deserialize_model(pkg["model"]) 90 | else: 91 | model = deserialize_model(pkg) 92 | elif args.dns64: 93 | logger.info("Loading pre-trained real time H=64 model trained on DNS.") 94 | model = dns64(args.hub_dir) 95 | elif args.master64: 96 | logger.info( 97 | "Loading pre-trained real time H=64 model trained on DNS and Valentini." 98 | ) 99 | model = master64(args.hub_dir) 100 | elif args.dns48: 101 | logger.info( 102 | "Loading pre-trained real time H=48 model trained on DNS and Valentini." 103 | ) 104 | model = dns48(args.hub_dir) 105 | elif args.valentini_nc: 106 | logger.info("Loading pre-trained H=64 model trained on Valentini.") 107 | model = valentini_nc(args.hub_dir) 108 | else: 109 | logger.info("Loading pre-trained real time H=48 model trained on DNS.") 110 | model = dns48(args.hub_dir) 111 | logger.debug(model) 112 | return model.cuda() 113 | -------------------------------------------------------------------------------- /modules/lang_list.py: -------------------------------------------------------------------------------- 1 | # Language dict 2 | language_code_to_name = { 3 | "afr": "Afrikaans", 4 | "amh": "Amharic", 5 | "arb": "Modern Standard Arabic", 6 | "ary": "Moroccan Arabic", 7 | "arz": "Egyptian Arabic", 8 | "asm": "Assamese", 9 | "ast": "Asturian", 10 | "azj": "North Azerbaijani", 11 | "bel": "Belarusian", 12 | "ben": "Bengali", 13 | "bos": "Bosnian", 14 | "bul": "Bulgarian", 15 | "cat": "Catalan", 16 | "ceb": "Cebuano", 17 | "ces": "Czech", 18 | "ckb": "Central Kurdish", 19 | "cmn": "Mandarin Chinese", 20 | "cym": "Welsh", 21 | "dan": "Danish", 22 | "deu": "German", 23 | "ell": "Greek", 24 | "eng": "English", 25 | "est": "Estonian", 26 | "eus": "Basque", 27 | "fin": "Finnish", 28 | "fra": "French", 29 | "gaz": "West Central Oromo", 30 | "gle": "Irish", 31 | "glg": "Galician", 32 | "guj": "Gujarati", 33 | "heb": "Hebrew", 34 | "hin": "Hindi", 35 | "hrv": "Croatian", 36 | "hun": "Hungarian", 37 | "hye": "Armenian", 38 | "ibo": "Igbo", 39 | "ind": "Indonesian", 40 | "isl": "Icelandic", 41 | "ita": "Italian", 42 | "jav": "Javanese", 43 | "jpn": "Japanese", 44 | "kam": "Kamba", 45 | "kan": "Kannada", 46 | "kat": "Georgian", 47 | "kaz": "Kazakh", 48 | "kea": "Kabuverdianu", 49 | "khk": "Halh Mongolian", 50 | "khm": "Khmer", 51 | "kir": "Kyrgyz", 52 | "kor": "Korean", 53 | "lao": "Lao", 54 | "lit": "Lithuanian", 55 | "ltz": "Luxembourgish", 56 | "lug": "Ganda", 57 | "luo": "Luo", 58 | "lvs": "Standard Latvian", 59 | "mai": "Maithili", 60 | "mal": "Malayalam", 61 | "mar": "Marathi", 62 | "mkd": "Macedonian", 63 | "mlt": "Maltese", 64 | "mni": "Meitei", 65 | "mya": "Burmese", 66 | "nld": "Dutch", 67 | "nno": "Norwegian Nynorsk", 68 | "nob": "Norwegian Bokm\u00e5l", 69 | "npi": "Nepali", 70 | "nya": "Nyanja", 71 | "oci": "Occitan", 72 | "ory": "Odia", 73 | "pan": "Punjabi", 74 | "pbt": "Southern Pashto", 75 | "pes": "Western Persian", 76 | "pol": "Polish", 77 | "por": "Portuguese", 78 | "ron": "Romanian", 79 | "rus": "Russian", 80 | "slk": "Slovak", 81 | "slv": "Slovenian", 82 | "sna": "Shona", 83 | "snd": "Sindhi", 84 | "som": "Somali", 85 | "spa": "Spanish", 86 | "srp": "Serbian", 87 | "swe": "Swedish", 88 | "swh": "Swahili", 89 | "tam": "Tamil", 90 | "tel": "Telugu", 91 | "tgk": "Tajik", 92 | "tgl": "Tagalog", 93 | "tha": "Thai", 94 | "tur": "Turkish", 95 | "ukr": "Ukrainian", 96 | "urd": "Urdu", 97 | "uzn": "Northern Uzbek", 98 | "vie": "Vietnamese", 99 | "xho": "Xhosa", 100 | "yor": "Yoruba", 101 | "yue": "Cantonese", 102 | "zlm": "Colloquial Malay", 103 | "zsm": "Standard Malay", 104 | "zul": "Zulu", 105 | } 106 | LANGUAGE_NAME_TO_CODE = {v: k for k, v in language_code_to_name.items()} 107 | 108 | # Source langs: S2ST / S2TT / ASR don't need source lang 109 | # T2TT / T2ST use this 110 | text_source_language_codes = [ 111 | "afr", 112 | "amh", 113 | "arb", 114 | "ary", 115 | "arz", 116 | "asm", 117 | "azj", 118 | "bel", 119 | "ben", 120 | "bos", 121 | "bul", 122 | "cat", 123 | "ceb", 124 | "ces", 125 | "ckb", 126 | "cmn", 127 | "cym", 128 | "dan", 129 | "deu", 130 | "ell", 131 | "eng", 132 | "est", 133 | "eus", 134 | "fin", 135 | "fra", 136 | "gaz", 137 | "gle", 138 | "glg", 139 | "guj", 140 | "heb", 141 | "hin", 142 | "hrv", 143 | "hun", 144 | "hye", 145 | "ibo", 146 | "ind", 147 | "isl", 148 | "ita", 149 | "jav", 150 | "jpn", 151 | "kan", 152 | "kat", 153 | "kaz", 154 | "khk", 155 | "khm", 156 | "kir", 157 | "kor", 158 | "lao", 159 | "lit", 160 | "lug", 161 | "luo", 162 | "lvs", 163 | "mai", 164 | "mal", 165 | "mar", 166 | "mkd", 167 | "mlt", 168 | "mni", 169 | "mya", 170 | "nld", 171 | "nno", 172 | "nob", 173 | "npi", 174 | "nya", 175 | "ory", 176 | "pan", 177 | "pbt", 178 | "pes", 179 | "pol", 180 | "por", 181 | "ron", 182 | "rus", 183 | "slk", 184 | "slv", 185 | "sna", 186 | "snd", 187 | "som", 188 | "spa", 189 | "srp", 190 | "swe", 191 | "swh", 192 | "tam", 193 | "tel", 194 | "tgk", 195 | "tgl", 196 | "tha", 197 | "tur", 198 | "ukr", 199 | "urd", 200 | "uzn", 201 | "vie", 202 | "yor", 203 | "yue", 204 | "zsm", 205 | "zul", 206 | ] 207 | TEXT_SOURCE_LANGUAGE_NAMES = sorted( 208 | [language_code_to_name[code] for code in text_source_language_codes] 209 | ) 210 | 211 | # Target langs: 212 | # S2ST / T2ST 213 | s2st_target_language_codes = [ 214 | "eng", 215 | "arb", 216 | "ben", 217 | "cat", 218 | "ces", 219 | "cmn", 220 | "cym", 221 | "dan", 222 | "deu", 223 | "est", 224 | "fin", 225 | "fra", 226 | "hin", 227 | "ind", 228 | "ita", 229 | "jpn", 230 | "kor", 231 | "mlt", 232 | "nld", 233 | "pes", 234 | "pol", 235 | "por", 236 | "ron", 237 | "rus", 238 | "slk", 239 | "spa", 240 | "swe", 241 | "swh", 242 | "tel", 243 | "tgl", 244 | "tha", 245 | "tur", 246 | "ukr", 247 | "urd", 248 | "uzn", 249 | "vie", 250 | ] 251 | S2ST_TARGET_LANGUAGE_NAMES = sorted( 252 | [language_code_to_name[code] for code in s2st_target_language_codes] 253 | ) 254 | 255 | # S2TT / ASR 256 | S2TT_TARGET_LANGUAGE_NAMES = TEXT_SOURCE_LANGUAGE_NAMES 257 | # T2TT 258 | T2TT_TARGET_LANGUAGE_NAMES = TEXT_SOURCE_LANGUAGE_NAMES 259 | -------------------------------------------------------------------------------- /modules/music_separation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import sys 4 | from dora.log import fatal 5 | from os import path as osp 6 | from functools import partial 7 | from demucs.apply import apply_model, BagOfModels 8 | from demucs.htdemucs import HTDemucs 9 | 10 | from .common import Base 11 | from .demucs_utils import load_demucs_model, load_track 12 | from . import audio as audio_ops 13 | from utils.helpers import exists 14 | from config import settings 15 | 16 | cache_dir = osp.join(settings.CACHE_DIR, "weights", "partition") 17 | 18 | 19 | class PartitionAudio(Base): 20 | MODEL_CHOICES = { 21 | "meta_demucs_htdemucs": { 22 | "model": partial( 23 | load_demucs_model, 24 | args=argparse.Namespace(**{"segment": None, "name": "htdemucs"}), 25 | ), 26 | "target": "partition_with_demucs", 27 | }, 28 | "meta_demucs_htdemucs_ft": { 29 | "model": partial( 30 | load_demucs_model, 31 | args=argparse.Namespace(**{"segment": None, "name": "htdemucs_ft"}), 32 | ), 33 | "target": "partition_with_demucs", 34 | }, 35 | "meta_demucs_htdemucs_6s": { 36 | "model": partial( 37 | load_demucs_model, 38 | args=argparse.Namespace(**{"segment": None, "name": "htdemucs_6s"}), 39 | ), 40 | "target": "partition_with_demucs", 41 | }, 42 | "meta_demucs_htdemucs_mmi": { 43 | "model": partial( 44 | load_demucs_model, 45 | args=argparse.Namespace(**{"segment": None, "name": "htdemucs_mmi"}), 46 | ), 47 | "target": "partition_with_demucs", 48 | }, 49 | "meta_demucs_mdx": { 50 | "model": partial( 51 | load_demucs_model, 52 | args=argparse.Namespace(**{"segment": None, "name": "mdx"}), 53 | ), 54 | "target": "partition_with_demucs", 55 | }, 56 | "meta_demucs_mdx_q": { 57 | "model": partial( 58 | load_demucs_model, 59 | args=argparse.Namespace(**{"segment": None, "name": "mdx_q"}), 60 | ), 61 | "target": "partition_with_demucs", 62 | }, 63 | "meta_demucs_mdx_extra": { 64 | "model": partial( 65 | load_demucs_model, 66 | args=argparse.Namespace(**{"segment": None, "name": "mdx_extra"}), 67 | ), 68 | "target": "partition_with_demucs", 69 | }, 70 | "meta_demucs_mdx_extra_q": { 71 | "model": partial( 72 | load_demucs_model, 73 | args=argparse.Namespace(**{"segment": None, "name": "mdx_extra_q"}), 74 | ), 75 | "target": "partition_with_demucs", 76 | }, 77 | } 78 | 79 | def partition_with_demucs( 80 | self, audio_path, save_to_file=False, save_partitions=None, **kwargs 81 | ): 82 | stem = "vocals" 83 | ext = kwargs.get("ext", "wav") 84 | float32 = False # output as float 32 wavs, unsused if 'mp3' is True. 85 | int24 = False 86 | segment = kwargs.get("segment", 15) 87 | 88 | if exists(save_partitions) and not isinstance(save_partitions, int): 89 | save_partitions = [int(save_partitions)] 90 | if save_to_file: 91 | save_dir = kwargs.get("save_dir") or osp.join( 92 | settings.CACHE_DIR, 93 | "tmp", 94 | "partitions", 95 | f"{osp.splitext(osp.split(audio_path)[-1])[0]}", 96 | ) 97 | 98 | max_allowed_segment = float("inf") 99 | if isinstance(self.model["model"], HTDemucs): 100 | max_allowed_segment = float(self.model["model"].segment) 101 | elif isinstance(self.model["model"], BagOfModels): 102 | max_allowed_segment = self.model["model"].max_allowed_segment 103 | if segment is not None and segment > max_allowed_segment: 104 | fatal( 105 | "Cannot use a Transformer model with a longer segment " 106 | f"than it was trained for. Maximum segment is: {max_allowed_segment}" 107 | ) 108 | 109 | if stem is not None and stem not in self.model["model"].sources: 110 | fatal( 111 | 'error: stem "{stem}" is not in selected model. STEM must be one of {sources}.'.format( 112 | stem=stem, sources=", ".join(self.model["model"].sources) 113 | ) 114 | ) 115 | 116 | if not osp.isfile(audio_path): 117 | print( 118 | f"File {audio_path} does not exist. If the path contains spaces, " 119 | 'please try again after surrounding the entire path with quotes "".', 120 | file=sys.stderr, 121 | ) 122 | raise FileNotFoundError(audio_path) 123 | 124 | print(f"Separating track {audio_path}") 125 | wav = load_track( 126 | audio_path, 127 | self.model["model"].audio_channels, 128 | self.model["model"].samplerate, 129 | ) 130 | 131 | ref = wav.mean(0) 132 | wav -= ref.mean() 133 | wav /= ref.std() 134 | sources = apply_model( 135 | self.model["model"], 136 | wav[None], 137 | device="cpu", 138 | shifts=kwargs.get("shifts", 1), 139 | split=kwargs.get("split", False), 140 | overlap=kwargs.get("overlap", 0.25), 141 | progress=True, 142 | num_workers=0, 143 | segment=segment, 144 | )[0] 145 | sources *= ref.std() 146 | sources += ref.mean() 147 | 148 | kwargs = { 149 | "samplerate": self.model["model"].samplerate, 150 | "bitrate": kwargs.get("mp3_bitrate"), 151 | "preset": kwargs.get("mp3_preset"), 152 | "clip": kwargs.get("clip_mode", "rescale"), 153 | "as_float": float32, 154 | "bits_per_sample": 24 if int24 else 16, 155 | } 156 | if stem is None: 157 | partitions = [] 158 | for source, name in zip(sources, self.model["model"].sources): 159 | if save_to_file and ( 160 | not exists(save_partitions) or len(partitions) in save_partitions 161 | ): 162 | save_path = self.save_to_file( 163 | source, 164 | save_dir=osp.join(save_dir, f"partition_{len(partitions)}.wav"), 165 | ) 166 | partitions.append(save_path) 167 | else: 168 | partitions.append(source) 169 | else: 170 | sources = list(sources) 171 | partitions = [] 172 | for i in range(0, 2): 173 | if not i: 174 | source = sources.pop(self.model["model"].sources.index(stem)) 175 | else: 176 | # Warning : after poping the stem, selected stem is no longer in the list 'sources' 177 | source = torch.zeros_like(sources[0]) 178 | for src in sources: 179 | source += src 180 | 181 | if save_to_file and ( 182 | not exists(save_partitions) or len(partitions) in save_partitions 183 | ): 184 | save_path = self.save_to_file( 185 | source, 186 | save_dir=save_dir, 187 | ) 188 | partitions.append(save_path) 189 | else: 190 | partitions.append(source) 191 | 192 | return torch.vstack(partitions) if not save_to_file else partitions 193 | 194 | def predict(self, audio_path, **kwargs) -> torch.Tensor: 195 | if hasattr(self.model, "separate_file"): 196 | paritions = self.model.separate_file(audio_path) 197 | else: 198 | raise NotImplementedError( 199 | f"{self.model_choice} doesn't have any supported methods" 200 | ) 201 | return paritions 202 | -------------------------------------------------------------------------------- /modules/transcribe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import whisper 3 | from transformers import pipeline 4 | from os import path as osp 5 | from functools import partial 6 | from .lang_list import LANGUAGE_NAME_TO_CODE 7 | from .common import Base 8 | from . import audio as audio_ops 9 | from utils.helpers import exists 10 | from config import settings 11 | 12 | cache_dir = osp.join(settings.CACHE_DIR, "weights", "transcription") 13 | 14 | 15 | class TranscribeAudio(Base): 16 | MODEL_CHOICES = { 17 | "openai_whisper_base": partial( 18 | whisper.load_model, 19 | name="base", 20 | download_root=osp.join(cache_dir, "openai-whisper-base"), 21 | ), 22 | "openai_whisper_medium": partial( 23 | whisper.load_model, 24 | name="medium", 25 | download_root=osp.join(cache_dir, "openai-whisper-medium"), 26 | ), 27 | "openai_whisper_large": partial( 28 | whisper.load_model, 29 | name="large", 30 | download_root=osp.join(cache_dir, "openai-whisper-large"), 31 | ), 32 | } 33 | def predict( 34 | self, audio_path: str = None, audio: torch.Tensor = None, **kwargs 35 | ) -> str: 36 | if exists(audio_path): 37 | if isinstance(self.model, whisper.Whisper): 38 | transcription = self.model.transcribe(audio_path)["text"] 39 | else: 40 | raise NotImplementedError( 41 | f"{self.model_choice} doesn't have any supported methods" 42 | ) 43 | else: 44 | transcription = self.model.transcribe(audio)["text"] 45 | return transcription 46 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | environs 3 | bson 4 | shortuuid 5 | python-json-logger 6 | pydantic 7 | requests 8 | omegaconf 9 | transformers 10 | torchvision 11 | denoiser 12 | webrtcvad 13 | sphfile 14 | pytube 15 | wget 16 | voicefixer 17 | audiosr==0.0.5 18 | librosa 19 | git+https://github.com/openai/whisper.git 20 | git+https://github.com/facebookresearch/demucs#egg=demucs 21 | git+https://github.com/pyannote/pyannote-audio.git -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /utils/data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import shortuuid 3 | import pandas as pd 4 | 5 | def normalize_json(obj, parent_key=None): 6 | updated_data = {} 7 | 8 | def _normalize(obj, parent_key): 9 | nonlocal updated_data 10 | for key, value in obj.items(): 11 | new_key = f"{parent_key}.{key}" if parent_key else key 12 | 13 | if isinstance(value, dict): 14 | _normalize(value, new_key) 15 | elif isinstance(value, list): 16 | if not value or not isinstance(value[0], dict): 17 | updated_data[new_key] = json.dumps(value) 18 | else: 19 | _normalize(value[0], new_key) 20 | else: 21 | updated_data[new_key] = value 22 | 23 | if isinstance(obj, list): 24 | for item in obj: 25 | _normalize(item, parent_key) 26 | else: 27 | _normalize(obj, parent_key) 28 | 29 | if isinstance(updated_data, dict): 30 | records, meta = [], [] 31 | for k, v in updated_data.items(): 32 | if isinstance(v, list): 33 | records.append(k) 34 | else: 35 | meta.append(k) 36 | 37 | dfs = [ 38 | pd.json_normalize( 39 | updated_data, 40 | record_path=record, 41 | record_prefix=f"{record}.", 42 | meta=meta, 43 | ) 44 | for record in records 45 | ] 46 | 47 | df = pd.concat(dfs, ignore_index=True) 48 | index = [str(shortuuid.uuid()) for _ in range(df.shape[0])] 49 | return df.rename(index=dict(zip(list(df.index), index))) 50 | 51 | return updated_data -------------------------------------------------------------------------------- /utils/helpers.py: -------------------------------------------------------------------------------- 1 | from hashlib import md5, sha1, sha256 2 | import json 3 | import random 4 | import time 5 | import importlib 6 | from os.path import isfile 7 | from pathlib import Path 8 | from types import SimpleNamespace 9 | from uuid import uuid4 10 | 11 | from bson import ObjectId 12 | from pydantic import BaseModel 13 | 14 | EXCLUDED_SPECIAL_FIELDS = "exclude_special_fields" 15 | 16 | 17 | class NestedNamespace(SimpleNamespace): 18 | def __init__(self, dictionary): 19 | for key, value in dictionary.items(): 20 | if isinstance(value, dict): 21 | value = NestedNamespace(value) 22 | elif isinstance(value, list): 23 | value = tuple(NestedNamespace(val) if isinstance(val, (dict, list)) else val for val in value) 24 | setattr(self, key, value) 25 | 26 | 27 | class SpecialExclusionBaseModel(BaseModel): 28 | _special_exclusions: set[str] 29 | 30 | def dict(self, **kwargs): 31 | exclude = kwargs.get("exclude", {}) 32 | if EXCLUDED_SPECIAL_FIELDS in exclude: 33 | exclude = {k: v for k, v in super().dict(**kwargs).items() if k not in self._special_exclusions} 34 | return super().dict(exclude=exclude, **kwargs) 35 | 36 | 37 | def default(value, d): 38 | return value if value is not None else (d() if callable(d) else d) 39 | 40 | def exists(value): 41 | return value is not None 42 | 43 | def get_obj_from_str(string, reload=False): 44 | module, cls = string.rsplit(".", 1) 45 | if reload: 46 | module_imp = importlib.import_module(module) 47 | importlib.reload(module_imp) 48 | return getattr(importlib.import_module(module), cls) 49 | 50 | 51 | def freeze(o): 52 | return frozenset((k, freeze(v)) for k, v in o.items()) if isinstance(o, dict) else (tuple(freeze(v) for v in o) if isinstance(o, list) else o) 53 | 54 | 55 | def make_hash(data, algorithm="sha1", serializer=None, sort=False): 56 | algos = {"sha1": sha1, "sha256": sha256, "md5": md5} 57 | data = default(data, "") 58 | data = sorted(data) if sort else data 59 | data = json.dumps(data).encode("utf-8") if serializer == "json" else (freeze(data) if serializer == "freeze" else data.encode("utf-8")) 60 | return algos[algorithm.lower()](data).hexdigest() 61 | 62 | 63 | def generate_oid(serialize=True): 64 | return str(ObjectId()) if serialize else ObjectId() 65 | -------------------------------------------------------------------------------- /utils/io.py: -------------------------------------------------------------------------------- 1 | from os import path as osp 2 | from pathlib import Path 3 | from tqdm import tqdm 4 | from omegaconf import OmegaConf, DictConfig, ListConfig 5 | import json 6 | import requests 7 | from .validators import uri_validator 8 | 9 | 10 | def load_configs(configs): 11 | if isinstance(configs, (DictConfig, ListConfig)): 12 | return configs 13 | 14 | if isinstance(configs, list): 15 | config = OmegaConf.merge(*[OmegaConf.load(conf) for conf in configs]) if osp.exists(configs[0]) else OmegaConf.create(configs) 16 | else: 17 | config = OmegaConf.load(configs) if osp.exists(configs) else OmegaConf.create(configs) 18 | 19 | return config 20 | 21 | def merge_configs(*confs): 22 | return OmegaConf.merge(*[load_configs(conf) for conf in confs]) 23 | 24 | def load_metadata(metadata_path): 25 | if not osp.isfile(metadata_path): 26 | metadata_path = osp.join(metadata_path, "metadata.jsonl") 27 | 28 | if not osp.isfile(metadata_path): 29 | raise FileNotFoundError(f"'metadata.jsonl' file missing in '{metadata_path}'") 30 | 31 | with open(metadata_path, "r") as file: 32 | ext = osp.splitext(metadata_path)[-1] 33 | return json.load(file) if ext == ".json" else [json.loads(line) for line in file.read().splitlines()] 34 | 35 | def save_metadata(metadata, save_path): 36 | parent_dir = Path(osp.dirname(save_path)) 37 | parent_dir.mkdir(exist_ok=True, parents=True) 38 | 39 | with open(save_path, "w") as f: 40 | if save_path.endswith(".jsonl"): 41 | f.writelines(f"{json.dumps(entry)}\n" for entry in metadata) 42 | elif save_path.endswith(".json"): 43 | json.dump(metadata, f) 44 | elif save_path.endswith(".txt"): 45 | f.write("\n".join(metadata) if isinstance(metadata, list) else metadata) 46 | else: 47 | raise NotImplementedError(f"{osp.splitext(save_path)[-1]} is not supported...") 48 | 49 | def download_file_from_url(url, save_path, show_progress=True): 50 | if not uri_validator(url): 51 | raise ValueError(f"'{url}' doesn't seem to be a valid url") 52 | 53 | parent_dir = Path(osp.dirname(save_path)) 54 | parent_dir.mkdir(exist_ok=True, parents=True) 55 | 56 | with open(save_path, "wb") as handle: 57 | response = requests.get(url, stream=True, timeout=20) 58 | if response.status_code != 200: 59 | raise AssertionError(f"Couldn't download the file, the url returned status code: {response.status_code}") 60 | 61 | for data in tqdm(response.iter_content(), disable=not show_progress): 62 | handle.write(data) -------------------------------------------------------------------------------- /utils/loggers.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import logging 3 | import logging.config 4 | from pathlib import Path 5 | from config import settings 6 | 7 | # Configure external library logging 8 | external_logs = ["requests", "urllib3"] 9 | for log in external_logs: 10 | logging.getLogger(log).setLevel(logging.CRITICAL) 11 | 12 | log_levels = { 13 | "DEBUG": logging.DEBUG, 14 | "WARNING": logging.WARNING, 15 | "INFO": logging.INFO, 16 | "ERROR": logging.ERROR, 17 | "CRITICAL": logging.CRITICAL, 18 | } 19 | 20 | 21 | class ColorFormatter(logging.Formatter): 22 | color_map = { 23 | logging.DEBUG: "\x1b[1;30m", 24 | logging.INFO: "\x1b[0;37m", 25 | logging.WARNING: "\x1b[1;33m", 26 | logging.ERROR: "\x1b[1;31m", 27 | logging.CRITICAL: "\x1b[1;35m", 28 | } 29 | reset = "\x1b[0m" 30 | _format = "%(asctime)s - [%(threadName)-12.12s] [%(levelname)s-5.5s] - %(name)s - (%(filename)s).%(funcName)s(%(lineno)d) - %(message)s" 31 | 32 | def format(self, record): 33 | formatter = logging.Formatter(f"{self.color_map.get(record.levelno)}{self._format}{self.reset}") 34 | return formatter.format(record) 35 | 36 | 37 | def create_file_logger(filename: str, log_level: str = settings.LOG_LEVEL, log_dir: str = settings.LOG_DIR): 38 | log_filename = f"{filename}_{datetime.utcnow().astimezone().strftime('%Y-%m-%dT%H-%M-%S')}" 39 | Path(log_dir).mkdir(exist_ok=True, parents=True) 40 | 41 | logger = logging.getLogger(log_filename) 42 | logger.setLevel(log_levels.get(log_level, logging.DEBUG)) 43 | 44 | file_handler = logging.FileHandler(Path(log_dir) / f"{log_filename}.log") 45 | file_handler.setFormatter(ColorFormatter()) 46 | logger.addHandler(file_handler) 47 | logger.debug(f"{log_filename} logger initialized!!!") 48 | 49 | return logger 50 | 51 | 52 | def init_loggers(log_dir: str = settings.LOG_DIR): 53 | Path(log_dir).mkdir(exist_ok=True, parents=True) 54 | logging.config.fileConfig( 55 | Path(__file__).resolve().parents[1] / "config" / "logging.ini", 56 | disable_existing_loggers=False, 57 | defaults={"logdir": log_dir}, 58 | ) 59 | 60 | 61 | def get_logger(log_name: str, log_level: str = settings.LOG_LEVEL): 62 | logger = logging.getLogger(log_name) 63 | logger.setLevel(log_levels.get(log_level, logging.DEBUG)) 64 | return logger 65 | -------------------------------------------------------------------------------- /utils/validators.py: -------------------------------------------------------------------------------- 1 | from urllib.parse import urlparse 2 | 3 | 4 | def uri_validator(url): 5 | try: 6 | result = urlparse(url) 7 | return all([result.scheme, result.netloc]) 8 | except: 9 | return False 10 | -------------------------------------------------------------------------------- /workers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LAION-AI/Text-to-speech/4d69b12975b3a74f37b11c93edf83d55e133b649/workers/__init__.py -------------------------------------------------------------------------------- /workers/pipeline.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | dir_path = os.path.dirname(os.path.realpath(__file__)) 5 | sys.path.append(os.path.abspath(os.path.join(dir_path, os.pardir))) 6 | 7 | import argparse 8 | from omegaconf import OmegaConf 9 | 10 | from utils.io import load_configs 11 | from utils.helpers import get_obj_from_str 12 | 13 | 14 | def run(configs): 15 | config = load_configs(configs)["pipeline"] 16 | downloader = get_obj_from_str(config["loader"]["target"])( 17 | **config["loader"]["args"] 18 | ) 19 | manager = get_obj_from_str(config["manager"]["target"])( 20 | configs=OmegaConf.to_yaml(config) 21 | ) 22 | for file_metadata in downloader.walk_files(): 23 | manager(file_metadata=file_metadata) 24 | 25 | 26 | if __name__ == "__main__": 27 | parser = argparse.ArgumentParser( 28 | "TTS audio pipeline", 29 | description="Run data workers for preparing tts audio datasets", 30 | ) 31 | parser.add_argument( 32 | "--config", 33 | nargs="+", 34 | help="Config file path for pipeline orchestration" 35 | "Config will be merged from left to right", 36 | ) 37 | args = parser.parse_args() 38 | 39 | run(args.config) 40 | --------------------------------------------------------------------------------