├── .gitignore ├── LICENSE ├── README.md ├── VoiceMOS_baseline_README.md ├── configs ├── LDNet-ML_MobileNetV3_FFN_1e-3.yaml ├── LDNet-ML_MobileNetV3_FFN_1e-3_ft.yaml ├── LDNet-MN_MobileNetV3_RNN_FFN_1e-3_lamb4.yaml ├── LDNet_MobileNetV3_FFN_1e-3.yaml ├── LDNet_MobileNetV3_RNN_5e-3.yaml └── MBNet.yaml ├── data ├── download.sh └── vcc2018 │ ├── vcc2018_testing_data.csv │ ├── vcc2018_training_data.csv │ └── vcc2018_valid_data.csv ├── dataset.py ├── exp ├── FT_LDNet-ML-bs20 │ ├── config.yml │ ├── idtable.pkl │ └── model-6400.pt └── Pretrained-LDNet-ML-2337 │ ├── config.yml │ └── model-27000.pt ├── imgs ├── 60000_distribution.png ├── 60000_sys_scatter_plot_utt.png ├── 60000_utt_scatter_plot_utt.png └── results.png ├── inference.py ├── inference_for_voicemos.py ├── models ├── LDNet.py ├── MBNet.py ├── __init__.py ├── loss.py ├── mobilenetv2.py ├── mobilenetv3.py └── modules.py ├── optimizers.py ├── pack_for_voicemos.sh ├── schedulers.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # specific to thie repo 132 | data/ 133 | *.wav 134 | exp/ 135 | runs/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Wen-Chin Huang (unilight) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LDNet 2 | 3 | Author: Wen-Chin Huang (Nagoya University) 4 | Email: wen.chinhuang@g.sp.m.is.nagoya-u.ac.jp 5 | 6 | This is the official implementation of the paper "LDNet: Unified Listener Dependent Modeling in MOS Prediction for Synthetic Speech". This is a model that takes an input synthetic speech sample and outputs the simulated human rating. 7 | 8 | ![Results](./imgs/results.png) 9 | 10 | ## Requirements 11 | 12 | - PyTorch 1.9 (versions not too old should be fine.) 13 | - librosa 14 | - pandas 15 | - h5py 16 | - scipy 17 | - matplotlib 18 | - tqdm 19 | 20 | ## Usage 21 | 22 | The following instructions are for the VCC2018 benchmark. 23 | 24 | **New**: This system is also one of the baseline systems of the first VoiceMOS Challenge. Please refer to [this document](./VoiceMOS_baseline_README.md) for detailed instructions. 25 | 26 | ### Data preparation 27 | 28 | ``` 29 | # Download the VCC2018 dataset. 30 | cd data 31 | ./download.sh vcc2018 32 | ``` 33 | 34 | ### Training 35 | 36 | We provide configs that correspond to the following rows in the above figure: 37 | 38 | - (a): `MBNet.yaml` 39 | - (d): `LDNet_MobileNetV3_RNN_5e-3.yaml` 40 | - (e): `LDNet_MobileNetV3_FFN_1e-3.yaml` 41 | - (f): `LDNet-MN_MobileNetV3_RNN_FFN_1e-3_lamb4.yaml` 42 | - (g): `LDNet-ML_MobileNetV3_FFN_1e-3.yaml` 43 | 44 | ``` 45 | python train.py --config configs/ --tag 46 | ``` 47 | 48 | By default, the experimental results will be stored in `exp/`, including: 49 | 50 | - `model-.pt`: model checkpoints. 51 | - `config.yml`: the config file. 52 | - `idtable.pkl`: the dictionary that maps listener to ID. 53 | - `training_`: the validation results generated along the training. This file is useful for model selection. Note that the `inference_mode` in the config file decides what mode is used during validation in the training. 54 | 55 | There are some arguments that can be changed: 56 | 57 | - `--exp_dir`: The directory for storing the experimental results. 58 | - `--data_dir`: The data directory. Default is `data/vcc2018`. 59 | - `seed`: random seed. 60 | - `update_freq`: *This is very important. See below.* 61 | 62 | ### Batch size and `update_freq` 63 | 64 | By default, all LDNet models are trained with a batch size of 60. In my experiments, I used a single NVIDIA GeForce RTX 3090 with 24GB mdemory for training. I cannot fit the whole model in the GPU, so I accumulate gradients for `update_freq` forward passes and do one backward update. Before training, please check the `train_batch_size` in the config file, and set `update_freq` properly. For instance, in `configs/LDNet_MobileNetV3_FFN_1e-3.yaml` the `train_batch_size` is 20, so `update_freq` should be set to 3. 65 | 66 | ### Inference 67 | 68 | ``` 69 | python inference.py --tag LDNet-ML_MobileNetV3_FFN_1e-3 --mode mean_listener 70 | ``` 71 | 72 | Use `mode` to specify which inference mode to use. Choices are: `mean_net`, `all_listeners` and `mean_listener`. By default, all checkpoints in the exp directory will be evaluated. 73 | 74 | There are some arguments that can be changed: 75 | 76 | - `ep`: if you want to evaluate one model checkpoint, say, `model-10000.pt`, then simply pass `--ep 10000`. 77 | - `start_ep`: if you want to evaluate model checkpoints after a certain steps, say, 10000 steps later, then simply pass `--start_ep 10000`. 78 | 79 | There are some files you can inspect after the evaluation: 80 | 81 | - `_.csv`: the validation and test set results. 82 | - `__/`: figures that visualize the prediction distributions, including; 83 | - `_distribution.png`: distribution over the score range (1-5). 84 | 85 | - `_utt_scatter_plot_utt`: _utterance-wise_ scatter plot of the ground truth and the predicted scores. 86 | 87 | - `_sys_scatter_plot_utt`: _system-wise_ scatter plot of the ground truth and the predicted scores. 88 | 89 | 90 | ## Acknowledgement 91 | 92 | This repository inherits from this great [unofficial MBNet implementation](https://github.com/sky1456723/Pytorch-MBNet). 93 | 94 | 95 | ## Citation 96 | 97 | If you find this recipe useful, please consider citing following paper: 98 | ``` 99 | @article{huang2021ldnet, 100 | title={LDNet: Unified Listener Dependent Modeling in MOS Prediction for Synthetic Speech}, 101 | author={Huang, Wen-Chin and Cooper, Erica and Yamagishi, Junichi and Toda, Tomoki}, 102 | journal={arXiv preprint arXiv:2110.09103}, 103 | year={2021} 104 | } 105 | ``` 106 | -------------------------------------------------------------------------------- /VoiceMOS_baseline_README.md: -------------------------------------------------------------------------------- 1 | # LDNet: Baeline system of the first VoiceMOS challenge 2 | 3 | Author: Wen-Chin Huang (Nagoya University) 4 | Email: wen.chinhuang@g.sp.m.is.nagoya-u.ac.jp 5 | 6 | The LDNet system implemented in this repository serves as one of the baselines of the first VoiceMOS challenge, a challenge to compare different systems and approaches on the task of predicting the MOS score of synthetic speech. 7 | 8 | ## Training phase (phase 1) 9 | 10 | During the training phase, participants should receive datasets for the main (BVCC) track and the OOD track. Each dataset will contain the training set and the developement set. In this instruction, we demonstrate how to perform: 11 | 12 | - Training a LDNet from scratch with the BVCC training set. 13 | - Inference on the BVCC validation set with the trained LDNet. 14 | - Zero-shot inference on the OOD validation set with the LDNet trained on the BVCC training set. 15 | - Fine-tuning on the OOD labeled training set with with the LDNet pretrained on the BVCC training set. 16 | - Inference on the OOD validation set with the LDNet fine-tuned on the OOD labeled training set. 17 | - Submission generation for the CodaLab platform. 18 | 19 | ### Data preparation for both tracks 20 | 21 | After downloading the dataset preparation scripts for both tracks, please follow the instructions to gather the complete datasets. For the rest of this README, we assume that the data is put under `data/`, but feel free to put it somewhere else. The data directorty should have the following structure: 22 | ``` 23 | data 24 | ├── phase1-main 25 | │    ├── DATA 26 | │    │   ├── mydata_system.csv 27 | │    │   ├── sets 28 | │ │ │ ├── DEVSET 29 | │ │ │ ├── train_mos_list.txt 30 | │ │ │ ├── TRAINSET 31 | │ │ │ └── val_mos_list.txt 32 | │    │   └── wav 33 | │    └─── ... 34 | └── phase1-ood 35 | ├── DATA 36 | │   ├── mydata_system.csv 37 | │   ├── sets 38 | │ │ ├── DEVSET 39 | │ │ ├── train_mos_list.txt 40 | │ │ ├── TRAINSET 41 | │ │ ├── unlabeled_mos_list.txt 42 | │ │ └── val_mos_list.txt 43 | │   └── wav 44 | └── ... 45 | ``` 46 | 47 | ### Pretrained model 48 | 49 | We provide two pretrained models. Feel free to use these pretrained models to test the inference code and generate sample submission files to familizarize yourself with CodaLab. 50 | 51 | - `exp/Pretrained-LDNet-ML-2337`: model trained on the BVCC training set with the `LDNet-ML_MobileNetV3_FFN_1e-3.yaml` config and seed `2337`. 52 | - `exp/FT_LDNet-ML-bs20`: model fine-tuned on the OOD labeled training set with the `LDNet-ML_MobileNetV3_FFN_1e-3_ft.yaml` config and seed `1337`. 53 | 54 | ### Training a LDNet from scratch with the BVCC training set. 55 | 56 | Although you can use any config mentioned in the [main README](./README.md#Training), according to the [results](./imgs/results.png), `LDNet-ML` gives the best results. So, here we demonstrate how to use the `LDNet-ML_MobileNetV3_FFN_1e-3.yaml` config with seed `2337` to perform training: 57 | 58 | ``` 59 | python train.py --dataset_name BVCC --data_dir ./data/phase1-main/DATA --config configs/LDNet-ML_MobileNetV3_FFN_1e-3.yaml --update_freq 2 --seed 2337 --tag 60 | ``` 61 | 62 | All checkpoints will be saved in, by default, `exp//model-.pt` 63 | 64 | ### Inference on the BVCC validation set with the trained LDNet. 65 | 66 | After the training ends, we can do inference using the saved checkpoints one-by-one: 67 | 68 | ``` 69 | python inference_for_voicemos.py --dataset_name BVCC --data_dir data/phase1-main/DATA --tag --mode mean_listener 70 | ``` 71 | 72 | This is what can be expected if using the `Pretrained-LDNet-ML-2337` model: 73 | 74 | ``` 75 | [Info] Number of valid samples: 1066 76 | [Info] Model parameters: 957057 77 | ================================================= 78 | [Info] Evaluating ep 27000 79 | 100%|██████████████| 1066/1066 [00:16<00:00, 66.50it/s] 80 | [UTTERANCE] valid MSE: 0.318, LCC: 0.785, SRCC: 0.787, KTAU: 0.597 81 | [SYSTEM] valid MSE: 0.110, LCC: 0.925, SRCC: 0.918, KTAU: 0.757 82 | ``` 83 | 84 | All results will be saved in `exp//BVCC_mean_listener_valid/`, including some figures for inspection (see [here](./README.md#Inference) for more details) and a file named `_answer.txt`. A summary of the results will also be saved in `exp//BVCC_mean_listener.csv`. We can use this file to manually choose which checkpoint we want to use. 85 | 86 | ### Zero-shot inference on the OOD validation set with the LDNet trained on the BVCC training set. 87 | 88 | We can simply change the dataset to perform zero-shot inference on the OOD validation set: 89 | 90 | ``` 91 | python inference_for_voicemos.py --dataset_name OOD --data_dir data/phase1-ood/DATA --tag --mode mean_listener 92 | ``` 93 | 94 | Then, all results will be saved in `exp//OOD_mean_listener_valid/`. 95 | 96 | ### Fine-tuning on the OOD labeled training set with with the LDNet pretrained on the BVCC training set 97 | 98 | We can perform a very simple fine-tuning procedure on the OOD labeled training set with a model pretrained on the BVCC training set. Here we demonstrate how to use the `LDNet-ML_MobileNetV3_FFN_1e-3_ft.yaml` config with the `Pretrained-LDNet-ML-2337` pretrained model to perform fine-tuning: 99 | 100 | ``` 101 | python train.py --dataset_name OOD --data_dir data/phase1-ood/DATA --config configs/LDNet-ML_MobileNetV3_FFN_1e-3_ft.yaml --tag --pretrained_model_path exp/Pretrained-LDNet-ML-2337/model-27000.pt 102 | ``` 103 | 104 | All checkpoints will be saved in, again, by default, `exp//model-.pt` 105 | 106 | ### Inference on the OOD validation set with the LDNet fine-tuned on the OOD labeled training set. 107 | 108 | Similar to above, we can simply change the dataset to perform inference on the OOD validation set using the fine-tuned model: 109 | 110 | ``` 111 | python inference_for_voicemos.py --dataset_name OOD --data_dir data/phase1-ood/DATA --tag --mode mean_listener 112 | ``` 113 | 114 | Then, all results will be saved in `exp//OOD_mean_listener_valid/`. 115 | 116 | ### Submission generation for the CodaLab platform 117 | 118 | The submission format of the CodaLab competition platform is a zip file (can be any name) containing a text file called `answer.txt` (this naming is a **MUST**). We have prepared a convenient packing script. If you only want to submit the main track result, pass the zip file name as the first argument, and pass the main track answer file as the the second argument. The following is an example: 119 | 120 | ``` 121 | ./pack_for_voicemos.sh .zip exp//BVCC__/_answer.txt 122 | ``` 123 | 124 | If you want to submit results for both the main track and the OOD track, pass the OOD track answer as the third argument: 125 | 126 | ``` 127 | ./pack_for_voicemos.sh .zip exp//BVCC__/_answer.txt exp//OOD__/_answer.txt 128 | ``` 129 | 130 | Then, submit the generated `.zip` to the CodaLab platform! 131 | 132 | ## Evaluation phase (phase 2) 133 | 134 | Will be updated once we enter phase 2! -------------------------------------------------------------------------------- /configs/LDNet-ML_MobileNetV3_FFN_1e-3.yaml: -------------------------------------------------------------------------------- 1 | # model configurations 2 | model: "LDNet" 3 | audio_input_dim: 257 4 | judge_emb_dim: 128 5 | 6 | encoder_type: "mobilenetv3" 7 | encoder_bneck_configs: 8 | - [16, 3, 16, 16, True, "RE", 3, 1] 9 | - [16, 3, 72, 24, False, "RE", 3, 1] 10 | - [24, 3, 88, 24, False, "RE", 1, 1] 11 | - [24, 5, 96, 40, True, "HS", 3, 1] 12 | - [40, 5, 240, 40, True, "HS", 1, 1] 13 | - [40, 5, 240, 40, True, "HS", 1, 1] 14 | - [40, 5, 120, 48, True, "HS", 1, 1] 15 | - [48, 5, 144, 48, True, "HS", 1, 1] 16 | - [48, 5, 288, 96, True, "HS", 3, 1] 17 | - [96, 5, 576, 96, True, "HS", 1, 1] 18 | - [96, 5, 576, 96, True, "HS", 1, 1] 19 | encoder_output_dim: 256 20 | 21 | decoder_type: "ffn" 22 | decoder_rnn_dim: 128 23 | decoder_dnn_dim: 64 24 | decoder_dropout_rate: 0.3 25 | 26 | activation: "ReLU" 27 | range_clipping: True # this is needed if output_type is scalar 28 | combine_mean_score: False 29 | use_mean_listener: True 30 | 31 | output_type: "scalar" 32 | 33 | # training configurations 34 | optimizer: 35 | name: "RMSprop" 36 | lr: 1.0e-3 37 | # the following params come from 38 | # https://github.com/pytorch/vision/blob/c2ab0c59f42babf9ad01aa616cd8a901daac86dd/references/classification/train.py#L172-L173 39 | eps: 0.0316 40 | alpha: 0.9 41 | scheduler: 42 | name: "stepLR" 43 | step_size: 1000 44 | gamma: 0.97 45 | train_batch_size: 30 46 | test_batch_size: 1 47 | inference_mode: "mean_listener" 48 | 49 | use_mean_net: False 50 | alpha: 0 51 | lambda: 1 52 | tau: 0.5 53 | 54 | padding_mode: "repetitive" # repetitive, zero_padding 55 | mask_loss: False 56 | total_steps: 100000 57 | valid_steps: 1000 58 | grad_clip: 1 59 | -------------------------------------------------------------------------------- /configs/LDNet-ML_MobileNetV3_FFN_1e-3_ft.yaml: -------------------------------------------------------------------------------- 1 | # model configurations 2 | model: "LDNet" 3 | audio_input_dim: 257 4 | judge_emb_dim: 128 5 | 6 | encoder_type: "mobilenetv3" 7 | encoder_bneck_configs: 8 | - [16, 3, 16, 16, True, "RE", 3, 1] 9 | - [16, 3, 72, 24, False, "RE", 3, 1] 10 | - [24, 3, 88, 24, False, "RE", 1, 1] 11 | - [24, 5, 96, 40, True, "HS", 3, 1] 12 | - [40, 5, 240, 40, True, "HS", 1, 1] 13 | - [40, 5, 240, 40, True, "HS", 1, 1] 14 | - [40, 5, 120, 48, True, "HS", 1, 1] 15 | - [48, 5, 144, 48, True, "HS", 1, 1] 16 | - [48, 5, 288, 96, True, "HS", 3, 1] 17 | - [96, 5, 576, 96, True, "HS", 1, 1] 18 | - [96, 5, 576, 96, True, "HS", 1, 1] 19 | encoder_output_dim: 256 20 | 21 | decoder_type: "ffn" 22 | decoder_rnn_dim: 128 23 | decoder_dnn_dim: 64 24 | decoder_dropout_rate: 0.3 25 | 26 | activation: "ReLU" 27 | range_clipping: True # this is needed if output_type is scalar 28 | combine_mean_score: False 29 | use_mean_listener: True 30 | 31 | output_type: "scalar" 32 | 33 | # training configurations 34 | optimizer: 35 | name: "RMSprop" 36 | lr: 1.0e-3 37 | # the following params come from 38 | # https://github.com/pytorch/vision/blob/c2ab0c59f42babf9ad01aa616cd8a901daac86dd/references/classification/train.py#L172-L173 39 | eps: 0.0316 40 | alpha: 0.9 41 | scheduler: 42 | name: "stepLR" 43 | step_size: 500 44 | gamma: 0.97 45 | train_batch_size: 30 46 | test_batch_size: 1 47 | inference_mode: "mean_listener" 48 | 49 | use_mean_net: False 50 | alpha: 0 51 | lambda: 1 52 | tau: 0.5 53 | 54 | padding_mode: "repetitive" # repetitive, zero_padding 55 | mask_loss: False 56 | total_steps: 10000 57 | valid_steps: 100 58 | grad_clip: 1 59 | -------------------------------------------------------------------------------- /configs/LDNet-MN_MobileNetV3_RNN_FFN_1e-3_lamb4.yaml: -------------------------------------------------------------------------------- 1 | # model configurations 2 | model: "LDNet" 3 | audio_input_dim: 257 4 | judge_emb_dim: 128 5 | 6 | mean_net_type: "ffn" 7 | mean_net_rnn_dim: 128 8 | mean_net_dnn_dim: 64 9 | mean_net_dropout_rate: 0.3 10 | mean_net_range_clipping: True # this is needed if output_type is scalar 11 | 12 | encoder_type: "mobilenetv3" 13 | encoder_bneck_configs: 14 | - [16, 3, 16, 16, True, "RE", 3, 1] 15 | - [16, 3, 72, 24, False, "RE", 3, 1] 16 | - [24, 3, 88, 24, False, "RE", 1, 1] 17 | - [24, 5, 96, 40, True, "HS", 3, 1] 18 | - [40, 5, 240, 40, True, "HS", 1, 1] 19 | - [40, 5, 240, 40, True, "HS", 1, 1] 20 | - [40, 5, 120, 48, True, "HS", 1, 1] 21 | - [48, 5, 144, 48, True, "HS", 1, 1] 22 | - [48, 5, 288, 96, True, "HS", 3, 1] 23 | - [96, 5, 576, 96, True, "HS", 1, 1] 24 | - [96, 5, 576, 96, True, "HS", 1, 1] 25 | encoder_output_dim: 256 26 | 27 | decoder_type: "rnn" 28 | decoder_rnn_dim: 128 29 | decoder_dnn_dim: 64 30 | decoder_dropout_rate: 0.3 31 | 32 | activation: "ReLU" 33 | range_clipping: True # this is needed if output_type is scalar 34 | combine_mean_score: False 35 | use_mean_listener: False 36 | 37 | output_type: "scalar" 38 | 39 | # training configurations 40 | optimizer: 41 | name: "RMSprop" 42 | lr: 1.0e-3 43 | # the following params come from 44 | # https://github.com/pytorch/vision/blob/c2ab0c59f42babf9ad01aa616cd8a901daac86dd/references/classification/train.py#L172-L173 45 | eps: 0.0316 46 | alpha: 0.9 47 | scheduler: 48 | name: "stepLR" 49 | step_size: 1000 50 | gamma: 0.97 51 | train_batch_size: 20 52 | test_batch_size: 1 53 | inference_mode: "all_listeners" 54 | 55 | use_mean_net: True 56 | alpha: 1 57 | lambda: 4 58 | tau: 0.5 59 | 60 | padding_mode: "repetitive" # repetitive, zero_padding 61 | mask_loss: False 62 | total_steps: 100000 63 | valid_steps: 1000 64 | grad_clip: 1 65 | -------------------------------------------------------------------------------- /configs/LDNet_MobileNetV3_FFN_1e-3.yaml: -------------------------------------------------------------------------------- 1 | # model configurations 2 | model: "LDNet" 3 | audio_input_dim: 257 4 | judge_emb_dim: 128 5 | 6 | encoder_type: "mobilenetv3" 7 | encoder_bneck_configs: 8 | - [16, 3, 16, 16, True, "RE", 3, 1] 9 | - [16, 3, 72, 24, False, "RE", 3, 1] 10 | - [24, 3, 88, 24, False, "RE", 1, 1] 11 | - [24, 5, 96, 40, True, "HS", 3, 1] 12 | - [40, 5, 240, 40, True, "HS", 1, 1] 13 | - [40, 5, 240, 40, True, "HS", 1, 1] 14 | - [40, 5, 120, 48, True, "HS", 1, 1] 15 | - [48, 5, 144, 48, True, "HS", 1, 1] 16 | - [48, 5, 288, 96, True, "HS", 3, 1] 17 | - [96, 5, 576, 96, True, "HS", 1, 1] 18 | - [96, 5, 576, 96, True, "HS", 1, 1] 19 | encoder_output_dim: 256 20 | 21 | decoder_type: "ffn" 22 | decoder_dnn_dim: 64 23 | decoder_dropout_rate: 0.3 24 | 25 | activation: "ReLU" 26 | range_clipping: True # this is needed if output_type is scalar 27 | combine_mean_score: False 28 | use_mean_listener: False 29 | 30 | output_type: "scalar" 31 | 32 | # training configurations 33 | optimizer: 34 | name: "RMSprop" 35 | lr: 1.0e-3 36 | # the following params come from 37 | # https://github.com/pytorch/vision/blob/c2ab0c59f42babf9ad01aa616cd8a901daac86dd/references/classification/train.py#L172-L173 38 | eps: 0.0316 39 | alpha: 0.9 40 | scheduler: 41 | name: "stepLR" 42 | step_size: 1000 43 | gamma: 0.97 44 | train_batch_size: 20 45 | test_batch_size: 1 46 | inference_mode: "all_listeners" 47 | 48 | use_mean_net: False 49 | alpha: 0 50 | lambda: 1 51 | tau: 0.5 52 | 53 | padding_mode: "repetitive" # repetitive, zero_padding 54 | mask_loss: False 55 | total_steps: 100000 56 | valid_steps: 1000 57 | grad_clip: 1 58 | -------------------------------------------------------------------------------- /configs/LDNet_MobileNetV3_RNN_5e-3.yaml: -------------------------------------------------------------------------------- 1 | # model configurations 2 | model: "LDNet" 3 | audio_input_dim: 257 4 | judge_emb_dim: 128 5 | 6 | encoder_type: "mobilenetv3" 7 | encoder_bneck_configs: 8 | - [16, 3, 16, 16, True, "RE", 3, 1] 9 | - [16, 3, 72, 24, False, "RE", 3, 1] 10 | - [24, 3, 88, 24, False, "RE", 1, 1] 11 | - [24, 5, 96, 40, True, "HS", 3, 1] 12 | - [40, 5, 240, 40, True, "HS", 1, 1] 13 | - [40, 5, 240, 40, True, "HS", 1, 1] 14 | - [40, 5, 120, 48, True, "HS", 1, 1] 15 | - [48, 5, 144, 48, True, "HS", 1, 1] 16 | - [48, 5, 288, 96, True, "HS", 3, 1] 17 | - [96, 5, 576, 96, True, "HS", 1, 1] 18 | - [96, 5, 576, 96, True, "HS", 1, 1] 19 | encoder_output_dim: 256 20 | 21 | decoder_type: "rnn" 22 | decoder_rnn_dim: 128 23 | decoder_dnn_dim: 64 24 | decoder_dropout_rate: 0.3 25 | 26 | activation: "ReLU" 27 | range_clipping: True # this is needed if output_type is scalar 28 | combine_mean_score: False 29 | use_mean_listener: False 30 | 31 | output_type: "scalar" 32 | 33 | # training configurations 34 | optimizer: 35 | name: "RMSprop" 36 | lr: 5.0e-3 37 | # the following params come from 38 | # https://github.com/pytorch/vision/blob/c2ab0c59f42babf9ad01aa616cd8a901daac86dd/references/classification/train.py#L172-L173 39 | eps: 0.0316 40 | alpha: 0.9 41 | scheduler: 42 | name: "stepLR" 43 | step_size: 1000 44 | gamma: 0.97 45 | train_batch_size: 20 46 | test_batch_size: 1 47 | inference_mode: "all_listeners" 48 | 49 | use_mean_net: False 50 | alpha: 0 51 | lambda: 1 52 | tau: 0.5 53 | 54 | padding_mode: "repetitive" # repetitive, zero_padding 55 | mask_loss: False 56 | total_steps: 100000 57 | valid_steps: 1000 58 | grad_clip: 1 59 | -------------------------------------------------------------------------------- /configs/MBNet.yaml: -------------------------------------------------------------------------------- 1 | # model configurations 2 | model: "MBNet" 3 | audio_input_dim: 257 4 | judge_emb_dim: 86 # simple encoder case 5 | 6 | mean_net_input: "audio" 7 | mean_net_conv_chs: [16, 32, 64, 128] 8 | mean_net_rnn_dim: 128 9 | mean_net_dnn_dim: 128 10 | mean_net_dropout_rate: 0.3 11 | mean_net_range_clipping: True 12 | mean_net_output_type: "scalar" 13 | 14 | decoder_conv_chs: [32, 32, 32] 15 | decoder_rnn_dim: 64 16 | decoder_dnn_dim: 32 17 | decoder_dropout_rate: 0.3 18 | 19 | activation: "ReLU" 20 | range_clipping: True 21 | combine_mean_score: True 22 | use_mean_listener: False 23 | 24 | output_type: "scalar" 25 | 26 | # training configurations 27 | optimizer: 28 | name: Adam 29 | lr: 1.0e-4 30 | train_batch_size: 64 31 | test_batch_size: 1 32 | inference_mode: "mean_net" 33 | 34 | use_mean_net: True 35 | alpha: 1 36 | lambda: 4 37 | tau: 0.5 38 | 39 | padding_mode: "repetitive" # repetitive, zero_padding 40 | mask_loss: False 41 | total_steps: 50000 42 | valid_steps: 1000 43 | grad_clip: 1 44 | -------------------------------------------------------------------------------- /data/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | dataset=$1 4 | 5 | if [ ${dataset} = "vcc2018" ]; then 6 | mkdir -p ${dataset} 7 | cd ${dataset} 8 | if [ ! -e ./.done ]; then 9 | wget https://datashare.ed.ac.uk/bitstream/handle/10283/3061/vcc2018_submitted_systems_converted_speech.tar.gz 10 | tar zxvf vcc2018_submitted_systems_converted_speech.tar.gz 11 | rm -f vcc2018_submitted_systems_converted_speech.tar.gz 12 | mv mnt/sysope/test_files/testVCC2/*.wav . 13 | rm -rf mnt/ 14 | echo "Successfully finished download." 15 | touch ./.done 16 | else 17 | echo "Already exists. Skip download." 18 | fi 19 | cd ../ 20 | else 21 | echo "Dataset not supported." 22 | fi -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import librosa 3 | import numpy as np 4 | import torch 5 | from torch.utils.data.dataset import Dataset 6 | from torch.utils.data import DataLoader 7 | import pandas as pd 8 | import scipy 9 | 10 | from collections import defaultdict 11 | import h5py 12 | from torch.nn.utils.rnn import pad_sequence 13 | 14 | FFT_SIZE = 512 15 | SGRAM_DIM = FFT_SIZE // 2 + 1 16 | 17 | class ASVBC19Dataset(Dataset): 18 | def __init__(self, original_metadata, data_dir, idtable_path=None, split="train", padding_mode="zero_padding", use_mean_listener=False): 19 | self.data_dir = data_dir 20 | self.split = split 21 | self.padding_mode = padding_mode 22 | self.use_mean_listener = use_mean_listener 23 | 24 | # add mean listener to metadata 25 | if use_mean_listener: 26 | mean_listener_metadata = self.gen_mean_listener_metadata(original_metadata) 27 | metadata = original_metadata + mean_listener_metadata 28 | else: 29 | metadata = original_metadata 30 | 31 | # get judge id table and number of judges 32 | if idtable_path is not None: 33 | if os.path.isfile(idtable_path): 34 | self.idtable = torch.load(idtable_path) 35 | elif self.split == "train": 36 | self.gen_idtable(metadata, idtable_path) 37 | self.num_judges = len(self.idtable) 38 | 39 | self.metadata = [] 40 | if self.split == "train": 41 | #(NOTE) unlight (210921): need to fix this in the future if we want to do training. 42 | for wav_name, judge_name, avg_score, score in metadata: 43 | self.metadata.append([wav_name, avg_score, score, self.idtable[judge_name]]) 44 | else: 45 | for item in metadata: 46 | self.metadata.append(item) 47 | 48 | # build system list 49 | self.systems = list(set([item[0] for item in metadata])) 50 | 51 | def __getitem__(self, idx): 52 | if self.split == "train": 53 | wav_name, avg_score, score, judge_id = self.metadata[idx] 54 | else: 55 | sys_name, wav_name, avg_score = self.metadata[idx] 56 | 57 | h5_path = os.path.join(self.data_dir, "bin", wav_name.replace(".wav", ".h5")) 58 | data_file = h5py.File(h5_path, 'r') 59 | mag_sgram = np.array(data_file['mag_sgram'][:]) 60 | timestep = mag_sgram.shape[0] 61 | mag_sgram = np.reshape(mag_sgram,(timestep, SGRAM_DIM)) 62 | 63 | if self.split == "train": 64 | return mag_sgram, avg_score, score, judge_id 65 | else: 66 | return mag_sgram, avg_score, sys_name 67 | 68 | def __len__(self): 69 | return len(self.metadata) 70 | 71 | def gen_mean_listener_metadata(self, original_metadata): 72 | assert self.split == "train" 73 | mean_listener_metadata = [] 74 | wav_names = set() 75 | for wav_name, _, avg_score, _ in original_metadata: 76 | if wav_name not in wav_names: 77 | mean_listener_metadata.append([wav_name, "mean_listener", avg_score, avg_score]) 78 | wav_names.add(wav_name) 79 | return mean_listener_metadata 80 | 81 | def gen_idtable(self, metadata, idtable_path): 82 | self.idtable = {} 83 | count = 0 84 | for _, judge_name, _, _ in metadata: 85 | # mean listener always takes the last id 86 | if judge_name not in self.idtable and not judge_name == "mean_listener": 87 | self.idtable[judge_name] = count 88 | count += 1 89 | if self.use_mean_listener: 90 | self.idtable["mean_listener"] = count 91 | count += 1 92 | torch.save(self.idtable, idtable_path) 93 | 94 | def collate_fn(self, batch): 95 | sorted_batch = sorted(batch, key=lambda x: -x[0].shape[0]) 96 | bs = len(sorted_batch) # batch_size 97 | avg_scores = torch.FloatTensor([sorted_batch[i][1] for i in range(bs)]) 98 | mag_sgrams = [torch.from_numpy(sorted_batch[i][0]) for i in range(bs)] 99 | mag_sgrams_lengths = torch.from_numpy(np.array([mag_sgram.size(0) for mag_sgram in mag_sgrams])) 100 | 101 | if self.padding_mode == "zero_padding": 102 | mag_sgrams_padded = pad_sequence(mag_sgrams, batch_first=True) 103 | elif self.padding_mode == "repetitive": 104 | max_len = mag_sgrams_lengths[0] 105 | mag_sgrams_padded = [] 106 | for mag_sgram in mag_sgrams: 107 | this_len = mag_sgram.shape[0] 108 | dup_times = max_len // this_len 109 | remain = max_len - this_len * dup_times 110 | to_dup = [mag_sgram for t in range(dup_times)] 111 | to_dup.append(mag_sgram[:remain, :]) 112 | mag_sgrams_padded.append(torch.Tensor(np.concatenate(to_dup, axis = 0))) 113 | mag_sgrams_padded = torch.stack(mag_sgrams_padded, dim = 0) 114 | else: 115 | raise NotImplementedError 116 | 117 | if not self.split == "train": 118 | sys_names = [sorted_batch[i][2] for i in range(bs)] 119 | return mag_sgrams_padded, avg_scores, sys_names 120 | else: 121 | scores = torch.FloatTensor([sorted_batch[i][2] for i in range(bs)]) 122 | judge_ids = torch.LongTensor([sorted_batch[i][3] for i in range(bs)]) 123 | return mag_sgrams_padded, mag_sgrams_lengths, avg_scores, scores, judge_ids 124 | 125 | 126 | class VCC18Dataset(Dataset): 127 | def __init__(self, original_wav_file, original_score_csv, idtable_path=None, split="traing", use_mean_listener=False): 128 | self.split = split 129 | self.use_mean_listener = use_mean_listener 130 | 131 | self.features = {} 132 | 133 | # add mean listener to metadata 134 | if use_mean_listener: 135 | mean_listener_wav_file, mean_listener_score_csv = self.gen_mean_listener_metadata(original_wav_file, original_score_csv) 136 | self.wavs = original_wav_file + mean_listener_wav_file 137 | self.scores = original_score_csv.append(mean_listener_score_csv, ignore_index = True) 138 | else: 139 | self.wavs = original_wav_file 140 | self.scores = original_score_csv 141 | 142 | # get judge id table and number of judges 143 | if idtable_path is not None: 144 | if os.path.isfile(idtable_path): 145 | self.idtable = torch.load(idtable_path) 146 | elif self.split == "train": 147 | self.gen_idtable(idtable_path) 148 | for i, judge_i in enumerate(self.scores['JUDGE']): 149 | self.scores['JUDGE'][i] = self.idtable[judge_i] 150 | self.num_judges = len(self.idtable) 151 | 152 | # build system list 153 | self.systems = list(set([name.split("_")[0] for name in self.scores["WAV_PATH"]])) 154 | 155 | def __getitem__(self, idx): 156 | if type(self.wavs[idx]) == int: 157 | wav_name = self.wavs[idx - self.wavs[idx]] 158 | else: 159 | wav_name = self.wavs[idx] 160 | 161 | # cache features 162 | if wav_name not in self.features: 163 | wav, _ = librosa.load(wav_name, sr = 16000) 164 | feature = np.abs(librosa.stft(wav, n_fft = 512)).T 165 | self.features[wav_name] = feature 166 | 167 | return self.features[wav_name], self.scores['MEAN'][idx], self.scores['MOS'][idx], self.scores['JUDGE'][idx], self.scores["WAV_PATH"][idx].split("_")[0] 168 | 169 | def __len__(self): 170 | return len(self.wavs) 171 | 172 | def gen_mean_listener_metadata(self, original_wav_file, original_score_csv): 173 | assert self.split == "train" 174 | ks = ["MEAN", "MOS", "JUDGE", "WAV_PATH"] 175 | mean_listener_wavs = [] 176 | mean_listener_metadata = {k: [] for k in ks} 177 | for i, line in enumerate(original_wav_file): 178 | if not type(line) == int and not line in mean_listener_wavs: 179 | mean_listener_wavs.append(line) 180 | mean_listener_metadata["MEAN"].append(original_score_csv["MEAN"][i]) 181 | mean_listener_metadata["MOS"].append(original_score_csv["MEAN"][i]) 182 | mean_listener_metadata["WAV_PATH"].append(original_score_csv["WAV_PATH"][i]) 183 | mean_listener_metadata["JUDGE"].append("mean_listener") 184 | 185 | return mean_listener_wavs, pd.DataFrame(mean_listener_metadata) 186 | 187 | def gen_idtable(self, idtable_path): 188 | self.idtable = {} 189 | count = 0 190 | for i, judge_i in enumerate(self.scores['JUDGE']): 191 | if judge_i not in self.idtable.keys() and not judge_i == "mean_listener": 192 | self.idtable[judge_i] = count 193 | count += 1 194 | if self.use_mean_listener: 195 | self.idtable["mean_listener"] = count 196 | count += 1 197 | torch.save(self.idtable, idtable_path) 198 | 199 | def collate_fn(self, samples): 200 | # wavs may be list of wave or spectrogram, which has shape (time, feature) or (time,) 201 | wavs, means, scores, judge_ids, sys_names = zip(*samples) 202 | max_len = max(wavs, key = lambda x: x.shape[0]).shape[0] 203 | wav_lengths = torch.from_numpy(np.array([wav.shape[0] for wav in wavs])) 204 | output_wavs = [] 205 | for i, wav in enumerate(wavs): 206 | wav_len = wav.shape[0] 207 | dup_times = max_len//wav_len 208 | remain = max_len - wav_len*dup_times 209 | to_dup = [wav for t in range(dup_times)] 210 | to_dup.append(wav[:remain, :]) 211 | output_wavs.append(torch.Tensor(np.concatenate(to_dup, axis = 0))) 212 | output_wavs = torch.stack(output_wavs, dim = 0) 213 | means = torch.FloatTensor(means) 214 | scores = torch.FloatTensor(scores) 215 | 216 | if not self.split == "train": 217 | return output_wavs, means, sys_names 218 | else: 219 | judge_ids = torch.LongTensor(judge_ids) 220 | return output_wavs, wav_lengths, means, scores, judge_ids 221 | 222 | class BCVCCDataset(Dataset): 223 | def __init__(self, original_metadata, data_dir, idtable_path=None, split="train", padding_mode="zero_padding", use_mean_listener=False): 224 | self.data_dir = data_dir 225 | self.split = split 226 | self.padding_mode = padding_mode 227 | self.use_mean_listener = use_mean_listener 228 | 229 | # cache features 230 | self.features = {} 231 | 232 | # add mean listener to metadata 233 | if use_mean_listener: 234 | mean_listener_metadata = self.gen_mean_listener_metadata(original_metadata) 235 | metadata = original_metadata + mean_listener_metadata 236 | else: 237 | metadata = original_metadata 238 | 239 | # get judge id table and number of judges 240 | if idtable_path is not None: 241 | if os.path.isfile(idtable_path): 242 | self.idtable = torch.load(idtable_path) 243 | elif self.split == "train": 244 | self.gen_idtable(metadata, idtable_path) 245 | self.num_judges = len(self.idtable) 246 | 247 | self.metadata = [] 248 | if self.split == "train": 249 | for wav_name, judge_name, avg_score, score in metadata: 250 | self.metadata.append([wav_name, avg_score, score, self.idtable[judge_name]]) 251 | else: 252 | for item in metadata: 253 | self.metadata.append(item) 254 | 255 | # build system list 256 | self.systems = list(set([item[0] for item in metadata])) 257 | 258 | def __getitem__(self, idx): 259 | if self.split == "train": 260 | wav_name, avg_score, score, judge_id = self.metadata[idx] 261 | else: 262 | sys_name, wav_name, avg_score = self.metadata[idx] 263 | 264 | # cache features 265 | if wav_name in self.features: 266 | mag_sgram = self.features[wav_name] 267 | else: 268 | h5_path = os.path.join(self.data_dir, "bin", wav_name + ".h5") 269 | if os.path.isfile(h5_path): 270 | data_file = h5py.File(h5_path, 'r') 271 | mag_sgram = np.array(data_file['mag_sgram'][:]) 272 | timestep = mag_sgram.shape[0] 273 | mag_sgram = np.reshape(mag_sgram,(timestep, SGRAM_DIM)) 274 | else: 275 | wav, _ = librosa.load(os.path.join(self.data_dir, "wav", wav_name), sr = 16000) 276 | mag_sgram = np.abs(librosa.stft(wav, n_fft = 512, hop_length=256, win_length=512, window=scipy.signal.hamming)).astype(np.float32).T 277 | self.features[wav_name] = mag_sgram 278 | 279 | if self.split == "train": 280 | return mag_sgram, avg_score, score, judge_id 281 | else: 282 | return mag_sgram, avg_score, sys_name, wav_name 283 | 284 | def __len__(self): 285 | return len(self.metadata) 286 | 287 | def gen_mean_listener_metadata(self, original_metadata): 288 | assert self.split == "train" 289 | mean_listener_metadata = [] 290 | wav_names = set() 291 | for wav_name, _, avg_score, _ in original_metadata: 292 | if wav_name not in wav_names: 293 | mean_listener_metadata.append([wav_name, "mean_listener", avg_score, avg_score]) 294 | wav_names.add(wav_name) 295 | return mean_listener_metadata 296 | 297 | def gen_idtable(self, metadata, idtable_path): 298 | self.idtable = {} 299 | count = 0 300 | for _, judge_name, _, _ in metadata: 301 | # mean listener always takes the last id 302 | if judge_name not in self.idtable and not judge_name == "mean_listener": 303 | self.idtable[judge_name] = count 304 | count += 1 305 | if self.use_mean_listener: 306 | self.idtable["mean_listener"] = count 307 | count += 1 308 | torch.save(self.idtable, idtable_path) 309 | 310 | def collate_fn(self, batch): 311 | sorted_batch = sorted(batch, key=lambda x: -x[0].shape[0]) 312 | bs = len(sorted_batch) # batch_size 313 | avg_scores = torch.FloatTensor([sorted_batch[i][1] for i in range(bs)]) 314 | mag_sgrams = [torch.from_numpy(sorted_batch[i][0]) for i in range(bs)] 315 | mag_sgrams_lengths = torch.from_numpy(np.array([mag_sgram.size(0) for mag_sgram in mag_sgrams])) 316 | 317 | if self.padding_mode == "zero_padding": 318 | mag_sgrams_padded = pad_sequence(mag_sgrams, batch_first=True) 319 | elif self.padding_mode == "repetitive": 320 | max_len = mag_sgrams_lengths[0] 321 | mag_sgrams_padded = [] 322 | for mag_sgram in mag_sgrams: 323 | this_len = mag_sgram.shape[0] 324 | dup_times = max_len // this_len 325 | remain = max_len - this_len * dup_times 326 | to_dup = [mag_sgram for t in range(dup_times)] 327 | to_dup.append(mag_sgram[:remain, :]) 328 | mag_sgrams_padded.append(torch.Tensor(np.concatenate(to_dup, axis = 0))) 329 | mag_sgrams_padded = torch.stack(mag_sgrams_padded, dim = 0) 330 | else: 331 | raise NotImplementedError 332 | 333 | if not self.split == "train": 334 | sys_names = [sorted_batch[i][2] for i in range(bs)] 335 | wav_names = [sorted_batch[i][3] for i in range(bs)] 336 | return mag_sgrams_padded, avg_scores, sys_names, wav_names 337 | else: 338 | scores = torch.FloatTensor([sorted_batch[i][2] for i in range(bs)]) 339 | judge_ids = torch.LongTensor([sorted_batch[i][3] for i in range(bs)]) 340 | return mag_sgrams_padded, mag_sgrams_lengths, avg_scores, scores, judge_ids 341 | 342 | def get_dataset(dataset_name, data_dir, split, idtable_path=None, padding_mode="zero_padding", use_mean_listener=False): 343 | if dataset_name in ["BVCC", "OOD"]: 344 | names = {"train":"TRAINSET", "valid":"DEVSET", "test":"TESTSET"} 345 | 346 | metadata = defaultdict(dict) 347 | metadata_with_avg = list() 348 | 349 | # read metadata 350 | with open(os.path.join(data_dir, "sets", names[split]), "r") as f: 351 | lines = f.read().splitlines() 352 | 353 | # line has format 354 | for line in lines: 355 | parts = line.split(",") 356 | sys_name = parts[0] 357 | wav_name = parts[1] 358 | score = int(parts[2]) 359 | judge_name = parts[4] 360 | metadata[sys_name + "|" + wav_name][judge_name] = score 361 | 362 | # calculate average score 363 | for _id, v in metadata.items(): 364 | sys_name, wav_name = _id.split("|") 365 | avg_score = np.mean(np.array(list(v.values()))) 366 | if split == "train": 367 | for judge_name, score in v.items(): 368 | metadata_with_avg.append([wav_name, judge_name, avg_score, score]) 369 | else: 370 | # in testing mode, additionally return system name and only average score 371 | metadata_with_avg.append([sys_name, wav_name, avg_score]) 372 | 373 | return BCVCCDataset(metadata_with_avg, data_dir, idtable_path, split, padding_mode, use_mean_listener) 374 | 375 | elif dataset_name == "vcc2018": 376 | names = {"train":"vcc2018_training_data.csv", "valid":"vcc2018_valid_data.csv", "test":"vcc2018_testing_data.csv"} 377 | dataframe = pd.read_csv(os.path.join(data_dir, f'{names[split]}'), index_col=False) 378 | wavs = [] 379 | filename = '' 380 | last = 0 381 | for i in range(len(dataframe)): 382 | if dataframe['WAV_PATH'][i] != filename: 383 | wav_name = os.path.join(data_dir, dataframe['WAV_PATH'][i]) 384 | wavs.append(wav_name) 385 | filename = dataframe['WAV_PATH'][i] 386 | last = 0 387 | else: 388 | last += 1 389 | wavs.append(last) 390 | return VCC18Dataset(wavs, dataframe, idtable_path, split, use_mean_listener) 391 | 392 | elif dataset_name in ["asv19", "bc19"]: 393 | if split == "train": 394 | raise NotImplementedError 395 | 396 | names = {"train":"train_mos_list.txt", "valid":"val_mos_list.txt", "test":"test_mos_list.txt"} 397 | 398 | metadata = defaultdict(dict) 399 | metadata_with_avg = list() 400 | 401 | # read metadata 402 | with open(os.path.join(data_dir, "sets", names[split]), "r") as f: 403 | lines = f.read().splitlines() 404 | 405 | # line has format 406 | for line in lines: 407 | parts = line.split(",") 408 | wav_name = parts[0] 409 | sys_name = wav_name.split("-")[1] 410 | avg_score = float(parts[1]) 411 | if dataset_name == "asv19": 412 | avg_score = avg_score * 4.0 / 9.0 + 1.0 413 | metadata_with_avg.append([sys_name, wav_name, avg_score]) 414 | 415 | return ASVBC19Dataset(metadata_with_avg, data_dir, idtable_path, split, padding_mode, use_mean_listener) 416 | else: 417 | raise NotImplementedError 418 | 419 | def get_dataloader(dataset, batch_size, num_workers, shuffle=True): 420 | return DataLoader( 421 | dataset, 422 | batch_size=batch_size, 423 | num_workers=num_workers, 424 | shuffle=shuffle, 425 | collate_fn=dataset.collate_fn, 426 | ) 427 | -------------------------------------------------------------------------------- /exp/FT_LDNet-ML-bs20/config.yml: -------------------------------------------------------------------------------- 1 | activation: ReLU 2 | alpha: 0 3 | audio_input_dim: 257 4 | combine_mean_score: false 5 | decoder_dnn_dim: 64 6 | decoder_dropout_rate: 0.3 7 | decoder_rnn_dim: 128 8 | decoder_type: ffn 9 | encoder_bneck_configs: 10 | - - 16 11 | - 3 12 | - 16 13 | - 16 14 | - true 15 | - RE 16 | - 3 17 | - 1 18 | - - 16 19 | - 3 20 | - 72 21 | - 24 22 | - false 23 | - RE 24 | - 3 25 | - 1 26 | - - 24 27 | - 3 28 | - 88 29 | - 24 30 | - false 31 | - RE 32 | - 1 33 | - 1 34 | - - 24 35 | - 5 36 | - 96 37 | - 40 38 | - true 39 | - HS 40 | - 3 41 | - 1 42 | - - 40 43 | - 5 44 | - 240 45 | - 40 46 | - true 47 | - HS 48 | - 1 49 | - 1 50 | - - 40 51 | - 5 52 | - 240 53 | - 40 54 | - true 55 | - HS 56 | - 1 57 | - 1 58 | - - 40 59 | - 5 60 | - 120 61 | - 48 62 | - true 63 | - HS 64 | - 1 65 | - 1 66 | - - 48 67 | - 5 68 | - 144 69 | - 48 70 | - true 71 | - HS 72 | - 1 73 | - 1 74 | - - 48 75 | - 5 76 | - 288 77 | - 96 78 | - true 79 | - HS 80 | - 3 81 | - 1 82 | - - 96 83 | - 5 84 | - 576 85 | - 96 86 | - true 87 | - HS 88 | - 1 89 | - 1 90 | - - 96 91 | - 5 92 | - 576 93 | - 96 94 | - true 95 | - HS 96 | - 1 97 | - 1 98 | encoder_output_dim: 256 99 | encoder_type: mobilenetv3 100 | grad_clip: 1 101 | inference_mode: mean_listener 102 | judge_emb_dim: 128 103 | lambda: 1 104 | mask_loss: false 105 | model: LDNet 106 | num_judges: 286 107 | optimizer: 108 | alpha: 0.9 109 | eps: 0.0316 110 | lr: 0.001 111 | name: RMSprop 112 | output_type: scalar 113 | padding_mode: repetitive 114 | range_clipping: true 115 | scheduler: 116 | gamma: 0.97 117 | name: stepLR 118 | step_size: 500 119 | tau: 0.5 120 | test_batch_size: 1 121 | total_steps: 10000 122 | train_batch_size: 30 123 | use_mean_listener: true 124 | use_mean_net: false 125 | valid_steps: 100 126 | -------------------------------------------------------------------------------- /exp/FT_LDNet-ML-bs20/idtable.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unilight/LDNet/710ed6979b96868643b4a0c916b43b64d8a95a94/exp/FT_LDNet-ML-bs20/idtable.pkl -------------------------------------------------------------------------------- /exp/FT_LDNet-ML-bs20/model-6400.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unilight/LDNet/710ed6979b96868643b4a0c916b43b64d8a95a94/exp/FT_LDNet-ML-bs20/model-6400.pt -------------------------------------------------------------------------------- /exp/Pretrained-LDNet-ML-2337/config.yml: -------------------------------------------------------------------------------- 1 | activation: ReLU 2 | alpha: 0 3 | audio_input_dim: 257 4 | combine_mean_score: false 5 | decoder_dnn_dim: 64 6 | decoder_dropout_rate: 0.3 7 | decoder_rnn_dim: 128 8 | decoder_type: ffn 9 | encoder_bneck_configs: 10 | - - 16 11 | - 3 12 | - 16 13 | - 16 14 | - true 15 | - RE 16 | - 3 17 | - 1 18 | - - 16 19 | - 3 20 | - 72 21 | - 24 22 | - false 23 | - RE 24 | - 3 25 | - 1 26 | - - 24 27 | - 3 28 | - 88 29 | - 24 30 | - false 31 | - RE 32 | - 1 33 | - 1 34 | - - 24 35 | - 5 36 | - 96 37 | - 40 38 | - true 39 | - HS 40 | - 3 41 | - 1 42 | - - 40 43 | - 5 44 | - 240 45 | - 40 46 | - true 47 | - HS 48 | - 1 49 | - 1 50 | - - 40 51 | - 5 52 | - 240 53 | - 40 54 | - true 55 | - HS 56 | - 1 57 | - 1 58 | - - 40 59 | - 5 60 | - 120 61 | - 48 62 | - true 63 | - HS 64 | - 1 65 | - 1 66 | - - 48 67 | - 5 68 | - 144 69 | - 48 70 | - true 71 | - HS 72 | - 1 73 | - 1 74 | - - 48 75 | - 5 76 | - 288 77 | - 96 78 | - true 79 | - HS 80 | - 3 81 | - 1 82 | - - 96 83 | - 5 84 | - 576 85 | - 96 86 | - true 87 | - HS 88 | - 1 89 | - 1 90 | - - 96 91 | - 5 92 | - 576 93 | - 96 94 | - true 95 | - HS 96 | - 1 97 | - 1 98 | encoder_output_dim: 256 99 | encoder_type: mobilenetv3 100 | grad_clip: 1 101 | inference_mode: mean_listener 102 | judge_emb_dim: 128 103 | lambda: 1 104 | log_steps: 500 105 | mask_loss: false 106 | mean_net_dnn_dim: 64 107 | mean_net_dropout_rate: 0.3 108 | mean_net_range_clipping: true 109 | mean_net_rnn_dim: 128 110 | mean_net_type: ffn 111 | model: LDNet 112 | num_judges: 289 113 | optimizer: 114 | alpha: 0.9 115 | eps: 0.0316 116 | lr: 0.001 117 | name: RMSprop 118 | output_type: scalar 119 | padding_mode: repetitive 120 | range_clipping: true 121 | scheduler: 122 | gamma: 0.97 123 | name: stepLR 124 | step_size: 1000 125 | tau: 0.5 126 | test_batch_size: 1 127 | total_steps: 100000 128 | train_batch_size: 30 129 | use_mean_listener: true 130 | use_mean_net: false 131 | valid_steps: 1000 132 | -------------------------------------------------------------------------------- /exp/Pretrained-LDNet-ML-2337/model-27000.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unilight/LDNet/710ed6979b96868643b4a0c916b43b64d8a95a94/exp/Pretrained-LDNet-ML-2337/model-27000.pt -------------------------------------------------------------------------------- /imgs/60000_distribution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unilight/LDNet/710ed6979b96868643b4a0c916b43b64d8a95a94/imgs/60000_distribution.png -------------------------------------------------------------------------------- /imgs/60000_sys_scatter_plot_utt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unilight/LDNet/710ed6979b96868643b4a0c916b43b64d8a95a94/imgs/60000_sys_scatter_plot_utt.png -------------------------------------------------------------------------------- /imgs/60000_utt_scatter_plot_utt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unilight/LDNet/710ed6979b96868643b4a0c916b43b64d8a95a94/imgs/60000_utt_scatter_plot_utt.png -------------------------------------------------------------------------------- /imgs/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unilight/LDNet/710ed6979b96868643b4a0c916b43b64d8a95a94/imgs/results.png -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import fnmatch 4 | import os 5 | import yaml 6 | 7 | import numpy as np 8 | import scipy 9 | import torch 10 | from tqdm import tqdm 11 | 12 | from dataset import get_dataloader, get_dataset 13 | from models.MBNet import MBNet 14 | from models.LDNet import LDNet 15 | 16 | import scipy.stats 17 | import matplotlib 18 | # Force matplotlib to not use any Xwindows backend. 19 | matplotlib.use('Agg') 20 | import matplotlib.pyplot as plt 21 | 22 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 23 | MAX_FRAMES = 1250 24 | 25 | def find_files(root_dir, query="*.wav", include_root_dir=True): 26 | files = [] 27 | for root, dirnames, filenames in os.walk(root_dir, followlinks=True): 28 | for filename in fnmatch.filter(filenames, query): 29 | files.append(os.path.join(root, filename)) 30 | if not include_root_dir: 31 | files = [file_.replace(root_dir + "/", "") for file_ in files] 32 | 33 | return files 34 | 35 | def save_results(ep, valid_result, test_result, result_path): 36 | if os.path.isfile(result_path): 37 | with open(result_path, "r", newline='') as csvfile: 38 | rows = list(csv.reader(csvfile)) 39 | data = {row[0]: row[1:] for row in rows} 40 | else: 41 | data = {} 42 | data[str(ep)] = valid_result + test_result 43 | rows = [[k]+v for k, v in data.items()] 44 | rows = sorted(rows, key=lambda x:int(x[0])) 45 | with open(result_path, "w", newline='') as csvfile: 46 | writer = csv.writer(csvfile) 47 | writer.writerows(rows) 48 | 49 | def inference(mode, model, ep, dataloader, systems, save_dir, name, dataset_name, return_posterior_scores=False): 50 | if return_posterior_scores: 51 | assert mode == "all_listeners" 52 | 53 | ep_scores = [] 54 | predict_mean_scores = [] 55 | post_scores = [] 56 | true_mean_scores = [] 57 | predict_sys_mean_scores = {system:[] for system in systems} 58 | true_sys_mean_scores = {system:[] for system in systems} 59 | 60 | for i, batch in enumerate(tqdm(dataloader)): 61 | mag_sgrams_padded, avg_scores, sys_names = batch 62 | mag_sgrams_padded = mag_sgrams_padded.to(device) 63 | 64 | # avoid OOM caused by long samples 65 | mag_sgrams_padded = mag_sgrams_padded[:, :MAX_FRAMES] 66 | 67 | # forward 68 | with torch.no_grad(): 69 | if mode == "mean_net": 70 | pred_mean_scores = model.only_mean_inference(spectrum = mag_sgrams_padded) 71 | elif mode == "all_listeners": 72 | pred_mean_scores, posterior_scores = model.average_inference(spectrum = mag_sgrams_padded, include_meanspk=return_posterior_scores) 73 | posterior_scores = posterior_scores.cpu().detach().numpy() 74 | post_scores.extend(posterior_scores.tolist()) 75 | elif mode == "mean_listener": 76 | pred_mean_scores = model.mean_listener_inference(spectrum = mag_sgrams_padded) 77 | else: 78 | raise NotImplementedError 79 | 80 | pred_mean_scores = pred_mean_scores.cpu().detach().numpy() 81 | avg_scores = avg_scores.cpu().detach().numpy() 82 | predict_mean_scores.extend(pred_mean_scores.tolist()) 83 | true_mean_scores.extend(avg_scores.tolist()) 84 | for j, sys_name in enumerate(sys_names): 85 | predict_sys_mean_scores[sys_name].append(pred_mean_scores[j]) 86 | true_sys_mean_scores[sys_name].append(avg_scores[j]) 87 | 88 | with torch.cuda.device(device): 89 | torch.cuda.empty_cache() 90 | 91 | predict_mean_scores = np.array(predict_mean_scores) 92 | true_mean_scores = np.array(true_mean_scores) 93 | predict_sys_mean_scores = np.array([np.mean(scores) for scores in predict_sys_mean_scores.values()]) 94 | true_sys_mean_scores = np.array([np.mean(scores) for scores in true_sys_mean_scores.values()]) 95 | 96 | # plot utterance-level histrogram 97 | plt.style.use('seaborn-deep') 98 | bins = np.linspace(1, 5, 40) 99 | plt.figure(2) 100 | plt.hist([true_mean_scores, predict_mean_scores], bins, label=['true_mos', 'predict_mos']) 101 | plt.legend(loc='upper right') 102 | plt.xlabel('MOS') 103 | plt.ylabel('number') 104 | plt.show() 105 | plt.savefig(os.path.join(save_dir, dataset_name + "_" + mode + "_" + name, f'{ep}_distribution.png'), dpi=150) 106 | plt.close() 107 | 108 | # utterance level scores 109 | MSE=np.mean((true_mean_scores-predict_mean_scores)**2) 110 | LCC=np.corrcoef(true_mean_scores, predict_mean_scores)[0][1] 111 | SRCC=scipy.stats.spearmanr(true_mean_scores, predict_mean_scores)[0] 112 | KTAU=scipy.stats.kendalltau(true_mean_scores, predict_mean_scores)[0] 113 | ep_scores += [MSE, LCC, SRCC, KTAU] 114 | print("[UTTERANCE] {} MSE: {:.3f}, LCC: {:.3f}, SRCC: {:.3f}, KTAU: {:.3f}".format(name, float(MSE), float(LCC), float(SRCC), float(KTAU))) 115 | 116 | # plotting utterance-level scatter plot 117 | M=np.max([np.max(predict_mean_scores),5]) 118 | plt.figure(3) 119 | plt.scatter(true_mean_scores, predict_mean_scores, s =15, color='b', marker='o', edgecolors='b', alpha=.20) 120 | plt.xlim([0.5,M]) 121 | plt.ylim([0.5,M]) 122 | plt.xlabel('True MOS') 123 | plt.ylabel('Predicted MOS') 124 | plt.title('Utt level LCC= {:.4f}, SRCC= {:.4f}, MSE= {:.4f}, KTAU= {:.4f}'.format(LCC, SRCC, MSE, KTAU)) 125 | plt.show() 126 | plt.savefig(os.path.join(save_dir, dataset_name + "_" + mode + "_" + name, f'{ep}_utt_scatter_plot_utt.png'), dpi=150) 127 | plt.close() 128 | 129 | # system level scores 130 | MSE=np.mean((true_sys_mean_scores-predict_sys_mean_scores)**2) 131 | LCC=np.corrcoef(true_sys_mean_scores, predict_sys_mean_scores)[0][1] 132 | SRCC=scipy.stats.spearmanr(true_sys_mean_scores, predict_sys_mean_scores)[0] 133 | KTAU=scipy.stats.kendalltau(true_sys_mean_scores, predict_sys_mean_scores)[0] 134 | ep_scores += [MSE, LCC, SRCC, KTAU] 135 | print("[SYSTEM] {} MSE: {:.3f}, LCC: {:.3f}, SRCC: {:.3f}, KTAU: {:.3f}".format(name, float(MSE), float(LCC), float(SRCC), float(KTAU))) 136 | 137 | # plotting utterance-level scatter plot 138 | M=np.max([np.max(predict_sys_mean_scores),5]) 139 | plt.figure(3) 140 | plt.scatter(true_sys_mean_scores, predict_sys_mean_scores, s =15, color='b', marker='o', edgecolors='b') 141 | plt.xlim([0.5,M]) 142 | plt.ylim([0.5,M]) 143 | plt.xlabel('True MOS') 144 | plt.ylabel('Predicted MOS') 145 | plt.title('Sys level LCC= {:.4f}, SRCC= {:.4f}, MSE= {:.4f}, KTAU= {:.4f}'.format(LCC, SRCC, MSE, KTAU)) 146 | plt.show() 147 | plt.savefig(os.path.join(save_dir, dataset_name + "_" + mode + "_" + name, f'{ep}_sys_scatter_plot_utt.png'), dpi=150) 148 | plt.close() 149 | 150 | if return_posterior_scores: 151 | post_scores = np.array(post_scores) 152 | return ep_scores, post_scores 153 | else: 154 | return ep_scores, None 155 | 156 | def main(): 157 | 158 | parser = argparse.ArgumentParser() 159 | parser.add_argument("--dataset_name", type=str, default = "vcc2018") 160 | parser.add_argument("--data_dir", type=str, default = "data/vcc2018") 161 | parser.add_argument("--exp_dir", type=str, default="exp") 162 | parser.add_argument("--tag", type=str, required=True) 163 | parser.add_argument("--config", type=str, default=None) 164 | parser.add_argument("--ep", type=str, default=None, help="If not specified, evaluate all ckpts.") 165 | parser.add_argument("--start_ep", type=int, default=0, help="Epoch to start evaluation") 166 | parser.add_argument("--mode", type=str, required=True, choices=["mean_net", "all_listeners", "mean_listener"], 167 | help="Inference mode.") 168 | args = parser.parse_args() 169 | 170 | # define dir 171 | save_dir = os.path.join(args.exp_dir, args.tag) 172 | os.makedirs(os.path.join(save_dir, args.dataset_name + "_" + args.mode + "_valid"), exist_ok=True) 173 | os.makedirs(os.path.join(save_dir, args.dataset_name + "_" + args.mode + "_test"), exist_ok=True) 174 | 175 | # read config 176 | if args.config is not None: 177 | print("[Warning] You would probably use the existing config in the exp folder") 178 | config_path = args.config 179 | else: 180 | config_path = os.path.join(save_dir, "config.yml") 181 | with open(config_path, 'r') as file: 182 | config = yaml.load(file, Loader=yaml.FullLoader) 183 | 184 | # read dataset (batch size is always 1 to avoid padding) 185 | valid_set = get_dataset(args.dataset_name, args.data_dir, "valid") 186 | test_set = get_dataset(args.dataset_name, args.data_dir, "test") 187 | valid_loader = get_dataloader(valid_set, batch_size=1, num_workers=1, shuffle=False) 188 | test_loader = get_dataloader(test_set, batch_size=1, num_workers=1, shuffle=False) 189 | print("[Info] Number of validation samples: {}".format(len(valid_set))) 190 | print("[Info] Number of testing samples: {}".format(len(test_set))) 191 | 192 | # define model 193 | if config["model"] == "MBNet": 194 | model = MBNet(config).to(device) 195 | elif config["model"] == "LDNet": 196 | model = LDNet(config).to(device) 197 | else: 198 | raise NotImplementedError 199 | print("[Info] Model parameters: {}".format(model.get_num_params())) 200 | 201 | # either perform inference on one ep (specified by args.ep) or all ep in expdir 202 | if args.ep is not None: 203 | all_ckpts = [os.path.join(save_dir, f"model-{args.ep}.pt")] 204 | else: 205 | # get all ckpts 206 | all_ckpts = find_files(save_dir, "model-*.pt") 207 | 208 | # loop through all ckpts 209 | for model_path in all_ckpts: 210 | ep = os.path.basename(model_path).split(".")[0].split("-")[1] 211 | if int(ep) < args.start_ep: 212 | continue 213 | print("=================================================") 214 | print(f"[Info] Evaluating ep {ep}") 215 | model.load_state_dict(torch.load(model_path), strict=False) 216 | model.eval() 217 | 218 | # returning posterior score was for analyzing listener embedding, but not useful so just ignore it 219 | valid_result, valid_posterior_scores = inference(args.mode, model, ep, valid_loader, valid_set.systems, 220 | save_dir, "valid", args.dataset_name, return_posterior_scores=False) 221 | test_result, test_posterior_scores = inference(args.mode, model, ep, test_loader, test_set.systems, 222 | save_dir, "test", args.dataset_name, return_posterior_scores=False) 223 | 224 | save_results(ep, valid_result, test_result, os.path.join(save_dir, args.dataset_name + "_" + args.mode + ".csv")) 225 | 226 | 227 | if __name__ == "__main__": 228 | main() 229 | -------------------------------------------------------------------------------- /inference_for_voicemos.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import fnmatch 4 | import os 5 | import yaml 6 | 7 | import numpy as np 8 | import scipy 9 | import torch 10 | from tqdm import tqdm 11 | 12 | from dataset import get_dataloader, get_dataset 13 | from models.MBNet import MBNet 14 | from models.LDNet import LDNet 15 | 16 | import scipy.stats 17 | import matplotlib 18 | # Force matplotlib to not use any Xwindows backend. 19 | matplotlib.use('Agg') 20 | import matplotlib.pyplot as plt 21 | 22 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 23 | MAX_FRAMES = 1250 24 | 25 | def find_files(root_dir, query="*.wav", include_root_dir=True): 26 | files = [] 27 | for root, dirnames, filenames in os.walk(root_dir, followlinks=True): 28 | for filename in fnmatch.filter(filenames, query): 29 | files.append(os.path.join(root, filename)) 30 | if not include_root_dir: 31 | files = [file_.replace(root_dir + "/", "") for file_ in files] 32 | 33 | return files 34 | 35 | def save_results(ep, result, result_path): 36 | if os.path.isfile(result_path): 37 | with open(result_path, "r", newline='') as csvfile: 38 | rows = list(csv.reader(csvfile)) 39 | data = {row[0]: row[1:] for row in rows} 40 | else: 41 | data = {} 42 | data[str(ep)] = result 43 | rows = [[k]+v for k, v in data.items()] 44 | rows = sorted(rows, key=lambda x:int(x[0])) 45 | with open(result_path, "w", newline='') as csvfile: 46 | writer = csv.writer(csvfile) 47 | writer.writerows(rows) 48 | 49 | def inference(mode, model, ep, dataloader, systems, save_dir, name, dataset_name, return_posterior_scores=False): 50 | if return_posterior_scores: 51 | assert mode == "all_listeners" 52 | 53 | ep_scores = [] 54 | predict_mean_scores = [] 55 | post_scores = [] 56 | true_mean_scores = [] 57 | predict_sys_mean_scores = {system:[] for system in systems} 58 | true_sys_mean_scores = {system:[] for system in systems} 59 | 60 | submission_scores = {} 61 | 62 | for i, batch in enumerate(tqdm(dataloader)): 63 | mag_sgrams_padded, avg_scores, sys_names, wav_names = batch 64 | mag_sgrams_padded = mag_sgrams_padded.to(device) 65 | 66 | # avoid OOM caused by long samples 67 | mag_sgrams_padded = mag_sgrams_padded[:, :MAX_FRAMES] 68 | 69 | # forward 70 | with torch.no_grad(): 71 | if mode == "mean_net": 72 | pred_mean_scores = model.only_mean_inference(spectrum = mag_sgrams_padded) 73 | elif mode == "all_listeners": 74 | pred_mean_scores, posterior_scores = model.average_inference(spectrum = mag_sgrams_padded, include_meanspk=return_posterior_scores) 75 | posterior_scores = posterior_scores.cpu().detach().numpy() 76 | post_scores.extend(posterior_scores.tolist()) 77 | elif mode == "mean_listener": 78 | pred_mean_scores = model.mean_listener_inference(spectrum = mag_sgrams_padded) 79 | else: 80 | raise NotImplementedError 81 | 82 | pred_mean_scores = pred_mean_scores.cpu().detach().numpy() 83 | avg_scores = avg_scores.cpu().detach().numpy() 84 | predict_mean_scores.extend(pred_mean_scores.tolist()) 85 | true_mean_scores.extend(avg_scores.tolist()) 86 | for j, sys_name in enumerate(sys_names): 87 | predict_sys_mean_scores[sys_name].append(pred_mean_scores[j]) 88 | true_sys_mean_scores[sys_name].append(avg_scores[j]) 89 | for wav_name, pred_mean_score in zip(wav_names, list(pred_mean_scores)): 90 | submission_scores[wav_name] = pred_mean_score 91 | 92 | with torch.cuda.device(device): 93 | torch.cuda.empty_cache() 94 | 95 | with open(os.path.join(save_dir, dataset_name + "_" + mode + "_" + name, f"{ep}_answer.txt"), "w") as f: 96 | for k, v in submission_scores.items(): 97 | f.write("{},{}\n".format(k, v)) 98 | 99 | predict_mean_scores = np.array(predict_mean_scores) 100 | true_mean_scores = np.array(true_mean_scores) 101 | predict_sys_mean_scores = np.array([np.mean(scores) for scores in predict_sys_mean_scores.values()]) 102 | true_sys_mean_scores = np.array([np.mean(scores) for scores in true_sys_mean_scores.values()]) 103 | 104 | # plot utterance-level histrogram 105 | plt.style.use('seaborn-deep') 106 | bins = np.linspace(1, 5, 40) 107 | plt.figure(2) 108 | plt.hist([true_mean_scores, predict_mean_scores], bins, label=['true_mos', 'predict_mos']) 109 | plt.legend(loc='upper right') 110 | plt.xlabel('MOS') 111 | plt.ylabel('number') 112 | plt.show() 113 | plt.savefig(os.path.join(save_dir, dataset_name + "_" + mode + "_" + name, f'{ep}_distribution.png'), dpi=150) 114 | plt.close() 115 | 116 | # utterance level scores 117 | MSE=np.mean((true_mean_scores-predict_mean_scores)**2) 118 | LCC=np.corrcoef(true_mean_scores, predict_mean_scores)[0][1] 119 | SRCC=scipy.stats.spearmanr(true_mean_scores, predict_mean_scores)[0] 120 | KTAU=scipy.stats.kendalltau(true_mean_scores, predict_mean_scores)[0] 121 | ep_scores += [MSE, LCC, SRCC, KTAU] 122 | print("[UTTERANCE] {} MSE: {:.3f}, LCC: {:.3f}, SRCC: {:.3f}, KTAU: {:.3f}".format(name, float(MSE), float(LCC), float(SRCC), float(KTAU))) 123 | 124 | # plotting utterance-level scatter plot 125 | M=np.max([np.max(predict_mean_scores),5]) 126 | plt.figure(3) 127 | plt.scatter(true_mean_scores, predict_mean_scores, s =15, color='b', marker='o', edgecolors='b', alpha=.20) 128 | plt.xlim([0.5,M]) 129 | plt.ylim([0.5,M]) 130 | plt.xlabel('True MOS') 131 | plt.ylabel('Predicted MOS') 132 | plt.title('Utt level LCC= {:.4f}, SRCC= {:.4f}, MSE= {:.4f}, KTAU= {:.4f}'.format(LCC, SRCC, MSE, KTAU)) 133 | plt.show() 134 | plt.savefig(os.path.join(save_dir, dataset_name + "_" + mode + "_" + name, f'{ep}_utt_scatter_plot_utt.png'), dpi=150) 135 | plt.close() 136 | 137 | # system level scores 138 | MSE=np.mean((true_sys_mean_scores-predict_sys_mean_scores)**2) 139 | LCC=np.corrcoef(true_sys_mean_scores, predict_sys_mean_scores)[0][1] 140 | SRCC=scipy.stats.spearmanr(true_sys_mean_scores, predict_sys_mean_scores)[0] 141 | KTAU=scipy.stats.kendalltau(true_sys_mean_scores, predict_sys_mean_scores)[0] 142 | ep_scores += [MSE, LCC, SRCC, KTAU] 143 | print("[SYSTEM] {} MSE: {:.3f}, LCC: {:.3f}, SRCC: {:.3f}, KTAU: {:.3f}".format(name, float(MSE), float(LCC), float(SRCC), float(KTAU))) 144 | 145 | # plotting utterance-level scatter plot 146 | M=np.max([np.max(predict_sys_mean_scores),5]) 147 | plt.figure(3) 148 | plt.scatter(true_sys_mean_scores, predict_sys_mean_scores, s =15, color='b', marker='o', edgecolors='b') 149 | plt.xlim([0.5,M]) 150 | plt.ylim([0.5,M]) 151 | plt.xlabel('True MOS') 152 | plt.ylabel('Predicted MOS') 153 | plt.title('Sys level LCC= {:.4f}, SRCC= {:.4f}, MSE= {:.4f}, KTAU= {:.4f}'.format(LCC, SRCC, MSE, KTAU)) 154 | plt.show() 155 | plt.savefig(os.path.join(save_dir, dataset_name + "_" + mode + "_" + name, f'{ep}_sys_scatter_plot_utt.png'), dpi=150) 156 | plt.close() 157 | 158 | if return_posterior_scores: 159 | post_scores = np.array(post_scores) 160 | return ep_scores, post_scores 161 | else: 162 | return ep_scores, None 163 | 164 | def main(): 165 | 166 | parser = argparse.ArgumentParser() 167 | parser.add_argument("--phase", type=str, default = "valid", choices = ["valid", "test"]) 168 | parser.add_argument("--dataset_name", type=str, default = "BVCC") 169 | parser.add_argument("--data_dir", type=str, default = "data/phase1-main/DATA") 170 | parser.add_argument("--exp_dir", type=str, default="exp") 171 | parser.add_argument("--tag", type=str, required=True) 172 | parser.add_argument("--config", type=str, default=None) 173 | parser.add_argument("--ep", type=str, default=None, help="If not specified, evaluate all ckpts.") 174 | parser.add_argument("--start_ep", type=int, default=0, help="Epoch to start evaluation") 175 | parser.add_argument("--mode", type=str, required=True, choices=["mean_net", "all_listeners", "mean_listener"], 176 | help="Inference mode.") 177 | args = parser.parse_args() 178 | 179 | # define dir 180 | save_dir = os.path.join(args.exp_dir, args.tag) 181 | os.makedirs(os.path.join(save_dir, args.dataset_name + "_" + args.mode + "_" + args.phase), exist_ok=True) 182 | 183 | # read config 184 | if args.config is not None: 185 | print("[Warning] You would probably use the existing config in the exp folder") 186 | config_path = args.config 187 | else: 188 | config_path = os.path.join(save_dir, "config.yml") 189 | with open(config_path, 'r') as file: 190 | config = yaml.load(file, Loader=yaml.FullLoader) 191 | 192 | # read dataset (batch size is always 1 to avoid padding) 193 | dataset = get_dataset(args.dataset_name, args.data_dir, args.phase) 194 | dataloader = get_dataloader(dataset, batch_size=1, num_workers=1, shuffle=False) 195 | print("[Info] Number of {} samples: {}".format(args.phase, len(dataset))) 196 | 197 | # define model 198 | if config["model"] == "MBNet": 199 | model = MBNet(config).to(device) 200 | elif config["model"] == "LDNet": 201 | model = LDNet(config).to(device) 202 | else: 203 | raise NotImplementedError 204 | print("[Info] Model parameters: {}".format(model.get_num_params())) 205 | 206 | # either perform inference on one ep (specified by args.ep) or all ep in expdir 207 | if args.ep is not None: 208 | all_ckpts = [os.path.join(save_dir, f"model-{args.ep}.pt")] 209 | else: 210 | # get all ckpts 211 | all_ckpts = find_files(save_dir, "model-*.pt") 212 | 213 | # loop through all ckpts 214 | for model_path in all_ckpts: 215 | ep = os.path.basename(model_path).split(".")[0].split("-")[1] 216 | if int(ep) < args.start_ep: 217 | continue 218 | print("=================================================") 219 | print(f"[Info] Evaluating ep {ep}") 220 | model.load_state_dict(torch.load(model_path), strict=False) 221 | model.eval() 222 | 223 | result, _ = inference(args.mode, model, ep, dataloader, dataset.systems, 224 | save_dir, args.phase, args.dataset_name, return_posterior_scores=False) 225 | 226 | save_results(ep, result, os.path.join(save_dir, args.dataset_name + "_" + args.mode + ".csv")) 227 | 228 | 229 | if __name__ == "__main__": 230 | main() 231 | -------------------------------------------------------------------------------- /models/LDNet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from functools import partial 7 | from .modules import Projection, MobileNetV2ConvBlocks, MobileNetV3ConvBlocks, STRIDE 8 | 9 | 10 | class LDNet(nn.Module): 11 | def __init__(self, config): 12 | super(LDNet, self).__init__() 13 | self.config = config 14 | 15 | # This is not really used but we just keep it 16 | if config["combine_mean_score"]: 17 | assert config["output_type"] == "scalar" 18 | assert config["use_mean_net"] 19 | assert config["mean_net_output_type"] == config["output_type"] 20 | 21 | # define judge embedding 22 | self.num_judges = config["num_judges"] 23 | self.judge_embedding = nn.Embedding(num_embeddings = self.num_judges, embedding_dim = config["judge_emb_dim"]) 24 | 25 | # define activation 26 | if config["activation"] == "ReLU": 27 | activation = nn.ReLU 28 | else: 29 | raise NotImplementedError 30 | 31 | # define encoder 32 | if config["encoder_type"] == "mobilenetv2": 33 | self.encoder = MobileNetV2ConvBlocks(config["encoder_conv_first_ch"], 34 | config["encoder_conv_t"], 35 | config["encoder_conv_c"], 36 | config["encoder_conv_n"], 37 | config["encoder_conv_s"], 38 | config["encoder_output_dim"]) 39 | elif config["encoder_type"] == "mobilenetv3": 40 | self.encoder = MobileNetV3ConvBlocks(config["encoder_bneck_configs"], 41 | config["encoder_output_dim"]) 42 | else: 43 | raise NotImplementedError 44 | 45 | # define decoder 46 | if config["decoder_type"] == "ffn": 47 | decoder_dnn_input_dim = config["encoder_output_dim"] + config["judge_emb_dim"] 48 | elif config["decoder_type"] == "rnn": 49 | self.decoder_rnn = nn.LSTM(input_size = config["encoder_output_dim"] + config["judge_emb_dim"], 50 | hidden_size = config["decoder_rnn_dim"], 51 | num_layers = 1, batch_first = True, bidirectional = True) 52 | decoder_dnn_input_dim = config["decoder_rnn_dim"] * 2 53 | # there is always dnn 54 | self.decoder_dnn = Projection(decoder_dnn_input_dim, config["decoder_dnn_dim"], 55 | activation, config["output_type"], config["range_clipping"]) 56 | 57 | # define mean net 58 | if config["use_mean_net"]: 59 | if config["mean_net_type"] == "ffn": 60 | mean_net_dnn_input_dim = config["encoder_output_dim"] 61 | elif config["mean_net_type"] == "rnn": 62 | self.mean_net_rnn = nn.LSTM(input_size = config["encoder_output_dim"], 63 | hidden_size = config["mean_net_rnn_dim"], 64 | num_layers = 1, batch_first = True, bidirectional = True) 65 | mean_net_dnn_input_dim = config["mean_net_rnn_dim"] * 2 66 | # there is always dnn 67 | self.mean_net_dnn = Projection(mean_net_dnn_input_dim, config["mean_net_dnn_dim"], 68 | activation, config["output_type"], config["mean_net_range_clipping"]) 69 | 70 | def _get_output_dim(self, input_size, num_layers, stride=STRIDE): 71 | """ 72 | calculate the final ouptut width (dim) of a CNN using the following formula 73 | w_i = |_ (w_i-1 - 1) / stride + 1 _| 74 | """ 75 | output_dim = input_size 76 | for _ in range(num_layers): 77 | output_dim = math.floor((output_dim-1)/STRIDE+1) 78 | return output_dim 79 | 80 | def get_num_params(self): 81 | return sum(p.numel() for n, p in self.named_parameters()) 82 | 83 | def forward(self, spectrum, judge_id): 84 | """Calculate forward propagation. 85 | Args: 86 | spectrum has shape (batch, time, dim) 87 | judge_id has shape (batch) 88 | """ 89 | batch, time, dim = spectrum.shape 90 | 91 | # get judge embedding 92 | judge_feat = self.judge_embedding(judge_id) # (batch, emb_dim) 93 | judge_feat = torch.stack([judge_feat for i in range(time)], dim = 1) #(batch, time, feat_dim) 94 | 95 | # encoder and inject judge embedding 96 | if self.config["encoder_type"] in ["mbnetstyle", "mobilenetv2", "mobilenetv3"]: 97 | spectrum = spectrum.unsqueeze(1) 98 | encoder_outputs = self.encoder(spectrum) # (batch, ch, time, feat_dim) 99 | encoder_outputs = encoder_outputs.view((batch, time, -1)) # (batch, time, feat_dim) 100 | decoder_inputs = torch.cat([encoder_outputs, judge_feat], dim = -1) # concat along feature dimension 101 | else: 102 | raise NotImplementedError 103 | 104 | # mean net 105 | if self.config["use_mean_net"]: 106 | mean_net_inputs = encoder_outputs 107 | if self.config["mean_net_type"] == "rnn": 108 | mean_net_outputs, (h, c) = self.mean_net_rnn(mean_net_inputs) 109 | else: 110 | mean_net_outputs = mean_net_inputs 111 | mean_net_outputs = self.mean_net_dnn(mean_net_outputs) # [batch, time, 1 (scalar) / 5 (categorical)] 112 | 113 | # decoder 114 | if self.config["decoder_type"] == "rnn": 115 | decoder_outputs, (h, c) = self.decoder_rnn(decoder_inputs) 116 | else: 117 | decoder_outputs = decoder_inputs 118 | decoder_outputs = self.decoder_dnn(decoder_outputs) # [batch, time, 1 (scalar) / 5 (categorical)] 119 | 120 | # define scores 121 | mean_score = mean_net_outputs if self.config["use_mean_net"] else None 122 | ld_score = decoder_outputs 123 | 124 | return mean_score, ld_score 125 | 126 | def mean_listener_inference(self, spectrum): 127 | assert self.config["use_mean_listener"] 128 | batch, time, dim = spectrum.shape 129 | device = spectrum.device 130 | 131 | # get judge embedding 132 | judge_id = (torch.ones(batch, dtype=torch.long) * self.num_judges - 1).to(device) # (bs) 133 | judge_feat = self.judge_embedding(judge_id) # (bs, emb_dim) 134 | judge_feat = torch.stack([judge_feat for i in range(time)], dim = 1) #(batch, time, feat_dim) 135 | 136 | # encoder and inject judge embedding 137 | if self.config["encoder_type"] in ["mobilenetv2", "mobilenetv3"]: 138 | spectrum = spectrum.unsqueeze(1) 139 | encoder_outputs = self.encoder(spectrum) # (batch, ch, time, feat_dim) 140 | encoder_outputs = encoder_outputs.view((batch, time, -1)) # (batch, time, feat_dim) 141 | decoder_inputs = torch.cat([encoder_outputs, judge_feat], dim = -1) # concat along feature dimension 142 | else: 143 | raise NotImplementedError 144 | 145 | # decoder 146 | if self.config["decoder_type"] == "rnn": 147 | decoder_outputs, (h, c) = self.decoder_rnn(decoder_inputs) 148 | else: 149 | decoder_outputs = decoder_inputs 150 | decoder_outputs = self.decoder_dnn(decoder_outputs) # [batch, time, 1 (scalar) / 5 (categorical)] 151 | 152 | # define scores 153 | decoder_outputs = decoder_outputs.squeeze(-1) 154 | scores = torch.mean(decoder_outputs, dim = 1) 155 | return scores 156 | 157 | def average_inference(self, spectrum, include_meanspk=False): 158 | bs, time, _ = spectrum.shape 159 | device = spectrum.device 160 | if self.config["use_mean_listener"] and not include_meanspk: 161 | actual_num_judges = self.num_judges - 1 162 | else: 163 | actual_num_judges = self.num_judges 164 | 165 | # all judge ids 166 | judge_id = torch.arange(actual_num_judges, dtype=torch.long).repeat(bs, 1).to(device) # (bs, nj) 167 | judge_feat = self.judge_embedding(judge_id) # (bs, nj, emb_dim) 168 | judge_feat = torch.stack([judge_feat for i in range(time)], dim = 2) # (bs, nj, time, feat_dim) 169 | 170 | # encoder and inject judge embedding 171 | if self.config["encoder_type"] in ["mobilenetv2", "mobilenetv3"]: 172 | spectrum = spectrum.unsqueeze(1) 173 | encoder_outputs = self.encoder(spectrum) # (batch, ch, time, feat_dim) 174 | encoder_outputs = encoder_outputs.view((bs, time, -1)) # (batch, time, feat_dim) 175 | decoder_inputs = torch.stack([encoder_outputs for i in range(actual_num_judges)], dim = 1) # (bs, nj, time, feat_dim) 176 | decoder_inputs = torch.cat([decoder_inputs, judge_feat], dim = -1) # concat along feature dimension 177 | else: 178 | raise NotImplementedError 179 | 180 | # mean net 181 | if self.config["use_mean_net"]: 182 | mean_net_inputs = encoder_outputs 183 | if self.config["mean_net_type"] == "rnn": 184 | mean_net_outputs, (h, c) = self.mean_net_rnn(mean_net_inputs) 185 | else: 186 | mean_net_outputs = mean_net_inputs 187 | mean_net_outputs = self.mean_net_dnn(mean_net_outputs) # [batch, time, 1 (scalar) / 5 (categorical)] 188 | 189 | # decoder 190 | if self.config["decoder_type"] == "rnn": 191 | decoder_outputs = decoder_inputs.view((bs * actual_num_judges, time, -1)) 192 | decoder_outputs, (h, c) = self.decoder_rnn(decoder_outputs) 193 | else: 194 | decoder_outputs = decoder_inputs 195 | decoder_outputs = self.decoder_dnn(decoder_outputs) # [batch, time, 1 (scalar) / 5 (categorical)] 196 | decoder_outputs = decoder_outputs.view((bs, actual_num_judges, time, -1)) # (bs, nj, time, 1/5) 197 | 198 | if self.config["output_type"] == "scalar": 199 | decoder_outputs = decoder_outputs.squeeze(-1) # (bs, nj, time) 200 | posterior_scores = torch.mean(decoder_outputs, dim=2) 201 | ld_scores = torch.mean(decoder_outputs, dim=1) # (bs, time) 202 | elif self.config["output_type"] == "categorical": 203 | ld_posterior = torch.nn.functional.softmax(decoder_outputs, dim=-1) 204 | ld_scores = torch.inner(ld_posterior, torch.Tensor([1,2,3,4,5]).to(device)) 205 | posterior_scores = torch.mean(ld_scores, dim=2) 206 | ld_scores = torch.mean(ld_scores, dim=1) # (bs, time) 207 | 208 | # define scores 209 | scores = torch.mean(ld_scores, dim = 1) 210 | return scores, posterior_scores -------------------------------------------------------------------------------- /models/MBNet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .modules import Projection, MBNetConvBlocks 7 | 8 | 9 | class MBNet(nn.Module): 10 | def __init__(self, config): 11 | super(MBNet, self).__init__() 12 | self.config = config 13 | 14 | # sanity check for MBNet 15 | assert config["combine_mean_score"] 16 | assert config["output_type"] == "scalar" 17 | assert config["use_mean_net"] 18 | assert config["mean_net_output_type"] == config["output_type"] 19 | 20 | # define judge embedding 21 | self.num_judges = config["num_judges"] 22 | self.judge_embedding = nn.Embedding(num_embeddings = self.num_judges, embedding_dim = config["judge_emb_dim"]) 23 | 24 | # define activation 25 | if config["activation"] == "ReLU": 26 | activation = nn.ReLU 27 | else: 28 | raise NotImplementedError 29 | 30 | # define mean net 31 | if config["use_mean_net"]: 32 | if config["mean_net_input"] == "audio": 33 | self.mean_net_conv = MBNetConvBlocks(1, 34 | config["mean_net_conv_chs"], 35 | config["mean_net_dropout_rate"], 36 | activation 37 | ) 38 | else: 39 | raise NotImplementedError 40 | 41 | self.mean_net_rnn = nn.LSTM(input_size = config["mean_net_conv_chs"][-1] * 4, 42 | hidden_size = config["mean_net_rnn_dim"], 43 | num_layers = 1, batch_first = True, bidirectional = True) 44 | self.mean_net_dnn = Projection(config["mean_net_rnn_dim"] * 2, config["mean_net_dnn_dim"], 45 | activation, config["mean_net_output_type"], config["mean_net_range_clipping"]) 46 | 47 | # define encoder and decoder (a.k.a. bias net) 48 | self.encoder = nn.Conv2d(in_channels = 1, out_channels = 16, kernel_size = (3,3), padding = (1,1), stride = (1,3)) 49 | self.decoder_conv = MBNetConvBlocks(17, 50 | config["decoder_conv_chs"], 51 | config["decoder_dropout_rate"], 52 | activation 53 | ) 54 | decoder_conv_output_dim = self._get_decoder_conv_output_dim(config["audio_input_dim"], len(config["decoder_conv_chs"])+1) # +1: encoder conv 55 | self.decoder_rnn = nn.LSTM(input_size = config["decoder_conv_chs"][-1] * decoder_conv_output_dim, 56 | hidden_size = config["decoder_rnn_dim"], 57 | num_layers = 1, batch_first = True, bidirectional = True) 58 | self.decoder_dnn = Projection(config["decoder_rnn_dim"] * 2, config["decoder_dnn_dim"], 59 | activation, config["output_type"], config["range_clipping"]) 60 | 61 | def _get_decoder_conv_output_dim(self, decoder_conv_input_size, decoder_conv_num_layers): 62 | # w_out = |_ (w_in - 1) / 3 + 1 _| 63 | output_dim = decoder_conv_input_size 64 | for _ in range(decoder_conv_num_layers): 65 | output_dim = math.floor((output_dim-1)/3+1) 66 | return output_dim 67 | 68 | def get_num_params(self): 69 | return sum(p.numel() for n, p in self.named_parameters()) 70 | 71 | def forward(self, spectrum, judge_id): 72 | """Calculate forward propagation. 73 | Args: 74 | spectrum has shape (batch, time, dim) 75 | judge_id has shape (batch) 76 | """ 77 | batch, time, _ = spectrum.shape 78 | 79 | # get judge embedding 80 | judge_feat = self.judge_embedding(judge_id) # (batch, emb_dim) 81 | judge_feat = torch.stack([judge_feat for i in range(time)], dim = 1) #(batch, time, feat_dim) 82 | 83 | # encoder and inject judge embedding 84 | spectrum = spectrum.unsqueeze(1) 85 | encoder_outputs = self.encoder(spectrum) 86 | judge_feat = judge_feat.unsqueeze(1) # (batch, 1, time, feat_dim) 87 | encoder_outputs = torch.cat([encoder_outputs, judge_feat], dim = 1) # concat along channel dimension, resulting in shape [batch, ch, t, d] 88 | 89 | # decoder 90 | decoder_outputs = self.decoder_conv(encoder_outputs) 91 | decoder_outputs = decoder_outputs.view((batch, time, -1)) 92 | decoder_outputs, _ = self.decoder_rnn(decoder_outputs) 93 | decoder_outputs = self.decoder_dnn(decoder_outputs) # [batch, time, 1 (scalar) / 5 (categorical)] 94 | 95 | # mean net 96 | if self.config["use_mean_net"]: 97 | if self.config["mean_net_input"] == "audio": 98 | mean_net_inputs = spectrum 99 | #mean_net_inputs = mean_net_inputs.unsqueeze(1) 100 | else: 101 | raise NotImplementedError 102 | mean_net_outputs = self.mean_net_conv(mean_net_inputs) 103 | mean_net_outputs = mean_net_outputs.view((batch, time, 512)) 104 | mean_net_outputs, _ = self.mean_net_rnn(mean_net_outputs) 105 | mean_net_outputs = self.mean_net_dnn(mean_net_outputs) # (batch, time, 1 (scalar) / 5 (categorical) 106 | 107 | # define scores 108 | mean_score = mean_net_outputs if self.config["use_mean_net"] else None 109 | ld_score = decoder_outputs + mean_net_outputs if self.config["combine_mean_score"] else decoder_outputs 110 | 111 | return mean_score, ld_score 112 | 113 | def only_mean_inference(self, spectrum): 114 | assert self.config["use_mean_net"] 115 | batch, time, _ = spectrum.shape 116 | spectrum = spectrum.unsqueeze(1) 117 | 118 | if self.config["mean_net_input"] == "audio": 119 | mean_net_inputs = spectrum 120 | else: 121 | raise NotImplementedError 122 | mean_net_outputs = self.mean_net_conv(mean_net_inputs) 123 | mean_net_outputs = mean_net_outputs.view((batch, time, 512)) 124 | mean_net_outputs, _ = self.mean_net_rnn(mean_net_outputs) 125 | mean_net_outputs = self.mean_net_dnn(mean_net_outputs) # (batch, seq, 1) 126 | mean_net_outputs = mean_net_outputs.squeeze(-1) 127 | 128 | mean_scores = torch.mean(mean_net_outputs, dim = -1) 129 | return mean_scores 130 | 131 | def average_inference_v1(self, spectrum, include_meanspk=False): 132 | bs, time, _ = spectrum.shape 133 | device = spectrum.device 134 | if self.config["use_mean_listener"] and not include_meanspk: 135 | actual_num_judges = self.num_judges - 1 136 | else: 137 | actual_num_judges = self.num_judges 138 | 139 | # all judge ids 140 | judge_id = torch.arange(actual_num_judges, dtype=torch.long).repeat(bs, 1).to(device) # (bs, nj) 141 | judge_feat = self.judge_embedding(judge_id) # (bs, nj, emb_dim) 142 | judge_feat = torch.stack([judge_feat for i in range(time)], dim = 2) # (bs, nj, time, feat_dim) 143 | 144 | # encoder and inject judge embedding 145 | if self.config["encoder_type"] == "simple": 146 | # encoder_outputs has shape [batch, ch, t, d] 147 | spectrum = spectrum.unsqueeze(1) 148 | encoder_outputs = self.encoder(spectrum) 149 | encoder_outputs = torch.stack([encoder_outputs for i in range(actual_num_judges)], dim = 1) # (bs, nj, ch, time, feat_dim) 150 | judge_feat = judge_feat.unsqueeze(2) # (batch, nj, 1, time, feat_dim) 151 | encoder_outputs = torch.cat([encoder_outputs, judge_feat], dim = 2) # concat along channel dimension 152 | elif self.config["encoder_type"] == "taco2": 153 | # encoder_outputs has shape [batch, t, d] 154 | # concat along feature dimension 155 | encoder_outputs = self.encoder(spectrum) # (bs, time, hidden_dim) 156 | encoder_outputs = torch.stack([encoder_outputs for i in range(actual_num_judges)], dim = 1) # (bs, nj, time, feat_dim) 157 | encoder_outputs = torch.cat([encoder_outputs, judge_feat], dim = 3) 158 | encoder_outputs = encoder_outputs.unsqueeze(2) # (batch, nj, 1, time, feat_dim) 159 | else: 160 | raise NotImplementedError 161 | 162 | # decoder conv 163 | encoder_outputs = encoder_outputs.view(-1, *encoder_outputs.shape[-3:]) 164 | decoder_outputs = self.decoder_conv(encoder_outputs) 165 | 166 | # decoder rnn 167 | decoder_outputs = decoder_outputs.view((bs * actual_num_judges, time, -1)) 168 | decoder_outputs, (h, c) = self.decoder_rnn(decoder_outputs) 169 | decoder_outputs = self.decoder_dnn(decoder_outputs) # (bs * nj, time, 1/5) 170 | decoder_outputs = decoder_outputs.view((bs, actual_num_judges, time, -1)) # (bs, nj, time, 1/5) 171 | if self.config["output_type"] == "scalar": 172 | decoder_outputs = decoder_outputs.squeeze(-1) 173 | posterior_scores = torch.mean(decoder_outputs, dim=2) 174 | ld_scores = torch.mean(decoder_outputs, dim=1) # (bs, time) 175 | elif self.config["output_type"] == "categorical": 176 | ld_posterior = torch.nn.functional.softmax(decoder_outputs, dim=-1) 177 | ld_scores = torch.inner(ld_posterior, torch.Tensor([1,2,3,4,5]).to(device)) 178 | ld_scores = torch.mean(ld_scores, dim=1) # (bs, time) 179 | 180 | # mean net 181 | if self.config["use_mean_net"]: 182 | if self.config["mean_net_input"] == "audio": 183 | mean_net_inputs = spectrum 184 | else: 185 | raise NotImplementedError 186 | #print("mean net input shape", mean_net_inputs.shape) 187 | mean_net_outputs = self.mean_net_conv(mean_net_inputs) 188 | mean_net_outputs = mean_net_outputs.view((bs, time, -1)) 189 | mean_net_outputs, (h, c) = self.mean_net_rnn(mean_net_outputs) 190 | mean_net_outputs = self.mean_net_dnn(mean_net_outputs) # (batch, seq, 1) 191 | mean_scores = mean_net_outputs.squeeze(-1) # (bs, time) 192 | 193 | # define scores 194 | if self.config["combine_mean_score"]: 195 | ld_scores += mean_scores 196 | scores = torch.mean(ld_scores, dim = 1) 197 | return scores, posterior_scores -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unilight/LDNet/710ed6979b96868643b4a0c916b43b64d8a95a94/models/__init__.py -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from utils import make_non_pad_mask 6 | 7 | 8 | class Loss(nn.Module): 9 | """ 10 | Implements the clipped MSE loss and categorical loss. 11 | """ 12 | def __init__(self, output_type, alpha, lamb, tau, masked_loss): 13 | super(Loss, self).__init__() 14 | self.output_type = output_type 15 | self.alpha = alpha 16 | self.lamb = lamb 17 | self.tau = tau 18 | self.masked_loss = masked_loss 19 | 20 | if output_type == "scalar": 21 | criterion = torch.nn.MSELoss 22 | elif output_type == "categorical": 23 | criterion = torch.nn.CrossEntropyLoss 24 | 25 | if self.alpha > 0: 26 | self.mean_net_criterion = criterion(reduction="none") 27 | self.main_criterion = criterion(reduction="none") 28 | 29 | def forward_criterion(self, y_hat, label, criterion_module, masks=None): 30 | # might investigate how to combine masked loss with categorical output 31 | if masks is not None: 32 | y_hat = y_hat.masked_select(masks) 33 | label = label.masked_select(masks) 34 | 35 | if self.output_type == "scalar": 36 | y_hat = y_hat.squeeze(-1) 37 | loss = criterion_module(y_hat, label) 38 | threshold = torch.abs(y_hat - label) > self.tau 39 | loss = torch.mean(threshold * loss) 40 | elif self.output_type == "categorical": 41 | # y_hat must have shape [..., 5] 42 | y_hat = y_hat.view((-1, 5)) 43 | label = torch.flatten(label).type(torch.long) 44 | loss = criterion_module(y_hat, label-1) 45 | loss = torch.mean(loss) 46 | return loss 47 | 48 | def forward(self, pred_mean, gt_mean, pred_score, gt_score, lens, device): 49 | """ 50 | Args: 51 | pred_mean, pred_score: [batch, time, 1/5] 52 | """ 53 | # make mask 54 | if self.masked_loss: 55 | masks = make_non_pad_mask(lens).to(device) 56 | else: 57 | masks = None 58 | 59 | # repeat for frame level loss 60 | time = pred_score.shape[1] 61 | gt_mean = gt_mean.unsqueeze(1).repeat(1, time) 62 | gt_score = gt_score.unsqueeze(1).repeat(1, time) 63 | 64 | main_loss = self.forward_criterion(pred_score, gt_score, self.main_criterion, masks) 65 | if self.alpha > 0: 66 | mean_net_loss = self.forward_criterion(pred_mean, gt_mean, self.mean_net_criterion, masks) 67 | return self.alpha * mean_net_loss + self.lamb * main_loss, mean_net_loss, main_loss 68 | else: 69 | return self.lamb * main_loss, None, main_loss 70 | 71 | ##################################################################################### 72 | 73 | # Categorical loss was not useful in initial experiments, but I keep it here for future reference 74 | 75 | class CategoricalLoss(nn.Module): 76 | def __init__(self, alpha, lamb): 77 | super(CategoricalLoss, self).__init__() 78 | self.alpha = alpha 79 | self.lamb = lamb 80 | 81 | if self.alpha > 0: 82 | self.mean_net_criterion = nn.CrossEntropyLoss() 83 | self.main_criterion = nn.CrossEntropyLoss() 84 | 85 | def ce(self, y_hat, label, criterion, masks=None): 86 | if masks is not None: 87 | y_hat = y_hat.masked_select(masks) 88 | label = label.masked_select(masks) 89 | ce = criterion(y_hat, label-1) 90 | return ce 91 | 92 | def forward(self, pred_mean, gt_mean, pred_score, gt_score, lens, device): 93 | # make mask 94 | if self.masked_loss: 95 | masks = make_non_pad_mask(lens).to(device) 96 | else: 97 | masks = None 98 | 99 | score_ce = self.ce(pred_score, gt_score, self.main_criterion, masks) 100 | if self.alpha > 0: 101 | mean_ce = self.ce(pred_mean, gt_mean, self.mean_net_criterion, masks) 102 | return self.alpha * mean_ce + self.lamb * score_ce, mean_ce, score_ce 103 | else: 104 | return self.lamb * score_ce, None, score_ce 105 | -------------------------------------------------------------------------------- /models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch import Tensor 4 | from typing import Callable, Any, Optional, List 5 | 6 | 7 | __all__ = ['MobileNetV2', 'mobilenet_v2'] 8 | 9 | 10 | model_urls = { 11 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', 12 | } 13 | 14 | 15 | def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int: 16 | """ 17 | This function is taken from the original tf repo. 18 | It ensures that all layers have a channel number that is divisible by 8 19 | It can be seen here: 20 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 21 | """ 22 | if min_value is None: 23 | min_value = divisor 24 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 25 | # Make sure that round down does not go down by more than 10%. 26 | if new_v < 0.9 * v: 27 | new_v += divisor 28 | return new_v 29 | 30 | 31 | class ConvBNActivation(nn.Sequential): 32 | def __init__( 33 | self, 34 | in_planes: int, 35 | out_planes: int, 36 | kernel_size: int = 3, 37 | stride: int = 1, 38 | groups: int = 1, 39 | norm_layer: Optional[Callable[..., nn.Module]] = None, 40 | activation_layer: Optional[Callable[..., nn.Module]] = None, 41 | dilation: int = 1, 42 | ) -> None: 43 | padding = (kernel_size - 1) // 2 * dilation 44 | if norm_layer is None: 45 | norm_layer = nn.BatchNorm2d 46 | if activation_layer is None: 47 | activation_layer = nn.ReLU6 48 | super().__init__( 49 | # NOTE(unilight): stride only operates on the last axis 50 | nn.Conv2d(in_planes, out_planes, kernel_size, (1, stride), padding, dilation=dilation, groups=groups, 51 | bias=False), 52 | norm_layer(out_planes), 53 | activation_layer(inplace=True) 54 | ) 55 | self.out_channels = out_planes 56 | 57 | 58 | # necessary for backwards compatibility 59 | ConvBNReLU = ConvBNActivation 60 | 61 | 62 | class InvertedResidual(nn.Module): 63 | def __init__( 64 | self, 65 | inp: int, 66 | oup: int, 67 | stride: int, 68 | expand_ratio: int, 69 | norm_layer: Optional[Callable[..., nn.Module]] = None 70 | ) -> None: 71 | super(InvertedResidual, self).__init__() 72 | self.stride = stride 73 | assert stride in [1, 2, 3] 74 | 75 | if norm_layer is None: 76 | norm_layer = nn.BatchNorm2d 77 | 78 | hidden_dim = int(round(inp * expand_ratio)) 79 | self.use_res_connect = self.stride == 1 and inp == oup 80 | 81 | layers: List[nn.Module] = [] 82 | if expand_ratio != 1: 83 | # pw 84 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer)) 85 | layers.extend([ 86 | # dw 87 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer), 88 | # pw-linear 89 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 90 | norm_layer(oup), 91 | ]) 92 | self.conv = nn.Sequential(*layers) 93 | self.out_channels = oup 94 | self._is_cn = stride > 1 95 | 96 | def forward(self, x: Tensor) -> Tensor: 97 | if self.use_res_connect: 98 | return x + self.conv(x) 99 | else: 100 | return self.conv(x) 101 | 102 | 103 | class MobileNetV2(nn.Module): 104 | def __init__( 105 | self, 106 | num_classes: int = 1000, 107 | width_mult: float = 1.0, 108 | inverted_residual_setting: Optional[List[List[int]]] = None, 109 | round_nearest: int = 8, 110 | block: Optional[Callable[..., nn.Module]] = None, 111 | norm_layer: Optional[Callable[..., nn.Module]] = None 112 | ) -> None: 113 | """ 114 | MobileNet V2 main class 115 | 116 | Args: 117 | num_classes (int): Number of classes 118 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 119 | inverted_residual_setting: Network structure 120 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 121 | Set to 1 to turn off rounding 122 | block: Module specifying inverted residual building block for mobilenet 123 | norm_layer: Module specifying the normalization layer to use 124 | 125 | """ 126 | super(MobileNetV2, self).__init__() 127 | 128 | if block is None: 129 | block = InvertedResidual 130 | 131 | if norm_layer is None: 132 | norm_layer = nn.BatchNorm2d 133 | 134 | input_channel = 32 135 | last_channel = 1280 136 | 137 | if inverted_residual_setting is None: 138 | inverted_residual_setting = [ 139 | # t, c, n, s 140 | [1, 16, 1, 1], 141 | [6, 24, 2, 2], 142 | [6, 32, 3, 2], 143 | [6, 64, 4, 2], 144 | [6, 96, 3, 1], 145 | [6, 160, 3, 2], 146 | [6, 320, 1, 1], 147 | ] 148 | 149 | # only check the first element, assuming user knows t,c,n,s are required 150 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 151 | raise ValueError("inverted_residual_setting should be non-empty " 152 | "or a 4-element list, got {}".format(inverted_residual_setting)) 153 | 154 | # building first layer 155 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 156 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 157 | features: List[nn.Module] = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)] 158 | # building inverted residual blocks 159 | for t, c, n, s in inverted_residual_setting: 160 | output_channel = _make_divisible(c * width_mult, round_nearest) 161 | for i in range(n): 162 | stride = s if i == 0 else 1 163 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer)) 164 | input_channel = output_channel 165 | # building last several layers 166 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer)) 167 | # make it nn.Sequential 168 | self.features = nn.Sequential(*features) 169 | 170 | # building classifier 171 | self.classifier = nn.Sequential( 172 | nn.Dropout(0.2), 173 | nn.Linear(self.last_channel, num_classes), 174 | ) 175 | 176 | # weight initialization 177 | for m in self.modules(): 178 | if isinstance(m, nn.Conv2d): 179 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 180 | if m.bias is not None: 181 | nn.init.zeros_(m.bias) 182 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 183 | nn.init.ones_(m.weight) 184 | nn.init.zeros_(m.bias) 185 | elif isinstance(m, nn.Linear): 186 | nn.init.normal_(m.weight, 0, 0.01) 187 | nn.init.zeros_(m.bias) 188 | 189 | def _forward_impl(self, x: Tensor) -> Tensor: 190 | # This exists since TorchScript doesn't support inheritance, so the superclass method 191 | # (this one) needs to have a name other than `forward` that can be accessed in a subclass 192 | x = self.features(x) 193 | # Cannot use "squeeze" as batch-size can be 1 194 | x = nn.functional.adaptive_avg_pool2d(x, (1, 1)) 195 | x = torch.flatten(x, 1) 196 | x = self.classifier(x) 197 | return x 198 | 199 | def forward(self, x: Tensor) -> Tensor: 200 | return self._forward_impl(x) 201 | -------------------------------------------------------------------------------- /models/mobilenetv3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from functools import partial 4 | from torch import nn, Tensor 5 | from torch.nn import functional as F 6 | from typing import Any, Callable, Dict, List, Optional, Sequence 7 | 8 | from .mobilenetv2 import _make_divisible, ConvBNActivation 9 | 10 | 11 | __all__ = ["MobileNetV3", "mobilenet_v3_large", "mobilenet_v3_small"] 12 | 13 | 14 | model_urls = { 15 | "mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth", 16 | "mobilenet_v3_small": "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth", 17 | } 18 | 19 | 20 | class SqueezeExcitation(nn.Module): 21 | # Implemented as described at Figure 4 of the MobileNetV3 paper 22 | def __init__(self, input_channels: int, squeeze_factor: int = 4): 23 | super().__init__() 24 | squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8) 25 | self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1) 28 | 29 | def _scale(self, input: Tensor, inplace: bool) -> Tensor: 30 | scale = F.adaptive_avg_pool2d(input, 1) 31 | scale = self.fc1(scale) 32 | scale = self.relu(scale) 33 | scale = self.fc2(scale) 34 | return F.hardsigmoid(scale, inplace=inplace) 35 | 36 | def forward(self, input: Tensor) -> Tensor: 37 | scale = self._scale(input, True) 38 | return scale * input 39 | 40 | 41 | class InvertedResidualConfig: 42 | # Stores information listed at Tables 1 and 2 of the MobileNetV3 paper 43 | def __init__(self, input_channels: int, kernel: int, expanded_channels: int, out_channels: int, use_se: bool, 44 | activation: str, stride: int, dilation: int, width_mult: float): 45 | self.input_channels = self.adjust_channels(input_channels, width_mult) 46 | self.kernel = kernel 47 | self.expanded_channels = self.adjust_channels(expanded_channels, width_mult) 48 | self.out_channels = self.adjust_channels(out_channels, width_mult) 49 | self.use_se = use_se 50 | self.use_hs = activation == "HS" 51 | self.stride = stride 52 | self.dilation = dilation 53 | 54 | @staticmethod 55 | def adjust_channels(channels: int, width_mult: float): 56 | return _make_divisible(channels * width_mult, 8) 57 | 58 | 59 | class InvertedResidual(nn.Module): 60 | # Implemented as described at section 5 of MobileNetV3 paper 61 | def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Module], 62 | se_layer: Callable[..., nn.Module] = SqueezeExcitation): 63 | super().__init__() 64 | if not (1 <= cnf.stride <= 3): 65 | raise ValueError('illegal stride value') 66 | 67 | self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels 68 | 69 | layers: List[nn.Module] = [] 70 | activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU 71 | 72 | # expand 73 | if cnf.expanded_channels != cnf.input_channels: 74 | layers.append(ConvBNActivation(cnf.input_channels, cnf.expanded_channels, kernel_size=1, 75 | norm_layer=norm_layer, activation_layer=activation_layer)) 76 | 77 | # depthwise 78 | stride = 1 if cnf.dilation > 1 else cnf.stride 79 | layers.append(ConvBNActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel, 80 | stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels, 81 | norm_layer=norm_layer, activation_layer=activation_layer)) 82 | if cnf.use_se: 83 | layers.append(se_layer(cnf.expanded_channels)) 84 | 85 | # project 86 | layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, 87 | activation_layer=nn.Identity)) 88 | 89 | self.block = nn.Sequential(*layers) 90 | self.out_channels = cnf.out_channels 91 | self._is_cn = cnf.stride > 1 92 | 93 | def forward(self, input: Tensor) -> Tensor: 94 | result = self.block(input) 95 | if self.use_res_connect: 96 | result += input 97 | return result 98 | 99 | 100 | class MobileNetV3(nn.Module): 101 | 102 | def __init__( 103 | self, 104 | inverted_residual_setting: List[InvertedResidualConfig], 105 | last_channel: int, 106 | num_classes: int = 1000, 107 | block: Optional[Callable[..., nn.Module]] = None, 108 | norm_layer: Optional[Callable[..., nn.Module]] = None, 109 | **kwargs: Any 110 | ) -> None: 111 | """ 112 | MobileNet V3 main class 113 | 114 | Args: 115 | inverted_residual_setting (List[InvertedResidualConfig]): Network structure 116 | last_channel (int): The number of channels on the penultimate layer 117 | num_classes (int): Number of classes 118 | block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for mobilenet 119 | norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use 120 | """ 121 | super().__init__() 122 | 123 | if not inverted_residual_setting: 124 | raise ValueError("The inverted_residual_setting should not be empty") 125 | elif not (isinstance(inverted_residual_setting, Sequence) and 126 | all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting])): 127 | raise TypeError("The inverted_residual_setting should be List[InvertedResidualConfig]") 128 | 129 | if block is None: 130 | block = InvertedResidual 131 | 132 | if norm_layer is None: 133 | norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.01) 134 | 135 | layers: List[nn.Module] = [] 136 | 137 | # building first layer 138 | firstconv_output_channels = inverted_residual_setting[0].input_channels 139 | layers.append(ConvBNActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, 140 | activation_layer=nn.Hardswish)) 141 | 142 | # building inverted residual blocks 143 | for cnf in inverted_residual_setting: 144 | layers.append(block(cnf, norm_layer)) 145 | 146 | # building last several layers 147 | lastconv_input_channels = inverted_residual_setting[-1].out_channels 148 | lastconv_output_channels = 6 * lastconv_input_channels 149 | layers.append(ConvBNActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1, 150 | norm_layer=norm_layer, activation_layer=nn.Hardswish)) 151 | 152 | self.features = nn.Sequential(*layers) 153 | self.avgpool = nn.AdaptiveAvgPool2d(1) 154 | self.classifier = nn.Sequential( 155 | nn.Linear(lastconv_output_channels, last_channel), 156 | nn.Hardswish(inplace=True), 157 | nn.Dropout(p=0.2, inplace=True), 158 | nn.Linear(last_channel, num_classes), 159 | ) 160 | 161 | for m in self.modules(): 162 | if isinstance(m, nn.Conv2d): 163 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 164 | if m.bias is not None: 165 | nn.init.zeros_(m.bias) 166 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 167 | nn.init.ones_(m.weight) 168 | nn.init.zeros_(m.bias) 169 | elif isinstance(m, nn.Linear): 170 | nn.init.normal_(m.weight, 0, 0.01) 171 | nn.init.zeros_(m.bias) 172 | 173 | def _forward_impl(self, x: Tensor) -> Tensor: 174 | x = self.features(x) 175 | 176 | x = self.avgpool(x) 177 | x = torch.flatten(x, 1) 178 | 179 | x = self.classifier(x) 180 | 181 | return x 182 | 183 | def forward(self, x: Tensor) -> Tensor: 184 | return self._forward_impl(x) 185 | 186 | 187 | def _mobilenet_v3_conf(arch: str, width_mult: float = 1.0, reduced_tail: bool = False, dilated: bool = False, 188 | **kwargs: Any): 189 | reduce_divider = 2 if reduced_tail else 1 190 | dilation = 2 if dilated else 1 191 | 192 | bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult) 193 | adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult) 194 | 195 | if arch == "mobilenet_v3_large": 196 | inverted_residual_setting = [ 197 | bneck_conf(16, 3, 16, 16, False, "RE", 1, 1), 198 | bneck_conf(16, 3, 64, 24, False, "RE", 2, 1), # C1 199 | bneck_conf(24, 3, 72, 24, False, "RE", 1, 1), 200 | bneck_conf(24, 5, 72, 40, True, "RE", 2, 1), # C2 201 | bneck_conf(40, 5, 120, 40, True, "RE", 1, 1), 202 | bneck_conf(40, 5, 120, 40, True, "RE", 1, 1), 203 | bneck_conf(40, 3, 240, 80, False, "HS", 2, 1), # C3 204 | bneck_conf(80, 3, 200, 80, False, "HS", 1, 1), 205 | bneck_conf(80, 3, 184, 80, False, "HS", 1, 1), 206 | bneck_conf(80, 3, 184, 80, False, "HS", 1, 1), 207 | bneck_conf(80, 3, 480, 112, True, "HS", 1, 1), 208 | bneck_conf(112, 3, 672, 112, True, "HS", 1, 1), 209 | bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", 2, dilation), # C4 210 | bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation), 211 | bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation), 212 | ] 213 | last_channel = adjust_channels(1280 // reduce_divider) # C5 214 | elif arch == "mobilenet_v3_small": 215 | inverted_residual_setting = [ 216 | bneck_conf(16, 3, 16, 16, True, "RE", 2, 1), # C1 217 | bneck_conf(16, 3, 72, 24, False, "RE", 2, 1), # C2 218 | bneck_conf(24, 3, 88, 24, False, "RE", 1, 1), 219 | bneck_conf(24, 5, 96, 40, True, "HS", 2, 1), # C3 220 | bneck_conf(40, 5, 240, 40, True, "HS", 1, 1), 221 | bneck_conf(40, 5, 240, 40, True, "HS", 1, 1), 222 | bneck_conf(40, 5, 120, 48, True, "HS", 1, 1), 223 | bneck_conf(48, 5, 144, 48, True, "HS", 1, 1), 224 | bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2, dilation), # C4 225 | bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation), 226 | bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation), 227 | ] 228 | last_channel = adjust_channels(1024 // reduce_divider) # C5 229 | else: 230 | raise ValueError("Unsupported model type {}".format(arch)) 231 | 232 | return inverted_residual_setting, last_channel 233 | 234 | 235 | def _mobilenet_v3_model( 236 | arch: str, 237 | inverted_residual_setting: List[InvertedResidualConfig], 238 | last_channel: int, 239 | pretrained: bool, 240 | progress: bool, 241 | **kwargs: Any 242 | ): 243 | model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs) 244 | if pretrained: 245 | if model_urls.get(arch, None) is None: 246 | raise ValueError("No checkpoint is available for model type {}".format(arch)) 247 | state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) 248 | model.load_state_dict(state_dict) 249 | return model 250 | 251 | 252 | def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3: 253 | """ 254 | Constructs a large MobileNetV3 architecture from 255 | `"Searching for MobileNetV3" `_. 256 | 257 | Args: 258 | pretrained (bool): If True, returns a model pre-trained on ImageNet 259 | progress (bool): If True, displays a progress bar of the download to stderr 260 | """ 261 | arch = "mobilenet_v3_large" 262 | inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, **kwargs) 263 | return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs) 264 | 265 | 266 | def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3: 267 | """ 268 | Constructs a small MobileNetV3 architecture from 269 | `"Searching for MobileNetV3" `_. 270 | 271 | Args: 272 | pretrained (bool): If True, returns a model pre-trained on ImageNet 273 | progress (bool): If True, displays a progress bar of the download to stderr 274 | """ 275 | arch = "mobilenet_v3_small" 276 | inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, **kwargs) 277 | return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs) 278 | -------------------------------------------------------------------------------- /models/modules.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import List 3 | from torch import nn 4 | import torch 5 | 6 | from .mobilenetv2 import ConvBNActivation, InvertedResidual 7 | from .mobilenetv3 import InvertedResidual as InvertedResidualV3 8 | from .mobilenetv3 import InvertedResidualConfig 9 | 10 | STRIDE=3 11 | 12 | class Projection(nn.Module): 13 | def __init__(self, in_dim, hidden_dim, activation, output_type, range_clipping=False): 14 | super(Projection, self).__init__() 15 | self.output_type = output_type 16 | self.range_clipping = range_clipping 17 | if output_type == "scalar": 18 | output_dim = 1 19 | if range_clipping: 20 | self.proj = nn.Tanh() 21 | elif output_type == "categorical": 22 | output_dim = 5 23 | else: 24 | raise NotImplementedError("wrong output_type: {}".format(output_type)) 25 | 26 | self.net = nn.Sequential( 27 | nn.Linear(in_dim, hidden_dim), 28 | activation(), 29 | nn.Dropout(0.3), 30 | nn.Linear(hidden_dim, output_dim), 31 | ) 32 | 33 | def forward(self, x): 34 | output = self.net(x) 35 | 36 | # range clipping 37 | if self.output_type == "scalar" and self.range_clipping: 38 | return self.proj(output) * 2.0 + 3 39 | else: 40 | return output 41 | 42 | 43 | class MBNetConvBlocks(nn.Module): 44 | def __init__(self, in_channel, ch_list, dropout_rate, activation): 45 | super(MBNetConvBlocks, self).__init__() 46 | 47 | num_layers = len(ch_list) 48 | assert num_layers > 0 49 | self.convs = torch.nn.ModuleList() 50 | for layer in range(num_layers): 51 | in_ch = in_channel if layer == 0 else ch_list[layer-1] 52 | out_ch = ch_list[layer] 53 | self.convs += [ 54 | torch.nn.Sequential( 55 | nn.Conv2d(in_channels = in_ch, out_channels = out_ch, kernel_size = (3,3), padding = (1,1)), 56 | nn.Conv2d(in_channels = out_ch, out_channels = out_ch, kernel_size = (3,3), padding = (1,1)), 57 | nn.Conv2d(in_channels = out_ch, out_channels = out_ch, kernel_size = (3,3), padding = (1,1), stride=(1,3)), 58 | nn.Dropout(dropout_rate), 59 | nn.BatchNorm2d(out_ch), 60 | activation(), 61 | ) 62 | ] 63 | def forward(self, x): 64 | for i in range(len(self.convs)): 65 | x = self.convs[i](x) 66 | return x 67 | 68 | class MobileNetV2ConvBlocks(nn.Module): 69 | def __init__(self, first_ch, t_list, c_list, n_list, s_list, output_dim): 70 | super(MobileNetV2ConvBlocks, self).__init__() 71 | block = InvertedResidual 72 | norm_layer = nn.BatchNorm2d 73 | 74 | # first layer 75 | features: List[nn.Module] = [ConvBNActivation(1, first_ch, stride=STRIDE, norm_layer=norm_layer)] 76 | in_ch = first_ch 77 | # bottleneck layers 78 | for t, c, n, s in zip(t_list, c_list, n_list, s_list): 79 | out_ch = c 80 | for i in range(n): 81 | stride = s if i == 0 else 1 82 | features.append(block(in_ch, out_ch, stride, expand_ratio=t, norm_layer=norm_layer)) 83 | in_ch = out_ch 84 | # last layer 85 | features.append(ConvBNActivation(in_ch, output_dim, kernel_size=1, norm_layer=norm_layer)) 86 | self.features = nn.Sequential(*features) 87 | 88 | # weight initialization 89 | for m in self.modules(): 90 | if isinstance(m, nn.Conv2d): 91 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 92 | if m.bias is not None: 93 | nn.init.zeros_(m.bias) 94 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 95 | nn.init.ones_(m.weight) 96 | nn.init.zeros_(m.bias) 97 | elif isinstance(m, nn.Linear): 98 | nn.init.normal_(m.weight, 0, 0.01) 99 | nn.init.zeros_(m.bias) 100 | 101 | def forward(self, x): 102 | time = x.shape[2] 103 | x = self.features(x) 104 | x = nn.functional.adaptive_avg_pool2d(x, (time, 1)) 105 | x = x.squeeze(-1).transpose(1, 2) 106 | return x 107 | 108 | class MobileNetV3ConvBlocks(nn.Module): 109 | def __init__(self, bneck_confs, output_dim): 110 | super(MobileNetV3ConvBlocks, self).__init__() 111 | 112 | bneck_conf = partial(InvertedResidualConfig, width_mult=1) 113 | inverted_residual_setting = [bneck_conf(*b_conf) for b_conf in bneck_confs] 114 | 115 | block = InvertedResidualV3 116 | 117 | # Never tested if a different eps and momentum is needed 118 | #norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.01) 119 | norm_layer = nn.BatchNorm2d 120 | 121 | layers: List[nn.Module] = [] 122 | 123 | # building first layer 124 | firstconv_output_channels = inverted_residual_setting[0].input_channels 125 | layers.append(ConvBNActivation(1, firstconv_output_channels, kernel_size=3, stride=STRIDE, norm_layer=norm_layer, 126 | activation_layer=nn.Hardswish)) 127 | 128 | # building inverted residual blocks 129 | for cnf in inverted_residual_setting: 130 | layers.append(block(cnf, norm_layer)) 131 | 132 | # building last several layers 133 | lastconv_input_channels = inverted_residual_setting[-1].out_channels 134 | lastconv_output_channels = output_dim 135 | layers.append(ConvBNActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1, 136 | norm_layer=norm_layer, activation_layer=nn.Hardswish)) 137 | self.features = nn.Sequential(*layers) 138 | 139 | for m in self.modules(): 140 | if isinstance(m, nn.Conv2d): 141 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 142 | if m.bias is not None: 143 | nn.init.zeros_(m.bias) 144 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 145 | nn.init.ones_(m.weight) 146 | nn.init.zeros_(m.bias) 147 | elif isinstance(m, nn.Linear): 148 | nn.init.normal_(m.weight, 0, 0.01) 149 | nn.init.zeros_(m.bias) 150 | 151 | def forward(self, x): 152 | time = x.shape[2] 153 | x = self.features(x) 154 | x = nn.functional.adaptive_avg_pool2d(x, (time, 1)) 155 | x = x.squeeze(-1).transpose(1, 2) 156 | return x -------------------------------------------------------------------------------- /optimizers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import copy 3 | 4 | from typing import Callable, Iterable, Tuple 5 | 6 | import torch 7 | from torch.optim import Optimizer, Adam, RMSprop 8 | 9 | # Reference: https://github.com/s3prl/s3prl/blob/master/s3prl/optimizers.py 10 | 11 | def get_optimizer(model, total_steps, optimizer_config): 12 | optimizer_config = copy.deepcopy(optimizer_config) 13 | optimizer_name = optimizer_config.pop('name') 14 | optimizer = eval(f'get_{optimizer_name}')( 15 | model, 16 | total_steps=total_steps, 17 | **optimizer_config 18 | ) 19 | return optimizer 20 | 21 | 22 | def get_grouped_parameters(model): 23 | named_params = model.named_parameters() 24 | 25 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 26 | grouped_parameters = [ 27 | {'params': [p for n, p in named_params if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 28 | {'params': [p for n, p in named_params if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}, 29 | ] 30 | return grouped_parameters 31 | 32 | 33 | 34 | def get_Adam(model, lr=2e-4, **kwargs): 35 | params = model.parameters() 36 | return Adam(params, lr=lr, betas=(0.9, 0.999)) 37 | 38 | def get_RMSprop(model, lr=2e-4, eps=1e-08, alpha=0.99, **kwargs): 39 | params = model.parameters() 40 | return RMSprop(params, lr=lr, momentum=0.9, weight_decay=1e-5, eps=eps, alpha=alpha) 41 | 42 | 43 | def get_AdamW(model, lr=2e-4, **kwargs): 44 | params = model.parameters() 45 | optimizer = AdamW(params, lr=lr) 46 | return optimizer 47 | 48 | 49 | def get_TorchOptim(model, torch_optim_name, **kwargs): 50 | params = model.parameters() 51 | Opt_class = getattr(torch.optim, torch_optim_name) 52 | 53 | kwargs.pop('total_steps') 54 | optim = Opt_class(params, **kwargs) 55 | return optim 56 | 57 | 58 | class AdamW(Optimizer): 59 | """ 60 | Implements Adam algorithm with weight decay fix as introduced in 61 | `Decoupled Weight Decay Regularization `__. 62 | Parameters: 63 | params (:obj:`Iterable[torch.nn.parameter.Parameter]`): 64 | Iterable of parameters to optimize or dictionaries defining parameter groups. 65 | lr (:obj:`float`, `optional`, defaults to 1e-3): 66 | The learning rate to use. 67 | betas (:obj:`Tuple[float,float]`, `optional`, defaults to (0.9, 0.999)): 68 | Adam's betas parameters (b1, b2). 69 | eps (:obj:`float`, `optional`, defaults to 1e-6): 70 | Adam's epsilon for numerical stability. 71 | weight_decay (:obj:`float`, `optional`, defaults to 0): 72 | Decoupled weight decay to apply. 73 | correct_bias (:obj:`bool`, `optional`, defaults to `True`): 74 | Whether ot not to correct bias in Adam (for instance, in Bert TF repository they use :obj:`False`). 75 | """ 76 | 77 | def __init__( 78 | self, 79 | params: Iterable[torch.nn.parameter.Parameter], 80 | lr: float = 1e-3, 81 | betas: Tuple[float, float] = (0.9, 0.999), 82 | eps: float = 1e-7, 83 | weight_decay: float = 0.0, 84 | correct_bias: bool = True, 85 | ): 86 | if lr < 0.0: 87 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 88 | if not 0.0 <= betas[0] < 1.0: 89 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0])) 90 | if not 0.0 <= betas[1] < 1.0: 91 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1])) 92 | if not 0.0 <= eps: 93 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps)) 94 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias) 95 | super().__init__(params, defaults) 96 | 97 | def step(self, closure: Callable = None): 98 | """ 99 | Performs a single optimization step. 100 | Arguments: 101 | closure (:obj:`Callable`, `optional`): A closure that reevaluates the model and returns the loss. 102 | """ 103 | loss = None 104 | if closure is not None: 105 | loss = closure() 106 | 107 | for group in self.param_groups: 108 | for p in group["params"]: 109 | if p.grad is None: 110 | continue 111 | grad = p.grad.data 112 | if grad.is_sparse: 113 | raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") 114 | 115 | state = self.state[p] 116 | 117 | # State initialization 118 | if len(state) == 0: 119 | state["step"] = 0 120 | # Exponential moving average of gradient values 121 | state["exp_avg"] = torch.zeros_like(p.data) 122 | # Exponential moving average of squared gradient values 123 | state["exp_avg_sq"] = torch.zeros_like(p.data) 124 | 125 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 126 | beta1, beta2 = group["betas"] 127 | 128 | state["step"] += 1 129 | 130 | # Decay the first and second moment running average coefficient 131 | # In-place operations to update the averages at the same time 132 | exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1) 133 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) 134 | denom = exp_avg_sq.sqrt().add_(group["eps"]) 135 | 136 | step_size = group["lr"] 137 | if group["correct_bias"]: # No bias correction for Bert 138 | bias_correction1 = 1.0 - beta1 ** state["step"] 139 | bias_correction2 = 1.0 - beta2 ** state["step"] 140 | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 141 | 142 | p.data.addcdiv_(exp_avg, denom, value=-step_size) 143 | 144 | # Just adding the square of the weights to the loss function is *not* 145 | # the correct way of using L2 regularization/weight decay with Adam, 146 | # since that will interact with the m and v parameters in strange ways. 147 | # 148 | # Instead we want to decay the weights in a manner that doesn't interact 149 | # with the m/v parameters. This is equivalent to adding the square 150 | # of the weights to the loss with plain (non-momentum) SGD. 151 | # Add weight decay at the end (fixed version) 152 | if group["weight_decay"] > 0.0: 153 | p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay"]) 154 | 155 | return loss 156 | 157 | def get_lr(self): 158 | lr = [] 159 | for group in self.param_groups: 160 | for p in group['params']: 161 | state = self.state[p] 162 | if len(state) == 0: 163 | pass 164 | else: 165 | lr.append(group['lr']) 166 | return lr -------------------------------------------------------------------------------- /pack_for_voicemos.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | zip_name=$1 4 | main_answer=$2 5 | 6 | set -e 7 | 8 | # check arguments 9 | if [ $# -le 1 ]; then 10 | echo "Usage: $0 []" 11 | exit 1 12 | fi 13 | 14 | # make dir for output zip file 15 | mkdir -p $(dirname ${zip_name}) 16 | 17 | # make temp dir 18 | TEMP_DIR=`mktemp -d` 19 | 20 | # dump main track answer file 21 | cat ${main_answer} > ${TEMP_DIR}/answer.txt 22 | 23 | echo "Main track answer file: ${main_answer}" 24 | 25 | # dump optional OOD track answer file 26 | if [ $# -ge 3 ]; then 27 | ood_answer=$3 28 | cat ${ood_answer} >> ${TEMP_DIR}/answer.txt 29 | echo "OOD track answer file: ${ood_answer}" 30 | fi 31 | echo "=== packing ..." 32 | 33 | # compress in zip format 34 | zip -j ${zip_name} ${TEMP_DIR}/answer.txt 35 | echo "=== packing done" 36 | echo "Zipped file: ${zip_name}" 37 | 38 | # remove temp dir 39 | rm -rf ${TEMP_DIR} -------------------------------------------------------------------------------- /schedulers.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | from torch.optim.lr_scheduler import MultiStepLR, StepLR 4 | 5 | # Reference: https://github.com/s3prl/s3prl/blob/master/s3prl/schedulers.py 6 | 7 | def get_scheduler(optimizer, total_steps, scheduler_config): 8 | scheduler_config = copy.deepcopy(scheduler_config) 9 | scheduler_name = scheduler_config.pop('name') 10 | scheduler = eval(f'get_{scheduler_name}')( 11 | optimizer, 12 | num_training_steps=total_steps, 13 | **scheduler_config 14 | ) 15 | return scheduler 16 | 17 | def get_multistep(optimizer, num_training_steps, milestones, gamma): 18 | return MultiStepLR(optimizer, milestones, gamma) 19 | 20 | def get_stepLR(optimizer, num_training_steps, step_size, gamma): 21 | return StepLR(optimizer, step_size, gamma) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | import yaml 5 | import warnings 6 | 7 | import numpy as np 8 | import scipy 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.utils.tensorboard import SummaryWriter 13 | from tqdm import tqdm 14 | 15 | from dataset import get_dataloader, get_dataset 16 | from models.MBNet import MBNet 17 | from models.LDNet import LDNet 18 | from models.loss import Loss 19 | from optimizers import get_optimizer 20 | from schedulers import get_scheduler 21 | from inference import save_results 22 | 23 | writer = SummaryWriter() 24 | warnings.filterwarnings("ignore") 25 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 26 | 27 | def valid(mode, model, dataloader, systems, save_dir, steps, prefix): 28 | model.eval() 29 | 30 | predict_mean_scores = [] 31 | true_mean_scores = [] 32 | predict_sys_mean_scores = {system:[] for system in systems} 33 | true_sys_mean_scores = {system:[] for system in systems} 34 | 35 | for i, batch in enumerate(tqdm(dataloader, ncols=0, desc=prefix, unit=" step")): 36 | mag_sgrams_padded, avg_scores, sys_names, wav_names = batch 37 | mag_sgrams_padded = mag_sgrams_padded.to(device) 38 | 39 | # forward 40 | with torch.no_grad(): 41 | try: 42 | # actual inference 43 | if mode == "mean_net": 44 | pred_mean_scores = model.only_mean_inference(spectrum = mag_sgrams_padded) 45 | elif mode == "all_listeners": 46 | pred_mean_scores, _ = model.average_inference(spectrum = mag_sgrams_padded) 47 | elif mode == "mean_listener": 48 | pred_mean_scores = model.mean_listener_inference(spectrum = mag_sgrams_padded) 49 | else: 50 | raise NotImplementedError 51 | 52 | pred_mean_scores = pred_mean_scores.cpu().detach().numpy() 53 | predict_mean_scores.extend(pred_mean_scores.tolist()) 54 | true_mean_scores.extend(avg_scores.tolist()) 55 | for j, sys_name in enumerate(sys_names): 56 | predict_sys_mean_scores[sys_name].append(pred_mean_scores[j]) 57 | true_sys_mean_scores[sys_name].append(avg_scores[j]) 58 | except RuntimeError as e: 59 | if "CUDA out of memory" in str(e): 60 | # print(f"[Runner] - CUDA out of memory at step {global_step}") 61 | with torch.cuda.device(device): 62 | torch.cuda.empty_cache() 63 | continue 64 | else: 65 | raise 66 | 67 | predict_mean_scores = np.array(predict_mean_scores) 68 | true_mean_scores = np.array(true_mean_scores) 69 | predict_sys_mean_scores = np.array([np.mean(scores) for scores in predict_sys_mean_scores.values()]) 70 | true_sys_mean_scores = np.array([np.mean(scores) for scores in true_sys_mean_scores.values()]) 71 | 72 | utt_MSE=np.mean((true_mean_scores-predict_mean_scores)**2) 73 | utt_LCC=np.corrcoef(true_mean_scores, predict_mean_scores)[0][1] 74 | utt_SRCC=scipy.stats.spearmanr(true_mean_scores, predict_mean_scores)[0] 75 | utt_KTAU=scipy.stats.kendalltau(true_mean_scores, predict_mean_scores)[0] 76 | sys_MSE=np.mean((true_sys_mean_scores-predict_sys_mean_scores)**2) 77 | sys_LCC=np.corrcoef(true_sys_mean_scores, predict_sys_mean_scores)[0][1] 78 | sys_SRCC=scipy.stats.spearmanr(true_sys_mean_scores, predict_sys_mean_scores)[0] 79 | sys_KTAU=scipy.stats.kendalltau(true_sys_mean_scores, predict_sys_mean_scores)[0] 80 | 81 | print( 82 | f"\n[{prefix}][{steps}][UTT][ MSE = {utt_MSE:.4f} | LCC = {utt_LCC:.4f} | SRCC = {utt_SRCC:.4f} ] [SYS][ MSE = {sys_MSE:.4f} | LCC = {sys_LCC:.4f} | SRCC = {sys_SRCC:.4f} ]\n" 83 | ) 84 | 85 | save_results(steps, [utt_MSE, utt_LCC, utt_SRCC, utt_KTAU], [sys_MSE, sys_LCC, sys_SRCC, sys_KTAU], os.path.join(save_dir, "training_" + mode + ".csv")) 86 | 87 | torch.save(model.state_dict(), os.path.join(save_dir, f"model-{steps}.pt")) 88 | model.train() 89 | 90 | def main(): 91 | 92 | parser = argparse.ArgumentParser() 93 | parser.add_argument("--dataset_name", type=str, default = "vcc2018") 94 | parser.add_argument("--data_dir", type=str, default = "data/vcc2018") 95 | parser.add_argument("--exp_dir", type=str, default = "exp") 96 | parser.add_argument("--tag", type=str, required=True) 97 | parser.add_argument("--config", type=str, required=True) 98 | parser.add_argument("--update_freq", type=int, default=1, 99 | help="If GPU OOM, decrease the batch size and increase this.") 100 | parser.add_argument('--seed', default=1337, type=int) 101 | 102 | # finetuning related 103 | parser.add_argument("--pretrained_model_path", type=str, default=None) 104 | parser.add_argument("--fix_main_module", action="store_true") 105 | 106 | args = parser.parse_args() 107 | 108 | # Fix seed and make backends deterministic 109 | np.random.seed(args.seed) 110 | torch.manual_seed(args.seed) 111 | if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) 112 | torch.backends.cudnn.deterministic = True 113 | torch.backends.cudnn.benchmark = False # because we have dynamic input size 114 | 115 | # fix issue of too many opened files 116 | # https://github.com/pytorch/pytorch/issues/11201 117 | torch.multiprocessing.set_sharing_strategy('file_system') 118 | 119 | # read config 120 | with open(args.config, 'r') as file: 121 | config = yaml.load(file, Loader=yaml.FullLoader) 122 | print("[Info] LR: {}".format(config["optimizer"]["lr"])) 123 | print("[Info] alpha: {}".format(config["alpha"])) 124 | print("[Info] lambda: {}".format(config["lambda"])) 125 | 126 | # define and make dirs 127 | save_dir = os.path.join(args.exp_dir, args.tag) 128 | os.makedirs(save_dir, exist_ok=True) 129 | idtable_path = os.path.join(save_dir, "idtable.pkl") 130 | 131 | # define dataloders 132 | train_set = get_dataset(args.dataset_name, args.data_dir, "train", idtable_path, config["padding_mode"], config["use_mean_listener"]) 133 | valid_set = get_dataset(args.dataset_name, args.data_dir, "valid", idtable_path) 134 | train_loader = get_dataloader(train_set, batch_size=config["train_batch_size"], num_workers=6) 135 | valid_loader = get_dataloader(valid_set, batch_size=config["test_batch_size"], num_workers=1, shuffle=False) 136 | print("[Info] Number of training samples: {}".format(len(train_set))) 137 | print("[Info] Number of validation samples: {}".format(len(valid_set))) 138 | 139 | # get number of judges 140 | num_judges = train_set.num_judges 141 | config["num_judges"] = num_judges 142 | print("[Info] Number of judges: {}".format(num_judges)) 143 | print("[Info] Use mean listener: {}".format("True" if config["use_mean_listener"] else "False")) 144 | 145 | # define model 146 | if config["model"] == "MBNet": 147 | model = MBNet(config).to(device) 148 | elif config["model"] == "LDNet": 149 | model = LDNet(config).to(device) 150 | else: 151 | raise NotImplementedError 152 | print("[Info] Model parameters: {}".format(model.get_num_params())) 153 | criterion = Loss(config["output_type"], config["alpha"], config["lambda"], config["tau"], config["mask_loss"]) 154 | 155 | # finetune 156 | if args.pretrained_model_path is not None: 157 | print("[Info] Loading pretrained model from {}".format(args.pretrained_model_path)) 158 | pretrained_model_state_dict = torch.load(args.pretrained_model_path) 159 | pretrained_model_state_dict.pop("judge_embedding.weight", None) 160 | model.load_state_dict(pretrained_model_state_dict, strict=False) 161 | if args.fix_main_module: 162 | for mod, param in model.named_parameters(): 163 | if mod.startswith("judge_embedding"): 164 | print("[Info] Freezing {}".format(mod)) 165 | param.requires_grad = False 166 | 167 | # optimizer 168 | optimizer = get_optimizer(model, config["total_steps"], config["optimizer"]) 169 | optimizer.zero_grad() 170 | 171 | # scheduler 172 | scheduler = None 173 | if config.get('scheduler'): 174 | scheduler = get_scheduler(optimizer, config["total_steps"], config["scheduler"]) 175 | 176 | # set pbar 177 | pbar = tqdm(total=config["total_steps"], ncols=0, desc="Overall", unit=" step") 178 | 179 | # count accumulated gradients 180 | backward_steps = 0 181 | 182 | # write config 183 | with open(os.path.join(save_dir, 'config.yml'), 'w') as f: 184 | yaml.dump(config, f) 185 | 186 | # actual training loop 187 | model.train() 188 | while pbar.n < pbar.total: 189 | for i, batch in enumerate( 190 | tqdm(train_loader, ncols=0, desc="Train", unit=" step") 191 | ): 192 | try: 193 | if pbar.n >= pbar.total: 194 | break 195 | global_step = pbar.n + 1 196 | 197 | # fetch batch and put on device 198 | mag_sgrams_padded, mag_sgrams_lengths, avg_scores, scores, judge_ids = batch 199 | mag_sgrams_padded = mag_sgrams_padded.to(device) 200 | judge_ids = judge_ids.to(device) 201 | avg_scores = avg_scores.to(device) 202 | scores = scores.to(device) 203 | 204 | # forward 205 | # each has shape [batch, time, 1 (scalar) / 5 (categorical)] 206 | pred_mean_scores, pred_ld_scores = model(spectrum = mag_sgrams_padded, 207 | judge_id = judge_ids, 208 | ) 209 | 210 | # loss calculation 211 | loss, mean_loss, ld_loss = criterion(pred_mean_scores, avg_scores, pred_ld_scores, scores, mag_sgrams_lengths, device) 212 | 213 | (loss / args.update_freq).backward() 214 | 215 | if config["alpha"] > 0: 216 | pbar.set_postfix( 217 | { 218 | "loss": loss.item(), 219 | "mean_loss": mean_loss.item(), 220 | "LD_loss": ld_loss.item(), 221 | } 222 | ) 223 | else: 224 | pbar.set_postfix( 225 | { 226 | "loss": loss.item(), 227 | "LD_loss": ld_loss.item(), 228 | } 229 | ) 230 | 231 | except RuntimeError as e: 232 | if "CUDA out of memory" in str(e): 233 | print(f"[Runner] - CUDA out of memory at step {global_step}") 234 | with torch.cuda.device(device): 235 | torch.cuda.empty_cache() 236 | optimizer.zero_grad() 237 | continue 238 | else: 239 | raise 240 | 241 | # release GPU memory 242 | del loss 243 | 244 | # whether to accumulate gradient 245 | backward_steps += 1 246 | if backward_steps % args.update_freq > 0: 247 | continue 248 | 249 | # gradient clipping 250 | grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=config["grad_clip"]) 251 | 252 | # optimize 253 | if math.isnan(grad_norm): 254 | print(f"[Runner] - grad norm is NaN at step {global_step}") 255 | else: 256 | optimizer.step() 257 | optimizer.zero_grad() 258 | 259 | # adjust learning rate 260 | if scheduler: 261 | scheduler.step() 262 | 263 | # evaluate 264 | if global_step % config["valid_steps"] == 0: 265 | valid(config["inference_mode"], model, valid_loader, valid_set.systems, save_dir, global_step, "Valid") 266 | pbar.update(1) 267 | 268 | if __name__ == "__main__": 269 | main() 270 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | ################################################################################ 4 | 5 | # The following function are based on: 6 | # https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/nets_utils.py 7 | 8 | def make_pad_mask(lengths, xs=None, length_dim=-1): 9 | """Make mask tensor containing indices of padded part. 10 | 11 | Args: 12 | lengths (LongTensor or List): Batch of lengths (B,). 13 | xs (Tensor, optional): The reference tensor. 14 | If set, masks will be the same shape as this tensor. 15 | length_dim (int, optional): Dimension indicator of the above tensor. 16 | See the example. 17 | 18 | Returns: 19 | Tensor: Mask tensor containing indices of padded part. 20 | dtype=torch.uint8 in PyTorch 1.2- 21 | dtype=torch.bool in PyTorch 1.2+ (including 1.2) 22 | 23 | Examples: 24 | With only lengths. 25 | 26 | >>> lengths = [5, 3, 2] 27 | >>> make_non_pad_mask(lengths) 28 | masks = [[0, 0, 0, 0 ,0], 29 | [0, 0, 0, 1, 1], 30 | [0, 0, 1, 1, 1]] 31 | 32 | With the reference tensor. 33 | 34 | >>> xs = torch.zeros((3, 2, 4)) 35 | >>> make_pad_mask(lengths, xs) 36 | tensor([[[0, 0, 0, 0], 37 | [0, 0, 0, 0]], 38 | [[0, 0, 0, 1], 39 | [0, 0, 0, 1]], 40 | [[0, 0, 1, 1], 41 | [0, 0, 1, 1]]], dtype=torch.uint8) 42 | >>> xs = torch.zeros((3, 2, 6)) 43 | >>> make_pad_mask(lengths, xs) 44 | tensor([[[0, 0, 0, 0, 0, 1], 45 | [0, 0, 0, 0, 0, 1]], 46 | [[0, 0, 0, 1, 1, 1], 47 | [0, 0, 0, 1, 1, 1]], 48 | [[0, 0, 1, 1, 1, 1], 49 | [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) 50 | 51 | With the reference tensor and dimension indicator. 52 | 53 | >>> xs = torch.zeros((3, 6, 6)) 54 | >>> make_pad_mask(lengths, xs, 1) 55 | tensor([[[0, 0, 0, 0, 0, 0], 56 | [0, 0, 0, 0, 0, 0], 57 | [0, 0, 0, 0, 0, 0], 58 | [0, 0, 0, 0, 0, 0], 59 | [0, 0, 0, 0, 0, 0], 60 | [1, 1, 1, 1, 1, 1]], 61 | [[0, 0, 0, 0, 0, 0], 62 | [0, 0, 0, 0, 0, 0], 63 | [0, 0, 0, 0, 0, 0], 64 | [1, 1, 1, 1, 1, 1], 65 | [1, 1, 1, 1, 1, 1], 66 | [1, 1, 1, 1, 1, 1]], 67 | [[0, 0, 0, 0, 0, 0], 68 | [0, 0, 0, 0, 0, 0], 69 | [1, 1, 1, 1, 1, 1], 70 | [1, 1, 1, 1, 1, 1], 71 | [1, 1, 1, 1, 1, 1], 72 | [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8) 73 | >>> make_pad_mask(lengths, xs, 2) 74 | tensor([[[0, 0, 0, 0, 0, 1], 75 | [0, 0, 0, 0, 0, 1], 76 | [0, 0, 0, 0, 0, 1], 77 | [0, 0, 0, 0, 0, 1], 78 | [0, 0, 0, 0, 0, 1], 79 | [0, 0, 0, 0, 0, 1]], 80 | [[0, 0, 0, 1, 1, 1], 81 | [0, 0, 0, 1, 1, 1], 82 | [0, 0, 0, 1, 1, 1], 83 | [0, 0, 0, 1, 1, 1], 84 | [0, 0, 0, 1, 1, 1], 85 | [0, 0, 0, 1, 1, 1]], 86 | [[0, 0, 1, 1, 1, 1], 87 | [0, 0, 1, 1, 1, 1], 88 | [0, 0, 1, 1, 1, 1], 89 | [0, 0, 1, 1, 1, 1], 90 | [0, 0, 1, 1, 1, 1], 91 | [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) 92 | 93 | """ 94 | if length_dim == 0: 95 | raise ValueError("length_dim cannot be 0: {}".format(length_dim)) 96 | 97 | if not isinstance(lengths, list): 98 | lengths = lengths.tolist() 99 | bs = int(len(lengths)) 100 | if xs is None: 101 | maxlen = int(max(lengths)) 102 | else: 103 | maxlen = xs.size(length_dim) 104 | 105 | seq_range = torch.arange(0, maxlen, dtype=torch.int64) 106 | seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen) 107 | seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1) 108 | mask = seq_range_expand >= seq_length_expand 109 | 110 | if xs is not None: 111 | assert xs.size(0) == bs, (xs.size(0), bs) 112 | 113 | if length_dim < 0: 114 | length_dim = xs.dim() + length_dim 115 | # ind = (:, None, ..., None, :, , None, ..., None) 116 | ind = tuple( 117 | slice(None) if i in (0, length_dim) else None for i in range(xs.dim()) 118 | ) 119 | mask = mask[ind].expand_as(xs).to(xs.device) 120 | return mask 121 | 122 | 123 | def make_non_pad_mask(lengths, xs=None, length_dim=-1): 124 | """Make mask tensor containing indices of non-padded part. 125 | 126 | Args: 127 | lengths (LongTensor or List): Batch of lengths (B,). 128 | xs (Tensor, optional): The reference tensor. 129 | If set, masks will be the same shape as this tensor. 130 | length_dim (int, optional): Dimension indicator of the above tensor. 131 | See the example. 132 | 133 | Returns: 134 | ByteTensor: mask tensor containing indices of padded part. 135 | dtype=torch.uint8 in PyTorch 1.2- 136 | dtype=torch.bool in PyTorch 1.2+ (including 1.2) 137 | 138 | Examples: 139 | With only lengths. 140 | 141 | >>> lengths = [5, 3, 2] 142 | >>> make_non_pad_mask(lengths) 143 | masks = [[1, 1, 1, 1 ,1], 144 | [1, 1, 1, 0, 0], 145 | [1, 1, 0, 0, 0]] 146 | 147 | With the reference tensor. 148 | 149 | >>> xs = torch.zeros((3, 2, 4)) 150 | >>> make_non_pad_mask(lengths, xs) 151 | tensor([[[1, 1, 1, 1], 152 | [1, 1, 1, 1]], 153 | [[1, 1, 1, 0], 154 | [1, 1, 1, 0]], 155 | [[1, 1, 0, 0], 156 | [1, 1, 0, 0]]], dtype=torch.uint8) 157 | >>> xs = torch.zeros((3, 2, 6)) 158 | >>> make_non_pad_mask(lengths, xs) 159 | tensor([[[1, 1, 1, 1, 1, 0], 160 | [1, 1, 1, 1, 1, 0]], 161 | [[1, 1, 1, 0, 0, 0], 162 | [1, 1, 1, 0, 0, 0]], 163 | [[1, 1, 0, 0, 0, 0], 164 | [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) 165 | 166 | With the reference tensor and dimension indicator. 167 | 168 | >>> xs = torch.zeros((3, 6, 6)) 169 | >>> make_non_pad_mask(lengths, xs, 1) 170 | tensor([[[1, 1, 1, 1, 1, 1], 171 | [1, 1, 1, 1, 1, 1], 172 | [1, 1, 1, 1, 1, 1], 173 | [1, 1, 1, 1, 1, 1], 174 | [1, 1, 1, 1, 1, 1], 175 | [0, 0, 0, 0, 0, 0]], 176 | [[1, 1, 1, 1, 1, 1], 177 | [1, 1, 1, 1, 1, 1], 178 | [1, 1, 1, 1, 1, 1], 179 | [0, 0, 0, 0, 0, 0], 180 | [0, 0, 0, 0, 0, 0], 181 | [0, 0, 0, 0, 0, 0]], 182 | [[1, 1, 1, 1, 1, 1], 183 | [1, 1, 1, 1, 1, 1], 184 | [0, 0, 0, 0, 0, 0], 185 | [0, 0, 0, 0, 0, 0], 186 | [0, 0, 0, 0, 0, 0], 187 | [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8) 188 | >>> make_non_pad_mask(lengths, xs, 2) 189 | tensor([[[1, 1, 1, 1, 1, 0], 190 | [1, 1, 1, 1, 1, 0], 191 | [1, 1, 1, 1, 1, 0], 192 | [1, 1, 1, 1, 1, 0], 193 | [1, 1, 1, 1, 1, 0], 194 | [1, 1, 1, 1, 1, 0]], 195 | [[1, 1, 1, 0, 0, 0], 196 | [1, 1, 1, 0, 0, 0], 197 | [1, 1, 1, 0, 0, 0], 198 | [1, 1, 1, 0, 0, 0], 199 | [1, 1, 1, 0, 0, 0], 200 | [1, 1, 1, 0, 0, 0]], 201 | [[1, 1, 0, 0, 0, 0], 202 | [1, 1, 0, 0, 0, 0], 203 | [1, 1, 0, 0, 0, 0], 204 | [1, 1, 0, 0, 0, 0], 205 | [1, 1, 0, 0, 0, 0], 206 | [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) 207 | 208 | """ 209 | return ~make_pad_mask(lengths, xs, length_dim) --------------------------------------------------------------------------------