├── .gitignore ├── LICENSE ├── README.md ├── fairseq_checkpoints ├── download_stacking_checkpoints.sh └── download_strong_checkpoints.sh ├── poetry.lock ├── pyproject.toml ├── remove_silenceWav.py ├── stacking ├── ensemble_multidomain_scripts │ ├── README.md │ ├── WavLM.py │ ├── calc_result.py │ ├── calc_testphase_result.py │ ├── collect_stage1_result.py │ ├── collect_stage1_testphase_result.py │ ├── collect_stage2_result.py │ ├── collect_stage2_testphase_result.py │ ├── convert_strong_learner_result.py │ ├── convert_strong_learner_testphase_result.py │ ├── data_util.py │ ├── external_mos_list.txt │ ├── extract_ssl_feature.py │ ├── gp_models.py │ ├── make_ensemble_dataset.py │ ├── make_ensemble_dataset_wotest.py │ ├── make_ensemble_testphase.py │ ├── models.py │ ├── modules.py │ ├── opt_stage1.py │ ├── opt_stage2.py │ ├── opt_stage3.py │ ├── pred_stage1.py │ ├── pred_stage1_ood.sh │ ├── pred_testphase_stage1.py │ ├── pred_testphase_stage1_main.sh │ ├── pred_testphase_stage1_ood.sh │ ├── pred_testphase_stage2-3_main.sh │ ├── pred_testphase_stage2-3_ood.sh │ ├── pred_testphase_stage2.py │ ├── pred_testphase_stage3.py │ ├── run_stage1.py │ ├── run_stage1.sh │ ├── run_stage2-3_main.sh │ ├── run_stage2-3_ood.sh │ ├── run_stage2.py │ ├── run_stage3.py │ ├── stage2-method │ │ ├── main-strong1-weak48.yaml │ │ └── ood-strong1-weak144.yaml │ └── unused │ │ ├── opt_stage1_all.sh │ │ ├── pred_stage1_exactgp.sh │ │ ├── run_stage1_exactgp.sh │ │ ├── run_stage1_other.sh │ │ └── run_stage2-end_opt.sh └── strong_learner_result │ ├── main1 │ ├── answer-main.csvfold_0 │ ├── answer-main.csvfold_1 │ ├── answer-main.csvfold_2 │ ├── answer-main.csvfold_3 │ ├── answer-main.csvfold_4 │ ├── answer-main.csvtest_0 │ ├── answer-main.csvtest_1 │ ├── answer-main.csvtest_2 │ ├── answer-main.csvtest_3 │ ├── answer-main.csvtest_4 │ ├── answer-main.csvval_0 │ ├── answer-main.csvval_1 │ ├── answer-main.csvval_2 │ ├── answer-main.csvval_3 │ └── answer-main.csvval_4 │ └── ood1 │ ├── answer-ood.csvfold_0 │ ├── answer-ood.csvfold_1 │ ├── answer-ood.csvfold_2 │ ├── answer-ood.csvtest_0 │ ├── answer-ood.csvtest_1 │ ├── answer-ood.csvtest_2 │ ├── answer-ood.csvval_0 │ ├── answer-ood.csvval_1 │ └── answer-ood.csvval_2 └── strong ├── README.md ├── TRAINSET_external.txt ├── WavLM.py ├── configs ├── cv.yaml ├── dataset │ ├── cv.yaml │ ├── default.yaml │ └── wo_augment.yaml ├── default.yaml ├── model │ ├── default.yaml │ └── wo_phoneme.yaml ├── optuna-main.yaml ├── optuna-ood.yaml └── train │ └── default.yaml ├── cross_validation.py ├── data ├── data_augment.py ├── dataset.py ├── lightning_module.py ├── loss_function.py ├── model.py ├── modules.py ├── param_tuning.py ├── predict.py ├── text └── symbols.py ├── train.py ├── transcribe_speech.py └── transcriptions_clustered.csv /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.toptal.com/developers/gitignore/api/python 3 | # Edit at https://www.toptal.com/developers/gitignore?templates=python 4 | 5 | ### Python ### 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # poetry 103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 104 | # This is especially recommended for binary packages to ensure reproducibility, and is more 105 | # commonly ignored for libraries. 106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 107 | #poetry.lock 108 | 109 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 110 | __pypackages__/ 111 | 112 | # Celery stuff 113 | celerybeat-schedule 114 | celerybeat.pid 115 | 116 | # SageMath parsed files 117 | *.sage.py 118 | 119 | # Environments 120 | .env 121 | .venv 122 | env/ 123 | venv/ 124 | ENV/ 125 | env.bak/ 126 | venv.bak/ 127 | 128 | # Spyder project settings 129 | .spyderproject 130 | .spyproject 131 | 132 | # Rope project settings 133 | .ropeproject 134 | 135 | # mkdocs documentation 136 | /site 137 | 138 | # mypy 139 | .mypy_cache/ 140 | .dmypy.json 141 | dmypy.json 142 | 143 | # Pyre type checker 144 | .pyre/ 145 | 146 | # pytype static type analyzer 147 | .pytype/ 148 | 149 | # Cython debug symbols 150 | cython_debug/ 151 | 152 | # PyCharm 153 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 154 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 155 | # and can be added to the global gitignore or merged into this file. For a more nuclear 156 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 157 | #.idea/ 158 | 159 | # End of https://www.toptal.com/developers/gitignore/api/python 160 | *.pt 161 | *.pth 162 | outputs/ 163 | stacking/out/ 164 | stacking/data -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Saruwatari&Koyama laboratory, The University of Tokyo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UTMOS: UTokyo-SaruLab MOS Prediction System 2 | 3 | Official implementation of ["UTMOS: UTokyo-SaruLab System for VoiceMOS Challenge 2022"](https://arxiv.org/abs/2204.02152) accepted by INTERSPEECH 2022. 4 | 5 | >**Abstract:**
6 | We present the UTokyo-SaruLab mean opinion score (MOS) prediction system submitted to VoiceMOS Challenge 2022. The challenge is to predict the MOS values of speech samples collected from previous Blizzard Challenges and Voice Conversion Challenges for two tracks: a main track for in-domain prediction and an out-of-domain (OOD) track for which there is less labeled data from different listening tests. Our system is based on ensemble learning of strong and weak learners. Strong learners incorporate several improvements to the previous fine-tuning models of self-supervised learning (SSL) models, while weak learners use basic machine-learning methods to predict scores from SSL features. 7 | In the Challenge, our system had the highest score on several metrics for both the main and OOD tracks. In addition, we conducted ablation studies to investigate the effectiveness of our proposed methods. 8 | 9 | 🏆 Our system achieved the 1st places in 10/16 metrics at [the VoiceMOS Challenge 2022](https://voicemos-challenge-2022.github.io/)! 10 | 11 | Demo for UTMOS is available: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/sarulab-speech/UTMOS-demo) 12 | 13 | ## Quick Prediction 14 | You can simply use a pretrained UTMOS strong learner trained on the VoiceMOS Challenge 2022 Main Track Dataset. We support both single and batch processings in a [NISQA](https://github.com/gabrielmittag/NISQA)-like interface. 15 | 16 | Git clone the Hugging Face repo: 17 | ``` 18 | git clone https://huggingface.co/spaces/sarulab-speech/UTMOS-demo 19 | cd UTMOS-demo 20 | pip install -r requirements.txt 21 | ``` 22 | 23 | To predict the MOS of a single wav file: 24 | ``` 25 | python predict.py --mode predict_file --inp_path /path/to/wav/file.wav --out_path /path/to/csv/file.csv 26 | ``` 27 | 28 | To predict the MOS of all .wav files in a folder use: 29 | ``` 30 | python predict.py --mode predict_dir --inp_dir /path/to/wav/dir/ --bs --out_path /path/to/csv/file.csv 31 | ``` 32 | 33 | ## How to use the whole functionality 34 | 35 | ### Enviornment setup 36 | 37 | 1. This repo uses poetry as the python envoirnmet manager. Install poetry following [this instruction](https://python-poetry.org/docs/#installation) first. 38 | 1. Install required python packages using `poetry install`. And enter the python enviornment with `poetry shell`. All following operations **requires** to be inside the poetry shell enviornment. 39 | 1. Second, download necessary fairseq checkpoint using [download_strong_checkpoints.sh](fairseq_checkpoints/download_strong_checkpoints.sh) for strong and [download_stacking_checkpoints.sh](fairseq_checkpoints/download_stacking_checkpoints.sh) for stacking. 40 | 1. Next, run the following command to exclude bad wav file from main track training set. 41 | The original data will be saved with `.bak` suffix. 42 | ```shell 43 | python remove_silenceWav.py --path_to_dataset path-to-dataset/phase1-main/ 44 | ``` 45 | 46 | ## Model training 47 | Our system predicts MOS with small errors by stacking of strong and weak learners. 48 | - To run training and inference with a single strong learner, see [strong/README.md](strong/README.md). 49 | - To run stacking, see [stacking/ensemble_multidomain_scripts/README.md](stacking/ensemble_multidomain_scripts/README.md). 50 | 51 | If you encounter any problems regarding running the code, feel free to submit an issue. The code is not fully tested. 52 | -------------------------------------------------------------------------------- /fairseq_checkpoints/download_stacking_checkpoints.sh: -------------------------------------------------------------------------------- 1 | gdown --fuzzy https://drive.google.com/file/d/19-C7SMQvEFAYLG5uc47NX_MY03JCbI4x/view?usp=sharing 2 | gdown --fuzzy https://drive.google.com/file/d/1PlbT_9_B4F9BsD_ija84sUTVw7almNX8/view?usp=sharing 3 | gdown --fuzzy https://drive.google.com/file/d/1rMu6PQ9vz3qPz4oIm72JDuIr5AHIbCOb/view?usp=sharing 4 | wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small.pt -O wav2vec_small.pt 5 | wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_new.pt -O wav2vec_vox_new.pt 6 | wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/w2v_large_lv_fsh_swbd_cv.pt -O w2v_large_lv_fsh_swbd_cv.pt 7 | wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/xlsr_53_56k.pt -O xlsr_53_56k.pt 8 | wget https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt -O hubert_base_ls960.pt 9 | wget https://dl.fbaipublicfiles.com/hubert/hubert_large_ll60k.pt -O hubert_large_ll60k.pt 10 | -------------------------------------------------------------------------------- /fairseq_checkpoints/download_strong_checkpoints.sh: -------------------------------------------------------------------------------- 1 | wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small.pt -O wav2vec_small.pt -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "utmos22" 3 | version = "0.1.0" 4 | description = "UTokyo-Sarulab submission for VoiceMOS challenge 2022" 5 | authors = [ 6 | "Takaaki Saeki ", 7 | "Xin Detai ", 8 | "Wataru Nakata ", 9 | "Tomoki Koriyama ", 10 | ] 11 | 12 | [tool.poetry.dependencies] 13 | python = "^3.8" 14 | transformers = "^4.15.0" 15 | torch = "1.11" 16 | datasets = "^1.17.0" 17 | matplotlib = "^3.5.1" 18 | seaborn = "^0.11.2" 19 | scikit-learn = "^1.0.2" 20 | pytorch-lightning = "^1.5.8" 21 | phonemizer = "^3.0.1" 22 | SoundFile = "^0.10.3" 23 | python-Levenshtein = "^0.12.2" 24 | wandb = "^0.12.10" 25 | hydra-core = "^1.1.1" 26 | hydra-optuna-sweeper = "^1.1.2" 27 | augment = {git = "https://github.com/facebookresearch/WavAugment.git", rev = "main"} 28 | fairseq = {git = "https://github.com/sarulab-speech/fairseq.git", rev = "for_utmos"} 29 | pandas = "^1.4.1" 30 | gpytorch = "^1.6.0" 31 | lightgbm = "^3.3.2" 32 | gdown = "^4.4.0" 33 | 34 | 35 | [tool.poetry.dev-dependencies] 36 | 37 | [build-system] 38 | requires = ["poetry-core>=1.0.0"] 39 | build-backend = "poetry.core.masonry.api" 40 | -------------------------------------------------------------------------------- /remove_silenceWav.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | import shutil 4 | 5 | if __name__ == '__main__': 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--path_to_dataset', type=str, required=True) 8 | args = parser.parse_args() 9 | train_set_path = Path(args.path_to_dataset)/'DATA/sets/TRAINSET' 10 | train_mos_list_path = Path(args.path_to_dataset)/'DATA/sets/train_mos_list.txt' 11 | shutil.copyfile( train_mos_list_path, train_mos_list_path.with_suffix('.txt.bak')) 12 | shutil.copyfile( train_set_path, train_set_path.with_suffix('.bak')) 13 | with open(train_mos_list_path) as f: 14 | lines = f.readlines() 15 | new_lines =[] 16 | for line in lines: 17 | if 'sys4bafa-uttc2e86f6.wav' in line: 18 | continue 19 | new_lines.append(line) 20 | with open(train_mos_list_path, 'w') as f: 21 | f.writelines(new_lines) 22 | with open(train_set_path) as f: 23 | lines = f.readlines() 24 | new_lines =[] 25 | for line in lines: 26 | if 'sys4bafa-uttc2e86f6.wav' in line: 27 | continue 28 | new_lines.append(line) 29 | with open(train_set_path, 'w') as f: 30 | f.writelines(new_lines) 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/README.md: -------------------------------------------------------------------------------- 1 | # Stacking of strong and weak learners 2 | 3 | ## Data split 4 | Firstly link `data` to `../data`. 5 | Then run the following commands. 6 | ```shell 7 | python make_ensemble_dataset.py --datatrack phase1-main 8 | python make_ensemble_dataset.py --datatrack phase1-ood 9 | python make_ensemble_dataset_wotest.py --datatrack external 10 | python make_ensemble_testphase.py --datatrack phase1-main 11 | python make_ensemble_testphase.py --datatrack phase1-odd 12 | ``` 13 | 14 | ## Feature extraction with SSL model 15 | Place the ckpt file of the pretrained model to `../pretrained_model`. 16 | Then run the following command. 17 | ```shell 18 | python extract_ssl_feature.py 19 | ``` 20 | 21 | ## Converting results of strong learners for stacking 22 | Place the respective result files to `../strong_learner_result/main1` and `../strong_learner_result/ood1`. 23 | Then run the following commands. 24 | ```shell 25 | python convert_strong_learner_result.py phase1-main main1 26 | python convert_strong_learner_result.py phase1-ood ood1 27 | python convert_strong_learner_testphase_result.py testphase-main main1 28 | python convert_strong_learner_testphase_result.py testphase-ood ood1 29 | ``` 30 | 31 | ## Stage1 32 | For both main and OOD tracks, run the following command to perform stage1. 33 | ```shell 34 | ./run_stage1.sh 35 | ``` 36 | 37 | ## Stage2 and 3 for Main track 38 | Run the following commands. 39 | ```shell 40 | ./run_stage2-3_main.sh # Run stage 2 and 3 41 | ./pred_testphase_stage1_main.sh # Predict stage 1 42 | ./pred_testphase_stage2-3_main.sh # Predict stage 2 and 3 43 | ``` 44 | 45 | ## Stage 2 and 3 for OOD track 46 | Run the following commands. 47 | ```shell 48 | ./pred_stage1_ood.sh # Predict by cross-domain 49 | ./run_stage2-3_ood.sh # Run stage 2 and 3 50 | ./pred_testphase_stage1_ood.sh # Predict stage 1 51 | ./pred_testphase_stage2-3_ood.sh # Predict stage 2 and 3 52 | ``` 53 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/WavLM.py: -------------------------------------------------------------------------------- 1 | ../../strong/WavLM.py -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/calc_result.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from pathlib import Path 4 | 5 | import itertools 6 | 7 | import yaml 8 | import numpy as np 9 | import pandas as pd 10 | import scipy 11 | import scipy.stats 12 | 13 | def get_arg(): 14 | import argparse 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('datatrack') 17 | parser.add_argument('feat_type', default='weak36') 18 | return parser.parse_args() 19 | 20 | K_CV = 5 21 | 22 | def evaluate_val(df_val): 23 | 24 | assert len(df_val) == len(df_val.index.unique()) 25 | 26 | mse = np.mean(np.square(df_val['true_mos'] - df_val['pred_mos'])) 27 | utt_srcc = scipy.stats.spearmanr(df_val['true_mos'], df_val['pred_mos'])[0] 28 | print('CV UTT MSE: {:f}'.format(mse)) 29 | print('CV UTT SRCC: {:f}'.format(utt_srcc)) 30 | 31 | # sys 32 | df_val['system_ID'] = df_val.index.str.extract(r'^(.+?)-').values 33 | 34 | df_val_sys = df_val.groupby('system_ID')['pred_mos'].mean() 35 | df_true_sys = df_val.groupby('system_ID')['true_mos'].mean() 36 | 37 | df_sys = pd.merge(df_val_sys, df_true_sys, on='system_ID', how='left') 38 | 39 | sys_mse = np.mean(np.square(df_sys['true_mos'] - df_sys['pred_mos'])) 40 | sys_srcc = scipy.stats.spearmanr(df_sys['true_mos'], df_sys['pred_mos'])[0] 41 | print('CV SYS MSE: {:f}'.format(sys_mse)) 42 | print('CV SYS SRCC: {:f}'.format(sys_srcc)) 43 | 44 | 45 | return {'cv_utt_mse': mse} 46 | 47 | def evaluate_test(df_test, pred_datatrack): 48 | 49 | utt_mse = np.mean(np.square(df_test['true_mos'] - df_test['pred_mos'])) 50 | utt_srcc = scipy.stats.spearmanr(df_test['true_mos'], df_test['pred_mos'])[0] 51 | print('TEST UTT MSE: {:f}'.format(utt_mse)) 52 | print('TEST UTT SRCC: {:f}'.format(utt_srcc)) 53 | 54 | df_test['system_ID'] = df_test.index.str.extract(r'^(.+?)-').values 55 | 56 | df_test_sys = df_test.groupby('system_ID')['pred_mos'].mean() 57 | 58 | df_true_sys = pd.read_csv(f'../data/{pred_datatrack}/DATA/mydata_system.csv') 59 | 60 | df_sys = pd.merge(df_test_sys, df_true_sys, on='system_ID', how='left').set_index('system_ID') 61 | 62 | sys_mse = np.mean(np.square(df_sys['mean'] - df_sys['pred_mos'])) 63 | sys_srcc = scipy.stats.spearmanr(df_sys['mean'], df_sys['pred_mos'])[0] 64 | print('TEST SYS MSE: {:f}'.format(sys_mse)) 65 | print('TEST SYS SRCC: {:f}'.format(sys_srcc)) 66 | 67 | return {'test_utt_mse': utt_mse, 68 | 'test_utt_srcc': utt_srcc, 69 | 'test_sys_mse': sys_mse, 70 | 'test_sys_srcc': sys_srcc} 71 | 72 | 73 | 74 | def calc_strong_learner_score(pred_datatrack, model_type, k_cv=K_CV): 75 | 76 | print('Stage1, {}, {}'.format(pred_datatrack, model_type)) 77 | 78 | result_dir = Path('../out/ensemble-multidomain/stage1') / \ 79 | pred_datatrack / f'{model_type}' 80 | print(result_dir) 81 | 82 | use_cv_result = True 83 | 84 | df_vals = [] 85 | df_tests = [] 86 | 87 | for i_cv in range(k_cv): 88 | if use_cv_result: 89 | df_vals.append(pd.read_csv(result_dir / str(i_cv) / f'val.csv', 90 | index_col=0)) 91 | df_tests.append(pd.read_csv(result_dir / str(i_cv) / f'test.csv', 92 | index_col=0)) 93 | else: 94 | pred_dir = result_dir / str(i_cv) / f'pred-{pred_datatrack}' 95 | df_vals.append(pd.read_csv(pred_dir / f'train.csv', 96 | index_col=0)) 97 | df_tests.append(pd.read_csv(pred_dir / f'test.csv', 98 | index_col=0)) 99 | 100 | if use_cv_result: 101 | df_val = pd.concat(df_vals) 102 | else: 103 | df_val = sum(df_vals) / len(df_vals) 104 | df_test = sum(df_tests) / len(df_tests) 105 | 106 | result = {'stage': 'stage1', 'train_datatrack': 'all', 107 | 'model_type': model_type, 'feat_type': 'nn'} 108 | 109 | result.update(evaluate_val(df_val)) 110 | result.update(evaluate_test(df_test, pred_datatrack)) 111 | 112 | return result 113 | 114 | def calc_stage1_score(pred_datatrack, train_datatrack, model_type, ssl_type): 115 | 116 | print('Stage1, {}, {}, {}, {}'.format(pred_datatrack, train_datatrack, model_type, ssl_type)) 117 | 118 | result_dir = Path('../out/ensemble-multidomain/stage1') / \ 119 | train_datatrack / f'{model_type}-{ssl_type}' 120 | print(result_dir) 121 | 122 | use_cv_result = (pred_datatrack == train_datatrack or train_datatrack.startswith('phase1-all')) 123 | 124 | df_vals = [] 125 | df_tests = [] 126 | 127 | for i_cv in range(K_CV): 128 | if use_cv_result: 129 | df_vals.append(pd.read_csv(result_dir / str(i_cv) / f'val.csv', 130 | index_col=0)) 131 | df_tests.append(pd.read_csv(result_dir / str(i_cv) / f'test.csv', 132 | index_col=0)) 133 | else: 134 | pred_dir = result_dir / str(i_cv) / f'pred-{pred_datatrack}' 135 | df_vals.append(pd.read_csv(pred_dir / f'train.csv', 136 | index_col=0)) 137 | df_tests.append(pd.read_csv(pred_dir / f'test.csv', 138 | index_col=0)) 139 | 140 | if use_cv_result: 141 | df_val = pd.concat(df_vals) 142 | else: 143 | df_val = sum(df_vals) / len(df_vals) 144 | df_test = sum(df_tests) / len(df_tests) 145 | 146 | result = {'stage': 'stage1', 'train_datatrack': train_datatrack, 147 | 'model_type': model_type, 'feat_type': ssl_type} 148 | 149 | result.update(evaluate_val(df_val)) 150 | result.update(evaluate_test(df_test, pred_datatrack)) 151 | 152 | return result 153 | 154 | def calc_stage_n_score(pred_datatrack, stage, model_type, feat_type): 155 | 156 | result_dir = Path('../out/ensemble-multidomain') / stage / \ 157 | pred_datatrack / f'{model_type}-{feat_type}' 158 | print(result_dir) 159 | 160 | df_vals = [] 161 | df_tests = [] 162 | 163 | for i_cv in range(K_CV): 164 | df_vals.append(pd.read_csv(result_dir / str(i_cv) / f'val.csv', 165 | index_col=0)) 166 | df_tests.append(pd.read_csv(result_dir / str(i_cv) / f'test.csv', 167 | index_col=0)) 168 | 169 | df_val = pd.concat(df_vals) 170 | df_test = sum(df_tests) / len(df_tests) 171 | 172 | result = {'stage': stage, 'train_datatrack': pred_datatrack, 173 | 'model_type': model_type, 'feat_type': feat_type} 174 | 175 | result.update(evaluate_val(df_val)) 176 | result.update(evaluate_test(df_test, pred_datatrack)) 177 | 178 | return result 179 | 180 | 181 | 182 | def main(): 183 | 184 | args = get_arg() 185 | 186 | feat_conf = yaml.safe_load(open('./stage2-method/{}.yaml'.format(args.feat_type))) 187 | print(feat_conf) 188 | 189 | data = [] 190 | 191 | # stage1 192 | for model_type in feat_conf['strong_learners']: 193 | k_cv = 3 if args.datatrack == 'phase1-ood' else K_CV 194 | data.append(calc_strong_learner_score(args.datatrack, model_type, k_cv=k_cv)) 195 | 196 | 197 | for train_datatrack, model_type, ssl_type in itertools.product( 198 | feat_conf['weak_learners']['datatracks'], 199 | feat_conf['weak_learners']['model_types'], 200 | feat_conf['weak_learners']['ssl_types']): 201 | data.append(calc_stage1_score(args.datatrack, train_datatrack, model_type, ssl_type)) 202 | 203 | # stage2 204 | for model_type in feat_conf['weak_learners']['model_types']: 205 | data.append(calc_stage_n_score(args.datatrack, 'stage2', model_type, args.feat_type)) 206 | 207 | # stage3 208 | for model_type in ['ridge']: 209 | data.append(calc_stage_n_score(args.datatrack, 'stage3', model_type, args.feat_type)) 210 | 211 | df = pd.DataFrame.from_dict(data) 212 | result_dir = Path('../out/ensemble-multidomain/result') 213 | os.makedirs(result_dir, exist_ok=True) 214 | df.to_csv(result_dir / f'{args.datatrack}-{args.feat_type}.csv') 215 | 216 | 217 | 218 | if __name__ == '__main__': 219 | main() 220 | 221 | 222 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/calc_testphase_result.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from pathlib import Path 4 | 5 | import itertools 6 | 7 | import yaml 8 | import numpy as np 9 | import pandas as pd 10 | import scipy 11 | import scipy.stats 12 | 13 | def get_arg(): 14 | import argparse 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('datatrack') 17 | parser.add_argument('feat_type', default='weak36') 18 | return parser.parse_args() 19 | 20 | K_CV = 5 21 | 22 | 23 | def main(): 24 | 25 | args = get_arg() 26 | 27 | feat_conf = yaml.safe_load(open('./stage2-method/{}.yaml'.format(args.feat_type))) 28 | print(feat_conf) 29 | 30 | train_datatrack = 'phase1-main' if args.datatrack in ['testphase-main', 'valphase-main'] else 'phase1-ood' 31 | 32 | result_dir = Path('../out/ensemble-multidomain') / 'stage3' / \ 33 | train_datatrack / f'ridge-{args.feat_type}' 34 | 35 | df_tests = [] 36 | 37 | for i_cv in range(K_CV): 38 | df_tests.append(pd.read_csv(result_dir / str(i_cv) / f'pred-{args.datatrack}/test.csv', 39 | index_col=0)) 40 | 41 | df_test = sum(df_tests) / len(df_tests) 42 | 43 | answer_dir = Path('../out/ensemble-multidomain/answer') 44 | os.makedirs(answer_dir, exist_ok=True) 45 | df_test['pred_mos'].to_csv(answer_dir / f'{args.datatrack}-{args.feat_type}.csv', 46 | header=None) 47 | 48 | 49 | 50 | if __name__ == '__main__': 51 | main() 52 | 53 | 54 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/collect_stage1_result.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from pathlib import Path 4 | import itertools 5 | 6 | import pandas as pd 7 | import itertools 8 | import yaml 9 | 10 | K_CV = 5 11 | 12 | def get_arg(): 13 | import argparse 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('datatrack') 16 | parser.add_argument('feat_type', default='weak36') 17 | return parser.parse_args() 18 | 19 | 20 | def get_learner_data(stage1_result_dir, pred_datatrack, use_cv_result, use_upper_lower, column_tag, k_cv=K_CV): 21 | 22 | df_vals = [] 23 | df_tests = [] 24 | for i_cv in range(k_cv): 25 | if use_cv_result: 26 | df_vals.append(pd.read_csv(stage1_result_dir / str(i_cv) / f'val.csv', 27 | index_col=0)) 28 | df_tests.append(pd.read_csv(stage1_result_dir / str(i_cv) / f'test.csv', 29 | index_col=0)) 30 | else: 31 | pred_dir = stage1_result_dir / str(i_cv) / f'pred-{pred_datatrack}' 32 | df_vals.append(pd.read_csv(pred_dir / f'train.csv', 33 | index_col=0)) 34 | df_tests.append(pd.read_csv(pred_dir / f'test.csv', 35 | index_col=0)) 36 | 37 | if use_cv_result: 38 | df_train = pd.concat(df_vals) 39 | else: 40 | df_train = sum(df_vals) / len(df_vals) 41 | df_test = sum(df_tests) / len(df_tests) 42 | 43 | # empty column df 44 | df_train_new = df_train[[]].copy() 45 | df_test_new = df_test[[]].copy() 46 | 47 | if use_upper_lower: 48 | col_name = 'mean-' + column_tag 49 | df_train_new[col_name] = df_train['pred_mos'].copy() 50 | df_test_new[col_name] = df_test['pred_mos'].copy() 51 | 52 | col_name = 'lower-' + column_tag 53 | df_train_new[col_name] = df_train['lower_mos'].copy() 54 | df_test_new[col_name] = df_test['lower_mos'].copy() 55 | 56 | col_name = 'upper-' + column_tag 57 | df_train_new[col_name] = df_train['upper_mos'].copy() 58 | df_test_new[col_name] = df_test['upper_mos'].copy() 59 | 60 | else: 61 | col_name = 'pred-' + column_tag 62 | df_train_new[col_name] = df_train['pred_mos'].copy() 63 | df_test_new[col_name] = df_test['pred_mos'].copy() 64 | 65 | return df_train_new, df_test_new 66 | 67 | 68 | def main(): 69 | 70 | args = get_arg() 71 | 72 | stage2_data_dir = Path('../out/ensemble-multidomain/data-stage2') / args.datatrack / args.feat_type 73 | stage1_result_base_dir = Path('../out/ensemble-multidomain/stage1') 74 | 75 | feat_conf = yaml.safe_load(open('./stage2-method/{}.yaml'.format(args.feat_type))) 76 | print(feat_conf) 77 | 78 | df_train_list = [] 79 | df_test_list = [] 80 | 81 | for strong_learner in feat_conf['strong_learners']: 82 | 83 | stage1_result_dir = stage1_result_base_dir / args.datatrack / strong_learner 84 | 85 | column_tag = strong_learner 86 | k_cv = 3 if args.datatrack == 'phase1-ood' else K_CV 87 | df_train, df_test = get_learner_data(stage1_result_dir=stage1_result_dir, 88 | pred_datatrack=args.datatrack, 89 | use_cv_result=True, 90 | use_upper_lower=False, 91 | column_tag=strong_learner, 92 | k_cv=k_cv) 93 | 94 | df_train_list.append(df_train) 95 | df_test_list.append(df_test) 96 | 97 | for train_datatrack, model_type, ssl_type in itertools.product( 98 | feat_conf['weak_learners']['datatracks'], 99 | feat_conf['weak_learners']['model_types'], 100 | feat_conf['weak_learners']['ssl_types']): 101 | 102 | if model_type == 'autogp': 103 | if train_datatrack == 'phase1-main': 104 | model_type = 'svgp' 105 | else: 106 | model_type = 'exactgp' 107 | 108 | use_cv_result = (args.datatrack == train_datatrack or train_datatrack.startswith('phase1-all')) 109 | use_upper_lower = (model_type in ['svgp', 'exactgp']) 110 | 111 | stage1_result_dir = stage1_result_base_dir / train_datatrack / f'{model_type}-{ssl_type}' 112 | 113 | column_tag = f'{train_datatrack}---{model_type}---{ssl_type}' 114 | 115 | df_train, df_test = get_learner_data(stage1_result_dir=stage1_result_dir, 116 | pred_datatrack=args.datatrack, 117 | use_cv_result=use_cv_result, 118 | use_upper_lower=use_upper_lower, 119 | column_tag=column_tag) 120 | 121 | df_train_list.append(df_train) 122 | df_test_list.append(df_test) 123 | 124 | df_train_all = pd.concat(df_train_list, axis=1) 125 | df_test_all = pd.concat(df_test_list, axis=1) 126 | 127 | df_train_all.sort_index(inplace=True) 128 | df_test_all.sort_index(inplace=True) 129 | 130 | 131 | print('Columns: {}'.format(df_train_all.columns)) 132 | print('Train: {}'.format(df_train_all.shape)) 133 | print('Test: {}'.format(df_test_all.shape)) 134 | 135 | os.makedirs(stage2_data_dir, exist_ok=True) 136 | df_train_all.to_csv(stage2_data_dir / 'train-X.csv') 137 | df_test_all.to_csv(stage2_data_dir / 'test-X.csv') 138 | 139 | 140 | if __name__ == '__main__': 141 | main() 142 | 143 | 144 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/collect_stage1_testphase_result.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from pathlib import Path 4 | import itertools 5 | 6 | import pandas as pd 7 | import itertools 8 | import yaml 9 | 10 | K_CV = 5 11 | 12 | def get_arg(): 13 | import argparse 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('datatrack') 16 | parser.add_argument('feat_type', default='weak36') 17 | return parser.parse_args() 18 | 19 | 20 | def get_learner_data(stage1_result_dir, pred_datatrack, use_upper_lower, column_tag, k_cv=K_CV): 21 | 22 | df_tests = [] 23 | for i_cv in range(k_cv): 24 | pred_dir = stage1_result_dir / str(i_cv) / f'pred-{pred_datatrack}' 25 | df_tests.append(pd.read_csv(pred_dir / f'test.csv', 26 | index_col=0)) 27 | 28 | df_test = sum(df_tests) / len(df_tests) 29 | 30 | # empty column df 31 | df_test_new = df_test[[]].copy() 32 | 33 | if use_upper_lower: 34 | col_name = 'mean-' + column_tag 35 | df_test_new[col_name] = df_test['pred_mos'].copy() 36 | 37 | col_name = 'lower-' + column_tag 38 | df_test_new[col_name] = df_test['lower_mos'].copy() 39 | 40 | col_name = 'upper-' + column_tag 41 | df_test_new[col_name] = df_test['upper_mos'].copy() 42 | 43 | else: 44 | col_name = 'pred-' + column_tag 45 | df_test_new[col_name] = df_test['pred_mos'].copy() 46 | 47 | return df_test_new 48 | 49 | 50 | def main(): 51 | 52 | args = get_arg() 53 | 54 | stage2_data_dir = Path('../out/ensemble-multidomain/data-stage2') / args.datatrack / args.feat_type 55 | stage1_result_base_dir = Path('../out/ensemble-multidomain/stage1') 56 | 57 | feat_conf = yaml.safe_load(open('./stage2-method/{}.yaml'.format(args.feat_type))) 58 | print(feat_conf) 59 | 60 | df_test_list = [] 61 | 62 | for strong_learner in feat_conf['strong_learners']: 63 | 64 | # train_datatrack = 'phase1-main' if args.datatrack == 'testphase-main' else 'phase1-ood' 65 | 66 | stage1_result_dir = stage1_result_base_dir / args.datatrack / strong_learner 67 | 68 | k_cv = 3 if args.datatrack == 'testphase-ood' else K_CV 69 | 70 | column_tag = strong_learner 71 | df_test = get_learner_data(stage1_result_dir=stage1_result_dir, 72 | pred_datatrack=args.datatrack, 73 | use_upper_lower=False, 74 | column_tag=strong_learner, 75 | k_cv=k_cv) 76 | 77 | df_test_list.append(df_test) 78 | 79 | for train_datatrack, model_type, ssl_type in itertools.product( 80 | feat_conf['weak_learners']['datatracks'], 81 | feat_conf['weak_learners']['model_types'], 82 | feat_conf['weak_learners']['ssl_types']): 83 | 84 | if model_type == 'autogp': 85 | if train_datatrack == 'phase1-main': 86 | model_type = 'svgp' 87 | else: 88 | model_type = 'exactgp' 89 | 90 | use_cv_result = (args.datatrack == train_datatrack or train_datatrack.startswith('phase1-all')) 91 | use_upper_lower = (model_type in ['svgp', 'exactgp']) 92 | 93 | stage1_result_dir = stage1_result_base_dir / train_datatrack / f'{model_type}-{ssl_type}' 94 | 95 | column_tag = f'{train_datatrack}---{model_type}---{ssl_type}' 96 | 97 | df_test = get_learner_data(stage1_result_dir=stage1_result_dir, 98 | pred_datatrack=args.datatrack, 99 | use_upper_lower=use_upper_lower, 100 | column_tag=column_tag) 101 | 102 | df_test_list.append(df_test) 103 | 104 | df_test_all = pd.concat(df_test_list, axis=1) 105 | 106 | df_test_all.sort_index(inplace=True) 107 | 108 | 109 | print('Columns: {}'.format(df_test_all.columns)) 110 | print('Test: {}'.format(df_test_all.shape)) 111 | 112 | os.makedirs(stage2_data_dir, exist_ok=True) 113 | df_test_all.to_csv(stage2_data_dir / 'test-X.csv') 114 | 115 | 116 | if __name__ == '__main__': 117 | main() 118 | 119 | 120 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/collect_stage2_result.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from pathlib import Path 4 | import itertools 5 | 6 | import pandas as pd 7 | import itertools 8 | import yaml 9 | 10 | K_CV = 5 11 | 12 | def get_arg(): 13 | import argparse 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('datatrack') 16 | parser.add_argument('feat_type', default='weak36') 17 | return parser.parse_args() 18 | 19 | 20 | def get_learner_data(stage2_result_dir, pred_datatrack, use_upper_lower, column_tag): 21 | 22 | df_vals = [] 23 | df_tests = [] 24 | for i_cv in range(K_CV): 25 | df_vals.append(pd.read_csv(stage2_result_dir / str(i_cv) / f'val.csv', 26 | index_col=0)) 27 | df_tests.append(pd.read_csv(stage2_result_dir / str(i_cv) / f'test.csv', 28 | index_col=0)) 29 | 30 | df_train = pd.concat(df_vals) 31 | df_test = sum(df_tests) / len(df_tests) 32 | 33 | # empty column df 34 | df_train_new = df_train[[]].copy() 35 | df_test_new = df_test[[]].copy() 36 | 37 | if use_upper_lower: 38 | col_name = 'mean-' + column_tag 39 | df_train_new[col_name] = df_train['pred_mos'].copy() 40 | df_test_new[col_name] = df_test['pred_mos'].copy() 41 | 42 | col_name = 'lower-' + column_tag 43 | df_train_new[col_name] = df_train['lower_mos'].copy() 44 | df_test_new[col_name] = df_test['lower_mos'].copy() 45 | 46 | col_name = 'upper-' + column_tag 47 | df_train_new[col_name] = df_train['upper_mos'].copy() 48 | df_test_new[col_name] = df_test['upper_mos'].copy() 49 | 50 | else: 51 | col_name = 'pred-' + column_tag 52 | df_train_new[col_name] = df_train['pred_mos'].copy() 53 | df_test_new[col_name] = df_test['pred_mos'].copy() 54 | 55 | return df_train_new, df_test_new 56 | 57 | 58 | def main(): 59 | 60 | args = get_arg() 61 | 62 | stage3_data_dir = Path('../out/ensemble-multidomain/data-stage3') / args.datatrack / args.feat_type 63 | stage2_result_base_dir = Path('../out/ensemble-multidomain/stage2') 64 | 65 | feat_conf = yaml.safe_load(open('./stage2-method/{}.yaml'.format(args.feat_type))) 66 | print(feat_conf) 67 | 68 | df_train_list = [] 69 | df_test_list = [] 70 | 71 | for model_type in feat_conf['weak_learners']['model_types']: 72 | 73 | if model_type == 'autogp': 74 | if train_datatrack == 'phase1-main': 75 | model_type = 'svgp' 76 | else: 77 | model_type = 'exactgp' 78 | 79 | use_upper_lower = (model_type in ['svgp', 'exactgp']) 80 | 81 | stage2_result_dir = stage2_result_base_dir / args.datatrack / f'{model_type}-{args.feat_type}' 82 | 83 | column_tag = model_type 84 | 85 | df_train, df_test = get_learner_data(stage2_result_dir, args.datatrack, 86 | use_upper_lower, column_tag) 87 | 88 | df_train_list.append(df_train) 89 | df_test_list.append(df_test) 90 | 91 | df_train_all = pd.concat(df_train_list, axis=1) 92 | df_test_all = pd.concat(df_test_list, axis=1) 93 | 94 | df_train_all.sort_index(inplace=True) 95 | df_test_all.sort_index(inplace=True) 96 | 97 | print('Columns: {}'.format(df_train_all.columns)) 98 | print('Train: {}'.format(df_train_all.shape)) 99 | print('Test: {}'.format(df_test_all.shape)) 100 | 101 | os.makedirs(stage3_data_dir, exist_ok=True) 102 | df_train_all.to_csv(stage3_data_dir / 'train-X.csv') 103 | df_test_all.to_csv(stage3_data_dir / 'test-X.csv') 104 | 105 | 106 | if __name__ == '__main__': 107 | main() 108 | 109 | 110 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/collect_stage2_testphase_result.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from pathlib import Path 4 | import itertools 5 | 6 | import pandas as pd 7 | import itertools 8 | import yaml 9 | 10 | K_CV = 5 11 | 12 | def get_arg(): 13 | import argparse 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('datatrack') 16 | parser.add_argument('feat_type', default='weak36') 17 | return parser.parse_args() 18 | 19 | 20 | def get_learner_data(stage2_result_dir, pred_datatrack, use_upper_lower, column_tag, k_cv=K_CV): 21 | 22 | df_tests = [] 23 | for i_cv in range(k_cv): 24 | pred_dir = stage2_result_dir / str(i_cv) / f'pred-{pred_datatrack}' 25 | df_tests.append(pd.read_csv(pred_dir / f'test.csv', 26 | index_col=0)) 27 | 28 | df_test = sum(df_tests) / len(df_tests) 29 | 30 | # empty column df 31 | df_test_new = df_test[[]].copy() 32 | 33 | if use_upper_lower: 34 | col_name = 'mean-' + column_tag 35 | df_test_new[col_name] = df_test['pred_mos'].copy() 36 | 37 | col_name = 'lower-' + column_tag 38 | df_test_new[col_name] = df_test['lower_mos'].copy() 39 | 40 | col_name = 'upper-' + column_tag 41 | df_test_new[col_name] = df_test['upper_mos'].copy() 42 | 43 | else: 44 | col_name = 'pred-' + column_tag 45 | df_test_new[col_name] = df_test['pred_mos'].copy() 46 | 47 | return df_test_new 48 | 49 | 50 | def main(): 51 | 52 | args = get_arg() 53 | 54 | stage3_data_dir = Path('../out/ensemble-multidomain/data-stage3') / args.datatrack / args.feat_type 55 | stage2_result_base_dir = Path('../out/ensemble-multidomain/stage2') 56 | 57 | feat_conf = yaml.safe_load(open('./stage2-method/{}.yaml'.format(args.feat_type))) 58 | print(feat_conf) 59 | 60 | df_test_list = [] 61 | 62 | 63 | for model_type in feat_conf['weak_learners']['model_types']: 64 | 65 | if model_type == 'autogp': 66 | if train_datatrack == 'phase1-main': 67 | model_type = 'svgp' 68 | else: 69 | model_type = 'exactgp' 70 | 71 | use_upper_lower = (model_type in ['svgp', 'exactgp']) 72 | 73 | train_datatrack = 'phase1-main' if args.datatrack in ['testphase-main', 'valphase-main'] else 'phase1-ood' 74 | 75 | stage2_result_dir = stage2_result_base_dir / train_datatrack / f'{model_type}-{args.feat_type}' 76 | 77 | column_tag = model_type 78 | 79 | df_test = get_learner_data(stage2_result_dir, args.datatrack, 80 | use_upper_lower, column_tag) 81 | df_test_list.append(df_test) 82 | 83 | df_test_all = pd.concat(df_test_list, axis=1) 84 | 85 | df_test_all.sort_index(inplace=True) 86 | 87 | 88 | print('Columns: {}'.format(df_test_all.columns)) 89 | print('Test: {}'.format(df_test_all.shape)) 90 | 91 | os.makedirs(stage3_data_dir, exist_ok=True) 92 | df_test_all.to_csv(stage3_data_dir / 'test-X.csv') 93 | 94 | 95 | if __name__ == '__main__': 96 | main() 97 | 98 | 99 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/convert_strong_learner_result.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from pathlib import Path 4 | import os 5 | 6 | 7 | def get_arg(): 8 | import argparse 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('datatrack') 11 | parser.add_argument('learner') 12 | return parser.parse_args() 13 | 14 | 15 | def get_merge_df(df_true, df_pred): 16 | 17 | df = pd.merge(df_pred, df_true, on="wavname", how="left") 18 | df = df.set_index('wavname')[['pred_mos', 'true_mos']] 19 | 20 | return df 21 | 22 | def get_true_mos(datatrack): 23 | 24 | train_true_path = f'../data/{datatrack}/DATA/sets/train_mos_list.txt' 25 | val_true_path = f'../data/{datatrack}/DATA/sets/val_mos_list.txt' 26 | 27 | df_true_dict = {} 28 | df_true_dict['train'] = pd.read_csv(train_true_path, header=None, 29 | names=['wavname', 'true_mos']) 30 | df_true_dict['val'] = pd.read_csv(val_true_path, header=None, 31 | names=['wavname', 'true_mos']) 32 | df_true_dict['train'].shape, df_true_dict['val'].shape 33 | 34 | return df_true_dict 35 | 36 | 37 | def main(): 38 | 39 | args = get_arg() 40 | 41 | phase = 'main' 42 | in_dir = Path('../strong_learner_result') / args.learner 43 | out_base_dir = Path(f'../out/ensemble-multidomain/stage1') \ 44 | / args.datatrack / args.learner 45 | 46 | df_true_dict = get_true_mos(args.datatrack) 47 | 48 | k_cv = 3 if args.datatrack == 'phase1-ood' else 5 49 | answer_file_tag = 'answer-ood' if args.datatrack == 'phase1-ood' else 'answer-main' 50 | 51 | for i_cv in range(k_cv): 52 | out_dir = out_base_dir / str(i_cv) 53 | os.makedirs(out_dir, exist_ok=True) 54 | print(out_dir) 55 | 56 | for split in ['train', 'val']: 57 | stacking_split_name = {'train': 'val', 'val': 'test'}[split] 58 | learner_split_name = {'train': 'fold', 'val': 'val'}[split] 59 | in_path = in_dir / f'{answer_file_tag}.csv{learner_split_name}_{i_cv}' 60 | 61 | df_pred = pd.read_csv(in_path, header=None, 62 | names=['wavbase', 'pred_mos']) 63 | df_pred["wavname"] = df_pred["wavbase"] + ".wav" 64 | 65 | df = get_merge_df(df_true_dict[split], df_pred) 66 | 67 | df.to_csv(out_dir / f'{stacking_split_name}.csv') 68 | 69 | 70 | if __name__ == '__main__': 71 | main() 72 | 73 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/convert_strong_learner_testphase_result.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from pathlib import Path 4 | import os 5 | 6 | 7 | def get_arg(): 8 | import argparse 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('datatrack') 11 | parser.add_argument('learner') 12 | return parser.parse_args() 13 | 14 | 15 | def get_merge_df(df_true, df_pred): 16 | 17 | df = pd.merge(df_pred, df_true, on="wavname", how="left") 18 | df = df.set_index('wavname')[['pred_mos', 'true_mos']] 19 | 20 | return df 21 | 22 | def main(): 23 | 24 | args = get_arg() 25 | 26 | phase = 'main' 27 | in_dir = Path('../strong_learner_result') / args.learner 28 | out_base_dir = Path(f'../out/ensemble-multidomain/stage1/') \ 29 | / args.datatrack / args.learner 30 | 31 | k_cv = 3 if args.datatrack == 'testphase-ood' else 5 32 | answer_file_tag = 'answer-ood' if args.datatrack == 'testphase-ood' else 'answer-main' 33 | 34 | for i_cv in range(k_cv): 35 | out_dir = out_base_dir / str(i_cv) / f'pred-{args.datatrack}' 36 | os.makedirs(out_dir, exist_ok=True) 37 | print(out_dir) 38 | 39 | in_path = in_dir / f'{answer_file_tag}.csvtest_{i_cv}' 40 | 41 | df_pred = pd.read_csv(in_path, header=None, 42 | names=['wavbase', 'pred_mos']) 43 | df_pred["wavname"] = df_pred["wavbase"] + ".wav" 44 | df_pred['true_mos'] = -99.0 45 | df = df_pred.set_index('wavname')[['pred_mos', 'true_mos']] 46 | 47 | df.to_csv(out_dir / f'test.csv') 48 | 49 | 50 | if __name__ == '__main__': 51 | main() 52 | 53 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/data_util.py: -------------------------------------------------------------------------------- 1 | 2 | from pathlib import Path 3 | from logging import getLogger 4 | 5 | import numpy as np 6 | import pandas as pd 7 | 8 | logger = getLogger(__name__) 9 | 10 | def make_stage1_data(fold_file, utt_data_dir): 11 | 12 | df = pd.read_csv(fold_file, header=None, index_col=0, 13 | names=['wavname', 'true_mos']) 14 | df = df.sort_index() 15 | 16 | embeddings = [] 17 | for wavname in df.index: 18 | wavbase = wavname.split('.')[0] 19 | embeddings.append(np.load(utt_data_dir / f'{wavbase}.npy')) 20 | 21 | X = np.stack(embeddings) 22 | y = df['true_mos'].values 23 | 24 | return df, X, y 25 | 26 | def load_stage1_data(datatrack, ssl_type, i_cv): 27 | utt_data_dir = Path('../out/utt_data') / ssl_type 28 | fold_dir = Path('../out/ensemble-multidomain/fold') / datatrack 29 | 30 | logger.info('load data') 31 | 32 | data = {} 33 | for split in ['train', 'val', 'test']: 34 | fold_file = fold_dir / f'{split}-{i_cv}.csv' 35 | df, X, y = make_stage1_data(fold_file, utt_data_dir) 36 | 37 | logger.info('[{}]\tX: {}, y: {}'.format(split, X.shape, y.shape)) 38 | 39 | data[split] = {'X': X, 'y': y, 'df': df} 40 | 41 | return data 42 | 43 | 44 | def load_stage1_train_all_data(datatrack, ssl_type): 45 | utt_data_dir = Path('../out/utt_data') / ssl_type 46 | fold_dir = Path('../out/ensemble-multidomain/fold') / datatrack 47 | 48 | logger.info('load data') 49 | 50 | fold_file = fold_dir / f'train-all.csv' 51 | df, X, y = make_stage1_data(fold_file, utt_data_dir) 52 | 53 | logger.info('[train-all]\tX: {}, y: {}'.format(X.shape, y.shape)) 54 | 55 | train_data = {'X': X, 'y': y, 'df': df} 56 | 57 | return train_data 58 | 59 | 60 | def load_stage1_test_data(datatrack, ssl_type): 61 | utt_data_dir = Path('../out/utt_data') / ssl_type 62 | fold_dir = Path('../out/ensemble-multidomain/fold') / datatrack 63 | 64 | logger.info('load data') 65 | 66 | fold_file = fold_dir / f'test-0.csv' 67 | df, X, y = make_stage1_data(fold_file, utt_data_dir) 68 | 69 | logger.info('[test]\tX: {}, y: {}'.format(X.shape, y.shape)) 70 | 71 | train_data = {'X': X, 'y': y, 'df': df} 72 | 73 | return train_data 74 | 75 | 76 | def make_stage2_data(fold_file, df_all_X): 77 | 78 | df_fold = pd.read_csv(fold_file, header=None, index_col=0, 79 | names=['wavname', 'true_mos']) 80 | df_fold = df_fold.sort_index() 81 | 82 | wavnames = df_fold.index.values 83 | 84 | X = df_all_X.loc[wavnames, :].values 85 | y = df_fold['true_mos'].values 86 | 87 | return df_fold, X, y 88 | 89 | 90 | def load_stage2_data(datatrack, feat_type, i_cv): 91 | data_dir = Path('../out/ensemble-multidomain') / 'data-stage2' / datatrack / feat_type 92 | fold_dir = Path('../out/ensemble-multidomain/fold') / datatrack 93 | 94 | logger.info('load data') 95 | 96 | 97 | data = {} 98 | for split in ['train', 'val', 'test']: 99 | data_path = data_dir / '{}-X.csv'.format('test' if split == 'test' else 'train') 100 | df_all_X = pd.read_csv(data_path, index_col=0) 101 | fold_file = fold_dir / f'{split}-{i_cv}.csv' 102 | df_fold, X, y = make_stage2_data(fold_file, df_all_X) 103 | 104 | logger.info('[{}]\tX: {}, y: {}'.format(split, X.shape, y.shape)) 105 | 106 | data[split] = {'X': X, 'y': y, 'df': df_fold} 107 | 108 | return data 109 | 110 | 111 | def load_stage2_train_all_data(datatrack, feat_type): 112 | data_dir = Path('../out/ensemble-multidomain') / 'data-stage2' / datatrack / feat_type 113 | fold_dir = Path('../out/ensemble-multidomain/fold') / datatrack 114 | 115 | logger.info('load data') 116 | 117 | data_path = data_dir / f'train-X.csv' 118 | df_all_X = pd.read_csv(data_path, index_col=0) 119 | fold_file = fold_dir / f'train-all.csv' 120 | df, X, y = make_stage2_data(fold_file, df_all_X) 121 | 122 | logger.info('[train-all]\tX: {}, y: {}'.format(X.shape, y.shape)) 123 | 124 | train_data = {'X': X, 'y': y, 'df': df} 125 | 126 | return train_data 127 | 128 | 129 | def load_stage2_test_data(datatrack, feat_type): 130 | data_dir = Path('../out/ensemble-multidomain') / 'data-stage2' / datatrack / feat_type 131 | fold_dir = Path('../out/ensemble-multidomain/fold') / datatrack 132 | 133 | logger.info('load data') 134 | 135 | data_path = data_dir / f'test-X.csv' 136 | df_all_X = pd.read_csv(data_path, index_col=0) 137 | fold_file = fold_dir / f'test-0.csv' 138 | df, X, y = make_stage2_data(fold_file, df_all_X) 139 | 140 | logger.info('[test]\tX: {}, y: {}'.format(X.shape, y.shape)) 141 | 142 | train_data = {'X': X, 'y': y, 'df': df} 143 | 144 | return train_data 145 | 146 | def make_stage3_data(fold_file, df_all_X): 147 | return make_stage2_data(fold_file, df_all_X) 148 | 149 | 150 | def load_stage3_data(datatrack, feat_type, i_cv): 151 | data_dir = Path('../out/ensemble-multidomain') / 'data-stage3' / datatrack / feat_type 152 | fold_dir = Path('../out/ensemble-multidomain/fold') / datatrack 153 | 154 | logger.info('load data') 155 | 156 | 157 | data = {} 158 | for split in ['train', 'val', 'test']: 159 | data_path = data_dir / '{}-X.csv'.format('test' if split == 'test' else 'train') 160 | df_all_X = pd.read_csv(data_path, index_col=0) 161 | fold_file = fold_dir / f'{split}-{i_cv}.csv' 162 | df_fold, X, y = make_stage2_data(fold_file, df_all_X) 163 | 164 | logger.info('[{}]\tX: {}, y: {}'.format(split, X.shape, y.shape)) 165 | 166 | data[split] = {'X': X, 'y': y, 'df': df_fold} 167 | 168 | return data 169 | 170 | 171 | def load_stage3_train_all_data(datatrack, feat_type): 172 | data_dir = Path('../out/ensemble-multidomain') / 'data-stage3' / datatrack / feat_type 173 | fold_dir = Path('../out/ensemble-multidomain/fold') / datatrack 174 | 175 | logger.info('load data') 176 | 177 | 178 | data_path = data_dir / f'train-X.csv' 179 | df_all_X = pd.read_csv(data_path, index_col=0) 180 | fold_file = fold_dir / f'train-all.csv' 181 | df, X, y = make_stage2_data(fold_file, df_all_X) 182 | 183 | logger.info('[train-all]\tX: {}, y: {}'.format(X.shape, y.shape)) 184 | 185 | train_data = {'X': X, 'y': y, 'df': df} 186 | 187 | return train_data 188 | 189 | 190 | def load_stage3_test_data(datatrack, feat_type): 191 | data_dir = Path('../out/ensemble-multidomain') / 'data-stage3' / datatrack / feat_type 192 | fold_dir = Path('../out/ensemble-multidomain/fold') / datatrack 193 | 194 | logger.info('load data') 195 | 196 | data_path = data_dir / f'test-X.csv' 197 | df_all_X = pd.read_csv(data_path, index_col=0) 198 | fold_file = fold_dir / f'test-0.csv' 199 | df, X, y = make_stage3_data(fold_file, df_all_X) 200 | 201 | logger.info('[test]\tX: {}, y: {}'.format(X.shape, y.shape)) 202 | 203 | train_data = {'X': X, 'y': y, 'df': df} 204 | 205 | return train_data 206 | 207 | 208 | def normalize_score(val): 209 | """ 210 | >>> normalize_score(1) 211 | -1.0 212 | >>> normalize_score(3) 213 | 0 214 | >>> normalize_score(5) 215 | 1.0 216 | """ 217 | return (val - 3.0) / 2.0 218 | 219 | def inverse_normalize_score(val): 220 | """ 221 | >>> inverse_normalize_score(-1) 222 | 1.0 223 | >>> inverse_normalize_score(0) 224 | 3.0 225 | >>> inverse_normalize_score(1) 226 | 5.0 227 | """ 228 | return (val * 2.0) + 3.0 229 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/extract_ssl_feature.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from tqdm import tqdm 5 | from pathlib import Path 6 | import torchaudio 7 | import fairseq 8 | import torch 9 | import os 10 | 11 | import sys 12 | sys.path.append('./external_libs/WavLM') 13 | from WavLM import WavLM, WavLMConfig 14 | 15 | def get_arg(): 16 | import argparse 17 | parser = argparse.ArgumentParser() 18 | return parser.parse_args() 19 | 20 | 21 | def extract_mean(wavpath, ssl_model, device, use_wavlm): 22 | with torch.no_grad(): 23 | if use_wavlm: 24 | wav = torchaudio.load(wavpath)[0] 25 | res = ssl_model.extract_features(wav.to(device)) 26 | return res[0].squeeze(0).mean(dim=0) 27 | else: 28 | wav = torchaudio.load(wavpath)[0] 29 | res = ssl_model(wav.to(device), mask=False, features_only=True) 30 | return res['x'].squeeze(0).mean(dim=0) 31 | 32 | 33 | def extract_feature(datatrack, ssl_type): 34 | 35 | device = torch.device('cuda') 36 | 37 | wav_dir = Path(f'../data/{datatrack}/DATA/wav/') 38 | 39 | base_ckpt_file = { 40 | 'w2v_small': '../../fairseq_checkpoints/wav2vec_small.pt', 41 | 'w2v_xlsr': '../../fairseq_checkpoints/xlsr_53_56k.pt', 42 | 'w2v_large': '../../fairseq_checkpoints/wav2vec_vox_new.pt', 43 | 'w2v_large2': '../../fairseq_checkpoints/w2v_large_lv_fsh_swbd_cv.pt', 44 | 'wavlm_base': '../../fairseq_checkpoints/WavLM-Base.pt', 45 | 'wavlm_large': '../../fairseq_checkpoints/WavLM-Large.pt', 46 | 'hubert_base': '../../fairseq_checkpoints/hubert_base_ls960.pt', 47 | 'hubert_large': '../../fairseq_checkpoints/hubert_large_ll60k.pt', 48 | }[ssl_type] 49 | 50 | print('base_ckpt_file: {}'.format(base_ckpt_file)) 51 | 52 | use_wavlm = ssl_type in ['wavlm_base', 'wavlm_large'] 53 | 54 | if use_wavlm: 55 | checkpoint = torch.load(base_ckpt_file) 56 | cfg = WavLMConfig(checkpoint['cfg']) 57 | ssl_model = WavLM(cfg) 58 | ssl_model.load_state_dict(checkpoint['model']) 59 | else: 60 | model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([base_ckpt_file]) 61 | ssl_model = model[0] 62 | ssl_model.remove_pretraining_modules() 63 | 64 | ssl_model.to(device) 65 | ssl_model.eval() 66 | 67 | # print(ssl_model) 68 | 69 | out_dir = Path(f'../out/utt_data/{ssl_type}') 70 | os.makedirs(out_dir, exist_ok=True) 71 | 72 | wavpath_list = list(wav_dir.glob('*.wav')) 73 | 74 | for wavpath in tqdm(wavpath_list): 75 | vec = extract_mean(wavpath, ssl_model, device, use_wavlm) 76 | outpath = out_dir / (wavpath.stem + '.npy') 77 | 78 | vec = vec.detach().cpu().numpy() 79 | np.save(outpath, vec) 80 | 81 | 82 | def main(): 83 | args = get_arg() 84 | 85 | ssl_types = [ 86 | 'w2v_large2', 'w2v_xlsr', 87 | 'wavlm_base', 'wavlm_large', 88 | 'hubert_large', 'hubert_base', 89 | 'w2v_small', 'w2v_large', 90 | ] 91 | datatracks = ['phase1-main', 'phase1-ood', 'testphase-main', 'testphase-ood'] 92 | 93 | for datatrack in datatracks: 94 | for ssl_type in ssl_types: 95 | print('datatrack {}, ssl_type: {}'.format( 96 | datatrack, ssl_type)) 97 | extract_feature(datatrack, ssl_type) 98 | 99 | 100 | if __name__ == '__main__': 101 | main() 102 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/gp_models.py: -------------------------------------------------------------------------------- 1 | 2 | from pathlib import Path 3 | from logging import getLogger 4 | import joblib 5 | import os 6 | import json 7 | import pickle 8 | 9 | import math 10 | import numpy as np 11 | from sklearn.preprocessing import StandardScaler 12 | from sklearn.cluster import MiniBatchKMeans 13 | from sklearn.metrics import mean_squared_error, pairwise_distances 14 | import torch 15 | import torch.utils 16 | import gpytorch 17 | 18 | from data_util import normalize_score, inverse_normalize_score 19 | 20 | logger = getLogger(__name__) 21 | 22 | class Dataset(torch.utils.data.Dataset): 23 | def __init__(self, dx, dy, transform=None): 24 | 25 | self._N = len(dx) 26 | self._dx = dx 27 | self._dy = dy 28 | 29 | self.transform = transform 30 | 31 | def __len__(self): 32 | return self._N 33 | 34 | def __getitem__(self, idx): 35 | return self._dx[idx], self._dy[idx] 36 | 37 | class SVGPModel(gpytorch.models.ApproximateGP): 38 | def __init__(self, initial_inducing, initial_lengthscale): 39 | variational_distribution = gpytorch.variational.NaturalVariationalDistribution(initial_inducing.size(0)) 40 | variational_strategy = gpytorch.variational.VariationalStrategy( 41 | self, initial_inducing, variational_distribution, learn_inducing_locations=True 42 | ) 43 | 44 | super().__init__(variational_strategy) 45 | 46 | self.mean_module = gpytorch.means.ZeroMean() 47 | self.covar_module = gpytorch.kernels.ScaleKernel( 48 | gpytorch.kernels.RBFKernel(ard_num_dims=initial_inducing.size(1))) 49 | self.covar_module.base_kernel.lengthscale = initial_lengthscale 50 | 51 | 52 | 53 | def forward(self, x): 54 | mean_x = self.mean_module(x) 55 | covar_x = self.covar_module(x) 56 | return gpytorch.distributions.MultivariateNormal(mean_x, covar_x) 57 | 58 | 59 | 60 | class SVGP: 61 | 62 | def __init__(self, params=None, stage='stage1'): 63 | 64 | if params is None: 65 | self.params = { 66 | 'max_inducings': 1024, 67 | 'batch_size': 1024, 68 | # 'training_epochs': 3000 if stage == 'stage1' else 1000, 69 | 'training_iters': 10000 if stage == 'stage1' else 2500, 70 | } 71 | 72 | else: 73 | self.params = params 74 | 75 | self.device = torch.device('cuda') 76 | 77 | def get_num_inducing(self, num_data): 78 | max_inducings = self.params['max_inducings'] 79 | 80 | if num_data >= max_inducings: 81 | return max_inducings 82 | 83 | power = math.floor((math.log(num_data) / math.log(2))) 84 | num_inducings = int(2 ** power) 85 | 86 | return num_inducings 87 | 88 | 89 | def train(self, train_X, train_y, val_X, val_y): 90 | 91 | X_scaler = StandardScaler() 92 | train_X_sc = torch.from_numpy(X_scaler.fit_transform(train_X).astype(np.float32)) 93 | train_y_sc = torch.from_numpy(normalize_score(train_y).astype(np.float32)) 94 | 95 | val_X_sc = torch.from_numpy(X_scaler.transform(val_X).astype(np.float32)) 96 | 97 | # dataloader 98 | train_X_sc = train_X_sc.to(self.device) 99 | train_y_sc = train_y_sc.to(self.device) 100 | val_X_sc = val_X_sc.to(self.device) 101 | 102 | dataset = Dataset(train_X_sc, train_y_sc) 103 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.params['batch_size'], shuffle=True) 104 | 105 | # initial inducing 106 | num_inducings = self.get_num_inducing(len(train_X_sc)) 107 | logger.info('Num inducing points: {}'.format(num_inducings)) 108 | 109 | kmeans = MiniBatchKMeans(num_inducings) 110 | 111 | for i in range(5): 112 | for b, (x_B, y_B) in enumerate(dataloader): 113 | kmeans.partial_fit(x_B.cpu().numpy()) 114 | 115 | initial_inducing = torch.from_numpy(kmeans.cluster_centers_.astype(np.float32)) 116 | 117 | # initial lengthscale 118 | for b, (x_B, y_B) in enumerate(dataloader): 119 | D = pairwise_distances(x_B.cpu().numpy()) 120 | 121 | distances = D[np.tril_indices(len(D), k=-1)] 122 | # initial_lengthscale = np.sqrt(np.median(distances)) 123 | initial_lengthscale = np.median(distances) 124 | 125 | break 126 | 127 | # initialize likelihood and model 128 | gpr = SVGPModel(initial_inducing=initial_inducing, 129 | initial_lengthscale=initial_lengthscale) 130 | 131 | likelihood = gpytorch.likelihoods.GaussianLikelihood() 132 | 133 | gpr = gpr.to(self.device) 134 | likelihood = likelihood.to(self.device) 135 | 136 | # optimizer 137 | variational_ngd_optimizer = gpytorch.optim.NGD(gpr.variational_parameters(), 138 | num_data=train_X_sc.size(0), lr=0.1) 139 | 140 | hyperparameter_optimizer = torch.optim.Adam([ 141 | {'params': gpr.hyperparameters()}, 142 | {'params': likelihood.parameters()}, 143 | ], lr=0.01) 144 | 145 | # training 146 | gpr.train() 147 | likelihood.train() 148 | 149 | mll = gpytorch.mlls.VariationalELBO(likelihood, gpr, num_data=train_X_sc.size(0)) 150 | 151 | num_epochs = max(1, self.params['training_iters'] // len(dataloader)) 152 | for i in range(num_epochs): 153 | gpr.train() 154 | likelihood.train() 155 | 156 | lower_bound = 0 157 | for b, (x_B, y_B) in enumerate(dataloader): 158 | # x_B = x_B.to(device) 159 | # y_B = y_B.to(device) 160 | 161 | variational_ngd_optimizer.zero_grad() 162 | hyperparameter_optimizer.zero_grad() 163 | 164 | output = gpr(x_B) 165 | bound = mll(output, y_B) 166 | loss = -bound 167 | loss.backward() 168 | 169 | variational_ngd_optimizer.step() 170 | hyperparameter_optimizer.step() 171 | 172 | lower_bound += bound.data.cpu().item() 173 | 174 | if i % 100 == 0: 175 | gpr.eval() 176 | likelihood.eval() 177 | with torch.no_grad(): 178 | preds = gpr(val_X_sc) 179 | mean = preds.mean 180 | pred_y_sc = mean.cpu().numpy() 181 | 182 | pred_y = inverse_normalize_score(pred_y_sc) 183 | 184 | mse = mean_squared_error(val_y, pred_y.ravel()) 185 | 186 | logger.info('Iter %d/%d - ELBO: %.3f - val_loss %.3f ' % ( 187 | i + 1, num_epochs, lower_bound, mse, 188 | )) 189 | 190 | self.X_scaler = X_scaler 191 | self.gpr = gpr 192 | 193 | self.conf = { 194 | 'num_inducings': num_inducings, 195 | 'input_dim': train_X.shape[1], 196 | } 197 | 198 | def predict(self, X, df): 199 | self.gpr.eval() 200 | 201 | X_sc = torch.from_numpy(self.X_scaler.transform(X).astype(np.float32)) 202 | X_sc = X_sc.to(self.device) 203 | 204 | with torch.no_grad(): 205 | preds = self.gpr(X_sc) 206 | mean = preds.mean 207 | lower, upper = preds.confidence_region() 208 | 209 | mean_y = inverse_normalize_score(mean.cpu().numpy()) 210 | lower_y = inverse_normalize_score(lower.cpu().numpy()) 211 | upper_y = inverse_normalize_score(upper.cpu().numpy()) 212 | 213 | df['pred_mos'] = mean_y.ravel() 214 | df['lower_mos'] = lower_y.ravel() 215 | df['upper_mos'] = upper_y.ravel() 216 | 217 | return df 218 | 219 | 220 | def save_model(self, out_dir: Path): 221 | torch.save(self.gpr.state_dict(), out_dir / 'model.pt') 222 | joblib.dump(self.X_scaler, out_dir / 'X_scaler.joblib') 223 | 224 | with open(out_dir / 'model_config.json', encoding="utf-8", mode="w") as f: 225 | json.dump(self.conf, f, ensure_ascii=False, indent=2) 226 | 227 | 228 | def load_model(self, model_dir: Path, train_X=None): 229 | if os.path.exists(model_dir / 'model_config.json'): 230 | self.conf = json.load(open(model_dir / 'model_config.json', 'rb')) 231 | else: 232 | assert train_X is not None 233 | self.conf = { 234 | 'num_inducings': self.get_num_inducing(len(train_X)), 235 | 'input_dim': train_X.shape[1], 236 | } 237 | 238 | initial_inducing = torch.from_numpy( 239 | np.empty((self.conf['num_inducings'], self.conf['input_dim']), dtype=np.float32)) 240 | 241 | self.gpr = SVGPModel(initial_inducing, 1.0) 242 | self.gpr.to(self.device) 243 | 244 | self.gpr.load_state_dict(torch.load(model_dir / 'model.pt', map_location=self.device)) 245 | self.X_scaler = joblib.load(model_dir / 'X_scaler.joblib') 246 | 247 | class ExactGPModel(gpytorch.models.ExactGP): 248 | def __init__(self, train_x, train_y, likelihood, initial_lengthscale): 249 | super().__init__(train_x, train_y, likelihood) 250 | self.mean_module = gpytorch.means.ZeroMean() 251 | 252 | # lengthscale_prior = gpytorch.priors.GammaPrior(3.0, 6.0) 253 | # outputscale_prior = gpytorch.priors.GammaPrior(2.0, 0.15) 254 | 255 | self.covar_module = gpytorch.kernels.ScaleKernel( 256 | gpytorch.kernels.RBFKernel( 257 | # lengthscale_prior=lengthscale_prior, 258 | ard_num_dims=train_x.size(1), 259 | ), 260 | # outputscale_prior=outputscale_prior 261 | ) 262 | 263 | # self.covar_module = gpytorch.kernels.ScaleKernel( 264 | # gpytorch.kernels.RBFKernel( 265 | # lengthscale_prior=lengthscale_prior, 266 | # ard_num_dims=train_x.size(1), 267 | # ), 268 | # outputscale_prior=outputscale_prior 269 | # ) 270 | 271 | 272 | # Initialize lengthscale and outputscale to mean of priors 273 | # self.covar_module.base_kernel.lengthscale = initial_lengthscale 274 | self.covar_module.base_kernel.lengthscale = initial_lengthscale 275 | # self.covar_module.outputscale = outputscale_prior.mean 276 | 277 | 278 | 279 | def forward(self, x): 280 | mean_x = self.mean_module(x) 281 | covar_x = self.covar_module(x) 282 | return gpytorch.distributions.MultivariateNormal(mean_x, covar_x) 283 | 284 | 285 | class ExactGP: 286 | 287 | def __init__(self, params=None, stage='stage1'): 288 | if params is None: 289 | self.params = { 290 | # 'max_inducings': 1024, 291 | # 'batch_size': 1024, 292 | # 'training_epochs': 3000 if stage == 'stage1' else 1000, 293 | # 'training_iters': 10000 if stage == 'stage1' else 2500, 294 | 'training_iters': 10000, 295 | } 296 | 297 | else: 298 | self.params = params 299 | 300 | self.device = torch.device('cuda') 301 | 302 | def train(self, train_X, train_y, val_X, val_y): 303 | 304 | X_scaler = StandardScaler() 305 | train_X_sc = torch.from_numpy(X_scaler.fit_transform(train_X).astype(np.float32)) 306 | train_y_sc = torch.from_numpy(normalize_score(train_y).astype(np.float32)) 307 | 308 | val_X_sc = torch.from_numpy(X_scaler.transform(val_X).astype(np.float32)) 309 | 310 | # dataloader 311 | train_X_sc = train_X_sc.to(self.device) 312 | train_y_sc = train_y_sc.to(self.device) 313 | val_X_sc = val_X_sc.to(self.device) 314 | 315 | # dataset = Dataset(train_X_sc, train_y_sc) 316 | # dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.params['batch_size'], shuffle=True) 317 | 318 | # initial lengthscale 319 | D = pairwise_distances(train_X_sc.cpu().numpy()) 320 | 321 | distances = D[np.tril_indices(len(D), k=-1)] 322 | # initial_lengthscale = np.sqrt(np.median(distances)) 323 | initial_lengthscale = np.median(distances) 324 | 325 | # initialize likelihood and model 326 | likelihood = gpytorch.likelihoods.GaussianLikelihood( 327 | noise_constraint=gpytorch.constraints.GreaterThan(1e-3), 328 | ) 329 | gpr = ExactGPModel(train_X_sc, train_y_sc, likelihood, 330 | initial_lengthscale=initial_lengthscale) 331 | 332 | gpr = gpr.to(self.device) 333 | likelihood = likelihood.to(self.device) 334 | 335 | # optimizer 336 | optimizer = torch.optim.Adam(gpr.parameters(), lr=0.01) 337 | 338 | # training 339 | gpr.train() 340 | likelihood.train() 341 | 342 | mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, gpr) 343 | 344 | for i in range(self.params['training_iters']): 345 | gpr.train() 346 | likelihood.train() 347 | 348 | optimizer.zero_grad() 349 | 350 | output = gpr(train_X_sc) 351 | loss = -mll(output, train_y_sc) 352 | loss.backward() 353 | 354 | optimizer.step() 355 | 356 | if i % 100 == 0: 357 | gpr.eval() 358 | likelihood.eval() 359 | with torch.no_grad(): 360 | preds = gpr(val_X_sc) 361 | mean = preds.mean 362 | pred_y_sc = mean.cpu().numpy() 363 | 364 | pred_y = inverse_normalize_score(pred_y_sc) 365 | 366 | mse = mean_squared_error(val_y, pred_y.ravel()) 367 | 368 | logger.info('Iter %d/%d - NMLL: %.3f - val_loss %.3f ' % ( 369 | i + 1, self.params['training_iters'], loss, mse, 370 | )) 371 | # logger.info('Pred_y var %.3f ' % ( 372 | # pred_y.var(), 373 | # )) 374 | 375 | self.X_scaler = X_scaler 376 | self.gpr = gpr 377 | 378 | self.conf = { 379 | 'output_shape': train_y.shape, 380 | 'input_shape': train_X.shape, 381 | } 382 | 383 | def predict(self, X, df): 384 | self.gpr.eval() 385 | 386 | X_sc = torch.from_numpy(self.X_scaler.transform(X).astype(np.float32)) 387 | X_sc = X_sc.to(self.device) 388 | 389 | with torch.no_grad(): 390 | preds = self.gpr(X_sc) 391 | mean = preds.mean 392 | lower, upper = preds.confidence_region() 393 | 394 | mean_y = inverse_normalize_score(mean.cpu().numpy()) 395 | lower_y = inverse_normalize_score(lower.cpu().numpy()) 396 | upper_y = inverse_normalize_score(upper.cpu().numpy()) 397 | 398 | df['pred_mos'] = mean_y.ravel() 399 | df['lower_mos'] = lower_y.ravel() 400 | df['upper_mos'] = upper_y.ravel() 401 | 402 | return df 403 | 404 | 405 | def save_model(self, out_dir: Path): 406 | torch.save(self.gpr.state_dict(), out_dir / 'model.pt') 407 | joblib.dump(self.X_scaler, out_dir / 'X_scaler.joblib') 408 | # pickle.dump(self.gpr, open(out_dir / 'gpr_pkl.pkl', 'wb')) 409 | 410 | with open(out_dir / 'model_config.json', encoding="utf-8", mode="w") as f: 411 | json.dump(self.conf, f, ensure_ascii=False, indent=2) 412 | 413 | 414 | def load_model(self, model_dir: Path, train_X, train_y): 415 | self.conf = json.load(open(model_dir / 'model_config.json', 'rb')) 416 | 417 | self.X_scaler = joblib.load(model_dir / 'X_scaler.joblib') 418 | 419 | train_X_sc = torch.from_numpy(self.X_scaler.transform(train_X).astype(np.float32)) 420 | train_y_sc = torch.from_numpy(normalize_score(train_y).astype(np.float32)) 421 | 422 | # dummy_X = torch.from_numpy( 423 | # np.empty(self.conf['input_shape'], dtype=np.float32)) 424 | # dummy_y = torch.from_numpy( 425 | # np.empty(self.conf['output_shape'], dtype=np.float32)) 426 | 427 | likelihood = gpytorch.likelihoods.GaussianLikelihood( 428 | noise_constraint=gpytorch.constraints.GreaterThan(1e-3), 429 | ) 430 | self.gpr = ExactGPModel(train_X_sc, train_y_sc, likelihood, 431 | initial_lengthscale=1.0) 432 | self.gpr.to(self.device) 433 | 434 | self.gpr.load_state_dict(torch.load(model_dir / 'model.pt', map_location=self.device)) 435 | 436 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/make_ensemble_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import argparse 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | from sklearn.model_selection import KFold 8 | 9 | SEED_CV = 0 10 | K_CV = 5 11 | 12 | 13 | def get_arg(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument( 16 | "--datatrack", type=str, required=True, help="phase1-main or phase1-ood", 17 | default='phase1-main',) 18 | return parser.parse_args() 19 | 20 | def write_wavnames(outpath, wavnames, mos_lookup): 21 | 22 | with open(outpath, 'w') as f: 23 | for wavname in sorted(wavnames): 24 | if wavname == 'sys4bafa-uttc2e86f6.wav': 25 | print('Skip sys4bafa-uttc2e86f6') 26 | continue 27 | print('{},{}'.format(wavname, mos_lookup[wavname]), file=f) 28 | 29 | 30 | def main(): 31 | args = get_arg() 32 | 33 | datadir = Path('../data', args.datatrack, 'DATA') 34 | # outdir = Path('./out/ensemble', args.datatrack, 'fold') 35 | outdir = Path('../out/ensemble-multidomain/fold', args.datatrack) 36 | 37 | os.makedirs(outdir, exist_ok=True) 38 | 39 | moslists = { 40 | "train": datadir / "sets/train_mos_list.txt", 41 | "val": datadir / "sets/val_mos_list.txt", 42 | } 43 | mos_lookup = {} 44 | 45 | wavnames = { 46 | 'train': [], 47 | 'val': [], 48 | } 49 | 50 | for split in ["train", "val"]: 51 | with open(moslists[split], "r") as fr: 52 | for line in fr: 53 | parts = line.strip().split(",") 54 | wavname = parts[0] 55 | mos = parts[1] 56 | mos_lookup[wavname] = mos 57 | wavnames[split].append(wavname) 58 | 59 | Kf = KFold(n_splits=K_CV, random_state=SEED_CV, shuffle=True) 60 | train_wavnames = np.asarray(wavnames['train']) 61 | 62 | for i, (cv_train_idx, cv_val_idx) in enumerate(Kf.split(train_wavnames)): 63 | 64 | cv_train_wavnames = train_wavnames[cv_train_idx] 65 | cv_val_wavnames = train_wavnames[cv_val_idx] 66 | 67 | write_wavnames(outdir / f'train-{i}.csv', cv_train_wavnames, mos_lookup) 68 | write_wavnames(outdir / f'val-{i}.csv', cv_val_wavnames, mos_lookup) 69 | write_wavnames(outdir / f'test-{i}.csv', wavnames['val'], mos_lookup) 70 | 71 | write_wavnames(outdir / f'train-all.csv', train_wavnames, mos_lookup) 72 | 73 | 74 | if __name__ == '__main__': 75 | main() -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/make_ensemble_dataset_wotest.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import argparse 4 | from pathlib import Path 5 | import random 6 | 7 | import numpy as np 8 | from sklearn.model_selection import KFold 9 | 10 | SEED_CV = 0 11 | K_CV = 5 12 | 13 | NUM_VAL_FILES = 100 14 | 15 | def get_arg(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument( 18 | "--datatrack", type=str, required=True, help="phase1-main or phase1-ood", 19 | default='phase1-main',) 20 | return parser.parse_args() 21 | 22 | def write_wavnames(outpath, wavnames, mos_lookup): 23 | 24 | with open(outpath, 'w') as f: 25 | for wavname in sorted(wavnames): 26 | if wavname == 'sys4bafa-uttc2e86f6.wav': 27 | print('Skip sys4bafa-uttc2e86f6') 28 | continue 29 | print('{},{}'.format(wavname, mos_lookup[wavname]), file=f) 30 | 31 | 32 | def main(): 33 | args = get_arg() 34 | 35 | datadir = Path('../data', args.datatrack, 'DATA') 36 | outdir = Path('../out/ensemble-multidomain/fold', args.datatrack + '-wo_test') 37 | 38 | os.makedirs(outdir, exist_ok=True) 39 | 40 | if args.datatrack == 'external': 41 | moslist_files = ['./external_mos_list.txt'] 42 | else: 43 | moslist_files = [ 44 | datadir / "sets/train_mos_list.txt", 45 | datadir / "sets/val_mos_list.txt", 46 | ] 47 | mos_lookup = {} 48 | 49 | wavnames = [] 50 | 51 | for mos_file in moslist_files: 52 | with open(mos_file, "r") as fr: 53 | for line in fr: 54 | parts = line.strip().split(",") 55 | wavname = parts[0] 56 | mos = parts[1] 57 | mos_lookup[wavname] = mos 58 | wavnames.append(wavname) 59 | 60 | train_wavnames = np.asarray(wavnames) 61 | 62 | rng = np.random.default_rng(SEED_CV) 63 | 64 | test_wavnames = list(sorted(rng.permutation(train_wavnames)[:NUM_VAL_FILES])) 65 | 66 | Kf = KFold(n_splits=K_CV, random_state=SEED_CV, shuffle=True) 67 | 68 | 69 | for i, (cv_train_idx, cv_val_idx) in enumerate(Kf.split(train_wavnames)): 70 | 71 | cv_train_wavnames = train_wavnames[cv_train_idx] 72 | cv_val_wavnames = train_wavnames[cv_val_idx] 73 | 74 | write_wavnames(outdir / f'train-{i}.csv', cv_train_wavnames, mos_lookup) 75 | write_wavnames(outdir / f'val-{i}.csv', cv_val_wavnames, mos_lookup) 76 | write_wavnames(outdir / f'test-{i}.csv', test_wavnames, mos_lookup) 77 | 78 | write_wavnames(outdir / f'train-all.csv', train_wavnames, mos_lookup) 79 | 80 | 81 | if __name__ == '__main__': 82 | main() -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/make_ensemble_testphase.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import argparse 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | from sklearn.model_selection import KFold 8 | 9 | SEED_CV = 0 10 | K_CV = 5 11 | 12 | 13 | def get_arg(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument( 16 | "--datatrack", type=str, required=True, help="testphase-main or testphase-ood", 17 | default='testphase-main',) 18 | return parser.parse_args() 19 | 20 | def write_wavnames(outpath, wavnames, mos_lookup): 21 | 22 | with open(outpath, 'w') as f: 23 | for wavname in sorted(wavnames): 24 | if wavname == 'sys4bafa-uttc2e86f6.wav': 25 | print('Skip sys4bafa-uttc2e86f6') 26 | continue 27 | print('{},{}'.format(wavname, mos_lookup[wavname]), file=f) 28 | 29 | 30 | def main(): 31 | args = get_arg() 32 | 33 | assert args.datatrack in ['phase1-main', 'phase1-ood'] 34 | 35 | if args.datatrack == 'phase1-main': 36 | pred_datatrack = 'testphase-main' 37 | elif args.datatrack == 'phase1-ood': 38 | pred_datatrack = 'testphase-ood' 39 | 40 | datadir = Path('../data', args.datatrack, 'DATA') 41 | outdir = Path('../out/ensemble-multidomain/fold', pred_datatrack) 42 | 43 | os.makedirs(outdir, exist_ok=True) 44 | 45 | moslist_file = datadir / "sets/test_mos_list.txt" 46 | mos_lookup = {} 47 | 48 | wavnames = [] 49 | 50 | with open(moslist_file, "r") as fr: 51 | for line in fr: 52 | parts = line.strip().split(",") 53 | wavname = parts[0] 54 | mos = -99.0 # dummy 55 | mos_lookup[wavname] = mos 56 | wavnames.append(wavname) 57 | 58 | write_wavnames(outdir / f'test-0.csv', wavnames, mos_lookup) 59 | 60 | 61 | if __name__ == '__main__': 62 | main() 63 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/models.py: -------------------------------------------------------------------------------- 1 | 2 | import pickle 3 | from pathlib import Path 4 | from logging import getLogger 5 | 6 | import joblib 7 | 8 | import sklearn.linear_model 9 | import sklearn.svm 10 | from sklearn.preprocessing import StandardScaler 11 | from sklearn.metrics import mean_squared_error, make_scorer 12 | from sklearn.pipeline import Pipeline 13 | import lightgbm 14 | import torch 15 | import optuna 16 | 17 | from data_util import normalize_score, inverse_normalize_score 18 | 19 | logger = getLogger(__name__) 20 | 21 | K_CV = 5 22 | HP_OPT_TRIALS = 100 23 | HP_OPT_SEED = 0 24 | LinearSVR_OPT_MAX_ITER = 1000 25 | KernelSVR_OPT_MAX_ITER = 100000 26 | 27 | class Ridge: 28 | 29 | def __init__(self, params=None): 30 | 31 | if params is None: 32 | # use hyperparameters tuned by w2v_small 33 | self.params = { 34 | 'alpha': 36.315622558743634, 35 | } 36 | 37 | else: 38 | self.params = params 39 | 40 | def train(self, train_X, train_y, val_X, val_y): 41 | 42 | X_scaler = StandardScaler() 43 | train_X_sc = X_scaler.fit_transform(train_X) 44 | train_y_sc = normalize_score(train_y) 45 | 46 | val_X_sc = X_scaler.transform(val_X) 47 | 48 | ridge = sklearn.linear_model.Ridge(**(self.params)) 49 | 50 | ridge.fit(train_X_sc, train_y_sc) 51 | 52 | pred_y = inverse_normalize_score(ridge.predict(val_X_sc)) 53 | 54 | val_mse = mean_squared_error(val_y, pred_y) 55 | 56 | logger.info('Val MSE: {:f}'.format(val_mse)) 57 | 58 | self.ridge = ridge 59 | self.X_scaler = X_scaler 60 | 61 | 62 | def predict(self, X, df): 63 | """ 64 | Calculate results and insert them in pd.DataDrame columns 65 | """ 66 | X_sc = self.X_scaler.transform(X) 67 | 68 | pred_y = inverse_normalize_score(self.ridge.predict(X_sc)) 69 | 70 | df['pred_mos'] = pred_y.ravel() 71 | 72 | return df 73 | 74 | def save_model(self, out_dir: Path): 75 | 76 | joblib.dump(self.ridge, out_dir / 'model.joblib') 77 | joblib.dump(self.X_scaler, out_dir / 'X_scaler.joblib') 78 | 79 | def load_model(self, model_dir: Path): 80 | 81 | self.ridge = joblib.load(model_dir / 'model.joblib') 82 | self.X_scaler = joblib.load(model_dir / 'X_scaler.joblib') 83 | 84 | def optimize_hp(self, train_X, train_y): 85 | 86 | X_scaler = StandardScaler() 87 | train_X_sc = X_scaler.fit_transform(train_X) 88 | train_y_sc = normalize_score(train_y) 89 | 90 | param_distributions = { 91 | 'alpha': optuna.distributions.LogUniformDistribution(1e-5, 1e+5) 92 | } 93 | 94 | model = sklearn.linear_model.Ridge() 95 | scoring = make_scorer(mean_squared_error, greater_is_better=False) 96 | 97 | optuna_search = optuna.integration.OptunaSearchCV(model, param_distributions, 98 | cv=K_CV, 99 | n_trials=HP_OPT_TRIALS, 100 | random_state=HP_OPT_SEED, 101 | scoring=scoring, 102 | verbose=0) 103 | 104 | optuna_search.fit(train_X_sc, train_y_sc) 105 | 106 | return optuna_search.best_params_ 107 | 108 | 109 | class LinearSVR: 110 | 111 | def __init__(self, params=None, stage='stage1'): 112 | 113 | if params is None: 114 | # use hyperparameters tuned by w2v_small 115 | self.params = { 116 | 'C': 0.01982058833734277, 117 | 'epsilon': 0.23072531432463972, 118 | } 119 | 120 | else: 121 | self.params = params 122 | 123 | self.max_iter = 10000 124 | if stage == 'stage1': 125 | self.opt_max_iter = self.max_iter 126 | else: 127 | self.opt_max_iter = LinearSVR_OPT_MAX_ITER 128 | 129 | def train(self, train_X, train_y, val_X, val_y): 130 | 131 | X_scaler = StandardScaler() 132 | train_X_sc = X_scaler.fit_transform(train_X) 133 | train_y_sc = normalize_score(train_y) 134 | 135 | val_X_sc = X_scaler.transform(val_X) 136 | 137 | svr = sklearn.svm.LinearSVR(max_iter=self.max_iter, **self.params) 138 | 139 | svr.fit(train_X_sc, train_y_sc) 140 | 141 | pred_y = inverse_normalize_score(svr.predict(val_X_sc)) 142 | 143 | val_mse = mean_squared_error(val_y, pred_y) 144 | 145 | logger.info('Val MSE: {:f}'.format(val_mse)) 146 | 147 | self.svr = svr 148 | self.X_scaler = X_scaler 149 | 150 | 151 | def predict(self, X, df): 152 | """ 153 | Calculate results and insert them in pd.DataDrame columns 154 | """ 155 | X_sc = self.X_scaler.transform(X) 156 | 157 | pred_y = inverse_normalize_score(self.svr.predict(X_sc)) 158 | 159 | df['pred_mos'] = pred_y.ravel() 160 | 161 | return df 162 | 163 | def save_model(self, out_dir: Path): 164 | 165 | joblib.dump(self.svr, out_dir / 'model.joblib') 166 | joblib.dump(self.X_scaler, out_dir / 'X_scaler.joblib') 167 | 168 | def load_model(self, model_dir: Path): 169 | 170 | self.svr = joblib.load(model_dir / 'model.joblib') 171 | self.X_scaler = joblib.load(model_dir / 'X_scaler.joblib') 172 | 173 | def optimize_hp(self, train_X, train_y): 174 | 175 | X_scaler = StandardScaler() 176 | train_X_sc = X_scaler.fit_transform(train_X) 177 | train_y_sc = normalize_score(train_y) 178 | 179 | param_distributions = { 180 | 'C': optuna.distributions.LogUniformDistribution(1e-5, 1e+5), 181 | 'epsilon': optuna.distributions.UniformDistribution(0, 1), 182 | } 183 | 184 | model = sklearn.svm.LinearSVR(max_iter=self.opt_max_iter) 185 | scoring = make_scorer(mean_squared_error, greater_is_better=False) 186 | 187 | optuna_search = optuna.integration.OptunaSearchCV(model, param_distributions, 188 | cv=K_CV, 189 | n_trials=HP_OPT_TRIALS, 190 | random_state=HP_OPT_SEED, 191 | scoring=scoring, 192 | verbose=0) 193 | 194 | optuna_search.fit(train_X_sc, train_y_sc) 195 | 196 | return optuna_search.best_params_ 197 | 198 | 199 | class KernelSVR: 200 | 201 | def __init__(self, params=None, stage='stage1'): 202 | 203 | if params is None: 204 | # use hyperparameters tuned by w2v_small 205 | self.params = { 206 | 'C': 4.483354499092266, 207 | 'epsilon': 0.2177054099781604, 208 | 'gamma': 0.0006981540829311363, 209 | } 210 | 211 | else: 212 | self.params = params 213 | 214 | # self.max_iter = 10000 215 | if stage == 'stage1': 216 | self.opt_max_iter = -1 217 | else: 218 | self.opt_max_iter = KernelSVR_OPT_MAX_ITER 219 | 220 | 221 | def train(self, train_X, train_y, val_X, val_y): 222 | 223 | X_scaler = StandardScaler() 224 | train_X_sc = X_scaler.fit_transform(train_X) 225 | train_y_sc = normalize_score(train_y) 226 | 227 | val_X_sc = X_scaler.transform(val_X) 228 | 229 | svr = sklearn.svm.SVR(kernel='rbf', **self.params) 230 | 231 | svr.fit(train_X_sc, train_y_sc) 232 | 233 | pred_y = inverse_normalize_score(svr.predict(val_X_sc)) 234 | 235 | val_mse = mean_squared_error(val_y, pred_y) 236 | 237 | logger.info('Val MSE: {:f}'.format(val_mse)) 238 | 239 | self.svr = svr 240 | self.X_scaler = X_scaler 241 | 242 | 243 | def predict(self, X, df): 244 | """ 245 | Calculate results and insert them in pd.DataDrame columns 246 | """ 247 | X_sc = self.X_scaler.transform(X) 248 | 249 | pred_y = inverse_normalize_score(self.svr.predict(X_sc)) 250 | 251 | df['pred_mos'] = pred_y.ravel() 252 | 253 | return df 254 | 255 | def save_model(self, out_dir: Path): 256 | 257 | joblib.dump(self.svr, out_dir / 'model.joblib') 258 | joblib.dump(self.X_scaler, out_dir / 'X_scaler.joblib') 259 | 260 | def load_model(self, model_dir: Path): 261 | 262 | self.svr = joblib.load(model_dir / 'model.joblib') 263 | self.X_scaler = joblib.load(model_dir / 'X_scaler.joblib') 264 | 265 | 266 | 267 | def optimize_hp(self, train_X, train_y): 268 | 269 | X_scaler = StandardScaler() 270 | train_X_sc = X_scaler.fit_transform(train_X) 271 | train_y_sc = normalize_score(train_y) 272 | 273 | param_distributions = { 274 | 'C': optuna.distributions.LogUniformDistribution(1e-5, 1e+5), 275 | 'epsilon': optuna.distributions.UniformDistribution(0, 1), 276 | 'gamma': optuna.distributions.LogUniformDistribution(1e-5, 1e+5), 277 | } 278 | 279 | model = sklearn.svm.SVR(max_iter=self.opt_max_iter) 280 | scoring = make_scorer(mean_squared_error, greater_is_better=False) 281 | 282 | optuna_search = optuna.integration.OptunaSearchCV(model, param_distributions, 283 | cv=K_CV, 284 | n_trials=HP_OPT_TRIALS, 285 | random_state=HP_OPT_SEED, 286 | scoring=scoring, 287 | verbose=0) 288 | 289 | optuna_search.fit(train_X_sc, train_y_sc) 290 | 291 | return optuna_search.best_params_ 292 | 293 | 294 | class RandomForest: 295 | 296 | def __init__(self, params=None): 297 | 298 | if params is None: 299 | self.params = { 300 | 'n_estimators': 100, 301 | 'max_depth': 1000, 302 | } 303 | 304 | else: 305 | self.params = params 306 | 307 | def train(self, train_X, train_y, val_X, val_y): 308 | 309 | X_scaler = StandardScaler() 310 | train_X_sc = X_scaler.fit_transform(train_X) 311 | train_y_sc = normalize_score(train_y) 312 | 313 | val_X_sc = X_scaler.transform(val_X) 314 | 315 | rf = sklearn.ensemble.RandomForestRegressor(**self.params) 316 | 317 | rf.fit(train_X_sc, train_y_sc) 318 | 319 | pred_y = inverse_normalize_score(rf.predict(val_X_sc)) 320 | 321 | val_mse = mean_squared_error(val_y, pred_y) 322 | 323 | logger.info('Val MSE: {:f}'.format(val_mse)) 324 | 325 | self.rf = rf 326 | self.X_scaler = X_scaler 327 | 328 | 329 | def predict(self, X, df): 330 | """ 331 | Calculate results and insert them in pd.DataDrame columns 332 | """ 333 | X_sc = self.X_scaler.transform(X) 334 | 335 | pred_y = inverse_normalize_score(self.rf.predict(X_sc)) 336 | 337 | df['pred_mos'] = pred_y.ravel() 338 | 339 | return df 340 | 341 | def save_model(self, out_dir: Path): 342 | 343 | joblib.dump(self.rf, out_dir / 'model.joblib') 344 | joblib.dump(self.X_scaler, out_dir / 'X_scaler.joblib') 345 | 346 | def load_model(self, model_dir: Path): 347 | 348 | self.rf = joblib.load(model_dir / 'model.joblib') 349 | self.X_scaler = joblib.load(model_dir / 'X_scaler.joblib') 350 | 351 | 352 | def optimize_hp(self, train_X, train_y): 353 | 354 | X_scaler = StandardScaler() 355 | train_X_sc = X_scaler.fit_transform(train_X) 356 | train_y_sc = normalize_score(train_y) 357 | 358 | param_distributions = { 359 | 'n_estimators': optuna.distributions.IntUniformDistribution(1, 100), 360 | 'max_depth': optuna.distributions.IntLogUniformDistribution(1, 1000), 361 | 'max_features': optuna.distributions.CategoricalDistribution(['auto', 'sqrt', 'log2']), 362 | } 363 | 364 | model = sklearn.ensemble.RandomForestRegressor() 365 | scoring = make_scorer(mean_squared_error, greater_is_better=False) 366 | 367 | optuna_search = optuna.integration.OptunaSearchCV(model, param_distributions, 368 | cv=K_CV, 369 | n_trials=HP_OPT_TRIALS, 370 | random_state=HP_OPT_SEED, 371 | scoring=scoring, 372 | verbose=0) 373 | 374 | optuna_search.fit(train_X_sc, train_y_sc) 375 | 376 | return optuna_search.best_params_ 377 | 378 | 379 | class LightGBM: 380 | 381 | def __init__(self, params=None): 382 | 383 | if params is None: 384 | # use hyperparameters tuned by w2v_small 385 | self.params = { 386 | 'lambda_l1': 0.03708013547428929, 387 | 'lambda_l2': 3.1884740170707856e-07, 388 | 'num_leaves': 220, 389 | 'feature_fraction': 0.6747205882024254, 390 | 'bagging_fraction': 0.9367956222111139, 391 | 'bagging_freq': 2, 392 | 'min_child_samples': 92, 393 | 'max_depth': 10 394 | } 395 | 396 | else: 397 | self.params = params 398 | 399 | def train(self, train_X, train_y, val_X, val_y): 400 | 401 | train_set = lightgbm.Dataset(train_X, train_y) 402 | valid_set = lightgbm.Dataset(val_X, val_y, reference=train_set) 403 | 404 | lgb_model = lightgbm.train( 405 | params = self.params, 406 | train_set = train_set, 407 | valid_sets = [train_set, valid_set], 408 | num_boost_round = 10000, 409 | early_stopping_rounds = 10, 410 | ) 411 | 412 | pred_y = lgb_model.predict(val_X, num_iteration=lgb_model.best_iteration) 413 | 414 | val_mse = mean_squared_error(val_y, pred_y) 415 | 416 | logger.info('Val MSE: {:f}'.format(val_mse)) 417 | 418 | self.lgb_model = lgb_model 419 | 420 | 421 | def predict(self, X, df): 422 | """ 423 | Calculate results and insert them in pd.DataDrame columns 424 | """ 425 | pred_y = self.lgb_model.predict(X, num_iteration=self.lgb_model.best_iteration) 426 | 427 | df['pred_mos'] = pred_y.ravel() 428 | 429 | return df 430 | 431 | def save_model(self, out_dir: Path): 432 | 433 | with open(out_dir / 'model.pkl', 'wb') as f: 434 | pickle.dump(self.lgb_model, f) 435 | 436 | 437 | def load_model(self, model_dir: Path): 438 | 439 | self.lgb_model = pickle.load(open(model_dir / 'model.pkl', 'rb')) 440 | 441 | 442 | def optimize_hp(self, train_X, train_y): 443 | 444 | X_scaler = StandardScaler() 445 | train_X_sc = X_scaler.fit_transform(train_X) 446 | train_y_sc = normalize_score(train_y) 447 | 448 | lgb_train = optuna.integration.lightgbm.Dataset(train_X_sc, train_y_sc) 449 | 450 | lgbm_params = { 451 | 'objective': 'regression', 452 | 'metric': 'mse', 453 | 'verbosity': -1, 454 | } 455 | 456 | folds = sklearn.model_selection.KFold(n_splits=K_CV, shuffle=True, random_state=HP_OPT_SEED) 457 | 458 | tuner_cv = optuna.integration.lightgbm.LightGBMTunerCV( 459 | lgbm_params, lgb_train, 460 | num_boost_round=1000, 461 | early_stopping_rounds=100, 462 | # verbose_eval=20, 463 | folds=folds, 464 | optuna_seed=HP_OPT_SEED, 465 | ) 466 | 467 | tuner_cv.run() 468 | 469 | return tuner_cv.best_params 470 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/modules.py: -------------------------------------------------------------------------------- 1 | ../../strong/modules.py -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/opt_stage1.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from pathlib import Path 4 | import logging 5 | from logging import getLogger 6 | import random 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from data_util import load_stage1_train_all_data 12 | 13 | from models import Ridge, LinearSVR, KernelSVR, RandomForest, LightGBM 14 | from gp_models import SVGP 15 | import json 16 | 17 | logger = getLogger(__name__) 18 | 19 | RAND_SEED = 0 20 | 21 | def get_arg(): 22 | import argparse 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('method') 25 | parser.add_argument('datatrack') 26 | parser.add_argument('ssl_type') 27 | return parser.parse_args() 28 | 29 | 30 | 31 | 32 | def main(): 33 | args = get_arg() 34 | 35 | random.seed(RAND_SEED) 36 | np.random.seed(RAND_SEED) 37 | torch.manual_seed(RAND_SEED) 38 | 39 | data = load_stage1_train_all_data( 40 | datatrack=args.datatrack, 41 | ssl_type=args.ssl_type, 42 | ) 43 | 44 | if args.method == 'ridge': 45 | model = Ridge() 46 | elif args.method == 'linear_svr': 47 | model = LinearSVR(stage='stage1') 48 | elif args.method == 'kernel_svr': 49 | model = KernelSVR(stage='stage1') 50 | elif args.method == 'rf': 51 | raise NotImplementedError() 52 | # model = RandomForest() 53 | elif args.method == 'lightgbm': 54 | model = LightGBM() 55 | elif args.method == 'svgp': 56 | raise NotImplementedError() 57 | else: 58 | raise RuntimeError('Not supported method: "{}"'.format(args.method)) 59 | 60 | best_params = model.optimize_hp(data['X'], data['y']) 61 | 62 | logger.info(best_params) 63 | 64 | out_dir = Path('../out/ensemble-multidomain/opt_hp_stage1') / args.datatrack / \ 65 | f'{args.method}-{args.ssl_type}' 66 | os.makedirs(out_dir, exist_ok=True) 67 | 68 | with open(out_dir / 'params.json', encoding="utf-8", mode="w") as f: 69 | json.dump(best_params, f, ensure_ascii=False, indent=2) 70 | 71 | 72 | if __name__ == '__main__': 73 | logging.basicConfig(level=logging.INFO, format='{name}: {message}', style='{') 74 | main() 75 | 76 | 77 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/opt_stage2.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from pathlib import Path 4 | import logging 5 | from logging import getLogger 6 | import random 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from data_util import load_stage2_train_all_data 12 | 13 | from models import Ridge, LinearSVR, KernelSVR, RandomForest, LightGBM 14 | from gp_models import SVGP 15 | import json 16 | 17 | logger = getLogger(__name__) 18 | 19 | RAND_SEED = 0 20 | 21 | def get_arg(): 22 | import argparse 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('method') 25 | parser.add_argument('datatrack') 26 | parser.add_argument('feat_type') 27 | return parser.parse_args() 28 | 29 | 30 | 31 | 32 | def main(): 33 | args = get_arg() 34 | 35 | random.seed(RAND_SEED) 36 | np.random.seed(RAND_SEED) 37 | torch.manual_seed(RAND_SEED) 38 | 39 | data = load_stage2_train_all_data( 40 | datatrack=args.datatrack, 41 | feat_type=args.feat_type, 42 | ) 43 | 44 | if args.method == 'ridge': 45 | model = Ridge() 46 | elif args.method == 'linear_svr': 47 | model = LinearSVR(stage='stage2') 48 | elif args.method == 'kernel_svr': 49 | model = KernelSVR(stage='stage2') 50 | elif args.method == 'rf': 51 | raise NotImplementedError() 52 | # model = RandomForest() 53 | elif args.method == 'lightgbm': 54 | model = LightGBM() 55 | elif args.method == 'svgp': 56 | raise NotImplementedError() 57 | else: 58 | raise RuntimeError('Not supported method: "{}"'.format(args.method)) 59 | 60 | best_params = model.optimize_hp(data['X'], data['y']) 61 | 62 | logger.info(best_params) 63 | 64 | out_dir = Path('../out/ensemble-multidomain/opt_hp_stage2') / args.datatrack / \ 65 | f'{args.method}-{args.feat_type}' 66 | os.makedirs(out_dir, exist_ok=True) 67 | 68 | with open(out_dir / 'params.json', encoding="utf-8", mode="w") as f: 69 | json.dump(best_params, f, ensure_ascii=False, indent=2) 70 | 71 | 72 | if __name__ == '__main__': 73 | logging.basicConfig(level=logging.INFO, format='{name}: {message}', style='{') 74 | main() 75 | 76 | 77 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/opt_stage3.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from pathlib import Path 4 | import logging 5 | from logging import getLogger 6 | import random 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from data_util import load_stage3_train_all_data 12 | 13 | from models import Ridge, LinearSVR, KernelSVR, RandomForest, LightGBM 14 | from gp_models import SVGP 15 | import json 16 | 17 | logger = getLogger(__name__) 18 | 19 | RAND_SEED = 0 20 | 21 | def get_arg(): 22 | import argparse 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('method') 25 | parser.add_argument('datatrack') 26 | parser.add_argument('feat_type') 27 | return parser.parse_args() 28 | 29 | 30 | 31 | 32 | def main(): 33 | args = get_arg() 34 | 35 | random.seed(RAND_SEED) 36 | np.random.seed(RAND_SEED) 37 | torch.manual_seed(RAND_SEED) 38 | 39 | data = load_stage3_train_all_data( 40 | datatrack=args.datatrack, 41 | feat_type=args.feat_type, 42 | ) 43 | 44 | if args.method == 'ridge': 45 | model = Ridge() 46 | elif args.method == 'linear_svr': 47 | model = LinearSVR() 48 | elif args.method == 'kernel_svr': 49 | model = KernelSVR() 50 | elif args.method == 'rf': 51 | raise NotImplementedError() 52 | # model = RandomForest() 53 | elif args.method == 'lightgbm': 54 | model = LightGBM() 55 | elif args.method == 'svgp': 56 | raise NotImplementedError() 57 | else: 58 | raise RuntimeError('Not supported method: "{}"'.format(args.method)) 59 | 60 | best_params = model.optimize_hp(data['X'], data['y']) 61 | 62 | logger.info(best_params) 63 | 64 | out_dir = Path('../out/ensemble-multidomain/opt_hp_stage3') / args.datatrack / \ 65 | f'{args.method}-{args.feat_type}' 66 | os.makedirs(out_dir, exist_ok=True) 67 | 68 | with open(out_dir / 'params.json', encoding="utf-8", mode="w") as f: 69 | json.dump(best_params, f, ensure_ascii=False, indent=2) 70 | 71 | 72 | if __name__ == '__main__': 73 | logging.basicConfig(level=logging.INFO, format='{name}: {message}', style='{') 74 | main() 75 | 76 | 77 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/pred_stage1.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from pathlib import Path 4 | import logging 5 | from logging import getLogger 6 | import random 7 | import json 8 | 9 | import numpy as np 10 | import torch 11 | 12 | from data_util import load_stage1_data, load_stage1_train_all_data, load_stage1_test_data 13 | 14 | from models import Ridge, LinearSVR, KernelSVR, LightGBM, RandomForest 15 | from gp_models import SVGP, ExactGP 16 | 17 | logger = getLogger(__name__) 18 | 19 | RAND_SEED = 0 20 | 21 | def get_arg(): 22 | import argparse 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('method') 25 | parser.add_argument('train_datatrack') 26 | parser.add_argument('ssl_type') 27 | parser.add_argument('i_cv', type=int) 28 | parser.add_argument('pred_datatrack') 29 | return parser.parse_args() 30 | 31 | 32 | def main(): 33 | args = get_arg() 34 | 35 | random.seed(RAND_SEED) 36 | np.random.seed(RAND_SEED) 37 | torch.manual_seed(RAND_SEED) 38 | 39 | if args.method == 'svgp': 40 | model = SVGP() 41 | elif args.method == 'exactgp': 42 | model = ExactGP() 43 | elif args.method == 'rf': 44 | model = RandomForest() 45 | else: 46 | if args.method == 'ridge': 47 | model = Ridge() 48 | elif args.method == 'linear_svr': 49 | model = LinearSVR() 50 | elif args.method == 'kernel_svr': 51 | model = KernelSVR() 52 | elif args.method == 'lightgbm': 53 | model = LightGBM() 54 | else: 55 | raise RuntimeError('Not supported method: "{}"'.format(args.method)) 56 | 57 | model_dir = Path('../out/ensemble-multidomain/stage1') / args.train_datatrack / \ 58 | f'{args.method}-{args.ssl_type}' / str(args.i_cv) 59 | out_dir = model_dir / f'pred-{args.pred_datatrack}' 60 | os.makedirs(out_dir, exist_ok=True) 61 | 62 | logger.info('Outdir: {}'.format(out_dir)) 63 | 64 | if args.method == 'svgp': 65 | # train_data = load_stage1_data( 66 | # datatrack=args.train_datatrack, 67 | # ssl_type=args.ssl_type, 68 | # i_cv=args.i_cv, 69 | # ) 70 | # model.load_model(model_dir, train_data['train']['X']) 71 | model.load_model(model_dir) 72 | elif args.method == 'exactgp': 73 | train_data = load_stage1_data( 74 | datatrack=args.train_datatrack, 75 | ssl_type=args.ssl_type, 76 | i_cv=args.i_cv, 77 | ) 78 | model.load_model(model_dir, train_data['train']['X'], train_data['train']['y']) 79 | else: 80 | model.load_model(model_dir) 81 | 82 | pred_data = {} 83 | pred_data['train'] = load_stage1_train_all_data( 84 | datatrack=args.pred_datatrack, 85 | ssl_type=args.ssl_type, 86 | ) 87 | pred_data['test'] = load_stage1_test_data( 88 | datatrack=args.pred_datatrack, 89 | ssl_type=args.ssl_type, 90 | ) 91 | 92 | df_train = model.predict(pred_data['train']['X'], pred_data['train']['df']) 93 | df_test = model.predict(pred_data['test']['X'], pred_data['test']['df']) 94 | 95 | df_train.to_csv(out_dir / 'train.csv') 96 | df_test.to_csv(out_dir / 'test.csv') 97 | 98 | 99 | 100 | if __name__ == '__main__': 101 | logging.basicConfig(level=logging.INFO, format='{name}: {message}', style='{') 102 | main() 103 | 104 | 105 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/pred_stage1_ood.sh: -------------------------------------------------------------------------------- 1 | 2 | set -eu 3 | 4 | ssl_types="w2v_small w2v_large w2v_large2 w2v_xlsr wavlm_base wavlm_large hubert_base hubert_large" 5 | #ssl_types="w2v_small w2v_large wavlm_base wavlm_large hubert_base hubert_large" 6 | 7 | # for ood 8 | for train_datatrack in external-wo_test phase1-main; do 9 | for pred_datatrack in phase1-ood ; do 10 | for ssl_type in $ssl_types; do 11 | for method in ridge linear_svr kernel_svr rf lightgbm exactgp; do 12 | for i_cv in 0 1 2 3 4; do 13 | echo "${method}, ${train_datatrack}, ${ssl_type}, ${i_cv}, ${pred_datatrack}" 14 | python -u pred_stage1.py ${method} ${train_datatrack} ${ssl_type} ${i_cv} ${pred_datatrack} 15 | done 16 | done 17 | done 18 | done 19 | done 20 | 21 | echo "done" 22 | 23 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/pred_testphase_stage1.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from pathlib import Path 4 | import logging 5 | from logging import getLogger 6 | import random 7 | import json 8 | 9 | import numpy as np 10 | import torch 11 | 12 | from data_util import load_stage1_data, load_stage1_train_all_data, load_stage1_test_data 13 | 14 | from models import Ridge, LinearSVR, KernelSVR, LightGBM, RandomForest 15 | from gp_models import SVGP, ExactGP 16 | 17 | logger = getLogger(__name__) 18 | 19 | RAND_SEED = 0 20 | 21 | def get_arg(): 22 | import argparse 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('method') 25 | parser.add_argument('train_datatrack') 26 | parser.add_argument('ssl_type') 27 | parser.add_argument('i_cv', type=int) 28 | parser.add_argument('pred_datatrack') 29 | return parser.parse_args() 30 | 31 | 32 | def main(): 33 | args = get_arg() 34 | 35 | random.seed(RAND_SEED) 36 | np.random.seed(RAND_SEED) 37 | torch.manual_seed(RAND_SEED) 38 | 39 | if args.method == 'svgp': 40 | model = SVGP() 41 | elif args.method == 'exactgp': 42 | model = ExactGP() 43 | elif args.method == 'rf': 44 | model = RandomForest() 45 | else: 46 | if args.method == 'ridge': 47 | model = Ridge() 48 | elif args.method == 'linear_svr': 49 | model = LinearSVR() 50 | elif args.method == 'kernel_svr': 51 | model = KernelSVR() 52 | elif args.method == 'lightgbm': 53 | model = LightGBM() 54 | else: 55 | raise RuntimeError('Not supported method: "{}"'.format(args.method)) 56 | 57 | model_dir = Path('../out/ensemble-multidomain/stage1') / args.train_datatrack / \ 58 | f'{args.method}-{args.ssl_type}' / str(args.i_cv) 59 | out_dir = model_dir / f'pred-{args.pred_datatrack}' 60 | os.makedirs(out_dir, exist_ok=True) 61 | 62 | logger.info('Outdir: {}'.format(out_dir)) 63 | 64 | if args.method == 'svgp': 65 | # train_data = load_stage1_data( 66 | # datatrack=args.train_datatrack, 67 | # ssl_type=args.ssl_type, 68 | # i_cv=args.i_cv, 69 | # ) 70 | # model.load_model(model_dir, train_data['train']['X']) 71 | model.load_model(model_dir) 72 | elif args.method == 'exactgp': 73 | train_data = load_stage1_data( 74 | datatrack=args.train_datatrack, 75 | ssl_type=args.ssl_type, 76 | i_cv=args.i_cv, 77 | ) 78 | model.load_model(model_dir, train_data['train']['X'], train_data['train']['y']) 79 | else: 80 | model.load_model(model_dir) 81 | 82 | pred_data = {} 83 | pred_data['test'] = load_stage1_test_data( 84 | datatrack=args.pred_datatrack, 85 | ssl_type=args.ssl_type, 86 | ) 87 | 88 | df_test = model.predict(pred_data['test']['X'], pred_data['test']['df']) 89 | 90 | df_test.to_csv(out_dir / 'test.csv') 91 | 92 | 93 | 94 | if __name__ == '__main__': 95 | logging.basicConfig(level=logging.INFO, format='{name}: {message}', style='{') 96 | main() 97 | 98 | 99 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/pred_testphase_stage1_main.sh: -------------------------------------------------------------------------------- 1 | 2 | set -eu 3 | 4 | ssl_types="w2v_small w2v_large w2v_large2 w2v_xlsr wavlm_base wavlm_large hubert_base hubert_large" 5 | # ssl_types="w2v_small w2v_large wavlm_base wavlm_large hubert_base hubert_large" 6 | 7 | for train_datatrack in phase1-main; do 8 | for pred_datatrack in testphase-main; do 9 | for ssl_type in ${ssl_types}; do 10 | for method in ridge linear_svr kernel_svr rf lightgbm exactgp ; do 11 | for i_cv in 0 1 2 3 4; do 12 | echo "${method}, ${train_datatrack}, ${ssl_type}, ${i_cv}, ${pred_datatrack}" 13 | python -u pred_testphase_stage1.py ${method} ${train_datatrack} ${ssl_type} ${i_cv} ${pred_datatrack} 14 | done 15 | done 16 | done 17 | done 18 | done 19 | 20 | echo "done" 21 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/pred_testphase_stage1_ood.sh: -------------------------------------------------------------------------------- 1 | 2 | set -eu 3 | 4 | ssl_types="w2v_small w2v_large w2v_large2 w2v_xlsr wavlm_base wavlm_large hubert_base hubert_large" 5 | 6 | for train_datatrack in phase1-ood external-wo_test phase1-main; do 7 | for pred_datatrack in testphase-ood ; do 8 | for ssl_type in ${ssl_types}; do 9 | for method in ridge linear_svr kernel_svr rf lightgbm exactgp ; do 10 | for i_cv in 0 1 2 3 4; do 11 | echo "${method}, ${train_datatrack}, ${ssl_type}, ${i_cv}, ${pred_datatrack}" 12 | python -u pred_testphase_stage1.py ${method} ${train_datatrack} ${ssl_type} ${i_cv} ${pred_datatrack} 13 | done 14 | done 15 | done 16 | done 17 | done 18 | 19 | echo "done" 20 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/pred_testphase_stage2-3_main.sh: -------------------------------------------------------------------------------- 1 | 2 | set -eu 3 | 4 | train_datatrack=phase1-main 5 | pred_datatrack=testphase-main 6 | feat_type=main-strong1-weak48 7 | 8 | python -u collect_stage1_testphase_result.py ${pred_datatrack} ${feat_type} 9 | 10 | for method in ridge linear_svr kernel_svr rf lightgbm exactgp; do 11 | for i_cv in 0 1 2 3 4; do 12 | echo "${method}, ${train_datatrack}, ${feat_type}, ${i_cv}, ${pred_datatrack}" 13 | python -u pred_testphase_stage2.py ${method} ${train_datatrack} ${feat_type} ${i_cv} ${pred_datatrack} 14 | done 15 | done 16 | 17 | echo "Collect stage2 data." 18 | python -u collect_stage2_testphase_result.py ${pred_datatrack} ${feat_type} 19 | 20 | for method in ridge; do 21 | for i_cv in 0 1 2 3 4; do 22 | echo "Run stage3: ${method}, ${train_datatrack}, ${feat_type}, ${i_cv}, ${pred_datatrack}" 23 | python -u pred_testphase_stage3.py ${method} ${train_datatrack} ${feat_type} ${i_cv} ${pred_datatrack} 24 | done 25 | done 26 | 27 | echo "Calculate result." 28 | 29 | python -u calc_testphase_result.py ${pred_datatrack} ${feat_type} 30 | 31 | echo "done" 32 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/pred_testphase_stage2-3_ood.sh: -------------------------------------------------------------------------------- 1 | 2 | set -eu 3 | 4 | train_datatrack=phase1-ood 5 | pred_datatrack=testphase-ood 6 | feat_type=ood-strong1-weak144 7 | 8 | python -u collect_stage1_testphase_result.py ${pred_datatrack} ${feat_type} 9 | 10 | for method in ridge linear_svr kernel_svr rf lightgbm exactgp; do 11 | for i_cv in 0 1 2 3 4; do 12 | echo "${method}, ${train_datatrack}, ${feat_type}, ${i_cv}, ${pred_datatrack}" 13 | python -u pred_testphase_stage2.py ${method} ${train_datatrack} ${feat_type} ${i_cv} ${pred_datatrack} 14 | done 15 | done 16 | 17 | echo "Collect stage2 data." 18 | python -u collect_stage2_testphase_result.py ${pred_datatrack} ${feat_type} 19 | 20 | for method in ridge; do 21 | for i_cv in 0 1 2 3 4; do 22 | echo "Run stage3: ${method}, ${train_datatrack}, ${feat_type}, ${i_cv}, ${pred_datatrack}" 23 | python -u pred_testphase_stage3.py ${method} ${train_datatrack} ${feat_type} ${i_cv} ${pred_datatrack} 24 | done 25 | done 26 | 27 | echo "Calculate result." 28 | 29 | python -u calc_testphase_result.py ${pred_datatrack} ${feat_type} 30 | 31 | echo "done" 32 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/pred_testphase_stage2.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from pathlib import Path 4 | import logging 5 | from logging import getLogger 6 | import random 7 | import json 8 | 9 | import numpy as np 10 | import torch 11 | 12 | from data_util import load_stage2_data, load_stage2_test_data 13 | 14 | from models import Ridge, LinearSVR, KernelSVR, LightGBM, RandomForest 15 | from gp_models import SVGP, ExactGP 16 | 17 | logger = getLogger(__name__) 18 | 19 | RAND_SEED = 0 20 | 21 | def get_arg(): 22 | import argparse 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('method') 25 | parser.add_argument('train_datatrack') 26 | parser.add_argument('feat_type') 27 | parser.add_argument('i_cv', type=int) 28 | parser.add_argument('pred_datatrack') 29 | return parser.parse_args() 30 | 31 | 32 | def main(): 33 | args = get_arg() 34 | 35 | random.seed(RAND_SEED) 36 | np.random.seed(RAND_SEED) 37 | torch.manual_seed(RAND_SEED) 38 | 39 | if args.method == 'svgp': 40 | model = SVGP() 41 | elif args.method == 'exactgp': 42 | model = ExactGP() 43 | elif args.method == 'rf': 44 | model = RandomForest() 45 | else: 46 | if args.method == 'ridge': 47 | model = Ridge() 48 | elif args.method == 'linear_svr': 49 | model = LinearSVR() 50 | elif args.method == 'kernel_svr': 51 | model = KernelSVR() 52 | elif args.method == 'lightgbm': 53 | model = LightGBM() 54 | else: 55 | raise RuntimeError('Not supported method: "{}"'.format(args.method)) 56 | 57 | model_dir = Path('../out/ensemble-multidomain/stage2') / args.train_datatrack / \ 58 | f'{args.method}-{args.feat_type}' / str(args.i_cv) 59 | out_dir = model_dir / f'pred-{args.pred_datatrack}' 60 | os.makedirs(out_dir, exist_ok=True) 61 | 62 | logger.info('Outdir: {}'.format(out_dir)) 63 | 64 | if args.method == 'svgp': 65 | # train_data = load_stage1_data( 66 | # datatrack=args.train_datatrack, 67 | # feat_type=args.feat_type, 68 | # i_cv=args.i_cv, 69 | # ) 70 | # model.load_model(model_dir, train_data['train']['X']) 71 | model.load_model(model_dir) 72 | elif args.method == 'exactgp': 73 | train_data = load_stage2_data( 74 | datatrack=args.train_datatrack, 75 | feat_type=args.feat_type, 76 | i_cv=args.i_cv, 77 | ) 78 | model.load_model(model_dir, train_data['train']['X'], train_data['train']['y']) 79 | else: 80 | model.load_model(model_dir) 81 | 82 | pred_data = {} 83 | pred_data['test'] = load_stage2_test_data( 84 | datatrack=args.pred_datatrack, 85 | feat_type=args.feat_type, 86 | ) 87 | 88 | df_test = model.predict(pred_data['test']['X'], pred_data['test']['df']) 89 | 90 | df_test.to_csv(out_dir / 'test.csv') 91 | 92 | 93 | 94 | if __name__ == '__main__': 95 | logging.basicConfig(level=logging.INFO, format='{name}: {message}', style='{') 96 | main() 97 | 98 | 99 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/pred_testphase_stage3.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from pathlib import Path 4 | import logging 5 | from logging import getLogger 6 | import random 7 | import json 8 | 9 | import numpy as np 10 | import torch 11 | 12 | from data_util import load_stage3_data, load_stage3_test_data 13 | 14 | from models import Ridge, LinearSVR, KernelSVR, LightGBM, RandomForest 15 | from gp_models import SVGP, ExactGP 16 | 17 | logger = getLogger(__name__) 18 | 19 | RAND_SEED = 0 20 | 21 | def get_arg(): 22 | import argparse 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('method') 25 | parser.add_argument('train_datatrack') 26 | parser.add_argument('feat_type') 27 | parser.add_argument('i_cv', type=int) 28 | parser.add_argument('pred_datatrack') 29 | return parser.parse_args() 30 | 31 | 32 | def main(): 33 | args = get_arg() 34 | 35 | random.seed(RAND_SEED) 36 | np.random.seed(RAND_SEED) 37 | torch.manual_seed(RAND_SEED) 38 | 39 | if args.method == 'svgp': 40 | model = SVGP() 41 | elif args.method == 'exactgp': 42 | model = ExactGP() 43 | elif args.method == 'rf': 44 | model = RandomForest() 45 | else: 46 | if args.method == 'ridge': 47 | model = Ridge() 48 | elif args.method == 'linear_svr': 49 | model = LinearSVR() 50 | elif args.method == 'kernel_svr': 51 | model = KernelSVR() 52 | elif args.method == 'lightgbm': 53 | model = LightGBM() 54 | else: 55 | raise RuntimeError('Not supported method: "{}"'.format(args.method)) 56 | 57 | model_dir = Path('../out/ensemble-multidomain/stage3') / args.train_datatrack / \ 58 | f'{args.method}-{args.feat_type}' / str(args.i_cv) 59 | out_dir = model_dir / f'pred-{args.pred_datatrack}' 60 | os.makedirs(out_dir, exist_ok=True) 61 | 62 | logger.info('Outdir: {}'.format(out_dir)) 63 | 64 | if args.method == 'svgp': 65 | # train_data = load_stage1_data( 66 | # datatrack=args.train_datatrack, 67 | # feat_type=args.feat_type, 68 | # i_cv=args.i_cv, 69 | # ) 70 | # model.load_model(model_dir, train_data['train']['X']) 71 | model.load_model(model_dir) 72 | elif args.method == 'exactgp': 73 | train_data = load_stage3_data( 74 | datatrack=args.train_datatrack, 75 | feat_type=args.feat_type, 76 | i_cv=args.i_cv, 77 | ) 78 | model.load_model(model_dir, train_data['train']['X'], train_data['train']['y']) 79 | else: 80 | model.load_model(model_dir) 81 | 82 | pred_data = {} 83 | pred_data['test'] = load_stage3_test_data( 84 | datatrack=args.pred_datatrack, 85 | feat_type=args.feat_type, 86 | ) 87 | 88 | df_test = model.predict(pred_data['test']['X'], pred_data['test']['df']) 89 | 90 | df_test.to_csv(out_dir / 'test.csv') 91 | 92 | 93 | 94 | if __name__ == '__main__': 95 | logging.basicConfig(level=logging.INFO, format='{name}: {message}', style='{') 96 | main() 97 | 98 | 99 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/run_stage1.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from pathlib import Path 4 | import logging 5 | from logging import getLogger 6 | import random 7 | import json 8 | 9 | import numpy as np 10 | import torch 11 | 12 | from data_util import load_stage1_data 13 | 14 | from models import Ridge, LinearSVR, KernelSVR, LightGBM, RandomForest 15 | from gp_models import SVGP, ExactGP 16 | 17 | logger = getLogger(__name__) 18 | 19 | RAND_SEED = 0 20 | 21 | def get_arg(): 22 | import argparse 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('method') 25 | parser.add_argument('datatrack') 26 | parser.add_argument('ssl_type') 27 | parser.add_argument('i_cv', type=int) 28 | parser.add_argument('--use_opt', action='store_true', default=False) 29 | return parser.parse_args() 30 | 31 | 32 | 33 | 34 | def main(): 35 | args = get_arg() 36 | 37 | random.seed(RAND_SEED) 38 | np.random.seed(RAND_SEED) 39 | torch.manual_seed(RAND_SEED) 40 | 41 | data = load_stage1_data( 42 | datatrack=args.datatrack, 43 | ssl_type=args.ssl_type, 44 | i_cv=args.i_cv, 45 | ) 46 | 47 | if args.method == 'svgp': 48 | model = SVGP() 49 | elif args.method == 'exactgp': 50 | model = ExactGP() 51 | elif args.method == 'rf': 52 | model = RandomForest() 53 | else: 54 | if args.use_opt: 55 | param_file = Path('../out/ensemble-multidomain/opt_hp_stage1') / args.datatrack / \ 56 | f'{args.method}-{args.ssl_type}' / 'params.json' 57 | params = json.load(open(param_file, 'rb')) 58 | logger.info('Params: {}'.format(params)) 59 | else: 60 | params = {} 61 | 62 | if args.method == 'ridge': 63 | model = Ridge(params=params) 64 | elif args.method == 'linear_svr': 65 | model = LinearSVR(params=params) 66 | elif args.method == 'kernel_svr': 67 | model = KernelSVR(params=params) 68 | elif args.method == 'lightgbm': 69 | model = LightGBM(params=params) 70 | else: 71 | raise RuntimeError('Not supported method: "{}"'.format(args.method)) 72 | 73 | model.train(data['train']['X'], data['train']['y'], 74 | data['val']['X'], data['val']['y']) 75 | 76 | df_val = model.predict(data['val']['X'], data['val']['df']) 77 | df_test = model.predict(data['test']['X'], data['test']['df']) 78 | 79 | out_dir = Path('../out/ensemble-multidomain/stage1') / args.datatrack / \ 80 | f'{args.method}-{args.ssl_type}' / str(args.i_cv) 81 | os.makedirs(out_dir, exist_ok=True) 82 | 83 | df_val.to_csv(out_dir / 'val.csv') 84 | df_test.to_csv(out_dir / 'test.csv') 85 | 86 | model.save_model(out_dir) 87 | 88 | 89 | 90 | 91 | if __name__ == '__main__': 92 | logging.basicConfig(level=logging.INFO, format='{name}: {message}', style='{') 93 | main() 94 | 95 | 96 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/run_stage1.sh: -------------------------------------------------------------------------------- 1 | 2 | set -eu 3 | 4 | ssl_types="w2v_small w2v_large w2v_large2 w2v_xlsr wavlm_base wavlm_large hubert_base hubert_large" 5 | 6 | for datatrack in phase1-ood phase1-main external-wo_test; do 7 | for method in ridge linear_svr kernel_svr rf lightgbm exactgp; do 8 | for ssl_type in ${ssl_types}; do 9 | for i_cv in 0 1 2 3 4; do 10 | echo "${datatrack}, ${method}, ${ssl_type}, ${i_cv}" 11 | python -u run_stage1.py ${method} ${datatrack} ${ssl_type} ${i_cv} 12 | done 13 | done 14 | done 15 | done 16 | 17 | echo "done" 18 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/run_stage2-3_main.sh: -------------------------------------------------------------------------------- 1 | 2 | set -eu 3 | 4 | datatrack=phase1-main 5 | feat_type=main-strong1-weak48 6 | 7 | python -u collect_stage1_result.py ${datatrack} ${feat_type} 8 | 9 | for method in ridge linear_svr kernel_svr rf lightgbm exactgp; do 10 | for i_cv in 0 1 2 3 4; do 11 | echo "Run stage2: ${method}, ${datatrack}, ${feat_type}, ${i_cv}" 12 | python -u run_stage2.py ${method} ${datatrack} ${feat_type} ${i_cv} 13 | done 14 | done 15 | 16 | echo "Collect stage2 data." 17 | python -u collect_stage2_result.py ${datatrack} ${feat_type} 18 | 19 | for method in ridge; do 20 | for i_cv in 0 1 2 3 4; do 21 | echo "Run stage3: ${method} ${datatrack}, ${feat_type}, ${i_cv}" 22 | python -u run_stage3.py ${method} ${datatrack} ${feat_type} ${i_cv} 23 | done 24 | done 25 | 26 | echo "Calculate result." 27 | 28 | python -u calc_result.py ${datatrack} ${feat_type} 29 | 30 | echo "done" 31 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/run_stage2-3_ood.sh: -------------------------------------------------------------------------------- 1 | 2 | set -eu 3 | 4 | datatrack=phase1-ood 5 | feat_type=ood-strong1-weak144 6 | 7 | python -u collect_stage1_result.py ${datatrack} ${feat_type} 8 | 9 | for method in ridge linear_svr kernel_svr rf lightgbm exactgp; do 10 | for i_cv in 0 1 2 3 4; do 11 | echo "Run stage2: ${method}, ${datatrack}, ${feat_type}, ${i_cv}" 12 | python -u run_stage2.py ${method} ${datatrack} ${feat_type} ${i_cv} 13 | done 14 | done 15 | 16 | echo "Collect stage2 data." 17 | python -u collect_stage2_result.py ${datatrack} ${feat_type} 18 | 19 | for method in ridge; do 20 | for i_cv in 0 1 2 3 4; do 21 | echo "Run stage3: ${method} ${datatrack}, ${feat_type}, ${i_cv}" 22 | python -u run_stage3.py ${method} ${datatrack} ${feat_type} ${i_cv} 23 | done 24 | done 25 | 26 | echo "Calculate result." 27 | 28 | python -u calc_result.py ${datatrack} ${feat_type} 29 | 30 | echo "done" 31 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/run_stage2.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from pathlib import Path 4 | import logging 5 | from logging import getLogger 6 | import random 7 | import json 8 | 9 | import numpy as np 10 | import torch 11 | 12 | from data_util import load_stage2_data 13 | 14 | from models import Ridge, LinearSVR, KernelSVR, LightGBM, RandomForest 15 | from gp_models import SVGP, ExactGP 16 | 17 | logger = getLogger(__name__) 18 | 19 | RAND_SEED = 0 20 | 21 | def get_arg(): 22 | import argparse 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('method') 25 | parser.add_argument('datatrack') 26 | parser.add_argument('feat_type') 27 | parser.add_argument('i_cv', type=int) 28 | parser.add_argument('--use_opt', action='store_true', default=False) 29 | return parser.parse_args() 30 | 31 | 32 | 33 | 34 | def main(): 35 | args = get_arg() 36 | 37 | random.seed(RAND_SEED) 38 | np.random.seed(RAND_SEED) 39 | torch.manual_seed(RAND_SEED) 40 | 41 | data = load_stage2_data( 42 | datatrack=args.datatrack, 43 | feat_type=args.feat_type, 44 | i_cv=args.i_cv, 45 | ) 46 | 47 | method = args.method 48 | 49 | if method == 'autogp': 50 | if args.datatrack == 'phase1-main': 51 | method = 'svgp' 52 | else: 53 | method = 'exactgp' 54 | 55 | if method == 'svgp': 56 | model = SVGP(stage='stage2') 57 | elif method == 'exactgp': 58 | model = ExactGP(stage='stage2') 59 | elif method == 'rf': 60 | model = RandomForest() 61 | else: 62 | if args.use_opt: 63 | param_file = Path('../out/ensemble-multidomain/opt_hp_stage2') / args.datatrack / \ 64 | f'{method}-{args.feat_type}' / 'params.json' 65 | params = json.load(open(param_file, 'rb')) 66 | logger.info('Params: {}'.format(params)) 67 | else: 68 | params = {} 69 | 70 | if method == 'ridge': 71 | model = Ridge(params=params) 72 | elif method == 'linear_svr': 73 | model = LinearSVR(params=params) 74 | elif method == 'kernel_svr': 75 | model = KernelSVR(params=params) 76 | elif method == 'lightgbm': 77 | model = LightGBM(params=params) 78 | else: 79 | raise RuntimeError('Not supported method: "{}"'.format(method)) 80 | 81 | model.train(data['train']['X'], data['train']['y'], 82 | data['val']['X'], data['val']['y']) 83 | 84 | df_val = model.predict(data['val']['X'], data['val']['df']) 85 | df_test = model.predict(data['test']['X'], data['test']['df']) 86 | 87 | out_dir = Path('../out/ensemble-multidomain/stage2') / args.datatrack / \ 88 | f'{method}-{args.feat_type}' / str(args.i_cv) 89 | os.makedirs(out_dir, exist_ok=True) 90 | 91 | df_val.to_csv(out_dir / 'val.csv') 92 | df_test.to_csv(out_dir / 'test.csv') 93 | 94 | model.save_model(out_dir) 95 | 96 | 97 | 98 | 99 | if __name__ == '__main__': 100 | logging.basicConfig(level=logging.INFO, format='{name}: {message}', style='{') 101 | main() 102 | 103 | 104 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/run_stage3.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from pathlib import Path 4 | import logging 5 | from logging import getLogger 6 | import random 7 | import json 8 | 9 | import numpy as np 10 | import torch 11 | 12 | from data_util import load_stage3_data 13 | 14 | from models import Ridge, LinearSVR, KernelSVR, LightGBM, RandomForest 15 | from gp_models import SVGP, ExactGP 16 | 17 | logger = getLogger(__name__) 18 | 19 | RAND_SEED = 0 20 | 21 | def get_arg(): 22 | import argparse 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('method') 25 | parser.add_argument('datatrack') 26 | parser.add_argument('feat_type') 27 | parser.add_argument('i_cv', type=int) 28 | parser.add_argument('--use_opt', action='store_true', default=False) 29 | return parser.parse_args() 30 | 31 | 32 | 33 | 34 | def main(): 35 | args = get_arg() 36 | 37 | random.seed(RAND_SEED) 38 | np.random.seed(RAND_SEED) 39 | torch.manual_seed(RAND_SEED) 40 | 41 | data = load_stage3_data( 42 | datatrack=args.datatrack, 43 | feat_type=args.feat_type, 44 | i_cv=args.i_cv, 45 | ) 46 | 47 | if args.method == 'svgp': 48 | model = SVGP(stage='stage2') 49 | elif args.method == 'exactgp': 50 | model = ExactGP(stage='stage2') 51 | elif args.method == 'rf': 52 | model = RandomForest() 53 | else: 54 | if args.use_opt: 55 | param_file = Path('../out/ensemble-multidomain/opt_hp_stage3') / args.datatrack / \ 56 | f'{args.method}-{args.feat_type}' / 'params.json' 57 | params = json.load(open(param_file, 'rb')) 58 | logger.info('Params: {}'.format(params)) 59 | else: 60 | params = {} 61 | 62 | if args.method == 'ridge': 63 | model = Ridge(params=params) 64 | elif args.method == 'linear_svr': 65 | model = LinearSVR(params=params) 66 | elif args.method == 'kernel_svr': 67 | model = KernelSVR(params=params) 68 | elif args.method == 'lightgbm': 69 | model = LightGBM(params=params) 70 | else: 71 | raise RuntimeError('Not supported method: "{}"'.format(args.method)) 72 | 73 | model.train(data['train']['X'], data['train']['y'], 74 | data['val']['X'], data['val']['y']) 75 | 76 | df_val = model.predict(data['val']['X'], data['val']['df']) 77 | df_test = model.predict(data['test']['X'], data['test']['df']) 78 | 79 | out_dir = Path('../out/ensemble-multidomain/stage3') / args.datatrack / \ 80 | f'{args.method}-{args.feat_type}' / str(args.i_cv) 81 | os.makedirs(out_dir, exist_ok=True) 82 | 83 | df_val.to_csv(out_dir / 'val.csv') 84 | df_test.to_csv(out_dir / 'test.csv') 85 | 86 | model.save_model(out_dir) 87 | 88 | 89 | 90 | 91 | if __name__ == '__main__': 92 | logging.basicConfig(level=logging.INFO, format='{name}: {message}', style='{') 93 | main() 94 | 95 | 96 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/stage2-method/main-strong1-weak48.yaml: -------------------------------------------------------------------------------- 1 | 2 | strong_learners: 3 | - main1 4 | weak_learners: 5 | datatracks: 6 | - phase1-main 7 | model_types: 8 | - ridge 9 | - linear_svr 10 | - kernel_svr 11 | - lightgbm 12 | - rf 13 | - exactgp 14 | ssl_types: 15 | - w2v_small 16 | - w2v_large 17 | - w2v_large2 18 | - w2v_xlsr 19 | - hubert_base 20 | - hubert_large 21 | - wavlm_base 22 | - wavlm_large 23 | 24 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/stage2-method/ood-strong1-weak144.yaml: -------------------------------------------------------------------------------- 1 | 2 | strong_learners: 3 | - ood1 4 | weak_learners: 5 | datatracks: 6 | - phase1-main 7 | - external-wo_test 8 | - phase1-ood 9 | model_types: 10 | - ridge 11 | - linear_svr 12 | - kernel_svr 13 | - lightgbm 14 | - rf 15 | - exactgp 16 | ssl_types: 17 | - w2v_small 18 | - w2v_large 19 | - w2v_large2 20 | - w2v_xlsr 21 | - hubert_base 22 | - hubert_large 23 | - wavlm_base 24 | - wavlm_large 25 | 26 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/unused/opt_stage1_all.sh: -------------------------------------------------------------------------------- 1 | 2 | set -eu 3 | 4 | ssl_types="w2v_small w2v_large w2v_large2 w2v_xlsr wavlm_base wavlm_large hubert_base hubert_large" 5 | 6 | for method in ridge lightgbm linear_svr kernel_svr; do 7 | for datatrack in phase1-ood phase1-main students-wo_test ; do 8 | for ssl_type in ${ssl_types}; do 9 | echo "${datatrack}, ${ssl_type}, ${method}" 10 | poetry run python -u opt_stage1.py ${method} ${datatrack} ${ssl_type} 11 | done 12 | done 13 | done 14 | 15 | echo "done" 16 | 17 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/unused/pred_stage1_exactgp.sh: -------------------------------------------------------------------------------- 1 | 2 | set -eu 3 | 4 | for train_datatrack in phase1-ood students-wo_test phase1-main; do 5 | for pred_datatrack in phase1-ood phase1-main ; do 6 | for ssl_type in w2v_small w2v_large w2v_large2 w2v_xlsr wavlm_base wavlm_large hubert_base hubert_large; do 7 | for method in exactgp; do 8 | for i_cv in 0 1 2 3 4; do 9 | echo "${method}, ${train_datatrack}, ${ssl_type}, ${i_cv}, ${pred_datatrack}" 10 | poetry run python -u pred_stage1.py ${method} ${train_datatrack} ${ssl_type} ${i_cv} ${pred_datatrack} 11 | done 12 | done 13 | done 14 | done 15 | done 16 | 17 | echo "done" 18 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/unused/run_stage1_exactgp.sh: -------------------------------------------------------------------------------- 1 | 2 | set -eu 3 | 4 | ssl_types="w2v_small w2v_large w2v_large2 w2v_xlsr wavlm_base wavlm_large hubert_base hubert_large" 5 | 6 | for datatrack in phase1-ood phase1-main students-wo_test; do 7 | for ssl_type in ${ssl_types}; do 8 | for i_cv in 0 1 2 3 4; do 9 | echo "${datatrack}, ${ssl_type}, ${i_cv}" 10 | poetry run python -u run_stage1.py exactgp ${datatrack} ${ssl_type} ${i_cv} 11 | done 12 | done 13 | done 14 | 15 | 16 | 17 | 18 | echo "done" 19 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/unused/run_stage1_other.sh: -------------------------------------------------------------------------------- 1 | 2 | set -eu 3 | 4 | ssl_types="w2v_small w2v_large w2v_large2 w2v_xlsr wavlm_base wavlm_large hubert_base hubert_large" 5 | 6 | for datatrack in phase1-ood phase1-main students-wo_test ; do 7 | for method in ridge linear_svr kernel_svr rf lightgbm; do 8 | for ssl_type in ${ssl_types} ; do 9 | for i_cv in 0 1 2 3 4; do 10 | echo "${method}, ${datatrack}, ${ssl_type}, ${i_cv}" 11 | poetry run python -u run_stage1.py ${method} ${datatrack} ${ssl_type} ${i_cv} 12 | done 13 | done 14 | done 15 | done 16 | 17 | 18 | 19 | echo "done" 20 | -------------------------------------------------------------------------------- /stacking/ensemble_multidomain_scripts/unused/run_stage2-end_opt.sh: -------------------------------------------------------------------------------- 1 | 2 | set -eu 3 | 4 | datatrack=phase1-main 5 | feat_type=main-strong1-weak48-opt 6 | 7 | poetry run python -u collect_stage1_result.py ${datatrack} ${feat_type} 8 | 9 | for method in ridge linear_svr kernel_svr lightgbm; do 10 | echo "Optimize hyperparameter for stage2: ${datatrack}, ${feat_type}, ${method}" 11 | poetry run python -u opt_stage2.py ${method} ${datatrack} ${feat_type} 12 | done 13 | 14 | 15 | for method in exactgp ridge linear_svr kernel_svr rf lightgbm; do 16 | for i_cv in 0 1 2 3 4; do 17 | echo "Run stage2: ${method}, ${datatrack}, ${feat_type}, ${i_cv}" 18 | poetry run python -u run_stage2.py --use_opt ${method} ${datatrack} ${feat_type} ${i_cv} 19 | done 20 | done 21 | 22 | echo "Collect stage2 data." 23 | poetry run python -u collect_stage2_result.py ${datatrack} ${feat_type} 24 | 25 | 26 | for method in ridge; do 27 | echo "Optimize hyperparameter for stage3: ${datatrack}, ${feat_type}, ${method}" 28 | poetry run python -u opt_stage3.py ${method} ${datatrack} ${feat_type} 29 | done 30 | 31 | for method in ridge; do 32 | for i_cv in 0 1 2 3 4; do 33 | echo "Run stage3: ${method} ${datatrack}, ${feat_type}, ${i_cv}" 34 | poetry run python -u run_stage3.py --use_opt ${method} ${datatrack} ${feat_type} ${i_cv} 35 | done 36 | done 37 | 38 | 39 | echo "Calculate result." 40 | 41 | poetry run python -u calc_result.py ${datatrack} ${feat_type} 42 | 43 | echo "done" 44 | -------------------------------------------------------------------------------- /stacking/strong_learner_result/ood1/answer-ood.csvfold_0: -------------------------------------------------------------------------------- 1 | sys3ac0c-utt161761a,3.00628410698846 2 | sys1b29a-uttc13c8ac,3.2306305915117264 3 | sys6d946-utt3fa137b,4.020784974098206 4 | sysc5b22-utte3cfb18,4.155950903892517 5 | sys50620-utt9080d4c,3.1653470546007156 6 | sysaefdf-utt8625c9a,4.133486866950989 7 | sys2dea8-utt0fb1b5f,3.668229043483734 8 | sys6aa80-uttb961ed0,3.049679197371006 9 | sys6aa80-utt11c30b8,3.0484387166798115 10 | sys2dea8-uttaca68d6,3.56439208984375 11 | sys86a3b-uttd7ee46d,3.4484071135520935 12 | sys86a3b-uttd408860,3.4142578840255737 13 | sys8569c-utt339d107,4.2832270860672 14 | sys3ac0c-utt3727514,2.874504864215851 15 | sys6d946-utt29b5124,3.8212228417396545 16 | sysc5b22-uttf413f6c,4.127067804336548 17 | sys9900c-utt380e04a,2.799784615635872 18 | sysc138b-utt4aa722c,3.3437674045562744 19 | sysc0247-uttcf421d3,4.266936302185059 20 | sysaefdf-utt4627f92,3.9700832962989807 21 | sys6d946-utt74b6d95,4.13299822807312 22 | sys288fc-utt8aac2a0,1.708821415901184 23 | sysc0232-utt984d929,4.142822504043579 24 | sys3ac0c-uttc822565,2.9671433679759502 25 | sysc0247-utt324aad1,4.186380982398987 26 | sys6aa80-utt944f00a,3.2828843891620636 27 | sys1b29a-utt31813b4,3.1065728440880775 28 | sys288fc-utt3ada6c5,1.6936789751052856 29 | sysc5b22-utt2ef4cc2,4.162316918373108 30 | sys25583-utt6330ecc,3.866409659385681 31 | sys0c3c7-uttc349635,1.7121723890304565 32 | sys6aa80-uttbee99e7,3.272907167673111 33 | sysc138b-utt3e6bb57,3.114454969763756 34 | sysc8b6d-utt2174cf3,3.2371485382318497 35 | sys3ac0c-utt9d9e772,3.394138514995575 36 | sysc0247-utt7bf8152,4.188221216201782 37 | sysc5b22-uttd0d4f7d,3.80684095621109 38 | sysc138b-utte0f2b88,3.3935784697532654 39 | sys50620-uttdd8c402,3.1611065566539764 40 | sys9900c-utt8a28b36,3.014586901292205 41 | syseb559-utt71685f4,2.927502065896988 42 | sys25583-uttdd97db1,3.775578737258911 43 | sysc93c1-utt9d574ec,1.705985426902771 44 | sysc8b6d-uttc01a331,3.4588173627853394 45 | sys8569c-utt4bd6700,4.181773781776428 46 | sys6d946-uttdc8ae80,3.7061699628829956 47 | -------------------------------------------------------------------------------- /stacking/strong_learner_result/ood1/answer-ood.csvfold_1: -------------------------------------------------------------------------------- 1 | sys288fc-utt2a428da,1.590195894241333 2 | sys8569c-uttda01d83,4.474315762519836 3 | sys3bec5-uttf9e9b65,4.395640730857849 4 | sysc5b22-utt507ff1e,4.440494537353516 5 | sys6d946-uttfc65d31,3.9793806672096252 6 | sys0c3c7-uttcde36fc,1.5656745433807373 7 | sysc138b-uttfaae8c5,3.9728047251701355 8 | sysaefdf-utta22cc47,4.128987550735474 9 | sys1b29a-utt238cb31,3.4775221943855286 10 | sysaefdf-uttac28750,4.094674468040466 11 | sys86a3b-utt5f51da3,3.7170485854148865 12 | sysc5b22-utt84b31de,4.41584038734436 13 | sysc0232-utt625ccdc,4.285597205162048 14 | sys86a3b-utt86c4d97,3.384651690721512 15 | sysc93c1-utt4a9ee10,1.7306538820266724 16 | sysc5b22-utt05b5a77,4.438835382461548 17 | syseb559-uttcd7fb9e,2.8986603021621704 18 | sysc0247-utt9720993,4.307410955429077 19 | sysc8b6d-utt3118c52,3.9715484380722046 20 | sysc8b6d-utt9bf948a,3.35931795835495 21 | sys25583-utt6518d56,4.095972418785095 22 | sysc8b6d-utt9ccaa2f,4.0339906215667725 23 | sysc0247-utt021f653,4.136787533760071 24 | sys50620-utt712ce0b,2.371592879295349 25 | sysc0247-utt457d98b,4.3827223777771 26 | sys0c3c7-uttb54c043,1.7769298553466797 27 | sysc0232-utt74896c3,4.17369544506073 28 | sysc93c1-utt4a606f1,1.8465670347213745 29 | sys50620-utt8f042c0,2.980513686314225 30 | sys1b29a-utt014d122,3.5027151703834534 31 | sys288fc-utt81bc852,1.5887526273727417 32 | sys3bec5-utte0c7cdf,4.340376496315002 33 | sys1b29a-utt1770396,3.6666112542152405 34 | sys6d946-utt3f304d4,4.209040522575378 35 | syseb559-utt4dd550d,2.7324198484420776 36 | sysc138b-uttcad56b6,3.19346222281456 37 | sys0c3c7-utt330976a,1.609305739402771 38 | sys3bec5-utt3301016,4.412176251411438 39 | sys3ac0c-utt5e1312b,3.5013919472694397 40 | sys6d946-utt6c1c6e1,4.254178404808044 41 | sys9900c-utt1f822b9,2.952831447124481 42 | sys3bec5-utt5ed2652,4.356784820556641 43 | sys2dea8-utt30d9387,3.3734480142593384 44 | sys86a3b-uttecd81d0,3.312489926815033 45 | sys3ac0c-uttafe7bb1,2.9006077349185944 46 | -------------------------------------------------------------------------------- /stacking/strong_learner_result/ood1/answer-ood.csvfold_2: -------------------------------------------------------------------------------- 1 | sysc0247-utt7596bba,3.916542708873749 2 | sys6d946-utt85e8181,3.309552401304245 3 | sysc5b22-utta450ecd,4.0385085344314575 4 | sysc8b6d-uttefe40e7,2.875736527144909 5 | sys3ac0c-uttc3491bb,3.1840667724609375 6 | sysc8b6d-uttbe29207,3.261176884174347 7 | sys288fc-utt116506b,1.8067857027053833 8 | syseb559-uttaa26ddd,3.0563898496329784 9 | sys288fc-uttcd33316,1.7399002313613892 10 | sys3ac0c-utt8152810,3.080054387450218 11 | sys6aa80-utta989ec2,3.151083752512932 12 | syseb559-uttfb60593,3.3584741950035095 13 | sysc0247-uttfad83f9,3.9709458351135254 14 | sysaefdf-utt1626164,3.972193658351898 15 | sys8569c-uttdf23051,4.064173221588135 16 | syseb559-uttfee71ad,3.322394996881485 17 | sys9900c-utt8e27993,3.5867892503738403 18 | sys86a3b-utt9594edb,3.778729021549225 19 | sys9900c-uttbabd33b,3.0449920147657394 20 | sys288fc-uttd2e0fbb,1.7499089241027832 21 | syseb559-utt80e03e6,3.3057122230529785 22 | sys86a3b-uttd537f38,3.1441497206687927 23 | sys9900c-utt7692d21,2.997207031119615 24 | sys2dea8-uttfee0d1c,3.522417902946472 25 | sys3bec5-uttb78d8af,4.011132478713989 26 | sys9900c-uttee29c4f,3.3247605562210083 27 | sysc93c1-utt6ffd131,1.8141653537750244 28 | sys6aa80-utt1458fca,3.6074836254119873 29 | sysc93c1-utte26df86,2.301818013191223 30 | sys3bec5-uttf0150e9,4.055609583854675 31 | sysc138b-utt5c64cce,3.0564831867814064 32 | sys86a3b-utt34ebcb5,3.5041947960853577 33 | sys6d946-utt97ad31f,3.905802607536316 34 | sys25583-utt85a7380,3.869354546070099 35 | sysc93c1-utt76f14cc,1.7454818487167358 36 | sysc0232-utt3f80869,4.01626980304718 37 | sys86a3b-uttd076e83,3.429683208465576 38 | sysaefdf-uttf1939ce,3.731216847896576 39 | sys3bec5-utt602db43,4.103129863739014 40 | sys288fc-uttf1d17fa,1.746074914932251 41 | sys25583-utt1c24cff,4.08445405960083 42 | sysc138b-utt67b77b4,3.580138087272644 43 | sys86a3b-utt81bb889,3.896349310874939 44 | sysc8b6d-utt5feb6c6,3.8013290762901306 45 | sys6aa80-utt7c44bb4,3.3627112805843353 46 | -------------------------------------------------------------------------------- /stacking/strong_learner_result/ood1/answer-ood.csvval_0: -------------------------------------------------------------------------------- 1 | sys0c3c7-utt29396c4,1.8738696575164795 2 | sys0c3c7-utt32149b2,1.7467237710952759 3 | sys0c3c7-utt3b43620,1.7445292472839355 4 | sys0c3c7-utte2ff6b2,1.7807897329330444 5 | sys1b29a-uttb5c2dfd,2.840162009000778 6 | sys1b29a-uttbe1f861,3.097307428717613 7 | sys1b29a-uttf1ba227,3.3140459656715393 8 | sys25583-utt32a3068,4.226199269294739 9 | sys25583-utt640ccda,4.173765182495117 10 | sys25583-uttb584e2d,4.127885103225708 11 | sys25583-utte43fff8,3.7679272294044495 12 | sys288fc-utt0b9b8c4,1.6744518280029297 13 | sys288fc-utt178b639,1.6881271600723267 14 | sys288fc-utt27c7f8e,1.6820378303527832 15 | sys288fc-utt4ca70be,1.6948258876800537 16 | sys288fc-utt737b11a,1.7346054315567017 17 | sys288fc-utt875c9bc,1.6459565162658691 18 | sys288fc-uttc86d188,1.7195290327072144 19 | sys2dea8-utt5d0ff00,3.190113589167595 20 | sys2dea8-uttc0684c5,3.69060617685318 21 | sys3ac0c-utt380b246,2.8424201756715775 22 | sys3ac0c-utt6e8c4c8,3.246820494532585 23 | sys3bec5-utt08b2712,4.067929744720459 24 | sys3bec5-utt256038c,4.0526769161224365 25 | sys3bec5-utt57f6999,3.77788507938385 26 | sys3bec5-utt6fa8766,4.025399327278137 27 | sys3bec5-uttf96e3f6,3.908892512321472 28 | sys40775-utt053fb06,2.8519225120544434 29 | sys40775-utt06280d0,3.298975795507431 30 | sys40775-utt21f7eb2,3.1713304817676544 31 | sys40775-utt25debbd,3.4360039830207825 32 | sys40775-utt262d917,3.363330453634262 33 | sys40775-utt2c689f4,3.5800936818122864 34 | sys40775-utt31d8803,3.100691355764866 35 | sys40775-utt3a6022b,3.7426048517227173 36 | sys40775-utt3c9d3a9,3.3877606987953186 37 | sys40775-utt3caa162,3.276692509651184 38 | sys40775-utt40cc32a,3.606023669242859 39 | sys40775-utt4830112,3.020839897915721 40 | sys40775-utt4beb3cb,3.937600076198578 41 | sys40775-utt52bccfe,3.4492527544498444 42 | sys40775-utt59ca17a,2.6853365898132324 43 | sys40775-utt63aa530,3.3543978333473206 44 | sys40775-utt6bccf3e,3.305658459663391 45 | sys40775-utt6cc4d95,3.031956598162651 46 | sys40775-utt734346c,3.5151914954185486 47 | sys40775-utt757daca,3.0054497895762324 48 | sys40775-utt7ab4c4c,3.481399267911911 49 | sys40775-utt7f58b52,3.616021752357483 50 | sys40775-utt8c70e0c,3.2666018903255463 51 | sys40775-utt9043da6,3.1890525817871094 52 | sys40775-utt95d9b13,3.7385385036468506 53 | sys40775-utt979fa4d,3.2476234287023544 54 | sys40775-utt98d2608,3.5850866436958313 55 | sys40775-uttacfa755,3.217980235815048 56 | sys40775-uttbffd493,3.032057635486126 57 | sys40775-uttca8a7fd,3.4135411977767944 58 | sys40775-uttcd51f1b,3.4589690566062927 59 | sys40775-uttd28fea3,2.8292620480060577 60 | sys40775-uttd4f5d6e,3.164977937936783 61 | sys40775-uttd803e0a,3.1480540484189987 62 | sys40775-uttddf3183,3.511105179786682 63 | sys40775-uttde965d9,3.2861602008342743 64 | sys40775-uttdeee18d,3.0033161328174174 65 | sys40775-uttdfaf42e,2.863734170794487 66 | sys40775-utte1baab8,3.1256515234708786 67 | sys40775-utte451004,3.608478367328644 68 | sys40775-utted1807d,3.1984187215566635 69 | sys40775-uttee1039e,3.0977888256311417 70 | sys40775-uttf0806be,3.4857234954833984 71 | sys40775-uttf1ca905,3.2617311477661133 72 | sys40775-uttf46b095,3.5599228739738464 73 | sys40775-uttfe39bc5,3.5992850065231323 74 | sys50620-utt36a51d2,3.6462497115135193 75 | sys50620-uttb648ed6,3.2346756011247635 76 | sys50620-uttdc9bc1b,3.1156226620078087 77 | sys50620-utte891e60,2.767196550965309 78 | sys50620-uttf932cea,2.980430446565151 79 | sys5a87a-utt3153e35,3.0460551269352436 80 | sys5a87a-uttd0f0a73,2.9686387814581394 81 | sys6aa80-utt41c16b6,3.295067608356476 82 | sys6aa80-utta52cb79,3.3978131711483 83 | sys6aa80-uttb7c3c36,3.3251181840896606 84 | sys6aa80-uttc4aa100,3.7246773838996887 85 | sys6aa80-utte8f1cbc,3.2285984605550766 86 | sys6d946-utt067a45d,3.814286947250366 87 | sys6d946-utt454ee70,4.088953971862793 88 | sys6d946-uttb8f803b,3.907171308994293 89 | sys6d946-uttc5d8752,4.120617151260376 90 | sys7d2dc-utt04fca18,4.019955635070801 91 | sys7d2dc-utt5b21c48,3.657771944999695 92 | sys8569c-utt53191c9,4.233504056930542 93 | sys8569c-utt75e2e48,4.241759300231934 94 | sys8569c-utt89043c6,4.2122883796691895 95 | sys86a3b-utt0dfad3e,3.167993053793907 96 | sys86a3b-utt6292b56,3.2046364545822144 97 | sys9900c-utt310ee00,3.095589391887188 98 | sys9900c-utt95df11f,3.000581094471272 99 | sys9900c-utt99ec0d0,2.683213710784912 100 | sys9900c-uttb72803f,3.0643272027373314 101 | sys9900c-utte2b3cdc,3.2304811626672745 102 | sys9900c-utte7af0d7,2.5658408403396606 103 | sysaefdf-utt0c7d358,3.761621117591858 104 | sysaefdf-utt33e9539,4.158675312995911 105 | sysaefdf-uttc8398cf,4.072099804878235 106 | sysaefdf-uttd832fe6,3.330040603876114 107 | sysc0232-utt0d5592e,4.234232664108276 108 | sysc0232-utt1136939,4.041603207588196 109 | sysc0232-utt321b7f3,3.6770675778388977 110 | sysc0232-utt43516ea,4.184963822364807 111 | sysc0232-utta98dc3e,4.0392502546310425 112 | sysc0232-uttcc93fbe,3.8121009469032288 113 | sysc0247-utt0c53cda,4.085257172584534 114 | sysc0247-utt71242b9,4.184931635856628 115 | sysc0247-utte133661,4.232062220573425 116 | sysc0247-uttfe8f1b6,4.241726398468018 117 | sysc138b-utt41cd1c3,2.7210413813591003 118 | sysc138b-utt52345ed,2.9520521461963654 119 | sysc138b-utt5ab7a39,3.299651026725769 120 | sysc138b-utt9d47077,3.0674551352858543 121 | sysc138b-uttc0e1eb7,3.020436340942979 122 | sysc5b22-utt2d3463f,4.21875536441803 123 | sysc5b22-utt3e37da2,4.168697476387024 124 | sysc5b22-utted9388f,4.244592905044556 125 | sysc5b22-uttf51fc7c,4.172379493713379 126 | sysc8b6d-utt445f8c7,3.9167977571487427 127 | sysc8b6d-utt6ab2d26,3.6579365730285645 128 | sysc8b6d-utt7eeebd8,3.894798457622528 129 | sysc8b6d-uttfe63322,3.786387324333191 130 | sysc93c1-utt249fc00,1.7767733335494995 131 | sysc93c1-utt577a573,1.7516868114471436 132 | syseb559-utt1464f1a,2.5668071806430817 133 | syseb559-utt274445f,2.5056876838207245 134 | syseb559-utt5337b78,3.019061068072915 135 | syseb559-utt7b8a073,2.846351996064186 136 | syseb559-uttdb7f27a,2.7430557310581207 137 | -------------------------------------------------------------------------------- /stacking/strong_learner_result/ood1/answer-ood.csvval_1: -------------------------------------------------------------------------------- 1 | sys0c3c7-utt29396c4,1.6569277048110962 2 | sys0c3c7-utt32149b2,1.5723577737808228 3 | sys0c3c7-utt3b43620,1.5654006004333496 4 | sys0c3c7-utte2ff6b2,1.604346752166748 5 | sys1b29a-uttb5c2dfd,3.4084918797016144 6 | sys1b29a-uttbe1f861,3.729147791862488 7 | sys1b29a-uttf1ba227,3.6519864201545715 8 | sys25583-utt32a3068,4.376435160636902 9 | sys25583-utt640ccda,4.414887428283691 10 | sys25583-uttb584e2d,4.447765827178955 11 | sys25583-utte43fff8,4.259235620498657 12 | sys288fc-utt0b9b8c4,1.5869495868682861 13 | sys288fc-utt178b639,1.5875202417373657 14 | sys288fc-utt27c7f8e,1.5941483974456787 15 | sys288fc-utt4ca70be,1.5977908372879028 16 | sys288fc-utt737b11a,1.8113304376602173 17 | sys288fc-utt875c9bc,1.5783995389938354 18 | sys288fc-uttc86d188,1.6000697612762451 19 | sys2dea8-utt5d0ff00,3.2767574787139893 20 | sys2dea8-uttc0684c5,3.652103900909424 21 | sys3ac0c-utt380b246,3.1722365468740463 22 | sys3ac0c-utt6e8c4c8,3.511052906513214 23 | sys3bec5-utt08b2712,4.320664405822754 24 | sys3bec5-utt256038c,4.418888330459595 25 | sys3bec5-utt57f6999,4.28934383392334 26 | sys3bec5-utt6fa8766,4.370813608169556 27 | sys3bec5-uttf96e3f6,4.3447675704956055 28 | sys40775-utt053fb06,3.485974669456482 29 | sys40775-utt06280d0,3.6210970282554626 30 | sys40775-utt21f7eb2,3.370717942714691 31 | sys40775-utt25debbd,3.524880826473236 32 | sys40775-utt262d917,3.469281166791916 33 | sys40775-utt2c689f4,3.5347546339035034 34 | sys40775-utt31d8803,3.364650994539261 35 | sys40775-utt3a6022b,4.066094398498535 36 | sys40775-utt3c9d3a9,3.7404050827026367 37 | sys40775-utt3caa162,3.8969979882240295 38 | sys40775-utt40cc32a,3.77112877368927 39 | sys40775-utt4830112,3.718696892261505 40 | sys40775-utt4beb3cb,3.9126113653182983 41 | sys40775-utt52bccfe,3.6475443243980408 42 | sys40775-utt59ca17a,2.9844800736755133 43 | sys40775-utt63aa530,3.4839371144771576 44 | sys40775-utt6bccf3e,3.4588995575904846 45 | sys40775-utt6cc4d95,3.3658704459667206 46 | sys40775-utt734346c,3.536608338356018 47 | sys40775-utt757daca,3.281584292650223 48 | sys40775-utt7ab4c4c,3.6340607404708862 49 | sys40775-utt7f58b52,3.5484229922294617 50 | sys40775-utt8c70e0c,3.6158772706985474 51 | sys40775-utt9043da6,3.355036199092865 52 | sys40775-utt95d9b13,3.745011806488037 53 | sys40775-utt979fa4d,3.4204566180706024 54 | sys40775-utt98d2608,3.7399336099624634 55 | sys40775-uttacfa755,2.953236449509859 56 | sys40775-uttbffd493,3.28841096162796 57 | sys40775-uttca8a7fd,3.6459875106811523 58 | sys40775-uttcd51f1b,3.506250262260437 59 | sys40775-uttd28fea3,3.4177536964416504 60 | sys40775-uttd4f5d6e,3.343674421310425 61 | sys40775-uttd803e0a,3.503384828567505 62 | sys40775-uttddf3183,3.475356549024582 63 | sys40775-uttde965d9,3.6865957975387573 64 | sys40775-uttdeee18d,3.341072738170624 65 | sys40775-uttdfaf42e,3.3064666390419006 66 | sys40775-utte1baab8,3.430553048849106 67 | sys40775-utte451004,3.3815099596977234 68 | sys40775-utted1807d,3.7209166288375854 69 | sys40775-uttee1039e,3.744911015033722 70 | sys40775-uttf0806be,3.442750334739685 71 | sys40775-uttf1ca905,3.821177899837494 72 | sys40775-uttf46b095,3.7114606499671936 73 | sys40775-uttfe39bc5,3.68908029794693 74 | sys50620-utt36a51d2,3.8827565908432007 75 | sys50620-uttb648ed6,3.0986974611878395 76 | sys50620-uttdc9bc1b,2.894637681543827 77 | sys50620-utte891e60,2.5575685501098633 78 | sys50620-uttf932cea,3.071582153439522 79 | sys5a87a-utt3153e35,3.135112166404724 80 | sys5a87a-uttd0f0a73,2.8925078213214874 81 | sys6aa80-utt41c16b6,3.6064711213111877 82 | sys6aa80-utta52cb79,3.840636134147644 83 | sys6aa80-uttb7c3c36,3.733500063419342 84 | sys6aa80-uttc4aa100,3.693002223968506 85 | sys6aa80-utte8f1cbc,3.446985363960266 86 | sys6d946-utt067a45d,4.0865994691848755 87 | sys6d946-utt454ee70,4.334984540939331 88 | sys6d946-uttb8f803b,4.222865581512451 89 | sys6d946-uttc5d8752,4.313713550567627 90 | sys7d2dc-utt04fca18,4.010199785232544 91 | sys7d2dc-utt5b21c48,3.5608004331588745 92 | sys8569c-utt53191c9,4.4660950899124146 93 | sys8569c-utt75e2e48,4.4166131019592285 94 | sys8569c-utt89043c6,4.4364928007125854 95 | sys86a3b-utt0dfad3e,3.55155611038208 96 | sys86a3b-utt6292b56,3.365819603204727 97 | sys9900c-utt310ee00,2.9474673829972744 98 | sys9900c-utt95df11f,2.4395697712898254 99 | sys9900c-utt99ec0d0,2.7191194593906403 100 | sys9900c-uttb72803f,2.4801058173179626 101 | sys9900c-utte2b3cdc,3.209576740860939 102 | sys9900c-utte7af0d7,2.9360008537769318 103 | sysaefdf-utt0c7d358,4.201140284538269 104 | sysaefdf-utt33e9539,4.343910217285156 105 | sysaefdf-uttc8398cf,4.286288022994995 106 | sysaefdf-uttd832fe6,3.982454299926758 107 | sysc0232-utt0d5592e,4.392175316810608 108 | sysc0232-utt1136939,4.324562072753906 109 | sysc0232-utt321b7f3,4.107685089111328 110 | sysc0232-utt43516ea,4.312808871269226 111 | sysc0232-utta98dc3e,4.302412629127502 112 | sysc0232-uttcc93fbe,4.046707510948181 113 | sysc0247-utt0c53cda,4.259897470474243 114 | sysc0247-utt71242b9,4.30824601650238 115 | sysc0247-utte133661,4.403515815734863 116 | sysc0247-uttfe8f1b6,4.245032429695129 117 | sysc138b-utt41cd1c3,3.3730322420597076 118 | sysc138b-utt52345ed,3.030466416850686 119 | sysc138b-utt5ab7a39,3.429475098848343 120 | sysc138b-utt9d47077,3.3303037881851196 121 | sysc138b-uttc0e1eb7,3.2648874521255493 122 | sysc5b22-utt2d3463f,4.465267658233643 123 | sysc5b22-utt3e37da2,4.396283984184265 124 | sysc5b22-utted9388f,4.438539385795593 125 | sysc5b22-uttf51fc7c,4.36857283115387 126 | sysc8b6d-utt445f8c7,4.024686098098755 127 | sysc8b6d-utt6ab2d26,3.605237662792206 128 | sysc8b6d-utt7eeebd8,3.8477973341941833 129 | sysc8b6d-uttfe63322,4.0216004848480225 130 | sysc93c1-utt249fc00,1.6773823499679565 131 | sysc93c1-utt577a573,1.6771522760391235 132 | syseb559-utt1464f1a,2.2051857709884644 133 | syseb559-utt274445f,2.5563432574272156 134 | syseb559-utt5337b78,2.7770099490880966 135 | syseb559-utt7b8a073,2.8372543454170227 136 | syseb559-uttdb7f27a,2.366033673286438 137 | -------------------------------------------------------------------------------- /stacking/strong_learner_result/ood1/answer-ood.csvval_2: -------------------------------------------------------------------------------- 1 | sys0c3c7-utt29396c4,1.9007413387298584 2 | sys0c3c7-utt32149b2,1.7681547403335571 3 | sys0c3c7-utt3b43620,1.721413016319275 4 | sys0c3c7-utte2ff6b2,1.7883329391479492 5 | sys1b29a-uttb5c2dfd,3.177410289645195 6 | sys1b29a-uttbe1f861,3.485043078660965 7 | sys1b29a-uttf1ba227,3.257159322500229 8 | sys25583-utt32a3068,4.03691840171814 9 | sys25583-utt640ccda,4.121394991874695 10 | sys25583-uttb584e2d,4.063861012458801 11 | sys25583-utte43fff8,3.9808879494667053 12 | sys288fc-utt0b9b8c4,1.7661837339401245 13 | sys288fc-utt178b639,1.7247849702835083 14 | sys288fc-utt27c7f8e,1.7177363634109497 15 | sys288fc-utt4ca70be,1.7433276176452637 16 | sys288fc-utt737b11a,1.8274472951889038 17 | sys288fc-utt875c9bc,1.7218009233474731 18 | sys288fc-uttc86d188,1.7508944272994995 19 | sys2dea8-utt5d0ff00,3.3823030591011047 20 | sys2dea8-uttc0684c5,3.7632473707199097 21 | sys3ac0c-utt380b246,3.2561529874801636 22 | sys3ac0c-utt6e8c4c8,3.215494468808174 23 | sys3bec5-utt08b2712,4.037221789360046 24 | sys3bec5-utt256038c,4.032892227172852 25 | sys3bec5-utt57f6999,3.992736518383026 26 | sys3bec5-utt6fa8766,4.005888819694519 27 | sys3bec5-uttf96e3f6,4.036606192588806 28 | sys40775-utt053fb06,3.1620380580425262 29 | sys40775-utt06280d0,3.2348039001226425 30 | sys40775-utt21f7eb2,3.249326527118683 31 | sys40775-utt25debbd,3.384929269552231 32 | sys40775-utt262d917,3.477863311767578 33 | sys40775-utt2c689f4,3.4279374182224274 34 | sys40775-utt31d8803,3.2081397473812103 35 | sys40775-utt3a6022b,4.024236440658569 36 | sys40775-utt3c9d3a9,3.5308963656425476 37 | sys40775-utt3caa162,3.556710362434387 38 | sys40775-utt40cc32a,3.562081277370453 39 | sys40775-utt4830112,3.3248510658740997 40 | sys40775-utt4beb3cb,3.7101657390594482 41 | sys40775-utt52bccfe,3.6892759203910828 42 | sys40775-utt59ca17a,2.932113580405712 43 | sys40775-utt63aa530,3.2914924323558807 44 | sys40775-utt6bccf3e,3.3274621665477753 45 | sys40775-utt6cc4d95,3.0198783073574305 46 | sys40775-utt734346c,3.5169814229011536 47 | sys40775-utt757daca,3.221805825829506 48 | sys40775-utt7ab4c4c,3.293033629655838 49 | sys40775-utt7f58b52,3.5033302903175354 50 | sys40775-utt8c70e0c,3.3420273661613464 51 | sys40775-utt9043da6,3.2828216552734375 52 | sys40775-utt95d9b13,3.5908207297325134 53 | sys40775-utt979fa4d,3.229120597243309 54 | sys40775-utt98d2608,3.323885351419449 55 | sys40775-uttacfa755,3.0765668898820877 56 | sys40775-uttbffd493,2.9936748747713864 57 | sys40775-uttca8a7fd,3.5150104761123657 58 | sys40775-uttcd51f1b,3.2664534151554108 59 | sys40775-uttd28fea3,2.7796780467033386 60 | sys40775-uttd4f5d6e,2.963381800800562 61 | sys40775-uttd803e0a,3.008986245840788 62 | sys40775-uttddf3183,3.4704970121383667 63 | sys40775-uttde965d9,3.3853626251220703 64 | sys40775-uttdeee18d,3.074149541556835 65 | sys40775-uttdfaf42e,2.9378270842134953 66 | sys40775-utte1baab8,3.2391035854816437 67 | sys40775-utte451004,3.4486441910266876 68 | sys40775-utted1807d,3.173550084233284 69 | sys40775-uttee1039e,3.364699602127075 70 | sys40775-uttf0806be,3.372744768857956 71 | sys40775-uttf1ca905,3.4473859667778015 72 | sys40775-uttf46b095,3.8440520763397217 73 | sys40775-uttfe39bc5,3.4940754771232605 74 | sys50620-utt36a51d2,3.7633183002471924 75 | sys50620-uttb648ed6,3.342268943786621 76 | sys50620-uttdc9bc1b,3.2014998346567154 77 | sys50620-utte891e60,2.83543960750103 78 | sys50620-uttf932cea,3.337063044309616 79 | sys5a87a-utt3153e35,2.974347261711955 80 | sys5a87a-uttd0f0a73,3.0582089573144913 81 | sys6aa80-utt41c16b6,3.3762663900852203 82 | sys6aa80-utta52cb79,3.5862852334976196 83 | sys6aa80-uttb7c3c36,3.4562728106975555 84 | sys6aa80-uttc4aa100,3.7309938073158264 85 | sys6aa80-utte8f1cbc,3.317980080842972 86 | sys6d946-utt067a45d,3.9896318316459656 87 | sys6d946-utt454ee70,4.0640339851379395 88 | sys6d946-uttb8f803b,3.968318819999695 89 | sys6d946-uttc5d8752,4.050647020339966 90 | sys7d2dc-utt04fca18,3.9717133045196533 91 | sys7d2dc-utt5b21c48,3.563229262828827 92 | sys8569c-utt53191c9,4.033492922782898 93 | sys8569c-utt75e2e48,4.029201507568359 94 | sys8569c-utt89043c6,4.080493688583374 95 | sys86a3b-utt0dfad3e,3.4544460773468018 96 | sys86a3b-utt6292b56,3.313985913991928 97 | sys9900c-utt310ee00,3.524359345436096 98 | sys9900c-utt95df11f,3.493915557861328 99 | sys9900c-utt99ec0d0,2.957903254777193 100 | sys9900c-uttb72803f,3.5492897033691406 101 | sys9900c-utte2b3cdc,4.015172481536865 102 | sys9900c-utte7af0d7,2.7136725187301636 103 | sysaefdf-utt0c7d358,3.929300010204315 104 | sysaefdf-utt33e9539,4.067215442657471 105 | sysaefdf-uttc8398cf,4.009889483451843 106 | sysaefdf-uttd832fe6,3.434192717075348 107 | sysc0232-utt0d5592e,4.063883185386658 108 | sysc0232-utt1136939,3.911326766014099 109 | sysc0232-utt321b7f3,3.7473238706588745 110 | sysc0232-utt43516ea,3.8803505897521973 111 | sysc0232-utta98dc3e,3.950296401977539 112 | sysc0232-uttcc93fbe,3.526122808456421 113 | sysc0247-utt0c53cda,3.925826609134674 114 | sysc0247-utt71242b9,4.021496653556824 115 | sysc0247-utte133661,4.058210730552673 116 | sysc0247-uttfe8f1b6,3.98443067073822 117 | sysc138b-utt41cd1c3,3.050193063914776 118 | sysc138b-utt52345ed,2.932532951235771 119 | sysc138b-utt5ab7a39,3.177796170115471 120 | sysc138b-utt9d47077,3.122036322951317 121 | sysc138b-uttc0e1eb7,3.0465613305568695 122 | sysc5b22-utt2d3463f,4.089601039886475 123 | sysc5b22-utt3e37da2,4.042580723762512 124 | sysc5b22-utted9388f,4.073251724243164 125 | sysc5b22-uttf51fc7c,4.013024687767029 126 | sysc8b6d-utt445f8c7,4.025208115577698 127 | sysc8b6d-utt6ab2d26,4.048284530639648 128 | sysc8b6d-utt7eeebd8,3.9827736616134644 129 | sysc8b6d-uttfe63322,4.045578479766846 130 | sysc93c1-utt249fc00,2.2003649473190308 131 | sysc93c1-utt577a573,2.0730539560317993 132 | syseb559-utt1464f1a,2.8193936198949814 133 | syseb559-utt274445f,3.435039311647415 134 | syseb559-utt5337b78,3.293728530406952 135 | syseb559-utt7b8a073,3.699833929538727 136 | syseb559-uttdb7f27a,3.16067473590374 137 | -------------------------------------------------------------------------------- /strong/README.md: -------------------------------------------------------------------------------- 1 | # UTMOS Strong Learner 2 | 3 | Training and inference scripts for the UTMOS strong learner. 4 | 5 | ## Prerequesities 6 | 7 | * poetry 8 | * [WavAugment](https://github.com/facebookresearch/WavAugment) 9 | 10 | ## Pretrained model 11 | Pretrained **UTMOS strong** models for the main and OOD tracks are available. 12 | For the model details, refer to the [paper](https://arxiv.org/abs/2204.02152). 13 | 14 | - [Main](https://drive.google.com/drive/folders/1U4XQze8mJqV4TRMwTcY6T247RpmU5hRg?usp=sharing) 15 | - [OOD](https://drive.google.com/drive/folders/1dPlV92fyKY1arei7TcU2ZFB-wZkYhIqK?usp=sharing) 16 | 17 | Note that each of the above directory contains pretrained models obtained with five different random seeds. 18 | 19 | ## Setup 20 | 1. Download SSL model checkpoints from [fairseq repo](https://github.com/pytorch/fairseq). 21 | 1. Run the following commands. 22 | ```shell 23 | cd path/to/this/repository 24 | poetry install 25 | cd strong/ 26 | ln -s path/to/dataset/ data/ 27 | poetry shell 28 | ``` 29 | 30 | ## Preprocessing 31 | The phoeneme transcription is alreday present in this repository. 32 | You can also perform the transcribe by your own! 33 | ```shell 34 | cd strong/ 35 | python transcribe_speech.py 36 | ``` 37 | 38 | ## Training 39 | 40 | To train the strong learner, run the following commands for each of the tracks. 41 | 42 | Main track 43 | ```shell 44 | cd strong/ 45 | python train.py dataset.use_data.ood=False dataset.use_data.external=False 46 | ``` 47 | OOD track 48 | ```shell 49 | cd strong/ 50 | python train.py train.model_selection_metric=val_SRCC_system_ood 51 | ``` 52 | 53 | ## Prediction 54 | To predict scores with the trained strong learner, run the following command. 55 | ```shell 56 | python predict.py +ckpt_path=outputs/${date}/${time}/train_outputs/hoge.ckpt 57 | ``` 58 | To perform prdiction with the pretrained model, run the following command. 59 | 60 | ```shell 61 | python predict.py +ckpt_path=outputs/${date}/${time}/train_outputs/hoge.ckpt +paper_weights=True 62 | ``` 63 | -------------------------------------------------------------------------------- /strong/configs/cv.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: cv 3 | - model: default 4 | - train: default 5 | 6 | batch_size_and_model: "wav2vec2-base-4" 7 | debug: False 8 | deepspeed: False 9 | outfile: answer.csv -------------------------------------------------------------------------------- /strong/configs/dataset/cv.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | data_dir: data/phase1-main/DATA 3 | data_sources: 4 | - 5 | name: "main" 6 | train_mos_list_path: data/phase1-main/DATA/sets/TRAINSET 7 | val_mos_list_path: data/phase1-main/DATA/sets/DEVSET 8 | test_mos_list_path: data/phase1-main/DATA/sets/test.scp 9 | wav_dir: data/phase1-main/DATA/wav/ 10 | data_dir: data/phase1-main/DATA 11 | outfile: answer-main.csv 12 | - 13 | name: "ood" 14 | train_mos_list_path: data/phase1-ood/DATA/sets/TRAINSET 15 | val_mos_list_path: data/phase1-ood/DATA/sets/DEVSET 16 | test_mos_list_path: data/phase1-ood/DATA/sets/test.scp 17 | wav_dir: data/phase1-ood/DATA/wav/ 18 | data_dir: data/phase1-ood/DATA 19 | outfile: answer-ood.csv 20 | - 21 | name: "external" 22 | train_mos_list_path: TRAINSET_external.txt 23 | wav_dir: data/phase1-ood/DATA/wav/ 24 | data_dir: data/phase1-ood/DATA 25 | k_cv: 5 26 | use_data: 27 | main: True 28 | ood: True 29 | external: True 30 | datamodule: 31 | _target_: dataset.CVDataModule 32 | only_mean: False 33 | additional_datas: 34 | - 35 | _target_: dataset.PhonemeData 36 | transcription_file_path: 'transcriptions_clustered.csv' 37 | with_reference: True 38 | - 39 | _target_: dataset.NormalizeScore 40 | org_max: 5.0 41 | org_min: 1.0 42 | normalize_to_max: 1.0 43 | normalize_to_min: -1.0 44 | - 45 | _target_: dataset.AugmentWav 46 | pitch_shift_minmax: 47 | min: -300 48 | max: 300 49 | random_time_warp_f: 1.0 50 | - 51 | _target_: dataset.SliceWav 52 | max_wav_seconds: 10 53 | -------------------------------------------------------------------------------- /strong/configs/dataset/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | data_dir: data/phase1-main/DATA 3 | data_sources: 4 | - 5 | name: "main" 6 | train_mos_list_path: data/phase1-main/DATA/sets/TRAINSET 7 | val_mos_list_path: data/phase1-main/DATA/sets/DEVSET 8 | test_mos_list_path: data/phase1-main/DATA/sets/test.scp 9 | wav_dir: data/phase1-main/DATA/wav/ 10 | data_dir: data/phase1-main/DATA 11 | outfile: answer-main.csv 12 | - 13 | name: "ood" 14 | train_mos_list_path: data/phase1-ood/DATA/sets/TRAINSET 15 | val_mos_list_path: data/phase1-ood/DATA/sets/DEVSET 16 | test_mos_list_path: data/phase1-ood/DATA/sets/test.scp 17 | wav_dir: data/phase1-ood/DATA/wav/ 18 | data_dir: data/phase1-ood/DATA 19 | outfile: answer-ood.csv 20 | - 21 | name: "external" 22 | train_mos_list_path: TRAINSET_external.txt 23 | wav_dir: data/phase1-ood/DATA/wav/ 24 | data_dir: data/phase1-ood/DATA 25 | use_data: 26 | main: True 27 | ood: True 28 | external: True 29 | datamodule: 30 | _target_: dataset.DataModule 31 | only_mean: False 32 | additional_datas: 33 | - 34 | _target_: dataset.PhonemeData 35 | transcription_file_path: 'transcriptions_clustered.csv' 36 | with_reference: True 37 | - 38 | _target_: dataset.NormalizeScore 39 | org_max: 5.0 40 | org_min: 1.0 41 | normalize_to_max: 1.0 42 | normalize_to_min: -1.0 43 | - 44 | _target_: dataset.AugmentWav 45 | pitch_shift_minmax: 46 | min: -300 47 | max: 300 48 | random_time_warp_f: 1.0 49 | - 50 | _target_: dataset.SliceWav 51 | max_wav_seconds: 10 52 | -------------------------------------------------------------------------------- /strong/configs/dataset/wo_augment.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | data_dir: data/phase1-main/DATA 3 | data_sources: 4 | - 5 | name: "main" 6 | train_mos_list_path: data/phase1-main/DATA/sets/TRAINSET 7 | val_mos_list_path: data/phase1-main/DATA/sets/DEVSET 8 | test_mos_list_path: data/phase1-main/DATA/sets/test.scp 9 | test_post_mos_list_path: data/phase1-main/DATA/sets/TESTSET 10 | wav_dir: data/phase1-main/DATA/wav/ 11 | data_dir: data/phase1-main/DATA 12 | outfile: answer-main.csv 13 | - 14 | name: "ood" 15 | train_mos_list_path: data/phase1-ood/DATA/sets/TRAINSET 16 | val_mos_list_path: data/phase1-ood/DATA/sets/DEVSET 17 | test_mos_list_path: data/phase1-ood/DATA/sets/test.scp 18 | test_post_mos_list_path: data/phase1-ood/DATA/sets/TESTSET 19 | wav_dir: data/phase1-ood/DATA/wav/ 20 | data_dir: data/phase1-ood/DATA 21 | outfile: answer-ood.csv 22 | - 23 | name: "external" 24 | train_mos_list_path: TRAINSET_external.txt 25 | wav_dir: data/phase1-ood/DATA/wav/ 26 | data_dir: data/phase1-ood/DATA 27 | use_data: 28 | main: True 29 | ood: True 30 | external: True 31 | datamodule: 32 | _target_: dataset.DataModule 33 | only_mean: False 34 | additional_datas: 35 | - 36 | _target_: dataset.PhonemeData 37 | transcription_file_path: 'transcriptions_clustered.csv' 38 | with_reference: True 39 | - 40 | _target_: dataset.NormalizeScore 41 | org_max: 5.0 42 | org_min: 1.0 43 | normalize_to_max: 1.0 44 | normalize_to_min: -1.0 45 | - 46 | _target_: dataset.SliceWav 47 | max_wav_seconds: 10 48 | -------------------------------------------------------------------------------- /strong/configs/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: default 3 | - model: default 4 | - train: default 5 | 6 | debug: False 7 | deepspeed: False 8 | outfile: answer.csv 9 | -------------------------------------------------------------------------------- /strong/configs/model/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | lightning_module: 3 | _target_: lightning_module.BaselineLightningModule 4 | WavLM: False 5 | 6 | feature_extractors: 7 | - 8 | _target_: model.load_ssl_model 9 | cp_path: ../fairseq_checkpoints/wav2vec_small.pt 10 | 11 | - 12 | _target_: model.PhonemeEncoder 13 | hidden_dim: 256 14 | emb_dim: 256 15 | out_dim: 256 16 | n_lstm_layers: 3 17 | vocab_size: 198 18 | - 19 | _target_: model.DomainEmbedding 20 | n_domains: 3 21 | domain_dim: 128 22 | 23 | output_layers: 24 | - 25 | _target_: model.LDConditioner 26 | judge_dim: 128 27 | num_judges: 3000 28 | - 29 | _target_: model.Projection 30 | hidden_dim: 2048 31 | activation: 32 | _target_: torch.nn.ReLU 33 | range_clipping: False 34 | 35 | -------------------------------------------------------------------------------- /strong/configs/model/wo_phoneme.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | lightning_module: 3 | _target_: lightning_module.BaselineLightningModule 4 | WavLM: False 5 | 6 | feature_extractors: 7 | - 8 | _target_: model.load_ssl_model 9 | cp_path: ../fairseq_checkpoints/fairseq/wav2vec_small.pt 10 | 11 | - 12 | _target_: model.DomainEmbedding 13 | n_domains: 3 14 | domain_dim: 128 15 | 16 | output_layers: 17 | - 18 | _target_: model.LDConditioner 19 | judge_dim: 128 20 | num_judges: 3000 21 | - 22 | _target_: model.Projection 23 | hidden_dim: 2048 24 | activation: 25 | _target_: torch.nn.ReLU 26 | range_clipping: False 27 | 28 | -------------------------------------------------------------------------------- /strong/configs/optuna-main.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: default 3 | - model: default 4 | - train: default 5 | - override hydra/sweeper: optuna 6 | 7 | hydra: 8 | sweeper: 9 | n_trials: 100 10 | direction: maximize 11 | storage: ??? 12 | study_name: sslmos_listener_ld_contrastive_abci 13 | n_jobs: 1 14 | search_space: 15 | train.criterion.loss_weights.1: 16 | type: float 17 | low: 0.0 18 | high: 2.0 19 | step: 0.1 20 | train.criterion.loss_instances.1.margin: 21 | type: float 22 | low: 0.0 23 | high: 0.5 24 | step: 0.1 25 | train.criterion.loss_instances.0.tau: 26 | type: categorical 27 | choices: 28 | - 0.1 29 | - 0.25 30 | train.criterion.loss_instances.0.mode: 31 | type: categorical 32 | choices: 33 | - 'frame' 34 | dataset.additional_datas.2.pitch_shift_minmax.min: 35 | type: categorical 36 | choices: 37 | - 0 38 | - -100 39 | - -200 40 | - -300 41 | dataset.additional_datas.2.pitch_shift_minmax.max: 42 | type: categorical 43 | choices: 44 | - 0 45 | - 100 46 | - 200 47 | - 300 48 | dataset.additional_datas.2.random_time_warp_f: 49 | type: float 50 | low: 1 51 | high: 3 52 | step: 0.5 53 | dataset.use_data.ood: 54 | type: categorical 55 | choices: 56 | - True 57 | - False 58 | dataset.use_data.external: 59 | type: categorical 60 | choices: 61 | - True 62 | - False 63 | dataset.only_mean: 64 | type: categorical 65 | choices: 66 | - True 67 | - False 68 | batch_size_and_model: 69 | type: categorical 70 | choices: 71 | - "wav2vec2-base-4" 72 | - "wav2vec2-base-8" 73 | - "wav2vec2-base-16" 74 | - "wavlm-large-4" 75 | batch_size_and_model: "wav2vec2-base-4" 76 | tuning_target: ??? 77 | debug: False 78 | outfile: answer.csv 79 | -------------------------------------------------------------------------------- /strong/configs/optuna-ood.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: default 3 | - model: default 4 | - train: default 5 | - override hydra/sweeper: optuna 6 | 7 | hydra: 8 | sweeper: 9 | n_trials: 100 10 | direction: maximize 11 | storage: ??? 12 | study_name: sslmos_listener_ld_contrastive_abci 13 | n_jobs: 1 14 | search_space: 15 | train.criterion.loss_weights.1: 16 | type: float 17 | low: 0.0 18 | high: 2.0 19 | step: 0.1 20 | train.criterion.loss_instances.1.margin: 21 | type: float 22 | low: 0.0 23 | high: 0.5 24 | step: 0.1 25 | train.criterion.loss_instances.0.tau: 26 | type: categorical 27 | choices: 28 | - 0.1 29 | - 0.25 30 | train.criterion.loss_instances.0.mode: 31 | type: categorical 32 | choices: 33 | - 'frame' 34 | dataset.additional_datas.2.pitch_shift_minmax.min: 35 | type: categorical 36 | choices: 37 | - 0 38 | - -100 39 | - -200 40 | - -300 41 | dataset.additional_datas.2.pitch_shift_minmax.max: 42 | type: categorical 43 | choices: 44 | - 0 45 | - 100 46 | - 200 47 | - 300 48 | dataset.additional_datas.2.random_time_warp_f: 49 | type: float 50 | low: 1 51 | high: 3 52 | step: 0.5 53 | dataset.use_data.main: 54 | type: categorical 55 | choices: 56 | - True 57 | - False 58 | dataset.use_data.external: 59 | type: categorical 60 | choices: 61 | - True 62 | - False 63 | dataset.only_mean: 64 | type: categorical 65 | choices: 66 | - True 67 | - False 68 | batch_size_and_model: 69 | type: categorical 70 | choices: 71 | - "wav2vec2-base-4" 72 | - "wav2vec2-base-8" 73 | - "wav2vec2-base-16" 74 | - "wavlm-large-4" 75 | batch_size_and_model: "wav2vec2-base-4" 76 | tuning_target: ??? 77 | debug: False 78 | outfile: answer.csv 79 | -------------------------------------------------------------------------------- /strong/configs/train/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | seed: 1234 3 | use_wandb: False 4 | model_selection_metric: val_SRCC_system_main 5 | train_batch_size: 12 6 | val_batch_size: 1 7 | test_batch_size: 1 8 | out_dir: train_output/ 9 | trainer_args: 10 | max_steps: 15_000 11 | gpus: [0] 12 | deterministic: True 13 | auto_select_gpus: False 14 | benchmark: True 15 | precision: 32 16 | gradient_clip_val: 1.0 17 | flush_logs_every_n_steps: 10 18 | val_check_interval: 0.5 19 | accumulate_grad_batches: 2 20 | # strategy: ddp 21 | optimizer: 22 | _target_: torch.optim.Adam 23 | lr: 2e-5 24 | scheduler: 25 | _target_: transformers.get_linear_schedule_with_warmup 26 | num_warmup_steps: 4000 27 | num_training_steps: 15_000 28 | early_stopping: 29 | patience: 100 30 | 31 | criterion: 32 | _target_: loss_function.CombineLosses 33 | loss_weights: 34 | - 1.0 35 | - 0.5 36 | loss_instances: 37 | - 38 | _target_: loss_function.ClippedMSELoss 39 | criterion: 40 | _target_: torch.nn.MSELoss 41 | reduction: 'none' 42 | tau: 0.25 43 | mode: 'frame' 44 | - 45 | _target_: loss_function.ContrastiveLoss 46 | margin: 0.1 47 | -------------------------------------------------------------------------------- /strong/cross_validation.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from pytorch_lightning.loggers.csv_logs import CSVLogger 3 | from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger 4 | from pytorch_lightning.callbacks import ModelCheckpoint 5 | from pytorch_lightning import Trainer 6 | from pytorch_lightning.callbacks.early_stopping import EarlyStopping 7 | from dataset import CVDataModule, TestDataModule 8 | from lightning_module import UTMOSLightningModule 9 | import hydra 10 | import wandb 11 | 12 | @hydra.main(config_path="configs",config_name='cv') 13 | def cross_validation(cfg): 14 | 15 | k_cv = cfg.dataset.k_cv 16 | i_cv = cfg.dataset.i_cv 17 | 18 | print("------- Using subset_{} out of {}-fold -------".format(i_cv, k_cv)) 19 | d_metrics = fit_and_test(cfg, k_cv, i_cv) 20 | wandb.log(d_metrics) 21 | 22 | 23 | def fit_and_test(cfg, k_cv, i_cv): 24 | debug = cfg.debug 25 | if debug: 26 | cfg.train.trainer_args.max_steps=10 27 | 28 | 29 | loggers = [] 30 | loggers.append(CSVLogger(save_dir=cfg.train.out_dir, name="train_log")) 31 | loggers.append(TensorBoardLogger(save_dir=cfg.train.out_dir, name="tf_log")) 32 | if cfg.train.use_wandb: 33 | loggers.append(WandbLogger(project="voicemos",offline=debug)) 34 | 35 | checkpoint_callback = ModelCheckpoint( 36 | dirpath=cfg.train.out_dir, 37 | save_weights_only=True, 38 | save_top_k=1, 39 | save_last=True, 40 | every_n_epochs=1, 41 | monitor=cfg.train.model_selection_metric, 42 | mode='max' 43 | ) 44 | callbacks = [checkpoint_callback] 45 | earlystop_callback = EarlyStopping( 46 | monitor="val_loss", min_delta=0.0, patience=cfg.train.early_stopping.patience, mode="min" 47 | ) 48 | callbacks.append(earlystop_callback) 49 | 50 | trainer = Trainer( 51 | **cfg.train.trainer_args, 52 | default_root_dir=hydra.utils.get_original_cwd(), 53 | limit_train_batches=0.01 if debug else 1.0, 54 | limit_val_batches=0.5 if debug else 1.0, 55 | callbacks=callbacks, 56 | logger=loggers, 57 | ) 58 | 59 | datamodule = CVDataModule(cfg=cfg, k_cv=k_cv, i_cv=i_cv) 60 | val_datamodule = TestDataModule(cfg=cfg, i_cv=i_cv, set_name='val') 61 | test_datamodule = TestDataModule(cfg=cfg, i_cv=i_cv, set_name='test') 62 | lightning_module = UTMOSLightningModule(cfg) 63 | trainer.fit(lightning_module, datamodule=datamodule) 64 | 65 | if debug: 66 | trainer.test(lightning_module, verbose=True, datamodule=datamodule) 67 | result = trainer.logged_metrics["test_SRCC_SYS_main_i_cv_{}_set_name_{}".format(i_cv, "fold")] 68 | trainer.test(lightning_module, verbose=True, datamodule=val_datamodule) 69 | trainer.test(lightning_module, verbose=True, datamodule=test_datamodule) 70 | else: 71 | trainer.test(lightning_module, verbose=True, datamodule=datamodule, ckpt_path=checkpoint_callback.best_model_path) 72 | result = trainer.logged_metrics["test_SRCC_SYS_main_i_cv_{}_set_name_{}".format(i_cv, "fold")] 73 | trainer.test(lightning_module, verbose=True, datamodule=val_datamodule, ckpt_path=checkpoint_callback.best_model_path) 74 | trainer.test(lightning_module, verbose=True, datamodule=test_datamodule, ckpt_path=checkpoint_callback.best_model_path) 75 | 76 | return result 77 | 78 | if __name__ == "__main__": 79 | cross_validation() 80 | -------------------------------------------------------------------------------- /strong/data: -------------------------------------------------------------------------------- 1 | /media/ssd/voicemos/data/2022/ -------------------------------------------------------------------------------- /strong/data_augment.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | import augment 4 | import numpy as np 5 | import random 6 | 7 | 8 | class ChainRunner: 9 | """ 10 | Takes an instance of augment.EffectChain and applies it on pytorch tensors. 11 | """ 12 | 13 | def __init__(self, chain): 14 | self.chain = chain 15 | 16 | def __call__(self, x): 17 | """ 18 | x: torch.Tensor, (channels, length). Must be placed on CPU. 19 | """ 20 | src_info = {'channels': x.size(0), # number of channels 21 | 'length': x.size(1), # length of the sequence 22 | 'precision': 32, # precision (16, 32 bits) 23 | 'rate': 16000.0, # sampling rate 24 | 'bits_per_sample': 32} # size of the sample 25 | 26 | target_info = {'channels': 1, 27 | 'length': x.size(1), 28 | 'precision': 32, 29 | 'rate': 16000.0, 30 | 'bits_per_sample': 32} 31 | 32 | y = self.chain.apply( 33 | x, src_info=src_info, target_info=target_info) 34 | 35 | if torch.isnan(y).any() or torch.isinf(y).any(): 36 | return x.clone() 37 | return y 38 | 39 | 40 | def random_pitch_shift(a=-300, b=300): 41 | return random.randint(a, b) 42 | 43 | def random_time_warp(f=1): 44 | # time warp range: [1-0.1*f, 1+0.1*f], default is [0.9, 1.1] 45 | return 1 + f * (random.random() - 0.5) / 5 46 | 47 | if __name__ == '__main__': 48 | chain = augment.EffectChain() 49 | chain.pitch(random_pitch_shift).rate(16000) 50 | chain.tempo(random_time_warp) 51 | chain = ChainRunner(chain) 52 | wav = torch.randn((1, 16000)) 53 | augmented = chain(wav) 54 | print(wav.shape, augmented.shape) 55 | print(wav[:, :-1000], augmented[:, :-1000]) 56 | -------------------------------------------------------------------------------- /strong/lightning_module.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | import torch.nn as nn 4 | from WavLM import WavLM, WavLMConfig 5 | import os 6 | import fairseq 7 | import numpy as np 8 | import scipy.stats 9 | import hydra 10 | from transformers import AdamW, get_linear_schedule_with_warmup 11 | from model import load_ssl_model, PhonemeEncoder, DomainEmbedding, LDConditioner, Projection 12 | import wandb 13 | 14 | 15 | class UTMOSLightningModule(pl.LightningModule): 16 | def __init__(self, cfg): 17 | super().__init__() 18 | self.cfg = cfg 19 | self.construct_model() 20 | self.prepare_domain_table() 21 | self.save_hyperparameters() 22 | 23 | def construct_model(self): 24 | self.feature_extractors = nn.ModuleList([ 25 | hydra.utils.instantiate(feature_extractor) for feature_extractor in self.cfg.model.feature_extractors 26 | ]) 27 | output_dim = sum([ feature_extractor.get_output_dim() for feature_extractor in self.feature_extractors]) 28 | output_layers = [] 29 | for output_layer in self.cfg.model.output_layers: 30 | output_layers.append( 31 | hydra.utils.instantiate(output_layer,input_dim=output_dim) 32 | ) 33 | output_dim = output_layers[-1].get_output_dim() 34 | 35 | self.output_layers = nn.ModuleList(output_layers) 36 | 37 | self.criterion = self.configure_criterion() 38 | 39 | def prepare_domain_table(self): 40 | self.domain_table = {} 41 | data_sources = self.cfg.dataset.data_sources 42 | for idx, datasource in enumerate(data_sources): 43 | if not self.cfg.dataset.use_data.external and datasource['name'] == 'external': 44 | data_sources.pop(idx) 45 | for idx, datasource in enumerate(data_sources): 46 | if not self.cfg.dataset.use_data.main and datasource['name'] == 'main': 47 | data_sources.pop(idx) 48 | for idx, datasource in enumerate(data_sources): 49 | if not self.cfg.dataset.use_data.ood and datasource['name'] == 'ood': 50 | data_sources.pop(idx) 51 | for i, datasource in enumerate(data_sources): 52 | if not hasattr(datasource,'val_mos_list_path'): 53 | continue 54 | self.domain_table[i] = datasource["name"] 55 | 56 | def forward(self, inputs): 57 | outputs = {} 58 | for feature_extractor in self.feature_extractors: 59 | outputs.update(feature_extractor(inputs)) 60 | x = outputs 61 | for output_layer in self.output_layers: 62 | x = output_layer(x,inputs) 63 | return x 64 | 65 | def training_step(self, batch, batch_idx): 66 | outputs = self(batch) 67 | loss = self.criterion(outputs, batch['score']) 68 | self.log( 69 | "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=self.cfg.train.train_batch_size 70 | ) 71 | return loss 72 | 73 | def validation_step(self, batch, batch_idx): 74 | outputs = self(batch) 75 | loss = self.criterion(outputs, batch['score']) 76 | if outputs.dim() > 1: 77 | outputs = outputs.mean(dim=1).squeeze(-1) 78 | return { 79 | "loss": loss, 80 | "outputs": outputs.cpu().numpy()[0]*2 +3.0, 81 | "filename": batch["wavname"][0], 82 | "domain": batch["domain"][0], 83 | "utt_avg_score": batch["utt_avg_score"][0].item(), 84 | "sys_avg_score": batch["sys_avg_score"][0].item() 85 | } 86 | 87 | def validation_epoch_end(self, outputs): 88 | val_loss = torch.stack([out["loss"] for out in outputs]).mean().item() 89 | self.log("val_loss", val_loss, on_epoch=True, prog_bar=True, logger=True) 90 | for domain_id in self.domain_table: 91 | outputs_domain = [out for out in outputs if out["domain"] == domain_id] 92 | if len(outputs_domain) == 0: 93 | continue 94 | _, SRCC, MSE = self.calc_score(outputs_domain) 95 | self.log( 96 | "val_SRCC_system_{}".format(self.domain_table[domain_id]), 97 | SRCC, 98 | on_epoch=True, 99 | prog_bar=True, 100 | logger=True 101 | ) 102 | self.log( 103 | "val_MSE_system_{}".format(self.domain_table[domain_id]), 104 | MSE, 105 | on_epoch=True, 106 | prog_bar=True, 107 | logger=True 108 | ) 109 | if domain_id == 0: 110 | self.log( 111 | "val_SRCC_system".format(self.domain_table[domain_id]), 112 | SRCC, 113 | on_epoch=True, 114 | prog_bar=True, 115 | logger=True 116 | ) 117 | 118 | def test_step(self, batch, batch_idx): 119 | outputs = self(batch) 120 | loss = self.criterion(outputs, batch['score']) 121 | labels = batch['score'] 122 | filenames = batch['wavname'] 123 | loss = self.criterion(outputs, labels) 124 | if outputs.dim() > 1: 125 | outputs = outputs.mean(dim=1).squeeze(-1) 126 | return { 127 | "loss": loss, 128 | "outputs": outputs.cpu().detach().numpy()[0]*2 +3.0, 129 | "labels": labels.cpu().detach().numpy()[0] *2 +3.0, 130 | "filename": filenames[0], 131 | "domain": batch["domain"][0], 132 | "i_cv": batch["i_cv"][0], 133 | "set_name": batch["set_name"][0], 134 | "utt_avg_score": batch["utt_avg_score"][0].item(), 135 | "sys_avg_score": batch["sys_avg_score"][0].item() 136 | } 137 | 138 | def test_epoch_end(self, outputs): 139 | outfiles = [datasource["outfile"] + '{}_{}'.format(outputs[0]['set_name'],outputs[0]['i_cv']) for datasource in self.cfg.dataset.data_sources if hasattr(datasource,'outfile')] 140 | for domain_id in self.domain_table: 141 | outputs_domain = [out for out in outputs if out["domain"] == domain_id] 142 | predictions, SRCC, MSE = self.calc_score(outputs_domain) 143 | self.log( 144 | "test_SRCC_SYS_{}_i_cv_{}_set_name_{}".format(self.domain_table[domain_id], outputs[0]['i_cv'], outputs[0]['set_name']), 145 | SRCC, 146 | ) 147 | if domain_id == 0: 148 | self.log( 149 | "test_SRCC_SYS".format(self.domain_table[domain_id]), 150 | SRCC 151 | ) 152 | with open(outfiles[domain_id], "w") as fw: 153 | for k, v in predictions.items(): 154 | outl = k.split(".")[0] + "," + str(v) + "\n" 155 | fw.write(outl) 156 | try: 157 | wandb.save(outfiles[domain_id]) 158 | except: 159 | print('outfile {} saved'.format(outfiles[domain_id])) 160 | 161 | def configure_optimizers(self): 162 | optimizer = hydra.utils.instantiate( 163 | self.cfg.train.optimizer, 164 | params=self.parameters() 165 | ) 166 | scheduler = hydra.utils.instantiate( 167 | self.cfg.train.scheduler, 168 | optimizer=optimizer 169 | ) 170 | scheduler = {"scheduler": scheduler, 171 | "interval": "step", "frequency": 1} 172 | return {"optimizer": optimizer, "lr_scheduler": scheduler} 173 | 174 | def configure_criterion(self): 175 | return hydra.utils.instantiate(self.cfg.train.criterion,_recursive_=True) 176 | 177 | def calc_score(self, outputs, verbose=False): 178 | 179 | def systemID(uttID): 180 | return uttID.split("-")[0] 181 | 182 | predictions = {} 183 | true_MOS = {} 184 | true_sys_MOS_avg = {} 185 | for out in outputs: 186 | predictions[out["filename"]] = out["outputs"] 187 | true_MOS[out["filename"]] = out["utt_avg_score"] 188 | true_sys_MOS_avg[out["filename"].split("-")[0]] = out["sys_avg_score"] 189 | 190 | ## compute correls. 191 | sorted_uttIDs = sorted(predictions.keys()) 192 | ts = [] 193 | ps = [] 194 | for uttID in sorted_uttIDs: 195 | t = true_MOS[uttID] 196 | p = predictions[uttID] 197 | ts.append(t) 198 | ps.append(p) 199 | 200 | truths = np.array(ts) 201 | print(ps) 202 | preds = np.array(ps) 203 | 204 | ### UTTERANCE 205 | MSE = np.mean((truths - preds) ** 2) 206 | LCC = np.corrcoef(truths, preds) 207 | SRCC = scipy.stats.spearmanr(truths.T, preds.T) 208 | KTAU = scipy.stats.kendalltau(truths, preds) 209 | if verbose: 210 | print("[UTTERANCE] Test error= %f" % MSE) 211 | print("[UTTERANCE] Linear correlation coefficient= %f" % LCC[0][1]) 212 | print("[UTTERANCE] Spearman rank correlation coefficient= %f" % SRCC[0]) 213 | print("[UTTERANCE] Kendall Tau rank correlation coefficient= %f" % KTAU[0]) 214 | 215 | ### SYSTEM 216 | pred_sys_MOSes = {} 217 | for uttID in sorted_uttIDs: 218 | sysID = systemID(uttID) 219 | noop = pred_sys_MOSes.setdefault(sysID, []) 220 | pred_sys_MOSes[sysID].append(predictions[uttID]) 221 | 222 | pred_sys_MOS_avg = {} 223 | for k, v in pred_sys_MOSes.items(): 224 | avg_MOS = sum(v) / (len(v) * 1.0) 225 | pred_sys_MOS_avg[k] = avg_MOS 226 | 227 | ## make lists sorted by system 228 | pred_sysIDs = sorted(pred_sys_MOS_avg.keys()) 229 | sys_p = [] 230 | sys_t = [] 231 | for sysID in pred_sysIDs: 232 | sys_p.append(pred_sys_MOS_avg[sysID]) 233 | sys_t.append(true_sys_MOS_avg[sysID]) 234 | 235 | sys_true = np.array(sys_t) 236 | sys_predicted = np.array(sys_p) 237 | 238 | MSE = np.mean((sys_true - sys_predicted) ** 2) 239 | LCC = np.corrcoef(sys_true, sys_predicted) 240 | SRCC = scipy.stats.spearmanr(sys_true.T, sys_predicted.T) 241 | KTAU = scipy.stats.kendalltau(sys_true, sys_predicted) 242 | if verbose: 243 | print("[SYSTEM] Test error= %f" % MSE) 244 | print("[SYSTEM] Linear correlation coefficient= %f" % LCC[0][1]) 245 | print("[SYSTEM] Spearman rank correlation coefficient= %f" % SRCC[0]) 246 | print("[SYSTEM] Kendall Tau rank correlation coefficient= %f" % KTAU[0]) 247 | 248 | return predictions, SRCC[0], MSE 249 | 250 | 251 | 252 | class DeepSpeedBaselineLightningModule(UTMOSLightningModule): 253 | def __init__(self, cfg): 254 | super().__init__(cfg) 255 | 256 | def configure_optimizers(self): 257 | from deepspeed.ops.adam import DeepSpeedCPUAdam 258 | return DeepSpeedCPUAdam(self.parameters()) -------------------------------------------------------------------------------- /strong/loss_function.py: -------------------------------------------------------------------------------- 1 | from numpy import dtype 2 | import torch 3 | import torch.nn as nn 4 | 5 | class ContrastiveLoss(nn.Module): 6 | ''' 7 | Contrastive Loss 8 | Args: 9 | margin: non-neg value, the smaller the stricter the loss will be, default: 0.2 10 | 11 | ''' 12 | def __init__(self, margin=0.2): 13 | super(ContrastiveLoss, self).__init__() 14 | self.margin = margin 15 | 16 | def forward(self, pred_score, gt_score): 17 | if pred_score.dim() > 2: 18 | pred_score = pred_score.mean(dim=1).squeeze(1) 19 | # pred_score, gt_score: tensor, [batch_size] 20 | gt_diff = gt_score.unsqueeze(1) - gt_score.unsqueeze(0) 21 | pred_diff = pred_score.unsqueeze(1) - pred_score.unsqueeze(0) 22 | loss = torch.maximum(torch.zeros(gt_diff.shape).to(gt_diff.device), torch.abs(pred_diff - gt_diff) - self.margin) 23 | loss = loss.mean().div(2) 24 | return loss 25 | 26 | 27 | class ClippedMSELoss(nn.Module): 28 | """ 29 | clipped MSE loss for listener-dependent model 30 | """ 31 | def __init__(self, criterion,tau,mode='frame'): 32 | super(ClippedMSELoss, self).__init__() 33 | self.tau = torch.tensor(tau,dtype=torch.float) 34 | 35 | self.criterion = criterion 36 | self.mode = mode 37 | 38 | 39 | def forward_criterion(self, y_hat, label): 40 | 41 | y_hat = y_hat.squeeze(-1) 42 | loss = self.criterion(y_hat, label) 43 | threshold = torch.abs(y_hat - label) > self.tau 44 | loss = torch.mean(threshold * loss) 45 | return loss 46 | 47 | def forward(self, pred_score, gt_score): 48 | """ 49 | Args: 50 | pred_mean, pred_score: [batch, time, 1/5] 51 | """ 52 | # repeat for frame level loss 53 | time = pred_score.shape[1] 54 | if self.mode == 'utt': 55 | pred_score = pred_score.mean(dim=1) 56 | else: 57 | gt_score = gt_score.unsqueeze(1).repeat(1, time) 58 | main_loss = self.forward_criterion(pred_score, gt_score) 59 | return main_loss # lamb 1.0 60 | 61 | class CombineLosses(nn.Module): 62 | ''' 63 | Combine losses 64 | Args: 65 | loss_weights: a list of weights for each loss 66 | ''' 67 | def __init__(self, loss_weights:list, loss_instances:list): 68 | super(CombineLosses, self).__init__() 69 | self.loss_weights = loss_weights 70 | self.loss_instances = nn.ModuleList(loss_instances) 71 | def forward(self, pred_score, gt_score): 72 | loss = torch.tensor(0,dtype=torch.float).to(pred_score.device) 73 | for loss_weight, loss_instance in zip(self.loss_weights, self.loss_instances): 74 | loss += loss_weight * loss_instance(pred_score,gt_score) 75 | return loss 76 | -------------------------------------------------------------------------------- /strong/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from WavLM import WavLM, WavLMConfig 4 | from text.symbols import symbols 5 | import fairseq 6 | import os 7 | import hydra 8 | 9 | def load_ssl_model(cp_path): 10 | cp_path = os.path.join( 11 | hydra.utils.get_original_cwd(), 12 | cp_path 13 | ) 14 | ssl_model_type = cp_path.split("/")[-1] 15 | wavlm = "WavLM" in ssl_model_type 16 | if wavlm: 17 | checkpoint = torch.load(cp_path) 18 | cfg = WavLMConfig(checkpoint['cfg']) 19 | ssl_model = WavLM(cfg) 20 | ssl_model.load_state_dict(checkpoint['model']) 21 | if 'Large' in ssl_model_type: 22 | SSL_OUT_DIM = 1024 23 | else: 24 | SSL_OUT_DIM = 768 25 | else: 26 | if ssl_model_type == "wav2vec_small.pt": 27 | SSL_OUT_DIM = 768 28 | elif ssl_model_type in ["w2v_large_lv_fsh_swbd_cv.pt", "xlsr_53_56k.pt"]: 29 | SSL_OUT_DIM = 1024 30 | else: 31 | print("*** ERROR *** SSL model type " + ssl_model_type + " not supported.") 32 | exit() 33 | model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task( 34 | [cp_path] 35 | ) 36 | ssl_model = model[0] 37 | ssl_model.remove_pretraining_modules() 38 | return SSL_model(ssl_model, SSL_OUT_DIM, wavlm) 39 | 40 | class SSL_model(nn.Module): 41 | def __init__(self,ssl_model,ssl_out_dim,wavlm) -> None: 42 | super(SSL_model,self).__init__() 43 | self.ssl_model, self.ssl_out_dim = ssl_model, ssl_out_dim 44 | self.WavLM = wavlm 45 | 46 | def forward(self,batch): 47 | wav = batch['wav'] 48 | wav = wav.squeeze(1) # [batches, audio_len] 49 | if self.WavLM: 50 | x = self.ssl_model.extract_features(wav)[0] 51 | else: 52 | res = self.ssl_model(wav, mask=False, features_only=True) 53 | x = res["x"] 54 | return {"ssl-feature":x} 55 | def get_output_dim(self): 56 | return self.ssl_out_dim 57 | 58 | 59 | class PhonemeEncoder(nn.Module): 60 | ''' 61 | PhonemeEncoder consists of an embedding layer, an LSTM layer, and a linear layer. 62 | Args: 63 | vocab_size: the size of the vocabulary 64 | hidden_dim: the size of the hidden state of the LSTM 65 | emb_dim: the size of the embedding layer 66 | out_dim: the size of the output of the linear layer 67 | n_lstm_layers: the number of LSTM layers 68 | ''' 69 | def __init__(self, vocab_size, hidden_dim, emb_dim, out_dim,n_lstm_layers,with_reference=True) -> None: 70 | super().__init__() 71 | self.with_reference = with_reference 72 | self.embedding = nn.Embedding(vocab_size, emb_dim) 73 | self.encoder = nn.LSTM(emb_dim, hidden_dim, 74 | num_layers=n_lstm_layers, dropout=0.1, bidirectional=True) 75 | self.linear = nn.Sequential( 76 | nn.Linear(hidden_dim + hidden_dim*self.with_reference, out_dim), 77 | nn.ReLU() 78 | ) 79 | self.out_dim = out_dim 80 | 81 | def forward(self,batch): 82 | seq = batch['phonemes'] 83 | lens = batch['phoneme_lens'] 84 | reference_seq = batch['reference'] 85 | reference_lens = batch['reference_lens'] 86 | emb = self.embedding(seq) 87 | emb = torch.nn.utils.rnn.pack_padded_sequence( 88 | emb, lens, batch_first=True, enforce_sorted=False) 89 | _, (ht, _) = self.encoder(emb) 90 | feature = ht[-1] + ht[0] 91 | if self.with_reference: 92 | if reference_seq==None or reference_lens ==None: 93 | raise ValueError("reference_batch and reference_lens should not be None when with_reference is True") 94 | reference_emb = self.embedding(reference_seq) 95 | reference_emb = torch.nn.utils.rnn.pack_padded_sequence( 96 | reference_emb, reference_lens, batch_first=True, enforce_sorted=False) 97 | _, (ht_ref, _) = self.encoder(emb) 98 | reference_feature = ht_ref[-1] + ht_ref[0] 99 | feature = self.linear(torch.cat([feature,reference_feature],1)) 100 | else: 101 | feature = self.linear(feature) 102 | return {"phoneme-feature": feature} 103 | def get_output_dim(self): 104 | return self.out_dim 105 | 106 | class DomainEmbedding(nn.Module): 107 | def __init__(self,n_domains,domain_dim) -> None: 108 | super().__init__() 109 | self.embedding = nn.Embedding(n_domains,domain_dim) 110 | self.output_dim = domain_dim 111 | def forward(self, batch): 112 | return {"domain-feature": self.embedding(batch['domains'])} 113 | def get_output_dim(self): 114 | return self.output_dim 115 | 116 | 117 | class LDConditioner(nn.Module): 118 | ''' 119 | Conditions ssl output by listener embedding 120 | ''' 121 | def __init__(self,input_dim, judge_dim, num_judges=None): 122 | super().__init__() 123 | self.input_dim = input_dim 124 | self.judge_dim = judge_dim 125 | self.num_judges = num_judges 126 | assert num_judges !=None 127 | self.judge_embedding = nn.Embedding(num_judges, self.judge_dim) 128 | # concat [self.output_layer, phoneme features] 129 | 130 | self.decoder_rnn = nn.LSTM( 131 | input_size = self.input_dim + self.judge_dim, 132 | hidden_size = 512, 133 | num_layers = 1, 134 | batch_first = True, 135 | bidirectional = True 136 | ) # linear? 137 | self.out_dim = self.decoder_rnn.hidden_size*2 138 | 139 | def get_output_dim(self): 140 | return self.out_dim 141 | 142 | 143 | def forward(self, x, batch): 144 | judge_ids = batch['judge_id'] 145 | if 'phoneme-feature' in x.keys(): 146 | concatenated_feature = torch.cat((x['ssl-feature'], x['phoneme-feature'].unsqueeze(1).expand(-1,x['ssl-feature'].size(1) ,-1)),dim=2) 147 | else: 148 | concatenated_feature = x['ssl-feature'] 149 | if 'domain-feature' in x.keys(): 150 | concatenated_feature = torch.cat( 151 | ( 152 | concatenated_feature, 153 | x['domain-feature'] 154 | .unsqueeze(1) 155 | .expand(-1, concatenated_feature.size(1), -1), 156 | ), 157 | dim=2, 158 | ) 159 | if judge_ids != None: 160 | concatenated_feature = torch.cat( 161 | ( 162 | concatenated_feature, 163 | self.judge_embedding(judge_ids) 164 | .unsqueeze(1) 165 | .expand(-1, concatenated_feature.size(1), -1), 166 | ), 167 | dim=2, 168 | ) 169 | decoder_output, (h, c) = self.decoder_rnn(concatenated_feature) 170 | return decoder_output 171 | 172 | class Projection(nn.Module): 173 | def __init__(self, input_dim, hidden_dim, activation, range_clipping=False): 174 | super(Projection, self).__init__() 175 | self.range_clipping = range_clipping 176 | output_dim = 1 177 | if range_clipping: 178 | self.proj = nn.Tanh() 179 | 180 | self.net = nn.Sequential( 181 | nn.Linear(input_dim, hidden_dim), 182 | activation, 183 | nn.Dropout(0.3), 184 | nn.Linear(hidden_dim, output_dim), 185 | ) 186 | self.output_dim = output_dim 187 | 188 | def forward(self, x, batch): 189 | output = self.net(x) 190 | 191 | # range clipping 192 | if self.range_clipping: 193 | return self.proj(output) * 2.0 + 3 194 | else: 195 | return output 196 | def get_output_dim(self): 197 | return self.output_dim 198 | -------------------------------------------------------------------------------- /strong/param_tuning.py: -------------------------------------------------------------------------------- 1 | from lib2to3.pytree import Base 2 | from pytorch_lightning.loggers.csv_logs import CSVLogger 3 | from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger 4 | from pytorch_lightning.callbacks import ModelCheckpoint 5 | from pytorch_lightning import Trainer 6 | from pytorch_lightning.callbacks.early_stopping import EarlyStopping 7 | from pytorch_lightning.callbacks import LearningRateMonitor 8 | import hydra 9 | from dataset import DataModule 10 | import wandb 11 | 12 | @hydra.main(config_path="configs",config_name='optuna-main') 13 | def train(cfg): 14 | debug = cfg.debug 15 | 16 | if cfg.batch_size_and_model == "wav2vec2-base-4": 17 | cfg.model.feature_extractors[0]["cp_path"] = "fairseq/wav2vec_small.pt" 18 | cfg.train.train_batch_size = 4 19 | elif cfg.batch_size_and_model == "wav2vec2-base-8": 20 | cfg.model.feature_extractors[0]["cp_path"] = "fairseq/wav2vec_small.pt" 21 | cfg.train.train_batch_size = 8 22 | elif cfg.batch_size_and_model == "wav2vec2-base-16": 23 | cfg.model.feature_extractors[0]["cp_path"] = "fairseq/wav2vec_small.pt" 24 | cfg.train.train_batch_size = 16 25 | elif cfg.batch_size_and_model == "wav2vec2-base-32": 26 | cfg.model.feature_extractors[0]["cp_path"] = "fairseq/wav2vec_small.pt" 27 | cfg.train.train_batch_size = 32 28 | elif cfg.batch_size_and_model == "wavlm-large-4": 29 | cfg.model.feature_extractors[0]["cp_path"] = "fairseq/WavLM-Large.pt" 30 | cfg.train.train_batch_size = 4 31 | print(cfg.batch_size_and_model) 32 | print(cfg.model.feature_extractors[0]["cp_path"]) 33 | print(cfg.train.train_batch_size) 34 | 35 | if cfg.dataset.use_data.main: 36 | cfg.dataset.data_sources.pop(0) 37 | if cfg.dataset.use_data.ood: 38 | cfg.dataset.data_sources.pop(1) 39 | if cfg.dataset.use_data.external: 40 | cfg.dataset.data_sources.pop(2) 41 | 42 | loggers = [] 43 | loggers.append(CSVLogger(save_dir=cfg.train.out_dir, name="train_log")) 44 | loggers.append(TensorBoardLogger(save_dir=cfg.train.out_dir, name="tf_log")) 45 | if cfg.train.use_wandb: 46 | loggers.append(WandbLogger(project="voicemos",offline=debug)) 47 | 48 | checkpoint_callback = ModelCheckpoint( 49 | dirpath=cfg.train.out_dir, 50 | save_weights_only=True, 51 | save_top_k=1, 52 | save_last=False, 53 | every_n_epochs=1, 54 | monitor=cfg.train.model_selection_metric, 55 | mode='max' 56 | ) 57 | lr_monitor = LearningRateMonitor(logging_interval='step') 58 | callbacks = [checkpoint_callback,lr_monitor] 59 | earlystop_callback = EarlyStopping( 60 | monitor="val_loss", min_delta=0.0, patience=cfg.train.early_stopping.patience, mode="min" 61 | ) 62 | callbacks.append(earlystop_callback) 63 | 64 | trainer = Trainer( 65 | **cfg.train.trainer_args, 66 | default_root_dir=hydra.utils.get_original_cwd(), 67 | limit_train_batches=0.01 if debug else 1.0, 68 | limit_val_batches=0.5 if debug else 1.0, 69 | callbacks=callbacks, 70 | logger=loggers, 71 | ) 72 | 73 | lightning_module = hydra.utils.instantiate(cfg.model.lightning_module,cfg=cfg , _recursive_=False) 74 | wandblogger.watch(lightning_module) 75 | datamodule = hydra.utils.instantiate(cfg.dataset.datamodule,cfg=cfg,_recursive_=False) 76 | trainer.fit(lightning_module,datamodule=datamodule) 77 | trainer.test(lightning_module, datamodule=datamodule,ckpt_path=checkpoint_callback.best_model_path) 78 | 79 | 80 | SRCC_system = trainer.logged_metrics[cfg.tuning_target] 81 | wandb.finish() 82 | 83 | return SRCC_system 84 | 85 | if __name__ == "__main__": 86 | train() 87 | -------------------------------------------------------------------------------- /strong/predict.py: -------------------------------------------------------------------------------- 1 | from omegaconf import open_dict 2 | from pytorch_lightning import Trainer 3 | import hydra 4 | import os 5 | import pathlib 6 | from lightning_module import UTMOSLightningModule 7 | from dataset import TestDataModule, DataModule 8 | import torch 9 | 10 | @hydra.main(config_path="configs",config_name='default') 11 | def predict(cfg): 12 | """ 13 | Specify ckeckpoint path as follows: 14 | 15 | python predict.py +ckpt_path="outputs/${date}/${time}/train_outputs/hoge.ckpt" 16 | """ 17 | 18 | trainer = Trainer( 19 | **cfg.train.trainer_args, 20 | default_root_dir=hydra.utils.get_original_cwd(), 21 | ) 22 | 23 | ckpt_path = pathlib.Path(cfg.ckpt_path) 24 | if not ckpt_path.is_absolute(): 25 | ckpt_path = (pathlib.Path(hydra.utils.get_original_cwd()) / ckpt_path) 26 | 27 | if 'paper_weights' in cfg.keys(): 28 | with open_dict(cfg): 29 | ckpt = torch.load(ckpt_path) 30 | use_data = ckpt['hyper_parameters']['cfg']['dataset']['use_data'] 31 | cfg.dataset.use_data['external'] = use_data['lancers'] 32 | cfg.dataset.use_data['main'] = use_data['main'] 33 | cfg.dataset.use_data['ood'] = use_data['ood'] 34 | cfg.dataset.only_mean = ckpt['hyper_parameters']['cfg']['dataset']['only_mean'] 35 | lightning_module = UTMOSLightningModule.load_from_checkpoint(ckpt_path,cfg=cfg,paper_weight=cfg.paper_weights) 36 | lightning_module.cfg 37 | else: 38 | lightning_module = UTMOSLightningModule.load_from_checkpoint(ckpt_path) 39 | print(lightning_module.cfg) 40 | datamodule = DataModule(lightning_module.cfg) 41 | test_datamodule = TestDataModule(cfg=lightning_module.cfg, i_cv=0, set_name='test') 42 | trainer.test( 43 | lightning_module, 44 | verbose=True, 45 | datamodule=datamodule, 46 | ckpt_path=ckpt_path 47 | ) 48 | trainer.test( 49 | lightning_module, 50 | verbose=True, 51 | datamodule=test_datamodule, 52 | ckpt_path=ckpt_path 53 | ) 54 | 55 | if __name__ == "__main__": 56 | predict() 57 | -------------------------------------------------------------------------------- /strong/text/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Defines the set of symbols used in text input to the model. 5 | ''' 6 | _pad = '_' 7 | _punctuation = ';:,.!?¡¿—…"«»“” ' 8 | _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' 9 | _numbers = '0123456789' 10 | _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ'̪'̃" 11 | 12 | 13 | # Export all symbols: 14 | symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) + list(_numbers) 15 | 16 | # Special symbol ids 17 | SPACE_ID = symbols.index(" ") -------------------------------------------------------------------------------- /strong/train.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from pytorch_lightning.loggers.csv_logs import CSVLogger 3 | from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger 4 | from pytorch_lightning.callbacks import ModelCheckpoint 5 | from pytorch_lightning import Trainer, seed_everything 6 | from pytorch_lightning.callbacks.early_stopping import EarlyStopping 7 | from pytorch_lightning.callbacks import LearningRateMonitor 8 | from dataset import CVDataModule, TestDataModule 9 | from lightning_module import UTMOSLightningModule 10 | import hydra 11 | import wandb 12 | 13 | 14 | 15 | @hydra.main(config_path="configs",config_name='default') 16 | def train(cfg): 17 | debug = cfg.debug 18 | if debug: 19 | cfg.train.train_batch_size=4 20 | cfg.train.trainer_args.max_steps=10 21 | 22 | loggers = [] 23 | loggers.append(CSVLogger(save_dir=cfg.train.out_dir, name="train_log")) 24 | loggers.append(TensorBoardLogger(save_dir=cfg.train.out_dir, name="tf_log")) 25 | if cfg.train.use_wandb: 26 | loggers.append(WandbLogger(project="voicemos",offline=debug)) 27 | 28 | checkpoint_callback = ModelCheckpoint( 29 | dirpath=cfg.train.out_dir, 30 | save_weights_only=False, 31 | save_top_k=1, 32 | save_last=True, 33 | every_n_epochs=1, 34 | monitor=cfg.train.model_selection_metric, 35 | mode='max' 36 | ) 37 | lr_monitor = LearningRateMonitor(logging_interval='step') 38 | callbacks = [checkpoint_callback,lr_monitor] 39 | earlystop_callback = EarlyStopping( 40 | monitor="val_loss", min_delta=0.0, patience=cfg.train.early_stopping.patience, mode="min" 41 | ) 42 | callbacks.append(earlystop_callback) 43 | 44 | trainer = Trainer( 45 | **cfg.train.trainer_args, 46 | default_root_dir=hydra.utils.get_original_cwd(), 47 | limit_train_batches=0.01 if debug else 1.0, 48 | limit_val_batches=0.5 if debug else 1.0, 49 | callbacks=callbacks, 50 | logger=loggers, 51 | ) 52 | 53 | datamodule = hydra.utils.instantiate(cfg.dataset.datamodule,cfg=cfg,_recursive_=False) 54 | test_datamodule = TestDataModule(cfg=cfg, i_cv=0, set_name='test') 55 | lightning_module = UTMOSLightningModule(cfg) 56 | 57 | trainer.fit(lightning_module, datamodule=datamodule) 58 | 59 | if debug: 60 | trainer.test(lightning_module, datamodule=datamodule) 61 | trainer.test(lightning_module, datamodule=test_datamodule) 62 | else: 63 | trainer.test(lightning_module, datamodule=datamodule,ckpt_path=checkpoint_callback.best_model_path) 64 | trainer.test(lightning_module, datamodule=test_datamodule,ckpt_path=checkpoint_callback.best_model_path) 65 | if cfg.train.use_wandb: 66 | wandb.save(checkpoint_callback.best_model_path) 67 | 68 | if __name__ == "__main__": 69 | train() 70 | -------------------------------------------------------------------------------- /strong/transcribe_speech.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC 3 | from datasets import load_dataset 4 | import soundfile as sf 5 | import torch 6 | from datasets import set_caching_enabled 7 | import pandas as pd 8 | import Levenshtein 9 | import numpy as np 10 | from sklearn.cluster import DBSCAN 11 | 12 | def cluster_transcriptions(df:pd.DataFrame): 13 | data = df['transcription'].to_list() 14 | def lev_metric(x, y): 15 | i, j = int(x[0]), int(y[0]) # extract indices 16 | return Levenshtein.distance(data[i], data[j])/max(len(data[i]), len(data[j])) 17 | 18 | X = np.arange(len(data)).reshape(-1, 1) 19 | result = DBSCAN(eps=0.3, metric=lev_metric,n_jobs=20,min_samples=3).fit(X) 20 | df['cluster'] = result.labels_ 21 | text_medians = df.groupby('cluster').apply(lambda x:Levenshtein.median(x['transcription'].to_list())) 22 | medians = [] 23 | for idx, row in df.iterrows(): 24 | if row['cluster'] == -1: 25 | medians.append(row['transcription']) 26 | else: 27 | medians.append(text_medians[row['cluster']]) 28 | df['reference'] = medians 29 | return df 30 | 31 | 32 | 33 | if __name__ == '__main__': 34 | set_caching_enabled(False) 35 | # load datasets 36 | 37 | processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-lv-60-espeak-cv-ft") 38 | model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-lv-60-espeak-cv-ft") 39 | from tqdm import tqdm 40 | _ = model.to('cuda') 41 | batch_size = 1 42 | all_df = pd.DataFrame() 43 | def collate_fn(batch): 44 | import numpy as np 45 | wav_name = batch[0]['audio']['path'] 46 | wav_array = batch[0]['audio']['array'] 47 | return wav_name, processor(wav_array,sampling_rate=16_000,return_tensors="pt").input_values 48 | for track in 'main', 'ood', 'unlabeled': 49 | wav_names = [] 50 | transcriptions = [] 51 | if track == 'main': 52 | dataset = load_dataset("sarulab-speech/bvcc-voicemos2022","main_track", data_dir="data/phase1-main/", use_auth_token=True,download_mode="force_redownload") 53 | elif track == 'ood': 54 | dataset = load_dataset("sarulab-speech/bvcc-voicemos2022","ood_track", data_dir="data/phase1-ood/", use_auth_token=True,download_mode="force_redownload") 55 | else: 56 | dataset = load_dataset("sarulab-speech/bvcc-voicemos2022","ood_track_unlabeled", data_dir="data/phase1-ood/", use_auth_token=True,download_mode='force_redownload') 57 | for stage in 'train', 'validation', 'test': 58 | if track == 'unlabeled' and stage != 'train': 59 | continue 60 | print('Processing {track} track {stage}'.format(track=track, stage=stage)) 61 | dl = torch.utils.data.DataLoader(dataset[stage], batch_size=batch_size, num_workers=4,collate_fn=collate_fn) 62 | for wav_name, data in tqdm(dl): 63 | 64 | # retrieve logits 65 | with torch.no_grad(): 66 | logits = model(data.to('cuda')).logits 67 | 68 | # take argmax and decode 69 | predicted_ids = torch.argmax(logits, dim=-1) 70 | transcription = processor.batch_decode(predicted_ids) 71 | transcriptions.extend(transcription) 72 | wav_names.append(wav_name) 73 | del(dataset) 74 | import pandas as pd 75 | df = pd.DataFrame({"wav_name": wav_names, "transcription": transcriptions}) 76 | df['wav_name'] = df['wav_name'].apply(lambda x: x.split("/")[-1]) 77 | df['track'] = track if track != 'unlabeled' else 'ood' 78 | all_df = pd.concat([all_df, df],ignore_index=True) 79 | result = pd.concat( 80 | [ 81 | cluster_transcriptions(all_df[all_df['track'] == 'main'].copy()), 82 | cluster_transcriptions(all_df[all_df['track'] == 'ood'].copy()), 83 | ], 84 | ignore_index=True 85 | ) 86 | result.to_csv('transcriptions_clustered.csv'.format(track), index=False) 87 | --------------------------------------------------------------------------------