├── __init__.py ├── biodenoising ├── htdemucs.py ├── __init__.py ├── img │ ├── demucs.png │ └── pavucontrol.png ├── datasets │ ├── __init__.py │ └── gng.py ├── hubconf.py ├── denoiser │ ├── __init__.py │ ├── noise.py │ ├── spec.py │ ├── resample.py │ ├── executor.py │ ├── distrib.py │ ├── dsp.py │ ├── data.py │ ├── evaluate_speech.py │ ├── states.py │ ├── live.py │ ├── stft_loss.py │ ├── pretrained.py │ ├── enhance.py │ ├── utils.py │ ├── audio.py │ ├── evaluate.py │ ├── denoise.py │ └── loss.py ├── conf │ ├── config_adapt.yaml │ └── config.yaml └── selection_table.py ├── biodenoising.jpg ├── setup.cfg ├── requirements.txt ├── MANIFEST.in ├── scripts ├── rir.sh ├── biodenoising_demo_long_audio.ipynb ├── Biodenoising_demo.ipynb └── Biodenoising_denoise_zip_demo.ipynb ├── LICENSE ├── compute_data_stats.py ├── pyproject.toml ├── setup.py ├── plot.py ├── .gitignore ├── adapt.py ├── evaluate.py ├── train.py ├── results.py └── prepare_experiments.py /__init__.py: -------------------------------------------------------------------------------- 1 | from .biodenoising import * -------------------------------------------------------------------------------- /biodenoising/htdemucs.py: -------------------------------------------------------------------------------- 1 | from biodenoising.denoiser.htdemucs import HTDemucs -------------------------------------------------------------------------------- /biodenoising.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/earthspecies/biodenoising/HEAD/biodenoising.jpg -------------------------------------------------------------------------------- /biodenoising/__init__.py: -------------------------------------------------------------------------------- 1 | from .denoiser import * 2 | from .datasets import * 3 | from .adapt import denoise -------------------------------------------------------------------------------- /biodenoising/img/demucs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/earthspecies/biodenoising/HEAD/biodenoising/img/demucs.png -------------------------------------------------------------------------------- /biodenoising/img/pavucontrol.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/earthspecies/biodenoising/HEAD/biodenoising/img/pavucontrol.png -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [pep8] 2 | max-line-length = 100 3 | 4 | [flake8] 5 | max-line-length = 100 6 | 7 | [yapf] 8 | column_limit = 100 -------------------------------------------------------------------------------- /biodenoising/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .NoiseCleanRndSet import NoiseCleanRndSet, NoiseCleanValidSet, NoiseClean1BalancedSet, NoiseClean2BalancedSet, NoiseClean1WeightedSet, NoiseCleanAdaptSet 2 | from .dataprune import * -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | julius 2 | librosa 3 | norbert 4 | numpy 5 | einops 6 | openunmix 7 | asteroid 8 | sounddevice 9 | torchaudio 10 | torch 11 | omegaconf==1.4.1 12 | noisereduce 13 | scikit-fuzzy 14 | prosemble 15 | librosa 16 | norbert 17 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include *.md 2 | include LICENSE 3 | include setup.cfg 4 | include *.txt 5 | include train.py 6 | include prepare_experiments.py 7 | include generate_training.py 8 | include denoise.py 9 | include evaluate.py 10 | include plot.py 11 | recursive-include datasets * 12 | recursive-include scripts * 13 | -------------------------------------------------------------------------------- /biodenoising/hubconf.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/facebookresearch/demucs under the MIT License 2 | # Original Copyright (c) Earth Species Project. This work is based on Facebook's denoiser. 3 | 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # author: adefossez 8 | 9 | dependencies = ['torch'] 10 | from denoiser.pretrained import dns48, dns64, master64, demucsv4 # noqa 11 | -------------------------------------------------------------------------------- /scripts/rir.sh: -------------------------------------------------------------------------------- 1 | RIR_PATH="/home/marius/data/biodenoising16k/rir" 2 | 3 | mkdir -p $RIR_PATH 4 | 5 | AZURE_URL="https://dnschallengepublic.blob.core.windows.net/dns5archive/V5_training_dataset" 6 | BLOB="datasets_fullband.impulse_responses_000.tar.bz2" 7 | 8 | URL="$AZURE_URL/$BLOB" 9 | echo "Download: $BLOB" 10 | 11 | # ### DRY RUN: print HTTP response and Content-Length 12 | # # WITHOUT downloading the files 13 | # curl -s -I "$URL" | head -n 2 14 | 15 | ### Actually download the files: UNCOMMENT when ready to download 16 | # curl "$URL" -o "$OUTPUT_PATH/$BLOB" 17 | 18 | ### Same as above, but using wget 19 | # wget "$URL" -O "$OUTPUT_PATH/$BLOB" 20 | 21 | ### Same, + unpack files on the fly 22 | curl "$URL" | tar -C "$RIR_PATH" -f - -x -j -------------------------------------------------------------------------------- /biodenoising/denoiser/__init__.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/facebookresearch/demucs under the MIT License 2 | # Original Copyright (c) Earth Species Project. This work is based on Facebook's denoiser. 3 | 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from .audio import * 9 | from .demucs import * 10 | from .hdemucs import * 11 | from .htdemucs import * 12 | from .transformer import * 13 | from .cleanunet import * 14 | from .spec import * 15 | from .utils import * 16 | from .augment import * 17 | from .utils import * 18 | from .pretrained import * 19 | from .distrib import * 20 | # from .enhance import * 21 | # from .evaluate import * 22 | # from .data import * 23 | from .distrib import * 24 | # from .dsp import * 25 | # from .executor import * 26 | # from .live import * 27 | from .resample import * 28 | from .solver import * 29 | # from .stft_loss import * 30 | from .loss import * 31 | from .noise import * -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Earth Species Project 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /biodenoising/denoiser/noise.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def noise_psd(N, psd = lambda f: 1): 6 | X_white = np.fft.rfft(np.random.randn(N)) 7 | S = psd(np.fft.rfftfreq(N)) 8 | # Normalize S 9 | S = S / np.sqrt(np.mean(S**2)) 10 | X_shaped = X_white * S 11 | return np.fft.irfft(X_shaped) 12 | 13 | def PSDGenerator(f): 14 | return lambda N: noise_psd(N, f) 15 | 16 | @PSDGenerator 17 | def white_noise(f): 18 | return 1 19 | 20 | @PSDGenerator 21 | def blue_noise(f): 22 | return np.sqrt(f) 23 | 24 | @PSDGenerator 25 | def violet_noise(f): 26 | return f 27 | 28 | @PSDGenerator 29 | def brownian_noise(f): 30 | return 1/np.where(f == 0, float('inf'), f) 31 | 32 | @PSDGenerator 33 | def pink_noise(f): 34 | return 1/np.where(f == 0, float('inf'), np.sqrt(f)) 35 | 36 | 37 | def get_noise(N,rngnp): 38 | noises = {0: white_noise, 1: blue_noise, 2: violet_noise, 3: brownian_noise, 4: pink_noise} 39 | noise_type = rngnp.integers(low=0,high=len(noises)) 40 | noise = noises[noise_type](N) 41 | # nframes_offset = int(rngnp.uniform() * N//2) 42 | # noise = noise[nframes_offset:nframes_offset+N] 43 | return noise -------------------------------------------------------------------------------- /compute_data_stats.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torchaudio 5 | 6 | 7 | parser = argparse.ArgumentParser( 8 | 'stats', 9 | description="Generate denoised files") 10 | parser.add_argument("--biodenoising_dir", type=str, required=True, 11 | help="directory where data is") 12 | 13 | def stats_dir(indir): 14 | ### get all subdirs 15 | subdirs = [ f.name for f in os.scandir(indir) if f.is_dir()] 16 | for subdir in subdirs: 17 | ### get all wav files 18 | files = [ f.name for f in os.scandir(os.path.join(indir,subdir)) if f.is_file() and f.name.endswith('.wav')] 19 | duration_sum = 0 20 | for file in files: 21 | ### get file duration with torchaudio.info 22 | info = torchaudio.info(os.path.join(indir,subdir,file)) 23 | duration = info.num_frames / info.sample_rate 24 | duration_sum += duration 25 | print(subdir, duration_sum/3600) 26 | 27 | def stats(indir): 28 | ### first compute stats for the dev subfolder 29 | print('Noisy') 30 | stats_dir(os.path.join(indir, 'dev', 'noisy')) 31 | print('Clean') 32 | stats_dir(os.path.join(indir, 'train', 'clean')) 33 | print('Noise') 34 | stats_dir(os.path.join(indir, 'train', 'noise')) 35 | 36 | 37 | if __name__ == "__main__": 38 | args = parser.parse_args() 39 | stats(args.biodenoising_dir) -------------------------------------------------------------------------------- /biodenoising/denoiser/spec.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/facebookresearch/demucs under the MIT License 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Conveniance wrapper to perform STFT and iSTFT""" 8 | 9 | import torch as th 10 | 11 | 12 | def spectro(x, n_fft=512, hop_length=None, pad=0): 13 | *other, length = x.shape 14 | x = x.reshape(-1, length) 15 | is_mps = x.device.type == 'mps' 16 | if is_mps: 17 | x = x.cpu() 18 | z = th.stft(x, 19 | n_fft * (1 + pad), 20 | hop_length or n_fft // 4, 21 | window=th.hann_window(n_fft).to(x), 22 | win_length=n_fft, 23 | normalized=True, 24 | center=True, 25 | return_complex=True, 26 | pad_mode='reflect') 27 | _, freqs, frame = z.shape 28 | return z.view(*other, freqs, frame) 29 | 30 | 31 | def ispectro(z, hop_length=None, length=None, pad=0): 32 | *other, freqs, frames = z.shape 33 | n_fft = 2 * freqs - 2 34 | z = z.view(-1, freqs, frames) 35 | win_length = n_fft // (1 + pad) 36 | is_mps = z.device.type == 'mps' 37 | if is_mps: 38 | z = z.cpu() 39 | x = th.istft(z, 40 | n_fft, 41 | hop_length, 42 | window=th.hann_window(win_length).to(z.real), 43 | win_length=win_length, 44 | normalized=True, 45 | length=length, 46 | center=True) 47 | _, length = x.shape 48 | return x.view(*other, length) 49 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | # https://setuptools.readthedocs.io/en/latest/setuptools.html#setup-cfg-only-projects 3 | requires = [ 4 | "setuptools >= 40.9.0", 5 | "wheel" 6 | ] 7 | build-backend = "setuptools.build_meta" 8 | 9 | [project] 10 | name = 'biodenoising' 11 | description = "Animal vocalization denoising" 12 | 13 | requires-python = '>=3.8.0' 14 | version = "0.3.3" 15 | authors = [ 16 | { name="Marius Miron", email="info@mariusmiron.com" }, 17 | ] 18 | readme = "README.md" 19 | dependencies = ['julius', 'librosa', 'norbert', 'numpy>=1.19', 'six', 'sounddevice>=0.4', 'torch>=1.5', 'torchaudio>=0.5', 'openunmix', 'asteroid','einops', 'omegaconf==1.4.1', 'noisereduce','scikit-fuzzy','prosemble'] 20 | 21 | classifiers=[ 22 | 'Topic :: Multimedia :: Sound/Audio :: Analysis', 23 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 24 | "Intended Audience :: Information Technology", 25 | "Intended Audience :: Science/Research", 26 | "Intended Audience :: System Administrators", 27 | "Intended Audience :: Telecommunications Industry", 28 | "Programming Language :: Python :: 3", 29 | "Programming Language :: Python :: 3 :: Only", 30 | "Programming Language :: Python :: 3.8", 31 | "Programming Language :: Python :: 3.9", 32 | "Programming Language :: Python :: 3.10", 33 | "Programming Language :: Python :: 3.11", 34 | "Programming Language :: Python :: 3.12", 35 | ] 36 | license = { file="LICENSE" } 37 | 38 | [project.urls] 39 | Homepage = 'https://github.com/earthspecies/biodenoising' 40 | Issues = 'https://github.com/earthspecies/biodenoising/issues' 41 | 42 | [project.scripts] 43 | biodenoise = "denoise:main" 44 | biodenoise-adapt = "adapt:main" 45 | 46 | [tool.setuptools.packages.find] 47 | where = ["."] 48 | include = ["biodenoising","biodenoising.adapt","biodenoising.denoiser","biodenoising.datasets","biodenoising.conf"] 49 | 50 | [tool.setuptools.package-data] 51 | "*" = ["*.yaml"] 52 | 53 | [tool.setuptools] 54 | py-modules = ["denoise", "adapt"] 55 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Earth Species Project. This work is based on Facebook's denoiser. 3 | 4 | 5 | from pathlib import Path 6 | 7 | from setuptools import setup, find_packages 8 | 9 | NAME = 'biodenoising' 10 | DESCRIPTION = ( 11 | 'Animal vocalization denoising') 12 | 13 | URL = 'https://github.com/earthspecies/biodenoising' 14 | EMAIL = 'info@mariusmiron.com' 15 | AUTHOR = 'Marius Miron' 16 | REQUIRES_PYTHON = '>=3.8.0' 17 | VERSION = "0.3.3" 18 | 19 | HERE = Path(__file__).parent 20 | 21 | REQUIRED = [ 22 | 'julius', 23 | 'numpy>=1.19', 24 | 'six', 25 | 'sounddevice>=0.4', 26 | 'torch>=1.5', 27 | 'torchaudio>=0.5', 28 | 'openunmix', 29 | 'einops', 30 | 'omegaconf==1.4.1', 31 | 'noisereduce', 32 | 'scikit-fuzzy', 33 | 'prosemble', 34 | 'librosa', 35 | 'norbert' 36 | ] 37 | 38 | REQUIRED_LINKS = [ 39 | ] 40 | 41 | try: 42 | with open(HERE / "README.md", encoding='utf-8') as f: 43 | long_description = '\n' + f.read() 44 | except FileNotFoundError: 45 | long_description = DESCRIPTION 46 | 47 | setup( 48 | name=NAME, 49 | version=VERSION, 50 | description=DESCRIPTION, 51 | long_description=long_description, 52 | long_description_content_type='text/markdown', 53 | author=AUTHOR, 54 | author_email=EMAIL, 55 | python_requires=REQUIRES_PYTHON, 56 | url=URL, 57 | packages=find_packages(), 58 | install_requires=REQUIRED, 59 | dependency_link=REQUIRED_LINKS, 60 | include_package_data=True, 61 | license='Creative Commons Attribution-NonCommercial 4.0 International', 62 | classifiers=[ 63 | # Trove classifiers 64 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers 65 | 'Topic :: Multimedia :: Sound/Audio :: Analysis', 66 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 67 | "Intended Audience :: Information Technology", 68 | "Intended Audience :: Science/Research", 69 | "Intended Audience :: System Administrators", 70 | "Intended Audience :: Telecommunications Industry", 71 | "Programming Language :: Python :: 3", 72 | "Programming Language :: Python :: 3 :: Only", 73 | "Programming Language :: Python :: 3.7", 74 | "Programming Language :: Python :: 3.8", 75 | "Programming Language :: Python :: 3.9", 76 | "Programming Language :: Python :: 3.10", 77 | "Programming Language :: Python :: 3.11", 78 | "Programming Language :: Python :: 3.12", 79 | ], 80 | ) 81 | -------------------------------------------------------------------------------- /biodenoising/denoiser/resample.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/facebookresearch/demucs under the MIT License 2 | # Original Copyright (c) Earth Species Project. This work is based on Facebook's denoiser. 3 | 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # author: adefossez 8 | 9 | import math 10 | 11 | import torch as th 12 | from torch.nn import functional as F 13 | 14 | 15 | def sinc(t): 16 | """sinc. 17 | 18 | :param t: the input tensor 19 | """ 20 | return th.where(t == 0, th.tensor(1., device=t.device, dtype=t.dtype), th.sin(t) / t) 21 | 22 | 23 | def kernel_upsample2(zeros=56): 24 | """kernel_upsample2. 25 | 26 | """ 27 | win = th.hann_window(4 * zeros + 1, periodic=False) 28 | winodd = win[1::2] 29 | t = th.linspace(-zeros + 0.5, zeros - 0.5, 2 * zeros) 30 | t *= math.pi 31 | kernel = (sinc(t) * winodd).view(1, 1, -1) 32 | return kernel 33 | 34 | 35 | def upsample2(x, zeros=56): 36 | """ 37 | Upsampling the input by 2 using sinc interpolation. 38 | Smith, Julius, and Phil Gossett. "A flexible sampling-rate conversion method." 39 | ICASSP'84. IEEE International Conference on Acoustics, Speech, and Signal Processing. 40 | Vol. 9. IEEE, 1984. 41 | """ 42 | *other, time = x.shape 43 | kernel = kernel_upsample2(zeros).to(x) 44 | out = F.conv1d(x.view(-1, 1, time), kernel, padding=zeros)[..., 1:].view(*other, time) 45 | y = th.stack([x, out], dim=-1) 46 | return y.view(*other, -1) 47 | 48 | 49 | def kernel_downsample2(zeros=56): 50 | """kernel_downsample2. 51 | 52 | """ 53 | win = th.hann_window(4 * zeros + 1, periodic=False) 54 | winodd = win[1::2] 55 | t = th.linspace(-zeros + 0.5, zeros - 0.5, 2 * zeros) 56 | t.mul_(math.pi) 57 | kernel = (sinc(t) * winodd).view(1, 1, -1) 58 | return kernel 59 | 60 | 61 | def downsample2(x, zeros=56): 62 | """ 63 | Downsampling the input by 2 using sinc interpolation. 64 | Smith, Julius, and Phil Gossett. "A flexible sampling-rate conversion method." 65 | ICASSP'84. IEEE International Conference on Acoustics, Speech, and Signal Processing. 66 | Vol. 9. IEEE, 1984. 67 | """ 68 | if x.shape[-1] % 2 != 0: 69 | x = F.pad(x, (0, 1)) 70 | xeven = x[..., ::2] 71 | xodd = x[..., 1::2] 72 | *other, time = xodd.shape 73 | kernel = kernel_downsample2(zeros).to(x) 74 | out = xeven + F.conv1d(xodd.view(-1, 1, time), kernel, padding=zeros)[..., :-1].view( 75 | *other, time) 76 | return out.view(*other, -1).mul(0.5) 77 | -------------------------------------------------------------------------------- /plot.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ''' 3 | import os 4 | import argparse 5 | from multiprocessing import Pool 6 | from tqdm import tqdm 7 | import pandas as pd 8 | import seaborn as sns 9 | import numpy as np 10 | from matplotlib import pyplot as plt 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument( 14 | "--results_path", type=str, required=True, help="Path to results folder containing csvs with the SISDRS and filenames" 15 | ) 16 | parser.add_argument("--num_processes",default=1,type=int,help="number of processes for multiprocessing") 17 | 18 | 19 | def process_file(args): 20 | filename, conf = args 21 | method = filename.split('.csv')[0] 22 | seed = None 23 | if ',seed=' in filename: 24 | seed = int(method.split(',seed=')[1]) 25 | method = method.split(',seed=')[0] 26 | 27 | print("Processing file {}".format(filename)) 28 | #read a csv file into a pandas dataframe and return it 29 | df = pd.read_csv(os.path.join(conf["results_path"],filename),usecols=[1,2,3]) 30 | df = pd.melt(df, id_vars=["filename"], var_name="metric", value_name="dB") 31 | df['seed'] = seed 32 | df['method'] = method 33 | return df.reset_index(drop=True) 34 | 35 | 36 | 37 | if __name__ == "__main__": 38 | args = parser.parse_args() 39 | arg_dic = dict(vars(args)) 40 | 41 | files = [file for file in os.listdir(os.path.join(arg_dic["results_path"])) if file.endswith('.csv') and not file.startswith('.')] 42 | assert len(files) > 0, "No csv files found in the results folder" 43 | if arg_dic["num_processes"] > 1: 44 | with Pool(processes=arg_dic["num_processes"]) as pool: 45 | mp_args = [[f,arg_dic] for f in files] 46 | results = tqdm(pool.map(process_file, mp_args), total=len(files)) 47 | df = pd.concat(results) 48 | else: 49 | for i,f in enumerate(files): 50 | result = process_file([f,arg_dic]) 51 | if i == 0: 52 | df = result 53 | else: 54 | df = pd.concat([df,result]) 55 | # import pdb; pdb.set_trace() 56 | fig, ax = plt.subplots() 57 | # sns.stripplot(data=df, x="metric", y="dB", hue="method",dodge=True, alpha=.2, legend=False, palette="dark:yellow") 58 | bb = sns.barplot(data=df, x="metric", y="dB", hue="method", ax=ax, palette="Blues",dodge=True,estimator=np.median) 59 | pp = sns.pointplot(data=df, x="metric", y="dB", hue="method", ax=ax, palette='dark:black',dodge=.64, linestyle="none", errorbar=None,estimator='mean', legend=False) 60 | for i in pp.containers: 61 | ax.bar_label(i, fmt='%.2f') 62 | # ax.bar_label(pp.containers[0]) 63 | plt.show() -------------------------------------------------------------------------------- /biodenoising/denoiser/executor.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/facebookresearch/demucs under the MIT License 2 | # Original Copyright (c) Earth Species Project. This work is based on Facebook's denoiser. 3 | 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # author: adefossez 8 | """ 9 | Start multiple process locally for DDP. 10 | """ 11 | 12 | from pathlib import Path 13 | import logging 14 | import subprocess as sp 15 | import sys 16 | 17 | from hydra import utils 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class ChildrenManager: 23 | def __init__(self): 24 | self.children = [] 25 | self.failed = False 26 | 27 | def add(self, child): 28 | child.rank = len(self.children) 29 | self.children.append(child) 30 | 31 | def __enter__(self): 32 | return self 33 | 34 | def __exit__(self, exc_type, exc_value, traceback): 35 | if exc_value is not None: 36 | logger.error("An exception happened while starting workers %r", exc_value) 37 | self.failed = True 38 | try: 39 | while self.children and not self.failed: 40 | for child in list(self.children): 41 | try: 42 | exitcode = child.wait(0.1) 43 | except sp.TimeoutExpired: 44 | continue 45 | else: 46 | self.children.remove(child) 47 | if exitcode: 48 | logger.error(f"Worker {child.rank} died, killing all workers") 49 | self.failed = True 50 | except KeyboardInterrupt: 51 | logger.error("Received keyboard interrupt, trying to kill all workers.") 52 | self.failed = True 53 | for child in self.children: 54 | child.terminate() 55 | if not self.failed: 56 | logger.info("All workers completed successfully") 57 | 58 | 59 | def start_ddp_workers(cfg): 60 | import torch as th 61 | log = utils.HydraConfig().hydra.job_logging.handlers.file.filename 62 | rendezvous_file = Path(cfg.rendezvous_file) 63 | if rendezvous_file.exists(): 64 | rendezvous_file.unlink() 65 | 66 | world_size = th.cuda.device_count() 67 | if not world_size: 68 | logger.error( 69 | "DDP is only available on GPU. Make sure GPUs are properly configured with cuda.") 70 | sys.exit(1) 71 | logger.info(f"Starting {world_size} worker processes for DDP.") 72 | with ChildrenManager() as manager: 73 | for rank in range(world_size): 74 | kwargs = {} 75 | argv = list(sys.argv) 76 | argv += [f"world_size={world_size}", f"rank={rank}"] 77 | if rank > 0: 78 | kwargs['stdin'] = sp.DEVNULL 79 | kwargs['stdout'] = sp.DEVNULL 80 | kwargs['stderr'] = sp.DEVNULL 81 | log += f".{rank}" 82 | argv.append("hydra.job_logging.handlers.file.filename=" + log) 83 | manager.add(sp.Popen([sys.executable] + argv, cwd=utils.get_original_cwd(), **kwargs)) 84 | sys.exit(int(manager.failed)) 85 | -------------------------------------------------------------------------------- /biodenoising/denoiser/distrib.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/facebookresearch/demucs under the MIT License 2 | # Original Copyright (c) Earth Species Project. This work is based on Facebook's denoiser. 3 | 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # author: adefossez 8 | 9 | import logging 10 | import os 11 | 12 | import torch 13 | from torch.utils.data.distributed import DistributedSampler 14 | from torch.utils.data import DataLoader, Subset 15 | from torch.nn.parallel.distributed import DistributedDataParallel 16 | 17 | logger = logging.getLogger(__name__) 18 | rank = 0 19 | world_size = 1 20 | 21 | 22 | def init(args): 23 | """init. 24 | 25 | Initialize DDP using the given rendezvous file. 26 | """ 27 | global rank, world_size 28 | if args.ddp: 29 | assert args.rank is not None and args.world_size is not None 30 | rank = args.rank 31 | world_size = args.world_size 32 | if world_size == 1: 33 | return 34 | torch.cuda.set_device(rank) 35 | torch.distributed.init_process_group( 36 | backend=args.ddp_backend, 37 | init_method='file://' + os.path.abspath(args.rendezvous_file), 38 | world_size=world_size, 39 | rank=rank) 40 | logger.debug("Distributed rendezvous went well, rank %d/%d", rank, world_size) 41 | 42 | 43 | def average(metrics, count=1.): 44 | """average. 45 | 46 | Average all the relevant metrices across processes 47 | `metrics`should be a 1D float32 vector. Returns the average of `metrics` 48 | over all hosts. You can use `count` to control the weight of each worker. 49 | """ 50 | if world_size == 1: 51 | return metrics 52 | tensor = torch.tensor(list(metrics) + [1], device='cuda', dtype=torch.float32) 53 | tensor *= count 54 | torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM) 55 | return (tensor[:-1] / tensor[-1]).cpu().numpy().tolist() 56 | 57 | 58 | def wrap(model): 59 | """wrap. 60 | 61 | Wrap a model with DDP if distributed training is enabled. 62 | """ 63 | if world_size == 1: 64 | return model 65 | else: 66 | return DistributedDataParallel( 67 | model, 68 | device_ids=[torch.cuda.current_device()], 69 | output_device=torch.cuda.current_device()) 70 | 71 | 72 | def barrier(): 73 | if world_size > 1: 74 | torch.distributed.barrier() 75 | 76 | 77 | def loader(dataset, *args, shuffle=False, klass=DataLoader, **kwargs): 78 | """loader. 79 | 80 | Create a dataloader properly in case of distributed training. 81 | If a gradient is going to be computed you must set `shuffle=True`. 82 | 83 | :param dataset: the dataset to be parallelized 84 | :param args: relevant args for the loader 85 | :param shuffle: shuffle examples 86 | :param klass: loader class 87 | :param kwargs: relevant args 88 | """ 89 | 90 | if world_size == 1: 91 | return klass(dataset, *args, shuffle=shuffle, **kwargs) 92 | 93 | if shuffle: 94 | # train means we will compute backward, we use DistributedSampler 95 | sampler = DistributedSampler(dataset) 96 | # We ignore shuffle, DistributedSampler already shuffles 97 | return klass(dataset, *args, **kwargs, sampler=sampler) 98 | else: 99 | # We make a manual shard, as DistributedSampler otherwise replicate some examples 100 | dataset = Subset(dataset, list(range(rank, len(dataset), world_size))) 101 | return klass(dataset, *args, shuffle=shuffle) 102 | -------------------------------------------------------------------------------- /biodenoising/denoiser/dsp.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/facebookresearch/demucs under the MIT License 2 | # Original Copyright (c) Earth Species Project. This work is based on Facebook's denoiser. 3 | 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # author: adefossez 8 | 9 | import julius 10 | import numpy as np 11 | import torch 12 | from torch.nn import functional as F 13 | 14 | 15 | def hz_to_mel(f): 16 | return 2595 * np.log10(1 + f / 700) 17 | 18 | 19 | def mel_to_hz(m): 20 | return 700 * (10**(m / 2595) - 1) 21 | 22 | 23 | def mel_frequencies(n_mels, fmin, fmax): 24 | low = hz_to_mel(fmin) 25 | high = hz_to_mel(fmax) 26 | mels = np.linspace(low, high, n_mels) 27 | return mel_to_hz(mels) 28 | 29 | 30 | def convert_audio_channels(wav, channels=2): 31 | """Convert audio to the given number of channels.""" 32 | *shape, src_channels, length = wav.shape 33 | if src_channels == channels: 34 | pass 35 | elif channels == 1: 36 | # Case 1: 37 | # The caller asked 1-channel audio, but the stream have multiple 38 | # channels, downmix all channels. 39 | wav = wav.mean(dim=-2, keepdim=True) 40 | elif src_channels == 1: 41 | # Case 2: 42 | # The caller asked for multiple channels, but the input file have 43 | # one single channel, replicate the audio over all channels. 44 | wav = wav.expand(*shape, channels, length) 45 | elif src_channels >= channels: 46 | # Case 3: 47 | # The caller asked for multiple channels, and the input file have 48 | # more channels than requested. In that case return the first channels. 49 | wav = wav[..., :channels, :] 50 | else: 51 | # Case 4: What is a reasonable choice here? 52 | raise ValueError('The audio file has less channels than requested but is not mono.') 53 | return wav 54 | 55 | 56 | def convert_audio(wav, from_samplerate, to_samplerate, channels): 57 | """Convert audio from a given samplerate to a target one and target number of channels.""" 58 | wav = convert_audio_channels(wav, channels) 59 | return julius.resample_frac(wav, from_samplerate, to_samplerate) 60 | 61 | 62 | class LowPassFilters(torch.nn.Module): 63 | """ 64 | Bank of low pass filters. 65 | 66 | Args: 67 | cutoffs (list[float]): list of cutoff frequencies, in [0, 1] expressed as `f/f_s` where 68 | f_s is the samplerate. 69 | width (int): width of the filters (i.e. kernel_size=2 * width + 1). 70 | Default to `2 / min(cutoffs)`. Longer filters will have better attenuation 71 | but more side effects. 72 | Shape: 73 | - Input: `(*, T)` 74 | - Output: `(F, *, T` with `F` the len of `cutoffs`. 75 | """ 76 | 77 | def __init__(self, cutoffs: list, width: int = None): 78 | super().__init__() 79 | self.cutoffs = cutoffs 80 | if width is None: 81 | width = int(2 / min(cutoffs)) 82 | self.width = width 83 | window = torch.hamming_window(2 * width + 1, periodic=False) 84 | t = np.arange(-width, width + 1, dtype=np.float32) 85 | filters = [] 86 | for cutoff in cutoffs: 87 | sinc = torch.from_numpy(np.sinc(2 * cutoff * t)) 88 | filters.append(2 * cutoff * sinc * window) 89 | self.register_buffer("filters", torch.stack(filters).unsqueeze(1)) 90 | 91 | def forward(self, input): 92 | *others, t = input.shape 93 | input = input.view(-1, 1, t) 94 | out = F.conv1d(input, self.filters, padding=self.width) 95 | return out.permute(1, 0, 2).reshape(-1, *others, t) 96 | 97 | def __repr__(self): 98 | return "LossPassFilters(width={},cutoffs={})".format(self.width, self.cutoffs) 99 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | biodenoising/dataset/ 2 | biodenoising/egs/ 3 | biodenoising/outputs/ 4 | outputs/ 5 | tmp/ 6 | egs/ 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | cover/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | .pybuilder/ 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | # For a library or package, you might want to ignore these files since the code is 94 | # intended to run in multiple environments; otherwise, check them in: 95 | # .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # poetry 105 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 106 | # This is especially recommended for binary packages to ensure reproducibility, and is more 107 | # commonly ignored for libraries. 108 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 109 | #poetry.lock 110 | 111 | # pdm 112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 113 | #pdm.lock 114 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 115 | # in version control. 116 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 117 | .pdm.toml 118 | .pdm-python 119 | .pdm-build/ 120 | 121 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 122 | __pypackages__/ 123 | 124 | # Celery stuff 125 | celerybeat-schedule 126 | celerybeat.pid 127 | 128 | # SageMath parsed files 129 | *.sage.py 130 | 131 | # Environments 132 | .env 133 | .venv 134 | env/ 135 | venv/ 136 | ENV/ 137 | env.bak/ 138 | venv.bak/ 139 | 140 | # Spyder project settings 141 | .spyderproject 142 | .spyproject 143 | 144 | # Rope project settings 145 | .ropeproject 146 | 147 | # mkdocs documentation 148 | /site 149 | 150 | # mypy 151 | .mypy_cache/ 152 | .dmypy.json 153 | dmypy.json 154 | 155 | # Pyre type checker 156 | .pyre/ 157 | 158 | # pytype static type analyzer 159 | .pytype/ 160 | 161 | # Cython debug symbols 162 | cython_debug/ 163 | 164 | # PyCharm 165 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 166 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 167 | # and can be added to the global gitignore or merged into this file. For a more nuclear 168 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 169 | #.idea/ 170 | -------------------------------------------------------------------------------- /biodenoising/denoiser/data.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/facebookresearch/demucs under the MIT License 2 | # Original Copyright (c) Earth Species Project. This work is based on Facebook's denoiser. 3 | 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # author: adefossez and adiyoss 8 | 9 | import json 10 | import logging 11 | import os 12 | import re 13 | 14 | from .audio import Audioset 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def match_dns(noisy, clean): 20 | """match_dns. 21 | Match noisy and clean DNS dataset filenames. 22 | 23 | :param noisy: list of the noisy filenames 24 | :param clean: list of the clean filenames 25 | """ 26 | logger.debug("Matching noisy and clean for dns dataset") 27 | noisydict = {} 28 | extra_noisy = [] 29 | for path, size in noisy: 30 | match = re.search(r'fileid_(\d+)\.wav$', path) 31 | if match is None: 32 | # maybe we are mixing some other dataset in 33 | extra_noisy.append((path, size)) 34 | else: 35 | noisydict[match.group(1)] = (path, size) 36 | noisy[:] = [] 37 | extra_clean = [] 38 | copied = list(clean) 39 | clean[:] = [] 40 | for path, size in copied: 41 | match = re.search(r'fileid_(\d+)\.wav$', path) 42 | if match is None: 43 | extra_clean.append((path, size)) 44 | else: 45 | noisy.append(noisydict[match.group(1)]) 46 | clean.append((path, size)) 47 | extra_noisy.sort() 48 | extra_clean.sort() 49 | clean += extra_clean 50 | noisy += extra_noisy 51 | 52 | 53 | def match_files(noisy, clean, matching="sort"): 54 | """match_files. 55 | Sort files to match noisy and clean filenames. 56 | :param noisy: list of the noisy filenames 57 | :param clean: list of the clean filenames 58 | :param matching: the matching function, at this point only sort is supported 59 | """ 60 | if matching == "dns": 61 | # dns dataset filenames don't match when sorted, we have to manually match them 62 | match_dns(noisy, clean) 63 | elif matching == "sort": 64 | noisy.sort() 65 | clean.sort() 66 | else: 67 | raise ValueError(f"Invalid value for matching {matching}") 68 | 69 | 70 | class NoisyCleanSet: 71 | def __init__(self, json_dir, matching="sort", length=None, stride=None, 72 | pad=True, sample_rate=None, seed=0, with_path=False, nsources=1): 73 | """__init__. 74 | 75 | :param json_dir: directory containing both clean.json and noisy.json 76 | :param matching: matching function for the files 77 | :param length: maximum sequence length 78 | :param stride: the stride used for splitting audio sequences 79 | :param pad: pad the end of the sequence with zeros 80 | :param sample_rate: the signals sampling rate 81 | :param seed: the random seed 82 | :param with_path: return the path of the file with the audio 83 | :param nsources: the number of sources in the mixture 84 | """ 85 | noisy_json = os.path.join(json_dir, 'noisy.json') 86 | clean_json = os.path.join(json_dir, 'clean.json') 87 | with open(noisy_json, 'r') as f: 88 | noisy = json.load(f) 89 | with open(clean_json, 'r') as f: 90 | clean = json.load(f) 91 | 92 | match_files(noisy, clean, matching) 93 | kw = {'length': length, 'stride': stride, 'pad': pad, 'sample_rate': sample_rate, 'with_path': with_path} 94 | self.clean_set = Audioset(clean, **kw) 95 | self.noisy_set = Audioset(noisy, **kw) 96 | self.with_path = with_path 97 | self.nsources = nsources 98 | assert len(self.clean_set) == len(self.noisy_set) 99 | 100 | def __getitem__(self, index): 101 | noisy = self.noisy_set[index] 102 | clean = self.clean_set[index] 103 | return noisy, clean 104 | 105 | def __len__(self): 106 | return len(self.noisy_set) 107 | -------------------------------------------------------------------------------- /biodenoising/conf/config_adapt.yaml: -------------------------------------------------------------------------------- 1 | dset: 2 | low_snr: -10 3 | high_snr: 10 4 | 5 | # Dataset related 6 | full_size: 1000000 7 | nsources: 1 8 | sample_rate: 16000 9 | segment: 4 # in seconds, 10 | stride: 1. # in seconds, how much to stride between training examples 11 | pad: true # if training sample is too short, pad it 12 | epoch_size: 1000 # number of samples per epoch 13 | exclude: #data to exclude 14 | exclude_noise: #noise to exclude 15 | repeat_prob: 0. # probability to repeat a short sample 16 | random_repeat: false # repeat a short sample several random times 17 | random_pad: false # pad a short sample with random silence intervals 18 | silence_prob: 0.2 # probability to add silence 19 | noise_prob: 0. # probability to add noise 20 | eval_window_size: 0 # evaluation window size in samples 21 | normalize: true # normalize the clean signal 22 | random_gain: True # random gain 23 | low_gain: 0.1 # low gain 24 | high_gain: 1. # high gain 25 | parallel_noise: False # mix with noise from the same file -> for domain adaptation 26 | annotations: False # use annotations to extract segments 27 | annotations_begin_column: Begin # column name for segment start time in annotation files 28 | annotations_end_column: End # column name for segment end time in annotation files 29 | annotations_label_column: None # column name for segment label in annotation files 30 | annotations_label_value: None # filter annotations by this label value 31 | annotations_extension: .csv # extension of annotation files 32 | 33 | 34 | # Dataset Augmentation 35 | remix: false # remix noise and clean 36 | bandmask: 0. # drop at most this fraction of freqs in mel scale 37 | shift: 0. # random shift, number of samples 38 | shift_same: false # shift noise and clean by the same amount 39 | revecho: 0. # add reverb like augment 40 | timescale: 0.2 # random time scaling 41 | flip: 0. # random flip 42 | mixup: 0.99 # mixup 43 | 44 | # Logging and printing, and does not impact training 45 | num_prints: 5 46 | device: cuda 47 | num_workers: 10 48 | show: 0 # just show the model and its size and exit 49 | 50 | # Checkpointing, by default automatically load last checkpoint 51 | checkpoint: true 52 | continue_from: '' # Path the a checkpoint.th file to start from. 53 | # this is not used in the name of the experiment! 54 | # so use a dummy=something not to mixup experiments. 55 | continue_best: false # continue from best, not last state if continue_from is set. 56 | continue_pretrained: # use either dns48, dns64 or master64 to fine tune from pretrained-model 57 | restart: true # Ignore existing checkpoints 58 | checkpoint_file: checkpoint.th 59 | best_file: best.th # will contain only best model at any point 60 | history_file: history.json 61 | samples_dir: samples 62 | save_again: false # if true, only load checkpoint and save again, useful to reexport best.th 63 | model_path: '' 64 | biodenoising16k_dns48: false 65 | 66 | # Other stuff 67 | seed: 0 68 | dummy: # use this if you want twice the same exp, with a different name 69 | 70 | # Evaluation stuff 71 | pesq: True # compute pesq? 72 | eval_every: 1 # compute test metrics every so epochs 73 | dry: 0. # dry/wet knob value at eval 74 | streaming: False # use streaming evaluation for Demucs 75 | 76 | # Optimization related 77 | optim: adam 78 | swa_scheduler: false 79 | swa_start: 1 80 | lr: 1e-6 81 | #lr: 1e-5 ### sisdr loss 82 | #lr: 1e-6 ### l1 loss 83 | weight_decay: 0 84 | beta1: 0.9 85 | beta2: 0.999 86 | loss: l1 87 | stft_loss: False 88 | stft_sc_factor: .5 89 | stft_mag_factor: .5 90 | stft_mask: False 91 | stft_mask_threshold: -60 92 | epochs: 30 93 | batch_size: 16 94 | clip_grad_norm: 10 95 | clamp_loss: 30 96 | rms_loss: 0 97 | 98 | # Teacher-student experiment 99 | teacher_student: False 100 | bootstrap_remix: False 101 | rescale_to_input_mixture: False 102 | apply_mixture_consistency: False 103 | n_epochs_teacher_update: 1 104 | teacher_momentum: 0.01 105 | other_noise: True 106 | 107 | # Models 108 | model: biodenoising16k_dns48 109 | demucs: 110 | chin: 1 111 | chout: 1 112 | hidden: 48 113 | max_hidden: 10000 114 | causal: true 115 | glu: true 116 | depth: 5 117 | kernel_size: 8 118 | stride: 4 119 | normalize: true 120 | resample: 4 121 | growth: 2 122 | rescale: 0.1 123 | noisereduce: False 124 | nr_mode: 'concat' 125 | cleanunet: 126 | channels_input: 1 127 | channels_output: 1 128 | channels_H: 64 129 | max_H: 768 130 | encoder_n_layers: 8 131 | kernel_size: 4 132 | stride: 2 133 | tsfm_n_layers: 5 134 | tsfm_n_head: 8 135 | tsfm_d_model: 512 136 | tsfm_d_inner: 2048 137 | 138 | # Experiment launching, distributed 139 | ddp: false 140 | ddp_backend: nccl 141 | rendezvous_file: ./rendezvous 142 | 143 | # Internal config, don't set manually 144 | rank: 145 | world_size: 146 | -------------------------------------------------------------------------------- /biodenoising/selection_table.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from typing import List, Optional, Sequence, Tuple 4 | 5 | import pandas as pd 6 | import torch 7 | 8 | 9 | def _normalize_column_name(name: str) -> str: 10 | """ 11 | Normalize a column name for fuzzy matching. 12 | Lowercase, strip, and remove spaces/underscores for robust comparisons. 13 | """ 14 | norm = name.lower().strip() 15 | norm = norm.replace(" ", "").replace("_", "") 16 | return norm 17 | 18 | 19 | def find_selection_table_for(audio_filepath: str) -> Optional[str]: 20 | """ 21 | Look for a csv/tsv/txt file in the same directory as the audio file whose 22 | filename contains the audio stem. 23 | 24 | Parameters 25 | ---------- 26 | audio_filepath : str 27 | Full path to the audio file. 28 | 29 | Returns 30 | ------- 31 | Optional[str] 32 | Path to the first matching table file, or None if none found. 33 | """ 34 | directory = os.path.dirname(audio_filepath) 35 | stem = os.path.splitext(os.path.basename(audio_filepath))[0] 36 | candidates: List[str] = [] 37 | for ext in ("csv", "tsv", "txt"): 38 | pattern = os.path.join(directory, f"*{stem}*.{ext}") 39 | candidates.extend(glob.glob(pattern)) 40 | return candidates[0] if candidates else None 41 | 42 | 43 | def load_events_seconds(table_path: Optional[str]) -> List[Tuple[float, float]]: 44 | """ 45 | Load selection table and extract start/end times in seconds. 46 | Tries to match columns for start and end using fuzzy rules. 47 | 48 | Parameters 49 | ---------- 50 | table_path : Optional[str] 51 | Path to the table file (csv, tsv, or txt). If None, returns empty list. 52 | 53 | Returns 54 | ------- 55 | List[Tuple[float, float]] 56 | A list of (start_seconds, end_seconds) tuples. 57 | """ 58 | if table_path is None: 59 | return [] 60 | 61 | try: 62 | df = pd.read_csv(table_path, sep=None, engine="python") 63 | except Exception: 64 | if table_path.endswith((".tsv", ".txt")): 65 | df = pd.read_csv(table_path, sep="\t") 66 | else: 67 | df = pd.read_csv(table_path) 68 | 69 | normalized = {col: _normalize_column_name(col) for col in df.columns} 70 | 71 | start_aliases = {"start", "beginning", "begintime", "begin"} 72 | end_aliases = {"end", "endtime"} 73 | 74 | start_col: Optional[str] = None 75 | end_col: Optional[str] = None 76 | for col, norm in normalized.items(): 77 | if start_col is None and (norm == "start" or any(alias in norm for alias in start_aliases)): 78 | start_col = col 79 | if end_col is None and (norm == "end" or any(alias in norm for alias in end_aliases)): 80 | end_col = col 81 | if start_col is not None and end_col is not None: 82 | break 83 | 84 | if start_col is None or end_col is None: 85 | return [] 86 | 87 | try: 88 | starts = pd.to_numeric(df[start_col], errors="coerce").astype(float).tolist() 89 | ends = pd.to_numeric(df[end_col], errors="coerce").astype(float).tolist() 90 | except Exception: 91 | return [] 92 | 93 | events: List[Tuple[float, float]] = [] 94 | for s, e in zip(starts, ends): 95 | if pd.isna(s) or pd.isna(e): 96 | continue 97 | if e <= s: 98 | continue 99 | events.append((float(s), float(e))) 100 | return events 101 | 102 | 103 | def build_mask_from_events( 104 | length_frames: int, 105 | sample_rate: int, 106 | events: Sequence[Tuple[float, float]], 107 | device: torch.device | str, 108 | ) -> torch.Tensor: 109 | """ 110 | Build a 1D mask tensor of shape [length_frames] with ones inside any event 111 | intervals and zeros elsewhere. 112 | 113 | Parameters 114 | ---------- 115 | length_frames : int 116 | Total number of frames (samples) in the signal. 117 | sample_rate : int 118 | Sample rate of the signal. 119 | events : Sequence[Tuple[float, float]] 120 | Event intervals in seconds. 121 | device : torch.device | str 122 | Torch device for the output tensor. 123 | 124 | Returns 125 | ------- 126 | torch.Tensor 127 | A 1D tensor mask of length `length_frames`. 128 | """ 129 | if not events: 130 | return torch.ones(length_frames, device=device) 131 | 132 | mask = torch.zeros(length_frames, device=device) 133 | for start_s, end_s in events: 134 | start_idx = int(max(0, round(start_s * sample_rate))) 135 | end_idx = int(min(length_frames, round(end_s * sample_rate))) 136 | if end_idx > start_idx: 137 | mask[start_idx:end_idx] = 1.0 138 | if mask.sum().item() == 0: 139 | mask[:] = 1.0 140 | return mask 141 | 142 | 143 | 144 | -------------------------------------------------------------------------------- /biodenoising/denoiser/evaluate_speech.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/facebookresearch/demucs under the MIT License 2 | # Original Copyright (c) Earth Species Project. This work is based on Facebook's denoiser. 3 | 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # author: adiyoss 8 | 9 | import argparse 10 | from concurrent.futures import ProcessPoolExecutor 11 | import json 12 | import logging 13 | import sys 14 | 15 | from pesq import pesq 16 | from pystoi import stoi 17 | import torch 18 | 19 | from .data import NoisyCleanSet 20 | from .enhance import add_flags, get_estimate 21 | from . import distrib, pretrained 22 | from .utils import bold, LogProgress 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | parser = argparse.ArgumentParser( 27 | 'denoiser.evaluate', 28 | description='Speech enhancement using Demucs - Evaluate model performance') 29 | add_flags(parser) 30 | parser.add_argument('--data_dir', help='directory including noisy.json and clean.json files') 31 | parser.add_argument('--matching', default="sort", help='set this to dns for the dns dataset.') 32 | parser.add_argument('--no_pesq', action="store_false", dest="pesq", default=True, 33 | help="Don't compute PESQ.") 34 | parser.add_argument('-v', '--verbose', action='store_const', const=logging.DEBUG, 35 | default=logging.INFO, help="More loggging") 36 | 37 | 38 | def evaluate(args, model=None, data_loader=None): 39 | total_pesq = 0 40 | total_stoi = 0 41 | total_cnt = 0 42 | updates = 5 43 | 44 | # Load model 45 | if not model: 46 | model = pretrained.get_model(args).to(args.device) 47 | model.eval() 48 | 49 | # Load data 50 | if data_loader is None: 51 | dataset = NoisyCleanSet(args.data_dir, 52 | matching=args.matching, sample_rate=model.sample_rate) 53 | data_loader = distrib.loader(dataset, batch_size=1, num_workers=2) 54 | pendings = [] 55 | with ProcessPoolExecutor(args.num_workers) as pool: 56 | with torch.no_grad(): 57 | iterator = LogProgress(logger, data_loader, name="Eval estimates") 58 | for i, data in enumerate(iterator): 59 | # Get batch data 60 | noisy, clean = [x.to(args.device) for x in data] 61 | # If device is CPU, we do parallel evaluation in each CPU worker. 62 | if args.device == 'cpu': 63 | pendings.append( 64 | pool.submit(_estimate_and_run_metrics, clean, model, noisy, args)) 65 | else: 66 | estimate = get_estimate(model, noisy, args) 67 | estimate = estimate.cpu() 68 | clean = clean.cpu() 69 | pendings.append( 70 | pool.submit(_run_metrics, clean, estimate, args, model.sample_rate)) 71 | total_cnt += clean.shape[0] 72 | 73 | for pending in LogProgress(logger, pendings, updates, name="Eval metrics"): 74 | pesq_i, stoi_i = pending.result() 75 | total_pesq += pesq_i 76 | total_stoi += stoi_i 77 | 78 | metrics = [total_pesq, total_stoi] 79 | pesq, stoi = distrib.average([m/total_cnt for m in metrics], total_cnt) 80 | logger.info(bold(f'Test set performance:PESQ={pesq}, STOI={stoi}.')) 81 | return pesq, stoi 82 | 83 | 84 | def _estimate_and_run_metrics(clean, model, noisy, args): 85 | estimate = get_estimate(model, noisy, args) 86 | return _run_metrics(clean, estimate, args, sr=model.sample_rate) 87 | 88 | 89 | def _run_metrics(clean, estimate, args, sr): 90 | estimate = estimate.numpy()[:, 0] 91 | clean = clean.numpy()[:, 0] 92 | if args.pesq: 93 | pesq_i = get_pesq(clean, estimate, sr=sr) 94 | else: 95 | pesq_i = 0 96 | stoi_i = get_stoi(clean, estimate, sr=sr) 97 | return pesq_i, stoi_i 98 | 99 | 100 | def get_pesq(ref_sig, out_sig, sr): 101 | """Calculate PESQ. 102 | Args: 103 | ref_sig: numpy.ndarray, [B, T] 104 | out_sig: numpy.ndarray, [B, T] 105 | Returns: 106 | PESQ 107 | """ 108 | pesq_val = 0 109 | for i in range(len(ref_sig)): 110 | pesq_val += pesq(sr, ref_sig[i], out_sig[i], 'wb') 111 | return pesq_val 112 | 113 | 114 | def get_stoi(ref_sig, out_sig, sr): 115 | """Calculate STOI. 116 | Args: 117 | ref_sig: numpy.ndarray, [B, T] 118 | out_sig: numpy.ndarray, [B, T] 119 | Returns: 120 | STOI 121 | """ 122 | stoi_val = 0 123 | for i in range(len(ref_sig)): 124 | stoi_val += stoi(ref_sig[i], out_sig[i], sr, extended=False) 125 | return stoi_val 126 | 127 | 128 | def main(): 129 | args = parser.parse_args() 130 | logging.basicConfig(stream=sys.stderr, level=args.verbose) 131 | logger.debug(args) 132 | pesq, stoi = evaluate(args) 133 | json.dump({'pesq': pesq, 'stoi': stoi}, sys.stdout) 134 | sys.stdout.write('\n') 135 | 136 | 137 | if __name__ == '__main__': 138 | main() 139 | -------------------------------------------------------------------------------- /scripts/biodenoising_demo_long_audio.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "4FYJAV3ElZ5h" 7 | }, 8 | "source": [ 9 | "#Biodenoising - Animal vocalization denoising\n", 10 | "This is a demo for animal vocalization denoising without access to clean data.\n", 11 | "For the more info check the [associated page](https://mariusmiron.com/research/biodenoising/) and the code repository on [github](https://github.com/earthspecies/biodenoising).\n", 12 | "\n", 13 | "First, let's install the package from pip:" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": { 20 | "colab": { 21 | "base_uri": "https://localhost:8080/" 22 | }, 23 | "id": "pl7pGTRidM1Q", 24 | "outputId": "119e964a-0fa7-4251-9ffc-1e4ed23f45fe" 25 | }, 26 | "outputs": [], 27 | "source": [ 28 | "!pip install biodenoising" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": { 34 | "id": "yzZIcD2PmJxB" 35 | }, 36 | "source": [ 37 | "We import the libraries:" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": { 44 | "id": "cypBqtncgE9Q" 45 | }, 46 | "outputs": [], 47 | "source": [ 48 | "from IPython import display as disp\n", 49 | "import os\n", 50 | "import torch\n", 51 | "import torchaudio\n", 52 | "from biodenoising import pretrained\n", 53 | "from biodenoising.denoiser.dsp import convert_audio" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": { 59 | "id": "mMfAk2Z2mRp3" 60 | }, 61 | "source": [ 62 | "We download some noisy animal vocalizations from the biodenoising_validation dataset. Note that these files, species, noise conditions were not seen during training, to test for generalization." 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "metadata": { 68 | "id": "ayCjnCxFmeJH" 69 | }, 70 | "source": [ 71 | "We set the device, gpu or cpu. You can use a computing instance with a GPU for faster processing." 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "metadata": { 78 | "id": "snrSxC8Uh4lt" 79 | }, 80 | "outputs": [], 81 | "source": [ 82 | "if torch.cuda.is_available():\n", 83 | " device = torch.device('cuda')\n", 84 | "else:\n", 85 | " device = torch.device('cpu')" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "metadata": { 91 | "id": "czIUjWoRmjh-" 92 | }, 93 | "source": [ 94 | "Let's load the 16kHz model. If it's the first time you run this, it will download the model locally." 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": { 101 | "colab": { 102 | "base_uri": "https://localhost:8080/" 103 | }, 104 | "id": "Kpuap38kgnxY", 105 | "outputId": "0a5fd42c-192e-493c-a986-48e523cf315e" 106 | }, 107 | "outputs": [], 108 | "source": [ 109 | "model = pretrained.biodenoising16k_dns48().to(device)" 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "metadata": { 115 | "id": "c8tHqnWtm2v9" 116 | }, 117 | "source": [ 118 | "We use the model above to denoise the first demo sound." 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "metadata": { 125 | "colab": { 126 | "base_uri": "https://localhost:8080/", 127 | "height": 134 128 | }, 129 | "id": "XexsuKUEg2W4", 130 | "outputId": "d863247a-a9d3-4f7f-acba-5b8f3ee10aaa" 131 | }, 132 | "outputs": [], 133 | "source": [ 134 | "wav, sr = torchaudio.load(os.path.join('whale.wav'))\n", 135 | "wav = convert_audio(wav, sr, model.sample_rate, model.chin).to(device)\n", 136 | "\n", 137 | "if wav.shape[-1] > 640000:\n", 138 | " import asteroid\n", 139 | " ola_model = asteroid.dsp.overlap_add.LambdaOverlapAdd(\n", 140 | " nnet=model, # function to apply to each segment.\n", 141 | " n_src=1, # number of sources in the output of nnet\n", 142 | " window_size=640000, # Size of segmenting window\n", 143 | " hop_size=640000//4, # segmentation hop size\n", 144 | " window=\"hann\", # Type of the window (see scipy.signal.get_window\n", 145 | " reorder_chunks=False, # Whether to reorder each consecutive segment.\n", 146 | " enable_grad=False, # Set gradient calculation on of off (see torch.set_grad_enabled)\n", 147 | " )\n", 148 | " ola_model.window = ola_model.window.to(device)\n", 149 | " with torch.no_grad():\n", 150 | " denoised = ola_model(wav[None])[0]\n", 151 | "else:\n", 152 | " with torch.no_grad():\n", 153 | " denoised = model(wav[None])[0]\n", 154 | "disp.display(disp.Audio(wav.data.cpu().numpy(), rate=model.sample_rate))\n", 155 | "disp.display(disp.Audio(denoised.data.cpu().numpy(), rate=model.sample_rate))" 156 | ] 157 | } 158 | ], 159 | "metadata": { 160 | "accelerator": "GPU", 161 | "colab": { 162 | "gpuType": "T4", 163 | "provenance": [] 164 | }, 165 | "kernelspec": { 166 | "display_name": "Python 3", 167 | "name": "python3" 168 | }, 169 | "language_info": { 170 | "name": "python" 171 | } 172 | }, 173 | "nbformat": 4, 174 | "nbformat_minor": 0 175 | } 176 | -------------------------------------------------------------------------------- /biodenoising/denoiser/states.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/facebookresearch/demucs under the MIT License 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """ 8 | Utilities to save and load models. 9 | """ 10 | from contextlib import contextmanager 11 | 12 | import functools 13 | import hashlib 14 | import inspect 15 | import io 16 | from pathlib import Path 17 | import warnings 18 | 19 | from omegaconf import OmegaConf 20 | # from dora.log import fatal 21 | import torch 22 | 23 | 24 | def _check_diffq(): 25 | try: 26 | import diffq # noqa 27 | except ImportError: 28 | print('Trying to use DiffQ, but diffq is not installed.\n' 29 | 'On Windows run: python.exe -m pip install diffq \n' 30 | 'On Linux/Mac, run: python3 -m pip install diffq') 31 | 32 | 33 | def get_quantizer(model, args, optimizer=None): 34 | """Return the quantizer given the XP quantization args.""" 35 | quantizer = None 36 | if args.diffq: 37 | _check_diffq() 38 | from diffq import DiffQuantizer 39 | quantizer = DiffQuantizer( 40 | model, min_size=args.min_size, group_size=args.group_size) 41 | if optimizer is not None: 42 | quantizer.setup_optimizer(optimizer) 43 | elif args.qat: 44 | _check_diffq() 45 | from diffq import UniformQuantizer 46 | quantizer = UniformQuantizer( 47 | model, bits=args.qat, min_size=args.min_size) 48 | return quantizer 49 | 50 | 51 | def load_model(path_or_package, strict=False): 52 | """Load a model from the given serialized model, either given as a dict (already loaded) 53 | or a path to a file on disk.""" 54 | if isinstance(path_or_package, dict): 55 | package = path_or_package 56 | elif isinstance(path_or_package, (str, Path)): 57 | with warnings.catch_warnings(): 58 | warnings.simplefilter("ignore") 59 | path = path_or_package 60 | package = torch.load(path, 'cpu') 61 | else: 62 | raise ValueError(f"Invalid type for {path_or_package}.") 63 | 64 | klass = package["klass"] 65 | args = package["args"] 66 | kwargs = package["kwargs"] 67 | 68 | if strict: 69 | model = klass(*args, **kwargs) 70 | else: 71 | sig = inspect.signature(klass) 72 | for key in list(kwargs): 73 | if key not in sig.parameters: 74 | warnings.warn("Dropping inexistant parameter " + key) 75 | del kwargs[key] 76 | model = klass(*args, **kwargs) 77 | 78 | state = package["state"] 79 | 80 | set_state(model, state) 81 | return model 82 | 83 | 84 | def get_state(model, quantizer, half=False): 85 | """Get the state from a model, potentially with quantization applied. 86 | If `half` is True, model are stored as half precision, which shouldn't impact performance 87 | but half the state size.""" 88 | if quantizer is None: 89 | dtype = torch.half if half else None 90 | state = {k: p.data.to(device='cpu', dtype=dtype) for k, p in model.state_dict().items()} 91 | else: 92 | state = quantizer.get_quantized_state() 93 | state['__quantized'] = True 94 | return state 95 | 96 | 97 | def set_state(model, state, quantizer=None): 98 | """Set the state on a given model.""" 99 | if state.get('__quantized'): 100 | if quantizer is not None: 101 | quantizer.restore_quantized_state(model, state['quantized']) 102 | else: 103 | _check_diffq() 104 | from diffq import restore_quantized_state 105 | restore_quantized_state(model, state) 106 | else: 107 | model.load_state_dict(state) 108 | return state 109 | 110 | 111 | def save_with_checksum(content, path): 112 | """Save the given value on disk, along with a sha256 hash. 113 | Should be used with the output of either `serialize_model` or `get_state`.""" 114 | buf = io.BytesIO() 115 | torch.save(content, buf) 116 | sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8] 117 | 118 | path = path.parent / (path.stem + "-" + sig + path.suffix) 119 | path.write_bytes(buf.getvalue()) 120 | 121 | 122 | def serialize_model(model, training_args, quantizer=None, half=True): 123 | args, kwargs = model._init_args_kwargs 124 | klass = model.__class__ 125 | 126 | state = get_state(model, quantizer, half) 127 | return { 128 | 'klass': klass, 129 | 'args': args, 130 | 'kwargs': kwargs, 131 | 'state': state, 132 | 'training_args': OmegaConf.to_container(training_args, resolve=True), 133 | } 134 | 135 | 136 | def copy_state(state): 137 | return {k: v.cpu().clone() for k, v in state.items()} 138 | 139 | 140 | @contextmanager 141 | def swap_state(model, state): 142 | """ 143 | Context manager that swaps the state of a model, e.g: 144 | 145 | # model is in old state 146 | with swap_state(model, new_state): 147 | # model in new state 148 | # model back to old state 149 | """ 150 | old_state = copy_state(model.state_dict()) 151 | model.load_state_dict(state, strict=False) 152 | try: 153 | yield 154 | finally: 155 | model.load_state_dict(old_state) 156 | 157 | 158 | def capture_init(init): 159 | @functools.wraps(init) 160 | def __init__(self, *args, **kwargs): 161 | self._init_args_kwargs = (args, kwargs) 162 | init(self, *args, **kwargs) 163 | 164 | return __init__ 165 | -------------------------------------------------------------------------------- /adapt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import biodenoising 3 | from biodenoising.adapt import ConfigParser # Import ConfigParser from the package 4 | import logging 5 | import os 6 | import sys 7 | import yaml 8 | import pandas as pd 9 | import soundfile as sf 10 | import torchaudio 11 | import numpy as np 12 | from pathlib import Path 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | # No need to define ConfigParser here anymore, as we're importing it 17 | 18 | # Use the imported ConfigParser 19 | parser = ConfigParser() 20 | parser.add_argument("--steps", default=5, type=int, help="Number of steps to use for adaptation") 21 | parser.add_argument("--noisy_dir", type=str, default=None, 22 | help="path to the directory with noisy wav files") 23 | parser.add_argument("--noise_dir", type=str, default=None, 24 | help="path to the directory with noise wav files") 25 | parser.add_argument("--test_dir", type=str, default=None, 26 | help="for evaluation purpose only: path to the directory containing clean.json and noise.json files") 27 | parser.add_argument("--out_dir", type=str, default="enhanced", 28 | help="directory putting enhanced wav files") 29 | parser.add_argument('--noisy_estimate', action="store_true",help="compute the noise as the difference between the noisy and the estimated signal") 30 | parser.add_argument("--cfg", type=str, default="biodenoising/conf/config_adapt.yaml", 31 | help="path to the directory with noise wav files") 32 | parser.add_argument("--epochs", default=5, type=int, help="Number of epochs per step") 33 | parser.add_argument('-v', '--verbose', action='store_const', const=logging.DEBUG, 34 | default=logging.INFO, help="more loggging") 35 | parser.add_argument("--method",choices=["biodenoising16k_dns48"], default="biodenoising16k_dns48",help="Method to use for denoising") 36 | parser.add_argument("--segment", default=4, type=int, help="minimum segment size in seconds") 37 | parser.add_argument("--highpass", default=20, type=int, help="apply a highpass filter with this cutoff before separating") 38 | parser.add_argument("--peak_height", default=0.008, type=float, help="filter segments with rms lower than this value") 39 | parser.add_argument("--transform",choices=["none", "time_scale"], default="none",help="Transform input by pitch shifting or time scaling") 40 | parser.add_argument('--revecho', type=float, default=0,help='revecho probability') 41 | parser.add_argument("--use_top", default=1., type=float, help="use the top ratio of files for training, sorted by their rms") 42 | parser.add_argument('--num_valid', type=float, default=0,help='the number of files to use for validation') 43 | parser.add_argument('--antialiasing', action="store_true",help="use an antialiasing filter when time scaling back") 44 | parser.add_argument('--keep_original_sr', action="store_true",help="keep the original sample rate of the audio rather than the model sample rate") 45 | parser.add_argument("--force_sample_rate", default=0, type=int, help="Force the model to take samples of this sample rate") 46 | parser.add_argument("--time_scale_factor", default=0, type=int, help="If the model has a different sample rate, play the audio slower or faster with this factor. If force_sample_rate this automatically changes.") 47 | parser.add_argument('--noise_reduce', action="store_true",help="use noisereduce preprocessing") 48 | parser.add_argument('--amp_scale', action="store_true",help="scale to the amplitude of the input") 49 | parser.add_argument('--interactive', action="store_true",help="pause at each step to allow the user to delete some files and continue") 50 | parser.add_argument("--window_size", type=int, default=0, 51 | help="size of the window for continuous processing") 52 | parser.add_argument('--selection_table', action="store_true", help="Enable event masking via selection tables (csv/tsv/txt) located next to audio files.") 53 | parser.add_argument('--device', default="cuda") 54 | parser.add_argument('--dry', type=float, default=0.1, 55 | help='dry/wet knob coefficient. 0 is only denoised, 1 only input signal.') 56 | parser.add_argument('--num_workers', type=int, default=5) 57 | parser.add_argument('--annotations', action="store_true", default=False, 58 | help="Use annotation files to extract segments from audio files") 59 | parser.add_argument('--annotations_begin_column', type=str, default="Begin", 60 | help="Column name for segment start time in annotation files") 61 | parser.add_argument('--annotations_end_column', type=str, default="End", 62 | help="Column name for segment end time in annotation files") 63 | parser.add_argument('--annotations_label_column', type=str, default=None, 64 | help="Column name for segment label in annotation files") 65 | parser.add_argument('--annotations_label_value', type=str, default=None, 66 | help="Filter annotations by this label value") 67 | parser.add_argument('--annotations_extension', type=str, default=".csv", 68 | help="Extension of annotation files") 69 | parser.add_argument('--processed_dir', type=str, default=None, 70 | help="Directory for storing preprocessed audio segments") 71 | 72 | def run_adaptation_main(args): 73 | logging.basicConfig(stream=sys.stderr, level=args.verbose) 74 | logger.debug(args) 75 | 76 | # Call the refactored adaptation function from the module 77 | model = biodenoising.adapt.run_adaptation(args) 78 | 79 | return model 80 | 81 | def main() -> None: 82 | args = parser.parse() 83 | if args.method == 'biodenoising16k_dns48': 84 | args.biodenoising16k_dns48 = True 85 | run_adaptation_main(args) 86 | 87 | 88 | if __name__ == "__main__": 89 | main() -------------------------------------------------------------------------------- /biodenoising/conf/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dset: debug 3 | - hydra/job_logging: colorlog 4 | - hydra/hydra_logging: colorlog 5 | 6 | # Dataset related 7 | full_size: 1000000 8 | nsources: 1 9 | sample_rate: 16000 10 | segment: 4 # in seconds, 11 | stride: 1. # in seconds, how much to stride between training examples 12 | pad: true # if training sample is too short, pad it 13 | epoch_size: 1000 # number of samples per epoch 14 | exclude: #data to exclude 15 | exclude_noise: #noise to exclude 16 | repeat_prob: 0 # probability to repeat a short sample 17 | random_repeat: false # repeat a short sample several random times 18 | random_pad: false # pad a short sample with random silence intervals 19 | silence_prob: 0.2 # probability to add silence 20 | noise_prob: 0 # probability to add noise 21 | eval_window_size: 0 # evaluation window size in samples 22 | normalize: true # normalize the clean signal 23 | random_gain: true # random gain 24 | low_gain: 0.5 # low gain 25 | high_gain: 1. # high gain 26 | parallel_noise: false # mix with noise from the same file -> for domain adaptation 27 | 28 | # Dataset Augmentation 29 | remix: false # remix noise and clean 30 | bandmask: 0 # drop at most this fraction of freqs in mel scale 31 | shift: 35000 # random shift, number of samples 32 | shift_same: false # shift noise and clean by the same amount 33 | revecho: 0 # add reverb like augment 34 | timescale: 0.8 # random time scaling 35 | flip: 0 # random flip 36 | mixup: 0. # mixup 37 | # remix: false # remix noise and clean 38 | # bandmask: 0.2 # drop at most this fraction of freqs in mel scale 39 | # shift: 64000 # random shift, number of samples 40 | # shift_same: false # shift noise and clean by the same amount 41 | # revecho: 1 # add reverb like augment 42 | # timescale: 4 # random time scaling 43 | # flip: 0.1 # random flip 44 | 45 | # Logging and printing, and does not impact training 46 | num_prints: 5 47 | device: cuda 48 | num_workers: 10 49 | verbose: 0 50 | show: 0 # just show the model and its size and exit 51 | 52 | # Checkpointing, by default automatically load last checkpoint 53 | checkpoint: true 54 | continue_from: '' # Path the a checkpoint.th file to start from. 55 | # this is not used in the name of the experiment! 56 | # so use a dummy=something not to mixup experiments. 57 | continue_best: false # continue from best, not last state if continue_from is set. 58 | continue_pretrained: # use either dns48, dns64 or master64 to fine tune from pretrained-model 59 | restart: true # Ignore existing checkpoints 60 | checkpoint_file: checkpoint.th 61 | best_file: best.th # will contain only best model at any point 62 | history_file: history.json 63 | samples_dir: samples 64 | save_again: false # if true, only load checkpoint and save again, useful to reexport best.th 65 | 66 | # Other stuff 67 | seed: 0 68 | dummy: # use this if you want twice the same exp, with a different name 69 | 70 | # Evaluation stuff 71 | pesq: True # compute pesq? 72 | eval_every: 1 # compute test metrics every so epochs 73 | dry: 0. # dry/wet knob value at eval 74 | streaming: False # use streaming evaluation for Demucs 75 | 76 | # Optimization related 77 | optim: adam 78 | swa_scheduler: true 79 | swa_start: 3 80 | lr: 1e-5 ### l1 loss 81 | #lr: 1e-3 ### sisdr loss 82 | weight_decay: 0 83 | beta1: 0.9 84 | beta2: 0.999 85 | loss: l1 86 | stft_loss: False 87 | stft_sc_factor: .5 88 | stft_mag_factor: .5 89 | stft_mask: False 90 | stft_mask_threshold: -60 91 | epochs: 30 92 | batch_size: 32 93 | clip_grad_norm: 10 94 | clamp_loss: 30 95 | rms_loss: 0 96 | 97 | # Teacher-student experiment 98 | teacher_student: False 99 | bootstrap_remix: False 100 | rescale_to_input_mixture: False 101 | apply_mixture_consistency: False 102 | n_epochs_teacher_update: 1 103 | teacher_momentum: 0.01 104 | other_noise: True 105 | 106 | # Models 107 | model: demucs # either demucs or cleanunet 108 | demucs: 109 | chin: 1 110 | chout: 1 111 | hidden: 48 112 | max_hidden: 10000 113 | causal: true 114 | glu: true 115 | depth: 5 116 | kernel_size: 8 117 | stride: 4 118 | normalize: true 119 | resample: 4 120 | growth: 2 121 | rescale: 0.1 122 | noisereduce: False 123 | nr_mode: 'concat' 124 | cleanunet: 125 | channels_input: 1 126 | channels_output: 1 127 | channels_H: 64 128 | max_H: 768 129 | encoder_n_layers: 8 130 | kernel_size: 4 131 | stride: 2 132 | tsfm_n_layers: 5 133 | tsfm_n_head: 8 134 | tsfm_d_model: 512 135 | tsfm_d_inner: 2048 136 | 137 | # Experiment launching, distributed 138 | ddp: false 139 | ddp_backend: nccl 140 | rendezvous_file: ./rendezvous 141 | 142 | # Internal config, don't set manually 143 | rank: 144 | world_size: 145 | 146 | # Hydra config 147 | hydra: 148 | run: 149 | dir: ./outputs/exp_${hydra.job.override_dirname} 150 | job: 151 | config: 152 | # configuration for the ${hydra.job.override_dirname} runtime variable 153 | override_dirname: 154 | kv_sep: '=' 155 | item_sep: ',' 156 | # Remove all paths, as the / in them would mess up things 157 | # Remove params that would not impact the training itself 158 | # Remove all slurm and submit params. 159 | # This is ugly I know... 160 | exclude_keys: [ 161 | 'hydra.job_logging.handles.file.filename', 162 | 'dset.train', 'dset.valid', 'dset.test', 'dset.noisy_json', 'dset.noisy_dir', 163 | 'num_prints', 'continue_from', 'save_again', 164 | 'device', 'num_workers', 'print_freq', 'restart', 'verbose', 165 | 'log', 'ddp', 'ddp_backend', 'rendezvous_file', 'rank', 'world_size'] 166 | job_logging: 167 | handlers: 168 | file: 169 | class: logging.FileHandler 170 | mode: w 171 | formatter: colorlog 172 | filename: trainer.log 173 | console: 174 | class: logging.StreamHandler 175 | formatter: colorlog 176 | stream: ext://sys.stderr 177 | 178 | hydra_logging: 179 | handlers: 180 | console: 181 | class: logging.StreamHandler 182 | formatter: colorlog 183 | stream: ext://sys.stderr 184 | -------------------------------------------------------------------------------- /biodenoising/denoiser/live.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/facebookresearch/demucs under the MIT License 2 | # Original Copyright (c) Earth Species Project. This work is based on Facebook's denoiser. 3 | 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # author: adefossez 8 | 9 | import argparse 10 | import sys 11 | 12 | import sounddevice as sd 13 | import torch 14 | 15 | from .demucs import DemucsStreamer 16 | from .pretrained import add_model_flags, get_model 17 | from .utils import bold 18 | 19 | 20 | def get_parser(): 21 | parser = argparse.ArgumentParser( 22 | "denoiser.live", 23 | description="Performs live speech enhancement, reading audio from " 24 | "the default mic (or interface specified by --in) and " 25 | "writing the enhanced version to 'Soundflower (2ch)' " 26 | "(or the interface specified by --out)." 27 | ) 28 | parser.add_argument( 29 | "-i", "--in", dest="in_", 30 | help="name or index of input interface.") 31 | parser.add_argument( 32 | "-o", "--out", default="Soundflower (2ch)", 33 | help="name or index of output interface.") 34 | add_model_flags(parser) 35 | parser.add_argument( 36 | "--no_compressor", action="store_false", dest="compressor", 37 | help="Deactivate compressor on output, might lead to clipping.") 38 | parser.add_argument( 39 | "--device", default="cpu") 40 | parser.add_argument( 41 | "--dry", type=float, default=0.04, 42 | help="Dry/wet knob, between 0 and 1. 0=maximum noise removal " 43 | "but it might cause distortions. Default is 0.04") 44 | parser.add_argument( 45 | "-t", "--num_threads", type=int, 46 | help="Number of threads. If you have DDR3 RAM, setting -t 1 can " 47 | "improve performance.") 48 | parser.add_argument( 49 | "-f", "--num_frames", type=int, default=1, 50 | help="Number of frames to process at once. Larger values increase " 51 | "the overall lag, but will improve speed.") 52 | return parser 53 | 54 | 55 | def parse_audio_device(device): 56 | if device is None: 57 | return device 58 | try: 59 | return int(device) 60 | except ValueError: 61 | return device 62 | 63 | 64 | def query_devices(device, kind): 65 | try: 66 | caps = sd.query_devices(device, kind=kind) 67 | except ValueError: 68 | message = bold(f"Invalid {kind} audio interface {device}.\n") 69 | message += ( 70 | "If you are on Mac OS X, try installing Soundflower " 71 | "(https://github.com/mattingalls/Soundflower).\n" 72 | "You can list available interfaces with `python3 -m sounddevice` on Linux and OS X, " 73 | "and `python.exe -m sounddevice` on Windows. You must have at least one loopback " 74 | "audio interface to use this.") 75 | print(message, file=sys.stderr) 76 | sys.exit(1) 77 | return caps 78 | 79 | 80 | def main(): 81 | args = get_parser().parse_args() 82 | if args.num_threads: 83 | torch.set_num_threads(args.num_threads) 84 | 85 | model = get_model(args).to(args.device) 86 | model.eval() 87 | print("Model loaded.") 88 | streamer = DemucsStreamer(model, dry=args.dry, num_frames=args.num_frames) 89 | 90 | device_in = parse_audio_device(args.in_) 91 | caps = query_devices(device_in, "input") 92 | channels_in = min(caps['max_input_channels'], 2) 93 | stream_in = sd.InputStream( 94 | device=device_in, 95 | samplerate=model.sample_rate, 96 | channels=channels_in) 97 | 98 | device_out = parse_audio_device(args.out) 99 | caps = query_devices(device_out, "output") 100 | channels_out = min(caps['max_output_channels'], 2) 101 | stream_out = sd.OutputStream( 102 | device=device_out, 103 | samplerate=model.sample_rate, 104 | channels=channels_out) 105 | 106 | stream_in.start() 107 | stream_out.start() 108 | first = True 109 | current_time = 0 110 | last_log_time = 0 111 | last_error_time = 0 112 | cooldown_time = 2 113 | log_delta = 10 114 | sr_ms = model.sample_rate / 1000 115 | stride_ms = streamer.stride / sr_ms 116 | print(f"Ready to process audio, total lag: {streamer.total_length / sr_ms:.1f}ms.") 117 | while True: 118 | try: 119 | if current_time > last_log_time + log_delta: 120 | last_log_time = current_time 121 | tpf = streamer.time_per_frame * 1000 122 | rtf = tpf / stride_ms 123 | print(f"time per frame: {tpf:.1f}ms, ", end='') 124 | print(f"RTF: {rtf:.1f}") 125 | streamer.reset_time_per_frame() 126 | 127 | length = streamer.total_length if first else streamer.stride 128 | first = False 129 | current_time += length / model.sample_rate 130 | frame, overflow = stream_in.read(length) 131 | frame = torch.from_numpy(frame).mean(dim=1).to(args.device) 132 | with torch.no_grad(): 133 | out = streamer.feed(frame[None])[0] 134 | if not out.numel(): 135 | continue 136 | if args.compressor: 137 | out = 0.99 * torch.tanh(out) 138 | out = out[:, None].repeat(1, channels_out) 139 | mx = out.abs().max().item() 140 | if mx > 1: 141 | print("Clipping!!") 142 | out.clamp_(-1, 1) 143 | out = out.cpu().numpy() 144 | underflow = stream_out.write(out) 145 | if overflow or underflow: 146 | if current_time >= last_error_time + cooldown_time: 147 | last_error_time = current_time 148 | tpf = 1000 * streamer.time_per_frame 149 | print(f"Not processing audio fast enough, time per frame is {tpf:.1f}ms " 150 | f"(should be less than {stride_ms:.1f}ms).") 151 | except KeyboardInterrupt: 152 | print("Stopping") 153 | break 154 | stream_out.stop() 155 | stream_in.stop() 156 | 157 | 158 | if __name__ == "__main__": 159 | main() 160 | -------------------------------------------------------------------------------- /biodenoising/denoiser/stft_loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Adapted from https://github.com/facebookresearch/demucs under the MIT License 3 | # Original Copyright (c) Earth Species Project. This work is based on Facebook's denoiser. 4 | 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | # Original copyright 2019 Tomoki Hayashi 10 | # MIT License (https://opensource.org/licenses/MIT) 11 | 12 | """STFT-based Loss modules.""" 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | 17 | EPS = 1e-8 18 | 19 | def stft(x, fft_size, hop_size, win_length, window): 20 | """Perform STFT and convert to magnitude spectrogram. 21 | Args: 22 | x (Tensor): Input signal tensor (B, T). 23 | fft_size (int): FFT size. 24 | hop_size (int): Hop size. 25 | win_length (int): Window length. 26 | window (str): Window function type. 27 | Returns: 28 | Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1). 29 | """ 30 | x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True, pad_mode='constant') # mstft loss breaks the determinism when pad_mode='constant' because reflection_pad1d_backward_out_cuda does not have a deterministic implementation 31 | x_stft = torch.view_as_real(x_stft) 32 | real = x_stft[..., 0] 33 | imag = x_stft[..., 1] 34 | 35 | # NOTE(kan-bayashi): clamp is needed to avoid nan or inf 36 | return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1) 37 | 38 | 39 | class SpectralConvergengeLoss(torch.nn.Module): 40 | """Spectral convergence loss module.""" 41 | 42 | def __init__(self): 43 | """Initilize spectral convergence loss module.""" 44 | super(SpectralConvergengeLoss, self).__init__() 45 | 46 | def forward(self, x_mag, y_mag): 47 | """Calculate forward propagation. 48 | Args: 49 | x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). 50 | y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). 51 | Returns: 52 | Tensor: Spectral convergence loss value. 53 | """ 54 | return torch.norm(y_mag - x_mag, p="fro") / (torch.norm(y_mag, p="fro")+EPS) 55 | 56 | 57 | class LogSTFTMagnitudeLoss(torch.nn.Module): 58 | """Log STFT magnitude loss module.""" 59 | 60 | def __init__(self): 61 | """Initilize los STFT magnitude loss module.""" 62 | super(LogSTFTMagnitudeLoss, self).__init__() 63 | 64 | def forward(self, x_mag, y_mag): 65 | """Calculate forward propagation. 66 | Args: 67 | x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). 68 | y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). 69 | Returns: 70 | Tensor: Log STFT magnitude loss value. 71 | """ 72 | return F.l1_loss(torch.log(y_mag), torch.log(x_mag)) 73 | 74 | 75 | class STFTLoss(torch.nn.Module): 76 | """STFT loss module.""" 77 | 78 | def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window"): 79 | """Initialize STFT loss module.""" 80 | super(STFTLoss, self).__init__() 81 | self.fft_size = fft_size 82 | self.shift_size = shift_size 83 | self.win_length = win_length 84 | self.register_buffer("window", getattr(torch, window)(win_length)) 85 | self.spectral_convergenge_loss = SpectralConvergengeLoss() 86 | self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss() 87 | 88 | def forward(self, x, y): 89 | """Calculate forward propagation. 90 | Args: 91 | x (Tensor): Predicted signal (B, T). 92 | y (Tensor): Groundtruth signal (B, T). 93 | Returns: 94 | Tensor: Spectral convergence loss value. 95 | Tensor: Log STFT magnitude loss value. 96 | """ 97 | x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window) 98 | y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window) 99 | sc_loss = self.spectral_convergenge_loss(x_mag, y_mag) 100 | mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag) 101 | 102 | return sc_loss, mag_loss 103 | 104 | 105 | class MultiResolutionSTFTLoss(torch.nn.Module): 106 | """Multi resolution STFT loss module.""" 107 | 108 | def __init__(self, 109 | fft_sizes=[1024, 2048, 512], 110 | hop_sizes=[120, 240, 50], 111 | win_lengths=[600, 1200, 240], 112 | window="hann_window", factor_sc=0.1, factor_mag=0.1, threshold=-60, mask=False): 113 | """Initialize Multi resolution STFT loss module. 114 | Args: 115 | fft_sizes (list): List of FFT sizes. 116 | hop_sizes (list): List of hop sizes. 117 | win_lengths (list): List of window lengths. 118 | window (str): Window function type. 119 | factor (float): a balancing factor across different losses. 120 | """ 121 | super(MultiResolutionSTFTLoss, self).__init__() 122 | assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) 123 | self.stft_losses = torch.nn.ModuleList() 124 | for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths): 125 | self.stft_losses += [STFTLoss(fs, ss, wl, window)] 126 | self.factor_sc = factor_sc 127 | self.factor_mag = factor_mag 128 | self.threshold = threshold 129 | self.mask = mask 130 | 131 | def forward(self, x, y): 132 | """Calculate forward propagation. 133 | Args: 134 | x (Tensor): Predicted signal (B, T). 135 | y (Tensor): Groundtruth signal (B, T). 136 | Returns: 137 | Tensor: Multi resolution spectral convergence loss value. 138 | Tensor: Multi resolution log STFT magnitude loss value. 139 | """ 140 | if self.mask: 141 | y_dB = 20 * torch.log10(y.abs() + 1e-10) 142 | mask = (y_dB > self.threshold).float() 143 | y = y * mask 144 | x = x * mask 145 | sc_loss = 0.0 146 | mag_loss = 0.0 147 | for f in self.stft_losses: 148 | sc_l, mag_l = f(x, y) 149 | sc_loss += sc_l 150 | mag_loss += mag_l 151 | sc_loss /= len(self.stft_losses) 152 | mag_loss /= len(self.stft_losses) 153 | 154 | return self.factor_sc*sc_loss, self.factor_mag*mag_loss 155 | -------------------------------------------------------------------------------- /biodenoising/denoiser/pretrained.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/facebookresearch/demucs under the MIT License 2 | # Original Copyright (c) Earth Species Project. This work is based on Facebook's denoiser. 3 | 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # author: adefossez 8 | 9 | import os 10 | import logging 11 | 12 | import torch.hub 13 | 14 | from .cleanunet import CleanUNet 15 | from .demucs import Demucs 16 | from .htdemucs import HTDemucs 17 | from .utils import deserialize_model, load_model_state_dict 18 | from .states import set_state 19 | 20 | logger = logging.getLogger(__name__) 21 | ROOT = "https://dl.fbaipublicfiles.com/adiyoss/denoiser/" 22 | DNS_48_URL = ROOT + "dns48-11decc9d8e3f0998.th" 23 | DNS_64_URL = ROOT + "dns64-a7761ff99a7d5bb6.th" 24 | MASTER_64_URL = ROOT + "master64-8a5dfb4bb92753dd.th" 25 | VALENTINI_NC = ROOT + 'valentini_nc-93fc4337.th' # Non causal Demucs on Valentini 26 | ### list of all music demucs models https://raw.githubusercontent.com/facebookresearch/demucs/main/demucs/remote/files.txt 27 | DEMUCSV4_URL = "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/f7e0c4bc-ba3fe64a.th" 28 | CLEANUNET_URL = "https://github.com/NVIDIA/CleanUNet/raw/main/exp/DNS-large-full/checkpoint/pretrained.pkl" 29 | 30 | BIO_ROOT = "https://storage.googleapis.com/esp-public-files/biodenoising/" 31 | BIO_DNS_48_URL = BIO_ROOT + "model-16kHz-dns48.th" 32 | 33 | def _demucs(pretrained, url, **kwargs): 34 | if '/demucs/' not in url: 35 | model = Demucs(**kwargs, sample_rate=16_000) 36 | if pretrained: 37 | #import pdb; pdb.set_trace() 38 | if '/demucs/' in url: 39 | full_model = torch.hub.load_state_dict_from_url(url, map_location='cpu') 40 | args = full_model["args"] 41 | kwargs = dict(kwargs, **full_model["kwargs"]) 42 | model = HTDemucs(*args, **kwargs) 43 | state_dict = full_model['state'] 44 | state_dict = set_state(model, state_dict) 45 | else: 46 | state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu') 47 | if 'model' in state_dict: 48 | state_dict = state_dict['model']['state'] 49 | model = load_model_state_dict(model, state_dict) 50 | return model 51 | 52 | def biodenoising16k_dns48(pretrained=True): 53 | return _demucs(pretrained, BIO_DNS_48_URL, hidden=48) 54 | 55 | def dns48(pretrained=True): 56 | return _demucs(pretrained, DNS_48_URL, hidden=48) 57 | 58 | 59 | def dns64(pretrained=True): 60 | return _demucs(pretrained, DNS_64_URL, hidden=64) 61 | 62 | 63 | def master64(pretrained=True): 64 | return _demucs(pretrained, MASTER_64_URL, hidden=64) 65 | 66 | 67 | def valentini_nc(pretrained=True): 68 | return _demucs(pretrained, VALENTINI_NC, hidden=64, causal=False, stride=2, resample=2) 69 | 70 | def demucsv4(pretrained=True): 71 | return _demucs(pretrained, DEMUCSV4_URL) 72 | 73 | def cleanunet_speech(args): 74 | if 'cleanunet' not in args: 75 | args.cleanunet = { 76 | "channels_input": 1, 77 | "channels_output": 1, 78 | "channels_H": 64, 79 | "max_H": 768, 80 | "encoder_n_layers": 8, 81 | "kernel_size": 4, 82 | "stride": 2, 83 | "tsfm_n_layers": 5, 84 | "tsfm_n_head": 8, 85 | "tsfm_d_model": 512, 86 | "tsfm_d_inner": 2048 87 | } 88 | model = CleanUNet(**args.cleanunet) 89 | checkpoint = torch.hub.load_state_dict_from_url(CLEANUNET_URL, map_location='cpu') 90 | model.load_state_dict(checkpoint['model_state_dict']) 91 | return model 92 | 93 | 94 | 95 | def add_model_flags(parser): 96 | group = parser.add_mutually_exclusive_group(required=False) 97 | group.add_argument("-m", "--model_path", help="Path to local trained model.") 98 | group.add_argument("--biodenoising16k_dns48", action="store_true", 99 | help="Use the biodenoising 16kHz pre-trained real time H=48 model.") 100 | group.add_argument("--dns48", action="store_true", 101 | help="Use pre-trained real time H=48 model trained on DNS.") 102 | group.add_argument("--dns64", action="store_true", 103 | help="Use pre-trained real time H=64 model trained on DNS.") 104 | group.add_argument("--master64", action="store_true", 105 | help="Use pre-trained real time H=64 model trained on DNS and Valentini.") 106 | group.add_argument("--valentini_nc", action="store_true", 107 | help="Use pre-trained H=64 model trained on Valentini, non causal.") 108 | group.add_argument("--demucsv4", action="store_true", 109 | help="Use pre-trained music hybrid-transformer demucs.") 110 | group.add_argument("--cleanunet_speech", action="store_true", 111 | help="Use pre-trained cleanunet model trained on DNS.") 112 | 113 | 114 | def get_model(args): 115 | """ 116 | Load local model package or torchhub pre-trained model. 117 | """ 118 | if args.model_path: 119 | logger.info("Loading model from %s", args.model_path) 120 | pkg = torch.load(args.model_path, 'cpu', weights_only=False) 121 | if 'model' in pkg: 122 | # if 'best_state' in pkg: 123 | # logger.info("Loading best model state.") 124 | # pkg['model']['state'] = pkg['best_state'] 125 | model = deserialize_model(pkg['model']) 126 | else: 127 | 128 | model = deserialize_model(pkg) 129 | elif args.biodenoising16k_dns48: 130 | logger.info("Loading the biodenoising 16kHz pre-trained real time H=64 model.") 131 | model = biodenoising16k_dns48() 132 | elif args.dns64: 133 | logger.info("Loading pre-trained real time H=64 model trained on DNS.") 134 | model = dns64() 135 | elif args.master64: 136 | logger.info("Loading pre-trained real time H=64 model trained on DNS and Valentini.") 137 | model = master64() 138 | elif args.demucsv4: 139 | logger.info("Loading pre-trained music hybrid-transformer demucs (v4).") 140 | model = demucsv4() 141 | elif args.valentini_nc: 142 | logger.info("Loading pre-trained H=64 model trained on Valentini.") 143 | model = valentini_nc() 144 | elif args.cleanunet_speech: 145 | logger.info("Loading pre-trained cleanunet model trained on DNS.") 146 | model = cleanunet_speech(args) 147 | else: 148 | logger.info("Loading pre-trained real time H=48 model trained on DNS.") 149 | model = dns48() 150 | logger.debug(model) 151 | return model 152 | -------------------------------------------------------------------------------- /biodenoising/denoiser/enhance.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/facebookresearch/demucs under the MIT License 2 | # Original Copyright (c) Earth Species Project. This work is based on Facebook's denoiser. 3 | 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # author: adiyoss 8 | 9 | import argparse 10 | from concurrent.futures import ProcessPoolExecutor 11 | import json 12 | import logging 13 | import os 14 | import sys 15 | 16 | import torch 17 | import torchaudio 18 | 19 | from .audio import Audioset, find_audio_files 20 | from . import distrib, pretrained 21 | from .demucs import DemucsStreamer 22 | 23 | from .utils import LogProgress 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | def add_flags(parser): 29 | """ 30 | Add the flags for the argument parser that are related to model loading and evaluation" 31 | """ 32 | pretrained.add_model_flags(parser) 33 | parser.add_argument('--device', default="cpu") 34 | parser.add_argument('--dry', type=float, default=0, 35 | help='dry/wet knob coefficient. 0 is only denoised, 1 only input signal.') 36 | parser.add_argument('--num_workers', type=int, default=10) 37 | parser.add_argument('--streaming', action="store_true", 38 | help="true streaming evaluation for Demucs") 39 | 40 | 41 | parser = argparse.ArgumentParser( 42 | 'denoiser.enhance', 43 | description="Speech enhancement using Demucs - Generate enhanced files") 44 | add_flags(parser) 45 | parser.add_argument("--out_dir", type=str, default="enhanced", 46 | help="directory putting enhanced wav files") 47 | parser.add_argument("--batch_size", default=1, type=int, help="batch size") 48 | parser.add_argument('-v', '--verbose', action='store_const', const=logging.DEBUG, 49 | default=logging.INFO, help="more loggging") 50 | 51 | group = parser.add_mutually_exclusive_group() 52 | group.add_argument("--noisy_dir", type=str, default=None, 53 | help="directory including noisy wav files") 54 | group.add_argument("--noisy_json", type=str, default=None, 55 | help="json file including noisy wav files") 56 | 57 | 58 | def get_estimate(model, noisy, args): 59 | torch.set_num_threads(1) 60 | if args.streaming: 61 | streamer = DemucsStreamer(model, dry=args.dry) 62 | with torch.no_grad(): 63 | estimate = torch.cat([ 64 | streamer.feed(noisy[0]), 65 | streamer.flush()], dim=1)[None] 66 | else: 67 | with torch.no_grad(): 68 | if hasattr(model, 'ola_forward'): 69 | while noisy.ndim < 3: 70 | noisy = noisy.unsqueeze(0) 71 | # if noisy.shape[-1] < model.window_size: 72 | # noisy = torch.cat([noisy, torch.zeros((1,1,model.window_size - noisy.shape[-1])).to(args.device)], dim=-1) 73 | estimate = model.forward(noisy) 74 | else: 75 | estimate = model(noisy) 76 | assert estimate.shape[-1] == noisy.shape[-1] 77 | estimate = (1 - args.dry) * estimate + args.dry * noisy 78 | return estimate 79 | 80 | 81 | def save_wavs(estimates, noisy_sigs, filenames, out_dir, sr=16_000, experiment_logger=None): 82 | # Write result 83 | for estimate, noisy, filename in zip(estimates, noisy_sigs, filenames): 84 | filename = os.path.join(out_dir, os.path.basename(filename).rsplit(".", 1)[0]) 85 | write(noisy, filename + "_noisy.wav", sr=sr) 86 | write(estimate, filename + "_enhanced.wav", sr=sr) 87 | if experiment_logger is not None: 88 | experiment_logger.log_audio(noisy.squeeze().detach().cpu().numpy(), 89 | sample_rate=sr, 90 | file_name=filename + "_noisy.wav", 91 | metadata=None, overwrite=False, 92 | step=experiment_logger.step) 93 | experiment_logger.log_audio(estimate.squeeze().detach().cpu().numpy(), 94 | sample_rate=sr, 95 | file_name=filename + "_enhanced.wav", 96 | metadata=None, overwrite=False, 97 | step=experiment_logger.step) 98 | 99 | 100 | def write(wav, filename, sr=16_000): 101 | # Normalize audio if it prevents clipping 102 | wav = wav / max(wav.abs().max().item(), 1) 103 | torchaudio.save(filename, wav.cpu(), sr) 104 | 105 | 106 | def get_dataset(args, sample_rate, channels): 107 | if hasattr(args, 'dset'): 108 | paths = args.dset 109 | else: 110 | paths = args 111 | if paths.noisy_json: 112 | with open(paths.noisy_json) as f: 113 | files = json.load(f) 114 | elif paths.noisy_dir: 115 | files = find_audio_files(paths.noisy_dir) 116 | else: 117 | logger.warning( 118 | "Small sample set was not provided by either noisy_dir or noisy_json. " 119 | "Skipping enhancement.") 120 | return None 121 | return Audioset(files, with_path=True, 122 | sample_rate=sample_rate, channels=channels, convert=True) 123 | 124 | 125 | def _estimate_and_save(model, noisy_signals, filenames, out_dir, args, experiment_logger=None): 126 | estimate = get_estimate(model, noisy_signals, args) 127 | save_wavs(estimate, noisy_signals, filenames, out_dir, sr=model.sample_rate, experiment_logger=experiment_logger) 128 | 129 | 130 | def enhance(args, model=None, local_out_dir=None, experiment_logger=None): 131 | # Load model 132 | if not model: 133 | model = pretrained.get_model(args).to(args.device) 134 | model.eval() 135 | if local_out_dir: 136 | out_dir = local_out_dir 137 | else: 138 | out_dir = args.out_dir 139 | 140 | dset = get_dataset(args, model.sample_rate, model.chin) 141 | if dset is None: 142 | return 143 | loader = distrib.loader(dset, batch_size=1) 144 | 145 | if distrib.rank == 0: 146 | os.makedirs(out_dir, exist_ok=True) 147 | distrib.barrier() 148 | 149 | with ProcessPoolExecutor(args.num_workers) as pool: 150 | iterator = LogProgress(logger, loader, name="Generate enhanced files") 151 | pendings = [] 152 | for i,data in enumerate(iterator): 153 | # Get batch data 154 | if len(data) > 1 and isinstance(data[1][0], str): 155 | noisy_signals, filenames, _ = data 156 | else: 157 | filenames = str(i) 158 | noisy_signals = noisy_signals.to(args.device) 159 | if args.device == 'cpu' and args.num_workers > 1: 160 | pendings.append( 161 | pool.submit(_estimate_and_save, 162 | model, noisy_signals, filenames, out_dir, args)) 163 | else: 164 | # Forward 165 | estimate = get_estimate(model, noisy_signals, args) 166 | save_wavs(estimate, noisy_signals, filenames, out_dir, sr=model.sample_rate, experiment_logger=experiment_logger) 167 | 168 | if pendings: 169 | print('Waiting for pending jobs...') 170 | for pending in LogProgress(logger, pendings, updates=5, name="Generate enhanced files"): 171 | pending.result() 172 | 173 | 174 | if __name__ == "__main__": 175 | args = parser.parse_args() 176 | logging.basicConfig(stream=sys.stderr, level=args.verbose) 177 | logger.debug(args) 178 | enhance(args, local_out_dir=args.out_dir) 179 | -------------------------------------------------------------------------------- /biodenoising/denoiser/utils.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/facebookresearch/demucs under the MIT License 2 | # Original Copyright (c) Earth Species Project. This work is based on Facebook's denoiser. 3 | 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # author: adefossez 8 | 9 | import functools 10 | import logging 11 | from contextlib import contextmanager 12 | import inspect 13 | import time 14 | import sys 15 | 16 | import torch 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | def capture_init(init): 21 | """capture_init. 22 | 23 | Decorate `__init__` with this, and you can then 24 | recover the *args and **kwargs passed to it in `self._init_args_kwargs` 25 | """ 26 | @functools.wraps(init) 27 | def __init__(self, *args, **kwargs): 28 | self._init_args_kwargs = (args, kwargs) 29 | init(self, *args, **kwargs) 30 | 31 | return __init__ 32 | 33 | 34 | def deserialize_model(package, strict=False): 35 | """deserialize_model. 36 | 37 | """ 38 | klass = package['class'] 39 | kwargs = package['kwargs'] 40 | if 'sample_rate' not in kwargs: 41 | logger.warning( 42 | "Training sample rate not available!, 16kHz will be assumed. " 43 | "If you used a different sample rate at train time, please fix your checkpoint " 44 | "with the command `./train.py [TRAINING_ARGS] save_again=true.") 45 | if strict: 46 | model = klass(*package['args'], **kwargs) 47 | else: 48 | sig = inspect.signature(klass) 49 | kw = package['kwargs'] 50 | for key in list(kw): 51 | if key not in sig.parameters: 52 | logger.warning("Dropping inexistant parameter %s", key) 53 | del kw[key] 54 | model = klass(*package['args'], **kw) 55 | model.load_state_dict(package['state']) 56 | return model 57 | 58 | 59 | def copy_state(state): 60 | return {k: v.cpu().clone() for k, v in state.items()} 61 | 62 | 63 | def serialize_model(model): 64 | args, kwargs = model._init_args_kwargs 65 | state = copy_state(model.state_dict()) 66 | return {"class": model.__class__, "args": args, "kwargs": kwargs, "state": state} 67 | 68 | 69 | @contextmanager 70 | def swap_state(model, state): 71 | """ 72 | Context manager that swaps the state of a model, e.g: 73 | 74 | # model is in old state 75 | with swap_state(model, new_state): 76 | # model in new state 77 | # model back to old state 78 | """ 79 | old_state = copy_state(model.state_dict()) 80 | model.load_state_dict(state) 81 | try: 82 | yield 83 | finally: 84 | model.load_state_dict(old_state) 85 | 86 | 87 | def pull_metric(history, name): 88 | out = [] 89 | for metrics in history: 90 | if name in metrics: 91 | out.append(metrics[name]) 92 | return out 93 | 94 | 95 | class LogProgress: 96 | """ 97 | Sort of like tqdm but using log lines and not as real time. 98 | Args: 99 | - logger: logger obtained from `logging.getLogger`, 100 | - iterable: iterable object to wrap 101 | - updates (int): number of lines that will be printed, e.g. 102 | if `updates=5`, log every 1/5th of the total length. 103 | - total (int): length of the iterable, in case it does not support 104 | `len`. 105 | - name (str): prefix to use in the log. 106 | - level: logging level (like `logging.INFO`). 107 | """ 108 | def __init__(self, 109 | logger, 110 | iterable, 111 | updates=5, 112 | total=None, 113 | name="LogProgress", 114 | level=logging.INFO): 115 | self.iterable = iterable 116 | self.total = total or len(iterable) 117 | self.updates = updates 118 | self.name = name 119 | self.logger = logger 120 | self.level = level 121 | 122 | def update(self, **infos): 123 | self._infos = infos 124 | 125 | def __iter__(self): 126 | self._iterator = iter(self.iterable) 127 | self._index = -1 128 | self._infos = {} 129 | self._begin = time.time() 130 | return self 131 | 132 | def __next__(self): 133 | self._index += 1 134 | try: 135 | value = next(self._iterator) 136 | except StopIteration: 137 | raise 138 | else: 139 | return value 140 | finally: 141 | log_every = max(1, self.total // self.updates) 142 | # logging is delayed by 1 it, in order to have the metrics from update 143 | if self._index >= 1 and self._index % log_every == 0: 144 | self._log() 145 | 146 | def _log(self): 147 | self._speed = (1 + self._index) / (time.time() - self._begin) 148 | infos = " | ".join(f"{k.capitalize()} {v}" for k, v in self._infos.items()) 149 | if self._speed < 1e-4: 150 | speed = "oo sec/it" 151 | elif self._speed < 0.1: 152 | speed = f"{1/self._speed:.1f} sec/it" 153 | else: 154 | speed = f"{self._speed:.1f} it/sec" 155 | out = f"{self.name} | {self._index}/{self.total} | {speed}" 156 | if infos: 157 | out += " | " + infos 158 | self.logger.log(self.level, out) 159 | 160 | 161 | def colorize(text, color): 162 | """ 163 | Display text with some ANSI color in the terminal. 164 | """ 165 | code = f"\033[{color}m" 166 | restore = "\033[0m" 167 | return "".join([code, text, restore]) 168 | 169 | 170 | def bold(text): 171 | """ 172 | Display text in bold in the terminal. 173 | """ 174 | return colorize(text, "1") 175 | 176 | def load_model_state_dict(model, state_dict): 177 | current_model_dict = model.state_dict() 178 | new_state_dict={k:v for k,v in state_dict.items() if list(v.size())==list(current_model_dict[k].size())} 179 | model.load_state_dict(new_state_dict, strict=False) 180 | return model 181 | 182 | def apply_output_transform(rec_sources_wavs, input_mix_std, 183 | input_mix_mean, input_mom, args): 184 | if args.rescale_to_input_mixture: 185 | rec_sources_wavs = (rec_sources_wavs * input_mix_std) + input_mix_mean 186 | if args.apply_mixture_consistency: 187 | rec_sources_wavs = apply_mixture_consistency(rec_sources_wavs, input_mom) 188 | return rec_sources_wavs 189 | 190 | 191 | def apply_mixture_consistency(pr_batch, input_mixture, mix_weights_type="uniform"): 192 | """Apply mixture consistency 193 | :param pr_batch: Torch Tensors of size: 194 | batch_size x self.n_sources x length_of_wavs 195 | :param input_mixture: Torch Tensors of size: 196 | batch_size x 1 x length_of_wavs 197 | :param mix_weights_type: type of wights applied 198 | """ 199 | num_sources = pr_batch.shape[1] 200 | pr_mixture = torch.sum(pr_batch, 1, keepdim=True) 201 | 202 | if mix_weights_type == "magsq": 203 | mix_weights = torch.mean(pr_batch ** 2, -1, keepdim=True) 204 | mix_weights /= torch.sum(mix_weights, 1, keepdim=True) + 1e-8 205 | elif mix_weights_type == "uniform": 206 | mix_weights = 1.0 / num_sources 207 | else: 208 | raise ValueError( 209 | "Invalid mixture consistency weight type: {}" "".format(mix_weights_type) 210 | ) 211 | 212 | source_correction = mix_weights * (input_mixture - pr_mixture) 213 | return pr_batch + source_correction 214 | -------------------------------------------------------------------------------- /scripts/Biodenoising_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "4FYJAV3ElZ5h" 7 | }, 8 | "source": [ 9 | "#Biodenoising - Animal vocalization denoising\n", 10 | "This is a demo for animal vocalization denoising without access to clean data.\n", 11 | "For the more info check the [associated page](https://mariusmiron.com/research/biodenoising/) and the code repository on [github](https://github.com/earthspecies/biodenoising).\n", 12 | "\n", 13 | "First, let's install the package from pip:" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": { 20 | "colab": { 21 | "base_uri": "https://localhost:8080/" 22 | }, 23 | "id": "pl7pGTRidM1Q", 24 | "outputId": "cb0bdaef-c7fc-4aa1-9b61-cb566bf51b51" 25 | }, 26 | "outputs": [], 27 | "source": [ 28 | "!pip install biodenoising" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": { 34 | "id": "yzZIcD2PmJxB" 35 | }, 36 | "source": [ 37 | "We import the libraries:" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": { 44 | "id": "cypBqtncgE9Q" 45 | }, 46 | "outputs": [], 47 | "source": [ 48 | "from IPython import display as disp\n", 49 | "import os\n", 50 | "import torch\n", 51 | "import torchaudio\n", 52 | "from biodenoising import pretrained\n", 53 | "from biodenoising.denoiser.dsp import convert_audio" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": { 59 | "id": "mMfAk2Z2mRp3" 60 | }, 61 | "source": [ 62 | "We download some noisy animal vocalizations from the biodenoising_validation dataset. Note that these files, species, noise conditions were not seen during training, to test for generalization." 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": { 69 | "colab": { 70 | "base_uri": "https://localhost:8080/" 71 | }, 72 | "id": "NIWalUP7g96a", 73 | "outputId": "49cf9cec-9955-4f73-cec4-591b83afcca7" 74 | }, 75 | "outputs": [], 76 | "source": [ 77 | "!wget https://storage.googleapis.com/esp-public-files/biodenoising/demo/benchmark/12_terrestrial_original.wav\n", 78 | "!wget https://storage.googleapis.com/esp-public-files/biodenoising/demo/benchmark/14_terrestrial_original.wav\n", 79 | "!wget https://storage.googleapis.com/esp-public-files/biodenoising/demo/benchmark/36_underwater_original.wav\n", 80 | "!wget https://storage.googleapis.com/esp-public-files/biodenoising/demo/benchmark/30_terrestrial_original.wav\n", 81 | "!wget https://storage.googleapis.com/esp-public-files/biodenoising/demo/benchmark/34_terrestrial_original.wav\n", 82 | "!wget https://storage.googleapis.com/esp-public-files/biodenoising/demo/benchmark/3_terrestrial_original.wav\n", 83 | "!wget https://storage.googleapis.com/esp-public-files/biodenoising/demo/benchmark/21_terrestrial_original.wav\n", 84 | "!wget https://storage.googleapis.com/esp-public-files/biodenoising/demo/benchmark/52_underwater_original.wav\n", 85 | "!wget https://storage.googleapis.com/esp-public-files/biodenoising/demo/benchmark/24_underwater_original.wav" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "metadata": { 91 | "id": "ayCjnCxFmeJH" 92 | }, 93 | "source": [ 94 | "We set the device, gpu or cpu. You can use a computing instance with a GPU for faster processing." 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": { 101 | "id": "snrSxC8Uh4lt" 102 | }, 103 | "outputs": [], 104 | "source": [ 105 | "if torch.cuda.is_available():\n", 106 | " device = torch.device('cuda')\n", 107 | "else:\n", 108 | " device = torch.device('cpu')" 109 | ] 110 | }, 111 | { 112 | "cell_type": "markdown", 113 | "metadata": { 114 | "id": "czIUjWoRmjh-" 115 | }, 116 | "source": [ 117 | "Let's load the 16kHz model. If it's the first time you run this, it will download the model locally." 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "metadata": { 124 | "colab": { 125 | "base_uri": "https://localhost:8080/" 126 | }, 127 | "id": "Kpuap38kgnxY", 128 | "outputId": "e8c0652c-bcf7-4dea-921d-10430aa0993f" 129 | }, 130 | "outputs": [], 131 | "source": [ 132 | "model = pretrained.biodenoising16k_dns48().to(device)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "markdown", 137 | "metadata": { 138 | "id": "c8tHqnWtm2v9" 139 | }, 140 | "source": [ 141 | "We use the model above to denoise the first demo sound." 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": null, 147 | "metadata": { 148 | "colab": { 149 | "base_uri": "https://localhost:8080/", 150 | "height": 134 151 | }, 152 | "id": "XexsuKUEg2W4", 153 | "outputId": "d06ac4ed-f819-4df2-f607-157c9c637c25" 154 | }, 155 | "outputs": [], 156 | "source": [ 157 | "wav, sr = torchaudio.load(os.path.join('12_terrestrial_original.wav'))\n", 158 | "wav = convert_audio(wav, sr, model.sample_rate, model.chin).to(device)\n", 159 | "with torch.no_grad():\n", 160 | " denoised = model(wav[None])[0]\n", 161 | "disp.display(disp.Audio(wav.data.cpu().numpy(), rate=model.sample_rate))\n", 162 | "disp.display(disp.Audio(denoised.data.cpu().numpy(), rate=model.sample_rate))" 163 | ] 164 | }, 165 | { 166 | "cell_type": "markdown", 167 | "metadata": { 168 | "id": "A6Ub0om2nNlc" 169 | }, 170 | "source": [ 171 | "We now process the remaining files." 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": null, 177 | "metadata": { 178 | "colab": { 179 | "background_save": true, 180 | "base_uri": "https://localhost:8080/", 181 | "height": 964 182 | }, 183 | "id": "yCMxF2oSjbBH", 184 | "outputId": "e8095075-88cc-4157-b3c4-8aabacf668f0" 185 | }, 186 | "outputs": [], 187 | "source": [ 188 | "file_list = ['36_underwater_original.wav','14_terrestrial_original.wav','3_terrestrial_original.wav','21_terrestrial_original.wav','24_underwater_original.wav','30_terrestrial_original.wav','34_terrestrial_original.wav']\n", 189 | "for f in file_list:\n", 190 | " wav, sr = torchaudio.load(os.path.join(f))\n", 191 | " wav = convert_audio(wav, sr, model.sample_rate, model.chin).to(device)\n", 192 | " with torch.no_grad():\n", 193 | " denoised = model(wav[None])[0]\n", 194 | " print(f)\n", 195 | " disp.display(disp.Audio(wav.data.cpu().numpy(), rate=model.sample_rate))\n", 196 | " disp.display(disp.Audio(denoised.data.cpu().numpy(), rate=model.sample_rate))" 197 | ] 198 | }, 199 | { 200 | "cell_type": "markdown", 201 | "metadata": { 202 | "id": "xKY-Tnh9nZib" 203 | }, 204 | "source": [ 205 | "You can download the whole benchmarking dataset from zenodo and process it (subfolder 16000/noisy). You can also check the clean vocalizations and the added noise." 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": null, 211 | "metadata": { 212 | "colab": { 213 | "base_uri": "https://localhost:8080/" 214 | }, 215 | "id": "z_i6k-iGfsih", 216 | "outputId": "a6739888-addf-4be7-e1bf-c4a73838c2c9" 217 | }, 218 | "outputs": [], 219 | "source": [ 220 | "!wget https://zenodo.org/records/13736465/files/biodenoising_validation_1.0.zip\n", 221 | "!unzip biodenoising_validation_1.0.zip" 222 | ] 223 | } 224 | ], 225 | "metadata": { 226 | "accelerator": "GPU", 227 | "colab": { 228 | "gpuType": "T4", 229 | "provenance": [] 230 | }, 231 | "kernelspec": { 232 | "display_name": "Python 3", 233 | "name": "python3" 234 | }, 235 | "language_info": { 236 | "name": "python" 237 | } 238 | }, 239 | "nbformat": 4, 240 | "nbformat_minor": 0 241 | } 242 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from concurrent.futures import ProcessPoolExecutor 3 | import json 4 | import logging 5 | import os 6 | import sys 7 | import pandas as pd 8 | 9 | import torch 10 | import torchaudio 11 | import torchmetrics 12 | from torchmetrics.functional.audio.sdr import scale_invariant_signal_distortion_ratio 13 | import noisereduce 14 | import librosa 15 | import numpy as np 16 | import scipy 17 | import biodenoising 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | def add_flags(parser): 23 | """ 24 | Add the flags for the argument parser that are related to model loading and evaluation" 25 | """ 26 | #biodenoising.denoiser.pretrained.add_model_flags(parser) 27 | parser.add_argument('--device', default="cpu") 28 | parser.add_argument('--num_workers', type=int, default=1) 29 | 30 | 31 | parser = argparse.ArgumentParser( 32 | 'denoise', 33 | description="Generate denoised files") 34 | add_flags(parser) 35 | parser.add_argument("--batch_size", default=1, type=int, help="batch size") 36 | parser.add_argument("--sample_rate", default=16000, type=int, help="sample_rate") 37 | parser.add_argument('-v', '--verbose', action='store_const', const=logging.DEBUG, 38 | default=logging.INFO, help="more logging") 39 | parser.add_argument('--filter_silent', action="store_true",help="filter silent examples based on peaks") 40 | parser.add_argument("--data_dir", type=str, default=None, 41 | help="path to the parent directory containing subdirectories named clean, noisy, denoised (this one containing subdirs with all methods) ") 42 | 43 | def get_peaks(wav, sample_rate, smoothing_window=6, peak_window=10): 44 | wav = noisereduce.reduce_noise(y=wav, sr=sample_rate) 45 | spec = librosa.magphase(librosa.stft(wav, n_fft=2048, hop_length=512, win_length=2048, window=np.ones, center=True))[0] 46 | frames2time = 512/sample_rate 47 | rms = librosa.feature.rms(S=spec).squeeze() 48 | rms = np.nan_to_num(rms) 49 | if hasattr(rms, "__len__"): 50 | if smoothing_window>len(rms): 51 | smoothing_window = len(rms)//3 52 | if smoothing_window>1: 53 | rms = scipy.signal.savgol_filter(rms, smoothing_window, 2) # window size 3, polynomial order 2 54 | ### compute peaks in both channels 55 | if peak_window>len(rms): 56 | peak_window = len(rms)//2 57 | peaks, _ = scipy.signal.find_peaks(rms, height=0.01, distance=peak_window) 58 | allowed = int(3 * sample_rate / 512) 59 | peaks = peaks * frames2time 60 | else: 61 | peaks = np.array([]) 62 | return peaks 63 | 64 | class EvaluationSet(torch.utils.data.Dataset): 65 | def __init__(self, data_dir, subdir, sample_rate=None): 66 | """ 67 | """ 68 | self.files = biodenoising.denoiser.audio.find_audio_files(os.path.join(data_dir,'clean')) 69 | self.sample_rate = sample_rate 70 | self.subdir = subdir 71 | 72 | def __len__(self): 73 | return len(self.files) 74 | 75 | def __getitem__(self, index): 76 | file, _ = self.files[index] 77 | clean, sr = torchaudio.load(str(file)) 78 | assert sr == self.sample_rate, f"Expected {file} to have sample rate of {self.sample_rate}, but got {sr}" 79 | noisy, sr = torchaudio.load(str(file.replace('clean','noisy'))) 80 | denoised, sr = torchaudio.load(str(file.replace('clean',"denoised"+os.sep+self.subdir))) 81 | assert sr == self.sample_rate, f"Expected {file} to have sample rate of {self.sample_rate}, but got {sr}" 82 | return clean, noisy, denoised, file 83 | 84 | 85 | def _run_metrics(clean, estimate, noisy, args, sr): 86 | if estimate.shape[-1] < clean.shape[-1]: 87 | clean = clean[..., :estimate.shape[-1]] 88 | noisy = noisy[..., :estimate.shape[-1]] 89 | sisdr = scale_invariant_signal_distortion_ratio(estimate, clean) 90 | sisdr_noisy = scale_invariant_signal_distortion_ratio(noisy, clean) 91 | sisdri = sisdr - sisdr_noisy 92 | return sisdri.mean().item(), sisdr.mean().item() 93 | 94 | def _evaluate(clean_signals, noisy_signals, denoised_signals, filenames, data_dir, subdir, sample_rate, args): 95 | sisdri_all, sisdr_all, fnames = [], [], [] 96 | for clean, noisy, denoised, filename in zip(clean_signals, noisy_signals, denoised_signals, filenames): 97 | run = True 98 | if args.filter_silent: 99 | peaks = get_peaks(denoised.to('cpu').numpy().squeeze(), args.sample_rate) 100 | if len(peaks) == 0: 101 | run = False 102 | if run: 103 | sisdri_i, sisdr_i = _run_metrics(clean, denoised, noisy, args, sr=args.sample_rate) 104 | sisdri_all.append(sisdri_i) 105 | sisdr_all.append(sisdr_i) 106 | fnames.append(os.path.basename(filename)) 107 | return sisdri_all, sisdr_all, fnames 108 | 109 | def process(args): 110 | out_dir = args.data_dir 111 | assert os.path.exists(os.path.join(args.data_dir,'clean')), f"Directory {os.path.join(args.data_dir,'clean')} does not exist" 112 | subdirs = [name for name in os.listdir(os.path.join(args.data_dir,'denoised')) if os.path.isdir(os.path.join(args.data_dir,'denoised', name))] 113 | os.makedirs(os.path.join(args.data_dir, 'results'), exist_ok=True) 114 | for subdir in sorted(subdirs): 115 | print(f"Evaluating {subdir}") 116 | 117 | dset = EvaluationSet(args.data_dir, subdir, args.sample_rate) 118 | 119 | if dset is None: 120 | return 121 | loader = biodenoising.denoiser.distrib.loader(dset, batch_size=1) 122 | sisdri_all, sisdr_all, fns_all = [], [], [] 123 | df = pd.DataFrame(columns=['sisdri','sisdr','filename']) 124 | with ProcessPoolExecutor(args.num_workers) as pool: 125 | iterator = biodenoising.denoiser.utils.LogProgress(logger, loader, name="Evaluating files") 126 | pendings = [] 127 | for data in iterator: 128 | # Get batch data 129 | clean_signals, noisy_signals, denoised_signals, filenames = data 130 | clean_signals = torch.nan_to_num(clean_signals.to(args.device)) 131 | noisy_signals = torch.nan_to_num(noisy_signals.to(args.device)) 132 | denoised_signals = torch.nan_to_num(denoised_signals.to(args.device)) 133 | if args.device == 'cpu' and args.num_workers > 1: 134 | pendings.append( 135 | pool.submit(_evaluate, 136 | clean_signals, noisy_signals, denoised_signals, filenames, out_dir, subdir, args.sample_rate, args)) 137 | else: 138 | sisdri, sisdr, fns = _evaluate(clean_signals, noisy_signals, denoised_signals, filenames, out_dir, subdir, args.sample_rate, args) 139 | sisdri_all.extend(sisdri) 140 | sisdr_all.extend(sisdr) 141 | fns_all.extend(fns) 142 | 143 | if pendings: 144 | print('Waiting for pending jobs...') 145 | for pending in biodenoising.denoiser.utils.LogProgress(logger, pendings, updates=5, name="Evaluating files"): 146 | sisdri, sisdr, fns = pending.result() 147 | sisdri_all.extend(sisdri) 148 | sisdr_all.extend(sisdr) 149 | fns_all.extend(fns) 150 | df['sisdri'] = sisdri_all 151 | df['sisdr'] = sisdr_all 152 | df['filename'] = fns_all 153 | df.to_csv(os.path.join(args.data_dir, 'results', subdir+'.csv')) 154 | print("mean SDRi {} SDR {}".format(df['sisdri'].mean(), df['sisdr'].mean())) 155 | print("median SDRi {} SDR {}".format(df['sisdri'].median(), df['sisdr'].median())) 156 | print(f"Done evaluating {subdir}") 157 | 158 | 159 | if __name__ == "__main__": 160 | args = parser.parse_args() 161 | logging.basicConfig(stream=sys.stderr, level=args.verbose) 162 | logger.debug(args) 163 | process(args) 164 | -------------------------------------------------------------------------------- /biodenoising/denoiser/audio.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import json 3 | from pathlib import Path 4 | import math 5 | import os 6 | import sys 7 | import inspect 8 | import random 9 | import torch 10 | import torchaudio 11 | from torch.nn import functional as F 12 | 13 | from .dsp import convert_audio 14 | 15 | Info = namedtuple("Info", ["length", "sample_rate", "channels"]) 16 | 17 | 18 | def get_info(path): 19 | try: 20 | info = torchaudio.info(path) 21 | if hasattr(info, 'num_frames'): 22 | # new version of torchaudio 23 | return Info(info.num_frames, info.sample_rate, info.num_channels) 24 | else: 25 | siginfo = info[0] 26 | return Info(siginfo.length // siginfo.channels, siginfo.rate, siginfo.channels) 27 | except Exception as e: 28 | print("Error while loading {}: {}".format(path, e)) 29 | return None 30 | 31 | 32 | def find_audio_files(path, exts=[".wav"], progress=True): 33 | audio_files = [] 34 | for root, folders, files in os.walk(path, followlinks=True): 35 | for file in files: 36 | file = Path(root) / file 37 | if file.suffix.lower() in exts: 38 | audio_files.append(str(file.resolve())) 39 | return build_meta(audio_files, progress=progress) 40 | 41 | def build_meta(audio_files, progress=True): 42 | meta = [] 43 | for idx, file in enumerate(audio_files): 44 | info = get_info(file) 45 | if info is not None: 46 | meta.append((file, info.length)) 47 | if progress: 48 | print(format((1 + idx) / len(audio_files), " 3.1%"), end='\r', file=sys.stderr) 49 | meta.sort() 50 | return meta 51 | 52 | def repeat_and_pad(audio, duration_samples, repeat_prob=0.5, random_repeat=False, random_pad=False, random_obj=None): 53 | assert random_obj is not None, "random_obj should be provided" 54 | if random_obj.random() > repeat_prob: 55 | ### repeat 56 | if random_repeat: 57 | initial_length = audio.shape[-1] 58 | nrepeats = random_obj.randint(2,math.ceil(duration_samples/audio.shape[-1])) 59 | for i in range(nrepeats): 60 | gain = random_obj.uniform(0.1,1.3) 61 | audio[...,i*initial_length:(i+1)*initial_length] *= gain 62 | else: 63 | nrepeats = int(math.ceil(duration_samples/audio.shape[-1])) 64 | audio = audio.repeat(1,nrepeats) 65 | audio = audio[...,:duration_samples] 66 | if audio.shape[-1] < duration_samples: 67 | if random_pad: 68 | audio_final = torch.zeros((audio.shape[0],duration_samples), device=audio.device) 69 | ### pad 70 | pad_left = random_obj.randint(0,duration_samples - audio.shape[-1]) 71 | audio_final[...,pad_left:pad_left+audio.shape[-1]] = audio 72 | audio = audio_final 73 | pad_right = duration_samples - audio.shape[-1] - pad_left 74 | else: 75 | pad_left = int((duration_samples - audio.shape[-1])//2) 76 | pad_right = int(duration_samples - audio.shape[-1] - pad_left) 77 | audio = F.pad(audio, (pad_left, pad_right), mode='constant', value=0) 78 | return audio 79 | 80 | class Audioset(torch.utils.data.Dataset): 81 | def __init__(self, files=None, length=None, stride=None, 82 | pad=True, with_path=False, sample_rate=None, 83 | channels=None, convert=False, repeat_prob=0.5, 84 | random_repeat=False, random_pad=False, use_subset=False, 85 | random_obj=None, resample_to_sr=None): 86 | """ 87 | files should be a list [(file, length)] 88 | """ 89 | self.files = files 90 | self.with_path = with_path 91 | self.length = length 92 | self.num_examples = [] 93 | self.offsets = [] 94 | self.subsets = {} 95 | self.num_examples_subsets = {} 96 | self.offsets_subsets = {} 97 | self.use_subset = use_subset 98 | self.length = length 99 | self.stride = stride or length 100 | self.sample_rate = sample_rate 101 | self.channels = channels 102 | self.convert = convert 103 | self.repeat_prob = repeat_prob 104 | self.random_repeat = random_repeat 105 | self.random_pad = random_pad 106 | self.random_obj = random_obj if random_obj is not None else random.Random(0) 107 | self.use_old = True if 'frame_offset' in inspect.getfullargspec(torchaudio.load)[0] else False 108 | self.resample_to_sr = resample_to_sr 109 | 110 | for idx, (file, file_length) in enumerate(self.files): 111 | subset = os.path.basename(file).split('_')[0] 112 | if subset not in self.subsets.keys(): 113 | self.subsets[subset] = [] 114 | self.num_examples_subsets[subset] = [] 115 | self.offsets_subsets[subset] = [] 116 | if length is None: 117 | examples = 1 118 | elif file_length < length: 119 | examples = 1 #if pad else 0 120 | elif pad: 121 | examples = int(math.ceil((file_length - self.length) / self.stride) + 1) 122 | else: 123 | examples = (file_length - self.length) // self.stride + 1 124 | self.num_examples.append(examples) 125 | # self.num_frames.append(file_length) 126 | self.num_examples_subsets[subset].append(examples) 127 | self.subsets[subset].append((file,file_length)) 128 | for chunk_idx in range(examples): 129 | if self.stride is None: 130 | self.offsets.append((idx, 0)) 131 | self.offsets_subsets[subset].append((idx, 0)) 132 | else: 133 | self.offsets.append((idx, chunk_idx * self.stride)) 134 | self.offsets_subsets[subset].append((idx, chunk_idx * self.stride)) 135 | 136 | if len(self.subsets.keys()) > 1: 137 | self.subset = list(self.subsets.keys())[0] 138 | 139 | def __len__(self): 140 | if self.use_subset: 141 | return sum(self.num_examples_subsets[self.subset]) 142 | else: 143 | return sum(self.num_examples) 144 | 145 | def total_examples(self): 146 | if self.use_subset: 147 | return sum(sum(self.num_examples_subsets[subset]) for subset in self.subsets.keys()) 148 | else: 149 | return sum(self.num_examples) 150 | 151 | def set_subset(self, subset_id): 152 | subsets = list(self.subsets.keys()) 153 | subset = subsets[subset_id] 154 | self.subset = subset 155 | 156 | def __getitem__(self, index): 157 | if self.use_subset: 158 | file_idx, offset = self.offsets_subsets[self.subset][index] 159 | else: 160 | file_idx, offset = self.offsets[index] 161 | filename, num_frames = self.files[file_idx] 162 | nframes = -1 if self.length is None else self.length 163 | if self.use_old: 164 | out, sr = torchaudio.load(filename, frame_offset=offset, num_frames=nframes) 165 | else: 166 | out, sr = torchaudio.load(filename, offset=offset, num_frames=nframes) 167 | 168 | target_channels = self.channels or out.shape[0] 169 | if self.resample_to_sr and sr != self.resample_to_sr: 170 | target_sr = self.resample_to_sr 171 | # Use convert_audio function to resample 172 | out = convert_audio(out, sr, target_sr, target_channels) 173 | sr = target_sr 174 | if out.shape[0] != target_channels: 175 | ### take the mean of the channels 176 | out = out.mean(dim=0, keepdim=True) 177 | if nframes > out.shape[-1]: 178 | out = repeat_and_pad(out, self.length, repeat_prob=self.repeat_prob, random_repeat=self.random_repeat, random_pad=self.random_pad, random_obj=self.random_obj) 179 | if self.with_path: 180 | return out, filename, sr 181 | else: 182 | return out 183 | -------------------------------------------------------------------------------- /biodenoising/denoiser/evaluate.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/facebookresearch/demucs under the MIT License 2 | # Original Copyright (c) Earth Species Project. This work is based on Facebook's denoiser. 3 | 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # author: adiyoss 8 | 9 | import argparse 10 | from concurrent.futures import ProcessPoolExecutor 11 | import json 12 | import logging 13 | import sys 14 | import os 15 | 16 | import torch 17 | import torchmetrics 18 | from torchmetrics.functional.audio.sdr import scale_invariant_signal_distortion_ratio 19 | 20 | from .data import NoisyCleanSet 21 | from .enhance import add_flags, get_estimate 22 | from . import distrib, pretrained 23 | from .utils import bold, LogProgress 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | parser = argparse.ArgumentParser( 28 | 'denoiser.evaluate', 29 | description='Denoising using Demucs - Evaluate model performance') 30 | add_flags(parser) 31 | parser.add_argument('--data_dir', help='directory including noisy.json and clean.json files') 32 | parser.add_argument('--matching', default="sort", help='set this to dns for the dns dataset.') 33 | parser.add_argument('-v', '--verbose', action='store_const', const=logging.DEBUG, 34 | default=logging.INFO, help="More loggging") 35 | 36 | 37 | def evaluate(args, model=None, data_loader=None, experiment_logger=None, window_size=0, hop_size=0): 38 | total_sisdri = 0 39 | total_sisdr = 0 40 | total_sisdrn = 0 41 | all_sisdri = [] 42 | all_sisdr = [] 43 | all_sisdrn = [] 44 | total_cnt = 0 45 | updates = 5 46 | 47 | # Load model 48 | if not model: 49 | model = pretrained.get_model(args).to(args.device) 50 | model.eval() 51 | 52 | if window_size > 0: 53 | import asteroid 54 | ola_model = asteroid.dsp.overlap_add.LambdaOverlapAdd( 55 | nnet=model, # function to apply to each segment. 56 | n_src=1, # number of sources in the output of nnet 57 | window_size=window_size, # Size of segmenting window 58 | hop_size=hop_size, # segmentation hop size 59 | window="hann", # Type of the window (see scipy.signal.get_window) 60 | reorder_chunks=False, # Whether to reorder each consecutive segment. 61 | enable_grad=False, # Set gradient calculation on of off (see torch.set_grad_enabled) 62 | ) 63 | ola_model.window = ola_model.window.to(args.device) 64 | else: 65 | ola_model = None 66 | 67 | # Load data 68 | if data_loader is None: 69 | dataset = NoisyCleanSet(args.data_dir, 70 | matching=args.matching, sample_rate=args.sample_rate, with_path=True) 71 | data_loader = distrib.loader(dataset, batch_size=1, num_workers=2) 72 | pendings = [] 73 | with ProcessPoolExecutor(args.num_workers) as pool: 74 | with torch.no_grad(): 75 | iterator = LogProgress(logger, data_loader, name="Eval estimates") 76 | for i, data in enumerate(iterator): 77 | # Get batch data 78 | tnoisy, tclean = [x for x in data] 79 | if len(tnoisy) > 1 and isinstance(tnoisy[1][0], str): 80 | noisy = tnoisy[0].to(args.device) 81 | clean = tclean[0].to(args.device) 82 | filename = str(os.path.basename(tnoisy[1][0]).rsplit(".", 1)[0]) 83 | else: 84 | noisy = tnoisy.to(args.device) 85 | clean = tclean.to(args.device) 86 | filename = str(i) 87 | #### If device is CPU, we do parallel evaluation in each CPU worker. 88 | if args.device == 'cpu': 89 | pendings.append( 90 | pool.submit(_estimate_and_run_metrics, clean, model, noisy, args, filename, experiment_logger)) 91 | else: 92 | if ola_model is not None and noisy.shape[-1] > (2*window_size): 93 | estimate = get_estimate(ola_model, noisy, args) 94 | else: 95 | estimate = get_estimate(model, noisy, args) 96 | sisdri_i, sisdr_i, sisdrn_i = _run_metrics(clean, estimate, noisy, args, sr=args.sample_rate, filename=filename, experiment_logger=experiment_logger) 97 | total_sisdri += sisdri_i 98 | total_sisdr += sisdr_i 99 | total_sisdrn += sisdrn_i 100 | all_sisdri.append(sisdri_i) 101 | all_sisdr.append(sisdr_i) 102 | all_sisdrn.append(sisdrn_i) 103 | total_cnt += clean.shape[0] 104 | 105 | if args.device == 'cpu': 106 | for pending in LogProgress(logger, pendings, updates, name="Eval metrics"): 107 | sisdri_i, sisdr_i, sisdrn_i = pending.result() 108 | total_sisdri += sisdri_i 109 | total_sisdr += sisdr_i 110 | total_sisdrn += sisdrn_i 111 | all_sisdri.append(sisdri_i) 112 | all_sisdr.append(sisdr_i) 113 | all_sisdrn.append(sisdrn_i) 114 | 115 | 116 | metrics = [total_sisdri, total_sisdr, total_sisdrn] 117 | sisdri_mean, sisdr_mean, sisdrn_mean = distrib.average([m/total_cnt for m in metrics], total_cnt) 118 | sisdri_median = torch.median(torch.tensor(all_sisdri)).item() 119 | sisdr_median = torch.median(torch.tensor(all_sisdr)).item() 120 | sisdrn_median = torch.median(torch.tensor(all_sisdrn)).item() 121 | logger.info(bold(f'SISDR performance: sisdri_mean={round(sisdri_mean,2)}, sisdr_mean={round(sisdr_mean,2)}, sisdri_median={round(sisdri_median,2)}, sisdr_median={round(sisdr_median,2)}')) 122 | return sisdri_mean, sisdr_mean, sisdri_median, sisdr_median, sisdrn_mean, sisdrn_median 123 | 124 | 125 | def _estimate_and_run_metrics(clean, model, noisy, args, filename, experiment_logger=None): 126 | estimate = get_estimate(model, noisy, args) 127 | return _run_metrics(clean, estimate, noisy, args, sr=args.sample_rate, filename=filename, experiment_logger=experiment_logger) 128 | 129 | 130 | def _run_metrics(clean, estimate, noisy, args, sr, filename, experiment_logger=None): 131 | sisdr_noisy = scale_invariant_signal_distortion_ratio(noisy, clean) 132 | sisdr = scale_invariant_signal_distortion_ratio(estimate[:,0:1,:], clean) 133 | sisdri = sisdr - sisdr_noisy 134 | metadata = {'sisdr': sisdr.mean().item(), 'sisdr_noisy': sisdr_noisy.mean().item(), 'sisdri': sisdri.mean().item()} 135 | if estimate.shape[1] > 1: 136 | sisdr_noise = scale_invariant_signal_distortion_ratio(estimate[:,1:2,:], noisy-clean) 137 | metadata.update({'sisdr_noise': sisdr_noise.mean().item()}) 138 | if experiment_logger is not None: 139 | experiment_logger.log_audio(estimate[:,1:2,:].squeeze().detach().cpu().numpy(), 140 | sample_rate=sr, 141 | file_name=filename + "_noise_est.wav", 142 | metadata=metadata, overwrite=False, 143 | step=experiment_logger.step) 144 | else: 145 | sisdr_noise = torch.zeros_like(sisdr) 146 | if experiment_logger is not None: 147 | experiment_logger.log_audio(noisy.squeeze().detach().cpu().numpy(), 148 | sample_rate=sr, 149 | file_name=filename + "_noisy.wav", 150 | metadata=metadata, overwrite=False, 151 | step=experiment_logger.step) 152 | experiment_logger.log_audio(estimate[:,0,:].squeeze().detach().cpu().numpy(), 153 | sample_rate=sr, 154 | file_name=filename + "_enhanced.wav", 155 | metadata=metadata, overwrite=False, 156 | step=experiment_logger.step) 157 | experiment_logger.log_audio(clean.squeeze().detach().cpu().numpy(), 158 | sample_rate=sr, 159 | file_name=filename + "_clean.wav", 160 | metadata=metadata, overwrite=False, 161 | step=experiment_logger.step) 162 | return sisdri.mean().item(), sisdr.mean().item(), sisdr_noise.mean().item() 163 | 164 | 165 | def main(): 166 | args = parser.parse_args() 167 | logging.basicConfig(stream=sys.stderr, level=args.verbose) 168 | logger.debug(args) 169 | sisdri, sisdr = evaluate(args) 170 | json.dump({'sisdri': sisdri, 'sisdr': sisdr}, sys.stdout) 171 | sys.stdout.write('\n') 172 | 173 | 174 | if __name__ == '__main__': 175 | main() 176 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Adapted from https://github.com/facebookresearch/demucs under the MIT License 3 | # Original Copyright (c) Earth Species Project. This work is based on Facebook's denoiser. 4 | 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import logging 10 | import os 11 | 12 | import hydra 13 | import random 14 | import numpy as np 15 | import biodenoising 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | def run(args): 20 | experiment_logger = None 21 | if "cometml" in args: 22 | import comet_ml 23 | os.environ["COMET_API_KEY"] = args.cometml['api-key'] 24 | experiment_logger = comet_ml.Experiment(args.cometml['api-key'], project_name=args.cometml['project'], log_code=False) 25 | experiment_logger.log_parameters(args) 26 | experiment_name = os.path.basename(os.getcwd()) 27 | experiment_logger.set_name(experiment_name) 28 | 29 | import torch 30 | 31 | biodenoising.denoiser.distrib.init(args) 32 | 33 | ### set the random seed 34 | torch.manual_seed(args.seed) 35 | random.seed(args.seed) 36 | np.random.seed(args.seed) 37 | torch.backends.cudnn.deterministic = True 38 | torch.use_deterministic_algorithms(True) 39 | rng = random.Random(args.seed) 40 | rngnp = np.random.default_rng(seed=args.seed) 41 | 42 | def seed_worker(worker_id): 43 | worker_seed = torch.initial_seed() % 2**32 44 | np.random.seed(worker_seed) 45 | random.seed(worker_seed) 46 | 47 | g = torch.Generator() 48 | g.manual_seed(args.seed) 49 | rngth = torch.Generator(device=args.device) 50 | rngth.manual_seed(args.seed) 51 | 52 | 53 | if args.sample_rate == 48000: 54 | args.demucs.resample = 8 55 | if args.model=="demucs": 56 | if 'chout' in args.demucs: 57 | args.demucs['chout'] = args.demucs['chout']*args.nsources 58 | model = biodenoising.denoiser.demucs.Demucs(**args.demucs, sample_rate=args.sample_rate) 59 | if args.teacher_student: 60 | model_teacher = biodenoising.denoiser.demucs.Demucs(**args.demucs, sample_rate=args.sample_rate).to(torch.device("cpu")) 61 | elif args.model=="cleanunet": 62 | model = biodenoising.denoiser.cleanunet.CleanUNet(**args.cleanunet) 63 | if args.teacher_student: 64 | model_teacher = biodenoising.denoiser.cleanunet.CleanUNet(**args.cleanunet).to(torch.device("cpu")) 65 | 66 | if args.show: 67 | logger.info(model) 68 | mb = sum(p.numel() for p in model.parameters()) * 4 / 2**20 69 | logger.info('Size: %.1f MB', mb) 70 | if hasattr(model, 'valid_length'): 71 | field = model.valid_length(1) 72 | logger.info('Field: %.1f ms', field / args.sample_rate * 1000) 73 | return 74 | 75 | assert args.batch_size % biodenoising.denoiser.distrib.world_size == 0 76 | args.batch_size //= biodenoising.denoiser.distrib.world_size 77 | length = int(args.segment * args.sample_rate) 78 | stride = int(args.stride * args.sample_rate) 79 | # Demucs requires a specific number of samples to avoid 0 padding during training 80 | if hasattr(model, 'valid_length'): 81 | length = model.valid_length(length) 82 | kwargs_valid = {"sample_rate": args.sample_rate,"seed": args.seed,"nsources": args.nsources,"exclude": args.exclude,"exclude_noise": args.exclude_noise, "rng":rng, "rngnp":rngnp, "rngth":rngth } 83 | kwargs_train = {"sample_rate": args.sample_rate,"seed": args.seed,"nsources": args.nsources,"exclude": args.exclude,"exclude_noise": args.exclude_noise, "rng":rng, "rngnp":rngnp, "rngth":rngth, 84 | 'repeat_prob': args.repeat_prob, 'random_repeat': args.random_repeat, 'random_pad': args.random_pad, 'silence_prob': args.silence_prob, 'noise_prob': args.noise_prob, 85 | 'normalize':args.normalize, 'random_gain':args.random_gain, 'low_gain':args.low_gain, 'high_gain':args.high_gain} 86 | if 'seed=' in args.dset.train: 87 | args.dset.train = args.dset.train.replace('seed=', f'seed={args.seed}') 88 | if args.continue_from and 'seed=' in args.continue_from: 89 | args.continue_from = args.continue_from.replace('seed=', f'seed={args.seed}') 90 | if args.continue_pretrained and 'seed=' in args.continue_pretrained: 91 | args.continue_pretrained = args.continue_pretrained.replace('seed=', f'seed={args.seed}') 92 | # Building datasets and loaders 93 | tr_dataset = biodenoising.datasets.NoiseClean1WeightedSet( 94 | args.dset.train, length=length, stride=stride, pad=args.pad, epoch_size=args.epoch_size, 95 | low_snr=args.dset.low_snr,high_snr=args.dset.high_snr,**kwargs_train) 96 | tr_loader = biodenoising.denoiser.distrib.loader( 97 | tr_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, worker_init_fn=seed_worker, generator=g) 98 | if args.dset.valid: 99 | # cv_dataset = biodenoising.denoiser.data.NoisyCleanSet(args.dset.valid, **kwargs) 100 | # cv_loader = biodenoising.denoiser.distrib.loader(cv_dataset, batch_size=1, num_workers=args.num_workers) 101 | cv_dataset = biodenoising.datasets.NoiseCleanValidSet( 102 | args.dset.valid, length=length, stride=0, pad=False, epoch_size=args.epoch_size, 103 | low_snr=args.dset.low_snr,high_snr=args.dset.high_snr,**kwargs_valid) 104 | cv_loader = biodenoising.denoiser.distrib.loader( 105 | cv_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers//4) 106 | else: 107 | cv_loader = None 108 | if args.dset.test: 109 | del kwargs_valid["exclude"] 110 | del kwargs_valid["exclude_noise"] 111 | del kwargs_valid["rng"] 112 | del kwargs_valid["rngnp"] 113 | del kwargs_valid["rngth"] 114 | if isinstance(args.dset.test, str): 115 | args.dset.test = {'biodenoising':args.dset.test} 116 | tt_dataset = {} 117 | tt_loader = {} 118 | for key, value in args.dset.test.items(): 119 | tt_dataset[key] = biodenoising.denoiser.data.NoisyCleanSet(value, stride=0, pad=False, with_path=True, **kwargs_valid) 120 | tt_loader[key] = biodenoising.denoiser.distrib.loader(tt_dataset[key], batch_size=1, shuffle=False, num_workers=args.num_workers//4) 121 | else: 122 | tt_loader = None 123 | data = {"tr_loader": tr_loader, "cv_loader": cv_loader, "tt_loader": tt_loader} 124 | 125 | if args.continue_pretrained: 126 | args.epochs = np.maximum(1, np.ceil(args.full_size / len(tr_loader.dataset))) 127 | else: 128 | args.epochs = np.maximum(1, np.ceil(args.full_size / len(tr_loader.dataset))) 129 | print("Train size", len(tr_loader.dataset)) 130 | # args.lr = args.lr * args.batch_size / 16 131 | if torch.cuda.is_available(): 132 | model.cuda() 133 | 134 | # optimizer 135 | if args.optim == "adam": 136 | optimizer = torch.optim.NAdam(model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) 137 | total_steps = int(args.epochs * len(tr_loader)) 138 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, total_steps=total_steps)#, cycle_momentum=False 139 | # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) 140 | elif args.optim == "lion": 141 | import lion_pytorch 142 | optimizer = lion_pytorch.Lion(model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) 143 | total_steps = int(args.epochs * len(tr_loader)) 144 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, total_steps=total_steps)#, cycle_momentum=False 145 | else: 146 | logger.fatal('Invalid optimizer %s', args.optim) 147 | os._exit(1) 148 | 149 | # Construct Solver 150 | if args.teacher_student: 151 | solver = biodenoising.denoiser.solver.TeacherStudentSolver(data, model, model_teacher, optimizer, args, rng=rng, rngnp=rngnp, rngth=rngth, seed=args.seed, experiment_logger=experiment_logger, scheduler=scheduler) 152 | else: 153 | solver = biodenoising.denoiser.solver.Solver(data, model, optimizer, args, rng=rng, rngnp=rngnp, rngth=rngth, seed=args.seed, experiment_logger=experiment_logger, scheduler=scheduler) 154 | solver.train() 155 | 156 | 157 | def _main(args): 158 | global __file__ 159 | # Updating paths in config 160 | for key, value in args.dset.items(): 161 | if key=='test': 162 | ### replace all subkeys 163 | for k,v in value.items(): 164 | args.dset.test[k] = hydra.utils.to_absolute_path(v.replace('<>', os.getenv('USER'))) 165 | elif isinstance(value, str) and key not in ["matching"]: 166 | args.dset[key] = hydra.utils.to_absolute_path(value) 167 | args.continue_pretrained = args.continue_pretrained.replace('<>', os.getenv('USER')) 168 | __file__ = hydra.utils.to_absolute_path(__file__) 169 | if args.verbose: 170 | logger.setLevel(logging.DEBUG) 171 | logging.getLogger("denoise").setLevel(logging.DEBUG) 172 | 173 | logger.info("For logs, checkpoints and samples check %s", os.getcwd()) 174 | logger.debug(args) 175 | if args.ddp and args.rank is None: 176 | biodenoising.denoiser.executor.start_ddp_workers(args) 177 | else: 178 | run(args) 179 | 180 | 181 | @hydra.main(config_path="biodenoising/conf/config.yaml") 182 | def main(args): 183 | try: 184 | _main(args) 185 | except Exception: 186 | logger.exception("Some error happened") 187 | # Hydra intercepts exit code, fixed in beta but I could not get the beta to work 188 | os._exit(1) 189 | 190 | 191 | if __name__ == "__main__": 192 | main() 193 | -------------------------------------------------------------------------------- /results.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ''' 3 | import os 4 | import argparse 5 | from multiprocessing import Pool 6 | from tqdm import tqdm 7 | import pandas as pd 8 | import seaborn as sns 9 | import numpy as np 10 | import scipy 11 | import confidence_intervals 12 | from matplotlib import pyplot as plt 13 | 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument( 17 | "--dataset_path", type=str, required=True, help="Path to biodenoising-validation" 18 | ) 19 | parser.add_argument("--num_processes",default=1,type=int,help="number of processes for multiprocessing") 20 | 21 | def mad(data, axis=None): 22 | return np.mean(np.absolute(data - np.mean(data, axis)), axis) 23 | 24 | def results(df): 25 | methods = ['demucs_demucs_noisy_step0','demucs_demucs_noisereduce_step0','demucs_demucs_none_step0','demucs_demucs_time_scale_step0','noisereduce'] 26 | df = df[df['method'].isin(methods)] 27 | print(df['method'].value_counts()) 28 | median_df = df.groupby(['method','metric'])['dB'].apply(np.median).apply(lambda x: round(x,2)).reset_index() 29 | mad_df = df.groupby(['method','metric'])['dB'].apply(scipy.stats.median_abs_deviation).apply(lambda x: round(x,2)).reset_index() 30 | print(median_df) 31 | print(mad_df) 32 | latex_row = ' & ' 33 | for method in methods: 34 | for metric in ['sisdr','sisdri']: 35 | data = df[(df['method'] == method) & (df['metric'] == metric)]['dB'].to_numpy() 36 | median, (ci_low, ci_high) = confidence_intervals.evaluate_with_conf_int(data, np.median, labels=None, conditions=None, num_bootstraps=1000, alpha=5) 37 | #latex_row += '\SI{' +str(median_df[(median_df['method'] == method) & (median_df['metric'] == metric)]['dB'].values[0]) + '}{} (\SI{' + str(mad_df[(mad_df['method'] == method) & (mad_df['metric'] == metric)]['dB'].values[0]) + '}{}) & ' 38 | latex_row += '\SI{' + str(np.round(median,2)) + '}{} $\\frac{'+ str(np.round(ci_low,2)) + '}{' + str(np.round(ci_high,2)) +'}$ & ' 39 | print(latex_row) 40 | 41 | def results_diff(df,methods): 42 | df = df[df['seed'] == 0] 43 | df1 = df[df['method'].isin(methods.values())].reset_index(drop=True) 44 | df2 = df[df['method'].isin(methods.keys())].reset_index(drop=True) 45 | df2['method'] = df2['method'].map(methods) 46 | 47 | # print(df1['metric'].value_counts()) 48 | # for method in methods.values(): 49 | # print(method,scipy.stats.wilcoxon(df1[(df1['method'] == method) & (df1['metric'] == 'sisdri')]['dB'],df2[(df2['method'] == method) & (df2['metric'] == 'sisdri')]['dB'])) 50 | # print(method,scipy.stats.ttest_ind(df1[(df1['method'] == method) & (df1['metric'] == 'sisdr')]['dB'],df2[(df2['method'] == method) & (df2['metric'] == 'sisdr')]['dB'], equal_var=False)) 51 | # # # plt.hist(df2[df2['method'] == method]['dB'], bins=20, alpha=0.5, label=method) 52 | # # # plt.show() 53 | # scipy.stats.probplot(df2[(df2['method'] == method )& (df2['metric'] == 'sisdri')]['dB'], dist="norm", plot=plt) 54 | # plt.title("shapiro {}".format(scipy.stats.shapiro(df2[(df2['method'] == method) & (df2['metric'] == 'sisdri')]['dB']))) 55 | # plt.show() 56 | # # scipy.stats.probplot(df1[df1['method'] == method& df1['metric'] == 'sisdri']['dB'], dist="norm", plot=plt) 57 | # # plt.title("shapiro {}".format(scipy.stats.shapiro(df1[df1['method'] == method & df1['metric'] == 'sisdri']['dB']))) 58 | # # plt.show() 59 | 60 | print(df1['method'].value_counts()) 61 | print(df2['method'].value_counts()) 62 | median_df = df2.groupby(['method','metric'])['dB'].apply(np.median).apply(lambda x: round(x,2)).reset_index() 63 | mad_df = df2.groupby(['method','metric'])['dB'].apply(scipy.stats.median_abs_deviation).apply(lambda x: round(x,2)).reset_index() 64 | print(median_df) 65 | print(mad_df) 66 | df1 = df1.set_index(['method','metric','filename']) 67 | df2 = df2.set_index(['method','metric','filename']) 68 | diff_df = df2.subtract(df1) 69 | # print(len(diff_df.reset_index()[diff_df.reset_index()['dB']<-0.1 & diff_df.reset_index()['metric'] == 'sisdri']),len(diff_df.reset_index()[diff_df.reset_index()['metric'] == 'sisdri'])) 70 | # import pdb; pdb.set_trace() 71 | # diff_df = (df1.groupby(['method','metric','filename'])['dB'].mean() - df2.groupby(['method','metric','filename'])['dB'].mean()).reset_index() 72 | # mean_df = diff_df.groupby(['method','metric'])['dB'].apply(np.mean).apply(lambda x: round(x,2)).reset_index() 73 | mean_df = diff_df.groupby(['method','metric'])['dB'].apply(np.median).apply(lambda x: round(x,2)).reset_index() 74 | mad_df = diff_df.groupby(['method','metric'])['dB'].apply(scipy.stats.median_abs_deviation).apply(lambda x: round(x,2)).reset_index() 75 | print(mean_df) 76 | print(mad_df) 77 | diff_df = diff_df.reset_index() 78 | latex_row = '' 79 | for metric in ['sisdr','sisdri']: 80 | latex_row += '\\\\' + metric + ' & ' 81 | for method in methods.values(): 82 | #latex_row += '\SI{' +str(mean_df[(mean_df['method'] == method) & (mean_df['metric'] == metric)]['dB'].values[0]) + '}{} (\SI{' + str(mad_df[(mad_df['method'] == method) & (mad_df['metric'] == metric)]['dB'].values[0]) + '}{}) & ' 83 | data = diff_df[(diff_df['method'] == method) & (diff_df['metric'] == metric)]['dB'].to_numpy() 84 | median, (ci_low, ci_high) = confidence_intervals.evaluate_with_conf_int(data, np.median, labels=None, conditions=None, num_bootstraps=1000, alpha=5) 85 | latex_row += '\SI{' + str(np.round(median,2)) + '}{} $\\frac{'+ str(np.round(ci_low,2)) + '}{' + str(np.round(ci_high,2)) +'}$ & ' 86 | 87 | print(latex_row) 88 | 89 | 90 | def process_file(args): 91 | filename, conf = args 92 | method = filename.split('.csv')[0] 93 | seed = None 94 | if ',seed=' in filename: 95 | seed = int(method.split(',seed=')[1].split(',')[0]) 96 | suffix = '' 97 | if len(method.split(',seed=')[1].split(',')) > 1: 98 | suffix = ","+method.split(',seed=')[1].split(',')[1] 99 | method = method.split(',seed=')[0] + suffix 100 | 101 | print("Processing file {}".format(filename)) 102 | #read a csv file into a pandas dataframe and return it 103 | df = pd.read_csv(os.path.join(conf["subset_path"],filename),usecols=[1,2,3]) 104 | df = pd.melt(df, id_vars=["filename"], var_name="metric", value_name="dB") 105 | df['seed'] = seed 106 | df['method'] = method 107 | assert len(df) == 124 or len(df) == 2568 , "Data is missing from csv {} data shape {}".format(filename, len(df)) 108 | return df.reset_index(drop=True) 109 | 110 | def process_folder(arg_dic): 111 | files = [file for file in os.listdir(os.path.join(arg_dic["subset_path"])) if file.endswith('.csv') and not file.startswith('.')] 112 | files = sorted(files) 113 | assert len(files) > 0, "No csv files found in the results folder" 114 | if arg_dic["num_processes"] > 1: 115 | with Pool(processes=arg_dic["num_processes"]) as pool: 116 | mp_args = [[f,arg_dic] for f in files] 117 | results = tqdm(pool.map(process_file, mp_args), total=len(files)) 118 | df = pd.concat(results) 119 | else: 120 | for i,f in enumerate(files): 121 | result = process_file([f,arg_dic]) 122 | if i == 0: 123 | df = result 124 | else: 125 | df = pd.concat([df,result]) 126 | return df 127 | 128 | if __name__ == "__main__": 129 | args = parser.parse_args() 130 | arg_dic = dict(vars(args)) 131 | 132 | for subset in ["16000","16000_large","16000_snr_experiments"]: 133 | print("Processing subset {}".format(subset)) 134 | if subset == "16000_snr_experiments": 135 | for snr in [-5,0,5,10]: 136 | print("Processing snr {}".format(snr)) 137 | arg_dic["subset_path"] = os.path.join(arg_dic["dataset_path"],subset,str(snr),"results") 138 | df = process_folder(arg_dic) 139 | mean_df = df.groupby(['method','metric','filename'])['dB'].mean().reset_index() 140 | results(df) 141 | else: 142 | arg_dic["subset_path"] = os.path.join(arg_dic["dataset_path"],subset,"results") 143 | df = process_folder(arg_dic) 144 | mean_df = df.groupby(['method','metric','filename'])['dB'].mean().reset_index() 145 | results(df) 146 | if subset == "16000": 147 | print("Ablation small") 148 | methods = {'demucs_demucs_noisy_small_step0':'demucs_demucs_noisy_step0','demucs_demucs_noisereduce_small_step0':'demucs_demucs_noisereduce_step0','demucs_demucs_none_small_step0':'demucs_demucs_none_step0','demucs_demucs_time_scale_small_step0':'demucs_demucs_time_scale_step0'} 149 | results_diff(df, methods) 150 | print("Ablation random") 151 | methods = {'demucs_random_noisy_step0':'demucs_demucs_noisy_step0','demucs_random_noisereduce_step0':'demucs_demucs_noisereduce_step0','demucs_random_none_step0':'demucs_demucs_none_step0','demucs_random_time_scale_step0':'demucs_demucs_time_scale_step0'} 152 | results_diff(df, methods) 153 | print("Ablation cleanunet") 154 | methods = {'cleanunet_cleanunet_noisy_step0':'demucs_demucs_noisy_step0','cleanunet_cleanunet_noisereduce_step0':'demucs_demucs_noisereduce_step0','demucs_random_none_step0':'demucs_demucs_none_step0','cleanunet_cleanunet_demucs_time_scale_step0':'demucs_demucs_time_scale_step0'} 155 | results_diff(df, methods) 156 | print("Ablation time scale") 157 | methods = {'demucs_demucs_noisy_step0,timescale=0':'demucs_demucs_noisy_step0','demucs_demucs_noisereduce_step0,timescale=0':'demucs_demucs_noisereduce_step0','demucs_demucs_none_step0,timescale=0':'demucs_demucs_none_step0','demucs_demucs_time_scale_step0,timescale=0':'demucs_demucs_time_scale_step0'} 158 | results_diff(df, methods) 159 | print("Ablation small exclude") 160 | methods = {'demucs_demucs_noisy_excl_step0':'demucs_demucs_noisy_step0','demucs_demucs_noisereduce_excl_step0':'demucs_demucs_noisereduce_step0','demucs_demucs_none_excl_step0':'demucs_demucs_none_step0','demucs_demucs_time_scale_excl_step0':'demucs_demucs_time_scale_step0'} 161 | results_diff(df, methods) -------------------------------------------------------------------------------- /biodenoising/datasets/gng.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | from scipy import spatial 4 | import networkx as nx 5 | import matplotlib.pyplot as plt 6 | from sklearn import decomposition 7 | 8 | __authors__ = 'Adrien Guille' 9 | 10 | ''' 11 | Simple implementation of the Growing Neural Gas algorithm, based on: 12 | A Growing Neural Gas Network Learns Topologies. B. Fritzke, Advances in Neural 13 | Information Processing Systems 7, 1995. 14 | ''' 15 | 16 | 17 | class GrowingNeuralGas: 18 | 19 | def __init__(self, input_data): 20 | self.network = None 21 | self.data = input_data 22 | self.units_created = 0 23 | plt.style.use('ggplot') 24 | 25 | def find_nearest_units(self, observation): 26 | distance = [] 27 | for u, attributes in self.network.nodes(data=True): 28 | vector = attributes['vector'] 29 | dist = spatial.distance.euclidean(vector, observation) 30 | distance.append((u, dist)) 31 | distance.sort(key=lambda x: x[1]) 32 | ranking = [u for u, dist in distance] 33 | return ranking 34 | 35 | def prune_connections(self, a_max): 36 | nodes_to_remove = [] 37 | for u, v, attributes in self.network.edges(data=True): 38 | if attributes['age'] > a_max: 39 | nodes_to_remove.append((u, v)) 40 | for u, v in nodes_to_remove: 41 | self.network.remove_edge(u, v) 42 | 43 | nodes_to_remove = [] 44 | for u in self.network.nodes(): 45 | if self.network.degree(u) == 0: 46 | nodes_to_remove.append(u) 47 | for u in nodes_to_remove: 48 | self.network.remove_node(u) 49 | 50 | def fit_network(self, e_b, e_n, a_max, l, a, d, passes=1, plot_evolution=False): 51 | # logging variables 52 | accumulated_local_error = [] 53 | global_error = [] 54 | network_order = [] 55 | network_size = [] 56 | total_units = [] 57 | self.units_created = 0 58 | # 0. start with two units a and b at random position w_a and w_b 59 | w_a = [np.random.uniform(-2, 2) for _ in range(np.shape(self.data)[1])] 60 | w_b = [np.random.uniform(-2, 2) for _ in range(np.shape(self.data)[1])] 61 | self.network = nx.Graph() 62 | self.network.add_node(self.units_created, vector=w_a, error=0) 63 | self.units_created += 1 64 | self.network.add_node(self.units_created, vector=w_b, error=0) 65 | self.units_created += 1 66 | # 1. iterate through the data 67 | sequence = 0 68 | for p in range(passes): 69 | # print(' Pass #%d' % (p + 1)) 70 | np.random.shuffle(self.data) 71 | steps = 0 72 | for observation in self.data: 73 | # 2. find the nearest unit s_1 and the second nearest unit s_2 74 | nearest_units = self.find_nearest_units(observation) 75 | s_1 = nearest_units[0] 76 | s_2 = nearest_units[1] 77 | # 3. increment the age of all edges emanating from s_1 78 | for u, v, attributes in self.network.edges(data=True, nbunch=[s_1]): 79 | self.network.add_edge(u, v, age=attributes['age']+1) 80 | # 4. add the squared distance between the observation and the nearest unit in input space 81 | self.network.nodes[s_1]['error'] += spatial.distance.euclidean(observation, self.network.nodes[s_1]['vector'])**2 82 | # 5 .move s_1 and its direct topological neighbors towards the observation by the fractions 83 | # e_b and e_n, respectively, of the total distance 84 | update_w_s_1 = e_b * \ 85 | (np.subtract(observation, 86 | self.network.nodes[s_1]['vector'])) 87 | self.network.nodes[s_1]['vector'] = np.add( 88 | self.network.nodes[s_1]['vector'], update_w_s_1) 89 | 90 | for neighbor in self.network.neighbors(s_1): 91 | update_w_s_n = e_n * \ 92 | (np.subtract(observation, 93 | self.network.nodes[neighbor]['vector'])) 94 | self.network.nodes[neighbor]['vector'] = np.add( 95 | self.network.nodes[neighbor]['vector'], update_w_s_n) 96 | # 6. if s_1 and s_2 are connected by an edge, set the age of this edge to zero 97 | # if such an edge doesn't exist, create it 98 | self.network.add_edge(s_1, s_2, age=0) 99 | # 7. remove edges with an age larger than a_max 100 | # if this results in units having no emanating edges, remove them as well 101 | self.prune_connections(a_max) 102 | # 8. if the number of steps so far is an integer multiple of parameter l, insert a new unit 103 | steps += 1 104 | if steps % l == 0: 105 | if plot_evolution: 106 | self.plot_network('visualization/sequence/' + str(sequence) + '.png') 107 | sequence += 1 108 | # 8.a determine the unit q with the maximum accumulated error 109 | q = 0 110 | error_max = 0 111 | for u in self.network.nodes(): 112 | if self.network.nodes[u]['error'] > error_max: 113 | error_max = self.network.nodes[u]['error'] 114 | q = u 115 | # 8.b insert a new unit r halfway between q and its neighbor f with the largest error variable 116 | f = -1 117 | largest_error = -1 118 | for u in self.network.neighbors(q): 119 | if self.network.nodes[u]['error'] > largest_error: 120 | largest_error = self.network.nodes[u]['error'] 121 | f = u 122 | w_r = 0.5 * (np.add(self.network.nodes[q]['vector'], self.network.nodes[f]['vector'])) 123 | r = self.units_created 124 | self.units_created += 1 125 | # 8.c insert edges connecting the new unit r with q and f 126 | # remove the original edge between q and f 127 | self.network.add_node(r, vector=w_r, error=0) 128 | self.network.add_edge(r, q, age=0) 129 | self.network.add_edge(r, f, age=0) 130 | self.network.remove_edge(q, f) 131 | # 8.d decrease the error variables of q and f by multiplying them with a 132 | # initialize the error variable of r with the new value of the error variable of q 133 | self.network.nodes[q]['error'] *= a 134 | self.network.nodes[f]['error'] *= a 135 | self.network.nodes[r]['error'] = self.network.nodes[q]['error'] 136 | # 9. decrease all error variables by multiplying them with a constant d 137 | error = 0 138 | for u in self.network.nodes(): 139 | error += self.network.nodes[u]['error'] 140 | accumulated_local_error.append(error) 141 | network_order.append(self.network.order()) 142 | network_size.append(self.network.size()) 143 | total_units.append(self.units_created) 144 | for u in self.network.nodes(): 145 | self.network.nodes[u]['error'] *= d 146 | if self.network.degree(nbunch=[u]) == 0: 147 | print(u) 148 | global_error.append(self.compute_global_error()) 149 | # plt.clf() 150 | # plt.title('Accumulated local error') 151 | # plt.xlabel('iterations') 152 | # plt.plot(range(len(accumulated_local_error)), accumulated_local_error) 153 | # plt.savefig('visualization/accumulated_local_error.png') 154 | # plt.clf() 155 | # plt.title('Global error') 156 | # plt.xlabel('passes') 157 | # plt.plot(range(len(global_error)), global_error) 158 | # plt.savefig('visualization/global_error.png') 159 | # plt.clf() 160 | # plt.title('Neural network properties') 161 | # plt.plot(range(len(network_order)), network_order, label='Network order') 162 | # plt.plot(range(len(network_size)), network_size, label='Network size') 163 | # plt.legend() 164 | # plt.savefig('visualization/network_properties.png') 165 | 166 | def plot_network(self, file_path): 167 | plt.clf() 168 | plt.scatter(self.data[:, 0], self.data[:, 1]) 169 | node_pos = {} 170 | for u in self.network.nodes(): 171 | vector = self.network.nodes[u]['vector'] 172 | node_pos[u] = (vector[0], vector[1]) 173 | nx.draw(self.network, pos=node_pos) 174 | plt.draw() 175 | plt.savefig(file_path) 176 | 177 | def number_of_clusters(self): 178 | return nx.number_connected_components(self.network) 179 | 180 | def cluster_data(self): 181 | unit_to_cluster = np.zeros(self.units_created) 182 | cluster = 0 183 | for c in nx.connected_components(self.network): 184 | for unit in c: 185 | unit_to_cluster[unit] = cluster 186 | cluster += 1 187 | clustered_data = [] 188 | for observation in self.data: 189 | nearest_units = self.find_nearest_units(observation) 190 | s = nearest_units[0] 191 | clustered_data.append((observation, unit_to_cluster[s])) 192 | return clustered_data 193 | 194 | def reduce_dimension(self, clustered_data): 195 | transformed_clustered_data = [] 196 | svd = decomposition.PCA(n_components=2) 197 | transformed_observations = svd.fit_transform(self.data) 198 | for i in range(len(clustered_data)): 199 | transformed_clustered_data.append((transformed_observations[i], clustered_data[i][1])) 200 | return transformed_clustered_data 201 | 202 | def plot_clusters(self, clustered_data): 203 | number_of_clusters = nx.number_connected_components(self.network) 204 | plt.clf() 205 | plt.title('Cluster affectation') 206 | color = ['r', 'b', 'g', 'k', 'm', 'r', 'b', 'g', 'k', 'm'] 207 | for i in range(number_of_clusters): 208 | observations = [observation for observation, s in clustered_data if s == i] 209 | if len(observations) > 0: 210 | observations = np.array(observations) 211 | plt.scatter(observations[:, 0], observations[:, 1], color=color[i], label='cluster #'+str(i)) 212 | plt.legend() 213 | plt.savefig('visualization/clusters.png') 214 | 215 | def compute_global_error(self): 216 | global_error = 0 217 | for observation in self.data: 218 | nearest_units = self.find_nearest_units(observation) 219 | s_1 = nearest_units[0] 220 | global_error += spatial.distance.euclidean(observation, self.network.nodes[s_1]['vector'])**2 221 | return global_error 222 | 223 | 224 | -------------------------------------------------------------------------------- /scripts/Biodenoising_denoise_zip_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Biodenoising Denoise Demo (ZIP input)\n", 8 | "\n", 9 | "This notebook lets you:\n", 10 | "- Upload a ZIP archive of audio files (e.g., WAV/FLAC)\n", 11 | "- Extract into a timestamped folder\n", 12 | "- Run denoising using `denoise.py` on the extracted folder\n", 13 | "\n", 14 | "Works in Jupyter/Colab. If using Colab, enable GPU for speed.\n" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "# Optional: install biodenoising (marius/fixes) + runtime deps on Colab\n", 24 | "from __future__ import annotations\n", 25 | "import sys\n", 26 | "import subprocess\n", 27 | "import importlib\n", 28 | "\n", 29 | "IN_COLAB = \"google.colab\" in sys.modules\n", 30 | "BRANCH_URL = \"git+https://github.com/earthspecies/biodenoising@marius/fixes\"\n", 31 | "\n", 32 | "# Base deps\n", 33 | "if IN_COLAB:\n", 34 | " try:\n", 35 | " import torch # type: ignore\n", 36 | " import torchaudio # type: ignore\n", 37 | " import soundfile # type: ignore\n", 38 | " import yaml # type: ignore\n", 39 | " import pandas # type: ignore\n", 40 | " except Exception:\n", 41 | " subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", \"torch\", \"torchaudio\", \"--index-url\", \"https://download.pytorch.org/whl/cu121\"], check=True)\n", 42 | " subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", \"soundfile\", \"pyyaml\", \"pandas\", \"numpy\", \"scipy\", \"tqdm\", \"librosa\"], check=True)\n", 43 | "\n", 44 | "# Ensure biodenoising is installed/updated from branch\n", 45 | "try:\n", 46 | " import biodenoising # type: ignore\n", 47 | " print(\"Found biodenoising. Ensuring branch version...\")\n", 48 | " subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", \"--upgrade\", \"--no-cache-dir\", BRANCH_URL], check=True)\n", 49 | " biodenoising = importlib.reload(biodenoising) # type: ignore\n", 50 | " print(\"Ensured biodenoising from marius/fixes branch.\")\n", 51 | "except Exception:\n", 52 | " print(\"Installing biodenoising from GitHub branch marius/fixes...\")\n", 53 | " subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", \"--no-cache-dir\", BRANCH_URL], check=True)\n", 54 | " import biodenoising # type: ignore\n", 55 | " print(\"Installed biodenoising from marius/fixes branch.\")\n" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "# Setup: environment and paths (Colab/Jupyter-friendly)\n", 65 | "from __future__ import annotations\n", 66 | "import os\n", 67 | "import sys\n", 68 | "import time\n", 69 | "from pathlib import Path\n", 70 | "\n", 71 | "# Use current working directory so this works on Colab (/content) and Jupyter\n", 72 | "WORKSPACE = Path.cwd().resolve()\n", 73 | "print(f\"CWD: {Path.cwd()}\")\n", 74 | "print(\"Python:\", sys.version)\n", 75 | "\n", 76 | "# Create base dirs for this demo\n", 77 | "BASE = WORKSPACE / \"scripts\"\n", 78 | "UPLOADS_DIR = BASE / \"denoise_uploads\"\n", 79 | "OUTPUTS_DIR = BASE / \"denoise_outputs\"\n", 80 | "for d in (UPLOADS_DIR, OUTPUTS_DIR):\n", 81 | " d.mkdir(parents=True, exist_ok=True)\n", 82 | "print(\"Uploads dir:\", UPLOADS_DIR)\n", 83 | "print(\"Outputs dir:\", OUTPUTS_DIR)\n" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "# Upload ZIP (Colab or Jupyter)\n", 93 | "from __future__ import annotations\n", 94 | "from datetime import datetime\n", 95 | "import zipfile\n", 96 | "\n", 97 | "zip_path = None\n", 98 | "try:\n", 99 | " from google.colab import files as colab_files # type: ignore\n", 100 | " print(\"Detected Colab. Use chooser to upload a ZIP.\")\n", 101 | " uploaded = colab_files.upload()\n", 102 | " if uploaded:\n", 103 | " name = next(iter(uploaded.keys()))\n", 104 | " data = uploaded[name]\n", 105 | " ts = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n", 106 | " zip_path = UPLOADS_DIR / f\"upload_{ts}.zip\"\n", 107 | " with open(zip_path, \"wb\") as f:\n", 108 | " f.write(data)\n", 109 | " print(\"Saved:\", zip_path)\n", 110 | "except Exception as e:\n", 111 | " print(\"Colab uploader unavailable:\", repr(e))\n", 112 | "\n", 113 | "if zip_path is None:\n", 114 | " try:\n", 115 | " import ipywidgets as widgets # type: ignore\n", 116 | " from IPython.display import display # type: ignore\n", 117 | "\n", 118 | " file_uploader = widgets.FileUpload(accept=\".zip\", multiple=False)\n", 119 | " display(file_uploader)\n", 120 | " print(\"Use the widget above, then re-run this cell once.\")\n", 121 | " if file_uploader.value:\n", 122 | " item = list(file_uploader.value.values())[0]\n", 123 | " ts = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n", 124 | " zip_path = UPLOADS_DIR / f\"upload_{ts}.zip\"\n", 125 | " with open(zip_path, \"wb\") as f:\n", 126 | " f.write(item[\"content\"]) # type: ignore[index]\n", 127 | " print(\"Saved:\", zip_path)\n", 128 | " except Exception as e:\n", 129 | " print(\"ipywidgets uploader unavailable:\", repr(e))\n", 130 | "\n", 131 | "zip_path\n" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "# Extract ZIP to a timestamped folder under scripts/denoise_uploads\n", 141 | "from __future__ import annotations\n", 142 | "import time\n", 143 | "import zipfile\n", 144 | "from pathlib import Path\n", 145 | "\n", 146 | "assert zip_path is not None and Path(zip_path).exists(), \"zip_path must be set to an existing file\"\n", 147 | "\n", 148 | "extract_ts = time.strftime(\"%Y%m%d_%H%M%S\")\n", 149 | "extract_dir = UPLOADS_DIR / f\"unzipped_{extract_ts}\"\n", 150 | "extract_dir.mkdir(parents=True, exist_ok=True)\n", 151 | "\n", 152 | "with zipfile.ZipFile(zip_path, \"r\") as zf:\n", 153 | " zf.extractall(extract_dir)\n", 154 | "\n", 155 | "AUDIO_EXTS = {\".wav\", \".flac\", \".mp3\", \".ogg\", \".m4a\", \".aac\"}\n", 156 | "audio_files = [p for p in extract_dir.rglob(\"*\") if p.suffix.lower() in AUDIO_EXTS]\n", 157 | "print(f\"Extracted to: {extract_dir}\")\n", 158 | "print(f\"Found {len(audio_files)} audio files (recursively)\")\n", 159 | "\n", 160 | "extract_dir\n" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "# Configure and run denoising via denoise.py\n", 170 | "from __future__ import annotations\n", 171 | "from pathlib import Path\n", 172 | "import torch # type: ignore\n", 173 | "\n", 174 | "from biodenoising import denoiser as dn # package import\n", 175 | "from biodenoising.denoiser import denoise as denoise_cli # packaged denoise module\n", 176 | "\n", 177 | "# Build args using the denoise parser defaults\n", 178 | "args = denoise_cli.parser.parse_args([])\n", 179 | "\n", 180 | "# Required IO\n", 181 | "args.noisy_dir = str(extract_dir)\n", 182 | "run_ts = time.strftime(\"%Y%m%d_%H%M%S\")\n", 183 | "run_out_dir = OUTPUTS_DIR / f\"denoise_run_{run_ts}\"\n", 184 | "run_out_dir.mkdir(parents=True, exist_ok=True)\n", 185 | "args.out_dir = str(run_out_dir)\n", 186 | "\n", 187 | "# Device/method\n", 188 | "args.method = getattr(args, \"method\", \"biodenoising16k_dns48\")\n", 189 | "args.device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", 190 | "\n", 191 | "# Optional flags (sane defaults for batch denoising)\n", 192 | "args.transform = getattr(args, \"transform\", \"none\")\n", 193 | "args.keep_original_sr = getattr(args, \"keep_original_sr\", True)\n", 194 | "args.noise_reduce = getattr(args, \"noise_reduce\", False)\n", 195 | "args.selection_table = getattr(args, \"selection_table\", False)\n", 196 | "args.window_size = getattr(args, \"window_size\", 0)\n", 197 | "args.batch_size = getattr(args, \"batch_size\", 1)\n", 198 | "args.num_workers = getattr(args, \"num_workers\", 2)\n", 199 | "\n", 200 | "print({\n", 201 | " \"noisy_dir\": args.noisy_dir,\n", 202 | " \"out_dir\": args.out_dir,\n", 203 | " \"method\": args.method,\n", 204 | " \"device\": args.device,\n", 205 | " \"transform\": args.transform,\n", 206 | " \"keep_original_sr\": args.keep_original_sr,\n", 207 | " \"window_size\": args.window_size,\n", 208 | "})\n", 209 | "\n", 210 | "# Patch get_dataset in packaged denoise to avoid NameError on global args\n", 211 | "if not hasattr(denoise_cli, \"_patched_get_dataset\"):\n", 212 | " def _patched_get_dataset(noisy_dir, sample_rate, channels, keep_original_sr):\n", 213 | " resample_to_sr = sample_rate if not keep_original_sr else None\n", 214 | " if noisy_dir:\n", 215 | " files = dn.audio.find_audio_files(noisy_dir)\n", 216 | " else:\n", 217 | " print(\"No noisy_dir provided; skipping denoising.\")\n", 218 | " return None\n", 219 | " return dn.audio.Audioset(\n", 220 | " files,\n", 221 | " with_path=True,\n", 222 | " sample_rate=sample_rate,\n", 223 | " channels=channels,\n", 224 | " convert=True,\n", 225 | " resample_to_sr=resample_to_sr,\n", 226 | " )\n", 227 | " denoise_cli._patched_get_dataset = True\n", 228 | " denoise_cli.get_dataset = _patched_get_dataset\n", 229 | "\n", 230 | "# Run denoising\n", 231 | "# The CLI's entrypoint is denoise.denoise(args)\n", 232 | "denoise_cli.denoise(args, local_out_dir=args.out_dir)\n", 233 | "print(\"Denoising completed.\")\n", 234 | "run_out_dir\n" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": null, 240 | "metadata": {}, 241 | "outputs": [], 242 | "source": [ 243 | "# Inspect outputs\n", 244 | "from __future__ import annotations\n", 245 | "from pathlib import Path\n", 246 | "\n", 247 | "out_dir = run_out_dir\n", 248 | "print(\"Output dir:\", out_dir)\n", 249 | "\n", 250 | "if out_dir.exists():\n", 251 | " enhanced = sorted(Path(out_dir).rglob(\"*.wav\"))\n", 252 | " print(f\"Enhanced WAV files: {len(enhanced)}\")\n", 253 | " for p in enhanced[:10]:\n", 254 | " print(\"-\", p.relative_to(out_dir))\n", 255 | "else:\n", 256 | " print(\"No outputs found.\")\n" 257 | ] 258 | } 259 | ], 260 | "metadata": { 261 | "language_info": { 262 | "name": "python" 263 | } 264 | }, 265 | "nbformat": 4, 266 | "nbformat_minor": 2 267 | } 268 | -------------------------------------------------------------------------------- /biodenoising/denoiser/denoise.py: -------------------------------------------------------------------------------- 1 | ''' 2 | python denoise.py --input /home/marius/data/biodenoising_validation/16000/noisy/ --output /home/marius/data/biodenoising_validation/16000/denoised/ 3 | ''' 4 | import argparse 5 | from concurrent.futures import ProcessPoolExecutor 6 | import logging 7 | import os 8 | import sys 9 | import numpy as np 10 | 11 | import torch 12 | import torchaudio 13 | import librosa 14 | import norbert 15 | 16 | from biodenoising.selection_table import ( 17 | build_mask_from_events, 18 | find_selection_table_for, 19 | load_events_seconds, 20 | ) 21 | 22 | from .demucs import DemucsStreamer 23 | from .pretrained import add_model_flags, get_model 24 | from .distrib import rank, loader, barrier 25 | from .audio import find_audio_files, build_meta, Audioset 26 | from .utils import LogProgress 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | ALLOWED_EXTENSIONS = set(['.wav','.mp3','.flac','.ogg','.aif','.aiff','.wmv','.WAV','.MP3','.FLAC','.OGG','.AIF','.AIFF','.WMV']) 31 | 32 | def add_flags(parser): 33 | """ 34 | Add the flags for the argument parser that are related to model loading and evaluation" 35 | """ 36 | add_model_flags(parser) 37 | parser.add_argument('--device', default="cpu") 38 | parser.add_argument('--dry', type=float, default=0, 39 | help='dry/wet knob coefficient. 0 is only denoised, 1 only input signal.') 40 | parser.add_argument('--num_workers', type=int, default=5) 41 | parser.add_argument('--streaming', action="store_true", 42 | help="true streaming evaluation for Demucs") 43 | 44 | 45 | parser = argparse.ArgumentParser( 46 | 'denoise', 47 | description="Generate denoised files") 48 | add_flags(parser) 49 | parser.add_argument("--output", type=str, default="enhanced", 50 | help="directory putting enhanced wav files") 51 | parser.add_argument("--batch_size", default=1, type=int, help="batch size") 52 | parser.add_argument('-v', '--verbose', action='store_const', const=logging.DEBUG, 53 | default=logging.INFO, help="more loggging") 54 | parser.add_argument("--method",choices=["demucs"], default="demucs",help="Method to use for denoising") 55 | parser.add_argument("--transform",choices=["none", "time_scale"], default="none",help="Transform input by pitch shifting or time scaling") 56 | parser.add_argument("--sample_rate",choices=[16000], default=16000,help="Sample rate of the model") 57 | parser.add_argument("--keep_original_sr", action="store_true",help="keep the original sample rate of the audio rather than the model sample rate") 58 | parser.add_argument('--selection_table', action="store_true", help="Enable event masking via selection tables (csv/tsv/txt) located next to audio files.") 59 | parser.add_argument("--input", type=str, default=None, 60 | help="path to the directory with noisy wav files") 61 | parser.add_argument("--window_size", type=int, default=0, 62 | help="size of the window for continuous processing") 63 | 64 | def normalize(wav): 65 | return wav / max(wav.abs().max().item(), 1) 66 | 67 | def get_estimate(model, noisy_signals, args): 68 | torch.set_num_threads(1) 69 | estimated_signals = torch.zeros_like(noisy_signals) 70 | for c in range(noisy_signals.shape[1]): 71 | noisy = noisy_signals[:,c:c+1,:] 72 | if args.method=='demucs' and args.streaming: 73 | streamer = DemucsStreamer(model, dry=args.dry) 74 | with torch.no_grad(): 75 | estimate = torch.cat([ 76 | streamer.feed(noisy[0]), 77 | streamer.flush()], dim=1)[None] 78 | else: 79 | with torch.no_grad(): 80 | if hasattr(model, 'ola_forward'): 81 | while noisy.ndim < 3: 82 | noisy = noisy.unsqueeze(0) 83 | estimate = model.forward(noisy) 84 | else: 85 | estimate = model(noisy) 86 | estimate = (1 - args.dry) * estimate + args.dry * noisy 87 | estimated_signals[:,c:c+1,:] = estimate 88 | return estimated_signals 89 | 90 | def time_scaling(signal, scaling): 91 | output_size = int(signal.shape[-1] * scaling) 92 | ref = torch.arange(output_size, device=signal.device, dtype=signal.dtype).div_(scaling) 93 | 94 | ref1 = ref.clone().type(torch.int64) 95 | ref2 = torch.min(ref1 + 1, torch.full_like(ref1, signal.shape[-1] - 1, dtype=torch.int64)) 96 | r = ref - ref1.type(ref.type()) 97 | scaled_signal = signal[..., ref1] * (1 - r) + signal[..., ref2] * r 98 | 99 | return scaled_signal 100 | 101 | 102 | def save_wavs(estimates, noisy_sigs, filenames, out_dir, sr=16_000, write_noisy=False): 103 | # Write result 104 | if rank == 0: 105 | os.makedirs(out_dir, exist_ok=True) 106 | for estimate, noisy, filename in zip(estimates, noisy_sigs, filenames): 107 | filename = os.path.join(out_dir, os.path.basename(filename).rsplit(".", 1)[0]) 108 | if write_noisy: 109 | write(estimate, filename + "_enhanced.wav", sr=sr) 110 | write(noisy, filename +"_noisy.wav", sr=sr) 111 | else: 112 | write(estimate, filename + ".wav", sr=sr) 113 | 114 | 115 | def write(wav, filename, sr=16_000): 116 | # Normalize audio if it prevents clipping 117 | # wav = wav / max(wav.abs().max().item(), 1) 118 | torchaudio.save(filename, wav.cpu(), sr) 119 | 120 | 121 | def get_dataset(noisy_dir, sample_rate, channels, keep_original_sr): 122 | resample_to_sr = sample_rate if not keep_original_sr else None 123 | 124 | if args.noisy_dir: 125 | if os.path.isdir(noisy_dir): 126 | files = find_audio_files(noisy_dir) 127 | else: 128 | audio_files = [noisy_dir] 129 | files = build_meta(audio_files) 130 | else: 131 | logger.warning( 132 | "Small sample set was not provided by noisy_dir. " 133 | "Skipping denoising.") 134 | return None 135 | return Audioset(files, with_path=True, 136 | sample_rate=sample_rate, channels=channels, convert=True, resample_to_sr=resample_to_sr) 137 | 138 | 139 | def _estimate_and_save(model, noisy_signals, filenames, out_dir, sample_rate, data_sample_rate, args): 140 | save_sr = data_sample_rate if args.keep_original_sr else sample_rate 141 | ### Forward 142 | estimate = get_estimate(model, noisy_signals, args) 143 | experiment = args.method 144 | 145 | if args.transform == 'none': 146 | if args.selection_table: 147 | masked_estimates = [] 148 | for i, fn in enumerate(filenames): 149 | table = find_selection_table_for(fn) 150 | events = load_events_seconds(table) 151 | length_frames = estimate.shape[-1] 152 | mask_1d = build_mask_from_events(length_frames, save_sr, events, estimate.device) 153 | mask = mask_1d.view(1, 1, -1) 154 | masked_estimates.append(estimate[i:i+1] * mask) 155 | estimate = torch.cat(masked_estimates, dim=0) if masked_estimates else estimate 156 | save_wavs(estimate, noisy_signals, filenames, out_dir, sr=save_sr) 157 | else: 158 | experiment += '_'+args.transform 159 | estimate_sum = estimate 160 | #noisy_signals = noisy_signals[None,None,:].float() 161 | for i in range(1,4): 162 | ### transform 163 | ### time scaling 164 | noisy_signals = time_scaling(noisy_signals, np.power(2, -0.5)) 165 | # print("Scale to: {}".format(np.power(2, -0.5))) 166 | 167 | ### forward 168 | estimate = get_estimate(model, noisy_signals, args) 169 | 170 | ### transform back 171 | ### time scaling 172 | estimate_write = time_scaling(estimate, np.power(2, i*0.5)) 173 | # print("Scale back: {}".format(np.power(2, i*0.5))) 174 | 175 | if estimate_sum.shape[-1] > estimate_write.shape[-1]: 176 | estimate_sum[...,:estimate_write.shape[-1]] += estimate_write 177 | elif estimate_sum.shape[-1] < estimate_write.shape[-1]: 178 | estimate_sum += estimate_write[...,:estimate_sum.shape[-1]] 179 | else: 180 | estimate_sum += estimate_write 181 | 182 | estimate_out = estimate_sum/4. 183 | if args.selection_table: 184 | masked_estimates = [] 185 | for i, fn in enumerate(filenames): 186 | table = find_selection_table_for(fn) 187 | events = load_events_seconds(table) 188 | length_frames = estimate_out.shape[-1] 189 | mask_1d = build_mask_from_events(length_frames, save_sr, events, estimate_out.device) 190 | mask = mask_1d.view(1, 1, -1) 191 | masked_estimates.append(estimate_out[i:i+1] * mask) 192 | estimate_out = torch.cat(masked_estimates, dim=0) if masked_estimates else estimate_out 193 | save_wavs(estimate_out, noisy_signals, filenames, out_dir, sr=save_sr) 194 | 195 | 196 | 197 | def denoise(args, model=None, local_out_dir=None): 198 | # if args.device == 'cpu' and args.num_workers > 1: 199 | # torch.multiprocessing.set_sharing_strategy('file_system') 200 | sample_rate = args.sample_rate 201 | channels = 1 202 | # Load model 203 | if args.method=='demucs': 204 | if not model: 205 | model = get_model(args).to(args.device) 206 | if args.sample_rate != model.sample_rate: 207 | logger.warning(f"Model sample rate is {model.sample_rate}, " 208 | f"but the provided sample rate is {args.sample_rate}. " 209 | f"Resampling will be performed.") 210 | sample_rate = model.sample_rate 211 | channels = model.chin 212 | else: 213 | sys.exit("Method not implemented") 214 | 215 | model.eval() 216 | 217 | if local_out_dir: 218 | out_dir = local_out_dir 219 | else: 220 | out_dir = args.out_dir 221 | 222 | dset = get_dataset(args.noisy_dir, sample_rate, channels, args.keep_original_sr) 223 | if dset is None: 224 | return 225 | dloader = loader(dset, batch_size=1, shuffle=False) 226 | 227 | if 'demucs' in args.method: 228 | barrier() 229 | 230 | with ProcessPoolExecutor(args.num_workers) as pool: 231 | iterator = LogProgress(logger, dloader, name="Denoising files") 232 | pendings = [] 233 | for data in iterator: 234 | # Get batch data 235 | noisy_signals, filenames, data_sample_rate = data 236 | noisy_signals = noisy_signals.to(args.device) 237 | if args.device == 'cpu' and args.num_workers > 1: 238 | pendings.append( 239 | pool.submit(_estimate_and_save, 240 | model, noisy_signals, filenames, out_dir, sample_rate, data_sample_rate, args)) 241 | else: 242 | if args.window_size > 0: 243 | import asteroid 244 | ola_model = asteroid.dsp.overlap_add.LambdaOverlapAdd( 245 | nnet=model, # function to apply to each segment. 246 | n_src=1, # number of sources in the output of nnet 247 | window_size=args.window_size, # Size of segmenting window 248 | hop_size=args.window_size//4, # segmentation hop size 249 | window="hann", # Type of the window (see scipy.signal.get_window 250 | reorder_chunks=False, # Whether to reorder each consecutive segment. 251 | enable_grad=False, # Set gradient calculation on of off (see torch.set_grad_enabled) 252 | ) 253 | ola_model.window = ola_model.window.to(args.device) 254 | _estimate_and_save(ola_model, noisy_signals, filenames, out_dir, sample_rate, data_sample_rate, args) 255 | else: 256 | _estimate_and_save(model, noisy_signals, filenames, out_dir, sample_rate, data_sample_rate, args) 257 | 258 | if pendings: 259 | print('Waiting for pending jobs...') 260 | for pending in LogProgress(logger, pendings, updates=5, name="Denoising files"): 261 | pending.result() 262 | 263 | 264 | 265 | if __name__ == "__main__": 266 | args = parser.parse_args() 267 | logging.basicConfig(stream=sys.stderr, level=args.verbose) 268 | logger.debug(args) 269 | os.makedirs(args.output, exist_ok=True) 270 | if os.path.isdir(args.input): 271 | ### walk each subfolder recursively 272 | for root, dirs, files in os.walk(args.input): 273 | audio_files = [f for f in files if os.path.splitext(f)[1] in ALLOWED_EXTENSIONS] 274 | if len(audio_files) == 0: 275 | continue 276 | args.noisy_dir = root 277 | relative_path = os.path.relpath(root, args.input) 278 | args.out_dir = os.path.join(args.output, relative_path) 279 | os.makedirs(args.out_dir, exist_ok=True) 280 | denoise(args, local_out_dir=args.out_dir) 281 | elif os.path.splitext(args.input)[1] in ALLOWED_EXTENSIONS: 282 | args.noisy_dir = args.input 283 | os.makedirs(args.output, exist_ok=True) 284 | denoise(args, local_out_dir=args.output) 285 | -------------------------------------------------------------------------------- /prepare_experiments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | import sys 6 | import biodenoising 7 | import pandas as pd 8 | import numpy as np 9 | import math 10 | import random 11 | import julius 12 | import torchaudio 13 | 14 | parser = argparse.ArgumentParser( 15 | 'prepare_experiments', 16 | description="Generate json files with all the audios") 17 | parser.add_argument("--data_dir", type=str, default="enhanced", required=True, 18 | help="directory where the training data is containing train, valid and test subdirectories") 19 | parser.add_argument("--step", default=0, type=int, help="step") 20 | parser.add_argument("--method",choices=["demucs", "cleanunet"], default="demucs",help="Method that was used to denoise at step 'step'") 21 | parser.add_argument("--approach",choices=["noisy2cleaned","noisier2noisy","noisereduce"], default="noisy2cleaned",help="Our approach vs using noisy files and noise") 22 | parser.add_argument("--tag", default="",help="This is used to tag the models at steps>0 with the origin of training data at step 0") 23 | parser.add_argument("--seed", default=-1, type=int, help="seed for step>0") 24 | parser.add_argument("--transform",choices=["none", "time_scale", "all"], default="none",help="Transform input by pitch shifting or time scaling") 25 | parser.add_argument("--dataprune",choices=["none","kmeans"], default="none",help="Data pruning method using the model's saved activations.") 26 | parser.add_argument('--prune_ratio', type=float, default=1., help="use the prune ratio from all the files") 27 | parser.add_argument('--with_previous', action="store_true",help="add previous steps") 28 | parser.add_argument('--with_lower_sr', action="store_true",help="add 16kHz data to the 48kHz experiments") 29 | parser.add_argument('--balance', action="store_true",help="downsample the larger datasets to match the smaller ones") 30 | parser.add_argument('--num_valid', type=int, default=0, help="number of validation files") 31 | parser.add_argument('--use_ratio', type=float, default=1., help="use the top ratio of the files") 32 | parser.add_argument('--version', type=str, default='') 33 | parser.add_argument('--exclude', nargs='+', default=[]) 34 | parser.add_argument('--overfit', action="store_true",help="oracle ablation") 35 | 36 | test_sets = {'biodenoising': 'biodenoising_validation', 'crows':'carrion_crows_denoising','zfinches': 'zebra_finch_denoising'} 37 | 38 | def to_json_folder(data_dict, args): 39 | json_dict = {'train':[], 'valid':[]} 40 | for split, dirs in data_dict.items(): 41 | for d in dirs: 42 | meta=biodenoising.denoiser.audio.find_audio_files(d) 43 | if args.approach=='noisy2cleaned' and 'valid' not in data_dict.keys() and split=='train': 44 | random.shuffle(meta) 45 | if args.num_valid > 0: 46 | json_dict['valid'] += meta[:args.num_valid] 47 | json_dict['train'] += meta[args.num_valid:] 48 | else: 49 | json_dict[split] += meta 50 | return json_dict 51 | 52 | def to_json_list(data_dict,args, json_dict={}): 53 | for split, filelist in data_dict.items(): 54 | meta=biodenoising.denoiser.audio.build_meta(filelist) 55 | if split not in json_dict.keys(): 56 | json_dict[split] = [] 57 | json_dict[split].extend(meta) 58 | return json_dict 59 | 60 | def write_json(json_dict, filename, args): 61 | for split, meta in json_dict.items(): 62 | exp_dirname = args.method + '_' + args.transform + args.tag + '_step' + str(args.step) 63 | if args.dataprune!='none': 64 | exp_dirname += '_'+args.dataprune + '_'+str(args.prune_ratio) 65 | if args.with_previous: 66 | exp_dirname += '_prev' 67 | if args.seed>=0: 68 | exp_dirname += ',seed='+str(args.seed) 69 | if args.approach=='noisier2noisy': 70 | exp_dirname = 'noisy' 71 | elif args.approach=='noisereduce': 72 | exp_dirname = 'noisereduce' 73 | out_dir = os.path.join('biodenoising','egs', os.path.basename(os.path.normpath(args.data_dir)), exp_dirname, split) 74 | os.makedirs(out_dir, exist_ok=True) 75 | fname = os.path.join(out_dir, filename) 76 | with open(fname, 'w', encoding='utf-8') as f: 77 | json.dump(meta, f, ensure_ascii=False, indent=4) 78 | 79 | def resample(inpath_file, outdir, to_samplerate): 80 | """Convert audio from a given samplerate to a target one """ 81 | out_file = os.path.basename(inpath_file) 82 | if not os.path.exists(os.path.join(outdir, out_file)): 83 | # read wav with torchaudio 84 | wav, sr = torchaudio.load(inpath_file) 85 | wav = julius.resample_frac(wav, sr, to_samplerate) 86 | # write wav with torchaudio 87 | torchaudio.save(os.path.join(outdir, out_file), wav, to_samplerate) 88 | return os.path.join(outdir, out_file) 89 | 90 | def generate_json(args): 91 | if args.with_previous: 92 | steps = range(0, args.step+1) 93 | else: 94 | steps = [args.step] 95 | if args.with_lower_sr and args.data_dir.endswith('48k'): 96 | data_dirs = [args.data_dir, args.data_dir.replace('48k','16k')] 97 | else: 98 | data_dirs = [args.data_dir] 99 | if args.dataprune!='none': 100 | mds = [[] for i, data_dir in enumerate(data_dirs)] 101 | for i, data_dir in enumerate(data_dirs): 102 | for step in steps: 103 | experiment = 'clean_'+args.method+ '_' + args.transform + args.tag +'_step'+str(args.step)+args.version 104 | if args.seed>=0: 105 | experiment += ',seed='+str(args.seed) 106 | md = pd.read_csv(os.path.join(args.data_dir, experiment+".csv")) 107 | nclusters = len(md['dataset'].unique()) 108 | prune: Prune = biodenoising.datasets.dataprune.Prune( 109 | prune_type=True, 110 | ssl_type=args.dataprune, 111 | number_cluster='auto', 112 | random_state=42, 113 | dataframe=md, 114 | data_frame_clustered=True 115 | ) 116 | md = prune.prune(prune_fraction=args.prune_ratio) 117 | mds[i].append(md) 118 | else: 119 | mds = None 120 | 121 | if args.transform!='all': 122 | if args.approach=='noisy2cleaned': 123 | clean_dirs_dict = {split:[] for split in ['train','valid']} 124 | clean_dirs={'train':[]} 125 | for i, data_dir in enumerate(data_dirs): 126 | for step in steps: 127 | if mds is not None: 128 | md = mds[i][step] 129 | else: 130 | experiment = 'clean_'+args.method+ '_' + args.transform + args.tag +'_step'+str(args.step)+args.version 131 | if args.seed>=0: 132 | experiment += ',seed='+str(args.seed) 133 | md = pd.read_csv(os.path.join(data_dir, experiment+".csv")) 134 | counts = md['dataset'].value_counts() 135 | counts_filter = counts[counts>1000] ### for stats remove datasets with less than 1000 files 136 | if len(counts_filter)==0: 137 | counts = counts 138 | filter_flag = False 139 | else: 140 | counts = counts_filter 141 | filter_flag = True 142 | magnitude = int(math.log10(counts.median()))+1 143 | for col in md['dataset'].unique(): 144 | if all(not col.startswith(x) for x in args.exclude): 145 | ds = md[md['dataset']==col] 146 | ds.sort_values(by='metric',ascending=False,ignore_index=True) 147 | filenames = ds['fn'].values.tolist() 148 | col_magnitude = int(math.log10(len(filenames)))+1 149 | if args.balance and filter_flag and col_magnitude>magnitude: 150 | reduced = int(len(filenames)/ (10**(col_magnitude-magnitude))) 151 | print("Reducing dataset size for {} from {} to {}".format(col, len(filenames), reduced)) 152 | filenames = filenames[:reduced] 153 | filenames = [f for f in filenames if os.path.exists(f)] 154 | if i==1: ### resample to 48kHz from 16kHz 155 | experiment = 'clean_16_'+args.method+ '_' + args.transform + args.tag +'_step'+str(step)+args.version 156 | if args.seed>=0: 157 | experiment += ',seed='+str(args.seed) 158 | out_dir = os.path.join(args.data_dir,'train', experiment) 159 | os.makedirs(out_dir, exist_ok=True) 160 | filenames = [resample(f, out_dir, 48000) for f in filenames] 161 | if len(filenames)>args.num_valid: 162 | if args.num_valid>0: 163 | clean_dirs_dict['valid'].extend(filenames[:args.num_valid]) 164 | clean_dirs_dict['train'].extend(filenames[args.num_valid:int(len(filenames)*args.use_ratio)]) 165 | else: 166 | clean_dirs_dict['train'].extend(filenames) 167 | 168 | if os.path.exists(os.path.join(args.data_dir, 'train','clean')): 169 | clean_dirs={'train':[os.path.join(args.data_dir, 'train','clean',f) for f in os.listdir(os.path.join(args.data_dir, 'train','clean'))]} 170 | elif args.approach=='noisereduce': 171 | clean_dirs = {split:[os.path.join(args.data_dir,'train','noisereduce')] for split in ['train']} 172 | else: 173 | clean_dirs = {split:[os.path.join(args.data_dir,'dev',f) for f in os.listdir(os.path.join(args.data_dir,'dev'))] for split in ['train']} 174 | noise_dirs_dict = {split:[os.path.join(args.data_dir, split,'noise',f) for f in os.listdir(os.path.join(args.data_dir, split,'noise'))] for split in ['train']} 175 | else: ### use all data 176 | if args.approach=='noisy2cleaned': 177 | clean_dirs_dict = {split:[] for split in ['train','valid']} 178 | clean_dirs={'train':[]} 179 | for transforms in ['none','time_scale']: 180 | experiment = 'clean_'+args.method+ '_' + transforms + args.tag +'_step'+str(args.step)+args.version 181 | if args.seed>=0: 182 | experiment += ',seed='+str(args.seed) 183 | md = pd.read_csv(os.path.join(args.data_dir, experiment+".csv")) 184 | for col in md['dataset'].unique(): 185 | if all(not col.startswith(x) for x in args.exclude): 186 | ds = md[md['dataset']==col] 187 | ds.sort_values(by='metric',ascending=False,ignore_index=True) 188 | filenames = ds['fn'].values.tolist() 189 | if len(filenames)>args.num_valid: 190 | if args.num_valid>0: 191 | clean_dirs_dict['valid'].extend(filenames[:args.num_valid]) 192 | clean_dirs_dict['train'].extend(filenames[args.num_valid:int(len(filenames)*args.use_ratio)]) 193 | else: 194 | clean_dirs_dict['train'].extend(filenames) 195 | if os.path.exists(os.path.join(args.data_dir, 'train','clean')): 196 | clean_dirs={'train':[os.path.join(args.data_dir, 'train','clean',f) for f in os.listdir(os.path.join(args.data_dir, 'train','clean'))]} 197 | elif args.approach=='noisereduce': 198 | clean_dirs = {split:[os.path.join(args.data_dir,'train','noisereduce')] for split in ['train']} 199 | else: 200 | clean_dirs = {split:[os.path.join(args.data_dir,'dev',f) for f in os.listdir(os.path.join(args.data_dir,'dev'))] for split in ['train']} 201 | noise_dirs_dict = {split:[os.path.join(args.data_dir, split,'noise',f) for f in os.listdir(os.path.join(args.data_dir, split,'noise'))] for split in ['train']} 202 | if args.overfit: 203 | for td in test_sets.values(): 204 | sample_rate = 16000 if args.data_dir.endswith('16k') else 48000 205 | test_dir = os.path.join(os.path.dirname(args.data_dir), td, str(sample_rate)) 206 | if os.path.exists(test_dir): 207 | clean_dirs['train'].extend([os.path.join(test_dir, 'clean', f) for f in os.listdir(test_dir,'clean')]) 208 | noise_dirs_dict['train'].extend([os.path.join(test_dir, 'noise', f) for f in os.listdir(test_dir, 'noise')]) 209 | json_dict = to_json_folder(clean_dirs, args) 210 | if args.approach=='noisy2cleaned': 211 | json_dict = to_json_list(clean_dirs_dict, args, json_dict) 212 | write_json(json_dict, 'clean.json', args) 213 | # ### match noise diversity to the clean diversity in the validation set 214 | # len_valid_clean = len(json_dict['valid']) if 'valid' in json_dict.keys() else 0 215 | # args.num_valid = int(len_valid_clean // len(noise_dirs_dict['train'])) if len_valid_clean>0 else 0 216 | json_dict = to_json_folder(noise_dirs_dict, args) 217 | write_json(json_dict, 'noise.json', args) 218 | 219 | if __name__ == "__main__": 220 | args = parser.parse_args() 221 | generate_json(args) 222 | 223 | #### python prepare_experiments.py --data_dir /home/marius/data/biodenoising48k/ 224 | #### python prepare_experiments.py --data_dir /home/marius/data/biodenoising48k/ --step 1 --method demucs --transform none -------------------------------------------------------------------------------- /biodenoising/denoiser/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.modules.loss import _Loss 3 | 4 | class RMSSmoothLoss(_Loss): 5 | r"""RMS Smooth Loss. 6 | 7 | Args: 8 | alpha (float): smoothing factor. 9 | reduction (string, optional): Specifies the reduction to apply to 10 | the output: 11 | ``'none'`` | ``'mean'``. ``'none'``: no reduction will be applied, 12 | ``'mean'``: the sum of the output will be divided by the number of 13 | elements in the output. 14 | 15 | Shape: 16 | - est_targets : :math:`(batch, nsrc, ...)`. 17 | - targets: :math:`(batch, nsrc, ...)`. 18 | 19 | Returns: 20 | :class:`torch.Tensor`: 21 | """ 22 | 23 | def __init__(self, threshold=-60, kernel_size=512, reduction="mean"): 24 | super().__init__(reduction=reduction) 25 | self.threshold = threshold 26 | self.kernel_size = kernel_size 27 | 28 | def forward(self, x, y): 29 | """Calculate forward propagation. 30 | Args: 31 | x (Tensor): Predicted signal (B, T). 32 | y (Tensor): Groundtruth signal (B, T). 33 | """ 34 | x = torch.nan_to_num(x, nan=1e-8, posinf=1, neginf=-1) 35 | x_energy = x.pow(2).sqrt() 36 | #y_energy = y.pow(2).mean(axis=-1).sqrt() 37 | 38 | ### x_shifted is x_energy shifted to the right by 1 39 | x_shifted = torch.roll(x_energy, shifts=-1, dims=1) 40 | #y_shifted = torch.cat([y_energy[..., :-1], y_energy[..., 1:]], dim=-1) 41 | ### take difference between x_energy and x_shifted 42 | loss = x_energy[...,1:] - x_shifted[...,1:] 43 | #y_energy = y_energy - y_shifted 44 | 45 | # ### perform convolution on the energy with a constant kernel 46 | # kernel = torch.ones(1, 1, self.kernel_size, self.kernel_size) / self.kernel_size**2 47 | # x_energy = torch.functional.conv1d(x_energy, kernel, padding=self.kernel_size//2) 48 | # #y_energy = torch.functional.conv1d(y_energy, kernel, padding=self.kernel_size//2) 49 | # #### concatenate filters on the x axis 50 | 51 | ### shift x_energy samples to the right 52 | # x_energy = torch.cat([x_energy[..., :-self.kernel_size//2], x_energy[..., self.kernel_size//2:]], dim=-1) 53 | #y_energy = torch.cat([y_energy[..., :-self.kernel_size//2], y_energy[..., self.kernel_size//2:]], dim=-1) 54 | 55 | # y_dB = 20 * torch.log10(y.abs() + 1e-10) 56 | # ### detect non silence segments 57 | # segments = y_dB > self.threshold 58 | # ### delete segments smaller than 10ms 59 | # segments = segments & (y_dB > -60) 60 | # segments = segments.float() 61 | # y = y * segments 62 | # x = x * segments 63 | 64 | return loss.mean() 65 | 66 | 67 | class PairwiseNegSDR(_Loss): 68 | r"""Base class for pairwise negative SI-SDR, SD-SDR and SNR on a batch. 69 | 70 | Args: 71 | sdr_type (str): choose between ``snr`` for plain SNR, ``sisdr`` for 72 | SI-SDR and ``sdsdr`` for SD-SDR [1]. 73 | zero_mean (bool, optional): by default it zero mean the target 74 | and estimate before computing the loss. 75 | take_log (bool, optional): by default the log10 of sdr is returned. 76 | 77 | Shape: 78 | - est_targets : :math:`(batch, nsrc, ...)`. 79 | - targets: :math:`(batch, nsrc, ...)`. 80 | 81 | Returns: 82 | :class:`torch.Tensor`: with shape :math:`(batch, nsrc, nsrc)`. Pairwise losses. 83 | 84 | Examples 85 | >>> import torch 86 | >>> from asteroid.losses import PITLossWrapper 87 | >>> targets = torch.randn(10, 2, 32000) 88 | >>> est_targets = torch.randn(10, 2, 32000) 89 | >>> loss_func = PITLossWrapper(PairwiseNegSDR("sisdr"), 90 | >>> pit_from='pairwise') 91 | >>> loss = loss_func(est_targets, targets) 92 | 93 | References 94 | [1] Le Roux, Jonathan, et al. "SDR half-baked or well done." IEEE 95 | International Conference on Acoustics, Speech and Signal 96 | Processing (ICASSP) 2019. 97 | """ 98 | 99 | def __init__(self, sdr_type, zero_mean=True, take_log=True, EPS=1e-8): 100 | super(PairwiseNegSDR, self).__init__() 101 | assert sdr_type in ["snr", "sisdr", "sdsdr"] 102 | self.sdr_type = sdr_type 103 | self.zero_mean = zero_mean 104 | self.take_log = take_log 105 | self.EPS = EPS 106 | 107 | def forward(self, est_targets, targets): 108 | if targets.size() != est_targets.size() or targets.ndim != 3: 109 | raise TypeError( 110 | f"Inputs must be of shape [batch, n_src, time], got {targets.size()} and {est_targets.size()} instead" 111 | ) 112 | assert targets.size() == est_targets.size() 113 | # Step 1. Zero-mean norm 114 | if self.zero_mean: 115 | mean_source = torch.mean(targets, dim=2, keepdim=True) 116 | mean_estimate = torch.mean(est_targets, dim=2, keepdim=True) 117 | targets = targets - mean_source 118 | est_targets = est_targets - mean_estimate 119 | # Step 2. Pair-wise SI-SDR. (Reshape to use broadcast) 120 | s_target = torch.unsqueeze(targets, dim=1) 121 | s_estimate = torch.unsqueeze(est_targets, dim=2) 122 | 123 | if self.sdr_type in ["sisdr", "sdsdr"]: 124 | # [batch, n_src, n_src, 1] 125 | pair_wise_dot = torch.sum(s_estimate * s_target, dim=3, keepdim=True) 126 | # [batch, 1, n_src, 1] 127 | s_target_energy = torch.sum(s_target**2, dim=3, keepdim=True) + self.EPS 128 | # [batch, n_src, n_src, time] 129 | pair_wise_proj = pair_wise_dot * s_target / s_target_energy 130 | else: 131 | # [batch, n_src, n_src, time] 132 | pair_wise_proj = s_target.repeat(1, s_target.shape[2], 1, 1) 133 | if self.sdr_type in ["sdsdr", "snr"]: 134 | e_noise = s_estimate - s_target 135 | else: 136 | e_noise = s_estimate - pair_wise_proj 137 | # [batch, n_src, n_src] 138 | pair_wise_sdr = torch.sum(pair_wise_proj**2, dim=3) / ( 139 | torch.sum(e_noise**2, dim=3) + self.EPS 140 | ) 141 | if self.take_log: 142 | pair_wise_sdr = 10 * torch.log10(pair_wise_sdr + self.EPS) 143 | return -torch.mean(pair_wise_sdr) 144 | 145 | 146 | class SingleSrcNegSDR(_Loss): 147 | r"""Base class for single-source negative SI-SDR, SD-SDR and SNR. 148 | 149 | Args: 150 | sdr_type (str): choose between ``snr`` for plain SNR, ``sisdr`` for 151 | SI-SDR and ``sdsdr`` for SD-SDR [1]. 152 | zero_mean (bool, optional): by default it zero mean the target and 153 | estimate before computing the loss. 154 | take_log (bool, optional): by default the log10 of sdr is returned. 155 | reduction (string, optional): Specifies the reduction to apply to 156 | the output: 157 | ``'none'`` | ``'mean'``. ``'none'``: no reduction will be applied, 158 | ``'mean'``: the sum of the output will be divided by the number of 159 | elements in the output. 160 | 161 | Shape: 162 | - est_targets : :math:`(batch, time)`. 163 | - targets: :math:`(batch, time)`. 164 | 165 | Returns: 166 | :class:`torch.Tensor`: with shape :math:`(batch)` if ``reduction='none'`` else 167 | [] scalar if ``reduction='mean'``. 168 | 169 | Examples 170 | >>> import torch 171 | >>> from asteroid.losses import PITLossWrapper 172 | >>> targets = torch.randn(10, 2, 32000) 173 | >>> est_targets = torch.randn(10, 2, 32000) 174 | >>> loss_func = PITLossWrapper(SingleSrcNegSDR("sisdr"), 175 | >>> pit_from='pw_pt') 176 | >>> loss = loss_func(est_targets, targets) 177 | 178 | References 179 | [1] Le Roux, Jonathan, et al. "SDR half-baked or well done." IEEE 180 | International Conference on Acoustics, Speech and Signal 181 | Processing (ICASSP) 2019. 182 | """ 183 | 184 | def __init__(self, sdr_type, zero_mean=True, take_log=True, reduction="mean", EPS=1e-8): 185 | assert reduction != "sum", NotImplementedError 186 | super().__init__(reduction=reduction) 187 | 188 | assert sdr_type in ["snr", "sisdr", "sdsdr"] 189 | self.sdr_type = sdr_type 190 | self.zero_mean = zero_mean 191 | self.take_log = take_log 192 | self.EPS = 1e-8 193 | 194 | def forward(self, est_target, target): 195 | if target.ndim == 3 and est_target.ndim == 3: 196 | target = target.squeeze(1) 197 | est_target = est_target.squeeze(1) 198 | if target.size() != est_target.size() or target.ndim != 2: 199 | raise TypeError( 200 | f"Inputs must be of shape [batch, time], got {target.size()} and {est_target.size()} instead" 201 | ) 202 | # Step 1. Zero-mean norm 203 | if self.zero_mean: 204 | mean_source = torch.mean(target, dim=1, keepdim=True) 205 | mean_estimate = torch.mean(est_target, dim=1, keepdim=True) 206 | target = target - mean_source 207 | est_target = est_target - mean_estimate 208 | # Step 2. Pair-wise SI-SDR. 209 | if self.sdr_type in ["sisdr", "sdsdr"]: 210 | # [batch, 1] 211 | dot = torch.sum(est_target * target, dim=1, keepdim=True) 212 | # [batch, 1] 213 | s_target_energy = torch.sum(target**2, dim=1, keepdim=True) + self.EPS 214 | # [batch, time] 215 | scaled_target = dot * target / s_target_energy 216 | else: 217 | # [batch, time] 218 | scaled_target = target 219 | if self.sdr_type in ["sdsdr", "snr"]: 220 | e_noise = est_target - target 221 | else: 222 | e_noise = est_target - scaled_target 223 | # [batch] 224 | losses = torch.sum(scaled_target**2, dim=1) / (torch.sum(e_noise**2, dim=1) + self.EPS) 225 | if self.take_log: 226 | losses = 10 * torch.log10(losses + self.EPS) 227 | losses = losses.mean() if self.reduction == "mean" else losses 228 | return -losses 229 | 230 | 231 | class MultiSrcNegSDR(_Loss): 232 | r"""Base class for computing negative SI-SDR, SD-SDR and SNR for a given 233 | permutation of source and their estimates. 234 | 235 | Args: 236 | sdr_type (str): choose between ``snr`` for plain SNR, ``sisdr`` for 237 | SI-SDR and ``sdsdr`` for SD-SDR [1]. 238 | zero_mean (bool, optional): by default it zero mean the target 239 | and estimate before computing the loss. 240 | take_log (bool, optional): by default the log10 of sdr is returned. 241 | 242 | Shape: 243 | - est_targets : :math:`(batch, nsrc, time)`. 244 | - targets: :math:`(batch, nsrc, time)`. 245 | 246 | Returns: 247 | :class:`torch.Tensor`: with shape :math:`(batch)` if ``reduction='none'`` else 248 | [] scalar if ``reduction='mean'``. 249 | 250 | Examples 251 | >>> import torch 252 | >>> from asteroid.losses import PITLossWrapper 253 | >>> targets = torch.randn(10, 2, 32000) 254 | >>> est_targets = torch.randn(10, 2, 32000) 255 | >>> loss_func = PITLossWrapper(MultiSrcNegSDR("sisdr"), 256 | >>> pit_from='perm_avg') 257 | >>> loss = loss_func(est_targets, targets) 258 | 259 | References 260 | [1] Le Roux, Jonathan, et al. "SDR half-baked or well done." IEEE 261 | International Conference on Acoustics, Speech and Signal 262 | Processing (ICASSP) 2019. 263 | 264 | """ 265 | 266 | def __init__(self, sdr_type, zero_mean=True, take_log=True, EPS=1e-8): 267 | super().__init__() 268 | 269 | assert sdr_type in ["snr", "sisdr", "sdsdr"] 270 | self.sdr_type = sdr_type 271 | self.zero_mean = zero_mean 272 | self.take_log = take_log 273 | self.EPS = 1e-8 274 | 275 | def forward(self, est_targets, targets): 276 | if targets.size() != est_targets.size() or targets.ndim != 3: 277 | raise TypeError( 278 | f"Inputs must be of shape [batch, n_src, time], got {targets.size()} and {est_targets.size()} instead" 279 | ) 280 | if self.zero_mean: 281 | mean_source = torch.mean(targets, dim=2, keepdim=True) 282 | mean_estimate = torch.mean(est_targets, dim=2, keepdim=True) 283 | targets = targets - mean_source 284 | est_targets = est_targets - mean_estimate 285 | # Step 2. Pair-wise SI-SDR. 286 | if self.sdr_type in ["sisdr", "sdsdr"]: 287 | # [batch, n_src] 288 | pair_wise_dot = torch.sum(est_targets * targets, dim=2, keepdim=True) 289 | # [batch, n_src] 290 | s_target_energy = torch.sum(targets**2, dim=2, keepdim=True) + self.EPS 291 | # [batch, n_src, time] 292 | scaled_targets = pair_wise_dot * targets / s_target_energy 293 | else: 294 | # [batch, n_src, time] 295 | scaled_targets = targets 296 | if self.sdr_type in ["sdsdr", "snr"]: 297 | e_noise = est_targets - targets 298 | else: 299 | e_noise = est_targets - scaled_targets 300 | # [batch, n_src] 301 | pair_wise_sdr = torch.sum(scaled_targets**2, dim=2) / ( 302 | torch.sum(e_noise**2, dim=2) + self.EPS 303 | ) 304 | if self.take_log: 305 | pair_wise_sdr = 10 * torch.log10(pair_wise_sdr + self.EPS) 306 | return -torch.mean(pair_wise_sdr, dim=-1) 307 | 308 | 309 | # aliases 310 | pairwise_neg_sisdr = PairwiseNegSDR("sisdr") 311 | pairwise_neg_sdsdr = PairwiseNegSDR("sdsdr") 312 | pairwise_neg_snr = PairwiseNegSDR("snr") 313 | singlesrc_neg_sisdr = SingleSrcNegSDR("sisdr") 314 | singlesrc_neg_sdsdr = SingleSrcNegSDR("sdsdr") 315 | singlesrc_neg_snr = SingleSrcNegSDR("snr") 316 | multisrc_neg_sisdr = MultiSrcNegSDR("sisdr") 317 | multisrc_neg_sdsdr = MultiSrcNegSDR("sdsdr") 318 | multisrc_neg_snr = MultiSrcNegSDR("snr") 319 | --------------------------------------------------------------------------------