├── .gitignore
├── LICENSE
├── README.md
├── ci
└── format.py
├── config
└── spin.yaml
├── figure
└── spin.png
├── prepare_data.py
├── requirements.txt
├── run_task.py
├── s3prl_py
├── WavLM.py
├── spin
│ ├── __init__.py
│ ├── expert.py
│ └── hubconf.py
└── wav2vec2_model.py
├── script
├── prepare.sh
└── train.sh
└── src
├── data
├── __init__.py
├── audio.py
├── dataset.py
├── librispeech.py
└── sampler.py
├── model
├── __init__.py
├── base.py
└── spin.py
├── nn
├── __init__.py
├── dnn.py
├── hubert.py
├── swav_vq_dis.py
└── wavlm.py
├── task
├── __init__.py
└── train_spin.py
└── util
├── __init__.py
├── log.py
├── model_utils.py
├── padding.py
├── pnmi.py
└── scheduler.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | _data/
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | pip-wheel-metadata/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Heng-Jui Chang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Speaker-invariant Clustering (Spin)
2 |
3 | - [Introduction](#Introduction)
4 | - [Citation](#Citation)
5 | - [Getting Started](#Getting-Started)
6 | - [Pre-trained Models](#Pre-trained-Models)
7 | - [References](#References)
8 | - [Contact](#Contact)
9 |
14 |
15 | ## Introduction
16 |
17 |

18 |
19 | This repository is the official PyTorch implementation of the **Speaker-invariant Clustering** (**Spin**) proposed in the **Interspeech 2023** paper [Self-supervised Fine-tuning for Improved Content Representations by Speaker-invariant Clustering](https://arxiv.org/abs/2305.11072) ([Heng-Jui Chang](https://people.csail.mit.edu/hengjui/), [Alexander H. Liu](https://alexander-h-liu.github.io/), [James Glass](https://www.csail.mit.edu/person/jim-glass); [MIT CSAIL](https://www.csail.mit.edu/)).
20 |
21 | Spin is a novel self-supervised learning method that clusters speech representations and performs swapped prediction between the original and speaker-perturbed utterances. Spin *disentangles speaker information* and preserves *content representations* with just 45 minutes of fine-tuning on a single GPU (HuBERT Base models). Spin improves pre-trained networks and outperforms prior methods in speech recognition and acoustic unit discovery.
22 |
23 |
24 | ## Citation
25 | Please cite our paper if you find this repository and/or the paper useful.
26 | ```
27 | @inproceedings{chang2023spin,
28 | author={Heng-Jui Chang and Alexander H. Liu and James Glass},
29 | title={{Self-supervised Fine-tuning for Improved Content Representations by Speaker-invariant Clustering}},
30 | year=2023,
31 | booktitle={Proc. Interspeech}
32 | }
33 | ```
34 |
35 |
36 | ## Getting Started
37 |
38 | ### 1. Environment
39 | Make sure `sox` is installed and your Python version is at least `3.6`.
40 | ```bash
41 | # Create virtual environment
42 | conda create --name spin python=3.8
43 | conda activate spin
44 |
45 | # Install s3prl
46 | git clone https://github.com/s3prl/s3prl.git
47 | cd s3prl
48 | pip install -e ".[all]"
49 | cd ..
50 |
51 | # Clone this repository and intall dependencies
52 | git clone https://github.com/vectominist/spin.git
53 | cd spin/
54 | pip install -r requirements.txt
55 |
56 | # Modify some s3prl files
57 | cp s3prl_py/wav2vec2_model.py ../s3prl/s3prl/upstream/wav2vec2/wav2vec2_model.py
58 | cp s3prl_py/WavLM.py ../s3prl/s3prl/upstream/wavlm/WavLM.py
59 | ```
60 |
61 |
62 | ### 2. Prepare Data
63 | Download required data.
64 | ```bash
65 | # Create a directory to save data (or any other path you like)
66 | mkdir data
67 | cd data
68 |
69 | # LibriSpeech (skip if you already have this)
70 | wget https://www.openslr.org/resources/12/train-clean-100.tar.gz
71 | wget https://www.openslr.org/resources/12/dev-clean.tar.gz
72 | wget https://www.openslr.org/resources/12/dev-other.tar.gz
73 | # Decompress
74 | tar zxvf train-clean-100.tar.gz
75 | tar zxvf dev-clean.tar.gz
76 | tar zxvf dev-clean.tar.gz
77 | rm train-clean-100.tar.gz dev-clean.tar.gz dev-clean.tar.gz
78 |
79 | # LibriSpeech Phoneme Alignments (for monitoring progress only)
80 | wget https://huggingface.co/datasets/vectominist/spin_data/resolve/main/dev-clean.tsv
81 | wget https://huggingface.co/datasets/vectominist/spin_data/resolve/main/dev-other.tsv
82 |
83 | # Speaker Information
84 | # Source: https://github.com/auspicious3000/contentvec
85 | wget https://huggingface.co/datasets/vectominist/spin_data/resolve/main/spk2info.dict
86 | ```
87 |
88 | Prepare LibriSpeech dataset, see [`script/prepare.sh`](https://github.com/vectominist/spin/blob/main/script/prepare.sh).
89 | - `libri_dir`: the directory of the LibriSpeech corpus
90 | - `json_dir`: the directory to save `.json` files generated from `prepare_data.py`
91 | ```bash
92 | bash script/prepare.sh ${libri_dir} ${json_dir}
93 | ```
94 |
95 | ### 3. Customize Configurations
96 | See [`config/spin.yaml`](https://github.com/vectominist/spin/blob/main/config/spin.yaml).
97 | - Modify `json_dir`, `spk2info`, and `phn_dir` according to the directories with the downloaded and preprocessed data.
98 | - Modify `logger` to switch to other loggers or simply setting it to `False` to disable logging.
99 | ```yaml
100 | data:
101 | json_dir: /path/to/json_dir
102 | spk2info: /path/to/spk2info.dict
103 |
104 | val_data:
105 | json_dir: /path/to/json_dir
106 | phn_dir: /path/to/phoneme/alignments/dir
107 |
108 | trainer:
109 | logger: wandb # specify a pytorch-lightning logger you prefer
110 | ```
111 |
112 | ### 4. Training
113 | See [`script/train.sh`](https://github.com/vectominist/spin/blob/main/script/train.sh).
114 | - `exp_dir`: the directory to save checkpoints
115 | - `exp_name`: experiment name
116 | - See [`src/task/train_spin.py`](https://github.com/vectominist/spin/blob/main/src/task/train_spin.py) for details about available arguments like number of GPUs to be used.
117 | ```bash
118 | bash script/train.sh ${exp_dir} ${exp_name}
119 | ```
120 | The trained model checkpoints can be found in `${exp_dir}/${exp_name}`. Note that we use `last.ckpt` for evaluation and downstream tasks.
121 |
122 | ### 5. Downstream Evaluation
123 | We use the [s3prl](https://github.com/s3prl/s3prl) toolkit for [SUPERB](https://arxiv.org/abs/2105.01051) downstream tasks.
124 | - Modify [line 26](https://github.com/vectominist/spin/blob/main/s3prl_py/spin/expert.py#L26) of [`s3prl_py/spin/expert.py`](https://github.com/vectominist/spin/blob/main/s3prl_py/spin/expert.py) to the absolute path to `spin/`.
125 | - Copy the `s3prl_py/spin` directory to `s3prl` so that the toolkit can load the models.
126 | ```bash
127 | cp -R s3prl_py/spin ../s3prl/s3prl/upstream/spin
128 | ```
129 | - Finally, add the following line to `../s3prl/s3prl/hub.py`:
130 | ```python
131 | from s3prl.upstream.spin.hubconf import *
132 | ```
133 |
134 |
135 | ## Pre-trained Models
136 | All models are trained on a single NVIDIA A5000 GPU with 24GB VRAM. To reproduce similar or better performance, we suggest using GPUs larger than 24GB or specifying `strategy: ddp` under `trainer` in [`config/spin.yaml`](https://github.com/vectominist/spin/blob/main/config/spin.yaml) to enable multiple GPU training. Note that the following checkpoints are reproduced with the same recipe, so the results are slightly different from our paper. The training logs can be found in this [link](https://api.wandb.ai/links/vectominist/5254la3b).
137 |
138 | | Base Model | Clusters | PNMI | Checkpoint |
139 | | ---------- | -------- | ----- | ------------------------------------------------------------------------------------------------ |
140 | | HuBERT | 128 | 0.625 | [link](https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_hubert_128.ckpt) |
141 | | HuBERT | 256 | 0.658 | [link](https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_hubert_256.ckpt) |
142 | | HuBERT | 512 | 0.707 | [link](https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_hubert_512.ckpt) |
143 | | HuBERT | 1024 | 0.745 | [link](https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_hubert_1024.ckpt) |
144 | | HuBERT | 2048 | 0.774 | [link](https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_hubert_2048.ckpt) |
145 | | WavLM | 128 | 0.604 | [link](https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_wavlm_128.ckpt) |
146 | | WavLM | 256 | 0.658 | [link](https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_wavlm_256.ckpt) |
147 | | WavLM | 512 | 0.714 | [link](https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_wavlm_512.ckpt) |
148 | | WavLM | 1024 | 0.748 | [link](https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_wavlm_1024.ckpt) |
149 | | WavLM | 2048 | 0.775 | [link](https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_wavlm_2048.ckpt) |
150 |
151 |
152 | ## References
153 | - [s3prl](https://github.com/s3prl/s3prl)
154 | - [contentvec](https://github.com/auspicious3000/contentvec)
155 | - [fairseq](https://github.com/facebookresearch/fairseq)
156 | - [MiniASR](https://github.com/vectominist/MiniASR)
157 | - [PyTorch](https://pytorch.org/)
158 | - [PyTorch Lightning](https://lightning.ai/docs/pytorch/latest/)
159 |
160 |
161 | ## Contact
162 | If you have any questions, please open an issue or send me an email hengjui@mit.edu.
163 |
--------------------------------------------------------------------------------
/ci/format.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # ref: https://github.com/s3prl/s3prl/blob/main/ci/format.py
3 |
4 | import argparse
5 | from subprocess import CalledProcessError, check_output
6 |
7 |
8 | def get_third_party():
9 | package_list = []
10 | with open("./requirements.txt", "r") as fp:
11 | for line in fp:
12 | line = line.strip()
13 | if line == "":
14 | continue
15 | package_list.append(line.split(" ")[0])
16 | return package_list
17 |
18 |
19 | def run_command(command: str):
20 | try:
21 | check_output(command.split(" "))
22 | except CalledProcessError as e:
23 | print(e.output.decode("utf-8"))
24 | raise
25 |
26 |
27 | def main():
28 | parser = argparse.ArgumentParser()
29 | parser.add_argument(
30 | "files",
31 | type=str,
32 | nargs="*",
33 | default=["src"],
34 | )
35 | parser.add_argument("--check", action="store_true", help="Only checks the files")
36 | args = parser.parse_args()
37 |
38 | print(f"Formatting files: {args.files}")
39 | args.files = " ".join(args.files)
40 |
41 | print("Run flake8")
42 | # stop the build if there are Python syntax errors or undefined names
43 | run_command(
44 | f"flake8 {args.files} --count --select=E9,F63,F7,F82 --show-source --statistics"
45 | )
46 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
47 | run_command(
48 | f"flake8 {args.files} --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics"
49 | )
50 |
51 | print("Run black")
52 | if args.check:
53 | run_command(f"black --check {args.files}")
54 | else:
55 | run_command(f"black {args.files}")
56 |
57 | print("Run isort")
58 | third_party = get_third_party()
59 | third_party = ",".join(third_party)
60 | if args.check:
61 | run_command(
62 | f"isort --profile black --thirdparty {third_party} --check {args.files}"
63 | )
64 | else:
65 | run_command(f"isort --profile black --thirdparty {third_party} {args.files}")
66 |
67 | if args.check:
68 | print("Successfully passed the format check!")
69 |
70 |
71 | if __name__ == "__main__":
72 | main()
73 |
--------------------------------------------------------------------------------
/config/spin.yaml:
--------------------------------------------------------------------------------
1 | # Interspeech 2023 version
2 |
3 | # Training data
4 | data:
5 | json_dir: /data/sls/r/u/hengjui/home/scratch/spin_test/data
6 | splits:
7 | - train-clean-100
8 | sample_rate: 16000
9 | min_audio_len: 40000 # minimum audio samples per utterance
10 | random_crop_len: 272000 # maximum audio samples per utterance
11 | spk2info: /data/sls/r/u/hengjui/home/scratch/dataset/libri_util/spk2info.dict
12 |
13 | # Validation data (not used for checkpointing, just for monitoring training progress)
14 | val_data:
15 | json_dir: /data/sls/r/u/hengjui/home/scratch/spin_test/data
16 | phn_dir: /data/sls/r/u/hengjui/home/scratch/spin_test/data
17 | splits:
18 | - dev-clean
19 | - dev-other
20 | sample_rate: 16000
21 |
22 | # SpinModel config
23 | model:
24 | encoder:
25 | type: HuBERT # `HuBERT` / `WavLM`
26 | use_layer: 12 # the layer which its representations are used for clustering
27 | normalize: False
28 | feat_select: x
29 | randomize_all: False
30 | randomize_layers: []
31 | freeze_all: False
32 | freeze_layers: ["pos", 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # `pos`: positional encoding, `0`: CNN extractor
33 | pred_head:
34 | type: DNN
35 | hid_dims: [256]
36 | dropout: 0
37 | activation: ReLU
38 | loss:
39 | type: SwavVQDisentangle
40 | num_vars: 256 # cluster size
41 | epsilon: 0.02
42 | sinkhorn_iters: 3
43 | temp: 0.1
44 | l2_norm: True
45 | prob_ratio: 1.0
46 |
47 | # Optimization
48 | optim:
49 | optimizer:
50 | name: Adam
51 | args:
52 | lr: 1.e-4
53 | weight_decay: 1.e-6
54 | scheduler:
55 | name: linear_warmup_decay # `linear_warmup_decay` / `linear_warmup_cosine_scheduler` / `noam_scheduler`
56 | args:
57 | warmup: 2500
58 | max_step: 5000
59 | final_lr: 1.e-6
60 |
61 | hparam:
62 | batch_len: 4096000 # audio samples per GPU (256 secs ~ batch_size = 12.8k)
63 | val_batch_size: 8
64 |
65 | # pytorch_lightning.Trainer
66 | # ref: https://lightning.ai/docs/pytorch/latest/common/trainer.html
67 | trainer:
68 | max_steps: 5000
69 | gradient_clip_val: 10
70 | accumulate_grad_batches: 1
71 | precision: 16
72 | logger: wandb # use `False` to disable logging
73 | log_every_n_steps: 100
74 | default_root_dir: exp/tmp
75 | accelerator: gpu
76 | # strategy: ddp # uncomment this line to enable DDP training
77 | num_sanity_val_steps: 0
78 | val_check_interval: 1000
79 |
80 | # pytorch_lightning.callbacks.ModelCheckpoint
81 | # ref: https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.ModelCheckpoint.html
82 | checkpoint:
83 | filename: "{epoch}-{step}"
84 | every_n_train_steps: 5000
85 | save_last: true
86 |
87 | # pytorch_lightning.loggers.WandbLogger
88 | # ref: https://lightning.ai/docs/pytorch/latest/extensions/generated/lightning.pytorch.loggers.WandbLogger.html
89 | logger:
90 | project: spin_is2023
91 |
--------------------------------------------------------------------------------
/figure/spin.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vectominist/spin/73edb7ae120be67a2d584931eeb388caecda744e/figure/spin.png
--------------------------------------------------------------------------------
/prepare_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os
4 | import sys
5 |
6 | from src.data.librispeech import find_all_librispeech, save_data_info
7 |
8 | logging.basicConfig(
9 | format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
10 | datefmt="%Y-%m-%d %H:%M:%S",
11 | level=os.environ.get("LOGLEVEL", "INFO").upper(),
12 | stream=sys.stdout,
13 | )
14 | logger = logging.getLogger("prepare_data")
15 |
16 |
17 | def main():
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument("root", type=str, help="Root directory of LibriSpeech")
20 | parser.add_argument("json_dir", type=str, help="Directory to save .json files")
21 | parser.add_argument(
22 | "--split",
23 | "-s",
24 | type=str,
25 | nargs="+",
26 | default=[
27 | "train-clean-100",
28 | "train-clean-360",
29 | "train-other-500",
30 | "dev-clean",
31 | "dev-other",
32 | "test-clean",
33 | "test-other",
34 | ],
35 | help="LibriSpeech partitions to be processed",
36 | )
37 | parser.add_argument(
38 | "--sort-by-len", "-l", action="store_true", help="Sort audio files by length"
39 | )
40 | args = parser.parse_args()
41 |
42 | logger.info(f"Preparing data from LibriSpeech at {args.root}")
43 | logger.info(f"Splits: {args.split}")
44 | logger.info(f"Sort audio files by length = {args.sort_by_len}")
45 | for s in args.split:
46 | logger.info(f"Processing {s} split...")
47 | data = find_all_librispeech(os.path.join(args.root, s), args.sort_by_len)
48 | save_data_info(data, os.path.join(args.json_dir, s + ".json"))
49 |
50 |
51 | if __name__ == "__main__":
52 | main()
53 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | praat-parselmouth==0.4.3
2 |
--------------------------------------------------------------------------------
/run_task.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from src import task
4 |
5 |
6 | def main():
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument("task")
9 | args, _ = parser.parse_known_args()
10 |
11 | runner = getattr(task, args.task)()
12 | runner.run()
13 |
14 |
15 | if __name__ == "__main__":
16 | main()
17 |
--------------------------------------------------------------------------------
/s3prl_py/WavLM.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
3 | # Github source: https://github.com/microsoft/unilm/tree/master/wavlm
4 | # Copyright (c) 2021 Microsoft
5 | # Licensed under The MIT License [see LICENSE for details]
6 | # Based on fairseq code bases
7 | # https://github.com/pytorch/fairseq
8 | # --------------------------------------------------------
9 |
10 | import logging
11 | import math
12 | from typing import List, Optional, Tuple
13 |
14 | import numpy as np
15 | import torch
16 | import torch.nn as nn
17 | import torch.nn.functional as F
18 | from torch.nn import LayerNorm
19 |
20 | from .modules import (
21 | Fp32GroupNorm,
22 | Fp32LayerNorm,
23 | GLU_Linear,
24 | GradMultiply,
25 | MultiheadAttention,
26 | SamePad,
27 | TransposeLast,
28 | get_activation_fn,
29 | init_bert_params,
30 | )
31 |
32 | logger = logging.getLogger(__name__)
33 |
34 |
35 | def compute_mask_indices(
36 | shape: Tuple[int, int],
37 | padding_mask: Optional[torch.Tensor],
38 | mask_prob: float,
39 | mask_length: int,
40 | mask_type: str = "static",
41 | mask_other: float = 0.0,
42 | min_masks: int = 0,
43 | no_overlap: bool = False,
44 | min_space: int = 0,
45 | ) -> np.ndarray:
46 | """
47 | Computes random mask spans for a given shape
48 |
49 | Args:
50 | shape: the the shape for which to compute masks.
51 | should be of size 2 where first element is batch size and 2nd is timesteps
52 | padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
53 | mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
54 | number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
55 | however due to overlaps, the actual number will be smaller (unless no_overlap is True)
56 | mask_type: how to compute mask lengths
57 | static = fixed size
58 | uniform = sample from uniform distribution [mask_other, mask_length*2]
59 | normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
60 | poisson = sample from possion distribution with lambda = mask length
61 | min_masks: minimum number of masked spans
62 | no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
63 | min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
64 | """
65 |
66 | bsz, all_sz = shape
67 | mask = np.full((bsz, all_sz), False)
68 |
69 | all_num_mask = int(
70 | # add a random number for probabilistic rounding
71 | mask_prob * all_sz / float(mask_length)
72 | + np.random.rand()
73 | )
74 |
75 | all_num_mask = max(min_masks, all_num_mask)
76 |
77 | mask_idcs = []
78 | for i in range(bsz):
79 | if padding_mask is not None:
80 | sz = all_sz - padding_mask[i].long().sum().item()
81 | num_mask = int(
82 | # add a random number for probabilistic rounding
83 | mask_prob * sz / float(mask_length)
84 | + np.random.rand()
85 | )
86 | num_mask = max(min_masks, num_mask)
87 | else:
88 | sz = all_sz
89 | num_mask = all_num_mask
90 |
91 | if mask_type == "static":
92 | lengths = np.full(num_mask, mask_length)
93 | elif mask_type == "uniform":
94 | lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
95 | elif mask_type == "normal":
96 | lengths = np.random.normal(mask_length, mask_other, size=num_mask)
97 | lengths = [max(1, int(round(x))) for x in lengths]
98 | elif mask_type == "poisson":
99 | lengths = np.random.poisson(mask_length, size=num_mask)
100 | lengths = [int(round(x)) for x in lengths]
101 | else:
102 | raise Exception("unknown mask selection " + mask_type)
103 |
104 | if sum(lengths) == 0:
105 | lengths[0] = min(mask_length, sz - 1)
106 |
107 | if no_overlap:
108 | mask_idc = []
109 |
110 | def arrange(s, e, length, keep_length):
111 | span_start = np.random.randint(s, e - length)
112 | mask_idc.extend(span_start + i for i in range(length))
113 |
114 | new_parts = []
115 | if span_start - s - min_space >= keep_length:
116 | new_parts.append((s, span_start - min_space + 1))
117 | if e - span_start - keep_length - min_space > keep_length:
118 | new_parts.append((span_start + length + min_space, e))
119 | return new_parts
120 |
121 | parts = [(0, sz)]
122 | min_length = min(lengths)
123 | for length in sorted(lengths, reverse=True):
124 | lens = np.fromiter(
125 | (e - s if e - s >= length + min_space else 0 for s, e in parts),
126 | np.int,
127 | )
128 | l_sum = np.sum(lens)
129 | if l_sum == 0:
130 | break
131 | probs = lens / np.sum(lens)
132 | c = np.random.choice(len(parts), p=probs)
133 | s, e = parts.pop(c)
134 | parts.extend(arrange(s, e, length, min_length))
135 | mask_idc = np.asarray(mask_idc)
136 | else:
137 | min_len = min(lengths)
138 | if sz - min_len <= num_mask:
139 | min_len = sz - num_mask - 1
140 |
141 | mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
142 |
143 | mask_idc = np.asarray(
144 | [
145 | mask_idc[j] + offset
146 | for j in range(len(mask_idc))
147 | for offset in range(lengths[j])
148 | ]
149 | )
150 |
151 | mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
152 |
153 | min_len = min([len(m) for m in mask_idcs])
154 | for i, mask_idc in enumerate(mask_idcs):
155 | if len(mask_idc) > min_len:
156 | mask_idc = np.random.choice(mask_idc, min_len, replace=False)
157 | mask[i, mask_idc] = True
158 |
159 | return mask
160 |
161 |
162 | class WavLMConfig:
163 | def __init__(self, cfg=None):
164 | self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True)
165 | self.encoder_layers: int = 12 # num encoder layers in the transformer
166 |
167 | self.encoder_embed_dim: int = 768 # encoder embedding dimension
168 | self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
169 | self.encoder_attention_heads: int = 12 # num encoder attention heads
170 | self.activation_fn: str = "gelu" # activation function to use
171 |
172 | self.layer_norm_first: bool = False # apply layernorm first in the transformer
173 | self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]
174 | self.conv_bias: bool = False # include bias in conv encoder
175 | self.feature_grad_mult: float = (
176 | 1.0 # multiply feature extractor var grads by this
177 | )
178 |
179 | self.normalize: bool = (
180 | False # normalize input to have 0 mean and unit variance during training
181 | )
182 |
183 | # dropouts
184 | self.dropout: float = 0.1 # dropout probability for the transformer
185 | self.attention_dropout: float = 0.1 # dropout probability for attention weights
186 | self.activation_dropout: float = (
187 | 0.0 # dropout probability after activation in FFN
188 | )
189 | self.encoder_layerdrop: float = (
190 | 0.0 # probability of dropping a tarnsformer layer
191 | )
192 | self.dropout_input: float = (
193 | 0.0 # dropout to apply to the input (after feat extr)
194 | )
195 | self.dropout_features: float = (
196 | 0.0 # dropout to apply to the features (after feat extr)
197 | )
198 |
199 | # masking
200 | self.mask_length: int = 10 # mask length
201 | self.mask_prob: float = 0.65 # probability of replacing a token with mask
202 | self.mask_selection: str = "static" # how to choose mask length
203 | self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh
204 | self.no_mask_overlap: bool = False # whether to allow masks to overlap
205 | self.mask_min_space: int = (
206 | 1 # min space between spans (if no overlap is enabled)
207 | )
208 |
209 | # channel masking
210 | self.mask_channel_length: int = 10 # length of the mask for features (channels)
211 | self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0
212 | self.mask_channel_selection: str = (
213 | "static" # how to choose mask length for channel masking
214 | )
215 | self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices
216 | self.no_mask_channel_overlap: bool = (
217 | False # whether to allow channel masks to overlap
218 | )
219 | self.mask_channel_min_space: int = (
220 | 1 # min space between spans (if no overlap is enabled)
221 | )
222 |
223 | # positional embeddings
224 | self.conv_pos: int = (
225 | 128 # number of filters for convolutional positional embeddings
226 | )
227 | self.conv_pos_groups: int = (
228 | 16 # number of groups for convolutional positional embedding
229 | )
230 |
231 | # relative position embedding
232 | self.relative_position_embedding: bool = (
233 | False # apply relative position embedding
234 | )
235 | self.num_buckets: int = 320 # number of buckets for relative position embedding
236 | self.max_distance: int = (
237 | 1280 # maximum distance for relative position embedding
238 | )
239 | self.gru_rel_pos: bool = False # apply gated relative position embedding
240 |
241 | if cfg is not None:
242 | self.update(cfg)
243 |
244 | def update(self, cfg: dict):
245 | self.__dict__.update(cfg)
246 |
247 |
248 | class WavLM(nn.Module):
249 | def __init__(
250 | self,
251 | cfg: WavLMConfig,
252 | ) -> None:
253 | super().__init__()
254 | logger.info(f"WavLM Config: {cfg.__dict__}")
255 |
256 | self.cfg = cfg
257 | feature_enc_layers = eval(cfg.conv_feature_layers)
258 | self.embed = feature_enc_layers[-1][0]
259 |
260 | self.feature_extractor = ConvFeatureExtractionModel(
261 | conv_layers=feature_enc_layers,
262 | dropout=0.0,
263 | mode=cfg.extractor_mode,
264 | conv_bias=cfg.conv_bias,
265 | )
266 |
267 | self.post_extract_proj = (
268 | nn.Linear(self.embed, cfg.encoder_embed_dim)
269 | if self.embed != cfg.encoder_embed_dim
270 | else None
271 | )
272 |
273 | self.mask_prob = cfg.mask_prob
274 | self.mask_selection = cfg.mask_selection
275 | self.mask_other = cfg.mask_other
276 | self.mask_length = cfg.mask_length
277 | self.no_mask_overlap = cfg.no_mask_overlap
278 | self.mask_min_space = cfg.mask_min_space
279 |
280 | self.mask_channel_prob = cfg.mask_channel_prob
281 | self.mask_channel_selection = cfg.mask_channel_selection
282 | self.mask_channel_other = cfg.mask_channel_other
283 | self.mask_channel_length = cfg.mask_channel_length
284 | self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
285 | self.mask_channel_min_space = cfg.mask_channel_min_space
286 |
287 | self.dropout_input = nn.Dropout(cfg.dropout_input)
288 | self.dropout_features = nn.Dropout(cfg.dropout_features)
289 |
290 | self.feature_grad_mult = cfg.feature_grad_mult
291 |
292 | self.mask_emb = nn.Parameter(
293 | torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
294 | )
295 |
296 | self.encoder = TransformerEncoder(cfg)
297 | self.layer_norm = LayerNorm(self.embed)
298 |
299 | def apply_mask(self, x, padding_mask):
300 | B, T, C = x.shape
301 | if self.mask_prob > 0:
302 | mask_indices = compute_mask_indices(
303 | (B, T),
304 | padding_mask,
305 | self.mask_prob,
306 | self.mask_length,
307 | self.mask_selection,
308 | self.mask_other,
309 | min_masks=2,
310 | no_overlap=self.no_mask_overlap,
311 | min_space=self.mask_min_space,
312 | )
313 | mask_indices = torch.from_numpy(mask_indices).to(x.device)
314 | x[mask_indices] = self.mask_emb
315 | else:
316 | mask_indices = None
317 |
318 | if self.mask_channel_prob > 0:
319 | mask_channel_indices = compute_mask_indices(
320 | (B, C),
321 | None,
322 | self.mask_channel_prob,
323 | self.mask_channel_length,
324 | self.mask_channel_selection,
325 | self.mask_channel_other,
326 | no_overlap=self.no_mask_channel_overlap,
327 | min_space=self.mask_channel_min_space,
328 | )
329 | mask_channel_indices = (
330 | torch.from_numpy(mask_channel_indices)
331 | .to(x.device)
332 | .unsqueeze(1)
333 | .expand(-1, T, -1)
334 | )
335 | x[mask_channel_indices] = 0
336 |
337 | return x, mask_indices
338 |
339 | def forward_padding_mask(
340 | self,
341 | features: torch.Tensor,
342 | padding_mask: torch.Tensor,
343 | ) -> torch.Tensor:
344 | extra = padding_mask.size(1) % features.size(1)
345 | if extra > 0:
346 | padding_mask = padding_mask[:, :-extra]
347 | padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
348 | padding_mask = padding_mask.all(-1)
349 | return padding_mask
350 |
351 | def extract_features(
352 | self,
353 | source: torch.Tensor,
354 | padding_mask: Optional[torch.Tensor] = None,
355 | mask: bool = False,
356 | ret_conv: bool = False,
357 | output_layer: Optional[int] = None,
358 | ret_layer_results: bool = False,
359 | ):
360 |
361 | if self.feature_grad_mult > 0:
362 | features = self.feature_extractor(source)
363 | if self.feature_grad_mult != 1.0:
364 | features = GradMultiply.apply(features, self.feature_grad_mult)
365 | else:
366 | with torch.no_grad():
367 | features = self.feature_extractor(source)
368 |
369 | features = features.transpose(1, 2)
370 | features = self.layer_norm(features)
371 |
372 | if padding_mask is not None:
373 | padding_mask = self.forward_padding_mask(features, padding_mask)
374 |
375 | if self.post_extract_proj is not None:
376 | features = self.post_extract_proj(features)
377 |
378 | features = self.dropout_input(features)
379 |
380 | if mask:
381 | x, mask_indices = self.apply_mask(features, padding_mask)
382 | else:
383 | x = features
384 |
385 | # feature: (B, T, D), float
386 | # target: (B, T), long
387 | # x: (B, T, D), float
388 | # padding_mask: (B, T), bool
389 | # mask_indices: (B, T), bool
390 | x, layer_results = self.encoder(
391 | x,
392 | padding_mask=padding_mask,
393 | layer=None if output_layer is None else output_layer - 1,
394 | )
395 |
396 | res = {
397 | "x": x,
398 | "padding_mask": padding_mask,
399 | "features": features,
400 | "layer_results": layer_results,
401 | }
402 |
403 | feature = res["features"] if ret_conv else res["x"]
404 | if ret_layer_results:
405 | feature = (feature, res["layer_results"])
406 | return feature, res["padding_mask"]
407 |
408 |
409 | class ConvFeatureExtractionModel(nn.Module):
410 | def __init__(
411 | self,
412 | conv_layers: List[Tuple[int, int, int]],
413 | dropout: float = 0.0,
414 | mode: str = "default",
415 | conv_bias: bool = False,
416 | conv_type: str = "default",
417 | ):
418 | super().__init__()
419 |
420 | assert mode in {"default", "layer_norm"}
421 |
422 | def block(
423 | n_in,
424 | n_out,
425 | k,
426 | stride,
427 | is_layer_norm=False,
428 | is_group_norm=False,
429 | conv_bias=False,
430 | ):
431 | def make_conv():
432 | conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
433 | nn.init.kaiming_normal_(conv.weight)
434 | return conv
435 |
436 | assert (
437 | is_layer_norm and is_group_norm
438 | ) == False, "layer norm and group norm are exclusive"
439 |
440 | if is_layer_norm:
441 | return nn.Sequential(
442 | make_conv(),
443 | nn.Dropout(p=dropout),
444 | nn.Sequential(
445 | TransposeLast(),
446 | Fp32LayerNorm(dim, elementwise_affine=True),
447 | TransposeLast(),
448 | ),
449 | nn.GELU(),
450 | )
451 | elif is_group_norm:
452 | return nn.Sequential(
453 | make_conv(),
454 | nn.Dropout(p=dropout),
455 | Fp32GroupNorm(dim, dim, affine=True),
456 | nn.GELU(),
457 | )
458 | else:
459 | return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
460 |
461 | self.conv_type = conv_type
462 | if self.conv_type == "default":
463 | in_d = 1
464 | self.conv_layers = nn.ModuleList()
465 | for i, cl in enumerate(conv_layers):
466 | assert len(cl) == 3, "invalid conv definition: " + str(cl)
467 | (dim, k, stride) = cl
468 |
469 | self.conv_layers.append(
470 | block(
471 | in_d,
472 | dim,
473 | k,
474 | stride,
475 | is_layer_norm=mode == "layer_norm",
476 | is_group_norm=mode == "default" and i == 0,
477 | conv_bias=conv_bias,
478 | )
479 | )
480 | in_d = dim
481 | elif self.conv_type == "conv2d":
482 | in_d = 1
483 | self.conv_layers = nn.ModuleList()
484 | for i, cl in enumerate(conv_layers):
485 | assert len(cl) == 3
486 | (dim, k, stride) = cl
487 |
488 | self.conv_layers.append(torch.nn.Conv2d(in_d, dim, k, stride))
489 | self.conv_layers.append(torch.nn.ReLU())
490 | in_d = dim
491 | elif self.conv_type == "custom":
492 | in_d = 1
493 | idim = 80
494 | self.conv_layers = nn.ModuleList()
495 | for i, cl in enumerate(conv_layers):
496 | assert len(cl) == 3
497 | (dim, k, stride) = cl
498 | self.conv_layers.append(
499 | torch.nn.Conv2d(in_d, dim, k, stride, padding=1)
500 | )
501 | self.conv_layers.append(torch.nn.LayerNorm([dim, idim]))
502 | self.conv_layers.append(torch.nn.ReLU())
503 | in_d = dim
504 | if (i + 1) % 2 == 0:
505 | self.conv_layers.append(
506 | torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
507 | )
508 | idim = int(math.ceil(idim / 2))
509 | else:
510 | pass
511 |
512 | def forward(self, x, mask=None):
513 |
514 | # BxT -> BxCxT
515 | x = x.unsqueeze(1)
516 | if self.conv_type == "custom":
517 | for conv in self.conv_layers:
518 | if isinstance(conv, nn.LayerNorm):
519 | x = x.transpose(1, 2)
520 | x = conv(x).transpose(1, 2)
521 | else:
522 | x = conv(x)
523 | x = x.transpose(2, 3).contiguous()
524 | x = x.view(x.size(0), -1, x.size(-1))
525 | else:
526 | for conv in self.conv_layers:
527 | x = conv(x)
528 | if self.conv_type == "conv2d":
529 | b, c, t, f = x.size()
530 | x = x.transpose(2, 3).contiguous().view(b, c * f, t)
531 | return x
532 |
533 |
534 | class TransformerEncoder(nn.Module):
535 | def __init__(self, args):
536 | super().__init__()
537 |
538 | self.dropout = args.dropout
539 | self.embedding_dim = args.encoder_embed_dim
540 |
541 | self.pos_conv = nn.Conv1d(
542 | self.embedding_dim,
543 | self.embedding_dim,
544 | kernel_size=args.conv_pos,
545 | padding=args.conv_pos // 2,
546 | groups=args.conv_pos_groups,
547 | )
548 | dropout = 0
549 | std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
550 | nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
551 | nn.init.constant_(self.pos_conv.bias, 0)
552 |
553 | self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
554 | self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
555 |
556 | if hasattr(args, "relative_position_embedding"):
557 | self.relative_position_embedding = args.relative_position_embedding
558 | self.num_buckets = args.num_buckets
559 | self.max_distance = args.max_distance
560 | else:
561 | self.relative_position_embedding = False
562 | self.num_buckets = 0
563 | self.max_distance = 0
564 |
565 | self.layers = nn.ModuleList(
566 | [
567 | TransformerSentenceEncoderLayer(
568 | embedding_dim=self.embedding_dim,
569 | ffn_embedding_dim=args.encoder_ffn_embed_dim,
570 | num_attention_heads=args.encoder_attention_heads,
571 | dropout=self.dropout,
572 | attention_dropout=args.attention_dropout,
573 | activation_dropout=args.activation_dropout,
574 | activation_fn=args.activation_fn,
575 | layer_norm_first=args.layer_norm_first,
576 | has_relative_attention_bias=(
577 | self.relative_position_embedding and i == 0
578 | ),
579 | num_buckets=self.num_buckets,
580 | max_distance=self.max_distance,
581 | gru_rel_pos=args.gru_rel_pos,
582 | )
583 | for i in range(args.encoder_layers)
584 | ]
585 | )
586 |
587 | self.layer_norm_first = args.layer_norm_first
588 | self.layer_norm = LayerNorm(self.embedding_dim)
589 | self.layerdrop = args.encoder_layerdrop
590 |
591 | self.apply(init_bert_params)
592 |
593 | def forward(
594 | self,
595 | x,
596 | padding_mask=None,
597 | streaming_mask=None,
598 | layer=None,
599 | ffn_adapters=None,
600 | freeze_pos=False,
601 | freeze_layers=None,
602 | ):
603 | x, layer_results = self.extract_features(
604 | x, padding_mask, streaming_mask, layer, ffn_adapters, freeze_pos=freeze_pos, freeze_layers=freeze_layers
605 | )
606 |
607 | if self.layer_norm_first and layer is None:
608 | x = self.layer_norm(x)
609 |
610 | return x, layer_results
611 |
612 | def extract_features(
613 | self,
614 | x,
615 | padding_mask=None,
616 | streaming_mask=None,
617 | tgt_layer=None,
618 | ffn_adapters=None,
619 | freeze_pos=False,
620 | freeze_layers=None,
621 | ):
622 |
623 | if padding_mask is not None:
624 | x[padding_mask] = 0
625 |
626 | if freeze_pos:
627 | with torch.no_grad():
628 | x_conv = self.pos_conv(x.transpose(1, 2))
629 | x_conv = x_conv.transpose(1, 2)
630 | x = x + x_conv
631 | else:
632 | x_conv = self.pos_conv(x.transpose(1, 2))
633 | x_conv = x_conv.transpose(1, 2)
634 | x = x + x_conv
635 |
636 | if not self.layer_norm_first:
637 | x = self.layer_norm(x)
638 |
639 | x = F.dropout(x, p=self.dropout, training=self.training)
640 |
641 | # B x T x C -> T x B x C
642 | x = x.transpose(0, 1)
643 |
644 | layer_results = []
645 | z = None
646 | # if tgt_layer is not None:
647 | # layer_results.append((x, z))
648 | r = None
649 | pos_bias = None
650 | for i, layer in enumerate(self.layers):
651 | dropout_probability = np.random.random()
652 | if not self.training or (dropout_probability > self.layerdrop):
653 | ffn_adapter = None
654 | if ffn_adapters is not None:
655 | ffn_adapter = ffn_adapters[i]
656 |
657 | if isinstance(freeze_layers, list) and i + 1 in freeze_layers:
658 | with torch.no_grad():
659 | x, z, pos_bias = layer(
660 | x,
661 | self_attn_padding_mask=padding_mask,
662 | need_weights=False,
663 | self_attn_mask=streaming_mask,
664 | pos_bias=pos_bias,
665 | ffn_adapter=ffn_adapter,
666 | )
667 | else:
668 | x, z, pos_bias = layer(
669 | x,
670 | self_attn_padding_mask=padding_mask,
671 | need_weights=False,
672 | self_attn_mask=streaming_mask,
673 | pos_bias=pos_bias,
674 | ffn_adapter=ffn_adapter,
675 | )
676 | # if tgt_layer is not None:
677 | # layer_results.append((x, z))
678 | layer_results.append((x, z))
679 | if i == tgt_layer:
680 | r = x
681 | break
682 |
683 | if r is not None:
684 | x = r
685 |
686 | # T x B x C -> B x T x C
687 | x = x.transpose(0, 1)
688 |
689 | return x, layer_results
690 |
691 |
692 | class TransformerSentenceEncoderLayer(nn.Module):
693 | """
694 | Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
695 | models.
696 | """
697 |
698 | def __init__(
699 | self,
700 | embedding_dim: float = 768,
701 | ffn_embedding_dim: float = 3072,
702 | num_attention_heads: float = 8,
703 | dropout: float = 0.1,
704 | attention_dropout: float = 0.1,
705 | activation_dropout: float = 0.1,
706 | activation_fn: str = "relu",
707 | layer_norm_first: bool = False,
708 | has_relative_attention_bias: bool = False,
709 | num_buckets: int = 0,
710 | max_distance: int = 0,
711 | rescale_init: bool = False,
712 | gru_rel_pos: bool = False,
713 | ) -> None:
714 |
715 | super().__init__()
716 | # Initialize parameters
717 | self.embedding_dim = embedding_dim
718 | self.dropout = dropout
719 | self.activation_dropout = activation_dropout
720 |
721 | # Initialize blocks
722 | self.activation_name = activation_fn
723 | self.activation_fn = get_activation_fn(activation_fn)
724 | self.self_attn = MultiheadAttention(
725 | self.embedding_dim,
726 | num_attention_heads,
727 | dropout=attention_dropout,
728 | self_attention=True,
729 | has_relative_attention_bias=has_relative_attention_bias,
730 | num_buckets=num_buckets,
731 | max_distance=max_distance,
732 | rescale_init=rescale_init,
733 | gru_rel_pos=gru_rel_pos,
734 | )
735 |
736 | self.dropout1 = nn.Dropout(dropout)
737 | self.dropout2 = nn.Dropout(self.activation_dropout)
738 | self.dropout3 = nn.Dropout(dropout)
739 |
740 | self.layer_norm_first = layer_norm_first
741 |
742 | # layer norm associated with the self attention layer
743 | self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
744 |
745 | if self.activation_name == "glu":
746 | self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
747 | else:
748 | self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
749 | self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
750 |
751 | # layer norm associated with the position wise feed-forward NN
752 | self.final_layer_norm = LayerNorm(self.embedding_dim)
753 |
754 | def forward(
755 | self,
756 | x: torch.Tensor,
757 | self_attn_mask: torch.Tensor = None,
758 | self_attn_padding_mask: torch.Tensor = None,
759 | need_weights: bool = False,
760 | pos_bias=None,
761 | ffn_adapter=None,
762 | ):
763 | """
764 | LayerNorm is applied either before or after the self-attention/ffn
765 | modules similar to the original Transformer imlementation.
766 | """
767 | residual = x
768 |
769 | if self.layer_norm_first:
770 | x = self.self_attn_layer_norm(x)
771 | x, attn, pos_bias = self.self_attn(
772 | query=x,
773 | key=x,
774 | value=x,
775 | key_padding_mask=self_attn_padding_mask,
776 | need_weights=False,
777 | attn_mask=self_attn_mask,
778 | position_bias=pos_bias,
779 | )
780 | x = self.dropout1(x)
781 | x = residual + x
782 |
783 | residual = x
784 | x = self.final_layer_norm(x)
785 | if self.activation_name == "glu":
786 | x = self.fc1(x)
787 | else:
788 | x = self.activation_fn(self.fc1(x))
789 | x = self.dropout2(x)
790 | x = self.fc2(x)
791 |
792 | if ffn_adapter is not None:
793 | x = ffn_adapter(x)
794 |
795 | x = self.dropout3(x)
796 | x = residual + x
797 | else:
798 | x, attn, pos_bias = self.self_attn(
799 | query=x,
800 | key=x,
801 | value=x,
802 | key_padding_mask=self_attn_padding_mask,
803 | need_weights=need_weights,
804 | attn_mask=self_attn_mask,
805 | position_bias=pos_bias,
806 | )
807 |
808 | x = self.dropout1(x)
809 | x = residual + x
810 |
811 | x = self.self_attn_layer_norm(x)
812 |
813 | residual = x
814 | if self.activation_name == "glu":
815 | x = self.fc1(x)
816 | else:
817 | x = self.activation_fn(self.fc1(x))
818 | x = self.dropout2(x)
819 | x = self.fc2(x)
820 |
821 | if ffn_adapter is not None:
822 | x = ffn_adapter(x)
823 |
824 | x = self.dropout3(x)
825 | x = residual + x
826 | x = self.final_layer_norm(x)
827 |
828 | return x, attn, pos_bias
829 |
--------------------------------------------------------------------------------
/s3prl_py/spin/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vectominist/spin/73edb7ae120be67a2d584931eeb388caecda744e/s3prl_py/spin/__init__.py
--------------------------------------------------------------------------------
/s3prl_py/spin/expert.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import sys
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from torch.nn.utils.rnn import pad_sequence
7 |
8 | from ..interfaces import UpstreamBase
9 |
10 | SAMPLE_RATE = 16000
11 | EXAMPLE_SEC = 5
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | class UpstreamExpert(UpstreamBase):
17 | def __init__(
18 | self,
19 | ckpt,
20 | feat_select: str = "hidden",
21 | **kwargs,
22 | ):
23 | super().__init__(**kwargs)
24 |
25 | # set directory of `spin/`
26 | sys.path.append("/path/to/spin")
27 | from src.model import SpinModel
28 |
29 | self.model = SpinModel.load_from_checkpoint(ckpt, strict=False)
30 | if self.model.encoder_type == "HuBERT":
31 | self.model.encoder.model.feature_grad_mult = 0.0
32 | self.feat_select = feat_select
33 |
34 | assert self.feat_select in {
35 | "all_feats",
36 | "hidden",
37 | "disentangled",
38 | "logits",
39 | "probs",
40 | "last",
41 | }, self.feat_select
42 |
43 | if self.feat_select in {"logits", "probs"}:
44 | assert self.model.loss_type in {"SwavVQDisentangle"}
45 | self.model.loss_module.normalize_codebook()
46 |
47 | logger.info(f"Feature selection: {self.feat_select}")
48 | logger.info(f"Loss type: {self.model.loss_type}")
49 |
50 | def get_downsample_rates(self, key: str) -> int:
51 | return self.model.encoder_rate
52 |
53 | def forward(self, wavs):
54 | device = wavs[0].device
55 | wav_lengths = torch.LongTensor([len(wav) for wav in wavs]).to(device)
56 | wav_padding_mask = ~torch.lt(
57 | torch.arange(max(wav_lengths)).unsqueeze(0).to(device),
58 | wav_lengths.unsqueeze(1),
59 | )
60 | padded_wav = pad_sequence(wavs, batch_first=True)
61 |
62 | results = self.model(
63 | (padded_wav, wav_lengths, wav_padding_mask),
64 | feat_only=True,
65 | )
66 |
67 | outputs = {}
68 |
69 | if self.model.loss_type in {"SwavVQDisentangle"}:
70 | outputs["code"] = results["codes"]
71 | outputs["logits"] = results["logits"]
72 | outputs["probs"] = F.softmax(results["logits"], dim=-1)
73 | feat = results["repr_list"]
74 | feat_list = results["feat_list"]
75 |
76 | outputs["disentangled"] = feat
77 |
78 | if self.feat_select == "all_feats":
79 | outputs["hidden_states"] = [feat] + feat_list
80 | elif self.feat_select == "hidden":
81 | outputs["hidden_states"] = feat_list
82 | elif self.feat_select == "last":
83 | outputs["hidden_states"] = [feat_list[-1]]
84 | elif self.feat_select == "disentangled":
85 | outputs["hidden_states"] = [feat]
86 | elif self.feat_select == "logits":
87 | outputs["hidden_states"] = [results["logits"]]
88 | elif self.feat_select == "probs":
89 | outputs["hidden_states"] = [outputs["probs"]]
90 |
91 | outputs["last_hidden_state"] = outputs["hidden_states"][-1]
92 |
93 | return outputs
94 |
--------------------------------------------------------------------------------
/s3prl_py/spin/hubconf.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from s3prl.util.download import _urls_to_filepaths
4 |
5 | from .expert import UpstreamExpert as _UpstreamExpert
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 | def spin_custom(
11 | ckpt: str,
12 | refresh: bool = False,
13 | **kwargs,
14 | ):
15 | if ckpt.startswith("http"):
16 | ckpt = _urls_to_filepaths(ckpt, refresh=refresh)
17 |
18 | return _UpstreamExpert(ckpt, **kwargs)
19 |
20 |
21 | def spin_local(*args, **kwargs):
22 | return spin_custom(*args, **kwargs)
23 |
24 |
25 | def spin_url(*args, **kwargs):
26 | return spin_custom(*args, **kwargs)
27 |
28 |
29 | def spin_hubert_128(refresh=False, **kwargs):
30 | kwargs[
31 | "ckpt"
32 | ] = "https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_hubert_128.ckpt"
33 | return spin_custom(refresh=refresh, **kwargs)
34 |
35 |
36 | def spin_hubert_256(refresh=False, **kwargs):
37 | kwargs[
38 | "ckpt"
39 | ] = "https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_hubert_256.ckpt"
40 | return spin_custom(refresh=refresh, **kwargs)
41 |
42 |
43 | def spin_hubert_512(refresh=False, **kwargs):
44 | kwargs[
45 | "ckpt"
46 | ] = "https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_hubert_512.ckpt"
47 | return spin_custom(refresh=refresh, **kwargs)
48 |
49 |
50 | def spin_hubert_1024(refresh=False, **kwargs):
51 | kwargs[
52 | "ckpt"
53 | ] = "https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_hubert_1024.ckpt"
54 | return spin_custom(refresh=refresh, **kwargs)
55 |
56 |
57 | def spin_hubert_2048(refresh=False, **kwargs):
58 | kwargs[
59 | "ckpt"
60 | ] = "https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_hubert_2048.ckpt"
61 | return spin_custom(refresh=refresh, **kwargs)
62 |
63 |
64 | def spin_wavlm_128(refresh=False, **kwargs):
65 | kwargs[
66 | "ckpt"
67 | ] = "https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_wavlm_128.ckpt"
68 | return spin_custom(refresh=refresh, **kwargs)
69 |
70 |
71 | def spin_wavlm_256(refresh=False, **kwargs):
72 | kwargs[
73 | "ckpt"
74 | ] = "https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_wavlm_256.ckpt"
75 | return spin_custom(refresh=refresh, **kwargs)
76 |
77 |
78 | def spin_wavlm_512(refresh=False, **kwargs):
79 | kwargs[
80 | "ckpt"
81 | ] = "https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_wavlm_512.ckpt"
82 | return spin_custom(refresh=refresh, **kwargs)
83 |
84 |
85 | def spin_wavlm_1024(refresh=False, **kwargs):
86 | kwargs[
87 | "ckpt"
88 | ] = "https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_wavlm_1024.ckpt"
89 | return spin_custom(refresh=refresh, **kwargs)
90 |
91 |
92 | def spin_wavlm_2048(refresh=False, **kwargs):
93 | kwargs[
94 | "ckpt"
95 | ] = "https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_wavlm_2048.ckpt"
96 | return spin_custom(refresh=refresh, **kwargs)
97 |
--------------------------------------------------------------------------------
/script/prepare.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | libri_dir=$1
4 | json_dir=$2
5 |
6 | mkdir -p $json_dir
7 |
8 | python3 prepare_data.py \
9 | $libri_dir \
10 | $json_dir \
11 | --split train-clean-100 \
12 | dev-clean \
13 | dev-other \
14 | --sort-by-len
15 |
--------------------------------------------------------------------------------
/script/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | exp_name=$1
4 | exp_dir=$2
5 | config=config/spin.yaml
6 |
7 | mkdir -p $exp_dir
8 |
9 | echo "Name: $exp_name"
10 | echo "Config: $config"
11 |
12 | python3 run_task.py \
13 | SpinPretrainTask \
14 | --config $config \
15 | --save-path $exp_dir/$exp_name \
16 | --gpus 1 \
17 | --njobs 16
18 |
--------------------------------------------------------------------------------
/src/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataset import (
2 | AudioPretrainDataset,
3 | AudioPretrainPnmiValDataset,
4 | collate_fn,
5 | val_collate_fn,
6 | )
7 | from .sampler import MaxLengthBatchSampler, MaxLengthDistributedSampler
8 |
--------------------------------------------------------------------------------
/src/data/audio.py:
--------------------------------------------------------------------------------
1 | # Source: https://github.com/auspicious3000/contentvec/blob/main/contentvec/data/audio/audio_utils_1.py
2 | # Paper: https://arxiv.org/abs/2204.09224
3 |
4 |
5 | import numpy as np
6 | import parselmouth
7 |
8 |
9 | def make_lowshelf(g, fc, Q, fs=44100):
10 | """Generate filter coefficients for 2nd order Lowshelf filter.
11 | This function follows the code from the JUCE DSP library
12 | which can be found in `juce_IIRFilter.cpp`.
13 |
14 | The design equations are based upon those found in the Cookbook
15 | formulae for audio equalizer biquad filter coefficients
16 | by Robert Bristow-Johnson.
17 | https://www.w3.org/2011/audio/audio-eq-cookbook.html
18 | Args:
19 | g (float): Gain factor in dB.
20 | fc (float): Cutoff frequency in Hz.
21 | Q (float): Q factor.
22 | fs (float): Sampling frequency in Hz.
23 | Returns:
24 | tuple: (b, a) filter coefficients
25 | """
26 | # convert gain from dB to linear
27 | g = np.power(10, (g / 20))
28 |
29 | # initial values
30 | A = np.max([0.0, np.sqrt(g)])
31 | aminus1 = A - 1
32 | aplus1 = A + 1
33 | omega = (2 * np.pi * np.max([fc, 2.0])) / fs
34 | coso = np.cos(omega)
35 | beta = np.sin(omega) * np.sqrt(A) / Q
36 | aminus1TimesCoso = aminus1 * coso
37 |
38 | # coefs calculation
39 | b0 = A * (aplus1 - aminus1TimesCoso + beta)
40 | b1 = A * 2 * (aminus1 - aplus1 * coso)
41 | b2 = A * (aplus1 - aminus1TimesCoso - beta)
42 | a0 = aplus1 + aminus1TimesCoso + beta
43 | a1 = -2 * (aminus1 + aplus1 * coso)
44 | a2 = aplus1 + aminus1TimesCoso - beta
45 |
46 | # output coefs
47 | # b = np.array([b0/a0, b1/a0, b2/a0])
48 | # a = np.array([a0/a0, a1/a0, a2/a0])
49 |
50 | return np.array([[b0 / a0, b1 / a0, b2 / a0, 1.0, a1 / a0, a2 / a0]])
51 |
52 |
53 | def make_highself(g, fc, Q, fs=44100):
54 | """Generate filter coefficients for 2nd order Highshelf filter.
55 | This function follows the code from the JUCE DSP library
56 | which can be found in `juce_IIRFilter.cpp`.
57 |
58 | The design equations are based upon those found in the Cookbook
59 | formulae for audio equalizer biquad filter coefficients
60 | by Robert Bristow-Johnson.
61 | https://www.w3.org/2011/audio/audio-eq-cookbook.html
62 | Args:
63 | g (float): Gain factor in dB.
64 | fc (float): Cutoff frequency in Hz.
65 | Q (float): Q factor.
66 | fs (float): Sampling frequency in Hz.
67 | Returns:
68 | tuple: (b, a) filter coefficients
69 | """
70 | # convert gain from dB to linear
71 | g = np.power(10, (g / 20))
72 |
73 | # initial values
74 | A = np.max([0.0, np.sqrt(g)])
75 | aminus1 = A - 1
76 | aplus1 = A + 1
77 | omega = (2 * np.pi * np.max([fc, 2.0])) / fs
78 | coso = np.cos(omega)
79 | beta = np.sin(omega) * np.sqrt(A) / Q
80 | aminus1TimesCoso = aminus1 * coso
81 |
82 | # coefs calculation
83 | b0 = A * (aplus1 + aminus1TimesCoso + beta)
84 | b1 = A * -2 * (aminus1 + aplus1 * coso)
85 | b2 = A * (aplus1 + aminus1TimesCoso - beta)
86 | a0 = aplus1 - aminus1TimesCoso + beta
87 | a1 = 2 * (aminus1 - aplus1 * coso)
88 | a2 = aplus1 - aminus1TimesCoso - beta
89 |
90 | # output coefs
91 | # b = np.array([b0/a0, b1/a0, b2/a0])
92 | # a = np.array([a0/a0, a1/a0, a2/a0])
93 |
94 | return np.array([[b0 / a0, b1 / a0, b2 / a0, 1.0, a1 / a0, a2 / a0]])
95 |
96 |
97 | def make_peaking(g, fc, Q, fs=44100):
98 | """Generate filter coefficients for 2nd order Peaking EQ.
99 | This function follows the code from the JUCE DSP library
100 | which can be found in `juce_IIRFilter.cpp`.
101 |
102 | The design equations are based upon those found in the Cookbook
103 | formulae for audio equalizer biquad filter coefficients
104 | by Robert Bristow-Johnson.
105 | https://www.w3.org/2011/audio/audio-eq-cookbook.html
106 | Args:
107 | g (float): Gain factor in dB.
108 | fc (float): Cutoff frequency in Hz.
109 | Q (float): Q factor.
110 | fs (float): Sampling frequency in Hz.
111 | Returns:
112 | tuple: (b, a) filter coefficients
113 | """
114 | # convert gain from dB to linear
115 | g = np.power(10, (g / 20))
116 |
117 | # initial values
118 | A = np.max([0.0, np.sqrt(g)])
119 | omega = (2 * np.pi * np.max([fc, 2.0])) / fs
120 | alpha = np.sin(omega) / (Q * 2)
121 | c2 = -2 * np.cos(omega)
122 | alphaTimesA = alpha * A
123 | alphaOverA = alpha / A
124 |
125 | # coefs calculation
126 | b0 = 1 + alphaTimesA
127 | b1 = c2
128 | b2 = 1 - alphaTimesA
129 | a0 = 1 + alphaOverA
130 | a1 = c2
131 | a2 = 1 - alphaOverA
132 |
133 | # output coefs
134 | # b = np.array([b0/a0, b1/a0, b2/a0])
135 | # a = np.array([a0/a0, a1/a0, a2/a0])
136 |
137 | return np.array([[b0 / a0, b1 / a0, b2 / a0, 1.0, a1 / a0, a2 / a0]])
138 |
139 |
140 | def params2sos(G, Fc, Q, fs):
141 | """Convert 5 band EQ paramaters to 2nd order sections.
142 | Takes a vector with shape (13,) of denormalized EQ parameters
143 | and calculates filter coefficients for each of the 5 filters.
144 | These coefficients (2nd order sections) are then stored into a
145 | single (5,6) matrix. This matrix can be fed to `scipy.signal.sosfreqz()`
146 | in order to determine the frequency response of the cascasd of
147 | all five biquad filters.
148 | Args:
149 | x (float): Gain factor in dB.
150 | fs (float): Sampling frequency in Hz.
151 | Returns:
152 | ndarray: filter coefficients for 5 band EQ stored in (5,6) matrix.
153 | [[b1_0, b1_1, b1_2, a1_0, a1_1, a1_2], # lowshelf coefficients
154 | [b2_0, b2_1, b2_2, a2_0, a2_1, a2_2], # first band coefficients
155 | [b3_0, b3_1, b3_2, a3_0, a3_1, a3_2], # second band coefficients
156 | [b4_0, b4_1, b4_2, a4_0, a4_1, a4_2], # third band coefficients
157 | [b5_0, b5_1, b5_2, a5_0, a5_1, a5_2]] # highshelf coefficients
158 | """
159 | # generate filter coefficients from eq params
160 | c0 = make_lowshelf(G[0], Fc[0], Q[0], fs=fs)
161 | c1 = make_peaking(G[1], Fc[1], Q[1], fs=fs)
162 | c2 = make_peaking(G[2], Fc[2], Q[2], fs=fs)
163 | c3 = make_peaking(G[3], Fc[3], Q[3], fs=fs)
164 | c4 = make_peaking(G[4], Fc[4], Q[4], fs=fs)
165 | c5 = make_peaking(G[5], Fc[5], Q[5], fs=fs)
166 | c6 = make_peaking(G[6], Fc[6], Q[6], fs=fs)
167 | c7 = make_peaking(G[7], Fc[7], Q[7], fs=fs)
168 | c8 = make_peaking(G[8], Fc[8], Q[8], fs=fs)
169 | c9 = make_highself(G[9], Fc[9], Q[9], fs=fs)
170 |
171 | # stuff coefficients into second order sections structure
172 | sos = np.concatenate([c0, c1, c2, c3, c4, c5, c6, c7, c8, c9], axis=0)
173 |
174 | return sos
175 |
176 |
177 | def change_gender(x, fs, lo, hi, ratio_fs, ratio_ps, ratio_pr):
178 | s = parselmouth.Sound(x, sampling_frequency=fs)
179 | f0 = s.to_pitch_ac(pitch_floor=lo, pitch_ceiling=hi, time_step=0.8 / lo)
180 | f0_np = f0.selected_array["frequency"]
181 | f0_med = np.median(f0_np[f0_np != 0]).item()
182 | ss = parselmouth.praat.call(
183 | [s, f0], "Change gender", ratio_fs, f0_med * ratio_ps, ratio_pr, 1.0
184 | )
185 | return ss.values.squeeze(0)
186 |
187 |
188 | def change_gender_f0(x, fs, lo, hi, ratio_fs, new_f0_med, ratio_pr):
189 | s = parselmouth.Sound(x, sampling_frequency=fs)
190 | ss = parselmouth.praat.call(
191 | s, "Change gender", lo, hi, ratio_fs, new_f0_med, ratio_pr, 1.0
192 | )
193 | return ss.values.squeeze(0)
194 |
--------------------------------------------------------------------------------
/src/data/dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 | import pickle
5 | import random
6 | import re
7 | from pathlib import Path
8 | from typing import List, Tuple
9 |
10 | import numpy as np
11 | import soundfile as sf
12 | import torch
13 | import torch.nn.functional as F
14 | from scipy.signal import sosfilt
15 | from torch.nn.utils.rnn import pad_sequence
16 | from torch.utils.data import Dataset
17 |
18 | from src.util import len_to_padding
19 |
20 | from .audio import change_gender, change_gender_f0, params2sos
21 |
22 | logger = logging.getLogger("dataset")
23 |
24 | Qmin, Qmax = 2, 5
25 |
26 |
27 | def read_phn(tsv_path, rm_stress=True):
28 | uid2phns = {}
29 | with open(tsv_path) as f:
30 | for line in f:
31 | uid, phns = line.rstrip().split("\t")
32 | phns = phns.split(",")
33 | if rm_stress:
34 | phns = [re.sub("[0-9]", "", phn) for phn in phns]
35 | uid2phns[uid] = phns
36 | return uid2phns
37 |
38 |
39 | class AudioPretrainDataset(Dataset):
40 | def __init__(
41 | self,
42 | json_dir: str,
43 | splits: List[str],
44 | spk2info: str = None,
45 | min_audio_len: int = 1,
46 | max_audio_len: int = 1_000_000,
47 | sample_rate: int = 16_000,
48 | random_crop_len: int = -1,
49 | use_ratio: float = 1.0,
50 | normalize: bool = False,
51 | nansy_both: bool = False,
52 | ) -> None:
53 | super().__init__()
54 |
55 | logger.info(f"Loading audio from: {json_dir}")
56 | logger.info(f"Splits : {splits}")
57 | self.split_type = "train" if splits[0].startswith("train") else "valid"
58 | self.return_augmented = True
59 |
60 | # Data augmentation config
61 | self.random_crop_len = random_crop_len
62 | self.random_crop_len_lab = random_crop_len // 320
63 | self.use_ratio = use_ratio
64 | self.normalize = normalize
65 | self.sample_rate = sample_rate
66 | self.nansy_aug = spk2info is not None
67 |
68 | if self.nansy_aug:
69 | # Load speaker information
70 | logger.info("Apply NANSY speaker perturbation")
71 | # ref: https://arxiv.org/abs/2110.14513
72 | self.num_view = 2
73 | self.nansy_both = nansy_both
74 |
75 | with open(spk2info, "rb") as fp:
76 | self.spk2info = pickle.load(fp)
77 | self.spk2info = self.spk2info[self.split_type]
78 |
79 | self.rng = np.random.default_rng()
80 | self.Fc = np.exp(np.linspace(np.log(60), np.log(7600), 10))
81 |
82 | self.num_view = 2 if self.nansy_aug else 1
83 |
84 | # Load from .json files
85 | data_list = []
86 | for s in splits:
87 | with open(os.path.join(json_dir, s + ".json"), "r") as fp:
88 | data_list += json.load(fp)
89 |
90 | # Preserve certain ratio of data
91 | if self.use_ratio < 1.0:
92 | logger.info(
93 | f"Using only {self.use_ratio * 100:.0f}% of randomly chosen data"
94 | )
95 | random.shuffle(data_list)
96 | data_list = data_list[: int(len(data_list) * self.use_ratio)]
97 |
98 | # Remove files that are too long or too short
99 | logger.info(
100 | f"Removing files shorter than {min_audio_len} or longer than {max_audio_len} frames"
101 | )
102 | orig_tot_len = sum([l for _, l in data_list]) / sample_rate / 3600.0
103 | orig_num_file = len(data_list)
104 | data_list = [
105 | (p, l) for p, l in data_list if l >= min_audio_len and l <= max_audio_len
106 | ]
107 | new_tot_len = sum([l for _, l in data_list]) / sample_rate / 3600.0
108 | logger.info(f"Original audio files: {orig_num_file} ({orig_tot_len:.2f} hrs)")
109 | logger.info(f"Final audio files: {len(data_list)} ({new_tot_len:.2f} hrs)")
110 |
111 | # Sort by length (long to short)
112 | data_list = sorted(data_list, key=lambda x: x[1], reverse=True)
113 |
114 | # Extract speaker
115 | spk_list = [Path(p).stem.split("-")[0] for p, _ in data_list]
116 | name_list = [Path(p).stem for p, _ in data_list]
117 | self.data = [
118 | (p, l, s, n) for (p, l), s, n in zip(data_list, spk_list, name_list)
119 | ]
120 | self.data_lens = [l for _, l, _, _ in self.data]
121 |
122 | if self.nansy_aug:
123 | # Check speaker
124 | data_spk_set = set(spk_list)
125 | avail_spk_set = set(self.spk2info.keys())
126 | assert all(s in avail_spk_set for s in data_spk_set), "Missing speakers!"
127 | logger.info(f"Total {len(data_spk_set)} speakers")
128 |
129 | def __len__(self) -> int:
130 | return len(self.data)
131 |
132 | def __getitem__(self, index: int) -> List[torch.FloatTensor]:
133 | # each sample: (audio file path, number of samples, speaker id)
134 | path, num_frames, spk, uid = self.data[index]
135 |
136 | wav, sr = sf.read(path)
137 | if wav.ndim == 2:
138 | wav = wav.mean(-1)
139 |
140 | wav = wav.astype(np.float32)
141 |
142 | if not self.return_augmented:
143 | return torch.from_numpy(wav)
144 |
145 | # Randomly crop wave length
146 | if self.random_crop_len > 0 and len(wav) > self.random_crop_len:
147 | idx = np.random.randint(0, len(wav) - self.random_crop_len)
148 | wav = wav[idx : idx + self.random_crop_len]
149 |
150 | # Apply NANSY speaker perturbation
151 | if self.nansy_aug:
152 | if self.nansy_both:
153 | wavs = [
154 | self.perturb_speaker(path, wav, sr, spk)
155 | for _ in range(self.num_view)
156 | ]
157 | else:
158 | wavs = [wav] + [
159 | self.perturb_speaker(path, wav, sr, spk)
160 | for _ in range(self.num_view - 1)
161 | ]
162 | else:
163 | wavs = [wav for _ in range(self.num_view)]
164 |
165 | wavs = [torch.FloatTensor(w) for w in wavs]
166 | if self.normalize:
167 | wavs = [F.layer_norm(w, w.shape) for w in wavs]
168 |
169 | return wavs
170 |
171 | def get_spk_info(self, spk: str):
172 | _, (lo, hi, _) = self.spk2info[spk]
173 | if lo == 50:
174 | lo = 75
175 | if spk == "1447":
176 | lo, hi = 60, 400
177 | return lo, hi
178 |
179 | def random_eq(self, wav, sr):
180 | z = self.rng.uniform(0, 1, size=(10,))
181 | Q = Qmin * (Qmax / Qmin) ** z
182 | G = self.rng.uniform(-12, 12, size=(10,))
183 | sos = params2sos(G, self.Fc, Q, sr)
184 | wav = sosfilt(sos, wav)
185 | return wav
186 |
187 | def random_formant_f0(self, wav, sr, spk):
188 | lo, hi = self.get_spk_info(spk)
189 |
190 | ratio_fs = self.rng.uniform(1, 1.4)
191 | coin = self.rng.random() > 0.5
192 | ratio_fs = coin * ratio_fs + (1 - coin) * (1 / ratio_fs)
193 |
194 | ratio_ps = self.rng.uniform(1, 2)
195 | coin = self.rng.random() > 0.5
196 | ratio_ps = coin * ratio_ps + (1 - coin) * (1 / ratio_ps)
197 |
198 | ratio_pr = self.rng.uniform(1, 1.5)
199 | coin = self.rng.random() > 0.5
200 | ratio_pr = coin * ratio_pr + (1 - coin) * (1 / ratio_pr)
201 |
202 | ss = change_gender(wav, sr, lo, hi, ratio_fs, ratio_ps, ratio_pr)
203 |
204 | return ss
205 |
206 | def fixed_formant_f0(self, wav, sr, spk):
207 | _, (lo, hi, _) = self.spk2info[spk]
208 |
209 | if lo == 50:
210 | lo = 75
211 | ratio_fs, f0_med, ratio_pr = 1.2, 300, 1.2
212 | else:
213 | ratio_fs, f0_med, ratio_pr = 0.8, 100, 0.8
214 |
215 | ss = change_gender_f0(wav, sr, lo, hi, ratio_fs, f0_med, ratio_pr)
216 | return ss
217 |
218 | def perturb_speaker(self, path, wav, sr, spk):
219 | if self.split_type == "train":
220 | # Speaker perturbation
221 | try:
222 | wav_p = self.random_formant_f0(wav, sr, spk)
223 | except UserWarning:
224 | wav_p = np.copy(wav)
225 | logger.info(f"Praat warning - {path}")
226 | except RuntimeError:
227 | wav_p = np.copy(wav)
228 | logger.info(f"Praat Error - {path}")
229 | wav_p = self.random_eq(wav_p, sr)
230 | else:
231 | try:
232 | wav_p = self.fixed_formant_f0(wav, sr, spk)
233 | except UserWarning:
234 | wav_p = np.copy(wav)
235 | logger.info(f"Praat warning - {path}")
236 | except RuntimeError:
237 | wav_p = np.copy(wav)
238 | logger.info(f"Praat Error - {path}")
239 |
240 | wav_p = np.clip(wav_p, -1.0, 1.0)
241 | return wav_p
242 |
243 |
244 | class AudioPretrainPnmiValDataset(Dataset):
245 | def __init__(
246 | self,
247 | json_dir: str,
248 | phn_dir: str,
249 | splits: List[str],
250 | min_audio_len: int = 1,
251 | max_audio_len: int = 1_000_000,
252 | sample_rate: int = 16_000,
253 | crop_len: int = 160_000,
254 | normalize: bool = False,
255 | **kwargs,
256 | ) -> None:
257 | super().__init__()
258 |
259 | logger.info(f"Loading audio from: {json_dir}")
260 | logger.info(f"Loading phoneme alignments from: {phn_dir}")
261 | logger.info(f"Splits : {splits}")
262 | self.split_type = "train" if splits[0].startswith("train") else "valid"
263 |
264 | # Data augmentation config
265 | self.crop_len = crop_len
266 | self.normalize = normalize
267 | self.sample_rate = sample_rate
268 |
269 | # Load from .json files
270 | data_list = []
271 | for s in splits:
272 | with open(os.path.join(json_dir, s + ".json"), "r") as fp:
273 | data_list += json.load(fp)
274 |
275 | # Load from .tsv files
276 | self.uid2refs = {}
277 | for s in splits:
278 | self.uid2refs.update(read_phn(os.path.join(phn_dir, s + ".tsv")))
279 |
280 | # Remove files that are too long or too short
281 | logger.info(
282 | f"Removing files shorter than {min_audio_len} or longer than {max_audio_len} frames"
283 | )
284 | orig_tot_len = sum([l for _, l in data_list]) / sample_rate / 3600.0
285 | orig_num_file = len(data_list)
286 | data_list = [
287 | (p, l) for p, l in data_list if l >= min_audio_len and l <= max_audio_len
288 | ]
289 | new_tot_len = sum([l for _, l in data_list]) / sample_rate / 3600.0
290 | logger.info(f"Original audio files: {orig_num_file} ({orig_tot_len:.2f} hrs)")
291 | logger.info(f"Final audio files: {len(data_list)} ({new_tot_len:.2f} hrs)")
292 |
293 | # Sort by length (long to short)
294 | data_list = sorted(data_list, key=lambda x: x[1], reverse=True)
295 |
296 | # Extract speaker
297 | spk_list = [Path(p).stem.split("-")[0] for p, _ in data_list]
298 | name_list = [Path(p).stem for p, _ in data_list]
299 | self.data = [
300 | (p, l, s, n) for (p, l), s, n in zip(data_list, spk_list, name_list)
301 | ]
302 | self.data_lens = [l for _, l, _, _ in self.data]
303 |
304 | def __len__(self) -> int:
305 | return len(self.data)
306 |
307 | def __getitem__(self, index: int) -> List[torch.FloatTensor]:
308 | # each sample: (audio file path, number of samples, speaker id)
309 | path, num_frames, spk, uid = self.data[index]
310 |
311 | wav, sr = sf.read(path)
312 | if wav.ndim == 2:
313 | wav = wav.mean(-1)
314 |
315 | wav = wav.astype(np.float32)
316 |
317 | # Crop wave length
318 | if self.crop_len > 0 and len(wav) > self.crop_len:
319 | wav = wav[: self.crop_len]
320 |
321 | wav = torch.from_numpy(wav)
322 |
323 | if self.normalize:
324 | wav = F.layer_norm(wav, wav.shape)
325 |
326 | return wav, uid
327 |
328 |
329 | def collate_fn(
330 | batch: List[List[Tuple[torch.FloatTensor, List[int]]]],
331 | ) -> Tuple[torch.FloatTensor, torch.LongTensor, torch.BoolTensor,]:
332 | wav_list = []
333 | wav_len = []
334 |
335 | for wavs in batch:
336 | for _, w in enumerate(wavs):
337 | wav_list.append(w)
338 | wav_len.append(len(w))
339 |
340 | wav_list = pad_sequence(wav_list, batch_first=True)
341 | wav_len = torch.LongTensor(wav_len)
342 | padding_mask = len_to_padding(wav_len)
343 | # padding_mask: has value <=> 0 else 1
344 |
345 | return wav_list, wav_len, padding_mask
346 |
347 |
348 | def val_collate_fn(
349 | batch: List[List[Tuple[torch.FloatTensor, List[int]]]],
350 | ) -> Tuple[torch.FloatTensor, torch.LongTensor, torch.BoolTensor, torch.LongTensor]:
351 | wav_list = []
352 | wav_len = []
353 | uid_list = []
354 |
355 | for wav, uid in batch:
356 | wav_list.append(wav)
357 | wav_len.append(len(wav))
358 | uid_list.append(uid)
359 |
360 | wav_list = pad_sequence(wav_list, batch_first=True)
361 | wav_len = torch.LongTensor(wav_len)
362 | padding_mask = len_to_padding(wav_len)
363 | # padding_mask: has value <=> 0 else 1
364 |
365 | return wav_list, wav_len, padding_mask, uid_list
366 |
--------------------------------------------------------------------------------
/src/data/librispeech.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pathlib import Path
3 | from typing import List, Tuple
4 |
5 | import torchaudio
6 | from tqdm import tqdm
7 |
8 |
9 | def find_all_librispeech(root: str, sort_by_len: bool = False) -> List[Tuple[str, int]]:
10 | files = list(Path(root).rglob("*.flac"))
11 | files = [str(f) for f in files]
12 | file_lens = [torchaudio.info(f).num_frames for f in tqdm(files)]
13 | assert len(files) == len(file_lens), (len(files), len(file_lens))
14 | data = sorted(
15 | zip(files, file_lens), key=lambda x: x[1 if sort_by_len else 0], reverse=True
16 | )
17 | return data
18 |
19 |
20 | def save_data_info(data: List[Tuple[str, int]], path: str) -> None:
21 | with open(path, "w") as fp:
22 | json.dump(data, fp, indent=2)
23 |
--------------------------------------------------------------------------------
/src/data/sampler.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 |
3 | import torch
4 | import torch.distributed as dist
5 |
6 | from .dataset import AudioPretrainDataset
7 |
8 |
9 | class MaxLengthBatchSampler:
10 | def __init__(
11 | self,
12 | lengths: List[int],
13 | max_length: int,
14 | cropped_length: int = 160_000,
15 | shuffle: bool = False,
16 | drop_last: bool = False,
17 | seed: int = 7122,
18 | ) -> None:
19 | self.lengths = lengths
20 | self.max_length = max_length
21 | self.cropped_length = cropped_length if cropped_length > 0 else 1000000
22 | self.shuffle = shuffle
23 | self.drop_last = drop_last
24 | self.seed = seed
25 | self.epoch = 0
26 |
27 | def set_epoch(self, epoch: int):
28 | self.epoch = epoch
29 |
30 | def __iter__(self):
31 | batch_list = []
32 | batch = []
33 | cur_length = 0
34 | for i in range(len(self.lengths)):
35 | new_batch = batch + [i]
36 | cur_length += min(self.lengths[i], self.cropped_length)
37 |
38 | if cur_length <= self.max_length:
39 | batch = new_batch
40 | elif len(batch) == 0:
41 | raise ValueError(
42 | f"There is a single length {self.lengths[i]} larger than "
43 | f"max_length {self.max_length}. Please increase "
44 | "the max_length."
45 | )
46 | else:
47 | batch_list.append(batch)
48 | batch = [i]
49 | cur_length = min(self.lengths[i], self.cropped_length)
50 |
51 | if len(batch) > 0 and not self.drop_last:
52 | batch_list.append(batch)
53 |
54 | if self.shuffle:
55 | generator = torch.Generator()
56 | generator.manual_seed(self.epoch + self.seed)
57 | indices = torch.randperm(len(batch_list), generator=generator).tolist()
58 | else:
59 | indices = list(range(len(batch_list)))
60 |
61 | for i in indices:
62 | yield batch_list[i]
63 |
64 | def __len__(self):
65 | return len(list(iter(self)))
66 |
67 |
68 | class MaxLengthDistributedSampler:
69 | def __init__(
70 | self,
71 | dataset: AudioPretrainDataset,
72 | lengths: List[int],
73 | max_length: int,
74 | cropped_length: int = 160_000,
75 | num_replicas: Optional[int] = None,
76 | rank: Optional[int] = None,
77 | shuffle: bool = True,
78 | seed: int = 7122,
79 | drop_last: bool = False,
80 | ) -> None:
81 |
82 | if num_replicas is None:
83 | if not dist.is_available():
84 | raise RuntimeError("Requires distributed package to be available")
85 | num_replicas = dist.get_world_size()
86 | if rank is None:
87 | if not dist.is_available():
88 | raise RuntimeError("Requires distributed package to be available")
89 | rank = dist.get_rank()
90 | if rank >= num_replicas or rank < 0:
91 | raise ValueError(
92 | "Invalid rank {}, rank should be in the interval"
93 | " [0, {}]".format(rank, num_replicas - 1)
94 | )
95 |
96 | print(f"- Rank: {rank} / Num replicas: {num_replicas}")
97 |
98 | self.dataset = dataset
99 | self.shuffle = shuffle
100 | self.seed = seed
101 | self.drop_last = drop_last
102 | self.epoch = 0
103 | self.num_replicas = num_replicas
104 | self.rank = rank
105 |
106 | self.lengths = lengths
107 | self.max_length = max_length
108 | self.cropped_length = cropped_length if cropped_length > 0 else 1000000
109 |
110 | def set_epoch(self, epoch: int):
111 | self.epoch = epoch
112 |
113 | def __iter__(self):
114 | batch_list = []
115 | batch = []
116 | cur_length = 0
117 | for i in range(len(self.lengths)):
118 | new_batch = batch + [i]
119 | cur_length += min(self.lengths[i], self.cropped_length)
120 |
121 | if cur_length <= self.max_length:
122 | batch = new_batch
123 | elif len(batch) == 0:
124 | raise ValueError(
125 | f"There is a single length {self.lengths[i]} larger than "
126 | f"max_length {self.max_length}. Please increase "
127 | "the max_length."
128 | )
129 | else:
130 | batch_list.append(batch)
131 | batch = [i]
132 | cur_length = min(self.lengths[i], self.cropped_length)
133 |
134 | if len(batch) > 0 and not self.drop_last:
135 | batch_list.append(batch)
136 |
137 | if self.shuffle:
138 | generator = torch.Generator()
139 | generator.manual_seed(self.epoch + self.seed)
140 | indices = torch.randperm(len(batch_list), generator=generator).tolist()
141 | else:
142 | indices = list(range(len(batch_list)))
143 |
144 | max_index = len(indices) - len(indices) % self.num_replicas
145 | indices = indices[self.rank : max_index : self.num_replicas]
146 | for i in indices:
147 | yield batch_list[i]
148 |
149 | def __len__(self):
150 | return len(list(iter(self)))
151 |
--------------------------------------------------------------------------------
/src/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .spin import SpinModel
2 |
--------------------------------------------------------------------------------
/src/model/base.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 | import pytorch_lightning as pl
4 | import yaml
5 |
6 |
7 | class BaseModel(pl.LightningModule):
8 | def __init__(self, config) -> None:
9 | super().__init__()
10 |
11 | if isinstance(config, str) and config.split(".")[-1] in {"yaml", "yml"}:
12 | config = yaml.load(open(config, "r"), Loader=yaml.FullLoader)
13 |
14 | self.config = config
15 | self.save_hyperparameters(config)
16 |
17 | @abc.abstractmethod
18 | def forward(self, batch):
19 | raise NotImplementedError
20 |
21 | @abc.abstractmethod
22 | def training_step(self, batch, batch_idx):
23 | raise NotImplementedError
24 |
25 | @abc.abstractmethod
26 | def configure_optimizers(self):
27 | raise NotImplementedError
28 |
--------------------------------------------------------------------------------
/src/model/spin.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import logging
3 | from collections import defaultdict
4 | from typing import Union
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | from torch import nn
9 | from torch.utils.data import DataLoader
10 |
11 | from src.data import (
12 | AudioPretrainDataset,
13 | MaxLengthBatchSampler,
14 | MaxLengthDistributedSampler,
15 | collate_fn,
16 | )
17 | from src.nn import DNN, HuBERT, SwavVQDisentangle, WavLM
18 | from src.util import compute_show_pnmi, get_scheduler, update_padding_mask
19 |
20 | from .base import BaseModel
21 |
22 | logger = logging.getLogger("spin")
23 |
24 |
25 | def get_pred_head(type_name: str, hid_dim: int, config: dict) -> Union[None, nn.Module]:
26 | if type_name == "None":
27 | return None
28 | if type_name == "DNN":
29 | return DNN(hid_dim, **config)
30 | raise NotImplementedError(type_name)
31 |
32 |
33 | def get_loss(type_name: str, hid_dim: int, config: dict) -> nn.Module:
34 | if type_name == "SwavVQDisentangle":
35 | return SwavVQDisentangle(hid_dim, **config)
36 | raise NotImplementedError(type_name)
37 |
38 |
39 | class SpinModel(BaseModel):
40 | def __init__(self, config, num_view: int = 2) -> None:
41 | super().__init__(config)
42 |
43 | config = copy.deepcopy(config)
44 | config = config["model"]
45 | logger.info(f"Model config: {config}")
46 |
47 | self.encoder_type = config["encoder"].pop("type", "HuBERT")
48 | self.pred_head_type = config["pred_head"].pop("type", "DNN")
49 | self.loss_type = config["loss"].pop("type", "SwavVQDisentangle")
50 |
51 | logger.info(f"Encoder: {self.encoder_type}")
52 | logger.info(f"Prediction head: {self.pred_head_type}")
53 | logger.info(f"Loss: {self.loss_type}")
54 |
55 | # Setup number of views
56 | self.normalize = config.get("normalize", False)
57 | self.num_view = num_view
58 | assert num_view == 2, num_view # NOTE: currently we support 2 views only
59 |
60 | # Setup encoder model
61 | if self.encoder_type in {"HuBERT", "WavLM"}:
62 | self.use_layer = config["encoder"].pop("use_layer", 12)
63 | self.encoder = eval(self.encoder_type)(**config["encoder"])
64 | hid_dim = self.encoder.hidden_sizes[self.use_layer]
65 | self.encoder_rate = 320
66 | logger.info(f"Taking features from layer {self.use_layer}")
67 | else:
68 | raise NotImplementedError(self.encoder_type)
69 |
70 | # All layers to be processed
71 | self.use_layers = [self.use_layer]
72 | logger.info(f"All selected layers: {self.use_layers}")
73 |
74 | # Setup prediction head
75 | if len(self.use_layers) == 1:
76 | self.pred_head = get_pred_head(
77 | self.pred_head_type, hid_dim, config["pred_head"]
78 | )
79 | hid_dim = self.pred_head.out_dim
80 | else:
81 | self.pred_head = nn.ModuleList(
82 | [
83 | get_pred_head(self.pred_head_type, hid_dim, config["pred_head"])
84 | for _ in self.use_layers
85 | ]
86 | )
87 | hid_dim = self.pred_head[0].out_dim
88 |
89 | # Setup loss function
90 | self.loss_module = get_loss(self.loss_type, hid_dim, config["loss"])
91 |
92 | # Validation
93 | self.val_uid2hyp = {}
94 |
95 | def normalize_wavs(
96 | self,
97 | wavs: torch.Tensor,
98 | wavs_len: torch.LongTensor,
99 | ) -> torch.Tensor:
100 | with torch.no_grad():
101 | for i in len(wavs):
102 | wavs[i, : wavs_len[i]] = F.layer_norm(
103 | wavs[i, : wavs_len[i]], wavs_len[i]
104 | )
105 | return wavs
106 |
107 | def forward_features(
108 | self,
109 | wavs: torch.Tensor,
110 | padding_mask: torch.BoolTensor,
111 | ):
112 | if self.encoder_type in {"HuBERT", "WavLM"}:
113 | feat_list, feat_len, padding_mask = self.encoder(wavs, padding_mask)
114 | repr_list = [feat_list[l] for l in self.use_layers]
115 |
116 | return repr_list, feat_list, feat_len, padding_mask
117 |
118 | def forward_pred_head(
119 | self, feat: torch.Tensor, feat_len: torch.LongTensor, i: int = None
120 | ):
121 | if len(self.use_layers) == 1:
122 | return self.pred_head(feat, feat_len)
123 | else:
124 | assert isinstance(i, int), i
125 | return self.pred_head[i](feat, feat_len)
126 |
127 | def forward(self, batch, feat_only: bool = False):
128 | # NOTE: padding_mask is 1 when the position is padded
129 | wavs, wavs_len, padding_mask = batch
130 |
131 | wavs.masked_fill_(padding_mask, 0.0)
132 |
133 | # Normalize wavs
134 | if self.normalize:
135 | wavs = self.normalize_wavs(wavs, wavs_len)
136 |
137 | # Extract features
138 | repr_list, feat_list, feat_len, padding_mask = self.forward_features(
139 | wavs, padding_mask
140 | )
141 | padding_mask = update_padding_mask(padding_mask, repr_list[0].shape[1])
142 |
143 | # Prediction head
144 | repr_list = [
145 | self.forward_pred_head(repr, feat_len, i)
146 | for i, repr in enumerate(repr_list)
147 | ]
148 |
149 | # Return results
150 | if feat_only:
151 | outputs = {
152 | "repr_list": repr_list,
153 | "feat_len": feat_len,
154 | "feat_list": feat_list,
155 | "padding_mask": padding_mask,
156 | }
157 | if self.loss_type == "SwavVQDisentangle":
158 | if self.loss_module.l2_norm:
159 | outputs["repr_list"] = F.normalize(repr_list[0], dim=-1)
160 | logits, codes = self.loss_module.produce_targets(
161 | outputs["repr_list"], normalized=True
162 | )
163 | outputs["logits"] = logits
164 | outputs["codes"] = codes
165 |
166 | return outputs
167 |
168 | # feat: (Batch * View, Time, Dim)
169 | # Split batch into views
170 | repr_views = [
171 | [
172 | r[i :: self.num_view][~padding_mask[i :: self.num_view]]
173 | for r in repr_list
174 | ]
175 | for i in range(self.num_view)
176 | ]
177 |
178 | # Computes loss from each pair of views
179 | total_loss = 0
180 | loss_res = defaultdict(list)
181 |
182 | # Main loss
183 | if self.loss_type in {"SwavVQDisentangle"}:
184 | res = self.loss_module.cal_loss(repr_views[0][0], repr_views[1][0])
185 | total_loss += res.pop("loss")
186 | for k in res:
187 | loss_res[f"{k}"].append(res[k])
188 |
189 | for k in loss_res:
190 | loss_res[k] = sum(loss_res[k]) / len(loss_res[k])
191 |
192 | return total_loss, loss_res
193 |
194 | def training_step(self, batch, batch_idx):
195 | total_loss, loss_res = self.forward(batch)
196 |
197 | self.log("loss", total_loss)
198 | for k, v in loss_res.items():
199 | if k == "acc":
200 | self.log(k, v, prog_bar=True)
201 | else:
202 | self.log(k, v)
203 |
204 | return total_loss
205 |
206 | def validation_step(self, batch, batch_idx):
207 | wav_list, wav_len, padding_mask, uid_list = batch
208 | results = self.forward(
209 | (wav_list, wav_len, padding_mask),
210 | feat_only=True,
211 | )
212 | codes = results.get("codes", None)
213 | if codes is None:
214 | return
215 |
216 | code = codes.cpu().numpy()
217 | feat_len = results["feat_len"]
218 | for i, (uid, c) in enumerate(zip(uid_list, code)):
219 | self.val_uid2hyp[uid] = c[: feat_len[i]]
220 |
221 | def on_validation_epoch_end(self) -> None:
222 | check_1, check_2 = True, True
223 | try:
224 | uid2ref = self.trainer.val_dataloaders[0].dataset.uid2refs
225 | except:
226 | check_1 = False
227 |
228 | if not check_1:
229 | try:
230 | uid2ref = self.trainer.val_dataloaders.dataset.uid2refs
231 | except:
232 | check_2 = False
233 |
234 | if not check_1 and not check_2:
235 | logger.info("Cannot find uid2ref in validation dataloader, skip PNMI")
236 | return
237 |
238 | if len(self.val_uid2hyp) == 0:
239 | return
240 |
241 | res = compute_show_pnmi(
242 | uid2ref, self.val_uid2hyp, upsample=self.encoder_rate // 160
243 | )
244 | self.log("cls_pur", res["cls_pur"])
245 | self.log("phn_pur", res["phn_pur"])
246 | self.log("pnmi", res["pnmi"])
247 | self.val_uid2hyp.clear()
248 |
249 | def on_before_zero_grad(self, optimizer) -> None:
250 | if self.loss_type == "SwavVQDisentangle":
251 | self.loss_module.normalize_codebook()
252 |
253 | return super().on_before_zero_grad(optimizer)
254 |
255 | def on_train_epoch_start(self) -> None:
256 | super().on_train_epoch_start()
257 |
258 | try:
259 | self.trainer.train_dataloader.batch_sampler.set_epoch(self.current_epoch)
260 | logger.info(f"Update epoch for batch sampler to {self.current_epoch}")
261 | except:
262 | logger.warn(
263 | "Unable to update epoch for batch sampler (possibly using fixed batch_size)"
264 | )
265 |
266 | def configure_optimizers(self):
267 | params = []
268 | if self.encoder_type in {"HuBERT", "WavLM"}:
269 | params += self.encoder.trainable_parameters()
270 | else:
271 | params += list(self.encoder.parameters())
272 |
273 | if self.pred_head:
274 | params += list(self.pred_head.parameters())
275 | params += list(self.loss_module.parameters())
276 |
277 | optimizer = getattr(torch.optim, self.config["optim"]["optimizer"]["name"])(
278 | params, **self.config["optim"]["optimizer"]["args"]
279 | )
280 |
281 | if self.config["optim"].get("scheduler", None):
282 | scheduler = get_scheduler(
283 | self.config["optim"]["scheduler"]["name"],
284 | optimizer,
285 | **self.config["optim"]["scheduler"]["args"],
286 | )
287 |
288 | return {
289 | "optimizer": optimizer,
290 | "lr_scheduler": {
291 | "scheduler": scheduler,
292 | "interval": "step",
293 | "frequency": 1,
294 | },
295 | }
296 | else:
297 | return optimizer
298 |
299 | def set_random_seed(self, seed: int = 7122):
300 | self.seed = seed
301 |
302 | def set_njobs(self, njobs: int = 0):
303 | self.njobs = njobs
304 |
305 | def set_use_ddp(self, use_ddp: bool = False):
306 | self.use_ddp = use_ddp
307 |
308 | def train_dataloader(self):
309 | dataset = AudioPretrainDataset(**self.config["data"])
310 | if "batch_len" in self.config["hparam"]:
311 | if self.use_ddp:
312 | sampler = MaxLengthDistributedSampler(
313 | dataset,
314 | dataset.data_lens,
315 | max_length=self.config["hparam"]["batch_len"],
316 | cropped_length=self.config["data"]["random_crop_len"],
317 | shuffle=True,
318 | drop_last=True,
319 | seed=self.seed,
320 | )
321 | else:
322 | sampler = MaxLengthBatchSampler(
323 | dataset.data_lens,
324 | max_length=self.config["hparam"]["batch_len"],
325 | cropped_length=self.config["data"]["random_crop_len"],
326 | shuffle=True,
327 | drop_last=True,
328 | seed=self.seed,
329 | )
330 | loader = DataLoader(
331 | dataset,
332 | batch_sampler=sampler,
333 | num_workers=self.njobs,
334 | pin_memory=True,
335 | collate_fn=collate_fn,
336 | )
337 | elif "batch_size" in self.config["hparam"]:
338 | loader = DataLoader(
339 | dataset,
340 | batch_size=self.config["hparam"]["batch_size"],
341 | num_workers=self.njobs,
342 | pin_memory=True,
343 | collate_fn=collate_fn,
344 | shuffle=True,
345 | drop_last=True,
346 | )
347 |
348 | return loader
349 |
--------------------------------------------------------------------------------
/src/nn/__init__.py:
--------------------------------------------------------------------------------
1 | from .dnn import DNN
2 | from .hubert import HuBERT
3 | from .swav_vq_dis import SwavVQDisentangle
4 | from .wavlm import WavLM
5 |
--------------------------------------------------------------------------------
/src/nn/dnn.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import torch
4 | from torch import nn
5 |
6 |
7 | class DNN(nn.Module):
8 | def __init__(
9 | self,
10 | in_dim: int,
11 | hid_dims: List[int],
12 | dropout: float = 0.0,
13 | activation: str = "ReLU",
14 | activate_last: bool = False,
15 | ) -> None:
16 | super().__init__()
17 |
18 | self.in_dim = in_dim
19 | self.out_dim = hid_dims[-1]
20 | self.activate_last = activate_last
21 |
22 | assert len(hid_dims) > 0, len(hid_dims)
23 | hid_dims = [in_dim] + hid_dims
24 |
25 | self.layers = nn.ModuleList(
26 | [nn.Linear(hid_dims[i], hid_dims[i + 1]) for i in range(len(hid_dims) - 1)]
27 | )
28 | self.num_layer = len(self.layers)
29 | self.dropout = nn.Dropout(dropout)
30 | n_acts = self.num_layer - (0 if self.activate_last else 1)
31 | self.acts = nn.ModuleList([getattr(nn, activation)() for _ in range(n_acts)])
32 |
33 | def forward(self, x: torch.Tensor, x_len: torch.LongTensor = None) -> torch.Tensor:
34 | for i in range(self.num_layer):
35 | x = self.layers[i](x)
36 | if i < self.num_layer - 1 or self.activate_last:
37 | x = self.dropout(x)
38 | x = self.acts[i](x)
39 | return x
40 |
--------------------------------------------------------------------------------
/src/nn/hubert.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Any, List
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from s3prl.upstream.hubert.convert import load_converted_model
7 | from s3prl.upstream.hubert.hubert_model import (
8 | HubertConfig,
9 | HubertModel,
10 | HubertPretrainingConfig,
11 | )
12 | from s3prl.upstream.utils import merge_with_parent
13 | from s3prl.upstream.wav2vec2.wav2vec2_model import GradMultiply
14 | from s3prl.util.download import _urls_to_filepaths
15 | from torch import nn
16 |
17 | from src.util import (
18 | freeze_module,
19 | init_module_bert,
20 | init_module_cnn,
21 | init_module_pos_conv,
22 | padding_to_len,
23 | )
24 |
25 | logger = logging.getLogger("hubert")
26 |
27 |
28 | def get_hubert_configs(ckpt_path):
29 | ckpt_state = torch.load(ckpt_path, map_location="cpu")
30 |
31 | for required_key in [
32 | "task_cfg",
33 | "model_cfg",
34 | "model_weight",
35 | "dictionaries_symbols",
36 | ]:
37 | if required_key not in ckpt_state:
38 | raise ValueError(
39 | f"{ckpt_path} is not a valid checkpoint since the required key: {required_key} is missing"
40 | )
41 |
42 | task_cfg = merge_with_parent(HubertPretrainingConfig, ckpt_state["task_cfg"])
43 | model_cfg = merge_with_parent(HubertConfig, ckpt_state["model_cfg"])
44 | dictionaries = ckpt_state["dictionaries_symbols"]
45 | return model_cfg, task_cfg, dictionaries
46 |
47 |
48 | def random_hubert(
49 | model_cfg: HubertConfig,
50 | task_cfg: HubertPretrainingConfig,
51 | dictionaries: List[Any],
52 | ):
53 | model = HubertModel(model_cfg, task_cfg, dictionaries)
54 | return model, model_cfg, task_cfg
55 |
56 |
57 | class HuBERT(nn.Module):
58 | def __init__(
59 | self,
60 | path_or_url: str = None,
61 | refresh: bool = False,
62 | pre_normalize: bool = False,
63 | normalize: bool = False,
64 | feat_select: str = "x",
65 | randomize_all: bool = False,
66 | randomize_layers: List[int] = [],
67 | freeze_all: bool = False,
68 | freeze_layers: List[int] = [],
69 | disable_pos: bool = False,
70 | masking: bool = False,
71 | ):
72 | super().__init__()
73 |
74 | ckpt = "https://huggingface.co/s3prl/converted_ckpts/resolve/main/hubert_base_ls960.pt"
75 | if path_or_url is not None:
76 | ckpt = path_or_url
77 | if ckpt.startswith("https"):
78 | ckpt = _urls_to_filepaths(ckpt, refresh=refresh)
79 |
80 | if not randomize_all:
81 | model, task_cfg = load_converted_model(ckpt)
82 | else:
83 | model_cfg, task_cfg, dictionaries = get_hubert_configs(ckpt)
84 | model, model_cfg, task_cfg = random_hubert(
85 | model_cfg, task_cfg, dictionaries
86 | )
87 | logger.info("Random HuBERT used")
88 |
89 | self.model: HubertModel = model
90 | self.task_cfg = task_cfg
91 | self.wav_normalize = task_cfg.normalize
92 | self.pre_normalize = pre_normalize
93 | self.normalize = normalize
94 | self.masking = masking
95 | self.num_layers = len(self.model.encoder.layers) + 1 # CNN + 12 Transformer
96 | self.hidden_sizes = [self.model.encoder.embedding_dim] * self.num_layers
97 | self.model.encoder.layerdrop = 0.0
98 |
99 | self.feat_select = 0
100 | if feat_select == "att":
101 | self.feat_select = 1
102 | if feat_select == "ffn":
103 | self.feat_select = 2
104 |
105 | logger.info(f"Feature selection: {feat_select} (index: {self.feat_select})")
106 | logger.info(
107 | f"Randomize all = {randomize_all} (randomize layers = {randomize_layers})"
108 | )
109 | logger.info(f"Freeze all = {freeze_all} (freeze layers = {freeze_layers})")
110 | logger.info(f"Apply masking to features = {masking}")
111 |
112 | self.randomize_all = randomize_all
113 | self.randomize_layers = randomize_layers
114 | self.freeze_all = freeze_all
115 | self.freeze_layers = freeze_layers
116 |
117 | self.freeze_cnn = (0 in freeze_layers) or self.freeze_all
118 | self.freeze_pos = ("pos" in freeze_layers) or self.freeze_all
119 | self.disable_pos = disable_pos
120 | if disable_pos:
121 | logger.info("Disabled CNN positional encoding")
122 |
123 | # Randomize weights
124 | if randomize_all:
125 | randomize_layers = []
126 | if len(randomize_layers) > 0:
127 | for i in randomize_layers:
128 | if i == 0:
129 | self.model.feature_extractor.apply(init_module_cnn)
130 | self.model.layer_norm.reset_parameters()
131 | if self.model.post_extract_proj is not None:
132 | self.model.post_extract_proj.reset_parameters()
133 | elif i == "pos":
134 | self.model.encoder.pos_conv.apply(init_module_pos_conv)
135 | else:
136 | self.model.encoder.layers[i - 1].apply(init_module_bert)
137 | if i == self.num_layers - 1 and self.model.encoder.layer_norm_first:
138 | self.model.encoder.layer_norm.reset_parameters()
139 |
140 | # Freeze weights
141 | if freeze_all:
142 | freeze_module(self.model)
143 | elif len(freeze_layers) > 0:
144 | for i in freeze_layers:
145 | if i == 0:
146 | self.model.feature_grad_mult = 0.0
147 | freeze_module(self.model.feature_extractor)
148 | freeze_module(self.model.layer_norm)
149 | if self.model.post_extract_proj is not None:
150 | freeze_module(self.model.post_extract_proj)
151 | elif i == "pos":
152 | freeze_module(self.model.encoder.pos_conv)
153 | else:
154 | assert isinstance(i, int), i
155 | freeze_module(self.model.encoder.layers[i - 1])
156 |
157 | if self.freeze_cnn:
158 | self.model.feature_grad_mult = 0.0
159 |
160 | self.model.remove_pretraining_modules()
161 |
162 | def trainable_parameters(self):
163 | params = []
164 |
165 | if self.freeze_all:
166 | return []
167 |
168 | if not self.freeze_all and len(self.freeze_layers) == 0:
169 | logger.info("Trains the entire model")
170 | return self.model.parameters()
171 |
172 | params = []
173 | for i in ["pos"] + list(range(self.num_layers)):
174 | if i in self.freeze_layers:
175 | continue
176 | if i == 0:
177 | params += list(self.model.feature_extractor.parameters())
178 | params += list(self.model.layer_norm.parameters())
179 | if self.model.post_extract_proj is not None:
180 | params += list(self.model.post_extract_proj.parameters())
181 | elif i == "pos":
182 | params += list(self.model.encoder.pos_conv.parameters())
183 | else:
184 | params += list(self.model.encoder.layers[i - 1].parameters())
185 | if i == self.num_layers - 1 and self.model.encoder.layer_norm_first:
186 | params += list(self.model.encoder.layer_norm.parameters())
187 |
188 | if self.masking:
189 | params += list(self.model.mask_emb.parameters())
190 |
191 | return params
192 |
193 | def forward_features(self, x: torch.Tensor) -> torch.Tensor:
194 | if self.model.feature_grad_mult > 0:
195 | features = self.model.feature_extractor(x)
196 | if self.model.feature_grad_mult != 1.0:
197 | features = GradMultiply.apply(features, self.model.feature_grad_mult)
198 | else:
199 | with torch.no_grad():
200 | features = self.model.feature_extractor(x)
201 |
202 | return features
203 |
204 | def forward(
205 | self,
206 | wavs: torch.FloatTensor,
207 | padding_mask: torch.BoolTensor,
208 | ):
209 | if self.pre_normalize:
210 | with torch.no_grad():
211 | wavs = (wavs - wavs[~padding_mask].mean()) / (
212 | wavs[~padding_mask].std() + 1e-5
213 | )
214 |
215 | if self.wav_normalize:
216 | with torch.no_grad():
217 | wav_len = padding_to_len(padding_mask)
218 | for i in range(len(wavs)):
219 | wavs[i, : wav_len[i]] = F.layer_norm(
220 | wavs[i, : wav_len[i]], (wav_len[i],)
221 | )
222 |
223 | features = self.forward_features(wavs).transpose(1, 2)
224 | if self.freeze_cnn:
225 | with torch.no_grad():
226 | features = self.model.layer_norm(features)
227 | if self.model.post_extract_proj is not None:
228 | features = self.model.post_extract_proj(features)
229 | else:
230 | features = self.model.layer_norm(features)
231 | if self.model.post_extract_proj is not None:
232 | features = self.model.post_extract_proj(features)
233 |
234 | if padding_mask is not None:
235 | padding_mask = self.model.forward_padding_mask(features, padding_mask)
236 |
237 | x = self.model.dropout_input(features)
238 | if self.masking and self.training:
239 | x, mask_indices = self.model.apply_mask(x, padding_mask, None)
240 |
241 | # feature: (B, T, D), float
242 | # x: (B, T, D), float
243 | # padding_mask: (B, T), bool
244 |
245 | if self.freeze_all:
246 | with torch.no_grad():
247 | _, layer_results = self.model.encoder(
248 | x,
249 | padding_mask=padding_mask,
250 | disable_pos=self.disable_pos,
251 | )
252 | else:
253 | _, layer_results = self.model.encoder(
254 | x,
255 | padding_mask=padding_mask,
256 | freeze_pos=self.freeze_pos,
257 | freeze_layers=self.freeze_layers,
258 | disable_pos=self.disable_pos,
259 | )
260 |
261 | feat_list = [features] + [
262 | feat[self.feat_select].transpose(0, 1) for feat in layer_results
263 | ]
264 |
265 | if self.normalize:
266 | feat_list = [F.layer_norm(feat, feat.shape[-1:]) for feat in feat_list]
267 |
268 | feat_len = padding_to_len(padding_mask)
269 |
270 | return feat_list, feat_len, padding_mask
271 |
--------------------------------------------------------------------------------
/src/nn/swav_vq_dis.py:
--------------------------------------------------------------------------------
1 | # Ref: https://github.com/facebookresearch/swav/blob/main/main_swav.py
2 |
3 | import logging
4 |
5 | import torch
6 | import torch.distributed as dist
7 | import torch.nn.functional as F
8 | from torch import nn
9 |
10 | logger = logging.getLogger("swav_vq_dis")
11 |
12 |
13 | @torch.no_grad()
14 | def compute_sinkhorn(
15 | out: torch.Tensor, epsilon: float, sinkhorn_iterations: int
16 | ) -> torch.Tensor:
17 | # out: (B, K)
18 | B, K = out.shape
19 | Q = out.div(epsilon).exp().t()
20 | # Q is K-by-B for consistency with notations from our paper
21 |
22 | # make the matrix sums to 1
23 | if dist.is_initialized():
24 | sum_Q = Q.sum()
25 | dist.all_reduce(sum_Q)
26 | Q.div_(sum_Q)
27 | else:
28 | Q.div_(Q.sum())
29 |
30 | for _ in range(sinkhorn_iterations):
31 | # normalize each row: total weight per prototype must be 1/K
32 | if dist.is_initialized():
33 | sum_of_rows = Q.sum(dim=1, keepdim=True)
34 | dist.all_reduce(sum_of_rows)
35 | Q.div_(sum_of_rows * K)
36 | else:
37 | Q.div_(Q.sum(dim=1, keepdim=True) * K)
38 |
39 | # normalize each column: total weight per sample must be 1/B
40 | Q.div_(Q.sum(dim=0, keepdim=True) * B)
41 |
42 | Q.mul_(B) # the colomns must sum to 1 so that Q is an assignment
43 | return Q.t()
44 |
45 |
46 | class SwavVQDisentangle(nn.Module):
47 | def __init__(
48 | self,
49 | dim: int,
50 | num_vars: int,
51 | epsilon: float = 0.05,
52 | sinkhorn_iters: int = 3,
53 | temp: float = 0.1,
54 | l2_norm: bool = True,
55 | hard_target: bool = False,
56 | prob_ratio: float = 1.0,
57 | ) -> None:
58 | super().__init__()
59 |
60 | self.dim = dim
61 | self.num_vars = num_vars
62 | self.epsilon = epsilon
63 | self.sinkhorn_iters = sinkhorn_iters
64 | self.temp = temp
65 | self.l2_norm = l2_norm
66 | self.hard_target = hard_target
67 | self.prob_ratio = prob_ratio
68 |
69 | logger.info(f"Codebook size: {num_vars}")
70 | self.codebook = nn.Linear(dim, num_vars, bias=False)
71 |
72 | def produce_targets(self, z: torch.Tensor, normalized: bool = False):
73 | if self.l2_norm and not normalized:
74 | z = F.normalize(z, dim=-1)
75 | logits = self.codebook(z) / self.temp
76 | codes = torch.argmax(logits, -1)
77 | return logits, codes
78 |
79 | @torch.no_grad()
80 | def normalize_codebook(self) -> None:
81 | w = self.codebook.weight.data.clone()
82 | w = F.normalize(w, dim=1, p=2)
83 | self.codebook.weight.copy_(w)
84 |
85 | @torch.no_grad()
86 | def zero_grad_codebook(self) -> None:
87 | self.codebook.zero_grad()
88 |
89 | @torch.no_grad()
90 | def copy_codebook(self) -> None:
91 | self.normalize_codebook()
92 | self.codebook_copy = self.codebook.weight.data.detach()
93 |
94 | @torch.no_grad()
95 | def restore_codebook(self) -> None:
96 | self.codebook.weight.copy_(self.codebook_copy)
97 |
98 | def forward(
99 | self,
100 | z_1: torch.Tensor,
101 | z_2: torch.Tensor,
102 | ):
103 | B = len(z_1)
104 |
105 | if self.l2_norm:
106 | z_1 = F.normalize(z_1, dim=1)
107 | z_2 = F.normalize(z_2, dim=1)
108 |
109 | logits_1: torch.Tensor = self.codebook(z_1) # (Batch, Num Codes)
110 | logits_2: torch.Tensor = self.codebook(z_2) # (Batch, Num Codes)
111 |
112 | # Compute targets
113 | with torch.no_grad():
114 | tgt_logits_w_1 = logits_1 * self.prob_ratio + logits_2 * (
115 | 1 - self.prob_ratio
116 | )
117 | tgt_logits_w_2 = logits_2 * self.prob_ratio + logits_1 * (
118 | 1 - self.prob_ratio
119 | )
120 |
121 | tgt_probs_1 = compute_sinkhorn(
122 | tgt_logits_w_1.detach(), self.epsilon, self.sinkhorn_iters
123 | )
124 | tgt_probs_2 = compute_sinkhorn(
125 | tgt_logits_w_2.detach(), self.epsilon, self.sinkhorn_iters
126 | )
127 |
128 | # Compute cross-entropy loss
129 | logits_1.div_(self.temp)
130 | logits_2.div_(self.temp)
131 | log_prob_1 = logits_1.log_softmax(1)
132 | log_prob_2 = logits_2.log_softmax(1)
133 |
134 | loss = 0
135 | if self.hard_target:
136 | loss_ce = 0.5 * (
137 | F.cross_entropy(
138 | logits_2,
139 | tgt_probs_1.argmax(-1),
140 | )
141 | + F.cross_entropy(
142 | logits_1,
143 | tgt_probs_2.argmax(-1),
144 | )
145 | )
146 | else:
147 | loss_ce = -0.5 * (
148 | (tgt_probs_1 * log_prob_2).sum(1).mean()
149 | + (tgt_probs_2 * log_prob_1).sum(1).mean()
150 | )
151 |
152 | loss += loss_ce
153 | result = {"loss_ce": loss_ce, "batch_size": B}
154 | result["loss"] = loss
155 |
156 | with torch.no_grad():
157 | logits = torch.cat([logits_1, logits_2], dim=0)
158 | _, k = logits.max(-1)
159 | hard_x = logits.new_zeros(*logits.shape).scatter_(1, k.view(-1, 1), 1.0)
160 | hard_probs = hard_x.float().mean(0)
161 | result["code_perplexity"] = (
162 | torch.exp(-torch.sum(hard_probs * torch.log(hard_probs + 1e-7), dim=-1))
163 | .sum()
164 | .cpu()
165 | .detach()
166 | )
167 |
168 | avg_probs = logits.float().softmax(-1).mean(0)
169 | result["prob_perplexity"] = (
170 | torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-7), dim=-1))
171 | .sum()
172 | .cpu()
173 | .detach()
174 | )
175 | acc_1 = (
176 | (torch.argmax(logits_1, dim=1) == torch.argmax(tgt_probs_2, dim=1))
177 | .float()
178 | .mean()
179 | .cpu()
180 | .detach()
181 | .item()
182 | )
183 | acc_2 = (
184 | (torch.argmax(logits_2, dim=1) == torch.argmax(tgt_probs_1, dim=1))
185 | .float()
186 | .mean()
187 | .cpu()
188 | .detach()
189 | .item()
190 | )
191 | result["acc"] = float((acc_1 + acc_2) / 2)
192 | result["acc_1"] = float(acc_1)
193 | result["acc_2"] = float(acc_2)
194 |
195 | return result
196 |
197 | def cal_loss(
198 | self,
199 | z_1: torch.Tensor,
200 | z_2: torch.Tensor,
201 | ):
202 | return self.forward(z_1, z_2)
203 |
--------------------------------------------------------------------------------
/src/nn/wavlm.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import List
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from s3prl.upstream.wavlm.modules import GradMultiply
7 | from s3prl.upstream.wavlm.WavLM import WavLM as WavLMModel
8 | from s3prl.upstream.wavlm.WavLM import WavLMConfig
9 | from s3prl.util.download import _urls_to_filepaths
10 | from torch import nn
11 |
12 | from src.util import (
13 | freeze_module,
14 | init_module_bert,
15 | init_module_cnn,
16 | init_module_pos_conv,
17 | padding_to_len,
18 | )
19 |
20 | logger = logging.getLogger("wavlm")
21 |
22 |
23 | class WavLM(nn.Module):
24 | def __init__(
25 | self,
26 | path_or_url: str = None,
27 | refresh: bool = False,
28 | pre_normalize: bool = False,
29 | normalize: bool = False,
30 | feat_select: str = "x",
31 | randomize_all: bool = False,
32 | randomize_layers: List[int] = [],
33 | freeze_all: bool = False,
34 | freeze_layers: List[int] = [],
35 | ):
36 | super().__init__()
37 |
38 | ckpt = "https://huggingface.co/s3prl/converted_ckpts/resolve/main/wavlm_base.pt"
39 | if path_or_url is not None:
40 | ckpt = path_or_url
41 | if ckpt.startswith("https"):
42 | ckpt = _urls_to_filepaths(ckpt, refresh=refresh)
43 |
44 | checkpoint = torch.load(ckpt)
45 | self.cfg = WavLMConfig(checkpoint["cfg"])
46 | self.model = WavLMModel(self.cfg)
47 | self.model.load_state_dict(checkpoint["model"])
48 |
49 | self.wav_normalize = self.cfg.normalize
50 | self.pre_normalize = pre_normalize
51 | self.normalize = normalize
52 | self.num_layers = self.cfg.encoder_layers + 1 # CNN + 12 Transformer
53 | self.hidden_sizes = [self.cfg.encoder_embed_dim] * self.num_layers
54 |
55 | self.model.feature_grad_mult = 0.0
56 | self.model.encoder.layerdrop = 0.0
57 |
58 | self.feat_select = 0
59 | if feat_select == "att":
60 | self.feat_select = 1
61 |
62 | logger.info(f"Feature selection: {feat_select}")
63 | logger.info(
64 | f"Randomize all = {randomize_all} (randomize layers = {randomize_layers})"
65 | )
66 | logger.info(f"Freeze all = {freeze_all} (freeze layers = {freeze_layers})")
67 |
68 | self.randomize_all = randomize_all
69 | self.randomize_layers = randomize_layers
70 | self.freeze_all = freeze_all
71 | self.freeze_layers = freeze_layers
72 |
73 | self.freeze_cnn = (0 in freeze_layers) or self.freeze_all
74 | self.freeze_pos = ("pos" in freeze_layers) or self.freeze_all
75 |
76 | # Randomize weights
77 | if randomize_all:
78 | randomize_layers = ["pos"] + list(range(self.num_layers))
79 | if len(randomize_layers) > 0:
80 | for i in randomize_layers:
81 | if i == 0:
82 | self.model.feature_extractor.apply(init_module_cnn)
83 | self.model.layer_norm.reset_parameters()
84 | if self.model.post_extract_proj is not None:
85 | self.model.post_extract_proj.reset_parameters()
86 | elif i == "pos":
87 | self.model.encoder.pos_conv.apply(init_module_pos_conv)
88 | else:
89 | self.model.encoder.layers[i - 1].apply(init_module_bert)
90 | if i == self.num_layers - 1 and self.model.encoder.layer_norm_first:
91 | self.model.encoder.layer_norm.reset_parameters()
92 |
93 | # Freeze weights
94 | if freeze_all:
95 | freeze_module(self.model)
96 | elif len(freeze_layers) > 0:
97 | for i in freeze_layers:
98 | if i == 0:
99 | self.model.feature_grad_mult = 0.0
100 | freeze_module(self.model.feature_extractor)
101 | freeze_module(self.model.layer_norm)
102 | if self.model.post_extract_proj is not None:
103 | freeze_module(self.model.post_extract_proj)
104 | elif i == "pos":
105 | freeze_module(self.model.encoder.pos_conv)
106 | else:
107 | assert isinstance(i, int), i
108 | freeze_module(self.model.encoder.layers[i - 1])
109 |
110 | if not self.freeze_cnn:
111 | self.model.feature_grad_mult = 1.0
112 |
113 | def trainable_parameters(self):
114 | params = []
115 |
116 | if self.freeze_all:
117 | return []
118 |
119 | if not self.freeze_all and len(self.freeze_layers) == 0:
120 | logger.info("Trains the entire model")
121 | return self.model.parameters()
122 |
123 | params = []
124 | for i in ["pos"] + list(range(self.num_layers)):
125 | if i in self.freeze_layers:
126 | continue
127 | if i == 0:
128 | params += list(self.model.feature_extractor.parameters())
129 | params += list(self.model.layer_norm.parameters())
130 | if self.model.post_extract_proj is not None:
131 | params += list(self.model.post_extract_proj.parameters())
132 | elif i == "pos":
133 | params += list(self.model.encoder.pos_conv.parameters())
134 | else:
135 | params += list(self.model.encoder.layers[i - 1].parameters())
136 | if i == self.num_layers - 1 and self.model.encoder.layer_norm_first:
137 | params += list(self.model.encoder.layer_norm.parameters())
138 |
139 | return params
140 |
141 | def forward_features(self, x: torch.Tensor) -> torch.Tensor:
142 | if self.model.feature_grad_mult > 0:
143 | features = self.model.feature_extractor(x)
144 | if self.feature_grad_mult != 1.0:
145 | features = GradMultiply.apply(features, self.model.feature_grad_mult)
146 | else:
147 | with torch.no_grad():
148 | features = self.model.feature_extractor(x)
149 |
150 | return features
151 |
152 | def forward(
153 | self,
154 | wavs: torch.FloatTensor,
155 | padding_mask: torch.BoolTensor,
156 | ):
157 | if self.pre_normalize:
158 | with torch.no_grad():
159 | wavs = (wavs - wavs[~padding_mask].mean()) / (
160 | wavs[~padding_mask].std() + 1e-5
161 | )
162 |
163 | if self.wav_normalize:
164 | with torch.no_grad():
165 | wav_len = padding_to_len(padding_mask)
166 | for i in range(len(wavs)):
167 | wavs[i, : wav_len[i]] = F.layer_norm(
168 | wavs[i, : wav_len[i]], (wav_len[i],)
169 | )
170 |
171 | features = self.forward_features(wavs).transpose(1, 2)
172 | if self.freeze_cnn:
173 | with torch.no_grad():
174 | features = self.model.layer_norm(features)
175 | if self.model.post_extract_proj is not None:
176 | features = self.model.post_extract_proj(features)
177 | else:
178 | features = self.model.layer_norm(features)
179 | if self.model.post_extract_proj is not None:
180 | features = self.model.post_extract_proj(features)
181 |
182 | if padding_mask is not None:
183 | padding_mask = self.model.forward_padding_mask(features, padding_mask)
184 |
185 | x = self.model.dropout_input(features)
186 |
187 | # feature: (B, T, D), float
188 | # x: (B, T, D), float
189 | # padding_mask: (B, T), bool
190 |
191 | if self.freeze_all:
192 | with torch.no_grad():
193 | _, layer_results = self.model.encoder(
194 | x,
195 | padding_mask=padding_mask,
196 | )
197 | else:
198 | _, layer_results = self.model.encoder(
199 | x,
200 | padding_mask=padding_mask,
201 | freeze_pos=self.freeze_pos,
202 | freeze_layers=self.freeze_layers,
203 | )
204 |
205 | feat_list = [features] + [
206 | feat[self.feat_select].transpose(0, 1) for feat in layer_results
207 | ]
208 |
209 | if self.normalize:
210 | feat_list = [F.layer_norm(feat, feat.shape[-1:]) for feat in feat_list]
211 |
212 | feat_len = padding_to_len(padding_mask)
213 |
214 | return feat_list, feat_len, padding_mask
215 |
--------------------------------------------------------------------------------
/src/task/__init__.py:
--------------------------------------------------------------------------------
1 | from .train_spin import SpinPretrainTask
2 |
--------------------------------------------------------------------------------
/src/task/train_spin.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | import yaml
5 | from pytorch_lightning import Trainer, seed_everything
6 | from pytorch_lightning.callbacks import (
7 | LearningRateMonitor,
8 | ModelCheckpoint,
9 | TQDMProgressBar,
10 | )
11 | from torch.utils.data import DataLoader
12 |
13 | from src.data import AudioPretrainPnmiValDataset, val_collate_fn
14 | from src.model import SpinModel
15 | from src.util import set_logging, set_pl_logger
16 |
17 |
18 | class SpinPretrainTask:
19 | def __init__(self):
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument("task", help="Task name")
22 | parser.add_argument("--config", "-c", help="Config .yaml file")
23 | parser.add_argument("--save-path", "-s", help="Path to save exp")
24 | parser.add_argument("--resume", "-r", default="", help="Resume training")
25 | parser.add_argument("--gpus", "-g", type=int, default=1, help="Number of GPUs")
26 | parser.add_argument(
27 | "--njobs", "-j", type=int, default=8, help="Number of workers"
28 | )
29 | parser.add_argument("--seed", type=int, default=7122, help="Random seed")
30 | parser.add_argument("--log-level", default="info", help="Logging level")
31 | args = parser.parse_args()
32 |
33 | if not torch.cuda.is_available():
34 | args.device = "cpu"
35 | args.gpus = 0
36 | else:
37 | args.device = "cuda" if args.gpus > 0 else "cpu"
38 |
39 | self.args = args
40 | set_logging(args.log_level)
41 |
42 | def run(self, model_cls=SpinModel):
43 | assert isinstance(self.args, argparse.Namespace)
44 |
45 | config = yaml.load(open(self.args.config, "r"), Loader=yaml.FullLoader)
46 | self.config = config
47 |
48 | use_ddp = (
49 | config["trainer"].get("strategy", "").startswith("ddp")
50 | and self.args.gpus > 1
51 | )
52 |
53 | if self.args.save_path != "":
54 | config["trainer"]["default_root_dir"] = self.args.save_path
55 |
56 | model_checkpoint = ModelCheckpoint(
57 | dirpath=config["trainer"]["default_root_dir"], **config["checkpoint"]
58 | )
59 |
60 | config["trainer"]["logger"] = set_pl_logger(
61 | config["trainer"]["logger"],
62 | config["logger"]["project"],
63 | config["trainer"]["default_root_dir"].split("/")[-1],
64 | )
65 |
66 | trainer = Trainer(
67 | callbacks=[
68 | TQDMProgressBar(),
69 | model_checkpoint,
70 | LearningRateMonitor("step"),
71 | ],
72 | enable_progress_bar=True,
73 | devices=self.args.gpus,
74 | check_val_every_n_epoch=None,
75 | use_distributed_sampler=False,
76 | sync_batchnorm=use_ddp,
77 | **config["trainer"],
78 | )
79 |
80 | seed_everything(self.args.seed)
81 |
82 | if config.get("val_data", None) is not None:
83 | val_dataset = AudioPretrainPnmiValDataset(**config["val_data"])
84 | val_loader = DataLoader(
85 | val_dataset,
86 | batch_size=config["hparam"]["val_batch_size"],
87 | num_workers=self.args.njobs,
88 | pin_memory=True,
89 | collate_fn=val_collate_fn,
90 | shuffle=False,
91 | drop_last=False,
92 | )
93 | else:
94 | val_dataset = None
95 | val_loader = None
96 |
97 | if self.args.resume != "":
98 | model = model_cls.load_from_checkpoint(self.args.resume)
99 | else:
100 | self.args.resume = None
101 | model = model_cls(config, 2)
102 |
103 | model.set_random_seed(self.args.seed)
104 | model.set_njobs(self.args.njobs)
105 | model.set_use_ddp(use_ddp)
106 |
107 | trainer.fit(model, val_dataloaders=val_loader, ckpt_path=self.args.resume)
108 |
--------------------------------------------------------------------------------
/src/util/__init__.py:
--------------------------------------------------------------------------------
1 | from .log import set_logging, set_pl_logger
2 | from .model_utils import (
3 | count_parameters,
4 | freeze_module,
5 | init_module,
6 | init_module_bert,
7 | init_module_cnn,
8 | init_module_pos_conv,
9 | unfreeze_module,
10 | )
11 | from .padding import (
12 | add_front_padding_mask,
13 | len_to_padding,
14 | padding_to_len,
15 | update_padding_mask,
16 | )
17 | from .pnmi import compute_show_pnmi, compute_snmi
18 | from .scheduler import get_scheduler
19 |
--------------------------------------------------------------------------------
/src/util/log.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Union
3 |
4 | from pytorch_lightning.loggers import WandbLogger
5 |
6 |
7 | def set_logging(log_level: str = "info") -> None:
8 | level = getattr(logging, str(log_level).upper())
9 | logging.basicConfig(
10 | level=level,
11 | format="%(asctime)s %(filename)s.%(funcName)s %(message)s",
12 | datefmt="%m-%d %H:%M",
13 | )
14 |
15 |
16 | def set_pl_logger(
17 | logger_type: Union[bool, str],
18 | project: str = "speech_disentangle",
19 | name: str = "example",
20 | ):
21 | if isinstance(logger_type, bool):
22 | return logger_type
23 | elif logger_type == "wandb":
24 | logger = WandbLogger(project=project, name=name)
25 | return logger
26 | else:
27 | raise NotImplementedError(f"Unknown logger type = {logger_type}")
28 |
--------------------------------------------------------------------------------
/src/util/model_utils.py:
--------------------------------------------------------------------------------
1 | from s3prl.upstream.wav2vec2.wav2vec2_model import MultiheadAttention
2 | from torch import nn
3 |
4 |
5 | def freeze_module(m: nn.Module) -> None:
6 | for p in m.parameters():
7 | p.requires_grad = False
8 |
9 |
10 | def unfreeze_module(m: nn.Module) -> None:
11 | for p in m.parameters():
12 | p.requires_grad = True
13 |
14 |
15 | def init_module(m: nn.Module):
16 | for p in m.parameters():
17 | nn.init.normal_(p, mean=0, std=0.02)
18 |
19 |
20 | def init_module_bert(m: nn.Module):
21 | def normal_(data):
22 | # with FSDP, module params will be on CUDA, so we cast them back to CPU
23 | # so that the RNG is consistent with and without FSDP
24 | data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
25 |
26 | if isinstance(m, nn.Linear):
27 | normal_(m.weight.data)
28 | if m.bias is not None:
29 | m.bias.data.zero_()
30 | if isinstance(m, nn.Embedding):
31 | normal_(m.weight.data)
32 | if m.padding_idx is not None:
33 | m.weight.data[m.padding_idx].zero_()
34 | if isinstance(m, MultiheadAttention):
35 | normal_(m.q_proj.weight.data)
36 | normal_(m.k_proj.weight.data)
37 | normal_(m.v_proj.weight.data)
38 |
39 |
40 | def init_module_cnn(m: nn.Module):
41 | if isinstance(m, nn.Conv1d):
42 | nn.init.kaiming_normal_(m.weight)
43 | if isinstance(m, nn.LayerNorm):
44 | m.reset_parameters()
45 |
46 |
47 | def init_module_pos_conv(m: nn.Module):
48 | if isinstance(m, nn.Conv1d):
49 | m.reset_parameters()
50 | if isinstance(m, nn.LayerNorm):
51 | m.reset_parameters()
52 |
53 |
54 | def count_parameters(model: nn.Module) -> int:
55 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
56 |
--------------------------------------------------------------------------------
/src/util/padding.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | @torch.no_grad()
5 | def len_to_padding(x_len: torch.LongTensor, max_len: int = 0) -> torch.BoolTensor:
6 | if max_len == 0:
7 | max_len = max(x_len)
8 | idxs = torch.arange(max_len, dtype=torch.long).to(x_len.device)
9 | padding_mask = idxs.unsqueeze(0) >= x_len.unsqueeze(1)
10 | return padding_mask
11 |
12 |
13 | @torch.no_grad()
14 | def padding_to_len(padding_mask: torch.BoolTensor) -> torch.LongTensor:
15 | x_len = (~padding_mask).long().sum(-1)
16 | return x_len
17 |
18 |
19 | @torch.no_grad()
20 | def update_padding_mask(
21 | padding_mask: torch.BoolTensor, new_len: int
22 | ) -> torch.BoolTensor:
23 | extra = padding_mask.shape[1] % new_len
24 | if extra > 0:
25 | padding_mask = padding_mask[:, :-extra]
26 | padding_mask = padding_mask.view(padding_mask.shape[0], new_len, -1)
27 | padding_mask = padding_mask.all(-1)
28 | return padding_mask
29 |
30 |
31 | @torch.no_grad()
32 | def add_front_padding_mask(
33 | padding_mask: torch.BoolTensor, pad_front_lens: torch.LongTensor
34 | ) -> None:
35 | for i in range(len(padding_mask)):
36 | if pad_front_lens[i] > 0:
37 | padding_mask[i, : pad_front_lens[i]] = True
38 |
--------------------------------------------------------------------------------
/src/util/pnmi.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from collections import Counter
7 |
8 | import numpy as np
9 | from tabulate import tabulate
10 |
11 |
12 | def comp_purity(p_xy, axis):
13 | max_p = p_xy.max(axis=axis)
14 | marg_p = p_xy.sum(axis=axis)
15 | indv_pur = max_p / marg_p
16 | aggr_pur = max_p.sum()
17 | return indv_pur, aggr_pur
18 |
19 |
20 | def comp_entropy(p):
21 | return (-p * np.log(p + 1e-8)).sum()
22 |
23 |
24 | def comp_norm_mutual_info(p_xy):
25 | p_x = p_xy.sum(axis=1, keepdims=True)
26 | p_y = p_xy.sum(axis=0, keepdims=True)
27 | pmi = np.log(p_xy / np.matmul(p_x, p_y) + 1e-8)
28 | mi = (p_xy * pmi).sum()
29 | h_x = comp_entropy(p_x)
30 | h_y = comp_entropy(p_y)
31 | return mi, mi / h_x, mi / h_y, h_x, h_y
32 |
33 |
34 | def pad(labs, n):
35 | if n == 0:
36 | return np.array(labs)
37 | return np.concatenate([[labs[0]] * n, labs, [labs[-1]] * n])
38 |
39 |
40 | def comp_avg_seg_dur(labs_list):
41 | n_frms = 0
42 | n_segs = 0
43 | for labs in labs_list:
44 | labs = np.array(labs)
45 | edges = np.zeros(len(labs)).astype(bool)
46 | edges[0] = True
47 | edges[1:] = labs[1:] != labs[:-1]
48 | n_frms += len(edges)
49 | n_segs += edges.astype(int).sum()
50 | return n_frms / n_segs
51 |
52 |
53 | def comp_joint_prob(uid2refs, uid2hyps):
54 | cnts = Counter()
55 | skipped = []
56 | abs_frmdiff = 0
57 | for uid in uid2refs:
58 | if uid not in uid2hyps:
59 | skipped.append(uid)
60 | continue
61 | refs = uid2refs[uid]
62 | hyps = uid2hyps[uid]
63 | abs_frmdiff += abs(len(refs) - len(hyps))
64 | min_len = min(len(refs), len(hyps))
65 | refs = refs[:min_len]
66 | hyps = hyps[:min_len]
67 | cnts.update(zip(refs, hyps))
68 | tot = sum(cnts.values())
69 |
70 | ref_set = sorted({ref for ref, _ in cnts.keys()})
71 | hyp_set = sorted({hyp for _, hyp in cnts.keys()})
72 | ref2pid = dict(zip(ref_set, range(len(ref_set))))
73 | hyp2lid = dict(zip(hyp_set, range(len(hyp_set))))
74 |
75 | p_xy = np.zeros((len(ref2pid), len(hyp2lid)), dtype=float)
76 | for (ref, hyp), cnt in cnts.items():
77 | p_xy[ref2pid[ref], hyp2lid[hyp]] = cnt
78 | freq_xy = p_xy
79 | full_freq_xy = np.zeros((len(ref2pid), 4096), dtype=float)
80 | for (ref, hyp), cnt in cnts.items():
81 | full_freq_xy[ref2pid[ref], int(hyp)] = cnt
82 | p_xy = p_xy / p_xy.sum()
83 | return (
84 | freq_xy,
85 | full_freq_xy,
86 | p_xy,
87 | ref2pid,
88 | hyp2lid,
89 | tot,
90 | abs_frmdiff,
91 | skipped,
92 | ref_set,
93 | hyp_set,
94 | )
95 |
96 |
97 | def comp_phone2code(p_xy):
98 | p_x = p_xy.sum(axis=1, keepdims=True) # ref (phone)
99 | p_y = p_xy.sum(axis=0, keepdims=True) # hyp (code)
100 |
101 | p_x_y = p_xy / p_y # P(x | y) = P(phone | code)
102 |
103 | y_order = np.argsort(p_x_y.argmax(0))
104 | p_x_y_sorted_y = np.take_along_axis(p_x_y, y_order.reshape((1, -1)), axis=1)
105 |
106 | x_order = np.argsort(p_x[:, 0])
107 | x_order = np.flip(x_order)
108 | p_x_y_sorted_x = np.take_along_axis(p_x_y, x_order.reshape((-1, 1)), axis=0)
109 | y_order = np.argsort(p_x_y_sorted_x.argmax(0))
110 | p_x_y_sorted_xy = np.take_along_axis(
111 | p_x_y_sorted_x, y_order.reshape((1, -1)), axis=1
112 | )
113 |
114 | return p_x_y, p_x_y_sorted_xy, p_x_y_sorted_y, x_order
115 |
116 |
117 | def compute_show_pnmi(uid2refs, uid2hyps, upsample=1, show_results: bool = False):
118 | for k, v in uid2hyps.items():
119 | uid2hyps[k] = pad(v, 0).repeat(upsample)
120 |
121 | (
122 | freq_xy,
123 | full_freq_xy,
124 | p_xy,
125 | ref2pid,
126 | hyp2lid,
127 | tot,
128 | frmdiff,
129 | skipped,
130 | ref_set,
131 | hyp_set,
132 | ) = comp_joint_prob(uid2refs, uid2hyps)
133 | ref_pur_by_hyp, ref_pur = comp_purity(p_xy, axis=0)
134 | hyp_pur_by_ref, hyp_pur = comp_purity(p_xy, axis=1)
135 | (mi, mi_norm_by_ref, mi_norm_by_hyp, h_ref, h_hyp) = comp_norm_mutual_info(p_xy)
136 |
137 | if show_results:
138 | print(
139 | tabulate(
140 | [[hyp_pur, ref_pur, mi_norm_by_ref]],
141 | ["Cls Pur", "Phn Pur", "PNMI"],
142 | floatfmt=".3f",
143 | tablefmt="fancy_grid",
144 | )
145 | )
146 |
147 | return {
148 | "cls_pur": hyp_pur,
149 | "phn_pur": ref_pur,
150 | "pnmi": mi_norm_by_ref,
151 | }
152 |
153 |
154 | def compute_snmi(p_xy):
155 | _, ref_pur = comp_purity(p_xy, axis=0)
156 | _, hyp_pur = comp_purity(p_xy, axis=1)
157 | (_, mi_norm_by_ref, _, _, _) = comp_norm_mutual_info(p_xy)
158 |
159 | return {
160 | "cls_pur": hyp_pur,
161 | "spk_pur": ref_pur,
162 | "snmi": mi_norm_by_ref,
163 | }
164 |
--------------------------------------------------------------------------------
/src/util/scheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | from torch.optim import Optimizer
4 | from torch.optim.lr_scheduler import LambdaLR, _LRScheduler
5 |
6 |
7 | def get_lr(optimizer: Optimizer) -> float:
8 | for param_group in optimizer.param_groups:
9 | return param_group["lr"]
10 |
11 |
12 | def noam_scheduler(
13 | optimizer: Optimizer, warmup: int = 4000, last_epoch: int = -1
14 | ) -> _LRScheduler:
15 | def func(step: int):
16 | if step < warmup:
17 | return (step + 1) / warmup
18 | else:
19 | return (warmup / (step + 1)) ** 0.5
20 |
21 | return LambdaLR(optimizer, func, last_epoch)
22 |
23 |
24 | def linear_warmup_decay_scheduler(
25 | optimizer: Optimizer,
26 | warmup: int = 4000,
27 | max_step: int = 1000000,
28 | init_lr: float = 1e-6,
29 | final_lr: float = 1e-6,
30 | ) -> _LRScheduler:
31 | func_list = []
32 |
33 | for param_group in optimizer.param_groups:
34 | base_lr = param_group["lr"]
35 | rate_i = init_lr / base_lr
36 | rate_f = final_lr / base_lr
37 |
38 | def func(step: int) -> float:
39 | if step <= warmup:
40 | return rate_i + (1.0 - rate_i) * step / warmup
41 | else:
42 | return 1.0 - (1.0 - rate_f) * (step - warmup) / (max_step - warmup - 1)
43 |
44 | func_list.append(func)
45 |
46 | return LambdaLR(optimizer, func_list)
47 |
48 |
49 | def linear_warmup_cosine_scheduler(
50 | optimizer: Optimizer,
51 | warmup: int = 4000,
52 | max_step: int = 1000000,
53 | final_lr: float = 1e-6,
54 | ) -> _LRScheduler:
55 | func_list = []
56 |
57 | for param_group in optimizer.param_groups:
58 | base_lr = param_group["lr"]
59 | rate = final_lr / base_lr
60 |
61 | def func(step: int) -> float:
62 | if step < warmup:
63 | return (step + 1) / warmup
64 | else:
65 | q = 0.5 * (
66 | 1 + math.cos(math.pi * (step + 1 - warmup) / (max_step - warmup))
67 | )
68 | return (1.0 - rate) * q + rate
69 |
70 | func_list.append(func)
71 |
72 | return LambdaLR(optimizer, func_list)
73 |
74 |
75 | def get_scheduler(name: str, optimizer: Optimizer, **kwargs) -> _LRScheduler:
76 | if name == "noam":
77 | return noam_scheduler(optimizer, **kwargs)
78 | elif name == "linear_warmup_decay":
79 | return linear_warmup_decay_scheduler(optimizer, **kwargs)
80 | elif name == "linear_warmup_cosine":
81 | return linear_warmup_cosine_scheduler(optimizer, **kwargs)
82 | else:
83 | raise NotImplementedError(f"Unknown lr scheduler {name}")
84 |
--------------------------------------------------------------------------------