├── .gitignore
├── LICENSE
├── MANIFEST.in
├── README.md
├── VERSION
├── clap.png
├── configs
├── model
│ ├── musiclm_large.json
│ ├── musiclm_large_small_context.json
│ └── musiclm_small.json
└── training
│ ├── train_fma_preprocess.json
│ └── train_musiclm_fma.json
├── environment.yaml
├── musiclm.png
├── notebooks
└── analyze_fma.ipynb
├── open_musiclm
├── __init__.py
├── clap_quantized.py
├── config.py
├── data.py
├── encodec_wrapper.py
├── hf_hubert_kmeans.py
├── laion_clap
│ ├── __init__.py
│ ├── clap_module
│ │ ├── __init__.py
│ │ ├── bert.py
│ │ ├── bpe_simple_vocab_16e6.txt.gz
│ │ ├── factory.py
│ │ ├── feature_fusion.py
│ │ ├── htsat.py
│ │ ├── linear_probe.py
│ │ ├── loss.py
│ │ ├── model.py
│ │ ├── model_configs
│ │ │ ├── HTSAT-base.json
│ │ │ ├── HTSAT-large.json
│ │ │ ├── HTSAT-tiny-win-1536.json
│ │ │ ├── HTSAT-tiny.json
│ │ │ ├── PANN-10.json
│ │ │ ├── PANN-14-fmax-18k.json
│ │ │ ├── PANN-14-fmax-8k-20s.json
│ │ │ ├── PANN-14-tiny-transformer.json
│ │ │ ├── PANN-14-win-1536.json
│ │ │ ├── PANN-14.json
│ │ │ ├── PANN-6.json
│ │ │ ├── RN101-quickgelu.json
│ │ │ ├── RN101.json
│ │ │ ├── RN50-quickgelu.json
│ │ │ ├── RN50.json
│ │ │ ├── RN50x16.json
│ │ │ ├── RN50x4.json
│ │ │ ├── ViT-B-16.json
│ │ │ ├── ViT-B-32-quickgelu.json
│ │ │ ├── ViT-B-32.json
│ │ │ └── ViT-L-14.json
│ │ ├── openai.py
│ │ ├── pann_model.py
│ │ ├── pretrained.py
│ │ ├── timm_model.py
│ │ ├── tokenizer.py
│ │ ├── transform.py
│ │ ├── utils.py
│ │ └── version.py
│ └── hook.py
├── model_types.py
├── open_musiclm.py
├── optimizer.py
├── preprocess.py
├── trainer.py
├── transformer.py
└── utils.py
├── scripts
├── __init__.py
├── download_checkpoints.sh
├── download_fma_large.sh
├── download_fma_metadata.sh
├── infer.py
├── infer_coarse.py
├── infer_fine.py
├── infer_top_match.py
├── preprocess_data.py
├── test
│ ├── __init__.py
│ ├── test_clap.py
│ ├── test_config.py
│ ├── test_dataloader.py
│ ├── test_encodec.py
│ ├── test_hubert_clustering.py
│ ├── test_load_preprocessed.py
│ └── test_rvq.py
├── train_clap_rvq.py
├── train_coarse_stage.py
├── train_fine_stage.py
├── train_hubert_kmeans.py
├── train_semantic_stage.py
└── train_utils.py
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | data/
3 | checkpoints/
4 | results/
5 | test.wav
6 | logs/
7 | .vscode
8 | wandb
9 |
10 | **/630k-*.pt
11 |
12 | # Byte-compiled / optimized / DLL files
13 | __pycache__/
14 | *.py[cod]
15 | *$py.class
16 |
17 | # C extensions
18 | *.so
19 |
20 | # Distribution / packaging
21 | .Python
22 | build/
23 | develop-eggs/
24 | dist/
25 | downloads/
26 | eggs/
27 | .eggs/
28 | lib/
29 | lib64/
30 | parts/
31 | sdist/
32 | var/
33 | wheels/
34 | pip-wheel-metadata/
35 | share/python-wheels/
36 | *.egg-info/
37 | .installed.cfg
38 | *.egg
39 | MANIFEST
40 |
41 | # PyInstaller
42 | # Usually these files are written by a python script from a template
43 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
44 | *.manifest
45 | *.spec
46 |
47 | # Installer logs
48 | pip-log.txt
49 | pip-delete-this-directory.txt
50 |
51 | # Unit test / coverage reports
52 | htmlcov/
53 | .tox/
54 | .nox/
55 | .coverage
56 | .coverage.*
57 | .cache
58 | nosetests.xml
59 | coverage.xml
60 | *.cover
61 | *.py,cover
62 | .hypothesis/
63 | .pytest_cache/
64 |
65 | # Translations
66 | *.mo
67 | *.pot
68 |
69 | # Django stuff:
70 | *.log
71 | local_settings.py
72 | db.sqlite3
73 | db.sqlite3-journal
74 |
75 | # Flask stuff:
76 | instance/
77 | .webassets-cache
78 |
79 | # Scrapy stuff:
80 | .scrapy
81 |
82 | # Sphinx documentation
83 | docs/_build/
84 |
85 | # PyBuilder
86 | target/
87 |
88 | # Jupyter Notebook
89 | .ipynb_checkpoints
90 |
91 | # IPython
92 | profile_default/
93 | ipython_config.py
94 |
95 | # pyenv
96 | .python-version
97 |
98 | # pipenv
99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
102 | # install all needed dependencies.
103 | #Pipfile.lock
104 |
105 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
106 | __pypackages__/
107 |
108 | # Celery stuff
109 | celerybeat-schedule
110 | celerybeat.pid
111 |
112 | # SageMath parsed files
113 | *.sage.py
114 |
115 | # Environments
116 | .env
117 | .venv
118 | env/
119 | venv/
120 | ENV/
121 | env.bak/
122 | venv.bak/
123 |
124 | # Spyder project settings
125 | .spyderproject
126 | .spyproject
127 |
128 | # Rope project settings
129 | .ropeproject
130 |
131 | # mkdocs documentation
132 | /site
133 |
134 | # mypy
135 | .mypy_cache/
136 | .dmypy.json
137 | dmypy.json
138 |
139 | # Pyre type checker
140 | .pyre/
141 |
142 | .DS_Store
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Allen Zhang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | recursive-include open_musiclm/laion_clap/clap_module/model_configs *.json
2 | recursive-include open_musiclm/laion_clap/clap_module bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Open MusicLM
2 | Pytorch implementation of [MusicLM](https://arxiv.org/abs/2301.11325), a SOTA text to music model published by Google, with a few modifications. We use [CLAP](https://github.com/LAION-AI/CLAP) as a replacement for MuLan, [Encodec](https://github.com/facebookresearch/encodec) as a replacement for SoundStream, and [MERT](https://huggingface.co/m-a-p/MERT-v0) as a replacement for w2v-BERT.
3 |
4 |
5 |
6 |
7 |
8 |
9 | ## Why CLAP?
10 | CLAP is a joint audio-text model trained on [LAION-Audio-630K](https://github.com/LAION-AI/audio-dataset). Similar to MuLan, it consists of an audio tower and a text tower that project their respective media onto a shared latent space (512 dimensions in CLAP vs 128 dimensions in MuLan).
11 |
12 | MuLan was trained on 50 million text-music pairs. Unfortunately I don't have the data to replicate this, so I'm relying on CLAP's pretrained checkpoints to come close. CLAP was trained on 2.6 million total text-audio pairs from LAION-630k (~633k text-audio pairs) and AudioSet (2 million samples with captions generated by a keyword-to-caption model). Although this is a fraction of the data used to train MuLan, we have successfully used CLAP to generate diverse music samples, which you can listen to [here](https://drive.google.com/drive/folders/1pGY8EP2EZlE2pPpXn5E3YAkoCgATWw3_) (keep in mind these are very early results). In the event that CLAP's latent space is not expressive enough for music generation, we can train CLAP on music or substitute the model for @lucidrain's [MuLan implementation](https://github.com/lucidrains/musiclm-pytorch) once it is trained.
13 |
14 | ## Why Encodec?
15 | SoundStream and Encodec are both neural audio codecs that encode any waveform to a sequence of acoustic tokens, which can then be decoded into a waveform resembling the original. These intermediate tokens can then be modeled as a seq2seq task. [Encodec](https://github.com/facebookresearch/encodec) is released by Facebook and pretrained checkpoints are publicly available, whereas this is not the case with SoundStream.
16 |
17 | ## Differences from @lucidrains implementation
18 | - Autoregressively models the CLAP/MuLan conditioning signal by passing it into the transformers as discrete tokens, as mentioned in section 3.1 of the paper. Musiclm-pytorch conditions on them with cross attention.
19 | - TokenConditionedTransformer can support variable token sequences, which makes it easy to do further experimentation (e.g. combining multiple conditioning signals, stereo waveform generation, etc.)
20 | - Uses existing open source models instead of training MuLan and SoundStream.
21 | - Some modifications to increase the chance of successfully training the model.
22 |
23 | # End Goal
24 | The goal of this project is to replicate the results of MusicLM as quickly as possible without necessarily sticking to the architecture in the paper. For those looking for a more true-to-form implementation, check out [musiclm-pytorch](https://github.com/lucidrains/musiclm-pytorch).
25 |
26 | We also seek to gain a better understanding of CLAP's latent space.
27 |
28 | Join us on discord if you'd like to get involved! [
](https://discord.gg/jN8jADShX5)
29 |
30 | # Usage
31 | ## Install
32 | ```shell
33 | conda env create -f environment.yaml
34 | conda activate open-musiclm
35 | ```
36 |
37 | ## Configs
38 | A "model config" contains information about the model architecture such as the number of layers, number of quantizers, target audio lengths for each stage, etc. It is used to instantiate the model during training and inference.
39 |
40 | A "training config" contains hyperparameters for training the model. It is used to instantiate the trainer classes during training.
41 |
42 | See the `./configs` directory for example configs.
43 |
44 | ## Training
45 | ### CLAP RVQ
46 | The first step is to train the residual vector quantizer that maps continuous CLAP embeds to a discrete token sequence.
47 | ```shell
48 | python ./scripts/train_clap_rvq.py \
49 | --results_folder ./results/clap_rvq \ # where to save results and checkpoints
50 | --model_config ./configs/model/musiclm_small.json \ # path to model config
51 | --training_config ./configs/training/train_musiclm_fma.json # path to training config
52 | ```
53 |
54 | ### Hubert K-means
55 | Next, we learn a K-means layer that we use to quantize our MERT embeddings into semantic tokens.
56 | ```shell
57 | python ./scripts/train_hubert_kmeans.py \
58 | --results_folder ./results/hubert_kmeans \ # where to save results and checkpoints
59 | --model_config ./configs/model/musiclm_small.json \
60 | --training_config ./configs/training/train_musiclm_fma.json
61 | ```
62 |
63 | ### Semantic Stage + Coarse Stage + Fine Stage
64 | Once we have a working K-means and RVQ, we can now train the semantic, coarse and fine stages. These stages can be trained concurrently.
65 | ```shell
66 | python ./scripts/train_semantic_stage.py \
67 | --results_folder ./results/semantic \ # where to save results and checkpoints
68 | --model_config ./configs/model/musiclm_small.json \
69 | --training_config ./configs/training/train_musiclm_fma.json \
70 | --rvq_path PATH_TO_RVQ_CHECKPOINT \ # path to previously trained rvq
71 | --kmeans_path PATH_TO_KMEANS_CHECKPOINT # path to previously trained kmeans
72 | ```
73 | ```shell
74 | python ./scripts/train_coarse_stage.py \
75 | --results_folder ./results/coarse \ # where to save results and checkpoints
76 | --model_config ./configs/model/musiclm_small.json \
77 | --training_config ./configs/training/train_musiclm_fma.json \
78 | --rvq_path PATH_TO_RVQ_CHECKPOINT \ # path to previously trained rvq
79 | --kmeans_path PATH_TO_KMEANS_CHECKPOINT # path to previously trained kmeans
80 | ```
81 | ```shell
82 | python ./scripts/train_fine_stage.py \
83 | --results_folder ./results/fine \ # where to save results and checkpoints
84 | --model_config ./configs/model/musiclm_small.json \
85 | --training_config ./configs/training/train_musiclm_fma.json \
86 | --rvq_path PATH_TO_RVQ_CHECKPOINT \ # path to previously trained rvq
87 | --kmeans_path PATH_TO_KMEANS_CHECKPOINT # path to previously trained kmeans
88 | ```
89 |
90 | ## Preprocessing
91 | In the above case, we are using CLAP, Hubert and Encodec to generate clap, semantic and acoustic tokens live during training. However, these models take up space on the GPU, and it is inefficient to recompute these tokens if we're making multiple runs on the same data. We can instead compute these tokens ahead of time and iterate over them during training.
92 |
93 | To do this, fill in the `data_preprocessor_cfg` field in the config and set `use_preprocessed_data` to True in the trainer configs (look at train_fma_preprocess.json for inspiration). Then run the following to preprocess the dataset, followed by your training script.
94 |
95 | ```shell
96 | python ./scripts/preprocess_data.py \
97 | --model_config ./configs/model/musiclm_small.json \
98 | --training_config ./configs/training/train_fma_preprocess.json \
99 | --rvq_path PATH_TO_RVQ_CHECKPOINT \ # path to previously trained rvq
100 | --kmeans_path PATH_TO_KMEANS_CHECKPOINT # path to previously trained kmeans
101 | ```
102 |
103 | ## Inference
104 | Generate multiple samples and use CLAP to select the best ones:
105 | ```shell
106 | python scripts/infer_top_match.py \
107 | "your text prompt"
108 | --num_samples 4 # number of samples to generate
109 | --num_top_matches 1 # number of top matches to return
110 | --semantic_path PATH_TO_SEMANTIC_CHECKPOINT \ # path to previously trained semantic stage
111 | --coarse_path PATH_TO_COARSE_CHECKPOINT \ # path to previously trained coarse stage
112 | --fine_path PATH_TO_FINE_CHECKPOINT \ # path to previously trained fine stage
113 | --rvq_path PATH_TO_RVQ_CHECKPOINT \ # path to previously trained rvq
114 | --kmeans_path PATH_TO_KMEANS_CHECKPOINT # path to previously trained kmeans
115 | --model_config ./configs/model/musiclm_small.json \
116 | --duration 4
117 | ```
118 |
119 | Generate samples for various test prompts:
120 | ```shell
121 | python scripts/infer.py \
122 | --semantic_path PATH_TO_SEMANTIC_CHECKPOINT \ # path to previously trained semantic stage
123 | --coarse_path PATH_TO_COARSE_CHECKPOINT \ # path to previously trained coarse stage
124 | --fine_path PATH_TO_FINE_CHECKPOINT \ # path to previously trained fine stage
125 | --rvq_path PATH_TO_RVQ_CHECKPOINT \ # path to previously trained rvq
126 | --kmeans_path PATH_TO_KMEANS_CHECKPOINT # path to previously trained kmeans
127 | --model_config ./configs/model/musiclm_small.json \
128 | --duration 4
129 | ```
130 |
131 | You can use the `--return_coarse_wave` flag to skip the fine stage and reconstruct audio from coarse tokens alone.
132 |
133 | ## Checkpoints
134 | You can download experimental checkpoints for the musiclm_large_small_context model [here](https://drive.google.com/drive/u/0/folders/1347glwEc-6XWulfU7NGrFrYTvTnjeVJE). To fine tune the model, call the train scripts with the `--fine_tune_from` flag.
135 |
136 | # Thank you
137 | * [Okio](https://okio.ai/) for providing compute to train the model! Okio is a startup that is developing Nendo - an open source generative music tool-suite
138 | that re-imagines music. If you're interested check them out at [okio.ai](https://okio.ai/)
139 | * [@lucidrains](https://github.com/lucidrains/) for the [audiolm-pytorch](https://github.com/lucidrains/audiolm-pytorch) implementation. This repo contains a refactored version of a lot of the code in [audiolm-pytorch](https://github.com/lucidrains/audiolm-pytorch).
140 | * [LAION](https://laion.ai/) for [CLAP](https://github.com/LAION-AI/CLAP)
141 | * [Music Audio Pretrain team](https://huggingface.co/m-a-p) for [MERT](https://huggingface.co/m-a-p/MERT-v0)
142 |
143 | # Citations
144 | ```bibtex
145 | @inproceedings{Agostinelli2023MusicLMGM,
146 | title = {MusicLM: Generating Music From Text},
147 | author = {Andrea Agostinelli and Timo I. Denk and Zal{\'a}n Borsos and Jesse Engel and Mauro Verzetti and Antoine Caillon and Qingqing Huang and Aren Jansen and Adam Roberts and Marco Tagliasacchi and Matthew Sharifi and Neil Zeghidour and C. Frank},
148 | year = {2023}
149 | }
150 | ```
151 | ```bibtex
152 | @article{wu2022large,
153 | title = {Large-scale Contrastive Language-Audio Pretraining with Feature Fusion and Keyword-to-Caption Augmentation},
154 | author = {Wu, Yusong and Chen, Ke and Zhang, Tianyu and Hui, Yuchen and Berg-Kirkpatrick, Taylor and Dubnov, Shlomo},
155 | journal = {arXiv preprint arXiv:2211:06687},
156 | year = {2022},
157 | }
158 | ```
159 | ```bibtex
160 | @article{defossez2022highfi,
161 | title = {High Fidelity Neural Audio Compression},
162 | author = {Défossez, Alexandre and Copet, Jade and Synnaeve, Gabriel and Adi, Yossi},
163 | journal = {arXiv preprint arXiv:2210.13438},
164 | year = {2022}
165 | }
166 | ```
167 | ```bibtex
168 | @misc{li2023mert,
169 | title = {MERT: Acoustic Music Understanding Model with Large-Scale Self-supervised Training},
170 | author = {Yizhi Li and Ruibin Yuan and Ge Zhang and Yinghao Ma and Xingran Chen and Hanzhi Yin and Chenghua Lin and Anton Ragni and Emmanouil Benetos and Norbert Gyenge and Roger Dannenberg and Ruibo Liu and Wenhu Chen and Gus Xia and Yemin Shi and Wenhao Huang and Yike Guo and Jie Fu},
171 | year = {2023},
172 | eprint = {2306.00107},
173 | archivePrefix = {arXiv},
174 | primaryClass = {cs.SD}
175 | }
176 | ```
177 |
--------------------------------------------------------------------------------
/VERSION:
--------------------------------------------------------------------------------
1 | 0.2.5
--------------------------------------------------------------------------------
/clap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhvng/open-musiclm/8e2c6a8d65a33d407072bab8f5af64d9cb49f2e4/clap.png
--------------------------------------------------------------------------------
/configs/model/musiclm_large.json:
--------------------------------------------------------------------------------
1 | {
2 | "global_cfg": {
3 | "semantic_audio_length_seconds": 30.0,
4 | "coarse_audio_length_seconds": 10.0,
5 | "fine_audio_length_seconds": 3.0,
6 | "clap_audio_length_seconds": 30.0,
7 | "num_coarse_quantizers": 3,
8 | "num_fine_quantizers": 5
9 | },
10 | "clap_rvq_cfg": {
11 | "enable_fusion": true,
12 | "rq_num_quantizers": 12,
13 | "codebook_size": 1024,
14 | "rq_ema_decay": 0.95,
15 | "threshold_ema_dead_code": 0.5
16 | },
17 | "hubert_kmeans_cfg": {
18 | "model_name": "m-a-p/MERT-v0",
19 | "normalize_embeds": true,
20 | "embed_layer": 7,
21 | "target_sample_hz": 16000,
22 | "seq_len_multiple_of": 320,
23 | "codebook_size": 1024,
24 | "output_hz": 50
25 | },
26 | "encodec_cfg": {
27 | "bandwidth": 6.0,
28 | "codebook_size": 1024,
29 | "output_hz": 75
30 | },
31 | "semantic_cfg": {
32 | "dim": 1024,
33 | "depth": 24,
34 | "heads": 16,
35 | "attn_dropout": 0.0,
36 | "ff_dropout": 0.1,
37 | "grad_shrink_alpha": 0.1,
38 | "non_causal_prefix_size": 0,
39 | "relative_position_bias_type": "continuous",
40 | "use_memory_efficient_attention": false
41 | },
42 | "coarse_cfg": {
43 | "dim": 1024,
44 | "depth": 24,
45 | "heads": 16,
46 | "attn_dropout": 0.0,
47 | "ff_dropout": 0.1,
48 | "grad_shrink_alpha": 0.1,
49 | "non_causal_prefix_size": 0,
50 | "relative_position_bias_type": "continuous",
51 | "use_memory_efficient_attention": false
52 | },
53 | "fine_cfg": {
54 | "dim": 1024,
55 | "depth": 24,
56 | "heads": 16,
57 | "attn_dropout": 0.0,
58 | "ff_dropout": 0.1,
59 | "grad_shrink_alpha": 0.1,
60 | "non_causal_prefix_size": 0,
61 | "relative_position_bias_type": "continuous",
62 | "use_memory_efficient_attention": false
63 | }
64 | }
--------------------------------------------------------------------------------
/configs/model/musiclm_large_small_context.json:
--------------------------------------------------------------------------------
1 | {
2 | "global_cfg": {
3 | "semantic_audio_length_seconds": 10.0,
4 | "coarse_audio_length_seconds": 4.0,
5 | "fine_audio_length_seconds": 2.0,
6 | "clap_audio_length_seconds": 10.0,
7 | "num_coarse_quantizers": 3,
8 | "num_fine_quantizers": 5
9 | },
10 | "clap_rvq_cfg": {
11 | "enable_fusion": false,
12 | "rq_num_quantizers": 12,
13 | "codebook_size": 1024,
14 | "rq_ema_decay": 0.95,
15 | "threshold_ema_dead_code": 0.5
16 | },
17 | "hubert_kmeans_cfg": {
18 | "model_name": "m-a-p/MERT-v0",
19 | "normalize_embeds": true,
20 | "embed_layer": 7,
21 | "target_sample_hz": 16000,
22 | "seq_len_multiple_of": 320,
23 | "codebook_size": 1024,
24 | "output_hz": 50
25 | },
26 | "encodec_cfg": {
27 | "bandwidth": 6.0,
28 | "codebook_size": 1024,
29 | "output_hz": 75
30 | },
31 | "semantic_cfg": {
32 | "dim": 1024,
33 | "depth": 24,
34 | "heads": 16,
35 | "attn_dropout": 0.0,
36 | "ff_dropout": 0.1,
37 | "grad_shrink_alpha": 0.1,
38 | "non_causal_prefix_size": 0,
39 | "relative_position_bias_type": "continuous",
40 | "use_memory_efficient_attention": false
41 | },
42 | "coarse_cfg": {
43 | "dim": 1024,
44 | "depth": 24,
45 | "heads": 16,
46 | "attn_dropout": 0.0,
47 | "ff_dropout": 0.1,
48 | "grad_shrink_alpha": 0.1,
49 | "non_causal_prefix_size": 0,
50 | "relative_position_bias_type": "continuous",
51 | "use_memory_efficient_attention": false
52 | },
53 | "fine_cfg": {
54 | "dim": 1024,
55 | "depth": 24,
56 | "heads": 16,
57 | "attn_dropout": 0.0,
58 | "ff_dropout": 0.1,
59 | "grad_shrink_alpha": 0.1,
60 | "non_causal_prefix_size": 0,
61 | "relative_position_bias_type": "continuous",
62 | "use_memory_efficient_attention": false
63 | }
64 | }
--------------------------------------------------------------------------------
/configs/model/musiclm_small.json:
--------------------------------------------------------------------------------
1 | {
2 | "global_cfg": {
3 | "semantic_audio_length_seconds": 10.0,
4 | "coarse_audio_length_seconds": 4.0,
5 | "fine_audio_length_seconds": 2.0,
6 | "clap_audio_length_seconds": 10.0,
7 | "num_coarse_quantizers": 3,
8 | "num_fine_quantizers": 5
9 | },
10 | "clap_rvq_cfg": {
11 | "enable_fusion": false,
12 | "rq_num_quantizers": 12,
13 | "codebook_size": 1024,
14 | "rq_ema_decay": 0.95,
15 | "threshold_ema_dead_code": 0.5
16 | },
17 | "hubert_kmeans_cfg": {
18 | "model_name": "m-a-p/MERT-v0",
19 | "normalize_embeds": true,
20 | "embed_layer": 7,
21 | "target_sample_hz": 16000,
22 | "seq_len_multiple_of": 320,
23 | "codebook_size": 1024,
24 | "output_hz": 50
25 | },
26 | "encodec_cfg": {
27 | "bandwidth": 6.0,
28 | "codebook_size": 1024,
29 | "output_hz": 75
30 | },
31 | "semantic_cfg": {
32 | "dim": 1024,
33 | "depth": 6,
34 | "heads": 8,
35 | "attn_dropout": 0.0,
36 | "ff_dropout": 0.1,
37 | "grad_shrink_alpha": 0.1,
38 | "non_causal_prefix_size": 0,
39 | "relative_position_bias_type": "continuous",
40 | "use_memory_efficient_attention": false
41 | },
42 | "coarse_cfg": {
43 | "dim": 1024,
44 | "depth": 6,
45 | "heads": 8,
46 | "attn_dropout": 0.0,
47 | "ff_dropout": 0.1,
48 | "grad_shrink_alpha": 0.1,
49 | "non_causal_prefix_size": 0,
50 | "relative_position_bias_type": "continuous",
51 | "use_memory_efficient_attention": false
52 | },
53 | "fine_cfg": {
54 | "dim": 1024,
55 | "depth": 6,
56 | "heads": 8,
57 | "attn_dropout": 0.0,
58 | "ff_dropout": 0.1,
59 | "grad_shrink_alpha": 0.1,
60 | "non_causal_prefix_size": 0,
61 | "relative_position_bias_type": "continuous",
62 | "use_memory_efficient_attention": false
63 | }
64 | }
--------------------------------------------------------------------------------
/configs/training/train_fma_preprocess.json:
--------------------------------------------------------------------------------
1 | {
2 | "clap_rvq_trainer_cfg": {
3 | "folder": "./data/fma_large",
4 | "num_train_steps": 1000,
5 | "batch_size": 64,
6 | "accumulate_batches": 32,
7 | "save_model_every": 10,
8 | "save_results_every": 5
9 | },
10 | "hubert_kmeans_trainer_cfg": {
11 | "folder": "./data/fma_large",
12 | "feature_extraction_num_steps": 320,
13 | "feature_extraction_batch_size": 32
14 | },
15 | "semantic_trainer_cfg": {
16 | "stage": "semantic",
17 | "folder": "./data/fma_preprocessed",
18 | "valid_frac": 0.05,
19 | "lr": 0.0003,
20 | "lr_warmup": 3000,
21 | "batch_size": 4,
22 | "grad_accum_every": 8,
23 | "wd": 0.01,
24 | "max_grad_norm": 0.5,
25 | "cross_entropy_loss_weights": [0.0, 1.0],
26 | "num_train_steps": 200001,
27 | "save_results_every": 250,
28 | "save_model_every": 1000,
29 | "save_predicted_tokens": true,
30 | "save_reconstructed_wave": true,
31 | "use_preprocessed_data": true
32 | },
33 | "coarse_trainer_cfg": {
34 | "stage": "coarse",
35 | "folder": "./data/fma_preprocessed",
36 | "valid_frac": 0.05,
37 | "lr": 0.0003,
38 | "lr_warmup": 6000,
39 | "batch_size": 2,
40 | "grad_accum_every": 8,
41 | "wd": 0.01,
42 | "max_grad_norm": 0.5,
43 | "cross_entropy_loss_weights": [0.0, 0.0, 1.0],
44 | "num_train_steps": 200001,
45 | "save_results_every": 250,
46 | "save_model_every": 1000,
47 | "save_predicted_tokens": true,
48 | "save_reconstructed_wave": true,
49 | "use_preprocessed_data": true
50 | },
51 | "fine_trainer_cfg": {
52 | "stage": "fine",
53 | "folder": "./data/fma_preprocessed",
54 | "valid_frac": 0.05,
55 | "lr": 0.0003,
56 | "lr_warmup": 0,
57 | "batch_size": 2,
58 | "grad_accum_every": 8,
59 | "wd": 0.01,
60 | "max_grad_norm": 0.5,
61 | "cross_entropy_loss_weights": [0.0, 0.0, 1.0],
62 | "num_train_steps": 200001,
63 | "save_results_every": 250,
64 | "save_model_every": 1000,
65 | "save_predicted_tokens": true,
66 | "save_reconstructed_wave": true,
67 | "use_preprocessed_data": true
68 | },
69 | "data_preprocessor_cfg": {
70 | "folder": "./data/fma_large",
71 | "metadata_folder": "./data/fma_metadata",
72 | "results_folder": "./data/fma_preprocessed",
73 | "max_audio_length_seconds": 30,
74 | "random_crop": true,
75 | "num_crops": 1,
76 | "clap_batch_size": 32
77 | }
78 | }
--------------------------------------------------------------------------------
/configs/training/train_musiclm_fma.json:
--------------------------------------------------------------------------------
1 | {
2 | "clap_rvq_trainer_cfg": {
3 | "folder": "./data/fma_large",
4 | "num_train_steps": 1000,
5 | "batch_size": 64,
6 | "accumulate_batches": 32,
7 | "save_model_every": 10,
8 | "save_results_every": 5
9 | },
10 | "hubert_kmeans_trainer_cfg": {
11 | "folder": "./data/fma_large",
12 | "feature_extraction_num_steps": 320,
13 | "feature_extraction_batch_size": 32
14 | },
15 | "semantic_trainer_cfg": {
16 | "stage": "semantic",
17 | "folder": "./data/fma_large",
18 | "valid_frac": 0.05,
19 | "lr": 0.0003,
20 | "lr_warmup": 3000,
21 | "batch_size": 4,
22 | "grad_accum_every": 8,
23 | "wd": 0.01,
24 | "max_grad_norm": 0.5,
25 | "cross_entropy_loss_weights": [0.0, 1.0],
26 | "num_train_steps": 200001,
27 | "save_results_every": 250,
28 | "save_model_every": 1000,
29 | "save_predicted_tokens": true,
30 | "save_reconstructed_wave": true,
31 | "use_preprocessed_data": false
32 | },
33 | "coarse_trainer_cfg": {
34 | "stage": "coarse",
35 | "folder": "./data/fma_large",
36 | "valid_frac": 0.05,
37 | "lr": 0.0003,
38 | "lr_warmup": 6000,
39 | "batch_size": 2,
40 | "grad_accum_every": 8,
41 | "wd": 0.01,
42 | "max_grad_norm": 0.5,
43 | "cross_entropy_loss_weights": [0.0, 0.0, 1.0],
44 | "num_train_steps": 200001,
45 | "save_results_every": 250,
46 | "save_model_every": 1000,
47 | "save_predicted_tokens": true,
48 | "save_reconstructed_wave": true,
49 | "use_preprocessed_data": false
50 | },
51 | "fine_trainer_cfg": {
52 | "stage": "fine",
53 | "folder": "./data/fma_large",
54 | "valid_frac": 0.05,
55 | "lr": 0.0003,
56 | "lr_warmup": 0,
57 | "batch_size": 2,
58 | "grad_accum_every": 8,
59 | "wd": 0.01,
60 | "max_grad_norm": 0.5,
61 | "cross_entropy_loss_weights": [0.0, 0.0, 1.0],
62 | "num_train_steps": 200001,
63 | "save_results_every": 250,
64 | "save_model_every": 1000,
65 | "save_predicted_tokens": true,
66 | "save_reconstructed_wave": true,
67 | "use_preprocessed_data": false
68 | },
69 | "data_preprocessor_cfg": {}
70 | }
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: open-musiclm
2 | channels:
3 | - conda-forge
4 | - pytorch
5 | - defaults
6 | dependencies:
7 | - python=3.10
8 | - pip=22.3.1
9 | - pip:
10 | - --find-links https://download.pytorch.org/whl/torch_stable.html
11 | - torch==2.0.0+cu117
12 | - torchvision==0.15.1+cu117
13 | - torchaudio==2.0.1+cu117
14 | - einops>=0.4
15 | - vector-quantize-pytorch>=0.10.15
16 | - librosa==0.10.0
17 | - torchlibrosa==0.1.0
18 | - ftfy
19 | - tqdm
20 | - transformers
21 | - encodec==0.1.1
22 | - gdown
23 | - accelerate>=0.17.0
24 | - beartype
25 | - joblib
26 | - h5py
27 | - scikit-learn
28 | - wget
29 |
--------------------------------------------------------------------------------
/musiclm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhvng/open-musiclm/8e2c6a8d65a33d407072bab8f5af64d9cb49f2e4/musiclm.png
--------------------------------------------------------------------------------
/open_musiclm/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhvng/open-musiclm/8e2c6a8d65a33d407072bab8f5af64d9cb49f2e4/open_musiclm/__init__.py
--------------------------------------------------------------------------------
/open_musiclm/clap_quantized.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | import torchaudio
5 | import torchvision.transforms
6 | from beartype.typing import Dict, List, Optional, Union
7 | from einops import rearrange
8 | from torch import nn
9 | from transformers import RobertaTokenizer
10 | from vector_quantize_pytorch import ResidualVQ
11 |
12 | from .laion_clap import CLAP_Module
13 | from .utils import exists, beartype_jit
14 |
15 |
16 | @beartype_jit
17 | class ClapQuantized(nn.Module):
18 | def __init__(self,
19 | *,
20 | clap: CLAP_Module,
21 | codebook_size: int = 1024,
22 | rq_num_quantizers: int = 12,
23 | rq_ema_decay: float = 0.95,
24 | learn_rvq: bool = False,
25 | threshold_ema_dead_code: float = 0.0,
26 | ):
27 | super().__init__()
28 |
29 | self.clap = clap
30 | self.codebook_size = codebook_size
31 | self.learn_rvq = learn_rvq
32 |
33 | self.sample_rate = self.clap.model_cfg['audio_cfg']['sample_rate']
34 |
35 | for param in self.clap.parameters():
36 | param.requires_grad = False
37 |
38 | self.rq = ResidualVQ(
39 | dim=clap.model.joint_embed_shape,
40 | num_quantizers=rq_num_quantizers, # specify number of quantizers
41 | codebook_size=codebook_size, # codebook size
42 | commitment_weight=0, # embeddings are frozen so no need for commitment loss
43 | decay=rq_ema_decay,
44 | kmeans_init=True,
45 | threshold_ema_dead_code=threshold_ema_dead_code,
46 | )
47 |
48 | def forward(self,
49 | *,
50 | audio_input: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
51 | text_input: Optional[List[str]] = None,
52 | return_embedding: Optional[bool] = False,
53 | return_rvq_loss = False,
54 | ):
55 | """
56 | Wrapper for clap module that takes in audio or text and returns the quantized embedding from the respective tower
57 | """
58 |
59 | assert exists(audio_input) ^ exists(text_input), "either audio or text must be provided, but not both"
60 | if exists(audio_input):
61 | assert all(wave.dim() == 1 for wave in audio_input), f"audio_input must be a list of 1D tensors, but got {audio_input[0].shape}"
62 |
63 | with torch.no_grad():
64 | self.clap.eval()
65 | if exists(audio_input):
66 | embedding = self.clap.get_audio_embedding_from_data(audio_input)
67 | else:
68 | embedding = self.clap.get_text_embedding(text_input)
69 |
70 | if return_embedding:
71 | return embedding
72 |
73 | return self.quantize(embedding, return_rvq_loss=return_rvq_loss)
74 |
75 | def quantize(self, embedding, return_rvq_loss=False):
76 | """
77 | Quantize an embedding and optionally return the loss
78 | """
79 | with torch.set_grad_enabled(self.learn_rvq):
80 | self.rq.train(self.learn_rvq)
81 | q, indices, _ = self.rq(rearrange(embedding, 'n c -> n 1 c'))
82 |
83 | if return_rvq_loss:
84 | return F.mse_loss(q, rearrange(embedding, 'n c -> n 1 c')).item()
85 |
86 | indices = rearrange(indices, 'n 1 c -> n c 1')
87 | return indices
88 |
89 |
90 | def create_clap_quantized(
91 | device=None,
92 | learn_rvq=False,
93 | enable_fusion=False,
94 | rvq_checkpoint_path=None,
95 | checkpoint_path: Optional[str] = None,
96 | amodel_type: str = 'HTSAT-tiny',
97 | **kwargs
98 | ):
99 | if device is None:
100 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
101 |
102 | clap = CLAP_Module(enable_fusion=enable_fusion, device=device, amodel=amodel_type)
103 | clap.load_ckpt(ckpt=checkpoint_path)
104 |
105 | clap_quantized = ClapQuantized(clap=clap, learn_rvq=learn_rvq, **kwargs)
106 |
107 | if exists(rvq_checkpoint_path):
108 | rvq = torch.load(rvq_checkpoint_path, map_location=device)
109 | clap_quantized.rq.load_state_dict(rvq)
110 |
111 | return clap_quantized
112 |
--------------------------------------------------------------------------------
/open_musiclm/encodec_wrapper.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | from einops import rearrange
5 | from encodec import EncodecModel
6 | from torch import nn
7 |
8 | from .utils import exists, beartype_jit
9 |
10 |
11 | @beartype_jit
12 | class EncodecWrapper(nn.Module):
13 | def __init__(self,
14 | *,
15 | encodec: EncodecModel,
16 | output_hz: int = 75,
17 | ):
18 | super().__init__()
19 |
20 | self.encodec = encodec
21 | self.sample_rate = encodec.sample_rate
22 | self.output_hz = output_hz
23 |
24 | assert exists(encodec.bandwidth)
25 | total_quantizers = encodec.quantizer.n_q
26 | self.num_quantizers = int(encodec.bandwidth / 24 * total_quantizers) # output quantizers per frame
27 | self.codebook_size = encodec.quantizer.bins
28 |
29 | def forward(self, x: torch.Tensor, return_encoded = True, **kwargs):
30 | assert return_encoded == True
31 |
32 | if x.dim() == 2:
33 | x = rearrange(x, 'b t -> b 1 t') # add in "mono" dimension
34 |
35 | with torch.no_grad():
36 | self.encodec.eval()
37 | encoded_frames = self.encodec.encode(x)
38 | codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # [B, n_q, T]
39 | codes = rearrange(codes, 'b n_q t -> b t n_q')
40 |
41 | return None, codes, None # [B, T, n_q]
42 |
43 | def decode_from_codebook_indices(self, quantized_indices):
44 | """
45 | Args:
46 | quantized_indices: [B, T, n_q]
47 | """
48 | quantized_indices = rearrange(quantized_indices, 'b t n_q -> b n_q t')
49 |
50 | frames = [(quantized_indices, None)] # 1 frame for now
51 | with torch.no_grad():
52 | self.encodec.eval()
53 | wave = self.encodec.decode(frames)
54 | return wave
55 |
56 | def create_encodec_24khz(bandwidth: float = 6.0, codebook_size: int = 1024, **kwargs):
57 | """
58 | Create a pretrained EnCodec model.
59 | Args:
60 | bandwidth: float, target bandwidth in kHz"""
61 | assert bandwidth in [1.5, 3., 6., 12., 24.], "invalid bandwidth. must be one of [1.5, 3., 6., 12., 24.]"
62 |
63 | encodec = EncodecModel.encodec_model_24khz()
64 | encodec.set_target_bandwidth(bandwidth)
65 | encodec_wrapper = EncodecWrapper(encodec=encodec, **kwargs)
66 |
67 | assert encodec_wrapper.codebook_size == codebook_size, "encodec codebook size must be 1024 for now"
68 |
69 | return encodec_wrapper
--------------------------------------------------------------------------------
/open_musiclm/hf_hubert_kmeans.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import torch
4 | from torch import nn
5 | import numpy as np
6 | from einops import rearrange, pack, unpack
7 | from beartype.typing import Optional
8 |
9 | from torchaudio.functional import resample
10 | from .utils import exists, curtail_to_multiple, zero_mean_unit_var_norm
11 | from transformers import HubertModel
12 | from sklearn.cluster import MiniBatchKMeans
13 |
14 | import joblib
15 | import logging
16 | logging.root.setLevel(logging.ERROR)
17 |
18 |
19 | class HfHubertWithKmeans(nn.Module):
20 | """
21 | Hugging Face HubertModel + a k-means layer on top. Pretrained checkpoint for music: https://huggingface.co/m-a-p/MERT-v0
22 | Note: MERT-v0 outputs features at 50Hz while Wav2Vec-BERT (used in the paper) outputs at 25 Hz.
23 | """
24 |
25 | def __init__(
26 | self,
27 | *,
28 | hubert: HubertModel,
29 | kmeans: Optional[MiniBatchKMeans] = None,
30 | embed_layer: int=7,
31 | target_sample_hz=16000,
32 | seq_len_multiple_of=int(16000 / 50),
33 | normalize_embeds=True,
34 | codebook_size: int=1024,
35 | output_hz: int=50
36 | ):
37 | super().__init__()
38 | self.target_sample_hz = target_sample_hz
39 | self.output_hz = output_hz
40 | self.seq_len_multiple_of = seq_len_multiple_of
41 | self.codebook_size = kmeans.n_clusters if exists(kmeans) else None
42 |
43 | self.codebook_size = codebook_size
44 | if exists(kmeans):
45 | assert self.codebook_size == kmeans.n_clusters, "codebook_size must match kmeans.n_clusters"
46 |
47 | self.normalize_embeds = normalize_embeds
48 |
49 | self.embed_layer = embed_layer
50 |
51 | self.hubert = hubert
52 | self.kmeans = kmeans
53 |
54 | @torch.no_grad()
55 | def forward(
56 | self,
57 | wav_input: torch.Tensor,
58 | flatten=True,
59 | return_embed=False,
60 | input_sample_hz=None
61 | ):
62 | assert return_embed or exists(self.kmeans), "kmeans model must be provided if return_embed==False"
63 |
64 | device = wav_input.device
65 |
66 | if exists(input_sample_hz):
67 | wav_input = resample(wav_input, input_sample_hz, self.target_sample_hz)
68 |
69 | if exists(self.seq_len_multiple_of):
70 | wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of)
71 |
72 | hubert_args = {
73 | 'input_values': wav_input,
74 | 'attention_mask': torch.ones_like(wav_input, device=device), # TODO: handle padding
75 | }
76 |
77 | outputs = self.hubert(**hubert_args, output_hidden_states = True)
78 | embed = outputs.hidden_states[self.embed_layer]
79 |
80 | if self.normalize_embeds:
81 | embed = zero_mean_unit_var_norm(embed)
82 |
83 | if return_embed:
84 | return embed
85 |
86 | embed, packed_shape = pack([embed], '* d')
87 | codebook_indices = self.kmeans.predict(embed.detach().cpu().numpy())
88 | codebook_indices = torch.from_numpy(codebook_indices).to(device).long()
89 |
90 | if flatten:
91 | return codebook_indices
92 |
93 | codebook_indices, = unpack(codebook_indices, packed_shape, '*')
94 | return codebook_indices
95 |
96 |
97 | def get_kmeans_model(
98 | n_clusters,
99 | init,
100 | max_iter,
101 | batch_size,
102 | tol,
103 | max_no_improvement,
104 | n_init,
105 | reassignment_ratio,
106 | ):
107 | return MiniBatchKMeans(
108 | n_clusters=n_clusters,
109 | init=init,
110 | max_iter=max_iter,
111 | batch_size=batch_size,
112 | verbose=1,
113 | compute_labels=False,
114 | tol=tol,
115 | max_no_improvement=max_no_improvement,
116 | init_size=None,
117 | n_init=n_init,
118 | reassignment_ratio=reassignment_ratio,
119 | )
120 |
121 |
122 | def learn_kmeans(
123 | feat,
124 | seed,
125 | km_path='./results/kmeans.joblib',
126 | n_clusters=1024,
127 | init="k-means++",
128 | max_iter=100,
129 | batch_size=10000,
130 | tol=0.0,
131 | n_init=20,
132 | reassignment_ratio=0.0,
133 | max_no_improvement=100,
134 | ):
135 | np.random.seed(seed)
136 | km_model = get_kmeans_model(
137 | n_clusters,
138 | init,
139 | max_iter,
140 | batch_size,
141 | tol,
142 | max_no_improvement,
143 | n_init,
144 | reassignment_ratio,
145 | )
146 | km_model.fit(feat)
147 | joblib.dump(km_model, km_path)
148 |
149 | inertia = -km_model.score(feat) / len(feat)
150 | print("total intertia: %.5f", inertia)
151 | print("finished successfully")
152 |
153 |
154 | def get_hubert_kmeans(model_name: str="m-a-p/MERT-v0", kmeans_path: Optional[str]='./checkpoints/kmeans.joblib', **kwargs):
155 | wav2vec = HubertModel.from_pretrained(model_name)
156 | kmeans = joblib.load(kmeans_path) if exists(kmeans_path) else None
157 |
158 | return HfHubertWithKmeans(hubert=wav2vec, kmeans=kmeans, **kwargs)
159 |
--------------------------------------------------------------------------------
/open_musiclm/laion_clap/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | dir_path = os.path.dirname(os.path.abspath(__file__))
4 | sys.path.append(dir_path)
5 | from hook import CLAP_Module
--------------------------------------------------------------------------------
/open_musiclm/laion_clap/clap_module/__init__.py:
--------------------------------------------------------------------------------
1 | from .factory import list_models, create_model, create_model_and_transforms, add_model_config
2 | from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics
3 | from .model import CLAP, CLAPTextCfg, CLAPVisionCfg, CLAPAudioCfp, convert_weights_to_fp16, trace_model
4 | from .openai import load_openai_model, list_openai_models
5 | from .pretrained import list_pretrained, list_pretrained_tag_models, list_pretrained_model_tags,\
6 | get_pretrained_url, download_pretrained
7 | from .tokenizer import SimpleTokenizer, tokenize
8 | from .transform import image_transform
--------------------------------------------------------------------------------
/open_musiclm/laion_clap/clap_module/bert.py:
--------------------------------------------------------------------------------
1 | from transformers import BertTokenizer, BertModel
2 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
3 | model = BertModel.from_pretrained("bert-base-uncased")
4 | text = "Replace me by any text you'd like."
5 |
6 | def bert_embeddings(text):
7 | # text = "Replace me by any text you'd like."
8 | encoded_input = tokenizer(text, return_tensors='pt')
9 | output = model(**encoded_input)
10 | return output
11 |
12 | from transformers import RobertaTokenizer, RobertaModel
13 |
14 | tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
15 | model = RobertaModel.from_pretrained('roberta-base')
16 | text = "Replace me by any text you'd like."
17 | def Roberta_embeddings(text):
18 | # text = "Replace me by any text you'd like."
19 | encoded_input = tokenizer(text, return_tensors='pt')
20 | output = model(**encoded_input)
21 | return output
22 |
23 | from transformers import BartTokenizer, BartModel
24 |
25 | tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
26 | model = BartModel.from_pretrained('facebook/bart-base')
27 | text = "Replace me by any text you'd like."
28 | def bart_embeddings(text):
29 | # text = "Replace me by any text you'd like."
30 | encoded_input = tokenizer(text, return_tensors='pt')
31 | output = model(**encoded_input)
32 | return output
--------------------------------------------------------------------------------
/open_musiclm/laion_clap/clap_module/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhvng/open-musiclm/8e2c6a8d65a33d407072bab8f5af64d9cb49f2e4/open_musiclm/laion_clap/clap_module/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/open_musiclm/laion_clap/clap_module/factory.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 | import pathlib
5 | import re
6 | from copy import deepcopy
7 | from pathlib import Path
8 |
9 | import torch
10 |
11 | from .model import CLAP, convert_weights_to_fp16
12 | from .openai import load_openai_model
13 | from .pretrained import get_pretrained_url, download_pretrained
14 | from .transform import image_transform
15 |
16 | _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
17 | _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
18 |
19 |
20 | def _natural_key(string_):
21 | return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
22 |
23 |
24 | def _rescan_model_configs():
25 | global _MODEL_CONFIGS
26 |
27 | config_ext = (".json",)
28 | config_files = []
29 | for config_path in _MODEL_CONFIG_PATHS:
30 | if config_path.is_file() and config_path.suffix in config_ext:
31 | config_files.append(config_path)
32 | elif config_path.is_dir():
33 | for ext in config_ext:
34 | config_files.extend(config_path.glob(f"*{ext}"))
35 |
36 | for cf in config_files:
37 | with open(cf, "r") as f:
38 | model_cfg = json.load(f)
39 | if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")):
40 | _MODEL_CONFIGS[cf.stem] = model_cfg
41 |
42 | _MODEL_CONFIGS = {
43 | k: v
44 | for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))
45 | }
46 |
47 |
48 | _rescan_model_configs() # initial populate of model config registry
49 |
50 |
51 | def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True):
52 | checkpoint = torch.load(checkpoint_path, map_location=map_location)
53 | if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
54 | state_dict = checkpoint["state_dict"]
55 | else:
56 | state_dict = checkpoint
57 | if skip_params:
58 | if next(iter(state_dict.items()))[0].startswith("module"):
59 | state_dict = {k[7:]: v for k, v in state_dict.items()}
60 | # for k in state_dict:
61 | # if k.startswith('transformer'):
62 | # v = state_dict.pop(k)
63 | # state_dict['text_branch.' + k[12:]] = v
64 | return state_dict
65 |
66 |
67 | def create_model(
68 | amodel_name: str,
69 | tmodel_name: str,
70 | pretrained: str = "",
71 | precision: str = "fp32",
72 | device: torch.device = torch.device("cpu"),
73 | jit: bool = False,
74 | force_quick_gelu: bool = False,
75 | openai_model_cache_dir: str = os.path.expanduser("~/.cache/clip"),
76 | skip_params=True,
77 | pretrained_audio: str = "",
78 | pretrained_text: str = "",
79 | enable_fusion: bool = False,
80 | fusion_type: str = 'None'
81 | # pretrained_image: bool = False,
82 | ):
83 | amodel_name = amodel_name.replace(
84 | "/", "-"
85 | ) # for callers using old naming with / in ViT names
86 | pretrained_orig = pretrained
87 | pretrained = pretrained.lower()
88 | if pretrained == "openai":
89 | if amodel_name in _MODEL_CONFIGS:
90 | logging.info(f"Loading {amodel_name} model config.")
91 | model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
92 | else:
93 | logging.error(
94 | f"Model config for {amodel_name} not found; available models {list_models()}."
95 | )
96 | raise RuntimeError(f"Model config for {amodel_name} not found.")
97 |
98 | logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.")
99 | # Hard Code in model name
100 | model_cfg["text_cfg"]["model_type"] = tmodel_name
101 | model = load_openai_model(
102 | "ViT-B-16",
103 | model_cfg,
104 | device=device,
105 | jit=jit,
106 | cache_dir=openai_model_cache_dir,
107 | enable_fusion=enable_fusion,
108 | fusion_type=fusion_type
109 | )
110 | # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
111 | if precision == "amp" or precision == "fp32":
112 | model = model.float()
113 | else:
114 | if amodel_name in _MODEL_CONFIGS:
115 | logging.info(f"Loading {amodel_name} model config.")
116 | model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
117 | else:
118 | logging.error(
119 | f"Model config for {amodel_name} not found; available models {list_models()}."
120 | )
121 | raise RuntimeError(f"Model config for {amodel_name} not found.")
122 |
123 | if force_quick_gelu:
124 | # override for use of QuickGELU on non-OpenAI transformer models
125 | model_cfg["quick_gelu"] = True
126 |
127 | # if pretrained_image:
128 | # if 'timm_amodel_name' in model_cfg.get('vision_cfg', {}):
129 | # # pretrained weight loading for timm models set via vision_cfg
130 | # model_cfg['vision_cfg']['timm_model_pretrained'] = True
131 | # else:
132 | # assert False, 'pretrained image towers currently only supported for timm models'
133 | model_cfg["text_cfg"]["model_type"] = tmodel_name
134 | model_cfg["enable_fusion"] = enable_fusion
135 | model_cfg["fusion_type"] = fusion_type
136 | model = CLAP(**model_cfg)
137 |
138 | if pretrained:
139 | checkpoint_path = ""
140 | url = get_pretrained_url(amodel_name, pretrained)
141 | if url:
142 | checkpoint_path = download_pretrained(url, root=openai_model_cache_dir)
143 | elif os.path.exists(pretrained_orig):
144 | checkpoint_path = pretrained_orig
145 | if checkpoint_path:
146 | logging.info(f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained}).")
147 | ckpt = load_state_dict(checkpoint_path, skip_params=True)
148 | model.load_state_dict(ckpt)
149 | param_names = [n for n, p in model.named_parameters()]
150 | for n in param_names:
151 | print(n, "\t", "Loaded" if n in ckpt else "Unloaded")
152 | else:
153 | logging.warning(
154 | f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
155 | )
156 | raise RuntimeError(
157 | f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
158 | )
159 |
160 | if pretrained_audio:
161 | if amodel_name.startswith('PANN'):
162 | if 'Cnn14_mAP' in pretrained_audio: # official checkpoint
163 | audio_ckpt = torch.load(pretrained_audio, map_location='cpu')
164 | audio_ckpt = audio_ckpt['model']
165 | keys = list(audio_ckpt.keys())
166 | for key in keys:
167 | if 'spectrogram_extractor' not in key and 'logmel_extractor' not in key:
168 | v = audio_ckpt.pop(key)
169 | audio_ckpt['audio_branch.' + key] = v
170 | elif os.path.basename(pretrained_audio).startswith('PANN'): # checkpoint trained via HTSAT codebase
171 | audio_ckpt = torch.load(pretrained_audio, map_location='cpu')
172 | audio_ckpt = audio_ckpt['state_dict']
173 | keys = list(audio_ckpt.keys())
174 | for key in keys:
175 | if key.startswith('sed_model'):
176 | v = audio_ckpt.pop(key)
177 | audio_ckpt['audio_branch.' + key[10:]] = v
178 | elif os.path.basename(pretrained_audio).startswith('finetuned'): # checkpoint trained via linear probe codebase
179 | audio_ckpt = torch.load(pretrained_audio, map_location='cpu')
180 | else:
181 | raise ValueError('Unknown audio checkpoint')
182 | elif amodel_name.startswith('HTSAT'):
183 | if 'HTSAT_AudioSet_Saved' in pretrained_audio: # official checkpoint
184 | audio_ckpt = torch.load(pretrained_audio, map_location='cpu')
185 | audio_ckpt = audio_ckpt['state_dict']
186 | keys = list(audio_ckpt.keys())
187 | for key in keys:
188 | if key.startswith('sed_model') and ('spectrogram_extractor' not in key
189 | and 'logmel_extractor' not in key):
190 | v = audio_ckpt.pop(key)
191 | audio_ckpt['audio_branch.' + key[10:]] = v
192 | elif os.path.basename(pretrained_audio).startswith('HTSAT'): # checkpoint trained via HTSAT codebase
193 | audio_ckpt = torch.load(pretrained_audio, map_location='cpu')
194 | audio_ckpt = audio_ckpt['state_dict']
195 | keys = list(audio_ckpt.keys())
196 | for key in keys:
197 | if key.startswith('sed_model'):
198 | v = audio_ckpt.pop(key)
199 | audio_ckpt['audio_branch.' + key[10:]] = v
200 | elif os.path.basename(pretrained_audio).startswith('finetuned'): # checkpoint trained via linear probe codebase
201 | audio_ckpt = torch.load(pretrained_audio, map_location='cpu')
202 | else:
203 | raise ValueError('Unknown audio checkpoint')
204 | else:
205 | raise f'this audio encoder pretrained checkpoint is not support'
206 |
207 | model.load_state_dict(audio_ckpt, strict=False)
208 | logging.info(f"Loading pretrained {amodel_name} weights ({pretrained_audio}).")
209 | param_names = [n for n, p in model.named_parameters()]
210 | for n in param_names:
211 | print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded")
212 |
213 | model.to(device=device)
214 | if precision == "fp16":
215 | assert device.type != "cpu"
216 | convert_weights_to_fp16(model)
217 |
218 | if jit:
219 | model = torch.jit.script(model)
220 |
221 | return model, model_cfg
222 |
223 |
224 | def create_model_and_transforms(
225 | model_name: str,
226 | pretrained: str = "",
227 | precision: str = "fp32",
228 | device: torch.device = torch.device("cpu"),
229 | jit: bool = False,
230 | force_quick_gelu: bool = False,
231 | # pretrained_image: bool = False,
232 | ):
233 | model = create_model(
234 | model_name,
235 | pretrained,
236 | precision,
237 | device,
238 | jit,
239 | force_quick_gelu=force_quick_gelu,
240 | # pretrained_image=pretrained_image
241 | )
242 | preprocess_train = image_transform(model.visual.image_size, is_train=True)
243 | preprocess_val = image_transform(model.visual.image_size, is_train=False)
244 | return model, preprocess_train, preprocess_val
245 |
246 |
247 | def list_models():
248 | """enumerate available model architectures based on config files"""
249 | return list(_MODEL_CONFIGS.keys())
250 |
251 |
252 | def add_model_config(path):
253 | """add model config path or file and update registry"""
254 | if not isinstance(path, Path):
255 | path = Path(path)
256 | _MODEL_CONFIG_PATHS.append(path)
257 | _rescan_model_configs()
258 |
--------------------------------------------------------------------------------
/open_musiclm/laion_clap/clap_module/feature_fusion.py:
--------------------------------------------------------------------------------
1 | '''
2 | Feature Fusion for Varible-Length Data Processing
3 | AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py
4 | According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021
5 | '''
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 |
11 | class DAF(nn.Module):
12 | '''
13 | 直接相加 DirectAddFuse
14 | '''
15 |
16 | def __init__(self):
17 | super(DAF, self).__init__()
18 |
19 | def forward(self, x, residual):
20 | return x + residual
21 |
22 |
23 | class iAFF(nn.Module):
24 | '''
25 | 多特征融合 iAFF
26 | '''
27 |
28 | def __init__(self, channels=64, r=4, type='2D'):
29 | super(iAFF, self).__init__()
30 | inter_channels = int(channels // r)
31 |
32 | if type == '1D':
33 | # 本地注意力
34 | self.local_att = nn.Sequential(
35 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
36 | nn.BatchNorm1d(inter_channels),
37 | nn.ReLU(inplace=True),
38 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
39 | nn.BatchNorm1d(channels),
40 | )
41 |
42 | # 全局注意力
43 | self.global_att = nn.Sequential(
44 | nn.AdaptiveAvgPool1d(1),
45 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
46 | nn.BatchNorm1d(inter_channels),
47 | nn.ReLU(inplace=True),
48 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
49 | nn.BatchNorm1d(channels),
50 | )
51 |
52 | # 第二次本地注意力
53 | self.local_att2 = nn.Sequential(
54 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
55 | nn.BatchNorm1d(inter_channels),
56 | nn.ReLU(inplace=True),
57 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
58 | nn.BatchNorm1d(channels),
59 | )
60 | # 第二次全局注意力
61 | self.global_att2 = nn.Sequential(
62 | nn.AdaptiveAvgPool1d(1),
63 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
64 | nn.BatchNorm1d(inter_channels),
65 | nn.ReLU(inplace=True),
66 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
67 | nn.BatchNorm1d(channels),
68 | )
69 | elif type == '2D':
70 | # 本地注意力
71 | self.local_att = nn.Sequential(
72 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
73 | nn.BatchNorm2d(inter_channels),
74 | nn.ReLU(inplace=True),
75 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
76 | nn.BatchNorm2d(channels),
77 | )
78 |
79 | # 全局注意力
80 | self.global_att = nn.Sequential(
81 | nn.AdaptiveAvgPool2d(1),
82 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
83 | nn.BatchNorm2d(inter_channels),
84 | nn.ReLU(inplace=True),
85 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
86 | nn.BatchNorm2d(channels),
87 | )
88 |
89 | # 第二次本地注意力
90 | self.local_att2 = nn.Sequential(
91 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
92 | nn.BatchNorm2d(inter_channels),
93 | nn.ReLU(inplace=True),
94 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
95 | nn.BatchNorm2d(channels),
96 | )
97 | # 第二次全局注意力
98 | self.global_att2 = nn.Sequential(
99 | nn.AdaptiveAvgPool2d(1),
100 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
101 | nn.BatchNorm2d(inter_channels),
102 | nn.ReLU(inplace=True),
103 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
104 | nn.BatchNorm2d(channels),
105 | )
106 | else:
107 | raise f'the type is not supported'
108 |
109 | self.sigmoid = nn.Sigmoid()
110 |
111 | def forward(self, x, residual):
112 | flag = False
113 | xa = x + residual
114 | if xa.size(0) == 1:
115 | xa = torch.cat([xa,xa],dim=0)
116 | flag = True
117 | xl = self.local_att(xa)
118 | xg = self.global_att(xa)
119 | xlg = xl + xg
120 | wei = self.sigmoid(xlg)
121 | xi = x * wei + residual * (1 - wei)
122 |
123 | xl2 = self.local_att2(xi)
124 | xg2 = self.global_att(xi)
125 | xlg2 = xl2 + xg2
126 | wei2 = self.sigmoid(xlg2)
127 | xo = x * wei2 + residual * (1 - wei2)
128 | if flag:
129 | xo = xo[0].unsqueeze(0)
130 | return xo
131 |
132 |
133 | class AFF(nn.Module):
134 | '''
135 | 多特征融合 AFF
136 | '''
137 |
138 | def __init__(self, channels=64, r=4, type='2D'):
139 | super(AFF, self).__init__()
140 | inter_channels = int(channels // r)
141 |
142 | if type == '1D':
143 | self.local_att = nn.Sequential(
144 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
145 | nn.BatchNorm1d(inter_channels),
146 | nn.ReLU(inplace=True),
147 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
148 | nn.BatchNorm1d(channels),
149 | )
150 | self.global_att = nn.Sequential(
151 | nn.AdaptiveAvgPool1d(1),
152 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
153 | nn.BatchNorm1d(inter_channels),
154 | nn.ReLU(inplace=True),
155 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
156 | nn.BatchNorm1d(channels),
157 | )
158 | elif type == '2D':
159 | self.local_att = nn.Sequential(
160 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
161 | nn.BatchNorm2d(inter_channels),
162 | nn.ReLU(inplace=True),
163 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
164 | nn.BatchNorm2d(channels),
165 | )
166 | self.global_att = nn.Sequential(
167 | nn.AdaptiveAvgPool2d(1),
168 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
169 | nn.BatchNorm2d(inter_channels),
170 | nn.ReLU(inplace=True),
171 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
172 | nn.BatchNorm2d(channels),
173 | )
174 | else:
175 | raise f'the type is not supported.'
176 |
177 | self.sigmoid = nn.Sigmoid()
178 |
179 | def forward(self, x, residual):
180 | flag = False
181 | xa = x + residual
182 | if xa.size(0) == 1:
183 | xa = torch.cat([xa,xa],dim=0)
184 | flag = True
185 | xl = self.local_att(xa)
186 | xg = self.global_att(xa)
187 | xlg = xl + xg
188 | wei = self.sigmoid(xlg)
189 | xo = 2 * x * wei + 2 * residual * (1 - wei)
190 | if flag:
191 | xo = xo[0].unsqueeze(0)
192 | return xo
193 |
194 |
--------------------------------------------------------------------------------
/open_musiclm/laion_clap/clap_module/linear_probe.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch.nn.functional as F
3 | from torch import nn
4 | from .model import MLPLayers
5 |
6 |
7 | class LinearProbe(nn.Module):
8 | def __init__(self, model, mlp, freeze, in_ch, out_ch, act=None):
9 | """
10 | Args:
11 | model: nn.Module
12 | mlp: bool, if True, then use the MLP layer as the linear probe module
13 | freeze: bool, if Ture, then freeze all the CLAP model's layers when training the linear probe
14 | in_ch: int, the output channel from CLAP model
15 | out_ch: int, the output channel from linear probe (class_num)
16 | act: torch.nn.functional, the activation function before the loss function
17 | """
18 | super().__init__()
19 | in_ch = 512
20 | self.clap_model = model
21 | self.clap_model.text_branch = None # to save memory
22 | self.freeze = freeze
23 | if mlp:
24 | self.lp_layer = MLPLayers(units=[in_ch, in_ch * 2, out_ch])
25 | else:
26 | self.lp_layer = nn.Linear(in_ch, out_ch)
27 |
28 | if self.freeze:
29 | for param in self.clap_model.parameters():
30 | param.requires_grad = False
31 |
32 | if act == 'None':
33 | self.act = None
34 | elif act == 'relu':
35 | self.act = nn.ReLU()
36 | elif act == 'elu':
37 | self.act = nn.ELU()
38 | elif act == 'prelu':
39 | self.act = nn.PReLU(num_parameters=in_ch)
40 | elif act == 'softmax':
41 | self.act = nn.Softmax(dim=-1)
42 | elif act == 'sigmoid':
43 | self.act = nn.Sigmoid()
44 |
45 | def forward(self, x, mix_lambda=None, device=None):
46 | """
47 | Args:
48 | x: waveform, torch.tensor [batch, t_samples] / batch of mel_spec and longer list
49 | mix_lambda: torch.tensor [batch], the mixup lambda
50 | Returns:
51 | class_prob: torch.tensor [batch, class_num]
52 |
53 | """
54 | # batchnorm cancel grandient
55 | if self.freeze:
56 | self.clap_model.eval()
57 |
58 | x = self.clap_model.audio_projection(
59 | self.clap_model.audio_branch(x, mixup_lambda=mix_lambda, device=device)["embedding"])
60 | out = self.lp_layer(x)
61 | if self.act is not None:
62 | out = self.act(out)
63 | return out
64 |
--------------------------------------------------------------------------------
/open_musiclm/laion_clap/clap_module/model_configs/HTSAT-base.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "audio_cfg": {
4 | "audio_length": 1024,
5 | "clip_samples": 480000,
6 | "mel_bins": 64,
7 | "sample_rate": 48000,
8 | "window_size": 1024,
9 | "hop_size": 480,
10 | "fmin": 50,
11 | "fmax": 14000,
12 | "class_num": 527,
13 | "model_type": "HTSAT",
14 | "model_name": "base"
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 512,
20 | "heads": 8,
21 | "layers": 12
22 | }
23 | }
--------------------------------------------------------------------------------
/open_musiclm/laion_clap/clap_module/model_configs/HTSAT-large.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 2048,
3 | "audio_cfg": {
4 | "audio_length": 1024,
5 | "clip_samples": 480000,
6 | "mel_bins": 64,
7 | "sample_rate": 48000,
8 | "window_size": 1024,
9 | "hop_size": 480,
10 | "fmin": 50,
11 | "fmax": 14000,
12 | "class_num": 527,
13 | "model_type": "HTSAT",
14 | "model_name": "large"
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 512,
20 | "heads": 8,
21 | "layers": 12
22 | }
23 | }
--------------------------------------------------------------------------------
/open_musiclm/laion_clap/clap_module/model_configs/HTSAT-tiny-win-1536.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "audio_cfg": {
4 | "audio_length": 1024,
5 | "clip_samples": 480000,
6 | "mel_bins": 64,
7 | "sample_rate": 48000,
8 | "window_size": 1536,
9 | "hop_size": 480,
10 | "fmin": 50,
11 | "fmax": 14000,
12 | "class_num": 527,
13 | "model_type": "HTSAT",
14 | "model_name": "tiny"
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 512,
20 | "heads": 8,
21 | "layers": 12
22 | }
23 | }
--------------------------------------------------------------------------------
/open_musiclm/laion_clap/clap_module/model_configs/HTSAT-tiny.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "audio_cfg": {
4 | "audio_length": 1024,
5 | "clip_samples": 480000,
6 | "mel_bins": 64,
7 | "sample_rate": 48000,
8 | "window_size": 1024,
9 | "hop_size": 480,
10 | "fmin": 50,
11 | "fmax": 14000,
12 | "class_num": 527,
13 | "model_type": "HTSAT",
14 | "model_name": "tiny"
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 512,
20 | "heads": 8,
21 | "layers": 12
22 | }
23 | }
--------------------------------------------------------------------------------
/open_musiclm/laion_clap/clap_module/model_configs/PANN-10.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "audio_cfg": {
4 | "audio_length": 1024,
5 | "clip_samples": 480000,
6 | "mel_bins": 64,
7 | "sample_rate": 48000,
8 | "window_size": 1024,
9 | "hop_size": 480,
10 | "fmin": 50,
11 | "fmax": 14000,
12 | "class_num": 527,
13 | "model_type": "PANN",
14 | "model_name": "Cnn10"
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 512,
20 | "heads": 8,
21 | "layers": 12
22 | }
23 | }
--------------------------------------------------------------------------------
/open_musiclm/laion_clap/clap_module/model_configs/PANN-14-fmax-18k.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 2048,
3 | "audio_cfg": {
4 | "audio_length": 1024,
5 | "clip_samples": 480000,
6 | "mel_bins": 64,
7 | "sample_rate": 48000,
8 | "window_size": 1024,
9 | "hop_size": 480,
10 | "fmin": 50,
11 | "fmax": 18000,
12 | "class_num": 527,
13 | "model_type": "PANN",
14 | "model_name": "Cnn14"
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 512,
20 | "heads": 8,
21 | "layers": 12
22 | }
23 | }
--------------------------------------------------------------------------------
/open_musiclm/laion_clap/clap_module/model_configs/PANN-14-fmax-8k-20s.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 2048,
3 | "audio_cfg": {
4 | "audio_length": 1024,
5 | "clip_samples": 960000,
6 | "mel_bins": 64,
7 | "sample_rate": 48000,
8 | "window_size": 1024,
9 | "hop_size": 360,
10 | "fmin": 50,
11 | "fmax": 8000,
12 | "class_num": 527,
13 | "model_type": "PANN",
14 | "model_name": "Cnn14"
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 512,
20 | "heads": 8,
21 | "layers": 12
22 | }
23 | }
--------------------------------------------------------------------------------
/open_musiclm/laion_clap/clap_module/model_configs/PANN-14-tiny-transformer.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 2048,
3 | "audio_cfg": {
4 | "audio_length": 1024,
5 | "clip_samples": 480000,
6 | "mel_bins": 64,
7 | "sample_rate": 48000,
8 | "window_size": 1024,
9 | "hop_size": 480,
10 | "fmin": 50,
11 | "fmax": 14000,
12 | "class_num": 527,
13 | "model_type": "PANN",
14 | "model_name": "Cnn14"
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 512,
20 | "heads": 8,
21 | "layers": 4
22 | }
23 | }
--------------------------------------------------------------------------------
/open_musiclm/laion_clap/clap_module/model_configs/PANN-14-win-1536.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 2048,
3 | "audio_cfg": {
4 | "audio_length": 1024,
5 | "clip_samples": 480000,
6 | "mel_bins": 64,
7 | "sample_rate": 48000,
8 | "window_size": 1536,
9 | "hop_size": 480,
10 | "fmin": 50,
11 | "fmax": 14000,
12 | "class_num": 527,
13 | "model_type": "PANN",
14 | "model_name": "Cnn14"
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 512,
20 | "heads": 8,
21 | "layers": 12
22 | }
23 | }
--------------------------------------------------------------------------------
/open_musiclm/laion_clap/clap_module/model_configs/PANN-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 2048,
3 | "audio_cfg": {
4 | "audio_length": 1024,
5 | "clip_samples": 480000,
6 | "mel_bins": 64,
7 | "sample_rate": 48000,
8 | "window_size": 1024,
9 | "hop_size": 480,
10 | "fmin": 50,
11 | "fmax": 14000,
12 | "class_num": 527,
13 | "model_type": "PANN",
14 | "model_name": "Cnn14"
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 512,
20 | "heads": 8,
21 | "layers": 12
22 | }
23 | }
--------------------------------------------------------------------------------
/open_musiclm/laion_clap/clap_module/model_configs/PANN-6.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "audio_cfg": {
4 | "audio_length": 1024,
5 | "clip_samples": 480000,
6 | "mel_bins": 64,
7 | "sample_rate": 48000,
8 | "window_size": 1024,
9 | "hop_size": 480,
10 | "fmin": 50,
11 | "fmax": 14000,
12 | "class_num": 527,
13 | "model_type": "PANN",
14 | "model_name": "Cnn6"
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 512,
20 | "heads": 8,
21 | "layers": 12
22 | }
23 | }
--------------------------------------------------------------------------------
/open_musiclm/laion_clap/clap_module/model_configs/RN101-quickgelu.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "quick_gelu": true,
4 | "vision_cfg": {
5 | "image_size": 224,
6 | "layers": [
7 | 3,
8 | 4,
9 | 23,
10 | 3
11 | ],
12 | "width": 64,
13 | "patch_size": null
14 | },
15 | "text_cfg": {
16 | "context_length": 77,
17 | "vocab_size": 49408,
18 | "width": 512,
19 | "heads": 8,
20 | "layers": 12
21 | }
22 | }
--------------------------------------------------------------------------------
/open_musiclm/laion_clap/clap_module/model_configs/RN101.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": [
6 | 3,
7 | 4,
8 | 23,
9 | 3
10 | ],
11 | "width": 64,
12 | "patch_size": null
13 | },
14 | "text_cfg": {
15 | "context_length": 77,
16 | "vocab_size": 49408,
17 | "width": 512,
18 | "heads": 8,
19 | "layers": 12
20 | }
21 | }
--------------------------------------------------------------------------------
/open_musiclm/laion_clap/clap_module/model_configs/RN50-quickgelu.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "quick_gelu": true,
4 | "vision_cfg": {
5 | "image_size": 224,
6 | "layers": [
7 | 3,
8 | 4,
9 | 6,
10 | 3
11 | ],
12 | "width": 64,
13 | "patch_size": null
14 | },
15 | "text_cfg": {
16 | "context_length": 77,
17 | "vocab_size": 49408,
18 | "width": 512,
19 | "heads": 8,
20 | "layers": 12
21 | }
22 | }
23 |
--------------------------------------------------------------------------------
/open_musiclm/laion_clap/clap_module/model_configs/RN50.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": [
6 | 3,
7 | 4,
8 | 6,
9 | 3
10 | ],
11 | "width": 64,
12 | "patch_size": null
13 | },
14 | "text_cfg": {
15 | "context_length": 77,
16 | "vocab_size": 49408,
17 | "width": 512,
18 | "heads": 8,
19 | "layers": 12
20 | }
21 | }
--------------------------------------------------------------------------------
/open_musiclm/laion_clap/clap_module/model_configs/RN50x16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 384,
5 | "layers": [
6 | 6,
7 | 8,
8 | 18,
9 | 8
10 | ],
11 | "width": 96,
12 | "patch_size": null
13 | },
14 | "text_cfg": {
15 | "context_length": 77,
16 | "vocab_size": 49408,
17 | "width": 768,
18 | "heads": 12,
19 | "layers": 12
20 | }
21 | }
--------------------------------------------------------------------------------
/open_musiclm/laion_clap/clap_module/model_configs/RN50x4.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 640,
3 | "vision_cfg": {
4 | "image_size": 288,
5 | "layers": [
6 | 4,
7 | 6,
8 | 10,
9 | 6
10 | ],
11 | "width": 80,
12 | "patch_size": null
13 | },
14 | "text_cfg": {
15 | "context_length": 77,
16 | "vocab_size": 49408,
17 | "width": 640,
18 | "heads": 10,
19 | "layers": 12
20 | }
21 | }
--------------------------------------------------------------------------------
/open_musiclm/laion_clap/clap_module/model_configs/ViT-B-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 768,
7 | "patch_size": 16
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 512,
13 | "heads": 8,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/open_musiclm/laion_clap/clap_module/model_configs/ViT-B-32-quickgelu.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "quick_gelu": true,
4 | "vision_cfg": {
5 | "image_size": 224,
6 | "layers": 12,
7 | "width": 768,
8 | "patch_size": 32
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 512,
14 | "heads": 8,
15 | "layers": 12
16 | }
17 | }
--------------------------------------------------------------------------------
/open_musiclm/laion_clap/clap_module/model_configs/ViT-B-32.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 768,
7 | "patch_size": 32
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 512,
13 | "heads": 8,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/open_musiclm/laion_clap/clap_module/model_configs/ViT-L-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 24,
6 | "width": 1024,
7 | "patch_size": 14
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 768,
13 | "heads": 12,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/open_musiclm/laion_clap/clap_module/openai.py:
--------------------------------------------------------------------------------
1 | """ OpenAI pretrained model functions
2 |
3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4 | """
5 |
6 | import os
7 | import warnings
8 | from typing import Union, List
9 |
10 | import torch
11 |
12 | from .model import build_model_from_openai_state_dict
13 | from .pretrained import get_pretrained_url, list_pretrained_tag_models, download_pretrained
14 |
15 | __all__ = ["list_openai_models", "load_openai_model"]
16 |
17 |
18 | def list_openai_models() -> List[str]:
19 | """Returns the names of available CLIP models"""
20 | return list_pretrained_tag_models('openai')
21 |
22 |
23 | def load_openai_model(
24 | name: str,
25 | model_cfg,
26 | device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
27 | jit=True,
28 | cache_dir=os.path.expanduser("~/.cache/clip"),
29 | enable_fusion: bool = False,
30 | fusion_type: str = 'None'
31 | ):
32 | """Load a CLIP model, preserve its text pretrained part, and set in the CLAP model
33 |
34 | Parameters
35 | ----------
36 | name : str
37 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
38 | device : Union[str, torch.device]
39 | The device to put the loaded model
40 | jit : bool
41 | Whether to load the optimized JIT model (default) or more hackable non-JIT model.
42 |
43 | Returns
44 | -------
45 | model : torch.nn.Module
46 | The CLAP model
47 | preprocess : Callable[[PIL.Image], torch.Tensor]
48 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
49 | """
50 | if get_pretrained_url(name, 'openai'):
51 | model_path = download_pretrained(get_pretrained_url(name, 'openai'), root=cache_dir)
52 | elif os.path.isfile(name):
53 | model_path = name
54 | else:
55 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
56 |
57 | try:
58 | # loading JIT archive
59 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
60 | state_dict = None
61 | except RuntimeError:
62 | # loading saved state dict
63 | if jit:
64 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
65 | jit = False
66 | state_dict = torch.load(model_path, map_location="cpu")
67 |
68 | if not jit:
69 | try:
70 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), model_cfg, enable_fusion, fusion_type).to(device)
71 | except KeyError:
72 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
73 | model = build_model_from_openai_state_dict(sd, model_cfg, enable_fusion, fusion_type).to(device)
74 |
75 | if str(device) == "cpu":
76 | model.float()
77 | return model
78 |
79 | # patch the device names
80 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
81 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
82 |
83 | def patch_device(module):
84 | try:
85 | graphs = [module.graph] if hasattr(module, "graph") else []
86 | except RuntimeError:
87 | graphs = []
88 |
89 | if hasattr(module, "forward1"):
90 | graphs.append(module.forward1.graph)
91 |
92 | for graph in graphs:
93 | for node in graph.findAllNodes("prim::Constant"):
94 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
95 | node.copyAttributes(device_node)
96 |
97 | model.apply(patch_device)
98 | patch_device(model.encode_audio)
99 | patch_device(model.encode_text)
100 |
101 | # patch dtype to float32 on CPU
102 | if str(device) == "cpu":
103 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
104 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
105 | float_node = float_input.node()
106 |
107 | def patch_float(module):
108 | try:
109 | graphs = [module.graph] if hasattr(module, "graph") else []
110 | except RuntimeError:
111 | graphs = []
112 |
113 | if hasattr(module, "forward1"):
114 | graphs.append(module.forward1.graph)
115 |
116 | for graph in graphs:
117 | for node in graph.findAllNodes("aten::to"):
118 | inputs = list(node.inputs())
119 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
120 | if inputs[i].node()["value"] == 5:
121 | inputs[i].node().copyAttributes(float_node)
122 |
123 | model.apply(patch_float)
124 | patch_float(model.encode_audio)
125 | patch_float(model.encode_text)
126 | model.float()
127 |
128 | model.audio_branch.audio_length = model.audio_cfg.audio_length
129 | return model
130 |
--------------------------------------------------------------------------------
/open_musiclm/laion_clap/clap_module/pretrained.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import urllib
4 | import warnings
5 |
6 | from tqdm import tqdm
7 |
8 | _RN50 = dict(
9 | openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
10 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt",
11 | cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"
12 | )
13 |
14 | _RN50_quickgelu = dict(
15 | openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
16 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt",
17 | cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"
18 | )
19 |
20 | _RN101 = dict(
21 | openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
22 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"
23 | )
24 |
25 | _RN101_quickgelu = dict(
26 | openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
27 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"
28 | )
29 |
30 | _RN50x4 = dict(
31 | openai="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
32 | )
33 |
34 | _RN50x16 = dict(
35 | openai="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
36 | )
37 |
38 | _RN50x64 = dict(
39 | openai="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
40 | )
41 |
42 | _VITB32 = dict(
43 | openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
44 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt",
45 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt",
46 | laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt",
47 | )
48 |
49 | _VITB32_quickgelu = dict(
50 | openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
51 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt",
52 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt",
53 | laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt",
54 | )
55 |
56 | _VITB16 = dict(
57 | openai="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
58 | )
59 |
60 | _VITL14 = dict(
61 | openai="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
62 | )
63 |
64 | _PRETRAINED = {
65 | "RN50": _RN50,
66 | "RN50-quickgelu": _RN50_quickgelu,
67 | "RN101": _RN101,
68 | "RN101-quickgelu": _RN101_quickgelu,
69 | "RN50x4": _RN50x4,
70 | "RN50x16": _RN50x16,
71 | "ViT-B-32": _VITB32,
72 | "ViT-B-32-quickgelu": _VITB32_quickgelu,
73 | "ViT-B-16": _VITB16,
74 | "ViT-L-14": _VITL14,
75 | }
76 |
77 |
78 | def list_pretrained(as_str: bool = False):
79 | """ returns list of pretrained models
80 | Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
81 | """
82 | return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]
83 |
84 |
85 | def list_pretrained_tag_models(tag: str):
86 | """ return all models having the specified pretrain tag """
87 | models = []
88 | for k in _PRETRAINED.keys():
89 | if tag in _PRETRAINED[k]:
90 | models.append(k)
91 | return models
92 |
93 |
94 | def list_pretrained_model_tags(model: str):
95 | """ return all pretrain tags for the specified model architecture """
96 | tags = []
97 | if model in _PRETRAINED:
98 | tags.extend(_PRETRAINED[model].keys())
99 | return tags
100 |
101 |
102 | def get_pretrained_url(model: str, tag: str):
103 | if model not in _PRETRAINED:
104 | return ''
105 | model_pretrained = _PRETRAINED[model]
106 | if tag not in model_pretrained:
107 | return ''
108 | return model_pretrained[tag]
109 |
110 |
111 | def download_pretrained(url: str, root: str = os.path.expanduser("~/.cache/clip")):
112 | os.makedirs(root, exist_ok=True)
113 | filename = os.path.basename(url)
114 |
115 | if 'openaipublic' in url:
116 | expected_sha256 = url.split("/")[-2]
117 | else:
118 | expected_sha256 = ''
119 |
120 | download_target = os.path.join(root, filename)
121 |
122 | if os.path.exists(download_target) and not os.path.isfile(download_target):
123 | raise RuntimeError(f"{download_target} exists and is not a regular file")
124 |
125 | if os.path.isfile(download_target):
126 | if expected_sha256:
127 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
128 | return download_target
129 | else:
130 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
131 | else:
132 | return download_target
133 |
134 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
135 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
136 | while True:
137 | buffer = source.read(8192)
138 | if not buffer:
139 | break
140 |
141 | output.write(buffer)
142 | loop.update(len(buffer))
143 |
144 | if expected_sha256 and hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
145 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
146 |
147 | return download_target
148 |
--------------------------------------------------------------------------------
/open_musiclm/laion_clap/clap_module/timm_model.py:
--------------------------------------------------------------------------------
1 | """ timm model adapter
2 |
3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
4 | """
5 | from collections import OrderedDict
6 |
7 | import torch.nn as nn
8 |
9 | try:
10 | import timm
11 | from timm.models.layers import Mlp, to_2tuple
12 | from timm.models.layers.attention_pool2d import RotAttentionPool2d
13 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
14 | except ImportError as e:
15 | timm = None
16 |
17 | from .utils import freeze_batch_norm_2d
18 |
19 |
20 | class TimmModel(nn.Module):
21 | """ timm model adapter
22 | # FIXME this adapter is a work in progress, may change in ways that break weight compat
23 | """
24 |
25 | def __init__(
26 | self,
27 | model_name,
28 | embed_dim,
29 | image_size=224,
30 | pool='avg',
31 | proj='linear',
32 | drop=0.,
33 | pretrained=False):
34 | super().__init__()
35 | if timm is None:
36 | raise RuntimeError("Please `pip install timm` to use timm models.")
37 |
38 | self.image_size = to_2tuple(image_size)
39 | self.trunk = timm.create_model(model_name, pretrained=pretrained)
40 | feat_size = self.trunk.default_cfg.get('pool_size', None)
41 | feature_ndim = 1 if not feat_size else 2
42 | if pool in ('abs_attn', 'rot_attn'):
43 | assert feature_ndim == 2
44 | # if attn pooling used, remove both classifier and default pool
45 | self.trunk.reset_classifier(0, global_pool='')
46 | else:
47 | # reset global pool if pool config set, otherwise leave as network default
48 | reset_kwargs = dict(global_pool=pool) if pool else {}
49 | self.trunk.reset_classifier(0, **reset_kwargs)
50 | prev_chs = self.trunk.num_features
51 |
52 | head_layers = OrderedDict()
53 | if pool == 'abs_attn':
54 | head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
55 | prev_chs = embed_dim
56 | elif pool == 'rot_attn':
57 | head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
58 | prev_chs = embed_dim
59 | else:
60 | assert proj, 'projection layer needed if non-attention pooling is used.'
61 |
62 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
63 | if proj == 'linear':
64 | head_layers['drop'] = nn.Dropout(drop)
65 | head_layers['proj'] = nn.Linear(prev_chs, embed_dim)
66 | elif proj == 'mlp':
67 | head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop)
68 |
69 | self.head = nn.Sequential(head_layers)
70 |
71 | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
72 | """ lock modules
73 | Args:
74 | unlocked_groups (int): leave last n layer groups unlocked (default: 0)
75 | """
76 | if not unlocked_groups:
77 | # lock full model
78 | for param in self.trunk.parameters():
79 | param.requires_grad = False
80 | if freeze_bn_stats:
81 | freeze_batch_norm_2d(self.trunk)
82 | else:
83 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change
84 | try:
85 | # FIXME import here until API stable and in an official release
86 | from timm.models.helpers import group_parameters, group_modules
87 | except ImportError:
88 | raise RuntimeError(
89 | 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
90 | matcher = self.trunk.group_matcher()
91 | gparams = group_parameters(self.trunk, matcher)
92 | max_layer_id = max(gparams.keys())
93 | max_layer_id = max_layer_id - unlocked_groups
94 | for group_idx in range(max_layer_id + 1):
95 | group = gparams[group_idx]
96 | for param in group:
97 | self.trunk.get_parameter(param).requires_grad = False
98 | if freeze_bn_stats:
99 | gmodules = group_modules(self.trunk, matcher, reverse=True)
100 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
101 | freeze_batch_norm_2d(self.trunk, gmodules)
102 |
103 | def forward(self, x):
104 | x = self.trunk(x)
105 | x = self.head(x)
106 | return x
107 |
--------------------------------------------------------------------------------
/open_musiclm/laion_clap/clap_module/tokenizer.py:
--------------------------------------------------------------------------------
1 | """ CLIP tokenizer
2 |
3 | Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4 | """
5 | import gzip
6 | import html
7 | import os
8 | from functools import lru_cache
9 | from typing import Union, List
10 |
11 | import ftfy
12 | import regex as re
13 | import torch
14 |
15 |
16 | @lru_cache()
17 | def default_bpe():
18 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
19 |
20 |
21 | @lru_cache()
22 | def bytes_to_unicode():
23 | """
24 | Returns list of utf-8 byte and a corresponding list of unicode strings.
25 | The reversible bpe codes work on unicode strings.
26 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
27 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
28 | This is a signficant percentage of your normal, say, 32K bpe vocab.
29 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
30 | And avoids mapping to whitespace/control characters the bpe code barfs on.
31 | """
32 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
33 | cs = bs[:]
34 | n = 0
35 | for b in range(2**8):
36 | if b not in bs:
37 | bs.append(b)
38 | cs.append(2**8+n)
39 | n += 1
40 | cs = [chr(n) for n in cs]
41 | return dict(zip(bs, cs))
42 |
43 |
44 | def get_pairs(word):
45 | """Return set of symbol pairs in a word.
46 | Word is represented as tuple of symbols (symbols being variable-length strings).
47 | """
48 | pairs = set()
49 | prev_char = word[0]
50 | for char in word[1:]:
51 | pairs.add((prev_char, char))
52 | prev_char = char
53 | return pairs
54 |
55 |
56 | def basic_clean(text):
57 | text = ftfy.fix_text(text)
58 | text = html.unescape(html.unescape(text))
59 | return text.strip()
60 |
61 |
62 | def whitespace_clean(text):
63 | text = re.sub(r'\s+', ' ', text)
64 | text = text.strip()
65 | return text
66 |
67 |
68 | class SimpleTokenizer(object):
69 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
70 | self.byte_encoder = bytes_to_unicode()
71 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
72 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
73 | merges = merges[1:49152-256-2+1]
74 | merges = [tuple(merge.split()) for merge in merges]
75 | vocab = list(bytes_to_unicode().values())
76 | vocab = vocab + [v+'' for v in vocab]
77 | for merge in merges:
78 | vocab.append(''.join(merge))
79 | if not special_tokens:
80 | special_tokens = ['', '']
81 | else:
82 | special_tokens = ['', ''] + special_tokens
83 | vocab.extend(special_tokens)
84 | self.encoder = dict(zip(vocab, range(len(vocab))))
85 | self.decoder = {v: k for k, v in self.encoder.items()}
86 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
87 | self.cache = {t:t for t in special_tokens}
88 | special = "|".join(special_tokens)
89 | self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
90 |
91 | self.vocab_size = len(self.encoder)
92 | self.all_special_ids = [self.encoder[t] for t in special_tokens]
93 |
94 | def bpe(self, token):
95 | if token in self.cache:
96 | return self.cache[token]
97 | word = tuple(token[:-1]) + ( token[-1] + '',)
98 | pairs = get_pairs(word)
99 |
100 | if not pairs:
101 | return token+''
102 |
103 | while True:
104 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
105 | if bigram not in self.bpe_ranks:
106 | break
107 | first, second = bigram
108 | new_word = []
109 | i = 0
110 | while i < len(word):
111 | try:
112 | j = word.index(first, i)
113 | new_word.extend(word[i:j])
114 | i = j
115 | except:
116 | new_word.extend(word[i:])
117 | break
118 |
119 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
120 | new_word.append(first+second)
121 | i += 2
122 | else:
123 | new_word.append(word[i])
124 | i += 1
125 | new_word = tuple(new_word)
126 | word = new_word
127 | if len(word) == 1:
128 | break
129 | else:
130 | pairs = get_pairs(word)
131 | word = ' '.join(word)
132 | self.cache[token] = word
133 | return word
134 |
135 | def encode(self, text):
136 | bpe_tokens = []
137 | text = whitespace_clean(basic_clean(text)).lower()
138 | for token in re.findall(self.pat, text):
139 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
140 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
141 | return bpe_tokens
142 |
143 | def decode(self, tokens):
144 | text = ''.join([self.decoder[token] for token in tokens])
145 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
146 | return text
147 |
148 |
149 | _tokenizer = SimpleTokenizer()
150 |
151 |
152 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
153 | """
154 | Returns the tokenized representation of given input string(s)
155 |
156 | Parameters
157 | ----------
158 | texts : Union[str, List[str]]
159 | An input string or a list of input strings to tokenize
160 | context_length : int
161 | The context length to use; all CLIP models use 77 as the context length
162 |
163 | Returns
164 | -------
165 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
166 | """
167 | if isinstance(texts, str):
168 | texts = [texts]
169 |
170 | sot_token = _tokenizer.encoder[""]
171 | eot_token = _tokenizer.encoder[""]
172 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
173 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
174 |
175 | for i, tokens in enumerate(all_tokens):
176 | if len(tokens) > context_length:
177 | tokens = tokens[:context_length] # Truncate
178 | result[i, :len(tokens)] = torch.tensor(tokens)
179 |
180 | return result
181 |
--------------------------------------------------------------------------------
/open_musiclm/laion_clap/clap_module/transform.py:
--------------------------------------------------------------------------------
1 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
2 | CenterCrop
3 |
4 |
5 | def _convert_to_rgb(image):
6 | return image.convert('RGB')
7 |
8 |
9 | def image_transform(
10 | image_size: int,
11 | is_train: bool,
12 | mean=(0.48145466, 0.4578275, 0.40821073),
13 | std=(0.26862954, 0.26130258, 0.27577711)
14 | ):
15 | normalize = Normalize(mean=mean, std=std)
16 | if is_train:
17 | return Compose([
18 | RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
19 | _convert_to_rgb,
20 | ToTensor(),
21 | normalize,
22 | ])
23 | else:
24 | return Compose([
25 | Resize(image_size, interpolation=InterpolationMode.BICUBIC),
26 | CenterCrop(image_size),
27 | _convert_to_rgb,
28 | ToTensor(),
29 | normalize,
30 | ])
31 |
--------------------------------------------------------------------------------
/open_musiclm/laion_clap/clap_module/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn as nn
4 | from torchvision.ops.misc import FrozenBatchNorm2d
5 | import logging
6 | import h5py
7 | from tqdm import tqdm
8 | import random
9 | import json
10 | import os
11 | import pathlib
12 |
13 | # TODO: (yusong) this not a good place to store those information and does not scale. Need to be fixed later.
14 | dataset_split = {
15 | "audiocaps": ["train", "valid", "test"],
16 | "audioset": ["balanced_train", "unbalanced_train", "eval"],
17 | "BBCSoundEffects": ["train", "test"],
18 | "Clotho": ["train", "test", "valid"],
19 | "free_to_use_sounds": ["train", "test"],
20 | "paramount_motion": ["train", "test"],
21 | "sonniss_game_effects": ["train", "test"],
22 | "wesoundeffects": ["train", "test"],
23 | "MACS": ["train", "test"],
24 | "freesound": ["train", "test"],
25 | "FSD50K": ["train", "test", "valid"],
26 | "fsd50k_class_label": ["train", "test", "valid"],
27 | "esc50": ["train", "test"],
28 | "ESC50_1": ["train", "test"],
29 | "ESC50_2": ["train", "test"],
30 | "ESC50_3": ["train", "test"],
31 | "ESC50_4": ["train", "test"],
32 | "ESC50_5": ["train", "test"],
33 | "audiostock": ["train", "test"],
34 | "freesound_no_overlap_noesc50": ["train", "test"],
35 | "epidemic_sound_effects": ["train", "test"],
36 | "VGGSound": ["train", "test"],
37 | "urbansound8k_class_label": ["train", "test"],
38 | "audioset_t5": ["balanced_train", "unbalanced_train", "eval"],
39 | "audioset_t5_debiased": ["balanced_train", "unbalanced_train", "eval"],
40 | "epidemic_sound_effects_t5": ["train", "test"],
41 | "epidemic_sound_effects_t5_debiased": ["train", "test"],
42 | "WavText5K": ["train", "test"],
43 | "esc50_no_overlap": ["train", "test"],
44 | "usd8k_no_overlap": ["train", "test"],
45 | "fsd50k_200_class_label": ["train", "test", "valid"],
46 | "fma_full": ["train", "test"],
47 | "Genius": ["train", "test"],
48 | "Jamendo": ["train", "test"],
49 | "juno": ["train", "test"],
50 | "CMU_Arctic": ["train", "test"],
51 | "ravdess": ["train", "test"],
52 | "Europarl-st": ["train", "test"],
53 | "common_voice": ["train", "test"],
54 | "Jamendo_16bit": ["train", "test"],
55 | "genius_16bit_128": ["train", "test"],
56 | "juno_16bit": ["train", "test"],
57 | "fma_full_16bit_128": ["train", "test"],
58 | }
59 |
60 |
61 | def freeze_batch_norm_2d(module, module_match={}, name=""):
62 | """
63 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
64 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
65 | returned. Otherwise, the module is walked recursively and submodules are converted in place.
66 |
67 | Args:
68 | module (torch.nn.Module): Any PyTorch module.
69 | module_match (dict): Dictionary of full module names to freeze (all if empty)
70 | name (str): Full module name (prefix)
71 |
72 | Returns:
73 | torch.nn.Module: Resulting module
74 |
75 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
76 | """
77 | res = module
78 | is_match = True
79 | if module_match:
80 | is_match = name in module_match
81 | if is_match and isinstance(
82 | module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)
83 | ):
84 | res = FrozenBatchNorm2d(module.num_features)
85 | res.num_features = module.num_features
86 | res.affine = module.affine
87 | if module.affine:
88 | res.weight.data = module.weight.data.clone().detach()
89 | res.bias.data = module.bias.data.clone().detach()
90 | res.running_mean.data = module.running_mean.data
91 | res.running_var.data = module.running_var.data
92 | res.eps = module.eps
93 | else:
94 | for child_name, child in module.named_children():
95 | full_child_name = ".".join([name, child_name]) if name else child_name
96 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
97 | if new_child is not child:
98 | res.add_module(child_name, new_child)
99 | return res
100 |
101 |
102 | def exist(dataset_name, dataset_type):
103 | """
104 | Check if dataset exists
105 | """
106 | if dataset_type in dataset_split[dataset_name]:
107 | return True
108 | else:
109 | return False
110 |
111 |
112 | def get_tar_path_from_dataset_name(
113 | dataset_names,
114 | dataset_types,
115 | islocal,
116 | dataset_path,
117 | proportion=1,
118 | full_dataset=None
119 | ):
120 | """
121 | Get tar path from dataset name and type
122 | """
123 | output = []
124 | for n in dataset_names:
125 | if full_dataset is not None and n in full_dataset:
126 | current_dataset_types = dataset_split[n]
127 | else:
128 | current_dataset_types = dataset_types
129 | for s in current_dataset_types:
130 | tmp = []
131 | if islocal:
132 | sizefilepath_ = f"{dataset_path}/{n}/{s}/sizes.json"
133 | if not os.path.exists(sizefilepath_):
134 | sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
135 | else:
136 | sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
137 | if not os.path.exists(sizefilepath_):
138 | continue
139 | sizes = json.load(open(sizefilepath_, "r"))
140 | for k in sizes.keys():
141 | if islocal:
142 | tmp.append(f"{dataset_path}/{n}/{s}/{k}")
143 | else:
144 | tmp.append(
145 | f"pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/{n}/{s}/{k} -"
146 | )
147 | if proportion != 1:
148 | tmp = random.sample(tmp, int(proportion * len(tmp)))
149 | output.append(tmp)
150 | return sum(output, [])
151 |
152 |
153 | def get_tar_path_from_txts(txt_path, islocal, proportion=1):
154 | """
155 | Get tar path from txt path
156 | """
157 | if isinstance(txt_path, (list, tuple)):
158 | return sum(
159 | [
160 | get_tar_path_from_txts(
161 | txt_path[i], islocal=islocal, proportion=proportion
162 | )
163 | for i in range(len(txt_path))
164 | ],
165 | [],
166 | )
167 | if isinstance(txt_path, str):
168 | with open(txt_path) as f:
169 | lines = f.readlines()
170 | if islocal:
171 | lines = [
172 | lines[i]
173 | .split("\n")[0]
174 | .replace("pipe:aws s3 cp s3://s-laion-audio/", "/mnt/audio_clip/")
175 | for i in range(len(lines))
176 | ]
177 | else:
178 | lines = [
179 | lines[i].split("\n")[0].replace(".tar", ".tar -")
180 | for i in range(len(lines))
181 | ]
182 | if proportion != 1:
183 | print("Sampling tars with proportion of {}".format(proportion))
184 | lines = random.sample(lines, int(proportion * len(lines)))
185 | return lines
186 |
187 |
188 | def get_mix_lambda(mixup_alpha, batch_size):
189 | mixup_lambdas = [
190 | np.random.beta(mixup_alpha, mixup_alpha, 1)[0] for _ in range(batch_size)
191 | ]
192 | return np.array(mixup_lambdas).astype(np.float32)
193 |
194 |
195 | def do_mixup(x, mixup_lambda):
196 | """
197 | Args:
198 | x: (batch_size , ...)
199 | mixup_lambda: (batch_size,)
200 | Returns:
201 | out: (batch_size, ...)
202 | """
203 | out = (
204 | x.transpose(0, -1) * mixup_lambda
205 | + torch.flip(x, dims=[0]).transpose(0, -1) * (1 - mixup_lambda)
206 | ).transpose(0, -1)
207 | return out
208 |
209 |
210 | def interpolate(x, ratio):
211 | """Interpolate data in time domain. This is used to compensate the
212 | resolution reduction in downsampling of a CNN.
213 |
214 | Args:
215 | x: (batch_size, time_steps, classes_num)
216 | ratio: int, ratio to interpolate
217 | Returns:
218 | upsampled: (batch_size, time_steps * ratio, classes_num)
219 | """
220 | (batch_size, time_steps, classes_num) = x.shape
221 | upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
222 | upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
223 | return upsampled
224 |
225 |
226 | def pad_framewise_output(framewise_output, frames_num):
227 | """Pad framewise_output to the same length as input frames. The pad value
228 | is the same as the value of the last frame.
229 | Args:
230 | framewise_output: (batch_size, frames_num, classes_num)
231 | frames_num: int, number of frames to pad
232 | Outputs:
233 | output: (batch_size, frames_num, classes_num)
234 | """
235 | pad = framewise_output[:, -1:, :].repeat(
236 | 1, frames_num - framewise_output.shape[1], 1
237 | )
238 | """tensor for padding"""
239 |
240 | output = torch.cat((framewise_output, pad), dim=1)
241 | """(batch_size, frames_num, classes_num)"""
242 |
243 |
244 | def process_ipc(index_path, classes_num, filename):
245 | # load data
246 | logging.info("Load Data...............")
247 | ipc = [[] for _ in range(classes_num)]
248 | with h5py.File(index_path, "r") as f:
249 | for i in tqdm(range(len(f["target"]))):
250 | t_class = np.where(f["target"][i])[0]
251 | for t in t_class:
252 | ipc[t].append(i)
253 | print(ipc)
254 | np.save(filename, ipc)
255 | logging.info("Load Data Succeed...............")
256 |
257 |
258 | def save_to_dict(s, o_={}):
259 | sp = s.split(": ")
260 | o_.update({sp[0]: float(sp[1])})
261 | return o_
262 |
263 |
264 | def get_data_from_log(txt_path):
265 | """
266 | Output dictionary from out.txt log file
267 | """
268 | with open(txt_path) as f:
269 | lines = f.readlines()
270 | val_data = {}
271 | train_data = {}
272 | train_losses = []
273 | train_losses_epoch = []
274 | for i in range(len(lines)):
275 | if "| INFO |" in lines[i]:
276 | if "Eval Epoch" in lines[i]:
277 | if "val_loss" in lines[i]:
278 | # float(regex.sub("", lines[310].split(" ")[-1]).replace(" ", ""))
279 | line = lines[i].split("Eval Epoch: ")[-1]
280 | num_epoch = int(line.split(" ")[0].split(" ")[0])
281 | d = {
282 | line.split(" ")[0]
283 | .split(" ")[1]
284 | .replace(":", ""): float(line.split(" ")[0].split(" ")[-1])
285 | }
286 | for i in range(1, len(line.split(" "))):
287 | d = save_to_dict(line.split(" ")[i], d)
288 | val_data[num_epoch] = d
289 | elif "Train Epoch" in lines[i]:
290 | num_epoch = int(lines[i].split("Train Epoch: ")[1][0])
291 | loss = float(lines[i].split("Loss: ")[-1].split(" (")[0])
292 | train_losses.append(loss)
293 | train_losses_epoch.append(num_epoch)
294 | for i in range(len(train_losses)):
295 | train_data[i] = {
296 | "num_epoch": train_losses_epoch[i],
297 | "train_loss": train_losses[i],
298 | }
299 | return train_data, val_data
300 |
301 |
302 | def save_p(obj, filename):
303 | import pickle
304 |
305 | try:
306 | from deepdiff import DeepDiff
307 | except:
308 | os.system("pip install deepdiff")
309 | from deepdiff import DeepDiff
310 | with open(filename, "wb") as file:
311 | pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) # highest protocol
312 | with open(filename, "rb") as file:
313 | z = pickle.load(file)
314 | assert (
315 | DeepDiff(obj, z, ignore_string_case=True) == {}
316 | ), "there is something wrong with the saving process"
317 | return
318 |
319 |
320 | def load_p(filename):
321 | import pickle
322 |
323 | with open(filename, "rb") as file:
324 | z = pickle.load(file)
325 | return z
326 |
327 |
328 | def save_json(data, name="data.json"):
329 | import json
330 | with open(name, 'w') as fp:
331 | json.dump(data, fp)
332 | return
333 |
334 |
335 | def load_json(name):
336 | import json
337 | with open(name, 'r') as fp:
338 | data = json.load(fp)
339 | return data
340 |
341 |
342 | from multiprocessing import Process, Manager
343 | from multiprocessing import Process, Value, Array
344 | from ctypes import c_wchar
345 |
346 |
347 | def load_class_label(path):
348 | # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing
349 | # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array
350 | out = None
351 | if path is not None:
352 | if pathlib.Path(path).suffix in [".pkl", ".pickle"]:
353 | out = load_p(path)
354 | elif pathlib.Path(path).suffix in [".json", ".txt"]:
355 | out = load_json(path)
356 | elif pathlib.Path(path).suffix in [".npy", ".npz"]:
357 | out = np.load(path)
358 | elif pathlib.Path(path).suffix in [".csv"]:
359 | import pandas as pd
360 | out = pd.read_csv(path)
361 | return out
362 | # if out is None:
363 | # return None
364 | # else:
365 | # key = Array(c_wchar, '\n'.join(list(out.keys())), lock=False)
366 | # val = Array('i', out.values(), lock=False)
367 | # return (key, val)
368 |
369 |
370 | from torch import optim
371 |
372 |
373 | def get_optimizer(params, lr, betas, eps, momentum, optimizer_name):
374 | if optimizer_name.lower() == "adamw":
375 | optimizer = optim.AdamW(
376 | params, lr=lr, betas=betas, eps=eps
377 | )
378 | elif optimizer_name.lower() == "sgd":
379 | optimizer = optim.SGD(
380 | params, lr=lr, momentum=momentum
381 | )
382 | elif optimizer_name.lower() == "adam":
383 | optimizer = optim.Adam(
384 | params, lr=lr, betas=betas, eps=eps
385 | )
386 | else:
387 | raise ValueError("optimizer name is not correct")
388 | return optimizer
389 |
--------------------------------------------------------------------------------
/open_musiclm/laion_clap/clap_module/version.py:
--------------------------------------------------------------------------------
1 | __version__ = '0.2.1'
2 |
--------------------------------------------------------------------------------
/open_musiclm/laion_clap/hook.py:
--------------------------------------------------------------------------------
1 | """
2 | Contrastive Language-Audio Pretraining Model from LAION
3 | --------------------------------------------------------
4 | Paper: https://arxiv.org/abs/2211.06687
5 | Authors (equal contributions): Ke Chen, Yusong Wu, Tianyu Zhang, Yuchen Hui
6 | Support: LAION
7 | """
8 | import os
9 | import torch
10 | import torch.nn.functional as F
11 | import torchaudio
12 | import torchvision.transforms
13 | from contextlib import suppress
14 | import numpy as np
15 | from clap_module import create_model
16 |
17 | from transformers import RobertaTokenizer
18 | import wget
19 | from clap_module.factory import load_state_dict
20 |
21 |
22 | def int16_to_float32_torch(x):
23 | return (x / 32767.0).type(torch.float32)
24 |
25 |
26 | def float32_to_int16_torch(x):
27 | x = torch.clamp(x, min=-1., max=1.)
28 | return (x * 32767.).type(torch.int16)
29 |
30 | class CLAP_Module(torch.nn.Module):
31 | def __init__(self, enable_fusion=False, device=None, amodel= 'HTSAT-tiny', tmodel='roberta') -> None:
32 | """Initialize CLAP Model
33 |
34 | Parameters
35 | ----------
36 | enable_fusion: bool
37 | if true, it will create the fusion clap model, otherwise non-fusion clap model (default: false)
38 | device: str
39 | if None, it will automatically detect the device (gpu or cpu)
40 | amodel: str
41 | audio encoder architecture, default: HTSAT-tiny
42 | tmodel: str
43 | text encoder architecture, default: roberta
44 | """
45 | super(CLAP_Module, self).__init__()
46 | if device is None:
47 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
48 | precision = 'fp32'
49 |
50 | if enable_fusion:
51 | fusion_type = 'aff_2d'
52 | model, model_cfg = create_model(
53 | amodel,
54 | tmodel,
55 | precision=precision,
56 | device=device,
57 | enable_fusion=enable_fusion,
58 | fusion_type=fusion_type
59 | )
60 | else:
61 | model, model_cfg = create_model(
62 | amodel,
63 | tmodel,
64 | precision=precision,
65 | device=device,
66 | enable_fusion=enable_fusion
67 | )
68 | self.enable_fusion = enable_fusion
69 | self.model = model
70 | self.model_cfg = model_cfg
71 | self.tokenize = RobertaTokenizer.from_pretrained('roberta-base')
72 |
73 | audio_cfg = model_cfg['audio_cfg']
74 | self.mel_transform = torchaudio.transforms.MelSpectrogram(
75 | sample_rate=audio_cfg['sample_rate'],
76 | n_fft=audio_cfg['window_size'],
77 | win_length=audio_cfg['window_size'],
78 | hop_length=audio_cfg['hop_size'],
79 | center=True,
80 | pad_mode="reflect",
81 | power=2.0,
82 | norm=None,
83 | onesided=True,
84 | n_mels=audio_cfg['mel_bins'],
85 | f_min=audio_cfg['fmin'],
86 | f_max=audio_cfg['fmax']
87 | )
88 | self.log_mel_transform = torchaudio.transforms.AmplitudeToDB(top_db=None)
89 |
90 | def tokenizer(self, text):
91 | result = self.tokenize(
92 | text,
93 | padding="max_length",
94 | truncation=True,
95 | max_length=77,
96 | return_tensors="pt",
97 | )
98 | return result
99 |
100 | def load_ckpt(self, ckpt = None, model_id = -1):
101 | """Load the pretrained checkpoint of CLAP model
102 |
103 | Parameters
104 | ----------
105 | ckpt: str
106 | if ckpt is specified, the model will load this ckpt, otherwise the model will download the ckpt from zenodo. \n
107 | For fusion model, it will download the 630k+audioset fusion model (id=3). For non-fusion model, it will download the 630k+audioset model (id=1).
108 | model_id:
109 | if model_id is specified, you can download our best ckpt, as:
110 | id = 0 --> 630k non-fusion ckpt \n
111 | id = 1 --> 630k+audioset non-fusion ckpt \n
112 | id = 2 --> 630k fusion ckpt \n
113 | id = 3 --> 630k+audioset fusion ckpt \n
114 | Note that if your model is specied as non-fusion model but you download a fusion model ckpt, you will face an error.
115 | """
116 | download_link = 'https://huggingface.co/lukewys/laion_clap/resolve/main/'
117 | download_names = [
118 | '630k-best.pt',
119 | '630k-audioset-best.pt',
120 | '630k-fusion-best.pt',
121 | '630k-audioset-fusion-best.pt'
122 | ]
123 | if ckpt is not None:
124 | print(f'Load the specified checkpoint {ckpt} from users.')
125 | else:
126 | print(f'Load our best checkpoint in the paper.')
127 | if model_id == -1:
128 | model_id = 3 if self.enable_fusion else 1
129 | package_dir = os.path.dirname(os.path.realpath(__file__))
130 | weight_file_name = download_names[model_id]
131 | ckpt = os.path.join(package_dir, weight_file_name)
132 | if os.path.exists(ckpt):
133 | print(f'The checkpoint is already downloaded')
134 | else:
135 | print('Downloading laion_clap weight files...')
136 | ckpt = wget.download(download_link + weight_file_name, os.path.dirname(ckpt))
137 | print('Download completed!')
138 | print('Load Checkpoint...')
139 | ckpt = load_state_dict(ckpt, skip_params=True)
140 | self.model.load_state_dict(ckpt, strict=False)
141 | param_names = [n for n, p in self.model.named_parameters()]
142 | for n in param_names:
143 | print(n, "\t", "Loaded" if n in ckpt else "Unloaded")
144 |
145 | def get_mel(self, audio_data):
146 | mel = self.mel_transform(audio_data)
147 | mel = self.log_mel_transform(mel)
148 | return mel.T # (T, n_mels)
149 |
150 | def get_audio_features(self, sample, audio_data, max_len, data_truncating, data_filling, audio_cfg, require_grad=False):
151 | """
152 | Calculate and add audio features to sample.
153 | Sample: a dict containing all the data of current sample.
154 | audio_data: a tensor of shape (T) containing audio data.
155 | max_len: the maximum length of audio data.
156 | data_truncating: the method of truncating data.
157 | data_filling: the method of filling data.
158 | audio_cfg: a dict containing audio configuration. Comes from model_cfg['audio_cfg'].
159 | require_grad: whether to require gradient for audio data.
160 | This is useful when we want to apply gradient-based classifier-guidance.
161 | """
162 | grad_fn = suppress if require_grad else torch.no_grad
163 | with grad_fn():
164 | if len(audio_data) > max_len:
165 | if data_truncating == "rand_trunc":
166 | longer = torch.tensor([True])
167 | elif data_truncating == "fusion":
168 | # fusion
169 | mel = self.get_mel(audio_data)
170 | # split to three parts
171 | chunk_frames = max_len // audio_cfg['hop_size'] + 1 # the +1 related to how the spectrogram is computed
172 | total_frames = mel.shape[0]
173 | if chunk_frames == total_frames:
174 | # there is a corner case where the audio length is
175 | # larger than max_len but smaller than max_len+hop_size.
176 | # In this case, we just use the whole audio.
177 | mel_fusion = torch.stack([mel, mel, mel, mel], dim=0)
178 | sample["mel_fusion"] = mel_fusion
179 | longer = torch.tensor([False])
180 | else:
181 | ranges = np.array_split(list(range(0, total_frames - chunk_frames + 1)), 3)
182 | # print('total_frames-chunk_frames:', total_frames-chunk_frames,
183 | # 'len(audio_data):', len(audio_data),
184 | # 'chunk_frames:', chunk_frames,
185 | # 'total_frames:', total_frames)
186 | if len(ranges[1]) == 0:
187 | # if the audio is too short, we just use the first chunk
188 | ranges[1] = [0]
189 | if len(ranges[2]) == 0:
190 | # if the audio is too short, we just use the first chunk
191 | ranges[2] = [0]
192 | # randomly choose index for each part
193 | idx_front = np.random.choice(ranges[0])
194 | idx_middle = np.random.choice(ranges[1])
195 | idx_back = np.random.choice(ranges[2])
196 | # select mel
197 | mel_chunk_front = mel[idx_front:idx_front + chunk_frames, :]
198 | mel_chunk_middle = mel[idx_middle:idx_middle + chunk_frames, :]
199 | mel_chunk_back = mel[idx_back:idx_back + chunk_frames, :]
200 |
201 | # shrink the mel
202 | mel_shrink = torchvision.transforms.Resize(size=[chunk_frames, audio_cfg['mel_bins']])(mel[None])[0]
203 | # logging.info(f"mel_shrink.shape: {mel_shrink.shape}")
204 |
205 | # stack
206 | mel_fusion = torch.stack([mel_shrink, mel_chunk_front, mel_chunk_middle, mel_chunk_back], dim=0)
207 | sample["mel_fusion"] = mel_fusion
208 | longer = torch.tensor([True])
209 | else:
210 | raise NotImplementedError(
211 | f"data_truncating {data_truncating} not implemented"
212 | )
213 | # random crop to max_len (for compatibility)
214 | overflow = len(audio_data) - max_len
215 | idx = np.random.randint(0, overflow + 1)
216 | audio_data = audio_data[idx: idx + max_len]
217 |
218 | else: # padding if too short
219 | if len(audio_data) < max_len: # do nothing if equal
220 | if data_filling == "repeatpad":
221 | n_repeat = int(max_len / len(audio_data))
222 | audio_data = audio_data.repeat(n_repeat)
223 | # audio_data = audio_data.unsqueeze(0).unsqueeze(0).unsqueeze(0)
224 | # audio_data = F.interpolate(audio_data,size=max_len,mode="bicubic")[0,0,0]
225 | audio_data = F.pad(
226 | audio_data,
227 | (0, max_len - len(audio_data)),
228 | mode="constant",
229 | value=0,
230 | )
231 | elif data_filling == "pad":
232 | audio_data = F.pad(
233 | audio_data,
234 | (0, max_len - len(audio_data)),
235 | mode="constant",
236 | value=0,
237 | )
238 | elif data_filling == "repeat":
239 | n_repeat = int(max_len / len(audio_data))
240 | audio_data = audio_data.repeat(n_repeat + 1)[:max_len]
241 | else:
242 | raise NotImplementedError(
243 | f"data_filling {data_filling} not implemented"
244 | )
245 | if data_truncating == 'fusion':
246 | mel = self.get_mel(audio_data)
247 | mel_fusion = torch.stack([mel, mel, mel, mel], dim=0)
248 | sample["mel_fusion"] = mel_fusion
249 | longer = torch.tensor([False])
250 |
251 | sample["longer"] = longer
252 | sample["waveform"] = audio_data
253 |
254 | return sample
255 |
256 | def get_audio_embedding_from_data(self, x):
257 | """get audio embeddings from the audio data
258 |
259 | Parameters
260 | ----------
261 | x: torch.Tensor (N,T):
262 | audio data, must be mono audio tracks.
263 | Returns
264 | ----------
265 | audio embed: torch.Tensor (N,D):
266 | audio embeddings that extracted from audio files
267 | """
268 | self.model.eval()
269 | audio_input = []
270 | for audio_waveform in x:
271 | # quantize
272 | audio_waveform = int16_to_float32_torch(float32_to_int16_torch(audio_waveform))
273 | temp_dict = {}
274 | # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
275 | temp_dict = self.get_audio_features(
276 | temp_dict, audio_waveform, 480000,
277 | data_truncating='fusion' if self.enable_fusion else 'rand_trunc',
278 | data_filling='repeatpad',
279 | audio_cfg=self.model_cfg['audio_cfg'],
280 | require_grad=audio_waveform.requires_grad
281 | )
282 | audio_input.append(temp_dict)
283 | audio_embed = self.model.get_audio_embedding(audio_input)
284 | return audio_embed
285 |
286 | def get_text_embedding(self, x, tokenizer = None):
287 | """get text embeddings from texts
288 |
289 | Parameters
290 | ----------
291 | x: List[str] (N,):
292 | text list
293 | tokenizer: func:
294 | the tokenizer function, if not provided (None), will use the default Roberta tokenizer.
295 |
296 | Returns
297 | ----------
298 | text_embed : torch.Tensor (N,D):
299 | text embeddings that extracted from texts
300 | """
301 | self.model.eval()
302 | if tokenizer is not None:
303 | text_input = tokenizer(x)
304 | else:
305 | text_input = self.tokenizer(x)
306 | text_embed = self.model.get_text_embedding(text_input)
307 | text_embed = text_embed
308 | return text_embed
309 |
310 |
311 |
--------------------------------------------------------------------------------
/open_musiclm/model_types.py:
--------------------------------------------------------------------------------
1 | from beartype.typing import Union
2 |
3 | from .hf_hubert_kmeans import HfHubertWithKmeans
4 | from .encodec_wrapper import EncodecWrapper
5 |
6 | Wav2Vec = HfHubertWithKmeans
7 | NeuralCodec = EncodecWrapper
8 |
--------------------------------------------------------------------------------
/open_musiclm/optimizer.py:
--------------------------------------------------------------------------------
1 | from torch.optim import AdamW, Adam, lr_scheduler
2 |
3 | def separate_weight_decayable_params(params):
4 | wd_params, no_wd_params = [], []
5 | for param in params:
6 | param_list = no_wd_params if param.ndim < 2 else wd_params
7 | param_list.append(param)
8 | return wd_params, no_wd_params
9 |
10 | def get_optimizer(
11 | params,
12 | lr = 1e-4,
13 | wd = 1e-2,
14 | betas = (0.9, 0.99),
15 | eps = 1e-8,
16 | filter_by_requires_grad = False,
17 | group_wd_params = True,
18 | **kwargs
19 | ):
20 | if filter_by_requires_grad:
21 | params = list(filter(lambda t: t.requires_grad, params))
22 |
23 | if wd == 0:
24 | return Adam(params, lr = lr, betas = betas, eps = eps)
25 |
26 | if group_wd_params:
27 | wd_params, no_wd_params = separate_weight_decayable_params(params)
28 |
29 | params = [
30 | {'params': wd_params},
31 | {'params': no_wd_params, 'weight_decay': 0},
32 | ]
33 |
34 | return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)
35 |
36 | def get_linear_scheduler(
37 | optimizer,
38 | total_iters=10000,
39 | start_factor=1e-7,
40 | ):
41 | return lr_scheduler.LinearLR(optimizer=optimizer, start_factor=start_factor, end_factor=1., total_iters=total_iters)
--------------------------------------------------------------------------------
/open_musiclm/preprocess.py:
--------------------------------------------------------------------------------
1 | import io
2 | import itertools
3 | import math
4 | import time
5 | from dataclasses import dataclass
6 | from pathlib import Path
7 | from shutil import rmtree
8 |
9 | import numpy as np
10 | import torch
11 | import torch.nn.functional as F
12 | import torchaudio
13 | from accelerate import Accelerator, DistributedType
14 | from beartype.door import is_bearable
15 | from beartype.typing import Dict, List, Literal, Optional, Union
16 | from beartype.vale import Is
17 | from einops import rearrange, reduce, repeat
18 | from einops.layers.torch import Rearrange
19 | from torch import einsum, nn
20 | from torch.utils.data import DataLoader, Dataset, random_split
21 | from tqdm import tqdm
22 | from typing_extensions import Annotated
23 |
24 | from .clap_quantized import ClapQuantized
25 | from .data import (SoundDatasetForPreprocessing,
26 | get_sound_preprocessing_dataloader, init_sqlite)
27 | from .hf_hubert_kmeans import HfHubertWithKmeans, learn_kmeans
28 | from .model_types import NeuralCodec, Wav2Vec
29 | from .open_musiclm import (get_or_compute_acoustic_token_ids,
30 | get_or_compute_clap_token_ids,
31 | get_or_compute_semantic_token_ids)
32 | from .optimizer import get_linear_scheduler, get_optimizer
33 | from .utils import (all_rows_have_eos_id, append_eos_id,
34 | batch_unique_consecutive, beartype_jit, ceil_div,
35 | copy_file_to_folder, default, eval_decorator, exists,
36 | generate_mask_with_prob, get_embeds, gumbel_sample,
37 | mask_out_after_eos_id, round_down_nearest_multiple, top_k)
38 |
39 |
40 | def cycle(dl):
41 | while True:
42 | for data in dl:
43 | yield data
44 |
45 |
46 | def yes_or_no(question):
47 | answer = input(f'{question} (y/n) ')
48 | return answer.lower() in ('yes', 'y')
49 |
50 |
51 | # auto data to module keyword argument routing functions
52 |
53 | def has_duplicates(tup):
54 | counts = dict()
55 | for el in tup:
56 | if el not in counts:
57 | counts[el] = 0
58 | counts[el] += 1
59 | return any(filter(lambda count: count > 1, counts.values()))
60 |
61 |
62 | def determine_types(data, config):
63 | output = []
64 | for el in data:
65 | for name, data_type in config.items():
66 | if is_bearable(el, data_type):
67 | output.append(name)
68 | break
69 | else:
70 | raise TypeError(f'unable to determine type of {data}')
71 |
72 | return tuple(output)
73 |
74 |
75 | def noop(*args, **kwargs):
76 | pass
77 |
78 | def without_none(arr):
79 | return list(filter(lambda x: x is not None, arr))
80 |
81 | @beartype_jit
82 | class DataPreprocessor(nn.Module):
83 | """
84 | Class to preprocess audio files for the single stage transformer trainer.
85 |
86 | Load audio and compute:
87 | 1) clap tokens for the entire audio file, computed for 10 second sliding windows with 1 second interval
88 | 2) semantic tokens for the entire audio file
89 | 3) coarse+fine tokens for the entire audio file
90 | Run this once over the dataset and then use the preprocessed data for training.
91 | """
92 |
93 | def __init__(
94 | self,
95 | *,
96 | num_coarse_quantizers=3,
97 | wav2vec: Optional[Wav2Vec] = None,
98 | neural_codec: Optional[NeuralCodec] = None,
99 | audio_conditioner: Optional[ClapQuantized] = None,
100 | max_audio_length_seconds=180,
101 | random_crop=True,
102 | clap_audio_length_seconds=10,
103 | semantic_audio_length_seconds=10,
104 | clap_batch_size=32,
105 | num_crops=1,
106 | ignore_files: Optional[List[str]]=None,
107 | ignore_load_errors=True,
108 | replace_existing=False,
109 | folder=None,
110 | results_folder='./data/fma_preprocessed',
111 | accelerate_kwargs: dict = {},
112 | config_paths: Optional[List[str]] = None,
113 | **kwargs,
114 | ):
115 | super().__init__()
116 |
117 | self.accelerator = Accelerator(**accelerate_kwargs)
118 |
119 | self.wav2vec = wav2vec
120 | self.audio_conditioner = audio_conditioner
121 | self.neural_codec = neural_codec
122 | self.num_coarse_quantizers = num_coarse_quantizers
123 | self.max_audio_length_seconds = max_audio_length_seconds
124 | self.clap_audio_length_seconds = int(clap_audio_length_seconds)
125 | self.semantic_audio_length_seconds = int(semantic_audio_length_seconds)
126 | # TODO: allow a smaller clap length than semantic length, and average the clap embeddings over the time period as in the paper
127 | assert self.clap_audio_length_seconds == self.semantic_audio_length_seconds, 'clap window must be equal to semantic window for now'
128 | self.clap_batch_size = clap_batch_size
129 | self.num_crops = num_crops
130 | self.replace_existing = replace_existing
131 |
132 | self.register_buffer('steps', torch.Tensor([0]))
133 |
134 | # create dataset
135 |
136 | assert exists(wav2vec) and exists(audio_conditioner) and exists(neural_codec)
137 |
138 | self.ds_fields = ('raw_wave_for_clap', 'raw_wave_for_semantic', 'raw_wave_for_acoustic')
139 |
140 | target_sample_hz = (audio_conditioner.sample_rate, wav2vec.target_sample_hz, neural_codec.sample_rate)
141 |
142 | normalize = (False, True, False)
143 |
144 | seq_len_multiple_of = (None, wav2vec.seq_len_multiple_of, None)
145 |
146 | data_max_length_seconds = (max_audio_length_seconds, max_audio_length_seconds, max_audio_length_seconds)
147 |
148 | assert exists(folder), 'audio folder must be passed in for preprocessing'
149 |
150 | self.ds = SoundDatasetForPreprocessing(
151 | folder,
152 | pad_to_seconds=self.semantic_audio_length_seconds,
153 | max_length_seconds=data_max_length_seconds,
154 | random_crop=random_crop,
155 | normalize=normalize,
156 | target_sample_hz=target_sample_hz,
157 | seq_len_multiple_of=seq_len_multiple_of,
158 | ignore_load_errors=ignore_load_errors,
159 | ignore_files=ignore_files,
160 | )
161 |
162 | # dataloader
163 |
164 | self.dl = get_sound_preprocessing_dataloader(self.ds, batch_size=1, shuffle=False)
165 |
166 | # prepare
167 |
168 | (
169 | self.audio_conditioner,
170 | self.wav2vec,
171 | self.neural_codec,
172 | self.dl
173 | ) = self.accelerator.prepare(
174 | self.audio_conditioner,
175 | self.wav2vec,
176 | self.neural_codec,
177 | self.dl
178 | )
179 |
180 | # dataloader iterators
181 |
182 | self.dl_iter = cycle(self.dl)
183 |
184 | self.results_folder = Path(results_folder)
185 |
186 | if self.is_main:
187 | if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'):
188 | rmtree(str(self.results_folder))
189 |
190 | self.results_folder.mkdir(parents=True, exist_ok=True)
191 |
192 | if self.is_main and exists(config_paths):
193 | configs_folder = self.results_folder / "configs"
194 | configs_folder.mkdir(parents=True, exist_ok=True)
195 | for config_path in config_paths:
196 | copy_file_to_folder(config_path, configs_folder)
197 |
198 | if self.is_main:
199 | self.conn, self.cursor = init_sqlite(str(self.results_folder / 'preprocessed.db'))
200 | self.cursor.execute("CREATE TABLE IF NOT EXISTS tokens(idx integer primary key, path text, clap array, semantic array, coarse array, fine array)")
201 |
202 | self.accelerator.wait_for_everyone()
203 |
204 | if not self.is_main:
205 | self.conn, self.cursor = init_sqlite(str(self.results_folder / 'preprocessed.db'))
206 |
207 | def print(self, msg):
208 | self.accelerator.print(msg)
209 |
210 | @property
211 | def device(self):
212 | return self.accelerator.device
213 |
214 | @property
215 | def is_distributed(self):
216 | return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)
217 |
218 | @property
219 | def is_main(self):
220 | return self.accelerator.is_main_process
221 |
222 | @property
223 | def is_local_main(self):
224 | return self.accelerator.is_local_main_process
225 |
226 | @property
227 | def device(self):
228 | return next(self.parameters()).device
229 |
230 | def generate_tokens_from_batch(self, raw_wave_for_clap, raw_wave_for_semantic, raw_wave_for_acoustic):
231 | # split clap waveform into a clap_audio_length_seconds sliding window with 1 second interval. sample rate is self.audio_conditioner.sample_rate
232 | clap_split = raw_wave_for_clap.unfold(
233 | -1,
234 | self.audio_conditioner.sample_rate * self.clap_audio_length_seconds,
235 | self.audio_conditioner.sample_rate
236 | ).squeeze(0)
237 |
238 | batch_size = self.clap_batch_size
239 | clap_token_ids_all = []
240 | for i in range(0, clap_split.shape[0], batch_size):
241 |
242 | batch = clap_split[i:i+batch_size, :]
243 | clap_token_ids = get_or_compute_clap_token_ids(None, self.accelerator.unwrap_model(self.audio_conditioner), batch, None)
244 | clap_token_ids_all.append(clap_token_ids)
245 |
246 | clap_token_ids = torch.cat(clap_token_ids_all, dim=0)
247 |
248 | semantic_token_ids = get_or_compute_semantic_token_ids(None, raw_wave_for_semantic, self.accelerator.unwrap_model(self.wav2vec))
249 |
250 | coarse_token_ids, fine_token_ids = get_or_compute_acoustic_token_ids(None, None, raw_wave_for_acoustic, self.accelerator.unwrap_model(self.neural_codec), self.num_coarse_quantizers)
251 |
252 | return clap_token_ids, semantic_token_ids, (coarse_token_ids, fine_token_ids)
253 |
254 | def process(self, log_fn=noop):
255 | iters = math.ceil(self.num_crops * len(self.ds) / self.accelerator.num_processes)
256 | for idx in tqdm(range(iters), desc='processing data', mininterval=5):
257 | inputs = next(self.dl_iter)
258 | if exists(inputs):
259 | idx = idx * self.accelerator.num_processes + self.accelerator.process_index
260 | if not self.replace_existing:
261 | self.cursor.execute("SELECT * FROM tokens WHERE idx=?", (idx,))
262 | if len(self.cursor.fetchall()) > 0:
263 | continue
264 |
265 | data_kwargs = dict(zip(self.ds_fields, inputs['data']))
266 | clap_token_ids, semantic_token_ids, (coarse_token_ids, fine_token_ids) = self.generate_tokens_from_batch(**data_kwargs)
267 |
268 | clap_token_ids = clap_token_ids.detach().cpu().numpy()
269 | semantic_token_ids = semantic_token_ids.detach().cpu().numpy()
270 | coarse_token_ids = coarse_token_ids.detach().cpu().numpy()
271 | fine_token_ids = fine_token_ids.detach().cpu().numpy()
272 |
273 | # convert to int16 to save space
274 | clap_token_ids = clap_token_ids.astype(np.uint16)
275 | semantic_token_ids = semantic_token_ids.astype(np.uint16)
276 | coarse_token_ids = coarse_token_ids.astype(np.uint16)
277 | fine_token_ids = fine_token_ids.astype(np.uint16)
278 | # add tokens to sqlite db
279 | self.cursor.execute("INSERT INTO tokens VALUES (?, ?, ?, ?, ?, ?)", (idx, inputs['file_path'][0], clap_token_ids, semantic_token_ids, coarse_token_ids, fine_token_ids))
280 | self.conn.commit()
281 |
282 | self.steps += 1
283 |
284 | self.print('processing complete')
285 |
--------------------------------------------------------------------------------
/open_musiclm/transformer.py:
--------------------------------------------------------------------------------
1 | # Transformer implementation from https://github.com/lucidrains/audiolm-pytorch/blob/main/audiolm_pytorch/audiolm_pytorch.py
2 |
3 | from functools import partial
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | from einops import rearrange, repeat
8 | from torch import einsum, nn
9 | import math
10 |
11 | from .utils import default, exists, grad_shrink, l2norm
12 |
13 | try:
14 | import xformers.ops as xops
15 |
16 | is_xformers_available = True
17 | except ImportError:
18 | is_xformers_available = False
19 |
20 | # bias-less layernorm, being used in more recent T5s, PaLM, also in @borisdayma 's experiments shared with me
21 | # greater stability
22 |
23 |
24 | class LayerNorm(nn.Module):
25 | def __init__(self, dim):
26 | super().__init__()
27 | self.gamma = nn.Parameter(torch.ones(dim))
28 | self.register_buffer("beta", torch.zeros(dim))
29 |
30 | def forward(self, x):
31 | return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
32 |
33 | # relative positional bias
34 |
35 |
36 | class RelativePositionBias(nn.Module):
37 | """ from https://arxiv.org/abs/2111.09883 """
38 |
39 | def __init__(
40 | self,
41 | *,
42 | dim,
43 | heads,
44 | layers=3
45 | ):
46 | super().__init__()
47 | self.net = nn.ModuleList([])
48 | self.net.append(nn.Sequential(nn.Linear(1, dim), nn.SiLU()))
49 |
50 | for _ in range(layers - 1):
51 | self.net.append(nn.Sequential(nn.Linear(dim, dim), nn.SiLU()))
52 |
53 | self.net.append(nn.Linear(dim, heads))
54 |
55 | def forward(self, n, device=torch.device('cpu')):
56 | pos = torch.arange(n, device=device)
57 | rel_pos = (rearrange(pos, 'i -> i 1') - rearrange(pos, 'j -> 1 j'))
58 | rel_pos += (n - 1)
59 |
60 | x = torch.arange(-n + 1, n, device=device).float()
61 | x = rearrange(x, '... -> ... 1')
62 |
63 | for layer in self.net:
64 | x = layer(x)
65 |
66 | x = x[rel_pos]
67 | return rearrange(x, 'i j h -> h i j')
68 |
69 | class T5RelativePositionBias(nn.Module):
70 | def __init__(
71 | self,
72 | *,
73 | heads,
74 | num_buckets=32,
75 | max_distance=128,
76 | causal=True
77 | ):
78 | super().__init__()
79 | self.num_buckets = num_buckets
80 | self.max_distance = max_distance
81 | self.causal = causal
82 |
83 | self.relative_attention_bias = nn.Embedding(num_buckets, heads)
84 |
85 | @staticmethod
86 | def _relative_position_bucket(relative_position, causal = True, num_buckets = 32, max_distance = 128):
87 | ret = 0
88 | n = -relative_position
89 | if not causal:
90 | num_buckets //= 2
91 | ret += (n < 0).long() * num_buckets
92 | n = torch.abs(n)
93 | else:
94 | n = torch.max(n, torch.zeros_like(n))
95 |
96 | max_exact = num_buckets // 2
97 | is_small = n < max_exact
98 |
99 | val_if_large = max_exact + (
100 | torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
101 | ).long()
102 | val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
103 |
104 | ret += torch.where(is_small, n, val_if_large)
105 | return ret
106 |
107 | def forward(
108 | self,
109 | n,
110 | device=torch.device('cpu')
111 | ):
112 | pos = torch.arange(n, device=device)
113 | rel_pos = (rearrange(pos, 'i -> i 1') - rearrange(pos, 'j -> 1 j'))
114 | rel_pos = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets, max_distance=self.max_distance)
115 |
116 | bias = self.relative_attention_bias(rel_pos)
117 | return rearrange(bias, 'i j h -> h i j')
118 |
119 | # feedforward
120 |
121 |
122 | class CausalDSConv(nn.Module):
123 | def __init__(self, dim):
124 | super().__init__()
125 | self.ds_conv = nn.Conv1d(dim, dim, 3, bias=False, groups=dim)
126 |
127 | def forward(self, x):
128 | x = rearrange(x, 'b n c -> b c n')
129 | x = F.pad(x, (2, 0))
130 | x = self.ds_conv(x)
131 | return rearrange(x, 'b c n -> b n c')
132 |
133 |
134 | class GEGLU(nn.Module):
135 | def forward(self, x):
136 | x, gate = x.chunk(2, dim=-1)
137 | return F.gelu(gate) * x
138 |
139 |
140 | def ConvFeedForward(dim, mult=4, dropout=0.1):
141 | inner_dim = int(dim * 2 * mult / 3)
142 | return nn.Sequential(
143 | LayerNorm(dim),
144 | nn.Linear(dim, inner_dim * 2, bias=False),
145 | CausalDSConv(inner_dim * 2),
146 | GEGLU(),
147 | LayerNorm(inner_dim),
148 | nn.Dropout(dropout),
149 | nn.Linear(inner_dim, dim, bias=False)
150 | )
151 |
152 | def FeedForward(dim, mult=4, dropout=0.1):
153 | inner_dim = int(dim * mult)
154 | return nn.Sequential(
155 | LayerNorm(dim),
156 | nn.Linear(dim, inner_dim * 2, bias=False),
157 | GEGLU(),
158 | LayerNorm(inner_dim),
159 | nn.Dropout(dropout),
160 | nn.Linear(inner_dim, dim, bias=False)
161 | )
162 |
163 | # attention
164 |
165 |
166 | class Attention(nn.Module):
167 | def __init__(
168 | self,
169 | dim,
170 | causal=False,
171 | non_causal_prefix=0,
172 | dim_head=64,
173 | dim_context=None,
174 | heads=8,
175 | norm_context=False,
176 | num_null_kv=0,
177 | dropout=0.1,
178 | scale=8,
179 | use_memory_efficient_attention=False
180 | ):
181 | super().__init__()
182 | self.heads = heads
183 | self.scale = scale
184 | self.causal = causal
185 | self.non_causal_prefix = non_causal_prefix
186 | self.dropout = dropout
187 | self.use_memory_efficient_attention = use_memory_efficient_attention
188 | if self.use_memory_efficient_attention and not is_xformers_available:
189 | raise ImportError("Please install xformers to use memory efficient attention")
190 |
191 | inner_dim = dim_head * heads
192 |
193 | dim_context = default(dim_context, dim)
194 |
195 | self.norm = LayerNorm(dim)
196 | self.context_norm = LayerNorm(dim_context) if norm_context else nn.Identity()
197 |
198 | self.attn_dropout = nn.Dropout(dropout) if not self.use_memory_efficient_attention else None
199 |
200 | self.num_null_kv = num_null_kv
201 | self.null_kv = nn.Parameter(torch.randn(2, num_null_kv, dim_head)) if num_null_kv > 0 else None
202 |
203 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
204 | self.to_kv = nn.Linear(dim_context, dim_head * 2, bias=False)
205 |
206 | self.q_scale = nn.Parameter(torch.ones(dim_head))
207 | self.k_scale = nn.Parameter(torch.ones(dim_head))
208 |
209 | self.to_out = nn.Sequential(
210 | nn.Linear(inner_dim, dim, bias=False),
211 | nn.Dropout(dropout)
212 | )
213 |
214 | def forward(
215 | self,
216 | x,
217 | context=None,
218 | mask=None,
219 | attn_bias=None,
220 | prefix_context=None,
221 | prefix_context_mask=None
222 | ):
223 | b, n, _, device = *x.shape, x.device
224 |
225 | if exists(context):
226 | context = self.context_norm(context)
227 |
228 | kv_input = default(context, x)
229 |
230 | # take care of prefix-based self attention conditioning
231 | # make sure to either concat the to the self attention mask or lengthen it accordingly
232 |
233 | if exists(prefix_context):
234 | kv_input = torch.cat((prefix_context, kv_input), dim=-2)
235 | prefix_seq_len = prefix_context.shape[-2]
236 |
237 | if not exists(mask):
238 | mask = torch.ones((b, n), device=device, dtype=torch.bool)
239 |
240 | if exists(prefix_context_mask):
241 | mask = torch.cat((prefix_context_mask, mask), dim=-1)
242 | else:
243 | mask = F.pad(mask, (prefix_seq_len, 0), value=True)
244 |
245 | if exists(attn_bias):
246 | attn_bias = F.pad(attn_bias, (prefix_seq_len, 0), value=0.)
247 |
248 | # prenorm
249 |
250 | x = self.norm(x)
251 |
252 | # project for queries, keys, values
253 |
254 | q, k, v = self.to_q(x), *self.to_kv(kv_input).chunk(2, dim=-1)
255 |
256 | # null key / values
257 |
258 | if self.num_null_kv > 0:
259 | null_k, null_v = repeat(self.null_kv, 'kv n d -> kv b n d', b=b).unbind(dim=0)
260 | k = torch.cat((null_k, k), dim=-2)
261 | v = torch.cat((null_v, v), dim=-2)
262 |
263 | # split for multi-headed attention
264 |
265 | q = rearrange(q, 'b n (h d) -> b h n d', h=self.heads)
266 |
267 | # new technique, rmsnormed queries and keys, first used by 22B parameter model successfully https://arxiv.org/abs/2302.05442
268 |
269 | q, k = map(l2norm, (q, k))
270 | q = q * self.q_scale
271 | k = k * self.k_scale
272 |
273 | # attention
274 |
275 | if self.use_memory_efficient_attention:
276 | if exists(attn_bias):
277 | attn_bias = F.pad(attn_bias, (self.num_null_kv, 0), value=0.)
278 |
279 | if exists(mask):
280 | mask = F.pad(mask, (self.num_null_kv, 0), value=True)
281 | mask = rearrange(mask, 'b j -> b 1 1 j')
282 | attn_bias = attn_bias.masked_fill(~mask, -torch.finfo(attn_bias.dtype).max)
283 |
284 | if self.causal:
285 | i, j = attn_bias.shape[-2:]
286 | causal_mask = torch.ones((i, j), dtype=torch.bool, device=x.device).triu(j - i + 1)
287 |
288 | if self.non_causal_prefix > 0:
289 | causal_mask[:self.non_causal_prefix, :(self.non_causal_prefix + j - i)] = False
290 |
291 | attn_bias = attn_bias.masked_fill(causal_mask, -torch.finfo(attn_bias.dtype).max)
292 |
293 | q = rearrange(q, 'b h n d -> b n h d')
294 | k = repeat(k, 'b n d -> b n h d', h=self.heads)
295 | v = repeat(v, 'b n d -> b n h d', h=self.heads)
296 |
297 | # compute attention
298 | out = xops.memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=self.dropout)
299 |
300 | # merge heads
301 | out = rearrange(out, 'b n h d -> b n (h d)')
302 |
303 | else:
304 | sim = einsum('b h i d, b j d -> b h i j', q, k) * self.scale
305 |
306 | if exists(attn_bias):
307 | attn_bias = F.pad(attn_bias, (self.num_null_kv, 0), value=0.)
308 | sim = sim + attn_bias
309 |
310 | if exists(mask):
311 | mask = F.pad(mask, (self.num_null_kv, 0), value=True)
312 | mask = rearrange(mask, 'b j -> b 1 1 j')
313 | sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
314 |
315 | if self.causal:
316 | i, j = sim.shape[-2:]
317 | causal_mask = torch.ones((i, j), dtype=torch.bool, device=x.device).triu(j - i + 1)
318 |
319 | if self.non_causal_prefix > 0:
320 | causal_mask[:self.non_causal_prefix, :(self.non_causal_prefix + j - i)] = False
321 |
322 | sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
323 |
324 | attn = sim.softmax(dim=-1)
325 | attn = self.attn_dropout(attn)
326 |
327 | # aggregate
328 | out = einsum('b h i j, b j d -> b h i d', attn, v)
329 |
330 | # merge heads
331 | out = rearrange(out, 'b h n d -> b n (h d)')
332 |
333 | return self.to_out(out)
334 |
335 | # transformer
336 |
337 |
338 | class Transformer(nn.Module):
339 | def __init__(
340 | self,
341 | *,
342 | dim,
343 | depth,
344 | heads,
345 | dim_context=None,
346 | cross_attend=False,
347 | attn_dropout=0.,
348 | ff_dropout=0.,
349 | use_conv_ff=True,
350 | grad_shrink_alpha=0.1,
351 | cond_as_self_attn_prefix=False,
352 | non_causal_prefix_size=0,
353 | relative_position_bias_type='continuous',
354 | **kwargs
355 | ):
356 | super().__init__()
357 | assert not (cross_attend and cond_as_self_attn_prefix)
358 | self.dim_context = default(dim_context, dim)
359 |
360 | self.cond_as_self_attn_prefix = cond_as_self_attn_prefix
361 |
362 | self.grad_shrink = partial(grad_shrink, alpha=grad_shrink_alpha)
363 |
364 | self.layers = nn.ModuleList([])
365 |
366 | if relative_position_bias_type == 'continuous':
367 | self.rel_pos_bias = RelativePositionBias(dim=dim // 2, heads=heads)
368 | elif relative_position_bias_type == 't5':
369 | self.rel_pos_bias = T5RelativePositionBias(heads=heads, num_buckets=32, max_distance=128)
370 | elif relative_position_bias_type == 'none':
371 | self.rel_pos_bias = None
372 | else:
373 | raise ValueError(f'invalid relative position bias type: {relative_position_bias_type}')
374 |
375 | for _ in range(depth):
376 | self.layers.append(nn.ModuleList([
377 | Attention(dim=dim, heads=heads, dropout=attn_dropout, causal=True, non_causal_prefix=non_causal_prefix_size, **kwargs),
378 | Attention(dim=dim, heads=heads, dropout=attn_dropout, dim_context=dim_context,
379 | num_null_kv=1, norm_context=True, **kwargs) if cross_attend else None,
380 | ConvFeedForward(dim=dim, dropout=ff_dropout) if use_conv_ff else FeedForward(dim=dim, dropout=ff_dropout),
381 | ]))
382 |
383 | self.norm = LayerNorm(dim)
384 |
385 | def forward(
386 | self,
387 | x,
388 | self_attn_mask=None,
389 | context=None,
390 | context_mask=None,
391 | attn_bias=None,
392 | ):
393 | assert not (self.cond_as_self_attn_prefix and not exists(context))
394 | assert not (exists(
395 | context) and context.shape[-1] != self.dim_context), f'you had specified a conditioning dimension of {self.dim_context}, yet what was received by the transformer has dimension of {context.shape[-1]}'
396 |
397 | n, device = x.shape[1], x.device
398 |
399 | # from cogview paper, adopted by GLM 130B LLM, decreases likelihood of attention net instability
400 | x = self.grad_shrink(x)
401 |
402 | if exists(attn_bias):
403 | rel_pos_bias = attn_bias
404 | else:
405 | rel_pos_bias = self.rel_pos_bias(n, device = device) if exists(self.rel_pos_bias) else None
406 |
407 | self_attn_kwargs = dict()
408 | if self.cond_as_self_attn_prefix:
409 | self_attn_kwargs = dict(
410 | prefix_context=context,
411 | prefix_context_mask=context_mask
412 | )
413 |
414 | for attn, cross_attn, ff in self.layers:
415 | x = attn(x, attn_bias=rel_pos_bias, mask=self_attn_mask, **self_attn_kwargs) + x
416 |
417 | if exists(cross_attn):
418 | assert exists(context)
419 |
420 | x = cross_attn(x, context=context, mask=context_mask) + x
421 |
422 | x = ff(x) + x
423 |
424 | return self.norm(x)
425 |
--------------------------------------------------------------------------------
/open_musiclm/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from torch.nn.utils.rnn import pad_sequence
5 | from beartype import beartype
6 | from pathlib import Path
7 | import shutil
8 | import os
9 | from torchaudio.functional import resample
10 |
11 | from einops import rearrange, repeat, reduce
12 |
13 | def beartype_jit(func):
14 | """decorator to enable beartype only if USE_BEARTYPE is set to 1"""
15 | return beartype(func) if os.environ.get('USE_BEARTYPE', '0') == '1' else func
16 |
17 | # helper functions
18 |
19 | def exists(val):
20 | return val is not None
21 |
22 | def default(val, d):
23 | return val if exists(val) else d
24 |
25 | def ceil_div(numer, denom):
26 | return (numer + denom - 1) // denom
27 |
28 | def remainder_needed_until_multiple(n, mult):
29 | return (ceil_div(n, mult) * mult) - n
30 |
31 | def round_down_nearest_multiple(val, mult):
32 | return (val // mult) * mult
33 |
34 | def curtail_to_multiple(t, mult):
35 | data_len = t.shape[-1]
36 | return t[..., :round_down_nearest_multiple(data_len, mult)]
37 |
38 | def eval_decorator(fn):
39 | def inner(model, *args, **kwargs):
40 | was_training = model.training
41 | model.eval()
42 | out = fn(model, *args, **kwargs)
43 | model.train(was_training)
44 | return out
45 | return inner
46 |
47 | # tensor helpers
48 |
49 | def generate_mask_with_prob(shape, mask_prob, device):
50 | seq = shape[-1]
51 | rand = torch.randn(shape, device = device)
52 | rand[:, 0] = -torch.finfo(rand.dtype).max
53 | num_mask = min(int(seq * mask_prob), seq - 1)
54 | indices = rand.topk(num_mask, dim = -1).indices
55 | mask = ~torch.zeros(shape, device = device).scatter(1, indices, 1.).bool()
56 | return mask
57 |
58 | # attention related utils
59 |
60 | def grad_shrink(t, alpha = 0.1):
61 | return t * alpha + t.detach() * (1 - alpha)
62 |
63 | # sampling helpers
64 |
65 | def log(t, eps = 1e-20):
66 | return torch.log(t + eps)
67 |
68 | def l2norm(t):
69 | return F.normalize(t, dim = -1)
70 |
71 | def gumbel_noise(t):
72 | noise = torch.zeros_like(t).uniform_(0, 1)
73 | return -log(-log(noise))
74 |
75 | def gumbel_sample(t, temperature = 1., dim = -1):
76 | return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim)
77 |
78 | def top_k(logits, thres = 0.5):
79 | num_logits = logits.shape[-1]
80 | k = max(int((1 - thres) * num_logits), 1)
81 | val, ind = torch.topk(logits, k)
82 | probs = torch.full_like(logits, float('-inf'))
83 | probs.scatter_(1, ind, val)
84 | return probs
85 |
86 | def mask_out_after_eos_id(t, eos_id, mask_value = -1, keep_eos = True):
87 | eos_mask = (t == eos_id).float()
88 |
89 | if keep_eos:
90 | eos_mask = F.pad(eos_mask, (1, -1))
91 |
92 | after_eos_mask = eos_mask.cumsum(dim = -1) > 0
93 | return t.masked_fill(after_eos_mask, mask_value)
94 |
95 | def all_rows_have_eos_id(t, eos_id):
96 | eos_mask = (t == eos_id)
97 | return torch.any(eos_mask, dim = -1).all()
98 |
99 | # classifier free guidance functions
100 |
101 | def prob_mask_like(shape, prob, device):
102 | if prob == 1:
103 | return torch.ones(shape, device = device, dtype = torch.bool)
104 | elif prob == 0:
105 | return torch.zeros(shape, device = device, dtype = torch.bool)
106 | else:
107 | return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob
108 |
109 | # removing unique consecutives in the semantic token ids
110 | # important detail noted by @eonglints
111 |
112 | def append_eos_id(ids, eos_id):
113 | b, device = ids.shape[0], ids.device
114 | eos_ids = torch.ones(1, device = device).long() * eos_id
115 | eos_ids = repeat(eos_ids, '1 -> b 1', b = b)
116 | ids = torch.cat((ids, eos_ids), dim = -1)
117 | return ids
118 |
119 | def batch_unique_consecutive(t, pad_value = 0.):
120 | unique_arr = [torch.unique_consecutive(el) for el in t.unbind(dim = 0)]
121 | return pad_sequence(unique_arr, batch_first = True, padding_value = pad_value)
122 |
123 | # to get embedding from sequence with padding token
124 |
125 | @beartype_jit
126 | def get_embeds(
127 | embeddings: nn.Embedding,
128 | codes: torch.Tensor,
129 | pad_id = -1,
130 | return_mask = False,
131 | mask_pad_pos_to = 0
132 | ):
133 | pad_mask = codes == pad_id
134 | codes_without_pad = codes.masked_fill(pad_mask, 0) # just retrieve first code as dummy
135 | embeds = embeddings(codes_without_pad)
136 |
137 | if exists(mask_pad_pos_to):
138 | embeds = embeds.masked_fill(rearrange(pad_mask, '... -> ... 1'), mask_pad_pos_to)
139 |
140 | if return_mask:
141 | return embeds, ~pad_mask
142 |
143 | return embeds
144 |
145 | # audio processing helpers
146 |
147 | def int16_to_float32(x):
148 | return (x / 32767.0).type(torch.float32)
149 |
150 | def float32_to_int16(x):
151 | x = torch.clamp(x, min=-1., max=1.)
152 | return (x * 32767.).type(torch.int16)
153 |
154 | def zero_mean_unit_var_norm(x):
155 | return (x - x.mean(dim=-1, keepdim=True)) / torch.sqrt(x.var(dim=-1, keepdim=True) + 1e-7)
156 |
157 | def prepare_audio(data, sample_hz, target_sample_hz, normalize=True, target_length_seconds=None):
158 | if data.shape[0] > 1:
159 | data = torch.mean(data, dim=0).unsqueeze(0)
160 | if normalize:
161 | data = zero_mean_unit_var_norm(data)
162 | if exists(target_length_seconds) and data.shape[1] > target_length_seconds * sample_hz:
163 | data = data[: , :int(target_length_seconds * sample_hz)]
164 | audio_for_wav2vec = resample(data, sample_hz, target_sample_hz)
165 | audio_for_wav2vec = int16_to_float32(float32_to_int16(audio_for_wav2vec))
166 | return audio_for_wav2vec
167 |
168 | # helper for saving config
169 |
170 | def copy_file_to_folder(file_path: str, folder_path: str):
171 | config_file = Path(file_path)
172 | folder = Path(folder_path)
173 |
174 | shutil.copy(str(config_file), str(folder / config_file.name))
--------------------------------------------------------------------------------
/scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhvng/open-musiclm/8e2c6a8d65a33d407072bab8f5af64d9cb49f2e4/scripts/__init__.py
--------------------------------------------------------------------------------
/scripts/download_checkpoints.sh:
--------------------------------------------------------------------------------
1 | # new clap checkpoint
2 | if [ -e ./checkpoints/music_speech_audioset_epoch_15_esc_89.98.pt ]
3 | then
4 | echo "clap checkpoint already downloaded"
5 | else
6 | wget -P ./checkpoints 'https://huggingface.co/lukewys/laion_clap/resolve/main/music_speech_audioset_epoch_15_esc_89.98.pt'
7 | fi
8 |
--------------------------------------------------------------------------------
/scripts/download_fma_large.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 | if [ -e ./data/fma_large.zip ]
4 | then
5 | echo "fma_large already downloaded"
6 | else
7 | echo "downloading fma_large.zip, may take a while..."
8 | wget -P ./data https://os.unil.cloud.switch.ch/fma/fma_large.zip
9 | fi
10 |
11 | if [ -e ./data/fma_large ]
12 | then
13 | echo "fma_large already unzipped"
14 | else
15 | echo "unzipping fma_large.zip, may take a while..."
16 | echo "497109f4dd721066b5ce5e5f250ec604dc78939e data/fma_large.zip" | sha1sum -c -
17 | cd data
18 | unzip fma_large.zip
19 | fi
--------------------------------------------------------------------------------
/scripts/download_fma_metadata.sh:
--------------------------------------------------------------------------------
1 | wget -P ./data https://os.unil.cloud.switch.ch/fma/fma_metadata.zip
2 | echo "f0df49ffe5f2a6008d7dc83c6915b31835dfe733 fma_metadata.zip" | sha1sum -c -
3 | cd data
4 | unzip fma_metadata.zip
--------------------------------------------------------------------------------
/scripts/infer.py:
--------------------------------------------------------------------------------
1 | '''
2 | example usage:
3 |
4 | python3 scripts/infer.py \
5 | --semantic_path ./results/semantic/semantic.transformer.10000.pt \
6 | --coarse_path ./results/coarse/coarse.transformer.10000.pt \
7 | --fine_path ./results/fine/fine.transformer.10000.pt \
8 | --model_config ./configs/model/musiclm_small.json \
9 | --return_coarse_wave
10 | '''
11 |
12 | import os
13 | import sys
14 |
15 | import torch
16 | import torchaudio
17 | from einops import rearrange
18 | import argparse
19 |
20 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
21 |
22 | from open_musiclm.config import load_model_config, create_musiclm_from_config
23 |
24 | prompts = [
25 | [
26 | 'The main soundtrack of an arcade game. It is fast-paced and upbeat, with a catchy electric guitar riff. The music is repetitive and easy to remember, but with unexpected sounds, like cymbal crashes or drum rolls.',
27 | 'A fusion of reggaeton and electronic dance music, with a spacey, otherworldly sound. Induces the experience of being lost in space, and the music would be designed to evoke a sense of wonder and awe, while being danceable.',
28 | 'A rising synth is playing an arpeggio with a lot of reverb. It is backed by pads, sub bass line and soft drums. This song is full of synth sounds creating a soothing and adventurous atmosphere. It may be playing at a festival during two songs for a buildup.',
29 | 'Slow tempo, bass-and-drums-led reggae song. Sustained electric guitar. High-pitched bongos with ringing tones. Vocals are relaxed with a laid-back feel, very expressive.',
30 | ],
31 | ['song with synths and flute', 'crowd cheering', 'piano sonata waltz, glittery', 'house song, 4 on the floor, rhythm'],
32 | ['chirping of birds and the distant echos of bells', 'cat meowing', 'saxophone with drums', 'beethoven piano sonata']
33 | ]
34 |
35 | if __name__ == '__main__':
36 | parser = argparse.ArgumentParser(description='run inference on trained musiclm model')
37 |
38 | parser.add_argument('--model_config', default='./configs/model/musiclm_small.json', help='path to model config')
39 | parser.add_argument('--semantic_path', required=True, help='path to semantic stage checkpoint')
40 | parser.add_argument('--coarse_path', required=True, help='path to coarse stage checkpoint')
41 | parser.add_argument('--fine_path', required=True, help='path to fine stage checkpoint')
42 | parser.add_argument('--rvq_path', default='./checkpoints/clap.rvq.350.pt')
43 | parser.add_argument('--kmeans_path', default='./results/hubert_kmeans/kmeans.joblib')
44 | parser.add_argument('--return_coarse_wave', default=False, action=argparse.BooleanOptionalAction)
45 | parser.add_argument('--duration', default=4, type=float, help='duration of audio to generate in seconds')
46 | parser.add_argument('--seed', default=0)
47 |
48 | args = parser.parse_args()
49 |
50 | model_config = load_model_config(args.model_config)
51 |
52 | semantic_path = args.semantic_path
53 | coarse_path = args.coarse_path
54 | fine_path = args.fine_path
55 | return_coarse_wave = args.return_coarse_wave
56 | duration = args.duration
57 | kmeans_path = args.kmeans_path
58 | rvq_path = args.rvq_path
59 | seed = args.seed
60 |
61 | print(f'semantic checkpoint {semantic_path}, coarse checkpoint {coarse_path}, fine checkpoint {fine_path}')
62 | print(f'kmeans path {kmeans_path}, rvq path {rvq_path}, return_coarse_wave {return_coarse_wave}')
63 |
64 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
65 |
66 | musiclm = create_musiclm_from_config(
67 | model_config=model_config,
68 | semantic_path=semantic_path,
69 | coarse_path=coarse_path,
70 | fine_path=fine_path,
71 | rvq_path=rvq_path,
72 | kmeans_path=kmeans_path,
73 | device=device)
74 |
75 | torch.manual_seed(seed)
76 |
77 | print('generating...')
78 |
79 | for prompt in prompts:
80 | generated_wave = musiclm.forward(
81 | text=prompt,
82 | output_seconds=duration,
83 | semantic_window_seconds=model_config.global_cfg.semantic_audio_length_seconds,
84 | coarse_window_seconds=model_config.global_cfg.coarse_audio_length_seconds,
85 | fine_window_seconds=model_config.global_cfg.fine_audio_length_seconds,
86 | semantic_steps_per_second=model_config.hubert_kmeans_cfg.output_hz,
87 | acoustic_steps_per_second=model_config.encodec_cfg.output_hz,
88 | return_coarse_generated_wave=return_coarse_wave,
89 | ).detach().cpu()
90 |
91 | print(generated_wave.shape)
92 |
93 | generated_wave = rearrange(generated_wave, 'b n -> b 1 n')
94 | for i, wave in enumerate(generated_wave):
95 | torchaudio.save(f'results/{prompt[i][:25]}_generated.wav', wave, musiclm.neural_codec.sample_rate)
96 |
97 |
--------------------------------------------------------------------------------
/scripts/infer_coarse.py:
--------------------------------------------------------------------------------
1 | '''
2 | 1) load audio samples
3 | 2) compute clap tokens and semantic tokens
4 | 3) run them through coarse stage to predict coarse tokens
5 | 4) reconstruct audio from coarse tokens
6 | Reconstructed audio should be semantically similar to the original audio if hubert-kmeans and coarse stage are working correctly
7 |
8 | example usage:
9 |
10 | python scripts/infer_coarse.py \
11 | ./data/fma_large/000/000005.mp3 \
12 | ./data/fma_large/000/000010.mp3 \
13 | --model_config ./configs/model/musiclm_small.json \
14 | --coarse_path ./results/coarse_continue_1/coarse.transformer.10000.pt
15 |
16 | '''
17 |
18 | import argparse
19 | import os
20 | import sys
21 | from pathlib import Path
22 |
23 | import torch
24 | import torchaudio
25 | from einops import rearrange
26 | from torchaudio.functional import resample
27 |
28 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
29 |
30 | from open_musiclm.config import (create_clap_quantized_from_config,
31 | create_coarse_transformer_from_config,
32 | create_encodec_from_config,
33 | create_hubert_kmeans_from_config,
34 | load_model_config)
35 | from open_musiclm.open_musiclm import (CoarseStage,
36 | get_or_compute_clap_token_ids,
37 | get_or_compute_semantic_token_ids)
38 | from open_musiclm.utils import int16_to_float32, float32_to_int16, zero_mean_unit_var_norm
39 | from scripts.train_utils import disable_print
40 |
41 | if __name__ == '__main__':
42 | parser = argparse.ArgumentParser(description='run inference on coarse stage')
43 | parser.add_argument('audio_files', type=str, nargs='+')
44 | parser.add_argument('--model_config', default='./configs/model/musiclm_small.json', help='path to model config')
45 | parser.add_argument('--coarse_path', required=True, help='path to coarse stage checkpoint')
46 | parser.add_argument('--rvq_path', default='./checkpoints/clap.rvq.350.pt')
47 | parser.add_argument('--kmeans_path', default='./results/hubert_kmeans/kmeans.joblib')
48 | parser.add_argument('--seed', default=0)
49 |
50 | args = parser.parse_args()
51 |
52 | model_config = load_model_config(args.model_config)
53 |
54 | audio_files = args.audio_files
55 | coarse_path = args.coarse_path
56 | kmeans_path = args.kmeans_path
57 | rvq_path = args.rvq_path
58 | seed = args.seed
59 |
60 | print(f'running inference on {audio_files}')
61 | print(f'coarse_path: {coarse_path}, kmeans_path: {kmeans_path}, rvq_path: {rvq_path}, seed: {seed}')
62 |
63 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
64 |
65 | print('loading clap...')
66 | clap = create_clap_quantized_from_config(model_config, args.rvq_path, device)
67 |
68 | print('loading wav2vec...')
69 | wav2vec = create_hubert_kmeans_from_config(model_config, args.kmeans_path, device)
70 |
71 | print('loading encodec...')
72 | encodec_wrapper = create_encodec_from_config(model_config, device)
73 |
74 | print('loading coarse stage...')
75 | coarse_transformer = create_coarse_transformer_from_config(model_config, coarse_path, device)
76 |
77 | coarse_stage = CoarseStage(
78 | coarse_transformer=coarse_transformer,
79 | neural_codec=encodec_wrapper,
80 | wav2vec=wav2vec,
81 | clap=clap
82 | )
83 |
84 | torch.manual_seed(args.seed)
85 |
86 | print('loading audio from dataset')
87 |
88 | audios_for_clap = []
89 | audios_for_wav2vec = []
90 | for audio_path in audio_files:
91 | data, sample_hz = torchaudio.load(audio_path)
92 |
93 | if data.shape[0] > 1:
94 | data = torch.mean(data, dim=0).unsqueeze(0)
95 |
96 | target_length = int(4 * sample_hz)
97 | normalized_data = zero_mean_unit_var_norm(data)
98 |
99 | data = data[:, :target_length]
100 | normalized_data = normalized_data[: , :target_length]
101 | audio_for_clap = resample(data, sample_hz, clap.sample_rate)
102 | audio_for_wav2vec = resample(normalized_data, sample_hz, wav2vec.target_sample_hz)
103 |
104 | audio_for_clap = int16_to_float32(float32_to_int16(audio_for_clap))
105 | audio_for_wav2vec = int16_to_float32(float32_to_int16(audio_for_wav2vec))
106 |
107 | audios_for_clap.append(audio_for_clap)
108 | audios_for_wav2vec.append(audio_for_wav2vec)
109 |
110 | audios_for_clap = torch.cat(audios_for_clap, dim=0).to(device)
111 | audios_for_wav2vec = torch.cat(audios_for_wav2vec, dim=0).to(device)
112 |
113 | clap_token_ids = get_or_compute_clap_token_ids(None, clap, audios_for_clap, None)
114 | semantic_token_ids = get_or_compute_semantic_token_ids(None, audios_for_wav2vec, wav2vec)
115 |
116 | generated_wave = coarse_stage.generate(
117 | clap_token_ids=clap_token_ids,
118 | semantic_token_ids=semantic_token_ids,
119 | coarse_token_ids=None,
120 | max_time_steps=int(model_config.global_cfg.coarse_audio_length_seconds * 75),
121 | reconstruct_wave=True,
122 | include_eos_in_output=False,
123 | append_eos_to_conditioning_tokens=True,
124 | temperature=0.95,
125 | )
126 |
127 | generated_wave = rearrange(generated_wave, 'b n -> b 1 n').detach().cpu()
128 | for i, wave in enumerate(generated_wave):
129 | torchaudio.save(f'results/{i}.wav', wave, encodec_wrapper.sample_rate)
--------------------------------------------------------------------------------
/scripts/infer_fine.py:
--------------------------------------------------------------------------------
1 | '''
2 | 1) load audio samples
3 | 2) compute clap tokens and coarse tokens
4 | 3) run them through fine stage to predict coarse tokens
5 | 4) reconstruct audio from coarse + fine tokens
6 | Reconstructed audio should be similar to the original audio if fine stage is working correctly
7 |
8 | example usage:
9 |
10 | python scripts/infer_fine.py \
11 | ./data/fma_large/000/000005.mp3 \
12 | ./data/fma_large/000/000010.mp3 \
13 | --model_config ./configs/model/musiclm_small.json \
14 | --fine_path ./results/coarse_continue_1/coarse.transformer.10000.pt
15 | '''
16 |
17 | import argparse
18 | import os
19 | import sys
20 | from pathlib import Path
21 |
22 | import torch
23 | import torchaudio
24 | from einops import rearrange
25 | from torchaudio.functional import resample
26 |
27 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
28 |
29 | from open_musiclm.config import (create_clap_quantized_from_config,
30 | create_fine_transformer_from_config,
31 | create_encodec_from_config,
32 | create_hubert_kmeans_from_config,
33 | load_model_config)
34 | from open_musiclm.open_musiclm import (FineStage,
35 | get_or_compute_clap_token_ids,
36 | get_or_compute_acoustic_token_ids)
37 | from open_musiclm.utils import int16_to_float32, float32_to_int16, zero_mean_unit_var_norm
38 | from scripts.train_utils import disable_print
39 |
40 | if __name__ == '__main__':
41 | parser = argparse.ArgumentParser(description='run inference on fine stage')
42 | parser.add_argument('audio_files', type=str, nargs='+')
43 | parser.add_argument('--model_config', default='./configs/model/musiclm_small.json', help='path to model config')
44 | parser.add_argument('--fine_path', required=True, help='path to fine stage checkpoint')
45 | parser.add_argument('--temperature', default=0.4, type=float)
46 | parser.add_argument('--rvq_path', default='./checkpoints/clap.rvq.350.pt')
47 | parser.add_argument('--seed', default=0)
48 |
49 | args = parser.parse_args()
50 |
51 | model_config = load_model_config(args.model_config)
52 |
53 | audio_files = args.audio_files
54 | fine_path = args.fine_path
55 | rvq_path = args.rvq_path
56 | seed = args.seed
57 | temperature = args.temperature
58 |
59 | print(f'running inference on {audio_files}')
60 | print(f'fine_path: {fine_path}, rvq_path: {rvq_path}, seed: {seed}')
61 |
62 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
63 |
64 | print('loading clap...')
65 | clap = create_clap_quantized_from_config(model_config, rvq_path, device)
66 |
67 | print('loading encodec...')
68 | encodec_wrapper = create_encodec_from_config(model_config, device)
69 |
70 | print('loading fine stage...')
71 | fine_transformer = create_fine_transformer_from_config(model_config, fine_path, device)
72 |
73 | fine_stage = FineStage(
74 | fine_transformer=fine_transformer,
75 | neural_codec=encodec_wrapper,
76 | clap=clap
77 | )
78 |
79 | torch.manual_seed(seed)
80 |
81 | print('loading audio from dataset')
82 |
83 | audios_for_clap = []
84 | audios_for_encodec = []
85 | for audio_path in audio_files:
86 | data, sample_hz = torchaudio.load(audio_path)
87 |
88 | if data.shape[0] > 1:
89 | data = torch.mean(data, dim=0).unsqueeze(0)
90 |
91 | target_length = int(model_config.global_cfg.fine_audio_length_seconds * sample_hz)
92 |
93 | data = data[:, :target_length]
94 | audio_for_clap = resample(data, sample_hz, clap.sample_rate)
95 | audio_for_encodec = resample(data, sample_hz, encodec_wrapper.sample_rate)
96 |
97 | audio_for_clap = int16_to_float32(float32_to_int16(audio_for_clap))
98 | audio_for_encodec = int16_to_float32(float32_to_int16(audio_for_encodec))
99 |
100 | audios_for_clap.append(audio_for_clap)
101 | audios_for_encodec.append(audio_for_encodec)
102 |
103 | audios_for_clap = torch.cat(audios_for_clap, dim=0).to(device)
104 | audios_for_encodec = torch.cat(audios_for_encodec, dim=0).to(device)
105 |
106 | clap_token_ids = get_or_compute_clap_token_ids(None, clap, audios_for_clap, None)
107 | coarse_token_ids, fine_token_ids = get_or_compute_acoustic_token_ids(None, None, audios_for_encodec, encodec_wrapper, model_config.global_cfg.num_coarse_quantizers)
108 |
109 | generated_wave = fine_stage.generate(
110 | clap_token_ids=clap_token_ids,
111 | coarse_token_ids=coarse_token_ids,
112 | max_time_steps=int(model_config.global_cfg.fine_audio_length_seconds * model_config.encodec_cfg.output_hz),
113 | reconstruct_wave=True,
114 | include_eos_in_output=False,
115 | append_eos_to_conditioning_tokens=True,
116 | temperature=temperature,
117 | )
118 |
119 | generated_wave = rearrange(generated_wave, 'b n -> b 1 n').detach().cpu()
120 | for i, wave in enumerate(generated_wave):
121 | torchaudio.save(f'results/fine_reconstruct_{i}.wav', wave, encodec_wrapper.sample_rate)
--------------------------------------------------------------------------------
/scripts/infer_top_match.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import torch
5 | import torchaudio
6 | from einops import rearrange
7 | from pathlib import Path
8 | import argparse
9 |
10 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
11 |
12 | from open_musiclm.config import load_model_config, create_musiclm_from_config
13 |
14 | if __name__ == '__main__':
15 | parser = argparse.ArgumentParser(description='run inference on trained musiclm model')
16 |
17 | parser.add_argument('prompt', help='prompts to generate audio for', type=str, nargs='+')
18 | parser.add_argument('--num_samples', default=4, type=int)
19 | parser.add_argument('--num_top_matches', default=1, type=int)
20 | parser.add_argument('--input_audio', default=None, type=str, help='input audio to condition on and generate continuations from')
21 | parser.add_argument('--model_config', default='./configs/model/musiclm_small.json', help='path to model config')
22 | parser.add_argument('--semantic_path', required=True, help='path to semantic stage checkpoint')
23 | parser.add_argument('--coarse_path', required=True, help='path to coarse stage checkpoint')
24 | parser.add_argument('--fine_path', required=True, help='path to fine stage checkpoint')
25 | parser.add_argument('--rvq_path', default='./checkpoints/clap.rvq.350.pt')
26 | parser.add_argument('--kmeans_path', default='./results/hubert_kmeans/kmeans.joblib')
27 | parser.add_argument('--results_folder', default='./results', type=str)
28 | parser.add_argument('--return_coarse_wave', default=False, action=argparse.BooleanOptionalAction)
29 | parser.add_argument('--duration', default=4, type=float, help='duration of audio to generate in seconds')
30 | parser.add_argument('--seed', default=0)
31 |
32 | args = parser.parse_args()
33 |
34 | model_config = load_model_config(args.model_config)
35 |
36 | semantic_path = args.semantic_path
37 | coarse_path = args.coarse_path
38 | fine_path = args.fine_path
39 | input_audio = args.input_audio
40 | return_coarse_wave = args.return_coarse_wave
41 | duration = args.duration
42 | kmeans_path = args.kmeans_path
43 | rvq_path = args.rvq_path
44 | seed = args.seed
45 | results_folder = args.results_folder
46 |
47 | Path(results_folder).mkdir(parents=True, exist_ok=True)
48 |
49 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
50 |
51 | musiclm = create_musiclm_from_config(
52 | model_config=model_config,
53 | semantic_path=semantic_path,
54 | coarse_path=coarse_path,
55 | fine_path=fine_path,
56 | rvq_path=rvq_path,
57 | kmeans_path=kmeans_path,
58 | device=device)
59 |
60 | torch.manual_seed(seed)
61 |
62 | print(f'prompt: {args.prompt}')
63 |
64 | prime_wave, prime_wave_sample_hz = None, None
65 | if input_audio is not None:
66 | prime_wave, prime_wave_sample_hz = torchaudio.load(input_audio)
67 | prime_wave = prime_wave.to(device)
68 |
69 | generated_wave, similarities = musiclm.generate_top_match(
70 | text=args.prompt,
71 | prime_wave=prime_wave,
72 | prime_wave_sample_hz=prime_wave_sample_hz,
73 | num_samples=args.num_samples,
74 | num_top_matches=args.num_top_matches,
75 | output_seconds=duration,
76 | semantic_window_seconds=model_config.global_cfg.semantic_audio_length_seconds,
77 | coarse_window_seconds=model_config.global_cfg.coarse_audio_length_seconds,
78 | fine_window_seconds=model_config.global_cfg.fine_audio_length_seconds,
79 | semantic_steps_per_second=model_config.hubert_kmeans_cfg.output_hz,
80 | acoustic_steps_per_second=model_config.encodec_cfg.output_hz,
81 | return_coarse_generated_wave=return_coarse_wave,
82 | )
83 |
84 | for i, (wave, sim) in enumerate(zip(generated_wave, similarities)):
85 | wave = rearrange(wave, 'b n -> b 1 n').detach().cpu()
86 | print(f'prompt: {args.prompt[i]}')
87 | print(f'topk similarities: {sim}')
88 | for j, w in enumerate(wave):
89 | torchaudio.save(Path(results_folder) / Path(f'{args.prompt[i][:35]}_top_match_{j}.wav'), w, musiclm.neural_codec.sample_rate)
90 |
--------------------------------------------------------------------------------
/scripts/preprocess_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | from pathlib import Path
5 |
6 | import torch
7 |
8 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
9 |
10 | from open_musiclm.config import (create_clap_quantized_from_config,
11 | create_encodec_from_config,
12 | create_hubert_kmeans_from_config,
13 | load_model_config, load_training_config,
14 | create_data_preprocessor_from_config)
15 |
16 | if __name__ == '__main__':
17 | parser = argparse.ArgumentParser(description='preprocess data')
18 | parser.add_argument('--model_config', default='./configs/model/musiclm_small.json')
19 | parser.add_argument('--training_config', default='./configs/training/train_fma_preprocess.json')
20 | parser.add_argument('--rvq_path', default='./checkpoints/clap.rvq.350.pt')
21 | parser.add_argument('--kmeans_path', default='./results/hubert_kmeans/kmeans.joblib')
22 | parser.add_argument('--filter_fma', default=False, action=argparse.BooleanOptionalAction)
23 |
24 | args = parser.parse_args()
25 |
26 | print(f'using model config {args.model_config}, training config {args.training_config}, rvq checkpoint {args.rvq_path}, kmeans checkpoint {args.kmeans_path}')
27 |
28 | model_config = load_model_config(args.model_config)
29 | training_config = load_training_config(args.training_config)
30 |
31 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
32 |
33 | print('loading clap...')
34 | clap = create_clap_quantized_from_config(model_config, args.rvq_path, device)
35 |
36 | print('loading wav2vec...')
37 | wav2vec = create_hubert_kmeans_from_config(model_config, args.kmeans_path, device)
38 |
39 | print('loading encodec...')
40 | encodec_wrapper = create_encodec_from_config(model_config, device)
41 |
42 |
43 | # get rid of some experimental tracks, see notebooks/analyze_fma.ipynb
44 | if args.filter_fma:
45 | try:
46 | import pandas as pd
47 | import ast
48 | except ImportError:
49 | pd = None
50 |
51 | assert pd is not None, 'pandas not found, please install pandas to filter fma'
52 |
53 | metadata_folder = training_config.data_preprocessor_cfg.metadata_folder
54 |
55 | tracks = pd.read_csv(os.path.join(metadata_folder, 'tracks.csv'), index_col=0, header=[0, 1])
56 | experimental_genre = 38
57 | experimental_tracks = tracks.loc[tracks['track', 'genres_all'].apply(lambda x: experimental_genre in ast.literal_eval(x))]
58 | ignore_files = list(experimental_tracks.loc[(experimental_tracks['track', 'listens'] <= 1000) | (experimental_tracks['track', 'favorites'] <= 5)].index)
59 | ignore_files = [f'{i:06d}.mp3' for i in ignore_files]
60 | else:
61 | ignore_files = None
62 |
63 | processor = create_data_preprocessor_from_config(
64 | model_config,
65 | training_config,
66 | clap,
67 | wav2vec,
68 | encodec_wrapper,
69 | device,
70 | config_paths=[args.model_config, args.training_config],
71 | ignore_files=ignore_files)
72 |
73 | processor.process()
74 |
75 |
--------------------------------------------------------------------------------
/scripts/test/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhvng/open-musiclm/8e2c6a8d65a33d407072bab8f5af64d9cb49f2e4/scripts/test/__init__.py
--------------------------------------------------------------------------------
/scripts/test/test_clap.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import torchaudio
5 | from torchaudio.functional import resample
6 | import numpy as np
7 | import torch
8 | from transformers import RobertaTokenizer
9 |
10 | sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
11 |
12 | from open_musiclm.clap_quantized import ClapQuantized, create_clap_quantized
13 |
14 | text_data = ['male rap',
15 | 'male rapping over a synth pad',
16 | 'female singing',
17 | 'female singing over a synth pad',
18 | 'male singing over a synth pad',
19 | 'male rapping chill voice',
20 | 'male with deep voice',
21 | 'male singing then rapping, synth in the background',
22 | 'producer tag then drake rapping',
23 | 'calming melody',
24 | 'upbeat, hype',
25 | 'pause and then the beat drops, and a male rapper is rapping over a trap beat',
26 | 'male rapping over a hip hop beat',
27 | 'house music',
28 | 'rock song with piano',
29 | 'groovy melody with piano and a male singing']
30 |
31 | def infer_text(clap_wrapper, return_embedding=False):
32 |
33 |
34 |
35 | text_embed = clap_wrapper(text_input=text_data, return_embedding=return_embedding)
36 |
37 | return text_embed
38 |
39 | def infer_audio(clap_wrapper: ClapQuantized, return_embedding: bool = False, device: str = 'cuda'):
40 |
41 | print('inferring audio...')
42 |
43 | # load the waveform of the shape (T,), should resample to 48000
44 | audio_waveform, sr = torchaudio.load('/u/zhvng/projects/audio_files/jumpman.mp3')
45 |
46 | wave_2, sr_2 = torchaudio.load('/u/zhvng/projects/open-musiclm/data/fma_large/000/000048.mp3')
47 |
48 | if audio_waveform.shape[0] > 1:
49 | # the audio has more than 1 channel, convert to mono
50 | audio_waveform = torch.mean(audio_waveform, dim=0, keepdim=True)
51 | if wave_2.shape[0] > 1:
52 | wave_2 = torch.mean(wave_2, dim=0, keepdim=True)
53 |
54 | audio_waveform = resample(audio_waveform, sr, 48000)
55 | wave_2 = resample(wave_2, sr_2, 48000)
56 |
57 | # audio_waveform = audio_waveform[:, :48000 * 30]
58 | audio_waveform_1 = audio_waveform[:, :48000 * 10]
59 | audio_waveform_2 = wave_2[:, :48000 * 10]
60 | # audio_waveform_3 = audio_waveform[:, 48000 * 20 : 48000 * 50]
61 | audio_waveform = torch.cat([audio_waveform_1, audio_waveform_2], dim=0)
62 |
63 | audio_embed = clap_wrapper(audio_input=audio_waveform.to(device), return_embedding=return_embedding)
64 |
65 | return audio_embed
66 |
67 |
68 | if __name__ == "__main__":
69 |
70 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
71 |
72 | clap_wrapper = create_clap_quantized(device=device, learn_rvq=False).to(device)
73 |
74 | text_embeds = infer_text(clap_wrapper, return_embedding=True)
75 | audio_embed = infer_audio(clap_wrapper, return_embedding=True, device=device)
76 |
77 | # print(text_embeds)
78 | print(text_embeds.shape)
79 |
80 | # print(audio_embed)
81 | print(audio_embed.shape)
82 |
83 | for i, text_embed in enumerate(text_embeds):
84 | # get cosine similarity with audio_embed
85 | cos_sim = torch.nn.functional.cosine_similarity(
86 | audio_embed, text_embed, dim=-1)
87 | print(text_data[i], cos_sim.cpu().numpy())
88 |
--------------------------------------------------------------------------------
/scripts/test/test_config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
5 |
6 | from open_musiclm.config import load_model_config, load_training_config
7 |
8 |
9 | model_config = load_model_config('./configs/model/musiclm_small.json')
10 | print(model_config)
11 |
12 | training_config = load_training_config('./configs/training/train_musiclm_fma.json')
13 | print(training_config)
14 |
15 | print('\nok!')
--------------------------------------------------------------------------------
/scripts/test/test_dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import torch
5 | import torchaudio
6 | from torchaudio.functional import resample
7 | import argparse
8 | from pathlib import Path
9 | import matplotlib.pyplot as plt
10 | import numpy as np
11 |
12 | sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
13 |
14 | from open_musiclm.data import SoundDataset, SoundDatasetForPreprocessing, PreprocessedDataset, get_dataloader, get_sound_preprocessing_dataloader, get_preprocessed_dataloader
15 |
16 |
17 | folder = './data/fma_large/000'
18 |
19 | # test random crop
20 |
21 | dataset = SoundDataset(
22 | folder,
23 | max_length_seconds=(1, 5),
24 | normalize=(True, False),
25 | target_sample_hz=(16000, 24000),
26 | seq_len_multiple_of=None,
27 | ignore_load_errors=True
28 | )
29 |
30 | dl = get_dataloader(dataset, batch_size=4, shuffle=False)
31 | dl_iter = iter(dl)
32 |
33 | test_steps = 2
34 | for i in range(test_steps):
35 | batch = next(dl_iter)
36 | # print(batch)
37 | for e in batch:
38 | print(e.shape)
39 |
40 | # test preprocessing
41 |
42 | dataset = SoundDatasetForPreprocessing(
43 | folder,
44 | max_length_seconds=(None, 1),
45 | normalize=(True, False),
46 | target_sample_hz=(16000, 24000),
47 | seq_len_multiple_of=None,
48 | ignore_load_errors=True
49 | )
50 |
51 | dl = get_sound_preprocessing_dataloader(dataset, shuffle=False)
52 | dl_iter = iter(dl)
53 |
54 | test_steps = 2
55 | for i in range(test_steps):
56 | batch = next(dl_iter)
57 | print(batch)
58 |
59 | # # test preprocessed
60 | dataset = PreprocessedDataset(
61 | './data/fma_preprocessed',
62 | stage='coarse',
63 | semantic_window_seconds=10,
64 | coarse_window_seconds=4,
65 | fine_window_seconds=2,
66 | semantic_steps_per_second=50,
67 | acoustic_steps_per_second=75,
68 | )
69 |
70 | dl = get_preprocessed_dataloader(dataset, batch_size=4, shuffle=True)
71 | dl_iter = iter(dl)
72 |
73 | test_steps = 2
74 | for i in range(test_steps):
75 | batch = next(dl_iter)
76 | for d in batch:
77 | print(d.shape)
78 |
--------------------------------------------------------------------------------
/scripts/test/test_encodec.py:
--------------------------------------------------------------------------------
1 | from encodec import EncodecModel
2 | from encodec.utils import convert_audio
3 |
4 | import torchaudio
5 | import torch
6 |
7 |
8 | if __name__ == "__main__":
9 |
10 | # Instantiate a pretrained EnCodec model
11 | model = EncodecModel.encodec_model_24khz()
12 | model.set_target_bandwidth(6.0)
13 |
14 | # Load and pre-process the audio waveform
15 | wav, sr = torchaudio.load("/u/zhvng/projects/audio_files/jumpman.mp3")
16 | wav = convert_audio(wav, sr, model.sample_rate, model.channels)
17 | print(model.channels, model.sample_rate, model.quantizer.n_q)
18 | print(model.segment_stride)
19 | wav = torch.stack([wav, wav], dim=0)
20 |
21 | # Extract discrete codes from EnCodec
22 | with torch.no_grad():
23 | encoded_frames = model.encode(wav)
24 | codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # [B, n_q, T]
25 |
26 | print(codes)
27 | print(codes.shape)
28 | print(encoded_frames[0][0].shape)
29 | print(len(encoded_frames))
30 |
31 | new_encoded_frames = [(encoded[0], None) for encoded in encoded_frames]
32 |
33 | wave = model.decode(new_encoded_frames)
34 | torchaudio.save('test.wav', wave[0], model.sample_rate)
35 |
--------------------------------------------------------------------------------
/scripts/test/test_hubert_clustering.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import torch
5 | import torchaudio
6 | from torchaudio.functional import resample
7 | import argparse
8 | from pathlib import Path
9 | import matplotlib.pyplot as plt
10 | import numpy as np
11 |
12 | sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
13 |
14 | from open_musiclm.config import load_model_config, load_training_config, create_hubert_kmeans_from_config, create_hubert_kmeans_trainer_from_config
15 | from open_musiclm.utils import zero_mean_unit_var_norm, int16_to_float32, float32_to_int16, exists
16 | from open_musiclm.open_musiclm import get_or_compute_semantic_token_ids
17 |
18 | if __name__ == '__main__':
19 | parser = argparse.ArgumentParser(description='test hubert kmeans to see the difference in sequences')
20 | parser.add_argument('--model_config', default='./configs/model/musiclm_small.json')
21 | parser.add_argument('--kmeans_path', default='./results/hubert_kmeans/kmeans.joblib')
22 | parser.add_argument('--folder', default='./data/fma_large')
23 |
24 | args = parser.parse_args()
25 |
26 | model_config = load_model_config(args.model_config)
27 |
28 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
29 |
30 | print('loading hubert...')
31 | wav2vec = create_hubert_kmeans_from_config(model_config, args.kmeans_path, device)
32 |
33 | path = Path(args.folder)
34 | assert path.exists(), 'folder does not exist'
35 |
36 | files = []
37 | for ext in ['mp3', 'wav', 'flac']:
38 | for file in path.glob(f'**/*.{ext}'):
39 | files.append(file)
40 | assert len(files) > 0, 'no sound files found'
41 |
42 | start_audio = 20000
43 | audio_lengths = [4, 10, 15, 25]
44 | batch_size = 16
45 | shortest_length = None
46 | cropped_semantic_tokens_for_each_length = []
47 | for audio_seconds in audio_lengths:
48 | audios_for_wav2vec = []
49 | for audio_path in files[start_audio: start_audio + 16]:
50 | data, sample_hz = torchaudio.load(audio_path)
51 |
52 | if data.shape[0] > 1:
53 | data = torch.mean(data, dim=0).unsqueeze(0)
54 |
55 | target_length = int(audio_seconds * sample_hz)
56 | normalized_data = zero_mean_unit_var_norm(data)
57 |
58 | normalized_data = normalized_data[: , :target_length]
59 |
60 | audio_for_wav2vec = resample(normalized_data, sample_hz, wav2vec.target_sample_hz)
61 |
62 | audio_for_wav2vec = int16_to_float32(float32_to_int16(audio_for_wav2vec))
63 |
64 | audios_for_wav2vec.append(audio_for_wav2vec)
65 |
66 | audios_for_wav2vec = torch.cat(audios_for_wav2vec, dim=0).to(device)
67 | semantic_token_ids = get_or_compute_semantic_token_ids(None, audios_for_wav2vec, wav2vec)
68 | print(semantic_token_ids.shape)
69 |
70 | if not exists(shortest_length):
71 | shortest_length = semantic_token_ids.shape[1]
72 | else:
73 | l = semantic_token_ids.shape[1]
74 | if l < shortest_length:
75 | shortest_length = l
76 |
77 | cropped_semantic_tokens_for_each_length.append(semantic_token_ids[:, :shortest_length])
78 |
79 | print(cropped_semantic_tokens_for_each_length[0][0])
80 | print(cropped_semantic_tokens_for_each_length[1][0])
81 | # get accuracy compared to last elem in cropped_semantic_tokens_for_each_length
82 |
83 | side_length = len(cropped_semantic_tokens_for_each_length)
84 | accuracy_matrix = np.zeros((side_length, side_length))
85 |
86 | for i in range(side_length):
87 | for j in range(side_length):
88 | accuracy = torch.mean((cropped_semantic_tokens_for_each_length[i] == cropped_semantic_tokens_for_each_length[j]).float())
89 | print(f'% similar between {audio_lengths[i]} and {audio_lengths[j]} second audio: {accuracy}')
90 | accuracy_matrix[i][j] = accuracy
91 |
92 | # plot the accuracy matrix in a grid with a title and axis labels
93 |
94 | # create a heatmap with darker colors representing higher accuracy
95 | fig, ax = plt.subplots()
96 | im = ax.imshow(accuracy_matrix, cmap='Blues', vmin=0, vmax=1)
97 |
98 | # remove ticks from the plot
99 | ax.tick_params(axis=u'both', which=u'both', length=0)
100 |
101 | # move the x-axis ticks to the top of the grid
102 | ax.xaxis.tick_top()
103 |
104 | # add a colorbar legend
105 | cbar = ax.figure.colorbar(im, ax=ax)
106 |
107 | # set axis labels
108 | ax.set_xticks(np.arange(len(audio_lengths)))
109 | ax.set_yticks(np.arange(len(audio_lengths)))
110 | ax.set_xticklabels(audio_lengths)
111 | ax.set_yticklabels(audio_lengths)
112 |
113 | # add text annotations for each cell
114 | for i in range(4):
115 | for j in range(4):
116 | text = ax.text(j, i, round(accuracy_matrix[i, j], 2),
117 | ha="center", va="center", color="w")
118 |
119 | # set plot title
120 | ax.set_title("Semantic Token Similarity Between Various Total Audio Lengths")
121 | ax.set_xlabel("total audio length (seconds)")
122 | ax.set_ylabel("total audio length (seconds)")
123 |
124 | # show the plot
125 | plt.savefig('./results/accuracy_matrix.png')
126 |
127 |
128 |
--------------------------------------------------------------------------------
/scripts/test/test_load_preprocessed.py:
--------------------------------------------------------------------------------
1 |
2 | import io
3 | import numpy as np
4 | import sqlite3
5 |
6 | def adapt_array(arr):
7 | """
8 | http://stackoverflow.com/a/31312102/190597 (SoulNibbler)
9 | """
10 | out = io.BytesIO()
11 | np.save(out, arr)
12 | out.seek(0)
13 | return sqlite3.Binary(out.read())
14 |
15 | def convert_array(text):
16 | out = io.BytesIO(text)
17 | out.seek(0)
18 | return np.load(out)
19 |
20 | sqlite3.register_adapter(np.ndarray, adapt_array)
21 | sqlite3.register_converter("array", convert_array)
22 |
23 | conn = sqlite3.connect('./data/fma_preprocessed/preprocessed.db', detect_types=sqlite3.PARSE_DECLTYPES)
24 | cursor = conn.cursor()
25 | cursor.execute("SELECT idx, path FROM tokens WHERE idx=?", (0,))
26 | print(cursor.fetchone())
27 |
28 | for i in [3,-1]:
29 | cursor.execute("SELECT clap, semantic, coarse FROM tokens WHERE idx=?", (i,))
30 |
31 | data = cursor.fetchone()
32 |
33 | if data is None:
34 | print("No data found")
35 | else:
36 | for datum in data:
37 | print(datum.shape)
38 |
39 | print(data[1])
40 |
--------------------------------------------------------------------------------
/scripts/test/test_rvq.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from vector_quantize_pytorch import ResidualVQ
3 |
4 | print('running rvq')
5 | rq = ResidualVQ(dim=512, num_quantizers=12, codebook_size=1024, commitment_weight=0, decay=0.95, kmeans_init=True, threshold_ema_dead_code=0)
6 |
7 | # rq.load_state_dict(torch.load('./results/semantic/semantic.conditioner_rvq.6000.pt', map_location='cpu'))
8 |
9 | for i in range(10):
10 | q, i, loss = rq(torch.randn(1, 512))
11 | print(i[0:2])
12 |
--------------------------------------------------------------------------------
/scripts/train_clap_rvq.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 |
5 | import torch
6 |
7 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
8 |
9 | from open_musiclm.clap_quantized import create_clap_quantized
10 | from open_musiclm.config import (create_clap_quantized_from_config,
11 | create_clap_rvq_trainer_from_config,
12 | load_model_config, load_training_config)
13 | from open_musiclm.trainer import ClapRVQTrainer
14 | from scripts.train_utils import disable_print
15 |
16 | if __name__ == '__main__':
17 | parser = argparse.ArgumentParser(description='train rvq to quantize clap embeddings')
18 | parser.add_argument('--results_folder', default='./results/clap_rvq')
19 | parser.add_argument('--model_config', default='./configs/model/musiclm_small.json')
20 | parser.add_argument('--training_config', default='./configs/training/train_musiclm_fma.json')
21 | parser.add_argument('--continue_from', default=None, type=str)
22 |
23 | args = parser.parse_args()
24 |
25 | model_config = load_model_config(args.model_config)
26 | training_config = load_training_config(args.training_config)
27 |
28 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
29 |
30 | print('loading clap...')
31 | clap = create_clap_quantized_from_config(model_config, args.continue_from, device)
32 |
33 | trainer = create_clap_rvq_trainer_from_config(
34 | model_config=model_config,
35 | training_config=training_config,
36 | clap=clap,
37 | results_folder=args.results_folder,
38 | device=device,
39 | accelerate_kwargs={
40 | 'log_with': "tensorboard",
41 | 'logging_dir': './logs/clap_rvq'
42 | },
43 | config_paths=[args.model_config, args.training_config])
44 |
45 | trainer.train()
--------------------------------------------------------------------------------
/scripts/train_coarse_stage.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | from pathlib import Path
5 |
6 | import torch
7 |
8 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
9 |
10 | from open_musiclm.config import (create_clap_quantized_from_config,
11 | create_coarse_transformer_from_config,
12 | create_encodec_from_config,
13 | create_hubert_kmeans_from_config,
14 | create_single_stage_trainer_from_config,
15 | load_model_config, load_training_config)
16 | from scripts.train_utils import load_checkpoint_from_args, validate_train_args
17 |
18 | if __name__ == '__main__':
19 | parser = argparse.ArgumentParser(description='train coarse stage')
20 | parser.add_argument('--results_folder', default='./results/coarse')
21 | parser.add_argument('--continue_from_dir', default=None, type=str)
22 | parser.add_argument('--continue_from_step', default=None, type=int)
23 | parser.add_argument('--model_config', default='./configs/model/musiclm_small.json')
24 | parser.add_argument('--training_config', default='./configs/training/train_musiclm_fma.json')
25 | parser.add_argument('--rvq_path', default='./checkpoints/clap.rvq.350.pt')
26 | parser.add_argument('--kmeans_path', default='./results/hubert_kmeans/kmeans.joblib')
27 | parser.add_argument('--fine_tune_from', default=None, type=str)
28 |
29 | args = parser.parse_args()
30 |
31 | validate_train_args(args)
32 |
33 | model_config = load_model_config(args.model_config)
34 | training_config = load_training_config(args.training_config)
35 |
36 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
37 |
38 | use_preprocessed_data = training_config.coarse_trainer_cfg.use_preprocessed_data
39 |
40 | if use_preprocessed_data:
41 | clap = None
42 | wav2vec = None
43 | print(f'training from preprocessed data {training_config.coarse_trainer_cfg.folder}')
44 | else:
45 | print('loading clap...')
46 | clap = create_clap_quantized_from_config(model_config, args.rvq_path, device)
47 |
48 | print('loading wav2vec...')
49 | wav2vec = create_hubert_kmeans_from_config(model_config, args.kmeans_path, device)
50 |
51 | print('loading encodec...')
52 | encodec_wrapper = create_encodec_from_config(model_config, device)
53 |
54 | print('loading coarse stage...')
55 | coarse_transformer = create_coarse_transformer_from_config(model_config, args.fine_tune_from, device)
56 |
57 | trainer = create_single_stage_trainer_from_config(
58 | model_config=model_config,
59 | training_config=training_config,
60 | stage='coarse',
61 | results_folder=args.results_folder,
62 | transformer=coarse_transformer,
63 | clap=clap,
64 | wav2vec=wav2vec,
65 | encodec_wrapper=encodec_wrapper,
66 | device=device,
67 | accelerate_kwargs={
68 | 'log_with': "tensorboard",
69 | 'logging_dir': './logs/coarse'
70 | },
71 | config_paths=[args.model_config, args.training_config])
72 |
73 | load_checkpoint_from_args(trainer, args)
74 |
75 | trainer.train()
76 |
--------------------------------------------------------------------------------
/scripts/train_fine_stage.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import torch
5 | import argparse
6 | from pathlib import Path
7 |
8 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
9 |
10 | from open_musiclm.config import (create_clap_quantized_from_config,
11 | create_encodec_from_config,
12 | create_hubert_kmeans_from_config,
13 | create_fine_transformer_from_config,
14 | create_single_stage_trainer_from_config,
15 | load_model_config, load_training_config)
16 | from scripts.train_utils import load_checkpoint_from_args, validate_train_args
17 |
18 | if __name__ == '__main__':
19 | parser = argparse.ArgumentParser(description='train fine stage')
20 | parser.add_argument('--results_folder', default='./results/fine')
21 | parser.add_argument('--continue_from_dir', default=None, type=str)
22 | parser.add_argument('--continue_from_step', default=None, type=int)
23 | parser.add_argument('--model_config', default='./configs/model/musiclm_small.json')
24 | parser.add_argument('--training_config', default='./configs/training/train_musiclm_fma.json')
25 | parser.add_argument('--rvq_path', default='./checkpoints/clap.rvq.350.pt')
26 | parser.add_argument('--kmeans_path', default='./results/hubert_kmeans/kmeans.joblib')
27 | parser.add_argument('--fine_tune_from', default=None, type=str)
28 |
29 | args = parser.parse_args()
30 |
31 | validate_train_args(args)
32 |
33 | model_config = load_model_config(args.model_config)
34 | training_config = load_training_config(args.training_config)
35 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
36 |
37 | use_preprocessed_data = training_config.fine_trainer_cfg.use_preprocessed_data
38 |
39 | if use_preprocessed_data:
40 | clap = None
41 | print(f'training from preprocessed data {training_config.fine_trainer_cfg.folder}')
42 | else:
43 | print('loading clap...')
44 | clap = create_clap_quantized_from_config(model_config, args.rvq_path, device)
45 |
46 | print('loading encodec...')
47 | encodec_wrapper = create_encodec_from_config(model_config, device)
48 |
49 | print('loading fine stage...')
50 | fine_transformer = create_fine_transformer_from_config(model_config, args.fine_tune_from, device)
51 |
52 | trainer = create_single_stage_trainer_from_config(
53 | model_config=model_config,
54 | training_config=training_config,
55 | stage='fine',
56 | results_folder=args.results_folder,
57 | transformer=fine_transformer,
58 | clap=clap,
59 | wav2vec=None,
60 | encodec_wrapper=encodec_wrapper,
61 | device=device,
62 | accelerate_kwargs={
63 | 'log_with': "tensorboard",
64 | 'logging_dir': './logs/fine'
65 | },
66 | config_paths=[args.model_config, args.training_config])
67 |
68 | load_checkpoint_from_args(trainer, args)
69 |
70 | trainer.train()
71 |
--------------------------------------------------------------------------------
/scripts/train_hubert_kmeans.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import torch
5 | import argparse
6 |
7 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
8 |
9 | from open_musiclm.config import load_model_config, load_training_config, create_hubert_kmeans_from_config, create_hubert_kmeans_trainer_from_config
10 |
11 | if __name__ == '__main__':
12 | parser = argparse.ArgumentParser(description='train kmeans to quantize hubert embeddings')
13 | parser.add_argument('--results_folder', default='./results/hubert_kmeans')
14 | parser.add_argument('--model_config', default='./configs/model/musiclm_small.json')
15 | parser.add_argument('--training_config', default='./configs/training/train_musiclm_fma.json')
16 |
17 | args = parser.parse_args()
18 |
19 | model_config = load_model_config(args.model_config)
20 | training_config = load_training_config(args.training_config)
21 |
22 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
23 |
24 | print('loading hubert...')
25 | hubert_kmeans = create_hubert_kmeans_from_config(model_config, None, device)
26 |
27 | trainer = create_hubert_kmeans_trainer_from_config(
28 | model_config=model_config,
29 | training_config=training_config,
30 | hubert_kmeans=hubert_kmeans,
31 | results_folder=args.results_folder,
32 | device=device,
33 | config_paths=[args.model_config, args.training_config]
34 | )
35 |
36 | trainer.train()
--------------------------------------------------------------------------------
/scripts/train_semantic_stage.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | from pathlib import Path
5 |
6 | import torch
7 |
8 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
9 |
10 | from open_musiclm.config import (create_clap_quantized_from_config,
11 | create_encodec_from_config,
12 | create_hubert_kmeans_from_config,
13 | create_semantic_transformer_from_config,
14 | create_single_stage_trainer_from_config,
15 | load_model_config, load_training_config)
16 | from scripts.train_utils import validate_train_args, load_checkpoint_from_args
17 |
18 | if __name__ == '__main__':
19 | parser = argparse.ArgumentParser(description='train semantic stage')
20 | parser.add_argument('--results_folder', default='./results/semantic')
21 | parser.add_argument('--continue_from_dir', default=None, type=str)
22 | parser.add_argument('--continue_from_step', default=None, type=int)
23 | parser.add_argument('--model_config', default='./configs/model/musiclm_small.json')
24 | parser.add_argument('--training_config', default='./configs/training/train_musiclm_fma.json')
25 | parser.add_argument('--rvq_path', default='./checkpoints/clap.rvq.350.pt')
26 | parser.add_argument('--kmeans_path', default='./results/hubert_kmeans/kmeans.joblib')
27 | parser.add_argument('--fine_tune_from', default=None, type=str)
28 |
29 | args = parser.parse_args()
30 |
31 | validate_train_args(args)
32 |
33 | model_config = load_model_config(args.model_config)
34 | training_config = load_training_config(args.training_config)
35 |
36 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
37 |
38 | use_preprocessed_data = training_config.semantic_trainer_cfg.use_preprocessed_data
39 |
40 | if use_preprocessed_data:
41 | clap = None
42 | wav2vec = None
43 | print(f'training from preprocessed data {training_config.semantic_trainer_cfg.folder}')
44 | else:
45 | print('loading clap...')
46 | clap = create_clap_quantized_from_config(model_config, args.rvq_path, device)
47 |
48 | print('loading wav2vec...')
49 | wav2vec = create_hubert_kmeans_from_config(model_config, args.kmeans_path, device)
50 |
51 | print('loading semantic stage...')
52 | semantic_transformer = create_semantic_transformer_from_config(model_config, args.fine_tune_from, device)
53 |
54 | trainer = create_single_stage_trainer_from_config(
55 | model_config=model_config,
56 | training_config=training_config,
57 | stage='semantic',
58 | results_folder=args.results_folder,
59 | transformer=semantic_transformer,
60 | clap=clap,
61 | wav2vec=wav2vec,
62 | encodec_wrapper=None,
63 | device=device,
64 | accelerate_kwargs={
65 | 'log_with': "tensorboard",
66 | 'logging_dir': './logs/semantic'
67 | },
68 | config_paths=[args.model_config, args.training_config])
69 |
70 | load_checkpoint_from_args(trainer, args)
71 |
72 | trainer.train()
73 |
--------------------------------------------------------------------------------
/scripts/train_utils.py:
--------------------------------------------------------------------------------
1 | from functools import wraps
2 | import logging
3 | import sys
4 | import os
5 | from pathlib import Path
6 |
7 | def exists(val):
8 | return val is not None
9 |
10 | class disable_print:
11 | def __enter__(self):
12 | self._original_stdout = sys.stdout
13 | sys.stdout = open(os.devnull, 'w')
14 |
15 | def __exit__(self, exc_type, exc_val, exc_tb):
16 | sys.stdout.close()
17 | sys.stdout = self._original_stdout
18 |
19 | def get_latest_checkpoints(results_folder, max_step=None):
20 | highest_transformer_step = -1
21 | highest_optimizer_step = -1
22 | highest_scheduler_step = -1
23 | transformer_path = None
24 | optimizer_path = None
25 | scheduler_path = None
26 | max_step = float('inf') if max_step is None else max_step
27 | for file in os.listdir(results_folder):
28 | if file.endswith('.pt'):
29 | if 'transformer' in file:
30 | step = int(file.split('.')[2])
31 | if step > highest_transformer_step and step <= max_step:
32 | highest_transformer_step = step
33 | transformer_path = os.path.join(results_folder, file)
34 | elif 'optimizer' in file:
35 | step = int(file.split('.')[2])
36 | if step > highest_optimizer_step and step <= max_step:
37 | highest_optimizer_step = step
38 | optimizer_path = os.path.join(results_folder, file)
39 | elif 'scheduler' in file:
40 | step = int(file.split('.')[2])
41 | if step > highest_scheduler_step and step <= max_step:
42 | highest_scheduler_step = step
43 | scheduler_path = os.path.join(results_folder, file)
44 |
45 | assert highest_transformer_step == highest_optimizer_step, 'transformer and optimizer checkpoints are not aligned'
46 | if scheduler_path is not None:
47 | assert highest_transformer_step == highest_scheduler_step, 'transformer and scheduler checkpoints are not aligned'
48 |
49 | return (transformer_path, optimizer_path, scheduler_path), highest_transformer_step
50 |
51 | def validate_train_args(args):
52 | assert not(exists(args.fine_tune_from) and exists(args.continue_from_dir)), 'choose one: fine tune from a checkpoint or continue from a directory'
53 |
54 | print(f'saving results to {args.results_folder}, using model config {args.model_config} and training config {args.training_config}, using rvq checkpoint {args.rvq_path} and kmeans checkpoint {args.kmeans_path}')
55 | if exists(args.continue_from_dir):
56 | print(f'continuing from latest checkpoint in {args.continue_from_dir}')
57 | assert not Path(args.continue_from_dir) == Path(args.results_folder), 'continue_from_dir must be different from results_folder'
58 | elif exists(args.fine_tune_from):
59 | print(f'fine tuning from checkpoint {args.fine_tune_from}. Make sure to use the same model config as the base model.')
60 |
61 | def load_checkpoint_from_args(trainer, args):
62 | if exists(args.continue_from_dir):
63 | checkpoints, steps = get_latest_checkpoints(args.continue_from_dir, args.continue_from_step)
64 | print(f'loading checkpoints: {checkpoints}')
65 | trainer.load(*checkpoints, steps=steps+1)
66 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from setuptools import setup, find_packages
4 |
5 | # Load the version from file
6 | __version__ = Path("VERSION").read_text().strip()
7 |
8 | setup(
9 | name = 'open-musiclm',
10 | packages = find_packages(exclude=[]),
11 | include_package_data=True,
12 | version = __version__,
13 | license='MIT',
14 | description = 'Open MusicLM - Implementation of MusicLM, a text to music model published by Google Research, with a few modifications',
15 | author = 'Allen Zhang',
16 | long_description_content_type = 'text/markdown',
17 | url = 'https://github.com/zhvng/open-musiclm',
18 | keywords = [
19 | 'artificial intelligence',
20 | 'deep learning',
21 | 'transformers',
22 | 'attention mechanism',
23 | 'audio generation',
24 | 'musiclm',
25 | ],
26 | install_requires=[
27 | 'torch',
28 | 'torchvision',
29 | 'torchaudio',
30 | 'einops>=0.6.1',
31 | 'vector-quantize-pytorch>=1.2.2',
32 | 'librosa',
33 | 'torchlibrosa',
34 | 'ftfy',
35 | 'tqdm',
36 | 'transformers',
37 | 'encodec',
38 | 'gdown',
39 | 'accelerate>=0.17.0',
40 | 'beartype',
41 | 'joblib',
42 | 'h5py',
43 | 'scikit-learn',
44 | 'wget',
45 | ],
46 | classifiers=[
47 | 'Development Status :: 4 - Beta',
48 | 'Intended Audience :: Developers',
49 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
50 | 'License :: OSI Approved :: MIT License',
51 | 'Programming Language :: Python :: 3.10',
52 | ],
53 | )
--------------------------------------------------------------------------------