├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── VERSION ├── clap.png ├── configs ├── model │ ├── musiclm_large.json │ ├── musiclm_large_small_context.json │ └── musiclm_small.json └── training │ ├── train_fma_preprocess.json │ └── train_musiclm_fma.json ├── environment.yaml ├── musiclm.png ├── notebooks └── analyze_fma.ipynb ├── open_musiclm ├── __init__.py ├── clap_quantized.py ├── config.py ├── data.py ├── encodec_wrapper.py ├── hf_hubert_kmeans.py ├── laion_clap │ ├── __init__.py │ ├── clap_module │ │ ├── __init__.py │ │ ├── bert.py │ │ ├── bpe_simple_vocab_16e6.txt.gz │ │ ├── factory.py │ │ ├── feature_fusion.py │ │ ├── htsat.py │ │ ├── linear_probe.py │ │ ├── loss.py │ │ ├── model.py │ │ ├── model_configs │ │ │ ├── HTSAT-base.json │ │ │ ├── HTSAT-large.json │ │ │ ├── HTSAT-tiny-win-1536.json │ │ │ ├── HTSAT-tiny.json │ │ │ ├── PANN-10.json │ │ │ ├── PANN-14-fmax-18k.json │ │ │ ├── PANN-14-fmax-8k-20s.json │ │ │ ├── PANN-14-tiny-transformer.json │ │ │ ├── PANN-14-win-1536.json │ │ │ ├── PANN-14.json │ │ │ ├── PANN-6.json │ │ │ ├── RN101-quickgelu.json │ │ │ ├── RN101.json │ │ │ ├── RN50-quickgelu.json │ │ │ ├── RN50.json │ │ │ ├── RN50x16.json │ │ │ ├── RN50x4.json │ │ │ ├── ViT-B-16.json │ │ │ ├── ViT-B-32-quickgelu.json │ │ │ ├── ViT-B-32.json │ │ │ └── ViT-L-14.json │ │ ├── openai.py │ │ ├── pann_model.py │ │ ├── pretrained.py │ │ ├── timm_model.py │ │ ├── tokenizer.py │ │ ├── transform.py │ │ ├── utils.py │ │ └── version.py │ └── hook.py ├── model_types.py ├── open_musiclm.py ├── optimizer.py ├── preprocess.py ├── trainer.py ├── transformer.py └── utils.py ├── scripts ├── __init__.py ├── download_checkpoints.sh ├── download_fma_large.sh ├── download_fma_metadata.sh ├── infer.py ├── infer_coarse.py ├── infer_fine.py ├── infer_top_match.py ├── preprocess_data.py ├── test │ ├── __init__.py │ ├── test_clap.py │ ├── test_config.py │ ├── test_dataloader.py │ ├── test_encodec.py │ ├── test_hubert_clustering.py │ ├── test_load_preprocessed.py │ └── test_rvq.py ├── train_clap_rvq.py ├── train_coarse_stage.py ├── train_fine_stage.py ├── train_hubert_kmeans.py ├── train_semantic_stage.py └── train_utils.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | data/ 3 | checkpoints/ 4 | results/ 5 | test.wav 6 | logs/ 7 | .vscode 8 | wandb 9 | 10 | **/630k-*.pt 11 | 12 | # Byte-compiled / optimized / DLL files 13 | __pycache__/ 14 | *.py[cod] 15 | *$py.class 16 | 17 | # C extensions 18 | *.so 19 | 20 | # Distribution / packaging 21 | .Python 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | pip-wheel-metadata/ 35 | share/python-wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | MANIFEST 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .nox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | *.py,cover 62 | .hypothesis/ 63 | .pytest_cache/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | db.sqlite3 73 | db.sqlite3-journal 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 106 | __pypackages__/ 107 | 108 | # Celery stuff 109 | celerybeat-schedule 110 | celerybeat.pid 111 | 112 | # SageMath parsed files 113 | *.sage.py 114 | 115 | # Environments 116 | .env 117 | .venv 118 | env/ 119 | venv/ 120 | ENV/ 121 | env.bak/ 122 | venv.bak/ 123 | 124 | # Spyder project settings 125 | .spyderproject 126 | .spyproject 127 | 128 | # Rope project settings 129 | .ropeproject 130 | 131 | # mkdocs documentation 132 | /site 133 | 134 | # mypy 135 | .mypy_cache/ 136 | .dmypy.json 137 | dmypy.json 138 | 139 | # Pyre type checker 140 | .pyre/ 141 | 142 | .DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Allen Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include open_musiclm/laion_clap/clap_module/model_configs *.json 2 | recursive-include open_musiclm/laion_clap/clap_module bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Open MusicLM 2 | Pytorch implementation of [MusicLM](https://arxiv.org/abs/2301.11325), a SOTA text to music model published by Google, with a few modifications. We use [CLAP](https://github.com/LAION-AI/CLAP) as a replacement for MuLan, [Encodec](https://github.com/facebookresearch/encodec) as a replacement for SoundStream, and [MERT](https://huggingface.co/m-a-p/MERT-v0) as a replacement for w2v-BERT. 3 | 4 |

5 | diagram of MusicLM 6 | diagram of CLAP 7 |

