├── .gitignore ├── .gitmodules ├── requirements.txt ├── model.py ├── LICENSE ├── util.py ├── reconstruct.py ├── train.py ├── dataset.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | venv/ 2 | .venv/ 3 | venv 4 | __pycache__/ 5 | *.wav 6 | *.ckpt 7 | logs/ -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "sgmse"] 2 | path = sgmse 3 | url = https://github.com/sp-uhh/sgmse.git 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | h5py==3.6.0 2 | ipympl==0.8.8 3 | ipywidgets==7.6.5 4 | jupyter==1.0.0 5 | jupyter-client==6.1.12 6 | jupyter-console==6.4.0 7 | jupyter-core==4.7.1 8 | jupyterlab-pygments==0.1.2 9 | jupyterlab-widgets==1.0.2 10 | librosa==0.9.1 11 | Ninja==1.10.2.3 12 | numpy==1.22.2 13 | pandas==1.4.0 14 | pesq==0.0.4 15 | Pillow==9.0.1 16 | protobuf==3.19.4 17 | pyroomacoustics==0.6.0 18 | pystoi==0.3.3 19 | pytorch-lightning==1.6.5 20 | scipy==1.8.0 21 | sdeint==0.2.4 22 | setuptools==59.5.0 # fixes https://github.com/pytorch/pytorch/issues/69894 23 | seaborn==0.11.2 24 | torch==1.12.0 25 | torch-ema==0.3 26 | torchaudio==0.12.0 27 | torchvision==0.13.0 28 | torchinfo==1.6.3 29 | torchsde==0.2.5 30 | tqdm==4.63.0 31 | wandb==0.12.11 32 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | 4 | sys.path.append("sgmse") 5 | from sgmse.model import ScoreModel 6 | from util import evaluate_model_pr 7 | 8 | class PRScoreModel(ScoreModel): 9 | def __init__(self, **kwargs): 10 | super().__init__(**kwargs) 11 | 12 | def validation_step(self, batch, batch_idx): 13 | loss = self._step(batch, batch_idx) 14 | self.log('valid_loss', loss, on_step=False, on_epoch=True) 15 | 16 | # Evaluate speech enhancement performance 17 | if batch_idx == 0 and self.num_eval_files != 0: 18 | pesq, si_sdr, estoi = evaluate_model_pr(self, self.num_eval_files) 19 | self.log('pesq', pesq, on_step=False, on_epoch=True) 20 | self.log('si_sdr', si_sdr, on_step=False, on_epoch=True) 21 | self.log('estoi', estoi, on_step=False, on_epoch=True) 22 | 23 | return loss 24 | 25 | def _pA(self, xt, y, eps=1e-8): 26 | return y.abs()*xt/torch.clamp(xt.abs(), min=eps) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Signal Processing (SP), Universität Hamburg 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchaudio import load 3 | 4 | from pesq import pesq 5 | from pystoi import stoi 6 | 7 | from sgmse.util.other import si_sdr, pad_spec 8 | 9 | def evaluate_model_pr(model, num_eval_files): 10 | # Settings 11 | sr = 16000 12 | snr = 0.5 13 | N = 30 14 | corrector_steps = 0 15 | 16 | valid_set = model.data_module.valid_set 17 | 18 | # Select test files uniformly accros validation files 19 | total_num_files = len(valid_set) 20 | indices = torch.linspace(0, total_num_files-1, num_eval_files, dtype=torch.int) 21 | 22 | _pesq = 0 23 | _si_sdr = 0 24 | _estoi = 0 25 | 26 | # iterate over files 27 | for i in indices: 28 | # Load wavs 29 | X, Y = valid_set[i] 30 | x, y = model.to_audio(X), model.to_audio(Y) 31 | 32 | norm_factor = y.abs().max().item() 33 | y = y / norm_factor 34 | x = x / norm_factor 35 | 36 | # Reverse sampling 37 | sampler = model.get_pc_sampler( 38 | 'reverse_diffusion', 'none', Y.unsqueeze(0).cuda(), N=N, 39 | corrector_steps=corrector_steps, snr=snr) 40 | sample, _ = sampler() 41 | 42 | x_hat = model.to_audio(sample.squeeze(), y.shape[-1]) 43 | x_hat = x_hat.squeeze().cpu().numpy() 44 | 45 | x = x.squeeze().cpu().numpy() 46 | y = y.squeeze().cpu().numpy() 47 | 48 | _si_sdr += si_sdr(x, x_hat) 49 | _pesq += pesq(sr, x, x_hat, 'wb') 50 | _estoi += stoi(x, x_hat, sr, extended=True) 51 | 52 | return _pesq/num_eval_files, _si_sdr/num_eval_files, _estoi/num_eval_files 53 | 54 | -------------------------------------------------------------------------------- /reconstruct.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import soundfile as sf 3 | from torchaudio import load 4 | import sys 5 | from argparse import ArgumentParser 6 | 7 | 8 | sys.path.append("sgmse") 9 | from sgmse.util.other import pad_spec 10 | from model import PRScoreModel 11 | 12 | def main(): 13 | parser = ArgumentParser() 14 | parser.add_argument("--input", type=str, required=True, help="Path to input WAV file") 15 | parser.add_argument("--output", type=str, required=True, help="Output filename for reconstructed audio") 16 | parser.add_argument("--ckpt", type=str, required=True, help="Path to model checkpoint") 17 | parser.add_argument("--N", type=int, default=30, help="The number of steps for the reverse SDE solver") 18 | 19 | 20 | args = parser.parse_args() 21 | 22 | # Load score model 23 | model = PRScoreModel.load_from_checkpoint(args.ckpt, base_dir='', batch_size=1, num_workers=0, kwargs=dict(gpu=False)) 24 | model.eval(no_ema=False) 25 | model.freeze() 26 | model.cuda() 27 | 28 | reconstruct(in_file=args.input, out_file=args.output, model=model, N=args.N) 29 | 30 | 31 | def reconstruct(in_file, out_file, model, N): 32 | model_fs = 16000 33 | 34 | # Load wav 35 | y, fs = load(in_file) 36 | assert fs == model_fs 37 | T_orig = y.size(1) 38 | 39 | # Normalize 40 | norm_factor = y.abs().max() 41 | y = y / norm_factor 42 | 43 | # Prepare DNN input 44 | Y = torch.unsqueeze(model._forward_transform(model._stft(y.cuda())), 0) 45 | Y = pad_spec(Y) 46 | 47 | # Discard phase 48 | Y = Y.abs() + 0j 49 | 50 | # Reverse sampling 51 | sampler = model.get_pc_sampler('reverse_diffusion', "none", Y.cuda(), N=N,corrector_steps=0, snr=0) 52 | 53 | sample, _ = sampler() 54 | 55 | # Apply final magnitude projection (enforce known magnitudes on output) 56 | sample = model._pA(sample, Y) 57 | 58 | # Backward transform in time domain 59 | x_hat = model.to_audio(sample.squeeze(), T_orig) 60 | 61 | # Renormalize 62 | x_hat = x_hat * norm_factor 63 | 64 | # Write enhanced wav file 65 | sf.write(out_file,x_hat.cpu().numpy(), model_fs) 66 | 67 | 68 | if __name__ == "__main__": 69 | main() 70 | 71 | 72 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from argparse import ArgumentParser 3 | import sys 4 | 5 | import pytorch_lightning as pl 6 | from pytorch_lightning.strategies import DDPStrategy 7 | from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger 8 | from pytorch_lightning.callbacks import ModelCheckpoint 9 | 10 | sys.path.append("sgmse") 11 | from sgmse.backbones.shared import BackboneRegistry 12 | from sgmse.sdes import SDERegistry 13 | 14 | from model import PRScoreModel 15 | from dataset import PRDataModule 16 | 17 | 18 | def get_argparse_groups(parser): 19 | groups = {} 20 | for group in parser._action_groups: 21 | group_dict = { a.dest: getattr(args, a.dest, None) for a in group._group_actions } 22 | groups[group.title] = argparse.Namespace(**group_dict) 23 | return groups 24 | 25 | 26 | if __name__ == '__main__': 27 | # throwaway parser for dynamic args - see https://stackoverflow.com/a/25320537/3090225 28 | base_parser = ArgumentParser(add_help=False) 29 | parser = ArgumentParser() 30 | for parser_ in (base_parser, parser): 31 | parser_.add_argument("--backbone", type=str, choices=BackboneRegistry.get_all_names(), default="ncsnpp") 32 | parser_.add_argument("--sde", type=str, choices=SDERegistry.get_all_names(), default="ouve") 33 | parser_.add_argument("--no_wandb", action='store_true', help="Turn off logging to W&B, using local default logger instead") 34 | 35 | temp_args, _ = base_parser.parse_known_args() 36 | 37 | # Add specific args for PRScoreModel, pl.Trainer, the SDE class and backbone DNN class 38 | backbone_cls = BackboneRegistry.get_by_name(temp_args.backbone) 39 | sde_class = SDERegistry.get_by_name(temp_args.sde) 40 | parser = pl.Trainer.add_argparse_args(parser) 41 | PRScoreModel.add_argparse_args( 42 | parser.add_argument_group("PRScoreModel", description=PRScoreModel.__name__)) 43 | sde_class.add_argparse_args( 44 | parser.add_argument_group("SDE", description=sde_class.__name__)) 45 | backbone_cls.add_argparse_args( 46 | parser.add_argument_group("Backbone", description=backbone_cls.__name__)) 47 | 48 | # Add data module args 49 | data_module_cls = PRDataModule 50 | data_module_cls.add_argparse_args( 51 | parser.add_argument_group("DataModule", description=data_module_cls.__name__)) 52 | # Parse args and separate into groups 53 | args = parser.parse_args() 54 | arg_groups = get_argparse_groups(parser) 55 | 56 | # Initialize logger, trainer, model, datamodule 57 | model = PRScoreModel( 58 | backbone=args.backbone, sde=args.sde, data_module_cls=data_module_cls, 59 | **{ 60 | **vars(arg_groups['PRScoreModel']), 61 | **vars(arg_groups['SDE']), 62 | **vars(arg_groups['Backbone']), 63 | **vars(arg_groups['DataModule']) 64 | } 65 | ) 66 | 67 | # Set up logger configuration 68 | if args.no_wandb: 69 | logger = TensorBoardLogger(save_dir="logs", name="tensorboard") 70 | else: 71 | logger = WandbLogger(project="diffphase", log_model=True, save_dir="logs") 72 | logger.experiment.log_code(".") 73 | 74 | # Set up callbacks for logger 75 | callbacks = [ModelCheckpoint(dirpath=f"logs/{logger.version}", save_last=True, filename='{epoch}-last')] 76 | if args.num_eval_files: 77 | checkpoint_callback_pesq = ModelCheckpoint(dirpath=f"logs/{logger.version}", 78 | save_top_k=2, monitor="pesq", mode="max", filename='{epoch}-{pesq:.2f}') 79 | checkpoint_callback_si_sdr = ModelCheckpoint(dirpath=f"logs/{logger.version}", 80 | save_top_k=2, monitor="si_sdr", mode="max", filename='{epoch}-{si_sdr:.2f}') 81 | callbacks += [checkpoint_callback_pesq, checkpoint_callback_si_sdr] 82 | 83 | # Initialize the Trainer and the DataModule 84 | trainer = pl.Trainer.from_argparse_args( 85 | arg_groups['pl.Trainer'], 86 | strategy=DDPStrategy(find_unused_parameters=False), logger=logger, 87 | log_every_n_steps=10, num_sanity_val_steps=0, 88 | callbacks=callbacks 89 | ) 90 | 91 | # Train model 92 | trainer.fit(model) 93 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | import torch 3 | import pytorch_lightning as pl 4 | from torch.utils.data import Dataset 5 | from glob import glob 6 | from torchaudio import load 7 | import numpy as np 8 | import torch.nn.functional as F 9 | 10 | from sgmse.data_module import SpecsDataModule 11 | 12 | 13 | class PRSpecs(Dataset): 14 | def __init__(self, data_dir, subset, dummy, shuffle_spec, num_frames, 15 | format='default', spec_transform=None, 16 | stft_kwargs=None, phase_init="zero", **ignored_kwargs): 17 | 18 | # Read file paths according to file naming format. 19 | if format == "default": 20 | self.clean_files = sorted(glob(join(data_dir, subset) + '/clean/*.wav')) 21 | elif format == "no_noisy": 22 | self.clean_files = sorted(glob(join(data_dir, subset, '**','*.wav'))) 23 | else: 24 | # Feel free to add your own directory format 25 | raise NotImplementedError(f"Directory format {format} unknown!") 26 | 27 | self.dummy = dummy 28 | self.num_frames = num_frames 29 | self.shuffle_spec = shuffle_spec 30 | self.spec_transform = spec_transform 31 | self.phase_init = phase_init 32 | 33 | assert all(k in stft_kwargs.keys() for k in ["n_fft", "hop_length", "center", "window"]), "misconfigured STFT kwargs" 34 | self.stft_kwargs = stft_kwargs 35 | self.hop_length = self.stft_kwargs["hop_length"] 36 | assert self.stft_kwargs.get("center", None) == True, "'center' must be True for current implementation" 37 | 38 | def __getitem__(self, i): 39 | x, _ = load(self.clean_files[i]) 40 | 41 | 42 | # formula applies for center=True 43 | target_len = (self.num_frames - 1) * self.hop_length 44 | current_len = x.size(-1) 45 | pad = max(target_len - current_len, 0) 46 | if pad == 0: 47 | # extract random part of the audio file 48 | if self.shuffle_spec: 49 | start = int(np.random.uniform(0, current_len-target_len)) 50 | else: 51 | start = int((current_len-target_len)/2) 52 | x = x[..., start:start+target_len] 53 | else: 54 | # pad audio if the length T is smaller than num_frames 55 | x = F.pad(x, (pad//2, pad//2+(pad%2)), mode='constant') 56 | 57 | normfac = x.abs().max() 58 | x = x / normfac 59 | 60 | # X is the Clean complex spectrogram 61 | X = torch.stft(x, **self.stft_kwargs) 62 | 63 | # Y is the phaseless (or random phase) spectrogram 64 | if self.phase_init == 'random': 65 | Y = X.abs() * torch.exp(1j * 2*np.pi * torch.rand_like(X.abs())) 66 | elif self.phase_init == 'zero': 67 | Y = X.abs() + 0j 68 | else: 69 | raise NotImplementedError(f"Phase initialization mode {self.phase_init} not implemented!") 70 | 71 | X, Y = self.spec_transform(X), self.spec_transform(Y) 72 | return X, Y 73 | 74 | def __len__(self): 75 | if self.dummy: 76 | # for debugging shrink the data set size 77 | return int(len(self.clean_files)/200) 78 | else: 79 | return len(self.clean_files) 80 | 81 | class PRDataModule(SpecsDataModule): 82 | @staticmethod 83 | def add_argparse_args(parser): 84 | parser = super(PRDataModule, PRDataModule).add_argparse_args(parser) 85 | parser.add_argument("--phase_init", type=str, default="zero", choices=["zero", "random"], help="Type of phase initalization") 86 | return parser 87 | 88 | def __init__(self, phase_init="zero", **kwargs): 89 | super().__init__(**kwargs) 90 | self.phase_init = phase_init 91 | 92 | def setup(self, stage=None): 93 | specs_kwargs = dict( 94 | stft_kwargs=self.stft_kwargs, num_frames=self.num_frames, 95 | spec_transform=self.spec_fwd, phase_init=self.phase_init, **self.kwargs 96 | ) 97 | if stage == 'fit' or stage is None: 98 | self.train_set = PRSpecs(data_dir=self.base_dir, subset='train', 99 | dummy=self.dummy, shuffle_spec=True, format=self.format, **specs_kwargs) 100 | self.valid_set = PRSpecs(data_dir=self.base_dir, subset='valid', 101 | dummy=self.dummy, shuffle_spec=False, format=self.format, **specs_kwargs) 102 | if stage == 'test' or stage is None: 103 | self.test_set = PRSpecs(data_dir=self.base_dir, subset='test', 104 | dummy=self.dummy, shuffle_spec=False, format=self.format, **specs_kwargs) 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DiffPhase: Generative Diffusion-based STFT Phase Retrieval 2 | 3 | This repository contains the official PyTorch implementation for the paper [1]: 4 | 5 | - Tal Peer, Simon Welker, Timo Gerkmann. [*"DiffPhase: Generative Diffusion-based STFT Phase Retrieval"*](https://ieeexplore.ieee.org/abstract/document/10095396), 2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), Rhodes Island, Greece, Jun. 2023. [[arxiv]](https://arxiv.org/abs/2211.04332) [[bibtex]](#citations--references) 6 | 7 | Audio examples are available [on our project page](https://www.inf.uni-hamburg.de/en/inst/ab/sp/publications/icassp2023-diffphase). 8 | 9 | DiffPhase is an adaptation of the SGMSE+ diffusion-based speech enhancement method to phase retrieval. SGMSE+ is described in [2] and [3] and has its own [repository](https://github.com/sp-uhh/sgmse). 10 | 11 | ## Installation 12 | - Clone this repository along with the [sgmse](https://github.com/sp-uhh/sgmse) repository which is included as a submodule: 13 | ```bash 14 | git clone --recurse-submodules https://github.com/sp-uhh/diffphase.git 15 | ``` 16 | - Create a new virtual environment with Python 3.8 (we have not tested other Python versions, but they may work). 17 | - Install the package dependencies via `pip install -r requirements.txt`. 18 | - If using W&B logging (default): 19 | - Set up a [wandb.ai](https://wandb.ai/) account 20 | - Log in via `wandb login` before running our code. 21 | - If not using W&B logging: 22 | - Pass the option `--no_wandb` to `train.py`. 23 | - Your logs will be stored as local TensorBoard logs. Run `tensorboard --logdir logs/` to see them. 24 | 25 | 26 | ## Pretrained checkpoints 27 | 28 | We provide two pretrained checkpoints: 29 | - [DiffPhase](https://drive.google.com/file/d/19sQLF20kmkdvCxVhiP2e8y_BrrFqwTyB/view?usp=sharing) using the default SGMSE configuration. This model has ~65M parameters 30 | - [DiffPhase-small](https://drive.google.com/file/d/1zsp-bqhB9G_KWHeK8HaFAgNSzZ5epVbW/view?usp=sharing) with ~22M parameters 31 | 32 | Usage: 33 | - For resuming training, you can use the `--resume_from_checkpoint` option of `train.py`. 34 | - For performing phase reconstructions with these checkpoints, use the `--ckpt` option of `reconstruct.py` (see section **Evaluation** below). 35 | 36 | 37 | ## Training 38 | 39 | Training is done by executing `train.py`. A minimal running example with default settings (as in our paper [1]) can be run with 40 | 41 | ```bash 42 | python train.py --base_dir 43 | ``` 44 | 45 | where `your_base_dir` should be a path to a folder containing subdirectories `train/` and `valid/`. Each subdirectory must itself have a directory `clean/`. We currently only support training with `.wav` files sampled at 16 kHz. 46 | 47 | For the DiffPhase-small variant, use the following options: 48 | 49 | ```bash 50 | python train.py --num_res_blocks 1 --attn_resolutions 0 --ch_mult 1 1 2 2 1 --base_dir 51 | ``` 52 | 53 | To see all available training options, run `python train.py --help`. Also see the [sgmse](https://github.com/sp-uhh/sgmse) repository for more information. 54 | 55 | 56 | ## Evaluation 57 | 58 | We provide an example script that takes a `.wav` file as an input, removes the phase and writes a reconstructed signal to another `.wav` file. Reconstruction is performed using the same procedure described in our paper. To use it, run 59 | 60 | ```bash 61 | python reconstruct.py --input --output --ckpt --N 62 | ``` 63 | 64 | 65 | ## Citations / References 66 | 67 | We kindly ask you to cite our paper in your publication when using any of our research or code: 68 | ```bib 69 | @inproceedings{peerDiffPhase2023, 70 | title = {{DiffPhase: Generative Diffusion-based STFT Phase Retrieval}}, 71 | booktitle = {{2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}}, 72 | author = {Peer, Tal and Welker, Simon and Gerkmann, Timo}, 73 | date = {2023-06}, 74 | doi = {10.1109/ICASSP49357.2023.10095396} 75 | } 76 | ``` 77 | 78 | >[1] Tal Peer, Simon Welker, Timo Gerkmann. "DiffPhase: Generative Diffusion-based STFT Phase Retrieval", 2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), Rhodes Island, Greece, Jun. 2023. 79 | > 80 | >[2] Simon Welker, Julius Richter, Timo Gerkmann. "Speech Enhancement with Score-Based Generative Models in the Complex STFT Domain", ISCA Interspeech, Incheon, Korea, Sep. 2022. 81 | > 82 | >[3] Julius Richter, Simon Welker, Jean-Marie Lemercier, Bunlong Lay, Timo Gerkmann. "Speech Enhancement and Dereverberation with Diffusion-Based Generative Models", IEEE/ACM Transactions on Audio, Speech, and Language Processing, vol. 31, pp. 2351-2364, 2023. 83 | --------------------------------------------------------------------------------