├── .gitignore ├── LICENSE ├── README.md ├── ci └── format.py ├── config └── spin.yaml ├── figure └── spin.png ├── prepare_data.py ├── requirements.txt ├── run_task.py ├── s3prl_py ├── WavLM.py ├── spin │ ├── __init__.py │ ├── expert.py │ └── hubconf.py └── wav2vec2_model.py ├── script ├── prepare.sh └── train.sh └── src ├── data ├── __init__.py ├── audio.py ├── dataset.py ├── librispeech.py └── sampler.py ├── model ├── __init__.py ├── base.py └── spin.py ├── nn ├── __init__.py ├── dnn.py ├── hubert.py ├── swav_vq_dis.py └── wavlm.py ├── task ├── __init__.py └── train_spin.py └── util ├── __init__.py ├── log.py ├── model_utils.py ├── padding.py ├── pnmi.py └── scheduler.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | _data/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Heng-Jui Chang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Speaker-invariant Clustering (Spin) 2 | 3 | - [Introduction](#Introduction) 4 | - [Citation](#Citation) 5 | - [Getting Started](#Getting-Started) 6 | - [Pre-trained Models](#Pre-trained-Models) 7 | - [References](#References) 8 | - [Contact](#Contact) 9 | 14 | 15 | ## Introduction 16 | 17 |

Spin framework.

18 | 19 | This repository is the official PyTorch implementation of the **Speaker-invariant Clustering** (**Spin**) proposed in the **Interspeech 2023** paper [Self-supervised Fine-tuning for Improved Content Representations by Speaker-invariant Clustering](https://arxiv.org/abs/2305.11072) ([Heng-Jui Chang](https://people.csail.mit.edu/hengjui/), [Alexander H. Liu](https://alexander-h-liu.github.io/), [James Glass](https://www.csail.mit.edu/person/jim-glass); [MIT CSAIL](https://www.csail.mit.edu/)). 20 | 21 | Spin is a novel self-supervised learning method that clusters speech representations and performs swapped prediction between the original and speaker-perturbed utterances. Spin *disentangles speaker information* and preserves *content representations* with just 45 minutes of fine-tuning on a single GPU (HuBERT Base models). Spin improves pre-trained networks and outperforms prior methods in speech recognition and acoustic unit discovery. 22 | 23 | 24 | ## Citation 25 | Please cite our paper if you find this repository and/or the paper useful. 26 | ``` 27 | @inproceedings{chang2023spin, 28 | author={Heng-Jui Chang and Alexander H. Liu and James Glass}, 29 | title={{Self-supervised Fine-tuning for Improved Content Representations by Speaker-invariant Clustering}}, 30 | year=2023, 31 | booktitle={Proc. Interspeech} 32 | } 33 | ``` 34 | 35 | 36 | ## Getting Started 37 | 38 | ### 1. Environment 39 | Make sure `sox` is installed and your Python version is at least `3.6`. 40 | ```bash 41 | # Create virtual environment 42 | conda create --name spin python=3.8 43 | conda activate spin 44 | 45 | # Install s3prl 46 | git clone https://github.com/s3prl/s3prl.git 47 | cd s3prl 48 | pip install -e ".[all]" 49 | cd .. 50 | 51 | # Clone this repository and intall dependencies 52 | git clone https://github.com/vectominist/spin.git 53 | cd spin/ 54 | pip install -r requirements.txt 55 | 56 | # Modify some s3prl files 57 | cp s3prl_py/wav2vec2_model.py ../s3prl/s3prl/upstream/wav2vec2/wav2vec2_model.py 58 | cp s3prl_py/WavLM.py ../s3prl/s3prl/upstream/wavlm/WavLM.py 59 | ``` 60 | 61 | 62 | ### 2. Prepare Data 63 | Download required data. 64 | ```bash 65 | # Create a directory to save data (or any other path you like) 66 | mkdir data 67 | cd data 68 | 69 | # LibriSpeech (skip if you already have this) 70 | wget https://www.openslr.org/resources/12/train-clean-100.tar.gz 71 | wget https://www.openslr.org/resources/12/dev-clean.tar.gz 72 | wget https://www.openslr.org/resources/12/dev-other.tar.gz 73 | # Decompress 74 | tar zxvf train-clean-100.tar.gz 75 | tar zxvf dev-clean.tar.gz 76 | tar zxvf dev-clean.tar.gz 77 | rm train-clean-100.tar.gz dev-clean.tar.gz dev-clean.tar.gz 78 | 79 | # LibriSpeech Phoneme Alignments (for monitoring progress only) 80 | wget https://huggingface.co/datasets/vectominist/spin_data/resolve/main/dev-clean.tsv 81 | wget https://huggingface.co/datasets/vectominist/spin_data/resolve/main/dev-other.tsv 82 | 83 | # Speaker Information 84 | # Source: https://github.com/auspicious3000/contentvec 85 | wget https://huggingface.co/datasets/vectominist/spin_data/resolve/main/spk2info.dict 86 | ``` 87 | 88 | Prepare LibriSpeech dataset, see [`script/prepare.sh`](https://github.com/vectominist/spin/blob/main/script/prepare.sh). 89 | - `libri_dir`: the directory of the LibriSpeech corpus 90 | - `json_dir`: the directory to save `.json` files generated from `prepare_data.py` 91 | ```bash 92 | bash script/prepare.sh ${libri_dir} ${json_dir} 93 | ``` 94 | 95 | ### 3. Customize Configurations 96 | See [`config/spin.yaml`](https://github.com/vectominist/spin/blob/main/config/spin.yaml). 97 | - Modify `json_dir`, `spk2info`, and `phn_dir` according to the directories with the downloaded and preprocessed data. 98 | - Modify `logger` to switch to other loggers or simply setting it to `False` to disable logging. 99 | ```yaml 100 | data: 101 | json_dir: /path/to/json_dir 102 | spk2info: /path/to/spk2info.dict 103 | 104 | val_data: 105 | json_dir: /path/to/json_dir 106 | phn_dir: /path/to/phoneme/alignments/dir 107 | 108 | trainer: 109 | logger: wandb # specify a pytorch-lightning logger you prefer 110 | ``` 111 | 112 | ### 4. Training 113 | See [`script/train.sh`](https://github.com/vectominist/spin/blob/main/script/train.sh). 114 | - `exp_dir`: the directory to save checkpoints 115 | - `exp_name`: experiment name 116 | - See [`src/task/train_spin.py`](https://github.com/vectominist/spin/blob/main/src/task/train_spin.py) for details about available arguments like number of GPUs to be used. 117 | ```bash 118 | bash script/train.sh ${exp_dir} ${exp_name} 119 | ``` 120 | The trained model checkpoints can be found in `${exp_dir}/${exp_name}`. Note that we use `last.ckpt` for evaluation and downstream tasks. 121 | 122 | ### 5. Downstream Evaluation 123 | We use the [s3prl](https://github.com/s3prl/s3prl) toolkit for [SUPERB](https://arxiv.org/abs/2105.01051) downstream tasks. 124 | - Modify [line 26](https://github.com/vectominist/spin/blob/main/s3prl_py/spin/expert.py#L26) of [`s3prl_py/spin/expert.py`](https://github.com/vectominist/spin/blob/main/s3prl_py/spin/expert.py) to the absolute path to `spin/`. 125 | - Copy the `s3prl_py/spin` directory to `s3prl` so that the toolkit can load the models. 126 | ```bash 127 | cp -R s3prl_py/spin ../s3prl/s3prl/upstream/spin 128 | ``` 129 | - Finally, add the following line to `../s3prl/s3prl/hub.py`: 130 | ```python 131 | from s3prl.upstream.spin.hubconf import * 132 | ``` 133 | 134 | 135 | ## Pre-trained Models 136 | All models are trained on a single NVIDIA A5000 GPU with 24GB VRAM. To reproduce similar or better performance, we suggest using GPUs larger than 24GB or specifying `strategy: ddp` under `trainer` in [`config/spin.yaml`](https://github.com/vectominist/spin/blob/main/config/spin.yaml) to enable multiple GPU training. Note that the following checkpoints are reproduced with the same recipe, so the results are slightly different from our paper. The training logs can be found in this [link](https://api.wandb.ai/links/vectominist/5254la3b). 137 | 138 | | Base Model | Clusters | PNMI | Checkpoint | 139 | | ---------- | -------- | ----- | ------------------------------------------------------------------------------------------------ | 140 | | HuBERT | 128 | 0.625 | [link](https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_hubert_128.ckpt) | 141 | | HuBERT | 256 | 0.658 | [link](https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_hubert_256.ckpt) | 142 | | HuBERT | 512 | 0.707 | [link](https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_hubert_512.ckpt) | 143 | | HuBERT | 1024 | 0.745 | [link](https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_hubert_1024.ckpt) | 144 | | HuBERT | 2048 | 0.774 | [link](https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_hubert_2048.ckpt) | 145 | | WavLM | 128 | 0.604 | [link](https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_wavlm_128.ckpt) | 146 | | WavLM | 256 | 0.658 | [link](https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_wavlm_256.ckpt) | 147 | | WavLM | 512 | 0.714 | [link](https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_wavlm_512.ckpt) | 148 | | WavLM | 1024 | 0.748 | [link](https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_wavlm_1024.ckpt) | 149 | | WavLM | 2048 | 0.775 | [link](https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_wavlm_2048.ckpt) | 150 | 151 | 152 | ## References 153 | - [s3prl](https://github.com/s3prl/s3prl) 154 | - [contentvec](https://github.com/auspicious3000/contentvec) 155 | - [fairseq](https://github.com/facebookresearch/fairseq) 156 | - [MiniASR](https://github.com/vectominist/MiniASR) 157 | - [PyTorch](https://pytorch.org/) 158 | - [PyTorch Lightning](https://lightning.ai/docs/pytorch/latest/) 159 | 160 | 161 | ## Contact 162 | If you have any questions, please open an issue or send me an email hengjui@mit.edu. 163 | -------------------------------------------------------------------------------- /ci/format.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # ref: https://github.com/s3prl/s3prl/blob/main/ci/format.py 3 | 4 | import argparse 5 | from subprocess import CalledProcessError, check_output 6 | 7 | 8 | def get_third_party(): 9 | package_list = [] 10 | with open("./requirements.txt", "r") as fp: 11 | for line in fp: 12 | line = line.strip() 13 | if line == "": 14 | continue 15 | package_list.append(line.split(" ")[0]) 16 | return package_list 17 | 18 | 19 | def run_command(command: str): 20 | try: 21 | check_output(command.split(" ")) 22 | except CalledProcessError as e: 23 | print(e.output.decode("utf-8")) 24 | raise 25 | 26 | 27 | def main(): 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument( 30 | "files", 31 | type=str, 32 | nargs="*", 33 | default=["src"], 34 | ) 35 | parser.add_argument("--check", action="store_true", help="Only checks the files") 36 | args = parser.parse_args() 37 | 38 | print(f"Formatting files: {args.files}") 39 | args.files = " ".join(args.files) 40 | 41 | print("Run flake8") 42 | # stop the build if there are Python syntax errors or undefined names 43 | run_command( 44 | f"flake8 {args.files} --count --select=E9,F63,F7,F82 --show-source --statistics" 45 | ) 46 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 47 | run_command( 48 | f"flake8 {args.files} --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics" 49 | ) 50 | 51 | print("Run black") 52 | if args.check: 53 | run_command(f"black --check {args.files}") 54 | else: 55 | run_command(f"black {args.files}") 56 | 57 | print("Run isort") 58 | third_party = get_third_party() 59 | third_party = ",".join(third_party) 60 | if args.check: 61 | run_command( 62 | f"isort --profile black --thirdparty {third_party} --check {args.files}" 63 | ) 64 | else: 65 | run_command(f"isort --profile black --thirdparty {third_party} {args.files}") 66 | 67 | if args.check: 68 | print("Successfully passed the format check!") 69 | 70 | 71 | if __name__ == "__main__": 72 | main() 73 | -------------------------------------------------------------------------------- /config/spin.yaml: -------------------------------------------------------------------------------- 1 | # Interspeech 2023 version 2 | 3 | # Training data 4 | data: 5 | json_dir: /data/sls/r/u/hengjui/home/scratch/spin_test/data 6 | splits: 7 | - train-clean-100 8 | sample_rate: 16000 9 | min_audio_len: 40000 # minimum audio samples per utterance 10 | random_crop_len: 272000 # maximum audio samples per utterance 11 | spk2info: /data/sls/r/u/hengjui/home/scratch/dataset/libri_util/spk2info.dict 12 | 13 | # Validation data (not used for checkpointing, just for monitoring training progress) 14 | val_data: 15 | json_dir: /data/sls/r/u/hengjui/home/scratch/spin_test/data 16 | phn_dir: /data/sls/r/u/hengjui/home/scratch/spin_test/data 17 | splits: 18 | - dev-clean 19 | - dev-other 20 | sample_rate: 16000 21 | 22 | # SpinModel config 23 | model: 24 | encoder: 25 | type: HuBERT # `HuBERT` / `WavLM` 26 | use_layer: 12 # the layer which its representations are used for clustering 27 | normalize: False 28 | feat_select: x 29 | randomize_all: False 30 | randomize_layers: [] 31 | freeze_all: False 32 | freeze_layers: ["pos", 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # `pos`: positional encoding, `0`: CNN extractor 33 | pred_head: 34 | type: DNN 35 | hid_dims: [256] 36 | dropout: 0 37 | activation: ReLU 38 | loss: 39 | type: SwavVQDisentangle 40 | num_vars: 256 # cluster size 41 | epsilon: 0.02 42 | sinkhorn_iters: 3 43 | temp: 0.1 44 | l2_norm: True 45 | prob_ratio: 1.0 46 | 47 | # Optimization 48 | optim: 49 | optimizer: 50 | name: Adam 51 | args: 52 | lr: 1.e-4 53 | weight_decay: 1.e-6 54 | scheduler: 55 | name: linear_warmup_decay # `linear_warmup_decay` / `linear_warmup_cosine_scheduler` / `noam_scheduler` 56 | args: 57 | warmup: 2500 58 | max_step: 5000 59 | final_lr: 1.e-6 60 | 61 | hparam: 62 | batch_len: 4096000 # audio samples per GPU (256 secs ~ batch_size = 12.8k) 63 | val_batch_size: 8 64 | 65 | # pytorch_lightning.Trainer 66 | # ref: https://lightning.ai/docs/pytorch/latest/common/trainer.html 67 | trainer: 68 | max_steps: 5000 69 | gradient_clip_val: 10 70 | accumulate_grad_batches: 1 71 | precision: 16 72 | logger: wandb # use `False` to disable logging 73 | log_every_n_steps: 100 74 | default_root_dir: exp/tmp 75 | accelerator: gpu 76 | # strategy: ddp # uncomment this line to enable DDP training 77 | num_sanity_val_steps: 0 78 | val_check_interval: 1000 79 | 80 | # pytorch_lightning.callbacks.ModelCheckpoint 81 | # ref: https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.ModelCheckpoint.html 82 | checkpoint: 83 | filename: "{epoch}-{step}" 84 | every_n_train_steps: 5000 85 | save_last: true 86 | 87 | # pytorch_lightning.loggers.WandbLogger 88 | # ref: https://lightning.ai/docs/pytorch/latest/extensions/generated/lightning.pytorch.loggers.WandbLogger.html 89 | logger: 90 | project: spin_is2023 91 | -------------------------------------------------------------------------------- /figure/spin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vectominist/spin/73edb7ae120be67a2d584931eeb388caecda744e/figure/spin.png -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import sys 5 | 6 | from src.data.librispeech import find_all_librispeech, save_data_info 7 | 8 | logging.basicConfig( 9 | format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 10 | datefmt="%Y-%m-%d %H:%M:%S", 11 | level=os.environ.get("LOGLEVEL", "INFO").upper(), 12 | stream=sys.stdout, 13 | ) 14 | logger = logging.getLogger("prepare_data") 15 | 16 | 17 | def main(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("root", type=str, help="Root directory of LibriSpeech") 20 | parser.add_argument("json_dir", type=str, help="Directory to save .json files") 21 | parser.add_argument( 22 | "--split", 23 | "-s", 24 | type=str, 25 | nargs="+", 26 | default=[ 27 | "train-clean-100", 28 | "train-clean-360", 29 | "train-other-500", 30 | "dev-clean", 31 | "dev-other", 32 | "test-clean", 33 | "test-other", 34 | ], 35 | help="LibriSpeech partitions to be processed", 36 | ) 37 | parser.add_argument( 38 | "--sort-by-len", "-l", action="store_true", help="Sort audio files by length" 39 | ) 40 | args = parser.parse_args() 41 | 42 | logger.info(f"Preparing data from LibriSpeech at {args.root}") 43 | logger.info(f"Splits: {args.split}") 44 | logger.info(f"Sort audio files by length = {args.sort_by_len}") 45 | for s in args.split: 46 | logger.info(f"Processing {s} split...") 47 | data = find_all_librispeech(os.path.join(args.root, s), args.sort_by_len) 48 | save_data_info(data, os.path.join(args.json_dir, s + ".json")) 49 | 50 | 51 | if __name__ == "__main__": 52 | main() 53 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | praat-parselmouth==0.4.3 2 | -------------------------------------------------------------------------------- /run_task.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from src import task 4 | 5 | 6 | def main(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("task") 9 | args, _ = parser.parse_known_args() 10 | 11 | runner = getattr(task, args.task)() 12 | runner.run() 13 | 14 | 15 | if __name__ == "__main__": 16 | main() 17 | -------------------------------------------------------------------------------- /s3prl_py/WavLM.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/wavlm 4 | # Copyright (c) 2021 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # Based on fairseq code bases 7 | # https://github.com/pytorch/fairseq 8 | # -------------------------------------------------------- 9 | 10 | import logging 11 | import math 12 | from typing import List, Optional, Tuple 13 | 14 | import numpy as np 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | from torch.nn import LayerNorm 19 | 20 | from .modules import ( 21 | Fp32GroupNorm, 22 | Fp32LayerNorm, 23 | GLU_Linear, 24 | GradMultiply, 25 | MultiheadAttention, 26 | SamePad, 27 | TransposeLast, 28 | get_activation_fn, 29 | init_bert_params, 30 | ) 31 | 32 | logger = logging.getLogger(__name__) 33 | 34 | 35 | def compute_mask_indices( 36 | shape: Tuple[int, int], 37 | padding_mask: Optional[torch.Tensor], 38 | mask_prob: float, 39 | mask_length: int, 40 | mask_type: str = "static", 41 | mask_other: float = 0.0, 42 | min_masks: int = 0, 43 | no_overlap: bool = False, 44 | min_space: int = 0, 45 | ) -> np.ndarray: 46 | """ 47 | Computes random mask spans for a given shape 48 | 49 | Args: 50 | shape: the the shape for which to compute masks. 51 | should be of size 2 where first element is batch size and 2nd is timesteps 52 | padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements 53 | mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by 54 | number of timesteps divided by length of mask span to mask approximately this percentage of all elements. 55 | however due to overlaps, the actual number will be smaller (unless no_overlap is True) 56 | mask_type: how to compute mask lengths 57 | static = fixed size 58 | uniform = sample from uniform distribution [mask_other, mask_length*2] 59 | normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element 60 | poisson = sample from possion distribution with lambda = mask length 61 | min_masks: minimum number of masked spans 62 | no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping 63 | min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans 64 | """ 65 | 66 | bsz, all_sz = shape 67 | mask = np.full((bsz, all_sz), False) 68 | 69 | all_num_mask = int( 70 | # add a random number for probabilistic rounding 71 | mask_prob * all_sz / float(mask_length) 72 | + np.random.rand() 73 | ) 74 | 75 | all_num_mask = max(min_masks, all_num_mask) 76 | 77 | mask_idcs = [] 78 | for i in range(bsz): 79 | if padding_mask is not None: 80 | sz = all_sz - padding_mask[i].long().sum().item() 81 | num_mask = int( 82 | # add a random number for probabilistic rounding 83 | mask_prob * sz / float(mask_length) 84 | + np.random.rand() 85 | ) 86 | num_mask = max(min_masks, num_mask) 87 | else: 88 | sz = all_sz 89 | num_mask = all_num_mask 90 | 91 | if mask_type == "static": 92 | lengths = np.full(num_mask, mask_length) 93 | elif mask_type == "uniform": 94 | lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) 95 | elif mask_type == "normal": 96 | lengths = np.random.normal(mask_length, mask_other, size=num_mask) 97 | lengths = [max(1, int(round(x))) for x in lengths] 98 | elif mask_type == "poisson": 99 | lengths = np.random.poisson(mask_length, size=num_mask) 100 | lengths = [int(round(x)) for x in lengths] 101 | else: 102 | raise Exception("unknown mask selection " + mask_type) 103 | 104 | if sum(lengths) == 0: 105 | lengths[0] = min(mask_length, sz - 1) 106 | 107 | if no_overlap: 108 | mask_idc = [] 109 | 110 | def arrange(s, e, length, keep_length): 111 | span_start = np.random.randint(s, e - length) 112 | mask_idc.extend(span_start + i for i in range(length)) 113 | 114 | new_parts = [] 115 | if span_start - s - min_space >= keep_length: 116 | new_parts.append((s, span_start - min_space + 1)) 117 | if e - span_start - keep_length - min_space > keep_length: 118 | new_parts.append((span_start + length + min_space, e)) 119 | return new_parts 120 | 121 | parts = [(0, sz)] 122 | min_length = min(lengths) 123 | for length in sorted(lengths, reverse=True): 124 | lens = np.fromiter( 125 | (e - s if e - s >= length + min_space else 0 for s, e in parts), 126 | np.int, 127 | ) 128 | l_sum = np.sum(lens) 129 | if l_sum == 0: 130 | break 131 | probs = lens / np.sum(lens) 132 | c = np.random.choice(len(parts), p=probs) 133 | s, e = parts.pop(c) 134 | parts.extend(arrange(s, e, length, min_length)) 135 | mask_idc = np.asarray(mask_idc) 136 | else: 137 | min_len = min(lengths) 138 | if sz - min_len <= num_mask: 139 | min_len = sz - num_mask - 1 140 | 141 | mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) 142 | 143 | mask_idc = np.asarray( 144 | [ 145 | mask_idc[j] + offset 146 | for j in range(len(mask_idc)) 147 | for offset in range(lengths[j]) 148 | ] 149 | ) 150 | 151 | mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) 152 | 153 | min_len = min([len(m) for m in mask_idcs]) 154 | for i, mask_idc in enumerate(mask_idcs): 155 | if len(mask_idc) > min_len: 156 | mask_idc = np.random.choice(mask_idc, min_len, replace=False) 157 | mask[i, mask_idc] = True 158 | 159 | return mask 160 | 161 | 162 | class WavLMConfig: 163 | def __init__(self, cfg=None): 164 | self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True) 165 | self.encoder_layers: int = 12 # num encoder layers in the transformer 166 | 167 | self.encoder_embed_dim: int = 768 # encoder embedding dimension 168 | self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN 169 | self.encoder_attention_heads: int = 12 # num encoder attention heads 170 | self.activation_fn: str = "gelu" # activation function to use 171 | 172 | self.layer_norm_first: bool = False # apply layernorm first in the transformer 173 | self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...] 174 | self.conv_bias: bool = False # include bias in conv encoder 175 | self.feature_grad_mult: float = ( 176 | 1.0 # multiply feature extractor var grads by this 177 | ) 178 | 179 | self.normalize: bool = ( 180 | False # normalize input to have 0 mean and unit variance during training 181 | ) 182 | 183 | # dropouts 184 | self.dropout: float = 0.1 # dropout probability for the transformer 185 | self.attention_dropout: float = 0.1 # dropout probability for attention weights 186 | self.activation_dropout: float = ( 187 | 0.0 # dropout probability after activation in FFN 188 | ) 189 | self.encoder_layerdrop: float = ( 190 | 0.0 # probability of dropping a tarnsformer layer 191 | ) 192 | self.dropout_input: float = ( 193 | 0.0 # dropout to apply to the input (after feat extr) 194 | ) 195 | self.dropout_features: float = ( 196 | 0.0 # dropout to apply to the features (after feat extr) 197 | ) 198 | 199 | # masking 200 | self.mask_length: int = 10 # mask length 201 | self.mask_prob: float = 0.65 # probability of replacing a token with mask 202 | self.mask_selection: str = "static" # how to choose mask length 203 | self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh 204 | self.no_mask_overlap: bool = False # whether to allow masks to overlap 205 | self.mask_min_space: int = ( 206 | 1 # min space between spans (if no overlap is enabled) 207 | ) 208 | 209 | # channel masking 210 | self.mask_channel_length: int = 10 # length of the mask for features (channels) 211 | self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0 212 | self.mask_channel_selection: str = ( 213 | "static" # how to choose mask length for channel masking 214 | ) 215 | self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices 216 | self.no_mask_channel_overlap: bool = ( 217 | False # whether to allow channel masks to overlap 218 | ) 219 | self.mask_channel_min_space: int = ( 220 | 1 # min space between spans (if no overlap is enabled) 221 | ) 222 | 223 | # positional embeddings 224 | self.conv_pos: int = ( 225 | 128 # number of filters for convolutional positional embeddings 226 | ) 227 | self.conv_pos_groups: int = ( 228 | 16 # number of groups for convolutional positional embedding 229 | ) 230 | 231 | # relative position embedding 232 | self.relative_position_embedding: bool = ( 233 | False # apply relative position embedding 234 | ) 235 | self.num_buckets: int = 320 # number of buckets for relative position embedding 236 | self.max_distance: int = ( 237 | 1280 # maximum distance for relative position embedding 238 | ) 239 | self.gru_rel_pos: bool = False # apply gated relative position embedding 240 | 241 | if cfg is not None: 242 | self.update(cfg) 243 | 244 | def update(self, cfg: dict): 245 | self.__dict__.update(cfg) 246 | 247 | 248 | class WavLM(nn.Module): 249 | def __init__( 250 | self, 251 | cfg: WavLMConfig, 252 | ) -> None: 253 | super().__init__() 254 | logger.info(f"WavLM Config: {cfg.__dict__}") 255 | 256 | self.cfg = cfg 257 | feature_enc_layers = eval(cfg.conv_feature_layers) 258 | self.embed = feature_enc_layers[-1][0] 259 | 260 | self.feature_extractor = ConvFeatureExtractionModel( 261 | conv_layers=feature_enc_layers, 262 | dropout=0.0, 263 | mode=cfg.extractor_mode, 264 | conv_bias=cfg.conv_bias, 265 | ) 266 | 267 | self.post_extract_proj = ( 268 | nn.Linear(self.embed, cfg.encoder_embed_dim) 269 | if self.embed != cfg.encoder_embed_dim 270 | else None 271 | ) 272 | 273 | self.mask_prob = cfg.mask_prob 274 | self.mask_selection = cfg.mask_selection 275 | self.mask_other = cfg.mask_other 276 | self.mask_length = cfg.mask_length 277 | self.no_mask_overlap = cfg.no_mask_overlap 278 | self.mask_min_space = cfg.mask_min_space 279 | 280 | self.mask_channel_prob = cfg.mask_channel_prob 281 | self.mask_channel_selection = cfg.mask_channel_selection 282 | self.mask_channel_other = cfg.mask_channel_other 283 | self.mask_channel_length = cfg.mask_channel_length 284 | self.no_mask_channel_overlap = cfg.no_mask_channel_overlap 285 | self.mask_channel_min_space = cfg.mask_channel_min_space 286 | 287 | self.dropout_input = nn.Dropout(cfg.dropout_input) 288 | self.dropout_features = nn.Dropout(cfg.dropout_features) 289 | 290 | self.feature_grad_mult = cfg.feature_grad_mult 291 | 292 | self.mask_emb = nn.Parameter( 293 | torch.FloatTensor(cfg.encoder_embed_dim).uniform_() 294 | ) 295 | 296 | self.encoder = TransformerEncoder(cfg) 297 | self.layer_norm = LayerNorm(self.embed) 298 | 299 | def apply_mask(self, x, padding_mask): 300 | B, T, C = x.shape 301 | if self.mask_prob > 0: 302 | mask_indices = compute_mask_indices( 303 | (B, T), 304 | padding_mask, 305 | self.mask_prob, 306 | self.mask_length, 307 | self.mask_selection, 308 | self.mask_other, 309 | min_masks=2, 310 | no_overlap=self.no_mask_overlap, 311 | min_space=self.mask_min_space, 312 | ) 313 | mask_indices = torch.from_numpy(mask_indices).to(x.device) 314 | x[mask_indices] = self.mask_emb 315 | else: 316 | mask_indices = None 317 | 318 | if self.mask_channel_prob > 0: 319 | mask_channel_indices = compute_mask_indices( 320 | (B, C), 321 | None, 322 | self.mask_channel_prob, 323 | self.mask_channel_length, 324 | self.mask_channel_selection, 325 | self.mask_channel_other, 326 | no_overlap=self.no_mask_channel_overlap, 327 | min_space=self.mask_channel_min_space, 328 | ) 329 | mask_channel_indices = ( 330 | torch.from_numpy(mask_channel_indices) 331 | .to(x.device) 332 | .unsqueeze(1) 333 | .expand(-1, T, -1) 334 | ) 335 | x[mask_channel_indices] = 0 336 | 337 | return x, mask_indices 338 | 339 | def forward_padding_mask( 340 | self, 341 | features: torch.Tensor, 342 | padding_mask: torch.Tensor, 343 | ) -> torch.Tensor: 344 | extra = padding_mask.size(1) % features.size(1) 345 | if extra > 0: 346 | padding_mask = padding_mask[:, :-extra] 347 | padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) 348 | padding_mask = padding_mask.all(-1) 349 | return padding_mask 350 | 351 | def extract_features( 352 | self, 353 | source: torch.Tensor, 354 | padding_mask: Optional[torch.Tensor] = None, 355 | mask: bool = False, 356 | ret_conv: bool = False, 357 | output_layer: Optional[int] = None, 358 | ret_layer_results: bool = False, 359 | ): 360 | 361 | if self.feature_grad_mult > 0: 362 | features = self.feature_extractor(source) 363 | if self.feature_grad_mult != 1.0: 364 | features = GradMultiply.apply(features, self.feature_grad_mult) 365 | else: 366 | with torch.no_grad(): 367 | features = self.feature_extractor(source) 368 | 369 | features = features.transpose(1, 2) 370 | features = self.layer_norm(features) 371 | 372 | if padding_mask is not None: 373 | padding_mask = self.forward_padding_mask(features, padding_mask) 374 | 375 | if self.post_extract_proj is not None: 376 | features = self.post_extract_proj(features) 377 | 378 | features = self.dropout_input(features) 379 | 380 | if mask: 381 | x, mask_indices = self.apply_mask(features, padding_mask) 382 | else: 383 | x = features 384 | 385 | # feature: (B, T, D), float 386 | # target: (B, T), long 387 | # x: (B, T, D), float 388 | # padding_mask: (B, T), bool 389 | # mask_indices: (B, T), bool 390 | x, layer_results = self.encoder( 391 | x, 392 | padding_mask=padding_mask, 393 | layer=None if output_layer is None else output_layer - 1, 394 | ) 395 | 396 | res = { 397 | "x": x, 398 | "padding_mask": padding_mask, 399 | "features": features, 400 | "layer_results": layer_results, 401 | } 402 | 403 | feature = res["features"] if ret_conv else res["x"] 404 | if ret_layer_results: 405 | feature = (feature, res["layer_results"]) 406 | return feature, res["padding_mask"] 407 | 408 | 409 | class ConvFeatureExtractionModel(nn.Module): 410 | def __init__( 411 | self, 412 | conv_layers: List[Tuple[int, int, int]], 413 | dropout: float = 0.0, 414 | mode: str = "default", 415 | conv_bias: bool = False, 416 | conv_type: str = "default", 417 | ): 418 | super().__init__() 419 | 420 | assert mode in {"default", "layer_norm"} 421 | 422 | def block( 423 | n_in, 424 | n_out, 425 | k, 426 | stride, 427 | is_layer_norm=False, 428 | is_group_norm=False, 429 | conv_bias=False, 430 | ): 431 | def make_conv(): 432 | conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) 433 | nn.init.kaiming_normal_(conv.weight) 434 | return conv 435 | 436 | assert ( 437 | is_layer_norm and is_group_norm 438 | ) == False, "layer norm and group norm are exclusive" 439 | 440 | if is_layer_norm: 441 | return nn.Sequential( 442 | make_conv(), 443 | nn.Dropout(p=dropout), 444 | nn.Sequential( 445 | TransposeLast(), 446 | Fp32LayerNorm(dim, elementwise_affine=True), 447 | TransposeLast(), 448 | ), 449 | nn.GELU(), 450 | ) 451 | elif is_group_norm: 452 | return nn.Sequential( 453 | make_conv(), 454 | nn.Dropout(p=dropout), 455 | Fp32GroupNorm(dim, dim, affine=True), 456 | nn.GELU(), 457 | ) 458 | else: 459 | return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) 460 | 461 | self.conv_type = conv_type 462 | if self.conv_type == "default": 463 | in_d = 1 464 | self.conv_layers = nn.ModuleList() 465 | for i, cl in enumerate(conv_layers): 466 | assert len(cl) == 3, "invalid conv definition: " + str(cl) 467 | (dim, k, stride) = cl 468 | 469 | self.conv_layers.append( 470 | block( 471 | in_d, 472 | dim, 473 | k, 474 | stride, 475 | is_layer_norm=mode == "layer_norm", 476 | is_group_norm=mode == "default" and i == 0, 477 | conv_bias=conv_bias, 478 | ) 479 | ) 480 | in_d = dim 481 | elif self.conv_type == "conv2d": 482 | in_d = 1 483 | self.conv_layers = nn.ModuleList() 484 | for i, cl in enumerate(conv_layers): 485 | assert len(cl) == 3 486 | (dim, k, stride) = cl 487 | 488 | self.conv_layers.append(torch.nn.Conv2d(in_d, dim, k, stride)) 489 | self.conv_layers.append(torch.nn.ReLU()) 490 | in_d = dim 491 | elif self.conv_type == "custom": 492 | in_d = 1 493 | idim = 80 494 | self.conv_layers = nn.ModuleList() 495 | for i, cl in enumerate(conv_layers): 496 | assert len(cl) == 3 497 | (dim, k, stride) = cl 498 | self.conv_layers.append( 499 | torch.nn.Conv2d(in_d, dim, k, stride, padding=1) 500 | ) 501 | self.conv_layers.append(torch.nn.LayerNorm([dim, idim])) 502 | self.conv_layers.append(torch.nn.ReLU()) 503 | in_d = dim 504 | if (i + 1) % 2 == 0: 505 | self.conv_layers.append( 506 | torch.nn.MaxPool2d(2, stride=2, ceil_mode=True) 507 | ) 508 | idim = int(math.ceil(idim / 2)) 509 | else: 510 | pass 511 | 512 | def forward(self, x, mask=None): 513 | 514 | # BxT -> BxCxT 515 | x = x.unsqueeze(1) 516 | if self.conv_type == "custom": 517 | for conv in self.conv_layers: 518 | if isinstance(conv, nn.LayerNorm): 519 | x = x.transpose(1, 2) 520 | x = conv(x).transpose(1, 2) 521 | else: 522 | x = conv(x) 523 | x = x.transpose(2, 3).contiguous() 524 | x = x.view(x.size(0), -1, x.size(-1)) 525 | else: 526 | for conv in self.conv_layers: 527 | x = conv(x) 528 | if self.conv_type == "conv2d": 529 | b, c, t, f = x.size() 530 | x = x.transpose(2, 3).contiguous().view(b, c * f, t) 531 | return x 532 | 533 | 534 | class TransformerEncoder(nn.Module): 535 | def __init__(self, args): 536 | super().__init__() 537 | 538 | self.dropout = args.dropout 539 | self.embedding_dim = args.encoder_embed_dim 540 | 541 | self.pos_conv = nn.Conv1d( 542 | self.embedding_dim, 543 | self.embedding_dim, 544 | kernel_size=args.conv_pos, 545 | padding=args.conv_pos // 2, 546 | groups=args.conv_pos_groups, 547 | ) 548 | dropout = 0 549 | std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim)) 550 | nn.init.normal_(self.pos_conv.weight, mean=0, std=std) 551 | nn.init.constant_(self.pos_conv.bias, 0) 552 | 553 | self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) 554 | self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) 555 | 556 | if hasattr(args, "relative_position_embedding"): 557 | self.relative_position_embedding = args.relative_position_embedding 558 | self.num_buckets = args.num_buckets 559 | self.max_distance = args.max_distance 560 | else: 561 | self.relative_position_embedding = False 562 | self.num_buckets = 0 563 | self.max_distance = 0 564 | 565 | self.layers = nn.ModuleList( 566 | [ 567 | TransformerSentenceEncoderLayer( 568 | embedding_dim=self.embedding_dim, 569 | ffn_embedding_dim=args.encoder_ffn_embed_dim, 570 | num_attention_heads=args.encoder_attention_heads, 571 | dropout=self.dropout, 572 | attention_dropout=args.attention_dropout, 573 | activation_dropout=args.activation_dropout, 574 | activation_fn=args.activation_fn, 575 | layer_norm_first=args.layer_norm_first, 576 | has_relative_attention_bias=( 577 | self.relative_position_embedding and i == 0 578 | ), 579 | num_buckets=self.num_buckets, 580 | max_distance=self.max_distance, 581 | gru_rel_pos=args.gru_rel_pos, 582 | ) 583 | for i in range(args.encoder_layers) 584 | ] 585 | ) 586 | 587 | self.layer_norm_first = args.layer_norm_first 588 | self.layer_norm = LayerNorm(self.embedding_dim) 589 | self.layerdrop = args.encoder_layerdrop 590 | 591 | self.apply(init_bert_params) 592 | 593 | def forward( 594 | self, 595 | x, 596 | padding_mask=None, 597 | streaming_mask=None, 598 | layer=None, 599 | ffn_adapters=None, 600 | freeze_pos=False, 601 | freeze_layers=None, 602 | ): 603 | x, layer_results = self.extract_features( 604 | x, padding_mask, streaming_mask, layer, ffn_adapters, freeze_pos=freeze_pos, freeze_layers=freeze_layers 605 | ) 606 | 607 | if self.layer_norm_first and layer is None: 608 | x = self.layer_norm(x) 609 | 610 | return x, layer_results 611 | 612 | def extract_features( 613 | self, 614 | x, 615 | padding_mask=None, 616 | streaming_mask=None, 617 | tgt_layer=None, 618 | ffn_adapters=None, 619 | freeze_pos=False, 620 | freeze_layers=None, 621 | ): 622 | 623 | if padding_mask is not None: 624 | x[padding_mask] = 0 625 | 626 | if freeze_pos: 627 | with torch.no_grad(): 628 | x_conv = self.pos_conv(x.transpose(1, 2)) 629 | x_conv = x_conv.transpose(1, 2) 630 | x = x + x_conv 631 | else: 632 | x_conv = self.pos_conv(x.transpose(1, 2)) 633 | x_conv = x_conv.transpose(1, 2) 634 | x = x + x_conv 635 | 636 | if not self.layer_norm_first: 637 | x = self.layer_norm(x) 638 | 639 | x = F.dropout(x, p=self.dropout, training=self.training) 640 | 641 | # B x T x C -> T x B x C 642 | x = x.transpose(0, 1) 643 | 644 | layer_results = [] 645 | z = None 646 | # if tgt_layer is not None: 647 | # layer_results.append((x, z)) 648 | r = None 649 | pos_bias = None 650 | for i, layer in enumerate(self.layers): 651 | dropout_probability = np.random.random() 652 | if not self.training or (dropout_probability > self.layerdrop): 653 | ffn_adapter = None 654 | if ffn_adapters is not None: 655 | ffn_adapter = ffn_adapters[i] 656 | 657 | if isinstance(freeze_layers, list) and i + 1 in freeze_layers: 658 | with torch.no_grad(): 659 | x, z, pos_bias = layer( 660 | x, 661 | self_attn_padding_mask=padding_mask, 662 | need_weights=False, 663 | self_attn_mask=streaming_mask, 664 | pos_bias=pos_bias, 665 | ffn_adapter=ffn_adapter, 666 | ) 667 | else: 668 | x, z, pos_bias = layer( 669 | x, 670 | self_attn_padding_mask=padding_mask, 671 | need_weights=False, 672 | self_attn_mask=streaming_mask, 673 | pos_bias=pos_bias, 674 | ffn_adapter=ffn_adapter, 675 | ) 676 | # if tgt_layer is not None: 677 | # layer_results.append((x, z)) 678 | layer_results.append((x, z)) 679 | if i == tgt_layer: 680 | r = x 681 | break 682 | 683 | if r is not None: 684 | x = r 685 | 686 | # T x B x C -> B x T x C 687 | x = x.transpose(0, 1) 688 | 689 | return x, layer_results 690 | 691 | 692 | class TransformerSentenceEncoderLayer(nn.Module): 693 | """ 694 | Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained 695 | models. 696 | """ 697 | 698 | def __init__( 699 | self, 700 | embedding_dim: float = 768, 701 | ffn_embedding_dim: float = 3072, 702 | num_attention_heads: float = 8, 703 | dropout: float = 0.1, 704 | attention_dropout: float = 0.1, 705 | activation_dropout: float = 0.1, 706 | activation_fn: str = "relu", 707 | layer_norm_first: bool = False, 708 | has_relative_attention_bias: bool = False, 709 | num_buckets: int = 0, 710 | max_distance: int = 0, 711 | rescale_init: bool = False, 712 | gru_rel_pos: bool = False, 713 | ) -> None: 714 | 715 | super().__init__() 716 | # Initialize parameters 717 | self.embedding_dim = embedding_dim 718 | self.dropout = dropout 719 | self.activation_dropout = activation_dropout 720 | 721 | # Initialize blocks 722 | self.activation_name = activation_fn 723 | self.activation_fn = get_activation_fn(activation_fn) 724 | self.self_attn = MultiheadAttention( 725 | self.embedding_dim, 726 | num_attention_heads, 727 | dropout=attention_dropout, 728 | self_attention=True, 729 | has_relative_attention_bias=has_relative_attention_bias, 730 | num_buckets=num_buckets, 731 | max_distance=max_distance, 732 | rescale_init=rescale_init, 733 | gru_rel_pos=gru_rel_pos, 734 | ) 735 | 736 | self.dropout1 = nn.Dropout(dropout) 737 | self.dropout2 = nn.Dropout(self.activation_dropout) 738 | self.dropout3 = nn.Dropout(dropout) 739 | 740 | self.layer_norm_first = layer_norm_first 741 | 742 | # layer norm associated with the self attention layer 743 | self.self_attn_layer_norm = LayerNorm(self.embedding_dim) 744 | 745 | if self.activation_name == "glu": 746 | self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish") 747 | else: 748 | self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) 749 | self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) 750 | 751 | # layer norm associated with the position wise feed-forward NN 752 | self.final_layer_norm = LayerNorm(self.embedding_dim) 753 | 754 | def forward( 755 | self, 756 | x: torch.Tensor, 757 | self_attn_mask: torch.Tensor = None, 758 | self_attn_padding_mask: torch.Tensor = None, 759 | need_weights: bool = False, 760 | pos_bias=None, 761 | ffn_adapter=None, 762 | ): 763 | """ 764 | LayerNorm is applied either before or after the self-attention/ffn 765 | modules similar to the original Transformer imlementation. 766 | """ 767 | residual = x 768 | 769 | if self.layer_norm_first: 770 | x = self.self_attn_layer_norm(x) 771 | x, attn, pos_bias = self.self_attn( 772 | query=x, 773 | key=x, 774 | value=x, 775 | key_padding_mask=self_attn_padding_mask, 776 | need_weights=False, 777 | attn_mask=self_attn_mask, 778 | position_bias=pos_bias, 779 | ) 780 | x = self.dropout1(x) 781 | x = residual + x 782 | 783 | residual = x 784 | x = self.final_layer_norm(x) 785 | if self.activation_name == "glu": 786 | x = self.fc1(x) 787 | else: 788 | x = self.activation_fn(self.fc1(x)) 789 | x = self.dropout2(x) 790 | x = self.fc2(x) 791 | 792 | if ffn_adapter is not None: 793 | x = ffn_adapter(x) 794 | 795 | x = self.dropout3(x) 796 | x = residual + x 797 | else: 798 | x, attn, pos_bias = self.self_attn( 799 | query=x, 800 | key=x, 801 | value=x, 802 | key_padding_mask=self_attn_padding_mask, 803 | need_weights=need_weights, 804 | attn_mask=self_attn_mask, 805 | position_bias=pos_bias, 806 | ) 807 | 808 | x = self.dropout1(x) 809 | x = residual + x 810 | 811 | x = self.self_attn_layer_norm(x) 812 | 813 | residual = x 814 | if self.activation_name == "glu": 815 | x = self.fc1(x) 816 | else: 817 | x = self.activation_fn(self.fc1(x)) 818 | x = self.dropout2(x) 819 | x = self.fc2(x) 820 | 821 | if ffn_adapter is not None: 822 | x = ffn_adapter(x) 823 | 824 | x = self.dropout3(x) 825 | x = residual + x 826 | x = self.final_layer_norm(x) 827 | 828 | return x, attn, pos_bias 829 | -------------------------------------------------------------------------------- /s3prl_py/spin/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vectominist/spin/73edb7ae120be67a2d584931eeb388caecda744e/s3prl_py/spin/__init__.py -------------------------------------------------------------------------------- /s3prl_py/spin/expert.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.nn.utils.rnn import pad_sequence 7 | 8 | from ..interfaces import UpstreamBase 9 | 10 | SAMPLE_RATE = 16000 11 | EXAMPLE_SEC = 5 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class UpstreamExpert(UpstreamBase): 17 | def __init__( 18 | self, 19 | ckpt, 20 | feat_select: str = "hidden", 21 | **kwargs, 22 | ): 23 | super().__init__(**kwargs) 24 | 25 | # set directory of `spin/` 26 | sys.path.append("/path/to/spin") 27 | from src.model import SpinModel 28 | 29 | self.model = SpinModel.load_from_checkpoint(ckpt, strict=False) 30 | if self.model.encoder_type == "HuBERT": 31 | self.model.encoder.model.feature_grad_mult = 0.0 32 | self.feat_select = feat_select 33 | 34 | assert self.feat_select in { 35 | "all_feats", 36 | "hidden", 37 | "disentangled", 38 | "logits", 39 | "probs", 40 | "last", 41 | }, self.feat_select 42 | 43 | if self.feat_select in {"logits", "probs"}: 44 | assert self.model.loss_type in {"SwavVQDisentangle"} 45 | self.model.loss_module.normalize_codebook() 46 | 47 | logger.info(f"Feature selection: {self.feat_select}") 48 | logger.info(f"Loss type: {self.model.loss_type}") 49 | 50 | def get_downsample_rates(self, key: str) -> int: 51 | return self.model.encoder_rate 52 | 53 | def forward(self, wavs): 54 | device = wavs[0].device 55 | wav_lengths = torch.LongTensor([len(wav) for wav in wavs]).to(device) 56 | wav_padding_mask = ~torch.lt( 57 | torch.arange(max(wav_lengths)).unsqueeze(0).to(device), 58 | wav_lengths.unsqueeze(1), 59 | ) 60 | padded_wav = pad_sequence(wavs, batch_first=True) 61 | 62 | results = self.model( 63 | (padded_wav, wav_lengths, wav_padding_mask), 64 | feat_only=True, 65 | ) 66 | 67 | outputs = {} 68 | 69 | if self.model.loss_type in {"SwavVQDisentangle"}: 70 | outputs["code"] = results["codes"] 71 | outputs["logits"] = results["logits"] 72 | outputs["probs"] = F.softmax(results["logits"], dim=-1) 73 | feat = results["repr_list"] 74 | feat_list = results["feat_list"] 75 | 76 | outputs["disentangled"] = feat 77 | 78 | if self.feat_select == "all_feats": 79 | outputs["hidden_states"] = [feat] + feat_list 80 | elif self.feat_select == "hidden": 81 | outputs["hidden_states"] = feat_list 82 | elif self.feat_select == "last": 83 | outputs["hidden_states"] = [feat_list[-1]] 84 | elif self.feat_select == "disentangled": 85 | outputs["hidden_states"] = [feat] 86 | elif self.feat_select == "logits": 87 | outputs["hidden_states"] = [results["logits"]] 88 | elif self.feat_select == "probs": 89 | outputs["hidden_states"] = [outputs["probs"]] 90 | 91 | outputs["last_hidden_state"] = outputs["hidden_states"][-1] 92 | 93 | return outputs 94 | -------------------------------------------------------------------------------- /s3prl_py/spin/hubconf.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from s3prl.util.download import _urls_to_filepaths 4 | 5 | from .expert import UpstreamExpert as _UpstreamExpert 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | def spin_custom( 11 | ckpt: str, 12 | refresh: bool = False, 13 | **kwargs, 14 | ): 15 | if ckpt.startswith("http"): 16 | ckpt = _urls_to_filepaths(ckpt, refresh=refresh) 17 | 18 | return _UpstreamExpert(ckpt, **kwargs) 19 | 20 | 21 | def spin_local(*args, **kwargs): 22 | return spin_custom(*args, **kwargs) 23 | 24 | 25 | def spin_url(*args, **kwargs): 26 | return spin_custom(*args, **kwargs) 27 | 28 | 29 | def spin_hubert_128(refresh=False, **kwargs): 30 | kwargs[ 31 | "ckpt" 32 | ] = "https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_hubert_128.ckpt" 33 | return spin_custom(refresh=refresh, **kwargs) 34 | 35 | 36 | def spin_hubert_256(refresh=False, **kwargs): 37 | kwargs[ 38 | "ckpt" 39 | ] = "https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_hubert_256.ckpt" 40 | return spin_custom(refresh=refresh, **kwargs) 41 | 42 | 43 | def spin_hubert_512(refresh=False, **kwargs): 44 | kwargs[ 45 | "ckpt" 46 | ] = "https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_hubert_512.ckpt" 47 | return spin_custom(refresh=refresh, **kwargs) 48 | 49 | 50 | def spin_hubert_1024(refresh=False, **kwargs): 51 | kwargs[ 52 | "ckpt" 53 | ] = "https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_hubert_1024.ckpt" 54 | return spin_custom(refresh=refresh, **kwargs) 55 | 56 | 57 | def spin_hubert_2048(refresh=False, **kwargs): 58 | kwargs[ 59 | "ckpt" 60 | ] = "https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_hubert_2048.ckpt" 61 | return spin_custom(refresh=refresh, **kwargs) 62 | 63 | 64 | def spin_wavlm_128(refresh=False, **kwargs): 65 | kwargs[ 66 | "ckpt" 67 | ] = "https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_wavlm_128.ckpt" 68 | return spin_custom(refresh=refresh, **kwargs) 69 | 70 | 71 | def spin_wavlm_256(refresh=False, **kwargs): 72 | kwargs[ 73 | "ckpt" 74 | ] = "https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_wavlm_256.ckpt" 75 | return spin_custom(refresh=refresh, **kwargs) 76 | 77 | 78 | def spin_wavlm_512(refresh=False, **kwargs): 79 | kwargs[ 80 | "ckpt" 81 | ] = "https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_wavlm_512.ckpt" 82 | return spin_custom(refresh=refresh, **kwargs) 83 | 84 | 85 | def spin_wavlm_1024(refresh=False, **kwargs): 86 | kwargs[ 87 | "ckpt" 88 | ] = "https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_wavlm_1024.ckpt" 89 | return spin_custom(refresh=refresh, **kwargs) 90 | 91 | 92 | def spin_wavlm_2048(refresh=False, **kwargs): 93 | kwargs[ 94 | "ckpt" 95 | ] = "https://huggingface.co/datasets/vectominist/spin_ckpt/resolve/main/spin_wavlm_2048.ckpt" 96 | return spin_custom(refresh=refresh, **kwargs) 97 | -------------------------------------------------------------------------------- /script/prepare.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | libri_dir=$1 4 | json_dir=$2 5 | 6 | mkdir -p $json_dir 7 | 8 | python3 prepare_data.py \ 9 | $libri_dir \ 10 | $json_dir \ 11 | --split train-clean-100 \ 12 | dev-clean \ 13 | dev-other \ 14 | --sort-by-len 15 | -------------------------------------------------------------------------------- /script/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | exp_name=$1 4 | exp_dir=$2 5 | config=config/spin.yaml 6 | 7 | mkdir -p $exp_dir 8 | 9 | echo "Name: $exp_name" 10 | echo "Config: $config" 11 | 12 | python3 run_task.py \ 13 | SpinPretrainTask \ 14 | --config $config \ 15 | --save-path $exp_dir/$exp_name \ 16 | --gpus 1 \ 17 | --njobs 16 18 | -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import ( 2 | AudioPretrainDataset, 3 | AudioPretrainPnmiValDataset, 4 | collate_fn, 5 | val_collate_fn, 6 | ) 7 | from .sampler import MaxLengthBatchSampler, MaxLengthDistributedSampler 8 | -------------------------------------------------------------------------------- /src/data/audio.py: -------------------------------------------------------------------------------- 1 | # Source: https://github.com/auspicious3000/contentvec/blob/main/contentvec/data/audio/audio_utils_1.py 2 | # Paper: https://arxiv.org/abs/2204.09224 3 | 4 | 5 | import numpy as np 6 | import parselmouth 7 | 8 | 9 | def make_lowshelf(g, fc, Q, fs=44100): 10 | """Generate filter coefficients for 2nd order Lowshelf filter. 11 | This function follows the code from the JUCE DSP library 12 | which can be found in `juce_IIRFilter.cpp`. 13 | 14 | The design equations are based upon those found in the Cookbook 15 | formulae for audio equalizer biquad filter coefficients 16 | by Robert Bristow-Johnson. 17 | https://www.w3.org/2011/audio/audio-eq-cookbook.html 18 | Args: 19 | g (float): Gain factor in dB. 20 | fc (float): Cutoff frequency in Hz. 21 | Q (float): Q factor. 22 | fs (float): Sampling frequency in Hz. 23 | Returns: 24 | tuple: (b, a) filter coefficients 25 | """ 26 | # convert gain from dB to linear 27 | g = np.power(10, (g / 20)) 28 | 29 | # initial values 30 | A = np.max([0.0, np.sqrt(g)]) 31 | aminus1 = A - 1 32 | aplus1 = A + 1 33 | omega = (2 * np.pi * np.max([fc, 2.0])) / fs 34 | coso = np.cos(omega) 35 | beta = np.sin(omega) * np.sqrt(A) / Q 36 | aminus1TimesCoso = aminus1 * coso 37 | 38 | # coefs calculation 39 | b0 = A * (aplus1 - aminus1TimesCoso + beta) 40 | b1 = A * 2 * (aminus1 - aplus1 * coso) 41 | b2 = A * (aplus1 - aminus1TimesCoso - beta) 42 | a0 = aplus1 + aminus1TimesCoso + beta 43 | a1 = -2 * (aminus1 + aplus1 * coso) 44 | a2 = aplus1 + aminus1TimesCoso - beta 45 | 46 | # output coefs 47 | # b = np.array([b0/a0, b1/a0, b2/a0]) 48 | # a = np.array([a0/a0, a1/a0, a2/a0]) 49 | 50 | return np.array([[b0 / a0, b1 / a0, b2 / a0, 1.0, a1 / a0, a2 / a0]]) 51 | 52 | 53 | def make_highself(g, fc, Q, fs=44100): 54 | """Generate filter coefficients for 2nd order Highshelf filter. 55 | This function follows the code from the JUCE DSP library 56 | which can be found in `juce_IIRFilter.cpp`. 57 | 58 | The design equations are based upon those found in the Cookbook 59 | formulae for audio equalizer biquad filter coefficients 60 | by Robert Bristow-Johnson. 61 | https://www.w3.org/2011/audio/audio-eq-cookbook.html 62 | Args: 63 | g (float): Gain factor in dB. 64 | fc (float): Cutoff frequency in Hz. 65 | Q (float): Q factor. 66 | fs (float): Sampling frequency in Hz. 67 | Returns: 68 | tuple: (b, a) filter coefficients 69 | """ 70 | # convert gain from dB to linear 71 | g = np.power(10, (g / 20)) 72 | 73 | # initial values 74 | A = np.max([0.0, np.sqrt(g)]) 75 | aminus1 = A - 1 76 | aplus1 = A + 1 77 | omega = (2 * np.pi * np.max([fc, 2.0])) / fs 78 | coso = np.cos(omega) 79 | beta = np.sin(omega) * np.sqrt(A) / Q 80 | aminus1TimesCoso = aminus1 * coso 81 | 82 | # coefs calculation 83 | b0 = A * (aplus1 + aminus1TimesCoso + beta) 84 | b1 = A * -2 * (aminus1 + aplus1 * coso) 85 | b2 = A * (aplus1 + aminus1TimesCoso - beta) 86 | a0 = aplus1 - aminus1TimesCoso + beta 87 | a1 = 2 * (aminus1 - aplus1 * coso) 88 | a2 = aplus1 - aminus1TimesCoso - beta 89 | 90 | # output coefs 91 | # b = np.array([b0/a0, b1/a0, b2/a0]) 92 | # a = np.array([a0/a0, a1/a0, a2/a0]) 93 | 94 | return np.array([[b0 / a0, b1 / a0, b2 / a0, 1.0, a1 / a0, a2 / a0]]) 95 | 96 | 97 | def make_peaking(g, fc, Q, fs=44100): 98 | """Generate filter coefficients for 2nd order Peaking EQ. 99 | This function follows the code from the JUCE DSP library 100 | which can be found in `juce_IIRFilter.cpp`. 101 | 102 | The design equations are based upon those found in the Cookbook 103 | formulae for audio equalizer biquad filter coefficients 104 | by Robert Bristow-Johnson. 105 | https://www.w3.org/2011/audio/audio-eq-cookbook.html 106 | Args: 107 | g (float): Gain factor in dB. 108 | fc (float): Cutoff frequency in Hz. 109 | Q (float): Q factor. 110 | fs (float): Sampling frequency in Hz. 111 | Returns: 112 | tuple: (b, a) filter coefficients 113 | """ 114 | # convert gain from dB to linear 115 | g = np.power(10, (g / 20)) 116 | 117 | # initial values 118 | A = np.max([0.0, np.sqrt(g)]) 119 | omega = (2 * np.pi * np.max([fc, 2.0])) / fs 120 | alpha = np.sin(omega) / (Q * 2) 121 | c2 = -2 * np.cos(omega) 122 | alphaTimesA = alpha * A 123 | alphaOverA = alpha / A 124 | 125 | # coefs calculation 126 | b0 = 1 + alphaTimesA 127 | b1 = c2 128 | b2 = 1 - alphaTimesA 129 | a0 = 1 + alphaOverA 130 | a1 = c2 131 | a2 = 1 - alphaOverA 132 | 133 | # output coefs 134 | # b = np.array([b0/a0, b1/a0, b2/a0]) 135 | # a = np.array([a0/a0, a1/a0, a2/a0]) 136 | 137 | return np.array([[b0 / a0, b1 / a0, b2 / a0, 1.0, a1 / a0, a2 / a0]]) 138 | 139 | 140 | def params2sos(G, Fc, Q, fs): 141 | """Convert 5 band EQ paramaters to 2nd order sections. 142 | Takes a vector with shape (13,) of denormalized EQ parameters 143 | and calculates filter coefficients for each of the 5 filters. 144 | These coefficients (2nd order sections) are then stored into a 145 | single (5,6) matrix. This matrix can be fed to `scipy.signal.sosfreqz()` 146 | in order to determine the frequency response of the cascasd of 147 | all five biquad filters. 148 | Args: 149 | x (float): Gain factor in dB. 150 | fs (float): Sampling frequency in Hz. 151 | Returns: 152 | ndarray: filter coefficients for 5 band EQ stored in (5,6) matrix. 153 | [[b1_0, b1_1, b1_2, a1_0, a1_1, a1_2], # lowshelf coefficients 154 | [b2_0, b2_1, b2_2, a2_0, a2_1, a2_2], # first band coefficients 155 | [b3_0, b3_1, b3_2, a3_0, a3_1, a3_2], # second band coefficients 156 | [b4_0, b4_1, b4_2, a4_0, a4_1, a4_2], # third band coefficients 157 | [b5_0, b5_1, b5_2, a5_0, a5_1, a5_2]] # highshelf coefficients 158 | """ 159 | # generate filter coefficients from eq params 160 | c0 = make_lowshelf(G[0], Fc[0], Q[0], fs=fs) 161 | c1 = make_peaking(G[1], Fc[1], Q[1], fs=fs) 162 | c2 = make_peaking(G[2], Fc[2], Q[2], fs=fs) 163 | c3 = make_peaking(G[3], Fc[3], Q[3], fs=fs) 164 | c4 = make_peaking(G[4], Fc[4], Q[4], fs=fs) 165 | c5 = make_peaking(G[5], Fc[5], Q[5], fs=fs) 166 | c6 = make_peaking(G[6], Fc[6], Q[6], fs=fs) 167 | c7 = make_peaking(G[7], Fc[7], Q[7], fs=fs) 168 | c8 = make_peaking(G[8], Fc[8], Q[8], fs=fs) 169 | c9 = make_highself(G[9], Fc[9], Q[9], fs=fs) 170 | 171 | # stuff coefficients into second order sections structure 172 | sos = np.concatenate([c0, c1, c2, c3, c4, c5, c6, c7, c8, c9], axis=0) 173 | 174 | return sos 175 | 176 | 177 | def change_gender(x, fs, lo, hi, ratio_fs, ratio_ps, ratio_pr): 178 | s = parselmouth.Sound(x, sampling_frequency=fs) 179 | f0 = s.to_pitch_ac(pitch_floor=lo, pitch_ceiling=hi, time_step=0.8 / lo) 180 | f0_np = f0.selected_array["frequency"] 181 | f0_med = np.median(f0_np[f0_np != 0]).item() 182 | ss = parselmouth.praat.call( 183 | [s, f0], "Change gender", ratio_fs, f0_med * ratio_ps, ratio_pr, 1.0 184 | ) 185 | return ss.values.squeeze(0) 186 | 187 | 188 | def change_gender_f0(x, fs, lo, hi, ratio_fs, new_f0_med, ratio_pr): 189 | s = parselmouth.Sound(x, sampling_frequency=fs) 190 | ss = parselmouth.praat.call( 191 | s, "Change gender", lo, hi, ratio_fs, new_f0_med, ratio_pr, 1.0 192 | ) 193 | return ss.values.squeeze(0) 194 | -------------------------------------------------------------------------------- /src/data/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import pickle 5 | import random 6 | import re 7 | from pathlib import Path 8 | from typing import List, Tuple 9 | 10 | import numpy as np 11 | import soundfile as sf 12 | import torch 13 | import torch.nn.functional as F 14 | from scipy.signal import sosfilt 15 | from torch.nn.utils.rnn import pad_sequence 16 | from torch.utils.data import Dataset 17 | 18 | from src.util import len_to_padding 19 | 20 | from .audio import change_gender, change_gender_f0, params2sos 21 | 22 | logger = logging.getLogger("dataset") 23 | 24 | Qmin, Qmax = 2, 5 25 | 26 | 27 | def read_phn(tsv_path, rm_stress=True): 28 | uid2phns = {} 29 | with open(tsv_path) as f: 30 | for line in f: 31 | uid, phns = line.rstrip().split("\t") 32 | phns = phns.split(",") 33 | if rm_stress: 34 | phns = [re.sub("[0-9]", "", phn) for phn in phns] 35 | uid2phns[uid] = phns 36 | return uid2phns 37 | 38 | 39 | class AudioPretrainDataset(Dataset): 40 | def __init__( 41 | self, 42 | json_dir: str, 43 | splits: List[str], 44 | spk2info: str = None, 45 | min_audio_len: int = 1, 46 | max_audio_len: int = 1_000_000, 47 | sample_rate: int = 16_000, 48 | random_crop_len: int = -1, 49 | use_ratio: float = 1.0, 50 | normalize: bool = False, 51 | nansy_both: bool = False, 52 | ) -> None: 53 | super().__init__() 54 | 55 | logger.info(f"Loading audio from: {json_dir}") 56 | logger.info(f"Splits : {splits}") 57 | self.split_type = "train" if splits[0].startswith("train") else "valid" 58 | self.return_augmented = True 59 | 60 | # Data augmentation config 61 | self.random_crop_len = random_crop_len 62 | self.random_crop_len_lab = random_crop_len // 320 63 | self.use_ratio = use_ratio 64 | self.normalize = normalize 65 | self.sample_rate = sample_rate 66 | self.nansy_aug = spk2info is not None 67 | 68 | if self.nansy_aug: 69 | # Load speaker information 70 | logger.info("Apply NANSY speaker perturbation") 71 | # ref: https://arxiv.org/abs/2110.14513 72 | self.num_view = 2 73 | self.nansy_both = nansy_both 74 | 75 | with open(spk2info, "rb") as fp: 76 | self.spk2info = pickle.load(fp) 77 | self.spk2info = self.spk2info[self.split_type] 78 | 79 | self.rng = np.random.default_rng() 80 | self.Fc = np.exp(np.linspace(np.log(60), np.log(7600), 10)) 81 | 82 | self.num_view = 2 if self.nansy_aug else 1 83 | 84 | # Load from .json files 85 | data_list = [] 86 | for s in splits: 87 | with open(os.path.join(json_dir, s + ".json"), "r") as fp: 88 | data_list += json.load(fp) 89 | 90 | # Preserve certain ratio of data 91 | if self.use_ratio < 1.0: 92 | logger.info( 93 | f"Using only {self.use_ratio * 100:.0f}% of randomly chosen data" 94 | ) 95 | random.shuffle(data_list) 96 | data_list = data_list[: int(len(data_list) * self.use_ratio)] 97 | 98 | # Remove files that are too long or too short 99 | logger.info( 100 | f"Removing files shorter than {min_audio_len} or longer than {max_audio_len} frames" 101 | ) 102 | orig_tot_len = sum([l for _, l in data_list]) / sample_rate / 3600.0 103 | orig_num_file = len(data_list) 104 | data_list = [ 105 | (p, l) for p, l in data_list if l >= min_audio_len and l <= max_audio_len 106 | ] 107 | new_tot_len = sum([l for _, l in data_list]) / sample_rate / 3600.0 108 | logger.info(f"Original audio files: {orig_num_file} ({orig_tot_len:.2f} hrs)") 109 | logger.info(f"Final audio files: {len(data_list)} ({new_tot_len:.2f} hrs)") 110 | 111 | # Sort by length (long to short) 112 | data_list = sorted(data_list, key=lambda x: x[1], reverse=True) 113 | 114 | # Extract speaker 115 | spk_list = [Path(p).stem.split("-")[0] for p, _ in data_list] 116 | name_list = [Path(p).stem for p, _ in data_list] 117 | self.data = [ 118 | (p, l, s, n) for (p, l), s, n in zip(data_list, spk_list, name_list) 119 | ] 120 | self.data_lens = [l for _, l, _, _ in self.data] 121 | 122 | if self.nansy_aug: 123 | # Check speaker 124 | data_spk_set = set(spk_list) 125 | avail_spk_set = set(self.spk2info.keys()) 126 | assert all(s in avail_spk_set for s in data_spk_set), "Missing speakers!" 127 | logger.info(f"Total {len(data_spk_set)} speakers") 128 | 129 | def __len__(self) -> int: 130 | return len(self.data) 131 | 132 | def __getitem__(self, index: int) -> List[torch.FloatTensor]: 133 | # each sample: (audio file path, number of samples, speaker id) 134 | path, num_frames, spk, uid = self.data[index] 135 | 136 | wav, sr = sf.read(path) 137 | if wav.ndim == 2: 138 | wav = wav.mean(-1) 139 | 140 | wav = wav.astype(np.float32) 141 | 142 | if not self.return_augmented: 143 | return torch.from_numpy(wav) 144 | 145 | # Randomly crop wave length 146 | if self.random_crop_len > 0 and len(wav) > self.random_crop_len: 147 | idx = np.random.randint(0, len(wav) - self.random_crop_len) 148 | wav = wav[idx : idx + self.random_crop_len] 149 | 150 | # Apply NANSY speaker perturbation 151 | if self.nansy_aug: 152 | if self.nansy_both: 153 | wavs = [ 154 | self.perturb_speaker(path, wav, sr, spk) 155 | for _ in range(self.num_view) 156 | ] 157 | else: 158 | wavs = [wav] + [ 159 | self.perturb_speaker(path, wav, sr, spk) 160 | for _ in range(self.num_view - 1) 161 | ] 162 | else: 163 | wavs = [wav for _ in range(self.num_view)] 164 | 165 | wavs = [torch.FloatTensor(w) for w in wavs] 166 | if self.normalize: 167 | wavs = [F.layer_norm(w, w.shape) for w in wavs] 168 | 169 | return wavs 170 | 171 | def get_spk_info(self, spk: str): 172 | _, (lo, hi, _) = self.spk2info[spk] 173 | if lo == 50: 174 | lo = 75 175 | if spk == "1447": 176 | lo, hi = 60, 400 177 | return lo, hi 178 | 179 | def random_eq(self, wav, sr): 180 | z = self.rng.uniform(0, 1, size=(10,)) 181 | Q = Qmin * (Qmax / Qmin) ** z 182 | G = self.rng.uniform(-12, 12, size=(10,)) 183 | sos = params2sos(G, self.Fc, Q, sr) 184 | wav = sosfilt(sos, wav) 185 | return wav 186 | 187 | def random_formant_f0(self, wav, sr, spk): 188 | lo, hi = self.get_spk_info(spk) 189 | 190 | ratio_fs = self.rng.uniform(1, 1.4) 191 | coin = self.rng.random() > 0.5 192 | ratio_fs = coin * ratio_fs + (1 - coin) * (1 / ratio_fs) 193 | 194 | ratio_ps = self.rng.uniform(1, 2) 195 | coin = self.rng.random() > 0.5 196 | ratio_ps = coin * ratio_ps + (1 - coin) * (1 / ratio_ps) 197 | 198 | ratio_pr = self.rng.uniform(1, 1.5) 199 | coin = self.rng.random() > 0.5 200 | ratio_pr = coin * ratio_pr + (1 - coin) * (1 / ratio_pr) 201 | 202 | ss = change_gender(wav, sr, lo, hi, ratio_fs, ratio_ps, ratio_pr) 203 | 204 | return ss 205 | 206 | def fixed_formant_f0(self, wav, sr, spk): 207 | _, (lo, hi, _) = self.spk2info[spk] 208 | 209 | if lo == 50: 210 | lo = 75 211 | ratio_fs, f0_med, ratio_pr = 1.2, 300, 1.2 212 | else: 213 | ratio_fs, f0_med, ratio_pr = 0.8, 100, 0.8 214 | 215 | ss = change_gender_f0(wav, sr, lo, hi, ratio_fs, f0_med, ratio_pr) 216 | return ss 217 | 218 | def perturb_speaker(self, path, wav, sr, spk): 219 | if self.split_type == "train": 220 | # Speaker perturbation 221 | try: 222 | wav_p = self.random_formant_f0(wav, sr, spk) 223 | except UserWarning: 224 | wav_p = np.copy(wav) 225 | logger.info(f"Praat warning - {path}") 226 | except RuntimeError: 227 | wav_p = np.copy(wav) 228 | logger.info(f"Praat Error - {path}") 229 | wav_p = self.random_eq(wav_p, sr) 230 | else: 231 | try: 232 | wav_p = self.fixed_formant_f0(wav, sr, spk) 233 | except UserWarning: 234 | wav_p = np.copy(wav) 235 | logger.info(f"Praat warning - {path}") 236 | except RuntimeError: 237 | wav_p = np.copy(wav) 238 | logger.info(f"Praat Error - {path}") 239 | 240 | wav_p = np.clip(wav_p, -1.0, 1.0) 241 | return wav_p 242 | 243 | 244 | class AudioPretrainPnmiValDataset(Dataset): 245 | def __init__( 246 | self, 247 | json_dir: str, 248 | phn_dir: str, 249 | splits: List[str], 250 | min_audio_len: int = 1, 251 | max_audio_len: int = 1_000_000, 252 | sample_rate: int = 16_000, 253 | crop_len: int = 160_000, 254 | normalize: bool = False, 255 | **kwargs, 256 | ) -> None: 257 | super().__init__() 258 | 259 | logger.info(f"Loading audio from: {json_dir}") 260 | logger.info(f"Loading phoneme alignments from: {phn_dir}") 261 | logger.info(f"Splits : {splits}") 262 | self.split_type = "train" if splits[0].startswith("train") else "valid" 263 | 264 | # Data augmentation config 265 | self.crop_len = crop_len 266 | self.normalize = normalize 267 | self.sample_rate = sample_rate 268 | 269 | # Load from .json files 270 | data_list = [] 271 | for s in splits: 272 | with open(os.path.join(json_dir, s + ".json"), "r") as fp: 273 | data_list += json.load(fp) 274 | 275 | # Load from .tsv files 276 | self.uid2refs = {} 277 | for s in splits: 278 | self.uid2refs.update(read_phn(os.path.join(phn_dir, s + ".tsv"))) 279 | 280 | # Remove files that are too long or too short 281 | logger.info( 282 | f"Removing files shorter than {min_audio_len} or longer than {max_audio_len} frames" 283 | ) 284 | orig_tot_len = sum([l for _, l in data_list]) / sample_rate / 3600.0 285 | orig_num_file = len(data_list) 286 | data_list = [ 287 | (p, l) for p, l in data_list if l >= min_audio_len and l <= max_audio_len 288 | ] 289 | new_tot_len = sum([l for _, l in data_list]) / sample_rate / 3600.0 290 | logger.info(f"Original audio files: {orig_num_file} ({orig_tot_len:.2f} hrs)") 291 | logger.info(f"Final audio files: {len(data_list)} ({new_tot_len:.2f} hrs)") 292 | 293 | # Sort by length (long to short) 294 | data_list = sorted(data_list, key=lambda x: x[1], reverse=True) 295 | 296 | # Extract speaker 297 | spk_list = [Path(p).stem.split("-")[0] for p, _ in data_list] 298 | name_list = [Path(p).stem for p, _ in data_list] 299 | self.data = [ 300 | (p, l, s, n) for (p, l), s, n in zip(data_list, spk_list, name_list) 301 | ] 302 | self.data_lens = [l for _, l, _, _ in self.data] 303 | 304 | def __len__(self) -> int: 305 | return len(self.data) 306 | 307 | def __getitem__(self, index: int) -> List[torch.FloatTensor]: 308 | # each sample: (audio file path, number of samples, speaker id) 309 | path, num_frames, spk, uid = self.data[index] 310 | 311 | wav, sr = sf.read(path) 312 | if wav.ndim == 2: 313 | wav = wav.mean(-1) 314 | 315 | wav = wav.astype(np.float32) 316 | 317 | # Crop wave length 318 | if self.crop_len > 0 and len(wav) > self.crop_len: 319 | wav = wav[: self.crop_len] 320 | 321 | wav = torch.from_numpy(wav) 322 | 323 | if self.normalize: 324 | wav = F.layer_norm(wav, wav.shape) 325 | 326 | return wav, uid 327 | 328 | 329 | def collate_fn( 330 | batch: List[List[Tuple[torch.FloatTensor, List[int]]]], 331 | ) -> Tuple[torch.FloatTensor, torch.LongTensor, torch.BoolTensor,]: 332 | wav_list = [] 333 | wav_len = [] 334 | 335 | for wavs in batch: 336 | for _, w in enumerate(wavs): 337 | wav_list.append(w) 338 | wav_len.append(len(w)) 339 | 340 | wav_list = pad_sequence(wav_list, batch_first=True) 341 | wav_len = torch.LongTensor(wav_len) 342 | padding_mask = len_to_padding(wav_len) 343 | # padding_mask: has value <=> 0 else 1 344 | 345 | return wav_list, wav_len, padding_mask 346 | 347 | 348 | def val_collate_fn( 349 | batch: List[List[Tuple[torch.FloatTensor, List[int]]]], 350 | ) -> Tuple[torch.FloatTensor, torch.LongTensor, torch.BoolTensor, torch.LongTensor]: 351 | wav_list = [] 352 | wav_len = [] 353 | uid_list = [] 354 | 355 | for wav, uid in batch: 356 | wav_list.append(wav) 357 | wav_len.append(len(wav)) 358 | uid_list.append(uid) 359 | 360 | wav_list = pad_sequence(wav_list, batch_first=True) 361 | wav_len = torch.LongTensor(wav_len) 362 | padding_mask = len_to_padding(wav_len) 363 | # padding_mask: has value <=> 0 else 1 364 | 365 | return wav_list, wav_len, padding_mask, uid_list 366 | -------------------------------------------------------------------------------- /src/data/librispeech.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import List, Tuple 4 | 5 | import torchaudio 6 | from tqdm import tqdm 7 | 8 | 9 | def find_all_librispeech(root: str, sort_by_len: bool = False) -> List[Tuple[str, int]]: 10 | files = list(Path(root).rglob("*.flac")) 11 | files = [str(f) for f in files] 12 | file_lens = [torchaudio.info(f).num_frames for f in tqdm(files)] 13 | assert len(files) == len(file_lens), (len(files), len(file_lens)) 14 | data = sorted( 15 | zip(files, file_lens), key=lambda x: x[1 if sort_by_len else 0], reverse=True 16 | ) 17 | return data 18 | 19 | 20 | def save_data_info(data: List[Tuple[str, int]], path: str) -> None: 21 | with open(path, "w") as fp: 22 | json.dump(data, fp, indent=2) 23 | -------------------------------------------------------------------------------- /src/data/sampler.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.distributed as dist 5 | 6 | from .dataset import AudioPretrainDataset 7 | 8 | 9 | class MaxLengthBatchSampler: 10 | def __init__( 11 | self, 12 | lengths: List[int], 13 | max_length: int, 14 | cropped_length: int = 160_000, 15 | shuffle: bool = False, 16 | drop_last: bool = False, 17 | seed: int = 7122, 18 | ) -> None: 19 | self.lengths = lengths 20 | self.max_length = max_length 21 | self.cropped_length = cropped_length if cropped_length > 0 else 1000000 22 | self.shuffle = shuffle 23 | self.drop_last = drop_last 24 | self.seed = seed 25 | self.epoch = 0 26 | 27 | def set_epoch(self, epoch: int): 28 | self.epoch = epoch 29 | 30 | def __iter__(self): 31 | batch_list = [] 32 | batch = [] 33 | cur_length = 0 34 | for i in range(len(self.lengths)): 35 | new_batch = batch + [i] 36 | cur_length += min(self.lengths[i], self.cropped_length) 37 | 38 | if cur_length <= self.max_length: 39 | batch = new_batch 40 | elif len(batch) == 0: 41 | raise ValueError( 42 | f"There is a single length {self.lengths[i]} larger than " 43 | f"max_length {self.max_length}. Please increase " 44 | "the max_length." 45 | ) 46 | else: 47 | batch_list.append(batch) 48 | batch = [i] 49 | cur_length = min(self.lengths[i], self.cropped_length) 50 | 51 | if len(batch) > 0 and not self.drop_last: 52 | batch_list.append(batch) 53 | 54 | if self.shuffle: 55 | generator = torch.Generator() 56 | generator.manual_seed(self.epoch + self.seed) 57 | indices = torch.randperm(len(batch_list), generator=generator).tolist() 58 | else: 59 | indices = list(range(len(batch_list))) 60 | 61 | for i in indices: 62 | yield batch_list[i] 63 | 64 | def __len__(self): 65 | return len(list(iter(self))) 66 | 67 | 68 | class MaxLengthDistributedSampler: 69 | def __init__( 70 | self, 71 | dataset: AudioPretrainDataset, 72 | lengths: List[int], 73 | max_length: int, 74 | cropped_length: int = 160_000, 75 | num_replicas: Optional[int] = None, 76 | rank: Optional[int] = None, 77 | shuffle: bool = True, 78 | seed: int = 7122, 79 | drop_last: bool = False, 80 | ) -> None: 81 | 82 | if num_replicas is None: 83 | if not dist.is_available(): 84 | raise RuntimeError("Requires distributed package to be available") 85 | num_replicas = dist.get_world_size() 86 | if rank is None: 87 | if not dist.is_available(): 88 | raise RuntimeError("Requires distributed package to be available") 89 | rank = dist.get_rank() 90 | if rank >= num_replicas or rank < 0: 91 | raise ValueError( 92 | "Invalid rank {}, rank should be in the interval" 93 | " [0, {}]".format(rank, num_replicas - 1) 94 | ) 95 | 96 | print(f"- Rank: {rank} / Num replicas: {num_replicas}") 97 | 98 | self.dataset = dataset 99 | self.shuffle = shuffle 100 | self.seed = seed 101 | self.drop_last = drop_last 102 | self.epoch = 0 103 | self.num_replicas = num_replicas 104 | self.rank = rank 105 | 106 | self.lengths = lengths 107 | self.max_length = max_length 108 | self.cropped_length = cropped_length if cropped_length > 0 else 1000000 109 | 110 | def set_epoch(self, epoch: int): 111 | self.epoch = epoch 112 | 113 | def __iter__(self): 114 | batch_list = [] 115 | batch = [] 116 | cur_length = 0 117 | for i in range(len(self.lengths)): 118 | new_batch = batch + [i] 119 | cur_length += min(self.lengths[i], self.cropped_length) 120 | 121 | if cur_length <= self.max_length: 122 | batch = new_batch 123 | elif len(batch) == 0: 124 | raise ValueError( 125 | f"There is a single length {self.lengths[i]} larger than " 126 | f"max_length {self.max_length}. Please increase " 127 | "the max_length." 128 | ) 129 | else: 130 | batch_list.append(batch) 131 | batch = [i] 132 | cur_length = min(self.lengths[i], self.cropped_length) 133 | 134 | if len(batch) > 0 and not self.drop_last: 135 | batch_list.append(batch) 136 | 137 | if self.shuffle: 138 | generator = torch.Generator() 139 | generator.manual_seed(self.epoch + self.seed) 140 | indices = torch.randperm(len(batch_list), generator=generator).tolist() 141 | else: 142 | indices = list(range(len(batch_list))) 143 | 144 | max_index = len(indices) - len(indices) % self.num_replicas 145 | indices = indices[self.rank : max_index : self.num_replicas] 146 | for i in indices: 147 | yield batch_list[i] 148 | 149 | def __len__(self): 150 | return len(list(iter(self))) 151 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .spin import SpinModel 2 | -------------------------------------------------------------------------------- /src/model/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | import pytorch_lightning as pl 4 | import yaml 5 | 6 | 7 | class BaseModel(pl.LightningModule): 8 | def __init__(self, config) -> None: 9 | super().__init__() 10 | 11 | if isinstance(config, str) and config.split(".")[-1] in {"yaml", "yml"}: 12 | config = yaml.load(open(config, "r"), Loader=yaml.FullLoader) 13 | 14 | self.config = config 15 | self.save_hyperparameters(config) 16 | 17 | @abc.abstractmethod 18 | def forward(self, batch): 19 | raise NotImplementedError 20 | 21 | @abc.abstractmethod 22 | def training_step(self, batch, batch_idx): 23 | raise NotImplementedError 24 | 25 | @abc.abstractmethod 26 | def configure_optimizers(self): 27 | raise NotImplementedError 28 | -------------------------------------------------------------------------------- /src/model/spin.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | from collections import defaultdict 4 | from typing import Union 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn 9 | from torch.utils.data import DataLoader 10 | 11 | from src.data import ( 12 | AudioPretrainDataset, 13 | MaxLengthBatchSampler, 14 | MaxLengthDistributedSampler, 15 | collate_fn, 16 | ) 17 | from src.nn import DNN, HuBERT, SwavVQDisentangle, WavLM 18 | from src.util import compute_show_pnmi, get_scheduler, update_padding_mask 19 | 20 | from .base import BaseModel 21 | 22 | logger = logging.getLogger("spin") 23 | 24 | 25 | def get_pred_head(type_name: str, hid_dim: int, config: dict) -> Union[None, nn.Module]: 26 | if type_name == "None": 27 | return None 28 | if type_name == "DNN": 29 | return DNN(hid_dim, **config) 30 | raise NotImplementedError(type_name) 31 | 32 | 33 | def get_loss(type_name: str, hid_dim: int, config: dict) -> nn.Module: 34 | if type_name == "SwavVQDisentangle": 35 | return SwavVQDisentangle(hid_dim, **config) 36 | raise NotImplementedError(type_name) 37 | 38 | 39 | class SpinModel(BaseModel): 40 | def __init__(self, config, num_view: int = 2) -> None: 41 | super().__init__(config) 42 | 43 | config = copy.deepcopy(config) 44 | config = config["model"] 45 | logger.info(f"Model config: {config}") 46 | 47 | self.encoder_type = config["encoder"].pop("type", "HuBERT") 48 | self.pred_head_type = config["pred_head"].pop("type", "DNN") 49 | self.loss_type = config["loss"].pop("type", "SwavVQDisentangle") 50 | 51 | logger.info(f"Encoder: {self.encoder_type}") 52 | logger.info(f"Prediction head: {self.pred_head_type}") 53 | logger.info(f"Loss: {self.loss_type}") 54 | 55 | # Setup number of views 56 | self.normalize = config.get("normalize", False) 57 | self.num_view = num_view 58 | assert num_view == 2, num_view # NOTE: currently we support 2 views only 59 | 60 | # Setup encoder model 61 | if self.encoder_type in {"HuBERT", "WavLM"}: 62 | self.use_layer = config["encoder"].pop("use_layer", 12) 63 | self.encoder = eval(self.encoder_type)(**config["encoder"]) 64 | hid_dim = self.encoder.hidden_sizes[self.use_layer] 65 | self.encoder_rate = 320 66 | logger.info(f"Taking features from layer {self.use_layer}") 67 | else: 68 | raise NotImplementedError(self.encoder_type) 69 | 70 | # All layers to be processed 71 | self.use_layers = [self.use_layer] 72 | logger.info(f"All selected layers: {self.use_layers}") 73 | 74 | # Setup prediction head 75 | if len(self.use_layers) == 1: 76 | self.pred_head = get_pred_head( 77 | self.pred_head_type, hid_dim, config["pred_head"] 78 | ) 79 | hid_dim = self.pred_head.out_dim 80 | else: 81 | self.pred_head = nn.ModuleList( 82 | [ 83 | get_pred_head(self.pred_head_type, hid_dim, config["pred_head"]) 84 | for _ in self.use_layers 85 | ] 86 | ) 87 | hid_dim = self.pred_head[0].out_dim 88 | 89 | # Setup loss function 90 | self.loss_module = get_loss(self.loss_type, hid_dim, config["loss"]) 91 | 92 | # Validation 93 | self.val_uid2hyp = {} 94 | 95 | def normalize_wavs( 96 | self, 97 | wavs: torch.Tensor, 98 | wavs_len: torch.LongTensor, 99 | ) -> torch.Tensor: 100 | with torch.no_grad(): 101 | for i in len(wavs): 102 | wavs[i, : wavs_len[i]] = F.layer_norm( 103 | wavs[i, : wavs_len[i]], wavs_len[i] 104 | ) 105 | return wavs 106 | 107 | def forward_features( 108 | self, 109 | wavs: torch.Tensor, 110 | padding_mask: torch.BoolTensor, 111 | ): 112 | if self.encoder_type in {"HuBERT", "WavLM"}: 113 | feat_list, feat_len, padding_mask = self.encoder(wavs, padding_mask) 114 | repr_list = [feat_list[l] for l in self.use_layers] 115 | 116 | return repr_list, feat_list, feat_len, padding_mask 117 | 118 | def forward_pred_head( 119 | self, feat: torch.Tensor, feat_len: torch.LongTensor, i: int = None 120 | ): 121 | if len(self.use_layers) == 1: 122 | return self.pred_head(feat, feat_len) 123 | else: 124 | assert isinstance(i, int), i 125 | return self.pred_head[i](feat, feat_len) 126 | 127 | def forward(self, batch, feat_only: bool = False): 128 | # NOTE: padding_mask is 1 when the position is padded 129 | wavs, wavs_len, padding_mask = batch 130 | 131 | wavs.masked_fill_(padding_mask, 0.0) 132 | 133 | # Normalize wavs 134 | if self.normalize: 135 | wavs = self.normalize_wavs(wavs, wavs_len) 136 | 137 | # Extract features 138 | repr_list, feat_list, feat_len, padding_mask = self.forward_features( 139 | wavs, padding_mask 140 | ) 141 | padding_mask = update_padding_mask(padding_mask, repr_list[0].shape[1]) 142 | 143 | # Prediction head 144 | repr_list = [ 145 | self.forward_pred_head(repr, feat_len, i) 146 | for i, repr in enumerate(repr_list) 147 | ] 148 | 149 | # Return results 150 | if feat_only: 151 | outputs = { 152 | "repr_list": repr_list, 153 | "feat_len": feat_len, 154 | "feat_list": feat_list, 155 | "padding_mask": padding_mask, 156 | } 157 | if self.loss_type == "SwavVQDisentangle": 158 | if self.loss_module.l2_norm: 159 | outputs["repr_list"] = F.normalize(repr_list[0], dim=-1) 160 | logits, codes = self.loss_module.produce_targets( 161 | outputs["repr_list"], normalized=True 162 | ) 163 | outputs["logits"] = logits 164 | outputs["codes"] = codes 165 | 166 | return outputs 167 | 168 | # feat: (Batch * View, Time, Dim) 169 | # Split batch into views 170 | repr_views = [ 171 | [ 172 | r[i :: self.num_view][~padding_mask[i :: self.num_view]] 173 | for r in repr_list 174 | ] 175 | for i in range(self.num_view) 176 | ] 177 | 178 | # Computes loss from each pair of views 179 | total_loss = 0 180 | loss_res = defaultdict(list) 181 | 182 | # Main loss 183 | if self.loss_type in {"SwavVQDisentangle"}: 184 | res = self.loss_module.cal_loss(repr_views[0][0], repr_views[1][0]) 185 | total_loss += res.pop("loss") 186 | for k in res: 187 | loss_res[f"{k}"].append(res[k]) 188 | 189 | for k in loss_res: 190 | loss_res[k] = sum(loss_res[k]) / len(loss_res[k]) 191 | 192 | return total_loss, loss_res 193 | 194 | def training_step(self, batch, batch_idx): 195 | total_loss, loss_res = self.forward(batch) 196 | 197 | self.log("loss", total_loss) 198 | for k, v in loss_res.items(): 199 | if k == "acc": 200 | self.log(k, v, prog_bar=True) 201 | else: 202 | self.log(k, v) 203 | 204 | return total_loss 205 | 206 | def validation_step(self, batch, batch_idx): 207 | wav_list, wav_len, padding_mask, uid_list = batch 208 | results = self.forward( 209 | (wav_list, wav_len, padding_mask), 210 | feat_only=True, 211 | ) 212 | codes = results.get("codes", None) 213 | if codes is None: 214 | return 215 | 216 | code = codes.cpu().numpy() 217 | feat_len = results["feat_len"] 218 | for i, (uid, c) in enumerate(zip(uid_list, code)): 219 | self.val_uid2hyp[uid] = c[: feat_len[i]] 220 | 221 | def on_validation_epoch_end(self) -> None: 222 | check_1, check_2 = True, True 223 | try: 224 | uid2ref = self.trainer.val_dataloaders[0].dataset.uid2refs 225 | except: 226 | check_1 = False 227 | 228 | if not check_1: 229 | try: 230 | uid2ref = self.trainer.val_dataloaders.dataset.uid2refs 231 | except: 232 | check_2 = False 233 | 234 | if not check_1 and not check_2: 235 | logger.info("Cannot find uid2ref in validation dataloader, skip PNMI") 236 | return 237 | 238 | if len(self.val_uid2hyp) == 0: 239 | return 240 | 241 | res = compute_show_pnmi( 242 | uid2ref, self.val_uid2hyp, upsample=self.encoder_rate // 160 243 | ) 244 | self.log("cls_pur", res["cls_pur"]) 245 | self.log("phn_pur", res["phn_pur"]) 246 | self.log("pnmi", res["pnmi"]) 247 | self.val_uid2hyp.clear() 248 | 249 | def on_before_zero_grad(self, optimizer) -> None: 250 | if self.loss_type == "SwavVQDisentangle": 251 | self.loss_module.normalize_codebook() 252 | 253 | return super().on_before_zero_grad(optimizer) 254 | 255 | def on_train_epoch_start(self) -> None: 256 | super().on_train_epoch_start() 257 | 258 | try: 259 | self.trainer.train_dataloader.batch_sampler.set_epoch(self.current_epoch) 260 | logger.info(f"Update epoch for batch sampler to {self.current_epoch}") 261 | except: 262 | logger.warn( 263 | "Unable to update epoch for batch sampler (possibly using fixed batch_size)" 264 | ) 265 | 266 | def configure_optimizers(self): 267 | params = [] 268 | if self.encoder_type in {"HuBERT", "WavLM"}: 269 | params += self.encoder.trainable_parameters() 270 | else: 271 | params += list(self.encoder.parameters()) 272 | 273 | if self.pred_head: 274 | params += list(self.pred_head.parameters()) 275 | params += list(self.loss_module.parameters()) 276 | 277 | optimizer = getattr(torch.optim, self.config["optim"]["optimizer"]["name"])( 278 | params, **self.config["optim"]["optimizer"]["args"] 279 | ) 280 | 281 | if self.config["optim"].get("scheduler", None): 282 | scheduler = get_scheduler( 283 | self.config["optim"]["scheduler"]["name"], 284 | optimizer, 285 | **self.config["optim"]["scheduler"]["args"], 286 | ) 287 | 288 | return { 289 | "optimizer": optimizer, 290 | "lr_scheduler": { 291 | "scheduler": scheduler, 292 | "interval": "step", 293 | "frequency": 1, 294 | }, 295 | } 296 | else: 297 | return optimizer 298 | 299 | def set_random_seed(self, seed: int = 7122): 300 | self.seed = seed 301 | 302 | def set_njobs(self, njobs: int = 0): 303 | self.njobs = njobs 304 | 305 | def set_use_ddp(self, use_ddp: bool = False): 306 | self.use_ddp = use_ddp 307 | 308 | def train_dataloader(self): 309 | dataset = AudioPretrainDataset(**self.config["data"]) 310 | if "batch_len" in self.config["hparam"]: 311 | if self.use_ddp: 312 | sampler = MaxLengthDistributedSampler( 313 | dataset, 314 | dataset.data_lens, 315 | max_length=self.config["hparam"]["batch_len"], 316 | cropped_length=self.config["data"]["random_crop_len"], 317 | shuffle=True, 318 | drop_last=True, 319 | seed=self.seed, 320 | ) 321 | else: 322 | sampler = MaxLengthBatchSampler( 323 | dataset.data_lens, 324 | max_length=self.config["hparam"]["batch_len"], 325 | cropped_length=self.config["data"]["random_crop_len"], 326 | shuffle=True, 327 | drop_last=True, 328 | seed=self.seed, 329 | ) 330 | loader = DataLoader( 331 | dataset, 332 | batch_sampler=sampler, 333 | num_workers=self.njobs, 334 | pin_memory=True, 335 | collate_fn=collate_fn, 336 | ) 337 | elif "batch_size" in self.config["hparam"]: 338 | loader = DataLoader( 339 | dataset, 340 | batch_size=self.config["hparam"]["batch_size"], 341 | num_workers=self.njobs, 342 | pin_memory=True, 343 | collate_fn=collate_fn, 344 | shuffle=True, 345 | drop_last=True, 346 | ) 347 | 348 | return loader 349 | -------------------------------------------------------------------------------- /src/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .dnn import DNN 2 | from .hubert import HuBERT 3 | from .swav_vq_dis import SwavVQDisentangle 4 | from .wavlm import WavLM 5 | -------------------------------------------------------------------------------- /src/nn/dnn.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class DNN(nn.Module): 8 | def __init__( 9 | self, 10 | in_dim: int, 11 | hid_dims: List[int], 12 | dropout: float = 0.0, 13 | activation: str = "ReLU", 14 | activate_last: bool = False, 15 | ) -> None: 16 | super().__init__() 17 | 18 | self.in_dim = in_dim 19 | self.out_dim = hid_dims[-1] 20 | self.activate_last = activate_last 21 | 22 | assert len(hid_dims) > 0, len(hid_dims) 23 | hid_dims = [in_dim] + hid_dims 24 | 25 | self.layers = nn.ModuleList( 26 | [nn.Linear(hid_dims[i], hid_dims[i + 1]) for i in range(len(hid_dims) - 1)] 27 | ) 28 | self.num_layer = len(self.layers) 29 | self.dropout = nn.Dropout(dropout) 30 | n_acts = self.num_layer - (0 if self.activate_last else 1) 31 | self.acts = nn.ModuleList([getattr(nn, activation)() for _ in range(n_acts)]) 32 | 33 | def forward(self, x: torch.Tensor, x_len: torch.LongTensor = None) -> torch.Tensor: 34 | for i in range(self.num_layer): 35 | x = self.layers[i](x) 36 | if i < self.num_layer - 1 or self.activate_last: 37 | x = self.dropout(x) 38 | x = self.acts[i](x) 39 | return x 40 | -------------------------------------------------------------------------------- /src/nn/hubert.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any, List 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from s3prl.upstream.hubert.convert import load_converted_model 7 | from s3prl.upstream.hubert.hubert_model import ( 8 | HubertConfig, 9 | HubertModel, 10 | HubertPretrainingConfig, 11 | ) 12 | from s3prl.upstream.utils import merge_with_parent 13 | from s3prl.upstream.wav2vec2.wav2vec2_model import GradMultiply 14 | from s3prl.util.download import _urls_to_filepaths 15 | from torch import nn 16 | 17 | from src.util import ( 18 | freeze_module, 19 | init_module_bert, 20 | init_module_cnn, 21 | init_module_pos_conv, 22 | padding_to_len, 23 | ) 24 | 25 | logger = logging.getLogger("hubert") 26 | 27 | 28 | def get_hubert_configs(ckpt_path): 29 | ckpt_state = torch.load(ckpt_path, map_location="cpu") 30 | 31 | for required_key in [ 32 | "task_cfg", 33 | "model_cfg", 34 | "model_weight", 35 | "dictionaries_symbols", 36 | ]: 37 | if required_key not in ckpt_state: 38 | raise ValueError( 39 | f"{ckpt_path} is not a valid checkpoint since the required key: {required_key} is missing" 40 | ) 41 | 42 | task_cfg = merge_with_parent(HubertPretrainingConfig, ckpt_state["task_cfg"]) 43 | model_cfg = merge_with_parent(HubertConfig, ckpt_state["model_cfg"]) 44 | dictionaries = ckpt_state["dictionaries_symbols"] 45 | return model_cfg, task_cfg, dictionaries 46 | 47 | 48 | def random_hubert( 49 | model_cfg: HubertConfig, 50 | task_cfg: HubertPretrainingConfig, 51 | dictionaries: List[Any], 52 | ): 53 | model = HubertModel(model_cfg, task_cfg, dictionaries) 54 | return model, model_cfg, task_cfg 55 | 56 | 57 | class HuBERT(nn.Module): 58 | def __init__( 59 | self, 60 | path_or_url: str = None, 61 | refresh: bool = False, 62 | pre_normalize: bool = False, 63 | normalize: bool = False, 64 | feat_select: str = "x", 65 | randomize_all: bool = False, 66 | randomize_layers: List[int] = [], 67 | freeze_all: bool = False, 68 | freeze_layers: List[int] = [], 69 | disable_pos: bool = False, 70 | masking: bool = False, 71 | ): 72 | super().__init__() 73 | 74 | ckpt = "https://huggingface.co/s3prl/converted_ckpts/resolve/main/hubert_base_ls960.pt" 75 | if path_or_url is not None: 76 | ckpt = path_or_url 77 | if ckpt.startswith("https"): 78 | ckpt = _urls_to_filepaths(ckpt, refresh=refresh) 79 | 80 | if not randomize_all: 81 | model, task_cfg = load_converted_model(ckpt) 82 | else: 83 | model_cfg, task_cfg, dictionaries = get_hubert_configs(ckpt) 84 | model, model_cfg, task_cfg = random_hubert( 85 | model_cfg, task_cfg, dictionaries 86 | ) 87 | logger.info("Random HuBERT used") 88 | 89 | self.model: HubertModel = model 90 | self.task_cfg = task_cfg 91 | self.wav_normalize = task_cfg.normalize 92 | self.pre_normalize = pre_normalize 93 | self.normalize = normalize 94 | self.masking = masking 95 | self.num_layers = len(self.model.encoder.layers) + 1 # CNN + 12 Transformer 96 | self.hidden_sizes = [self.model.encoder.embedding_dim] * self.num_layers 97 | self.model.encoder.layerdrop = 0.0 98 | 99 | self.feat_select = 0 100 | if feat_select == "att": 101 | self.feat_select = 1 102 | if feat_select == "ffn": 103 | self.feat_select = 2 104 | 105 | logger.info(f"Feature selection: {feat_select} (index: {self.feat_select})") 106 | logger.info( 107 | f"Randomize all = {randomize_all} (randomize layers = {randomize_layers})" 108 | ) 109 | logger.info(f"Freeze all = {freeze_all} (freeze layers = {freeze_layers})") 110 | logger.info(f"Apply masking to features = {masking}") 111 | 112 | self.randomize_all = randomize_all 113 | self.randomize_layers = randomize_layers 114 | self.freeze_all = freeze_all 115 | self.freeze_layers = freeze_layers 116 | 117 | self.freeze_cnn = (0 in freeze_layers) or self.freeze_all 118 | self.freeze_pos = ("pos" in freeze_layers) or self.freeze_all 119 | self.disable_pos = disable_pos 120 | if disable_pos: 121 | logger.info("Disabled CNN positional encoding") 122 | 123 | # Randomize weights 124 | if randomize_all: 125 | randomize_layers = [] 126 | if len(randomize_layers) > 0: 127 | for i in randomize_layers: 128 | if i == 0: 129 | self.model.feature_extractor.apply(init_module_cnn) 130 | self.model.layer_norm.reset_parameters() 131 | if self.model.post_extract_proj is not None: 132 | self.model.post_extract_proj.reset_parameters() 133 | elif i == "pos": 134 | self.model.encoder.pos_conv.apply(init_module_pos_conv) 135 | else: 136 | self.model.encoder.layers[i - 1].apply(init_module_bert) 137 | if i == self.num_layers - 1 and self.model.encoder.layer_norm_first: 138 | self.model.encoder.layer_norm.reset_parameters() 139 | 140 | # Freeze weights 141 | if freeze_all: 142 | freeze_module(self.model) 143 | elif len(freeze_layers) > 0: 144 | for i in freeze_layers: 145 | if i == 0: 146 | self.model.feature_grad_mult = 0.0 147 | freeze_module(self.model.feature_extractor) 148 | freeze_module(self.model.layer_norm) 149 | if self.model.post_extract_proj is not None: 150 | freeze_module(self.model.post_extract_proj) 151 | elif i == "pos": 152 | freeze_module(self.model.encoder.pos_conv) 153 | else: 154 | assert isinstance(i, int), i 155 | freeze_module(self.model.encoder.layers[i - 1]) 156 | 157 | if self.freeze_cnn: 158 | self.model.feature_grad_mult = 0.0 159 | 160 | self.model.remove_pretraining_modules() 161 | 162 | def trainable_parameters(self): 163 | params = [] 164 | 165 | if self.freeze_all: 166 | return [] 167 | 168 | if not self.freeze_all and len(self.freeze_layers) == 0: 169 | logger.info("Trains the entire model") 170 | return self.model.parameters() 171 | 172 | params = [] 173 | for i in ["pos"] + list(range(self.num_layers)): 174 | if i in self.freeze_layers: 175 | continue 176 | if i == 0: 177 | params += list(self.model.feature_extractor.parameters()) 178 | params += list(self.model.layer_norm.parameters()) 179 | if self.model.post_extract_proj is not None: 180 | params += list(self.model.post_extract_proj.parameters()) 181 | elif i == "pos": 182 | params += list(self.model.encoder.pos_conv.parameters()) 183 | else: 184 | params += list(self.model.encoder.layers[i - 1].parameters()) 185 | if i == self.num_layers - 1 and self.model.encoder.layer_norm_first: 186 | params += list(self.model.encoder.layer_norm.parameters()) 187 | 188 | if self.masking: 189 | params += list(self.model.mask_emb.parameters()) 190 | 191 | return params 192 | 193 | def forward_features(self, x: torch.Tensor) -> torch.Tensor: 194 | if self.model.feature_grad_mult > 0: 195 | features = self.model.feature_extractor(x) 196 | if self.model.feature_grad_mult != 1.0: 197 | features = GradMultiply.apply(features, self.model.feature_grad_mult) 198 | else: 199 | with torch.no_grad(): 200 | features = self.model.feature_extractor(x) 201 | 202 | return features 203 | 204 | def forward( 205 | self, 206 | wavs: torch.FloatTensor, 207 | padding_mask: torch.BoolTensor, 208 | ): 209 | if self.pre_normalize: 210 | with torch.no_grad(): 211 | wavs = (wavs - wavs[~padding_mask].mean()) / ( 212 | wavs[~padding_mask].std() + 1e-5 213 | ) 214 | 215 | if self.wav_normalize: 216 | with torch.no_grad(): 217 | wav_len = padding_to_len(padding_mask) 218 | for i in range(len(wavs)): 219 | wavs[i, : wav_len[i]] = F.layer_norm( 220 | wavs[i, : wav_len[i]], (wav_len[i],) 221 | ) 222 | 223 | features = self.forward_features(wavs).transpose(1, 2) 224 | if self.freeze_cnn: 225 | with torch.no_grad(): 226 | features = self.model.layer_norm(features) 227 | if self.model.post_extract_proj is not None: 228 | features = self.model.post_extract_proj(features) 229 | else: 230 | features = self.model.layer_norm(features) 231 | if self.model.post_extract_proj is not None: 232 | features = self.model.post_extract_proj(features) 233 | 234 | if padding_mask is not None: 235 | padding_mask = self.model.forward_padding_mask(features, padding_mask) 236 | 237 | x = self.model.dropout_input(features) 238 | if self.masking and self.training: 239 | x, mask_indices = self.model.apply_mask(x, padding_mask, None) 240 | 241 | # feature: (B, T, D), float 242 | # x: (B, T, D), float 243 | # padding_mask: (B, T), bool 244 | 245 | if self.freeze_all: 246 | with torch.no_grad(): 247 | _, layer_results = self.model.encoder( 248 | x, 249 | padding_mask=padding_mask, 250 | disable_pos=self.disable_pos, 251 | ) 252 | else: 253 | _, layer_results = self.model.encoder( 254 | x, 255 | padding_mask=padding_mask, 256 | freeze_pos=self.freeze_pos, 257 | freeze_layers=self.freeze_layers, 258 | disable_pos=self.disable_pos, 259 | ) 260 | 261 | feat_list = [features] + [ 262 | feat[self.feat_select].transpose(0, 1) for feat in layer_results 263 | ] 264 | 265 | if self.normalize: 266 | feat_list = [F.layer_norm(feat, feat.shape[-1:]) for feat in feat_list] 267 | 268 | feat_len = padding_to_len(padding_mask) 269 | 270 | return feat_list, feat_len, padding_mask 271 | -------------------------------------------------------------------------------- /src/nn/swav_vq_dis.py: -------------------------------------------------------------------------------- 1 | # Ref: https://github.com/facebookresearch/swav/blob/main/main_swav.py 2 | 3 | import logging 4 | 5 | import torch 6 | import torch.distributed as dist 7 | import torch.nn.functional as F 8 | from torch import nn 9 | 10 | logger = logging.getLogger("swav_vq_dis") 11 | 12 | 13 | @torch.no_grad() 14 | def compute_sinkhorn( 15 | out: torch.Tensor, epsilon: float, sinkhorn_iterations: int 16 | ) -> torch.Tensor: 17 | # out: (B, K) 18 | B, K = out.shape 19 | Q = out.div(epsilon).exp().t() 20 | # Q is K-by-B for consistency with notations from our paper 21 | 22 | # make the matrix sums to 1 23 | if dist.is_initialized(): 24 | sum_Q = Q.sum() 25 | dist.all_reduce(sum_Q) 26 | Q.div_(sum_Q) 27 | else: 28 | Q.div_(Q.sum()) 29 | 30 | for _ in range(sinkhorn_iterations): 31 | # normalize each row: total weight per prototype must be 1/K 32 | if dist.is_initialized(): 33 | sum_of_rows = Q.sum(dim=1, keepdim=True) 34 | dist.all_reduce(sum_of_rows) 35 | Q.div_(sum_of_rows * K) 36 | else: 37 | Q.div_(Q.sum(dim=1, keepdim=True) * K) 38 | 39 | # normalize each column: total weight per sample must be 1/B 40 | Q.div_(Q.sum(dim=0, keepdim=True) * B) 41 | 42 | Q.mul_(B) # the colomns must sum to 1 so that Q is an assignment 43 | return Q.t() 44 | 45 | 46 | class SwavVQDisentangle(nn.Module): 47 | def __init__( 48 | self, 49 | dim: int, 50 | num_vars: int, 51 | epsilon: float = 0.05, 52 | sinkhorn_iters: int = 3, 53 | temp: float = 0.1, 54 | l2_norm: bool = True, 55 | hard_target: bool = False, 56 | prob_ratio: float = 1.0, 57 | ) -> None: 58 | super().__init__() 59 | 60 | self.dim = dim 61 | self.num_vars = num_vars 62 | self.epsilon = epsilon 63 | self.sinkhorn_iters = sinkhorn_iters 64 | self.temp = temp 65 | self.l2_norm = l2_norm 66 | self.hard_target = hard_target 67 | self.prob_ratio = prob_ratio 68 | 69 | logger.info(f"Codebook size: {num_vars}") 70 | self.codebook = nn.Linear(dim, num_vars, bias=False) 71 | 72 | def produce_targets(self, z: torch.Tensor, normalized: bool = False): 73 | if self.l2_norm and not normalized: 74 | z = F.normalize(z, dim=-1) 75 | logits = self.codebook(z) / self.temp 76 | codes = torch.argmax(logits, -1) 77 | return logits, codes 78 | 79 | @torch.no_grad() 80 | def normalize_codebook(self) -> None: 81 | w = self.codebook.weight.data.clone() 82 | w = F.normalize(w, dim=1, p=2) 83 | self.codebook.weight.copy_(w) 84 | 85 | @torch.no_grad() 86 | def zero_grad_codebook(self) -> None: 87 | self.codebook.zero_grad() 88 | 89 | @torch.no_grad() 90 | def copy_codebook(self) -> None: 91 | self.normalize_codebook() 92 | self.codebook_copy = self.codebook.weight.data.detach() 93 | 94 | @torch.no_grad() 95 | def restore_codebook(self) -> None: 96 | self.codebook.weight.copy_(self.codebook_copy) 97 | 98 | def forward( 99 | self, 100 | z_1: torch.Tensor, 101 | z_2: torch.Tensor, 102 | ): 103 | B = len(z_1) 104 | 105 | if self.l2_norm: 106 | z_1 = F.normalize(z_1, dim=1) 107 | z_2 = F.normalize(z_2, dim=1) 108 | 109 | logits_1: torch.Tensor = self.codebook(z_1) # (Batch, Num Codes) 110 | logits_2: torch.Tensor = self.codebook(z_2) # (Batch, Num Codes) 111 | 112 | # Compute targets 113 | with torch.no_grad(): 114 | tgt_logits_w_1 = logits_1 * self.prob_ratio + logits_2 * ( 115 | 1 - self.prob_ratio 116 | ) 117 | tgt_logits_w_2 = logits_2 * self.prob_ratio + logits_1 * ( 118 | 1 - self.prob_ratio 119 | ) 120 | 121 | tgt_probs_1 = compute_sinkhorn( 122 | tgt_logits_w_1.detach(), self.epsilon, self.sinkhorn_iters 123 | ) 124 | tgt_probs_2 = compute_sinkhorn( 125 | tgt_logits_w_2.detach(), self.epsilon, self.sinkhorn_iters 126 | ) 127 | 128 | # Compute cross-entropy loss 129 | logits_1.div_(self.temp) 130 | logits_2.div_(self.temp) 131 | log_prob_1 = logits_1.log_softmax(1) 132 | log_prob_2 = logits_2.log_softmax(1) 133 | 134 | loss = 0 135 | if self.hard_target: 136 | loss_ce = 0.5 * ( 137 | F.cross_entropy( 138 | logits_2, 139 | tgt_probs_1.argmax(-1), 140 | ) 141 | + F.cross_entropy( 142 | logits_1, 143 | tgt_probs_2.argmax(-1), 144 | ) 145 | ) 146 | else: 147 | loss_ce = -0.5 * ( 148 | (tgt_probs_1 * log_prob_2).sum(1).mean() 149 | + (tgt_probs_2 * log_prob_1).sum(1).mean() 150 | ) 151 | 152 | loss += loss_ce 153 | result = {"loss_ce": loss_ce, "batch_size": B} 154 | result["loss"] = loss 155 | 156 | with torch.no_grad(): 157 | logits = torch.cat([logits_1, logits_2], dim=0) 158 | _, k = logits.max(-1) 159 | hard_x = logits.new_zeros(*logits.shape).scatter_(1, k.view(-1, 1), 1.0) 160 | hard_probs = hard_x.float().mean(0) 161 | result["code_perplexity"] = ( 162 | torch.exp(-torch.sum(hard_probs * torch.log(hard_probs + 1e-7), dim=-1)) 163 | .sum() 164 | .cpu() 165 | .detach() 166 | ) 167 | 168 | avg_probs = logits.float().softmax(-1).mean(0) 169 | result["prob_perplexity"] = ( 170 | torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-7), dim=-1)) 171 | .sum() 172 | .cpu() 173 | .detach() 174 | ) 175 | acc_1 = ( 176 | (torch.argmax(logits_1, dim=1) == torch.argmax(tgt_probs_2, dim=1)) 177 | .float() 178 | .mean() 179 | .cpu() 180 | .detach() 181 | .item() 182 | ) 183 | acc_2 = ( 184 | (torch.argmax(logits_2, dim=1) == torch.argmax(tgt_probs_1, dim=1)) 185 | .float() 186 | .mean() 187 | .cpu() 188 | .detach() 189 | .item() 190 | ) 191 | result["acc"] = float((acc_1 + acc_2) / 2) 192 | result["acc_1"] = float(acc_1) 193 | result["acc_2"] = float(acc_2) 194 | 195 | return result 196 | 197 | def cal_loss( 198 | self, 199 | z_1: torch.Tensor, 200 | z_2: torch.Tensor, 201 | ): 202 | return self.forward(z_1, z_2) 203 | -------------------------------------------------------------------------------- /src/nn/wavlm.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from s3prl.upstream.wavlm.modules import GradMultiply 7 | from s3prl.upstream.wavlm.WavLM import WavLM as WavLMModel 8 | from s3prl.upstream.wavlm.WavLM import WavLMConfig 9 | from s3prl.util.download import _urls_to_filepaths 10 | from torch import nn 11 | 12 | from src.util import ( 13 | freeze_module, 14 | init_module_bert, 15 | init_module_cnn, 16 | init_module_pos_conv, 17 | padding_to_len, 18 | ) 19 | 20 | logger = logging.getLogger("wavlm") 21 | 22 | 23 | class WavLM(nn.Module): 24 | def __init__( 25 | self, 26 | path_or_url: str = None, 27 | refresh: bool = False, 28 | pre_normalize: bool = False, 29 | normalize: bool = False, 30 | feat_select: str = "x", 31 | randomize_all: bool = False, 32 | randomize_layers: List[int] = [], 33 | freeze_all: bool = False, 34 | freeze_layers: List[int] = [], 35 | ): 36 | super().__init__() 37 | 38 | ckpt = "https://huggingface.co/s3prl/converted_ckpts/resolve/main/wavlm_base.pt" 39 | if path_or_url is not None: 40 | ckpt = path_or_url 41 | if ckpt.startswith("https"): 42 | ckpt = _urls_to_filepaths(ckpt, refresh=refresh) 43 | 44 | checkpoint = torch.load(ckpt) 45 | self.cfg = WavLMConfig(checkpoint["cfg"]) 46 | self.model = WavLMModel(self.cfg) 47 | self.model.load_state_dict(checkpoint["model"]) 48 | 49 | self.wav_normalize = self.cfg.normalize 50 | self.pre_normalize = pre_normalize 51 | self.normalize = normalize 52 | self.num_layers = self.cfg.encoder_layers + 1 # CNN + 12 Transformer 53 | self.hidden_sizes = [self.cfg.encoder_embed_dim] * self.num_layers 54 | 55 | self.model.feature_grad_mult = 0.0 56 | self.model.encoder.layerdrop = 0.0 57 | 58 | self.feat_select = 0 59 | if feat_select == "att": 60 | self.feat_select = 1 61 | 62 | logger.info(f"Feature selection: {feat_select}") 63 | logger.info( 64 | f"Randomize all = {randomize_all} (randomize layers = {randomize_layers})" 65 | ) 66 | logger.info(f"Freeze all = {freeze_all} (freeze layers = {freeze_layers})") 67 | 68 | self.randomize_all = randomize_all 69 | self.randomize_layers = randomize_layers 70 | self.freeze_all = freeze_all 71 | self.freeze_layers = freeze_layers 72 | 73 | self.freeze_cnn = (0 in freeze_layers) or self.freeze_all 74 | self.freeze_pos = ("pos" in freeze_layers) or self.freeze_all 75 | 76 | # Randomize weights 77 | if randomize_all: 78 | randomize_layers = ["pos"] + list(range(self.num_layers)) 79 | if len(randomize_layers) > 0: 80 | for i in randomize_layers: 81 | if i == 0: 82 | self.model.feature_extractor.apply(init_module_cnn) 83 | self.model.layer_norm.reset_parameters() 84 | if self.model.post_extract_proj is not None: 85 | self.model.post_extract_proj.reset_parameters() 86 | elif i == "pos": 87 | self.model.encoder.pos_conv.apply(init_module_pos_conv) 88 | else: 89 | self.model.encoder.layers[i - 1].apply(init_module_bert) 90 | if i == self.num_layers - 1 and self.model.encoder.layer_norm_first: 91 | self.model.encoder.layer_norm.reset_parameters() 92 | 93 | # Freeze weights 94 | if freeze_all: 95 | freeze_module(self.model) 96 | elif len(freeze_layers) > 0: 97 | for i in freeze_layers: 98 | if i == 0: 99 | self.model.feature_grad_mult = 0.0 100 | freeze_module(self.model.feature_extractor) 101 | freeze_module(self.model.layer_norm) 102 | if self.model.post_extract_proj is not None: 103 | freeze_module(self.model.post_extract_proj) 104 | elif i == "pos": 105 | freeze_module(self.model.encoder.pos_conv) 106 | else: 107 | assert isinstance(i, int), i 108 | freeze_module(self.model.encoder.layers[i - 1]) 109 | 110 | if not self.freeze_cnn: 111 | self.model.feature_grad_mult = 1.0 112 | 113 | def trainable_parameters(self): 114 | params = [] 115 | 116 | if self.freeze_all: 117 | return [] 118 | 119 | if not self.freeze_all and len(self.freeze_layers) == 0: 120 | logger.info("Trains the entire model") 121 | return self.model.parameters() 122 | 123 | params = [] 124 | for i in ["pos"] + list(range(self.num_layers)): 125 | if i in self.freeze_layers: 126 | continue 127 | if i == 0: 128 | params += list(self.model.feature_extractor.parameters()) 129 | params += list(self.model.layer_norm.parameters()) 130 | if self.model.post_extract_proj is not None: 131 | params += list(self.model.post_extract_proj.parameters()) 132 | elif i == "pos": 133 | params += list(self.model.encoder.pos_conv.parameters()) 134 | else: 135 | params += list(self.model.encoder.layers[i - 1].parameters()) 136 | if i == self.num_layers - 1 and self.model.encoder.layer_norm_first: 137 | params += list(self.model.encoder.layer_norm.parameters()) 138 | 139 | return params 140 | 141 | def forward_features(self, x: torch.Tensor) -> torch.Tensor: 142 | if self.model.feature_grad_mult > 0: 143 | features = self.model.feature_extractor(x) 144 | if self.feature_grad_mult != 1.0: 145 | features = GradMultiply.apply(features, self.model.feature_grad_mult) 146 | else: 147 | with torch.no_grad(): 148 | features = self.model.feature_extractor(x) 149 | 150 | return features 151 | 152 | def forward( 153 | self, 154 | wavs: torch.FloatTensor, 155 | padding_mask: torch.BoolTensor, 156 | ): 157 | if self.pre_normalize: 158 | with torch.no_grad(): 159 | wavs = (wavs - wavs[~padding_mask].mean()) / ( 160 | wavs[~padding_mask].std() + 1e-5 161 | ) 162 | 163 | if self.wav_normalize: 164 | with torch.no_grad(): 165 | wav_len = padding_to_len(padding_mask) 166 | for i in range(len(wavs)): 167 | wavs[i, : wav_len[i]] = F.layer_norm( 168 | wavs[i, : wav_len[i]], (wav_len[i],) 169 | ) 170 | 171 | features = self.forward_features(wavs).transpose(1, 2) 172 | if self.freeze_cnn: 173 | with torch.no_grad(): 174 | features = self.model.layer_norm(features) 175 | if self.model.post_extract_proj is not None: 176 | features = self.model.post_extract_proj(features) 177 | else: 178 | features = self.model.layer_norm(features) 179 | if self.model.post_extract_proj is not None: 180 | features = self.model.post_extract_proj(features) 181 | 182 | if padding_mask is not None: 183 | padding_mask = self.model.forward_padding_mask(features, padding_mask) 184 | 185 | x = self.model.dropout_input(features) 186 | 187 | # feature: (B, T, D), float 188 | # x: (B, T, D), float 189 | # padding_mask: (B, T), bool 190 | 191 | if self.freeze_all: 192 | with torch.no_grad(): 193 | _, layer_results = self.model.encoder( 194 | x, 195 | padding_mask=padding_mask, 196 | ) 197 | else: 198 | _, layer_results = self.model.encoder( 199 | x, 200 | padding_mask=padding_mask, 201 | freeze_pos=self.freeze_pos, 202 | freeze_layers=self.freeze_layers, 203 | ) 204 | 205 | feat_list = [features] + [ 206 | feat[self.feat_select].transpose(0, 1) for feat in layer_results 207 | ] 208 | 209 | if self.normalize: 210 | feat_list = [F.layer_norm(feat, feat.shape[-1:]) for feat in feat_list] 211 | 212 | feat_len = padding_to_len(padding_mask) 213 | 214 | return feat_list, feat_len, padding_mask 215 | -------------------------------------------------------------------------------- /src/task/__init__.py: -------------------------------------------------------------------------------- 1 | from .train_spin import SpinPretrainTask 2 | -------------------------------------------------------------------------------- /src/task/train_spin.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import yaml 5 | from pytorch_lightning import Trainer, seed_everything 6 | from pytorch_lightning.callbacks import ( 7 | LearningRateMonitor, 8 | ModelCheckpoint, 9 | TQDMProgressBar, 10 | ) 11 | from torch.utils.data import DataLoader 12 | 13 | from src.data import AudioPretrainPnmiValDataset, val_collate_fn 14 | from src.model import SpinModel 15 | from src.util import set_logging, set_pl_logger 16 | 17 | 18 | class SpinPretrainTask: 19 | def __init__(self): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("task", help="Task name") 22 | parser.add_argument("--config", "-c", help="Config .yaml file") 23 | parser.add_argument("--save-path", "-s", help="Path to save exp") 24 | parser.add_argument("--resume", "-r", default="", help="Resume training") 25 | parser.add_argument("--gpus", "-g", type=int, default=1, help="Number of GPUs") 26 | parser.add_argument( 27 | "--njobs", "-j", type=int, default=8, help="Number of workers" 28 | ) 29 | parser.add_argument("--seed", type=int, default=7122, help="Random seed") 30 | parser.add_argument("--log-level", default="info", help="Logging level") 31 | args = parser.parse_args() 32 | 33 | if not torch.cuda.is_available(): 34 | args.device = "cpu" 35 | args.gpus = 0 36 | else: 37 | args.device = "cuda" if args.gpus > 0 else "cpu" 38 | 39 | self.args = args 40 | set_logging(args.log_level) 41 | 42 | def run(self, model_cls=SpinModel): 43 | assert isinstance(self.args, argparse.Namespace) 44 | 45 | config = yaml.load(open(self.args.config, "r"), Loader=yaml.FullLoader) 46 | self.config = config 47 | 48 | use_ddp = ( 49 | config["trainer"].get("strategy", "").startswith("ddp") 50 | and self.args.gpus > 1 51 | ) 52 | 53 | if self.args.save_path != "": 54 | config["trainer"]["default_root_dir"] = self.args.save_path 55 | 56 | model_checkpoint = ModelCheckpoint( 57 | dirpath=config["trainer"]["default_root_dir"], **config["checkpoint"] 58 | ) 59 | 60 | config["trainer"]["logger"] = set_pl_logger( 61 | config["trainer"]["logger"], 62 | config["logger"]["project"], 63 | config["trainer"]["default_root_dir"].split("/")[-1], 64 | ) 65 | 66 | trainer = Trainer( 67 | callbacks=[ 68 | TQDMProgressBar(), 69 | model_checkpoint, 70 | LearningRateMonitor("step"), 71 | ], 72 | enable_progress_bar=True, 73 | devices=self.args.gpus, 74 | check_val_every_n_epoch=None, 75 | use_distributed_sampler=False, 76 | sync_batchnorm=use_ddp, 77 | **config["trainer"], 78 | ) 79 | 80 | seed_everything(self.args.seed) 81 | 82 | if config.get("val_data", None) is not None: 83 | val_dataset = AudioPretrainPnmiValDataset(**config["val_data"]) 84 | val_loader = DataLoader( 85 | val_dataset, 86 | batch_size=config["hparam"]["val_batch_size"], 87 | num_workers=self.args.njobs, 88 | pin_memory=True, 89 | collate_fn=val_collate_fn, 90 | shuffle=False, 91 | drop_last=False, 92 | ) 93 | else: 94 | val_dataset = None 95 | val_loader = None 96 | 97 | if self.args.resume != "": 98 | model = model_cls.load_from_checkpoint(self.args.resume) 99 | else: 100 | self.args.resume = None 101 | model = model_cls(config, 2) 102 | 103 | model.set_random_seed(self.args.seed) 104 | model.set_njobs(self.args.njobs) 105 | model.set_use_ddp(use_ddp) 106 | 107 | trainer.fit(model, val_dataloaders=val_loader, ckpt_path=self.args.resume) 108 | -------------------------------------------------------------------------------- /src/util/__init__.py: -------------------------------------------------------------------------------- 1 | from .log import set_logging, set_pl_logger 2 | from .model_utils import ( 3 | count_parameters, 4 | freeze_module, 5 | init_module, 6 | init_module_bert, 7 | init_module_cnn, 8 | init_module_pos_conv, 9 | unfreeze_module, 10 | ) 11 | from .padding import ( 12 | add_front_padding_mask, 13 | len_to_padding, 14 | padding_to_len, 15 | update_padding_mask, 16 | ) 17 | from .pnmi import compute_show_pnmi, compute_snmi 18 | from .scheduler import get_scheduler 19 | -------------------------------------------------------------------------------- /src/util/log.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Union 3 | 4 | from pytorch_lightning.loggers import WandbLogger 5 | 6 | 7 | def set_logging(log_level: str = "info") -> None: 8 | level = getattr(logging, str(log_level).upper()) 9 | logging.basicConfig( 10 | level=level, 11 | format="%(asctime)s %(filename)s.%(funcName)s %(message)s", 12 | datefmt="%m-%d %H:%M", 13 | ) 14 | 15 | 16 | def set_pl_logger( 17 | logger_type: Union[bool, str], 18 | project: str = "speech_disentangle", 19 | name: str = "example", 20 | ): 21 | if isinstance(logger_type, bool): 22 | return logger_type 23 | elif logger_type == "wandb": 24 | logger = WandbLogger(project=project, name=name) 25 | return logger 26 | else: 27 | raise NotImplementedError(f"Unknown logger type = {logger_type}") 28 | -------------------------------------------------------------------------------- /src/util/model_utils.py: -------------------------------------------------------------------------------- 1 | from s3prl.upstream.wav2vec2.wav2vec2_model import MultiheadAttention 2 | from torch import nn 3 | 4 | 5 | def freeze_module(m: nn.Module) -> None: 6 | for p in m.parameters(): 7 | p.requires_grad = False 8 | 9 | 10 | def unfreeze_module(m: nn.Module) -> None: 11 | for p in m.parameters(): 12 | p.requires_grad = True 13 | 14 | 15 | def init_module(m: nn.Module): 16 | for p in m.parameters(): 17 | nn.init.normal_(p, mean=0, std=0.02) 18 | 19 | 20 | def init_module_bert(m: nn.Module): 21 | def normal_(data): 22 | # with FSDP, module params will be on CUDA, so we cast them back to CPU 23 | # so that the RNG is consistent with and without FSDP 24 | data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) 25 | 26 | if isinstance(m, nn.Linear): 27 | normal_(m.weight.data) 28 | if m.bias is not None: 29 | m.bias.data.zero_() 30 | if isinstance(m, nn.Embedding): 31 | normal_(m.weight.data) 32 | if m.padding_idx is not None: 33 | m.weight.data[m.padding_idx].zero_() 34 | if isinstance(m, MultiheadAttention): 35 | normal_(m.q_proj.weight.data) 36 | normal_(m.k_proj.weight.data) 37 | normal_(m.v_proj.weight.data) 38 | 39 | 40 | def init_module_cnn(m: nn.Module): 41 | if isinstance(m, nn.Conv1d): 42 | nn.init.kaiming_normal_(m.weight) 43 | if isinstance(m, nn.LayerNorm): 44 | m.reset_parameters() 45 | 46 | 47 | def init_module_pos_conv(m: nn.Module): 48 | if isinstance(m, nn.Conv1d): 49 | m.reset_parameters() 50 | if isinstance(m, nn.LayerNorm): 51 | m.reset_parameters() 52 | 53 | 54 | def count_parameters(model: nn.Module) -> int: 55 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 56 | -------------------------------------------------------------------------------- /src/util/padding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | @torch.no_grad() 5 | def len_to_padding(x_len: torch.LongTensor, max_len: int = 0) -> torch.BoolTensor: 6 | if max_len == 0: 7 | max_len = max(x_len) 8 | idxs = torch.arange(max_len, dtype=torch.long).to(x_len.device) 9 | padding_mask = idxs.unsqueeze(0) >= x_len.unsqueeze(1) 10 | return padding_mask 11 | 12 | 13 | @torch.no_grad() 14 | def padding_to_len(padding_mask: torch.BoolTensor) -> torch.LongTensor: 15 | x_len = (~padding_mask).long().sum(-1) 16 | return x_len 17 | 18 | 19 | @torch.no_grad() 20 | def update_padding_mask( 21 | padding_mask: torch.BoolTensor, new_len: int 22 | ) -> torch.BoolTensor: 23 | extra = padding_mask.shape[1] % new_len 24 | if extra > 0: 25 | padding_mask = padding_mask[:, :-extra] 26 | padding_mask = padding_mask.view(padding_mask.shape[0], new_len, -1) 27 | padding_mask = padding_mask.all(-1) 28 | return padding_mask 29 | 30 | 31 | @torch.no_grad() 32 | def add_front_padding_mask( 33 | padding_mask: torch.BoolTensor, pad_front_lens: torch.LongTensor 34 | ) -> None: 35 | for i in range(len(padding_mask)): 36 | if pad_front_lens[i] > 0: 37 | padding_mask[i, : pad_front_lens[i]] = True 38 | -------------------------------------------------------------------------------- /src/util/pnmi.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from collections import Counter 7 | 8 | import numpy as np 9 | from tabulate import tabulate 10 | 11 | 12 | def comp_purity(p_xy, axis): 13 | max_p = p_xy.max(axis=axis) 14 | marg_p = p_xy.sum(axis=axis) 15 | indv_pur = max_p / marg_p 16 | aggr_pur = max_p.sum() 17 | return indv_pur, aggr_pur 18 | 19 | 20 | def comp_entropy(p): 21 | return (-p * np.log(p + 1e-8)).sum() 22 | 23 | 24 | def comp_norm_mutual_info(p_xy): 25 | p_x = p_xy.sum(axis=1, keepdims=True) 26 | p_y = p_xy.sum(axis=0, keepdims=True) 27 | pmi = np.log(p_xy / np.matmul(p_x, p_y) + 1e-8) 28 | mi = (p_xy * pmi).sum() 29 | h_x = comp_entropy(p_x) 30 | h_y = comp_entropy(p_y) 31 | return mi, mi / h_x, mi / h_y, h_x, h_y 32 | 33 | 34 | def pad(labs, n): 35 | if n == 0: 36 | return np.array(labs) 37 | return np.concatenate([[labs[0]] * n, labs, [labs[-1]] * n]) 38 | 39 | 40 | def comp_avg_seg_dur(labs_list): 41 | n_frms = 0 42 | n_segs = 0 43 | for labs in labs_list: 44 | labs = np.array(labs) 45 | edges = np.zeros(len(labs)).astype(bool) 46 | edges[0] = True 47 | edges[1:] = labs[1:] != labs[:-1] 48 | n_frms += len(edges) 49 | n_segs += edges.astype(int).sum() 50 | return n_frms / n_segs 51 | 52 | 53 | def comp_joint_prob(uid2refs, uid2hyps): 54 | cnts = Counter() 55 | skipped = [] 56 | abs_frmdiff = 0 57 | for uid in uid2refs: 58 | if uid not in uid2hyps: 59 | skipped.append(uid) 60 | continue 61 | refs = uid2refs[uid] 62 | hyps = uid2hyps[uid] 63 | abs_frmdiff += abs(len(refs) - len(hyps)) 64 | min_len = min(len(refs), len(hyps)) 65 | refs = refs[:min_len] 66 | hyps = hyps[:min_len] 67 | cnts.update(zip(refs, hyps)) 68 | tot = sum(cnts.values()) 69 | 70 | ref_set = sorted({ref for ref, _ in cnts.keys()}) 71 | hyp_set = sorted({hyp for _, hyp in cnts.keys()}) 72 | ref2pid = dict(zip(ref_set, range(len(ref_set)))) 73 | hyp2lid = dict(zip(hyp_set, range(len(hyp_set)))) 74 | 75 | p_xy = np.zeros((len(ref2pid), len(hyp2lid)), dtype=float) 76 | for (ref, hyp), cnt in cnts.items(): 77 | p_xy[ref2pid[ref], hyp2lid[hyp]] = cnt 78 | freq_xy = p_xy 79 | full_freq_xy = np.zeros((len(ref2pid), 4096), dtype=float) 80 | for (ref, hyp), cnt in cnts.items(): 81 | full_freq_xy[ref2pid[ref], int(hyp)] = cnt 82 | p_xy = p_xy / p_xy.sum() 83 | return ( 84 | freq_xy, 85 | full_freq_xy, 86 | p_xy, 87 | ref2pid, 88 | hyp2lid, 89 | tot, 90 | abs_frmdiff, 91 | skipped, 92 | ref_set, 93 | hyp_set, 94 | ) 95 | 96 | 97 | def comp_phone2code(p_xy): 98 | p_x = p_xy.sum(axis=1, keepdims=True) # ref (phone) 99 | p_y = p_xy.sum(axis=0, keepdims=True) # hyp (code) 100 | 101 | p_x_y = p_xy / p_y # P(x | y) = P(phone | code) 102 | 103 | y_order = np.argsort(p_x_y.argmax(0)) 104 | p_x_y_sorted_y = np.take_along_axis(p_x_y, y_order.reshape((1, -1)), axis=1) 105 | 106 | x_order = np.argsort(p_x[:, 0]) 107 | x_order = np.flip(x_order) 108 | p_x_y_sorted_x = np.take_along_axis(p_x_y, x_order.reshape((-1, 1)), axis=0) 109 | y_order = np.argsort(p_x_y_sorted_x.argmax(0)) 110 | p_x_y_sorted_xy = np.take_along_axis( 111 | p_x_y_sorted_x, y_order.reshape((1, -1)), axis=1 112 | ) 113 | 114 | return p_x_y, p_x_y_sorted_xy, p_x_y_sorted_y, x_order 115 | 116 | 117 | def compute_show_pnmi(uid2refs, uid2hyps, upsample=1, show_results: bool = False): 118 | for k, v in uid2hyps.items(): 119 | uid2hyps[k] = pad(v, 0).repeat(upsample) 120 | 121 | ( 122 | freq_xy, 123 | full_freq_xy, 124 | p_xy, 125 | ref2pid, 126 | hyp2lid, 127 | tot, 128 | frmdiff, 129 | skipped, 130 | ref_set, 131 | hyp_set, 132 | ) = comp_joint_prob(uid2refs, uid2hyps) 133 | ref_pur_by_hyp, ref_pur = comp_purity(p_xy, axis=0) 134 | hyp_pur_by_ref, hyp_pur = comp_purity(p_xy, axis=1) 135 | (mi, mi_norm_by_ref, mi_norm_by_hyp, h_ref, h_hyp) = comp_norm_mutual_info(p_xy) 136 | 137 | if show_results: 138 | print( 139 | tabulate( 140 | [[hyp_pur, ref_pur, mi_norm_by_ref]], 141 | ["Cls Pur", "Phn Pur", "PNMI"], 142 | floatfmt=".3f", 143 | tablefmt="fancy_grid", 144 | ) 145 | ) 146 | 147 | return { 148 | "cls_pur": hyp_pur, 149 | "phn_pur": ref_pur, 150 | "pnmi": mi_norm_by_ref, 151 | } 152 | 153 | 154 | def compute_snmi(p_xy): 155 | _, ref_pur = comp_purity(p_xy, axis=0) 156 | _, hyp_pur = comp_purity(p_xy, axis=1) 157 | (_, mi_norm_by_ref, _, _, _) = comp_norm_mutual_info(p_xy) 158 | 159 | return { 160 | "cls_pur": hyp_pur, 161 | "spk_pur": ref_pur, 162 | "snmi": mi_norm_by_ref, 163 | } 164 | -------------------------------------------------------------------------------- /src/util/scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from torch.optim import Optimizer 4 | from torch.optim.lr_scheduler import LambdaLR, _LRScheduler 5 | 6 | 7 | def get_lr(optimizer: Optimizer) -> float: 8 | for param_group in optimizer.param_groups: 9 | return param_group["lr"] 10 | 11 | 12 | def noam_scheduler( 13 | optimizer: Optimizer, warmup: int = 4000, last_epoch: int = -1 14 | ) -> _LRScheduler: 15 | def func(step: int): 16 | if step < warmup: 17 | return (step + 1) / warmup 18 | else: 19 | return (warmup / (step + 1)) ** 0.5 20 | 21 | return LambdaLR(optimizer, func, last_epoch) 22 | 23 | 24 | def linear_warmup_decay_scheduler( 25 | optimizer: Optimizer, 26 | warmup: int = 4000, 27 | max_step: int = 1000000, 28 | init_lr: float = 1e-6, 29 | final_lr: float = 1e-6, 30 | ) -> _LRScheduler: 31 | func_list = [] 32 | 33 | for param_group in optimizer.param_groups: 34 | base_lr = param_group["lr"] 35 | rate_i = init_lr / base_lr 36 | rate_f = final_lr / base_lr 37 | 38 | def func(step: int) -> float: 39 | if step <= warmup: 40 | return rate_i + (1.0 - rate_i) * step / warmup 41 | else: 42 | return 1.0 - (1.0 - rate_f) * (step - warmup) / (max_step - warmup - 1) 43 | 44 | func_list.append(func) 45 | 46 | return LambdaLR(optimizer, func_list) 47 | 48 | 49 | def linear_warmup_cosine_scheduler( 50 | optimizer: Optimizer, 51 | warmup: int = 4000, 52 | max_step: int = 1000000, 53 | final_lr: float = 1e-6, 54 | ) -> _LRScheduler: 55 | func_list = [] 56 | 57 | for param_group in optimizer.param_groups: 58 | base_lr = param_group["lr"] 59 | rate = final_lr / base_lr 60 | 61 | def func(step: int) -> float: 62 | if step < warmup: 63 | return (step + 1) / warmup 64 | else: 65 | q = 0.5 * ( 66 | 1 + math.cos(math.pi * (step + 1 - warmup) / (max_step - warmup)) 67 | ) 68 | return (1.0 - rate) * q + rate 69 | 70 | func_list.append(func) 71 | 72 | return LambdaLR(optimizer, func_list) 73 | 74 | 75 | def get_scheduler(name: str, optimizer: Optimizer, **kwargs) -> _LRScheduler: 76 | if name == "noam": 77 | return noam_scheduler(optimizer, **kwargs) 78 | elif name == "linear_warmup_decay": 79 | return linear_warmup_decay_scheduler(optimizer, **kwargs) 80 | elif name == "linear_warmup_cosine": 81 | return linear_warmup_cosine_scheduler(optimizer, **kwargs) 82 | else: 83 | raise NotImplementedError(f"Unknown lr scheduler {name}") 84 | --------------------------------------------------------------------------------