├── LICENSE ├── MusicSourceSeparation ├── eval.py └── infer │ ├── .env.example │ ├── .gitignore │ ├── .pre-commit-config.yaml │ ├── .project-root │ ├── LICENSE │ ├── README.assets │ └── eval.png │ ├── README.md │ ├── conda_env_gpu.yaml │ ├── configs │ ├── callbacks │ │ ├── default.yaml │ │ ├── none.yaml │ │ └── wandb.yaml │ ├── config.yaml │ ├── datamodule │ │ ├── musdb18_hq.yaml │ │ └── musdb_dev14.yaml │ ├── evaluation.yaml │ ├── experiment │ │ ├── bass_dis.yaml │ │ ├── drums_dis.yaml │ │ ├── multigpu_default.yaml │ │ ├── other_dis.yaml │ │ └── vocals_dis.yaml │ ├── hydra │ │ └── default.yaml │ ├── infer_bass.yaml │ ├── infer_drums.yaml │ ├── infer_other.yaml │ ├── infer_vocals.yaml │ ├── logger │ │ ├── csv.yaml │ │ ├── many_loggers.yaml │ │ ├── neptune.yaml │ │ ├── none.yaml │ │ ├── tensorboard.yaml │ │ └── wandb.yaml │ ├── model │ │ ├── bass.yaml │ │ ├── drums.yaml │ │ ├── other.yaml │ │ └── vocals.yaml │ ├── paths │ │ └── default.yaml │ └── trainer │ │ ├── ddp.yaml │ │ ├── default.yaml │ │ └── minimal.yaml │ ├── requirements.txt │ ├── run_infer.py │ ├── setup.cfg │ ├── src │ ├── __init__.py │ ├── callbacks │ │ ├── __init__.py │ │ ├── onnx_callback.py │ │ └── wandb_callbacks.py │ ├── datamodules │ │ ├── __init__.py │ │ ├── datasets │ │ │ ├── __init__.py │ │ │ └── musdb.py │ │ └── musdb_datamodule.py │ ├── dp_tdf │ │ ├── __init__.py │ │ ├── abstract.py │ │ ├── bandsequence.py │ │ ├── dp_tdf_net.py │ │ └── modules.py │ ├── evaluation │ │ ├── eval.py │ │ ├── eval_demo.py │ │ └── separate.py │ ├── layers │ │ ├── __init__.py │ │ ├── batch_norm.py │ │ └── chunk_size.py │ ├── train.py │ └── utils │ │ ├── __init__.py │ │ ├── data_augmentation.py │ │ ├── omega_resolvers.py │ │ ├── pick_best.py │ │ ├── pylogger.py │ │ ├── rich_utils.py │ │ └── utils.py │ ├── tests │ ├── __init__.py │ ├── helpers │ │ ├── __init__.py │ │ ├── module_available.py │ │ ├── run_command.py │ │ └── runif.py │ ├── smoke │ │ ├── __init__.py │ │ ├── test_commands.py │ │ ├── test_mixed_precision.py │ │ ├── test_sweeps.py │ │ └── test_wandb.py │ ├── submit │ │ └── to_onnx.py │ └── unit │ │ ├── __init__.py │ │ └── test_sth.py │ └── train.py ├── README.md └── SpeechEnhancement ├── eval ├── intrusive.py └── non-intrusive.py └── infer ├── infer_blind_utmos.py ├── infer_large.py ├── infer_pesq.py └── infer_sdr.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MusicSourceSeparation/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import soundfile as sf 4 | from pathlib import Path 5 | from typing import Dict, Tuple, List 6 | import multiprocessing as mp 7 | from functools import partial 8 | import csv 9 | import argparse 10 | from tqdm import tqdm 11 | from datetime import datetime 12 | 13 | def sdr( 14 | ref: np.ndarray, 15 | est: np.ndarray, 16 | eps: float = 1e-10 17 | ): 18 | """Calculate SDR.""" 19 | noise = est - ref 20 | numerator = np.clip(a=np.mean(ref ** 2), a_min=eps, a_max=None) 21 | denominator = np.clip(a=np.mean(noise ** 2), a_min=eps, a_max=None) 22 | sdr = 10. * np.log10(numerator / denominator) 23 | return sdr 24 | 25 | 26 | def load_audio(file_path: str) -> Tuple[np.ndarray, int]: 27 | """Load audio file and return (audio, sample_rate).""" 28 | data, sr = sf.read(file_path) 29 | return data, sr 30 | 31 | 32 | def calculate_utterance_sdr(ref_audio: np.ndarray, est_audio: np.ndarray) -> float: 33 | """ 34 | Calculate utterance-level SDR for stereo audio. 35 | Treats stereo as two mono files and averages the SDR. 36 | """ 37 | # Ensure both arrays have the same shape 38 | min_len = min(len(ref_audio), len(est_audio)) 39 | ref_audio = ref_audio[:min_len] 40 | est_audio = est_audio[:min_len] 41 | 42 | return sdr(ref_audio, est_audio) 43 | 44 | 45 | def calculate_chunk_sdr(ref_audio: np.ndarray, est_audio: np.ndarray, 46 | chunk_duration: float = 1.0, sr: int = 44100) -> float: 47 | """ 48 | Calculate chunk-level SDR (median of median SDR per channel). 49 | """ 50 | # Ensure both arrays have the same shape 51 | min_len = min(len(ref_audio), len(est_audio)) 52 | ref_audio = ref_audio[:min_len] 53 | est_audio = est_audio[:min_len] 54 | 55 | # Calculate chunk size in samples 56 | chunk_size = int(chunk_duration * sr) 57 | 58 | sdrs = [] 59 | for start in range(0, len(ref_audio) - chunk_size + 1, chunk_size): 60 | end = start + chunk_size 61 | sdrs.append(sdr(ref_audio[start:end], est_audio[start:end])) 62 | 63 | return np.median(sdrs) 64 | 65 | 66 | def process_single_song(args): 67 | """Process a single song and return SDR metrics.""" 68 | song_dir, source_path, target_stem, source_stem, source_ext = args 69 | 70 | song_name = song_dir.name 71 | target_file = song_dir / target_stem 72 | source_file = source_path / f"{song_name}{source_ext}" 73 | 74 | if not target_file.exists() or not source_file.exists(): 75 | print(f"Skipping {song_name}: Files not found") 76 | return None 77 | 78 | print(f"Processing: {song_name}") 79 | 80 | try: 81 | # Load audio files 82 | ref_audio, ref_sr = load_audio(str(target_file)) 83 | est_audio, est_sr = load_audio(str(source_file)) 84 | 85 | # Resample if necessary 86 | if ref_sr != est_sr: 87 | try: 88 | import librosa 89 | est_audio = librosa.resample(est_audio.T, orig_sr=est_sr, target_sr=ref_sr).T 90 | est_sr = ref_sr 91 | except ImportError: 92 | print(f"Error: librosa not installed. Cannot resample {song_name}") 93 | return None 94 | 95 | # Calculate SDR metrics 96 | utterance_sdr = calculate_utterance_sdr(ref_audio, est_audio) 97 | chunk_sdr = calculate_chunk_sdr(ref_audio, est_audio, sr=ref_sr) 98 | 99 | result = { 100 | 'song_name': song_name, 101 | 'utterance_sdr': utterance_sdr, 102 | 'chunk_sdr': chunk_sdr 103 | } 104 | 105 | print(f" {song_name} - Utterance SDR: {utterance_sdr:.2f} dB, Chunk SDR: {chunk_sdr:.2f} dB") 106 | 107 | return result 108 | 109 | except Exception as e: 110 | print(f"Error processing {song_name}: {e}") 111 | return None 112 | 113 | 114 | def process_songs_parallel(target_dir: str, source_dir: str, target_stem: str = "bass.wav", 115 | source_stem: str = "bass", source_ext: str = ".flac", 116 | step_num: int = 0, num_workers: int = None) -> List[Dict]: 117 | """ 118 | Process all songs using multiprocessing and calculate SDR metrics. 119 | """ 120 | target_path = Path(target_dir) 121 | source_path = Path(source_dir) / source_stem / f"step_{step_num}" 122 | 123 | # Prepare arguments for multiprocessing 124 | song_dirs = [d for d in target_path.iterdir() if d.is_dir()] 125 | args_list = [(song_dir, source_path, target_stem, source_stem, source_ext) 126 | for song_dir in song_dirs] 127 | 128 | # Set number of workers 129 | if num_workers is None: 130 | num_workers = mp.cpu_count() 131 | 132 | print(f"Using {num_workers} workers for parallel processing") 133 | 134 | # Process songs in parallel 135 | with mp.Pool(num_workers) as pool: 136 | results = pool.map(process_single_song, args_list) 137 | 138 | # Filter out None results (failed processing) 139 | results = [r for r in results if r is not None] 140 | 141 | return results 142 | 143 | 144 | def save_results_to_csv(results: List[Dict], output_file: str): 145 | """Save results to a CSV file.""" 146 | if not results: 147 | print("No results to save") 148 | return 149 | 150 | # Create CSV file 151 | with open(output_file, 'w', newline='') as csvfile: 152 | fieldnames = ['song_name', 'utterance_sdr', 'chunk_sdr'] 153 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames) 154 | 155 | writer.writeheader() 156 | for result in results: 157 | writer.writerow(result) 158 | 159 | print(f"Results saved to {output_file}") 160 | 161 | 162 | def main(): 163 | parser = argparse.ArgumentParser(description='Calculate SDR metrics for audio files') 164 | parser.add_argument('--target-dir', type=str, default='/root/autodl-tmp/test', 165 | help='Directory containing target audio files') 166 | parser.add_argument('--source-dir', type=str, default='/root/autodl-tmp/inferred_results', 167 | help='Directory containing source audio files') 168 | # parser.add_argument('--target-stem', type=str, default='drums.wav', 169 | # help='Target stem filename') 170 | # parser.add_argument('--source-stem', type=str, default='drums', 171 | # help='Source stem name') 172 | parser.add_argument('--source-ext', type=str, default='.flac', 173 | help='Source file extension') 174 | # parser.add_argument('--step', type=int, default=10, 175 | # help='Step number for source files') 176 | parser.add_argument('--workers', type=int, default=8, 177 | help='Number of worker processes (default: CPU count)') 178 | 179 | args = parser.parse_args() 180 | 181 | stems = ["vocals", "drums", "bass", "other"] 182 | 183 | for stem in tqdm(stems): 184 | for step in tqdm(range(21)): 185 | # Generate output filename if not provided 186 | 187 | 188 | print(f"Processing files from step_{step}") 189 | 190 | # Process all songs with multiprocessing 191 | results = process_songs_parallel( 192 | target_dir=args.target_dir, 193 | source_dir=args.source_dir, 194 | target_stem=stem + ".wav", 195 | source_stem=stem, 196 | source_ext=args.source_ext, 197 | step_num=step, 198 | num_workers=args.workers 199 | ) 200 | 201 | # Save results to CSV 202 | output_csv = stem + "_step_" + str(step) + ".csv" 203 | save_results_to_csv(results, output_csv) 204 | 205 | # Calculate overall metrics 206 | if results: 207 | utterance_sdrs = [r['utterance_sdr'] for r in results] 208 | chunk_sdrs = [r['chunk_sdr'] for r in results] 209 | 210 | mean_utterance_sdr = np.mean(utterance_sdrs) 211 | median_chunk_sdr = np.median(chunk_sdrs) 212 | 213 | print("\n=== Overall Results ===") 214 | print(f"Mean Utterance-level SDR: {mean_utterance_sdr:.4f} dB") 215 | print(f"Median Chunk-level SDR: {median_chunk_sdr:.4f} dB") 216 | print(f"Number of songs processed: {len(results)}") 217 | 218 | # Also save summary to CSV 219 | summary_file = output_csv.replace('.csv', '_summary.csv') 220 | with open(summary_file, 'w', newline='') as csvfile: 221 | writer = csv.writer(csvfile) 222 | writer.writerow(['Metric', 'Value']) 223 | writer.writerow(['uSDR (dB)', f'{mean_utterance_sdr:.4f}']) 224 | writer.writerow(['cSDR (dB)', f'{median_chunk_sdr:.4f}']) 225 | 226 | print(f"Summary saved to {summary_file}") 227 | 228 | 229 | if __name__ == "__main__": 230 | main() -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/.env.example: -------------------------------------------------------------------------------- 1 | # the folder of musdb18hq dataset, contains train and test folders 2 | data_dir=xxx 3 | 4 | PROJECT_ROOT=xxx 5 | 6 | # the directory where the checkpoints will be saved 7 | LOG_DIR=xxxx 8 | 9 | # You don't have to fill this if you want to use tensorboard only. 10 | wandb_api_key=xxxx # go wandb.ai/settings and copy your key 11 | 12 | # Number of cores of your cpu, will be used in dataloaders 13 | NUM_WORKERS=0 14 | 15 | HYDRA_FULL_ERROR=1 16 | 17 | 18 | -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | .env 10 | 11 | predictions/ 12 | tmp/ 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | pip-wheel-metadata/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | 135 | ### VisualStudioCode 136 | .vscode/* 137 | !.vscode/settings.json 138 | !.vscode/tasks.json 139 | !.vscode/launch.json 140 | !.vscode/extensions.json 141 | *.code-workspace 142 | **/.vscode 143 | 144 | # JetBrains 145 | .idea/ 146 | 147 | # Lightning-Hydra-Template 148 | #data/ 149 | logs/ 150 | wandb/ 151 | .env 152 | .autoenv 153 | 154 | onnx/* 155 | outputs/* 156 | 157 | assets/ 158 | 159 | notebooks/ 160 | 161 | pretrained/ -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3.8 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v3.4.0 7 | hooks: 8 | # list of supported hooks: https://pre-commit.com/hooks.html 9 | - id: trailing-whitespace 10 | - id: end-of-file-fixer 11 | - id: check-yaml 12 | - id: check-added-large-files 13 | - id: debug-statements 14 | - id: detect-private-key 15 | 16 | # python code formatting 17 | - repo: https://github.com/psf/black 18 | rev: 20.8b1 19 | hooks: 20 | - id: black 21 | args: [--line-length, "99"] 22 | 23 | # python import sorting 24 | - repo: https://github.com/PyCQA/isort 25 | rev: 5.8.0 26 | hooks: 27 | - id: isort 28 | 29 | # yaml formatting 30 | - repo: https://github.com/pre-commit/mirrors-prettier 31 | rev: v2.3.0 32 | hooks: 33 | - id: prettier 34 | types: [yaml] 35 | 36 | # python code analysis 37 | - repo: https://github.com/PyCQA/flake8 38 | rev: 3.9.2 39 | hooks: 40 | - id: flake8 41 | -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/.project-root: -------------------------------------------------------------------------------- 1 | # this file is required for inferring the project root directory 2 | # do not delete -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/README.assets/eval.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yongyizang/TrainingFreeMultiStepASR/49333f1050455f25649cbec902bba0d527561b75/MusicSourceSeparation/infer/README.assets/eval.png -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/README.md: -------------------------------------------------------------------------------- 1 | # Dual-Path TFC-TDF UNet 2 | 3 | A Pytorch Implementation of the ICASSP 2024 paper: Dual-Path TFC-TDF UNet for Music Source Separation. DTTNet achieves 10.12 dB cSDR on vocals with 86% fewer parameters compared to BSRNN (SOTA). 4 | 5 | Link to our paper: 6 | 7 | - arXiv (Accepted Version): https://arxiv.org/abs/2309.08684 8 | - IEEE Xplore (Published Version): https://ieeexplore.ieee.org/document/10448020 9 | 10 | 11 | 12 | ## Notes 13 | 14 | 1. Overlap-add is switched on by default, comment the **values** of key ```overlap_add``` in ```configs\infer``` and ```configs\evaluation``` to switch it off and the inference time will be 4x faster. 15 | 16 | ![eval](README.assets/eval.png) 17 | 18 | 19 | 20 | 21 | 22 | ## Environment Setup (First Time) 23 | 24 | 1. Download MUSDB18HQ from https://sigsep.github.io/datasets/musdb.html 25 | 2. (Optional) Edit the validation_set in configs/datamodule/musdb_dev14.yaml 26 | 3. Create Miniconda/Anaconda environment 27 | 28 | ``` 29 | conda env create -f conda_env_gpu.yaml -n DTT 30 | source /root/miniconda3/etc/profile.d/conda.sh 31 | conda activate DTT 32 | pip install -r requirements.txt 33 | export PYTHONPATH=$PYTHONPATH:$(pwd) # for Windows, replace the 'export' with 'set' 34 | ``` 35 | 36 | 4. Edit .env file according to the instructions. It is recommended to use wandb to manage the logs. 37 | 38 | ``` 39 | cp .env.example .env 40 | vim .env 41 | ``` 42 | 43 | 44 | 45 | ## Environment Setup (After First Time) 46 | 47 | Once all these settings are configured, the next time you simply need to execute these code snippets to set up the environment 48 | 49 | ``` 50 | source /root/miniconda3/etc/profile.d/conda.sh 51 | conda activate DTT 52 | ``` 53 | 54 | 55 | 56 | ## Inference 57 | 58 | 1. Download checkpoints from either: 59 | - https://mega.nz/folder/E4c1QD7Z#OkgM_dEK1tC5MzpqEBuxvQ 60 | - https://pan.baidu.com/s/1Tw6Fp6wYVZTqjRE-aDZdyw (code: lgwe) 61 | 2. Run code 62 | 63 | ``` 64 | python run_infer.py model=vocals ckpt_path=xxxxx mixture_path=xxxx 65 | ``` 66 | 67 | The files will be saved under the folder ```PROJECT_ROOT\infer\songname_suffix\``` 68 | 69 | 70 | 71 | Parameter Options: 72 | 73 | - model=vocals, model=bass, model=drums, model=other 74 | 75 | 76 | 77 | ## Evaluation 78 | 79 | Change ```pool_workers``` in ```configs\evaluation```. You can set the number as the number of cores in your CPU. 80 | 81 | ``` 82 | export ckpt_path=xxx # for Windows, replace the 'export' with 'set' 83 | 84 | python run_eval.py model=vocals logger.wandb.name=xxxx 85 | 86 | # or if you don't want to use logger 87 | python run_eval.py model=vocals logger=[] 88 | ``` 89 | 90 | The result will be saved as eval.csv under the folder ```LOG_DIR\basename(ckpt_path)_suffix``` 91 | 92 | 93 | 94 | Parameter Options: 95 | 96 | - model=vocals, model=bass, model=drums, model=other 97 | 98 | 99 | 100 | 101 | 102 | ## Train 103 | 104 | Note that you will need: 105 | 106 | - 1 TB disk space for data augmentation. 107 | - Otherwise, edit ```configs/datamodule/musdb18_hq.yaml``` so that: 108 | - ```aug_params=[]```. This will train the model without data augmentation. 109 | - 2 A40 (48GB). Or equivalently, 4 RTX 3090 (24 GB). 110 | - Otherwise, edit ```configs/experiment/vocals_dis.yaml``` so that: 111 | - ```datamodule.batch_size``` is smaller 112 | - ```trainer.devices:1``` 113 | - ```model.bn_norm: BN``` 114 | - delete```trainer.sync_batchnorm``` 115 | 116 | ### 1. Data Partition 117 | ``` 118 | python demos/split_dataset.py # data partition 119 | ``` 120 | 121 | ### 2. Data Augmentation (Optional) 122 | 123 | ``` 124 | # install aug tools 125 | sudo apt-get update 126 | sudo apt-get install soundstretch 127 | 128 | mkdir /root/autodl-tmp/tmp 129 | 130 | # perform augumentation 131 | python src/utils/data_augmentation.py --data_dir /root/autodl-tmp/musdb18hq/ 132 | ``` 133 | 134 | ### 3. Run code 135 | 136 | ``` 137 | python train.py experiment=vocals_dis datamodule=musdb_dev14 trainer=default 138 | 139 | # or if you don't want to use logger 140 | 141 | python train.py experiment=vocals_dis datamodule=musdb_dev14 trainer=default logger=[] 142 | ``` 143 | 144 | The 5 best models will be saved under ```LOG_DIR\dtt_vocals_suffix\checkpoints``` 145 | 146 | ### 4. Pick the best model 147 | 148 | ``` 149 | # edit api_key and path 150 | python src/utils/pick_best.py 151 | ``` 152 | 153 | 154 | 155 | ## Bespoke Fine-tune 156 | 157 | ``` 158 | git checkout bespoke 159 | ``` 160 | 161 | 162 | 163 | ## Referenced Repositories 164 | 165 | 1. TFC-TDF UNet 166 | 1. https://github.com/kuielab/sdx23 167 | 2. https://github.com/kuielab/mdx-net 168 | 3. https://github.com/ws-choi/sdx23 169 | 4. https://github.com/ws-choi/ISMIR2020_U_Nets_SVS 170 | 2. BandSplitRNN 171 | 1. https://github.com/amanteur/BandSplitRNN-Pytorch 172 | 3. fast-reid (Sync BN) 173 | 1. https://github.com/JDAI-CV/fast-reid 174 | 4. Zero_Shot_Audio_Source_Separation (overlap-add) 175 | 1. https://github.com/RetroCirce/Zero_Shot_Audio_Source_Separation 176 | 177 | 178 | 179 | ## Cite 180 | 181 | ``` 182 | @INPROCEEDINGS{chen_dttnet_2024, 183 | author={Chen, Junyu and Vekkot, Susmitha and Shukla, Pancham}, 184 | booktitle={ICASSP 2024 - 2024 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 185 | title={Music Source Separation Based on a Lightweight Deep Learning Framework (DTTNET: DUAL-PATH TFC-TDF UNET)}, 186 | year={2024}, 187 | volume={}, 188 | number={}, 189 | pages={656-660}, 190 | keywords={Deep learning;Time-frequency analysis;Source separation;Target tracking;Convolution;Market research;Acoustics;source separation;music;audio;dual-path;deep learning}, 191 | doi={10.1109/ICASSP48485.2024.10448020}} 192 | ``` 193 | 194 | 195 | -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/conda_env_gpu.yaml: -------------------------------------------------------------------------------- 1 | name: mdx-net 2 | 3 | channels: 4 | - pytorch 5 | - nvidia 6 | - anaconda 7 | 8 | dependencies: 9 | - python<=3.9 10 | - pip 11 | - cudatoolkit 12 | - pytorch=1.13.1 13 | - torchaudio 14 | - ffmpeg==4.3 15 | 16 | - pip: 17 | - -r requirements.txt 18 | -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/configs/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | model_checkpoint: 2 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 3 | monitor: "val/usdr" # name of the logged metric which determines when model is improving 4 | save_top_k: 5 # save k best models (determined by above metric) 5 | save_last: True # additionaly always save model from last epoch 6 | mode: "max" # can be "max" or "min" 7 | verbose: False 8 | dirpath: "checkpoints/" 9 | filename: "{epoch:02d}-{step}" 10 | # 11 | #early_stopping: 12 | # _target_: pytorch_lightning.callbacks.EarlyStopping 13 | # monitor: "val/sdr" # name of the logged metric which determines when model is improving 14 | # patience: 300 # how many epochs of not improving until training stops 15 | # mode: "max" # can be "max" or "min" 16 | # min_delta: 0.05 # minimum change in the monitored metric needed to qualify as an improvement 17 | 18 | #make_onnx: 19 | # _target_: src.callbacks.onnx_callback.MakeONNXCallback 20 | # dirpath: "onnx/" -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/configs/callbacks/none.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yongyizang/TrainingFreeMultiStepASR/49333f1050455f25649cbec902bba0d527561b75/MusicSourceSeparation/infer/configs/callbacks/none.yaml -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/configs/callbacks/wandb.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | watch_model: 5 | _target_: src.callbacks.wandb_callbacks.WatchModel 6 | log: "all" 7 | log_freq: 100 8 | 9 | #upload_valid_track: 10 | # _target_: src.callbacks.wandb_callbacks.UploadValidTrack 11 | # crop: 3 12 | # upload_after_n_epoch: -1 13 | 14 | #upload_code_as_artifact: 15 | # _target_: src.callbacks.wandb_callbacks.UploadCodeAsArtifact 16 | # code_dir: ${work_dir}/src 17 | # 18 | #upload_ckpts_as_artifact: 19 | # _target_: src.callbacks.wandb_callbacks.UploadCheckpointsAsArtifact 20 | # ckpt_dir: "checkpoints/" 21 | # upload_best_only: True 22 | # 23 | #log_f1_precision_recall_heatmap: 24 | # _target_: src.callbacks.wandb_callbacks.LogF1PrecRecHeatmap 25 | # 26 | #log_confusion_matrix: 27 | # _target_: src.callbacks.wandb_callbacks.LogConfusionMatrix 28 | # 29 | #log_image_predictions: 30 | # _target_: src.callbacks.wandb_callbacks.LogImagePredictions 31 | # num_samples: 8 32 | -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/configs/config.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default training configuration 4 | defaults: 5 | - datamodule: musdb18_hq 6 | - model: null 7 | - callbacks: default # set this to null if you don't want to use callbacks 8 | - logger: null # set logger here or use command line (e.g. `python run.py logger=wandb`) 9 | - trainer: default 10 | - hparams_search: null 11 | - paths: default.yaml 12 | 13 | - hydra: default 14 | 15 | - experiment: null 16 | 17 | # enable color logging 18 | - override hydra/hydra_logging: colorlog 19 | - override hydra/job_logging: colorlog 20 | 21 | 22 | # path to original working directory 23 | # hydra hijacks working directory by changing it to the current log directory, 24 | # so it's useful to have this path as a special variable 25 | # learn more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory 26 | #work_dir: ${hydra:runtime.cwd} 27 | #output_dir: ${hydra:runtime.output_dir} 28 | 29 | # path to folder with data 30 | 31 | 32 | # use `python run.py debug=true` for easy debugging! 33 | # this will run 1 train, val and test loop with only 1 batch 34 | # equivalent to running `python run.py trainer.fast_dev_run=true` 35 | # (this is placed here just for easier access from command line) 36 | debug: False 37 | 38 | # pretty print config at the start of the run using Rich library 39 | print_config: True 40 | 41 | # disable python warnings if they annoy you 42 | ignore_warnings: True 43 | 44 | wandb_api_key: ${oc.env:wandb_api_key} -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/configs/datamodule/musdb18_hq.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.datamodules.musdb_datamodule.MusdbDataModule 2 | 3 | # data_dir is specified in config.yaml 4 | data_dir: null 5 | 6 | single_channel: False 7 | 8 | # chunk_size = (hop_length * (dim_t - 1) / sample_rate) secs 9 | sample_rate: 44100 10 | hop_length: ${model.hop_length} # stft hop_length 11 | dim_t: ${model.dim_t} # number of stft frames 12 | 13 | # number of overlapping wave samples between chunks when separating a whole track 14 | overlap: ${model.overlap} 15 | 16 | source_names: 17 | - bass 18 | - drums 19 | - other 20 | - vocals 21 | target_name: ${model.target_name} 22 | 23 | external_datasets: null 24 | #external_datasets: 25 | # - test 26 | 27 | 28 | batch_size: 8 29 | num_workers: 0 30 | pin_memory: False 31 | 32 | aug_params: 33 | - 2 # maximum pitch shift in semitones (-x < shift param < x) 34 | - 20 # maximum time stretch percentage (-x < stretch param < x) 35 | 36 | validation_set: 37 | - Actions - One Minute Smile 38 | - Clara Berry And Wooldog - Waltz For My Victims 39 | - Johnny Lokke - Promises & Lies 40 | - Patrick Talbot - A Reason To Leave 41 | - Triviul - Angelsaint 42 | # - Alexander Ross - Goodbye Bolero 43 | # - Fergessen - Nos Palpitants 44 | # - Leaf - Summerghost 45 | # - Skelpolu - Human Mistakes 46 | # - Young Griffo - Pennies 47 | # - ANiMAL - Rockshow 48 | # - James May - On The Line 49 | # - Meaxic - Take A Step 50 | # - Traffic Experiment - Sirens -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/configs/datamodule/musdb_dev14.yaml: -------------------------------------------------------------------------------- 1 | 2 | defaults: 3 | - musdb18_hq 4 | 5 | data_dir: ${oc.env:data_dir} 6 | 7 | has_split_structure: True 8 | 9 | validation_set: 10 | # - Meaxic - Take A Step 11 | # - Skelpolu - Human Mistakes 12 | - Actions - One Minute Smile 13 | - Clara Berry And Wooldog - Waltz For My Victims 14 | - Johnny Lokke - Promises & Lies 15 | - Patrick Talbot - A Reason To Leave 16 | - Triviul - Angelsaint 17 | - Alexander Ross - Goodbye Bolero 18 | - Fergessen - Nos Palpitants 19 | - Leaf - Summerghost 20 | - Skelpolu - Human Mistakes 21 | - Young Griffo - Pennies 22 | - ANiMAL - Rockshow 23 | - James May - On The Line 24 | - Meaxic - Take A Step 25 | - Traffic Experiment - Sirens 26 | 27 | 28 | mode: musdb18hq -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/configs/evaluation.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default training configuration 4 | defaults: 5 | - model: ConvTDFNet_vocals 6 | - logger: 7 | - wandb 8 | - tensorboard 9 | - paths: default.yaml 10 | # enable color logging 11 | - override hydra/hydra_logging: colorlog 12 | - override hydra/job_logging: colorlog 13 | 14 | hydra: 15 | run: 16 | dir: ${get_eval_log_dir:${ckpt_path}} 17 | 18 | #ckpt_path: "G:\\Experiments\\KLRef\\vocals.onnx" 19 | ckpt_path: ${oc.env:ckpt_path} 20 | 21 | split: 'test' 22 | batch_size: 4 23 | device: 'cuda:0' 24 | bss: fast # fast or official 25 | single: False # for debug investigation, only run the model on 1 single song 26 | 27 | #data_dir: ${oc.env:data_dir} 28 | eval_dir: ${oc.env:data_dir} 29 | wandb_api_key: ${oc.env:wandb_api_key} 30 | 31 | logger: 32 | wandb: 33 | # project: mdx_eval_${split} 34 | project: new_eval_order 35 | name: ${get_eval_log_dir:${ckpt_path}} 36 | 37 | pool_workers: 8 38 | double_chunk: False 39 | 40 | overlap_add: 41 | overlap_rate: 0.5 42 | tmp_root: ${paths.root_dir}/tmp # for saving temp chunks, since we use ffmpeg and will need io to disk 43 | samplerate: 44100 -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/configs/experiment/bass_dis.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=example_simple.yaml 5 | 6 | defaults: 7 | - multigpu_default 8 | - override /model: bass.yaml 9 | 10 | seed: 2021 11 | 12 | exp_name: bass_g32 13 | 14 | # the name inside project 15 | logger: 16 | wandb: 17 | name: ${exp_name} 18 | 19 | model: 20 | lr: 0.0002 21 | optimizer: adamW 22 | bn_norm: syncBN 23 | audio_ch: 2 # datamodule.single_channel 24 | g: 32 25 | 26 | trainer: 27 | devices: 2 # int or list 28 | sync_batchnorm: True 29 | track_grad_norm: 2 30 | # gradient_clip_val: 5 31 | 32 | datamodule: 33 | batch_size: 8 34 | num_workers: ${oc.decode:${oc.env:NUM_WORKERS}} 35 | pin_memory: False 36 | overlap: ${model.overlap} 37 | audio_ch: ${model.audio_ch} 38 | epoch_size: 39 | -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/configs/experiment/drums_dis.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=example_simple.yaml 5 | 6 | defaults: 7 | - multigpu_default 8 | - override /model: drums.yaml 9 | 10 | seed: 2021 11 | 12 | exp_name: drums_g32 13 | 14 | # the name inside project 15 | logger: 16 | wandb: 17 | name: ${exp_name} 18 | 19 | model: 20 | lr: 0.0002 21 | optimizer: adamW 22 | bn_norm: syncBN 23 | audio_ch: 2 # datamodule.single_channel 24 | g: 32 25 | 26 | trainer: 27 | devices: 2 # int or list 28 | sync_batchnorm: True 29 | track_grad_norm: 2 30 | # gradient_clip_val: 5 31 | 32 | datamodule: 33 | batch_size: 8 34 | num_workers: ${oc.decode:${oc.env:NUM_WORKERS}} 35 | pin_memory: False 36 | overlap: ${model.overlap} 37 | audio_ch: ${model.audio_ch} 38 | epoch_size: 39 | -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/configs/experiment/multigpu_default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=example_simple.yaml 5 | 6 | defaults: 7 | - override /callbacks: default 8 | - override /logger: 9 | - wandb 10 | - tensorboard 11 | 12 | 13 | #callbacks: 14 | # early_stopping: 15 | # patience: 1000000 16 | 17 | #datamodule: 18 | # external_datasets: 19 | # - test 20 | 21 | trainer: 22 | max_epochs: 1000000 23 | accelerator: cuda 24 | amp_backend: native 25 | precision: 16 26 | track_grad_norm: -1 -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/configs/experiment/other_dis.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=example_simple.yaml 5 | 6 | defaults: 7 | - multigpu_default 8 | - override /model: other.yaml 9 | 10 | seed: 2021 11 | 12 | exp_name: other_g32 13 | 14 | # the name inside project 15 | logger: 16 | wandb: 17 | name: ${exp_name} 18 | 19 | model: 20 | lr: 0.0002 21 | optimizer: adamW 22 | bn_norm: syncBN 23 | audio_ch: 2 # datamodule.single_channel 24 | g: 32 25 | 26 | trainer: 27 | devices: 2 # int or list 28 | sync_batchnorm: True 29 | track_grad_norm: 2 30 | # gradient_clip_val: 5 31 | 32 | datamodule: 33 | batch_size: 8 34 | num_workers: ${oc.decode:${oc.env:NUM_WORKERS}} 35 | pin_memory: False 36 | overlap: ${model.overlap} 37 | audio_ch: ${model.audio_ch} 38 | epoch_size: 39 | -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/configs/experiment/vocals_dis.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=example_simple.yaml 5 | 6 | defaults: 7 | - multigpu_default 8 | - override /model: vocals.yaml 9 | 10 | seed: 2021 11 | 12 | exp_name: vocals_g32 13 | 14 | # the name inside project 15 | logger: 16 | wandb: 17 | name: ${exp_name} 18 | 19 | model: 20 | lr: 0.0002 21 | optimizer: adamW 22 | bn_norm: syncBN 23 | audio_ch: 2 # datamodule.single_channel 24 | g: 32 25 | 26 | trainer: 27 | devices: 2 # int or list 28 | sync_batchnorm: True 29 | track_grad_norm: 2 30 | # gradient_clip_val: 5 31 | 32 | datamodule: 33 | batch_size: 8 34 | num_workers: ${oc.decode:${oc.env:NUM_WORKERS}} 35 | pin_memory: False 36 | overlap: ${model.overlap} 37 | audio_ch: ${model.audio_ch} 38 | epoch_size: 39 | -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/configs/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # output paths for hydra logs 2 | run: 3 | # dir: logs/runs/${datamodule.target_name}_${exp_name}/${now:%Y-%m-%d}/${now:%H-%M-%S} 4 | dir: ${get_train_log_dir:${datamodule.target_name},${exp_name}} 5 | 6 | sweep: 7 | # dir: logs/multiruns/${now:%Y-%m-%d_%H-%M-%S} 8 | dir: ${get_sweep_log_dir:${datamodule.target_name},${exp_name}} 9 | subdir: ${hydra.job.num} 10 | 11 | # you can set here environment variables that are universal for all users 12 | # for system specific variables (like data paths) it's better to use .env file! 13 | job: 14 | env_set: 15 | EXAMPLE_VAR: "example_value" 16 | 17 | -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/configs/infer_bass.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default training configuration 4 | defaults: 5 | - model: bass 6 | - paths: default.yaml 7 | # enable color logging 8 | - override hydra/hydra_logging: colorlog 9 | - override hydra/job_logging: colorlog 10 | 11 | #hydra: 12 | # run: 13 | # dir: ${get_eval_log_dir:${ckpt_path}} 14 | 15 | #ckpt_path: "G:\\Experiments\\KLRef\\vocals.onnx" 16 | ckpt_path: "/root/autodl-tmp/dtt/bassg32_ep2935.ckpt" 17 | mixture_path: "/root/autodl-tmp/test" 18 | batch_size: 16 19 | device: 'cuda:1' 20 | 21 | double_chunk: False 22 | 23 | overlap_add: 24 | overlap_rate: 0.5 25 | tmp_root: "/root/autodl-tmp/tmp" # for saving temp chunks, since we use ffmpeg and will need io to disk 26 | samplerate: 44100 -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/configs/infer_drums.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default training configuration 4 | defaults: 5 | - model: drums 6 | - paths: default.yaml 7 | # enable color logging 8 | - override hydra/hydra_logging: colorlog 9 | - override hydra/job_logging: colorlog 10 | 11 | #hydra: 12 | # run: 13 | # dir: ${get_eval_log_dir:${ckpt_path}} 14 | 15 | #ckpt_path: "G:\\Experiments\\KLRef\\vocals.onnx" 16 | ckpt_path: "/root/autodl-tmp/dtt/drumsg32_ep1612.ckpt" 17 | mixture_path: "/root/autodl-tmp/test" 18 | batch_size: 16 19 | device: 'cuda:2' 20 | 21 | double_chunk: False 22 | 23 | overlap_add: 24 | overlap_rate: 0.5 25 | tmp_root: "/root/autodl-tmp/tmp" # for saving temp chunks, since we use ffmpeg and will need io to disk 26 | samplerate: 44100 -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/configs/infer_other.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default training configuration 4 | defaults: 5 | - model: other 6 | - paths: default.yaml 7 | # enable color logging 8 | - override hydra/hydra_logging: colorlog 9 | - override hydra/job_logging: colorlog 10 | 11 | #hydra: 12 | # run: 13 | # dir: ${get_eval_log_dir:${ckpt_path}} 14 | 15 | #ckpt_path: "G:\\Experiments\\KLRef\\vocals.onnx" 16 | ckpt_path: "/root/autodl-tmp/dtt/otherg32_ep3605.ckpt" 17 | mixture_path: "/root/autodl-tmp/test" 18 | batch_size: 16 19 | device: 'cuda:3' 20 | 21 | double_chunk: False 22 | 23 | overlap_add: 24 | overlap_rate: 0.5 25 | tmp_root: "/root/autodl-tmp/tmp" # for saving temp chunks, since we use ffmpeg and will need io to disk 26 | samplerate: 44100 -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/configs/infer_vocals.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default training configuration 4 | defaults: 5 | - model: vocals 6 | - paths: default.yaml 7 | # enable color logging 8 | - override hydra/hydra_logging: colorlog 9 | - override hydra/job_logging: colorlog 10 | 11 | #hydra: 12 | # run: 13 | # dir: ${get_eval_log_dir:${ckpt_path}} 14 | 15 | #ckpt_path: "G:\\Experiments\\KLRef\\vocals.onnx" 16 | ckpt_path: "/root/autodl-tmp/dtt/vocalsg32_ep4082.ckpt" 17 | mixture_path: "/root/autodl-tmp/test" 18 | batch_size: 16 19 | device: 'cuda:0' 20 | 21 | double_chunk: False 22 | 23 | overlap_add: 24 | overlap_rate: 0.5 25 | tmp_root: "/root/autodl-tmp/tmp" # for saving temp chunks, since we use ffmpeg and will need io to disk 26 | samplerate: 44100 -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/configs/logger/csv.yaml: -------------------------------------------------------------------------------- 1 | # csv logger built in lightning 2 | 3 | csv: 4 | _target_: pytorch_lightning.loggers.csv_logs.CSVLogger 5 | save_dir: "." 6 | name: "csv/" 7 | version: null 8 | prefix: "" 9 | -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/configs/logger/many_loggers.yaml: -------------------------------------------------------------------------------- 1 | # train with many loggers at once 2 | 3 | defaults: 4 | # - aim.yaml 5 | # - comet.yaml 6 | - csv.yaml 7 | # - mlflow.yaml 8 | # - neptune.yaml 9 | # - tensorboard.yaml 10 | - wandb.yaml 11 | -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/configs/logger/neptune.yaml: -------------------------------------------------------------------------------- 1 | # https://neptune.ai 2 | 3 | neptune: 4 | _target_: pytorch_lightning.loggers.neptune.NeptuneLogger 5 | api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is laoded from environment variable 6 | project_name: your_name/template-tests 7 | close_after_fit: True 8 | offline_mode: False 9 | experiment_name: null 10 | experiment_id: null 11 | prefix: "" 12 | -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/configs/logger/none.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yongyizang/TrainingFreeMultiStepASR/49333f1050455f25649cbec902bba0d527561b75/MusicSourceSeparation/infer/configs/logger/none.yaml -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/configs/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | # https://www.tensorflow.org/tensorboard/ 2 | 3 | tensorboard: 4 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger 5 | save_dir: "tensorboard/" 6 | name: "default" 7 | version: null 8 | log_graph: False 9 | default_hp_metric: True 10 | prefix: "" 11 | -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | # https://wandb.ai 2 | 3 | wandb: 4 | _target_: pytorch_lightning.loggers.wandb.WandbLogger 5 | project: dtt_${model.target_name} 6 | name: null 7 | save_dir: ${hydra:run.dir} 8 | offline: False # set True to store all logs only locally 9 | id: null # pass correct id to resume experiment! 10 | # entity: "" # set to name of your wandb team or just remove it 11 | log_model: False 12 | prefix: "" 13 | job_type: "train" 14 | group: "" 15 | tags: [] 16 | -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/configs/model/bass.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.dp_tdf.dp_tdf_net.DPTDFNet 2 | 3 | # abstract parent class 4 | target_name: 'bass' 5 | lr: 0.0001 6 | optimizer: adamW 7 | 8 | dim_f: 864 9 | dim_t: 256 10 | n_fft: 6144 11 | hop_length: 1024 12 | overlap: 3072 13 | 14 | audio_ch: 2 15 | 16 | block_type: TFC_TDF_Res2 17 | num_blocks: 5 18 | l: 3 19 | g: 32 20 | k: 3 21 | bn: 2 22 | bias: False 23 | bn_norm: BN 24 | bandsequence: 25 | rnn_type: LSTM 26 | bidirectional: True 27 | num_layers: 4 28 | n_heads: 2 -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/configs/model/drums.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.dp_tdf.dp_tdf_net.DPTDFNet 2 | 3 | # abstract parent class 4 | target_name: 'drums' 5 | lr: 0.0001 6 | optimizer: adamW 7 | 8 | dim_f: 2048 9 | dim_t: 256 10 | n_fft: 6144 11 | hop_length: 1024 12 | overlap: 3072 13 | 14 | audio_ch: 2 15 | 16 | block_type: TFC_TDF_Res2 17 | num_blocks: 5 18 | l: 3 19 | g: 32 20 | k: 3 21 | bn: 8 22 | bias: False 23 | bn_norm: BN 24 | bandsequence: 25 | rnn_type: LSTM 26 | bidirectional: True 27 | num_layers: 4 28 | n_heads: 2 -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/configs/model/other.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.dp_tdf.dp_tdf_net.DPTDFNet 2 | 3 | # abstract parent class 4 | target_name: 'other' 5 | lr: 0.0001 6 | optimizer: adamW 7 | 8 | dim_f: 2048 9 | dim_t: 256 10 | n_fft: 6144 11 | hop_length: 1024 12 | overlap: 3072 13 | 14 | audio_ch: 2 15 | 16 | block_type: TFC_TDF_Res2 17 | num_blocks: 5 18 | l: 3 19 | g: 32 20 | k: 3 21 | bn: 8 22 | bias: False 23 | bn_norm: BN 24 | bandsequence: 25 | rnn_type: LSTM 26 | bidirectional: True 27 | num_layers: 4 28 | n_heads: 2 -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/configs/model/vocals.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.dp_tdf.dp_tdf_net.DPTDFNet 2 | 3 | # abstract parent class 4 | target_name: 'vocals' 5 | lr: 0.0001 6 | optimizer: adamW 7 | 8 | dim_f: 2048 9 | dim_t: 256 10 | n_fft: 6144 11 | hop_length: 1024 12 | overlap: 3072 13 | 14 | audio_ch: 2 15 | 16 | block_type: TFC_TDF_Res2 17 | num_blocks: 5 18 | l: 3 19 | g: 32 20 | k: 3 21 | bn: 8 22 | bias: False 23 | bn_norm: BN 24 | bandsequence: 25 | rnn_type: LSTM 26 | bidirectional: True 27 | num_layers: 4 28 | n_heads: 2 -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/configs/paths/default.yaml: -------------------------------------------------------------------------------- 1 | # path to root directory 2 | # this requires PROJECT_ROOT environment variable to exist 3 | # you can replace it with "." if you want the root to be the current working directory 4 | root_dir: ${oc.env:PROJECT_ROOT} 5 | 6 | # path to data directory 7 | data_dir: ${paths.root_dir}/data/ 8 | 9 | # path to logging directory 10 | log_dir: ${oc.env:LOG_DIR} 11 | 12 | # path to output directory, created dynamically by hydra 13 | # path generation pattern is specified in `configs/hydra/default.yaml` 14 | # use it to store all files generated during the run, like ckpts and metrics 15 | output_dir: ${hydra:runtime.output_dir} 16 | 17 | # path to working directory 18 | work_dir: ${hydra:runtime.cwd} -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/configs/trainer/ddp.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | # use "ddp_spawn" instead of "ddp", 5 | # it's slower but normal "ddp" currently doesn't work ideally with hydra 6 | # https://github.com/facebookresearch/hydra/issues/2070 7 | # https://pytorch-lightning.readthedocs.io/en/latest/accelerators/gpu_intermediate.html#distributed-data-parallel-spawn 8 | strategy: ddp_spawn 9 | 10 | accelerator: gpu 11 | devices: 2 12 | num_nodes: 1 13 | sync_batchnorm: True -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | default_root_dir: ${paths.output_dir} 4 | 5 | min_epochs: 1 # prevents early stopping 6 | max_epochs: 10 7 | 8 | accelerator: cpu 9 | devices: 1 10 | 11 | # mixed precision for extra speed-up 12 | # precision: 16 13 | 14 | # perform a validation loop every N training epochs 15 | check_val_every_n_epoch: 1 16 | 17 | # set True to to ensure deterministic results 18 | # makes training slower but gives more reproducibility than just setting seeds 19 | deterministic: False -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/configs/trainer/minimal.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | defaults: 4 | - default 5 | 6 | devices: 4 7 | 8 | resume_from_checkpoint: 9 | auto_lr_find: False 10 | deterministic: True 11 | accelerator: dp 12 | sync_batchnorm: False 13 | 14 | max_epochs: 3000 15 | min_epochs: 1 16 | check_val_every_n_epoch: 10 17 | num_sanity_val_steps: 1 18 | 19 | precision: 16 20 | amp_backend: "native" 21 | amp_level: "O2" -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/requirements.txt: -------------------------------------------------------------------------------- 1 | # --------- pytorch --------- # 2 | torch==1.13.1 3 | pytorch-lightning==1.9.0 4 | 5 | # --------- hydra --------- # 6 | hydra-core 7 | hydra-colorlog 8 | hydra-optuna-sweeper 9 | 10 | # --------- loggers --------- # 11 | wandb>=0.10.30 12 | tensorboard 13 | 14 | # --------- linters --------- # 15 | pre-commit 16 | black 17 | isort 18 | flake8 19 | 20 | # --------- others --------- # 21 | python-dotenv 22 | rich 23 | pytest 24 | sh 25 | scikit-learn 26 | pudb 27 | soundfile 28 | pyrootutils 29 | matplotlib 30 | seaborn 31 | onnxruntime-gpu 32 | librosa==0.9.1 33 | museval -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/run_infer.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import hydra 4 | from omegaconf import DictConfig, OmegaConf 5 | import soundfile as sf 6 | import torch 7 | from src.utils.utils import load_wav, get_unique_save_path 8 | from src.utils.omega_resolvers import get_eval_log_dir 9 | from pathlib import Path 10 | from tqdm import tqdm 11 | import numpy as np 12 | import fast_bss_eval 13 | import dotenv 14 | from src.evaluation.separate import separate_with_ckpt_TDF, no_overlap_inference, overlap_inference 15 | dotenv.load_dotenv(override=True) 16 | 17 | mode = "other" 18 | 19 | def sdr( 20 | ref: np.ndarray, 21 | est: np.ndarray, 22 | eps: float = 1e-10 23 | ): 24 | r"""Calcualte SDR. 25 | """ 26 | noise = est - ref 27 | numerator = np.clip(a=np.mean(ref ** 2), a_min=eps, a_max=None) 28 | denominator = np.clip(a=np.mean(noise ** 2), a_min=eps, a_max=None) 29 | sdr = 10. * np.log10(numerator / denominator) 30 | return sdr 31 | 32 | @hydra.main(config_path="configs/", config_name="infer_" + mode + ".yaml", version_base='1.1') 33 | def main(config: DictConfig): 34 | # Imports should be nested inside @hydra.main to optimize tab completion 35 | # Read more here: https://github.com/facebookresearch/hydra/issues/934 36 | 37 | from src.utils import utils 38 | 39 | model = hydra.utils.instantiate(config.model) 40 | 41 | ckpt_path = Path(config.ckpt_path) 42 | 43 | checkpoint = torch.load(ckpt_path) 44 | model.load_state_dict(checkpoint["state_dict"]) 45 | model = model.to(config.device) 46 | 47 | mixtures = [] 48 | 49 | # get all mixture.wav recursively from mixture_path 50 | for root, dirs, files in os.walk(config.mixture_path): 51 | for file in files: 52 | if file.endswith("mixture.wav"): 53 | mixtures.append(os.path.join(root, file)) 54 | 55 | print(f"Found {len(mixtures)} mixtures") 56 | 57 | mixture_to_target_map = {} 58 | for mixture in mixtures: 59 | mixture_lastdirname = os.path.basename(os.path.dirname(mixture)) 60 | if mixture_lastdirname not in mixture_to_target_map: 61 | mixture_to_target_map[mixture_lastdirname] = {} 62 | mixture_to_target_map[mixture_lastdirname]["mixture"] = mixture 63 | mixture_to_target_map[mixture_lastdirname]["target"] = os.path.dirname(mixture) + "/" + mode + ".wav" 64 | # check if target exists 65 | if not os.path.exists(mixture_to_target_map[mixture_lastdirname]["target"]): 66 | print(f"Target {mixture_to_target_map[mixture_lastdirname]['target']} does not exist") 67 | continue 68 | 69 | print("Found {} targets".format(len(mixture_to_target_map))) 70 | 71 | inference_steps = 20 72 | candidate_each_step = 10 73 | 74 | for step in tqdm(range(inference_steps)): 75 | all_sdrs = [] 76 | save_path = "/root/autodl-tmp/inferred_results/" + mode + "/step_" + str(step + 1) 77 | prev_step = "/root/autodl-tmp/inferred_results/" + mode + "/step_" + str(step) 78 | if not os.path.exists(save_path): 79 | os.makedirs(save_path, exist_ok=True) 80 | 81 | # ratio is 0 to 0.9 82 | candidate_ratios = [i * 0.1 for i in range(candidate_each_step)] 83 | 84 | for item in tqdm(mixture_to_target_map): 85 | mixture_path = mixture_to_target_map[item]["mixture"] 86 | target_path = mixture_to_target_map[item]["target"] 87 | mixture = load_wav(mixture_path) 88 | target_audio = load_wav(target_path) 89 | prev_audio = load_wav(os.path.join(prev_step, os.path.basename(os.path.dirname(mixture_path)) + ".flac")) 90 | prev_sdr = sdr(target_audio, prev_audio) 91 | 92 | inferred = separate_with_ckpt_TDF(config.batch_size, model, ckpt_path, mixture, prev_audio, target_audio, candidate_ratios, config.device, 93 | config.double_chunk, config.overlap_add) 94 | 95 | mixture_lastdirname = os.path.basename(os.path.dirname(mixture_path)) 96 | curr_save_path = os.path.join(save_path, mixture_lastdirname + ".flac") 97 | sf.write(curr_save_path, inferred.T, 44100) 98 | sdr_value = sdr(target_audio, inferred) 99 | all_sdrs.append(sdr_value) 100 | 101 | if __name__ == "__main__": 102 | main() 103 | -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | project_name = ... 3 | author = ... 4 | contact = ... 5 | license_file = LICENSE 6 | description_file = README.md 7 | project_template = https://github.com/ashleve/lightning-hydra-template 8 | 9 | 10 | [isort] 11 | line_length = 99 12 | profile = black 13 | filter_files = True 14 | 15 | 16 | [flake8] 17 | max_line_length = 99 18 | show_source = True 19 | format = pylint 20 | ignore = 21 | F401 # Module imported but unused 22 | W504 # Line break occurred after a binary operator 23 | F841 # Local variable name is assigned to but never used 24 | exclude = 25 | .git 26 | __pycache__ 27 | data/* 28 | tests/* 29 | notebooks/* 30 | logs/* 31 | 32 | 33 | [tool:pytest] 34 | python_files = tests/* 35 | log_cli = True 36 | markers = 37 | slow 38 | addopts = 39 | --durations=0 40 | --strict-markers 41 | --doctest-modules 42 | filterwarnings = 43 | ignore::DeprecationWarning 44 | -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yongyizang/TrainingFreeMultiStepASR/49333f1050455f25649cbec902bba0d527561b75/MusicSourceSeparation/infer/src/__init__.py -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/src/callbacks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yongyizang/TrainingFreeMultiStepASR/49333f1050455f25649cbec902bba0d527561b75/MusicSourceSeparation/infer/src/callbacks/__init__.py -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/src/callbacks/onnx_callback.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from typing import Dict, Any 3 | 4 | import torch 5 | from pytorch_lightning import Callback 6 | import pytorch_lightning as pl 7 | import inspect 8 | from src.models.mdxnet import AbstractMDXNet 9 | 10 | 11 | class MakeONNXCallback(Callback): 12 | """Upload all *.py files to wandb as an artifact, at the beginning of the run.""" 13 | 14 | def __init__(self, dirpath: str): 15 | self.dirpath = dirpath 16 | if not os.path.exists(self.dirpath): 17 | os.mkdir(self.dirpath) 18 | 19 | def on_save_checkpoint(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', 20 | checkpoint: Dict[str, Any]) -> dict: 21 | res = super().on_save_checkpoint(trainer, pl_module, checkpoint) 22 | 23 | var = inspect.signature(pl_module.__init__).parameters 24 | model = pl_module.__class__(**dict((name, pl_module.__dict__[name]) for name in var)) 25 | model.load_state_dict(pl_module.state_dict()) 26 | 27 | target_dir = '{}epoch_{}'.format(self.dirpath, pl_module.current_epoch) 28 | 29 | try: 30 | if not os.path.exists(target_dir): 31 | os.mkdir(target_dir) 32 | 33 | with torch.no_grad(): 34 | torch.onnx.export(model, 35 | torch.zeros(model.inference_chunk_shape), 36 | '{}/{}.onnx'.format(target_dir, model.target_name), 37 | export_params=True, # store the trained parameter weights inside the model file 38 | opset_version=13, # the ONNX version to export the model to 39 | do_constant_folding=True, # whether to execute constant folding for optimization 40 | input_names=['input'], # the model's input names 41 | output_names=['output'], # the model's output names 42 | dynamic_axes={'input': {0: 'batch_size'}, # variable length axes 43 | 'output': {0: 'batch_size'}}) 44 | except: 45 | print('onnx error') 46 | finally: 47 | del model 48 | 49 | return res 50 | -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/src/callbacks/wandb_callbacks.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from typing import List, Optional, Any 4 | 5 | import matplotlib.pyplot as plt 6 | import seaborn as sn 7 | import torch 8 | import wandb 9 | from pytorch_lightning import Callback, Trainer 10 | from pytorch_lightning.loggers import WandbLogger 11 | from pytorch_lightning.utilities.types import STEP_OUTPUT 12 | from sklearn import metrics 13 | from sklearn.metrics import f1_score, precision_score, recall_score 14 | 15 | 16 | def get_wandb_logger(trainer: Trainer) -> WandbLogger: 17 | """Safely get Weights&Biases logger from Trainer.""" 18 | 19 | if isinstance(trainer.logger, WandbLogger): 20 | return trainer.logger 21 | 22 | if isinstance(trainer.loggers, list): 23 | for logger in trainer.loggers: 24 | if isinstance(logger, WandbLogger): 25 | return logger 26 | 27 | raise Exception( 28 | "You are using wandb related callback, but WandbLogger was not found for some reason..." 29 | ) 30 | 31 | 32 | class UploadValidTrack(Callback): 33 | def __init__(self, crop: int, upload_after_n_epoch: int): 34 | self.sample_length = crop * 44100 35 | self.upload_after_n_epoch = upload_after_n_epoch 36 | self.len_left_window = self.len_right_window = self.sample_length // 2 37 | 38 | def on_validation_batch_end( 39 | self, 40 | trainer: 'pl.Trainer', 41 | pl_module: 'pl.LightningModule', 42 | outputs: Optional[STEP_OUTPUT], 43 | batch: Any, 44 | batch_idx: int, 45 | dataloader_idx: int, 46 | ) -> None: 47 | if outputs is None: 48 | return 49 | track_id = outputs['track_id'] 50 | track = outputs['track'] 51 | 52 | logger = get_wandb_logger(trainer=trainer) 53 | experiment = logger.experiment 54 | if pl_module.current_epoch < self.upload_after_n_epoch: 55 | return None 56 | 57 | mid = track.shape[-1]//2 58 | track = track[:, mid-self.len_left_window:mid+self.len_right_window] 59 | 60 | experiment.log({'track={}_epoch={}'.format(track_id, pl_module.current_epoch): 61 | [wandb.Audio(track.T, sample_rate=44100)]}) 62 | 63 | 64 | class WatchModel(Callback): 65 | """Make wandb watch model at the beginning of the run.""" 66 | 67 | def __init__(self, log: str = "gradients", log_freq: int = 100): 68 | self.log = log 69 | self.log_freq = log_freq 70 | 71 | def on_train_start(self, trainer, pl_module): 72 | logger = get_wandb_logger(trainer=trainer) 73 | logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq) 74 | 75 | 76 | class UploadCodeAsArtifact(Callback): 77 | """Upload all *.py files to wandb as an artifact, at the beginning of the run.""" 78 | 79 | def __init__(self, code_dir: str): 80 | self.code_dir = code_dir 81 | 82 | def on_train_start(self, trainer, pl_module): 83 | logger = get_wandb_logger(trainer=trainer) 84 | experiment = logger.experiment 85 | 86 | code = wandb.Artifact("project-source", type="code") 87 | for path in glob.glob(os.path.join(self.code_dir, "**/*.py"), recursive=True): 88 | code.add_file(path) 89 | 90 | experiment.use_artifact(code) 91 | 92 | 93 | class UploadCheckpointsAsArtifact(Callback): 94 | """Upload checkpoints to wandb as an artifact, at the end of run.""" 95 | 96 | def __init__(self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False): 97 | self.ckpt_dir = ckpt_dir 98 | self.upload_best_only = upload_best_only 99 | 100 | def on_train_end(self, trainer, pl_module): 101 | logger = get_wandb_logger(trainer=trainer) 102 | experiment = logger.experiment 103 | 104 | ckpts = wandb.Artifact("experiment-ckpts", type="checkpoints") 105 | 106 | if self.upload_best_only: 107 | ckpts.add_file(trainer.checkpoint_callback.best_model_path) 108 | else: 109 | for path in glob.glob(os.path.join(self.ckpt_dir, "**/*.ckpt"), recursive=True): 110 | ckpts.add_file(path) 111 | 112 | experiment.use_artifact(ckpts) 113 | 114 | 115 | class LogConfusionMatrix(Callback): 116 | """Generate confusion matrix every epoch and send it to wandb. 117 | Expects validation step to return predictions and targets. 118 | """ 119 | 120 | def __init__(self): 121 | self.preds = [] 122 | self.targets = [] 123 | self.ready = True 124 | 125 | def on_sanity_check_start(self, trainer, pl_module) -> None: 126 | self.ready = False 127 | 128 | def on_sanity_check_end(self, trainer, pl_module): 129 | """Start executing this callback only after all validation sanity checks end.""" 130 | self.ready = True 131 | 132 | def on_validation_batch_end( 133 | self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx 134 | ): 135 | """Gather data from single batch.""" 136 | if self.ready: 137 | self.preds.append(outputs["preds"]) 138 | self.targets.append(outputs["targets"]) 139 | 140 | def on_validation_epoch_end(self, trainer, pl_module): 141 | """Generate confusion matrix.""" 142 | if self.ready: 143 | logger = get_wandb_logger(trainer) 144 | experiment = logger.experiment 145 | 146 | preds = torch.cat(self.preds).cpu().numpy() 147 | targets = torch.cat(self.targets).cpu().numpy() 148 | 149 | confusion_matrix = metrics.confusion_matrix(y_true=targets, y_pred=preds) 150 | 151 | # set figure size 152 | plt.figure(figsize=(14, 8)) 153 | 154 | # set labels size 155 | sn.set(font_scale=1.4) 156 | 157 | # set font size 158 | sn.heatmap(confusion_matrix, annot=True, annot_kws={"size": 8}, fmt="g") 159 | 160 | # names should be uniqe or else charts from different experiments in wandb will overlap 161 | experiment.log({f"confusion_matrix/{experiment.name}": wandb.Image(plt)}, commit=False) 162 | 163 | # according to wandb docs this should also work but it crashes 164 | # experiment.log(f{"confusion_matrix/{experiment.name}": plt}) 165 | 166 | # reset plot 167 | plt.clf() 168 | 169 | self.preds.clear() 170 | self.targets.clear() 171 | 172 | 173 | class LogF1PrecRecHeatmap(Callback): 174 | """Generate f1, precision, recall heatmap every epoch and send it to wandb. 175 | Expects validation step to return predictions and targets. 176 | """ 177 | 178 | def __init__(self, class_names: List[str] = None): 179 | self.preds = [] 180 | self.targets = [] 181 | self.ready = True 182 | 183 | def on_sanity_check_start(self, trainer, pl_module): 184 | self.ready = False 185 | 186 | def on_sanity_check_end(self, trainer, pl_module): 187 | """Start executing this callback only after all validation sanity checks end.""" 188 | self.ready = True 189 | 190 | def on_validation_batch_end( 191 | self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx 192 | ): 193 | """Gather data from single batch.""" 194 | if self.ready: 195 | self.preds.append(outputs["preds"]) 196 | self.targets.append(outputs["targets"]) 197 | 198 | def on_validation_epoch_end(self, trainer, pl_module): 199 | """Generate f1, precision and recall heatmap.""" 200 | if self.ready: 201 | logger = get_wandb_logger(trainer=trainer) 202 | experiment = logger.experiment 203 | 204 | preds = torch.cat(self.preds).cpu().numpy() 205 | targets = torch.cat(self.targets).cpu().numpy() 206 | f1 = f1_score(preds, targets, average=None) 207 | r = recall_score(preds, targets, average=None) 208 | p = precision_score(preds, targets, average=None) 209 | data = [f1, p, r] 210 | 211 | # set figure size 212 | plt.figure(figsize=(14, 3)) 213 | 214 | # set labels size 215 | sn.set(font_scale=1.2) 216 | 217 | # set font size 218 | sn.heatmap( 219 | data, 220 | annot=True, 221 | annot_kws={"size": 10}, 222 | fmt=".3f", 223 | yticklabels=["F1", "Precision", "Recall"], 224 | ) 225 | 226 | # names should be uniqe or else charts from different experiments in wandb will overlap 227 | experiment.log({f"f1_p_r_heatmap/{experiment.name}": wandb.Image(plt)}, commit=False) 228 | 229 | # reset plot 230 | plt.clf() 231 | 232 | self.preds.clear() 233 | self.targets.clear() 234 | 235 | 236 | class LogImagePredictions(Callback): 237 | """Logs a validation batch and their predictions to wandb. 238 | Example adapted from: 239 | https://wandb.ai/wandb/wandb-lightning/reports/Image-Classification-using-PyTorch-Lightning--VmlldzoyODk1NzY 240 | """ 241 | 242 | def __init__(self, num_samples: int = 8): 243 | super().__init__() 244 | self.num_samples = num_samples 245 | self.ready = True 246 | 247 | def on_sanity_check_start(self, trainer, pl_module): 248 | self.ready = False 249 | 250 | def on_sanity_check_end(self, trainer, pl_module): 251 | """Start executing this callback only after all validation sanity checks end.""" 252 | self.ready = True 253 | 254 | def on_validation_epoch_end(self, trainer, pl_module): 255 | if self.ready: 256 | logger = get_wandb_logger(trainer=trainer) 257 | experiment = logger.experiment 258 | 259 | # get a validation batch from the validation dat loader 260 | val_samples = next(iter(trainer.datamodule.val_dataloader())) 261 | val_imgs, val_labels = val_samples 262 | 263 | # run the batch through the network 264 | val_imgs = val_imgs.to(device=pl_module.device) 265 | logits = pl_module(val_imgs) 266 | preds = torch.argmax(logits, axis=-1) 267 | 268 | # log the images as wandb Image 269 | experiment.log( 270 | { 271 | f"Images/{experiment.name}": [ 272 | wandb.Image(x, caption=f"Pred:{pred}, Label:{y}") 273 | for x, pred, y in zip( 274 | val_imgs[: self.num_samples], 275 | preds[: self.num_samples], 276 | val_labels[: self.num_samples], 277 | ) 278 | ] 279 | } 280 | ) 281 | -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/src/datamodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yongyizang/TrainingFreeMultiStepASR/49333f1050455f25649cbec902bba0d527561b75/MusicSourceSeparation/infer/src/datamodules/__init__.py -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/src/datamodules/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yongyizang/TrainingFreeMultiStepASR/49333f1050455f25649cbec902bba0d527561b75/MusicSourceSeparation/infer/src/datamodules/datasets/__init__.py -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/src/datamodules/datasets/musdb.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import ABCMeta, ABC 3 | from pathlib import Path 4 | 5 | import soundfile 6 | from torch.utils.data import Dataset 7 | import torch 8 | import numpy as np 9 | import random 10 | from tqdm import tqdm 11 | 12 | from src.utils.utils import load_wav 13 | from src import utils 14 | import numpy as np 15 | 16 | log = utils.get_pylogger(__name__) 17 | 18 | def check_target_name(target_name, source_names): 19 | try: 20 | assert target_name is not None 21 | except AssertionError: 22 | print('[ERROR] please identify target name. ex) +datamodule.target_name="vocals"') 23 | exit(-1) 24 | try: 25 | assert target_name in source_names or target_name == 'all' 26 | except AssertionError: 27 | print('[ERROR] target name should one of "bass", "drums", "other", "vocals", "all"') 28 | exit(-1) 29 | 30 | 31 | def check_sample_rate(sr, sample_track): 32 | try: 33 | sample_rate = soundfile.read(sample_track)[1] 34 | assert sample_rate == sr 35 | except AssertionError: 36 | sample_rate = soundfile.read(sample_track)[1] 37 | print('[ERROR] sampling rate mismatched') 38 | print('\t=> sr in Config file: {}, but sr of data: {}'.format(sr, sample_rate)) 39 | exit(-1) 40 | 41 | 42 | class MusdbDataset(Dataset): 43 | __metaclass__ = ABCMeta 44 | 45 | def __init__(self, data_dir, chunk_size): 46 | self.source_names = ['bass', 'drums', 'other', 'vocals'] 47 | self.chunk_size = chunk_size 48 | self.musdb_path = Path(data_dir) 49 | 50 | 51 | class MusdbTrainDataset(MusdbDataset): 52 | def __init__(self, data_dir, chunk_size, target_name, aug_params, external_datasets, single_channel, epoch_size): 53 | super(MusdbTrainDataset, self).__init__(data_dir, chunk_size) 54 | 55 | self.single_channel = single_channel 56 | self.neg_lst = [x for x in self.source_names if x != target_name] 57 | 58 | self.target_name = target_name 59 | check_target_name(self.target_name, self.source_names) 60 | 61 | if not self.musdb_path.joinpath('metadata').exists(): 62 | os.mkdir(self.musdb_path.joinpath('metadata')) 63 | 64 | splits = ['train'] 65 | if external_datasets is not None: 66 | splits += external_datasets 67 | 68 | # collect paths for datasets and metadata (track names and duration) 69 | datasets, metadata_caches = [], [] 70 | raw_datasets = [] # un-augmented datasets 71 | for split in splits: 72 | raw_datasets.append(self.musdb_path.joinpath(split)) 73 | max_pitch, max_tempo = aug_params 74 | for p in range(-max_pitch, max_pitch+1): 75 | for t in range(-max_tempo, max_tempo+1, 10): 76 | aug_split = split if p==t==0 else split + f'_p={p}_t={t}' 77 | datasets.append(self.musdb_path.joinpath(aug_split)) 78 | metadata_caches.append(self.musdb_path.joinpath('metadata').joinpath(aug_split + '.pkl')) 79 | 80 | # collect all track names and their duration 81 | self.metadata = [] 82 | raw_track_lengths = [] # for calculating epoch size 83 | for i, (dataset, metadata_cache) in enumerate(tqdm(zip(datasets, metadata_caches))): 84 | try: 85 | metadata = torch.load(metadata_cache) 86 | except FileNotFoundError: 87 | print('creating metadata for', dataset) 88 | metadata = [] 89 | for track_name in sorted(os.listdir(dataset)): 90 | track_path = dataset.joinpath(track_name) 91 | track_length = load_wav(track_path.joinpath('vocals.wav')).shape[-1] 92 | metadata.append((track_path, track_length)) 93 | torch.save(metadata, metadata_cache) 94 | 95 | self.metadata += metadata 96 | if dataset in raw_datasets: 97 | raw_track_lengths += [length for path, length in metadata] 98 | 99 | self.epoch_size = sum(raw_track_lengths) // self.chunk_size if epoch_size is None else epoch_size 100 | log.info(f'epoch size: {self.epoch_size}') 101 | 102 | def __getitem__(self, _): 103 | sources = [] 104 | for source_name in self.source_names: 105 | track_path, track_length = random.choice(self.metadata) # random mixing between tracks 106 | source = load_wav(track_path.joinpath(source_name + '.wav'), 107 | track_length=track_length, chunk_size=self.chunk_size) # (2, times) 108 | sources.append(source) 109 | 110 | mix = sum(sources) 111 | 112 | if self.target_name == 'all': 113 | # Targets for models that separate all four sources (ex. Demucs). 114 | # This adds additional 'source' dimension => batch_shape=[batch, source, channel, time] 115 | target = sources 116 | else: 117 | target = sources[self.source_names.index(self.target_name)] 118 | 119 | mix, target = torch.tensor(mix), torch.tensor(target) 120 | if self.single_channel: 121 | mix = torch.mean(mix, dim=0, keepdim=True) 122 | target = torch.mean(target, dim=0, keepdim=True) 123 | return mix, target 124 | 125 | def __len__(self): 126 | return self.epoch_size 127 | 128 | 129 | class MusdbValidDataset(MusdbDataset): 130 | 131 | def __init__(self, data_dir, chunk_size, target_name, overlap, batch_size, single_channel): 132 | super(MusdbValidDataset, self).__init__(data_dir, chunk_size) 133 | 134 | self.target_name = target_name 135 | check_target_name(self.target_name, self.source_names) 136 | 137 | self.overlap = overlap 138 | self.batch_size = batch_size 139 | self.single_channel = single_channel 140 | 141 | musdb_valid_path = self.musdb_path.joinpath('valid') 142 | self.track_paths = [musdb_valid_path.joinpath(track_name) 143 | for track_name in os.listdir(musdb_valid_path)] 144 | 145 | def __getitem__(self, index): 146 | mix = load_wav(self.track_paths[index].joinpath('mixture.wav')) # (2, time) 147 | 148 | if self.target_name == 'all': 149 | # Targets for models that separate all four sources (ex. Demucs). 150 | # This adds additional 'source' dimension => batch_shape=[batch, source, channel, time] 151 | target = [load_wav(self.track_paths[index].joinpath(source_name + '.wav')) 152 | for source_name in self.source_names] 153 | else: 154 | target = load_wav(self.track_paths[index].joinpath(self.target_name + '.wav')) 155 | 156 | chunk_output_size = self.chunk_size - 2 * self.overlap 157 | left_pad = np.zeros([2, self.overlap]) 158 | right_pad = np.zeros([2, self.overlap + chunk_output_size - (mix.shape[-1] % chunk_output_size)]) 159 | mix_padded = np.concatenate([left_pad, mix, right_pad], 1) 160 | 161 | num_chunks = mix_padded.shape[-1] // chunk_output_size 162 | mix_chunks = np.array([mix_padded[:, i * chunk_output_size: i * chunk_output_size + self.chunk_size] 163 | for i in range(num_chunks)]) 164 | mix_chunk_batches = torch.tensor(mix_chunks, dtype=torch.float32).split(self.batch_size) 165 | target = torch.tensor(target) 166 | 167 | if self.single_channel: 168 | mix_chunk_batches = [torch.mean(t, dim=1, keepdim=True) for t in mix_chunk_batches] 169 | target = torch.mean(target, dim=0, keepdim=True) 170 | 171 | return mix_chunk_batches, target 172 | 173 | def __len__(self): 174 | return len(self.track_paths) -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/src/datamodules/musdb_datamodule.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import exists, join 3 | from pathlib import Path 4 | from typing import Optional, Tuple 5 | 6 | from pytorch_lightning import LightningDataModule 7 | from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split 8 | 9 | from src.datamodules.datasets.musdb import MusdbTrainDataset, MusdbValidDataset 10 | 11 | 12 | class MusdbDataModule(LightningDataModule): 13 | """ 14 | LightningDataModule for Musdb18-HQ dataset. 15 | A DataModule implements 5 key methods: 16 | - prepare_data (things to do on 1 GPU/TPU, not on every GPU/TPU in distributed mode) 17 | - setup (things to do on every accelerator in distributed mode) 18 | - train_dataloader (the training dataloader) 19 | - val_dataloader (the validation dataloader(s)) 20 | - test_dataloader (the test dataloader(s)) 21 | This allows you to share a full dataset without explaining how to download, 22 | split, transform and process the data 23 | Read the docs: 24 | https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html 25 | """ 26 | 27 | def __init__( 28 | self, 29 | data_dir: str, 30 | aug_params, 31 | target_name: str, 32 | overlap: int, 33 | hop_length: int, 34 | dim_t: int, 35 | sample_rate: int, 36 | batch_size: int, 37 | num_workers: int, 38 | pin_memory: bool, 39 | external_datasets, 40 | audio_ch: int, 41 | epoch_size, 42 | **kwargs, 43 | ): 44 | super().__init__() 45 | 46 | self.data_dir = Path(data_dir) 47 | self.target_name = target_name 48 | self.aug_params = aug_params 49 | self.external_datasets = external_datasets 50 | 51 | self.batch_size = batch_size 52 | self.num_workers = num_workers 53 | self.pin_memory = pin_memory 54 | 55 | # audio-related 56 | self.hop_length = hop_length 57 | self.sample_rate = sample_rate 58 | self.single_channel = audio_ch == 1 59 | 60 | # derived 61 | self.chunk_size = hop_length * (dim_t - 1) 62 | self.overlap = overlap 63 | 64 | self.epoch_size = epoch_size 65 | 66 | self.data_train: Optional[Dataset] = None 67 | self.data_val: Optional[Dataset] = None 68 | self.data_test: Optional[Dataset] = None 69 | 70 | trainset_path = self.data_dir.joinpath('train') 71 | validset_path = self.data_dir.joinpath('valid') 72 | 73 | # create validation split 74 | if not exists(validset_path): 75 | from shutil import move 76 | os.mkdir(validset_path) 77 | for track in kwargs['validation_set']: 78 | if trainset_path.joinpath(track).exists(): 79 | move(trainset_path.joinpath(track), validset_path.joinpath(track)) 80 | else: 81 | valid_files = os.listdir(validset_path) 82 | assert set(valid_files) == set(kwargs['validation_set']) 83 | 84 | def setup(self, stage: Optional[str] = None): 85 | """Load data. Set variables: self.data_train, self.data_val, self.data_test.""" 86 | self.data_train = MusdbTrainDataset(self.data_dir, 87 | self.chunk_size, 88 | self.target_name, 89 | self.aug_params, 90 | self.external_datasets, 91 | self.single_channel, 92 | self.epoch_size) 93 | 94 | self.data_val = MusdbValidDataset(self.data_dir, 95 | self.chunk_size, 96 | self.target_name, 97 | self.overlap, 98 | self.batch_size, 99 | self.single_channel) 100 | 101 | def train_dataloader(self): 102 | return DataLoader( 103 | dataset=self.data_train, 104 | batch_size=self.batch_size, 105 | num_workers=self.num_workers, 106 | pin_memory=self.pin_memory, 107 | shuffle=True, 108 | ) 109 | 110 | def val_dataloader(self): 111 | return DataLoader( 112 | dataset=self.data_val, 113 | batch_size=1, 114 | num_workers=self.num_workers, 115 | pin_memory=self.pin_memory, 116 | shuffle=False, 117 | ) -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/src/dp_tdf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yongyizang/TrainingFreeMultiStepASR/49333f1050455f25649cbec902bba0d527561b75/MusicSourceSeparation/infer/src/dp_tdf/__init__.py -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/src/dp_tdf/abstract.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta 2 | from typing import Optional 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from pytorch_lightning import LightningModule 10 | from pytorch_lightning.utilities.types import STEP_OUTPUT 11 | 12 | from src.utils.utils import sdr, simplified_msseval 13 | 14 | 15 | class AbstractModel(LightningModule): 16 | __metaclass__ = ABCMeta 17 | 18 | def __init__(self, target_name, 19 | lr, optimizer, 20 | dim_f, dim_t, n_fft, hop_length, overlap, 21 | audio_ch, 22 | **kwargs): 23 | super().__init__() 24 | self.target_name = target_name 25 | self.lr = lr 26 | self.optimizer = optimizer 27 | self.dim_c_in = audio_ch * 2 28 | self.dim_c_out = audio_ch * 2 29 | self.dim_f = dim_f 30 | self.dim_t = dim_t 31 | self.n_fft = n_fft 32 | self.n_bins = n_fft // 2 + 1 33 | self.hop_length = hop_length 34 | self.audio_ch = audio_ch 35 | 36 | self.chunk_size = hop_length * (self.dim_t - 1) 37 | self.inference_chunk_size = hop_length * (self.dim_t*2 - 1) 38 | self.overlap = overlap 39 | self.window = nn.Parameter(torch.hann_window(window_length=self.n_fft, periodic=True), requires_grad=False) 40 | self.freq_pad = nn.Parameter(torch.zeros([1, self.dim_c_out, self.n_bins - self.dim_f, 1]), requires_grad=False) 41 | self.inference_chunk_shape = (self.stft(torch.zeros([1, audio_ch, self.inference_chunk_size]))).shape 42 | 43 | 44 | def configure_optimizers(self): 45 | if self.optimizer == 'rmsprop': 46 | print("Using RMSprop optimizer") 47 | return torch.optim.RMSprop(self.parameters(), self.lr) 48 | elif self.optimizer == 'adamW': 49 | print("Using AdamW optimizer") 50 | return torch.optim.AdamW(self.parameters(), self.lr) 51 | 52 | def comp_loss(self, pred_detail, target_wave): 53 | pred_detail = self.istft(pred_detail) 54 | 55 | comp_loss = F.l1_loss(pred_detail, target_wave) 56 | 57 | self.log("train/comp_loss", comp_loss, sync_dist=True, on_step=False, on_epoch=True, prog_bar=False) 58 | 59 | return comp_loss 60 | 61 | 62 | def training_step(self, *args, **kwargs) -> STEP_OUTPUT: 63 | mix_wave, target_wave = args[0] # (batch, c, 261120) 64 | 65 | # input 1 66 | stft_44k = self.stft(mix_wave) # (batch, c*2, 1044, 256) 67 | # forward 68 | t_est_stft = self(stft_44k) # (batch, c, 1044, 256) 69 | 70 | loss = self.comp_loss(t_est_stft, target_wave) 71 | 72 | self.log("train/loss", loss, sync_dist=True, on_step=True, on_epoch=True, prog_bar=True) 73 | 74 | return {"loss": loss} 75 | 76 | 77 | # Validation SDR is calculated on whole tracks and not chunks since 78 | # short inputs have high possibility of being silent (all-zero signal) 79 | # which leads to very low sdr values regardless of the model. 80 | # A natural procedure would be to split a track into chunk batches and 81 | # load them on multiple gpus, but aggregation was too difficult. 82 | # So instead we load one whole track on a single device (data_loader batch_size should always be 1) 83 | # and do all the batch splitting and aggregation on a single device. 84 | def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: 85 | mix_chunk_batches, target = args[0] 86 | 87 | # remove data_loader batch dimension 88 | # [(b, c, time)], (c, all_times) 89 | mix_chunk_batches, target = [batch[0] for batch in mix_chunk_batches], target[0] 90 | 91 | # process whole track in batches of chunks 92 | target_hat_chunks = [] 93 | for batch in mix_chunk_batches: 94 | # input 95 | stft_44k = self.stft(batch) # (batch, c*2, 1044, 256) 96 | pred_detail = self(stft_44k) # (batch, c, 1044, 256), irm 97 | pred_detail = self.istft(pred_detail) 98 | 99 | target_hat_chunks.append(pred_detail[..., self.overlap:-self.overlap]) 100 | target_hat_chunks = torch.cat(target_hat_chunks) # (b*len(ls),c,t) 101 | 102 | # concat all output chunks (c, all_times) 103 | target_hat = target_hat_chunks.transpose(0, 1).reshape(self.audio_ch, -1)[..., :target.shape[-1]] 104 | 105 | ests = target_hat.detach().cpu().numpy() # (c, all_times) 106 | references = target.cpu().numpy() 107 | score = sdr(ests, references) 108 | 109 | # (src, t, c) 110 | SDR = simplified_msseval(np.expand_dims(references.T, axis=0), np.expand_dims(ests.T, axis=0), chunk_size=44100) 111 | # self.log("val/sdr", score, sync_dist=True, on_step=False, on_epoch=True, logger=True) 112 | 113 | return {'song': score, 'chunk': SDR} 114 | 115 | def validation_epoch_end(self, outputs) -> None: 116 | avg_uSDR = torch.Tensor([x['song'] for x in outputs]).mean() 117 | self.log("val/usdr", avg_uSDR, sync_dist=True, on_step=False, on_epoch=True, logger=True) 118 | 119 | chunks = [x['chunk'][0, :] for x in outputs] 120 | # concat np array 121 | chunks = np.concatenate(chunks, axis=0) 122 | median_cSDR = np.nanmedian(chunks.flatten(), axis=0) 123 | median_cSDR = float(median_cSDR) 124 | self.log("val/csdr", median_cSDR, sync_dist=True, on_step=False, on_epoch=True, logger=True) 125 | 126 | def stft(self, x): 127 | ''' 128 | Args: 129 | x: (batch, c, 261120) 130 | ''' 131 | dim_b = x.shape[0] 132 | x = x.reshape([dim_b * self.audio_ch, -1]) # (batch*c, 261120) 133 | x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=self.window, center=True, return_complex=False) # (batch*c, 3073, 256, 2) 134 | x = x.permute([0, 3, 1, 2]) # (batch*c, 2, 3073, 256) 135 | x = x.reshape([dim_b, self.audio_ch, 2, self.n_bins, -1]).reshape([dim_b, self.audio_ch * 2, self.n_bins, -1]) # (batch, c*2, 3073, 256) 136 | return x[:, :, :self.dim_f] # (batch, c*2, 2048, 256) 137 | 138 | def istft(self, x): 139 | ''' 140 | Args: 141 | x: (batch, c*2, 2048, 256) 142 | ''' 143 | dim_b = x.shape[0] 144 | x = torch.cat([x, self.freq_pad.repeat([x.shape[0], 1, 1, x.shape[-1]])], -2) # (batch, c*2, 3073, 256) 145 | x = x.reshape([dim_b, self.audio_ch, 2, self.n_bins, -1]).reshape([dim_b * self.audio_ch, 2, self.n_bins, -1]) # (batch*c, 2, 3073, 256) 146 | x = x.permute([0, 2, 3, 1]) # (batch*c, 3073, 256, 2) 147 | x = torch.view_as_complex(x.contiguous()) # (batch*c, 3073, 256) 148 | x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=self.window, center=True) # (batch*c, 261120) 149 | return x.reshape([dim_b, self.audio_ch, -1]) # (batch,c,261120) 150 | 151 | def demix(self, mix, inf_chunk_size, batch_size=5, inf_overf=4): 152 | ''' 153 | Args: 154 | mix: (C, L) 155 | Returns: 156 | est: (src, C, L) 157 | ''' 158 | 159 | # batch_size = self.config.inference.batch_size 160 | # = self.chunk_size 161 | # self.instruments = ['bass', 'drums', 'other', 'vocals'] 162 | num_instruments = 1 163 | 164 | inf_hop = inf_chunk_size // inf_overf # hop size 165 | L = mix.shape[1] 166 | pad_size = inf_hop - (L - inf_chunk_size) % inf_hop 167 | mix = torch.cat([torch.zeros(2, inf_chunk_size - inf_hop), torch.Tensor(mix), torch.zeros(2, pad_size + inf_chunk_size - inf_hop)], 1) 168 | mix = mix.cuda() 169 | 170 | chunks = [] 171 | i = 0 172 | while i + inf_chunk_size <= mix.shape[1]: 173 | chunks.append(mix[:, i:i + inf_chunk_size]) 174 | i += inf_hop 175 | chunks = torch.stack(chunks) 176 | 177 | batches = [] 178 | i = 0 179 | while i < len(chunks): 180 | batches.append(chunks[i:i + batch_size]) 181 | i = i + batch_size 182 | 183 | X = torch.zeros(num_instruments, 2, inf_chunk_size - inf_hop) # (src, c, t) 184 | X = X.cuda() 185 | with torch.cuda.amp.autocast(): 186 | with torch.no_grad(): 187 | for batch in batches: 188 | x = self.stft(batch) 189 | x = self(x) 190 | x = self.istft(x) # (batch, c, 261120) 191 | # insert new axis, the model only predict 1 src so we need to add axis 192 | x = x[:,None, ...] # (batch, 1, c, 261120) 193 | x = x.repeat([ 1, num_instruments, 1, 1]) # (batch, src, c, 261120) 194 | for w in x: # iterate over batch 195 | a = X[..., :-(inf_chunk_size - inf_hop)] 196 | b = X[..., -(inf_chunk_size - inf_hop):] + w[..., :(inf_chunk_size - inf_hop)] 197 | c = w[..., (inf_chunk_size - inf_hop):] 198 | X = torch.cat([a, b, c], -1) 199 | 200 | estimated_sources = X[..., inf_chunk_size - inf_hop:-(pad_size + inf_chunk_size - inf_hop)] / inf_overf 201 | 202 | assert L == estimated_sources.shape[-1] 203 | 204 | return estimated_sources 205 | 206 | -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/src/dp_tdf/bandsequence.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # Original code from https://github.com/amanteur/BandSplitRNN-Pytorch 5 | class RNNModule(nn.Module): 6 | """ 7 | RNN submodule of BandSequence module 8 | """ 9 | 10 | def __init__( 11 | self, 12 | group_num: int, 13 | input_dim_size: int, 14 | hidden_dim_size: int, 15 | rnn_type: str = 'lstm', 16 | bidirectional: bool = True 17 | ): 18 | super(RNNModule, self).__init__() 19 | self.groupnorm = nn.GroupNorm(group_num, input_dim_size) 20 | self.rnn = getattr(nn, rnn_type)( 21 | input_dim_size, hidden_dim_size, batch_first=True, bidirectional=bidirectional # 输出是2*hidden_dim_size,因为是bi 22 | ) 23 | self.fc = nn.Linear( 24 | hidden_dim_size * 2 if bidirectional else hidden_dim_size, 25 | input_dim_size 26 | ) 27 | 28 | def forward( 29 | self, 30 | x: torch.Tensor 31 | ): 32 | """ 33 | Input shape: 34 | across T - [batch_size, k_subbands, time, n_features] 35 | OR 36 | across K - [batch_size, time, k_subbands, n_features] 37 | """ 38 | B, K, T, N = x.shape # across T across K (keep in mind T->K, K->T) 39 | # print(x.shape) 40 | 41 | out = x.view(B * K, T, N) # [BK, T, N] [BT, K, N] 42 | 43 | # print(out.shape) 44 | # print(self.groupnorm) 45 | out = self.groupnorm( 46 | out.transpose(-1, -2) 47 | ).transpose(-1, -2) # [BK, T, N] [BT, K, N] 48 | out = self.rnn(out)[0] # [BK, T, H] [BT, K, H], 最后一维是特征 49 | out = self.fc(out) # [BK, T, N] [BT, K, N] 50 | 51 | x = out.view(B, K, T, N) + x # [B, K, T, N] [B, T, K, N] 52 | 53 | x = x.permute(0, 2, 1, 3).contiguous() # [B, T, K, N] [B, K, T, N] 54 | return x 55 | 56 | 57 | class BandSequenceModelModule(nn.Module): 58 | """ 59 | BandSequence (2nd) Module of BandSplitRNN. 60 | Runs input through n BiLSTMs in two dimensions - time and subbands. 61 | """ 62 | 63 | def __init__( 64 | self, 65 | # group_num, 66 | input_dim_size: int, 67 | hidden_dim_size: int, 68 | rnn_type: str = 'lstm', 69 | bidirectional: bool = True, 70 | num_layers: int = 12, 71 | n_heads: int = 4, 72 | ): 73 | super(BandSequenceModelModule, self).__init__() 74 | 75 | self.bsrnn = nn.ModuleList([]) 76 | self.n_heads = n_heads 77 | 78 | input_dim_size = input_dim_size // n_heads 79 | hidden_dim_size = hidden_dim_size // n_heads 80 | group_num = input_dim_size // 16 81 | # print(f"input_dim_size: {input_dim_size}, hidden_dim_size: {hidden_dim_size}, group_num: {group_num}") 82 | 83 | # print(group_num, input_dim_size) 84 | 85 | for _ in range(num_layers): 86 | rnn_across_t = RNNModule( 87 | group_num, input_dim_size, hidden_dim_size, rnn_type, bidirectional 88 | ) 89 | rnn_across_k = RNNModule( 90 | group_num, input_dim_size, hidden_dim_size, rnn_type, bidirectional 91 | ) 92 | self.bsrnn.append( 93 | nn.Sequential(rnn_across_t, rnn_across_k) 94 | ) 95 | 96 | def forward(self, x: torch.Tensor): 97 | """ 98 | Input shape: [batch_size, k_subbands, time, n_features] 99 | Output shape: [batch_size, k_subbands, time, n_features] 100 | """ 101 | # x (b,c,t,f) 102 | b,c,t,f = x.shape 103 | x = x.view(b * self.n_heads, c // self.n_heads, t, f) # [b*n_heads, c//n_heads, t, f] 104 | 105 | x = x.permute(0, 3, 2, 1).contiguous() # [b*n_heads, f, t, c//n_heads] 106 | for i in range(len(self.bsrnn)): 107 | x = self.bsrnn[i](x) 108 | 109 | x = x.permute(0, 3, 2, 1).contiguous() # [b*n_heads, c//n_heads, t, f] 110 | x = x.view(b, c, t, f) # [b, c, t, f] 111 | return x 112 | 113 | 114 | if __name__ == '__main__': 115 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 116 | 117 | batch_size, k_subbands, t_timesteps, input_dim = 4, 41, 512, 128 118 | in_features = torch.rand(batch_size, k_subbands, t_timesteps, input_dim).to(device) 119 | 120 | cfg = { 121 | # "t_timesteps": t_timesteps, 122 | "group_num": 32, 123 | "input_dim_size": 128, 124 | "hidden_dim_size": 256, 125 | "rnn_type": "LSTM", 126 | "bidirectional": True, 127 | "num_layers": 1 128 | } 129 | model = BandSequenceModelModule(**cfg).to(device) 130 | _ = model.eval() 131 | 132 | with torch.no_grad(): 133 | out_features = model(in_features) 134 | 135 | print(f"In: {in_features.shape}\nOut: {out_features.shape}") 136 | print(f"Total number of parameters: {sum([p.numel() for p in model.parameters()])}") 137 | -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/src/dp_tdf/dp_tdf_net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | from src.dp_tdf.modules import TFC_TDF, TFC_TDF_Res1, TFC_TDF_Res2 5 | from src.dp_tdf.bandsequence import BandSequenceModelModule 6 | 7 | from src.layers import (get_norm) 8 | from src.dp_tdf.abstract import AbstractModel 9 | 10 | class DPTDFNet(AbstractModel): 11 | def __init__(self, num_blocks, l, g, k, bn, bias, bn_norm, bandsequence, block_type, **kwargs): 12 | 13 | super(DPTDFNet, self).__init__(**kwargs) 14 | # self.save_hyperparameters() 15 | 16 | self.num_blocks = num_blocks 17 | self.l = l 18 | self.g = g 19 | self.k = k 20 | self.bn = bn 21 | self.bias = bias 22 | 23 | self.n = num_blocks // 2 24 | scale = (2, 2) 25 | 26 | if block_type == "TFC_TDF": 27 | T_BLOCK = TFC_TDF 28 | elif block_type == "TFC_TDF_Res1": 29 | T_BLOCK = TFC_TDF_Res1 30 | elif block_type == "TFC_TDF_Res2": 31 | T_BLOCK = TFC_TDF_Res2 32 | else: 33 | raise ValueError(f"Unknown block type {block_type}") 34 | 35 | self.first_conv = nn.Sequential( 36 | nn.Conv2d(in_channels=self.dim_c_in, out_channels=g, kernel_size=(1, 1)), 37 | get_norm(bn_norm, g), 38 | nn.ReLU(), 39 | ) 40 | 41 | f = self.dim_f 42 | c = g 43 | self.encoding_blocks = nn.ModuleList() 44 | self.ds = nn.ModuleList() 45 | 46 | for i in range(self.n): 47 | c_in = c 48 | 49 | self.encoding_blocks.append(T_BLOCK(c_in, c, l, f, k, bn, bn_norm, bias=bias)) 50 | self.ds.append( 51 | nn.Sequential( 52 | nn.Conv2d(in_channels=c, out_channels=c + g, kernel_size=scale, stride=scale), 53 | get_norm(bn_norm, c + g), 54 | nn.ReLU() 55 | ) 56 | ) 57 | f = f // 2 58 | c += g 59 | 60 | self.bottleneck_block1 = T_BLOCK(c, c, l, f, k, bn, bn_norm, bias=bias) 61 | self.bottleneck_block2 = BandSequenceModelModule( 62 | **bandsequence, 63 | input_dim_size=c, 64 | hidden_dim_size=2*c 65 | ) 66 | 67 | self.decoding_blocks = nn.ModuleList() 68 | self.us = nn.ModuleList() 69 | for i in range(self.n): 70 | # print(f"i: {i}, in channels: {c}") 71 | self.us.append( 72 | nn.Sequential( 73 | nn.ConvTranspose2d(in_channels=c, out_channels=c - g, kernel_size=scale, stride=scale), 74 | get_norm(bn_norm, c - g), 75 | nn.ReLU() 76 | ) 77 | ) 78 | 79 | f = f * 2 80 | c -= g 81 | 82 | self.decoding_blocks.append(T_BLOCK(c, c, l, f, k, bn, bn_norm, bias=bias)) 83 | 84 | self.final_conv = nn.Sequential( 85 | nn.Conv2d(in_channels=c, out_channels=self.dim_c_out, kernel_size=(1, 1)), 86 | ) 87 | 88 | def forward(self, x): 89 | ''' 90 | Args: 91 | x: (batch, c*2, 2048, 256) 92 | ''' 93 | x = self.first_conv(x) 94 | 95 | x = x.transpose(-1, -2) 96 | 97 | ds_outputs = [] 98 | for i in range(self.n): 99 | x = self.encoding_blocks[i](x) 100 | ds_outputs.append(x) 101 | x = self.ds[i](x) 102 | 103 | # print(f"bottleneck in: {x.shape}") 104 | x = self.bottleneck_block1(x) 105 | x = self.bottleneck_block2(x) 106 | 107 | for i in range(self.n): 108 | x = self.us[i](x) 109 | # print(f"us{i} in: {x.shape}") 110 | # print(f"ds{i} out: {ds_outputs[-i - 1].shape}") 111 | x = x * ds_outputs[-i - 1] 112 | x = self.decoding_blocks[i](x) 113 | 114 | x = x.transpose(-1, -2) 115 | 116 | x = self.final_conv(x) 117 | 118 | return x -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/src/dp_tdf/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from src.layers import (get_norm) 5 | 6 | class TFC(nn.Module): 7 | def __init__(self, c_in, c_out, l, k, bn_norm): 8 | super(TFC, self).__init__() 9 | 10 | self.H = nn.ModuleList() 11 | for i in range(l): 12 | if i == 0: 13 | c_in = c_in 14 | else: 15 | c_in = c_out 16 | self.H.append( 17 | nn.Sequential( 18 | nn.Conv2d(in_channels=c_in, out_channels=c_out, kernel_size=k, stride=1, padding=k // 2), 19 | get_norm(bn_norm, c_out), 20 | nn.ReLU(), 21 | ) 22 | ) 23 | 24 | def forward(self, x): 25 | for h in self.H: 26 | x = h(x) 27 | return x 28 | 29 | 30 | class DenseTFC(nn.Module): 31 | def __init__(self, c_in, c_out, l, k, bn_norm): 32 | super(DenseTFC, self).__init__() 33 | 34 | self.conv = nn.ModuleList() 35 | for i in range(l): 36 | self.conv.append( 37 | nn.Sequential( 38 | nn.Conv2d(in_channels=c_in, out_channels=c_out, kernel_size=k, stride=1, padding=k // 2), 39 | get_norm(bn_norm, c_out), 40 | nn.ReLU(), 41 | ) 42 | ) 43 | 44 | def forward(self, x): 45 | for layer in self.conv[:-1]: 46 | x = torch.cat([layer(x), x], 1) 47 | return self.conv[-1](x) 48 | 49 | 50 | class TFC_TDF(nn.Module): 51 | def __init__(self, c_in, c_out, l, f, k, bn, bn_norm, dense=False, bias=True): 52 | 53 | super(TFC_TDF, self).__init__() 54 | 55 | self.use_tdf = bn is not None 56 | 57 | self.tfc = DenseTFC(c_in, c_out, l, k, bn_norm) if dense else TFC(c_in, c_out, l, k, bn_norm) 58 | 59 | if self.use_tdf: 60 | if bn == 0: 61 | # print(f"TDF={f},{f}") 62 | self.tdf = nn.Sequential( 63 | nn.Linear(f, f, bias=bias), 64 | get_norm(bn_norm, c_out), 65 | nn.ReLU() 66 | ) 67 | else: 68 | # print(f"TDF={f},{f // bn},{f}") 69 | self.tdf = nn.Sequential( 70 | nn.Linear(f, f // bn, bias=bias), 71 | get_norm(bn_norm, c_out), 72 | nn.ReLU(), 73 | nn.Linear(f // bn, f, bias=bias), 74 | get_norm(bn_norm, c_out), 75 | nn.ReLU() 76 | ) 77 | 78 | def forward(self, x): 79 | x = self.tfc(x) 80 | return x + self.tdf(x) if self.use_tdf else x 81 | 82 | 83 | class TFC_TDF_Res1(nn.Module): 84 | def __init__(self, c_in, c_out, l, f, k, bn, bn_norm, dense=False, bias=True): 85 | 86 | super(TFC_TDF_Res1, self).__init__() 87 | 88 | self.use_tdf = bn is not None 89 | 90 | self.tfc = DenseTFC(c_in, c_out, l, k, bn_norm) if dense else TFC(c_in, c_out, l, k, bn_norm) 91 | 92 | self.res = TFC(c_in, c_out, 1, k, bn_norm) 93 | 94 | if self.use_tdf: 95 | if bn == 0: 96 | # print(f"TDF={f},{f}") 97 | self.tdf = nn.Sequential( 98 | nn.Linear(f, f, bias=bias), 99 | get_norm(bn_norm, c_out), 100 | nn.ReLU() 101 | ) 102 | else: 103 | # print(f"TDF={f},{f // bn},{f}") 104 | self.tdf = nn.Sequential( 105 | nn.Linear(f, f // bn, bias=bias), 106 | get_norm(bn_norm, c_out), 107 | nn.ReLU(), 108 | nn.Linear(f // bn, f, bias=bias), 109 | get_norm(bn_norm, c_out), 110 | nn.ReLU() 111 | ) 112 | 113 | def forward(self, x): 114 | res = self.res(x) 115 | x = self.tfc(x) 116 | x = x + res 117 | return x + self.tdf(x) if self.use_tdf else x 118 | 119 | 120 | class TFC_TDF_Res2(nn.Module): 121 | def __init__(self, c_in, c_out, l, f, k, bn, bn_norm, dense=False, bias=True): 122 | 123 | super(TFC_TDF_Res2, self).__init__() 124 | 125 | self.use_tdf = bn is not None 126 | 127 | self.tfc1 = TFC(c_in, c_out, l, k, bn_norm) 128 | self.tfc2 = TFC(c_in, c_out, l, k, bn_norm) 129 | 130 | self.res = TFC(c_in, c_out, 1, k, bn_norm) 131 | 132 | if self.use_tdf: 133 | if bn == 0: 134 | # print(f"TDF={f},{f}") 135 | self.tdf = nn.Sequential( 136 | nn.Linear(f, f, bias=bias), 137 | get_norm(bn_norm, c_out), 138 | nn.ReLU() 139 | ) 140 | else: 141 | # print(f"TDF={f},{f // bn},{f}") 142 | self.tdf = nn.Sequential( 143 | nn.Linear(f, f // bn, bias=bias), 144 | get_norm(bn_norm, c_out), 145 | nn.ReLU(), 146 | nn.Linear(f // bn, f, bias=bias), 147 | get_norm(bn_norm, c_out), 148 | nn.ReLU() 149 | ) 150 | 151 | def forward(self, x): 152 | res = self.res(x) 153 | x = self.tfc1(x) 154 | if self.use_tdf: 155 | x = x + self.tdf(x) 156 | x = self.tfc2(x) 157 | x = x + res 158 | return x 159 | -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/src/evaluation/eval.py: -------------------------------------------------------------------------------- 1 | from os import listdir 2 | from pathlib import Path 3 | from typing import Optional, List 4 | 5 | from concurrent import futures 6 | import hydra 7 | import wandb 8 | import os 9 | import shutil 10 | from omegaconf import DictConfig 11 | from pytorch_lightning import LightningDataModule, LightningModule 12 | from pytorch_lightning.loggers import Logger, WandbLogger 13 | import soundfile as sf 14 | 15 | from tqdm import tqdm 16 | import numpy as np 17 | from src.callbacks.wandb_callbacks import get_wandb_logger 18 | from src.evaluation.separate import separate_with_onnx_TDF, separate_with_ckpt_TDF 19 | from src.utils import utils 20 | from src.utils.utils import load_wav, sdr, get_median_csdr, save_results, get_metrics 21 | 22 | from src.utils import pylogger 23 | 24 | log = pylogger.get_pylogger(__name__) 25 | 26 | 27 | def evaluation(config: DictConfig): 28 | 29 | assert config.split in ['train', 'valid', 'test'] 30 | 31 | data_dir = Path(config.get('eval_dir')).joinpath(config['split']) 32 | assert data_dir.exists() 33 | 34 | # Init Lightning loggers 35 | loggers: List[Logger] = [] 36 | if "logger" in config: 37 | for _, lg_conf in config.logger.items(): 38 | if "_target_" in lg_conf: 39 | log.info(f"Instantiating logger <{lg_conf._target_}>") 40 | loggers.append(hydra.utils.instantiate(lg_conf)) 41 | 42 | if any([isinstance(l, WandbLogger) for l in loggers]): 43 | utils.wandb_login(key=config.wandb_api_key) 44 | 45 | model = hydra.utils.instantiate(config.model) 46 | target_name = model.target_name 47 | ckpt_path = Path(config.ckpt_path) 48 | is_onnx = os.path.split(ckpt_path)[-1].split('.')[-1] == 'onnx' 49 | shutil.copy(ckpt_path,os.getcwd()) # copy model 50 | 51 | ssdrs = [] 52 | bss_lst = [] 53 | bss_perms = [] 54 | num_tracks = len(listdir(data_dir)) 55 | target_list = [config.model.target_name,"complement"] 56 | 57 | 58 | pool = futures.ProcessPoolExecutor 59 | with pool(config.pool_workers) as pool: 60 | datas = sorted(listdir(data_dir)) 61 | if len(datas) > 27: # if not debugging 62 | # move idx 27 to head 63 | datas = [datas[27]] + datas[:27] + datas[28:] 64 | # iterate datas with batchsize 8 65 | for k in range(0, len(datas), config.pool_workers): 66 | batch = datas[k:k + config.pool_workers] 67 | pendings = [] 68 | for i, track in tqdm(enumerate(batch)): 69 | folder_name = track 70 | track = data_dir.joinpath(track) 71 | mixture = load_wav(track.joinpath('mixture.wav')) # (c, t) 72 | target = load_wav(track.joinpath(target_name + '.wav')) 73 | 74 | if model.audio_ch == 1: 75 | mixture = np.mean(mixture, axis=0, keepdims=True) 76 | target = np.mean(target, axis=0, keepdims=True) 77 | #target_hat = {source: separate(config['batch_size'], models[source], onnxs[source], mixture) for source in sources} 78 | if is_onnx: 79 | target_hat = separate_with_onnx_TDF(config.batch_size, model, ckpt_path, mixture) 80 | else: 81 | target_hat = separate_with_ckpt_TDF(config.batch_size, model, ckpt_path, mixture, config.device, config.double_chunk, config.overlap_add) 82 | 83 | 84 | pendings.append((folder_name, pool.submit( 85 | get_metrics, target_hat, target, mixture, sr=44100,version=config.bss))) 86 | 87 | for wandb_logger in [logger for logger in loggers if isinstance(logger, WandbLogger)]: 88 | mid = mixture.shape[-1] // 2 89 | track = target_hat[:, mid - 44100 * 3:mid + 44100 * 3] 90 | wandb_logger.experiment.log( 91 | {f'track={k+i}_target={target_name}': [wandb.Audio(track.T, sample_rate=44100)]}) 92 | 93 | 94 | for i, (track_name, pending) in tqdm(enumerate(pendings)): 95 | pending = pending.result() 96 | bssmetrics, perms, ssdr = pending 97 | bss_lst.append(bssmetrics) 98 | bss_perms.append(perms) 99 | ssdrs.append(ssdr) 100 | 101 | for logger in loggers: 102 | logger.log_metrics({'song/ssdr': ssdr}, k+i) 103 | logger.log_metrics({'song/csdr': get_median_csdr([bssmetrics])}, k+i) 104 | 105 | log_dir = os.getcwd() 106 | save_results(log_dir, bss_lst, target_list, bss_perms, ssdrs) 107 | 108 | cSDR = get_median_csdr(bss_lst) 109 | uSDR = sum(ssdrs)/num_tracks 110 | for logger in loggers: 111 | logger.log_metrics({'metrics/mean_sdr_' + target_name: sum(ssdrs)/num_tracks}) 112 | logger.log_metrics({'metrics/median_csdr_' + target_name: get_median_csdr(bss_lst)}) 113 | # get the path of the log dir 114 | if not isinstance(logger, WandbLogger): 115 | logger.experiment.close() 116 | 117 | if any([isinstance(logger, WandbLogger) for logger in loggers]): 118 | wandb.finish() 119 | 120 | return cSDR, uSDR 121 | -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/src/evaluation/eval_demo.py: -------------------------------------------------------------------------------- 1 | from os import listdir 2 | from pathlib import Path 3 | from typing import Optional, List 4 | 5 | from concurrent import futures 6 | import hydra 7 | import wandb 8 | import os 9 | import shutil 10 | from omegaconf import DictConfig 11 | from pytorch_lightning import LightningDataModule, LightningModule 12 | from pytorch_lightning.loggers import Logger, WandbLogger 13 | 14 | from tqdm import tqdm 15 | import numpy as np 16 | from src.callbacks.wandb_callbacks import get_wandb_logger 17 | from src.evaluation.separate import separate_with_onnx_TDF, separate_with_ckpt_TDF 18 | from src.utils import utils 19 | from src.utils.utils import load_wav, sdr, get_median_csdr, save_results, get_metrics 20 | 21 | from src.utils import pylogger 22 | import soundfile as sf 23 | log = pylogger.get_pylogger(__name__) 24 | 25 | 26 | def evaluation(config: DictConfig, idx): 27 | 28 | assert config.split in ['train', 'valid', 'test'] 29 | 30 | data_dir = Path(config.get('eval_dir')).joinpath(config['split']) 31 | assert data_dir.exists() 32 | 33 | model = hydra.utils.instantiate(config.model) 34 | target_name = model.target_name 35 | ckpt_path = Path(config.ckpt_path) 36 | is_onnx = os.path.split(ckpt_path)[-1].split('.')[-1] == 'onnx' 37 | shutil.copy(ckpt_path,os.getcwd()) # copy model 38 | 39 | datas = sorted(listdir(data_dir)) 40 | if len(datas) > 27: # if not debugging 41 | # move idx 27 to head 42 | datas = [datas[27]] + datas[:27] + datas[28:] 43 | 44 | 45 | track = datas[idx] 46 | track = data_dir.joinpath(track) 47 | print(track) 48 | mixture = load_wav(track.joinpath('mixture.wav')) # (c, t) 49 | target = load_wav(track.joinpath(target_name + '.wav')) 50 | if model.audio_ch == 1: 51 | mixture = np.mean(mixture, axis=0, keepdims=True) 52 | target = np.mean(target, axis=0, keepdims=True) 53 | #target_hat = {source: separate(config['batch_size'], models[source], onnxs[source], mixture) for source in sources} 54 | if is_onnx: 55 | target_hat = separate_with_onnx_TDF(config.batch_size, model, ckpt_path, mixture) 56 | else: 57 | target_hat = separate_with_ckpt_TDF(config.batch_size, model, ckpt_path, mixture, config.device, config.double_chunk, overlap_factor=config.overlap_factor) 58 | 59 | bssmetrics, perms, ssdr = get_metrics(target_hat, target, mixture, sr=44100,version=config.bss) 60 | # dump bssmetrics into pkl 61 | import pickle 62 | with open(os.path.join(os.getcwd(),'bssmetrics.pkl'),'wb') as f: 63 | pickle.dump(bssmetrics,f) 64 | 65 | return bssmetrics 66 | 67 | 68 | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/src/layers/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .batch_norm import * -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/src/layers/batch_norm.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn 6 | 7 | __all__ = ["IBN", "get_norm"] 8 | 9 | 10 | class BatchNorm(nn.BatchNorm2d): 11 | def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0, 12 | bias_init=0.0, **kwargs): 13 | super().__init__(num_features, eps=eps, momentum=momentum) 14 | if weight_init is not None: nn.init.constant_(self.weight, weight_init) 15 | if bias_init is not None: nn.init.constant_(self.bias, bias_init) 16 | self.weight.requires_grad_(not weight_freeze) 17 | self.bias.requires_grad_(not bias_freeze) 18 | 19 | 20 | class SyncBatchNorm(nn.SyncBatchNorm): 21 | def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0, 22 | bias_init=0.0): 23 | super().__init__(num_features, eps=eps, momentum=momentum) 24 | if weight_init is not None: nn.init.constant_(self.weight, weight_init) 25 | if bias_init is not None: nn.init.constant_(self.bias, bias_init) 26 | self.weight.requires_grad_(not weight_freeze) 27 | self.bias.requires_grad_(not bias_freeze) 28 | 29 | 30 | class IBN(nn.Module): 31 | def __init__(self, planes, bn_norm, **kwargs): 32 | super(IBN, self).__init__() 33 | half1 = int(planes / 2) 34 | self.half = half1 35 | half2 = planes - half1 36 | self.IN = nn.InstanceNorm2d(half1, affine=True) 37 | self.BN = get_norm(bn_norm, half2, **kwargs) 38 | 39 | def forward(self, x): 40 | split = torch.split(x, self.half, 1) 41 | out1 = self.IN(split[0].contiguous()) 42 | out2 = self.BN(split[1].contiguous()) 43 | out = torch.cat((out1, out2), 1) 44 | return out 45 | 46 | 47 | class GhostBatchNorm(BatchNorm): 48 | def __init__(self, num_features, num_splits=1, **kwargs): 49 | super().__init__(num_features, **kwargs) 50 | self.num_splits = num_splits 51 | self.register_buffer('running_mean', torch.zeros(num_features)) 52 | self.register_buffer('running_var', torch.ones(num_features)) 53 | 54 | def forward(self, input): 55 | N, C, H, W = input.shape 56 | if self.training or not self.track_running_stats: 57 | self.running_mean = self.running_mean.repeat(self.num_splits) 58 | self.running_var = self.running_var.repeat(self.num_splits) 59 | outputs = F.batch_norm( 60 | input.view(-1, C * self.num_splits, H, W), self.running_mean, self.running_var, 61 | self.weight.repeat(self.num_splits), self.bias.repeat(self.num_splits), 62 | True, self.momentum, self.eps).view(N, C, H, W) 63 | self.running_mean = torch.mean(self.running_mean.view(self.num_splits, self.num_features), dim=0) 64 | self.running_var = torch.mean(self.running_var.view(self.num_splits, self.num_features), dim=0) 65 | return outputs 66 | else: 67 | return F.batch_norm( 68 | input, self.running_mean, self.running_var, 69 | self.weight, self.bias, False, self.momentum, self.eps) 70 | 71 | 72 | class FrozenBatchNorm(nn.Module): 73 | """ 74 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 75 | It contains non-trainable buffers called 76 | "weight" and "bias", "running_mean", "running_var", 77 | initialized to perform identity transformation. 78 | The pre-trained backbone models from Caffe2 only contain "weight" and "bias", 79 | which are computed from the original four parameters of BN. 80 | The affine transform `x * weight + bias` will perform the equivalent 81 | computation of `(x - running_mean) / sqrt(running_var) * weight + bias`. 82 | When loading a backbone model from Caffe2, "running_mean" and "running_var" 83 | will be left unchanged as identity transformation. 84 | Other pre-trained backbone models may contain all 4 parameters. 85 | The forward is implemented by `F.batch_norm(..., training=False)`. 86 | """ 87 | 88 | _version = 3 89 | 90 | def __init__(self, num_features, eps=1e-5, **kwargs): 91 | super().__init__() 92 | self.num_features = num_features 93 | self.eps = eps 94 | self.register_buffer("weight", torch.ones(num_features)) 95 | self.register_buffer("bias", torch.zeros(num_features)) 96 | self.register_buffer("running_mean", torch.zeros(num_features)) 97 | self.register_buffer("running_var", torch.ones(num_features) - eps) 98 | 99 | def forward(self, x): 100 | if x.requires_grad: 101 | # When gradients are needed, F.batch_norm will use extra memory 102 | # because its backward op computes gradients for weight/bias as well. 103 | scale = self.weight * (self.running_var + self.eps).rsqrt() 104 | bias = self.bias - self.running_mean * scale 105 | scale = scale.reshape(1, -1, 1, 1) 106 | bias = bias.reshape(1, -1, 1, 1) 107 | return x * scale + bias 108 | else: 109 | # When gradients are not needed, F.batch_norm is a single fused op 110 | # and provide more optimization opportunities. 111 | return F.batch_norm( 112 | x, 113 | self.running_mean, 114 | self.running_var, 115 | self.weight, 116 | self.bias, 117 | training=False, 118 | eps=self.eps, 119 | ) 120 | 121 | def _load_from_state_dict( 122 | self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs 123 | ): 124 | version = local_metadata.get("version", None) 125 | 126 | if version is None or version < 2: 127 | # No running_mean/var in early versions 128 | # This will silent the warnings 129 | if prefix + "running_mean" not in state_dict: 130 | state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean) 131 | if prefix + "running_var" not in state_dict: 132 | state_dict[prefix + "running_var"] = torch.ones_like(self.running_var) 133 | 134 | if version is not None and version < 3: 135 | logger = logging.getLogger(__name__) 136 | logger.info("FrozenBatchNorm {} is upgraded to version 3.".format(prefix.rstrip("."))) 137 | # In version < 3, running_var are used without +eps. 138 | state_dict[prefix + "running_var"] -= self.eps 139 | 140 | super()._load_from_state_dict( 141 | state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs 142 | ) 143 | 144 | def __repr__(self): 145 | return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps) 146 | 147 | @classmethod 148 | def convert_frozen_batchnorm(cls, module): 149 | """ 150 | Convert BatchNorm/SyncBatchNorm in module into FrozenBatchNorm. 151 | Args: 152 | module (torch.nn.Module): 153 | Returns: 154 | If module is BatchNorm/SyncBatchNorm, returns a new module. 155 | Otherwise, in-place convert module and return it. 156 | Similar to convert_sync_batchnorm in 157 | https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py 158 | """ 159 | bn_module = nn.modules.batchnorm 160 | bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm) 161 | res = module 162 | if isinstance(module, bn_module): 163 | res = cls(module.num_features) 164 | if module.affine: 165 | res.weight.data = module.weight.data.clone().detach() 166 | res.bias.data = module.bias.data.clone().detach() 167 | res.running_mean.data = module.running_mean.data 168 | res.running_var.data = module.running_var.data 169 | res.eps = module.eps 170 | else: 171 | for name, child in module.named_children(): 172 | new_child = cls.convert_frozen_batchnorm(child) 173 | if new_child is not child: 174 | res.add_module(name, new_child) 175 | return res 176 | 177 | 178 | def get_norm(norm, out_channels, **kwargs): 179 | """ 180 | Args: 181 | norm (str or callable): either one of BN, GhostBN, FrozenBN, GN or SyncBN; 182 | or a callable that takes a channel number and returns 183 | the normalization layer as a nn.Module 184 | out_channels: number of channels for normalization layer 185 | 186 | Returns: 187 | nn.Module or None: the normalization layer 188 | """ 189 | # return nn.BatchNorm2d(out_channels) 190 | 191 | if isinstance(norm, str): 192 | if len(norm) == 0: 193 | return None 194 | norm = { 195 | "BN": BatchNorm, 196 | "syncBN": SyncBatchNorm, 197 | "GhostBN": GhostBatchNorm, 198 | "FrozenBN": FrozenBatchNorm, 199 | "GN": lambda channels, **args: nn.GroupNorm(32, channels), 200 | }[norm] 201 | return norm(out_channels, **kwargs) 202 | -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/src/layers/chunk_size.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import numpy as np 3 | import torch 4 | 5 | 6 | 7 | def wave_to_batches(mix, inf_ck, overlap, batch_size): 8 | ''' 9 | Args: 10 | mix: (2, N) numpy array 11 | inf_ck: int, the chunk size as the model input (contains 2*overlap) 12 | inf_ck = overlap + true_samples + overlap 13 | overlap: int, the discarded samples at each side 14 | Returns: 15 | a tuples of batches, each batch is a (batch, 2, inf_ck) torch tensor 16 | ''' 17 | true_samples = inf_ck - 2 * overlap 18 | channels = mix.shape[0] 19 | 20 | right_pad = true_samples + overlap - ((mix.shape[-1]) % true_samples) 21 | mixture = np.concatenate((np.zeros((channels, overlap), dtype='float32'), 22 | mix, 23 | np.zeros((channels, right_pad), dtype='float32')), 24 | 1) 25 | 26 | num_chunks = mixture.shape[-1] // true_samples 27 | mix_waves_batched = np.array([mixture[:, i * true_samples: i * true_samples + inf_ck] for i in 28 | range(num_chunks)]) # (x,2,inf_ck) 29 | return torch.tensor(mix_waves_batched, dtype=torch.float32).split(batch_size) 30 | 31 | def batches_to_wave(target_hat_chunks, overlap, org_len): 32 | ''' 33 | Args: 34 | target_hat_chunks: a list of (batch, 2, inf_ck) torch tensors 35 | overlap: int, the discarded samples at each side 36 | org_len: int, the original length of the mixture 37 | Returns: 38 | (2, N) numpy array 39 | ''' 40 | target_hat_chunks = [c[..., overlap:-overlap] for c in target_hat_chunks] 41 | target_hat_chunks = torch.cat(target_hat_chunks) 42 | 43 | # concat all output chunks 44 | return target_hat_chunks.transpose(0, 1).reshape(2, -1)[..., :org_len].detach().cpu().numpy() 45 | 46 | if __name__ == '__main__': 47 | mix = np.random.rand(2, 14318640) 48 | inf_ck = 261120 49 | overlap = 3072 50 | batch_size = 8 51 | out = wave_to_batches(mix, inf_ck, overlap, batch_size) 52 | in_wav = batches_to_wave(out, overlap, mix.shape[-1]) 53 | print(in_wav.shape) -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/src/train.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import hydra 4 | import pytorch_lightning as pl 5 | import pyrootutils 6 | import torch 7 | import os 8 | import shutil 9 | from omegaconf import DictConfig 10 | from pytorch_lightning import ( 11 | Callback, 12 | LightningDataModule, 13 | LightningModule, 14 | Trainer, 15 | seed_everything, 16 | ) 17 | from pytorch_lightning.loggers import WandbLogger 18 | from hydra.core.hydra_config import HydraConfig 19 | 20 | from src import utils 21 | 22 | pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 23 | 24 | log = utils.get_pylogger(__name__) 25 | 26 | 27 | @utils.task_wrapper 28 | def train(cfg: DictConfig) -> Optional[float]: 29 | """Contains training pipeline. 30 | Instantiates all PyTorch Lightning objects from config. 31 | 32 | Args: 33 | cfg (DictConfig): Configuration composed by Hydra. 34 | 35 | Returns: 36 | Optional[float]: Metric score for hyperparameter optimization. 37 | """ 38 | 39 | # Set seed for random number generators in pytorch, numpy and python.random 40 | try: 41 | if "seed" in cfg: 42 | # set seed for random number generators in pytorch, numpy and python.random 43 | if cfg.get("seed"): 44 | pl.seed_everything(cfg.seed, workers=True) 45 | 46 | else: 47 | raise ModuleNotFoundError 48 | 49 | except ModuleNotFoundError: 50 | print('[Error] seed should be fixed for reproducibility \n=> e.g. python run.py +seed=$SEED') 51 | exit(-1) 52 | 53 | # Init Lightning datamodule 54 | log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>") 55 | datamodule: LightningDataModule = hydra.utils.instantiate(cfg.datamodule) 56 | 57 | # Init Lightning model 58 | log.info(f"Instantiating model <{cfg.model._target_}>") 59 | model: LightningModule = hydra.utils.instantiate(cfg.model) 60 | 61 | # Init Lightning callbacks 62 | callbacks: List[Callback] = [] 63 | if "callbacks" in cfg: 64 | for _, cb_conf in cfg["callbacks"].items(): 65 | if "_target_" in cb_conf: 66 | log.info(f"Instantiating callback <{cb_conf._target_}>") 67 | callbacks.append(hydra.utils.instantiate(cb_conf)) 68 | 69 | # Init Lightning loggers 70 | if "resume_from_checkpoint" in cfg.trainer: 71 | ckpt_path = cfg.trainer.resume_from_checkpoint 72 | # get the parent directory of the checkpoint path 73 | log_dir = os.path.dirname(os.path.dirname(ckpt_path)) 74 | tensorboard_dir = os.path.join(log_dir, "tensorboard") 75 | if os.path.exists(tensorboard_dir): 76 | # copy tensorboard dir to the parent directory of the checkpoint path 77 | # HydraConfig.get().run.dir returns new dir so do not use it! (now fixed) 78 | shutil.copytree(tensorboard_dir,os.path.join(os.getcwd(),"tensorboard")) 79 | 80 | wandb_dir = os.path.join(log_dir, "wandb") 81 | if os.path.exists(wandb_dir): 82 | shutil.copytree(wandb_dir,os.path.join(os.getcwd(),"wandb")) 83 | 84 | 85 | logger: List = [] 86 | if "logger" in cfg: 87 | for _, lg_conf in cfg["logger"].items(): 88 | if "_target_" in lg_conf: 89 | log.info(f"Instantiating logger <{lg_conf._target_}>") 90 | logger.append(hydra.utils.instantiate(lg_conf)) 91 | 92 | for wandb_logger in [l for l in logger if isinstance(l, WandbLogger)]: 93 | utils.wandb_login(key=cfg.wandb_api_key) 94 | # utils.wandb_watch_all(wandb_logger, model) # TODO buggy 95 | break 96 | 97 | # Init Lightning trainer 98 | log.info(f"Instantiating trainer <{cfg.trainer._target_}>") 99 | # get env variable use_gloo 100 | use_gloo = os.environ.get("USE_GLOO", False) 101 | if use_gloo: 102 | from pytorch_lightning.strategies import DDPStrategy 103 | ddp = DDPStrategy(process_group_backend='gloo') 104 | trainer: Trainer = hydra.utils.instantiate( 105 | cfg.trainer, strategy=ddp, callbacks=callbacks, logger=logger, _convert_="partial" 106 | ) 107 | else: 108 | trainer: Trainer = hydra.utils.instantiate( 109 | cfg.trainer, callbacks=callbacks, logger=logger, _convert_="partial" 110 | ) 111 | 112 | # Send some parameters from config to all lightning loggers 113 | log.info("Logging hyperparameters!") 114 | utils.log_hyperparameters( 115 | dict( 116 | cfg=cfg, 117 | model=model, 118 | datamodule=datamodule, 119 | trainer=trainer, 120 | callbacks=callbacks, 121 | logger=logger, 122 | ) 123 | ) 124 | 125 | # Train the model 126 | log.info("Starting training!") 127 | trainer.fit(model=model, datamodule=datamodule) 128 | 129 | # Evaluate model on test set after training 130 | # if not cfg.trainer.get("fast_dev_run"): 131 | # log.info("Starting testing!") 132 | # trainer.test() 133 | 134 | # Make sure everything closed properly 135 | log.info("Finalizing!") 136 | # utils.finish( 137 | # config=cfg, 138 | # model=model, 139 | # datamodule=datamodule, 140 | # trainer=trainer, 141 | # callbacks=callbacks, 142 | # logger=logger, 143 | # ) 144 | 145 | # Print path to best checkpoint 146 | # log.info(f"Best checkpoint path:\n{trainer.checkpoint_callback.best_model_path}") 147 | 148 | # Return metric score for hyperparameter optimization 149 | # optimized_metric = cfg.get("optimized_metric") 150 | # if optimized_metric: 151 | # return trainer.callback_metrics[optimized_metric] 152 | return None, None 153 | -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from src.utils.pylogger import get_pylogger 2 | from src.utils.rich_utils import enforce_tags, print_config_tree 3 | from src.utils.utils import * 4 | -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/src/utils/data_augmentation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess as sp 3 | import tempfile 4 | import warnings 5 | from argparse import ArgumentParser 6 | from concurrent import futures 7 | 8 | import numpy as np 9 | import soundfile as sf 10 | import torch 11 | from tqdm import tqdm 12 | 13 | warnings.simplefilter(action='ignore', category=Warning) 14 | source_names = ['vocals', 'drums', 'bass', 'other'] 15 | sample_rate = 44100 16 | 17 | def main (args): 18 | data_root = args.data_dir 19 | train = args.train 20 | test = args.test 21 | valid = args.valid 22 | 23 | musdb_train_path = data_root + 'train/' 24 | musdb_test_path = data_root + 'test/' 25 | musdb_valid_path = data_root + 'valid/' 26 | print(f"train={train}, test={test}, valid={valid}") 27 | 28 | mix_name = 'mixture' 29 | 30 | P = [-3, -2, -1, 0, 1, 2, 3] # pitch shift amounts (in semitones) 31 | T = [-30, -20, -10, 0, 10, 20, 30] # time stretch amounts (10 means 10% slower) 32 | 33 | pool = futures.ProcessPoolExecutor 34 | pool_workers = 13 35 | pendings = [] 36 | with pool(pool_workers) as pool: 37 | for p in P: 38 | for t in T: 39 | if not (p==0 and t==0): 40 | if train: 41 | pendings.append(pool.submit(save_shifted_dataset, p, t, musdb_train_path)) 42 | # save_shifted_dataset(p, t, musdb_train_path) 43 | if valid: 44 | save_shifted_dataset(p, t, musdb_valid_path) 45 | if test: 46 | save_shifted_dataset(p, t, musdb_test_path) 47 | for pending in pendings: 48 | pending.result() 49 | 50 | 51 | def shift(wav, pitch, tempo, voice=False, quick=False, samplerate=44100): 52 | def i16_pcm(wav): 53 | if wav.dtype == np.int16: 54 | return wav 55 | return (wav * 2 ** 15).clamp_(-2 ** 15, 2 ** 15 - 1).short() 56 | 57 | def f32_pcm(wav): 58 | if wav.dtype == np.float: 59 | return wav 60 | return wav.float() / 2 ** 15 61 | 62 | """ 63 | tempo is a relative delta in percentage, so tempo=10 means tempo at 110%! 64 | pitch is in semi tones. 65 | Requires `soundstretch` to be installed, see 66 | https://www.surina.net/soundtouch/soundstretch.html 67 | """ 68 | 69 | inputfile = tempfile.NamedTemporaryFile(dir="/root/autodl-tmp/tmp", suffix=".wav") 70 | outfile = tempfile.NamedTemporaryFile(dir="/root/autodl-tmp/tmp", suffix=".wav") 71 | 72 | sf.write(inputfile.name, data=i16_pcm(wav).t().numpy(), samplerate=samplerate, format='WAV') 73 | command = [ 74 | "soundstretch", 75 | inputfile.name, 76 | outfile.name, 77 | f"-pitch={pitch}", 78 | f"-tempo={tempo:.6f}", 79 | ] 80 | if quick: 81 | command += ["-quick"] 82 | if voice: 83 | command += ["-speech"] 84 | try: 85 | sp.run(command, capture_output=True, check=True) 86 | except sp.CalledProcessError as error: 87 | raise RuntimeError(f"Could not change bpm because {error.stderr.decode('utf-8')}") 88 | wav, sr = sf.read(outfile.name, dtype='float32') 89 | # wav = np.float32(wav) 90 | # wav = f32_pcm(torch.from_numpy(wav).t()) 91 | assert sr == samplerate 92 | return wav 93 | 94 | 95 | def save_shifted_dataset(delta_pitch, delta_tempo, data_path): 96 | out_path = data_path[:-1] + f'_p={delta_pitch}_t={delta_tempo}/' 97 | try: 98 | os.mkdir(out_path) 99 | except FileExistsError: 100 | pass 101 | track_names = list(filter(lambda x: os.path.isdir(f'{data_path}/{x}'), sorted(os.listdir(data_path)))) 102 | for track_name in tqdm(track_names): 103 | try: 104 | os.mkdir(f'{out_path}/{track_name}') 105 | except FileExistsError: 106 | pass 107 | for s_name in source_names: 108 | source = load_wav(f'{data_path}/{track_name}/{s_name}.wav') 109 | shifted = shift( 110 | torch.tensor(source), 111 | delta_pitch, 112 | delta_tempo, 113 | voice=s_name == 'vocals') 114 | sf.write(f'{out_path}/{track_name}/{s_name}.wav', shifted, samplerate=sample_rate, format='WAV') 115 | 116 | 117 | def load_wav(path, sr=None): 118 | return sf.read(path, samplerate=sr, dtype='float32')[0].T 119 | 120 | 121 | if __name__ == '__main__': 122 | parser = ArgumentParser() 123 | parser.add_argument('--data_dir', type=str) 124 | parser.add_argument('--train', type=bool, default=True) 125 | parser.add_argument('--valid', type=bool, default=False) 126 | parser.add_argument('--test', type=bool, default=False) 127 | 128 | main(parser.parse_args()) -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/src/utils/omega_resolvers.py: -------------------------------------------------------------------------------- 1 | import os 2 | from src.utils import get_unique_save_path 3 | 4 | cur_suffix = -1 5 | 6 | def get_train_log_dir(a, b, *, _parent_): 7 | global cur_suffix 8 | logbase = os.environ.get("LOG_DIR", "dtt_logs") 9 | if not os.path.exists(logbase): 10 | os.mkdir(logbase) 11 | dir_path = os.path.join(logbase, f"{a}_{b}") 12 | if cur_suffix == -1: 13 | cur_suffix = get_unique_save_path(dir_path) 14 | return f"{dir_path}_{str(cur_suffix)}" 15 | else: 16 | return f"{dir_path}_{str(cur_suffix)}" 17 | 18 | cur_suffix1 = -1 19 | 20 | def get_eval_log_dir(ckpt_path, *, _parent_): 21 | # get environment variable logbase 22 | global cur_suffix1 23 | logbase = os.environ.get("LOG_DIR", "dtt_logs") 24 | if not os.path.exists(logbase): 25 | os.mkdir(logbase) 26 | ckpt_name = os.path.basename(ckpt_path) 27 | # remove the suffix 28 | ckpt_name = ckpt_name.split(".")[0] 29 | dir_path = os.path.join(logbase, f"eval_{ckpt_name}") 30 | if cur_suffix1 == -1: 31 | cur_suffix1 = get_unique_save_path(dir_path) 32 | return f"{dir_path}_{str(cur_suffix1)}" 33 | else: 34 | return f"{dir_path}_{str(cur_suffix1)}" 35 | 36 | cur_suffix2 = -1 37 | 38 | def get_sweep_log_dir(a, b, *, _parent_): 39 | global cur_suffix2 40 | logbase = os.environ.get("LOG_DIR", "dtt_logs") 41 | if not os.path.exists(logbase): 42 | os.mkdir(logbase) 43 | dir_path = os.path.join(logbase, f"m_{a}_{b}") 44 | if cur_suffix2 == -1: 45 | cur_suffix2 = get_unique_save_path(dir_path) 46 | return f"{dir_path}_{str(cur_suffix2)}" 47 | else: 48 | return f"{dir_path}_{str(cur_suffix2)}" -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/src/utils/pick_best.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import traceback 4 | import pandas as pd 5 | import wandb 6 | import numpy as np 7 | from tensorboard.backend.event_processing.event_accumulator import EventAccumulator 8 | 9 | class AbstractLoggerQuery: 10 | def get_value_steps(self, tag): 11 | raise NotImplementedError 12 | 13 | # def steps_to_epoch(self, steps): 14 | # raise NotImplementedError 15 | 16 | def get_best(self, tag, method="min"): 17 | v, s = self.get_value_steps(tag) 18 | indexes = [i for i in range(len(s))] 19 | # sort the indexes by v 20 | if method == "min": 21 | indexes.sort(key=lambda x: v[x]) 22 | else: 23 | indexes.sort(key=lambda x: -v[x]) 24 | # get the top 5 indexes 25 | top5_idx = indexes[:3] 26 | # the the values 27 | top5_values = [v[i] for i in top5_idx] 28 | top5_steps = [s[i] for i in top5_idx] 29 | return top5_idx, top5_values, top5_steps 30 | 31 | def get_best_epochs(self, tag, method="min"): 32 | v, e = self.get_value_epochs(tag) 33 | indexes = [i for i in range(len(e))] 34 | # sort the indexes by v 35 | if method == "min": 36 | indexes.sort(key=lambda x: v[x]) 37 | else: 38 | indexes.sort(key=lambda x: -v[x]) 39 | # get the top 5 indexes 40 | top5_idx = indexes[:3] 41 | # the the values 42 | top5_values = [v[i] for i in top5_idx] 43 | top5_epochs = [e[i] for i in top5_idx] 44 | return top5_idx, top5_values, top5_epochs 45 | 46 | def get_last_ep(self, tag): 47 | v, s = self.get_value_steps(tag) 48 | return len(s) - 1 49 | 50 | def _bsearch(self, arr, target_step): 51 | l = 0 52 | r = len(arr) - 1 53 | while l < r: 54 | mid = (l + r + 1) // 2 55 | if arr[mid] <= target_step: 56 | l = mid 57 | else: 58 | r = mid - 1 59 | return l 60 | 61 | def steps_to_epoch(self, steps): 62 | ep_ep, ep_step = self.get_value_steps('epoch') 63 | mapped = [] 64 | for vs in steps: 65 | idx = self._bsearch(ep_step, vs) 66 | mapped.append(ep_ep[idx]) 67 | return mapped 68 | 69 | def get_value_epochs(self, tag): 70 | raise NotImplementedError 71 | 72 | 73 | class TFLoggerQuery(AbstractLoggerQuery): 74 | def __init__(self, log_dir): 75 | self.event_acc = EventAccumulator(log_dir) 76 | self.event_acc.Reload() 77 | 78 | def get_value_steps(self, tag): 79 | event_list = self.event_acc.Scalars(tag) 80 | values = list(map(lambda x: x.value, event_list)) 81 | step = list(map(lambda x: x.step, event_list)) 82 | return values, step 83 | 84 | # def steps_to_epoch(self, steps): 85 | # v, s = self.get_value_steps('epoch') 86 | # # zip the two list 87 | # return [int(v[s.index(step)]) for step in steps] 88 | 89 | def get_tags(self): 90 | return self.event_acc.Tags()["scalars"] 91 | 92 | 93 | class WandbLoggerQuery(AbstractLoggerQuery): 94 | def __init__(self, api_key, path): 95 | self.api_key = api_key 96 | self.path = path 97 | 98 | def _get_run(self): 99 | return wandb.Api(api_key=self.api_key).run(self.path) 100 | 101 | def get_exp_name(self): 102 | run = self._get_run() 103 | return run.name 104 | 105 | def get_value_steps(self, tag): 106 | run = self._get_run() 107 | # run.history returns sampled data, do not use it 108 | # history = run.history(keys=[tag]) 109 | history = run.scan_history(keys=[tag, "_step"]) 110 | losses = [row for row in history] 111 | df = pd.DataFrame(losses) 112 | df = df.sort_values(by="_step") 113 | values = df[tag].values 114 | step = df["_step"].values 115 | return values, step 116 | 117 | def get_value_epochs(self, tag, sort_by="epoch"): 118 | run = self._get_run() 119 | history = run.scan_history(keys=[tag, "epoch"]) 120 | losses = [row for row in history] 121 | df = pd.DataFrame(losses) 122 | df = df.sort_values(by=sort_by) 123 | values = df[tag].values 124 | epoch = df["epoch"].values 125 | return values, epoch 126 | 127 | 128 | 129 | def steps_to_runtime(self, steps): 130 | rt_ep, rt_step = self.get_value_steps('_runtime') 131 | mapped = [] 132 | for vs in steps: 133 | idx = self._bsearch(rt_step, vs) 134 | secs = rt_ep[idx] 135 | # seconds to days hours minutes seconds 136 | m, s = divmod(secs, 60) 137 | h, m = divmod(m, 60) 138 | d, h = divmod(h, 24) 139 | mapped.append(f"{d:.0f}d {h:.0f}h {m:.0f}m {s:.0f}s") 140 | return mapped 141 | 142 | def epochs_to_runtime(self, epochs): 143 | rt_ep, ep_lst = self.get_value_epochs('_runtime') 144 | mapped = [] 145 | for vs in epochs: 146 | idx = self._bsearch(ep_lst, vs) 147 | secs = rt_ep[idx] 148 | # seconds to days hours minutes seconds 149 | m, s = divmod(secs, 60) 150 | h, m = divmod(m, 60) 151 | d, h = divmod(h, 24) 152 | mapped.append(f"{d:.0f}d {h:.0f}h {m:.0f}m {s:.0f}s") 153 | return mapped 154 | 155 | def _secs_to_string(self, secs): 156 | # seconds to days hours minutes seconds 157 | m, s = divmod(secs, 60) 158 | h, m = divmod(m, 60) 159 | d, h = divmod(h, 24) 160 | return f"{d:.0f}d {h:.0f}h {m:.0f}m {s:.0f}s" 161 | 162 | def avg_runtime_per_ep(self): 163 | epochs = [i for i in range(1, 100)] 164 | rt_ep, ep_lst = self.get_value_epochs('_runtime', sort_by="_runtime") # 坑:这里必须要按照runtime排序,因为epochs有很多相同的值 165 | mapped = [] 166 | for vs in epochs: 167 | idx = self._bsearch(ep_lst, vs) 168 | secs = rt_ep[idx] 169 | mapped.append(secs) 170 | new = [] 171 | for i in range(1, len(mapped)): 172 | new.append(mapped[i] - mapped[i-1]) 173 | meansecs = np.mean(new) 174 | print(new) 175 | return self._secs_to_string(meansecs) 176 | 177 | 178 | if __name__ == "__main__": 179 | tag = "val/usdr" 180 | api_key = "xxxxxxxxxxxx" 181 | path = "username/dtt_vocals/xxxxxxx" 182 | 183 | w_query = WandbLoggerQuery(api_key=api_key, path=path) 184 | best_sdr_idx, best_sdr_values, best_sdr_epochs = w_query.get_best_epochs(tag, method="max") 185 | print("idx", best_sdr_idx) 186 | print("values", best_sdr_values) 187 | print("epochs", best_sdr_epochs) 188 | print("runtime", w_query.epochs_to_runtime(best_sdr_epochs)) -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/src/utils/pylogger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from pytorch_lightning.utilities import rank_zero_only 4 | 5 | 6 | def get_pylogger(name=__name__) -> logging.Logger: 7 | """Initializes multi-GPU-friendly python command line logger.""" 8 | 9 | logger = logging.getLogger(name) 10 | 11 | # this ensures all logging levels get marked with the rank zero decorator 12 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup 13 | logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") 14 | for level in logging_levels: 15 | setattr(logger, level, rank_zero_only(getattr(logger, level))) 16 | 17 | return logger -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/src/utils/rich_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Sequence 3 | 4 | import rich 5 | import rich.syntax 6 | import rich.tree 7 | from hydra.core.hydra_config import HydraConfig 8 | from omegaconf import DictConfig, OmegaConf, open_dict 9 | from pytorch_lightning.utilities import rank_zero_only 10 | from rich.prompt import Prompt 11 | 12 | from src.utils import pylogger 13 | 14 | log = pylogger.get_pylogger(__name__) 15 | 16 | 17 | @rank_zero_only 18 | def print_config_tree( 19 | cfg: DictConfig, 20 | print_order: Sequence[str] = ( 21 | "data", 22 | "model", 23 | "callbacks", 24 | "logger", 25 | "trainer", 26 | "paths", 27 | "extras", 28 | ), 29 | resolve: bool = False, 30 | save_to_file: bool = False, 31 | ) -> None: 32 | """Prints content of DictConfig using Rich library and its tree structure. 33 | 34 | Args: 35 | cfg (DictConfig): Configuration composed by Hydra. 36 | print_order (Sequence[str], optional): Determines in what order config components are printed. 37 | resolve (bool, optional): Whether to resolve reference fields of DictConfig. 38 | save_to_file (bool, optional): Whether to export config to the hydra output folder. 39 | """ 40 | 41 | style = "dim" 42 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) 43 | 44 | queue = [] 45 | 46 | # add fields from `print_order` to queue 47 | for field in print_order: 48 | queue.append(field) if field in cfg else log.warning( 49 | f"Field '{field}' not found in config. Skipping '{field}' config printing..." 50 | ) 51 | 52 | # add all the other fields to queue (not specified in `print_order`) 53 | for field in cfg: 54 | if field not in queue: 55 | queue.append(field) 56 | 57 | # generate config tree from queue 58 | for field in queue: 59 | branch = tree.add(field, style=style, guide_style=style) 60 | 61 | config_group = cfg[field] 62 | if isinstance(config_group, DictConfig): 63 | branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) 64 | else: 65 | branch_content = str(config_group) 66 | 67 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 68 | 69 | # print config tree 70 | rich.print(tree) 71 | 72 | # save config tree to file 73 | if save_to_file: 74 | with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: 75 | rich.print(tree, file=file) 76 | 77 | 78 | @rank_zero_only 79 | def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: 80 | """Prompts user to input tags from command line if no tags are provided in config.""" 81 | 82 | if not cfg.get("tags"): 83 | if "id" in HydraConfig().cfg.hydra.job: 84 | raise ValueError("Specify tags before launching a multirun!") 85 | 86 | log.warning("No tags provided in config. Prompting user to input tags...") 87 | tags = Prompt.ask("Enter a list of comma separated tags", default="dev") 88 | tags = [t.strip() for t in tags.split(",") if t != ""] 89 | 90 | with open_dict(cfg): 91 | cfg.tags = tags 92 | 93 | log.info(f"Tags: {cfg.tags}") 94 | 95 | if save_to_file: 96 | with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: 97 | rich.print(cfg.tags, file=file) -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yongyizang/TrainingFreeMultiStepASR/49333f1050455f25649cbec902bba0d527561b75/MusicSourceSeparation/infer/tests/__init__.py -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/tests/helpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yongyizang/TrainingFreeMultiStepASR/49333f1050455f25649cbec902bba0d527561b75/MusicSourceSeparation/infer/tests/helpers/__init__.py -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/tests/helpers/module_available.py: -------------------------------------------------------------------------------- 1 | import platform 2 | from importlib.util import find_spec 3 | 4 | """ 5 | Adapted from: 6 | https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/imports.py 7 | """ 8 | 9 | 10 | def _module_available(module_path: str) -> bool: 11 | """Check if a path is available in your environment. 12 | 13 | >>> _module_available('os') 14 | True 15 | >>> _module_available('bla.bla') 16 | False 17 | 18 | """ 19 | try: 20 | return find_spec(module_path) is not None 21 | except AttributeError: 22 | # Python 3.6 23 | return False 24 | except ModuleNotFoundError: 25 | # Python 3.7+ 26 | return False 27 | 28 | 29 | _IS_WINDOWS = platform.system() == "Windows" 30 | _APEX_AVAILABLE = _module_available("apex.amp") 31 | _DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _module_available("deepspeed") 32 | _FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available("fairscale.nn") 33 | _RPC_AVAILABLE = not _IS_WINDOWS and _module_available("torch.distributed.rpc") 34 | -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/tests/helpers/run_command.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import pytest 4 | import sh 5 | 6 | 7 | def run_command(command: List[str]): 8 | """Default method for executing shell commands with pytest.""" 9 | msg = None 10 | try: 11 | sh.python(command) 12 | except sh.ErrorReturnCode as e: 13 | msg = e.stderr.decode() 14 | if msg: 15 | pytest.fail(msg=msg) 16 | -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/tests/helpers/runif.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import Optional 3 | 4 | import pytest 5 | import torch 6 | from packaging.version import Version 7 | from pkg_resources import get_distribution 8 | 9 | """ 10 | Adapted from: 11 | https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/helpers/runif.py 12 | """ 13 | 14 | from tests.helpers.module_available import ( 15 | _APEX_AVAILABLE, 16 | _DEEPSPEED_AVAILABLE, 17 | _FAIRSCALE_AVAILABLE, 18 | _IS_WINDOWS, 19 | _RPC_AVAILABLE, 20 | ) 21 | 22 | 23 | class RunIf: 24 | """ 25 | RunIf wrapper for conditional skipping of tests. 26 | Fully compatible with `@pytest.mark`. 27 | 28 | Example: 29 | 30 | @RunIf(min_torch="1.8") 31 | @pytest.mark.parametrize("arg1", [1.0, 2.0]) 32 | def test_wrapper(arg1): 33 | assert arg1 > 0 34 | 35 | """ 36 | 37 | def __new__( 38 | self, 39 | min_gpus: int = 0, 40 | min_torch: Optional[str] = None, 41 | max_torch: Optional[str] = None, 42 | min_python: Optional[str] = None, 43 | amp_apex: bool = False, 44 | skip_windows: bool = False, 45 | rpc: bool = False, 46 | fairscale: bool = False, 47 | deepspeed: bool = False, 48 | **kwargs, 49 | ): 50 | """ 51 | Args: 52 | min_gpus: min number of gpus required to run test 53 | min_torch: minimum pytorch version to run test 54 | max_torch: maximum pytorch version to run test 55 | min_python: minimum python version required to run test 56 | amp_apex: NVIDIA Apex is installed 57 | skip_windows: skip test for Windows platform 58 | rpc: requires Remote Procedure Call (RPC) 59 | fairscale: if `fairscale` module is required to run the test 60 | deepspeed: if `deepspeed` module is required to run the test 61 | kwargs: native pytest.mark.skipif keyword arguments 62 | """ 63 | conditions = [] 64 | reasons = [] 65 | 66 | if min_gpus: 67 | conditions.append(torch.cuda.device_count() < min_gpus) 68 | reasons.append(f"GPUs>={min_gpus}") 69 | 70 | if min_torch: 71 | torch_version = get_distribution("torch").version 72 | conditions.append(Version(torch_version) < Version(min_torch)) 73 | reasons.append(f"torch>={min_torch}") 74 | 75 | if max_torch: 76 | torch_version = get_distribution("torch").version 77 | conditions.append(Version(torch_version) >= Version(max_torch)) 78 | reasons.append(f"torch<{max_torch}") 79 | 80 | if min_python: 81 | py_version = ( 82 | f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" 83 | ) 84 | conditions.append(Version(py_version) < Version(min_python)) 85 | reasons.append(f"python>={min_python}") 86 | 87 | if amp_apex: 88 | conditions.append(not _APEX_AVAILABLE) 89 | reasons.append("NVIDIA Apex") 90 | 91 | if skip_windows: 92 | conditions.append(_IS_WINDOWS) 93 | reasons.append("does not run on Windows") 94 | 95 | if rpc: 96 | conditions.append(not _RPC_AVAILABLE) 97 | reasons.append("RPC") 98 | 99 | if fairscale: 100 | conditions.append(not _FAIRSCALE_AVAILABLE) 101 | reasons.append("Fairscale") 102 | 103 | if deepspeed: 104 | conditions.append(not _DEEPSPEED_AVAILABLE) 105 | reasons.append("Deepspeed") 106 | 107 | reasons = [rs for cond, rs in zip(conditions, reasons) if cond] 108 | return pytest.mark.skipif( 109 | condition=any(conditions), 110 | reason=f"Requires: [{' + '.join(reasons)}]", 111 | **kwargs, 112 | ) 113 | -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/tests/smoke/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yongyizang/TrainingFreeMultiStepASR/49333f1050455f25649cbec902bba0d527561b75/MusicSourceSeparation/infer/tests/smoke/__init__.py -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/tests/smoke/test_commands.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tests.helpers.run_command import run_command 4 | from tests.helpers.runif import RunIf 5 | 6 | 7 | def test_fast_dev_run(): 8 | """Run 1 train, val, test batch.""" 9 | command = ["run.py", "trainer=default", "trainer.fast_dev_run=true"] 10 | run_command(command) 11 | 12 | 13 | def test_default_cpu(): 14 | """Test default configuration on CPU.""" 15 | command = ["run.py", "trainer.max_epochs=1", "trainer.gpus=0"] 16 | run_command(command) 17 | 18 | 19 | @RunIf(min_gpus=1) 20 | def test_default_gpu(): 21 | """Test default configuration on GPU.""" 22 | command = [ 23 | "run.py", 24 | "trainer.max_epochs=1", 25 | "trainer.gpus=1", 26 | "datamodule.pin_memory=True", 27 | ] 28 | run_command(command) 29 | 30 | 31 | @pytest.mark.slow 32 | def test_experiments(): 33 | """Train 1 epoch with all experiment configs.""" 34 | command = ["run.py", "-m", "experiment=glob(*)", "trainer.max_epochs=1"] 35 | run_command(command) 36 | 37 | 38 | def test_limit_batches(): 39 | """Train 1 epoch on 25% of data.""" 40 | command = [ 41 | "run.py", 42 | "trainer=default", 43 | "trainer.max_epochs=1", 44 | "trainer.limit_train_batches=0.25", 45 | "trainer.limit_val_batches=0.25", 46 | "trainer.limit_test_batches=0.25", 47 | ] 48 | run_command(command) 49 | 50 | 51 | def test_gradient_accumulation(): 52 | """Train 1 epoch with gradient accumulation.""" 53 | command = [ 54 | "run.py", 55 | "trainer=default", 56 | "trainer.max_epochs=1", 57 | "trainer.accumulate_grad_batches=10", 58 | ] 59 | run_command(command) 60 | 61 | 62 | def test_double_validation_loop(): 63 | """Train 1 epoch with validation loop twice per epoch.""" 64 | command = [ 65 | "run.py", 66 | "trainer=default", 67 | "trainer.max_epochs=1", 68 | "trainer.val_check_interval=0.5", 69 | ] 70 | run_command(command) 71 | 72 | 73 | def test_csv_logger(): 74 | """Train 5 epochs with 5 batches with CSVLogger.""" 75 | command = [ 76 | "run.py", 77 | "trainer=default", 78 | "trainer.max_epochs=5", 79 | "trainer.limit_train_batches=5", 80 | "logger=csv", 81 | ] 82 | run_command(command) 83 | 84 | 85 | def test_tensorboard_logger(): 86 | """Train 5 epochs with 5 batches with TensorboardLogger.""" 87 | command = [ 88 | "run.py", 89 | "trainer=default", 90 | "trainer.max_epochs=5", 91 | "trainer.limit_train_batches=5", 92 | "logger=tensorboard", 93 | ] 94 | run_command(command) 95 | 96 | 97 | def test_overfit_batches(): 98 | """Overfit to 10 batches over 10 epochs.""" 99 | command = [ 100 | "run.py", 101 | "trainer=default", 102 | "trainer.min_epochs=10", 103 | "trainer.max_epochs=10", 104 | "trainer.overfit_batches=10", 105 | ] 106 | run_command(command) 107 | -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/tests/smoke/test_mixed_precision.py: -------------------------------------------------------------------------------- 1 | from tests.helpers.run_command import run_command 2 | from tests.helpers.runif import RunIf 3 | 4 | 5 | @RunIf(amp_apex=True) 6 | def test_apex_01(): 7 | """Test mixed-precision level 01.""" 8 | command = [ 9 | "run.py", 10 | "trainer=default", 11 | "trainer.max_epochs=1", 12 | "trainer.gpus=1", 13 | "trainer.amp_backend=apex", 14 | "trainer.amp_level=O1", 15 | "trainer.precision=16", 16 | ] 17 | run_command(command) 18 | 19 | 20 | @RunIf(amp_apex=True) 21 | def test_apex_02(): 22 | """Test mixed-precision level 02.""" 23 | command = [ 24 | "run.py", 25 | "trainer=default", 26 | "trainer.max_epochs=1", 27 | "trainer.gpus=1", 28 | "trainer.amp_backend=apex", 29 | "trainer.amp_level=O2", 30 | "trainer.precision=16", 31 | ] 32 | run_command(command) 33 | 34 | 35 | @RunIf(amp_apex=True) 36 | def test_apex_03(): 37 | """Test mixed-precision level 03.""" 38 | command = [ 39 | "run.py", 40 | "trainer=default", 41 | "trainer.max_epochs=1", 42 | "trainer.gpus=1", 43 | "trainer.amp_backend=apex", 44 | "trainer.amp_level=O3", 45 | "trainer.precision=16", 46 | ] 47 | run_command(command) 48 | -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/tests/smoke/test_sweeps.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tests.helpers.run_command import run_command 4 | 5 | """ 6 | Use the following command to skip slow tests: 7 | pytest -k "not slow" 8 | """ 9 | 10 | 11 | @pytest.mark.slow 12 | def test_default_sweep(): 13 | """Test default Hydra sweeper.""" 14 | command = [ 15 | "run.py", 16 | "-m", 17 | "datamodule.batch_size=64,128", 18 | "model.lr=0.01,0.02", 19 | "trainer=default", 20 | "trainer.fast_dev_run=true", 21 | ] 22 | run_command(command) 23 | 24 | 25 | @pytest.mark.slow 26 | def test_optuna_sweep(): 27 | """Test Optuna sweeper.""" 28 | command = [ 29 | "run.py", 30 | "-m", 31 | "hparams_search=mnist_optuna", 32 | "trainer=default", 33 | "trainer.fast_dev_run=true", 34 | ] 35 | run_command(command) 36 | 37 | 38 | @pytest.mark.skip(reason="TODO: Add Ax sweep config.") 39 | @pytest.mark.slow 40 | def test_ax_sweep(): 41 | """Test Ax sweeper.""" 42 | command = ["run.py", "-m", "hparams_search=mnist_ax", "trainer.fast_dev_run=true"] 43 | run_command(command) 44 | -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/tests/smoke/test_wandb.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tests.helpers.run_command import run_command 4 | 5 | """ 6 | Use the following command to skip slow tests: 7 | pytest -k "not slow" 8 | """ 9 | 10 | 11 | # @pytest.mark.slow 12 | # def test_wandb_optuna_sweep(): 13 | # """Test wandb logging with Optuna sweep.""" 14 | # command = [ 15 | # "run.py", 16 | # "-m", 17 | # "hparams_search=mnist_optuna", 18 | # "trainer=default", 19 | # "trainer.max_epochs=10", 20 | # "trainer.limit_train_batches=20", 21 | # "logger=wandb", 22 | # "logger.wandb.project=template-tests", 23 | # "logger.wandb.group=Optuna_SimpleDenseNet_MNIST", 24 | # "hydra.sweeper.n_trials=5", 25 | # ] 26 | # run_command(command) 27 | 28 | 29 | # @pytest.mark.slow 30 | # def test_wandb_callbacks(): 31 | # """Test wandb callbacks.""" 32 | # command = [ 33 | # "run.py", 34 | # "trainer=default", 35 | # "trainer.max_epochs=3", 36 | # "logger=wandb", 37 | # "logger.wandb.project=template-tests", 38 | # "callbacks=wandb", 39 | # ] 40 | # run_command(command) 41 | -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/tests/submit/to_onnx.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yongyizang/TrainingFreeMultiStepASR/49333f1050455f25649cbec902bba0d527561b75/MusicSourceSeparation/infer/tests/submit/to_onnx.py -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/tests/unit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yongyizang/TrainingFreeMultiStepASR/49333f1050455f25649cbec902bba0d527561b75/MusicSourceSeparation/infer/tests/unit/__init__.py -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/tests/unit/test_sth.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tests.helpers.runif import RunIf 4 | 5 | 6 | def test_something1(): 7 | """Some test description.""" 8 | assert True is True 9 | 10 | 11 | def test_something2(): 12 | """Some test description.""" 13 | assert 1 + 1 == 2 14 | 15 | 16 | @pytest.mark.parametrize("arg1", [0.5, 1.0, 2.0]) 17 | def test_something3(arg1: float): 18 | """Some test description.""" 19 | assert arg1 > 0 20 | 21 | 22 | # use RunIf to skip execution of some tests, e.g. when not on windows or when no gpus are available 23 | @RunIf(skip_windows=True, min_gpus=1) 24 | def test_something4(): 25 | """Some test description.""" 26 | assert True is True 27 | -------------------------------------------------------------------------------- /MusicSourceSeparation/infer/train.py: -------------------------------------------------------------------------------- 1 | import dotenv 2 | import hydra 3 | from omegaconf import DictConfig, OmegaConf 4 | from src.utils.omega_resolvers import get_train_log_dir, get_sweep_log_dir 5 | # load environment variables from `.env` file if it exists 6 | # recursively searches for `.env` in all folders starting from work dir 7 | from pytorch_lightning.utilities import rank_zero_info 8 | 9 | from src.utils import print_config_tree 10 | 11 | dotenv.load_dotenv(override=True) 12 | 13 | @hydra.main(config_path="configs/", config_name="config.yaml", version_base='1.1') 14 | def main(config: DictConfig): 15 | 16 | # Imports should be nested inside @hydra.main to optimize tab completion 17 | # Read more here: https://github.com/facebookresearch/hydra/issues/934 18 | from src.train import train 19 | from src.utils import utils 20 | 21 | rank_zero_info(OmegaConf.to_yaml(config)) 22 | 23 | # A couple of optional utilities: 24 | # - disabling python warnings 25 | # - easier access to debug mode 26 | # - forcing debug friendly configuration 27 | # - forcing multi-gpu friendly configuration 28 | # You can safely get rid of this line if you don't want those 29 | utils.extras(config) 30 | 31 | # Pretty print config using Rich library 32 | if config.get("print_config"): 33 | print_config_tree(config, resolve=True) 34 | 35 | # Train model 36 | train(config) 37 | 38 | 39 | 40 | if __name__ == "__main__": 41 | # register resolvers with hydra key run.dir 42 | OmegaConf.register_new_resolver("get_train_log_dir", get_train_log_dir) 43 | OmegaConf.register_new_resolver("get_sweep_log_dir", get_sweep_log_dir) 44 | main() 45 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Training-Free Multi-Step Audio Source Separation 2 | 3 | Official implementation of **"Training-Free Multi-Step Audio Source Separation"** 4 | 5 | We reveal that pretrained one-step audio source separation models can be leveraged for multi-step separation without additional training. Our simple yet effective inference method iteratively applies separation by optimally blending the input mixture with the previous step's separation result. At each step, we determine the optimal blending ratio by maximizing a metric. 6 | 7 | Note that the code is for research purposes and is thus very noisy and not well-structured. We will not provide support for running the code, but we will try to answer questions related to the paper. Running the code directly without any change should yield the exact same results as reported in paper. 8 | 9 | ## Structure 10 | You may need to change the path for your dataset. 11 | 12 | Run experiments for speech enhancement models using: 13 | `python ./SpeechEnhancement/infer/infer_sdr.py` for SDR-based search, 14 | `python ./SpeechEnhancement/infer/infer_pesq.py` for PESQ-based search, 15 | `python ./SpeechEnhancement/infer/infer_large.py` for large or xlarge model variants, 16 | and `python ./SpeechEnhancement/infer/infer_blind_utmos.py` for blind UTMOS search. 17 | 18 | Run experiments for music separation models using: 19 | `python ./MusicSourceSeparation/infer/run_infer.py` (note that the code is largely a direct copy from the original DTTNet repository, with key changes in `MusicSourceSeparation/infer/src/evaluation/separate.py`) 20 | 21 | Evaluation code is in respective folders as well (see `SpeechEnhancement/eval` and `MusicSourceSeparation/eval.py`). -------------------------------------------------------------------------------- /SpeechEnhancement/eval/intrusive.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import librosa 5 | from pystoi import stoi 6 | from pesq import pesq 7 | from tqdm import tqdm 8 | from multiprocessing import Pool, cpu_count 9 | from functools import partial 10 | import warnings 11 | warnings.filterwarnings('ignore') 12 | 13 | def calculate_si_snr(reference, estimate): 14 | """Calculate SI-SNR between two signals""" 15 | ref = reference - np.mean(reference) 16 | est = estimate - np.mean(estimate) 17 | 18 | alpha = np.dot(ref, est) / np.dot(ref, ref) 19 | s_target = alpha * ref 20 | e_noise = est - s_target 21 | 22 | return 10 * np.log10(np.dot(s_target, s_target) / np.dot(e_noise, e_noise)) 23 | 24 | def evaluate_audio_pair(filename, reference_folder, estimate_folder): 25 | """Calculate STOI, PESQ, and SI-SNR for a pair of audio files""" 26 | try: 27 | ref_path = os.path.join(reference_folder, filename) 28 | est_path = os.path.join(estimate_folder, filename) 29 | 30 | if not os.path.exists(est_path): 31 | return None 32 | 33 | # Load audio 34 | ref_16k, _ = librosa.load(ref_path, sr=16000, mono=True) 35 | est_16k, _ = librosa.load(est_path, sr=16000, mono=True) 36 | 37 | # Match lengths 38 | min_len = min(len(ref_16k), len(est_16k)) 39 | ref_16k = ref_16k[:min_len] 40 | est_16k = est_16k[:min_len] 41 | 42 | # Calculate metrics 43 | stoi_score = stoi(ref_16k, est_16k, fs_sig=16000) 44 | estoi_score = stoi(ref_16k, est_16k, fs_sig=16000, extended=True) 45 | pesq_score = pesq(16000, ref_16k, est_16k, 'wb') 46 | si_snr_score = calculate_si_snr(ref_16k, est_16k) 47 | 48 | return { 49 | 'filename': filename, 50 | 'stoi': stoi_score, 51 | 'estoi': estoi_score, 52 | 'pesq': pesq_score, 53 | 'si_snr': si_snr_score 54 | } 55 | except Exception as e: 56 | print(f"Error processing {filename}: {e}") 57 | return None 58 | 59 | def process_batch(filenames, reference_folder, estimate_folder): 60 | """Process a batch of files""" 61 | results = [] 62 | for filename in filenames: 63 | result = evaluate_audio_pair(filename, reference_folder, estimate_folder) 64 | if result is not None: 65 | results.append(result) 66 | return results 67 | 68 | def main(): 69 | reference_folder = "/root/autodl-fs/clean_testset_wav" # UPDATE THIS 70 | for step_count in tqdm(range(21)): 71 | estimate_folder = "/root/autodl-tmp/inferred_output_vctkdemucs_pesq/step_" + str(step_count) # UPDATE THIS 72 | 73 | # Get list of WAV files 74 | wav_files = [f for f in os.listdir(reference_folder) if f.endswith('.wav')] 75 | print(f"Found {len(wav_files)} audio files to process") 76 | 77 | # Set number of workers (leave one core free for system) 78 | num_workers = 4 79 | print(f"Using {num_workers} workers for parallel processing") 80 | 81 | # Create a partial function with fixed reference and estimate folders 82 | evaluate_func = partial(evaluate_audio_pair, 83 | reference_folder=reference_folder, 84 | estimate_folder=estimate_folder) 85 | 86 | # Process files in parallel with progress bar 87 | results = [] 88 | with Pool(num_workers) as pool: 89 | with tqdm(total=len(wav_files), desc="Processing audio files") as pbar: 90 | for result in pool.imap_unordered(evaluate_func, wav_files): 91 | if result is not None: 92 | results.append(result) 93 | pbar.update(1) 94 | 95 | if not results: 96 | print("No results found. Check file paths.") 97 | return 98 | 99 | # Calculate mean scores 100 | stoi_mean = np.mean([r['stoi'] for r in results]) 101 | pesq_mean = np.mean([r['pesq'] for r in results]) 102 | estoi_mean = np.mean([r['estoi'] for r in results]) 103 | si_snr_mean = np.mean([r['si_snr'] for r in results]) 104 | 105 | print(f"\nProcessed {len(results)} files successfully") 106 | print(f"Mean STOI: {stoi_mean:.4f}, Mean PESQ: {pesq_mean:.4f}, Mean SI-SNR: {si_snr_mean:.4f}") 107 | 108 | # # Save to CSV 109 | # df = pd.DataFrame(results) 110 | # df.to_csv('results.csv', index=False) 111 | # print("\nDone. Results saved to results.csv") 112 | 113 | # Save summary statistics 114 | summary = { 115 | 'metric': ['STOI', 'PESQ', 'SI-SNR'], 116 | 'mean': [stoi_mean, pesq_mean, si_snr_mean], 117 | 'std': [ 118 | np.std([r['stoi'] for r in results]), 119 | np.std([r['pesq'] for r in results]), 120 | np.std([r['si_snr'] for r in results]) 121 | ], 122 | 'min': [ 123 | np.min([r['stoi'] for r in results]), 124 | np.min([r['pesq'] for r in results]), 125 | np.min([r['si_snr'] for r in results]) 126 | ], 127 | 'max': [ 128 | np.max([r['stoi'] for r in results]), 129 | np.max([r['pesq'] for r in results]), 130 | np.max([r['si_snr'] for r in results]) 131 | ] 132 | } 133 | summary_df = pd.DataFrame(summary) 134 | summary_df.to_csv("summary_statistics_" + str(step_count) + ".csv", index=False) 135 | 136 | if __name__ == "__main__": 137 | main() -------------------------------------------------------------------------------- /SpeechEnhancement/eval/non-intrusive.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import librosa 5 | import torch 6 | from tqdm import tqdm 7 | import warnings 8 | import utmos 9 | warnings.filterwarnings('ignore') 10 | 11 | from torchmetrics.audio.dnsmos import DeepNoiseSuppressionMeanOpinionScore 12 | 13 | utmos_model = utmos.Score() 14 | 15 | dnsmos_model = DeepNoiseSuppressionMeanOpinionScore( 16 | fs=16000, 17 | personalized=False, # Set to True for personalized MOS 18 | device="cuda" 19 | ) 20 | 21 | def evaluate_audio_file_speechmos(filepath): 22 | """Calculate UTMOS and DNSMOS-P808 metrics using speechmos package (easier)""" 23 | try: 24 | # Load audio at 16kHz 25 | audio, sr = librosa.load(filepath, sr=16000, mono=True) 26 | audio_tensor = torch.from_numpy(audio).unsqueeze(0).float() 27 | dnsmos_scores = dnsmos_model(audio_tensor.cuda()) 28 | utmos_score = utmos_model.calculate_wav(audio_tensor, 16000) 29 | 30 | return { 31 | 'filename': os.path.basename(filepath), 32 | 'utmos': utmos_score.item(), 33 | 'dnsmos_p808': dnsmos_scores[0].item(), 34 | 'dnsmos_sig': dnsmos_scores[1].item(), 35 | 'dnsmos_bak': dnsmos_scores[2].item(), 36 | 'dnsmos_ovrl': dnsmos_scores[3].item() 37 | } 38 | 39 | except Exception as e: 40 | print(f"Error processing {filepath}: {e}") 41 | return None 42 | 43 | def evaluate_audio_file_torchmetrics(filepath, dnsmos_model): 44 | """Alternative: Calculate DNSMOS using torchmetrics (also easy)""" 45 | try: 46 | # Load audio at 16kHz 47 | audio, sr = librosa.load(filepath, sr=16000, mono=True) 48 | 49 | # Convert to torch tensor 50 | audio_tensor = torch.from_numpy(audio).float() 51 | 52 | # DNSMOS evaluation using torchmetrics 53 | # Returns tensor with [p808_mos, mos_sig, mos_bak, mos_ovr] 54 | dnsmos_scores = dnsmos_model(audio_tensor) 55 | 56 | # UTMOS using torch.hub 57 | utmos_model = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True) 58 | audio_tensor_batch = audio_tensor.unsqueeze(0) 59 | utmos_score = utmos_model(audio_tensor_batch, sr) 60 | 61 | return { 62 | 'filename': os.path.basename(filepath), 63 | 'utmos': utmos_score.item(), 64 | 'dnsmos_p808': dnsmos_scores[0].item(), 65 | 'dnsmos_sig': dnsmos_scores[1].item(), 66 | 'dnsmos_bak': dnsmos_scores[2].item(), 67 | 'dnsmos_ovrl': dnsmos_scores[3].item() 68 | } 69 | 70 | except Exception as e: 71 | print(f"Error processing {filepath}: {e}") 72 | return None 73 | 74 | def main(): 75 | 76 | for step_count in tqdm(range(21)): 77 | # Configuration 78 | audio_folder = "/root/autodl-tmp/inferred_output_vctkdemucs_pesq/step_" + str(step_count) 79 | use_speechmos = True # Set to False to use torchmetrics instead 80 | 81 | print("Processing audio files in folder:", audio_folder) 82 | 83 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 84 | print(f"Using device: {device}") 85 | 86 | # Get list of audio files 87 | audio_extensions = ('.wav', '.mp3', '.flac', '.m4a') 88 | audio_files = [f for f in os.listdir(audio_folder) 89 | if f.lower().endswith(audio_extensions)] 90 | 91 | print(f"Found {len(audio_files)} audio files to process") 92 | 93 | # Process files with progress bar 94 | results = [] 95 | for filename in tqdm(audio_files, desc="Evaluating audio files"): 96 | filepath = os.path.join(audio_folder, filename) 97 | 98 | if use_speechmos: 99 | result = evaluate_audio_file_speechmos(filepath) 100 | else: 101 | result = evaluate_audio_file_torchmetrics(filepath, dnsmos_model) 102 | 103 | if result is not None: 104 | results.append(result) 105 | 106 | if not results: 107 | print("No results found. Check file paths.") 108 | return 109 | 110 | # Calculate mean scores 111 | utmos_mean = np.mean([r['utmos'] for r in results]) 112 | dnsmos_ovrl_mean = np.mean([r['dnsmos_ovrl'] for r in results]) 113 | dnsmos_sig_mean = np.mean([r['dnsmos_sig'] for r in results]) 114 | dnsmos_bak_mean = np.mean([r['dnsmos_bak'] for r in results]) 115 | dnsmos_p808_mean = np.mean([r.get('dnsmos_p808', r['dnsmos_ovrl']) for r in results]) 116 | 117 | print(f"\nProcessed {len(results)} files successfully") 118 | print(f"Mean UTMOS: {utmos_mean:.4f}") 119 | print(f"Mean DNSMOS-P808 OVRL: {dnsmos_ovrl_mean:.4f}") 120 | print(f"Mean DNSMOS-P808 SIG: {dnsmos_sig_mean:.4f}") 121 | print(f"Mean DNSMOS-P808 BAK: {dnsmos_bak_mean:.4f}") 122 | print(f"Mean DNSMOS-P808: {dnsmos_p808_mean:.4f}") 123 | 124 | # Save detailed results to CSV 125 | df = pd.DataFrame(results) 126 | df.to_csv('utmos_dnsmos_results.csv', index=False) 127 | print("\nDetailed results saved to utmos_dnsmos_results.csv") 128 | 129 | # Save summary statistics 130 | metrics = ['UTMOS', 'DNSMOS-P808 OVRL', 'DNSMOS-P808 SIG', 'DNSMOS-P808 BAK', 'DNSMOS-P808'] 131 | values = [ 132 | [r['utmos'] for r in results], 133 | [r['dnsmos_ovrl'] for r in results], 134 | [r['dnsmos_sig'] for r in results], 135 | [r['dnsmos_bak'] for r in results], 136 | [r.get('dnsmos_p808', r['dnsmos_ovrl']) for r in results] 137 | ] 138 | 139 | summary = { 140 | 'metric': metrics, 141 | 'mean': [np.mean(v) for v in values], 142 | 'std': [np.std(v) for v in values], 143 | 'min': [np.min(v) for v in values], 144 | 'max': [np.max(v) for v in values], 145 | 'median': [np.median(v) for v in values] 146 | } 147 | 148 | summary_df = pd.DataFrame(summary) 149 | summary_df.to_csv("vctkdemucs_pesq_dnsmos_summary_step_" + str(step_count) + ".csv", index=False) 150 | # print("Summary statistics saved to utmos_dnsmos_summary.csv") 151 | 152 | if __name__ == "__main__": 153 | main() -------------------------------------------------------------------------------- /SpeechEnhancement/infer/infer_blind_utmos.py: -------------------------------------------------------------------------------- 1 | import soundfile as sf 2 | from espnet2.bin.enh_inference import SeparateSpeech 3 | from tqdm import tqdm 4 | import numpy as np 5 | import torch 6 | from torch_pesq import PesqLoss 7 | import os 8 | import utmos 9 | import librosa 10 | 11 | model = SeparateSpeech.from_pretrained( 12 | model_tag="wyz/vctk_dns2020_whamr_bsrnn_medium_noncausal", 13 | normalize_output_wav=False, 14 | device="cuda", 15 | ) 16 | 17 | utmos_model = utmos.Score() 18 | 19 | print("Model loaded successfully.") 20 | 21 | def mix_inferred_with_noisy(inferred, noisy, ratio): 22 | # if len(inferred) != len(noisy): 23 | # # pad the shorter one with zeros 24 | # if len(inferred) < len(noisy): 25 | # inferred = np.pad(inferred, (0, len(noisy) - len(inferred)), mode='constant') 26 | # else: 27 | # noisy = np.pad(noisy, (0, len(inferred) - len(noisy)), mode='constant') 28 | 29 | if len(inferred) != len(noisy): 30 | raise ValueError("Inferred and noisy audio must have the same length.") 31 | 32 | mixed = (1 - ratio) * noisy + ratio * inferred 33 | return mixed 34 | 35 | # def sdr(clean, noisy): 36 | # noise = noisy - clean 37 | # sdr = 10 * np.log10(np.sum(clean**2) / np.sum(noise**2)) 38 | # return sdr 39 | 40 | def sdr( 41 | ref: np.ndarray, 42 | est: np.ndarray, 43 | eps: float = 1e-10 44 | ): 45 | r"""Calcualte SDR. 46 | """ 47 | noise = est - ref 48 | numerator = np.clip(a=np.mean(ref ** 2), a_min=eps, a_max=None) 49 | denominator = np.clip(a=np.mean(noise ** 2), a_min=eps, a_max=None) 50 | sdr = 10. * np.log10(numerator / denominator) 51 | return sdr 52 | 53 | def pesq_score(ref, est): 54 | ref = torch.tensor(ref, dtype=torch.float32) 55 | est = torch.tensor(est, dtype=torch.float32) 56 | mos = pesq.mos(ref, est) 57 | return mos.item() 58 | 59 | def utmos_score(est): 60 | est = torch.tensor(est, dtype=torch.float32) 61 | mos = utmos_model.calculate_wav(est, 16000) 62 | return mos.item() 63 | 64 | def save_audio(file_path, audio, fs): 65 | # if clipping, then normalize. 66 | if np.max(np.abs(audio)) > 1.0: 67 | audio = audio / np.max(np.abs(audio)) 68 | sf.write(file_path, audio, fs) 69 | 70 | noisy_files = [] 71 | noisy_dir = "/root/autodl-tmp/V2_V3_DNSChallenge_Blindset/noisy_blind_testset_v3_challenge_withSNR_16k" 72 | 73 | for root, dirs, files in os.walk(noisy_dir): 74 | for file in files: 75 | if file.endswith(".wav"): 76 | noisy_files.append(os.path.join(root, file)) 77 | 78 | noisy_files.sort() 79 | 80 | print(f"Number of noisy files: {len(noisy_files)}") 81 | 82 | os.makedirs("/root/autodl-tmp/inferred_output_dnsv3_utmos/step_0", exist_ok=True) 83 | 84 | # for noisy_file in tqdm((noisy_files), total=len(noisy_files)): 85 | # audio, fs = sf.read(noisy_file) 86 | # inferred = model(audio[None, :], fs=fs)[0] 87 | # filename = noisy_file.split("/")[-1] 88 | # save_audio(f"/root/autodl-tmp/inferred_output_dnsv3_utmos/step_0/{filename}", inferred[0], fs) 89 | 90 | inference_steps = 20 91 | candidate_each_step = 10 92 | 93 | for i in tqdm(range(inference_steps)): 94 | noisy_dir = f"/root/autodl-tmp/V2_V3_DNSChallenge_Blindset/noisy_blind_testset_v3_challenge_withSNR_16k" 95 | prev_step_dir = f"/root/autodl-tmp/inferred_output_dnsv3_utmos/step_{i}" 96 | current_step_dir = f"/root/autodl-tmp/inferred_output_dnsv3_utmos/step_{i + 1}" 97 | os.makedirs(current_step_dir, exist_ok=True) 98 | clean_files = [] 99 | noisy_files = [] 100 | 101 | for root, dirs, files in os.walk(prev_step_dir): 102 | for file in files: 103 | if file.endswith(".wav"): 104 | clean_files.append(os.path.join(prev_step_dir, file)) 105 | noisy_files.append(os.path.join(noisy_dir, file)) 106 | 107 | all_sdrs = [] 108 | 109 | # start_ratio = i / inference_steps 110 | # ratio_step = (1 - start_ratio) / candidate_each_step 111 | candidate_ratios = [0.1 * i for i in range(candidate_each_step)] 112 | 113 | for clean_file, noisy_file in tqdm(zip(clean_files, noisy_files), total=len(clean_files)): 114 | fs = 16000 115 | candidate_inferred = [] 116 | clean_audio, _ = librosa.load(clean_file, sr=fs) 117 | candidate_inferred.append(np.expand_dims(clean_audio, axis=0)) 118 | noisy_audio, _ = librosa.load(noisy_file, sr=fs) 119 | for ratio in candidate_ratios: 120 | mixed = mix_inferred_with_noisy(clean_audio, noisy_audio, ratio) 121 | if np.max(np.abs(mixed)) > 1.0: 122 | mixed = mixed / np.max(np.abs(mixed)) 123 | inferred = model(mixed[None, :], fs=fs)[0] 124 | candidate_inferred.append(inferred) 125 | sdrs = [utmos_score(inferred) for inferred in candidate_inferred] 126 | best_index = np.argmax(sdrs) 127 | best_inferred = candidate_inferred[best_index] 128 | save_audio(f"{current_step_dir}/{clean_file.split('/')[-1]}", best_inferred[0], fs) 129 | all_sdrs.append(sdrs[best_index]) -------------------------------------------------------------------------------- /SpeechEnhancement/infer/infer_large.py: -------------------------------------------------------------------------------- 1 | import soundfile as sf 2 | from espnet2.bin.enh_inference import SeparateSpeech 3 | from tqdm import tqdm 4 | import numpy as np 5 | import torch 6 | from torch_pesq import PesqLoss 7 | import os 8 | import librosa 9 | 10 | model = SeparateSpeech.from_pretrained( 11 | model_tag="wyz/vctk_dns2020_whamr_bsrnn_large_double_noncausal", # change to large or xlarge 12 | normalize_output_wav=False, 13 | device="cuda", 14 | ) 15 | 16 | print("Model loaded successfully.") 17 | 18 | def pesq_score(ref, est): 19 | ref = torch.tensor(ref, dtype=torch.float32) 20 | est = torch.tensor(est, dtype=torch.float32) 21 | mos = pesq.mos(ref, est) 22 | return mos.item() 23 | 24 | def save_audio(file_path, audio, fs): 25 | # if clipping, then normalize. 26 | if np.max(np.abs(audio)) > 1.0: 27 | audio = audio / np.max(np.abs(audio)) 28 | sf.write(file_path, audio, fs) 29 | 30 | clean_files, noisy_files = [], [] 31 | clean_dir, noisy_dir = "/root/autodl-fs/clean_testset_wav", "/root/autodl-fs/noisy_testset_wav" 32 | # Get the list of files in the clean and noisy directories 33 | for root, dirs, files in os.walk(clean_dir): 34 | for file in files: 35 | if file.endswith(".wav"): 36 | clean_files.append(os.path.join(root, file)) 37 | 38 | for root, dirs, files in os.walk(noisy_dir): 39 | for file in files: 40 | if file.endswith(".wav"): 41 | noisy_files.append(os.path.join(root, file)) 42 | 43 | # Sort the files to ensure they are in the same order 44 | clean_files.sort() 45 | noisy_files.sort() 46 | 47 | print(f"Number of clean files: {len(clean_files)}") 48 | print(f"Number of noisy files: {len(noisy_files)}") 49 | # Check if the number of files is the same 50 | if len(clean_files) != len(noisy_files): 51 | raise ValueError("The number of clean and noisy files must be the same.") 52 | 53 | os.makedirs("/root/autodl-tmp/inferred_xlarge_demand", exist_ok=True) 54 | 55 | for noisy_file in tqdm((noisy_files), total=len(noisy_files)): 56 | audio, fs = sf.read(noisy_file) 57 | inferred = model(audio[None, :], fs=fs)[0] 58 | filename = noisy_file.split("/")[-1] 59 | save_audio(f"/root/autodl-tmp/inferred_xlarge_demand/{filename}", inferred[0], fs) -------------------------------------------------------------------------------- /SpeechEnhancement/infer/infer_pesq.py: -------------------------------------------------------------------------------- 1 | import soundfile as sf 2 | from espnet2.bin.enh_inference import SeparateSpeech 3 | from tqdm import tqdm 4 | import numpy as np 5 | import torch 6 | from torch_pesq import PesqLoss 7 | import os 8 | import librosa 9 | 10 | model = SeparateSpeech.from_pretrained( 11 | model_tag="wyz/vctk_dns2020_whamr_bsrnn_medium_noncausal", 12 | normalize_output_wav=False, 13 | device="cuda", 14 | ) 15 | 16 | pesq = PesqLoss(0.5, 17 | sample_rate=48000, 18 | ) 19 | 20 | print("Model loaded successfully.") 21 | 22 | def mix_inferred_with_noisy(inferred, noisy, ratio): 23 | if len(inferred) != len(noisy): 24 | # pad the shorter one with zeros 25 | if len(inferred) < len(noisy): 26 | inferred = np.pad(inferred, (0, len(noisy) - len(inferred)), mode='constant') 27 | else: 28 | noisy = np.pad(noisy, (0, len(inferred) - len(noisy)), mode='constant') 29 | if len(inferred) != len(noisy): 30 | raise ValueError("Inferred and noisy audio must have the same length after padding.") 31 | 32 | mixed = (1 - ratio) * noisy + ratio * inferred 33 | return mixed 34 | 35 | # def sdr(clean, noisy): 36 | # noise = noisy - clean 37 | # sdr = 10 * np.log10(np.sum(clean**2) / np.sum(noise**2)) 38 | # return sdr 39 | 40 | def sdr( 41 | ref: np.ndarray, 42 | est: np.ndarray, 43 | eps: float = 1e-10 44 | ): 45 | r"""Calcualte SDR. 46 | """ 47 | noise = est - ref 48 | numerator = np.clip(a=np.mean(ref ** 2), a_min=eps, a_max=None) 49 | denominator = np.clip(a=np.mean(noise ** 2), a_min=eps, a_max=None) 50 | sdr = 10. * np.log10(numerator / denominator) 51 | return sdr 52 | 53 | def pesq_score(ref, est): 54 | ref = torch.tensor(ref, dtype=torch.float32) 55 | est = torch.tensor(est, dtype=torch.float32) 56 | mos = pesq.mos(ref, est) 57 | return mos.item() 58 | 59 | def save_audio(file_path, audio, fs): 60 | # if clipping, then normalize. 61 | if np.max(np.abs(audio)) > 1.0: 62 | audio = audio / np.max(np.abs(audio)) 63 | sf.write(file_path, audio, fs) 64 | 65 | clean_files, noisy_files = [], [] 66 | clean_dir, noisy_dir = "/root/autodl-fs/clean_testset_wav", "/root/autodl-fs/noisy_testset_wav" 67 | # Get the list of files in the clean and noisy directories 68 | for root, dirs, files in os.walk(clean_dir): 69 | for file in files: 70 | if file.endswith(".wav"): 71 | clean_files.append(os.path.join(root, file)) 72 | 73 | for root, dirs, files in os.walk(noisy_dir): 74 | for file in files: 75 | if file.endswith(".wav"): 76 | noisy_files.append(os.path.join(root, file)) 77 | 78 | # Sort the files to ensure they are in the same order 79 | clean_files.sort() 80 | noisy_files.sort() 81 | 82 | print(f"Number of clean files: {len(clean_files)}") 83 | print(f"Number of noisy files: {len(noisy_files)}") 84 | # Check if the number of files is the same 85 | if len(clean_files) != len(noisy_files): 86 | raise ValueError("The number of clean and noisy files must be the same.") 87 | 88 | os.makedirs("/root/autodl-tmp/inferred_output_vctkdemucs_pesq/step_0", exist_ok=True) 89 | 90 | # for noisy_file in tqdm((noisy_files), total=len(noisy_files)): 91 | # audio, fs = sf.read(noisy_file) 92 | # inferred = model(audio[None, :], fs=fs)[0] 93 | # filename = noisy_file.split("/")[-1] 94 | # save_audio(f"/root/autodl-tmp/inferred_output_vctkdemucs_pesq/step_0/{filename}", inferred[0], fs) 95 | 96 | inference_steps = 20 97 | candidate_each_step = 10 98 | 99 | for i in tqdm(range(inference_steps)): 100 | target_dir = f"/root/autodl-fs/clean_testset_wav" 101 | noisy_dir = f"/root/autodl-fs/noisy_testset_wav" 102 | prev_step_dir = f"/root/autodl-tmp/inferred_output_vctkdemucs_pesq/step_{i}" 103 | current_step_dir = f"/root/autodl-tmp/inferred_output_vctkdemucs_pesq/step_{i + 1}" 104 | os.makedirs(current_step_dir, exist_ok=True) 105 | target_files = [] 106 | clean_files = [] 107 | noisy_files = [] 108 | 109 | for root, dirs, files in os.walk(target_dir): 110 | for file in files: 111 | if file.endswith(".wav"): 112 | target_files.append(os.path.join(root, file)) 113 | clean_files.append(os.path.join(prev_step_dir, file)) 114 | noisy_files.append(os.path.join(noisy_dir, file)) 115 | 116 | all_sdrs = [] 117 | 118 | candidate_ratios = [0.1 * i for i in range(candidate_each_step)] 119 | 120 | for target_file, clean_file, noisy_file in tqdm(zip(target_files, clean_files, noisy_files), total=len(target_files)): 121 | fs = 48000 122 | target_audio, _ = librosa.load(target_file, sr=fs) 123 | candidate_inferred = [] 124 | clean_audio, _ = librosa.load(clean_file, sr=fs) 125 | candidate_inferred.append(np.expand_dims(clean_audio, axis=0)) 126 | noisy_audio, _ = librosa.load(noisy_file, sr=fs) 127 | for ratio in candidate_ratios: 128 | mixed = mix_inferred_with_noisy(clean_audio, noisy_audio, ratio) 129 | if np.max(np.abs(mixed)) > 1.0: 130 | mixed = mixed / np.max(np.abs(mixed)) 131 | inferred = model(mixed[None, :], fs=fs)[0] 132 | candidate_inferred.append(inferred) 133 | sdrs = [pesq_score(target_audio, inferred) for inferred in candidate_inferred] 134 | best_index = np.argmax(sdrs) 135 | best_inferred = candidate_inferred[best_index] 136 | save_audio(f"{current_step_dir}/{clean_file.split('/')[-1]}", best_inferred[0], fs) 137 | all_sdrs.append(sdrs[best_index]) -------------------------------------------------------------------------------- /SpeechEnhancement/infer/infer_sdr.py: -------------------------------------------------------------------------------- 1 | import soundfile as sf 2 | from espnet2.bin.enh_inference import SeparateSpeech 3 | from tqdm import tqdm 4 | import numpy as np 5 | import torch 6 | from torch_pesq import PesqLoss 7 | import os 8 | import librosa 9 | 10 | model = SeparateSpeech.from_pretrained( 11 | model_tag="wyz/vctk_dns2020_whamr_bsrnn_medium_noncausal", 12 | normalize_output_wav=False, 13 | device="cuda", 14 | ) 15 | 16 | pesq = PesqLoss(0.5, 17 | sample_rate=48000, 18 | ) 19 | 20 | print("Model loaded successfully.") 21 | 22 | def mix_inferred_with_noisy(inferred, noisy, ratio): 23 | if len(inferred) != len(noisy): 24 | # pad the shorter one with zeros 25 | if len(inferred) < len(noisy): 26 | inferred = np.pad(inferred, (0, len(noisy) - len(inferred)), mode='constant') 27 | else: 28 | noisy = np.pad(noisy, (0, len(inferred) - len(noisy)), mode='constant') 29 | if len(inferred) != len(noisy): 30 | raise ValueError("Inferred and noisy audio must have the same length after padding.") 31 | 32 | mixed = (1 - ratio) * noisy + ratio * inferred 33 | return mixed 34 | 35 | # def sdr(clean, noisy): 36 | # noise = noisy - clean 37 | # sdr = 10 * np.log10(np.sum(clean**2) / np.sum(noise**2)) 38 | # return sdr 39 | 40 | def sdr( 41 | ref: np.ndarray, 42 | est: np.ndarray, 43 | eps: float = 1e-10 44 | ): 45 | r"""Calcualte SDR. 46 | """ 47 | noise = est - ref 48 | numerator = np.clip(a=np.mean(ref ** 2), a_min=eps, a_max=None) 49 | denominator = np.clip(a=np.mean(noise ** 2), a_min=eps, a_max=None) 50 | sdr = 10. * np.log10(numerator / denominator) 51 | return sdr 52 | 53 | def pesq_score(ref, est): 54 | ref = torch.tensor(ref, dtype=torch.float32) 55 | est = torch.tensor(est, dtype=torch.float32) 56 | mos = pesq.mos(ref, est) 57 | return mos.item() 58 | 59 | def save_audio(file_path, audio, fs): 60 | # if clipping, then normalize. 61 | if np.max(np.abs(audio)) > 1.0: 62 | audio = audio / np.max(np.abs(audio)) 63 | sf.write(file_path, audio, fs) 64 | 65 | clean_files, noisy_files = [], [] 66 | clean_dir, noisy_dir = "/root/autodl-fs/clean_testset_wav", "/root/autodl-fs/noisy_testset_wav" 67 | # Get the list of files in the clean and noisy directories 68 | for root, dirs, files in os.walk(clean_dir): 69 | for file in files: 70 | if file.endswith(".wav"): 71 | clean_files.append(os.path.join(root, file)) 72 | 73 | for root, dirs, files in os.walk(noisy_dir): 74 | for file in files: 75 | if file.endswith(".wav"): 76 | noisy_files.append(os.path.join(root, file)) 77 | 78 | # Sort the files to ensure they are in the same order 79 | clean_files.sort() 80 | noisy_files.sort() 81 | 82 | print(f"Number of clean files: {len(clean_files)}") 83 | print(f"Number of noisy files: {len(noisy_files)}") 84 | # Check if the number of files is the same 85 | if len(clean_files) != len(noisy_files): 86 | raise ValueError("The number of clean and noisy files must be the same.") 87 | 88 | os.makedirs("/root/autodl-tmp/inferred_output_vctkdemucs_sdr/step_0", exist_ok=True) 89 | 90 | for noisy_file in tqdm((noisy_files), total=len(noisy_files)): 91 | audio, fs = sf.read(noisy_file) 92 | inferred = model(audio[None, :], fs=fs)[0] 93 | filename = noisy_file.split("/")[-1] 94 | save_audio(f"/root/autodl-tmp/inferred_output_vctkdemucs_sdr/step_0/{filename}", inferred[0], fs) 95 | 96 | inference_steps = 20 97 | candidate_each_step = 10 98 | 99 | for i in tqdm(range(inference_steps)): 100 | target_dir = f"/root/autodl-fs/clean_testset_wav" 101 | noisy_dir = f"/root/autodl-fs/noisy_testset_wav" 102 | prev_step_dir = f"/root/autodl-tmp/inferred_output_vctkdemucs_sdr/step_{i}" 103 | current_step_dir = f"/root/autodl-tmp/inferred_output_vctkdemucs_sdr/step_{i + 1}" 104 | os.makedirs(current_step_dir, exist_ok=True) 105 | target_files = [] 106 | clean_files = [] 107 | noisy_files = [] 108 | 109 | for root, dirs, files in os.walk(target_dir): 110 | for file in files: 111 | if file.endswith(".wav"): 112 | target_files.append(os.path.join(root, file)) 113 | clean_files.append(os.path.join(prev_step_dir, file)) 114 | noisy_files.append(os.path.join(noisy_dir, file)) 115 | 116 | all_sdrs = [] 117 | 118 | start_ratio = i / inference_steps 119 | ratio_step = (1 - start_ratio) / candidate_each_step 120 | candidate_ratios = [start_ratio + j * ratio_step for j in range(candidate_each_step)] 121 | 122 | for target_file, clean_file, noisy_file in tqdm(zip(target_files, clean_files, noisy_files), total=len(target_files)): 123 | fs = 48000 124 | target_audio, _ = librosa.load(target_file, sr=fs) 125 | candidate_inferred = [] 126 | clean_audio, _ = librosa.load(clean_file, sr=fs) 127 | candidate_inferred.append(np.expand_dims(clean_audio, axis=0)) 128 | noisy_audio, _ = librosa.load(noisy_file, sr=fs) 129 | for ratio in candidate_ratios: 130 | mixed = mix_inferred_with_noisy(clean_audio, noisy_audio, ratio) 131 | if np.max(np.abs(mixed)) > 1.0: 132 | mixed = mixed / np.max(np.abs(mixed)) 133 | inferred = model(mixed[None, :], fs=fs)[0] 134 | candidate_inferred.append(inferred) 135 | sdrs = [sdr(target_audio, inferred) for inferred in candidate_inferred] 136 | best_index = np.argmax(sdrs) 137 | best_inferred = candidate_inferred[best_index] 138 | save_audio(f"{current_step_dir}/{clean_file.split('/')[-1]}", best_inferred[0], fs) 139 | all_sdrs.append(sdrs[best_index]) --------------------------------------------------------------------------------