8 | 9 | ## Why CLAP? 10 | CLAP is a joint audio-text model trained on [LAION-Audio-630K](https://github.com/LAION-AI/audio-dataset). Similar to MuLan, it consists of an audio tower and a text tower that project their respective media onto a shared latent space (512 dimensions in CLAP vs 128 dimensions in MuLan). 11 | 12 | MuLan was trained on 50 million text-music pairs. Unfortunately I don't have the data to replicate this, so I'm relying on CLAP's pretrained checkpoints to come close. CLAP was trained on 2.6 million total text-audio pairs from LAION-630k (~633k text-audio pairs) and AudioSet (2 million samples with captions generated by a keyword-to-caption model). Although this is a fraction of the data used to train MuLan, we have successfully used CLAP to generate diverse music samples, which you can listen to [here](https://drive.google.com/drive/folders/1pGY8EP2EZlE2pPpXn5E3YAkoCgATWw3_) (keep in mind these are very early results). In the event that CLAP's latent space is not expressive enough for music generation, we can train CLAP on music or substitute the model for @lucidrain's [MuLan implementation](https://github.com/lucidrains/musiclm-pytorch) once it is trained. 13 | 14 | ## Why Encodec? 15 | SoundStream and Encodec are both neural audio codecs that encode any waveform to a sequence of acoustic tokens, which can then be decoded into a waveform resembling the original. These intermediate tokens can then be modeled as a seq2seq task. [Encodec](https://github.com/facebookresearch/encodec) is released by Facebook and pretrained checkpoints are publicly available, whereas this is not the case with SoundStream. 16 | 17 | ## Differences from @lucidrains implementation 18 | - Autoregressively models the CLAP/MuLan conditioning signal by passing it into the transformers as discrete tokens, as mentioned in section 3.1 of the paper. Musiclm-pytorch conditions on them with cross attention. 19 | - TokenConditionedTransformer can support variable token sequences, which makes it easy to do further experimentation (e.g. combining multiple conditioning signals, stereo waveform generation, etc.) 20 | - Uses existing open source models instead of training MuLan and SoundStream. 21 | - Some modifications to increase the chance of successfully training the model. 22 | 23 | # End Goal 24 | The goal of this project is to replicate the results of MusicLM as quickly as possible without necessarily sticking to the architecture in the paper. For those looking for a more true-to-form implementation, check out [musiclm-pytorch](https://github.com/lucidrains/musiclm-pytorch). 25 | 26 | We also seek to gain a better understanding of CLAP's latent space. 27 | 28 | Join us on discord if you'd like to get involved! [join discord](https://discord.gg/jN8jADShX5) 29 | 30 | # Usage 31 | ## Install 32 | ```shell 33 | conda env create -f environment.yaml 34 | conda activate open-musiclm 35 | ``` 36 | 37 | ## Configs 38 | A "model config" contains information about the model architecture such as the number of layers, number of quantizers, target audio lengths for each stage, etc. It is used to instantiate the model during training and inference. 39 | 40 | A "training config" contains hyperparameters for training the model. It is used to instantiate the trainer classes during training. 41 | 42 | See the `./configs` directory for example configs. 43 | 44 | ## Training 45 | ### CLAP RVQ 46 | The first step is to train the residual vector quantizer that maps continuous CLAP embeds to a discrete token sequence. 47 | ```shell 48 | python ./scripts/train_clap_rvq.py \ 49 | --results_folder ./results/clap_rvq \ # where to save results and checkpoints 50 | --model_config ./configs/model/musiclm_small.json \ # path to model config 51 | --training_config ./configs/training/train_musiclm_fma.json # path to training config 52 | ``` 53 | 54 | ### Hubert K-means 55 | Next, we learn a K-means layer that we use to quantize our MERT embeddings into semantic tokens. 56 | ```shell 57 | python ./scripts/train_hubert_kmeans.py \ 58 | --results_folder ./results/hubert_kmeans \ # where to save results and checkpoints 59 | --model_config ./configs/model/musiclm_small.json \ 60 | --training_config ./configs/training/train_musiclm_fma.json 61 | ``` 62 | 63 | ### Semantic Stage + Coarse Stage + Fine Stage 64 | Once we have a working K-means and RVQ, we can now train the semantic, coarse and fine stages. These stages can be trained concurrently. 65 | ```shell 66 | python ./scripts/train_semantic_stage.py \ 67 | --results_folder ./results/semantic \ # where to save results and checkpoints 68 | --model_config ./configs/model/musiclm_small.json \ 69 | --training_config ./configs/training/train_musiclm_fma.json \ 70 | --rvq_path PATH_TO_RVQ_CHECKPOINT \ # path to previously trained rvq 71 | --kmeans_path PATH_TO_KMEANS_CHECKPOINT # path to previously trained kmeans 72 | ``` 73 | ```shell 74 | python ./scripts/train_coarse_stage.py \ 75 | --results_folder ./results/coarse \ # where to save results and checkpoints 76 | --model_config ./configs/model/musiclm_small.json \ 77 | --training_config ./configs/training/train_musiclm_fma.json \ 78 | --rvq_path PATH_TO_RVQ_CHECKPOINT \ # path to previously trained rvq 79 | --kmeans_path PATH_TO_KMEANS_CHECKPOINT # path to previously trained kmeans 80 | ``` 81 | ```shell 82 | python ./scripts/train_fine_stage.py \ 83 | --results_folder ./results/fine \ # where to save results and checkpoints 84 | --model_config ./configs/model/musiclm_small.json \ 85 | --training_config ./configs/training/train_musiclm_fma.json \ 86 | --rvq_path PATH_TO_RVQ_CHECKPOINT \ # path to previously trained rvq 87 | --kmeans_path PATH_TO_KMEANS_CHECKPOINT # path to previously trained kmeans 88 | ``` 89 | 90 | ## Preprocessing 91 | In the above case, we are using CLAP, Hubert and Encodec to generate clap, semantic and acoustic tokens live during training. However, these models take up space on the GPU, and it is inefficient to recompute these tokens if we're making multiple runs on the same data. We can instead compute these tokens ahead of time and iterate over them during training. 92 | 93 | To do this, fill in the `data_preprocessor_cfg` field in the config and set `use_preprocessed_data` to True in the trainer configs (look at train_fma_preprocess.json for inspiration). Then run the following to preprocess the dataset, followed by your training script. 94 | 95 | ```shell 96 | python ./scripts/preprocess_data.py \ 97 | --model_config ./configs/model/musiclm_small.json \ 98 | --training_config ./configs/training/train_fma_preprocess.json \ 99 | --rvq_path PATH_TO_RVQ_CHECKPOINT \ # path to previously trained rvq 100 | --kmeans_path PATH_TO_KMEANS_CHECKPOINT # path to previously trained kmeans 101 | ``` 102 | 103 | ## Inference 104 | Generate multiple samples and use CLAP to select the best ones: 105 | ```shell 106 | python scripts/infer_top_match.py \ 107 | "your text prompt" 108 | --num_samples 4 # number of samples to generate 109 | --num_top_matches 1 # number of top matches to return 110 | --semantic_path PATH_TO_SEMANTIC_CHECKPOINT \ # path to previously trained semantic stage 111 | --coarse_path PATH_TO_COARSE_CHECKPOINT \ # path to previously trained coarse stage 112 | --fine_path PATH_TO_FINE_CHECKPOINT \ # path to previously trained fine stage 113 | --rvq_path PATH_TO_RVQ_CHECKPOINT \ # path to previously trained rvq 114 | --kmeans_path PATH_TO_KMEANS_CHECKPOINT # path to previously trained kmeans 115 | --model_config ./configs/model/musiclm_small.json \ 116 | --duration 4 117 | ``` 118 | 119 | Generate samples for various test prompts: 120 | ```shell 121 | python scripts/infer.py \ 122 | --semantic_path PATH_TO_SEMANTIC_CHECKPOINT \ # path to previously trained semantic stage 123 | --coarse_path PATH_TO_COARSE_CHECKPOINT \ # path to previously trained coarse stage 124 | --fine_path PATH_TO_FINE_CHECKPOINT \ # path to previously trained fine stage 125 | --rvq_path PATH_TO_RVQ_CHECKPOINT \ # path to previously trained rvq 126 | --kmeans_path PATH_TO_KMEANS_CHECKPOINT # path to previously trained kmeans 127 | --model_config ./configs/model/musiclm_small.json \ 128 | --duration 4 129 | ``` 130 | 131 | You can use the `--return_coarse_wave` flag to skip the fine stage and reconstruct audio from coarse tokens alone. 132 | 133 | ## Checkpoints 134 | You can download experimental checkpoints for the musiclm_large_small_context model [here](https://drive.google.com/drive/u/0/folders/1347glwEc-6XWulfU7NGrFrYTvTnjeVJE). To fine tune the model, call the train scripts with the `--fine_tune_from` flag. 135 | 136 | # Thank you 137 | * [Okio](https://okio.ai/) for providing compute to train the model! Okio is a startup that is developing Nendo - an open source generative music tool-suite 138 | that re-imagines music. If you're interested check them out at [okio.ai](https://okio.ai/) 139 | * [@lucidrains](https://github.com/lucidrains/) for the [audiolm-pytorch](https://github.com/lucidrains/audiolm-pytorch) implementation. This repo contains a refactored version of a lot of the code in [audiolm-pytorch](https://github.com/lucidrains/audiolm-pytorch). 140 | * [LAION](https://laion.ai/) for [CLAP](https://github.com/LAION-AI/CLAP) 141 | * [Music Audio Pretrain team](https://huggingface.co/m-a-p) for [MERT](https://huggingface.co/m-a-p/MERT-v0) 142 | 143 | # Citations 144 | ```bibtex 145 | @inproceedings{Agostinelli2023MusicLMGM, 146 | title = {MusicLM: Generating Music From Text}, 147 | author = {Andrea Agostinelli and Timo I. Denk and Zal{\'a}n Borsos and Jesse Engel and Mauro Verzetti and Antoine Caillon and Qingqing Huang and Aren Jansen and Adam Roberts and Marco Tagliasacchi and Matthew Sharifi and Neil Zeghidour and C. Frank}, 148 | year = {2023} 149 | } 150 | ``` 151 | ```bibtex 152 | @article{wu2022large, 153 | title = {Large-scale Contrastive Language-Audio Pretraining with Feature Fusion and Keyword-to-Caption Augmentation}, 154 | author = {Wu, Yusong and Chen, Ke and Zhang, Tianyu and Hui, Yuchen and Berg-Kirkpatrick, Taylor and Dubnov, Shlomo}, 155 | journal = {arXiv preprint arXiv:2211:06687}, 156 | year = {2022}, 157 | } 158 | ``` 159 | ```bibtex 160 | @article{defossez2022highfi, 161 | title = {High Fidelity Neural Audio Compression}, 162 | author = {Défossez, Alexandre and Copet, Jade and Synnaeve, Gabriel and Adi, Yossi}, 163 | journal = {arXiv preprint arXiv:2210.13438}, 164 | year = {2022} 165 | } 166 | ``` 167 | ```bibtex 168 | @misc{li2023mert, 169 | title = {MERT: Acoustic Music Understanding Model with Large-Scale Self-supervised Training}, 170 | author = {Yizhi Li and Ruibin Yuan and Ge Zhang and Yinghao Ma and Xingran Chen and Hanzhi Yin and Chenghua Lin and Anton Ragni and Emmanouil Benetos and Norbert Gyenge and Roger Dannenberg and Ruibo Liu and Wenhu Chen and Gus Xia and Yemin Shi and Wenhao Huang and Yike Guo and Jie Fu}, 171 | year = {2023}, 172 | eprint = {2306.00107}, 173 | archivePrefix = {arXiv}, 174 | primaryClass = {cs.SD} 175 | } 176 | ``` 177 | -------------------------------------------------------------------------------- /VERSION: -------------------------------------------------------------------------------- 1 | 0.2.5 -------------------------------------------------------------------------------- /clap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhvng/open-musiclm/8e2c6a8d65a33d407072bab8f5af64d9cb49f2e4/clap.png -------------------------------------------------------------------------------- /configs/model/musiclm_large.json: -------------------------------------------------------------------------------- 1 | { 2 | "global_cfg": { 3 | "semantic_audio_length_seconds": 30.0, 4 | "coarse_audio_length_seconds": 10.0, 5 | "fine_audio_length_seconds": 3.0, 6 | "clap_audio_length_seconds": 30.0, 7 | "num_coarse_quantizers": 3, 8 | "num_fine_quantizers": 5 9 | }, 10 | "clap_rvq_cfg": { 11 | "enable_fusion": true, 12 | "rq_num_quantizers": 12, 13 | "codebook_size": 1024, 14 | "rq_ema_decay": 0.95, 15 | "threshold_ema_dead_code": 0.5 16 | }, 17 | "hubert_kmeans_cfg": { 18 | "model_name": "m-a-p/MERT-v0", 19 | "normalize_embeds": true, 20 | "embed_layer": 7, 21 | "target_sample_hz": 16000, 22 | "seq_len_multiple_of": 320, 23 | "codebook_size": 1024, 24 | "output_hz": 50 25 | }, 26 | "encodec_cfg": { 27 | "bandwidth": 6.0, 28 | "codebook_size": 1024, 29 | "output_hz": 75 30 | }, 31 | "semantic_cfg": { 32 | "dim": 1024, 33 | "depth": 24, 34 | "heads": 16, 35 | "attn_dropout": 0.0, 36 | "ff_dropout": 0.1, 37 | "grad_shrink_alpha": 0.1, 38 | "non_causal_prefix_size": 0, 39 | "relative_position_bias_type": "continuous", 40 | "use_memory_efficient_attention": false 41 | }, 42 | "coarse_cfg": { 43 | "dim": 1024, 44 | "depth": 24, 45 | "heads": 16, 46 | "attn_dropout": 0.0, 47 | "ff_dropout": 0.1, 48 | "grad_shrink_alpha": 0.1, 49 | "non_causal_prefix_size": 0, 50 | "relative_position_bias_type": "continuous", 51 | "use_memory_efficient_attention": false 52 | }, 53 | "fine_cfg": { 54 | "dim": 1024, 55 | "depth": 24, 56 | "heads": 16, 57 | "attn_dropout": 0.0, 58 | "ff_dropout": 0.1, 59 | "grad_shrink_alpha": 0.1, 60 | "non_causal_prefix_size": 0, 61 | "relative_position_bias_type": "continuous", 62 | "use_memory_efficient_attention": false 63 | } 64 | } -------------------------------------------------------------------------------- /configs/model/musiclm_large_small_context.json: -------------------------------------------------------------------------------- 1 | { 2 | "global_cfg": { 3 | "semantic_audio_length_seconds": 10.0, 4 | "coarse_audio_length_seconds": 4.0, 5 | "fine_audio_length_seconds": 2.0, 6 | "clap_audio_length_seconds": 10.0, 7 | "num_coarse_quantizers": 3, 8 | "num_fine_quantizers": 5 9 | }, 10 | "clap_rvq_cfg": { 11 | "enable_fusion": false, 12 | "rq_num_quantizers": 12, 13 | "codebook_size": 1024, 14 | "rq_ema_decay": 0.95, 15 | "threshold_ema_dead_code": 0.5 16 | }, 17 | "hubert_kmeans_cfg": { 18 | "model_name": "m-a-p/MERT-v0", 19 | "normalize_embeds": true, 20 | "embed_layer": 7, 21 | "target_sample_hz": 16000, 22 | "seq_len_multiple_of": 320, 23 | "codebook_size": 1024, 24 | "output_hz": 50 25 | }, 26 | "encodec_cfg": { 27 | "bandwidth": 6.0, 28 | "codebook_size": 1024, 29 | "output_hz": 75 30 | }, 31 | "semantic_cfg": { 32 | "dim": 1024, 33 | "depth": 24, 34 | "heads": 16, 35 | "attn_dropout": 0.0, 36 | "ff_dropout": 0.1, 37 | "grad_shrink_alpha": 0.1, 38 | "non_causal_prefix_size": 0, 39 | "relative_position_bias_type": "continuous", 40 | "use_memory_efficient_attention": false 41 | }, 42 | "coarse_cfg": { 43 | "dim": 1024, 44 | "depth": 24, 45 | "heads": 16, 46 | "attn_dropout": 0.0, 47 | "ff_dropout": 0.1, 48 | "grad_shrink_alpha": 0.1, 49 | "non_causal_prefix_size": 0, 50 | "relative_position_bias_type": "continuous", 51 | "use_memory_efficient_attention": false 52 | }, 53 | "fine_cfg": { 54 | "dim": 1024, 55 | "depth": 24, 56 | "heads": 16, 57 | "attn_dropout": 0.0, 58 | "ff_dropout": 0.1, 59 | "grad_shrink_alpha": 0.1, 60 | "non_causal_prefix_size": 0, 61 | "relative_position_bias_type": "continuous", 62 | "use_memory_efficient_attention": false 63 | } 64 | } -------------------------------------------------------------------------------- /configs/model/musiclm_small.json: -------------------------------------------------------------------------------- 1 | { 2 | "global_cfg": { 3 | "semantic_audio_length_seconds": 10.0, 4 | "coarse_audio_length_seconds": 4.0, 5 | "fine_audio_length_seconds": 2.0, 6 | "clap_audio_length_seconds": 10.0, 7 | "num_coarse_quantizers": 3, 8 | "num_fine_quantizers": 5 9 | }, 10 | "clap_rvq_cfg": { 11 | "enable_fusion": false, 12 | "rq_num_quantizers": 12, 13 | "codebook_size": 1024, 14 | "rq_ema_decay": 0.95, 15 | "threshold_ema_dead_code": 0.5 16 | }, 17 | "hubert_kmeans_cfg": { 18 | "model_name": "m-a-p/MERT-v0", 19 | "normalize_embeds": true, 20 | "embed_layer": 7, 21 | "target_sample_hz": 16000, 22 | "seq_len_multiple_of": 320, 23 | "codebook_size": 1024, 24 | "output_hz": 50 25 | }, 26 | "encodec_cfg": { 27 | "bandwidth": 6.0, 28 | "codebook_size": 1024, 29 | "output_hz": 75 30 | }, 31 | "semantic_cfg": { 32 | "dim": 1024, 33 | "depth": 6, 34 | "heads": 8, 35 | "attn_dropout": 0.0, 36 | "ff_dropout": 0.1, 37 | "grad_shrink_alpha": 0.1, 38 | "non_causal_prefix_size": 0, 39 | "relative_position_bias_type": "continuous", 40 | "use_memory_efficient_attention": false 41 | }, 42 | "coarse_cfg": { 43 | "dim": 1024, 44 | "depth": 6, 45 | "heads": 8, 46 | "attn_dropout": 0.0, 47 | "ff_dropout": 0.1, 48 | "grad_shrink_alpha": 0.1, 49 | "non_causal_prefix_size": 0, 50 | "relative_position_bias_type": "continuous", 51 | "use_memory_efficient_attention": false 52 | }, 53 | "fine_cfg": { 54 | "dim": 1024, 55 | "depth": 6, 56 | "heads": 8, 57 | "attn_dropout": 0.0, 58 | "ff_dropout": 0.1, 59 | "grad_shrink_alpha": 0.1, 60 | "non_causal_prefix_size": 0, 61 | "relative_position_bias_type": "continuous", 62 | "use_memory_efficient_attention": false 63 | } 64 | } -------------------------------------------------------------------------------- /configs/training/train_fma_preprocess.json: -------------------------------------------------------------------------------- 1 | { 2 | "clap_rvq_trainer_cfg": { 3 | "folder": "./data/fma_large", 4 | "num_train_steps": 1000, 5 | "batch_size": 64, 6 | "accumulate_batches": 32, 7 | "save_model_every": 10, 8 | "save_results_every": 5 9 | }, 10 | "hubert_kmeans_trainer_cfg": { 11 | "folder": "./data/fma_large", 12 | "feature_extraction_num_steps": 320, 13 | "feature_extraction_batch_size": 32 14 | }, 15 | "semantic_trainer_cfg": { 16 | "stage": "semantic", 17 | "folder": "./data/fma_preprocessed", 18 | "valid_frac": 0.05, 19 | "lr": 0.0003, 20 | "lr_warmup": 3000, 21 | "batch_size": 4, 22 | "grad_accum_every": 8, 23 | "wd": 0.01, 24 | "max_grad_norm": 0.5, 25 | "cross_entropy_loss_weights": [0.0, 1.0], 26 | "num_train_steps": 200001, 27 | "save_results_every": 250, 28 | "save_model_every": 1000, 29 | "save_predicted_tokens": true, 30 | "save_reconstructed_wave": true, 31 | "use_preprocessed_data": true 32 | }, 33 | "coarse_trainer_cfg": { 34 | "stage": "coarse", 35 | "folder": "./data/fma_preprocessed", 36 | "valid_frac": 0.05, 37 | "lr": 0.0003, 38 | "lr_warmup": 6000, 39 | "batch_size": 2, 40 | "grad_accum_every": 8, 41 | "wd": 0.01, 42 | "max_grad_norm": 0.5, 43 | "cross_entropy_loss_weights": [0.0, 0.0, 1.0], 44 | "num_train_steps": 200001, 45 | "save_results_every": 250, 46 | "save_model_every": 1000, 47 | "save_predicted_tokens": true, 48 | "save_reconstructed_wave": true, 49 | "use_preprocessed_data": true 50 | }, 51 | "fine_trainer_cfg": { 52 | "stage": "fine", 53 | "folder": "./data/fma_preprocessed", 54 | "valid_frac": 0.05, 55 | "lr": 0.0003, 56 | "lr_warmup": 0, 57 | "batch_size": 2, 58 | "grad_accum_every": 8, 59 | "wd": 0.01, 60 | "max_grad_norm": 0.5, 61 | "cross_entropy_loss_weights": [0.0, 0.0, 1.0], 62 | "num_train_steps": 200001, 63 | "save_results_every": 250, 64 | "save_model_every": 1000, 65 | "save_predicted_tokens": true, 66 | "save_reconstructed_wave": true, 67 | "use_preprocessed_data": true 68 | }, 69 | "data_preprocessor_cfg": { 70 | "folder": "./data/fma_large", 71 | "metadata_folder": "./data/fma_metadata", 72 | "results_folder": "./data/fma_preprocessed", 73 | "max_audio_length_seconds": 30, 74 | "random_crop": true, 75 | "num_crops": 1, 76 | "clap_batch_size": 32 77 | } 78 | } -------------------------------------------------------------------------------- /configs/training/train_musiclm_fma.json: -------------------------------------------------------------------------------- 1 | { 2 | "clap_rvq_trainer_cfg": { 3 | "folder": "./data/fma_large", 4 | "num_train_steps": 1000, 5 | "batch_size": 64, 6 | "accumulate_batches": 32, 7 | "save_model_every": 10, 8 | "save_results_every": 5 9 | }, 10 | "hubert_kmeans_trainer_cfg": { 11 | "folder": "./data/fma_large", 12 | "feature_extraction_num_steps": 320, 13 | "feature_extraction_batch_size": 32 14 | }, 15 | "semantic_trainer_cfg": { 16 | "stage": "semantic", 17 | "folder": "./data/fma_large", 18 | "valid_frac": 0.05, 19 | "lr": 0.0003, 20 | "lr_warmup": 3000, 21 | "batch_size": 4, 22 | "grad_accum_every": 8, 23 | "wd": 0.01, 24 | "max_grad_norm": 0.5, 25 | "cross_entropy_loss_weights": [0.0, 1.0], 26 | "num_train_steps": 200001, 27 | "save_results_every": 250, 28 | "save_model_every": 1000, 29 | "save_predicted_tokens": true, 30 | "save_reconstructed_wave": true, 31 | "use_preprocessed_data": false 32 | }, 33 | "coarse_trainer_cfg": { 34 | "stage": "coarse", 35 | "folder": "./data/fma_large", 36 | "valid_frac": 0.05, 37 | "lr": 0.0003, 38 | "lr_warmup": 6000, 39 | "batch_size": 2, 40 | "grad_accum_every": 8, 41 | "wd": 0.01, 42 | "max_grad_norm": 0.5, 43 | "cross_entropy_loss_weights": [0.0, 0.0, 1.0], 44 | "num_train_steps": 200001, 45 | "save_results_every": 250, 46 | "save_model_every": 1000, 47 | "save_predicted_tokens": true, 48 | "save_reconstructed_wave": true, 49 | "use_preprocessed_data": false 50 | }, 51 | "fine_trainer_cfg": { 52 | "stage": "fine", 53 | "folder": "./data/fma_large", 54 | "valid_frac": 0.05, 55 | "lr": 0.0003, 56 | "lr_warmup": 0, 57 | "batch_size": 2, 58 | "grad_accum_every": 8, 59 | "wd": 0.01, 60 | "max_grad_norm": 0.5, 61 | "cross_entropy_loss_weights": [0.0, 0.0, 1.0], 62 | "num_train_steps": 200001, 63 | "save_results_every": 250, 64 | "save_model_every": 1000, 65 | "save_predicted_tokens": true, 66 | "save_reconstructed_wave": true, 67 | "use_preprocessed_data": false 68 | }, 69 | "data_preprocessor_cfg": {} 70 | } -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: open-musiclm 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | - defaults 6 | dependencies: 7 | - python=3.10 8 | - pip=22.3.1 9 | - pip: 10 | - --find-links https://download.pytorch.org/whl/torch_stable.html 11 | - torch==2.0.0+cu117 12 | - torchvision==0.15.1+cu117 13 | - torchaudio==2.0.1+cu117 14 | - einops>=0.4 15 | - vector-quantize-pytorch>=0.10.15 16 | - librosa==0.10.0 17 | - torchlibrosa==0.1.0 18 | - ftfy 19 | - tqdm 20 | - transformers 21 | - encodec==0.1.1 22 | - gdown 23 | - accelerate>=0.17.0 24 | - beartype 25 | - joblib 26 | - h5py 27 | - scikit-learn 28 | - wget 29 | -------------------------------------------------------------------------------- /musiclm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhvng/open-musiclm/8e2c6a8d65a33d407072bab8f5af64d9cb49f2e4/musiclm.png -------------------------------------------------------------------------------- /open_musiclm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhvng/open-musiclm/8e2c6a8d65a33d407072bab8f5af64d9cb49f2e4/open_musiclm/__init__.py -------------------------------------------------------------------------------- /open_musiclm/clap_quantized.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import torchaudio 5 | import torchvision.transforms 6 | from beartype.typing import Dict, List, Optional, Union 7 | from einops import rearrange 8 | from torch import nn 9 | from transformers import RobertaTokenizer 10 | from vector_quantize_pytorch import ResidualVQ 11 | 12 | from .laion_clap import CLAP_Module 13 | from .utils import exists, beartype_jit 14 | 15 | 16 | @beartype_jit 17 | class ClapQuantized(nn.Module): 18 | def __init__(self, 19 | *, 20 | clap: CLAP_Module, 21 | codebook_size: int = 1024, 22 | rq_num_quantizers: int = 12, 23 | rq_ema_decay: float = 0.95, 24 | learn_rvq: bool = False, 25 | threshold_ema_dead_code: float = 0.0, 26 | ): 27 | super().__init__() 28 | 29 | self.clap = clap 30 | self.codebook_size = codebook_size 31 | self.learn_rvq = learn_rvq 32 | 33 | self.sample_rate = self.clap.model_cfg['audio_cfg']['sample_rate'] 34 | 35 | for param in self.clap.parameters(): 36 | param.requires_grad = False 37 | 38 | self.rq = ResidualVQ( 39 | dim=clap.model.joint_embed_shape, 40 | num_quantizers=rq_num_quantizers, # specify number of quantizers 41 | codebook_size=codebook_size, # codebook size 42 | commitment_weight=0, # embeddings are frozen so no need for commitment loss 43 | decay=rq_ema_decay, 44 | kmeans_init=True, 45 | threshold_ema_dead_code=threshold_ema_dead_code, 46 | ) 47 | 48 | def forward(self, 49 | *, 50 | audio_input: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, 51 | text_input: Optional[List[str]] = None, 52 | return_embedding: Optional[bool] = False, 53 | return_rvq_loss = False, 54 | ): 55 | """ 56 | Wrapper for clap module that takes in audio or text and returns the quantized embedding from the respective tower 57 | """ 58 | 59 | assert exists(audio_input) ^ exists(text_input), "either audio or text must be provided, but not both" 60 | if exists(audio_input): 61 | assert all(wave.dim() == 1 for wave in audio_input), f"audio_input must be a list of 1D tensors, but got {audio_input[0].shape}" 62 | 63 | with torch.no_grad(): 64 | self.clap.eval() 65 | if exists(audio_input): 66 | embedding = self.clap.get_audio_embedding_from_data(audio_input) 67 | else: 68 | embedding = self.clap.get_text_embedding(text_input) 69 | 70 | if return_embedding: 71 | return embedding 72 | 73 | return self.quantize(embedding, return_rvq_loss=return_rvq_loss) 74 | 75 | def quantize(self, embedding, return_rvq_loss=False): 76 | """ 77 | Quantize an embedding and optionally return the loss 78 | """ 79 | with torch.set_grad_enabled(self.learn_rvq): 80 | self.rq.train(self.learn_rvq) 81 | q, indices, _ = self.rq(rearrange(embedding, 'n c -> n 1 c')) 82 | 83 | if return_rvq_loss: 84 | return F.mse_loss(q, rearrange(embedding, 'n c -> n 1 c')).item() 85 | 86 | indices = rearrange(indices, 'n 1 c -> n c 1') 87 | return indices 88 | 89 | 90 | def create_clap_quantized( 91 | device=None, 92 | learn_rvq=False, 93 | enable_fusion=False, 94 | rvq_checkpoint_path=None, 95 | checkpoint_path: Optional[str] = None, 96 | amodel_type: str = 'HTSAT-tiny', 97 | **kwargs 98 | ): 99 | if device is None: 100 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 101 | 102 | clap = CLAP_Module(enable_fusion=enable_fusion, device=device, amodel=amodel_type) 103 | clap.load_ckpt(ckpt=checkpoint_path) 104 | 105 | clap_quantized = ClapQuantized(clap=clap, learn_rvq=learn_rvq, **kwargs) 106 | 107 | if exists(rvq_checkpoint_path): 108 | rvq = torch.load(rvq_checkpoint_path, map_location=device) 109 | clap_quantized.rq.load_state_dict(rvq) 110 | 111 | return clap_quantized 112 | -------------------------------------------------------------------------------- /open_musiclm/encodec_wrapper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from einops import rearrange 5 | from encodec import EncodecModel 6 | from torch import nn 7 | 8 | from .utils import exists, beartype_jit 9 | 10 | 11 | @beartype_jit 12 | class EncodecWrapper(nn.Module): 13 | def __init__(self, 14 | *, 15 | encodec: EncodecModel, 16 | output_hz: int = 75, 17 | ): 18 | super().__init__() 19 | 20 | self.encodec = encodec 21 | self.sample_rate = encodec.sample_rate 22 | self.output_hz = output_hz 23 | 24 | assert exists(encodec.bandwidth) 25 | total_quantizers = encodec.quantizer.n_q 26 | self.num_quantizers = int(encodec.bandwidth / 24 * total_quantizers) # output quantizers per frame 27 | self.codebook_size = encodec.quantizer.bins 28 | 29 | def forward(self, x: torch.Tensor, return_encoded = True, **kwargs): 30 | assert return_encoded == True 31 | 32 | if x.dim() == 2: 33 | x = rearrange(x, 'b t -> b 1 t') # add in "mono" dimension 34 | 35 | with torch.no_grad(): 36 | self.encodec.eval() 37 | encoded_frames = self.encodec.encode(x) 38 | codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # [B, n_q, T] 39 | codes = rearrange(codes, 'b n_q t -> b t n_q') 40 | 41 | return None, codes, None # [B, T, n_q] 42 | 43 | def decode_from_codebook_indices(self, quantized_indices): 44 | """ 45 | Args: 46 | quantized_indices: [B, T, n_q] 47 | """ 48 | quantized_indices = rearrange(quantized_indices, 'b t n_q -> b n_q t') 49 | 50 | frames = [(quantized_indices, None)] # 1 frame for now 51 | with torch.no_grad(): 52 | self.encodec.eval() 53 | wave = self.encodec.decode(frames) 54 | return wave 55 | 56 | def create_encodec_24khz(bandwidth: float = 6.0, codebook_size: int = 1024, **kwargs): 57 | """ 58 | Create a pretrained EnCodec model. 59 | Args: 60 | bandwidth: float, target bandwidth in kHz""" 61 | assert bandwidth in [1.5, 3., 6., 12., 24.], "invalid bandwidth. must be one of [1.5, 3., 6., 12., 24.]" 62 | 63 | encodec = EncodecModel.encodec_model_24khz() 64 | encodec.set_target_bandwidth(bandwidth) 65 | encodec_wrapper = EncodecWrapper(encodec=encodec, **kwargs) 66 | 67 | assert encodec_wrapper.codebook_size == codebook_size, "encodec codebook size must be 1024 for now" 68 | 69 | return encodec_wrapper -------------------------------------------------------------------------------- /open_musiclm/hf_hubert_kmeans.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | from torch import nn 5 | import numpy as np 6 | from einops import rearrange, pack, unpack 7 | from beartype.typing import Optional 8 | 9 | from torchaudio.functional import resample 10 | from .utils import exists, curtail_to_multiple, zero_mean_unit_var_norm 11 | from transformers import HubertModel 12 | from sklearn.cluster import MiniBatchKMeans 13 | 14 | import joblib 15 | import logging 16 | logging.root.setLevel(logging.ERROR) 17 | 18 | 19 | class HfHubertWithKmeans(nn.Module): 20 | """ 21 | Hugging Face HubertModel + a k-means layer on top. Pretrained checkpoint for music: https://huggingface.co/m-a-p/MERT-v0 22 | Note: MERT-v0 outputs features at 50Hz while Wav2Vec-BERT (used in the paper) outputs at 25 Hz. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | *, 28 | hubert: HubertModel, 29 | kmeans: Optional[MiniBatchKMeans] = None, 30 | embed_layer: int=7, 31 | target_sample_hz=16000, 32 | seq_len_multiple_of=int(16000 / 50), 33 | normalize_embeds=True, 34 | codebook_size: int=1024, 35 | output_hz: int=50 36 | ): 37 | super().__init__() 38 | self.target_sample_hz = target_sample_hz 39 | self.output_hz = output_hz 40 | self.seq_len_multiple_of = seq_len_multiple_of 41 | self.codebook_size = kmeans.n_clusters if exists(kmeans) else None 42 | 43 | self.codebook_size = codebook_size 44 | if exists(kmeans): 45 | assert self.codebook_size == kmeans.n_clusters, "codebook_size must match kmeans.n_clusters" 46 | 47 | self.normalize_embeds = normalize_embeds 48 | 49 | self.embed_layer = embed_layer 50 | 51 | self.hubert = hubert 52 | self.kmeans = kmeans 53 | 54 | @torch.no_grad() 55 | def forward( 56 | self, 57 | wav_input: torch.Tensor, 58 | flatten=True, 59 | return_embed=False, 60 | input_sample_hz=None 61 | ): 62 | assert return_embed or exists(self.kmeans), "kmeans model must be provided if return_embed==False" 63 | 64 | device = wav_input.device 65 | 66 | if exists(input_sample_hz): 67 | wav_input = resample(wav_input, input_sample_hz, self.target_sample_hz) 68 | 69 | if exists(self.seq_len_multiple_of): 70 | wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of) 71 | 72 | hubert_args = { 73 | 'input_values': wav_input, 74 | 'attention_mask': torch.ones_like(wav_input, device=device), # TODO: handle padding 75 | } 76 | 77 | outputs = self.hubert(**hubert_args, output_hidden_states = True) 78 | embed = outputs.hidden_states[self.embed_layer] 79 | 80 | if self.normalize_embeds: 81 | embed = zero_mean_unit_var_norm(embed) 82 | 83 | if return_embed: 84 | return embed 85 | 86 | embed, packed_shape = pack([embed], '* d') 87 | codebook_indices = self.kmeans.predict(embed.detach().cpu().numpy()) 88 | codebook_indices = torch.from_numpy(codebook_indices).to(device).long() 89 | 90 | if flatten: 91 | return codebook_indices 92 | 93 | codebook_indices, = unpack(codebook_indices, packed_shape, '*') 94 | return codebook_indices 95 | 96 | 97 | def get_kmeans_model( 98 | n_clusters, 99 | init, 100 | max_iter, 101 | batch_size, 102 | tol, 103 | max_no_improvement, 104 | n_init, 105 | reassignment_ratio, 106 | ): 107 | return MiniBatchKMeans( 108 | n_clusters=n_clusters, 109 | init=init, 110 | max_iter=max_iter, 111 | batch_size=batch_size, 112 | verbose=1, 113 | compute_labels=False, 114 | tol=tol, 115 | max_no_improvement=max_no_improvement, 116 | init_size=None, 117 | n_init=n_init, 118 | reassignment_ratio=reassignment_ratio, 119 | ) 120 | 121 | 122 | def learn_kmeans( 123 | feat, 124 | seed, 125 | km_path='./results/kmeans.joblib', 126 | n_clusters=1024, 127 | init="k-means++", 128 | max_iter=100, 129 | batch_size=10000, 130 | tol=0.0, 131 | n_init=20, 132 | reassignment_ratio=0.0, 133 | max_no_improvement=100, 134 | ): 135 | np.random.seed(seed) 136 | km_model = get_kmeans_model( 137 | n_clusters, 138 | init, 139 | max_iter, 140 | batch_size, 141 | tol, 142 | max_no_improvement, 143 | n_init, 144 | reassignment_ratio, 145 | ) 146 | km_model.fit(feat) 147 | joblib.dump(km_model, km_path) 148 | 149 | inertia = -km_model.score(feat) / len(feat) 150 | print("total intertia: %.5f", inertia) 151 | print("finished successfully") 152 | 153 | 154 | def get_hubert_kmeans(model_name: str="m-a-p/MERT-v0", kmeans_path: Optional[str]='./checkpoints/kmeans.joblib', **kwargs): 155 | wav2vec = HubertModel.from_pretrained(model_name) 156 | kmeans = joblib.load(kmeans_path) if exists(kmeans_path) else None 157 | 158 | return HfHubertWithKmeans(hubert=wav2vec, kmeans=kmeans, **kwargs) 159 | -------------------------------------------------------------------------------- /open_musiclm/laion_clap/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | dir_path = os.path.dirname(os.path.abspath(__file__)) 4 | sys.path.append(dir_path) 5 | from hook import CLAP_Module -------------------------------------------------------------------------------- /open_musiclm/laion_clap/clap_module/__init__.py: -------------------------------------------------------------------------------- 1 | from .factory import list_models, create_model, create_model_and_transforms, add_model_config 2 | from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics 3 | from .model import CLAP, CLAPTextCfg, CLAPVisionCfg, CLAPAudioCfp, convert_weights_to_fp16, trace_model 4 | from .openai import load_openai_model, list_openai_models 5 | from .pretrained import list_pretrained, list_pretrained_tag_models, list_pretrained_model_tags,\ 6 | get_pretrained_url, download_pretrained 7 | from .tokenizer import SimpleTokenizer, tokenize 8 | from .transform import image_transform -------------------------------------------------------------------------------- /open_musiclm/laion_clap/clap_module/bert.py: -------------------------------------------------------------------------------- 1 | from transformers import BertTokenizer, BertModel 2 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 3 | model = BertModel.from_pretrained("bert-base-uncased") 4 | text = "Replace me by any text you'd like." 5 | 6 | def bert_embeddings(text): 7 | # text = "Replace me by any text you'd like." 8 | encoded_input = tokenizer(text, return_tensors='pt') 9 | output = model(**encoded_input) 10 | return output 11 | 12 | from transformers import RobertaTokenizer, RobertaModel 13 | 14 | tokenizer = RobertaTokenizer.from_pretrained('roberta-base') 15 | model = RobertaModel.from_pretrained('roberta-base') 16 | text = "Replace me by any text you'd like." 17 | def Roberta_embeddings(text): 18 | # text = "Replace me by any text you'd like." 19 | encoded_input = tokenizer(text, return_tensors='pt') 20 | output = model(**encoded_input) 21 | return output 22 | 23 | from transformers import BartTokenizer, BartModel 24 | 25 | tokenizer = BartTokenizer.from_pretrained('facebook/bart-base') 26 | model = BartModel.from_pretrained('facebook/bart-base') 27 | text = "Replace me by any text you'd like." 28 | def bart_embeddings(text): 29 | # text = "Replace me by any text you'd like." 30 | encoded_input = tokenizer(text, return_tensors='pt') 31 | output = model(**encoded_input) 32 | return output -------------------------------------------------------------------------------- /open_musiclm/laion_clap/clap_module/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhvng/open-musiclm/8e2c6a8d65a33d407072bab8f5af64d9cb49f2e4/open_musiclm/laion_clap/clap_module/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /open_musiclm/laion_clap/clap_module/factory.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import pathlib 5 | import re 6 | from copy import deepcopy 7 | from pathlib import Path 8 | 9 | import torch 10 | 11 | from .model import CLAP, convert_weights_to_fp16 12 | from .openai import load_openai_model 13 | from .pretrained import get_pretrained_url, download_pretrained 14 | from .transform import image_transform 15 | 16 | _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] 17 | _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs 18 | 19 | 20 | def _natural_key(string_): 21 | return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())] 22 | 23 | 24 | def _rescan_model_configs(): 25 | global _MODEL_CONFIGS 26 | 27 | config_ext = (".json",) 28 | config_files = [] 29 | for config_path in _MODEL_CONFIG_PATHS: 30 | if config_path.is_file() and config_path.suffix in config_ext: 31 | config_files.append(config_path) 32 | elif config_path.is_dir(): 33 | for ext in config_ext: 34 | config_files.extend(config_path.glob(f"*{ext}")) 35 | 36 | for cf in config_files: 37 | with open(cf, "r") as f: 38 | model_cfg = json.load(f) 39 | if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")): 40 | _MODEL_CONFIGS[cf.stem] = model_cfg 41 | 42 | _MODEL_CONFIGS = { 43 | k: v 44 | for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])) 45 | } 46 | 47 | 48 | _rescan_model_configs() # initial populate of model config registry 49 | 50 | 51 | def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True): 52 | checkpoint = torch.load(checkpoint_path, map_location=map_location) 53 | if isinstance(checkpoint, dict) and "state_dict" in checkpoint: 54 | state_dict = checkpoint["state_dict"] 55 | else: 56 | state_dict = checkpoint 57 | if skip_params: 58 | if next(iter(state_dict.items()))[0].startswith("module"): 59 | state_dict = {k[7:]: v for k, v in state_dict.items()} 60 | # for k in state_dict: 61 | # if k.startswith('transformer'): 62 | # v = state_dict.pop(k) 63 | # state_dict['text_branch.' + k[12:]] = v 64 | return state_dict 65 | 66 | 67 | def create_model( 68 | amodel_name: str, 69 | tmodel_name: str, 70 | pretrained: str = "", 71 | precision: str = "fp32", 72 | device: torch.device = torch.device("cpu"), 73 | jit: bool = False, 74 | force_quick_gelu: bool = False, 75 | openai_model_cache_dir: str = os.path.expanduser("~/.cache/clip"), 76 | skip_params=True, 77 | pretrained_audio: str = "", 78 | pretrained_text: str = "", 79 | enable_fusion: bool = False, 80 | fusion_type: str = 'None' 81 | # pretrained_image: bool = False, 82 | ): 83 | amodel_name = amodel_name.replace( 84 | "/", "-" 85 | ) # for callers using old naming with / in ViT names 86 | pretrained_orig = pretrained 87 | pretrained = pretrained.lower() 88 | if pretrained == "openai": 89 | if amodel_name in _MODEL_CONFIGS: 90 | logging.info(f"Loading {amodel_name} model config.") 91 | model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name]) 92 | else: 93 | logging.error( 94 | f"Model config for {amodel_name} not found; available models {list_models()}." 95 | ) 96 | raise RuntimeError(f"Model config for {amodel_name} not found.") 97 | 98 | logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.") 99 | # Hard Code in model name 100 | model_cfg["text_cfg"]["model_type"] = tmodel_name 101 | model = load_openai_model( 102 | "ViT-B-16", 103 | model_cfg, 104 | device=device, 105 | jit=jit, 106 | cache_dir=openai_model_cache_dir, 107 | enable_fusion=enable_fusion, 108 | fusion_type=fusion_type 109 | ) 110 | # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372 111 | if precision == "amp" or precision == "fp32": 112 | model = model.float() 113 | else: 114 | if amodel_name in _MODEL_CONFIGS: 115 | logging.info(f"Loading {amodel_name} model config.") 116 | model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name]) 117 | else: 118 | logging.error( 119 | f"Model config for {amodel_name} not found; available models {list_models()}." 120 | ) 121 | raise RuntimeError(f"Model config for {amodel_name} not found.") 122 | 123 | if force_quick_gelu: 124 | # override for use of QuickGELU on non-OpenAI transformer models 125 | model_cfg["quick_gelu"] = True 126 | 127 | # if pretrained_image: 128 | # if 'timm_amodel_name' in model_cfg.get('vision_cfg', {}): 129 | # # pretrained weight loading for timm models set via vision_cfg 130 | # model_cfg['vision_cfg']['timm_model_pretrained'] = True 131 | # else: 132 | # assert False, 'pretrained image towers currently only supported for timm models' 133 | model_cfg["text_cfg"]["model_type"] = tmodel_name 134 | model_cfg["enable_fusion"] = enable_fusion 135 | model_cfg["fusion_type"] = fusion_type 136 | model = CLAP(**model_cfg) 137 | 138 | if pretrained: 139 | checkpoint_path = "" 140 | url = get_pretrained_url(amodel_name, pretrained) 141 | if url: 142 | checkpoint_path = download_pretrained(url, root=openai_model_cache_dir) 143 | elif os.path.exists(pretrained_orig): 144 | checkpoint_path = pretrained_orig 145 | if checkpoint_path: 146 | logging.info(f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained}).") 147 | ckpt = load_state_dict(checkpoint_path, skip_params=True) 148 | model.load_state_dict(ckpt) 149 | param_names = [n for n, p in model.named_parameters()] 150 | for n in param_names: 151 | print(n, "\t", "Loaded" if n in ckpt else "Unloaded") 152 | else: 153 | logging.warning( 154 | f"Pretrained weights ({pretrained}) not found for model {amodel_name}." 155 | ) 156 | raise RuntimeError( 157 | f"Pretrained weights ({pretrained}) not found for model {amodel_name}." 158 | ) 159 | 160 | if pretrained_audio: 161 | if amodel_name.startswith('PANN'): 162 | if 'Cnn14_mAP' in pretrained_audio: # official checkpoint 163 | audio_ckpt = torch.load(pretrained_audio, map_location='cpu') 164 | audio_ckpt = audio_ckpt['model'] 165 | keys = list(audio_ckpt.keys()) 166 | for key in keys: 167 | if 'spectrogram_extractor' not in key and 'logmel_extractor' not in key: 168 | v = audio_ckpt.pop(key) 169 | audio_ckpt['audio_branch.' + key] = v 170 | elif os.path.basename(pretrained_audio).startswith('PANN'): # checkpoint trained via HTSAT codebase 171 | audio_ckpt = torch.load(pretrained_audio, map_location='cpu') 172 | audio_ckpt = audio_ckpt['state_dict'] 173 | keys = list(audio_ckpt.keys()) 174 | for key in keys: 175 | if key.startswith('sed_model'): 176 | v = audio_ckpt.pop(key) 177 | audio_ckpt['audio_branch.' + key[10:]] = v 178 | elif os.path.basename(pretrained_audio).startswith('finetuned'): # checkpoint trained via linear probe codebase 179 | audio_ckpt = torch.load(pretrained_audio, map_location='cpu') 180 | else: 181 | raise ValueError('Unknown audio checkpoint') 182 | elif amodel_name.startswith('HTSAT'): 183 | if 'HTSAT_AudioSet_Saved' in pretrained_audio: # official checkpoint 184 | audio_ckpt = torch.load(pretrained_audio, map_location='cpu') 185 | audio_ckpt = audio_ckpt['state_dict'] 186 | keys = list(audio_ckpt.keys()) 187 | for key in keys: 188 | if key.startswith('sed_model') and ('spectrogram_extractor' not in key 189 | and 'logmel_extractor' not in key): 190 | v = audio_ckpt.pop(key) 191 | audio_ckpt['audio_branch.' + key[10:]] = v 192 | elif os.path.basename(pretrained_audio).startswith('HTSAT'): # checkpoint trained via HTSAT codebase 193 | audio_ckpt = torch.load(pretrained_audio, map_location='cpu') 194 | audio_ckpt = audio_ckpt['state_dict'] 195 | keys = list(audio_ckpt.keys()) 196 | for key in keys: 197 | if key.startswith('sed_model'): 198 | v = audio_ckpt.pop(key) 199 | audio_ckpt['audio_branch.' + key[10:]] = v 200 | elif os.path.basename(pretrained_audio).startswith('finetuned'): # checkpoint trained via linear probe codebase 201 | audio_ckpt = torch.load(pretrained_audio, map_location='cpu') 202 | else: 203 | raise ValueError('Unknown audio checkpoint') 204 | else: 205 | raise f'this audio encoder pretrained checkpoint is not support' 206 | 207 | model.load_state_dict(audio_ckpt, strict=False) 208 | logging.info(f"Loading pretrained {amodel_name} weights ({pretrained_audio}).") 209 | param_names = [n for n, p in model.named_parameters()] 210 | for n in param_names: 211 | print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded") 212 | 213 | model.to(device=device) 214 | if precision == "fp16": 215 | assert device.type != "cpu" 216 | convert_weights_to_fp16(model) 217 | 218 | if jit: 219 | model = torch.jit.script(model) 220 | 221 | return model, model_cfg 222 | 223 | 224 | def create_model_and_transforms( 225 | model_name: str, 226 | pretrained: str = "", 227 | precision: str = "fp32", 228 | device: torch.device = torch.device("cpu"), 229 | jit: bool = False, 230 | force_quick_gelu: bool = False, 231 | # pretrained_image: bool = False, 232 | ): 233 | model = create_model( 234 | model_name, 235 | pretrained, 236 | precision, 237 | device, 238 | jit, 239 | force_quick_gelu=force_quick_gelu, 240 | # pretrained_image=pretrained_image 241 | ) 242 | preprocess_train = image_transform(model.visual.image_size, is_train=True) 243 | preprocess_val = image_transform(model.visual.image_size, is_train=False) 244 | return model, preprocess_train, preprocess_val 245 | 246 | 247 | def list_models(): 248 | """enumerate available model architectures based on config files""" 249 | return list(_MODEL_CONFIGS.keys()) 250 | 251 | 252 | def add_model_config(path): 253 | """add model config path or file and update registry""" 254 | if not isinstance(path, Path): 255 | path = Path(path) 256 | _MODEL_CONFIG_PATHS.append(path) 257 | _rescan_model_configs() 258 | -------------------------------------------------------------------------------- /open_musiclm/laion_clap/clap_module/feature_fusion.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Feature Fusion for Varible-Length Data Processing 3 | AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py 4 | According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021 5 | ''' 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class DAF(nn.Module): 12 | ''' 13 | 直接相加 DirectAddFuse 14 | ''' 15 | 16 | def __init__(self): 17 | super(DAF, self).__init__() 18 | 19 | def forward(self, x, residual): 20 | return x + residual 21 | 22 | 23 | class iAFF(nn.Module): 24 | ''' 25 | 多特征融合 iAFF 26 | ''' 27 | 28 | def __init__(self, channels=64, r=4, type='2D'): 29 | super(iAFF, self).__init__() 30 | inter_channels = int(channels // r) 31 | 32 | if type == '1D': 33 | # 本地注意力 34 | self.local_att = nn.Sequential( 35 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 36 | nn.BatchNorm1d(inter_channels), 37 | nn.ReLU(inplace=True), 38 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 39 | nn.BatchNorm1d(channels), 40 | ) 41 | 42 | # 全局注意力 43 | self.global_att = nn.Sequential( 44 | nn.AdaptiveAvgPool1d(1), 45 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 46 | nn.BatchNorm1d(inter_channels), 47 | nn.ReLU(inplace=True), 48 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 49 | nn.BatchNorm1d(channels), 50 | ) 51 | 52 | # 第二次本地注意力 53 | self.local_att2 = nn.Sequential( 54 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 55 | nn.BatchNorm1d(inter_channels), 56 | nn.ReLU(inplace=True), 57 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 58 | nn.BatchNorm1d(channels), 59 | ) 60 | # 第二次全局注意力 61 | self.global_att2 = nn.Sequential( 62 | nn.AdaptiveAvgPool1d(1), 63 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 64 | nn.BatchNorm1d(inter_channels), 65 | nn.ReLU(inplace=True), 66 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 67 | nn.BatchNorm1d(channels), 68 | ) 69 | elif type == '2D': 70 | # 本地注意力 71 | self.local_att = nn.Sequential( 72 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 73 | nn.BatchNorm2d(inter_channels), 74 | nn.ReLU(inplace=True), 75 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 76 | nn.BatchNorm2d(channels), 77 | ) 78 | 79 | # 全局注意力 80 | self.global_att = nn.Sequential( 81 | nn.AdaptiveAvgPool2d(1), 82 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 83 | nn.BatchNorm2d(inter_channels), 84 | nn.ReLU(inplace=True), 85 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 86 | nn.BatchNorm2d(channels), 87 | ) 88 | 89 | # 第二次本地注意力 90 | self.local_att2 = nn.Sequential( 91 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 92 | nn.BatchNorm2d(inter_channels), 93 | nn.ReLU(inplace=True), 94 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 95 | nn.BatchNorm2d(channels), 96 | ) 97 | # 第二次全局注意力 98 | self.global_att2 = nn.Sequential( 99 | nn.AdaptiveAvgPool2d(1), 100 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 101 | nn.BatchNorm2d(inter_channels), 102 | nn.ReLU(inplace=True), 103 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 104 | nn.BatchNorm2d(channels), 105 | ) 106 | else: 107 | raise f'the type is not supported' 108 | 109 | self.sigmoid = nn.Sigmoid() 110 | 111 | def forward(self, x, residual): 112 | flag = False 113 | xa = x + residual 114 | if xa.size(0) == 1: 115 | xa = torch.cat([xa,xa],dim=0) 116 | flag = True 117 | xl = self.local_att(xa) 118 | xg = self.global_att(xa) 119 | xlg = xl + xg 120 | wei = self.sigmoid(xlg) 121 | xi = x * wei + residual * (1 - wei) 122 | 123 | xl2 = self.local_att2(xi) 124 | xg2 = self.global_att(xi) 125 | xlg2 = xl2 + xg2 126 | wei2 = self.sigmoid(xlg2) 127 | xo = x * wei2 + residual * (1 - wei2) 128 | if flag: 129 | xo = xo[0].unsqueeze(0) 130 | return xo 131 | 132 | 133 | class AFF(nn.Module): 134 | ''' 135 | 多特征融合 AFF 136 | ''' 137 | 138 | def __init__(self, channels=64, r=4, type='2D'): 139 | super(AFF, self).__init__() 140 | inter_channels = int(channels // r) 141 | 142 | if type == '1D': 143 | self.local_att = nn.Sequential( 144 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 145 | nn.BatchNorm1d(inter_channels), 146 | nn.ReLU(inplace=True), 147 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 148 | nn.BatchNorm1d(channels), 149 | ) 150 | self.global_att = nn.Sequential( 151 | nn.AdaptiveAvgPool1d(1), 152 | nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 153 | nn.BatchNorm1d(inter_channels), 154 | nn.ReLU(inplace=True), 155 | nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 156 | nn.BatchNorm1d(channels), 157 | ) 158 | elif type == '2D': 159 | self.local_att = nn.Sequential( 160 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 161 | nn.BatchNorm2d(inter_channels), 162 | nn.ReLU(inplace=True), 163 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 164 | nn.BatchNorm2d(channels), 165 | ) 166 | self.global_att = nn.Sequential( 167 | nn.AdaptiveAvgPool2d(1), 168 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 169 | nn.BatchNorm2d(inter_channels), 170 | nn.ReLU(inplace=True), 171 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 172 | nn.BatchNorm2d(channels), 173 | ) 174 | else: 175 | raise f'the type is not supported.' 176 | 177 | self.sigmoid = nn.Sigmoid() 178 | 179 | def forward(self, x, residual): 180 | flag = False 181 | xa = x + residual 182 | if xa.size(0) == 1: 183 | xa = torch.cat([xa,xa],dim=0) 184 | flag = True 185 | xl = self.local_att(xa) 186 | xg = self.global_att(xa) 187 | xlg = xl + xg 188 | wei = self.sigmoid(xlg) 189 | xo = 2 * x * wei + 2 * residual * (1 - wei) 190 | if flag: 191 | xo = xo[0].unsqueeze(0) 192 | return xo 193 | 194 | -------------------------------------------------------------------------------- /open_musiclm/laion_clap/clap_module/linear_probe.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from .model import MLPLayers 5 | 6 | 7 | class LinearProbe(nn.Module): 8 | def __init__(self, model, mlp, freeze, in_ch, out_ch, act=None): 9 | """ 10 | Args: 11 | model: nn.Module 12 | mlp: bool, if True, then use the MLP layer as the linear probe module 13 | freeze: bool, if Ture, then freeze all the CLAP model's layers when training the linear probe 14 | in_ch: int, the output channel from CLAP model 15 | out_ch: int, the output channel from linear probe (class_num) 16 | act: torch.nn.functional, the activation function before the loss function 17 | """ 18 | super().__init__() 19 | in_ch = 512 20 | self.clap_model = model 21 | self.clap_model.text_branch = None # to save memory 22 | self.freeze = freeze 23 | if mlp: 24 | self.lp_layer = MLPLayers(units=[in_ch, in_ch * 2, out_ch]) 25 | else: 26 | self.lp_layer = nn.Linear(in_ch, out_ch) 27 | 28 | if self.freeze: 29 | for param in self.clap_model.parameters(): 30 | param.requires_grad = False 31 | 32 | if act == 'None': 33 | self.act = None 34 | elif act == 'relu': 35 | self.act = nn.ReLU() 36 | elif act == 'elu': 37 | self.act = nn.ELU() 38 | elif act == 'prelu': 39 | self.act = nn.PReLU(num_parameters=in_ch) 40 | elif act == 'softmax': 41 | self.act = nn.Softmax(dim=-1) 42 | elif act == 'sigmoid': 43 | self.act = nn.Sigmoid() 44 | 45 | def forward(self, x, mix_lambda=None, device=None): 46 | """ 47 | Args: 48 | x: waveform, torch.tensor [batch, t_samples] / batch of mel_spec and longer list 49 | mix_lambda: torch.tensor [batch], the mixup lambda 50 | Returns: 51 | class_prob: torch.tensor [batch, class_num] 52 | 53 | """ 54 | # batchnorm cancel grandient 55 | if self.freeze: 56 | self.clap_model.eval() 57 | 58 | x = self.clap_model.audio_projection( 59 | self.clap_model.audio_branch(x, mixup_lambda=mix_lambda, device=device)["embedding"]) 60 | out = self.lp_layer(x) 61 | if self.act is not None: 62 | out = self.act(out) 63 | return out 64 | -------------------------------------------------------------------------------- /open_musiclm/laion_clap/clap_module/model_configs/HTSAT-base.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "HTSAT", 14 | "model_name": "base" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /open_musiclm/laion_clap/clap_module/model_configs/HTSAT-large.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "HTSAT", 14 | "model_name": "large" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /open_musiclm/laion_clap/clap_module/model_configs/HTSAT-tiny-win-1536.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1536, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "HTSAT", 14 | "model_name": "tiny" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /open_musiclm/laion_clap/clap_module/model_configs/HTSAT-tiny.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "HTSAT", 14 | "model_name": "tiny" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /open_musiclm/laion_clap/clap_module/model_configs/PANN-10.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn10" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /open_musiclm/laion_clap/clap_module/model_configs/PANN-14-fmax-18k.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 18000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /open_musiclm/laion_clap/clap_module/model_configs/PANN-14-fmax-8k-20s.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 960000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 360, 10 | "fmin": 50, 11 | "fmax": 8000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /open_musiclm/laion_clap/clap_module/model_configs/PANN-14-tiny-transformer.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 4 22 | } 23 | } -------------------------------------------------------------------------------- /open_musiclm/laion_clap/clap_module/model_configs/PANN-14-win-1536.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1536, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /open_musiclm/laion_clap/clap_module/model_configs/PANN-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 2048, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn14" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /open_musiclm/laion_clap/clap_module/model_configs/PANN-6.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "audio_cfg": { 4 | "audio_length": 1024, 5 | "clip_samples": 480000, 6 | "mel_bins": 64, 7 | "sample_rate": 48000, 8 | "window_size": 1024, 9 | "hop_size": 480, 10 | "fmin": 50, 11 | "fmax": 14000, 12 | "class_num": 527, 13 | "model_type": "PANN", 14 | "model_name": "Cnn6" 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 512, 20 | "heads": 8, 21 | "layers": 12 22 | } 23 | } -------------------------------------------------------------------------------- /open_musiclm/laion_clap/clap_module/model_configs/RN101-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 23, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } -------------------------------------------------------------------------------- /open_musiclm/laion_clap/clap_module/model_configs/RN101.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 23, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /open_musiclm/laion_clap/clap_module/model_configs/RN50-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 6, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /open_musiclm/laion_clap/clap_module/model_configs/RN50.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 6, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /open_musiclm/laion_clap/clap_module/model_configs/RN50x16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 384, 5 | "layers": [ 6 | 6, 7 | 8, 8 | 18, 9 | 8 10 | ], 11 | "width": 96, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 768, 18 | "heads": 12, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /open_musiclm/laion_clap/clap_module/model_configs/RN50x4.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 288, 5 | "layers": [ 6 | 4, 7 | 6, 8 | 10, 9 | 6 10 | ], 11 | "width": 80, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 640, 18 | "heads": 10, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /open_musiclm/laion_clap/clap_module/model_configs/ViT-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_musiclm/laion_clap/clap_module/model_configs/ViT-B-32-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 12, 7 | "width": 768, 8 | "patch_size": 32 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /open_musiclm/laion_clap/clap_module/model_configs/ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_musiclm/laion_clap/clap_module/model_configs/ViT-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_musiclm/laion_clap/clap_module/openai.py: -------------------------------------------------------------------------------- 1 | """ OpenAI pretrained model functions 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | import os 7 | import warnings 8 | from typing import Union, List 9 | 10 | import torch 11 | 12 | from .model import build_model_from_openai_state_dict 13 | from .pretrained import get_pretrained_url, list_pretrained_tag_models, download_pretrained 14 | 15 | __all__ = ["list_openai_models", "load_openai_model"] 16 | 17 | 18 | def list_openai_models() -> List[str]: 19 | """Returns the names of available CLIP models""" 20 | return list_pretrained_tag_models('openai') 21 | 22 | 23 | def load_openai_model( 24 | name: str, 25 | model_cfg, 26 | device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", 27 | jit=True, 28 | cache_dir=os.path.expanduser("~/.cache/clip"), 29 | enable_fusion: bool = False, 30 | fusion_type: str = 'None' 31 | ): 32 | """Load a CLIP model, preserve its text pretrained part, and set in the CLAP model 33 | 34 | Parameters 35 | ---------- 36 | name : str 37 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 38 | device : Union[str, torch.device] 39 | The device to put the loaded model 40 | jit : bool 41 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 42 | 43 | Returns 44 | ------- 45 | model : torch.nn.Module 46 | The CLAP model 47 | preprocess : Callable[[PIL.Image], torch.Tensor] 48 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 49 | """ 50 | if get_pretrained_url(name, 'openai'): 51 | model_path = download_pretrained(get_pretrained_url(name, 'openai'), root=cache_dir) 52 | elif os.path.isfile(name): 53 | model_path = name 54 | else: 55 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") 56 | 57 | try: 58 | # loading JIT archive 59 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 60 | state_dict = None 61 | except RuntimeError: 62 | # loading saved state dict 63 | if jit: 64 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 65 | jit = False 66 | state_dict = torch.load(model_path, map_location="cpu") 67 | 68 | if not jit: 69 | try: 70 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), model_cfg, enable_fusion, fusion_type).to(device) 71 | except KeyError: 72 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 73 | model = build_model_from_openai_state_dict(sd, model_cfg, enable_fusion, fusion_type).to(device) 74 | 75 | if str(device) == "cpu": 76 | model.float() 77 | return model 78 | 79 | # patch the device names 80 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 81 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 82 | 83 | def patch_device(module): 84 | try: 85 | graphs = [module.graph] if hasattr(module, "graph") else [] 86 | except RuntimeError: 87 | graphs = [] 88 | 89 | if hasattr(module, "forward1"): 90 | graphs.append(module.forward1.graph) 91 | 92 | for graph in graphs: 93 | for node in graph.findAllNodes("prim::Constant"): 94 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 95 | node.copyAttributes(device_node) 96 | 97 | model.apply(patch_device) 98 | patch_device(model.encode_audio) 99 | patch_device(model.encode_text) 100 | 101 | # patch dtype to float32 on CPU 102 | if str(device) == "cpu": 103 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 104 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 105 | float_node = float_input.node() 106 | 107 | def patch_float(module): 108 | try: 109 | graphs = [module.graph] if hasattr(module, "graph") else [] 110 | except RuntimeError: 111 | graphs = [] 112 | 113 | if hasattr(module, "forward1"): 114 | graphs.append(module.forward1.graph) 115 | 116 | for graph in graphs: 117 | for node in graph.findAllNodes("aten::to"): 118 | inputs = list(node.inputs()) 119 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 120 | if inputs[i].node()["value"] == 5: 121 | inputs[i].node().copyAttributes(float_node) 122 | 123 | model.apply(patch_float) 124 | patch_float(model.encode_audio) 125 | patch_float(model.encode_text) 126 | model.float() 127 | 128 | model.audio_branch.audio_length = model.audio_cfg.audio_length 129 | return model 130 | -------------------------------------------------------------------------------- /open_musiclm/laion_clap/clap_module/pretrained.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | 6 | from tqdm import tqdm 7 | 8 | _RN50 = dict( 9 | openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 10 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", 11 | cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt" 12 | ) 13 | 14 | _RN50_quickgelu = dict( 15 | openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 16 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", 17 | cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt" 18 | ) 19 | 20 | _RN101 = dict( 21 | openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 22 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt" 23 | ) 24 | 25 | _RN101_quickgelu = dict( 26 | openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 27 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt" 28 | ) 29 | 30 | _RN50x4 = dict( 31 | openai="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 32 | ) 33 | 34 | _RN50x16 = dict( 35 | openai="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 36 | ) 37 | 38 | _RN50x64 = dict( 39 | openai="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 40 | ) 41 | 42 | _VITB32 = dict( 43 | openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 44 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", 45 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", 46 | laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt", 47 | ) 48 | 49 | _VITB32_quickgelu = dict( 50 | openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 51 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", 52 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", 53 | laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt", 54 | ) 55 | 56 | _VITB16 = dict( 57 | openai="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 58 | ) 59 | 60 | _VITL14 = dict( 61 | openai="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 62 | ) 63 | 64 | _PRETRAINED = { 65 | "RN50": _RN50, 66 | "RN50-quickgelu": _RN50_quickgelu, 67 | "RN101": _RN101, 68 | "RN101-quickgelu": _RN101_quickgelu, 69 | "RN50x4": _RN50x4, 70 | "RN50x16": _RN50x16, 71 | "ViT-B-32": _VITB32, 72 | "ViT-B-32-quickgelu": _VITB32_quickgelu, 73 | "ViT-B-16": _VITB16, 74 | "ViT-L-14": _VITL14, 75 | } 76 | 77 | 78 | def list_pretrained(as_str: bool = False): 79 | """ returns list of pretrained models 80 | Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True 81 | """ 82 | return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] 83 | 84 | 85 | def list_pretrained_tag_models(tag: str): 86 | """ return all models having the specified pretrain tag """ 87 | models = [] 88 | for k in _PRETRAINED.keys(): 89 | if tag in _PRETRAINED[k]: 90 | models.append(k) 91 | return models 92 | 93 | 94 | def list_pretrained_model_tags(model: str): 95 | """ return all pretrain tags for the specified model architecture """ 96 | tags = [] 97 | if model in _PRETRAINED: 98 | tags.extend(_PRETRAINED[model].keys()) 99 | return tags 100 | 101 | 102 | def get_pretrained_url(model: str, tag: str): 103 | if model not in _PRETRAINED: 104 | return '' 105 | model_pretrained = _PRETRAINED[model] 106 | if tag not in model_pretrained: 107 | return '' 108 | return model_pretrained[tag] 109 | 110 | 111 | def download_pretrained(url: str, root: str = os.path.expanduser("~/.cache/clip")): 112 | os.makedirs(root, exist_ok=True) 113 | filename = os.path.basename(url) 114 | 115 | if 'openaipublic' in url: 116 | expected_sha256 = url.split("/")[-2] 117 | else: 118 | expected_sha256 = '' 119 | 120 | download_target = os.path.join(root, filename) 121 | 122 | if os.path.exists(download_target) and not os.path.isfile(download_target): 123 | raise RuntimeError(f"{download_target} exists and is not a regular file") 124 | 125 | if os.path.isfile(download_target): 126 | if expected_sha256: 127 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 128 | return download_target 129 | else: 130 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 131 | else: 132 | return download_target 133 | 134 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 135 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 136 | while True: 137 | buffer = source.read(8192) 138 | if not buffer: 139 | break 140 | 141 | output.write(buffer) 142 | loop.update(len(buffer)) 143 | 144 | if expected_sha256 and hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 145 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 146 | 147 | return download_target 148 | -------------------------------------------------------------------------------- /open_musiclm/laion_clap/clap_module/timm_model.py: -------------------------------------------------------------------------------- 1 | """ timm model adapter 2 | 3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. 4 | """ 5 | from collections import OrderedDict 6 | 7 | import torch.nn as nn 8 | 9 | try: 10 | import timm 11 | from timm.models.layers import Mlp, to_2tuple 12 | from timm.models.layers.attention_pool2d import RotAttentionPool2d 13 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d 14 | except ImportError as e: 15 | timm = None 16 | 17 | from .utils import freeze_batch_norm_2d 18 | 19 | 20 | class TimmModel(nn.Module): 21 | """ timm model adapter 22 | # FIXME this adapter is a work in progress, may change in ways that break weight compat 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model_name, 28 | embed_dim, 29 | image_size=224, 30 | pool='avg', 31 | proj='linear', 32 | drop=0., 33 | pretrained=False): 34 | super().__init__() 35 | if timm is None: 36 | raise RuntimeError("Please `pip install timm` to use timm models.") 37 | 38 | self.image_size = to_2tuple(image_size) 39 | self.trunk = timm.create_model(model_name, pretrained=pretrained) 40 | feat_size = self.trunk.default_cfg.get('pool_size', None) 41 | feature_ndim = 1 if not feat_size else 2 42 | if pool in ('abs_attn', 'rot_attn'): 43 | assert feature_ndim == 2 44 | # if attn pooling used, remove both classifier and default pool 45 | self.trunk.reset_classifier(0, global_pool='') 46 | else: 47 | # reset global pool if pool config set, otherwise leave as network default 48 | reset_kwargs = dict(global_pool=pool) if pool else {} 49 | self.trunk.reset_classifier(0, **reset_kwargs) 50 | prev_chs = self.trunk.num_features 51 | 52 | head_layers = OrderedDict() 53 | if pool == 'abs_attn': 54 | head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) 55 | prev_chs = embed_dim 56 | elif pool == 'rot_attn': 57 | head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) 58 | prev_chs = embed_dim 59 | else: 60 | assert proj, 'projection layer needed if non-attention pooling is used.' 61 | 62 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used 63 | if proj == 'linear': 64 | head_layers['drop'] = nn.Dropout(drop) 65 | head_layers['proj'] = nn.Linear(prev_chs, embed_dim) 66 | elif proj == 'mlp': 67 | head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop) 68 | 69 | self.head = nn.Sequential(head_layers) 70 | 71 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 72 | """ lock modules 73 | Args: 74 | unlocked_groups (int): leave last n layer groups unlocked (default: 0) 75 | """ 76 | if not unlocked_groups: 77 | # lock full model 78 | for param in self.trunk.parameters(): 79 | param.requires_grad = False 80 | if freeze_bn_stats: 81 | freeze_batch_norm_2d(self.trunk) 82 | else: 83 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change 84 | try: 85 | # FIXME import here until API stable and in an official release 86 | from timm.models.helpers import group_parameters, group_modules 87 | except ImportError: 88 | raise RuntimeError( 89 | 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') 90 | matcher = self.trunk.group_matcher() 91 | gparams = group_parameters(self.trunk, matcher) 92 | max_layer_id = max(gparams.keys()) 93 | max_layer_id = max_layer_id - unlocked_groups 94 | for group_idx in range(max_layer_id + 1): 95 | group = gparams[group_idx] 96 | for param in group: 97 | self.trunk.get_parameter(param).requires_grad = False 98 | if freeze_bn_stats: 99 | gmodules = group_modules(self.trunk, matcher, reverse=True) 100 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} 101 | freeze_batch_norm_2d(self.trunk, gmodules) 102 | 103 | def forward(self, x): 104 | x = self.trunk(x) 105 | x = self.head(x) 106 | return x 107 | -------------------------------------------------------------------------------- /open_musiclm/laion_clap/clap_module/tokenizer.py: -------------------------------------------------------------------------------- 1 | """ CLIP tokenizer 2 | 3 | Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | import gzip 6 | import html 7 | import os 8 | from functools import lru_cache 9 | from typing import Union, List 10 | 11 | import ftfy 12 | import regex as re 13 | import torch 14 | 15 | 16 | @lru_cache() 17 | def default_bpe(): 18 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 19 | 20 | 21 | @lru_cache() 22 | def bytes_to_unicode(): 23 | """ 24 | Returns list of utf-8 byte and a corresponding list of unicode strings. 25 | The reversible bpe codes work on unicode strings. 26 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 27 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 28 | This is a signficant percentage of your normal, say, 32K bpe vocab. 29 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 30 | And avoids mapping to whitespace/control characters the bpe code barfs on. 31 | """ 32 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 33 | cs = bs[:] 34 | n = 0 35 | for b in range(2**8): 36 | if b not in bs: 37 | bs.append(b) 38 | cs.append(2**8+n) 39 | n += 1 40 | cs = [chr(n) for n in cs] 41 | return dict(zip(bs, cs)) 42 | 43 | 44 | def get_pairs(word): 45 | """Return set of symbol pairs in a word. 46 | Word is represented as tuple of symbols (symbols being variable-length strings). 47 | """ 48 | pairs = set() 49 | prev_char = word[0] 50 | for char in word[1:]: 51 | pairs.add((prev_char, char)) 52 | prev_char = char 53 | return pairs 54 | 55 | 56 | def basic_clean(text): 57 | text = ftfy.fix_text(text) 58 | text = html.unescape(html.unescape(text)) 59 | return text.strip() 60 | 61 | 62 | def whitespace_clean(text): 63 | text = re.sub(r'\s+', ' ', text) 64 | text = text.strip() 65 | return text 66 | 67 | 68 | class SimpleTokenizer(object): 69 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): 70 | self.byte_encoder = bytes_to_unicode() 71 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 72 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 73 | merges = merges[1:49152-256-2+1] 74 | merges = [tuple(merge.split()) for merge in merges] 75 | vocab = list(bytes_to_unicode().values()) 76 | vocab = vocab + [v+'' for v in vocab] 77 | for merge in merges: 78 | vocab.append(''.join(merge)) 79 | if not special_tokens: 80 | special_tokens = ['', ''] 81 | else: 82 | special_tokens = ['', ''] + special_tokens 83 | vocab.extend(special_tokens) 84 | self.encoder = dict(zip(vocab, range(len(vocab)))) 85 | self.decoder = {v: k for k, v in self.encoder.items()} 86 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 87 | self.cache = {t:t for t in special_tokens} 88 | special = "|".join(special_tokens) 89 | self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 90 | 91 | self.vocab_size = len(self.encoder) 92 | self.all_special_ids = [self.encoder[t] for t in special_tokens] 93 | 94 | def bpe(self, token): 95 | if token in self.cache: 96 | return self.cache[token] 97 | word = tuple(token[:-1]) + ( token[-1] + '',) 98 | pairs = get_pairs(word) 99 | 100 | if not pairs: 101 | return token+'' 102 | 103 | while True: 104 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 105 | if bigram not in self.bpe_ranks: 106 | break 107 | first, second = bigram 108 | new_word = [] 109 | i = 0 110 | while i < len(word): 111 | try: 112 | j = word.index(first, i) 113 | new_word.extend(word[i:j]) 114 | i = j 115 | except: 116 | new_word.extend(word[i:]) 117 | break 118 | 119 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 120 | new_word.append(first+second) 121 | i += 2 122 | else: 123 | new_word.append(word[i]) 124 | i += 1 125 | new_word = tuple(new_word) 126 | word = new_word 127 | if len(word) == 1: 128 | break 129 | else: 130 | pairs = get_pairs(word) 131 | word = ' '.join(word) 132 | self.cache[token] = word 133 | return word 134 | 135 | def encode(self, text): 136 | bpe_tokens = [] 137 | text = whitespace_clean(basic_clean(text)).lower() 138 | for token in re.findall(self.pat, text): 139 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 140 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 141 | return bpe_tokens 142 | 143 | def decode(self, tokens): 144 | text = ''.join([self.decoder[token] for token in tokens]) 145 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 146 | return text 147 | 148 | 149 | _tokenizer = SimpleTokenizer() 150 | 151 | 152 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: 153 | """ 154 | Returns the tokenized representation of given input string(s) 155 | 156 | Parameters 157 | ---------- 158 | texts : Union[str, List[str]] 159 | An input string or a list of input strings to tokenize 160 | context_length : int 161 | The context length to use; all CLIP models use 77 as the context length 162 | 163 | Returns 164 | ------- 165 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 166 | """ 167 | if isinstance(texts, str): 168 | texts = [texts] 169 | 170 | sot_token = _tokenizer.encoder[""] 171 | eot_token = _tokenizer.encoder[""] 172 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 173 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 174 | 175 | for i, tokens in enumerate(all_tokens): 176 | if len(tokens) > context_length: 177 | tokens = tokens[:context_length] # Truncate 178 | result[i, :len(tokens)] = torch.tensor(tokens) 179 | 180 | return result 181 | -------------------------------------------------------------------------------- /open_musiclm/laion_clap/clap_module/transform.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ 2 | CenterCrop 3 | 4 | 5 | def _convert_to_rgb(image): 6 | return image.convert('RGB') 7 | 8 | 9 | def image_transform( 10 | image_size: int, 11 | is_train: bool, 12 | mean=(0.48145466, 0.4578275, 0.40821073), 13 | std=(0.26862954, 0.26130258, 0.27577711) 14 | ): 15 | normalize = Normalize(mean=mean, std=std) 16 | if is_train: 17 | return Compose([ 18 | RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC), 19 | _convert_to_rgb, 20 | ToTensor(), 21 | normalize, 22 | ]) 23 | else: 24 | return Compose([ 25 | Resize(image_size, interpolation=InterpolationMode.BICUBIC), 26 | CenterCrop(image_size), 27 | _convert_to_rgb, 28 | ToTensor(), 29 | normalize, 30 | ]) 31 | -------------------------------------------------------------------------------- /open_musiclm/laion_clap/clap_module/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn as nn 4 | from torchvision.ops.misc import FrozenBatchNorm2d 5 | import logging 6 | import h5py 7 | from tqdm import tqdm 8 | import random 9 | import json 10 | import os 11 | import pathlib 12 | 13 | # TODO: (yusong) this not a good place to store those information and does not scale. Need to be fixed later. 14 | dataset_split = { 15 | "audiocaps": ["train", "valid", "test"], 16 | "audioset": ["balanced_train", "unbalanced_train", "eval"], 17 | "BBCSoundEffects": ["train", "test"], 18 | "Clotho": ["train", "test", "valid"], 19 | "free_to_use_sounds": ["train", "test"], 20 | "paramount_motion": ["train", "test"], 21 | "sonniss_game_effects": ["train", "test"], 22 | "wesoundeffects": ["train", "test"], 23 | "MACS": ["train", "test"], 24 | "freesound": ["train", "test"], 25 | "FSD50K": ["train", "test", "valid"], 26 | "fsd50k_class_label": ["train", "test", "valid"], 27 | "esc50": ["train", "test"], 28 | "ESC50_1": ["train", "test"], 29 | "ESC50_2": ["train", "test"], 30 | "ESC50_3": ["train", "test"], 31 | "ESC50_4": ["train", "test"], 32 | "ESC50_5": ["train", "test"], 33 | "audiostock": ["train", "test"], 34 | "freesound_no_overlap_noesc50": ["train", "test"], 35 | "epidemic_sound_effects": ["train", "test"], 36 | "VGGSound": ["train", "test"], 37 | "urbansound8k_class_label": ["train", "test"], 38 | "audioset_t5": ["balanced_train", "unbalanced_train", "eval"], 39 | "audioset_t5_debiased": ["balanced_train", "unbalanced_train", "eval"], 40 | "epidemic_sound_effects_t5": ["train", "test"], 41 | "epidemic_sound_effects_t5_debiased": ["train", "test"], 42 | "WavText5K": ["train", "test"], 43 | "esc50_no_overlap": ["train", "test"], 44 | "usd8k_no_overlap": ["train", "test"], 45 | "fsd50k_200_class_label": ["train", "test", "valid"], 46 | "fma_full": ["train", "test"], 47 | "Genius": ["train", "test"], 48 | "Jamendo": ["train", "test"], 49 | "juno": ["train", "test"], 50 | "CMU_Arctic": ["train", "test"], 51 | "ravdess": ["train", "test"], 52 | "Europarl-st": ["train", "test"], 53 | "common_voice": ["train", "test"], 54 | "Jamendo_16bit": ["train", "test"], 55 | "genius_16bit_128": ["train", "test"], 56 | "juno_16bit": ["train", "test"], 57 | "fma_full_16bit_128": ["train", "test"], 58 | } 59 | 60 | 61 | def freeze_batch_norm_2d(module, module_match={}, name=""): 62 | """ 63 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is 64 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and 65 | returned. Otherwise, the module is walked recursively and submodules are converted in place. 66 | 67 | Args: 68 | module (torch.nn.Module): Any PyTorch module. 69 | module_match (dict): Dictionary of full module names to freeze (all if empty) 70 | name (str): Full module name (prefix) 71 | 72 | Returns: 73 | torch.nn.Module: Resulting module 74 | 75 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 76 | """ 77 | res = module 78 | is_match = True 79 | if module_match: 80 | is_match = name in module_match 81 | if is_match and isinstance( 82 | module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm) 83 | ): 84 | res = FrozenBatchNorm2d(module.num_features) 85 | res.num_features = module.num_features 86 | res.affine = module.affine 87 | if module.affine: 88 | res.weight.data = module.weight.data.clone().detach() 89 | res.bias.data = module.bias.data.clone().detach() 90 | res.running_mean.data = module.running_mean.data 91 | res.running_var.data = module.running_var.data 92 | res.eps = module.eps 93 | else: 94 | for child_name, child in module.named_children(): 95 | full_child_name = ".".join([name, child_name]) if name else child_name 96 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name) 97 | if new_child is not child: 98 | res.add_module(child_name, new_child) 99 | return res 100 | 101 | 102 | def exist(dataset_name, dataset_type): 103 | """ 104 | Check if dataset exists 105 | """ 106 | if dataset_type in dataset_split[dataset_name]: 107 | return True 108 | else: 109 | return False 110 | 111 | 112 | def get_tar_path_from_dataset_name( 113 | dataset_names, 114 | dataset_types, 115 | islocal, 116 | dataset_path, 117 | proportion=1, 118 | full_dataset=None 119 | ): 120 | """ 121 | Get tar path from dataset name and type 122 | """ 123 | output = [] 124 | for n in dataset_names: 125 | if full_dataset is not None and n in full_dataset: 126 | current_dataset_types = dataset_split[n] 127 | else: 128 | current_dataset_types = dataset_types 129 | for s in current_dataset_types: 130 | tmp = [] 131 | if islocal: 132 | sizefilepath_ = f"{dataset_path}/{n}/{s}/sizes.json" 133 | if not os.path.exists(sizefilepath_): 134 | sizefilepath_ = f"./json_files/{n}/{s}/sizes.json" 135 | else: 136 | sizefilepath_ = f"./json_files/{n}/{s}/sizes.json" 137 | if not os.path.exists(sizefilepath_): 138 | continue 139 | sizes = json.load(open(sizefilepath_, "r")) 140 | for k in sizes.keys(): 141 | if islocal: 142 | tmp.append(f"{dataset_path}/{n}/{s}/{k}") 143 | else: 144 | tmp.append( 145 | f"pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/{n}/{s}/{k} -" 146 | ) 147 | if proportion != 1: 148 | tmp = random.sample(tmp, int(proportion * len(tmp))) 149 | output.append(tmp) 150 | return sum(output, []) 151 | 152 | 153 | def get_tar_path_from_txts(txt_path, islocal, proportion=1): 154 | """ 155 | Get tar path from txt path 156 | """ 157 | if isinstance(txt_path, (list, tuple)): 158 | return sum( 159 | [ 160 | get_tar_path_from_txts( 161 | txt_path[i], islocal=islocal, proportion=proportion 162 | ) 163 | for i in range(len(txt_path)) 164 | ], 165 | [], 166 | ) 167 | if isinstance(txt_path, str): 168 | with open(txt_path) as f: 169 | lines = f.readlines() 170 | if islocal: 171 | lines = [ 172 | lines[i] 173 | .split("\n")[0] 174 | .replace("pipe:aws s3 cp s3://s-laion-audio/", "/mnt/audio_clip/") 175 | for i in range(len(lines)) 176 | ] 177 | else: 178 | lines = [ 179 | lines[i].split("\n")[0].replace(".tar", ".tar -") 180 | for i in range(len(lines)) 181 | ] 182 | if proportion != 1: 183 | print("Sampling tars with proportion of {}".format(proportion)) 184 | lines = random.sample(lines, int(proportion * len(lines))) 185 | return lines 186 | 187 | 188 | def get_mix_lambda(mixup_alpha, batch_size): 189 | mixup_lambdas = [ 190 | np.random.beta(mixup_alpha, mixup_alpha, 1)[0] for _ in range(batch_size) 191 | ] 192 | return np.array(mixup_lambdas).astype(np.float32) 193 | 194 | 195 | def do_mixup(x, mixup_lambda): 196 | """ 197 | Args: 198 | x: (batch_size , ...) 199 | mixup_lambda: (batch_size,) 200 | Returns: 201 | out: (batch_size, ...) 202 | """ 203 | out = ( 204 | x.transpose(0, -1) * mixup_lambda 205 | + torch.flip(x, dims=[0]).transpose(0, -1) * (1 - mixup_lambda) 206 | ).transpose(0, -1) 207 | return out 208 | 209 | 210 | def interpolate(x, ratio): 211 | """Interpolate data in time domain. This is used to compensate the 212 | resolution reduction in downsampling of a CNN. 213 | 214 | Args: 215 | x: (batch_size, time_steps, classes_num) 216 | ratio: int, ratio to interpolate 217 | Returns: 218 | upsampled: (batch_size, time_steps * ratio, classes_num) 219 | """ 220 | (batch_size, time_steps, classes_num) = x.shape 221 | upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1) 222 | upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num) 223 | return upsampled 224 | 225 | 226 | def pad_framewise_output(framewise_output, frames_num): 227 | """Pad framewise_output to the same length as input frames. The pad value 228 | is the same as the value of the last frame. 229 | Args: 230 | framewise_output: (batch_size, frames_num, classes_num) 231 | frames_num: int, number of frames to pad 232 | Outputs: 233 | output: (batch_size, frames_num, classes_num) 234 | """ 235 | pad = framewise_output[:, -1:, :].repeat( 236 | 1, frames_num - framewise_output.shape[1], 1 237 | ) 238 | """tensor for padding""" 239 | 240 | output = torch.cat((framewise_output, pad), dim=1) 241 | """(batch_size, frames_num, classes_num)""" 242 | 243 | 244 | def process_ipc(index_path, classes_num, filename): 245 | # load data 246 | logging.info("Load Data...............") 247 | ipc = [[] for _ in range(classes_num)] 248 | with h5py.File(index_path, "r") as f: 249 | for i in tqdm(range(len(f["target"]))): 250 | t_class = np.where(f["target"][i])[0] 251 | for t in t_class: 252 | ipc[t].append(i) 253 | print(ipc) 254 | np.save(filename, ipc) 255 | logging.info("Load Data Succeed...............") 256 | 257 | 258 | def save_to_dict(s, o_={}): 259 | sp = s.split(": ") 260 | o_.update({sp[0]: float(sp[1])}) 261 | return o_ 262 | 263 | 264 | def get_data_from_log(txt_path): 265 | """ 266 | Output dictionary from out.txt log file 267 | """ 268 | with open(txt_path) as f: 269 | lines = f.readlines() 270 | val_data = {} 271 | train_data = {} 272 | train_losses = [] 273 | train_losses_epoch = [] 274 | for i in range(len(lines)): 275 | if "| INFO |" in lines[i]: 276 | if "Eval Epoch" in lines[i]: 277 | if "val_loss" in lines[i]: 278 | # float(regex.sub("", lines[310].split(" ")[-1]).replace(" ", "")) 279 | line = lines[i].split("Eval Epoch: ")[-1] 280 | num_epoch = int(line.split(" ")[0].split(" ")[0]) 281 | d = { 282 | line.split(" ")[0] 283 | .split(" ")[1] 284 | .replace(":", ""): float(line.split(" ")[0].split(" ")[-1]) 285 | } 286 | for i in range(1, len(line.split(" "))): 287 | d = save_to_dict(line.split(" ")[i], d) 288 | val_data[num_epoch] = d 289 | elif "Train Epoch" in lines[i]: 290 | num_epoch = int(lines[i].split("Train Epoch: ")[1][0]) 291 | loss = float(lines[i].split("Loss: ")[-1].split(" (")[0]) 292 | train_losses.append(loss) 293 | train_losses_epoch.append(num_epoch) 294 | for i in range(len(train_losses)): 295 | train_data[i] = { 296 | "num_epoch": train_losses_epoch[i], 297 | "train_loss": train_losses[i], 298 | } 299 | return train_data, val_data 300 | 301 | 302 | def save_p(obj, filename): 303 | import pickle 304 | 305 | try: 306 | from deepdiff import DeepDiff 307 | except: 308 | os.system("pip install deepdiff") 309 | from deepdiff import DeepDiff 310 | with open(filename, "wb") as file: 311 | pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) # highest protocol 312 | with open(filename, "rb") as file: 313 | z = pickle.load(file) 314 | assert ( 315 | DeepDiff(obj, z, ignore_string_case=True) == {} 316 | ), "there is something wrong with the saving process" 317 | return 318 | 319 | 320 | def load_p(filename): 321 | import pickle 322 | 323 | with open(filename, "rb") as file: 324 | z = pickle.load(file) 325 | return z 326 | 327 | 328 | def save_json(data, name="data.json"): 329 | import json 330 | with open(name, 'w') as fp: 331 | json.dump(data, fp) 332 | return 333 | 334 | 335 | def load_json(name): 336 | import json 337 | with open(name, 'r') as fp: 338 | data = json.load(fp) 339 | return data 340 | 341 | 342 | from multiprocessing import Process, Manager 343 | from multiprocessing import Process, Value, Array 344 | from ctypes import c_wchar 345 | 346 | 347 | def load_class_label(path): 348 | # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing 349 | # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array 350 | out = None 351 | if path is not None: 352 | if pathlib.Path(path).suffix in [".pkl", ".pickle"]: 353 | out = load_p(path) 354 | elif pathlib.Path(path).suffix in [".json", ".txt"]: 355 | out = load_json(path) 356 | elif pathlib.Path(path).suffix in [".npy", ".npz"]: 357 | out = np.load(path) 358 | elif pathlib.Path(path).suffix in [".csv"]: 359 | import pandas as pd 360 | out = pd.read_csv(path) 361 | return out 362 | # if out is None: 363 | # return None 364 | # else: 365 | # key = Array(c_wchar, '\n'.join(list(out.keys())), lock=False) 366 | # val = Array('i', out.values(), lock=False) 367 | # return (key, val) 368 | 369 | 370 | from torch import optim 371 | 372 | 373 | def get_optimizer(params, lr, betas, eps, momentum, optimizer_name): 374 | if optimizer_name.lower() == "adamw": 375 | optimizer = optim.AdamW( 376 | params, lr=lr, betas=betas, eps=eps 377 | ) 378 | elif optimizer_name.lower() == "sgd": 379 | optimizer = optim.SGD( 380 | params, lr=lr, momentum=momentum 381 | ) 382 | elif optimizer_name.lower() == "adam": 383 | optimizer = optim.Adam( 384 | params, lr=lr, betas=betas, eps=eps 385 | ) 386 | else: 387 | raise ValueError("optimizer name is not correct") 388 | return optimizer 389 | -------------------------------------------------------------------------------- /open_musiclm/laion_clap/clap_module/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.2.1' 2 | -------------------------------------------------------------------------------- /open_musiclm/laion_clap/hook.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contrastive Language-Audio Pretraining Model from LAION 3 | -------------------------------------------------------- 4 | Paper: https://arxiv.org/abs/2211.06687 5 | Authors (equal contributions): Ke Chen, Yusong Wu, Tianyu Zhang, Yuchen Hui 6 | Support: LAION 7 | """ 8 | import os 9 | import torch 10 | import torch.nn.functional as F 11 | import torchaudio 12 | import torchvision.transforms 13 | from contextlib import suppress 14 | import numpy as np 15 | from clap_module import create_model 16 | 17 | from transformers import RobertaTokenizer 18 | import wget 19 | from clap_module.factory import load_state_dict 20 | 21 | 22 | def int16_to_float32_torch(x): 23 | return (x / 32767.0).type(torch.float32) 24 | 25 | 26 | def float32_to_int16_torch(x): 27 | x = torch.clamp(x, min=-1., max=1.) 28 | return (x * 32767.).type(torch.int16) 29 | 30 | class CLAP_Module(torch.nn.Module): 31 | def __init__(self, enable_fusion=False, device=None, amodel= 'HTSAT-tiny', tmodel='roberta') -> None: 32 | """Initialize CLAP Model 33 | 34 | Parameters 35 | ---------- 36 | enable_fusion: bool 37 | if true, it will create the fusion clap model, otherwise non-fusion clap model (default: false) 38 | device: str 39 | if None, it will automatically detect the device (gpu or cpu) 40 | amodel: str 41 | audio encoder architecture, default: HTSAT-tiny 42 | tmodel: str 43 | text encoder architecture, default: roberta 44 | """ 45 | super(CLAP_Module, self).__init__() 46 | if device is None: 47 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu' 48 | precision = 'fp32' 49 | 50 | if enable_fusion: 51 | fusion_type = 'aff_2d' 52 | model, model_cfg = create_model( 53 | amodel, 54 | tmodel, 55 | precision=precision, 56 | device=device, 57 | enable_fusion=enable_fusion, 58 | fusion_type=fusion_type 59 | ) 60 | else: 61 | model, model_cfg = create_model( 62 | amodel, 63 | tmodel, 64 | precision=precision, 65 | device=device, 66 | enable_fusion=enable_fusion 67 | ) 68 | self.enable_fusion = enable_fusion 69 | self.model = model 70 | self.model_cfg = model_cfg 71 | self.tokenize = RobertaTokenizer.from_pretrained('roberta-base') 72 | 73 | audio_cfg = model_cfg['audio_cfg'] 74 | self.mel_transform = torchaudio.transforms.MelSpectrogram( 75 | sample_rate=audio_cfg['sample_rate'], 76 | n_fft=audio_cfg['window_size'], 77 | win_length=audio_cfg['window_size'], 78 | hop_length=audio_cfg['hop_size'], 79 | center=True, 80 | pad_mode="reflect", 81 | power=2.0, 82 | norm=None, 83 | onesided=True, 84 | n_mels=audio_cfg['mel_bins'], 85 | f_min=audio_cfg['fmin'], 86 | f_max=audio_cfg['fmax'] 87 | ) 88 | self.log_mel_transform = torchaudio.transforms.AmplitudeToDB(top_db=None) 89 | 90 | def tokenizer(self, text): 91 | result = self.tokenize( 92 | text, 93 | padding="max_length", 94 | truncation=True, 95 | max_length=77, 96 | return_tensors="pt", 97 | ) 98 | return result 99 | 100 | def load_ckpt(self, ckpt = None, model_id = -1): 101 | """Load the pretrained checkpoint of CLAP model 102 | 103 | Parameters 104 | ---------- 105 | ckpt: str 106 | if ckpt is specified, the model will load this ckpt, otherwise the model will download the ckpt from zenodo. \n 107 | For fusion model, it will download the 630k+audioset fusion model (id=3). For non-fusion model, it will download the 630k+audioset model (id=1). 108 | model_id: 109 | if model_id is specified, you can download our best ckpt, as: 110 | id = 0 --> 630k non-fusion ckpt \n 111 | id = 1 --> 630k+audioset non-fusion ckpt \n 112 | id = 2 --> 630k fusion ckpt \n 113 | id = 3 --> 630k+audioset fusion ckpt \n 114 | Note that if your model is specied as non-fusion model but you download a fusion model ckpt, you will face an error. 115 | """ 116 | download_link = 'https://huggingface.co/lukewys/laion_clap/resolve/main/' 117 | download_names = [ 118 | '630k-best.pt', 119 | '630k-audioset-best.pt', 120 | '630k-fusion-best.pt', 121 | '630k-audioset-fusion-best.pt' 122 | ] 123 | if ckpt is not None: 124 | print(f'Load the specified checkpoint {ckpt} from users.') 125 | else: 126 | print(f'Load our best checkpoint in the paper.') 127 | if model_id == -1: 128 | model_id = 3 if self.enable_fusion else 1 129 | package_dir = os.path.dirname(os.path.realpath(__file__)) 130 | weight_file_name = download_names[model_id] 131 | ckpt = os.path.join(package_dir, weight_file_name) 132 | if os.path.exists(ckpt): 133 | print(f'The checkpoint is already downloaded') 134 | else: 135 | print('Downloading laion_clap weight files...') 136 | ckpt = wget.download(download_link + weight_file_name, os.path.dirname(ckpt)) 137 | print('Download completed!') 138 | print('Load Checkpoint...') 139 | ckpt = load_state_dict(ckpt, skip_params=True) 140 | self.model.load_state_dict(ckpt, strict=False) 141 | param_names = [n for n, p in self.model.named_parameters()] 142 | for n in param_names: 143 | print(n, "\t", "Loaded" if n in ckpt else "Unloaded") 144 | 145 | def get_mel(self, audio_data): 146 | mel = self.mel_transform(audio_data) 147 | mel = self.log_mel_transform(mel) 148 | return mel.T # (T, n_mels) 149 | 150 | def get_audio_features(self, sample, audio_data, max_len, data_truncating, data_filling, audio_cfg, require_grad=False): 151 | """ 152 | Calculate and add audio features to sample. 153 | Sample: a dict containing all the data of current sample. 154 | audio_data: a tensor of shape (T) containing audio data. 155 | max_len: the maximum length of audio data. 156 | data_truncating: the method of truncating data. 157 | data_filling: the method of filling data. 158 | audio_cfg: a dict containing audio configuration. Comes from model_cfg['audio_cfg']. 159 | require_grad: whether to require gradient for audio data. 160 | This is useful when we want to apply gradient-based classifier-guidance. 161 | """ 162 | grad_fn = suppress if require_grad else torch.no_grad 163 | with grad_fn(): 164 | if len(audio_data) > max_len: 165 | if data_truncating == "rand_trunc": 166 | longer = torch.tensor([True]) 167 | elif data_truncating == "fusion": 168 | # fusion 169 | mel = self.get_mel(audio_data) 170 | # split to three parts 171 | chunk_frames = max_len // audio_cfg['hop_size'] + 1 # the +1 related to how the spectrogram is computed 172 | total_frames = mel.shape[0] 173 | if chunk_frames == total_frames: 174 | # there is a corner case where the audio length is 175 | # larger than max_len but smaller than max_len+hop_size. 176 | # In this case, we just use the whole audio. 177 | mel_fusion = torch.stack([mel, mel, mel, mel], dim=0) 178 | sample["mel_fusion"] = mel_fusion 179 | longer = torch.tensor([False]) 180 | else: 181 | ranges = np.array_split(list(range(0, total_frames - chunk_frames + 1)), 3) 182 | # print('total_frames-chunk_frames:', total_frames-chunk_frames, 183 | # 'len(audio_data):', len(audio_data), 184 | # 'chunk_frames:', chunk_frames, 185 | # 'total_frames:', total_frames) 186 | if len(ranges[1]) == 0: 187 | # if the audio is too short, we just use the first chunk 188 | ranges[1] = [0] 189 | if len(ranges[2]) == 0: 190 | # if the audio is too short, we just use the first chunk 191 | ranges[2] = [0] 192 | # randomly choose index for each part 193 | idx_front = np.random.choice(ranges[0]) 194 | idx_middle = np.random.choice(ranges[1]) 195 | idx_back = np.random.choice(ranges[2]) 196 | # select mel 197 | mel_chunk_front = mel[idx_front:idx_front + chunk_frames, :] 198 | mel_chunk_middle = mel[idx_middle:idx_middle + chunk_frames, :] 199 | mel_chunk_back = mel[idx_back:idx_back + chunk_frames, :] 200 | 201 | # shrink the mel 202 | mel_shrink = torchvision.transforms.Resize(size=[chunk_frames, audio_cfg['mel_bins']])(mel[None])[0] 203 | # logging.info(f"mel_shrink.shape: {mel_shrink.shape}") 204 | 205 | # stack 206 | mel_fusion = torch.stack([mel_shrink, mel_chunk_front, mel_chunk_middle, mel_chunk_back], dim=0) 207 | sample["mel_fusion"] = mel_fusion 208 | longer = torch.tensor([True]) 209 | else: 210 | raise NotImplementedError( 211 | f"data_truncating {data_truncating} not implemented" 212 | ) 213 | # random crop to max_len (for compatibility) 214 | overflow = len(audio_data) - max_len 215 | idx = np.random.randint(0, overflow + 1) 216 | audio_data = audio_data[idx: idx + max_len] 217 | 218 | else: # padding if too short 219 | if len(audio_data) < max_len: # do nothing if equal 220 | if data_filling == "repeatpad": 221 | n_repeat = int(max_len / len(audio_data)) 222 | audio_data = audio_data.repeat(n_repeat) 223 | # audio_data = audio_data.unsqueeze(0).unsqueeze(0).unsqueeze(0) 224 | # audio_data = F.interpolate(audio_data,size=max_len,mode="bicubic")[0,0,0] 225 | audio_data = F.pad( 226 | audio_data, 227 | (0, max_len - len(audio_data)), 228 | mode="constant", 229 | value=0, 230 | ) 231 | elif data_filling == "pad": 232 | audio_data = F.pad( 233 | audio_data, 234 | (0, max_len - len(audio_data)), 235 | mode="constant", 236 | value=0, 237 | ) 238 | elif data_filling == "repeat": 239 | n_repeat = int(max_len / len(audio_data)) 240 | audio_data = audio_data.repeat(n_repeat + 1)[:max_len] 241 | else: 242 | raise NotImplementedError( 243 | f"data_filling {data_filling} not implemented" 244 | ) 245 | if data_truncating == 'fusion': 246 | mel = self.get_mel(audio_data) 247 | mel_fusion = torch.stack([mel, mel, mel, mel], dim=0) 248 | sample["mel_fusion"] = mel_fusion 249 | longer = torch.tensor([False]) 250 | 251 | sample["longer"] = longer 252 | sample["waveform"] = audio_data 253 | 254 | return sample 255 | 256 | def get_audio_embedding_from_data(self, x): 257 | """get audio embeddings from the audio data 258 | 259 | Parameters 260 | ---------- 261 | x: torch.Tensor (N,T): 262 | audio data, must be mono audio tracks. 263 | Returns 264 | ---------- 265 | audio embed: torch.Tensor (N,D): 266 | audio embeddings that extracted from audio files 267 | """ 268 | self.model.eval() 269 | audio_input = [] 270 | for audio_waveform in x: 271 | # quantize 272 | audio_waveform = int16_to_float32_torch(float32_to_int16_torch(audio_waveform)) 273 | temp_dict = {} 274 | # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode 275 | temp_dict = self.get_audio_features( 276 | temp_dict, audio_waveform, 480000, 277 | data_truncating='fusion' if self.enable_fusion else 'rand_trunc', 278 | data_filling='repeatpad', 279 | audio_cfg=self.model_cfg['audio_cfg'], 280 | require_grad=audio_waveform.requires_grad 281 | ) 282 | audio_input.append(temp_dict) 283 | audio_embed = self.model.get_audio_embedding(audio_input) 284 | return audio_embed 285 | 286 | def get_text_embedding(self, x, tokenizer = None): 287 | """get text embeddings from texts 288 | 289 | Parameters 290 | ---------- 291 | x: List[str] (N,): 292 | text list 293 | tokenizer: func: 294 | the tokenizer function, if not provided (None), will use the default Roberta tokenizer. 295 | 296 | Returns 297 | ---------- 298 | text_embed : torch.Tensor (N,D): 299 | text embeddings that extracted from texts 300 | """ 301 | self.model.eval() 302 | if tokenizer is not None: 303 | text_input = tokenizer(x) 304 | else: 305 | text_input = self.tokenizer(x) 306 | text_embed = self.model.get_text_embedding(text_input) 307 | text_embed = text_embed 308 | return text_embed 309 | 310 | 311 | -------------------------------------------------------------------------------- /open_musiclm/model_types.py: -------------------------------------------------------------------------------- 1 | from beartype.typing import Union 2 | 3 | from .hf_hubert_kmeans import HfHubertWithKmeans 4 | from .encodec_wrapper import EncodecWrapper 5 | 6 | Wav2Vec = HfHubertWithKmeans 7 | NeuralCodec = EncodecWrapper 8 | -------------------------------------------------------------------------------- /open_musiclm/optimizer.py: -------------------------------------------------------------------------------- 1 | from torch.optim import AdamW, Adam, lr_scheduler 2 | 3 | def separate_weight_decayable_params(params): 4 | wd_params, no_wd_params = [], [] 5 | for param in params: 6 | param_list = no_wd_params if param.ndim < 2 else wd_params 7 | param_list.append(param) 8 | return wd_params, no_wd_params 9 | 10 | def get_optimizer( 11 | params, 12 | lr = 1e-4, 13 | wd = 1e-2, 14 | betas = (0.9, 0.99), 15 | eps = 1e-8, 16 | filter_by_requires_grad = False, 17 | group_wd_params = True, 18 | **kwargs 19 | ): 20 | if filter_by_requires_grad: 21 | params = list(filter(lambda t: t.requires_grad, params)) 22 | 23 | if wd == 0: 24 | return Adam(params, lr = lr, betas = betas, eps = eps) 25 | 26 | if group_wd_params: 27 | wd_params, no_wd_params = separate_weight_decayable_params(params) 28 | 29 | params = [ 30 | {'params': wd_params}, 31 | {'params': no_wd_params, 'weight_decay': 0}, 32 | ] 33 | 34 | return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps) 35 | 36 | def get_linear_scheduler( 37 | optimizer, 38 | total_iters=10000, 39 | start_factor=1e-7, 40 | ): 41 | return lr_scheduler.LinearLR(optimizer=optimizer, start_factor=start_factor, end_factor=1., total_iters=total_iters) -------------------------------------------------------------------------------- /open_musiclm/preprocess.py: -------------------------------------------------------------------------------- 1 | import io 2 | import itertools 3 | import math 4 | import time 5 | from dataclasses import dataclass 6 | from pathlib import Path 7 | from shutil import rmtree 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn.functional as F 12 | import torchaudio 13 | from accelerate import Accelerator, DistributedType 14 | from beartype.door import is_bearable 15 | from beartype.typing import Dict, List, Literal, Optional, Union 16 | from beartype.vale import Is 17 | from einops import rearrange, reduce, repeat 18 | from einops.layers.torch import Rearrange 19 | from torch import einsum, nn 20 | from torch.utils.data import DataLoader, Dataset, random_split 21 | from tqdm import tqdm 22 | from typing_extensions import Annotated 23 | 24 | from .clap_quantized import ClapQuantized 25 | from .data import (SoundDatasetForPreprocessing, 26 | get_sound_preprocessing_dataloader, init_sqlite) 27 | from .hf_hubert_kmeans import HfHubertWithKmeans, learn_kmeans 28 | from .model_types import NeuralCodec, Wav2Vec 29 | from .open_musiclm import (get_or_compute_acoustic_token_ids, 30 | get_or_compute_clap_token_ids, 31 | get_or_compute_semantic_token_ids) 32 | from .optimizer import get_linear_scheduler, get_optimizer 33 | from .utils import (all_rows_have_eos_id, append_eos_id, 34 | batch_unique_consecutive, beartype_jit, ceil_div, 35 | copy_file_to_folder, default, eval_decorator, exists, 36 | generate_mask_with_prob, get_embeds, gumbel_sample, 37 | mask_out_after_eos_id, round_down_nearest_multiple, top_k) 38 | 39 | 40 | def cycle(dl): 41 | while True: 42 | for data in dl: 43 | yield data 44 | 45 | 46 | def yes_or_no(question): 47 | answer = input(f'{question} (y/n) ') 48 | return answer.lower() in ('yes', 'y') 49 | 50 | 51 | # auto data to module keyword argument routing functions 52 | 53 | def has_duplicates(tup): 54 | counts = dict() 55 | for el in tup: 56 | if el not in counts: 57 | counts[el] = 0 58 | counts[el] += 1 59 | return any(filter(lambda count: count > 1, counts.values())) 60 | 61 | 62 | def determine_types(data, config): 63 | output = [] 64 | for el in data: 65 | for name, data_type in config.items(): 66 | if is_bearable(el, data_type): 67 | output.append(name) 68 | break 69 | else: 70 | raise TypeError(f'unable to determine type of {data}') 71 | 72 | return tuple(output) 73 | 74 | 75 | def noop(*args, **kwargs): 76 | pass 77 | 78 | def without_none(arr): 79 | return list(filter(lambda x: x is not None, arr)) 80 | 81 | @beartype_jit 82 | class DataPreprocessor(nn.Module): 83 | """ 84 | Class to preprocess audio files for the single stage transformer trainer. 85 | 86 | Load audio and compute: 87 | 1) clap tokens for the entire audio file, computed for 10 second sliding windows with 1 second interval 88 | 2) semantic tokens for the entire audio file 89 | 3) coarse+fine tokens for the entire audio file 90 | Run this once over the dataset and then use the preprocessed data for training. 91 | """ 92 | 93 | def __init__( 94 | self, 95 | *, 96 | num_coarse_quantizers=3, 97 | wav2vec: Optional[Wav2Vec] = None, 98 | neural_codec: Optional[NeuralCodec] = None, 99 | audio_conditioner: Optional[ClapQuantized] = None, 100 | max_audio_length_seconds=180, 101 | random_crop=True, 102 | clap_audio_length_seconds=10, 103 | semantic_audio_length_seconds=10, 104 | clap_batch_size=32, 105 | num_crops=1, 106 | ignore_files: Optional[List[str]]=None, 107 | ignore_load_errors=True, 108 | replace_existing=False, 109 | folder=None, 110 | results_folder='./data/fma_preprocessed', 111 | accelerate_kwargs: dict = {}, 112 | config_paths: Optional[List[str]] = None, 113 | **kwargs, 114 | ): 115 | super().__init__() 116 | 117 | self.accelerator = Accelerator(**accelerate_kwargs) 118 | 119 | self.wav2vec = wav2vec 120 | self.audio_conditioner = audio_conditioner 121 | self.neural_codec = neural_codec 122 | self.num_coarse_quantizers = num_coarse_quantizers 123 | self.max_audio_length_seconds = max_audio_length_seconds 124 | self.clap_audio_length_seconds = int(clap_audio_length_seconds) 125 | self.semantic_audio_length_seconds = int(semantic_audio_length_seconds) 126 | # TODO: allow a smaller clap length than semantic length, and average the clap embeddings over the time period as in the paper 127 | assert self.clap_audio_length_seconds == self.semantic_audio_length_seconds, 'clap window must be equal to semantic window for now' 128 | self.clap_batch_size = clap_batch_size 129 | self.num_crops = num_crops 130 | self.replace_existing = replace_existing 131 | 132 | self.register_buffer('steps', torch.Tensor([0])) 133 | 134 | # create dataset 135 | 136 | assert exists(wav2vec) and exists(audio_conditioner) and exists(neural_codec) 137 | 138 | self.ds_fields = ('raw_wave_for_clap', 'raw_wave_for_semantic', 'raw_wave_for_acoustic') 139 | 140 | target_sample_hz = (audio_conditioner.sample_rate, wav2vec.target_sample_hz, neural_codec.sample_rate) 141 | 142 | normalize = (False, True, False) 143 | 144 | seq_len_multiple_of = (None, wav2vec.seq_len_multiple_of, None) 145 | 146 | data_max_length_seconds = (max_audio_length_seconds, max_audio_length_seconds, max_audio_length_seconds) 147 | 148 | assert exists(folder), 'audio folder must be passed in for preprocessing' 149 | 150 | self.ds = SoundDatasetForPreprocessing( 151 | folder, 152 | pad_to_seconds=self.semantic_audio_length_seconds, 153 | max_length_seconds=data_max_length_seconds, 154 | random_crop=random_crop, 155 | normalize=normalize, 156 | target_sample_hz=target_sample_hz, 157 | seq_len_multiple_of=seq_len_multiple_of, 158 | ignore_load_errors=ignore_load_errors, 159 | ignore_files=ignore_files, 160 | ) 161 | 162 | # dataloader 163 | 164 | self.dl = get_sound_preprocessing_dataloader(self.ds, batch_size=1, shuffle=False) 165 | 166 | # prepare 167 | 168 | ( 169 | self.audio_conditioner, 170 | self.wav2vec, 171 | self.neural_codec, 172 | self.dl 173 | ) = self.accelerator.prepare( 174 | self.audio_conditioner, 175 | self.wav2vec, 176 | self.neural_codec, 177 | self.dl 178 | ) 179 | 180 | # dataloader iterators 181 | 182 | self.dl_iter = cycle(self.dl) 183 | 184 | self.results_folder = Path(results_folder) 185 | 186 | if self.is_main: 187 | if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'): 188 | rmtree(str(self.results_folder)) 189 | 190 | self.results_folder.mkdir(parents=True, exist_ok=True) 191 | 192 | if self.is_main and exists(config_paths): 193 | configs_folder = self.results_folder / "configs" 194 | configs_folder.mkdir(parents=True, exist_ok=True) 195 | for config_path in config_paths: 196 | copy_file_to_folder(config_path, configs_folder) 197 | 198 | if self.is_main: 199 | self.conn, self.cursor = init_sqlite(str(self.results_folder / 'preprocessed.db')) 200 | self.cursor.execute("CREATE TABLE IF NOT EXISTS tokens(idx integer primary key, path text, clap array, semantic array, coarse array, fine array)") 201 | 202 | self.accelerator.wait_for_everyone() 203 | 204 | if not self.is_main: 205 | self.conn, self.cursor = init_sqlite(str(self.results_folder / 'preprocessed.db')) 206 | 207 | def print(self, msg): 208 | self.accelerator.print(msg) 209 | 210 | @property 211 | def device(self): 212 | return self.accelerator.device 213 | 214 | @property 215 | def is_distributed(self): 216 | return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1) 217 | 218 | @property 219 | def is_main(self): 220 | return self.accelerator.is_main_process 221 | 222 | @property 223 | def is_local_main(self): 224 | return self.accelerator.is_local_main_process 225 | 226 | @property 227 | def device(self): 228 | return next(self.parameters()).device 229 | 230 | def generate_tokens_from_batch(self, raw_wave_for_clap, raw_wave_for_semantic, raw_wave_for_acoustic): 231 | # split clap waveform into a clap_audio_length_seconds sliding window with 1 second interval. sample rate is self.audio_conditioner.sample_rate 232 | clap_split = raw_wave_for_clap.unfold( 233 | -1, 234 | self.audio_conditioner.sample_rate * self.clap_audio_length_seconds, 235 | self.audio_conditioner.sample_rate 236 | ).squeeze(0) 237 | 238 | batch_size = self.clap_batch_size 239 | clap_token_ids_all = [] 240 | for i in range(0, clap_split.shape[0], batch_size): 241 | 242 | batch = clap_split[i:i+batch_size, :] 243 | clap_token_ids = get_or_compute_clap_token_ids(None, self.accelerator.unwrap_model(self.audio_conditioner), batch, None) 244 | clap_token_ids_all.append(clap_token_ids) 245 | 246 | clap_token_ids = torch.cat(clap_token_ids_all, dim=0) 247 | 248 | semantic_token_ids = get_or_compute_semantic_token_ids(None, raw_wave_for_semantic, self.accelerator.unwrap_model(self.wav2vec)) 249 | 250 | coarse_token_ids, fine_token_ids = get_or_compute_acoustic_token_ids(None, None, raw_wave_for_acoustic, self.accelerator.unwrap_model(self.neural_codec), self.num_coarse_quantizers) 251 | 252 | return clap_token_ids, semantic_token_ids, (coarse_token_ids, fine_token_ids) 253 | 254 | def process(self, log_fn=noop): 255 | iters = math.ceil(self.num_crops * len(self.ds) / self.accelerator.num_processes) 256 | for idx in tqdm(range(iters), desc='processing data', mininterval=5): 257 | inputs = next(self.dl_iter) 258 | if exists(inputs): 259 | idx = idx * self.accelerator.num_processes + self.accelerator.process_index 260 | if not self.replace_existing: 261 | self.cursor.execute("SELECT * FROM tokens WHERE idx=?", (idx,)) 262 | if len(self.cursor.fetchall()) > 0: 263 | continue 264 | 265 | data_kwargs = dict(zip(self.ds_fields, inputs['data'])) 266 | clap_token_ids, semantic_token_ids, (coarse_token_ids, fine_token_ids) = self.generate_tokens_from_batch(**data_kwargs) 267 | 268 | clap_token_ids = clap_token_ids.detach().cpu().numpy() 269 | semantic_token_ids = semantic_token_ids.detach().cpu().numpy() 270 | coarse_token_ids = coarse_token_ids.detach().cpu().numpy() 271 | fine_token_ids = fine_token_ids.detach().cpu().numpy() 272 | 273 | # convert to int16 to save space 274 | clap_token_ids = clap_token_ids.astype(np.uint16) 275 | semantic_token_ids = semantic_token_ids.astype(np.uint16) 276 | coarse_token_ids = coarse_token_ids.astype(np.uint16) 277 | fine_token_ids = fine_token_ids.astype(np.uint16) 278 | # add tokens to sqlite db 279 | self.cursor.execute("INSERT INTO tokens VALUES (?, ?, ?, ?, ?, ?)", (idx, inputs['file_path'][0], clap_token_ids, semantic_token_ids, coarse_token_ids, fine_token_ids)) 280 | self.conn.commit() 281 | 282 | self.steps += 1 283 | 284 | self.print('processing complete') 285 | -------------------------------------------------------------------------------- /open_musiclm/transformer.py: -------------------------------------------------------------------------------- 1 | # Transformer implementation from https://github.com/lucidrains/audiolm-pytorch/blob/main/audiolm_pytorch/audiolm_pytorch.py 2 | 3 | from functools import partial 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from einops import rearrange, repeat 8 | from torch import einsum, nn 9 | import math 10 | 11 | from .utils import default, exists, grad_shrink, l2norm 12 | 13 | try: 14 | import xformers.ops as xops 15 | 16 | is_xformers_available = True 17 | except ImportError: 18 | is_xformers_available = False 19 | 20 | # bias-less layernorm, being used in more recent T5s, PaLM, also in @borisdayma 's experiments shared with me 21 | # greater stability 22 | 23 | 24 | class LayerNorm(nn.Module): 25 | def __init__(self, dim): 26 | super().__init__() 27 | self.gamma = nn.Parameter(torch.ones(dim)) 28 | self.register_buffer("beta", torch.zeros(dim)) 29 | 30 | def forward(self, x): 31 | return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) 32 | 33 | # relative positional bias 34 | 35 | 36 | class RelativePositionBias(nn.Module): 37 | """ from https://arxiv.org/abs/2111.09883 """ 38 | 39 | def __init__( 40 | self, 41 | *, 42 | dim, 43 | heads, 44 | layers=3 45 | ): 46 | super().__init__() 47 | self.net = nn.ModuleList([]) 48 | self.net.append(nn.Sequential(nn.Linear(1, dim), nn.SiLU())) 49 | 50 | for _ in range(layers - 1): 51 | self.net.append(nn.Sequential(nn.Linear(dim, dim), nn.SiLU())) 52 | 53 | self.net.append(nn.Linear(dim, heads)) 54 | 55 | def forward(self, n, device=torch.device('cpu')): 56 | pos = torch.arange(n, device=device) 57 | rel_pos = (rearrange(pos, 'i -> i 1') - rearrange(pos, 'j -> 1 j')) 58 | rel_pos += (n - 1) 59 | 60 | x = torch.arange(-n + 1, n, device=device).float() 61 | x = rearrange(x, '... -> ... 1') 62 | 63 | for layer in self.net: 64 | x = layer(x) 65 | 66 | x = x[rel_pos] 67 | return rearrange(x, 'i j h -> h i j') 68 | 69 | class T5RelativePositionBias(nn.Module): 70 | def __init__( 71 | self, 72 | *, 73 | heads, 74 | num_buckets=32, 75 | max_distance=128, 76 | causal=True 77 | ): 78 | super().__init__() 79 | self.num_buckets = num_buckets 80 | self.max_distance = max_distance 81 | self.causal = causal 82 | 83 | self.relative_attention_bias = nn.Embedding(num_buckets, heads) 84 | 85 | @staticmethod 86 | def _relative_position_bucket(relative_position, causal = True, num_buckets = 32, max_distance = 128): 87 | ret = 0 88 | n = -relative_position 89 | if not causal: 90 | num_buckets //= 2 91 | ret += (n < 0).long() * num_buckets 92 | n = torch.abs(n) 93 | else: 94 | n = torch.max(n, torch.zeros_like(n)) 95 | 96 | max_exact = num_buckets // 2 97 | is_small = n < max_exact 98 | 99 | val_if_large = max_exact + ( 100 | torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) 101 | ).long() 102 | val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) 103 | 104 | ret += torch.where(is_small, n, val_if_large) 105 | return ret 106 | 107 | def forward( 108 | self, 109 | n, 110 | device=torch.device('cpu') 111 | ): 112 | pos = torch.arange(n, device=device) 113 | rel_pos = (rearrange(pos, 'i -> i 1') - rearrange(pos, 'j -> 1 j')) 114 | rel_pos = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets, max_distance=self.max_distance) 115 | 116 | bias = self.relative_attention_bias(rel_pos) 117 | return rearrange(bias, 'i j h -> h i j') 118 | 119 | # feedforward 120 | 121 | 122 | class CausalDSConv(nn.Module): 123 | def __init__(self, dim): 124 | super().__init__() 125 | self.ds_conv = nn.Conv1d(dim, dim, 3, bias=False, groups=dim) 126 | 127 | def forward(self, x): 128 | x = rearrange(x, 'b n c -> b c n') 129 | x = F.pad(x, (2, 0)) 130 | x = self.ds_conv(x) 131 | return rearrange(x, 'b c n -> b n c') 132 | 133 | 134 | class GEGLU(nn.Module): 135 | def forward(self, x): 136 | x, gate = x.chunk(2, dim=-1) 137 | return F.gelu(gate) * x 138 | 139 | 140 | def ConvFeedForward(dim, mult=4, dropout=0.1): 141 | inner_dim = int(dim * 2 * mult / 3) 142 | return nn.Sequential( 143 | LayerNorm(dim), 144 | nn.Linear(dim, inner_dim * 2, bias=False), 145 | CausalDSConv(inner_dim * 2), 146 | GEGLU(), 147 | LayerNorm(inner_dim), 148 | nn.Dropout(dropout), 149 | nn.Linear(inner_dim, dim, bias=False) 150 | ) 151 | 152 | def FeedForward(dim, mult=4, dropout=0.1): 153 | inner_dim = int(dim * mult) 154 | return nn.Sequential( 155 | LayerNorm(dim), 156 | nn.Linear(dim, inner_dim * 2, bias=False), 157 | GEGLU(), 158 | LayerNorm(inner_dim), 159 | nn.Dropout(dropout), 160 | nn.Linear(inner_dim, dim, bias=False) 161 | ) 162 | 163 | # attention 164 | 165 | 166 | class Attention(nn.Module): 167 | def __init__( 168 | self, 169 | dim, 170 | causal=False, 171 | non_causal_prefix=0, 172 | dim_head=64, 173 | dim_context=None, 174 | heads=8, 175 | norm_context=False, 176 | num_null_kv=0, 177 | dropout=0.1, 178 | scale=8, 179 | use_memory_efficient_attention=False 180 | ): 181 | super().__init__() 182 | self.heads = heads 183 | self.scale = scale 184 | self.causal = causal 185 | self.non_causal_prefix = non_causal_prefix 186 | self.dropout = dropout 187 | self.use_memory_efficient_attention = use_memory_efficient_attention 188 | if self.use_memory_efficient_attention and not is_xformers_available: 189 | raise ImportError("Please install xformers to use memory efficient attention") 190 | 191 | inner_dim = dim_head * heads 192 | 193 | dim_context = default(dim_context, dim) 194 | 195 | self.norm = LayerNorm(dim) 196 | self.context_norm = LayerNorm(dim_context) if norm_context else nn.Identity() 197 | 198 | self.attn_dropout = nn.Dropout(dropout) if not self.use_memory_efficient_attention else None 199 | 200 | self.num_null_kv = num_null_kv 201 | self.null_kv = nn.Parameter(torch.randn(2, num_null_kv, dim_head)) if num_null_kv > 0 else None 202 | 203 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 204 | self.to_kv = nn.Linear(dim_context, dim_head * 2, bias=False) 205 | 206 | self.q_scale = nn.Parameter(torch.ones(dim_head)) 207 | self.k_scale = nn.Parameter(torch.ones(dim_head)) 208 | 209 | self.to_out = nn.Sequential( 210 | nn.Linear(inner_dim, dim, bias=False), 211 | nn.Dropout(dropout) 212 | ) 213 | 214 | def forward( 215 | self, 216 | x, 217 | context=None, 218 | mask=None, 219 | attn_bias=None, 220 | prefix_context=None, 221 | prefix_context_mask=None 222 | ): 223 | b, n, _, device = *x.shape, x.device 224 | 225 | if exists(context): 226 | context = self.context_norm(context) 227 | 228 | kv_input = default(context, x) 229 | 230 | # take care of prefix-based self attention conditioning 231 | # make sure to either concat the to the self attention mask or lengthen it accordingly 232 | 233 | if exists(prefix_context): 234 | kv_input = torch.cat((prefix_context, kv_input), dim=-2) 235 | prefix_seq_len = prefix_context.shape[-2] 236 | 237 | if not exists(mask): 238 | mask = torch.ones((b, n), device=device, dtype=torch.bool) 239 | 240 | if exists(prefix_context_mask): 241 | mask = torch.cat((prefix_context_mask, mask), dim=-1) 242 | else: 243 | mask = F.pad(mask, (prefix_seq_len, 0), value=True) 244 | 245 | if exists(attn_bias): 246 | attn_bias = F.pad(attn_bias, (prefix_seq_len, 0), value=0.) 247 | 248 | # prenorm 249 | 250 | x = self.norm(x) 251 | 252 | # project for queries, keys, values 253 | 254 | q, k, v = self.to_q(x), *self.to_kv(kv_input).chunk(2, dim=-1) 255 | 256 | # null key / values 257 | 258 | if self.num_null_kv > 0: 259 | null_k, null_v = repeat(self.null_kv, 'kv n d -> kv b n d', b=b).unbind(dim=0) 260 | k = torch.cat((null_k, k), dim=-2) 261 | v = torch.cat((null_v, v), dim=-2) 262 | 263 | # split for multi-headed attention 264 | 265 | q = rearrange(q, 'b n (h d) -> b h n d', h=self.heads) 266 | 267 | # new technique, rmsnormed queries and keys, first used by 22B parameter model successfully https://arxiv.org/abs/2302.05442 268 | 269 | q, k = map(l2norm, (q, k)) 270 | q = q * self.q_scale 271 | k = k * self.k_scale 272 | 273 | # attention 274 | 275 | if self.use_memory_efficient_attention: 276 | if exists(attn_bias): 277 | attn_bias = F.pad(attn_bias, (self.num_null_kv, 0), value=0.) 278 | 279 | if exists(mask): 280 | mask = F.pad(mask, (self.num_null_kv, 0), value=True) 281 | mask = rearrange(mask, 'b j -> b 1 1 j') 282 | attn_bias = attn_bias.masked_fill(~mask, -torch.finfo(attn_bias.dtype).max) 283 | 284 | if self.causal: 285 | i, j = attn_bias.shape[-2:] 286 | causal_mask = torch.ones((i, j), dtype=torch.bool, device=x.device).triu(j - i + 1) 287 | 288 | if self.non_causal_prefix > 0: 289 | causal_mask[:self.non_causal_prefix, :(self.non_causal_prefix + j - i)] = False 290 | 291 | attn_bias = attn_bias.masked_fill(causal_mask, -torch.finfo(attn_bias.dtype).max) 292 | 293 | q = rearrange(q, 'b h n d -> b n h d') 294 | k = repeat(k, 'b n d -> b n h d', h=self.heads) 295 | v = repeat(v, 'b n d -> b n h d', h=self.heads) 296 | 297 | # compute attention 298 | out = xops.memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=self.dropout) 299 | 300 | # merge heads 301 | out = rearrange(out, 'b n h d -> b n (h d)') 302 | 303 | else: 304 | sim = einsum('b h i d, b j d -> b h i j', q, k) * self.scale 305 | 306 | if exists(attn_bias): 307 | attn_bias = F.pad(attn_bias, (self.num_null_kv, 0), value=0.) 308 | sim = sim + attn_bias 309 | 310 | if exists(mask): 311 | mask = F.pad(mask, (self.num_null_kv, 0), value=True) 312 | mask = rearrange(mask, 'b j -> b 1 1 j') 313 | sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) 314 | 315 | if self.causal: 316 | i, j = sim.shape[-2:] 317 | causal_mask = torch.ones((i, j), dtype=torch.bool, device=x.device).triu(j - i + 1) 318 | 319 | if self.non_causal_prefix > 0: 320 | causal_mask[:self.non_causal_prefix, :(self.non_causal_prefix + j - i)] = False 321 | 322 | sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) 323 | 324 | attn = sim.softmax(dim=-1) 325 | attn = self.attn_dropout(attn) 326 | 327 | # aggregate 328 | out = einsum('b h i j, b j d -> b h i d', attn, v) 329 | 330 | # merge heads 331 | out = rearrange(out, 'b h n d -> b n (h d)') 332 | 333 | return self.to_out(out) 334 | 335 | # transformer 336 | 337 | 338 | class Transformer(nn.Module): 339 | def __init__( 340 | self, 341 | *, 342 | dim, 343 | depth, 344 | heads, 345 | dim_context=None, 346 | cross_attend=False, 347 | attn_dropout=0., 348 | ff_dropout=0., 349 | use_conv_ff=True, 350 | grad_shrink_alpha=0.1, 351 | cond_as_self_attn_prefix=False, 352 | non_causal_prefix_size=0, 353 | relative_position_bias_type='continuous', 354 | **kwargs 355 | ): 356 | super().__init__() 357 | assert not (cross_attend and cond_as_self_attn_prefix) 358 | self.dim_context = default(dim_context, dim) 359 | 360 | self.cond_as_self_attn_prefix = cond_as_self_attn_prefix 361 | 362 | self.grad_shrink = partial(grad_shrink, alpha=grad_shrink_alpha) 363 | 364 | self.layers = nn.ModuleList([]) 365 | 366 | if relative_position_bias_type == 'continuous': 367 | self.rel_pos_bias = RelativePositionBias(dim=dim // 2, heads=heads) 368 | elif relative_position_bias_type == 't5': 369 | self.rel_pos_bias = T5RelativePositionBias(heads=heads, num_buckets=32, max_distance=128) 370 | elif relative_position_bias_type == 'none': 371 | self.rel_pos_bias = None 372 | else: 373 | raise ValueError(f'invalid relative position bias type: {relative_position_bias_type}') 374 | 375 | for _ in range(depth): 376 | self.layers.append(nn.ModuleList([ 377 | Attention(dim=dim, heads=heads, dropout=attn_dropout, causal=True, non_causal_prefix=non_causal_prefix_size, **kwargs), 378 | Attention(dim=dim, heads=heads, dropout=attn_dropout, dim_context=dim_context, 379 | num_null_kv=1, norm_context=True, **kwargs) if cross_attend else None, 380 | ConvFeedForward(dim=dim, dropout=ff_dropout) if use_conv_ff else FeedForward(dim=dim, dropout=ff_dropout), 381 | ])) 382 | 383 | self.norm = LayerNorm(dim) 384 | 385 | def forward( 386 | self, 387 | x, 388 | self_attn_mask=None, 389 | context=None, 390 | context_mask=None, 391 | attn_bias=None, 392 | ): 393 | assert not (self.cond_as_self_attn_prefix and not exists(context)) 394 | assert not (exists( 395 | context) and context.shape[-1] != self.dim_context), f'you had specified a conditioning dimension of {self.dim_context}, yet what was received by the transformer has dimension of {context.shape[-1]}' 396 | 397 | n, device = x.shape[1], x.device 398 | 399 | # from cogview paper, adopted by GLM 130B LLM, decreases likelihood of attention net instability 400 | x = self.grad_shrink(x) 401 | 402 | if exists(attn_bias): 403 | rel_pos_bias = attn_bias 404 | else: 405 | rel_pos_bias = self.rel_pos_bias(n, device = device) if exists(self.rel_pos_bias) else None 406 | 407 | self_attn_kwargs = dict() 408 | if self.cond_as_self_attn_prefix: 409 | self_attn_kwargs = dict( 410 | prefix_context=context, 411 | prefix_context_mask=context_mask 412 | ) 413 | 414 | for attn, cross_attn, ff in self.layers: 415 | x = attn(x, attn_bias=rel_pos_bias, mask=self_attn_mask, **self_attn_kwargs) + x 416 | 417 | if exists(cross_attn): 418 | assert exists(context) 419 | 420 | x = cross_attn(x, context=context, mask=context_mask) + x 421 | 422 | x = ff(x) + x 423 | 424 | return self.norm(x) 425 | -------------------------------------------------------------------------------- /open_musiclm/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch.nn.utils.rnn import pad_sequence 5 | from beartype import beartype 6 | from pathlib import Path 7 | import shutil 8 | import os 9 | from torchaudio.functional import resample 10 | 11 | from einops import rearrange, repeat, reduce 12 | 13 | def beartype_jit(func): 14 | """decorator to enable beartype only if USE_BEARTYPE is set to 1""" 15 | return beartype(func) if os.environ.get('USE_BEARTYPE', '0') == '1' else func 16 | 17 | # helper functions 18 | 19 | def exists(val): 20 | return val is not None 21 | 22 | def default(val, d): 23 | return val if exists(val) else d 24 | 25 | def ceil_div(numer, denom): 26 | return (numer + denom - 1) // denom 27 | 28 | def remainder_needed_until_multiple(n, mult): 29 | return (ceil_div(n, mult) * mult) - n 30 | 31 | def round_down_nearest_multiple(val, mult): 32 | return (val // mult) * mult 33 | 34 | def curtail_to_multiple(t, mult): 35 | data_len = t.shape[-1] 36 | return t[..., :round_down_nearest_multiple(data_len, mult)] 37 | 38 | def eval_decorator(fn): 39 | def inner(model, *args, **kwargs): 40 | was_training = model.training 41 | model.eval() 42 | out = fn(model, *args, **kwargs) 43 | model.train(was_training) 44 | return out 45 | return inner 46 | 47 | # tensor helpers 48 | 49 | def generate_mask_with_prob(shape, mask_prob, device): 50 | seq = shape[-1] 51 | rand = torch.randn(shape, device = device) 52 | rand[:, 0] = -torch.finfo(rand.dtype).max 53 | num_mask = min(int(seq * mask_prob), seq - 1) 54 | indices = rand.topk(num_mask, dim = -1).indices 55 | mask = ~torch.zeros(shape, device = device).scatter(1, indices, 1.).bool() 56 | return mask 57 | 58 | # attention related utils 59 | 60 | def grad_shrink(t, alpha = 0.1): 61 | return t * alpha + t.detach() * (1 - alpha) 62 | 63 | # sampling helpers 64 | 65 | def log(t, eps = 1e-20): 66 | return torch.log(t + eps) 67 | 68 | def l2norm(t): 69 | return F.normalize(t, dim = -1) 70 | 71 | def gumbel_noise(t): 72 | noise = torch.zeros_like(t).uniform_(0, 1) 73 | return -log(-log(noise)) 74 | 75 | def gumbel_sample(t, temperature = 1., dim = -1): 76 | return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim) 77 | 78 | def top_k(logits, thres = 0.5): 79 | num_logits = logits.shape[-1] 80 | k = max(int((1 - thres) * num_logits), 1) 81 | val, ind = torch.topk(logits, k) 82 | probs = torch.full_like(logits, float('-inf')) 83 | probs.scatter_(1, ind, val) 84 | return probs 85 | 86 | def mask_out_after_eos_id(t, eos_id, mask_value = -1, keep_eos = True): 87 | eos_mask = (t == eos_id).float() 88 | 89 | if keep_eos: 90 | eos_mask = F.pad(eos_mask, (1, -1)) 91 | 92 | after_eos_mask = eos_mask.cumsum(dim = -1) > 0 93 | return t.masked_fill(after_eos_mask, mask_value) 94 | 95 | def all_rows_have_eos_id(t, eos_id): 96 | eos_mask = (t == eos_id) 97 | return torch.any(eos_mask, dim = -1).all() 98 | 99 | # classifier free guidance functions 100 | 101 | def prob_mask_like(shape, prob, device): 102 | if prob == 1: 103 | return torch.ones(shape, device = device, dtype = torch.bool) 104 | elif prob == 0: 105 | return torch.zeros(shape, device = device, dtype = torch.bool) 106 | else: 107 | return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob 108 | 109 | # removing unique consecutives in the semantic token ids 110 | # important detail noted by @eonglints 111 | 112 | def append_eos_id(ids, eos_id): 113 | b, device = ids.shape[0], ids.device 114 | eos_ids = torch.ones(1, device = device).long() * eos_id 115 | eos_ids = repeat(eos_ids, '1 -> b 1', b = b) 116 | ids = torch.cat((ids, eos_ids), dim = -1) 117 | return ids 118 | 119 | def batch_unique_consecutive(t, pad_value = 0.): 120 | unique_arr = [torch.unique_consecutive(el) for el in t.unbind(dim = 0)] 121 | return pad_sequence(unique_arr, batch_first = True, padding_value = pad_value) 122 | 123 | # to get embedding from sequence with padding token 124 | 125 | @beartype_jit 126 | def get_embeds( 127 | embeddings: nn.Embedding, 128 | codes: torch.Tensor, 129 | pad_id = -1, 130 | return_mask = False, 131 | mask_pad_pos_to = 0 132 | ): 133 | pad_mask = codes == pad_id 134 | codes_without_pad = codes.masked_fill(pad_mask, 0) # just retrieve first code as dummy 135 | embeds = embeddings(codes_without_pad) 136 | 137 | if exists(mask_pad_pos_to): 138 | embeds = embeds.masked_fill(rearrange(pad_mask, '... -> ... 1'), mask_pad_pos_to) 139 | 140 | if return_mask: 141 | return embeds, ~pad_mask 142 | 143 | return embeds 144 | 145 | # audio processing helpers 146 | 147 | def int16_to_float32(x): 148 | return (x / 32767.0).type(torch.float32) 149 | 150 | def float32_to_int16(x): 151 | x = torch.clamp(x, min=-1., max=1.) 152 | return (x * 32767.).type(torch.int16) 153 | 154 | def zero_mean_unit_var_norm(x): 155 | return (x - x.mean(dim=-1, keepdim=True)) / torch.sqrt(x.var(dim=-1, keepdim=True) + 1e-7) 156 | 157 | def prepare_audio(data, sample_hz, target_sample_hz, normalize=True, target_length_seconds=None): 158 | if data.shape[0] > 1: 159 | data = torch.mean(data, dim=0).unsqueeze(0) 160 | if normalize: 161 | data = zero_mean_unit_var_norm(data) 162 | if exists(target_length_seconds) and data.shape[1] > target_length_seconds * sample_hz: 163 | data = data[: , :int(target_length_seconds * sample_hz)] 164 | audio_for_wav2vec = resample(data, sample_hz, target_sample_hz) 165 | audio_for_wav2vec = int16_to_float32(float32_to_int16(audio_for_wav2vec)) 166 | return audio_for_wav2vec 167 | 168 | # helper for saving config 169 | 170 | def copy_file_to_folder(file_path: str, folder_path: str): 171 | config_file = Path(file_path) 172 | folder = Path(folder_path) 173 | 174 | shutil.copy(str(config_file), str(folder / config_file.name)) -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhvng/open-musiclm/8e2c6a8d65a33d407072bab8f5af64d9cb49f2e4/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/download_checkpoints.sh: -------------------------------------------------------------------------------- 1 | # new clap checkpoint 2 | if [ -e ./checkpoints/music_speech_audioset_epoch_15_esc_89.98.pt ] 3 | then 4 | echo "clap checkpoint already downloaded" 5 | else 6 | wget -P ./checkpoints 'https://huggingface.co/lukewys/laion_clap/resolve/main/music_speech_audioset_epoch_15_esc_89.98.pt' 7 | fi 8 | -------------------------------------------------------------------------------- /scripts/download_fma_large.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | if [ -e ./data/fma_large.zip ] 4 | then 5 | echo "fma_large already downloaded" 6 | else 7 | echo "downloading fma_large.zip, may take a while..." 8 | wget -P ./data https://os.unil.cloud.switch.ch/fma/fma_large.zip 9 | fi 10 | 11 | if [ -e ./data/fma_large ] 12 | then 13 | echo "fma_large already unzipped" 14 | else 15 | echo "unzipping fma_large.zip, may take a while..." 16 | echo "497109f4dd721066b5ce5e5f250ec604dc78939e data/fma_large.zip" | sha1sum -c - 17 | cd data 18 | unzip fma_large.zip 19 | fi -------------------------------------------------------------------------------- /scripts/download_fma_metadata.sh: -------------------------------------------------------------------------------- 1 | wget -P ./data https://os.unil.cloud.switch.ch/fma/fma_metadata.zip 2 | echo "f0df49ffe5f2a6008d7dc83c6915b31835dfe733 fma_metadata.zip" | sha1sum -c - 3 | cd data 4 | unzip fma_metadata.zip -------------------------------------------------------------------------------- /scripts/infer.py: -------------------------------------------------------------------------------- 1 | ''' 2 | example usage: 3 | 4 | python3 scripts/infer.py \ 5 | --semantic_path ./results/semantic/semantic.transformer.10000.pt \ 6 | --coarse_path ./results/coarse/coarse.transformer.10000.pt \ 7 | --fine_path ./results/fine/fine.transformer.10000.pt \ 8 | --model_config ./configs/model/musiclm_small.json \ 9 | --return_coarse_wave 10 | ''' 11 | 12 | import os 13 | import sys 14 | 15 | import torch 16 | import torchaudio 17 | from einops import rearrange 18 | import argparse 19 | 20 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 21 | 22 | from open_musiclm.config import load_model_config, create_musiclm_from_config 23 | 24 | prompts = [ 25 | [ 26 | 'The main soundtrack of an arcade game. It is fast-paced and upbeat, with a catchy electric guitar riff. The music is repetitive and easy to remember, but with unexpected sounds, like cymbal crashes or drum rolls.', 27 | 'A fusion of reggaeton and electronic dance music, with a spacey, otherworldly sound. Induces the experience of being lost in space, and the music would be designed to evoke a sense of wonder and awe, while being danceable.', 28 | 'A rising synth is playing an arpeggio with a lot of reverb. It is backed by pads, sub bass line and soft drums. This song is full of synth sounds creating a soothing and adventurous atmosphere. It may be playing at a festival during two songs for a buildup.', 29 | 'Slow tempo, bass-and-drums-led reggae song. Sustained electric guitar. High-pitched bongos with ringing tones. Vocals are relaxed with a laid-back feel, very expressive.', 30 | ], 31 | ['song with synths and flute', 'crowd cheering', 'piano sonata waltz, glittery', 'house song, 4 on the floor, rhythm'], 32 | ['chirping of birds and the distant echos of bells', 'cat meowing', 'saxophone with drums', 'beethoven piano sonata'] 33 | ] 34 | 35 | if __name__ == '__main__': 36 | parser = argparse.ArgumentParser(description='run inference on trained musiclm model') 37 | 38 | parser.add_argument('--model_config', default='./configs/model/musiclm_small.json', help='path to model config') 39 | parser.add_argument('--semantic_path', required=True, help='path to semantic stage checkpoint') 40 | parser.add_argument('--coarse_path', required=True, help='path to coarse stage checkpoint') 41 | parser.add_argument('--fine_path', required=True, help='path to fine stage checkpoint') 42 | parser.add_argument('--rvq_path', default='./checkpoints/clap.rvq.350.pt') 43 | parser.add_argument('--kmeans_path', default='./results/hubert_kmeans/kmeans.joblib') 44 | parser.add_argument('--return_coarse_wave', default=False, action=argparse.BooleanOptionalAction) 45 | parser.add_argument('--duration', default=4, type=float, help='duration of audio to generate in seconds') 46 | parser.add_argument('--seed', default=0) 47 | 48 | args = parser.parse_args() 49 | 50 | model_config = load_model_config(args.model_config) 51 | 52 | semantic_path = args.semantic_path 53 | coarse_path = args.coarse_path 54 | fine_path = args.fine_path 55 | return_coarse_wave = args.return_coarse_wave 56 | duration = args.duration 57 | kmeans_path = args.kmeans_path 58 | rvq_path = args.rvq_path 59 | seed = args.seed 60 | 61 | print(f'semantic checkpoint {semantic_path}, coarse checkpoint {coarse_path}, fine checkpoint {fine_path}') 62 | print(f'kmeans path {kmeans_path}, rvq path {rvq_path}, return_coarse_wave {return_coarse_wave}') 63 | 64 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 65 | 66 | musiclm = create_musiclm_from_config( 67 | model_config=model_config, 68 | semantic_path=semantic_path, 69 | coarse_path=coarse_path, 70 | fine_path=fine_path, 71 | rvq_path=rvq_path, 72 | kmeans_path=kmeans_path, 73 | device=device) 74 | 75 | torch.manual_seed(seed) 76 | 77 | print('generating...') 78 | 79 | for prompt in prompts: 80 | generated_wave = musiclm.forward( 81 | text=prompt, 82 | output_seconds=duration, 83 | semantic_window_seconds=model_config.global_cfg.semantic_audio_length_seconds, 84 | coarse_window_seconds=model_config.global_cfg.coarse_audio_length_seconds, 85 | fine_window_seconds=model_config.global_cfg.fine_audio_length_seconds, 86 | semantic_steps_per_second=model_config.hubert_kmeans_cfg.output_hz, 87 | acoustic_steps_per_second=model_config.encodec_cfg.output_hz, 88 | return_coarse_generated_wave=return_coarse_wave, 89 | ).detach().cpu() 90 | 91 | print(generated_wave.shape) 92 | 93 | generated_wave = rearrange(generated_wave, 'b n -> b 1 n') 94 | for i, wave in enumerate(generated_wave): 95 | torchaudio.save(f'results/{prompt[i][:25]}_generated.wav', wave, musiclm.neural_codec.sample_rate) 96 | 97 | -------------------------------------------------------------------------------- /scripts/infer_coarse.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 1) load audio samples 3 | 2) compute clap tokens and semantic tokens 4 | 3) run them through coarse stage to predict coarse tokens 5 | 4) reconstruct audio from coarse tokens 6 | Reconstructed audio should be semantically similar to the original audio if hubert-kmeans and coarse stage are working correctly 7 | 8 | example usage: 9 | 10 | python scripts/infer_coarse.py \ 11 | ./data/fma_large/000/000005.mp3 \ 12 | ./data/fma_large/000/000010.mp3 \ 13 | --model_config ./configs/model/musiclm_small.json \ 14 | --coarse_path ./results/coarse_continue_1/coarse.transformer.10000.pt 15 | 16 | ''' 17 | 18 | import argparse 19 | import os 20 | import sys 21 | from pathlib import Path 22 | 23 | import torch 24 | import torchaudio 25 | from einops import rearrange 26 | from torchaudio.functional import resample 27 | 28 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 29 | 30 | from open_musiclm.config import (create_clap_quantized_from_config, 31 | create_coarse_transformer_from_config, 32 | create_encodec_from_config, 33 | create_hubert_kmeans_from_config, 34 | load_model_config) 35 | from open_musiclm.open_musiclm import (CoarseStage, 36 | get_or_compute_clap_token_ids, 37 | get_or_compute_semantic_token_ids) 38 | from open_musiclm.utils import int16_to_float32, float32_to_int16, zero_mean_unit_var_norm 39 | from scripts.train_utils import disable_print 40 | 41 | if __name__ == '__main__': 42 | parser = argparse.ArgumentParser(description='run inference on coarse stage') 43 | parser.add_argument('audio_files', type=str, nargs='+') 44 | parser.add_argument('--model_config', default='./configs/model/musiclm_small.json', help='path to model config') 45 | parser.add_argument('--coarse_path', required=True, help='path to coarse stage checkpoint') 46 | parser.add_argument('--rvq_path', default='./checkpoints/clap.rvq.350.pt') 47 | parser.add_argument('--kmeans_path', default='./results/hubert_kmeans/kmeans.joblib') 48 | parser.add_argument('--seed', default=0) 49 | 50 | args = parser.parse_args() 51 | 52 | model_config = load_model_config(args.model_config) 53 | 54 | audio_files = args.audio_files 55 | coarse_path = args.coarse_path 56 | kmeans_path = args.kmeans_path 57 | rvq_path = args.rvq_path 58 | seed = args.seed 59 | 60 | print(f'running inference on {audio_files}') 61 | print(f'coarse_path: {coarse_path}, kmeans_path: {kmeans_path}, rvq_path: {rvq_path}, seed: {seed}') 62 | 63 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 64 | 65 | print('loading clap...') 66 | clap = create_clap_quantized_from_config(model_config, args.rvq_path, device) 67 | 68 | print('loading wav2vec...') 69 | wav2vec = create_hubert_kmeans_from_config(model_config, args.kmeans_path, device) 70 | 71 | print('loading encodec...') 72 | encodec_wrapper = create_encodec_from_config(model_config, device) 73 | 74 | print('loading coarse stage...') 75 | coarse_transformer = create_coarse_transformer_from_config(model_config, coarse_path, device) 76 | 77 | coarse_stage = CoarseStage( 78 | coarse_transformer=coarse_transformer, 79 | neural_codec=encodec_wrapper, 80 | wav2vec=wav2vec, 81 | clap=clap 82 | ) 83 | 84 | torch.manual_seed(args.seed) 85 | 86 | print('loading audio from dataset') 87 | 88 | audios_for_clap = [] 89 | audios_for_wav2vec = [] 90 | for audio_path in audio_files: 91 | data, sample_hz = torchaudio.load(audio_path) 92 | 93 | if data.shape[0] > 1: 94 | data = torch.mean(data, dim=0).unsqueeze(0) 95 | 96 | target_length = int(4 * sample_hz) 97 | normalized_data = zero_mean_unit_var_norm(data) 98 | 99 | data = data[:, :target_length] 100 | normalized_data = normalized_data[: , :target_length] 101 | audio_for_clap = resample(data, sample_hz, clap.sample_rate) 102 | audio_for_wav2vec = resample(normalized_data, sample_hz, wav2vec.target_sample_hz) 103 | 104 | audio_for_clap = int16_to_float32(float32_to_int16(audio_for_clap)) 105 | audio_for_wav2vec = int16_to_float32(float32_to_int16(audio_for_wav2vec)) 106 | 107 | audios_for_clap.append(audio_for_clap) 108 | audios_for_wav2vec.append(audio_for_wav2vec) 109 | 110 | audios_for_clap = torch.cat(audios_for_clap, dim=0).to(device) 111 | audios_for_wav2vec = torch.cat(audios_for_wav2vec, dim=0).to(device) 112 | 113 | clap_token_ids = get_or_compute_clap_token_ids(None, clap, audios_for_clap, None) 114 | semantic_token_ids = get_or_compute_semantic_token_ids(None, audios_for_wav2vec, wav2vec) 115 | 116 | generated_wave = coarse_stage.generate( 117 | clap_token_ids=clap_token_ids, 118 | semantic_token_ids=semantic_token_ids, 119 | coarse_token_ids=None, 120 | max_time_steps=int(model_config.global_cfg.coarse_audio_length_seconds * 75), 121 | reconstruct_wave=True, 122 | include_eos_in_output=False, 123 | append_eos_to_conditioning_tokens=True, 124 | temperature=0.95, 125 | ) 126 | 127 | generated_wave = rearrange(generated_wave, 'b n -> b 1 n').detach().cpu() 128 | for i, wave in enumerate(generated_wave): 129 | torchaudio.save(f'results/{i}.wav', wave, encodec_wrapper.sample_rate) -------------------------------------------------------------------------------- /scripts/infer_fine.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 1) load audio samples 3 | 2) compute clap tokens and coarse tokens 4 | 3) run them through fine stage to predict coarse tokens 5 | 4) reconstruct audio from coarse + fine tokens 6 | Reconstructed audio should be similar to the original audio if fine stage is working correctly 7 | 8 | example usage: 9 | 10 | python scripts/infer_fine.py \ 11 | ./data/fma_large/000/000005.mp3 \ 12 | ./data/fma_large/000/000010.mp3 \ 13 | --model_config ./configs/model/musiclm_small.json \ 14 | --fine_path ./results/coarse_continue_1/coarse.transformer.10000.pt 15 | ''' 16 | 17 | import argparse 18 | import os 19 | import sys 20 | from pathlib import Path 21 | 22 | import torch 23 | import torchaudio 24 | from einops import rearrange 25 | from torchaudio.functional import resample 26 | 27 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 28 | 29 | from open_musiclm.config import (create_clap_quantized_from_config, 30 | create_fine_transformer_from_config, 31 | create_encodec_from_config, 32 | create_hubert_kmeans_from_config, 33 | load_model_config) 34 | from open_musiclm.open_musiclm import (FineStage, 35 | get_or_compute_clap_token_ids, 36 | get_or_compute_acoustic_token_ids) 37 | from open_musiclm.utils import int16_to_float32, float32_to_int16, zero_mean_unit_var_norm 38 | from scripts.train_utils import disable_print 39 | 40 | if __name__ == '__main__': 41 | parser = argparse.ArgumentParser(description='run inference on fine stage') 42 | parser.add_argument('audio_files', type=str, nargs='+') 43 | parser.add_argument('--model_config', default='./configs/model/musiclm_small.json', help='path to model config') 44 | parser.add_argument('--fine_path', required=True, help='path to fine stage checkpoint') 45 | parser.add_argument('--temperature', default=0.4, type=float) 46 | parser.add_argument('--rvq_path', default='./checkpoints/clap.rvq.350.pt') 47 | parser.add_argument('--seed', default=0) 48 | 49 | args = parser.parse_args() 50 | 51 | model_config = load_model_config(args.model_config) 52 | 53 | audio_files = args.audio_files 54 | fine_path = args.fine_path 55 | rvq_path = args.rvq_path 56 | seed = args.seed 57 | temperature = args.temperature 58 | 59 | print(f'running inference on {audio_files}') 60 | print(f'fine_path: {fine_path}, rvq_path: {rvq_path}, seed: {seed}') 61 | 62 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 63 | 64 | print('loading clap...') 65 | clap = create_clap_quantized_from_config(model_config, rvq_path, device) 66 | 67 | print('loading encodec...') 68 | encodec_wrapper = create_encodec_from_config(model_config, device) 69 | 70 | print('loading fine stage...') 71 | fine_transformer = create_fine_transformer_from_config(model_config, fine_path, device) 72 | 73 | fine_stage = FineStage( 74 | fine_transformer=fine_transformer, 75 | neural_codec=encodec_wrapper, 76 | clap=clap 77 | ) 78 | 79 | torch.manual_seed(seed) 80 | 81 | print('loading audio from dataset') 82 | 83 | audios_for_clap = [] 84 | audios_for_encodec = [] 85 | for audio_path in audio_files: 86 | data, sample_hz = torchaudio.load(audio_path) 87 | 88 | if data.shape[0] > 1: 89 | data = torch.mean(data, dim=0).unsqueeze(0) 90 | 91 | target_length = int(model_config.global_cfg.fine_audio_length_seconds * sample_hz) 92 | 93 | data = data[:, :target_length] 94 | audio_for_clap = resample(data, sample_hz, clap.sample_rate) 95 | audio_for_encodec = resample(data, sample_hz, encodec_wrapper.sample_rate) 96 | 97 | audio_for_clap = int16_to_float32(float32_to_int16(audio_for_clap)) 98 | audio_for_encodec = int16_to_float32(float32_to_int16(audio_for_encodec)) 99 | 100 | audios_for_clap.append(audio_for_clap) 101 | audios_for_encodec.append(audio_for_encodec) 102 | 103 | audios_for_clap = torch.cat(audios_for_clap, dim=0).to(device) 104 | audios_for_encodec = torch.cat(audios_for_encodec, dim=0).to(device) 105 | 106 | clap_token_ids = get_or_compute_clap_token_ids(None, clap, audios_for_clap, None) 107 | coarse_token_ids, fine_token_ids = get_or_compute_acoustic_token_ids(None, None, audios_for_encodec, encodec_wrapper, model_config.global_cfg.num_coarse_quantizers) 108 | 109 | generated_wave = fine_stage.generate( 110 | clap_token_ids=clap_token_ids, 111 | coarse_token_ids=coarse_token_ids, 112 | max_time_steps=int(model_config.global_cfg.fine_audio_length_seconds * model_config.encodec_cfg.output_hz), 113 | reconstruct_wave=True, 114 | include_eos_in_output=False, 115 | append_eos_to_conditioning_tokens=True, 116 | temperature=temperature, 117 | ) 118 | 119 | generated_wave = rearrange(generated_wave, 'b n -> b 1 n').detach().cpu() 120 | for i, wave in enumerate(generated_wave): 121 | torchaudio.save(f'results/fine_reconstruct_{i}.wav', wave, encodec_wrapper.sample_rate) -------------------------------------------------------------------------------- /scripts/infer_top_match.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import torch 5 | import torchaudio 6 | from einops import rearrange 7 | from pathlib import Path 8 | import argparse 9 | 10 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 11 | 12 | from open_musiclm.config import load_model_config, create_musiclm_from_config 13 | 14 | if __name__ == '__main__': 15 | parser = argparse.ArgumentParser(description='run inference on trained musiclm model') 16 | 17 | parser.add_argument('prompt', help='prompts to generate audio for', type=str, nargs='+') 18 | parser.add_argument('--num_samples', default=4, type=int) 19 | parser.add_argument('--num_top_matches', default=1, type=int) 20 | parser.add_argument('--input_audio', default=None, type=str, help='input audio to condition on and generate continuations from') 21 | parser.add_argument('--model_config', default='./configs/model/musiclm_small.json', help='path to model config') 22 | parser.add_argument('--semantic_path', required=True, help='path to semantic stage checkpoint') 23 | parser.add_argument('--coarse_path', required=True, help='path to coarse stage checkpoint') 24 | parser.add_argument('--fine_path', required=True, help='path to fine stage checkpoint') 25 | parser.add_argument('--rvq_path', default='./checkpoints/clap.rvq.350.pt') 26 | parser.add_argument('--kmeans_path', default='./results/hubert_kmeans/kmeans.joblib') 27 | parser.add_argument('--results_folder', default='./results', type=str) 28 | parser.add_argument('--return_coarse_wave', default=False, action=argparse.BooleanOptionalAction) 29 | parser.add_argument('--duration', default=4, type=float, help='duration of audio to generate in seconds') 30 | parser.add_argument('--seed', default=0) 31 | 32 | args = parser.parse_args() 33 | 34 | model_config = load_model_config(args.model_config) 35 | 36 | semantic_path = args.semantic_path 37 | coarse_path = args.coarse_path 38 | fine_path = args.fine_path 39 | input_audio = args.input_audio 40 | return_coarse_wave = args.return_coarse_wave 41 | duration = args.duration 42 | kmeans_path = args.kmeans_path 43 | rvq_path = args.rvq_path 44 | seed = args.seed 45 | results_folder = args.results_folder 46 | 47 | Path(results_folder).mkdir(parents=True, exist_ok=True) 48 | 49 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 50 | 51 | musiclm = create_musiclm_from_config( 52 | model_config=model_config, 53 | semantic_path=semantic_path, 54 | coarse_path=coarse_path, 55 | fine_path=fine_path, 56 | rvq_path=rvq_path, 57 | kmeans_path=kmeans_path, 58 | device=device) 59 | 60 | torch.manual_seed(seed) 61 | 62 | print(f'prompt: {args.prompt}') 63 | 64 | prime_wave, prime_wave_sample_hz = None, None 65 | if input_audio is not None: 66 | prime_wave, prime_wave_sample_hz = torchaudio.load(input_audio) 67 | prime_wave = prime_wave.to(device) 68 | 69 | generated_wave, similarities = musiclm.generate_top_match( 70 | text=args.prompt, 71 | prime_wave=prime_wave, 72 | prime_wave_sample_hz=prime_wave_sample_hz, 73 | num_samples=args.num_samples, 74 | num_top_matches=args.num_top_matches, 75 | output_seconds=duration, 76 | semantic_window_seconds=model_config.global_cfg.semantic_audio_length_seconds, 77 | coarse_window_seconds=model_config.global_cfg.coarse_audio_length_seconds, 78 | fine_window_seconds=model_config.global_cfg.fine_audio_length_seconds, 79 | semantic_steps_per_second=model_config.hubert_kmeans_cfg.output_hz, 80 | acoustic_steps_per_second=model_config.encodec_cfg.output_hz, 81 | return_coarse_generated_wave=return_coarse_wave, 82 | ) 83 | 84 | for i, (wave, sim) in enumerate(zip(generated_wave, similarities)): 85 | wave = rearrange(wave, 'b n -> b 1 n').detach().cpu() 86 | print(f'prompt: {args.prompt[i]}') 87 | print(f'topk similarities: {sim}') 88 | for j, w in enumerate(wave): 89 | torchaudio.save(Path(results_folder) / Path(f'{args.prompt[i][:35]}_top_match_{j}.wav'), w, musiclm.neural_codec.sample_rate) 90 | -------------------------------------------------------------------------------- /scripts/preprocess_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | from pathlib import Path 5 | 6 | import torch 7 | 8 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 9 | 10 | from open_musiclm.config import (create_clap_quantized_from_config, 11 | create_encodec_from_config, 12 | create_hubert_kmeans_from_config, 13 | load_model_config, load_training_config, 14 | create_data_preprocessor_from_config) 15 | 16 | if __name__ == '__main__': 17 | parser = argparse.ArgumentParser(description='preprocess data') 18 | parser.add_argument('--model_config', default='./configs/model/musiclm_small.json') 19 | parser.add_argument('--training_config', default='./configs/training/train_fma_preprocess.json') 20 | parser.add_argument('--rvq_path', default='./checkpoints/clap.rvq.350.pt') 21 | parser.add_argument('--kmeans_path', default='./results/hubert_kmeans/kmeans.joblib') 22 | parser.add_argument('--filter_fma', default=False, action=argparse.BooleanOptionalAction) 23 | 24 | args = parser.parse_args() 25 | 26 | print(f'using model config {args.model_config}, training config {args.training_config}, rvq checkpoint {args.rvq_path}, kmeans checkpoint {args.kmeans_path}') 27 | 28 | model_config = load_model_config(args.model_config) 29 | training_config = load_training_config(args.training_config) 30 | 31 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 32 | 33 | print('loading clap...') 34 | clap = create_clap_quantized_from_config(model_config, args.rvq_path, device) 35 | 36 | print('loading wav2vec...') 37 | wav2vec = create_hubert_kmeans_from_config(model_config, args.kmeans_path, device) 38 | 39 | print('loading encodec...') 40 | encodec_wrapper = create_encodec_from_config(model_config, device) 41 | 42 | 43 | # get rid of some experimental tracks, see notebooks/analyze_fma.ipynb 44 | if args.filter_fma: 45 | try: 46 | import pandas as pd 47 | import ast 48 | except ImportError: 49 | pd = None 50 | 51 | assert pd is not None, 'pandas not found, please install pandas to filter fma' 52 | 53 | metadata_folder = training_config.data_preprocessor_cfg.metadata_folder 54 | 55 | tracks = pd.read_csv(os.path.join(metadata_folder, 'tracks.csv'), index_col=0, header=[0, 1]) 56 | experimental_genre = 38 57 | experimental_tracks = tracks.loc[tracks['track', 'genres_all'].apply(lambda x: experimental_genre in ast.literal_eval(x))] 58 | ignore_files = list(experimental_tracks.loc[(experimental_tracks['track', 'listens'] <= 1000) | (experimental_tracks['track', 'favorites'] <= 5)].index) 59 | ignore_files = [f'{i:06d}.mp3' for i in ignore_files] 60 | else: 61 | ignore_files = None 62 | 63 | processor = create_data_preprocessor_from_config( 64 | model_config, 65 | training_config, 66 | clap, 67 | wav2vec, 68 | encodec_wrapper, 69 | device, 70 | config_paths=[args.model_config, args.training_config], 71 | ignore_files=ignore_files) 72 | 73 | processor.process() 74 | 75 | -------------------------------------------------------------------------------- /scripts/test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhvng/open-musiclm/8e2c6a8d65a33d407072bab8f5af64d9cb49f2e4/scripts/test/__init__.py -------------------------------------------------------------------------------- /scripts/test/test_clap.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import torchaudio 5 | from torchaudio.functional import resample 6 | import numpy as np 7 | import torch 8 | from transformers import RobertaTokenizer 9 | 10 | sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) 11 | 12 | from open_musiclm.clap_quantized import ClapQuantized, create_clap_quantized 13 | 14 | text_data = ['male rap', 15 | 'male rapping over a synth pad', 16 | 'female singing', 17 | 'female singing over a synth pad', 18 | 'male singing over a synth pad', 19 | 'male rapping chill voice', 20 | 'male with deep voice', 21 | 'male singing then rapping, synth in the background', 22 | 'producer tag then drake rapping', 23 | 'calming melody', 24 | 'upbeat, hype', 25 | 'pause and then the beat drops, and a male rapper is rapping over a trap beat', 26 | 'male rapping over a hip hop beat', 27 | 'house music', 28 | 'rock song with piano', 29 | 'groovy melody with piano and a male singing'] 30 | 31 | def infer_text(clap_wrapper, return_embedding=False): 32 | 33 | 34 | 35 | text_embed = clap_wrapper(text_input=text_data, return_embedding=return_embedding) 36 | 37 | return text_embed 38 | 39 | def infer_audio(clap_wrapper: ClapQuantized, return_embedding: bool = False, device: str = 'cuda'): 40 | 41 | print('inferring audio...') 42 | 43 | # load the waveform of the shape (T,), should resample to 48000 44 | audio_waveform, sr = torchaudio.load('/u/zhvng/projects/audio_files/jumpman.mp3') 45 | 46 | wave_2, sr_2 = torchaudio.load('/u/zhvng/projects/open-musiclm/data/fma_large/000/000048.mp3') 47 | 48 | if audio_waveform.shape[0] > 1: 49 | # the audio has more than 1 channel, convert to mono 50 | audio_waveform = torch.mean(audio_waveform, dim=0, keepdim=True) 51 | if wave_2.shape[0] > 1: 52 | wave_2 = torch.mean(wave_2, dim=0, keepdim=True) 53 | 54 | audio_waveform = resample(audio_waveform, sr, 48000) 55 | wave_2 = resample(wave_2, sr_2, 48000) 56 | 57 | # audio_waveform = audio_waveform[:, :48000 * 30] 58 | audio_waveform_1 = audio_waveform[:, :48000 * 10] 59 | audio_waveform_2 = wave_2[:, :48000 * 10] 60 | # audio_waveform_3 = audio_waveform[:, 48000 * 20 : 48000 * 50] 61 | audio_waveform = torch.cat([audio_waveform_1, audio_waveform_2], dim=0) 62 | 63 | audio_embed = clap_wrapper(audio_input=audio_waveform.to(device), return_embedding=return_embedding) 64 | 65 | return audio_embed 66 | 67 | 68 | if __name__ == "__main__": 69 | 70 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 71 | 72 | clap_wrapper = create_clap_quantized(device=device, learn_rvq=False).to(device) 73 | 74 | text_embeds = infer_text(clap_wrapper, return_embedding=True) 75 | audio_embed = infer_audio(clap_wrapper, return_embedding=True, device=device) 76 | 77 | # print(text_embeds) 78 | print(text_embeds.shape) 79 | 80 | # print(audio_embed) 81 | print(audio_embed.shape) 82 | 83 | for i, text_embed in enumerate(text_embeds): 84 | # get cosine similarity with audio_embed 85 | cos_sim = torch.nn.functional.cosine_similarity( 86 | audio_embed, text_embed, dim=-1) 87 | print(text_data[i], cos_sim.cpu().numpy()) 88 | -------------------------------------------------------------------------------- /scripts/test/test_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) 5 | 6 | from open_musiclm.config import load_model_config, load_training_config 7 | 8 | 9 | model_config = load_model_config('./configs/model/musiclm_small.json') 10 | print(model_config) 11 | 12 | training_config = load_training_config('./configs/training/train_musiclm_fma.json') 13 | print(training_config) 14 | 15 | print('\nok!') -------------------------------------------------------------------------------- /scripts/test/test_dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import torch 5 | import torchaudio 6 | from torchaudio.functional import resample 7 | import argparse 8 | from pathlib import Path 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | 12 | sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) 13 | 14 | from open_musiclm.data import SoundDataset, SoundDatasetForPreprocessing, PreprocessedDataset, get_dataloader, get_sound_preprocessing_dataloader, get_preprocessed_dataloader 15 | 16 | 17 | folder = './data/fma_large/000' 18 | 19 | # test random crop 20 | 21 | dataset = SoundDataset( 22 | folder, 23 | max_length_seconds=(1, 5), 24 | normalize=(True, False), 25 | target_sample_hz=(16000, 24000), 26 | seq_len_multiple_of=None, 27 | ignore_load_errors=True 28 | ) 29 | 30 | dl = get_dataloader(dataset, batch_size=4, shuffle=False) 31 | dl_iter = iter(dl) 32 | 33 | test_steps = 2 34 | for i in range(test_steps): 35 | batch = next(dl_iter) 36 | # print(batch) 37 | for e in batch: 38 | print(e.shape) 39 | 40 | # test preprocessing 41 | 42 | dataset = SoundDatasetForPreprocessing( 43 | folder, 44 | max_length_seconds=(None, 1), 45 | normalize=(True, False), 46 | target_sample_hz=(16000, 24000), 47 | seq_len_multiple_of=None, 48 | ignore_load_errors=True 49 | ) 50 | 51 | dl = get_sound_preprocessing_dataloader(dataset, shuffle=False) 52 | dl_iter = iter(dl) 53 | 54 | test_steps = 2 55 | for i in range(test_steps): 56 | batch = next(dl_iter) 57 | print(batch) 58 | 59 | # # test preprocessed 60 | dataset = PreprocessedDataset( 61 | './data/fma_preprocessed', 62 | stage='coarse', 63 | semantic_window_seconds=10, 64 | coarse_window_seconds=4, 65 | fine_window_seconds=2, 66 | semantic_steps_per_second=50, 67 | acoustic_steps_per_second=75, 68 | ) 69 | 70 | dl = get_preprocessed_dataloader(dataset, batch_size=4, shuffle=True) 71 | dl_iter = iter(dl) 72 | 73 | test_steps = 2 74 | for i in range(test_steps): 75 | batch = next(dl_iter) 76 | for d in batch: 77 | print(d.shape) 78 | -------------------------------------------------------------------------------- /scripts/test/test_encodec.py: -------------------------------------------------------------------------------- 1 | from encodec import EncodecModel 2 | from encodec.utils import convert_audio 3 | 4 | import torchaudio 5 | import torch 6 | 7 | 8 | if __name__ == "__main__": 9 | 10 | # Instantiate a pretrained EnCodec model 11 | model = EncodecModel.encodec_model_24khz() 12 | model.set_target_bandwidth(6.0) 13 | 14 | # Load and pre-process the audio waveform 15 | wav, sr = torchaudio.load("/u/zhvng/projects/audio_files/jumpman.mp3") 16 | wav = convert_audio(wav, sr, model.sample_rate, model.channels) 17 | print(model.channels, model.sample_rate, model.quantizer.n_q) 18 | print(model.segment_stride) 19 | wav = torch.stack([wav, wav], dim=0) 20 | 21 | # Extract discrete codes from EnCodec 22 | with torch.no_grad(): 23 | encoded_frames = model.encode(wav) 24 | codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # [B, n_q, T] 25 | 26 | print(codes) 27 | print(codes.shape) 28 | print(encoded_frames[0][0].shape) 29 | print(len(encoded_frames)) 30 | 31 | new_encoded_frames = [(encoded[0], None) for encoded in encoded_frames] 32 | 33 | wave = model.decode(new_encoded_frames) 34 | torchaudio.save('test.wav', wave[0], model.sample_rate) 35 | -------------------------------------------------------------------------------- /scripts/test/test_hubert_clustering.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import torch 5 | import torchaudio 6 | from torchaudio.functional import resample 7 | import argparse 8 | from pathlib import Path 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | 12 | sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) 13 | 14 | from open_musiclm.config import load_model_config, load_training_config, create_hubert_kmeans_from_config, create_hubert_kmeans_trainer_from_config 15 | from open_musiclm.utils import zero_mean_unit_var_norm, int16_to_float32, float32_to_int16, exists 16 | from open_musiclm.open_musiclm import get_or_compute_semantic_token_ids 17 | 18 | if __name__ == '__main__': 19 | parser = argparse.ArgumentParser(description='test hubert kmeans to see the difference in sequences') 20 | parser.add_argument('--model_config', default='./configs/model/musiclm_small.json') 21 | parser.add_argument('--kmeans_path', default='./results/hubert_kmeans/kmeans.joblib') 22 | parser.add_argument('--folder', default='./data/fma_large') 23 | 24 | args = parser.parse_args() 25 | 26 | model_config = load_model_config(args.model_config) 27 | 28 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 29 | 30 | print('loading hubert...') 31 | wav2vec = create_hubert_kmeans_from_config(model_config, args.kmeans_path, device) 32 | 33 | path = Path(args.folder) 34 | assert path.exists(), 'folder does not exist' 35 | 36 | files = [] 37 | for ext in ['mp3', 'wav', 'flac']: 38 | for file in path.glob(f'**/*.{ext}'): 39 | files.append(file) 40 | assert len(files) > 0, 'no sound files found' 41 | 42 | start_audio = 20000 43 | audio_lengths = [4, 10, 15, 25] 44 | batch_size = 16 45 | shortest_length = None 46 | cropped_semantic_tokens_for_each_length = [] 47 | for audio_seconds in audio_lengths: 48 | audios_for_wav2vec = [] 49 | for audio_path in files[start_audio: start_audio + 16]: 50 | data, sample_hz = torchaudio.load(audio_path) 51 | 52 | if data.shape[0] > 1: 53 | data = torch.mean(data, dim=0).unsqueeze(0) 54 | 55 | target_length = int(audio_seconds * sample_hz) 56 | normalized_data = zero_mean_unit_var_norm(data) 57 | 58 | normalized_data = normalized_data[: , :target_length] 59 | 60 | audio_for_wav2vec = resample(normalized_data, sample_hz, wav2vec.target_sample_hz) 61 | 62 | audio_for_wav2vec = int16_to_float32(float32_to_int16(audio_for_wav2vec)) 63 | 64 | audios_for_wav2vec.append(audio_for_wav2vec) 65 | 66 | audios_for_wav2vec = torch.cat(audios_for_wav2vec, dim=0).to(device) 67 | semantic_token_ids = get_or_compute_semantic_token_ids(None, audios_for_wav2vec, wav2vec) 68 | print(semantic_token_ids.shape) 69 | 70 | if not exists(shortest_length): 71 | shortest_length = semantic_token_ids.shape[1] 72 | else: 73 | l = semantic_token_ids.shape[1] 74 | if l < shortest_length: 75 | shortest_length = l 76 | 77 | cropped_semantic_tokens_for_each_length.append(semantic_token_ids[:, :shortest_length]) 78 | 79 | print(cropped_semantic_tokens_for_each_length[0][0]) 80 | print(cropped_semantic_tokens_for_each_length[1][0]) 81 | # get accuracy compared to last elem in cropped_semantic_tokens_for_each_length 82 | 83 | side_length = len(cropped_semantic_tokens_for_each_length) 84 | accuracy_matrix = np.zeros((side_length, side_length)) 85 | 86 | for i in range(side_length): 87 | for j in range(side_length): 88 | accuracy = torch.mean((cropped_semantic_tokens_for_each_length[i] == cropped_semantic_tokens_for_each_length[j]).float()) 89 | print(f'% similar between {audio_lengths[i]} and {audio_lengths[j]} second audio: {accuracy}') 90 | accuracy_matrix[i][j] = accuracy 91 | 92 | # plot the accuracy matrix in a grid with a title and axis labels 93 | 94 | # create a heatmap with darker colors representing higher accuracy 95 | fig, ax = plt.subplots() 96 | im = ax.imshow(accuracy_matrix, cmap='Blues', vmin=0, vmax=1) 97 | 98 | # remove ticks from the plot 99 | ax.tick_params(axis=u'both', which=u'both', length=0) 100 | 101 | # move the x-axis ticks to the top of the grid 102 | ax.xaxis.tick_top() 103 | 104 | # add a colorbar legend 105 | cbar = ax.figure.colorbar(im, ax=ax) 106 | 107 | # set axis labels 108 | ax.set_xticks(np.arange(len(audio_lengths))) 109 | ax.set_yticks(np.arange(len(audio_lengths))) 110 | ax.set_xticklabels(audio_lengths) 111 | ax.set_yticklabels(audio_lengths) 112 | 113 | # add text annotations for each cell 114 | for i in range(4): 115 | for j in range(4): 116 | text = ax.text(j, i, round(accuracy_matrix[i, j], 2), 117 | ha="center", va="center", color="w") 118 | 119 | # set plot title 120 | ax.set_title("Semantic Token Similarity Between Various Total Audio Lengths") 121 | ax.set_xlabel("total audio length (seconds)") 122 | ax.set_ylabel("total audio length (seconds)") 123 | 124 | # show the plot 125 | plt.savefig('./results/accuracy_matrix.png') 126 | 127 | 128 | -------------------------------------------------------------------------------- /scripts/test/test_load_preprocessed.py: -------------------------------------------------------------------------------- 1 | 2 | import io 3 | import numpy as np 4 | import sqlite3 5 | 6 | def adapt_array(arr): 7 | """ 8 | http://stackoverflow.com/a/31312102/190597 (SoulNibbler) 9 | """ 10 | out = io.BytesIO() 11 | np.save(out, arr) 12 | out.seek(0) 13 | return sqlite3.Binary(out.read()) 14 | 15 | def convert_array(text): 16 | out = io.BytesIO(text) 17 | out.seek(0) 18 | return np.load(out) 19 | 20 | sqlite3.register_adapter(np.ndarray, adapt_array) 21 | sqlite3.register_converter("array", convert_array) 22 | 23 | conn = sqlite3.connect('./data/fma_preprocessed/preprocessed.db', detect_types=sqlite3.PARSE_DECLTYPES) 24 | cursor = conn.cursor() 25 | cursor.execute("SELECT idx, path FROM tokens WHERE idx=?", (0,)) 26 | print(cursor.fetchone()) 27 | 28 | for i in [3,-1]: 29 | cursor.execute("SELECT clap, semantic, coarse FROM tokens WHERE idx=?", (i,)) 30 | 31 | data = cursor.fetchone() 32 | 33 | if data is None: 34 | print("No data found") 35 | else: 36 | for datum in data: 37 | print(datum.shape) 38 | 39 | print(data[1]) 40 | -------------------------------------------------------------------------------- /scripts/test/test_rvq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from vector_quantize_pytorch import ResidualVQ 3 | 4 | print('running rvq') 5 | rq = ResidualVQ(dim=512, num_quantizers=12, codebook_size=1024, commitment_weight=0, decay=0.95, kmeans_init=True, threshold_ema_dead_code=0) 6 | 7 | # rq.load_state_dict(torch.load('./results/semantic/semantic.conditioner_rvq.6000.pt', map_location='cpu')) 8 | 9 | for i in range(10): 10 | q, i, loss = rq(torch.randn(1, 512)) 11 | print(i[0:2]) 12 | -------------------------------------------------------------------------------- /scripts/train_clap_rvq.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | import torch 6 | 7 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 8 | 9 | from open_musiclm.clap_quantized import create_clap_quantized 10 | from open_musiclm.config import (create_clap_quantized_from_config, 11 | create_clap_rvq_trainer_from_config, 12 | load_model_config, load_training_config) 13 | from open_musiclm.trainer import ClapRVQTrainer 14 | from scripts.train_utils import disable_print 15 | 16 | if __name__ == '__main__': 17 | parser = argparse.ArgumentParser(description='train rvq to quantize clap embeddings') 18 | parser.add_argument('--results_folder', default='./results/clap_rvq') 19 | parser.add_argument('--model_config', default='./configs/model/musiclm_small.json') 20 | parser.add_argument('--training_config', default='./configs/training/train_musiclm_fma.json') 21 | parser.add_argument('--continue_from', default=None, type=str) 22 | 23 | args = parser.parse_args() 24 | 25 | model_config = load_model_config(args.model_config) 26 | training_config = load_training_config(args.training_config) 27 | 28 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 29 | 30 | print('loading clap...') 31 | clap = create_clap_quantized_from_config(model_config, args.continue_from, device) 32 | 33 | trainer = create_clap_rvq_trainer_from_config( 34 | model_config=model_config, 35 | training_config=training_config, 36 | clap=clap, 37 | results_folder=args.results_folder, 38 | device=device, 39 | accelerate_kwargs={ 40 | 'log_with': "tensorboard", 41 | 'logging_dir': './logs/clap_rvq' 42 | }, 43 | config_paths=[args.model_config, args.training_config]) 44 | 45 | trainer.train() -------------------------------------------------------------------------------- /scripts/train_coarse_stage.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | from pathlib import Path 5 | 6 | import torch 7 | 8 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 9 | 10 | from open_musiclm.config import (create_clap_quantized_from_config, 11 | create_coarse_transformer_from_config, 12 | create_encodec_from_config, 13 | create_hubert_kmeans_from_config, 14 | create_single_stage_trainer_from_config, 15 | load_model_config, load_training_config) 16 | from scripts.train_utils import load_checkpoint_from_args, validate_train_args 17 | 18 | if __name__ == '__main__': 19 | parser = argparse.ArgumentParser(description='train coarse stage') 20 | parser.add_argument('--results_folder', default='./results/coarse') 21 | parser.add_argument('--continue_from_dir', default=None, type=str) 22 | parser.add_argument('--continue_from_step', default=None, type=int) 23 | parser.add_argument('--model_config', default='./configs/model/musiclm_small.json') 24 | parser.add_argument('--training_config', default='./configs/training/train_musiclm_fma.json') 25 | parser.add_argument('--rvq_path', default='./checkpoints/clap.rvq.350.pt') 26 | parser.add_argument('--kmeans_path', default='./results/hubert_kmeans/kmeans.joblib') 27 | parser.add_argument('--fine_tune_from', default=None, type=str) 28 | 29 | args = parser.parse_args() 30 | 31 | validate_train_args(args) 32 | 33 | model_config = load_model_config(args.model_config) 34 | training_config = load_training_config(args.training_config) 35 | 36 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 37 | 38 | use_preprocessed_data = training_config.coarse_trainer_cfg.use_preprocessed_data 39 | 40 | if use_preprocessed_data: 41 | clap = None 42 | wav2vec = None 43 | print(f'training from preprocessed data {training_config.coarse_trainer_cfg.folder}') 44 | else: 45 | print('loading clap...') 46 | clap = create_clap_quantized_from_config(model_config, args.rvq_path, device) 47 | 48 | print('loading wav2vec...') 49 | wav2vec = create_hubert_kmeans_from_config(model_config, args.kmeans_path, device) 50 | 51 | print('loading encodec...') 52 | encodec_wrapper = create_encodec_from_config(model_config, device) 53 | 54 | print('loading coarse stage...') 55 | coarse_transformer = create_coarse_transformer_from_config(model_config, args.fine_tune_from, device) 56 | 57 | trainer = create_single_stage_trainer_from_config( 58 | model_config=model_config, 59 | training_config=training_config, 60 | stage='coarse', 61 | results_folder=args.results_folder, 62 | transformer=coarse_transformer, 63 | clap=clap, 64 | wav2vec=wav2vec, 65 | encodec_wrapper=encodec_wrapper, 66 | device=device, 67 | accelerate_kwargs={ 68 | 'log_with': "tensorboard", 69 | 'logging_dir': './logs/coarse' 70 | }, 71 | config_paths=[args.model_config, args.training_config]) 72 | 73 | load_checkpoint_from_args(trainer, args) 74 | 75 | trainer.train() 76 | -------------------------------------------------------------------------------- /scripts/train_fine_stage.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import torch 5 | import argparse 6 | from pathlib import Path 7 | 8 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 9 | 10 | from open_musiclm.config import (create_clap_quantized_from_config, 11 | create_encodec_from_config, 12 | create_hubert_kmeans_from_config, 13 | create_fine_transformer_from_config, 14 | create_single_stage_trainer_from_config, 15 | load_model_config, load_training_config) 16 | from scripts.train_utils import load_checkpoint_from_args, validate_train_args 17 | 18 | if __name__ == '__main__': 19 | parser = argparse.ArgumentParser(description='train fine stage') 20 | parser.add_argument('--results_folder', default='./results/fine') 21 | parser.add_argument('--continue_from_dir', default=None, type=str) 22 | parser.add_argument('--continue_from_step', default=None, type=int) 23 | parser.add_argument('--model_config', default='./configs/model/musiclm_small.json') 24 | parser.add_argument('--training_config', default='./configs/training/train_musiclm_fma.json') 25 | parser.add_argument('--rvq_path', default='./checkpoints/clap.rvq.350.pt') 26 | parser.add_argument('--kmeans_path', default='./results/hubert_kmeans/kmeans.joblib') 27 | parser.add_argument('--fine_tune_from', default=None, type=str) 28 | 29 | args = parser.parse_args() 30 | 31 | validate_train_args(args) 32 | 33 | model_config = load_model_config(args.model_config) 34 | training_config = load_training_config(args.training_config) 35 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 36 | 37 | use_preprocessed_data = training_config.fine_trainer_cfg.use_preprocessed_data 38 | 39 | if use_preprocessed_data: 40 | clap = None 41 | print(f'training from preprocessed data {training_config.fine_trainer_cfg.folder}') 42 | else: 43 | print('loading clap...') 44 | clap = create_clap_quantized_from_config(model_config, args.rvq_path, device) 45 | 46 | print('loading encodec...') 47 | encodec_wrapper = create_encodec_from_config(model_config, device) 48 | 49 | print('loading fine stage...') 50 | fine_transformer = create_fine_transformer_from_config(model_config, args.fine_tune_from, device) 51 | 52 | trainer = create_single_stage_trainer_from_config( 53 | model_config=model_config, 54 | training_config=training_config, 55 | stage='fine', 56 | results_folder=args.results_folder, 57 | transformer=fine_transformer, 58 | clap=clap, 59 | wav2vec=None, 60 | encodec_wrapper=encodec_wrapper, 61 | device=device, 62 | accelerate_kwargs={ 63 | 'log_with': "tensorboard", 64 | 'logging_dir': './logs/fine' 65 | }, 66 | config_paths=[args.model_config, args.training_config]) 67 | 68 | load_checkpoint_from_args(trainer, args) 69 | 70 | trainer.train() 71 | -------------------------------------------------------------------------------- /scripts/train_hubert_kmeans.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import torch 5 | import argparse 6 | 7 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 8 | 9 | from open_musiclm.config import load_model_config, load_training_config, create_hubert_kmeans_from_config, create_hubert_kmeans_trainer_from_config 10 | 11 | if __name__ == '__main__': 12 | parser = argparse.ArgumentParser(description='train kmeans to quantize hubert embeddings') 13 | parser.add_argument('--results_folder', default='./results/hubert_kmeans') 14 | parser.add_argument('--model_config', default='./configs/model/musiclm_small.json') 15 | parser.add_argument('--training_config', default='./configs/training/train_musiclm_fma.json') 16 | 17 | args = parser.parse_args() 18 | 19 | model_config = load_model_config(args.model_config) 20 | training_config = load_training_config(args.training_config) 21 | 22 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 23 | 24 | print('loading hubert...') 25 | hubert_kmeans = create_hubert_kmeans_from_config(model_config, None, device) 26 | 27 | trainer = create_hubert_kmeans_trainer_from_config( 28 | model_config=model_config, 29 | training_config=training_config, 30 | hubert_kmeans=hubert_kmeans, 31 | results_folder=args.results_folder, 32 | device=device, 33 | config_paths=[args.model_config, args.training_config] 34 | ) 35 | 36 | trainer.train() -------------------------------------------------------------------------------- /scripts/train_semantic_stage.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | from pathlib import Path 5 | 6 | import torch 7 | 8 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 9 | 10 | from open_musiclm.config import (create_clap_quantized_from_config, 11 | create_encodec_from_config, 12 | create_hubert_kmeans_from_config, 13 | create_semantic_transformer_from_config, 14 | create_single_stage_trainer_from_config, 15 | load_model_config, load_training_config) 16 | from scripts.train_utils import validate_train_args, load_checkpoint_from_args 17 | 18 | if __name__ == '__main__': 19 | parser = argparse.ArgumentParser(description='train semantic stage') 20 | parser.add_argument('--results_folder', default='./results/semantic') 21 | parser.add_argument('--continue_from_dir', default=None, type=str) 22 | parser.add_argument('--continue_from_step', default=None, type=int) 23 | parser.add_argument('--model_config', default='./configs/model/musiclm_small.json') 24 | parser.add_argument('--training_config', default='./configs/training/train_musiclm_fma.json') 25 | parser.add_argument('--rvq_path', default='./checkpoints/clap.rvq.350.pt') 26 | parser.add_argument('--kmeans_path', default='./results/hubert_kmeans/kmeans.joblib') 27 | parser.add_argument('--fine_tune_from', default=None, type=str) 28 | 29 | args = parser.parse_args() 30 | 31 | validate_train_args(args) 32 | 33 | model_config = load_model_config(args.model_config) 34 | training_config = load_training_config(args.training_config) 35 | 36 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 37 | 38 | use_preprocessed_data = training_config.semantic_trainer_cfg.use_preprocessed_data 39 | 40 | if use_preprocessed_data: 41 | clap = None 42 | wav2vec = None 43 | print(f'training from preprocessed data {training_config.semantic_trainer_cfg.folder}') 44 | else: 45 | print('loading clap...') 46 | clap = create_clap_quantized_from_config(model_config, args.rvq_path, device) 47 | 48 | print('loading wav2vec...') 49 | wav2vec = create_hubert_kmeans_from_config(model_config, args.kmeans_path, device) 50 | 51 | print('loading semantic stage...') 52 | semantic_transformer = create_semantic_transformer_from_config(model_config, args.fine_tune_from, device) 53 | 54 | trainer = create_single_stage_trainer_from_config( 55 | model_config=model_config, 56 | training_config=training_config, 57 | stage='semantic', 58 | results_folder=args.results_folder, 59 | transformer=semantic_transformer, 60 | clap=clap, 61 | wav2vec=wav2vec, 62 | encodec_wrapper=None, 63 | device=device, 64 | accelerate_kwargs={ 65 | 'log_with': "tensorboard", 66 | 'logging_dir': './logs/semantic' 67 | }, 68 | config_paths=[args.model_config, args.training_config]) 69 | 70 | load_checkpoint_from_args(trainer, args) 71 | 72 | trainer.train() 73 | -------------------------------------------------------------------------------- /scripts/train_utils.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | import logging 3 | import sys 4 | import os 5 | from pathlib import Path 6 | 7 | def exists(val): 8 | return val is not None 9 | 10 | class disable_print: 11 | def __enter__(self): 12 | self._original_stdout = sys.stdout 13 | sys.stdout = open(os.devnull, 'w') 14 | 15 | def __exit__(self, exc_type, exc_val, exc_tb): 16 | sys.stdout.close() 17 | sys.stdout = self._original_stdout 18 | 19 | def get_latest_checkpoints(results_folder, max_step=None): 20 | highest_transformer_step = -1 21 | highest_optimizer_step = -1 22 | highest_scheduler_step = -1 23 | transformer_path = None 24 | optimizer_path = None 25 | scheduler_path = None 26 | max_step = float('inf') if max_step is None else max_step 27 | for file in os.listdir(results_folder): 28 | if file.endswith('.pt'): 29 | if 'transformer' in file: 30 | step = int(file.split('.')[2]) 31 | if step > highest_transformer_step and step <= max_step: 32 | highest_transformer_step = step 33 | transformer_path = os.path.join(results_folder, file) 34 | elif 'optimizer' in file: 35 | step = int(file.split('.')[2]) 36 | if step > highest_optimizer_step and step <= max_step: 37 | highest_optimizer_step = step 38 | optimizer_path = os.path.join(results_folder, file) 39 | elif 'scheduler' in file: 40 | step = int(file.split('.')[2]) 41 | if step > highest_scheduler_step and step <= max_step: 42 | highest_scheduler_step = step 43 | scheduler_path = os.path.join(results_folder, file) 44 | 45 | assert highest_transformer_step == highest_optimizer_step, 'transformer and optimizer checkpoints are not aligned' 46 | if scheduler_path is not None: 47 | assert highest_transformer_step == highest_scheduler_step, 'transformer and scheduler checkpoints are not aligned' 48 | 49 | return (transformer_path, optimizer_path, scheduler_path), highest_transformer_step 50 | 51 | def validate_train_args(args): 52 | assert not(exists(args.fine_tune_from) and exists(args.continue_from_dir)), 'choose one: fine tune from a checkpoint or continue from a directory' 53 | 54 | print(f'saving results to {args.results_folder}, using model config {args.model_config} and training config {args.training_config}, using rvq checkpoint {args.rvq_path} and kmeans checkpoint {args.kmeans_path}') 55 | if exists(args.continue_from_dir): 56 | print(f'continuing from latest checkpoint in {args.continue_from_dir}') 57 | assert not Path(args.continue_from_dir) == Path(args.results_folder), 'continue_from_dir must be different from results_folder' 58 | elif exists(args.fine_tune_from): 59 | print(f'fine tuning from checkpoint {args.fine_tune_from}. Make sure to use the same model config as the base model.') 60 | 61 | def load_checkpoint_from_args(trainer, args): 62 | if exists(args.continue_from_dir): 63 | checkpoints, steps = get_latest_checkpoints(args.continue_from_dir, args.continue_from_step) 64 | print(f'loading checkpoints: {checkpoints}') 65 | trainer.load(*checkpoints, steps=steps+1) 66 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from setuptools import setup, find_packages 4 | 5 | # Load the version from file 6 | __version__ = Path("VERSION").read_text().strip() 7 | 8 | setup( 9 | name = 'open-musiclm', 10 | packages = find_packages(exclude=[]), 11 | include_package_data=True, 12 | version = __version__, 13 | license='MIT', 14 | description = 'Open MusicLM - Implementation of MusicLM, a text to music model published by Google Research, with a few modifications', 15 | author = 'Allen Zhang', 16 | long_description_content_type = 'text/markdown', 17 | url = 'https://github.com/zhvng/open-musiclm', 18 | keywords = [ 19 | 'artificial intelligence', 20 | 'deep learning', 21 | 'transformers', 22 | 'attention mechanism', 23 | 'audio generation', 24 | 'musiclm', 25 | ], 26 | install_requires=[ 27 | 'torch', 28 | 'torchvision', 29 | 'torchaudio', 30 | 'einops>=0.6.1', 31 | 'vector-quantize-pytorch>=1.2.2', 32 | 'librosa', 33 | 'torchlibrosa', 34 | 'ftfy', 35 | 'tqdm', 36 | 'transformers', 37 | 'encodec', 38 | 'gdown', 39 | 'accelerate>=0.17.0', 40 | 'beartype', 41 | 'joblib', 42 | 'h5py', 43 | 'scikit-learn', 44 | 'wget', 45 | ], 46 | classifiers=[ 47 | 'Development Status :: 4 - Beta', 48 | 'Intended Audience :: Developers', 49 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 50 | 'License :: OSI Approved :: MIT License', 51 | 'Programming Language :: Python :: 3.10', 52 | ], 53 | ) --------------------------------------------------------------------------------