├── .gitignore
├── LICENSE
├── README.md
├── SharedTrainer.py
├── configs
├── NB-BLSTM.yaml
├── NBC.yaml
├── NBC2.yaml
├── SpatialNet.yaml
├── datasets
│ ├── CHiME3_moving_rir_cfg.npz
│ ├── chime3_moving.yaml
│ ├── chime3_moving_plus_static.yaml
│ ├── chime3_static.yaml
│ ├── sms_wsj.yaml
│ ├── sms_wsj_plus.yaml
│ ├── sms_wsj_plus_diffuse.npz
│ ├── sms_wsj_rir_cfg.npz
│ ├── spatialized_wsj0_mix.yaml
│ └── whamr.yaml
├── onlineSpatialNet.yaml
└── wsj0-mix
│ ├── mix_2_spk_cv.txt
│ ├── mix_2_spk_tr.txt
│ ├── mix_2_spk_tt.txt
│ ├── mix_3_spk_cv.txt
│ ├── mix_3_spk_tr.txt
│ ├── mix_3_spk_tt.txt
│ └── speaker_gender.csv
├── data_loaders
├── __init__.py
├── chime3_moving.py
├── libricss.py
├── reverb.py
├── sms_wsj.py
├── sms_wsj_plus.py
├── spatialized_wsj0_mix.py
├── spk4_wsj0_mix_sp.py
├── utils
│ ├── array_geometry.py
│ ├── collate_func.py
│ ├── diffuse_noise.py
│ ├── mix.py
│ ├── my_distributed_sampler.py
│ ├── rand.py
│ └── window.py
└── whamr.py
├── generate_rirs.py
├── images
├── model_size_and_flops.png
└── results.png
├── models
├── __init__.py
├── arch
│ ├── NBC.py
│ ├── NBC2.py
│ ├── NBSS.py
│ ├── OnlineSpatialNet.py
│ ├── SpatialNet.py
│ ├── base
│ │ ├── linear_group.py
│ │ ├── non_linear.py
│ │ ├── norm.py
│ │ └── retention.py
│ └── blstm2_fc1.py
├── io
│ ├── __init__.py
│ ├── cirm.py
│ ├── loss.py
│ ├── norm.py
│ └── stft.py
├── oracle_beamformer.py
└── utils
│ ├── __init__.py
│ ├── base_cli.py
│ ├── dnsmos.py
│ ├── ensemble.py
│ ├── flops.py
│ ├── general_steps.py
│ ├── git_tools.py
│ ├── metrics.py
│ ├── my_earlystopping.py
│ ├── my_json_encoder.py
│ ├── my_logger.py
│ ├── my_progress_bar.py
│ ├── my_rich_progress_bar.py
│ ├── my_save_config_callback.py
│ └── shared_cli.py
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | .mypy_cache
3 | .vscode
4 | lightning_logs
5 | logs
6 | .dataset
7 | *.log
8 | *.out
9 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Changsheng Quan
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Multichannel Speech Separation, Denoising and Dereverberation
2 |
3 | The official repo of:
4 | [1] Changsheng Quan, Xiaofei Li. [Multi-channel Narrow-band Deep Speech Separation with Full-band Permutation Invariant Training](https://arxiv.org/abs/2110.05966). In ICASSP 2022.
5 | [2] Changsheng Quan, Xiaofei Li. [Multichannel Speech Separation with Narrow-band Conformer](https://arxiv.org/abs/2204.04464). In Interspeech 2022.
6 | [3] Changsheng Quan, Xiaofei Li. [NBC2: Multichannel Speech Separation with Revised Narrow-band Conformer](https://arxiv.org/abs/2212.02076). arXiv:2212.02076.
7 | [4] Changsheng Quan, Xiaofei Li. [SpatialNet: Extensively Learning Spatial Information for Multichannel Joint Speech Separation, Denoising and Dereverberation](https://arxiv.org/abs/2307.16516). TASLP, 2024.
8 | [5] Changsheng Quan, Xiaofei Li. [Multichannel Long-Term Streaming Neural Speech Enhancement for Static and Moving Speakers](https://arxiv.org/abs/2403.07675). IEEE Signal Precessing Letters, 2024.
9 |
10 | Audio examples can be found at [https://audio.westlake.edu.cn/Research/nbss.htm](https://audio.westlake.edu.cn/Research/nbss.htm) and [https://audio.westlake.edu.cn/Research/SpatialNet.htm](https://audio.westlake.edu.cn/Research/SpatialNet.htm).
11 | More information about our group can be found at [https://audio.westlake.edu.cn](https://audio.westlake.edu.cn/Publications.htm).
12 |
13 | ## Performance
14 | SpatialNet:
15 | - Performance
16 |
17 | - Computational cost
18 |
19 |
20 | ## Requirements
21 |
22 | ```bash
23 | pip install -r requirements.txt
24 |
25 | # gpuRIR: check https://github.com/DavidDiazGuerra/gpuRIR
26 | ```
27 |
28 | ## Generate Dataset SMS-WSJ-Plus
29 |
30 | Generate rirs for the dataset `SMS-WSJ_plus` used in `SpatialNet` ablation experiment.
31 |
32 | ```bash
33 | CUDA_VISIBLE_DEVICES=0 python generate_rirs.py --rir_dir ~/datasets/SMS_WSJ_Plus_rirs --save_to configs/datasets/sms_wsj_rir_cfg.npz
34 | cp configs/datasets/sms_wsj_plus_diffuse.npz ~/datasets/SMS_WSJ_Plus_rirs/diffuse.npz # copy diffuse parameters
35 | ```
36 |
37 | For SMS-WSJ, please see https://github.com/fgnt/sms_wsj
38 |
39 | ## Train & Test
40 |
41 | This project is built on the `pytorch-lightning` package, in particular its [command line interface (CLI)](https://pytorch-lightning.readthedocs.io/en/latest/cli/lightning_cli_intermediate.html). Thus we recommond you to have some knowledge about the CLI in lightning. For Chinese user, you can learn CLI & lightning with this begining project [pytorch_lightning_template_for_beginners](https://github.com/Audio-WestlakeU/pytorch_lightning_template_for_beginners).
42 |
43 | **Train** SpatialNet on the 0-th GPU with network config file `configs/SpatialNet.yaml` and dataset config file `configs/datasets/sms_wsj_plus.yaml` (replace the rir & clean speech dir before training).
44 |
45 | ```bash
46 | python SharedTrainer.py fit \
47 | --config=configs/SpatialNet.yaml \ # network config
48 | --config=configs/datasets/sms_wsj_plus.yaml \ # dataset config
49 | --model.channels=[0,1,2,3,4,5] \ # the channels used
50 | --model.arch.dim_input=12 \ # input dim per T-F point, i.e. 2 * the number of channels
51 | --model.arch.dim_output=4 \ # output dim per T-F point, i.e. 2 * the number of sources
52 | --model.arch.num_freqs=129 \ # the number of frequencies, related to model.stft.n_fft
53 | --trainer.precision=bf16-mixed \ # mixed precision training, can also be 16-mixed or 32, where 32 can produce the best performance
54 | --model.compile=true \ # compile the network, requires torch>=2.0. the compiled model is trained much faster
55 | --data.batch_size=[2,4] \ # batch size for train and val
56 | --trainer.devices=0, \
57 | --trainer.max_epochs=100 # better performance may be obtained if more epochs are given
58 | ```
59 |
60 | More gpus can be used by appending the gpu indexes to `trainer.devices`, e.g. `--trainer.devices=0,1,2,3,`.
61 |
62 | **Resume** training from a checkpoint:
63 |
64 | ```bash
65 | python SharedTrainer.py fit --config=logs/SpatialNet/version_x/config.yaml \
66 | --data.batch_size=[2,2] \
67 | --trainer.devices=0, \
68 | --ckpt_path=logs/SpatialNet/version_x/checkpoints/last.ckpt
69 | ```
70 |
71 | where `version_x` should be replaced with the version you want to resume.
72 |
73 | **Test** the model trained:
74 |
75 | ```bash
76 | python SharedTrainer.py test --config=logs/SpatialNet/version_x/config.yaml \
77 | --ckpt_path=logs/SpatialNet/version_x/checkpoints/epochY_neg_si_sdrZ.ckpt \
78 | --trainer.devices=0,
79 | ```
80 |
81 | ## Module Version
82 |
83 | | network | file |
84 | |:---|:---|
85 | | NB-BLSTM [1] / NBC [2] / NBC2 [3] | models/arch/NBSS.py |
86 | | SpatialNet [4] | models/arch/SpatialNet.py |
87 | | online SpatialNet [5] | models/arch/OnlineSpatialNet.py |
88 |
89 | ## Note
90 | The dataset generation & training commands for the `NB-BLSTM`/`NBC`/`NBC2` are available in the `NBSS` branch.
91 |
--------------------------------------------------------------------------------
/configs/NB-BLSTM.yaml:
--------------------------------------------------------------------------------
1 | seed_everything: 2
2 | trainer:
3 | gradient_clip_val: 5
4 | gradient_clip_algorithm: norm
5 | devices: null
6 | accelerator: gpu
7 | strategy: auto
8 | sync_batchnorm: false
9 | precision: 32
10 | model:
11 | arch:
12 | class_path: models.arch.blstm2_fc1.BLSTM2_FC1
13 | init_args:
14 | activation: ""
15 | hidden_size:
16 | - 256
17 | - 128
18 | n_repeat_last_lstm: 1
19 | dropout: null
20 | channels: [0, 1, 2, 3, 4, 5]
21 | ref_channel: 0
22 | stft:
23 | class_path: models.io.stft.STFT
24 | init_args:
25 | n_fft: 256
26 | n_hop: 128
27 | loss:
28 | class_path: models.io.loss.Loss
29 | init_args:
30 | loss_func: models.io.loss.neg_si_sdr
31 | pit: True
32 | norm:
33 | class_path: models.io.norm.Norm
34 | init_args:
35 | mode: frequency
36 | optimizer: [Adam, { lr: 0.001 }]
37 | lr_scheduler: [ReduceLROnPlateau, { mode: min, factor: 0.5, patience: 10, min_lr: 0.0001 }]
38 | exp_name: exp
39 | metrics: [SDR, SI_SDR, NB_PESQ, WB_PESQ, eSTOI]
40 | val_metric: loss
41 |
--------------------------------------------------------------------------------
/configs/NBC.yaml:
--------------------------------------------------------------------------------
1 | seed_everything: 2
2 | trainer:
3 | gradient_clip_val: 5
4 | gradient_clip_algorithm: norm
5 | devices: null
6 | accelerator: gpu
7 | strategy: auto
8 | sync_batchnorm: false
9 | precision: 32
10 | model:
11 | arch:
12 | class_path: models.arch.NBC.NBC
13 | init_args:
14 | # dim_input: 12
15 | # dim_output: 4
16 | encoder_kernel_size: 4
17 | n_heads: 8
18 | n_layers: 4
19 | activation: ""
20 | hidden_size: 192
21 | norm_first: true
22 | ffn_size: 384
23 | inner_conv_kernel_size: 3
24 | inner_conv_groups: 8
25 | inner_conv_bias: true
26 | inner_conv_layers: 3
27 | inner_conv_mid_norm: GN
28 | channels: [0, 1, 2, 3, 4, 5]
29 | ref_channel: 0
30 | stft:
31 | class_path: models.io.stft.STFT
32 | init_args:
33 | n_fft: 256
34 | n_hop: 128
35 | loss:
36 | class_path: models.io.loss.Loss
37 | init_args:
38 | loss_func: models.io.loss.neg_si_sdr
39 | pit: True
40 | norm:
41 | class_path: models.io.norm.Norm
42 | init_args:
43 | mode: frequency
44 | optimizer: [Adam, { lr: 0.001 }]
45 | lr_scheduler: [ReduceLROnPlateau, { mode: min, factor: 0.5, patience: 5, min_lr: 0.0001 }]
46 | exp_name: exp
47 | metrics: [SDR, SI_SDR, NB_PESQ, WB_PESQ, eSTOI]
48 | val_metric: loss
49 |
--------------------------------------------------------------------------------
/configs/NBC2.yaml:
--------------------------------------------------------------------------------
1 | seed_everything: 2
2 | trainer:
3 | gradient_clip_val: 5
4 | gradient_clip_algorithm: norm
5 | devices: null
6 | accelerator: gpu
7 | strategy: auto
8 | sync_batchnorm: false
9 | precision: 32
10 | model:
11 | arch:
12 | class_path: models.arch.NBC2.NBC2
13 | init_args:
14 | # dim_input: 12
15 | # dim_output: 4
16 | n_layers: 8 # 12 for large
17 | encoder_kernel_size: 5
18 | dim_hidden: 96 # 192 for large
19 | dim_ffn: 192 # 384 for large
20 | num_freqs: 129
21 | block_kwargs:
22 | n_heads: 2
23 | dropout: 0
24 | conv_kernel_size: 3
25 | n_conv_groups: 8
26 | norms: [LN, GBN, GBN]
27 | group_batch_norm_kwargs:
28 | # group_size: 129
29 | share_along_sequence_dim: false
30 | channels: [0, 1, 2, 3, 4, 5]
31 | ref_channel: 0
32 | stft:
33 | class_path: models.io.stft.STFT
34 | init_args:
35 | n_fft: 256
36 | n_hop: 128
37 | loss:
38 | class_path: models.io.loss.Loss
39 | init_args:
40 | loss_func: models.io.loss.neg_si_sdr
41 | pit: True
42 | norm:
43 | class_path: models.io.norm.Norm
44 | init_args:
45 | mode: frequency
46 | optimizer: [Adam, { lr: 0.001 }]
47 | lr_scheduler: [ExponentialLR, { gamma: 0.99 }]
48 | exp_name: exp
49 | metrics: [SDR, SI_SDR, NB_PESQ, WB_PESQ, eSTOI]
50 | val_metric: loss
51 |
--------------------------------------------------------------------------------
/configs/SpatialNet.yaml:
--------------------------------------------------------------------------------
1 | seed_everything: 2
2 | trainer:
3 | gradient_clip_val: 5
4 | gradient_clip_algorithm: norm
5 | devices: null
6 | accelerator: gpu
7 | strategy: auto
8 | sync_batchnorm: false
9 | precision: 32
10 | model:
11 | arch:
12 | class_path: models.arch.SpatialNet.SpatialNet
13 | init_args:
14 | # dim_input: 12
15 | # dim_output: 4
16 | num_layers: 8 # 12 for large
17 | encoder_kernel_size: 5
18 | dim_hidden: 96 # 192 for large
19 | dim_ffn: 192 # 384 for large
20 | num_heads: 4
21 | dropout: [0, 0, 0]
22 | kernel_size: [5, 3]
23 | conv_groups: [8, 8]
24 | norms: ["LN", "LN", "GN", "LN", "LN", "LN"]
25 | dim_squeeze: 8 # 16 for large
26 | num_freqs: 129
27 | full_share: 0
28 | channels: [0, 1, 2, 3, 4, 5]
29 | ref_channel: 0
30 | stft:
31 | class_path: models.io.stft.STFT
32 | init_args:
33 | n_fft: 256
34 | n_hop: 128
35 | loss:
36 | class_path: models.io.loss.Loss
37 | init_args:
38 | loss_func: models.io.loss.neg_si_sdr
39 | pit: True
40 | norm:
41 | class_path: models.io.norm.Norm
42 | init_args:
43 | mode: frequency
44 | optimizer: [Adam, { lr: 0.001 }]
45 | lr_scheduler: [ExponentialLR, { gamma: 0.99 }]
46 | exp_name: exp
47 | metrics: [SDR, SI_SDR, NB_PESQ, WB_PESQ, eSTOI]
48 | val_metric: loss
49 |
--------------------------------------------------------------------------------
/configs/datasets/CHiME3_moving_rir_cfg.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Audio-WestlakeU/NBSS/cc42fc8ad2e6642c09b8f4169a85b4766dc22b7e/configs/datasets/CHiME3_moving_rir_cfg.npz
--------------------------------------------------------------------------------
/configs/datasets/chime3_moving.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | class_path: data_loaders.chime3_moving.CHiME3MovingDataModule
3 | init_args:
4 | wsj0_dir: ~/datasets/wsj0
5 | rir_dir: ~/datasets/CHiME3_moving_rirs
6 | chime3_dir: ~/datasets/CHiME3
7 | target: direct_path
8 | datasets: ["train_moving(0.12,0.4)", "val_moving(0.12,0.4)", "test_moving(0.12,0.4)", "test_moving(0.12,0.4)"]
9 | audio_time_len: [4.0, 32.0, 32.0, 32.0]
10 | snr: [-5, 10]
11 | batch_size: [4, 8]
12 |
--------------------------------------------------------------------------------
/configs/datasets/chime3_moving_plus_static.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | class_path: data_loaders.chime3_moving.CHiME3MovingDataModule
3 | init_args:
4 | wsj0_dir: ~/datasets/wsj0
5 | rir_dir: ~/datasets/CHiME3_moving_rirs
6 | chime3_dir: ~/datasets/CHiME3
7 | target: direct_path
8 | datasets: ["train_moving(0.12,0.4,0.5)", "val_moving(0.12,0.4,0.5)", "test_moving(0.12,0.4,0.5)", "test_moving(0.12,0.4,0.5)"]
9 | audio_time_len: [4.0, 32.0, 32.0, 32.0]
10 | snr: [-5, 10]
11 | batch_size: [4, 8]
12 |
--------------------------------------------------------------------------------
/configs/datasets/chime3_static.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | class_path: data_loaders.chime3_moving.CHiME3MovingDataModule
3 | init_args:
4 | wsj0_dir: ~/datasets/wsj0
5 | rir_dir: ~/datasets/CHiME3_moving_rirs
6 | chime3_dir: ~/datasets/CHiME3
7 | target: direct_path
8 | datasets: ["train", "val", "test", "test"]
9 | audio_time_len: [4.0, 32.0, 32.0, 32.0]
10 | snr: [-5, 10]
11 | batch_size: [4, 8]
12 |
--------------------------------------------------------------------------------
/configs/datasets/sms_wsj.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | class_path: data_loaders.sms_wsj.SmsWsjDataModule
3 | init_args:
4 | sms_wsj_dir: ~/datasets/sms_wsj/
5 | target: direct_path
6 | audio_time_len: [4.0, 4.0, null]
7 | batch_size: [2, 1]
8 | test_set: test
9 |
--------------------------------------------------------------------------------
/configs/datasets/sms_wsj_plus.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | class_path: data_loaders.sms_wsj_plus.SmsWsjPlusDataModule
3 | init_args:
4 | sms_wsj_dir: ~/datasets/sms_wsj/
5 | rir_dir: ~/datasets/SMS_WSJ_Plus_rirs/
6 | target: direct_path
7 | datasets: ["train_si284", "cv_dev93", "test_eval92", "test_eval92"]
8 | audio_time_len: [4.0, 4.0, null, null]
9 | ovlp: mid
10 | speech_overlap_ratio: [0.1, 1.0]
11 | sir: [-5, 5]
12 | snr: [0, 20]
13 | num_spk: 2
14 | noise_type: ["babble", "white"]
15 | batch_size: [2, 1]
16 |
--------------------------------------------------------------------------------
/configs/datasets/sms_wsj_plus_diffuse.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Audio-WestlakeU/NBSS/cc42fc8ad2e6642c09b8f4169a85b4766dc22b7e/configs/datasets/sms_wsj_plus_diffuse.npz
--------------------------------------------------------------------------------
/configs/datasets/sms_wsj_rir_cfg.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Audio-WestlakeU/NBSS/cc42fc8ad2e6642c09b8f4169a85b4766dc22b7e/configs/datasets/sms_wsj_rir_cfg.npz
--------------------------------------------------------------------------------
/configs/datasets/spatialized_wsj0_mix.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | class_path: data_loaders.spatialized_wsj0_mix.SpatializedWSJ0MixDataModule
3 | init_args:
4 | sp_wsj0_dir: ~/datasets/spatialized-wsj0-mix
5 | version: min
6 | target: reverb
7 | sample_rate: 8000
8 | num_speakers: 2
9 | audio_time_len: [4.0, 4.0, null]
10 | batch_size: [2, 1]
11 |
--------------------------------------------------------------------------------
/configs/datasets/whamr.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | class_path: data_loaders.whamr.WHAMRDataModule
3 | init_args:
4 | whamr_dir: ~/datasets/whamr
5 | version: min
6 | target: anechoic
7 | sample_rate: 8000
8 | audio_time_len: [4.0, 4.0, null]
9 | batch_size: [2, 1]
10 |
--------------------------------------------------------------------------------
/configs/onlineSpatialNet.yaml:
--------------------------------------------------------------------------------
1 | seed_everything: 2
2 | trainer:
3 | gradient_clip_val: 1
4 | gradient_clip_algorithm: norm
5 | devices: null
6 | accelerator: gpu
7 | strategy: auto
8 | sync_batchnorm: false
9 | precision: 32
10 | model:
11 | arch:
12 | class_path: models.arch.OnlineSpatialNet.OnlineSpatialNet
13 | init_args:
14 | # dim_input: 16
15 | # dim_output: 4
16 | num_layers: 8
17 | encoder_kernel_size: 5
18 | dim_hidden: 96
19 | dim_ffn: 192
20 | num_heads: 4
21 | dropout: [0, 0, 0]
22 | kernel_size: [5, 3]
23 | conv_groups: [8, 8]
24 | norms: ["LN", "LN", "GN", "LN", "LN", "LN"]
25 | dim_squeeze: 8
26 | # num_freqs: 257
27 | full_share: 0 # set to 9999 to not share the full-band module, which will increase the model performance with the cost of larger parameter size.
28 | attention: mamba(16,4) # mhsa(251)/ret(2)/mamba(16,4)
29 | decay: [4, 5, 9, 10]
30 | rope: false
31 | channels: [0, 1, 2, 3, 4, 5]
32 | ref_channel: 0
33 | stft:
34 | class_path: models.io.stft.STFT
35 | init_args: {} # by default set to {} to avoid using wrong stft config
36 | # n_fft: 256
37 | # n_hop: 128
38 | loss:
39 | class_path: models.io.loss.Loss
40 | init_args:
41 | loss_func: models.io.loss.neg_snr
42 | pit: true
43 | norm:
44 | class_path: models.io.norm.Norm
45 | init_args:
46 | mode: utterance
47 | online: true
48 | optimizer: [AdamW, { lr: 0.001, weight_decay: 0.001}]
49 | lr_scheduler: [ExponentialLR, { gamma: 0.99 }]
50 | # lr_scheduler: [ReduceLROnPlateau, {mode: max, factor: 0.5, patience: 5, min_lr: 0.0001}]
51 | exp_name: exp
52 | metrics: [SNR, SDR, SI_SDR, NB_PESQ, WB_PESQ, eSTOI]
53 | val_metric: loss
54 | early_stopping:
55 | enable: false
56 | monitor: val/metric
57 | patience: 10
58 | mode: max
59 | min_delta: 0.1
60 |
--------------------------------------------------------------------------------
/configs/wsj0-mix/speaker_gender.csv:
--------------------------------------------------------------------------------
1 | ID Gender
2 | 001 M
3 | 002 F
4 | 00a F
5 | 00b M
6 | 00c M
7 | 00d M
8 | 00f F
9 | 010 M
10 | 011 F
11 | 012 M
12 | 013 M
13 | 014 F
14 | 015 M
15 | 016 F
16 | 017 F
17 | 018 F
18 | 019 F
19 | 01l M
20 | 01a F
21 | 01b F
22 | 01c F
23 | 01d F
24 | 01e M
25 | 01f F
26 | 01g M
27 | 01h F
28 | 01i M
29 | 01j F
30 | 01k F
31 | 01m F
32 | 01n F
33 | 01o F
34 | 01p F
35 | 01q F
36 | 01r M
37 | 01s M
38 | 01t M
39 | 01u F
40 | 01v F
41 | 01w M
42 | 01x F
43 | 01y M
44 | 01z M
45 | 020 M
46 | 021 M
47 | 022 F
48 | 023 F
49 | 024 M
50 | 025 M
51 | 026 M
52 | 027 F
53 | 028 F
54 | 029 M
55 | 02a F
56 | 02b M
57 | 02c F
58 | 02d F
59 | 02e F
60 | 02f F
61 | 050 F
62 | 051 M
63 | 052 M
64 | 053 F
65 | 200 M
66 | 201 M
67 | 202 F
68 | 203 F
69 | 204 F
70 | 205 F
71 | 206 F
72 | 207 M
73 | 208 M
74 | 209 F
75 | 20a F
76 | 20b F
77 | 20c M
78 | 20d F
79 | 20e F
80 | 20f M
81 | 20g M
82 | 20h F
83 | 20i M
84 | 20j M
85 | 20k M
86 | 20l M
87 | 20m M
88 | 20n M
89 | 20o M
90 | 20p F
91 | 20q M
92 | 20r M
93 | 20s M
94 | 20t F
95 | 20u M
96 | 20v M
97 | 22g M
98 | 22h M
99 | 400 M
100 | 401 F
101 | 403 M
102 | 404 F
103 | 405 M
104 | 406 M
105 | 407 F
106 | 408 M
107 | 409 F
108 | 40a M
109 | 40b M
110 | 40c M
111 | 40d F
112 | 40e F
113 | 40f M
114 | 40g F
115 | 40h F
116 | 40i M
117 | 40j M
118 | 40k M
119 | 40l F
120 | 40m F
121 | 40n M
122 | 40o F
123 | 40p F
124 | 420 F
125 | 421 F
126 | 422 M
127 | 423 M
128 | 430 F
129 | 431 M
130 | 432 F
131 | 440 M
132 | 441 F
133 | 442 M
134 | 443 M
135 | 444 F
136 | 445 F
137 | 446 M
138 | 447 M
--------------------------------------------------------------------------------
/data_loaders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Audio-WestlakeU/NBSS/cc42fc8ad2e6642c09b8f4169a85b4766dc22b7e/data_loaders/__init__.py
--------------------------------------------------------------------------------
/data_loaders/spatialized_wsj0_mix.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | from os.path import *
4 | from pathlib import Path
5 | from typing import *
6 | from typing import Callable, List, Optional, Tuple
7 |
8 | import numpy as np
9 | import soundfile as sf
10 | import torch
11 | from pytorch_lightning import LightningDataModule
12 | from pytorch_lightning.utilities.rank_zero import rank_zero_info
13 | from torch.utils.data import DataLoader, Dataset
14 |
15 | from data_loaders.utils.collate_func import default_collate_func
16 | from data_loaders.utils.my_distributed_sampler import MyDistributedSampler
17 | from data_loaders.utils.rand import randint
18 |
19 |
20 | class SpatializedWSJMixDataset(Dataset):
21 | """The Spatialized WSJ0-2/3Mix dataset"""
22 |
23 | def __init__(
24 | self,
25 | sp_wsj0_dir: str,
26 | dataset: str,
27 | version: str = 'min',
28 | target: str = 'reverb',
29 | audio_time_len: Optional[float] = None,
30 | sample_rate: int = 8000,
31 | num_speakers: int = 2,
32 | ) -> None:
33 | """The Spatialized WSJ-2/3Mix dataset
34 |
35 | Args:
36 | sp_wsj0_dir: a dir contains [2speakers_reverb]
37 | dataset: tr, cv, tt
38 | target: anechoic or reverb
39 | version: min or max
40 | audio_time_len: cut the audio to `audio_time_len` seconds if given audio_time_len
41 | """
42 | super().__init__()
43 | assert target in ['anechoic', 'reverb'], target
44 | assert sample_rate in [8000, 16000], sample_rate
45 | assert dataset in ['tr', 'cv', 'tt'], dataset
46 | assert version in ['min', 'max'], version
47 | assert num_speakers in [2, 3], num_speakers
48 |
49 | self.sp_wsj0_dir = str(Path(sp_wsj0_dir).expanduser())
50 | self.wav_dir = Path(self.sp_wsj0_dir) / f"{num_speakers}speakers_{target}" / {8000: 'wav8k', 16000: 'wav16k'}[sample_rate] / version / dataset
51 | self.files = [basename(str(x)) for x in list((self.wav_dir / 'mix').rglob('*.wav'))]
52 | self.files.sort()
53 | assert len(self.files) > 0, f"dir is empty or not exists: {self.sp_wsj0_dir}"
54 |
55 | self.version = version
56 | self.dataset = dataset
57 | self.target = target
58 | self.audio_time_len = audio_time_len
59 | self.sr = sample_rate
60 |
61 | def __getitem__(self, index_seed: Union[int, Tuple[int, int]]):
62 | if type(index_seed) == int:
63 | index = index_seed
64 | if self.dataset == 'tr':
65 | seed = random.randint(a=0, b=99999999)
66 | else:
67 | seed = index
68 | else:
69 | index, seed = index_seed
70 | g = torch.Generator()
71 | g.manual_seed(seed)
72 |
73 | mix, sr = sf.read(self.wav_dir / 'mix' / self.files[index])
74 | s1, sr = sf.read(self.wav_dir / 's1' / self.files[index])
75 | s2, sr = sf.read(self.wav_dir / 's2' / self.files[index])
76 | assert sr == self.sr, (sr, self.sr)
77 | mix = mix.T
78 | target = np.stack([s1.T, s2.T], axis=0) # [spk, chn, time]
79 |
80 | # pad or cut signals
81 | T = mix.shape[-1]
82 | start = 0
83 | if self.audio_time_len:
84 | frames = int(sr * self.audio_time_len)
85 | if T < frames:
86 | mix = np.pad(mix, pad_width=((0, 0), (0, frames - T)), mode='constant', constant_values=0)
87 | target = np.pad(target, pad_width=((0, 0), (0, 0), (0, frames - T)), mode='constant', constant_values=0)
88 | elif T > frames:
89 | start = randint(g, low=0, high=T - frames)
90 | mix = mix[:, start:start + frames]
91 | target = target[:, :, start:start + frames]
92 |
93 | paras = {
94 | 'index': index,
95 | 'seed': seed,
96 | 'wavname': self.files[index],
97 | 'wavdir': str(self.wav_dir),
98 | 'sample_rate': self.sr,
99 | 'dataset': self.dataset,
100 | 'target': self.target,
101 | 'version': self.version,
102 | 'audio_time_len': self.audio_time_len,
103 | 'start': start,
104 | }
105 |
106 | return torch.as_tensor(mix, dtype=torch.float32), torch.as_tensor(target, dtype=torch.float32), paras
107 |
108 | def __len__(self):
109 | return len(self.files)
110 |
111 |
112 | class SpatializedWSJ0MixDataModule(LightningDataModule):
113 |
114 | def __init__(
115 | self,
116 | sp_wsj0_dir: str,
117 | version: str = 'min',
118 | target: str = 'reverb',
119 | sample_rate: int = 8000,
120 | num_speakers: int = 2,
121 | audio_time_len: Tuple[Optional[float], Optional[float], Optional[float]] = [4.0, 4.0, None], # audio_time_len (seconds) for training, val, test.
122 | batch_size: List[int] = [1, 1],
123 | test_set: str = 'test', # the dataset to test: train, val, test
124 | num_workers: int = 5,
125 | collate_func_train: Callable = default_collate_func,
126 | collate_func_val: Callable = default_collate_func,
127 | collate_func_test: Callable = default_collate_func,
128 | seeds: Tuple[Optional[int], int, int] = [None, 2, 3], # random seeds for train, val and test sets
129 | # if pin_memory=True, will occupy a lot of memory & speed up
130 | pin_memory: bool = True,
131 | # prefetch how many samples, will increase the memory occupied when pin_memory=True
132 | prefetch_factor: int = 5,
133 | persistent_workers: bool = False,
134 | ):
135 | super().__init__()
136 | self.sp_wsj0_dir = sp_wsj0_dir
137 | self.version = version
138 | self.target = target
139 | self.sample_rate = sample_rate
140 | self.num_speakers = num_speakers
141 | self.audio_time_len = audio_time_len
142 | self.persistent_workers = persistent_workers
143 | self.test_set = test_set
144 |
145 | rank_zero_info("dataset: SpatializedWSJ0Mix")
146 | rank_zero_info(f'train/valid/test set: {version} {target} {sample_rate}, time length={audio_time_len}, {num_speakers}spk')
147 | assert audio_time_len[2] == None, 'test audio time length should be None'
148 |
149 | self.batch_size = batch_size[0]
150 | self.batch_size_val = batch_size[1]
151 | self.batch_size_test = 1
152 | if len(batch_size) > 2:
153 | self.batch_size_test = batch_size[2]
154 | rank_zero_info(f'batch size: train={self.batch_size}; val={self.batch_size_val}; test={self.batch_size_test}')
155 | assert self.batch_size_test == 1, "batch size for test should be 1 as the audios have different length"
156 |
157 | self.num_workers = num_workers
158 |
159 | self.collate_func_train = collate_func_train
160 | self.collate_func_val = collate_func_val
161 | self.collate_func_test = collate_func_test
162 |
163 | self.seeds = []
164 | for seed in seeds:
165 | self.seeds.append(seed if seed is not None else random.randint(0, 1000000))
166 |
167 | self.pin_memory = pin_memory
168 | self.prefetch_factor = prefetch_factor
169 |
170 | def setup(self, stage=None):
171 | if stage is not None and stage == 'test':
172 | audio_time_len = None
173 | else:
174 | audio_time_len = self.audio_time_len[0]
175 |
176 | self.train = SpatializedWSJMixDataset(
177 | sp_wsj0_dir=self.sp_wsj0_dir,
178 | dataset='tr',
179 | version=self.version,
180 | target=self.target,
181 | audio_time_len=audio_time_len,
182 | sample_rate=self.sample_rate,
183 | )
184 | self.val = SpatializedWSJMixDataset(
185 | sp_wsj0_dir=self.sp_wsj0_dir,
186 | dataset='cv',
187 | version=self.version,
188 | target=self.target,
189 | audio_time_len=self.audio_time_len[1],
190 | sample_rate=self.sample_rate,
191 | )
192 | self.test = SpatializedWSJMixDataset(
193 | sp_wsj0_dir=self.sp_wsj0_dir,
194 | dataset='tt',
195 | version=self.version,
196 | target=self.target,
197 | audio_time_len=None,
198 | sample_rate=self.sample_rate,
199 | )
200 |
201 | def train_dataloader(self) -> DataLoader:
202 | return DataLoader(
203 | self.train,
204 | sampler=MyDistributedSampler(self.train, seed=self.seeds[0], shuffle=True),
205 | batch_size=self.batch_size,
206 | collate_fn=self.collate_func_train,
207 | num_workers=self.num_workers,
208 | prefetch_factor=self.prefetch_factor,
209 | pin_memory=self.pin_memory,
210 | persistent_workers=self.persistent_workers,
211 | )
212 |
213 | def val_dataloader(self) -> DataLoader:
214 | return DataLoader(
215 | self.val,
216 | sampler=MyDistributedSampler(self.val, seed=self.seeds[1], shuffle=False),
217 | batch_size=self.batch_size_val,
218 | collate_fn=self.collate_func_val,
219 | num_workers=self.num_workers,
220 | prefetch_factor=self.prefetch_factor,
221 | pin_memory=self.pin_memory,
222 | persistent_workers=self.persistent_workers,
223 | )
224 |
225 | def test_dataloader(self) -> DataLoader:
226 | if self.test_set == 'test':
227 | dataset = self.test
228 | elif self.test_set == 'val':
229 | dataset = self.val
230 | else: # train
231 | dataset = self.train
232 |
233 | return DataLoader(
234 | dataset,
235 | sampler=MyDistributedSampler(self.test, seed=self.seeds[2], shuffle=False),
236 | batch_size=self.batch_size_test,
237 | collate_fn=self.collate_func_test,
238 | num_workers=self.num_workers,
239 | prefetch_factor=self.prefetch_factor,
240 | pin_memory=self.pin_memory,
241 | persistent_workers=self.persistent_workers,
242 | )
243 |
244 |
245 | if __name__ == '__main__':
246 | """python -m data_loaders.spatialized_wsj0_mix"""
247 | from argparse import ArgumentParser
248 | parser = ArgumentParser("")
249 | parser.add_argument('--sp_wsj0_dir', type=str, default='~/datasets/spatialized-wsj0-mix')
250 | parser.add_argument('--version', type=str, default='min')
251 | parser.add_argument('--target', type=str, default='reverb')
252 | parser.add_argument('--sample_rate', type=int, default=8000)
253 | parser.add_argument('--gen_unprocessed', type=bool, default=True)
254 | parser.add_argument('--gen_target', type=bool, default=True)
255 | parser.add_argument('--save_dir', type=str, default='dataset/sp_wsj')
256 | parser.add_argument('--dataset', type=str, default='train', choices=['train', 'val', 'test'])
257 |
258 | args = parser.parse_args()
259 | os.makedirs(args.save_dir, exist_ok=True)
260 |
261 | if not args.gen_unprocessed and not args.gen_target:
262 | exit()
263 |
264 | datamodule = SpatializedWSJ0MixDataModule(args.sp_wsj0_dir, args.version, target=args.target, sample_rate=args.sample_rate, batch_size=[1, 1], num_workers=1)
265 | datamodule.setup()
266 | if args.dataset == 'train':
267 | dataloader = datamodule.train_dataloader()
268 | elif args.dataset == 'val':
269 | dataloader = datamodule.val_dataloader()
270 | else:
271 | assert args.dataset == 'test'
272 | dataloader = datamodule.test_dataloader()
273 |
274 | for idx, (mix, tar, paras) in enumerate(dataloader):
275 | # write target to dir
276 | print(mix.shape, tar.shape, paras)
277 |
278 | if idx > 10:
279 | continue
280 |
281 | if args.gen_target:
282 | tar_path = Path(f"{args.save_dir}/{args.target}/{args.dataset}").expanduser()
283 | tar_path.mkdir(parents=True, exist_ok=True)
284 | for i in range(2):
285 | # assert np.max(np.abs(tar[0, i, 0, :].numpy())) <= 1
286 | sp = tar_path / (paras[0]['wavname'] + f'_spk{i}.wav')
287 | if not sp.exists():
288 | sf.write(sp, tar[0, i, 0, :].numpy(), samplerate=paras[0]['sample_rate'])
289 |
290 | # write unprocessed's 0-th channel
291 | if args.gen_unprocessed:
292 | tar_path = Path(f"{args.save_dir}/unprocessed/{args.dataset}").expanduser()
293 | tar_path.mkdir(parents=True, exist_ok=True)
294 | # assert np.max(np.abs(mix[0, 0, :].numpy())) <= 1
295 | sp = tar_path / (paras[0]['wavname'])
296 | if not sp.exists():
297 | sf.write(sp, mix[0, 0, :].numpy(), samplerate=paras[0]['sample_rate'])
298 |
--------------------------------------------------------------------------------
/data_loaders/utils/array_geometry.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tqdm
3 | from numpy.linalg import norm
4 |
5 |
6 | def normalize(vec: np.ndarray) -> np.ndarray:
7 | # get unit vector
8 | vec = vec / norm(vec)
9 | vec = vec / norm(vec)
10 | assert np.isclose(norm(vec), 1), 'norm of vec is not close to 1'
11 | return vec
12 |
13 |
14 | def circular_array_geometry(radius: float, mic_num: int) -> np.ndarray:
15 | # 生成圆阵的拓扑(原点为中心),后期可以通过旋转、改变中心的位置来实现阵列位置的改变
16 | pos_rcv = np.empty((mic_num, 3))
17 | v1 = np.array([1, 0, 0]) # 第一个麦克风的位置(要求单位向量)
18 | v1 = normalize(v1) # 单位向量
19 | # 将v1绕原点水平旋转angle角度,来生成其他mic的位置
20 | angles = np.arange(0, 2 * np.pi, 2 * np.pi / mic_num)
21 | for idx, angle in enumerate(angles):
22 | x = v1[0] * np.cos(angle) - v1[1] * np.sin(angle)
23 | y = v1[0] * np.sin(angle) + v1[1] * np.cos(angle)
24 | pos_rcv[idx, :] = normalize(np.array([x, y, 0]))
25 | # 设置radius
26 | pos_rcv *= radius
27 | return pos_rcv
28 |
29 |
30 | def linear_array_geometry(radius: float, mic_num: int) -> np.ndarray:
31 | xs = np.arange(start=0, stop=radius * mic_num, step=radius)
32 | xs -= np.mean(xs) # 将中心移动到原点
33 | pos_rcv = np.zeros((mic_num, 3))
34 | pos_rcv[:, 0] = xs
35 | return pos_rcv
36 |
37 |
38 | def chime3_array_geometry() -> np.ndarray:
39 | # TODO 加入麦克风的朝向向量,以及麦克风的全向/半向
40 | pos_rcv = np.zeros((6, 3))
41 | pos_rcv[0, :] = np.array([-0.1, 0.095, 0])
42 | pos_rcv[1, :] = np.array([0, 0.095, 0])
43 | pos_rcv[2, :] = np.array([0.1, 0.095, 0])
44 | pos_rcv[3, :] = np.array([-0.1, -0.095, 0])
45 | pos_rcv[4, :] = np.array([0, -0.095, 0])
46 | pos_rcv[5, :] = np.array([0.1, -0.095, 0])
47 |
48 | # 验证边长是否正确,边与边之间是否垂直
49 | assert np.isclose(np.linalg.norm(pos_rcv[0, :] - pos_rcv[1, :]), 0.1), 'distance between #1 and #2 is wrong'
50 | assert np.isclose(np.linalg.norm(pos_rcv[1, :] - pos_rcv[2, :]), 0.1), 'distance between #2 and #3 is wrong'
51 | assert np.isclose(np.linalg.norm(pos_rcv[0, :] - pos_rcv[3, :]), 0.19), 'distance between #1 and #4 is wrong'
52 | assert np.isclose(np.linalg.norm(pos_rcv[2, :] - pos_rcv[5, :]), 0.19), 'distance between #3 and #6 is wrong'
53 | assert np.isclose(np.linalg.norm(pos_rcv[3, :] - pos_rcv[4, :]), 0.1), 'distance between #4 and #5 is wrong'
54 | assert np.isclose(np.linalg.norm(pos_rcv[4, :] - pos_rcv[5, :]), 0.1), 'distance between #5 and #6 is wrong'
55 | assert np.isclose(np.dot(pos_rcv[0, :] - pos_rcv[1, :], pos_rcv[0, :] - pos_rcv[3, :]), 0), 'not vertical'
56 | assert np.isclose(np.dot(pos_rcv[2, :] - pos_rcv[5, :], pos_rcv[4, :] - pos_rcv[5, :]), 0), 'not vertical'
57 | return pos_rcv
58 |
59 |
60 | def libricss_array_geometry() -> np.ndarray:
61 | pos_rcv = np.zeros((7, 3))
62 | pos_rcv_c = circular_array_geometry(radius=0.0425, mic_num=6)
63 | pos_rcv[1:, :] = pos_rcv_c
64 | return pos_rcv
65 |
--------------------------------------------------------------------------------
/data_loaders/utils/collate_func.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Tuple
2 |
3 | import numpy as np
4 | import torch
5 | from torch import Tensor
6 |
7 |
8 | def default_collate_func(batches: List[Tuple[Tensor, Tensor, Dict[str, Any]]]) -> List[Any]:
9 | mini_batch = []
10 | for x in zip(*batches):
11 | if isinstance(x[0], np.ndarray):
12 | x = [torch.tensor(x[i]) for i in range(len(x))]
13 | if isinstance(x[0], Tensor):
14 | x = torch.stack(x)
15 | mini_batch.append(x)
16 | return mini_batch
17 |
--------------------------------------------------------------------------------
/data_loaders/utils/diffuse_noise.py:
--------------------------------------------------------------------------------
1 | #############################################################################################
2 | # Translated from https://github.com/ehabets/ANF-Generator. Implementation of the method in
3 | #
4 | # Generating Nonstationary Multisensor Signals under a Spatial Coherence Constraint.
5 | # Habets, Emanuël A. P. and Cohen, Israel and Gannot, Sharon
6 | #
7 | # Note: Though the generated noise is diffuse, but it doesn't simulate the reverberation of rooms
8 | #
9 | # Copyright: Changsheng Quan @ Audio Lab of Westlake University
10 | #############################################################################################
11 |
12 | import math
13 |
14 | import numpy as np
15 | import scipy
16 | from scipy.signal import stft, istft
17 |
18 |
19 | def gen_desired_spatial_coherence(pos_mics: np.ndarray, fs: int, noise_field: str = 'spherical', c: float = 343.0, nfft: int = 256) -> np.ndarray:
20 | """generate desired spatial coherence for one array
21 |
22 | Args:
23 | pos_mics: microphone positions, shape (num_mics, 3)
24 | fs: sampling frequency
25 | noise_field: 'spherical' or 'cylindrical'
26 | c: sound velocity
27 | nfft: points of fft
28 |
29 | Raises:
30 | Exception: Unknown noise field if noise_field != 'spherical' and != 'cylindrical'
31 |
32 | Returns:
33 | np.ndarray: desired spatial coherence, shape [num_mics, num_mics, num_freqs]
34 | np.ndarray: desired mixing matrices, shape [num_freqs, num_mics, num_mics]
35 |
36 |
37 | Reference: E. A. P. Habets, “Arbitrary noise field generator.” https://github.com/ehabets/ANF-Generator
38 | """
39 | assert pos_mics.shape[1] == 3, pos_mics.shape
40 | M = pos_mics.shape[0]
41 | num_freqs = nfft // 2 + 1
42 |
43 | # compute desired spatial coherence matric
44 | ww = 2 * math.pi * fs * np.array(list(range(num_freqs))) / nfft
45 | dist = np.linalg.norm(pos_mics[:, np.newaxis, :] - pos_mics[np.newaxis, :, :], axis=-1, keepdims=True)
46 | if noise_field == 'spherical':
47 | DSC = np.sinc(ww * dist / (c * math.pi))
48 | elif noise_field == 'cylindrical':
49 | DSC = scipy.special(0, ww * dist / c)
50 | else:
51 | raise Exception('Unknown noise field')
52 |
53 | # compute mixing matrices of the desired spatial coherence matric
54 | Cs = np.zeros((num_freqs, M, M), dtype=np.complex128)
55 | for k in range(1, num_freqs):
56 | D, V = np.linalg.eig(DSC[:, :, k])
57 | C = V.T * np.sqrt(D)[:, np.newaxis]
58 | # C = scipy.linalg.cholesky(DSC[:, :, k])
59 | Cs[k, ...] = C
60 |
61 | return DSC, Cs
62 |
63 |
64 | def gen_diffuse_noise(noise: np.ndarray, L: int, Cs: np.ndarray, nfft: int = 256, rng: np.random.Generator = np.random.default_rng()) -> np.ndarray:
65 | """generate diffuse noise with the mixing matrice of desired spatial coherence
66 |
67 | Args:
68 | noise: at least `num_mic*L` samples long
69 | L: the length in samples
70 | Cs: mixing matrices, shape [num_freqs, num_mics, num_mics]
71 | nfft: the number of fft points
72 | rng: random number generator used for reproducibility
73 |
74 | Returns:
75 | np.ndarray: multi-channel diffuse noise, shape [num_mics, L]
76 | """
77 |
78 | M = Cs.shape[-1]
79 | assert noise.shape[-1] >= M * L, noise.shape
80 |
81 | # Generate M mutually 'independent' input signals
82 | # noise = noise - np.mean(noise)
83 | assert noise.shape[-1] >= M * L, ("The noise signal should be at least `num_mic*L` samples long", noise.shape, M, L)
84 | start = rng.integers(low=0, high=noise.shape[-1] - M * L + 1)
85 | noise = noise[start:start + M * L].reshape(M, L)
86 | noise = noise - np.mean(noise, axis=-1, keepdims=True)
87 | f, t, N = stft(noise, window='hann', nperseg=nfft, noverlap=0.75 * nfft, nfft=nfft) # N: [M,F,T]
88 | # Generate output in the STFT domain for each frequency bin k
89 | X = np.einsum('fmn,mft->nft', np.conj(Cs), N)
90 | # Compute inverse STFT
91 | F, x = istft(X, window='hann', nperseg=nfft, noverlap=0.75 * nfft, nfft=nfft)
92 | x = x[:, :L]
93 | return x # [M, L]
94 |
95 |
96 | if __name__ == '__main__':
97 | import matplotlib.pyplot as plt
98 | import soundfile as sf
99 | from pathlib import Path
100 | # pos_mic = np.random.randn(8, 3)
101 | # pos_mic = np.array([[0, 0, 1.5], [0, 0.2, 1.5]])
102 | nfft = 1024
103 | num_mics = 3
104 | pos_mics = [[0, 0, 1.5]]
105 | for i in range(1, num_mics):
106 | pos_mics.append([0, 0.3 * i, 1.5])
107 | pos_mics = np.array(pos_mics)
108 | DSC, Cs = gen_desired_spatial_coherence(pos_mics=pos_mics, fs=8000, noise_field='spherical', nfft=nfft)
109 |
110 | # wav_files = Path('dataset/datasets/fma_small/000').rglob('*.mp3')
111 | # noise = np.concatenate([sf.read(wav_file,always_2d=True)[0][:,0] for wav_file in wav_files])
112 | noise = np.random.randn(8000 * 22 * 8)
113 | x = gen_diffuse_noise(noise=noise, T=20, fs=8000, Cs=Cs, nfft=nfft)
114 |
115 | f, t, X = stft(x, window='hann', nperseg=nfft, noverlap=0.75 * nfft, nfft=nfft) # X: [M,F,T]
116 | cross_psd_0 = np.mean(X[[0], :, :] * np.conj(X[1:, :, :]), axis=-1)
117 | cross_psd_1 = np.mean(np.abs(X[[0], :, :])**2, axis=-1) * np.mean(np.abs(X[1:, :, :])**2, axis=-1)
118 | cross_psd = cross_psd_0 / np.sqrt(cross_psd_1)
119 | sc_generated = np.real(cross_psd) # 实部是关于f的偶函数。此处只绘制实部是因为虚部的值小?因为前面的spatial conherence是实数?
120 |
121 | ww = 2 * math.pi * 8000 * np.array(list(range(nfft // 2 + 1))) / nfft
122 | if num_mics > 2:
123 | dist = np.linalg.norm(pos_mics[1:] - pos_mics[[0], ...], axis=-1, keepdims=True)
124 | else:
125 | dist = np.linalg.norm(pos_mics[[1]] - pos_mics[[0], ...], axis=-1, keepdims=True)
126 | sc_theory = np.sinc(ww * dist / (343 * math.pi))
127 |
128 | for i in range(len(sc_theory)):
129 | plt.plot(list(range(nfft // 2 + 1)), sc_theory[i])
130 | plt.plot(list(range(nfft // 2 + 1)), sc_generated[i])
131 | plt.title(f"Chn {i+2} vs. Chn {1}")
132 | plt.show()
133 |
--------------------------------------------------------------------------------
/data_loaders/utils/mix.py:
--------------------------------------------------------------------------------
1 | ##############################################################################################################
2 | # Note: there are datasets relaying on these functions.
3 | # Changing the implementations might change the data generated by these datasets.
4 | #
5 | # Copyright: Changsheng Quan @ Audio Lab of Westlake University
6 | ##############################################################################################################
7 | """
8 | mid:
9 | ---
10 | -------
11 | headtail:
12 | ------
13 | ----
14 | startend:
15 | --- start/end ---
16 | ------ ------
17 | full:
18 | ---------
19 | ---------
20 | hms: headtail, mid or startend
21 | fhms: full, headtail, mid or startend
22 | """
23 |
24 | from typing import *
25 |
26 | import numpy as np
27 | from numpy import ndarray
28 | from numpy.random import Generator
29 | from scipy.signal import fftconvolve
30 |
31 | OVLPS = ['mid', 'headtail', 'startend', 'full', 'hms', 'fhms']
32 |
33 |
34 | def sample_an_overlap(ovlp_type: str, num_spk: int, rng: Generator) -> str:
35 | assert ovlp_type in OVLPS, ovlp_type
36 | assert num_spk in [1, 2], num_spk
37 |
38 | if num_spk == 1:
39 | ovlp_type = 'full'
40 | elif ovlp_type == 'fhms':
41 | types = ['full', 'headtail', 'mid', 'startend']
42 | which_type = rng.integers(low=0, high=len(types))
43 | ovlp_type = types[which_type]
44 | elif ovlp_type == 'hms':
45 | types = ['headtail', 'mid', 'startend']
46 | which_type = rng.integers(low=0, high=len(types))
47 | ovlp_type = types[which_type]
48 | else:
49 | assert ovlp_type in ['full', 'headtail', 'mid', 'startend'], ovlp_type
50 |
51 | if ovlp_type == 'startend':
52 | types = ['start', 'end']
53 | which_type = rng.integers(low=0, high=len(types))
54 | ovlp_type = types[which_type]
55 | else:
56 | ovlp_type = ovlp_type
57 |
58 | return ovlp_type
59 |
60 |
61 | def sample_ovlp_ratio_and_cal_length(ovlp_type: str, ratio_range: Tuple[float, float], target_len: Optional[int], lens: List[int], rng: Generator) -> Tuple[float, List[int], int]:
62 | """sample one overlap ratio and calculate the needed length for each wav
63 |
64 | Returns:
65 | Tuple[float, List[int]]: ovlp_ratio, needed length
66 | """
67 | for rr in ratio_range:
68 | assert rr >= 0 and rr <= 1, rr
69 | assert ratio_range[0] <= ratio_range[1], ratio_range
70 |
71 | if target_len == None:
72 | mix_frames = max(lens)
73 | if ovlp_type == 'full':
74 | ovlp_ratio = 1
75 | elif ovlp_type == 'headtail':
76 | low = ratio_range[0]
77 | high = np.min(lens) / np.max(lens)
78 | if low > high:
79 | ovlp_ratio = high
80 | else:
81 | ovlp_ratio = rng.uniform(low=low, high=high)
82 | mix_frames = round((np.min(lens) + np.max(lens)) / (1 + ovlp_ratio))
83 | elif ovlp_type == 'mid':
84 | ovlp_ratio = np.min(lens) / np.max(lens)
85 | else:
86 | assert ovlp_type in ['start', 'end'], ovlp_type
87 | ovlp_ratio = np.min(lens) / np.max(lens)
88 | else:
89 | mix_frames = target_len
90 | ovlp_ratio = rng.uniform(low=ratio_range[0], high=ratio_range[1])
91 | if ovlp_type == 'full':
92 | lens = [mix_frames] * len(lens)
93 | ovlp_ratio = 1
94 | elif ovlp_type == 'headtail':
95 | lens = [int(mix_frames * (0.5 + ovlp_ratio / 2))] * len(lens)
96 | else:
97 | assert ovlp_type in ['mid', 'start', 'end'], ovlp_type
98 | max_idx = lens.index(max(lens))
99 | min_idx = lens.index(min(lens))
100 | if max_idx == min_idx:
101 | max_idx = [1, 0][max_idx]
102 | lens[max_idx] = mix_frames
103 | lens[min_idx] = int(mix_frames * ovlp_ratio)
104 | return ovlp_ratio, lens, mix_frames
105 |
106 |
107 | def pad_or_cut(wavs: List[ndarray], lens: List[int], rng: Generator) -> List[ndarray]:
108 | """repeat signals if they are shorter than the length needed, then cut them to needed
109 | """
110 | for i, wav in enumerate(wavs):
111 | # repeat
112 | while len(wav) < lens[i]:
113 | wav = np.concatenate([wav, wav])
114 | # cut to needed length
115 | if len(wav) > lens[i]:
116 | start = rng.integers(low=0, high=len(wav) - lens[i] + 1)
117 | wav = wav[start:start + lens[i]]
118 | wavs[i] = wav
119 | return wavs
120 |
121 |
122 | def convolve(wav: ndarray, rir: ndarray, rir_target: ndarray, ref_channel: Optional[int] = 0, align: bool = True) -> Tuple[ndarray, ndarray]:
123 | assert wav.ndim == 1, wav.shape
124 | assert rir.ndim == 2 and rir_target.ndim == 2, (rir.shape, rir_target.shape)
125 |
126 | rvbt = fftconvolve(wav[np.newaxis, :], rir, mode='full', axes=-1)
127 | target = rvbt if rir is rir_target else fftconvolve(wav[np.newaxis, :], rir_target, mode='full', axes=-1)
128 | if align:
129 | # Note: don't take the dry clean source as target if the ref_channel is not correct
130 | rir = rir[ref_channel, ...]
131 | delay = np.argmax(rir)
132 | rvbt = rvbt[:, delay:delay + wav.shape[-1]]
133 | target = target[:, delay:delay + wav.shape[-1]]
134 | return rvbt, target
135 |
136 | def convolve_v2(wav: ndarray, rir: ndarray, rir_target: ndarray, ref_channel: Optional[int] = 0, align: bool = True) -> Tuple[ndarray, ndarray]:
137 | assert wav.ndim == 1, wav.shape
138 | assert rir.ndim == 2 and rir_target.ndim == 2, (rir.shape, rir_target.shape)
139 |
140 | rvbt = fftconvolve(wav[np.newaxis, :], rir, mode='full', axes=-1)
141 | target = rvbt if rir is rir_target else fftconvolve(wav[np.newaxis, :], rir_target, mode='full', axes=-1)
142 | if align:
143 | # Note: don't take the dry clean source as target if the ref_channel is not correct
144 | rir_align = rir_target[ref_channel, ...] # use rir_target rather than rir
145 | delay = np.argmax(rir_align)
146 | rvbt = rvbt[:, delay:delay + wav.shape[-1]]
147 | target = target[:, delay:delay + wav.shape[-1]]
148 | return rvbt, target
149 |
150 |
151 | def convolve_traj(wav: np.ndarray, traj_rirs: np.ndarray, traj_rirs_tar: np.ndarray, samples_per_rir: Union[np.ndarray, int], ref_channel: Optional[int] = 0, align: bool = True) -> np.ndarray:
152 | """Convolve wav by using a set of trajectory rirs (Note: the generated audio signal using this method has click noise)
153 |
154 | Args:
155 | wav: shape [time]
156 | traj_rirs: shape [num_rirs, num_mics, num_samples]
157 | traj_rirs_tar: shape [num_rirs, num_mics, num_samples]
158 | samples_per_rir: the number of samples in wav for each trajectory rir
159 | ref_channel: reference channel. Defaults to 0.
160 | align: Defaults to True.
161 |
162 | Returns:
163 | np.ndarray: [num_mics, time]
164 | """
165 | assert wav.ndim == 1, "not implemented"
166 | wav_samps = wav.shape[0]
167 | if isinstance(samples_per_rir, np.ndarray):
168 | assert samples_per_rir.ndim == 1
169 | assert samples_per_rir.sum() == wav.shape[-1], "the number of samples specified in samples_per_rir should match that of the wav"
170 | else:
171 | if wav_samps % samples_per_rir == 0:
172 | samples_per_rir = [samples_per_rir] * (wav_samps // samples_per_rir)
173 | else:
174 | samples_per_rir = [samples_per_rir] * (wav_samps // samples_per_rir) + [wav_samps % samples_per_rir]
175 | (num_rirs, num_mics, rir_samps), rir_samps_t = traj_rirs.shape, traj_rirs_tar.shape[-1]
176 | assert num_rirs == len(samples_per_rir), "the number of rirs should match the length of samples_per_rir"
177 |
178 | rvbt = np.zeros((num_mics, rir_samps + wav_samps - 1), dtype=np.float32)
179 | target = np.zeros((num_mics, rir_samps_t + wav_samps - 1), dtype=np.float32)
180 | start_samp = 0
181 | for i, n_samps in enumerate(samples_per_rir):
182 | wav_i = wav[start_samp:start_samp + n_samps]
183 | rir_i = traj_rirs[i]
184 | rir_i_tar = traj_rirs_tar[i]
185 | rvbt[:, start_samp:start_samp + n_samps + rir_samps - 1] += fftconvolve(wav_i[np.newaxis], rir_i, mode='full', axes=-1)
186 | target[:, start_samp:start_samp + n_samps + rir_samps_t - 1] += fftconvolve(wav_i[np.newaxis], rir_i_tar, mode='full', axes=-1)
187 | start_samp += n_samps
188 |
189 | if align:
190 | rir = traj_rirs_tar[0, ref_channel, ...]
191 | delay = np.argmax(rir)
192 | rvbt = rvbt[:, delay:delay + wav.shape[-1]]
193 | target = target[:, delay:delay + wav.shape[-1]]
194 | return rvbt, target
195 |
196 |
197 | def convolve_traj_with_win(wav: np.ndarray, traj_rirs: np.ndarray, samples_per_rir: int, wintype: str = 'trapezium20') -> np.ndarray:
198 | """Convolve wav by using a set of trajectory rirs (Note: the generated audio signal using this method barely have click noise)
199 |
200 | Args:
201 | wav: shape [T]
202 | traj_rirs: shape [num_rirs, num_mics, num_samples]
203 | samples_per_rir: the number of samples in wav for each trajectory rir
204 | wintype: hann, tri, or trapezium. by default, trapezium20 is used.
205 |
206 | Returns:
207 | np.ndarray: [num_mics, time]
208 | """
209 | assert wav.ndim == 1, "not implemented"
210 | wav_samps = wav.shape[0]
211 |
212 | hop = samples_per_rir
213 | samples_per_rir = samples_per_rir * 2
214 | num_rirs, num_mics, rir_samps = traj_rirs.shape
215 |
216 | if wintype == 'hann':
217 | win = np.hanning(samples_per_rir)
218 | elif wintype.startswith('trapezium'): # 左右一边10个点
219 | n = int(wintype.replace('trapezium', ''))
220 | assert hop - n > 0, hop
221 | tri = np.arange(0, n) / (n - 1)
222 | tri2 = np.arange((n - 1), -1, -1) / (n - 1)
223 | zlen = (hop - n) // 2
224 | onelen = hop - n - zlen
225 | win = np.concatenate([np.zeros(zlen), tri, np.ones(onelen * 2), tri2, np.zeros(zlen)])
226 | else:
227 | assert wintype == 'tri', wintype
228 | win = np.concatenate([np.arange(0, samples_per_rir // 2), np.arange(samples_per_rir // 2 - 1, -1, -1)]) / (samples_per_rir // 2 - 1)
229 |
230 | out = np.zeros((num_mics, rir_samps + wav_samps - 1), dtype=np.float32)
231 | for i, start_samp in enumerate(range(0, wav_samps + hop - 1, hop)):
232 | rir_i = traj_rirs[i]
233 |
234 | if start_samp == 0:
235 | wav_i = wav[start_samp:start_samp + hop] * win[hop:]
236 | out[:, start_samp:start_samp + hop + rir_samps - 1] += fftconvolve(wav_i[np.newaxis], rir_i, axes=-1)
237 | elif wav.shape[-1] >= start_samp + hop:
238 | wav_i = wav[start_samp - hop:start_samp + hop] * win
239 | out[:, start_samp - hop:start_samp + hop + rir_samps - 1] += fftconvolve(wav_i[np.newaxis], rir_i, axes=-1)
240 | else:
241 | wav_i = wav[start_samp - hop:] * win[:wav.shape[-1] - start_samp + hop]
242 | out[:, start_samp - hop:] += fftconvolve(wav_i[np.newaxis], rir_i, axes=-1)
243 |
244 | return out
245 |
246 |
247 | def align(rir: np.ndarray, rvbt: np.ndarray, target: np.ndarray, src: np.ndarray):
248 | assert rir.ndim == 1 and src.ndim == 1, (rir.shape, src.shape)
249 | delay = np.argmax(rir)
250 | rvbt = rvbt[:, delay:delay + src.shape[-1]]
251 | target = target[:, delay:delay + src.shape[-1]]
252 | return rvbt, target
253 |
254 |
255 | def convolve1(wav: ndarray, rir: ndarray, ref_channel: Optional[int] = 0, align: bool = True) -> Union[Tuple[ndarray, ndarray], ndarray]:
256 | assert wav.ndim == 1, wav.shape
257 | while wav.ndim < rir.ndim:
258 | wav = wav[np.newaxis, ...]
259 | rvbt = fftconvolve(wav, rir, mode='full', axes=-1)
260 | if align:
261 | # Note: don't take the dry clean source as target if the ref_channel is not correct
262 | if rir.ndim >= 2:
263 | rir = rir[..., ref_channel, :] # the second last dim is regarded as the channel dim
264 | delay = np.argmax(rir)
265 | rvbt = rvbt[..., delay:delay + wav.shape[-1]]
266 | return rvbt
267 |
268 |
269 | def overlap2(rvbts: List[ndarray], targets: List[ndarray], ovlp_type: str, mix_frames: int, rng: Generator) -> Tuple[ndarray, ndarray]:
270 | assert np.array([rvbt_i.shape == target_i.shape for (rvbt_i, target_i) in zip(rvbts, targets)]).all(), "rvbt and target should have the same shape"
271 | assert len(rvbts) <= 2 and len(targets) <= 2, "this function is used only for two-speaker overlapping"
272 | assert rvbts[0].ndim == 2 and rvbts[0].shape[0] < rvbts[0].shape[1], "rvbt should have a shape of [chn, time]"
273 |
274 | num_spk, chn_num = len(rvbts), rvbts[0].shape[0]
275 |
276 | rvbt = np.zeros((num_spk, chn_num, mix_frames), dtype=np.float32)
277 | target = np.zeros((num_spk, chn_num, mix_frames), dtype=np.float32)
278 |
279 | for i, (rvbt_i, target_i) in enumerate(zip(rvbts, targets)):
280 | # overlap signals
281 | Ti = rvbt_i.shape[-1] # use all revbt signals
282 |
283 | if ovlp_type == 'full':
284 | shift = 0
285 | elif ovlp_type == 'mid':
286 | if Ti == mix_frames:
287 | shift = 0
288 | else:
289 | shift = rng.integers(low=0, high=mix_frames - Ti + 1) # [0, mix_frames - use_len]
290 | elif ovlp_type == 'start' or ovlp_type == 'end':
291 | assert num_spk == 2
292 | if Ti == mix_frames:
293 | shift = 0
294 | else:
295 | shift = {'start': 0, 'end': mix_frames - Ti}[ovlp_type]
296 | else:
297 | assert ovlp_type == 'headtail', ovlp_type
298 | assert num_spk == 2
299 | shift = 0 if i == 0 else (mix_frames - Ti)
300 |
301 | rvbt[i, :, shift:shift + Ti] = rvbt_i[:, :]
302 | target[i, :, shift:shift + Ti] = target_i[:, :]
303 | return rvbt, target
304 |
305 |
306 | def overlap3(rvbts: List[ndarray], targets: List[ndarray], mix_frames: int, rng: Generator, output_stream: int = 2) -> Tuple[ndarray, ndarray]:
307 | assert np.array([rvbt_i.shape == target_i.shape for (rvbt_i, target_i) in zip(rvbts, targets)]).all(), "rvbt and target should have the same shape"
308 | assert len(rvbts) == 3 and len(targets) == 3, "this function is used only for 3-speaker overlapping"
309 | assert output_stream == 2, "2-stream output is supported only"
310 | assert rvbts[0].ndim == 2 and rvbts[0].shape[0] < rvbts[0].shape[1], "rvbt should have a shape of [chn, time]"
311 |
312 | num_spk, chn_num = len(rvbts), rvbts[0].shape[0]
313 |
314 | rvbt = np.zeros((2, chn_num, mix_frames), dtype=np.float32)
315 | target = np.zeros((2, chn_num, mix_frames), dtype=np.float32)
316 |
317 | rvbt[0, :, :] = rvbts[0][:, :]
318 | rvbt[1, :, :rvbts[1].shape[-1]] = rvbts[1][:, :]
319 | rvbt[1, :, -rvbts[2].shape[-1]:] = rvbts[2][:, :]
320 |
321 | target[0, :, :] = targets[0][:, :]
322 | target[1, :, :targets[1].shape[-1]] = targets[1][:, :]
323 | target[1, :, -targets[2].shape[-1]:] = targets[2][:, :]
324 |
325 | return rvbt, target
326 |
327 |
328 | def cal_coeff_for_adjusting_relative_energy(wav1: ndarray, wav2: ndarray, target_dB: float) -> Optional[float]:
329 | r"""calculate the coefficient used for adjusting the relative energy of two signals
330 |
331 | Args:
332 | wav1: the first wav
333 | wav2: the second wav
334 | target_dB: the target relative energy in dB, i.e. after adjusting: 10 * log_10 (average_energy(wav1) / average_energy(wav2 * coeff)) = target_dB
335 |
336 | Returns:
337 | float: coeff
338 | """
339 | # compute averaged energy over time and channel
340 | ae1 = np.sum(wav1**2) / np.prod(wav1.shape)
341 | ae2 = np.sum(wav2**2) / np.prod(wav2.shape)
342 | if ae1 == 0 or ae2 == 0 or not np.isfinite(ae1) or not np.isfinite(ae2):
343 | return None
344 | # compute the coefficients
345 | coeff = np.sqrt(ae1 / ae2 * np.power(10, -target_dB / 10))
346 | return coeff # multiply it with wav2
347 |
--------------------------------------------------------------------------------
/data_loaders/utils/my_distributed_sampler.py:
--------------------------------------------------------------------------------
1 | ##############################################################################################################
2 | # Code reproducity is essential for DL. The MyDistributedSampler tries to make datasets reproducible by
3 | # generating a seed for each dataset item at specific epoch.
4 | #
5 | # Copyright: Changsheng Quan @ Audio Lab of Westlake University 2023
6 | ##############################################################################################################
7 |
8 |
9 |
10 | import math
11 | from typing import Iterator, Optional
12 |
13 | import torch
14 | from pytorch_lightning.utilities.rank_zero import rank_zero_warn
15 | from torch.utils.data import Dataset
16 | from torch.utils.data.distributed import DistributedSampler, T_co
17 |
18 |
19 | class MyDistributedSampler(DistributedSampler[T_co]):
20 | r"""Sampler for single GPU and multi GPU (or Distributed) cases. Change int index to a tuple (index, random seed for this index).
21 | This sampler is used to enhance the reproducibility of datasets by generating random seed for each item at each epoch.
22 | """
23 |
24 | def __init__(
25 | self,
26 | dataset: Dataset,
27 | num_replicas: Optional[int] = None,
28 | rank: Optional[int] = None,
29 | shuffle: bool = True,
30 | seed: int = 0,
31 | drop_last: bool = False,
32 | ) -> None:
33 | try:
34 | super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)
35 | except:
36 | # if error raises, it is running on single GPU
37 | # thus, set num_replicas=1, rank=0
38 | super().__init__(dataset, 1, 0, shuffle, seed, drop_last)
39 | self.last_epoch = -1
40 |
41 | def __iter__(self) -> Iterator[T_co]:
42 | if self.shuffle:
43 | # deterministically shuffle based on epoch and seed
44 | g = torch.Generator()
45 | g.manual_seed(self.seed + self.epoch)
46 | if self.last_epoch >= self.epoch:
47 | if self.epoch != 0:
48 | rank_zero_warn(f'shuffle is true but the epoch value doesn\'t get update, thus the order of training data won\'t change at epoch={self.epoch}')
49 | else:
50 | self.last_epoch = self.epoch
51 | indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore
52 | else:
53 | g = torch.Generator()
54 | g.manual_seed(self.seed)
55 |
56 | indices = list(range(len(self.dataset))) # type: ignore
57 |
58 | seeds = []
59 | for i in range(len(indices)):
60 | seed = torch.randint(high=9999999999, size=(1,), generator=g)[0].item()
61 | seeds.append(seed)
62 | indices = list(zip(indices, seeds))
63 |
64 | # drop last
65 | if not self.drop_last:
66 | # add extra samples to make it evenly divisible
67 | padding_size = self.total_size - len(indices)
68 | if padding_size <= len(indices):
69 | indices += indices[:padding_size]
70 | else:
71 | indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
72 | else:
73 | # remove tail of data to make it evenly divisible.
74 | indices = indices[:self.total_size]
75 | assert len(indices) == self.total_size
76 |
77 | # subsample
78 | indices = indices[self.rank:self.total_size:self.num_replicas]
79 | assert len(indices) == self.num_samples
80 |
81 | return iter(indices) # type: ignore
82 |
83 | def __len__(self) -> int:
84 | return self.num_samples
85 |
86 | def set_epoch(self, epoch: int) -> None:
87 | r"""
88 | Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
89 | use a different random ordering for each epoch. Otherwise, the next iteration of this
90 | sampler will yield the same ordering.
91 |
92 | Args:
93 | epoch (int): Epoch number.
94 | """
95 | self.epoch = epoch
96 |
--------------------------------------------------------------------------------
/data_loaders/utils/rand.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def randint(g: torch.Generator, low: int, high: int) -> int:
5 | """return a value sampled in [low, high)
6 | """
7 | if low == high:
8 | return low
9 | r = torch.randint(low=low, high=high, size=(1,), generator=g, device='cpu')
10 | return r[0].item() # type:ignore
11 |
12 |
13 | def randfloat(g: torch.Generator, low: float, high: float) -> float:
14 | """return a value sampled in [low, high)
15 | """
16 | if low == high:
17 | return low
18 | r = torch.rand(size=(1,), generator=g, device='cpu')[0].item()
19 | return float(low + r * (high - low))
20 |
--------------------------------------------------------------------------------
/data_loaders/utils/window.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def reverberation_time_shortening_window(rir: np.ndarray, original_T60: float, target_T60: float, sr: int = 8000, time_after_max: float = 0.002, time_before_max: float = None) -> np.ndarray:
5 | """shorten the T60 of a given rir
6 |
7 | Args:
8 | rir: the rir array
9 | original_T60: the T60 of the rir
10 | target_T60: the target T60
11 | sr: sample rate
12 | time_after_max: time in seconds after the maximum value in rir taken as part of the direct path. Defaults to 0.002.
13 | time_before_max: time in seconds before the maximum value in rir taken as part of the direct path. By default, all the values before the maximum are taken as direct path.
14 |
15 | Returns:
16 | np.ndarray: the reverberation time shortening window
17 | """
18 |
19 | if original_T60 <= target_T60:
20 | return np.ones(shape=rir.shape)
21 | shape = rir.shape
22 | rir = rir.reshape(-1, shape[-1])
23 | win = np.empty(shape=rir.shape, dtype=rir.dtype)
24 | q = 3 / (target_T60 * sr) - 3 / (original_T60 * sr)
25 | exps = 10**(-q * np.arange(rir.shape[-1]))
26 | idx_max_array = np.argmax(np.abs(rir), axis=-1)
27 | for i, idx_max in enumerate(idx_max_array):
28 | N1 = idx_max + int(time_after_max * sr)
29 | win[i, :N1] = 1
30 | win[i, N1:] = exps[:rir.shape[-1] - N1]
31 | if time_before_max:
32 | N0 = int(idx_max - time_before_max * sr)
33 | if N0 > 0:
34 | win[i, :N0] = 0
35 | win = win.reshape(shape)
36 | return win
37 |
38 |
39 | def rectangular_window(rir: np.ndarray, sr: int = 8000, time_before_after_max: float = 0.002) -> np.ndarray:
40 | assert rir.ndim == 1, rir.ndim
41 | idx = int(np.argmax(np.abs(rir)))
42 | win = np.zeros(shape=rir.shape)
43 | N = int(sr * time_before_after_max)
44 | win[max(0, idx - N):idx + N + 1] = 1
45 | return win
46 |
47 |
48 | if __name__ == '__main__':
49 | rir = np.random.rand(3, 2, 10000)
50 | rir[..., 1000] = 2
51 | win = reverberation_time_shortening_window(rir, original_T60=0.8, target_T60=0.1, sr=8000)
52 | print(win.shape)
53 |
--------------------------------------------------------------------------------
/data_loaders/whamr.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | from os.path import *
4 | from pathlib import Path
5 | from typing import *
6 |
7 | import numpy as np
8 | import soundfile as sf
9 | import torch
10 | from pytorch_lightning import LightningDataModule
11 | from pytorch_lightning.utilities.rank_zero import rank_zero_info
12 |
13 | from torch.utils.data import DataLoader, Dataset
14 |
15 | from data_loaders.utils.collate_func import default_collate_func
16 | from data_loaders.utils.my_distributed_sampler import MyDistributedSampler
17 | from data_loaders.utils.rand import randint
18 |
19 |
20 | class WHAMRDataset(Dataset):
21 | """The WHAMR! dataset"""
22 |
23 | def __init__(
24 | self,
25 | whamr_dir: str,
26 | dataset: str,
27 | version: str = 'min',
28 | target: str = 'anechoic',
29 | audio_time_len: Optional[float] = None,
30 | sample_rate: int = 8000,
31 | ) -> None:
32 | """The WHAMR! dataset
33 |
34 | Args:
35 | whamr_dir: a dir contains [wav8k]
36 | dataset: tr, cv, tt
37 | target: anechoic or reverb
38 | version: min or max
39 | audio_time_len: cut the audio to `audio_time_len` seconds if given audio_time_len
40 | """
41 | super().__init__()
42 | assert target in ['anechoic', 'reverb'], target
43 | assert sample_rate in [8000, 16000], sample_rate
44 | assert dataset in ['tr', 'cv', 'tt'], dataset
45 | assert version in ['min', 'max'], version
46 |
47 | self.whamr_dir = str(Path(whamr_dir).expanduser())
48 | self.wav_dir = Path(self.whamr_dir) / {8000: 'wav8k', 16000: 'wav16k'}[sample_rate] / version / dataset
49 | self.files = [basename(str(x)) for x in list((self.wav_dir / 'mix_both_reverb').rglob('*.wav'))]
50 | self.files.sort()
51 | assert len(self.files) > 0, (self.whamr_dir, ': files is empty!')
52 |
53 | self.version = version
54 | self.dataset = dataset
55 | self.target = target
56 | self.audio_time_len = audio_time_len
57 | self.sr = sample_rate
58 |
59 | def __getitem__(self, index_seed: Union[int, Tuple[int, int]]):
60 | if type(index_seed) == int:
61 | index = index_seed
62 | if self.dataset == 'tr':
63 | seed = random.randint(a=0, b=99999999)
64 | else:
65 | seed = index
66 | else:
67 | index, seed = index_seed
68 | g = torch.Generator()
69 | g.manual_seed(seed)
70 |
71 | mix, sr = sf.read(self.wav_dir / 'mix_both_reverb' / self.files[index])
72 | s1, sr = sf.read(self.wav_dir / ('s1_' + self.target) / self.files[index])
73 | s2, sr = sf.read(self.wav_dir / ('s2_' + self.target) / self.files[index])
74 | assert sr == self.sr, (sr, self.sr)
75 | mix = mix.T
76 | target = np.stack([s1.T, s2.T], axis=0) # [spk, chn, time]
77 |
78 | # pad or cut signals
79 | T = mix.shape[-1]
80 | start = 0
81 | if self.audio_time_len:
82 | frames = int(sr * self.audio_time_len)
83 | if T < frames:
84 | mix = np.pad(mix, pad_width=((0, 0), (0, frames - T)), mode='constant', constant_values=0)
85 | target = np.pad(target, pad_width=((0, 0), (0, 0), (0, frames - T)), mode='constant', constant_values=0)
86 | elif T > frames:
87 | start = randint(g, low=0, high=T - frames)
88 | mix = mix[:, start:start + frames]
89 | target = target[:, :, start:start + frames]
90 |
91 | paras = {
92 | 'index': index,
93 | 'seed': seed,
94 | 'wavname': self.files[index],
95 | 'wavdir': str(self.wav_dir),
96 | 'sample_rate': self.sr,
97 | 'dataset': self.dataset,
98 | 'target': self.target,
99 | 'version': self.version,
100 | 'audio_time_len': self.audio_time_len,
101 | 'start': start,
102 | }
103 |
104 | return torch.as_tensor(mix, dtype=torch.float32), torch.as_tensor(target, dtype=torch.float32), paras
105 |
106 | def __len__(self):
107 | return len(self.files)
108 |
109 |
110 | class WHAMRDataModule(LightningDataModule):
111 |
112 | def __init__(
113 | self,
114 | whamr_dir: str,
115 | version: str = 'min',
116 | target: str = 'anechoic',
117 | sample_rate: int = 8000,
118 | audio_time_len: Tuple[Optional[float], Optional[float], Optional[float]] = [4.0, 4.0, None], # audio_time_len (seconds) for training, val, test.
119 | batch_size: List[int] = [1, 1],
120 | test_set: str = 'test', # the dataset to test: train, val, test
121 | num_workers: int = 5,
122 | collate_func_train: Callable = default_collate_func,
123 | collate_func_val: Callable = default_collate_func,
124 | collate_func_test: Callable = default_collate_func,
125 | seeds: Tuple[Optional[int], int, int] = [None, 2, 3], # random seeds for train, val and test sets
126 | # if pin_memory=True, will occupy a lot of memory & speed up
127 | pin_memory: bool = True,
128 | # prefetch how many samples, will increase the memory occupied when pin_memory=True
129 | prefetch_factor: int = 5,
130 | persistent_workers: bool = False,
131 | ):
132 | super().__init__()
133 | self.whamr_dir = whamr_dir
134 | self.version = version
135 | self.target = target
136 | self.sample_rate = sample_rate
137 | self.audio_time_len = audio_time_len
138 | self.persistent_workers = persistent_workers
139 | self.test_set = test_set
140 |
141 | rank_zero_info(f'dataset: WHAMR!, datasets for train/valid/test: {version} {target} {sample_rate}, time length: {audio_time_len}')
142 | assert audio_time_len[2] == None, audio_time_len
143 |
144 | self.batch_size = batch_size[0]
145 | self.batch_size_val = batch_size[1]
146 | self.batch_size_test = 1
147 | if len(batch_size) > 2:
148 | self.batch_size_test = batch_size[2]
149 | rank_zero_info(f'batch size: train={self.batch_size}; val={self.batch_size_val}; test={self.batch_size_test}')
150 | # assert self.batch_size_val == 1, "batch size for validation should be 1 as the audios have different length"
151 |
152 | self.num_workers = num_workers
153 |
154 | self.collate_func_train = collate_func_train
155 | self.collate_func_val = collate_func_val
156 | self.collate_func_test = collate_func_test
157 |
158 | self.seeds = []
159 | for seed in seeds:
160 | self.seeds.append(seed if seed is not None else random.randint(0, 1000000))
161 |
162 | self.pin_memory = pin_memory
163 | self.prefetch_factor = prefetch_factor
164 |
165 | def setup(self, stage=None):
166 | if stage is not None and stage == 'test':
167 | audio_time_len = None
168 | else:
169 | audio_time_len = self.audio_time_len
170 |
171 | self.train = WHAMRDataset(
172 | whamr_dir=self.whamr_dir,
173 | dataset='tr',
174 | version=self.version,
175 | target=self.target,
176 | audio_time_len=self.audio_time_len[0] if stage != 'test' else None,
177 | sample_rate=self.sample_rate,
178 | )
179 | self.val = WHAMRDataset(
180 | whamr_dir=self.whamr_dir,
181 | dataset='cv',
182 | version=self.version,
183 | target=self.target,
184 | audio_time_len=self.audio_time_len[1] if stage != 'test' else None,
185 | sample_rate=self.sample_rate,
186 | )
187 | self.test = WHAMRDataset(
188 | whamr_dir=self.whamr_dir,
189 | dataset='tt',
190 | version=self.version,
191 | target=self.target,
192 | audio_time_len=self.audio_time_len[2],
193 | sample_rate=self.sample_rate,
194 | )
195 |
196 | def train_dataloader(self) -> DataLoader:
197 | return DataLoader(
198 | self.train,
199 | sampler=MyDistributedSampler(self.train, seed=self.seeds[0], shuffle=True),
200 | batch_size=self.batch_size,
201 | collate_fn=self.collate_func_train,
202 | num_workers=self.num_workers,
203 | prefetch_factor=self.prefetch_factor,
204 | pin_memory=self.pin_memory,
205 | persistent_workers=self.persistent_workers,
206 | )
207 |
208 | def val_dataloader(self) -> DataLoader:
209 | return DataLoader(
210 | self.val,
211 | sampler=MyDistributedSampler(self.val, seed=self.seeds[1], shuffle=False),
212 | batch_size=self.batch_size_val,
213 | collate_fn=self.collate_func_val,
214 | num_workers=self.num_workers,
215 | prefetch_factor=self.prefetch_factor,
216 | pin_memory=self.pin_memory,
217 | persistent_workers=self.persistent_workers,
218 | )
219 |
220 | def test_dataloader(self) -> DataLoader:
221 | if self.test_set == 'test':
222 | dataset = self.test
223 | elif self.test_set == 'val':
224 | dataset = self.val
225 | else: # train
226 | dataset = self.train
227 |
228 | return DataLoader(
229 | dataset,
230 | sampler=MyDistributedSampler(self.test, seed=self.seeds[2], shuffle=False),
231 | batch_size=self.batch_size_test,
232 | collate_fn=self.collate_func_test,
233 | num_workers=self.num_workers,
234 | prefetch_factor=self.prefetch_factor,
235 | pin_memory=self.pin_memory,
236 | persistent_workers=self.persistent_workers,
237 | )
238 |
239 |
240 | if __name__ == '__main__':
241 | """python -m data_loaders.whamr"""
242 | from argparse import ArgumentParser
243 | parser = ArgumentParser("")
244 | parser.add_argument('--whamr_dir', type=str, default='~/datasets/whamr')
245 | parser.add_argument('--version', type=str, default='min')
246 | parser.add_argument('--target', type=str, default='anechoic')
247 | parser.add_argument('--sample_rate', type=int, default=8000)
248 | parser.add_argument('--gen_unprocessed', type=bool, default=True)
249 | parser.add_argument('--gen_target', type=bool, default=True)
250 | parser.add_argument('--save_dir', type=str, default='dataset/whamr')
251 | parser.add_argument('--dataset', type=str, default='train', choices=['train', 'val', 'test'])
252 |
253 | args = parser.parse_args()
254 | os.makedirs(args.save_dir, exist_ok=True)
255 |
256 | if not args.gen_unprocessed and not args.gen_target:
257 | exit()
258 |
259 | datamodule = WHAMRDataModule(args.whamr_dir, args.version, target=args.target, sample_rate=args.sample_rate, batch_size=[1, 1], num_workers=1)
260 | datamodule.setup()
261 | if args.dataset == 'train':
262 | dataloader = datamodule.train_dataloader()
263 | elif args.dataset == 'val':
264 | dataloader = datamodule.val_dataloader()
265 | else:
266 | assert args.dataset == 'test'
267 | dataloader = datamodule.test_dataloader()
268 |
269 | for idx, (mix, tar, paras) in enumerate(dataloader):
270 | # write target to dir
271 | print(mix.shape, tar.shape, paras)
272 |
273 | if idx > 10:
274 | continue
275 |
276 | if args.gen_target:
277 | tar_path = Path(f"{args.save_dir}/{args.target}/{args.dataset}").expanduser()
278 | tar_path.mkdir(parents=True, exist_ok=True)
279 | for i in range(2):
280 | # assert np.max(np.abs(tar[0, i, 0, :].numpy())) <= 1
281 | sp = tar_path / (paras[0]['wavname'] + f'_spk{i}.wav')
282 | if not sp.exists():
283 | sf.write(sp, tar[0, i, 0, :].numpy(), samplerate=paras[0]['sample_rate'])
284 |
285 | # write unprocessed's 0-th channel
286 | if args.gen_unprocessed:
287 | tar_path = Path(f"{args.save_dir}/unprocessed/{args.dataset}").expanduser()
288 | tar_path.mkdir(parents=True, exist_ok=True)
289 | # assert np.max(np.abs(mix[0, 0, :].numpy())) <= 1
290 | sp = tar_path / (paras[0]['wavname'])
291 | if not sp.exists():
292 | sf.write(sp, mix[0, 0, :].numpy(), samplerate=paras[0]['sample_rate'])
293 |
--------------------------------------------------------------------------------
/images/model_size_and_flops.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Audio-WestlakeU/NBSS/cc42fc8ad2e6642c09b8f4169a85b4766dc22b7e/images/model_size_and_flops.png
--------------------------------------------------------------------------------
/images/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Audio-WestlakeU/NBSS/cc42fc8ad2e6642c09b8f4169a85b4766dc22b7e/images/results.png
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Audio-WestlakeU/NBSS/cc42fc8ad2e6642c09b8f4169a85b4766dc22b7e/models/__init__.py
--------------------------------------------------------------------------------
/models/arch/NBC.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Callable, Optional, Tuple
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torch.nn.init as init
8 | from torch import Tensor
9 |
10 |
11 | class Linear(nn.Linear):
12 | """
13 | Wrapper class of torch.nn.Linear
14 | Weight initialize by xavier initialization and bias initialize to zeros.
15 | """
16 |
17 | def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
18 | super(Linear, self).__init__(in_features=in_features, out_features=out_features, bias=bias)
19 |
20 | init.xavier_uniform_(self.weight)
21 | if bias:
22 | init.zeros_(self.bias)
23 |
24 |
25 | class RelativePositionalEncoding(nn.Module):
26 | """This class returns the relative positional encoding in a range.
27 |
28 | for i in [-m, m]:
29 | PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
30 | PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
31 |
32 | Arguments
33 | ---------
34 | max_len : int
35 | Max length of the input sequences (default 1000).
36 |
37 | Example
38 | -------
39 | >>> a = torch.rand((8, 120, 512))
40 | >>> enc = RelativePositionalEncoding(input_size=a.shape[-1])
41 | >>> b = enc(a)
42 | >>> b.shape
43 | torch.Size([1, 239, 512])
44 | """
45 |
46 | def __init__(self, input_size, max_len=1000):
47 | super().__init__()
48 | # [-m, -m+1, ..., -1, 0, 1, ..., m]
49 | self.max_len = max_len
50 | self.zero_index = max_len
51 | pe = torch.zeros(self.max_len * 2 + 1, input_size, requires_grad=False)
52 | positions = torch.arange(-self.max_len, self.max_len + 1).unsqueeze(1).float()
53 | denominator = torch.exp(torch.arange(0, input_size, 2).float() * -(math.log(10000.0) / input_size))
54 |
55 | pe[:, 0::2] = torch.sin(positions * denominator)
56 | pe[:, 1::2] = torch.cos(positions * denominator)
57 | pe = pe.unsqueeze(0)
58 | self.register_buffer("pe", pe)
59 |
60 | def forward(self, x: torch.Tensor) -> torch.Tensor:
61 | """
62 | Args:
63 | x (torch.Tensor): shape [batch, time, feature]
64 |
65 | Returns:
66 | torch.Tensor: relative positional encoding
67 | """
68 | B, T, F = x.shape
69 | start, end = -T + 1, T - 1
70 | return self.pe[:, start + self.zero_index:end + self.zero_index + 1].clone().detach()
71 |
72 |
73 | class RelativePositionalMultiHeadAttention(nn.Module):
74 | """
75 | Multi-head attention with relative positional encoding.
76 | This concept was proposed in the "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
77 | """
78 |
79 | def __init__(
80 | self,
81 | d_model: int = 256,
82 | num_heads: int = 8,
83 | dropout: float = 0.1,
84 | ):
85 | super(RelativePositionalMultiHeadAttention, self).__init__()
86 | assert d_model % num_heads == 0, "d_model % num_heads should be zero."
87 | self.d_model = d_model
88 | self.d_head = int(d_model / num_heads)
89 | self.num_heads = num_heads
90 | self.sqrt_dim = math.sqrt(d_model)
91 |
92 | self.query_proj = Linear(d_model, d_model)
93 | self.key_proj = Linear(d_model, d_model)
94 | self.value_proj = Linear(d_model, d_model)
95 | self.pos_proj = Linear(d_model, d_model, bias=False)
96 | self.rel_pos = RelativePositionalEncoding(d_model)
97 |
98 | self.dropout = nn.Dropout(p=dropout)
99 | self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
100 | self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
101 | init.xavier_uniform_(self.u_bias)
102 | init.xavier_uniform_(self.v_bias)
103 |
104 | self.out_proj = Linear(d_model, d_model)
105 |
106 | def forward(self, query: Tensor, key: Optional[Tensor] = None, value: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
107 | if key is None:
108 | key = query
109 | if value is None:
110 | value = query
111 |
112 | batch_size, time_frames, feature_size = value.shape
113 | pos_embedding = self.rel_pos.forward(value)
114 |
115 | query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)
116 | key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
117 | value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
118 |
119 | content_score = torch.matmul((query + self.u_bias).transpose(1, 2), key.transpose(2, 3))
120 |
121 | pos_embedding = self.pos_proj(pos_embedding)
122 | #### my implementation to calculate pos_score ####
123 | pos_index = self._get_relative_pos_index(time_frames, pos_embedding.device)
124 | pos_embedding = pos_embedding[:, pos_index, :].view(1, time_frames, time_frames, self.num_heads, self.d_head)
125 | # original matmul
126 | # pos_score = torch.matmul((query + self.v_bias).transpose(1, 2).unsqueeze(3), pos_embedding.permute(0, 3, 1, 4, 2)).squeeze(3)
127 | # faster than matmul, and saving memory
128 | qv_bias = (query + self.v_bias).transpose(1, 2) # [B, N, T, D]
129 | pos_embedding = pos_embedding.permute(0, 3, 1, 4, 2) # [1, N, T, D, T], 1 is broadcasted to B
130 | pos_score = torch.einsum("abcd,abcdf->abcf", qv_bias, pos_embedding)
131 | score = (content_score + pos_score) / self.sqrt_dim
132 |
133 | if attn_mask is not None:
134 | attn += attn_mask # give -inf to mask the corresponding point
135 |
136 | attn = F.softmax(score, -1)
137 | attn = self.dropout(attn)
138 | output = torch.matmul(attn, value).transpose(1, 2) # [B, T, N, D]
139 |
140 | # output
141 | output = output.contiguous().view(batch_size, -1, self.d_model)
142 |
143 | return self.out_proj(output), attn
144 |
145 | def _get_relative_pos_index(self, T: int, device: torch.device) -> Tensor:
146 | with torch.no_grad():
147 | pos1 = torch.arange(start=0, end=T, dtype=torch.long, device=device, requires_grad=False).unsqueeze(1)
148 | pos2 = torch.arange(start=0, end=T, dtype=torch.long, device=device, requires_grad=False).unsqueeze(0)
149 | relative_pos = pos1 - pos2
150 | """ now, relative_pos=[
151 | [0,-1,-2,...,-(T-1)],
152 | [1, 0,-1,...,-(T-2)],
153 | ...
154 | [T-1,T-2,..., 1, 0]
155 | ]
156 | """
157 | pos_index = relative_pos[:, :] + (T - 1) # (T-1) is the index of the relative position 0
158 | return pos_index
159 |
160 |
161 | class NBCBlock(nn.Module):
162 |
163 | def __init__(
164 | self,
165 | dim_model: int = 192,
166 | num_head: int = 8,
167 | dim_ffn: int = 384,
168 | dropout: float = 0.1,
169 | activation: Callable = F.silu,
170 | layer_norm_eps: float = 1e-5,
171 | norm_first: bool = True,
172 | n_conv_groups: int = 384,
173 | conv_kernel_size: int = 3,
174 | conv_bias: bool = True,
175 | n_conv_layers: int = 3,
176 | conv_mid_norm: str = "GN",
177 | ) -> None:
178 | super().__init__()
179 |
180 | self.self_attn = RelativePositionalMultiHeadAttention(dim_model, num_head, dropout=dropout)
181 |
182 | # Implementation of Feedforward model
183 | self.linear1 = Linear(dim_model, dim_ffn)
184 | self.dropout = nn.Dropout(dropout)
185 | self.linear2 = Linear(dim_ffn, dim_model)
186 |
187 | self.norm_first = norm_first
188 | self.norm1 = nn.LayerNorm(dim_model, eps=layer_norm_eps)
189 | self.norm2 = nn.LayerNorm(dim_model, eps=layer_norm_eps)
190 | self.dropout1 = nn.Dropout(dropout)
191 | self.dropout2 = nn.Dropout(dropout)
192 |
193 | self.activation = activation
194 |
195 | convs = []
196 | for l in range(n_conv_layers):
197 | convs.append(nn.Conv1d(in_channels=dim_ffn, out_channels=dim_ffn, kernel_size=conv_kernel_size, padding='same', groups=n_conv_groups, bias=conv_bias))
198 | if conv_mid_norm != None:
199 | if conv_mid_norm == 'GN':
200 | convs.append(nn.GroupNorm(8, dim_ffn))
201 | else:
202 | raise Exception('Unspoorted mid norm ' + conv_mid_norm)
203 | convs.append(nn.SiLU())
204 | self.conv = nn.Sequential(*convs)
205 |
206 | def forward(self, x: Tensor, att_mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
207 | r"""
208 |
209 | Args:
210 | x: shape [batch, seq, feature]
211 | att_mask: the mask for attentions. shape [batch, seq, seq]
212 |
213 | Shape:
214 | out: shape [batch, seq, feature]
215 | attention: shape [batch, head, seq, seq]
216 | """
217 |
218 | if self.norm_first:
219 | x_, attn = self._sa_block(self.norm1(x), att_mask)
220 | x = x + x_
221 | x = x + self._ff_block(self.norm2(x))
222 | else:
223 | x_, attn = self._sa_block(x, att_mask)
224 | x = self.norm1(x + x_)
225 | x = self.norm2(x + self._ff_block(x))
226 |
227 | return x, attn
228 |
229 | # self-attention block
230 | def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor]) -> Tuple[Tensor, Tensor]:
231 | x, attn = self.self_attn(x, attn_mask=attn_mask)
232 | return self.dropout1(x), attn
233 |
234 | # feed forward block
235 | def _ff_block(self, x: Tensor) -> Tensor:
236 | x = self.linear2(self.dropout(self.conv(self.activation(self.linear1(x)).transpose(-1, -2)).transpose(-1, -2)))
237 | return self.dropout2(x)
238 |
239 |
240 | class NBC(nn.Module):
241 |
242 | def __init__(
243 | self,
244 | dim_input: int = 16, # 2*8
245 | dim_output: int = 4, # 2*2
246 | n_layers: int = 4,
247 | encoder_kernel_size: int = 4,
248 | n_heads: int = 8,
249 | activation: Optional[str] = "",
250 | hidden_size: int = 192,
251 | norm_first: bool = True,
252 | ffn_size: int = 384,
253 | inner_conv_kernel_size: int = 3,
254 | inner_conv_groups: int = 8,
255 | inner_conv_bias: bool = True,
256 | inner_conv_layers: int = 3,
257 | inner_conv_mid_norm: str = "GN",
258 | ):
259 | super().__init__()
260 | # encoder
261 | self.encoder = nn.Conv1d(in_channels=dim_input, out_channels=hidden_size, kernel_size=encoder_kernel_size, stride=1)
262 | # self-attention net
263 | self.sa_layers = nn.ModuleList()
264 | for l in range(n_layers):
265 | self.sa_layers.append(
266 | NBCBlock(
267 | dim_model=hidden_size,
268 | num_head=n_heads,
269 | norm_first=norm_first,
270 | dim_ffn=ffn_size,
271 | n_conv_groups=inner_conv_groups,
272 | conv_kernel_size=inner_conv_kernel_size,
273 | conv_bias=inner_conv_bias,
274 | n_conv_layers=inner_conv_layers,
275 | conv_mid_norm=inner_conv_mid_norm,
276 | ))
277 |
278 | # decoder
279 | assert activation == '', 'not implemented'
280 | self.decoder = nn.ConvTranspose1d(in_channels=hidden_size, out_channels=dim_output, kernel_size=encoder_kernel_size, stride=1)
281 |
282 | def forward(self, x: Tensor) -> Tensor:
283 | # x: [Batch, NumFreqs, Time, Feature]
284 | B, F, T, H = x.shape
285 | x = x.reshape(B * F, T, H)
286 | x = self.encoder(x.permute(0, 2, 1)).permute(0, 2, 1)
287 | attns = []
288 | for m in self.sa_layers:
289 | x, attn = m(x)
290 | attns.append(attn)
291 | y = self.decoder(x.permute(0, 2, 1)).permute(0, 2, 1)
292 | y = y.reshape(B, F, T, -1)
293 | return y.contiguous() # , attns
294 |
295 |
296 | if __name__ == '__main__':
297 | Batch, Freq, Time, Chn, Spk = 1, 257, 100, 8, 2
298 | x = torch.randn((Batch, Freq, Time, Chn * 2))
299 | m = NBC(dim_input=Chn * 2, dim_output=Spk * 2, n_layers=4)
300 | y = m(x)
301 | print(y.shape)
302 |
--------------------------------------------------------------------------------
/models/arch/NBC2.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, Optional, Tuple, Union
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.init as init
6 | from torch import Tensor
7 | from torch.nn import Module, MultiheadAttention
8 | from torch.nn.parameter import Parameter
9 |
10 |
11 | class LayerNorm(nn.LayerNorm):
12 |
13 | def __init__(self, transpose: bool, **kwargs) -> None:
14 | super().__init__(**kwargs)
15 | self.transpose = transpose
16 |
17 | def forward(self, input: Tensor) -> Tensor:
18 | if self.transpose:
19 | input = input.transpose(-1, -2) # [B, H, T] -> [B, T, H]
20 | o = super().forward(input)
21 | if self.transpose:
22 | o = o.transpose(-1, -2)
23 | return o
24 |
25 |
26 | class BatchNorm1d(nn.Module):
27 |
28 | def __init__(self, transpose: bool, **kwargs) -> None:
29 | super().__init__()
30 | self.transpose = transpose
31 | self.bn = nn.BatchNorm1d(**kwargs)
32 |
33 | def forward(self, input: Tensor) -> Tensor:
34 | if self.transpose == False:
35 | input = input.transpose(-1, -2) # [B, T, H] -> [B, H, T]
36 | o = self.bn.forward(input) # accepts [B, H, T]
37 | if self.transpose == False:
38 | o = o.transpose(-1, -2)
39 | return o
40 |
41 |
42 | class GroupNorm(nn.GroupNorm):
43 |
44 | def __init__(self, transpose: bool, **kwargs) -> None:
45 | super().__init__(**kwargs)
46 | self.transpose = transpose
47 |
48 | def forward(self, input: Tensor) -> Tensor:
49 | if self.transpose == False:
50 | input = input.transpose(-1, -2) # [B, T, H] -> [B, H, T]
51 | o = super().forward(input) # accepts [B, H, T]
52 | if self.transpose == False:
53 | o = o.transpose(-1, -2)
54 | return o
55 |
56 |
57 | class GroupBatchNorm(Module):
58 | """Applies Group Batch Normalization over a group of inputs
59 |
60 | This layer uses statistics computed from input data in both training and
61 | evaluation modes.
62 | """
63 |
64 | dim_hidden: int
65 | group_size: int
66 | eps: float
67 | affine: bool
68 | transpose: bool
69 | share_along_sequence_dim: bool
70 |
71 | def __init__(
72 | self,
73 | dim_hidden: int,
74 | group_size: int,
75 | share_along_sequence_dim: bool = False,
76 | transpose: bool = False,
77 | affine: bool = True,
78 | eps: float = 1e-5,
79 | ) -> None:
80 | """
81 | Args:
82 | dim_hidden (int): hidden dimension
83 | group_size (int): the size of group
84 | share_along_sequence_dim (bool): share statistics along the sequence dimension. Defaults to False.
85 | transpose (bool): whether the shape of input is [B, T, H] or [B, H, T]. Defaults to False, i.e. [B, T, H].
86 | affine (bool): affine transformation. Defaults to True.
87 | eps (float): Defaults to 1e-5.
88 | """
89 | super(GroupBatchNorm, self).__init__()
90 |
91 | self.dim_hidden = dim_hidden
92 | self.group_size = group_size
93 | self.eps = eps
94 | self.affine = affine
95 | self.transpose = transpose
96 | self.share_along_sequence_dim = share_along_sequence_dim
97 | if self.affine:
98 | if transpose:
99 | self.weight = Parameter(torch.empty([dim_hidden, 1]))
100 | self.bias = Parameter(torch.empty([dim_hidden, 1]))
101 | else:
102 | self.weight = Parameter(torch.empty([dim_hidden]))
103 | self.bias = Parameter(torch.empty([dim_hidden]))
104 | self.reset_parameters()
105 |
106 | def reset_parameters(self) -> None:
107 | if self.affine:
108 | init.ones_(self.weight)
109 | init.zeros_(self.bias)
110 |
111 | def forward(self, input: Tensor) -> Tensor:
112 | """
113 | Args:
114 | input: shape [B, T, H] if transpose=False, else shape [B, H, T] , where B = num of groups * group size.
115 | """
116 | assert (input.shape[0] // self.group_size) * self.group_size, f'batch size {input.shape[0]} is not divisible by group size {self.group_size}'
117 | if self.transpose == False:
118 | B, T, H = input.shape
119 | input = input.reshape(B // self.group_size, self.group_size, T, H)
120 |
121 | if self.share_along_sequence_dim:
122 | var, mean = torch.var_mean(input, dim=(1, 2, 3), unbiased=False, keepdim=True)
123 | else:
124 | var, mean = torch.var_mean(input, dim=(1, 3), unbiased=False, keepdim=True)
125 |
126 | output = (input - mean) / torch.sqrt(var + self.eps)
127 | if self.affine:
128 | output = output * self.weight + self.bias
129 | output = output.reshape(B, T, H)
130 | else:
131 | B, H, T = input.shape
132 | input = input.reshape(B // self.group_size, self.group_size, H, T)
133 |
134 | if self.share_along_sequence_dim:
135 | var, mean = torch.var_mean(input, dim=(1, 2, 3), unbiased=False, keepdim=True)
136 | else:
137 | var, mean = torch.var_mean(input, dim=(1, 2), unbiased=False, keepdim=True)
138 |
139 | output = (input - mean) / torch.sqrt(var + self.eps)
140 | if self.affine:
141 | output = output * self.weight + self.bias
142 |
143 | output = output.reshape(B, H, T)
144 |
145 | return output
146 |
147 | def extra_repr(self) -> str:
148 | return '{dim_hidden}, {group_size}, share_along_sequence_dim={share_along_sequence_dim}, transpose={transpose}, eps={eps}, ' \
149 | 'affine={affine}'.format(**self.__dict__)
150 |
151 |
152 | class NBC2Block(nn.Module):
153 |
154 | def __init__(
155 | self,
156 | dim_hidden: int,
157 | dim_ffn: int,
158 | n_heads: int,
159 | dropout: float = 0,
160 | conv_kernel_size: int = 3,
161 | n_conv_groups: int = 8,
162 | norms: Tuple[str, str, str] = ("LN", "GBN", "GBN"),
163 | group_batch_norm_kwargs: Dict[str, Any] = {
164 | 'group_size': 257,
165 | 'share_along_sequence_dim': False,
166 | },
167 | ) -> None:
168 | super().__init__()
169 | # self-attention
170 | self.norm1 = self._new_norm(norms[0], dim_hidden, False, n_conv_groups, **group_batch_norm_kwargs)
171 | self.self_attn = MultiheadAttention(embed_dim=dim_hidden, num_heads=n_heads, batch_first=True)
172 | self.dropout1 = nn.Dropout(dropout)
173 |
174 | # Convolutional Feedforward
175 | self.norm2 = self._new_norm(norms[1], dim_hidden, False, n_conv_groups, **group_batch_norm_kwargs)
176 | self.linear1 = nn.Linear(dim_hidden, dim_ffn)
177 | self.conv = nn.Sequential(
178 | nn.SiLU(),
179 | nn.Conv1d(in_channels=dim_ffn, out_channels=dim_ffn, kernel_size=conv_kernel_size, padding='same', groups=n_conv_groups, bias=True),
180 | nn.SiLU(),
181 | nn.Conv1d(in_channels=dim_ffn, out_channels=dim_ffn, kernel_size=conv_kernel_size, padding='same', groups=n_conv_groups, bias=True),
182 | self._new_norm(norms[2], dim_ffn, True, n_conv_groups, **group_batch_norm_kwargs),
183 | nn.SiLU(),
184 | nn.Conv1d(in_channels=dim_ffn, out_channels=dim_ffn, kernel_size=conv_kernel_size, padding='same', groups=n_conv_groups, bias=True),
185 | nn.SiLU(),
186 | nn.Dropout(dropout),
187 | )
188 | self.linear2 = nn.Linear(dim_ffn, dim_hidden)
189 | self.dropout2 = nn.Dropout(dropout)
190 |
191 | nn.init.xavier_uniform_(self.linear1.weight)
192 | nn.init.xavier_uniform_(self.linear2.weight)
193 | nn.init.zeros_(self.linear1.bias)
194 | nn.init.zeros_(self.linear2.bias)
195 |
196 | def forward(self, x: Tensor, att_mask: Optional[Tensor] = None) -> Tensor:
197 | r"""
198 |
199 | Args:
200 | x: shape [batch, seq, feature]
201 | att_mask: the mask for attentions. shape [batch, seq, seq]
202 |
203 | Shape:
204 | out: shape [batch, seq, feature]
205 | attention: shape [batch, head, seq, seq]
206 | """
207 |
208 | x_, attn = self._sa_block(self.norm1(x), att_mask)
209 | x = x + x_
210 | x = x + self._ff_block(self.norm2(x))
211 |
212 | return x, attn
213 |
214 | # self-attention block
215 | def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor]) -> Tuple[Tensor, Tensor]:
216 | if isinstance(self.self_attn, MultiheadAttention):
217 | x, attn = self.self_attn.forward(x, x, x, average_attn_weights=False, attn_mask=attn_mask)
218 | else:
219 | x, attn = self.self_attn(x, attn_mask=attn_mask)
220 | return self.dropout1(x), attn
221 |
222 | # conv feed forward block
223 | def _ff_block(self, x: Tensor) -> Tensor:
224 | x = self.linear2(self.conv(self.linear1(x).transpose(-1, -2)).transpose(-1, -2))
225 | return self.dropout2(x)
226 |
227 | def _new_norm(self, norm_type: str, dim_hidden: int, transpose: bool, num_conv_groups: int, **freq_norm_kwargs):
228 | if norm_type == 'LN':
229 | norm = LayerNorm(normalized_shape=dim_hidden, transpose=transpose)
230 | elif norm_type == 'GBN':
231 | norm = GroupBatchNorm(dim_hidden=dim_hidden, transpose=transpose, **freq_norm_kwargs)
232 | elif norm_type == 'BN':
233 | norm = BatchNorm1d(num_features=dim_hidden, transpose=transpose)
234 | elif norm_type == 'GN':
235 | norm = GroupNorm(num_groups=num_conv_groups, num_channels=dim_hidden, transpose=transpose)
236 | else:
237 | raise Exception(norm_type)
238 | return norm
239 |
240 |
241 | class NBC2(nn.Module):
242 |
243 | def __init__(
244 | self,
245 | dim_input: int,
246 | dim_output: int,
247 | n_layers: int,
248 | encoder_kernel_size: int = 5,
249 | dim_hidden: int = 192,
250 | dim_ffn: int = 384,
251 | num_freqs: int = 257,
252 | block_kwargs: Dict[str, Any] = {
253 | 'n_heads': 2,
254 | 'dropout': 0,
255 | 'conv_kernel_size': 3,
256 | 'n_conv_groups': 8,
257 | 'norms': ("LN", "GBN", "GBN"),
258 | 'group_batch_norm_kwargs': {
259 | 'share_along_sequence_dim': False,
260 | },
261 | },
262 | ):
263 | super().__init__()
264 | block_kwargs['group_batch_norm_kwargs']['group_size'] = num_freqs
265 |
266 | # encoder
267 | self.encoder = nn.Conv1d(in_channels=dim_input, out_channels=dim_hidden, kernel_size=encoder_kernel_size, stride=1, padding="same")
268 |
269 | # self-attention net
270 | self.sa_layers = nn.ModuleList()
271 | for l in range(n_layers):
272 | self.sa_layers.append(NBC2Block(dim_hidden=dim_hidden, dim_ffn=dim_ffn, **block_kwargs))
273 |
274 | # decoder
275 | self.decoder = nn.Linear(in_features=dim_hidden, out_features=dim_output)
276 |
277 | def forward(self, x: Tensor) -> Tensor:
278 | # x: [Batch, NumFreqs, Time, Feature]
279 | B, F, T, H = x.shape
280 | x = x.reshape(B * F, T, H)
281 | x = self.encoder(x.permute(0, 2, 1)).permute(0, 2, 1)
282 | # attns = []
283 | for m in self.sa_layers:
284 | x, attn = m(x)
285 | del attn
286 | # attns.append(attn)
287 | y = self.decoder(x)
288 | y = y.reshape(B, F, T, -1)
289 | return y.contiguous() # , attns
290 |
291 |
292 | if __name__ == '__main__':
293 | x = torch.randn((5, 257, 100, 16))
294 | NBC2_small = NBC2(
295 | dim_input=16,
296 | dim_output=4,
297 | n_layers=8,
298 | dim_hidden=96,
299 | dim_ffn=192,
300 | block_kwargs={
301 | 'n_heads': 2,
302 | 'dropout': 0,
303 | 'conv_kernel_size': 3,
304 | 'n_conv_groups': 8,
305 | 'norms': ("LN", "GBN", "GBN"),
306 | 'group_batch_norm_kwargs': {
307 | 'group_size': 257,
308 | 'share_along_sequence_dim': False,
309 | },
310 | },
311 | )
312 | y = NBC2_small(x)
313 | print(NBC2_small)
314 | print(y.shape)
315 | NBC2_large = NBC2(
316 | dim_input=16,
317 | dim_output=4,
318 | n_layers=12,
319 | dim_hidden=192,
320 | dim_ffn=384,
321 | block_kwargs={
322 | 'n_heads': 2,
323 | 'dropout': 0,
324 | 'conv_kernel_size': 3,
325 | 'n_conv_groups': 8,
326 | 'norms': ("LN", "GBN", "GBN"),
327 | 'group_batch_norm_kwargs': {
328 | 'group_size': 257,
329 | 'share_along_sequence_dim': False,
330 | },
331 | },
332 | )
333 | y = NBC2_large(x)
334 | print(NBC2_large)
335 | print(y.shape)
336 |
--------------------------------------------------------------------------------
/models/arch/NBSS.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, Tuple
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch import Tensor
6 | from torchmetrics.functional.audio import permutation_invariant_training as pit
7 | from torchmetrics.functional.audio import scale_invariant_signal_distortion_ratio as si_sdr
8 |
9 | from models.arch.blstm2_fc1 import BLSTM2_FC1
10 | from models.arch.NBC import NBC
11 | from models.arch.NBC2 import NBC2
12 |
13 |
14 | def neg_si_sdr(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
15 | batch_size = target.shape[0]
16 | si_snr_val = si_sdr(preds=preds, target=target)
17 | return -torch.mean(si_snr_val.view(batch_size, -1), dim=1)
18 |
19 |
20 | class NBSS(nn.Module):
21 | """Multi-channel Narrow-band Deep Speech Separation with Full-band Permutation Invariant Training.
22 |
23 | A module version of NBSS which takes time domain signal as input, and outputs time domain signal.
24 |
25 | Arch could be NB-BLSTM or NBC
26 | """
27 |
28 | def __init__(
29 | self,
30 | n_channel: int = 8,
31 | n_speaker: int = 2,
32 | n_fft: int = 512,
33 | n_overlap: int = 256,
34 | ref_channel: int = 0,
35 | arch: str = "NB_BLSTM", # could also be NBC, NBC2
36 | arch_kwargs: Dict[str, Any] = dict(),
37 | ):
38 | super().__init__()
39 |
40 | if arch == "NB_BLSTM":
41 | self.arch: nn.Module = BLSTM2_FC1(dim_input=n_channel * 2, dim_output=n_speaker * 2, **arch_kwargs)
42 | elif arch == "NBC":
43 | self.arch = NBC(dim_input=n_channel * 2, dim_output=n_speaker * 2, **arch_kwargs)
44 | elif arch == 'NBC2':
45 | self.arch = NBC2(dim_input=n_channel * 2, dim_output=n_speaker * 2, **arch_kwargs)
46 | else:
47 | raise Exception(f"Unkown arch={arch}")
48 |
49 | self.register_buffer('window', torch.hann_window(n_fft), False) # self.window, will be moved to self.device at training time
50 | self.n_fft = n_fft
51 | self.n_overlap = n_overlap
52 | self.ref_channel = ref_channel
53 | self.n_channel = n_channel
54 | self.n_speaker = n_speaker
55 |
56 | def forward(self, x: Tensor) -> Tensor:
57 | """forward
58 |
59 | Args:
60 | x: time domain signal, shape [Batch, Channel, Time]
61 |
62 | Returns:
63 | y: the predicted time domain signal, shape [Batch, Speaker, Time]
64 | """
65 |
66 | # STFT
67 | B, C, T = x.shape
68 | x = x.reshape((B * C, T))
69 | X = torch.stft(x, n_fft=self.n_fft, hop_length=self.n_overlap, window=self.window, win_length=self.n_fft, return_complex=True)
70 | X = X.reshape((B, C, X.shape[-2], X.shape[-1])) # (batch, channel, freq, time frame)
71 | X = X.permute(0, 2, 3, 1) # (batch, freq, time frame, channel)
72 |
73 | # normalization by using ref_channel
74 | F, TF = X.shape[1], X.shape[2]
75 | Xr = X[..., self.ref_channel].clone() # copy
76 | XrMM = torch.abs(Xr).mean(dim=2) # Xr_magnitude_mean: mean of the magnitude of the ref channel of X
77 | X[:, :, :, :] /= (XrMM.reshape(B, F, 1, 1) + 1e-8)
78 |
79 | # to real
80 | X = torch.view_as_real(X) # [B, F, T, C, 2]
81 | X = X.reshape(B, F, TF, C * 2)
82 |
83 | # network processing
84 | output = self.arch(X)
85 |
86 | # to complex
87 | output = output.reshape(B, F, TF, self.n_speaker, 2)
88 | output = torch.view_as_complex(output) # [B, F, TF, S]
89 |
90 | # inverse normalization
91 | Ys_hat = torch.empty(size=(B, self.n_speaker, F, TF), dtype=torch.complex64, device=output.device)
92 | XrMM = torch.unsqueeze(XrMM, dim=2).expand(-1, -1, TF)
93 | for spk in range(self.n_speaker):
94 | Ys_hat[:, spk, :, :] = output[:, :, :, spk] * XrMM[:, :, :]
95 |
96 | # iSTFT with frequency binding
97 | ys_hat = torch.istft(Ys_hat.reshape(B * self.n_speaker, F, TF), n_fft=self.n_fft, hop_length=self.n_overlap, window=self.window, win_length=self.n_fft, length=T)
98 | ys_hat = ys_hat.reshape(B, self.n_speaker, T)
99 | return ys_hat
100 |
101 |
102 | if __name__ == '__main__':
103 | x = torch.randn(size=(10, 8, 16000))
104 | ys = torch.randn(size=(10, 2, 16000))
105 |
106 | NBSS_with_NB_BLSTM = NBSS(n_channel=8, n_speaker=2, arch="NB_BLSTM")
107 | ys_hat = NBSS_with_NB_BLSTM(x)
108 | neg_sisdr_loss, best_perm = pit(preds=ys_hat, target=ys, metric_func=neg_si_sdr, eval_func='min')
109 | print(ys_hat.shape, neg_sisdr_loss.mean())
110 |
111 | NBSS_with_NBC = NBSS(n_channel=8, n_speaker=2, arch="NBC")
112 | ys_hat = NBSS_with_NBC(x)
113 | neg_sisdr_loss, best_perm = pit(preds=ys_hat, target=ys, metric_func=neg_si_sdr, eval_func='min')
114 | print(ys_hat.shape, neg_sisdr_loss.mean())
115 |
116 | NBSS_with_NBC_small = NBSS(n_channel=8,
117 | n_speaker=2,
118 | arch="NBC2",
119 | arch_kwargs={
120 | "n_layers": 8, # 12 for large
121 | "dim_hidden": 96, # 192 for large
122 | "dim_ffn": 192, # 384 for large
123 | "block_kwargs": {
124 | 'n_heads': 2,
125 | 'dropout': 0,
126 | 'conv_kernel_size': 3,
127 | 'n_conv_groups': 8,
128 | 'norms': ("LN", "GBN", "GBN"),
129 | 'group_batch_norm_kwargs': {
130 | 'group_size': 257, # 129 for 8k Hz
131 | 'share_along_sequence_dim': False,
132 | },
133 | }
134 | },)
135 | ys_hat = NBSS_with_NBC_small(x)
136 | neg_sisdr_loss, best_perm = pit(preds=ys_hat, target=ys, metric_func=neg_si_sdr, eval_func='min')
137 | print(ys_hat.shape, neg_sisdr_loss.mean())
138 |
--------------------------------------------------------------------------------
/models/arch/SpatialNet.py:
--------------------------------------------------------------------------------
1 | from typing import *
2 |
3 | import torch
4 | import torch.nn as nn
5 | from models.arch.base.norm import *
6 | from models.arch.base.non_linear import *
7 | from models.arch.base.linear_group import LinearGroup
8 | from torch import Tensor
9 | from torch.nn import MultiheadAttention
10 |
11 |
12 | class SpatialNetLayer(nn.Module):
13 |
14 | def __init__(
15 | self,
16 | dim_hidden: int,
17 | dim_ffn: int,
18 | dim_squeeze: int,
19 | num_freqs: int,
20 | num_heads: int,
21 | dropout: Tuple[float, float, float] = (0, 0, 0),
22 | kernel_size: Tuple[int, int] = (5, 3),
23 | conv_groups: Tuple[int, int] = (8, 8),
24 | norms: List[str] = ("LN", "LN", "GN", "LN", "LN", "LN"),
25 | padding: str = 'zeros',
26 | full: nn.Module = None,
27 | ) -> None:
28 | super().__init__()
29 | f_conv_groups = conv_groups[0]
30 | t_conv_groups = conv_groups[1]
31 | f_kernel_size = kernel_size[0]
32 | t_kernel_size = kernel_size[1]
33 |
34 | # cross-band block
35 | # frequency-convolutional module
36 | self.fconv1 = nn.ModuleList([
37 | new_norm(norms[3], dim_hidden, seq_last=True, group_size=None, num_groups=f_conv_groups),
38 | nn.Conv1d(in_channels=dim_hidden, out_channels=dim_hidden, kernel_size=f_kernel_size, groups=f_conv_groups, padding='same', padding_mode=padding),
39 | nn.PReLU(dim_hidden),
40 | ])
41 | # full-band linear module
42 | self.norm_full = new_norm(norms[5], dim_hidden, seq_last=False, group_size=None, num_groups=f_conv_groups)
43 | self.full_share = False if full == None else True
44 | self.squeeze = nn.Sequential(nn.Conv1d(in_channels=dim_hidden, out_channels=dim_squeeze, kernel_size=1), nn.SiLU())
45 | self.dropout_full = nn.Dropout2d(dropout[2]) if dropout[2] > 0 else None
46 | self.full = LinearGroup(num_freqs, num_freqs, num_groups=dim_squeeze) if full == None else full
47 | self.unsqueeze = nn.Sequential(nn.Conv1d(in_channels=dim_squeeze, out_channels=dim_hidden, kernel_size=1), nn.SiLU())
48 | # frequency-convolutional module
49 | self.fconv2 = nn.ModuleList([
50 | new_norm(norms[4], dim_hidden, seq_last=True, group_size=None, num_groups=f_conv_groups),
51 | nn.Conv1d(in_channels=dim_hidden, out_channels=dim_hidden, kernel_size=f_kernel_size, groups=f_conv_groups, padding='same', padding_mode=padding),
52 | nn.PReLU(dim_hidden),
53 | ])
54 |
55 | # narrow-band block
56 | # MHSA module
57 | self.norm_mhsa = new_norm(norms[0], dim_hidden, seq_last=False, group_size=None, num_groups=t_conv_groups)
58 | self.mhsa = MultiheadAttention(embed_dim=dim_hidden, num_heads=num_heads, batch_first=True)
59 | self.dropout_mhsa = nn.Dropout(dropout[0])
60 | # T-ConvFFN module
61 | self.tconvffn = nn.ModuleList([
62 | new_norm(norms[1], dim_hidden, seq_last=True, group_size=None, num_groups=t_conv_groups),
63 | nn.Conv1d(in_channels=dim_hidden, out_channels=dim_ffn, kernel_size=1),
64 | nn.SiLU(),
65 | nn.Conv1d(in_channels=dim_ffn, out_channels=dim_ffn, kernel_size=t_kernel_size, padding='same', groups=t_conv_groups),
66 | nn.SiLU(),
67 | nn.Conv1d(in_channels=dim_ffn, out_channels=dim_ffn, kernel_size=t_kernel_size, padding='same', groups=t_conv_groups),
68 | new_norm(norms[2], dim_ffn, seq_last=True, group_size=None, num_groups=t_conv_groups),
69 | nn.SiLU(),
70 | nn.Conv1d(in_channels=dim_ffn, out_channels=dim_ffn, kernel_size=t_kernel_size, padding='same', groups=t_conv_groups),
71 | nn.SiLU(),
72 | nn.Conv1d(in_channels=dim_ffn, out_channels=dim_hidden, kernel_size=1),
73 | ])
74 | self.dropout_tconvffn = nn.Dropout(dropout[1])
75 |
76 | def forward(self, x: Tensor, att_mask: Optional[Tensor] = None) -> Tensor:
77 | r"""
78 | Args:
79 | x: shape [B, F, T, H]
80 | att_mask: the mask for attention along T. shape [B, T, T]
81 |
82 | Shape:
83 | out: shape [B, F, T, H]
84 | """
85 | x = x + self._fconv(self.fconv1, x)
86 | x = x + self._full(x)
87 | x = x + self._fconv(self.fconv2, x)
88 | x_, attn = self._tsa(x, att_mask)
89 | x = x + x_
90 | x = x + self._tconvffn(x)
91 | return x, attn
92 |
93 | def _tsa(self, x: Tensor, attn_mask: Optional[Tensor]) -> Tuple[Tensor, Tensor]:
94 | B, F, T, H = x.shape
95 | x = self.norm_mhsa(x)
96 | x = x.reshape(B * F, T, H)
97 | need_weights = False if hasattr(self, "need_weights") else self.need_weights
98 | x, attn = self.mhsa.forward(x, x, x, need_weights=need_weights, average_attn_weights=False, attn_mask=attn_mask)
99 | x = x.reshape(B, F, T, H)
100 | return self.dropout_mhsa(x), attn
101 |
102 | def _tconvffn(self, x: Tensor) -> Tensor:
103 | B, F, T, H0 = x.shape
104 | # T-Conv
105 | x = x.transpose(-1, -2) # [B,F,H,T]
106 | x = x.reshape(B * F, H0, T)
107 | for m in self.tconvffn:
108 | if type(m) == GroupBatchNorm:
109 | x = m(x, group_size=F)
110 | else:
111 | x = m(x)
112 | x = x.reshape(B, F, H0, T)
113 | x = x.transpose(-1, -2) # [B,F,T,H]
114 | return self.dropout_tconvffn(x)
115 |
116 | def _fconv(self, ml: nn.ModuleList, x: Tensor) -> Tensor:
117 | B, F, T, H = x.shape
118 | x = x.permute(0, 2, 3, 1) # [B,T,H,F]
119 | x = x.reshape(B * T, H, F)
120 | for m in ml:
121 | if type(m) == GroupBatchNorm:
122 | x = m(x, group_size=T)
123 | else:
124 | x = m(x)
125 | x = x.reshape(B, T, H, F)
126 | x = x.permute(0, 3, 1, 2) # [B,F,T,H]
127 | return x
128 |
129 | def _full(self, x: Tensor) -> Tensor:
130 | B, F, T, H = x.shape
131 | x = self.norm_full(x)
132 | x = x.permute(0, 2, 3, 1) # [B,T,H,F]
133 | x = x.reshape(B * T, H, F)
134 | x = self.squeeze(x) # [B*T,H',F]
135 | if self.dropout_full:
136 | x = x.reshape(B, T, -1, F)
137 | x = x.transpose(1, 3) # [B,F,H',T]
138 | x = self.dropout_full(x) # dropout some frequencies in one utterance
139 | x = x.transpose(1, 3) # [B,T,H',F]
140 | x = x.reshape(B * T, -1, F)
141 |
142 | x = self.full(x) # [B*T,H',F]
143 | x = self.unsqueeze(x) # [B*T,H,F]
144 | x = x.reshape(B, T, H, F)
145 | x = x.permute(0, 3, 1, 2) # [B,F,T,H]
146 | return x
147 |
148 | def extra_repr(self) -> str:
149 | return f"full_share={self.full_share}"
150 |
151 |
152 | class SpatialNet(nn.Module):
153 |
154 | def __init__(
155 | self,
156 | dim_input: int, # the input dim for each time-frequency point
157 | dim_output: int, # the output dim for each time-frequency point
158 | dim_squeeze: int,
159 | num_layers: int,
160 | num_freqs: int,
161 | encoder_kernel_size: int = 5,
162 | dim_hidden: int = 192,
163 | dim_ffn: int = 384,
164 | num_heads: int = 2,
165 | dropout: Tuple[float, float, float] = (0, 0, 0),
166 | kernel_size: Tuple[int, int] = (5, 3),
167 | conv_groups: Tuple[int, int] = (8, 8),
168 | norms: List[str] = ("LN", "LN", "GN", "LN", "LN", "LN"),
169 | padding: str = 'zeros',
170 | full_share: int = 0, # share from layer 0
171 | ):
172 | super().__init__()
173 |
174 | # encoder
175 | self.encoder = nn.Conv1d(in_channels=dim_input, out_channels=dim_hidden, kernel_size=encoder_kernel_size, stride=1, padding="same")
176 |
177 | # spatialnet layers
178 | full = None
179 | layers = []
180 | for l in range(num_layers):
181 | layer = SpatialNetLayer(
182 | dim_hidden=dim_hidden,
183 | dim_ffn=dim_ffn,
184 | dim_squeeze=dim_squeeze,
185 | num_freqs=num_freqs,
186 | num_heads=num_heads,
187 | dropout=dropout,
188 | kernel_size=kernel_size,
189 | conv_groups=conv_groups,
190 | norms=norms,
191 | padding=padding,
192 | full=full if l > full_share else None,
193 | )
194 | if hasattr(layer, 'full'):
195 | full = layer.full
196 | layers.append(layer)
197 | self.layers = nn.ModuleList(layers)
198 |
199 | # decoder
200 | self.decoder = nn.Linear(in_features=dim_hidden, out_features=dim_output)
201 |
202 | def forward(self, x: Tensor, return_attn_score: bool = False) -> Tensor:
203 | # x: [Batch, Freq, Time, Feature]
204 | B, F, T, H0 = x.shape
205 | x = self.encoder(x.reshape(B * F, T, H0).permute(0, 2, 1)).permute(0, 2, 1)
206 | H = x.shape[2]
207 |
208 | attns = [] if return_attn_score else None
209 | x = x.reshape(B, F, T, H)
210 | for m in self.layers:
211 | setattr(m, "need_weights", return_attn_score)
212 | x, attn = m(x)
213 | if return_attn_score:
214 | attns.append(attn)
215 |
216 | y = self.decoder(x)
217 | if return_attn_score:
218 | return y.contiguous(), attns
219 | else:
220 | return y.contiguous()
221 |
222 |
223 | if __name__ == '__main__':
224 | # CUDA_VISIBLE_DEVICES=7, python -m models.arch.SpatialNet
225 | x = torch.randn((1, 129, 251, 12)) #.cuda() # 251 = 4 second; 129 = 8 kHz; 257 = 16 kHz
226 | spatialnet_small = SpatialNet(
227 | dim_input=12,
228 | dim_output=4,
229 | num_layers=8,
230 | dim_hidden=96,
231 | dim_ffn=192,
232 | kernel_size=(5, 3),
233 | conv_groups=(8, 8),
234 | norms=("LN", "LN", "GN", "LN", "LN", "LN"),
235 | dim_squeeze=8,
236 | num_freqs=129,
237 | full_share=0,
238 | ) #.cuda()
239 | # from packaging.version import Version
240 | # if Version(torch.__version__) >= Version('2.0.0'):
241 | # SSFNet_small = torch.compile(SSFNet_small)
242 | # torch.cuda.synchronize(7)
243 | import time
244 | ts = time.time()
245 | y = spatialnet_small(x)
246 | # torch.cuda.synchronize(7)
247 | te = time.time()
248 | print(spatialnet_small)
249 | print(y.shape)
250 | print(te - ts)
251 |
252 | spatialnet_small = spatialnet_small.to('meta')
253 | x = x.to('meta')
254 | from torch.utils.flop_counter import FlopCounterMode # requires torch>=2.1.0
255 | with FlopCounterMode(spatialnet_small, display=False) as fcm:
256 | y = spatialnet_small(x)
257 | flops_forward_eval = fcm.get_total_flops()
258 | res = y.sum()
259 | res.backward()
260 | flops_backward_eval = fcm.get_total_flops() - flops_forward_eval
261 |
262 | params_eval = sum(param.numel() for param in spatialnet_small.parameters())
263 | print(f"flops_forward={flops_forward_eval/1e9:.2f}G, flops_back={flops_backward_eval/1e9:.2f}G, params={params_eval/1e6:.2f} M")
264 |
--------------------------------------------------------------------------------
/models/arch/base/linear_group.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 | from torch.nn import *
4 | import math
5 |
6 |
7 | class LinearGroup(nn.Module):
8 |
9 | def __init__(self, in_features: int, out_features: int, num_groups: int, bias: bool = True) -> None:
10 | super(LinearGroup, self).__init__()
11 | self.in_features = in_features
12 | self.out_features = out_features
13 | self.num_groups = num_groups
14 | self.weight = Parameter(torch.empty((num_groups, out_features, in_features)))
15 | if bias:
16 | self.bias = Parameter(torch.empty(num_groups, out_features))
17 | else:
18 | self.register_parameter('bias', None)
19 | self.reset_parameters()
20 |
21 | def reset_parameters(self) -> None:
22 | # same as linear
23 | init.kaiming_uniform_(self.weight, a=math.sqrt(5))
24 | if self.bias is not None:
25 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
26 | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
27 | init.uniform_(self.bias, -bound, bound)
28 |
29 | def forward(self, x: Tensor) -> Tensor:
30 | """shape [..., group, feature]"""
31 | x = torch.einsum("...gh,gkh->...gk", x, self.weight)
32 | if self.bias is not None:
33 | x = x + self.bias
34 | return x
35 |
36 | def extra_repr(self) -> str:
37 | return f"{self.in_features}, {self.out_features}, num_groups={self.num_groups}, bias={True if self.bias is not None else False}"
38 |
39 |
40 | class Conv1dGroup(nn.Module):
41 |
42 | def __init__(self, in_features: int, out_features: int, num_groups: int, kernel_size: int, bias: bool = True) -> None:
43 | super(Conv1dGroup, self).__init__()
44 | self.in_features = in_features
45 | self.out_features = out_features
46 | self.num_groups = num_groups
47 | self.kernel_size = kernel_size
48 |
49 | self.weight = Parameter(torch.empty((num_groups, out_features, in_features, kernel_size)))
50 | if bias:
51 | self.bias = Parameter(torch.empty(num_groups, out_features))
52 | else:
53 | self.register_parameter('bias', None)
54 | self.reset_parameters()
55 |
56 | def reset_parameters(self) -> None:
57 | # same as linear
58 | init.kaiming_uniform_(self.weight, a=math.sqrt(5))
59 | if self.bias is not None:
60 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
61 | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
62 | init.uniform_(self.bias, -bound, bound)
63 |
64 | def forward(self, x: Tensor) -> Tensor:
65 | """shape [batch, time, group, feature]"""
66 | (B, T, G, F), K = x.shape, self.kernel_size
67 | x = x.permute(0, 2, 3, 1).reshape(B * G * F, 1, 1, T) # [B*G*F,1,1,T]
68 | x = torch.nn.functional.unfold(x, kernel_size=(1, K), padding=(0, K // 2)) # [B*G*F,K,T]
69 | x = x.reshape(B, G, F, K, T)
70 | x = torch.einsum("bgfkt,gofk->btgo", x, self.weight)
71 | if self.bias is not None:
72 | x = x + self.bias
73 | return x
74 |
75 | def extra_repr(self) -> str:
76 | return f"{self.in_features}, {self.out_features}, num_groups={self.num_groups}, kernel_size={self.kernel_size}, bias={True if self.bias is not None else False}"
77 |
--------------------------------------------------------------------------------
/models/arch/base/non_linear.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor, nn
3 |
4 |
5 | class PReLU(nn.PReLU):
6 |
7 | def __init__(self, num_parameters: int = 1, init: float = 0.25, dim: int = 1, device=None, dtype=None) -> None:
8 | super().__init__(num_parameters, init, device, dtype)
9 | self.dim = dim
10 |
11 | def forward(self, input: Tensor) -> Tensor:
12 | if self.dim == 1:
13 | # [B, Chn, Feature]
14 | return super().forward(input)
15 | else:
16 | return super().forward(input.transpose(self.dim, 1)).transpose(self.dim, 1)
17 |
18 |
19 | def new_non_linear(non_linear_type: str, dim_hidden: int, seq_last: bool) -> nn.Module:
20 | if non_linear_type.lower() == 'prelu':
21 | return PReLU(num_parameters=dim_hidden, dim=1 if seq_last == True else -1)
22 | elif non_linear_type.lower() == 'silu':
23 | return nn.SiLU()
24 | elif non_linear_type.lower() == 'sigmoid':
25 | return nn.Sigmoid()
26 | elif non_linear_type.lower() == 'relu':
27 | return nn.ReLU()
28 | elif non_linear_type.lower() == 'leakyrelu':
29 | return nn.LeakyReLU()
30 | elif non_linear_type.lower() == 'elu':
31 | return nn.ELU()
32 | else:
33 | raise Exception(non_linear_type)
34 |
35 |
36 | if __name__ == '__main__':
37 | x = torch.rand(size=(12, 10, 100))
38 | prelu = new_non_linear('PReLU', 10, True)
39 | y = prelu(x)
40 | print(y.shape)
41 |
--------------------------------------------------------------------------------
/models/arch/base/norm.py:
--------------------------------------------------------------------------------
1 | from typing import *
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.init as init
6 | from torch import Tensor
7 | from torch.nn import Module
8 | from torch.nn.parameter import Parameter
9 |
10 |
11 | class LayerNorm(nn.LayerNorm):
12 |
13 | def __init__(self, seq_last: bool, **kwargs) -> None:
14 | """
15 | Arg s:
16 | seq_last (bool): whether the sequence dim is the last dim
17 | """
18 | super().__init__(**kwargs)
19 | self.seq_last = seq_last
20 |
21 | def forward(self, input: Tensor) -> Tensor:
22 | if self.seq_last:
23 | input = input.transpose(-1, 1) # [B, H, Seq] -> [B, Seq, H], or [B,H,w,h] -> [B,h,w,H]
24 | o = super().forward(input)
25 | if self.seq_last:
26 | o = o.transpose(-1, 1)
27 | return o
28 |
29 |
30 | class GlobalLayerNorm(nn.Module):
31 | """gLN in convtasnet"""
32 |
33 | def __init__(self, dim_hidden: int, seq_last: bool, eps: float = 1e-5) -> None:
34 | super().__init__()
35 | self.dim_hidden = dim_hidden
36 | self.seq_last = seq_last
37 | self.eps = eps
38 |
39 | if seq_last:
40 | self.weight = Parameter(torch.empty([dim_hidden, 1]))
41 | self.bias = Parameter(torch.empty([dim_hidden, 1]))
42 | else:
43 | self.weight = Parameter(torch.empty([dim_hidden]))
44 | self.bias = Parameter(torch.empty([dim_hidden]))
45 | init.ones_(self.weight)
46 | init.zeros_(self.bias)
47 |
48 | def forward(self, input: Tensor) -> Tensor:
49 | """
50 | Args:
51 | input (Tensor): shape [B, Seq, H] or [B, H, Seq]
52 | """
53 | var, mean = torch.var_mean(input, dim=(1, 2), unbiased=False, keepdim=True)
54 |
55 | output = (input - mean) / torch.sqrt(var + self.eps)
56 | output = output * self.weight + self.bias
57 | return output
58 |
59 | def extra_repr(self) -> str:
60 | return '{dim_hidden}, seq_last={seq_last}, eps={eps}'.format(**self.__dict__)
61 |
62 |
63 | class BatchNorm1d(nn.Module):
64 |
65 | def __init__(self, seq_last: bool, **kwargs) -> None:
66 | super().__init__()
67 | self.seq_last = seq_last
68 | self.bn = nn.BatchNorm1d(**kwargs)
69 |
70 | def forward(self, input: Tensor) -> Tensor:
71 | if not self.seq_last:
72 | input = input.transpose(-1, -2) # [B, Seq, H] -> [B, H, Seq]
73 | o = self.bn.forward(input) # accepts [B, H, Seq]
74 | if not self.seq_last:
75 | o = o.transpose(-1, -2)
76 | return o
77 |
78 |
79 | class GroupNorm(nn.GroupNorm):
80 |
81 | def __init__(self, seq_last: bool, **kwargs) -> None:
82 | super().__init__(**kwargs)
83 | self.seq_last = seq_last
84 |
85 | def forward(self, input: Tensor) -> Tensor:
86 | if self.seq_last == False:
87 | input = input.transpose(-1, 1) # [B, ..., H] -> [B, H, ...]
88 | o = super().forward(input) # accepts [B, H, ...]
89 | if self.seq_last == False:
90 | o = o.transpose(-1, 1)
91 | return o
92 |
93 |
94 | class GroupBatchNorm(Module):
95 | """Applies Group Batch Normalization over a group of inputs
96 |
97 | This layer uses statistics computed from input data in both training and
98 | evaluation modes.
99 |
100 | see: `Changsheng Quan, Xiaofei Li. NBC2: Multichannel Speech Separation with Revised Narrow-band Conformer. arXiv:2212.02076.`
101 |
102 | """
103 |
104 | dim_hidden: int
105 | group_size: int
106 | eps: float
107 | affine: bool
108 | seq_last: bool
109 | share_along_sequence_dim: bool
110 |
111 | def __init__(
112 | self,
113 | dim_hidden: int,
114 | group_size: Optional[int],
115 | share_along_sequence_dim: bool = False,
116 | seq_last: bool = False,
117 | affine: bool = True,
118 | eps: float = 1e-5,
119 | dims_norm: List[int] = None,
120 | dim_affine: int = None,
121 | ) -> None:
122 | """
123 | Args:
124 | dim_hidden (int): hidden dimension
125 | group_size (int): the size of group, optional
126 | share_along_sequence_dim (bool): share statistics along the sequence dimension. Defaults to False.
127 | seq_last (bool): whether the shape of input is [B, Seq, H] or [B, H, Seq]. Defaults to False, i.e. [B, Seq, H].
128 | affine (bool): affine transformation. Defaults to True.
129 | eps (float): Defaults to 1e-5.
130 | dims_norm: the dims for normalization
131 | dim_affine: the dims for affine transformation
132 | """
133 | super(GroupBatchNorm, self).__init__()
134 |
135 | self.dim_hidden = dim_hidden
136 | self.group_size = group_size
137 | self.eps = eps
138 | self.affine = affine
139 | self.seq_last = seq_last
140 | self.share_along_sequence_dim = share_along_sequence_dim
141 | if self.affine:
142 | if seq_last:
143 | weight = torch.empty([dim_hidden, 1])
144 | bias = torch.empty([dim_hidden, 1])
145 | else:
146 | self.weight = torch.empty([dim_hidden])
147 | self.bias = torch.empty([dim_hidden])
148 |
149 | assert (dims_norm is not None and dim_affine is not None) or (dims_norm is not None), (dims_norm, dim_affine, 'should be none at the time')
150 | self.dims_norm, self.dim_affine = dims_norm, dim_affine
151 | if dim_affine is not None:
152 | assert dim_affine < 0, dim_affine
153 | weight = weight.squeeze()
154 | bias = bias.squeeze()
155 | while dim_affine < -1:
156 | weight = weight.unsqueeze(-1)
157 | bias = bias.unsqueeze(-1)
158 | dim_affine += 1
159 |
160 | self.weight = Parameter(weight)
161 | self.bias = Parameter(bias)
162 | self.reset_parameters()
163 |
164 | def reset_parameters(self) -> None:
165 | if self.affine:
166 | init.ones_(self.weight)
167 | init.zeros_(self.bias)
168 |
169 | def forward(self, x: Tensor, group_size: int = None) -> Tensor:
170 | """
171 | Args:
172 | x: shape [B, Seq, H] if seq_last=False, else shape [B, H, Seq] , where B = num of groups * group size.
173 | group_size: the size of one group. if not given anywhere, the input must be 4-dim tensor with shape [B, group_size, Seq, H] or [B, group_size, H, Seq]
174 | """
175 | if self.group_size != None:
176 | assert group_size == None or group_size == self.group_size, (group_size, self.group_size)
177 | group_size = self.group_size
178 |
179 | if group_size is not None:
180 | assert (x.shape[0] // group_size) * group_size, f'batch size {x.shape[0]} is not divisible by group size {group_size}'
181 |
182 | original_shape = x.shape
183 | if self.dims_norm is not None:
184 | var, mean = torch.var_mean(x, dim=self.dims_norm, unbiased=False, keepdim=True)
185 | output = (x - mean) / torch.sqrt(var + self.eps)
186 | if self.affine:
187 | output = output * self.weight + self.bias
188 | elif self.seq_last == False:
189 | if x.ndim == 4:
190 | assert group_size is None or group_size == x.shape[1], (group_size, x.shape)
191 | B, group_size, Seq, H = x.shape
192 | else:
193 | B, Seq, H = x.shape
194 | x = x.reshape(B // group_size, group_size, Seq, H)
195 |
196 | if self.share_along_sequence_dim:
197 | var, mean = torch.var_mean(x, dim=(1, 2, 3), unbiased=False, keepdim=True)
198 | else:
199 | var, mean = torch.var_mean(x, dim=(1, 3), unbiased=False, keepdim=True)
200 |
201 | output = (x - mean) / torch.sqrt(var + self.eps)
202 | if self.affine:
203 | output = output * self.weight + self.bias
204 |
205 | output = output.reshape(original_shape)
206 | else:
207 | if x.ndim == 4:
208 | assert group_size is None or group_size == x.shape[1], (group_size, x.shape)
209 | B, group_size, H, Seq = x.shape
210 | else:
211 | B, H, Seq = x.shape
212 | x = x.reshape(B // group_size, group_size, H, Seq)
213 |
214 | if self.share_along_sequence_dim:
215 | var, mean = torch.var_mean(x, dim=(1, 2, 3), unbiased=False, keepdim=True)
216 | else:
217 | var, mean = torch.var_mean(x, dim=(1, 2), unbiased=False, keepdim=True)
218 |
219 | output = (x - mean) / torch.sqrt(var + self.eps)
220 | if self.affine:
221 | output = output * self.weight + self.bias
222 |
223 | output = output.reshape(original_shape)
224 |
225 | return output
226 |
227 | def extra_repr(self) -> str:
228 | return '{dim_hidden}, {group_size}, share_along_sequence_dim={share_along_sequence_dim}, seq_last={seq_last}, eps={eps}, ' \
229 | 'affine={affine}'.format(**self.__dict__)
230 |
231 |
232 | def new_norm(norm_type: str, dim_hidden: int, seq_last: bool, group_size: int = None, num_groups: int = None, dims_norm: List[int] = None, dim_affine: int = None) -> nn.Module:
233 | if norm_type.upper() == 'LN':
234 | norm = LayerNorm(normalized_shape=dim_hidden, seq_last=seq_last)
235 | elif norm_type.upper() == 'GBN':
236 | norm = GroupBatchNorm(dim_hidden=dim_hidden, seq_last=seq_last, group_size=group_size, share_along_sequence_dim=False, dims_norm=dims_norm, dim_affine=dim_affine)
237 | elif norm_type == 'GBNShare':
238 | norm = GroupBatchNorm(dim_hidden=dim_hidden, seq_last=seq_last, group_size=group_size, share_along_sequence_dim=True, dims_norm=dims_norm, dim_affine=dim_affine)
239 | elif norm_type.upper() == 'BN':
240 | norm = BatchNorm1d(num_features=dim_hidden, seq_last=seq_last)
241 | elif norm_type.upper() == 'GN':
242 | norm = GroupNorm(num_groups=num_groups, num_channels=dim_hidden, seq_last=seq_last)
243 | elif norm == 'gLN':
244 | norm = GlobalLayerNorm(dim_hidden, seq_last=seq_last)
245 | else:
246 | raise Exception(norm_type)
247 | return norm
248 |
--------------------------------------------------------------------------------
/models/arch/blstm2_fc1.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torch import Tensor
3 | from typing import Optional, Tuple
4 |
5 |
6 | class BLSTM2_FC1(nn.Module):
7 |
8 | def __init__(
9 | self,
10 | dim_input: int,
11 | dim_output: int,
12 | activation: Optional[str] = "",
13 | hidden_size: Tuple[int, int] = (256, 128),
14 | n_repeat_last_lstm: int = 1,
15 | dropout: Optional[float] = None,
16 | ):
17 | """Two layers of BiLSTMs & one fully connected layer
18 |
19 | Args:
20 | input_size: the input size for the features of the first BiLSTM layer
21 | output_size: the output size for the features of the last BiLSTM layer
22 | hidden_size: the hidden size of each BiLSTM layer. Defaults to (256, 128).
23 | """
24 |
25 | super().__init__()
26 |
27 | self.input_size = dim_input
28 | self.output_size = dim_output
29 | self.hidden_size = hidden_size
30 | self.activation = activation
31 | self.dropout = dropout
32 |
33 | self.blstm1 = nn.LSTM(input_size=self.input_size, hidden_size=self.hidden_size[0], batch_first=True, bidirectional=True) # type:ignore
34 | self.blstm2 = nn.LSTM(input_size=self.hidden_size[0] * 2, hidden_size=self.hidden_size[1], batch_first=True, bidirectional=True, num_layers=n_repeat_last_lstm) # type:ignore
35 | if dropout is not None:
36 | self.dropout1 = nn.Dropout(p=dropout)
37 | self.dropout2 = nn.Dropout(p=dropout)
38 |
39 | self.linear = nn.Linear(self.hidden_size[1] * 2, self.output_size) # type:ignore
40 | if self.activation is not None and len(self.activation) > 0: # type:ignore
41 | self.activation_func = getattr(nn, self.activation)() # type:ignore
42 | else:
43 | self.activation_func = None
44 |
45 | def forward(self, x: Tensor) -> Tensor:
46 | """forward
47 |
48 | Args:
49 | x: shape [batch, num_freqs, seq, input_size]
50 |
51 | Returns:
52 | Tensor: shape [batch, num_freqs, seq, output_size]
53 | """
54 | # x: [Batch, NumFreqs, Time, Feature]
55 | B, F, T, H = x.shape
56 | x = x.reshape(B * F, T, H)
57 | x, _ = self.blstm1(x)
58 | if self.dropout:
59 | x = self.dropout1(x)
60 | x, _ = self.blstm2(x)
61 | if self.dropout:
62 | x = self.dropout2(x)
63 | if self.activation_func is not None:
64 | y = self.activation_func(self.linear(x))
65 | else:
66 | y = self.linear(x)
67 |
68 | y = y.reshape(B, F, T, -1)
69 | return y
--------------------------------------------------------------------------------
/models/io/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Audio-WestlakeU/NBSS/cc42fc8ad2e6642c09b8f4169a85b4766dc22b7e/models/io/__init__.py
--------------------------------------------------------------------------------
/models/io/cirm.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | EPSILON = np.finfo(np.float32).eps
5 |
6 |
7 | def build_complex_ideal_ratio_mask(noisy: torch.Tensor, clean: torch.Tensor) -> torch.Tensor:
8 | """Build the complex ratio mask.
9 |
10 | Args:
11 | noisy: [..., F, T], noisy complex-valued stft coefficients
12 | clean: [..., F, T], clean complex-valued stft coefficients
13 |
14 | References:
15 | https://ieeexplore.ieee.org/document/7364200
16 |
17 | Returns:
18 | [..., F, T, 2]
19 | """
20 | noisy_real, noisy_imag = noisy.real, noisy.imag
21 | clean_real, clean_imag = clean.real, clean.imag
22 |
23 | denominator = torch.square(noisy_real) + torch.square(noisy_imag) + EPSILON
24 |
25 | mask_real = (noisy_real * clean_real + noisy_imag * clean_imag) / denominator
26 | mask_imag = (noisy_real * clean_imag - noisy_imag * clean_real) / denominator
27 |
28 | complex_ratio_mask = torch.stack((mask_real, mask_imag), dim=-1)
29 |
30 | cirm = compress_cIRM(complex_ratio_mask, K=10, C=0.1)
31 | cirm = torch.view_as_complex(cirm)
32 | return cirm
33 |
34 |
35 | def compress_cIRM(mask, K=10, C=0.1):
36 | """Compress the value of cIRM from (-inf, +inf) to [-K ~ K].
37 |
38 | References:
39 | https://ieeexplore.ieee.org/document/7364200
40 | """
41 | if torch.is_tensor(mask):
42 | mask = -100 * (mask <= -100) + mask * (mask > -100)
43 | mask = K * (1 - torch.exp(-C * mask)) / (1 + torch.exp(-C * mask))
44 | else:
45 | mask = -100 * (mask <= -100) + mask * (mask > -100)
46 | mask = K * (1 - np.exp(-C * mask)) / (1 + np.exp(-C * mask))
47 | return mask
48 |
49 |
50 | def decompress_cIRM(mask, K=10, limit=9.9):
51 | """Decompress cIRM from [-K ~ K] to [-inf, +inf].
52 |
53 | Args:
54 | mask: cIRM mask
55 | K: default 10
56 | limit: default 0.1
57 |
58 | References:
59 | https://ieeexplore.ieee.org/document/7364200
60 | """
61 | mask = torch.view_as_real(mask)
62 | mask = (limit * (mask >= limit) - limit * (mask <= -limit) + mask * (torch.abs(mask) < limit))
63 | mask = -K * torch.log((K - mask) / (K + mask))
64 | return torch.view_as_complex(mask)
65 |
66 |
67 | def complex_mul(noisy_r, noisy_i, mask_r, mask_i):
68 | r = noisy_r * mask_r - noisy_i * mask_i
69 | i = noisy_r * mask_i + noisy_i * mask_r
70 | return r, i
71 |
72 |
73 | def complex_mul_v2(noisy, mask):
74 | return noisy * mask
75 |
--------------------------------------------------------------------------------
/models/io/loss.py:
--------------------------------------------------------------------------------
1 | from torch import Tensor
2 | import torch
3 | from torch import nn
4 |
5 | from torchmetrics.functional.audio import scale_invariant_signal_distortion_ratio as si_sdr
6 | from torchmetrics.functional.audio import signal_noise_ratio as snr
7 | from torchmetrics.functional.audio import source_aggregated_signal_distortion_ratio as sa_sdr
8 | from torchmetrics.functional.audio import permutation_invariant_training as pit
9 | from torchmetrics.functional.audio import pit_permutate as permutate
10 | from typing import *
11 | from models.io.cirm import build_complex_ideal_ratio_mask, decompress_cIRM
12 | from models.io.stft import STFT
13 |
14 |
15 | def neg_sa_sdr(preds: Tensor, target: Tensor, scale_invariant: bool = False) -> Tensor:
16 | batch_size = target.shape[0]
17 | sa_sdr_val = sa_sdr(preds=preds, target=target, scale_invariant=scale_invariant)
18 | return -torch.mean(sa_sdr_val.view(batch_size, -1), dim=1)
19 |
20 |
21 | def neg_si_sdr(preds: Tensor, target: Tensor) -> Tensor:
22 | """calculate neg_si_sdr loss for a batch
23 |
24 | Returns:
25 | loss: shape [batch], real
26 | """
27 | batch_size = target.shape[0]
28 | si_sdr_val = si_sdr(preds=preds, target=target)
29 | return -torch.mean(si_sdr_val.view(batch_size, -1), dim=1)
30 |
31 |
32 | def neg_snr(preds: Tensor, target: Tensor) -> Tensor:
33 | """calculate neg_snr loss for a batch
34 |
35 | Returns:
36 | loss: shape [batch], real
37 | """
38 | batch_size = target.shape[0]
39 | snr_val = snr(preds=preds, target=target)
40 | return -torch.mean(snr_val.view(batch_size, -1), dim=1)
41 |
42 |
43 | def _mse(preds: Tensor, target: Tensor) -> Tensor:
44 | """calculate mse loss for a batch
45 |
46 | Returns:
47 | loss: shape [batch], real
48 | """
49 | batch_size = target.shape[0]
50 | diff = preds - target
51 | diff = diff.view(batch_size, -1)
52 | mse_val = torch.mean(diff**2, dim=1)
53 | return mse_val
54 |
55 |
56 | def cirm_mse(preds: Tensor, target: Tensor) -> Tensor:
57 | """calculate mse loss for a batch of cirms
58 |
59 | Returns:
60 | loss: shape [batch], real
61 | """
62 | return _mse(preds=preds, target=target)
63 |
64 |
65 | def cc_mse(preds: Tensor, target: Tensor) -> Tensor:
66 | """calculate mse loss for a batch of STFT coefficients
67 |
68 | Returns:
69 | loss: shape [batch], real
70 | """
71 | return _mse(preds=preds, target=target)
72 |
73 |
74 | class Loss(nn.Module):
75 | is_scale_invariant_loss: bool
76 | name: str
77 | mask: str
78 |
79 | def __init__(self, loss_func: Callable, pit: bool, loss_func_kwargs: Dict[str, Any] = dict()):
80 | super().__init__()
81 |
82 | self.loss_func = loss_func
83 | self.pit = pit
84 | self.loss_func_kwargs = loss_func_kwargs
85 | self.is_scale_invariant_loss = {
86 | neg_sa_sdr: True if 'scale_invariant' in loss_func_kwargs and loss_func_kwargs['scale_invariant'] == True else False,
87 | neg_si_sdr: True,
88 | neg_snr: False,
89 | cirm_mse: False,
90 | cc_mse: False,
91 | }[loss_func]
92 | self.name = loss_func.__name__
93 | self.mask = 'cirm' if self.loss_func == cirm_mse else None
94 |
95 | def forward(self, yr_hat: Tensor, yr: Tensor, reorder: bool = None, reduce_batch: bool = True, **kwargs) -> Tuple[Tensor, Tensor, Tensor]:
96 | if self.mask is not None:
97 | out, Xr, stft = kwargs['out'], kwargs['Xr'], kwargs['stft']
98 | Yr, _ = stft.stft(yr)
99 | preds, target = out, self.to_mask(Yr=Yr, Xr=Xr)
100 | preds, target = torch.view_as_real(preds), torch.view_as_real(target)
101 | elif self.loss_func == cc_mse:
102 | out, XrMM, stft = kwargs['out'], kwargs['XrMM'], kwargs['stft']
103 | Yr, _ = stft.stft(yr)
104 | Yr = Yr / XrMM
105 | preds, target = torch.view_as_real(out), torch.view_as_real(Yr)
106 | else:
107 | preds, target = yr_hat, yr
108 |
109 | perms = None
110 | if self.pit:
111 | losses, perms = pit(preds=preds, target=target, metric_func=self.loss_func, eval_func='min', mode="permutation-wise", **self.loss_func_kwargs)
112 | else:
113 | losses = self.loss_func(preds=preds, target=target, **self.loss_func_kwargs)
114 |
115 | if reorder and perms is not None:
116 | yr_hat = permutate(yr_hat, perm=perms)
117 |
118 | return losses.mean() if reduce_batch else losses, perms, yr_hat
119 |
120 | def to_CC(self, out: Tensor, Xr: Tensor, stft: STFT, XrMM: Tensor) -> Tensor:
121 | if self.loss_func == cirm_mse:
122 | cIRM = decompress_cIRM(mask=out)
123 | Yr = Xr * cIRM
124 | return Yr, {'out': out, 'Xr': Xr, 'stft': stft, 'XrMM': XrMM}
125 | else:
126 | return out, {'out': out, 'Xr': Xr, 'stft': stft, 'XrMM': XrMM}
127 |
128 | def to_mask(self, Yr: Tensor, Xr: Tensor):
129 | if self.mask == 'cirm':
130 | return build_complex_ideal_ratio_mask(noisy=Xr, clean=Yr)
131 | else:
132 | raise Exception(f'not implemented for mask type {self.mask}')
133 |
134 | def extra_repr(self) -> str:
135 | kwargs = ""
136 | for k, v in self.loss_func_kwargs.items():
137 | kwargs += f'{k}={v},'
138 |
139 | return f"loss_func={self.loss_func.__name__}({kwargs}), pit={self.pit}, mask={self.mask}"
140 |
--------------------------------------------------------------------------------
/models/io/norm.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torch import Tensor
3 | from typing import *
4 |
5 | import torch
6 |
7 |
8 | def forgetting_normalization(XrMag: Tensor, sliding_window_len: int = 192) -> Tensor:
9 | # https://github.com/Audio-WestlakeU/FullSubNet/blob/e97448375cd1e883276ad583317b1828318910dc/audio_zen/model/base_model.py#L103C19-L103C19
10 | alpha = (sliding_window_len - 1) / (sliding_window_len + 1)
11 | mu = 0
12 | mu_list = []
13 | B, _, F, T = XrMag.shape
14 | XrMM = XrMag.mean(dim=2, keepdim=True).detach().cpu() # [B,1,1,T]
15 | for t in range(T):
16 | if t < sliding_window_len:
17 | alpha_this = min((t - 1) / (t + 1), alpha)
18 | else:
19 | alpha_this = alpha
20 | mu = alpha_this * mu + (1 - alpha_this) * XrMM[..., t]
21 | mu_list.append(mu)
22 |
23 | XrMM = torch.stack(mu_list, dim=-1).to(XrMag.device)
24 | return XrMM
25 |
26 |
27 | # 雨洁的实现
28 | # def cumulative_normalization(original_signal_mag: Tensor, sliding_window_len: int = 192) -> Tensor:
29 | # alpha = (sliding_window_len - 1) / (sliding_window_len + 1)
30 | # eps = 1e-10
31 | # mu = 0
32 | # mu_list = []
33 | # batch_size, frame_num, freq_num = original_signal_mag.shape
34 | # for frame_idx in range(frame_num):
35 | # if frame_idx < sliding_window_len:
36 | # alp = torch.min(torch.tensor([(frame_idx - 1) / (frame_idx + 1), alpha]))
37 | # mu = alp * mu + (1 - alp) * torch.mean(original_signal_mag[:, frame_idx, :], dim=-1).reshape(batch_size, 1)
38 | # else:
39 | # current_frame_mu = torch.mean(original_signal_mag[:, frame_idx, :], dim=-1).reshape(batch_size, 1)
40 | # mu = alpha * mu + (1 - alpha) * current_frame_mu
41 | # mu_list.append(mu)
42 |
43 | # XrMM = torch.stack(mu_list, dim=-1).permute(0, 2, 1).reshape(batch_size, frame_num, 1, 1)
44 | # return XrMM
45 |
46 |
47 | class Norm(nn.Module):
48 |
49 | def __init__(self, mode: Optional[Literal['utterance', 'frequency', 'forgetting', 'none']], online: bool = True) -> None:
50 | super().__init__()
51 | self.mode = mode
52 | self.online = online
53 | assert mode != 'forgetting' or online == True, 'forgetting is one online normalization'
54 |
55 | def forward(self, X: Tensor, norm_paras: Any = None, inverse: bool = False) -> Any:
56 | if not inverse:
57 | return self.norm(X, norm_paras=norm_paras)
58 | else:
59 | return self.inorm(X, norm_paras=norm_paras)
60 |
61 | def norm(self, X: Tensor, norm_paras: Any = None, ref_channel: int = None, eps: float = 1e-6) -> Tuple[Tensor, Any]:
62 | """ normalization
63 | Args:
64 | X: [B, Chn, F, T], complex
65 | norm_paras: the paramters for inverse normalization or for the normalization of other X's
66 | eps: 1e-6!=0 when dtype=float16
67 |
68 | Returns:
69 | the normalized tensor and the paramters for inverse normalization
70 | """
71 | if self.mode == 'none' or self.mode is None:
72 | Xr = X[:, [ref_channel], :, :].clone()
73 | return X, (Xr, None)
74 |
75 | B, C, F, T = X.shape
76 | if norm_paras is None:
77 | Xr = X[:, [ref_channel], :, :].clone() # [B,1,F,T], complex
78 |
79 | if self.mode == 'frequency':
80 | if self.online:
81 | XrMM = torch.abs(Xr) + eps # [B,1,F,T]
82 | else:
83 | XrMM = torch.abs(Xr).mean(dim=3, keepdim=True) + eps # Xr_magnitude_mean, [B,1,F,1]
84 | elif self.mode == 'forgetting':
85 | XrMM = forgetting_normalization(XrMag=torch.abs(Xr)) + eps # [B,1,1,T]
86 | else:
87 | assert self.mode == 'utterance', self.mode
88 | if self.online:
89 | XrMM = torch.abs(Xr).mean(dim=(2,), keepdim=True) + eps # Xr_magnitude_mean, [B,1,1,T]
90 | else:
91 | XrMM = torch.abs(Xr).mean(dim=(2, 3), keepdim=True) + eps # Xr_magnitude_mean, [B,1,1,1]
92 | else:
93 | Xr, XrMM = norm_paras
94 | X[:, :, :, :] /= XrMM
95 | return X, (Xr, XrMM)
96 |
97 | def inorm(self, X: Tensor, norm_paras: Any) -> Tensor:
98 | """ inverse normalization
99 | Args:
100 | X: [B, Chn, F, T], complex
101 | norm_paras: the paramters for inverse normalization
102 |
103 | Returns:
104 | the normalized tensor and the paramters for inverse normalization
105 | """
106 |
107 | Xr, XrMM = norm_paras
108 | return X * XrMM
109 |
110 | def extra_repr(self) -> str:
111 | return f"{self.mode}, online={self.online}"
112 |
113 |
114 | if __name__ == '__main__':
115 |
116 | # x = torch.randn((10, 1, 129, 251))
117 | # y1 = forgetting_normalization(x)
118 | # y2 = cumulative_normalization(x.squeeze().transpose(-1, -2))
119 | # print((y1.squeeze() == y2.squeeze()).all())
120 | # print(torch.allclose(y1.squeeze(), y2.squeeze()))
121 |
122 | x = torch.randn((2, 1, 129, 251), dtype=torch.complex64).cuda()
123 | # norm = Norm('forgetting')
124 | norm = Norm('utterance', online=False)
125 | for i in range(10):
126 | y = norm.norm(x, ref_channel=0)
127 | import time
128 | torch.cuda.synchronize()
129 | ts = time.time()
130 | for i in range(1):
131 | y = norm.norm(x, ref_channel=0)
132 | torch.cuda.synchronize()
133 | te = time.time()
134 | print(te - ts)
135 |
--------------------------------------------------------------------------------
/models/io/stft.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Mapping
2 | import torch
3 | from torch import nn
4 |
5 | from torch import Tensor
6 | from typing import *
7 |
8 | paras_16k = {
9 | 'n_fft': 512,
10 | 'n_hop': 256,
11 | 'win_len': 512,
12 | }
13 |
14 | paras_8k = {
15 | 'n_fft': 256,
16 | 'n_hop': 128,
17 | 'win_len': 256,
18 | }
19 |
20 |
21 | class STFT(nn.Module):
22 |
23 | def __init__(self, n_fft: int, n_hop: int, win_len: Optional[int] = None, win: str = 'hann_window') -> None:
24 | super().__init__()
25 | self.n_fft, self.n_hop, self.win_len = n_fft, n_hop, win_len if win_len is not None else n_fft
26 | self.repr = str((n_fft, n_hop, win, win_len))
27 |
28 | assert win in ['hann_window', 'sqrt_hann_window'], win
29 | if win == 'hann_window':
30 | window = torch.hann_window(self.n_fft)
31 | else:
32 | # For FT-JNF. Deep Non-linear Filters for Multi-channel Speech Enhancement and Separation
33 | assert win == 'sqrt_hann_window', win
34 | window = torch.sqrt(torch.hann_window(self.n_fft))
35 | self.register_buffer('window', window)
36 |
37 | def forward(self, X: Tensor, original_len: int = None, inverse=False) -> Any:
38 | """istft
39 | Args:
40 | X: complex [..., F, T]
41 | original_len: original length
42 | inverse: stft or istft
43 | """
44 | if not inverse:
45 | return self.stft(x)
46 | else:
47 | return self.istft(x, original_len=original_len)
48 |
49 | def stft(self, x: Tensor) -> Tuple[Tensor, int]:
50 | """stft
51 | Args:
52 | x: [..., time]
53 |
54 | Returns:
55 | the complex STFT domain representation of shape [..., freq, time] and the original length of the time domain waveform
56 | """
57 | shape = list(x.shape)
58 | x = x.reshape(-1, shape[-1])
59 | if x.is_cuda:
60 | with torch.autocast(device_type=x.device.type, dtype=torch.float32): # use float32 for stft & istft
61 | X = torch.stft(x, n_fft=self.n_fft, hop_length=self.n_hop, win_length=self.win_len, window=self.window, return_complex=True)
62 | else:
63 | X = torch.stft(x, n_fft=self.n_fft, hop_length=self.n_hop, win_length=self.win_len, window=self.window, return_complex=True)
64 | F, T = X.shape[-2:]
65 | X = X.reshape(shape=shape[:-1] + [F, T])
66 | return X, shape[-1]
67 |
68 | def istft(self, X: Tensor, original_len: int = None) -> Tensor:
69 | """istft
70 | Args:
71 | X: complex [..., F, T]
72 | original_len: returned by stft
73 |
74 | Returns:
75 | the complex STFT domain representation of shape [..., freq, time] and the original length of the time domain waveform
76 | """
77 | shape = list(X.shape)
78 | X = X.reshape(-1, *shape[-2:])
79 | if X.is_cuda:
80 | with torch.autocast(device_type=X.device.type, dtype=torch.float32): # use float32 for stft & istft
81 | # iSTFT is problematic when batch size is larger than 16
82 | # x = torch.istft(X, n_fft=self.n_fft, hop_length=self.n_hop, win_length=self.win_len, window=self.window, length=original_len)
83 | xs = []
84 | for b in range(X.shape[0]):
85 | xb = torch.istft(X[b], n_fft=self.n_fft, hop_length=self.n_hop, win_length=self.win_len, window=self.window, length=original_len)
86 | xs.append(xb)
87 | x = torch.stack(xs, dim=0)
88 | else:
89 | # iSTFT is problematic when batch size is larger than 16
90 | # x = torch.istft(X, n_fft=self.n_fft, hop_length=self.n_hop, win_length=self.win_len, window=self.window, length=original_len)
91 | xs = []
92 | for b in range(X.shape[0]):
93 | xb = torch.istft(X[b], n_fft=self.n_fft, hop_length=self.n_hop, win_length=self.win_len, window=self.window, length=original_len)
94 | xs.append(xb)
95 | x = torch.stack(xs, dim=0)
96 | x = x.reshape(shape=shape[:-2] + [original_len])
97 | return x
98 |
99 | def extra_repr(self) -> str:
100 | return self.repr
101 |
102 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
103 | return
104 |
105 |
106 | if __name__ == '__main__':
107 | x = torch.randn((1, 1, 8000 * 4))
108 | stft = STFT(**paras_8k)
109 | X, ol = stft.stft(x)
110 | x_p = stft.istft(X, ol)
111 | print(x.shape, x_p.shape, X.shape)
112 | print(torch.allclose(x, x_p, rtol=1e-1))
113 | print()
114 |
--------------------------------------------------------------------------------
/models/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from models.utils.git_tools import tag_and_log_git_status
2 | from models.utils.my_json_encoder import MyJsonEncoder
3 | from models.utils.my_logger import MyLogger
4 | from models.utils.my_progress_bar import MyProgressBar
5 | from models.utils.my_rich_progress_bar import MyRichProgressBar
6 |
--------------------------------------------------------------------------------
/models/utils/base_cli.py:
--------------------------------------------------------------------------------
1 | """
2 | Basic Command Line Interface, provides command line controls for training, test, and inference. Be sure to import this file before `import torch`, otherwise the OMP_NUM_THREADS would not work.
3 | """
4 |
5 | import os
6 |
7 | os.environ["OMP_NUM_THREADS"] = str(2) # limit the threads to reduce cpu overloads, will speed up when there are lots of CPU cores on the running machine
8 | os.environ['OPENBLAS_NUM_THREADS'] = '2'
9 | os.environ["MKL_NUM_THREADS"] = str(2)
10 |
11 | from typing import *
12 |
13 | import torch
14 | if torch.multiprocessing.get_start_method() != 'spawn':
15 | torch.multiprocessing.set_start_method('spawn', force=True) # fix stoi stuck
16 |
17 | from models.utils import MyRichProgressBar as RichProgressBar
18 | # from pytorch_lightning.loggers import TensorBoardLogger
19 | from models.utils.my_logger import MyLogger as TensorBoardLogger
20 |
21 | from pytorch_lightning.callbacks import (LearningRateMonitor, ModelSummary)
22 | from pytorch_lightning.cli import LightningArgumentParser, LightningCLI
23 |
24 | torch.backends.cuda.matmul.allow_tf32 = True # The flag below controls whether to allow TF32 on matmul. This flag defaults to False in PyTorch 1.12 and later.
25 | torch.backends.cudnn.allow_tf32 = True # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
26 | from packaging.version import Version
27 | if Version(torch.__version__) >= Version('2.0.0'):
28 | torch._dynamo.config.optimize_ddp = False # fix this issue: https://github.com/pytorch/pytorch/issues/111279#issuecomment-1870641439
29 | torch._dynamo.config.cache_size_limit = 64
30 |
31 |
32 | class BaseCLI(LightningCLI):
33 |
34 | def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
35 | self.add_model_invariant_arguments_to_parser(parser)
36 |
37 | def add_model_invariant_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
38 | # RichProgressBar
39 | parser.add_lightning_class_args(RichProgressBar, nested_key='progress_bar')
40 | parser.set_defaults({"progress_bar.console_kwargs": {
41 | "force_terminal": True,
42 | "no_color": True,
43 | "width": 200,
44 | }})
45 |
46 | # LearningRateMonitor
47 | parser.add_lightning_class_args(LearningRateMonitor, "learning_rate_monitor")
48 | learning_rate_monitor_defaults = {
49 | "learning_rate_monitor.logging_interval": "epoch",
50 | }
51 | parser.set_defaults(learning_rate_monitor_defaults)
52 |
53 | # ModelSummary
54 | parser.add_lightning_class_args(ModelSummary, 'model_summary')
55 | model_summary_defaults = {
56 | "model_summary.max_depth": 2,
57 | }
58 | parser.set_defaults(model_summary_defaults)
59 |
60 | def before_fit(self):
61 | resume_from_checkpoint: str = self.config['fit']['ckpt_path']
62 | if resume_from_checkpoint is not None and resume_from_checkpoint.endswith('last.ckpt'):
63 | # log in same dir
64 | # resume_from_checkpoint example: /mnt/home/quancs/projects/NBSS_pmt/logs/NBSS_ifp/version_29/checkpoints/last.ckpt
65 | resume_from_checkpoint = os.path.normpath(resume_from_checkpoint)
66 | splits = resume_from_checkpoint.split(os.path.sep)
67 | version = int(splits[-3].replace('version_', ''))
68 | save_dir = os.path.sep.join(splits[:-3])
69 | self.trainer.logger = TensorBoardLogger(save_dir=save_dir, name="", version=version, default_hp_metric=False)
70 | else:
71 | model_name = self.model.name if hasattr(self.model, 'name') else type(self.model).__name__
72 | self.trainer.logger = TensorBoardLogger('logs/', name=model_name, default_hp_metric=False)
73 |
74 | def before_test(self):
75 | if self.config['test']['ckpt_path'] != None:
76 | ckpt_path = self.config['test']['ckpt_path']
77 | else:
78 | raise Exception('You should give --ckpt_path if you want to test')
79 | epoch = os.path.basename(ckpt_path).split('_')[0]
80 | write_dir = os.path.dirname(os.path.dirname(ckpt_path))
81 |
82 | test_set = 'test'
83 | if 'test_set' in self.config['test']['data']:
84 | test_set = self.config['test']['data']["test_set"]
85 | elif 'init_args' in self.config['test']['data'] and 'test_set' in self.config['test']['data']['init_args']:
86 | test_set = self.config['test']['data']['init_args']["test_set"]
87 | exp_save_path = os.path.normpath(write_dir + '/' + epoch + '_' + test_set + '_set')
88 |
89 | self.copy_ckpt(exp_save_path=exp_save_path, ckpt_path=ckpt_path)
90 |
91 | import time
92 | # add 10 seconds for threads to simultaneously detect the next version
93 | self.trainer.logger = TensorBoardLogger(exp_save_path, name='', default_hp_metric=False)
94 | time.sleep(10)
95 |
96 | def after_test(self):
97 | if not self.trainer.is_global_zero:
98 | return
99 | import fnmatch
100 | files = fnmatch.filter(os.listdir(self.trainer.log_dir), 'events.out.tfevents.*')
101 | for f in files:
102 | os.remove(self.trainer.log_dir + '/' + f)
103 | print('tensorboard log file for test is removed: ' + self.trainer.log_dir + '/' + f)
104 |
105 | def before_predict(self):
106 | if self.config['predict']['ckpt_path'] != None:
107 | ckpt_path = self.config['predict']['ckpt_path']
108 | else:
109 | raise Exception('You should give --ckpt_path if you want to test')
110 | epoch = os.path.basename(ckpt_path).split('_')[0]
111 | write_dir = os.path.dirname(os.path.dirname(ckpt_path))
112 |
113 | exp_save_path = os.path.normpath(write_dir + '/' + epoch + '_predict_set')
114 |
115 | self.copy_ckpt(exp_save_path=exp_save_path, ckpt_path=ckpt_path)
116 |
117 | import time
118 | # add 10 seconds for threads to simultaneously detect the next version
119 | self.trainer.logger = TensorBoardLogger(exp_save_path, name='', default_hp_metric=False)
120 | time.sleep(10)
121 |
122 | def after_predict(self):
123 | if not self.trainer.is_global_zero:
124 | return
125 | import fnmatch
126 | files = fnmatch.filter(os.listdir(self.trainer.log_dir), 'events.out.tfevents.*')
127 | for f in files:
128 | os.remove(self.trainer.log_dir + '/' + f)
129 | print('tensorboard log file for predict is removed: ' + self.trainer.log_dir + '/' + f)
130 |
131 | def copy_ckpt(self, exp_save_path: str, ckpt_path: str):
132 | # copy checkpoint to save path
133 | from pathlib import Path
134 | Path(exp_save_path).mkdir(exist_ok=True)
135 | if (Path(exp_save_path) / Path(ckpt_path).name).exists() == False:
136 | import shutil
137 | shutil.copyfile(ckpt_path, Path(exp_save_path) / Path(ckpt_path).name)
138 |
--------------------------------------------------------------------------------
/models/utils/dnsmos.py:
--------------------------------------------------------------------------------
1 | # The following DNSMOS implementation will be available in the next release of torchmetrics.
2 |
3 | import os
4 | from functools import lru_cache
5 | from typing import Any, Dict, Optional
6 |
7 | import numpy as np
8 | import torch
9 | from torch import Tensor
10 |
11 | from lightning_utilities.core.imports import RequirementCache
12 |
13 | _REQUESTS_AVAILABLE = RequirementCache("requests")
14 | _LIBROSA_AVAILABLE = RequirementCache("librosa")
15 | _ONNXRUNTIME_AVAILABLE = RequirementCache("onnxruntime")
16 |
17 | from torchmetrics.utilities import rank_zero_info
18 |
19 | if _LIBROSA_AVAILABLE and _ONNXRUNTIME_AVAILABLE and _REQUESTS_AVAILABLE:
20 | import librosa
21 | import onnxruntime as ort
22 | import requests
23 | from onnxruntime import InferenceSession
24 | else:
25 | librosa, ort, requests = None, None, None # type:ignore
26 |
27 | class InferenceSession: # type:ignore
28 | """Dummy InferenceSession."""
29 |
30 | def __init__(self, **kwargs: Dict[str, Any]) -> None:
31 | ...
32 |
33 |
34 | __doctest_requires__ = {("deep_noise_suppression_mean_opinion_score", "_load_session"): ["requests", "librosa", "onnxruntime"]}
35 |
36 | SAMPLING_RATE = 16000
37 | INPUT_LENGTH = 9.01
38 | DNSMOS_DIR = "~/.torchmetrics/DNSMOS"
39 |
40 |
41 | def _prepare_dnsmos(dnsmos_dir: str) -> None:
42 | """Download required DNSMOS files.
43 |
44 | Args:
45 | dnsmos_dir: a dir to save the downloaded files. Defaults to "~/.torchmetrics".
46 |
47 | """
48 | # https://raw.githubusercontent.com/microsoft/DNS-Challenge/master/DNSMOS/DNSMOS/model_v8.onnx
49 | # https://raw.githubusercontent.com/microsoft/DNS-Challenge/master/DNSMOS/DNSMOS/sig_bak_ovr.onnx
50 | # https://raw.githubusercontent.com/microsoft/DNS-Challenge/master/DNSMOS/pDNSMOS/sig_bak_ovr.onnx
51 | url = "https://raw.githubusercontent.com/microsoft/DNS-Challenge/master"
52 | dnsmos_dir = os.path.expanduser(dnsmos_dir)
53 |
54 | # save to or load from ~/torchmetrics/dnsmos/.
55 | for file in ["DNSMOS/DNSMOS/model_v8.onnx", "DNSMOS/DNSMOS/sig_bak_ovr.onnx", "DNSMOS/pDNSMOS/sig_bak_ovr.onnx"]:
56 | saveto = os.path.join(dnsmos_dir, file[7:])
57 | os.makedirs(os.path.dirname(saveto), exist_ok=True)
58 | if os.path.exists(saveto):
59 | # try load onnx
60 | try:
61 | _ = InferenceSession(saveto)
62 | continue # skip downloading if succeeded
63 | except Exception as _:
64 | os.remove(saveto)
65 | urlf = f"{url}/{file}"
66 | rank_zero_info(f"downloading {urlf} to {saveto}")
67 | myfile = requests.get(urlf)
68 | with open(saveto, "wb") as f:
69 | f.write(myfile.content)
70 |
71 |
72 | @lru_cache
73 | def _load_session(
74 | path: str,
75 | device: torch.device,
76 | ) -> InferenceSession:
77 | """Load onnxruntime session.
78 |
79 | Args:
80 | path: the model path
81 | device: the device used
82 |
83 | Returns:
84 | onnxruntime session
85 |
86 | """
87 | path = os.path.expanduser(path)
88 | if not os.path.exists(path):
89 | _prepare_dnsmos(DNSMOS_DIR)
90 |
91 | opts = ort.SessionOptions()
92 | opts.inter_op_num_threads = 4
93 | opts.intra_op_num_threads = 4
94 |
95 | if device.type == "cpu":
96 | infs = InferenceSession(path, providers=["CPUExecutionProvider"], sess_options=opts)
97 | elif "CUDAExecutionProvider" in ort.get_all_providers():
98 | providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
99 | provider_options = [{"device_id": device.index}, {}]
100 | infs = InferenceSession(path, providers=providers, provider_options=provider_options, sess_options=opts)
101 | else:
102 | infs = InferenceSession(path, providers=["CPUExecutionProvider"], sess_options=opts)
103 |
104 | return infs
105 |
106 |
107 | def _audio_melspec(
108 | audio: np.ndarray,
109 | n_mels: int = 120,
110 | frame_size: int = 320,
111 | hop_length: int = 160,
112 | sr: int = 16000,
113 | to_db: bool = True,
114 | ) -> np.ndarray:
115 | """Calculate the mel-spectrogram of an audio.
116 |
117 | Args:
118 | audio: [..., T]
119 | n_mels: the number of mel-frequencies
120 | frame_size: stft length
121 | hop_length: stft hop length
122 | sr: sample rate of audio
123 | to_db: convert to dB scale if `True` is given
124 |
125 | Returns:
126 | mel-spectrogram: [..., num_mel, T']
127 |
128 | """
129 | shape = audio.shape
130 | audio = audio.reshape(-1, shape[-1])
131 | mel_spec = librosa.feature.melspectrogram(y=audio, sr=sr, n_fft=frame_size + 1, hop_length=hop_length, n_mels=n_mels)
132 | mel_spec = mel_spec.transpose(0, 2, 1)
133 | mel_spec = mel_spec.reshape(shape[:-1] + mel_spec.shape[1:])
134 | if to_db:
135 | for b in range(mel_spec.shape[0]):
136 | mel_spec[b, ...] = (librosa.power_to_db(mel_spec[b], ref=np.max) + 40) / 40
137 | return mel_spec
138 |
139 |
140 | def _polyfit_val(mos: np.ndarray, personalized: bool) -> np.ndarray:
141 | """Use polyfit to convert raw mos values to DNSMOS values.
142 |
143 | Args:
144 | mos: the raw mos values, [..., 4]
145 | personalized: whether interfering speaker is penalized
146 |
147 | Returns:
148 | DNSMOS: [..., 4]
149 |
150 | """
151 | if personalized:
152 | p_ovr = np.poly1d([-0.00533021, 0.005101, 1.18058466, -0.11236046])
153 | p_sig = np.poly1d([-0.01019296, 0.02751166, 1.19576786, -0.24348726])
154 | p_bak = np.poly1d([-0.04976499, 0.44276479, -0.1644611, 0.96883132])
155 | else:
156 | p_ovr = np.poly1d([-0.06766283, 1.11546468, 0.04602535])
157 | p_sig = np.poly1d([-0.08397278, 1.22083953, 0.0052439]) # x**2*v0 + x**1*v1+ v2
158 | p_bak = np.poly1d([-0.13166888, 1.60915514, -0.39604546])
159 |
160 | mos[..., 1] = p_sig(mos[..., 1])
161 | mos[..., 2] = p_bak(mos[..., 2])
162 | mos[..., 3] = p_ovr(mos[..., 3])
163 | return mos
164 |
165 |
166 | def deep_noise_suppression_mean_opinion_score(preds: Tensor, fs: int, personalized: bool, device: Optional[str] = None) -> Tensor:
167 | """Calculate `Deep Noise Suppression performance evaluation based on Mean Opinion Score`_ (DNSMOS).
168 |
169 | Human subjective evaluation is the ”gold standard” to evaluate speech quality optimized for human perception.
170 | Perceptual objective metrics serve as a proxy for subjective scores. The conventional and widely used metrics
171 | require a reference clean speech signal, which is unavailable in real recordings. The no-reference approaches
172 | correlate poorly with human ratings and are not widely adopted in the research community. One of the biggest
173 | use cases of these perceptual objective metrics is to evaluate noise suppression algorithms. DNSMOS generalizes
174 | well in challenging test conditions with a high correlation to human ratings in stack ranking noise suppression
175 | methods. More details can be found in [DNSMOS paper](https://arxiv.org/pdf/2010.15258.pdf).
176 |
177 |
178 | .. note:: using this metric requires you to have ``librosa``, ``onnxruntime`` and ``requests`` installed.
179 | Install as ``pip install librosa onnxruntime-gpu requests``.
180 |
181 | Args:
182 | preds: [..., time]
183 | fs: sampling frequency
184 | personalized: whether interfering speaker is penalized
185 | device: the device used for calculating DNSMOS, can be cpu or cuda:n, where n is the index of gpu.
186 | If None is given, then the device of input is used.
187 |
188 | Returns:
189 | Float tensor with shape ``(..., 4)`` of DNSMOS values per sample, i.e. [p808_mos, mos_sig, mos_bak, mos_ovr]
190 |
191 | Raises:
192 | ModuleNotFoundError:
193 | If ``librosa``, ``onnxruntime`` or ``requests`` packages are not installed
194 |
195 | Example:
196 | >>> from torch import randn
197 | >>> from torchmetrics.functional.audio.dnsmos import deep_noise_suppression_mean_opinion_score
198 | >>> g = torch.manual_seed(1)
199 | >>> preds = randn(8000)
200 | >>> deep_noise_suppression_mean_opinion_score(preds, 8000, False)
201 | tensor([2.2285, 2.1132, 1.3972, 1.3652], dtype=torch.float64)
202 |
203 | """
204 | if not _LIBROSA_AVAILABLE or not _ONNXRUNTIME_AVAILABLE or not _REQUESTS_AVAILABLE:
205 | raise ModuleNotFoundError("DNSMOS metric requires that librosa, onnxruntime and requests are installed."
206 | " Install as `pip install librosa onnxruntime-gpu requests`.")
207 | device = torch.device(device) if device is not None else preds.device
208 |
209 | onnx_sess = _load_session(f"{DNSMOS_DIR}/{'p' if personalized else ''}DNSMOS/sig_bak_ovr.onnx", device)
210 | p808_onnx_sess = _load_session(f"{DNSMOS_DIR}/DNSMOS/model_v8.onnx", device)
211 |
212 | desired_fs = SAMPLING_RATE
213 | if fs != desired_fs:
214 | audio = librosa.resample(preds.cpu().numpy(), orig_sr=fs, target_sr=desired_fs)
215 | else:
216 | audio = preds.cpu().numpy()
217 |
218 | # normalize audio
219 | audio = audio / np.max(np.abs(audio),axis=-1,keepdims=True)
220 |
221 | len_samples = int(INPUT_LENGTH * desired_fs)
222 | while audio.shape[-1] < len_samples:
223 | audio = np.concatenate([audio, audio], axis=-1)
224 |
225 |
226 | num_hops = int(np.floor(audio.shape[-1] / desired_fs) - INPUT_LENGTH) + 1
227 |
228 | moss = []
229 | hop_len_samples = desired_fs
230 | for idx in range(num_hops):
231 | audio_seg = audio[..., int(idx * hop_len_samples):int((idx + INPUT_LENGTH) * hop_len_samples)]
232 | if audio_seg.shape[-1] < len_samples:
233 | continue
234 | shape = audio_seg.shape
235 | audio_seg = audio_seg.reshape((-1, shape[-1]))
236 |
237 | input_features = np.array(audio_seg).astype("float32")
238 | p808_input_features = np.array(_audio_melspec(audio=audio_seg[..., :-160])).astype("float32")
239 |
240 | if device.type != "cpu" and "CUDAExecutionProvider" in ort.get_all_providers():
241 | input_features = ort.OrtValue.ortvalue_from_numpy(input_features, device.type, device.index)
242 | p808_input_features = ort.OrtValue.ortvalue_from_numpy(p808_input_features, device.type, device.index)
243 |
244 | oi = {"input_1": input_features}
245 | p808_oi = {"input_1": p808_input_features}
246 | mos_np = np.concatenate([p808_onnx_sess.run(None, p808_oi)[0], onnx_sess.run(None, oi)[0]], axis=-1, dtype="float64")
247 | mos_np = _polyfit_val(mos_np, personalized)
248 |
249 | mos_np = mos_np.reshape(shape[:-1] + (4, ))
250 | moss.append(mos_np)
251 | return torch.from_numpy(np.mean(np.stack(moss, axis=-1), axis=-1)) # [p808_mos, mos_sig, mos_bak, mos_ovr]
252 |
--------------------------------------------------------------------------------
/models/utils/ensemble.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import *
3 | import torch
4 |
5 |
6 | def ensemble(opts: Union[int, str, List[str]], ckpt: str) -> Tuple[List[str], Dict]:
7 | """ensemble checkpoints
8 |
9 | Args:
10 | opts: ensemble last N epochs if opts is int; ensemble globed checkpoints if opts is str; ensemble specified checkpoints if opts is a list.
11 | ckpt: the current checkpoint path
12 |
13 | Returns:
14 | ckpts: the checkpoints ensembled
15 | state_dict
16 | """
17 | # parse the ensemble args to obtains the ckpts to ensemble
18 | if isinstance(opts, int):
19 | ckpts = []
20 | if opts > 0:
21 | epoch = int(Path(ckpt).name.split('_')[0].replace('epoch', ''))
22 | for epc in range(max(0, epoch - opts), epoch, 1):
23 | path = list(Path(ckpt).parent.glob(f'epoch{epc}_*'))[0]
24 | ckpts.append(path)
25 | elif isinstance(opts, list):
26 | assert len(opts) > 0, opts
27 | ckpts = list(set(opts))
28 | else: # e.g. logs/SSFNetLM/version_100/checkpoints/epoch* or epoch*
29 | assert isinstance(opts, str), opts
30 | ckpts = list(Path(opts).parent.glob(Path(opts).name))
31 | if len(ckpts) == 0:
32 | ckpts = list(Path(ckpt).parent.glob(opts))
33 | assert len(ckpts) > 0, f"checkpoints not found in {opts} or {Path(ckpt).parent/opts}"
34 | ckpts = ckpts + [ckpt]
35 |
36 | # remove redundant ckpt
37 | ckpts_ = dict()
38 | for ckpt in ckpts:
39 | ckpts_[Path(ckpt).name] = str(ckpt)
40 | ckpts = list(ckpts_.values())
41 | ckpts.sort()
42 |
43 | # load weights from checkpoints
44 | state_dict = dict()
45 | for path in ckpts:
46 | data = torch.load(path, map_location='cpu')
47 | for k, v in data['state_dict'].items():
48 | if k in state_dict:
49 | state_dict[k] += (v / len(ckpts))
50 | else:
51 | state_dict[k] = (v / len(ckpts))
52 | return ckpts, state_dict
53 |
--------------------------------------------------------------------------------
/models/utils/flops.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 |
4 | import torch
5 | import yaml
6 | from jsonargparse import ArgumentParser
7 | from pytorch_lightning import LightningModule
8 | from argparse import ArgumentParser as Parser
9 | import traceback
10 | from typing import *
11 |
12 | import operator
13 |
14 | from lightning_utilities.core.imports import compare_version
15 |
16 | _TORCH_GREATER_EQUAL_2_1 = compare_version("torch", operator.ge, "2.1.0", use_base_version=True)
17 |
18 |
19 | # this function is ported from lightning
20 | def measure_flops(
21 | model: torch.nn.Module,
22 | forward_fn: Callable[[], torch.Tensor],
23 | loss_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
24 | total: bool = True,
25 | ) -> int:
26 | """Utility to compute the total number of FLOPs used by a module during training or during inference.
27 |
28 | It's recommended to create a meta-device model for this, because:
29 | 1) the flops of LSTM cannot be measured if the model is not a meta-device model:
30 |
31 | Example::
32 |
33 | with torch.device("meta"):
34 | model = MyModel()
35 | x = torch.randn(2, 32)
36 |
37 | model_fwd = lambda: model(x)
38 | fwd_flops = measure_flops(model, model_fwd)
39 |
40 | model_loss = lambda y: y.sum()
41 | fwd_and_bwd_flops = measure_flops(model, model_fwd, model_loss)
42 |
43 | Args:
44 | model: The model whose FLOPs should be measured.
45 | forward_fn: A function that runs ``forward`` on the model and returns the result.
46 | loss_fn: A function that computes the loss given the ``forward_fn`` output. If provided, the loss and `backward`
47 | FLOPs will be included in the result.
48 |
49 | """
50 | if not _TORCH_GREATER_EQUAL_2_1:
51 | raise ImportError("`measure_flops` requires PyTorch >= 2.1.")
52 | from torch.utils.flop_counter import FlopCounterMode
53 |
54 | flop_counter = FlopCounterMode(model, display=False)
55 | with flop_counter:
56 | if loss_fn is None:
57 | forward_fn()
58 | else:
59 | loss_fn(forward_fn()).backward()
60 | if total:
61 | return flop_counter.get_total_flops()
62 | else:
63 | return flop_counter
64 |
65 |
66 | def detailed_flops(flop_counter) -> str:
67 | sss = ""
68 | for k, v in flop_counter.get_flop_counts().items():
69 | ss = f"{k}: {{"
70 | for kk, vv in v.items():
71 | ss += f" {str(kk)}:{vv}"
72 | ss += " }\n"
73 | sss += ss
74 | return sss
75 |
76 |
77 | class FakeModule(torch.nn.Module):
78 |
79 | def __init__(self, module: LightningModule) -> None:
80 | super().__init__()
81 | self.module = module
82 |
83 | def forward(self, x):
84 | return self.module.predict_step(x, 0)
85 |
86 |
87 | def _get_num_params(model: torch.nn.Module):
88 | num_params = sum(param.numel() for param in model.parameters())
89 | return num_params
90 |
91 |
92 | def _test_FLOPs(model: LightningModule, save_dir: str, num_chns: int, fs: int, audio_time_len: int = 4, num_params: int = None):
93 | if _TORCH_GREATER_EQUAL_2_1:
94 | x = torch.randn(1, num_chns, int(fs * audio_time_len), dtype=torch.float32).to('meta')
95 | model = model.to('meta')
96 |
97 | model_fwd = lambda: model(x, istft=False)
98 | fwd_flops = measure_flops(model, model_fwd, total=False)
99 |
100 | model_loss = lambda y: y[0].sum()
101 | fwd_and_bwd_flops = measure_flops(model, model_fwd, model_loss)
102 |
103 | with open(os.path.join(save_dir, 'FLOPs-detailed.txt'), 'w') as f:
104 | f.write(detailed_flops(fwd_flops))
105 | flops_forward_eval, flops_backward_eval = fwd_flops.get_total_flops(), fwd_and_bwd_flops - fwd_flops.get_total_flops()
106 | else:
107 | print(
108 | "Warning: FLOPs is measured with torchtnt.utils.flops.FlopTensorDispatchMode which doesn't support LSTM, if your model has LSTMs inside please upgrade to torch>=2.1.0, and use torch.utils.flop_counter.FlopCounterMode with tensor and model on meta device"
109 | )
110 | module = FakeModule(model)
111 |
112 | import copy
113 | from torchtnt.utils.flops import FlopTensorDispatchMode
114 |
115 | x = torch.randn(1, num_chns, int(fs * audio_time_len), dtype=torch.float32)
116 | flops_forward_eval, flops_backward_eval = 0, 0
117 | try:
118 | with FlopTensorDispatchMode(module) as ftdm:
119 | res = module(x).mean()
120 | flops_forward = copy.deepcopy(ftdm.flop_counts)
121 | flops_forward_eval = sum(list(flops_forward[''].values())) # MACs
122 |
123 | with open(os.path.join(save_dir, 'FLOPs-detailed.txt'), 'w') as f:
124 | for k, v in flops_forward.items():
125 | f.write(str(k) + ': { ')
126 | for kk, vv in v.items():
127 | f.write(str(kk).replace('.default', '') + ': ' + str(vv) + ', ')
128 | f.write(' }\n')
129 |
130 | ftdm.reset()
131 |
132 | res.backward()
133 | flops_backward = copy.deepcopy(ftdm.flop_counts)
134 | flops_backward_eval = sum(list(flops_backward[''].values()))
135 | except Exception as e:
136 | exp_file = os.path.join(save_dir, 'FLOPs-failed.txt')
137 | traceback.print_exc(file=open(exp_file, 'w'))
138 | print(f"FLOPs test failed '{repr(e)}', see {exp_file}")
139 |
140 | params_eval = num_params if num_params is not None else _get_num_params(module)
141 | flops_forward_eval_avg = flops_forward_eval / audio_time_len
142 | print(
143 | f"FLOPs: forward={flops_forward_eval/1e9:.2f} G, {flops_forward_eval_avg/1e9:.2f} G/s, back={flops_backward_eval/1e9:.2f} G, params: {params_eval/1e6:.3f} M, detailed: {os.path.join(save_dir, 'FLOPs-detailed.txt')}"
144 | )
145 |
146 | with open(os.path.join(save_dir, 'FLOPs.yaml'), 'w') as f:
147 | yaml.dump(
148 | {
149 | "flops_forward" if _TORCH_GREATER_EQUAL_2_1 else "macs_forward": f"{flops_forward_eval/1e9:.2f} G",
150 | "flops_forward_avg" if _TORCH_GREATER_EQUAL_2_1 else "macs_forward_avg": f"{flops_forward_eval_avg/1e9:.2f} G/s",
151 | "flops_backward" if _TORCH_GREATER_EQUAL_2_1 else "macs_backward": f"{flops_backward_eval/1e9:.2f} G",
152 | "params": f"{params_eval/1e6:.3f} M",
153 | "fs": fs,
154 | "audio_time_len": audio_time_len,
155 | "num_chns": num_chns,
156 | }, f)
157 | f.close()
158 |
159 |
160 | def import_class(class_path: str):
161 | try:
162 | iclass = importlib.import_module(class_path)
163 | return iclass
164 | except:
165 | imodule = importlib.import_module('.'.join(class_path.split('.')[:-1]))
166 | iclass = getattr(imodule, class_path.split('.')[-1])
167 | return iclass
168 |
169 |
170 | def _test_FLOPs_from_config(save_dir: str, model_class_path: str, num_chns: int, fs: int, audio_time_len: int = 4, config_file: str = None):
171 | if config_file is None:
172 | config_file = os.path.join(save_dir, 'config.yaml')
173 |
174 | model_class = import_class(model_class_path)
175 | with open(config_file, 'r', encoding='utf-8') as f:
176 | config = yaml.load(f, yaml.FullLoader)
177 | parser = ArgumentParser()
178 | parser.add_class_arguments(model_class)
179 |
180 | if 'compile' in config['model']:
181 | config['model']['compile'] = False # compiled model will fail to test its flops
182 | try:
183 | if 'compile' in config['model']['arch']['init_args']:
184 | config['model']['arch']['init_args']['compile'] = False
185 | except:
186 | ...
187 | model_config = parser.instantiate_classes(config['model'])
188 | model = model_class(**model_config.as_dict())
189 | num_params = _get_num_params(model=model)
190 | try:
191 | # torcheval report error for shared modules, so config to not share
192 | if "full_share" in config['model']['arch']['init_args']:
193 | if config['model']['arch']['init_args']['full_share'] == True:
194 | config['model']['arch']['init_args']['full_share'] = False
195 | model_config = parser.instantiate_classes(config['model'])
196 | model = model_class(**model_config.as_dict())
197 | elif type(config['model']['arch']['init_args']['full_share']) == int or config['model']['arch']['init_args']['full_share'] == None:
198 | config['model']['arch']['init_args']['full_share'] = 9999
199 | model_config = parser.instantiate_classes(config['model'])
200 | model = model_class(**model_config.as_dict())
201 | except Exception as e:
202 | ...
203 | _test_FLOPs(model, save_dir=save_dir, num_chns=num_chns, fs=fs, audio_time_len=audio_time_len, num_params=num_params)
204 |
205 |
206 | def write_FLOPs(model: LightningModule, save_dir: str, num_chns: int, fs: int = None, nfft: int = None, audio_time_len: int = 4, model_class_path: str = None):
207 | assert fs is not None or nfft is not None, (fs, nfft)
208 | if model_class_path is None:
209 | model_class_path = f"{str(model.__class__.__module__)}.{type(model).__name__}"
210 |
211 | if fs:
212 | cmd = f'CUDA_VISIBLE_DEVICES={model.device.index}, python -m models.utils.flops ' + f'--save_dir {save_dir} --model_class_path {model_class_path} ' + f'--num_chns {num_chns} --fs {fs} --audio_time_len {audio_time_len}'
213 | else:
214 | cmd = f'CUDA_VISIBLE_DEVICES={model.device.index}, python -m models.utils.flops ' + f'--save_dir {save_dir} --model_class_path {model_class_path} ' + f'--num_chns {num_chns} --nfft {nfft} --audio_time_len {audio_time_len}'
215 | print(cmd)
216 | os.system(cmd)
217 |
218 |
219 | if __name__ == '__main__':
220 | # CUDA_VISIBLE_DEVICES=5, python -m models.utils.flops --save_dir logs/SSFNetLM/version_90 --model_class_path models.SSFNetLM.SSFNetLM --num_chns 6 --fs 8000
221 | # CUDA_VISIBLE_DEVICES=5, python -m models.utils.flops --save_dir logs/SSFNetLM/version_90 --model_class_path models.SSFNetLM.SSFNetLM --num_chns 6 --nfft 256
222 | parser = Parser()
223 | parser.add_argument('--save_dir', type=str, required=True, help='save FLOPs to dir')
224 | parser.add_argument('--model_class_path', type=str, required=True, help='the import path of your Lightning Module')
225 | parser.add_argument('--num_chns', type=int, required=True, help='the number of microphone channels')
226 | parser.add_argument('--fs', type=int, default=None, help='sampling rate')
227 | parser.add_argument('--nfft', type=int, default=None, help='sampling rate')
228 | parser.add_argument('--audio_time_len', type=float, default=4., help='seconds of test mixture waveform')
229 | parser.add_argument('--config_file', type=str, default=None, help='config file path')
230 | args = parser.parse_args()
231 |
232 | fs = args.fs
233 | if fs is None:
234 | if args.nfft is None:
235 | print('MACs test error: you should specify the fs or nfft')
236 | exit(-1)
237 | fs = {256: 8000, 512: 16000, 320: 16000, 160: 8000}[args.nfft]
238 |
239 | _test_FLOPs_from_config(
240 | save_dir=args.save_dir,
241 | model_class_path=args.model_class_path,
242 | num_chns=args.num_chns,
243 | fs=fs,
244 | audio_time_len=args.audio_time_len,
245 | config_file=args.config_file,
246 | )
247 |
--------------------------------------------------------------------------------
/models/utils/general_steps.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from typing import *
4 | from pathlib import Path
5 |
6 | import pytorch_lightning as pl
7 | import soundfile as sf
8 | import torch
9 | from numpy import ndarray
10 | from pandas import DataFrame
11 | from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
12 | from torch import Tensor
13 |
14 | from models.utils import MyJsonEncoder, tag_and_log_git_status
15 | from models.utils.ensemble import ensemble
16 | from models.utils.flops import write_FLOPs
17 | from models.utils.metrics import (cal_metrics_functional, cal_pesq, recover_scale)
18 |
19 |
20 | def on_validation_epoch_end(self: pl.LightningModule, cpu_metric_input: List[Tuple[ndarray, ndarray, int]], N: int = 5) -> None:
21 | """calculate heavy metrics for every N epochs
22 |
23 | Args:
24 | self: LightningModule
25 | cpu_metric_input: the input list for cal_metrics_functional
26 | N: the number of epochs. Defaults to 5.
27 | """
28 |
29 | if self.current_epoch != 0 and self.current_epoch % N != (N - 1):
30 | cpu_metric_input.clear()
31 | return
32 |
33 | if len(cpu_metric_input) == 0:
34 | return
35 |
36 | torch.multiprocessing.set_sharing_strategy('file_system')
37 | num_thread = torch.multiprocessing.cpu_count() // (self.trainer.world_size * 2)
38 | p = torch.multiprocessing.Pool(min(num_thread, len(cpu_metric_input)))
39 | cpu_metrics = list(p.starmap(cal_metrics_functional, cpu_metric_input))
40 | p.close()
41 | p.join()
42 |
43 | for k in cpu_metric_input[0][0]:
44 | ms = list(filter(None, [m[0][k.lower()] for m in cpu_metrics]))
45 | if len(ms) > 0:
46 | self.log(f'val/{k}', sum(ms) / len(ms), sync_dist=True, batch_size=len(ms))
47 |
48 | cpu_metric_input.clear()
49 |
50 |
51 | def on_test_epoch_end(self: pl.LightningModule, results: List[Dict[str, Any]], cpu_metric_input: List, exp_save_path: str):
52 | """ calculate cpu metrics on CPU, collect results, save results to file
53 |
54 | Args:
55 | self: LightningModule
56 | results: the result list
57 | cpu_metric_input: the input list for cal_metrics_functional
58 | exp_save_path: the path to save result file
59 | """
60 |
61 | # calculate metrics, input_metrics, improve_metrics on CPU using multiprocessing to speed up
62 | torch.multiprocessing.set_sharing_strategy('file_system')
63 | num_thread = torch.multiprocessing.cpu_count() // (self.trainer.world_size * 2)
64 | p = torch.multiprocessing.Pool(min(num_thread, len(cpu_metric_input)))
65 | cpu_metrics = list(p.starmap(cal_metrics_functional, cpu_metric_input))
66 | p.close()
67 | p.join()
68 | for i, m in enumerate(cpu_metrics):
69 | metrics, input_metrics, imp_metrics = m
70 | results[i].update(input_metrics)
71 | results[i].update(imp_metrics)
72 | results[i].update(metrics)
73 |
74 | # gather results from all GPUs
75 | import torch.distributed as dist
76 |
77 | # collect results from other gpus if world_size > 1
78 | if self.trainer.world_size > 1:
79 | dist.barrier()
80 | results_list = [None for obj in results]
81 | dist.all_gather_object(results_list, results) # gather results from all gpus
82 | # merge them
83 | exist = set()
84 | results = []
85 | for rs in results_list:
86 | if rs == None:
87 | continue
88 | for r in rs:
89 | if r['wavname'] not in exist:
90 | results.append(r)
91 | exist.add(r['wavname'])
92 |
93 | # save collected data on 0-th gpu
94 | if self.trainer.is_global_zero:
95 | # save
96 | import datetime
97 | x = datetime.datetime.now()
98 | dtstr = x.strftime('%Y%m%d_%H%M%S.%f')
99 | path = os.path.join(exp_save_path, 'results_{}.json'.format(dtstr))
100 | # write results to json
101 | f = open(path, 'w', encoding='utf-8')
102 | json.dump(results, f, indent=4, cls=MyJsonEncoder)
103 | f.close()
104 | # write mean to json
105 | df = DataFrame(results)
106 | df.mean(numeric_only=True).to_json(os.path.join(exp_save_path, 'results_mean.json'), indent=4)
107 | self.print('results: ', os.path.join(exp_save_path, 'results_mean.json'), ' ', path)
108 |
109 |
110 | def on_predict_batch_end(
111 | self: pl.LightningModule,
112 | outputs: Optional[Any],
113 | batch: Any,
114 | ) -> None:
115 | """save predicted results to `log_dir/examples`
116 |
117 | Args:
118 | self: LightningModule
119 | outputs: _description_
120 | batch: _description_
121 | """
122 | save_dir = self.trainer.logger.log_dir + '/' + 'examples'
123 | os.makedirs(save_dir, exist_ok=True)
124 |
125 | if not isinstance(batch, Tensor):
126 | input, target, paras = batch
127 | if 'saveto' in paras[0]:
128 | for b in range(len(paras)):
129 | saveto = paras[b]['saveto']
130 | if isinstance(saveto, str):
131 | saveto = [saveto]
132 |
133 | assert isinstance(saveto, list), ('saveto should be a list of size num_speakers', type(saveto))
134 | for spk, spk_saveto in enumerate(saveto):
135 | if isinstance(spk_saveto, dict):
136 | input_saveto = spk_saveto['input'] if 'input' in spk_saveto else None
137 | target_saveto = spk_saveto['target'] if 'target' in spk_saveto else None
138 | pred_saveto = spk_saveto['prediction'] if 'prediction' in spk_saveto else None
139 | else:
140 | pred_saveto, input_saveto, target_saveto = spk_saveto, None, None
141 |
142 | # save predictions
143 | if pred_saveto:
144 | y = outputs[b][spk]
145 | assert len(y.shape) == 1, y.shape
146 | save_path = Path(save_dir) / pred_saveto
147 | os.makedirs(os.path.dirname(save_path), exist_ok=True)
148 | sf.write(save_path, y.detach().cpu().numpy(), samplerate=paras[b]['sample_rate'])
149 | # save input
150 | if input_saveto:
151 | y = input[b].T # [T,CHN]
152 | save_path = Path(save_dir) / input_saveto
153 | os.makedirs(os.path.dirname(save_path), exist_ok=True)
154 | sf.write(save_path, y.detach().cpu().numpy(), samplerate=paras[b]['sample_rate'])
155 | # # save target
156 | # if input_saveto and target is not None:
157 | # y = target[b, spk, :, :].T # [T,CHN]
158 | # save_path = save_dir + '/' + target_saveto
159 | # os.makedirs(os.path.dirname(save_path), exist_ok=True)
160 | # sf.write(save_path, y.detach().cpu().numpy(), samplerate=paras[b]['sample_rate'])
161 |
162 |
163 | def on_load_checkpoint(
164 | self: pl.LightningModule,
165 | checkpoint: Dict[str, Any],
166 | ensemble_opts: Union[int, str, List[str], Literal[None]] = None,
167 | compile: bool = True,
168 | reset: List[str] = [],
169 | ) -> None:
170 | """load checkpoint
171 |
172 | Args:
173 | self: LightningModule
174 | checkpoint: the loaded weights
175 | ensemble_opts: opts for ensemble. Defaults to None.
176 | compile: whether the checkpoint is a compiled one. Defaults to True.
177 | """
178 | from pytorch_lightning.strategies import FSDPStrategy
179 | if isinstance(self.trainer.strategy, FSDPStrategy):
180 | rank_zero_warn('using fsdp, ensemble is disenabled')
181 | return super(pl.LightningModule, self).on_load_checkpoint(checkpoint)
182 |
183 | if ensemble_opts:
184 | ckpt = self.trainer.ckpt_path
185 | ckpts, state_dict = ensemble(opts=ensemble_opts, ckpt=ckpt)
186 | self.print(f'rank {self.trainer.local_rank}/{self.trainer.world_size}, ensemble {ensemble_opts}: {ckpts}')
187 | checkpoint['state_dict'] = state_dict
188 |
189 | # rename weights for removing _orig_mod in name
190 | name_mapping = {} # {name without _orig_mod: the actual name}
191 | parameters = self.state_dict()
192 | for k, v in parameters.items():
193 | name_mapping[k.replace('_orig_mod.', '')] = k
194 |
195 | state_dict = checkpoint['state_dict']
196 | state_dict_new = dict()
197 | for k, v, in state_dict.items():
198 | state_dict_new[name_mapping[k.replace('_orig_mod.', '')]] = v
199 | checkpoint['state_dict'] = state_dict_new
200 |
201 | # reset optimizer and lr_scheduler
202 | if reset is not None:
203 | for key in reset:
204 | assert key in ['optimizer', 'lr_scheduler'], f'unsupported reset key {key}'
205 | if key == 'optimizer':
206 | checkpoint['optimizer'] = dict()
207 | checkpoint['optimizer_states'] = []
208 | rank_zero_info('reset optimizer')
209 | elif key == 'lr_scheduler':
210 | checkpoint['lr_scheduler'] = dict()
211 | checkpoint['lr_schedulers'] = []
212 | rank_zero_info('reset lr_scheduler')
213 |
214 | return super(pl.LightningModule, self).on_load_checkpoint(checkpoint)
215 |
216 |
217 | def on_train_start(self: pl.LightningModule, exp_name: str, model_name: str, num_chns: int, nfft: int, model_class_path: str = None):
218 | """ 1) add git tags/write requirements for better change tracking; 2) write model architecture to file; 3) measure the model FLOPs
219 |
220 | Args:
221 | self: LightningModule
222 | exp_name: `notag` or exp_name, add git tag e.g. 'model_name_v10' if exp_name!='notag'
223 | model_name: the model name
224 | num_chns: the number of channels for FLOPs test
225 | nfft: the number of fft points
226 | model_class_path: the path to import the self
227 | """
228 | if self.current_epoch == 0:
229 | if self.trainer.is_global_zero and hasattr(self.logger, 'log_dir') and 'notag' not in exp_name:
230 | # add git tags for better change tracking
231 | # note: if change self.logger.log_dir to self.trainer.log_dir, the training will stuck on multi-gpu training
232 | tag_and_log_git_status(self.logger.log_dir + '/git.out', self.logger.version, exp_name, model_name=model_name)
233 |
234 | if self.trainer.is_global_zero and hasattr(self.logger, 'log_dir'):
235 | # write model architecture to file
236 | with open(self.logger.log_dir + '/model.txt', 'a') as f:
237 | f.write(str(self))
238 | f.write('\n\n\n')
239 | # measure the model FLOPs, the num_chns here only means the original channels
240 | # write_FLOPs(model=self, save_dir=self.logger.log_dir, num_chns=num_chns, nfft=nfft, model_class_path=model_class_path)
241 |
242 |
243 | def configure_optimizers(
244 | self: pl.LightningModule,
245 | optimizer: str,
246 | optimizer_kwargs: Dict[str, Any],
247 | monitor: str = 'val/loss',
248 | lr_scheduler: str = None,
249 | lr_scheduler_kwargs: Dict[str, Any] = None,
250 | ):
251 | """configure optimizer and lr_scheduler"""
252 | if optimizer == 'Adam' and self.trainer.precision == '16-mixed':
253 | if 'eps' not in optimizer_kwargs:
254 | optimizer_kwargs['eps'] = 1e-4 # according to https://discuss.pytorch.org/t/adam-half-precision-nans/1765
255 | rank_zero_info('setting the eps of Adam to 1e-4 for FP16 mixed precision training')
256 | else:
257 | allowed_minimum = torch.finfo(torch.float16).eps
258 | assert optimizer_kwargs['eps'] >= allowed_minimum, f"You should specify an eps greater than the allowed minimum of the FP16 precision: {optimizer_kwargs['eps']} {allowed_minimum}"
259 | optimizer = getattr(torch.optim, optimizer)(self.parameters(), **optimizer_kwargs)
260 |
261 | if lr_scheduler is not None and len(lr_scheduler) > 0:
262 | lr_scheduler = getattr(torch.optim.lr_scheduler, lr_scheduler)
263 | return {
264 | 'optimizer': optimizer,
265 | 'lr_scheduler': {
266 | 'scheduler': lr_scheduler(optimizer, **lr_scheduler_kwargs),
267 | 'monitor': monitor,
268 | }
269 | }
270 | else:
271 | return optimizer
272 |
273 |
274 | def test_setp_write_example(self, xr: Tensor, yr: Tensor, yr_hat: Tensor, sample_rate: int, paras: Dict[str, Any], result_dict: Dict[str, Any], wavname: str, exp_save_path: str):
275 | """
276 | Args:
277 | xr: [B,T]
278 | yr: [B,Spk,T]
279 | yr_hat: [B,Spk,T]
280 | """
281 |
282 | # write examples
283 | abs_max = max(torch.max(torch.abs(xr[0, ...])), torch.max(torch.abs(yr[0, ...])))
284 |
285 | def write_wav(wav_path: str, wav: torch.Tensor, norm_to: torch.Tensor = None):
286 | # make sure wav don't have illegal values (abs greater than 1)
287 | if norm_to:
288 | wav = wav / torch.max(torch.abs(wav)) * norm_to
289 | if abs_max > 1:
290 | wav /= abs_max
291 | abs_max_wav = torch.max(torch.abs(wav))
292 | if abs_max_wav > 1:
293 | import warnings
294 | warnings.warn(f"abs_max_wav > 1, {abs_max_wav}")
295 | wav /= abs_max_wav
296 | sf.write(wav_path, wav.detach().cpu().numpy(), sample_rate)
297 |
298 | pattern = '.'.join(wavname.split('.')[:-1]) + '{name}' # remove .wav in wavname
299 | example_dir = os.path.join(exp_save_path, 'examples', str(paras[0]['index']))
300 | os.makedirs(example_dir, exist_ok=True)
301 | # save preds and targets for each speaker
302 | for i in range(yr.shape[1]):
303 | # write ys
304 | wav_path = os.path.join(example_dir, pattern.format(name=f"_spk{i+1}.wav"))
305 | write_wav(wav_path=wav_path, wav=yr[0, i])
306 | # write ys_hat
307 | wav_path = os.path.join(example_dir, pattern.format(name=f"_spk{i+1}_p.wav"))
308 | write_wav(wav_path=wav_path, wav=yr_hat[0, i]) #, norm_to=ys[0, i].abs().max())
309 | # write mix
310 | wav_path = os.path.join(example_dir, pattern.format(name=f"_mix.wav"))
311 | write_wav(wav_path=wav_path, wav=xr[0, :])
312 |
313 | # write paras & results
314 | f = open(os.path.join(example_dir, pattern.format(name=f"_paras.json")), 'w', encoding='utf-8')
315 | paras[0]['metrics'] = result_dict
316 | json.dump(paras[0], f, indent=4, cls=MyJsonEncoder)
317 | f.close()
318 |
--------------------------------------------------------------------------------
/models/utils/git_tools.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | def tag_and_log_git_status(log_to: str, version: str, exp_name: str, model_name: str) -> None:
4 | # add git tags for better change tracking
5 | import subprocess
6 | gitout = open(log_to, 'a', encoding='utf-8')
7 | del_tag = f'git tag -d {model_name}_v{version}'
8 | add_tag = f'git tag -a {model_name}_v{version} -m "{exp_name}"'
9 | print_branch = "git branch -vv"
10 | print_status = 'git status'
11 | print_status2 = f'pip freeze > {str(Path(log_to).expanduser().parent)}/requirements_pip.txt'
12 | print_status3 = f'conda list -e > {str(Path(log_to).expanduser().parent)}/requirements_conda.txt'
13 | cmds = [del_tag, add_tag, print_branch, print_status, print_status2, print_status3]
14 | for cmd in cmds:
15 | p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE, encoding="utf-8", universal_newlines=True, shell=True)
16 | o, err = p.communicate()
17 | gitout.write(f'========={cmd}=========\n{o}\n\n\n')
18 | gitout.close()
19 |
--------------------------------------------------------------------------------
/models/utils/metrics.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional, Tuple, Union
2 | import warnings
3 | from torchmetrics import Metric
4 | from torchmetrics.collections import MetricCollection
5 | from torchmetrics.audio import *
6 | from torchmetrics.functional.audio import *
7 | from torch import Tensor
8 | import torch
9 | import pesq as pesq_backend
10 | import numpy as np
11 | from typing import *
12 | from models.utils.dnsmos import deep_noise_suppression_mean_opinion_score
13 |
14 | ALL_AUDIO_METRICS = ['SDR', 'SI_SDR', 'SI_SNR', 'SNR', 'NB_PESQ', 'WB_PESQ', 'STOI', 'DNSMOS', 'pDNSMOS']
15 |
16 |
17 | def get_metric_list_on_device(device: Optional[str]):
18 | metric_device = {
19 | None: ['SDR', 'SI_SDR', 'SNR', 'SI_SNR', 'NB_PESQ', 'WB_PESQ', 'STOI', 'ESTOI', 'DNSMOS', 'pDNSMOS'],
20 | "cpu": ['NB_PESQ', 'WB_PESQ', 'STOI', 'ESTOI'],
21 | "gpu": ['SDR', 'SI_SDR', 'SNR', 'SI_SNR', 'DNSMOS', 'pDNSMOS'],
22 | }
23 | return metric_device[device]
24 |
25 |
26 | def cal_metrics_functional(
27 | metric_list: List[str],
28 | preds: Tensor,
29 | target: Tensor,
30 | original: Optional[Tensor],
31 | fs: int,
32 | device_only: Literal['cpu', 'gpu', None] = None, # cpu-only: pesq, stoi;
33 | chunk: Tuple[float, float] = None, # (chunk length, hop length) in seconds for chunk-wise metric evaluation
34 | suffix: str = "",
35 | ) -> Tuple[Dict[str, Tensor], Dict[str, Tensor], Dict[str, Tensor]]:
36 | metrics, input_metrics, imp_metrics = {}, {}, {}
37 | if chunk is not None:
38 | clen, chop = int(fs * chunk[0]), int(fs * chunk[1])
39 | for i in range(int((preds.shape[-1] / fs - chunk[0]) / chunk[1]) + 1):
40 | metrics_chunk, input_metrics_chunk, imp_metrics_chunk = cal_metrics_functional(
41 | metric_list,
42 | preds[..., i * chop:i * chop + clen],
43 | target[..., i * chop:i * chop + clen],
44 | original[..., i * chop:i * chop + clen] if original is not None else None,
45 | fs,
46 | device_only,
47 | chunk=None,
48 | suffix=f"_{i*chunk[1]+1}s-{i*chunk[1]+chunk[0]}s",
49 | )
50 | metrics.update(metrics_chunk), input_metrics.update(input_metrics_chunk), imp_metrics.update(imp_metrics_chunk)
51 |
52 | if device_only is None or device_only == 'cpu':
53 | preds_cpu = preds.detach().cpu()
54 | target_cpu = target.detach().cpu()
55 | original_cpu = original.detach().cpu() if original is not None else None
56 | else:
57 | preds_cpu = None
58 | target_cpu = None
59 | original_cpu = None
60 |
61 | for m in metric_list:
62 | mname = m.lower()
63 | if m.upper() not in get_metric_list_on_device(device=device_only):
64 | continue
65 |
66 | if m.upper() == 'SDR':
67 | ## not use signal_distortion_ratio for it gives NaN sometimes
68 | metric_func = lambda: signal_distortion_ratio(preds, target).detach().cpu()
69 | input_metric_func = lambda: signal_distortion_ratio(original, target).detach().cpu()
70 | # assert preds.dim() == 2 and target.dim() == 2 and original.dim() == 2, '(spk, time)!'
71 | # metric_func = lambda: torch.tensor(bss_eval_sources(target_cpu.numpy(), preds_cpu.numpy(), False)[0]).mean().detach().cpu()
72 | # input_metric_func = lambda: torch.tensor(bss_eval_sources(target_cpu.numpy(), original_cpu.numpy(), False)[0]).mean().detach().cpu()
73 | elif m.upper() == 'SI_SDR':
74 | metric_func = lambda: scale_invariant_signal_distortion_ratio(preds, target).detach().cpu()
75 | input_metric_func = lambda: scale_invariant_signal_distortion_ratio(original, target).detach().cpu()
76 | elif m.upper() == 'SI_SNR':
77 | metric_func = lambda: scale_invariant_signal_noise_ratio(preds, target).detach().cpu()
78 | input_metric_func = lambda: scale_invariant_signal_noise_ratio(original, target).detach().cpu()
79 | elif m.upper() == 'SNR':
80 | metric_func = lambda: signal_noise_ratio(preds, target).detach().cpu()
81 | input_metric_func = lambda: signal_noise_ratio(original, target).detach().cpu()
82 | elif m.upper() == 'NB_PESQ':
83 | metric_func = lambda: perceptual_evaluation_speech_quality(preds_cpu, target_cpu, fs, 'nb', n_processes=0)
84 | input_metric_func = lambda: perceptual_evaluation_speech_quality(original_cpu, target_cpu, fs, 'nb', n_processes=0)
85 | elif m.upper() == 'WB_PESQ':
86 | metric_func = lambda: perceptual_evaluation_speech_quality(preds_cpu, target_cpu, fs, 'wb', n_processes=0)
87 | input_metric_func = lambda: perceptual_evaluation_speech_quality(original_cpu, target_cpu, fs, 'wb', n_processes=0)
88 | elif m.upper() == 'STOI':
89 | metric_func = lambda: short_time_objective_intelligibility(preds_cpu, target_cpu, fs)
90 | input_metric_func = lambda: short_time_objective_intelligibility(original_cpu, target_cpu, fs)
91 | elif m.upper() == 'ESTOI':
92 | metric_func = lambda: short_time_objective_intelligibility(preds_cpu, target_cpu, fs, extended=True)
93 | input_metric_func = lambda: short_time_objective_intelligibility(original_cpu, target_cpu, fs, extended=True)
94 | elif m.upper() == 'DNSMOS':
95 | metric_func = lambda: deep_noise_suppression_mean_opinion_score(preds, fs, False)
96 | input_metric_func = lambda: deep_noise_suppression_mean_opinion_score(original, fs, False)
97 | elif m.upper() == 'PDNSMOS': # personalized DNSMOS
98 | metric_func = lambda: deep_noise_suppression_mean_opinion_score(preds, fs, True)
99 | input_metric_func = lambda: deep_noise_suppression_mean_opinion_score(original, fs, True)
100 | else:
101 | raise ValueError('Unkown audio metric ' + m)
102 |
103 | if m.upper() == 'WB_PESQ' and fs == 8000:
104 | # warnings.warn("There is narrow band (nb) mode only when sampling rate is 8000Hz")
105 | continue # Note there is narrow band (nb) mode only when sampling rate is 8000Hz
106 |
107 | try:
108 | if mname == 'dnsmos':
109 | # p808_mos, mos_sig, mos_bak, mos_ovr
110 | m_val = metric_func().cpu().numpy()
111 |
112 | for idx, mid in enumerate(['p808', 'sig', 'bak', 'ovr']):
113 | mname_i = mname + '_' + mid + suffix
114 | metrics[mname_i] = np.mean(m_val[..., idx]).item()
115 | metrics[mname_i + '_all'] = m_val[..., idx].tolist()
116 | if original is None:
117 | continue
118 |
119 | if 'input_' + mname_i not in input_metrics.keys():
120 | im_val = input_metric_func().cpu().numpy()
121 | input_metrics['input_' + mname_i] = np.mean(im_val[..., idx]).item()
122 | input_metrics['input_' + mname_i + '_all'] = im_val[..., idx].tolist()
123 |
124 | imp_metrics[mname_i + '_i'] = metrics[mname_i] - input_metrics['input_' + mname_i] # _i means improvement
125 | imp_metrics[mname_i + '_all' + '_i'] = (m_val[..., idx] - im_val[..., idx]).tolist()
126 | continue
127 |
128 | mname = mname + suffix
129 | m_val = metric_func().cpu().numpy()
130 | metrics[mname] = np.mean(m_val).item()
131 | metrics[mname + '_all'] = m_val.tolist() # _all means not averaged
132 | if original is None:
133 | continue
134 |
135 | if 'input_' + mname not in input_metrics.keys():
136 | im_val = input_metric_func().cpu().numpy()
137 | input_metrics['input_' + mname] = np.mean(im_val).item()
138 | input_metrics['input_' + mname + '_all'] = im_val.tolist()
139 |
140 | imp_metrics[mname + '_i'] = metrics[mname] - input_metrics['input_' + mname] # _i means improvement
141 | imp_metrics[mname + '_all' + '_i'] = (m_val - im_val).tolist()
142 | except Exception as e:
143 | metrics[mname] = None
144 | metrics[mname + '_all'] = None
145 | if 'input_' + mname not in input_metrics.keys():
146 | input_metrics['input_' + mname] = None
147 | input_metrics['input_' + mname + '_all'] = None
148 | imp_metrics[mname + '_i'] = None
149 | imp_metrics[mname + '_i' + '_all'] = None
150 |
151 | return metrics, input_metrics, imp_metrics
152 |
153 | def mypesq(preds: np.ndarray, target: np.ndarray, mode: str, fs: int) -> np.ndarray:
154 | # 使用ndarray是因为tensor会在linux上会导致一些多进程的错误
155 | ori_shape = preds.shape
156 | if type(preds) == Tensor:
157 | preds = preds.detach().cpu().numpy()
158 | target = target.detach().cpu().numpy()
159 | else:
160 | assert type(preds) == np.ndarray, type(preds)
161 | assert type(target) == np.ndarray, type(target)
162 |
163 | if preds.ndim == 1:
164 | pesq_val = pesq_backend.pesq(fs=fs, ref=target, deg=preds, mode=mode)
165 | else:
166 | preds = preds.reshape(-1, ori_shape[-1])
167 | target = target.reshape(-1, ori_shape[-1])
168 | pesq_val = np.empty(shape=(preds.shape[0]))
169 | for b in range(preds.shape[0]):
170 | pesq_val[b] = pesq_backend.pesq(fs=fs, ref=target[b, :], deg=preds[b, :], mode=mode)
171 | pesq_val = pesq_val.reshape(ori_shape[:-1])
172 | return pesq_val
173 |
174 |
175 | def cal_pesq(ys: np.ndarray, ys_hat: np.ndarray, sample_rate: int) -> Tuple[float, float]:
176 | try:
177 | if sample_rate == 16000:
178 | wb_pesq_val = mypesq(preds=ys_hat, target=ys, fs=sample_rate, mode='wb').mean()
179 | nb_pesq_val = mypesq(preds=ys_hat, target=ys, fs=sample_rate, mode='nb').mean()
180 | return [wb_pesq_val, nb_pesq_val]
181 | elif sample_rate == 8000:
182 | nb_pesq_val = mypesq(preds=ys_hat, target=ys, fs=sample_rate, mode='nb').mean()
183 | return [None, nb_pesq_val]
184 | else:
185 | ...
186 | except Exception as e:
187 | ...
188 | # warnings.warn(str(e))
189 | return [None, None]
190 |
191 |
192 | def recover_scale(preds: Tensor, mixture: Tensor, scale_src_together: bool, norm_if_exceed_1: bool = True) -> Tensor:
193 | """recover wav's original scale by solving min ||Y^T a - X||F, cuz sisdr will lose scale
194 |
195 | Args:
196 | preds: prediction, shape [batch, n_src, time]
197 | mixture: mixture or noisy or reverberant signal, shape [batch, time]
198 | scale_src_together: keep the relative ennergy level between sources. can be used for scale-invariant SA-SDR
199 | norm_max_if_exceed_1: norm the magitude if exceeds one
200 |
201 | Returns:
202 | Tensor: the scale-recovered preds
203 | """
204 | # TODO: add some kind of weighting mechanism to make the predicted scales more precise
205 | # recover wav's original scale. solve min ||Y^T a - X||F to obtain the scales of the predictions of speakers, cuz sisdr will lose scale
206 | if scale_src_together:
207 | a = torch.linalg.lstsq(preds.sum(dim=-2, keepdim=True).transpose(-1, -2), mixture.unsqueeze(-1)).solution
208 | else:
209 | a = torch.linalg.lstsq(preds.transpose(-1, -2), mixture.unsqueeze(-1)).solution
210 |
211 | preds = preds * a
212 |
213 | if norm_if_exceed_1:
214 | # normalize the audios so that the maximum doesn't exceed 1
215 | max_vals = torch.max(torch.abs(preds), dim=-1).values
216 | norm = torch.where(max_vals > 1, max_vals, 1)
217 | preds = preds / norm.unsqueeze(-1)
218 | return preds
219 |
220 |
221 | if __name__ == "__main__":
222 | x, y, m = torch.rand((2, 8000 * 8)), torch.rand((2, 8000 * 8)), torch.rand((2, 8000 * 8))
223 | m, im, mi = cal_metrics_functional(["si_sdr"], preds=x, target=y, original=m, fs=8000, chunk=[4, 1])
224 | print(m)
225 | print(im)
226 | print(mi)
227 |
--------------------------------------------------------------------------------
/models/utils/my_earlystopping.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 | from pytorch_lightning.callbacks import EarlyStopping
3 |
4 |
5 | class MyEarlyStopping(EarlyStopping):
6 |
7 | def __init__(
8 | self,
9 | enable: bool = True, # enable EarlyStopping or not
10 | **kwargs,
11 | ):
12 | super().__init__(**kwargs)
13 | self.enable = enable
14 |
15 | def _should_skip_check(self, trainer: "pl.Trainer") -> bool:
16 | if self.enable == False:
17 | return True
18 | else:
19 | return super()._should_skip_check(trainer)
20 |
--------------------------------------------------------------------------------
/models/utils/my_json_encoder.py:
--------------------------------------------------------------------------------
1 | import json
2 | import numpy as np
3 | from torch import Tensor
4 | from pytorch_lightning.utilities.rank_zero import rank_zero_warn
5 |
6 |
7 | class MyJsonEncoder(json.JSONEncoder):
8 | large_array_size: bool = 100
9 | ignore_large_array: bool = True
10 |
11 | def default(self, obj):
12 | if isinstance(obj, np.int64) or isinstance(obj, np.float64) or isinstance(obj, np.float32):
13 | return obj.item()
14 | elif isinstance(obj, np.ndarray):
15 | if obj.size == 1:
16 | return obj.item()
17 | else:
18 | if obj.size > self.large_array_size:
19 | if self.ignore_large_array:
20 | rank_zero_warn('large array is ignored while saved to json file.')
21 | return None
22 | else:
23 | rank_zero_warn('large array detected. saving it in json is slow. please remove it')
24 | return obj.tolist()
25 | elif isinstance(obj, Tensor):
26 | if obj.numel() == 1:
27 | return obj.item()
28 | else:
29 | if obj.numel() > self.large_array_size:
30 | if self.ignore_large_array:
31 | rank_zero_warn('large array is ignored while saved to json file.')
32 | return None
33 | else:
34 | rank_zero_warn('large array detected. saving it in json is slow. please remove it')
35 | return obj.detach().cpu().numpy().tolist()
36 | return json.JSONEncoder.default(self, obj)
37 |
--------------------------------------------------------------------------------
/models/utils/my_logger.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional
2 | from pytorch_lightning.loggers import TensorBoardLogger
3 | from pytorch_lightning.utilities import rank_zero_only
4 |
5 |
6 | class MyLogger(TensorBoardLogger):
7 |
8 | @rank_zero_only
9 | def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
10 | for k, v in metrics.items():
11 | _my_step = step
12 | if k.startswith('val/'): # use epoch for val metrics
13 | _my_step = int(metrics['epoch'])
14 | super().log_metrics(metrics={k: v}, step=_my_step)
15 |
--------------------------------------------------------------------------------
/models/utils/my_progress_bar.py:
--------------------------------------------------------------------------------
1 | from pytorch_lightning.callbacks.progress import TQDMProgressBar
2 | import sys
3 | from torch import Tensor
4 | from pytorch_lightning import Trainer
5 |
6 |
7 | class MyProgressBar(TQDMProgressBar):
8 | """print out the metrics on_validation_epoch_end
9 | """
10 |
11 | def on_validation_epoch_end(self, trainer: Trainer, pl_module):
12 | super().on_validation_epoch_end(trainer, pl_module)
13 | sys.stdout.flush()
14 | if trainer.is_global_zero:
15 | metrics = trainer.logged_metrics
16 | infos = f"\x1B[1A\x1B[K\nEpoch {trainer.current_epoch} metrics: "
17 | for k, v in metrics.items():
18 | value = v
19 | if isinstance(v, Tensor):
20 | value = v.item()
21 | if isinstance(value, float):
22 | infos += k + f"={value:.4f} "
23 | else:
24 | infos += k + f"={value} "
25 | if len(metrics) > 0:
26 | sys.stdout.write(f'{infos}\x1B[K\n')
27 | sys.stdout.flush()
28 |
--------------------------------------------------------------------------------
/models/utils/my_rich_progress_bar.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | from pytorch_lightning import Trainer
4 | from pytorch_lightning.callbacks import RichProgressBar
5 | from pytorch_lightning.callbacks.progress.rich_progress import *
6 | from torch import Tensor
7 |
8 |
9 | class MyRichProgressBar(RichProgressBar):
10 | """A progress bar prints metrics at the end of each epoch
11 | """
12 |
13 | def on_validation_end(self, trainer: Trainer, pl_module):
14 | super().on_validation_end(trainer, pl_module)
15 | sys.stdout.flush()
16 | if trainer.is_global_zero:
17 | metrics = trainer.logged_metrics
18 | infos = f"Epoch {trainer.current_epoch} metrics: "
19 | for k, v in metrics.items():
20 | if k.startswith('train/'):
21 | continue
22 | value = v
23 | if isinstance(v, Tensor):
24 | value = v.item()
25 | if isinstance(value, float):
26 | if abs(value) < 1:
27 | infos += k + f"={value:.4e} "
28 | else:
29 | infos += k + f"={value:.4f} "
30 | else:
31 | infos += k + f"={value} "
32 | if len(metrics) > 0:
33 | sys.stdout.write(f'{infos}\n')
34 | sys.stdout.flush()
35 |
--------------------------------------------------------------------------------
/models/utils/my_save_config_callback.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 | from jsonargparse import Namespace
3 | from pytorch_lightning import Trainer, LightningModule
4 | from pytorch_lightning.cli import SaveConfigCallback
5 |
6 |
7 | class MySaveConfigCallback(SaveConfigCallback):
8 | ignores: List[str] = ['progress_bar', 'learning_rate_monitor', 'model_summary'] # 'model_checkpoint',
9 |
10 | def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None:
11 | for ignore in MySaveConfigCallback.ignores:
12 | self.del_config(ignore)
13 | super().setup(trainer, pl_module, stage)
14 |
15 | @staticmethod
16 | def add_ignores(ignore: str):
17 | MySaveConfigCallback.ignores.append(ignore)
18 |
19 | def del_config(self, ignore: str):
20 | if '.' not in ignore:
21 | if ignore in self.config:
22 | del self.config[ignore]
23 | else:
24 | config: Namespace = self.config
25 | ignore_namespace = ignore.split('.')
26 | for idx, name in enumerate(ignore_namespace):
27 | if idx != len(ignore_namespace) - 1:
28 | if name in config:
29 | config = config[name]
30 | else:
31 | return
32 | else:
33 | del config[name]
34 |
--------------------------------------------------------------------------------
/models/utils/shared_cli.py:
--------------------------------------------------------------------------------
1 | """
2 | Command Line Interface for different models, provides command line controls for training, test, and inference
3 | """
4 | import os
5 |
6 | os.environ["OMP_NUM_THREADS"] = str(1) # limit the threads to reduce cpu overloads, will speed up when there are lots of CPU cores on the running machine
7 |
8 | import torch
9 |
10 | torch.backends.cuda.matmul.allow_tf32 = True # The flag below controls whether to allow TF32 on matmul. This flag defaults to False in PyTorch 1.12 and later.
11 | torch.backends.cudnn.allow_tf32 = True # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
12 |
13 | import pytorch_lightning as pl
14 | from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint, ModelSummary
15 | from pytorch_lightning.cli import LightningArgumentParser, LightningCLI
16 | from jsonargparse import lazy_instance
17 |
18 | from models.utils import MyRichProgressBar as RichProgressBar
19 | # from pytorch_lightning.loggers import TensorBoardLogger
20 | from models.utils.my_logger import MyLogger as TensorBoardLogger
21 | from models.utils.my_save_config_callback import MySaveConfigCallback as SaveConfigCallback
22 |
23 |
24 | class SharedCLI(LightningCLI):
25 |
26 | def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
27 | parser.set_defaults({"trainer.strategy": "ddp_find_unused_parameters_false"})
28 |
29 | # RichProgressBar
30 | parser.add_lightning_class_args(RichProgressBar, nested_key='progress_bar')
31 | if pl.__version__.startswith('1.5.'):
32 | parser.set_defaults({
33 | "progress_bar.refresh_rate_per_second": 1,
34 | })
35 | else:
36 | parser.set_defaults({"progress_bar.console_kwargs": {
37 | "force_terminal": True,
38 | "no_color": True,
39 | "width": 200,
40 | }})
41 |
42 | # LearningRateMonitor
43 | parser.add_lightning_class_args(LearningRateMonitor, "learning_rate_monitor")
44 | learning_rate_monitor_defaults = {
45 | "learning_rate_monitor.logging_interval": "epoch",
46 | }
47 | parser.set_defaults(learning_rate_monitor_defaults)
48 |
49 | # ModelSummary
50 | parser.add_lightning_class_args(ModelSummary, 'model_summary')
51 | model_summary_defaults = {
52 | "model_summary.max_depth": -1,
53 | }
54 | parser.set_defaults(model_summary_defaults)
55 |
56 | def before_fit(self):
57 | resume_from_checkpoint: str = self.config['fit']['ckpt_path']
58 | if resume_from_checkpoint is not None and resume_from_checkpoint.endswith('last.ckpt'):
59 | # log in same dir
60 | # resume_from_checkpoint example: /mnt/home/quancs/projects/NBSS_pmt/logs/NBSS_ifp/version_29/checkpoints/last.ckpt
61 | resume_from_checkpoint = os.path.normpath(resume_from_checkpoint)
62 | splits = resume_from_checkpoint.split(os.path.sep)
63 | version = int(splits[-3].replace('version_', ''))
64 | save_dir = os.path.sep.join(splits[:-3])
65 | self.trainer.logger = TensorBoardLogger(save_dir=save_dir, name="", version=version, default_hp_metric=False)
66 | else:
67 | model_name = type(self.model).__name__
68 | self.trainer.logger = TensorBoardLogger('logs/', name=model_name, default_hp_metric=False)
69 |
70 | def before_test(self):
71 | if self.config['test']['ckpt_path'] != None:
72 | ckpt_path = self.config['test']['ckpt_path']
73 | else:
74 | raise Exception('You should give --ckpt_path if you want to test')
75 | epoch = os.path.basename(ckpt_path).split('_')[0]
76 | write_dir = os.path.dirname(os.path.dirname(ckpt_path))
77 |
78 | test_set = 'test'
79 | if 'test_set' in self.config['test']['data']:
80 | test_set = self.config['test']['data']["test_set"]
81 | elif 'init_args' in self.config['test']['data'] and 'test_set' in self.config['test']['data']['init_args']:
82 | test_set = self.config['test']['data']['init_args']["test_set"]
83 | exp_save_path = os.path.normpath(write_dir + '/' + epoch + '_' + test_set + '_set')
84 |
85 | self.trainer.logger = TensorBoardLogger(exp_save_path, name='', default_hp_metric=False)
86 | print(self.trainer.log_dir)
87 |
88 | def after_test(self):
89 | if not self.trainer.is_global_zero:
90 | return
91 | import fnmatch
92 | files = fnmatch.filter(os.listdir(self.trainer.log_dir), 'events.out.tfevents.*')
93 | for f in files:
94 | os.remove(self.trainer.log_dir + '/' + f)
95 | print('tensorboard log file for test is removed: ' + self.trainer.log_dir + '/' + f)
96 |
97 |
98 | if __name__ == '__main__':
99 | cli = SharedCLI(
100 | pl.LightningModule,
101 | pl.LightningDataModule,
102 | save_config_callback=SaveConfigCallback,
103 | save_config_kwargs={'overwrite': True},
104 | subclass_mode_data=True,
105 | subclass_mode_model=False,
106 | run=True,
107 | )
108 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | jsonargparse[signatures,urls]>=4.3.1
2 | torchmetrics[audio]
3 | omegaconf
4 | pytorch-lightning>=2.0.0
5 | torch>=1.12.1
6 | rich
7 | mir_eval
8 | soundfile
9 | pesq
10 | tensorboard
11 | pandas
12 | torcheval==0.0.6
13 | mamba-ssm
14 | causal-conv1d>=1.1.0
--------------------------------------------------------------------------------