├── .gitignore ├── LICENSE ├── README.md ├── SharedTrainer.py ├── configs ├── NB-BLSTM.yaml ├── NBC.yaml ├── NBC2.yaml ├── SpatialNet.yaml ├── datasets │ ├── CHiME3_moving_rir_cfg.npz │ ├── chime3_moving.yaml │ ├── chime3_moving_plus_static.yaml │ ├── chime3_static.yaml │ ├── sms_wsj.yaml │ ├── sms_wsj_plus.yaml │ ├── sms_wsj_plus_diffuse.npz │ ├── sms_wsj_rir_cfg.npz │ ├── spatialized_wsj0_mix.yaml │ └── whamr.yaml ├── onlineSpatialNet.yaml └── wsj0-mix │ ├── mix_2_spk_cv.txt │ ├── mix_2_spk_tr.txt │ ├── mix_2_spk_tt.txt │ ├── mix_3_spk_cv.txt │ ├── mix_3_spk_tr.txt │ ├── mix_3_spk_tt.txt │ └── speaker_gender.csv ├── data_loaders ├── __init__.py ├── chime3_moving.py ├── libricss.py ├── reverb.py ├── sms_wsj.py ├── sms_wsj_plus.py ├── spatialized_wsj0_mix.py ├── spk4_wsj0_mix_sp.py ├── utils │ ├── array_geometry.py │ ├── collate_func.py │ ├── diffuse_noise.py │ ├── mix.py │ ├── my_distributed_sampler.py │ ├── rand.py │ └── window.py └── whamr.py ├── generate_rirs.py ├── images ├── model_size_and_flops.png └── results.png ├── models ├── __init__.py ├── arch │ ├── NBC.py │ ├── NBC2.py │ ├── NBSS.py │ ├── OnlineSpatialNet.py │ ├── SpatialNet.py │ ├── base │ │ ├── linear_group.py │ │ ├── non_linear.py │ │ ├── norm.py │ │ └── retention.py │ └── blstm2_fc1.py ├── io │ ├── __init__.py │ ├── cirm.py │ ├── loss.py │ ├── norm.py │ └── stft.py ├── oracle_beamformer.py └── utils │ ├── __init__.py │ ├── base_cli.py │ ├── dnsmos.py │ ├── ensemble.py │ ├── flops.py │ ├── general_steps.py │ ├── git_tools.py │ ├── metrics.py │ ├── my_earlystopping.py │ ├── my_json_encoder.py │ ├── my_logger.py │ ├── my_progress_bar.py │ ├── my_rich_progress_bar.py │ ├── my_save_config_callback.py │ └── shared_cli.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .mypy_cache 3 | .vscode 4 | lightning_logs 5 | logs 6 | .dataset 7 | *.log 8 | *.out 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Changsheng Quan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multichannel Speech Separation, Denoising and Dereverberation 2 | 3 | The official repo of: 4 | [1] Changsheng Quan, Xiaofei Li. [Multi-channel Narrow-band Deep Speech Separation with Full-band Permutation Invariant Training](https://arxiv.org/abs/2110.05966). In ICASSP 2022. 5 | [2] Changsheng Quan, Xiaofei Li. [Multichannel Speech Separation with Narrow-band Conformer](https://arxiv.org/abs/2204.04464). In Interspeech 2022. 6 | [3] Changsheng Quan, Xiaofei Li. [NBC2: Multichannel Speech Separation with Revised Narrow-band Conformer](https://arxiv.org/abs/2212.02076). arXiv:2212.02076. 7 | [4] Changsheng Quan, Xiaofei Li. [SpatialNet: Extensively Learning Spatial Information for Multichannel Joint Speech Separation, Denoising and Dereverberation](https://arxiv.org/abs/2307.16516). TASLP, 2024. 8 | [5] Changsheng Quan, Xiaofei Li. [Multichannel Long-Term Streaming Neural Speech Enhancement for Static and Moving Speakers](https://arxiv.org/abs/2403.07675). IEEE Signal Precessing Letters, 2024. 9 | 10 | Audio examples can be found at [https://audio.westlake.edu.cn/Research/nbss.htm](https://audio.westlake.edu.cn/Research/nbss.htm) and [https://audio.westlake.edu.cn/Research/SpatialNet.htm](https://audio.westlake.edu.cn/Research/SpatialNet.htm). 11 | More information about our group can be found at [https://audio.westlake.edu.cn](https://audio.westlake.edu.cn/Publications.htm). 12 | 13 | ## Performance 14 | SpatialNet: 15 | - Performance 16 |
17 | - Computational cost 18 |
19 | 20 | ## Requirements 21 | 22 | ```bash 23 | pip install -r requirements.txt 24 | 25 | # gpuRIR: check https://github.com/DavidDiazGuerra/gpuRIR 26 | ``` 27 | 28 | ## Generate Dataset SMS-WSJ-Plus 29 | 30 | Generate rirs for the dataset `SMS-WSJ_plus` used in `SpatialNet` ablation experiment. 31 | 32 | ```bash 33 | CUDA_VISIBLE_DEVICES=0 python generate_rirs.py --rir_dir ~/datasets/SMS_WSJ_Plus_rirs --save_to configs/datasets/sms_wsj_rir_cfg.npz 34 | cp configs/datasets/sms_wsj_plus_diffuse.npz ~/datasets/SMS_WSJ_Plus_rirs/diffuse.npz # copy diffuse parameters 35 | ``` 36 | 37 | For SMS-WSJ, please see https://github.com/fgnt/sms_wsj 38 | 39 | ## Train & Test 40 | 41 | This project is built on the `pytorch-lightning` package, in particular its [command line interface (CLI)](https://pytorch-lightning.readthedocs.io/en/latest/cli/lightning_cli_intermediate.html). Thus we recommond you to have some knowledge about the CLI in lightning. For Chinese user, you can learn CLI & lightning with this begining project [pytorch_lightning_template_for_beginners](https://github.com/Audio-WestlakeU/pytorch_lightning_template_for_beginners). 42 | 43 | **Train** SpatialNet on the 0-th GPU with network config file `configs/SpatialNet.yaml` and dataset config file `configs/datasets/sms_wsj_plus.yaml` (replace the rir & clean speech dir before training). 44 | 45 | ```bash 46 | python SharedTrainer.py fit \ 47 | --config=configs/SpatialNet.yaml \ # network config 48 | --config=configs/datasets/sms_wsj_plus.yaml \ # dataset config 49 | --model.channels=[0,1,2,3,4,5] \ # the channels used 50 | --model.arch.dim_input=12 \ # input dim per T-F point, i.e. 2 * the number of channels 51 | --model.arch.dim_output=4 \ # output dim per T-F point, i.e. 2 * the number of sources 52 | --model.arch.num_freqs=129 \ # the number of frequencies, related to model.stft.n_fft 53 | --trainer.precision=bf16-mixed \ # mixed precision training, can also be 16-mixed or 32, where 32 can produce the best performance 54 | --model.compile=true \ # compile the network, requires torch>=2.0. the compiled model is trained much faster 55 | --data.batch_size=[2,4] \ # batch size for train and val 56 | --trainer.devices=0, \ 57 | --trainer.max_epochs=100 # better performance may be obtained if more epochs are given 58 | ``` 59 | 60 | More gpus can be used by appending the gpu indexes to `trainer.devices`, e.g. `--trainer.devices=0,1,2,3,`. 61 | 62 | **Resume** training from a checkpoint: 63 | 64 | ```bash 65 | python SharedTrainer.py fit --config=logs/SpatialNet/version_x/config.yaml \ 66 | --data.batch_size=[2,2] \ 67 | --trainer.devices=0, \ 68 | --ckpt_path=logs/SpatialNet/version_x/checkpoints/last.ckpt 69 | ``` 70 | 71 | where `version_x` should be replaced with the version you want to resume. 72 | 73 | **Test** the model trained: 74 | 75 | ```bash 76 | python SharedTrainer.py test --config=logs/SpatialNet/version_x/config.yaml \ 77 | --ckpt_path=logs/SpatialNet/version_x/checkpoints/epochY_neg_si_sdrZ.ckpt \ 78 | --trainer.devices=0, 79 | ``` 80 | 81 | ## Module Version 82 | 83 | | network | file | 84 | |:---|:---| 85 | | NB-BLSTM [1] / NBC [2] / NBC2 [3] | models/arch/NBSS.py | 86 | | SpatialNet [4] | models/arch/SpatialNet.py | 87 | | online SpatialNet [5] | models/arch/OnlineSpatialNet.py | 88 | 89 | ## Note 90 | The dataset generation & training commands for the `NB-BLSTM`/`NBC`/`NBC2` are available in the `NBSS` branch. 91 | -------------------------------------------------------------------------------- /configs/NB-BLSTM.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 2 2 | trainer: 3 | gradient_clip_val: 5 4 | gradient_clip_algorithm: norm 5 | devices: null 6 | accelerator: gpu 7 | strategy: auto 8 | sync_batchnorm: false 9 | precision: 32 10 | model: 11 | arch: 12 | class_path: models.arch.blstm2_fc1.BLSTM2_FC1 13 | init_args: 14 | activation: "" 15 | hidden_size: 16 | - 256 17 | - 128 18 | n_repeat_last_lstm: 1 19 | dropout: null 20 | channels: [0, 1, 2, 3, 4, 5] 21 | ref_channel: 0 22 | stft: 23 | class_path: models.io.stft.STFT 24 | init_args: 25 | n_fft: 256 26 | n_hop: 128 27 | loss: 28 | class_path: models.io.loss.Loss 29 | init_args: 30 | loss_func: models.io.loss.neg_si_sdr 31 | pit: True 32 | norm: 33 | class_path: models.io.norm.Norm 34 | init_args: 35 | mode: frequency 36 | optimizer: [Adam, { lr: 0.001 }] 37 | lr_scheduler: [ReduceLROnPlateau, { mode: min, factor: 0.5, patience: 10, min_lr: 0.0001 }] 38 | exp_name: exp 39 | metrics: [SDR, SI_SDR, NB_PESQ, WB_PESQ, eSTOI] 40 | val_metric: loss 41 | -------------------------------------------------------------------------------- /configs/NBC.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 2 2 | trainer: 3 | gradient_clip_val: 5 4 | gradient_clip_algorithm: norm 5 | devices: null 6 | accelerator: gpu 7 | strategy: auto 8 | sync_batchnorm: false 9 | precision: 32 10 | model: 11 | arch: 12 | class_path: models.arch.NBC.NBC 13 | init_args: 14 | # dim_input: 12 15 | # dim_output: 4 16 | encoder_kernel_size: 4 17 | n_heads: 8 18 | n_layers: 4 19 | activation: "" 20 | hidden_size: 192 21 | norm_first: true 22 | ffn_size: 384 23 | inner_conv_kernel_size: 3 24 | inner_conv_groups: 8 25 | inner_conv_bias: true 26 | inner_conv_layers: 3 27 | inner_conv_mid_norm: GN 28 | channels: [0, 1, 2, 3, 4, 5] 29 | ref_channel: 0 30 | stft: 31 | class_path: models.io.stft.STFT 32 | init_args: 33 | n_fft: 256 34 | n_hop: 128 35 | loss: 36 | class_path: models.io.loss.Loss 37 | init_args: 38 | loss_func: models.io.loss.neg_si_sdr 39 | pit: True 40 | norm: 41 | class_path: models.io.norm.Norm 42 | init_args: 43 | mode: frequency 44 | optimizer: [Adam, { lr: 0.001 }] 45 | lr_scheduler: [ReduceLROnPlateau, { mode: min, factor: 0.5, patience: 5, min_lr: 0.0001 }] 46 | exp_name: exp 47 | metrics: [SDR, SI_SDR, NB_PESQ, WB_PESQ, eSTOI] 48 | val_metric: loss 49 | -------------------------------------------------------------------------------- /configs/NBC2.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 2 2 | trainer: 3 | gradient_clip_val: 5 4 | gradient_clip_algorithm: norm 5 | devices: null 6 | accelerator: gpu 7 | strategy: auto 8 | sync_batchnorm: false 9 | precision: 32 10 | model: 11 | arch: 12 | class_path: models.arch.NBC2.NBC2 13 | init_args: 14 | # dim_input: 12 15 | # dim_output: 4 16 | n_layers: 8 # 12 for large 17 | encoder_kernel_size: 5 18 | dim_hidden: 96 # 192 for large 19 | dim_ffn: 192 # 384 for large 20 | num_freqs: 129 21 | block_kwargs: 22 | n_heads: 2 23 | dropout: 0 24 | conv_kernel_size: 3 25 | n_conv_groups: 8 26 | norms: [LN, GBN, GBN] 27 | group_batch_norm_kwargs: 28 | # group_size: 129 29 | share_along_sequence_dim: false 30 | channels: [0, 1, 2, 3, 4, 5] 31 | ref_channel: 0 32 | stft: 33 | class_path: models.io.stft.STFT 34 | init_args: 35 | n_fft: 256 36 | n_hop: 128 37 | loss: 38 | class_path: models.io.loss.Loss 39 | init_args: 40 | loss_func: models.io.loss.neg_si_sdr 41 | pit: True 42 | norm: 43 | class_path: models.io.norm.Norm 44 | init_args: 45 | mode: frequency 46 | optimizer: [Adam, { lr: 0.001 }] 47 | lr_scheduler: [ExponentialLR, { gamma: 0.99 }] 48 | exp_name: exp 49 | metrics: [SDR, SI_SDR, NB_PESQ, WB_PESQ, eSTOI] 50 | val_metric: loss 51 | -------------------------------------------------------------------------------- /configs/SpatialNet.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 2 2 | trainer: 3 | gradient_clip_val: 5 4 | gradient_clip_algorithm: norm 5 | devices: null 6 | accelerator: gpu 7 | strategy: auto 8 | sync_batchnorm: false 9 | precision: 32 10 | model: 11 | arch: 12 | class_path: models.arch.SpatialNet.SpatialNet 13 | init_args: 14 | # dim_input: 12 15 | # dim_output: 4 16 | num_layers: 8 # 12 for large 17 | encoder_kernel_size: 5 18 | dim_hidden: 96 # 192 for large 19 | dim_ffn: 192 # 384 for large 20 | num_heads: 4 21 | dropout: [0, 0, 0] 22 | kernel_size: [5, 3] 23 | conv_groups: [8, 8] 24 | norms: ["LN", "LN", "GN", "LN", "LN", "LN"] 25 | dim_squeeze: 8 # 16 for large 26 | num_freqs: 129 27 | full_share: 0 28 | channels: [0, 1, 2, 3, 4, 5] 29 | ref_channel: 0 30 | stft: 31 | class_path: models.io.stft.STFT 32 | init_args: 33 | n_fft: 256 34 | n_hop: 128 35 | loss: 36 | class_path: models.io.loss.Loss 37 | init_args: 38 | loss_func: models.io.loss.neg_si_sdr 39 | pit: True 40 | norm: 41 | class_path: models.io.norm.Norm 42 | init_args: 43 | mode: frequency 44 | optimizer: [Adam, { lr: 0.001 }] 45 | lr_scheduler: [ExponentialLR, { gamma: 0.99 }] 46 | exp_name: exp 47 | metrics: [SDR, SI_SDR, NB_PESQ, WB_PESQ, eSTOI] 48 | val_metric: loss 49 | -------------------------------------------------------------------------------- /configs/datasets/CHiME3_moving_rir_cfg.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/NBSS/cc42fc8ad2e6642c09b8f4169a85b4766dc22b7e/configs/datasets/CHiME3_moving_rir_cfg.npz -------------------------------------------------------------------------------- /configs/datasets/chime3_moving.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: data_loaders.chime3_moving.CHiME3MovingDataModule 3 | init_args: 4 | wsj0_dir: ~/datasets/wsj0 5 | rir_dir: ~/datasets/CHiME3_moving_rirs 6 | chime3_dir: ~/datasets/CHiME3 7 | target: direct_path 8 | datasets: ["train_moving(0.12,0.4)", "val_moving(0.12,0.4)", "test_moving(0.12,0.4)", "test_moving(0.12,0.4)"] 9 | audio_time_len: [4.0, 32.0, 32.0, 32.0] 10 | snr: [-5, 10] 11 | batch_size: [4, 8] 12 | -------------------------------------------------------------------------------- /configs/datasets/chime3_moving_plus_static.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: data_loaders.chime3_moving.CHiME3MovingDataModule 3 | init_args: 4 | wsj0_dir: ~/datasets/wsj0 5 | rir_dir: ~/datasets/CHiME3_moving_rirs 6 | chime3_dir: ~/datasets/CHiME3 7 | target: direct_path 8 | datasets: ["train_moving(0.12,0.4,0.5)", "val_moving(0.12,0.4,0.5)", "test_moving(0.12,0.4,0.5)", "test_moving(0.12,0.4,0.5)"] 9 | audio_time_len: [4.0, 32.0, 32.0, 32.0] 10 | snr: [-5, 10] 11 | batch_size: [4, 8] 12 | -------------------------------------------------------------------------------- /configs/datasets/chime3_static.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: data_loaders.chime3_moving.CHiME3MovingDataModule 3 | init_args: 4 | wsj0_dir: ~/datasets/wsj0 5 | rir_dir: ~/datasets/CHiME3_moving_rirs 6 | chime3_dir: ~/datasets/CHiME3 7 | target: direct_path 8 | datasets: ["train", "val", "test", "test"] 9 | audio_time_len: [4.0, 32.0, 32.0, 32.0] 10 | snr: [-5, 10] 11 | batch_size: [4, 8] 12 | -------------------------------------------------------------------------------- /configs/datasets/sms_wsj.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: data_loaders.sms_wsj.SmsWsjDataModule 3 | init_args: 4 | sms_wsj_dir: ~/datasets/sms_wsj/ 5 | target: direct_path 6 | audio_time_len: [4.0, 4.0, null] 7 | batch_size: [2, 1] 8 | test_set: test 9 | -------------------------------------------------------------------------------- /configs/datasets/sms_wsj_plus.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: data_loaders.sms_wsj_plus.SmsWsjPlusDataModule 3 | init_args: 4 | sms_wsj_dir: ~/datasets/sms_wsj/ 5 | rir_dir: ~/datasets/SMS_WSJ_Plus_rirs/ 6 | target: direct_path 7 | datasets: ["train_si284", "cv_dev93", "test_eval92", "test_eval92"] 8 | audio_time_len: [4.0, 4.0, null, null] 9 | ovlp: mid 10 | speech_overlap_ratio: [0.1, 1.0] 11 | sir: [-5, 5] 12 | snr: [0, 20] 13 | num_spk: 2 14 | noise_type: ["babble", "white"] 15 | batch_size: [2, 1] 16 | -------------------------------------------------------------------------------- /configs/datasets/sms_wsj_plus_diffuse.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/NBSS/cc42fc8ad2e6642c09b8f4169a85b4766dc22b7e/configs/datasets/sms_wsj_plus_diffuse.npz -------------------------------------------------------------------------------- /configs/datasets/sms_wsj_rir_cfg.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/NBSS/cc42fc8ad2e6642c09b8f4169a85b4766dc22b7e/configs/datasets/sms_wsj_rir_cfg.npz -------------------------------------------------------------------------------- /configs/datasets/spatialized_wsj0_mix.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: data_loaders.spatialized_wsj0_mix.SpatializedWSJ0MixDataModule 3 | init_args: 4 | sp_wsj0_dir: ~/datasets/spatialized-wsj0-mix 5 | version: min 6 | target: reverb 7 | sample_rate: 8000 8 | num_speakers: 2 9 | audio_time_len: [4.0, 4.0, null] 10 | batch_size: [2, 1] 11 | -------------------------------------------------------------------------------- /configs/datasets/whamr.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: data_loaders.whamr.WHAMRDataModule 3 | init_args: 4 | whamr_dir: ~/datasets/whamr 5 | version: min 6 | target: anechoic 7 | sample_rate: 8000 8 | audio_time_len: [4.0, 4.0, null] 9 | batch_size: [2, 1] 10 | -------------------------------------------------------------------------------- /configs/onlineSpatialNet.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 2 2 | trainer: 3 | gradient_clip_val: 1 4 | gradient_clip_algorithm: norm 5 | devices: null 6 | accelerator: gpu 7 | strategy: auto 8 | sync_batchnorm: false 9 | precision: 32 10 | model: 11 | arch: 12 | class_path: models.arch.OnlineSpatialNet.OnlineSpatialNet 13 | init_args: 14 | # dim_input: 16 15 | # dim_output: 4 16 | num_layers: 8 17 | encoder_kernel_size: 5 18 | dim_hidden: 96 19 | dim_ffn: 192 20 | num_heads: 4 21 | dropout: [0, 0, 0] 22 | kernel_size: [5, 3] 23 | conv_groups: [8, 8] 24 | norms: ["LN", "LN", "GN", "LN", "LN", "LN"] 25 | dim_squeeze: 8 26 | # num_freqs: 257 27 | full_share: 0 # set to 9999 to not share the full-band module, which will increase the model performance with the cost of larger parameter size. 28 | attention: mamba(16,4) # mhsa(251)/ret(2)/mamba(16,4) 29 | decay: [4, 5, 9, 10] 30 | rope: false 31 | channels: [0, 1, 2, 3, 4, 5] 32 | ref_channel: 0 33 | stft: 34 | class_path: models.io.stft.STFT 35 | init_args: {} # by default set to {} to avoid using wrong stft config 36 | # n_fft: 256 37 | # n_hop: 128 38 | loss: 39 | class_path: models.io.loss.Loss 40 | init_args: 41 | loss_func: models.io.loss.neg_snr 42 | pit: true 43 | norm: 44 | class_path: models.io.norm.Norm 45 | init_args: 46 | mode: utterance 47 | online: true 48 | optimizer: [AdamW, { lr: 0.001, weight_decay: 0.001}] 49 | lr_scheduler: [ExponentialLR, { gamma: 0.99 }] 50 | # lr_scheduler: [ReduceLROnPlateau, {mode: max, factor: 0.5, patience: 5, min_lr: 0.0001}] 51 | exp_name: exp 52 | metrics: [SNR, SDR, SI_SDR, NB_PESQ, WB_PESQ, eSTOI] 53 | val_metric: loss 54 | early_stopping: 55 | enable: false 56 | monitor: val/metric 57 | patience: 10 58 | mode: max 59 | min_delta: 0.1 60 | -------------------------------------------------------------------------------- /configs/wsj0-mix/speaker_gender.csv: -------------------------------------------------------------------------------- 1 | ID Gender 2 | 001 M 3 | 002 F 4 | 00a F 5 | 00b M 6 | 00c M 7 | 00d M 8 | 00f F 9 | 010 M 10 | 011 F 11 | 012 M 12 | 013 M 13 | 014 F 14 | 015 M 15 | 016 F 16 | 017 F 17 | 018 F 18 | 019 F 19 | 01l M 20 | 01a F 21 | 01b F 22 | 01c F 23 | 01d F 24 | 01e M 25 | 01f F 26 | 01g M 27 | 01h F 28 | 01i M 29 | 01j F 30 | 01k F 31 | 01m F 32 | 01n F 33 | 01o F 34 | 01p F 35 | 01q F 36 | 01r M 37 | 01s M 38 | 01t M 39 | 01u F 40 | 01v F 41 | 01w M 42 | 01x F 43 | 01y M 44 | 01z M 45 | 020 M 46 | 021 M 47 | 022 F 48 | 023 F 49 | 024 M 50 | 025 M 51 | 026 M 52 | 027 F 53 | 028 F 54 | 029 M 55 | 02a F 56 | 02b M 57 | 02c F 58 | 02d F 59 | 02e F 60 | 02f F 61 | 050 F 62 | 051 M 63 | 052 M 64 | 053 F 65 | 200 M 66 | 201 M 67 | 202 F 68 | 203 F 69 | 204 F 70 | 205 F 71 | 206 F 72 | 207 M 73 | 208 M 74 | 209 F 75 | 20a F 76 | 20b F 77 | 20c M 78 | 20d F 79 | 20e F 80 | 20f M 81 | 20g M 82 | 20h F 83 | 20i M 84 | 20j M 85 | 20k M 86 | 20l M 87 | 20m M 88 | 20n M 89 | 20o M 90 | 20p F 91 | 20q M 92 | 20r M 93 | 20s M 94 | 20t F 95 | 20u M 96 | 20v M 97 | 22g M 98 | 22h M 99 | 400 M 100 | 401 F 101 | 403 M 102 | 404 F 103 | 405 M 104 | 406 M 105 | 407 F 106 | 408 M 107 | 409 F 108 | 40a M 109 | 40b M 110 | 40c M 111 | 40d F 112 | 40e F 113 | 40f M 114 | 40g F 115 | 40h F 116 | 40i M 117 | 40j M 118 | 40k M 119 | 40l F 120 | 40m F 121 | 40n M 122 | 40o F 123 | 40p F 124 | 420 F 125 | 421 F 126 | 422 M 127 | 423 M 128 | 430 F 129 | 431 M 130 | 432 F 131 | 440 M 132 | 441 F 133 | 442 M 134 | 443 M 135 | 444 F 136 | 445 F 137 | 446 M 138 | 447 M -------------------------------------------------------------------------------- /data_loaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/NBSS/cc42fc8ad2e6642c09b8f4169a85b4766dc22b7e/data_loaders/__init__.py -------------------------------------------------------------------------------- /data_loaders/spatialized_wsj0_mix.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from os.path import * 4 | from pathlib import Path 5 | from typing import * 6 | from typing import Callable, List, Optional, Tuple 7 | 8 | import numpy as np 9 | import soundfile as sf 10 | import torch 11 | from pytorch_lightning import LightningDataModule 12 | from pytorch_lightning.utilities.rank_zero import rank_zero_info 13 | from torch.utils.data import DataLoader, Dataset 14 | 15 | from data_loaders.utils.collate_func import default_collate_func 16 | from data_loaders.utils.my_distributed_sampler import MyDistributedSampler 17 | from data_loaders.utils.rand import randint 18 | 19 | 20 | class SpatializedWSJMixDataset(Dataset): 21 | """The Spatialized WSJ0-2/3Mix dataset""" 22 | 23 | def __init__( 24 | self, 25 | sp_wsj0_dir: str, 26 | dataset: str, 27 | version: str = 'min', 28 | target: str = 'reverb', 29 | audio_time_len: Optional[float] = None, 30 | sample_rate: int = 8000, 31 | num_speakers: int = 2, 32 | ) -> None: 33 | """The Spatialized WSJ-2/3Mix dataset 34 | 35 | Args: 36 | sp_wsj0_dir: a dir contains [2speakers_reverb] 37 | dataset: tr, cv, tt 38 | target: anechoic or reverb 39 | version: min or max 40 | audio_time_len: cut the audio to `audio_time_len` seconds if given audio_time_len 41 | """ 42 | super().__init__() 43 | assert target in ['anechoic', 'reverb'], target 44 | assert sample_rate in [8000, 16000], sample_rate 45 | assert dataset in ['tr', 'cv', 'tt'], dataset 46 | assert version in ['min', 'max'], version 47 | assert num_speakers in [2, 3], num_speakers 48 | 49 | self.sp_wsj0_dir = str(Path(sp_wsj0_dir).expanduser()) 50 | self.wav_dir = Path(self.sp_wsj0_dir) / f"{num_speakers}speakers_{target}" / {8000: 'wav8k', 16000: 'wav16k'}[sample_rate] / version / dataset 51 | self.files = [basename(str(x)) for x in list((self.wav_dir / 'mix').rglob('*.wav'))] 52 | self.files.sort() 53 | assert len(self.files) > 0, f"dir is empty or not exists: {self.sp_wsj0_dir}" 54 | 55 | self.version = version 56 | self.dataset = dataset 57 | self.target = target 58 | self.audio_time_len = audio_time_len 59 | self.sr = sample_rate 60 | 61 | def __getitem__(self, index_seed: Union[int, Tuple[int, int]]): 62 | if type(index_seed) == int: 63 | index = index_seed 64 | if self.dataset == 'tr': 65 | seed = random.randint(a=0, b=99999999) 66 | else: 67 | seed = index 68 | else: 69 | index, seed = index_seed 70 | g = torch.Generator() 71 | g.manual_seed(seed) 72 | 73 | mix, sr = sf.read(self.wav_dir / 'mix' / self.files[index]) 74 | s1, sr = sf.read(self.wav_dir / 's1' / self.files[index]) 75 | s2, sr = sf.read(self.wav_dir / 's2' / self.files[index]) 76 | assert sr == self.sr, (sr, self.sr) 77 | mix = mix.T 78 | target = np.stack([s1.T, s2.T], axis=0) # [spk, chn, time] 79 | 80 | # pad or cut signals 81 | T = mix.shape[-1] 82 | start = 0 83 | if self.audio_time_len: 84 | frames = int(sr * self.audio_time_len) 85 | if T < frames: 86 | mix = np.pad(mix, pad_width=((0, 0), (0, frames - T)), mode='constant', constant_values=0) 87 | target = np.pad(target, pad_width=((0, 0), (0, 0), (0, frames - T)), mode='constant', constant_values=0) 88 | elif T > frames: 89 | start = randint(g, low=0, high=T - frames) 90 | mix = mix[:, start:start + frames] 91 | target = target[:, :, start:start + frames] 92 | 93 | paras = { 94 | 'index': index, 95 | 'seed': seed, 96 | 'wavname': self.files[index], 97 | 'wavdir': str(self.wav_dir), 98 | 'sample_rate': self.sr, 99 | 'dataset': self.dataset, 100 | 'target': self.target, 101 | 'version': self.version, 102 | 'audio_time_len': self.audio_time_len, 103 | 'start': start, 104 | } 105 | 106 | return torch.as_tensor(mix, dtype=torch.float32), torch.as_tensor(target, dtype=torch.float32), paras 107 | 108 | def __len__(self): 109 | return len(self.files) 110 | 111 | 112 | class SpatializedWSJ0MixDataModule(LightningDataModule): 113 | 114 | def __init__( 115 | self, 116 | sp_wsj0_dir: str, 117 | version: str = 'min', 118 | target: str = 'reverb', 119 | sample_rate: int = 8000, 120 | num_speakers: int = 2, 121 | audio_time_len: Tuple[Optional[float], Optional[float], Optional[float]] = [4.0, 4.0, None], # audio_time_len (seconds) for training, val, test. 122 | batch_size: List[int] = [1, 1], 123 | test_set: str = 'test', # the dataset to test: train, val, test 124 | num_workers: int = 5, 125 | collate_func_train: Callable = default_collate_func, 126 | collate_func_val: Callable = default_collate_func, 127 | collate_func_test: Callable = default_collate_func, 128 | seeds: Tuple[Optional[int], int, int] = [None, 2, 3], # random seeds for train, val and test sets 129 | # if pin_memory=True, will occupy a lot of memory & speed up 130 | pin_memory: bool = True, 131 | # prefetch how many samples, will increase the memory occupied when pin_memory=True 132 | prefetch_factor: int = 5, 133 | persistent_workers: bool = False, 134 | ): 135 | super().__init__() 136 | self.sp_wsj0_dir = sp_wsj0_dir 137 | self.version = version 138 | self.target = target 139 | self.sample_rate = sample_rate 140 | self.num_speakers = num_speakers 141 | self.audio_time_len = audio_time_len 142 | self.persistent_workers = persistent_workers 143 | self.test_set = test_set 144 | 145 | rank_zero_info("dataset: SpatializedWSJ0Mix") 146 | rank_zero_info(f'train/valid/test set: {version} {target} {sample_rate}, time length={audio_time_len}, {num_speakers}spk') 147 | assert audio_time_len[2] == None, 'test audio time length should be None' 148 | 149 | self.batch_size = batch_size[0] 150 | self.batch_size_val = batch_size[1] 151 | self.batch_size_test = 1 152 | if len(batch_size) > 2: 153 | self.batch_size_test = batch_size[2] 154 | rank_zero_info(f'batch size: train={self.batch_size}; val={self.batch_size_val}; test={self.batch_size_test}') 155 | assert self.batch_size_test == 1, "batch size for test should be 1 as the audios have different length" 156 | 157 | self.num_workers = num_workers 158 | 159 | self.collate_func_train = collate_func_train 160 | self.collate_func_val = collate_func_val 161 | self.collate_func_test = collate_func_test 162 | 163 | self.seeds = [] 164 | for seed in seeds: 165 | self.seeds.append(seed if seed is not None else random.randint(0, 1000000)) 166 | 167 | self.pin_memory = pin_memory 168 | self.prefetch_factor = prefetch_factor 169 | 170 | def setup(self, stage=None): 171 | if stage is not None and stage == 'test': 172 | audio_time_len = None 173 | else: 174 | audio_time_len = self.audio_time_len[0] 175 | 176 | self.train = SpatializedWSJMixDataset( 177 | sp_wsj0_dir=self.sp_wsj0_dir, 178 | dataset='tr', 179 | version=self.version, 180 | target=self.target, 181 | audio_time_len=audio_time_len, 182 | sample_rate=self.sample_rate, 183 | ) 184 | self.val = SpatializedWSJMixDataset( 185 | sp_wsj0_dir=self.sp_wsj0_dir, 186 | dataset='cv', 187 | version=self.version, 188 | target=self.target, 189 | audio_time_len=self.audio_time_len[1], 190 | sample_rate=self.sample_rate, 191 | ) 192 | self.test = SpatializedWSJMixDataset( 193 | sp_wsj0_dir=self.sp_wsj0_dir, 194 | dataset='tt', 195 | version=self.version, 196 | target=self.target, 197 | audio_time_len=None, 198 | sample_rate=self.sample_rate, 199 | ) 200 | 201 | def train_dataloader(self) -> DataLoader: 202 | return DataLoader( 203 | self.train, 204 | sampler=MyDistributedSampler(self.train, seed=self.seeds[0], shuffle=True), 205 | batch_size=self.batch_size, 206 | collate_fn=self.collate_func_train, 207 | num_workers=self.num_workers, 208 | prefetch_factor=self.prefetch_factor, 209 | pin_memory=self.pin_memory, 210 | persistent_workers=self.persistent_workers, 211 | ) 212 | 213 | def val_dataloader(self) -> DataLoader: 214 | return DataLoader( 215 | self.val, 216 | sampler=MyDistributedSampler(self.val, seed=self.seeds[1], shuffle=False), 217 | batch_size=self.batch_size_val, 218 | collate_fn=self.collate_func_val, 219 | num_workers=self.num_workers, 220 | prefetch_factor=self.prefetch_factor, 221 | pin_memory=self.pin_memory, 222 | persistent_workers=self.persistent_workers, 223 | ) 224 | 225 | def test_dataloader(self) -> DataLoader: 226 | if self.test_set == 'test': 227 | dataset = self.test 228 | elif self.test_set == 'val': 229 | dataset = self.val 230 | else: # train 231 | dataset = self.train 232 | 233 | return DataLoader( 234 | dataset, 235 | sampler=MyDistributedSampler(self.test, seed=self.seeds[2], shuffle=False), 236 | batch_size=self.batch_size_test, 237 | collate_fn=self.collate_func_test, 238 | num_workers=self.num_workers, 239 | prefetch_factor=self.prefetch_factor, 240 | pin_memory=self.pin_memory, 241 | persistent_workers=self.persistent_workers, 242 | ) 243 | 244 | 245 | if __name__ == '__main__': 246 | """python -m data_loaders.spatialized_wsj0_mix""" 247 | from argparse import ArgumentParser 248 | parser = ArgumentParser("") 249 | parser.add_argument('--sp_wsj0_dir', type=str, default='~/datasets/spatialized-wsj0-mix') 250 | parser.add_argument('--version', type=str, default='min') 251 | parser.add_argument('--target', type=str, default='reverb') 252 | parser.add_argument('--sample_rate', type=int, default=8000) 253 | parser.add_argument('--gen_unprocessed', type=bool, default=True) 254 | parser.add_argument('--gen_target', type=bool, default=True) 255 | parser.add_argument('--save_dir', type=str, default='dataset/sp_wsj') 256 | parser.add_argument('--dataset', type=str, default='train', choices=['train', 'val', 'test']) 257 | 258 | args = parser.parse_args() 259 | os.makedirs(args.save_dir, exist_ok=True) 260 | 261 | if not args.gen_unprocessed and not args.gen_target: 262 | exit() 263 | 264 | datamodule = SpatializedWSJ0MixDataModule(args.sp_wsj0_dir, args.version, target=args.target, sample_rate=args.sample_rate, batch_size=[1, 1], num_workers=1) 265 | datamodule.setup() 266 | if args.dataset == 'train': 267 | dataloader = datamodule.train_dataloader() 268 | elif args.dataset == 'val': 269 | dataloader = datamodule.val_dataloader() 270 | else: 271 | assert args.dataset == 'test' 272 | dataloader = datamodule.test_dataloader() 273 | 274 | for idx, (mix, tar, paras) in enumerate(dataloader): 275 | # write target to dir 276 | print(mix.shape, tar.shape, paras) 277 | 278 | if idx > 10: 279 | continue 280 | 281 | if args.gen_target: 282 | tar_path = Path(f"{args.save_dir}/{args.target}/{args.dataset}").expanduser() 283 | tar_path.mkdir(parents=True, exist_ok=True) 284 | for i in range(2): 285 | # assert np.max(np.abs(tar[0, i, 0, :].numpy())) <= 1 286 | sp = tar_path / (paras[0]['wavname'] + f'_spk{i}.wav') 287 | if not sp.exists(): 288 | sf.write(sp, tar[0, i, 0, :].numpy(), samplerate=paras[0]['sample_rate']) 289 | 290 | # write unprocessed's 0-th channel 291 | if args.gen_unprocessed: 292 | tar_path = Path(f"{args.save_dir}/unprocessed/{args.dataset}").expanduser() 293 | tar_path.mkdir(parents=True, exist_ok=True) 294 | # assert np.max(np.abs(mix[0, 0, :].numpy())) <= 1 295 | sp = tar_path / (paras[0]['wavname']) 296 | if not sp.exists(): 297 | sf.write(sp, mix[0, 0, :].numpy(), samplerate=paras[0]['sample_rate']) 298 | -------------------------------------------------------------------------------- /data_loaders/utils/array_geometry.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tqdm 3 | from numpy.linalg import norm 4 | 5 | 6 | def normalize(vec: np.ndarray) -> np.ndarray: 7 | # get unit vector 8 | vec = vec / norm(vec) 9 | vec = vec / norm(vec) 10 | assert np.isclose(norm(vec), 1), 'norm of vec is not close to 1' 11 | return vec 12 | 13 | 14 | def circular_array_geometry(radius: float, mic_num: int) -> np.ndarray: 15 | # 生成圆阵的拓扑(原点为中心),后期可以通过旋转、改变中心的位置来实现阵列位置的改变 16 | pos_rcv = np.empty((mic_num, 3)) 17 | v1 = np.array([1, 0, 0]) # 第一个麦克风的位置(要求单位向量) 18 | v1 = normalize(v1) # 单位向量 19 | # 将v1绕原点水平旋转angle角度,来生成其他mic的位置 20 | angles = np.arange(0, 2 * np.pi, 2 * np.pi / mic_num) 21 | for idx, angle in enumerate(angles): 22 | x = v1[0] * np.cos(angle) - v1[1] * np.sin(angle) 23 | y = v1[0] * np.sin(angle) + v1[1] * np.cos(angle) 24 | pos_rcv[idx, :] = normalize(np.array([x, y, 0])) 25 | # 设置radius 26 | pos_rcv *= radius 27 | return pos_rcv 28 | 29 | 30 | def linear_array_geometry(radius: float, mic_num: int) -> np.ndarray: 31 | xs = np.arange(start=0, stop=radius * mic_num, step=radius) 32 | xs -= np.mean(xs) # 将中心移动到原点 33 | pos_rcv = np.zeros((mic_num, 3)) 34 | pos_rcv[:, 0] = xs 35 | return pos_rcv 36 | 37 | 38 | def chime3_array_geometry() -> np.ndarray: 39 | # TODO 加入麦克风的朝向向量,以及麦克风的全向/半向 40 | pos_rcv = np.zeros((6, 3)) 41 | pos_rcv[0, :] = np.array([-0.1, 0.095, 0]) 42 | pos_rcv[1, :] = np.array([0, 0.095, 0]) 43 | pos_rcv[2, :] = np.array([0.1, 0.095, 0]) 44 | pos_rcv[3, :] = np.array([-0.1, -0.095, 0]) 45 | pos_rcv[4, :] = np.array([0, -0.095, 0]) 46 | pos_rcv[5, :] = np.array([0.1, -0.095, 0]) 47 | 48 | # 验证边长是否正确,边与边之间是否垂直 49 | assert np.isclose(np.linalg.norm(pos_rcv[0, :] - pos_rcv[1, :]), 0.1), 'distance between #1 and #2 is wrong' 50 | assert np.isclose(np.linalg.norm(pos_rcv[1, :] - pos_rcv[2, :]), 0.1), 'distance between #2 and #3 is wrong' 51 | assert np.isclose(np.linalg.norm(pos_rcv[0, :] - pos_rcv[3, :]), 0.19), 'distance between #1 and #4 is wrong' 52 | assert np.isclose(np.linalg.norm(pos_rcv[2, :] - pos_rcv[5, :]), 0.19), 'distance between #3 and #6 is wrong' 53 | assert np.isclose(np.linalg.norm(pos_rcv[3, :] - pos_rcv[4, :]), 0.1), 'distance between #4 and #5 is wrong' 54 | assert np.isclose(np.linalg.norm(pos_rcv[4, :] - pos_rcv[5, :]), 0.1), 'distance between #5 and #6 is wrong' 55 | assert np.isclose(np.dot(pos_rcv[0, :] - pos_rcv[1, :], pos_rcv[0, :] - pos_rcv[3, :]), 0), 'not vertical' 56 | assert np.isclose(np.dot(pos_rcv[2, :] - pos_rcv[5, :], pos_rcv[4, :] - pos_rcv[5, :]), 0), 'not vertical' 57 | return pos_rcv 58 | 59 | 60 | def libricss_array_geometry() -> np.ndarray: 61 | pos_rcv = np.zeros((7, 3)) 62 | pos_rcv_c = circular_array_geometry(radius=0.0425, mic_num=6) 63 | pos_rcv[1:, :] = pos_rcv_c 64 | return pos_rcv 65 | -------------------------------------------------------------------------------- /data_loaders/utils/collate_func.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Tuple 2 | 3 | import numpy as np 4 | import torch 5 | from torch import Tensor 6 | 7 | 8 | def default_collate_func(batches: List[Tuple[Tensor, Tensor, Dict[str, Any]]]) -> List[Any]: 9 | mini_batch = [] 10 | for x in zip(*batches): 11 | if isinstance(x[0], np.ndarray): 12 | x = [torch.tensor(x[i]) for i in range(len(x))] 13 | if isinstance(x[0], Tensor): 14 | x = torch.stack(x) 15 | mini_batch.append(x) 16 | return mini_batch 17 | -------------------------------------------------------------------------------- /data_loaders/utils/diffuse_noise.py: -------------------------------------------------------------------------------- 1 | ############################################################################################# 2 | # Translated from https://github.com/ehabets/ANF-Generator. Implementation of the method in 3 | # 4 | # Generating Nonstationary Multisensor Signals under a Spatial Coherence Constraint. 5 | # Habets, Emanuël A. P. and Cohen, Israel and Gannot, Sharon 6 | # 7 | # Note: Though the generated noise is diffuse, but it doesn't simulate the reverberation of rooms 8 | # 9 | # Copyright: Changsheng Quan @ Audio Lab of Westlake University 10 | ############################################################################################# 11 | 12 | import math 13 | 14 | import numpy as np 15 | import scipy 16 | from scipy.signal import stft, istft 17 | 18 | 19 | def gen_desired_spatial_coherence(pos_mics: np.ndarray, fs: int, noise_field: str = 'spherical', c: float = 343.0, nfft: int = 256) -> np.ndarray: 20 | """generate desired spatial coherence for one array 21 | 22 | Args: 23 | pos_mics: microphone positions, shape (num_mics, 3) 24 | fs: sampling frequency 25 | noise_field: 'spherical' or 'cylindrical' 26 | c: sound velocity 27 | nfft: points of fft 28 | 29 | Raises: 30 | Exception: Unknown noise field if noise_field != 'spherical' and != 'cylindrical' 31 | 32 | Returns: 33 | np.ndarray: desired spatial coherence, shape [num_mics, num_mics, num_freqs] 34 | np.ndarray: desired mixing matrices, shape [num_freqs, num_mics, num_mics] 35 | 36 | 37 | Reference: E. A. P. Habets, “Arbitrary noise field generator.” https://github.com/ehabets/ANF-Generator 38 | """ 39 | assert pos_mics.shape[1] == 3, pos_mics.shape 40 | M = pos_mics.shape[0] 41 | num_freqs = nfft // 2 + 1 42 | 43 | # compute desired spatial coherence matric 44 | ww = 2 * math.pi * fs * np.array(list(range(num_freqs))) / nfft 45 | dist = np.linalg.norm(pos_mics[:, np.newaxis, :] - pos_mics[np.newaxis, :, :], axis=-1, keepdims=True) 46 | if noise_field == 'spherical': 47 | DSC = np.sinc(ww * dist / (c * math.pi)) 48 | elif noise_field == 'cylindrical': 49 | DSC = scipy.special(0, ww * dist / c) 50 | else: 51 | raise Exception('Unknown noise field') 52 | 53 | # compute mixing matrices of the desired spatial coherence matric 54 | Cs = np.zeros((num_freqs, M, M), dtype=np.complex128) 55 | for k in range(1, num_freqs): 56 | D, V = np.linalg.eig(DSC[:, :, k]) 57 | C = V.T * np.sqrt(D)[:, np.newaxis] 58 | # C = scipy.linalg.cholesky(DSC[:, :, k]) 59 | Cs[k, ...] = C 60 | 61 | return DSC, Cs 62 | 63 | 64 | def gen_diffuse_noise(noise: np.ndarray, L: int, Cs: np.ndarray, nfft: int = 256, rng: np.random.Generator = np.random.default_rng()) -> np.ndarray: 65 | """generate diffuse noise with the mixing matrice of desired spatial coherence 66 | 67 | Args: 68 | noise: at least `num_mic*L` samples long 69 | L: the length in samples 70 | Cs: mixing matrices, shape [num_freqs, num_mics, num_mics] 71 | nfft: the number of fft points 72 | rng: random number generator used for reproducibility 73 | 74 | Returns: 75 | np.ndarray: multi-channel diffuse noise, shape [num_mics, L] 76 | """ 77 | 78 | M = Cs.shape[-1] 79 | assert noise.shape[-1] >= M * L, noise.shape 80 | 81 | # Generate M mutually 'independent' input signals 82 | # noise = noise - np.mean(noise) 83 | assert noise.shape[-1] >= M * L, ("The noise signal should be at least `num_mic*L` samples long", noise.shape, M, L) 84 | start = rng.integers(low=0, high=noise.shape[-1] - M * L + 1) 85 | noise = noise[start:start + M * L].reshape(M, L) 86 | noise = noise - np.mean(noise, axis=-1, keepdims=True) 87 | f, t, N = stft(noise, window='hann', nperseg=nfft, noverlap=0.75 * nfft, nfft=nfft) # N: [M,F,T] 88 | # Generate output in the STFT domain for each frequency bin k 89 | X = np.einsum('fmn,mft->nft', np.conj(Cs), N) 90 | # Compute inverse STFT 91 | F, x = istft(X, window='hann', nperseg=nfft, noverlap=0.75 * nfft, nfft=nfft) 92 | x = x[:, :L] 93 | return x # [M, L] 94 | 95 | 96 | if __name__ == '__main__': 97 | import matplotlib.pyplot as plt 98 | import soundfile as sf 99 | from pathlib import Path 100 | # pos_mic = np.random.randn(8, 3) 101 | # pos_mic = np.array([[0, 0, 1.5], [0, 0.2, 1.5]]) 102 | nfft = 1024 103 | num_mics = 3 104 | pos_mics = [[0, 0, 1.5]] 105 | for i in range(1, num_mics): 106 | pos_mics.append([0, 0.3 * i, 1.5]) 107 | pos_mics = np.array(pos_mics) 108 | DSC, Cs = gen_desired_spatial_coherence(pos_mics=pos_mics, fs=8000, noise_field='spherical', nfft=nfft) 109 | 110 | # wav_files = Path('dataset/datasets/fma_small/000').rglob('*.mp3') 111 | # noise = np.concatenate([sf.read(wav_file,always_2d=True)[0][:,0] for wav_file in wav_files]) 112 | noise = np.random.randn(8000 * 22 * 8) 113 | x = gen_diffuse_noise(noise=noise, T=20, fs=8000, Cs=Cs, nfft=nfft) 114 | 115 | f, t, X = stft(x, window='hann', nperseg=nfft, noverlap=0.75 * nfft, nfft=nfft) # X: [M,F,T] 116 | cross_psd_0 = np.mean(X[[0], :, :] * np.conj(X[1:, :, :]), axis=-1) 117 | cross_psd_1 = np.mean(np.abs(X[[0], :, :])**2, axis=-1) * np.mean(np.abs(X[1:, :, :])**2, axis=-1) 118 | cross_psd = cross_psd_0 / np.sqrt(cross_psd_1) 119 | sc_generated = np.real(cross_psd) # 实部是关于f的偶函数。此处只绘制实部是因为虚部的值小?因为前面的spatial conherence是实数? 120 | 121 | ww = 2 * math.pi * 8000 * np.array(list(range(nfft // 2 + 1))) / nfft 122 | if num_mics > 2: 123 | dist = np.linalg.norm(pos_mics[1:] - pos_mics[[0], ...], axis=-1, keepdims=True) 124 | else: 125 | dist = np.linalg.norm(pos_mics[[1]] - pos_mics[[0], ...], axis=-1, keepdims=True) 126 | sc_theory = np.sinc(ww * dist / (343 * math.pi)) 127 | 128 | for i in range(len(sc_theory)): 129 | plt.plot(list(range(nfft // 2 + 1)), sc_theory[i]) 130 | plt.plot(list(range(nfft // 2 + 1)), sc_generated[i]) 131 | plt.title(f"Chn {i+2} vs. Chn {1}") 132 | plt.show() 133 | -------------------------------------------------------------------------------- /data_loaders/utils/mix.py: -------------------------------------------------------------------------------- 1 | ############################################################################################################## 2 | # Note: there are datasets relaying on these functions. 3 | # Changing the implementations might change the data generated by these datasets. 4 | # 5 | # Copyright: Changsheng Quan @ Audio Lab of Westlake University 6 | ############################################################################################################## 7 | """ 8 | mid: 9 | --- 10 | ------- 11 | headtail: 12 | ------ 13 | ---- 14 | startend: 15 | --- start/end --- 16 | ------ ------ 17 | full: 18 | --------- 19 | --------- 20 | hms: headtail, mid or startend 21 | fhms: full, headtail, mid or startend 22 | """ 23 | 24 | from typing import * 25 | 26 | import numpy as np 27 | from numpy import ndarray 28 | from numpy.random import Generator 29 | from scipy.signal import fftconvolve 30 | 31 | OVLPS = ['mid', 'headtail', 'startend', 'full', 'hms', 'fhms'] 32 | 33 | 34 | def sample_an_overlap(ovlp_type: str, num_spk: int, rng: Generator) -> str: 35 | assert ovlp_type in OVLPS, ovlp_type 36 | assert num_spk in [1, 2], num_spk 37 | 38 | if num_spk == 1: 39 | ovlp_type = 'full' 40 | elif ovlp_type == 'fhms': 41 | types = ['full', 'headtail', 'mid', 'startend'] 42 | which_type = rng.integers(low=0, high=len(types)) 43 | ovlp_type = types[which_type] 44 | elif ovlp_type == 'hms': 45 | types = ['headtail', 'mid', 'startend'] 46 | which_type = rng.integers(low=0, high=len(types)) 47 | ovlp_type = types[which_type] 48 | else: 49 | assert ovlp_type in ['full', 'headtail', 'mid', 'startend'], ovlp_type 50 | 51 | if ovlp_type == 'startend': 52 | types = ['start', 'end'] 53 | which_type = rng.integers(low=0, high=len(types)) 54 | ovlp_type = types[which_type] 55 | else: 56 | ovlp_type = ovlp_type 57 | 58 | return ovlp_type 59 | 60 | 61 | def sample_ovlp_ratio_and_cal_length(ovlp_type: str, ratio_range: Tuple[float, float], target_len: Optional[int], lens: List[int], rng: Generator) -> Tuple[float, List[int], int]: 62 | """sample one overlap ratio and calculate the needed length for each wav 63 | 64 | Returns: 65 | Tuple[float, List[int]]: ovlp_ratio, needed length 66 | """ 67 | for rr in ratio_range: 68 | assert rr >= 0 and rr <= 1, rr 69 | assert ratio_range[0] <= ratio_range[1], ratio_range 70 | 71 | if target_len == None: 72 | mix_frames = max(lens) 73 | if ovlp_type == 'full': 74 | ovlp_ratio = 1 75 | elif ovlp_type == 'headtail': 76 | low = ratio_range[0] 77 | high = np.min(lens) / np.max(lens) 78 | if low > high: 79 | ovlp_ratio = high 80 | else: 81 | ovlp_ratio = rng.uniform(low=low, high=high) 82 | mix_frames = round((np.min(lens) + np.max(lens)) / (1 + ovlp_ratio)) 83 | elif ovlp_type == 'mid': 84 | ovlp_ratio = np.min(lens) / np.max(lens) 85 | else: 86 | assert ovlp_type in ['start', 'end'], ovlp_type 87 | ovlp_ratio = np.min(lens) / np.max(lens) 88 | else: 89 | mix_frames = target_len 90 | ovlp_ratio = rng.uniform(low=ratio_range[0], high=ratio_range[1]) 91 | if ovlp_type == 'full': 92 | lens = [mix_frames] * len(lens) 93 | ovlp_ratio = 1 94 | elif ovlp_type == 'headtail': 95 | lens = [int(mix_frames * (0.5 + ovlp_ratio / 2))] * len(lens) 96 | else: 97 | assert ovlp_type in ['mid', 'start', 'end'], ovlp_type 98 | max_idx = lens.index(max(lens)) 99 | min_idx = lens.index(min(lens)) 100 | if max_idx == min_idx: 101 | max_idx = [1, 0][max_idx] 102 | lens[max_idx] = mix_frames 103 | lens[min_idx] = int(mix_frames * ovlp_ratio) 104 | return ovlp_ratio, lens, mix_frames 105 | 106 | 107 | def pad_or_cut(wavs: List[ndarray], lens: List[int], rng: Generator) -> List[ndarray]: 108 | """repeat signals if they are shorter than the length needed, then cut them to needed 109 | """ 110 | for i, wav in enumerate(wavs): 111 | # repeat 112 | while len(wav) < lens[i]: 113 | wav = np.concatenate([wav, wav]) 114 | # cut to needed length 115 | if len(wav) > lens[i]: 116 | start = rng.integers(low=0, high=len(wav) - lens[i] + 1) 117 | wav = wav[start:start + lens[i]] 118 | wavs[i] = wav 119 | return wavs 120 | 121 | 122 | def convolve(wav: ndarray, rir: ndarray, rir_target: ndarray, ref_channel: Optional[int] = 0, align: bool = True) -> Tuple[ndarray, ndarray]: 123 | assert wav.ndim == 1, wav.shape 124 | assert rir.ndim == 2 and rir_target.ndim == 2, (rir.shape, rir_target.shape) 125 | 126 | rvbt = fftconvolve(wav[np.newaxis, :], rir, mode='full', axes=-1) 127 | target = rvbt if rir is rir_target else fftconvolve(wav[np.newaxis, :], rir_target, mode='full', axes=-1) 128 | if align: 129 | # Note: don't take the dry clean source as target if the ref_channel is not correct 130 | rir = rir[ref_channel, ...] 131 | delay = np.argmax(rir) 132 | rvbt = rvbt[:, delay:delay + wav.shape[-1]] 133 | target = target[:, delay:delay + wav.shape[-1]] 134 | return rvbt, target 135 | 136 | def convolve_v2(wav: ndarray, rir: ndarray, rir_target: ndarray, ref_channel: Optional[int] = 0, align: bool = True) -> Tuple[ndarray, ndarray]: 137 | assert wav.ndim == 1, wav.shape 138 | assert rir.ndim == 2 and rir_target.ndim == 2, (rir.shape, rir_target.shape) 139 | 140 | rvbt = fftconvolve(wav[np.newaxis, :], rir, mode='full', axes=-1) 141 | target = rvbt if rir is rir_target else fftconvolve(wav[np.newaxis, :], rir_target, mode='full', axes=-1) 142 | if align: 143 | # Note: don't take the dry clean source as target if the ref_channel is not correct 144 | rir_align = rir_target[ref_channel, ...] # use rir_target rather than rir 145 | delay = np.argmax(rir_align) 146 | rvbt = rvbt[:, delay:delay + wav.shape[-1]] 147 | target = target[:, delay:delay + wav.shape[-1]] 148 | return rvbt, target 149 | 150 | 151 | def convolve_traj(wav: np.ndarray, traj_rirs: np.ndarray, traj_rirs_tar: np.ndarray, samples_per_rir: Union[np.ndarray, int], ref_channel: Optional[int] = 0, align: bool = True) -> np.ndarray: 152 | """Convolve wav by using a set of trajectory rirs (Note: the generated audio signal using this method has click noise) 153 | 154 | Args: 155 | wav: shape [time] 156 | traj_rirs: shape [num_rirs, num_mics, num_samples] 157 | traj_rirs_tar: shape [num_rirs, num_mics, num_samples] 158 | samples_per_rir: the number of samples in wav for each trajectory rir 159 | ref_channel: reference channel. Defaults to 0. 160 | align: Defaults to True. 161 | 162 | Returns: 163 | np.ndarray: [num_mics, time] 164 | """ 165 | assert wav.ndim == 1, "not implemented" 166 | wav_samps = wav.shape[0] 167 | if isinstance(samples_per_rir, np.ndarray): 168 | assert samples_per_rir.ndim == 1 169 | assert samples_per_rir.sum() == wav.shape[-1], "the number of samples specified in samples_per_rir should match that of the wav" 170 | else: 171 | if wav_samps % samples_per_rir == 0: 172 | samples_per_rir = [samples_per_rir] * (wav_samps // samples_per_rir) 173 | else: 174 | samples_per_rir = [samples_per_rir] * (wav_samps // samples_per_rir) + [wav_samps % samples_per_rir] 175 | (num_rirs, num_mics, rir_samps), rir_samps_t = traj_rirs.shape, traj_rirs_tar.shape[-1] 176 | assert num_rirs == len(samples_per_rir), "the number of rirs should match the length of samples_per_rir" 177 | 178 | rvbt = np.zeros((num_mics, rir_samps + wav_samps - 1), dtype=np.float32) 179 | target = np.zeros((num_mics, rir_samps_t + wav_samps - 1), dtype=np.float32) 180 | start_samp = 0 181 | for i, n_samps in enumerate(samples_per_rir): 182 | wav_i = wav[start_samp:start_samp + n_samps] 183 | rir_i = traj_rirs[i] 184 | rir_i_tar = traj_rirs_tar[i] 185 | rvbt[:, start_samp:start_samp + n_samps + rir_samps - 1] += fftconvolve(wav_i[np.newaxis], rir_i, mode='full', axes=-1) 186 | target[:, start_samp:start_samp + n_samps + rir_samps_t - 1] += fftconvolve(wav_i[np.newaxis], rir_i_tar, mode='full', axes=-1) 187 | start_samp += n_samps 188 | 189 | if align: 190 | rir = traj_rirs_tar[0, ref_channel, ...] 191 | delay = np.argmax(rir) 192 | rvbt = rvbt[:, delay:delay + wav.shape[-1]] 193 | target = target[:, delay:delay + wav.shape[-1]] 194 | return rvbt, target 195 | 196 | 197 | def convolve_traj_with_win(wav: np.ndarray, traj_rirs: np.ndarray, samples_per_rir: int, wintype: str = 'trapezium20') -> np.ndarray: 198 | """Convolve wav by using a set of trajectory rirs (Note: the generated audio signal using this method barely have click noise) 199 | 200 | Args: 201 | wav: shape [T] 202 | traj_rirs: shape [num_rirs, num_mics, num_samples] 203 | samples_per_rir: the number of samples in wav for each trajectory rir 204 | wintype: hann, tri, or trapezium. by default, trapezium20 is used. 205 | 206 | Returns: 207 | np.ndarray: [num_mics, time] 208 | """ 209 | assert wav.ndim == 1, "not implemented" 210 | wav_samps = wav.shape[0] 211 | 212 | hop = samples_per_rir 213 | samples_per_rir = samples_per_rir * 2 214 | num_rirs, num_mics, rir_samps = traj_rirs.shape 215 | 216 | if wintype == 'hann': 217 | win = np.hanning(samples_per_rir) 218 | elif wintype.startswith('trapezium'): # 左右一边10个点 219 | n = int(wintype.replace('trapezium', '')) 220 | assert hop - n > 0, hop 221 | tri = np.arange(0, n) / (n - 1) 222 | tri2 = np.arange((n - 1), -1, -1) / (n - 1) 223 | zlen = (hop - n) // 2 224 | onelen = hop - n - zlen 225 | win = np.concatenate([np.zeros(zlen), tri, np.ones(onelen * 2), tri2, np.zeros(zlen)]) 226 | else: 227 | assert wintype == 'tri', wintype 228 | win = np.concatenate([np.arange(0, samples_per_rir // 2), np.arange(samples_per_rir // 2 - 1, -1, -1)]) / (samples_per_rir // 2 - 1) 229 | 230 | out = np.zeros((num_mics, rir_samps + wav_samps - 1), dtype=np.float32) 231 | for i, start_samp in enumerate(range(0, wav_samps + hop - 1, hop)): 232 | rir_i = traj_rirs[i] 233 | 234 | if start_samp == 0: 235 | wav_i = wav[start_samp:start_samp + hop] * win[hop:] 236 | out[:, start_samp:start_samp + hop + rir_samps - 1] += fftconvolve(wav_i[np.newaxis], rir_i, axes=-1) 237 | elif wav.shape[-1] >= start_samp + hop: 238 | wav_i = wav[start_samp - hop:start_samp + hop] * win 239 | out[:, start_samp - hop:start_samp + hop + rir_samps - 1] += fftconvolve(wav_i[np.newaxis], rir_i, axes=-1) 240 | else: 241 | wav_i = wav[start_samp - hop:] * win[:wav.shape[-1] - start_samp + hop] 242 | out[:, start_samp - hop:] += fftconvolve(wav_i[np.newaxis], rir_i, axes=-1) 243 | 244 | return out 245 | 246 | 247 | def align(rir: np.ndarray, rvbt: np.ndarray, target: np.ndarray, src: np.ndarray): 248 | assert rir.ndim == 1 and src.ndim == 1, (rir.shape, src.shape) 249 | delay = np.argmax(rir) 250 | rvbt = rvbt[:, delay:delay + src.shape[-1]] 251 | target = target[:, delay:delay + src.shape[-1]] 252 | return rvbt, target 253 | 254 | 255 | def convolve1(wav: ndarray, rir: ndarray, ref_channel: Optional[int] = 0, align: bool = True) -> Union[Tuple[ndarray, ndarray], ndarray]: 256 | assert wav.ndim == 1, wav.shape 257 | while wav.ndim < rir.ndim: 258 | wav = wav[np.newaxis, ...] 259 | rvbt = fftconvolve(wav, rir, mode='full', axes=-1) 260 | if align: 261 | # Note: don't take the dry clean source as target if the ref_channel is not correct 262 | if rir.ndim >= 2: 263 | rir = rir[..., ref_channel, :] # the second last dim is regarded as the channel dim 264 | delay = np.argmax(rir) 265 | rvbt = rvbt[..., delay:delay + wav.shape[-1]] 266 | return rvbt 267 | 268 | 269 | def overlap2(rvbts: List[ndarray], targets: List[ndarray], ovlp_type: str, mix_frames: int, rng: Generator) -> Tuple[ndarray, ndarray]: 270 | assert np.array([rvbt_i.shape == target_i.shape for (rvbt_i, target_i) in zip(rvbts, targets)]).all(), "rvbt and target should have the same shape" 271 | assert len(rvbts) <= 2 and len(targets) <= 2, "this function is used only for two-speaker overlapping" 272 | assert rvbts[0].ndim == 2 and rvbts[0].shape[0] < rvbts[0].shape[1], "rvbt should have a shape of [chn, time]" 273 | 274 | num_spk, chn_num = len(rvbts), rvbts[0].shape[0] 275 | 276 | rvbt = np.zeros((num_spk, chn_num, mix_frames), dtype=np.float32) 277 | target = np.zeros((num_spk, chn_num, mix_frames), dtype=np.float32) 278 | 279 | for i, (rvbt_i, target_i) in enumerate(zip(rvbts, targets)): 280 | # overlap signals 281 | Ti = rvbt_i.shape[-1] # use all revbt signals 282 | 283 | if ovlp_type == 'full': 284 | shift = 0 285 | elif ovlp_type == 'mid': 286 | if Ti == mix_frames: 287 | shift = 0 288 | else: 289 | shift = rng.integers(low=0, high=mix_frames - Ti + 1) # [0, mix_frames - use_len] 290 | elif ovlp_type == 'start' or ovlp_type == 'end': 291 | assert num_spk == 2 292 | if Ti == mix_frames: 293 | shift = 0 294 | else: 295 | shift = {'start': 0, 'end': mix_frames - Ti}[ovlp_type] 296 | else: 297 | assert ovlp_type == 'headtail', ovlp_type 298 | assert num_spk == 2 299 | shift = 0 if i == 0 else (mix_frames - Ti) 300 | 301 | rvbt[i, :, shift:shift + Ti] = rvbt_i[:, :] 302 | target[i, :, shift:shift + Ti] = target_i[:, :] 303 | return rvbt, target 304 | 305 | 306 | def overlap3(rvbts: List[ndarray], targets: List[ndarray], mix_frames: int, rng: Generator, output_stream: int = 2) -> Tuple[ndarray, ndarray]: 307 | assert np.array([rvbt_i.shape == target_i.shape for (rvbt_i, target_i) in zip(rvbts, targets)]).all(), "rvbt and target should have the same shape" 308 | assert len(rvbts) == 3 and len(targets) == 3, "this function is used only for 3-speaker overlapping" 309 | assert output_stream == 2, "2-stream output is supported only" 310 | assert rvbts[0].ndim == 2 and rvbts[0].shape[0] < rvbts[0].shape[1], "rvbt should have a shape of [chn, time]" 311 | 312 | num_spk, chn_num = len(rvbts), rvbts[0].shape[0] 313 | 314 | rvbt = np.zeros((2, chn_num, mix_frames), dtype=np.float32) 315 | target = np.zeros((2, chn_num, mix_frames), dtype=np.float32) 316 | 317 | rvbt[0, :, :] = rvbts[0][:, :] 318 | rvbt[1, :, :rvbts[1].shape[-1]] = rvbts[1][:, :] 319 | rvbt[1, :, -rvbts[2].shape[-1]:] = rvbts[2][:, :] 320 | 321 | target[0, :, :] = targets[0][:, :] 322 | target[1, :, :targets[1].shape[-1]] = targets[1][:, :] 323 | target[1, :, -targets[2].shape[-1]:] = targets[2][:, :] 324 | 325 | return rvbt, target 326 | 327 | 328 | def cal_coeff_for_adjusting_relative_energy(wav1: ndarray, wav2: ndarray, target_dB: float) -> Optional[float]: 329 | r"""calculate the coefficient used for adjusting the relative energy of two signals 330 | 331 | Args: 332 | wav1: the first wav 333 | wav2: the second wav 334 | target_dB: the target relative energy in dB, i.e. after adjusting: 10 * log_10 (average_energy(wav1) / average_energy(wav2 * coeff)) = target_dB 335 | 336 | Returns: 337 | float: coeff 338 | """ 339 | # compute averaged energy over time and channel 340 | ae1 = np.sum(wav1**2) / np.prod(wav1.shape) 341 | ae2 = np.sum(wav2**2) / np.prod(wav2.shape) 342 | if ae1 == 0 or ae2 == 0 or not np.isfinite(ae1) or not np.isfinite(ae2): 343 | return None 344 | # compute the coefficients 345 | coeff = np.sqrt(ae1 / ae2 * np.power(10, -target_dB / 10)) 346 | return coeff # multiply it with wav2 347 | -------------------------------------------------------------------------------- /data_loaders/utils/my_distributed_sampler.py: -------------------------------------------------------------------------------- 1 | ############################################################################################################## 2 | # Code reproducity is essential for DL. The MyDistributedSampler tries to make datasets reproducible by 3 | # generating a seed for each dataset item at specific epoch. 4 | # 5 | # Copyright: Changsheng Quan @ Audio Lab of Westlake University 2023 6 | ############################################################################################################## 7 | 8 | 9 | 10 | import math 11 | from typing import Iterator, Optional 12 | 13 | import torch 14 | from pytorch_lightning.utilities.rank_zero import rank_zero_warn 15 | from torch.utils.data import Dataset 16 | from torch.utils.data.distributed import DistributedSampler, T_co 17 | 18 | 19 | class MyDistributedSampler(DistributedSampler[T_co]): 20 | r"""Sampler for single GPU and multi GPU (or Distributed) cases. Change int index to a tuple (index, random seed for this index). 21 | This sampler is used to enhance the reproducibility of datasets by generating random seed for each item at each epoch. 22 | """ 23 | 24 | def __init__( 25 | self, 26 | dataset: Dataset, 27 | num_replicas: Optional[int] = None, 28 | rank: Optional[int] = None, 29 | shuffle: bool = True, 30 | seed: int = 0, 31 | drop_last: bool = False, 32 | ) -> None: 33 | try: 34 | super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last) 35 | except: 36 | # if error raises, it is running on single GPU 37 | # thus, set num_replicas=1, rank=0 38 | super().__init__(dataset, 1, 0, shuffle, seed, drop_last) 39 | self.last_epoch = -1 40 | 41 | def __iter__(self) -> Iterator[T_co]: 42 | if self.shuffle: 43 | # deterministically shuffle based on epoch and seed 44 | g = torch.Generator() 45 | g.manual_seed(self.seed + self.epoch) 46 | if self.last_epoch >= self.epoch: 47 | if self.epoch != 0: 48 | rank_zero_warn(f'shuffle is true but the epoch value doesn\'t get update, thus the order of training data won\'t change at epoch={self.epoch}') 49 | else: 50 | self.last_epoch = self.epoch 51 | indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore 52 | else: 53 | g = torch.Generator() 54 | g.manual_seed(self.seed) 55 | 56 | indices = list(range(len(self.dataset))) # type: ignore 57 | 58 | seeds = [] 59 | for i in range(len(indices)): 60 | seed = torch.randint(high=9999999999, size=(1,), generator=g)[0].item() 61 | seeds.append(seed) 62 | indices = list(zip(indices, seeds)) 63 | 64 | # drop last 65 | if not self.drop_last: 66 | # add extra samples to make it evenly divisible 67 | padding_size = self.total_size - len(indices) 68 | if padding_size <= len(indices): 69 | indices += indices[:padding_size] 70 | else: 71 | indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] 72 | else: 73 | # remove tail of data to make it evenly divisible. 74 | indices = indices[:self.total_size] 75 | assert len(indices) == self.total_size 76 | 77 | # subsample 78 | indices = indices[self.rank:self.total_size:self.num_replicas] 79 | assert len(indices) == self.num_samples 80 | 81 | return iter(indices) # type: ignore 82 | 83 | def __len__(self) -> int: 84 | return self.num_samples 85 | 86 | def set_epoch(self, epoch: int) -> None: 87 | r""" 88 | Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas 89 | use a different random ordering for each epoch. Otherwise, the next iteration of this 90 | sampler will yield the same ordering. 91 | 92 | Args: 93 | epoch (int): Epoch number. 94 | """ 95 | self.epoch = epoch 96 | -------------------------------------------------------------------------------- /data_loaders/utils/rand.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def randint(g: torch.Generator, low: int, high: int) -> int: 5 | """return a value sampled in [low, high) 6 | """ 7 | if low == high: 8 | return low 9 | r = torch.randint(low=low, high=high, size=(1,), generator=g, device='cpu') 10 | return r[0].item() # type:ignore 11 | 12 | 13 | def randfloat(g: torch.Generator, low: float, high: float) -> float: 14 | """return a value sampled in [low, high) 15 | """ 16 | if low == high: 17 | return low 18 | r = torch.rand(size=(1,), generator=g, device='cpu')[0].item() 19 | return float(low + r * (high - low)) 20 | -------------------------------------------------------------------------------- /data_loaders/utils/window.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def reverberation_time_shortening_window(rir: np.ndarray, original_T60: float, target_T60: float, sr: int = 8000, time_after_max: float = 0.002, time_before_max: float = None) -> np.ndarray: 5 | """shorten the T60 of a given rir 6 | 7 | Args: 8 | rir: the rir array 9 | original_T60: the T60 of the rir 10 | target_T60: the target T60 11 | sr: sample rate 12 | time_after_max: time in seconds after the maximum value in rir taken as part of the direct path. Defaults to 0.002. 13 | time_before_max: time in seconds before the maximum value in rir taken as part of the direct path. By default, all the values before the maximum are taken as direct path. 14 | 15 | Returns: 16 | np.ndarray: the reverberation time shortening window 17 | """ 18 | 19 | if original_T60 <= target_T60: 20 | return np.ones(shape=rir.shape) 21 | shape = rir.shape 22 | rir = rir.reshape(-1, shape[-1]) 23 | win = np.empty(shape=rir.shape, dtype=rir.dtype) 24 | q = 3 / (target_T60 * sr) - 3 / (original_T60 * sr) 25 | exps = 10**(-q * np.arange(rir.shape[-1])) 26 | idx_max_array = np.argmax(np.abs(rir), axis=-1) 27 | for i, idx_max in enumerate(idx_max_array): 28 | N1 = idx_max + int(time_after_max * sr) 29 | win[i, :N1] = 1 30 | win[i, N1:] = exps[:rir.shape[-1] - N1] 31 | if time_before_max: 32 | N0 = int(idx_max - time_before_max * sr) 33 | if N0 > 0: 34 | win[i, :N0] = 0 35 | win = win.reshape(shape) 36 | return win 37 | 38 | 39 | def rectangular_window(rir: np.ndarray, sr: int = 8000, time_before_after_max: float = 0.002) -> np.ndarray: 40 | assert rir.ndim == 1, rir.ndim 41 | idx = int(np.argmax(np.abs(rir))) 42 | win = np.zeros(shape=rir.shape) 43 | N = int(sr * time_before_after_max) 44 | win[max(0, idx - N):idx + N + 1] = 1 45 | return win 46 | 47 | 48 | if __name__ == '__main__': 49 | rir = np.random.rand(3, 2, 10000) 50 | rir[..., 1000] = 2 51 | win = reverberation_time_shortening_window(rir, original_T60=0.8, target_T60=0.1, sr=8000) 52 | print(win.shape) 53 | -------------------------------------------------------------------------------- /data_loaders/whamr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from os.path import * 4 | from pathlib import Path 5 | from typing import * 6 | 7 | import numpy as np 8 | import soundfile as sf 9 | import torch 10 | from pytorch_lightning import LightningDataModule 11 | from pytorch_lightning.utilities.rank_zero import rank_zero_info 12 | 13 | from torch.utils.data import DataLoader, Dataset 14 | 15 | from data_loaders.utils.collate_func import default_collate_func 16 | from data_loaders.utils.my_distributed_sampler import MyDistributedSampler 17 | from data_loaders.utils.rand import randint 18 | 19 | 20 | class WHAMRDataset(Dataset): 21 | """The WHAMR! dataset""" 22 | 23 | def __init__( 24 | self, 25 | whamr_dir: str, 26 | dataset: str, 27 | version: str = 'min', 28 | target: str = 'anechoic', 29 | audio_time_len: Optional[float] = None, 30 | sample_rate: int = 8000, 31 | ) -> None: 32 | """The WHAMR! dataset 33 | 34 | Args: 35 | whamr_dir: a dir contains [wav8k] 36 | dataset: tr, cv, tt 37 | target: anechoic or reverb 38 | version: min or max 39 | audio_time_len: cut the audio to `audio_time_len` seconds if given audio_time_len 40 | """ 41 | super().__init__() 42 | assert target in ['anechoic', 'reverb'], target 43 | assert sample_rate in [8000, 16000], sample_rate 44 | assert dataset in ['tr', 'cv', 'tt'], dataset 45 | assert version in ['min', 'max'], version 46 | 47 | self.whamr_dir = str(Path(whamr_dir).expanduser()) 48 | self.wav_dir = Path(self.whamr_dir) / {8000: 'wav8k', 16000: 'wav16k'}[sample_rate] / version / dataset 49 | self.files = [basename(str(x)) for x in list((self.wav_dir / 'mix_both_reverb').rglob('*.wav'))] 50 | self.files.sort() 51 | assert len(self.files) > 0, (self.whamr_dir, ': files is empty!') 52 | 53 | self.version = version 54 | self.dataset = dataset 55 | self.target = target 56 | self.audio_time_len = audio_time_len 57 | self.sr = sample_rate 58 | 59 | def __getitem__(self, index_seed: Union[int, Tuple[int, int]]): 60 | if type(index_seed) == int: 61 | index = index_seed 62 | if self.dataset == 'tr': 63 | seed = random.randint(a=0, b=99999999) 64 | else: 65 | seed = index 66 | else: 67 | index, seed = index_seed 68 | g = torch.Generator() 69 | g.manual_seed(seed) 70 | 71 | mix, sr = sf.read(self.wav_dir / 'mix_both_reverb' / self.files[index]) 72 | s1, sr = sf.read(self.wav_dir / ('s1_' + self.target) / self.files[index]) 73 | s2, sr = sf.read(self.wav_dir / ('s2_' + self.target) / self.files[index]) 74 | assert sr == self.sr, (sr, self.sr) 75 | mix = mix.T 76 | target = np.stack([s1.T, s2.T], axis=0) # [spk, chn, time] 77 | 78 | # pad or cut signals 79 | T = mix.shape[-1] 80 | start = 0 81 | if self.audio_time_len: 82 | frames = int(sr * self.audio_time_len) 83 | if T < frames: 84 | mix = np.pad(mix, pad_width=((0, 0), (0, frames - T)), mode='constant', constant_values=0) 85 | target = np.pad(target, pad_width=((0, 0), (0, 0), (0, frames - T)), mode='constant', constant_values=0) 86 | elif T > frames: 87 | start = randint(g, low=0, high=T - frames) 88 | mix = mix[:, start:start + frames] 89 | target = target[:, :, start:start + frames] 90 | 91 | paras = { 92 | 'index': index, 93 | 'seed': seed, 94 | 'wavname': self.files[index], 95 | 'wavdir': str(self.wav_dir), 96 | 'sample_rate': self.sr, 97 | 'dataset': self.dataset, 98 | 'target': self.target, 99 | 'version': self.version, 100 | 'audio_time_len': self.audio_time_len, 101 | 'start': start, 102 | } 103 | 104 | return torch.as_tensor(mix, dtype=torch.float32), torch.as_tensor(target, dtype=torch.float32), paras 105 | 106 | def __len__(self): 107 | return len(self.files) 108 | 109 | 110 | class WHAMRDataModule(LightningDataModule): 111 | 112 | def __init__( 113 | self, 114 | whamr_dir: str, 115 | version: str = 'min', 116 | target: str = 'anechoic', 117 | sample_rate: int = 8000, 118 | audio_time_len: Tuple[Optional[float], Optional[float], Optional[float]] = [4.0, 4.0, None], # audio_time_len (seconds) for training, val, test. 119 | batch_size: List[int] = [1, 1], 120 | test_set: str = 'test', # the dataset to test: train, val, test 121 | num_workers: int = 5, 122 | collate_func_train: Callable = default_collate_func, 123 | collate_func_val: Callable = default_collate_func, 124 | collate_func_test: Callable = default_collate_func, 125 | seeds: Tuple[Optional[int], int, int] = [None, 2, 3], # random seeds for train, val and test sets 126 | # if pin_memory=True, will occupy a lot of memory & speed up 127 | pin_memory: bool = True, 128 | # prefetch how many samples, will increase the memory occupied when pin_memory=True 129 | prefetch_factor: int = 5, 130 | persistent_workers: bool = False, 131 | ): 132 | super().__init__() 133 | self.whamr_dir = whamr_dir 134 | self.version = version 135 | self.target = target 136 | self.sample_rate = sample_rate 137 | self.audio_time_len = audio_time_len 138 | self.persistent_workers = persistent_workers 139 | self.test_set = test_set 140 | 141 | rank_zero_info(f'dataset: WHAMR!, datasets for train/valid/test: {version} {target} {sample_rate}, time length: {audio_time_len}') 142 | assert audio_time_len[2] == None, audio_time_len 143 | 144 | self.batch_size = batch_size[0] 145 | self.batch_size_val = batch_size[1] 146 | self.batch_size_test = 1 147 | if len(batch_size) > 2: 148 | self.batch_size_test = batch_size[2] 149 | rank_zero_info(f'batch size: train={self.batch_size}; val={self.batch_size_val}; test={self.batch_size_test}') 150 | # assert self.batch_size_val == 1, "batch size for validation should be 1 as the audios have different length" 151 | 152 | self.num_workers = num_workers 153 | 154 | self.collate_func_train = collate_func_train 155 | self.collate_func_val = collate_func_val 156 | self.collate_func_test = collate_func_test 157 | 158 | self.seeds = [] 159 | for seed in seeds: 160 | self.seeds.append(seed if seed is not None else random.randint(0, 1000000)) 161 | 162 | self.pin_memory = pin_memory 163 | self.prefetch_factor = prefetch_factor 164 | 165 | def setup(self, stage=None): 166 | if stage is not None and stage == 'test': 167 | audio_time_len = None 168 | else: 169 | audio_time_len = self.audio_time_len 170 | 171 | self.train = WHAMRDataset( 172 | whamr_dir=self.whamr_dir, 173 | dataset='tr', 174 | version=self.version, 175 | target=self.target, 176 | audio_time_len=self.audio_time_len[0] if stage != 'test' else None, 177 | sample_rate=self.sample_rate, 178 | ) 179 | self.val = WHAMRDataset( 180 | whamr_dir=self.whamr_dir, 181 | dataset='cv', 182 | version=self.version, 183 | target=self.target, 184 | audio_time_len=self.audio_time_len[1] if stage != 'test' else None, 185 | sample_rate=self.sample_rate, 186 | ) 187 | self.test = WHAMRDataset( 188 | whamr_dir=self.whamr_dir, 189 | dataset='tt', 190 | version=self.version, 191 | target=self.target, 192 | audio_time_len=self.audio_time_len[2], 193 | sample_rate=self.sample_rate, 194 | ) 195 | 196 | def train_dataloader(self) -> DataLoader: 197 | return DataLoader( 198 | self.train, 199 | sampler=MyDistributedSampler(self.train, seed=self.seeds[0], shuffle=True), 200 | batch_size=self.batch_size, 201 | collate_fn=self.collate_func_train, 202 | num_workers=self.num_workers, 203 | prefetch_factor=self.prefetch_factor, 204 | pin_memory=self.pin_memory, 205 | persistent_workers=self.persistent_workers, 206 | ) 207 | 208 | def val_dataloader(self) -> DataLoader: 209 | return DataLoader( 210 | self.val, 211 | sampler=MyDistributedSampler(self.val, seed=self.seeds[1], shuffle=False), 212 | batch_size=self.batch_size_val, 213 | collate_fn=self.collate_func_val, 214 | num_workers=self.num_workers, 215 | prefetch_factor=self.prefetch_factor, 216 | pin_memory=self.pin_memory, 217 | persistent_workers=self.persistent_workers, 218 | ) 219 | 220 | def test_dataloader(self) -> DataLoader: 221 | if self.test_set == 'test': 222 | dataset = self.test 223 | elif self.test_set == 'val': 224 | dataset = self.val 225 | else: # train 226 | dataset = self.train 227 | 228 | return DataLoader( 229 | dataset, 230 | sampler=MyDistributedSampler(self.test, seed=self.seeds[2], shuffle=False), 231 | batch_size=self.batch_size_test, 232 | collate_fn=self.collate_func_test, 233 | num_workers=self.num_workers, 234 | prefetch_factor=self.prefetch_factor, 235 | pin_memory=self.pin_memory, 236 | persistent_workers=self.persistent_workers, 237 | ) 238 | 239 | 240 | if __name__ == '__main__': 241 | """python -m data_loaders.whamr""" 242 | from argparse import ArgumentParser 243 | parser = ArgumentParser("") 244 | parser.add_argument('--whamr_dir', type=str, default='~/datasets/whamr') 245 | parser.add_argument('--version', type=str, default='min') 246 | parser.add_argument('--target', type=str, default='anechoic') 247 | parser.add_argument('--sample_rate', type=int, default=8000) 248 | parser.add_argument('--gen_unprocessed', type=bool, default=True) 249 | parser.add_argument('--gen_target', type=bool, default=True) 250 | parser.add_argument('--save_dir', type=str, default='dataset/whamr') 251 | parser.add_argument('--dataset', type=str, default='train', choices=['train', 'val', 'test']) 252 | 253 | args = parser.parse_args() 254 | os.makedirs(args.save_dir, exist_ok=True) 255 | 256 | if not args.gen_unprocessed and not args.gen_target: 257 | exit() 258 | 259 | datamodule = WHAMRDataModule(args.whamr_dir, args.version, target=args.target, sample_rate=args.sample_rate, batch_size=[1, 1], num_workers=1) 260 | datamodule.setup() 261 | if args.dataset == 'train': 262 | dataloader = datamodule.train_dataloader() 263 | elif args.dataset == 'val': 264 | dataloader = datamodule.val_dataloader() 265 | else: 266 | assert args.dataset == 'test' 267 | dataloader = datamodule.test_dataloader() 268 | 269 | for idx, (mix, tar, paras) in enumerate(dataloader): 270 | # write target to dir 271 | print(mix.shape, tar.shape, paras) 272 | 273 | if idx > 10: 274 | continue 275 | 276 | if args.gen_target: 277 | tar_path = Path(f"{args.save_dir}/{args.target}/{args.dataset}").expanduser() 278 | tar_path.mkdir(parents=True, exist_ok=True) 279 | for i in range(2): 280 | # assert np.max(np.abs(tar[0, i, 0, :].numpy())) <= 1 281 | sp = tar_path / (paras[0]['wavname'] + f'_spk{i}.wav') 282 | if not sp.exists(): 283 | sf.write(sp, tar[0, i, 0, :].numpy(), samplerate=paras[0]['sample_rate']) 284 | 285 | # write unprocessed's 0-th channel 286 | if args.gen_unprocessed: 287 | tar_path = Path(f"{args.save_dir}/unprocessed/{args.dataset}").expanduser() 288 | tar_path.mkdir(parents=True, exist_ok=True) 289 | # assert np.max(np.abs(mix[0, 0, :].numpy())) <= 1 290 | sp = tar_path / (paras[0]['wavname']) 291 | if not sp.exists(): 292 | sf.write(sp, mix[0, 0, :].numpy(), samplerate=paras[0]['sample_rate']) 293 | -------------------------------------------------------------------------------- /images/model_size_and_flops.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/NBSS/cc42fc8ad2e6642c09b8f4169a85b4766dc22b7e/images/model_size_and_flops.png -------------------------------------------------------------------------------- /images/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/NBSS/cc42fc8ad2e6642c09b8f4169a85b4766dc22b7e/images/results.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/NBSS/cc42fc8ad2e6642c09b8f4169a85b4766dc22b7e/models/__init__.py -------------------------------------------------------------------------------- /models/arch/NBC.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Callable, Optional, Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.nn.init as init 8 | from torch import Tensor 9 | 10 | 11 | class Linear(nn.Linear): 12 | """ 13 | Wrapper class of torch.nn.Linear 14 | Weight initialize by xavier initialization and bias initialize to zeros. 15 | """ 16 | 17 | def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: 18 | super(Linear, self).__init__(in_features=in_features, out_features=out_features, bias=bias) 19 | 20 | init.xavier_uniform_(self.weight) 21 | if bias: 22 | init.zeros_(self.bias) 23 | 24 | 25 | class RelativePositionalEncoding(nn.Module): 26 | """This class returns the relative positional encoding in a range. 27 | 28 | for i in [-m, m]: 29 | PE(pos, 2i) = sin(pos/(10000^(2i/dmodel))) 30 | PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel))) 31 | 32 | Arguments 33 | --------- 34 | max_len : int 35 | Max length of the input sequences (default 1000). 36 | 37 | Example 38 | ------- 39 | >>> a = torch.rand((8, 120, 512)) 40 | >>> enc = RelativePositionalEncoding(input_size=a.shape[-1]) 41 | >>> b = enc(a) 42 | >>> b.shape 43 | torch.Size([1, 239, 512]) 44 | """ 45 | 46 | def __init__(self, input_size, max_len=1000): 47 | super().__init__() 48 | # [-m, -m+1, ..., -1, 0, 1, ..., m] 49 | self.max_len = max_len 50 | self.zero_index = max_len 51 | pe = torch.zeros(self.max_len * 2 + 1, input_size, requires_grad=False) 52 | positions = torch.arange(-self.max_len, self.max_len + 1).unsqueeze(1).float() 53 | denominator = torch.exp(torch.arange(0, input_size, 2).float() * -(math.log(10000.0) / input_size)) 54 | 55 | pe[:, 0::2] = torch.sin(positions * denominator) 56 | pe[:, 1::2] = torch.cos(positions * denominator) 57 | pe = pe.unsqueeze(0) 58 | self.register_buffer("pe", pe) 59 | 60 | def forward(self, x: torch.Tensor) -> torch.Tensor: 61 | """ 62 | Args: 63 | x (torch.Tensor): shape [batch, time, feature] 64 | 65 | Returns: 66 | torch.Tensor: relative positional encoding 67 | """ 68 | B, T, F = x.shape 69 | start, end = -T + 1, T - 1 70 | return self.pe[:, start + self.zero_index:end + self.zero_index + 1].clone().detach() 71 | 72 | 73 | class RelativePositionalMultiHeadAttention(nn.Module): 74 | """ 75 | Multi-head attention with relative positional encoding. 76 | This concept was proposed in the "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" 77 | """ 78 | 79 | def __init__( 80 | self, 81 | d_model: int = 256, 82 | num_heads: int = 8, 83 | dropout: float = 0.1, 84 | ): 85 | super(RelativePositionalMultiHeadAttention, self).__init__() 86 | assert d_model % num_heads == 0, "d_model % num_heads should be zero." 87 | self.d_model = d_model 88 | self.d_head = int(d_model / num_heads) 89 | self.num_heads = num_heads 90 | self.sqrt_dim = math.sqrt(d_model) 91 | 92 | self.query_proj = Linear(d_model, d_model) 93 | self.key_proj = Linear(d_model, d_model) 94 | self.value_proj = Linear(d_model, d_model) 95 | self.pos_proj = Linear(d_model, d_model, bias=False) 96 | self.rel_pos = RelativePositionalEncoding(d_model) 97 | 98 | self.dropout = nn.Dropout(p=dropout) 99 | self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) 100 | self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) 101 | init.xavier_uniform_(self.u_bias) 102 | init.xavier_uniform_(self.v_bias) 103 | 104 | self.out_proj = Linear(d_model, d_model) 105 | 106 | def forward(self, query: Tensor, key: Optional[Tensor] = None, value: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: 107 | if key is None: 108 | key = query 109 | if value is None: 110 | value = query 111 | 112 | batch_size, time_frames, feature_size = value.shape 113 | pos_embedding = self.rel_pos.forward(value) 114 | 115 | query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head) 116 | key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3) 117 | value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3) 118 | 119 | content_score = torch.matmul((query + self.u_bias).transpose(1, 2), key.transpose(2, 3)) 120 | 121 | pos_embedding = self.pos_proj(pos_embedding) 122 | #### my implementation to calculate pos_score #### 123 | pos_index = self._get_relative_pos_index(time_frames, pos_embedding.device) 124 | pos_embedding = pos_embedding[:, pos_index, :].view(1, time_frames, time_frames, self.num_heads, self.d_head) 125 | # original matmul 126 | # pos_score = torch.matmul((query + self.v_bias).transpose(1, 2).unsqueeze(3), pos_embedding.permute(0, 3, 1, 4, 2)).squeeze(3) 127 | # faster than matmul, and saving memory 128 | qv_bias = (query + self.v_bias).transpose(1, 2) # [B, N, T, D] 129 | pos_embedding = pos_embedding.permute(0, 3, 1, 4, 2) # [1, N, T, D, T], 1 is broadcasted to B 130 | pos_score = torch.einsum("abcd,abcdf->abcf", qv_bias, pos_embedding) 131 | score = (content_score + pos_score) / self.sqrt_dim 132 | 133 | if attn_mask is not None: 134 | attn += attn_mask # give -inf to mask the corresponding point 135 | 136 | attn = F.softmax(score, -1) 137 | attn = self.dropout(attn) 138 | output = torch.matmul(attn, value).transpose(1, 2) # [B, T, N, D] 139 | 140 | # output 141 | output = output.contiguous().view(batch_size, -1, self.d_model) 142 | 143 | return self.out_proj(output), attn 144 | 145 | def _get_relative_pos_index(self, T: int, device: torch.device) -> Tensor: 146 | with torch.no_grad(): 147 | pos1 = torch.arange(start=0, end=T, dtype=torch.long, device=device, requires_grad=False).unsqueeze(1) 148 | pos2 = torch.arange(start=0, end=T, dtype=torch.long, device=device, requires_grad=False).unsqueeze(0) 149 | relative_pos = pos1 - pos2 150 | """ now, relative_pos=[ 151 | [0,-1,-2,...,-(T-1)], 152 | [1, 0,-1,...,-(T-2)], 153 | ... 154 | [T-1,T-2,..., 1, 0] 155 | ] 156 | """ 157 | pos_index = relative_pos[:, :] + (T - 1) # (T-1) is the index of the relative position 0 158 | return pos_index 159 | 160 | 161 | class NBCBlock(nn.Module): 162 | 163 | def __init__( 164 | self, 165 | dim_model: int = 192, 166 | num_head: int = 8, 167 | dim_ffn: int = 384, 168 | dropout: float = 0.1, 169 | activation: Callable = F.silu, 170 | layer_norm_eps: float = 1e-5, 171 | norm_first: bool = True, 172 | n_conv_groups: int = 384, 173 | conv_kernel_size: int = 3, 174 | conv_bias: bool = True, 175 | n_conv_layers: int = 3, 176 | conv_mid_norm: str = "GN", 177 | ) -> None: 178 | super().__init__() 179 | 180 | self.self_attn = RelativePositionalMultiHeadAttention(dim_model, num_head, dropout=dropout) 181 | 182 | # Implementation of Feedforward model 183 | self.linear1 = Linear(dim_model, dim_ffn) 184 | self.dropout = nn.Dropout(dropout) 185 | self.linear2 = Linear(dim_ffn, dim_model) 186 | 187 | self.norm_first = norm_first 188 | self.norm1 = nn.LayerNorm(dim_model, eps=layer_norm_eps) 189 | self.norm2 = nn.LayerNorm(dim_model, eps=layer_norm_eps) 190 | self.dropout1 = nn.Dropout(dropout) 191 | self.dropout2 = nn.Dropout(dropout) 192 | 193 | self.activation = activation 194 | 195 | convs = [] 196 | for l in range(n_conv_layers): 197 | convs.append(nn.Conv1d(in_channels=dim_ffn, out_channels=dim_ffn, kernel_size=conv_kernel_size, padding='same', groups=n_conv_groups, bias=conv_bias)) 198 | if conv_mid_norm != None: 199 | if conv_mid_norm == 'GN': 200 | convs.append(nn.GroupNorm(8, dim_ffn)) 201 | else: 202 | raise Exception('Unspoorted mid norm ' + conv_mid_norm) 203 | convs.append(nn.SiLU()) 204 | self.conv = nn.Sequential(*convs) 205 | 206 | def forward(self, x: Tensor, att_mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: 207 | r""" 208 | 209 | Args: 210 | x: shape [batch, seq, feature] 211 | att_mask: the mask for attentions. shape [batch, seq, seq] 212 | 213 | Shape: 214 | out: shape [batch, seq, feature] 215 | attention: shape [batch, head, seq, seq] 216 | """ 217 | 218 | if self.norm_first: 219 | x_, attn = self._sa_block(self.norm1(x), att_mask) 220 | x = x + x_ 221 | x = x + self._ff_block(self.norm2(x)) 222 | else: 223 | x_, attn = self._sa_block(x, att_mask) 224 | x = self.norm1(x + x_) 225 | x = self.norm2(x + self._ff_block(x)) 226 | 227 | return x, attn 228 | 229 | # self-attention block 230 | def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor]) -> Tuple[Tensor, Tensor]: 231 | x, attn = self.self_attn(x, attn_mask=attn_mask) 232 | return self.dropout1(x), attn 233 | 234 | # feed forward block 235 | def _ff_block(self, x: Tensor) -> Tensor: 236 | x = self.linear2(self.dropout(self.conv(self.activation(self.linear1(x)).transpose(-1, -2)).transpose(-1, -2))) 237 | return self.dropout2(x) 238 | 239 | 240 | class NBC(nn.Module): 241 | 242 | def __init__( 243 | self, 244 | dim_input: int = 16, # 2*8 245 | dim_output: int = 4, # 2*2 246 | n_layers: int = 4, 247 | encoder_kernel_size: int = 4, 248 | n_heads: int = 8, 249 | activation: Optional[str] = "", 250 | hidden_size: int = 192, 251 | norm_first: bool = True, 252 | ffn_size: int = 384, 253 | inner_conv_kernel_size: int = 3, 254 | inner_conv_groups: int = 8, 255 | inner_conv_bias: bool = True, 256 | inner_conv_layers: int = 3, 257 | inner_conv_mid_norm: str = "GN", 258 | ): 259 | super().__init__() 260 | # encoder 261 | self.encoder = nn.Conv1d(in_channels=dim_input, out_channels=hidden_size, kernel_size=encoder_kernel_size, stride=1) 262 | # self-attention net 263 | self.sa_layers = nn.ModuleList() 264 | for l in range(n_layers): 265 | self.sa_layers.append( 266 | NBCBlock( 267 | dim_model=hidden_size, 268 | num_head=n_heads, 269 | norm_first=norm_first, 270 | dim_ffn=ffn_size, 271 | n_conv_groups=inner_conv_groups, 272 | conv_kernel_size=inner_conv_kernel_size, 273 | conv_bias=inner_conv_bias, 274 | n_conv_layers=inner_conv_layers, 275 | conv_mid_norm=inner_conv_mid_norm, 276 | )) 277 | 278 | # decoder 279 | assert activation == '', 'not implemented' 280 | self.decoder = nn.ConvTranspose1d(in_channels=hidden_size, out_channels=dim_output, kernel_size=encoder_kernel_size, stride=1) 281 | 282 | def forward(self, x: Tensor) -> Tensor: 283 | # x: [Batch, NumFreqs, Time, Feature] 284 | B, F, T, H = x.shape 285 | x = x.reshape(B * F, T, H) 286 | x = self.encoder(x.permute(0, 2, 1)).permute(0, 2, 1) 287 | attns = [] 288 | for m in self.sa_layers: 289 | x, attn = m(x) 290 | attns.append(attn) 291 | y = self.decoder(x.permute(0, 2, 1)).permute(0, 2, 1) 292 | y = y.reshape(B, F, T, -1) 293 | return y.contiguous() # , attns 294 | 295 | 296 | if __name__ == '__main__': 297 | Batch, Freq, Time, Chn, Spk = 1, 257, 100, 8, 2 298 | x = torch.randn((Batch, Freq, Time, Chn * 2)) 299 | m = NBC(dim_input=Chn * 2, dim_output=Spk * 2, n_layers=4) 300 | y = m(x) 301 | print(y.shape) 302 | -------------------------------------------------------------------------------- /models/arch/NBC2.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.init as init 6 | from torch import Tensor 7 | from torch.nn import Module, MultiheadAttention 8 | from torch.nn.parameter import Parameter 9 | 10 | 11 | class LayerNorm(nn.LayerNorm): 12 | 13 | def __init__(self, transpose: bool, **kwargs) -> None: 14 | super().__init__(**kwargs) 15 | self.transpose = transpose 16 | 17 | def forward(self, input: Tensor) -> Tensor: 18 | if self.transpose: 19 | input = input.transpose(-1, -2) # [B, H, T] -> [B, T, H] 20 | o = super().forward(input) 21 | if self.transpose: 22 | o = o.transpose(-1, -2) 23 | return o 24 | 25 | 26 | class BatchNorm1d(nn.Module): 27 | 28 | def __init__(self, transpose: bool, **kwargs) -> None: 29 | super().__init__() 30 | self.transpose = transpose 31 | self.bn = nn.BatchNorm1d(**kwargs) 32 | 33 | def forward(self, input: Tensor) -> Tensor: 34 | if self.transpose == False: 35 | input = input.transpose(-1, -2) # [B, T, H] -> [B, H, T] 36 | o = self.bn.forward(input) # accepts [B, H, T] 37 | if self.transpose == False: 38 | o = o.transpose(-1, -2) 39 | return o 40 | 41 | 42 | class GroupNorm(nn.GroupNorm): 43 | 44 | def __init__(self, transpose: bool, **kwargs) -> None: 45 | super().__init__(**kwargs) 46 | self.transpose = transpose 47 | 48 | def forward(self, input: Tensor) -> Tensor: 49 | if self.transpose == False: 50 | input = input.transpose(-1, -2) # [B, T, H] -> [B, H, T] 51 | o = super().forward(input) # accepts [B, H, T] 52 | if self.transpose == False: 53 | o = o.transpose(-1, -2) 54 | return o 55 | 56 | 57 | class GroupBatchNorm(Module): 58 | """Applies Group Batch Normalization over a group of inputs 59 | 60 | This layer uses statistics computed from input data in both training and 61 | evaluation modes. 62 | """ 63 | 64 | dim_hidden: int 65 | group_size: int 66 | eps: float 67 | affine: bool 68 | transpose: bool 69 | share_along_sequence_dim: bool 70 | 71 | def __init__( 72 | self, 73 | dim_hidden: int, 74 | group_size: int, 75 | share_along_sequence_dim: bool = False, 76 | transpose: bool = False, 77 | affine: bool = True, 78 | eps: float = 1e-5, 79 | ) -> None: 80 | """ 81 | Args: 82 | dim_hidden (int): hidden dimension 83 | group_size (int): the size of group 84 | share_along_sequence_dim (bool): share statistics along the sequence dimension. Defaults to False. 85 | transpose (bool): whether the shape of input is [B, T, H] or [B, H, T]. Defaults to False, i.e. [B, T, H]. 86 | affine (bool): affine transformation. Defaults to True. 87 | eps (float): Defaults to 1e-5. 88 | """ 89 | super(GroupBatchNorm, self).__init__() 90 | 91 | self.dim_hidden = dim_hidden 92 | self.group_size = group_size 93 | self.eps = eps 94 | self.affine = affine 95 | self.transpose = transpose 96 | self.share_along_sequence_dim = share_along_sequence_dim 97 | if self.affine: 98 | if transpose: 99 | self.weight = Parameter(torch.empty([dim_hidden, 1])) 100 | self.bias = Parameter(torch.empty([dim_hidden, 1])) 101 | else: 102 | self.weight = Parameter(torch.empty([dim_hidden])) 103 | self.bias = Parameter(torch.empty([dim_hidden])) 104 | self.reset_parameters() 105 | 106 | def reset_parameters(self) -> None: 107 | if self.affine: 108 | init.ones_(self.weight) 109 | init.zeros_(self.bias) 110 | 111 | def forward(self, input: Tensor) -> Tensor: 112 | """ 113 | Args: 114 | input: shape [B, T, H] if transpose=False, else shape [B, H, T] , where B = num of groups * group size. 115 | """ 116 | assert (input.shape[0] // self.group_size) * self.group_size, f'batch size {input.shape[0]} is not divisible by group size {self.group_size}' 117 | if self.transpose == False: 118 | B, T, H = input.shape 119 | input = input.reshape(B // self.group_size, self.group_size, T, H) 120 | 121 | if self.share_along_sequence_dim: 122 | var, mean = torch.var_mean(input, dim=(1, 2, 3), unbiased=False, keepdim=True) 123 | else: 124 | var, mean = torch.var_mean(input, dim=(1, 3), unbiased=False, keepdim=True) 125 | 126 | output = (input - mean) / torch.sqrt(var + self.eps) 127 | if self.affine: 128 | output = output * self.weight + self.bias 129 | output = output.reshape(B, T, H) 130 | else: 131 | B, H, T = input.shape 132 | input = input.reshape(B // self.group_size, self.group_size, H, T) 133 | 134 | if self.share_along_sequence_dim: 135 | var, mean = torch.var_mean(input, dim=(1, 2, 3), unbiased=False, keepdim=True) 136 | else: 137 | var, mean = torch.var_mean(input, dim=(1, 2), unbiased=False, keepdim=True) 138 | 139 | output = (input - mean) / torch.sqrt(var + self.eps) 140 | if self.affine: 141 | output = output * self.weight + self.bias 142 | 143 | output = output.reshape(B, H, T) 144 | 145 | return output 146 | 147 | def extra_repr(self) -> str: 148 | return '{dim_hidden}, {group_size}, share_along_sequence_dim={share_along_sequence_dim}, transpose={transpose}, eps={eps}, ' \ 149 | 'affine={affine}'.format(**self.__dict__) 150 | 151 | 152 | class NBC2Block(nn.Module): 153 | 154 | def __init__( 155 | self, 156 | dim_hidden: int, 157 | dim_ffn: int, 158 | n_heads: int, 159 | dropout: float = 0, 160 | conv_kernel_size: int = 3, 161 | n_conv_groups: int = 8, 162 | norms: Tuple[str, str, str] = ("LN", "GBN", "GBN"), 163 | group_batch_norm_kwargs: Dict[str, Any] = { 164 | 'group_size': 257, 165 | 'share_along_sequence_dim': False, 166 | }, 167 | ) -> None: 168 | super().__init__() 169 | # self-attention 170 | self.norm1 = self._new_norm(norms[0], dim_hidden, False, n_conv_groups, **group_batch_norm_kwargs) 171 | self.self_attn = MultiheadAttention(embed_dim=dim_hidden, num_heads=n_heads, batch_first=True) 172 | self.dropout1 = nn.Dropout(dropout) 173 | 174 | # Convolutional Feedforward 175 | self.norm2 = self._new_norm(norms[1], dim_hidden, False, n_conv_groups, **group_batch_norm_kwargs) 176 | self.linear1 = nn.Linear(dim_hidden, dim_ffn) 177 | self.conv = nn.Sequential( 178 | nn.SiLU(), 179 | nn.Conv1d(in_channels=dim_ffn, out_channels=dim_ffn, kernel_size=conv_kernel_size, padding='same', groups=n_conv_groups, bias=True), 180 | nn.SiLU(), 181 | nn.Conv1d(in_channels=dim_ffn, out_channels=dim_ffn, kernel_size=conv_kernel_size, padding='same', groups=n_conv_groups, bias=True), 182 | self._new_norm(norms[2], dim_ffn, True, n_conv_groups, **group_batch_norm_kwargs), 183 | nn.SiLU(), 184 | nn.Conv1d(in_channels=dim_ffn, out_channels=dim_ffn, kernel_size=conv_kernel_size, padding='same', groups=n_conv_groups, bias=True), 185 | nn.SiLU(), 186 | nn.Dropout(dropout), 187 | ) 188 | self.linear2 = nn.Linear(dim_ffn, dim_hidden) 189 | self.dropout2 = nn.Dropout(dropout) 190 | 191 | nn.init.xavier_uniform_(self.linear1.weight) 192 | nn.init.xavier_uniform_(self.linear2.weight) 193 | nn.init.zeros_(self.linear1.bias) 194 | nn.init.zeros_(self.linear2.bias) 195 | 196 | def forward(self, x: Tensor, att_mask: Optional[Tensor] = None) -> Tensor: 197 | r""" 198 | 199 | Args: 200 | x: shape [batch, seq, feature] 201 | att_mask: the mask for attentions. shape [batch, seq, seq] 202 | 203 | Shape: 204 | out: shape [batch, seq, feature] 205 | attention: shape [batch, head, seq, seq] 206 | """ 207 | 208 | x_, attn = self._sa_block(self.norm1(x), att_mask) 209 | x = x + x_ 210 | x = x + self._ff_block(self.norm2(x)) 211 | 212 | return x, attn 213 | 214 | # self-attention block 215 | def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor]) -> Tuple[Tensor, Tensor]: 216 | if isinstance(self.self_attn, MultiheadAttention): 217 | x, attn = self.self_attn.forward(x, x, x, average_attn_weights=False, attn_mask=attn_mask) 218 | else: 219 | x, attn = self.self_attn(x, attn_mask=attn_mask) 220 | return self.dropout1(x), attn 221 | 222 | # conv feed forward block 223 | def _ff_block(self, x: Tensor) -> Tensor: 224 | x = self.linear2(self.conv(self.linear1(x).transpose(-1, -2)).transpose(-1, -2)) 225 | return self.dropout2(x) 226 | 227 | def _new_norm(self, norm_type: str, dim_hidden: int, transpose: bool, num_conv_groups: int, **freq_norm_kwargs): 228 | if norm_type == 'LN': 229 | norm = LayerNorm(normalized_shape=dim_hidden, transpose=transpose) 230 | elif norm_type == 'GBN': 231 | norm = GroupBatchNorm(dim_hidden=dim_hidden, transpose=transpose, **freq_norm_kwargs) 232 | elif norm_type == 'BN': 233 | norm = BatchNorm1d(num_features=dim_hidden, transpose=transpose) 234 | elif norm_type == 'GN': 235 | norm = GroupNorm(num_groups=num_conv_groups, num_channels=dim_hidden, transpose=transpose) 236 | else: 237 | raise Exception(norm_type) 238 | return norm 239 | 240 | 241 | class NBC2(nn.Module): 242 | 243 | def __init__( 244 | self, 245 | dim_input: int, 246 | dim_output: int, 247 | n_layers: int, 248 | encoder_kernel_size: int = 5, 249 | dim_hidden: int = 192, 250 | dim_ffn: int = 384, 251 | num_freqs: int = 257, 252 | block_kwargs: Dict[str, Any] = { 253 | 'n_heads': 2, 254 | 'dropout': 0, 255 | 'conv_kernel_size': 3, 256 | 'n_conv_groups': 8, 257 | 'norms': ("LN", "GBN", "GBN"), 258 | 'group_batch_norm_kwargs': { 259 | 'share_along_sequence_dim': False, 260 | }, 261 | }, 262 | ): 263 | super().__init__() 264 | block_kwargs['group_batch_norm_kwargs']['group_size'] = num_freqs 265 | 266 | # encoder 267 | self.encoder = nn.Conv1d(in_channels=dim_input, out_channels=dim_hidden, kernel_size=encoder_kernel_size, stride=1, padding="same") 268 | 269 | # self-attention net 270 | self.sa_layers = nn.ModuleList() 271 | for l in range(n_layers): 272 | self.sa_layers.append(NBC2Block(dim_hidden=dim_hidden, dim_ffn=dim_ffn, **block_kwargs)) 273 | 274 | # decoder 275 | self.decoder = nn.Linear(in_features=dim_hidden, out_features=dim_output) 276 | 277 | def forward(self, x: Tensor) -> Tensor: 278 | # x: [Batch, NumFreqs, Time, Feature] 279 | B, F, T, H = x.shape 280 | x = x.reshape(B * F, T, H) 281 | x = self.encoder(x.permute(0, 2, 1)).permute(0, 2, 1) 282 | # attns = [] 283 | for m in self.sa_layers: 284 | x, attn = m(x) 285 | del attn 286 | # attns.append(attn) 287 | y = self.decoder(x) 288 | y = y.reshape(B, F, T, -1) 289 | return y.contiguous() # , attns 290 | 291 | 292 | if __name__ == '__main__': 293 | x = torch.randn((5, 257, 100, 16)) 294 | NBC2_small = NBC2( 295 | dim_input=16, 296 | dim_output=4, 297 | n_layers=8, 298 | dim_hidden=96, 299 | dim_ffn=192, 300 | block_kwargs={ 301 | 'n_heads': 2, 302 | 'dropout': 0, 303 | 'conv_kernel_size': 3, 304 | 'n_conv_groups': 8, 305 | 'norms': ("LN", "GBN", "GBN"), 306 | 'group_batch_norm_kwargs': { 307 | 'group_size': 257, 308 | 'share_along_sequence_dim': False, 309 | }, 310 | }, 311 | ) 312 | y = NBC2_small(x) 313 | print(NBC2_small) 314 | print(y.shape) 315 | NBC2_large = NBC2( 316 | dim_input=16, 317 | dim_output=4, 318 | n_layers=12, 319 | dim_hidden=192, 320 | dim_ffn=384, 321 | block_kwargs={ 322 | 'n_heads': 2, 323 | 'dropout': 0, 324 | 'conv_kernel_size': 3, 325 | 'n_conv_groups': 8, 326 | 'norms': ("LN", "GBN", "GBN"), 327 | 'group_batch_norm_kwargs': { 328 | 'group_size': 257, 329 | 'share_along_sequence_dim': False, 330 | }, 331 | }, 332 | ) 333 | y = NBC2_large(x) 334 | print(NBC2_large) 335 | print(y.shape) 336 | -------------------------------------------------------------------------------- /models/arch/NBSS.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch import Tensor 6 | from torchmetrics.functional.audio import permutation_invariant_training as pit 7 | from torchmetrics.functional.audio import scale_invariant_signal_distortion_ratio as si_sdr 8 | 9 | from models.arch.blstm2_fc1 import BLSTM2_FC1 10 | from models.arch.NBC import NBC 11 | from models.arch.NBC2 import NBC2 12 | 13 | 14 | def neg_si_sdr(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 15 | batch_size = target.shape[0] 16 | si_snr_val = si_sdr(preds=preds, target=target) 17 | return -torch.mean(si_snr_val.view(batch_size, -1), dim=1) 18 | 19 | 20 | class NBSS(nn.Module): 21 | """Multi-channel Narrow-band Deep Speech Separation with Full-band Permutation Invariant Training. 22 | 23 | A module version of NBSS which takes time domain signal as input, and outputs time domain signal. 24 | 25 | Arch could be NB-BLSTM or NBC 26 | """ 27 | 28 | def __init__( 29 | self, 30 | n_channel: int = 8, 31 | n_speaker: int = 2, 32 | n_fft: int = 512, 33 | n_overlap: int = 256, 34 | ref_channel: int = 0, 35 | arch: str = "NB_BLSTM", # could also be NBC, NBC2 36 | arch_kwargs: Dict[str, Any] = dict(), 37 | ): 38 | super().__init__() 39 | 40 | if arch == "NB_BLSTM": 41 | self.arch: nn.Module = BLSTM2_FC1(dim_input=n_channel * 2, dim_output=n_speaker * 2, **arch_kwargs) 42 | elif arch == "NBC": 43 | self.arch = NBC(dim_input=n_channel * 2, dim_output=n_speaker * 2, **arch_kwargs) 44 | elif arch == 'NBC2': 45 | self.arch = NBC2(dim_input=n_channel * 2, dim_output=n_speaker * 2, **arch_kwargs) 46 | else: 47 | raise Exception(f"Unkown arch={arch}") 48 | 49 | self.register_buffer('window', torch.hann_window(n_fft), False) # self.window, will be moved to self.device at training time 50 | self.n_fft = n_fft 51 | self.n_overlap = n_overlap 52 | self.ref_channel = ref_channel 53 | self.n_channel = n_channel 54 | self.n_speaker = n_speaker 55 | 56 | def forward(self, x: Tensor) -> Tensor: 57 | """forward 58 | 59 | Args: 60 | x: time domain signal, shape [Batch, Channel, Time] 61 | 62 | Returns: 63 | y: the predicted time domain signal, shape [Batch, Speaker, Time] 64 | """ 65 | 66 | # STFT 67 | B, C, T = x.shape 68 | x = x.reshape((B * C, T)) 69 | X = torch.stft(x, n_fft=self.n_fft, hop_length=self.n_overlap, window=self.window, win_length=self.n_fft, return_complex=True) 70 | X = X.reshape((B, C, X.shape[-2], X.shape[-1])) # (batch, channel, freq, time frame) 71 | X = X.permute(0, 2, 3, 1) # (batch, freq, time frame, channel) 72 | 73 | # normalization by using ref_channel 74 | F, TF = X.shape[1], X.shape[2] 75 | Xr = X[..., self.ref_channel].clone() # copy 76 | XrMM = torch.abs(Xr).mean(dim=2) # Xr_magnitude_mean: mean of the magnitude of the ref channel of X 77 | X[:, :, :, :] /= (XrMM.reshape(B, F, 1, 1) + 1e-8) 78 | 79 | # to real 80 | X = torch.view_as_real(X) # [B, F, T, C, 2] 81 | X = X.reshape(B, F, TF, C * 2) 82 | 83 | # network processing 84 | output = self.arch(X) 85 | 86 | # to complex 87 | output = output.reshape(B, F, TF, self.n_speaker, 2) 88 | output = torch.view_as_complex(output) # [B, F, TF, S] 89 | 90 | # inverse normalization 91 | Ys_hat = torch.empty(size=(B, self.n_speaker, F, TF), dtype=torch.complex64, device=output.device) 92 | XrMM = torch.unsqueeze(XrMM, dim=2).expand(-1, -1, TF) 93 | for spk in range(self.n_speaker): 94 | Ys_hat[:, spk, :, :] = output[:, :, :, spk] * XrMM[:, :, :] 95 | 96 | # iSTFT with frequency binding 97 | ys_hat = torch.istft(Ys_hat.reshape(B * self.n_speaker, F, TF), n_fft=self.n_fft, hop_length=self.n_overlap, window=self.window, win_length=self.n_fft, length=T) 98 | ys_hat = ys_hat.reshape(B, self.n_speaker, T) 99 | return ys_hat 100 | 101 | 102 | if __name__ == '__main__': 103 | x = torch.randn(size=(10, 8, 16000)) 104 | ys = torch.randn(size=(10, 2, 16000)) 105 | 106 | NBSS_with_NB_BLSTM = NBSS(n_channel=8, n_speaker=2, arch="NB_BLSTM") 107 | ys_hat = NBSS_with_NB_BLSTM(x) 108 | neg_sisdr_loss, best_perm = pit(preds=ys_hat, target=ys, metric_func=neg_si_sdr, eval_func='min') 109 | print(ys_hat.shape, neg_sisdr_loss.mean()) 110 | 111 | NBSS_with_NBC = NBSS(n_channel=8, n_speaker=2, arch="NBC") 112 | ys_hat = NBSS_with_NBC(x) 113 | neg_sisdr_loss, best_perm = pit(preds=ys_hat, target=ys, metric_func=neg_si_sdr, eval_func='min') 114 | print(ys_hat.shape, neg_sisdr_loss.mean()) 115 | 116 | NBSS_with_NBC_small = NBSS(n_channel=8, 117 | n_speaker=2, 118 | arch="NBC2", 119 | arch_kwargs={ 120 | "n_layers": 8, # 12 for large 121 | "dim_hidden": 96, # 192 for large 122 | "dim_ffn": 192, # 384 for large 123 | "block_kwargs": { 124 | 'n_heads': 2, 125 | 'dropout': 0, 126 | 'conv_kernel_size': 3, 127 | 'n_conv_groups': 8, 128 | 'norms': ("LN", "GBN", "GBN"), 129 | 'group_batch_norm_kwargs': { 130 | 'group_size': 257, # 129 for 8k Hz 131 | 'share_along_sequence_dim': False, 132 | }, 133 | } 134 | },) 135 | ys_hat = NBSS_with_NBC_small(x) 136 | neg_sisdr_loss, best_perm = pit(preds=ys_hat, target=ys, metric_func=neg_si_sdr, eval_func='min') 137 | print(ys_hat.shape, neg_sisdr_loss.mean()) 138 | -------------------------------------------------------------------------------- /models/arch/SpatialNet.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | import torch 4 | import torch.nn as nn 5 | from models.arch.base.norm import * 6 | from models.arch.base.non_linear import * 7 | from models.arch.base.linear_group import LinearGroup 8 | from torch import Tensor 9 | from torch.nn import MultiheadAttention 10 | 11 | 12 | class SpatialNetLayer(nn.Module): 13 | 14 | def __init__( 15 | self, 16 | dim_hidden: int, 17 | dim_ffn: int, 18 | dim_squeeze: int, 19 | num_freqs: int, 20 | num_heads: int, 21 | dropout: Tuple[float, float, float] = (0, 0, 0), 22 | kernel_size: Tuple[int, int] = (5, 3), 23 | conv_groups: Tuple[int, int] = (8, 8), 24 | norms: List[str] = ("LN", "LN", "GN", "LN", "LN", "LN"), 25 | padding: str = 'zeros', 26 | full: nn.Module = None, 27 | ) -> None: 28 | super().__init__() 29 | f_conv_groups = conv_groups[0] 30 | t_conv_groups = conv_groups[1] 31 | f_kernel_size = kernel_size[0] 32 | t_kernel_size = kernel_size[1] 33 | 34 | # cross-band block 35 | # frequency-convolutional module 36 | self.fconv1 = nn.ModuleList([ 37 | new_norm(norms[3], dim_hidden, seq_last=True, group_size=None, num_groups=f_conv_groups), 38 | nn.Conv1d(in_channels=dim_hidden, out_channels=dim_hidden, kernel_size=f_kernel_size, groups=f_conv_groups, padding='same', padding_mode=padding), 39 | nn.PReLU(dim_hidden), 40 | ]) 41 | # full-band linear module 42 | self.norm_full = new_norm(norms[5], dim_hidden, seq_last=False, group_size=None, num_groups=f_conv_groups) 43 | self.full_share = False if full == None else True 44 | self.squeeze = nn.Sequential(nn.Conv1d(in_channels=dim_hidden, out_channels=dim_squeeze, kernel_size=1), nn.SiLU()) 45 | self.dropout_full = nn.Dropout2d(dropout[2]) if dropout[2] > 0 else None 46 | self.full = LinearGroup(num_freqs, num_freqs, num_groups=dim_squeeze) if full == None else full 47 | self.unsqueeze = nn.Sequential(nn.Conv1d(in_channels=dim_squeeze, out_channels=dim_hidden, kernel_size=1), nn.SiLU()) 48 | # frequency-convolutional module 49 | self.fconv2 = nn.ModuleList([ 50 | new_norm(norms[4], dim_hidden, seq_last=True, group_size=None, num_groups=f_conv_groups), 51 | nn.Conv1d(in_channels=dim_hidden, out_channels=dim_hidden, kernel_size=f_kernel_size, groups=f_conv_groups, padding='same', padding_mode=padding), 52 | nn.PReLU(dim_hidden), 53 | ]) 54 | 55 | # narrow-band block 56 | # MHSA module 57 | self.norm_mhsa = new_norm(norms[0], dim_hidden, seq_last=False, group_size=None, num_groups=t_conv_groups) 58 | self.mhsa = MultiheadAttention(embed_dim=dim_hidden, num_heads=num_heads, batch_first=True) 59 | self.dropout_mhsa = nn.Dropout(dropout[0]) 60 | # T-ConvFFN module 61 | self.tconvffn = nn.ModuleList([ 62 | new_norm(norms[1], dim_hidden, seq_last=True, group_size=None, num_groups=t_conv_groups), 63 | nn.Conv1d(in_channels=dim_hidden, out_channels=dim_ffn, kernel_size=1), 64 | nn.SiLU(), 65 | nn.Conv1d(in_channels=dim_ffn, out_channels=dim_ffn, kernel_size=t_kernel_size, padding='same', groups=t_conv_groups), 66 | nn.SiLU(), 67 | nn.Conv1d(in_channels=dim_ffn, out_channels=dim_ffn, kernel_size=t_kernel_size, padding='same', groups=t_conv_groups), 68 | new_norm(norms[2], dim_ffn, seq_last=True, group_size=None, num_groups=t_conv_groups), 69 | nn.SiLU(), 70 | nn.Conv1d(in_channels=dim_ffn, out_channels=dim_ffn, kernel_size=t_kernel_size, padding='same', groups=t_conv_groups), 71 | nn.SiLU(), 72 | nn.Conv1d(in_channels=dim_ffn, out_channels=dim_hidden, kernel_size=1), 73 | ]) 74 | self.dropout_tconvffn = nn.Dropout(dropout[1]) 75 | 76 | def forward(self, x: Tensor, att_mask: Optional[Tensor] = None) -> Tensor: 77 | r""" 78 | Args: 79 | x: shape [B, F, T, H] 80 | att_mask: the mask for attention along T. shape [B, T, T] 81 | 82 | Shape: 83 | out: shape [B, F, T, H] 84 | """ 85 | x = x + self._fconv(self.fconv1, x) 86 | x = x + self._full(x) 87 | x = x + self._fconv(self.fconv2, x) 88 | x_, attn = self._tsa(x, att_mask) 89 | x = x + x_ 90 | x = x + self._tconvffn(x) 91 | return x, attn 92 | 93 | def _tsa(self, x: Tensor, attn_mask: Optional[Tensor]) -> Tuple[Tensor, Tensor]: 94 | B, F, T, H = x.shape 95 | x = self.norm_mhsa(x) 96 | x = x.reshape(B * F, T, H) 97 | need_weights = False if hasattr(self, "need_weights") else self.need_weights 98 | x, attn = self.mhsa.forward(x, x, x, need_weights=need_weights, average_attn_weights=False, attn_mask=attn_mask) 99 | x = x.reshape(B, F, T, H) 100 | return self.dropout_mhsa(x), attn 101 | 102 | def _tconvffn(self, x: Tensor) -> Tensor: 103 | B, F, T, H0 = x.shape 104 | # T-Conv 105 | x = x.transpose(-1, -2) # [B,F,H,T] 106 | x = x.reshape(B * F, H0, T) 107 | for m in self.tconvffn: 108 | if type(m) == GroupBatchNorm: 109 | x = m(x, group_size=F) 110 | else: 111 | x = m(x) 112 | x = x.reshape(B, F, H0, T) 113 | x = x.transpose(-1, -2) # [B,F,T,H] 114 | return self.dropout_tconvffn(x) 115 | 116 | def _fconv(self, ml: nn.ModuleList, x: Tensor) -> Tensor: 117 | B, F, T, H = x.shape 118 | x = x.permute(0, 2, 3, 1) # [B,T,H,F] 119 | x = x.reshape(B * T, H, F) 120 | for m in ml: 121 | if type(m) == GroupBatchNorm: 122 | x = m(x, group_size=T) 123 | else: 124 | x = m(x) 125 | x = x.reshape(B, T, H, F) 126 | x = x.permute(0, 3, 1, 2) # [B,F,T,H] 127 | return x 128 | 129 | def _full(self, x: Tensor) -> Tensor: 130 | B, F, T, H = x.shape 131 | x = self.norm_full(x) 132 | x = x.permute(0, 2, 3, 1) # [B,T,H,F] 133 | x = x.reshape(B * T, H, F) 134 | x = self.squeeze(x) # [B*T,H',F] 135 | if self.dropout_full: 136 | x = x.reshape(B, T, -1, F) 137 | x = x.transpose(1, 3) # [B,F,H',T] 138 | x = self.dropout_full(x) # dropout some frequencies in one utterance 139 | x = x.transpose(1, 3) # [B,T,H',F] 140 | x = x.reshape(B * T, -1, F) 141 | 142 | x = self.full(x) # [B*T,H',F] 143 | x = self.unsqueeze(x) # [B*T,H,F] 144 | x = x.reshape(B, T, H, F) 145 | x = x.permute(0, 3, 1, 2) # [B,F,T,H] 146 | return x 147 | 148 | def extra_repr(self) -> str: 149 | return f"full_share={self.full_share}" 150 | 151 | 152 | class SpatialNet(nn.Module): 153 | 154 | def __init__( 155 | self, 156 | dim_input: int, # the input dim for each time-frequency point 157 | dim_output: int, # the output dim for each time-frequency point 158 | dim_squeeze: int, 159 | num_layers: int, 160 | num_freqs: int, 161 | encoder_kernel_size: int = 5, 162 | dim_hidden: int = 192, 163 | dim_ffn: int = 384, 164 | num_heads: int = 2, 165 | dropout: Tuple[float, float, float] = (0, 0, 0), 166 | kernel_size: Tuple[int, int] = (5, 3), 167 | conv_groups: Tuple[int, int] = (8, 8), 168 | norms: List[str] = ("LN", "LN", "GN", "LN", "LN", "LN"), 169 | padding: str = 'zeros', 170 | full_share: int = 0, # share from layer 0 171 | ): 172 | super().__init__() 173 | 174 | # encoder 175 | self.encoder = nn.Conv1d(in_channels=dim_input, out_channels=dim_hidden, kernel_size=encoder_kernel_size, stride=1, padding="same") 176 | 177 | # spatialnet layers 178 | full = None 179 | layers = [] 180 | for l in range(num_layers): 181 | layer = SpatialNetLayer( 182 | dim_hidden=dim_hidden, 183 | dim_ffn=dim_ffn, 184 | dim_squeeze=dim_squeeze, 185 | num_freqs=num_freqs, 186 | num_heads=num_heads, 187 | dropout=dropout, 188 | kernel_size=kernel_size, 189 | conv_groups=conv_groups, 190 | norms=norms, 191 | padding=padding, 192 | full=full if l > full_share else None, 193 | ) 194 | if hasattr(layer, 'full'): 195 | full = layer.full 196 | layers.append(layer) 197 | self.layers = nn.ModuleList(layers) 198 | 199 | # decoder 200 | self.decoder = nn.Linear(in_features=dim_hidden, out_features=dim_output) 201 | 202 | def forward(self, x: Tensor, return_attn_score: bool = False) -> Tensor: 203 | # x: [Batch, Freq, Time, Feature] 204 | B, F, T, H0 = x.shape 205 | x = self.encoder(x.reshape(B * F, T, H0).permute(0, 2, 1)).permute(0, 2, 1) 206 | H = x.shape[2] 207 | 208 | attns = [] if return_attn_score else None 209 | x = x.reshape(B, F, T, H) 210 | for m in self.layers: 211 | setattr(m, "need_weights", return_attn_score) 212 | x, attn = m(x) 213 | if return_attn_score: 214 | attns.append(attn) 215 | 216 | y = self.decoder(x) 217 | if return_attn_score: 218 | return y.contiguous(), attns 219 | else: 220 | return y.contiguous() 221 | 222 | 223 | if __name__ == '__main__': 224 | # CUDA_VISIBLE_DEVICES=7, python -m models.arch.SpatialNet 225 | x = torch.randn((1, 129, 251, 12)) #.cuda() # 251 = 4 second; 129 = 8 kHz; 257 = 16 kHz 226 | spatialnet_small = SpatialNet( 227 | dim_input=12, 228 | dim_output=4, 229 | num_layers=8, 230 | dim_hidden=96, 231 | dim_ffn=192, 232 | kernel_size=(5, 3), 233 | conv_groups=(8, 8), 234 | norms=("LN", "LN", "GN", "LN", "LN", "LN"), 235 | dim_squeeze=8, 236 | num_freqs=129, 237 | full_share=0, 238 | ) #.cuda() 239 | # from packaging.version import Version 240 | # if Version(torch.__version__) >= Version('2.0.0'): 241 | # SSFNet_small = torch.compile(SSFNet_small) 242 | # torch.cuda.synchronize(7) 243 | import time 244 | ts = time.time() 245 | y = spatialnet_small(x) 246 | # torch.cuda.synchronize(7) 247 | te = time.time() 248 | print(spatialnet_small) 249 | print(y.shape) 250 | print(te - ts) 251 | 252 | spatialnet_small = spatialnet_small.to('meta') 253 | x = x.to('meta') 254 | from torch.utils.flop_counter import FlopCounterMode # requires torch>=2.1.0 255 | with FlopCounterMode(spatialnet_small, display=False) as fcm: 256 | y = spatialnet_small(x) 257 | flops_forward_eval = fcm.get_total_flops() 258 | res = y.sum() 259 | res.backward() 260 | flops_backward_eval = fcm.get_total_flops() - flops_forward_eval 261 | 262 | params_eval = sum(param.numel() for param in spatialnet_small.parameters()) 263 | print(f"flops_forward={flops_forward_eval/1e9:.2f}G, flops_back={flops_backward_eval/1e9:.2f}G, params={params_eval/1e6:.2f} M") 264 | -------------------------------------------------------------------------------- /models/arch/base/linear_group.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from torch.nn import * 4 | import math 5 | 6 | 7 | class LinearGroup(nn.Module): 8 | 9 | def __init__(self, in_features: int, out_features: int, num_groups: int, bias: bool = True) -> None: 10 | super(LinearGroup, self).__init__() 11 | self.in_features = in_features 12 | self.out_features = out_features 13 | self.num_groups = num_groups 14 | self.weight = Parameter(torch.empty((num_groups, out_features, in_features))) 15 | if bias: 16 | self.bias = Parameter(torch.empty(num_groups, out_features)) 17 | else: 18 | self.register_parameter('bias', None) 19 | self.reset_parameters() 20 | 21 | def reset_parameters(self) -> None: 22 | # same as linear 23 | init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 24 | if self.bias is not None: 25 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) 26 | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 27 | init.uniform_(self.bias, -bound, bound) 28 | 29 | def forward(self, x: Tensor) -> Tensor: 30 | """shape [..., group, feature]""" 31 | x = torch.einsum("...gh,gkh->...gk", x, self.weight) 32 | if self.bias is not None: 33 | x = x + self.bias 34 | return x 35 | 36 | def extra_repr(self) -> str: 37 | return f"{self.in_features}, {self.out_features}, num_groups={self.num_groups}, bias={True if self.bias is not None else False}" 38 | 39 | 40 | class Conv1dGroup(nn.Module): 41 | 42 | def __init__(self, in_features: int, out_features: int, num_groups: int, kernel_size: int, bias: bool = True) -> None: 43 | super(Conv1dGroup, self).__init__() 44 | self.in_features = in_features 45 | self.out_features = out_features 46 | self.num_groups = num_groups 47 | self.kernel_size = kernel_size 48 | 49 | self.weight = Parameter(torch.empty((num_groups, out_features, in_features, kernel_size))) 50 | if bias: 51 | self.bias = Parameter(torch.empty(num_groups, out_features)) 52 | else: 53 | self.register_parameter('bias', None) 54 | self.reset_parameters() 55 | 56 | def reset_parameters(self) -> None: 57 | # same as linear 58 | init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 59 | if self.bias is not None: 60 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) 61 | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 62 | init.uniform_(self.bias, -bound, bound) 63 | 64 | def forward(self, x: Tensor) -> Tensor: 65 | """shape [batch, time, group, feature]""" 66 | (B, T, G, F), K = x.shape, self.kernel_size 67 | x = x.permute(0, 2, 3, 1).reshape(B * G * F, 1, 1, T) # [B*G*F,1,1,T] 68 | x = torch.nn.functional.unfold(x, kernel_size=(1, K), padding=(0, K // 2)) # [B*G*F,K,T] 69 | x = x.reshape(B, G, F, K, T) 70 | x = torch.einsum("bgfkt,gofk->btgo", x, self.weight) 71 | if self.bias is not None: 72 | x = x + self.bias 73 | return x 74 | 75 | def extra_repr(self) -> str: 76 | return f"{self.in_features}, {self.out_features}, num_groups={self.num_groups}, kernel_size={self.kernel_size}, bias={True if self.bias is not None else False}" 77 | -------------------------------------------------------------------------------- /models/arch/base/non_linear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor, nn 3 | 4 | 5 | class PReLU(nn.PReLU): 6 | 7 | def __init__(self, num_parameters: int = 1, init: float = 0.25, dim: int = 1, device=None, dtype=None) -> None: 8 | super().__init__(num_parameters, init, device, dtype) 9 | self.dim = dim 10 | 11 | def forward(self, input: Tensor) -> Tensor: 12 | if self.dim == 1: 13 | # [B, Chn, Feature] 14 | return super().forward(input) 15 | else: 16 | return super().forward(input.transpose(self.dim, 1)).transpose(self.dim, 1) 17 | 18 | 19 | def new_non_linear(non_linear_type: str, dim_hidden: int, seq_last: bool) -> nn.Module: 20 | if non_linear_type.lower() == 'prelu': 21 | return PReLU(num_parameters=dim_hidden, dim=1 if seq_last == True else -1) 22 | elif non_linear_type.lower() == 'silu': 23 | return nn.SiLU() 24 | elif non_linear_type.lower() == 'sigmoid': 25 | return nn.Sigmoid() 26 | elif non_linear_type.lower() == 'relu': 27 | return nn.ReLU() 28 | elif non_linear_type.lower() == 'leakyrelu': 29 | return nn.LeakyReLU() 30 | elif non_linear_type.lower() == 'elu': 31 | return nn.ELU() 32 | else: 33 | raise Exception(non_linear_type) 34 | 35 | 36 | if __name__ == '__main__': 37 | x = torch.rand(size=(12, 10, 100)) 38 | prelu = new_non_linear('PReLU', 10, True) 39 | y = prelu(x) 40 | print(y.shape) 41 | -------------------------------------------------------------------------------- /models/arch/base/norm.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.init as init 6 | from torch import Tensor 7 | from torch.nn import Module 8 | from torch.nn.parameter import Parameter 9 | 10 | 11 | class LayerNorm(nn.LayerNorm): 12 | 13 | def __init__(self, seq_last: bool, **kwargs) -> None: 14 | """ 15 | Arg s: 16 | seq_last (bool): whether the sequence dim is the last dim 17 | """ 18 | super().__init__(**kwargs) 19 | self.seq_last = seq_last 20 | 21 | def forward(self, input: Tensor) -> Tensor: 22 | if self.seq_last: 23 | input = input.transpose(-1, 1) # [B, H, Seq] -> [B, Seq, H], or [B,H,w,h] -> [B,h,w,H] 24 | o = super().forward(input) 25 | if self.seq_last: 26 | o = o.transpose(-1, 1) 27 | return o 28 | 29 | 30 | class GlobalLayerNorm(nn.Module): 31 | """gLN in convtasnet""" 32 | 33 | def __init__(self, dim_hidden: int, seq_last: bool, eps: float = 1e-5) -> None: 34 | super().__init__() 35 | self.dim_hidden = dim_hidden 36 | self.seq_last = seq_last 37 | self.eps = eps 38 | 39 | if seq_last: 40 | self.weight = Parameter(torch.empty([dim_hidden, 1])) 41 | self.bias = Parameter(torch.empty([dim_hidden, 1])) 42 | else: 43 | self.weight = Parameter(torch.empty([dim_hidden])) 44 | self.bias = Parameter(torch.empty([dim_hidden])) 45 | init.ones_(self.weight) 46 | init.zeros_(self.bias) 47 | 48 | def forward(self, input: Tensor) -> Tensor: 49 | """ 50 | Args: 51 | input (Tensor): shape [B, Seq, H] or [B, H, Seq] 52 | """ 53 | var, mean = torch.var_mean(input, dim=(1, 2), unbiased=False, keepdim=True) 54 | 55 | output = (input - mean) / torch.sqrt(var + self.eps) 56 | output = output * self.weight + self.bias 57 | return output 58 | 59 | def extra_repr(self) -> str: 60 | return '{dim_hidden}, seq_last={seq_last}, eps={eps}'.format(**self.__dict__) 61 | 62 | 63 | class BatchNorm1d(nn.Module): 64 | 65 | def __init__(self, seq_last: bool, **kwargs) -> None: 66 | super().__init__() 67 | self.seq_last = seq_last 68 | self.bn = nn.BatchNorm1d(**kwargs) 69 | 70 | def forward(self, input: Tensor) -> Tensor: 71 | if not self.seq_last: 72 | input = input.transpose(-1, -2) # [B, Seq, H] -> [B, H, Seq] 73 | o = self.bn.forward(input) # accepts [B, H, Seq] 74 | if not self.seq_last: 75 | o = o.transpose(-1, -2) 76 | return o 77 | 78 | 79 | class GroupNorm(nn.GroupNorm): 80 | 81 | def __init__(self, seq_last: bool, **kwargs) -> None: 82 | super().__init__(**kwargs) 83 | self.seq_last = seq_last 84 | 85 | def forward(self, input: Tensor) -> Tensor: 86 | if self.seq_last == False: 87 | input = input.transpose(-1, 1) # [B, ..., H] -> [B, H, ...] 88 | o = super().forward(input) # accepts [B, H, ...] 89 | if self.seq_last == False: 90 | o = o.transpose(-1, 1) 91 | return o 92 | 93 | 94 | class GroupBatchNorm(Module): 95 | """Applies Group Batch Normalization over a group of inputs 96 | 97 | This layer uses statistics computed from input data in both training and 98 | evaluation modes. 99 | 100 | see: `Changsheng Quan, Xiaofei Li. NBC2: Multichannel Speech Separation with Revised Narrow-band Conformer. arXiv:2212.02076.` 101 | 102 | """ 103 | 104 | dim_hidden: int 105 | group_size: int 106 | eps: float 107 | affine: bool 108 | seq_last: bool 109 | share_along_sequence_dim: bool 110 | 111 | def __init__( 112 | self, 113 | dim_hidden: int, 114 | group_size: Optional[int], 115 | share_along_sequence_dim: bool = False, 116 | seq_last: bool = False, 117 | affine: bool = True, 118 | eps: float = 1e-5, 119 | dims_norm: List[int] = None, 120 | dim_affine: int = None, 121 | ) -> None: 122 | """ 123 | Args: 124 | dim_hidden (int): hidden dimension 125 | group_size (int): the size of group, optional 126 | share_along_sequence_dim (bool): share statistics along the sequence dimension. Defaults to False. 127 | seq_last (bool): whether the shape of input is [B, Seq, H] or [B, H, Seq]. Defaults to False, i.e. [B, Seq, H]. 128 | affine (bool): affine transformation. Defaults to True. 129 | eps (float): Defaults to 1e-5. 130 | dims_norm: the dims for normalization 131 | dim_affine: the dims for affine transformation 132 | """ 133 | super(GroupBatchNorm, self).__init__() 134 | 135 | self.dim_hidden = dim_hidden 136 | self.group_size = group_size 137 | self.eps = eps 138 | self.affine = affine 139 | self.seq_last = seq_last 140 | self.share_along_sequence_dim = share_along_sequence_dim 141 | if self.affine: 142 | if seq_last: 143 | weight = torch.empty([dim_hidden, 1]) 144 | bias = torch.empty([dim_hidden, 1]) 145 | else: 146 | self.weight = torch.empty([dim_hidden]) 147 | self.bias = torch.empty([dim_hidden]) 148 | 149 | assert (dims_norm is not None and dim_affine is not None) or (dims_norm is not None), (dims_norm, dim_affine, 'should be none at the time') 150 | self.dims_norm, self.dim_affine = dims_norm, dim_affine 151 | if dim_affine is not None: 152 | assert dim_affine < 0, dim_affine 153 | weight = weight.squeeze() 154 | bias = bias.squeeze() 155 | while dim_affine < -1: 156 | weight = weight.unsqueeze(-1) 157 | bias = bias.unsqueeze(-1) 158 | dim_affine += 1 159 | 160 | self.weight = Parameter(weight) 161 | self.bias = Parameter(bias) 162 | self.reset_parameters() 163 | 164 | def reset_parameters(self) -> None: 165 | if self.affine: 166 | init.ones_(self.weight) 167 | init.zeros_(self.bias) 168 | 169 | def forward(self, x: Tensor, group_size: int = None) -> Tensor: 170 | """ 171 | Args: 172 | x: shape [B, Seq, H] if seq_last=False, else shape [B, H, Seq] , where B = num of groups * group size. 173 | group_size: the size of one group. if not given anywhere, the input must be 4-dim tensor with shape [B, group_size, Seq, H] or [B, group_size, H, Seq] 174 | """ 175 | if self.group_size != None: 176 | assert group_size == None or group_size == self.group_size, (group_size, self.group_size) 177 | group_size = self.group_size 178 | 179 | if group_size is not None: 180 | assert (x.shape[0] // group_size) * group_size, f'batch size {x.shape[0]} is not divisible by group size {group_size}' 181 | 182 | original_shape = x.shape 183 | if self.dims_norm is not None: 184 | var, mean = torch.var_mean(x, dim=self.dims_norm, unbiased=False, keepdim=True) 185 | output = (x - mean) / torch.sqrt(var + self.eps) 186 | if self.affine: 187 | output = output * self.weight + self.bias 188 | elif self.seq_last == False: 189 | if x.ndim == 4: 190 | assert group_size is None or group_size == x.shape[1], (group_size, x.shape) 191 | B, group_size, Seq, H = x.shape 192 | else: 193 | B, Seq, H = x.shape 194 | x = x.reshape(B // group_size, group_size, Seq, H) 195 | 196 | if self.share_along_sequence_dim: 197 | var, mean = torch.var_mean(x, dim=(1, 2, 3), unbiased=False, keepdim=True) 198 | else: 199 | var, mean = torch.var_mean(x, dim=(1, 3), unbiased=False, keepdim=True) 200 | 201 | output = (x - mean) / torch.sqrt(var + self.eps) 202 | if self.affine: 203 | output = output * self.weight + self.bias 204 | 205 | output = output.reshape(original_shape) 206 | else: 207 | if x.ndim == 4: 208 | assert group_size is None or group_size == x.shape[1], (group_size, x.shape) 209 | B, group_size, H, Seq = x.shape 210 | else: 211 | B, H, Seq = x.shape 212 | x = x.reshape(B // group_size, group_size, H, Seq) 213 | 214 | if self.share_along_sequence_dim: 215 | var, mean = torch.var_mean(x, dim=(1, 2, 3), unbiased=False, keepdim=True) 216 | else: 217 | var, mean = torch.var_mean(x, dim=(1, 2), unbiased=False, keepdim=True) 218 | 219 | output = (x - mean) / torch.sqrt(var + self.eps) 220 | if self.affine: 221 | output = output * self.weight + self.bias 222 | 223 | output = output.reshape(original_shape) 224 | 225 | return output 226 | 227 | def extra_repr(self) -> str: 228 | return '{dim_hidden}, {group_size}, share_along_sequence_dim={share_along_sequence_dim}, seq_last={seq_last}, eps={eps}, ' \ 229 | 'affine={affine}'.format(**self.__dict__) 230 | 231 | 232 | def new_norm(norm_type: str, dim_hidden: int, seq_last: bool, group_size: int = None, num_groups: int = None, dims_norm: List[int] = None, dim_affine: int = None) -> nn.Module: 233 | if norm_type.upper() == 'LN': 234 | norm = LayerNorm(normalized_shape=dim_hidden, seq_last=seq_last) 235 | elif norm_type.upper() == 'GBN': 236 | norm = GroupBatchNorm(dim_hidden=dim_hidden, seq_last=seq_last, group_size=group_size, share_along_sequence_dim=False, dims_norm=dims_norm, dim_affine=dim_affine) 237 | elif norm_type == 'GBNShare': 238 | norm = GroupBatchNorm(dim_hidden=dim_hidden, seq_last=seq_last, group_size=group_size, share_along_sequence_dim=True, dims_norm=dims_norm, dim_affine=dim_affine) 239 | elif norm_type.upper() == 'BN': 240 | norm = BatchNorm1d(num_features=dim_hidden, seq_last=seq_last) 241 | elif norm_type.upper() == 'GN': 242 | norm = GroupNorm(num_groups=num_groups, num_channels=dim_hidden, seq_last=seq_last) 243 | elif norm == 'gLN': 244 | norm = GlobalLayerNorm(dim_hidden, seq_last=seq_last) 245 | else: 246 | raise Exception(norm_type) 247 | return norm 248 | -------------------------------------------------------------------------------- /models/arch/blstm2_fc1.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch import Tensor 3 | from typing import Optional, Tuple 4 | 5 | 6 | class BLSTM2_FC1(nn.Module): 7 | 8 | def __init__( 9 | self, 10 | dim_input: int, 11 | dim_output: int, 12 | activation: Optional[str] = "", 13 | hidden_size: Tuple[int, int] = (256, 128), 14 | n_repeat_last_lstm: int = 1, 15 | dropout: Optional[float] = None, 16 | ): 17 | """Two layers of BiLSTMs & one fully connected layer 18 | 19 | Args: 20 | input_size: the input size for the features of the first BiLSTM layer 21 | output_size: the output size for the features of the last BiLSTM layer 22 | hidden_size: the hidden size of each BiLSTM layer. Defaults to (256, 128). 23 | """ 24 | 25 | super().__init__() 26 | 27 | self.input_size = dim_input 28 | self.output_size = dim_output 29 | self.hidden_size = hidden_size 30 | self.activation = activation 31 | self.dropout = dropout 32 | 33 | self.blstm1 = nn.LSTM(input_size=self.input_size, hidden_size=self.hidden_size[0], batch_first=True, bidirectional=True) # type:ignore 34 | self.blstm2 = nn.LSTM(input_size=self.hidden_size[0] * 2, hidden_size=self.hidden_size[1], batch_first=True, bidirectional=True, num_layers=n_repeat_last_lstm) # type:ignore 35 | if dropout is not None: 36 | self.dropout1 = nn.Dropout(p=dropout) 37 | self.dropout2 = nn.Dropout(p=dropout) 38 | 39 | self.linear = nn.Linear(self.hidden_size[1] * 2, self.output_size) # type:ignore 40 | if self.activation is not None and len(self.activation) > 0: # type:ignore 41 | self.activation_func = getattr(nn, self.activation)() # type:ignore 42 | else: 43 | self.activation_func = None 44 | 45 | def forward(self, x: Tensor) -> Tensor: 46 | """forward 47 | 48 | Args: 49 | x: shape [batch, num_freqs, seq, input_size] 50 | 51 | Returns: 52 | Tensor: shape [batch, num_freqs, seq, output_size] 53 | """ 54 | # x: [Batch, NumFreqs, Time, Feature] 55 | B, F, T, H = x.shape 56 | x = x.reshape(B * F, T, H) 57 | x, _ = self.blstm1(x) 58 | if self.dropout: 59 | x = self.dropout1(x) 60 | x, _ = self.blstm2(x) 61 | if self.dropout: 62 | x = self.dropout2(x) 63 | if self.activation_func is not None: 64 | y = self.activation_func(self.linear(x)) 65 | else: 66 | y = self.linear(x) 67 | 68 | y = y.reshape(B, F, T, -1) 69 | return y -------------------------------------------------------------------------------- /models/io/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Audio-WestlakeU/NBSS/cc42fc8ad2e6642c09b8f4169a85b4766dc22b7e/models/io/__init__.py -------------------------------------------------------------------------------- /models/io/cirm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | EPSILON = np.finfo(np.float32).eps 5 | 6 | 7 | def build_complex_ideal_ratio_mask(noisy: torch.Tensor, clean: torch.Tensor) -> torch.Tensor: 8 | """Build the complex ratio mask. 9 | 10 | Args: 11 | noisy: [..., F, T], noisy complex-valued stft coefficients 12 | clean: [..., F, T], clean complex-valued stft coefficients 13 | 14 | References: 15 | https://ieeexplore.ieee.org/document/7364200 16 | 17 | Returns: 18 | [..., F, T, 2] 19 | """ 20 | noisy_real, noisy_imag = noisy.real, noisy.imag 21 | clean_real, clean_imag = clean.real, clean.imag 22 | 23 | denominator = torch.square(noisy_real) + torch.square(noisy_imag) + EPSILON 24 | 25 | mask_real = (noisy_real * clean_real + noisy_imag * clean_imag) / denominator 26 | mask_imag = (noisy_real * clean_imag - noisy_imag * clean_real) / denominator 27 | 28 | complex_ratio_mask = torch.stack((mask_real, mask_imag), dim=-1) 29 | 30 | cirm = compress_cIRM(complex_ratio_mask, K=10, C=0.1) 31 | cirm = torch.view_as_complex(cirm) 32 | return cirm 33 | 34 | 35 | def compress_cIRM(mask, K=10, C=0.1): 36 | """Compress the value of cIRM from (-inf, +inf) to [-K ~ K]. 37 | 38 | References: 39 | https://ieeexplore.ieee.org/document/7364200 40 | """ 41 | if torch.is_tensor(mask): 42 | mask = -100 * (mask <= -100) + mask * (mask > -100) 43 | mask = K * (1 - torch.exp(-C * mask)) / (1 + torch.exp(-C * mask)) 44 | else: 45 | mask = -100 * (mask <= -100) + mask * (mask > -100) 46 | mask = K * (1 - np.exp(-C * mask)) / (1 + np.exp(-C * mask)) 47 | return mask 48 | 49 | 50 | def decompress_cIRM(mask, K=10, limit=9.9): 51 | """Decompress cIRM from [-K ~ K] to [-inf, +inf]. 52 | 53 | Args: 54 | mask: cIRM mask 55 | K: default 10 56 | limit: default 0.1 57 | 58 | References: 59 | https://ieeexplore.ieee.org/document/7364200 60 | """ 61 | mask = torch.view_as_real(mask) 62 | mask = (limit * (mask >= limit) - limit * (mask <= -limit) + mask * (torch.abs(mask) < limit)) 63 | mask = -K * torch.log((K - mask) / (K + mask)) 64 | return torch.view_as_complex(mask) 65 | 66 | 67 | def complex_mul(noisy_r, noisy_i, mask_r, mask_i): 68 | r = noisy_r * mask_r - noisy_i * mask_i 69 | i = noisy_r * mask_i + noisy_i * mask_r 70 | return r, i 71 | 72 | 73 | def complex_mul_v2(noisy, mask): 74 | return noisy * mask 75 | -------------------------------------------------------------------------------- /models/io/loss.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | import torch 3 | from torch import nn 4 | 5 | from torchmetrics.functional.audio import scale_invariant_signal_distortion_ratio as si_sdr 6 | from torchmetrics.functional.audio import signal_noise_ratio as snr 7 | from torchmetrics.functional.audio import source_aggregated_signal_distortion_ratio as sa_sdr 8 | from torchmetrics.functional.audio import permutation_invariant_training as pit 9 | from torchmetrics.functional.audio import pit_permutate as permutate 10 | from typing import * 11 | from models.io.cirm import build_complex_ideal_ratio_mask, decompress_cIRM 12 | from models.io.stft import STFT 13 | 14 | 15 | def neg_sa_sdr(preds: Tensor, target: Tensor, scale_invariant: bool = False) -> Tensor: 16 | batch_size = target.shape[0] 17 | sa_sdr_val = sa_sdr(preds=preds, target=target, scale_invariant=scale_invariant) 18 | return -torch.mean(sa_sdr_val.view(batch_size, -1), dim=1) 19 | 20 | 21 | def neg_si_sdr(preds: Tensor, target: Tensor) -> Tensor: 22 | """calculate neg_si_sdr loss for a batch 23 | 24 | Returns: 25 | loss: shape [batch], real 26 | """ 27 | batch_size = target.shape[0] 28 | si_sdr_val = si_sdr(preds=preds, target=target) 29 | return -torch.mean(si_sdr_val.view(batch_size, -1), dim=1) 30 | 31 | 32 | def neg_snr(preds: Tensor, target: Tensor) -> Tensor: 33 | """calculate neg_snr loss for a batch 34 | 35 | Returns: 36 | loss: shape [batch], real 37 | """ 38 | batch_size = target.shape[0] 39 | snr_val = snr(preds=preds, target=target) 40 | return -torch.mean(snr_val.view(batch_size, -1), dim=1) 41 | 42 | 43 | def _mse(preds: Tensor, target: Tensor) -> Tensor: 44 | """calculate mse loss for a batch 45 | 46 | Returns: 47 | loss: shape [batch], real 48 | """ 49 | batch_size = target.shape[0] 50 | diff = preds - target 51 | diff = diff.view(batch_size, -1) 52 | mse_val = torch.mean(diff**2, dim=1) 53 | return mse_val 54 | 55 | 56 | def cirm_mse(preds: Tensor, target: Tensor) -> Tensor: 57 | """calculate mse loss for a batch of cirms 58 | 59 | Returns: 60 | loss: shape [batch], real 61 | """ 62 | return _mse(preds=preds, target=target) 63 | 64 | 65 | def cc_mse(preds: Tensor, target: Tensor) -> Tensor: 66 | """calculate mse loss for a batch of STFT coefficients 67 | 68 | Returns: 69 | loss: shape [batch], real 70 | """ 71 | return _mse(preds=preds, target=target) 72 | 73 | 74 | class Loss(nn.Module): 75 | is_scale_invariant_loss: bool 76 | name: str 77 | mask: str 78 | 79 | def __init__(self, loss_func: Callable, pit: bool, loss_func_kwargs: Dict[str, Any] = dict()): 80 | super().__init__() 81 | 82 | self.loss_func = loss_func 83 | self.pit = pit 84 | self.loss_func_kwargs = loss_func_kwargs 85 | self.is_scale_invariant_loss = { 86 | neg_sa_sdr: True if 'scale_invariant' in loss_func_kwargs and loss_func_kwargs['scale_invariant'] == True else False, 87 | neg_si_sdr: True, 88 | neg_snr: False, 89 | cirm_mse: False, 90 | cc_mse: False, 91 | }[loss_func] 92 | self.name = loss_func.__name__ 93 | self.mask = 'cirm' if self.loss_func == cirm_mse else None 94 | 95 | def forward(self, yr_hat: Tensor, yr: Tensor, reorder: bool = None, reduce_batch: bool = True, **kwargs) -> Tuple[Tensor, Tensor, Tensor]: 96 | if self.mask is not None: 97 | out, Xr, stft = kwargs['out'], kwargs['Xr'], kwargs['stft'] 98 | Yr, _ = stft.stft(yr) 99 | preds, target = out, self.to_mask(Yr=Yr, Xr=Xr) 100 | preds, target = torch.view_as_real(preds), torch.view_as_real(target) 101 | elif self.loss_func == cc_mse: 102 | out, XrMM, stft = kwargs['out'], kwargs['XrMM'], kwargs['stft'] 103 | Yr, _ = stft.stft(yr) 104 | Yr = Yr / XrMM 105 | preds, target = torch.view_as_real(out), torch.view_as_real(Yr) 106 | else: 107 | preds, target = yr_hat, yr 108 | 109 | perms = None 110 | if self.pit: 111 | losses, perms = pit(preds=preds, target=target, metric_func=self.loss_func, eval_func='min', mode="permutation-wise", **self.loss_func_kwargs) 112 | else: 113 | losses = self.loss_func(preds=preds, target=target, **self.loss_func_kwargs) 114 | 115 | if reorder and perms is not None: 116 | yr_hat = permutate(yr_hat, perm=perms) 117 | 118 | return losses.mean() if reduce_batch else losses, perms, yr_hat 119 | 120 | def to_CC(self, out: Tensor, Xr: Tensor, stft: STFT, XrMM: Tensor) -> Tensor: 121 | if self.loss_func == cirm_mse: 122 | cIRM = decompress_cIRM(mask=out) 123 | Yr = Xr * cIRM 124 | return Yr, {'out': out, 'Xr': Xr, 'stft': stft, 'XrMM': XrMM} 125 | else: 126 | return out, {'out': out, 'Xr': Xr, 'stft': stft, 'XrMM': XrMM} 127 | 128 | def to_mask(self, Yr: Tensor, Xr: Tensor): 129 | if self.mask == 'cirm': 130 | return build_complex_ideal_ratio_mask(noisy=Xr, clean=Yr) 131 | else: 132 | raise Exception(f'not implemented for mask type {self.mask}') 133 | 134 | def extra_repr(self) -> str: 135 | kwargs = "" 136 | for k, v in self.loss_func_kwargs.items(): 137 | kwargs += f'{k}={v},' 138 | 139 | return f"loss_func={self.loss_func.__name__}({kwargs}), pit={self.pit}, mask={self.mask}" 140 | -------------------------------------------------------------------------------- /models/io/norm.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch import Tensor 3 | from typing import * 4 | 5 | import torch 6 | 7 | 8 | def forgetting_normalization(XrMag: Tensor, sliding_window_len: int = 192) -> Tensor: 9 | # https://github.com/Audio-WestlakeU/FullSubNet/blob/e97448375cd1e883276ad583317b1828318910dc/audio_zen/model/base_model.py#L103C19-L103C19 10 | alpha = (sliding_window_len - 1) / (sliding_window_len + 1) 11 | mu = 0 12 | mu_list = [] 13 | B, _, F, T = XrMag.shape 14 | XrMM = XrMag.mean(dim=2, keepdim=True).detach().cpu() # [B,1,1,T] 15 | for t in range(T): 16 | if t < sliding_window_len: 17 | alpha_this = min((t - 1) / (t + 1), alpha) 18 | else: 19 | alpha_this = alpha 20 | mu = alpha_this * mu + (1 - alpha_this) * XrMM[..., t] 21 | mu_list.append(mu) 22 | 23 | XrMM = torch.stack(mu_list, dim=-1).to(XrMag.device) 24 | return XrMM 25 | 26 | 27 | # 雨洁的实现 28 | # def cumulative_normalization(original_signal_mag: Tensor, sliding_window_len: int = 192) -> Tensor: 29 | # alpha = (sliding_window_len - 1) / (sliding_window_len + 1) 30 | # eps = 1e-10 31 | # mu = 0 32 | # mu_list = [] 33 | # batch_size, frame_num, freq_num = original_signal_mag.shape 34 | # for frame_idx in range(frame_num): 35 | # if frame_idx < sliding_window_len: 36 | # alp = torch.min(torch.tensor([(frame_idx - 1) / (frame_idx + 1), alpha])) 37 | # mu = alp * mu + (1 - alp) * torch.mean(original_signal_mag[:, frame_idx, :], dim=-1).reshape(batch_size, 1) 38 | # else: 39 | # current_frame_mu = torch.mean(original_signal_mag[:, frame_idx, :], dim=-1).reshape(batch_size, 1) 40 | # mu = alpha * mu + (1 - alpha) * current_frame_mu 41 | # mu_list.append(mu) 42 | 43 | # XrMM = torch.stack(mu_list, dim=-1).permute(0, 2, 1).reshape(batch_size, frame_num, 1, 1) 44 | # return XrMM 45 | 46 | 47 | class Norm(nn.Module): 48 | 49 | def __init__(self, mode: Optional[Literal['utterance', 'frequency', 'forgetting', 'none']], online: bool = True) -> None: 50 | super().__init__() 51 | self.mode = mode 52 | self.online = online 53 | assert mode != 'forgetting' or online == True, 'forgetting is one online normalization' 54 | 55 | def forward(self, X: Tensor, norm_paras: Any = None, inverse: bool = False) -> Any: 56 | if not inverse: 57 | return self.norm(X, norm_paras=norm_paras) 58 | else: 59 | return self.inorm(X, norm_paras=norm_paras) 60 | 61 | def norm(self, X: Tensor, norm_paras: Any = None, ref_channel: int = None, eps: float = 1e-6) -> Tuple[Tensor, Any]: 62 | """ normalization 63 | Args: 64 | X: [B, Chn, F, T], complex 65 | norm_paras: the paramters for inverse normalization or for the normalization of other X's 66 | eps: 1e-6!=0 when dtype=float16 67 | 68 | Returns: 69 | the normalized tensor and the paramters for inverse normalization 70 | """ 71 | if self.mode == 'none' or self.mode is None: 72 | Xr = X[:, [ref_channel], :, :].clone() 73 | return X, (Xr, None) 74 | 75 | B, C, F, T = X.shape 76 | if norm_paras is None: 77 | Xr = X[:, [ref_channel], :, :].clone() # [B,1,F,T], complex 78 | 79 | if self.mode == 'frequency': 80 | if self.online: 81 | XrMM = torch.abs(Xr) + eps # [B,1,F,T] 82 | else: 83 | XrMM = torch.abs(Xr).mean(dim=3, keepdim=True) + eps # Xr_magnitude_mean, [B,1,F,1] 84 | elif self.mode == 'forgetting': 85 | XrMM = forgetting_normalization(XrMag=torch.abs(Xr)) + eps # [B,1,1,T] 86 | else: 87 | assert self.mode == 'utterance', self.mode 88 | if self.online: 89 | XrMM = torch.abs(Xr).mean(dim=(2,), keepdim=True) + eps # Xr_magnitude_mean, [B,1,1,T] 90 | else: 91 | XrMM = torch.abs(Xr).mean(dim=(2, 3), keepdim=True) + eps # Xr_magnitude_mean, [B,1,1,1] 92 | else: 93 | Xr, XrMM = norm_paras 94 | X[:, :, :, :] /= XrMM 95 | return X, (Xr, XrMM) 96 | 97 | def inorm(self, X: Tensor, norm_paras: Any) -> Tensor: 98 | """ inverse normalization 99 | Args: 100 | X: [B, Chn, F, T], complex 101 | norm_paras: the paramters for inverse normalization 102 | 103 | Returns: 104 | the normalized tensor and the paramters for inverse normalization 105 | """ 106 | 107 | Xr, XrMM = norm_paras 108 | return X * XrMM 109 | 110 | def extra_repr(self) -> str: 111 | return f"{self.mode}, online={self.online}" 112 | 113 | 114 | if __name__ == '__main__': 115 | 116 | # x = torch.randn((10, 1, 129, 251)) 117 | # y1 = forgetting_normalization(x) 118 | # y2 = cumulative_normalization(x.squeeze().transpose(-1, -2)) 119 | # print((y1.squeeze() == y2.squeeze()).all()) 120 | # print(torch.allclose(y1.squeeze(), y2.squeeze())) 121 | 122 | x = torch.randn((2, 1, 129, 251), dtype=torch.complex64).cuda() 123 | # norm = Norm('forgetting') 124 | norm = Norm('utterance', online=False) 125 | for i in range(10): 126 | y = norm.norm(x, ref_channel=0) 127 | import time 128 | torch.cuda.synchronize() 129 | ts = time.time() 130 | for i in range(1): 131 | y = norm.norm(x, ref_channel=0) 132 | torch.cuda.synchronize() 133 | te = time.time() 134 | print(te - ts) 135 | -------------------------------------------------------------------------------- /models/io/stft.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Mapping 2 | import torch 3 | from torch import nn 4 | 5 | from torch import Tensor 6 | from typing import * 7 | 8 | paras_16k = { 9 | 'n_fft': 512, 10 | 'n_hop': 256, 11 | 'win_len': 512, 12 | } 13 | 14 | paras_8k = { 15 | 'n_fft': 256, 16 | 'n_hop': 128, 17 | 'win_len': 256, 18 | } 19 | 20 | 21 | class STFT(nn.Module): 22 | 23 | def __init__(self, n_fft: int, n_hop: int, win_len: Optional[int] = None, win: str = 'hann_window') -> None: 24 | super().__init__() 25 | self.n_fft, self.n_hop, self.win_len = n_fft, n_hop, win_len if win_len is not None else n_fft 26 | self.repr = str((n_fft, n_hop, win, win_len)) 27 | 28 | assert win in ['hann_window', 'sqrt_hann_window'], win 29 | if win == 'hann_window': 30 | window = torch.hann_window(self.n_fft) 31 | else: 32 | # For FT-JNF. Deep Non-linear Filters for Multi-channel Speech Enhancement and Separation 33 | assert win == 'sqrt_hann_window', win 34 | window = torch.sqrt(torch.hann_window(self.n_fft)) 35 | self.register_buffer('window', window) 36 | 37 | def forward(self, X: Tensor, original_len: int = None, inverse=False) -> Any: 38 | """istft 39 | Args: 40 | X: complex [..., F, T] 41 | original_len: original length 42 | inverse: stft or istft 43 | """ 44 | if not inverse: 45 | return self.stft(x) 46 | else: 47 | return self.istft(x, original_len=original_len) 48 | 49 | def stft(self, x: Tensor) -> Tuple[Tensor, int]: 50 | """stft 51 | Args: 52 | x: [..., time] 53 | 54 | Returns: 55 | the complex STFT domain representation of shape [..., freq, time] and the original length of the time domain waveform 56 | """ 57 | shape = list(x.shape) 58 | x = x.reshape(-1, shape[-1]) 59 | if x.is_cuda: 60 | with torch.autocast(device_type=x.device.type, dtype=torch.float32): # use float32 for stft & istft 61 | X = torch.stft(x, n_fft=self.n_fft, hop_length=self.n_hop, win_length=self.win_len, window=self.window, return_complex=True) 62 | else: 63 | X = torch.stft(x, n_fft=self.n_fft, hop_length=self.n_hop, win_length=self.win_len, window=self.window, return_complex=True) 64 | F, T = X.shape[-2:] 65 | X = X.reshape(shape=shape[:-1] + [F, T]) 66 | return X, shape[-1] 67 | 68 | def istft(self, X: Tensor, original_len: int = None) -> Tensor: 69 | """istft 70 | Args: 71 | X: complex [..., F, T] 72 | original_len: returned by stft 73 | 74 | Returns: 75 | the complex STFT domain representation of shape [..., freq, time] and the original length of the time domain waveform 76 | """ 77 | shape = list(X.shape) 78 | X = X.reshape(-1, *shape[-2:]) 79 | if X.is_cuda: 80 | with torch.autocast(device_type=X.device.type, dtype=torch.float32): # use float32 for stft & istft 81 | # iSTFT is problematic when batch size is larger than 16 82 | # x = torch.istft(X, n_fft=self.n_fft, hop_length=self.n_hop, win_length=self.win_len, window=self.window, length=original_len) 83 | xs = [] 84 | for b in range(X.shape[0]): 85 | xb = torch.istft(X[b], n_fft=self.n_fft, hop_length=self.n_hop, win_length=self.win_len, window=self.window, length=original_len) 86 | xs.append(xb) 87 | x = torch.stack(xs, dim=0) 88 | else: 89 | # iSTFT is problematic when batch size is larger than 16 90 | # x = torch.istft(X, n_fft=self.n_fft, hop_length=self.n_hop, win_length=self.win_len, window=self.window, length=original_len) 91 | xs = [] 92 | for b in range(X.shape[0]): 93 | xb = torch.istft(X[b], n_fft=self.n_fft, hop_length=self.n_hop, win_length=self.win_len, window=self.window, length=original_len) 94 | xs.append(xb) 95 | x = torch.stack(xs, dim=0) 96 | x = x.reshape(shape=shape[:-2] + [original_len]) 97 | return x 98 | 99 | def extra_repr(self) -> str: 100 | return self.repr 101 | 102 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): 103 | return 104 | 105 | 106 | if __name__ == '__main__': 107 | x = torch.randn((1, 1, 8000 * 4)) 108 | stft = STFT(**paras_8k) 109 | X, ol = stft.stft(x) 110 | x_p = stft.istft(X, ol) 111 | print(x.shape, x_p.shape, X.shape) 112 | print(torch.allclose(x, x_p, rtol=1e-1)) 113 | print() 114 | -------------------------------------------------------------------------------- /models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from models.utils.git_tools import tag_and_log_git_status 2 | from models.utils.my_json_encoder import MyJsonEncoder 3 | from models.utils.my_logger import MyLogger 4 | from models.utils.my_progress_bar import MyProgressBar 5 | from models.utils.my_rich_progress_bar import MyRichProgressBar 6 | -------------------------------------------------------------------------------- /models/utils/base_cli.py: -------------------------------------------------------------------------------- 1 | """ 2 | Basic Command Line Interface, provides command line controls for training, test, and inference. Be sure to import this file before `import torch`, otherwise the OMP_NUM_THREADS would not work. 3 | """ 4 | 5 | import os 6 | 7 | os.environ["OMP_NUM_THREADS"] = str(2) # limit the threads to reduce cpu overloads, will speed up when there are lots of CPU cores on the running machine 8 | os.environ['OPENBLAS_NUM_THREADS'] = '2' 9 | os.environ["MKL_NUM_THREADS"] = str(2) 10 | 11 | from typing import * 12 | 13 | import torch 14 | if torch.multiprocessing.get_start_method() != 'spawn': 15 | torch.multiprocessing.set_start_method('spawn', force=True) # fix stoi stuck 16 | 17 | from models.utils import MyRichProgressBar as RichProgressBar 18 | # from pytorch_lightning.loggers import TensorBoardLogger 19 | from models.utils.my_logger import MyLogger as TensorBoardLogger 20 | 21 | from pytorch_lightning.callbacks import (LearningRateMonitor, ModelSummary) 22 | from pytorch_lightning.cli import LightningArgumentParser, LightningCLI 23 | 24 | torch.backends.cuda.matmul.allow_tf32 = True # The flag below controls whether to allow TF32 on matmul. This flag defaults to False in PyTorch 1.12 and later. 25 | torch.backends.cudnn.allow_tf32 = True # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. 26 | from packaging.version import Version 27 | if Version(torch.__version__) >= Version('2.0.0'): 28 | torch._dynamo.config.optimize_ddp = False # fix this issue: https://github.com/pytorch/pytorch/issues/111279#issuecomment-1870641439 29 | torch._dynamo.config.cache_size_limit = 64 30 | 31 | 32 | class BaseCLI(LightningCLI): 33 | 34 | def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: 35 | self.add_model_invariant_arguments_to_parser(parser) 36 | 37 | def add_model_invariant_arguments_to_parser(self, parser: LightningArgumentParser) -> None: 38 | # RichProgressBar 39 | parser.add_lightning_class_args(RichProgressBar, nested_key='progress_bar') 40 | parser.set_defaults({"progress_bar.console_kwargs": { 41 | "force_terminal": True, 42 | "no_color": True, 43 | "width": 200, 44 | }}) 45 | 46 | # LearningRateMonitor 47 | parser.add_lightning_class_args(LearningRateMonitor, "learning_rate_monitor") 48 | learning_rate_monitor_defaults = { 49 | "learning_rate_monitor.logging_interval": "epoch", 50 | } 51 | parser.set_defaults(learning_rate_monitor_defaults) 52 | 53 | # ModelSummary 54 | parser.add_lightning_class_args(ModelSummary, 'model_summary') 55 | model_summary_defaults = { 56 | "model_summary.max_depth": 2, 57 | } 58 | parser.set_defaults(model_summary_defaults) 59 | 60 | def before_fit(self): 61 | resume_from_checkpoint: str = self.config['fit']['ckpt_path'] 62 | if resume_from_checkpoint is not None and resume_from_checkpoint.endswith('last.ckpt'): 63 | # log in same dir 64 | # resume_from_checkpoint example: /mnt/home/quancs/projects/NBSS_pmt/logs/NBSS_ifp/version_29/checkpoints/last.ckpt 65 | resume_from_checkpoint = os.path.normpath(resume_from_checkpoint) 66 | splits = resume_from_checkpoint.split(os.path.sep) 67 | version = int(splits[-3].replace('version_', '')) 68 | save_dir = os.path.sep.join(splits[:-3]) 69 | self.trainer.logger = TensorBoardLogger(save_dir=save_dir, name="", version=version, default_hp_metric=False) 70 | else: 71 | model_name = self.model.name if hasattr(self.model, 'name') else type(self.model).__name__ 72 | self.trainer.logger = TensorBoardLogger('logs/', name=model_name, default_hp_metric=False) 73 | 74 | def before_test(self): 75 | if self.config['test']['ckpt_path'] != None: 76 | ckpt_path = self.config['test']['ckpt_path'] 77 | else: 78 | raise Exception('You should give --ckpt_path if you want to test') 79 | epoch = os.path.basename(ckpt_path).split('_')[0] 80 | write_dir = os.path.dirname(os.path.dirname(ckpt_path)) 81 | 82 | test_set = 'test' 83 | if 'test_set' in self.config['test']['data']: 84 | test_set = self.config['test']['data']["test_set"] 85 | elif 'init_args' in self.config['test']['data'] and 'test_set' in self.config['test']['data']['init_args']: 86 | test_set = self.config['test']['data']['init_args']["test_set"] 87 | exp_save_path = os.path.normpath(write_dir + '/' + epoch + '_' + test_set + '_set') 88 | 89 | self.copy_ckpt(exp_save_path=exp_save_path, ckpt_path=ckpt_path) 90 | 91 | import time 92 | # add 10 seconds for threads to simultaneously detect the next version 93 | self.trainer.logger = TensorBoardLogger(exp_save_path, name='', default_hp_metric=False) 94 | time.sleep(10) 95 | 96 | def after_test(self): 97 | if not self.trainer.is_global_zero: 98 | return 99 | import fnmatch 100 | files = fnmatch.filter(os.listdir(self.trainer.log_dir), 'events.out.tfevents.*') 101 | for f in files: 102 | os.remove(self.trainer.log_dir + '/' + f) 103 | print('tensorboard log file for test is removed: ' + self.trainer.log_dir + '/' + f) 104 | 105 | def before_predict(self): 106 | if self.config['predict']['ckpt_path'] != None: 107 | ckpt_path = self.config['predict']['ckpt_path'] 108 | else: 109 | raise Exception('You should give --ckpt_path if you want to test') 110 | epoch = os.path.basename(ckpt_path).split('_')[0] 111 | write_dir = os.path.dirname(os.path.dirname(ckpt_path)) 112 | 113 | exp_save_path = os.path.normpath(write_dir + '/' + epoch + '_predict_set') 114 | 115 | self.copy_ckpt(exp_save_path=exp_save_path, ckpt_path=ckpt_path) 116 | 117 | import time 118 | # add 10 seconds for threads to simultaneously detect the next version 119 | self.trainer.logger = TensorBoardLogger(exp_save_path, name='', default_hp_metric=False) 120 | time.sleep(10) 121 | 122 | def after_predict(self): 123 | if not self.trainer.is_global_zero: 124 | return 125 | import fnmatch 126 | files = fnmatch.filter(os.listdir(self.trainer.log_dir), 'events.out.tfevents.*') 127 | for f in files: 128 | os.remove(self.trainer.log_dir + '/' + f) 129 | print('tensorboard log file for predict is removed: ' + self.trainer.log_dir + '/' + f) 130 | 131 | def copy_ckpt(self, exp_save_path: str, ckpt_path: str): 132 | # copy checkpoint to save path 133 | from pathlib import Path 134 | Path(exp_save_path).mkdir(exist_ok=True) 135 | if (Path(exp_save_path) / Path(ckpt_path).name).exists() == False: 136 | import shutil 137 | shutil.copyfile(ckpt_path, Path(exp_save_path) / Path(ckpt_path).name) 138 | -------------------------------------------------------------------------------- /models/utils/dnsmos.py: -------------------------------------------------------------------------------- 1 | # The following DNSMOS implementation will be available in the next release of torchmetrics. 2 | 3 | import os 4 | from functools import lru_cache 5 | from typing import Any, Dict, Optional 6 | 7 | import numpy as np 8 | import torch 9 | from torch import Tensor 10 | 11 | from lightning_utilities.core.imports import RequirementCache 12 | 13 | _REQUESTS_AVAILABLE = RequirementCache("requests") 14 | _LIBROSA_AVAILABLE = RequirementCache("librosa") 15 | _ONNXRUNTIME_AVAILABLE = RequirementCache("onnxruntime") 16 | 17 | from torchmetrics.utilities import rank_zero_info 18 | 19 | if _LIBROSA_AVAILABLE and _ONNXRUNTIME_AVAILABLE and _REQUESTS_AVAILABLE: 20 | import librosa 21 | import onnxruntime as ort 22 | import requests 23 | from onnxruntime import InferenceSession 24 | else: 25 | librosa, ort, requests = None, None, None # type:ignore 26 | 27 | class InferenceSession: # type:ignore 28 | """Dummy InferenceSession.""" 29 | 30 | def __init__(self, **kwargs: Dict[str, Any]) -> None: 31 | ... 32 | 33 | 34 | __doctest_requires__ = {("deep_noise_suppression_mean_opinion_score", "_load_session"): ["requests", "librosa", "onnxruntime"]} 35 | 36 | SAMPLING_RATE = 16000 37 | INPUT_LENGTH = 9.01 38 | DNSMOS_DIR = "~/.torchmetrics/DNSMOS" 39 | 40 | 41 | def _prepare_dnsmos(dnsmos_dir: str) -> None: 42 | """Download required DNSMOS files. 43 | 44 | Args: 45 | dnsmos_dir: a dir to save the downloaded files. Defaults to "~/.torchmetrics". 46 | 47 | """ 48 | # https://raw.githubusercontent.com/microsoft/DNS-Challenge/master/DNSMOS/DNSMOS/model_v8.onnx 49 | # https://raw.githubusercontent.com/microsoft/DNS-Challenge/master/DNSMOS/DNSMOS/sig_bak_ovr.onnx 50 | # https://raw.githubusercontent.com/microsoft/DNS-Challenge/master/DNSMOS/pDNSMOS/sig_bak_ovr.onnx 51 | url = "https://raw.githubusercontent.com/microsoft/DNS-Challenge/master" 52 | dnsmos_dir = os.path.expanduser(dnsmos_dir) 53 | 54 | # save to or load from ~/torchmetrics/dnsmos/. 55 | for file in ["DNSMOS/DNSMOS/model_v8.onnx", "DNSMOS/DNSMOS/sig_bak_ovr.onnx", "DNSMOS/pDNSMOS/sig_bak_ovr.onnx"]: 56 | saveto = os.path.join(dnsmos_dir, file[7:]) 57 | os.makedirs(os.path.dirname(saveto), exist_ok=True) 58 | if os.path.exists(saveto): 59 | # try load onnx 60 | try: 61 | _ = InferenceSession(saveto) 62 | continue # skip downloading if succeeded 63 | except Exception as _: 64 | os.remove(saveto) 65 | urlf = f"{url}/{file}" 66 | rank_zero_info(f"downloading {urlf} to {saveto}") 67 | myfile = requests.get(urlf) 68 | with open(saveto, "wb") as f: 69 | f.write(myfile.content) 70 | 71 | 72 | @lru_cache 73 | def _load_session( 74 | path: str, 75 | device: torch.device, 76 | ) -> InferenceSession: 77 | """Load onnxruntime session. 78 | 79 | Args: 80 | path: the model path 81 | device: the device used 82 | 83 | Returns: 84 | onnxruntime session 85 | 86 | """ 87 | path = os.path.expanduser(path) 88 | if not os.path.exists(path): 89 | _prepare_dnsmos(DNSMOS_DIR) 90 | 91 | opts = ort.SessionOptions() 92 | opts.inter_op_num_threads = 4 93 | opts.intra_op_num_threads = 4 94 | 95 | if device.type == "cpu": 96 | infs = InferenceSession(path, providers=["CPUExecutionProvider"], sess_options=opts) 97 | elif "CUDAExecutionProvider" in ort.get_all_providers(): 98 | providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] 99 | provider_options = [{"device_id": device.index}, {}] 100 | infs = InferenceSession(path, providers=providers, provider_options=provider_options, sess_options=opts) 101 | else: 102 | infs = InferenceSession(path, providers=["CPUExecutionProvider"], sess_options=opts) 103 | 104 | return infs 105 | 106 | 107 | def _audio_melspec( 108 | audio: np.ndarray, 109 | n_mels: int = 120, 110 | frame_size: int = 320, 111 | hop_length: int = 160, 112 | sr: int = 16000, 113 | to_db: bool = True, 114 | ) -> np.ndarray: 115 | """Calculate the mel-spectrogram of an audio. 116 | 117 | Args: 118 | audio: [..., T] 119 | n_mels: the number of mel-frequencies 120 | frame_size: stft length 121 | hop_length: stft hop length 122 | sr: sample rate of audio 123 | to_db: convert to dB scale if `True` is given 124 | 125 | Returns: 126 | mel-spectrogram: [..., num_mel, T'] 127 | 128 | """ 129 | shape = audio.shape 130 | audio = audio.reshape(-1, shape[-1]) 131 | mel_spec = librosa.feature.melspectrogram(y=audio, sr=sr, n_fft=frame_size + 1, hop_length=hop_length, n_mels=n_mels) 132 | mel_spec = mel_spec.transpose(0, 2, 1) 133 | mel_spec = mel_spec.reshape(shape[:-1] + mel_spec.shape[1:]) 134 | if to_db: 135 | for b in range(mel_spec.shape[0]): 136 | mel_spec[b, ...] = (librosa.power_to_db(mel_spec[b], ref=np.max) + 40) / 40 137 | return mel_spec 138 | 139 | 140 | def _polyfit_val(mos: np.ndarray, personalized: bool) -> np.ndarray: 141 | """Use polyfit to convert raw mos values to DNSMOS values. 142 | 143 | Args: 144 | mos: the raw mos values, [..., 4] 145 | personalized: whether interfering speaker is penalized 146 | 147 | Returns: 148 | DNSMOS: [..., 4] 149 | 150 | """ 151 | if personalized: 152 | p_ovr = np.poly1d([-0.00533021, 0.005101, 1.18058466, -0.11236046]) 153 | p_sig = np.poly1d([-0.01019296, 0.02751166, 1.19576786, -0.24348726]) 154 | p_bak = np.poly1d([-0.04976499, 0.44276479, -0.1644611, 0.96883132]) 155 | else: 156 | p_ovr = np.poly1d([-0.06766283, 1.11546468, 0.04602535]) 157 | p_sig = np.poly1d([-0.08397278, 1.22083953, 0.0052439]) # x**2*v0 + x**1*v1+ v2 158 | p_bak = np.poly1d([-0.13166888, 1.60915514, -0.39604546]) 159 | 160 | mos[..., 1] = p_sig(mos[..., 1]) 161 | mos[..., 2] = p_bak(mos[..., 2]) 162 | mos[..., 3] = p_ovr(mos[..., 3]) 163 | return mos 164 | 165 | 166 | def deep_noise_suppression_mean_opinion_score(preds: Tensor, fs: int, personalized: bool, device: Optional[str] = None) -> Tensor: 167 | """Calculate `Deep Noise Suppression performance evaluation based on Mean Opinion Score`_ (DNSMOS). 168 | 169 | Human subjective evaluation is the ”gold standard” to evaluate speech quality optimized for human perception. 170 | Perceptual objective metrics serve as a proxy for subjective scores. The conventional and widely used metrics 171 | require a reference clean speech signal, which is unavailable in real recordings. The no-reference approaches 172 | correlate poorly with human ratings and are not widely adopted in the research community. One of the biggest 173 | use cases of these perceptual objective metrics is to evaluate noise suppression algorithms. DNSMOS generalizes 174 | well in challenging test conditions with a high correlation to human ratings in stack ranking noise suppression 175 | methods. More details can be found in [DNSMOS paper](https://arxiv.org/pdf/2010.15258.pdf). 176 | 177 | 178 | .. note:: using this metric requires you to have ``librosa``, ``onnxruntime`` and ``requests`` installed. 179 | Install as ``pip install librosa onnxruntime-gpu requests``. 180 | 181 | Args: 182 | preds: [..., time] 183 | fs: sampling frequency 184 | personalized: whether interfering speaker is penalized 185 | device: the device used for calculating DNSMOS, can be cpu or cuda:n, where n is the index of gpu. 186 | If None is given, then the device of input is used. 187 | 188 | Returns: 189 | Float tensor with shape ``(..., 4)`` of DNSMOS values per sample, i.e. [p808_mos, mos_sig, mos_bak, mos_ovr] 190 | 191 | Raises: 192 | ModuleNotFoundError: 193 | If ``librosa``, ``onnxruntime`` or ``requests`` packages are not installed 194 | 195 | Example: 196 | >>> from torch import randn 197 | >>> from torchmetrics.functional.audio.dnsmos import deep_noise_suppression_mean_opinion_score 198 | >>> g = torch.manual_seed(1) 199 | >>> preds = randn(8000) 200 | >>> deep_noise_suppression_mean_opinion_score(preds, 8000, False) 201 | tensor([2.2285, 2.1132, 1.3972, 1.3652], dtype=torch.float64) 202 | 203 | """ 204 | if not _LIBROSA_AVAILABLE or not _ONNXRUNTIME_AVAILABLE or not _REQUESTS_AVAILABLE: 205 | raise ModuleNotFoundError("DNSMOS metric requires that librosa, onnxruntime and requests are installed." 206 | " Install as `pip install librosa onnxruntime-gpu requests`.") 207 | device = torch.device(device) if device is not None else preds.device 208 | 209 | onnx_sess = _load_session(f"{DNSMOS_DIR}/{'p' if personalized else ''}DNSMOS/sig_bak_ovr.onnx", device) 210 | p808_onnx_sess = _load_session(f"{DNSMOS_DIR}/DNSMOS/model_v8.onnx", device) 211 | 212 | desired_fs = SAMPLING_RATE 213 | if fs != desired_fs: 214 | audio = librosa.resample(preds.cpu().numpy(), orig_sr=fs, target_sr=desired_fs) 215 | else: 216 | audio = preds.cpu().numpy() 217 | 218 | # normalize audio 219 | audio = audio / np.max(np.abs(audio),axis=-1,keepdims=True) 220 | 221 | len_samples = int(INPUT_LENGTH * desired_fs) 222 | while audio.shape[-1] < len_samples: 223 | audio = np.concatenate([audio, audio], axis=-1) 224 | 225 | 226 | num_hops = int(np.floor(audio.shape[-1] / desired_fs) - INPUT_LENGTH) + 1 227 | 228 | moss = [] 229 | hop_len_samples = desired_fs 230 | for idx in range(num_hops): 231 | audio_seg = audio[..., int(idx * hop_len_samples):int((idx + INPUT_LENGTH) * hop_len_samples)] 232 | if audio_seg.shape[-1] < len_samples: 233 | continue 234 | shape = audio_seg.shape 235 | audio_seg = audio_seg.reshape((-1, shape[-1])) 236 | 237 | input_features = np.array(audio_seg).astype("float32") 238 | p808_input_features = np.array(_audio_melspec(audio=audio_seg[..., :-160])).astype("float32") 239 | 240 | if device.type != "cpu" and "CUDAExecutionProvider" in ort.get_all_providers(): 241 | input_features = ort.OrtValue.ortvalue_from_numpy(input_features, device.type, device.index) 242 | p808_input_features = ort.OrtValue.ortvalue_from_numpy(p808_input_features, device.type, device.index) 243 | 244 | oi = {"input_1": input_features} 245 | p808_oi = {"input_1": p808_input_features} 246 | mos_np = np.concatenate([p808_onnx_sess.run(None, p808_oi)[0], onnx_sess.run(None, oi)[0]], axis=-1, dtype="float64") 247 | mos_np = _polyfit_val(mos_np, personalized) 248 | 249 | mos_np = mos_np.reshape(shape[:-1] + (4, )) 250 | moss.append(mos_np) 251 | return torch.from_numpy(np.mean(np.stack(moss, axis=-1), axis=-1)) # [p808_mos, mos_sig, mos_bak, mos_ovr] 252 | -------------------------------------------------------------------------------- /models/utils/ensemble.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import * 3 | import torch 4 | 5 | 6 | def ensemble(opts: Union[int, str, List[str]], ckpt: str) -> Tuple[List[str], Dict]: 7 | """ensemble checkpoints 8 | 9 | Args: 10 | opts: ensemble last N epochs if opts is int; ensemble globed checkpoints if opts is str; ensemble specified checkpoints if opts is a list. 11 | ckpt: the current checkpoint path 12 | 13 | Returns: 14 | ckpts: the checkpoints ensembled 15 | state_dict 16 | """ 17 | # parse the ensemble args to obtains the ckpts to ensemble 18 | if isinstance(opts, int): 19 | ckpts = [] 20 | if opts > 0: 21 | epoch = int(Path(ckpt).name.split('_')[0].replace('epoch', '')) 22 | for epc in range(max(0, epoch - opts), epoch, 1): 23 | path = list(Path(ckpt).parent.glob(f'epoch{epc}_*'))[0] 24 | ckpts.append(path) 25 | elif isinstance(opts, list): 26 | assert len(opts) > 0, opts 27 | ckpts = list(set(opts)) 28 | else: # e.g. logs/SSFNetLM/version_100/checkpoints/epoch* or epoch* 29 | assert isinstance(opts, str), opts 30 | ckpts = list(Path(opts).parent.glob(Path(opts).name)) 31 | if len(ckpts) == 0: 32 | ckpts = list(Path(ckpt).parent.glob(opts)) 33 | assert len(ckpts) > 0, f"checkpoints not found in {opts} or {Path(ckpt).parent/opts}" 34 | ckpts = ckpts + [ckpt] 35 | 36 | # remove redundant ckpt 37 | ckpts_ = dict() 38 | for ckpt in ckpts: 39 | ckpts_[Path(ckpt).name] = str(ckpt) 40 | ckpts = list(ckpts_.values()) 41 | ckpts.sort() 42 | 43 | # load weights from checkpoints 44 | state_dict = dict() 45 | for path in ckpts: 46 | data = torch.load(path, map_location='cpu') 47 | for k, v in data['state_dict'].items(): 48 | if k in state_dict: 49 | state_dict[k] += (v / len(ckpts)) 50 | else: 51 | state_dict[k] = (v / len(ckpts)) 52 | return ckpts, state_dict 53 | -------------------------------------------------------------------------------- /models/utils/flops.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | import torch 5 | import yaml 6 | from jsonargparse import ArgumentParser 7 | from pytorch_lightning import LightningModule 8 | from argparse import ArgumentParser as Parser 9 | import traceback 10 | from typing import * 11 | 12 | import operator 13 | 14 | from lightning_utilities.core.imports import compare_version 15 | 16 | _TORCH_GREATER_EQUAL_2_1 = compare_version("torch", operator.ge, "2.1.0", use_base_version=True) 17 | 18 | 19 | # this function is ported from lightning 20 | def measure_flops( 21 | model: torch.nn.Module, 22 | forward_fn: Callable[[], torch.Tensor], 23 | loss_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, 24 | total: bool = True, 25 | ) -> int: 26 | """Utility to compute the total number of FLOPs used by a module during training or during inference. 27 | 28 | It's recommended to create a meta-device model for this, because: 29 | 1) the flops of LSTM cannot be measured if the model is not a meta-device model: 30 | 31 | Example:: 32 | 33 | with torch.device("meta"): 34 | model = MyModel() 35 | x = torch.randn(2, 32) 36 | 37 | model_fwd = lambda: model(x) 38 | fwd_flops = measure_flops(model, model_fwd) 39 | 40 | model_loss = lambda y: y.sum() 41 | fwd_and_bwd_flops = measure_flops(model, model_fwd, model_loss) 42 | 43 | Args: 44 | model: The model whose FLOPs should be measured. 45 | forward_fn: A function that runs ``forward`` on the model and returns the result. 46 | loss_fn: A function that computes the loss given the ``forward_fn`` output. If provided, the loss and `backward` 47 | FLOPs will be included in the result. 48 | 49 | """ 50 | if not _TORCH_GREATER_EQUAL_2_1: 51 | raise ImportError("`measure_flops` requires PyTorch >= 2.1.") 52 | from torch.utils.flop_counter import FlopCounterMode 53 | 54 | flop_counter = FlopCounterMode(model, display=False) 55 | with flop_counter: 56 | if loss_fn is None: 57 | forward_fn() 58 | else: 59 | loss_fn(forward_fn()).backward() 60 | if total: 61 | return flop_counter.get_total_flops() 62 | else: 63 | return flop_counter 64 | 65 | 66 | def detailed_flops(flop_counter) -> str: 67 | sss = "" 68 | for k, v in flop_counter.get_flop_counts().items(): 69 | ss = f"{k}: {{" 70 | for kk, vv in v.items(): 71 | ss += f" {str(kk)}:{vv}" 72 | ss += " }\n" 73 | sss += ss 74 | return sss 75 | 76 | 77 | class FakeModule(torch.nn.Module): 78 | 79 | def __init__(self, module: LightningModule) -> None: 80 | super().__init__() 81 | self.module = module 82 | 83 | def forward(self, x): 84 | return self.module.predict_step(x, 0) 85 | 86 | 87 | def _get_num_params(model: torch.nn.Module): 88 | num_params = sum(param.numel() for param in model.parameters()) 89 | return num_params 90 | 91 | 92 | def _test_FLOPs(model: LightningModule, save_dir: str, num_chns: int, fs: int, audio_time_len: int = 4, num_params: int = None): 93 | if _TORCH_GREATER_EQUAL_2_1: 94 | x = torch.randn(1, num_chns, int(fs * audio_time_len), dtype=torch.float32).to('meta') 95 | model = model.to('meta') 96 | 97 | model_fwd = lambda: model(x, istft=False) 98 | fwd_flops = measure_flops(model, model_fwd, total=False) 99 | 100 | model_loss = lambda y: y[0].sum() 101 | fwd_and_bwd_flops = measure_flops(model, model_fwd, model_loss) 102 | 103 | with open(os.path.join(save_dir, 'FLOPs-detailed.txt'), 'w') as f: 104 | f.write(detailed_flops(fwd_flops)) 105 | flops_forward_eval, flops_backward_eval = fwd_flops.get_total_flops(), fwd_and_bwd_flops - fwd_flops.get_total_flops() 106 | else: 107 | print( 108 | "Warning: FLOPs is measured with torchtnt.utils.flops.FlopTensorDispatchMode which doesn't support LSTM, if your model has LSTMs inside please upgrade to torch>=2.1.0, and use torch.utils.flop_counter.FlopCounterMode with tensor and model on meta device" 109 | ) 110 | module = FakeModule(model) 111 | 112 | import copy 113 | from torchtnt.utils.flops import FlopTensorDispatchMode 114 | 115 | x = torch.randn(1, num_chns, int(fs * audio_time_len), dtype=torch.float32) 116 | flops_forward_eval, flops_backward_eval = 0, 0 117 | try: 118 | with FlopTensorDispatchMode(module) as ftdm: 119 | res = module(x).mean() 120 | flops_forward = copy.deepcopy(ftdm.flop_counts) 121 | flops_forward_eval = sum(list(flops_forward[''].values())) # MACs 122 | 123 | with open(os.path.join(save_dir, 'FLOPs-detailed.txt'), 'w') as f: 124 | for k, v in flops_forward.items(): 125 | f.write(str(k) + ': { ') 126 | for kk, vv in v.items(): 127 | f.write(str(kk).replace('.default', '') + ': ' + str(vv) + ', ') 128 | f.write(' }\n') 129 | 130 | ftdm.reset() 131 | 132 | res.backward() 133 | flops_backward = copy.deepcopy(ftdm.flop_counts) 134 | flops_backward_eval = sum(list(flops_backward[''].values())) 135 | except Exception as e: 136 | exp_file = os.path.join(save_dir, 'FLOPs-failed.txt') 137 | traceback.print_exc(file=open(exp_file, 'w')) 138 | print(f"FLOPs test failed '{repr(e)}', see {exp_file}") 139 | 140 | params_eval = num_params if num_params is not None else _get_num_params(module) 141 | flops_forward_eval_avg = flops_forward_eval / audio_time_len 142 | print( 143 | f"FLOPs: forward={flops_forward_eval/1e9:.2f} G, {flops_forward_eval_avg/1e9:.2f} G/s, back={flops_backward_eval/1e9:.2f} G, params: {params_eval/1e6:.3f} M, detailed: {os.path.join(save_dir, 'FLOPs-detailed.txt')}" 144 | ) 145 | 146 | with open(os.path.join(save_dir, 'FLOPs.yaml'), 'w') as f: 147 | yaml.dump( 148 | { 149 | "flops_forward" if _TORCH_GREATER_EQUAL_2_1 else "macs_forward": f"{flops_forward_eval/1e9:.2f} G", 150 | "flops_forward_avg" if _TORCH_GREATER_EQUAL_2_1 else "macs_forward_avg": f"{flops_forward_eval_avg/1e9:.2f} G/s", 151 | "flops_backward" if _TORCH_GREATER_EQUAL_2_1 else "macs_backward": f"{flops_backward_eval/1e9:.2f} G", 152 | "params": f"{params_eval/1e6:.3f} M", 153 | "fs": fs, 154 | "audio_time_len": audio_time_len, 155 | "num_chns": num_chns, 156 | }, f) 157 | f.close() 158 | 159 | 160 | def import_class(class_path: str): 161 | try: 162 | iclass = importlib.import_module(class_path) 163 | return iclass 164 | except: 165 | imodule = importlib.import_module('.'.join(class_path.split('.')[:-1])) 166 | iclass = getattr(imodule, class_path.split('.')[-1]) 167 | return iclass 168 | 169 | 170 | def _test_FLOPs_from_config(save_dir: str, model_class_path: str, num_chns: int, fs: int, audio_time_len: int = 4, config_file: str = None): 171 | if config_file is None: 172 | config_file = os.path.join(save_dir, 'config.yaml') 173 | 174 | model_class = import_class(model_class_path) 175 | with open(config_file, 'r', encoding='utf-8') as f: 176 | config = yaml.load(f, yaml.FullLoader) 177 | parser = ArgumentParser() 178 | parser.add_class_arguments(model_class) 179 | 180 | if 'compile' in config['model']: 181 | config['model']['compile'] = False # compiled model will fail to test its flops 182 | try: 183 | if 'compile' in config['model']['arch']['init_args']: 184 | config['model']['arch']['init_args']['compile'] = False 185 | except: 186 | ... 187 | model_config = parser.instantiate_classes(config['model']) 188 | model = model_class(**model_config.as_dict()) 189 | num_params = _get_num_params(model=model) 190 | try: 191 | # torcheval report error for shared modules, so config to not share 192 | if "full_share" in config['model']['arch']['init_args']: 193 | if config['model']['arch']['init_args']['full_share'] == True: 194 | config['model']['arch']['init_args']['full_share'] = False 195 | model_config = parser.instantiate_classes(config['model']) 196 | model = model_class(**model_config.as_dict()) 197 | elif type(config['model']['arch']['init_args']['full_share']) == int or config['model']['arch']['init_args']['full_share'] == None: 198 | config['model']['arch']['init_args']['full_share'] = 9999 199 | model_config = parser.instantiate_classes(config['model']) 200 | model = model_class(**model_config.as_dict()) 201 | except Exception as e: 202 | ... 203 | _test_FLOPs(model, save_dir=save_dir, num_chns=num_chns, fs=fs, audio_time_len=audio_time_len, num_params=num_params) 204 | 205 | 206 | def write_FLOPs(model: LightningModule, save_dir: str, num_chns: int, fs: int = None, nfft: int = None, audio_time_len: int = 4, model_class_path: str = None): 207 | assert fs is not None or nfft is not None, (fs, nfft) 208 | if model_class_path is None: 209 | model_class_path = f"{str(model.__class__.__module__)}.{type(model).__name__}" 210 | 211 | if fs: 212 | cmd = f'CUDA_VISIBLE_DEVICES={model.device.index}, python -m models.utils.flops ' + f'--save_dir {save_dir} --model_class_path {model_class_path} ' + f'--num_chns {num_chns} --fs {fs} --audio_time_len {audio_time_len}' 213 | else: 214 | cmd = f'CUDA_VISIBLE_DEVICES={model.device.index}, python -m models.utils.flops ' + f'--save_dir {save_dir} --model_class_path {model_class_path} ' + f'--num_chns {num_chns} --nfft {nfft} --audio_time_len {audio_time_len}' 215 | print(cmd) 216 | os.system(cmd) 217 | 218 | 219 | if __name__ == '__main__': 220 | # CUDA_VISIBLE_DEVICES=5, python -m models.utils.flops --save_dir logs/SSFNetLM/version_90 --model_class_path models.SSFNetLM.SSFNetLM --num_chns 6 --fs 8000 221 | # CUDA_VISIBLE_DEVICES=5, python -m models.utils.flops --save_dir logs/SSFNetLM/version_90 --model_class_path models.SSFNetLM.SSFNetLM --num_chns 6 --nfft 256 222 | parser = Parser() 223 | parser.add_argument('--save_dir', type=str, required=True, help='save FLOPs to dir') 224 | parser.add_argument('--model_class_path', type=str, required=True, help='the import path of your Lightning Module') 225 | parser.add_argument('--num_chns', type=int, required=True, help='the number of microphone channels') 226 | parser.add_argument('--fs', type=int, default=None, help='sampling rate') 227 | parser.add_argument('--nfft', type=int, default=None, help='sampling rate') 228 | parser.add_argument('--audio_time_len', type=float, default=4., help='seconds of test mixture waveform') 229 | parser.add_argument('--config_file', type=str, default=None, help='config file path') 230 | args = parser.parse_args() 231 | 232 | fs = args.fs 233 | if fs is None: 234 | if args.nfft is None: 235 | print('MACs test error: you should specify the fs or nfft') 236 | exit(-1) 237 | fs = {256: 8000, 512: 16000, 320: 16000, 160: 8000}[args.nfft] 238 | 239 | _test_FLOPs_from_config( 240 | save_dir=args.save_dir, 241 | model_class_path=args.model_class_path, 242 | num_chns=args.num_chns, 243 | fs=fs, 244 | audio_time_len=args.audio_time_len, 245 | config_file=args.config_file, 246 | ) 247 | -------------------------------------------------------------------------------- /models/utils/general_steps.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import * 4 | from pathlib import Path 5 | 6 | import pytorch_lightning as pl 7 | import soundfile as sf 8 | import torch 9 | from numpy import ndarray 10 | from pandas import DataFrame 11 | from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn 12 | from torch import Tensor 13 | 14 | from models.utils import MyJsonEncoder, tag_and_log_git_status 15 | from models.utils.ensemble import ensemble 16 | from models.utils.flops import write_FLOPs 17 | from models.utils.metrics import (cal_metrics_functional, cal_pesq, recover_scale) 18 | 19 | 20 | def on_validation_epoch_end(self: pl.LightningModule, cpu_metric_input: List[Tuple[ndarray, ndarray, int]], N: int = 5) -> None: 21 | """calculate heavy metrics for every N epochs 22 | 23 | Args: 24 | self: LightningModule 25 | cpu_metric_input: the input list for cal_metrics_functional 26 | N: the number of epochs. Defaults to 5. 27 | """ 28 | 29 | if self.current_epoch != 0 and self.current_epoch % N != (N - 1): 30 | cpu_metric_input.clear() 31 | return 32 | 33 | if len(cpu_metric_input) == 0: 34 | return 35 | 36 | torch.multiprocessing.set_sharing_strategy('file_system') 37 | num_thread = torch.multiprocessing.cpu_count() // (self.trainer.world_size * 2) 38 | p = torch.multiprocessing.Pool(min(num_thread, len(cpu_metric_input))) 39 | cpu_metrics = list(p.starmap(cal_metrics_functional, cpu_metric_input)) 40 | p.close() 41 | p.join() 42 | 43 | for k in cpu_metric_input[0][0]: 44 | ms = list(filter(None, [m[0][k.lower()] for m in cpu_metrics])) 45 | if len(ms) > 0: 46 | self.log(f'val/{k}', sum(ms) / len(ms), sync_dist=True, batch_size=len(ms)) 47 | 48 | cpu_metric_input.clear() 49 | 50 | 51 | def on_test_epoch_end(self: pl.LightningModule, results: List[Dict[str, Any]], cpu_metric_input: List, exp_save_path: str): 52 | """ calculate cpu metrics on CPU, collect results, save results to file 53 | 54 | Args: 55 | self: LightningModule 56 | results: the result list 57 | cpu_metric_input: the input list for cal_metrics_functional 58 | exp_save_path: the path to save result file 59 | """ 60 | 61 | # calculate metrics, input_metrics, improve_metrics on CPU using multiprocessing to speed up 62 | torch.multiprocessing.set_sharing_strategy('file_system') 63 | num_thread = torch.multiprocessing.cpu_count() // (self.trainer.world_size * 2) 64 | p = torch.multiprocessing.Pool(min(num_thread, len(cpu_metric_input))) 65 | cpu_metrics = list(p.starmap(cal_metrics_functional, cpu_metric_input)) 66 | p.close() 67 | p.join() 68 | for i, m in enumerate(cpu_metrics): 69 | metrics, input_metrics, imp_metrics = m 70 | results[i].update(input_metrics) 71 | results[i].update(imp_metrics) 72 | results[i].update(metrics) 73 | 74 | # gather results from all GPUs 75 | import torch.distributed as dist 76 | 77 | # collect results from other gpus if world_size > 1 78 | if self.trainer.world_size > 1: 79 | dist.barrier() 80 | results_list = [None for obj in results] 81 | dist.all_gather_object(results_list, results) # gather results from all gpus 82 | # merge them 83 | exist = set() 84 | results = [] 85 | for rs in results_list: 86 | if rs == None: 87 | continue 88 | for r in rs: 89 | if r['wavname'] not in exist: 90 | results.append(r) 91 | exist.add(r['wavname']) 92 | 93 | # save collected data on 0-th gpu 94 | if self.trainer.is_global_zero: 95 | # save 96 | import datetime 97 | x = datetime.datetime.now() 98 | dtstr = x.strftime('%Y%m%d_%H%M%S.%f') 99 | path = os.path.join(exp_save_path, 'results_{}.json'.format(dtstr)) 100 | # write results to json 101 | f = open(path, 'w', encoding='utf-8') 102 | json.dump(results, f, indent=4, cls=MyJsonEncoder) 103 | f.close() 104 | # write mean to json 105 | df = DataFrame(results) 106 | df.mean(numeric_only=True).to_json(os.path.join(exp_save_path, 'results_mean.json'), indent=4) 107 | self.print('results: ', os.path.join(exp_save_path, 'results_mean.json'), ' ', path) 108 | 109 | 110 | def on_predict_batch_end( 111 | self: pl.LightningModule, 112 | outputs: Optional[Any], 113 | batch: Any, 114 | ) -> None: 115 | """save predicted results to `log_dir/examples` 116 | 117 | Args: 118 | self: LightningModule 119 | outputs: _description_ 120 | batch: _description_ 121 | """ 122 | save_dir = self.trainer.logger.log_dir + '/' + 'examples' 123 | os.makedirs(save_dir, exist_ok=True) 124 | 125 | if not isinstance(batch, Tensor): 126 | input, target, paras = batch 127 | if 'saveto' in paras[0]: 128 | for b in range(len(paras)): 129 | saveto = paras[b]['saveto'] 130 | if isinstance(saveto, str): 131 | saveto = [saveto] 132 | 133 | assert isinstance(saveto, list), ('saveto should be a list of size num_speakers', type(saveto)) 134 | for spk, spk_saveto in enumerate(saveto): 135 | if isinstance(spk_saveto, dict): 136 | input_saveto = spk_saveto['input'] if 'input' in spk_saveto else None 137 | target_saveto = spk_saveto['target'] if 'target' in spk_saveto else None 138 | pred_saveto = spk_saveto['prediction'] if 'prediction' in spk_saveto else None 139 | else: 140 | pred_saveto, input_saveto, target_saveto = spk_saveto, None, None 141 | 142 | # save predictions 143 | if pred_saveto: 144 | y = outputs[b][spk] 145 | assert len(y.shape) == 1, y.shape 146 | save_path = Path(save_dir) / pred_saveto 147 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 148 | sf.write(save_path, y.detach().cpu().numpy(), samplerate=paras[b]['sample_rate']) 149 | # save input 150 | if input_saveto: 151 | y = input[b].T # [T,CHN] 152 | save_path = Path(save_dir) / input_saveto 153 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 154 | sf.write(save_path, y.detach().cpu().numpy(), samplerate=paras[b]['sample_rate']) 155 | # # save target 156 | # if input_saveto and target is not None: 157 | # y = target[b, spk, :, :].T # [T,CHN] 158 | # save_path = save_dir + '/' + target_saveto 159 | # os.makedirs(os.path.dirname(save_path), exist_ok=True) 160 | # sf.write(save_path, y.detach().cpu().numpy(), samplerate=paras[b]['sample_rate']) 161 | 162 | 163 | def on_load_checkpoint( 164 | self: pl.LightningModule, 165 | checkpoint: Dict[str, Any], 166 | ensemble_opts: Union[int, str, List[str], Literal[None]] = None, 167 | compile: bool = True, 168 | reset: List[str] = [], 169 | ) -> None: 170 | """load checkpoint 171 | 172 | Args: 173 | self: LightningModule 174 | checkpoint: the loaded weights 175 | ensemble_opts: opts for ensemble. Defaults to None. 176 | compile: whether the checkpoint is a compiled one. Defaults to True. 177 | """ 178 | from pytorch_lightning.strategies import FSDPStrategy 179 | if isinstance(self.trainer.strategy, FSDPStrategy): 180 | rank_zero_warn('using fsdp, ensemble is disenabled') 181 | return super(pl.LightningModule, self).on_load_checkpoint(checkpoint) 182 | 183 | if ensemble_opts: 184 | ckpt = self.trainer.ckpt_path 185 | ckpts, state_dict = ensemble(opts=ensemble_opts, ckpt=ckpt) 186 | self.print(f'rank {self.trainer.local_rank}/{self.trainer.world_size}, ensemble {ensemble_opts}: {ckpts}') 187 | checkpoint['state_dict'] = state_dict 188 | 189 | # rename weights for removing _orig_mod in name 190 | name_mapping = {} # {name without _orig_mod: the actual name} 191 | parameters = self.state_dict() 192 | for k, v in parameters.items(): 193 | name_mapping[k.replace('_orig_mod.', '')] = k 194 | 195 | state_dict = checkpoint['state_dict'] 196 | state_dict_new = dict() 197 | for k, v, in state_dict.items(): 198 | state_dict_new[name_mapping[k.replace('_orig_mod.', '')]] = v 199 | checkpoint['state_dict'] = state_dict_new 200 | 201 | # reset optimizer and lr_scheduler 202 | if reset is not None: 203 | for key in reset: 204 | assert key in ['optimizer', 'lr_scheduler'], f'unsupported reset key {key}' 205 | if key == 'optimizer': 206 | checkpoint['optimizer'] = dict() 207 | checkpoint['optimizer_states'] = [] 208 | rank_zero_info('reset optimizer') 209 | elif key == 'lr_scheduler': 210 | checkpoint['lr_scheduler'] = dict() 211 | checkpoint['lr_schedulers'] = [] 212 | rank_zero_info('reset lr_scheduler') 213 | 214 | return super(pl.LightningModule, self).on_load_checkpoint(checkpoint) 215 | 216 | 217 | def on_train_start(self: pl.LightningModule, exp_name: str, model_name: str, num_chns: int, nfft: int, model_class_path: str = None): 218 | """ 1) add git tags/write requirements for better change tracking; 2) write model architecture to file; 3) measure the model FLOPs 219 | 220 | Args: 221 | self: LightningModule 222 | exp_name: `notag` or exp_name, add git tag e.g. 'model_name_v10' if exp_name!='notag' 223 | model_name: the model name 224 | num_chns: the number of channels for FLOPs test 225 | nfft: the number of fft points 226 | model_class_path: the path to import the self 227 | """ 228 | if self.current_epoch == 0: 229 | if self.trainer.is_global_zero and hasattr(self.logger, 'log_dir') and 'notag' not in exp_name: 230 | # add git tags for better change tracking 231 | # note: if change self.logger.log_dir to self.trainer.log_dir, the training will stuck on multi-gpu training 232 | tag_and_log_git_status(self.logger.log_dir + '/git.out', self.logger.version, exp_name, model_name=model_name) 233 | 234 | if self.trainer.is_global_zero and hasattr(self.logger, 'log_dir'): 235 | # write model architecture to file 236 | with open(self.logger.log_dir + '/model.txt', 'a') as f: 237 | f.write(str(self)) 238 | f.write('\n\n\n') 239 | # measure the model FLOPs, the num_chns here only means the original channels 240 | # write_FLOPs(model=self, save_dir=self.logger.log_dir, num_chns=num_chns, nfft=nfft, model_class_path=model_class_path) 241 | 242 | 243 | def configure_optimizers( 244 | self: pl.LightningModule, 245 | optimizer: str, 246 | optimizer_kwargs: Dict[str, Any], 247 | monitor: str = 'val/loss', 248 | lr_scheduler: str = None, 249 | lr_scheduler_kwargs: Dict[str, Any] = None, 250 | ): 251 | """configure optimizer and lr_scheduler""" 252 | if optimizer == 'Adam' and self.trainer.precision == '16-mixed': 253 | if 'eps' not in optimizer_kwargs: 254 | optimizer_kwargs['eps'] = 1e-4 # according to https://discuss.pytorch.org/t/adam-half-precision-nans/1765 255 | rank_zero_info('setting the eps of Adam to 1e-4 for FP16 mixed precision training') 256 | else: 257 | allowed_minimum = torch.finfo(torch.float16).eps 258 | assert optimizer_kwargs['eps'] >= allowed_minimum, f"You should specify an eps greater than the allowed minimum of the FP16 precision: {optimizer_kwargs['eps']} {allowed_minimum}" 259 | optimizer = getattr(torch.optim, optimizer)(self.parameters(), **optimizer_kwargs) 260 | 261 | if lr_scheduler is not None and len(lr_scheduler) > 0: 262 | lr_scheduler = getattr(torch.optim.lr_scheduler, lr_scheduler) 263 | return { 264 | 'optimizer': optimizer, 265 | 'lr_scheduler': { 266 | 'scheduler': lr_scheduler(optimizer, **lr_scheduler_kwargs), 267 | 'monitor': monitor, 268 | } 269 | } 270 | else: 271 | return optimizer 272 | 273 | 274 | def test_setp_write_example(self, xr: Tensor, yr: Tensor, yr_hat: Tensor, sample_rate: int, paras: Dict[str, Any], result_dict: Dict[str, Any], wavname: str, exp_save_path: str): 275 | """ 276 | Args: 277 | xr: [B,T] 278 | yr: [B,Spk,T] 279 | yr_hat: [B,Spk,T] 280 | """ 281 | 282 | # write examples 283 | abs_max = max(torch.max(torch.abs(xr[0, ...])), torch.max(torch.abs(yr[0, ...]))) 284 | 285 | def write_wav(wav_path: str, wav: torch.Tensor, norm_to: torch.Tensor = None): 286 | # make sure wav don't have illegal values (abs greater than 1) 287 | if norm_to: 288 | wav = wav / torch.max(torch.abs(wav)) * norm_to 289 | if abs_max > 1: 290 | wav /= abs_max 291 | abs_max_wav = torch.max(torch.abs(wav)) 292 | if abs_max_wav > 1: 293 | import warnings 294 | warnings.warn(f"abs_max_wav > 1, {abs_max_wav}") 295 | wav /= abs_max_wav 296 | sf.write(wav_path, wav.detach().cpu().numpy(), sample_rate) 297 | 298 | pattern = '.'.join(wavname.split('.')[:-1]) + '{name}' # remove .wav in wavname 299 | example_dir = os.path.join(exp_save_path, 'examples', str(paras[0]['index'])) 300 | os.makedirs(example_dir, exist_ok=True) 301 | # save preds and targets for each speaker 302 | for i in range(yr.shape[1]): 303 | # write ys 304 | wav_path = os.path.join(example_dir, pattern.format(name=f"_spk{i+1}.wav")) 305 | write_wav(wav_path=wav_path, wav=yr[0, i]) 306 | # write ys_hat 307 | wav_path = os.path.join(example_dir, pattern.format(name=f"_spk{i+1}_p.wav")) 308 | write_wav(wav_path=wav_path, wav=yr_hat[0, i]) #, norm_to=ys[0, i].abs().max()) 309 | # write mix 310 | wav_path = os.path.join(example_dir, pattern.format(name=f"_mix.wav")) 311 | write_wav(wav_path=wav_path, wav=xr[0, :]) 312 | 313 | # write paras & results 314 | f = open(os.path.join(example_dir, pattern.format(name=f"_paras.json")), 'w', encoding='utf-8') 315 | paras[0]['metrics'] = result_dict 316 | json.dump(paras[0], f, indent=4, cls=MyJsonEncoder) 317 | f.close() 318 | -------------------------------------------------------------------------------- /models/utils/git_tools.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | def tag_and_log_git_status(log_to: str, version: str, exp_name: str, model_name: str) -> None: 4 | # add git tags for better change tracking 5 | import subprocess 6 | gitout = open(log_to, 'a', encoding='utf-8') 7 | del_tag = f'git tag -d {model_name}_v{version}' 8 | add_tag = f'git tag -a {model_name}_v{version} -m "{exp_name}"' 9 | print_branch = "git branch -vv" 10 | print_status = 'git status' 11 | print_status2 = f'pip freeze > {str(Path(log_to).expanduser().parent)}/requirements_pip.txt' 12 | print_status3 = f'conda list -e > {str(Path(log_to).expanduser().parent)}/requirements_conda.txt' 13 | cmds = [del_tag, add_tag, print_branch, print_status, print_status2, print_status3] 14 | for cmd in cmds: 15 | p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE, encoding="utf-8", universal_newlines=True, shell=True) 16 | o, err = p.communicate() 17 | gitout.write(f'========={cmd}=========\n{o}\n\n\n') 18 | gitout.close() 19 | -------------------------------------------------------------------------------- /models/utils/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Tuple, Union 2 | import warnings 3 | from torchmetrics import Metric 4 | from torchmetrics.collections import MetricCollection 5 | from torchmetrics.audio import * 6 | from torchmetrics.functional.audio import * 7 | from torch import Tensor 8 | import torch 9 | import pesq as pesq_backend 10 | import numpy as np 11 | from typing import * 12 | from models.utils.dnsmos import deep_noise_suppression_mean_opinion_score 13 | 14 | ALL_AUDIO_METRICS = ['SDR', 'SI_SDR', 'SI_SNR', 'SNR', 'NB_PESQ', 'WB_PESQ', 'STOI', 'DNSMOS', 'pDNSMOS'] 15 | 16 | 17 | def get_metric_list_on_device(device: Optional[str]): 18 | metric_device = { 19 | None: ['SDR', 'SI_SDR', 'SNR', 'SI_SNR', 'NB_PESQ', 'WB_PESQ', 'STOI', 'ESTOI', 'DNSMOS', 'pDNSMOS'], 20 | "cpu": ['NB_PESQ', 'WB_PESQ', 'STOI', 'ESTOI'], 21 | "gpu": ['SDR', 'SI_SDR', 'SNR', 'SI_SNR', 'DNSMOS', 'pDNSMOS'], 22 | } 23 | return metric_device[device] 24 | 25 | 26 | def cal_metrics_functional( 27 | metric_list: List[str], 28 | preds: Tensor, 29 | target: Tensor, 30 | original: Optional[Tensor], 31 | fs: int, 32 | device_only: Literal['cpu', 'gpu', None] = None, # cpu-only: pesq, stoi; 33 | chunk: Tuple[float, float] = None, # (chunk length, hop length) in seconds for chunk-wise metric evaluation 34 | suffix: str = "", 35 | ) -> Tuple[Dict[str, Tensor], Dict[str, Tensor], Dict[str, Tensor]]: 36 | metrics, input_metrics, imp_metrics = {}, {}, {} 37 | if chunk is not None: 38 | clen, chop = int(fs * chunk[0]), int(fs * chunk[1]) 39 | for i in range(int((preds.shape[-1] / fs - chunk[0]) / chunk[1]) + 1): 40 | metrics_chunk, input_metrics_chunk, imp_metrics_chunk = cal_metrics_functional( 41 | metric_list, 42 | preds[..., i * chop:i * chop + clen], 43 | target[..., i * chop:i * chop + clen], 44 | original[..., i * chop:i * chop + clen] if original is not None else None, 45 | fs, 46 | device_only, 47 | chunk=None, 48 | suffix=f"_{i*chunk[1]+1}s-{i*chunk[1]+chunk[0]}s", 49 | ) 50 | metrics.update(metrics_chunk), input_metrics.update(input_metrics_chunk), imp_metrics.update(imp_metrics_chunk) 51 | 52 | if device_only is None or device_only == 'cpu': 53 | preds_cpu = preds.detach().cpu() 54 | target_cpu = target.detach().cpu() 55 | original_cpu = original.detach().cpu() if original is not None else None 56 | else: 57 | preds_cpu = None 58 | target_cpu = None 59 | original_cpu = None 60 | 61 | for m in metric_list: 62 | mname = m.lower() 63 | if m.upper() not in get_metric_list_on_device(device=device_only): 64 | continue 65 | 66 | if m.upper() == 'SDR': 67 | ## not use signal_distortion_ratio for it gives NaN sometimes 68 | metric_func = lambda: signal_distortion_ratio(preds, target).detach().cpu() 69 | input_metric_func = lambda: signal_distortion_ratio(original, target).detach().cpu() 70 | # assert preds.dim() == 2 and target.dim() == 2 and original.dim() == 2, '(spk, time)!' 71 | # metric_func = lambda: torch.tensor(bss_eval_sources(target_cpu.numpy(), preds_cpu.numpy(), False)[0]).mean().detach().cpu() 72 | # input_metric_func = lambda: torch.tensor(bss_eval_sources(target_cpu.numpy(), original_cpu.numpy(), False)[0]).mean().detach().cpu() 73 | elif m.upper() == 'SI_SDR': 74 | metric_func = lambda: scale_invariant_signal_distortion_ratio(preds, target).detach().cpu() 75 | input_metric_func = lambda: scale_invariant_signal_distortion_ratio(original, target).detach().cpu() 76 | elif m.upper() == 'SI_SNR': 77 | metric_func = lambda: scale_invariant_signal_noise_ratio(preds, target).detach().cpu() 78 | input_metric_func = lambda: scale_invariant_signal_noise_ratio(original, target).detach().cpu() 79 | elif m.upper() == 'SNR': 80 | metric_func = lambda: signal_noise_ratio(preds, target).detach().cpu() 81 | input_metric_func = lambda: signal_noise_ratio(original, target).detach().cpu() 82 | elif m.upper() == 'NB_PESQ': 83 | metric_func = lambda: perceptual_evaluation_speech_quality(preds_cpu, target_cpu, fs, 'nb', n_processes=0) 84 | input_metric_func = lambda: perceptual_evaluation_speech_quality(original_cpu, target_cpu, fs, 'nb', n_processes=0) 85 | elif m.upper() == 'WB_PESQ': 86 | metric_func = lambda: perceptual_evaluation_speech_quality(preds_cpu, target_cpu, fs, 'wb', n_processes=0) 87 | input_metric_func = lambda: perceptual_evaluation_speech_quality(original_cpu, target_cpu, fs, 'wb', n_processes=0) 88 | elif m.upper() == 'STOI': 89 | metric_func = lambda: short_time_objective_intelligibility(preds_cpu, target_cpu, fs) 90 | input_metric_func = lambda: short_time_objective_intelligibility(original_cpu, target_cpu, fs) 91 | elif m.upper() == 'ESTOI': 92 | metric_func = lambda: short_time_objective_intelligibility(preds_cpu, target_cpu, fs, extended=True) 93 | input_metric_func = lambda: short_time_objective_intelligibility(original_cpu, target_cpu, fs, extended=True) 94 | elif m.upper() == 'DNSMOS': 95 | metric_func = lambda: deep_noise_suppression_mean_opinion_score(preds, fs, False) 96 | input_metric_func = lambda: deep_noise_suppression_mean_opinion_score(original, fs, False) 97 | elif m.upper() == 'PDNSMOS': # personalized DNSMOS 98 | metric_func = lambda: deep_noise_suppression_mean_opinion_score(preds, fs, True) 99 | input_metric_func = lambda: deep_noise_suppression_mean_opinion_score(original, fs, True) 100 | else: 101 | raise ValueError('Unkown audio metric ' + m) 102 | 103 | if m.upper() == 'WB_PESQ' and fs == 8000: 104 | # warnings.warn("There is narrow band (nb) mode only when sampling rate is 8000Hz") 105 | continue # Note there is narrow band (nb) mode only when sampling rate is 8000Hz 106 | 107 | try: 108 | if mname == 'dnsmos': 109 | # p808_mos, mos_sig, mos_bak, mos_ovr 110 | m_val = metric_func().cpu().numpy() 111 | 112 | for idx, mid in enumerate(['p808', 'sig', 'bak', 'ovr']): 113 | mname_i = mname + '_' + mid + suffix 114 | metrics[mname_i] = np.mean(m_val[..., idx]).item() 115 | metrics[mname_i + '_all'] = m_val[..., idx].tolist() 116 | if original is None: 117 | continue 118 | 119 | if 'input_' + mname_i not in input_metrics.keys(): 120 | im_val = input_metric_func().cpu().numpy() 121 | input_metrics['input_' + mname_i] = np.mean(im_val[..., idx]).item() 122 | input_metrics['input_' + mname_i + '_all'] = im_val[..., idx].tolist() 123 | 124 | imp_metrics[mname_i + '_i'] = metrics[mname_i] - input_metrics['input_' + mname_i] # _i means improvement 125 | imp_metrics[mname_i + '_all' + '_i'] = (m_val[..., idx] - im_val[..., idx]).tolist() 126 | continue 127 | 128 | mname = mname + suffix 129 | m_val = metric_func().cpu().numpy() 130 | metrics[mname] = np.mean(m_val).item() 131 | metrics[mname + '_all'] = m_val.tolist() # _all means not averaged 132 | if original is None: 133 | continue 134 | 135 | if 'input_' + mname not in input_metrics.keys(): 136 | im_val = input_metric_func().cpu().numpy() 137 | input_metrics['input_' + mname] = np.mean(im_val).item() 138 | input_metrics['input_' + mname + '_all'] = im_val.tolist() 139 | 140 | imp_metrics[mname + '_i'] = metrics[mname] - input_metrics['input_' + mname] # _i means improvement 141 | imp_metrics[mname + '_all' + '_i'] = (m_val - im_val).tolist() 142 | except Exception as e: 143 | metrics[mname] = None 144 | metrics[mname + '_all'] = None 145 | if 'input_' + mname not in input_metrics.keys(): 146 | input_metrics['input_' + mname] = None 147 | input_metrics['input_' + mname + '_all'] = None 148 | imp_metrics[mname + '_i'] = None 149 | imp_metrics[mname + '_i' + '_all'] = None 150 | 151 | return metrics, input_metrics, imp_metrics 152 | 153 | def mypesq(preds: np.ndarray, target: np.ndarray, mode: str, fs: int) -> np.ndarray: 154 | # 使用ndarray是因为tensor会在linux上会导致一些多进程的错误 155 | ori_shape = preds.shape 156 | if type(preds) == Tensor: 157 | preds = preds.detach().cpu().numpy() 158 | target = target.detach().cpu().numpy() 159 | else: 160 | assert type(preds) == np.ndarray, type(preds) 161 | assert type(target) == np.ndarray, type(target) 162 | 163 | if preds.ndim == 1: 164 | pesq_val = pesq_backend.pesq(fs=fs, ref=target, deg=preds, mode=mode) 165 | else: 166 | preds = preds.reshape(-1, ori_shape[-1]) 167 | target = target.reshape(-1, ori_shape[-1]) 168 | pesq_val = np.empty(shape=(preds.shape[0])) 169 | for b in range(preds.shape[0]): 170 | pesq_val[b] = pesq_backend.pesq(fs=fs, ref=target[b, :], deg=preds[b, :], mode=mode) 171 | pesq_val = pesq_val.reshape(ori_shape[:-1]) 172 | return pesq_val 173 | 174 | 175 | def cal_pesq(ys: np.ndarray, ys_hat: np.ndarray, sample_rate: int) -> Tuple[float, float]: 176 | try: 177 | if sample_rate == 16000: 178 | wb_pesq_val = mypesq(preds=ys_hat, target=ys, fs=sample_rate, mode='wb').mean() 179 | nb_pesq_val = mypesq(preds=ys_hat, target=ys, fs=sample_rate, mode='nb').mean() 180 | return [wb_pesq_val, nb_pesq_val] 181 | elif sample_rate == 8000: 182 | nb_pesq_val = mypesq(preds=ys_hat, target=ys, fs=sample_rate, mode='nb').mean() 183 | return [None, nb_pesq_val] 184 | else: 185 | ... 186 | except Exception as e: 187 | ... 188 | # warnings.warn(str(e)) 189 | return [None, None] 190 | 191 | 192 | def recover_scale(preds: Tensor, mixture: Tensor, scale_src_together: bool, norm_if_exceed_1: bool = True) -> Tensor: 193 | """recover wav's original scale by solving min ||Y^T a - X||F, cuz sisdr will lose scale 194 | 195 | Args: 196 | preds: prediction, shape [batch, n_src, time] 197 | mixture: mixture or noisy or reverberant signal, shape [batch, time] 198 | scale_src_together: keep the relative ennergy level between sources. can be used for scale-invariant SA-SDR 199 | norm_max_if_exceed_1: norm the magitude if exceeds one 200 | 201 | Returns: 202 | Tensor: the scale-recovered preds 203 | """ 204 | # TODO: add some kind of weighting mechanism to make the predicted scales more precise 205 | # recover wav's original scale. solve min ||Y^T a - X||F to obtain the scales of the predictions of speakers, cuz sisdr will lose scale 206 | if scale_src_together: 207 | a = torch.linalg.lstsq(preds.sum(dim=-2, keepdim=True).transpose(-1, -2), mixture.unsqueeze(-1)).solution 208 | else: 209 | a = torch.linalg.lstsq(preds.transpose(-1, -2), mixture.unsqueeze(-1)).solution 210 | 211 | preds = preds * a 212 | 213 | if norm_if_exceed_1: 214 | # normalize the audios so that the maximum doesn't exceed 1 215 | max_vals = torch.max(torch.abs(preds), dim=-1).values 216 | norm = torch.where(max_vals > 1, max_vals, 1) 217 | preds = preds / norm.unsqueeze(-1) 218 | return preds 219 | 220 | 221 | if __name__ == "__main__": 222 | x, y, m = torch.rand((2, 8000 * 8)), torch.rand((2, 8000 * 8)), torch.rand((2, 8000 * 8)) 223 | m, im, mi = cal_metrics_functional(["si_sdr"], preds=x, target=y, original=m, fs=8000, chunk=[4, 1]) 224 | print(m) 225 | print(im) 226 | print(mi) 227 | -------------------------------------------------------------------------------- /models/utils/my_earlystopping.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from pytorch_lightning.callbacks import EarlyStopping 3 | 4 | 5 | class MyEarlyStopping(EarlyStopping): 6 | 7 | def __init__( 8 | self, 9 | enable: bool = True, # enable EarlyStopping or not 10 | **kwargs, 11 | ): 12 | super().__init__(**kwargs) 13 | self.enable = enable 14 | 15 | def _should_skip_check(self, trainer: "pl.Trainer") -> bool: 16 | if self.enable == False: 17 | return True 18 | else: 19 | return super()._should_skip_check(trainer) 20 | -------------------------------------------------------------------------------- /models/utils/my_json_encoder.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | from torch import Tensor 4 | from pytorch_lightning.utilities.rank_zero import rank_zero_warn 5 | 6 | 7 | class MyJsonEncoder(json.JSONEncoder): 8 | large_array_size: bool = 100 9 | ignore_large_array: bool = True 10 | 11 | def default(self, obj): 12 | if isinstance(obj, np.int64) or isinstance(obj, np.float64) or isinstance(obj, np.float32): 13 | return obj.item() 14 | elif isinstance(obj, np.ndarray): 15 | if obj.size == 1: 16 | return obj.item() 17 | else: 18 | if obj.size > self.large_array_size: 19 | if self.ignore_large_array: 20 | rank_zero_warn('large array is ignored while saved to json file.') 21 | return None 22 | else: 23 | rank_zero_warn('large array detected. saving it in json is slow. please remove it') 24 | return obj.tolist() 25 | elif isinstance(obj, Tensor): 26 | if obj.numel() == 1: 27 | return obj.item() 28 | else: 29 | if obj.numel() > self.large_array_size: 30 | if self.ignore_large_array: 31 | rank_zero_warn('large array is ignored while saved to json file.') 32 | return None 33 | else: 34 | rank_zero_warn('large array detected. saving it in json is slow. please remove it') 35 | return obj.detach().cpu().numpy().tolist() 36 | return json.JSONEncoder.default(self, obj) 37 | -------------------------------------------------------------------------------- /models/utils/my_logger.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | from pytorch_lightning.loggers import TensorBoardLogger 3 | from pytorch_lightning.utilities import rank_zero_only 4 | 5 | 6 | class MyLogger(TensorBoardLogger): 7 | 8 | @rank_zero_only 9 | def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: 10 | for k, v in metrics.items(): 11 | _my_step = step 12 | if k.startswith('val/'): # use epoch for val metrics 13 | _my_step = int(metrics['epoch']) 14 | super().log_metrics(metrics={k: v}, step=_my_step) 15 | -------------------------------------------------------------------------------- /models/utils/my_progress_bar.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.callbacks.progress import TQDMProgressBar 2 | import sys 3 | from torch import Tensor 4 | from pytorch_lightning import Trainer 5 | 6 | 7 | class MyProgressBar(TQDMProgressBar): 8 | """print out the metrics on_validation_epoch_end 9 | """ 10 | 11 | def on_validation_epoch_end(self, trainer: Trainer, pl_module): 12 | super().on_validation_epoch_end(trainer, pl_module) 13 | sys.stdout.flush() 14 | if trainer.is_global_zero: 15 | metrics = trainer.logged_metrics 16 | infos = f"\x1B[1A\x1B[K\nEpoch {trainer.current_epoch} metrics: " 17 | for k, v in metrics.items(): 18 | value = v 19 | if isinstance(v, Tensor): 20 | value = v.item() 21 | if isinstance(value, float): 22 | infos += k + f"={value:.4f} " 23 | else: 24 | infos += k + f"={value} " 25 | if len(metrics) > 0: 26 | sys.stdout.write(f'{infos}\x1B[K\n') 27 | sys.stdout.flush() 28 | -------------------------------------------------------------------------------- /models/utils/my_rich_progress_bar.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from pytorch_lightning import Trainer 4 | from pytorch_lightning.callbacks import RichProgressBar 5 | from pytorch_lightning.callbacks.progress.rich_progress import * 6 | from torch import Tensor 7 | 8 | 9 | class MyRichProgressBar(RichProgressBar): 10 | """A progress bar prints metrics at the end of each epoch 11 | """ 12 | 13 | def on_validation_end(self, trainer: Trainer, pl_module): 14 | super().on_validation_end(trainer, pl_module) 15 | sys.stdout.flush() 16 | if trainer.is_global_zero: 17 | metrics = trainer.logged_metrics 18 | infos = f"Epoch {trainer.current_epoch} metrics: " 19 | for k, v in metrics.items(): 20 | if k.startswith('train/'): 21 | continue 22 | value = v 23 | if isinstance(v, Tensor): 24 | value = v.item() 25 | if isinstance(value, float): 26 | if abs(value) < 1: 27 | infos += k + f"={value:.4e} " 28 | else: 29 | infos += k + f"={value:.4f} " 30 | else: 31 | infos += k + f"={value} " 32 | if len(metrics) > 0: 33 | sys.stdout.write(f'{infos}\n') 34 | sys.stdout.flush() 35 | -------------------------------------------------------------------------------- /models/utils/my_save_config_callback.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | from jsonargparse import Namespace 3 | from pytorch_lightning import Trainer, LightningModule 4 | from pytorch_lightning.cli import SaveConfigCallback 5 | 6 | 7 | class MySaveConfigCallback(SaveConfigCallback): 8 | ignores: List[str] = ['progress_bar', 'learning_rate_monitor', 'model_summary'] # 'model_checkpoint', 9 | 10 | def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None: 11 | for ignore in MySaveConfigCallback.ignores: 12 | self.del_config(ignore) 13 | super().setup(trainer, pl_module, stage) 14 | 15 | @staticmethod 16 | def add_ignores(ignore: str): 17 | MySaveConfigCallback.ignores.append(ignore) 18 | 19 | def del_config(self, ignore: str): 20 | if '.' not in ignore: 21 | if ignore in self.config: 22 | del self.config[ignore] 23 | else: 24 | config: Namespace = self.config 25 | ignore_namespace = ignore.split('.') 26 | for idx, name in enumerate(ignore_namespace): 27 | if idx != len(ignore_namespace) - 1: 28 | if name in config: 29 | config = config[name] 30 | else: 31 | return 32 | else: 33 | del config[name] 34 | -------------------------------------------------------------------------------- /models/utils/shared_cli.py: -------------------------------------------------------------------------------- 1 | """ 2 | Command Line Interface for different models, provides command line controls for training, test, and inference 3 | """ 4 | import os 5 | 6 | os.environ["OMP_NUM_THREADS"] = str(1) # limit the threads to reduce cpu overloads, will speed up when there are lots of CPU cores on the running machine 7 | 8 | import torch 9 | 10 | torch.backends.cuda.matmul.allow_tf32 = True # The flag below controls whether to allow TF32 on matmul. This flag defaults to False in PyTorch 1.12 and later. 11 | torch.backends.cudnn.allow_tf32 = True # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. 12 | 13 | import pytorch_lightning as pl 14 | from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint, ModelSummary 15 | from pytorch_lightning.cli import LightningArgumentParser, LightningCLI 16 | from jsonargparse import lazy_instance 17 | 18 | from models.utils import MyRichProgressBar as RichProgressBar 19 | # from pytorch_lightning.loggers import TensorBoardLogger 20 | from models.utils.my_logger import MyLogger as TensorBoardLogger 21 | from models.utils.my_save_config_callback import MySaveConfigCallback as SaveConfigCallback 22 | 23 | 24 | class SharedCLI(LightningCLI): 25 | 26 | def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: 27 | parser.set_defaults({"trainer.strategy": "ddp_find_unused_parameters_false"}) 28 | 29 | # RichProgressBar 30 | parser.add_lightning_class_args(RichProgressBar, nested_key='progress_bar') 31 | if pl.__version__.startswith('1.5.'): 32 | parser.set_defaults({ 33 | "progress_bar.refresh_rate_per_second": 1, 34 | }) 35 | else: 36 | parser.set_defaults({"progress_bar.console_kwargs": { 37 | "force_terminal": True, 38 | "no_color": True, 39 | "width": 200, 40 | }}) 41 | 42 | # LearningRateMonitor 43 | parser.add_lightning_class_args(LearningRateMonitor, "learning_rate_monitor") 44 | learning_rate_monitor_defaults = { 45 | "learning_rate_monitor.logging_interval": "epoch", 46 | } 47 | parser.set_defaults(learning_rate_monitor_defaults) 48 | 49 | # ModelSummary 50 | parser.add_lightning_class_args(ModelSummary, 'model_summary') 51 | model_summary_defaults = { 52 | "model_summary.max_depth": -1, 53 | } 54 | parser.set_defaults(model_summary_defaults) 55 | 56 | def before_fit(self): 57 | resume_from_checkpoint: str = self.config['fit']['ckpt_path'] 58 | if resume_from_checkpoint is not None and resume_from_checkpoint.endswith('last.ckpt'): 59 | # log in same dir 60 | # resume_from_checkpoint example: /mnt/home/quancs/projects/NBSS_pmt/logs/NBSS_ifp/version_29/checkpoints/last.ckpt 61 | resume_from_checkpoint = os.path.normpath(resume_from_checkpoint) 62 | splits = resume_from_checkpoint.split(os.path.sep) 63 | version = int(splits[-3].replace('version_', '')) 64 | save_dir = os.path.sep.join(splits[:-3]) 65 | self.trainer.logger = TensorBoardLogger(save_dir=save_dir, name="", version=version, default_hp_metric=False) 66 | else: 67 | model_name = type(self.model).__name__ 68 | self.trainer.logger = TensorBoardLogger('logs/', name=model_name, default_hp_metric=False) 69 | 70 | def before_test(self): 71 | if self.config['test']['ckpt_path'] != None: 72 | ckpt_path = self.config['test']['ckpt_path'] 73 | else: 74 | raise Exception('You should give --ckpt_path if you want to test') 75 | epoch = os.path.basename(ckpt_path).split('_')[0] 76 | write_dir = os.path.dirname(os.path.dirname(ckpt_path)) 77 | 78 | test_set = 'test' 79 | if 'test_set' in self.config['test']['data']: 80 | test_set = self.config['test']['data']["test_set"] 81 | elif 'init_args' in self.config['test']['data'] and 'test_set' in self.config['test']['data']['init_args']: 82 | test_set = self.config['test']['data']['init_args']["test_set"] 83 | exp_save_path = os.path.normpath(write_dir + '/' + epoch + '_' + test_set + '_set') 84 | 85 | self.trainer.logger = TensorBoardLogger(exp_save_path, name='', default_hp_metric=False) 86 | print(self.trainer.log_dir) 87 | 88 | def after_test(self): 89 | if not self.trainer.is_global_zero: 90 | return 91 | import fnmatch 92 | files = fnmatch.filter(os.listdir(self.trainer.log_dir), 'events.out.tfevents.*') 93 | for f in files: 94 | os.remove(self.trainer.log_dir + '/' + f) 95 | print('tensorboard log file for test is removed: ' + self.trainer.log_dir + '/' + f) 96 | 97 | 98 | if __name__ == '__main__': 99 | cli = SharedCLI( 100 | pl.LightningModule, 101 | pl.LightningDataModule, 102 | save_config_callback=SaveConfigCallback, 103 | save_config_kwargs={'overwrite': True}, 104 | subclass_mode_data=True, 105 | subclass_mode_model=False, 106 | run=True, 107 | ) 108 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jsonargparse[signatures,urls]>=4.3.1 2 | torchmetrics[audio] 3 | omegaconf 4 | pytorch-lightning>=2.0.0 5 | torch>=1.12.1 6 | rich 7 | mir_eval 8 | soundfile 9 | pesq 10 | tensorboard 11 | pandas 12 | torcheval==0.0.6 13 | mamba-ssm 14 | causal-conv1d>=1.1.0 --------------------------------------------------------------------------------