├── .gitignore
├── LICENSE
├── README.md
├── fairseq_checkpoints
├── download_stacking_checkpoints.sh
└── download_strong_checkpoints.sh
├── poetry.lock
├── pyproject.toml
├── remove_silenceWav.py
├── stacking
├── ensemble_multidomain_scripts
│ ├── README.md
│ ├── WavLM.py
│ ├── calc_result.py
│ ├── calc_testphase_result.py
│ ├── collect_stage1_result.py
│ ├── collect_stage1_testphase_result.py
│ ├── collect_stage2_result.py
│ ├── collect_stage2_testphase_result.py
│ ├── convert_strong_learner_result.py
│ ├── convert_strong_learner_testphase_result.py
│ ├── data_util.py
│ ├── external_mos_list.txt
│ ├── extract_ssl_feature.py
│ ├── gp_models.py
│ ├── make_ensemble_dataset.py
│ ├── make_ensemble_dataset_wotest.py
│ ├── make_ensemble_testphase.py
│ ├── models.py
│ ├── modules.py
│ ├── opt_stage1.py
│ ├── opt_stage2.py
│ ├── opt_stage3.py
│ ├── pred_stage1.py
│ ├── pred_stage1_ood.sh
│ ├── pred_testphase_stage1.py
│ ├── pred_testphase_stage1_main.sh
│ ├── pred_testphase_stage1_ood.sh
│ ├── pred_testphase_stage2-3_main.sh
│ ├── pred_testphase_stage2-3_ood.sh
│ ├── pred_testphase_stage2.py
│ ├── pred_testphase_stage3.py
│ ├── run_stage1.py
│ ├── run_stage1.sh
│ ├── run_stage2-3_main.sh
│ ├── run_stage2-3_ood.sh
│ ├── run_stage2.py
│ ├── run_stage3.py
│ ├── stage2-method
│ │ ├── main-strong1-weak48.yaml
│ │ └── ood-strong1-weak144.yaml
│ └── unused
│ │ ├── opt_stage1_all.sh
│ │ ├── pred_stage1_exactgp.sh
│ │ ├── run_stage1_exactgp.sh
│ │ ├── run_stage1_other.sh
│ │ └── run_stage2-end_opt.sh
└── strong_learner_result
│ ├── main1
│ ├── answer-main.csvfold_0
│ ├── answer-main.csvfold_1
│ ├── answer-main.csvfold_2
│ ├── answer-main.csvfold_3
│ ├── answer-main.csvfold_4
│ ├── answer-main.csvtest_0
│ ├── answer-main.csvtest_1
│ ├── answer-main.csvtest_2
│ ├── answer-main.csvtest_3
│ ├── answer-main.csvtest_4
│ ├── answer-main.csvval_0
│ ├── answer-main.csvval_1
│ ├── answer-main.csvval_2
│ ├── answer-main.csvval_3
│ └── answer-main.csvval_4
│ └── ood1
│ ├── answer-ood.csvfold_0
│ ├── answer-ood.csvfold_1
│ ├── answer-ood.csvfold_2
│ ├── answer-ood.csvtest_0
│ ├── answer-ood.csvtest_1
│ ├── answer-ood.csvtest_2
│ ├── answer-ood.csvval_0
│ ├── answer-ood.csvval_1
│ └── answer-ood.csvval_2
└── strong
├── README.md
├── TRAINSET_external.txt
├── WavLM.py
├── configs
├── cv.yaml
├── dataset
│ ├── cv.yaml
│ ├── default.yaml
│ └── wo_augment.yaml
├── default.yaml
├── model
│ ├── default.yaml
│ └── wo_phoneme.yaml
├── optuna-main.yaml
├── optuna-ood.yaml
└── train
│ └── default.yaml
├── cross_validation.py
├── data
├── data_augment.py
├── dataset.py
├── lightning_module.py
├── loss_function.py
├── model.py
├── modules.py
├── param_tuning.py
├── predict.py
├── text
└── symbols.py
├── train.py
├── transcribe_speech.py
└── transcriptions_clustered.csv
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | # Created by https://www.toptal.com/developers/gitignore/api/python
3 | # Edit at https://www.toptal.com/developers/gitignore?templates=python
4 |
5 | ### Python ###
6 | # Byte-compiled / optimized / DLL files
7 | __pycache__/
8 | *.py[cod]
9 | *$py.class
10 |
11 | # C extensions
12 | *.so
13 |
14 | # Distribution / packaging
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | share/python-wheels/
29 | *.egg-info/
30 | .installed.cfg
31 | *.egg
32 | MANIFEST
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .nox/
48 | .coverage
49 | .coverage.*
50 | .cache
51 | nosetests.xml
52 | coverage.xml
53 | *.cover
54 | *.py,cover
55 | .hypothesis/
56 | .pytest_cache/
57 | cover/
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 | db.sqlite3
67 | db.sqlite3-journal
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | .pybuilder/
81 | target/
82 |
83 | # Jupyter Notebook
84 | .ipynb_checkpoints
85 |
86 | # IPython
87 | profile_default/
88 | ipython_config.py
89 |
90 | # pyenv
91 | # For a library or package, you might want to ignore these files since the code is
92 | # intended to run in multiple environments; otherwise, check them in:
93 | # .python-version
94 |
95 | # pipenv
96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
99 | # install all needed dependencies.
100 | #Pipfile.lock
101 |
102 | # poetry
103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
104 | # This is especially recommended for binary packages to ensure reproducibility, and is more
105 | # commonly ignored for libraries.
106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
107 | #poetry.lock
108 |
109 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
110 | __pypackages__/
111 |
112 | # Celery stuff
113 | celerybeat-schedule
114 | celerybeat.pid
115 |
116 | # SageMath parsed files
117 | *.sage.py
118 |
119 | # Environments
120 | .env
121 | .venv
122 | env/
123 | venv/
124 | ENV/
125 | env.bak/
126 | venv.bak/
127 |
128 | # Spyder project settings
129 | .spyderproject
130 | .spyproject
131 |
132 | # Rope project settings
133 | .ropeproject
134 |
135 | # mkdocs documentation
136 | /site
137 |
138 | # mypy
139 | .mypy_cache/
140 | .dmypy.json
141 | dmypy.json
142 |
143 | # Pyre type checker
144 | .pyre/
145 |
146 | # pytype static type analyzer
147 | .pytype/
148 |
149 | # Cython debug symbols
150 | cython_debug/
151 |
152 | # PyCharm
153 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
154 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
155 | # and can be added to the global gitignore or merged into this file. For a more nuclear
156 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
157 | #.idea/
158 |
159 | # End of https://www.toptal.com/developers/gitignore/api/python
160 | *.pt
161 | *.pth
162 | outputs/
163 | stacking/out/
164 | stacking/data
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Saruwatari&Koyama laboratory, The University of Tokyo
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # UTMOS: UTokyo-SaruLab MOS Prediction System
2 |
3 | Official implementation of ["UTMOS: UTokyo-SaruLab System for VoiceMOS Challenge 2022"](https://arxiv.org/abs/2204.02152) accepted by INTERSPEECH 2022.
4 |
5 | >**Abstract:**
6 | We present the UTokyo-SaruLab mean opinion score (MOS) prediction system submitted to VoiceMOS Challenge 2022. The challenge is to predict the MOS values of speech samples collected from previous Blizzard Challenges and Voice Conversion Challenges for two tracks: a main track for in-domain prediction and an out-of-domain (OOD) track for which there is less labeled data from different listening tests. Our system is based on ensemble learning of strong and weak learners. Strong learners incorporate several improvements to the previous fine-tuning models of self-supervised learning (SSL) models, while weak learners use basic machine-learning methods to predict scores from SSL features.
7 | In the Challenge, our system had the highest score on several metrics for both the main and OOD tracks. In addition, we conducted ablation studies to investigate the effectiveness of our proposed methods.
8 |
9 | 🏆 Our system achieved the 1st places in 10/16 metrics at [the VoiceMOS Challenge 2022](https://voicemos-challenge-2022.github.io/)!
10 |
11 | Demo for UTMOS is available: [](https://huggingface.co/spaces/sarulab-speech/UTMOS-demo)
12 |
13 | ## Quick Prediction
14 | You can simply use a pretrained UTMOS strong learner trained on the VoiceMOS Challenge 2022 Main Track Dataset. We support both single and batch processings in a [NISQA](https://github.com/gabrielmittag/NISQA)-like interface.
15 |
16 | Git clone the Hugging Face repo:
17 | ```
18 | git clone https://huggingface.co/spaces/sarulab-speech/UTMOS-demo
19 | cd UTMOS-demo
20 | pip install -r requirements.txt
21 | ```
22 |
23 | To predict the MOS of a single wav file:
24 | ```
25 | python predict.py --mode predict_file --inp_path /path/to/wav/file.wav --out_path /path/to/csv/file.csv
26 | ```
27 |
28 | To predict the MOS of all .wav files in a folder use:
29 | ```
30 | python predict.py --mode predict_dir --inp_dir /path/to/wav/dir/ --bs --out_path /path/to/csv/file.csv
31 | ```
32 |
33 | ## How to use the whole functionality
34 |
35 | ### Enviornment setup
36 |
37 | 1. This repo uses poetry as the python envoirnmet manager. Install poetry following [this instruction](https://python-poetry.org/docs/#installation) first.
38 | 1. Install required python packages using `poetry install`. And enter the python enviornment with `poetry shell`. All following operations **requires** to be inside the poetry shell enviornment.
39 | 1. Second, download necessary fairseq checkpoint using [download_strong_checkpoints.sh](fairseq_checkpoints/download_strong_checkpoints.sh) for strong and [download_stacking_checkpoints.sh](fairseq_checkpoints/download_stacking_checkpoints.sh) for stacking.
40 | 1. Next, run the following command to exclude bad wav file from main track training set.
41 | The original data will be saved with `.bak` suffix.
42 | ```shell
43 | python remove_silenceWav.py --path_to_dataset path-to-dataset/phase1-main/
44 | ```
45 |
46 | ## Model training
47 | Our system predicts MOS with small errors by stacking of strong and weak learners.
48 | - To run training and inference with a single strong learner, see [strong/README.md](strong/README.md).
49 | - To run stacking, see [stacking/ensemble_multidomain_scripts/README.md](stacking/ensemble_multidomain_scripts/README.md).
50 |
51 | If you encounter any problems regarding running the code, feel free to submit an issue. The code is not fully tested.
52 |
--------------------------------------------------------------------------------
/fairseq_checkpoints/download_stacking_checkpoints.sh:
--------------------------------------------------------------------------------
1 | gdown --fuzzy https://drive.google.com/file/d/19-C7SMQvEFAYLG5uc47NX_MY03JCbI4x/view?usp=sharing
2 | gdown --fuzzy https://drive.google.com/file/d/1PlbT_9_B4F9BsD_ija84sUTVw7almNX8/view?usp=sharing
3 | gdown --fuzzy https://drive.google.com/file/d/1rMu6PQ9vz3qPz4oIm72JDuIr5AHIbCOb/view?usp=sharing
4 | wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small.pt -O wav2vec_small.pt
5 | wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_new.pt -O wav2vec_vox_new.pt
6 | wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/w2v_large_lv_fsh_swbd_cv.pt -O w2v_large_lv_fsh_swbd_cv.pt
7 | wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/xlsr_53_56k.pt -O xlsr_53_56k.pt
8 | wget https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt -O hubert_base_ls960.pt
9 | wget https://dl.fbaipublicfiles.com/hubert/hubert_large_ll60k.pt -O hubert_large_ll60k.pt
10 |
--------------------------------------------------------------------------------
/fairseq_checkpoints/download_strong_checkpoints.sh:
--------------------------------------------------------------------------------
1 | wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small.pt -O wav2vec_small.pt
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "utmos22"
3 | version = "0.1.0"
4 | description = "UTokyo-Sarulab submission for VoiceMOS challenge 2022"
5 | authors = [
6 | "Takaaki Saeki ",
7 | "Xin Detai ",
8 | "Wataru Nakata ",
9 | "Tomoki Koriyama ",
10 | ]
11 |
12 | [tool.poetry.dependencies]
13 | python = "^3.8"
14 | transformers = "^4.15.0"
15 | torch = "1.11"
16 | datasets = "^1.17.0"
17 | matplotlib = "^3.5.1"
18 | seaborn = "^0.11.2"
19 | scikit-learn = "^1.0.2"
20 | pytorch-lightning = "^1.5.8"
21 | phonemizer = "^3.0.1"
22 | SoundFile = "^0.10.3"
23 | python-Levenshtein = "^0.12.2"
24 | wandb = "^0.12.10"
25 | hydra-core = "^1.1.1"
26 | hydra-optuna-sweeper = "^1.1.2"
27 | augment = {git = "https://github.com/facebookresearch/WavAugment.git", rev = "main"}
28 | fairseq = {git = "https://github.com/sarulab-speech/fairseq.git", rev = "for_utmos"}
29 | pandas = "^1.4.1"
30 | gpytorch = "^1.6.0"
31 | lightgbm = "^3.3.2"
32 | gdown = "^4.4.0"
33 |
34 |
35 | [tool.poetry.dev-dependencies]
36 |
37 | [build-system]
38 | requires = ["poetry-core>=1.0.0"]
39 | build-backend = "poetry.core.masonry.api"
40 |
--------------------------------------------------------------------------------
/remove_silenceWav.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 | import shutil
4 |
5 | if __name__ == '__main__':
6 | parser = argparse.ArgumentParser()
7 | parser.add_argument('--path_to_dataset', type=str, required=True)
8 | args = parser.parse_args()
9 | train_set_path = Path(args.path_to_dataset)/'DATA/sets/TRAINSET'
10 | train_mos_list_path = Path(args.path_to_dataset)/'DATA/sets/train_mos_list.txt'
11 | shutil.copyfile( train_mos_list_path, train_mos_list_path.with_suffix('.txt.bak'))
12 | shutil.copyfile( train_set_path, train_set_path.with_suffix('.bak'))
13 | with open(train_mos_list_path) as f:
14 | lines = f.readlines()
15 | new_lines =[]
16 | for line in lines:
17 | if 'sys4bafa-uttc2e86f6.wav' in line:
18 | continue
19 | new_lines.append(line)
20 | with open(train_mos_list_path, 'w') as f:
21 | f.writelines(new_lines)
22 | with open(train_set_path) as f:
23 | lines = f.readlines()
24 | new_lines =[]
25 | for line in lines:
26 | if 'sys4bafa-uttc2e86f6.wav' in line:
27 | continue
28 | new_lines.append(line)
29 | with open(train_set_path, 'w') as f:
30 | f.writelines(new_lines)
31 |
32 |
33 |
34 |
35 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/README.md:
--------------------------------------------------------------------------------
1 | # Stacking of strong and weak learners
2 |
3 | ## Data split
4 | Firstly link `data` to `../data`.
5 | Then run the following commands.
6 | ```shell
7 | python make_ensemble_dataset.py --datatrack phase1-main
8 | python make_ensemble_dataset.py --datatrack phase1-ood
9 | python make_ensemble_dataset_wotest.py --datatrack external
10 | python make_ensemble_testphase.py --datatrack phase1-main
11 | python make_ensemble_testphase.py --datatrack phase1-odd
12 | ```
13 |
14 | ## Feature extraction with SSL model
15 | Place the ckpt file of the pretrained model to `../pretrained_model`.
16 | Then run the following command.
17 | ```shell
18 | python extract_ssl_feature.py
19 | ```
20 |
21 | ## Converting results of strong learners for stacking
22 | Place the respective result files to `../strong_learner_result/main1` and `../strong_learner_result/ood1`.
23 | Then run the following commands.
24 | ```shell
25 | python convert_strong_learner_result.py phase1-main main1
26 | python convert_strong_learner_result.py phase1-ood ood1
27 | python convert_strong_learner_testphase_result.py testphase-main main1
28 | python convert_strong_learner_testphase_result.py testphase-ood ood1
29 | ```
30 |
31 | ## Stage1
32 | For both main and OOD tracks, run the following command to perform stage1.
33 | ```shell
34 | ./run_stage1.sh
35 | ```
36 |
37 | ## Stage2 and 3 for Main track
38 | Run the following commands.
39 | ```shell
40 | ./run_stage2-3_main.sh # Run stage 2 and 3
41 | ./pred_testphase_stage1_main.sh # Predict stage 1
42 | ./pred_testphase_stage2-3_main.sh # Predict stage 2 and 3
43 | ```
44 |
45 | ## Stage 2 and 3 for OOD track
46 | Run the following commands.
47 | ```shell
48 | ./pred_stage1_ood.sh # Predict by cross-domain
49 | ./run_stage2-3_ood.sh # Run stage 2 and 3
50 | ./pred_testphase_stage1_ood.sh # Predict stage 1
51 | ./pred_testphase_stage2-3_ood.sh # Predict stage 2 and 3
52 | ```
53 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/WavLM.py:
--------------------------------------------------------------------------------
1 | ../../strong/WavLM.py
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/calc_result.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from pathlib import Path
4 |
5 | import itertools
6 |
7 | import yaml
8 | import numpy as np
9 | import pandas as pd
10 | import scipy
11 | import scipy.stats
12 |
13 | def get_arg():
14 | import argparse
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('datatrack')
17 | parser.add_argument('feat_type', default='weak36')
18 | return parser.parse_args()
19 |
20 | K_CV = 5
21 |
22 | def evaluate_val(df_val):
23 |
24 | assert len(df_val) == len(df_val.index.unique())
25 |
26 | mse = np.mean(np.square(df_val['true_mos'] - df_val['pred_mos']))
27 | utt_srcc = scipy.stats.spearmanr(df_val['true_mos'], df_val['pred_mos'])[0]
28 | print('CV UTT MSE: {:f}'.format(mse))
29 | print('CV UTT SRCC: {:f}'.format(utt_srcc))
30 |
31 | # sys
32 | df_val['system_ID'] = df_val.index.str.extract(r'^(.+?)-').values
33 |
34 | df_val_sys = df_val.groupby('system_ID')['pred_mos'].mean()
35 | df_true_sys = df_val.groupby('system_ID')['true_mos'].mean()
36 |
37 | df_sys = pd.merge(df_val_sys, df_true_sys, on='system_ID', how='left')
38 |
39 | sys_mse = np.mean(np.square(df_sys['true_mos'] - df_sys['pred_mos']))
40 | sys_srcc = scipy.stats.spearmanr(df_sys['true_mos'], df_sys['pred_mos'])[0]
41 | print('CV SYS MSE: {:f}'.format(sys_mse))
42 | print('CV SYS SRCC: {:f}'.format(sys_srcc))
43 |
44 |
45 | return {'cv_utt_mse': mse}
46 |
47 | def evaluate_test(df_test, pred_datatrack):
48 |
49 | utt_mse = np.mean(np.square(df_test['true_mos'] - df_test['pred_mos']))
50 | utt_srcc = scipy.stats.spearmanr(df_test['true_mos'], df_test['pred_mos'])[0]
51 | print('TEST UTT MSE: {:f}'.format(utt_mse))
52 | print('TEST UTT SRCC: {:f}'.format(utt_srcc))
53 |
54 | df_test['system_ID'] = df_test.index.str.extract(r'^(.+?)-').values
55 |
56 | df_test_sys = df_test.groupby('system_ID')['pred_mos'].mean()
57 |
58 | df_true_sys = pd.read_csv(f'../data/{pred_datatrack}/DATA/mydata_system.csv')
59 |
60 | df_sys = pd.merge(df_test_sys, df_true_sys, on='system_ID', how='left').set_index('system_ID')
61 |
62 | sys_mse = np.mean(np.square(df_sys['mean'] - df_sys['pred_mos']))
63 | sys_srcc = scipy.stats.spearmanr(df_sys['mean'], df_sys['pred_mos'])[0]
64 | print('TEST SYS MSE: {:f}'.format(sys_mse))
65 | print('TEST SYS SRCC: {:f}'.format(sys_srcc))
66 |
67 | return {'test_utt_mse': utt_mse,
68 | 'test_utt_srcc': utt_srcc,
69 | 'test_sys_mse': sys_mse,
70 | 'test_sys_srcc': sys_srcc}
71 |
72 |
73 |
74 | def calc_strong_learner_score(pred_datatrack, model_type, k_cv=K_CV):
75 |
76 | print('Stage1, {}, {}'.format(pred_datatrack, model_type))
77 |
78 | result_dir = Path('../out/ensemble-multidomain/stage1') / \
79 | pred_datatrack / f'{model_type}'
80 | print(result_dir)
81 |
82 | use_cv_result = True
83 |
84 | df_vals = []
85 | df_tests = []
86 |
87 | for i_cv in range(k_cv):
88 | if use_cv_result:
89 | df_vals.append(pd.read_csv(result_dir / str(i_cv) / f'val.csv',
90 | index_col=0))
91 | df_tests.append(pd.read_csv(result_dir / str(i_cv) / f'test.csv',
92 | index_col=0))
93 | else:
94 | pred_dir = result_dir / str(i_cv) / f'pred-{pred_datatrack}'
95 | df_vals.append(pd.read_csv(pred_dir / f'train.csv',
96 | index_col=0))
97 | df_tests.append(pd.read_csv(pred_dir / f'test.csv',
98 | index_col=0))
99 |
100 | if use_cv_result:
101 | df_val = pd.concat(df_vals)
102 | else:
103 | df_val = sum(df_vals) / len(df_vals)
104 | df_test = sum(df_tests) / len(df_tests)
105 |
106 | result = {'stage': 'stage1', 'train_datatrack': 'all',
107 | 'model_type': model_type, 'feat_type': 'nn'}
108 |
109 | result.update(evaluate_val(df_val))
110 | result.update(evaluate_test(df_test, pred_datatrack))
111 |
112 | return result
113 |
114 | def calc_stage1_score(pred_datatrack, train_datatrack, model_type, ssl_type):
115 |
116 | print('Stage1, {}, {}, {}, {}'.format(pred_datatrack, train_datatrack, model_type, ssl_type))
117 |
118 | result_dir = Path('../out/ensemble-multidomain/stage1') / \
119 | train_datatrack / f'{model_type}-{ssl_type}'
120 | print(result_dir)
121 |
122 | use_cv_result = (pred_datatrack == train_datatrack or train_datatrack.startswith('phase1-all'))
123 |
124 | df_vals = []
125 | df_tests = []
126 |
127 | for i_cv in range(K_CV):
128 | if use_cv_result:
129 | df_vals.append(pd.read_csv(result_dir / str(i_cv) / f'val.csv',
130 | index_col=0))
131 | df_tests.append(pd.read_csv(result_dir / str(i_cv) / f'test.csv',
132 | index_col=0))
133 | else:
134 | pred_dir = result_dir / str(i_cv) / f'pred-{pred_datatrack}'
135 | df_vals.append(pd.read_csv(pred_dir / f'train.csv',
136 | index_col=0))
137 | df_tests.append(pd.read_csv(pred_dir / f'test.csv',
138 | index_col=0))
139 |
140 | if use_cv_result:
141 | df_val = pd.concat(df_vals)
142 | else:
143 | df_val = sum(df_vals) / len(df_vals)
144 | df_test = sum(df_tests) / len(df_tests)
145 |
146 | result = {'stage': 'stage1', 'train_datatrack': train_datatrack,
147 | 'model_type': model_type, 'feat_type': ssl_type}
148 |
149 | result.update(evaluate_val(df_val))
150 | result.update(evaluate_test(df_test, pred_datatrack))
151 |
152 | return result
153 |
154 | def calc_stage_n_score(pred_datatrack, stage, model_type, feat_type):
155 |
156 | result_dir = Path('../out/ensemble-multidomain') / stage / \
157 | pred_datatrack / f'{model_type}-{feat_type}'
158 | print(result_dir)
159 |
160 | df_vals = []
161 | df_tests = []
162 |
163 | for i_cv in range(K_CV):
164 | df_vals.append(pd.read_csv(result_dir / str(i_cv) / f'val.csv',
165 | index_col=0))
166 | df_tests.append(pd.read_csv(result_dir / str(i_cv) / f'test.csv',
167 | index_col=0))
168 |
169 | df_val = pd.concat(df_vals)
170 | df_test = sum(df_tests) / len(df_tests)
171 |
172 | result = {'stage': stage, 'train_datatrack': pred_datatrack,
173 | 'model_type': model_type, 'feat_type': feat_type}
174 |
175 | result.update(evaluate_val(df_val))
176 | result.update(evaluate_test(df_test, pred_datatrack))
177 |
178 | return result
179 |
180 |
181 |
182 | def main():
183 |
184 | args = get_arg()
185 |
186 | feat_conf = yaml.safe_load(open('./stage2-method/{}.yaml'.format(args.feat_type)))
187 | print(feat_conf)
188 |
189 | data = []
190 |
191 | # stage1
192 | for model_type in feat_conf['strong_learners']:
193 | k_cv = 3 if args.datatrack == 'phase1-ood' else K_CV
194 | data.append(calc_strong_learner_score(args.datatrack, model_type, k_cv=k_cv))
195 |
196 |
197 | for train_datatrack, model_type, ssl_type in itertools.product(
198 | feat_conf['weak_learners']['datatracks'],
199 | feat_conf['weak_learners']['model_types'],
200 | feat_conf['weak_learners']['ssl_types']):
201 | data.append(calc_stage1_score(args.datatrack, train_datatrack, model_type, ssl_type))
202 |
203 | # stage2
204 | for model_type in feat_conf['weak_learners']['model_types']:
205 | data.append(calc_stage_n_score(args.datatrack, 'stage2', model_type, args.feat_type))
206 |
207 | # stage3
208 | for model_type in ['ridge']:
209 | data.append(calc_stage_n_score(args.datatrack, 'stage3', model_type, args.feat_type))
210 |
211 | df = pd.DataFrame.from_dict(data)
212 | result_dir = Path('../out/ensemble-multidomain/result')
213 | os.makedirs(result_dir, exist_ok=True)
214 | df.to_csv(result_dir / f'{args.datatrack}-{args.feat_type}.csv')
215 |
216 |
217 |
218 | if __name__ == '__main__':
219 | main()
220 |
221 |
222 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/calc_testphase_result.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from pathlib import Path
4 |
5 | import itertools
6 |
7 | import yaml
8 | import numpy as np
9 | import pandas as pd
10 | import scipy
11 | import scipy.stats
12 |
13 | def get_arg():
14 | import argparse
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('datatrack')
17 | parser.add_argument('feat_type', default='weak36')
18 | return parser.parse_args()
19 |
20 | K_CV = 5
21 |
22 |
23 | def main():
24 |
25 | args = get_arg()
26 |
27 | feat_conf = yaml.safe_load(open('./stage2-method/{}.yaml'.format(args.feat_type)))
28 | print(feat_conf)
29 |
30 | train_datatrack = 'phase1-main' if args.datatrack in ['testphase-main', 'valphase-main'] else 'phase1-ood'
31 |
32 | result_dir = Path('../out/ensemble-multidomain') / 'stage3' / \
33 | train_datatrack / f'ridge-{args.feat_type}'
34 |
35 | df_tests = []
36 |
37 | for i_cv in range(K_CV):
38 | df_tests.append(pd.read_csv(result_dir / str(i_cv) / f'pred-{args.datatrack}/test.csv',
39 | index_col=0))
40 |
41 | df_test = sum(df_tests) / len(df_tests)
42 |
43 | answer_dir = Path('../out/ensemble-multidomain/answer')
44 | os.makedirs(answer_dir, exist_ok=True)
45 | df_test['pred_mos'].to_csv(answer_dir / f'{args.datatrack}-{args.feat_type}.csv',
46 | header=None)
47 |
48 |
49 |
50 | if __name__ == '__main__':
51 | main()
52 |
53 |
54 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/collect_stage1_result.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from pathlib import Path
4 | import itertools
5 |
6 | import pandas as pd
7 | import itertools
8 | import yaml
9 |
10 | K_CV = 5
11 |
12 | def get_arg():
13 | import argparse
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('datatrack')
16 | parser.add_argument('feat_type', default='weak36')
17 | return parser.parse_args()
18 |
19 |
20 | def get_learner_data(stage1_result_dir, pred_datatrack, use_cv_result, use_upper_lower, column_tag, k_cv=K_CV):
21 |
22 | df_vals = []
23 | df_tests = []
24 | for i_cv in range(k_cv):
25 | if use_cv_result:
26 | df_vals.append(pd.read_csv(stage1_result_dir / str(i_cv) / f'val.csv',
27 | index_col=0))
28 | df_tests.append(pd.read_csv(stage1_result_dir / str(i_cv) / f'test.csv',
29 | index_col=0))
30 | else:
31 | pred_dir = stage1_result_dir / str(i_cv) / f'pred-{pred_datatrack}'
32 | df_vals.append(pd.read_csv(pred_dir / f'train.csv',
33 | index_col=0))
34 | df_tests.append(pd.read_csv(pred_dir / f'test.csv',
35 | index_col=0))
36 |
37 | if use_cv_result:
38 | df_train = pd.concat(df_vals)
39 | else:
40 | df_train = sum(df_vals) / len(df_vals)
41 | df_test = sum(df_tests) / len(df_tests)
42 |
43 | # empty column df
44 | df_train_new = df_train[[]].copy()
45 | df_test_new = df_test[[]].copy()
46 |
47 | if use_upper_lower:
48 | col_name = 'mean-' + column_tag
49 | df_train_new[col_name] = df_train['pred_mos'].copy()
50 | df_test_new[col_name] = df_test['pred_mos'].copy()
51 |
52 | col_name = 'lower-' + column_tag
53 | df_train_new[col_name] = df_train['lower_mos'].copy()
54 | df_test_new[col_name] = df_test['lower_mos'].copy()
55 |
56 | col_name = 'upper-' + column_tag
57 | df_train_new[col_name] = df_train['upper_mos'].copy()
58 | df_test_new[col_name] = df_test['upper_mos'].copy()
59 |
60 | else:
61 | col_name = 'pred-' + column_tag
62 | df_train_new[col_name] = df_train['pred_mos'].copy()
63 | df_test_new[col_name] = df_test['pred_mos'].copy()
64 |
65 | return df_train_new, df_test_new
66 |
67 |
68 | def main():
69 |
70 | args = get_arg()
71 |
72 | stage2_data_dir = Path('../out/ensemble-multidomain/data-stage2') / args.datatrack / args.feat_type
73 | stage1_result_base_dir = Path('../out/ensemble-multidomain/stage1')
74 |
75 | feat_conf = yaml.safe_load(open('./stage2-method/{}.yaml'.format(args.feat_type)))
76 | print(feat_conf)
77 |
78 | df_train_list = []
79 | df_test_list = []
80 |
81 | for strong_learner in feat_conf['strong_learners']:
82 |
83 | stage1_result_dir = stage1_result_base_dir / args.datatrack / strong_learner
84 |
85 | column_tag = strong_learner
86 | k_cv = 3 if args.datatrack == 'phase1-ood' else K_CV
87 | df_train, df_test = get_learner_data(stage1_result_dir=stage1_result_dir,
88 | pred_datatrack=args.datatrack,
89 | use_cv_result=True,
90 | use_upper_lower=False,
91 | column_tag=strong_learner,
92 | k_cv=k_cv)
93 |
94 | df_train_list.append(df_train)
95 | df_test_list.append(df_test)
96 |
97 | for train_datatrack, model_type, ssl_type in itertools.product(
98 | feat_conf['weak_learners']['datatracks'],
99 | feat_conf['weak_learners']['model_types'],
100 | feat_conf['weak_learners']['ssl_types']):
101 |
102 | if model_type == 'autogp':
103 | if train_datatrack == 'phase1-main':
104 | model_type = 'svgp'
105 | else:
106 | model_type = 'exactgp'
107 |
108 | use_cv_result = (args.datatrack == train_datatrack or train_datatrack.startswith('phase1-all'))
109 | use_upper_lower = (model_type in ['svgp', 'exactgp'])
110 |
111 | stage1_result_dir = stage1_result_base_dir / train_datatrack / f'{model_type}-{ssl_type}'
112 |
113 | column_tag = f'{train_datatrack}---{model_type}---{ssl_type}'
114 |
115 | df_train, df_test = get_learner_data(stage1_result_dir=stage1_result_dir,
116 | pred_datatrack=args.datatrack,
117 | use_cv_result=use_cv_result,
118 | use_upper_lower=use_upper_lower,
119 | column_tag=column_tag)
120 |
121 | df_train_list.append(df_train)
122 | df_test_list.append(df_test)
123 |
124 | df_train_all = pd.concat(df_train_list, axis=1)
125 | df_test_all = pd.concat(df_test_list, axis=1)
126 |
127 | df_train_all.sort_index(inplace=True)
128 | df_test_all.sort_index(inplace=True)
129 |
130 |
131 | print('Columns: {}'.format(df_train_all.columns))
132 | print('Train: {}'.format(df_train_all.shape))
133 | print('Test: {}'.format(df_test_all.shape))
134 |
135 | os.makedirs(stage2_data_dir, exist_ok=True)
136 | df_train_all.to_csv(stage2_data_dir / 'train-X.csv')
137 | df_test_all.to_csv(stage2_data_dir / 'test-X.csv')
138 |
139 |
140 | if __name__ == '__main__':
141 | main()
142 |
143 |
144 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/collect_stage1_testphase_result.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from pathlib import Path
4 | import itertools
5 |
6 | import pandas as pd
7 | import itertools
8 | import yaml
9 |
10 | K_CV = 5
11 |
12 | def get_arg():
13 | import argparse
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('datatrack')
16 | parser.add_argument('feat_type', default='weak36')
17 | return parser.parse_args()
18 |
19 |
20 | def get_learner_data(stage1_result_dir, pred_datatrack, use_upper_lower, column_tag, k_cv=K_CV):
21 |
22 | df_tests = []
23 | for i_cv in range(k_cv):
24 | pred_dir = stage1_result_dir / str(i_cv) / f'pred-{pred_datatrack}'
25 | df_tests.append(pd.read_csv(pred_dir / f'test.csv',
26 | index_col=0))
27 |
28 | df_test = sum(df_tests) / len(df_tests)
29 |
30 | # empty column df
31 | df_test_new = df_test[[]].copy()
32 |
33 | if use_upper_lower:
34 | col_name = 'mean-' + column_tag
35 | df_test_new[col_name] = df_test['pred_mos'].copy()
36 |
37 | col_name = 'lower-' + column_tag
38 | df_test_new[col_name] = df_test['lower_mos'].copy()
39 |
40 | col_name = 'upper-' + column_tag
41 | df_test_new[col_name] = df_test['upper_mos'].copy()
42 |
43 | else:
44 | col_name = 'pred-' + column_tag
45 | df_test_new[col_name] = df_test['pred_mos'].copy()
46 |
47 | return df_test_new
48 |
49 |
50 | def main():
51 |
52 | args = get_arg()
53 |
54 | stage2_data_dir = Path('../out/ensemble-multidomain/data-stage2') / args.datatrack / args.feat_type
55 | stage1_result_base_dir = Path('../out/ensemble-multidomain/stage1')
56 |
57 | feat_conf = yaml.safe_load(open('./stage2-method/{}.yaml'.format(args.feat_type)))
58 | print(feat_conf)
59 |
60 | df_test_list = []
61 |
62 | for strong_learner in feat_conf['strong_learners']:
63 |
64 | # train_datatrack = 'phase1-main' if args.datatrack == 'testphase-main' else 'phase1-ood'
65 |
66 | stage1_result_dir = stage1_result_base_dir / args.datatrack / strong_learner
67 |
68 | k_cv = 3 if args.datatrack == 'testphase-ood' else K_CV
69 |
70 | column_tag = strong_learner
71 | df_test = get_learner_data(stage1_result_dir=stage1_result_dir,
72 | pred_datatrack=args.datatrack,
73 | use_upper_lower=False,
74 | column_tag=strong_learner,
75 | k_cv=k_cv)
76 |
77 | df_test_list.append(df_test)
78 |
79 | for train_datatrack, model_type, ssl_type in itertools.product(
80 | feat_conf['weak_learners']['datatracks'],
81 | feat_conf['weak_learners']['model_types'],
82 | feat_conf['weak_learners']['ssl_types']):
83 |
84 | if model_type == 'autogp':
85 | if train_datatrack == 'phase1-main':
86 | model_type = 'svgp'
87 | else:
88 | model_type = 'exactgp'
89 |
90 | use_cv_result = (args.datatrack == train_datatrack or train_datatrack.startswith('phase1-all'))
91 | use_upper_lower = (model_type in ['svgp', 'exactgp'])
92 |
93 | stage1_result_dir = stage1_result_base_dir / train_datatrack / f'{model_type}-{ssl_type}'
94 |
95 | column_tag = f'{train_datatrack}---{model_type}---{ssl_type}'
96 |
97 | df_test = get_learner_data(stage1_result_dir=stage1_result_dir,
98 | pred_datatrack=args.datatrack,
99 | use_upper_lower=use_upper_lower,
100 | column_tag=column_tag)
101 |
102 | df_test_list.append(df_test)
103 |
104 | df_test_all = pd.concat(df_test_list, axis=1)
105 |
106 | df_test_all.sort_index(inplace=True)
107 |
108 |
109 | print('Columns: {}'.format(df_test_all.columns))
110 | print('Test: {}'.format(df_test_all.shape))
111 |
112 | os.makedirs(stage2_data_dir, exist_ok=True)
113 | df_test_all.to_csv(stage2_data_dir / 'test-X.csv')
114 |
115 |
116 | if __name__ == '__main__':
117 | main()
118 |
119 |
120 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/collect_stage2_result.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from pathlib import Path
4 | import itertools
5 |
6 | import pandas as pd
7 | import itertools
8 | import yaml
9 |
10 | K_CV = 5
11 |
12 | def get_arg():
13 | import argparse
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('datatrack')
16 | parser.add_argument('feat_type', default='weak36')
17 | return parser.parse_args()
18 |
19 |
20 | def get_learner_data(stage2_result_dir, pred_datatrack, use_upper_lower, column_tag):
21 |
22 | df_vals = []
23 | df_tests = []
24 | for i_cv in range(K_CV):
25 | df_vals.append(pd.read_csv(stage2_result_dir / str(i_cv) / f'val.csv',
26 | index_col=0))
27 | df_tests.append(pd.read_csv(stage2_result_dir / str(i_cv) / f'test.csv',
28 | index_col=0))
29 |
30 | df_train = pd.concat(df_vals)
31 | df_test = sum(df_tests) / len(df_tests)
32 |
33 | # empty column df
34 | df_train_new = df_train[[]].copy()
35 | df_test_new = df_test[[]].copy()
36 |
37 | if use_upper_lower:
38 | col_name = 'mean-' + column_tag
39 | df_train_new[col_name] = df_train['pred_mos'].copy()
40 | df_test_new[col_name] = df_test['pred_mos'].copy()
41 |
42 | col_name = 'lower-' + column_tag
43 | df_train_new[col_name] = df_train['lower_mos'].copy()
44 | df_test_new[col_name] = df_test['lower_mos'].copy()
45 |
46 | col_name = 'upper-' + column_tag
47 | df_train_new[col_name] = df_train['upper_mos'].copy()
48 | df_test_new[col_name] = df_test['upper_mos'].copy()
49 |
50 | else:
51 | col_name = 'pred-' + column_tag
52 | df_train_new[col_name] = df_train['pred_mos'].copy()
53 | df_test_new[col_name] = df_test['pred_mos'].copy()
54 |
55 | return df_train_new, df_test_new
56 |
57 |
58 | def main():
59 |
60 | args = get_arg()
61 |
62 | stage3_data_dir = Path('../out/ensemble-multidomain/data-stage3') / args.datatrack / args.feat_type
63 | stage2_result_base_dir = Path('../out/ensemble-multidomain/stage2')
64 |
65 | feat_conf = yaml.safe_load(open('./stage2-method/{}.yaml'.format(args.feat_type)))
66 | print(feat_conf)
67 |
68 | df_train_list = []
69 | df_test_list = []
70 |
71 | for model_type in feat_conf['weak_learners']['model_types']:
72 |
73 | if model_type == 'autogp':
74 | if train_datatrack == 'phase1-main':
75 | model_type = 'svgp'
76 | else:
77 | model_type = 'exactgp'
78 |
79 | use_upper_lower = (model_type in ['svgp', 'exactgp'])
80 |
81 | stage2_result_dir = stage2_result_base_dir / args.datatrack / f'{model_type}-{args.feat_type}'
82 |
83 | column_tag = model_type
84 |
85 | df_train, df_test = get_learner_data(stage2_result_dir, args.datatrack,
86 | use_upper_lower, column_tag)
87 |
88 | df_train_list.append(df_train)
89 | df_test_list.append(df_test)
90 |
91 | df_train_all = pd.concat(df_train_list, axis=1)
92 | df_test_all = pd.concat(df_test_list, axis=1)
93 |
94 | df_train_all.sort_index(inplace=True)
95 | df_test_all.sort_index(inplace=True)
96 |
97 | print('Columns: {}'.format(df_train_all.columns))
98 | print('Train: {}'.format(df_train_all.shape))
99 | print('Test: {}'.format(df_test_all.shape))
100 |
101 | os.makedirs(stage3_data_dir, exist_ok=True)
102 | df_train_all.to_csv(stage3_data_dir / 'train-X.csv')
103 | df_test_all.to_csv(stage3_data_dir / 'test-X.csv')
104 |
105 |
106 | if __name__ == '__main__':
107 | main()
108 |
109 |
110 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/collect_stage2_testphase_result.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from pathlib import Path
4 | import itertools
5 |
6 | import pandas as pd
7 | import itertools
8 | import yaml
9 |
10 | K_CV = 5
11 |
12 | def get_arg():
13 | import argparse
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('datatrack')
16 | parser.add_argument('feat_type', default='weak36')
17 | return parser.parse_args()
18 |
19 |
20 | def get_learner_data(stage2_result_dir, pred_datatrack, use_upper_lower, column_tag, k_cv=K_CV):
21 |
22 | df_tests = []
23 | for i_cv in range(k_cv):
24 | pred_dir = stage2_result_dir / str(i_cv) / f'pred-{pred_datatrack}'
25 | df_tests.append(pd.read_csv(pred_dir / f'test.csv',
26 | index_col=0))
27 |
28 | df_test = sum(df_tests) / len(df_tests)
29 |
30 | # empty column df
31 | df_test_new = df_test[[]].copy()
32 |
33 | if use_upper_lower:
34 | col_name = 'mean-' + column_tag
35 | df_test_new[col_name] = df_test['pred_mos'].copy()
36 |
37 | col_name = 'lower-' + column_tag
38 | df_test_new[col_name] = df_test['lower_mos'].copy()
39 |
40 | col_name = 'upper-' + column_tag
41 | df_test_new[col_name] = df_test['upper_mos'].copy()
42 |
43 | else:
44 | col_name = 'pred-' + column_tag
45 | df_test_new[col_name] = df_test['pred_mos'].copy()
46 |
47 | return df_test_new
48 |
49 |
50 | def main():
51 |
52 | args = get_arg()
53 |
54 | stage3_data_dir = Path('../out/ensemble-multidomain/data-stage3') / args.datatrack / args.feat_type
55 | stage2_result_base_dir = Path('../out/ensemble-multidomain/stage2')
56 |
57 | feat_conf = yaml.safe_load(open('./stage2-method/{}.yaml'.format(args.feat_type)))
58 | print(feat_conf)
59 |
60 | df_test_list = []
61 |
62 |
63 | for model_type in feat_conf['weak_learners']['model_types']:
64 |
65 | if model_type == 'autogp':
66 | if train_datatrack == 'phase1-main':
67 | model_type = 'svgp'
68 | else:
69 | model_type = 'exactgp'
70 |
71 | use_upper_lower = (model_type in ['svgp', 'exactgp'])
72 |
73 | train_datatrack = 'phase1-main' if args.datatrack in ['testphase-main', 'valphase-main'] else 'phase1-ood'
74 |
75 | stage2_result_dir = stage2_result_base_dir / train_datatrack / f'{model_type}-{args.feat_type}'
76 |
77 | column_tag = model_type
78 |
79 | df_test = get_learner_data(stage2_result_dir, args.datatrack,
80 | use_upper_lower, column_tag)
81 | df_test_list.append(df_test)
82 |
83 | df_test_all = pd.concat(df_test_list, axis=1)
84 |
85 | df_test_all.sort_index(inplace=True)
86 |
87 |
88 | print('Columns: {}'.format(df_test_all.columns))
89 | print('Test: {}'.format(df_test_all.shape))
90 |
91 | os.makedirs(stage3_data_dir, exist_ok=True)
92 | df_test_all.to_csv(stage3_data_dir / 'test-X.csv')
93 |
94 |
95 | if __name__ == '__main__':
96 | main()
97 |
98 |
99 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/convert_strong_learner_result.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | from pathlib import Path
4 | import os
5 |
6 |
7 | def get_arg():
8 | import argparse
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('datatrack')
11 | parser.add_argument('learner')
12 | return parser.parse_args()
13 |
14 |
15 | def get_merge_df(df_true, df_pred):
16 |
17 | df = pd.merge(df_pred, df_true, on="wavname", how="left")
18 | df = df.set_index('wavname')[['pred_mos', 'true_mos']]
19 |
20 | return df
21 |
22 | def get_true_mos(datatrack):
23 |
24 | train_true_path = f'../data/{datatrack}/DATA/sets/train_mos_list.txt'
25 | val_true_path = f'../data/{datatrack}/DATA/sets/val_mos_list.txt'
26 |
27 | df_true_dict = {}
28 | df_true_dict['train'] = pd.read_csv(train_true_path, header=None,
29 | names=['wavname', 'true_mos'])
30 | df_true_dict['val'] = pd.read_csv(val_true_path, header=None,
31 | names=['wavname', 'true_mos'])
32 | df_true_dict['train'].shape, df_true_dict['val'].shape
33 |
34 | return df_true_dict
35 |
36 |
37 | def main():
38 |
39 | args = get_arg()
40 |
41 | phase = 'main'
42 | in_dir = Path('../strong_learner_result') / args.learner
43 | out_base_dir = Path(f'../out/ensemble-multidomain/stage1') \
44 | / args.datatrack / args.learner
45 |
46 | df_true_dict = get_true_mos(args.datatrack)
47 |
48 | k_cv = 3 if args.datatrack == 'phase1-ood' else 5
49 | answer_file_tag = 'answer-ood' if args.datatrack == 'phase1-ood' else 'answer-main'
50 |
51 | for i_cv in range(k_cv):
52 | out_dir = out_base_dir / str(i_cv)
53 | os.makedirs(out_dir, exist_ok=True)
54 | print(out_dir)
55 |
56 | for split in ['train', 'val']:
57 | stacking_split_name = {'train': 'val', 'val': 'test'}[split]
58 | learner_split_name = {'train': 'fold', 'val': 'val'}[split]
59 | in_path = in_dir / f'{answer_file_tag}.csv{learner_split_name}_{i_cv}'
60 |
61 | df_pred = pd.read_csv(in_path, header=None,
62 | names=['wavbase', 'pred_mos'])
63 | df_pred["wavname"] = df_pred["wavbase"] + ".wav"
64 |
65 | df = get_merge_df(df_true_dict[split], df_pred)
66 |
67 | df.to_csv(out_dir / f'{stacking_split_name}.csv')
68 |
69 |
70 | if __name__ == '__main__':
71 | main()
72 |
73 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/convert_strong_learner_testphase_result.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | from pathlib import Path
4 | import os
5 |
6 |
7 | def get_arg():
8 | import argparse
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('datatrack')
11 | parser.add_argument('learner')
12 | return parser.parse_args()
13 |
14 |
15 | def get_merge_df(df_true, df_pred):
16 |
17 | df = pd.merge(df_pred, df_true, on="wavname", how="left")
18 | df = df.set_index('wavname')[['pred_mos', 'true_mos']]
19 |
20 | return df
21 |
22 | def main():
23 |
24 | args = get_arg()
25 |
26 | phase = 'main'
27 | in_dir = Path('../strong_learner_result') / args.learner
28 | out_base_dir = Path(f'../out/ensemble-multidomain/stage1/') \
29 | / args.datatrack / args.learner
30 |
31 | k_cv = 3 if args.datatrack == 'testphase-ood' else 5
32 | answer_file_tag = 'answer-ood' if args.datatrack == 'testphase-ood' else 'answer-main'
33 |
34 | for i_cv in range(k_cv):
35 | out_dir = out_base_dir / str(i_cv) / f'pred-{args.datatrack}'
36 | os.makedirs(out_dir, exist_ok=True)
37 | print(out_dir)
38 |
39 | in_path = in_dir / f'{answer_file_tag}.csvtest_{i_cv}'
40 |
41 | df_pred = pd.read_csv(in_path, header=None,
42 | names=['wavbase', 'pred_mos'])
43 | df_pred["wavname"] = df_pred["wavbase"] + ".wav"
44 | df_pred['true_mos'] = -99.0
45 | df = df_pred.set_index('wavname')[['pred_mos', 'true_mos']]
46 |
47 | df.to_csv(out_dir / f'test.csv')
48 |
49 |
50 | if __name__ == '__main__':
51 | main()
52 |
53 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/data_util.py:
--------------------------------------------------------------------------------
1 |
2 | from pathlib import Path
3 | from logging import getLogger
4 |
5 | import numpy as np
6 | import pandas as pd
7 |
8 | logger = getLogger(__name__)
9 |
10 | def make_stage1_data(fold_file, utt_data_dir):
11 |
12 | df = pd.read_csv(fold_file, header=None, index_col=0,
13 | names=['wavname', 'true_mos'])
14 | df = df.sort_index()
15 |
16 | embeddings = []
17 | for wavname in df.index:
18 | wavbase = wavname.split('.')[0]
19 | embeddings.append(np.load(utt_data_dir / f'{wavbase}.npy'))
20 |
21 | X = np.stack(embeddings)
22 | y = df['true_mos'].values
23 |
24 | return df, X, y
25 |
26 | def load_stage1_data(datatrack, ssl_type, i_cv):
27 | utt_data_dir = Path('../out/utt_data') / ssl_type
28 | fold_dir = Path('../out/ensemble-multidomain/fold') / datatrack
29 |
30 | logger.info('load data')
31 |
32 | data = {}
33 | for split in ['train', 'val', 'test']:
34 | fold_file = fold_dir / f'{split}-{i_cv}.csv'
35 | df, X, y = make_stage1_data(fold_file, utt_data_dir)
36 |
37 | logger.info('[{}]\tX: {}, y: {}'.format(split, X.shape, y.shape))
38 |
39 | data[split] = {'X': X, 'y': y, 'df': df}
40 |
41 | return data
42 |
43 |
44 | def load_stage1_train_all_data(datatrack, ssl_type):
45 | utt_data_dir = Path('../out/utt_data') / ssl_type
46 | fold_dir = Path('../out/ensemble-multidomain/fold') / datatrack
47 |
48 | logger.info('load data')
49 |
50 | fold_file = fold_dir / f'train-all.csv'
51 | df, X, y = make_stage1_data(fold_file, utt_data_dir)
52 |
53 | logger.info('[train-all]\tX: {}, y: {}'.format(X.shape, y.shape))
54 |
55 | train_data = {'X': X, 'y': y, 'df': df}
56 |
57 | return train_data
58 |
59 |
60 | def load_stage1_test_data(datatrack, ssl_type):
61 | utt_data_dir = Path('../out/utt_data') / ssl_type
62 | fold_dir = Path('../out/ensemble-multidomain/fold') / datatrack
63 |
64 | logger.info('load data')
65 |
66 | fold_file = fold_dir / f'test-0.csv'
67 | df, X, y = make_stage1_data(fold_file, utt_data_dir)
68 |
69 | logger.info('[test]\tX: {}, y: {}'.format(X.shape, y.shape))
70 |
71 | train_data = {'X': X, 'y': y, 'df': df}
72 |
73 | return train_data
74 |
75 |
76 | def make_stage2_data(fold_file, df_all_X):
77 |
78 | df_fold = pd.read_csv(fold_file, header=None, index_col=0,
79 | names=['wavname', 'true_mos'])
80 | df_fold = df_fold.sort_index()
81 |
82 | wavnames = df_fold.index.values
83 |
84 | X = df_all_X.loc[wavnames, :].values
85 | y = df_fold['true_mos'].values
86 |
87 | return df_fold, X, y
88 |
89 |
90 | def load_stage2_data(datatrack, feat_type, i_cv):
91 | data_dir = Path('../out/ensemble-multidomain') / 'data-stage2' / datatrack / feat_type
92 | fold_dir = Path('../out/ensemble-multidomain/fold') / datatrack
93 |
94 | logger.info('load data')
95 |
96 |
97 | data = {}
98 | for split in ['train', 'val', 'test']:
99 | data_path = data_dir / '{}-X.csv'.format('test' if split == 'test' else 'train')
100 | df_all_X = pd.read_csv(data_path, index_col=0)
101 | fold_file = fold_dir / f'{split}-{i_cv}.csv'
102 | df_fold, X, y = make_stage2_data(fold_file, df_all_X)
103 |
104 | logger.info('[{}]\tX: {}, y: {}'.format(split, X.shape, y.shape))
105 |
106 | data[split] = {'X': X, 'y': y, 'df': df_fold}
107 |
108 | return data
109 |
110 |
111 | def load_stage2_train_all_data(datatrack, feat_type):
112 | data_dir = Path('../out/ensemble-multidomain') / 'data-stage2' / datatrack / feat_type
113 | fold_dir = Path('../out/ensemble-multidomain/fold') / datatrack
114 |
115 | logger.info('load data')
116 |
117 | data_path = data_dir / f'train-X.csv'
118 | df_all_X = pd.read_csv(data_path, index_col=0)
119 | fold_file = fold_dir / f'train-all.csv'
120 | df, X, y = make_stage2_data(fold_file, df_all_X)
121 |
122 | logger.info('[train-all]\tX: {}, y: {}'.format(X.shape, y.shape))
123 |
124 | train_data = {'X': X, 'y': y, 'df': df}
125 |
126 | return train_data
127 |
128 |
129 | def load_stage2_test_data(datatrack, feat_type):
130 | data_dir = Path('../out/ensemble-multidomain') / 'data-stage2' / datatrack / feat_type
131 | fold_dir = Path('../out/ensemble-multidomain/fold') / datatrack
132 |
133 | logger.info('load data')
134 |
135 | data_path = data_dir / f'test-X.csv'
136 | df_all_X = pd.read_csv(data_path, index_col=0)
137 | fold_file = fold_dir / f'test-0.csv'
138 | df, X, y = make_stage2_data(fold_file, df_all_X)
139 |
140 | logger.info('[test]\tX: {}, y: {}'.format(X.shape, y.shape))
141 |
142 | train_data = {'X': X, 'y': y, 'df': df}
143 |
144 | return train_data
145 |
146 | def make_stage3_data(fold_file, df_all_X):
147 | return make_stage2_data(fold_file, df_all_X)
148 |
149 |
150 | def load_stage3_data(datatrack, feat_type, i_cv):
151 | data_dir = Path('../out/ensemble-multidomain') / 'data-stage3' / datatrack / feat_type
152 | fold_dir = Path('../out/ensemble-multidomain/fold') / datatrack
153 |
154 | logger.info('load data')
155 |
156 |
157 | data = {}
158 | for split in ['train', 'val', 'test']:
159 | data_path = data_dir / '{}-X.csv'.format('test' if split == 'test' else 'train')
160 | df_all_X = pd.read_csv(data_path, index_col=0)
161 | fold_file = fold_dir / f'{split}-{i_cv}.csv'
162 | df_fold, X, y = make_stage2_data(fold_file, df_all_X)
163 |
164 | logger.info('[{}]\tX: {}, y: {}'.format(split, X.shape, y.shape))
165 |
166 | data[split] = {'X': X, 'y': y, 'df': df_fold}
167 |
168 | return data
169 |
170 |
171 | def load_stage3_train_all_data(datatrack, feat_type):
172 | data_dir = Path('../out/ensemble-multidomain') / 'data-stage3' / datatrack / feat_type
173 | fold_dir = Path('../out/ensemble-multidomain/fold') / datatrack
174 |
175 | logger.info('load data')
176 |
177 |
178 | data_path = data_dir / f'train-X.csv'
179 | df_all_X = pd.read_csv(data_path, index_col=0)
180 | fold_file = fold_dir / f'train-all.csv'
181 | df, X, y = make_stage2_data(fold_file, df_all_X)
182 |
183 | logger.info('[train-all]\tX: {}, y: {}'.format(X.shape, y.shape))
184 |
185 | train_data = {'X': X, 'y': y, 'df': df}
186 |
187 | return train_data
188 |
189 |
190 | def load_stage3_test_data(datatrack, feat_type):
191 | data_dir = Path('../out/ensemble-multidomain') / 'data-stage3' / datatrack / feat_type
192 | fold_dir = Path('../out/ensemble-multidomain/fold') / datatrack
193 |
194 | logger.info('load data')
195 |
196 | data_path = data_dir / f'test-X.csv'
197 | df_all_X = pd.read_csv(data_path, index_col=0)
198 | fold_file = fold_dir / f'test-0.csv'
199 | df, X, y = make_stage3_data(fold_file, df_all_X)
200 |
201 | logger.info('[test]\tX: {}, y: {}'.format(X.shape, y.shape))
202 |
203 | train_data = {'X': X, 'y': y, 'df': df}
204 |
205 | return train_data
206 |
207 |
208 | def normalize_score(val):
209 | """
210 | >>> normalize_score(1)
211 | -1.0
212 | >>> normalize_score(3)
213 | 0
214 | >>> normalize_score(5)
215 | 1.0
216 | """
217 | return (val - 3.0) / 2.0
218 |
219 | def inverse_normalize_score(val):
220 | """
221 | >>> inverse_normalize_score(-1)
222 | 1.0
223 | >>> inverse_normalize_score(0)
224 | 3.0
225 | >>> inverse_normalize_score(1)
226 | 5.0
227 | """
228 | return (val * 2.0) + 3.0
229 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/extract_ssl_feature.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | from tqdm import tqdm
5 | from pathlib import Path
6 | import torchaudio
7 | import fairseq
8 | import torch
9 | import os
10 |
11 | import sys
12 | sys.path.append('./external_libs/WavLM')
13 | from WavLM import WavLM, WavLMConfig
14 |
15 | def get_arg():
16 | import argparse
17 | parser = argparse.ArgumentParser()
18 | return parser.parse_args()
19 |
20 |
21 | def extract_mean(wavpath, ssl_model, device, use_wavlm):
22 | with torch.no_grad():
23 | if use_wavlm:
24 | wav = torchaudio.load(wavpath)[0]
25 | res = ssl_model.extract_features(wav.to(device))
26 | return res[0].squeeze(0).mean(dim=0)
27 | else:
28 | wav = torchaudio.load(wavpath)[0]
29 | res = ssl_model(wav.to(device), mask=False, features_only=True)
30 | return res['x'].squeeze(0).mean(dim=0)
31 |
32 |
33 | def extract_feature(datatrack, ssl_type):
34 |
35 | device = torch.device('cuda')
36 |
37 | wav_dir = Path(f'../data/{datatrack}/DATA/wav/')
38 |
39 | base_ckpt_file = {
40 | 'w2v_small': '../../fairseq_checkpoints/wav2vec_small.pt',
41 | 'w2v_xlsr': '../../fairseq_checkpoints/xlsr_53_56k.pt',
42 | 'w2v_large': '../../fairseq_checkpoints/wav2vec_vox_new.pt',
43 | 'w2v_large2': '../../fairseq_checkpoints/w2v_large_lv_fsh_swbd_cv.pt',
44 | 'wavlm_base': '../../fairseq_checkpoints/WavLM-Base.pt',
45 | 'wavlm_large': '../../fairseq_checkpoints/WavLM-Large.pt',
46 | 'hubert_base': '../../fairseq_checkpoints/hubert_base_ls960.pt',
47 | 'hubert_large': '../../fairseq_checkpoints/hubert_large_ll60k.pt',
48 | }[ssl_type]
49 |
50 | print('base_ckpt_file: {}'.format(base_ckpt_file))
51 |
52 | use_wavlm = ssl_type in ['wavlm_base', 'wavlm_large']
53 |
54 | if use_wavlm:
55 | checkpoint = torch.load(base_ckpt_file)
56 | cfg = WavLMConfig(checkpoint['cfg'])
57 | ssl_model = WavLM(cfg)
58 | ssl_model.load_state_dict(checkpoint['model'])
59 | else:
60 | model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([base_ckpt_file])
61 | ssl_model = model[0]
62 | ssl_model.remove_pretraining_modules()
63 |
64 | ssl_model.to(device)
65 | ssl_model.eval()
66 |
67 | # print(ssl_model)
68 |
69 | out_dir = Path(f'../out/utt_data/{ssl_type}')
70 | os.makedirs(out_dir, exist_ok=True)
71 |
72 | wavpath_list = list(wav_dir.glob('*.wav'))
73 |
74 | for wavpath in tqdm(wavpath_list):
75 | vec = extract_mean(wavpath, ssl_model, device, use_wavlm)
76 | outpath = out_dir / (wavpath.stem + '.npy')
77 |
78 | vec = vec.detach().cpu().numpy()
79 | np.save(outpath, vec)
80 |
81 |
82 | def main():
83 | args = get_arg()
84 |
85 | ssl_types = [
86 | 'w2v_large2', 'w2v_xlsr',
87 | 'wavlm_base', 'wavlm_large',
88 | 'hubert_large', 'hubert_base',
89 | 'w2v_small', 'w2v_large',
90 | ]
91 | datatracks = ['phase1-main', 'phase1-ood', 'testphase-main', 'testphase-ood']
92 |
93 | for datatrack in datatracks:
94 | for ssl_type in ssl_types:
95 | print('datatrack {}, ssl_type: {}'.format(
96 | datatrack, ssl_type))
97 | extract_feature(datatrack, ssl_type)
98 |
99 |
100 | if __name__ == '__main__':
101 | main()
102 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/gp_models.py:
--------------------------------------------------------------------------------
1 |
2 | from pathlib import Path
3 | from logging import getLogger
4 | import joblib
5 | import os
6 | import json
7 | import pickle
8 |
9 | import math
10 | import numpy as np
11 | from sklearn.preprocessing import StandardScaler
12 | from sklearn.cluster import MiniBatchKMeans
13 | from sklearn.metrics import mean_squared_error, pairwise_distances
14 | import torch
15 | import torch.utils
16 | import gpytorch
17 |
18 | from data_util import normalize_score, inverse_normalize_score
19 |
20 | logger = getLogger(__name__)
21 |
22 | class Dataset(torch.utils.data.Dataset):
23 | def __init__(self, dx, dy, transform=None):
24 |
25 | self._N = len(dx)
26 | self._dx = dx
27 | self._dy = dy
28 |
29 | self.transform = transform
30 |
31 | def __len__(self):
32 | return self._N
33 |
34 | def __getitem__(self, idx):
35 | return self._dx[idx], self._dy[idx]
36 |
37 | class SVGPModel(gpytorch.models.ApproximateGP):
38 | def __init__(self, initial_inducing, initial_lengthscale):
39 | variational_distribution = gpytorch.variational.NaturalVariationalDistribution(initial_inducing.size(0))
40 | variational_strategy = gpytorch.variational.VariationalStrategy(
41 | self, initial_inducing, variational_distribution, learn_inducing_locations=True
42 | )
43 |
44 | super().__init__(variational_strategy)
45 |
46 | self.mean_module = gpytorch.means.ZeroMean()
47 | self.covar_module = gpytorch.kernels.ScaleKernel(
48 | gpytorch.kernels.RBFKernel(ard_num_dims=initial_inducing.size(1)))
49 | self.covar_module.base_kernel.lengthscale = initial_lengthscale
50 |
51 |
52 |
53 | def forward(self, x):
54 | mean_x = self.mean_module(x)
55 | covar_x = self.covar_module(x)
56 | return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
57 |
58 |
59 |
60 | class SVGP:
61 |
62 | def __init__(self, params=None, stage='stage1'):
63 |
64 | if params is None:
65 | self.params = {
66 | 'max_inducings': 1024,
67 | 'batch_size': 1024,
68 | # 'training_epochs': 3000 if stage == 'stage1' else 1000,
69 | 'training_iters': 10000 if stage == 'stage1' else 2500,
70 | }
71 |
72 | else:
73 | self.params = params
74 |
75 | self.device = torch.device('cuda')
76 |
77 | def get_num_inducing(self, num_data):
78 | max_inducings = self.params['max_inducings']
79 |
80 | if num_data >= max_inducings:
81 | return max_inducings
82 |
83 | power = math.floor((math.log(num_data) / math.log(2)))
84 | num_inducings = int(2 ** power)
85 |
86 | return num_inducings
87 |
88 |
89 | def train(self, train_X, train_y, val_X, val_y):
90 |
91 | X_scaler = StandardScaler()
92 | train_X_sc = torch.from_numpy(X_scaler.fit_transform(train_X).astype(np.float32))
93 | train_y_sc = torch.from_numpy(normalize_score(train_y).astype(np.float32))
94 |
95 | val_X_sc = torch.from_numpy(X_scaler.transform(val_X).astype(np.float32))
96 |
97 | # dataloader
98 | train_X_sc = train_X_sc.to(self.device)
99 | train_y_sc = train_y_sc.to(self.device)
100 | val_X_sc = val_X_sc.to(self.device)
101 |
102 | dataset = Dataset(train_X_sc, train_y_sc)
103 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.params['batch_size'], shuffle=True)
104 |
105 | # initial inducing
106 | num_inducings = self.get_num_inducing(len(train_X_sc))
107 | logger.info('Num inducing points: {}'.format(num_inducings))
108 |
109 | kmeans = MiniBatchKMeans(num_inducings)
110 |
111 | for i in range(5):
112 | for b, (x_B, y_B) in enumerate(dataloader):
113 | kmeans.partial_fit(x_B.cpu().numpy())
114 |
115 | initial_inducing = torch.from_numpy(kmeans.cluster_centers_.astype(np.float32))
116 |
117 | # initial lengthscale
118 | for b, (x_B, y_B) in enumerate(dataloader):
119 | D = pairwise_distances(x_B.cpu().numpy())
120 |
121 | distances = D[np.tril_indices(len(D), k=-1)]
122 | # initial_lengthscale = np.sqrt(np.median(distances))
123 | initial_lengthscale = np.median(distances)
124 |
125 | break
126 |
127 | # initialize likelihood and model
128 | gpr = SVGPModel(initial_inducing=initial_inducing,
129 | initial_lengthscale=initial_lengthscale)
130 |
131 | likelihood = gpytorch.likelihoods.GaussianLikelihood()
132 |
133 | gpr = gpr.to(self.device)
134 | likelihood = likelihood.to(self.device)
135 |
136 | # optimizer
137 | variational_ngd_optimizer = gpytorch.optim.NGD(gpr.variational_parameters(),
138 | num_data=train_X_sc.size(0), lr=0.1)
139 |
140 | hyperparameter_optimizer = torch.optim.Adam([
141 | {'params': gpr.hyperparameters()},
142 | {'params': likelihood.parameters()},
143 | ], lr=0.01)
144 |
145 | # training
146 | gpr.train()
147 | likelihood.train()
148 |
149 | mll = gpytorch.mlls.VariationalELBO(likelihood, gpr, num_data=train_X_sc.size(0))
150 |
151 | num_epochs = max(1, self.params['training_iters'] // len(dataloader))
152 | for i in range(num_epochs):
153 | gpr.train()
154 | likelihood.train()
155 |
156 | lower_bound = 0
157 | for b, (x_B, y_B) in enumerate(dataloader):
158 | # x_B = x_B.to(device)
159 | # y_B = y_B.to(device)
160 |
161 | variational_ngd_optimizer.zero_grad()
162 | hyperparameter_optimizer.zero_grad()
163 |
164 | output = gpr(x_B)
165 | bound = mll(output, y_B)
166 | loss = -bound
167 | loss.backward()
168 |
169 | variational_ngd_optimizer.step()
170 | hyperparameter_optimizer.step()
171 |
172 | lower_bound += bound.data.cpu().item()
173 |
174 | if i % 100 == 0:
175 | gpr.eval()
176 | likelihood.eval()
177 | with torch.no_grad():
178 | preds = gpr(val_X_sc)
179 | mean = preds.mean
180 | pred_y_sc = mean.cpu().numpy()
181 |
182 | pred_y = inverse_normalize_score(pred_y_sc)
183 |
184 | mse = mean_squared_error(val_y, pred_y.ravel())
185 |
186 | logger.info('Iter %d/%d - ELBO: %.3f - val_loss %.3f ' % (
187 | i + 1, num_epochs, lower_bound, mse,
188 | ))
189 |
190 | self.X_scaler = X_scaler
191 | self.gpr = gpr
192 |
193 | self.conf = {
194 | 'num_inducings': num_inducings,
195 | 'input_dim': train_X.shape[1],
196 | }
197 |
198 | def predict(self, X, df):
199 | self.gpr.eval()
200 |
201 | X_sc = torch.from_numpy(self.X_scaler.transform(X).astype(np.float32))
202 | X_sc = X_sc.to(self.device)
203 |
204 | with torch.no_grad():
205 | preds = self.gpr(X_sc)
206 | mean = preds.mean
207 | lower, upper = preds.confidence_region()
208 |
209 | mean_y = inverse_normalize_score(mean.cpu().numpy())
210 | lower_y = inverse_normalize_score(lower.cpu().numpy())
211 | upper_y = inverse_normalize_score(upper.cpu().numpy())
212 |
213 | df['pred_mos'] = mean_y.ravel()
214 | df['lower_mos'] = lower_y.ravel()
215 | df['upper_mos'] = upper_y.ravel()
216 |
217 | return df
218 |
219 |
220 | def save_model(self, out_dir: Path):
221 | torch.save(self.gpr.state_dict(), out_dir / 'model.pt')
222 | joblib.dump(self.X_scaler, out_dir / 'X_scaler.joblib')
223 |
224 | with open(out_dir / 'model_config.json', encoding="utf-8", mode="w") as f:
225 | json.dump(self.conf, f, ensure_ascii=False, indent=2)
226 |
227 |
228 | def load_model(self, model_dir: Path, train_X=None):
229 | if os.path.exists(model_dir / 'model_config.json'):
230 | self.conf = json.load(open(model_dir / 'model_config.json', 'rb'))
231 | else:
232 | assert train_X is not None
233 | self.conf = {
234 | 'num_inducings': self.get_num_inducing(len(train_X)),
235 | 'input_dim': train_X.shape[1],
236 | }
237 |
238 | initial_inducing = torch.from_numpy(
239 | np.empty((self.conf['num_inducings'], self.conf['input_dim']), dtype=np.float32))
240 |
241 | self.gpr = SVGPModel(initial_inducing, 1.0)
242 | self.gpr.to(self.device)
243 |
244 | self.gpr.load_state_dict(torch.load(model_dir / 'model.pt', map_location=self.device))
245 | self.X_scaler = joblib.load(model_dir / 'X_scaler.joblib')
246 |
247 | class ExactGPModel(gpytorch.models.ExactGP):
248 | def __init__(self, train_x, train_y, likelihood, initial_lengthscale):
249 | super().__init__(train_x, train_y, likelihood)
250 | self.mean_module = gpytorch.means.ZeroMean()
251 |
252 | # lengthscale_prior = gpytorch.priors.GammaPrior(3.0, 6.0)
253 | # outputscale_prior = gpytorch.priors.GammaPrior(2.0, 0.15)
254 |
255 | self.covar_module = gpytorch.kernels.ScaleKernel(
256 | gpytorch.kernels.RBFKernel(
257 | # lengthscale_prior=lengthscale_prior,
258 | ard_num_dims=train_x.size(1),
259 | ),
260 | # outputscale_prior=outputscale_prior
261 | )
262 |
263 | # self.covar_module = gpytorch.kernels.ScaleKernel(
264 | # gpytorch.kernels.RBFKernel(
265 | # lengthscale_prior=lengthscale_prior,
266 | # ard_num_dims=train_x.size(1),
267 | # ),
268 | # outputscale_prior=outputscale_prior
269 | # )
270 |
271 |
272 | # Initialize lengthscale and outputscale to mean of priors
273 | # self.covar_module.base_kernel.lengthscale = initial_lengthscale
274 | self.covar_module.base_kernel.lengthscale = initial_lengthscale
275 | # self.covar_module.outputscale = outputscale_prior.mean
276 |
277 |
278 |
279 | def forward(self, x):
280 | mean_x = self.mean_module(x)
281 | covar_x = self.covar_module(x)
282 | return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
283 |
284 |
285 | class ExactGP:
286 |
287 | def __init__(self, params=None, stage='stage1'):
288 | if params is None:
289 | self.params = {
290 | # 'max_inducings': 1024,
291 | # 'batch_size': 1024,
292 | # 'training_epochs': 3000 if stage == 'stage1' else 1000,
293 | # 'training_iters': 10000 if stage == 'stage1' else 2500,
294 | 'training_iters': 10000,
295 | }
296 |
297 | else:
298 | self.params = params
299 |
300 | self.device = torch.device('cuda')
301 |
302 | def train(self, train_X, train_y, val_X, val_y):
303 |
304 | X_scaler = StandardScaler()
305 | train_X_sc = torch.from_numpy(X_scaler.fit_transform(train_X).astype(np.float32))
306 | train_y_sc = torch.from_numpy(normalize_score(train_y).astype(np.float32))
307 |
308 | val_X_sc = torch.from_numpy(X_scaler.transform(val_X).astype(np.float32))
309 |
310 | # dataloader
311 | train_X_sc = train_X_sc.to(self.device)
312 | train_y_sc = train_y_sc.to(self.device)
313 | val_X_sc = val_X_sc.to(self.device)
314 |
315 | # dataset = Dataset(train_X_sc, train_y_sc)
316 | # dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.params['batch_size'], shuffle=True)
317 |
318 | # initial lengthscale
319 | D = pairwise_distances(train_X_sc.cpu().numpy())
320 |
321 | distances = D[np.tril_indices(len(D), k=-1)]
322 | # initial_lengthscale = np.sqrt(np.median(distances))
323 | initial_lengthscale = np.median(distances)
324 |
325 | # initialize likelihood and model
326 | likelihood = gpytorch.likelihoods.GaussianLikelihood(
327 | noise_constraint=gpytorch.constraints.GreaterThan(1e-3),
328 | )
329 | gpr = ExactGPModel(train_X_sc, train_y_sc, likelihood,
330 | initial_lengthscale=initial_lengthscale)
331 |
332 | gpr = gpr.to(self.device)
333 | likelihood = likelihood.to(self.device)
334 |
335 | # optimizer
336 | optimizer = torch.optim.Adam(gpr.parameters(), lr=0.01)
337 |
338 | # training
339 | gpr.train()
340 | likelihood.train()
341 |
342 | mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, gpr)
343 |
344 | for i in range(self.params['training_iters']):
345 | gpr.train()
346 | likelihood.train()
347 |
348 | optimizer.zero_grad()
349 |
350 | output = gpr(train_X_sc)
351 | loss = -mll(output, train_y_sc)
352 | loss.backward()
353 |
354 | optimizer.step()
355 |
356 | if i % 100 == 0:
357 | gpr.eval()
358 | likelihood.eval()
359 | with torch.no_grad():
360 | preds = gpr(val_X_sc)
361 | mean = preds.mean
362 | pred_y_sc = mean.cpu().numpy()
363 |
364 | pred_y = inverse_normalize_score(pred_y_sc)
365 |
366 | mse = mean_squared_error(val_y, pred_y.ravel())
367 |
368 | logger.info('Iter %d/%d - NMLL: %.3f - val_loss %.3f ' % (
369 | i + 1, self.params['training_iters'], loss, mse,
370 | ))
371 | # logger.info('Pred_y var %.3f ' % (
372 | # pred_y.var(),
373 | # ))
374 |
375 | self.X_scaler = X_scaler
376 | self.gpr = gpr
377 |
378 | self.conf = {
379 | 'output_shape': train_y.shape,
380 | 'input_shape': train_X.shape,
381 | }
382 |
383 | def predict(self, X, df):
384 | self.gpr.eval()
385 |
386 | X_sc = torch.from_numpy(self.X_scaler.transform(X).astype(np.float32))
387 | X_sc = X_sc.to(self.device)
388 |
389 | with torch.no_grad():
390 | preds = self.gpr(X_sc)
391 | mean = preds.mean
392 | lower, upper = preds.confidence_region()
393 |
394 | mean_y = inverse_normalize_score(mean.cpu().numpy())
395 | lower_y = inverse_normalize_score(lower.cpu().numpy())
396 | upper_y = inverse_normalize_score(upper.cpu().numpy())
397 |
398 | df['pred_mos'] = mean_y.ravel()
399 | df['lower_mos'] = lower_y.ravel()
400 | df['upper_mos'] = upper_y.ravel()
401 |
402 | return df
403 |
404 |
405 | def save_model(self, out_dir: Path):
406 | torch.save(self.gpr.state_dict(), out_dir / 'model.pt')
407 | joblib.dump(self.X_scaler, out_dir / 'X_scaler.joblib')
408 | # pickle.dump(self.gpr, open(out_dir / 'gpr_pkl.pkl', 'wb'))
409 |
410 | with open(out_dir / 'model_config.json', encoding="utf-8", mode="w") as f:
411 | json.dump(self.conf, f, ensure_ascii=False, indent=2)
412 |
413 |
414 | def load_model(self, model_dir: Path, train_X, train_y):
415 | self.conf = json.load(open(model_dir / 'model_config.json', 'rb'))
416 |
417 | self.X_scaler = joblib.load(model_dir / 'X_scaler.joblib')
418 |
419 | train_X_sc = torch.from_numpy(self.X_scaler.transform(train_X).astype(np.float32))
420 | train_y_sc = torch.from_numpy(normalize_score(train_y).astype(np.float32))
421 |
422 | # dummy_X = torch.from_numpy(
423 | # np.empty(self.conf['input_shape'], dtype=np.float32))
424 | # dummy_y = torch.from_numpy(
425 | # np.empty(self.conf['output_shape'], dtype=np.float32))
426 |
427 | likelihood = gpytorch.likelihoods.GaussianLikelihood(
428 | noise_constraint=gpytorch.constraints.GreaterThan(1e-3),
429 | )
430 | self.gpr = ExactGPModel(train_X_sc, train_y_sc, likelihood,
431 | initial_lengthscale=1.0)
432 | self.gpr.to(self.device)
433 |
434 | self.gpr.load_state_dict(torch.load(model_dir / 'model.pt', map_location=self.device))
435 |
436 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/make_ensemble_dataset.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import argparse
4 | from pathlib import Path
5 |
6 | import numpy as np
7 | from sklearn.model_selection import KFold
8 |
9 | SEED_CV = 0
10 | K_CV = 5
11 |
12 |
13 | def get_arg():
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument(
16 | "--datatrack", type=str, required=True, help="phase1-main or phase1-ood",
17 | default='phase1-main',)
18 | return parser.parse_args()
19 |
20 | def write_wavnames(outpath, wavnames, mos_lookup):
21 |
22 | with open(outpath, 'w') as f:
23 | for wavname in sorted(wavnames):
24 | if wavname == 'sys4bafa-uttc2e86f6.wav':
25 | print('Skip sys4bafa-uttc2e86f6')
26 | continue
27 | print('{},{}'.format(wavname, mos_lookup[wavname]), file=f)
28 |
29 |
30 | def main():
31 | args = get_arg()
32 |
33 | datadir = Path('../data', args.datatrack, 'DATA')
34 | # outdir = Path('./out/ensemble', args.datatrack, 'fold')
35 | outdir = Path('../out/ensemble-multidomain/fold', args.datatrack)
36 |
37 | os.makedirs(outdir, exist_ok=True)
38 |
39 | moslists = {
40 | "train": datadir / "sets/train_mos_list.txt",
41 | "val": datadir / "sets/val_mos_list.txt",
42 | }
43 | mos_lookup = {}
44 |
45 | wavnames = {
46 | 'train': [],
47 | 'val': [],
48 | }
49 |
50 | for split in ["train", "val"]:
51 | with open(moslists[split], "r") as fr:
52 | for line in fr:
53 | parts = line.strip().split(",")
54 | wavname = parts[0]
55 | mos = parts[1]
56 | mos_lookup[wavname] = mos
57 | wavnames[split].append(wavname)
58 |
59 | Kf = KFold(n_splits=K_CV, random_state=SEED_CV, shuffle=True)
60 | train_wavnames = np.asarray(wavnames['train'])
61 |
62 | for i, (cv_train_idx, cv_val_idx) in enumerate(Kf.split(train_wavnames)):
63 |
64 | cv_train_wavnames = train_wavnames[cv_train_idx]
65 | cv_val_wavnames = train_wavnames[cv_val_idx]
66 |
67 | write_wavnames(outdir / f'train-{i}.csv', cv_train_wavnames, mos_lookup)
68 | write_wavnames(outdir / f'val-{i}.csv', cv_val_wavnames, mos_lookup)
69 | write_wavnames(outdir / f'test-{i}.csv', wavnames['val'], mos_lookup)
70 |
71 | write_wavnames(outdir / f'train-all.csv', train_wavnames, mos_lookup)
72 |
73 |
74 | if __name__ == '__main__':
75 | main()
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/make_ensemble_dataset_wotest.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import argparse
4 | from pathlib import Path
5 | import random
6 |
7 | import numpy as np
8 | from sklearn.model_selection import KFold
9 |
10 | SEED_CV = 0
11 | K_CV = 5
12 |
13 | NUM_VAL_FILES = 100
14 |
15 | def get_arg():
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument(
18 | "--datatrack", type=str, required=True, help="phase1-main or phase1-ood",
19 | default='phase1-main',)
20 | return parser.parse_args()
21 |
22 | def write_wavnames(outpath, wavnames, mos_lookup):
23 |
24 | with open(outpath, 'w') as f:
25 | for wavname in sorted(wavnames):
26 | if wavname == 'sys4bafa-uttc2e86f6.wav':
27 | print('Skip sys4bafa-uttc2e86f6')
28 | continue
29 | print('{},{}'.format(wavname, mos_lookup[wavname]), file=f)
30 |
31 |
32 | def main():
33 | args = get_arg()
34 |
35 | datadir = Path('../data', args.datatrack, 'DATA')
36 | outdir = Path('../out/ensemble-multidomain/fold', args.datatrack + '-wo_test')
37 |
38 | os.makedirs(outdir, exist_ok=True)
39 |
40 | if args.datatrack == 'external':
41 | moslist_files = ['./external_mos_list.txt']
42 | else:
43 | moslist_files = [
44 | datadir / "sets/train_mos_list.txt",
45 | datadir / "sets/val_mos_list.txt",
46 | ]
47 | mos_lookup = {}
48 |
49 | wavnames = []
50 |
51 | for mos_file in moslist_files:
52 | with open(mos_file, "r") as fr:
53 | for line in fr:
54 | parts = line.strip().split(",")
55 | wavname = parts[0]
56 | mos = parts[1]
57 | mos_lookup[wavname] = mos
58 | wavnames.append(wavname)
59 |
60 | train_wavnames = np.asarray(wavnames)
61 |
62 | rng = np.random.default_rng(SEED_CV)
63 |
64 | test_wavnames = list(sorted(rng.permutation(train_wavnames)[:NUM_VAL_FILES]))
65 |
66 | Kf = KFold(n_splits=K_CV, random_state=SEED_CV, shuffle=True)
67 |
68 |
69 | for i, (cv_train_idx, cv_val_idx) in enumerate(Kf.split(train_wavnames)):
70 |
71 | cv_train_wavnames = train_wavnames[cv_train_idx]
72 | cv_val_wavnames = train_wavnames[cv_val_idx]
73 |
74 | write_wavnames(outdir / f'train-{i}.csv', cv_train_wavnames, mos_lookup)
75 | write_wavnames(outdir / f'val-{i}.csv', cv_val_wavnames, mos_lookup)
76 | write_wavnames(outdir / f'test-{i}.csv', test_wavnames, mos_lookup)
77 |
78 | write_wavnames(outdir / f'train-all.csv', train_wavnames, mos_lookup)
79 |
80 |
81 | if __name__ == '__main__':
82 | main()
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/make_ensemble_testphase.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import argparse
4 | from pathlib import Path
5 |
6 | import numpy as np
7 | from sklearn.model_selection import KFold
8 |
9 | SEED_CV = 0
10 | K_CV = 5
11 |
12 |
13 | def get_arg():
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument(
16 | "--datatrack", type=str, required=True, help="testphase-main or testphase-ood",
17 | default='testphase-main',)
18 | return parser.parse_args()
19 |
20 | def write_wavnames(outpath, wavnames, mos_lookup):
21 |
22 | with open(outpath, 'w') as f:
23 | for wavname in sorted(wavnames):
24 | if wavname == 'sys4bafa-uttc2e86f6.wav':
25 | print('Skip sys4bafa-uttc2e86f6')
26 | continue
27 | print('{},{}'.format(wavname, mos_lookup[wavname]), file=f)
28 |
29 |
30 | def main():
31 | args = get_arg()
32 |
33 | assert args.datatrack in ['phase1-main', 'phase1-ood']
34 |
35 | if args.datatrack == 'phase1-main':
36 | pred_datatrack = 'testphase-main'
37 | elif args.datatrack == 'phase1-ood':
38 | pred_datatrack = 'testphase-ood'
39 |
40 | datadir = Path('../data', args.datatrack, 'DATA')
41 | outdir = Path('../out/ensemble-multidomain/fold', pred_datatrack)
42 |
43 | os.makedirs(outdir, exist_ok=True)
44 |
45 | moslist_file = datadir / "sets/test_mos_list.txt"
46 | mos_lookup = {}
47 |
48 | wavnames = []
49 |
50 | with open(moslist_file, "r") as fr:
51 | for line in fr:
52 | parts = line.strip().split(",")
53 | wavname = parts[0]
54 | mos = -99.0 # dummy
55 | mos_lookup[wavname] = mos
56 | wavnames.append(wavname)
57 |
58 | write_wavnames(outdir / f'test-0.csv', wavnames, mos_lookup)
59 |
60 |
61 | if __name__ == '__main__':
62 | main()
63 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/models.py:
--------------------------------------------------------------------------------
1 |
2 | import pickle
3 | from pathlib import Path
4 | from logging import getLogger
5 |
6 | import joblib
7 |
8 | import sklearn.linear_model
9 | import sklearn.svm
10 | from sklearn.preprocessing import StandardScaler
11 | from sklearn.metrics import mean_squared_error, make_scorer
12 | from sklearn.pipeline import Pipeline
13 | import lightgbm
14 | import torch
15 | import optuna
16 |
17 | from data_util import normalize_score, inverse_normalize_score
18 |
19 | logger = getLogger(__name__)
20 |
21 | K_CV = 5
22 | HP_OPT_TRIALS = 100
23 | HP_OPT_SEED = 0
24 | LinearSVR_OPT_MAX_ITER = 1000
25 | KernelSVR_OPT_MAX_ITER = 100000
26 |
27 | class Ridge:
28 |
29 | def __init__(self, params=None):
30 |
31 | if params is None:
32 | # use hyperparameters tuned by w2v_small
33 | self.params = {
34 | 'alpha': 36.315622558743634,
35 | }
36 |
37 | else:
38 | self.params = params
39 |
40 | def train(self, train_X, train_y, val_X, val_y):
41 |
42 | X_scaler = StandardScaler()
43 | train_X_sc = X_scaler.fit_transform(train_X)
44 | train_y_sc = normalize_score(train_y)
45 |
46 | val_X_sc = X_scaler.transform(val_X)
47 |
48 | ridge = sklearn.linear_model.Ridge(**(self.params))
49 |
50 | ridge.fit(train_X_sc, train_y_sc)
51 |
52 | pred_y = inverse_normalize_score(ridge.predict(val_X_sc))
53 |
54 | val_mse = mean_squared_error(val_y, pred_y)
55 |
56 | logger.info('Val MSE: {:f}'.format(val_mse))
57 |
58 | self.ridge = ridge
59 | self.X_scaler = X_scaler
60 |
61 |
62 | def predict(self, X, df):
63 | """
64 | Calculate results and insert them in pd.DataDrame columns
65 | """
66 | X_sc = self.X_scaler.transform(X)
67 |
68 | pred_y = inverse_normalize_score(self.ridge.predict(X_sc))
69 |
70 | df['pred_mos'] = pred_y.ravel()
71 |
72 | return df
73 |
74 | def save_model(self, out_dir: Path):
75 |
76 | joblib.dump(self.ridge, out_dir / 'model.joblib')
77 | joblib.dump(self.X_scaler, out_dir / 'X_scaler.joblib')
78 |
79 | def load_model(self, model_dir: Path):
80 |
81 | self.ridge = joblib.load(model_dir / 'model.joblib')
82 | self.X_scaler = joblib.load(model_dir / 'X_scaler.joblib')
83 |
84 | def optimize_hp(self, train_X, train_y):
85 |
86 | X_scaler = StandardScaler()
87 | train_X_sc = X_scaler.fit_transform(train_X)
88 | train_y_sc = normalize_score(train_y)
89 |
90 | param_distributions = {
91 | 'alpha': optuna.distributions.LogUniformDistribution(1e-5, 1e+5)
92 | }
93 |
94 | model = sklearn.linear_model.Ridge()
95 | scoring = make_scorer(mean_squared_error, greater_is_better=False)
96 |
97 | optuna_search = optuna.integration.OptunaSearchCV(model, param_distributions,
98 | cv=K_CV,
99 | n_trials=HP_OPT_TRIALS,
100 | random_state=HP_OPT_SEED,
101 | scoring=scoring,
102 | verbose=0)
103 |
104 | optuna_search.fit(train_X_sc, train_y_sc)
105 |
106 | return optuna_search.best_params_
107 |
108 |
109 | class LinearSVR:
110 |
111 | def __init__(self, params=None, stage='stage1'):
112 |
113 | if params is None:
114 | # use hyperparameters tuned by w2v_small
115 | self.params = {
116 | 'C': 0.01982058833734277,
117 | 'epsilon': 0.23072531432463972,
118 | }
119 |
120 | else:
121 | self.params = params
122 |
123 | self.max_iter = 10000
124 | if stage == 'stage1':
125 | self.opt_max_iter = self.max_iter
126 | else:
127 | self.opt_max_iter = LinearSVR_OPT_MAX_ITER
128 |
129 | def train(self, train_X, train_y, val_X, val_y):
130 |
131 | X_scaler = StandardScaler()
132 | train_X_sc = X_scaler.fit_transform(train_X)
133 | train_y_sc = normalize_score(train_y)
134 |
135 | val_X_sc = X_scaler.transform(val_X)
136 |
137 | svr = sklearn.svm.LinearSVR(max_iter=self.max_iter, **self.params)
138 |
139 | svr.fit(train_X_sc, train_y_sc)
140 |
141 | pred_y = inverse_normalize_score(svr.predict(val_X_sc))
142 |
143 | val_mse = mean_squared_error(val_y, pred_y)
144 |
145 | logger.info('Val MSE: {:f}'.format(val_mse))
146 |
147 | self.svr = svr
148 | self.X_scaler = X_scaler
149 |
150 |
151 | def predict(self, X, df):
152 | """
153 | Calculate results and insert them in pd.DataDrame columns
154 | """
155 | X_sc = self.X_scaler.transform(X)
156 |
157 | pred_y = inverse_normalize_score(self.svr.predict(X_sc))
158 |
159 | df['pred_mos'] = pred_y.ravel()
160 |
161 | return df
162 |
163 | def save_model(self, out_dir: Path):
164 |
165 | joblib.dump(self.svr, out_dir / 'model.joblib')
166 | joblib.dump(self.X_scaler, out_dir / 'X_scaler.joblib')
167 |
168 | def load_model(self, model_dir: Path):
169 |
170 | self.svr = joblib.load(model_dir / 'model.joblib')
171 | self.X_scaler = joblib.load(model_dir / 'X_scaler.joblib')
172 |
173 | def optimize_hp(self, train_X, train_y):
174 |
175 | X_scaler = StandardScaler()
176 | train_X_sc = X_scaler.fit_transform(train_X)
177 | train_y_sc = normalize_score(train_y)
178 |
179 | param_distributions = {
180 | 'C': optuna.distributions.LogUniformDistribution(1e-5, 1e+5),
181 | 'epsilon': optuna.distributions.UniformDistribution(0, 1),
182 | }
183 |
184 | model = sklearn.svm.LinearSVR(max_iter=self.opt_max_iter)
185 | scoring = make_scorer(mean_squared_error, greater_is_better=False)
186 |
187 | optuna_search = optuna.integration.OptunaSearchCV(model, param_distributions,
188 | cv=K_CV,
189 | n_trials=HP_OPT_TRIALS,
190 | random_state=HP_OPT_SEED,
191 | scoring=scoring,
192 | verbose=0)
193 |
194 | optuna_search.fit(train_X_sc, train_y_sc)
195 |
196 | return optuna_search.best_params_
197 |
198 |
199 | class KernelSVR:
200 |
201 | def __init__(self, params=None, stage='stage1'):
202 |
203 | if params is None:
204 | # use hyperparameters tuned by w2v_small
205 | self.params = {
206 | 'C': 4.483354499092266,
207 | 'epsilon': 0.2177054099781604,
208 | 'gamma': 0.0006981540829311363,
209 | }
210 |
211 | else:
212 | self.params = params
213 |
214 | # self.max_iter = 10000
215 | if stage == 'stage1':
216 | self.opt_max_iter = -1
217 | else:
218 | self.opt_max_iter = KernelSVR_OPT_MAX_ITER
219 |
220 |
221 | def train(self, train_X, train_y, val_X, val_y):
222 |
223 | X_scaler = StandardScaler()
224 | train_X_sc = X_scaler.fit_transform(train_X)
225 | train_y_sc = normalize_score(train_y)
226 |
227 | val_X_sc = X_scaler.transform(val_X)
228 |
229 | svr = sklearn.svm.SVR(kernel='rbf', **self.params)
230 |
231 | svr.fit(train_X_sc, train_y_sc)
232 |
233 | pred_y = inverse_normalize_score(svr.predict(val_X_sc))
234 |
235 | val_mse = mean_squared_error(val_y, pred_y)
236 |
237 | logger.info('Val MSE: {:f}'.format(val_mse))
238 |
239 | self.svr = svr
240 | self.X_scaler = X_scaler
241 |
242 |
243 | def predict(self, X, df):
244 | """
245 | Calculate results and insert them in pd.DataDrame columns
246 | """
247 | X_sc = self.X_scaler.transform(X)
248 |
249 | pred_y = inverse_normalize_score(self.svr.predict(X_sc))
250 |
251 | df['pred_mos'] = pred_y.ravel()
252 |
253 | return df
254 |
255 | def save_model(self, out_dir: Path):
256 |
257 | joblib.dump(self.svr, out_dir / 'model.joblib')
258 | joblib.dump(self.X_scaler, out_dir / 'X_scaler.joblib')
259 |
260 | def load_model(self, model_dir: Path):
261 |
262 | self.svr = joblib.load(model_dir / 'model.joblib')
263 | self.X_scaler = joblib.load(model_dir / 'X_scaler.joblib')
264 |
265 |
266 |
267 | def optimize_hp(self, train_X, train_y):
268 |
269 | X_scaler = StandardScaler()
270 | train_X_sc = X_scaler.fit_transform(train_X)
271 | train_y_sc = normalize_score(train_y)
272 |
273 | param_distributions = {
274 | 'C': optuna.distributions.LogUniformDistribution(1e-5, 1e+5),
275 | 'epsilon': optuna.distributions.UniformDistribution(0, 1),
276 | 'gamma': optuna.distributions.LogUniformDistribution(1e-5, 1e+5),
277 | }
278 |
279 | model = sklearn.svm.SVR(max_iter=self.opt_max_iter)
280 | scoring = make_scorer(mean_squared_error, greater_is_better=False)
281 |
282 | optuna_search = optuna.integration.OptunaSearchCV(model, param_distributions,
283 | cv=K_CV,
284 | n_trials=HP_OPT_TRIALS,
285 | random_state=HP_OPT_SEED,
286 | scoring=scoring,
287 | verbose=0)
288 |
289 | optuna_search.fit(train_X_sc, train_y_sc)
290 |
291 | return optuna_search.best_params_
292 |
293 |
294 | class RandomForest:
295 |
296 | def __init__(self, params=None):
297 |
298 | if params is None:
299 | self.params = {
300 | 'n_estimators': 100,
301 | 'max_depth': 1000,
302 | }
303 |
304 | else:
305 | self.params = params
306 |
307 | def train(self, train_X, train_y, val_X, val_y):
308 |
309 | X_scaler = StandardScaler()
310 | train_X_sc = X_scaler.fit_transform(train_X)
311 | train_y_sc = normalize_score(train_y)
312 |
313 | val_X_sc = X_scaler.transform(val_X)
314 |
315 | rf = sklearn.ensemble.RandomForestRegressor(**self.params)
316 |
317 | rf.fit(train_X_sc, train_y_sc)
318 |
319 | pred_y = inverse_normalize_score(rf.predict(val_X_sc))
320 |
321 | val_mse = mean_squared_error(val_y, pred_y)
322 |
323 | logger.info('Val MSE: {:f}'.format(val_mse))
324 |
325 | self.rf = rf
326 | self.X_scaler = X_scaler
327 |
328 |
329 | def predict(self, X, df):
330 | """
331 | Calculate results and insert them in pd.DataDrame columns
332 | """
333 | X_sc = self.X_scaler.transform(X)
334 |
335 | pred_y = inverse_normalize_score(self.rf.predict(X_sc))
336 |
337 | df['pred_mos'] = pred_y.ravel()
338 |
339 | return df
340 |
341 | def save_model(self, out_dir: Path):
342 |
343 | joblib.dump(self.rf, out_dir / 'model.joblib')
344 | joblib.dump(self.X_scaler, out_dir / 'X_scaler.joblib')
345 |
346 | def load_model(self, model_dir: Path):
347 |
348 | self.rf = joblib.load(model_dir / 'model.joblib')
349 | self.X_scaler = joblib.load(model_dir / 'X_scaler.joblib')
350 |
351 |
352 | def optimize_hp(self, train_X, train_y):
353 |
354 | X_scaler = StandardScaler()
355 | train_X_sc = X_scaler.fit_transform(train_X)
356 | train_y_sc = normalize_score(train_y)
357 |
358 | param_distributions = {
359 | 'n_estimators': optuna.distributions.IntUniformDistribution(1, 100),
360 | 'max_depth': optuna.distributions.IntLogUniformDistribution(1, 1000),
361 | 'max_features': optuna.distributions.CategoricalDistribution(['auto', 'sqrt', 'log2']),
362 | }
363 |
364 | model = sklearn.ensemble.RandomForestRegressor()
365 | scoring = make_scorer(mean_squared_error, greater_is_better=False)
366 |
367 | optuna_search = optuna.integration.OptunaSearchCV(model, param_distributions,
368 | cv=K_CV,
369 | n_trials=HP_OPT_TRIALS,
370 | random_state=HP_OPT_SEED,
371 | scoring=scoring,
372 | verbose=0)
373 |
374 | optuna_search.fit(train_X_sc, train_y_sc)
375 |
376 | return optuna_search.best_params_
377 |
378 |
379 | class LightGBM:
380 |
381 | def __init__(self, params=None):
382 |
383 | if params is None:
384 | # use hyperparameters tuned by w2v_small
385 | self.params = {
386 | 'lambda_l1': 0.03708013547428929,
387 | 'lambda_l2': 3.1884740170707856e-07,
388 | 'num_leaves': 220,
389 | 'feature_fraction': 0.6747205882024254,
390 | 'bagging_fraction': 0.9367956222111139,
391 | 'bagging_freq': 2,
392 | 'min_child_samples': 92,
393 | 'max_depth': 10
394 | }
395 |
396 | else:
397 | self.params = params
398 |
399 | def train(self, train_X, train_y, val_X, val_y):
400 |
401 | train_set = lightgbm.Dataset(train_X, train_y)
402 | valid_set = lightgbm.Dataset(val_X, val_y, reference=train_set)
403 |
404 | lgb_model = lightgbm.train(
405 | params = self.params,
406 | train_set = train_set,
407 | valid_sets = [train_set, valid_set],
408 | num_boost_round = 10000,
409 | early_stopping_rounds = 10,
410 | )
411 |
412 | pred_y = lgb_model.predict(val_X, num_iteration=lgb_model.best_iteration)
413 |
414 | val_mse = mean_squared_error(val_y, pred_y)
415 |
416 | logger.info('Val MSE: {:f}'.format(val_mse))
417 |
418 | self.lgb_model = lgb_model
419 |
420 |
421 | def predict(self, X, df):
422 | """
423 | Calculate results and insert them in pd.DataDrame columns
424 | """
425 | pred_y = self.lgb_model.predict(X, num_iteration=self.lgb_model.best_iteration)
426 |
427 | df['pred_mos'] = pred_y.ravel()
428 |
429 | return df
430 |
431 | def save_model(self, out_dir: Path):
432 |
433 | with open(out_dir / 'model.pkl', 'wb') as f:
434 | pickle.dump(self.lgb_model, f)
435 |
436 |
437 | def load_model(self, model_dir: Path):
438 |
439 | self.lgb_model = pickle.load(open(model_dir / 'model.pkl', 'rb'))
440 |
441 |
442 | def optimize_hp(self, train_X, train_y):
443 |
444 | X_scaler = StandardScaler()
445 | train_X_sc = X_scaler.fit_transform(train_X)
446 | train_y_sc = normalize_score(train_y)
447 |
448 | lgb_train = optuna.integration.lightgbm.Dataset(train_X_sc, train_y_sc)
449 |
450 | lgbm_params = {
451 | 'objective': 'regression',
452 | 'metric': 'mse',
453 | 'verbosity': -1,
454 | }
455 |
456 | folds = sklearn.model_selection.KFold(n_splits=K_CV, shuffle=True, random_state=HP_OPT_SEED)
457 |
458 | tuner_cv = optuna.integration.lightgbm.LightGBMTunerCV(
459 | lgbm_params, lgb_train,
460 | num_boost_round=1000,
461 | early_stopping_rounds=100,
462 | # verbose_eval=20,
463 | folds=folds,
464 | optuna_seed=HP_OPT_SEED,
465 | )
466 |
467 | tuner_cv.run()
468 |
469 | return tuner_cv.best_params
470 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/modules.py:
--------------------------------------------------------------------------------
1 | ../../strong/modules.py
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/opt_stage1.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from pathlib import Path
4 | import logging
5 | from logging import getLogger
6 | import random
7 |
8 | import numpy as np
9 | import torch
10 |
11 | from data_util import load_stage1_train_all_data
12 |
13 | from models import Ridge, LinearSVR, KernelSVR, RandomForest, LightGBM
14 | from gp_models import SVGP
15 | import json
16 |
17 | logger = getLogger(__name__)
18 |
19 | RAND_SEED = 0
20 |
21 | def get_arg():
22 | import argparse
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument('method')
25 | parser.add_argument('datatrack')
26 | parser.add_argument('ssl_type')
27 | return parser.parse_args()
28 |
29 |
30 |
31 |
32 | def main():
33 | args = get_arg()
34 |
35 | random.seed(RAND_SEED)
36 | np.random.seed(RAND_SEED)
37 | torch.manual_seed(RAND_SEED)
38 |
39 | data = load_stage1_train_all_data(
40 | datatrack=args.datatrack,
41 | ssl_type=args.ssl_type,
42 | )
43 |
44 | if args.method == 'ridge':
45 | model = Ridge()
46 | elif args.method == 'linear_svr':
47 | model = LinearSVR(stage='stage1')
48 | elif args.method == 'kernel_svr':
49 | model = KernelSVR(stage='stage1')
50 | elif args.method == 'rf':
51 | raise NotImplementedError()
52 | # model = RandomForest()
53 | elif args.method == 'lightgbm':
54 | model = LightGBM()
55 | elif args.method == 'svgp':
56 | raise NotImplementedError()
57 | else:
58 | raise RuntimeError('Not supported method: "{}"'.format(args.method))
59 |
60 | best_params = model.optimize_hp(data['X'], data['y'])
61 |
62 | logger.info(best_params)
63 |
64 | out_dir = Path('../out/ensemble-multidomain/opt_hp_stage1') / args.datatrack / \
65 | f'{args.method}-{args.ssl_type}'
66 | os.makedirs(out_dir, exist_ok=True)
67 |
68 | with open(out_dir / 'params.json', encoding="utf-8", mode="w") as f:
69 | json.dump(best_params, f, ensure_ascii=False, indent=2)
70 |
71 |
72 | if __name__ == '__main__':
73 | logging.basicConfig(level=logging.INFO, format='{name}: {message}', style='{')
74 | main()
75 |
76 |
77 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/opt_stage2.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from pathlib import Path
4 | import logging
5 | from logging import getLogger
6 | import random
7 |
8 | import numpy as np
9 | import torch
10 |
11 | from data_util import load_stage2_train_all_data
12 |
13 | from models import Ridge, LinearSVR, KernelSVR, RandomForest, LightGBM
14 | from gp_models import SVGP
15 | import json
16 |
17 | logger = getLogger(__name__)
18 |
19 | RAND_SEED = 0
20 |
21 | def get_arg():
22 | import argparse
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument('method')
25 | parser.add_argument('datatrack')
26 | parser.add_argument('feat_type')
27 | return parser.parse_args()
28 |
29 |
30 |
31 |
32 | def main():
33 | args = get_arg()
34 |
35 | random.seed(RAND_SEED)
36 | np.random.seed(RAND_SEED)
37 | torch.manual_seed(RAND_SEED)
38 |
39 | data = load_stage2_train_all_data(
40 | datatrack=args.datatrack,
41 | feat_type=args.feat_type,
42 | )
43 |
44 | if args.method == 'ridge':
45 | model = Ridge()
46 | elif args.method == 'linear_svr':
47 | model = LinearSVR(stage='stage2')
48 | elif args.method == 'kernel_svr':
49 | model = KernelSVR(stage='stage2')
50 | elif args.method == 'rf':
51 | raise NotImplementedError()
52 | # model = RandomForest()
53 | elif args.method == 'lightgbm':
54 | model = LightGBM()
55 | elif args.method == 'svgp':
56 | raise NotImplementedError()
57 | else:
58 | raise RuntimeError('Not supported method: "{}"'.format(args.method))
59 |
60 | best_params = model.optimize_hp(data['X'], data['y'])
61 |
62 | logger.info(best_params)
63 |
64 | out_dir = Path('../out/ensemble-multidomain/opt_hp_stage2') / args.datatrack / \
65 | f'{args.method}-{args.feat_type}'
66 | os.makedirs(out_dir, exist_ok=True)
67 |
68 | with open(out_dir / 'params.json', encoding="utf-8", mode="w") as f:
69 | json.dump(best_params, f, ensure_ascii=False, indent=2)
70 |
71 |
72 | if __name__ == '__main__':
73 | logging.basicConfig(level=logging.INFO, format='{name}: {message}', style='{')
74 | main()
75 |
76 |
77 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/opt_stage3.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from pathlib import Path
4 | import logging
5 | from logging import getLogger
6 | import random
7 |
8 | import numpy as np
9 | import torch
10 |
11 | from data_util import load_stage3_train_all_data
12 |
13 | from models import Ridge, LinearSVR, KernelSVR, RandomForest, LightGBM
14 | from gp_models import SVGP
15 | import json
16 |
17 | logger = getLogger(__name__)
18 |
19 | RAND_SEED = 0
20 |
21 | def get_arg():
22 | import argparse
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument('method')
25 | parser.add_argument('datatrack')
26 | parser.add_argument('feat_type')
27 | return parser.parse_args()
28 |
29 |
30 |
31 |
32 | def main():
33 | args = get_arg()
34 |
35 | random.seed(RAND_SEED)
36 | np.random.seed(RAND_SEED)
37 | torch.manual_seed(RAND_SEED)
38 |
39 | data = load_stage3_train_all_data(
40 | datatrack=args.datatrack,
41 | feat_type=args.feat_type,
42 | )
43 |
44 | if args.method == 'ridge':
45 | model = Ridge()
46 | elif args.method == 'linear_svr':
47 | model = LinearSVR()
48 | elif args.method == 'kernel_svr':
49 | model = KernelSVR()
50 | elif args.method == 'rf':
51 | raise NotImplementedError()
52 | # model = RandomForest()
53 | elif args.method == 'lightgbm':
54 | model = LightGBM()
55 | elif args.method == 'svgp':
56 | raise NotImplementedError()
57 | else:
58 | raise RuntimeError('Not supported method: "{}"'.format(args.method))
59 |
60 | best_params = model.optimize_hp(data['X'], data['y'])
61 |
62 | logger.info(best_params)
63 |
64 | out_dir = Path('../out/ensemble-multidomain/opt_hp_stage3') / args.datatrack / \
65 | f'{args.method}-{args.feat_type}'
66 | os.makedirs(out_dir, exist_ok=True)
67 |
68 | with open(out_dir / 'params.json', encoding="utf-8", mode="w") as f:
69 | json.dump(best_params, f, ensure_ascii=False, indent=2)
70 |
71 |
72 | if __name__ == '__main__':
73 | logging.basicConfig(level=logging.INFO, format='{name}: {message}', style='{')
74 | main()
75 |
76 |
77 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/pred_stage1.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from pathlib import Path
4 | import logging
5 | from logging import getLogger
6 | import random
7 | import json
8 |
9 | import numpy as np
10 | import torch
11 |
12 | from data_util import load_stage1_data, load_stage1_train_all_data, load_stage1_test_data
13 |
14 | from models import Ridge, LinearSVR, KernelSVR, LightGBM, RandomForest
15 | from gp_models import SVGP, ExactGP
16 |
17 | logger = getLogger(__name__)
18 |
19 | RAND_SEED = 0
20 |
21 | def get_arg():
22 | import argparse
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument('method')
25 | parser.add_argument('train_datatrack')
26 | parser.add_argument('ssl_type')
27 | parser.add_argument('i_cv', type=int)
28 | parser.add_argument('pred_datatrack')
29 | return parser.parse_args()
30 |
31 |
32 | def main():
33 | args = get_arg()
34 |
35 | random.seed(RAND_SEED)
36 | np.random.seed(RAND_SEED)
37 | torch.manual_seed(RAND_SEED)
38 |
39 | if args.method == 'svgp':
40 | model = SVGP()
41 | elif args.method == 'exactgp':
42 | model = ExactGP()
43 | elif args.method == 'rf':
44 | model = RandomForest()
45 | else:
46 | if args.method == 'ridge':
47 | model = Ridge()
48 | elif args.method == 'linear_svr':
49 | model = LinearSVR()
50 | elif args.method == 'kernel_svr':
51 | model = KernelSVR()
52 | elif args.method == 'lightgbm':
53 | model = LightGBM()
54 | else:
55 | raise RuntimeError('Not supported method: "{}"'.format(args.method))
56 |
57 | model_dir = Path('../out/ensemble-multidomain/stage1') / args.train_datatrack / \
58 | f'{args.method}-{args.ssl_type}' / str(args.i_cv)
59 | out_dir = model_dir / f'pred-{args.pred_datatrack}'
60 | os.makedirs(out_dir, exist_ok=True)
61 |
62 | logger.info('Outdir: {}'.format(out_dir))
63 |
64 | if args.method == 'svgp':
65 | # train_data = load_stage1_data(
66 | # datatrack=args.train_datatrack,
67 | # ssl_type=args.ssl_type,
68 | # i_cv=args.i_cv,
69 | # )
70 | # model.load_model(model_dir, train_data['train']['X'])
71 | model.load_model(model_dir)
72 | elif args.method == 'exactgp':
73 | train_data = load_stage1_data(
74 | datatrack=args.train_datatrack,
75 | ssl_type=args.ssl_type,
76 | i_cv=args.i_cv,
77 | )
78 | model.load_model(model_dir, train_data['train']['X'], train_data['train']['y'])
79 | else:
80 | model.load_model(model_dir)
81 |
82 | pred_data = {}
83 | pred_data['train'] = load_stage1_train_all_data(
84 | datatrack=args.pred_datatrack,
85 | ssl_type=args.ssl_type,
86 | )
87 | pred_data['test'] = load_stage1_test_data(
88 | datatrack=args.pred_datatrack,
89 | ssl_type=args.ssl_type,
90 | )
91 |
92 | df_train = model.predict(pred_data['train']['X'], pred_data['train']['df'])
93 | df_test = model.predict(pred_data['test']['X'], pred_data['test']['df'])
94 |
95 | df_train.to_csv(out_dir / 'train.csv')
96 | df_test.to_csv(out_dir / 'test.csv')
97 |
98 |
99 |
100 | if __name__ == '__main__':
101 | logging.basicConfig(level=logging.INFO, format='{name}: {message}', style='{')
102 | main()
103 |
104 |
105 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/pred_stage1_ood.sh:
--------------------------------------------------------------------------------
1 |
2 | set -eu
3 |
4 | ssl_types="w2v_small w2v_large w2v_large2 w2v_xlsr wavlm_base wavlm_large hubert_base hubert_large"
5 | #ssl_types="w2v_small w2v_large wavlm_base wavlm_large hubert_base hubert_large"
6 |
7 | # for ood
8 | for train_datatrack in external-wo_test phase1-main; do
9 | for pred_datatrack in phase1-ood ; do
10 | for ssl_type in $ssl_types; do
11 | for method in ridge linear_svr kernel_svr rf lightgbm exactgp; do
12 | for i_cv in 0 1 2 3 4; do
13 | echo "${method}, ${train_datatrack}, ${ssl_type}, ${i_cv}, ${pred_datatrack}"
14 | python -u pred_stage1.py ${method} ${train_datatrack} ${ssl_type} ${i_cv} ${pred_datatrack}
15 | done
16 | done
17 | done
18 | done
19 | done
20 |
21 | echo "done"
22 |
23 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/pred_testphase_stage1.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from pathlib import Path
4 | import logging
5 | from logging import getLogger
6 | import random
7 | import json
8 |
9 | import numpy as np
10 | import torch
11 |
12 | from data_util import load_stage1_data, load_stage1_train_all_data, load_stage1_test_data
13 |
14 | from models import Ridge, LinearSVR, KernelSVR, LightGBM, RandomForest
15 | from gp_models import SVGP, ExactGP
16 |
17 | logger = getLogger(__name__)
18 |
19 | RAND_SEED = 0
20 |
21 | def get_arg():
22 | import argparse
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument('method')
25 | parser.add_argument('train_datatrack')
26 | parser.add_argument('ssl_type')
27 | parser.add_argument('i_cv', type=int)
28 | parser.add_argument('pred_datatrack')
29 | return parser.parse_args()
30 |
31 |
32 | def main():
33 | args = get_arg()
34 |
35 | random.seed(RAND_SEED)
36 | np.random.seed(RAND_SEED)
37 | torch.manual_seed(RAND_SEED)
38 |
39 | if args.method == 'svgp':
40 | model = SVGP()
41 | elif args.method == 'exactgp':
42 | model = ExactGP()
43 | elif args.method == 'rf':
44 | model = RandomForest()
45 | else:
46 | if args.method == 'ridge':
47 | model = Ridge()
48 | elif args.method == 'linear_svr':
49 | model = LinearSVR()
50 | elif args.method == 'kernel_svr':
51 | model = KernelSVR()
52 | elif args.method == 'lightgbm':
53 | model = LightGBM()
54 | else:
55 | raise RuntimeError('Not supported method: "{}"'.format(args.method))
56 |
57 | model_dir = Path('../out/ensemble-multidomain/stage1') / args.train_datatrack / \
58 | f'{args.method}-{args.ssl_type}' / str(args.i_cv)
59 | out_dir = model_dir / f'pred-{args.pred_datatrack}'
60 | os.makedirs(out_dir, exist_ok=True)
61 |
62 | logger.info('Outdir: {}'.format(out_dir))
63 |
64 | if args.method == 'svgp':
65 | # train_data = load_stage1_data(
66 | # datatrack=args.train_datatrack,
67 | # ssl_type=args.ssl_type,
68 | # i_cv=args.i_cv,
69 | # )
70 | # model.load_model(model_dir, train_data['train']['X'])
71 | model.load_model(model_dir)
72 | elif args.method == 'exactgp':
73 | train_data = load_stage1_data(
74 | datatrack=args.train_datatrack,
75 | ssl_type=args.ssl_type,
76 | i_cv=args.i_cv,
77 | )
78 | model.load_model(model_dir, train_data['train']['X'], train_data['train']['y'])
79 | else:
80 | model.load_model(model_dir)
81 |
82 | pred_data = {}
83 | pred_data['test'] = load_stage1_test_data(
84 | datatrack=args.pred_datatrack,
85 | ssl_type=args.ssl_type,
86 | )
87 |
88 | df_test = model.predict(pred_data['test']['X'], pred_data['test']['df'])
89 |
90 | df_test.to_csv(out_dir / 'test.csv')
91 |
92 |
93 |
94 | if __name__ == '__main__':
95 | logging.basicConfig(level=logging.INFO, format='{name}: {message}', style='{')
96 | main()
97 |
98 |
99 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/pred_testphase_stage1_main.sh:
--------------------------------------------------------------------------------
1 |
2 | set -eu
3 |
4 | ssl_types="w2v_small w2v_large w2v_large2 w2v_xlsr wavlm_base wavlm_large hubert_base hubert_large"
5 | # ssl_types="w2v_small w2v_large wavlm_base wavlm_large hubert_base hubert_large"
6 |
7 | for train_datatrack in phase1-main; do
8 | for pred_datatrack in testphase-main; do
9 | for ssl_type in ${ssl_types}; do
10 | for method in ridge linear_svr kernel_svr rf lightgbm exactgp ; do
11 | for i_cv in 0 1 2 3 4; do
12 | echo "${method}, ${train_datatrack}, ${ssl_type}, ${i_cv}, ${pred_datatrack}"
13 | python -u pred_testphase_stage1.py ${method} ${train_datatrack} ${ssl_type} ${i_cv} ${pred_datatrack}
14 | done
15 | done
16 | done
17 | done
18 | done
19 |
20 | echo "done"
21 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/pred_testphase_stage1_ood.sh:
--------------------------------------------------------------------------------
1 |
2 | set -eu
3 |
4 | ssl_types="w2v_small w2v_large w2v_large2 w2v_xlsr wavlm_base wavlm_large hubert_base hubert_large"
5 |
6 | for train_datatrack in phase1-ood external-wo_test phase1-main; do
7 | for pred_datatrack in testphase-ood ; do
8 | for ssl_type in ${ssl_types}; do
9 | for method in ridge linear_svr kernel_svr rf lightgbm exactgp ; do
10 | for i_cv in 0 1 2 3 4; do
11 | echo "${method}, ${train_datatrack}, ${ssl_type}, ${i_cv}, ${pred_datatrack}"
12 | python -u pred_testphase_stage1.py ${method} ${train_datatrack} ${ssl_type} ${i_cv} ${pred_datatrack}
13 | done
14 | done
15 | done
16 | done
17 | done
18 |
19 | echo "done"
20 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/pred_testphase_stage2-3_main.sh:
--------------------------------------------------------------------------------
1 |
2 | set -eu
3 |
4 | train_datatrack=phase1-main
5 | pred_datatrack=testphase-main
6 | feat_type=main-strong1-weak48
7 |
8 | python -u collect_stage1_testphase_result.py ${pred_datatrack} ${feat_type}
9 |
10 | for method in ridge linear_svr kernel_svr rf lightgbm exactgp; do
11 | for i_cv in 0 1 2 3 4; do
12 | echo "${method}, ${train_datatrack}, ${feat_type}, ${i_cv}, ${pred_datatrack}"
13 | python -u pred_testphase_stage2.py ${method} ${train_datatrack} ${feat_type} ${i_cv} ${pred_datatrack}
14 | done
15 | done
16 |
17 | echo "Collect stage2 data."
18 | python -u collect_stage2_testphase_result.py ${pred_datatrack} ${feat_type}
19 |
20 | for method in ridge; do
21 | for i_cv in 0 1 2 3 4; do
22 | echo "Run stage3: ${method}, ${train_datatrack}, ${feat_type}, ${i_cv}, ${pred_datatrack}"
23 | python -u pred_testphase_stage3.py ${method} ${train_datatrack} ${feat_type} ${i_cv} ${pred_datatrack}
24 | done
25 | done
26 |
27 | echo "Calculate result."
28 |
29 | python -u calc_testphase_result.py ${pred_datatrack} ${feat_type}
30 |
31 | echo "done"
32 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/pred_testphase_stage2-3_ood.sh:
--------------------------------------------------------------------------------
1 |
2 | set -eu
3 |
4 | train_datatrack=phase1-ood
5 | pred_datatrack=testphase-ood
6 | feat_type=ood-strong1-weak144
7 |
8 | python -u collect_stage1_testphase_result.py ${pred_datatrack} ${feat_type}
9 |
10 | for method in ridge linear_svr kernel_svr rf lightgbm exactgp; do
11 | for i_cv in 0 1 2 3 4; do
12 | echo "${method}, ${train_datatrack}, ${feat_type}, ${i_cv}, ${pred_datatrack}"
13 | python -u pred_testphase_stage2.py ${method} ${train_datatrack} ${feat_type} ${i_cv} ${pred_datatrack}
14 | done
15 | done
16 |
17 | echo "Collect stage2 data."
18 | python -u collect_stage2_testphase_result.py ${pred_datatrack} ${feat_type}
19 |
20 | for method in ridge; do
21 | for i_cv in 0 1 2 3 4; do
22 | echo "Run stage3: ${method}, ${train_datatrack}, ${feat_type}, ${i_cv}, ${pred_datatrack}"
23 | python -u pred_testphase_stage3.py ${method} ${train_datatrack} ${feat_type} ${i_cv} ${pred_datatrack}
24 | done
25 | done
26 |
27 | echo "Calculate result."
28 |
29 | python -u calc_testphase_result.py ${pred_datatrack} ${feat_type}
30 |
31 | echo "done"
32 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/pred_testphase_stage2.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from pathlib import Path
4 | import logging
5 | from logging import getLogger
6 | import random
7 | import json
8 |
9 | import numpy as np
10 | import torch
11 |
12 | from data_util import load_stage2_data, load_stage2_test_data
13 |
14 | from models import Ridge, LinearSVR, KernelSVR, LightGBM, RandomForest
15 | from gp_models import SVGP, ExactGP
16 |
17 | logger = getLogger(__name__)
18 |
19 | RAND_SEED = 0
20 |
21 | def get_arg():
22 | import argparse
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument('method')
25 | parser.add_argument('train_datatrack')
26 | parser.add_argument('feat_type')
27 | parser.add_argument('i_cv', type=int)
28 | parser.add_argument('pred_datatrack')
29 | return parser.parse_args()
30 |
31 |
32 | def main():
33 | args = get_arg()
34 |
35 | random.seed(RAND_SEED)
36 | np.random.seed(RAND_SEED)
37 | torch.manual_seed(RAND_SEED)
38 |
39 | if args.method == 'svgp':
40 | model = SVGP()
41 | elif args.method == 'exactgp':
42 | model = ExactGP()
43 | elif args.method == 'rf':
44 | model = RandomForest()
45 | else:
46 | if args.method == 'ridge':
47 | model = Ridge()
48 | elif args.method == 'linear_svr':
49 | model = LinearSVR()
50 | elif args.method == 'kernel_svr':
51 | model = KernelSVR()
52 | elif args.method == 'lightgbm':
53 | model = LightGBM()
54 | else:
55 | raise RuntimeError('Not supported method: "{}"'.format(args.method))
56 |
57 | model_dir = Path('../out/ensemble-multidomain/stage2') / args.train_datatrack / \
58 | f'{args.method}-{args.feat_type}' / str(args.i_cv)
59 | out_dir = model_dir / f'pred-{args.pred_datatrack}'
60 | os.makedirs(out_dir, exist_ok=True)
61 |
62 | logger.info('Outdir: {}'.format(out_dir))
63 |
64 | if args.method == 'svgp':
65 | # train_data = load_stage1_data(
66 | # datatrack=args.train_datatrack,
67 | # feat_type=args.feat_type,
68 | # i_cv=args.i_cv,
69 | # )
70 | # model.load_model(model_dir, train_data['train']['X'])
71 | model.load_model(model_dir)
72 | elif args.method == 'exactgp':
73 | train_data = load_stage2_data(
74 | datatrack=args.train_datatrack,
75 | feat_type=args.feat_type,
76 | i_cv=args.i_cv,
77 | )
78 | model.load_model(model_dir, train_data['train']['X'], train_data['train']['y'])
79 | else:
80 | model.load_model(model_dir)
81 |
82 | pred_data = {}
83 | pred_data['test'] = load_stage2_test_data(
84 | datatrack=args.pred_datatrack,
85 | feat_type=args.feat_type,
86 | )
87 |
88 | df_test = model.predict(pred_data['test']['X'], pred_data['test']['df'])
89 |
90 | df_test.to_csv(out_dir / 'test.csv')
91 |
92 |
93 |
94 | if __name__ == '__main__':
95 | logging.basicConfig(level=logging.INFO, format='{name}: {message}', style='{')
96 | main()
97 |
98 |
99 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/pred_testphase_stage3.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from pathlib import Path
4 | import logging
5 | from logging import getLogger
6 | import random
7 | import json
8 |
9 | import numpy as np
10 | import torch
11 |
12 | from data_util import load_stage3_data, load_stage3_test_data
13 |
14 | from models import Ridge, LinearSVR, KernelSVR, LightGBM, RandomForest
15 | from gp_models import SVGP, ExactGP
16 |
17 | logger = getLogger(__name__)
18 |
19 | RAND_SEED = 0
20 |
21 | def get_arg():
22 | import argparse
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument('method')
25 | parser.add_argument('train_datatrack')
26 | parser.add_argument('feat_type')
27 | parser.add_argument('i_cv', type=int)
28 | parser.add_argument('pred_datatrack')
29 | return parser.parse_args()
30 |
31 |
32 | def main():
33 | args = get_arg()
34 |
35 | random.seed(RAND_SEED)
36 | np.random.seed(RAND_SEED)
37 | torch.manual_seed(RAND_SEED)
38 |
39 | if args.method == 'svgp':
40 | model = SVGP()
41 | elif args.method == 'exactgp':
42 | model = ExactGP()
43 | elif args.method == 'rf':
44 | model = RandomForest()
45 | else:
46 | if args.method == 'ridge':
47 | model = Ridge()
48 | elif args.method == 'linear_svr':
49 | model = LinearSVR()
50 | elif args.method == 'kernel_svr':
51 | model = KernelSVR()
52 | elif args.method == 'lightgbm':
53 | model = LightGBM()
54 | else:
55 | raise RuntimeError('Not supported method: "{}"'.format(args.method))
56 |
57 | model_dir = Path('../out/ensemble-multidomain/stage3') / args.train_datatrack / \
58 | f'{args.method}-{args.feat_type}' / str(args.i_cv)
59 | out_dir = model_dir / f'pred-{args.pred_datatrack}'
60 | os.makedirs(out_dir, exist_ok=True)
61 |
62 | logger.info('Outdir: {}'.format(out_dir))
63 |
64 | if args.method == 'svgp':
65 | # train_data = load_stage1_data(
66 | # datatrack=args.train_datatrack,
67 | # feat_type=args.feat_type,
68 | # i_cv=args.i_cv,
69 | # )
70 | # model.load_model(model_dir, train_data['train']['X'])
71 | model.load_model(model_dir)
72 | elif args.method == 'exactgp':
73 | train_data = load_stage3_data(
74 | datatrack=args.train_datatrack,
75 | feat_type=args.feat_type,
76 | i_cv=args.i_cv,
77 | )
78 | model.load_model(model_dir, train_data['train']['X'], train_data['train']['y'])
79 | else:
80 | model.load_model(model_dir)
81 |
82 | pred_data = {}
83 | pred_data['test'] = load_stage3_test_data(
84 | datatrack=args.pred_datatrack,
85 | feat_type=args.feat_type,
86 | )
87 |
88 | df_test = model.predict(pred_data['test']['X'], pred_data['test']['df'])
89 |
90 | df_test.to_csv(out_dir / 'test.csv')
91 |
92 |
93 |
94 | if __name__ == '__main__':
95 | logging.basicConfig(level=logging.INFO, format='{name}: {message}', style='{')
96 | main()
97 |
98 |
99 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/run_stage1.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from pathlib import Path
4 | import logging
5 | from logging import getLogger
6 | import random
7 | import json
8 |
9 | import numpy as np
10 | import torch
11 |
12 | from data_util import load_stage1_data
13 |
14 | from models import Ridge, LinearSVR, KernelSVR, LightGBM, RandomForest
15 | from gp_models import SVGP, ExactGP
16 |
17 | logger = getLogger(__name__)
18 |
19 | RAND_SEED = 0
20 |
21 | def get_arg():
22 | import argparse
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument('method')
25 | parser.add_argument('datatrack')
26 | parser.add_argument('ssl_type')
27 | parser.add_argument('i_cv', type=int)
28 | parser.add_argument('--use_opt', action='store_true', default=False)
29 | return parser.parse_args()
30 |
31 |
32 |
33 |
34 | def main():
35 | args = get_arg()
36 |
37 | random.seed(RAND_SEED)
38 | np.random.seed(RAND_SEED)
39 | torch.manual_seed(RAND_SEED)
40 |
41 | data = load_stage1_data(
42 | datatrack=args.datatrack,
43 | ssl_type=args.ssl_type,
44 | i_cv=args.i_cv,
45 | )
46 |
47 | if args.method == 'svgp':
48 | model = SVGP()
49 | elif args.method == 'exactgp':
50 | model = ExactGP()
51 | elif args.method == 'rf':
52 | model = RandomForest()
53 | else:
54 | if args.use_opt:
55 | param_file = Path('../out/ensemble-multidomain/opt_hp_stage1') / args.datatrack / \
56 | f'{args.method}-{args.ssl_type}' / 'params.json'
57 | params = json.load(open(param_file, 'rb'))
58 | logger.info('Params: {}'.format(params))
59 | else:
60 | params = {}
61 |
62 | if args.method == 'ridge':
63 | model = Ridge(params=params)
64 | elif args.method == 'linear_svr':
65 | model = LinearSVR(params=params)
66 | elif args.method == 'kernel_svr':
67 | model = KernelSVR(params=params)
68 | elif args.method == 'lightgbm':
69 | model = LightGBM(params=params)
70 | else:
71 | raise RuntimeError('Not supported method: "{}"'.format(args.method))
72 |
73 | model.train(data['train']['X'], data['train']['y'],
74 | data['val']['X'], data['val']['y'])
75 |
76 | df_val = model.predict(data['val']['X'], data['val']['df'])
77 | df_test = model.predict(data['test']['X'], data['test']['df'])
78 |
79 | out_dir = Path('../out/ensemble-multidomain/stage1') / args.datatrack / \
80 | f'{args.method}-{args.ssl_type}' / str(args.i_cv)
81 | os.makedirs(out_dir, exist_ok=True)
82 |
83 | df_val.to_csv(out_dir / 'val.csv')
84 | df_test.to_csv(out_dir / 'test.csv')
85 |
86 | model.save_model(out_dir)
87 |
88 |
89 |
90 |
91 | if __name__ == '__main__':
92 | logging.basicConfig(level=logging.INFO, format='{name}: {message}', style='{')
93 | main()
94 |
95 |
96 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/run_stage1.sh:
--------------------------------------------------------------------------------
1 |
2 | set -eu
3 |
4 | ssl_types="w2v_small w2v_large w2v_large2 w2v_xlsr wavlm_base wavlm_large hubert_base hubert_large"
5 |
6 | for datatrack in phase1-ood phase1-main external-wo_test; do
7 | for method in ridge linear_svr kernel_svr rf lightgbm exactgp; do
8 | for ssl_type in ${ssl_types}; do
9 | for i_cv in 0 1 2 3 4; do
10 | echo "${datatrack}, ${method}, ${ssl_type}, ${i_cv}"
11 | python -u run_stage1.py ${method} ${datatrack} ${ssl_type} ${i_cv}
12 | done
13 | done
14 | done
15 | done
16 |
17 | echo "done"
18 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/run_stage2-3_main.sh:
--------------------------------------------------------------------------------
1 |
2 | set -eu
3 |
4 | datatrack=phase1-main
5 | feat_type=main-strong1-weak48
6 |
7 | python -u collect_stage1_result.py ${datatrack} ${feat_type}
8 |
9 | for method in ridge linear_svr kernel_svr rf lightgbm exactgp; do
10 | for i_cv in 0 1 2 3 4; do
11 | echo "Run stage2: ${method}, ${datatrack}, ${feat_type}, ${i_cv}"
12 | python -u run_stage2.py ${method} ${datatrack} ${feat_type} ${i_cv}
13 | done
14 | done
15 |
16 | echo "Collect stage2 data."
17 | python -u collect_stage2_result.py ${datatrack} ${feat_type}
18 |
19 | for method in ridge; do
20 | for i_cv in 0 1 2 3 4; do
21 | echo "Run stage3: ${method} ${datatrack}, ${feat_type}, ${i_cv}"
22 | python -u run_stage3.py ${method} ${datatrack} ${feat_type} ${i_cv}
23 | done
24 | done
25 |
26 | echo "Calculate result."
27 |
28 | python -u calc_result.py ${datatrack} ${feat_type}
29 |
30 | echo "done"
31 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/run_stage2-3_ood.sh:
--------------------------------------------------------------------------------
1 |
2 | set -eu
3 |
4 | datatrack=phase1-ood
5 | feat_type=ood-strong1-weak144
6 |
7 | python -u collect_stage1_result.py ${datatrack} ${feat_type}
8 |
9 | for method in ridge linear_svr kernel_svr rf lightgbm exactgp; do
10 | for i_cv in 0 1 2 3 4; do
11 | echo "Run stage2: ${method}, ${datatrack}, ${feat_type}, ${i_cv}"
12 | python -u run_stage2.py ${method} ${datatrack} ${feat_type} ${i_cv}
13 | done
14 | done
15 |
16 | echo "Collect stage2 data."
17 | python -u collect_stage2_result.py ${datatrack} ${feat_type}
18 |
19 | for method in ridge; do
20 | for i_cv in 0 1 2 3 4; do
21 | echo "Run stage3: ${method} ${datatrack}, ${feat_type}, ${i_cv}"
22 | python -u run_stage3.py ${method} ${datatrack} ${feat_type} ${i_cv}
23 | done
24 | done
25 |
26 | echo "Calculate result."
27 |
28 | python -u calc_result.py ${datatrack} ${feat_type}
29 |
30 | echo "done"
31 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/run_stage2.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from pathlib import Path
4 | import logging
5 | from logging import getLogger
6 | import random
7 | import json
8 |
9 | import numpy as np
10 | import torch
11 |
12 | from data_util import load_stage2_data
13 |
14 | from models import Ridge, LinearSVR, KernelSVR, LightGBM, RandomForest
15 | from gp_models import SVGP, ExactGP
16 |
17 | logger = getLogger(__name__)
18 |
19 | RAND_SEED = 0
20 |
21 | def get_arg():
22 | import argparse
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument('method')
25 | parser.add_argument('datatrack')
26 | parser.add_argument('feat_type')
27 | parser.add_argument('i_cv', type=int)
28 | parser.add_argument('--use_opt', action='store_true', default=False)
29 | return parser.parse_args()
30 |
31 |
32 |
33 |
34 | def main():
35 | args = get_arg()
36 |
37 | random.seed(RAND_SEED)
38 | np.random.seed(RAND_SEED)
39 | torch.manual_seed(RAND_SEED)
40 |
41 | data = load_stage2_data(
42 | datatrack=args.datatrack,
43 | feat_type=args.feat_type,
44 | i_cv=args.i_cv,
45 | )
46 |
47 | method = args.method
48 |
49 | if method == 'autogp':
50 | if args.datatrack == 'phase1-main':
51 | method = 'svgp'
52 | else:
53 | method = 'exactgp'
54 |
55 | if method == 'svgp':
56 | model = SVGP(stage='stage2')
57 | elif method == 'exactgp':
58 | model = ExactGP(stage='stage2')
59 | elif method == 'rf':
60 | model = RandomForest()
61 | else:
62 | if args.use_opt:
63 | param_file = Path('../out/ensemble-multidomain/opt_hp_stage2') / args.datatrack / \
64 | f'{method}-{args.feat_type}' / 'params.json'
65 | params = json.load(open(param_file, 'rb'))
66 | logger.info('Params: {}'.format(params))
67 | else:
68 | params = {}
69 |
70 | if method == 'ridge':
71 | model = Ridge(params=params)
72 | elif method == 'linear_svr':
73 | model = LinearSVR(params=params)
74 | elif method == 'kernel_svr':
75 | model = KernelSVR(params=params)
76 | elif method == 'lightgbm':
77 | model = LightGBM(params=params)
78 | else:
79 | raise RuntimeError('Not supported method: "{}"'.format(method))
80 |
81 | model.train(data['train']['X'], data['train']['y'],
82 | data['val']['X'], data['val']['y'])
83 |
84 | df_val = model.predict(data['val']['X'], data['val']['df'])
85 | df_test = model.predict(data['test']['X'], data['test']['df'])
86 |
87 | out_dir = Path('../out/ensemble-multidomain/stage2') / args.datatrack / \
88 | f'{method}-{args.feat_type}' / str(args.i_cv)
89 | os.makedirs(out_dir, exist_ok=True)
90 |
91 | df_val.to_csv(out_dir / 'val.csv')
92 | df_test.to_csv(out_dir / 'test.csv')
93 |
94 | model.save_model(out_dir)
95 |
96 |
97 |
98 |
99 | if __name__ == '__main__':
100 | logging.basicConfig(level=logging.INFO, format='{name}: {message}', style='{')
101 | main()
102 |
103 |
104 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/run_stage3.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from pathlib import Path
4 | import logging
5 | from logging import getLogger
6 | import random
7 | import json
8 |
9 | import numpy as np
10 | import torch
11 |
12 | from data_util import load_stage3_data
13 |
14 | from models import Ridge, LinearSVR, KernelSVR, LightGBM, RandomForest
15 | from gp_models import SVGP, ExactGP
16 |
17 | logger = getLogger(__name__)
18 |
19 | RAND_SEED = 0
20 |
21 | def get_arg():
22 | import argparse
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument('method')
25 | parser.add_argument('datatrack')
26 | parser.add_argument('feat_type')
27 | parser.add_argument('i_cv', type=int)
28 | parser.add_argument('--use_opt', action='store_true', default=False)
29 | return parser.parse_args()
30 |
31 |
32 |
33 |
34 | def main():
35 | args = get_arg()
36 |
37 | random.seed(RAND_SEED)
38 | np.random.seed(RAND_SEED)
39 | torch.manual_seed(RAND_SEED)
40 |
41 | data = load_stage3_data(
42 | datatrack=args.datatrack,
43 | feat_type=args.feat_type,
44 | i_cv=args.i_cv,
45 | )
46 |
47 | if args.method == 'svgp':
48 | model = SVGP(stage='stage2')
49 | elif args.method == 'exactgp':
50 | model = ExactGP(stage='stage2')
51 | elif args.method == 'rf':
52 | model = RandomForest()
53 | else:
54 | if args.use_opt:
55 | param_file = Path('../out/ensemble-multidomain/opt_hp_stage3') / args.datatrack / \
56 | f'{args.method}-{args.feat_type}' / 'params.json'
57 | params = json.load(open(param_file, 'rb'))
58 | logger.info('Params: {}'.format(params))
59 | else:
60 | params = {}
61 |
62 | if args.method == 'ridge':
63 | model = Ridge(params=params)
64 | elif args.method == 'linear_svr':
65 | model = LinearSVR(params=params)
66 | elif args.method == 'kernel_svr':
67 | model = KernelSVR(params=params)
68 | elif args.method == 'lightgbm':
69 | model = LightGBM(params=params)
70 | else:
71 | raise RuntimeError('Not supported method: "{}"'.format(args.method))
72 |
73 | model.train(data['train']['X'], data['train']['y'],
74 | data['val']['X'], data['val']['y'])
75 |
76 | df_val = model.predict(data['val']['X'], data['val']['df'])
77 | df_test = model.predict(data['test']['X'], data['test']['df'])
78 |
79 | out_dir = Path('../out/ensemble-multidomain/stage3') / args.datatrack / \
80 | f'{args.method}-{args.feat_type}' / str(args.i_cv)
81 | os.makedirs(out_dir, exist_ok=True)
82 |
83 | df_val.to_csv(out_dir / 'val.csv')
84 | df_test.to_csv(out_dir / 'test.csv')
85 |
86 | model.save_model(out_dir)
87 |
88 |
89 |
90 |
91 | if __name__ == '__main__':
92 | logging.basicConfig(level=logging.INFO, format='{name}: {message}', style='{')
93 | main()
94 |
95 |
96 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/stage2-method/main-strong1-weak48.yaml:
--------------------------------------------------------------------------------
1 |
2 | strong_learners:
3 | - main1
4 | weak_learners:
5 | datatracks:
6 | - phase1-main
7 | model_types:
8 | - ridge
9 | - linear_svr
10 | - kernel_svr
11 | - lightgbm
12 | - rf
13 | - exactgp
14 | ssl_types:
15 | - w2v_small
16 | - w2v_large
17 | - w2v_large2
18 | - w2v_xlsr
19 | - hubert_base
20 | - hubert_large
21 | - wavlm_base
22 | - wavlm_large
23 |
24 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/stage2-method/ood-strong1-weak144.yaml:
--------------------------------------------------------------------------------
1 |
2 | strong_learners:
3 | - ood1
4 | weak_learners:
5 | datatracks:
6 | - phase1-main
7 | - external-wo_test
8 | - phase1-ood
9 | model_types:
10 | - ridge
11 | - linear_svr
12 | - kernel_svr
13 | - lightgbm
14 | - rf
15 | - exactgp
16 | ssl_types:
17 | - w2v_small
18 | - w2v_large
19 | - w2v_large2
20 | - w2v_xlsr
21 | - hubert_base
22 | - hubert_large
23 | - wavlm_base
24 | - wavlm_large
25 |
26 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/unused/opt_stage1_all.sh:
--------------------------------------------------------------------------------
1 |
2 | set -eu
3 |
4 | ssl_types="w2v_small w2v_large w2v_large2 w2v_xlsr wavlm_base wavlm_large hubert_base hubert_large"
5 |
6 | for method in ridge lightgbm linear_svr kernel_svr; do
7 | for datatrack in phase1-ood phase1-main students-wo_test ; do
8 | for ssl_type in ${ssl_types}; do
9 | echo "${datatrack}, ${ssl_type}, ${method}"
10 | poetry run python -u opt_stage1.py ${method} ${datatrack} ${ssl_type}
11 | done
12 | done
13 | done
14 |
15 | echo "done"
16 |
17 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/unused/pred_stage1_exactgp.sh:
--------------------------------------------------------------------------------
1 |
2 | set -eu
3 |
4 | for train_datatrack in phase1-ood students-wo_test phase1-main; do
5 | for pred_datatrack in phase1-ood phase1-main ; do
6 | for ssl_type in w2v_small w2v_large w2v_large2 w2v_xlsr wavlm_base wavlm_large hubert_base hubert_large; do
7 | for method in exactgp; do
8 | for i_cv in 0 1 2 3 4; do
9 | echo "${method}, ${train_datatrack}, ${ssl_type}, ${i_cv}, ${pred_datatrack}"
10 | poetry run python -u pred_stage1.py ${method} ${train_datatrack} ${ssl_type} ${i_cv} ${pred_datatrack}
11 | done
12 | done
13 | done
14 | done
15 | done
16 |
17 | echo "done"
18 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/unused/run_stage1_exactgp.sh:
--------------------------------------------------------------------------------
1 |
2 | set -eu
3 |
4 | ssl_types="w2v_small w2v_large w2v_large2 w2v_xlsr wavlm_base wavlm_large hubert_base hubert_large"
5 |
6 | for datatrack in phase1-ood phase1-main students-wo_test; do
7 | for ssl_type in ${ssl_types}; do
8 | for i_cv in 0 1 2 3 4; do
9 | echo "${datatrack}, ${ssl_type}, ${i_cv}"
10 | poetry run python -u run_stage1.py exactgp ${datatrack} ${ssl_type} ${i_cv}
11 | done
12 | done
13 | done
14 |
15 |
16 |
17 |
18 | echo "done"
19 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/unused/run_stage1_other.sh:
--------------------------------------------------------------------------------
1 |
2 | set -eu
3 |
4 | ssl_types="w2v_small w2v_large w2v_large2 w2v_xlsr wavlm_base wavlm_large hubert_base hubert_large"
5 |
6 | for datatrack in phase1-ood phase1-main students-wo_test ; do
7 | for method in ridge linear_svr kernel_svr rf lightgbm; do
8 | for ssl_type in ${ssl_types} ; do
9 | for i_cv in 0 1 2 3 4; do
10 | echo "${method}, ${datatrack}, ${ssl_type}, ${i_cv}"
11 | poetry run python -u run_stage1.py ${method} ${datatrack} ${ssl_type} ${i_cv}
12 | done
13 | done
14 | done
15 | done
16 |
17 |
18 |
19 | echo "done"
20 |
--------------------------------------------------------------------------------
/stacking/ensemble_multidomain_scripts/unused/run_stage2-end_opt.sh:
--------------------------------------------------------------------------------
1 |
2 | set -eu
3 |
4 | datatrack=phase1-main
5 | feat_type=main-strong1-weak48-opt
6 |
7 | poetry run python -u collect_stage1_result.py ${datatrack} ${feat_type}
8 |
9 | for method in ridge linear_svr kernel_svr lightgbm; do
10 | echo "Optimize hyperparameter for stage2: ${datatrack}, ${feat_type}, ${method}"
11 | poetry run python -u opt_stage2.py ${method} ${datatrack} ${feat_type}
12 | done
13 |
14 |
15 | for method in exactgp ridge linear_svr kernel_svr rf lightgbm; do
16 | for i_cv in 0 1 2 3 4; do
17 | echo "Run stage2: ${method}, ${datatrack}, ${feat_type}, ${i_cv}"
18 | poetry run python -u run_stage2.py --use_opt ${method} ${datatrack} ${feat_type} ${i_cv}
19 | done
20 | done
21 |
22 | echo "Collect stage2 data."
23 | poetry run python -u collect_stage2_result.py ${datatrack} ${feat_type}
24 |
25 |
26 | for method in ridge; do
27 | echo "Optimize hyperparameter for stage3: ${datatrack}, ${feat_type}, ${method}"
28 | poetry run python -u opt_stage3.py ${method} ${datatrack} ${feat_type}
29 | done
30 |
31 | for method in ridge; do
32 | for i_cv in 0 1 2 3 4; do
33 | echo "Run stage3: ${method} ${datatrack}, ${feat_type}, ${i_cv}"
34 | poetry run python -u run_stage3.py --use_opt ${method} ${datatrack} ${feat_type} ${i_cv}
35 | done
36 | done
37 |
38 |
39 | echo "Calculate result."
40 |
41 | poetry run python -u calc_result.py ${datatrack} ${feat_type}
42 |
43 | echo "done"
44 |
--------------------------------------------------------------------------------
/stacking/strong_learner_result/ood1/answer-ood.csvfold_0:
--------------------------------------------------------------------------------
1 | sys3ac0c-utt161761a,3.00628410698846
2 | sys1b29a-uttc13c8ac,3.2306305915117264
3 | sys6d946-utt3fa137b,4.020784974098206
4 | sysc5b22-utte3cfb18,4.155950903892517
5 | sys50620-utt9080d4c,3.1653470546007156
6 | sysaefdf-utt8625c9a,4.133486866950989
7 | sys2dea8-utt0fb1b5f,3.668229043483734
8 | sys6aa80-uttb961ed0,3.049679197371006
9 | sys6aa80-utt11c30b8,3.0484387166798115
10 | sys2dea8-uttaca68d6,3.56439208984375
11 | sys86a3b-uttd7ee46d,3.4484071135520935
12 | sys86a3b-uttd408860,3.4142578840255737
13 | sys8569c-utt339d107,4.2832270860672
14 | sys3ac0c-utt3727514,2.874504864215851
15 | sys6d946-utt29b5124,3.8212228417396545
16 | sysc5b22-uttf413f6c,4.127067804336548
17 | sys9900c-utt380e04a,2.799784615635872
18 | sysc138b-utt4aa722c,3.3437674045562744
19 | sysc0247-uttcf421d3,4.266936302185059
20 | sysaefdf-utt4627f92,3.9700832962989807
21 | sys6d946-utt74b6d95,4.13299822807312
22 | sys288fc-utt8aac2a0,1.708821415901184
23 | sysc0232-utt984d929,4.142822504043579
24 | sys3ac0c-uttc822565,2.9671433679759502
25 | sysc0247-utt324aad1,4.186380982398987
26 | sys6aa80-utt944f00a,3.2828843891620636
27 | sys1b29a-utt31813b4,3.1065728440880775
28 | sys288fc-utt3ada6c5,1.6936789751052856
29 | sysc5b22-utt2ef4cc2,4.162316918373108
30 | sys25583-utt6330ecc,3.866409659385681
31 | sys0c3c7-uttc349635,1.7121723890304565
32 | sys6aa80-uttbee99e7,3.272907167673111
33 | sysc138b-utt3e6bb57,3.114454969763756
34 | sysc8b6d-utt2174cf3,3.2371485382318497
35 | sys3ac0c-utt9d9e772,3.394138514995575
36 | sysc0247-utt7bf8152,4.188221216201782
37 | sysc5b22-uttd0d4f7d,3.80684095621109
38 | sysc138b-utte0f2b88,3.3935784697532654
39 | sys50620-uttdd8c402,3.1611065566539764
40 | sys9900c-utt8a28b36,3.014586901292205
41 | syseb559-utt71685f4,2.927502065896988
42 | sys25583-uttdd97db1,3.775578737258911
43 | sysc93c1-utt9d574ec,1.705985426902771
44 | sysc8b6d-uttc01a331,3.4588173627853394
45 | sys8569c-utt4bd6700,4.181773781776428
46 | sys6d946-uttdc8ae80,3.7061699628829956
47 |
--------------------------------------------------------------------------------
/stacking/strong_learner_result/ood1/answer-ood.csvfold_1:
--------------------------------------------------------------------------------
1 | sys288fc-utt2a428da,1.590195894241333
2 | sys8569c-uttda01d83,4.474315762519836
3 | sys3bec5-uttf9e9b65,4.395640730857849
4 | sysc5b22-utt507ff1e,4.440494537353516
5 | sys6d946-uttfc65d31,3.9793806672096252
6 | sys0c3c7-uttcde36fc,1.5656745433807373
7 | sysc138b-uttfaae8c5,3.9728047251701355
8 | sysaefdf-utta22cc47,4.128987550735474
9 | sys1b29a-utt238cb31,3.4775221943855286
10 | sysaefdf-uttac28750,4.094674468040466
11 | sys86a3b-utt5f51da3,3.7170485854148865
12 | sysc5b22-utt84b31de,4.41584038734436
13 | sysc0232-utt625ccdc,4.285597205162048
14 | sys86a3b-utt86c4d97,3.384651690721512
15 | sysc93c1-utt4a9ee10,1.7306538820266724
16 | sysc5b22-utt05b5a77,4.438835382461548
17 | syseb559-uttcd7fb9e,2.8986603021621704
18 | sysc0247-utt9720993,4.307410955429077
19 | sysc8b6d-utt3118c52,3.9715484380722046
20 | sysc8b6d-utt9bf948a,3.35931795835495
21 | sys25583-utt6518d56,4.095972418785095
22 | sysc8b6d-utt9ccaa2f,4.0339906215667725
23 | sysc0247-utt021f653,4.136787533760071
24 | sys50620-utt712ce0b,2.371592879295349
25 | sysc0247-utt457d98b,4.3827223777771
26 | sys0c3c7-uttb54c043,1.7769298553466797
27 | sysc0232-utt74896c3,4.17369544506073
28 | sysc93c1-utt4a606f1,1.8465670347213745
29 | sys50620-utt8f042c0,2.980513686314225
30 | sys1b29a-utt014d122,3.5027151703834534
31 | sys288fc-utt81bc852,1.5887526273727417
32 | sys3bec5-utte0c7cdf,4.340376496315002
33 | sys1b29a-utt1770396,3.6666112542152405
34 | sys6d946-utt3f304d4,4.209040522575378
35 | syseb559-utt4dd550d,2.7324198484420776
36 | sysc138b-uttcad56b6,3.19346222281456
37 | sys0c3c7-utt330976a,1.609305739402771
38 | sys3bec5-utt3301016,4.412176251411438
39 | sys3ac0c-utt5e1312b,3.5013919472694397
40 | sys6d946-utt6c1c6e1,4.254178404808044
41 | sys9900c-utt1f822b9,2.952831447124481
42 | sys3bec5-utt5ed2652,4.356784820556641
43 | sys2dea8-utt30d9387,3.3734480142593384
44 | sys86a3b-uttecd81d0,3.312489926815033
45 | sys3ac0c-uttafe7bb1,2.9006077349185944
46 |
--------------------------------------------------------------------------------
/stacking/strong_learner_result/ood1/answer-ood.csvfold_2:
--------------------------------------------------------------------------------
1 | sysc0247-utt7596bba,3.916542708873749
2 | sys6d946-utt85e8181,3.309552401304245
3 | sysc5b22-utta450ecd,4.0385085344314575
4 | sysc8b6d-uttefe40e7,2.875736527144909
5 | sys3ac0c-uttc3491bb,3.1840667724609375
6 | sysc8b6d-uttbe29207,3.261176884174347
7 | sys288fc-utt116506b,1.8067857027053833
8 | syseb559-uttaa26ddd,3.0563898496329784
9 | sys288fc-uttcd33316,1.7399002313613892
10 | sys3ac0c-utt8152810,3.080054387450218
11 | sys6aa80-utta989ec2,3.151083752512932
12 | syseb559-uttfb60593,3.3584741950035095
13 | sysc0247-uttfad83f9,3.9709458351135254
14 | sysaefdf-utt1626164,3.972193658351898
15 | sys8569c-uttdf23051,4.064173221588135
16 | syseb559-uttfee71ad,3.322394996881485
17 | sys9900c-utt8e27993,3.5867892503738403
18 | sys86a3b-utt9594edb,3.778729021549225
19 | sys9900c-uttbabd33b,3.0449920147657394
20 | sys288fc-uttd2e0fbb,1.7499089241027832
21 | syseb559-utt80e03e6,3.3057122230529785
22 | sys86a3b-uttd537f38,3.1441497206687927
23 | sys9900c-utt7692d21,2.997207031119615
24 | sys2dea8-uttfee0d1c,3.522417902946472
25 | sys3bec5-uttb78d8af,4.011132478713989
26 | sys9900c-uttee29c4f,3.3247605562210083
27 | sysc93c1-utt6ffd131,1.8141653537750244
28 | sys6aa80-utt1458fca,3.6074836254119873
29 | sysc93c1-utte26df86,2.301818013191223
30 | sys3bec5-uttf0150e9,4.055609583854675
31 | sysc138b-utt5c64cce,3.0564831867814064
32 | sys86a3b-utt34ebcb5,3.5041947960853577
33 | sys6d946-utt97ad31f,3.905802607536316
34 | sys25583-utt85a7380,3.869354546070099
35 | sysc93c1-utt76f14cc,1.7454818487167358
36 | sysc0232-utt3f80869,4.01626980304718
37 | sys86a3b-uttd076e83,3.429683208465576
38 | sysaefdf-uttf1939ce,3.731216847896576
39 | sys3bec5-utt602db43,4.103129863739014
40 | sys288fc-uttf1d17fa,1.746074914932251
41 | sys25583-utt1c24cff,4.08445405960083
42 | sysc138b-utt67b77b4,3.580138087272644
43 | sys86a3b-utt81bb889,3.896349310874939
44 | sysc8b6d-utt5feb6c6,3.8013290762901306
45 | sys6aa80-utt7c44bb4,3.3627112805843353
46 |
--------------------------------------------------------------------------------
/stacking/strong_learner_result/ood1/answer-ood.csvval_0:
--------------------------------------------------------------------------------
1 | sys0c3c7-utt29396c4,1.8738696575164795
2 | sys0c3c7-utt32149b2,1.7467237710952759
3 | sys0c3c7-utt3b43620,1.7445292472839355
4 | sys0c3c7-utte2ff6b2,1.7807897329330444
5 | sys1b29a-uttb5c2dfd,2.840162009000778
6 | sys1b29a-uttbe1f861,3.097307428717613
7 | sys1b29a-uttf1ba227,3.3140459656715393
8 | sys25583-utt32a3068,4.226199269294739
9 | sys25583-utt640ccda,4.173765182495117
10 | sys25583-uttb584e2d,4.127885103225708
11 | sys25583-utte43fff8,3.7679272294044495
12 | sys288fc-utt0b9b8c4,1.6744518280029297
13 | sys288fc-utt178b639,1.6881271600723267
14 | sys288fc-utt27c7f8e,1.6820378303527832
15 | sys288fc-utt4ca70be,1.6948258876800537
16 | sys288fc-utt737b11a,1.7346054315567017
17 | sys288fc-utt875c9bc,1.6459565162658691
18 | sys288fc-uttc86d188,1.7195290327072144
19 | sys2dea8-utt5d0ff00,3.190113589167595
20 | sys2dea8-uttc0684c5,3.69060617685318
21 | sys3ac0c-utt380b246,2.8424201756715775
22 | sys3ac0c-utt6e8c4c8,3.246820494532585
23 | sys3bec5-utt08b2712,4.067929744720459
24 | sys3bec5-utt256038c,4.0526769161224365
25 | sys3bec5-utt57f6999,3.77788507938385
26 | sys3bec5-utt6fa8766,4.025399327278137
27 | sys3bec5-uttf96e3f6,3.908892512321472
28 | sys40775-utt053fb06,2.8519225120544434
29 | sys40775-utt06280d0,3.298975795507431
30 | sys40775-utt21f7eb2,3.1713304817676544
31 | sys40775-utt25debbd,3.4360039830207825
32 | sys40775-utt262d917,3.363330453634262
33 | sys40775-utt2c689f4,3.5800936818122864
34 | sys40775-utt31d8803,3.100691355764866
35 | sys40775-utt3a6022b,3.7426048517227173
36 | sys40775-utt3c9d3a9,3.3877606987953186
37 | sys40775-utt3caa162,3.276692509651184
38 | sys40775-utt40cc32a,3.606023669242859
39 | sys40775-utt4830112,3.020839897915721
40 | sys40775-utt4beb3cb,3.937600076198578
41 | sys40775-utt52bccfe,3.4492527544498444
42 | sys40775-utt59ca17a,2.6853365898132324
43 | sys40775-utt63aa530,3.3543978333473206
44 | sys40775-utt6bccf3e,3.305658459663391
45 | sys40775-utt6cc4d95,3.031956598162651
46 | sys40775-utt734346c,3.5151914954185486
47 | sys40775-utt757daca,3.0054497895762324
48 | sys40775-utt7ab4c4c,3.481399267911911
49 | sys40775-utt7f58b52,3.616021752357483
50 | sys40775-utt8c70e0c,3.2666018903255463
51 | sys40775-utt9043da6,3.1890525817871094
52 | sys40775-utt95d9b13,3.7385385036468506
53 | sys40775-utt979fa4d,3.2476234287023544
54 | sys40775-utt98d2608,3.5850866436958313
55 | sys40775-uttacfa755,3.217980235815048
56 | sys40775-uttbffd493,3.032057635486126
57 | sys40775-uttca8a7fd,3.4135411977767944
58 | sys40775-uttcd51f1b,3.4589690566062927
59 | sys40775-uttd28fea3,2.8292620480060577
60 | sys40775-uttd4f5d6e,3.164977937936783
61 | sys40775-uttd803e0a,3.1480540484189987
62 | sys40775-uttddf3183,3.511105179786682
63 | sys40775-uttde965d9,3.2861602008342743
64 | sys40775-uttdeee18d,3.0033161328174174
65 | sys40775-uttdfaf42e,2.863734170794487
66 | sys40775-utte1baab8,3.1256515234708786
67 | sys40775-utte451004,3.608478367328644
68 | sys40775-utted1807d,3.1984187215566635
69 | sys40775-uttee1039e,3.0977888256311417
70 | sys40775-uttf0806be,3.4857234954833984
71 | sys40775-uttf1ca905,3.2617311477661133
72 | sys40775-uttf46b095,3.5599228739738464
73 | sys40775-uttfe39bc5,3.5992850065231323
74 | sys50620-utt36a51d2,3.6462497115135193
75 | sys50620-uttb648ed6,3.2346756011247635
76 | sys50620-uttdc9bc1b,3.1156226620078087
77 | sys50620-utte891e60,2.767196550965309
78 | sys50620-uttf932cea,2.980430446565151
79 | sys5a87a-utt3153e35,3.0460551269352436
80 | sys5a87a-uttd0f0a73,2.9686387814581394
81 | sys6aa80-utt41c16b6,3.295067608356476
82 | sys6aa80-utta52cb79,3.3978131711483
83 | sys6aa80-uttb7c3c36,3.3251181840896606
84 | sys6aa80-uttc4aa100,3.7246773838996887
85 | sys6aa80-utte8f1cbc,3.2285984605550766
86 | sys6d946-utt067a45d,3.814286947250366
87 | sys6d946-utt454ee70,4.088953971862793
88 | sys6d946-uttb8f803b,3.907171308994293
89 | sys6d946-uttc5d8752,4.120617151260376
90 | sys7d2dc-utt04fca18,4.019955635070801
91 | sys7d2dc-utt5b21c48,3.657771944999695
92 | sys8569c-utt53191c9,4.233504056930542
93 | sys8569c-utt75e2e48,4.241759300231934
94 | sys8569c-utt89043c6,4.2122883796691895
95 | sys86a3b-utt0dfad3e,3.167993053793907
96 | sys86a3b-utt6292b56,3.2046364545822144
97 | sys9900c-utt310ee00,3.095589391887188
98 | sys9900c-utt95df11f,3.000581094471272
99 | sys9900c-utt99ec0d0,2.683213710784912
100 | sys9900c-uttb72803f,3.0643272027373314
101 | sys9900c-utte2b3cdc,3.2304811626672745
102 | sys9900c-utte7af0d7,2.5658408403396606
103 | sysaefdf-utt0c7d358,3.761621117591858
104 | sysaefdf-utt33e9539,4.158675312995911
105 | sysaefdf-uttc8398cf,4.072099804878235
106 | sysaefdf-uttd832fe6,3.330040603876114
107 | sysc0232-utt0d5592e,4.234232664108276
108 | sysc0232-utt1136939,4.041603207588196
109 | sysc0232-utt321b7f3,3.6770675778388977
110 | sysc0232-utt43516ea,4.184963822364807
111 | sysc0232-utta98dc3e,4.0392502546310425
112 | sysc0232-uttcc93fbe,3.8121009469032288
113 | sysc0247-utt0c53cda,4.085257172584534
114 | sysc0247-utt71242b9,4.184931635856628
115 | sysc0247-utte133661,4.232062220573425
116 | sysc0247-uttfe8f1b6,4.241726398468018
117 | sysc138b-utt41cd1c3,2.7210413813591003
118 | sysc138b-utt52345ed,2.9520521461963654
119 | sysc138b-utt5ab7a39,3.299651026725769
120 | sysc138b-utt9d47077,3.0674551352858543
121 | sysc138b-uttc0e1eb7,3.020436340942979
122 | sysc5b22-utt2d3463f,4.21875536441803
123 | sysc5b22-utt3e37da2,4.168697476387024
124 | sysc5b22-utted9388f,4.244592905044556
125 | sysc5b22-uttf51fc7c,4.172379493713379
126 | sysc8b6d-utt445f8c7,3.9167977571487427
127 | sysc8b6d-utt6ab2d26,3.6579365730285645
128 | sysc8b6d-utt7eeebd8,3.894798457622528
129 | sysc8b6d-uttfe63322,3.786387324333191
130 | sysc93c1-utt249fc00,1.7767733335494995
131 | sysc93c1-utt577a573,1.7516868114471436
132 | syseb559-utt1464f1a,2.5668071806430817
133 | syseb559-utt274445f,2.5056876838207245
134 | syseb559-utt5337b78,3.019061068072915
135 | syseb559-utt7b8a073,2.846351996064186
136 | syseb559-uttdb7f27a,2.7430557310581207
137 |
--------------------------------------------------------------------------------
/stacking/strong_learner_result/ood1/answer-ood.csvval_1:
--------------------------------------------------------------------------------
1 | sys0c3c7-utt29396c4,1.6569277048110962
2 | sys0c3c7-utt32149b2,1.5723577737808228
3 | sys0c3c7-utt3b43620,1.5654006004333496
4 | sys0c3c7-utte2ff6b2,1.604346752166748
5 | sys1b29a-uttb5c2dfd,3.4084918797016144
6 | sys1b29a-uttbe1f861,3.729147791862488
7 | sys1b29a-uttf1ba227,3.6519864201545715
8 | sys25583-utt32a3068,4.376435160636902
9 | sys25583-utt640ccda,4.414887428283691
10 | sys25583-uttb584e2d,4.447765827178955
11 | sys25583-utte43fff8,4.259235620498657
12 | sys288fc-utt0b9b8c4,1.5869495868682861
13 | sys288fc-utt178b639,1.5875202417373657
14 | sys288fc-utt27c7f8e,1.5941483974456787
15 | sys288fc-utt4ca70be,1.5977908372879028
16 | sys288fc-utt737b11a,1.8113304376602173
17 | sys288fc-utt875c9bc,1.5783995389938354
18 | sys288fc-uttc86d188,1.6000697612762451
19 | sys2dea8-utt5d0ff00,3.2767574787139893
20 | sys2dea8-uttc0684c5,3.652103900909424
21 | sys3ac0c-utt380b246,3.1722365468740463
22 | sys3ac0c-utt6e8c4c8,3.511052906513214
23 | sys3bec5-utt08b2712,4.320664405822754
24 | sys3bec5-utt256038c,4.418888330459595
25 | sys3bec5-utt57f6999,4.28934383392334
26 | sys3bec5-utt6fa8766,4.370813608169556
27 | sys3bec5-uttf96e3f6,4.3447675704956055
28 | sys40775-utt053fb06,3.485974669456482
29 | sys40775-utt06280d0,3.6210970282554626
30 | sys40775-utt21f7eb2,3.370717942714691
31 | sys40775-utt25debbd,3.524880826473236
32 | sys40775-utt262d917,3.469281166791916
33 | sys40775-utt2c689f4,3.5347546339035034
34 | sys40775-utt31d8803,3.364650994539261
35 | sys40775-utt3a6022b,4.066094398498535
36 | sys40775-utt3c9d3a9,3.7404050827026367
37 | sys40775-utt3caa162,3.8969979882240295
38 | sys40775-utt40cc32a,3.77112877368927
39 | sys40775-utt4830112,3.718696892261505
40 | sys40775-utt4beb3cb,3.9126113653182983
41 | sys40775-utt52bccfe,3.6475443243980408
42 | sys40775-utt59ca17a,2.9844800736755133
43 | sys40775-utt63aa530,3.4839371144771576
44 | sys40775-utt6bccf3e,3.4588995575904846
45 | sys40775-utt6cc4d95,3.3658704459667206
46 | sys40775-utt734346c,3.536608338356018
47 | sys40775-utt757daca,3.281584292650223
48 | sys40775-utt7ab4c4c,3.6340607404708862
49 | sys40775-utt7f58b52,3.5484229922294617
50 | sys40775-utt8c70e0c,3.6158772706985474
51 | sys40775-utt9043da6,3.355036199092865
52 | sys40775-utt95d9b13,3.745011806488037
53 | sys40775-utt979fa4d,3.4204566180706024
54 | sys40775-utt98d2608,3.7399336099624634
55 | sys40775-uttacfa755,2.953236449509859
56 | sys40775-uttbffd493,3.28841096162796
57 | sys40775-uttca8a7fd,3.6459875106811523
58 | sys40775-uttcd51f1b,3.506250262260437
59 | sys40775-uttd28fea3,3.4177536964416504
60 | sys40775-uttd4f5d6e,3.343674421310425
61 | sys40775-uttd803e0a,3.503384828567505
62 | sys40775-uttddf3183,3.475356549024582
63 | sys40775-uttde965d9,3.6865957975387573
64 | sys40775-uttdeee18d,3.341072738170624
65 | sys40775-uttdfaf42e,3.3064666390419006
66 | sys40775-utte1baab8,3.430553048849106
67 | sys40775-utte451004,3.3815099596977234
68 | sys40775-utted1807d,3.7209166288375854
69 | sys40775-uttee1039e,3.744911015033722
70 | sys40775-uttf0806be,3.442750334739685
71 | sys40775-uttf1ca905,3.821177899837494
72 | sys40775-uttf46b095,3.7114606499671936
73 | sys40775-uttfe39bc5,3.68908029794693
74 | sys50620-utt36a51d2,3.8827565908432007
75 | sys50620-uttb648ed6,3.0986974611878395
76 | sys50620-uttdc9bc1b,2.894637681543827
77 | sys50620-utte891e60,2.5575685501098633
78 | sys50620-uttf932cea,3.071582153439522
79 | sys5a87a-utt3153e35,3.135112166404724
80 | sys5a87a-uttd0f0a73,2.8925078213214874
81 | sys6aa80-utt41c16b6,3.6064711213111877
82 | sys6aa80-utta52cb79,3.840636134147644
83 | sys6aa80-uttb7c3c36,3.733500063419342
84 | sys6aa80-uttc4aa100,3.693002223968506
85 | sys6aa80-utte8f1cbc,3.446985363960266
86 | sys6d946-utt067a45d,4.0865994691848755
87 | sys6d946-utt454ee70,4.334984540939331
88 | sys6d946-uttb8f803b,4.222865581512451
89 | sys6d946-uttc5d8752,4.313713550567627
90 | sys7d2dc-utt04fca18,4.010199785232544
91 | sys7d2dc-utt5b21c48,3.5608004331588745
92 | sys8569c-utt53191c9,4.4660950899124146
93 | sys8569c-utt75e2e48,4.4166131019592285
94 | sys8569c-utt89043c6,4.4364928007125854
95 | sys86a3b-utt0dfad3e,3.55155611038208
96 | sys86a3b-utt6292b56,3.365819603204727
97 | sys9900c-utt310ee00,2.9474673829972744
98 | sys9900c-utt95df11f,2.4395697712898254
99 | sys9900c-utt99ec0d0,2.7191194593906403
100 | sys9900c-uttb72803f,2.4801058173179626
101 | sys9900c-utte2b3cdc,3.209576740860939
102 | sys9900c-utte7af0d7,2.9360008537769318
103 | sysaefdf-utt0c7d358,4.201140284538269
104 | sysaefdf-utt33e9539,4.343910217285156
105 | sysaefdf-uttc8398cf,4.286288022994995
106 | sysaefdf-uttd832fe6,3.982454299926758
107 | sysc0232-utt0d5592e,4.392175316810608
108 | sysc0232-utt1136939,4.324562072753906
109 | sysc0232-utt321b7f3,4.107685089111328
110 | sysc0232-utt43516ea,4.312808871269226
111 | sysc0232-utta98dc3e,4.302412629127502
112 | sysc0232-uttcc93fbe,4.046707510948181
113 | sysc0247-utt0c53cda,4.259897470474243
114 | sysc0247-utt71242b9,4.30824601650238
115 | sysc0247-utte133661,4.403515815734863
116 | sysc0247-uttfe8f1b6,4.245032429695129
117 | sysc138b-utt41cd1c3,3.3730322420597076
118 | sysc138b-utt52345ed,3.030466416850686
119 | sysc138b-utt5ab7a39,3.429475098848343
120 | sysc138b-utt9d47077,3.3303037881851196
121 | sysc138b-uttc0e1eb7,3.2648874521255493
122 | sysc5b22-utt2d3463f,4.465267658233643
123 | sysc5b22-utt3e37da2,4.396283984184265
124 | sysc5b22-utted9388f,4.438539385795593
125 | sysc5b22-uttf51fc7c,4.36857283115387
126 | sysc8b6d-utt445f8c7,4.024686098098755
127 | sysc8b6d-utt6ab2d26,3.605237662792206
128 | sysc8b6d-utt7eeebd8,3.8477973341941833
129 | sysc8b6d-uttfe63322,4.0216004848480225
130 | sysc93c1-utt249fc00,1.6773823499679565
131 | sysc93c1-utt577a573,1.6771522760391235
132 | syseb559-utt1464f1a,2.2051857709884644
133 | syseb559-utt274445f,2.5563432574272156
134 | syseb559-utt5337b78,2.7770099490880966
135 | syseb559-utt7b8a073,2.8372543454170227
136 | syseb559-uttdb7f27a,2.366033673286438
137 |
--------------------------------------------------------------------------------
/stacking/strong_learner_result/ood1/answer-ood.csvval_2:
--------------------------------------------------------------------------------
1 | sys0c3c7-utt29396c4,1.9007413387298584
2 | sys0c3c7-utt32149b2,1.7681547403335571
3 | sys0c3c7-utt3b43620,1.721413016319275
4 | sys0c3c7-utte2ff6b2,1.7883329391479492
5 | sys1b29a-uttb5c2dfd,3.177410289645195
6 | sys1b29a-uttbe1f861,3.485043078660965
7 | sys1b29a-uttf1ba227,3.257159322500229
8 | sys25583-utt32a3068,4.03691840171814
9 | sys25583-utt640ccda,4.121394991874695
10 | sys25583-uttb584e2d,4.063861012458801
11 | sys25583-utte43fff8,3.9808879494667053
12 | sys288fc-utt0b9b8c4,1.7661837339401245
13 | sys288fc-utt178b639,1.7247849702835083
14 | sys288fc-utt27c7f8e,1.7177363634109497
15 | sys288fc-utt4ca70be,1.7433276176452637
16 | sys288fc-utt737b11a,1.8274472951889038
17 | sys288fc-utt875c9bc,1.7218009233474731
18 | sys288fc-uttc86d188,1.7508944272994995
19 | sys2dea8-utt5d0ff00,3.3823030591011047
20 | sys2dea8-uttc0684c5,3.7632473707199097
21 | sys3ac0c-utt380b246,3.2561529874801636
22 | sys3ac0c-utt6e8c4c8,3.215494468808174
23 | sys3bec5-utt08b2712,4.037221789360046
24 | sys3bec5-utt256038c,4.032892227172852
25 | sys3bec5-utt57f6999,3.992736518383026
26 | sys3bec5-utt6fa8766,4.005888819694519
27 | sys3bec5-uttf96e3f6,4.036606192588806
28 | sys40775-utt053fb06,3.1620380580425262
29 | sys40775-utt06280d0,3.2348039001226425
30 | sys40775-utt21f7eb2,3.249326527118683
31 | sys40775-utt25debbd,3.384929269552231
32 | sys40775-utt262d917,3.477863311767578
33 | sys40775-utt2c689f4,3.4279374182224274
34 | sys40775-utt31d8803,3.2081397473812103
35 | sys40775-utt3a6022b,4.024236440658569
36 | sys40775-utt3c9d3a9,3.5308963656425476
37 | sys40775-utt3caa162,3.556710362434387
38 | sys40775-utt40cc32a,3.562081277370453
39 | sys40775-utt4830112,3.3248510658740997
40 | sys40775-utt4beb3cb,3.7101657390594482
41 | sys40775-utt52bccfe,3.6892759203910828
42 | sys40775-utt59ca17a,2.932113580405712
43 | sys40775-utt63aa530,3.2914924323558807
44 | sys40775-utt6bccf3e,3.3274621665477753
45 | sys40775-utt6cc4d95,3.0198783073574305
46 | sys40775-utt734346c,3.5169814229011536
47 | sys40775-utt757daca,3.221805825829506
48 | sys40775-utt7ab4c4c,3.293033629655838
49 | sys40775-utt7f58b52,3.5033302903175354
50 | sys40775-utt8c70e0c,3.3420273661613464
51 | sys40775-utt9043da6,3.2828216552734375
52 | sys40775-utt95d9b13,3.5908207297325134
53 | sys40775-utt979fa4d,3.229120597243309
54 | sys40775-utt98d2608,3.323885351419449
55 | sys40775-uttacfa755,3.0765668898820877
56 | sys40775-uttbffd493,2.9936748747713864
57 | sys40775-uttca8a7fd,3.5150104761123657
58 | sys40775-uttcd51f1b,3.2664534151554108
59 | sys40775-uttd28fea3,2.7796780467033386
60 | sys40775-uttd4f5d6e,2.963381800800562
61 | sys40775-uttd803e0a,3.008986245840788
62 | sys40775-uttddf3183,3.4704970121383667
63 | sys40775-uttde965d9,3.3853626251220703
64 | sys40775-uttdeee18d,3.074149541556835
65 | sys40775-uttdfaf42e,2.9378270842134953
66 | sys40775-utte1baab8,3.2391035854816437
67 | sys40775-utte451004,3.4486441910266876
68 | sys40775-utted1807d,3.173550084233284
69 | sys40775-uttee1039e,3.364699602127075
70 | sys40775-uttf0806be,3.372744768857956
71 | sys40775-uttf1ca905,3.4473859667778015
72 | sys40775-uttf46b095,3.8440520763397217
73 | sys40775-uttfe39bc5,3.4940754771232605
74 | sys50620-utt36a51d2,3.7633183002471924
75 | sys50620-uttb648ed6,3.342268943786621
76 | sys50620-uttdc9bc1b,3.2014998346567154
77 | sys50620-utte891e60,2.83543960750103
78 | sys50620-uttf932cea,3.337063044309616
79 | sys5a87a-utt3153e35,2.974347261711955
80 | sys5a87a-uttd0f0a73,3.0582089573144913
81 | sys6aa80-utt41c16b6,3.3762663900852203
82 | sys6aa80-utta52cb79,3.5862852334976196
83 | sys6aa80-uttb7c3c36,3.4562728106975555
84 | sys6aa80-uttc4aa100,3.7309938073158264
85 | sys6aa80-utte8f1cbc,3.317980080842972
86 | sys6d946-utt067a45d,3.9896318316459656
87 | sys6d946-utt454ee70,4.0640339851379395
88 | sys6d946-uttb8f803b,3.968318819999695
89 | sys6d946-uttc5d8752,4.050647020339966
90 | sys7d2dc-utt04fca18,3.9717133045196533
91 | sys7d2dc-utt5b21c48,3.563229262828827
92 | sys8569c-utt53191c9,4.033492922782898
93 | sys8569c-utt75e2e48,4.029201507568359
94 | sys8569c-utt89043c6,4.080493688583374
95 | sys86a3b-utt0dfad3e,3.4544460773468018
96 | sys86a3b-utt6292b56,3.313985913991928
97 | sys9900c-utt310ee00,3.524359345436096
98 | sys9900c-utt95df11f,3.493915557861328
99 | sys9900c-utt99ec0d0,2.957903254777193
100 | sys9900c-uttb72803f,3.5492897033691406
101 | sys9900c-utte2b3cdc,4.015172481536865
102 | sys9900c-utte7af0d7,2.7136725187301636
103 | sysaefdf-utt0c7d358,3.929300010204315
104 | sysaefdf-utt33e9539,4.067215442657471
105 | sysaefdf-uttc8398cf,4.009889483451843
106 | sysaefdf-uttd832fe6,3.434192717075348
107 | sysc0232-utt0d5592e,4.063883185386658
108 | sysc0232-utt1136939,3.911326766014099
109 | sysc0232-utt321b7f3,3.7473238706588745
110 | sysc0232-utt43516ea,3.8803505897521973
111 | sysc0232-utta98dc3e,3.950296401977539
112 | sysc0232-uttcc93fbe,3.526122808456421
113 | sysc0247-utt0c53cda,3.925826609134674
114 | sysc0247-utt71242b9,4.021496653556824
115 | sysc0247-utte133661,4.058210730552673
116 | sysc0247-uttfe8f1b6,3.98443067073822
117 | sysc138b-utt41cd1c3,3.050193063914776
118 | sysc138b-utt52345ed,2.932532951235771
119 | sysc138b-utt5ab7a39,3.177796170115471
120 | sysc138b-utt9d47077,3.122036322951317
121 | sysc138b-uttc0e1eb7,3.0465613305568695
122 | sysc5b22-utt2d3463f,4.089601039886475
123 | sysc5b22-utt3e37da2,4.042580723762512
124 | sysc5b22-utted9388f,4.073251724243164
125 | sysc5b22-uttf51fc7c,4.013024687767029
126 | sysc8b6d-utt445f8c7,4.025208115577698
127 | sysc8b6d-utt6ab2d26,4.048284530639648
128 | sysc8b6d-utt7eeebd8,3.9827736616134644
129 | sysc8b6d-uttfe63322,4.045578479766846
130 | sysc93c1-utt249fc00,2.2003649473190308
131 | sysc93c1-utt577a573,2.0730539560317993
132 | syseb559-utt1464f1a,2.8193936198949814
133 | syseb559-utt274445f,3.435039311647415
134 | syseb559-utt5337b78,3.293728530406952
135 | syseb559-utt7b8a073,3.699833929538727
136 | syseb559-uttdb7f27a,3.16067473590374
137 |
--------------------------------------------------------------------------------
/strong/README.md:
--------------------------------------------------------------------------------
1 | # UTMOS Strong Learner
2 |
3 | Training and inference scripts for the UTMOS strong learner.
4 |
5 | ## Prerequesities
6 |
7 | * poetry
8 | * [WavAugment](https://github.com/facebookresearch/WavAugment)
9 |
10 | ## Pretrained model
11 | Pretrained **UTMOS strong** models for the main and OOD tracks are available.
12 | For the model details, refer to the [paper](https://arxiv.org/abs/2204.02152).
13 |
14 | - [Main](https://drive.google.com/drive/folders/1U4XQze8mJqV4TRMwTcY6T247RpmU5hRg?usp=sharing)
15 | - [OOD](https://drive.google.com/drive/folders/1dPlV92fyKY1arei7TcU2ZFB-wZkYhIqK?usp=sharing)
16 |
17 | Note that each of the above directory contains pretrained models obtained with five different random seeds.
18 |
19 | ## Setup
20 | 1. Download SSL model checkpoints from [fairseq repo](https://github.com/pytorch/fairseq).
21 | 1. Run the following commands.
22 | ```shell
23 | cd path/to/this/repository
24 | poetry install
25 | cd strong/
26 | ln -s path/to/dataset/ data/
27 | poetry shell
28 | ```
29 |
30 | ## Preprocessing
31 | The phoeneme transcription is alreday present in this repository.
32 | You can also perform the transcribe by your own!
33 | ```shell
34 | cd strong/
35 | python transcribe_speech.py
36 | ```
37 |
38 | ## Training
39 |
40 | To train the strong learner, run the following commands for each of the tracks.
41 |
42 | Main track
43 | ```shell
44 | cd strong/
45 | python train.py dataset.use_data.ood=False dataset.use_data.external=False
46 | ```
47 | OOD track
48 | ```shell
49 | cd strong/
50 | python train.py train.model_selection_metric=val_SRCC_system_ood
51 | ```
52 |
53 | ## Prediction
54 | To predict scores with the trained strong learner, run the following command.
55 | ```shell
56 | python predict.py +ckpt_path=outputs/${date}/${time}/train_outputs/hoge.ckpt
57 | ```
58 | To perform prdiction with the pretrained model, run the following command.
59 |
60 | ```shell
61 | python predict.py +ckpt_path=outputs/${date}/${time}/train_outputs/hoge.ckpt +paper_weights=True
62 | ```
63 |
--------------------------------------------------------------------------------
/strong/configs/cv.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - dataset: cv
3 | - model: default
4 | - train: default
5 |
6 | batch_size_and_model: "wav2vec2-base-4"
7 | debug: False
8 | deepspeed: False
9 | outfile: answer.csv
--------------------------------------------------------------------------------
/strong/configs/dataset/cv.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | data_dir: data/phase1-main/DATA
3 | data_sources:
4 | -
5 | name: "main"
6 | train_mos_list_path: data/phase1-main/DATA/sets/TRAINSET
7 | val_mos_list_path: data/phase1-main/DATA/sets/DEVSET
8 | test_mos_list_path: data/phase1-main/DATA/sets/test.scp
9 | wav_dir: data/phase1-main/DATA/wav/
10 | data_dir: data/phase1-main/DATA
11 | outfile: answer-main.csv
12 | -
13 | name: "ood"
14 | train_mos_list_path: data/phase1-ood/DATA/sets/TRAINSET
15 | val_mos_list_path: data/phase1-ood/DATA/sets/DEVSET
16 | test_mos_list_path: data/phase1-ood/DATA/sets/test.scp
17 | wav_dir: data/phase1-ood/DATA/wav/
18 | data_dir: data/phase1-ood/DATA
19 | outfile: answer-ood.csv
20 | -
21 | name: "external"
22 | train_mos_list_path: TRAINSET_external.txt
23 | wav_dir: data/phase1-ood/DATA/wav/
24 | data_dir: data/phase1-ood/DATA
25 | k_cv: 5
26 | use_data:
27 | main: True
28 | ood: True
29 | external: True
30 | datamodule:
31 | _target_: dataset.CVDataModule
32 | only_mean: False
33 | additional_datas:
34 | -
35 | _target_: dataset.PhonemeData
36 | transcription_file_path: 'transcriptions_clustered.csv'
37 | with_reference: True
38 | -
39 | _target_: dataset.NormalizeScore
40 | org_max: 5.0
41 | org_min: 1.0
42 | normalize_to_max: 1.0
43 | normalize_to_min: -1.0
44 | -
45 | _target_: dataset.AugmentWav
46 | pitch_shift_minmax:
47 | min: -300
48 | max: 300
49 | random_time_warp_f: 1.0
50 | -
51 | _target_: dataset.SliceWav
52 | max_wav_seconds: 10
53 |
--------------------------------------------------------------------------------
/strong/configs/dataset/default.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | data_dir: data/phase1-main/DATA
3 | data_sources:
4 | -
5 | name: "main"
6 | train_mos_list_path: data/phase1-main/DATA/sets/TRAINSET
7 | val_mos_list_path: data/phase1-main/DATA/sets/DEVSET
8 | test_mos_list_path: data/phase1-main/DATA/sets/test.scp
9 | wav_dir: data/phase1-main/DATA/wav/
10 | data_dir: data/phase1-main/DATA
11 | outfile: answer-main.csv
12 | -
13 | name: "ood"
14 | train_mos_list_path: data/phase1-ood/DATA/sets/TRAINSET
15 | val_mos_list_path: data/phase1-ood/DATA/sets/DEVSET
16 | test_mos_list_path: data/phase1-ood/DATA/sets/test.scp
17 | wav_dir: data/phase1-ood/DATA/wav/
18 | data_dir: data/phase1-ood/DATA
19 | outfile: answer-ood.csv
20 | -
21 | name: "external"
22 | train_mos_list_path: TRAINSET_external.txt
23 | wav_dir: data/phase1-ood/DATA/wav/
24 | data_dir: data/phase1-ood/DATA
25 | use_data:
26 | main: True
27 | ood: True
28 | external: True
29 | datamodule:
30 | _target_: dataset.DataModule
31 | only_mean: False
32 | additional_datas:
33 | -
34 | _target_: dataset.PhonemeData
35 | transcription_file_path: 'transcriptions_clustered.csv'
36 | with_reference: True
37 | -
38 | _target_: dataset.NormalizeScore
39 | org_max: 5.0
40 | org_min: 1.0
41 | normalize_to_max: 1.0
42 | normalize_to_min: -1.0
43 | -
44 | _target_: dataset.AugmentWav
45 | pitch_shift_minmax:
46 | min: -300
47 | max: 300
48 | random_time_warp_f: 1.0
49 | -
50 | _target_: dataset.SliceWav
51 | max_wav_seconds: 10
52 |
--------------------------------------------------------------------------------
/strong/configs/dataset/wo_augment.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | data_dir: data/phase1-main/DATA
3 | data_sources:
4 | -
5 | name: "main"
6 | train_mos_list_path: data/phase1-main/DATA/sets/TRAINSET
7 | val_mos_list_path: data/phase1-main/DATA/sets/DEVSET
8 | test_mos_list_path: data/phase1-main/DATA/sets/test.scp
9 | test_post_mos_list_path: data/phase1-main/DATA/sets/TESTSET
10 | wav_dir: data/phase1-main/DATA/wav/
11 | data_dir: data/phase1-main/DATA
12 | outfile: answer-main.csv
13 | -
14 | name: "ood"
15 | train_mos_list_path: data/phase1-ood/DATA/sets/TRAINSET
16 | val_mos_list_path: data/phase1-ood/DATA/sets/DEVSET
17 | test_mos_list_path: data/phase1-ood/DATA/sets/test.scp
18 | test_post_mos_list_path: data/phase1-ood/DATA/sets/TESTSET
19 | wav_dir: data/phase1-ood/DATA/wav/
20 | data_dir: data/phase1-ood/DATA
21 | outfile: answer-ood.csv
22 | -
23 | name: "external"
24 | train_mos_list_path: TRAINSET_external.txt
25 | wav_dir: data/phase1-ood/DATA/wav/
26 | data_dir: data/phase1-ood/DATA
27 | use_data:
28 | main: True
29 | ood: True
30 | external: True
31 | datamodule:
32 | _target_: dataset.DataModule
33 | only_mean: False
34 | additional_datas:
35 | -
36 | _target_: dataset.PhonemeData
37 | transcription_file_path: 'transcriptions_clustered.csv'
38 | with_reference: True
39 | -
40 | _target_: dataset.NormalizeScore
41 | org_max: 5.0
42 | org_min: 1.0
43 | normalize_to_max: 1.0
44 | normalize_to_min: -1.0
45 | -
46 | _target_: dataset.SliceWav
47 | max_wav_seconds: 10
48 |
--------------------------------------------------------------------------------
/strong/configs/default.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - dataset: default
3 | - model: default
4 | - train: default
5 |
6 | debug: False
7 | deepspeed: False
8 | outfile: answer.csv
9 |
--------------------------------------------------------------------------------
/strong/configs/model/default.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | lightning_module:
3 | _target_: lightning_module.BaselineLightningModule
4 | WavLM: False
5 |
6 | feature_extractors:
7 | -
8 | _target_: model.load_ssl_model
9 | cp_path: ../fairseq_checkpoints/wav2vec_small.pt
10 |
11 | -
12 | _target_: model.PhonemeEncoder
13 | hidden_dim: 256
14 | emb_dim: 256
15 | out_dim: 256
16 | n_lstm_layers: 3
17 | vocab_size: 198
18 | -
19 | _target_: model.DomainEmbedding
20 | n_domains: 3
21 | domain_dim: 128
22 |
23 | output_layers:
24 | -
25 | _target_: model.LDConditioner
26 | judge_dim: 128
27 | num_judges: 3000
28 | -
29 | _target_: model.Projection
30 | hidden_dim: 2048
31 | activation:
32 | _target_: torch.nn.ReLU
33 | range_clipping: False
34 |
35 |
--------------------------------------------------------------------------------
/strong/configs/model/wo_phoneme.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | lightning_module:
3 | _target_: lightning_module.BaselineLightningModule
4 | WavLM: False
5 |
6 | feature_extractors:
7 | -
8 | _target_: model.load_ssl_model
9 | cp_path: ../fairseq_checkpoints/fairseq/wav2vec_small.pt
10 |
11 | -
12 | _target_: model.DomainEmbedding
13 | n_domains: 3
14 | domain_dim: 128
15 |
16 | output_layers:
17 | -
18 | _target_: model.LDConditioner
19 | judge_dim: 128
20 | num_judges: 3000
21 | -
22 | _target_: model.Projection
23 | hidden_dim: 2048
24 | activation:
25 | _target_: torch.nn.ReLU
26 | range_clipping: False
27 |
28 |
--------------------------------------------------------------------------------
/strong/configs/optuna-main.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - dataset: default
3 | - model: default
4 | - train: default
5 | - override hydra/sweeper: optuna
6 |
7 | hydra:
8 | sweeper:
9 | n_trials: 100
10 | direction: maximize
11 | storage: ???
12 | study_name: sslmos_listener_ld_contrastive_abci
13 | n_jobs: 1
14 | search_space:
15 | train.criterion.loss_weights.1:
16 | type: float
17 | low: 0.0
18 | high: 2.0
19 | step: 0.1
20 | train.criterion.loss_instances.1.margin:
21 | type: float
22 | low: 0.0
23 | high: 0.5
24 | step: 0.1
25 | train.criterion.loss_instances.0.tau:
26 | type: categorical
27 | choices:
28 | - 0.1
29 | - 0.25
30 | train.criterion.loss_instances.0.mode:
31 | type: categorical
32 | choices:
33 | - 'frame'
34 | dataset.additional_datas.2.pitch_shift_minmax.min:
35 | type: categorical
36 | choices:
37 | - 0
38 | - -100
39 | - -200
40 | - -300
41 | dataset.additional_datas.2.pitch_shift_minmax.max:
42 | type: categorical
43 | choices:
44 | - 0
45 | - 100
46 | - 200
47 | - 300
48 | dataset.additional_datas.2.random_time_warp_f:
49 | type: float
50 | low: 1
51 | high: 3
52 | step: 0.5
53 | dataset.use_data.ood:
54 | type: categorical
55 | choices:
56 | - True
57 | - False
58 | dataset.use_data.external:
59 | type: categorical
60 | choices:
61 | - True
62 | - False
63 | dataset.only_mean:
64 | type: categorical
65 | choices:
66 | - True
67 | - False
68 | batch_size_and_model:
69 | type: categorical
70 | choices:
71 | - "wav2vec2-base-4"
72 | - "wav2vec2-base-8"
73 | - "wav2vec2-base-16"
74 | - "wavlm-large-4"
75 | batch_size_and_model: "wav2vec2-base-4"
76 | tuning_target: ???
77 | debug: False
78 | outfile: answer.csv
79 |
--------------------------------------------------------------------------------
/strong/configs/optuna-ood.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - dataset: default
3 | - model: default
4 | - train: default
5 | - override hydra/sweeper: optuna
6 |
7 | hydra:
8 | sweeper:
9 | n_trials: 100
10 | direction: maximize
11 | storage: ???
12 | study_name: sslmos_listener_ld_contrastive_abci
13 | n_jobs: 1
14 | search_space:
15 | train.criterion.loss_weights.1:
16 | type: float
17 | low: 0.0
18 | high: 2.0
19 | step: 0.1
20 | train.criterion.loss_instances.1.margin:
21 | type: float
22 | low: 0.0
23 | high: 0.5
24 | step: 0.1
25 | train.criterion.loss_instances.0.tau:
26 | type: categorical
27 | choices:
28 | - 0.1
29 | - 0.25
30 | train.criterion.loss_instances.0.mode:
31 | type: categorical
32 | choices:
33 | - 'frame'
34 | dataset.additional_datas.2.pitch_shift_minmax.min:
35 | type: categorical
36 | choices:
37 | - 0
38 | - -100
39 | - -200
40 | - -300
41 | dataset.additional_datas.2.pitch_shift_minmax.max:
42 | type: categorical
43 | choices:
44 | - 0
45 | - 100
46 | - 200
47 | - 300
48 | dataset.additional_datas.2.random_time_warp_f:
49 | type: float
50 | low: 1
51 | high: 3
52 | step: 0.5
53 | dataset.use_data.main:
54 | type: categorical
55 | choices:
56 | - True
57 | - False
58 | dataset.use_data.external:
59 | type: categorical
60 | choices:
61 | - True
62 | - False
63 | dataset.only_mean:
64 | type: categorical
65 | choices:
66 | - True
67 | - False
68 | batch_size_and_model:
69 | type: categorical
70 | choices:
71 | - "wav2vec2-base-4"
72 | - "wav2vec2-base-8"
73 | - "wav2vec2-base-16"
74 | - "wavlm-large-4"
75 | batch_size_and_model: "wav2vec2-base-4"
76 | tuning_target: ???
77 | debug: False
78 | outfile: answer.csv
79 |
--------------------------------------------------------------------------------
/strong/configs/train/default.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | seed: 1234
3 | use_wandb: False
4 | model_selection_metric: val_SRCC_system_main
5 | train_batch_size: 12
6 | val_batch_size: 1
7 | test_batch_size: 1
8 | out_dir: train_output/
9 | trainer_args:
10 | max_steps: 15_000
11 | gpus: [0]
12 | deterministic: True
13 | auto_select_gpus: False
14 | benchmark: True
15 | precision: 32
16 | gradient_clip_val: 1.0
17 | flush_logs_every_n_steps: 10
18 | val_check_interval: 0.5
19 | accumulate_grad_batches: 2
20 | # strategy: ddp
21 | optimizer:
22 | _target_: torch.optim.Adam
23 | lr: 2e-5
24 | scheduler:
25 | _target_: transformers.get_linear_schedule_with_warmup
26 | num_warmup_steps: 4000
27 | num_training_steps: 15_000
28 | early_stopping:
29 | patience: 100
30 |
31 | criterion:
32 | _target_: loss_function.CombineLosses
33 | loss_weights:
34 | - 1.0
35 | - 0.5
36 | loss_instances:
37 | -
38 | _target_: loss_function.ClippedMSELoss
39 | criterion:
40 | _target_: torch.nn.MSELoss
41 | reduction: 'none'
42 | tau: 0.25
43 | mode: 'frame'
44 | -
45 | _target_: loss_function.ContrastiveLoss
46 | margin: 0.1
47 |
--------------------------------------------------------------------------------
/strong/cross_validation.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from pytorch_lightning.loggers.csv_logs import CSVLogger
3 | from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
4 | from pytorch_lightning.callbacks import ModelCheckpoint
5 | from pytorch_lightning import Trainer
6 | from pytorch_lightning.callbacks.early_stopping import EarlyStopping
7 | from dataset import CVDataModule, TestDataModule
8 | from lightning_module import UTMOSLightningModule
9 | import hydra
10 | import wandb
11 |
12 | @hydra.main(config_path="configs",config_name='cv')
13 | def cross_validation(cfg):
14 |
15 | k_cv = cfg.dataset.k_cv
16 | i_cv = cfg.dataset.i_cv
17 |
18 | print("------- Using subset_{} out of {}-fold -------".format(i_cv, k_cv))
19 | d_metrics = fit_and_test(cfg, k_cv, i_cv)
20 | wandb.log(d_metrics)
21 |
22 |
23 | def fit_and_test(cfg, k_cv, i_cv):
24 | debug = cfg.debug
25 | if debug:
26 | cfg.train.trainer_args.max_steps=10
27 |
28 |
29 | loggers = []
30 | loggers.append(CSVLogger(save_dir=cfg.train.out_dir, name="train_log"))
31 | loggers.append(TensorBoardLogger(save_dir=cfg.train.out_dir, name="tf_log"))
32 | if cfg.train.use_wandb:
33 | loggers.append(WandbLogger(project="voicemos",offline=debug))
34 |
35 | checkpoint_callback = ModelCheckpoint(
36 | dirpath=cfg.train.out_dir,
37 | save_weights_only=True,
38 | save_top_k=1,
39 | save_last=True,
40 | every_n_epochs=1,
41 | monitor=cfg.train.model_selection_metric,
42 | mode='max'
43 | )
44 | callbacks = [checkpoint_callback]
45 | earlystop_callback = EarlyStopping(
46 | monitor="val_loss", min_delta=0.0, patience=cfg.train.early_stopping.patience, mode="min"
47 | )
48 | callbacks.append(earlystop_callback)
49 |
50 | trainer = Trainer(
51 | **cfg.train.trainer_args,
52 | default_root_dir=hydra.utils.get_original_cwd(),
53 | limit_train_batches=0.01 if debug else 1.0,
54 | limit_val_batches=0.5 if debug else 1.0,
55 | callbacks=callbacks,
56 | logger=loggers,
57 | )
58 |
59 | datamodule = CVDataModule(cfg=cfg, k_cv=k_cv, i_cv=i_cv)
60 | val_datamodule = TestDataModule(cfg=cfg, i_cv=i_cv, set_name='val')
61 | test_datamodule = TestDataModule(cfg=cfg, i_cv=i_cv, set_name='test')
62 | lightning_module = UTMOSLightningModule(cfg)
63 | trainer.fit(lightning_module, datamodule=datamodule)
64 |
65 | if debug:
66 | trainer.test(lightning_module, verbose=True, datamodule=datamodule)
67 | result = trainer.logged_metrics["test_SRCC_SYS_main_i_cv_{}_set_name_{}".format(i_cv, "fold")]
68 | trainer.test(lightning_module, verbose=True, datamodule=val_datamodule)
69 | trainer.test(lightning_module, verbose=True, datamodule=test_datamodule)
70 | else:
71 | trainer.test(lightning_module, verbose=True, datamodule=datamodule, ckpt_path=checkpoint_callback.best_model_path)
72 | result = trainer.logged_metrics["test_SRCC_SYS_main_i_cv_{}_set_name_{}".format(i_cv, "fold")]
73 | trainer.test(lightning_module, verbose=True, datamodule=val_datamodule, ckpt_path=checkpoint_callback.best_model_path)
74 | trainer.test(lightning_module, verbose=True, datamodule=test_datamodule, ckpt_path=checkpoint_callback.best_model_path)
75 |
76 | return result
77 |
78 | if __name__ == "__main__":
79 | cross_validation()
80 |
--------------------------------------------------------------------------------
/strong/data:
--------------------------------------------------------------------------------
1 | /media/ssd/voicemos/data/2022/
--------------------------------------------------------------------------------
/strong/data_augment.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchaudio
3 | import augment
4 | import numpy as np
5 | import random
6 |
7 |
8 | class ChainRunner:
9 | """
10 | Takes an instance of augment.EffectChain and applies it on pytorch tensors.
11 | """
12 |
13 | def __init__(self, chain):
14 | self.chain = chain
15 |
16 | def __call__(self, x):
17 | """
18 | x: torch.Tensor, (channels, length). Must be placed on CPU.
19 | """
20 | src_info = {'channels': x.size(0), # number of channels
21 | 'length': x.size(1), # length of the sequence
22 | 'precision': 32, # precision (16, 32 bits)
23 | 'rate': 16000.0, # sampling rate
24 | 'bits_per_sample': 32} # size of the sample
25 |
26 | target_info = {'channels': 1,
27 | 'length': x.size(1),
28 | 'precision': 32,
29 | 'rate': 16000.0,
30 | 'bits_per_sample': 32}
31 |
32 | y = self.chain.apply(
33 | x, src_info=src_info, target_info=target_info)
34 |
35 | if torch.isnan(y).any() or torch.isinf(y).any():
36 | return x.clone()
37 | return y
38 |
39 |
40 | def random_pitch_shift(a=-300, b=300):
41 | return random.randint(a, b)
42 |
43 | def random_time_warp(f=1):
44 | # time warp range: [1-0.1*f, 1+0.1*f], default is [0.9, 1.1]
45 | return 1 + f * (random.random() - 0.5) / 5
46 |
47 | if __name__ == '__main__':
48 | chain = augment.EffectChain()
49 | chain.pitch(random_pitch_shift).rate(16000)
50 | chain.tempo(random_time_warp)
51 | chain = ChainRunner(chain)
52 | wav = torch.randn((1, 16000))
53 | augmented = chain(wav)
54 | print(wav.shape, augmented.shape)
55 | print(wav[:, :-1000], augmented[:, :-1000])
56 |
--------------------------------------------------------------------------------
/strong/lightning_module.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 | import torch
3 | import torch.nn as nn
4 | from WavLM import WavLM, WavLMConfig
5 | import os
6 | import fairseq
7 | import numpy as np
8 | import scipy.stats
9 | import hydra
10 | from transformers import AdamW, get_linear_schedule_with_warmup
11 | from model import load_ssl_model, PhonemeEncoder, DomainEmbedding, LDConditioner, Projection
12 | import wandb
13 |
14 |
15 | class UTMOSLightningModule(pl.LightningModule):
16 | def __init__(self, cfg):
17 | super().__init__()
18 | self.cfg = cfg
19 | self.construct_model()
20 | self.prepare_domain_table()
21 | self.save_hyperparameters()
22 |
23 | def construct_model(self):
24 | self.feature_extractors = nn.ModuleList([
25 | hydra.utils.instantiate(feature_extractor) for feature_extractor in self.cfg.model.feature_extractors
26 | ])
27 | output_dim = sum([ feature_extractor.get_output_dim() for feature_extractor in self.feature_extractors])
28 | output_layers = []
29 | for output_layer in self.cfg.model.output_layers:
30 | output_layers.append(
31 | hydra.utils.instantiate(output_layer,input_dim=output_dim)
32 | )
33 | output_dim = output_layers[-1].get_output_dim()
34 |
35 | self.output_layers = nn.ModuleList(output_layers)
36 |
37 | self.criterion = self.configure_criterion()
38 |
39 | def prepare_domain_table(self):
40 | self.domain_table = {}
41 | data_sources = self.cfg.dataset.data_sources
42 | for idx, datasource in enumerate(data_sources):
43 | if not self.cfg.dataset.use_data.external and datasource['name'] == 'external':
44 | data_sources.pop(idx)
45 | for idx, datasource in enumerate(data_sources):
46 | if not self.cfg.dataset.use_data.main and datasource['name'] == 'main':
47 | data_sources.pop(idx)
48 | for idx, datasource in enumerate(data_sources):
49 | if not self.cfg.dataset.use_data.ood and datasource['name'] == 'ood':
50 | data_sources.pop(idx)
51 | for i, datasource in enumerate(data_sources):
52 | if not hasattr(datasource,'val_mos_list_path'):
53 | continue
54 | self.domain_table[i] = datasource["name"]
55 |
56 | def forward(self, inputs):
57 | outputs = {}
58 | for feature_extractor in self.feature_extractors:
59 | outputs.update(feature_extractor(inputs))
60 | x = outputs
61 | for output_layer in self.output_layers:
62 | x = output_layer(x,inputs)
63 | return x
64 |
65 | def training_step(self, batch, batch_idx):
66 | outputs = self(batch)
67 | loss = self.criterion(outputs, batch['score'])
68 | self.log(
69 | "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=self.cfg.train.train_batch_size
70 | )
71 | return loss
72 |
73 | def validation_step(self, batch, batch_idx):
74 | outputs = self(batch)
75 | loss = self.criterion(outputs, batch['score'])
76 | if outputs.dim() > 1:
77 | outputs = outputs.mean(dim=1).squeeze(-1)
78 | return {
79 | "loss": loss,
80 | "outputs": outputs.cpu().numpy()[0]*2 +3.0,
81 | "filename": batch["wavname"][0],
82 | "domain": batch["domain"][0],
83 | "utt_avg_score": batch["utt_avg_score"][0].item(),
84 | "sys_avg_score": batch["sys_avg_score"][0].item()
85 | }
86 |
87 | def validation_epoch_end(self, outputs):
88 | val_loss = torch.stack([out["loss"] for out in outputs]).mean().item()
89 | self.log("val_loss", val_loss, on_epoch=True, prog_bar=True, logger=True)
90 | for domain_id in self.domain_table:
91 | outputs_domain = [out for out in outputs if out["domain"] == domain_id]
92 | if len(outputs_domain) == 0:
93 | continue
94 | _, SRCC, MSE = self.calc_score(outputs_domain)
95 | self.log(
96 | "val_SRCC_system_{}".format(self.domain_table[domain_id]),
97 | SRCC,
98 | on_epoch=True,
99 | prog_bar=True,
100 | logger=True
101 | )
102 | self.log(
103 | "val_MSE_system_{}".format(self.domain_table[domain_id]),
104 | MSE,
105 | on_epoch=True,
106 | prog_bar=True,
107 | logger=True
108 | )
109 | if domain_id == 0:
110 | self.log(
111 | "val_SRCC_system".format(self.domain_table[domain_id]),
112 | SRCC,
113 | on_epoch=True,
114 | prog_bar=True,
115 | logger=True
116 | )
117 |
118 | def test_step(self, batch, batch_idx):
119 | outputs = self(batch)
120 | loss = self.criterion(outputs, batch['score'])
121 | labels = batch['score']
122 | filenames = batch['wavname']
123 | loss = self.criterion(outputs, labels)
124 | if outputs.dim() > 1:
125 | outputs = outputs.mean(dim=1).squeeze(-1)
126 | return {
127 | "loss": loss,
128 | "outputs": outputs.cpu().detach().numpy()[0]*2 +3.0,
129 | "labels": labels.cpu().detach().numpy()[0] *2 +3.0,
130 | "filename": filenames[0],
131 | "domain": batch["domain"][0],
132 | "i_cv": batch["i_cv"][0],
133 | "set_name": batch["set_name"][0],
134 | "utt_avg_score": batch["utt_avg_score"][0].item(),
135 | "sys_avg_score": batch["sys_avg_score"][0].item()
136 | }
137 |
138 | def test_epoch_end(self, outputs):
139 | outfiles = [datasource["outfile"] + '{}_{}'.format(outputs[0]['set_name'],outputs[0]['i_cv']) for datasource in self.cfg.dataset.data_sources if hasattr(datasource,'outfile')]
140 | for domain_id in self.domain_table:
141 | outputs_domain = [out for out in outputs if out["domain"] == domain_id]
142 | predictions, SRCC, MSE = self.calc_score(outputs_domain)
143 | self.log(
144 | "test_SRCC_SYS_{}_i_cv_{}_set_name_{}".format(self.domain_table[domain_id], outputs[0]['i_cv'], outputs[0]['set_name']),
145 | SRCC,
146 | )
147 | if domain_id == 0:
148 | self.log(
149 | "test_SRCC_SYS".format(self.domain_table[domain_id]),
150 | SRCC
151 | )
152 | with open(outfiles[domain_id], "w") as fw:
153 | for k, v in predictions.items():
154 | outl = k.split(".")[0] + "," + str(v) + "\n"
155 | fw.write(outl)
156 | try:
157 | wandb.save(outfiles[domain_id])
158 | except:
159 | print('outfile {} saved'.format(outfiles[domain_id]))
160 |
161 | def configure_optimizers(self):
162 | optimizer = hydra.utils.instantiate(
163 | self.cfg.train.optimizer,
164 | params=self.parameters()
165 | )
166 | scheduler = hydra.utils.instantiate(
167 | self.cfg.train.scheduler,
168 | optimizer=optimizer
169 | )
170 | scheduler = {"scheduler": scheduler,
171 | "interval": "step", "frequency": 1}
172 | return {"optimizer": optimizer, "lr_scheduler": scheduler}
173 |
174 | def configure_criterion(self):
175 | return hydra.utils.instantiate(self.cfg.train.criterion,_recursive_=True)
176 |
177 | def calc_score(self, outputs, verbose=False):
178 |
179 | def systemID(uttID):
180 | return uttID.split("-")[0]
181 |
182 | predictions = {}
183 | true_MOS = {}
184 | true_sys_MOS_avg = {}
185 | for out in outputs:
186 | predictions[out["filename"]] = out["outputs"]
187 | true_MOS[out["filename"]] = out["utt_avg_score"]
188 | true_sys_MOS_avg[out["filename"].split("-")[0]] = out["sys_avg_score"]
189 |
190 | ## compute correls.
191 | sorted_uttIDs = sorted(predictions.keys())
192 | ts = []
193 | ps = []
194 | for uttID in sorted_uttIDs:
195 | t = true_MOS[uttID]
196 | p = predictions[uttID]
197 | ts.append(t)
198 | ps.append(p)
199 |
200 | truths = np.array(ts)
201 | print(ps)
202 | preds = np.array(ps)
203 |
204 | ### UTTERANCE
205 | MSE = np.mean((truths - preds) ** 2)
206 | LCC = np.corrcoef(truths, preds)
207 | SRCC = scipy.stats.spearmanr(truths.T, preds.T)
208 | KTAU = scipy.stats.kendalltau(truths, preds)
209 | if verbose:
210 | print("[UTTERANCE] Test error= %f" % MSE)
211 | print("[UTTERANCE] Linear correlation coefficient= %f" % LCC[0][1])
212 | print("[UTTERANCE] Spearman rank correlation coefficient= %f" % SRCC[0])
213 | print("[UTTERANCE] Kendall Tau rank correlation coefficient= %f" % KTAU[0])
214 |
215 | ### SYSTEM
216 | pred_sys_MOSes = {}
217 | for uttID in sorted_uttIDs:
218 | sysID = systemID(uttID)
219 | noop = pred_sys_MOSes.setdefault(sysID, [])
220 | pred_sys_MOSes[sysID].append(predictions[uttID])
221 |
222 | pred_sys_MOS_avg = {}
223 | for k, v in pred_sys_MOSes.items():
224 | avg_MOS = sum(v) / (len(v) * 1.0)
225 | pred_sys_MOS_avg[k] = avg_MOS
226 |
227 | ## make lists sorted by system
228 | pred_sysIDs = sorted(pred_sys_MOS_avg.keys())
229 | sys_p = []
230 | sys_t = []
231 | for sysID in pred_sysIDs:
232 | sys_p.append(pred_sys_MOS_avg[sysID])
233 | sys_t.append(true_sys_MOS_avg[sysID])
234 |
235 | sys_true = np.array(sys_t)
236 | sys_predicted = np.array(sys_p)
237 |
238 | MSE = np.mean((sys_true - sys_predicted) ** 2)
239 | LCC = np.corrcoef(sys_true, sys_predicted)
240 | SRCC = scipy.stats.spearmanr(sys_true.T, sys_predicted.T)
241 | KTAU = scipy.stats.kendalltau(sys_true, sys_predicted)
242 | if verbose:
243 | print("[SYSTEM] Test error= %f" % MSE)
244 | print("[SYSTEM] Linear correlation coefficient= %f" % LCC[0][1])
245 | print("[SYSTEM] Spearman rank correlation coefficient= %f" % SRCC[0])
246 | print("[SYSTEM] Kendall Tau rank correlation coefficient= %f" % KTAU[0])
247 |
248 | return predictions, SRCC[0], MSE
249 |
250 |
251 |
252 | class DeepSpeedBaselineLightningModule(UTMOSLightningModule):
253 | def __init__(self, cfg):
254 | super().__init__(cfg)
255 |
256 | def configure_optimizers(self):
257 | from deepspeed.ops.adam import DeepSpeedCPUAdam
258 | return DeepSpeedCPUAdam(self.parameters())
--------------------------------------------------------------------------------
/strong/loss_function.py:
--------------------------------------------------------------------------------
1 | from numpy import dtype
2 | import torch
3 | import torch.nn as nn
4 |
5 | class ContrastiveLoss(nn.Module):
6 | '''
7 | Contrastive Loss
8 | Args:
9 | margin: non-neg value, the smaller the stricter the loss will be, default: 0.2
10 |
11 | '''
12 | def __init__(self, margin=0.2):
13 | super(ContrastiveLoss, self).__init__()
14 | self.margin = margin
15 |
16 | def forward(self, pred_score, gt_score):
17 | if pred_score.dim() > 2:
18 | pred_score = pred_score.mean(dim=1).squeeze(1)
19 | # pred_score, gt_score: tensor, [batch_size]
20 | gt_diff = gt_score.unsqueeze(1) - gt_score.unsqueeze(0)
21 | pred_diff = pred_score.unsqueeze(1) - pred_score.unsqueeze(0)
22 | loss = torch.maximum(torch.zeros(gt_diff.shape).to(gt_diff.device), torch.abs(pred_diff - gt_diff) - self.margin)
23 | loss = loss.mean().div(2)
24 | return loss
25 |
26 |
27 | class ClippedMSELoss(nn.Module):
28 | """
29 | clipped MSE loss for listener-dependent model
30 | """
31 | def __init__(self, criterion,tau,mode='frame'):
32 | super(ClippedMSELoss, self).__init__()
33 | self.tau = torch.tensor(tau,dtype=torch.float)
34 |
35 | self.criterion = criterion
36 | self.mode = mode
37 |
38 |
39 | def forward_criterion(self, y_hat, label):
40 |
41 | y_hat = y_hat.squeeze(-1)
42 | loss = self.criterion(y_hat, label)
43 | threshold = torch.abs(y_hat - label) > self.tau
44 | loss = torch.mean(threshold * loss)
45 | return loss
46 |
47 | def forward(self, pred_score, gt_score):
48 | """
49 | Args:
50 | pred_mean, pred_score: [batch, time, 1/5]
51 | """
52 | # repeat for frame level loss
53 | time = pred_score.shape[1]
54 | if self.mode == 'utt':
55 | pred_score = pred_score.mean(dim=1)
56 | else:
57 | gt_score = gt_score.unsqueeze(1).repeat(1, time)
58 | main_loss = self.forward_criterion(pred_score, gt_score)
59 | return main_loss # lamb 1.0
60 |
61 | class CombineLosses(nn.Module):
62 | '''
63 | Combine losses
64 | Args:
65 | loss_weights: a list of weights for each loss
66 | '''
67 | def __init__(self, loss_weights:list, loss_instances:list):
68 | super(CombineLosses, self).__init__()
69 | self.loss_weights = loss_weights
70 | self.loss_instances = nn.ModuleList(loss_instances)
71 | def forward(self, pred_score, gt_score):
72 | loss = torch.tensor(0,dtype=torch.float).to(pred_score.device)
73 | for loss_weight, loss_instance in zip(self.loss_weights, self.loss_instances):
74 | loss += loss_weight * loss_instance(pred_score,gt_score)
75 | return loss
76 |
--------------------------------------------------------------------------------
/strong/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from WavLM import WavLM, WavLMConfig
4 | from text.symbols import symbols
5 | import fairseq
6 | import os
7 | import hydra
8 |
9 | def load_ssl_model(cp_path):
10 | cp_path = os.path.join(
11 | hydra.utils.get_original_cwd(),
12 | cp_path
13 | )
14 | ssl_model_type = cp_path.split("/")[-1]
15 | wavlm = "WavLM" in ssl_model_type
16 | if wavlm:
17 | checkpoint = torch.load(cp_path)
18 | cfg = WavLMConfig(checkpoint['cfg'])
19 | ssl_model = WavLM(cfg)
20 | ssl_model.load_state_dict(checkpoint['model'])
21 | if 'Large' in ssl_model_type:
22 | SSL_OUT_DIM = 1024
23 | else:
24 | SSL_OUT_DIM = 768
25 | else:
26 | if ssl_model_type == "wav2vec_small.pt":
27 | SSL_OUT_DIM = 768
28 | elif ssl_model_type in ["w2v_large_lv_fsh_swbd_cv.pt", "xlsr_53_56k.pt"]:
29 | SSL_OUT_DIM = 1024
30 | else:
31 | print("*** ERROR *** SSL model type " + ssl_model_type + " not supported.")
32 | exit()
33 | model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
34 | [cp_path]
35 | )
36 | ssl_model = model[0]
37 | ssl_model.remove_pretraining_modules()
38 | return SSL_model(ssl_model, SSL_OUT_DIM, wavlm)
39 |
40 | class SSL_model(nn.Module):
41 | def __init__(self,ssl_model,ssl_out_dim,wavlm) -> None:
42 | super(SSL_model,self).__init__()
43 | self.ssl_model, self.ssl_out_dim = ssl_model, ssl_out_dim
44 | self.WavLM = wavlm
45 |
46 | def forward(self,batch):
47 | wav = batch['wav']
48 | wav = wav.squeeze(1) # [batches, audio_len]
49 | if self.WavLM:
50 | x = self.ssl_model.extract_features(wav)[0]
51 | else:
52 | res = self.ssl_model(wav, mask=False, features_only=True)
53 | x = res["x"]
54 | return {"ssl-feature":x}
55 | def get_output_dim(self):
56 | return self.ssl_out_dim
57 |
58 |
59 | class PhonemeEncoder(nn.Module):
60 | '''
61 | PhonemeEncoder consists of an embedding layer, an LSTM layer, and a linear layer.
62 | Args:
63 | vocab_size: the size of the vocabulary
64 | hidden_dim: the size of the hidden state of the LSTM
65 | emb_dim: the size of the embedding layer
66 | out_dim: the size of the output of the linear layer
67 | n_lstm_layers: the number of LSTM layers
68 | '''
69 | def __init__(self, vocab_size, hidden_dim, emb_dim, out_dim,n_lstm_layers,with_reference=True) -> None:
70 | super().__init__()
71 | self.with_reference = with_reference
72 | self.embedding = nn.Embedding(vocab_size, emb_dim)
73 | self.encoder = nn.LSTM(emb_dim, hidden_dim,
74 | num_layers=n_lstm_layers, dropout=0.1, bidirectional=True)
75 | self.linear = nn.Sequential(
76 | nn.Linear(hidden_dim + hidden_dim*self.with_reference, out_dim),
77 | nn.ReLU()
78 | )
79 | self.out_dim = out_dim
80 |
81 | def forward(self,batch):
82 | seq = batch['phonemes']
83 | lens = batch['phoneme_lens']
84 | reference_seq = batch['reference']
85 | reference_lens = batch['reference_lens']
86 | emb = self.embedding(seq)
87 | emb = torch.nn.utils.rnn.pack_padded_sequence(
88 | emb, lens, batch_first=True, enforce_sorted=False)
89 | _, (ht, _) = self.encoder(emb)
90 | feature = ht[-1] + ht[0]
91 | if self.with_reference:
92 | if reference_seq==None or reference_lens ==None:
93 | raise ValueError("reference_batch and reference_lens should not be None when with_reference is True")
94 | reference_emb = self.embedding(reference_seq)
95 | reference_emb = torch.nn.utils.rnn.pack_padded_sequence(
96 | reference_emb, reference_lens, batch_first=True, enforce_sorted=False)
97 | _, (ht_ref, _) = self.encoder(emb)
98 | reference_feature = ht_ref[-1] + ht_ref[0]
99 | feature = self.linear(torch.cat([feature,reference_feature],1))
100 | else:
101 | feature = self.linear(feature)
102 | return {"phoneme-feature": feature}
103 | def get_output_dim(self):
104 | return self.out_dim
105 |
106 | class DomainEmbedding(nn.Module):
107 | def __init__(self,n_domains,domain_dim) -> None:
108 | super().__init__()
109 | self.embedding = nn.Embedding(n_domains,domain_dim)
110 | self.output_dim = domain_dim
111 | def forward(self, batch):
112 | return {"domain-feature": self.embedding(batch['domains'])}
113 | def get_output_dim(self):
114 | return self.output_dim
115 |
116 |
117 | class LDConditioner(nn.Module):
118 | '''
119 | Conditions ssl output by listener embedding
120 | '''
121 | def __init__(self,input_dim, judge_dim, num_judges=None):
122 | super().__init__()
123 | self.input_dim = input_dim
124 | self.judge_dim = judge_dim
125 | self.num_judges = num_judges
126 | assert num_judges !=None
127 | self.judge_embedding = nn.Embedding(num_judges, self.judge_dim)
128 | # concat [self.output_layer, phoneme features]
129 |
130 | self.decoder_rnn = nn.LSTM(
131 | input_size = self.input_dim + self.judge_dim,
132 | hidden_size = 512,
133 | num_layers = 1,
134 | batch_first = True,
135 | bidirectional = True
136 | ) # linear?
137 | self.out_dim = self.decoder_rnn.hidden_size*2
138 |
139 | def get_output_dim(self):
140 | return self.out_dim
141 |
142 |
143 | def forward(self, x, batch):
144 | judge_ids = batch['judge_id']
145 | if 'phoneme-feature' in x.keys():
146 | concatenated_feature = torch.cat((x['ssl-feature'], x['phoneme-feature'].unsqueeze(1).expand(-1,x['ssl-feature'].size(1) ,-1)),dim=2)
147 | else:
148 | concatenated_feature = x['ssl-feature']
149 | if 'domain-feature' in x.keys():
150 | concatenated_feature = torch.cat(
151 | (
152 | concatenated_feature,
153 | x['domain-feature']
154 | .unsqueeze(1)
155 | .expand(-1, concatenated_feature.size(1), -1),
156 | ),
157 | dim=2,
158 | )
159 | if judge_ids != None:
160 | concatenated_feature = torch.cat(
161 | (
162 | concatenated_feature,
163 | self.judge_embedding(judge_ids)
164 | .unsqueeze(1)
165 | .expand(-1, concatenated_feature.size(1), -1),
166 | ),
167 | dim=2,
168 | )
169 | decoder_output, (h, c) = self.decoder_rnn(concatenated_feature)
170 | return decoder_output
171 |
172 | class Projection(nn.Module):
173 | def __init__(self, input_dim, hidden_dim, activation, range_clipping=False):
174 | super(Projection, self).__init__()
175 | self.range_clipping = range_clipping
176 | output_dim = 1
177 | if range_clipping:
178 | self.proj = nn.Tanh()
179 |
180 | self.net = nn.Sequential(
181 | nn.Linear(input_dim, hidden_dim),
182 | activation,
183 | nn.Dropout(0.3),
184 | nn.Linear(hidden_dim, output_dim),
185 | )
186 | self.output_dim = output_dim
187 |
188 | def forward(self, x, batch):
189 | output = self.net(x)
190 |
191 | # range clipping
192 | if self.range_clipping:
193 | return self.proj(output) * 2.0 + 3
194 | else:
195 | return output
196 | def get_output_dim(self):
197 | return self.output_dim
198 |
--------------------------------------------------------------------------------
/strong/param_tuning.py:
--------------------------------------------------------------------------------
1 | from lib2to3.pytree import Base
2 | from pytorch_lightning.loggers.csv_logs import CSVLogger
3 | from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
4 | from pytorch_lightning.callbacks import ModelCheckpoint
5 | from pytorch_lightning import Trainer
6 | from pytorch_lightning.callbacks.early_stopping import EarlyStopping
7 | from pytorch_lightning.callbacks import LearningRateMonitor
8 | import hydra
9 | from dataset import DataModule
10 | import wandb
11 |
12 | @hydra.main(config_path="configs",config_name='optuna-main')
13 | def train(cfg):
14 | debug = cfg.debug
15 |
16 | if cfg.batch_size_and_model == "wav2vec2-base-4":
17 | cfg.model.feature_extractors[0]["cp_path"] = "fairseq/wav2vec_small.pt"
18 | cfg.train.train_batch_size = 4
19 | elif cfg.batch_size_and_model == "wav2vec2-base-8":
20 | cfg.model.feature_extractors[0]["cp_path"] = "fairseq/wav2vec_small.pt"
21 | cfg.train.train_batch_size = 8
22 | elif cfg.batch_size_and_model == "wav2vec2-base-16":
23 | cfg.model.feature_extractors[0]["cp_path"] = "fairseq/wav2vec_small.pt"
24 | cfg.train.train_batch_size = 16
25 | elif cfg.batch_size_and_model == "wav2vec2-base-32":
26 | cfg.model.feature_extractors[0]["cp_path"] = "fairseq/wav2vec_small.pt"
27 | cfg.train.train_batch_size = 32
28 | elif cfg.batch_size_and_model == "wavlm-large-4":
29 | cfg.model.feature_extractors[0]["cp_path"] = "fairseq/WavLM-Large.pt"
30 | cfg.train.train_batch_size = 4
31 | print(cfg.batch_size_and_model)
32 | print(cfg.model.feature_extractors[0]["cp_path"])
33 | print(cfg.train.train_batch_size)
34 |
35 | if cfg.dataset.use_data.main:
36 | cfg.dataset.data_sources.pop(0)
37 | if cfg.dataset.use_data.ood:
38 | cfg.dataset.data_sources.pop(1)
39 | if cfg.dataset.use_data.external:
40 | cfg.dataset.data_sources.pop(2)
41 |
42 | loggers = []
43 | loggers.append(CSVLogger(save_dir=cfg.train.out_dir, name="train_log"))
44 | loggers.append(TensorBoardLogger(save_dir=cfg.train.out_dir, name="tf_log"))
45 | if cfg.train.use_wandb:
46 | loggers.append(WandbLogger(project="voicemos",offline=debug))
47 |
48 | checkpoint_callback = ModelCheckpoint(
49 | dirpath=cfg.train.out_dir,
50 | save_weights_only=True,
51 | save_top_k=1,
52 | save_last=False,
53 | every_n_epochs=1,
54 | monitor=cfg.train.model_selection_metric,
55 | mode='max'
56 | )
57 | lr_monitor = LearningRateMonitor(logging_interval='step')
58 | callbacks = [checkpoint_callback,lr_monitor]
59 | earlystop_callback = EarlyStopping(
60 | monitor="val_loss", min_delta=0.0, patience=cfg.train.early_stopping.patience, mode="min"
61 | )
62 | callbacks.append(earlystop_callback)
63 |
64 | trainer = Trainer(
65 | **cfg.train.trainer_args,
66 | default_root_dir=hydra.utils.get_original_cwd(),
67 | limit_train_batches=0.01 if debug else 1.0,
68 | limit_val_batches=0.5 if debug else 1.0,
69 | callbacks=callbacks,
70 | logger=loggers,
71 | )
72 |
73 | lightning_module = hydra.utils.instantiate(cfg.model.lightning_module,cfg=cfg , _recursive_=False)
74 | wandblogger.watch(lightning_module)
75 | datamodule = hydra.utils.instantiate(cfg.dataset.datamodule,cfg=cfg,_recursive_=False)
76 | trainer.fit(lightning_module,datamodule=datamodule)
77 | trainer.test(lightning_module, datamodule=datamodule,ckpt_path=checkpoint_callback.best_model_path)
78 |
79 |
80 | SRCC_system = trainer.logged_metrics[cfg.tuning_target]
81 | wandb.finish()
82 |
83 | return SRCC_system
84 |
85 | if __name__ == "__main__":
86 | train()
87 |
--------------------------------------------------------------------------------
/strong/predict.py:
--------------------------------------------------------------------------------
1 | from omegaconf import open_dict
2 | from pytorch_lightning import Trainer
3 | import hydra
4 | import os
5 | import pathlib
6 | from lightning_module import UTMOSLightningModule
7 | from dataset import TestDataModule, DataModule
8 | import torch
9 |
10 | @hydra.main(config_path="configs",config_name='default')
11 | def predict(cfg):
12 | """
13 | Specify ckeckpoint path as follows:
14 |
15 | python predict.py +ckpt_path="outputs/${date}/${time}/train_outputs/hoge.ckpt"
16 | """
17 |
18 | trainer = Trainer(
19 | **cfg.train.trainer_args,
20 | default_root_dir=hydra.utils.get_original_cwd(),
21 | )
22 |
23 | ckpt_path = pathlib.Path(cfg.ckpt_path)
24 | if not ckpt_path.is_absolute():
25 | ckpt_path = (pathlib.Path(hydra.utils.get_original_cwd()) / ckpt_path)
26 |
27 | if 'paper_weights' in cfg.keys():
28 | with open_dict(cfg):
29 | ckpt = torch.load(ckpt_path)
30 | use_data = ckpt['hyper_parameters']['cfg']['dataset']['use_data']
31 | cfg.dataset.use_data['external'] = use_data['lancers']
32 | cfg.dataset.use_data['main'] = use_data['main']
33 | cfg.dataset.use_data['ood'] = use_data['ood']
34 | cfg.dataset.only_mean = ckpt['hyper_parameters']['cfg']['dataset']['only_mean']
35 | lightning_module = UTMOSLightningModule.load_from_checkpoint(ckpt_path,cfg=cfg,paper_weight=cfg.paper_weights)
36 | lightning_module.cfg
37 | else:
38 | lightning_module = UTMOSLightningModule.load_from_checkpoint(ckpt_path)
39 | print(lightning_module.cfg)
40 | datamodule = DataModule(lightning_module.cfg)
41 | test_datamodule = TestDataModule(cfg=lightning_module.cfg, i_cv=0, set_name='test')
42 | trainer.test(
43 | lightning_module,
44 | verbose=True,
45 | datamodule=datamodule,
46 | ckpt_path=ckpt_path
47 | )
48 | trainer.test(
49 | lightning_module,
50 | verbose=True,
51 | datamodule=test_datamodule,
52 | ckpt_path=ckpt_path
53 | )
54 |
55 | if __name__ == "__main__":
56 | predict()
57 |
--------------------------------------------------------------------------------
/strong/text/symbols.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | '''
4 | Defines the set of symbols used in text input to the model.
5 | '''
6 | _pad = '_'
7 | _punctuation = ';:,.!?¡¿—…"«»“” '
8 | _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
9 | _numbers = '0123456789'
10 | _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ'̪'̃"
11 |
12 |
13 | # Export all symbols:
14 | symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) + list(_numbers)
15 |
16 | # Special symbol ids
17 | SPACE_ID = symbols.index(" ")
--------------------------------------------------------------------------------
/strong/train.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from pytorch_lightning.loggers.csv_logs import CSVLogger
3 | from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
4 | from pytorch_lightning.callbacks import ModelCheckpoint
5 | from pytorch_lightning import Trainer, seed_everything
6 | from pytorch_lightning.callbacks.early_stopping import EarlyStopping
7 | from pytorch_lightning.callbacks import LearningRateMonitor
8 | from dataset import CVDataModule, TestDataModule
9 | from lightning_module import UTMOSLightningModule
10 | import hydra
11 | import wandb
12 |
13 |
14 |
15 | @hydra.main(config_path="configs",config_name='default')
16 | def train(cfg):
17 | debug = cfg.debug
18 | if debug:
19 | cfg.train.train_batch_size=4
20 | cfg.train.trainer_args.max_steps=10
21 |
22 | loggers = []
23 | loggers.append(CSVLogger(save_dir=cfg.train.out_dir, name="train_log"))
24 | loggers.append(TensorBoardLogger(save_dir=cfg.train.out_dir, name="tf_log"))
25 | if cfg.train.use_wandb:
26 | loggers.append(WandbLogger(project="voicemos",offline=debug))
27 |
28 | checkpoint_callback = ModelCheckpoint(
29 | dirpath=cfg.train.out_dir,
30 | save_weights_only=False,
31 | save_top_k=1,
32 | save_last=True,
33 | every_n_epochs=1,
34 | monitor=cfg.train.model_selection_metric,
35 | mode='max'
36 | )
37 | lr_monitor = LearningRateMonitor(logging_interval='step')
38 | callbacks = [checkpoint_callback,lr_monitor]
39 | earlystop_callback = EarlyStopping(
40 | monitor="val_loss", min_delta=0.0, patience=cfg.train.early_stopping.patience, mode="min"
41 | )
42 | callbacks.append(earlystop_callback)
43 |
44 | trainer = Trainer(
45 | **cfg.train.trainer_args,
46 | default_root_dir=hydra.utils.get_original_cwd(),
47 | limit_train_batches=0.01 if debug else 1.0,
48 | limit_val_batches=0.5 if debug else 1.0,
49 | callbacks=callbacks,
50 | logger=loggers,
51 | )
52 |
53 | datamodule = hydra.utils.instantiate(cfg.dataset.datamodule,cfg=cfg,_recursive_=False)
54 | test_datamodule = TestDataModule(cfg=cfg, i_cv=0, set_name='test')
55 | lightning_module = UTMOSLightningModule(cfg)
56 |
57 | trainer.fit(lightning_module, datamodule=datamodule)
58 |
59 | if debug:
60 | trainer.test(lightning_module, datamodule=datamodule)
61 | trainer.test(lightning_module, datamodule=test_datamodule)
62 | else:
63 | trainer.test(lightning_module, datamodule=datamodule,ckpt_path=checkpoint_callback.best_model_path)
64 | trainer.test(lightning_module, datamodule=test_datamodule,ckpt_path=checkpoint_callback.best_model_path)
65 | if cfg.train.use_wandb:
66 | wandb.save(checkpoint_callback.best_model_path)
67 |
68 | if __name__ == "__main__":
69 | train()
70 |
--------------------------------------------------------------------------------
/strong/transcribe_speech.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 | from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
3 | from datasets import load_dataset
4 | import soundfile as sf
5 | import torch
6 | from datasets import set_caching_enabled
7 | import pandas as pd
8 | import Levenshtein
9 | import numpy as np
10 | from sklearn.cluster import DBSCAN
11 |
12 | def cluster_transcriptions(df:pd.DataFrame):
13 | data = df['transcription'].to_list()
14 | def lev_metric(x, y):
15 | i, j = int(x[0]), int(y[0]) # extract indices
16 | return Levenshtein.distance(data[i], data[j])/max(len(data[i]), len(data[j]))
17 |
18 | X = np.arange(len(data)).reshape(-1, 1)
19 | result = DBSCAN(eps=0.3, metric=lev_metric,n_jobs=20,min_samples=3).fit(X)
20 | df['cluster'] = result.labels_
21 | text_medians = df.groupby('cluster').apply(lambda x:Levenshtein.median(x['transcription'].to_list()))
22 | medians = []
23 | for idx, row in df.iterrows():
24 | if row['cluster'] == -1:
25 | medians.append(row['transcription'])
26 | else:
27 | medians.append(text_medians[row['cluster']])
28 | df['reference'] = medians
29 | return df
30 |
31 |
32 |
33 | if __name__ == '__main__':
34 | set_caching_enabled(False)
35 | # load datasets
36 |
37 | processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-lv-60-espeak-cv-ft")
38 | model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-lv-60-espeak-cv-ft")
39 | from tqdm import tqdm
40 | _ = model.to('cuda')
41 | batch_size = 1
42 | all_df = pd.DataFrame()
43 | def collate_fn(batch):
44 | import numpy as np
45 | wav_name = batch[0]['audio']['path']
46 | wav_array = batch[0]['audio']['array']
47 | return wav_name, processor(wav_array,sampling_rate=16_000,return_tensors="pt").input_values
48 | for track in 'main', 'ood', 'unlabeled':
49 | wav_names = []
50 | transcriptions = []
51 | if track == 'main':
52 | dataset = load_dataset("sarulab-speech/bvcc-voicemos2022","main_track", data_dir="data/phase1-main/", use_auth_token=True,download_mode="force_redownload")
53 | elif track == 'ood':
54 | dataset = load_dataset("sarulab-speech/bvcc-voicemos2022","ood_track", data_dir="data/phase1-ood/", use_auth_token=True,download_mode="force_redownload")
55 | else:
56 | dataset = load_dataset("sarulab-speech/bvcc-voicemos2022","ood_track_unlabeled", data_dir="data/phase1-ood/", use_auth_token=True,download_mode='force_redownload')
57 | for stage in 'train', 'validation', 'test':
58 | if track == 'unlabeled' and stage != 'train':
59 | continue
60 | print('Processing {track} track {stage}'.format(track=track, stage=stage))
61 | dl = torch.utils.data.DataLoader(dataset[stage], batch_size=batch_size, num_workers=4,collate_fn=collate_fn)
62 | for wav_name, data in tqdm(dl):
63 |
64 | # retrieve logits
65 | with torch.no_grad():
66 | logits = model(data.to('cuda')).logits
67 |
68 | # take argmax and decode
69 | predicted_ids = torch.argmax(logits, dim=-1)
70 | transcription = processor.batch_decode(predicted_ids)
71 | transcriptions.extend(transcription)
72 | wav_names.append(wav_name)
73 | del(dataset)
74 | import pandas as pd
75 | df = pd.DataFrame({"wav_name": wav_names, "transcription": transcriptions})
76 | df['wav_name'] = df['wav_name'].apply(lambda x: x.split("/")[-1])
77 | df['track'] = track if track != 'unlabeled' else 'ood'
78 | all_df = pd.concat([all_df, df],ignore_index=True)
79 | result = pd.concat(
80 | [
81 | cluster_transcriptions(all_df[all_df['track'] == 'main'].copy()),
82 | cluster_transcriptions(all_df[all_df['track'] == 'ood'].copy()),
83 | ],
84 | ignore_index=True
85 | )
86 | result.to_csv('transcriptions_clustered.csv'.format(track), index=False)
87 |
--------------------------------------------------------------------------------