├── .gitignore
├── LICENSE
├── README.md
├── dasheng
├── __init__.py
├── __pycache__
│ └── __init__.cpython-39.pyc
├── prepare
│ └── wavlist_to_tar.py
├── pretrained
│ ├── __pycache__
│ │ └── pretrained.cpython-39.pyc
│ └── pretrained.py
└── train
│ ├── audiowebdataset.py
│ ├── config
│ ├── dasheng_06B.yaml
│ ├── dasheng_12B.yaml
│ └── dasheng_base.yaml
│ ├── models.py
│ ├── train.py
│ └── utils.py
├── metadata
└── hear_capabilities.png
└── pyproject.toml
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env
128 | venv
129 | env/
130 | venv/
131 | ENV/
132 | env.bak/
133 | venv.bak/
134 |
135 | # Spyder project settings
136 | .spyderproject
137 | .spyproject
138 |
139 | # Rope project settings
140 | .ropeproject
141 |
142 | # mkdocs documentation
143 | /site
144 |
145 | # mypy
146 | .mypy_cache/
147 | .dmypy.json
148 | dmypy.json
149 |
150 | # Pyre type checker
151 | .pyre/
152 |
153 | # pytype static type analyzer
154 | .pytype/
155 |
156 | # Cython debug symbols
157 | cython_debug/
158 |
159 | # VSCode
160 | .vscode/
161 |
162 | # PyCharm
163 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
164 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
165 | # and can be added to the global gitignore or merged into this file. For a more nuclear
166 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
167 | #.idea/
168 |
169 | # Debug
170 | debug*
171 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright (C) 2024 Xiaomi Corporation.
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | Dasheng (大声)
4 |
5 |
6 | Official PyTorch code for Deep Audio-Signal Holistic Embeddings
7 | Scaling up masked audio encoder learning for general audio classification
8 |
9 |
10 |

11 |

12 |

13 |

14 |

