├── .github ├── ISSUE_TEMPLATE │ ├── bug-report.yml │ └── config.yml ├── pull_request_template.md └── workflows │ ├── lint_and_format.yml │ ├── mypy_check.yml │ └── tests.yml ├── .gitignore ├── .vscode └── settings.json ├── CITATION.cff ├── LICENSE ├── README.md ├── docs ├── datasets.md ├── image │ ├── line2.svg │ ├── line3.svg │ ├── line4.svg │ ├── titleLine.svg │ ├── titleLine2b.svg │ ├── titleLine2t.svg │ ├── titleLine3b.svg │ ├── titleLine3t.svg │ ├── titleLine4b.svg │ ├── titleLine4t.svg │ ├── titleLine5b.svg │ ├── titleLine5t.svg │ └── utmosv2.PNG ├── inference.md ├── reproduction.md └── training.md ├── inference.py ├── poster.pdf ├── pyproject.toml ├── quickstart.ipynb ├── tests └── core_tests │ └── test_create.py ├── train.py └── utmosv2 ├── __init__.py ├── _core ├── __init__.py ├── create.py └── model │ ├── __init__.py │ ├── _common.py │ └── _models.py ├── _import.py ├── _settings ├── __init__.py └── _config.py ├── config ├── c_fusion_stage2.py ├── c_fusion_stage3.py ├── c_spec_only_stage1.py ├── c_spec_only_stage2.py ├── c_ssl_only_stage1.py ├── c_ssl_only_stage2.py ├── fusion_stage2.py ├── fusion_stage2_wo_bc.py ├── fusion_stage2_wo_bvcc.py ├── fusion_stage2_wo_sarulab.py ├── fusion_stage2_wo_somos.py ├── fusion_stage3.py ├── fusion_stage3_wo_bc.py ├── fusion_stage3_wo_bvcc.py ├── fusion_stage3_wo_sarulab.py ├── fusion_stage3_wo_somos.py ├── fusion_wo_stage1and2.py ├── fusion_wo_stage2.py ├── spec_only.py ├── spec_only_wo_bc.py ├── spec_only_wo_bvcc.py ├── spec_only_wo_sarulab.py ├── spec_only_wo_somos.py ├── ssl_only_stage1.py ├── ssl_only_stage1_wo_bc.py ├── ssl_only_stage1_wo_bvcc.py ├── ssl_only_stage1_wo_sarulab.py ├── ssl_only_stage1_wo_somos.py ├── ssl_only_stage2.py ├── ssl_only_stage2_wo_bc.py ├── ssl_only_stage2_wo_bvcc.py ├── ssl_only_stage2_wo_sarulab.py └── ssl_only_stage2_wo_somos.py ├── dataset ├── __init__.py ├── _base.py ├── _schema.py ├── _utils.py ├── multi_spec.py ├── ssl.py └── ssl_multispec.py ├── loss ├── __init__.py └── _losses.py ├── model ├── __init__.py ├── multi_spec.py ├── ssl.py └── ssl_multispec.py ├── preprocess ├── __init__.py └── _preprocess.py ├── runner ├── __init__.py ├── _inference.py └── _train.py ├── transform ├── __init__.py └── _xymasking.py └── utils ├── __init__.py ├── _constants.py ├── _download.py ├── _pure ├── __init__.py ├── initializers.py ├── metrics.py ├── save.py └── split.py └── _task_dependents ├── __init__.py ├── initializers.py ├── log.py ├── metrics.py └── save.py /.github/ISSUE_TEMPLATE/bug-report.yml: -------------------------------------------------------------------------------- 1 | name: 🐞 Bug Report 2 | description: Create a report to help us improve UTMOSv2. 3 | labels: ["bug"] 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: Thanks for taking the time to report a bug! Please fill out the sections below to help us reproduce and fix the issue. 8 | 9 | - type: checkboxes 10 | id: checks 11 | attributes: 12 | label: ✅ Checks 13 | options: 14 | - label: I have searched [existing/past issues](https://github.com/sarulab-speech/UTMOSv2/issues) and found no similar issue. 15 | required: true 16 | 17 | - type: textarea 18 | id: description 19 | attributes: 20 | label: ✏️ Description 21 | description: A clear and concise description of what the bug is. 22 | placeholder: Include screenshots, logs, and any other relevant information to receive the most helpful response. 23 | validations: 24 | required: true 25 | 26 | - type: textarea 27 | attributes: 28 | label: 💻 Environment 29 | description: Please specify the software and hardware you used to produce the bug. 30 | placeholder: OS, Python version, UTMOSv2 version, etc. 31 | validations: 32 | required: true 33 | 34 | - type: textarea 35 | id: expected-behavior 36 | attributes: 37 | label: 🌟 Expected behavior 38 | description: A clear and concise description of what you expected to happen. 39 | placeholder: Describe the expected behavior. 40 | validations: 41 | required: true 42 | 43 | - type: textarea 44 | id: steps-to-reproduce 45 | attributes: 46 | label: 🚶‍♀️ Steps to reproduce 47 | description: Please provide detailed steps to reproduce the bug. 48 | placeholder: | 49 | 1. 50 | 2. 51 | 3. 52 | validations: 53 | required: true 54 | 55 | - type: textarea 56 | id: additional-notes 57 | attributes: 58 | label: 🗒️ Additional notes 59 | description: Add any other context about the problem here. 60 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: true 2 | contact_links: 3 | - name: 🖐️ Questions / 💬 Discussions 4 | url: https://github.com/sarulab-speech/UTMOSv2/discussions 5 | about: Get help with using UTMOSv2, share ideas, and discuss features. -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ## 🎯 Motivation 2 | 3 | 4 | ## 📝 Description of Changes 5 | 6 | 7 | ## 🔖 Additional Notes -------------------------------------------------------------------------------- /.github/workflows/lint_and_format.yml: -------------------------------------------------------------------------------- 1 | name: Lint and Format 2 | 3 | on: push 4 | 5 | jobs: 6 | ckecks: 7 | runs-on: ubuntu-latest 8 | permissions: 9 | contents: write 10 | steps: 11 | - uses: actions/checkout@v4 12 | with: 13 | ref: ${{ github.head_ref }} 14 | fetch-depth: 0 15 | - name: Set up Python 16 | uses: actions/setup-python@v4 17 | with: 18 | python-version: "3.12" 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install ruff 23 | - name: Run Ruff (lint) 24 | run: ruff check --output-format=github . 25 | - name: Run Ruff (format) 26 | run: | 27 | ruff format . 28 | - name: Commit changes 29 | uses: stefanzweifel/git-auto-commit-action@v5 30 | with: 31 | commit_message: Apply automatic code formatting 32 | -------------------------------------------------------------------------------- /.github/workflows/mypy_check.yml: -------------------------------------------------------------------------------- 1 | name: Type Check 2 | 3 | on: push 4 | 5 | jobs: 6 | checks: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v4 10 | 11 | - uses: actions/setup-python@v5 12 | with: 13 | python-version: 3.11 14 | 15 | - name: Install 16 | run: | 17 | python -m pip install -U pip 18 | pip install torch --index-url https://download.pytorch.org/whl/cpu 19 | pip install --progress-bar off -U .[check] 20 | 21 | - name: Mypy Check 22 | run: mypy . 23 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: push 4 | 5 | jobs: 6 | tests: 7 | runs-on: ubuntu-latest 8 | 9 | strategy: 10 | matrix: 11 | python-version: ['3.9', '3.10', '3.11', '3.12'] 12 | 13 | steps: 14 | - uses: actions/checkout@v4 15 | 16 | - name: Setup Python${{ matrix.python-version }} 17 | uses: actions/setup-python@v5 18 | with: 19 | python-version: ${{ matrix.python-version }} 20 | 21 | - name: Set up cache 22 | uses: actions/cache@v4 23 | with: 24 | path: ~/.cache/pip 25 | key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }} 26 | restore-keys: | 27 | ${{ runner.os }}-pip- 28 | 29 | - name: Install 30 | run: | 31 | python -m pip install -U pip 32 | pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu 33 | pip install --progress-bar off -U .[test] 34 | 35 | - name: Test 36 | run: pytest 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | data2/ 3 | preds/ 4 | preprocessed_data/ 5 | wandb/ 6 | 7 | .ruff_cache/ 8 | 9 | *.sif 10 | *.log 11 | *.out 12 | 13 | .DS_Store 14 | 15 | __pycache__/ 16 | *.py[cod] 17 | *$py.class 18 | 19 | *.so 20 | 21 | .Python 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | share/python-wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | .ipynb_checkpoints 41 | 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | .python-version 46 | 47 | .env 48 | 49 | .mypy_cache 50 | 51 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "editor.formatOnSave": false, 3 | "editor.formatOnPaste": false, 4 | "editor.formatOnType": false, 5 | "editor.formatOnSaveMode": "file", 6 | "[python]": { 7 | "editor.formatOnSave": true, 8 | "editor.codeActionsOnSave": { 9 | "source.fixAll": "explicit", 10 | "source.organizeImports.ruff": "explicit" 11 | }, 12 | "editor.defaultFormatter": "charliermarsh.ruff" 13 | }, 14 | "[json]": { 15 | "editor.defaultFormatter": "esbenp.prettier-vscode", 16 | "editor.formatOnSave": true, 17 | "editor.tabSize": 2 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you find UTMOSv2 useful in your research, please cite the following paper." 3 | authors: 4 | - name: Kaito Baba 5 | title: "UTMOSv2: UTokyo-SaruLab MOS Prediction System" 6 | url: "https://github.com/sarulab-speech/UTMOSv2" 7 | preferred-citation: 8 | type: conference-paper 9 | authors: 10 | - family-names: "Baba" 11 | given-names: "Kaito" 12 | - family-names: "Nakata" 13 | given-names: "Wataru" 14 | - family-names: "Saito" 15 | given-names: "Yuki" 16 | - family-names: "Saruwatari" 17 | given-names: "Hiroshi" 18 | collection-title: "IEEE Spoken Language Technology Workshop (SLT)" 19 | title: "The t05 system for the VoiceMOS Challenge 2024: Transfer learning from deep image classifier to naturalness MOS prediction of high-quality synthetic speech" 20 | year: 2024 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 sarulab-speech 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /docs/datasets.md: -------------------------------------------------------------------------------- 1 |

2 | 3 | 4 | 5 |
📂 Used Datasets 📂
6 | 7 | 8 | 9 |

10 | 11 | The datasets used in this project and their respective licenses are as follows: 12 | 13 | - **BVCC**: 14 | - **License**: 15 | - Data Downloading and Processing Scripts, Listening Test Results: [BSD 3-Clause License](https://opensource.org/license/bsd-3-clause) 16 | - VCC2016, VCC2018: [Creative Commons License: Attribution 4.0 International](http://creativecommons.org/licenses/by/4.0/legalcode) 17 | - VCC2020: [Open Database License](http://opendatacommons.org/licenses/odbl/1.0/) 18 | - **Link to dataset**: [https://www.codabench.org/competitions/2650/](https://www.codabench.org/competitions/2650/) 19 | 20 | - **Sarulab Data**: 21 | - **License**: [MIT License](https://opensource.org/licenses/MIT) 22 | - **Link to dataset**: [https://github.com/sarulab-speech/VMC2024-sarulab-data](https://github.com/sarulab-speech/VMC2024-sarulab-data) 23 | 24 | - **Blizzard Challenges**: 25 | - **Link to dataset**: [https://www.cstr.ed.ac.uk/projects/blizzard/data.html](https://www.cstr.ed.ac.uk/projects/blizzard/data.html) 26 | 27 | - **SOMOS**: 28 | - **License**: [Creative Commons License: Attribution-NonCommercial-ShareAlike 4.0 International](https://creativecommons.org/licenses/by-nc-sa/4.0/) 29 | - **Link to dataset**: [https://innoetics.github.io/publications/somos-dataset/index.html](https://innoetics.github.io/publications/somos-dataset/index.html) 30 | 31 | ## 🙏 Acknowledgments 32 | 33 | We acknowledge the creators of the datasets used. We would like to thank the organisers of the Blizzard Challenge for the provision of resources. 34 | -------------------------------------------------------------------------------- /docs/image/line2.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /docs/image/line3.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /docs/image/line4.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /docs/image/titleLine.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /docs/image/titleLine2b.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /docs/image/titleLine2t.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /docs/image/titleLine3b.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /docs/image/titleLine3t.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /docs/image/titleLine4b.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /docs/image/titleLine4t.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /docs/image/titleLine5b.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /docs/image/titleLine5t.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /docs/image/utmosv2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sarulab-speech/UTMOSv2/00b80845f85fad4e3c23a743851304e0e23c5a02/docs/image/utmosv2.PNG -------------------------------------------------------------------------------- /docs/inference.md: -------------------------------------------------------------------------------- 1 |

2 | 3 | 4 | 5 |
📘 Guide to Inference 📘
6 | 7 | 8 | 9 |

10 | 11 | Please refer to [this section](https://github.com/sarulab-speech/UTMOSv2?tab=readme-ov-file#---quick-prediction--------) for basic inference methods. 12 | 13 | If you want to use the `inference.py` scripts, rather than using the `utmosv2` library, please install some additional dependencies: 14 | 15 | ```bash 16 | pip install --upgrade pip # enable PEP 660 support 17 | pip install -e .[optional] 18 | ``` 19 | 20 | > [!NOTE] 21 | > If you are using zsh, make sure to escape the square brackets like this: 22 | > 23 | > ```zsh 24 | > pip install -e '.[optional]' 25 | > ``` 26 | 27 |

28 |
📌 Data-domain ID for the MOS Prediction 📌
29 | 30 | 31 | 32 |

33 | 34 | By default, the data-domain ID for the MOS prediction is set to sarulab-data. To specify this and make predictions, you can specify the `--predict_dataset` flag with the following options: 35 | 36 | - `sarulab` (default) 37 | - `bvcc` 38 | - `blizzard2008`, `blizzard2009`, `blizzard2010-EH1`, `blizzard2010-EH2`, `blizzard2010-ES1`, `blizzard2010-ES3`, `blizzard2011` 39 | - `somos` 40 | 41 | For example, to make predictions with the data-domain ID set to somos, use the following command: 42 | 43 | - If you are using in your Python code: 44 | 45 | ```python 46 | mos = model.predict(input_dir="/path/to/wav/dir/", predict_dataset="somos") 47 | ``` 48 | 49 | - If you are using the inference script: 50 | 51 | ```bash 52 | python inference.py --input_dir /path/to/wav/dir/ --out_path /path/to/output/file.csv --predict_dataset somos 53 | ``` 54 | 55 |

56 |
✂️ Predicting Only a Subset of Files ✂️
57 | 58 | 59 | 60 |

61 | 62 | By default, all `.wav` files in the `--input_dir` are used for prediction. To specify only a subset of these files, use the `--val_list_path` flag: 63 | 64 | - If you are using in your Python code: 65 | 66 | ```python 67 | mos = model.predict(input_dir="/path/to/wav/dir/", val_list_path="/path/to/your/val/list.txt") 68 | ``` 69 | 70 | or, you can provide the list directly: 71 | 72 | ```python 73 | mos = model.predict( 74 | input_dir="/path/to/wav/dir/", 75 | val_list=["sys00691-utt0682e32", "sys00691-utt31fd854", "sys00691-utt33a4826", ...] 76 | ) 77 | ``` 78 | 79 | - If you are using the inference script: 80 | 81 | ```bash 82 | python inference.py --input_dir /path/to/wav/dir/ --out_path /path/to/output/file.csv --val_list_path /path/to/your/val/list.txt 83 | ``` 84 | 85 | The list of `.wav` files specified here should contain utt-id separated by new lines, as shown below. The file extension `.wav` is optional and can be included or omitted. 86 | 87 | ```text 88 | sys00691-utt0682e32 89 | sys00691-utt31fd854 90 | sys00691-utt33a4826 91 | ... 92 | ``` 93 | 94 |

95 |
📈 Specify the Fold and the Number of Repetitions for More Accurate Predictions 📈
96 | 97 | 98 | 99 |

100 | 101 | In the paper, predictions are made repeatedly for five randomly selected frames of the input speech waveform for all five folds, and the average is used. To specify this for more accurate predictions, do the following: 102 | 103 | - If you are using in your Python code: 104 | 105 | ```python 106 | model = utmosv2.create_model(fold=2) 107 | mos = model.predict(input_dir="/path/to/wav/dir/", num_repetitions=5) 108 | ``` 109 | 110 | - If you are using the inference script: 111 | 112 | ```bash 113 | python inference.py --input_dir /path/to/wav/dir/ --out_path /path/to/output/file.csv --fold 2 --num_repetitions 5 114 | ``` 115 | 116 | Here, the `--fold` option specifies the fold number to be used. If set to `-1`, all folds will be used. The `--num_repetitions` option specifies the number of repetitions. 117 | 118 |

119 |
🎯 Specify a Configuration File 🎯
120 | 121 | 122 | 123 |

124 | 125 | To specify a configuration file for predictions, do the following: 126 | 127 | - If you are using in your Python code: 128 | 129 | ```python 130 | model = utmosv2.create_model(config="configuration_file_name") 131 | mos = model.predict(input_dir="/path/to/wav/dir/") 132 | ``` 133 | 134 | - If you are using the inference script: 135 | 136 | ```bash 137 | python inference.py --config configuration_file_name --input_dir /path/to/wav/dir/ --out_path /path/to/output/file.csv 138 | ``` 139 | 140 | By default, `fusion_stage3`, which is the entire model of UTMOSv2, is used. 141 | 142 |

143 |
⚖️ Make Predictions Using Your Own Weights ⚖️
144 | 145 | 146 | 147 |

148 | 149 | If you are using in your Python code, specify the checkpoint path with the `checkpoint_path` argument to make predictions using your own weights: 150 | 151 | ```python 152 | model = utmosv2.create_model(checkpoint_path="/path/to/your/weight.pth") 153 | mos = model.predict(input_dir="/path/to/wav/dir/") 154 | ``` 155 | 156 | If you are using the inference script, specify the path to the weights with the `--weight` option to make predictions using your own weights: 157 | 158 | ```bash 159 | python inference.py --input_dir /path/to/wav/dir/ --out_path /path/to/output/file.csv --weight /path/to/your/weight.pth 160 | ``` 161 | 162 | The `checkpoint_path` argument and `--weight` option can specify either the configuration file name or the path to the weight `.pth` file. By default, `models/{config_name}/fold{now_fold}_s{seed}_best_model.pth` is used. 163 | 164 | The weights must be compatible with the model specified by `config` argument or `--config_name` option. 165 | 166 | > [!NOTE] 167 | > In this case, the same weights specified will be used for all folds. To use different weights for each fold, you can do the following: 168 | > 169 | > ```bash 170 | > for i in {0..5}; do 171 | > python inference.py --input_path /path/to/wav/file.wav --out_path /path/to/output/file.csv --weight /path/to/your/weight_fold${i}.pth --fold $i 172 | > done 173 | > ``` 174 | -------------------------------------------------------------------------------- /docs/training.md: -------------------------------------------------------------------------------- 1 |

2 | 3 | 4 | 5 |
📚 Guide to Training 📚
6 | 7 | 8 | 9 |

10 | 11 | To train UTMOSv2 following the methods described in the paper or used in the competition, please refer to [this document](reproduction.md). 12 | 13 |

14 |
📩 Install Training Dependencies 📩
15 | 16 | 17 | 18 |

19 | 20 | To install the dependencies required for training, run the following command: 21 | 22 | ```bash 23 | pip install --upgrade pip # enable PEP 660 support 24 | pip install -e .[train,optional] 25 | ``` 26 | 27 | > [!NOTE] 28 | > If you are using zsh, make sure to escape the square brackets like this: 29 | > 30 | > ```zsh 31 | > pip install -e '.[train,optional]' 32 | > ``` 33 | 34 |

35 |
🚀 Train UTMOSv2 Using Your Own Data 🚀
36 | 37 | 38 | 39 |

40 | 41 | To train UTMOSv2 using your own data, you need to create a JSON file that contains the location and name of your data. Here is an example structure for the JSON file: 42 | 43 | ```json 44 | { 45 | "data": [ 46 | { 47 | "name": "dataset1", 48 | "dir": "/path/to/your/dataset1", 49 | "mos_list": "/path/to/your/moslist1.txt" 50 | }, 51 | { 52 | "name": "dataset2", 53 | "dir": "/path/to/your/dataset2", 54 | "mos_list": "/path/to/your/moslist2.txt" 55 | } 56 | // Add more data entries as needed 57 | ] 58 | } 59 | ``` 60 | 61 | Here, `name` is used to identify the data-domain ID, and `dir` specifies the directory where the corresponding `.wav` files are located. Additionally, mos_list records the MOS values for the .wav files in the directory, in the following format: 62 | 63 | ```text 64 | sys64e2f-utt491a78a,2.375 65 | sys64e2f-utt8485f83,3.625 66 | sys7ab3c-utt1417b69,4.0 67 | ... 68 | ``` 69 | 70 | The file extension `.wav` is optional and can be included or omitted. The common files between those in the dir and those specified in the mos_list will be used. 71 | 72 | Specify the name, dir, and mos_list set for each dataset-domain ID you want to train. 73 | 74 | Save this JSON file with an appropriate name, for example, `data_config.json` and run the following command: 75 | 76 | ```bash 77 | python train.py --config spec_only --data_config data_config.json 78 | ``` 79 | 80 |

81 |
🧪 Fine-tuning from Pre-trained Weights 🧪
82 | 83 | 84 | 85 |

86 | 87 | To continue training from existing weights, specify the `--weight` option and train as follows. This is useful when you want to perform additional training using weights learned in a previous stage or when fine-tuning. 88 | 89 | ```bash 90 | python train.py --config spec_only --data_config data_config.json --weight /path/to/your/weights.pth 91 | ``` 92 | 93 | The `--weight` option can specify either the configuration file name or the path to the weight `.pth` file. If the configuration file name is specified, `models/{config_name}/fold{now_fold}_s{seed}_best_model.pth` is used. 94 | 95 |

96 |
🔬 Using Weights & Biases (wandb) for Experiment Tracking 🔬
97 | 98 | 99 | 100 |

101 | 102 | To use Weights & Biases (wandb) for experiment tracking, specify the `--wandb` option. You will also need to set the `WANDB_API_KEY` in your `.env` file or environment variables, or follow the prompt during execution to input your API key directly in the command line. 103 | 104 | ```bash 105 | python train.py --config spec_only --data_config data_config.json --wandb 106 | ``` 107 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import importlib 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from utmosv2._settings import configure_defaults, configure_inference_args 8 | from utmosv2._settings._config import Config 9 | from utmosv2.runner import run_inference 10 | from utmosv2.utils import ( 11 | get_dataloader, 12 | get_dataset, 13 | get_inference_data, 14 | get_model, 15 | make_submission_file, 16 | print_metrics, 17 | save_preds, 18 | save_test_preds, 19 | show_inference_data, 20 | ) 21 | 22 | 23 | def main(cfg: Config) -> None: 24 | data = get_inference_data(cfg) 25 | show_inference_data(data) 26 | 27 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 28 | cfg.print_config = True # type: ignore 29 | 30 | test_preds = np.zeros(data.shape[0]) 31 | if cfg.reproduce: 32 | test_metrics: dict[str, float] = {} 33 | 34 | for fold in range(cfg.num_folds): 35 | if 0 <= cfg.inference.fold < cfg.num_folds and fold != cfg.inference.fold: 36 | continue 37 | 38 | cfg.now_fold = fold # type: ignore 39 | 40 | model = get_model(cfg, device) 41 | 42 | cfg.print_config = False # type: ignore 43 | print(f"+*+*[[Fold {fold + 1}/{cfg.num_folds}]]" + "+*" * 30) 44 | 45 | for cycle in range(cfg.inference.num_tta): 46 | test_dataset = get_dataset(cfg, data, "test") 47 | test_dataloader = get_dataloader(cfg, test_dataset, "test") 48 | test_preds_tta, test_metrics_tta = run_inference( 49 | cfg, model, test_dataloader, cycle, data, device 50 | ) 51 | test_preds += test_preds_tta 52 | if cfg.reproduce: 53 | assert test_metrics_tta is not None 54 | for k, v in test_metrics_tta.items(): 55 | test_metrics[k] = test_metrics.get(k, 0) + v 56 | 57 | fold_cnt = 1 if 0 <= cfg.inference.fold < cfg.num_folds else cfg.num_folds 58 | print(f"Average of {fold_cnt} folds") 59 | test_preds /= fold_cnt * cfg.inference.num_tta 60 | if cfg.reproduce: 61 | test_metrics = { 62 | k: v / fold_cnt / cfg.inference.num_tta for k, v in test_metrics.items() 63 | } 64 | print_metrics(test_metrics) 65 | save_test_preds(cfg, data, test_preds, test_metrics) 66 | make_submission_file(cfg, data, test_preds) 67 | else: 68 | save_preds(cfg, data, test_preds) 69 | 70 | 71 | if __name__ == "__main__": 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument( 74 | "-c", "--config", type=str, default="fusion_stage3", help="config file name" 75 | ) 76 | parser.add_argument("-f", "--fold", type=int, default=0, help="fold number") 77 | parser.add_argument( 78 | "-s", "--seed", type=int, default=42, help="random seed for split" 79 | ) 80 | parser.add_argument("-d", "--input_dir", type=str, help="data path") 81 | parser.add_argument("-p", "--input_path", type=str, help="data path") 82 | parser.add_argument("-o", "--out_path", type=str, help="output path") 83 | parser.add_argument( 84 | "-n", 85 | "--num_workers", 86 | type=int, 87 | default=4, 88 | help="number of workers for dataloader", 89 | ) 90 | parser.add_argument( 91 | "-t", 92 | "--val_list_path", 93 | type=str, 94 | help="test data path", 95 | ) 96 | parser.add_argument( 97 | "-w", "--weight", type=str, default=None, help="path to the weight file to load" 98 | ) 99 | parser.add_argument( 100 | "-pd", 101 | "--predict_dataset", 102 | type=str, 103 | default="sarulab", 104 | help="predict dataset", 105 | ) 106 | parser.add_argument( 107 | "-nr", 108 | "--num_repetitions", 109 | type=int, 110 | default=1, 111 | help="number of repetitions for prediction", 112 | ) 113 | parser.add_argument( 114 | "-e", 115 | "--reproduce", 116 | action="store_true", 117 | help="Run the experiment as described in the paper, including all necessary steps for reproducibility.", 118 | ) 119 | parser.add_argument( 120 | "-fi", 121 | "--final", 122 | action="store_true", 123 | help="final submission", 124 | ) 125 | args = parser.parse_args() 126 | 127 | if args.input_dir is None and args.input_path is None: 128 | raise ValueError( 129 | "Either input_dir or input_path must be provided when you use your own data." 130 | ) 131 | if args.input_dir is not None and args.input_path is not None: 132 | raise ValueError( 133 | "Only one of input_dir or input_path must be provided when you use your own data." 134 | ) 135 | 136 | cfg = importlib.import_module("utmosv2.config." + args.config) 137 | configure_inference_args(cfg, args) 138 | configure_defaults(cfg) 139 | 140 | main(cfg) 141 | -------------------------------------------------------------------------------- /poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sarulab-speech/UTMOSv2/00b80845f85fad4e3c23a743851304e0e23c5a02/poster.pdf -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=42", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "utmosv2" 7 | description = "UTokyo-SaruLab MOS Prediction System" 8 | readme = "README.md" 9 | license = { file = "LICENSE" } 10 | authors = [{ name = "Kaito Baba" }] 11 | classifiers = [ 12 | "License :: OSI Approved :: MIT License", 13 | "Programming Language :: Python :: 3.9", 14 | "Programming Language :: Python :: 3.10", 15 | "Programming Language :: Python :: 3.11", 16 | "Programming Language :: Python :: 3.12", 17 | "Programming Language :: Python :: 3 :: Only", 18 | ] 19 | dependencies = [ 20 | "numpy>=1.24.4", 21 | "torch>=2.3.1", 22 | "timm>=1.0.7", 23 | "librosa>=0.10.2", 24 | "tqdm>=4.66.4", 25 | "transformers>=4.42.4", 26 | "typing-extensions" 27 | ] 28 | requires-python = ">=3.9" 29 | dynamic = ["version"] 30 | 31 | [project.optional-dependencies] 32 | check = ["ruff", "mypy", "types-setuptools", "types-tqdm"] 33 | train = ["scikit-learn>=1.3.2", "wandb>=0.17.0", "python-dotenv>=1.0.1"] 34 | optional = ["pandas>=2.2.2"] 35 | test = ["pytest"] 36 | 37 | [tool.setuptools.dynamic] 38 | version = { attr = "utmosv2.__version__" } 39 | 40 | [tool.setuptools.packages.find] 41 | include = ["utmosv2*"] 42 | 43 | [tool.mypy] 44 | python_version = "3.11" 45 | ignore_missing_imports = true 46 | disallow_untyped_defs = true 47 | exclude = ["^build/"] 48 | -------------------------------------------------------------------------------- /quickstart.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 🚀 Quick Introduction to MOS Prediction using UTMOSv2\n", 8 | "\n", 9 | "In this Jupyter notebook, we will introduce a method for predicting MOS (Mean Opinion Score) using UTMOSv2." 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "## 🛠 Installation" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": { 23 | "vscode": { 24 | "languageId": "plaintext" 25 | } 26 | }, 27 | "outputs": [], 28 | "source": [ 29 | "!GIT_LFS_SKIP_SMUDGE=1 pip install git+https://github.com/sarulab-speech/UTMOSv2.git" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": { 36 | "vscode": { 37 | "languageId": "plaintext" 38 | } 39 | }, 40 | "outputs": [], 41 | "source": [ 42 | "import utmosv2" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "metadata": {}, 48 | "source": [ 49 | "## 🔮 Make predictions" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "To predict the MOS of a single wav file:" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": { 63 | "vscode": { 64 | "languageId": "plaintext" 65 | } 66 | }, 67 | "outputs": [], 68 | "source": [ 69 | "model = utmosv2.create_model(pretrained=True)\n", 70 | "mos = model.predict(input_path=\"/path/to/wav/file.wav\")\n", 71 | "print(mos)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": {}, 77 | "source": [ 78 | "To predict the MOS of all .wav files in a folder:" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": { 85 | "vscode": { 86 | "languageId": "plaintext" 87 | } 88 | }, 89 | "outputs": [], 90 | "source": [ 91 | "model = utmosv2.create_model(pretrained=True)\n", 92 | "mos = model.predict(input_dir=\"/path/to/wav/dir/\")\n", 93 | "print(mos)" 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "metadata": {}, 99 | "source": [ 100 | "Note that either `input_path` or `input_dir` must be specified, but not both." 101 | ] 102 | }, 103 | { 104 | "cell_type": "markdown", 105 | "metadata": {}, 106 | "source": [ 107 | "For more details on how to use the inference script, please refer to [inference guide](https://github.com/sarulab-speech/UTMOSv2/blob/main/docs/inference.md)." 108 | ] 109 | } 110 | ], 111 | "metadata": { 112 | "language_info": { 113 | "name": "python" 114 | } 115 | }, 116 | "nbformat": 4, 117 | "nbformat_minor": 2 118 | } 119 | -------------------------------------------------------------------------------- /tests/core_tests/test_create.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import sys 4 | 5 | import pytest 6 | 7 | from utmosv2._core.create import create_model 8 | 9 | 10 | @pytest.mark.parametrize( 11 | "pretrained", 12 | [ 13 | pytest.param( 14 | True, 15 | marks=pytest.mark.skipif( 16 | sys.version_info[:2] != (3, 11), 17 | reason="To avoid downloading the model weights multiple times", 18 | ), 19 | ), 20 | False, 21 | ], 22 | ) 23 | def test_create_model(pretrained: bool) -> None: 24 | model = create_model(pretrained=pretrained) 25 | assert hasattr(model, "forward") 26 | assert hasattr(model, "predict") 27 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import importlib 3 | import os 4 | 5 | import numpy as np 6 | import torch 7 | import wandb 8 | from dotenv import load_dotenv 9 | 10 | from utmosv2._settings import configure_args, configure_defaults 11 | from utmosv2._settings._config import Config 12 | from utmosv2.runner import run_train 13 | from utmosv2.utils import ( 14 | get_dataloader, 15 | get_dataset, 16 | get_loss, 17 | get_metrics, 18 | get_model, 19 | get_optimizer, 20 | get_scheduler, 21 | get_train_data, 22 | save_oof_preds, 23 | split_data, 24 | ) 25 | 26 | 27 | def main(cfg: Config) -> None: 28 | data = get_train_data(cfg) 29 | print(data.head()) 30 | oof_preds = np.zeros(data.shape[0]) 31 | 32 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 33 | print(f"Device: {device}") 34 | 35 | cfg.print_config = True # type: ignore 36 | 37 | for fold, (train_idx, val_idx) in enumerate(split_data(cfg, data)): 38 | if 0 <= cfg.fold < cfg.num_folds and fold != cfg.fold: 39 | continue 40 | 41 | cfg.now_fold = fold # type: ignore 42 | 43 | train_data = data.iloc[train_idx] 44 | val_data = data.iloc[val_idx] 45 | 46 | train_dataset = get_dataset(cfg, train_data, "train") 47 | val_dataset = get_dataset(cfg, val_data, "valid") 48 | 49 | train_dataloader = get_dataloader(cfg, train_dataset, "train") 50 | val_dataloader = get_dataloader(cfg, val_dataset, "valid") 51 | 52 | model = get_model(cfg, device) 53 | criterions = get_loss(cfg) 54 | metrics = get_metrics() 55 | optimizer = get_optimizer(cfg, model) 56 | scheduler = get_scheduler( 57 | cfg, optimizer, len(train_dataloader) * cfg.run.num_epochs 58 | ) 59 | 60 | cfg.print_config = False # type: ignore 61 | print(f"+*+*[[Fold {fold + 1}/{cfg.num_folds}]]" + "+*" * 30) 62 | if cfg.wandb: 63 | wandb.init( 64 | project="voice-mos-challenge-2024", 65 | name=cfg.config_name, 66 | config={ 67 | "fold": fold, 68 | "seed": cfg.split.seed, 69 | }, 70 | ) 71 | 72 | run_train( 73 | cfg, 74 | model, 75 | train_dataloader, 76 | val_dataloader, 77 | val_data, 78 | oof_preds, 79 | fold, 80 | criterions, 81 | metrics, 82 | optimizer, 83 | scheduler, 84 | device, 85 | ) 86 | if cfg.wandb: 87 | wandb.finish() 88 | 89 | save_oof_preds(cfg, data, oof_preds, cfg.fold) 90 | 91 | 92 | if __name__ == "__main__": 93 | parser = argparse.ArgumentParser() 94 | parser.add_argument( 95 | "-c", "--config", type=str, required=True, help="config file name" 96 | ) 97 | parser.add_argument("-f", "--fold", type=int, default=-1, help="fold number") 98 | parser.add_argument( 99 | "-s", "--seed", type=int, default=42, help="random seed for split" 100 | ) 101 | parser.add_argument( 102 | "-i", "--input_dir", type=str, default="data/main/DATA", help="data path" 103 | ) 104 | parser.add_argument( 105 | "-dc", "--data_config", type=str, help="path to the data config file" 106 | ) 107 | parser.add_argument( 108 | "-n", 109 | "--num_workers", 110 | type=int, 111 | default=4, 112 | help="number of workers for dataloader", 113 | ) 114 | parser.add_argument( 115 | "-w", "--weight", type=str, help="path to the weight file to load" 116 | ) 117 | parser.add_argument( 118 | "-e", 119 | "--reproduce", 120 | action="store_true", 121 | help="Run the experiment as described in the paper, including all necessary steps for reproducibility.", 122 | ) 123 | parser.add_argument( 124 | "-wb", "--wandb", action="store_true", help="Use wandb for logging" 125 | ) 126 | args = parser.parse_args() 127 | 128 | if args.reproduce is None and args.data_config is None: 129 | raise ValueError("Either --reproduce or --data_config must be specified") 130 | 131 | cfg = importlib.import_module("utmosv2.config." + args.config) 132 | configure_args(cfg, args) 133 | configure_defaults(cfg) 134 | 135 | load_dotenv() 136 | if cfg.wandb: 137 | wandb.login(key=os.getenv("WANDB_API_KEY")) 138 | 139 | main(cfg) 140 | -------------------------------------------------------------------------------- /utmosv2/__init__.py: -------------------------------------------------------------------------------- 1 | from utmosv2 import config, dataset, loss, model, preprocess, runner, transform, utils 2 | from utmosv2._core import UTMOSv2Model, create_model 3 | 4 | __all__ = [ 5 | "config", 6 | "dataset", 7 | "loss", 8 | "model", 9 | "preprocess", 10 | "runner", 11 | "transform", 12 | "utils", 13 | "create_model", 14 | "UTMOSv2Model", 15 | ] 16 | 17 | __version__ = "1.2.0" 18 | -------------------------------------------------------------------------------- /utmosv2/_core/__init__.py: -------------------------------------------------------------------------------- 1 | from utmosv2._core.create import create_model 2 | from utmosv2._core.model import UTMOSv2Model 3 | from utmosv2._core.model._common import UTMOSv2ModelMixin 4 | 5 | __all__ = ["UTMOSv2Model", "UTMOSv2ModelMixin", "create_model"] 6 | -------------------------------------------------------------------------------- /utmosv2/_core/create.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import importlib 4 | from pathlib import Path 5 | from types import SimpleNamespace 6 | from typing import Literal 7 | 8 | import torch 9 | 10 | from utmosv2._core.model import UTMOSv2Model 11 | from utmosv2._settings import configure_execution 12 | from utmosv2.utils._constants import _UTMOSV2_CHACHE 13 | from utmosv2.utils._download import download_pretrained_weights_from_hf 14 | 15 | 16 | def create_model( 17 | pretrained: bool = True, 18 | config: str = "fusion_stage3", 19 | fold: int = 0, 20 | checkpoint_path: Path | str | None = None, 21 | seed: int = 42, 22 | device: torch.device | str | Literal["auto"] = "auto", 23 | ) -> UTMOSv2Model: 24 | """ 25 | Create a UTMOSv2 model with the specified configuration and optional pretrained weights. 26 | 27 | Args: 28 | pretrained (bool): 29 | If True, loads pretrained weights. Defaults to True. 30 | config (str): 31 | The configuration name to load for the model. Defaults to "fusion_stage3". 32 | fold (int): 33 | The fold number for the pretrained weights (used for model selection). Defaults to 0. 34 | checkpoint_path (Path | str | None): 35 | Path to a specific model checkpoint. If None, the checkpoint downloaded from GitHub is used. Defaults to None. 36 | seed (int): 37 | The seed used for model training to select the correct checkpoint. Defaults to 42. 38 | 39 | Returns: 40 | UTMOSv2Model: The initialized UTMOSv2 model. 41 | 42 | Raises: 43 | FileNotFoundError: If the specified checkpoint file is not found. 44 | 45 | Notes: 46 | - The configuration is dynamically loaded from `utmosv2.config`. 47 | - If `pretrained` is True and `checkpoint_path` is not provided, the function attempts to download pretrained weights from GitHub. 48 | """ 49 | _cfg = importlib.import_module(f"utmosv2.config.{config}") 50 | # Avoid issues with pickling `types.ModuleType`, 51 | # making it easier to use with multiprocessing, DDP, etc. 52 | cfg = SimpleNamespace( 53 | **{k: v for k, v in _cfg.__dict__.items() if not k.startswith("__")} 54 | ) 55 | configure_execution(cfg) 56 | 57 | model = UTMOSv2Model(cfg) 58 | 59 | if pretrained: 60 | if checkpoint_path is None: 61 | checkpoint_path = ( 62 | _UTMOSV2_CHACHE 63 | / "models" 64 | / config 65 | / f"fold{fold}_s{seed}_best_model.pth" 66 | ) 67 | if not checkpoint_path.exists(): 68 | download_pretrained_weights_from_hf(config, fold) 69 | if isinstance(checkpoint_path, str): 70 | checkpoint_path = Path(checkpoint_path) 71 | if not checkpoint_path.exists(): 72 | raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") 73 | 74 | device = torch.device( 75 | ("cuda" if torch.cuda.is_available() else "cpu") 76 | if device == "auto" 77 | else device 78 | ) 79 | model.load_state_dict(torch.load(checkpoint_path, map_location=device)) 80 | print(f"Loaded checkpoint from {checkpoint_path}") 81 | 82 | return model 83 | -------------------------------------------------------------------------------- /utmosv2/_core/model/__init__.py: -------------------------------------------------------------------------------- 1 | from utmosv2._core.model._common import UTMOSv2ModelMixin 2 | from utmosv2._core.model._models import UTMOSv2Model 3 | 4 | __all__ = ["UTMOSv2Model", "UTMOSv2ModelMixin"] 5 | -------------------------------------------------------------------------------- /utmosv2/_core/model/_models.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, Any 4 | 5 | from utmosv2._core.model._common import UTMOSv2ModelMixin 6 | from utmosv2._settings._config import Config 7 | from utmosv2.model import ( 8 | MultiSpecExtModel, 9 | MultiSpecModelV2, 10 | SSLExtModel, 11 | SSLMultiSpecExtModelV1, 12 | SSLMultiSpecExtModelV2, 13 | ) 14 | 15 | if TYPE_CHECKING: 16 | import torch 17 | import torch.nn as nn 18 | 19 | 20 | class UTMOSv2Model(UTMOSv2ModelMixin): 21 | """ 22 | UTMOSv2Model class that wraps different models specified by the configuration. 23 | This class allows for flexible model selection and provides a unified interface for evaluation, calling, and prediction. 24 | """ 25 | 26 | def __init__(self, cfg: Config): 27 | """ 28 | Initialize the UTMOSv2Model with a specified configuration. 29 | 30 | Args: 31 | cfg (SimpleNamespace | ModuleType): Configuration object that contains the model configuration. 32 | 33 | Raises: 34 | ValueError: If the model name specified in the configuration is not recognized. 35 | """ 36 | models = { 37 | "multi_spec_ext": MultiSpecExtModel, 38 | "multi_specv2": MultiSpecModelV2, 39 | "sslext": SSLExtModel, 40 | "ssl_multispec_ext": SSLMultiSpecExtModelV1, 41 | "ssl_multispec_ext_v2": SSLMultiSpecExtModelV2, 42 | } 43 | if cfg.model.name not in models: 44 | raise ValueError(f"Unknown model name: {cfg.model.name}") 45 | self._model = models[cfg.model.name](cfg) 46 | self._cfg_value = cfg 47 | 48 | @property 49 | def _cfg(self) -> Config: 50 | return self._cfg_value 51 | 52 | def eval(self) -> "nn.Module": 53 | return self._model.eval() 54 | 55 | def __call__(self, *args: Any, **kwargs: Any) -> "torch.Tensor": 56 | return self._model(*args, **kwargs) 57 | 58 | def __getattr__(self, name: str) -> Any: 59 | return getattr(self._model, name) 60 | 61 | def __setattr__(self, name: str, value: Any) -> None: 62 | if name == "_model": 63 | super().__setattr__(name, value) 64 | else: 65 | setattr(self._model, name, value) 66 | 67 | def __delattr__(self, name: str) -> None: 68 | delattr(self._model, name) 69 | 70 | def __repr__(self) -> str: 71 | return f"UTMOSv2Model({'('.join(self._model.__repr__().split('(')[1:])}" 72 | 73 | def __str__(self) -> str: 74 | return self.__repr__() 75 | 76 | def __dir__(self) -> list[str]: 77 | return super().__dir__() + self._model.__dir__() 78 | -------------------------------------------------------------------------------- /utmosv2/_import.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import importlib 4 | import types 5 | from typing import Any 6 | 7 | 8 | class _LazyImport(types.ModuleType): 9 | def __init__(self, name: str): 10 | super().__init__(name) 11 | self._name = name 12 | self._module: types.ModuleType | None = None 13 | 14 | def __getattr__(self, name: str) -> Any: 15 | if self._module is None: 16 | self._module = importlib.import_module(self._name) 17 | self.__dict__.update(self._module.__dict__) 18 | return getattr(self._module, name) 19 | -------------------------------------------------------------------------------- /utmosv2/_settings/__init__.py: -------------------------------------------------------------------------------- 1 | from utmosv2._settings._config import ( 2 | configure_args, 3 | configure_defaults, 4 | configure_execution, 5 | configure_inference_args, 6 | ) 7 | 8 | __all__ = [ 9 | "configure_args", 10 | "configure_defaults", 11 | "configure_inference_args", 12 | "configure_execution", 13 | ] 14 | -------------------------------------------------------------------------------- /utmosv2/_settings/_config.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import argparse 4 | import sys 5 | from pathlib import Path 6 | from types import ModuleType, SimpleNamespace 7 | 8 | if sys.version_info >= (3, 10): 9 | from typing import TypeAlias 10 | 11 | # NOTE: Python 3.12 introduces the type statement, so once Python 3.11 is dropped, 12 | # it should be updated to use that instead. 13 | Config: TypeAlias = SimpleNamespace | ModuleType 14 | else: 15 | from typing import Union 16 | 17 | from typing_extensions import TypeAlias 18 | 19 | Config: TypeAlias = Union[SimpleNamespace, ModuleType] 20 | 21 | 22 | def configure_args(cfg: Config, args: argparse.Namespace) -> None: 23 | cfg.fold = args.fold # type: ignore 24 | cfg.split.seed = args.seed # type: ignore 25 | cfg.config_name = args.config # type: ignore 26 | cfg.input_dir = args.input_dir and Path(args.input_dir) # type: ignore 27 | cfg.num_workers = args.num_workers # type: ignore 28 | cfg.weight = args.weight # type: ignore 29 | cfg.save_path = Path("models") / cfg.config_name # type: ignore 30 | cfg.wandb = args.wandb # type: ignore 31 | cfg.reproduce = args.reproduce # type: ignore 32 | cfg.data_config = args.data_config # type: ignore 33 | cfg.phase = "train" # type: ignore 34 | 35 | 36 | def configure_inference_args(cfg: Config, args: argparse.Namespace) -> None: 37 | cfg.inference.fold = args.fold # type: ignore 38 | cfg.split.seed = args.seed # type: ignore 39 | cfg.config_name = args.config # type: ignore 40 | cfg.input_dir = args.input_dir and Path(args.input_dir) # type: ignore 41 | cfg.input_path = args.input_path and Path(args.input_path) # type: ignore 42 | cfg.num_workers = args.num_workers # type: ignore 43 | cfg.weight = args.weight # type: ignore 44 | if not cfg.weight: 45 | cfg.weight = cfg.config_name # type: ignore 46 | cfg.inference.val_list_path = args.val_list_path and Path(args.val_list_path) # type: ignore 47 | cfg.save_path = Path("models") / cfg.config_name # type: ignore 48 | cfg.predict_dataset = args.predict_dataset # type: ignore 49 | cfg.final = args.final # type: ignore 50 | cfg.inference.num_tta = args.num_repetitions # type: ignore 51 | cfg.reproduce = args.reproduce # type: ignore 52 | cfg.out_path = args.out_path and Path(args.out_path) # type: ignore 53 | cfg.data_config = None # type: ignore 54 | cfg.phase = "inference" # type: ignore 55 | 56 | 57 | def configure_defaults(cfg: Config) -> None: 58 | if cfg.id_name is None: 59 | cfg.id_name = "utt_id" # type: ignore 60 | 61 | 62 | def configure_execution(cfg: Config) -> None: 63 | cfg.data_config = None # type: ignore 64 | cfg.phase = "prediction" # type: ignore 65 | cfg.print_config = False # type: ignore 66 | -------------------------------------------------------------------------------- /utmosv2/config/c_fusion_stage2.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from types import SimpleNamespace 5 | 6 | from torchvision import transforms 7 | 8 | from utmosv2.transform import XYMasking 9 | 10 | batch_size = 16 11 | num_folds = 5 12 | 13 | sr = 16000 14 | 15 | preprocess = SimpleNamespace( 16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data") 17 | ) 18 | 19 | split = SimpleNamespace( 20 | type="stratified_group", 21 | target="mos", 22 | group="sys_id", 23 | ) 24 | 25 | external_data: list[str] | str = "all" 26 | use_bvcc = True 27 | 28 | 29 | validation_dataset = "sarulab" 30 | 31 | dataset = SimpleNamespace( 32 | name="ssl_multispec_ext", 33 | specs=[ 34 | SimpleNamespace( 35 | mode="melspec", 36 | n_fft=4096, 37 | hop_length=32, 38 | win_length=4096, 39 | n_mels=512, 40 | shape=(512, 512), 41 | norm=80, 42 | ), 43 | SimpleNamespace( 44 | mode="melspec", 45 | n_fft=4096, 46 | hop_length=32, 47 | win_length=2048, 48 | n_mels=512, 49 | shape=(512, 512), 50 | norm=80, 51 | ), 52 | SimpleNamespace( 53 | mode="melspec", 54 | n_fft=4096, 55 | hop_length=32, 56 | win_length=1024, 57 | n_mels=512, 58 | shape=(512, 512), 59 | norm=80, 60 | ), 61 | SimpleNamespace( 62 | mode="melspec", 63 | n_fft=4096, 64 | hop_length=32, 65 | win_length=512, 66 | n_mels=512, 67 | shape=(512, 512), 68 | norm=80, 69 | ), 70 | ], 71 | spec_frames=SimpleNamespace( 72 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" 73 | ), 74 | ssl=SimpleNamespace( 75 | duration=3, 76 | ), 77 | ) 78 | transform = dict( 79 | train=transforms.Compose( 80 | [ 81 | transforms.Resize((512, 512)), 82 | XYMasking( 83 | num_masks_x=(0, 2), 84 | num_masks_y=(0, 2), 85 | mask_x_length=(10, 40), 86 | mask_y_length=(10, 30), 87 | fill_value=0, 88 | p=0.5, 89 | ), 90 | # transforms.ToTensor(), 91 | ] 92 | ), 93 | valid=transforms.Compose( 94 | [ 95 | transforms.Resize((512, 512)), 96 | # transforms.ToTensor() 97 | ] 98 | ), 99 | ) 100 | 101 | loss = [ 102 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), 103 | (SimpleNamespace(name="mse"), 0.2), 104 | ] 105 | 106 | optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4) 107 | 108 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-5) 109 | 110 | model = SimpleNamespace( 111 | name="ssl_multispec_ext", 112 | multi_spec=SimpleNamespace( 113 | backbone="tf_efficientnetv2_s.in21k_ft_in1k", 114 | pretrained=True, 115 | num_classes=1, 116 | pool_type="catavgmax", 117 | # feature_height=16, 118 | atten=True, 119 | # classifier=None, 120 | ), 121 | ssl=SimpleNamespace( 122 | name="facebook/wav2vec2-base", 123 | attn=1, 124 | freeze=True, 125 | num_classes=1, 126 | ), 127 | ssl_spec=SimpleNamespace( 128 | ssl_weight="c_ssl_only_stage2", 129 | spec_weight="c_spec_only_stage2", 130 | num_classes=1, 131 | freeze=True, 132 | ), 133 | ) 134 | 135 | run = SimpleNamespace( 136 | mixup=True, 137 | mixup_alpha=0.4, 138 | num_epochs=8, 139 | ) 140 | 141 | main_metric = "sys_srcc" 142 | id_name = None 143 | 144 | 145 | inference = SimpleNamespace( 146 | save_path=Path("preds"), 147 | submit_save_path=Path("submissions"), 148 | num_tta=5, 149 | batch_size=8, 150 | extend="tile", 151 | ) 152 | -------------------------------------------------------------------------------- /utmosv2/config/c_fusion_stage3.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from types import SimpleNamespace 5 | 6 | from torchvision import transforms 7 | 8 | from utmosv2.transform import XYMasking 9 | 10 | batch_size = 8 11 | num_folds = 5 12 | 13 | sr = 16000 14 | 15 | preprocess = SimpleNamespace( 16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data") 17 | ) 18 | 19 | split = SimpleNamespace( 20 | type="stratified_group", 21 | target="mos", 22 | group="sys_id", 23 | ) 24 | 25 | external_data: list[str] | str = "all" 26 | use_bvcc = True 27 | 28 | 29 | validation_dataset = "sarulab" 30 | 31 | dataset = SimpleNamespace( 32 | name="ssl_multispec_ext", 33 | specs=[ 34 | SimpleNamespace( 35 | mode="melspec", 36 | n_fft=4096, 37 | hop_length=32, 38 | win_length=4096, 39 | n_mels=512, 40 | shape=(512, 512), 41 | norm=80, 42 | ), 43 | SimpleNamespace( 44 | mode="melspec", 45 | n_fft=4096, 46 | hop_length=32, 47 | win_length=2048, 48 | n_mels=512, 49 | shape=(512, 512), 50 | norm=80, 51 | ), 52 | SimpleNamespace( 53 | mode="melspec", 54 | n_fft=4096, 55 | hop_length=32, 56 | win_length=1024, 57 | n_mels=512, 58 | shape=(512, 512), 59 | norm=80, 60 | ), 61 | SimpleNamespace( 62 | mode="melspec", 63 | n_fft=4096, 64 | hop_length=32, 65 | win_length=512, 66 | n_mels=512, 67 | shape=(512, 512), 68 | norm=80, 69 | ), 70 | ], 71 | spec_frames=SimpleNamespace( 72 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" 73 | ), 74 | ssl=SimpleNamespace( 75 | duration=3, 76 | ), 77 | ) 78 | transform = dict( 79 | train=transforms.Compose( 80 | [ 81 | transforms.Resize((512, 512)), 82 | XYMasking( 83 | num_masks_x=(0, 2), 84 | num_masks_y=(0, 2), 85 | mask_x_length=(10, 40), 86 | mask_y_length=(10, 30), 87 | fill_value=0, 88 | p=0.5, 89 | ), 90 | # transforms.ToTensor(), 91 | ] 92 | ), 93 | valid=transforms.Compose( 94 | [ 95 | transforms.Resize((512, 512)), 96 | # transforms.ToTensor() 97 | ] 98 | ), 99 | ) 100 | 101 | loss = [ 102 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), 103 | (SimpleNamespace(name="mse"), 0.2), 104 | ] 105 | 106 | optimizer = SimpleNamespace(name="adamw", lr=5e-5, weight_decay=1e-4) 107 | 108 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-8) 109 | 110 | model = SimpleNamespace( 111 | name="ssl_multispec_ext", 112 | multi_spec=SimpleNamespace( 113 | backbone="tf_efficientnetv2_s.in21k_ft_in1k", 114 | pretrained=True, 115 | num_classes=1, 116 | pool_type="catavgmax", 117 | # feature_height=16, 118 | atten=True, 119 | # classifier=None, 120 | ), 121 | ssl=SimpleNamespace( 122 | name="facebook/wav2vec2-base", 123 | attn=1, 124 | freeze=False, 125 | num_classes=1, 126 | ), 127 | ssl_spec=SimpleNamespace( 128 | ssl_weight="c_ssl_only_stage2", 129 | spec_weight="c_spec_only_stage2", 130 | num_classes=1, 131 | freeze=False, 132 | ), 133 | ) 134 | 135 | run = SimpleNamespace( 136 | mixup=True, 137 | mixup_alpha=0.4, 138 | num_epochs=2, 139 | ) 140 | 141 | main_metric = "sys_srcc" 142 | id_name = None 143 | 144 | 145 | inference = SimpleNamespace( 146 | save_path=Path("preds"), 147 | submit_save_path=Path("submissions"), 148 | num_tta=5, 149 | batch_size=8, 150 | extend="tile", 151 | ) 152 | -------------------------------------------------------------------------------- /utmosv2/config/c_spec_only_stage1.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from types import SimpleNamespace 5 | 6 | from torchvision import transforms 7 | 8 | from utmosv2.transform import XYMasking 9 | 10 | batch_size = 10 11 | num_folds = 5 12 | 13 | sr = 16000 14 | 15 | preprocess = SimpleNamespace( 16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data") 17 | ) 18 | 19 | split = SimpleNamespace( 20 | type="stratified_group", 21 | target="mos", 22 | group="sys_id", 23 | ) 24 | 25 | external_data: list[str] | str = [] 26 | use_bvcc = True 27 | 28 | 29 | validation_dataset = "bvcc" 30 | 31 | dataset = SimpleNamespace( 32 | name="multi_spec", 33 | specs=[ 34 | SimpleNamespace( 35 | mode="melspec", 36 | n_fft=4096, 37 | hop_length=32, 38 | win_length=4096, 39 | n_mels=512, 40 | shape=(512, 512), 41 | norm=80, 42 | ), 43 | SimpleNamespace( 44 | mode="melspec", 45 | n_fft=4096, 46 | hop_length=32, 47 | win_length=2048, 48 | n_mels=512, 49 | shape=(512, 512), 50 | norm=80, 51 | ), 52 | SimpleNamespace( 53 | mode="melspec", 54 | n_fft=4096, 55 | hop_length=32, 56 | win_length=1024, 57 | n_mels=512, 58 | shape=(512, 512), 59 | norm=80, 60 | ), 61 | SimpleNamespace( 62 | mode="melspec", 63 | n_fft=4096, 64 | hop_length=32, 65 | win_length=512, 66 | n_mels=512, 67 | shape=(512, 512), 68 | norm=80, 69 | ), 70 | ], 71 | spec_frames=SimpleNamespace( 72 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" 73 | ), 74 | ) 75 | transform = dict( 76 | train=transforms.Compose( 77 | [ 78 | transforms.Resize((512, 512)), 79 | XYMasking( 80 | num_masks_x=(0, 2), 81 | num_masks_y=(0, 2), 82 | mask_x_length=(10, 40), 83 | mask_y_length=(10, 30), 84 | fill_value=0, 85 | p=0.5, 86 | ), 87 | # transforms.ToTensor(), 88 | ] 89 | ), 90 | valid=transforms.Compose( 91 | [ 92 | transforms.Resize((512, 512)), 93 | # transforms.ToTensor() 94 | ] 95 | ), 96 | ) 97 | 98 | loss = [ 99 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), 100 | (SimpleNamespace(name="mse"), 0.2), 101 | ] 102 | 103 | optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4) 104 | 105 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7) 106 | 107 | model = SimpleNamespace( 108 | name="multi_specv2", 109 | multi_spec=SimpleNamespace( 110 | backbone="tf_efficientnetv2_s.in21k_ft_in1k", 111 | pretrained=True, 112 | num_classes=1, 113 | pool_type="catavgmax", 114 | # feature_height=16, 115 | atten=True, 116 | # classifier=None, 117 | ), 118 | ) 119 | 120 | run = SimpleNamespace( 121 | mixup=True, 122 | mixup_alpha=0.4, 123 | num_epochs=20, 124 | ) 125 | 126 | main_metric = "sys_srcc" 127 | id_name = None 128 | 129 | 130 | inference = SimpleNamespace( 131 | save_path=Path("preds"), 132 | submit_save_path=Path("submissions"), 133 | num_tta=5, 134 | batch_size=8, 135 | extend="tile", 136 | ) 137 | -------------------------------------------------------------------------------- /utmosv2/config/c_spec_only_stage2.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from types import SimpleNamespace 5 | 6 | from torchvision import transforms 7 | 8 | from utmosv2.transform import XYMasking 9 | 10 | batch_size = 10 11 | num_folds = 5 12 | 13 | sr = 16000 14 | 15 | preprocess = SimpleNamespace( 16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data") 17 | ) 18 | 19 | split = SimpleNamespace( 20 | type="stratified_group", 21 | target="mos", 22 | group="sys_id", 23 | ) 24 | 25 | external_data: list[str] | str = ["sarulab"] 26 | use_bvcc = False 27 | 28 | 29 | validation_dataset = "sarulab" 30 | 31 | dataset = SimpleNamespace( 32 | name="multi_spec", 33 | specs=[ 34 | SimpleNamespace( 35 | mode="melspec", 36 | n_fft=4096, 37 | hop_length=32, 38 | win_length=4096, 39 | n_mels=512, 40 | shape=(512, 512), 41 | norm=80, 42 | ), 43 | SimpleNamespace( 44 | mode="melspec", 45 | n_fft=4096, 46 | hop_length=32, 47 | win_length=2048, 48 | n_mels=512, 49 | shape=(512, 512), 50 | norm=80, 51 | ), 52 | SimpleNamespace( 53 | mode="melspec", 54 | n_fft=4096, 55 | hop_length=32, 56 | win_length=1024, 57 | n_mels=512, 58 | shape=(512, 512), 59 | norm=80, 60 | ), 61 | SimpleNamespace( 62 | mode="melspec", 63 | n_fft=4096, 64 | hop_length=32, 65 | win_length=512, 66 | n_mels=512, 67 | shape=(512, 512), 68 | norm=80, 69 | ), 70 | ], 71 | spec_frames=SimpleNamespace( 72 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" 73 | ), 74 | ) 75 | transform = dict( 76 | train=transforms.Compose( 77 | [ 78 | transforms.Resize((512, 512)), 79 | XYMasking( 80 | num_masks_x=(0, 2), 81 | num_masks_y=(0, 2), 82 | mask_x_length=(10, 40), 83 | mask_y_length=(10, 30), 84 | fill_value=0, 85 | p=0.5, 86 | ), 87 | # transforms.ToTensor(), 88 | ] 89 | ), 90 | valid=transforms.Compose( 91 | [ 92 | transforms.Resize((512, 512)), 93 | # transforms.ToTensor() 94 | ] 95 | ), 96 | ) 97 | 98 | loss = [ 99 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), 100 | (SimpleNamespace(name="mse"), 0.2), 101 | ] 102 | 103 | optimizer = SimpleNamespace(name="adamw", lr=5e-5, weight_decay=1e-4) 104 | 105 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-9) 106 | 107 | model = SimpleNamespace( 108 | name="multi_specv2", 109 | multi_spec=SimpleNamespace( 110 | backbone="tf_efficientnetv2_s.in21k_ft_in1k", 111 | pretrained=True, 112 | num_classes=1, 113 | pool_type="catavgmax", 114 | # feature_height=16, 115 | atten=True, 116 | # classifier=None, 117 | ), 118 | ) 119 | 120 | run = SimpleNamespace( 121 | mixup=True, 122 | mixup_alpha=0.4, 123 | num_epochs=5, 124 | ) 125 | 126 | main_metric = "sys_srcc" 127 | id_name = None 128 | 129 | 130 | inference = SimpleNamespace( 131 | save_path=Path("preds"), 132 | submit_save_path=Path("submissions"), 133 | num_tta=5, 134 | batch_size=8, 135 | extend="tile", 136 | ) 137 | -------------------------------------------------------------------------------- /utmosv2/config/c_ssl_only_stage1.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from types import SimpleNamespace 5 | 6 | batch_size = 32 7 | num_folds = 5 8 | 9 | sr = 16000 10 | 11 | preprocess = SimpleNamespace( 12 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data") 13 | ) 14 | 15 | split = SimpleNamespace( 16 | type="stratified_group", 17 | target="mos", 18 | group="sys_id", 19 | ) 20 | 21 | dataset = SimpleNamespace( 22 | name="sslext", 23 | ssl=SimpleNamespace( 24 | duration=3, 25 | ), 26 | ) 27 | 28 | external_data: list[str] | str = "all" 29 | use_bvcc = True 30 | 31 | 32 | validation_dataset = "sarulab" 33 | 34 | loss = [ 35 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), 36 | (SimpleNamespace(name="mse"), 0.2), 37 | ] 38 | 39 | optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4) 40 | 41 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7) 42 | 43 | model_path = "model" 44 | model = SimpleNamespace( 45 | name="sslext", 46 | ssl=SimpleNamespace( 47 | name="facebook/wav2vec2-base", 48 | attn=1, 49 | freeze=True, 50 | num_classes=1, 51 | ), 52 | ) 53 | 54 | run = SimpleNamespace( 55 | mixup=True, 56 | mixup_alpha=0.4, 57 | num_epochs=20, 58 | ) 59 | 60 | main_metric = "sys_srcc" 61 | id_name = None 62 | 63 | 64 | inference = SimpleNamespace( 65 | save_path=Path("preds"), 66 | submit_save_path=Path("submissions"), 67 | num_tta=5, 68 | batch_size=8, 69 | # extend="tile", 70 | ) 71 | -------------------------------------------------------------------------------- /utmosv2/config/c_ssl_only_stage2.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from types import SimpleNamespace 5 | 6 | batch_size = 32 7 | num_folds = 5 8 | 9 | sr = 16000 10 | 11 | preprocess = SimpleNamespace( 12 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data") 13 | ) 14 | 15 | split = SimpleNamespace( 16 | type="stratified_group", 17 | target="mos", 18 | group="sys_id", 19 | ) 20 | 21 | dataset = SimpleNamespace( 22 | name="sslext", 23 | ssl=SimpleNamespace( 24 | duration=3, 25 | ), 26 | ) 27 | 28 | external_data: list[str] | str = "all" 29 | use_bvcc = True 30 | 31 | 32 | validation_dataset = "sarulab" 33 | 34 | loss = [ 35 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), 36 | (SimpleNamespace(name="mse"), 0.2), 37 | ] 38 | 39 | optimizer = SimpleNamespace(name="adamw", lr=3e-5, weight_decay=1e-4) 40 | 41 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-9) 42 | 43 | model_path = "model" 44 | model = SimpleNamespace( 45 | name="sslext", 46 | ssl=SimpleNamespace( 47 | name="facebook/wav2vec2-base", 48 | attn=1, 49 | freeze=False, 50 | num_classes=1, 51 | ), 52 | ) 53 | 54 | run = SimpleNamespace( 55 | mixup=True, 56 | mixup_alpha=0.4, 57 | num_epochs=5, 58 | ) 59 | 60 | main_metric = "sys_srcc" 61 | id_name = None 62 | 63 | 64 | inference = SimpleNamespace( 65 | save_path=Path("preds"), 66 | submit_save_path=Path("submissions"), 67 | num_tta=5, 68 | batch_size=8, 69 | # extend="tile", 70 | ) 71 | -------------------------------------------------------------------------------- /utmosv2/config/fusion_stage2.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from types import SimpleNamespace 4 | 5 | from torchvision import transforms 6 | 7 | from utmosv2.transform import XYMasking 8 | 9 | batch_size = 16 10 | num_folds = 5 11 | 12 | sr = 16000 13 | 14 | preprocess = SimpleNamespace( 15 | top_db=30, min_seconds=None, save_path="preprocessed_data/clip_audio" 16 | ) 17 | 18 | split = SimpleNamespace( 19 | type="sgkf_kind", 20 | target="mos", 21 | group="sys_id", 22 | kind="dataset", 23 | ) 24 | 25 | external_data: list[str] | str = "all" 26 | use_bvcc = True 27 | 28 | predict_dataset = "ysaito" 29 | # predict_dataset = "bvcc" 30 | 31 | validation_dataset = "each" 32 | 33 | dataset = SimpleNamespace( 34 | name="ssl_multispec_ext", 35 | specs=[ 36 | SimpleNamespace( 37 | mode="melspec", 38 | n_fft=4096, 39 | hop_length=32, 40 | win_length=4096, 41 | n_mels=512, 42 | shape=(512, 512), 43 | norm=80, 44 | ), 45 | SimpleNamespace( 46 | mode="melspec", 47 | n_fft=4096, 48 | hop_length=32, 49 | win_length=2048, 50 | n_mels=512, 51 | shape=(512, 512), 52 | norm=80, 53 | ), 54 | SimpleNamespace( 55 | mode="melspec", 56 | n_fft=4096, 57 | hop_length=32, 58 | win_length=1024, 59 | n_mels=512, 60 | shape=(512, 512), 61 | norm=80, 62 | ), 63 | SimpleNamespace( 64 | mode="melspec", 65 | n_fft=4096, 66 | hop_length=32, 67 | win_length=512, 68 | n_mels=512, 69 | shape=(512, 512), 70 | norm=80, 71 | ), 72 | ], 73 | spec_frames=SimpleNamespace( 74 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" 75 | ), 76 | ssl=SimpleNamespace( 77 | duration=3, 78 | ), 79 | ) 80 | transform = dict( 81 | train=transforms.Compose( 82 | [ 83 | transforms.Resize((512, 512)), 84 | XYMasking( 85 | num_masks_x=(0, 2), 86 | num_masks_y=(0, 2), 87 | mask_x_length=(10, 40), 88 | mask_y_length=(10, 30), 89 | fill_value=0, 90 | p=0.5, 91 | ), 92 | # transforms.ToTensor(), 93 | ] 94 | ), 95 | valid=transforms.Compose( 96 | [ 97 | transforms.Resize((512, 512)), 98 | # transforms.ToTensor() 99 | ] 100 | ), 101 | ) 102 | 103 | loss = [ 104 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), 105 | (SimpleNamespace(name="mse"), 0.2), 106 | ] 107 | 108 | optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4) 109 | 110 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-5) 111 | 112 | model = SimpleNamespace( 113 | name="ssl_multispec_ext_v2", 114 | multi_spec=SimpleNamespace( 115 | backbone="tf_efficientnetv2_s.in21k_ft_in1k", 116 | pretrained=True, 117 | num_classes=1, 118 | pool_type="catavgmax", 119 | # feature_height=16, 120 | atten=True, 121 | # classifier=None, 122 | ), 123 | ssl=SimpleNamespace( 124 | name="facebook/wav2vec2-base", 125 | attn=1, 126 | freeze=True, 127 | num_classes=1, 128 | ), 129 | ssl_spec=SimpleNamespace( 130 | ssl_weight="ssl_only_stage2", 131 | spec_weight="spec_only", 132 | num_classes=1, 133 | freeze=True, 134 | ), 135 | ) 136 | 137 | run = SimpleNamespace( 138 | mixup=True, 139 | mixup_alpha=0.4, 140 | num_epochs=8, 141 | ) 142 | 143 | main_metric = "sys_srcc" 144 | id_name = None 145 | 146 | 147 | inference = SimpleNamespace( 148 | save_path="preds", 149 | submit_save_path="submissions", 150 | num_tta=5, 151 | batch_size=8, 152 | extend="tile", 153 | ) 154 | -------------------------------------------------------------------------------- /utmosv2/config/fusion_stage2_wo_bc.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from types import SimpleNamespace 5 | 6 | from torchvision import transforms 7 | 8 | from utmosv2.transform import XYMasking 9 | 10 | batch_size = 16 11 | num_folds = 5 12 | 13 | sr = 16000 14 | 15 | preprocess = SimpleNamespace( 16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data") 17 | ) 18 | 19 | split = SimpleNamespace( 20 | type="sgkf_kind", 21 | target="mos", 22 | group="sys_id", 23 | kind="dataset", 24 | ) 25 | 26 | external_data: list[str] | str = [ 27 | "sarulab", 28 | # "blizzard2008", 29 | # "blizzard2009", 30 | # "blizzard2011", 31 | # "blizzard2010-EH1", 32 | # "blizzard2010-EH2", 33 | # "blizzard2010-ES1", 34 | # "blizzard2010-ES3", 35 | "somos", 36 | ] 37 | use_bvcc = True 38 | 39 | 40 | validation_dataset = "each" 41 | 42 | dataset = SimpleNamespace( 43 | name="ssl_multispec_ext", 44 | specs=[ 45 | SimpleNamespace( 46 | mode="melspec", 47 | n_fft=4096, 48 | hop_length=32, 49 | win_length=4096, 50 | n_mels=512, 51 | shape=(512, 512), 52 | norm=80, 53 | ), 54 | SimpleNamespace( 55 | mode="melspec", 56 | n_fft=4096, 57 | hop_length=32, 58 | win_length=2048, 59 | n_mels=512, 60 | shape=(512, 512), 61 | norm=80, 62 | ), 63 | SimpleNamespace( 64 | mode="melspec", 65 | n_fft=4096, 66 | hop_length=32, 67 | win_length=1024, 68 | n_mels=512, 69 | shape=(512, 512), 70 | norm=80, 71 | ), 72 | SimpleNamespace( 73 | mode="melspec", 74 | n_fft=4096, 75 | hop_length=32, 76 | win_length=512, 77 | n_mels=512, 78 | shape=(512, 512), 79 | norm=80, 80 | ), 81 | ], 82 | spec_frames=SimpleNamespace( 83 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" 84 | ), 85 | ssl=SimpleNamespace( 86 | duration=3, 87 | ), 88 | ) 89 | transform = dict( 90 | train=transforms.Compose( 91 | [ 92 | transforms.Resize((512, 512)), 93 | XYMasking( 94 | num_masks_x=(0, 2), 95 | num_masks_y=(0, 2), 96 | mask_x_length=(10, 40), 97 | mask_y_length=(10, 30), 98 | fill_value=0, 99 | p=0.5, 100 | ), 101 | # transforms.ToTensor(), 102 | ] 103 | ), 104 | valid=transforms.Compose( 105 | [ 106 | transforms.Resize((512, 512)), 107 | # transforms.ToTensor() 108 | ] 109 | ), 110 | ) 111 | 112 | loss = [ 113 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), 114 | (SimpleNamespace(name="mse"), 0.2), 115 | ] 116 | 117 | optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4) 118 | 119 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-5) 120 | 121 | model = SimpleNamespace( 122 | name="ssl_multispec_ext_v2", 123 | multi_spec=SimpleNamespace( 124 | backbone="tf_efficientnetv2_s.in21k_ft_in1k", 125 | pretrained=True, 126 | num_classes=1, 127 | pool_type="catavgmax", 128 | # feature_height=16, 129 | atten=True, 130 | # classifier=None, 131 | ), 132 | ssl=SimpleNamespace( 133 | name="facebook/wav2vec2-base", 134 | attn=1, 135 | freeze=True, 136 | num_classes=1, 137 | ), 138 | ssl_spec=SimpleNamespace( 139 | ssl_weight="ssl_only_stage2_wo_bc", 140 | spec_weight="spec_only_wo_bc", 141 | num_classes=1, 142 | freeze=True, 143 | ), 144 | ) 145 | 146 | run = SimpleNamespace( 147 | mixup=True, 148 | mixup_alpha=0.4, 149 | num_epochs=8, 150 | ) 151 | 152 | main_metric = "sys_srcc" 153 | id_name = None 154 | 155 | 156 | inference = SimpleNamespace( 157 | save_path=Path("preds"), 158 | submit_save_path=Path("submissions"), 159 | num_tta=5, 160 | batch_size=8, 161 | extend="tile", 162 | ) 163 | -------------------------------------------------------------------------------- /utmosv2/config/fusion_stage2_wo_bvcc.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from types import SimpleNamespace 5 | 6 | from torchvision import transforms 7 | 8 | from utmosv2.transform import XYMasking 9 | 10 | batch_size = 16 11 | num_folds = 5 12 | 13 | sr = 16000 14 | 15 | preprocess = SimpleNamespace( 16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data") 17 | ) 18 | 19 | split = SimpleNamespace( 20 | type="sgkf_kind", 21 | target="mos", 22 | group="sys_id", 23 | kind="dataset", 24 | ) 25 | 26 | external_data: list[str] | str = [ 27 | "sarulab", 28 | "blizzard2008", 29 | "blizzard2009", 30 | "blizzard2011", 31 | "blizzard2010-EH1", 32 | "blizzard2010-EH2", 33 | "blizzard2010-ES1", 34 | "blizzard2010-ES3", 35 | "somos", 36 | ] 37 | use_bvcc = False 38 | 39 | 40 | validation_dataset = "each" 41 | 42 | dataset = SimpleNamespace( 43 | name="ssl_multispec_ext", 44 | specs=[ 45 | SimpleNamespace( 46 | mode="melspec", 47 | n_fft=4096, 48 | hop_length=32, 49 | win_length=4096, 50 | n_mels=512, 51 | shape=(512, 512), 52 | norm=80, 53 | ), 54 | SimpleNamespace( 55 | mode="melspec", 56 | n_fft=4096, 57 | hop_length=32, 58 | win_length=2048, 59 | n_mels=512, 60 | shape=(512, 512), 61 | norm=80, 62 | ), 63 | SimpleNamespace( 64 | mode="melspec", 65 | n_fft=4096, 66 | hop_length=32, 67 | win_length=1024, 68 | n_mels=512, 69 | shape=(512, 512), 70 | norm=80, 71 | ), 72 | SimpleNamespace( 73 | mode="melspec", 74 | n_fft=4096, 75 | hop_length=32, 76 | win_length=512, 77 | n_mels=512, 78 | shape=(512, 512), 79 | norm=80, 80 | ), 81 | ], 82 | spec_frames=SimpleNamespace( 83 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" 84 | ), 85 | ssl=SimpleNamespace( 86 | duration=3, 87 | ), 88 | ) 89 | transform = dict( 90 | train=transforms.Compose( 91 | [ 92 | transforms.Resize((512, 512)), 93 | XYMasking( 94 | num_masks_x=(0, 2), 95 | num_masks_y=(0, 2), 96 | mask_x_length=(10, 40), 97 | mask_y_length=(10, 30), 98 | fill_value=0, 99 | p=0.5, 100 | ), 101 | # transforms.ToTensor(), 102 | ] 103 | ), 104 | valid=transforms.Compose( 105 | [ 106 | transforms.Resize((512, 512)), 107 | # transforms.ToTensor() 108 | ] 109 | ), 110 | ) 111 | 112 | loss = [ 113 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), 114 | (SimpleNamespace(name="mse"), 0.2), 115 | ] 116 | 117 | optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4) 118 | 119 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-5) 120 | 121 | model = SimpleNamespace( 122 | name="ssl_multispec_ext_v2", 123 | multi_spec=SimpleNamespace( 124 | backbone="tf_efficientnetv2_s.in21k_ft_in1k", 125 | pretrained=True, 126 | num_classes=1, 127 | pool_type="catavgmax", 128 | # feature_height=16, 129 | atten=True, 130 | # classifier=None, 131 | ), 132 | ssl=SimpleNamespace( 133 | name="facebook/wav2vec2-base", 134 | attn=1, 135 | freeze=True, 136 | num_classes=1, 137 | ), 138 | ssl_spec=SimpleNamespace( 139 | ssl_weight="ssl_only_stage2_wo_bvcc", 140 | spec_weight="spec_only_wo_bvcc", 141 | num_classes=1, 142 | freeze=True, 143 | ), 144 | ) 145 | 146 | run = SimpleNamespace( 147 | mixup=True, 148 | mixup_alpha=0.4, 149 | num_epochs=8, 150 | ) 151 | 152 | main_metric = "sys_srcc" 153 | id_name = None 154 | 155 | 156 | inference = SimpleNamespace( 157 | save_path=Path("preds"), 158 | submit_save_path=Path("submissions"), 159 | num_tta=5, 160 | batch_size=8, 161 | extend="tile", 162 | ) 163 | -------------------------------------------------------------------------------- /utmosv2/config/fusion_stage2_wo_sarulab.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from types import SimpleNamespace 5 | 6 | from torchvision import transforms 7 | 8 | from utmosv2.transform import XYMasking 9 | 10 | batch_size = 16 11 | num_folds = 5 12 | 13 | sr = 16000 14 | 15 | preprocess = SimpleNamespace( 16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data") 17 | ) 18 | 19 | split = SimpleNamespace( 20 | type="sgkf_kind", 21 | target="mos", 22 | group="sys_id", 23 | kind="dataset", 24 | ) 25 | 26 | external_data: list[str] | str = [ 27 | # "sarulab", 28 | "blizzard2008", 29 | "blizzard2009", 30 | "blizzard2011", 31 | "blizzard2010-EH1", 32 | "blizzard2010-EH2", 33 | "blizzard2010-ES1", 34 | "blizzard2010-ES3", 35 | "somos", 36 | ] 37 | use_bvcc = True 38 | 39 | 40 | validation_dataset = "each" 41 | 42 | dataset = SimpleNamespace( 43 | name="ssl_multispec_ext", 44 | specs=[ 45 | SimpleNamespace( 46 | mode="melspec", 47 | n_fft=4096, 48 | hop_length=32, 49 | win_length=4096, 50 | n_mels=512, 51 | shape=(512, 512), 52 | norm=80, 53 | ), 54 | SimpleNamespace( 55 | mode="melspec", 56 | n_fft=4096, 57 | hop_length=32, 58 | win_length=2048, 59 | n_mels=512, 60 | shape=(512, 512), 61 | norm=80, 62 | ), 63 | SimpleNamespace( 64 | mode="melspec", 65 | n_fft=4096, 66 | hop_length=32, 67 | win_length=1024, 68 | n_mels=512, 69 | shape=(512, 512), 70 | norm=80, 71 | ), 72 | SimpleNamespace( 73 | mode="melspec", 74 | n_fft=4096, 75 | hop_length=32, 76 | win_length=512, 77 | n_mels=512, 78 | shape=(512, 512), 79 | norm=80, 80 | ), 81 | ], 82 | spec_frames=SimpleNamespace( 83 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" 84 | ), 85 | ssl=SimpleNamespace( 86 | duration=3, 87 | ), 88 | ) 89 | transform = dict( 90 | train=transforms.Compose( 91 | [ 92 | transforms.Resize((512, 512)), 93 | XYMasking( 94 | num_masks_x=(0, 2), 95 | num_masks_y=(0, 2), 96 | mask_x_length=(10, 40), 97 | mask_y_length=(10, 30), 98 | fill_value=0, 99 | p=0.5, 100 | ), 101 | # transforms.ToTensor(), 102 | ] 103 | ), 104 | valid=transforms.Compose( 105 | [ 106 | transforms.Resize((512, 512)), 107 | # transforms.ToTensor() 108 | ] 109 | ), 110 | ) 111 | 112 | loss = [ 113 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), 114 | (SimpleNamespace(name="mse"), 0.2), 115 | ] 116 | 117 | optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4) 118 | 119 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-5) 120 | 121 | model = SimpleNamespace( 122 | name="ssl_multispec_ext_v2", 123 | multi_spec=SimpleNamespace( 124 | backbone="tf_efficientnetv2_s.in21k_ft_in1k", 125 | pretrained=True, 126 | num_classes=1, 127 | pool_type="catavgmax", 128 | # feature_height=16, 129 | atten=True, 130 | # classifier=None, 131 | ), 132 | ssl=SimpleNamespace( 133 | name="facebook/wav2vec2-base", 134 | attn=1, 135 | freeze=True, 136 | num_classes=1, 137 | ), 138 | ssl_spec=SimpleNamespace( 139 | ssl_weight="ssl_only_stage2_wo_sarulab", 140 | spec_weight="spec_only_wo_sarulab", 141 | num_classes=1, 142 | freeze=True, 143 | ), 144 | ) 145 | 146 | run = SimpleNamespace( 147 | mixup=True, 148 | mixup_alpha=0.4, 149 | num_epochs=8, 150 | ) 151 | 152 | main_metric = "sys_srcc" 153 | id_name = None 154 | 155 | 156 | inference = SimpleNamespace( 157 | save_path=Path("preds"), 158 | submit_save_path=Path("submissions"), 159 | num_tta=5, 160 | batch_size=8, 161 | extend="tile", 162 | ) 163 | -------------------------------------------------------------------------------- /utmosv2/config/fusion_stage2_wo_somos.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from types import SimpleNamespace 5 | 6 | from torchvision import transforms 7 | 8 | from utmosv2.transform import XYMasking 9 | 10 | batch_size = 16 11 | num_folds = 5 12 | 13 | sr = 16000 14 | 15 | preprocess = SimpleNamespace( 16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data") 17 | ) 18 | 19 | split = SimpleNamespace( 20 | type="sgkf_kind", 21 | target="mos", 22 | group="sys_id", 23 | kind="dataset", 24 | ) 25 | 26 | external_data: list[str] | str = [ 27 | "sarulab", 28 | "blizzard2008", 29 | "blizzard2009", 30 | "blizzard2011", 31 | "blizzard2010-EH1", 32 | "blizzard2010-EH2", 33 | "blizzard2010-ES1", 34 | "blizzard2010-ES3", 35 | # "somos", 36 | ] 37 | use_bvcc = True 38 | 39 | 40 | validation_dataset = "each" 41 | 42 | dataset = SimpleNamespace( 43 | name="ssl_multispec_ext", 44 | specs=[ 45 | SimpleNamespace( 46 | mode="melspec", 47 | n_fft=4096, 48 | hop_length=32, 49 | win_length=4096, 50 | n_mels=512, 51 | shape=(512, 512), 52 | norm=80, 53 | ), 54 | SimpleNamespace( 55 | mode="melspec", 56 | n_fft=4096, 57 | hop_length=32, 58 | win_length=2048, 59 | n_mels=512, 60 | shape=(512, 512), 61 | norm=80, 62 | ), 63 | SimpleNamespace( 64 | mode="melspec", 65 | n_fft=4096, 66 | hop_length=32, 67 | win_length=1024, 68 | n_mels=512, 69 | shape=(512, 512), 70 | norm=80, 71 | ), 72 | SimpleNamespace( 73 | mode="melspec", 74 | n_fft=4096, 75 | hop_length=32, 76 | win_length=512, 77 | n_mels=512, 78 | shape=(512, 512), 79 | norm=80, 80 | ), 81 | ], 82 | spec_frames=SimpleNamespace( 83 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" 84 | ), 85 | ssl=SimpleNamespace( 86 | duration=3, 87 | ), 88 | ) 89 | transform = dict( 90 | train=transforms.Compose( 91 | [ 92 | transforms.Resize((512, 512)), 93 | XYMasking( 94 | num_masks_x=(0, 2), 95 | num_masks_y=(0, 2), 96 | mask_x_length=(10, 40), 97 | mask_y_length=(10, 30), 98 | fill_value=0, 99 | p=0.5, 100 | ), 101 | # transforms.ToTensor(), 102 | ] 103 | ), 104 | valid=transforms.Compose( 105 | [ 106 | transforms.Resize((512, 512)), 107 | # transforms.ToTensor() 108 | ] 109 | ), 110 | ) 111 | 112 | loss = [ 113 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), 114 | (SimpleNamespace(name="mse"), 0.2), 115 | ] 116 | 117 | optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4) 118 | 119 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-5) 120 | 121 | model = SimpleNamespace( 122 | name="ssl_multispec_ext_v2", 123 | multi_spec=SimpleNamespace( 124 | backbone="tf_efficientnetv2_s.in21k_ft_in1k", 125 | pretrained=True, 126 | num_classes=1, 127 | pool_type="catavgmax", 128 | # feature_height=16, 129 | atten=True, 130 | # classifier=None, 131 | ), 132 | ssl=SimpleNamespace( 133 | name="facebook/wav2vec2-base", 134 | attn=1, 135 | freeze=True, 136 | num_classes=1, 137 | ), 138 | ssl_spec=SimpleNamespace( 139 | ssl_weight="ssl_only_stage2_wo_somos", 140 | spec_weight="spec_only_wo_somos", 141 | num_classes=1, 142 | freeze=True, 143 | ), 144 | ) 145 | 146 | run = SimpleNamespace( 147 | mixup=True, 148 | mixup_alpha=0.4, 149 | num_epochs=8, 150 | ) 151 | 152 | main_metric = "sys_srcc" 153 | id_name = None 154 | 155 | 156 | inference = SimpleNamespace( 157 | save_path=Path("preds"), 158 | submit_save_path=Path("submissions"), 159 | num_tta=5, 160 | batch_size=8, 161 | extend="tile", 162 | ) 163 | -------------------------------------------------------------------------------- /utmosv2/config/fusion_stage3.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from types import SimpleNamespace 5 | 6 | from torchvision import transforms 7 | 8 | from utmosv2.transform import XYMasking 9 | 10 | batch_size = 8 11 | num_folds = 5 12 | 13 | sr = 16000 14 | 15 | preprocess = SimpleNamespace( 16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data") 17 | ) 18 | 19 | split = SimpleNamespace( 20 | type="sgkf_kind", 21 | target="mos", 22 | group="sys_id", 23 | kind="dataset", 24 | ) 25 | 26 | external_data: list[str] | str = "all" 27 | use_bvcc = True 28 | 29 | 30 | validation_dataset = "each" 31 | 32 | dataset = SimpleNamespace( 33 | name="ssl_multispec_ext", 34 | specs=[ 35 | SimpleNamespace( 36 | mode="melspec", 37 | n_fft=4096, 38 | hop_length=32, 39 | win_length=4096, 40 | n_mels=512, 41 | shape=(512, 512), 42 | norm=80, 43 | ), 44 | SimpleNamespace( 45 | mode="melspec", 46 | n_fft=4096, 47 | hop_length=32, 48 | win_length=2048, 49 | n_mels=512, 50 | shape=(512, 512), 51 | norm=80, 52 | ), 53 | SimpleNamespace( 54 | mode="melspec", 55 | n_fft=4096, 56 | hop_length=32, 57 | win_length=1024, 58 | n_mels=512, 59 | shape=(512, 512), 60 | norm=80, 61 | ), 62 | SimpleNamespace( 63 | mode="melspec", 64 | n_fft=4096, 65 | hop_length=32, 66 | win_length=512, 67 | n_mels=512, 68 | shape=(512, 512), 69 | norm=80, 70 | ), 71 | ], 72 | spec_frames=SimpleNamespace( 73 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" 74 | ), 75 | ssl=SimpleNamespace( 76 | duration=3, 77 | ), 78 | ) 79 | transform = dict( 80 | train=transforms.Compose( 81 | [ 82 | transforms.Resize((512, 512)), 83 | XYMasking( 84 | num_masks_x=(0, 2), 85 | num_masks_y=(0, 2), 86 | mask_x_length=(10, 40), 87 | mask_y_length=(10, 30), 88 | fill_value=0, 89 | p=0.5, 90 | ), 91 | # transforms.ToTensor(), 92 | ] 93 | ), 94 | valid=transforms.Compose( 95 | [ 96 | transforms.Resize((512, 512)), 97 | # transforms.ToTensor() 98 | ] 99 | ), 100 | ) 101 | 102 | loss = [ 103 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), 104 | (SimpleNamespace(name="mse"), 0.2), 105 | ] 106 | 107 | optimizer = SimpleNamespace(name="adamw", lr=5e-5, weight_decay=1e-4) 108 | 109 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-8) 110 | 111 | model = SimpleNamespace( 112 | name="ssl_multispec_ext_v2", 113 | multi_spec=SimpleNamespace( 114 | backbone="tf_efficientnetv2_s.in21k_ft_in1k", 115 | pretrained=True, 116 | num_classes=1, 117 | pool_type="catavgmax", 118 | # feature_height=16, 119 | atten=True, 120 | # classifier=None, 121 | ), 122 | ssl=SimpleNamespace( 123 | name="facebook/wav2vec2-base", 124 | attn=1, 125 | freeze=False, 126 | num_classes=1, 127 | ), 128 | ssl_spec=SimpleNamespace( 129 | ssl_weight="ssl_only_stage2", 130 | spec_weight="spec_only", 131 | num_classes=1, 132 | freeze=False, 133 | ), 134 | ) 135 | 136 | run = SimpleNamespace( 137 | mixup=True, 138 | mixup_alpha=0.4, 139 | num_epochs=2, 140 | ) 141 | 142 | main_metric = "sys_srcc" 143 | id_name = None 144 | 145 | 146 | inference = SimpleNamespace( 147 | save_path=Path("preds"), 148 | submit_save_path=Path("submissions"), 149 | num_tta=5, 150 | batch_size=8, 151 | extend="tile", 152 | ) 153 | -------------------------------------------------------------------------------- /utmosv2/config/fusion_stage3_wo_bc.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from types import SimpleNamespace 5 | 6 | from torchvision import transforms 7 | 8 | from utmosv2.transform import XYMasking 9 | 10 | batch_size = 8 11 | num_folds = 5 12 | 13 | sr = 16000 14 | 15 | preprocess = SimpleNamespace( 16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data") 17 | ) 18 | 19 | split = SimpleNamespace( 20 | type="sgkf_kind", 21 | target="mos", 22 | group="sys_id", 23 | kind="dataset", 24 | ) 25 | 26 | external_data: list[str] | str = [ 27 | "sarulab", 28 | # "blizzard2008", 29 | # "blizzard2009", 30 | # "blizzard2011", 31 | # "blizzard2010-EH1", 32 | # "blizzard2010-EH2", 33 | # "blizzard2010-ES1", 34 | # "blizzard2010-ES3", 35 | "somos", 36 | ] 37 | use_bvcc = True 38 | 39 | 40 | validation_dataset = "each" 41 | 42 | dataset = SimpleNamespace( 43 | name="ssl_multispec_ext", 44 | specs=[ 45 | SimpleNamespace( 46 | mode="melspec", 47 | n_fft=4096, 48 | hop_length=32, 49 | win_length=4096, 50 | n_mels=512, 51 | shape=(512, 512), 52 | norm=80, 53 | ), 54 | SimpleNamespace( 55 | mode="melspec", 56 | n_fft=4096, 57 | hop_length=32, 58 | win_length=2048, 59 | n_mels=512, 60 | shape=(512, 512), 61 | norm=80, 62 | ), 63 | SimpleNamespace( 64 | mode="melspec", 65 | n_fft=4096, 66 | hop_length=32, 67 | win_length=1024, 68 | n_mels=512, 69 | shape=(512, 512), 70 | norm=80, 71 | ), 72 | SimpleNamespace( 73 | mode="melspec", 74 | n_fft=4096, 75 | hop_length=32, 76 | win_length=512, 77 | n_mels=512, 78 | shape=(512, 512), 79 | norm=80, 80 | ), 81 | ], 82 | spec_frames=SimpleNamespace( 83 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" 84 | ), 85 | ssl=SimpleNamespace( 86 | duration=3, 87 | ), 88 | ) 89 | transform = dict( 90 | train=transforms.Compose( 91 | [ 92 | transforms.Resize((512, 512)), 93 | XYMasking( 94 | num_masks_x=(0, 2), 95 | num_masks_y=(0, 2), 96 | mask_x_length=(10, 40), 97 | mask_y_length=(10, 30), 98 | fill_value=0, 99 | p=0.5, 100 | ), 101 | # transforms.ToTensor(), 102 | ] 103 | ), 104 | valid=transforms.Compose( 105 | [ 106 | transforms.Resize((512, 512)), 107 | # transforms.ToTensor() 108 | ] 109 | ), 110 | ) 111 | 112 | loss = [ 113 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), 114 | (SimpleNamespace(name="mse"), 0.2), 115 | ] 116 | 117 | optimizer = SimpleNamespace(name="adamw", lr=5e-5, weight_decay=1e-4) 118 | 119 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-8) 120 | 121 | model = SimpleNamespace( 122 | name="ssl_multispec_ext_v2", 123 | multi_spec=SimpleNamespace( 124 | backbone="tf_efficientnetv2_s.in21k_ft_in1k", 125 | pretrained=True, 126 | num_classes=1, 127 | pool_type="catavgmax", 128 | # feature_height=16, 129 | atten=True, 130 | # classifier=None, 131 | ), 132 | ssl=SimpleNamespace( 133 | name="facebook/wav2vec2-base", 134 | attn=1, 135 | freeze=False, 136 | num_classes=1, 137 | ), 138 | ssl_spec=SimpleNamespace( 139 | ssl_weight="ssl_only_stage2_wo_bc", 140 | spec_weight="spec_only_wo_bc", 141 | num_classes=1, 142 | freeze=False, 143 | ), 144 | ) 145 | 146 | run = SimpleNamespace( 147 | mixup=True, 148 | mixup_alpha=0.4, 149 | num_epochs=2, 150 | ) 151 | 152 | main_metric = "sys_srcc" 153 | id_name = None 154 | 155 | 156 | inference = SimpleNamespace( 157 | save_path=Path("preds"), 158 | submit_save_path=Path("submissions"), 159 | num_tta=5, 160 | batch_size=8, 161 | extend="tile", 162 | ) 163 | -------------------------------------------------------------------------------- /utmosv2/config/fusion_stage3_wo_bvcc.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from types import SimpleNamespace 5 | 6 | from torchvision import transforms 7 | 8 | from utmosv2.transform import XYMasking 9 | 10 | batch_size = 8 11 | num_folds = 5 12 | 13 | sr = 16000 14 | 15 | preprocess = SimpleNamespace( 16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data") 17 | ) 18 | 19 | split = SimpleNamespace( 20 | type="sgkf_kind", 21 | target="mos", 22 | group="sys_id", 23 | kind="dataset", 24 | ) 25 | 26 | external_data: list[str] | str = [ 27 | "sarulab", 28 | "blizzard2008", 29 | "blizzard2009", 30 | "blizzard2011", 31 | "blizzard2010-EH1", 32 | "blizzard2010-EH2", 33 | "blizzard2010-ES1", 34 | "blizzard2010-ES3", 35 | "somos", 36 | ] 37 | use_bvcc = False 38 | 39 | 40 | validation_dataset = "each" 41 | 42 | dataset = SimpleNamespace( 43 | name="ssl_multispec_ext", 44 | specs=[ 45 | SimpleNamespace( 46 | mode="melspec", 47 | n_fft=4096, 48 | hop_length=32, 49 | win_length=4096, 50 | n_mels=512, 51 | shape=(512, 512), 52 | norm=80, 53 | ), 54 | SimpleNamespace( 55 | mode="melspec", 56 | n_fft=4096, 57 | hop_length=32, 58 | win_length=2048, 59 | n_mels=512, 60 | shape=(512, 512), 61 | norm=80, 62 | ), 63 | SimpleNamespace( 64 | mode="melspec", 65 | n_fft=4096, 66 | hop_length=32, 67 | win_length=1024, 68 | n_mels=512, 69 | shape=(512, 512), 70 | norm=80, 71 | ), 72 | SimpleNamespace( 73 | mode="melspec", 74 | n_fft=4096, 75 | hop_length=32, 76 | win_length=512, 77 | n_mels=512, 78 | shape=(512, 512), 79 | norm=80, 80 | ), 81 | ], 82 | spec_frames=SimpleNamespace( 83 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" 84 | ), 85 | ssl=SimpleNamespace( 86 | duration=3, 87 | ), 88 | ) 89 | transform = dict( 90 | train=transforms.Compose( 91 | [ 92 | transforms.Resize((512, 512)), 93 | XYMasking( 94 | num_masks_x=(0, 2), 95 | num_masks_y=(0, 2), 96 | mask_x_length=(10, 40), 97 | mask_y_length=(10, 30), 98 | fill_value=0, 99 | p=0.5, 100 | ), 101 | # transforms.ToTensor(), 102 | ] 103 | ), 104 | valid=transforms.Compose( 105 | [ 106 | transforms.Resize((512, 512)), 107 | # transforms.ToTensor() 108 | ] 109 | ), 110 | ) 111 | 112 | loss = [ 113 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), 114 | (SimpleNamespace(name="mse"), 0.2), 115 | ] 116 | 117 | optimizer = SimpleNamespace(name="adamw", lr=5e-5, weight_decay=1e-4) 118 | 119 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-8) 120 | 121 | model = SimpleNamespace( 122 | name="ssl_multispec_ext_v2", 123 | multi_spec=SimpleNamespace( 124 | backbone="tf_efficientnetv2_s.in21k_ft_in1k", 125 | pretrained=True, 126 | num_classes=1, 127 | pool_type="catavgmax", 128 | # feature_height=16, 129 | atten=True, 130 | # classifier=None, 131 | ), 132 | ssl=SimpleNamespace( 133 | name="facebook/wav2vec2-base", 134 | attn=1, 135 | freeze=False, 136 | num_classes=1, 137 | ), 138 | ssl_spec=SimpleNamespace( 139 | ssl_weight="ssl_only_stage2_wo_bvcc", 140 | spec_weight="spec_only_wo_bvcc", 141 | num_classes=1, 142 | freeze=False, 143 | ), 144 | ) 145 | 146 | run = SimpleNamespace( 147 | mixup=True, 148 | mixup_alpha=0.4, 149 | num_epochs=2, 150 | ) 151 | 152 | main_metric = "sys_srcc" 153 | id_name = None 154 | 155 | 156 | inference = SimpleNamespace( 157 | save_path=Path("preds"), 158 | submit_save_path=Path("submissions"), 159 | num_tta=5, 160 | batch_size=8, 161 | extend="tile", 162 | ) 163 | -------------------------------------------------------------------------------- /utmosv2/config/fusion_stage3_wo_sarulab.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from types import SimpleNamespace 5 | 6 | from torchvision import transforms 7 | 8 | from utmosv2.transform import XYMasking 9 | 10 | batch_size = 8 11 | num_folds = 5 12 | 13 | sr = 16000 14 | 15 | preprocess = SimpleNamespace( 16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data") 17 | ) 18 | 19 | split = SimpleNamespace( 20 | type="sgkf_kind", 21 | target="mos", 22 | group="sys_id", 23 | kind="dataset", 24 | ) 25 | 26 | external_data: list[str] | str = [ 27 | # "sarulab", 28 | "blizzard2008", 29 | "blizzard2009", 30 | "blizzard2011", 31 | "blizzard2010-EH1", 32 | "blizzard2010-EH2", 33 | "blizzard2010-ES1", 34 | "blizzard2010-ES3", 35 | "somos", 36 | ] 37 | use_bvcc = True 38 | 39 | 40 | validation_dataset = "each" 41 | 42 | dataset = SimpleNamespace( 43 | name="ssl_multispec_ext", 44 | specs=[ 45 | SimpleNamespace( 46 | mode="melspec", 47 | n_fft=4096, 48 | hop_length=32, 49 | win_length=4096, 50 | n_mels=512, 51 | shape=(512, 512), 52 | norm=80, 53 | ), 54 | SimpleNamespace( 55 | mode="melspec", 56 | n_fft=4096, 57 | hop_length=32, 58 | win_length=2048, 59 | n_mels=512, 60 | shape=(512, 512), 61 | norm=80, 62 | ), 63 | SimpleNamespace( 64 | mode="melspec", 65 | n_fft=4096, 66 | hop_length=32, 67 | win_length=1024, 68 | n_mels=512, 69 | shape=(512, 512), 70 | norm=80, 71 | ), 72 | SimpleNamespace( 73 | mode="melspec", 74 | n_fft=4096, 75 | hop_length=32, 76 | win_length=512, 77 | n_mels=512, 78 | shape=(512, 512), 79 | norm=80, 80 | ), 81 | ], 82 | spec_frames=SimpleNamespace( 83 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" 84 | ), 85 | ssl=SimpleNamespace( 86 | duration=3, 87 | ), 88 | ) 89 | transform = dict( 90 | train=transforms.Compose( 91 | [ 92 | transforms.Resize((512, 512)), 93 | XYMasking( 94 | num_masks_x=(0, 2), 95 | num_masks_y=(0, 2), 96 | mask_x_length=(10, 40), 97 | mask_y_length=(10, 30), 98 | fill_value=0, 99 | p=0.5, 100 | ), 101 | # transforms.ToTensor(), 102 | ] 103 | ), 104 | valid=transforms.Compose( 105 | [ 106 | transforms.Resize((512, 512)), 107 | # transforms.ToTensor() 108 | ] 109 | ), 110 | ) 111 | 112 | loss = [ 113 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), 114 | (SimpleNamespace(name="mse"), 0.2), 115 | ] 116 | 117 | optimizer = SimpleNamespace(name="adamw", lr=5e-5, weight_decay=1e-4) 118 | 119 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-8) 120 | 121 | model = SimpleNamespace( 122 | name="ssl_multispec_ext_v2", 123 | multi_spec=SimpleNamespace( 124 | backbone="tf_efficientnetv2_s.in21k_ft_in1k", 125 | pretrained=True, 126 | num_classes=1, 127 | pool_type="catavgmax", 128 | # feature_height=16, 129 | atten=True, 130 | # classifier=None, 131 | ), 132 | ssl=SimpleNamespace( 133 | name="facebook/wav2vec2-base", 134 | attn=1, 135 | freeze=False, 136 | num_classes=1, 137 | ), 138 | ssl_spec=SimpleNamespace( 139 | ssl_weight="ssl_only_stage2_wo_sarulab", 140 | spec_weight="spec_only_wo_sarulab", 141 | num_classes=1, 142 | freeze=False, 143 | ), 144 | ) 145 | 146 | run = SimpleNamespace( 147 | mixup=True, 148 | mixup_alpha=0.4, 149 | num_epochs=2, 150 | ) 151 | 152 | main_metric = "sys_srcc" 153 | id_name = None 154 | 155 | 156 | inference = SimpleNamespace( 157 | save_path=Path("preds"), 158 | submit_save_path=Path("submissions"), 159 | num_tta=5, 160 | batch_size=8, 161 | extend="tile", 162 | ) 163 | -------------------------------------------------------------------------------- /utmosv2/config/fusion_stage3_wo_somos.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from types import SimpleNamespace 5 | 6 | from torchvision import transforms 7 | 8 | from utmosv2.transform import XYMasking 9 | 10 | batch_size = 8 11 | num_folds = 5 12 | 13 | sr = 16000 14 | 15 | preprocess = SimpleNamespace( 16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data") 17 | ) 18 | 19 | split = SimpleNamespace( 20 | type="sgkf_kind", 21 | target="mos", 22 | group="sys_id", 23 | kind="dataset", 24 | ) 25 | 26 | external_data: list[str] | str = [ 27 | "sarulab", 28 | "blizzard2008", 29 | "blizzard2009", 30 | "blizzard2011", 31 | "blizzard2010-EH1", 32 | "blizzard2010-EH2", 33 | "blizzard2010-ES1", 34 | "blizzard2010-ES3", 35 | # "somos", 36 | ] 37 | use_bvcc = True 38 | 39 | 40 | validation_dataset = "each" 41 | 42 | dataset = SimpleNamespace( 43 | name="ssl_multispec_ext", 44 | specs=[ 45 | SimpleNamespace( 46 | mode="melspec", 47 | n_fft=4096, 48 | hop_length=32, 49 | win_length=4096, 50 | n_mels=512, 51 | shape=(512, 512), 52 | norm=80, 53 | ), 54 | SimpleNamespace( 55 | mode="melspec", 56 | n_fft=4096, 57 | hop_length=32, 58 | win_length=2048, 59 | n_mels=512, 60 | shape=(512, 512), 61 | norm=80, 62 | ), 63 | SimpleNamespace( 64 | mode="melspec", 65 | n_fft=4096, 66 | hop_length=32, 67 | win_length=1024, 68 | n_mels=512, 69 | shape=(512, 512), 70 | norm=80, 71 | ), 72 | SimpleNamespace( 73 | mode="melspec", 74 | n_fft=4096, 75 | hop_length=32, 76 | win_length=512, 77 | n_mels=512, 78 | shape=(512, 512), 79 | norm=80, 80 | ), 81 | ], 82 | spec_frames=SimpleNamespace( 83 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" 84 | ), 85 | ssl=SimpleNamespace( 86 | duration=3, 87 | ), 88 | ) 89 | transform = dict( 90 | train=transforms.Compose( 91 | [ 92 | transforms.Resize((512, 512)), 93 | XYMasking( 94 | num_masks_x=(0, 2), 95 | num_masks_y=(0, 2), 96 | mask_x_length=(10, 40), 97 | mask_y_length=(10, 30), 98 | fill_value=0, 99 | p=0.5, 100 | ), 101 | # transforms.ToTensor(), 102 | ] 103 | ), 104 | valid=transforms.Compose( 105 | [ 106 | transforms.Resize((512, 512)), 107 | # transforms.ToTensor() 108 | ] 109 | ), 110 | ) 111 | 112 | loss = [ 113 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), 114 | (SimpleNamespace(name="mse"), 0.2), 115 | ] 116 | 117 | optimizer = SimpleNamespace(name="adamw", lr=5e-5, weight_decay=1e-4) 118 | 119 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-8) 120 | 121 | model = SimpleNamespace( 122 | name="ssl_multispec_ext_v2", 123 | multi_spec=SimpleNamespace( 124 | backbone="tf_efficientnetv2_s.in21k_ft_in1k", 125 | pretrained=True, 126 | num_classes=1, 127 | pool_type="catavgmax", 128 | # feature_height=16, 129 | atten=True, 130 | # classifier=None, 131 | ), 132 | ssl=SimpleNamespace( 133 | name="facebook/wav2vec2-base", 134 | attn=1, 135 | freeze=False, 136 | num_classes=1, 137 | ), 138 | ssl_spec=SimpleNamespace( 139 | ssl_weight="ssl_only_stage2_wo_somos", 140 | spec_weight="spec_only_wo_somos", 141 | num_classes=1, 142 | freeze=False, 143 | ), 144 | ) 145 | 146 | run = SimpleNamespace( 147 | mixup=True, 148 | mixup_alpha=0.4, 149 | num_epochs=2, 150 | ) 151 | 152 | main_metric = "sys_srcc" 153 | id_name = None 154 | 155 | 156 | inference = SimpleNamespace( 157 | save_path=Path("preds"), 158 | submit_save_path=Path("submissions"), 159 | num_tta=5, 160 | batch_size=8, 161 | extend="tile", 162 | ) 163 | -------------------------------------------------------------------------------- /utmosv2/config/fusion_wo_stage1and2.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from types import SimpleNamespace 5 | 6 | from torchvision import transforms 7 | 8 | from utmosv2.transform import XYMasking 9 | 10 | batch_size = 8 11 | num_folds = 5 12 | 13 | sr = 16000 14 | 15 | preprocess = SimpleNamespace( 16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data") 17 | ) 18 | 19 | split = SimpleNamespace( 20 | type="sgkf_kind", 21 | target="mos", 22 | group="sys_id", 23 | kind="dataset", 24 | ) 25 | 26 | external_data: list[str] | str = "all" 27 | use_bvcc = True 28 | 29 | 30 | validation_dataset = "each" 31 | 32 | dataset = SimpleNamespace( 33 | name="ssl_multispec_ext", 34 | specs=[ 35 | SimpleNamespace( 36 | mode="melspec", 37 | n_fft=4096, 38 | hop_length=32, 39 | win_length=4096, 40 | n_mels=512, 41 | shape=(512, 512), 42 | norm=80, 43 | ), 44 | SimpleNamespace( 45 | mode="melspec", 46 | n_fft=4096, 47 | hop_length=32, 48 | win_length=2048, 49 | n_mels=512, 50 | shape=(512, 512), 51 | norm=80, 52 | ), 53 | SimpleNamespace( 54 | mode="melspec", 55 | n_fft=4096, 56 | hop_length=32, 57 | win_length=1024, 58 | n_mels=512, 59 | shape=(512, 512), 60 | norm=80, 61 | ), 62 | SimpleNamespace( 63 | mode="melspec", 64 | n_fft=4096, 65 | hop_length=32, 66 | win_length=512, 67 | n_mels=512, 68 | shape=(512, 512), 69 | norm=80, 70 | ), 71 | ], 72 | spec_frames=SimpleNamespace( 73 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" 74 | ), 75 | ssl=SimpleNamespace( 76 | duration=3, 77 | ), 78 | ) 79 | transform = dict( 80 | train=transforms.Compose( 81 | [ 82 | transforms.Resize((512, 512)), 83 | XYMasking( 84 | num_masks_x=(0, 2), 85 | num_masks_y=(0, 2), 86 | mask_x_length=(10, 40), 87 | mask_y_length=(10, 30), 88 | fill_value=0, 89 | p=0.5, 90 | ), 91 | # transforms.ToTensor(), 92 | ] 93 | ), 94 | valid=transforms.Compose( 95 | [ 96 | transforms.Resize((512, 512)), 97 | # transforms.ToTensor() 98 | ] 99 | ), 100 | ) 101 | 102 | loss = [ 103 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), 104 | (SimpleNamespace(name="mse"), 0.2), 105 | ] 106 | 107 | optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4) 108 | 109 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7) 110 | 111 | model = SimpleNamespace( 112 | name="ssl_multispec_ext_v2", 113 | multi_spec=SimpleNamespace( 114 | backbone="tf_efficientnetv2_s.in21k_ft_in1k", 115 | pretrained=True, 116 | num_classes=1, 117 | pool_type="catavgmax", 118 | # feature_height=16, 119 | atten=True, 120 | # classifier=None, 121 | ), 122 | ssl=SimpleNamespace( 123 | name="facebook/wav2vec2-base", 124 | attn=1, 125 | freeze=False, 126 | num_classes=1, 127 | ), 128 | ssl_spec=SimpleNamespace( 129 | ssl_weight=None, 130 | spec_weight=None, 131 | num_classes=1, 132 | freeze=False, 133 | ), 134 | ) 135 | 136 | run = SimpleNamespace( 137 | mixup=True, 138 | mixup_alpha=0.4, 139 | num_epochs=20, 140 | ) 141 | 142 | main_metric = "sys_srcc" 143 | id_name = None 144 | 145 | 146 | inference = SimpleNamespace( 147 | save_path=Path("preds"), 148 | submit_save_path=Path("submissions"), 149 | num_tta=5, 150 | batch_size=8, 151 | extend="tile", 152 | ) 153 | -------------------------------------------------------------------------------- /utmosv2/config/fusion_wo_stage2.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from types import SimpleNamespace 5 | 6 | from torchvision import transforms 7 | 8 | from utmosv2.transform import XYMasking 9 | 10 | batch_size = 8 11 | num_folds = 5 12 | 13 | sr = 16000 14 | 15 | preprocess = SimpleNamespace( 16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data") 17 | ) 18 | 19 | split = SimpleNamespace( 20 | type="sgkf_kind", 21 | target="mos", 22 | group="sys_id", 23 | kind="dataset", 24 | ) 25 | 26 | external_data: list[str] | str = "all" 27 | use_bvcc = True 28 | 29 | 30 | validation_dataset = "each" 31 | 32 | dataset = SimpleNamespace( 33 | name="ssl_multispec_ext", 34 | specs=[ 35 | SimpleNamespace( 36 | mode="melspec", 37 | n_fft=4096, 38 | hop_length=32, 39 | win_length=4096, 40 | n_mels=512, 41 | shape=(512, 512), 42 | norm=80, 43 | ), 44 | SimpleNamespace( 45 | mode="melspec", 46 | n_fft=4096, 47 | hop_length=32, 48 | win_length=2048, 49 | n_mels=512, 50 | shape=(512, 512), 51 | norm=80, 52 | ), 53 | SimpleNamespace( 54 | mode="melspec", 55 | n_fft=4096, 56 | hop_length=32, 57 | win_length=1024, 58 | n_mels=512, 59 | shape=(512, 512), 60 | norm=80, 61 | ), 62 | SimpleNamespace( 63 | mode="melspec", 64 | n_fft=4096, 65 | hop_length=32, 66 | win_length=512, 67 | n_mels=512, 68 | shape=(512, 512), 69 | norm=80, 70 | ), 71 | ], 72 | spec_frames=SimpleNamespace( 73 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" 74 | ), 75 | ssl=SimpleNamespace( 76 | duration=3, 77 | ), 78 | ) 79 | transform = dict( 80 | train=transforms.Compose( 81 | [ 82 | transforms.Resize((512, 512)), 83 | XYMasking( 84 | num_masks_x=(0, 2), 85 | num_masks_y=(0, 2), 86 | mask_x_length=(10, 40), 87 | mask_y_length=(10, 30), 88 | fill_value=0, 89 | p=0.5, 90 | ), 91 | # transforms.ToTensor(), 92 | ] 93 | ), 94 | valid=transforms.Compose( 95 | [ 96 | transforms.Resize((512, 512)), 97 | # transforms.ToTensor() 98 | ] 99 | ), 100 | ) 101 | 102 | loss = [ 103 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), 104 | (SimpleNamespace(name="mse"), 0.2), 105 | ] 106 | 107 | optimizer = SimpleNamespace(name="adamw", lr=1e-4, weight_decay=1e-4) 108 | 109 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7) 110 | 111 | model = SimpleNamespace( 112 | name="ssl_multispec_ext_v2", 113 | multi_spec=SimpleNamespace( 114 | backbone="tf_efficientnetv2_s.in21k_ft_in1k", 115 | pretrained=True, 116 | num_classes=1, 117 | pool_type="catavgmax", 118 | # feature_height=16, 119 | atten=True, 120 | # classifier=None, 121 | ), 122 | ssl=SimpleNamespace( 123 | name="facebook/wav2vec2-base", 124 | attn=1, 125 | freeze=False, 126 | num_classes=1, 127 | ), 128 | ssl_spec=SimpleNamespace( 129 | ssl_weight="ssl_only_stage2", 130 | spec_weight="spec_only", 131 | num_classes=1, 132 | freeze=False, 133 | ), 134 | ) 135 | 136 | run = SimpleNamespace( 137 | mixup=True, 138 | mixup_alpha=0.4, 139 | num_epochs=20, 140 | ) 141 | 142 | main_metric = "sys_srcc" 143 | id_name = None 144 | 145 | 146 | inference = SimpleNamespace( 147 | save_path=Path("preds"), 148 | submit_save_path=Path("submissions"), 149 | num_tta=5, 150 | batch_size=8, 151 | extend="tile", 152 | ) 153 | -------------------------------------------------------------------------------- /utmosv2/config/spec_only.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from types import SimpleNamespace 5 | 6 | from torchvision import transforms 7 | 8 | from utmosv2.transform import XYMasking 9 | 10 | batch_size = 10 11 | num_folds = 5 12 | 13 | sr = 16000 14 | 15 | preprocess = SimpleNamespace( 16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data") 17 | ) 18 | 19 | split = SimpleNamespace( 20 | type="sgkf_kind", 21 | target="mos", 22 | group="sys_id", 23 | kind="dataset", 24 | ) 25 | 26 | external_data: list[str] | str = "all" 27 | use_bvcc = True 28 | 29 | 30 | validation_dataset = "each" 31 | 32 | dataset = SimpleNamespace( 33 | name="multi_spec_ext", 34 | specs=[ 35 | SimpleNamespace( 36 | mode="melspec", 37 | n_fft=4096, 38 | hop_length=32, 39 | win_length=4096, 40 | n_mels=512, 41 | shape=(512, 512), 42 | norm=80, 43 | ), 44 | SimpleNamespace( 45 | mode="melspec", 46 | n_fft=4096, 47 | hop_length=32, 48 | win_length=2048, 49 | n_mels=512, 50 | shape=(512, 512), 51 | norm=80, 52 | ), 53 | SimpleNamespace( 54 | mode="melspec", 55 | n_fft=4096, 56 | hop_length=32, 57 | win_length=1024, 58 | n_mels=512, 59 | shape=(512, 512), 60 | norm=80, 61 | ), 62 | SimpleNamespace( 63 | mode="melspec", 64 | n_fft=4096, 65 | hop_length=32, 66 | win_length=512, 67 | n_mels=512, 68 | shape=(512, 512), 69 | norm=80, 70 | ), 71 | ], 72 | spec_frames=SimpleNamespace( 73 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" 74 | ), 75 | ) 76 | transform = dict( 77 | train=transforms.Compose( 78 | [ 79 | transforms.Resize((512, 512)), 80 | XYMasking( 81 | num_masks_x=(0, 2), 82 | num_masks_y=(0, 2), 83 | mask_x_length=(10, 40), 84 | mask_y_length=(10, 30), 85 | fill_value=0, 86 | p=0.5, 87 | ), 88 | # transforms.ToTensor(), 89 | ] 90 | ), 91 | valid=transforms.Compose( 92 | [ 93 | transforms.Resize((512, 512)), 94 | # transforms.ToTensor() 95 | ] 96 | ), 97 | ) 98 | 99 | loss = [ 100 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), 101 | (SimpleNamespace(name="mse"), 0.2), 102 | ] 103 | 104 | optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4) 105 | 106 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7) 107 | 108 | model = SimpleNamespace( 109 | name="multi_spec_ext", 110 | multi_spec=SimpleNamespace( 111 | backbone="tf_efficientnetv2_s.in21k_ft_in1k", 112 | pretrained=True, 113 | num_classes=1, 114 | pool_type="catavgmax", 115 | # feature_height=16, 116 | atten=True, 117 | # classifier=None, 118 | ), 119 | ) 120 | 121 | run = SimpleNamespace( 122 | mixup=True, 123 | mixup_alpha=0.4, 124 | num_epochs=20, 125 | ) 126 | 127 | main_metric = "sys_srcc" 128 | id_name = None 129 | 130 | 131 | inference = SimpleNamespace( 132 | save_path=Path("preds"), 133 | submit_save_path=Path("submissions"), 134 | num_tta=5, 135 | batch_size=8, 136 | extend="tile", 137 | ) 138 | -------------------------------------------------------------------------------- /utmosv2/config/spec_only_wo_bc.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from types import SimpleNamespace 5 | 6 | from torchvision import transforms 7 | 8 | from utmosv2.transform import XYMasking 9 | 10 | batch_size = 10 11 | num_folds = 5 12 | 13 | sr = 16000 14 | 15 | preprocess = SimpleNamespace( 16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data") 17 | ) 18 | 19 | split = SimpleNamespace( 20 | type="sgkf_kind", 21 | target="mos", 22 | group="sys_id", 23 | kind="dataset", 24 | ) 25 | 26 | external_data: list[str] | str = [ 27 | "sarulab", 28 | # "blizzard2008", 29 | # "blizzard2009", 30 | # "blizzard2011", 31 | # "blizzard2010-EH1", 32 | # "blizzard2010-EH2", 33 | # "blizzard2010-ES1", 34 | # "blizzard2010-ES3", 35 | "somos", 36 | ] 37 | use_bvcc = True 38 | 39 | 40 | validation_dataset = "each" 41 | 42 | dataset = SimpleNamespace( 43 | name="multi_spec_ext", 44 | specs=[ 45 | SimpleNamespace( 46 | mode="melspec", 47 | n_fft=4096, 48 | hop_length=32, 49 | win_length=4096, 50 | n_mels=512, 51 | shape=(512, 512), 52 | norm=80, 53 | ), 54 | SimpleNamespace( 55 | mode="melspec", 56 | n_fft=4096, 57 | hop_length=32, 58 | win_length=2048, 59 | n_mels=512, 60 | shape=(512, 512), 61 | norm=80, 62 | ), 63 | SimpleNamespace( 64 | mode="melspec", 65 | n_fft=4096, 66 | hop_length=32, 67 | win_length=1024, 68 | n_mels=512, 69 | shape=(512, 512), 70 | norm=80, 71 | ), 72 | SimpleNamespace( 73 | mode="melspec", 74 | n_fft=4096, 75 | hop_length=32, 76 | win_length=512, 77 | n_mels=512, 78 | shape=(512, 512), 79 | norm=80, 80 | ), 81 | ], 82 | spec_frames=SimpleNamespace( 83 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" 84 | ), 85 | ) 86 | transform = dict( 87 | train=transforms.Compose( 88 | [ 89 | transforms.Resize((512, 512)), 90 | XYMasking( 91 | num_masks_x=(0, 2), 92 | num_masks_y=(0, 2), 93 | mask_x_length=(10, 40), 94 | mask_y_length=(10, 30), 95 | fill_value=0, 96 | p=0.5, 97 | ), 98 | # transforms.ToTensor(), 99 | ] 100 | ), 101 | valid=transforms.Compose( 102 | [ 103 | transforms.Resize((512, 512)), 104 | # transforms.ToTensor() 105 | ] 106 | ), 107 | ) 108 | 109 | loss = [ 110 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), 111 | (SimpleNamespace(name="mse"), 0.2), 112 | ] 113 | 114 | optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4) 115 | 116 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7) 117 | 118 | model = SimpleNamespace( 119 | name="multi_spec_ext", 120 | multi_spec=SimpleNamespace( 121 | backbone="tf_efficientnetv2_s.in21k_ft_in1k", 122 | pretrained=True, 123 | num_classes=1, 124 | pool_type="catavgmax", 125 | # feature_height=16, 126 | atten=True, 127 | # classifier=None, 128 | ), 129 | ) 130 | 131 | run = SimpleNamespace( 132 | mixup=True, 133 | mixup_alpha=0.4, 134 | num_epochs=20, 135 | ) 136 | 137 | main_metric = "sys_srcc" 138 | id_name = None 139 | 140 | 141 | inference = SimpleNamespace( 142 | save_path=Path("preds"), 143 | submit_save_path=Path("submissions"), 144 | num_tta=5, 145 | batch_size=8, 146 | extend="tile", 147 | ) 148 | -------------------------------------------------------------------------------- /utmosv2/config/spec_only_wo_bvcc.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from types import SimpleNamespace 5 | 6 | from torchvision import transforms 7 | 8 | from utmosv2.transform import XYMasking 9 | 10 | batch_size = 10 11 | num_folds = 5 12 | 13 | sr = 16000 14 | 15 | preprocess = SimpleNamespace( 16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data") 17 | ) 18 | 19 | split = SimpleNamespace( 20 | type="sgkf_kind", 21 | target="mos", 22 | group="sys_id", 23 | kind="dataset", 24 | ) 25 | 26 | external_data: list[str] | str = [ 27 | "sarulab", 28 | "blizzard2008", 29 | "blizzard2009", 30 | "blizzard2011", 31 | "blizzard2010-EH1", 32 | "blizzard2010-EH2", 33 | "blizzard2010-ES1", 34 | "blizzard2010-ES3", 35 | "somos", 36 | ] 37 | use_bvcc = False 38 | 39 | 40 | validation_dataset = "each" 41 | 42 | dataset = SimpleNamespace( 43 | name="multi_spec_ext", 44 | specs=[ 45 | SimpleNamespace( 46 | mode="melspec", 47 | n_fft=4096, 48 | hop_length=32, 49 | win_length=4096, 50 | n_mels=512, 51 | shape=(512, 512), 52 | norm=80, 53 | ), 54 | SimpleNamespace( 55 | mode="melspec", 56 | n_fft=4096, 57 | hop_length=32, 58 | win_length=2048, 59 | n_mels=512, 60 | shape=(512, 512), 61 | norm=80, 62 | ), 63 | SimpleNamespace( 64 | mode="melspec", 65 | n_fft=4096, 66 | hop_length=32, 67 | win_length=1024, 68 | n_mels=512, 69 | shape=(512, 512), 70 | norm=80, 71 | ), 72 | SimpleNamespace( 73 | mode="melspec", 74 | n_fft=4096, 75 | hop_length=32, 76 | win_length=512, 77 | n_mels=512, 78 | shape=(512, 512), 79 | norm=80, 80 | ), 81 | ], 82 | spec_frames=SimpleNamespace( 83 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" 84 | ), 85 | ) 86 | transform = dict( 87 | train=transforms.Compose( 88 | [ 89 | transforms.Resize((512, 512)), 90 | XYMasking( 91 | num_masks_x=(0, 2), 92 | num_masks_y=(0, 2), 93 | mask_x_length=(10, 40), 94 | mask_y_length=(10, 30), 95 | fill_value=0, 96 | p=0.5, 97 | ), 98 | # transforms.ToTensor(), 99 | ] 100 | ), 101 | valid=transforms.Compose( 102 | [ 103 | transforms.Resize((512, 512)), 104 | # transforms.ToTensor() 105 | ] 106 | ), 107 | ) 108 | 109 | loss = [ 110 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), 111 | (SimpleNamespace(name="mse"), 0.2), 112 | ] 113 | 114 | optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4) 115 | 116 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7) 117 | 118 | model = SimpleNamespace( 119 | name="multi_spec_ext", 120 | multi_spec=SimpleNamespace( 121 | backbone="tf_efficientnetv2_s.in21k_ft_in1k", 122 | pretrained=True, 123 | num_classes=1, 124 | pool_type="catavgmax", 125 | # feature_height=16, 126 | atten=True, 127 | # classifier=None, 128 | ), 129 | ) 130 | 131 | run = SimpleNamespace( 132 | mixup=True, 133 | mixup_alpha=0.4, 134 | num_epochs=20, 135 | ) 136 | 137 | main_metric = "sys_srcc" 138 | id_name = None 139 | 140 | 141 | inference = SimpleNamespace( 142 | save_path=Path("preds"), 143 | submit_save_path=Path("submissions"), 144 | num_tta=5, 145 | batch_size=8, 146 | extend="tile", 147 | ) 148 | -------------------------------------------------------------------------------- /utmosv2/config/spec_only_wo_sarulab.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from types import SimpleNamespace 5 | 6 | from torchvision import transforms 7 | 8 | from utmosv2.transform import XYMasking 9 | 10 | batch_size = 10 11 | num_folds = 5 12 | 13 | sr = 16000 14 | 15 | preprocess = SimpleNamespace( 16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data") 17 | ) 18 | 19 | split = SimpleNamespace( 20 | type="sgkf_kind", 21 | target="mos", 22 | group="sys_id", 23 | kind="dataset", 24 | ) 25 | 26 | external_data: list[str] | str = [ 27 | # "sarulab", 28 | "blizzard2008", 29 | "blizzard2009", 30 | "blizzard2011", 31 | "blizzard2010-EH1", 32 | "blizzard2010-EH2", 33 | "blizzard2010-ES1", 34 | "blizzard2010-ES3", 35 | "somos", 36 | ] 37 | use_bvcc = True 38 | 39 | 40 | validation_dataset = "each" 41 | 42 | dataset = SimpleNamespace( 43 | name="multi_spec_ext", 44 | specs=[ 45 | SimpleNamespace( 46 | mode="melspec", 47 | n_fft=4096, 48 | hop_length=32, 49 | win_length=4096, 50 | n_mels=512, 51 | shape=(512, 512), 52 | norm=80, 53 | ), 54 | SimpleNamespace( 55 | mode="melspec", 56 | n_fft=4096, 57 | hop_length=32, 58 | win_length=2048, 59 | n_mels=512, 60 | shape=(512, 512), 61 | norm=80, 62 | ), 63 | SimpleNamespace( 64 | mode="melspec", 65 | n_fft=4096, 66 | hop_length=32, 67 | win_length=1024, 68 | n_mels=512, 69 | shape=(512, 512), 70 | norm=80, 71 | ), 72 | SimpleNamespace( 73 | mode="melspec", 74 | n_fft=4096, 75 | hop_length=32, 76 | win_length=512, 77 | n_mels=512, 78 | shape=(512, 512), 79 | norm=80, 80 | ), 81 | ], 82 | spec_frames=SimpleNamespace( 83 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" 84 | ), 85 | ) 86 | transform = dict( 87 | train=transforms.Compose( 88 | [ 89 | transforms.Resize((512, 512)), 90 | XYMasking( 91 | num_masks_x=(0, 2), 92 | num_masks_y=(0, 2), 93 | mask_x_length=(10, 40), 94 | mask_y_length=(10, 30), 95 | fill_value=0, 96 | p=0.5, 97 | ), 98 | # transforms.ToTensor(), 99 | ] 100 | ), 101 | valid=transforms.Compose( 102 | [ 103 | transforms.Resize((512, 512)), 104 | # transforms.ToTensor() 105 | ] 106 | ), 107 | ) 108 | 109 | loss = [ 110 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), 111 | (SimpleNamespace(name="mse"), 0.2), 112 | ] 113 | 114 | optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4) 115 | 116 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7) 117 | 118 | model = SimpleNamespace( 119 | name="multi_spec_ext", 120 | multi_spec=SimpleNamespace( 121 | backbone="tf_efficientnetv2_s.in21k_ft_in1k", 122 | pretrained=True, 123 | num_classes=1, 124 | pool_type="catavgmax", 125 | # feature_height=16, 126 | atten=True, 127 | # classifier=None, 128 | ), 129 | ) 130 | 131 | run = SimpleNamespace( 132 | mixup=True, 133 | mixup_alpha=0.4, 134 | num_epochs=20, 135 | ) 136 | 137 | main_metric = "sys_srcc" 138 | id_name = None 139 | 140 | 141 | inference = SimpleNamespace( 142 | save_path=Path("preds"), 143 | submit_save_path=Path("submissions"), 144 | num_tta=5, 145 | batch_size=8, 146 | extend="tile", 147 | ) 148 | -------------------------------------------------------------------------------- /utmosv2/config/spec_only_wo_somos.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from types import SimpleNamespace 5 | 6 | from torchvision import transforms 7 | 8 | from utmosv2.transform import XYMasking 9 | 10 | batch_size = 10 11 | num_folds = 5 12 | 13 | sr = 16000 14 | 15 | preprocess = SimpleNamespace( 16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data") 17 | ) 18 | 19 | split = SimpleNamespace( 20 | type="sgkf_kind", 21 | target="mos", 22 | group="sys_id", 23 | kind="dataset", 24 | ) 25 | 26 | external_data: list[str] | str = [ 27 | "sarulab", 28 | "blizzard2008", 29 | "blizzard2009", 30 | "blizzard2011", 31 | "blizzard2010-EH1", 32 | "blizzard2010-EH2", 33 | "blizzard2010-ES1", 34 | "blizzard2010-ES3", 35 | # "somos", 36 | ] 37 | use_bvcc = True 38 | 39 | 40 | validation_dataset = "each" 41 | 42 | dataset = SimpleNamespace( 43 | name="multi_spec_ext", 44 | specs=[ 45 | SimpleNamespace( 46 | mode="melspec", 47 | n_fft=4096, 48 | hop_length=32, 49 | win_length=4096, 50 | n_mels=512, 51 | shape=(512, 512), 52 | norm=80, 53 | ), 54 | SimpleNamespace( 55 | mode="melspec", 56 | n_fft=4096, 57 | hop_length=32, 58 | win_length=2048, 59 | n_mels=512, 60 | shape=(512, 512), 61 | norm=80, 62 | ), 63 | SimpleNamespace( 64 | mode="melspec", 65 | n_fft=4096, 66 | hop_length=32, 67 | win_length=1024, 68 | n_mels=512, 69 | shape=(512, 512), 70 | norm=80, 71 | ), 72 | SimpleNamespace( 73 | mode="melspec", 74 | n_fft=4096, 75 | hop_length=32, 76 | win_length=512, 77 | n_mels=512, 78 | shape=(512, 512), 79 | norm=80, 80 | ), 81 | ], 82 | spec_frames=SimpleNamespace( 83 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile" 84 | ), 85 | ) 86 | transform = dict( 87 | train=transforms.Compose( 88 | [ 89 | transforms.Resize((512, 512)), 90 | XYMasking( 91 | num_masks_x=(0, 2), 92 | num_masks_y=(0, 2), 93 | mask_x_length=(10, 40), 94 | mask_y_length=(10, 30), 95 | fill_value=0, 96 | p=0.5, 97 | ), 98 | # transforms.ToTensor(), 99 | ] 100 | ), 101 | valid=transforms.Compose( 102 | [ 103 | transforms.Resize((512, 512)), 104 | # transforms.ToTensor() 105 | ] 106 | ), 107 | ) 108 | 109 | loss = [ 110 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), 111 | (SimpleNamespace(name="mse"), 0.2), 112 | ] 113 | 114 | optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4) 115 | 116 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7) 117 | 118 | model = SimpleNamespace( 119 | name="multi_spec_ext", 120 | multi_spec=SimpleNamespace( 121 | backbone="tf_efficientnetv2_s.in21k_ft_in1k", 122 | pretrained=True, 123 | num_classes=1, 124 | pool_type="catavgmax", 125 | # feature_height=16, 126 | atten=True, 127 | # classifier=None, 128 | ), 129 | ) 130 | 131 | run = SimpleNamespace( 132 | mixup=True, 133 | mixup_alpha=0.4, 134 | num_epochs=20, 135 | ) 136 | 137 | main_metric = "sys_srcc" 138 | id_name = None 139 | 140 | 141 | inference = SimpleNamespace( 142 | save_path=Path("preds"), 143 | submit_save_path=Path("submissions"), 144 | num_tta=5, 145 | batch_size=8, 146 | extend="tile", 147 | ) 148 | -------------------------------------------------------------------------------- /utmosv2/config/ssl_only_stage1.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from types import SimpleNamespace 5 | 6 | batch_size = 32 7 | num_folds = 5 8 | 9 | sr = 16000 10 | 11 | preprocess = SimpleNamespace( 12 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data") 13 | ) 14 | 15 | split = SimpleNamespace( 16 | type="sgkf_kind", 17 | target="mos", 18 | group="sys_id", 19 | kind="dataset", 20 | ) 21 | 22 | dataset = SimpleNamespace( 23 | name="sslext", 24 | ssl=SimpleNamespace( 25 | duration=3, 26 | ), 27 | ) 28 | 29 | external_data: list[str] | str = "all" 30 | use_bvcc = True 31 | 32 | 33 | validation_dataset = "each" 34 | 35 | loss = [ 36 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), 37 | (SimpleNamespace(name="mse"), 0.2), 38 | ] 39 | 40 | optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4) 41 | 42 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7) 43 | 44 | model_path = "model" 45 | model = SimpleNamespace( 46 | name="sslext", 47 | ssl=SimpleNamespace( 48 | name="facebook/wav2vec2-base", 49 | attn=1, 50 | freeze=True, 51 | num_classes=1, 52 | ), 53 | ) 54 | 55 | run = SimpleNamespace( 56 | mixup=True, 57 | mixup_alpha=0.4, 58 | num_epochs=20, 59 | ) 60 | 61 | main_metric = "sys_srcc" 62 | id_name = None 63 | 64 | 65 | inference = SimpleNamespace( 66 | save_path=Path("preds"), 67 | submit_save_path=Path("submissions"), 68 | num_tta=5, 69 | batch_size=8, 70 | # extend="tile", 71 | ) 72 | -------------------------------------------------------------------------------- /utmosv2/config/ssl_only_stage1_wo_bc.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from types import SimpleNamespace 5 | 6 | batch_size = 32 7 | num_folds = 5 8 | 9 | sr = 16000 10 | 11 | preprocess = SimpleNamespace( 12 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data") 13 | ) 14 | 15 | split = SimpleNamespace( 16 | type="sgkf_kind", 17 | target="mos", 18 | group="sys_id", 19 | kind="dataset", 20 | ) 21 | 22 | dataset = SimpleNamespace( 23 | name="sslext", 24 | ssl=SimpleNamespace( 25 | duration=3, 26 | ), 27 | ) 28 | 29 | external_data: list[str] | str = [ 30 | "sarulab", 31 | # "blizzard2008", 32 | # "blizzard2009", 33 | # "blizzard2011", 34 | # "blizzard2010-EH1", 35 | # "blizzard2010-EH2", 36 | # "blizzard2010-ES1", 37 | # "blizzard2010-ES3", 38 | "somos", 39 | ] 40 | use_bvcc = True 41 | 42 | 43 | validation_dataset = "each" 44 | 45 | loss = [ 46 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), 47 | (SimpleNamespace(name="mse"), 0.2), 48 | ] 49 | 50 | optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4) 51 | 52 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7) 53 | 54 | model_path = "model" 55 | model = SimpleNamespace( 56 | name="sslext", 57 | ssl=SimpleNamespace( 58 | name="facebook/wav2vec2-base", 59 | attn=1, 60 | freeze=True, 61 | num_classes=1, 62 | ), 63 | ) 64 | 65 | run = SimpleNamespace( 66 | mixup=True, 67 | mixup_alpha=0.4, 68 | num_epochs=20, 69 | ) 70 | 71 | main_metric = "sys_srcc" 72 | id_name = None 73 | 74 | 75 | inference = SimpleNamespace( 76 | save_path=Path("preds"), 77 | submit_save_path=Path("submissions"), 78 | num_tta=5, 79 | batch_size=8, 80 | # extend="tile", 81 | ) 82 | -------------------------------------------------------------------------------- /utmosv2/config/ssl_only_stage1_wo_bvcc.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from types import SimpleNamespace 5 | 6 | batch_size = 32 7 | num_folds = 5 8 | 9 | sr = 16000 10 | 11 | preprocess = SimpleNamespace( 12 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data") 13 | ) 14 | 15 | split = SimpleNamespace( 16 | type="sgkf_kind", 17 | target="mos", 18 | group="sys_id", 19 | kind="dataset", 20 | ) 21 | 22 | dataset = SimpleNamespace( 23 | name="sslext", 24 | ssl=SimpleNamespace( 25 | duration=3, 26 | ), 27 | ) 28 | 29 | external_data: list[str] | str = [ 30 | "sarulab", 31 | "blizzard2008", 32 | "blizzard2009", 33 | "blizzard2011", 34 | "blizzard2010-EH1", 35 | "blizzard2010-EH2", 36 | "blizzard2010-ES1", 37 | "blizzard2010-ES3", 38 | "somos", 39 | ] 40 | use_bvcc = False 41 | 42 | 43 | validation_dataset = "each" 44 | 45 | loss = [ 46 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), 47 | (SimpleNamespace(name="mse"), 0.2), 48 | ] 49 | 50 | optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4) 51 | 52 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7) 53 | 54 | model_path = "model" 55 | model = SimpleNamespace( 56 | name="sslext", 57 | ssl=SimpleNamespace( 58 | name="facebook/wav2vec2-base", 59 | attn=1, 60 | freeze=True, 61 | num_classes=1, 62 | ), 63 | ) 64 | 65 | run = SimpleNamespace( 66 | mixup=True, 67 | mixup_alpha=0.4, 68 | num_epochs=20, 69 | ) 70 | 71 | main_metric = "sys_srcc" 72 | id_name = None 73 | 74 | 75 | inference = SimpleNamespace( 76 | save_path=Path("preds"), 77 | submit_save_path=Path("submissions"), 78 | num_tta=5, 79 | batch_size=8, 80 | # extend="tile", 81 | ) 82 | -------------------------------------------------------------------------------- /utmosv2/config/ssl_only_stage1_wo_sarulab.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from types import SimpleNamespace 5 | 6 | batch_size = 32 7 | num_folds = 5 8 | 9 | sr = 16000 10 | 11 | preprocess = SimpleNamespace( 12 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data") 13 | ) 14 | 15 | split = SimpleNamespace( 16 | type="sgkf_kind", 17 | target="mos", 18 | group="sys_id", 19 | kind="dataset", 20 | ) 21 | 22 | dataset = SimpleNamespace( 23 | name="sslext", 24 | ssl=SimpleNamespace( 25 | duration=3, 26 | ), 27 | ) 28 | 29 | external_data: list[str] | str = [ 30 | # "sarulab", 31 | "blizzard2008", 32 | "blizzard2009", 33 | "blizzard2011", 34 | "blizzard2010-EH1", 35 | "blizzard2010-EH2", 36 | "blizzard2010-ES1", 37 | "blizzard2010-ES3", 38 | "somos", 39 | ] 40 | use_bvcc = True 41 | 42 | 43 | validation_dataset = "each" 44 | 45 | loss = [ 46 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), 47 | (SimpleNamespace(name="mse"), 0.2), 48 | ] 49 | 50 | optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4) 51 | 52 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7) 53 | 54 | model_path = "model" 55 | model = SimpleNamespace( 56 | name="sslext", 57 | ssl=SimpleNamespace( 58 | name="facebook/wav2vec2-base", 59 | attn=1, 60 | freeze=True, 61 | num_classes=1, 62 | ), 63 | ) 64 | 65 | run = SimpleNamespace( 66 | mixup=True, 67 | mixup_alpha=0.4, 68 | num_epochs=20, 69 | ) 70 | 71 | main_metric = "sys_srcc" 72 | id_name = None 73 | 74 | 75 | inference = SimpleNamespace( 76 | save_path=Path("preds"), 77 | submit_save_path=Path("submissions"), 78 | num_tta=5, 79 | batch_size=8, 80 | # extend="tile", 81 | ) 82 | -------------------------------------------------------------------------------- /utmosv2/config/ssl_only_stage1_wo_somos.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from types import SimpleNamespace 5 | 6 | batch_size = 32 7 | num_folds = 5 8 | 9 | sr = 16000 10 | 11 | preprocess = SimpleNamespace( 12 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data") 13 | ) 14 | 15 | split = SimpleNamespace( 16 | type="sgkf_kind", 17 | target="mos", 18 | group="sys_id", 19 | kind="dataset", 20 | ) 21 | 22 | dataset = SimpleNamespace( 23 | name="sslext", 24 | ssl=SimpleNamespace( 25 | duration=3, 26 | ), 27 | ) 28 | 29 | external_data: list[str] | str = [ 30 | "sarulab", 31 | "blizzard2008", 32 | "blizzard2009", 33 | "blizzard2011", 34 | "blizzard2010-EH1", 35 | "blizzard2010-EH2", 36 | "blizzard2010-ES1", 37 | "blizzard2010-ES3", 38 | # "somos", 39 | ] 40 | use_bvcc = True 41 | 42 | 43 | validation_dataset = "each" 44 | 45 | loss = [ 46 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), 47 | (SimpleNamespace(name="mse"), 0.2), 48 | ] 49 | 50 | optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4) 51 | 52 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7) 53 | 54 | model_path = "model" 55 | model = SimpleNamespace( 56 | name="sslext", 57 | ssl=SimpleNamespace( 58 | name="facebook/wav2vec2-base", 59 | attn=1, 60 | freeze=True, 61 | num_classes=1, 62 | ), 63 | ) 64 | 65 | run = SimpleNamespace( 66 | mixup=True, 67 | mixup_alpha=0.4, 68 | num_epochs=20, 69 | ) 70 | 71 | main_metric = "sys_srcc" 72 | id_name = None 73 | 74 | 75 | inference = SimpleNamespace( 76 | save_path=Path("preds"), 77 | submit_save_path=Path("submissions"), 78 | num_tta=5, 79 | batch_size=8, 80 | # extend="tile", 81 | ) 82 | -------------------------------------------------------------------------------- /utmosv2/config/ssl_only_stage2.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from types import SimpleNamespace 5 | 6 | batch_size = 32 7 | num_folds = 5 8 | 9 | sr = 16000 10 | 11 | preprocess = SimpleNamespace( 12 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data") 13 | ) 14 | 15 | split = SimpleNamespace( 16 | type="sgkf_kind", 17 | target="mos", 18 | group="sys_id", 19 | kind="dataset", 20 | ) 21 | 22 | dataset = SimpleNamespace( 23 | name="sslext", 24 | ssl=SimpleNamespace( 25 | duration=3, 26 | ), 27 | ) 28 | 29 | external_data: list[str] | str = "all" 30 | use_bvcc = True 31 | 32 | 33 | validation_dataset = "each" 34 | 35 | loss = [ 36 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), 37 | (SimpleNamespace(name="mse"), 0.2), 38 | ] 39 | 40 | optimizer = SimpleNamespace(name="adamw", lr=3e-5, weight_decay=1e-4) 41 | 42 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-9) 43 | 44 | model_path = "model" 45 | model = SimpleNamespace( 46 | name="sslext", 47 | ssl=SimpleNamespace( 48 | name="facebook/wav2vec2-base", 49 | attn=1, 50 | freeze=False, 51 | num_classes=1, 52 | ), 53 | ) 54 | 55 | run = SimpleNamespace( 56 | mixup=True, 57 | mixup_alpha=0.4, 58 | num_epochs=5, 59 | ) 60 | 61 | main_metric = "sys_srcc" 62 | id_name = None 63 | 64 | 65 | inference = SimpleNamespace( 66 | save_path=Path("preds"), 67 | submit_save_path=Path("submissions"), 68 | num_tta=5, 69 | batch_size=8, 70 | # extend="tile", 71 | ) 72 | -------------------------------------------------------------------------------- /utmosv2/config/ssl_only_stage2_wo_bc.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from types import SimpleNamespace 5 | 6 | batch_size = 32 7 | num_folds = 5 8 | 9 | sr = 16000 10 | 11 | preprocess = SimpleNamespace( 12 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data") 13 | ) 14 | 15 | split = SimpleNamespace( 16 | type="sgkf_kind", 17 | target="mos", 18 | group="sys_id", 19 | kind="dataset", 20 | ) 21 | 22 | dataset = SimpleNamespace( 23 | name="sslext", 24 | ssl=SimpleNamespace( 25 | duration=3, 26 | ), 27 | ) 28 | 29 | external_data: list[str] | str = [ 30 | "sarulab", 31 | # "blizzard2008", 32 | # "blizzard2009", 33 | # "blizzard2011", 34 | # "blizzard2010-EH1", 35 | # "blizzard2010-EH2", 36 | # "blizzard2010-ES1", 37 | # "blizzard2010-ES3", 38 | "somos", 39 | ] 40 | use_bvcc = True 41 | 42 | 43 | validation_dataset = "each" 44 | 45 | loss = [ 46 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), 47 | (SimpleNamespace(name="mse"), 0.2), 48 | ] 49 | 50 | optimizer = SimpleNamespace(name="adamw", lr=3e-5, weight_decay=1e-4) 51 | 52 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-9) 53 | 54 | model_path = "model" 55 | model = SimpleNamespace( 56 | name="sslext", 57 | ssl=SimpleNamespace( 58 | name="facebook/wav2vec2-base", 59 | attn=1, 60 | freeze=False, 61 | num_classes=1, 62 | ), 63 | ) 64 | 65 | run = SimpleNamespace( 66 | mixup=True, 67 | mixup_alpha=0.4, 68 | num_epochs=5, 69 | ) 70 | 71 | main_metric = "sys_srcc" 72 | id_name = None 73 | 74 | 75 | inference = SimpleNamespace( 76 | save_path=Path("preds"), 77 | submit_save_path=Path("submissions"), 78 | num_tta=5, 79 | batch_size=8, 80 | # extend="tile", 81 | ) 82 | -------------------------------------------------------------------------------- /utmosv2/config/ssl_only_stage2_wo_bvcc.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from types import SimpleNamespace 5 | 6 | batch_size = 32 7 | num_folds = 5 8 | 9 | sr = 16000 10 | 11 | preprocess = SimpleNamespace( 12 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data") 13 | ) 14 | 15 | split = SimpleNamespace( 16 | type="sgkf_kind", 17 | target="mos", 18 | group="sys_id", 19 | kind="dataset", 20 | ) 21 | 22 | dataset = SimpleNamespace( 23 | name="sslext", 24 | ssl=SimpleNamespace( 25 | duration=3, 26 | ), 27 | ) 28 | 29 | external_data: list[str] | str = [ 30 | "sarulab", 31 | "blizzard2008", 32 | "blizzard2009", 33 | "blizzard2011", 34 | "blizzard2010-EH1", 35 | "blizzard2010-EH2", 36 | "blizzard2010-ES1", 37 | "blizzard2010-ES3", 38 | "somos", 39 | ] 40 | use_bvcc = False 41 | 42 | 43 | validation_dataset = "each" 44 | 45 | loss = [ 46 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), 47 | (SimpleNamespace(name="mse"), 0.2), 48 | ] 49 | 50 | optimizer = SimpleNamespace(name="adamw", lr=3e-5, weight_decay=1e-4) 51 | 52 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-9) 53 | 54 | model_path = "model" 55 | model = SimpleNamespace( 56 | name="sslext", 57 | ssl=SimpleNamespace( 58 | name="facebook/wav2vec2-base", 59 | attn=1, 60 | freeze=False, 61 | num_classes=1, 62 | ), 63 | ) 64 | 65 | run = SimpleNamespace( 66 | mixup=True, 67 | mixup_alpha=0.4, 68 | num_epochs=5, 69 | ) 70 | 71 | main_metric = "sys_srcc" 72 | id_name = None 73 | 74 | 75 | inference = SimpleNamespace( 76 | save_path=Path("preds"), 77 | submit_save_path=Path("submissions"), 78 | num_tta=5, 79 | batch_size=8, 80 | # extend="tile", 81 | ) 82 | -------------------------------------------------------------------------------- /utmosv2/config/ssl_only_stage2_wo_sarulab.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from types import SimpleNamespace 5 | 6 | batch_size = 32 7 | num_folds = 5 8 | 9 | sr = 16000 10 | 11 | preprocess = SimpleNamespace( 12 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data") 13 | ) 14 | 15 | split = SimpleNamespace( 16 | type="sgkf_kind", 17 | target="mos", 18 | group="sys_id", 19 | kind="dataset", 20 | ) 21 | 22 | dataset = SimpleNamespace( 23 | name="sslext", 24 | ssl=SimpleNamespace( 25 | duration=3, 26 | ), 27 | ) 28 | 29 | external_data: list[str] | str = [ 30 | # "sarulab", 31 | "blizzard2008", 32 | "blizzard2009", 33 | "blizzard2011", 34 | "blizzard2010-EH1", 35 | "blizzard2010-EH2", 36 | "blizzard2010-ES1", 37 | "blizzard2010-ES3", 38 | "somos", 39 | ] 40 | use_bvcc = True 41 | 42 | 43 | validation_dataset = "each" 44 | 45 | loss = [ 46 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), 47 | (SimpleNamespace(name="mse"), 0.2), 48 | ] 49 | 50 | optimizer = SimpleNamespace(name="adamw", lr=3e-5, weight_decay=1e-4) 51 | 52 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-9) 53 | 54 | model_path = "model" 55 | model = SimpleNamespace( 56 | name="sslext", 57 | ssl=SimpleNamespace( 58 | name="facebook/wav2vec2-base", 59 | attn=1, 60 | freeze=False, 61 | num_classes=1, 62 | ), 63 | ) 64 | 65 | run = SimpleNamespace( 66 | mixup=True, 67 | mixup_alpha=0.4, 68 | num_epochs=5, 69 | ) 70 | 71 | main_metric = "sys_srcc" 72 | id_name = None 73 | 74 | 75 | inference = SimpleNamespace( 76 | save_path=Path("preds"), 77 | submit_save_path=Path("submissions"), 78 | num_tta=5, 79 | batch_size=8, 80 | # extend="tile", 81 | ) 82 | -------------------------------------------------------------------------------- /utmosv2/config/ssl_only_stage2_wo_somos.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from types import SimpleNamespace 5 | 6 | batch_size = 32 7 | num_folds = 5 8 | 9 | sr = 16000 10 | 11 | preprocess = SimpleNamespace( 12 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data") 13 | ) 14 | 15 | split = SimpleNamespace( 16 | type="sgkf_kind", 17 | target="mos", 18 | group="sys_id", 19 | kind="dataset", 20 | ) 21 | 22 | dataset = SimpleNamespace( 23 | name="sslext", 24 | ssl=SimpleNamespace( 25 | duration=3, 26 | ), 27 | ) 28 | 29 | external_data: list[str] | str = [ 30 | "sarulab", 31 | "blizzard2008", 32 | "blizzard2009", 33 | "blizzard2011", 34 | "blizzard2010-EH1", 35 | "blizzard2010-EH2", 36 | "blizzard2010-ES1", 37 | "blizzard2010-ES3", 38 | # "somos", 39 | ] 40 | use_bvcc = True 41 | 42 | 43 | validation_dataset = "each" 44 | 45 | loss = [ 46 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7), 47 | (SimpleNamespace(name="mse"), 0.2), 48 | ] 49 | 50 | optimizer = SimpleNamespace(name="adamw", lr=3e-5, weight_decay=1e-4) 51 | 52 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-9) 53 | 54 | model_path = "model" 55 | model = SimpleNamespace( 56 | name="sslext", 57 | ssl=SimpleNamespace( 58 | name="facebook/wav2vec2-base", 59 | attn=1, 60 | freeze=False, 61 | num_classes=1, 62 | ), 63 | ) 64 | 65 | run = SimpleNamespace( 66 | mixup=True, 67 | mixup_alpha=0.4, 68 | num_epochs=5, 69 | ) 70 | 71 | main_metric = "sys_srcc" 72 | id_name = None 73 | 74 | 75 | inference = SimpleNamespace( 76 | save_path=Path("preds"), 77 | submit_save_path=Path("submissions"), 78 | num_tta=5, 79 | batch_size=8, 80 | # extend="tile", 81 | ) 82 | -------------------------------------------------------------------------------- /utmosv2/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from utmosv2.dataset.multi_spec import MultiSpecDataset, MultiSpecExtDataset 2 | from utmosv2.dataset.ssl import SSLDataset, SSLExtDataset 3 | from utmosv2.dataset.ssl_multispec import SSLLMultiSpecExtDataset 4 | 5 | __all__ = [ 6 | "MultiSpecDataset", 7 | "MultiSpecExtDataset", 8 | "SSLLMultiSpecExtDataset", 9 | "SSLDataset", 10 | "SSLExtDataset", 11 | ] 12 | -------------------------------------------------------------------------------- /utmosv2/dataset/_base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import abc 4 | from collections.abc import Callable 5 | from typing import TYPE_CHECKING 6 | 7 | import torch 8 | 9 | from utmosv2._settings._config import Config 10 | 11 | if TYPE_CHECKING: 12 | import pandas as pd 13 | 14 | from utmosv2.dataset._schema import DatasetSchema 15 | 16 | 17 | class _BaseDataset(torch.utils.data.Dataset, abc.ABC): 18 | def __init__( 19 | self, 20 | cfg: Config, 21 | data: "pd.DataFrame" | list[DatasetSchema], 22 | phase: str, 23 | transform: dict[str, Callable[[torch.Tensor], torch.Tensor]] | None = None, 24 | ): 25 | self.cfg = cfg 26 | self.data = data 27 | self.phase = phase 28 | self.transform = transform 29 | 30 | def __len__(self) -> int: 31 | return len(self.data) 32 | 33 | @abc.abstractmethod 34 | def __getitem__(self, idx: int) -> tuple[torch.Tensor, ...]: 35 | pass 36 | -------------------------------------------------------------------------------- /utmosv2/dataset/_schema.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass 4 | from pathlib import Path 5 | 6 | 7 | @dataclass 8 | class DatasetSchema: 9 | file_path: Path 10 | dataset: str 11 | mos: int | None = None 12 | -------------------------------------------------------------------------------- /utmosv2/dataset/_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | import librosa 5 | import numpy as np 6 | 7 | from utmosv2._settings._config import Config 8 | 9 | 10 | def load_audio(cfg: Config, file: Path) -> np.ndarray: 11 | if file.suffix in [".wav", ".flac"]: 12 | y, sr = librosa.load(file, sr=None) 13 | y = librosa.resample(y, orig_sr=sr, target_sr=cfg.sr) 14 | else: 15 | y = np.load(file) 16 | return y 17 | 18 | 19 | def extend_audio(cfg: Config, y: np.ndarray, length: int, type: str) -> np.ndarray: 20 | if y.shape[0] > length: 21 | return y 22 | elif type == "tile": 23 | n = length // y.shape[0] + 1 24 | y = np.tile(y, n) 25 | return y 26 | else: 27 | raise NotImplementedError 28 | 29 | 30 | def select_random_start(y: np.ndarray, length: int) -> np.ndarray: 31 | start = np.random.randint(0, y.shape[0] - length) 32 | return y[start : start + length] 33 | 34 | 35 | def get_dataset_map(cfg: Config) -> dict[str, int]: 36 | if cfg.data_config: 37 | with open(cfg.data_config, "r") as f: 38 | datasets = json.load(f) 39 | return {d["name"]: i for i, d in enumerate(datasets["data"])} 40 | else: 41 | return { 42 | "bvcc": 0, 43 | "sarulab": 1, 44 | "blizzard2008": 2, 45 | "blizzard2009": 3, 46 | "blizzard2010-EH1": 4, 47 | "blizzard2010-EH2": 5, 48 | "blizzard2010-ES1": 6, 49 | "blizzard2010-ES3": 7, 50 | "blizzard2011": 8, 51 | "somos": 9, 52 | } 53 | 54 | 55 | def get_dataset_num(cfg: Config) -> int: 56 | return len(get_dataset_map(cfg)) 57 | -------------------------------------------------------------------------------- /utmosv2/dataset/multi_spec.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Callable 4 | from typing import TYPE_CHECKING 5 | 6 | import librosa 7 | import numpy as np 8 | import torch 9 | 10 | from utmosv2._settings._config import Config 11 | from utmosv2.dataset._base import _BaseDataset 12 | from utmosv2.dataset._utils import ( 13 | extend_audio, 14 | get_dataset_map, 15 | load_audio, 16 | select_random_start, 17 | ) 18 | from utmosv2.preprocess._preprocess import remove_silent_section 19 | 20 | if TYPE_CHECKING: 21 | import pandas as pd 22 | 23 | from utmosv2.dataset._schema import DatasetSchema 24 | 25 | 26 | class MultiSpecDataset(_BaseDataset): 27 | """ 28 | Dataset class for mel-spectrogram feature extractor. This class is responsible for 29 | loading audio data, generating multiple spectrograms for each sample, and 30 | applying the necessary transformations. 31 | 32 | Args: 33 | cfg (SimpleNamespace): The configuration object containing dataset and model settings. 34 | data (list[DatasetSchema] | pd.DataFrame): The dataset containing file paths and labels. 35 | phase (str): The phase of the dataset, either "train" or any other phase (e.g., "valid"). 36 | transform (str, dict[Callable[[torch.Tensor], torch.Tensor]] | None): Transformation function to apply to spectrograms. 37 | """ 38 | 39 | def __getitem__(self, idx: int) -> tuple[torch.Tensor, ...]: 40 | """ 41 | Get the spectrogram and target MOS for a given index. 42 | 43 | Args: 44 | idx (int): Index of the sample. 45 | 46 | Returns: 47 | tuple: The spectrogram (torch.Tensor) and target MOS (torch.Tensor) for the sample. 48 | """ 49 | row = self.data[idx] if isinstance(self.data, list) else self.data.iloc[idx] 50 | file = row.file_path 51 | y = load_audio(self.cfg, file) 52 | if ( 53 | hasattr(self.cfg.dataset, "remove_silent_section") 54 | and self.cfg.dataset.remove_silent_section 55 | ): 56 | y = remove_silent_section(y) 57 | specs = [] 58 | length = int(self.cfg.dataset.spec_frames.frame_sec * self.cfg.sr) 59 | y = extend_audio(self.cfg, y, length, type=self.cfg.dataset.spec_frames.extend) 60 | for _ in range(self.cfg.dataset.spec_frames.num_frames): 61 | y1 = select_random_start(y, length) 62 | for spec_cfg in self.cfg.dataset.specs: 63 | spec = _make_spctrogram(self.cfg, spec_cfg, y1) 64 | if self.cfg.dataset.spec_frames.mixup_inner: 65 | y2 = select_random_start(y, length) 66 | spec2 = _make_spctrogram(self.cfg, spec_cfg, y2) 67 | lmd = np.random.beta( 68 | self.cfg.dataset.spec_frames.mixup_alpha, 69 | self.cfg.dataset.spec_frames.mixup_alpha, 70 | ) 71 | spec = lmd * spec + (1 - lmd) * spec2 72 | spec = np.stack([spec, spec, spec], axis=0) 73 | # spec = np.transpose(spec, (1, 2, 0)) 74 | spec_tensor = torch.tensor(spec, dtype=torch.float32) 75 | phase = "train" if self.phase == "train" else "valid" 76 | assert self.transform is not None, "Transform must be provided." 77 | spec_tensor = self.transform[phase](spec_tensor) 78 | specs.append(spec_tensor) 79 | spec_tensor = torch.stack(specs).float() 80 | 81 | target = row.mos or 0.0 82 | target = torch.tensor(target, dtype=torch.float32) 83 | 84 | return spec_tensor, target 85 | 86 | 87 | class MultiSpecExtDataset(MultiSpecDataset): 88 | """ 89 | Dataset class for mel-spectrogram feature extractor with data-domain embedding. 90 | 91 | Args: 92 | cfg (SimpleNamespace | ModuleType): 93 | The configuration object containing dataset and model settings. 94 | data (pd.DataFrame | list[DatasetSchema]): 95 | The dataset containing file paths and labels. 96 | phase (str): 97 | The phase of the dataset, either "train" or any other phase (e.g., "valid"). 98 | transform (dict[str, Callable[[torch.Tensor], torch.Tensor]] | None): 99 | Transformation function to apply to spectrograms. 100 | """ 101 | 102 | def __init__( 103 | self, 104 | cfg: Config, 105 | data: "pd.DataFrame" | list[DatasetSchema], 106 | phase: str, 107 | transform: dict[str, Callable[[torch.Tensor], torch.Tensor]] | None = None, 108 | ): 109 | super().__init__(cfg, data, phase, transform) 110 | self.dataset_map = get_dataset_map(cfg) 111 | 112 | def __getitem__(self, idx: int) -> tuple[torch.Tensor, ...]: 113 | """ 114 | Get the spectrogram, data-domain embedding, and target MOS for a given index. 115 | 116 | Args: 117 | idx (int): Index of the sample. 118 | 119 | Returns: 120 | tuple: A tuple containing the generated spectrogram (torch.Tensor), data-domain embedding (torch.Tensor), 121 | and target MOS (torch.Tensor). 122 | """ 123 | spec, target = super().__getitem__(idx) 124 | row = self.data[idx] if isinstance(self.data, list) else self.data.iloc[idx] 125 | 126 | d = np.zeros(len(self.dataset_map)) 127 | d[self.dataset_map[row.dataset]] = 1 128 | dt = torch.tensor(d, dtype=torch.float32) 129 | 130 | return spec, dt, target 131 | 132 | 133 | def _make_spctrogram(cfg: Config, spec_cfg: Config, y: np.ndarray) -> np.ndarray: 134 | if spec_cfg.mode == "melspec": 135 | return _make_melspec(cfg, spec_cfg, y) 136 | elif spec_cfg.mode == "stft": 137 | return _make_stft(cfg, spec_cfg, y) 138 | else: 139 | raise NotImplementedError 140 | 141 | 142 | def _make_melspec(cfg: Config, spec_cfg: Config, y: np.ndarray) -> np.ndarray: 143 | spec = librosa.feature.melspectrogram( 144 | y=y, 145 | sr=cfg.sr, 146 | n_fft=spec_cfg.n_fft, 147 | hop_length=spec_cfg.hop_length, 148 | n_mels=spec_cfg.n_mels, 149 | win_length=spec_cfg.win_length, 150 | ) 151 | spec = librosa.power_to_db(spec, ref=np.max) 152 | if spec_cfg.norm is not None: 153 | spec = (spec + spec_cfg.norm) / spec_cfg.norm 154 | return spec 155 | 156 | 157 | def _make_stft(cfg: Config, spec_cfg: Config, y: np.ndarray) -> np.ndarray: 158 | spec = librosa.stft(y=y, n_fft=spec_cfg.n_fft, hop_length=spec_cfg.hop_length) 159 | spec = np.abs(spec) 160 | spec = librosa.amplitude_to_db(spec) 161 | return spec 162 | -------------------------------------------------------------------------------- /utmosv2/dataset/ssl.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from utmosv2._settings._config import Config 9 | from utmosv2.dataset._base import _BaseDataset 10 | from utmosv2.dataset._utils import ( 11 | extend_audio, 12 | get_dataset_map, 13 | load_audio, 14 | select_random_start, 15 | ) 16 | from utmosv2.preprocess._preprocess import remove_silent_section 17 | 18 | if TYPE_CHECKING: 19 | import pandas as pd 20 | 21 | from utmosv2.dataset._schema import DatasetSchema 22 | 23 | 24 | class SSLDataset(_BaseDataset): 25 | """ 26 | Dataset class for SSL (Self-Supervised Learning) feature extractor. 27 | This class handles audio loading, extending, and random selection of a segment from the audio. 28 | 29 | Args: 30 | cfg (SimpleNamespace | ModuleType): 31 | The configuration object containing dataset and model settings. 32 | data (pd.DataFrame | list[DatasetSchema]): 33 | The dataset containing file paths and MOS labels. 34 | phase (str): 35 | The phase of the dataset, either "train" or any other phase (e.g., "valid"). 36 | transform (dict[str, Callable[[torch.Tensor], torch.Tensor]] | None): 37 | Transformation function to apply to spectrograms. 38 | """ 39 | 40 | def __getitem__(self, idx: int) -> tuple[torch.Tensor, ...]: 41 | """ 42 | Get the processed audio, and target MOS for a given index. 43 | 44 | Args: 45 | idx (int): Index of the sample. 46 | Returns: 47 | tuple: A tuple containing the processed audio (torch.Tensor), and target MOS (torch.Tensor). 48 | """ 49 | row = self.data[idx] if isinstance(self.data, list) else self.data.iloc[idx] 50 | file = row.file_path 51 | y = load_audio(self.cfg, file) 52 | if ( 53 | hasattr(self.cfg.dataset, "remove_silent_section") 54 | and self.cfg.dataset.remove_silent_section 55 | ): 56 | y = remove_silent_section(y) 57 | length = int(self.cfg.dataset.ssl.duration * self.cfg.sr) 58 | y = extend_audio(self.cfg, y, length, type="tile") 59 | y = select_random_start(y, length) 60 | 61 | target = row.mos or 0.0 62 | target = torch.tensor(target, dtype=torch.float32) 63 | 64 | return torch.from_numpy(y), target 65 | 66 | 67 | class SSLExtDataset(SSLDataset): 68 | """ 69 | Dataset class for SSL (Self-Supervised Learning) feature extractor with data-domein embedding. 70 | 71 | Args: 72 | cfg (SimpleNamespace | ModuleType): 73 | The configuration object containing dataset and model settings. 74 | data (pd.DataFrame | list[DatasetSchema]): 75 | The dataset containing file paths and MOS labels. 76 | phase (str): 77 | The phase of the dataset, either "train" or any other phase (e.g., "valid"). 78 | """ 79 | 80 | def __init__( 81 | self, cfg: Config, data: "pd.DataFrame" | list[DatasetSchema], phase: str 82 | ): 83 | super().__init__(cfg, data, phase) 84 | self.dataset_map = get_dataset_map(cfg) 85 | 86 | def __getitem__(self, idx: int) -> tuple[torch.Tensor, ...]: 87 | """ 88 | Get the processed audio, data-domain embedding, and target MOS for a given index. 89 | 90 | Args: 91 | idx (int): Index of the sample. 92 | Returns: 93 | tuple: A tuple containing the processed audio (torch.Tensor), data-domain embedding (torch.Tensor), 94 | and target MOS (torch.Tensor). 95 | """ 96 | y, target = super().__getitem__(idx) 97 | row = self.data[idx] if isinstance(self.data, list) else self.data.iloc[idx] 98 | 99 | d = np.zeros(len(self.dataset_map)) 100 | d[self.dataset_map[row.dataset]] = 1 101 | dt = torch.tensor(d, dtype=torch.float32) 102 | 103 | return y, dt, target 104 | -------------------------------------------------------------------------------- /utmosv2/dataset/ssl_multispec.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Callable 4 | from typing import TYPE_CHECKING 5 | 6 | import torch 7 | 8 | from utmosv2._settings._config import Config 9 | from utmosv2.dataset import MultiSpecDataset, SSLExtDataset 10 | from utmosv2.dataset._base import _BaseDataset 11 | 12 | if TYPE_CHECKING: 13 | import pandas as pd 14 | 15 | from utmosv2.dataset._schema import DatasetSchema 16 | 17 | 18 | class SSLLMultiSpecExtDataset(_BaseDataset): 19 | """ 20 | Dataset class that combines both SSL (Self-Supervised Learning) and Multi-Spectrogram datasets. 21 | This dataset uses both SSLExtDataset and MultiSpecDataset to provide different representations 22 | of the same audio sample. 23 | 24 | Args: 25 | cfg (SimpleNamespace | ModuleType): 26 | The configuration object containing dataset and model settings. 27 | data (pd.DataFrame | list[DatasetSchema]): 28 | The dataset containing file paths and MOS labels. 29 | phase (str): 30 | The phase of the dataset, either "train" or any other phase (e.g., "valid"). 31 | transform (dict[str, Callable[[torch.Tensor], torch.Tensor]] | None): 32 | Transformation function to apply to spectrograms. 33 | """ 34 | 35 | def __init__( 36 | self, 37 | cfg: Config, 38 | data: "pd.DataFrame" | list[DatasetSchema], 39 | phase: str, 40 | transform: dict[str, Callable[[torch.Tensor], torch.Tensor]] | None = None, 41 | ): 42 | super().__init__(cfg, data, phase, transform) 43 | self.ssl = SSLExtDataset(cfg, data, phase) 44 | self.multi_spec = MultiSpecDataset(cfg, data, phase, transform) 45 | 46 | def __len__(self) -> int: 47 | return len(self.data) 48 | 49 | def __getitem__(self, idx: int) -> tuple[torch.Tensor, ...]: 50 | """ 51 | Get data for SSL feature extractor, mel-spectrogram feature extractor, data-domain embedding, and target MOS for a given index. 52 | 53 | Args: 54 | idx (int): Index of the sample. 55 | 56 | Returns: 57 | tuple: data for SSL feature extractor (torch.Tensor), data for mel-spectrogram feature extractor (torch.Tensor), 58 | data-domain id (torch.Tensor), and target MOS (torch.Tensor). 59 | """ 60 | x1, d, target = self.ssl[idx] 61 | x2, _ = self.multi_spec[idx] 62 | 63 | return x1, x2, d, target 64 | -------------------------------------------------------------------------------- /utmosv2/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from utmosv2.loss._losses import CombinedLoss, PairwizeDiffLoss 2 | 3 | __all__ = ["PairwizeDiffLoss", "CombinedLoss"] 4 | -------------------------------------------------------------------------------- /utmosv2/loss/_losses.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class PairwizeDiffLoss(nn.Module): 9 | """ 10 | Pairwise difference loss function for comparing input and target tensors. 11 | The loss is based on the difference between pairs of inputs and pairs of targets, 12 | with a specified margin and norm ("l1" or "l2_squared"). 13 | """ 14 | 15 | def __init__(self, margin: float = 0.2, norm: str = "l1"): 16 | """ 17 | Initialize the PairwizeDiffLoss with the specified margin and norm. 18 | 19 | Args: 20 | margin (float): 21 | The margin value used for the loss function. Defaults to 0.2. 22 | norm (str): 23 | The norm to use for the difference calculation. Must be "l1" or "l2_squared". Defaults to "l1". 24 | """ 25 | super().__init__() 26 | self.margin = margin 27 | self.norm = norm 28 | 29 | def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 30 | """ 31 | Compute the pairwise difference loss between input and target tensors. 32 | 33 | Args: 34 | input (torch.Tensor): The input tensor. 35 | target (torch.Tensor): The target tensor. 36 | 37 | Returns: 38 | torch.Tensor: The computed loss. 39 | """ 40 | s = input.unsqueeze(1) - input.unsqueeze(0) 41 | t = target.unsqueeze(1) - target.unsqueeze(0) 42 | if self.norm not in ["l1", "l2_squared"]: 43 | raise ValueError( 44 | f'Unknown norm: {self.norm}. Must be one of ["l1", "l2_squared"]' 45 | ) 46 | norm_fn = { 47 | "l1": torch.abs, 48 | "l2_squared": lambda x: x**2, 49 | }[self.norm] 50 | loss = F.relu(norm_fn(s - t) - self.margin) # type: ignore 51 | return loss.mean().div(2) 52 | 53 | 54 | class CombinedLoss(nn.Module): 55 | """ 56 | A combined loss function that allows for multiple loss functions to be weighted and combined. 57 | 58 | Args: 59 | weighted_losses (list[tuple[nn.Module, float]]): 60 | A list of loss functions and their associated weights. 61 | """ 62 | 63 | def __init__(self, weighted_losses: list[tuple[nn.Module, float]]): 64 | super().__init__() 65 | self.weighted_losses = weighted_losses 66 | 67 | def forward( 68 | self, input: torch.Tensor, target: torch.Tensor 69 | ) -> list[tuple[float, torch.Tensor]]: 70 | """ 71 | Compute the weighted loss for each loss function in the list. 72 | 73 | Args: 74 | input (torch.Tensor): The input tensor. 75 | target (torch.Tensor): The target tensor. 76 | 77 | Returns: 78 | list[tuple[float, torch.Tensor]]: 79 | A list of tuples where each contains a weight and the corresponding computed loss. 80 | """ 81 | return [(w, loss(input, target)) for loss, w in self.weighted_losses] 82 | -------------------------------------------------------------------------------- /utmosv2/model/__init__.py: -------------------------------------------------------------------------------- 1 | from utmosv2.model.multi_spec import MultiSpecExtModel, MultiSpecModelV2 2 | from utmosv2.model.ssl import SSLExtModel 3 | from utmosv2.model.ssl_multispec import SSLMultiSpecExtModelV1, SSLMultiSpecExtModelV2 4 | 5 | __all__ = [ 6 | "MultiSpecExtModel", 7 | "MultiSpecModelV2", 8 | "SSLExtModel", 9 | "SSLMultiSpecExtModelV1", 10 | "SSLMultiSpecExtModelV2", 11 | ] 12 | -------------------------------------------------------------------------------- /utmosv2/model/ssl.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from transformers import AutoFeatureExtractor, AutoModel 7 | 8 | from utmosv2._settings._config import Config 9 | from utmosv2.dataset._utils import get_dataset_num 10 | 11 | 12 | class _SSLEncoder(nn.Module): 13 | def __init__(self, sr: int, model_name: str, freeze: bool): 14 | super().__init__() 15 | self.sr = sr 16 | self.processor = AutoFeatureExtractor.from_pretrained(model_name) 17 | self.model = AutoModel.from_pretrained(model_name) 18 | if freeze: 19 | for param in self.model.parameters(): 20 | param.requires_grad = False 21 | 22 | def forward(self, x: tuple[torch.Tensor]) -> tuple[torch.Tensor]: 23 | x = self.processor( 24 | [t.cpu().numpy() for t in x], 25 | sampling_rate=self.sr, 26 | return_tensors="pt", 27 | ).to(self.model.device) 28 | outputs = self.model(**x, output_hidden_states=True) 29 | return outputs.hidden_states 30 | 31 | 32 | class SSLExtModel(nn.Module): 33 | """ 34 | A self-supervised learning (SSL) model extended with data-domain id. 35 | This model uses an encoder to process input data, applies attention layers if configured, 36 | and combines the features with data-domain embeddings before classification. 37 | 38 | Args: 39 | cfg (SimpleNamespace | ModuleType): 40 | Configuration object containing model and dataset settings. 41 | name (str | None): 42 | Optional name for the SSL encoder. Defaults to the name specified in `cfg.model.ssl.name`. 43 | """ 44 | 45 | def __init__(self, cfg: Config, name: str | None = None): 46 | super().__init__() 47 | self.cfg = cfg 48 | self.encoder = _SSLEncoder( 49 | cfg.sr, name or cfg.model.ssl.name, cfg.model.ssl.freeze 50 | ) 51 | hidden_num, in_features = get_ssl_output_shape(name or cfg.model.ssl.name) 52 | self.weights = nn.Parameter(F.softmax(torch.randn(hidden_num), dim=0)) 53 | if cfg.model.ssl.attn: 54 | self.attn = nn.ModuleList( 55 | [ 56 | nn.MultiheadAttention( 57 | embed_dim=in_features, 58 | num_heads=8, 59 | dropout=0.2, 60 | batch_first=True, 61 | ) 62 | for _ in range(cfg.model.ssl.attn) 63 | ] 64 | ) 65 | self.num_dataset = get_dataset_num(cfg) 66 | self.fc: nn.Linear | nn.Identity = nn.Linear( 67 | in_features * 2 + self.num_dataset, cfg.model.ssl.num_classes 68 | ) 69 | 70 | def forward(self, xt: tuple[torch.Tensor], d: torch.Tensor) -> torch.Tensor: 71 | """ 72 | Forward pass of the SSLExtModel. 73 | 74 | Args: 75 | x (torch.Tensor): 76 | Input tensor representing the features to be processed by the SSL encoder. 77 | d (torch.Tensor): 78 | Dataset-specific information tensor. 79 | 80 | Returns: 81 | torch.Tensor: 82 | Output tensor after applying the SSL encoder, attention (if configured), and fully connected layers. 83 | """ 84 | xt = self.encoder(xt) 85 | x: torch.Tensor = sum([t * w for t, w in zip(xt, self.weights)]) 86 | if self.cfg.model.ssl.attn: 87 | y = x 88 | for attn in self.attn: 89 | y, _ = attn(y, y, y) 90 | x = torch.cat([torch.mean(y, dim=1), torch.max(x, dim=1)[0]], dim=1) 91 | else: 92 | x = torch.cat([torch.mean(x, dim=1), torch.max(x, dim=1)[0]], dim=1) 93 | x = self.fc(torch.cat([x, d], dim=1)) 94 | return x 95 | 96 | 97 | def get_ssl_output_shape(name: str) -> tuple[int, int]: 98 | if name in [ 99 | "facebook/w2v-bert-2.0", 100 | "facebook/wav2vec2-large", 101 | "facebook/wav2vec2-large-robust", 102 | "facebook/wav2vec2-large-960h", 103 | "microsoft/wavlm-large", 104 | "facebook/wav2vec2-large-xlsr-53", 105 | ]: 106 | return 25, 1024 107 | elif name in [ 108 | "facebook/hubert-base-ls960", 109 | "facebook/data2vec-audio-base-960h", 110 | "microsoft/wavlm-base", 111 | "microsoft/wavlm-base-plus", 112 | "microsoft/wavlm-base-plus-sv", 113 | "facebook/wav2vec2-base", 114 | ]: 115 | return 13, 768 116 | else: 117 | raise NotImplementedError 118 | -------------------------------------------------------------------------------- /utmosv2/model/ssl_multispec.py: -------------------------------------------------------------------------------- 1 | from typing import cast 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from utmosv2._settings._config import Config 7 | from utmosv2.dataset._utils import get_dataset_num 8 | from utmosv2.model import MultiSpecExtModel, MultiSpecModelV2, SSLExtModel 9 | 10 | 11 | class SSLMultiSpecExtModelV1(nn.Module): 12 | def __init__(self, cfg: Config): 13 | super().__init__() 14 | self.cfg = cfg 15 | self.ssl = SSLExtModel(cfg) 16 | self.spec_long = MultiSpecModelV2(cfg) 17 | self.ssl.load_state_dict( 18 | torch.load( 19 | f"outputs/{cfg.model.ssl_spec.ssl_weight}/fold{cfg.now_fold}_s{cfg.split.seed}_best_model.pth" 20 | ) 21 | ) 22 | self.spec_long.load_state_dict( 23 | torch.load( 24 | f"outputs/{cfg.model.ssl_spec.spec_weight}/fold{cfg.now_fold}_s{cfg.split.seed}_best_model.pth" 25 | ) 26 | ) 27 | if cfg.model.ssl_spec.freeze: 28 | for param in self.ssl.parameters(): 29 | param.requires_grad = False 30 | for param in self.spec_long.parameters(): 31 | param.requires_grad = False 32 | self.ssl.fc = nn.Identity() 33 | self.spec_long.fc = nn.Identity() 34 | 35 | self.num_dataset = get_dataset_num(cfg) 36 | 37 | self.fc = nn.Linear( 38 | cast(int, self.ssl.fc.in_features) 39 | + cast(int, self.spec_long.fc.in_features) 40 | + self.num_dataset, 41 | cfg.model.ssl_spec.num_classes, 42 | ) 43 | 44 | def forward( 45 | self, x1: torch.Tensor, x2: torch.Tensor, d: torch.Tensor 46 | ) -> torch.Tensor: 47 | x1 = self.ssl(x1, torch.zeros(x1.shape[0], self.num_dataset).to(x1.device)) 48 | x2 = self.spec_long(x2) 49 | x = torch.cat([x1, x2, d], dim=1) 50 | x = self.fc(x) 51 | return x 52 | 53 | 54 | class SSLMultiSpecExtModelV2(nn.Module): 55 | def __init__(self, cfg: Config): 56 | super().__init__() 57 | self.cfg = cfg 58 | self.ssl = SSLExtModel(cfg) 59 | self.spec_long = MultiSpecExtModel(cfg) 60 | if cfg.model.ssl_spec.ssl_weight is not None and cfg.phase == "train": 61 | self.ssl.load_state_dict( 62 | torch.load( 63 | f"outputs/{cfg.model.ssl_spec.ssl_weight}/fold{cfg.now_fold}_s{cfg.split.seed}_best_model.pth" 64 | ) 65 | ) 66 | if cfg.model.ssl_spec.spec_weight is not None and cfg.phase == "train": 67 | self.spec_long.load_state_dict( 68 | torch.load( 69 | f"outputs/{cfg.model.ssl_spec.spec_weight}/fold{cfg.now_fold}_s{cfg.split.seed}_best_model.pth" 70 | ) 71 | ) 72 | if cfg.model.ssl_spec.freeze: 73 | for param in self.ssl.parameters(): 74 | param.requires_grad = False 75 | for param in self.spec_long.parameters(): 76 | param.requires_grad = False 77 | ssl_input = self.ssl.fc.in_features 78 | spec_long_input = self.spec_long.fc.in_features 79 | self.ssl.fc = nn.Identity() 80 | self.spec_long.fc = nn.Identity() 81 | 82 | self.num_dataset = get_dataset_num(cfg) 83 | 84 | self.fc = nn.Linear( 85 | cast(int, ssl_input) + cast(int, spec_long_input) + self.num_dataset, 86 | cfg.model.ssl_spec.num_classes, 87 | ) 88 | 89 | def forward( 90 | self, x1: torch.Tensor, x2: torch.Tensor, d: torch.Tensor 91 | ) -> torch.Tensor: 92 | x1 = self.ssl(x1, torch.zeros(x1.shape[0], self.num_dataset).to(x1.device)) 93 | x2 = self.spec_long( 94 | x2, torch.zeros(x1.shape[0], self.num_dataset).to(x1.device) 95 | ) 96 | x = torch.cat([x1, x2, d], dim=1) 97 | x = self.fc(x) 98 | return x 99 | -------------------------------------------------------------------------------- /utmosv2/preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | from utmosv2.preprocess._preprocess import ( 2 | add_sys_mean, 3 | preprocess, 4 | preprocess_test, 5 | remove_silent_section, 6 | ) 7 | 8 | __all__ = ["add_sys_mean", "preprocess", "preprocess_test", "remove_silent_section"] 9 | -------------------------------------------------------------------------------- /utmosv2/runner/__init__.py: -------------------------------------------------------------------------------- 1 | from utmosv2.runner._inference import run_inference 2 | from utmosv2.runner._train import run_train, train_1epoch, validate_1epoch 3 | 4 | __all__ = [ 5 | "run_train", 6 | "train_1epoch", 7 | "validate_1epoch", 8 | "run_inference", 9 | ] 10 | -------------------------------------------------------------------------------- /utmosv2/runner/_inference.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | import numpy as np 6 | import torch 7 | from torch.cuda.amp import autocast 8 | from tqdm import tqdm 9 | 10 | from utmosv2._settings._config import Config 11 | from utmosv2.utils import calc_metrics, print_metrics 12 | 13 | if TYPE_CHECKING: 14 | import pandas as pd 15 | 16 | 17 | def run_inference( 18 | cfg: Config, 19 | model: torch.nn.Module, 20 | test_dataloader: torch.utils.data.DataLoader, 21 | cycle: int, 22 | test_data: "pd.DataFrame", 23 | device: torch.device, 24 | ) -> tuple[np.ndarray, dict[str, float] | None]: 25 | """ 26 | Run inference on the test dataset using the provided model. 27 | 28 | Args: 29 | cfg (SimpleNamespace | ModuleType): 30 | Configuration object containing inference settings. 31 | It includes settings for test-time augmentation (TTA) and reproducibility. 32 | model (torch.nn.Module): 33 | The trained model to be used for inference. 34 | test_dataloader (torch.utils.data.DataLoader): 35 | Dataloader for the test dataset. 36 | cycle (int): 37 | Current cycle of test-time augmentation (TTA) if used. 38 | test_data (pd.DataFrame): 39 | DataFrame containing test data, used for metric calculation if reproducibility is enabled. 40 | device (torch.device): 41 | Device to run inference on (e.g., 'cuda' or 'cpu'). 42 | 43 | Returns: 44 | tuple[np.ndarray, dict[str, float] | None]: 45 | - test_preds: Array containing the model's predictions for the test dataset. 46 | - test_metrics: Dictionary containing the calculated metrics if reproducibility is enabled; otherwise, None. 47 | """ 48 | model.eval() 49 | test_preds_ls = [] 50 | pbar = tqdm( 51 | test_dataloader, 52 | total=len(test_dataloader), 53 | desc=f" [Inference] ({cycle + 1}/{cfg.inference.num_tta})", 54 | ) 55 | 56 | with torch.no_grad(): 57 | for t in pbar: 58 | x = t[:-1] 59 | x = [t.to(device, non_blocking=True) for t in x] 60 | with autocast(): 61 | output = model(*x).squeeze(1) 62 | test_preds_ls.append(output.cpu().numpy()) 63 | test_preds = np.concatenate(test_preds_ls) 64 | if cfg.reproduce: 65 | test_metrics = calc_metrics(test_data, test_preds) 66 | print_metrics(test_metrics) 67 | else: 68 | test_metrics = None 69 | 70 | return test_preds, test_metrics 71 | -------------------------------------------------------------------------------- /utmosv2/transform/__init__.py: -------------------------------------------------------------------------------- 1 | from utmosv2.transform._xymasking import XYMasking 2 | 3 | __all__ = ["XYMasking"] 4 | -------------------------------------------------------------------------------- /utmosv2/transform/_xymasking.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | import numpy as np 6 | 7 | if TYPE_CHECKING: 8 | import torch 9 | 10 | 11 | class XYMasking: 12 | """ 13 | Apply random rectangular masks to an image along the x and y axes. This augmentation 14 | is useful for randomly masking parts of an image during training to improve robustness. 15 | 16 | Args: 17 | num_masks_x (int | tuple[int, int]): 18 | The number of masks to apply along the x-axis. 19 | If a tuple is provided, a random number in the range will be used. 20 | num_masks_y (int | tuple[int, int]): 21 | The number of masks to apply along the y-axis. 22 | If a tuple is provided, a random number in the range will be used. 23 | mask_x_length (int | tuple[int, int]): 24 | The length of each mask along the x-axis. 25 | If a tuple is provided, a random length in the range will be used. 26 | mask_y_length (int | tuple[int, int]): 27 | The length of each mask along the y-axis. 28 | If a tuple is provided, a random length in the range will be used. 29 | fill_value (int): 30 | The value to fill the masked areas with. 31 | p (float): 32 | The probability of applying the masking. Defaults to 1.0 (always apply masking). 33 | """ 34 | 35 | def __init__( 36 | self, 37 | num_masks_x: int | tuple[int, int], 38 | num_masks_y: int | tuple[int, int], 39 | mask_x_length: int | tuple[int, int], 40 | mask_y_length: int | tuple[int, int], 41 | fill_value: int, 42 | p: float = 1.0, 43 | ): 44 | self.num_masks_x = num_masks_x 45 | self.num_masks_y = num_masks_y 46 | self.mask_x_length = mask_x_length 47 | self.mask_y_length = mask_y_length 48 | self.fill_value = fill_value 49 | self.p = p 50 | 51 | def __call__(self, img: "torch.Tensor") -> "torch.Tensor": 52 | """ 53 | Apply the XY masking to the given image. 54 | 55 | Args: 56 | img (torch.Tensor): The input image tensor of shape (channels, width, height). 57 | 58 | Returns: 59 | torch.Tensor: The image tensor with masks applied along the x and y axes. 60 | """ 61 | if np.random.rand() < self.p: 62 | return img 63 | _, width, height = img.shape 64 | num_masks_x = ( 65 | np.random.randint(*self.num_masks_x) 66 | if isinstance(self.num_masks_x, tuple) 67 | else self.num_masks_x 68 | ) 69 | for _ in range(num_masks_x): 70 | mask_x_length = ( 71 | np.random.randint(*self.mask_x_length) 72 | if isinstance(self.mask_x_length, tuple) 73 | else self.mask_x_length 74 | ) 75 | x = np.random.randint(0, width - mask_x_length) 76 | img[:, :, x : x + mask_x_length] = self.fill_value 77 | 78 | num_masks_y = ( 79 | np.random.randint(*self.num_masks_y) 80 | if isinstance(self.num_masks_y, tuple) 81 | else self.num_masks_y 82 | ) 83 | for _ in range(num_masks_y): 84 | mask_y_length = ( 85 | np.random.randint(*self.mask_y_length) 86 | if isinstance(self.mask_y_length, tuple) 87 | else self.mask_y_length 88 | ) 89 | y = np.random.randint(0, height - mask_y_length) 90 | img[:, y : y + mask_y_length, :] = self.fill_value 91 | 92 | return img 93 | -------------------------------------------------------------------------------- /utmosv2/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from utmosv2.utils._pure import ( 2 | get_dataloader, 3 | get_loss, 4 | get_optimizer, 5 | get_scheduler, 6 | print_metrics, 7 | save_oof_preds, 8 | split_data, 9 | ) 10 | from utmosv2.utils._task_dependents import ( 11 | calc_metrics, 12 | get_data, 13 | get_dataset, 14 | get_inference_data, 15 | get_metrics, 16 | get_model, 17 | get_train_data, 18 | make_submission_file, 19 | save_preds, 20 | save_test_preds, 21 | show_inference_data, 22 | ) 23 | from utmosv2.utils._download import download_pretrained_weights_from_hf 24 | 25 | __all__ = [ 26 | "get_dataloader", 27 | "get_loss", 28 | "get_optimizer", 29 | "get_scheduler", 30 | "print_metrics", 31 | "save_oof_preds", 32 | "split_data", 33 | "calc_metrics", 34 | "get_data", 35 | "get_dataset", 36 | "get_inference_data", 37 | "get_train_data", 38 | "get_metrics", 39 | "get_model", 40 | "make_submission_file", 41 | "save_preds", 42 | "save_test_preds", 43 | "show_inference_data", 44 | "download_pretrained_weights_from_hf", 45 | ] 46 | -------------------------------------------------------------------------------- /utmosv2/utils/_constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | _UTMOSV2_CHACHE = Path(os.getenv("UTMOSV2_CHACHE", "~/.cache/utmosv2")).expanduser() 5 | -------------------------------------------------------------------------------- /utmosv2/utils/_download.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | from utmosv2.utils._constants import _UTMOSV2_CHACHE 4 | 5 | 6 | def download_pretrained_weights_from_github(cfg_name: str) -> None: 7 | if cfg_name != "fusion_stage3": 8 | raise ValueError(f"{cfg_name} is not stored.") 9 | print(f"Downloading pretrained weights for `{cfg_name}`...") 10 | try: 11 | subprocess.run( 12 | [ 13 | "git", 14 | "clone", 15 | "--filter=blob:none", 16 | "--no-checkout", 17 | "https://github.com/sarulab-speech/UTMOSv2.git", 18 | _UTMOSV2_CHACHE.as_posix(), 19 | ], 20 | check=True, 21 | ) 22 | subprocess.run( 23 | ["git", "sparse-checkout", "set", "models"], 24 | cwd=_UTMOSV2_CHACHE, 25 | check=True, 26 | ) 27 | subprocess.run( 28 | ["git", "checkout"], 29 | cwd=_UTMOSV2_CHACHE, 30 | check=True, 31 | ) 32 | except subprocess.CalledProcessError as e: 33 | print(f"Failed to download pretrained weights: {e}") 34 | print("Done.") 35 | 36 | 37 | def download_pretrained_weights_from_hf(cfg_name: str, now_fold: int) -> None: 38 | if cfg_name != "fusion_stage3": 39 | raise ValueError(f"{cfg_name} is not stored.") 40 | print(f"Downloading pretrained weights for `{cfg_name}`...") 41 | url = f"https://huggingface.co/sarulab-speech/UTMOSv2/resolve/main/fold{now_fold}_s42_best_model.pth" 42 | try: 43 | subprocess.run( 44 | [ 45 | "wget", 46 | "-P", 47 | (_UTMOSV2_CHACHE / "models" / cfg_name).as_posix(), 48 | url, 49 | ] 50 | ) 51 | except subprocess.CalledProcessError as e: 52 | print(f"Failed to download pretrained weights: {e}") 53 | print("Done.") 54 | -------------------------------------------------------------------------------- /utmosv2/utils/_pure/__init__.py: -------------------------------------------------------------------------------- 1 | from utmosv2.utils._pure.initializers import ( 2 | get_dataloader, 3 | get_loss, 4 | get_optimizer, 5 | get_scheduler, 6 | ) 7 | from utmosv2.utils._pure.metrics import print_metrics 8 | from utmosv2.utils._pure.save import save_oof_preds 9 | from utmosv2.utils._pure.split import split_data 10 | 11 | __all__ = [ 12 | "get_dataloader", 13 | "get_loss", 14 | "get_optimizer", 15 | "get_scheduler", 16 | "print_metrics", 17 | "save_oof_preds", 18 | "split_data", 19 | ] 20 | -------------------------------------------------------------------------------- /utmosv2/utils/_pure/initializers.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | 7 | from utmosv2._settings._config import Config 8 | from utmosv2.loss import CombinedLoss, PairwizeDiffLoss 9 | 10 | 11 | def get_dataloader( 12 | cfg: Config, dataset: torch.utils.data.Dataset, phase: str 13 | ) -> torch.utils.data.DataLoader: 14 | """ 15 | Return a DataLoader for the specified dataset and phase. 16 | 17 | Args: 18 | cfg (SimpleNamespace | ModuleType): 19 | Configuration object containing settings for batch size, number of workers, and pin memory. 20 | dataset (torch.utils.data.Dataset): 21 | The dataset to load data from. 22 | phase (str): 23 | The phase of the training process. Must be one of ["train", "valid", "test"]. 24 | 25 | Returns: 26 | torch.utils.data.DataLoader: A DataLoader for the given dataset and phase. 27 | 28 | Raises: 29 | ValueError: If the phase is not one of ["train", "valid", "test"]. 30 | """ 31 | if phase == "train": 32 | return torch.utils.data.DataLoader( 33 | dataset, 34 | batch_size=cfg.batch_size, 35 | shuffle=True, 36 | num_workers=cfg.num_workers, 37 | pin_memory=True, 38 | ) 39 | elif phase == "valid": 40 | return torch.utils.data.DataLoader( 41 | dataset, 42 | batch_size=cfg.batch_size, 43 | shuffle=False, 44 | num_workers=cfg.num_workers, 45 | pin_memory=True, 46 | ) 47 | elif phase == "test": 48 | return torch.utils.data.DataLoader( 49 | dataset, 50 | batch_size=cfg.inference.batch_size, 51 | shuffle=False, 52 | num_workers=cfg.num_workers, 53 | pin_memory=True, 54 | ) 55 | else: 56 | raise ValueError(f"Phase must be one of [train, valid, test], but got {phase}") 57 | 58 | 59 | def _get_unit_loss(loss_cfg: Config) -> nn.Module: 60 | if loss_cfg.name == "pairwize_diff": 61 | return PairwizeDiffLoss(loss_cfg.margin, loss_cfg.norm) 62 | elif loss_cfg.name == "mse": 63 | return nn.MSELoss() 64 | else: 65 | raise NotImplementedError 66 | 67 | 68 | def _get_combined_loss(cfg: Config) -> nn.Module: 69 | if cfg.print_config: 70 | print( 71 | "Using losses: " 72 | + ", ".join([f"{loss_cfg.name} ({w})" for loss_cfg, w in cfg.loss]) 73 | ) 74 | weighted_losses = [(_get_unit_loss(loss_cfg), w) for loss_cfg, w in cfg.loss] 75 | return CombinedLoss(weighted_losses) 76 | 77 | 78 | def get_loss(cfg: Config) -> nn.Module: 79 | """ 80 | Return the appropriate loss function based on the configuration. 81 | 82 | Args: 83 | cfg (SimpleNamespace | ModuleType): 84 | Configuration object containing the loss settings. 85 | If `cfg.loss` is a list, a combined loss is returned. 86 | Otherwise, a single loss function is returned. 87 | 88 | Returns: 89 | nn.Module: The configured loss function, either a single loss or a combined loss module. 90 | """ 91 | if isinstance(cfg.loss, list): 92 | return _get_combined_loss(cfg) 93 | else: 94 | return _get_unit_loss(cfg.loss) 95 | 96 | 97 | def get_optimizer(cfg: Config, model: nn.Module) -> optim.Optimizer: 98 | """ 99 | Return the optimizer based on the configuration settings. 100 | 101 | Args: 102 | cfg (SimpleNamespace | ModuleType): 103 | Configuration object containing optimizer settings. 104 | The optimizer name and learning rate are specified in `cfg.optimizer`. 105 | model (nn.Module): 106 | The model whose parameters will be optimized. 107 | 108 | Returns: 109 | optim.Optimizer: The configured optimizer (Adam, AdamW, or SGD). 110 | 111 | Raises: 112 | NotImplementedError: If the specified optimizer is not implemented. 113 | """ 114 | if cfg.print_config: 115 | print(f"Using optimizer: {cfg.optimizer.name}") 116 | if cfg.optimizer.name == "adam": 117 | return optim.Adam(model.parameters(), lr=cfg.optimizer.lr) 118 | elif cfg.optimizer.name == "adamw": 119 | return optim.AdamW( 120 | model.parameters(), 121 | lr=cfg.optimizer.lr, 122 | weight_decay=cfg.optimizer.weight_decay, 123 | ) 124 | elif cfg.optimizer.name == "sgd": 125 | return optim.SGD( 126 | model.parameters(), 127 | lr=cfg.optimizer.lr, 128 | weight_decay=cfg.optimizer.weight_decay, 129 | ) 130 | else: 131 | raise NotImplementedError 132 | 133 | 134 | def get_scheduler( 135 | cfg: Config, optimizer: optim.Optimizer, n_iterations: int 136 | ) -> optim.lr_scheduler.LRScheduler: 137 | """ 138 | Return the learning rate scheduler based on the configuration settings. 139 | 140 | Args: 141 | cfg (SimpleNamespace | ModuleType): 142 | Configuration object containing scheduler settings. 143 | The scheduler name, T_max, and eta_min are specified in `cfg.scheduler`. 144 | optimizer (optim.Optimizer): 145 | The optimizer for which the learning rate will be scheduled. 146 | n_iterations (int): 147 | The number of iterations for the scheduler (used in CosineAnnealingLR). 148 | 149 | Returns: 150 | optim.lr_scheduler.LRScheduler: The configured learning rate scheduler. 151 | 152 | Raises: 153 | NotImplementedError: If the specified scheduler is not implemented. 154 | """ 155 | if cfg.print_config: 156 | print(f"Using scheduler: {cfg.scheduler}") 157 | if cfg.scheduler is None: 158 | return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: 1) 159 | if cfg.scheduler.name == "cosine": 160 | return optim.lr_scheduler.CosineAnnealingLR( 161 | optimizer, 162 | T_max=cfg.scheduler.T_max or n_iterations, 163 | eta_min=cfg.scheduler.eta_min, 164 | ) 165 | else: 166 | raise NotImplementedError 167 | -------------------------------------------------------------------------------- /utmosv2/utils/_pure/metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | 4 | def print_metrics(metrics: dict[str, float]) -> None: 5 | """ 6 | Print the given metrics in a formatted string. 7 | 8 | Args: 9 | metrics (dict[str, float]): 10 | A dictionary of metric names and their corresponding values. 11 | 12 | Returns: 13 | None: This function prints the metrics to the console in the format "metric_name: value". 14 | """ 15 | print(", ".join([f"{k}: {v:.4f}" for k, v in metrics.items()])) 16 | -------------------------------------------------------------------------------- /utmosv2/utils/_pure/save.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from utmosv2._import import _LazyImport 4 | from utmosv2._settings._config import Config 5 | 6 | if TYPE_CHECKING: 7 | import numpy as np 8 | import pandas as pd 9 | else: 10 | pd = _LazyImport("pandas") 11 | 12 | 13 | def save_oof_preds( 14 | cfg: Config, data: "pd.DataFrame", oof_preds: "np.ndarray", fold: int 15 | ) -> None: 16 | """ 17 | Save out-of-fold (OOF) predictions to a CSV file. 18 | 19 | Args: 20 | cfg (SimpleNamespace): 21 | Configuration object containing settings for saving OOF predictions. 22 | Includes `id_name` for the ID column and `save_path` for the save directory. 23 | data (pd.DataFrame): 24 | The original dataset containing the ID column. 25 | oof_preds (np.ndarray): 26 | The array of OOF predictions. 27 | fold (int): 28 | The current fold number used in cross-validation. 29 | 30 | Returns: 31 | None: The function saves the OOF predictions to a CSV file in the specified save path. 32 | """ 33 | oof_df = pd.DataFrame({cfg.id_name: data[cfg.id_name], "oof_preds": oof_preds}) 34 | oof_df.to_csv( 35 | cfg.save_path / f"fold{fold}_s{cfg.split.seed}_oof_preds.csv", index=False 36 | ) 37 | -------------------------------------------------------------------------------- /utmosv2/utils/_pure/split.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Generator 4 | from typing import TYPE_CHECKING 5 | 6 | import numpy as np 7 | 8 | from utmosv2._import import _LazyImport 9 | from utmosv2._settings._config import Config 10 | 11 | if TYPE_CHECKING: 12 | import pandas as pd 13 | from sklearn.model_selection import ( 14 | GroupKFold, 15 | KFold, 16 | StratifiedGroupKFold, 17 | StratifiedKFold, 18 | ) 19 | else: 20 | _model_selection = _LazyImport("sklearn.model_selection") 21 | GroupKFold = _model_selection.GroupKFold 22 | KFold = _model_selection.KFold 23 | StratifiedGroupKFold = _model_selection.StratifiedGroupKFold 24 | StratifiedKFold = _model_selection.StratifiedKFold 25 | 26 | 27 | def split_data( 28 | cfg: Config, data: "pd.DataFrame" 29 | ) -> Generator[tuple[np.ndarray, np.ndarray], None, None]: 30 | """ 31 | Split the data into training and validation sets based on the specified splitting method in the configuration. 32 | 33 | Args: 34 | cfg (SimpleNamespace | ModuleType): Configuration object containing the splitting settings. It includes: 35 | - split.type: Type of split to use ('simple', 'stratified', 'group', 'stratified_group', etc.). 36 | - num_folds: Number of folds for K-Fold cross-validation. 37 | - split.seed: Random seed for shuffling. 38 | - split.target: Target column used for stratification in 'stratified' and 'stratified_group'. 39 | - split.group: Group column used for grouping in 'group' and 'stratified_group'. 40 | - split.kind: Kind of data for splitting in the 'sgkf_kind' case. 41 | data (pd.DataFrame): The dataset to be split. 42 | 43 | Yields: 44 | tuple[np.ndarray, np.ndarray]: Indices of training and validation sets for each fold. 45 | 46 | Raises: 47 | NotImplementedError: If the split type specified in the configuration is not implemented. 48 | """ 49 | if cfg.print_config: 50 | print(f"Using split: {cfg.split.type}") 51 | if cfg.split.type == "simple": 52 | kf = KFold(n_splits=cfg.num_folds, shuffle=True, random_state=cfg.split.seed) 53 | for train_idx, valid_idx in kf.split(data): 54 | yield train_idx, valid_idx 55 | elif cfg.split.type == "stratified": 56 | kf = StratifiedKFold( 57 | n_splits=cfg.num_folds, shuffle=True, random_state=cfg.split.seed 58 | ) 59 | for train_idx, valid_idx in kf.split(data, data[cfg.split.target].astype(int)): 60 | yield train_idx, valid_idx 61 | elif cfg.split.type == "group": 62 | kf = GroupKFold(n_splits=cfg.num_folds) 63 | for train_idx, valid_idx in kf.split(data, groups=data[cfg.split.group]): 64 | yield train_idx, valid_idx 65 | elif cfg.split.type == "stratified_group": 66 | kf = StratifiedGroupKFold( 67 | n_splits=cfg.num_folds, shuffle=True, random_state=cfg.split.seed 68 | ) 69 | for train_idx, valid_idx in kf.split( 70 | data, data[cfg.split.target].astype(int), groups=data[cfg.split.group] 71 | ): 72 | yield train_idx, valid_idx 73 | elif cfg.split.type == "sgkf_kind": 74 | kind = data[cfg.split.kind].unique() 75 | kf = [ 76 | StratifiedGroupKFold( 77 | n_splits=cfg.num_folds, shuffle=True, random_state=cfg.split.seed 78 | ) 79 | for _ in range(len(kind)) 80 | ] 81 | kf = [ 82 | kf_i.split( 83 | data[data[cfg.split.kind] == ds], 84 | data[data[cfg.split.kind] == ds][cfg.split.target].astype(int), 85 | groups=data[data[cfg.split.kind] == ds][cfg.split.group], 86 | ) 87 | for kf_i, ds in zip(kf, kind) 88 | ] 89 | for ds_idx in zip(*kf): 90 | train_idx = np.concatenate([d[0] for d in ds_idx]) 91 | valid_idx = np.concatenate([d[1] for d in ds_idx]) 92 | yield train_idx, valid_idx 93 | else: 94 | raise NotImplementedError 95 | -------------------------------------------------------------------------------- /utmosv2/utils/_task_dependents/__init__.py: -------------------------------------------------------------------------------- 1 | from utmosv2.utils._task_dependents.initializers import ( 2 | get_data, 3 | get_dataset, 4 | get_inference_data, 5 | get_metrics, 6 | get_model, 7 | get_train_data, 8 | ) 9 | from utmosv2.utils._task_dependents.log import show_inference_data 10 | from utmosv2.utils._task_dependents.metrics import calc_metrics 11 | from utmosv2.utils._task_dependents.save import ( 12 | make_submission_file, 13 | save_preds, 14 | save_test_preds, 15 | ) 16 | 17 | __all__ = [ 18 | "get_data", 19 | "get_dataset", 20 | "get_inference_data", 21 | "get_metrics", 22 | "get_model", 23 | "get_train_data", 24 | "show_inference_data", 25 | "calc_metrics", 26 | "make_submission_file", 27 | "save_preds", 28 | "save_test_preds", 29 | ] 30 | -------------------------------------------------------------------------------- /utmosv2/utils/_task_dependents/log.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | if TYPE_CHECKING: 4 | import pandas as pd 5 | 6 | 7 | def show_inference_data(data: "pd.DataFrame") -> None: 8 | print( 9 | data[[c for c in data.columns if c != "mos"]] 10 | .rename(columns={"dataset": "predict_dataset"}) 11 | .head() 12 | ) 13 | -------------------------------------------------------------------------------- /utmosv2/utils/_task_dependents/metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | import numpy as np 6 | import scipy.stats 7 | 8 | if TYPE_CHECKING: 9 | import pandas as pd 10 | 11 | 12 | def calc_metrics(data: "pd.DataFrame", preds: np.ndarray) -> dict[str, float]: 13 | data = data.copy() 14 | data["preds"] = preds 15 | data_sys = data.groupby("sys_id", as_index=False)[["mos", "preds"]].mean() 16 | res = {} 17 | for name, d in {"utt": data, "sys": data_sys}.items(): 18 | res[f"{name}_mse"] = np.mean((d["mos"].values - d["preds"].values) ** 2) 19 | res[f"{name}_lcc"] = np.corrcoef(d["mos"].values, d["preds"].values)[0][1] 20 | res[f"{name}_srcc"] = scipy.stats.spearmanr(d["mos"].values, d["preds"].values)[ 21 | 0 22 | ] 23 | res[f"{name}_ktau"] = scipy.stats.kendalltau( 24 | d["mos"].values, d["preds"].values 25 | )[0] 26 | return res 27 | -------------------------------------------------------------------------------- /utmosv2/utils/_task_dependents/save.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import json 4 | from typing import TYPE_CHECKING 5 | 6 | import numpy as np 7 | 8 | from utmosv2._import import _LazyImport 9 | from utmosv2._settings._config import Config 10 | from utmosv2.utils._task_dependents.initializers import _get_test_save_name 11 | 12 | if TYPE_CHECKING: 13 | import pandas as pd 14 | else: 15 | pd = _LazyImport("pandas") 16 | 17 | 18 | def save_test_preds( 19 | cfg: Config, 20 | data: "pd.DataFrame", 21 | test_preds: np.ndarray, 22 | test_metrics: dict[str, float], 23 | ) -> None: 24 | test_df = pd.DataFrame({cfg.id_name: data[cfg.id_name], "test_preds": test_preds}) 25 | cfg.inference.save_path.mkdir(parents=True, exist_ok=True) 26 | save_path = ( 27 | cfg.inference.save_path 28 | / f"{_get_test_save_name(cfg)}_({cfg.predict_dataset})_test_preds{'_final' if cfg.final else ''}.csv" 29 | ) 30 | test_df.to_csv(save_path, index=False) 31 | save_path = ( 32 | cfg.inference.save_path 33 | / f"{_get_test_save_name(cfg)}_({cfg.predict_dataset})_val_score{'_final' if cfg.final else ''}.json" 34 | ) 35 | with open(save_path, "w") as f: 36 | json.dump(test_metrics, f) 37 | print(f"Test predictions are saved to {save_path}") 38 | 39 | 40 | def make_submission_file( 41 | cfg: Config, data: "pd.DataFrame", test_preds: np.ndarray 42 | ) -> None: 43 | submit = pd.DataFrame({cfg.id_name: data[cfg.id_name], "prediction": test_preds}) 44 | ( 45 | cfg.inference.submit_save_path 46 | / f"{_get_test_save_name(cfg)}_({cfg.predict_dataset})" 47 | ).mkdir(parents=True, exist_ok=True) 48 | sub_file = ( 49 | cfg.inference.submit_save_path 50 | / f"{_get_test_save_name(cfg)}_({cfg.predict_dataset})" 51 | / "answer.txt" 52 | ) 53 | submit.to_csv( 54 | sub_file, 55 | index=False, 56 | header=False, 57 | ) 58 | print(f"Submission file is saved to {sub_file}") 59 | 60 | 61 | def save_preds(cfg: Config, data: "pd.DataFrame", test_preds: np.ndarray) -> None: 62 | pred = pd.DataFrame({cfg.id_name: data[cfg.id_name], "mos": test_preds}) 63 | if cfg.out_path is None: 64 | print("Predictions:") 65 | print(pred) 66 | else: 67 | pred.to_csv(cfg.out_path, index=False) 68 | print(f"Predictions are saved to {cfg.out_path}") 69 | --------------------------------------------------------------------------------