6 |
7 |
8 |
9 |
10 |
11 | To train UTMOSv2 following the methods described in the paper or used in the competition, please refer to [this document](reproduction.md).
12 |
13 |
14 |
📩 Install Training Dependencies 📩
15 |
16 |
17 |
18 |
19 |
20 | To install the dependencies required for training, run the following command:
21 |
22 | ```bash
23 | pip install --upgrade pip # enable PEP 660 support
24 | pip install -e .[train,optional]
25 | ```
26 |
27 | > [!NOTE]
28 | > If you are using zsh, make sure to escape the square brackets like this:
29 | >
30 | > ```zsh
31 | > pip install -e '.[train,optional]'
32 | > ```
33 |
34 |
35 |
🚀 Train UTMOSv2 Using Your Own Data 🚀
36 |
37 |
38 |
39 |
40 |
41 | To train UTMOSv2 using your own data, you need to create a JSON file that contains the location and name of your data. Here is an example structure for the JSON file:
42 |
43 | ```json
44 | {
45 | "data": [
46 | {
47 | "name": "dataset1",
48 | "dir": "/path/to/your/dataset1",
49 | "mos_list": "/path/to/your/moslist1.txt"
50 | },
51 | {
52 | "name": "dataset2",
53 | "dir": "/path/to/your/dataset2",
54 | "mos_list": "/path/to/your/moslist2.txt"
55 | }
56 | // Add more data entries as needed
57 | ]
58 | }
59 | ```
60 |
61 | Here, `name` is used to identify the data-domain ID, and `dir` specifies the directory where the corresponding `.wav` files are located. Additionally, mos_list records the MOS values for the .wav files in the directory, in the following format:
62 |
63 | ```text
64 | sys64e2f-utt491a78a,2.375
65 | sys64e2f-utt8485f83,3.625
66 | sys7ab3c-utt1417b69,4.0
67 | ...
68 | ```
69 |
70 | The file extension `.wav` is optional and can be included or omitted. The common files between those in the dir and those specified in the mos_list will be used.
71 |
72 | Specify the name, dir, and mos_list set for each dataset-domain ID you want to train.
73 |
74 | Save this JSON file with an appropriate name, for example, `data_config.json` and run the following command:
75 |
76 | ```bash
77 | python train.py --config spec_only --data_config data_config.json
78 | ```
79 |
80 |
81 |
🧪 Fine-tuning from Pre-trained Weights 🧪
82 |
83 |
84 |
85 |
86 |
87 | To continue training from existing weights, specify the `--weight` option and train as follows. This is useful when you want to perform additional training using weights learned in a previous stage or when fine-tuning.
88 |
89 | ```bash
90 | python train.py --config spec_only --data_config data_config.json --weight /path/to/your/weights.pth
91 | ```
92 |
93 | The `--weight` option can specify either the configuration file name or the path to the weight `.pth` file. If the configuration file name is specified, `models/{config_name}/fold{now_fold}_s{seed}_best_model.pth` is used.
94 |
95 |
96 |
🔬 Using Weights & Biases (wandb) for Experiment Tracking 🔬
97 |
98 |
99 |
100 |
101 |
102 | To use Weights & Biases (wandb) for experiment tracking, specify the `--wandb` option. You will also need to set the `WANDB_API_KEY` in your `.env` file or environment variables, or follow the prompt during execution to input your API key directly in the command line.
103 |
104 | ```bash
105 | python train.py --config spec_only --data_config data_config.json --wandb
106 | ```
107 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import importlib
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from utmosv2._settings import configure_defaults, configure_inference_args
8 | from utmosv2._settings._config import Config
9 | from utmosv2.runner import run_inference
10 | from utmosv2.utils import (
11 | get_dataloader,
12 | get_dataset,
13 | get_inference_data,
14 | get_model,
15 | make_submission_file,
16 | print_metrics,
17 | save_preds,
18 | save_test_preds,
19 | show_inference_data,
20 | )
21 |
22 |
23 | def main(cfg: Config) -> None:
24 | data = get_inference_data(cfg)
25 | show_inference_data(data)
26 |
27 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28 | cfg.print_config = True # type: ignore
29 |
30 | test_preds = np.zeros(data.shape[0])
31 | if cfg.reproduce:
32 | test_metrics: dict[str, float] = {}
33 |
34 | for fold in range(cfg.num_folds):
35 | if 0 <= cfg.inference.fold < cfg.num_folds and fold != cfg.inference.fold:
36 | continue
37 |
38 | cfg.now_fold = fold # type: ignore
39 |
40 | model = get_model(cfg, device)
41 |
42 | cfg.print_config = False # type: ignore
43 | print(f"+*+*[[Fold {fold + 1}/{cfg.num_folds}]]" + "+*" * 30)
44 |
45 | for cycle in range(cfg.inference.num_tta):
46 | test_dataset = get_dataset(cfg, data, "test")
47 | test_dataloader = get_dataloader(cfg, test_dataset, "test")
48 | test_preds_tta, test_metrics_tta = run_inference(
49 | cfg, model, test_dataloader, cycle, data, device
50 | )
51 | test_preds += test_preds_tta
52 | if cfg.reproduce:
53 | assert test_metrics_tta is not None
54 | for k, v in test_metrics_tta.items():
55 | test_metrics[k] = test_metrics.get(k, 0) + v
56 |
57 | fold_cnt = 1 if 0 <= cfg.inference.fold < cfg.num_folds else cfg.num_folds
58 | print(f"Average of {fold_cnt} folds")
59 | test_preds /= fold_cnt * cfg.inference.num_tta
60 | if cfg.reproduce:
61 | test_metrics = {
62 | k: v / fold_cnt / cfg.inference.num_tta for k, v in test_metrics.items()
63 | }
64 | print_metrics(test_metrics)
65 | save_test_preds(cfg, data, test_preds, test_metrics)
66 | make_submission_file(cfg, data, test_preds)
67 | else:
68 | save_preds(cfg, data, test_preds)
69 |
70 |
71 | if __name__ == "__main__":
72 | parser = argparse.ArgumentParser()
73 | parser.add_argument(
74 | "-c", "--config", type=str, default="fusion_stage3", help="config file name"
75 | )
76 | parser.add_argument("-f", "--fold", type=int, default=0, help="fold number")
77 | parser.add_argument(
78 | "-s", "--seed", type=int, default=42, help="random seed for split"
79 | )
80 | parser.add_argument("-d", "--input_dir", type=str, help="data path")
81 | parser.add_argument("-p", "--input_path", type=str, help="data path")
82 | parser.add_argument("-o", "--out_path", type=str, help="output path")
83 | parser.add_argument(
84 | "-n",
85 | "--num_workers",
86 | type=int,
87 | default=4,
88 | help="number of workers for dataloader",
89 | )
90 | parser.add_argument(
91 | "-t",
92 | "--val_list_path",
93 | type=str,
94 | help="test data path",
95 | )
96 | parser.add_argument(
97 | "-w", "--weight", type=str, default=None, help="path to the weight file to load"
98 | )
99 | parser.add_argument(
100 | "-pd",
101 | "--predict_dataset",
102 | type=str,
103 | default="sarulab",
104 | help="predict dataset",
105 | )
106 | parser.add_argument(
107 | "-nr",
108 | "--num_repetitions",
109 | type=int,
110 | default=1,
111 | help="number of repetitions for prediction",
112 | )
113 | parser.add_argument(
114 | "-e",
115 | "--reproduce",
116 | action="store_true",
117 | help="Run the experiment as described in the paper, including all necessary steps for reproducibility.",
118 | )
119 | parser.add_argument(
120 | "-fi",
121 | "--final",
122 | action="store_true",
123 | help="final submission",
124 | )
125 | args = parser.parse_args()
126 |
127 | if args.input_dir is None and args.input_path is None:
128 | raise ValueError(
129 | "Either input_dir or input_path must be provided when you use your own data."
130 | )
131 | if args.input_dir is not None and args.input_path is not None:
132 | raise ValueError(
133 | "Only one of input_dir or input_path must be provided when you use your own data."
134 | )
135 |
136 | cfg = importlib.import_module("utmosv2.config." + args.config)
137 | configure_inference_args(cfg, args)
138 | configure_defaults(cfg)
139 |
140 | main(cfg)
141 |
--------------------------------------------------------------------------------
/poster.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sarulab-speech/UTMOSv2/00b80845f85fad4e3c23a743851304e0e23c5a02/poster.pdf
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=42", "wheel"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "utmosv2"
7 | description = "UTokyo-SaruLab MOS Prediction System"
8 | readme = "README.md"
9 | license = { file = "LICENSE" }
10 | authors = [{ name = "Kaito Baba" }]
11 | classifiers = [
12 | "License :: OSI Approved :: MIT License",
13 | "Programming Language :: Python :: 3.9",
14 | "Programming Language :: Python :: 3.10",
15 | "Programming Language :: Python :: 3.11",
16 | "Programming Language :: Python :: 3.12",
17 | "Programming Language :: Python :: 3 :: Only",
18 | ]
19 | dependencies = [
20 | "numpy>=1.24.4",
21 | "torch>=2.3.1",
22 | "timm>=1.0.7",
23 | "librosa>=0.10.2",
24 | "tqdm>=4.66.4",
25 | "transformers>=4.42.4",
26 | "typing-extensions"
27 | ]
28 | requires-python = ">=3.9"
29 | dynamic = ["version"]
30 |
31 | [project.optional-dependencies]
32 | check = ["ruff", "mypy", "types-setuptools", "types-tqdm"]
33 | train = ["scikit-learn>=1.3.2", "wandb>=0.17.0", "python-dotenv>=1.0.1"]
34 | optional = ["pandas>=2.2.2"]
35 | test = ["pytest"]
36 |
37 | [tool.setuptools.dynamic]
38 | version = { attr = "utmosv2.__version__" }
39 |
40 | [tool.setuptools.packages.find]
41 | include = ["utmosv2*"]
42 |
43 | [tool.mypy]
44 | python_version = "3.11"
45 | ignore_missing_imports = true
46 | disallow_untyped_defs = true
47 | exclude = ["^build/"]
48 |
--------------------------------------------------------------------------------
/quickstart.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# 🚀 Quick Introduction to MOS Prediction using UTMOSv2\n",
8 | "\n",
9 | "In this Jupyter notebook, we will introduce a method for predicting MOS (Mean Opinion Score) using UTMOSv2."
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {},
15 | "source": [
16 | "## 🛠 Installation"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": null,
22 | "metadata": {
23 | "vscode": {
24 | "languageId": "plaintext"
25 | }
26 | },
27 | "outputs": [],
28 | "source": [
29 | "!GIT_LFS_SKIP_SMUDGE=1 pip install git+https://github.com/sarulab-speech/UTMOSv2.git"
30 | ]
31 | },
32 | {
33 | "cell_type": "code",
34 | "execution_count": null,
35 | "metadata": {
36 | "vscode": {
37 | "languageId": "plaintext"
38 | }
39 | },
40 | "outputs": [],
41 | "source": [
42 | "import utmosv2"
43 | ]
44 | },
45 | {
46 | "cell_type": "markdown",
47 | "metadata": {},
48 | "source": [
49 | "## 🔮 Make predictions"
50 | ]
51 | },
52 | {
53 | "cell_type": "markdown",
54 | "metadata": {},
55 | "source": [
56 | "To predict the MOS of a single wav file:"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": null,
62 | "metadata": {
63 | "vscode": {
64 | "languageId": "plaintext"
65 | }
66 | },
67 | "outputs": [],
68 | "source": [
69 | "model = utmosv2.create_model(pretrained=True)\n",
70 | "mos = model.predict(input_path=\"/path/to/wav/file.wav\")\n",
71 | "print(mos)"
72 | ]
73 | },
74 | {
75 | "cell_type": "markdown",
76 | "metadata": {},
77 | "source": [
78 | "To predict the MOS of all .wav files in a folder:"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": null,
84 | "metadata": {
85 | "vscode": {
86 | "languageId": "plaintext"
87 | }
88 | },
89 | "outputs": [],
90 | "source": [
91 | "model = utmosv2.create_model(pretrained=True)\n",
92 | "mos = model.predict(input_dir=\"/path/to/wav/dir/\")\n",
93 | "print(mos)"
94 | ]
95 | },
96 | {
97 | "cell_type": "markdown",
98 | "metadata": {},
99 | "source": [
100 | "Note that either `input_path` or `input_dir` must be specified, but not both."
101 | ]
102 | },
103 | {
104 | "cell_type": "markdown",
105 | "metadata": {},
106 | "source": [
107 | "For more details on how to use the inference script, please refer to [inference guide](https://github.com/sarulab-speech/UTMOSv2/blob/main/docs/inference.md)."
108 | ]
109 | }
110 | ],
111 | "metadata": {
112 | "language_info": {
113 | "name": "python"
114 | }
115 | },
116 | "nbformat": 4,
117 | "nbformat_minor": 2
118 | }
119 |
--------------------------------------------------------------------------------
/tests/core_tests/test_create.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import sys
4 |
5 | import pytest
6 |
7 | from utmosv2._core.create import create_model
8 |
9 |
10 | @pytest.mark.parametrize(
11 | "pretrained",
12 | [
13 | pytest.param(
14 | True,
15 | marks=pytest.mark.skipif(
16 | sys.version_info[:2] != (3, 11),
17 | reason="To avoid downloading the model weights multiple times",
18 | ),
19 | ),
20 | False,
21 | ],
22 | )
23 | def test_create_model(pretrained: bool) -> None:
24 | model = create_model(pretrained=pretrained)
25 | assert hasattr(model, "forward")
26 | assert hasattr(model, "predict")
27 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import importlib
3 | import os
4 |
5 | import numpy as np
6 | import torch
7 | import wandb
8 | from dotenv import load_dotenv
9 |
10 | from utmosv2._settings import configure_args, configure_defaults
11 | from utmosv2._settings._config import Config
12 | from utmosv2.runner import run_train
13 | from utmosv2.utils import (
14 | get_dataloader,
15 | get_dataset,
16 | get_loss,
17 | get_metrics,
18 | get_model,
19 | get_optimizer,
20 | get_scheduler,
21 | get_train_data,
22 | save_oof_preds,
23 | split_data,
24 | )
25 |
26 |
27 | def main(cfg: Config) -> None:
28 | data = get_train_data(cfg)
29 | print(data.head())
30 | oof_preds = np.zeros(data.shape[0])
31 |
32 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33 | print(f"Device: {device}")
34 |
35 | cfg.print_config = True # type: ignore
36 |
37 | for fold, (train_idx, val_idx) in enumerate(split_data(cfg, data)):
38 | if 0 <= cfg.fold < cfg.num_folds and fold != cfg.fold:
39 | continue
40 |
41 | cfg.now_fold = fold # type: ignore
42 |
43 | train_data = data.iloc[train_idx]
44 | val_data = data.iloc[val_idx]
45 |
46 | train_dataset = get_dataset(cfg, train_data, "train")
47 | val_dataset = get_dataset(cfg, val_data, "valid")
48 |
49 | train_dataloader = get_dataloader(cfg, train_dataset, "train")
50 | val_dataloader = get_dataloader(cfg, val_dataset, "valid")
51 |
52 | model = get_model(cfg, device)
53 | criterions = get_loss(cfg)
54 | metrics = get_metrics()
55 | optimizer = get_optimizer(cfg, model)
56 | scheduler = get_scheduler(
57 | cfg, optimizer, len(train_dataloader) * cfg.run.num_epochs
58 | )
59 |
60 | cfg.print_config = False # type: ignore
61 | print(f"+*+*[[Fold {fold + 1}/{cfg.num_folds}]]" + "+*" * 30)
62 | if cfg.wandb:
63 | wandb.init(
64 | project="voice-mos-challenge-2024",
65 | name=cfg.config_name,
66 | config={
67 | "fold": fold,
68 | "seed": cfg.split.seed,
69 | },
70 | )
71 |
72 | run_train(
73 | cfg,
74 | model,
75 | train_dataloader,
76 | val_dataloader,
77 | val_data,
78 | oof_preds,
79 | fold,
80 | criterions,
81 | metrics,
82 | optimizer,
83 | scheduler,
84 | device,
85 | )
86 | if cfg.wandb:
87 | wandb.finish()
88 |
89 | save_oof_preds(cfg, data, oof_preds, cfg.fold)
90 |
91 |
92 | if __name__ == "__main__":
93 | parser = argparse.ArgumentParser()
94 | parser.add_argument(
95 | "-c", "--config", type=str, required=True, help="config file name"
96 | )
97 | parser.add_argument("-f", "--fold", type=int, default=-1, help="fold number")
98 | parser.add_argument(
99 | "-s", "--seed", type=int, default=42, help="random seed for split"
100 | )
101 | parser.add_argument(
102 | "-i", "--input_dir", type=str, default="data/main/DATA", help="data path"
103 | )
104 | parser.add_argument(
105 | "-dc", "--data_config", type=str, help="path to the data config file"
106 | )
107 | parser.add_argument(
108 | "-n",
109 | "--num_workers",
110 | type=int,
111 | default=4,
112 | help="number of workers for dataloader",
113 | )
114 | parser.add_argument(
115 | "-w", "--weight", type=str, help="path to the weight file to load"
116 | )
117 | parser.add_argument(
118 | "-e",
119 | "--reproduce",
120 | action="store_true",
121 | help="Run the experiment as described in the paper, including all necessary steps for reproducibility.",
122 | )
123 | parser.add_argument(
124 | "-wb", "--wandb", action="store_true", help="Use wandb for logging"
125 | )
126 | args = parser.parse_args()
127 |
128 | if args.reproduce is None and args.data_config is None:
129 | raise ValueError("Either --reproduce or --data_config must be specified")
130 |
131 | cfg = importlib.import_module("utmosv2.config." + args.config)
132 | configure_args(cfg, args)
133 | configure_defaults(cfg)
134 |
135 | load_dotenv()
136 | if cfg.wandb:
137 | wandb.login(key=os.getenv("WANDB_API_KEY"))
138 |
139 | main(cfg)
140 |
--------------------------------------------------------------------------------
/utmosv2/__init__.py:
--------------------------------------------------------------------------------
1 | from utmosv2 import config, dataset, loss, model, preprocess, runner, transform, utils
2 | from utmosv2._core import UTMOSv2Model, create_model
3 |
4 | __all__ = [
5 | "config",
6 | "dataset",
7 | "loss",
8 | "model",
9 | "preprocess",
10 | "runner",
11 | "transform",
12 | "utils",
13 | "create_model",
14 | "UTMOSv2Model",
15 | ]
16 |
17 | __version__ = "1.2.0"
18 |
--------------------------------------------------------------------------------
/utmosv2/_core/__init__.py:
--------------------------------------------------------------------------------
1 | from utmosv2._core.create import create_model
2 | from utmosv2._core.model import UTMOSv2Model
3 | from utmosv2._core.model._common import UTMOSv2ModelMixin
4 |
5 | __all__ = ["UTMOSv2Model", "UTMOSv2ModelMixin", "create_model"]
6 |
--------------------------------------------------------------------------------
/utmosv2/_core/create.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import importlib
4 | from pathlib import Path
5 | from types import SimpleNamespace
6 | from typing import Literal
7 |
8 | import torch
9 |
10 | from utmosv2._core.model import UTMOSv2Model
11 | from utmosv2._settings import configure_execution
12 | from utmosv2.utils._constants import _UTMOSV2_CHACHE
13 | from utmosv2.utils._download import download_pretrained_weights_from_hf
14 |
15 |
16 | def create_model(
17 | pretrained: bool = True,
18 | config: str = "fusion_stage3",
19 | fold: int = 0,
20 | checkpoint_path: Path | str | None = None,
21 | seed: int = 42,
22 | device: torch.device | str | Literal["auto"] = "auto",
23 | ) -> UTMOSv2Model:
24 | """
25 | Create a UTMOSv2 model with the specified configuration and optional pretrained weights.
26 |
27 | Args:
28 | pretrained (bool):
29 | If True, loads pretrained weights. Defaults to True.
30 | config (str):
31 | The configuration name to load for the model. Defaults to "fusion_stage3".
32 | fold (int):
33 | The fold number for the pretrained weights (used for model selection). Defaults to 0.
34 | checkpoint_path (Path | str | None):
35 | Path to a specific model checkpoint. If None, the checkpoint downloaded from GitHub is used. Defaults to None.
36 | seed (int):
37 | The seed used for model training to select the correct checkpoint. Defaults to 42.
38 |
39 | Returns:
40 | UTMOSv2Model: The initialized UTMOSv2 model.
41 |
42 | Raises:
43 | FileNotFoundError: If the specified checkpoint file is not found.
44 |
45 | Notes:
46 | - The configuration is dynamically loaded from `utmosv2.config`.
47 | - If `pretrained` is True and `checkpoint_path` is not provided, the function attempts to download pretrained weights from GitHub.
48 | """
49 | _cfg = importlib.import_module(f"utmosv2.config.{config}")
50 | # Avoid issues with pickling `types.ModuleType`,
51 | # making it easier to use with multiprocessing, DDP, etc.
52 | cfg = SimpleNamespace(
53 | **{k: v for k, v in _cfg.__dict__.items() if not k.startswith("__")}
54 | )
55 | configure_execution(cfg)
56 |
57 | model = UTMOSv2Model(cfg)
58 |
59 | if pretrained:
60 | if checkpoint_path is None:
61 | checkpoint_path = (
62 | _UTMOSV2_CHACHE
63 | / "models"
64 | / config
65 | / f"fold{fold}_s{seed}_best_model.pth"
66 | )
67 | if not checkpoint_path.exists():
68 | download_pretrained_weights_from_hf(config, fold)
69 | if isinstance(checkpoint_path, str):
70 | checkpoint_path = Path(checkpoint_path)
71 | if not checkpoint_path.exists():
72 | raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
73 |
74 | device = torch.device(
75 | ("cuda" if torch.cuda.is_available() else "cpu")
76 | if device == "auto"
77 | else device
78 | )
79 | model.load_state_dict(torch.load(checkpoint_path, map_location=device))
80 | print(f"Loaded checkpoint from {checkpoint_path}")
81 |
82 | return model
83 |
--------------------------------------------------------------------------------
/utmosv2/_core/model/__init__.py:
--------------------------------------------------------------------------------
1 | from utmosv2._core.model._common import UTMOSv2ModelMixin
2 | from utmosv2._core.model._models import UTMOSv2Model
3 |
4 | __all__ = ["UTMOSv2Model", "UTMOSv2ModelMixin"]
5 |
--------------------------------------------------------------------------------
/utmosv2/_core/model/_models.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import TYPE_CHECKING, Any
4 |
5 | from utmosv2._core.model._common import UTMOSv2ModelMixin
6 | from utmosv2._settings._config import Config
7 | from utmosv2.model import (
8 | MultiSpecExtModel,
9 | MultiSpecModelV2,
10 | SSLExtModel,
11 | SSLMultiSpecExtModelV1,
12 | SSLMultiSpecExtModelV2,
13 | )
14 |
15 | if TYPE_CHECKING:
16 | import torch
17 | import torch.nn as nn
18 |
19 |
20 | class UTMOSv2Model(UTMOSv2ModelMixin):
21 | """
22 | UTMOSv2Model class that wraps different models specified by the configuration.
23 | This class allows for flexible model selection and provides a unified interface for evaluation, calling, and prediction.
24 | """
25 |
26 | def __init__(self, cfg: Config):
27 | """
28 | Initialize the UTMOSv2Model with a specified configuration.
29 |
30 | Args:
31 | cfg (SimpleNamespace | ModuleType): Configuration object that contains the model configuration.
32 |
33 | Raises:
34 | ValueError: If the model name specified in the configuration is not recognized.
35 | """
36 | models = {
37 | "multi_spec_ext": MultiSpecExtModel,
38 | "multi_specv2": MultiSpecModelV2,
39 | "sslext": SSLExtModel,
40 | "ssl_multispec_ext": SSLMultiSpecExtModelV1,
41 | "ssl_multispec_ext_v2": SSLMultiSpecExtModelV2,
42 | }
43 | if cfg.model.name not in models:
44 | raise ValueError(f"Unknown model name: {cfg.model.name}")
45 | self._model = models[cfg.model.name](cfg)
46 | self._cfg_value = cfg
47 |
48 | @property
49 | def _cfg(self) -> Config:
50 | return self._cfg_value
51 |
52 | def eval(self) -> "nn.Module":
53 | return self._model.eval()
54 |
55 | def __call__(self, *args: Any, **kwargs: Any) -> "torch.Tensor":
56 | return self._model(*args, **kwargs)
57 |
58 | def __getattr__(self, name: str) -> Any:
59 | return getattr(self._model, name)
60 |
61 | def __setattr__(self, name: str, value: Any) -> None:
62 | if name == "_model":
63 | super().__setattr__(name, value)
64 | else:
65 | setattr(self._model, name, value)
66 |
67 | def __delattr__(self, name: str) -> None:
68 | delattr(self._model, name)
69 |
70 | def __repr__(self) -> str:
71 | return f"UTMOSv2Model({'('.join(self._model.__repr__().split('(')[1:])}"
72 |
73 | def __str__(self) -> str:
74 | return self.__repr__()
75 |
76 | def __dir__(self) -> list[str]:
77 | return super().__dir__() + self._model.__dir__()
78 |
--------------------------------------------------------------------------------
/utmosv2/_import.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import importlib
4 | import types
5 | from typing import Any
6 |
7 |
8 | class _LazyImport(types.ModuleType):
9 | def __init__(self, name: str):
10 | super().__init__(name)
11 | self._name = name
12 | self._module: types.ModuleType | None = None
13 |
14 | def __getattr__(self, name: str) -> Any:
15 | if self._module is None:
16 | self._module = importlib.import_module(self._name)
17 | self.__dict__.update(self._module.__dict__)
18 | return getattr(self._module, name)
19 |
--------------------------------------------------------------------------------
/utmosv2/_settings/__init__.py:
--------------------------------------------------------------------------------
1 | from utmosv2._settings._config import (
2 | configure_args,
3 | configure_defaults,
4 | configure_execution,
5 | configure_inference_args,
6 | )
7 |
8 | __all__ = [
9 | "configure_args",
10 | "configure_defaults",
11 | "configure_inference_args",
12 | "configure_execution",
13 | ]
14 |
--------------------------------------------------------------------------------
/utmosv2/_settings/_config.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import argparse
4 | import sys
5 | from pathlib import Path
6 | from types import ModuleType, SimpleNamespace
7 |
8 | if sys.version_info >= (3, 10):
9 | from typing import TypeAlias
10 |
11 | # NOTE: Python 3.12 introduces the type statement, so once Python 3.11 is dropped,
12 | # it should be updated to use that instead.
13 | Config: TypeAlias = SimpleNamespace | ModuleType
14 | else:
15 | from typing import Union
16 |
17 | from typing_extensions import TypeAlias
18 |
19 | Config: TypeAlias = Union[SimpleNamespace, ModuleType]
20 |
21 |
22 | def configure_args(cfg: Config, args: argparse.Namespace) -> None:
23 | cfg.fold = args.fold # type: ignore
24 | cfg.split.seed = args.seed # type: ignore
25 | cfg.config_name = args.config # type: ignore
26 | cfg.input_dir = args.input_dir and Path(args.input_dir) # type: ignore
27 | cfg.num_workers = args.num_workers # type: ignore
28 | cfg.weight = args.weight # type: ignore
29 | cfg.save_path = Path("models") / cfg.config_name # type: ignore
30 | cfg.wandb = args.wandb # type: ignore
31 | cfg.reproduce = args.reproduce # type: ignore
32 | cfg.data_config = args.data_config # type: ignore
33 | cfg.phase = "train" # type: ignore
34 |
35 |
36 | def configure_inference_args(cfg: Config, args: argparse.Namespace) -> None:
37 | cfg.inference.fold = args.fold # type: ignore
38 | cfg.split.seed = args.seed # type: ignore
39 | cfg.config_name = args.config # type: ignore
40 | cfg.input_dir = args.input_dir and Path(args.input_dir) # type: ignore
41 | cfg.input_path = args.input_path and Path(args.input_path) # type: ignore
42 | cfg.num_workers = args.num_workers # type: ignore
43 | cfg.weight = args.weight # type: ignore
44 | if not cfg.weight:
45 | cfg.weight = cfg.config_name # type: ignore
46 | cfg.inference.val_list_path = args.val_list_path and Path(args.val_list_path) # type: ignore
47 | cfg.save_path = Path("models") / cfg.config_name # type: ignore
48 | cfg.predict_dataset = args.predict_dataset # type: ignore
49 | cfg.final = args.final # type: ignore
50 | cfg.inference.num_tta = args.num_repetitions # type: ignore
51 | cfg.reproduce = args.reproduce # type: ignore
52 | cfg.out_path = args.out_path and Path(args.out_path) # type: ignore
53 | cfg.data_config = None # type: ignore
54 | cfg.phase = "inference" # type: ignore
55 |
56 |
57 | def configure_defaults(cfg: Config) -> None:
58 | if cfg.id_name is None:
59 | cfg.id_name = "utt_id" # type: ignore
60 |
61 |
62 | def configure_execution(cfg: Config) -> None:
63 | cfg.data_config = None # type: ignore
64 | cfg.phase = "prediction" # type: ignore
65 | cfg.print_config = False # type: ignore
66 |
--------------------------------------------------------------------------------
/utmosv2/config/c_fusion_stage2.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 | from types import SimpleNamespace
5 |
6 | from torchvision import transforms
7 |
8 | from utmosv2.transform import XYMasking
9 |
10 | batch_size = 16
11 | num_folds = 5
12 |
13 | sr = 16000
14 |
15 | preprocess = SimpleNamespace(
16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data")
17 | )
18 |
19 | split = SimpleNamespace(
20 | type="stratified_group",
21 | target="mos",
22 | group="sys_id",
23 | )
24 |
25 | external_data: list[str] | str = "all"
26 | use_bvcc = True
27 |
28 |
29 | validation_dataset = "sarulab"
30 |
31 | dataset = SimpleNamespace(
32 | name="ssl_multispec_ext",
33 | specs=[
34 | SimpleNamespace(
35 | mode="melspec",
36 | n_fft=4096,
37 | hop_length=32,
38 | win_length=4096,
39 | n_mels=512,
40 | shape=(512, 512),
41 | norm=80,
42 | ),
43 | SimpleNamespace(
44 | mode="melspec",
45 | n_fft=4096,
46 | hop_length=32,
47 | win_length=2048,
48 | n_mels=512,
49 | shape=(512, 512),
50 | norm=80,
51 | ),
52 | SimpleNamespace(
53 | mode="melspec",
54 | n_fft=4096,
55 | hop_length=32,
56 | win_length=1024,
57 | n_mels=512,
58 | shape=(512, 512),
59 | norm=80,
60 | ),
61 | SimpleNamespace(
62 | mode="melspec",
63 | n_fft=4096,
64 | hop_length=32,
65 | win_length=512,
66 | n_mels=512,
67 | shape=(512, 512),
68 | norm=80,
69 | ),
70 | ],
71 | spec_frames=SimpleNamespace(
72 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile"
73 | ),
74 | ssl=SimpleNamespace(
75 | duration=3,
76 | ),
77 | )
78 | transform = dict(
79 | train=transforms.Compose(
80 | [
81 | transforms.Resize((512, 512)),
82 | XYMasking(
83 | num_masks_x=(0, 2),
84 | num_masks_y=(0, 2),
85 | mask_x_length=(10, 40),
86 | mask_y_length=(10, 30),
87 | fill_value=0,
88 | p=0.5,
89 | ),
90 | # transforms.ToTensor(),
91 | ]
92 | ),
93 | valid=transforms.Compose(
94 | [
95 | transforms.Resize((512, 512)),
96 | # transforms.ToTensor()
97 | ]
98 | ),
99 | )
100 |
101 | loss = [
102 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7),
103 | (SimpleNamespace(name="mse"), 0.2),
104 | ]
105 |
106 | optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4)
107 |
108 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-5)
109 |
110 | model = SimpleNamespace(
111 | name="ssl_multispec_ext",
112 | multi_spec=SimpleNamespace(
113 | backbone="tf_efficientnetv2_s.in21k_ft_in1k",
114 | pretrained=True,
115 | num_classes=1,
116 | pool_type="catavgmax",
117 | # feature_height=16,
118 | atten=True,
119 | # classifier=None,
120 | ),
121 | ssl=SimpleNamespace(
122 | name="facebook/wav2vec2-base",
123 | attn=1,
124 | freeze=True,
125 | num_classes=1,
126 | ),
127 | ssl_spec=SimpleNamespace(
128 | ssl_weight="c_ssl_only_stage2",
129 | spec_weight="c_spec_only_stage2",
130 | num_classes=1,
131 | freeze=True,
132 | ),
133 | )
134 |
135 | run = SimpleNamespace(
136 | mixup=True,
137 | mixup_alpha=0.4,
138 | num_epochs=8,
139 | )
140 |
141 | main_metric = "sys_srcc"
142 | id_name = None
143 |
144 |
145 | inference = SimpleNamespace(
146 | save_path=Path("preds"),
147 | submit_save_path=Path("submissions"),
148 | num_tta=5,
149 | batch_size=8,
150 | extend="tile",
151 | )
152 |
--------------------------------------------------------------------------------
/utmosv2/config/c_fusion_stage3.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 | from types import SimpleNamespace
5 |
6 | from torchvision import transforms
7 |
8 | from utmosv2.transform import XYMasking
9 |
10 | batch_size = 8
11 | num_folds = 5
12 |
13 | sr = 16000
14 |
15 | preprocess = SimpleNamespace(
16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data")
17 | )
18 |
19 | split = SimpleNamespace(
20 | type="stratified_group",
21 | target="mos",
22 | group="sys_id",
23 | )
24 |
25 | external_data: list[str] | str = "all"
26 | use_bvcc = True
27 |
28 |
29 | validation_dataset = "sarulab"
30 |
31 | dataset = SimpleNamespace(
32 | name="ssl_multispec_ext",
33 | specs=[
34 | SimpleNamespace(
35 | mode="melspec",
36 | n_fft=4096,
37 | hop_length=32,
38 | win_length=4096,
39 | n_mels=512,
40 | shape=(512, 512),
41 | norm=80,
42 | ),
43 | SimpleNamespace(
44 | mode="melspec",
45 | n_fft=4096,
46 | hop_length=32,
47 | win_length=2048,
48 | n_mels=512,
49 | shape=(512, 512),
50 | norm=80,
51 | ),
52 | SimpleNamespace(
53 | mode="melspec",
54 | n_fft=4096,
55 | hop_length=32,
56 | win_length=1024,
57 | n_mels=512,
58 | shape=(512, 512),
59 | norm=80,
60 | ),
61 | SimpleNamespace(
62 | mode="melspec",
63 | n_fft=4096,
64 | hop_length=32,
65 | win_length=512,
66 | n_mels=512,
67 | shape=(512, 512),
68 | norm=80,
69 | ),
70 | ],
71 | spec_frames=SimpleNamespace(
72 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile"
73 | ),
74 | ssl=SimpleNamespace(
75 | duration=3,
76 | ),
77 | )
78 | transform = dict(
79 | train=transforms.Compose(
80 | [
81 | transforms.Resize((512, 512)),
82 | XYMasking(
83 | num_masks_x=(0, 2),
84 | num_masks_y=(0, 2),
85 | mask_x_length=(10, 40),
86 | mask_y_length=(10, 30),
87 | fill_value=0,
88 | p=0.5,
89 | ),
90 | # transforms.ToTensor(),
91 | ]
92 | ),
93 | valid=transforms.Compose(
94 | [
95 | transforms.Resize((512, 512)),
96 | # transforms.ToTensor()
97 | ]
98 | ),
99 | )
100 |
101 | loss = [
102 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7),
103 | (SimpleNamespace(name="mse"), 0.2),
104 | ]
105 |
106 | optimizer = SimpleNamespace(name="adamw", lr=5e-5, weight_decay=1e-4)
107 |
108 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-8)
109 |
110 | model = SimpleNamespace(
111 | name="ssl_multispec_ext",
112 | multi_spec=SimpleNamespace(
113 | backbone="tf_efficientnetv2_s.in21k_ft_in1k",
114 | pretrained=True,
115 | num_classes=1,
116 | pool_type="catavgmax",
117 | # feature_height=16,
118 | atten=True,
119 | # classifier=None,
120 | ),
121 | ssl=SimpleNamespace(
122 | name="facebook/wav2vec2-base",
123 | attn=1,
124 | freeze=False,
125 | num_classes=1,
126 | ),
127 | ssl_spec=SimpleNamespace(
128 | ssl_weight="c_ssl_only_stage2",
129 | spec_weight="c_spec_only_stage2",
130 | num_classes=1,
131 | freeze=False,
132 | ),
133 | )
134 |
135 | run = SimpleNamespace(
136 | mixup=True,
137 | mixup_alpha=0.4,
138 | num_epochs=2,
139 | )
140 |
141 | main_metric = "sys_srcc"
142 | id_name = None
143 |
144 |
145 | inference = SimpleNamespace(
146 | save_path=Path("preds"),
147 | submit_save_path=Path("submissions"),
148 | num_tta=5,
149 | batch_size=8,
150 | extend="tile",
151 | )
152 |
--------------------------------------------------------------------------------
/utmosv2/config/c_spec_only_stage1.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 | from types import SimpleNamespace
5 |
6 | from torchvision import transforms
7 |
8 | from utmosv2.transform import XYMasking
9 |
10 | batch_size = 10
11 | num_folds = 5
12 |
13 | sr = 16000
14 |
15 | preprocess = SimpleNamespace(
16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data")
17 | )
18 |
19 | split = SimpleNamespace(
20 | type="stratified_group",
21 | target="mos",
22 | group="sys_id",
23 | )
24 |
25 | external_data: list[str] | str = []
26 | use_bvcc = True
27 |
28 |
29 | validation_dataset = "bvcc"
30 |
31 | dataset = SimpleNamespace(
32 | name="multi_spec",
33 | specs=[
34 | SimpleNamespace(
35 | mode="melspec",
36 | n_fft=4096,
37 | hop_length=32,
38 | win_length=4096,
39 | n_mels=512,
40 | shape=(512, 512),
41 | norm=80,
42 | ),
43 | SimpleNamespace(
44 | mode="melspec",
45 | n_fft=4096,
46 | hop_length=32,
47 | win_length=2048,
48 | n_mels=512,
49 | shape=(512, 512),
50 | norm=80,
51 | ),
52 | SimpleNamespace(
53 | mode="melspec",
54 | n_fft=4096,
55 | hop_length=32,
56 | win_length=1024,
57 | n_mels=512,
58 | shape=(512, 512),
59 | norm=80,
60 | ),
61 | SimpleNamespace(
62 | mode="melspec",
63 | n_fft=4096,
64 | hop_length=32,
65 | win_length=512,
66 | n_mels=512,
67 | shape=(512, 512),
68 | norm=80,
69 | ),
70 | ],
71 | spec_frames=SimpleNamespace(
72 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile"
73 | ),
74 | )
75 | transform = dict(
76 | train=transforms.Compose(
77 | [
78 | transforms.Resize((512, 512)),
79 | XYMasking(
80 | num_masks_x=(0, 2),
81 | num_masks_y=(0, 2),
82 | mask_x_length=(10, 40),
83 | mask_y_length=(10, 30),
84 | fill_value=0,
85 | p=0.5,
86 | ),
87 | # transforms.ToTensor(),
88 | ]
89 | ),
90 | valid=transforms.Compose(
91 | [
92 | transforms.Resize((512, 512)),
93 | # transforms.ToTensor()
94 | ]
95 | ),
96 | )
97 |
98 | loss = [
99 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7),
100 | (SimpleNamespace(name="mse"), 0.2),
101 | ]
102 |
103 | optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4)
104 |
105 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7)
106 |
107 | model = SimpleNamespace(
108 | name="multi_specv2",
109 | multi_spec=SimpleNamespace(
110 | backbone="tf_efficientnetv2_s.in21k_ft_in1k",
111 | pretrained=True,
112 | num_classes=1,
113 | pool_type="catavgmax",
114 | # feature_height=16,
115 | atten=True,
116 | # classifier=None,
117 | ),
118 | )
119 |
120 | run = SimpleNamespace(
121 | mixup=True,
122 | mixup_alpha=0.4,
123 | num_epochs=20,
124 | )
125 |
126 | main_metric = "sys_srcc"
127 | id_name = None
128 |
129 |
130 | inference = SimpleNamespace(
131 | save_path=Path("preds"),
132 | submit_save_path=Path("submissions"),
133 | num_tta=5,
134 | batch_size=8,
135 | extend="tile",
136 | )
137 |
--------------------------------------------------------------------------------
/utmosv2/config/c_spec_only_stage2.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 | from types import SimpleNamespace
5 |
6 | from torchvision import transforms
7 |
8 | from utmosv2.transform import XYMasking
9 |
10 | batch_size = 10
11 | num_folds = 5
12 |
13 | sr = 16000
14 |
15 | preprocess = SimpleNamespace(
16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data")
17 | )
18 |
19 | split = SimpleNamespace(
20 | type="stratified_group",
21 | target="mos",
22 | group="sys_id",
23 | )
24 |
25 | external_data: list[str] | str = ["sarulab"]
26 | use_bvcc = False
27 |
28 |
29 | validation_dataset = "sarulab"
30 |
31 | dataset = SimpleNamespace(
32 | name="multi_spec",
33 | specs=[
34 | SimpleNamespace(
35 | mode="melspec",
36 | n_fft=4096,
37 | hop_length=32,
38 | win_length=4096,
39 | n_mels=512,
40 | shape=(512, 512),
41 | norm=80,
42 | ),
43 | SimpleNamespace(
44 | mode="melspec",
45 | n_fft=4096,
46 | hop_length=32,
47 | win_length=2048,
48 | n_mels=512,
49 | shape=(512, 512),
50 | norm=80,
51 | ),
52 | SimpleNamespace(
53 | mode="melspec",
54 | n_fft=4096,
55 | hop_length=32,
56 | win_length=1024,
57 | n_mels=512,
58 | shape=(512, 512),
59 | norm=80,
60 | ),
61 | SimpleNamespace(
62 | mode="melspec",
63 | n_fft=4096,
64 | hop_length=32,
65 | win_length=512,
66 | n_mels=512,
67 | shape=(512, 512),
68 | norm=80,
69 | ),
70 | ],
71 | spec_frames=SimpleNamespace(
72 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile"
73 | ),
74 | )
75 | transform = dict(
76 | train=transforms.Compose(
77 | [
78 | transforms.Resize((512, 512)),
79 | XYMasking(
80 | num_masks_x=(0, 2),
81 | num_masks_y=(0, 2),
82 | mask_x_length=(10, 40),
83 | mask_y_length=(10, 30),
84 | fill_value=0,
85 | p=0.5,
86 | ),
87 | # transforms.ToTensor(),
88 | ]
89 | ),
90 | valid=transforms.Compose(
91 | [
92 | transforms.Resize((512, 512)),
93 | # transforms.ToTensor()
94 | ]
95 | ),
96 | )
97 |
98 | loss = [
99 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7),
100 | (SimpleNamespace(name="mse"), 0.2),
101 | ]
102 |
103 | optimizer = SimpleNamespace(name="adamw", lr=5e-5, weight_decay=1e-4)
104 |
105 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-9)
106 |
107 | model = SimpleNamespace(
108 | name="multi_specv2",
109 | multi_spec=SimpleNamespace(
110 | backbone="tf_efficientnetv2_s.in21k_ft_in1k",
111 | pretrained=True,
112 | num_classes=1,
113 | pool_type="catavgmax",
114 | # feature_height=16,
115 | atten=True,
116 | # classifier=None,
117 | ),
118 | )
119 |
120 | run = SimpleNamespace(
121 | mixup=True,
122 | mixup_alpha=0.4,
123 | num_epochs=5,
124 | )
125 |
126 | main_metric = "sys_srcc"
127 | id_name = None
128 |
129 |
130 | inference = SimpleNamespace(
131 | save_path=Path("preds"),
132 | submit_save_path=Path("submissions"),
133 | num_tta=5,
134 | batch_size=8,
135 | extend="tile",
136 | )
137 |
--------------------------------------------------------------------------------
/utmosv2/config/c_ssl_only_stage1.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 | from types import SimpleNamespace
5 |
6 | batch_size = 32
7 | num_folds = 5
8 |
9 | sr = 16000
10 |
11 | preprocess = SimpleNamespace(
12 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data")
13 | )
14 |
15 | split = SimpleNamespace(
16 | type="stratified_group",
17 | target="mos",
18 | group="sys_id",
19 | )
20 |
21 | dataset = SimpleNamespace(
22 | name="sslext",
23 | ssl=SimpleNamespace(
24 | duration=3,
25 | ),
26 | )
27 |
28 | external_data: list[str] | str = "all"
29 | use_bvcc = True
30 |
31 |
32 | validation_dataset = "sarulab"
33 |
34 | loss = [
35 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7),
36 | (SimpleNamespace(name="mse"), 0.2),
37 | ]
38 |
39 | optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4)
40 |
41 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7)
42 |
43 | model_path = "model"
44 | model = SimpleNamespace(
45 | name="sslext",
46 | ssl=SimpleNamespace(
47 | name="facebook/wav2vec2-base",
48 | attn=1,
49 | freeze=True,
50 | num_classes=1,
51 | ),
52 | )
53 |
54 | run = SimpleNamespace(
55 | mixup=True,
56 | mixup_alpha=0.4,
57 | num_epochs=20,
58 | )
59 |
60 | main_metric = "sys_srcc"
61 | id_name = None
62 |
63 |
64 | inference = SimpleNamespace(
65 | save_path=Path("preds"),
66 | submit_save_path=Path("submissions"),
67 | num_tta=5,
68 | batch_size=8,
69 | # extend="tile",
70 | )
71 |
--------------------------------------------------------------------------------
/utmosv2/config/c_ssl_only_stage2.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 | from types import SimpleNamespace
5 |
6 | batch_size = 32
7 | num_folds = 5
8 |
9 | sr = 16000
10 |
11 | preprocess = SimpleNamespace(
12 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data")
13 | )
14 |
15 | split = SimpleNamespace(
16 | type="stratified_group",
17 | target="mos",
18 | group="sys_id",
19 | )
20 |
21 | dataset = SimpleNamespace(
22 | name="sslext",
23 | ssl=SimpleNamespace(
24 | duration=3,
25 | ),
26 | )
27 |
28 | external_data: list[str] | str = "all"
29 | use_bvcc = True
30 |
31 |
32 | validation_dataset = "sarulab"
33 |
34 | loss = [
35 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7),
36 | (SimpleNamespace(name="mse"), 0.2),
37 | ]
38 |
39 | optimizer = SimpleNamespace(name="adamw", lr=3e-5, weight_decay=1e-4)
40 |
41 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-9)
42 |
43 | model_path = "model"
44 | model = SimpleNamespace(
45 | name="sslext",
46 | ssl=SimpleNamespace(
47 | name="facebook/wav2vec2-base",
48 | attn=1,
49 | freeze=False,
50 | num_classes=1,
51 | ),
52 | )
53 |
54 | run = SimpleNamespace(
55 | mixup=True,
56 | mixup_alpha=0.4,
57 | num_epochs=5,
58 | )
59 |
60 | main_metric = "sys_srcc"
61 | id_name = None
62 |
63 |
64 | inference = SimpleNamespace(
65 | save_path=Path("preds"),
66 | submit_save_path=Path("submissions"),
67 | num_tta=5,
68 | batch_size=8,
69 | # extend="tile",
70 | )
71 |
--------------------------------------------------------------------------------
/utmosv2/config/fusion_stage2.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from types import SimpleNamespace
4 |
5 | from torchvision import transforms
6 |
7 | from utmosv2.transform import XYMasking
8 |
9 | batch_size = 16
10 | num_folds = 5
11 |
12 | sr = 16000
13 |
14 | preprocess = SimpleNamespace(
15 | top_db=30, min_seconds=None, save_path="preprocessed_data/clip_audio"
16 | )
17 |
18 | split = SimpleNamespace(
19 | type="sgkf_kind",
20 | target="mos",
21 | group="sys_id",
22 | kind="dataset",
23 | )
24 |
25 | external_data: list[str] | str = "all"
26 | use_bvcc = True
27 |
28 | predict_dataset = "ysaito"
29 | # predict_dataset = "bvcc"
30 |
31 | validation_dataset = "each"
32 |
33 | dataset = SimpleNamespace(
34 | name="ssl_multispec_ext",
35 | specs=[
36 | SimpleNamespace(
37 | mode="melspec",
38 | n_fft=4096,
39 | hop_length=32,
40 | win_length=4096,
41 | n_mels=512,
42 | shape=(512, 512),
43 | norm=80,
44 | ),
45 | SimpleNamespace(
46 | mode="melspec",
47 | n_fft=4096,
48 | hop_length=32,
49 | win_length=2048,
50 | n_mels=512,
51 | shape=(512, 512),
52 | norm=80,
53 | ),
54 | SimpleNamespace(
55 | mode="melspec",
56 | n_fft=4096,
57 | hop_length=32,
58 | win_length=1024,
59 | n_mels=512,
60 | shape=(512, 512),
61 | norm=80,
62 | ),
63 | SimpleNamespace(
64 | mode="melspec",
65 | n_fft=4096,
66 | hop_length=32,
67 | win_length=512,
68 | n_mels=512,
69 | shape=(512, 512),
70 | norm=80,
71 | ),
72 | ],
73 | spec_frames=SimpleNamespace(
74 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile"
75 | ),
76 | ssl=SimpleNamespace(
77 | duration=3,
78 | ),
79 | )
80 | transform = dict(
81 | train=transforms.Compose(
82 | [
83 | transforms.Resize((512, 512)),
84 | XYMasking(
85 | num_masks_x=(0, 2),
86 | num_masks_y=(0, 2),
87 | mask_x_length=(10, 40),
88 | mask_y_length=(10, 30),
89 | fill_value=0,
90 | p=0.5,
91 | ),
92 | # transforms.ToTensor(),
93 | ]
94 | ),
95 | valid=transforms.Compose(
96 | [
97 | transforms.Resize((512, 512)),
98 | # transforms.ToTensor()
99 | ]
100 | ),
101 | )
102 |
103 | loss = [
104 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7),
105 | (SimpleNamespace(name="mse"), 0.2),
106 | ]
107 |
108 | optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4)
109 |
110 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-5)
111 |
112 | model = SimpleNamespace(
113 | name="ssl_multispec_ext_v2",
114 | multi_spec=SimpleNamespace(
115 | backbone="tf_efficientnetv2_s.in21k_ft_in1k",
116 | pretrained=True,
117 | num_classes=1,
118 | pool_type="catavgmax",
119 | # feature_height=16,
120 | atten=True,
121 | # classifier=None,
122 | ),
123 | ssl=SimpleNamespace(
124 | name="facebook/wav2vec2-base",
125 | attn=1,
126 | freeze=True,
127 | num_classes=1,
128 | ),
129 | ssl_spec=SimpleNamespace(
130 | ssl_weight="ssl_only_stage2",
131 | spec_weight="spec_only",
132 | num_classes=1,
133 | freeze=True,
134 | ),
135 | )
136 |
137 | run = SimpleNamespace(
138 | mixup=True,
139 | mixup_alpha=0.4,
140 | num_epochs=8,
141 | )
142 |
143 | main_metric = "sys_srcc"
144 | id_name = None
145 |
146 |
147 | inference = SimpleNamespace(
148 | save_path="preds",
149 | submit_save_path="submissions",
150 | num_tta=5,
151 | batch_size=8,
152 | extend="tile",
153 | )
154 |
--------------------------------------------------------------------------------
/utmosv2/config/fusion_stage2_wo_bc.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 | from types import SimpleNamespace
5 |
6 | from torchvision import transforms
7 |
8 | from utmosv2.transform import XYMasking
9 |
10 | batch_size = 16
11 | num_folds = 5
12 |
13 | sr = 16000
14 |
15 | preprocess = SimpleNamespace(
16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data")
17 | )
18 |
19 | split = SimpleNamespace(
20 | type="sgkf_kind",
21 | target="mos",
22 | group="sys_id",
23 | kind="dataset",
24 | )
25 |
26 | external_data: list[str] | str = [
27 | "sarulab",
28 | # "blizzard2008",
29 | # "blizzard2009",
30 | # "blizzard2011",
31 | # "blizzard2010-EH1",
32 | # "blizzard2010-EH2",
33 | # "blizzard2010-ES1",
34 | # "blizzard2010-ES3",
35 | "somos",
36 | ]
37 | use_bvcc = True
38 |
39 |
40 | validation_dataset = "each"
41 |
42 | dataset = SimpleNamespace(
43 | name="ssl_multispec_ext",
44 | specs=[
45 | SimpleNamespace(
46 | mode="melspec",
47 | n_fft=4096,
48 | hop_length=32,
49 | win_length=4096,
50 | n_mels=512,
51 | shape=(512, 512),
52 | norm=80,
53 | ),
54 | SimpleNamespace(
55 | mode="melspec",
56 | n_fft=4096,
57 | hop_length=32,
58 | win_length=2048,
59 | n_mels=512,
60 | shape=(512, 512),
61 | norm=80,
62 | ),
63 | SimpleNamespace(
64 | mode="melspec",
65 | n_fft=4096,
66 | hop_length=32,
67 | win_length=1024,
68 | n_mels=512,
69 | shape=(512, 512),
70 | norm=80,
71 | ),
72 | SimpleNamespace(
73 | mode="melspec",
74 | n_fft=4096,
75 | hop_length=32,
76 | win_length=512,
77 | n_mels=512,
78 | shape=(512, 512),
79 | norm=80,
80 | ),
81 | ],
82 | spec_frames=SimpleNamespace(
83 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile"
84 | ),
85 | ssl=SimpleNamespace(
86 | duration=3,
87 | ),
88 | )
89 | transform = dict(
90 | train=transforms.Compose(
91 | [
92 | transforms.Resize((512, 512)),
93 | XYMasking(
94 | num_masks_x=(0, 2),
95 | num_masks_y=(0, 2),
96 | mask_x_length=(10, 40),
97 | mask_y_length=(10, 30),
98 | fill_value=0,
99 | p=0.5,
100 | ),
101 | # transforms.ToTensor(),
102 | ]
103 | ),
104 | valid=transforms.Compose(
105 | [
106 | transforms.Resize((512, 512)),
107 | # transforms.ToTensor()
108 | ]
109 | ),
110 | )
111 |
112 | loss = [
113 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7),
114 | (SimpleNamespace(name="mse"), 0.2),
115 | ]
116 |
117 | optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4)
118 |
119 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-5)
120 |
121 | model = SimpleNamespace(
122 | name="ssl_multispec_ext_v2",
123 | multi_spec=SimpleNamespace(
124 | backbone="tf_efficientnetv2_s.in21k_ft_in1k",
125 | pretrained=True,
126 | num_classes=1,
127 | pool_type="catavgmax",
128 | # feature_height=16,
129 | atten=True,
130 | # classifier=None,
131 | ),
132 | ssl=SimpleNamespace(
133 | name="facebook/wav2vec2-base",
134 | attn=1,
135 | freeze=True,
136 | num_classes=1,
137 | ),
138 | ssl_spec=SimpleNamespace(
139 | ssl_weight="ssl_only_stage2_wo_bc",
140 | spec_weight="spec_only_wo_bc",
141 | num_classes=1,
142 | freeze=True,
143 | ),
144 | )
145 |
146 | run = SimpleNamespace(
147 | mixup=True,
148 | mixup_alpha=0.4,
149 | num_epochs=8,
150 | )
151 |
152 | main_metric = "sys_srcc"
153 | id_name = None
154 |
155 |
156 | inference = SimpleNamespace(
157 | save_path=Path("preds"),
158 | submit_save_path=Path("submissions"),
159 | num_tta=5,
160 | batch_size=8,
161 | extend="tile",
162 | )
163 |
--------------------------------------------------------------------------------
/utmosv2/config/fusion_stage2_wo_bvcc.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 | from types import SimpleNamespace
5 |
6 | from torchvision import transforms
7 |
8 | from utmosv2.transform import XYMasking
9 |
10 | batch_size = 16
11 | num_folds = 5
12 |
13 | sr = 16000
14 |
15 | preprocess = SimpleNamespace(
16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data")
17 | )
18 |
19 | split = SimpleNamespace(
20 | type="sgkf_kind",
21 | target="mos",
22 | group="sys_id",
23 | kind="dataset",
24 | )
25 |
26 | external_data: list[str] | str = [
27 | "sarulab",
28 | "blizzard2008",
29 | "blizzard2009",
30 | "blizzard2011",
31 | "blizzard2010-EH1",
32 | "blizzard2010-EH2",
33 | "blizzard2010-ES1",
34 | "blizzard2010-ES3",
35 | "somos",
36 | ]
37 | use_bvcc = False
38 |
39 |
40 | validation_dataset = "each"
41 |
42 | dataset = SimpleNamespace(
43 | name="ssl_multispec_ext",
44 | specs=[
45 | SimpleNamespace(
46 | mode="melspec",
47 | n_fft=4096,
48 | hop_length=32,
49 | win_length=4096,
50 | n_mels=512,
51 | shape=(512, 512),
52 | norm=80,
53 | ),
54 | SimpleNamespace(
55 | mode="melspec",
56 | n_fft=4096,
57 | hop_length=32,
58 | win_length=2048,
59 | n_mels=512,
60 | shape=(512, 512),
61 | norm=80,
62 | ),
63 | SimpleNamespace(
64 | mode="melspec",
65 | n_fft=4096,
66 | hop_length=32,
67 | win_length=1024,
68 | n_mels=512,
69 | shape=(512, 512),
70 | norm=80,
71 | ),
72 | SimpleNamespace(
73 | mode="melspec",
74 | n_fft=4096,
75 | hop_length=32,
76 | win_length=512,
77 | n_mels=512,
78 | shape=(512, 512),
79 | norm=80,
80 | ),
81 | ],
82 | spec_frames=SimpleNamespace(
83 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile"
84 | ),
85 | ssl=SimpleNamespace(
86 | duration=3,
87 | ),
88 | )
89 | transform = dict(
90 | train=transforms.Compose(
91 | [
92 | transforms.Resize((512, 512)),
93 | XYMasking(
94 | num_masks_x=(0, 2),
95 | num_masks_y=(0, 2),
96 | mask_x_length=(10, 40),
97 | mask_y_length=(10, 30),
98 | fill_value=0,
99 | p=0.5,
100 | ),
101 | # transforms.ToTensor(),
102 | ]
103 | ),
104 | valid=transforms.Compose(
105 | [
106 | transforms.Resize((512, 512)),
107 | # transforms.ToTensor()
108 | ]
109 | ),
110 | )
111 |
112 | loss = [
113 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7),
114 | (SimpleNamespace(name="mse"), 0.2),
115 | ]
116 |
117 | optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4)
118 |
119 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-5)
120 |
121 | model = SimpleNamespace(
122 | name="ssl_multispec_ext_v2",
123 | multi_spec=SimpleNamespace(
124 | backbone="tf_efficientnetv2_s.in21k_ft_in1k",
125 | pretrained=True,
126 | num_classes=1,
127 | pool_type="catavgmax",
128 | # feature_height=16,
129 | atten=True,
130 | # classifier=None,
131 | ),
132 | ssl=SimpleNamespace(
133 | name="facebook/wav2vec2-base",
134 | attn=1,
135 | freeze=True,
136 | num_classes=1,
137 | ),
138 | ssl_spec=SimpleNamespace(
139 | ssl_weight="ssl_only_stage2_wo_bvcc",
140 | spec_weight="spec_only_wo_bvcc",
141 | num_classes=1,
142 | freeze=True,
143 | ),
144 | )
145 |
146 | run = SimpleNamespace(
147 | mixup=True,
148 | mixup_alpha=0.4,
149 | num_epochs=8,
150 | )
151 |
152 | main_metric = "sys_srcc"
153 | id_name = None
154 |
155 |
156 | inference = SimpleNamespace(
157 | save_path=Path("preds"),
158 | submit_save_path=Path("submissions"),
159 | num_tta=5,
160 | batch_size=8,
161 | extend="tile",
162 | )
163 |
--------------------------------------------------------------------------------
/utmosv2/config/fusion_stage2_wo_sarulab.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 | from types import SimpleNamespace
5 |
6 | from torchvision import transforms
7 |
8 | from utmosv2.transform import XYMasking
9 |
10 | batch_size = 16
11 | num_folds = 5
12 |
13 | sr = 16000
14 |
15 | preprocess = SimpleNamespace(
16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data")
17 | )
18 |
19 | split = SimpleNamespace(
20 | type="sgkf_kind",
21 | target="mos",
22 | group="sys_id",
23 | kind="dataset",
24 | )
25 |
26 | external_data: list[str] | str = [
27 | # "sarulab",
28 | "blizzard2008",
29 | "blizzard2009",
30 | "blizzard2011",
31 | "blizzard2010-EH1",
32 | "blizzard2010-EH2",
33 | "blizzard2010-ES1",
34 | "blizzard2010-ES3",
35 | "somos",
36 | ]
37 | use_bvcc = True
38 |
39 |
40 | validation_dataset = "each"
41 |
42 | dataset = SimpleNamespace(
43 | name="ssl_multispec_ext",
44 | specs=[
45 | SimpleNamespace(
46 | mode="melspec",
47 | n_fft=4096,
48 | hop_length=32,
49 | win_length=4096,
50 | n_mels=512,
51 | shape=(512, 512),
52 | norm=80,
53 | ),
54 | SimpleNamespace(
55 | mode="melspec",
56 | n_fft=4096,
57 | hop_length=32,
58 | win_length=2048,
59 | n_mels=512,
60 | shape=(512, 512),
61 | norm=80,
62 | ),
63 | SimpleNamespace(
64 | mode="melspec",
65 | n_fft=4096,
66 | hop_length=32,
67 | win_length=1024,
68 | n_mels=512,
69 | shape=(512, 512),
70 | norm=80,
71 | ),
72 | SimpleNamespace(
73 | mode="melspec",
74 | n_fft=4096,
75 | hop_length=32,
76 | win_length=512,
77 | n_mels=512,
78 | shape=(512, 512),
79 | norm=80,
80 | ),
81 | ],
82 | spec_frames=SimpleNamespace(
83 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile"
84 | ),
85 | ssl=SimpleNamespace(
86 | duration=3,
87 | ),
88 | )
89 | transform = dict(
90 | train=transforms.Compose(
91 | [
92 | transforms.Resize((512, 512)),
93 | XYMasking(
94 | num_masks_x=(0, 2),
95 | num_masks_y=(0, 2),
96 | mask_x_length=(10, 40),
97 | mask_y_length=(10, 30),
98 | fill_value=0,
99 | p=0.5,
100 | ),
101 | # transforms.ToTensor(),
102 | ]
103 | ),
104 | valid=transforms.Compose(
105 | [
106 | transforms.Resize((512, 512)),
107 | # transforms.ToTensor()
108 | ]
109 | ),
110 | )
111 |
112 | loss = [
113 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7),
114 | (SimpleNamespace(name="mse"), 0.2),
115 | ]
116 |
117 | optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4)
118 |
119 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-5)
120 |
121 | model = SimpleNamespace(
122 | name="ssl_multispec_ext_v2",
123 | multi_spec=SimpleNamespace(
124 | backbone="tf_efficientnetv2_s.in21k_ft_in1k",
125 | pretrained=True,
126 | num_classes=1,
127 | pool_type="catavgmax",
128 | # feature_height=16,
129 | atten=True,
130 | # classifier=None,
131 | ),
132 | ssl=SimpleNamespace(
133 | name="facebook/wav2vec2-base",
134 | attn=1,
135 | freeze=True,
136 | num_classes=1,
137 | ),
138 | ssl_spec=SimpleNamespace(
139 | ssl_weight="ssl_only_stage2_wo_sarulab",
140 | spec_weight="spec_only_wo_sarulab",
141 | num_classes=1,
142 | freeze=True,
143 | ),
144 | )
145 |
146 | run = SimpleNamespace(
147 | mixup=True,
148 | mixup_alpha=0.4,
149 | num_epochs=8,
150 | )
151 |
152 | main_metric = "sys_srcc"
153 | id_name = None
154 |
155 |
156 | inference = SimpleNamespace(
157 | save_path=Path("preds"),
158 | submit_save_path=Path("submissions"),
159 | num_tta=5,
160 | batch_size=8,
161 | extend="tile",
162 | )
163 |
--------------------------------------------------------------------------------
/utmosv2/config/fusion_stage2_wo_somos.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 | from types import SimpleNamespace
5 |
6 | from torchvision import transforms
7 |
8 | from utmosv2.transform import XYMasking
9 |
10 | batch_size = 16
11 | num_folds = 5
12 |
13 | sr = 16000
14 |
15 | preprocess = SimpleNamespace(
16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data")
17 | )
18 |
19 | split = SimpleNamespace(
20 | type="sgkf_kind",
21 | target="mos",
22 | group="sys_id",
23 | kind="dataset",
24 | )
25 |
26 | external_data: list[str] | str = [
27 | "sarulab",
28 | "blizzard2008",
29 | "blizzard2009",
30 | "blizzard2011",
31 | "blizzard2010-EH1",
32 | "blizzard2010-EH2",
33 | "blizzard2010-ES1",
34 | "blizzard2010-ES3",
35 | # "somos",
36 | ]
37 | use_bvcc = True
38 |
39 |
40 | validation_dataset = "each"
41 |
42 | dataset = SimpleNamespace(
43 | name="ssl_multispec_ext",
44 | specs=[
45 | SimpleNamespace(
46 | mode="melspec",
47 | n_fft=4096,
48 | hop_length=32,
49 | win_length=4096,
50 | n_mels=512,
51 | shape=(512, 512),
52 | norm=80,
53 | ),
54 | SimpleNamespace(
55 | mode="melspec",
56 | n_fft=4096,
57 | hop_length=32,
58 | win_length=2048,
59 | n_mels=512,
60 | shape=(512, 512),
61 | norm=80,
62 | ),
63 | SimpleNamespace(
64 | mode="melspec",
65 | n_fft=4096,
66 | hop_length=32,
67 | win_length=1024,
68 | n_mels=512,
69 | shape=(512, 512),
70 | norm=80,
71 | ),
72 | SimpleNamespace(
73 | mode="melspec",
74 | n_fft=4096,
75 | hop_length=32,
76 | win_length=512,
77 | n_mels=512,
78 | shape=(512, 512),
79 | norm=80,
80 | ),
81 | ],
82 | spec_frames=SimpleNamespace(
83 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile"
84 | ),
85 | ssl=SimpleNamespace(
86 | duration=3,
87 | ),
88 | )
89 | transform = dict(
90 | train=transforms.Compose(
91 | [
92 | transforms.Resize((512, 512)),
93 | XYMasking(
94 | num_masks_x=(0, 2),
95 | num_masks_y=(0, 2),
96 | mask_x_length=(10, 40),
97 | mask_y_length=(10, 30),
98 | fill_value=0,
99 | p=0.5,
100 | ),
101 | # transforms.ToTensor(),
102 | ]
103 | ),
104 | valid=transforms.Compose(
105 | [
106 | transforms.Resize((512, 512)),
107 | # transforms.ToTensor()
108 | ]
109 | ),
110 | )
111 |
112 | loss = [
113 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7),
114 | (SimpleNamespace(name="mse"), 0.2),
115 | ]
116 |
117 | optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4)
118 |
119 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-5)
120 |
121 | model = SimpleNamespace(
122 | name="ssl_multispec_ext_v2",
123 | multi_spec=SimpleNamespace(
124 | backbone="tf_efficientnetv2_s.in21k_ft_in1k",
125 | pretrained=True,
126 | num_classes=1,
127 | pool_type="catavgmax",
128 | # feature_height=16,
129 | atten=True,
130 | # classifier=None,
131 | ),
132 | ssl=SimpleNamespace(
133 | name="facebook/wav2vec2-base",
134 | attn=1,
135 | freeze=True,
136 | num_classes=1,
137 | ),
138 | ssl_spec=SimpleNamespace(
139 | ssl_weight="ssl_only_stage2_wo_somos",
140 | spec_weight="spec_only_wo_somos",
141 | num_classes=1,
142 | freeze=True,
143 | ),
144 | )
145 |
146 | run = SimpleNamespace(
147 | mixup=True,
148 | mixup_alpha=0.4,
149 | num_epochs=8,
150 | )
151 |
152 | main_metric = "sys_srcc"
153 | id_name = None
154 |
155 |
156 | inference = SimpleNamespace(
157 | save_path=Path("preds"),
158 | submit_save_path=Path("submissions"),
159 | num_tta=5,
160 | batch_size=8,
161 | extend="tile",
162 | )
163 |
--------------------------------------------------------------------------------
/utmosv2/config/fusion_stage3.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 | from types import SimpleNamespace
5 |
6 | from torchvision import transforms
7 |
8 | from utmosv2.transform import XYMasking
9 |
10 | batch_size = 8
11 | num_folds = 5
12 |
13 | sr = 16000
14 |
15 | preprocess = SimpleNamespace(
16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data")
17 | )
18 |
19 | split = SimpleNamespace(
20 | type="sgkf_kind",
21 | target="mos",
22 | group="sys_id",
23 | kind="dataset",
24 | )
25 |
26 | external_data: list[str] | str = "all"
27 | use_bvcc = True
28 |
29 |
30 | validation_dataset = "each"
31 |
32 | dataset = SimpleNamespace(
33 | name="ssl_multispec_ext",
34 | specs=[
35 | SimpleNamespace(
36 | mode="melspec",
37 | n_fft=4096,
38 | hop_length=32,
39 | win_length=4096,
40 | n_mels=512,
41 | shape=(512, 512),
42 | norm=80,
43 | ),
44 | SimpleNamespace(
45 | mode="melspec",
46 | n_fft=4096,
47 | hop_length=32,
48 | win_length=2048,
49 | n_mels=512,
50 | shape=(512, 512),
51 | norm=80,
52 | ),
53 | SimpleNamespace(
54 | mode="melspec",
55 | n_fft=4096,
56 | hop_length=32,
57 | win_length=1024,
58 | n_mels=512,
59 | shape=(512, 512),
60 | norm=80,
61 | ),
62 | SimpleNamespace(
63 | mode="melspec",
64 | n_fft=4096,
65 | hop_length=32,
66 | win_length=512,
67 | n_mels=512,
68 | shape=(512, 512),
69 | norm=80,
70 | ),
71 | ],
72 | spec_frames=SimpleNamespace(
73 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile"
74 | ),
75 | ssl=SimpleNamespace(
76 | duration=3,
77 | ),
78 | )
79 | transform = dict(
80 | train=transforms.Compose(
81 | [
82 | transforms.Resize((512, 512)),
83 | XYMasking(
84 | num_masks_x=(0, 2),
85 | num_masks_y=(0, 2),
86 | mask_x_length=(10, 40),
87 | mask_y_length=(10, 30),
88 | fill_value=0,
89 | p=0.5,
90 | ),
91 | # transforms.ToTensor(),
92 | ]
93 | ),
94 | valid=transforms.Compose(
95 | [
96 | transforms.Resize((512, 512)),
97 | # transforms.ToTensor()
98 | ]
99 | ),
100 | )
101 |
102 | loss = [
103 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7),
104 | (SimpleNamespace(name="mse"), 0.2),
105 | ]
106 |
107 | optimizer = SimpleNamespace(name="adamw", lr=5e-5, weight_decay=1e-4)
108 |
109 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-8)
110 |
111 | model = SimpleNamespace(
112 | name="ssl_multispec_ext_v2",
113 | multi_spec=SimpleNamespace(
114 | backbone="tf_efficientnetv2_s.in21k_ft_in1k",
115 | pretrained=True,
116 | num_classes=1,
117 | pool_type="catavgmax",
118 | # feature_height=16,
119 | atten=True,
120 | # classifier=None,
121 | ),
122 | ssl=SimpleNamespace(
123 | name="facebook/wav2vec2-base",
124 | attn=1,
125 | freeze=False,
126 | num_classes=1,
127 | ),
128 | ssl_spec=SimpleNamespace(
129 | ssl_weight="ssl_only_stage2",
130 | spec_weight="spec_only",
131 | num_classes=1,
132 | freeze=False,
133 | ),
134 | )
135 |
136 | run = SimpleNamespace(
137 | mixup=True,
138 | mixup_alpha=0.4,
139 | num_epochs=2,
140 | )
141 |
142 | main_metric = "sys_srcc"
143 | id_name = None
144 |
145 |
146 | inference = SimpleNamespace(
147 | save_path=Path("preds"),
148 | submit_save_path=Path("submissions"),
149 | num_tta=5,
150 | batch_size=8,
151 | extend="tile",
152 | )
153 |
--------------------------------------------------------------------------------
/utmosv2/config/fusion_stage3_wo_bc.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 | from types import SimpleNamespace
5 |
6 | from torchvision import transforms
7 |
8 | from utmosv2.transform import XYMasking
9 |
10 | batch_size = 8
11 | num_folds = 5
12 |
13 | sr = 16000
14 |
15 | preprocess = SimpleNamespace(
16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data")
17 | )
18 |
19 | split = SimpleNamespace(
20 | type="sgkf_kind",
21 | target="mos",
22 | group="sys_id",
23 | kind="dataset",
24 | )
25 |
26 | external_data: list[str] | str = [
27 | "sarulab",
28 | # "blizzard2008",
29 | # "blizzard2009",
30 | # "blizzard2011",
31 | # "blizzard2010-EH1",
32 | # "blizzard2010-EH2",
33 | # "blizzard2010-ES1",
34 | # "blizzard2010-ES3",
35 | "somos",
36 | ]
37 | use_bvcc = True
38 |
39 |
40 | validation_dataset = "each"
41 |
42 | dataset = SimpleNamespace(
43 | name="ssl_multispec_ext",
44 | specs=[
45 | SimpleNamespace(
46 | mode="melspec",
47 | n_fft=4096,
48 | hop_length=32,
49 | win_length=4096,
50 | n_mels=512,
51 | shape=(512, 512),
52 | norm=80,
53 | ),
54 | SimpleNamespace(
55 | mode="melspec",
56 | n_fft=4096,
57 | hop_length=32,
58 | win_length=2048,
59 | n_mels=512,
60 | shape=(512, 512),
61 | norm=80,
62 | ),
63 | SimpleNamespace(
64 | mode="melspec",
65 | n_fft=4096,
66 | hop_length=32,
67 | win_length=1024,
68 | n_mels=512,
69 | shape=(512, 512),
70 | norm=80,
71 | ),
72 | SimpleNamespace(
73 | mode="melspec",
74 | n_fft=4096,
75 | hop_length=32,
76 | win_length=512,
77 | n_mels=512,
78 | shape=(512, 512),
79 | norm=80,
80 | ),
81 | ],
82 | spec_frames=SimpleNamespace(
83 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile"
84 | ),
85 | ssl=SimpleNamespace(
86 | duration=3,
87 | ),
88 | )
89 | transform = dict(
90 | train=transforms.Compose(
91 | [
92 | transforms.Resize((512, 512)),
93 | XYMasking(
94 | num_masks_x=(0, 2),
95 | num_masks_y=(0, 2),
96 | mask_x_length=(10, 40),
97 | mask_y_length=(10, 30),
98 | fill_value=0,
99 | p=0.5,
100 | ),
101 | # transforms.ToTensor(),
102 | ]
103 | ),
104 | valid=transforms.Compose(
105 | [
106 | transforms.Resize((512, 512)),
107 | # transforms.ToTensor()
108 | ]
109 | ),
110 | )
111 |
112 | loss = [
113 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7),
114 | (SimpleNamespace(name="mse"), 0.2),
115 | ]
116 |
117 | optimizer = SimpleNamespace(name="adamw", lr=5e-5, weight_decay=1e-4)
118 |
119 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-8)
120 |
121 | model = SimpleNamespace(
122 | name="ssl_multispec_ext_v2",
123 | multi_spec=SimpleNamespace(
124 | backbone="tf_efficientnetv2_s.in21k_ft_in1k",
125 | pretrained=True,
126 | num_classes=1,
127 | pool_type="catavgmax",
128 | # feature_height=16,
129 | atten=True,
130 | # classifier=None,
131 | ),
132 | ssl=SimpleNamespace(
133 | name="facebook/wav2vec2-base",
134 | attn=1,
135 | freeze=False,
136 | num_classes=1,
137 | ),
138 | ssl_spec=SimpleNamespace(
139 | ssl_weight="ssl_only_stage2_wo_bc",
140 | spec_weight="spec_only_wo_bc",
141 | num_classes=1,
142 | freeze=False,
143 | ),
144 | )
145 |
146 | run = SimpleNamespace(
147 | mixup=True,
148 | mixup_alpha=0.4,
149 | num_epochs=2,
150 | )
151 |
152 | main_metric = "sys_srcc"
153 | id_name = None
154 |
155 |
156 | inference = SimpleNamespace(
157 | save_path=Path("preds"),
158 | submit_save_path=Path("submissions"),
159 | num_tta=5,
160 | batch_size=8,
161 | extend="tile",
162 | )
163 |
--------------------------------------------------------------------------------
/utmosv2/config/fusion_stage3_wo_bvcc.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 | from types import SimpleNamespace
5 |
6 | from torchvision import transforms
7 |
8 | from utmosv2.transform import XYMasking
9 |
10 | batch_size = 8
11 | num_folds = 5
12 |
13 | sr = 16000
14 |
15 | preprocess = SimpleNamespace(
16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data")
17 | )
18 |
19 | split = SimpleNamespace(
20 | type="sgkf_kind",
21 | target="mos",
22 | group="sys_id",
23 | kind="dataset",
24 | )
25 |
26 | external_data: list[str] | str = [
27 | "sarulab",
28 | "blizzard2008",
29 | "blizzard2009",
30 | "blizzard2011",
31 | "blizzard2010-EH1",
32 | "blizzard2010-EH2",
33 | "blizzard2010-ES1",
34 | "blizzard2010-ES3",
35 | "somos",
36 | ]
37 | use_bvcc = False
38 |
39 |
40 | validation_dataset = "each"
41 |
42 | dataset = SimpleNamespace(
43 | name="ssl_multispec_ext",
44 | specs=[
45 | SimpleNamespace(
46 | mode="melspec",
47 | n_fft=4096,
48 | hop_length=32,
49 | win_length=4096,
50 | n_mels=512,
51 | shape=(512, 512),
52 | norm=80,
53 | ),
54 | SimpleNamespace(
55 | mode="melspec",
56 | n_fft=4096,
57 | hop_length=32,
58 | win_length=2048,
59 | n_mels=512,
60 | shape=(512, 512),
61 | norm=80,
62 | ),
63 | SimpleNamespace(
64 | mode="melspec",
65 | n_fft=4096,
66 | hop_length=32,
67 | win_length=1024,
68 | n_mels=512,
69 | shape=(512, 512),
70 | norm=80,
71 | ),
72 | SimpleNamespace(
73 | mode="melspec",
74 | n_fft=4096,
75 | hop_length=32,
76 | win_length=512,
77 | n_mels=512,
78 | shape=(512, 512),
79 | norm=80,
80 | ),
81 | ],
82 | spec_frames=SimpleNamespace(
83 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile"
84 | ),
85 | ssl=SimpleNamespace(
86 | duration=3,
87 | ),
88 | )
89 | transform = dict(
90 | train=transforms.Compose(
91 | [
92 | transforms.Resize((512, 512)),
93 | XYMasking(
94 | num_masks_x=(0, 2),
95 | num_masks_y=(0, 2),
96 | mask_x_length=(10, 40),
97 | mask_y_length=(10, 30),
98 | fill_value=0,
99 | p=0.5,
100 | ),
101 | # transforms.ToTensor(),
102 | ]
103 | ),
104 | valid=transforms.Compose(
105 | [
106 | transforms.Resize((512, 512)),
107 | # transforms.ToTensor()
108 | ]
109 | ),
110 | )
111 |
112 | loss = [
113 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7),
114 | (SimpleNamespace(name="mse"), 0.2),
115 | ]
116 |
117 | optimizer = SimpleNamespace(name="adamw", lr=5e-5, weight_decay=1e-4)
118 |
119 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-8)
120 |
121 | model = SimpleNamespace(
122 | name="ssl_multispec_ext_v2",
123 | multi_spec=SimpleNamespace(
124 | backbone="tf_efficientnetv2_s.in21k_ft_in1k",
125 | pretrained=True,
126 | num_classes=1,
127 | pool_type="catavgmax",
128 | # feature_height=16,
129 | atten=True,
130 | # classifier=None,
131 | ),
132 | ssl=SimpleNamespace(
133 | name="facebook/wav2vec2-base",
134 | attn=1,
135 | freeze=False,
136 | num_classes=1,
137 | ),
138 | ssl_spec=SimpleNamespace(
139 | ssl_weight="ssl_only_stage2_wo_bvcc",
140 | spec_weight="spec_only_wo_bvcc",
141 | num_classes=1,
142 | freeze=False,
143 | ),
144 | )
145 |
146 | run = SimpleNamespace(
147 | mixup=True,
148 | mixup_alpha=0.4,
149 | num_epochs=2,
150 | )
151 |
152 | main_metric = "sys_srcc"
153 | id_name = None
154 |
155 |
156 | inference = SimpleNamespace(
157 | save_path=Path("preds"),
158 | submit_save_path=Path("submissions"),
159 | num_tta=5,
160 | batch_size=8,
161 | extend="tile",
162 | )
163 |
--------------------------------------------------------------------------------
/utmosv2/config/fusion_stage3_wo_sarulab.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 | from types import SimpleNamespace
5 |
6 | from torchvision import transforms
7 |
8 | from utmosv2.transform import XYMasking
9 |
10 | batch_size = 8
11 | num_folds = 5
12 |
13 | sr = 16000
14 |
15 | preprocess = SimpleNamespace(
16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data")
17 | )
18 |
19 | split = SimpleNamespace(
20 | type="sgkf_kind",
21 | target="mos",
22 | group="sys_id",
23 | kind="dataset",
24 | )
25 |
26 | external_data: list[str] | str = [
27 | # "sarulab",
28 | "blizzard2008",
29 | "blizzard2009",
30 | "blizzard2011",
31 | "blizzard2010-EH1",
32 | "blizzard2010-EH2",
33 | "blizzard2010-ES1",
34 | "blizzard2010-ES3",
35 | "somos",
36 | ]
37 | use_bvcc = True
38 |
39 |
40 | validation_dataset = "each"
41 |
42 | dataset = SimpleNamespace(
43 | name="ssl_multispec_ext",
44 | specs=[
45 | SimpleNamespace(
46 | mode="melspec",
47 | n_fft=4096,
48 | hop_length=32,
49 | win_length=4096,
50 | n_mels=512,
51 | shape=(512, 512),
52 | norm=80,
53 | ),
54 | SimpleNamespace(
55 | mode="melspec",
56 | n_fft=4096,
57 | hop_length=32,
58 | win_length=2048,
59 | n_mels=512,
60 | shape=(512, 512),
61 | norm=80,
62 | ),
63 | SimpleNamespace(
64 | mode="melspec",
65 | n_fft=4096,
66 | hop_length=32,
67 | win_length=1024,
68 | n_mels=512,
69 | shape=(512, 512),
70 | norm=80,
71 | ),
72 | SimpleNamespace(
73 | mode="melspec",
74 | n_fft=4096,
75 | hop_length=32,
76 | win_length=512,
77 | n_mels=512,
78 | shape=(512, 512),
79 | norm=80,
80 | ),
81 | ],
82 | spec_frames=SimpleNamespace(
83 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile"
84 | ),
85 | ssl=SimpleNamespace(
86 | duration=3,
87 | ),
88 | )
89 | transform = dict(
90 | train=transforms.Compose(
91 | [
92 | transforms.Resize((512, 512)),
93 | XYMasking(
94 | num_masks_x=(0, 2),
95 | num_masks_y=(0, 2),
96 | mask_x_length=(10, 40),
97 | mask_y_length=(10, 30),
98 | fill_value=0,
99 | p=0.5,
100 | ),
101 | # transforms.ToTensor(),
102 | ]
103 | ),
104 | valid=transforms.Compose(
105 | [
106 | transforms.Resize((512, 512)),
107 | # transforms.ToTensor()
108 | ]
109 | ),
110 | )
111 |
112 | loss = [
113 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7),
114 | (SimpleNamespace(name="mse"), 0.2),
115 | ]
116 |
117 | optimizer = SimpleNamespace(name="adamw", lr=5e-5, weight_decay=1e-4)
118 |
119 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-8)
120 |
121 | model = SimpleNamespace(
122 | name="ssl_multispec_ext_v2",
123 | multi_spec=SimpleNamespace(
124 | backbone="tf_efficientnetv2_s.in21k_ft_in1k",
125 | pretrained=True,
126 | num_classes=1,
127 | pool_type="catavgmax",
128 | # feature_height=16,
129 | atten=True,
130 | # classifier=None,
131 | ),
132 | ssl=SimpleNamespace(
133 | name="facebook/wav2vec2-base",
134 | attn=1,
135 | freeze=False,
136 | num_classes=1,
137 | ),
138 | ssl_spec=SimpleNamespace(
139 | ssl_weight="ssl_only_stage2_wo_sarulab",
140 | spec_weight="spec_only_wo_sarulab",
141 | num_classes=1,
142 | freeze=False,
143 | ),
144 | )
145 |
146 | run = SimpleNamespace(
147 | mixup=True,
148 | mixup_alpha=0.4,
149 | num_epochs=2,
150 | )
151 |
152 | main_metric = "sys_srcc"
153 | id_name = None
154 |
155 |
156 | inference = SimpleNamespace(
157 | save_path=Path("preds"),
158 | submit_save_path=Path("submissions"),
159 | num_tta=5,
160 | batch_size=8,
161 | extend="tile",
162 | )
163 |
--------------------------------------------------------------------------------
/utmosv2/config/fusion_stage3_wo_somos.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 | from types import SimpleNamespace
5 |
6 | from torchvision import transforms
7 |
8 | from utmosv2.transform import XYMasking
9 |
10 | batch_size = 8
11 | num_folds = 5
12 |
13 | sr = 16000
14 |
15 | preprocess = SimpleNamespace(
16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data")
17 | )
18 |
19 | split = SimpleNamespace(
20 | type="sgkf_kind",
21 | target="mos",
22 | group="sys_id",
23 | kind="dataset",
24 | )
25 |
26 | external_data: list[str] | str = [
27 | "sarulab",
28 | "blizzard2008",
29 | "blizzard2009",
30 | "blizzard2011",
31 | "blizzard2010-EH1",
32 | "blizzard2010-EH2",
33 | "blizzard2010-ES1",
34 | "blizzard2010-ES3",
35 | # "somos",
36 | ]
37 | use_bvcc = True
38 |
39 |
40 | validation_dataset = "each"
41 |
42 | dataset = SimpleNamespace(
43 | name="ssl_multispec_ext",
44 | specs=[
45 | SimpleNamespace(
46 | mode="melspec",
47 | n_fft=4096,
48 | hop_length=32,
49 | win_length=4096,
50 | n_mels=512,
51 | shape=(512, 512),
52 | norm=80,
53 | ),
54 | SimpleNamespace(
55 | mode="melspec",
56 | n_fft=4096,
57 | hop_length=32,
58 | win_length=2048,
59 | n_mels=512,
60 | shape=(512, 512),
61 | norm=80,
62 | ),
63 | SimpleNamespace(
64 | mode="melspec",
65 | n_fft=4096,
66 | hop_length=32,
67 | win_length=1024,
68 | n_mels=512,
69 | shape=(512, 512),
70 | norm=80,
71 | ),
72 | SimpleNamespace(
73 | mode="melspec",
74 | n_fft=4096,
75 | hop_length=32,
76 | win_length=512,
77 | n_mels=512,
78 | shape=(512, 512),
79 | norm=80,
80 | ),
81 | ],
82 | spec_frames=SimpleNamespace(
83 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile"
84 | ),
85 | ssl=SimpleNamespace(
86 | duration=3,
87 | ),
88 | )
89 | transform = dict(
90 | train=transforms.Compose(
91 | [
92 | transforms.Resize((512, 512)),
93 | XYMasking(
94 | num_masks_x=(0, 2),
95 | num_masks_y=(0, 2),
96 | mask_x_length=(10, 40),
97 | mask_y_length=(10, 30),
98 | fill_value=0,
99 | p=0.5,
100 | ),
101 | # transforms.ToTensor(),
102 | ]
103 | ),
104 | valid=transforms.Compose(
105 | [
106 | transforms.Resize((512, 512)),
107 | # transforms.ToTensor()
108 | ]
109 | ),
110 | )
111 |
112 | loss = [
113 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7),
114 | (SimpleNamespace(name="mse"), 0.2),
115 | ]
116 |
117 | optimizer = SimpleNamespace(name="adamw", lr=5e-5, weight_decay=1e-4)
118 |
119 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-8)
120 |
121 | model = SimpleNamespace(
122 | name="ssl_multispec_ext_v2",
123 | multi_spec=SimpleNamespace(
124 | backbone="tf_efficientnetv2_s.in21k_ft_in1k",
125 | pretrained=True,
126 | num_classes=1,
127 | pool_type="catavgmax",
128 | # feature_height=16,
129 | atten=True,
130 | # classifier=None,
131 | ),
132 | ssl=SimpleNamespace(
133 | name="facebook/wav2vec2-base",
134 | attn=1,
135 | freeze=False,
136 | num_classes=1,
137 | ),
138 | ssl_spec=SimpleNamespace(
139 | ssl_weight="ssl_only_stage2_wo_somos",
140 | spec_weight="spec_only_wo_somos",
141 | num_classes=1,
142 | freeze=False,
143 | ),
144 | )
145 |
146 | run = SimpleNamespace(
147 | mixup=True,
148 | mixup_alpha=0.4,
149 | num_epochs=2,
150 | )
151 |
152 | main_metric = "sys_srcc"
153 | id_name = None
154 |
155 |
156 | inference = SimpleNamespace(
157 | save_path=Path("preds"),
158 | submit_save_path=Path("submissions"),
159 | num_tta=5,
160 | batch_size=8,
161 | extend="tile",
162 | )
163 |
--------------------------------------------------------------------------------
/utmosv2/config/fusion_wo_stage1and2.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 | from types import SimpleNamespace
5 |
6 | from torchvision import transforms
7 |
8 | from utmosv2.transform import XYMasking
9 |
10 | batch_size = 8
11 | num_folds = 5
12 |
13 | sr = 16000
14 |
15 | preprocess = SimpleNamespace(
16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data")
17 | )
18 |
19 | split = SimpleNamespace(
20 | type="sgkf_kind",
21 | target="mos",
22 | group="sys_id",
23 | kind="dataset",
24 | )
25 |
26 | external_data: list[str] | str = "all"
27 | use_bvcc = True
28 |
29 |
30 | validation_dataset = "each"
31 |
32 | dataset = SimpleNamespace(
33 | name="ssl_multispec_ext",
34 | specs=[
35 | SimpleNamespace(
36 | mode="melspec",
37 | n_fft=4096,
38 | hop_length=32,
39 | win_length=4096,
40 | n_mels=512,
41 | shape=(512, 512),
42 | norm=80,
43 | ),
44 | SimpleNamespace(
45 | mode="melspec",
46 | n_fft=4096,
47 | hop_length=32,
48 | win_length=2048,
49 | n_mels=512,
50 | shape=(512, 512),
51 | norm=80,
52 | ),
53 | SimpleNamespace(
54 | mode="melspec",
55 | n_fft=4096,
56 | hop_length=32,
57 | win_length=1024,
58 | n_mels=512,
59 | shape=(512, 512),
60 | norm=80,
61 | ),
62 | SimpleNamespace(
63 | mode="melspec",
64 | n_fft=4096,
65 | hop_length=32,
66 | win_length=512,
67 | n_mels=512,
68 | shape=(512, 512),
69 | norm=80,
70 | ),
71 | ],
72 | spec_frames=SimpleNamespace(
73 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile"
74 | ),
75 | ssl=SimpleNamespace(
76 | duration=3,
77 | ),
78 | )
79 | transform = dict(
80 | train=transforms.Compose(
81 | [
82 | transforms.Resize((512, 512)),
83 | XYMasking(
84 | num_masks_x=(0, 2),
85 | num_masks_y=(0, 2),
86 | mask_x_length=(10, 40),
87 | mask_y_length=(10, 30),
88 | fill_value=0,
89 | p=0.5,
90 | ),
91 | # transforms.ToTensor(),
92 | ]
93 | ),
94 | valid=transforms.Compose(
95 | [
96 | transforms.Resize((512, 512)),
97 | # transforms.ToTensor()
98 | ]
99 | ),
100 | )
101 |
102 | loss = [
103 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7),
104 | (SimpleNamespace(name="mse"), 0.2),
105 | ]
106 |
107 | optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4)
108 |
109 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7)
110 |
111 | model = SimpleNamespace(
112 | name="ssl_multispec_ext_v2",
113 | multi_spec=SimpleNamespace(
114 | backbone="tf_efficientnetv2_s.in21k_ft_in1k",
115 | pretrained=True,
116 | num_classes=1,
117 | pool_type="catavgmax",
118 | # feature_height=16,
119 | atten=True,
120 | # classifier=None,
121 | ),
122 | ssl=SimpleNamespace(
123 | name="facebook/wav2vec2-base",
124 | attn=1,
125 | freeze=False,
126 | num_classes=1,
127 | ),
128 | ssl_spec=SimpleNamespace(
129 | ssl_weight=None,
130 | spec_weight=None,
131 | num_classes=1,
132 | freeze=False,
133 | ),
134 | )
135 |
136 | run = SimpleNamespace(
137 | mixup=True,
138 | mixup_alpha=0.4,
139 | num_epochs=20,
140 | )
141 |
142 | main_metric = "sys_srcc"
143 | id_name = None
144 |
145 |
146 | inference = SimpleNamespace(
147 | save_path=Path("preds"),
148 | submit_save_path=Path("submissions"),
149 | num_tta=5,
150 | batch_size=8,
151 | extend="tile",
152 | )
153 |
--------------------------------------------------------------------------------
/utmosv2/config/fusion_wo_stage2.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 | from types import SimpleNamespace
5 |
6 | from torchvision import transforms
7 |
8 | from utmosv2.transform import XYMasking
9 |
10 | batch_size = 8
11 | num_folds = 5
12 |
13 | sr = 16000
14 |
15 | preprocess = SimpleNamespace(
16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data")
17 | )
18 |
19 | split = SimpleNamespace(
20 | type="sgkf_kind",
21 | target="mos",
22 | group="sys_id",
23 | kind="dataset",
24 | )
25 |
26 | external_data: list[str] | str = "all"
27 | use_bvcc = True
28 |
29 |
30 | validation_dataset = "each"
31 |
32 | dataset = SimpleNamespace(
33 | name="ssl_multispec_ext",
34 | specs=[
35 | SimpleNamespace(
36 | mode="melspec",
37 | n_fft=4096,
38 | hop_length=32,
39 | win_length=4096,
40 | n_mels=512,
41 | shape=(512, 512),
42 | norm=80,
43 | ),
44 | SimpleNamespace(
45 | mode="melspec",
46 | n_fft=4096,
47 | hop_length=32,
48 | win_length=2048,
49 | n_mels=512,
50 | shape=(512, 512),
51 | norm=80,
52 | ),
53 | SimpleNamespace(
54 | mode="melspec",
55 | n_fft=4096,
56 | hop_length=32,
57 | win_length=1024,
58 | n_mels=512,
59 | shape=(512, 512),
60 | norm=80,
61 | ),
62 | SimpleNamespace(
63 | mode="melspec",
64 | n_fft=4096,
65 | hop_length=32,
66 | win_length=512,
67 | n_mels=512,
68 | shape=(512, 512),
69 | norm=80,
70 | ),
71 | ],
72 | spec_frames=SimpleNamespace(
73 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile"
74 | ),
75 | ssl=SimpleNamespace(
76 | duration=3,
77 | ),
78 | )
79 | transform = dict(
80 | train=transforms.Compose(
81 | [
82 | transforms.Resize((512, 512)),
83 | XYMasking(
84 | num_masks_x=(0, 2),
85 | num_masks_y=(0, 2),
86 | mask_x_length=(10, 40),
87 | mask_y_length=(10, 30),
88 | fill_value=0,
89 | p=0.5,
90 | ),
91 | # transforms.ToTensor(),
92 | ]
93 | ),
94 | valid=transforms.Compose(
95 | [
96 | transforms.Resize((512, 512)),
97 | # transforms.ToTensor()
98 | ]
99 | ),
100 | )
101 |
102 | loss = [
103 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7),
104 | (SimpleNamespace(name="mse"), 0.2),
105 | ]
106 |
107 | optimizer = SimpleNamespace(name="adamw", lr=1e-4, weight_decay=1e-4)
108 |
109 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7)
110 |
111 | model = SimpleNamespace(
112 | name="ssl_multispec_ext_v2",
113 | multi_spec=SimpleNamespace(
114 | backbone="tf_efficientnetv2_s.in21k_ft_in1k",
115 | pretrained=True,
116 | num_classes=1,
117 | pool_type="catavgmax",
118 | # feature_height=16,
119 | atten=True,
120 | # classifier=None,
121 | ),
122 | ssl=SimpleNamespace(
123 | name="facebook/wav2vec2-base",
124 | attn=1,
125 | freeze=False,
126 | num_classes=1,
127 | ),
128 | ssl_spec=SimpleNamespace(
129 | ssl_weight="ssl_only_stage2",
130 | spec_weight="spec_only",
131 | num_classes=1,
132 | freeze=False,
133 | ),
134 | )
135 |
136 | run = SimpleNamespace(
137 | mixup=True,
138 | mixup_alpha=0.4,
139 | num_epochs=20,
140 | )
141 |
142 | main_metric = "sys_srcc"
143 | id_name = None
144 |
145 |
146 | inference = SimpleNamespace(
147 | save_path=Path("preds"),
148 | submit_save_path=Path("submissions"),
149 | num_tta=5,
150 | batch_size=8,
151 | extend="tile",
152 | )
153 |
--------------------------------------------------------------------------------
/utmosv2/config/spec_only.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 | from types import SimpleNamespace
5 |
6 | from torchvision import transforms
7 |
8 | from utmosv2.transform import XYMasking
9 |
10 | batch_size = 10
11 | num_folds = 5
12 |
13 | sr = 16000
14 |
15 | preprocess = SimpleNamespace(
16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data")
17 | )
18 |
19 | split = SimpleNamespace(
20 | type="sgkf_kind",
21 | target="mos",
22 | group="sys_id",
23 | kind="dataset",
24 | )
25 |
26 | external_data: list[str] | str = "all"
27 | use_bvcc = True
28 |
29 |
30 | validation_dataset = "each"
31 |
32 | dataset = SimpleNamespace(
33 | name="multi_spec_ext",
34 | specs=[
35 | SimpleNamespace(
36 | mode="melspec",
37 | n_fft=4096,
38 | hop_length=32,
39 | win_length=4096,
40 | n_mels=512,
41 | shape=(512, 512),
42 | norm=80,
43 | ),
44 | SimpleNamespace(
45 | mode="melspec",
46 | n_fft=4096,
47 | hop_length=32,
48 | win_length=2048,
49 | n_mels=512,
50 | shape=(512, 512),
51 | norm=80,
52 | ),
53 | SimpleNamespace(
54 | mode="melspec",
55 | n_fft=4096,
56 | hop_length=32,
57 | win_length=1024,
58 | n_mels=512,
59 | shape=(512, 512),
60 | norm=80,
61 | ),
62 | SimpleNamespace(
63 | mode="melspec",
64 | n_fft=4096,
65 | hop_length=32,
66 | win_length=512,
67 | n_mels=512,
68 | shape=(512, 512),
69 | norm=80,
70 | ),
71 | ],
72 | spec_frames=SimpleNamespace(
73 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile"
74 | ),
75 | )
76 | transform = dict(
77 | train=transforms.Compose(
78 | [
79 | transforms.Resize((512, 512)),
80 | XYMasking(
81 | num_masks_x=(0, 2),
82 | num_masks_y=(0, 2),
83 | mask_x_length=(10, 40),
84 | mask_y_length=(10, 30),
85 | fill_value=0,
86 | p=0.5,
87 | ),
88 | # transforms.ToTensor(),
89 | ]
90 | ),
91 | valid=transforms.Compose(
92 | [
93 | transforms.Resize((512, 512)),
94 | # transforms.ToTensor()
95 | ]
96 | ),
97 | )
98 |
99 | loss = [
100 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7),
101 | (SimpleNamespace(name="mse"), 0.2),
102 | ]
103 |
104 | optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4)
105 |
106 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7)
107 |
108 | model = SimpleNamespace(
109 | name="multi_spec_ext",
110 | multi_spec=SimpleNamespace(
111 | backbone="tf_efficientnetv2_s.in21k_ft_in1k",
112 | pretrained=True,
113 | num_classes=1,
114 | pool_type="catavgmax",
115 | # feature_height=16,
116 | atten=True,
117 | # classifier=None,
118 | ),
119 | )
120 |
121 | run = SimpleNamespace(
122 | mixup=True,
123 | mixup_alpha=0.4,
124 | num_epochs=20,
125 | )
126 |
127 | main_metric = "sys_srcc"
128 | id_name = None
129 |
130 |
131 | inference = SimpleNamespace(
132 | save_path=Path("preds"),
133 | submit_save_path=Path("submissions"),
134 | num_tta=5,
135 | batch_size=8,
136 | extend="tile",
137 | )
138 |
--------------------------------------------------------------------------------
/utmosv2/config/spec_only_wo_bc.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 | from types import SimpleNamespace
5 |
6 | from torchvision import transforms
7 |
8 | from utmosv2.transform import XYMasking
9 |
10 | batch_size = 10
11 | num_folds = 5
12 |
13 | sr = 16000
14 |
15 | preprocess = SimpleNamespace(
16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data")
17 | )
18 |
19 | split = SimpleNamespace(
20 | type="sgkf_kind",
21 | target="mos",
22 | group="sys_id",
23 | kind="dataset",
24 | )
25 |
26 | external_data: list[str] | str = [
27 | "sarulab",
28 | # "blizzard2008",
29 | # "blizzard2009",
30 | # "blizzard2011",
31 | # "blizzard2010-EH1",
32 | # "blizzard2010-EH2",
33 | # "blizzard2010-ES1",
34 | # "blizzard2010-ES3",
35 | "somos",
36 | ]
37 | use_bvcc = True
38 |
39 |
40 | validation_dataset = "each"
41 |
42 | dataset = SimpleNamespace(
43 | name="multi_spec_ext",
44 | specs=[
45 | SimpleNamespace(
46 | mode="melspec",
47 | n_fft=4096,
48 | hop_length=32,
49 | win_length=4096,
50 | n_mels=512,
51 | shape=(512, 512),
52 | norm=80,
53 | ),
54 | SimpleNamespace(
55 | mode="melspec",
56 | n_fft=4096,
57 | hop_length=32,
58 | win_length=2048,
59 | n_mels=512,
60 | shape=(512, 512),
61 | norm=80,
62 | ),
63 | SimpleNamespace(
64 | mode="melspec",
65 | n_fft=4096,
66 | hop_length=32,
67 | win_length=1024,
68 | n_mels=512,
69 | shape=(512, 512),
70 | norm=80,
71 | ),
72 | SimpleNamespace(
73 | mode="melspec",
74 | n_fft=4096,
75 | hop_length=32,
76 | win_length=512,
77 | n_mels=512,
78 | shape=(512, 512),
79 | norm=80,
80 | ),
81 | ],
82 | spec_frames=SimpleNamespace(
83 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile"
84 | ),
85 | )
86 | transform = dict(
87 | train=transforms.Compose(
88 | [
89 | transforms.Resize((512, 512)),
90 | XYMasking(
91 | num_masks_x=(0, 2),
92 | num_masks_y=(0, 2),
93 | mask_x_length=(10, 40),
94 | mask_y_length=(10, 30),
95 | fill_value=0,
96 | p=0.5,
97 | ),
98 | # transforms.ToTensor(),
99 | ]
100 | ),
101 | valid=transforms.Compose(
102 | [
103 | transforms.Resize((512, 512)),
104 | # transforms.ToTensor()
105 | ]
106 | ),
107 | )
108 |
109 | loss = [
110 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7),
111 | (SimpleNamespace(name="mse"), 0.2),
112 | ]
113 |
114 | optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4)
115 |
116 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7)
117 |
118 | model = SimpleNamespace(
119 | name="multi_spec_ext",
120 | multi_spec=SimpleNamespace(
121 | backbone="tf_efficientnetv2_s.in21k_ft_in1k",
122 | pretrained=True,
123 | num_classes=1,
124 | pool_type="catavgmax",
125 | # feature_height=16,
126 | atten=True,
127 | # classifier=None,
128 | ),
129 | )
130 |
131 | run = SimpleNamespace(
132 | mixup=True,
133 | mixup_alpha=0.4,
134 | num_epochs=20,
135 | )
136 |
137 | main_metric = "sys_srcc"
138 | id_name = None
139 |
140 |
141 | inference = SimpleNamespace(
142 | save_path=Path("preds"),
143 | submit_save_path=Path("submissions"),
144 | num_tta=5,
145 | batch_size=8,
146 | extend="tile",
147 | )
148 |
--------------------------------------------------------------------------------
/utmosv2/config/spec_only_wo_bvcc.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 | from types import SimpleNamespace
5 |
6 | from torchvision import transforms
7 |
8 | from utmosv2.transform import XYMasking
9 |
10 | batch_size = 10
11 | num_folds = 5
12 |
13 | sr = 16000
14 |
15 | preprocess = SimpleNamespace(
16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data")
17 | )
18 |
19 | split = SimpleNamespace(
20 | type="sgkf_kind",
21 | target="mos",
22 | group="sys_id",
23 | kind="dataset",
24 | )
25 |
26 | external_data: list[str] | str = [
27 | "sarulab",
28 | "blizzard2008",
29 | "blizzard2009",
30 | "blizzard2011",
31 | "blizzard2010-EH1",
32 | "blizzard2010-EH2",
33 | "blizzard2010-ES1",
34 | "blizzard2010-ES3",
35 | "somos",
36 | ]
37 | use_bvcc = False
38 |
39 |
40 | validation_dataset = "each"
41 |
42 | dataset = SimpleNamespace(
43 | name="multi_spec_ext",
44 | specs=[
45 | SimpleNamespace(
46 | mode="melspec",
47 | n_fft=4096,
48 | hop_length=32,
49 | win_length=4096,
50 | n_mels=512,
51 | shape=(512, 512),
52 | norm=80,
53 | ),
54 | SimpleNamespace(
55 | mode="melspec",
56 | n_fft=4096,
57 | hop_length=32,
58 | win_length=2048,
59 | n_mels=512,
60 | shape=(512, 512),
61 | norm=80,
62 | ),
63 | SimpleNamespace(
64 | mode="melspec",
65 | n_fft=4096,
66 | hop_length=32,
67 | win_length=1024,
68 | n_mels=512,
69 | shape=(512, 512),
70 | norm=80,
71 | ),
72 | SimpleNamespace(
73 | mode="melspec",
74 | n_fft=4096,
75 | hop_length=32,
76 | win_length=512,
77 | n_mels=512,
78 | shape=(512, 512),
79 | norm=80,
80 | ),
81 | ],
82 | spec_frames=SimpleNamespace(
83 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile"
84 | ),
85 | )
86 | transform = dict(
87 | train=transforms.Compose(
88 | [
89 | transforms.Resize((512, 512)),
90 | XYMasking(
91 | num_masks_x=(0, 2),
92 | num_masks_y=(0, 2),
93 | mask_x_length=(10, 40),
94 | mask_y_length=(10, 30),
95 | fill_value=0,
96 | p=0.5,
97 | ),
98 | # transforms.ToTensor(),
99 | ]
100 | ),
101 | valid=transforms.Compose(
102 | [
103 | transforms.Resize((512, 512)),
104 | # transforms.ToTensor()
105 | ]
106 | ),
107 | )
108 |
109 | loss = [
110 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7),
111 | (SimpleNamespace(name="mse"), 0.2),
112 | ]
113 |
114 | optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4)
115 |
116 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7)
117 |
118 | model = SimpleNamespace(
119 | name="multi_spec_ext",
120 | multi_spec=SimpleNamespace(
121 | backbone="tf_efficientnetv2_s.in21k_ft_in1k",
122 | pretrained=True,
123 | num_classes=1,
124 | pool_type="catavgmax",
125 | # feature_height=16,
126 | atten=True,
127 | # classifier=None,
128 | ),
129 | )
130 |
131 | run = SimpleNamespace(
132 | mixup=True,
133 | mixup_alpha=0.4,
134 | num_epochs=20,
135 | )
136 |
137 | main_metric = "sys_srcc"
138 | id_name = None
139 |
140 |
141 | inference = SimpleNamespace(
142 | save_path=Path("preds"),
143 | submit_save_path=Path("submissions"),
144 | num_tta=5,
145 | batch_size=8,
146 | extend="tile",
147 | )
148 |
--------------------------------------------------------------------------------
/utmosv2/config/spec_only_wo_sarulab.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 | from types import SimpleNamespace
5 |
6 | from torchvision import transforms
7 |
8 | from utmosv2.transform import XYMasking
9 |
10 | batch_size = 10
11 | num_folds = 5
12 |
13 | sr = 16000
14 |
15 | preprocess = SimpleNamespace(
16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data")
17 | )
18 |
19 | split = SimpleNamespace(
20 | type="sgkf_kind",
21 | target="mos",
22 | group="sys_id",
23 | kind="dataset",
24 | )
25 |
26 | external_data: list[str] | str = [
27 | # "sarulab",
28 | "blizzard2008",
29 | "blizzard2009",
30 | "blizzard2011",
31 | "blizzard2010-EH1",
32 | "blizzard2010-EH2",
33 | "blizzard2010-ES1",
34 | "blizzard2010-ES3",
35 | "somos",
36 | ]
37 | use_bvcc = True
38 |
39 |
40 | validation_dataset = "each"
41 |
42 | dataset = SimpleNamespace(
43 | name="multi_spec_ext",
44 | specs=[
45 | SimpleNamespace(
46 | mode="melspec",
47 | n_fft=4096,
48 | hop_length=32,
49 | win_length=4096,
50 | n_mels=512,
51 | shape=(512, 512),
52 | norm=80,
53 | ),
54 | SimpleNamespace(
55 | mode="melspec",
56 | n_fft=4096,
57 | hop_length=32,
58 | win_length=2048,
59 | n_mels=512,
60 | shape=(512, 512),
61 | norm=80,
62 | ),
63 | SimpleNamespace(
64 | mode="melspec",
65 | n_fft=4096,
66 | hop_length=32,
67 | win_length=1024,
68 | n_mels=512,
69 | shape=(512, 512),
70 | norm=80,
71 | ),
72 | SimpleNamespace(
73 | mode="melspec",
74 | n_fft=4096,
75 | hop_length=32,
76 | win_length=512,
77 | n_mels=512,
78 | shape=(512, 512),
79 | norm=80,
80 | ),
81 | ],
82 | spec_frames=SimpleNamespace(
83 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile"
84 | ),
85 | )
86 | transform = dict(
87 | train=transforms.Compose(
88 | [
89 | transforms.Resize((512, 512)),
90 | XYMasking(
91 | num_masks_x=(0, 2),
92 | num_masks_y=(0, 2),
93 | mask_x_length=(10, 40),
94 | mask_y_length=(10, 30),
95 | fill_value=0,
96 | p=0.5,
97 | ),
98 | # transforms.ToTensor(),
99 | ]
100 | ),
101 | valid=transforms.Compose(
102 | [
103 | transforms.Resize((512, 512)),
104 | # transforms.ToTensor()
105 | ]
106 | ),
107 | )
108 |
109 | loss = [
110 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7),
111 | (SimpleNamespace(name="mse"), 0.2),
112 | ]
113 |
114 | optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4)
115 |
116 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7)
117 |
118 | model = SimpleNamespace(
119 | name="multi_spec_ext",
120 | multi_spec=SimpleNamespace(
121 | backbone="tf_efficientnetv2_s.in21k_ft_in1k",
122 | pretrained=True,
123 | num_classes=1,
124 | pool_type="catavgmax",
125 | # feature_height=16,
126 | atten=True,
127 | # classifier=None,
128 | ),
129 | )
130 |
131 | run = SimpleNamespace(
132 | mixup=True,
133 | mixup_alpha=0.4,
134 | num_epochs=20,
135 | )
136 |
137 | main_metric = "sys_srcc"
138 | id_name = None
139 |
140 |
141 | inference = SimpleNamespace(
142 | save_path=Path("preds"),
143 | submit_save_path=Path("submissions"),
144 | num_tta=5,
145 | batch_size=8,
146 | extend="tile",
147 | )
148 |
--------------------------------------------------------------------------------
/utmosv2/config/spec_only_wo_somos.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 | from types import SimpleNamespace
5 |
6 | from torchvision import transforms
7 |
8 | from utmosv2.transform import XYMasking
9 |
10 | batch_size = 10
11 | num_folds = 5
12 |
13 | sr = 16000
14 |
15 | preprocess = SimpleNamespace(
16 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data")
17 | )
18 |
19 | split = SimpleNamespace(
20 | type="sgkf_kind",
21 | target="mos",
22 | group="sys_id",
23 | kind="dataset",
24 | )
25 |
26 | external_data: list[str] | str = [
27 | "sarulab",
28 | "blizzard2008",
29 | "blizzard2009",
30 | "blizzard2011",
31 | "blizzard2010-EH1",
32 | "blizzard2010-EH2",
33 | "blizzard2010-ES1",
34 | "blizzard2010-ES3",
35 | # "somos",
36 | ]
37 | use_bvcc = True
38 |
39 |
40 | validation_dataset = "each"
41 |
42 | dataset = SimpleNamespace(
43 | name="multi_spec_ext",
44 | specs=[
45 | SimpleNamespace(
46 | mode="melspec",
47 | n_fft=4096,
48 | hop_length=32,
49 | win_length=4096,
50 | n_mels=512,
51 | shape=(512, 512),
52 | norm=80,
53 | ),
54 | SimpleNamespace(
55 | mode="melspec",
56 | n_fft=4096,
57 | hop_length=32,
58 | win_length=2048,
59 | n_mels=512,
60 | shape=(512, 512),
61 | norm=80,
62 | ),
63 | SimpleNamespace(
64 | mode="melspec",
65 | n_fft=4096,
66 | hop_length=32,
67 | win_length=1024,
68 | n_mels=512,
69 | shape=(512, 512),
70 | norm=80,
71 | ),
72 | SimpleNamespace(
73 | mode="melspec",
74 | n_fft=4096,
75 | hop_length=32,
76 | win_length=512,
77 | n_mels=512,
78 | shape=(512, 512),
79 | norm=80,
80 | ),
81 | ],
82 | spec_frames=SimpleNamespace(
83 | num_frames=2, frame_sec=1.4, mixup_inner=True, mixup_alpha=0.4, extend="tile"
84 | ),
85 | )
86 | transform = dict(
87 | train=transforms.Compose(
88 | [
89 | transforms.Resize((512, 512)),
90 | XYMasking(
91 | num_masks_x=(0, 2),
92 | num_masks_y=(0, 2),
93 | mask_x_length=(10, 40),
94 | mask_y_length=(10, 30),
95 | fill_value=0,
96 | p=0.5,
97 | ),
98 | # transforms.ToTensor(),
99 | ]
100 | ),
101 | valid=transforms.Compose(
102 | [
103 | transforms.Resize((512, 512)),
104 | # transforms.ToTensor()
105 | ]
106 | ),
107 | )
108 |
109 | loss = [
110 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7),
111 | (SimpleNamespace(name="mse"), 0.2),
112 | ]
113 |
114 | optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4)
115 |
116 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7)
117 |
118 | model = SimpleNamespace(
119 | name="multi_spec_ext",
120 | multi_spec=SimpleNamespace(
121 | backbone="tf_efficientnetv2_s.in21k_ft_in1k",
122 | pretrained=True,
123 | num_classes=1,
124 | pool_type="catavgmax",
125 | # feature_height=16,
126 | atten=True,
127 | # classifier=None,
128 | ),
129 | )
130 |
131 | run = SimpleNamespace(
132 | mixup=True,
133 | mixup_alpha=0.4,
134 | num_epochs=20,
135 | )
136 |
137 | main_metric = "sys_srcc"
138 | id_name = None
139 |
140 |
141 | inference = SimpleNamespace(
142 | save_path=Path("preds"),
143 | submit_save_path=Path("submissions"),
144 | num_tta=5,
145 | batch_size=8,
146 | extend="tile",
147 | )
148 |
--------------------------------------------------------------------------------
/utmosv2/config/ssl_only_stage1.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 | from types import SimpleNamespace
5 |
6 | batch_size = 32
7 | num_folds = 5
8 |
9 | sr = 16000
10 |
11 | preprocess = SimpleNamespace(
12 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data")
13 | )
14 |
15 | split = SimpleNamespace(
16 | type="sgkf_kind",
17 | target="mos",
18 | group="sys_id",
19 | kind="dataset",
20 | )
21 |
22 | dataset = SimpleNamespace(
23 | name="sslext",
24 | ssl=SimpleNamespace(
25 | duration=3,
26 | ),
27 | )
28 |
29 | external_data: list[str] | str = "all"
30 | use_bvcc = True
31 |
32 |
33 | validation_dataset = "each"
34 |
35 | loss = [
36 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7),
37 | (SimpleNamespace(name="mse"), 0.2),
38 | ]
39 |
40 | optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4)
41 |
42 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7)
43 |
44 | model_path = "model"
45 | model = SimpleNamespace(
46 | name="sslext",
47 | ssl=SimpleNamespace(
48 | name="facebook/wav2vec2-base",
49 | attn=1,
50 | freeze=True,
51 | num_classes=1,
52 | ),
53 | )
54 |
55 | run = SimpleNamespace(
56 | mixup=True,
57 | mixup_alpha=0.4,
58 | num_epochs=20,
59 | )
60 |
61 | main_metric = "sys_srcc"
62 | id_name = None
63 |
64 |
65 | inference = SimpleNamespace(
66 | save_path=Path("preds"),
67 | submit_save_path=Path("submissions"),
68 | num_tta=5,
69 | batch_size=8,
70 | # extend="tile",
71 | )
72 |
--------------------------------------------------------------------------------
/utmosv2/config/ssl_only_stage1_wo_bc.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 | from types import SimpleNamespace
5 |
6 | batch_size = 32
7 | num_folds = 5
8 |
9 | sr = 16000
10 |
11 | preprocess = SimpleNamespace(
12 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data")
13 | )
14 |
15 | split = SimpleNamespace(
16 | type="sgkf_kind",
17 | target="mos",
18 | group="sys_id",
19 | kind="dataset",
20 | )
21 |
22 | dataset = SimpleNamespace(
23 | name="sslext",
24 | ssl=SimpleNamespace(
25 | duration=3,
26 | ),
27 | )
28 |
29 | external_data: list[str] | str = [
30 | "sarulab",
31 | # "blizzard2008",
32 | # "blizzard2009",
33 | # "blizzard2011",
34 | # "blizzard2010-EH1",
35 | # "blizzard2010-EH2",
36 | # "blizzard2010-ES1",
37 | # "blizzard2010-ES3",
38 | "somos",
39 | ]
40 | use_bvcc = True
41 |
42 |
43 | validation_dataset = "each"
44 |
45 | loss = [
46 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7),
47 | (SimpleNamespace(name="mse"), 0.2),
48 | ]
49 |
50 | optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4)
51 |
52 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7)
53 |
54 | model_path = "model"
55 | model = SimpleNamespace(
56 | name="sslext",
57 | ssl=SimpleNamespace(
58 | name="facebook/wav2vec2-base",
59 | attn=1,
60 | freeze=True,
61 | num_classes=1,
62 | ),
63 | )
64 |
65 | run = SimpleNamespace(
66 | mixup=True,
67 | mixup_alpha=0.4,
68 | num_epochs=20,
69 | )
70 |
71 | main_metric = "sys_srcc"
72 | id_name = None
73 |
74 |
75 | inference = SimpleNamespace(
76 | save_path=Path("preds"),
77 | submit_save_path=Path("submissions"),
78 | num_tta=5,
79 | batch_size=8,
80 | # extend="tile",
81 | )
82 |
--------------------------------------------------------------------------------
/utmosv2/config/ssl_only_stage1_wo_bvcc.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 | from types import SimpleNamespace
5 |
6 | batch_size = 32
7 | num_folds = 5
8 |
9 | sr = 16000
10 |
11 | preprocess = SimpleNamespace(
12 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data")
13 | )
14 |
15 | split = SimpleNamespace(
16 | type="sgkf_kind",
17 | target="mos",
18 | group="sys_id",
19 | kind="dataset",
20 | )
21 |
22 | dataset = SimpleNamespace(
23 | name="sslext",
24 | ssl=SimpleNamespace(
25 | duration=3,
26 | ),
27 | )
28 |
29 | external_data: list[str] | str = [
30 | "sarulab",
31 | "blizzard2008",
32 | "blizzard2009",
33 | "blizzard2011",
34 | "blizzard2010-EH1",
35 | "blizzard2010-EH2",
36 | "blizzard2010-ES1",
37 | "blizzard2010-ES3",
38 | "somos",
39 | ]
40 | use_bvcc = False
41 |
42 |
43 | validation_dataset = "each"
44 |
45 | loss = [
46 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7),
47 | (SimpleNamespace(name="mse"), 0.2),
48 | ]
49 |
50 | optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4)
51 |
52 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7)
53 |
54 | model_path = "model"
55 | model = SimpleNamespace(
56 | name="sslext",
57 | ssl=SimpleNamespace(
58 | name="facebook/wav2vec2-base",
59 | attn=1,
60 | freeze=True,
61 | num_classes=1,
62 | ),
63 | )
64 |
65 | run = SimpleNamespace(
66 | mixup=True,
67 | mixup_alpha=0.4,
68 | num_epochs=20,
69 | )
70 |
71 | main_metric = "sys_srcc"
72 | id_name = None
73 |
74 |
75 | inference = SimpleNamespace(
76 | save_path=Path("preds"),
77 | submit_save_path=Path("submissions"),
78 | num_tta=5,
79 | batch_size=8,
80 | # extend="tile",
81 | )
82 |
--------------------------------------------------------------------------------
/utmosv2/config/ssl_only_stage1_wo_sarulab.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 | from types import SimpleNamespace
5 |
6 | batch_size = 32
7 | num_folds = 5
8 |
9 | sr = 16000
10 |
11 | preprocess = SimpleNamespace(
12 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data")
13 | )
14 |
15 | split = SimpleNamespace(
16 | type="sgkf_kind",
17 | target="mos",
18 | group="sys_id",
19 | kind="dataset",
20 | )
21 |
22 | dataset = SimpleNamespace(
23 | name="sslext",
24 | ssl=SimpleNamespace(
25 | duration=3,
26 | ),
27 | )
28 |
29 | external_data: list[str] | str = [
30 | # "sarulab",
31 | "blizzard2008",
32 | "blizzard2009",
33 | "blizzard2011",
34 | "blizzard2010-EH1",
35 | "blizzard2010-EH2",
36 | "blizzard2010-ES1",
37 | "blizzard2010-ES3",
38 | "somos",
39 | ]
40 | use_bvcc = True
41 |
42 |
43 | validation_dataset = "each"
44 |
45 | loss = [
46 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7),
47 | (SimpleNamespace(name="mse"), 0.2),
48 | ]
49 |
50 | optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4)
51 |
52 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7)
53 |
54 | model_path = "model"
55 | model = SimpleNamespace(
56 | name="sslext",
57 | ssl=SimpleNamespace(
58 | name="facebook/wav2vec2-base",
59 | attn=1,
60 | freeze=True,
61 | num_classes=1,
62 | ),
63 | )
64 |
65 | run = SimpleNamespace(
66 | mixup=True,
67 | mixup_alpha=0.4,
68 | num_epochs=20,
69 | )
70 |
71 | main_metric = "sys_srcc"
72 | id_name = None
73 |
74 |
75 | inference = SimpleNamespace(
76 | save_path=Path("preds"),
77 | submit_save_path=Path("submissions"),
78 | num_tta=5,
79 | batch_size=8,
80 | # extend="tile",
81 | )
82 |
--------------------------------------------------------------------------------
/utmosv2/config/ssl_only_stage1_wo_somos.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 | from types import SimpleNamespace
5 |
6 | batch_size = 32
7 | num_folds = 5
8 |
9 | sr = 16000
10 |
11 | preprocess = SimpleNamespace(
12 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data")
13 | )
14 |
15 | split = SimpleNamespace(
16 | type="sgkf_kind",
17 | target="mos",
18 | group="sys_id",
19 | kind="dataset",
20 | )
21 |
22 | dataset = SimpleNamespace(
23 | name="sslext",
24 | ssl=SimpleNamespace(
25 | duration=3,
26 | ),
27 | )
28 |
29 | external_data: list[str] | str = [
30 | "sarulab",
31 | "blizzard2008",
32 | "blizzard2009",
33 | "blizzard2011",
34 | "blizzard2010-EH1",
35 | "blizzard2010-EH2",
36 | "blizzard2010-ES1",
37 | "blizzard2010-ES3",
38 | # "somos",
39 | ]
40 | use_bvcc = True
41 |
42 |
43 | validation_dataset = "each"
44 |
45 | loss = [
46 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7),
47 | (SimpleNamespace(name="mse"), 0.2),
48 | ]
49 |
50 | optimizer = SimpleNamespace(name="adamw", lr=1e-3, weight_decay=1e-4)
51 |
52 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-7)
53 |
54 | model_path = "model"
55 | model = SimpleNamespace(
56 | name="sslext",
57 | ssl=SimpleNamespace(
58 | name="facebook/wav2vec2-base",
59 | attn=1,
60 | freeze=True,
61 | num_classes=1,
62 | ),
63 | )
64 |
65 | run = SimpleNamespace(
66 | mixup=True,
67 | mixup_alpha=0.4,
68 | num_epochs=20,
69 | )
70 |
71 | main_metric = "sys_srcc"
72 | id_name = None
73 |
74 |
75 | inference = SimpleNamespace(
76 | save_path=Path("preds"),
77 | submit_save_path=Path("submissions"),
78 | num_tta=5,
79 | batch_size=8,
80 | # extend="tile",
81 | )
82 |
--------------------------------------------------------------------------------
/utmosv2/config/ssl_only_stage2.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 | from types import SimpleNamespace
5 |
6 | batch_size = 32
7 | num_folds = 5
8 |
9 | sr = 16000
10 |
11 | preprocess = SimpleNamespace(
12 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data")
13 | )
14 |
15 | split = SimpleNamespace(
16 | type="sgkf_kind",
17 | target="mos",
18 | group="sys_id",
19 | kind="dataset",
20 | )
21 |
22 | dataset = SimpleNamespace(
23 | name="sslext",
24 | ssl=SimpleNamespace(
25 | duration=3,
26 | ),
27 | )
28 |
29 | external_data: list[str] | str = "all"
30 | use_bvcc = True
31 |
32 |
33 | validation_dataset = "each"
34 |
35 | loss = [
36 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7),
37 | (SimpleNamespace(name="mse"), 0.2),
38 | ]
39 |
40 | optimizer = SimpleNamespace(name="adamw", lr=3e-5, weight_decay=1e-4)
41 |
42 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-9)
43 |
44 | model_path = "model"
45 | model = SimpleNamespace(
46 | name="sslext",
47 | ssl=SimpleNamespace(
48 | name="facebook/wav2vec2-base",
49 | attn=1,
50 | freeze=False,
51 | num_classes=1,
52 | ),
53 | )
54 |
55 | run = SimpleNamespace(
56 | mixup=True,
57 | mixup_alpha=0.4,
58 | num_epochs=5,
59 | )
60 |
61 | main_metric = "sys_srcc"
62 | id_name = None
63 |
64 |
65 | inference = SimpleNamespace(
66 | save_path=Path("preds"),
67 | submit_save_path=Path("submissions"),
68 | num_tta=5,
69 | batch_size=8,
70 | # extend="tile",
71 | )
72 |
--------------------------------------------------------------------------------
/utmosv2/config/ssl_only_stage2_wo_bc.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 | from types import SimpleNamespace
5 |
6 | batch_size = 32
7 | num_folds = 5
8 |
9 | sr = 16000
10 |
11 | preprocess = SimpleNamespace(
12 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data")
13 | )
14 |
15 | split = SimpleNamespace(
16 | type="sgkf_kind",
17 | target="mos",
18 | group="sys_id",
19 | kind="dataset",
20 | )
21 |
22 | dataset = SimpleNamespace(
23 | name="sslext",
24 | ssl=SimpleNamespace(
25 | duration=3,
26 | ),
27 | )
28 |
29 | external_data: list[str] | str = [
30 | "sarulab",
31 | # "blizzard2008",
32 | # "blizzard2009",
33 | # "blizzard2011",
34 | # "blizzard2010-EH1",
35 | # "blizzard2010-EH2",
36 | # "blizzard2010-ES1",
37 | # "blizzard2010-ES3",
38 | "somos",
39 | ]
40 | use_bvcc = True
41 |
42 |
43 | validation_dataset = "each"
44 |
45 | loss = [
46 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7),
47 | (SimpleNamespace(name="mse"), 0.2),
48 | ]
49 |
50 | optimizer = SimpleNamespace(name="adamw", lr=3e-5, weight_decay=1e-4)
51 |
52 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-9)
53 |
54 | model_path = "model"
55 | model = SimpleNamespace(
56 | name="sslext",
57 | ssl=SimpleNamespace(
58 | name="facebook/wav2vec2-base",
59 | attn=1,
60 | freeze=False,
61 | num_classes=1,
62 | ),
63 | )
64 |
65 | run = SimpleNamespace(
66 | mixup=True,
67 | mixup_alpha=0.4,
68 | num_epochs=5,
69 | )
70 |
71 | main_metric = "sys_srcc"
72 | id_name = None
73 |
74 |
75 | inference = SimpleNamespace(
76 | save_path=Path("preds"),
77 | submit_save_path=Path("submissions"),
78 | num_tta=5,
79 | batch_size=8,
80 | # extend="tile",
81 | )
82 |
--------------------------------------------------------------------------------
/utmosv2/config/ssl_only_stage2_wo_bvcc.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 | from types import SimpleNamespace
5 |
6 | batch_size = 32
7 | num_folds = 5
8 |
9 | sr = 16000
10 |
11 | preprocess = SimpleNamespace(
12 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data")
13 | )
14 |
15 | split = SimpleNamespace(
16 | type="sgkf_kind",
17 | target="mos",
18 | group="sys_id",
19 | kind="dataset",
20 | )
21 |
22 | dataset = SimpleNamespace(
23 | name="sslext",
24 | ssl=SimpleNamespace(
25 | duration=3,
26 | ),
27 | )
28 |
29 | external_data: list[str] | str = [
30 | "sarulab",
31 | "blizzard2008",
32 | "blizzard2009",
33 | "blizzard2011",
34 | "blizzard2010-EH1",
35 | "blizzard2010-EH2",
36 | "blizzard2010-ES1",
37 | "blizzard2010-ES3",
38 | "somos",
39 | ]
40 | use_bvcc = False
41 |
42 |
43 | validation_dataset = "each"
44 |
45 | loss = [
46 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7),
47 | (SimpleNamespace(name="mse"), 0.2),
48 | ]
49 |
50 | optimizer = SimpleNamespace(name="adamw", lr=3e-5, weight_decay=1e-4)
51 |
52 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-9)
53 |
54 | model_path = "model"
55 | model = SimpleNamespace(
56 | name="sslext",
57 | ssl=SimpleNamespace(
58 | name="facebook/wav2vec2-base",
59 | attn=1,
60 | freeze=False,
61 | num_classes=1,
62 | ),
63 | )
64 |
65 | run = SimpleNamespace(
66 | mixup=True,
67 | mixup_alpha=0.4,
68 | num_epochs=5,
69 | )
70 |
71 | main_metric = "sys_srcc"
72 | id_name = None
73 |
74 |
75 | inference = SimpleNamespace(
76 | save_path=Path("preds"),
77 | submit_save_path=Path("submissions"),
78 | num_tta=5,
79 | batch_size=8,
80 | # extend="tile",
81 | )
82 |
--------------------------------------------------------------------------------
/utmosv2/config/ssl_only_stage2_wo_sarulab.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 | from types import SimpleNamespace
5 |
6 | batch_size = 32
7 | num_folds = 5
8 |
9 | sr = 16000
10 |
11 | preprocess = SimpleNamespace(
12 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data")
13 | )
14 |
15 | split = SimpleNamespace(
16 | type="sgkf_kind",
17 | target="mos",
18 | group="sys_id",
19 | kind="dataset",
20 | )
21 |
22 | dataset = SimpleNamespace(
23 | name="sslext",
24 | ssl=SimpleNamespace(
25 | duration=3,
26 | ),
27 | )
28 |
29 | external_data: list[str] | str = [
30 | # "sarulab",
31 | "blizzard2008",
32 | "blizzard2009",
33 | "blizzard2011",
34 | "blizzard2010-EH1",
35 | "blizzard2010-EH2",
36 | "blizzard2010-ES1",
37 | "blizzard2010-ES3",
38 | "somos",
39 | ]
40 | use_bvcc = True
41 |
42 |
43 | validation_dataset = "each"
44 |
45 | loss = [
46 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7),
47 | (SimpleNamespace(name="mse"), 0.2),
48 | ]
49 |
50 | optimizer = SimpleNamespace(name="adamw", lr=3e-5, weight_decay=1e-4)
51 |
52 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-9)
53 |
54 | model_path = "model"
55 | model = SimpleNamespace(
56 | name="sslext",
57 | ssl=SimpleNamespace(
58 | name="facebook/wav2vec2-base",
59 | attn=1,
60 | freeze=False,
61 | num_classes=1,
62 | ),
63 | )
64 |
65 | run = SimpleNamespace(
66 | mixup=True,
67 | mixup_alpha=0.4,
68 | num_epochs=5,
69 | )
70 |
71 | main_metric = "sys_srcc"
72 | id_name = None
73 |
74 |
75 | inference = SimpleNamespace(
76 | save_path=Path("preds"),
77 | submit_save_path=Path("submissions"),
78 | num_tta=5,
79 | batch_size=8,
80 | # extend="tile",
81 | )
82 |
--------------------------------------------------------------------------------
/utmosv2/config/ssl_only_stage2_wo_somos.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 | from types import SimpleNamespace
5 |
6 | batch_size = 32
7 | num_folds = 5
8 |
9 | sr = 16000
10 |
11 | preprocess = SimpleNamespace(
12 | top_db=30, min_seconds=None, save_path=Path("preprocessed_data")
13 | )
14 |
15 | split = SimpleNamespace(
16 | type="sgkf_kind",
17 | target="mos",
18 | group="sys_id",
19 | kind="dataset",
20 | )
21 |
22 | dataset = SimpleNamespace(
23 | name="sslext",
24 | ssl=SimpleNamespace(
25 | duration=3,
26 | ),
27 | )
28 |
29 | external_data: list[str] | str = [
30 | "sarulab",
31 | "blizzard2008",
32 | "blizzard2009",
33 | "blizzard2011",
34 | "blizzard2010-EH1",
35 | "blizzard2010-EH2",
36 | "blizzard2010-ES1",
37 | "blizzard2010-ES3",
38 | # "somos",
39 | ]
40 | use_bvcc = True
41 |
42 |
43 | validation_dataset = "each"
44 |
45 | loss = [
46 | (SimpleNamespace(name="pairwize_diff", margin=0.2, norm="l1"), 0.7),
47 | (SimpleNamespace(name="mse"), 0.2),
48 | ]
49 |
50 | optimizer = SimpleNamespace(name="adamw", lr=3e-5, weight_decay=1e-4)
51 |
52 | scheduler = SimpleNamespace(name="cosine", T_max=None, eta_min=1e-9)
53 |
54 | model_path = "model"
55 | model = SimpleNamespace(
56 | name="sslext",
57 | ssl=SimpleNamespace(
58 | name="facebook/wav2vec2-base",
59 | attn=1,
60 | freeze=False,
61 | num_classes=1,
62 | ),
63 | )
64 |
65 | run = SimpleNamespace(
66 | mixup=True,
67 | mixup_alpha=0.4,
68 | num_epochs=5,
69 | )
70 |
71 | main_metric = "sys_srcc"
72 | id_name = None
73 |
74 |
75 | inference = SimpleNamespace(
76 | save_path=Path("preds"),
77 | submit_save_path=Path("submissions"),
78 | num_tta=5,
79 | batch_size=8,
80 | # extend="tile",
81 | )
82 |
--------------------------------------------------------------------------------
/utmosv2/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from utmosv2.dataset.multi_spec import MultiSpecDataset, MultiSpecExtDataset
2 | from utmosv2.dataset.ssl import SSLDataset, SSLExtDataset
3 | from utmosv2.dataset.ssl_multispec import SSLLMultiSpecExtDataset
4 |
5 | __all__ = [
6 | "MultiSpecDataset",
7 | "MultiSpecExtDataset",
8 | "SSLLMultiSpecExtDataset",
9 | "SSLDataset",
10 | "SSLExtDataset",
11 | ]
12 |
--------------------------------------------------------------------------------
/utmosv2/dataset/_base.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import abc
4 | from collections.abc import Callable
5 | from typing import TYPE_CHECKING
6 |
7 | import torch
8 |
9 | from utmosv2._settings._config import Config
10 |
11 | if TYPE_CHECKING:
12 | import pandas as pd
13 |
14 | from utmosv2.dataset._schema import DatasetSchema
15 |
16 |
17 | class _BaseDataset(torch.utils.data.Dataset, abc.ABC):
18 | def __init__(
19 | self,
20 | cfg: Config,
21 | data: "pd.DataFrame" | list[DatasetSchema],
22 | phase: str,
23 | transform: dict[str, Callable[[torch.Tensor], torch.Tensor]] | None = None,
24 | ):
25 | self.cfg = cfg
26 | self.data = data
27 | self.phase = phase
28 | self.transform = transform
29 |
30 | def __len__(self) -> int:
31 | return len(self.data)
32 |
33 | @abc.abstractmethod
34 | def __getitem__(self, idx: int) -> tuple[torch.Tensor, ...]:
35 | pass
36 |
--------------------------------------------------------------------------------
/utmosv2/dataset/_schema.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from dataclasses import dataclass
4 | from pathlib import Path
5 |
6 |
7 | @dataclass
8 | class DatasetSchema:
9 | file_path: Path
10 | dataset: str
11 | mos: int | None = None
12 |
--------------------------------------------------------------------------------
/utmosv2/dataset/_utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pathlib import Path
3 |
4 | import librosa
5 | import numpy as np
6 |
7 | from utmosv2._settings._config import Config
8 |
9 |
10 | def load_audio(cfg: Config, file: Path) -> np.ndarray:
11 | if file.suffix in [".wav", ".flac"]:
12 | y, sr = librosa.load(file, sr=None)
13 | y = librosa.resample(y, orig_sr=sr, target_sr=cfg.sr)
14 | else:
15 | y = np.load(file)
16 | return y
17 |
18 |
19 | def extend_audio(cfg: Config, y: np.ndarray, length: int, type: str) -> np.ndarray:
20 | if y.shape[0] > length:
21 | return y
22 | elif type == "tile":
23 | n = length // y.shape[0] + 1
24 | y = np.tile(y, n)
25 | return y
26 | else:
27 | raise NotImplementedError
28 |
29 |
30 | def select_random_start(y: np.ndarray, length: int) -> np.ndarray:
31 | start = np.random.randint(0, y.shape[0] - length)
32 | return y[start : start + length]
33 |
34 |
35 | def get_dataset_map(cfg: Config) -> dict[str, int]:
36 | if cfg.data_config:
37 | with open(cfg.data_config, "r") as f:
38 | datasets = json.load(f)
39 | return {d["name"]: i for i, d in enumerate(datasets["data"])}
40 | else:
41 | return {
42 | "bvcc": 0,
43 | "sarulab": 1,
44 | "blizzard2008": 2,
45 | "blizzard2009": 3,
46 | "blizzard2010-EH1": 4,
47 | "blizzard2010-EH2": 5,
48 | "blizzard2010-ES1": 6,
49 | "blizzard2010-ES3": 7,
50 | "blizzard2011": 8,
51 | "somos": 9,
52 | }
53 |
54 |
55 | def get_dataset_num(cfg: Config) -> int:
56 | return len(get_dataset_map(cfg))
57 |
--------------------------------------------------------------------------------
/utmosv2/dataset/multi_spec.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from collections.abc import Callable
4 | from typing import TYPE_CHECKING
5 |
6 | import librosa
7 | import numpy as np
8 | import torch
9 |
10 | from utmosv2._settings._config import Config
11 | from utmosv2.dataset._base import _BaseDataset
12 | from utmosv2.dataset._utils import (
13 | extend_audio,
14 | get_dataset_map,
15 | load_audio,
16 | select_random_start,
17 | )
18 | from utmosv2.preprocess._preprocess import remove_silent_section
19 |
20 | if TYPE_CHECKING:
21 | import pandas as pd
22 |
23 | from utmosv2.dataset._schema import DatasetSchema
24 |
25 |
26 | class MultiSpecDataset(_BaseDataset):
27 | """
28 | Dataset class for mel-spectrogram feature extractor. This class is responsible for
29 | loading audio data, generating multiple spectrograms for each sample, and
30 | applying the necessary transformations.
31 |
32 | Args:
33 | cfg (SimpleNamespace): The configuration object containing dataset and model settings.
34 | data (list[DatasetSchema] | pd.DataFrame): The dataset containing file paths and labels.
35 | phase (str): The phase of the dataset, either "train" or any other phase (e.g., "valid").
36 | transform (str, dict[Callable[[torch.Tensor], torch.Tensor]] | None): Transformation function to apply to spectrograms.
37 | """
38 |
39 | def __getitem__(self, idx: int) -> tuple[torch.Tensor, ...]:
40 | """
41 | Get the spectrogram and target MOS for a given index.
42 |
43 | Args:
44 | idx (int): Index of the sample.
45 |
46 | Returns:
47 | tuple: The spectrogram (torch.Tensor) and target MOS (torch.Tensor) for the sample.
48 | """
49 | row = self.data[idx] if isinstance(self.data, list) else self.data.iloc[idx]
50 | file = row.file_path
51 | y = load_audio(self.cfg, file)
52 | if (
53 | hasattr(self.cfg.dataset, "remove_silent_section")
54 | and self.cfg.dataset.remove_silent_section
55 | ):
56 | y = remove_silent_section(y)
57 | specs = []
58 | length = int(self.cfg.dataset.spec_frames.frame_sec * self.cfg.sr)
59 | y = extend_audio(self.cfg, y, length, type=self.cfg.dataset.spec_frames.extend)
60 | for _ in range(self.cfg.dataset.spec_frames.num_frames):
61 | y1 = select_random_start(y, length)
62 | for spec_cfg in self.cfg.dataset.specs:
63 | spec = _make_spctrogram(self.cfg, spec_cfg, y1)
64 | if self.cfg.dataset.spec_frames.mixup_inner:
65 | y2 = select_random_start(y, length)
66 | spec2 = _make_spctrogram(self.cfg, spec_cfg, y2)
67 | lmd = np.random.beta(
68 | self.cfg.dataset.spec_frames.mixup_alpha,
69 | self.cfg.dataset.spec_frames.mixup_alpha,
70 | )
71 | spec = lmd * spec + (1 - lmd) * spec2
72 | spec = np.stack([spec, spec, spec], axis=0)
73 | # spec = np.transpose(spec, (1, 2, 0))
74 | spec_tensor = torch.tensor(spec, dtype=torch.float32)
75 | phase = "train" if self.phase == "train" else "valid"
76 | assert self.transform is not None, "Transform must be provided."
77 | spec_tensor = self.transform[phase](spec_tensor)
78 | specs.append(spec_tensor)
79 | spec_tensor = torch.stack(specs).float()
80 |
81 | target = row.mos or 0.0
82 | target = torch.tensor(target, dtype=torch.float32)
83 |
84 | return spec_tensor, target
85 |
86 |
87 | class MultiSpecExtDataset(MultiSpecDataset):
88 | """
89 | Dataset class for mel-spectrogram feature extractor with data-domain embedding.
90 |
91 | Args:
92 | cfg (SimpleNamespace | ModuleType):
93 | The configuration object containing dataset and model settings.
94 | data (pd.DataFrame | list[DatasetSchema]):
95 | The dataset containing file paths and labels.
96 | phase (str):
97 | The phase of the dataset, either "train" or any other phase (e.g., "valid").
98 | transform (dict[str, Callable[[torch.Tensor], torch.Tensor]] | None):
99 | Transformation function to apply to spectrograms.
100 | """
101 |
102 | def __init__(
103 | self,
104 | cfg: Config,
105 | data: "pd.DataFrame" | list[DatasetSchema],
106 | phase: str,
107 | transform: dict[str, Callable[[torch.Tensor], torch.Tensor]] | None = None,
108 | ):
109 | super().__init__(cfg, data, phase, transform)
110 | self.dataset_map = get_dataset_map(cfg)
111 |
112 | def __getitem__(self, idx: int) -> tuple[torch.Tensor, ...]:
113 | """
114 | Get the spectrogram, data-domain embedding, and target MOS for a given index.
115 |
116 | Args:
117 | idx (int): Index of the sample.
118 |
119 | Returns:
120 | tuple: A tuple containing the generated spectrogram (torch.Tensor), data-domain embedding (torch.Tensor),
121 | and target MOS (torch.Tensor).
122 | """
123 | spec, target = super().__getitem__(idx)
124 | row = self.data[idx] if isinstance(self.data, list) else self.data.iloc[idx]
125 |
126 | d = np.zeros(len(self.dataset_map))
127 | d[self.dataset_map[row.dataset]] = 1
128 | dt = torch.tensor(d, dtype=torch.float32)
129 |
130 | return spec, dt, target
131 |
132 |
133 | def _make_spctrogram(cfg: Config, spec_cfg: Config, y: np.ndarray) -> np.ndarray:
134 | if spec_cfg.mode == "melspec":
135 | return _make_melspec(cfg, spec_cfg, y)
136 | elif spec_cfg.mode == "stft":
137 | return _make_stft(cfg, spec_cfg, y)
138 | else:
139 | raise NotImplementedError
140 |
141 |
142 | def _make_melspec(cfg: Config, spec_cfg: Config, y: np.ndarray) -> np.ndarray:
143 | spec = librosa.feature.melspectrogram(
144 | y=y,
145 | sr=cfg.sr,
146 | n_fft=spec_cfg.n_fft,
147 | hop_length=spec_cfg.hop_length,
148 | n_mels=spec_cfg.n_mels,
149 | win_length=spec_cfg.win_length,
150 | )
151 | spec = librosa.power_to_db(spec, ref=np.max)
152 | if spec_cfg.norm is not None:
153 | spec = (spec + spec_cfg.norm) / spec_cfg.norm
154 | return spec
155 |
156 |
157 | def _make_stft(cfg: Config, spec_cfg: Config, y: np.ndarray) -> np.ndarray:
158 | spec = librosa.stft(y=y, n_fft=spec_cfg.n_fft, hop_length=spec_cfg.hop_length)
159 | spec = np.abs(spec)
160 | spec = librosa.amplitude_to_db(spec)
161 | return spec
162 |
--------------------------------------------------------------------------------
/utmosv2/dataset/ssl.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import TYPE_CHECKING
4 |
5 | import numpy as np
6 | import torch
7 |
8 | from utmosv2._settings._config import Config
9 | from utmosv2.dataset._base import _BaseDataset
10 | from utmosv2.dataset._utils import (
11 | extend_audio,
12 | get_dataset_map,
13 | load_audio,
14 | select_random_start,
15 | )
16 | from utmosv2.preprocess._preprocess import remove_silent_section
17 |
18 | if TYPE_CHECKING:
19 | import pandas as pd
20 |
21 | from utmosv2.dataset._schema import DatasetSchema
22 |
23 |
24 | class SSLDataset(_BaseDataset):
25 | """
26 | Dataset class for SSL (Self-Supervised Learning) feature extractor.
27 | This class handles audio loading, extending, and random selection of a segment from the audio.
28 |
29 | Args:
30 | cfg (SimpleNamespace | ModuleType):
31 | The configuration object containing dataset and model settings.
32 | data (pd.DataFrame | list[DatasetSchema]):
33 | The dataset containing file paths and MOS labels.
34 | phase (str):
35 | The phase of the dataset, either "train" or any other phase (e.g., "valid").
36 | transform (dict[str, Callable[[torch.Tensor], torch.Tensor]] | None):
37 | Transformation function to apply to spectrograms.
38 | """
39 |
40 | def __getitem__(self, idx: int) -> tuple[torch.Tensor, ...]:
41 | """
42 | Get the processed audio, and target MOS for a given index.
43 |
44 | Args:
45 | idx (int): Index of the sample.
46 | Returns:
47 | tuple: A tuple containing the processed audio (torch.Tensor), and target MOS (torch.Tensor).
48 | """
49 | row = self.data[idx] if isinstance(self.data, list) else self.data.iloc[idx]
50 | file = row.file_path
51 | y = load_audio(self.cfg, file)
52 | if (
53 | hasattr(self.cfg.dataset, "remove_silent_section")
54 | and self.cfg.dataset.remove_silent_section
55 | ):
56 | y = remove_silent_section(y)
57 | length = int(self.cfg.dataset.ssl.duration * self.cfg.sr)
58 | y = extend_audio(self.cfg, y, length, type="tile")
59 | y = select_random_start(y, length)
60 |
61 | target = row.mos or 0.0
62 | target = torch.tensor(target, dtype=torch.float32)
63 |
64 | return torch.from_numpy(y), target
65 |
66 |
67 | class SSLExtDataset(SSLDataset):
68 | """
69 | Dataset class for SSL (Self-Supervised Learning) feature extractor with data-domein embedding.
70 |
71 | Args:
72 | cfg (SimpleNamespace | ModuleType):
73 | The configuration object containing dataset and model settings.
74 | data (pd.DataFrame | list[DatasetSchema]):
75 | The dataset containing file paths and MOS labels.
76 | phase (str):
77 | The phase of the dataset, either "train" or any other phase (e.g., "valid").
78 | """
79 |
80 | def __init__(
81 | self, cfg: Config, data: "pd.DataFrame" | list[DatasetSchema], phase: str
82 | ):
83 | super().__init__(cfg, data, phase)
84 | self.dataset_map = get_dataset_map(cfg)
85 |
86 | def __getitem__(self, idx: int) -> tuple[torch.Tensor, ...]:
87 | """
88 | Get the processed audio, data-domain embedding, and target MOS for a given index.
89 |
90 | Args:
91 | idx (int): Index of the sample.
92 | Returns:
93 | tuple: A tuple containing the processed audio (torch.Tensor), data-domain embedding (torch.Tensor),
94 | and target MOS (torch.Tensor).
95 | """
96 | y, target = super().__getitem__(idx)
97 | row = self.data[idx] if isinstance(self.data, list) else self.data.iloc[idx]
98 |
99 | d = np.zeros(len(self.dataset_map))
100 | d[self.dataset_map[row.dataset]] = 1
101 | dt = torch.tensor(d, dtype=torch.float32)
102 |
103 | return y, dt, target
104 |
--------------------------------------------------------------------------------
/utmosv2/dataset/ssl_multispec.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from collections.abc import Callable
4 | from typing import TYPE_CHECKING
5 |
6 | import torch
7 |
8 | from utmosv2._settings._config import Config
9 | from utmosv2.dataset import MultiSpecDataset, SSLExtDataset
10 | from utmosv2.dataset._base import _BaseDataset
11 |
12 | if TYPE_CHECKING:
13 | import pandas as pd
14 |
15 | from utmosv2.dataset._schema import DatasetSchema
16 |
17 |
18 | class SSLLMultiSpecExtDataset(_BaseDataset):
19 | """
20 | Dataset class that combines both SSL (Self-Supervised Learning) and Multi-Spectrogram datasets.
21 | This dataset uses both SSLExtDataset and MultiSpecDataset to provide different representations
22 | of the same audio sample.
23 |
24 | Args:
25 | cfg (SimpleNamespace | ModuleType):
26 | The configuration object containing dataset and model settings.
27 | data (pd.DataFrame | list[DatasetSchema]):
28 | The dataset containing file paths and MOS labels.
29 | phase (str):
30 | The phase of the dataset, either "train" or any other phase (e.g., "valid").
31 | transform (dict[str, Callable[[torch.Tensor], torch.Tensor]] | None):
32 | Transformation function to apply to spectrograms.
33 | """
34 |
35 | def __init__(
36 | self,
37 | cfg: Config,
38 | data: "pd.DataFrame" | list[DatasetSchema],
39 | phase: str,
40 | transform: dict[str, Callable[[torch.Tensor], torch.Tensor]] | None = None,
41 | ):
42 | super().__init__(cfg, data, phase, transform)
43 | self.ssl = SSLExtDataset(cfg, data, phase)
44 | self.multi_spec = MultiSpecDataset(cfg, data, phase, transform)
45 |
46 | def __len__(self) -> int:
47 | return len(self.data)
48 |
49 | def __getitem__(self, idx: int) -> tuple[torch.Tensor, ...]:
50 | """
51 | Get data for SSL feature extractor, mel-spectrogram feature extractor, data-domain embedding, and target MOS for a given index.
52 |
53 | Args:
54 | idx (int): Index of the sample.
55 |
56 | Returns:
57 | tuple: data for SSL feature extractor (torch.Tensor), data for mel-spectrogram feature extractor (torch.Tensor),
58 | data-domain id (torch.Tensor), and target MOS (torch.Tensor).
59 | """
60 | x1, d, target = self.ssl[idx]
61 | x2, _ = self.multi_spec[idx]
62 |
63 | return x1, x2, d, target
64 |
--------------------------------------------------------------------------------
/utmosv2/loss/__init__.py:
--------------------------------------------------------------------------------
1 | from utmosv2.loss._losses import CombinedLoss, PairwizeDiffLoss
2 |
3 | __all__ = ["PairwizeDiffLoss", "CombinedLoss"]
4 |
--------------------------------------------------------------------------------
/utmosv2/loss/_losses.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class PairwizeDiffLoss(nn.Module):
9 | """
10 | Pairwise difference loss function for comparing input and target tensors.
11 | The loss is based on the difference between pairs of inputs and pairs of targets,
12 | with a specified margin and norm ("l1" or "l2_squared").
13 | """
14 |
15 | def __init__(self, margin: float = 0.2, norm: str = "l1"):
16 | """
17 | Initialize the PairwizeDiffLoss with the specified margin and norm.
18 |
19 | Args:
20 | margin (float):
21 | The margin value used for the loss function. Defaults to 0.2.
22 | norm (str):
23 | The norm to use for the difference calculation. Must be "l1" or "l2_squared". Defaults to "l1".
24 | """
25 | super().__init__()
26 | self.margin = margin
27 | self.norm = norm
28 |
29 | def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
30 | """
31 | Compute the pairwise difference loss between input and target tensors.
32 |
33 | Args:
34 | input (torch.Tensor): The input tensor.
35 | target (torch.Tensor): The target tensor.
36 |
37 | Returns:
38 | torch.Tensor: The computed loss.
39 | """
40 | s = input.unsqueeze(1) - input.unsqueeze(0)
41 | t = target.unsqueeze(1) - target.unsqueeze(0)
42 | if self.norm not in ["l1", "l2_squared"]:
43 | raise ValueError(
44 | f'Unknown norm: {self.norm}. Must be one of ["l1", "l2_squared"]'
45 | )
46 | norm_fn = {
47 | "l1": torch.abs,
48 | "l2_squared": lambda x: x**2,
49 | }[self.norm]
50 | loss = F.relu(norm_fn(s - t) - self.margin) # type: ignore
51 | return loss.mean().div(2)
52 |
53 |
54 | class CombinedLoss(nn.Module):
55 | """
56 | A combined loss function that allows for multiple loss functions to be weighted and combined.
57 |
58 | Args:
59 | weighted_losses (list[tuple[nn.Module, float]]):
60 | A list of loss functions and their associated weights.
61 | """
62 |
63 | def __init__(self, weighted_losses: list[tuple[nn.Module, float]]):
64 | super().__init__()
65 | self.weighted_losses = weighted_losses
66 |
67 | def forward(
68 | self, input: torch.Tensor, target: torch.Tensor
69 | ) -> list[tuple[float, torch.Tensor]]:
70 | """
71 | Compute the weighted loss for each loss function in the list.
72 |
73 | Args:
74 | input (torch.Tensor): The input tensor.
75 | target (torch.Tensor): The target tensor.
76 |
77 | Returns:
78 | list[tuple[float, torch.Tensor]]:
79 | A list of tuples where each contains a weight and the corresponding computed loss.
80 | """
81 | return [(w, loss(input, target)) for loss, w in self.weighted_losses]
82 |
--------------------------------------------------------------------------------
/utmosv2/model/__init__.py:
--------------------------------------------------------------------------------
1 | from utmosv2.model.multi_spec import MultiSpecExtModel, MultiSpecModelV2
2 | from utmosv2.model.ssl import SSLExtModel
3 | from utmosv2.model.ssl_multispec import SSLMultiSpecExtModelV1, SSLMultiSpecExtModelV2
4 |
5 | __all__ = [
6 | "MultiSpecExtModel",
7 | "MultiSpecModelV2",
8 | "SSLExtModel",
9 | "SSLMultiSpecExtModelV1",
10 | "SSLMultiSpecExtModelV2",
11 | ]
12 |
--------------------------------------------------------------------------------
/utmosv2/model/ssl.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from transformers import AutoFeatureExtractor, AutoModel
7 |
8 | from utmosv2._settings._config import Config
9 | from utmosv2.dataset._utils import get_dataset_num
10 |
11 |
12 | class _SSLEncoder(nn.Module):
13 | def __init__(self, sr: int, model_name: str, freeze: bool):
14 | super().__init__()
15 | self.sr = sr
16 | self.processor = AutoFeatureExtractor.from_pretrained(model_name)
17 | self.model = AutoModel.from_pretrained(model_name)
18 | if freeze:
19 | for param in self.model.parameters():
20 | param.requires_grad = False
21 |
22 | def forward(self, x: tuple[torch.Tensor]) -> tuple[torch.Tensor]:
23 | x = self.processor(
24 | [t.cpu().numpy() for t in x],
25 | sampling_rate=self.sr,
26 | return_tensors="pt",
27 | ).to(self.model.device)
28 | outputs = self.model(**x, output_hidden_states=True)
29 | return outputs.hidden_states
30 |
31 |
32 | class SSLExtModel(nn.Module):
33 | """
34 | A self-supervised learning (SSL) model extended with data-domain id.
35 | This model uses an encoder to process input data, applies attention layers if configured,
36 | and combines the features with data-domain embeddings before classification.
37 |
38 | Args:
39 | cfg (SimpleNamespace | ModuleType):
40 | Configuration object containing model and dataset settings.
41 | name (str | None):
42 | Optional name for the SSL encoder. Defaults to the name specified in `cfg.model.ssl.name`.
43 | """
44 |
45 | def __init__(self, cfg: Config, name: str | None = None):
46 | super().__init__()
47 | self.cfg = cfg
48 | self.encoder = _SSLEncoder(
49 | cfg.sr, name or cfg.model.ssl.name, cfg.model.ssl.freeze
50 | )
51 | hidden_num, in_features = get_ssl_output_shape(name or cfg.model.ssl.name)
52 | self.weights = nn.Parameter(F.softmax(torch.randn(hidden_num), dim=0))
53 | if cfg.model.ssl.attn:
54 | self.attn = nn.ModuleList(
55 | [
56 | nn.MultiheadAttention(
57 | embed_dim=in_features,
58 | num_heads=8,
59 | dropout=0.2,
60 | batch_first=True,
61 | )
62 | for _ in range(cfg.model.ssl.attn)
63 | ]
64 | )
65 | self.num_dataset = get_dataset_num(cfg)
66 | self.fc: nn.Linear | nn.Identity = nn.Linear(
67 | in_features * 2 + self.num_dataset, cfg.model.ssl.num_classes
68 | )
69 |
70 | def forward(self, xt: tuple[torch.Tensor], d: torch.Tensor) -> torch.Tensor:
71 | """
72 | Forward pass of the SSLExtModel.
73 |
74 | Args:
75 | x (torch.Tensor):
76 | Input tensor representing the features to be processed by the SSL encoder.
77 | d (torch.Tensor):
78 | Dataset-specific information tensor.
79 |
80 | Returns:
81 | torch.Tensor:
82 | Output tensor after applying the SSL encoder, attention (if configured), and fully connected layers.
83 | """
84 | xt = self.encoder(xt)
85 | x: torch.Tensor = sum([t * w for t, w in zip(xt, self.weights)])
86 | if self.cfg.model.ssl.attn:
87 | y = x
88 | for attn in self.attn:
89 | y, _ = attn(y, y, y)
90 | x = torch.cat([torch.mean(y, dim=1), torch.max(x, dim=1)[0]], dim=1)
91 | else:
92 | x = torch.cat([torch.mean(x, dim=1), torch.max(x, dim=1)[0]], dim=1)
93 | x = self.fc(torch.cat([x, d], dim=1))
94 | return x
95 |
96 |
97 | def get_ssl_output_shape(name: str) -> tuple[int, int]:
98 | if name in [
99 | "facebook/w2v-bert-2.0",
100 | "facebook/wav2vec2-large",
101 | "facebook/wav2vec2-large-robust",
102 | "facebook/wav2vec2-large-960h",
103 | "microsoft/wavlm-large",
104 | "facebook/wav2vec2-large-xlsr-53",
105 | ]:
106 | return 25, 1024
107 | elif name in [
108 | "facebook/hubert-base-ls960",
109 | "facebook/data2vec-audio-base-960h",
110 | "microsoft/wavlm-base",
111 | "microsoft/wavlm-base-plus",
112 | "microsoft/wavlm-base-plus-sv",
113 | "facebook/wav2vec2-base",
114 | ]:
115 | return 13, 768
116 | else:
117 | raise NotImplementedError
118 |
--------------------------------------------------------------------------------
/utmosv2/model/ssl_multispec.py:
--------------------------------------------------------------------------------
1 | from typing import cast
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from utmosv2._settings._config import Config
7 | from utmosv2.dataset._utils import get_dataset_num
8 | from utmosv2.model import MultiSpecExtModel, MultiSpecModelV2, SSLExtModel
9 |
10 |
11 | class SSLMultiSpecExtModelV1(nn.Module):
12 | def __init__(self, cfg: Config):
13 | super().__init__()
14 | self.cfg = cfg
15 | self.ssl = SSLExtModel(cfg)
16 | self.spec_long = MultiSpecModelV2(cfg)
17 | self.ssl.load_state_dict(
18 | torch.load(
19 | f"outputs/{cfg.model.ssl_spec.ssl_weight}/fold{cfg.now_fold}_s{cfg.split.seed}_best_model.pth"
20 | )
21 | )
22 | self.spec_long.load_state_dict(
23 | torch.load(
24 | f"outputs/{cfg.model.ssl_spec.spec_weight}/fold{cfg.now_fold}_s{cfg.split.seed}_best_model.pth"
25 | )
26 | )
27 | if cfg.model.ssl_spec.freeze:
28 | for param in self.ssl.parameters():
29 | param.requires_grad = False
30 | for param in self.spec_long.parameters():
31 | param.requires_grad = False
32 | self.ssl.fc = nn.Identity()
33 | self.spec_long.fc = nn.Identity()
34 |
35 | self.num_dataset = get_dataset_num(cfg)
36 |
37 | self.fc = nn.Linear(
38 | cast(int, self.ssl.fc.in_features)
39 | + cast(int, self.spec_long.fc.in_features)
40 | + self.num_dataset,
41 | cfg.model.ssl_spec.num_classes,
42 | )
43 |
44 | def forward(
45 | self, x1: torch.Tensor, x2: torch.Tensor, d: torch.Tensor
46 | ) -> torch.Tensor:
47 | x1 = self.ssl(x1, torch.zeros(x1.shape[0], self.num_dataset).to(x1.device))
48 | x2 = self.spec_long(x2)
49 | x = torch.cat([x1, x2, d], dim=1)
50 | x = self.fc(x)
51 | return x
52 |
53 |
54 | class SSLMultiSpecExtModelV2(nn.Module):
55 | def __init__(self, cfg: Config):
56 | super().__init__()
57 | self.cfg = cfg
58 | self.ssl = SSLExtModel(cfg)
59 | self.spec_long = MultiSpecExtModel(cfg)
60 | if cfg.model.ssl_spec.ssl_weight is not None and cfg.phase == "train":
61 | self.ssl.load_state_dict(
62 | torch.load(
63 | f"outputs/{cfg.model.ssl_spec.ssl_weight}/fold{cfg.now_fold}_s{cfg.split.seed}_best_model.pth"
64 | )
65 | )
66 | if cfg.model.ssl_spec.spec_weight is not None and cfg.phase == "train":
67 | self.spec_long.load_state_dict(
68 | torch.load(
69 | f"outputs/{cfg.model.ssl_spec.spec_weight}/fold{cfg.now_fold}_s{cfg.split.seed}_best_model.pth"
70 | )
71 | )
72 | if cfg.model.ssl_spec.freeze:
73 | for param in self.ssl.parameters():
74 | param.requires_grad = False
75 | for param in self.spec_long.parameters():
76 | param.requires_grad = False
77 | ssl_input = self.ssl.fc.in_features
78 | spec_long_input = self.spec_long.fc.in_features
79 | self.ssl.fc = nn.Identity()
80 | self.spec_long.fc = nn.Identity()
81 |
82 | self.num_dataset = get_dataset_num(cfg)
83 |
84 | self.fc = nn.Linear(
85 | cast(int, ssl_input) + cast(int, spec_long_input) + self.num_dataset,
86 | cfg.model.ssl_spec.num_classes,
87 | )
88 |
89 | def forward(
90 | self, x1: torch.Tensor, x2: torch.Tensor, d: torch.Tensor
91 | ) -> torch.Tensor:
92 | x1 = self.ssl(x1, torch.zeros(x1.shape[0], self.num_dataset).to(x1.device))
93 | x2 = self.spec_long(
94 | x2, torch.zeros(x1.shape[0], self.num_dataset).to(x1.device)
95 | )
96 | x = torch.cat([x1, x2, d], dim=1)
97 | x = self.fc(x)
98 | return x
99 |
--------------------------------------------------------------------------------
/utmosv2/preprocess/__init__.py:
--------------------------------------------------------------------------------
1 | from utmosv2.preprocess._preprocess import (
2 | add_sys_mean,
3 | preprocess,
4 | preprocess_test,
5 | remove_silent_section,
6 | )
7 |
8 | __all__ = ["add_sys_mean", "preprocess", "preprocess_test", "remove_silent_section"]
9 |
--------------------------------------------------------------------------------
/utmosv2/runner/__init__.py:
--------------------------------------------------------------------------------
1 | from utmosv2.runner._inference import run_inference
2 | from utmosv2.runner._train import run_train, train_1epoch, validate_1epoch
3 |
4 | __all__ = [
5 | "run_train",
6 | "train_1epoch",
7 | "validate_1epoch",
8 | "run_inference",
9 | ]
10 |
--------------------------------------------------------------------------------
/utmosv2/runner/_inference.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import TYPE_CHECKING
4 |
5 | import numpy as np
6 | import torch
7 | from torch.cuda.amp import autocast
8 | from tqdm import tqdm
9 |
10 | from utmosv2._settings._config import Config
11 | from utmosv2.utils import calc_metrics, print_metrics
12 |
13 | if TYPE_CHECKING:
14 | import pandas as pd
15 |
16 |
17 | def run_inference(
18 | cfg: Config,
19 | model: torch.nn.Module,
20 | test_dataloader: torch.utils.data.DataLoader,
21 | cycle: int,
22 | test_data: "pd.DataFrame",
23 | device: torch.device,
24 | ) -> tuple[np.ndarray, dict[str, float] | None]:
25 | """
26 | Run inference on the test dataset using the provided model.
27 |
28 | Args:
29 | cfg (SimpleNamespace | ModuleType):
30 | Configuration object containing inference settings.
31 | It includes settings for test-time augmentation (TTA) and reproducibility.
32 | model (torch.nn.Module):
33 | The trained model to be used for inference.
34 | test_dataloader (torch.utils.data.DataLoader):
35 | Dataloader for the test dataset.
36 | cycle (int):
37 | Current cycle of test-time augmentation (TTA) if used.
38 | test_data (pd.DataFrame):
39 | DataFrame containing test data, used for metric calculation if reproducibility is enabled.
40 | device (torch.device):
41 | Device to run inference on (e.g., 'cuda' or 'cpu').
42 |
43 | Returns:
44 | tuple[np.ndarray, dict[str, float] | None]:
45 | - test_preds: Array containing the model's predictions for the test dataset.
46 | - test_metrics: Dictionary containing the calculated metrics if reproducibility is enabled; otherwise, None.
47 | """
48 | model.eval()
49 | test_preds_ls = []
50 | pbar = tqdm(
51 | test_dataloader,
52 | total=len(test_dataloader),
53 | desc=f" [Inference] ({cycle + 1}/{cfg.inference.num_tta})",
54 | )
55 |
56 | with torch.no_grad():
57 | for t in pbar:
58 | x = t[:-1]
59 | x = [t.to(device, non_blocking=True) for t in x]
60 | with autocast():
61 | output = model(*x).squeeze(1)
62 | test_preds_ls.append(output.cpu().numpy())
63 | test_preds = np.concatenate(test_preds_ls)
64 | if cfg.reproduce:
65 | test_metrics = calc_metrics(test_data, test_preds)
66 | print_metrics(test_metrics)
67 | else:
68 | test_metrics = None
69 |
70 | return test_preds, test_metrics
71 |
--------------------------------------------------------------------------------
/utmosv2/transform/__init__.py:
--------------------------------------------------------------------------------
1 | from utmosv2.transform._xymasking import XYMasking
2 |
3 | __all__ = ["XYMasking"]
4 |
--------------------------------------------------------------------------------
/utmosv2/transform/_xymasking.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import TYPE_CHECKING
4 |
5 | import numpy as np
6 |
7 | if TYPE_CHECKING:
8 | import torch
9 |
10 |
11 | class XYMasking:
12 | """
13 | Apply random rectangular masks to an image along the x and y axes. This augmentation
14 | is useful for randomly masking parts of an image during training to improve robustness.
15 |
16 | Args:
17 | num_masks_x (int | tuple[int, int]):
18 | The number of masks to apply along the x-axis.
19 | If a tuple is provided, a random number in the range will be used.
20 | num_masks_y (int | tuple[int, int]):
21 | The number of masks to apply along the y-axis.
22 | If a tuple is provided, a random number in the range will be used.
23 | mask_x_length (int | tuple[int, int]):
24 | The length of each mask along the x-axis.
25 | If a tuple is provided, a random length in the range will be used.
26 | mask_y_length (int | tuple[int, int]):
27 | The length of each mask along the y-axis.
28 | If a tuple is provided, a random length in the range will be used.
29 | fill_value (int):
30 | The value to fill the masked areas with.
31 | p (float):
32 | The probability of applying the masking. Defaults to 1.0 (always apply masking).
33 | """
34 |
35 | def __init__(
36 | self,
37 | num_masks_x: int | tuple[int, int],
38 | num_masks_y: int | tuple[int, int],
39 | mask_x_length: int | tuple[int, int],
40 | mask_y_length: int | tuple[int, int],
41 | fill_value: int,
42 | p: float = 1.0,
43 | ):
44 | self.num_masks_x = num_masks_x
45 | self.num_masks_y = num_masks_y
46 | self.mask_x_length = mask_x_length
47 | self.mask_y_length = mask_y_length
48 | self.fill_value = fill_value
49 | self.p = p
50 |
51 | def __call__(self, img: "torch.Tensor") -> "torch.Tensor":
52 | """
53 | Apply the XY masking to the given image.
54 |
55 | Args:
56 | img (torch.Tensor): The input image tensor of shape (channels, width, height).
57 |
58 | Returns:
59 | torch.Tensor: The image tensor with masks applied along the x and y axes.
60 | """
61 | if np.random.rand() < self.p:
62 | return img
63 | _, width, height = img.shape
64 | num_masks_x = (
65 | np.random.randint(*self.num_masks_x)
66 | if isinstance(self.num_masks_x, tuple)
67 | else self.num_masks_x
68 | )
69 | for _ in range(num_masks_x):
70 | mask_x_length = (
71 | np.random.randint(*self.mask_x_length)
72 | if isinstance(self.mask_x_length, tuple)
73 | else self.mask_x_length
74 | )
75 | x = np.random.randint(0, width - mask_x_length)
76 | img[:, :, x : x + mask_x_length] = self.fill_value
77 |
78 | num_masks_y = (
79 | np.random.randint(*self.num_masks_y)
80 | if isinstance(self.num_masks_y, tuple)
81 | else self.num_masks_y
82 | )
83 | for _ in range(num_masks_y):
84 | mask_y_length = (
85 | np.random.randint(*self.mask_y_length)
86 | if isinstance(self.mask_y_length, tuple)
87 | else self.mask_y_length
88 | )
89 | y = np.random.randint(0, height - mask_y_length)
90 | img[:, y : y + mask_y_length, :] = self.fill_value
91 |
92 | return img
93 |
--------------------------------------------------------------------------------
/utmosv2/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from utmosv2.utils._pure import (
2 | get_dataloader,
3 | get_loss,
4 | get_optimizer,
5 | get_scheduler,
6 | print_metrics,
7 | save_oof_preds,
8 | split_data,
9 | )
10 | from utmosv2.utils._task_dependents import (
11 | calc_metrics,
12 | get_data,
13 | get_dataset,
14 | get_inference_data,
15 | get_metrics,
16 | get_model,
17 | get_train_data,
18 | make_submission_file,
19 | save_preds,
20 | save_test_preds,
21 | show_inference_data,
22 | )
23 | from utmosv2.utils._download import download_pretrained_weights_from_hf
24 |
25 | __all__ = [
26 | "get_dataloader",
27 | "get_loss",
28 | "get_optimizer",
29 | "get_scheduler",
30 | "print_metrics",
31 | "save_oof_preds",
32 | "split_data",
33 | "calc_metrics",
34 | "get_data",
35 | "get_dataset",
36 | "get_inference_data",
37 | "get_train_data",
38 | "get_metrics",
39 | "get_model",
40 | "make_submission_file",
41 | "save_preds",
42 | "save_test_preds",
43 | "show_inference_data",
44 | "download_pretrained_weights_from_hf",
45 | ]
46 |
--------------------------------------------------------------------------------
/utmosv2/utils/_constants.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 |
4 | _UTMOSV2_CHACHE = Path(os.getenv("UTMOSV2_CHACHE", "~/.cache/utmosv2")).expanduser()
5 |
--------------------------------------------------------------------------------
/utmosv2/utils/_download.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 |
3 | from utmosv2.utils._constants import _UTMOSV2_CHACHE
4 |
5 |
6 | def download_pretrained_weights_from_github(cfg_name: str) -> None:
7 | if cfg_name != "fusion_stage3":
8 | raise ValueError(f"{cfg_name} is not stored.")
9 | print(f"Downloading pretrained weights for `{cfg_name}`...")
10 | try:
11 | subprocess.run(
12 | [
13 | "git",
14 | "clone",
15 | "--filter=blob:none",
16 | "--no-checkout",
17 | "https://github.com/sarulab-speech/UTMOSv2.git",
18 | _UTMOSV2_CHACHE.as_posix(),
19 | ],
20 | check=True,
21 | )
22 | subprocess.run(
23 | ["git", "sparse-checkout", "set", "models"],
24 | cwd=_UTMOSV2_CHACHE,
25 | check=True,
26 | )
27 | subprocess.run(
28 | ["git", "checkout"],
29 | cwd=_UTMOSV2_CHACHE,
30 | check=True,
31 | )
32 | except subprocess.CalledProcessError as e:
33 | print(f"Failed to download pretrained weights: {e}")
34 | print("Done.")
35 |
36 |
37 | def download_pretrained_weights_from_hf(cfg_name: str, now_fold: int) -> None:
38 | if cfg_name != "fusion_stage3":
39 | raise ValueError(f"{cfg_name} is not stored.")
40 | print(f"Downloading pretrained weights for `{cfg_name}`...")
41 | url = f"https://huggingface.co/sarulab-speech/UTMOSv2/resolve/main/fold{now_fold}_s42_best_model.pth"
42 | try:
43 | subprocess.run(
44 | [
45 | "wget",
46 | "-P",
47 | (_UTMOSV2_CHACHE / "models" / cfg_name).as_posix(),
48 | url,
49 | ]
50 | )
51 | except subprocess.CalledProcessError as e:
52 | print(f"Failed to download pretrained weights: {e}")
53 | print("Done.")
54 |
--------------------------------------------------------------------------------
/utmosv2/utils/_pure/__init__.py:
--------------------------------------------------------------------------------
1 | from utmosv2.utils._pure.initializers import (
2 | get_dataloader,
3 | get_loss,
4 | get_optimizer,
5 | get_scheduler,
6 | )
7 | from utmosv2.utils._pure.metrics import print_metrics
8 | from utmosv2.utils._pure.save import save_oof_preds
9 | from utmosv2.utils._pure.split import split_data
10 |
11 | __all__ = [
12 | "get_dataloader",
13 | "get_loss",
14 | "get_optimizer",
15 | "get_scheduler",
16 | "print_metrics",
17 | "save_oof_preds",
18 | "split_data",
19 | ]
20 |
--------------------------------------------------------------------------------
/utmosv2/utils/_pure/initializers.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.optim as optim
6 |
7 | from utmosv2._settings._config import Config
8 | from utmosv2.loss import CombinedLoss, PairwizeDiffLoss
9 |
10 |
11 | def get_dataloader(
12 | cfg: Config, dataset: torch.utils.data.Dataset, phase: str
13 | ) -> torch.utils.data.DataLoader:
14 | """
15 | Return a DataLoader for the specified dataset and phase.
16 |
17 | Args:
18 | cfg (SimpleNamespace | ModuleType):
19 | Configuration object containing settings for batch size, number of workers, and pin memory.
20 | dataset (torch.utils.data.Dataset):
21 | The dataset to load data from.
22 | phase (str):
23 | The phase of the training process. Must be one of ["train", "valid", "test"].
24 |
25 | Returns:
26 | torch.utils.data.DataLoader: A DataLoader for the given dataset and phase.
27 |
28 | Raises:
29 | ValueError: If the phase is not one of ["train", "valid", "test"].
30 | """
31 | if phase == "train":
32 | return torch.utils.data.DataLoader(
33 | dataset,
34 | batch_size=cfg.batch_size,
35 | shuffle=True,
36 | num_workers=cfg.num_workers,
37 | pin_memory=True,
38 | )
39 | elif phase == "valid":
40 | return torch.utils.data.DataLoader(
41 | dataset,
42 | batch_size=cfg.batch_size,
43 | shuffle=False,
44 | num_workers=cfg.num_workers,
45 | pin_memory=True,
46 | )
47 | elif phase == "test":
48 | return torch.utils.data.DataLoader(
49 | dataset,
50 | batch_size=cfg.inference.batch_size,
51 | shuffle=False,
52 | num_workers=cfg.num_workers,
53 | pin_memory=True,
54 | )
55 | else:
56 | raise ValueError(f"Phase must be one of [train, valid, test], but got {phase}")
57 |
58 |
59 | def _get_unit_loss(loss_cfg: Config) -> nn.Module:
60 | if loss_cfg.name == "pairwize_diff":
61 | return PairwizeDiffLoss(loss_cfg.margin, loss_cfg.norm)
62 | elif loss_cfg.name == "mse":
63 | return nn.MSELoss()
64 | else:
65 | raise NotImplementedError
66 |
67 |
68 | def _get_combined_loss(cfg: Config) -> nn.Module:
69 | if cfg.print_config:
70 | print(
71 | "Using losses: "
72 | + ", ".join([f"{loss_cfg.name} ({w})" for loss_cfg, w in cfg.loss])
73 | )
74 | weighted_losses = [(_get_unit_loss(loss_cfg), w) for loss_cfg, w in cfg.loss]
75 | return CombinedLoss(weighted_losses)
76 |
77 |
78 | def get_loss(cfg: Config) -> nn.Module:
79 | """
80 | Return the appropriate loss function based on the configuration.
81 |
82 | Args:
83 | cfg (SimpleNamespace | ModuleType):
84 | Configuration object containing the loss settings.
85 | If `cfg.loss` is a list, a combined loss is returned.
86 | Otherwise, a single loss function is returned.
87 |
88 | Returns:
89 | nn.Module: The configured loss function, either a single loss or a combined loss module.
90 | """
91 | if isinstance(cfg.loss, list):
92 | return _get_combined_loss(cfg)
93 | else:
94 | return _get_unit_loss(cfg.loss)
95 |
96 |
97 | def get_optimizer(cfg: Config, model: nn.Module) -> optim.Optimizer:
98 | """
99 | Return the optimizer based on the configuration settings.
100 |
101 | Args:
102 | cfg (SimpleNamespace | ModuleType):
103 | Configuration object containing optimizer settings.
104 | The optimizer name and learning rate are specified in `cfg.optimizer`.
105 | model (nn.Module):
106 | The model whose parameters will be optimized.
107 |
108 | Returns:
109 | optim.Optimizer: The configured optimizer (Adam, AdamW, or SGD).
110 |
111 | Raises:
112 | NotImplementedError: If the specified optimizer is not implemented.
113 | """
114 | if cfg.print_config:
115 | print(f"Using optimizer: {cfg.optimizer.name}")
116 | if cfg.optimizer.name == "adam":
117 | return optim.Adam(model.parameters(), lr=cfg.optimizer.lr)
118 | elif cfg.optimizer.name == "adamw":
119 | return optim.AdamW(
120 | model.parameters(),
121 | lr=cfg.optimizer.lr,
122 | weight_decay=cfg.optimizer.weight_decay,
123 | )
124 | elif cfg.optimizer.name == "sgd":
125 | return optim.SGD(
126 | model.parameters(),
127 | lr=cfg.optimizer.lr,
128 | weight_decay=cfg.optimizer.weight_decay,
129 | )
130 | else:
131 | raise NotImplementedError
132 |
133 |
134 | def get_scheduler(
135 | cfg: Config, optimizer: optim.Optimizer, n_iterations: int
136 | ) -> optim.lr_scheduler.LRScheduler:
137 | """
138 | Return the learning rate scheduler based on the configuration settings.
139 |
140 | Args:
141 | cfg (SimpleNamespace | ModuleType):
142 | Configuration object containing scheduler settings.
143 | The scheduler name, T_max, and eta_min are specified in `cfg.scheduler`.
144 | optimizer (optim.Optimizer):
145 | The optimizer for which the learning rate will be scheduled.
146 | n_iterations (int):
147 | The number of iterations for the scheduler (used in CosineAnnealingLR).
148 |
149 | Returns:
150 | optim.lr_scheduler.LRScheduler: The configured learning rate scheduler.
151 |
152 | Raises:
153 | NotImplementedError: If the specified scheduler is not implemented.
154 | """
155 | if cfg.print_config:
156 | print(f"Using scheduler: {cfg.scheduler}")
157 | if cfg.scheduler is None:
158 | return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: 1)
159 | if cfg.scheduler.name == "cosine":
160 | return optim.lr_scheduler.CosineAnnealingLR(
161 | optimizer,
162 | T_max=cfg.scheduler.T_max or n_iterations,
163 | eta_min=cfg.scheduler.eta_min,
164 | )
165 | else:
166 | raise NotImplementedError
167 |
--------------------------------------------------------------------------------
/utmosv2/utils/_pure/metrics.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 |
4 | def print_metrics(metrics: dict[str, float]) -> None:
5 | """
6 | Print the given metrics in a formatted string.
7 |
8 | Args:
9 | metrics (dict[str, float]):
10 | A dictionary of metric names and their corresponding values.
11 |
12 | Returns:
13 | None: This function prints the metrics to the console in the format "metric_name: value".
14 | """
15 | print(", ".join([f"{k}: {v:.4f}" for k, v in metrics.items()]))
16 |
--------------------------------------------------------------------------------
/utmosv2/utils/_pure/save.py:
--------------------------------------------------------------------------------
1 | from typing import TYPE_CHECKING
2 |
3 | from utmosv2._import import _LazyImport
4 | from utmosv2._settings._config import Config
5 |
6 | if TYPE_CHECKING:
7 | import numpy as np
8 | import pandas as pd
9 | else:
10 | pd = _LazyImport("pandas")
11 |
12 |
13 | def save_oof_preds(
14 | cfg: Config, data: "pd.DataFrame", oof_preds: "np.ndarray", fold: int
15 | ) -> None:
16 | """
17 | Save out-of-fold (OOF) predictions to a CSV file.
18 |
19 | Args:
20 | cfg (SimpleNamespace):
21 | Configuration object containing settings for saving OOF predictions.
22 | Includes `id_name` for the ID column and `save_path` for the save directory.
23 | data (pd.DataFrame):
24 | The original dataset containing the ID column.
25 | oof_preds (np.ndarray):
26 | The array of OOF predictions.
27 | fold (int):
28 | The current fold number used in cross-validation.
29 |
30 | Returns:
31 | None: The function saves the OOF predictions to a CSV file in the specified save path.
32 | """
33 | oof_df = pd.DataFrame({cfg.id_name: data[cfg.id_name], "oof_preds": oof_preds})
34 | oof_df.to_csv(
35 | cfg.save_path / f"fold{fold}_s{cfg.split.seed}_oof_preds.csv", index=False
36 | )
37 |
--------------------------------------------------------------------------------
/utmosv2/utils/_pure/split.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from collections.abc import Generator
4 | from typing import TYPE_CHECKING
5 |
6 | import numpy as np
7 |
8 | from utmosv2._import import _LazyImport
9 | from utmosv2._settings._config import Config
10 |
11 | if TYPE_CHECKING:
12 | import pandas as pd
13 | from sklearn.model_selection import (
14 | GroupKFold,
15 | KFold,
16 | StratifiedGroupKFold,
17 | StratifiedKFold,
18 | )
19 | else:
20 | _model_selection = _LazyImport("sklearn.model_selection")
21 | GroupKFold = _model_selection.GroupKFold
22 | KFold = _model_selection.KFold
23 | StratifiedGroupKFold = _model_selection.StratifiedGroupKFold
24 | StratifiedKFold = _model_selection.StratifiedKFold
25 |
26 |
27 | def split_data(
28 | cfg: Config, data: "pd.DataFrame"
29 | ) -> Generator[tuple[np.ndarray, np.ndarray], None, None]:
30 | """
31 | Split the data into training and validation sets based on the specified splitting method in the configuration.
32 |
33 | Args:
34 | cfg (SimpleNamespace | ModuleType): Configuration object containing the splitting settings. It includes:
35 | - split.type: Type of split to use ('simple', 'stratified', 'group', 'stratified_group', etc.).
36 | - num_folds: Number of folds for K-Fold cross-validation.
37 | - split.seed: Random seed for shuffling.
38 | - split.target: Target column used for stratification in 'stratified' and 'stratified_group'.
39 | - split.group: Group column used for grouping in 'group' and 'stratified_group'.
40 | - split.kind: Kind of data for splitting in the 'sgkf_kind' case.
41 | data (pd.DataFrame): The dataset to be split.
42 |
43 | Yields:
44 | tuple[np.ndarray, np.ndarray]: Indices of training and validation sets for each fold.
45 |
46 | Raises:
47 | NotImplementedError: If the split type specified in the configuration is not implemented.
48 | """
49 | if cfg.print_config:
50 | print(f"Using split: {cfg.split.type}")
51 | if cfg.split.type == "simple":
52 | kf = KFold(n_splits=cfg.num_folds, shuffle=True, random_state=cfg.split.seed)
53 | for train_idx, valid_idx in kf.split(data):
54 | yield train_idx, valid_idx
55 | elif cfg.split.type == "stratified":
56 | kf = StratifiedKFold(
57 | n_splits=cfg.num_folds, shuffle=True, random_state=cfg.split.seed
58 | )
59 | for train_idx, valid_idx in kf.split(data, data[cfg.split.target].astype(int)):
60 | yield train_idx, valid_idx
61 | elif cfg.split.type == "group":
62 | kf = GroupKFold(n_splits=cfg.num_folds)
63 | for train_idx, valid_idx in kf.split(data, groups=data[cfg.split.group]):
64 | yield train_idx, valid_idx
65 | elif cfg.split.type == "stratified_group":
66 | kf = StratifiedGroupKFold(
67 | n_splits=cfg.num_folds, shuffle=True, random_state=cfg.split.seed
68 | )
69 | for train_idx, valid_idx in kf.split(
70 | data, data[cfg.split.target].astype(int), groups=data[cfg.split.group]
71 | ):
72 | yield train_idx, valid_idx
73 | elif cfg.split.type == "sgkf_kind":
74 | kind = data[cfg.split.kind].unique()
75 | kf = [
76 | StratifiedGroupKFold(
77 | n_splits=cfg.num_folds, shuffle=True, random_state=cfg.split.seed
78 | )
79 | for _ in range(len(kind))
80 | ]
81 | kf = [
82 | kf_i.split(
83 | data[data[cfg.split.kind] == ds],
84 | data[data[cfg.split.kind] == ds][cfg.split.target].astype(int),
85 | groups=data[data[cfg.split.kind] == ds][cfg.split.group],
86 | )
87 | for kf_i, ds in zip(kf, kind)
88 | ]
89 | for ds_idx in zip(*kf):
90 | train_idx = np.concatenate([d[0] for d in ds_idx])
91 | valid_idx = np.concatenate([d[1] for d in ds_idx])
92 | yield train_idx, valid_idx
93 | else:
94 | raise NotImplementedError
95 |
--------------------------------------------------------------------------------
/utmosv2/utils/_task_dependents/__init__.py:
--------------------------------------------------------------------------------
1 | from utmosv2.utils._task_dependents.initializers import (
2 | get_data,
3 | get_dataset,
4 | get_inference_data,
5 | get_metrics,
6 | get_model,
7 | get_train_data,
8 | )
9 | from utmosv2.utils._task_dependents.log import show_inference_data
10 | from utmosv2.utils._task_dependents.metrics import calc_metrics
11 | from utmosv2.utils._task_dependents.save import (
12 | make_submission_file,
13 | save_preds,
14 | save_test_preds,
15 | )
16 |
17 | __all__ = [
18 | "get_data",
19 | "get_dataset",
20 | "get_inference_data",
21 | "get_metrics",
22 | "get_model",
23 | "get_train_data",
24 | "show_inference_data",
25 | "calc_metrics",
26 | "make_submission_file",
27 | "save_preds",
28 | "save_test_preds",
29 | ]
30 |
--------------------------------------------------------------------------------
/utmosv2/utils/_task_dependents/log.py:
--------------------------------------------------------------------------------
1 | from typing import TYPE_CHECKING
2 |
3 | if TYPE_CHECKING:
4 | import pandas as pd
5 |
6 |
7 | def show_inference_data(data: "pd.DataFrame") -> None:
8 | print(
9 | data[[c for c in data.columns if c != "mos"]]
10 | .rename(columns={"dataset": "predict_dataset"})
11 | .head()
12 | )
13 |
--------------------------------------------------------------------------------
/utmosv2/utils/_task_dependents/metrics.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import TYPE_CHECKING
4 |
5 | import numpy as np
6 | import scipy.stats
7 |
8 | if TYPE_CHECKING:
9 | import pandas as pd
10 |
11 |
12 | def calc_metrics(data: "pd.DataFrame", preds: np.ndarray) -> dict[str, float]:
13 | data = data.copy()
14 | data["preds"] = preds
15 | data_sys = data.groupby("sys_id", as_index=False)[["mos", "preds"]].mean()
16 | res = {}
17 | for name, d in {"utt": data, "sys": data_sys}.items():
18 | res[f"{name}_mse"] = np.mean((d["mos"].values - d["preds"].values) ** 2)
19 | res[f"{name}_lcc"] = np.corrcoef(d["mos"].values, d["preds"].values)[0][1]
20 | res[f"{name}_srcc"] = scipy.stats.spearmanr(d["mos"].values, d["preds"].values)[
21 | 0
22 | ]
23 | res[f"{name}_ktau"] = scipy.stats.kendalltau(
24 | d["mos"].values, d["preds"].values
25 | )[0]
26 | return res
27 |
--------------------------------------------------------------------------------
/utmosv2/utils/_task_dependents/save.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import json
4 | from typing import TYPE_CHECKING
5 |
6 | import numpy as np
7 |
8 | from utmosv2._import import _LazyImport
9 | from utmosv2._settings._config import Config
10 | from utmosv2.utils._task_dependents.initializers import _get_test_save_name
11 |
12 | if TYPE_CHECKING:
13 | import pandas as pd
14 | else:
15 | pd = _LazyImport("pandas")
16 |
17 |
18 | def save_test_preds(
19 | cfg: Config,
20 | data: "pd.DataFrame",
21 | test_preds: np.ndarray,
22 | test_metrics: dict[str, float],
23 | ) -> None:
24 | test_df = pd.DataFrame({cfg.id_name: data[cfg.id_name], "test_preds": test_preds})
25 | cfg.inference.save_path.mkdir(parents=True, exist_ok=True)
26 | save_path = (
27 | cfg.inference.save_path
28 | / f"{_get_test_save_name(cfg)}_({cfg.predict_dataset})_test_preds{'_final' if cfg.final else ''}.csv"
29 | )
30 | test_df.to_csv(save_path, index=False)
31 | save_path = (
32 | cfg.inference.save_path
33 | / f"{_get_test_save_name(cfg)}_({cfg.predict_dataset})_val_score{'_final' if cfg.final else ''}.json"
34 | )
35 | with open(save_path, "w") as f:
36 | json.dump(test_metrics, f)
37 | print(f"Test predictions are saved to {save_path}")
38 |
39 |
40 | def make_submission_file(
41 | cfg: Config, data: "pd.DataFrame", test_preds: np.ndarray
42 | ) -> None:
43 | submit = pd.DataFrame({cfg.id_name: data[cfg.id_name], "prediction": test_preds})
44 | (
45 | cfg.inference.submit_save_path
46 | / f"{_get_test_save_name(cfg)}_({cfg.predict_dataset})"
47 | ).mkdir(parents=True, exist_ok=True)
48 | sub_file = (
49 | cfg.inference.submit_save_path
50 | / f"{_get_test_save_name(cfg)}_({cfg.predict_dataset})"
51 | / "answer.txt"
52 | )
53 | submit.to_csv(
54 | sub_file,
55 | index=False,
56 | header=False,
57 | )
58 | print(f"Submission file is saved to {sub_file}")
59 |
60 |
61 | def save_preds(cfg: Config, data: "pd.DataFrame", test_preds: np.ndarray) -> None:
62 | pred = pd.DataFrame({cfg.id_name: data[cfg.id_name], "mos": test_preds})
63 | if cfg.out_path is None:
64 | print("Predictions:")
65 | print(pred)
66 | else:
67 | pred.to_csv(cfg.out_path, index=False)
68 | print(f"Predictions are saved to {cfg.out_path}")
69 |
--------------------------------------------------------------------------------