15 |
16 |
17 | # TL;DR
18 |
19 | ```bash
20 | python3 -m pip install dasheng
21 | python3 -c "from dasheng import dasheng_base; import torch; model = dasheng_base().eval(); features=model(torch.randn(1, 16000))"
22 | ```
23 |
24 |
25 | This repo provides checkpoints for the Interspeech 2024 paper [Scaling up masked audio encoder learning for general audio classification](https://arxiv.org/abs/2406.06992).
26 | The goal of this work is to investigate the scalability of masked autoencoders for audio.
27 | Prior work did not scale beyond 10,000 hours of audio, while Dasheng used 272,000 hours of training data.
28 |
29 |
30 | ## Huggingface 🤗
31 |
32 |
33 |
38 |
39 | Please see [here](https://huggingface.co/mispeech/dasheng-base) for usage instructions.
40 |
41 |
42 | # Models
43 |
44 | Dasheng models have been trained on 272k hours of general audio, mainly [VGGSound](https://www.robots.ox.ac.uk/~vgg/data/vggsound/), [Audioset](https://research.google.com/audioset/), [MTG-Jamendo](https://mtg.github.io/mtg-jamendo-dataset/) and [ACAV100M](https://acav100m.github.io/).
45 |
46 | Models with their evaluation results on the [HEAR benchmark](https://hearbenchmark.com/), averaged across different domains.
47 |
48 | | Model | Parameters (M) | Environment Sounds | Speech | Music |
49 | |------|-------|-------|-------| ------ |
50 | | Dasheng-Base| 86 | 80.2 | 72.5 | 84.0 |
51 | |Dasheng-0.6B | 600 | 82.4 | 74.9 | 84.0 |
52 | | Dasheng-1.2B | 1200 | **83.2** | **75.7** | **84.9** |
53 | | [AudioMAE](https://github.com/facebookresearch/AudioMAE) | 86 | 61.7 | 38.7 | 72.7 |
54 | | [Whisper-Base-V1](https://github.com/openai/whisper) | 74 | 52.5 | 73.1 | 69.1 |
55 | | [WavLM-Large](https://github.com/microsoft/unilm/tree/master/wavlm) | 330 | 71.4 | 72.2 | 65.0 |
56 | | [Wav2vec-large-100k-voxpopuli](https://huggingface.co/facebook/wav2vec2-large-100k-voxpopuli) | 300 | 62.5 | 63.6 | 69.5 |
57 | | [Data2Vec-Audio-Large](https://huggingface.co/facebook/data2vec-audio-large) | 300 |41.1 | 60.5 | 55.0 |
58 |
59 |
60 |
61 | ## K-Nearest Neighbor results
62 |
63 | Performance of features without parameterized training.
64 |
65 | | | ESC50 | FSDKaggle18 | NSynth Instrument | Speech Commands 1 | Speech Commands 2 | US8k | VoxCeleb1 | RAVDESS-Speech | FluentSpeechCommands |
66 | |--------------------------|-------|--------|-------------|-------|-------|-------|-----------|---------|-------|
67 | | [MSM-MAE](https://github.com/nttcslab/msm-mae) | 2 | 2.18 | 20.58 | 3.7 | 1.5 | 11.5 | 0.12 | 6.77 | 1.85 |
68 | | MelSpec | 18.4 | 38.5 | 35.5 | 3.7 | 1.5 | 40.39 | 5.26 | 29.65 | 9.97 |
69 | | [CED-Base](https://github.com/RicherMans/CED) | 95.35 | 85.06 | 74.41 | 79.78 | 62.66 | 87.06 | 7.02 | 52.78 | 16.61 |
70 | | [AudioMAE](https://github.com/facebookresearch/AudioMAE) | 53.05 | 43.38 | 67.21 | 56.87 | 5.9 | 58.18 | 2.9 | 28.68 | 7.59 |
71 | | [WavLM-Large](https://github.com/microsoft/unilm/tree/master/wavlm) | 51.3 | 60.87 | | 96.97 | 92.69 | 58.67 | 28.54 | 51.39 | 83.28 |
72 | | [Wav2vec-large-100k-voxpopuli](https://huggingface.co/facebook/wav2vec2-large-100k-voxpopuli) | 44 | 59.5 | 60.42 | 80.86 | 66.61 | 59.84 | 18.22 | 45.76 | 30.48 |
73 | | Dasheng-Base | 61.9 | 70.31 | 70.02 | 93.55 | 86 | 73.87 | 34.21 | 58.12 | 52.33 |
74 | | Dasheng-0.6B | 66.55 | 72.06 | 70.87 | 93.36 | 87.27 | 75.92 | 37.78 | 61.81 | 57.63 |
75 | | Dasheng-1.2B | 68.55 | 72.06 | 71.19 | 95.9 | 90.9 | 77.71 | 39.39 | 61.94 | 62.38 |
76 |
77 | ## 1. Installation (Recommended for inference)
78 |
79 | Install the package.
80 |
81 | ```bash
82 | python3 -m pip install dasheng
83 | ```
84 |
85 | ### 1.2 Installation for Training
86 |
87 | ```bash
88 | python3 -m pip install dasheng[train]
89 | ```
90 |
91 | ## 2. Usage
92 |
93 | ```python
94 | # The three models of the paper
95 | from dasheng import dasheng_base, dasheng_06B, dasheng_12B
96 |
97 | model = dasheng_base()
98 | ```
99 |
100 | Forward some audio data (note should be 16khz)
101 |
102 | ```python
103 | import torch
104 | model = model.eval()
105 | features = model(torch.randn(1, 16000))
106 | print(features.shape)
107 | ```
108 |
109 | ## 3. Training
110 |
111 | Install dependencies:
112 |
113 | ```bash
114 | python3 -m pip install dasheng[train]
115 | ```
116 |
117 | ### 3.1 Prepare data
118 |
119 | We rely on the excellent [webdataset](https://github.com/webdataset) library for I/O.
120 | Thus one simply needs to pack their data into a bunch of `.tar` files.
121 |
122 | A simple example of such a file would be:
123 |
124 | ```bash
125 | find DIR -type f -name '*flac' | tar -rvf data.tgz -T -
126 | ```
127 |
128 | We also provide a simple script [wavlist_to_tar] that automates this process, which is installed with the package.
129 |
130 | ```bash
131 | wavlist_to_tar your_data.tsv shards/
132 | ```
133 |
134 | Creating `your_data.tsv` is simple:
135 |
136 | ```bash
137 | find data -type f | awk 'BEGIN{print "filename"} {print}' > your_data.tsv
138 | ```
139 |
140 | ### 3.2 Training from source
141 |
142 | To train one should first adjust the config in `dasheng/train/config/*yaml` accordingly, by adding their training data.
143 |
144 | ```bash
145 | python3 dasheng/train/train.py dasheng/train/config/dasheng_base.yaml
146 | ```
147 |
148 | MultiGPU support is realized using [Accelerate](https://huggingface.co/docs/accelerate/index)
149 |
150 | ```bash
151 | accelerate launch --mixed_precision='bf16' dasheng/train/train.py dasheng/train/config/dasheng_base.yaml
152 | ```
153 |
154 | ## FAQ
155 |
156 | ### Is there an Audioset-finetuned Dasheng?
157 |
158 | Yes, the performance for the base model is 49.7 mAP. One can use it as follows:
159 |
160 | ```python
161 | from typing import Any, Mapping
162 | import dasheng
163 | import torch
164 |
165 | class DashengAudiosetClassifier(torch.nn.Module):
166 |
167 | def __init__(self) -> None:
168 | super().__init__()
169 | self.dashengmodel = dasheng.dasheng_base()
170 | self.classifier = torch.nn.Sequential(torch.nn.LayerNorm(self.dashengmodel.embed_dim), torch.nn.Linear(self.dashengmodel.embed_dim, 527))
171 |
172 | def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False):
173 | self.dashengmodel.load_state_dict(state_dict, strict=False)
174 | for_classifier_dict = {}
175 | for k,v in state_dict.items():
176 | if 'outputlayer' in k:
177 | for_classifier_dict[k.replace('outputlayer.','')] = v
178 | self.classifier.load_state_dict(for_classifier_dict)
179 | return self
180 |
181 | def forward(self, x):
182 | x = self.dashengmodel(x).mean(1)
183 | return self.classifier(x).sigmoid()
184 |
185 |
186 | mdl = DashengAudiosetClassifier()
187 | check = torch.hub.load_state_dict_from_url('https://zenodo.org/records/13315686/files/dasheng_audioset_mAP497.pt?download=1',map_location='cpu')
188 | mdl.load_state_dict(check)
189 |
190 | prediction = mdl(torch.randn(1,16000))
191 | ```
192 |
193 |
194 | ## Citation
195 |
196 | ```bibtex
197 | @inproceedings{dinkel2024dasheng,
198 | title={Scaling up masked audio encoder learning for general audio classification},
199 | author={Dinkel, Heinrich and Yan, Zhiyong and Wang, Yongqing and Zhang, Junbo and Wang, Yujun and Wang, Bin},
200 | booktitle={Interspeech 2024},
201 | year={2024}
202 | }
203 | ```
204 |
--------------------------------------------------------------------------------
/dasheng/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib.metadata
2 | __version__ = importlib.metadata.version("dasheng")
3 |
4 | from .pretrained.pretrained import dasheng_base, dasheng_06B, dasheng_12B
5 |
--------------------------------------------------------------------------------
/dasheng/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RicherMans/Dasheng/db7309358edbeea1b1cca37739f442c4139ac8e9/dasheng/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/dasheng/prepare/wavlist_to_tar.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from typing import Any, Dict, Iterable
3 | import json
4 | from pathlib import Path
5 | import pandas as pd
6 | import argparse
7 | import multiprocessing
8 | from webdataset import TarWriter
9 | from tqdm import tqdm
10 |
11 |
12 | def proxy_read(data: Dict, filename_column: str):
13 | filename = data.pop(filename_column)
14 | with open(filename, 'rb') as buf:
15 | raw_data = buf.read()
16 | fpath = Path(filename)
17 | stem_name = str(fpath.stem).replace('.', '_')
18 | suffix = fpath.suffix.replace('.', '')
19 | ret_data = {
20 | suffix: raw_data,
21 | '__key__': f"{stem_name}", # Just cast to str
22 | }
23 | # If we have some labels, also dump a .json file
24 | if len(data) > 0:
25 | ret_data['json'] = json.dumps(data).encode('utf-8')
26 | return ret_data
27 |
28 |
29 | def main():
30 | parser = argparse.ArgumentParser()
31 | parser.add_argument(
32 | 'input_filelist',
33 | type=Path,
34 | help=
35 | "Some input filelist. We will expect a column named for the data and every other column will be dumped to json format. The filenames (basepath) need to be unique. Please first shuffle the list before processing."
36 | )
37 | parser.add_argument('outputdir', type=Path)
38 | parser.add_argument('-s', '--size_per_file', type=int, default=10000)
39 | parser.add_argument('-n', '--n_workers', type=int, default=4)
40 | parser.add_argument(
41 | '--filename_column',
42 | default='filename',
43 | type=str,
44 | help="The column name that identifies the files to extract")
45 | parser.add_argument('-d', '--delim', default='\t', type=str)
46 | parser.add_argument('--compress',
47 | action='store_true',
48 | default=False,
49 | help="Using tar.gz instead of .tar")
50 | parser.add_argument(
51 | '--write_json',
52 | default=None,
53 | type=str,
54 | help=
55 | "Also writes a json to the target directory. Useful with the 'wids' library to read in random."
56 | )
57 | parser.set_defaults(stereo=False)
58 | args = parser.parse_args()
59 | df_iterator: Iterable[pd.DataFrame] = pd.read_csv(
60 | args.input_filelist, sep=args.delim, chunksize=args.size_per_file)
61 |
62 | shards_base_path = args.outputdir
63 | shards_base_path.mkdir(parents=True, exist_ok=True)
64 |
65 | suffix = '.tar' if args.compress is False else '.tar.gz'
66 |
67 | output_json: Dict[str, Any] = dict(wids_version=1)
68 | tar_file_outputs = []
69 | with multiprocessing.Pool(processes=args.n_workers) as pool:
70 | for file_num, df in enumerate(
71 | tqdm(df_iterator,
72 | leave=True,
73 | desc='Dumping to file',
74 | unit='shard')):
75 | #Locally sample
76 | data = df.sample(frac=1.0).to_dict('records')
77 | output_file_iter = str(
78 | shards_base_path /
79 | f'{args.input_filelist.stem}_{file_num:07d}{suffix}')
80 | n_samples = len(data)
81 | tar_file_outputs.append(
82 | dict(url=str(output_file_iter), nsamples=n_samples))
83 | with TarWriter(output_file_iter,
84 | encoder=False,
85 | compress=args.compress) as dst:
86 | for return_values in tqdm(pool.imap_unordered(
87 | partial(proxy_read,
88 | filename_column=args.filename_column), data),
89 | unit='file',
90 | total=len(data),
91 | leave=False):
92 | dst.write(return_values)
93 | print(f"Finished, final data can be found at {args.outputdir}")
94 | if args.write_json is not None:
95 | import json
96 | output_json['shardlist'] = tar_file_outputs
97 | with open(args.write_json, 'w') as f:
98 | json.dump(output_json, f)
99 | print(f"Dumped Json for wids usage at {args.write_json}")
100 |
--------------------------------------------------------------------------------
/dasheng/pretrained/__pycache__/pretrained.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RicherMans/Dasheng/db7309358edbeea1b1cca37739f442c4139ac8e9/dasheng/pretrained/__pycache__/pretrained.cpython-39.pyc
--------------------------------------------------------------------------------
/dasheng/pretrained/pretrained.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Optional
3 | from einops import rearrange
4 | from ..train.models import AudioTransformerMAE_Encoder
5 |
6 | PRETRAINED_CHECKPOINTS = {
7 | "dasheng_base": "https://zenodo.org/records/11511780/files/dasheng_base.pt?download=1",
8 | "dasheng_06B": "https://zenodo.org/records/11511780/files/dasheng_06b.pt?download=1",
9 | "dasheng_12B": "https://zenodo.org/records/11511780/files/dasheng_12b.pt?download=1",
10 | }
11 |
12 |
13 | # Using the pretrained encoders, remove all masking
14 | class Dasheng(AudioTransformerMAE_Encoder):
15 | # need the *args, **kwargs otherwise we get a linter warning
16 | def forward_features(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
17 | *_, t = x.shape
18 | x = x + self.time_pos_embed[:, :, :, :t]
19 | x = x + self.freq_pos_embed[:, :, :, :]
20 | x = rearrange(x, "b c f t -> b (f t) c")
21 | if self.pooling == "token":
22 | cls_token = self.cls_token.expand(x.shape[0], -1, -1)
23 | cls_token = cls_token + self.token_pos_embed[:, :]
24 | x = torch.cat((cls_token, x), dim=1)
25 | x = self.pos_drop(x)
26 | x = self.blocks(x, **kwargs)
27 | x = self.norm(x)
28 | return x
29 |
30 | def _to_mask(self, lengths: torch.Tensor, max_length: int) -> torch.Tensor:
31 | batch_size = len(lengths)
32 | idx = torch.arange(max_length, device=lengths.device)
33 | idx = idx.repeat(batch_size).view(batch_size, max_length)
34 | mask = (idx >= lengths.unsqueeze(-1)).bool()
35 | return mask
36 |
37 | def forward_spectrogram(self, x: torch.Tensor, x_length:Optional[torch.Tensor] = None) -> torch.Tensor:
38 | # For dasheng, target-length is 40 ms
39 | target_length_in_patches = self.target_length // self.patch_stride[-1]
40 | x = self.patch_embed(x)
41 | b, c, f, t = x.shape
42 | input_splits = x.split(target_length_in_patches, dim=-1)
43 | mask = None # Single mask
44 | masks = [None for _ in range(len(input_splits))]
45 | if x_length is not None:
46 | assert len(x_length) == len(x),"batchsizes of input x and x_length need to be same"
47 | assert x_length.ndim == 1, "Lengths are of size (B,)"
48 | scaled_lengths = (x_length / (self.hop_size * 4)).long() # 40ms for all dasheng models
49 | # Note that the mask is in (t f) format, but transformers here use (f t) format
50 | mask = self._to_mask(
51 | max_length=t,
52 | lengths=scaled_lengths,
53 | )
54 | # Trim mask to only use valid "patches", since x.shape[-1] is based on the possibly padded input
55 | masks = mask.split(target_length_in_patches, dim=-1)
56 | outputs = []
57 |
58 | for split_x,mask in zip(input_splits, masks):
59 | forward_kwargs = dict(mask = mask)
60 | split_x = self.forward_features(split_x, **forward_kwargs)
61 | outputs.append(split_x)
62 | x = torch.cat(outputs, dim =1 )
63 | return x
64 |
65 |
66 | def forward(self, x, x_length : Optional[torch.Tensor] = None) -> torch.Tensor:
67 | x = self.forward_to_spec(x)
68 | return self.forward_spectrogram(x,x_length=x_length)
69 |
70 | @classmethod
71 | def from_pretrained(
72 | cls, pretrained_url: str, **additional_model_kwargs
73 | ) -> AudioTransformerMAE_Encoder:
74 | """
75 | Class method to create a new Dasheng model instance from a pre-trained model stored in the Hugging Face model hub.
76 | """
77 | if "http" in pretrained_url:
78 | dump = torch.hub.load_state_dict_from_url(
79 | pretrained_url, map_location="cpu"
80 | )
81 | else:
82 | dump = torch.load(pretrained_url, map_location="cpu")
83 | model_parmeters, model_config = dump["model"], dump["config"]
84 | instance = cls(**{**model_config, **additional_model_kwargs})
85 | instance.load_state_dict(model_parmeters, strict=True)
86 | return instance
87 |
88 |
89 | def dasheng_base(**model_kwargs):
90 | model_kwargs["embed_dim"] = 768
91 | model_kwargs["depth"] = 12
92 | model_kwargs["num_heads"] = 12
93 | return Dasheng.from_pretrained(
94 | PRETRAINED_CHECKPOINTS["dasheng_base"], **model_kwargs
95 | )
96 |
97 |
98 | def dasheng_06B(**model_kwargs):
99 | model_kwargs["embed_dim"] = 1280
100 | model_kwargs["depth"] = 32
101 | model_kwargs["num_heads"] = 16
102 | return Dasheng.from_pretrained(
103 | PRETRAINED_CHECKPOINTS["dasheng_06B"], **model_kwargs
104 | )
105 |
106 |
107 | def dasheng_12B(**model_kwargs):
108 | model_kwargs["embed_dim"] = 1536
109 | model_kwargs["depth"] = 40
110 | model_kwargs["num_heads"] = 24
111 | return Dasheng.from_pretrained(
112 | PRETRAINED_CHECKPOINTS["dasheng_12B"], **model_kwargs
113 | )
114 |
115 |
116 | if __name__ == "__main__":
117 | mdl = dasheng_base()
118 | print(mdl(torch.randn(1, 168499)).shape)
119 |
--------------------------------------------------------------------------------
/dasheng/train/audiowebdataset.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 | import random
3 | from functools import partial
4 | import webdataset as wds
5 | import braceexpand # Dependency in wds
6 | import torch
7 | import numpy as np
8 |
9 | def crop_or_pad(wav: torch.Tensor, crop_size: int, pad_last: bool = False):
10 | n_samples, *_ = wav.shape
11 | available_crops = n_samples // crop_size
12 | for i in range(available_crops):
13 | crop = wav[i * crop_size:(i + 1) * crop_size, ...]
14 | yield crop
15 |
16 | if (available_crops == 0) or (pad_last):
17 | last_crop = wav[available_crops * crop_size:, ...]
18 | padded = torch.zeros((crop_size, *last_crop.shape[1:]))
19 | padded[:last_crop.shape[0]] = last_crop
20 | yield padded
21 |
22 |
23 | def convert_decibels_to_amplitude_ratio(decibels):
24 | return 10**(decibels / 20)
25 |
26 |
27 | def _audio_gain(data_stream, min_gain_db: float = -6, max_gain_db=10):
28 | for sample in data_stream:
29 | audio, *extra = sample
30 | scale_factor = convert_decibels_to_amplitude_ratio(
31 | random.uniform(min_gain_db, max_gain_db))
32 | yield (audio * scale_factor, *extra)
33 |
34 | def _seq_crop(data, crop_size: int, mono: bool = True, pad_last: bool = False, drop_crops: bool = False,handler=None):
35 | """WebDataset crop filter, yields sequential crops"""
36 | for sample in data:
37 | audio, *extra = sample
38 | if isinstance(audio, tuple):
39 | audio = audio[0]
40 | if mono and audio.ndim == 2:
41 | audio = audio.mean(0)
42 | if drop_crops and audio.shape[-1] < int(crop_size * 0.8):
43 | continue
44 | crops = crop_or_pad(audio.float(),
45 | crop_size=crop_size,
46 | pad_last=pad_last)
47 | for crop in crops:
48 | yield (crop, *extra)
49 |
50 |
51 | class Audiowebdataset_Fluid(wds.DataPipeline):
52 |
53 | def __init__(self,
54 | urls,
55 | shuffle: Optional[int] = None,
56 | crop_size: int = 16000,
57 | resample: bool = False,
58 | crop_shuffle: Optional[int] = None,
59 | batch_size: Optional[int] = None,
60 | add_gain: bool = False,
61 | drop_crops: bool = False,
62 | with_json: bool = False,
63 |
64 | ):
65 | pipeline: List = [
66 | wds.SimpleShardList(urls)
67 | if resample is False else wds.ResampledShards(urls)
68 | ]
69 | if shuffle is not None:
70 | # Tar wise shuffle
71 | pipeline.extend([
72 | wds.detshuffle(
73 | bufsize=shuffle,
74 | initial=shuffle // 4,
75 | ),
76 | wds.split_by_node,
77 | wds.split_by_worker,
78 | # at this point, we have an iterator over the shards assigned to each worker at each node
79 | wds.tarfile_to_samples(handler=wds.warn_and_continue),
80 | wds.shuffle(
81 | bufsize=shuffle,
82 | initial=shuffle // 4,
83 | ),
84 | ])
85 | else:
86 | pipeline.extend([wds.split_by_worker, wds.tarfile_to_samples()])
87 | pipeline.extend([
88 | wds.decode(wds.torch_audio, handler=wds.warn_and_continue),
89 | wds.to_tuple("mp3;wav;flac", "json", "__key__") if with_json else wds.to_tuple("mp3;wav;flac", "__key__"),
90 | partial(_seq_crop, crop_size=crop_size, drop_crops = drop_crops)
91 | ])
92 | if add_gain:
93 | pipeline.extend([_audio_gain])
94 | if crop_shuffle is not None:
95 | pipeline.append(wds.shuffle(crop_shuffle))
96 | if batch_size is not None:
97 | pipeline.append(wds.batched(batch_size))
98 | super().__init__(pipeline)
99 |
100 |
101 | #Can also replace with wds.Randomix
102 | class SampleDatasets(wds.DataPipeline, wds.compat.FluidInterface):
103 |
104 | def __init__(self, datasets, probability: Optional[List[float]] = None):
105 | super().__init__()
106 | self.datasets = datasets
107 | if probability is None:
108 | probability = [1.0] * len(self.datasets)
109 | self.prob = probability
110 |
111 | def __iter__(self):
112 | sources = [iter(ds) for ds in self.datasets]
113 | while True:
114 | for source in random.choices(sources, weights=self.prob):
115 | try:
116 | yield next(source)
117 | except StopIteration:
118 | return
119 |
120 |
121 | def create_dataloader(data_urls: List[str],
122 | crop_size: int,
123 | batch_size: int = 32,
124 | crop_shuffle: Optional[int] = None,
125 | resampled: bool = False,
126 | num_workers: int = 4,
127 | *args,
128 | **kwargs):
129 | train_lists: List[str] = []
130 | for train_data_url in data_urls:
131 | train_lists.extend(braceexpand.braceexpand(train_data_url))
132 | ds = Audiowebdataset_Fluid(
133 | train_lists,
134 | crop_size=crop_size,
135 | resample=resampled,
136 | batch_size=batch_size,
137 | crop_shuffle=crop_shuffle,
138 | shuffle=crop_shuffle,
139 | )
140 | dataloader = wds.WebLoader(ds, batch_size=None, num_workers=num_workers)
141 | if crop_shuffle is not None:
142 | dataloader = dataloader.unbatched().shuffle(crop_shuffle).batched(
143 | batch_size)
144 | return dataloader
145 |
--------------------------------------------------------------------------------
/dasheng/train/config/dasheng_06B.yaml:
--------------------------------------------------------------------------------
1 | outputpath: experiments
2 | num_workers: 16
3 | batch_size: 256
4 | train_data:
5 | - acav100M/shards/acav_100M_split{1..1824}_0000000.tar
6 | - audioset/full/shards/full_train_16k_filenames_0000{000..190}.tar
7 | - vggsound/train/shards/train_dev_audio_00000{00..17}.tar
8 | - Jamendo/train/shards/audio_30s_16k_0000{000..055}.tar
9 | cv_data:
10 | - vggsound/test/shards/test_audio_00000{00..14}.tar
11 | chunk_length: 10.0
12 | epochs: 100
13 | epoch_length: 15000
14 | decay_frac: 0.1
15 | mask_ratio: 0.75
16 | warmup_epochs: 3
17 | warmup_iters: null
18 | model: dasheng_06B
19 | model_args:
20 | target_length: 1008 #frames
21 | group_masking: True
22 | optimizer: AdamW8bit
23 | optimizer_args:
24 | lr: 0.0003
25 | weight_decay: 0.01
26 |
--------------------------------------------------------------------------------
/dasheng/train/config/dasheng_12B.yaml:
--------------------------------------------------------------------------------
1 | outputpath: experiments
2 | num_workers: 16
3 | batch_size: 256
4 | train_data:
5 | - acav100M/shards/acav_100M_split{1..1824}_0000000.tar
6 | - audioset/full/shards/full_train_16k_filenames_0000{000..190}.tar
7 | - vggsound/train/shards/train_dev_audio_00000{00..17}.tar
8 | - Jamendo/train/shards/audio_30s_16k_0000{000..055}.tar
9 | cv_data:
10 | - vggsound/test/shards/test_audio_00000{00..14}.tar
11 | chunk_length: 10.0
12 | epochs: 100
13 | epoch_length: 15000
14 | decay_frac: 0.1
15 | mask_ratio: 0.75
16 | warmup_epochs: 3
17 | warmup_iters: null
18 | model: dasheng_12B
19 | model_args:
20 | target_length: 1008 #frames
21 | group_masking: True
22 | optimizer: AdamW8bit
23 | optimizer_args:
24 | lr: 0.0003
25 | weight_decay: 0.01
26 |
27 |
--------------------------------------------------------------------------------
/dasheng/train/config/dasheng_base.yaml:
--------------------------------------------------------------------------------
1 | outputpath: experiments
2 | num_workers: 16
3 | batch_size: 256
4 | train_data:
5 | - acav100M/shards/acav_100M_split{1..1824}_0000000.tar
6 | - audioset/full/shards/full_train_16k_filenames_0000{000..190}.tar
7 | - vggsound/train/shards/train_dev_audio_00000{00..17}.tar
8 | - Jamendo/train/shards/audio_30s_16k_0000{000..055}.tar
9 | cv_data:
10 | - vggsound/test/shards/test_audio_00000{00..14}.tar
11 | chunk_length: 10.0
12 | epochs: 100
13 | epoch_length: 15000
14 | decay_frac: 0.1
15 | mask_ratio: 0.75
16 | warmup_epochs: 3
17 | warmup_iters: null
18 | model: dasheng_base
19 | model_args:
20 | target_length: 1008 #frames
21 | group_masking: True
22 | optimizer: AdamW8bit
23 | optimizer_args:
24 | lr: 0.0003
25 | weight_decay: 0.01
26 |
27 |
--------------------------------------------------------------------------------
/dasheng/train/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from einops import rearrange
4 | from torch.amp import autocast
5 | from functools import partial
6 | from typing import Callable, Optional, Tuple, Union
7 | import torchaudio.transforms as audio_transforms
8 | from einops.layers.torch import Rearrange
9 | from itertools import repeat
10 | import collections
11 |
12 |
13 | def _ntuple(n):
14 |
15 | def parse(x) -> Tuple:
16 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
17 | return tuple(x)
18 | return tuple(repeat(x, n))
19 |
20 | return parse
21 |
22 |
23 | to_2tuple = _ntuple(2)
24 |
25 | class KwargsSequential(nn.Sequential):
26 |
27 | def forward(self, x, **kwargs):
28 | for module in self._modules.values():
29 | x = module(x, **kwargs)
30 | return x
31 |
32 |
33 |
34 | class MAELoss(torch.nn.Module):
35 |
36 | def __init__(self, norm_pix_loss: bool = True):
37 | super().__init__()
38 | self.norm_pix_loss = norm_pix_loss
39 |
40 | @autocast('cuda', enabled=False)
41 | def forward(self, pred: torch.Tensor, target: torch.Tensor,
42 | mask: torch.Tensor) -> torch.Tensor:
43 | if self.norm_pix_loss is True:
44 | mean = target.mean(dim=-1, keepdim=True)
45 | var = target.var(dim=-1, keepdim=True)
46 | target = (target - mean) / (var + 1.e-6)**.5
47 | elif self.norm_pix_loss == 'global':
48 | mean = target.mean()
49 | var = target.var()
50 | target = (target - mean) / (var + 1.e-6)**.5
51 | loss = (pred - target)**2
52 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch
53 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
54 | return loss
55 |
56 |
57 | class AudioPatchEmbed(nn.Module):
58 |
59 | def __init__(self,
60 | input_size: Union[int, Tuple[int, int]] = (64, 100),
61 | patch_size: Tuple[int, int] = (64, 4),
62 | patch_stride: Tuple[int, int] = (64, 4),
63 | in_chans=1,
64 | embed_dim=768,
65 | norm_layer=None,
66 | flatten=False):
67 | super().__init__()
68 | patch_size = to_2tuple(patch_size)
69 | patch_stride = to_2tuple(patch_stride)
70 | self.input_size: Tuple[int, int] = to_2tuple(input_size)
71 | self.patch_size: Tuple[int, int] = to_2tuple(patch_size)
72 | self.patch_stride: Tuple[int, int] = to_2tuple(patch_stride)
73 | self.grid_size = (self.input_size[0] // self.patch_stride[0],
74 | self.input_size[1] // self.patch_stride[1])
75 | self.num_patches = self.grid_size[0] * self.grid_size[1]
76 | self.flatten = flatten
77 |
78 | self.proj = nn.Conv2d(in_chans,
79 | embed_dim,
80 | kernel_size=patch_size,
81 | stride=patch_stride)
82 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
83 |
84 | def forward(self, x):
85 | x = self.proj(x)
86 | if self.flatten:
87 | x = rearrange(x, 'b c f t -> b (f t) c')
88 | x = self.norm(x)
89 | return x
90 |
91 |
92 | class LayerScale(nn.Module):
93 |
94 | def __init__(self, dim: int, init_values=1e-5, inplace=False):
95 | super().__init__()
96 | self.inplace = inplace
97 | self.gamma = nn.Parameter(init_values * torch.ones(dim))
98 |
99 | def forward(self, x):
100 | return x.mul_(self.gamma) if self.inplace else x * self.gamma
101 |
102 |
103 | class Attention(nn.Module):
104 |
105 | def __init__(self,
106 | dim,
107 | num_heads=8,
108 | qkv_bias=False,
109 | attn_drop=0.,
110 | proj_drop=0.):
111 | super().__init__()
112 | assert dim % num_heads == 0, 'dim should be divisible by num_heads'
113 | self.num_heads = num_heads
114 | head_dim = dim // num_heads
115 | self.scale = head_dim**-0.5
116 |
117 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
118 | self.attn_drop = nn.Dropout(attn_drop)
119 | self.proj = nn.Linear(dim, dim)
120 | self.proj_drop = nn.Dropout(proj_drop)
121 |
122 | def forward(self, x, mask:Optional[torch.Tensor] = None):
123 | B, N, C = x.shape
124 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
125 | C // self.num_heads).permute(2, 0, 3, 1, 4)
126 | q, k, v = qkv.unbind(
127 | 0) # make torchscript happy (cannot use tensor as tuple)
128 |
129 | attn = (q @ k.transpose(-2, -1)) * self.scale
130 | if mask is not None:
131 | # mask value as the lowest possible value in fp32
132 | mask_value = torch.finfo(attn.dtype).min
133 | # Mask is of shape [1, SRC_LEN]
134 | attn_mask = mask[:, None, None, :].expand(B, 1, N, N)
135 | #Mask should be of shape
136 | #[B,1,Target_len, Source_len]
137 | attn = attn.masked_fill(attn_mask, mask_value)
138 | attn = attn.softmax(dim=-1)
139 | attn = torch.nan_to_num(attn)
140 | attn = self.attn_drop(attn)
141 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
142 | x = self.proj(x)
143 | x = self.proj_drop(x)
144 | return x
145 |
146 |
147 | class Mlp(nn.Module):
148 |
149 | def __init__(self,
150 | in_features,
151 | hidden_features:Optional[int]=None,
152 | out_features:Optional[int]=None,
153 | act_layer:Callable=nn.GELU,
154 | drop=0.):
155 | super().__init__()
156 | out_features = out_features or in_features
157 | hidden_features = hidden_features or in_features
158 | self.fc1 = nn.Linear(in_features, hidden_features)
159 | self.act = act_layer()
160 | self.fc2 = nn.Linear(hidden_features, out_features)
161 | self.drop = nn.Dropout(drop)
162 |
163 | def forward(self, x):
164 | x = self.fc1(x)
165 | x = self.act(x)
166 | x = self.drop(x)
167 | x = self.fc2(x)
168 | x = self.drop(x)
169 | return x
170 |
171 |
172 | class Block(nn.Module):
173 |
174 | def __init__(
175 | self,
176 | dim,
177 | num_heads,
178 | mlp_ratio=4.,
179 | qkv_bias=False,
180 | drop=0.,
181 | attn_drop=0.,
182 | init_values=None,
183 | act_layer: Callable = nn.GELU,
184 | norm_layer: Callable = nn.LayerNorm,
185 | attention_type='Attention',
186 | ):
187 | super().__init__()
188 | self.norm1 = norm_layer(dim)
189 | attn_type = globals()[attention_type]
190 | self.attn = attn_type(dim,
191 | num_heads=num_heads,
192 | qkv_bias=qkv_bias,
193 | attn_drop=attn_drop,
194 | proj_drop=drop)
195 | self.ls1 = LayerScale(
196 | dim, init_values=init_values) if init_values else nn.Identity()
197 |
198 | self.norm2 = norm_layer(dim)
199 | self.mlp = Mlp(in_features=dim,
200 | hidden_features=int(dim * mlp_ratio),
201 | act_layer=act_layer,
202 | drop=drop)
203 | self.ls2 = LayerScale(
204 | dim, init_values=init_values) if init_values else nn.Identity()
205 |
206 | # kwargs is usually a mask
207 | def forward(self, x, **kwargs):
208 | x = x + self.ls1(self.attn(self.norm1(x), **kwargs))
209 | x = x + self.ls2(self.mlp(self.norm2(x)))
210 | return x
211 |
212 |
213 | class AudioTransformerMAE_Encoder(nn.Module):
214 |
215 | def __init__(self,
216 | patch_size: Tuple[int, int] = (64, 4),
217 | patch_stride: Tuple[int, int] = (64, 4),
218 | embed_dim: int = 768,
219 | depth: int = 12,
220 | num_heads=8,
221 | mlp_ratio=4.,
222 | qkv_bias=True,
223 | drop_rate=0.,
224 | attn_drop_rate=0.,
225 | norm_layer=None,
226 | act_layer=None,
227 | init_values=None,
228 | target_length=1008,
229 | pooling='mean',
230 | time_patch_out: Optional[float] = None,
231 | freq_patch_out: Optional[float] = None,
232 | block_type='Block',
233 | attention_type='Attention',
234 | eval_avg='cat',
235 | n_fft: int = 512,
236 | n_mels: int = 64,
237 | hop_size: int = 160,
238 | win_size: int = 512,
239 | f_min: int = 0,
240 | f_max: int = 8000,
241 | center: bool = True,
242 | **kwargs):
243 | super().__init__()
244 | self.pooling = pooling
245 | self.embed_dim = embed_dim
246 | self.patch_stride = patch_stride
247 | self.patch_size = patch_size
248 | self.hop_size = hop_size
249 | self.win_size = win_size
250 | self.n_mels = n_mels
251 | self.eval_avg = eval_avg
252 | self.time_patch_out = time_patch_out
253 | self.freq_patch_out = freq_patch_out
254 |
255 | self.front_end = nn.Sequential(
256 | audio_transforms.MelSpectrogram(f_min=f_min,
257 | sample_rate=16000,
258 | win_length=win_size,
259 | center=center,
260 | n_fft=n_fft,
261 | f_max=f_max,
262 | hop_length=hop_size,
263 | n_mels=self.n_mels),
264 | audio_transforms.AmplitudeToDB(top_db=kwargs.get('top_db', 120)))
265 |
266 | self.init_bn = nn.Sequential(
267 | Rearrange('b c f t -> b f c t'),
268 | nn.BatchNorm2d(self.n_mels, momentum=0.01),
269 | Rearrange('b f c t -> b c f t'))
270 |
271 | self.target_length = target_length
272 | self.patch_embed = AudioPatchEmbed(input_size=(self.n_mels,
273 | target_length),
274 | embed_dim=self.embed_dim,
275 | patch_size=self.patch_size,
276 | flatten=False,
277 | patch_stride=self.patch_stride)
278 | self.num_patches = self.patch_embed.num_patches
279 |
280 | if pooling == 'token':
281 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
282 | self.token_pos_embed = nn.Parameter(
283 | torch.randn(1, embed_dim) * .02)
284 |
285 | self.time_pos_embed = nn.Parameter(
286 | torch.randn(1, embed_dim, 1, self.patch_embed.grid_size[1]) * .02)
287 | self.freq_pos_embed = nn.Parameter(
288 | torch.randn(1, embed_dim, self.patch_embed.grid_size[0], 1) * .02)
289 |
290 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
291 | act_layer = act_layer or nn.GELU
292 | self.pos_drop = nn.Dropout(p=drop_rate)
293 | block_function = globals()[block_type]
294 | self.blocks = KwargsSequential(*[
295 | block_function(
296 | dim=embed_dim,
297 | num_heads=num_heads,
298 | mlp_ratio=mlp_ratio,
299 | qkv_bias=qkv_bias,
300 | init_values=init_values,
301 | drop=drop_rate,
302 | attn_drop=attn_drop_rate,
303 | norm_layer=norm_layer,
304 | act_layer=act_layer,
305 | attention_type=attention_type,
306 | ) for _ in range(depth)
307 | ])
308 | self.norm = norm_layer(embed_dim)
309 | self.apply(self.init_weights)
310 | if hasattr(self, 'cls_token') and self.cls_token is not None:
311 | nn.init.normal_(self.cls_token, std=1e-6)
312 | group_masking = kwargs.get('group_masking', False)
313 | if isinstance(group_masking, bool):
314 | if group_masking is True:
315 | self.masking_func = self.random_masking_group
316 | else:
317 | self.masking_func = self.random_masking
318 | elif isinstance(group_masking, int):
319 | self.masking_func = partial(self.random_masking_group,
320 | group_factor=group_masking)
321 |
322 | @torch.jit.ignore
323 | def no_weight_decay(self):
324 | return {
325 | 'time_pos_embed', 'cls_token', 'freq_pos_embed', 'token_pos_embed'
326 | }
327 |
328 | def init_weights(self, module):
329 | if isinstance(module, nn.Linear):
330 | torch.nn.init.xavier_uniform_(module.weight)
331 | if module.bias is not None:
332 | nn.init.zeros_(module.bias)
333 | elif isinstance(module, nn.LayerNorm):
334 | nn.init.constant_(module.bias, 0)
335 | nn.init.constant_(module.weight, 1.0)
336 |
337 | def random_masking_group(self, x, mask_ratio, group_factor: int = 2):
338 | """
339 | Perform per-sample random masking by per-sample shuffling.
340 | Per-sample shuffling is done by argsort random noise.
341 | x: [N, L, D], sequence
342 | """
343 | N, L, D = x.shape # batch, length, dim
344 | len_keep = int(L * (1 - mask_ratio))
345 |
346 | noise = torch.rand(N, L // group_factor,
347 | device=x.device) # noise in [0, 1]
348 | # indices = torch.arange(L).view(1, 5, 4).repeat(N, 1, 1)
349 | indices = torch.arange(L, device=x.device).view(-1, group_factor)
350 |
351 | # sort noise for each sample
352 | ids_shuffle = torch.argsort(
353 | noise, dim=1) # ascend: small is keep, large is remove
354 | ids_shuffle = indices[ids_shuffle].flatten(-2)
355 | ids_restore = torch.argsort(ids_shuffle, dim=1)
356 |
357 | # keep the first subset
358 | ids_keep = ids_shuffle[:, :len_keep]
359 | x_masked = torch.gather(x,
360 | dim=1,
361 | index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
362 |
363 | # generate the binary mask: 0 is keep, 1 is remove
364 | mask = torch.ones([N, L], device=x.device)
365 | mask[:, :len_keep] = 0
366 | # unshuffle to get the binary mask
367 | mask = torch.gather(mask, dim=1, index=ids_restore)
368 |
369 | return x_masked, mask, ids_restore
370 |
371 | def random_masking(self, x, mask_ratio):
372 | """
373 | Perform per-sample random masking by per-sample shuffling.
374 | Per-sample shuffling is done by argsort random noise.
375 | x: [N, L, D], sequence
376 | """
377 | N, L, D = x.shape # batch, length, dim
378 | len_keep = int(L * (1 - mask_ratio))
379 |
380 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
381 |
382 | # sort noise for each sample
383 | ids_shuffle = torch.argsort(
384 | noise, dim=1) # ascend: small is keep, large is remove
385 | ids_restore = torch.argsort(ids_shuffle, dim=1)
386 |
387 | # keep the first subset
388 | ids_keep = ids_shuffle[:, :len_keep]
389 | x_masked = torch.gather(x,
390 | dim=1,
391 | index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
392 |
393 | # generate the binary mask: 0 is keep, 1 is remove
394 | mask = torch.ones([N, L], device=x.device)
395 | mask[:, :len_keep] = 0
396 | # unshuffle to get the binary mask
397 | mask = torch.gather(mask, dim=1, index=ids_restore)
398 |
399 | return x_masked, mask, ids_restore
400 |
401 | def forward_features(self, x, mask_ratio):
402 | x = self.patch_embed(x)
403 | b, c, f, t = x.shape
404 | x = x + self.time_pos_embed[:, :, :, :t]
405 | x = x + self.freq_pos_embed[:, :, :, :] # Just for sin pos embed
406 | x = rearrange(x, 'b c f t -> b (f t) c')
407 | # x, mask, ids_restore = self.random_masking(x, mask_ratio)
408 | x, mask, ids_restore = self.masking_func(x, mask_ratio)
409 | if self.pooling == 'token':
410 | cls_token = self.cls_token.expand(x.shape[0], -1, -1)
411 | cls_token = cls_token + self.token_pos_embed[:, :]
412 | x = torch.cat((cls_token, x), dim=1)
413 | x = self.pos_drop(x)
414 | x = self.blocks(x)
415 | x = self.norm(x)
416 | return x, mask, ids_restore
417 |
418 | def load_state_dict(self, state_dict, strict=True, **kwargs):
419 | if 'time_pos_embed' in state_dict and self.time_pos_embed.shape != state_dict[
420 | 'time_pos_embed'].shape:
421 | print(
422 | "Positional Embedding shape not the same with model, resizing!"
423 | )
424 | self.change_pos_embedding(state_dict)
425 | super().load_state_dict(state_dict, strict=strict, **kwargs)
426 |
427 | def change_pos_embedding(self, state_dict):
428 | target_time_pos_embed_length = self.time_pos_embed.shape[-1]
429 | target_freq_pos_embed_length = self.freq_pos_embed.shape[-2]
430 |
431 | pretrained_time_pos_embed = state_dict['time_pos_embed']
432 | pretrained_freq_pos_embed = state_dict['freq_pos_embed']
433 |
434 | if target_time_pos_embed_length <= pretrained_time_pos_embed.shape[-1]:
435 | state_dict['time_pos_embed'] = pretrained_time_pos_embed[
436 | ..., :target_time_pos_embed_length]
437 | else:
438 | state_dict['time_pos_embed'] = torch.nn.functional.interpolate(
439 | pretrained_time_pos_embed,
440 | size=(1, target_time_pos_embed_length),
441 | align_corners=False,
442 | mode='bilinear')
443 | if target_freq_pos_embed_length <= pretrained_freq_pos_embed.shape[-2]:
444 | state_dict[
445 | 'freq_pos_embed'] = pretrained_freq_pos_embed[:, :, :
446 | target_freq_pos_embed_length, :]
447 | else:
448 | state_dict['freq_pos_embed'] = torch.nn.functional.interpolate(
449 | pretrained_freq_pos_embed,
450 | size=(target_freq_pos_embed_length, 1),
451 | align_corners=False,
452 | mode='bilinear')
453 |
454 | def forward_to_spec(self, x):
455 | # Do not use fp16 for feature extraction, that is likely to get nan
456 | with autocast('cuda', enabled=False):
457 | X = self.front_end(x)
458 | X = rearrange(X, 'b f t -> b 1 f t')
459 | X = self.init_bn(X)
460 | return X
461 |
462 | def forward(self, x, mask_ratio: float = 0.75):
463 | x = self.forward_to_spec(x)
464 | x, mask, restore_idxs = self.forward_features(x, mask_ratio=mask_ratio)
465 | return x, mask, restore_idxs
466 |
467 |
468 | class AudioTransformerMAE_Decoder(nn.Module):
469 |
470 | def __init__(self,
471 | input_dim: int,
472 | outputdim: int,
473 | patch_size: int = 16,
474 | patch_stride: int = 16,
475 | embed_dim: int = 768,
476 | num_patches: int = 100,
477 | depth: int = 12,
478 | num_heads: int = 12,
479 | mlp_ratio: float = 4.,
480 | qkv_bias: bool = True,
481 | drop_rate: float = 0.,
482 | attn_drop_rate: float = 0.,
483 | norm_layer: Optional[torch.nn.Module] = None,
484 | act_layer: Optional[torch.nn.Module] = None,
485 | cls_token: bool = False,
486 | attention_type='Attention',
487 | init_values=None,
488 | **kwargs):
489 | super().__init__()
490 | self.embed_dim = embed_dim
491 | self.patch_stride = patch_stride
492 | self.patch_size = patch_size
493 | self.input_dim = input_dim
494 |
495 | self.input_proj = nn.Linear(input_dim, embed_dim)
496 |
497 | self.mask_token = nn.Parameter(torch.randn(1, 1, embed_dim) * .02)
498 | _norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
499 | _act_layer = act_layer or nn.GELU
500 | self.use_cls = cls_token
501 | num_patches_total = num_patches + 1 if not cls_token else num_patches
502 | self.pos_embed = nn.Parameter(
503 | torch.zeros(1, num_patches_total, embed_dim))
504 | self.pos_drop = nn.Dropout(p=drop_rate)
505 | self.blocks = nn.Sequential(*[
506 | Block(
507 | dim=embed_dim,
508 | num_heads=num_heads,
509 | mlp_ratio=mlp_ratio,
510 | qkv_bias=qkv_bias,
511 | init_values=init_values,
512 | drop=drop_rate,
513 | attn_drop=attn_drop_rate,
514 | norm_layer=_norm_layer,
515 | act_layer=_act_layer,
516 | attention_type=attention_type,
517 | ) for i in range(depth)
518 | ])
519 | self.norm = _norm_layer(embed_dim)
520 | self.outputlayer = nn.Linear(self.embed_dim, outputdim)
521 | self.apply(self.init_weights)
522 | torch.nn.init.normal_(self.mask_token, std=.02)
523 |
524 | @torch.jit.ignore
525 | def no_weight_decay(self):
526 | return {
527 | 'time_pos_embed', 'cls_token', 'freq_pos_embed', 'token_pos_embed'
528 | }
529 |
530 | def init_weights(self, module):
531 | if isinstance(module, nn.Linear):
532 | nn.init.trunc_normal_(module.weight, std=.02)
533 | if module.bias is not None:
534 | nn.init.zeros_(module.bias)
535 | elif isinstance(module, nn.LayerNorm):
536 | nn.init.constant_(module.bias, 0)
537 | nn.init.constant_(module.weight, 1.0)
538 |
539 | def forward_features(self, x, ids_restore):
540 | x = self.input_proj(x)
541 | mask_tokens = self.mask_token.repeat(
542 | x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
543 | if self.use_cls:
544 | x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
545 | else:
546 | x_ = torch.cat([x[:, :, :], mask_tokens], dim=1)
547 | x_ = torch.gather(x_,
548 | dim=1,
549 | index=ids_restore.unsqueeze(-1).repeat(
550 | 1, 1, x.shape[2])) # unshuffle
551 | if self.use_cls:
552 | x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
553 | else:
554 | x = x_
555 | t = x.shape[1]
556 |
557 | x = x + self.pos_embed[:, :t, :]
558 | x = self.pos_drop(x)
559 | x = self.blocks(x)
560 | x = self.norm(x)
561 | return x
562 |
563 | def forward(self, x, restore_idxs):
564 | x = self.forward_features(x, restore_idxs)
565 | x = self.outputlayer(x)
566 | return x
567 |
568 |
569 | class AudioTransformerMAE(nn.Module):
570 |
571 | def __init__(self,
572 | encoder: AudioTransformerMAE_Encoder,
573 | decoder: AudioTransformerMAE_Decoder,
574 | loss_fn: Optional[torch.nn.Module] = None):
575 | super().__init__()
576 | self.encoder = encoder
577 | self.decoder = decoder
578 | self.unfold = nn.Unfold(
579 | kernel_size=self.encoder.patch_embed.patch_size,
580 | stride=self.encoder.patch_embed.patch_size)
581 | self.loss_fn = MAELoss() if loss_fn is None else loss_fn
582 |
583 | def forward(self,
584 | x: torch.Tensor,
585 | mask_ratio: float = 0.75,
586 | return_loss: bool = False):
587 | latent, mask, restore_ids = self.encoder(x, mask_ratio=mask_ratio)
588 | pred = self.decoder(latent, restore_ids)
589 | with autocast('cuda', enabled=False):
590 | targets = self.encoder.front_end(x)
591 | targets = self.patchify(targets)
592 | if return_loss:
593 | return self.loss_fn(pred, targets, mask)
594 | return pred, targets, mask
595 |
596 | def patchify(self, x):
597 | return self.unfold(x.unsqueeze(1)).transpose(-2, -1)
598 |
599 |
600 | def dasheng_base(**kwargs):
601 | encoder_kwargs = dict(embed_dim=768,
602 | depth=12,
603 | num_heads=12,
604 | target_length=1008,
605 | patch_size=[64, 4],
606 | patch_stride=[64, 4])
607 | encoder_kwargs.update(
608 | (k, kwargs[k]) for k in set(kwargs).intersection(encoder_kwargs))
609 | encoder_kwargs = {**encoder_kwargs, **kwargs}
610 | encoder = AudioTransformerMAE_Encoder(**encoder_kwargs)
611 |
612 | decoder_kwargs = dict(embed_dim=512,
613 | depth=8,
614 | num_heads=16,
615 | input_dim=encoder_kwargs['embed_dim'],
616 | outputdim=encoder.patch_embed.patch_size[0] *
617 | encoder.patch_embed.patch_size[1],
618 | num_patches=encoder.patch_embed.num_patches)
619 | decoder = AudioTransformerMAE_Decoder(**decoder_kwargs)
620 | return AudioTransformerMAE(encoder, decoder)
621 |
622 |
623 | def dasheng_06B(**kwargs):
624 | encoder_kwargs = dict(
625 | patch_size=[64, 4],
626 | patch_stride=[64, 4],
627 | embed_dim=1536,
628 | depth=24,
629 | num_heads=24,
630 | mlp_ratio=4,
631 | )
632 | encoder_kwargs.update(
633 | (k, kwargs[k]) for k in set(kwargs).intersection(encoder_kwargs))
634 | encoder_kwargs = {**encoder_kwargs, **kwargs}
635 | encoder = AudioTransformerMAE_Encoder(**encoder_kwargs)
636 |
637 | decoder_kwargs = dict(embed_dim=512,
638 | depth=8,
639 | num_heads=16,
640 | input_dim=encoder_kwargs['embed_dim'],
641 | outputdim=encoder.patch_embed.patch_size[0] *
642 | encoder.patch_embed.patch_size[1],
643 | num_patches=encoder.patch_embed.num_patches)
644 | decoder = AudioTransformerMAE_Decoder(**decoder_kwargs)
645 | return AudioTransformerMAE(encoder, decoder)
646 |
647 |
648 | def dasheng_12B(**kwargs):
649 | encoder_kwargs = dict(
650 | patch_size=[64, 4],
651 | patch_stride=[64, 4],
652 | embed_dim=1536,
653 | depth=40,
654 | num_heads=24,
655 | mlp_ratio=4,
656 | )
657 | encoder_kwargs.update(
658 | (k, kwargs[k]) for k in set(kwargs).intersection(encoder_kwargs))
659 | encoder_kwargs = {**encoder_kwargs, **kwargs}
660 | encoder = AudioTransformerMAE_Encoder(**encoder_kwargs)
661 |
662 | decoder_kwargs = dict(embed_dim=768,
663 | depth=8,
664 | num_heads=24,
665 | input_dim=encoder_kwargs['embed_dim'],
666 | outputdim=encoder.patch_embed.patch_size[0] *
667 | encoder.patch_embed.patch_size[1],
668 | num_patches=encoder.patch_embed.num_patches)
669 | decoder = AudioTransformerMAE_Decoder(**decoder_kwargs)
670 | return AudioTransformerMAE(encoder, decoder)
671 |
--------------------------------------------------------------------------------
/dasheng/train/train.py:
--------------------------------------------------------------------------------
1 | from loguru import logger
2 | from fire import Fire
3 | import numpy as np
4 | from audiowebdataset import create_dataloader
5 |
6 | import models
7 | import utils
8 | import torch
9 | import sys
10 | import ignite
11 | from ignite.contrib.handlers import ProgressBar
12 | from ignite.engine import (Engine, Events)
13 | from ignite.handlers import (Checkpoint, DiskSaver, global_step_from_engine,
14 | create_lr_scheduler_with_warmup)
15 | from accelerate import Accelerator
16 |
17 | logger.configure(handlers=[{
18 | "sink": sys.stderr,
19 | "format": "[{time:YYYY-MM-DD HH:mm:ss}] {message}",
20 | 'level': 'DEBUG',
21 | }])
22 |
23 |
24 | def transfer_to_device(batch, device):
25 | return (x.to(device, non_blocking=True)
26 | if isinstance(x, torch.Tensor) else x for x in batch)
27 |
28 |
29 | def create_engine(engine_function,
30 | local_rank: int = 0,
31 | output_transform=lambda x: x):
32 | engine = Engine(engine_function)
33 | if local_rank == 0:
34 | ProgressBar().attach(engine, output_transform=output_transform)
35 | return engine
36 |
37 |
38 | class Runner(object):
39 |
40 | def __init__(self, seed: int = 42, nthreads: int = 1):
41 | super().__init__()
42 | self.seed = seed
43 | torch.manual_seed(seed)
44 | np.random.seed(seed)
45 | torch.set_num_threads(nthreads)
46 | logger.info(f"Using seed {seed}")
47 |
48 | def __create_dir(self, config: utils.MAEConfig):
49 | config.outputdir.mkdir(exist_ok=True, parents=True)
50 | logger.add(
51 | config.outputdir / config.logfile,
52 | enqueue=True,
53 | level='INFO',
54 | format=
55 | "[{level} {time:YYYY-MM-DD HH:mm:ss}] {message}"
56 | )
57 |
58 | def log_basic_info(self, config_parameters: utils.MAEConfig, device):
59 | logger.info(f"Running on device {device}")
60 | logger.info(f"Storing output in {config_parameters.outputdir}")
61 | logger.info(f"- PyTorch version: {torch.__version__}")
62 | logger.info(f"- Ignite version: {ignite.__version__}")
63 | if torch.cuda.is_available():
64 | logger.info(f"- GPU Device: {torch.cuda.current_device()}")
65 | logger.info(f"- CUDA version: {torch.version.cuda}")
66 | for k, v in config_parameters.to_dict().items():
67 | logger.info(f"{k} : {v}")
68 |
69 | def train(self, config, **overwrite_kwargs):
70 | config_parameters = utils.MAEConfig.from_config_file(
71 | config, **overwrite_kwargs)
72 | accelerator = Accelerator()
73 |
74 | def log(message: str):
75 | if accelerator.is_main_process:
76 | logger.info(message)
77 |
78 | if accelerator.is_main_process:
79 | self.__create_dir(config_parameters)
80 | self.log_basic_info(config_parameters, device=accelerator.device)
81 | train_dataloader = create_dataloader(
82 | config_parameters.train_data,
83 | crop_size=int(config_parameters.chunk_length *
84 | config_parameters.sample_rate),
85 | batch_size=config_parameters.batch_size,
86 | crop_shuffle=config_parameters.crop_shuffle,
87 | resampled=True)
88 |
89 | test_dataloader = create_dataloader(
90 | config_parameters.cv_data,
91 | crop_size=int(config_parameters.chunk_length *
92 | config_parameters.sample_rate),
93 | batch_size=config_parameters.batch_size)
94 |
95 | model = getattr(
96 | models,
97 | config_parameters.model)(**config_parameters.model_args).train()
98 | log(model)
99 |
100 | if '8bit' in config_parameters.optimizer:
101 | import bitsandbytes as bnb
102 | optimizer = getattr(bnb.optim, config_parameters.optimizer)(
103 | model.parameters(),
104 | **config_parameters.optimizer_args) # add bnb optimizer
105 | else:
106 | optimizer = getattr(torch.optim, config_parameters.optimizer)(
107 | model.parameters(), **config_parameters.optimizer_args)
108 |
109 | def _inference(engine, batch):
110 | model.eval()
111 | with torch.no_grad():
112 | with accelerator.autocast():
113 | x, *_ = transfer_to_device(batch, accelerator.device)
114 | loss = model(x,
115 | mask_ratio=config_parameters.mask_ratio,
116 | return_loss=True)
117 | return loss
118 |
119 | def train_batch(engine, batch):
120 | model.train()
121 | with torch.enable_grad():
122 | x, *_ = transfer_to_device(batch, accelerator.device)
123 | optimizer.zero_grad()
124 | with accelerator.autocast():
125 | loss = model(x,
126 | mask_ratio=config_parameters.mask_ratio,
127 | return_loss=True)
128 | accelerator.backward(loss)
129 | optimizer.step()
130 | return {
131 | 'loss': loss.item(),
132 | 'lr': optimizer.param_groups[0]['lr']
133 | }
134 |
135 | def run_validation(engine, title=None):
136 | if accelerator.is_main_process:
137 | results = engine.state.metrics
138 | output_str_list = [
139 | f"{title:<10} Results - Epoch : {train_engine.state.epoch:<4}"
140 | ] + [
141 | f"{metric} {results[metric]:<5.4f}" for metric in results
142 | ] + [f"LR: {optimizer.param_groups[0]['lr']:.2e}"]
143 | log(" ".join(output_str_list))
144 |
145 | train_engine = create_engine(train_batch)
146 | inference_engine = create_engine(_inference, output_transform=None)
147 | ignite.metrics.Average().attach(inference_engine, 'Loss')
148 |
149 | score_function = Checkpoint.get_default_score_fn(
150 | *config_parameters.score_function)
151 | checkpoint_saver = Checkpoint(
152 | {
153 | 'model': model.encoder,
154 | 'config': config_parameters,
155 | },
156 | DiskSaver(config_parameters.outputdir),
157 | n_saved=config_parameters.n_saved,
158 | global_step_transform=global_step_from_engine(train_engine),
159 | filename_prefix='best',
160 | score_function=score_function)
161 | last_checkpoint_saver = Checkpoint(
162 | {
163 | 'model': model.encoder,
164 | 'config': config_parameters
165 | },
166 | DiskSaver(config_parameters.outputdir),
167 | n_saved=1,
168 | global_step_transform=global_step_from_engine(train_engine))
169 |
170 | train_length = config_parameters.epoch_length * config_parameters.epochs
171 | decay_steps = train_length
172 |
173 | if config_parameters.use_scheduler:
174 | scheduler = ignite.handlers.param_scheduler.CosineAnnealingScheduler(
175 | optimizer, 'lr', optimizer.param_groups[0]['lr'],
176 | optimizer.param_groups[0]['lr'] * config_parameters.decay_frac,
177 | decay_steps)
178 | warmup_time_in_iters = None
179 | if config_parameters.warmup_iters is not None:
180 | warmup_time_in_iters = config_parameters.warmup_iters
181 | elif config_parameters.warmup_epochs is not None:
182 | warmup_time_in_iters = config_parameters.epoch_length * config_parameters.warmup_epochs
183 | if warmup_time_in_iters is not None:
184 | log(f"Using warmup with {warmup_time_in_iters} iters")
185 | scheduler = create_lr_scheduler_with_warmup(
186 | scheduler,
187 | warmup_start_value=0.0,
188 | warmup_duration=warmup_time_in_iters)
189 |
190 | train_engine.add_event_handler(Events.ITERATION_STARTED, scheduler)
191 | inference_engine.add_event_handler(Events.COMPLETED, checkpoint_saver)
192 | inference_engine.add_event_handler(Events.COMPLETED,
193 | last_checkpoint_saver)
194 |
195 | @train_engine.on(
196 | Events.EPOCH_COMPLETED(every=config_parameters.valid_every))
197 | def valid_eval(train_engine):
198 | with inference_engine.add_event_handler(Events.COMPLETED,
199 | run_validation,
200 | "Validation"):
201 | inference_engine.run(test_dataloader)
202 |
203 | model, optimizer, train_dataloader, test_dataloader = accelerator.prepare(
204 | model, optimizer, train_dataloader, test_dataloader)
205 |
206 | train_engine.run(
207 | train_dataloader,
208 | max_epochs=config_parameters.epochs,
209 | epoch_length=config_parameters.epoch_length,
210 | )
211 | output_model = config_parameters.outputdir / checkpoint_saver.last_checkpoint
212 | if config_parameters.average_final_model:
213 | log("Averaging best models ...")
214 | output_model = config_parameters.outputdir / 'averaged.pt'
215 |
216 | averaged_state_dict = utils.average_models([
217 | config_parameters.outputdir / f.filename
218 | for f in checkpoint_saver._saved
219 | ])
220 | torch.save(averaged_state_dict, output_model)
221 |
222 |
223 | if __name__ == "__main__":
224 | Fire(Runner().train)
225 |
--------------------------------------------------------------------------------
/dasheng/train/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from pathlib import Path
3 | import uuid
4 | from typing import Dict, List, Optional, Tuple, Any, Type
5 |
6 | import yaml
7 | import datetime
8 | import torch
9 | from dataclasses import dataclass, field, asdict
10 |
11 |
12 | @dataclass
13 | class MAEConfig:
14 | train_data: List[str]
15 | cv_data: List[str]
16 | config_file: str = '' # Will be overwritten during parsing
17 | logfile: str = 'train.log'
18 | outputpath: str = 'experiments'
19 | #Train args
20 | mask_ratio: float = 0.75
21 | use_scheduler: bool = True
22 | warmup_iters: Optional[int] = None
23 | warmup_epochs: Optional[int] = None
24 | model: str = 'dasheng_base'
25 | model_args: Dict[str, Any] = field(default_factory=lambda: dict())
26 | decay_frac: float = 0.01 # Decay fraction of learning rate
27 |
28 | optimizer: str = 'AdamW8bit'
29 | optimizer_args: Dict[str,
30 | Any] = field(default_factory=lambda: dict(lr=0.0003, weight_decay=0.01))
31 | epochs: int = 100
32 | epoch_length: int = 15000
33 | # Dataloader args
34 | batch_size: int = 32
35 | n_saved: int = 4 # Num models saved
36 | num_workers: int = 4
37 | resampled: bool = True
38 | crop_shuffle: int = 512
39 | chunk_length: float = 10.0 # Sample length during training/testing
40 | sample_rate: int = 16000 # Sampling rate of audio
41 | valid_every: int = 1 # When to run validation
42 | score_function: Tuple[str, float] = ('Loss', -1.0) # Save best loss on CV
43 | average_final_model: bool = True
44 | outputdir: Path = field(init=False)
45 |
46 | def __post_init__(self):
47 | self.outputdir = Path(self.outputpath) / Path(
48 | self.config_file
49 | ).stem / self.model / f"{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')}_{uuid.uuid1().hex}"
50 |
51 | def to_dict(self):
52 | return asdict(self)
53 |
54 | def state_dict(self):
55 | return self.to_dict()
56 |
57 | @classmethod
58 | def load_state_dict(cls, state):
59 | return cls(**state)
60 |
61 | @classmethod
62 | def from_config_file(cls: Type[MAEConfig], config_file: str,
63 | **kwargs) -> MAEConfig:
64 | """parse_config_or_kwargs
65 |
66 | :param config_file: Config file that has parameters, yaml format
67 | :param **kwargs: Other alternative parameters or overwrites for config
68 | """
69 | with open(config_file) as con_read:
70 | yaml_config = yaml.load(con_read, Loader=yaml.FullLoader)
71 | # values from config file are all possible params
72 | return cls(**dict(yaml_config, config_file=config_file, **kwargs))
73 |
74 |
75 | def average_models(models: List[str]):
76 | model_res_state_dict = {}
77 | state_dict = {}
78 | has_new_structure = False
79 | for m in models:
80 | cur_state = torch.load(m, map_location='cpu')
81 | if 'model' in cur_state:
82 | has_new_structure = True
83 | model_params = cur_state.pop('model')
84 | # Append non "model" items, encoder, optimizer etc ...
85 | for k in cur_state:
86 | state_dict[k] = cur_state[k]
87 | # Accumulate statistics
88 | for k in model_params:
89 | if k in model_res_state_dict:
90 | model_res_state_dict[k] += model_params[k]
91 | else:
92 | model_res_state_dict[k] = model_params[k]
93 | else:
94 | for k in cur_state:
95 | if k in model_res_state_dict:
96 | model_res_state_dict[k] += cur_state[k]
97 | else:
98 | model_res_state_dict[k] = cur_state[k]
99 |
100 | # Average
101 | for k in model_res_state_dict:
102 | # If there are any parameters
103 | if model_res_state_dict[k].ndim > 0:
104 | model_res_state_dict[k] /= float(len(models))
105 | if has_new_structure:
106 | state_dict['model'] = model_res_state_dict
107 | else:
108 | state_dict = model_res_state_dict
109 | return state_dict
110 |
--------------------------------------------------------------------------------
/metadata/hear_capabilities.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RicherMans/Dasheng/db7309358edbeea1b1cca37739f442c4139ac8e9/metadata/hear_capabilities.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = 'dasheng'
3 | version = '0.0.8'
4 | dependencies = [
5 | "einops",
6 | "numpy",
7 | "pytorch_ignite",
8 | "torch",
9 | "torchaudio",
10 | ]
11 | authors = [
12 | {name = "Heinrich Dinkel", email = "dinkelheinrich@xiaomi.com"},
13 | {name = "Junbo Zhang", email = "zhangjunbo1@xiaomi.com"},
14 | ]
15 | maintainers = [
16 | {name = "Heinrich Dinkel", email = "dinkelheinrich@xiaomi.com"},
17 | {name = "Junbo Zhang", email = "zhangjunbo1@xiaomi.com"},
18 | ]
19 | readme = "README.md"
20 | license = {file = "LICENSE"}
21 |
22 |
23 | classifiers = [
24 | "Development Status :: 3 - Alpha",
25 |
26 | # Indicate who your project is intended for
27 | "Intended Audience :: Developers",
28 | "Topic :: Software Development :: Build Tools",
29 |
30 | # Pick your license as you wish (see also "license" above)
31 | "License :: OSI Approved :: MIT License",
32 |
33 | # Specify the Python versions you support here.
34 | "Programming Language :: Python :: 3",
35 | "Programming Language :: Python :: 3.8",
36 | "Programming Language :: Python :: 3.9",
37 | "Programming Language :: Python :: 3.10",
38 | "Programming Language :: Python :: 3.11",
39 | "Environment :: GPU :: NVIDIA CUDA :: 11.4",
40 | "Environment :: GPU :: NVIDIA CUDA :: 12",
41 | "Topic :: Multimedia :: Sound/Audio :: Speech",
42 | ]
43 |
44 |
45 | [project.urls]
46 | Homepage = "https://github.com/Richermans/dasheng"
47 | Documentation = "https://github.com/Richermans/dasheng"
48 | Repository = "https://github.com/Richermans/dasheng"
49 | Issues = "https://github.com/Richermans/dasheng/issues"
50 |
51 | [tool.poetry]
52 | description = {file = 'README.md', format = 'markdown'}
53 | authors = ['Heinrich Dinkel dinkelheinrich@xiaomi.com']
54 | readme = 'README.md'
55 | requires-python = '>=3.9'
56 |
57 | [tool.setuptools.packages.find]
58 | where = ['.']
59 |
60 | [project.scripts]
61 | wavlist_to_tar = "dasheng.prepare.wavlist_to_tar:main"
62 |
63 | [project.optional-dependencies]
64 | train = [
65 | 'accelerate>=0.28.0',
66 | 'bitsandbytes>=0.35.4',
67 | 'webdataset>=0.2.86',
68 | 'braceexpand>=0.1.7',
69 | 'fire>=0.5.0',
70 | 'loguru>=0.7.2',
71 | 'numpy>=1.24.1',
72 | 'pytorch_ignite>=0.4.13',
73 | 'PyYAML>=6.0.1',
74 | 'torch>=2.1.1',
75 | 'torchaudio>=2.1.1',
76 | 'tqdm>=4.66.1',
77 | 'pandas>=2.0',
78 | ]
79 | all = [
80 | "dasheng[train]"
81 | ]
82 |
--------------------------------------------------------------------------------