├── .gitignore ├── LICENSE ├── README.md ├── convert_hubert_from_fairseq.py ├── convert_hubert_from_hf.py ├── convert_hubert_large_from_fairseq.py ├── convert_wav2vec2_from_fairseq.py ├── convert_wavlm_from_hf.py ├── dataset ├── __init__.py └── audio_dataset.py ├── distill.py ├── final_distill.py ├── imgs ├── dphubert-results.png ├── dphubert-superb-score.png ├── dphubert-superb.png └── dphubert-train.png ├── lightning.py ├── prepare_data.py ├── prune.py ├── run.sh ├── save_final_ckpt.py └── wav2vec2 ├── __init__.py ├── components.py ├── hardconcrete.py ├── model.py ├── pruning_utils.py └── utils ├── __init__.py └── import_huggingface_wavlm.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | # Data and exp directories 163 | data/ 164 | exp/ 165 | pretrained/ 166 | 167 | # SLURM output 168 | slurm-*.out 169 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Carnegie Mellon University (Yifan Peng) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DPHuBERT 2 | 3 | ![GitHub Repo stars](https://img.shields.io/github/stars/pyf98/DPHuBERT) 4 | 5 | This repo contains the code and models for our paper: 6 | 7 | Yifan Peng, Yui Sudo, Shakeel Muhammad, and Shinji Watanabe, “DPHuBERT: Joint Distillation and Pruning of Self-Supervised Speech Models,” in Proc. INTERSPEECH, 2023. 8 | [[arXiv](https://arxiv.org/abs/2305.17651)] [[paper](https://www.isca-speech.org/archive/interspeech_2023/peng23c_interspeech.html)] 9 | 10 | 11 | ## Overview 12 | 13 | DPHuBERT is a task-agnostic compression method based on **joint distillation and structured pruning**. DPHuBERT outperforms pure distillation methods in most [SUPERB](https://superbbenchmark.org/leaderboard) tasks. It also performs well with limited training data. Our method can be directly applied to various speech SSL models like HuBERT (eithr Base or Large) and WavLM. 14 | 15 | The training procedure is illustrated in the figure below: 16 | 17 | ![Training procedure of DPHuBERT](imgs/dphubert-train.png) 18 | 19 | --- 20 | 21 | The main results are summarized in this table: 22 | 23 | ![DPHuBERT results](imgs/dphubert-results.png) 24 | 25 | --- 26 | 27 | Our models are also shown in the [SUPERB leaderboard](https://superbbenchmark.org/leaderboard). Here are the results sorted by Rank and Score, respectively. 28 | 29 | ![SUPERB sorted by Rank](imgs/dphubert-superb.png) 30 | 31 | --- 32 | 33 | ![SUPERB sorted by Score](imgs/dphubert-superb-score.png) 34 | 35 | ## Requirements 36 | 37 | Our code is based on PyTorch, TorchAudio, and PyTorch Lightning. Please install these required packages from their official sources. We include our versions below for reference, but other versions might also work. 38 | 39 | ``` 40 | # Main packages for training 41 | pytorch=1.13.1 42 | cuda=11.6.2 43 | pytorch-lightning=1.8.1 44 | torchaudio=0.13.1 45 | 46 | # Other packages for obtaining pre-trained SSL 47 | fairseq=0.12.2 48 | transformers=4.24.0 49 | ``` 50 | 51 | 52 | ## Usage 53 | 54 | Please follow these steps to train DPHuBERT. 55 | 56 | ### 1. Download and prepare audio data 57 | 58 | The following script creates file lists for LibriSpeech in tsv format. `LibriSpeech_PATH` is the path to the downloaded raw data. 59 | 60 | ```bash 61 | python prepare_data.py --data LibriSpeech_PATH --out data/librispeech 62 | ``` 63 | 64 | The output directory has this structure: 65 | 66 | ``` 67 | data 68 | └── librispeech 69 | ├── train100.tsv 70 | ├── train960.tsv 71 | └── valid.tsv 72 | ``` 73 | 74 | ### 2. Download pre-trained SSL (e.g., HuBERT Base) and convert it to our format 75 | 76 | We need to download pre-trained SSL checkpoints from fairseq or Hugging Face and then convert them to our own format. These models will be used as the teacher for compression. For example, we can obtain HuBERT Base by executing: 77 | 78 | ```bash 79 | mkdir -p pretrained 80 | python convert_hubert_from_hf.py 81 | ``` 82 | 83 | The converted checkpoint will be saved as `pretrained/hubert-base-ls960.hf.pth`. The output path can be changed in the python script. 84 | 85 | ### 3. Start training 86 | 87 | After preparing data and pre-trained model, we can start training by sequentially executing the four python scripts: `distill.py`, `prune.py`, `final_distill.py`, and `save_final_ckpt.py`. We provide a shell script `run.sh` to better record the hyper-parameters. By default, we request 4 NVIDIA A100 (40GB) GPUs via the SLURM job scheduler. It takes around 6 hours to compress HuBERT Base. Please modify the hyper-parameters if the environment is different. For example, one can reduce the number of GPUs but enable gradient accumulation to keep the total batch size in a similar range. 88 | 89 | ```bash 90 | sbatch run.sh 91 | ``` 92 | 93 | After training, the compressed model parameters and configurations will be saved in the corresponding experiment directory. We can easily load a compressed model as follows: 94 | 95 | ```python 96 | import torch 97 | from wav2vec2.model import wav2vec2_model 98 | 99 | ckpt_path = "path/to/ckpt" 100 | ckpt = torch.load(ckpt_path) 101 | model = wav2vec2_model(**ckpt["config"]) 102 | result = model.load_state_dict(ckpt["state_dict"], strict=False) 103 | print(f"missing: {result.missing_keys}, unexpected: {result.unexpected_keys}") 104 | print(f"{sum(p.numel() for p in model.parameters())} params") 105 | ``` 106 | 107 | 108 | ## Pre-trained models 109 | 110 | We also provide some pre-trained models. 111 | 112 | | Name | Teacher | Sparsity | Params | Link | 113 | |:---:|:---:|:---:|:---:|:---:| 114 | | DPHuBERT | HuBERT Base | 0.75 | 23,585,946 | [Hugging Face](https://huggingface.co/pyf98/DPHuBERT/blob/main/DPHuBERT-sp0.75.pth) | 115 | | DPWavLM | WavLM Base+ | 0.75 | 23,586,325 | [Hugging Face](https://huggingface.co/pyf98/DPHuBERT/blob/main/DPWavLM-sp0.75.pth) | 116 | 117 | 118 | 119 | ## Citation 120 | 121 | Please cite related papers if you use DPHuBERT. 122 | 123 | ``` 124 | @inproceedings{peng23c_interspeech, 125 | author={Yifan Peng and Yui Sudo and Shakeel Muhammad and Shinji Watanabe}, 126 | title={{DPHuBERT: Joint Distillation and Pruning of Self-Supervised Speech Models}}, 127 | year=2023, 128 | booktitle={Proc. INTERSPEECH 2023}, 129 | pages={62--66}, 130 | doi={10.21437/Interspeech.2023-1213} 131 | } 132 | @INPROCEEDINGS{10095780, 133 | author={Peng, Yifan and Kim, Kwangyoun and Wu, Felix and Sridhar, Prashant and Watanabe, Shinji}, 134 | booktitle={ICASSP 2023 - 2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 135 | title={Structured Pruning of Self-Supervised Pre-Trained Models for Speech Recognition and Understanding}, 136 | year={2023}, 137 | pages={1-5}, 138 | doi={10.1109/ICASSP49357.2023.10095780}} 139 | ``` 140 | 141 | ## Acknowledgments 142 | 143 | We thank the authors of the following projects for open-sourcing their code: 144 | - [TorchAudio](https://github.com/pytorch/audio): Our speech SSL models and training pipelines are based on TorchAudio. 145 | - [FLOP](https://github.com/asappresearch/flop): Our implementation of the Hard Concrete Distribution is from FLOP. 146 | - [CoFiPruning](https://github.com/princeton-nlp/CoFiPruning): Some of our training hyper-parameters follow CoFiPruning. 147 | 148 | Our method is inspired by prior studies: 149 | - Distillation: [DistilHuBERT](https://arxiv.org/abs/2110.01900), [FitHuBERT](https://arxiv.org/abs/2207.00555), [Deep versus Wide](https://arxiv.org/abs/2207.06867) 150 | - Pruning: [FLOP](https://arxiv.org/abs/1910.04732), [CoFiPruning](https://arxiv.org/abs/2204.00408), [HJ-Pruning](https://arxiv.org/abs/2302.14132) 151 | -------------------------------------------------------------------------------- /convert_hubert_from_fairseq.py: -------------------------------------------------------------------------------- 1 | """Convert fairseq's HuBERT to our format.""" 2 | 3 | import torch 4 | import fairseq 5 | from torchaudio.models.wav2vec2.utils import import_fairseq_model 6 | 7 | from wav2vec2.model import wav2vec2_model 8 | 9 | 10 | if __name__ == "__main__": 11 | out_name = "pretrained/hubert-base-ls960.fairseq.pth" 12 | 13 | fairseq_ckpt = "pretrained/fairseq/hubert_base_ls960.pt" 14 | ensemble, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([fairseq_ckpt]) 15 | original = ensemble[0] 16 | imported = import_fairseq_model(original) 17 | print(imported) 18 | 19 | # default config of hubert base 20 | hubert_base_config = dict( 21 | extractor_mode="group_norm", # hubert base only uses a group norm at the first conv layer 22 | extractor_conv_layer_config=[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2, 23 | extractor_conv_bias=False, 24 | encoder_embed_dim=768, 25 | encoder_projection_dropout=0.1, 26 | encoder_pos_conv_kernel=128, 27 | encoder_pos_conv_groups=16, 28 | encoder_num_layers=12, 29 | encoder_use_attention=[True] * 12, 30 | encoder_use_feed_forward=[True] * 12, 31 | encoder_num_heads=[12] * 12, 32 | encoder_head_dim=64, 33 | encoder_attention_dropout=0.1, 34 | encoder_ff_interm_features=[3072] * 12, 35 | encoder_ff_interm_dropout=0.0, 36 | encoder_dropout=0.1, 37 | encoder_layer_norm_first=False, # hubert base uses post norm 38 | encoder_layer_drop=0.05, 39 | aux_num_out=None, 40 | normalize_waveform=False, 41 | extractor_prune_conv_channels=False, 42 | encoder_prune_attention_heads=False, 43 | encoder_prune_attention_layer=False, 44 | encoder_prune_feed_forward_intermediate=False, 45 | encoder_prune_feed_forward_layer=False, 46 | ) 47 | 48 | torch.save( 49 | { 50 | 'state_dict': imported.state_dict(), 51 | 'config': hubert_base_config, 52 | }, 53 | out_name 54 | ) 55 | 56 | # verify the saved ckpt 57 | ckpt = torch.load(out_name, map_location="cpu") 58 | model = wav2vec2_model(**ckpt['config']) 59 | res = model.load_state_dict(ckpt['state_dict'], strict=False) 60 | print(f"Missing: {res.missing_keys}\nUnexpected: {res.unexpected_keys}") 61 | -------------------------------------------------------------------------------- /convert_hubert_from_hf.py: -------------------------------------------------------------------------------- 1 | """Convert Hugging Face's HuBERT to our format.""" 2 | 3 | import torch 4 | from transformers import HubertModel 5 | from torchaudio.models.wav2vec2.utils import import_huggingface_model 6 | 7 | from wav2vec2.model import wav2vec2_model 8 | 9 | 10 | if __name__ == "__main__": 11 | out_name = "pretrained/hubert-base-ls960.hf.pth" 12 | 13 | original = HubertModel.from_pretrained("facebook/hubert-base-ls960") 14 | imported = import_huggingface_model(original) 15 | print(imported) 16 | 17 | # default config of hubert base 18 | hubert_base_config = dict( 19 | extractor_mode="group_norm", # hubert base only uses a group norm at the first conv layer 20 | extractor_conv_layer_config=[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2, 21 | extractor_conv_bias=False, 22 | encoder_embed_dim=768, 23 | encoder_projection_dropout=0.1, 24 | encoder_pos_conv_kernel=128, 25 | encoder_pos_conv_groups=16, 26 | encoder_num_layers=12, 27 | encoder_use_attention=[True] * 12, 28 | encoder_use_feed_forward=[True] * 12, 29 | encoder_num_heads=[12] * 12, 30 | encoder_head_dim=64, 31 | encoder_attention_dropout=0.1, 32 | encoder_ff_interm_features=[3072] * 12, 33 | encoder_ff_interm_dropout=0.0, 34 | encoder_dropout=0.1, 35 | encoder_layer_norm_first=False, # hubert base uses post norm 36 | encoder_layer_drop=0.05, 37 | aux_num_out=None, 38 | normalize_waveform=False, 39 | extractor_prune_conv_channels=False, 40 | encoder_prune_attention_heads=False, 41 | encoder_prune_attention_layer=False, 42 | encoder_prune_feed_forward_intermediate=False, 43 | encoder_prune_feed_forward_layer=False, 44 | ) 45 | 46 | torch.save( 47 | { 48 | 'state_dict': imported.state_dict(), 49 | 'config': hubert_base_config, 50 | }, 51 | out_name 52 | ) 53 | 54 | # verify the saved ckpt 55 | ckpt = torch.load(out_name, map_location="cpu") 56 | model = wav2vec2_model(**ckpt['config']) 57 | res = model.load_state_dict(ckpt['state_dict'], strict=False) 58 | print(f"Missing: {res.missing_keys}\nUnexpected: {res.unexpected_keys}") 59 | -------------------------------------------------------------------------------- /convert_hubert_large_from_fairseq.py: -------------------------------------------------------------------------------- 1 | """Convert fairseq's HuBERT Large to our format.""" 2 | 3 | import torch 4 | import fairseq 5 | from torchaudio.models.wav2vec2.utils import import_fairseq_model 6 | 7 | from wav2vec2.model import wav2vec2_model 8 | 9 | 10 | if __name__ == "__main__": 11 | out_name = "pretrained/hubert-large-ll60k.fairseq.pth" 12 | 13 | fairseq_ckpt = "pretrained/fairseq/hubert_large_ll60k.pt" 14 | ensemble, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([fairseq_ckpt]) 15 | original = ensemble[0] 16 | imported = import_fairseq_model(original) 17 | print(imported) 18 | 19 | # default config of hubert large 20 | hubert_large_config = dict( 21 | extractor_mode="layer_norm", 22 | extractor_conv_layer_config=[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2, 23 | extractor_conv_bias=False, 24 | encoder_embed_dim=1024, 25 | encoder_projection_dropout=0.1, 26 | encoder_pos_conv_kernel=128, 27 | encoder_pos_conv_groups=16, 28 | encoder_num_layers=24, 29 | encoder_use_attention=[True] * 24, 30 | encoder_use_feed_forward=[True] * 24, 31 | encoder_num_heads=[16] * 24, 32 | encoder_head_dim=64, 33 | encoder_attention_dropout=0.1, 34 | encoder_ff_interm_features=[4096] * 24, 35 | encoder_ff_interm_dropout=0.0, 36 | encoder_dropout=0.1, 37 | encoder_layer_norm_first=True, # hubert large uses pre norm 38 | encoder_layer_drop=0.05, 39 | aux_num_out=None, 40 | normalize_waveform=True, 41 | extractor_prune_conv_channels=False, 42 | encoder_prune_attention_heads=False, 43 | encoder_prune_attention_layer=False, 44 | encoder_prune_feed_forward_intermediate=False, 45 | encoder_prune_feed_forward_layer=False, 46 | ) 47 | 48 | torch.save( 49 | { 50 | 'state_dict': imported.state_dict(), 51 | 'config': hubert_large_config, 52 | }, 53 | out_name 54 | ) 55 | 56 | # verify the saved ckpt 57 | ckpt = torch.load(out_name, map_location="cpu") 58 | model = wav2vec2_model(**ckpt['config']) 59 | res = model.load_state_dict(ckpt['state_dict'], strict=False) 60 | print(f"Missing: {res.missing_keys}\nUnexpected: {res.unexpected_keys}") 61 | -------------------------------------------------------------------------------- /convert_wav2vec2_from_fairseq.py: -------------------------------------------------------------------------------- 1 | """Convert fairseq's wav2vec2 to our format.""" 2 | 3 | import torch 4 | import fairseq 5 | from torchaudio.models.wav2vec2.utils import import_fairseq_model 6 | 7 | from wav2vec2.model import wav2vec2_model 8 | 9 | 10 | if __name__ == "__main__": 11 | out_name = "pretrained/wav2vec2-base-ls960.fairseq.pth" 12 | 13 | fairseq_ckpt = "pretrained/fairseq/wav2vec_small.pt" 14 | ensemble, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([fairseq_ckpt]) 15 | original = ensemble[0] 16 | imported = import_fairseq_model(original) 17 | print(imported) 18 | 19 | # default config of wav2vec2 base 20 | wav2vec2_base_config = dict( 21 | extractor_mode="group_norm", # hubert/w2v2 base only uses a group norm at the first conv layer 22 | extractor_conv_layer_config=[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2, 23 | extractor_conv_bias=False, 24 | encoder_embed_dim=768, 25 | encoder_projection_dropout=0.1, 26 | encoder_pos_conv_kernel=128, 27 | encoder_pos_conv_groups=16, 28 | encoder_num_layers=12, 29 | encoder_use_attention=[True] * 12, 30 | encoder_use_feed_forward=[True] * 12, 31 | encoder_num_heads=[12] * 12, 32 | encoder_head_dim=64, 33 | encoder_attention_dropout=0.1, 34 | encoder_ff_interm_features=[3072] * 12, 35 | encoder_ff_interm_dropout=0.0, 36 | encoder_dropout=0.1, 37 | encoder_layer_norm_first=False, # hubert/w2v2 base uses post norm 38 | encoder_layer_drop=0.05, 39 | aux_num_out=None, 40 | normalize_waveform=False, 41 | extractor_prune_conv_channels=False, 42 | encoder_prune_attention_heads=False, 43 | encoder_prune_attention_layer=False, 44 | encoder_prune_feed_forward_intermediate=False, 45 | encoder_prune_feed_forward_layer=False, 46 | ) 47 | 48 | torch.save( 49 | { 50 | 'state_dict': imported.state_dict(), 51 | 'config': wav2vec2_base_config, 52 | }, 53 | out_name 54 | ) 55 | 56 | # verify the saved ckpt 57 | ckpt = torch.load(out_name, map_location="cpu") 58 | model = wav2vec2_model(**ckpt['config']) 59 | res = model.load_state_dict(ckpt['state_dict'], strict=False) 60 | print(f"Missing: {res.missing_keys}\nUnexpected: {res.unexpected_keys}") 61 | -------------------------------------------------------------------------------- /convert_wavlm_from_hf.py: -------------------------------------------------------------------------------- 1 | """Convert Hugging Face's WavLM to our format.""" 2 | 3 | import torch 4 | from transformers import WavLMModel 5 | 6 | from wav2vec2.model import wav2vec2_model 7 | from wav2vec2.utils.import_huggingface_wavlm import import_huggingface_model 8 | 9 | 10 | if __name__ == "__main__": 11 | out_name = "pretrained/wavlm-base-plus.hf.pth" 12 | 13 | original = WavLMModel.from_pretrained("microsoft/wavlm-base-plus") 14 | imported = import_huggingface_model(original) 15 | imported.eval() 16 | print(imported) 17 | 18 | # default config of wavlm base 19 | wavlm_base_plus_config = dict( 20 | extractor_mode="group_norm", # wavlm base only uses a group norm at the first conv layer 21 | extractor_conv_layer_config=[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2, 22 | extractor_conv_bias=False, 23 | encoder_embed_dim=768, 24 | encoder_projection_dropout=0.1, 25 | encoder_pos_conv_kernel=128, 26 | encoder_pos_conv_groups=16, 27 | encoder_num_layers=12, 28 | encoder_use_attention=[True] * 12, 29 | encoder_use_feed_forward=[True] * 12, 30 | encoder_total_num_heads=[12] * 12, 31 | encoder_remaining_heads=[list(range(12)) for _ in range(12)], 32 | encoder_num_buckets=320, 33 | encoder_max_distance=800, 34 | encoder_attention_dropout=0.1, 35 | encoder_ff_interm_features=[3072] * 12, 36 | encoder_ff_interm_dropout=0.0, 37 | encoder_dropout=0.1, 38 | encoder_layer_norm_first=False, # wavlm base uses post norm 39 | encoder_layer_drop=0.05, 40 | aux_num_out=None, 41 | normalize_waveform=False, 42 | extractor_prune_conv_channels=False, 43 | encoder_prune_attention_heads=False, 44 | encoder_prune_attention_layer=False, 45 | encoder_prune_feed_forward_intermediate=False, 46 | encoder_prune_feed_forward_layer=False, 47 | ) 48 | 49 | torch.save( 50 | { 51 | 'state_dict': imported.state_dict(), 52 | 'config': wavlm_base_plus_config, 53 | }, 54 | out_name 55 | ) 56 | 57 | # verify the saved ckpt 58 | ckpt = torch.load(out_name, map_location="cpu") 59 | model = wav2vec2_model(**ckpt['config']) 60 | print(model.load_state_dict(ckpt['state_dict'], strict=False)) 61 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyf98/DPHuBERT/c18093fe4b56a0027a80bf9b9b1a23f932cbf14c/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/audio_dataset.py: -------------------------------------------------------------------------------- 1 | """Dataset class for audio. 2 | 3 | Originally from TorchAudio's hubert_dataset: 4 | https://github.com/pytorch/audio/blob/main/examples/hubert/dataset/hubert_dataset.py 5 | 6 | """ 7 | 8 | from pathlib import Path 9 | from typing import Dict, Iterator, List, Optional, Tuple, Union 10 | 11 | import numpy as np 12 | import torch 13 | import torch.distributed as dist 14 | import torchaudio 15 | from torch import Tensor 16 | from torch.utils.data import BatchSampler, Dataset, DistributedSampler 17 | 18 | 19 | class BucketizeBatchSampler(BatchSampler): 20 | """Buketized BatchSampler for sequential data with different lengths to reduce number of paddings. 21 | 22 | Args: 23 | lengths (List[int]): The lengths of the samples in the dataset. 24 | num_buckets (int): The number of buckets to split the data samples. 25 | min_len (int, optional): The minimum sample lengths to keep. 26 | (Default: 0) 27 | max_len (int or None, optional): The maximum sample lengths to keep. Inferred if not provided. 28 | (Default ``None``) 29 | max_token_count (int or None, optional): The max number of tokens in one mini-batch. 30 | (Default: ``None``) 31 | batch_size (int or None, optional): The number of samples in one mini-batch. 32 | (Default: ``None``) 33 | shuffle (bool, optional): Whether to shuffle buckets for non-monotonic length sampling. 34 | (Default: True) 35 | drop_last (bool, optional): If ``True``, the sampler will drop the last batch if 36 | its size would be less than ``batch_size`` 37 | (Default: False) 38 | 39 | Note: 40 | ``max_token_count`` and ``batch_size`` are mutually exclusive. Only one argument of the two 41 | should have value. 42 | 43 | Note: 44 | ``drop_last`` is only valid when ``batch_size`` argument is given. 45 | 46 | Note: 47 | if ``shuffle`` is True, it will only shuffle the data once. Please set ``reload_dataloaders_every_n_epochs=1`` 48 | in pytorch_lightning Trainer to enable shuffling every epoch. 49 | """ 50 | 51 | def __init__( 52 | self, 53 | lengths: List[int], 54 | num_buckets: int, 55 | min_len: int = 0, 56 | max_len: Optional[int] = None, 57 | max_token_count: Optional[int] = None, 58 | batch_size: Optional[int] = None, 59 | shuffle: bool = True, 60 | drop_last: bool = False, 61 | ) -> None: 62 | if max_len is None: 63 | max_len = max(lengths) 64 | 65 | if not (0 <= min_len <= max_len): 66 | raise AssertionError("``min_len`` should be non-negative and smaller than ``max_len``") 67 | if max_token_count is not None and batch_size is not None: 68 | raise AssertionError("The ``max_token_count`` and ``batch_size`` can't be both set.") 69 | if max_token_count is None and batch_size is None: 70 | raise AssertionError("One of ``max_token_count`` or ``batch_size`` must be set.") 71 | if max_token_count is not None: 72 | assert ( 73 | max_len <= max_token_count 74 | ), "The ``max_token_count`` must be greater than or equal to the maximum value of ``lengths``." 75 | # Filter out samples which are outside the bounds of [min_len, max_len] 76 | filtered_length_idx = [(length, i) for i, length in enumerate(lengths) if min_len <= length <= max_len] 77 | if len(filtered_length_idx) == 0: 78 | raise AssertionError("``lengths`` cannot be empty after filtering.") 79 | sorted_filtered_length_idx = sorted(filtered_length_idx, key=lambda x: x[0]) 80 | self.lengths = [e[0] for e in sorted_filtered_length_idx] 81 | self.indices = [e[1] for e in sorted_filtered_length_idx] 82 | self.max_token_count = max_token_count 83 | self.batch_size = batch_size 84 | self.shuffle = shuffle 85 | self.drop_last = drop_last 86 | self.buckets = self._get_buckets(self.lengths, num_buckets, min_len, max_len) 87 | self._update_iter_list() 88 | 89 | def _get_buckets(self, lengths: List[int], num_buckets: int, min_len: int, max_len: int) -> Dict[int, Tensor]: 90 | """Generate buckets based on the dataset. 91 | Args: 92 | lengths (List[int]): The lengths of the samples in the dataset. 93 | num_buckets (int): The number of buckets. 94 | min_len (int): The lower bound of the evenly spaced length intervals to determine bucket width. 95 | max_len (int): The upper bound of the evenly spaced length intervals to determine bucket width. 96 | 97 | Returns: 98 | (dict[int, Tensor]): A dictionary in which the key is the bucket index, the value is 99 | the Tensor of corresponding sample indices. 100 | """ 101 | buckets = {} 102 | boundaries = torch.linspace(min_len - 1, max_len + 1, num_buckets + 1) 103 | bucket_ids = torch.bucketize(torch.tensor(lengths), boundaries) 104 | for i in range(bucket_ids.size(0)): 105 | bucket_id = int(bucket_ids[i]) 106 | if bucket_id in buckets: 107 | buckets[bucket_id].append(i) 108 | else: 109 | buckets[bucket_id] = [i] 110 | for k in buckets: 111 | buckets[k] = torch.as_tensor(buckets[k], dtype=torch.int) 112 | buckets = {k: v for k, v in sorted(buckets.items())} 113 | return buckets 114 | 115 | def _update_iter_list(self) -> None: 116 | if self.shuffle: 117 | for k in self.buckets: 118 | self.buckets[k] = self.buckets[k][torch.randperm(self.buckets[k].size(0))] 119 | self.iter_list = [] 120 | total_len = 0 121 | batch = [] 122 | max_batch_size = self.max_token_count if self.max_token_count else self.batch_size 123 | for k in self.buckets: 124 | for i in range(self.buckets[k].size(0)): 125 | index = int(self.buckets[k][i]) 126 | sample_length = self.lengths[index] if self.max_token_count else 1 127 | if total_len + sample_length <= max_batch_size: 128 | batch.append(self.indices[index]) 129 | total_len += sample_length 130 | else: 131 | self.iter_list.append(batch) 132 | batch = [self.indices[index]] 133 | total_len = sample_length 134 | if len(batch) > 0 and (self.max_token_count or not self.drop_last): 135 | self.iter_list.append(batch) 136 | 137 | def __iter__(self) -> Iterator[List[int]]: 138 | return iter(self.iter_list) 139 | 140 | def __len__(self): 141 | if self.batch_size or (self.max_token_count and not self.shuffle): 142 | return len(self.iter_list) 143 | 144 | 145 | class DistributedBatchSampler(DistributedSampler): 146 | """`BucketizeBatchSampler` wrapper that distributes across each processor. 147 | 148 | Args: 149 | batch_sampler (BucketizeBatchSampler): the initialized bucketize batch sampler. 150 | num_replicas (int, optional): Number of processes participating in 151 | distributed training. By default, :attr:`world_size` is retrieved from the 152 | current distributed group. 153 | rank (int, optional): Rank of the current process within :attr:`num_replicas`. 154 | By default, :attr:`rank` is retrieved from the current distributed 155 | group. 156 | shuffle (bool, optional): if ``True``, the list of batch indices will be shuffled. 157 | (Default: ``True``) 158 | seed (int, optional): random seed used to shuffle the batch_sampler if 159 | :attr:`shuffle=True`. This number should be identical across all 160 | processes in the distributed group. (Default: ``0``) 161 | drop_last (bool, optional): if ``True``, then the sampler will drop the 162 | tail of the data to make it evenly divisible across the number of 163 | replicas. If ``False``, the sampler will add extra indices to make 164 | the data evenly divisible across the replicas. (Default: ``False``) 165 | 166 | Note: 167 | if ``shuffle`` is True, it will only shuffle the data once. Please set ``reload_dataloaders_every_n_epochs=1`` 168 | in pytorch_lightning Trainer, and set `sampler.set_epoch(self.current_epoch)` before DataLoader initialization 169 | in `train_dataloader` method to enable shuffling every epoch. 170 | """ 171 | 172 | def __init__( 173 | self, 174 | batch_sampler: BucketizeBatchSampler, 175 | num_replicas: Optional[int] = None, 176 | rank: Optional[int] = None, 177 | shuffle: bool = True, 178 | seed: int = 0, 179 | drop_last: bool = False, 180 | ) -> None: 181 | self.batch_sampler = batch_sampler 182 | if num_replicas is None: 183 | if not dist.is_available(): 184 | raise RuntimeError("Requires distributed package to be available") 185 | num_replicas = dist.get_world_size() 186 | if rank is None: 187 | if not dist.is_available(): 188 | raise RuntimeError("Requires distributed package to be available") 189 | rank = dist.get_rank() 190 | self.num_replicas = num_replicas 191 | self.rank = rank 192 | self.shuffle = shuffle 193 | self.epoch = 0 194 | self.seed = seed 195 | self.drop_last = drop_last 196 | if shuffle: 197 | g = torch.Generator() 198 | g.manual_seed(self.seed + self.epoch) 199 | perm = torch.randperm(len(self.batch_sampler.iter_list), generator=g).tolist() 200 | indices = [self.batch_sampler.iter_list[i] for i in perm] 201 | else: 202 | indices = self.batch_sampler.iter_list 203 | if self.drop_last: 204 | self.total_size = len(indices) - len(indices) % self.num_replicas 205 | else: 206 | padding_size = self.num_replicas - len(indices) % self.num_replicas 207 | indices += indices[:padding_size] 208 | self.total_size = len(indices) 209 | self.num_samples = self.total_size // self.num_replicas 210 | self.subset = indices[self.rank : self.total_size : self.num_replicas] 211 | assert len(self.subset) == self.num_samples 212 | 213 | def __iter__(self): 214 | return iter(self.subset) 215 | 216 | def __len__(self): 217 | return self.num_samples 218 | 219 | 220 | class AudioDataset(Dataset): 221 | """Create a Dataset for HuBERT model training and fine-tuning. 222 | 223 | Args: 224 | tsv_dir (str or Path): The root directory of the ``.tsv`` file list. 225 | subset (str): The subset of the dataset. Options: [``train100``, ``train960``, ``valid``]. 226 | """ 227 | 228 | def __init__( 229 | self, 230 | tsv_dir: Union[str, Path], 231 | subset: str, 232 | ) -> None: 233 | self.f_list, self.ind_list, self.len_list = self._get_lists(Path(tsv_dir), subset) 234 | 235 | def __len__(self): 236 | return len(self.f_list) 237 | 238 | def _get_lists( 239 | self, 240 | tsv_dir: Path, 241 | subset: str, 242 | ) -> Tuple[List[Path], List[int], List[int]]: 243 | """Get the list of paths for iteration. 244 | Args: 245 | tsv_dir (Path): The root directory of the ``.tsv`` file list. 246 | subset (str): The subset of the dataset. Options: [``train100``, ``train960``, ``valid``]. 247 | 248 | Returns: 249 | (numpy.array) List of file paths. 250 | (numpy.array) List of indices. 251 | (numpy.array) List of waveform lengths. 252 | """ 253 | f_ind_len_list = [] 254 | with open(tsv_dir / f"{subset}.tsv") as f: 255 | root = f.readline().rstrip() 256 | for index, line in enumerate(f): 257 | path, nsample = line.split("\t") 258 | path = f"{root}/{path}" 259 | nsample = int(nsample) 260 | f_ind_len_list.append((path, index, nsample)) 261 | f_list, ind_list, len_list = zip(*f_ind_len_list) 262 | return np.asarray(f_list), np.asarray(ind_list), np.asarray(len_list) 263 | 264 | def _load_audio(self, index: int) -> Tensor: 265 | """Load waveform given the sample index of the dataset. 266 | Args: 267 | index (int): The sample index. 268 | 269 | Returns: 270 | (Tensor): The corresponding waveform Tensor. shape: (channel, time) 271 | """ 272 | wav_path = self.f_list[index] 273 | waveform, sample_rate = torchaudio.load(wav_path) 274 | assert waveform.shape[1] == self.len_list[index] 275 | return waveform 276 | 277 | def __getitem__(self, index): 278 | waveform = self._load_audio(index) # (channel, time) 279 | length = waveform.shape[1] 280 | return waveform, length 281 | 282 | 283 | def _crop_audio( 284 | waveform: Tensor, 285 | length: Tensor, 286 | num_frames: int, 287 | rand_crop: bool, 288 | ) -> Tuple[Tensor, Tensor]: 289 | """Crop the audio. 290 | Args: 291 | waveform (Tensor): The waveform Tensor with dimensions `(1, time)`. 292 | length (Tensor): The length Tensor with dimension `(1,)`. 293 | num_frames (int): The final length of the waveform. 294 | rand_crop (bool): if ``rand_crop`` is True, the starting index of the 295 | waveform and label is random if the length is longer than the minimum 296 | length in the mini-batch. 297 | 298 | Returns: 299 | (Tuple(Tensor, Tensor)): 300 | Returns the Tensors for the waveform and the waveform length. 301 | """ 302 | frame_offset = 0 303 | waveform = waveform[0] 304 | if waveform.size(0) > num_frames and rand_crop: 305 | diff = waveform.size(0) - num_frames 306 | frame_offset = torch.randint(diff, size=(1,)) 307 | elif waveform.size(0) < num_frames: 308 | num_frames = waveform.size(0) 309 | waveform = waveform[frame_offset : frame_offset + num_frames] 310 | length = num_frames 311 | 312 | return waveform, length 313 | 314 | 315 | class CollateFnAudio: 316 | """The collate class for HuBERT pre-training and fine-tuning. 317 | Args: 318 | pad (bool): If ``True``, the waveforms and labels will be padded to the 319 | max length in the mini-batch. If ``pad`` is False, the waveforms 320 | and labels will be cropped to the minimum length in the mini-batch. 321 | (Default: False) 322 | rand_crop (bool): if ``True``, the starting index of the waveform 323 | and label is random if the length is longer than the minimum 324 | length in the mini-batch. 325 | """ 326 | 327 | def __init__( 328 | self, 329 | pad: bool = False, 330 | rand_crop: bool = True, 331 | ) -> None: 332 | self.pad = pad 333 | self.rand_crop = rand_crop 334 | 335 | def __call__(self, batch: List[Tuple[Tensor, int]]) -> Tuple[Tensor, Tensor]: 336 | """ 337 | Args: 338 | batch (List[Tuple(Tensor, int)]): 339 | The list of tuples that contains the waveforms and audio lengths. 340 | 341 | Returns: 342 | (Tuple(Tensor, Tensor)): 343 | The Tensor of waveforms with dimensions `(batch, time)`. 344 | The Tensor of audio lengths with dimension `(batch,)`. 345 | """ 346 | if self.pad: 347 | num_frames = max([sample[0].shape[1] for sample in batch]) 348 | else: 349 | num_frames = min([sample[0].shape[1] for sample in batch]) 350 | waveforms, lengths = [], [] 351 | for sample in batch: 352 | waveform, length = sample # waveform has shape (channel, time) 353 | waveform, length = _crop_audio(waveform, length, num_frames, self.rand_crop) # waveform has shape (time,) 354 | waveforms.append(waveform) 355 | lengths.append(length) 356 | # make sure the shapes are the same if not apply zero-padding 357 | if not self.pad: 358 | assert all( 359 | [waveform.shape[0] == waveforms[0].shape[0] for waveform in waveforms] 360 | ), "The dimensions of the waveforms should be identical in the same batch." 361 | waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True) 362 | lengths = torch.tensor(lengths) 363 | return waveforms, lengths 364 | -------------------------------------------------------------------------------- /distill.py: -------------------------------------------------------------------------------- 1 | """Perform distillation and pruning.""" 2 | 3 | import logging 4 | import pathlib 5 | from argparse import ArgumentParser 6 | 7 | import torch 8 | import torch.nn as nn 9 | import pytorch_lightning as pl 10 | from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor 11 | from lightning_lite.utilities.rank_zero import _get_rank 12 | 13 | from lightning import ( 14 | DistillModule, 15 | DistillLoss, 16 | ) 17 | from wav2vec2.model import ( 18 | wav2vec2_model, 19 | ) 20 | 21 | _LG = logging.getLogger(f"{__name__}:{_get_rank()}") 22 | 23 | 24 | def _init_layer_transform(module: nn.Linear): 25 | module.weight.data.copy_(torch.eye(len(module.weight))) 26 | module.bias.data.fill_(0) 27 | 28 | 29 | def run_train(args): 30 | pl.seed_everything(2022) 31 | 32 | # Callbacks 33 | lr_monitor = LearningRateMonitor() # log learning rates for all param groups 34 | model_checkpoint = ModelCheckpoint(dirpath=args.exp_dir / "ckpts", verbose=True) # only save the latest epoch 35 | callbacks = [lr_monitor, model_checkpoint] 36 | 37 | trainer = pl.Trainer( 38 | default_root_dir=args.exp_dir, 39 | callbacks=callbacks, 40 | max_steps=args.max_updates, 41 | strategy="ddp", 42 | accelerator="gpu", 43 | num_nodes=args.num_nodes, 44 | devices=args.gpus, 45 | accumulate_grad_batches=args.accum_grad, 46 | replace_sampler_ddp=False, # we use the custom distributed sampler for ddp 47 | reload_dataloaders_every_n_epochs=1, 48 | gradient_clip_val=args.clip_norm, 49 | log_every_n_steps=args.log_interval, 50 | precision=args.precision, 51 | ) 52 | 53 | # Create teacher model 54 | teacher_ckpt = torch.load(args.teacher_ckpt, map_location="cpu") 55 | teacher_model = wav2vec2_model(**teacher_ckpt['config']) 56 | _LG.info(f"Teacher model:\n{teacher_model}") 57 | teacher_result = teacher_model.load_state_dict(teacher_ckpt['state_dict'], strict=False) 58 | _LG.info(f"Load pretrained ckpt to teacher: missing {teacher_result.missing_keys}, unexpected {teacher_result.unexpected_keys}") 59 | # Freeze teacher model 60 | for p in teacher_model.parameters(): 61 | p.requires_grad = False 62 | _LG.info("Freeze parameters of the teacher model by setting requires_grad=False") 63 | teacher_model.eval() 64 | 65 | # Create student model 66 | student_ckpt = torch.load(args.student_ckpt, map_location="cpu") 67 | pruning_units = args.pruning_units.split(",") 68 | _LG.info(f"Pruning units: {pruning_units}") 69 | student_config = student_ckpt['config'] 70 | student_config.update( 71 | dict( 72 | extractor_prune_conv_channels = "conv" in pruning_units, 73 | encoder_prune_attention_heads = "head" in pruning_units, 74 | encoder_prune_attention_layer = "attlayer" in pruning_units, 75 | encoder_prune_feed_forward_intermediate = "interm" in pruning_units, 76 | encoder_prune_feed_forward_layer = "ffnlayer" in pruning_units, 77 | ) 78 | ) 79 | student_model = wav2vec2_model(**student_config) 80 | _LG.info(f"Student model:\n{student_model}") 81 | student_result = student_model.load_state_dict(student_ckpt['state_dict'], strict=False) 82 | _LG.info(f"Load pretrained ckpt to student: missing {student_result.missing_keys}, unexpected {student_result.unexpected_keys}") 83 | 84 | # Create linear layers which transform student hiddens to teacher hiddens 85 | distill_layer_groups = [[int(l) for l in g.split(",")] for g in args.distill_layers.split(".")] 86 | _LG.info(f"Distill transformer layers: {distill_layer_groups}") 87 | distill_layers = [] 88 | for g in distill_layer_groups: 89 | distill_layers.extend(g) 90 | student_embed_dim = student_model.encoder.feature_projection.projection.out_features 91 | teacher_embed_dim = teacher_model.encoder.feature_projection.projection.out_features 92 | 93 | if args.distill_mode == "layer2layer": 94 | distill_linear_projs = nn.ModuleList() 95 | for g in distill_layer_groups: # layers in the same group share a linear layer 96 | tmp_linear = nn.Linear(student_embed_dim, teacher_embed_dim) 97 | _init_layer_transform(tmp_linear) 98 | for _ in range(len(g)): 99 | distill_linear_projs.append(tmp_linear) 100 | elif args.distill_mode == "predlayer": # same as DistilHuBERT 101 | # use independent linear layers, cannot be shared 102 | distill_linear_projs = nn.ModuleList( 103 | nn.Sequential( 104 | nn.Linear(student_embed_dim, teacher_embed_dim), 105 | nn.GELU(), 106 | ) for _ in range(len(distill_layers)) 107 | ) 108 | else: 109 | raise ValueError(f"Invalid distill mode: {args.distill_mode}") 110 | 111 | # Create DistillLoss module 112 | distill_loss_criterion = DistillLoss( 113 | l2_weight=args.l2_weight, 114 | l1_weight=args.l1_weight, 115 | cos_weight=args.cos_weight, 116 | cos_type=args.cos_type, 117 | ) 118 | _LG.info(f"Distill loss module:\n{distill_loss_criterion}") 119 | 120 | distill_module = DistillModule( 121 | teacher_model=teacher_model, 122 | student_model=student_model, 123 | distill_mode=args.distill_mode, 124 | distill_layers=distill_layers, 125 | distill_linear_projs=distill_linear_projs, 126 | distill_loss=distill_loss_criterion, 127 | learning_rate=args.learning_rate, 128 | weight_decay=args.weight_decay, 129 | warmup_updates=args.warmup_updates, 130 | max_updates=args.max_updates, 131 | use_reg=True, 132 | reg_learning_rate=args.reg_learning_rate, 133 | target_sparsity=args.target_sparsity, 134 | sparsity_warmup_updates=args.sparsity_warmup_updates, 135 | tsv_dir=args.tsv_dir, 136 | train_subset=args.train_subset, 137 | seconds_per_batch=args.seconds_per_batch, 138 | num_workers=args.num_workers, 139 | ) 140 | 141 | trainer.fit( 142 | distill_module, 143 | ckpt_path=args.resume_checkpoint, 144 | ) 145 | 146 | 147 | def _parse_args(): 148 | parser = ArgumentParser( 149 | description="Joint distillation and pruning of HuBERT", 150 | ) 151 | 152 | # dataset and dataloader related 153 | parser.add_argument( 154 | "--tsv_dir", 155 | type=pathlib.Path, 156 | required=True, 157 | help="Path to the directory containing tsv files.", 158 | ) 159 | parser.add_argument( 160 | "--train_subset", 161 | default="train100", 162 | choices=["train100", "train960"], 163 | type=str, 164 | help="The subset name for training. (Default: 'train100')", 165 | ) 166 | parser.add_argument( 167 | "--seconds_per_batch", 168 | default=87.5, 169 | type=float, 170 | help="Number of seconds of audio in a mini-batch. (Default: 87.5)", 171 | ) 172 | parser.add_argument( 173 | "--num_workers", 174 | default=1, 175 | type=int, 176 | help="Number of workers in DataLoader." 177 | ) 178 | 179 | # general training related 180 | parser.add_argument( 181 | "--resume_checkpoint", 182 | type=pathlib.Path, 183 | default=None, 184 | help="Path to the feature and label directories. (Default: None)", 185 | ) 186 | parser.add_argument( 187 | "--exp_dir", 188 | default=pathlib.Path("./exp"), 189 | type=pathlib.Path, 190 | help="Directory to save checkpoints and logs to. (Default: './exp')", 191 | ) 192 | parser.add_argument( 193 | "--log_interval", 194 | default=50, 195 | type=int, 196 | help="Log interval in steps." 197 | ) 198 | parser.add_argument( 199 | "--learning_rate", 200 | default=0.0002, 201 | type=float, 202 | help="The peak learning rate. (Default: 0.0002)", 203 | ) 204 | parser.add_argument( 205 | "--weight_decay", 206 | default=0.0, 207 | type=float, 208 | help="Weight decay (L2 penalty) (Default: 0.0)", 209 | ) 210 | parser.add_argument( 211 | "--warmup_updates", 212 | default=15000, 213 | type=int, 214 | help="Number of steps for warm up the learning rate. (Default: 15000)", 215 | ) 216 | parser.add_argument( 217 | "--max_updates", 218 | default=50000, 219 | type=int, 220 | help="Total number of training steps. (Default: 50000)", 221 | ) 222 | parser.add_argument( 223 | "--clip_norm", 224 | default=10.0, 225 | type=float, 226 | help="The gradient norm value to clip. (Default: 10.0)", 227 | ) 228 | parser.add_argument( 229 | "--num_nodes", 230 | default=1, 231 | type=int, 232 | help="Number of nodes to use for training. (Default: 1)", 233 | ) 234 | parser.add_argument( 235 | "--gpus", 236 | default=4, 237 | type=int, 238 | help="Number of GPUs per node to use for training. (Default: 4)", 239 | ) 240 | parser.add_argument( 241 | "--accum_grad", 242 | default=1, 243 | type=int, 244 | help="Gradient accumulation steps." 245 | ) 246 | parser.add_argument( 247 | "--precision", 248 | default=32, 249 | type=int, 250 | help="Precision for training." 251 | ) 252 | 253 | # distillation related 254 | parser.add_argument( 255 | "--teacher_ckpt", 256 | default=pathlib.Path("pretrained_ckpts/hubert-base-ls960.pth"), 257 | type=pathlib.Path, 258 | help="Path to the teacher model checkpoint." 259 | ) 260 | parser.add_argument( 261 | "--student_ckpt", 262 | default=pathlib.Path("pretrained_ckpts/hubert-base-ls960.pth"), 263 | type=pathlib.Path, 264 | help="Path to the student model checkpoint (for initialization)." 265 | ) 266 | parser.add_argument( 267 | "--distill_layers", 268 | default="0.4,8,12", 269 | type=str, 270 | help="Distill layer indices (use period to separate groups and comma to separate layers within a group)." 271 | ) 272 | parser.add_argument( 273 | "--distill_mode", 274 | type=str, 275 | default="layer2layer", 276 | choices=["layer2layer", "predlayer"], 277 | help="Distill mode, either layer2layer or predlayer." 278 | ) 279 | parser.add_argument( 280 | "--l2_weight", 281 | default=0.0, 282 | type=float, 283 | help="Weight of MSE loss." 284 | ) 285 | parser.add_argument( 286 | "--l1_weight", 287 | default=1.0, 288 | type=float, 289 | help="Weight of L1 loss." 290 | ) 291 | parser.add_argument( 292 | "--cos_weight", 293 | default=1.0, 294 | type=float, 295 | help="Weight of cosine similarity loss." 296 | ) 297 | parser.add_argument( 298 | "--cos_type", 299 | default="raw", 300 | type=str, 301 | choices=["raw", "log_sig"], 302 | help="Type of the cosine similarity loss." 303 | ) 304 | 305 | # pruning related 306 | parser.add_argument( 307 | "--pruning_units", 308 | default="conv,head,interm,attlayer,ffnlayer", 309 | type=str, 310 | help="Pruning units as a comma-separated list." 311 | ) 312 | parser.add_argument( 313 | "--reg_learning_rate", 314 | default=0.02, 315 | type=float, 316 | help="Regularization learning rate." 317 | ) 318 | parser.add_argument( 319 | "--target_sparsity", 320 | default=0.75, 321 | type=float, 322 | help="Target sparsity." 323 | ) 324 | parser.add_argument( 325 | "--sparsity_warmup_updates", 326 | default=5000, 327 | type=int, 328 | help="Warmup updates for the target sparsity." 329 | ) 330 | 331 | return parser.parse_args() 332 | 333 | 334 | def _init_logger(): 335 | logging.basicConfig( 336 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 337 | datefmt="%Y-%m-%d %H:%M:%S", 338 | level=logging.INFO, 339 | ) 340 | if _get_rank() == 0: 341 | _LG.setLevel(logging.INFO) 342 | else: 343 | _LG.setLevel(logging.WARN) 344 | 345 | 346 | def cli_main(): 347 | _init_logger() 348 | args = _parse_args() 349 | run_train(args) 350 | 351 | 352 | if __name__ == "__main__": 353 | cli_main() 354 | -------------------------------------------------------------------------------- /final_distill.py: -------------------------------------------------------------------------------- 1 | """Perform distillation for the pruned model.""" 2 | 3 | import logging 4 | import pathlib 5 | from argparse import ArgumentParser 6 | 7 | import torch 8 | import torch.nn as nn 9 | import pytorch_lightning as pl 10 | from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor 11 | from lightning_lite.utilities.rank_zero import _get_rank 12 | 13 | from lightning import ( 14 | DistillModule, 15 | DistillLoss, 16 | ) 17 | from wav2vec2.model import ( 18 | wav2vec2_model, 19 | ) 20 | 21 | _LG = logging.getLogger(f"{__name__}:{_get_rank()}") 22 | 23 | 24 | def run_train(args): 25 | pl.seed_everything(2022) 26 | 27 | # Callbacks 28 | lr_monitor = LearningRateMonitor() # log learning rates for all param groups 29 | model_checkpoint = ModelCheckpoint(dirpath=args.exp_dir / "ckpts", verbose=True) # only save the latest epoch 30 | callbacks = [lr_monitor, model_checkpoint] 31 | 32 | trainer = pl.Trainer( 33 | default_root_dir=args.exp_dir, 34 | callbacks=callbacks, 35 | max_steps=args.max_updates, 36 | strategy="ddp", 37 | accelerator="gpu", 38 | num_nodes=args.num_nodes, 39 | devices=args.gpus, 40 | accumulate_grad_batches=args.accum_grad, 41 | replace_sampler_ddp=False, # we use the custom distributed sampler for ddp 42 | reload_dataloaders_every_n_epochs=1, 43 | gradient_clip_val=args.clip_norm, 44 | log_every_n_steps=args.log_interval, 45 | precision=args.precision, 46 | ) 47 | 48 | # Create teacher model 49 | teacher_ckpt = torch.load(args.teacher_ckpt, map_location='cpu') 50 | teacher_model = wav2vec2_model(**teacher_ckpt['config']) 51 | _LG.info(f"Teacher model:\n{teacher_model}") 52 | teacher_result = teacher_model.load_state_dict(teacher_ckpt['state_dict'], strict=False) 53 | _LG.info(f"Load pretrained ckpt to teacher: missing {teacher_result.missing_keys}, unexpected {teacher_result.unexpected_keys}") 54 | # Freeze teacher model 55 | for p in teacher_model.parameters(): 56 | p.requires_grad = False 57 | _LG.info("Freeze parameters of the teacher model by setting requires_grad=False") 58 | teacher_model.eval() 59 | 60 | # Create student model 61 | student_ckpt = torch.load(args.student_ckpt, map_location='cpu') 62 | student_model = wav2vec2_model(**student_ckpt["config"]) 63 | _LG.info(f"Student model:\n{student_model}") 64 | student_result = student_model.load_state_dict(student_ckpt["state_dict"], strict=False) 65 | _LG.info(f"Load pretrained ckpt to student: missing {student_result.missing_keys}, unexpected {student_result.unexpected_keys}") 66 | 67 | # Load weights to create linear layers which transform student hiddens to teacher hiddens 68 | distill_layer_groups = [[int(l) for l in g.split(",")] for g in args.distill_layers.split(".")] 69 | _LG.info(f"Distill transformer layers: {distill_layer_groups}") 70 | distill_layers = [] 71 | for g in distill_layer_groups: 72 | distill_layers.extend(g) 73 | student_embed_dim = student_model.encoder.feature_projection.projection.out_features 74 | teacher_embed_dim = teacher_model.encoder.feature_projection.projection.out_features 75 | 76 | if args.distill_mode == "layer2layer": 77 | distill_linear_projs = nn.ModuleList() 78 | for g in distill_layer_groups: # layers in the same group share a linear layer 79 | tmp_linear = nn.Linear(student_embed_dim, teacher_embed_dim) 80 | for _ in range(len(g)): 81 | distill_linear_projs.append(tmp_linear) 82 | elif args.distill_mode == "predlayer": # same as DistilHuBERT 83 | # use independent linear layers, cannot be shared 84 | distill_linear_projs = nn.ModuleList( 85 | nn.Sequential( 86 | nn.Linear(student_embed_dim, teacher_embed_dim), 87 | nn.GELU(), 88 | ) for _ in range(len(distill_layers)) 89 | ) 90 | else: 91 | raise ValueError(f"Invalid distill mode: {args.distill_mode}") 92 | 93 | distill_linear_projs.load_state_dict(student_ckpt["distill_linear_projs"]) 94 | 95 | # Create DistillLoss module 96 | distill_loss_criterion = DistillLoss( 97 | l2_weight=args.l2_weight, 98 | l1_weight=args.l1_weight, 99 | cos_weight=args.cos_weight, 100 | cos_type=args.cos_type, 101 | ) 102 | _LG.info(f"Distill loss module:\n{distill_loss_criterion}") 103 | 104 | distill_module = DistillModule( 105 | teacher_model=teacher_model, 106 | student_model=student_model, 107 | distill_mode=args.distill_mode, 108 | distill_layers=distill_layers, 109 | distill_linear_projs=distill_linear_projs, 110 | distill_loss=distill_loss_criterion, 111 | learning_rate=args.learning_rate, 112 | weight_decay=args.weight_decay, 113 | warmup_updates=args.warmup_updates, 114 | max_updates=args.max_updates, 115 | use_reg=False, # no pruning, only distillation 116 | reg_learning_rate=None, 117 | target_sparsity=None, 118 | sparsity_warmup_updates=None, 119 | tsv_dir=args.tsv_dir, 120 | train_subset=args.train_subset, 121 | seconds_per_batch=args.seconds_per_batch, 122 | num_workers=args.num_workers, 123 | ) 124 | 125 | trainer.fit( 126 | distill_module, 127 | ckpt_path=args.resume_checkpoint, 128 | ) 129 | 130 | 131 | def _parse_args(): 132 | parser = ArgumentParser( 133 | description="Distill the pruned model.", 134 | ) 135 | 136 | # dataset and dataloader related 137 | parser.add_argument( 138 | "--tsv_dir", 139 | type=pathlib.Path, 140 | required=True, 141 | help="Path to the directory containing tsv files.", 142 | ) 143 | parser.add_argument( 144 | "--train_subset", 145 | default="train100", 146 | choices=["train100", "train960"], 147 | type=str, 148 | help="The subset name for training. (Default: 'train100')", 149 | ) 150 | parser.add_argument( 151 | "--seconds_per_batch", 152 | default=87.5, 153 | type=float, 154 | help="Number of seconds of audio in a mini-batch. (Default: 87.5)", 155 | ) 156 | parser.add_argument( 157 | "--num_workers", 158 | default=1, 159 | type=int, 160 | help="Number of workers in DataLoader." 161 | ) 162 | 163 | # general training related 164 | parser.add_argument( 165 | "--resume_checkpoint", 166 | type=pathlib.Path, 167 | default=None, 168 | help="Path to the feature and label directories. (Default: None)", 169 | ) 170 | parser.add_argument( 171 | "--exp_dir", 172 | type=pathlib.Path, 173 | help="Suffix of the exp directory name." 174 | ) 175 | parser.add_argument( 176 | "--log_interval", 177 | default=50, 178 | type=int, 179 | help="Log interval in steps." 180 | ) 181 | parser.add_argument( 182 | "--learning_rate", 183 | default=0.0001, 184 | type=float, 185 | help="The peak learning rate. (Default: 0.0001)", 186 | ) 187 | parser.add_argument( 188 | "--weight_decay", 189 | default=0.0, 190 | type=float, 191 | help="Weight decay (L2 penalty) (Default: 0.0)", 192 | ) 193 | parser.add_argument( 194 | "--warmup_updates", 195 | default=5000, 196 | type=int, 197 | help="Number of steps for warm up the learning rate. (Default: 5000)", 198 | ) 199 | parser.add_argument( 200 | "--max_updates", 201 | default=25000, 202 | type=int, 203 | help="Total number of training steps. (Default: 25000)", 204 | ) 205 | parser.add_argument( 206 | "--clip_norm", 207 | default=10.0, 208 | type=float, 209 | help="The gradient norm value to clip. (Default: 10.0)", 210 | ) 211 | parser.add_argument( 212 | "--num_nodes", 213 | default=1, 214 | type=int, 215 | help="Number of nodes to use for training. (Default: 1)", 216 | ) 217 | parser.add_argument( 218 | "--gpus", 219 | default=4, 220 | type=int, 221 | help="Number of GPUs per node to use for training. (Default: 4)", 222 | ) 223 | parser.add_argument( 224 | "--accum_grad", 225 | default=1, 226 | type=int, 227 | help="Gradient accumulation steps." 228 | ) 229 | parser.add_argument( 230 | "--precision", 231 | default=32, 232 | type=int, 233 | help="Precision for training." 234 | ) 235 | 236 | # distillation related 237 | parser.add_argument( 238 | "--teacher_ckpt", 239 | default=pathlib.Path("pretrained_ckpts/hubert-base-ls960.pth"), 240 | type=pathlib.Path, 241 | help="Path to the teacher model checkpoint." 242 | ) 243 | parser.add_argument( 244 | "--student_ckpt", 245 | type=pathlib.Path, 246 | help="Path to the student model checkpoint (for initialization)." 247 | ) 248 | parser.add_argument( 249 | "--distill_layers", 250 | default="0.4,8,12", 251 | type=str, 252 | help="Distill layer indices (use period to separate groups and comma to separate layers within a group)." 253 | ) 254 | parser.add_argument( 255 | "--distill_mode", 256 | type=str, 257 | default="layer2layer", 258 | choices=["layer2layer", "predlayer"], 259 | help="Distill mode, either layer2layer or predlayer." 260 | ) 261 | parser.add_argument( 262 | "--l2_weight", 263 | default=0.0, 264 | type=float, 265 | help="Weight of MSE loss." 266 | ) 267 | parser.add_argument( 268 | "--l1_weight", 269 | default=1.0, 270 | type=float, 271 | help="Weight of L1 loss." 272 | ) 273 | parser.add_argument( 274 | "--cos_weight", 275 | default=1.0, 276 | type=float, 277 | help="Weight of cosine similarity loss." 278 | ) 279 | parser.add_argument( 280 | "--cos_type", 281 | default="raw", 282 | type=str, 283 | choices=["raw", "log_sig"], 284 | help="Type of the cosine similarity loss." 285 | ) 286 | 287 | return parser.parse_args() 288 | 289 | 290 | def _init_logger(): 291 | logging.basicConfig( 292 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 293 | datefmt="%Y-%m-%d %H:%M:%S", 294 | level=logging.INFO, 295 | ) 296 | if _get_rank() == 0: 297 | _LG.setLevel(logging.INFO) 298 | else: 299 | _LG.setLevel(logging.WARN) 300 | 301 | 302 | def cli_main(): 303 | _init_logger() 304 | args = _parse_args() 305 | run_train(args) 306 | 307 | 308 | if __name__ == "__main__": 309 | cli_main() 310 | -------------------------------------------------------------------------------- /imgs/dphubert-results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyf98/DPHuBERT/c18093fe4b56a0027a80bf9b9b1a23f932cbf14c/imgs/dphubert-results.png -------------------------------------------------------------------------------- /imgs/dphubert-superb-score.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyf98/DPHuBERT/c18093fe4b56a0027a80bf9b9b1a23f932cbf14c/imgs/dphubert-superb-score.png -------------------------------------------------------------------------------- /imgs/dphubert-superb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyf98/DPHuBERT/c18093fe4b56a0027a80bf9b9b1a23f932cbf14c/imgs/dphubert-superb.png -------------------------------------------------------------------------------- /imgs/dphubert-train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyf98/DPHuBERT/c18093fe4b56a0027a80bf9b9b1a23f932cbf14c/imgs/dphubert-train.png -------------------------------------------------------------------------------- /lightning.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pathlib 3 | from typing import Optional, List, Union 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.optim.optimizer import Optimizer 8 | from torch.utils.data import DataLoader 9 | import pytorch_lightning as pl 10 | 11 | from wav2vec2.model import ( 12 | Wav2Vec2Model, 13 | ) 14 | from dataset.audio_dataset import ( 15 | BucketizeBatchSampler, 16 | DistributedBatchSampler, 17 | CollateFnAudio, 18 | AudioDataset, 19 | ) 20 | 21 | 22 | class LinearDecayLRScheduler(torch.optim.lr_scheduler._LRScheduler): 23 | """Linear learning rate scheduler with warm up.""" 24 | 25 | def __init__( 26 | self, 27 | optimizer: Optimizer, 28 | warmup_updates: int, 29 | max_updates: int, 30 | last_epoch: int = -1, 31 | verbose: bool = False, 32 | ): 33 | self.warmup_updates = warmup_updates 34 | self.max_updates = max_updates 35 | super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose) 36 | 37 | def get_lr(self): 38 | if self._step_count <= self.warmup_updates: 39 | return [self._step_count / self.warmup_updates * base_lr for base_lr in self.base_lrs] 40 | elif self._step_count >= self.max_updates: 41 | return [0.0 for _ in self.base_lrs] 42 | else: 43 | pct_remaining = (self.max_updates - self._step_count) / (self.max_updates - self.warmup_updates) 44 | return [base_lr * pct_remaining for base_lr in self.base_lrs] 45 | 46 | 47 | class TriStageLRScheduler(torch.optim.lr_scheduler._LRScheduler): 48 | """Linear learning rate scheduler with warmup, hold, and decay.""" 49 | 50 | def __init__( 51 | self, 52 | optimizer: Optimizer, 53 | warmup_updates: int, 54 | hold_updates: int, 55 | decay_updates: int, 56 | init_lr_scale: float = 0.01, 57 | final_lr_scale: float = 0.05, 58 | last_epoch: int = -1, 59 | verbose: bool = False, 60 | ): 61 | self.warmup_updates = warmup_updates 62 | self.hold_updates = hold_updates 63 | self.decay_updates = decay_updates 64 | self.init_lr_scale = init_lr_scale 65 | self.final_lr_scale = final_lr_scale 66 | 67 | super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose) 68 | 69 | def get_lr(self): 70 | if self._step_count <= self.warmup_updates: 71 | return [ 72 | base_lr * (self.init_lr_scale + self._step_count / self.warmup_updates * (1 - self.init_lr_scale)) 73 | for base_lr in self.base_lrs 74 | ] 75 | elif self.warmup_updates < self._step_count <= (self.warmup_updates + self.hold_updates): 76 | return list(self.base_lrs) 77 | elif self._step_count <= (self.warmup_updates + self.hold_updates + self.decay_updates): 78 | return [ 79 | base_lr 80 | * math.exp( 81 | math.log(self.final_lr_scale) 82 | * (self._step_count - self.warmup_updates - self.hold_updates) 83 | / self.decay_updates 84 | ) 85 | for base_lr in self.base_lrs 86 | ] 87 | else: 88 | return [base_lr * self.final_lr_scale for base_lr in self.base_lrs] 89 | 90 | 91 | class DistillLoss(nn.Module): 92 | def __init__(self, l2_weight, l1_weight, cos_weight, cos_type): 93 | super().__init__() 94 | self.l2_weight = l2_weight 95 | self.l1_weight = l1_weight 96 | self.cos_weight = cos_weight 97 | self.cos_type = cos_type 98 | assert cos_type in ["raw", "log_sig"], cos_type 99 | 100 | if l2_weight != 0: 101 | self.mse_loss = nn.MSELoss() 102 | if l1_weight != 0: 103 | self.l1_loss = nn.L1Loss() 104 | if cos_weight != 0: 105 | self.cos_sim = nn.CosineSimilarity(dim=-1) 106 | 107 | def __repr__(self) -> str: 108 | return "{}(l2={}, l1={}, {}_cos={})".format( 109 | self.__class__.__name__, 110 | self.l2_weight, 111 | self.l1_weight, 112 | self.cos_type, 113 | self.cos_weight, 114 | ) 115 | 116 | def forward(self, input: torch.Tensor, target: torch.Tensor): 117 | """ 118 | Args: 119 | input: (batch, layer, time, feature) 120 | target: same shape as input 121 | """ 122 | loss_mse = 0 123 | loss_l1 = 0 124 | loss_cos = 0 125 | if self.l2_weight != 0: 126 | loss_mse = self.mse_loss(input, target) 127 | if self.l1_weight != 0: 128 | loss_l1 = self.l1_loss(input, target) 129 | if self.cos_weight != 0: # maximize cosine similarity 130 | if self.cos_type == "raw": 131 | loss_cos = -self.cos_sim(input, target).mean() 132 | elif self.cos_type == "log_sig": 133 | loss_cos = -self.cos_sim(input, target).sigmoid().log().mean() 134 | else: 135 | raise ValueError 136 | 137 | loss = self.l2_weight * loss_mse + self.l1_weight * loss_l1 + self.cos_weight * loss_cos 138 | 139 | return loss, (loss_mse, loss_l1, loss_cos) 140 | 141 | 142 | class DistillModule(pl.LightningModule): 143 | def __init__( 144 | self, 145 | *, 146 | teacher_model: Wav2Vec2Model, 147 | student_model: Wav2Vec2Model, 148 | distill_mode: str, # "layer2layer", "predlayer" 149 | distill_layers: List[int], # layer indices to align, from 0 to num_layers 150 | distill_linear_projs: nn.ModuleList, # list of linear layers which transform student to teacher 151 | distill_loss: DistillLoss, 152 | learning_rate: float, 153 | weight_decay: float, 154 | warmup_updates: int, 155 | max_updates: int, 156 | use_reg: bool, # whether to use the L0 regularization 157 | reg_learning_rate: Optional[float], # lr for loga and lambda 158 | target_sparsity: Optional[float], 159 | sparsity_warmup_updates: Optional[int], # linearly increase the target sparsity 160 | tsv_dir: Union[str, pathlib.Path], 161 | train_subset: str, 162 | seconds_per_batch: float, 163 | num_workers: int, 164 | ): 165 | super().__init__() 166 | 167 | self.teacher_model = teacher_model 168 | self.student_model = student_model 169 | 170 | self.original_num_params = sum(p.numel() for p in teacher_model.parameters()) 171 | 172 | assert distill_mode in ["layer2layer", "predlayer"], distill_mode 173 | assert len(distill_layers) == len(distill_linear_projs) 174 | self.distill_mode = distill_mode 175 | self.distill_layers = distill_layers 176 | self.distill_linear_projs = distill_linear_projs 177 | self.distill_loss = distill_loss 178 | 179 | self.learning_rate = learning_rate 180 | self.weight_decay = weight_decay 181 | self.warmup_updates = warmup_updates 182 | self.max_updates = max_updates 183 | 184 | self.use_reg = use_reg 185 | self.reg_learning_rate = reg_learning_rate 186 | self.target_sparsity = target_sparsity 187 | self.sparsity_warmup_updates = sparsity_warmup_updates 188 | 189 | # lambdas for Lagrangian 190 | if self.use_reg: 191 | self.lambda1 = nn.Parameter(torch.tensor(0.0)) 192 | self.lambda2 = nn.Parameter(torch.tensor(0.0)) 193 | 194 | # dataset related 195 | self.tsv_dir = tsv_dir 196 | self.train_subset = train_subset 197 | self.seconds_per_batch = seconds_per_batch 198 | self.num_workers = num_workers 199 | 200 | def configure_optimizers(self): 201 | main_params = [p for n, p in self.student_model.named_parameters() if "log_alpha" not in n] 202 | main_params.extend(list(self.distill_linear_projs.parameters())) 203 | pgs = [ 204 | { 205 | 'params': main_params, 206 | 'lr': self.learning_rate, 207 | 'weight_decay': self.weight_decay, 208 | 'name': 'main_params', 209 | }, 210 | ] 211 | if self.use_reg: 212 | pgs.extend( 213 | [ 214 | { 215 | 'params': [p for n, p in self.student_model.named_parameters() if "log_alpha" in n], 216 | 'lr': self.reg_learning_rate, 217 | 'weight_decay': 0.0, 218 | 'name': 'log_alpha', 219 | }, 220 | { 221 | 'params': [self.lambda1, self.lambda2], 222 | 'lr': -self.reg_learning_rate, 223 | 'weight_decay': 0.0, 224 | 'name': 'lambda', 225 | }, 226 | ] 227 | ) 228 | optimizer = torch.optim.AdamW(pgs) 229 | lr_scheduler = LinearDecayLRScheduler( 230 | optimizer, warmup_updates=self.warmup_updates, max_updates=self.max_updates 231 | ) 232 | return { 233 | 'optimizer': optimizer, 234 | 'lr_scheduler': { 235 | "scheduler": lr_scheduler, 236 | "interval": "step", 237 | }, 238 | } 239 | 240 | def _get_target_sparsity(self): 241 | if self.global_step >= self.sparsity_warmup_updates: 242 | return self.target_sparsity 243 | return self.target_sparsity * (self.global_step / self.sparsity_warmup_updates) 244 | 245 | def _step(self, batch, batch_idx, mode): 246 | waveforms, lengths = batch 247 | self.teacher_model.eval() 248 | with torch.no_grad(): 249 | teacher_hiddens, teacher_lengths = self.teacher_model.extract_features(waveforms, lengths) 250 | teacher_hiddens = torch.stack( 251 | [teacher_hiddens[idx] for idx in self.distill_layers], dim=1 252 | ) # (batch, layer, time, feature) 253 | 254 | student_hiddens, student_lengths = self.student_model.extract_features(waveforms, lengths) 255 | new_student_hiddens = [] 256 | for idx, proj in zip(self.distill_layers, self.distill_linear_projs): 257 | if self.distill_mode == "layer2layer": 258 | new_student_hiddens.append(proj(student_hiddens[idx])) 259 | elif self.distill_mode == "predlayer": 260 | new_student_hiddens.append(proj(student_hiddens[-1])) 261 | else: 262 | raise ValueError(f"Invalid distill mode: {self.distill_mode}") 263 | student_hiddens = torch.stack(new_student_hiddens, dim=1) # (batch, layer, time, feature) 264 | 265 | loss_distill, (loss_mse, loss_l1, loss_cos) = self.distill_loss(student_hiddens, teacher_hiddens) 266 | 267 | if self.use_reg: 268 | cur_target_sparsity = self._get_target_sparsity() 269 | cur_expected_sparsity = 1. - self.student_model.get_num_params() / self.original_num_params 270 | loss_reg = self.lambda1 * (cur_expected_sparsity - cur_target_sparsity) \ 271 | + self.lambda2 * (cur_expected_sparsity - cur_target_sparsity)**2 272 | else: 273 | loss_reg = 0 274 | 275 | loss = loss_distill + loss_reg 276 | 277 | self.log_dict( 278 | { 279 | f"{mode}_loss": loss, # total loss 280 | f"{mode}_loss_distill": loss_distill, # distill total loss 281 | f"{mode}_loss_mse": loss_mse, 282 | f"{mode}_loss_l1": loss_l1, 283 | f"{mode}_loss_cos": loss_cos, 284 | f"{mode}_loss_reg": loss_reg, # sparsity loss 285 | } 286 | ) 287 | if mode == "train" and self.use_reg: 288 | self.log_dict( 289 | { 290 | 'sparsity_expected': cur_expected_sparsity, 291 | 'sparsity_target': cur_target_sparsity, 292 | 'lambda1': self.lambda1, 293 | 'lambda2': self.lambda2, 294 | }, 295 | ) 296 | return loss 297 | 298 | def training_step(self, batch, batch_idx): 299 | loss = self._step(batch, batch_idx, mode="train") 300 | return loss 301 | 302 | def validation_step(self, batch, batch_idx): 303 | loss = self._step(batch, batch_idx, mode="valid") 304 | return loss 305 | 306 | def train_dataloader(self): 307 | dataset = AudioDataset(self.tsv_dir, self.train_subset) 308 | sampler = BucketizeBatchSampler( 309 | dataset.len_list, 310 | num_buckets=1000, 311 | max_token_count=self.seconds_per_batch * 16000, 312 | min_len=32000, 313 | max_len=250000, 314 | shuffle=False, 315 | ) 316 | sampler = DistributedBatchSampler(sampler, shuffle=True) 317 | sampler.set_epoch(self.current_epoch) 318 | dataloader = DataLoader( 319 | dataset, 320 | batch_sampler=sampler, 321 | collate_fn=CollateFnAudio(pad=False, rand_crop=True), # crop to the min length in a mini-batch 322 | num_workers=self.num_workers, 323 | ) 324 | return dataloader 325 | 326 | def val_dataloader(self): 327 | dataset = AudioDataset(self.tsv_dir, "valid") 328 | sampler = BucketizeBatchSampler( 329 | dataset.len_list, 330 | num_buckets=1000, 331 | max_token_count=self.seconds_per_batch * 16000, 332 | min_len=32000, 333 | max_len=250000, 334 | shuffle=False, 335 | ) 336 | dataloader = DataLoader( 337 | dataset, 338 | batch_sampler=sampler, 339 | collate_fn=CollateFnAudio(pad=False, rand_crop=True), 340 | num_workers=self.num_workers, 341 | ) 342 | return dataloader 343 | -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | """Prepare audio data for compressing speech SSL.""" 2 | 3 | from argparse import ArgumentParser 4 | from pathlib import Path 5 | from typing import Union 6 | from tqdm import tqdm 7 | 8 | import torchaudio 9 | 10 | 11 | def create_tsv( 12 | root_dir: Union[str, Path], 13 | out_dir: Union[str, Path], 14 | extension: str = "flac", 15 | ) -> None: 16 | """Create file lists for training and validation. 17 | Args: 18 | root_dir (str or Path): The directory of the dataset. 19 | out_dir (str or Path): The directory to store the file lists. 20 | extension (str, optional): The extension of audio files. (Default: ``flac``) 21 | 22 | Returns: 23 | None 24 | """ 25 | root_dir = Path(root_dir) 26 | out_dir = Path(out_dir) 27 | 28 | if not out_dir.exists(): 29 | out_dir.mkdir() 30 | 31 | with open( 32 | out_dir / "train100.tsv", "w" 33 | ) as train100_f, open( 34 | out_dir / "train960.tsv", "w" 35 | ) as train960_f, open( 36 | out_dir / "valid.tsv", "w" 37 | ) as valid_f: 38 | print(root_dir, file=train100_f) 39 | print(root_dir, file=train960_f) 40 | print(root_dir, file=valid_f) 41 | 42 | for fname in tqdm(root_dir.glob(f"**/*.{extension}")): 43 | line = f"{fname.relative_to(root_dir)}\t{torchaudio.info(fname).num_frames}" 44 | 45 | if "train-clean-100" in str(fname): 46 | print(line, file=train100_f) 47 | if "train" in str(fname): 48 | print(line, file=train960_f) 49 | if "dev" in str(fname): 50 | print(line, file=valid_f) 51 | 52 | print("Finished creating the file lists successfully") 53 | 54 | 55 | def parse_args(): 56 | parser = ArgumentParser( 57 | description="Prepare audio data." 58 | ) 59 | parser.add_argument( 60 | "--data", 61 | type=Path, 62 | required=True, 63 | help="Path to the original dataset." 64 | ) 65 | parser.add_argument( 66 | "--out", 67 | type=Path, 68 | default=Path("data/librispeech"), 69 | help="Path to save the output." 70 | ) 71 | args = parser.parse_args() 72 | return args 73 | 74 | 75 | if __name__ == "__main__": 76 | args = parse_args() 77 | 78 | assert args.data.is_dir(), args.data 79 | args.out.mkdir(parents=True, exist_ok=True) 80 | 81 | create_tsv( 82 | root_dir=args.data, 83 | out_dir=args.out, 84 | ) 85 | -------------------------------------------------------------------------------- /prune.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pathlib 3 | import torch 4 | from argparse import ArgumentParser 5 | 6 | from wav2vec2.model import ( 7 | wav2vec2_model, 8 | ) 9 | 10 | 11 | def prune_from_ckpt(distilled_ckpt, original_ckpt): 12 | ckpt = torch.load(distilled_ckpt, map_location='cpu') 13 | student_model_state_dict = { 14 | k[len("student_model."):]: v for k, v in ckpt["state_dict"].items() if k.startswith("student_model.") 15 | } 16 | distill_linear_projs_state_dict = { 17 | k[len("distill_linear_projs."):]: v for k, v in ckpt["state_dict"].items() if k.startswith("distill_linear_projs.") 18 | } 19 | config = torch.load(original_ckpt, map_location='cpu')['config'] 20 | config.update( 21 | dict( 22 | extractor_prune_conv_channels="feature_extractor.conv_layers.0.hard_concrete.log_alpha" in student_model_state_dict, 23 | encoder_prune_attention_heads="encoder.transformer.layers.0.attention.hard_concrete_for_heads.log_alpha" in student_model_state_dict, 24 | encoder_prune_attention_layer="encoder.transformer.layers.0.attention.hard_concrete_for_layer.log_alpha" in student_model_state_dict, 25 | encoder_prune_feed_forward_intermediate="encoder.transformer.layers.0.feed_forward.hard_concrete_for_intermediate.log_alpha" in student_model_state_dict, 26 | encoder_prune_feed_forward_layer="encoder.transformer.layers.0.feed_forward.hard_concrete_for_layer.log_alpha" in student_model_state_dict, 27 | ) 28 | ) 29 | model = wav2vec2_model(**config) 30 | model.load_state_dict(student_model_state_dict, strict=True) 31 | 32 | conv_config, use_attention, use_feed_forward, num_heads, remaining_heads, ff_interm_features = model.prune() 33 | pruned_config = config.copy() 34 | if len(num_heads) == 0: # for wavlm 35 | assert len(remaining_heads) > 0 36 | pruned_config.update( 37 | { 38 | "encoder_remaining_heads": remaining_heads, 39 | } 40 | ) 41 | else: 42 | pruned_config.update( 43 | { 44 | "encoder_num_heads": num_heads, 45 | } 46 | ) 47 | pruned_config.update( 48 | { 49 | "extractor_conv_layer_config": conv_config, 50 | "encoder_use_attention": use_attention, 51 | "encoder_use_feed_forward": use_feed_forward, 52 | "encoder_ff_interm_features": ff_interm_features, 53 | "extractor_prune_conv_channels": False, 54 | "encoder_prune_attention_heads": False, 55 | "encoder_prune_attention_layer": False, 56 | "encoder_prune_feed_forward_intermediate": False, 57 | "encoder_prune_feed_forward_layer": False, 58 | } 59 | ) 60 | print(json.dumps(pruned_config, indent=4)) 61 | 62 | ret = { 63 | "state_dict": model.state_dict(), 64 | "config": pruned_config, 65 | "distill_linear_projs": distill_linear_projs_state_dict, 66 | } 67 | return ret 68 | 69 | 70 | def load_pruned_model(ckpt_path): 71 | ckpt = torch.load(ckpt_path, map_location='cpu') 72 | model = wav2vec2_model(**ckpt["config"]) 73 | model.load_state_dict(ckpt["state_dict"], strict=True) 74 | return model 75 | 76 | 77 | def parse_args(): 78 | parser = ArgumentParser(description="Prune and save distilled model.") 79 | parser.add_argument( 80 | "--distilled_ckpt", 81 | type=pathlib.Path, 82 | help="Path to the distilled model checkpoint." 83 | ) 84 | parser.add_argument( 85 | "--original_ckpt", 86 | type=pathlib.Path, 87 | help="Path to the original checkpoint." 88 | ) 89 | args = parser.parse_args() 90 | return args 91 | 92 | 93 | if __name__ == "__main__": 94 | args = parse_args() 95 | out_path = args.distilled_ckpt.parent / "pruned_hubert_base.pth" 96 | torch.save( 97 | prune_from_ckpt( 98 | distilled_ckpt=args.distilled_ckpt, 99 | original_ckpt=args.original_ckpt 100 | ), 101 | out_path 102 | ) 103 | 104 | # Check if loading from ckpt works 105 | load_pruned_model(out_path) 106 | 107 | print(f"Successfully saved pruned model weights and config to: {out_path}") 108 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --nodes=1 3 | #SBATCH --gres=gpu:4 4 | #SBATCH --ntasks-per-node=4 5 | #SBATCH --cpus-per-task=16 6 | #SBATCH --mem=240000M 7 | #SBATCH --partition=gpuA100x4 8 | #SBATCH --job-name=dphubert 9 | #SBATCH --time=2-00:00:00 10 | 11 | # first source conda.sh, and then 12 | # activate your conda environment 13 | 14 | set -x 15 | 16 | # shared config 17 | tsv_dir=data/librispeech # data path 18 | train_subset=train960 # train subset name: train960, train100 19 | teacher_ckpt=pretrained/hubert-base-ls960.hf.pth # checkpoint path 20 | student_ckpt=${teacher_ckpt} # student initialization, same as teacher 21 | distill_layers=0.4,8,12 # use period to separate groups where each group shares the same linear layer: [0], [4, 8, 12] 22 | distill_mode=layer2layer # "layer2layer", "predlayer" 23 | l2_weight=0 # weight for L2 loss 24 | l1_weight=1 # weight for L1 loss 25 | cos_weight=1 # weight for cosine similarity 26 | cos_type=raw # "raw", "log_sig" 27 | 28 | # distill config 29 | lr=0.0002 # learning rate 30 | warmup=15000 # warmup steps 31 | max=50000 # max update steps 32 | pruning_units=conv,head,interm # conv,head,interm,attlayer,ffnlayer 33 | reg_lr=0.02 # learning rate for regularization params 34 | target_sparsity=0.75 # final target sparsity 35 | sparsity_warmup=5000 # warmup steps for sparsity; sparsity will linearly increase from 0 to target 36 | root_dir=exp/hubert-base_${train_subset}_sp${target_sparsity}_spup${sparsity_warmup}_lr${lr}_up${warmup}_max${max}_${distill_mode}${distill_layers}_reglr${reg_lr}_${pruning_units} 37 | 38 | # final distill config 39 | final_lr=0.0001 # learning rate for final distillation (training step 2) 40 | final_warmup=5000 # warmup steps 41 | final_max=25000 # max update steps 42 | final_exp_dir=${root_dir}/lr${final_lr}_up${final_warmup}_max${final_max} 43 | 44 | 45 | # Training step 1: distill 46 | mkdir -p ${root_dir} 47 | 48 | srun python distill.py \ 49 | --tsv_dir ${tsv_dir} \ 50 | --train_subset ${train_subset} \ 51 | --seconds_per_batch 160 \ 52 | --num_workers 12 \ 53 | --exp_dir ${root_dir} \ 54 | --log_interval 50 \ 55 | --learning_rate ${lr} \ 56 | --weight_decay 0.0 \ 57 | --warmup_updates ${warmup} \ 58 | --max_updates ${max} \ 59 | --clip_norm 10.0 \ 60 | --num_nodes 1 \ 61 | --gpus 4 \ 62 | --accum_grad 1 \ 63 | --precision 16 \ 64 | --teacher_ckpt ${teacher_ckpt} \ 65 | --student_ckpt ${student_ckpt} \ 66 | --distill_layers ${distill_layers} \ 67 | --distill_mode ${distill_mode} \ 68 | --l2_weight ${l2_weight} \ 69 | --l1_weight ${l1_weight} \ 70 | --cos_weight ${cos_weight} \ 71 | --cos_type ${cos_type} \ 72 | --pruning_units ${pruning_units} \ 73 | --reg_learning_rate ${reg_lr} \ 74 | --target_sparsity ${target_sparsity} \ 75 | --sparsity_warmup_updates ${sparsity_warmup} 2>&1 | tee ${root_dir}/distill.log || exit 1; 76 | 77 | # prune and save model 78 | python prune.py \ 79 | --distilled_ckpt ${root_dir}/ckpts/*.ckpt \ 80 | --original_ckpt ${student_ckpt} || exit 1; 81 | 82 | 83 | # Training step 2: final distill 84 | pruned_ckpt=${root_dir}/ckpts/pruned_hubert_base.pth 85 | mkdir -p ${final_exp_dir} 86 | 87 | srun python final_distill.py \ 88 | --tsv_dir ${tsv_dir} \ 89 | --train_subset ${train_subset} \ 90 | --seconds_per_batch 160 \ 91 | --num_workers 12 \ 92 | --exp_dir ${final_exp_dir} \ 93 | --log_interval 50 \ 94 | --learning_rate ${final_lr} \ 95 | --weight_decay 0.0 \ 96 | --warmup_updates ${final_warmup} \ 97 | --max_updates ${final_max} \ 98 | --clip_norm 10.0 \ 99 | --num_nodes 1 \ 100 | --gpus 4 \ 101 | --accum_grad 1 \ 102 | --precision 16 \ 103 | --teacher_ckpt ${teacher_ckpt} \ 104 | --student_ckpt ${pruned_ckpt} \ 105 | --distill_layers ${distill_layers} \ 106 | --distill_mode ${distill_mode} \ 107 | --l2_weight ${l2_weight} \ 108 | --l1_weight ${l1_weight} \ 109 | --cos_weight ${cos_weight} \ 110 | --cos_type ${cos_type} 2>&1 | tee ${final_exp_dir}/final_distill.log || exit 1; 111 | 112 | # save final model and config 113 | python save_final_ckpt.py \ 114 | --config_path ${pruned_ckpt} \ 115 | --ckpt_after_final_distill ${final_exp_dir}/ckpts/*.ckpt || exit 1; 116 | -------------------------------------------------------------------------------- /save_final_ckpt.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import json 3 | import pathlib 4 | import torch 5 | 6 | from prune import load_pruned_model 7 | 8 | 9 | def parse_args(): 10 | parser = ArgumentParser(description="Save ckpt and config after final distill.") 11 | parser.add_argument( 12 | "--config_path", 13 | type=pathlib.Path, 14 | help="Path to the checkpoint file containing the pruned config." 15 | ) 16 | parser.add_argument( 17 | "--ckpt_after_final_distill", 18 | type=pathlib.Path, 19 | help="Path to the checkpoint file after final distill." 20 | ) 21 | args = parser.parse_args() 22 | return args 23 | 24 | 25 | if __name__ == "__main__": 26 | args = parse_args() 27 | config = torch.load(args.config_path, map_location="cpu")["config"] 28 | print(json.dumps(config, indent=4)) 29 | 30 | ckpt = torch.load(args.ckpt_after_final_distill, map_location="cpu") 31 | student_model_state_dict = { 32 | k[len("student_model."):]: v for k, v in ckpt["state_dict"].items() if k.startswith("student_model.") 33 | } 34 | distill_linear_projs_state_dict = { 35 | k[len("distill_linear_projs."):]: v for k, v in ckpt["state_dict"].items() if k.startswith("distill_linear_projs.") 36 | } 37 | 38 | out_path = args.ckpt_after_final_distill.parent / "pruned_hubert_base.pth" 39 | torch.save( 40 | { 41 | "state_dict": student_model_state_dict, 42 | "config": config, 43 | "distill_linear_projs": distill_linear_projs_state_dict, 44 | }, 45 | out_path 46 | ) 47 | 48 | load_pruned_model(out_path) # verify if it works 49 | print(f"Successfully saved pruned model weights and config to: {out_path}") 50 | -------------------------------------------------------------------------------- /wav2vec2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyf98/DPHuBERT/c18093fe4b56a0027a80bf9b9b1a23f932cbf14c/wav2vec2/__init__.py -------------------------------------------------------------------------------- /wav2vec2/components.py: -------------------------------------------------------------------------------- 1 | """Building blocks for speech SSL models supporting pruning. 2 | 3 | Originally from: 4 | https://github.com/pytorch/audio/blob/main/torchaudio/models/wav2vec2/components.py 5 | 6 | """ 7 | 8 | from collections import defaultdict 9 | from typing import List, Optional, Tuple 10 | import math 11 | 12 | import torch 13 | from torch import nn, Tensor 14 | from torch.nn import Module, Parameter 15 | 16 | from .hardconcrete import HardConcrete 17 | from .pruning_utils import ( 18 | prune_linear_layer, 19 | prune_conv1d_layer, 20 | prune_layer_norm, 21 | ) 22 | 23 | 24 | def _init_transformer_params(module): 25 | """ 26 | Initialize the weights of Transformer module in Wav2Vec2/HuBERT. 27 | 28 | If the module is ``nn.Linear``, normalize the weight with mean 0 and standard deviation 0.02. 29 | If ``bias`` is set to ``True`` in the module, set ``bias`` to 0. 30 | 31 | If the module is ``nn.Embedding``, normalize the weight with mean 0 and standard deviation 0.02. 32 | If ``padding_idx`` is not None, set the weight of padding to 0. 33 | 34 | Note: 35 | Ths method corresponds to 36 | `init_bert_params 37 | `__ 38 | in the original ``fairseq`` implementation. 39 | """ 40 | 41 | def normal_(data): 42 | data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) 43 | 44 | if isinstance(module, nn.Linear): 45 | normal_(module.weight.data) 46 | if module.bias is not None: 47 | module.bias.data.zero_() 48 | if isinstance(module, nn.Embedding): 49 | normal_(module.weight.data) 50 | if module.padding_idx is not None: 51 | module.weight.data[module.padding_idx].zero_() 52 | 53 | 54 | class LayerNorm(nn.LayerNorm): 55 | """Layer norm with transpose""" 56 | 57 | def forward(self, input: Tensor) -> Tensor: 58 | x = input.transpose(-2, -1) 59 | x = nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 60 | x = x.transpose(-2, -1) 61 | return x 62 | 63 | 64 | class ConvLayerBlock(Module): 65 | """Convolution unit of FeatureExtractor""" 66 | 67 | def __init__( 68 | self, 69 | in_channels: int, 70 | out_channels: int, 71 | kernel_size: int, 72 | stride: int, 73 | bias: bool, 74 | layer_norm: Optional[Module], 75 | prune_conv_channels: bool = False, 76 | ): 77 | super().__init__() 78 | self.kernel_size = kernel_size 79 | self.stride = stride 80 | self.layer_norm = layer_norm 81 | self.conv = nn.Conv1d( 82 | in_channels=in_channels, 83 | out_channels=out_channels, 84 | kernel_size=kernel_size, 85 | stride=stride, 86 | bias=bias, 87 | ) 88 | 89 | if prune_conv_channels: 90 | self.hard_concrete = HardConcrete(n_in=out_channels, init_mean=0.01) 91 | else: 92 | self.hard_concrete = None 93 | 94 | def forward( 95 | self, 96 | x: Tensor, 97 | length: Optional[Tensor], 98 | ) -> Tuple[Tensor, Optional[Tensor]]: 99 | """ 100 | Args: 101 | x (Tensor): Shape: ``[batch, in_channels, in_frame]``. 102 | length (Tensor or None, optional): Shape ``[batch, ]``. 103 | Returns: 104 | Tensor: Shape ``[batch, out_channels, out_frames]``. 105 | Optional[Tensor]: Shape ``[batch, ]``. 106 | """ 107 | x = self.conv(x) 108 | if self.layer_norm is not None: 109 | x = self.layer_norm(x) 110 | x = nn.functional.gelu(x) 111 | 112 | if self.hard_concrete is not None: 113 | channel_mask = self.hard_concrete() # hard concrete mask, (out_channels,) 114 | x = x * channel_mask.unsqueeze(-1) 115 | 116 | if length is not None: 117 | length = torch.div(length - self.kernel_size, self.stride, rounding_mode="floor") + 1 118 | # When input length is 0, the resulting length can be negative. So fix it here. 119 | length = torch.max(torch.zeros_like(length), length) 120 | return x, length 121 | 122 | def get_num_params_and_out_channels(self, in_channels): 123 | if self.hard_concrete is not None: 124 | out_channels = self.hard_concrete.l0_norm() 125 | else: 126 | out_channels = self.conv.out_channels 127 | 128 | num_params = in_channels * out_channels * self.kernel_size 129 | if self.conv.bias is not None: 130 | num_params += out_channels 131 | if self.layer_norm is not None: 132 | num_params += out_channels * 2 133 | 134 | return num_params, out_channels 135 | 136 | 137 | class FeatureExtractor(Module): 138 | """Extract features from audio 139 | 140 | Args: 141 | conv_layers (nn.ModuleList): 142 | convolution layers 143 | """ 144 | 145 | def __init__( 146 | self, 147 | conv_layers: nn.ModuleList, 148 | ): 149 | super().__init__() 150 | self.conv_layers = conv_layers 151 | 152 | # NOTE: a dummy weight used to save the soft mask of the last conv layer 153 | self.dummy_weight = nn.Parameter( 154 | torch.ones(conv_layers[-1].conv.out_channels, dtype=torch.float32), 155 | requires_grad=False 156 | ) 157 | 158 | def forward( 159 | self, 160 | x: Tensor, 161 | length: Optional[Tensor], 162 | ) -> Tuple[Tensor, Optional[Tensor]]: 163 | """ 164 | Args: 165 | x (Tensor): 166 | Input Tensor representing a batch of audio, 167 | shape: ``[batch, time]``. 168 | length (Tensor or None, optional): 169 | Valid length of each input sample. shape: ``[batch, ]``. 170 | 171 | Returns: 172 | Tensor: 173 | The resulting feature, shape: ``[batch, frame, feature]`` 174 | Optional[Tensor]: 175 | Valid length of each output sample. shape: ``[batch, ]``. 176 | """ 177 | if x.ndim != 2: 178 | raise ValueError("Expected the input Tensor to be 2D (batch, time), " "but received {list(x.shape)}") 179 | 180 | x = x.unsqueeze(1) # (batch, channel==1, frame) 181 | for layer in self.conv_layers: 182 | x, length = layer(x, length) # (batch, feature, frame) 183 | x = x.transpose(1, 2) # (batch, frame, feature) 184 | x = x * self.dummy_weight 185 | return x, length 186 | 187 | def get_num_params_and_final_out_channels(self): 188 | in_channels = 1 189 | num_params = 0 190 | for layer in self.conv_layers: 191 | layer_params, in_channels = layer.get_num_params_and_out_channels(in_channels) 192 | num_params += layer_params 193 | 194 | num_params += in_channels # dummy weight 195 | 196 | return num_params, in_channels 197 | 198 | def prune(self): 199 | """"Prune conv layers and dummy weight based on hardconcrete parameters. 200 | This is an in-place operation. 201 | """ 202 | new_config = [] # [(output_channel, kernel_size, stride), ...] 203 | for idx, layer in enumerate(self.conv_layers): 204 | if layer.hard_concrete is not None: 205 | assert not layer.hard_concrete.training 206 | mask = layer.hard_concrete() # (out_features,) 207 | index = mask.nonzero().squeeze(-1) # 2D -> 1D 208 | assert len(index) > 0, f"Conv channels pruned to zero at index {idx}" 209 | new_config.append( 210 | (len(index), layer.kernel_size, layer.stride) 211 | ) 212 | 213 | # prune the current layer 214 | prune_conv1d_layer(layer.conv, index, "output") 215 | if layer.layer_norm is not None: 216 | prune_layer_norm(layer.layer_norm, index) 217 | 218 | # prune the next layer 219 | if idx == len(self.conv_layers) - 1: 220 | self.dummy_weight.data *= mask 221 | self.dummy_weight = nn.Parameter( 222 | self.dummy_weight.index_select(0, index).clone().detach(), requires_grad=False 223 | ) 224 | else: 225 | self.conv_layers[idx+1].conv.weight.data *= mask.unsqueeze(-1) 226 | prune_conv1d_layer(self.conv_layers[idx+1].conv, index, dim="input") 227 | 228 | layer.hard_concrete = None 229 | else: 230 | new_config.append( 231 | (layer.conv.out_channels, layer.kernel_size, layer.stride) 232 | ) 233 | index = torch.arange(layer.conv.out_channels, dtype=torch.long) 234 | 235 | return new_config, index 236 | 237 | 238 | class FeatureProjection(Module): 239 | """Layer that connects FeatureExtractor and Encoder 240 | 241 | Projects features to encoder dimension. 242 | 243 | Args: 244 | in_features (int): Input feature dim. 245 | out_features (int): Output feature dim. 246 | dropout (float): Dropout probability. 247 | """ 248 | 249 | def __init__( 250 | self, 251 | in_features: int, 252 | out_features: int, 253 | dropout: float, 254 | ): 255 | super().__init__() 256 | self.layer_norm = nn.LayerNorm(in_features) 257 | self.projection = nn.Linear( 258 | in_features, 259 | out_features, 260 | ) 261 | self.dropout = nn.Dropout(dropout) 262 | 263 | def forward(self, x): 264 | """ 265 | Args: 266 | x (Tensor): 267 | Feature Tensor. shape: ``[batch, frame, in_feature]`` 268 | Returns: 269 | Tensor: Projected features. ``[batch, frame, out_feature]``. 270 | """ 271 | x = self.layer_norm(x) 272 | x = self.projection(x) 273 | x = self.dropout(x) 274 | return x 275 | 276 | def get_num_params(self, in_features): 277 | return in_features * 2 + (in_features + 1) * self.projection.out_features 278 | 279 | 280 | class ConvolutionalPositionalEmbedding(Module): 281 | """Positional embedding which is placed at the beginning of Transformer. 282 | 283 | Args: 284 | embed_dim (int): Feature dimension of the input Tensor. 285 | kernel_size (int): The number of frames to be use. 286 | groups (int): The number of groups in feature dimensions. 287 | """ 288 | 289 | def __init__( 290 | self, 291 | embed_dim: int, 292 | kernel_size: int, 293 | groups: int, 294 | ): 295 | super().__init__() 296 | self.embed_dim = embed_dim 297 | self.kernel_size = kernel_size 298 | self.conv = nn.Conv1d( 299 | in_channels=embed_dim, 300 | out_channels=embed_dim, 301 | kernel_size=kernel_size, 302 | padding=kernel_size // 2, 303 | groups=groups, 304 | ) 305 | 306 | self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) 307 | self.num_remove: int = 1 if kernel_size % 2 == 0 else 0 308 | 309 | def __prepare_scriptable__(self): 310 | for hook in self.conv._forward_pre_hooks.values(): 311 | # The hook we want to remove is an instance of WeightNorm class, so 312 | # normally we would do `if isinstance(...)` but this class is not accessible 313 | # because of shadowing, so we check the module name directly. 314 | # https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3 315 | if hook.__module__ == "torch.nn.utils.weight_norm" and hook.__class__.__name__ == "WeightNorm": 316 | torch.nn.utils.remove_weight_norm(self.conv) 317 | return self 318 | 319 | def forward(self, x): 320 | """ 321 | Args: 322 | x (Tensor): shape ``[batch, frame, feature]``. 323 | 324 | Returns: 325 | Tensor: The resulting feature. Shape ``[batch, frame, feature]``. 326 | """ 327 | x = x.transpose(-2, -1) 328 | x = self.conv(x) 329 | if self.num_remove > 0: 330 | x = x[..., : -self.num_remove] 331 | x = torch.nn.functional.gelu(x) 332 | x = x.transpose(-2, -1) 333 | return x 334 | 335 | 336 | class SelfAttention(Module): 337 | """Multihead Self Attention module 338 | 339 | Args: 340 | embed_dim (int): Total dimension of the model. 341 | num_heads (int): The number of heads. 342 | dropout (float, optional): 343 | Dropout probability on attn_output_weights. Default: ``0.0`` 344 | """ 345 | 346 | def __init__( 347 | self, 348 | embed_dim: int, 349 | num_heads: int, 350 | head_dim: int, 351 | dropout: float = 0.0, 352 | prune_heads: bool = False, # whether to prune attention heads 353 | prune_layer: bool = False, # whether to prune entire attention layers 354 | ): 355 | super().__init__() 356 | 357 | self.embed_dim = embed_dim 358 | self.num_heads = num_heads 359 | self.head_dim = head_dim 360 | self.dropout = torch.nn.Dropout(dropout) 361 | 362 | self.scaling = self.head_dim**-0.5 363 | 364 | self.k_proj = nn.Linear(embed_dim, num_heads * head_dim, bias=True) 365 | self.v_proj = nn.Linear(embed_dim, num_heads * head_dim, bias=True) 366 | self.q_proj = nn.Linear(embed_dim, num_heads * head_dim, bias=True) 367 | self.out_proj = nn.Linear(num_heads * head_dim, embed_dim, bias=True) 368 | 369 | if prune_heads: 370 | self.hard_concrete_for_heads = HardConcrete(n_in=num_heads, init_mean=0.01) 371 | else: 372 | self.hard_concrete_for_heads = None 373 | 374 | if prune_layer: 375 | self.hard_concrete_for_layer = HardConcrete(n_in=1, init_mean=0.01) 376 | else: 377 | self.hard_concrete_for_layer = None 378 | 379 | def forward( 380 | self, 381 | x: Tensor, 382 | attention_mask: Optional[Tensor] = None, 383 | position_bias: Optional[Tensor] = None, 384 | key_padding_mask: Optional[Tensor] = None, 385 | ) -> Tuple[Tensor, Optional[Tensor]]: 386 | """ 387 | Args: 388 | x (Tensor): shape: ``[batch_size, sequence_length, embed_dim]``. 389 | attention_mask (Tensor or ``None``, optional): 390 | shape: ``[batch_size, 1, sequence_length, sequence_length]`` 391 | position_bias: Not used. Only for the compatibility with :py:class:`WavLMSelfAttention`. 392 | key_padding_mask (Tensor or ``None``): Not used. Only for the compatibility with 393 | :py:class:`WavLMSelfAttention`. 394 | Returns: 395 | (Tensor, ``None``): The resulting attention output and ``None`` (necessary for compatibility 396 | with :py:class:`WavLMSelAttention`). 397 | Attention output shape: ``[batch, sequence_length, embed_dim]``. 398 | """ 399 | if x.ndim != 3 or x.shape[2] != self.embed_dim: 400 | raise ValueError( 401 | f"The expected input shape is (batch, sequence, embed_dim=={self.embed_dim}). " f"Found {x.shape}." 402 | ) 403 | batch_size, length, embed_dim = x.size() 404 | 405 | shape = (batch_size, length, self.num_heads, self.head_dim) 406 | q = self.q_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd 407 | k = self.k_proj(x).view(*shape).permute(0, 2, 3, 1) # B, nH, Hd, L 408 | v = self.v_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd 409 | 410 | # scale down q to avoid value overflow. 411 | weights = (self.scaling * q) @ k # B, nH, L, L 412 | if attention_mask is not None: 413 | weights += attention_mask 414 | # subtracting a constant value from the tensor won't change the output of softmax. 415 | # apply the subtraction to avoid value overflow in torch.nn.functional.softmax. 416 | # for more details, please see Equation 7 in https://arxiv.org/abs/2112.08778 417 | weights = weights - weights.max(dim=-1, keepdim=True)[0] 418 | 419 | weights = torch.nn.functional.softmax(weights, dim=-1) 420 | weights = self.dropout(weights) 421 | 422 | output = weights @ v # B, nH, L, Hd 423 | 424 | if self.hard_concrete_for_heads is not None: 425 | head_mask = self.hard_concrete_for_heads() # (nH,) 426 | output = output * head_mask.unsqueeze(-1).unsqueeze(-1) 427 | 428 | output = output.transpose(2, 1).reshape(batch_size, length, self.num_heads * self.head_dim) 429 | 430 | output = self.out_proj(output) 431 | 432 | if self.hard_concrete_for_layer is not None: 433 | layer_mask = self.hard_concrete_for_layer() # (1,) 434 | output = output * layer_mask 435 | 436 | return output, None # Necessary for compatibility with WavLMSelAttention 437 | 438 | def get_num_params(self): 439 | if self.hard_concrete_for_heads is not None: 440 | num_heads = self.hard_concrete_for_heads.l0_norm() 441 | else: 442 | num_heads = self.num_heads 443 | num_params = (self.embed_dim + 1) * num_heads * self.head_dim * 3 \ 444 | + (num_heads * self.head_dim + 1) * self.embed_dim 445 | 446 | if self.hard_concrete_for_layer is not None: 447 | num_params *= self.hard_concrete_for_layer.l0_norm() 448 | 449 | return num_params 450 | 451 | def prune(self): 452 | new_config = { 453 | "use_attention": True, 454 | "num_heads": self.num_heads, 455 | } 456 | if self.hard_concrete_for_layer is not None: 457 | assert not self.hard_concrete_for_layer.training 458 | layer_mask = self.hard_concrete_for_layer() # (1,) 459 | self.out_proj.weight.data *= layer_mask 460 | self.out_proj.bias.data *= layer_mask 461 | if layer_mask == 0: 462 | new_config["use_attention"] = False 463 | self.hard_concrete_for_layer = None 464 | 465 | if self.hard_concrete_for_heads is not None: 466 | assert not self.hard_concrete_for_heads.training 467 | head_mask = self.hard_concrete_for_heads() # (num_heads,) 468 | new_config["num_heads"] = len(head_mask.nonzero()) 469 | if new_config["num_heads"] == 0: 470 | new_config["use_attention"] = False 471 | else: 472 | full_mask = head_mask.repeat_interleave(self.head_dim) 473 | full_index = full_mask.nonzero().squeeze(-1) # 1D 474 | 475 | prune_linear_layer(self.k_proj, full_index, "output") 476 | prune_linear_layer(self.v_proj, full_index, "output") 477 | prune_linear_layer(self.q_proj, full_index, "output") 478 | 479 | self.out_proj.weight.data *= full_mask 480 | prune_linear_layer(self.out_proj, full_index, "input") 481 | self.hard_concrete_for_heads = None 482 | 483 | return new_config 484 | 485 | 486 | class WavLMSelfAttention(SelfAttention): 487 | """Multi-headed self-attention for WavLM model :cite:`chen2022wavlm`. 488 | 489 | Args: 490 | embed_dim (int): Total dimension of the model. 491 | num_heads (int): The number of heads. 492 | dropout (float, optional): Dropout probability on attn_output_weights. (Default: to ``0.0``) 493 | bias (bool, optional): If ``True``, add bias to input / output projection layers. (Default: ``True``) 494 | has_relative_attention_bias (bool, optional): If ``True``, apply relative position embedding. 495 | Necessary in the first encoder layer, but not in the subsequent ones. (Default: ``False``) 496 | num_buckets (int, optional): Number of buckets for relative position embedding. (Default: ``32``) 497 | max_distance (int, optional): Naximum distance for relative position embedding. (Default: ``128``) 498 | gru_rel_pos (bool, optional): If ``True``, apply gated relative position embedding. (Default: ``False``) 499 | """ 500 | 501 | def __init__( 502 | self, 503 | embed_dim: int, 504 | total_num_heads: int, 505 | remaining_heads: Optional[List[int]] = None, 506 | dropout: float = 0.0, 507 | bias: bool = True, 508 | has_relative_attention_bias: bool = False, 509 | num_buckets: int = 32, 510 | max_distance: int = 128, 511 | gru_rel_pos: bool = True, 512 | prune_heads: bool = False, 513 | prune_layer: bool = False, 514 | ): 515 | self.total_num_heads = total_num_heads 516 | if remaining_heads is None: 517 | self.remaining_heads = list(range(total_num_heads)) 518 | else: 519 | self.remaining_heads = remaining_heads # list of indices 520 | 521 | self.head_dim = embed_dim // total_num_heads 522 | 523 | super().__init__(embed_dim, len(self.remaining_heads), self.head_dim, dropout, prune_heads, prune_layer) 524 | 525 | self.has_relative_attention_bias = has_relative_attention_bias 526 | self.num_buckets = num_buckets 527 | self.max_distance = max_distance 528 | 529 | if has_relative_attention_bias: 530 | self.rel_attn_embed = nn.Embedding(num_buckets, total_num_heads) 531 | else: 532 | self.rel_attn_embed = None 533 | 534 | # override linear layers to customize bias 535 | self.k_proj = nn.Linear(embed_dim, len(self.remaining_heads) * self.head_dim, bias=bias) 536 | self.v_proj = nn.Linear(embed_dim, len(self.remaining_heads) * self.head_dim, bias=bias) 537 | self.q_proj = nn.Linear(embed_dim, len(self.remaining_heads) * self.head_dim, bias=bias) 538 | self.out_proj = nn.Linear(len(self.remaining_heads) * self.head_dim, embed_dim, bias=bias) 539 | 540 | self.gru_rel_pos = gru_rel_pos 541 | if self.gru_rel_pos: 542 | self.gru_rel_pos_linear = nn.Linear(self.head_dim, 8) 543 | self.gru_rel_pos_const = nn.Parameter(torch.ones(1, total_num_heads, 1, 1)) 544 | self.has_position_bias = True 545 | 546 | def compute_bias(self, query_length: int, key_length: int) -> Tensor: 547 | """Compute relative position embeddings for WavLM model. 548 | Args: 549 | query_length (int): Query position can take values between 0 and ``query_length - 1``. 550 | key_length (int): Key position can take values between 0 and ``key_length - 1``. 551 | Returns: 552 | Tensor of shape `(num_heads, query_length, key_length)`, relative positions embeddings 553 | """ 554 | context_position = torch.arange(query_length, dtype=torch.long)[:, None] 555 | memory_position = torch.arange(key_length, dtype=torch.long)[None, :] 556 | relative_position = memory_position - context_position # Shape (query_length, key_length) 557 | relative_position_bucket = self._relative_positions_bucket(relative_position, bidirectional=True) 558 | relative_position_bucket = relative_position_bucket.to(self.rel_attn_embed.weight.device) 559 | values = self.rel_attn_embed(relative_position_bucket) # Shape (query_length, key_length, num_heads) 560 | values = values.permute([2, 0, 1]) 561 | return values 562 | 563 | def _relative_positions_bucket(self, relative_positions: Tensor, bidirectional: bool = True): 564 | """Compute relative position buckets for WavLM model. Computation similar to formula (5) in WavLM 565 | paper :cite:`chen2022wavlm`. 566 | Args: 567 | relative_positions (Tensor): Relative offsets between query and key positions, 568 | of shape ``(query_length, key_length)``. 569 | bidirectional (bool): If ``True``, values will be filled both above and below the diagonal in the resulting 570 | matrix. If ``False``, the elements above the diagonal (i.e. with negative relative offsets) will be set 571 | to zero. (Default ``True``) 572 | Returns: 573 | Tensor of shape ``(query_length, key_length)`` filled bucketed values of with relative positions. 574 | """ 575 | num_buckets = self.num_buckets 576 | max_distance = self.max_distance 577 | # Shape (query_length, key_length) 578 | relative_buckets = torch.zeros_like(relative_positions, dtype=torch.long) 579 | 580 | if bidirectional: 581 | num_buckets = num_buckets // 2 582 | relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets 583 | relative_positions = torch.abs(relative_positions) 584 | else: 585 | relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions)) 586 | 587 | max_exact = num_buckets // 2 588 | is_small = relative_positions < max_exact 589 | 590 | relative_postion_if_large = max_exact + ( 591 | torch.log(relative_positions.float() / max_exact) 592 | / math.log(max_distance / max_exact) 593 | * (num_buckets - max_exact) 594 | ).to(torch.long) 595 | relative_postion_if_large = torch.min( 596 | relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) 597 | ) 598 | 599 | relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large) 600 | return relative_buckets 601 | 602 | def forward( 603 | self, 604 | query: Tensor, 605 | attention_mask: Optional[Tensor] = None, 606 | position_bias: Optional[Tensor] = None, 607 | key_padding_mask: Optional[Tensor] = None, 608 | ) -> Tuple[Tensor, Optional[Tensor]]: 609 | """ 610 | Args: 611 | query (Tensor): Input of shape ``(batch_size, src_len, embed_dim)``. 612 | key_padding_mask (Tensor or None, optional): Mask to exclude keys that are pads, of shape 613 | `(batch, src_len)`, where padding elements are indicated by 1s. (Default: ``None``) 614 | attn_mask: Needs to be ``None``. The argument exists for compatibility with 615 | ``EncoderLayer``. (Default: ``None``) 616 | position_bias (Tensor or None, optional): Position bias of shape 617 | ``(batch_size * num_heads, src_len, src_len)``. When used inside WavLM model encoder, will be 618 | generated in the first layer and then passed from each encoder layer to the next one. 619 | (Default: ``None``) 620 | Returns: 621 | attn_output (Tensor): Attention output of shape ``(batch_size, src_len, embed_dim)``. 622 | position_bias (Tensor or None): Position bias of shape ``(batch_size * num_heads, src_len, src_len)``. 623 | """ 624 | bsz, seq_len, embed_dim = query.size() 625 | assert embed_dim == self.embed_dim 626 | assert key_padding_mask is None 627 | 628 | # only for the first layer 629 | if self.rel_attn_embed is not None and position_bias is None: 630 | position_bias = self.compute_bias(seq_len, seq_len) 631 | position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.total_num_heads, seq_len, seq_len) 632 | 633 | attn_mask_rel_pos: Optional[Tensor] = None 634 | if position_bias is not None: 635 | attn_mask_rel_pos = position_bias 636 | if self.gru_rel_pos: # Apply gating on relative position bias 637 | query_layer = query.view(bsz, seq_len, self.total_num_heads, -1) 638 | query_layer = query_layer.permute(0, 2, 1, 3) 639 | 640 | gate_a, gate_b = torch.sigmoid( 641 | self.gru_rel_pos_linear(query_layer).view(bsz, self.total_num_heads, seq_len, 2, 4).sum(-1, keepdim=False) 642 | ).chunk(2, dim=-1) 643 | gate_a_1 = gate_a * (gate_b * self.gru_rel_pos_const - 1.0) + 2.0 644 | attn_mask_rel_pos = gate_a_1.view(bsz * self.total_num_heads, -1, 1) * position_bias 645 | 646 | attn_mask_rel_pos = attn_mask_rel_pos.view((-1, seq_len, seq_len)) 647 | attn_mask_rel_pos = attn_mask_rel_pos.reshape(bsz, self.total_num_heads, seq_len, seq_len)[:, self.remaining_heads, :, :] 648 | 649 | attn_mask = attn_mask_rel_pos 650 | if attention_mask is not None: 651 | attn_mask = attn_mask + attention_mask 652 | if key_padding_mask is not None: 653 | attn_mask = attn_mask.masked_fill( 654 | key_padding_mask.reshape(bsz, 1, 1, seq_len), 655 | float("-inf") 656 | ) 657 | attn_output, _ = super().forward(query, attention_mask=attn_mask) 658 | 659 | return attn_output, position_bias 660 | 661 | def prune(self): 662 | new_config = { 663 | "use_attention": True, 664 | "remaining_heads": self.remaining_heads, 665 | } 666 | if self.hard_concrete_for_layer is not None: 667 | assert not self.hard_concrete_for_layer.training 668 | layer_mask = self.hard_concrete_for_layer() # (1,) 669 | self.out_proj.weight.data *= layer_mask 670 | self.out_proj.bias.data *= layer_mask 671 | if layer_mask == 0: 672 | new_config["use_attention"] = False 673 | self.hard_concrete_for_layer = None 674 | 675 | if self.hard_concrete_for_heads is not None: 676 | assert not self.hard_concrete_for_heads.training 677 | head_mask = self.hard_concrete_for_heads() # (num_heads,) 678 | new_config["remaining_heads"] = head_mask.nonzero().squeeze(-1).tolist() 679 | if len(new_config["remaining_heads"]) == 0: 680 | new_config["use_attention"] = False 681 | else: 682 | full_mask = head_mask.repeat_interleave(self.head_dim) 683 | full_index = full_mask.nonzero().squeeze(-1) # 1D 684 | 685 | prune_linear_layer(self.k_proj, full_index, "output") 686 | prune_linear_layer(self.v_proj, full_index, "output") 687 | prune_linear_layer(self.q_proj, full_index, "output") 688 | 689 | self.out_proj.weight.data *= full_mask 690 | prune_linear_layer(self.out_proj, full_index, "input") 691 | self.hard_concrete_for_heads = None 692 | 693 | return new_config 694 | 695 | 696 | class FeedForward(Module): 697 | """Layer that follows attention layer in encoder layer.""" 698 | 699 | def __init__( 700 | self, 701 | io_features: int, 702 | intermediate_features: int, 703 | intermediate_dropout: float, 704 | output_dropout: float, 705 | prune_intermediate: bool = False, 706 | prune_layer: bool = False, 707 | ): 708 | super().__init__() 709 | self.intermediate_dense = nn.Linear(io_features, intermediate_features) 710 | self.intermediate_dropout = nn.Dropout(intermediate_dropout) 711 | self.output_dense = nn.Linear(intermediate_features, io_features) 712 | self.output_dropout = nn.Dropout(output_dropout) 713 | 714 | if prune_intermediate: 715 | self.hard_concrete_for_intermediate = HardConcrete( 716 | n_in=intermediate_features, init_mean=0.5 717 | ) 718 | else: 719 | self.hard_concrete_for_intermediate = None 720 | 721 | if prune_layer: 722 | self.hard_concrete_for_layer = HardConcrete(n_in=1, init_mean=0.01) 723 | else: 724 | self.hard_concrete_for_layer = None 725 | 726 | def forward(self, x): 727 | """ 728 | Args: 729 | x (Tensor): shape: `(batch, sequence_length, io_features)` 730 | Returns: 731 | x (Tensor): shape: `(batch, sequence_length, io_features)` 732 | """ 733 | x = self.intermediate_dense(x) 734 | x = torch.nn.functional.gelu(x) 735 | x = self.intermediate_dropout(x) 736 | 737 | if self.hard_concrete_for_intermediate is not None: 738 | intermediate_mask = self.hard_concrete_for_intermediate() # (intermediate_features,) 739 | x = x * intermediate_mask 740 | 741 | x = self.output_dense(x) 742 | x = self.output_dropout(x) 743 | 744 | if self.hard_concrete_for_layer is not None: 745 | layer_mask = self.hard_concrete_for_layer() # (1,) 746 | x = x * layer_mask 747 | 748 | return x 749 | 750 | def get_num_params(self): 751 | io_features = self.intermediate_dense.in_features 752 | if self.hard_concrete_for_intermediate is not None: 753 | intermediate_features = self.hard_concrete_for_intermediate.l0_norm() 754 | else: 755 | intermediate_features = self.intermediate_dense.out_features 756 | num_params = (io_features + 1) * intermediate_features + (intermediate_features + 1) * io_features 757 | 758 | if self.hard_concrete_for_layer is not None: 759 | num_params *= self.hard_concrete_for_layer.l0_norm() 760 | 761 | return num_params 762 | 763 | def prune(self): 764 | new_config = { 765 | "use_feed_forward": True, 766 | "ff_interm_features": self.intermediate_dense.out_features 767 | } 768 | if self.hard_concrete_for_layer is not None: 769 | assert not self.hard_concrete_for_layer.training 770 | layer_mask = self.hard_concrete_for_layer() 771 | self.output_dense.weight.data *= layer_mask 772 | self.output_dense.bias.data *= layer_mask 773 | if layer_mask == 0: 774 | new_config["use_feed_forward"] = False 775 | self.hard_concrete_for_layer = None 776 | 777 | if self.hard_concrete_for_intermediate is not None: 778 | assert not self.hard_concrete_for_intermediate.training 779 | interm_mask = self.hard_concrete_for_intermediate() 780 | interm_index = interm_mask.nonzero().squeeze(-1) # NOTE: must specify dim=-1 781 | new_config["ff_interm_features"] = len(interm_index) 782 | if new_config["ff_interm_features"] == 0: 783 | new_config["use_feed_forward"] = False 784 | else: 785 | prune_linear_layer(self.intermediate_dense, interm_index, "output") 786 | 787 | self.output_dense.weight.data *= interm_mask 788 | prune_linear_layer(self.output_dense, interm_index, "input") 789 | self.hard_concrete_for_intermediate = None 790 | 791 | return new_config 792 | 793 | 794 | class EncoderLayer(Module): 795 | """A layer unit in encoder. Combines multihead self attention and feed forward.""" 796 | 797 | def __init__( 798 | self, 799 | attention: Optional[Module], # can be None if the entire layer is pruned 800 | dropout: float, 801 | layer_norm_first: bool, 802 | feed_forward: Optional[Module], # can be None if the entire layer is pruned 803 | embed_dim: int, 804 | ): 805 | super().__init__() 806 | self.attention = attention 807 | self.dropout = nn.Dropout(dropout) 808 | self.layer_norm = nn.LayerNorm(embed_dim) 809 | self.layer_norm_first = layer_norm_first 810 | self.feed_forward = feed_forward 811 | self.final_layer_norm = nn.LayerNorm(embed_dim) 812 | self.embed_dim = embed_dim 813 | 814 | def forward( 815 | self, 816 | x: Tensor, 817 | attention_mask: Optional[Tensor] = None, 818 | position_bias: Optional[Tensor] = None, 819 | key_padding_mask: Optional[Tensor] = None, 820 | ) -> Tuple[Tensor, Optional[Tensor]]: 821 | """ 822 | Args: 823 | x (Tensor): Input of shape ``(batch, sequence_length, embed_dim)``. 824 | attention_mask (Tensor or ``None``, optional): attention mask 825 | of shape ``(batch, 1, sequence_length, sequence_length)``. (Default: ``None``) 826 | position_bias (Tensor or ``None``, optional): position bias of shape 827 | ``(batch_size * num_heads, src_len, src_len)``. 828 | Only necessary for WavLM model, ``None`` otherwise. (Default: ``None``) 829 | key_padding_mask (Tensor or ``None``, optional): key padding mask of shape ``(batch_size, src_len)``. 830 | Only used for WavLM model, ignored otherwise. (Default: ``None``) 831 | Returns: 832 | (x, position_bias): Shapes are the same as in the input. Position bias is only relevant for WaLM model, 833 | ``None`` otherwise. 834 | """ 835 | if self.attention is not None: 836 | residual = x 837 | 838 | if self.layer_norm_first: 839 | x = self.layer_norm(x) 840 | 841 | x, position_bias = self.attention( 842 | x, attention_mask=attention_mask, position_bias=position_bias, key_padding_mask=key_padding_mask 843 | ) 844 | 845 | x = self.dropout(x) 846 | x = residual + x 847 | 848 | if self.layer_norm_first: 849 | if self.feed_forward is not None: 850 | x = x + self.feed_forward(self.final_layer_norm(x)) 851 | else: 852 | # NOTE: for post norm, the layer norms should always be applied even if the layers are pruned. 853 | x = self.layer_norm(x) 854 | if self.feed_forward is not None: 855 | x = x + self.feed_forward(x) 856 | x = self.final_layer_norm(x) 857 | return x, position_bias 858 | 859 | def get_num_params(self): 860 | num_params = self.embed_dim * 2 * 2 # two layer norms 861 | if self.attention is not None: 862 | num_params += self.attention.get_num_params() 863 | if self.feed_forward is not None: 864 | num_params += self.feed_forward.get_num_params() 865 | return num_params 866 | 867 | 868 | class Transformer(Module): 869 | def __init__( 870 | self, 871 | pos_conv_embed: Module, 872 | dropout: float, 873 | layers: Module, 874 | layer_norm_first: bool, 875 | layer_drop: float, 876 | ): 877 | super().__init__() 878 | self.pos_conv_embed = pos_conv_embed 879 | self.layer_norm = nn.LayerNorm(pos_conv_embed.embed_dim) 880 | self.layer_norm_first = layer_norm_first 881 | self.layer_drop = layer_drop 882 | self.dropout = nn.Dropout(dropout) 883 | self.layers = layers 884 | 885 | def _preprocess(self, x: Tensor): 886 | x = x + self.pos_conv_embed(x) 887 | 888 | if self.layer_norm_first: 889 | x = self.layer_norm(x) 890 | 891 | x = self.dropout(x) 892 | return x 893 | 894 | def forward( 895 | self, 896 | x: Tensor, 897 | attention_mask: Optional[Tensor] = None, 898 | position_bias: Optional[Tensor] = None, 899 | ) -> Tensor: 900 | x = self._preprocess(x) 901 | for layer in self.layers: 902 | if not (self.training and torch.rand(1).item() <= self.layer_drop): 903 | x, position_bias = layer(x, attention_mask, position_bias=position_bias) 904 | 905 | if not self.layer_norm_first: 906 | x = self.layer_norm(x) 907 | return x 908 | 909 | def get_intermediate_outputs( 910 | self, 911 | x: Tensor, 912 | attention_mask: Optional[Tensor] = None, 913 | num_layers: Optional[int] = None, 914 | position_bias: Optional[Tensor] = None, 915 | ) -> List[Tensor]: 916 | if num_layers is not None: 917 | if not 0 < num_layers <= len(self.layers): 918 | raise ValueError(f"`num_layers` must be between [1, {len(self.layers)}]") 919 | 920 | ret: List[Tensor] = [] 921 | x = self._preprocess(x) 922 | for layer in self.layers: 923 | x, position_bias = layer(x, attention_mask, position_bias=position_bias) 924 | ret.append(x) 925 | if num_layers is not None and len(ret) >= num_layers: 926 | return ret 927 | return ret 928 | 929 | def get_num_params(self): 930 | # pos_conv_embed and layer_norm 931 | num_params = sum(p.numel() for p in self.pos_conv_embed.parameters()) + self.pos_conv_embed.embed_dim * 2 932 | for layer in self.layers: 933 | num_params += layer.get_num_params() 934 | return num_params 935 | 936 | def prune(self): 937 | new_config = defaultdict(list) 938 | for layer in self.layers: 939 | attention_config = layer.attention.prune() 940 | new_config["use_attention"].append(attention_config["use_attention"]) 941 | if "remaining_heads" in attention_config: 942 | new_config["remaining_heads"].append(attention_config["remaining_heads"]) 943 | else: 944 | new_config["num_heads"].append(attention_config["num_heads"]) 945 | 946 | if not attention_config["use_attention"]: 947 | layer.attention = None 948 | 949 | ff_config = layer.feed_forward.prune() 950 | new_config["use_feed_forward"].append(ff_config["use_feed_forward"]) 951 | new_config["ff_interm_features"].append(ff_config["ff_interm_features"]) 952 | if not ff_config["use_feed_forward"]: 953 | layer.feed_forward = None 954 | 955 | return new_config 956 | 957 | 958 | class Encoder(Module): 959 | def __init__( 960 | self, 961 | feature_projection: Module, 962 | transformer: Module, 963 | ): 964 | super().__init__() 965 | self.feature_projection = feature_projection 966 | self.transformer = transformer 967 | 968 | def _preprocess( 969 | self, 970 | features: Tensor, 971 | lengths: Optional[Tensor] = None, 972 | ) -> Tuple[Tensor, Optional[Tensor]]: 973 | x = self.feature_projection(features) 974 | 975 | mask: Optional[Tensor] = None 976 | if lengths is not None: 977 | batch_size, max_len, _ = x.shape 978 | # create mask for padded elements and zero-out them 979 | mask = torch.arange(max_len, device=lengths.device).expand(batch_size, max_len) >= lengths[:, None] 980 | x[mask] = 0.0 981 | # extend the mask to attention shape and set weight 982 | mask = -10000.0 * mask[:, None, None, :].to(dtype=features.dtype) 983 | mask = mask.expand(batch_size, 1, max_len, max_len) 984 | return x, mask 985 | 986 | def forward( 987 | self, 988 | features: Tensor, 989 | lengths: Optional[Tensor] = None, 990 | ) -> Tensor: 991 | x, mask = self._preprocess(features, lengths) 992 | x = self.transformer(x, attention_mask=mask) 993 | return x 994 | 995 | def extract_features( 996 | self, 997 | features: Tensor, 998 | lengths: Optional[Tensor] = None, 999 | num_layers: Optional[int] = None, 1000 | ) -> List[Tensor]: 1001 | x, masks = self._preprocess(features, lengths) 1002 | interm = self.transformer.get_intermediate_outputs(x, attention_mask=masks, num_layers=num_layers) 1003 | return [x] + interm 1004 | 1005 | def get_num_params(self, in_features): 1006 | """Calculate the current model size.""" 1007 | feature_projection_size = self.feature_projection.get_num_params(in_features) 1008 | transformer_size = self.transformer.get_num_params() 1009 | return feature_projection_size + transformer_size 1010 | 1011 | def prune(self, conv_out_index): 1012 | """In-place pruning of submodules.""" 1013 | prune_layer_norm(self.feature_projection.layer_norm, conv_out_index) 1014 | prune_linear_layer(self.feature_projection.projection, conv_out_index, "input") 1015 | transformer_config = self.transformer.prune() 1016 | return transformer_config 1017 | 1018 | 1019 | ################################################################################ 1020 | def _get_feature_extractor( 1021 | norm_mode: str, 1022 | shapes: List[Tuple[int, int, int]], 1023 | bias: bool, 1024 | prune_conv_channels: bool = False, 1025 | ) -> FeatureExtractor: 1026 | """ 1027 | Args: 1028 | norm_mode (str): 1029 | Either "group_norm" or "layer_norm". 1030 | If "group_norm", then a single normalization is applied 1031 | in the first convolution block. Otherwise, all the convolution 1032 | blocks will have layer normalization. 1033 | This option corresponds to "extractor_mode" from fairseq. 1034 | Expected values are "group_norm" for Base arch, and 1035 | "layer_norm" for Large arch. 1036 | shapes (list of tuple of int): 1037 | Configuration of convolution layers. List of convolution configuration, 1038 | i.e. ``[(output_channel, kernel_size, stride), ...]`` 1039 | This option corresponds to "conv_feature_layers" from fairseq. 1040 | Expected values are 1041 | ``[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2`` 1042 | for all the architectures. 1043 | bias (bool): 1044 | Whether to include bias term to each convolution operation. 1045 | This option corresponds to "conv_bias" from fairseq. 1046 | Expected values are False for Base arch, and True for Large arch. 1047 | 1048 | See Also: 1049 | * Original implementation 1050 | https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L666-L733 1051 | * "extractor_mode" 1052 | - Def and base: 1053 | https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L38-L45 1054 | - Large: 1055 | https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L52 1056 | * "conv_feature_layers" 1057 | - Def, base and large: 1058 | https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L94-L100 1059 | * "conv_bias" 1060 | - Def and base: 1061 | https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L101-L103 1062 | - Large: 1063 | https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L61 1064 | """ 1065 | if norm_mode not in ["group_norm", "layer_norm"]: 1066 | raise ValueError("Invalid norm mode") 1067 | blocks = [] 1068 | in_channels = 1 1069 | for i, (out_channels, kernel_size, stride) in enumerate(shapes): 1070 | normalization = None 1071 | if norm_mode == "group_norm" and i == 0: 1072 | normalization = nn.GroupNorm( 1073 | num_groups=out_channels, 1074 | num_channels=out_channels, 1075 | affine=True, 1076 | ) 1077 | elif norm_mode == "layer_norm": 1078 | normalization = LayerNorm( 1079 | normalized_shape=out_channels, 1080 | elementwise_affine=True, 1081 | ) 1082 | blocks.append( 1083 | ConvLayerBlock( 1084 | in_channels=in_channels, 1085 | out_channels=out_channels, 1086 | kernel_size=kernel_size, 1087 | stride=stride, 1088 | bias=bias, 1089 | layer_norm=normalization, 1090 | prune_conv_channels=prune_conv_channels, 1091 | ) 1092 | ) 1093 | in_channels = out_channels 1094 | return FeatureExtractor(nn.ModuleList(blocks)) 1095 | 1096 | 1097 | def _get_encoder( 1098 | in_features: int, 1099 | embed_dim: int, 1100 | dropout_input: float, 1101 | pos_conv_kernel: int, 1102 | pos_conv_groups: int, 1103 | num_layers: int, 1104 | use_attention: List[bool], 1105 | use_feed_forward: List[bool], 1106 | num_heads: List[int], 1107 | head_dim: int, 1108 | attention_dropout: float, 1109 | ff_interm_features: List[int], 1110 | ff_interm_dropout: float, 1111 | dropout: float, 1112 | layer_norm_first: bool, 1113 | layer_drop: float, 1114 | prune_attention_heads: bool = False, 1115 | prune_attention_layer: bool = False, 1116 | prune_feed_forward_intermediate: bool = False, 1117 | prune_feed_forward_layer: bool = False, 1118 | ) -> Encoder: 1119 | """ 1120 | Args: 1121 | in_features (int): The number of input features. 1122 | embed_dim (int): 1123 | The dimension of embedding. 1124 | This option corresponds to "encoder_embed_dim" from fairseq. 1125 | Expected values are 768 for Base arch, and 1024 for Large arch. 1126 | dropout_input (float): 1127 | The dropout probability applied after the input feature is projected 1128 | to ``embed_dim``. 1129 | This option corresponds to "dropout_input" from fairseq. 1130 | Expected values are 0.1 for both Base and Large arch. 1131 | pos_conv_kernel (int): 1132 | The kernel size of convolutional positional embeddings. 1133 | This option corresponds to "conv_pos" from fairseq. 1134 | Expected values are 128 for both Base and Large arch. 1135 | pos_conv_groups (int): 1136 | The number of groups of convolutional positional embeddings. 1137 | This option corresponds to "conv_pos_groups" from fairseq. 1138 | Expected values are 16 for both Base and Large arch. 1139 | num_layers (int): 1140 | The number of self attention layers in transformer block. 1141 | This option corresponds to "encoder_layers" from fairseq. 1142 | Expected values are 12 for Base and 24 for Large arch. 1143 | num_heads (int): 1144 | The number of heads in self attention layers. 1145 | This option corresponds to "encoder_attention_heads" from fairseq. 1146 | Expected values are 12 for Base and 16 for Large arch. 1147 | attention_dropout (float): 1148 | The dropout probability applied after softmax in self-attention layer. 1149 | This option corresponds to "attention_dropout" from fairseq. 1150 | Expected values are 0.1 for Base and 0.0 for Large arch. 1151 | ff_interm_features (int): 1152 | The dimension of hidden features in feed forward layer. 1153 | This option corresponds to "encoder_ffn_embed_dim" from fairseq. 1154 | Expected values are 3072 for Base and 4096 for Large arch. 1155 | ff_interm_dropout (float): 1156 | The dropout probability applied in feedforward layer. 1157 | This option correspinds to "activation_dropout" from fairseq. 1158 | Expected values are 0.1 for both Base and Large arch. 1159 | dropout (float): 1160 | The dropout probability applied at the end of feed forward layer. 1161 | This option corresponds to "dropout" from fairseq. 1162 | Expected values are 0.1 for Base and 0.0 for Large arch. 1163 | layer_norm_first (bool): 1164 | Control the order of layer norm in transformer layer and each encoder layer. 1165 | If True, in transformer layer, layer norm is applied before features are fed 1166 | to encoder layers. In encoder layer, two layer norms are applied before and after 1167 | self attention. 1168 | If False, in transformer layer, layer norm is applied after features are fed 1169 | to encoder layers. In encoder layer, two layer norms are applied after self 1170 | attention, before and after feed forward. 1171 | This option corresponds to "layer_norm_first" from fairseq. 1172 | Expected values are False for Base and True for Large arch. 1173 | layer_drop (float): 1174 | Probability to drop each encoder layer during training. 1175 | This option corresponds to "layerdrop" from fairseq. 1176 | Expected values are 0.1 for both Base and Large arch. 1177 | 1178 | See Also: 1179 | * "encoder_embed_dim" 1180 | - Def and base 1181 | https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L49-L51 1182 | - Large 1183 | https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L64 1184 | * "dropout_input" 1185 | - Def, base and large 1186 | https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L75-L78 1187 | * "conv_pos" 1188 | - Def, base and large 1189 | NOTE: The description is wrong. 1190 | https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L204-L207 1191 | - Usage 1192 | https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L756 1193 | * "conv_pos_groups" 1194 | - Def, base and large 1195 | https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L208-L211 1196 | * "encoder_layers" 1197 | - Def and base 1198 | https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L46-L48 1199 | - Large 1200 | https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L63 1201 | * "encoder_attention_heads" 1202 | - Def and base 1203 | https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L55-L57 1204 | - Large 1205 | https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L66 1206 | * "attention_dropout" 1207 | - Def and base 1208 | https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L66-L68 1209 | - Large 1210 | https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L60 1211 | * "encoder_ffn_embed_dim" 1212 | - Def and base 1213 | https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L52-L54 1214 | - Large 1215 | https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L65 1216 | * "activation_dropout" 1217 | - Def 1218 | https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L69-L71 1219 | - Base 1220 | https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/base_960h.yaml#L55 1221 | - Large 1222 | https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/vox_960h.yaml#L55 1223 | * "dropout" 1224 | - Def and base 1225 | https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L63-L65 1226 | - Large 1227 | https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L59 1228 | * "layer_norm_first" 1229 | - Def and base 1230 | https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L91-L93 1231 | - Large 1232 | https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L53 1233 | * "layerdrop" 1234 | - Def 1235 | https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L72-L74 1236 | - Base 1237 | https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/base_960h.yaml#L54 1238 | - Large 1239 | https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/vox_960h.yaml#L54 1240 | """ 1241 | feature_projection = FeatureProjection(in_features, embed_dim, dropout_input) 1242 | pos_conv = ConvolutionalPositionalEmbedding(embed_dim, pos_conv_kernel, pos_conv_groups) 1243 | 1244 | # Original impl 1245 | # https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L768-L782 1246 | encoder_layers = nn.ModuleList() 1247 | for idx in range(num_layers): 1248 | if use_attention[idx]: 1249 | attention = SelfAttention( 1250 | embed_dim=embed_dim, 1251 | num_heads=num_heads[idx], 1252 | head_dim=head_dim, 1253 | dropout=attention_dropout, 1254 | prune_heads=prune_attention_heads, 1255 | prune_layer=prune_attention_layer, 1256 | ) 1257 | else: 1258 | attention = None 1259 | if use_feed_forward[idx]: 1260 | feed_forward = FeedForward( 1261 | io_features=embed_dim, 1262 | intermediate_features=ff_interm_features[idx], 1263 | intermediate_dropout=ff_interm_dropout, 1264 | output_dropout=dropout, 1265 | prune_intermediate=prune_feed_forward_intermediate, 1266 | prune_layer=prune_feed_forward_layer, 1267 | ) 1268 | else: 1269 | feed_forward = None 1270 | encoder_layers.append( 1271 | EncoderLayer( 1272 | attention=attention, 1273 | dropout=dropout, 1274 | layer_norm_first=layer_norm_first, 1275 | feed_forward=feed_forward, 1276 | embed_dim=embed_dim, 1277 | ) 1278 | ) 1279 | transformer = Transformer( 1280 | pos_conv_embed=pos_conv, 1281 | dropout=dropout, 1282 | layers=encoder_layers, 1283 | layer_norm_first=not layer_norm_first, 1284 | layer_drop=layer_drop, 1285 | ) 1286 | return Encoder(feature_projection, transformer) 1287 | 1288 | 1289 | def _get_wavlm_encoder( 1290 | in_features: int, 1291 | embed_dim: int, 1292 | dropout_input: float, 1293 | pos_conv_kernel: int, 1294 | pos_conv_groups: int, 1295 | num_layers: int, 1296 | use_attention: List[bool], 1297 | use_feed_forward: List[bool], 1298 | total_num_heads: List[int], 1299 | remaining_heads: List[List[int]], 1300 | num_buckets: int, 1301 | max_distance: int, 1302 | attention_dropout: float, 1303 | ff_interm_features: List[int], 1304 | ff_interm_dropout: float, 1305 | dropout: float, 1306 | layer_norm_first: bool, 1307 | layer_drop: float, 1308 | prune_attention_heads: bool = False, 1309 | prune_attention_layer: bool = False, 1310 | prune_feed_forward_intermediate: bool = False, 1311 | prune_feed_forward_layer: bool = False, 1312 | ) -> Encoder: 1313 | """ 1314 | Construct encoder for WavLM model :cite:`chen2022wavlm`. The structure of the encoder and most of the argments are 1315 | the same as in :py:func:`_get_encoder` so refer there for documentation. The only difference from Wav2Vec2 encoder 1316 | is usage of `WavLMSelfAttention` instead of `SelfAttention` and two additional parameters: `num_buckets` and 1317 | `max_distance`. 1318 | Args: 1319 | in_features (int): See :py:func:`_get_encoder`. 1320 | embed_dim (int): See :py:func:`_get_encoder`. 1321 | dropout_input (float): See :py:func:`_get_encoder`. 1322 | pos_conv_kernel (int): See :py:func:`_get_encoder`. 1323 | pos_conv_groups (int): See :py:func:`_get_encoder`. 1324 | num_layers (int): See :py:func:`_get_encoder`. 1325 | num_heads (int): See :py:func:`_get_encoder`. 1326 | num_buckets (int): Number of buckets for relative position embedding. 1327 | max_distance (int): Maximum distance for relative position embedding. 1328 | attention_dropout (float): See :py:func:`_get_encoder`. 1329 | ff_interm_features (int): See :py:func:`_get_encoder`. 1330 | ff_interm_dropout (float): See :py:func:`_get_encoder`. 1331 | dropout (float): See :py:func:`_get_encoder`. 1332 | layer_norm_first (bool): See :py:func:`_get_encoder`. 1333 | layer_drop (float): See :py:func:`_get_encoder`. 1334 | 1335 | """ 1336 | feature_projection = FeatureProjection(in_features, embed_dim, dropout_input) 1337 | pos_conv = ConvolutionalPositionalEmbedding(embed_dim, pos_conv_kernel, pos_conv_groups) 1338 | 1339 | # Original impl 1340 | # https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L768-L782 1341 | encoder_layers = nn.ModuleList() 1342 | for i in range(num_layers): 1343 | if use_attention[i]: 1344 | attention = WavLMSelfAttention( 1345 | embed_dim=embed_dim, 1346 | total_num_heads=total_num_heads[i], 1347 | remaining_heads=remaining_heads[i], 1348 | dropout=attention_dropout, 1349 | has_relative_attention_bias=(i == 0), # Position embedding is only necessary in the first layer. 1350 | num_buckets=num_buckets, 1351 | max_distance=max_distance, 1352 | prune_heads=prune_attention_heads, 1353 | prune_layer=prune_attention_layer, 1354 | ) 1355 | else: 1356 | attention = None 1357 | if use_feed_forward[i]: 1358 | feed_forward = FeedForward( 1359 | io_features=embed_dim, 1360 | intermediate_features=ff_interm_features[i], 1361 | intermediate_dropout=ff_interm_dropout, 1362 | output_dropout=dropout, 1363 | prune_intermediate=prune_feed_forward_intermediate, 1364 | prune_layer=prune_feed_forward_layer, 1365 | ) 1366 | else: 1367 | feed_forward = None 1368 | encoder_layers.append( 1369 | EncoderLayer( 1370 | attention=attention, 1371 | dropout=dropout, 1372 | layer_norm_first=layer_norm_first, 1373 | feed_forward=feed_forward, 1374 | embed_dim=embed_dim, 1375 | ) 1376 | ) 1377 | transformer = Transformer( 1378 | pos_conv_embed=pos_conv, 1379 | dropout=dropout, 1380 | layers=encoder_layers, 1381 | layer_norm_first=not layer_norm_first, 1382 | layer_drop=layer_drop, 1383 | ) 1384 | return Encoder(feature_projection, transformer) 1385 | 1386 | 1387 | def _get_padding_mask(input: Tensor, lengths: Tensor) -> Tensor: 1388 | """Generate the padding mask given the padded input and the lengths Tensors. 1389 | Args: 1390 | input (Tensor): The padded Tensor of dimension `[batch, max_len, frequency]`. 1391 | lengths (Tensor): The lengths Tensor of dimension `[batch,]`. 1392 | 1393 | Returns: 1394 | (Tensor): The padding mask. 1395 | """ 1396 | batch_size, max_len, _ = input.shape 1397 | mask = torch.arange(max_len, device=lengths.device).expand(batch_size, max_len) >= lengths[:, None] 1398 | return mask 1399 | 1400 | 1401 | class GradMultiply(torch.autograd.Function): 1402 | @staticmethod 1403 | def forward(ctx, x, scale): 1404 | ctx.scale = scale 1405 | res = x.new(x) 1406 | return res 1407 | 1408 | @staticmethod 1409 | def backward(ctx, grad): 1410 | return grad * ctx.scale, None 1411 | -------------------------------------------------------------------------------- /wav2vec2/hardconcrete.py: -------------------------------------------------------------------------------- 1 | """Implementation of the hard Concrete distribution. 2 | 3 | Originally from: 4 | https://github.com/asappresearch/flop/blob/master/flop/hardconcrete.py 5 | 6 | """ 7 | 8 | import math 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | 14 | class HardConcrete(nn.Module): 15 | """A HarcConcrete module. 16 | Use this module to create a mask of size N, which you can 17 | then use to perform L0 regularization. 18 | 19 | To obtain a mask, simply run a forward pass through the module 20 | with no input data. The mask is sampled in training mode, and 21 | fixed during evaluation mode, e.g.: 22 | 23 | >>> module = HardConcrete(n_in=100) 24 | >>> mask = module() 25 | >>> norm = module.l0_norm() 26 | """ 27 | 28 | def __init__( 29 | self, 30 | n_in: int, 31 | init_mean: float = 0.5, 32 | init_std: float = 0.01, 33 | temperature: float = 2/3, # from CoFi 34 | stretch: float = 0.1, 35 | eps: float = 1e-6 36 | ) -> None: 37 | """Initialize the HardConcrete module. 38 | Parameters 39 | ---------- 40 | n_in : int 41 | The number of hard concrete variables in this mask. 42 | init_mean : float, optional 43 | Initial drop rate for hard concrete parameter, 44 | by default 0.5., 45 | init_std: float, optional 46 | Used to initialize the hard concrete parameters, 47 | by default 0.01. 48 | temperature : float, optional 49 | Temperature used to control the sharpness of the 50 | distribution, by default 1.0 51 | stretch : float, optional 52 | Stretch the sampled value from [0, 1] to the interval 53 | [-stretch, 1 + stretch], by default 0.1. 54 | """ 55 | super().__init__() 56 | 57 | self.n_in = n_in 58 | self.limit_l = -stretch 59 | self.limit_r = 1.0 + stretch 60 | self.log_alpha = nn.Parameter(torch.zeros(n_in)) 61 | self.beta = temperature 62 | self.init_mean = init_mean 63 | self.init_std = init_std 64 | self.bias = -self.beta * math.log(-self.limit_l / self.limit_r) 65 | 66 | self.eps = eps 67 | self.compiled_mask = None 68 | self.reset_parameters() 69 | 70 | def reset_parameters(self): 71 | """Reset the parameters of this module.""" 72 | self.compiled_mask = None 73 | mean = math.log(1 - self.init_mean) - math.log(self.init_mean) 74 | self.log_alpha.data.normal_(mean, self.init_std) 75 | 76 | def l0_norm(self) -> torch.Tensor: 77 | """Compute the expected L0 norm of this mask. 78 | Returns 79 | ------- 80 | torch.Tensor 81 | The expected L0 norm. 82 | """ 83 | return (self.log_alpha + self.bias).sigmoid().sum() 84 | 85 | def forward(self) -> torch.Tensor: 86 | """Sample a hard concrete mask. 87 | Returns 88 | ------- 89 | torch.Tensor 90 | The sampled binary mask 91 | """ 92 | if self.training: 93 | # Reset the compiled mask 94 | self.compiled_mask = None 95 | # Sample mask dynamically 96 | u = self.log_alpha.new(self.n_in).uniform_(self.eps, 1 - self.eps) 97 | s = torch.sigmoid((torch.log(u / (1 - u)) + self.log_alpha) / self.beta) 98 | s = s * (self.limit_r - self.limit_l) + self.limit_l 99 | mask = s.clamp(min=0., max=1.) 100 | 101 | else: 102 | # Compile new mask if not cached 103 | if self.compiled_mask is None: 104 | # Get expected sparsity 105 | expected_num_zeros = self.n_in - self.l0_norm().item() 106 | num_zeros = round(expected_num_zeros) 107 | # Approximate expected value of each mask variable z; 108 | # We use an empirically validated magic number 0.8 109 | soft_mask = torch.sigmoid(self.log_alpha / self.beta * 0.8) 110 | # Prune small values to set to 0 111 | _, indices = torch.topk(soft_mask, k=num_zeros, largest=False) 112 | soft_mask[indices] = 0. 113 | self.compiled_mask = soft_mask 114 | mask = self.compiled_mask 115 | 116 | return mask 117 | 118 | def extra_repr(self) -> str: 119 | return str(self.n_in) 120 | 121 | def __repr__(self) -> str: 122 | return "{}({})".format(self.__class__.__name__, self.extra_repr()) 123 | -------------------------------------------------------------------------------- /wav2vec2/model.py: -------------------------------------------------------------------------------- 1 | """Speech SSL models supporting pruning. 2 | 3 | Originally from: 4 | https://github.com/pytorch/audio/blob/main/torchaudio/models/wav2vec2/model.py 5 | 6 | """ 7 | 8 | import math 9 | from typing import List, Optional, Tuple 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from torch import Tensor 14 | from torch.nn import Module 15 | 16 | from . import components 17 | 18 | 19 | class Wav2Vec2Model(Module): 20 | """Acoustic model used in *wav2vec 2.0* :cite:`baevski2020wav2vec`. 21 | 22 | Note: 23 | To build the model, please use one of the factory functions. 24 | :py:func:`wav2vec2_model`, :py:func:`wav2vec2_base`, :py:func:`wav2vec2_large`, 25 | :py:func:`wav2vec2_large_lv60k`, :py:func:`hubert_base`, :py:func:`hubert_large`, 26 | and :py:func:`hubert_xlarge`. 27 | 28 | See Also: 29 | * :class:`torchaudio.pipelines.Wav2Vec2Bundle`: Pretrained models (without fine-tuning) 30 | * :class:`torchaudio.pipelines.Wav2Vec2ASRBundle`: ASR pipelines with pretrained models. 31 | 32 | Args: 33 | feature_extractor (torch.nn.Module): 34 | Feature extractor that extracts feature vectors from raw audio Tensor. 35 | 36 | encoder (torch.nn.Module): 37 | Encoder that converts the audio features into the sequence of probability 38 | distribution (in negative log-likelihood) over labels. 39 | 40 | aux (torch.nn.Module or None, optional): 41 | Auxiliary module. If provided, the output from encoder is passed to this module. 42 | """ # noqa: E501 43 | 44 | def __init__( 45 | self, 46 | normalize_waveform: bool, 47 | feature_extractor: Module, 48 | encoder: Module, 49 | aux: Optional[Module] = None, 50 | ): 51 | super().__init__() 52 | self.normalize_waveform = normalize_waveform 53 | self.feature_extractor = feature_extractor 54 | self.encoder = encoder 55 | self.aux = aux 56 | 57 | @torch.jit.export 58 | def extract_features( 59 | self, 60 | waveforms: Tensor, 61 | lengths: Optional[Tensor] = None, 62 | num_layers: Optional[int] = None, 63 | ) -> Tuple[List[Tensor], Optional[Tensor]]: 64 | """Extract feature vectors from raw waveforms 65 | 66 | This returns the list of outputs from the intermediate layers of 67 | transformer block in encoder. 68 | 69 | Args: 70 | waveforms (Tensor): Audio tensor of shape `(batch, frames)`. 71 | lengths (Tensor or None, optional): 72 | Indicates the valid length of each audio in the batch. 73 | Shape: `(batch, )`. 74 | When the ``waveforms`` contains audios with different durations, 75 | by providing ``lengths`` argument, the model will compute 76 | the corresponding valid output lengths and apply proper mask in 77 | transformer attention layer. 78 | If ``None``, it is assumed that the entire audio waveform 79 | length is valid. 80 | num_layers (int or None, optional): 81 | If given, limit the number of intermediate layers to go through. 82 | Providing `1` will stop the computation after going through one 83 | intermediate layers. If not given, the outputs from all the 84 | intermediate layers are returned. 85 | 86 | Returns: 87 | (List[Tensor], Optional[Tensor]): 88 | List of Tensors 89 | Features from requested layers. 90 | Each Tensor is of shape: `(batch, time frame, feature dimension)` 91 | Tensor or None 92 | If ``lengths`` argument was provided, a Tensor of shape `(batch, )` 93 | is returned. 94 | It indicates the valid length in time axis of each feature Tensor. 95 | """ 96 | if self.normalize_waveform: 97 | if lengths is not None: 98 | waveforms = [ 99 | F.layer_norm(wave[:length], (length,)) for wave, length in zip(waveforms, lengths) 100 | ] 101 | waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True) 102 | else: 103 | waveforms = F.layer_norm(waveforms, waveforms.shape[-1:]) 104 | 105 | x, lengths = self.feature_extractor(waveforms, lengths) 106 | x = self.encoder.extract_features(x, lengths, num_layers) # (num_layers+1,), including the input 107 | return x, lengths 108 | 109 | def get_num_params(self): 110 | """Calculate the current size.""" 111 | feature_extractor_size, encoder_in_features = self.feature_extractor.get_num_params_and_final_out_channels() 112 | encoder_size = self.encoder.get_num_params(encoder_in_features) 113 | return feature_extractor_size + encoder_size 114 | 115 | def prune(self): 116 | self.eval() # must be in eval mode 117 | conv_config, conv_out_index = self.feature_extractor.prune() # [(output_channel, kernel_size, stride), ...] 118 | transformer_config = self.encoder.prune(conv_out_index) # NOTE: this is a defaultdict(list) 119 | use_attention = transformer_config["use_attention"] 120 | use_feed_forward = transformer_config["use_feed_forward"] 121 | num_heads = transformer_config["num_heads"] # can be [] 122 | remaining_heads = transformer_config["remaining_heads"] # can be [] 123 | ff_interm_features = transformer_config["ff_interm_features"] 124 | 125 | return conv_config, use_attention, use_feed_forward, num_heads, remaining_heads, ff_interm_features 126 | 127 | def forward( 128 | self, 129 | waveforms: Tensor, 130 | lengths: Optional[Tensor] = None, 131 | ) -> Tuple[Tensor, Optional[Tensor]]: 132 | """Compute the sequence of probability distribution over labels. 133 | 134 | Args: 135 | waveforms (Tensor): Audio tensor of shape `(batch, frames)`. 136 | lengths (Tensor or None, optional): 137 | Indicates the valid length of each audio in the batch. 138 | Shape: `(batch, )`. 139 | When the ``waveforms`` contains audios with different durations, 140 | by providing ``lengths`` argument, the model will compute 141 | the corresponding valid output lengths and apply proper mask in 142 | transformer attention layer. 143 | If ``None``, it is assumed that all the audio in ``waveforms`` 144 | have valid length. Default: ``None``. 145 | 146 | Returns: 147 | (Tensor, Optional[Tensor]): 148 | Tensor 149 | The sequences of probability distribution (in logit) over labels. 150 | Shape: `(batch, frames, num labels)`. 151 | Tensor or None 152 | If ``lengths`` argument was provided, a Tensor of shape `(batch, )` 153 | is returned. 154 | It indicates the valid length in time axis of the output Tensor. 155 | """ 156 | if self.normalize_waveform: 157 | if lengths is not None: 158 | waveforms = [ 159 | F.layer_norm(wave[:length], (length,)) for wave, length in zip(waveforms, lengths) 160 | ] 161 | waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True) 162 | else: 163 | waveforms = F.layer_norm(waveforms, waveforms.shape[-1:]) 164 | 165 | x, lengths = self.feature_extractor(waveforms, lengths) 166 | x = self.encoder(x, lengths) 167 | if self.aux is not None: 168 | x = self.aux(x) 169 | return x, lengths 170 | 171 | 172 | def wav2vec2_model(**configs) -> Wav2Vec2Model: 173 | """Wraps the original wav2vec2_model and wavlm_model.""" 174 | 175 | if "encoder_remaining_heads" in configs: 176 | return wavlm_model(**configs) 177 | 178 | return wav2vec2_model_original(**configs) 179 | 180 | 181 | def wav2vec2_model_original( 182 | extractor_mode: str, 183 | extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]], 184 | extractor_conv_bias: bool, 185 | encoder_embed_dim: int, 186 | encoder_projection_dropout: float, 187 | encoder_pos_conv_kernel: int, 188 | encoder_pos_conv_groups: int, 189 | encoder_num_layers: int, 190 | encoder_use_attention: List[bool], 191 | encoder_use_feed_forward: List[bool], 192 | encoder_num_heads: List[int], 193 | encoder_head_dim: int, 194 | encoder_attention_dropout: float, 195 | encoder_ff_interm_features: List[int], 196 | encoder_ff_interm_dropout: float, 197 | encoder_dropout: float, 198 | encoder_layer_norm_first: bool, 199 | encoder_layer_drop: float, 200 | aux_num_out: Optional[int], 201 | normalize_waveform: bool, 202 | extractor_prune_conv_channels: bool = False, 203 | encoder_prune_attention_heads: bool = False, 204 | encoder_prune_attention_layer: bool = False, 205 | encoder_prune_feed_forward_intermediate: bool = False, 206 | encoder_prune_feed_forward_layer: bool = False, 207 | ) -> Wav2Vec2Model: 208 | """Builds custom :class:`~torchaudio.models.Wav2Vec2Model`. 209 | 210 | Note: 211 | The "feature extractor" below corresponds to 212 | `ConvFeatureExtractionModel `__ 213 | in the original ``fairseq`` implementation. 214 | This is referred as "(convolutional) feature encoder" in the *wav2vec 2.0* 215 | :cite:`baevski2020wav2vec` paper. 216 | 217 | The "encoder" below corresponds to `TransformerEncoder `__, 218 | and this is referred as "Transformer" in the paper. 219 | 220 | Args: 221 | extractor_mode (str): Operation mode of feature extractor. 222 | Valid values are ``"group_norm"`` or ``"layer_norm"``. 223 | If ``"group_norm"``, then a single normalization is applied 224 | in the first convolution block. Otherwise, all the convolution 225 | blocks will have layer normalization. 226 | 227 | This option corresponds to ``extractor_mode`` from ``fairseq``. 228 | extractor_conv_layer_config (list of integer tuples or None): 229 | Configuration of convolution layers in feature extractor. 230 | List of convolution configuration, 231 | i.e. ``[(output_channel, kernel_size, stride), ...]`` 232 | 233 | If ``None`` is provided, then the following default value is used. 234 | 235 | .. code-block:: python 236 | 237 | [ 238 | (512, 10, 5), 239 | (512, 3, 2), 240 | (512, 3, 2), 241 | (512, 3, 2), 242 | (512, 3, 2), 243 | (512, 2, 2), 244 | (512, 2, 2), 245 | ] 246 | 247 | This option corresponds to ``conv_feature_layers`` from ``fairseq``. 248 | 249 | extractor_conv_bias (bool): 250 | Whether to include bias term to each convolution operation. 251 | 252 | This option corresponds to ``conv_bias`` from ``fairseq``. 253 | 254 | encoder_embed_dim (int): 255 | The dimension of embedding in encoder. 256 | 257 | This option corresponds to ``encoder_embed_dim`` from ``fairseq``. 258 | 259 | encoder_projection_dropout (float): 260 | The dropout probability applied after the input feature is projected 261 | to ``encoder_embed_dim``. 262 | 263 | This option corresponds to ``dropout_input`` from ``fairseq``. 264 | 265 | encoder_pos_conv_kernel (int): 266 | The kernel size of convolutional positional embeddings. 267 | 268 | This option corresponds to ``conv_pos`` from ``fairseq``. 269 | 270 | encoder_pos_conv_groups (int): 271 | The number of groups of convolutional positional embeddings. 272 | 273 | This option corresponds to ``conv_pos_groups`` from ``fairseq``. 274 | 275 | encoder_num_layers (int): 276 | The number of self attention layers in transformer block. 277 | 278 | This option corresponds to ``encoder_layers`` from ``fairseq``. 279 | 280 | encoder_num_heads (int): 281 | The number of heads in self attention layers. 282 | 283 | This option corresponds to ``encoder_attention_heads`` from ``fairseq``. 284 | 285 | encoder_attention_dropout (float): 286 | The dropout probability applied after softmax in self-attention layer. 287 | 288 | This option corresponds to ``attention_dropout`` from ``fairseq``. 289 | 290 | encoder_ff_interm_features (int): 291 | The dimension of hidden features in feed forward layer. 292 | 293 | This option corresponds to ``encoder_ffn_embed_dim`` from ``fairseq``. 294 | 295 | encoder_ff_interm_dropout (float): 296 | The dropout probability applied in feedforward layer. 297 | 298 | This option correspinds to ``activation_dropout`` from ``fairseq``. 299 | 300 | encoder_dropout (float): 301 | The dropout probability applied at the end of feed forward layer. 302 | 303 | This option corresponds to ``dropout`` from ``fairseq``. 304 | 305 | encoder_layer_norm_first (bool): 306 | Control the order of layer norm in transformer layer and each encoder layer. 307 | If True, in transformer layer, layer norm is applied before features are fed 308 | to encoder layers. In encoder layer, two layer norms are applied before and after 309 | self attention. 310 | If False, in transformer layer, layer norm is applied after features are fed 311 | to encoder layers. In encoder layer, two layer norms are applied after self 312 | attention, before and after feed forward. 313 | 314 | This option corresponds to ``layer_norm_first`` from ``fairseq``. 315 | 316 | encoder_layer_drop (float): 317 | Probability to drop each encoder layer during training. 318 | 319 | This option corresponds to ``layerdrop`` from ``fairseq``. 320 | 321 | aux_num_out (int or None): 322 | When provided, attach an extra linear layer on top of encoder, which can be 323 | used for fine-tuning. 324 | 325 | Returns: 326 | Wav2Vec2Model: 327 | The resulting model. 328 | """ # noqa: E501 329 | if extractor_conv_layer_config is None: 330 | extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2 331 | 332 | feature_extractor = components._get_feature_extractor( 333 | extractor_mode, extractor_conv_layer_config, extractor_conv_bias, 334 | prune_conv_channels=extractor_prune_conv_channels, 335 | ) 336 | encoder = components._get_encoder( 337 | in_features=extractor_conv_layer_config[-1][0], 338 | embed_dim=encoder_embed_dim, 339 | dropout_input=encoder_projection_dropout, 340 | pos_conv_kernel=encoder_pos_conv_kernel, 341 | pos_conv_groups=encoder_pos_conv_groups, 342 | num_layers=encoder_num_layers, 343 | use_attention=encoder_use_attention, 344 | use_feed_forward=encoder_use_feed_forward, 345 | num_heads=encoder_num_heads, 346 | head_dim=encoder_head_dim, 347 | attention_dropout=encoder_attention_dropout, 348 | ff_interm_features=encoder_ff_interm_features, 349 | ff_interm_dropout=encoder_ff_interm_dropout, 350 | dropout=encoder_dropout, 351 | layer_norm_first=encoder_layer_norm_first, 352 | layer_drop=encoder_layer_drop, 353 | prune_attention_heads=encoder_prune_attention_heads, 354 | prune_attention_layer=encoder_prune_attention_layer, 355 | prune_feed_forward_intermediate=encoder_prune_feed_forward_intermediate, 356 | prune_feed_forward_layer=encoder_prune_feed_forward_layer, 357 | ) 358 | aux = None 359 | if aux_num_out is not None: 360 | aux = torch.nn.Linear(in_features=encoder_embed_dim, out_features=aux_num_out) 361 | return Wav2Vec2Model(normalize_waveform, feature_extractor, encoder, aux) 362 | 363 | 364 | def wav2vec2_base( 365 | encoder_projection_dropout: float = 0.1, 366 | encoder_attention_dropout: float = 0.1, 367 | encoder_ff_interm_dropout: float = 0.1, 368 | encoder_dropout: float = 0.1, 369 | encoder_layer_drop: float = 0.1, 370 | aux_num_out: Optional[int] = None, 371 | extractor_prune_conv_channels: bool = False, 372 | encoder_prune_attention_heads: bool = False, 373 | encoder_prune_attention_layer: bool = False, 374 | encoder_prune_feed_forward_intermediate: bool = False, 375 | encoder_prune_feed_forward_layer: bool = False, 376 | ) -> Wav2Vec2Model: 377 | """Builds "base" :class:`~torchaudio.models.Wav2Vec2Model` from *wav2vec 2.0* :cite:`baevski2020wav2vec` 378 | 379 | Args: 380 | encoder_projection_dropout (float): 381 | See :py:func:`wav2vec2_model`. 382 | encoder_attention_dropout (float): 383 | See :py:func:`wav2vec2_model`. 384 | encoder_ff_interm_dropout (float): 385 | See :py:func:`wav2vec2_model`. 386 | encoder_dropout (float): 387 | See :py:func:`wav2vec2_model`. 388 | encoder_layer_drop (float): 389 | See :py:func:`wav2vec2_model`. 390 | aux_num_out (int or None, optional): 391 | See :py:func:`wav2vec2_model`. 392 | 393 | Returns: 394 | Wav2Vec2Model: 395 | The resulting model. 396 | """ # noqa: E501 397 | return wav2vec2_model( 398 | extractor_mode="group_norm", 399 | extractor_conv_layer_config=None, 400 | extractor_conv_bias=False, 401 | encoder_embed_dim=768, 402 | encoder_projection_dropout=encoder_projection_dropout, 403 | encoder_pos_conv_kernel=128, 404 | encoder_pos_conv_groups=16, 405 | encoder_num_layers=12, 406 | encoder_num_heads=12, 407 | encoder_attention_dropout=encoder_attention_dropout, 408 | encoder_ff_interm_features=3072, 409 | encoder_ff_interm_dropout=encoder_ff_interm_dropout, 410 | encoder_dropout=encoder_dropout, 411 | encoder_layer_norm_first=False, 412 | encoder_layer_drop=encoder_layer_drop, 413 | aux_num_out=aux_num_out, 414 | extractor_prune_conv_channels=extractor_prune_conv_channels, 415 | encoder_prune_attention_heads=encoder_prune_attention_heads, 416 | encoder_prune_attention_layer=encoder_prune_attention_layer, 417 | encoder_prune_feed_forward_intermediate=encoder_prune_feed_forward_intermediate, 418 | encoder_prune_feed_forward_layer=encoder_prune_feed_forward_layer, 419 | ) 420 | 421 | 422 | def wav2vec2_large( 423 | encoder_projection_dropout: float = 0.1, 424 | encoder_attention_dropout: float = 0.1, 425 | encoder_ff_interm_dropout: float = 0.1, 426 | encoder_dropout: float = 0.1, 427 | encoder_layer_drop: float = 0.1, 428 | aux_num_out: Optional[int] = None, 429 | extractor_prune_conv_channels: bool = False, 430 | encoder_prune_attention_heads: bool = False, 431 | encoder_prune_attention_layer: bool = False, 432 | encoder_prune_feed_forward_intermediate: bool = False, 433 | encoder_prune_feed_forward_layer: bool = False, 434 | ) -> Wav2Vec2Model: 435 | """Builds "large" :class:`~torchaudio.models.Wav2Vec2Model` from *wav2vec 2.0* :cite:`baevski2020wav2vec` 436 | 437 | Args: 438 | encoder_projection_dropout (float): 439 | See :py:func:`wav2vec2_model`. 440 | encoder_attention_dropout (float): 441 | See :py:func:`wav2vec2_model`. 442 | encoder_ff_interm_dropout (float): 443 | See :py:func:`wav2vec2_model`. 444 | encoder_dropout (float): 445 | See :py:func:`wav2vec2_model`. 446 | encoder_layer_drop (float): 447 | See :py:func:`wav2vec2_model`. 448 | aux_num_out (int or None, optional): 449 | See :py:func:`wav2vec2_model`. 450 | 451 | Returns: 452 | Wav2Vec2Model: 453 | The resulting model. 454 | """ # noqa: E501 455 | return wav2vec2_model( 456 | extractor_mode="group_norm", 457 | extractor_conv_layer_config=None, 458 | extractor_conv_bias=False, 459 | encoder_embed_dim=1024, 460 | encoder_projection_dropout=encoder_projection_dropout, 461 | encoder_pos_conv_kernel=128, 462 | encoder_pos_conv_groups=16, 463 | encoder_num_layers=24, 464 | encoder_num_heads=16, 465 | encoder_attention_dropout=encoder_attention_dropout, 466 | encoder_ff_interm_features=4096, 467 | encoder_ff_interm_dropout=encoder_ff_interm_dropout, 468 | encoder_dropout=encoder_dropout, 469 | encoder_layer_norm_first=False, 470 | encoder_layer_drop=encoder_layer_drop, 471 | aux_num_out=aux_num_out, 472 | extractor_prune_conv_channels=extractor_prune_conv_channels, 473 | encoder_prune_attention_heads=encoder_prune_attention_heads, 474 | encoder_prune_attention_layer=encoder_prune_attention_layer, 475 | encoder_prune_feed_forward_intermediate=encoder_prune_feed_forward_intermediate, 476 | encoder_prune_feed_forward_layer=encoder_prune_feed_forward_layer, 477 | ) 478 | 479 | 480 | def wav2vec2_large_lv60k( 481 | encoder_projection_dropout: float = 0.1, 482 | encoder_attention_dropout: float = 0.0, 483 | encoder_ff_interm_dropout: float = 0.1, 484 | encoder_dropout: float = 0.0, 485 | encoder_layer_drop: float = 0.1, 486 | aux_num_out: Optional[int] = None, 487 | extractor_prune_conv_channels: bool = False, 488 | encoder_prune_attention_heads: bool = False, 489 | encoder_prune_attention_layer: bool = False, 490 | encoder_prune_feed_forward_intermediate: bool = False, 491 | encoder_prune_feed_forward_layer: bool = False, 492 | ) -> Wav2Vec2Model: 493 | """Builds "large lv-60k" :class:`~torchaudio.models.Wav2Vec2Model` from *wav2vec 2.0* :cite:`baevski2020wav2vec` 494 | 495 | Args: 496 | encoder_projection_dropout (float): 497 | See :py:func:`wav2vec2_model`. 498 | encoder_attention_dropout (float): 499 | See :py:func:`wav2vec2_model`. 500 | encoder_ff_interm_dropout (float): 501 | See :py:func:`wav2vec2_model`. 502 | encoder_dropout (float): 503 | See :py:func:`wav2vec2_model`. 504 | encoder_layer_drop (float): 505 | See :py:func:`wav2vec2_model`. 506 | aux_num_out (int or None, optional): 507 | See :py:func:`wav2vec2_model`. 508 | 509 | Returns: 510 | Wav2Vec2Model: 511 | The resulting model. 512 | """ # noqa: E501 513 | return wav2vec2_model( 514 | extractor_mode="layer_norm", 515 | extractor_conv_layer_config=None, 516 | extractor_conv_bias=True, 517 | encoder_embed_dim=1024, 518 | encoder_projection_dropout=encoder_projection_dropout, 519 | encoder_pos_conv_kernel=128, 520 | encoder_pos_conv_groups=16, 521 | encoder_num_layers=24, 522 | encoder_num_heads=16, 523 | encoder_attention_dropout=encoder_attention_dropout, 524 | encoder_ff_interm_features=4096, 525 | encoder_ff_interm_dropout=encoder_ff_interm_dropout, 526 | encoder_dropout=encoder_dropout, 527 | encoder_layer_norm_first=True, 528 | encoder_layer_drop=encoder_layer_drop, 529 | aux_num_out=aux_num_out, 530 | extractor_prune_conv_channels=extractor_prune_conv_channels, 531 | encoder_prune_attention_heads=encoder_prune_attention_heads, 532 | encoder_prune_attention_layer=encoder_prune_attention_layer, 533 | encoder_prune_feed_forward_intermediate=encoder_prune_feed_forward_intermediate, 534 | encoder_prune_feed_forward_layer=encoder_prune_feed_forward_layer, 535 | ) 536 | 537 | 538 | def hubert_base( 539 | encoder_projection_dropout: float = 0.1, 540 | encoder_attention_dropout: float = 0.1, 541 | encoder_ff_interm_dropout: float = 0.0, 542 | encoder_dropout: float = 0.1, 543 | encoder_layer_drop: float = 0.05, 544 | aux_num_out: Optional[int] = None, 545 | extractor_prune_conv_channels: bool = False, 546 | encoder_prune_attention_heads: bool = False, 547 | encoder_prune_attention_layer: bool = False, 548 | encoder_prune_feed_forward_intermediate: bool = False, 549 | encoder_prune_feed_forward_layer: bool = False, 550 | ) -> Wav2Vec2Model: 551 | """Builds "base" :class:`HuBERT ` from *HuBERT* :cite:`hsu2021hubert` 552 | 553 | Args: 554 | encoder_projection_dropout (float): 555 | See :py:func:`wav2vec2_model`. 556 | encoder_attention_dropout (float): 557 | See :py:func:`wav2vec2_model`. 558 | encoder_ff_interm_dropout (float): 559 | See :py:func:`wav2vec2_model`. 560 | encoder_dropout (float): 561 | See :py:func:`wav2vec2_model`. 562 | encoder_layer_drop (float): 563 | See :py:func:`wav2vec2_model`. 564 | aux_num_out (int or None, optional): 565 | See :py:func:`wav2vec2_model`. 566 | 567 | Returns: 568 | Wav2Vec2Model: 569 | The resulting model. 570 | """ # noqa: E501 571 | return wav2vec2_model( 572 | extractor_mode="group_norm", 573 | extractor_conv_layer_config=None, 574 | extractor_conv_bias=False, 575 | encoder_embed_dim=768, 576 | encoder_projection_dropout=encoder_projection_dropout, 577 | encoder_pos_conv_kernel=128, 578 | encoder_pos_conv_groups=16, 579 | encoder_num_layers=12, 580 | encoder_use_attention=[True] * 12, 581 | encoder_use_feed_forward=[True] * 12, 582 | encoder_num_heads=[12] * 12, 583 | encoder_head_dim=64, 584 | encoder_attention_dropout=encoder_attention_dropout, 585 | encoder_ff_interm_features=[3072] * 12, 586 | encoder_ff_interm_dropout=encoder_ff_interm_dropout, 587 | encoder_dropout=encoder_dropout, 588 | encoder_layer_norm_first=False, 589 | encoder_layer_drop=encoder_layer_drop, 590 | aux_num_out=aux_num_out, 591 | extractor_prune_conv_channels=extractor_prune_conv_channels, 592 | encoder_prune_attention_heads=encoder_prune_attention_heads, 593 | encoder_prune_attention_layer=encoder_prune_attention_layer, 594 | encoder_prune_feed_forward_intermediate=encoder_prune_feed_forward_intermediate, 595 | encoder_prune_feed_forward_layer=encoder_prune_feed_forward_layer, 596 | ) 597 | 598 | 599 | def hubert_large( 600 | encoder_projection_dropout: float = 0.0, 601 | encoder_attention_dropout: float = 0.0, 602 | encoder_ff_interm_dropout: float = 0.0, 603 | encoder_dropout: float = 0.0, 604 | encoder_layer_drop: float = 0.0, 605 | aux_num_out: Optional[int] = None, 606 | extractor_prune_conv_channels: bool = False, 607 | encoder_prune_attention_heads: bool = False, 608 | encoder_prune_attention_layer: bool = False, 609 | encoder_prune_feed_forward_intermediate: bool = False, 610 | encoder_prune_feed_forward_layer: bool = False, 611 | ) -> Wav2Vec2Model: 612 | """Builds "large" :class:`HuBERT ` from *HuBERT* :cite:`hsu2021hubert` 613 | 614 | Args: 615 | encoder_projection_dropout (float): 616 | See :py:func:`wav2vec2_model`. 617 | encoder_attention_dropout (float): 618 | See :py:func:`wav2vec2_model`. 619 | encoder_ff_interm_dropout (float): 620 | See :py:func:`wav2vec2_model`. 621 | encoder_dropout (float): 622 | See :py:func:`wav2vec2_model`. 623 | encoder_layer_drop (float): 624 | See :py:func:`wav2vec2_model`. 625 | aux_num_out (int or None, optional): 626 | See :py:func:`wav2vec2_model`. 627 | 628 | Returns: 629 | Wav2Vec2Model: 630 | The resulting model. 631 | """ # noqa: E501 632 | return wav2vec2_model( 633 | extractor_mode="layer_norm", 634 | extractor_conv_layer_config=None, 635 | extractor_conv_bias=False, 636 | encoder_embed_dim=1024, 637 | encoder_projection_dropout=encoder_projection_dropout, 638 | encoder_pos_conv_kernel=128, 639 | encoder_pos_conv_groups=16, 640 | encoder_num_layers=24, 641 | encoder_num_heads=16, 642 | encoder_attention_dropout=encoder_attention_dropout, 643 | encoder_ff_interm_features=4096, 644 | encoder_ff_interm_dropout=encoder_ff_interm_dropout, 645 | encoder_dropout=encoder_dropout, 646 | encoder_layer_norm_first=True, 647 | encoder_layer_drop=encoder_layer_drop, 648 | aux_num_out=aux_num_out, 649 | extractor_prune_conv_channels=extractor_prune_conv_channels, 650 | encoder_prune_attention_heads=encoder_prune_attention_heads, 651 | encoder_prune_attention_layer=encoder_prune_attention_layer, 652 | encoder_prune_feed_forward_intermediate=encoder_prune_feed_forward_intermediate, 653 | encoder_prune_feed_forward_layer=encoder_prune_feed_forward_layer, 654 | ) 655 | 656 | 657 | def hubert_xlarge( 658 | encoder_projection_dropout: float = 0.0, 659 | encoder_attention_dropout: float = 0.0, 660 | encoder_ff_interm_dropout: float = 0.0, 661 | encoder_dropout: float = 0.0, 662 | encoder_layer_drop: float = 0.0, 663 | aux_num_out: Optional[int] = None, 664 | extractor_prune_conv_channels: bool = False, 665 | encoder_prune_attention_heads: bool = False, 666 | encoder_prune_attention_layer: bool = False, 667 | encoder_prune_feed_forward_intermediate: bool = False, 668 | encoder_prune_feed_forward_layer: bool = False, 669 | ) -> Wav2Vec2Model: 670 | """Builds "extra large" :class:`HuBERT ` from *HuBERT* :cite:`hsu2021hubert` 671 | 672 | Args: 673 | encoder_projection_dropout (float): 674 | See :py:func:`wav2vec2_model`. 675 | encoder_attention_dropout (float): 676 | See :py:func:`wav2vec2_model`. 677 | encoder_ff_interm_dropout (float): 678 | See :py:func:`wav2vec2_model`. 679 | encoder_dropout (float): 680 | See :py:func:`wav2vec2_model`. 681 | encoder_layer_drop (float): 682 | See :py:func:`wav2vec2_model`. 683 | aux_num_out (int or None, optional): 684 | See :py:func:`wav2vec2_model`. 685 | 686 | Returns: 687 | Wav2Vec2Model: 688 | The resulting model. 689 | """ # noqa: E501 690 | return wav2vec2_model( 691 | extractor_mode="layer_norm", 692 | extractor_conv_layer_config=None, 693 | extractor_conv_bias=False, 694 | encoder_embed_dim=1280, 695 | encoder_projection_dropout=encoder_projection_dropout, 696 | encoder_pos_conv_kernel=128, 697 | encoder_pos_conv_groups=16, 698 | encoder_num_layers=48, 699 | encoder_num_heads=16, 700 | encoder_attention_dropout=encoder_attention_dropout, 701 | encoder_ff_interm_features=5120, 702 | encoder_ff_interm_dropout=encoder_ff_interm_dropout, 703 | encoder_dropout=encoder_dropout, 704 | encoder_layer_norm_first=True, 705 | encoder_layer_drop=encoder_layer_drop, 706 | aux_num_out=aux_num_out, 707 | extractor_prune_conv_channels=extractor_prune_conv_channels, 708 | encoder_prune_attention_heads=encoder_prune_attention_heads, 709 | encoder_prune_attention_layer=encoder_prune_attention_layer, 710 | encoder_prune_feed_forward_intermediate=encoder_prune_feed_forward_intermediate, 711 | encoder_prune_feed_forward_layer=encoder_prune_feed_forward_layer, 712 | ) 713 | 714 | 715 | def _init_hubert_pretrain_model(module): 716 | if isinstance(module, components.LayerNorm): 717 | torch.nn.init.kaiming_normal_(module.conv.weight) 718 | elif isinstance(module, components.ConvolutionalPositionalEmbedding): 719 | # normalize the weight to normal distribution. 720 | std = math.sqrt(4.0 / (module.embed_dim * module.kernel_size)) 721 | torch.nn.init.normal_(module.conv.weight, mean=0.0, std=std) 722 | torch.nn.init.constant_(module.conv.bias, 0.0) 723 | elif isinstance(module, components.SelfAttention): 724 | # normalize the query, key, value, and out_proj parameters in self attention module. 725 | torch.nn.init.xavier_uniform_(module.k_proj.weight, gain=1 / math.sqrt(2)) 726 | torch.nn.init.xavier_uniform_(module.v_proj.weight, gain=1 / math.sqrt(2)) 727 | torch.nn.init.xavier_uniform_(module.q_proj.weight, gain=1 / math.sqrt(2)) 728 | torch.nn.init.xavier_uniform_(module.out_proj.weight) 729 | torch.nn.init.constant_(module.out_proj.bias, 0.0) 730 | elif isinstance(module, components.Transformer): 731 | module.apply(components._init_transformer_params) 732 | else: 733 | pass 734 | 735 | 736 | def wavlm_model( 737 | extractor_mode: str, 738 | extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]], 739 | extractor_conv_bias: bool, 740 | encoder_embed_dim: int, 741 | encoder_projection_dropout: float, 742 | encoder_pos_conv_kernel: int, 743 | encoder_pos_conv_groups: int, 744 | encoder_num_layers: int, 745 | encoder_use_attention: List[bool], 746 | encoder_use_feed_forward: List[bool], 747 | encoder_total_num_heads: List[int], 748 | encoder_remaining_heads: List[List[int]], 749 | encoder_num_buckets: int, 750 | encoder_max_distance: int, 751 | encoder_attention_dropout: float, 752 | encoder_ff_interm_features: List[int], 753 | encoder_ff_interm_dropout: float, 754 | encoder_dropout: float, 755 | encoder_layer_norm_first: bool, 756 | encoder_layer_drop: float, 757 | aux_num_out: Optional[int], 758 | normalize_waveform: bool, 759 | extractor_prune_conv_channels: bool = False, 760 | encoder_prune_attention_heads: bool = False, 761 | encoder_prune_attention_layer: bool = False, 762 | encoder_prune_feed_forward_intermediate: bool = False, 763 | encoder_prune_feed_forward_layer: bool = False, 764 | ) -> Wav2Vec2Model: 765 | """Builds custom WaveLM model :cite:`chen2022wavlm`. The architecture is compatible 766 | with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output object is 767 | :class:`~torchaudio.models.Wav2Vec2Model`. Most of the arguments have the same meaning 768 | as in :py:func:`wav2vec2_model` so please refer there for documentation. 769 | 770 | Args: 771 | extractor_mode (str): Operation mode of feature extractor. 772 | See :py:func:`wav2vec2_model`. 773 | 774 | extractor_conv_layer_config (list of integer tuples or None): 775 | See :py:func:`wav2vec2_model`. 776 | 777 | extractor_conv_bias (bool): 778 | See :py:func:`wav2vec2_model`. 779 | 780 | encoder_embed_dim (int): 781 | See :py:func:`wav2vec2_model`. 782 | 783 | encoder_projection_dropout (float): 784 | See :py:func:`wav2vec2_model`. 785 | 786 | encoder_pos_conv_kernel (int): 787 | See :py:func:`wav2vec2_model`. 788 | 789 | encoder_pos_conv_groups (int): 790 | See :py:func:`wav2vec2_model`. 791 | 792 | encoder_num_layers (int): 793 | See :py:func:`wav2vec2_model`. 794 | 795 | encoder_num_heads (int): 796 | See :py:func:`wav2vec2_model`. 797 | 798 | encoder_num_buckets (int): 799 | Number of buckets for relative position embedding. 800 | encoder_max_distance (int): 801 | Maximum distance for relative position embedding. 802 | 803 | encoder_attention_dropout (float): 804 | See :py:func:`wav2vec2_model`. 805 | 806 | encoder_ff_interm_features (int): 807 | See :py:func:`wav2vec2_model`. 808 | 809 | encoder_ff_interm_dropout (float): 810 | See :py:func:`wav2vec2_model`. 811 | 812 | encoder_dropout (float): 813 | See :py:func:`wav2vec2_model`. 814 | 815 | encoder_layer_norm_first (bool): 816 | See :py:func:`wav2vec2_model`. 817 | 818 | encoder_layer_drop (float): 819 | See :py:func:`wav2vec2_model`. 820 | 821 | aux_num_out (int or None): 822 | See :py:func:`wav2vec2_model`. 823 | 824 | Returns: 825 | Wav2Vec2Model: 826 | The resulting model. 827 | """ 828 | if extractor_conv_layer_config is None: 829 | extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2 830 | 831 | feature_extractor = components._get_feature_extractor( 832 | extractor_mode, extractor_conv_layer_config, extractor_conv_bias, 833 | prune_conv_channels=extractor_prune_conv_channels, 834 | ) 835 | encoder = components._get_wavlm_encoder( 836 | in_features=extractor_conv_layer_config[-1][0], 837 | embed_dim=encoder_embed_dim, 838 | dropout_input=encoder_projection_dropout, 839 | pos_conv_kernel=encoder_pos_conv_kernel, 840 | pos_conv_groups=encoder_pos_conv_groups, 841 | num_layers=encoder_num_layers, 842 | use_attention=encoder_use_attention, 843 | use_feed_forward=encoder_use_feed_forward, 844 | total_num_heads=encoder_total_num_heads, 845 | remaining_heads=encoder_remaining_heads, 846 | num_buckets=encoder_num_buckets, 847 | max_distance=encoder_max_distance, 848 | attention_dropout=encoder_attention_dropout, 849 | ff_interm_features=encoder_ff_interm_features, 850 | ff_interm_dropout=encoder_ff_interm_dropout, 851 | dropout=encoder_dropout, 852 | layer_norm_first=encoder_layer_norm_first, 853 | layer_drop=encoder_layer_drop, 854 | prune_attention_heads=encoder_prune_attention_heads, 855 | prune_attention_layer=encoder_prune_attention_layer, 856 | prune_feed_forward_intermediate=encoder_prune_feed_forward_intermediate, 857 | prune_feed_forward_layer=encoder_prune_feed_forward_layer, 858 | ) 859 | aux = None 860 | if aux_num_out is not None: 861 | aux = torch.nn.Linear(in_features=encoder_embed_dim, out_features=aux_num_out) 862 | return Wav2Vec2Model(normalize_waveform, feature_extractor, encoder, aux) 863 | 864 | 865 | def wavlm_base( 866 | encoder_projection_dropout: float = 0.1, 867 | encoder_attention_dropout: float = 0.1, 868 | encoder_ff_interm_dropout: float = 0.1, 869 | encoder_dropout: float = 0.1, 870 | encoder_layer_drop: float = 0.1, 871 | aux_num_out: Optional[int] = None, 872 | ) -> Wav2Vec2Model: 873 | """Builds "base" WaveLM model :cite:`chen2022wavlm`. The architecture is compatible 874 | with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output class is 875 | :class:`~torchaudio.models.Wav2Vec2Model`. 876 | 877 | Args: 878 | encoder_projection_dropout (float): 879 | See :py:func:`wav2vec2_model`. 880 | encoder_attention_dropout (float): 881 | See :py:func:`wav2vec2_model`. 882 | encoder_ff_interm_dropout (float): 883 | See :py:func:`wav2vec2_model`. 884 | encoder_dropout (float): 885 | See :py:func:`wav2vec2_model`. 886 | encoder_layer_drop (float): 887 | See :py:func:`wav2vec2_model`. 888 | aux_num_out (int, optional): 889 | See :py:func:`wav2vec2_model`. 890 | 891 | Returns: 892 | Wav2Vec2Model: 893 | The resulting model. 894 | """ 895 | return wavlm_model( 896 | extractor_mode="group_norm", 897 | extractor_conv_layer_config=None, 898 | extractor_conv_bias=False, 899 | encoder_embed_dim=768, 900 | encoder_projection_dropout=encoder_projection_dropout, 901 | encoder_pos_conv_kernel=128, 902 | encoder_pos_conv_groups=16, 903 | encoder_num_layers=12, 904 | encoder_num_heads=12, 905 | encoder_num_buckets=320, 906 | encoder_max_distance=800, 907 | encoder_attention_dropout=encoder_attention_dropout, 908 | encoder_ff_interm_features=3072, 909 | encoder_ff_interm_dropout=encoder_ff_interm_dropout, 910 | encoder_dropout=encoder_dropout, 911 | encoder_layer_norm_first=False, 912 | encoder_layer_drop=encoder_layer_drop, 913 | aux_num_out=aux_num_out, 914 | ) 915 | 916 | 917 | def wavlm_large( 918 | encoder_projection_dropout: float = 0.1, 919 | encoder_attention_dropout: float = 0.1, 920 | encoder_ff_interm_dropout: float = 0.0, 921 | encoder_dropout: float = 0.1, 922 | encoder_layer_drop: float = 0.1, 923 | aux_num_out: Optional[int] = None, 924 | ) -> Wav2Vec2Model: 925 | """Builds "large" WaveLM model :cite:`chen2022wavlm`. The architecture is compatible 926 | with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output class is 927 | :class:`~torchaudio.models.Wav2Vec2Model`. 928 | 929 | Args: 930 | encoder_projection_dropout (float): 931 | See :py:func:`wav2vec2_model`. 932 | encoder_attention_dropout (float): 933 | See :py:func:`wav2vec2_model`. 934 | encoder_ff_interm_dropout (float): 935 | See :py:func:`wav2vec2_model`. 936 | encoder_dropout (float): 937 | See :py:func:`wav2vec2_model`. 938 | encoder_layer_drop (float): 939 | See :py:func:`wav2vec2_model`. 940 | aux_num_out (int, optional): 941 | See :py:func:`wav2vec2_model`. 942 | 943 | Returns: 944 | Wav2Vec2Model: 945 | The resulting model. 946 | """ 947 | return wavlm_model( 948 | extractor_mode="layer_norm", 949 | extractor_conv_layer_config=None, 950 | extractor_conv_bias=False, 951 | encoder_embed_dim=1024, 952 | encoder_projection_dropout=encoder_projection_dropout, 953 | encoder_pos_conv_kernel=128, 954 | encoder_pos_conv_groups=16, 955 | encoder_num_layers=24, 956 | encoder_num_heads=16, 957 | encoder_num_buckets=320, 958 | encoder_max_distance=800, 959 | encoder_attention_dropout=encoder_attention_dropout, 960 | encoder_ff_interm_features=4096, 961 | encoder_ff_interm_dropout=encoder_ff_interm_dropout, 962 | encoder_dropout=encoder_dropout, 963 | encoder_layer_norm_first=True, 964 | encoder_layer_drop=encoder_layer_drop, 965 | aux_num_out=aux_num_out, 966 | ) 967 | -------------------------------------------------------------------------------- /wav2vec2/pruning_utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for pruning.""" 2 | 3 | from typing import Union 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: str): 10 | "Prune linear layer in place." 11 | # NOTE: weight: (out_features, in_features), bias: (out_features,) 12 | if dim == "input": 13 | dim = 1 14 | layer.in_features = len(index) 15 | elif dim == "output": 16 | dim = 0 17 | layer.out_features = len(index) 18 | else: 19 | raise ValueError 20 | 21 | layer.weight = nn.Parameter(layer.weight.index_select(dim, index).clone().detach()) 22 | if layer.bias is not None and dim == 0: 23 | layer.bias = nn.Parameter(layer.bias.index_select(0, index).clone().detach()) 24 | 25 | 26 | def prune_conv1d_layer(layer: nn.Conv1d, index: torch.LongTensor, dim: str): 27 | """Prune conv1d in place.""" 28 | # NOTE: weight: (out_channels, in_channels, kernel_size), bias: (out_channels,) 29 | if dim == "input": 30 | dim = 1 31 | layer.in_channels = len(index) 32 | elif dim == "output": 33 | dim = 0 34 | layer.out_channels = len(index) 35 | else: 36 | raise ValueError 37 | 38 | layer.weight = nn.Parameter(layer.weight.index_select(dim, index).clone().detach()) 39 | if layer.bias is not None and dim == 0: 40 | layer.bias = nn.Parameter(layer.bias.index_select(0, index).clone().detach()) 41 | 42 | 43 | def prune_layer_norm(layernorm: Union[nn.LayerNorm, nn.GroupNorm], index: torch.LongTensor): 44 | """Prune layer norm or group norm in place.""" 45 | layernorm.weight = nn.Parameter(layernorm.weight.index_select(0, index).clone().detach()) 46 | layernorm.bias = nn.Parameter(layernorm.bias.index_select(0, index).clone().detach()) 47 | if isinstance(layernorm, nn.LayerNorm): 48 | layernorm.normalized_shape = (len(index),) 49 | elif isinstance(layernorm, nn.GroupNorm): 50 | layernorm.num_groups = len(index) 51 | layernorm.num_channels = len(index) 52 | -------------------------------------------------------------------------------- /wav2vec2/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyf98/DPHuBERT/c18093fe4b56a0027a80bf9b9b1a23f932cbf14c/wav2vec2/utils/__init__.py -------------------------------------------------------------------------------- /wav2vec2/utils/import_huggingface_wavlm.py: -------------------------------------------------------------------------------- 1 | """Import Hugging Face transformers's wav2vec2.0 pretrained weights to torchaudios's format. 2 | 3 | Originally from: 4 | https://github.com/pytorch/audio/blob/main/torchaudio/models/wav2vec2/utils/import_huggingface.py 5 | 6 | """ 7 | 8 | import logging 9 | from typing import Any, Dict 10 | 11 | from torch.nn import Module 12 | 13 | from ..model import wav2vec2_model, Wav2Vec2Model, wavlm_model 14 | 15 | _LG = logging.getLogger(__name__) 16 | 17 | 18 | def _get_config(cfg): 19 | config = { 20 | "extractor_mode": f"{cfg.feat_extract_norm}_norm", 21 | "extractor_conv_layer_config": list(zip(cfg.conv_dim, cfg.conv_kernel, cfg.conv_stride)), 22 | "extractor_conv_bias": cfg.conv_bias, 23 | "encoder_embed_dim": cfg.hidden_size, 24 | "encoder_projection_dropout": cfg.feat_proj_dropout, 25 | "encoder_pos_conv_kernel": cfg.num_conv_pos_embeddings, 26 | "encoder_pos_conv_groups": cfg.num_conv_pos_embedding_groups, 27 | "encoder_num_layers": cfg.num_hidden_layers, 28 | "encoder_num_heads": cfg.num_attention_heads, 29 | "encoder_attention_dropout": cfg.attention_dropout, 30 | "encoder_ff_interm_features": cfg.intermediate_size, 31 | "encoder_ff_interm_dropout": cfg.activation_dropout, 32 | "encoder_dropout": cfg.hidden_dropout, 33 | "encoder_layer_norm_first": cfg.do_stable_layer_norm, 34 | "encoder_layer_drop": cfg.layerdrop, 35 | } 36 | return config 37 | 38 | 39 | def _get_config_wavlm(cfg): 40 | config = { 41 | "extractor_mode": f"{cfg.feat_extract_norm}_norm", 42 | "extractor_conv_layer_config": list(zip(cfg.conv_dim, cfg.conv_kernel, cfg.conv_stride)), 43 | "extractor_conv_bias": cfg.conv_bias, 44 | "encoder_embed_dim": cfg.hidden_size, 45 | "encoder_projection_dropout": cfg.feat_proj_dropout, 46 | "encoder_pos_conv_kernel": cfg.num_conv_pos_embeddings, 47 | "encoder_pos_conv_groups": cfg.num_conv_pos_embedding_groups, 48 | "encoder_num_layers": cfg.num_hidden_layers, 49 | "encoder_use_attention": [True] * cfg.num_hidden_layers, 50 | "encoder_use_feed_forward": [True] * cfg.num_hidden_layers, 51 | "encoder_total_num_heads": [cfg.num_attention_heads for _ in range(cfg.num_hidden_layers)], 52 | "encoder_remaining_heads": [list(range(cfg.num_attention_heads)) for _ in range(cfg.num_hidden_layers)], 53 | "encoder_num_buckets": cfg.num_buckets, 54 | "encoder_max_distance": cfg.max_bucket_distance, 55 | "encoder_attention_dropout": cfg.attention_dropout, 56 | "encoder_ff_interm_features": [cfg.intermediate_size for _ in range(cfg.num_hidden_layers)], 57 | "encoder_ff_interm_dropout": cfg.activation_dropout, 58 | "encoder_dropout": cfg.hidden_dropout, 59 | "encoder_layer_norm_first": cfg.do_stable_layer_norm, 60 | "encoder_layer_drop": cfg.layerdrop, 61 | "normalize_waveform": cfg.feat_extract_norm == "layer", 62 | } 63 | return config 64 | 65 | 66 | def _build(config, original): 67 | is_for_ctc = original.__class__.__name__ in ["Wav2Vec2ForCTC", "WavLMForCTC"] 68 | if is_for_ctc: 69 | aux_num_out = original.config.vocab_size 70 | wav2vec2 = original.wav2vec2 71 | else: 72 | _LG.warning( 73 | "The model is not an instance of Wav2Vec2ForCTC or WavLMForCTC. " '"lm_head" module is not imported.' 74 | ) 75 | aux_num_out = None 76 | wav2vec2 = original 77 | is_wavlm = original.__class__.__name__ in ["WavLMModel", "WavLMForCTC"] 78 | if is_wavlm: 79 | imported = wavlm_model(**config, aux_num_out=aux_num_out) 80 | else: 81 | imported = wav2vec2_model(**config, aux_num_out=aux_num_out) 82 | print(imported.feature_extractor.load_state_dict(wav2vec2.feature_extractor.state_dict(), strict=False)) 83 | print(imported.encoder.feature_projection.load_state_dict(wav2vec2.feature_projection.state_dict(), strict=False)) 84 | encoder_state_dict = wav2vec2.encoder.state_dict() 85 | if is_wavlm: # Rename paramaters of linear transformations for compatibility with the HF model 86 | transform_wavlm_encoder_state(encoder_state_dict, config["encoder_num_layers"]) 87 | print(imported.encoder.transformer.load_state_dict(encoder_state_dict, strict=False)) 88 | if is_for_ctc: 89 | imported.aux.load_state_dict(original.lm_head.state_dict()) 90 | return imported 91 | 92 | 93 | def transform_wavlm_encoder_state(state: Dict[str, Any], encoder_num_layers: int): 94 | """Converts WavLM encoder state from HuggingFace format. In particular, concatenates linear projection weights and 95 | biases to align with the structure of ``torch.nn.MultiheadAttention``. 96 | """ 97 | pass 98 | 99 | 100 | def import_huggingface_model(original: Module) -> Wav2Vec2Model: 101 | """Builds :class:`Wav2Vec2Model` from the corresponding model object of 102 | `Transformers `_. 103 | 104 | Args: 105 | original (torch.nn.Module): An instance of ``Wav2Vec2ForCTC`` from ``transformers``. 106 | 107 | Returns: 108 | Wav2Vec2Model: Imported model. 109 | 110 | Example 111 | >>> from torchaudio.models.wav2vec2.utils import import_huggingface_model 112 | >>> 113 | >>> original = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") 114 | >>> model = import_huggingface_model(original) 115 | >>> 116 | >>> waveforms, _ = torchaudio.load("audio.wav") 117 | >>> logits, _ = model(waveforms) 118 | """ 119 | _LG.info("Importing model.") 120 | _LG.info("Loading model configuration.") 121 | is_wavlm = original.__class__.__name__ in ["WavLMModel", "WavLMForCTC"] 122 | if is_wavlm: 123 | config = _get_config_wavlm(original.config) 124 | else: 125 | config = _get_config(original.config) 126 | _LG.debug(" - config: %s", config) 127 | _LG.info("Building model.") 128 | imported = _build(config, original) 129 | return imported 130 | --------------------------------------------------------------------------------