├── .gitignore ├── LICENSE ├── README.md ├── benchmarks.z01 ├── benchmarks.zip ├── code ├── README.md ├── config.py ├── extract_clamp2.py ├── extract_m3.py ├── logs_clamp2_h_size_768_lr_5e-05_batch_128_scale_1_t_length_128_t_model_FacebookAI_xlm-roberta-base_t_dropout_True_m3_True.txt ├── logs_m3_p_size_64_p_length_512_t_layers_3_p_layers_12_h_size_768_lr_0.0001_batch_16_mask_0.45.txt ├── train_clamp2.py ├── train_m3.py └── utils.py ├── environment.yml ├── music_classification ├── README.md ├── config.py ├── inference_cls.py ├── train_cls.py └── utils.py ├── overview.jpg ├── process_data ├── README.md ├── batch_abc2xml.py ├── batch_interleaved_abc.py ├── batch_midi2mtf.py ├── batch_mtf2midi.py ├── batch_xml2abc.py ├── gpt4_summarize.py └── utils │ ├── abc2xml.py │ ├── pyparsing.py │ └── xml2abc.py ├── requirements.txt └── semantic_search ├── README.md ├── clamp2_score.py ├── semantic_search.py └── semantic_search_metrics.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 113 | .pdm.toml 114 | .pdm-python 115 | .pdm-build/ 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | #.idea/ 166 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Sander Wood 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CLaMP 2: Multimodal Music Information Retrieval Across 101 Languages Using Large Language Models 2 | 3 | **CLaMP 3 is out!** This repo is no longer maintained—check it out here: [CLaMP 3](https://github.com/sanderwood/clamp3). 4 | 5 | ![CLaMP 2 Overview](overview.jpg) 6 | 7 | ## Overview 8 | CLaMP 2 is a music information retrieval model compatible with 101 languages, designed to support both ABC notation (a text-based musical notation format) and MIDI (Musical Instrument Digital Interface). This repository provides a comprehensive suite of scripts for training models, extracting features, converting various musical data formats, generating multilingual summaries of music metadata using GPT-4, and performing music classification and semantic search tasks. By leveraging the multilingual capabilities of GPT-4, CLaMP 2 aims to enhance the accuracy and inclusivity of music retrieval across diverse linguistic and musical modalities. 9 | 10 | ### Links 11 | - [CLaMP 2 Code](https://github.com/sanderwood/clamp2) 12 | - [CLaMP 2 Paper](https://arxiv.org/pdf/2410.13267) 13 | - [CLaMP 2 Model Weights](https://huggingface.co/sander-wood/clamp2/blob/main/weights_clamp2_h_size_768_lr_5e-05_batch_128_scale_1_t_length_128_t_model_FacebookAI_xlm-roberta-base_t_dropout_True_m3_True.pth) 14 | - [M3 Model Weights](https://huggingface.co/sander-wood/clamp2/blob/main/weights_m3_p_size_64_p_length_512_t_layers_3_p_layers_12_h_size_768_lr_0.0001_batch_16_mask_0.45.pth) 15 | 16 | Note: The model weights for both CLaMP 2 and M3 should be placed under the `code/` folder to ensure proper loading. Make sure the config hyperparameters are correctly set. 17 | 18 | ## Repository Structure 19 | The repository is organized into the following main directories: 20 | 21 | - **`code/`**: Includes scripts for training the CLaMP 2 and M3 models and extracting features from music and text data. You can modify hyperparameters and file paths in the configuration files to suit your training needs. 22 | 23 | - **`music_classification/`**: Contains scripts for performing classification tasks via linear probe using the extracted features. This directory includes utilities for training linear classification models and making predictions on new feature data. 24 | 25 | - **`process_data/`**: Provides tools to convert between various musical data formats (ABC notation, MusicXML, MIDI, and MTF) and to summarize music metadata with GPT-4. Before using CLaMP 2 or M3, you should use these scripts to convert your files into interleaved ABC notation or MTF compatible with these models. 26 | 27 | - **`semantic_search/`**: Provides scripts for evaluating model performance, conducting semantic searches, and calculating similarity metrics based on extracted feature vectors. 28 | 29 | ## Getting Started 30 | ### Environment Setup 31 | To set up the environment for CLaMP 2, run the following commands: 32 | 33 | ```bash 34 | conda env create -f environment.yml 35 | conda activate clamp2 36 | ``` 37 | 38 | ### Data Preparation 39 | 1. **Convert Files**: Navigate to the `process_data/` folder and convert your music files to a compatible format (interleaved ABC notation or MTF) suitable for CLaMP 2 and M3. 40 | - Use the conversion scripts in this folder for tasks like converting MusicXML to ABC and MIDI to MTF. 41 | 42 | - After collecting MusicXML (sheet music) or MIDI (performance data), perform the following operations to convert them into interleaved ABC notation or MTF respectively for model training: 43 | 1. **Obtain Interleaved ABC Notation**: 44 | - Convert MusicXML files to ABC using `batch_xml2abc.py`. 45 | - Process the ABC files into interleaved notation using `batch_interleaved_abc.py`. 46 | 2. **Obtain MTF**: 47 | - Convert MIDI files to MTF format using `batch_midi2mtf.py`. 48 | 3. **Convert Interleaved ABC Back to XML (Optional)**: 49 | - Use `batch_xml2abc.py` to convert interleaved ABC files back to MusicXML. 50 | 4. **Convert MTF Back to MIDI (Optional)**: 51 | - Use `batch_mtf2midi.py` to convert MTF files back to MIDI format. 52 | 53 | 2. **Generate Multilingual Metadata Summaries**: After converting the music files, the next step is to generate multilingual summaries of the music metadata. This is done using the `gpt4_summarize.py` script, which leverages the GPT-4 API to create structured summaries in both English and a randomly selected non-English language. 54 | 55 | **Input Example**: The input to the summarization script consists of a JSON file representing the music metadata. Here’s an example of a music entry in JSON format: 56 | 57 | ```json 58 | { 59 | "title": "Hard Times Come Again No More", 60 | "composer": "Stephen Foster", 61 | "genres": ["Children's Music", "Folk"], 62 | "description": "\"Hard Times Come Again No More\" (sometimes referred to as \"Hard Times\") is an American parlor song written by Stephen Foster, reflecting themes of sorrow and hope.", 63 | "lyrics": "Let us pause in life's pleasures and count its many tears,\nWhile we all sup sorrow with the poor;\nThere's a song that will linger forever in our ears;\nOh! Hard times come again no more.\n\nChorus:\n'Tis the song, the sigh of the weary,\nHard Times, hard times, come again no more.\nMany days you have lingered around my cabin door;\nOh! Hard times come again no more.\n\nWhile we seek mirth and beauty and music light and gay,\nThere are frail forms fainting at the door;\nThough their voices are silent, their pleading looks will say\nOh! Hard times come again no more.\nChorus\n\nThere's a pale weeping maiden who toils her life away,\nWith a worn heart whose better days are o'er:\nThough her voice would be merry, 'tis sighing all the day,\nOh! Hard times come again no more.\nChorus\n\n'Tis a sigh that is wafted across the troubled wave,\n'Tis a wail that is heard upon the shore\n'Tis a dirge that is murmured around the lowly grave\nOh! Hard times come again no more.\nChorus", 64 | "tags": ["folk", "traditional", "bluegrass", "nostalgic", "heartfelt", "acoustic", "melancholic", "storytelling", "American roots", "resilience"], 65 | "ensembles": ["Folk Ensemble"], 66 | "instruments": ["Vocal", "Violin", "Tin whistle", "Guitar", "Banjo", "Tambourine"], 67 | "filepaths": [ 68 | "abc/American_Music/Folk_Traditions/19th_Century/Stephen_Foster/Hard_Times_Come_Again_No_More.abc", 69 | "mtf/American_Music/Folk_Traditions/19th_Century/Stephen_Foster/Hard_Times_Come_Again_No_More.mtf" 70 | ] 71 | } 72 | ``` 73 | The filepaths field contains relative paths starting from the shortest common root directory (e.g., abc/ or mtf/). This ensures that only the minimal shared part of the path is included, and each file is represented with a concise relative path from this root. 74 | 75 | **Output Example**: The output will be a JSON file containing the structured summary in both English and a selected non-English language. Here’s an example of the expected output: 76 | 77 | ```json 78 | { 79 | "title": "Hard Times Come Again No More", 80 | "composer": "Stephen Foster", 81 | "genres": ["Children's Music", "Folk"], 82 | "description": "\"Hard Times Come Again No More\" (sometimes referred to as \"Hard Times\") is an American parlor song written by Stephen Foster, reflecting themes of sorrow and hope.", 83 | "lyrics": "Let us pause in life's pleasures and count its many tears,\nWhile we all sup sorrow with the poor;\nThere's a song that will linger forever in our ears;\nOh! Hard times come again no more.\n\nChorus:\n'Tis the song, the sigh of the weary,\nHard Times, hard times, come again no more.\nMany days you have lingered around my cabin door;\nOh! Hard times come again no more.\n\nWhile we seek mirth and beauty and music light and gay,\nThere are frail forms fainting at the door;\nThough their voices are silent, their pleading looks will say\nOh! Hard times come again no more.\nChorus\n\nThere's a pale weeping maiden who toils her life away,\nWith a worn heart whose better days are o'er:\nThough her voice would be merry, 'tis sighing all the day,\nOh! Hard times come again no more.\nChorus\n\n'Tis a sigh that is wafted across the troubled wave,\n'Tis a wail that is heard upon the shore\n'Tis a dirge that is murmured around the lowly grave\nOh! Hard times come again no more.\nChorus", 84 | "tags": ["folk", "traditional", "bluegrass", "nostalgic", "heartfelt", "acoustic", "melancholic", "storytelling", "American roots", "resilience"], 85 | "ensembles": ["Folk Ensemble"], 86 | "instruments": ["Vocal", "Violin", "Tin whistle", "Guitar", "Banjo", "Tambourine"], 87 | "summary_en": "\"Hard Times Come Again No More,\" composed by Stephen Foster, is a poignant American parlor song that explores themes of sorrow and hope. The lyrics reflect on the contrast between life's pleasures and its hardships, inviting listeners to acknowledge both joy and suffering. With a heartfelt chorus that repeats the line \"Hard times come again no more,\" the song resonates with nostalgia and resilience. It is often performed by folk ensembles and features a variety of instruments, including vocals, violin, guitar, and banjo, encapsulating the spirit of American roots music.", 88 | "summary_nen": { 89 | "language": "Chinese (Simplified)", 90 | "summary": "《艰难时光再无来临》是斯蒂芬·福斯特创作的一首感人至深的美国小歌厅歌曲,探讨了悲伤与希望的主题。歌词展现了生活的乐趣与艰辛之间的对比,邀请听众去感受快乐与痛苦的交织。歌曲中那句反复吟唱的“艰难时光再无来临”深情地表达了怀旧与坚韧。它常常由民谣乐队演奏,伴随着人声、小提琴、吉他和班卓琴等多种乐器,生动地展现了美国根源音乐的独特魅力。" 91 | }, 92 | "filepaths": [ 93 | "abc/American_Music/Folk_Traditions/19th_Century/Stephen_Foster/Hard_Times_Come_Again_No_More.abc", 94 | "mtf/American_Music/Folk_Traditions/19th_Century/Stephen_Foster/Hard_Times_Come_Again_No_More.mtf" 95 | ] 96 | } 97 | ``` 98 | 99 | After generating the individual JSON files: 100 | 101 | 1. Merge all JSON files into a single JSONL file. 102 | 103 | 2. Place the merged JSONL file and the shortest common root directories (e.g., abc/ and/or mtf/) in the same folder, structured like this: 104 | 105 | ``` 106 | /your-target-folder/ 107 | ├── abc/ 108 | ├── mtf/ 109 | ├── merged_output.jsonl 110 | ``` 111 | 112 | ### Training and Feature Extraction 113 | 2. **Training Models**: If you want to train CLaMP 2 or M3 models, check the scripts in the `code/` folder. 114 | - Modify the `config.py` files to set your training hyperparameters and paths. 115 | 116 | 3. **Extracting Features**: After training, or if you have pre-trained models, you can extract features from your data using the respective scripts in the `code/` folder. 117 | 118 | ### Classification and Retrieval 119 | 4. **Classification**: If you need to classify the extracted features, navigate to the `music_classification/` directory. 120 | - Here, you'll find scripts to train linear classification models and perform inference on new data. 121 | 122 | 5. **Semantic Search**: To perform semantic searches using the extracted features, refer to the scripts in the `semantic_search/` folder. 123 | 124 | ## Benchmarks 125 | Benchmark datasets related to the experiments conducted with CLaMP 2 and M3, including data used for classification and semantic search tasks, are available in the `benchmarks.zip` file. Note that the `benchmarks.z01` file is required for proper extraction of the contents from `benchmarks.zip`. 126 | 127 | ## Citation 128 | 129 | If you use CLaMP 2 or M3 in your research, please cite the following paper: 130 | 131 | ```bibtex 132 | @misc{wu2024clamp2multimodalmusic, 133 | title={CLaMP 2: Multimodal Music Information Retrieval Across 101 Languages Using Large Language Models}, 134 | author={Shangda Wu and Yashan Wang and Ruibin Yuan and Zhancheng Guo and Xu Tan and Ge Zhang and Monan Zhou and Jing Chen and Xuefeng Mu and Yuejie Gao and Yuanliang Dong and Jiafeng Liu and Xiaobing Li and Feng Yu and Maosong Sun}, 135 | year={2024}, 136 | eprint={2410.13267}, 137 | archivePrefix={arXiv}, 138 | primaryClass={cs.SD}, 139 | url={https://arxiv.org/abs/2410.13267}, 140 | } 141 | -------------------------------------------------------------------------------- /benchmarks.z01: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanderwood/clamp2/8404488ef36a61735c0295a5afcd6d3a3d74bbcd/benchmarks.z01 -------------------------------------------------------------------------------- /benchmarks.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanderwood/clamp2/8404488ef36a61735c0295a5afcd6d3a3d74bbcd/benchmarks.zip -------------------------------------------------------------------------------- /code/README.md: -------------------------------------------------------------------------------- 1 | # CLaMP 2 Codebase 2 | 3 | ## Overview 4 | CLaMP 2 is a state-of-the-art multimodal music information retrieval system designed to work with 101 languages. This codebase includes scripts for training models, extracting features, and utility functions for processing music and text data. Below is a description of the scripts contained in the `code/` folder. 5 | 6 | ## Repository Structure 7 | The `code/` folder contains the following scripts: 8 | 9 | ### 1. `config.py` 10 | This script contains the training hyperparameters and file paths used in the `train_clamp2.py` and `train_m3.py` scripts. You can modify parameters such as learning rates, batch sizes, and file locations for training data. 11 | 12 | ### 2. `extract_clamp2.py` 13 | This script utilizes the pre-trained CLaMP 2 model to extract representations of text (.txt) or music (.abc or .mtf) from a specified input folder and save the features to a target output folder in `.npy` format. The extracted features can be normalized for semantic search or retain temporal information for classification tasks. 14 | 15 | **Usage:** 16 | 17 | It supports multi-GPU processing with `accelerate launch` for efficient extraction across multiple GPUs. 18 | 19 | ```bash 20 | accelerate launch extract_clamp2.py [--normalize] 21 | ``` 22 | - `input_dir`: Directory containing input data files. 23 | - `output_dir`: Directory to save the output features. 24 | - `--normalize`: (Optional) Normalize the extracted features. Normalization is not required for music classification tasks, but it is required for semantic search tasks. 25 | 26 | ### 3. `extract_m3.py` 27 | This script employs the pre-trained M3 model to extract representations in interleaved ABC notation and MIDI Text Format (MTF) from the specified input folder, saving the features to the target folder as `.npy` files. 28 | 29 | **Usage:** 30 | 31 | It supports multi-GPU processing with `accelerate launch` for efficient extraction across multiple GPUs. 32 | 33 | ```bash 34 | accelerate launch extract_m3.py 35 | ``` 36 | - `input_dir`: Directory with input files (in .abc or .mtf format). 37 | - `output_dir`: Directory to save extracted features. 38 | 39 | ### 4. `train_clamp2.py` 40 | This script manages the training process for the CLaMP 2 model. It prepares training data from a path specified in the `TRAIN_JSONL` variable, which is defined in the `config.py` file. If `EVAL_JSONL` is provided in the configuration, it will be used for validation. By default, 1% of the training data is reserved for validation. 41 | 42 | CLaMP 2 utilizes the multilingual text encoder `FacebookAI/xlm-roberta-base` for processing text data. Additionally, it employs the M3 model, pre-trained on both ABC and MIDI data, as the multimodal music encoder. If the pre-trained weights for M3 are available and the configuration variable `CLAMP2_LOAD_M3` is set to True, the training script will automatically load the M3 weights. 43 | 44 | **Training Command:** 45 | To start the training process, use the following command: 46 | 47 | ```bash 48 | python -m torch.distributed.launch --nproc_per_node= --use_env train_clamp2.py 49 | ``` 50 | 51 | Replace `` with the number of GPUs you want to use for training. 52 | 53 | **Input Data Format** 54 | The input training data should be in JSONL format, where each line contains a single JSON object with the following structure. Fields that do not apply should be set to `None`: 55 | 56 | ```json 57 | { 58 | "title": "Song Title", 59 | "composer": "Composer Name", 60 | "genres": ["Genre1", "Genre2"], 61 | "description": "Song description.", 62 | "lyrics": "Song lyrics.", 63 | "tags": ["tag1", "tag2"], 64 | "ensembles": ["Ensemble Name"], 65 | "instruments": ["Instrument1", "Instrument2"], 66 | "summary_en": "English summary.", 67 | "summary_nen": { 68 | "language": "Language Name", 69 | "summary": "Summary in specified language." 70 | }, 71 | "filepaths": [ 72 | "path/to/abc/file.abc", 73 | "path/to/mtf/file.mtf" 74 | ] 75 | } 76 | ``` 77 | 78 | For obtaining the English and non-English summaries generated by GPT-4, refer to the `process_data/gpt4_summarize.py` script. 79 | 80 | ### 5. `train_m3.py` 81 | This script is dedicated to training the M3 model using interleaved ABC and MTF files. The directories for training and optional evaluation data should be specified in the `TRAIN_FOLDERS` and `EVAL_FOLDERS` variables, respectively. 82 | 83 | **Training Command:** 84 | To start the training process for the M3 model, use the following command: 85 | 86 | ```bash 87 | python -m torch.distributed.launch --nproc_per_node= --use_env train_m3.py 88 | ``` 89 | 90 | Replace `` with the number of GPUs you want to use for training. 91 | 92 | **Data Preparation:** 93 | The data should be structured in interleaved ABC (.abc) and MTF (.mtf) formats. Please refer to the `process_data/` folder for instructions on how to prepare these formats. 94 | 95 | ### 6. `utils.py` 96 | This utility script contains various classes for model definitions and functions used for training. 97 | -------------------------------------------------------------------------------- /code/config.py: -------------------------------------------------------------------------------- 1 | EVAL_SPLIT = 0.01 # Fraction of training data used for evaluation 2 | WANDB_KEY = "" # Set M3/CLaMP2_WANDB_LOG=False if no API key for Weights and Biases logging 3 | 4 | # -------------------- Configuration for M3 Training -------------------- 5 | TRAIN_FOLDERS = [ 6 | "" # Directory containing training data 7 | ] 8 | 9 | EVAL_FOLDERS = [ 10 | "" # (Optional) Directory containing evaluation data 11 | ] 12 | 13 | PATCH_SIZE = 64 # Size of each patch 14 | PATCH_LENGTH = 512 # Length of the patches 15 | PATCH_NUM_LAYERS = 12 # Number of layers in the encoder 16 | TOKEN_NUM_LAYERS = 3 # Number of layers in the decoder 17 | M3_HIDDEN_SIZE = 768 # Size of the hidden layer 18 | 19 | M3_NUM_EPOCH = 100 # Maximum number of epochs for training 20 | M3_LEARNING_RATE = 1e-4 # Learning rate for the optimizer 21 | M3_BATCH_SIZE = 16 # Batch size per GPU (single card) during training 22 | M3_MASK_RATIO = 0.45 # Ratio of masked elements during training 23 | M3_DETERMINISTIC = True # Ensures deterministic results with random seeds 24 | M3_WANDB_LOG = True # Enable logging to Weights and Biases 25 | M3_LOAD_CKPT = True # Load model weights from a checkpoint if available 26 | 27 | M3_WEIGHTS_PATH = ( 28 | "weights_m3_p_size_" + str(PATCH_SIZE) + 29 | "_p_length_" + str(PATCH_LENGTH) + 30 | "_t_layers_" + str(TOKEN_NUM_LAYERS) + 31 | "_p_layers_" + str(PATCH_NUM_LAYERS) + 32 | "_h_size_" + str(M3_HIDDEN_SIZE) + 33 | "_lr_" + str(M3_LEARNING_RATE) + 34 | "_batch_" + str(M3_BATCH_SIZE) + 35 | "_mask_" + str(M3_MASK_RATIO) + ".pth" 36 | ) # Path to store the model weights 37 | M3_LOGS_PATH = M3_WEIGHTS_PATH.replace("weights", "logs").replace("pth", "txt") # Path to save training logs 38 | 39 | # -------------------- Configuration for CLaMP2 Training ---------------- 40 | TRAIN_JSONL = "" # Path to the JSONL file with training data 41 | EVAL_JSONL = "" # (Optional) Path to the JSONL file with evaluation data 42 | 43 | CLAMP2_HIDDEN_SIZE = 768 # Size of the hidden layer 44 | TEXT_MODEL_NAME = "FacebookAI/xlm-roberta-base" # Name of the pre-trained text model 45 | 46 | CLAMP2_NUM_EPOCH = 100 # Maximum number of epochs for training 47 | CLAMP2_LEARNING_RATE = 5e-5 # Learning rate for the optimizer 48 | CLAMP2_BATCH_SIZE = 128 # Batch size per GPU (single card) during training 49 | LOGIT_SCALE = 1 # Scaling factor for contrastive loss 50 | MAX_TEXT_LENGTH = 128 # Maximum allowed length for text input 51 | TEXT_DROPOUT = True # Whether to apply dropout during text processing 52 | CLAMP2_DETERMINISTIC = True # Ensures deterministic results with random seeds 53 | CLAMP2_LOAD_M3 = True # Load weights from the M3 model 54 | CLAMP2_WANDB_LOG = True # Enable logging to Weights and Biases 55 | CLAMP2_LOAD_CKPT = True # Load weights from a checkpoint if available 56 | 57 | CLAMP2_WEIGHTS_PATH = ( 58 | "weights_clamp2_h_size_" + str(CLAMP2_HIDDEN_SIZE) + 59 | "_lr_" + str(CLAMP2_LEARNING_RATE) + 60 | "_batch_" + str(CLAMP2_BATCH_SIZE) + 61 | "_scale_" + str(LOGIT_SCALE) + 62 | "_t_length_" + str(MAX_TEXT_LENGTH) + 63 | "_t_model_" + TEXT_MODEL_NAME.replace("/", "_") + 64 | "_t_dropout_" + str(TEXT_DROPOUT) + 65 | "_m3_" + str(CLAMP2_LOAD_M3) + ".pth" 66 | ) # Path to store CLaMP2 model weights 67 | CLAMP2_LOGS_PATH = CLAMP2_WEIGHTS_PATH.replace("weights", "logs").replace("pth", "txt") # Path to save training logs 68 | -------------------------------------------------------------------------------- /code/extract_clamp2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import torch 5 | import numpy as np 6 | from tqdm import tqdm 7 | from config import * 8 | from utils import * 9 | from samplings import * 10 | from accelerate import Accelerator 11 | from transformers import BertConfig, AutoTokenizer 12 | import argparse 13 | 14 | # Parse command-line arguments 15 | parser = argparse.ArgumentParser(description="Feature extraction for CLaMP2.") 16 | parser.add_argument("input_dir", type=str, help="Directory containing input data files.") 17 | parser.add_argument("output_dir", type=str, help="Directory to save the output features.") 18 | parser.add_argument("--normalize", action="store_true", help="Normalize the extracted features.") 19 | 20 | args = parser.parse_args() 21 | 22 | # Retrieve arguments 23 | input_dir = args.input_dir 24 | output_dir = args.output_dir 25 | normalize = args.normalize 26 | 27 | os.makedirs("logs", exist_ok=True) 28 | for file in ["logs/files_extract_clamp2.json", 29 | "logs/files_shuffle_extract_clamp2.json", 30 | "logs/log_extract_clamp2.txt", 31 | "logs/pass_extract_clamp2.txt", 32 | "logs/skip_extract_clamp2.txt"]: 33 | if os.path.exists(file): 34 | os.remove(file) 35 | 36 | files = [] 37 | for root, dirs, fs in os.walk(input_dir): 38 | for f in fs: 39 | if f.endswith(".txt") or f.endswith(".abc") or f.endswith(".mtf"): 40 | files.append(os.path.join(root, f)) 41 | print(f"Found {len(files)} files in total") 42 | with open("logs/files_extract_clamp2.json", "w", encoding="utf-8") as f: 43 | json.dump(files, f) 44 | random.shuffle(files) 45 | with open("logs/files_shuffle_extract_clamp2.json", "w", encoding="utf-8") as f: 46 | json.dump(files, f) 47 | 48 | accelerator = Accelerator() 49 | device = accelerator.device 50 | print("Using device:", device) 51 | with open("logs/log_extract_clamp2.txt", "a", encoding="utf-8") as f: 52 | f.write("Using device: " + str(device) + "\n") 53 | 54 | m3_config = BertConfig(vocab_size=1, 55 | hidden_size=M3_HIDDEN_SIZE, 56 | num_hidden_layers=PATCH_NUM_LAYERS, 57 | num_attention_heads=M3_HIDDEN_SIZE//64, 58 | intermediate_size=M3_HIDDEN_SIZE*4, 59 | max_position_embeddings=PATCH_LENGTH) 60 | model = CLaMP2Model(m3_config, 61 | text_model_name=TEXT_MODEL_NAME, 62 | hidden_size=CLAMP2_HIDDEN_SIZE, 63 | load_m3=CLAMP2_LOAD_M3) 64 | model = model.to(device) 65 | tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME) 66 | patchilizer = M3Patchilizer() 67 | 68 | # print parameter number 69 | print("Parameter Number: "+str(sum(p.numel() for p in model.parameters() if p.requires_grad))) 70 | 71 | model.eval() 72 | checkpoint = torch.load(CLAMP2_WEIGHTS_PATH, map_location='cpu', weights_only=True) 73 | print(f"Successfully Loaded CLaMP 2 Checkpoint from Epoch {checkpoint['epoch']} with loss {checkpoint['min_eval_loss']}") 74 | model.load_state_dict(checkpoint['model']) 75 | 76 | def extract_feature(filename, get_normalized=normalize): 77 | with open(filename, "r", encoding="utf-8") as f: 78 | item = f.read() 79 | 80 | if filename.endswith(".txt"): 81 | item = list(set(item.split("\n"))) 82 | item = "\n".join(item) 83 | item = item.split("\n") 84 | item = [c for c in item if len(c) > 0] 85 | item = tokenizer.sep_token.join(item) 86 | input_data = tokenizer(item, return_tensors="pt") 87 | input_data = input_data['input_ids'].squeeze(0) 88 | max_input_length = MAX_TEXT_LENGTH 89 | else: 90 | input_data = patchilizer.encode(item, add_special_patches=True) 91 | input_data = torch.tensor(input_data) 92 | max_input_length = PATCH_LENGTH 93 | 94 | segment_list = [] 95 | for i in range(0, len(input_data), max_input_length): 96 | segment_list.append(input_data[i:i+max_input_length]) 97 | segment_list[-1] = input_data[-max_input_length:] 98 | 99 | last_hidden_states_list = [] 100 | 101 | for input_segment in segment_list: 102 | input_masks = torch.tensor([1]*input_segment.size(0)) 103 | if filename.endswith(".txt"): 104 | pad_indices = torch.ones(MAX_TEXT_LENGTH - input_segment.size(0)).long() * tokenizer.pad_token_id 105 | else: 106 | pad_indices = torch.ones((PATCH_LENGTH - input_segment.size(0), PATCH_SIZE)).long() * patchilizer.pad_token_id 107 | input_masks = torch.cat((input_masks, torch.zeros(max_input_length - input_segment.size(0))), 0) 108 | input_segment = torch.cat((input_segment, pad_indices), 0) 109 | 110 | if filename.endswith(".txt"): 111 | last_hidden_states = model.get_text_features(text_inputs=input_segment.unsqueeze(0).to(device), 112 | text_masks=input_masks.unsqueeze(0).to(device), 113 | get_normalized=get_normalized) 114 | else: 115 | last_hidden_states = model.get_music_features(music_inputs=input_segment.unsqueeze(0).to(device), 116 | music_masks=input_masks.unsqueeze(0).to(device), 117 | get_normalized=get_normalized) 118 | if not get_normalized: 119 | last_hidden_states = last_hidden_states[:, :input_masks.sum().long().item(), :] 120 | last_hidden_states_list.append(last_hidden_states) 121 | 122 | if not get_normalized: 123 | last_hidden_states_list = [last_hidden_states[0] for last_hidden_states in last_hidden_states_list] 124 | last_hidden_states_list[-1] = last_hidden_states_list[-1][-(len(input_data)%max_input_length):] 125 | last_hidden_states_list = torch.concat(last_hidden_states_list, 0) 126 | else: 127 | full_chunk_cnt = len(input_data) // max_input_length 128 | remain_chunk_len = len(input_data) % max_input_length 129 | if remain_chunk_len == 0: 130 | feature_weights = torch.tensor([max_input_length] * full_chunk_cnt, device=device).view(-1, 1) 131 | else: 132 | feature_weights = torch.tensor([max_input_length] * full_chunk_cnt + [remain_chunk_len], device=device).view(-1, 1) 133 | 134 | last_hidden_states_list = torch.concat(last_hidden_states_list, 0) 135 | last_hidden_states_list = last_hidden_states_list * feature_weights 136 | last_hidden_states_list = last_hidden_states_list.sum(dim=0) / feature_weights.sum() 137 | 138 | return last_hidden_states_list 139 | 140 | def process_directory(input_dir, output_dir, files): 141 | print(f"Found {len(files)} files in total") 142 | with open("logs/log_extract_clamp2.txt", "a", encoding="utf-8") as f: 143 | f.write("Found " + str(len(files)) + " files in total\n") 144 | 145 | # calculate the number of files to process per GPU 146 | num_files_per_gpu = len(files) // accelerator.num_processes 147 | 148 | # calculate the start and end index for the current GPU 149 | start_idx = accelerator.process_index * num_files_per_gpu 150 | end_idx = start_idx + num_files_per_gpu 151 | if accelerator.process_index == accelerator.num_processes - 1: 152 | end_idx = len(files) 153 | 154 | files_to_process = files[start_idx:end_idx] 155 | 156 | # process the files 157 | for file in tqdm(files_to_process): 158 | output_subdir = output_dir + os.path.dirname(file)[len(input_dir):] 159 | try: 160 | os.makedirs(output_subdir, exist_ok=True) 161 | except Exception as e: 162 | print(output_subdir + " can not be created\n" + str(e)) 163 | with open("logs/log_extract_clamp2.txt", "a") as f: 164 | f.write(output_subdir + " can not be created\n" + str(e) + "\n") 165 | 166 | output_file = os.path.join(output_subdir, os.path.splitext(os.path.basename(file))[0] + ".npy") 167 | 168 | if os.path.exists(output_file): 169 | print(f"Skipping {file}, output already exists") 170 | with open("logs/skip_extract_clamp2.txt", "a", encoding="utf-8") as f: 171 | f.write(file + "\n") 172 | continue 173 | 174 | try: 175 | with torch.no_grad(): 176 | features = extract_feature(file).unsqueeze(0) 177 | np.save(output_file, features.detach().cpu().numpy()) 178 | with open("logs/pass_extract_clamp2.txt", "a", encoding="utf-8") as f: 179 | f.write(file + "\n") 180 | except Exception as e: 181 | print(f"Failed to process {file}: {e}") 182 | with open("logs/log_extract_clamp2.txt", "a", encoding="utf-8") as f: 183 | f.write("Failed to process " + file + ": " + str(e) + "\n") 184 | 185 | with open("logs/files_shuffle_extract_clamp2.json", "r", encoding="utf-8") as f: 186 | files = json.load(f) 187 | 188 | # process the files 189 | process_directory(input_dir, output_dir, files) 190 | 191 | with open("logs/log_extract_clamp2.txt", "a", encoding="utf-8") as f: 192 | f.write("GPU ID: " + str(device) + "\n") 193 | -------------------------------------------------------------------------------- /code/extract_m3.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import torch 5 | import numpy as np 6 | from tqdm import tqdm 7 | from config import * 8 | from utils import * 9 | from samplings import * 10 | from accelerate import Accelerator 11 | from transformers import BertConfig, GPT2Config 12 | import argparse 13 | 14 | # Parse command-line arguments for input_dir and output_dir 15 | parser = argparse.ArgumentParser(description="Process files to extract features.") 16 | parser.add_argument("input_dir", type=str, help="Directory with input files") 17 | parser.add_argument("output_dir", type=str, help="Directory to save extracted features") 18 | args = parser.parse_args() 19 | 20 | # Use args for input and output directories 21 | input_dir = args.input_dir 22 | output_dir = args.output_dir 23 | 24 | # Create logs directory if it doesn't exist 25 | os.makedirs("logs", exist_ok=True) 26 | 27 | # Remove existing log files if present 28 | for file in [ 29 | "logs/files_extract_m3.json", 30 | "logs/files_shuffle_extract_m3.json", 31 | "logs/log_extract_m3.txt", 32 | "logs/pass_extract_m3.txt", 33 | "logs/skip_extract_m3.txt", 34 | ]: 35 | if os.path.exists(file): 36 | os.remove(file) 37 | 38 | # Collect input files 39 | files = [] 40 | for root, dirs, fs in os.walk(input_dir): 41 | for f in fs: 42 | if f.endswith(".abc") or f.endswith(".mtf"): 43 | files.append(os.path.join(root, f)) 44 | 45 | print(f"Found {len(files)} files in total") 46 | with open("logs/files_extract_m3.json", "w", encoding="utf-8") as f: 47 | json.dump(files, f) 48 | 49 | # Shuffle files and save the shuffled order 50 | random.shuffle(files) 51 | with open("logs/files_shuffle_extract_m3.json", "w", encoding="utf-8") as f: 52 | json.dump(files, f) 53 | 54 | # Initialize accelerator and device 55 | accelerator = Accelerator() 56 | device = accelerator.device 57 | print("Using device:", device) 58 | with open("logs/log_extract_m3.txt", "a", encoding="utf-8") as f: 59 | f.write("Using device: " + str(device) + "\n") 60 | 61 | # Model and configuration setup 62 | patchilizer = M3Patchilizer() 63 | encoder_config = BertConfig( 64 | vocab_size=1, 65 | hidden_size=M3_HIDDEN_SIZE, 66 | num_hidden_layers=PATCH_NUM_LAYERS, 67 | num_attention_heads=M3_HIDDEN_SIZE // 64, 68 | intermediate_size=M3_HIDDEN_SIZE * 4, 69 | max_position_embeddings=PATCH_LENGTH, 70 | ) 71 | decoder_config = GPT2Config( 72 | vocab_size=128, 73 | n_positions=PATCH_SIZE, 74 | n_embd=M3_HIDDEN_SIZE, 75 | n_layer=TOKEN_NUM_LAYERS, 76 | n_head=M3_HIDDEN_SIZE // 64, 77 | n_inner=M3_HIDDEN_SIZE * 4, 78 | ) 79 | model = M3Model(encoder_config, decoder_config).to(device) 80 | 81 | # Print parameter count 82 | print("Parameter Number: " + str(sum(p.numel() for p in model.parameters() if p.requires_grad))) 83 | 84 | # Load model weights 85 | model.eval() 86 | checkpoint = torch.load(M3_WEIGHTS_PATH, map_location='cpu', weights_only=True) 87 | print(f"Successfully Loaded Checkpoint from Epoch {checkpoint['epoch']} with loss {checkpoint['min_eval_loss']}") 88 | model.load_state_dict(checkpoint['model']) 89 | 90 | def extract_feature(item): 91 | """Extracts features from input data.""" 92 | target_patches = patchilizer.encode(item, add_special_patches=True) 93 | target_patches_list = [target_patches[i:i + PATCH_LENGTH] for i in range(0, len(target_patches), PATCH_LENGTH)] 94 | target_patches_list[-1] = target_patches[-PATCH_LENGTH:] 95 | 96 | last_hidden_states_list = [] 97 | for input_patches in target_patches_list: 98 | input_masks = torch.tensor([1] * len(input_patches)) 99 | input_patches = torch.tensor(input_patches) 100 | last_hidden_states = model.encoder( 101 | input_patches.unsqueeze(0).to(device), input_masks.unsqueeze(0).to(device) 102 | )["last_hidden_state"][0] 103 | last_hidden_states_list.append(last_hidden_states) 104 | 105 | # Handle the last segment padding correctly 106 | last_hidden_states_list[-1] = last_hidden_states_list[-1][-(len(target_patches) % PATCH_LENGTH):] 107 | return torch.concat(last_hidden_states_list, 0) 108 | 109 | def process_directory(input_dir, output_dir, files): 110 | """Processes files in the input directory and saves features to the output directory.""" 111 | print(f"Found {len(files)} files in total") 112 | with open("logs/log_extract_m3.txt", "a", encoding="utf-8") as f: 113 | f.write("Found " + str(len(files)) + " files in total\n") 114 | 115 | # Distribute files across processes for parallel processing 116 | num_files_per_gpu = len(files) // accelerator.num_processes 117 | start_idx = accelerator.process_index * num_files_per_gpu 118 | end_idx = start_idx + num_files_per_gpu if accelerator.process_index < accelerator.num_processes - 1 else len(files) 119 | files_to_process = files[start_idx:end_idx] 120 | 121 | # Process each file 122 | for file in tqdm(files_to_process): 123 | output_subdir = output_dir + os.path.dirname(file)[len(input_dir):] 124 | try: 125 | os.makedirs(output_subdir, exist_ok=True) 126 | except Exception as e: 127 | print(f"{output_subdir} cannot be created\n{e}") 128 | with open("logs/log_extract_m3.txt", "a") as f: 129 | f.write(f"{output_subdir} cannot be created\n{e}\n") 130 | 131 | output_file = os.path.join(output_subdir, os.path.splitext(os.path.basename(file))[0] + ".npy") 132 | 133 | if os.path.exists(output_file): 134 | print(f"Skipping {file}, output already exists") 135 | with open("logs/skip_extract_m3.txt", "a", encoding="utf-8") as f: 136 | f.write(file + "\n") 137 | continue 138 | 139 | try: 140 | with open(file, "r", encoding="utf-8") as f: 141 | item = f.read() 142 | if not item.startswith("ticks_per_beat"): 143 | item = item.replace("L:1/8\n", "") 144 | with torch.no_grad(): 145 | features = extract_feature(item).unsqueeze(0) 146 | np.save(output_file, features.detach().cpu().numpy()) 147 | with open("logs/pass_extract_m3.txt", "a", encoding="utf-8") as f: 148 | f.write(file + "\n") 149 | except Exception as e: 150 | print(f"Failed to process {file}: {e}") 151 | with open("logs/log_extract_m3.txt", "a", encoding="utf-8") as f: 152 | f.write(f"Failed to process {file}: {e}\n") 153 | 154 | # Load shuffled files list and start processing 155 | with open("logs/files_shuffle_extract_m3.json", "r", encoding="utf-8") as f: 156 | files = json.load(f) 157 | 158 | # Process the directory 159 | process_directory(input_dir, output_dir, files) 160 | 161 | with open("logs/log_extract_m3.txt", "a", encoding="utf-8") as f: 162 | f.write("GPU ID: " + str(device) + "\n") 163 | -------------------------------------------------------------------------------- /code/logs_clamp2_h_size_768_lr_5e-05_batch_128_scale_1_t_length_128_t_model_FacebookAI_xlm-roberta-base_t_dropout_True_m3_True.txt: -------------------------------------------------------------------------------- 1 | Epoch 1 2 | train_loss: 4.135209383230448 3 | eval_loss: 1.9609466393788655 4 | time: Sun Sep 15 16:53:25 2024 5 | 6 | Epoch 2 7 | train_loss: 2.9014642586344244 8 | eval_loss: 1.518966309229533 9 | time: Sun Sep 15 18:43:11 2024 10 | 11 | Epoch 3 12 | train_loss: 2.5554833216573805 13 | eval_loss: 1.2929850498835245 14 | time: Sun Sep 15 20:35:14 2024 15 | 16 | Epoch 4 17 | train_loss: 2.3372960954320985 18 | eval_loss: 1.1918509085973104 19 | time: Sun Sep 15 22:23:17 2024 20 | 21 | Epoch 5 22 | train_loss: 2.183457492896627 23 | eval_loss: 1.0725035190582275 24 | time: Mon Sep 16 00:11:33 2024 25 | 26 | Epoch 6 27 | train_loss: 2.0555087922796216 28 | eval_loss: 0.9798878033955892 29 | time: Mon Sep 16 01:59:41 2024 30 | 31 | Epoch 7 32 | train_loss: 1.9556777615308922 33 | eval_loss: 0.9242686867713928 34 | time: Mon Sep 16 03:47:45 2024 35 | 36 | Epoch 8 37 | train_loss: 1.8689270445336208 38 | eval_loss: 0.8476833502451578 39 | time: Mon Sep 16 05:35:55 2024 40 | 41 | Epoch 9 42 | train_loss: 1.7932923043361795 43 | eval_loss: 0.8043208519617716 44 | time: Mon Sep 16 07:24:40 2024 45 | 46 | Epoch 10 47 | train_loss: 1.726651672090251 48 | eval_loss: 0.7419429302215577 49 | time: Mon Sep 16 09:12:24 2024 50 | 51 | Epoch 11 52 | train_loss: 1.6657210920084013 53 | eval_loss: 0.7209376732508341 54 | time: Mon Sep 16 11:01:02 2024 55 | 56 | Epoch 12 57 | train_loss: 1.6131336856822078 58 | eval_loss: 0.6950737277666728 59 | time: Mon Sep 16 12:50:07 2024 60 | 61 | Epoch 13 62 | train_loss: 1.5601606700647368 63 | eval_loss: 0.652581254641215 64 | time: Mon Sep 16 14:40:19 2024 65 | 66 | Epoch 14 67 | train_loss: 1.5198849670231784 68 | eval_loss: 0.5868058403333029 69 | time: Mon Sep 16 16:30:09 2024 70 | 71 | Epoch 15 72 | train_loss: 1.4755114693644882 73 | eval_loss: 0.5867449243863424 74 | time: Mon Sep 16 18:20:11 2024 75 | 76 | Epoch 16 77 | train_loss: 1.4336211351749464 78 | eval_loss: 0.5479268968105316 79 | time: Mon Sep 16 20:10:42 2024 80 | 81 | Epoch 17 82 | train_loss: 1.4039166571722088 83 | eval_loss: 0.5280438164869944 84 | time: Mon Sep 16 22:01:01 2024 85 | 86 | Epoch 18 87 | train_loss: 1.365380759842085 88 | eval_loss: 0.5008598109086354 89 | time: Mon Sep 16 23:51:36 2024 90 | 91 | Epoch 19 92 | train_loss: 1.332988672848894 93 | eval_loss: 0.46479900081952413 94 | time: Tue Sep 17 01:41:49 2024 95 | 96 | Epoch 20 97 | train_loss: 1.3014001791981087 98 | eval_loss: 0.45230263272921245 99 | time: Tue Sep 17 03:33:12 2024 100 | 101 | Epoch 21 102 | train_loss: 1.2688752540577755 103 | eval_loss: 0.4297992348670959 104 | time: Tue Sep 17 05:23:40 2024 105 | 106 | Epoch 22 107 | train_loss: 1.2425967695381415 108 | eval_loss: 0.4219102164109548 109 | time: Tue Sep 17 07:14:26 2024 110 | 111 | Epoch 23 112 | train_loss: 1.216824040410488 113 | eval_loss: 0.40282649795214337 114 | time: Tue Sep 17 09:05:53 2024 115 | 116 | Epoch 24 117 | train_loss: 1.1875996505747286 118 | eval_loss: 0.36659018794695536 119 | time: Tue Sep 17 10:56:46 2024 120 | 121 | Epoch 25 122 | train_loss: 1.1670776548255222 123 | eval_loss: 0.36906688412030536 124 | time: Tue Sep 17 12:47:39 2024 125 | 126 | Epoch 26 127 | train_loss: 1.1426405137974536 128 | eval_loss: 0.3478178918361664 129 | time: Tue Sep 17 14:38:34 2024 130 | 131 | Epoch 27 132 | train_loss: 1.1208335824466733 133 | eval_loss: 0.33407697081565857 134 | time: Tue Sep 17 16:28:45 2024 135 | 136 | Epoch 28 137 | train_loss: 1.0998876758880667 138 | eval_loss: 0.33792892495791116 139 | time: Tue Sep 17 18:19:59 2024 140 | 141 | Epoch 29 142 | train_loss: 1.0769698478377083 143 | eval_loss: 0.3026650925477346 144 | time: Tue Sep 17 20:11:16 2024 145 | 146 | Epoch 30 147 | train_loss: 1.0587592209657248 148 | eval_loss: 0.2914476583401362 149 | time: Tue Sep 17 22:03:30 2024 150 | 151 | Epoch 31 152 | train_loss: 1.0384011404245468 153 | eval_loss: 0.27578969597816466 154 | time: Tue Sep 17 23:55:15 2024 155 | 156 | Epoch 32 157 | train_loss: 1.0233595809527622 158 | eval_loss: 0.2651842157046 159 | time: Wed Sep 18 01:46:45 2024 160 | 161 | Epoch 33 162 | train_loss: 1.001824217418977 163 | eval_loss: 0.2630385269721349 164 | time: Wed Sep 18 03:39:08 2024 165 | 166 | Epoch 34 167 | train_loss: 0.9853754720520442 168 | eval_loss: 0.25253995358943937 169 | time: Wed Sep 18 05:30:33 2024 170 | 171 | Epoch 35 172 | train_loss: 0.9676362536067821 173 | eval_loss: 0.24096360007921855 174 | time: Wed Sep 18 07:22:09 2024 175 | 176 | Epoch 36 177 | train_loss: 0.9507065269691086 178 | eval_loss: 0.2413844664891561 179 | time: Wed Sep 18 09:12:59 2024 180 | 181 | Epoch 37 182 | train_loss: 0.9362979678186832 183 | eval_loss: 0.23412639300028484 184 | time: Wed Sep 18 11:04:09 2024 185 | 186 | Epoch 38 187 | train_loss: 0.9174621180856977 188 | eval_loss: 0.21386308073997498 189 | time: Wed Sep 18 12:54:52 2024 190 | 191 | Epoch 39 192 | train_loss: 0.9090870427650668 193 | eval_loss: 0.19962686796983084 194 | time: Wed Sep 18 14:45:52 2024 195 | 196 | Epoch 40 197 | train_loss: 0.8918763521271409 198 | eval_loss: 0.20026112000147503 199 | time: Wed Sep 18 16:36:37 2024 200 | 201 | Epoch 41 202 | train_loss: 0.8786202421428222 203 | eval_loss: 0.18366556564966838 204 | time: Wed Sep 18 18:27:31 2024 205 | 206 | Epoch 42 207 | train_loss: 0.8670675420604148 208 | eval_loss: 0.17908457616964976 209 | time: Wed Sep 18 20:18:16 2024 210 | 211 | Epoch 43 212 | train_loss: 0.8505593872931582 213 | eval_loss: 0.17053016225496928 214 | time: Wed Sep 18 22:10:39 2024 215 | 216 | Epoch 44 217 | train_loss: 0.8421949260766888 218 | eval_loss: 0.17344878117243448 219 | time: Thu Sep 19 00:02:24 2024 220 | 221 | Epoch 45 222 | train_loss: 0.8267569324702205 223 | eval_loss: 0.1591893643140793 224 | time: Thu Sep 19 01:53:48 2024 225 | 226 | Epoch 46 227 | train_loss: 0.8144617894466949 228 | eval_loss: 0.15313500861326854 229 | time: Thu Sep 19 03:44:58 2024 230 | 231 | Epoch 47 232 | train_loss: 0.8041844731303666 233 | eval_loss: 0.14998503575722377 234 | time: Thu Sep 19 05:36:50 2024 235 | 236 | Epoch 48 237 | train_loss: 0.7938160687423412 238 | eval_loss: 0.1401842971642812 239 | time: Thu Sep 19 07:28:21 2024 240 | 241 | Epoch 49 242 | train_loss: 0.7808867423096515 243 | eval_loss: 0.1368137091398239 244 | time: Thu Sep 19 09:20:09 2024 245 | 246 | Epoch 50 247 | train_loss: 0.7702171771933628 248 | eval_loss: 0.13333487262328467 249 | time: Thu Sep 19 11:12:37 2024 250 | 251 | Epoch 51 252 | train_loss: 0.7604444062967384 253 | eval_loss: 0.13119754443566004 254 | time: Thu Sep 19 13:04:26 2024 255 | 256 | Epoch 52 257 | train_loss: 0.7496546459894258 258 | eval_loss: 0.1236343190073967 259 | time: Thu Sep 19 14:55:53 2024 260 | 261 | Epoch 53 262 | train_loss: 0.7406523988345118 263 | eval_loss: 0.12237562835216523 264 | time: Thu Sep 19 16:47:51 2024 265 | 266 | Epoch 54 267 | train_loss: 0.7331518270251398 268 | eval_loss: 0.11441469887892405 269 | time: Thu Sep 19 18:38:48 2024 270 | 271 | Epoch 55 272 | train_loss: 0.7238280263746373 273 | eval_loss: 0.10651812156041464 274 | time: Thu Sep 19 20:29:18 2024 275 | 276 | Epoch 56 277 | train_loss: 0.7141688125488486 278 | eval_loss: 0.10959143290917078 279 | time: Thu Sep 19 22:19:28 2024 280 | 281 | Epoch 57 282 | train_loss: 0.7053173944645842 283 | eval_loss: 0.10957898745934168 284 | time: Fri Sep 20 00:10:06 2024 285 | 286 | Epoch 58 287 | train_loss: 0.6992166797548109 288 | eval_loss: 0.09759224901596705 289 | time: Fri Sep 20 02:01:02 2024 290 | 291 | Epoch 59 292 | train_loss: 0.6855367768623795 293 | eval_loss: 0.10631066560745239 294 | time: Fri Sep 20 03:51:25 2024 295 | 296 | Epoch 60 297 | train_loss: 0.6812366953699432 298 | eval_loss: 0.08681503732999166 299 | time: Fri Sep 20 05:41:32 2024 300 | 301 | Epoch 61 302 | train_loss: 0.6744320154854127 303 | eval_loss: 0.08995070978999138 304 | time: Fri Sep 20 07:32:33 2024 305 | 306 | Epoch 62 307 | train_loss: 0.6627048003782218 308 | eval_loss: 0.08492780551314354 309 | time: Fri Sep 20 09:22:52 2024 310 | 311 | Epoch 63 312 | train_loss: 0.6554694614403961 313 | eval_loss: 0.09110054125388463 314 | time: Fri Sep 20 11:15:14 2024 315 | 316 | Epoch 64 317 | train_loss: 0.6519363358224428 318 | eval_loss: 0.08603844990332922 319 | time: Fri Sep 20 13:05:45 2024 320 | 321 | Epoch 65 322 | train_loss: 0.6432196787488694 323 | eval_loss: 0.07920929342508316 324 | time: Fri Sep 20 14:56:27 2024 325 | 326 | Epoch 66 327 | train_loss: 0.6355774498505016 328 | eval_loss: 0.08108622878789902 329 | time: Fri Sep 20 16:47:00 2024 330 | 331 | Epoch 67 332 | train_loss: 0.628098195042665 333 | eval_loss: 0.0835166151324908 334 | time: Fri Sep 20 18:37:19 2024 335 | 336 | Epoch 68 337 | train_loss: 0.6229319736150211 338 | eval_loss: 0.08126899500687917 339 | time: Fri Sep 20 20:27:49 2024 340 | 341 | Epoch 69 342 | train_loss: 0.6162204064685376 343 | eval_loss: 0.07405624414483707 344 | time: Fri Sep 20 22:18:28 2024 345 | 346 | Epoch 70 347 | train_loss: 0.6093617768645045 348 | eval_loss: 0.07916868552565574 349 | time: Sat Sep 21 00:10:02 2024 350 | 351 | Epoch 71 352 | train_loss: 0.603765148576412 353 | eval_loss: 0.07368899683157602 354 | time: Sat Sep 21 02:00:29 2024 355 | 356 | Epoch 72 357 | train_loss: 0.5988557130088281 358 | eval_loss: 0.06763924509286881 359 | time: Sat Sep 21 03:51:46 2024 360 | 361 | Epoch 73 362 | train_loss: 0.590835969827209 363 | eval_loss: 0.07139033873875936 364 | time: Sat Sep 21 05:43:51 2024 365 | 366 | Epoch 74 367 | train_loss: 0.5864904869113879 368 | eval_loss: 0.06859012718002001 369 | time: Sat Sep 21 07:34:23 2024 370 | 371 | Epoch 75 372 | train_loss: 0.5819329118342274 373 | eval_loss: 0.07611284777522087 374 | time: Sat Sep 21 09:25:24 2024 375 | 376 | Epoch 76 377 | train_loss: 0.5750655913014898 378 | eval_loss: 0.06813529431819916 379 | time: Sat Sep 21 11:16:26 2024 380 | 381 | Epoch 77 382 | train_loss: 0.5703848759963817 383 | eval_loss: 0.07192744488517443 384 | time: Sat Sep 21 13:07:32 2024 385 | 386 | Epoch 78 387 | train_loss: 0.5666614368024667 388 | eval_loss: 0.06931692684690158 389 | time: Sat Sep 21 14:59:16 2024 390 | 391 | Epoch 79 392 | train_loss: 0.5610024514409998 393 | eval_loss: 0.06487631574273109 394 | time: Sat Sep 21 16:50:56 2024 395 | 396 | Epoch 80 397 | train_loss: 0.5552226794301296 398 | eval_loss: 0.06034566586216291 399 | time: Sat Sep 21 18:43:49 2024 400 | 401 | Epoch 81 402 | train_loss: 0.5512203840912394 403 | eval_loss: 0.05962909683585167 404 | time: Sat Sep 21 20:36:01 2024 405 | 406 | Epoch 82 407 | train_loss: 0.5477618443893468 408 | eval_loss: 0.05546447386344274 409 | time: Sat Sep 21 22:28:13 2024 410 | 411 | Epoch 83 412 | train_loss: 0.5428704522615506 413 | eval_loss: 0.05013169844945272 414 | time: Sun Sep 22 00:21:20 2024 415 | 416 | Epoch 84 417 | train_loss: 0.5396500316264258 418 | eval_loss: 0.062498694161574046 419 | time: Sun Sep 22 02:13:07 2024 420 | 421 | Epoch 85 422 | train_loss: 0.5349479554715307 423 | eval_loss: 0.06073434228698413 424 | time: Sun Sep 22 04:05:17 2024 425 | 426 | Epoch 86 427 | train_loss: 0.5292192482811466 428 | eval_loss: 0.05734321524699529 429 | time: Sun Sep 22 05:57:05 2024 430 | 431 | Epoch 87 432 | train_loss: 0.5249555090607058 433 | eval_loss: 0.05274935985604922 434 | time: Sun Sep 22 07:48:52 2024 435 | 436 | Epoch 88 437 | train_loss: 0.523276918144503 438 | eval_loss: 0.05601314604282379 439 | time: Sun Sep 22 09:41:05 2024 440 | 441 | Epoch 89 442 | train_loss: 0.5179934711230115 443 | eval_loss: 0.057493301729361214 444 | time: Sun Sep 22 11:33:47 2024 445 | 446 | Epoch 90 447 | train_loss: 0.5129834874146376 448 | eval_loss: 0.05289425750573476 449 | time: Sun Sep 22 13:25:54 2024 450 | 451 | Epoch 91 452 | train_loss: 0.5104886514866054 453 | eval_loss: 0.0586332509915034 454 | time: Sun Sep 22 15:18:13 2024 455 | 456 | Epoch 92 457 | train_loss: 0.5067275374282622 458 | eval_loss: 0.0489634457975626 459 | time: Sun Sep 22 17:10:39 2024 460 | 461 | Epoch 93 462 | train_loss: 0.5038576471461468 463 | eval_loss: 0.05257208868861198 464 | time: Sun Sep 22 19:04:46 2024 465 | 466 | Epoch 94 467 | train_loss: 0.5013840998762528 468 | eval_loss: 0.05249967947602272 469 | time: Sun Sep 22 20:57:55 2024 470 | 471 | Epoch 95 472 | train_loss: 0.4949465335763684 473 | eval_loss: 0.048154672731955846 474 | time: Sun Sep 22 22:50:30 2024 475 | 476 | Epoch 96 477 | train_loss: 0.4925781255166608 478 | eval_loss: 0.052830965568621956 479 | time: Mon Sep 23 00:43:13 2024 480 | 481 | Epoch 97 482 | train_loss: 0.4875780233282 483 | eval_loss: 0.04684837857882182 484 | time: Mon Sep 23 02:35:38 2024 485 | 486 | Epoch 98 487 | train_loss: 0.4858591078021573 488 | eval_loss: 0.04507673804958661 489 | time: Mon Sep 23 04:28:25 2024 490 | 491 | Epoch 99 492 | train_loss: 0.4804891498405977 493 | eval_loss: 0.048148307204246524 494 | time: Mon Sep 23 06:21:11 2024 495 | 496 | Epoch 100 497 | train_loss: 0.4782898508661265 498 | eval_loss: 0.044317328557372096 499 | time: Mon Sep 23 08:13:38 2024 500 | 501 | -------------------------------------------------------------------------------- /code/logs_m3_p_size_64_p_length_512_t_layers_3_p_layers_12_h_size_768_lr_0.0001_batch_16_mask_0.45.txt: -------------------------------------------------------------------------------- 1 | Epoch 1 2 | train_loss: 0.3055062843047765 3 | eval_loss: 0.03727418192925584 4 | time: Wed Aug 7 08:54:27 2024 5 | 6 | Epoch 2 7 | train_loss: 0.1718286834194018 8 | eval_loss: 0.020085958587123625 9 | time: Wed Aug 7 14:38:40 2024 10 | 11 | Epoch 3 12 | train_loss: 0.1476379437283353 13 | eval_loss: 0.013794219702922342 14 | time: Wed Aug 7 20:24:50 2024 15 | 16 | Epoch 4 17 | train_loss: 0.13554848474242498 18 | eval_loss: 0.011902455668817844 19 | time: Thu Aug 8 02:13:14 2024 20 | 21 | Epoch 5 22 | train_loss: 0.12781724531702496 23 | eval_loss: 0.008929020740909163 24 | time: Thu Aug 8 08:00:32 2024 25 | 26 | Epoch 6 27 | train_loss: 0.12264176163121285 28 | eval_loss: 0.008453744098166877 29 | time: Thu Aug 8 13:47:12 2024 30 | 31 | Epoch 7 32 | train_loss: 0.11872020949762974 33 | eval_loss: 0.007165850172573819 34 | time: Thu Aug 8 19:34:09 2024 35 | 36 | Epoch 8 37 | train_loss: 0.1153576639058103 38 | eval_loss: 0.006601243383027142 39 | time: Fri Aug 9 01:22:19 2024 40 | 41 | Epoch 9 42 | train_loss: 0.11312788720856465 43 | eval_loss: 0.005973297544609645 44 | time: Fri Aug 9 07:11:25 2024 45 | 46 | Epoch 10 47 | train_loss: 0.11096722687304313 48 | eval_loss: 0.005796642405204946 49 | time: Fri Aug 9 12:59:13 2024 50 | 51 | Epoch 11 52 | train_loss: 0.10913465011501206 53 | eval_loss: 0.005249736892483845 54 | time: Fri Aug 9 18:46:14 2024 55 | 56 | Epoch 12 57 | train_loss: 0.10780615577682347 58 | eval_loss: 0.005115668955435858 59 | time: Sat Aug 10 00:34:28 2024 60 | 61 | Epoch 13 62 | train_loss: 0.10650418949283817 63 | eval_loss: 0.00475350690255028 64 | time: Sat Aug 10 06:22:51 2024 65 | 66 | Epoch 14 67 | train_loss: 0.10524643352798381 68 | eval_loss: 0.004583307054632575 69 | time: Sat Aug 10 12:11:18 2024 70 | 71 | Epoch 15 72 | train_loss: 0.1041887047117438 73 | eval_loss: 0.004289783886142609 74 | time: Sat Aug 10 17:59:48 2024 75 | 76 | Epoch 16 77 | train_loss: 0.10343191375801945 78 | eval_loss: 0.004421581262192111 79 | time: Sat Aug 10 23:47:37 2024 80 | 81 | Epoch 17 82 | train_loss: 0.10256196519161385 83 | eval_loss: 0.004104017818401634 84 | time: Sun Aug 11 05:35:49 2024 85 | 86 | Epoch 18 87 | train_loss: 0.10170993055767087 88 | eval_loss: 0.0039769458375234585 89 | time: Sun Aug 11 11:23:39 2024 90 | 91 | Epoch 19 92 | train_loss: 0.1011880517951369 93 | eval_loss: 0.0039005329324529833 94 | time: Sun Aug 11 17:11:19 2024 95 | 96 | Epoch 20 97 | train_loss: 0.10030771156829077 98 | eval_loss: 0.0036845325137673237 99 | time: Sun Aug 11 22:59:49 2024 100 | 101 | Epoch 21 102 | train_loss: 0.09972109616302548 103 | eval_loss: 0.0038503893043940205 104 | time: Mon Aug 12 04:48:03 2024 105 | 106 | Epoch 22 107 | train_loss: 0.09932596696744844 108 | eval_loss: 0.00370702411211194 109 | time: Mon Aug 12 10:36:32 2024 110 | 111 | Epoch 23 112 | train_loss: 0.09888291950362459 113 | eval_loss: 0.0034573812171313834 114 | time: Mon Aug 12 16:24:55 2024 115 | 116 | Epoch 24 117 | train_loss: 0.09852503939581284 118 | eval_loss: 0.003370235667697582 119 | time: Mon Aug 12 22:12:14 2024 120 | 121 | Epoch 25 122 | train_loss: 0.09825884147004627 123 | eval_loss: 0.00346387299475209 124 | time: Tue Aug 13 04:00:42 2024 125 | 126 | Epoch 26 127 | train_loss: 0.09756856258879791 128 | eval_loss: 0.0033276399650575615 129 | time: Tue Aug 13 09:49:22 2024 130 | 131 | Epoch 27 132 | train_loss: 0.09730380131801182 133 | eval_loss: 0.003326884365762399 134 | time: Tue Aug 13 15:36:05 2024 135 | 136 | Epoch 28 137 | train_loss: 0.09687296288584166 138 | eval_loss: 0.0034621171255395573 139 | time: Tue Aug 13 21:23:24 2024 140 | 141 | Epoch 29 142 | train_loss: 0.09668537175198876 143 | eval_loss: 0.003284947640647648 144 | time: Wed Aug 14 03:10:21 2024 145 | 146 | Epoch 30 147 | train_loss: 0.09628572566572022 148 | eval_loss: 0.003119471057549999 149 | time: Wed Aug 14 08:56:45 2024 150 | 151 | Epoch 31 152 | train_loss: 0.09617123452549026 153 | eval_loss: 0.003124797866062776 154 | time: Wed Aug 14 14:43:12 2024 155 | 156 | Epoch 32 157 | train_loss: 0.09578377932399237 158 | eval_loss: 0.0030736677601092537 159 | time: Wed Aug 14 20:31:01 2024 160 | 161 | Epoch 33 162 | train_loss: 0.09558304869954821 163 | eval_loss: 0.003178201471396451 164 | time: Thu Aug 15 02:19:14 2024 165 | 166 | Epoch 34 167 | train_loss: 0.0952804450174092 168 | eval_loss: 0.0030847328114775225 169 | time: Thu Aug 15 08:06:29 2024 170 | 171 | Epoch 35 172 | train_loss: 0.09513826066486042 173 | eval_loss: 0.00303873973446682 174 | time: Thu Aug 15 13:52:17 2024 175 | 176 | Epoch 36 177 | train_loss: 0.09466769916316405 178 | eval_loss: 0.0030122215467611258 179 | time: Thu Aug 15 19:38:29 2024 180 | 181 | Epoch 37 182 | train_loss: 0.09465687754501316 183 | eval_loss: 0.00289094522015785 184 | time: Fri Aug 16 01:25:14 2024 185 | 186 | Epoch 38 187 | train_loss: 0.09435585222324992 188 | eval_loss: 0.0030173959307773393 189 | time: Fri Aug 16 07:11:56 2024 190 | 191 | Epoch 39 192 | train_loss: 0.09413478592045794 193 | eval_loss: 0.002968058454507435 194 | time: Fri Aug 16 12:59:16 2024 195 | 196 | Epoch 40 197 | train_loss: 0.09393180562734375 198 | eval_loss: 0.0030673167865746948 199 | time: Fri Aug 16 18:45:23 2024 200 | 201 | Epoch 41 202 | train_loss: 0.09365266143982799 203 | eval_loss: 0.00287582161187937 204 | time: Sat Aug 17 00:31:47 2024 205 | 206 | Epoch 42 207 | train_loss: 0.09359205519747489 208 | eval_loss: 0.0027280030162997134 209 | time: Sat Aug 17 06:18:32 2024 210 | 211 | Epoch 43 212 | train_loss: 0.09349238520961266 213 | eval_loss: 0.0029261269570300787 214 | time: Sat Aug 17 12:05:26 2024 215 | 216 | Epoch 44 217 | train_loss: 0.09324607778116949 218 | eval_loss: 0.002691730654519444 219 | time: Sat Aug 17 17:52:20 2024 220 | 221 | Epoch 45 222 | train_loss: 0.09310021795996155 223 | eval_loss: 0.0028863806760858132 224 | time: Sat Aug 17 23:38:57 2024 225 | 226 | Epoch 46 227 | train_loss: 0.09307358593283441 228 | eval_loss: 0.002793597210717352 229 | time: Sun Aug 18 05:25:42 2024 230 | 231 | Epoch 47 232 | train_loss: 0.09299390690766882 233 | eval_loss: 0.0027052024821456098 234 | time: Sun Aug 18 11:12:32 2024 235 | 236 | Epoch 48 237 | train_loss: 0.09253486422624911 238 | eval_loss: 0.0027312307396534247 239 | time: Sun Aug 18 16:59:16 2024 240 | 241 | Epoch 49 242 | train_loss: 0.09243107154309635 243 | eval_loss: 0.002648197562936772 244 | time: Sun Aug 18 22:46:41 2024 245 | 246 | Epoch 50 247 | train_loss: 0.09237845186490301 248 | eval_loss: 0.0026844193827840284 249 | time: Mon Aug 19 04:35:10 2024 250 | 251 | Epoch 51 252 | train_loss: 0.09231985249015236 253 | eval_loss: 0.002708845011956738 254 | time: Mon Aug 19 10:24:17 2024 255 | 256 | Epoch 52 257 | train_loss: 0.0922615721153286 258 | eval_loss: 0.0035362059711223225 259 | time: Mon Aug 19 16:11:39 2024 260 | 261 | Epoch 53 262 | train_loss: 0.09200190843071623 263 | eval_loss: 0.0025848455890180064 264 | time: Mon Aug 19 21:58:31 2024 265 | 266 | Epoch 54 267 | train_loss: 0.09200848002425245 268 | eval_loss: 0.0026311414897881983 269 | time: Tue Aug 20 03:45:36 2024 270 | 271 | Epoch 55 272 | train_loss: 0.09154813869071807 273 | eval_loss: 0.0025586662145983823 274 | time: Tue Aug 20 09:34:48 2024 275 | 276 | Epoch 56 277 | train_loss: 0.09162745474034129 278 | eval_loss: 0.0026280648907143545 279 | time: Tue Aug 20 15:23:23 2024 280 | 281 | Epoch 57 282 | train_loss: 0.09156280245772795 283 | eval_loss: 0.002539119078534093 284 | time: Tue Aug 20 21:11:25 2024 285 | 286 | Epoch 58 287 | train_loss: 0.09142590950099329 288 | eval_loss: 0.0026369429265152866 289 | time: Wed Aug 21 02:59:19 2024 290 | 291 | Epoch 59 292 | train_loss: 0.09139848643851392 293 | eval_loss: 0.0024354966580356916 294 | time: Wed Aug 21 08:46:23 2024 295 | 296 | Epoch 60 297 | train_loss: 0.09131192888740647 298 | eval_loss: 0.0024594995301248277 299 | time: Wed Aug 21 14:33:28 2024 300 | 301 | Epoch 61 302 | train_loss: 0.09122042933562911 303 | eval_loss: 0.002616936316367883 304 | time: Wed Aug 21 20:20:57 2024 305 | 306 | Epoch 62 307 | train_loss: 0.09109125168796305 308 | eval_loss: 0.0025555431279884297 309 | time: Thu Aug 22 02:08:45 2024 310 | 311 | Epoch 63 312 | train_loss: 0.09106527324403817 313 | eval_loss: 0.0025145284593781213 314 | time: Thu Aug 22 07:56:26 2024 315 | 316 | Epoch 64 317 | train_loss: 0.09095406525682191 318 | eval_loss: 0.0025151555842959678 319 | time: Thu Aug 22 13:45:57 2024 320 | 321 | Epoch 65 322 | train_loss: 0.09102793501718281 323 | eval_loss: 0.0024135450126194563 324 | time: Thu Aug 22 19:54:28 2024 325 | 326 | Epoch 66 327 | train_loss: 0.0908411063853937 328 | eval_loss: 0.002460922076728368 329 | time: Fri Aug 23 01:59:41 2024 330 | 331 | Epoch 67 332 | train_loss: 0.09070221083785855 333 | eval_loss: 0.002453409551882543 334 | time: Fri Aug 23 07:52:30 2024 335 | 336 | Epoch 68 337 | train_loss: 0.0906545008953897 338 | eval_loss: 0.0024080786435031784 339 | time: Fri Aug 23 13:41:28 2024 340 | 341 | Epoch 69 342 | train_loss: 0.0907353380525871 343 | eval_loss: 0.0024573436347799147 344 | time: Fri Aug 23 19:27:14 2024 345 | 346 | Epoch 70 347 | train_loss: 0.09040538104085095 348 | eval_loss: 0.0023765437401249566 349 | time: Sat Aug 24 01:14:45 2024 350 | 351 | Epoch 71 352 | train_loss: 0.09036114065518137 353 | eval_loss: 0.0023877528348234226 354 | time: Sat Aug 24 07:02:04 2024 355 | 356 | Epoch 72 357 | train_loss: 0.09037455027205546 358 | eval_loss: 0.002315233082103814 359 | time: Sat Aug 24 12:49:24 2024 360 | 361 | Epoch 73 362 | train_loss: 0.09026183628343257 363 | eval_loss: 0.0024284060419643228 364 | time: Sat Aug 24 18:35:36 2024 365 | 366 | Epoch 74 367 | train_loss: 0.09019025581511034 368 | eval_loss: 0.002393116130206718 369 | time: Sun Aug 25 00:21:29 2024 370 | 371 | Epoch 75 372 | train_loss: 0.089901714783446 373 | eval_loss: 0.002298152916632467 374 | time: Sun Aug 25 06:08:01 2024 375 | 376 | Epoch 76 377 | train_loss: 0.09018262871273484 378 | eval_loss: 0.002273971366672482 379 | time: Sun Aug 25 11:54:02 2024 380 | 381 | Epoch 77 382 | train_loss: 0.08998425874228 383 | eval_loss: 0.002317420323379338 384 | time: Sun Aug 25 17:44:05 2024 385 | 386 | Epoch 78 387 | train_loss: 0.08983653943919646 388 | eval_loss: 0.0024391192159878743 389 | time: Sun Aug 25 23:31:34 2024 390 | 391 | Epoch 79 392 | train_loss: 0.08981405456901183 393 | eval_loss: 0.002319374949895317 394 | time: Mon Aug 26 05:24:56 2024 395 | 396 | Epoch 80 397 | train_loss: 0.08974534569690559 398 | eval_loss: 0.0023008979344151066 399 | time: Mon Aug 26 11:28:33 2024 400 | 401 | Epoch 81 402 | train_loss: 0.08972110153310983 403 | eval_loss: 0.002406696710865237 404 | time: Mon Aug 26 17:33:30 2024 405 | 406 | Epoch 82 407 | train_loss: 0.0895689915361898 408 | eval_loss: 0.002241936448434926 409 | time: Mon Aug 26 23:39:15 2024 410 | 411 | Epoch 83 412 | train_loss: 0.08950625452328584 413 | eval_loss: 0.002408353965493697 414 | time: Tue Aug 27 05:37:59 2024 415 | 416 | Epoch 84 417 | train_loss: 0.08959725393084628 418 | eval_loss: 0.0023435966142665455 419 | time: Tue Aug 27 11:34:29 2024 420 | 421 | Epoch 85 422 | train_loss: 0.08970333726515986 423 | eval_loss: 0.0023965956810233086 424 | time: Tue Aug 27 17:27:31 2024 425 | 426 | Epoch 86 427 | train_loss: 0.08948115523227308 428 | eval_loss: 0.002325803569256709 429 | time: Tue Aug 27 23:19:43 2024 430 | 431 | Epoch 87 432 | train_loss: 0.08933937688654775 433 | eval_loss: 0.0023552257988114647 434 | time: Wed Aug 28 05:11:34 2024 435 | 436 | Epoch 88 437 | train_loss: 0.08938353908107184 438 | eval_loss: 0.0024397599904794043 439 | time: Wed Aug 28 11:01:23 2024 440 | 441 | Epoch 89 442 | train_loss: 0.08921640703096091 443 | eval_loss: 0.002223708766084243 444 | time: Wed Aug 28 16:50:21 2024 445 | 446 | Epoch 90 447 | train_loss: 0.08929300930090782 448 | eval_loss: 0.0022849828316260303 449 | time: Wed Aug 28 22:38:53 2024 450 | 451 | Epoch 91 452 | train_loss: 0.08910525214309825 453 | eval_loss: 0.0022257193633186227 454 | time: Thu Aug 29 04:35:16 2024 455 | 456 | Epoch 92 457 | train_loss: 0.08905495976636461 458 | eval_loss: 0.0022299331251850137 459 | time: Thu Aug 29 10:29:46 2024 460 | 461 | Epoch 93 462 | train_loss: 0.08890526102100955 463 | eval_loss: 0.0022962711695463786 464 | time: Thu Aug 29 16:23:49 2024 465 | 466 | Epoch 94 467 | train_loss: 0.08908289874104246 468 | eval_loss: 0.002243622880820028 469 | time: Thu Aug 29 22:15:42 2024 470 | 471 | Epoch 95 472 | train_loss: 0.08908785978677156 473 | eval_loss: 0.0022457318524397784 474 | time: Fri Aug 30 04:06:57 2024 475 | 476 | Epoch 96 477 | train_loss: 0.08888098475318565 478 | eval_loss: 0.002224675611787346 479 | time: Fri Aug 30 09:58:43 2024 480 | 481 | Epoch 97 482 | train_loss: 0.08888529259134526 483 | eval_loss: 0.0021844924980664493 484 | time: Fri Aug 30 15:50:16 2024 485 | 486 | Epoch 98 487 | train_loss: 0.08885388837534758 488 | eval_loss: 0.0022109088076294288 489 | time: Fri Aug 30 21:41:59 2024 490 | 491 | Epoch 99 492 | train_loss: 0.08873902663868657 493 | eval_loss: 0.0022606451996653202 494 | time: Sat Aug 31 03:34:46 2024 495 | 496 | Epoch 100 497 | train_loss: 0.08877080098666765 498 | eval_loss: 0.002279470525367602 499 | time: Sat Aug 31 09:28:38 2024 500 | 501 | -------------------------------------------------------------------------------- /code/train_clamp2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import time 4 | import wandb 5 | import torch 6 | import random 7 | import numpy as np 8 | from utils import * 9 | from config import * 10 | from tqdm import tqdm 11 | from copy import deepcopy 12 | import torch.distributed as dist 13 | from torch.amp import autocast, GradScaler 14 | from torch.utils.data import Dataset, DataLoader 15 | from torch.utils.data.distributed import DistributedSampler 16 | from torch.nn.parallel import DistributedDataParallel as DDP 17 | from transformers import AutoTokenizer, BertConfig, get_constant_schedule_with_warmup 18 | 19 | def list_files_in_json(json_path): 20 | file_list = [] 21 | 22 | if os.path.exists(json_path): 23 | with open(json_path, 'r', encoding='utf-8') as f: 24 | for line in f: 25 | item = json.loads(line) 26 | file_list.append(item) 27 | 28 | return file_list 29 | 30 | def collate_batch(batch): 31 | text_inputs, text_masks, music_inputs, music_masks = zip(*batch) 32 | 33 | text_inputs = torch.stack(text_inputs) 34 | text_masks = torch.stack(text_masks) 35 | music_inputs = torch.stack(music_inputs) 36 | music_masks = torch.stack(music_masks) 37 | 38 | return text_inputs, text_masks, music_inputs, music_masks 39 | 40 | class TextMusicDataset(Dataset): 41 | def __init__(self, items, mode): 42 | print("The number of "+mode+" data: "+str(len(items))) 43 | self.items = items 44 | self.mode = mode 45 | if self.mode == 'train' or not EVAL_JSONL: 46 | self.datapath = os.path.dirname(TRAIN_JSONL) 47 | elif self.mode == 'eval': 48 | self.datapath = os.path.dirname(EVAL_JSONL) 49 | 50 | def text_dropout(self, item): 51 | if random.random() < 0.5: 52 | candidates = [] 53 | for key in item.keys(): 54 | if key not in ["summary_en", "summary_nen", "filepaths"]: 55 | if item[key] == None: 56 | continue 57 | elif isinstance(item[key], str): 58 | candidates.append(item[key]) 59 | elif isinstance(item[key], list): 60 | candidates.extend(item[key]) 61 | candidates = list(set(candidates)) 62 | candidates = "\n".join(candidates) 63 | candidates = candidates.split("\n") 64 | selected_candidates = [c for c in candidates if len(c) > 0 and random.random() < 0.5] 65 | if len(selected_candidates) == 0: 66 | selected_candidates = candidates 67 | random.shuffle(selected_candidates) 68 | text = tokenizer.sep_token.join(selected_candidates) 69 | else: 70 | if random.random() < 0.5: 71 | text = random.choice(item["summary_en"]) 72 | else: 73 | text = random.choice(item["summary_nen"])["summary"] 74 | 75 | return text 76 | 77 | def random_truncate(self, input_tensor, max_length): 78 | choices = ["head", "tail", "middle"] 79 | choice = random.choice(choices) 80 | if choice == "head" or self.mode == 'eval': 81 | input_tensor = input_tensor[:max_length] 82 | elif choice == "tail": 83 | input_tensor = input_tensor[-max_length:] 84 | elif choice == "middle": 85 | start = random.randint(1, input_tensor.size(0)-max_length) 86 | input_tensor = input_tensor[start:start+max_length] 87 | 88 | return input_tensor 89 | 90 | def __len__(self): 91 | return len(self.items) 92 | 93 | def __getitem__(self, idx): 94 | item = self.items[idx] 95 | 96 | # randomly select text from the item 97 | if self.mode == 'train' and TEXT_DROPOUT: 98 | text = self.text_dropout(item) 99 | else: 100 | text = item["summary_en"][0] 101 | 102 | # tokenize text and build mask for text tokens 103 | text_inputs = tokenizer(text, return_tensors='pt') 104 | text_inputs = text_inputs['input_ids'].squeeze(0) 105 | if text_inputs.size(0) > MAX_TEXT_LENGTH: 106 | text_inputs = self.random_truncate(text_inputs, MAX_TEXT_LENGTH) 107 | text_masks = torch.ones(text_inputs.size(0)) 108 | 109 | # load music file 110 | if self.mode == 'train': 111 | filepath = random.choice(item["filepaths"]) 112 | else: 113 | if item["filepaths"][0].endswith(".abc"): 114 | filepath = item["filepaths"][0] 115 | else: 116 | filepath = item["filepaths"][1] 117 | filepath = self.datapath + '/' + filepath 118 | 119 | with open(filepath, "r", encoding="utf-8") as f: 120 | item = f.read().replace("L:1/8\n", "") if filepath.endswith(".abc") else f.read() 121 | 122 | # randomly remove instrument info from the music file 123 | if random.random() < 0.9 and self.mode == 'train': 124 | item = remove_instrument_info(item) 125 | 126 | # mask music inputs 127 | music_inputs = patchilizer.encode(item, add_special_patches=True, truncate=True, random_truncate=(self.mode=="train")) 128 | music_inputs = torch.tensor(music_inputs) 129 | music_masks = torch.ones(music_inputs.size(0)) 130 | 131 | # pad text inputs and masks 132 | pad_indices = torch.ones(MAX_TEXT_LENGTH - text_inputs.size(0)).long() * tokenizer.pad_token_id 133 | text_inputs = torch.cat((text_inputs, pad_indices), 0) 134 | text_masks = torch.cat((text_masks, torch.zeros(MAX_TEXT_LENGTH - text_masks.size(0))), 0) 135 | 136 | # pad music inputs and masks 137 | pad_indices = torch.ones((PATCH_LENGTH - music_inputs.size(0), PATCH_SIZE)).long() * patchilizer.pad_token_id 138 | music_inputs = torch.cat((music_inputs, pad_indices), 0) 139 | music_masks = torch.cat((music_masks, torch.zeros(PATCH_LENGTH - music_masks.size(0))), 0) 140 | 141 | return text_inputs, text_masks, music_inputs, music_masks 142 | 143 | # call model with a batch of input 144 | def process_one_batch(batch): 145 | text_inputs, text_masks, music_inputs, music_masks = batch 146 | 147 | loss = model(text_inputs, 148 | text_masks, 149 | music_inputs, 150 | music_masks) 151 | 152 | # Reduce the loss on GPU 0 153 | if world_size > 1: 154 | loss = loss.unsqueeze(0) 155 | dist.reduce(loss, dst=0) 156 | loss = loss / world_size 157 | dist.broadcast(loss, src=0) 158 | 159 | return loss.mean() 160 | 161 | # do one epoch for training 162 | def train_epoch(epoch): 163 | tqdm_train_set = tqdm(train_set) 164 | total_train_loss = 0 165 | iter_idx = 1 166 | model.train() 167 | train_steps = (epoch-1)*len(train_set) 168 | 169 | for batch in tqdm_train_set: 170 | with autocast(device_type='cuda'): 171 | loss = process_one_batch(batch) 172 | scaler.scale(loss).backward() 173 | total_train_loss += loss.item() 174 | scaler.step(optimizer) 175 | scaler.update() 176 | 177 | lr_scheduler.step() 178 | model.zero_grad(set_to_none=True) 179 | tqdm_train_set.set_postfix({str(global_rank)+'_train_loss': total_train_loss / iter_idx}) 180 | train_steps += 1 181 | 182 | # Log the training loss to wandb 183 | if global_rank==0 and CLAMP2_WANDB_LOG: 184 | wandb.log({"train_loss": total_train_loss / iter_idx}, step=train_steps) 185 | 186 | iter_idx += 1 187 | 188 | return total_train_loss / (iter_idx-1) 189 | 190 | # do one epoch for eval 191 | def eval_epoch(): 192 | tqdm_eval_set = tqdm(eval_set) 193 | total_eval_loss = 0 194 | iter_idx = 1 195 | model.eval() 196 | 197 | # Evaluate data for one epoch 198 | for batch in tqdm_eval_set: 199 | with torch.no_grad(): 200 | loss = process_one_batch(batch) 201 | 202 | total_eval_loss += loss.item() 203 | tqdm_eval_set.set_postfix({str(global_rank)+'_eval_loss': total_eval_loss / iter_idx}) 204 | iter_idx += 1 205 | 206 | return total_eval_loss / (iter_idx-1) 207 | 208 | # train and eval 209 | if __name__ == "__main__": 210 | 211 | # Set up distributed training 212 | world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 213 | global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else 0 214 | local_rank = int(os.environ['LOCAL_RANK']) if 'LOCAL_RANK' in os.environ else 0 215 | 216 | if world_size > 1: 217 | torch.cuda.set_device(local_rank) 218 | device = torch.device("cuda", local_rank) 219 | dist.init_process_group(backend='nccl') 220 | else: 221 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 222 | 223 | if CLAMP2_DETERMINISTIC: 224 | seed = 42 + global_rank 225 | random.seed(seed) 226 | np.random.seed(seed) 227 | torch.manual_seed(seed) 228 | torch.cuda.manual_seed_all(seed) 229 | torch.backends.cudnn.deterministic = True 230 | torch.backends.cudnn.benchmark = False 231 | 232 | m3_config = BertConfig(vocab_size=1, 233 | hidden_size=M3_HIDDEN_SIZE, 234 | num_hidden_layers=PATCH_NUM_LAYERS, 235 | num_attention_heads=M3_HIDDEN_SIZE//64, 236 | intermediate_size=M3_HIDDEN_SIZE*4, 237 | max_position_embeddings=PATCH_LENGTH) 238 | model = CLaMP2Model(m3_config, 239 | global_rank, 240 | world_size, 241 | TEXT_MODEL_NAME, 242 | CLAMP2_HIDDEN_SIZE, 243 | CLAMP2_LOAD_M3) 244 | model = model.to(device) 245 | tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME) 246 | patchilizer = M3Patchilizer() 247 | 248 | # print parameter number 249 | print("Parameter Number: "+str(sum(p.numel() for p in model.parameters() if p.requires_grad))) 250 | 251 | if world_size > 1: 252 | model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) 253 | 254 | scaler = GradScaler() 255 | optimizer = torch.optim.AdamW(model.parameters(), lr=CLAMP2_LEARNING_RATE) 256 | 257 | if CLAMP2_WANDB_LOG and global_rank==0: 258 | # Initialize wandb 259 | if WANDB_KEY: 260 | wandb.login(key=WANDB_KEY) 261 | wandb.init(project="clamp2", 262 | name=CLAMP2_WEIGHTS_PATH.replace("weights_", "").replace(".pth", "")) 263 | 264 | # load filenames under train and eval folder 265 | train_files = list_files_in_json(TRAIN_JSONL) 266 | eval_files = list_files_in_json(EVAL_JSONL) 267 | 268 | if len(eval_files)==0: 269 | train_files, eval_files = split_data(train_files) 270 | 271 | train_batch_nums = int(len(train_files) / CLAMP2_BATCH_SIZE) 272 | eval_batch_nums = int(len(eval_files) / CLAMP2_BATCH_SIZE) 273 | 274 | train_files = train_files[:train_batch_nums*CLAMP2_BATCH_SIZE] 275 | eval_files = eval_files[:eval_batch_nums*CLAMP2_BATCH_SIZE] 276 | 277 | train_set = TextMusicDataset(train_files, 'train') 278 | eval_set = TextMusicDataset(eval_files, 'eval') 279 | 280 | train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=global_rank) 281 | eval_sampler = DistributedSampler(eval_set, num_replicas=world_size, rank=global_rank) 282 | 283 | train_set = DataLoader(train_set, batch_size=CLAMP2_BATCH_SIZE, collate_fn=collate_batch, sampler=train_sampler, shuffle = (train_sampler is None)) 284 | eval_set = DataLoader(eval_set, batch_size=CLAMP2_BATCH_SIZE, collate_fn=collate_batch, sampler=eval_sampler, shuffle = (train_sampler is None)) 285 | 286 | lr_scheduler = get_constant_schedule_with_warmup(optimizer = optimizer, num_warmup_steps = 1000) 287 | 288 | if CLAMP2_LOAD_CKPT and os.path.exists(CLAMP2_WEIGHTS_PATH): 289 | # Load checkpoint to CPU 290 | checkpoint = torch.load(CLAMP2_WEIGHTS_PATH, map_location='cpu', weights_only=True) 291 | 292 | # Here, model is assumed to be on GPU 293 | # Load state dict to CPU model first, then move the model to GPU 294 | if torch.cuda.device_count() > 1: 295 | # If you have a DataParallel model, you need to load to model.module instead 296 | cpu_model = deepcopy(model.module) 297 | cpu_model.load_state_dict(checkpoint['model']) 298 | model.module.load_state_dict(cpu_model.state_dict()) 299 | else: 300 | # Load to a CPU clone of the model, then load back 301 | cpu_model = deepcopy(model) 302 | cpu_model.load_state_dict(checkpoint['model']) 303 | model.load_state_dict(cpu_model.state_dict()) 304 | optimizer.load_state_dict(checkpoint['optimizer']) 305 | lr_scheduler.load_state_dict(checkpoint['lr_sched']) 306 | pre_epoch = checkpoint['epoch'] 307 | best_epoch = checkpoint['best_epoch'] 308 | min_eval_loss = checkpoint['min_eval_loss'] 309 | print(f"Successfully Loaded Checkpoint from Epoch {checkpoint['epoch']} with loss {checkpoint['min_eval_loss']}") 310 | checkpoint = None 311 | 312 | else: 313 | pre_epoch = 0 314 | best_epoch = 0 315 | min_eval_loss = float('inf') 316 | 317 | model = model.to(device) 318 | optimizer = torch.optim.AdamW(model.parameters(), lr=CLAMP2_LEARNING_RATE) 319 | 320 | for epoch in range(1+pre_epoch, CLAMP2_NUM_EPOCH+1): 321 | train_sampler.set_epoch(epoch) 322 | eval_sampler.set_epoch(epoch) 323 | print('-' * 21 + "Epoch " + str(epoch) + '-' * 21) 324 | train_loss = train_epoch(epoch) 325 | eval_loss = eval_epoch() 326 | if global_rank==0: 327 | with open(CLAMP2_LOGS_PATH,'a') as f: 328 | f.write("Epoch " + str(epoch) + "\ntrain_loss: " + str(train_loss) + "\neval_loss: " +str(eval_loss) + "\ntime: " + time.asctime(time.localtime(time.time())) + "\n\n") 329 | if eval_loss < min_eval_loss: 330 | best_epoch = epoch 331 | min_eval_loss = eval_loss 332 | checkpoint = { 333 | 'model': model.module.state_dict() if hasattr(model, "module") else model.state_dict(), 334 | 'optimizer': optimizer.state_dict(), 335 | 'lr_sched': lr_scheduler.state_dict(), 336 | 'epoch': epoch, 337 | 'best_epoch': best_epoch, 338 | 'min_eval_loss': min_eval_loss 339 | } 340 | torch.save(checkpoint, CLAMP2_WEIGHTS_PATH) 341 | checkpoint = { 342 | 'model': model.module.state_dict() if hasattr(model, "module") else model.state_dict(), 343 | 'optimizer': optimizer.state_dict(), 344 | 'lr_sched': lr_scheduler.state_dict(), 345 | 'epoch': epoch, 346 | 'best_epoch': best_epoch, 347 | 'min_eval_loss': min_eval_loss 348 | } 349 | torch.save(checkpoint, "latest_"+CLAMP2_WEIGHTS_PATH) 350 | 351 | if world_size > 1: 352 | dist.barrier() 353 | 354 | if global_rank==0: 355 | print("Best Eval Epoch : "+str(best_epoch)) 356 | print("Min Eval Loss : "+str(min_eval_loss)) 357 | -------------------------------------------------------------------------------- /code/train_m3.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gc 3 | import time 4 | import wandb 5 | import torch 6 | import random 7 | import weakref 8 | import numpy as np 9 | from utils import * 10 | from config import * 11 | from tqdm import tqdm 12 | from copy import deepcopy 13 | import torch.distributed as dist 14 | from torch.amp import autocast, GradScaler 15 | from torch.utils.data import Dataset, DataLoader 16 | from torch.utils.data.distributed import DistributedSampler 17 | from torch.nn.parallel import DistributedDataParallel as DDP 18 | from transformers import BertConfig, GPT2Config, get_constant_schedule_with_warmup 19 | 20 | patchilizer = M3Patchilizer() 21 | 22 | def clear_unused_tensors(): 23 | gc.disable() # Temporarily disable garbage collection 24 | try: 25 | # Get the set of tensor ids used by the model 26 | if hasattr(model, "module"): 27 | model_tensors = {id(p) for p in model.module.parameters()} 28 | else: 29 | model_tensors = {id(p) for p in model.parameters()} 30 | 31 | # Get the set of tensor ids used by the optimizer 32 | optimizer_tensors = { 33 | id(state) 34 | for state_dict in optimizer.state.values() 35 | for state in state_dict.values() 36 | if isinstance(state, torch.Tensor) # Ensure only tensors are considered 37 | } 38 | 39 | # List of all CUDA tensors currently in memory 40 | tensors = [obj for obj in gc.get_objects() if isinstance(obj, torch.Tensor) and obj.is_cuda] 41 | 42 | # Create weak references to avoid interfering with garbage collection 43 | tensor_refs = [weakref.ref(tensor) for tensor in tensors] 44 | 45 | for tensor_ref in tensor_refs: 46 | tensor = tensor_ref() # Dereference the weak reference 47 | if tensor is not None and id(tensor) not in model_tensors and id(tensor) not in optimizer_tensors: 48 | # Mark the tensor for deletion 49 | tensor.detach_() # Detach from computation graph 50 | del tensor # Delete the tensor reference 51 | except: 52 | pass 53 | 54 | finally: 55 | gc.enable() # Re-enable garbage collection 56 | gc.collect() # Force a garbage collection 57 | torch.cuda.empty_cache() # Clear the CUDA cache 58 | 59 | def list_files_in_directory(directories, extensions=["abc", "mtf"]): 60 | file_list = [] 61 | 62 | for directory in directories: 63 | for root, dirs, files in os.walk(directory): 64 | for file in files: 65 | if any(file.endswith(ext) for ext in extensions): 66 | file_path = os.path.join(root, file) 67 | file_list.append(file_path) 68 | 69 | return file_list 70 | 71 | def collate_batch(batch): 72 | input_patches, input_masks, selected_indices, target_patches = zip(*batch) 73 | 74 | input_patches = torch.nn.utils.rnn.pad_sequence(input_patches, batch_first=True, padding_value=patchilizer.pad_token_id) 75 | input_masks = torch.nn.utils.rnn.pad_sequence(input_masks, batch_first=True, padding_value=0) 76 | selected_indices = torch.nn.utils.rnn.pad_sequence(selected_indices, batch_first=True, padding_value=0) 77 | target_patches = torch.nn.utils.rnn.pad_sequence(target_patches, batch_first=True, padding_value=patchilizer.pad_token_id) 78 | 79 | return input_patches, input_masks, selected_indices, target_patches 80 | 81 | class M3Dataset(Dataset): 82 | def __init__(self, filenames, mode): 83 | print("The number of "+mode+" data: "+str(len(filenames))) 84 | self.filenames = filenames 85 | self.mode = mode 86 | 87 | def __len__(self): 88 | return len(self.filenames) 89 | 90 | def __getitem__(self, idx): 91 | filename = self.filenames[idx] 92 | try: 93 | with open(filename, "r", encoding="utf-8") as f: 94 | item = f.read().replace("L:1/8\n", "") if filename.endswith(".abc") else f.read() 95 | except Exception as e: 96 | print(e) 97 | print("Failed to load: "+filename) 98 | item = "" 99 | 100 | target_patches = patchilizer.encode(item, add_special_patches=True, truncate=True, random_truncate=(self.mode=="train")) 101 | input_masks = torch.tensor([1]*len(target_patches)) 102 | input_patches, selected_indices = mask_patches(target_patches, patchilizer, self.mode) 103 | input_patches = input_patches.reshape(-1) 104 | target_patches = torch.tensor(target_patches).reshape(-1) 105 | return input_patches, input_masks, selected_indices, target_patches 106 | 107 | # call model with a batch of input 108 | def process_one_batch(batch): 109 | input_patches, input_masks, selected_indices, target_patches = batch 110 | 111 | loss = model(input_patches, 112 | input_masks, 113 | selected_indices, 114 | target_patches).loss 115 | 116 | # Reduce the loss on GPU 0 117 | if world_size > 1: 118 | loss = loss.unsqueeze(0) 119 | dist.reduce(loss, dst=0) 120 | loss = loss / world_size 121 | dist.broadcast(loss, src=0) 122 | 123 | return loss.mean() 124 | 125 | # do one epoch for training 126 | def train_epoch(epoch): 127 | tqdm_train_set = tqdm(train_set) 128 | total_train_loss = 0 129 | iter_idx = 1 130 | model.train() 131 | train_steps = (epoch-1)*len(train_set) 132 | 133 | for batch in tqdm_train_set: 134 | with autocast(device_type='cuda'): 135 | loss = process_one_batch(batch) 136 | scaler.scale(loss).backward() 137 | total_train_loss += loss.item() 138 | scaler.step(optimizer) 139 | scaler.update() 140 | 141 | lr_scheduler.step() 142 | model.zero_grad(set_to_none=True) 143 | tqdm_train_set.set_postfix({str(global_rank)+'_train_loss': total_train_loss / iter_idx}) 144 | train_steps += 1 145 | 146 | # Log the training loss to wandb 147 | if global_rank==0 and M3_WANDB_LOG: 148 | wandb.log({"train_loss": total_train_loss / iter_idx}, step=train_steps) 149 | 150 | iter_idx += 1 151 | if iter_idx % 1000 == 0: 152 | clear_unused_tensors() 153 | 154 | return total_train_loss / (iter_idx-1) 155 | 156 | # do one epoch for eval 157 | def eval_epoch(): 158 | tqdm_eval_set = tqdm(eval_set) 159 | total_eval_loss = 0 160 | iter_idx = 1 161 | model.eval() 162 | 163 | # Evaluate data for one epoch 164 | for batch in tqdm_eval_set: 165 | with torch.no_grad(): 166 | loss = process_one_batch(batch) 167 | 168 | total_eval_loss += loss.item() 169 | tqdm_eval_set.set_postfix({str(global_rank)+'_eval_loss': total_eval_loss / iter_idx}) 170 | iter_idx += 1 171 | 172 | return total_eval_loss / (iter_idx-1) 173 | 174 | # train and eval 175 | if __name__ == "__main__": 176 | 177 | # Set up distributed training 178 | world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 179 | global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else 0 180 | local_rank = int(os.environ['LOCAL_RANK']) if 'LOCAL_RANK' in os.environ else 0 181 | 182 | if world_size > 1: 183 | torch.cuda.set_device(local_rank) 184 | device = torch.device("cuda", local_rank) 185 | dist.init_process_group(backend='nccl') 186 | else: 187 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 188 | 189 | if M3_DETERMINISTIC: 190 | seed = 42 + global_rank 191 | random.seed(seed) 192 | np.random.seed(seed) 193 | torch.manual_seed(seed) 194 | torch.cuda.manual_seed_all(seed) 195 | torch.backends.cudnn.deterministic = True 196 | torch.backends.cudnn.benchmark = False 197 | 198 | encoder_config = BertConfig(vocab_size=1, 199 | hidden_size=M3_HIDDEN_SIZE, 200 | num_hidden_layers=PATCH_NUM_LAYERS, 201 | num_attention_heads=M3_HIDDEN_SIZE//64, 202 | intermediate_size=M3_HIDDEN_SIZE*4, 203 | max_position_embeddings=PATCH_LENGTH) 204 | decoder_config = GPT2Config(vocab_size=128, 205 | n_positions=PATCH_SIZE, 206 | n_embd=M3_HIDDEN_SIZE, 207 | n_layer=TOKEN_NUM_LAYERS, 208 | n_head=M3_HIDDEN_SIZE//64, 209 | n_inner=M3_HIDDEN_SIZE*4) 210 | model = M3Model(encoder_config, decoder_config) 211 | model = model.to(device) 212 | 213 | # print parameter number 214 | print("Parameter Number: "+str(sum(p.numel() for p in model.parameters() if p.requires_grad))) 215 | 216 | if world_size > 1: 217 | model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) 218 | 219 | scaler = GradScaler() 220 | optimizer = torch.optim.AdamW(model.parameters(), lr=M3_LEARNING_RATE) 221 | 222 | if M3_WANDB_LOG and global_rank==0: 223 | # Initialize wandb 224 | if WANDB_KEY: 225 | wandb.login(key=WANDB_KEY) 226 | wandb.init(project="m3", 227 | name=M3_WEIGHTS_PATH.replace("weights_", "").replace(".pth", "")) 228 | 229 | # load filenames under train and eval folder 230 | train_files = list_files_in_directory(TRAIN_FOLDERS) 231 | eval_files = list_files_in_directory(EVAL_FOLDERS) 232 | 233 | if len(eval_files)==0: 234 | train_files, eval_files = split_data(train_files) 235 | 236 | train_batch_nums = int(len(train_files) / M3_BATCH_SIZE) 237 | eval_batch_nums = int(len(eval_files) / M3_BATCH_SIZE) 238 | 239 | train_files = train_files[:train_batch_nums*M3_BATCH_SIZE] 240 | eval_files = eval_files[:eval_batch_nums*M3_BATCH_SIZE] 241 | 242 | train_set = M3Dataset(train_files, 'train') 243 | eval_set = M3Dataset(eval_files, 'eval') 244 | 245 | train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=global_rank) 246 | eval_sampler = DistributedSampler(eval_set, num_replicas=world_size, rank=global_rank) 247 | 248 | train_set = DataLoader(train_set, batch_size=M3_BATCH_SIZE, collate_fn=collate_batch, sampler=train_sampler, shuffle = (train_sampler is None)) 249 | eval_set = DataLoader(eval_set, batch_size=M3_BATCH_SIZE, collate_fn=collate_batch, sampler=eval_sampler, shuffle = (train_sampler is None)) 250 | 251 | lr_scheduler = get_constant_schedule_with_warmup(optimizer = optimizer, num_warmup_steps = 1000) 252 | 253 | if M3_LOAD_CKPT and os.path.exists(M3_WEIGHTS_PATH): 254 | # Load checkpoint to CPU 255 | checkpoint = torch.load(M3_WEIGHTS_PATH, map_location='cpu', weights_only=True) 256 | 257 | # Here, model is assumed to be on GPU 258 | # Load state dict to CPU model first, then move the model to GPU 259 | if torch.cuda.device_count() > 1: 260 | # If you have a DataParallel model, you need to load to model.module instead 261 | cpu_model = deepcopy(model.module) 262 | cpu_model.load_state_dict(checkpoint['model']) 263 | model.module.load_state_dict(cpu_model.state_dict()) 264 | else: 265 | # Load to a CPU clone of the model, then load back 266 | cpu_model = deepcopy(model) 267 | cpu_model.load_state_dict(checkpoint['model']) 268 | model.load_state_dict(cpu_model.state_dict()) 269 | optimizer.load_state_dict(checkpoint['optimizer']) 270 | lr_scheduler.load_state_dict(checkpoint['lr_sched']) 271 | pre_epoch = checkpoint['epoch'] 272 | best_epoch = checkpoint['best_epoch'] 273 | min_eval_loss = checkpoint['min_eval_loss'] 274 | print(f"Successfully Loaded Checkpoint from Epoch {checkpoint['epoch']} with loss {checkpoint['min_eval_loss']}") 275 | checkpoint = None 276 | 277 | else: 278 | pre_epoch = 0 279 | best_epoch = 0 280 | min_eval_loss = float('inf') 281 | 282 | model = model.to(device) 283 | optimizer = torch.optim.AdamW(model.parameters(), lr=M3_LEARNING_RATE) 284 | 285 | for epoch in range(1+pre_epoch, M3_NUM_EPOCH+1): 286 | train_sampler.set_epoch(epoch) 287 | eval_sampler.set_epoch(epoch) 288 | print('-' * 21 + "Epoch " + str(epoch) + '-' * 21) 289 | train_loss = train_epoch(epoch) 290 | eval_loss = eval_epoch() 291 | if global_rank==0: 292 | with open(M3_LOGS_PATH,'a') as f: 293 | f.write("Epoch " + str(epoch) + "\ntrain_loss: " + str(train_loss) + "\neval_loss: " +str(eval_loss) + "\ntime: " + time.asctime(time.localtime(time.time())) + "\n\n") 294 | if eval_loss < min_eval_loss: 295 | best_epoch = epoch 296 | min_eval_loss = eval_loss 297 | checkpoint = { 298 | 'model': model.module.state_dict() if hasattr(model, "module") else model.state_dict(), 299 | 'optimizer': optimizer.state_dict(), 300 | 'lr_sched': lr_scheduler.state_dict(), 301 | 'epoch': epoch, 302 | 'best_epoch': best_epoch, 303 | 'min_eval_loss': min_eval_loss 304 | } 305 | torch.save(checkpoint, M3_WEIGHTS_PATH) 306 | checkpoint = { 307 | 'model': model.module.state_dict() if hasattr(model, "module") else model.state_dict(), 308 | 'optimizer': optimizer.state_dict(), 309 | 'lr_sched': lr_scheduler.state_dict(), 310 | 'epoch': epoch, 311 | 'best_epoch': best_epoch, 312 | 'min_eval_loss': min_eval_loss 313 | } 314 | torch.save(checkpoint, "latest_"+M3_WEIGHTS_PATH) 315 | 316 | if world_size > 1: 317 | dist.barrier() 318 | 319 | if global_rank==0: 320 | print("Best Eval Epoch : "+str(best_epoch)) 321 | print("Min Eval Loss : "+str(min_eval_loss)) 322 | -------------------------------------------------------------------------------- /code/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import math 4 | import torch 5 | import random 6 | from config import * 7 | from unidecode import unidecode 8 | from torch.nn import functional as F 9 | from transformers import AutoModel, BertModel, GPT2LMHeadModel, PreTrainedModel, GPT2Config 10 | 11 | try: 12 | import torch.distributed.nn 13 | from torch import distributed as dist 14 | 15 | has_distributed = True 16 | except ImportError: 17 | has_distributed = False 18 | 19 | try: 20 | import horovod.torch as hvd 21 | except ImportError: 22 | hvd = None 23 | 24 | class ClipLoss(torch.nn.Module): 25 | 26 | def __init__( 27 | self, 28 | local_loss=False, 29 | gather_with_grad=False, 30 | cache_labels=False, 31 | rank=0, 32 | world_size=1, 33 | use_horovod=False, 34 | ): 35 | super().__init__() 36 | self.local_loss = local_loss 37 | self.gather_with_grad = gather_with_grad 38 | self.cache_labels = cache_labels 39 | self.rank = rank 40 | self.world_size = world_size 41 | self.use_horovod = use_horovod 42 | 43 | # cache state 44 | self.prev_num_logits = 0 45 | self.labels = {} 46 | 47 | def gather_features( 48 | self, 49 | image_features, 50 | text_features, 51 | local_loss=False, 52 | gather_with_grad=False, 53 | rank=0, 54 | world_size=1, 55 | use_horovod=False 56 | ): 57 | assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.' 58 | if use_horovod: 59 | assert hvd is not None, 'Please install horovod' 60 | if gather_with_grad: 61 | all_image_features = hvd.allgather(image_features) 62 | all_text_features = hvd.allgather(text_features) 63 | else: 64 | with torch.no_grad(): 65 | all_image_features = hvd.allgather(image_features) 66 | all_text_features = hvd.allgather(text_features) 67 | if not local_loss: 68 | # ensure grads for local rank when all_* features don't have a gradient 69 | gathered_image_features = list(all_image_features.chunk(world_size, dim=0)) 70 | gathered_text_features = list(all_text_features.chunk(world_size, dim=0)) 71 | gathered_image_features[rank] = image_features 72 | gathered_text_features[rank] = text_features 73 | all_image_features = torch.cat(gathered_image_features, dim=0) 74 | all_text_features = torch.cat(gathered_text_features, dim=0) 75 | else: 76 | # We gather tensors from all gpus 77 | if gather_with_grad: 78 | all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0) 79 | all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0) 80 | else: 81 | gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)] 82 | gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)] 83 | dist.all_gather(gathered_image_features, image_features) 84 | dist.all_gather(gathered_text_features, text_features) 85 | if not local_loss: 86 | # ensure grads for local rank when all_* features don't have a gradient 87 | gathered_image_features[rank] = image_features 88 | gathered_text_features[rank] = text_features 89 | all_image_features = torch.cat(gathered_image_features, dim=0) 90 | all_text_features = torch.cat(gathered_text_features, dim=0) 91 | 92 | return all_image_features, all_text_features 93 | 94 | def get_ground_truth(self, device, num_logits) -> torch.Tensor: 95 | # calculated ground-truth and cache if enabled 96 | if self.prev_num_logits != num_logits or device not in self.labels: 97 | labels = torch.arange(num_logits, device=device, dtype=torch.long) 98 | if self.world_size > 1 and self.local_loss: 99 | labels = labels + num_logits * self.rank 100 | if self.cache_labels: 101 | self.labels[device] = labels 102 | self.prev_num_logits = num_logits 103 | else: 104 | labels = self.labels[device] 105 | return labels 106 | 107 | def get_logits(self, image_features, text_features, logit_scale): 108 | if self.world_size > 1: 109 | all_image_features, all_text_features = self.gather_features( 110 | image_features, text_features, 111 | self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) 112 | 113 | if self.local_loss: 114 | logits_per_image = logit_scale * image_features @ all_text_features.T 115 | logits_per_text = logit_scale * text_features @ all_image_features.T 116 | else: 117 | logits_per_image = logit_scale * all_image_features @ all_text_features.T 118 | logits_per_text = logits_per_image.T 119 | else: 120 | logits_per_image = logit_scale * image_features @ text_features.T 121 | logits_per_text = logit_scale * text_features @ image_features.T 122 | 123 | return logits_per_image, logits_per_text 124 | 125 | def forward(self, image_features, text_features, logit_scale, output_dict=False): 126 | device = image_features.device 127 | logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale) 128 | 129 | labels = self.get_ground_truth(device, logits_per_image.shape[0]) 130 | 131 | total_loss = ( 132 | F.cross_entropy(logits_per_image, labels) + 133 | F.cross_entropy(logits_per_text, labels) 134 | ) / 2 135 | 136 | return {"contrastive_loss": total_loss} if output_dict else total_loss 137 | 138 | class M3Patchilizer: 139 | def __init__(self): 140 | self.delimiters = ["|:", "::", ":|", "[|", "||", "|]", "|"] 141 | self.regexPattern = '(' + '|'.join(map(re.escape, self.delimiters)) + ')' 142 | self.pad_token_id = 0 143 | self.bos_token_id = 1 144 | self.eos_token_id = 2 145 | self.mask_token_id = 3 146 | 147 | def split_bars(self, body): 148 | bars = re.split(self.regexPattern, ''.join(body)) 149 | bars = list(filter(None, bars)) # remove empty strings 150 | if bars[0] in self.delimiters: 151 | bars[1] = bars[0] + bars[1] 152 | bars = bars[1:] 153 | bars = [bars[i * 2] + bars[i * 2 + 1] for i in range(len(bars) // 2)] 154 | return bars 155 | 156 | def bar2patch(self, bar, patch_size=PATCH_SIZE): 157 | patch = [self.bos_token_id] + [ord(c) for c in bar] + [self.eos_token_id] 158 | patch = patch[:patch_size] 159 | patch += [self.pad_token_id] * (patch_size - len(patch)) 160 | return patch 161 | 162 | def patch2bar(self, patch): 163 | return ''.join(chr(idx) if idx > self.mask_token_id else '' for idx in patch) 164 | 165 | def encode(self, 166 | item, 167 | patch_size=PATCH_SIZE, 168 | add_special_patches=False, 169 | truncate=False, 170 | random_truncate=False): 171 | 172 | item = unidecode(item) 173 | lines = re.findall(r'.*?\n|.*$', item) 174 | lines = list(filter(None, lines)) # remove empty lines 175 | 176 | patches = [] 177 | 178 | if lines[0].split(" ")[0] == "ticks_per_beat": 179 | patch = "" 180 | for line in lines: 181 | if patch.startswith(line.split(" ")[0]) and (len(patch) + len(" ".join(line.split(" ")[1:])) <= patch_size-2): 182 | patch = patch[:-1] + "\t" + " ".join(line.split(" ")[1:]) 183 | else: 184 | if patch: 185 | patches.append(patch) 186 | patch = line 187 | if patch!="": 188 | patches.append(patch) 189 | else: 190 | for line in lines: 191 | if len(line) > 1 and ((line[0].isalpha() and line[1] == ':') or line.startswith('%%')): 192 | patches.append(line) 193 | else: 194 | bars = self.split_bars(line) 195 | if bars: 196 | bars[-1] += '\n' 197 | patches.extend(bars) 198 | 199 | if add_special_patches: 200 | bos_patch = chr(self.bos_token_id) * patch_size 201 | eos_patch = chr(self.eos_token_id) * patch_size 202 | patches = [bos_patch] + patches + [eos_patch] 203 | 204 | if len(patches) > PATCH_LENGTH and truncate: 205 | choices = ["head", "tail", "middle"] 206 | choice = random.choice(choices) 207 | if choice=="head" or random_truncate==False: 208 | patches = patches[:PATCH_LENGTH] 209 | elif choice=="tail": 210 | patches = patches[-PATCH_LENGTH:] 211 | else: 212 | start = random.randint(1, len(patches)-PATCH_LENGTH) 213 | patches = patches[start:start+PATCH_LENGTH] 214 | 215 | patches = [self.bar2patch(patch) for patch in patches] 216 | 217 | return patches 218 | 219 | def decode(self, patches): 220 | return ''.join(self.patch2bar(patch) for patch in patches) 221 | 222 | class M3PatchEncoder(PreTrainedModel): 223 | def __init__(self, config): 224 | super(M3PatchEncoder, self).__init__(config) 225 | self.patch_embedding = torch.nn.Linear(PATCH_SIZE*128, M3_HIDDEN_SIZE) 226 | torch.nn.init.normal_(self.patch_embedding.weight, std=0.02) 227 | self.base = BertModel(config=config) 228 | self.pad_token_id = 0 229 | self.bos_token_id = 1 230 | self.eos_token_id = 2 231 | self.mask_token_id = 3 232 | 233 | def forward(self, 234 | input_patches, # [batch_size, seq_length, hidden_size] 235 | input_masks): # [batch_size, seq_length] 236 | # Transform input_patches into embeddings 237 | input_patches = torch.nn.functional.one_hot(input_patches, num_classes=128) 238 | input_patches = input_patches.reshape(len(input_patches), -1, PATCH_SIZE*128).type(torch.FloatTensor) 239 | input_patches = self.patch_embedding(input_patches.to(self.device)) 240 | 241 | # Apply BERT model to input_patches and input_masks 242 | return self.base(inputs_embeds=input_patches, attention_mask=input_masks) 243 | 244 | class M3TokenDecoder(PreTrainedModel): 245 | def __init__(self, config): 246 | super(M3TokenDecoder, self).__init__(config) 247 | self.base = GPT2LMHeadModel(config=config) 248 | self.pad_token_id = 0 249 | self.bos_token_id = 1 250 | self.eos_token_id = 2 251 | self.mask_token_id = 3 252 | 253 | def forward(self, 254 | patch_features, # [batch_size, hidden_size] 255 | target_patches): # [batch_size, seq_length] 256 | # get input embeddings 257 | inputs_embeds = torch.nn.functional.embedding(target_patches, self.base.transformer.wte.weight) 258 | 259 | # concatenate the encoded patches with the input embeddings 260 | inputs_embeds = torch.cat((patch_features.unsqueeze(1), inputs_embeds[:,1:,:]), dim=1) 261 | 262 | # preparing the labels for model training 263 | target_masks = target_patches == self.pad_token_id 264 | target_patches = target_patches.clone().masked_fill_(target_masks, -100) 265 | 266 | # get the attention mask 267 | target_masks = ~target_masks 268 | target_masks = target_masks.type(torch.int) 269 | 270 | return self.base(inputs_embeds=inputs_embeds, 271 | attention_mask=target_masks, 272 | labels=target_patches) 273 | 274 | def generate(self, 275 | patch_feature, 276 | tokens): 277 | # reshape the patch_feature and tokens 278 | patch_feature = patch_feature.reshape(1, 1, -1) 279 | tokens = tokens.reshape(1, -1) 280 | 281 | # get input embeddings 282 | tokens = torch.nn.functional.embedding(tokens, self.base.transformer.wte.weight) 283 | 284 | # concatenate the encoded patches with the input embeddings 285 | tokens = torch.cat((patch_feature, tokens[:,1:,:]), dim=1) 286 | 287 | # get the outputs from the model 288 | outputs = self.base(inputs_embeds=tokens) 289 | 290 | # get the probabilities of the next token 291 | probs = torch.nn.functional.softmax(outputs.logits.squeeze(0)[-1], dim=-1) 292 | 293 | return probs.detach().cpu().numpy() 294 | 295 | class M3Model(PreTrainedModel): 296 | def __init__(self, encoder_config, decoder_config): 297 | super(M3Model, self).__init__(encoder_config) 298 | self.encoder = M3PatchEncoder(encoder_config) 299 | self.decoder = M3TokenDecoder(decoder_config) 300 | self.pad_token_id = 0 301 | self.bos_token_id = 1 302 | self.eos_token_id = 2 303 | self.mask_token_id = 3 304 | 305 | def forward(self, 306 | input_patches, # [batch_size, seq_length, hidden_size] 307 | input_masks, # [batch_size, seq_length] 308 | selected_indices, # [batch_size, seq_length] 309 | target_patches): # [batch_size, seq_length, hidden_size] 310 | input_patches = input_patches.reshape(len(input_patches), -1, PATCH_SIZE).to(self.device) 311 | input_masks = input_masks.to(self.device) 312 | selected_indices = selected_indices.to(self.device) 313 | target_patches = target_patches.reshape(len(target_patches), -1, PATCH_SIZE).to(self.device) 314 | 315 | # Pass the input_patches and input_masks through the encoder 316 | outputs = self.encoder(input_patches, input_masks)["last_hidden_state"] 317 | 318 | # Use selected_indices to form target_patches 319 | target_patches = target_patches[selected_indices.bool()] 320 | patch_features = outputs[selected_indices.bool()] 321 | 322 | # Pass patch_features and target_patches through the decoder 323 | return self.decoder(patch_features, target_patches) 324 | 325 | class CLaMP2Model(PreTrainedModel): 326 | def __init__(self, 327 | music_config, 328 | global_rank=None, 329 | world_size=None, 330 | text_model_name=TEXT_MODEL_NAME, 331 | hidden_size=CLAMP2_HIDDEN_SIZE, 332 | load_m3=CLAMP2_LOAD_M3): 333 | super(CLaMP2Model, self).__init__(music_config) 334 | 335 | self.text_model = AutoModel.from_pretrained(text_model_name) # Load the text model 336 | self.text_proj = torch.nn.Linear(self.text_model.config.hidden_size, hidden_size) # Linear layer for text projections 337 | torch.nn.init.normal_(self.text_proj.weight, std=0.02) # Initialize weights with normal distribution 338 | 339 | self.music_model = M3PatchEncoder(music_config) # Initialize the music model 340 | self.music_proj = torch.nn.Linear(M3_HIDDEN_SIZE, hidden_size) # Linear layer for music projections 341 | torch.nn.init.normal_(self.music_proj.weight, std=0.02) # Initialize weights with normal distribution 342 | 343 | if global_rank==None or world_size==None: 344 | global_rank = 0 345 | world_size = 1 346 | 347 | self.loss_fn = ClipLoss(local_loss=False, 348 | gather_with_grad=True, 349 | cache_labels=False, 350 | rank=global_rank, 351 | world_size=world_size, 352 | use_horovod=False) 353 | 354 | if load_m3 and os.path.exists(M3_WEIGHTS_PATH): 355 | checkpoint = torch.load(M3_WEIGHTS_PATH, map_location='cpu', weights_only=True) 356 | decoder_config = GPT2Config(vocab_size=128, 357 | n_positions=PATCH_SIZE, 358 | n_embd=M3_HIDDEN_SIZE, 359 | n_layer=TOKEN_NUM_LAYERS, 360 | n_head=M3_HIDDEN_SIZE//64, 361 | n_inner=M3_HIDDEN_SIZE*4) 362 | model = M3Model(music_config, decoder_config) 363 | model.load_state_dict(checkpoint['model']) 364 | self.music_model = model.encoder 365 | model = None 366 | print(f"Successfully Loaded M3 Checkpoint from Epoch {checkpoint['epoch']} with loss {checkpoint['min_eval_loss']}") 367 | 368 | def avg_pooling(self, input_features, input_masks): 369 | input_masks = input_masks.unsqueeze(-1).to(self.device) # add a dimension to match the feature dimension 370 | input_features = input_features * input_masks # apply mask to input_features 371 | avg_pool = input_features.sum(dim=1) / input_masks.sum(dim=1) # calculate average pooling 372 | 373 | return avg_pool 374 | 375 | def get_text_features(self, 376 | text_inputs, 377 | text_masks, 378 | get_normalized=False): 379 | text_features = self.text_model(text_inputs.to(self.device), 380 | attention_mask=text_masks.to(self.device))['last_hidden_state'] 381 | 382 | if get_normalized: 383 | text_features = self.avg_pooling(text_features, text_masks) 384 | text_features = self.text_proj(text_features) 385 | 386 | return text_features 387 | 388 | def get_music_features(self, 389 | music_inputs, 390 | music_masks, 391 | get_normalized=False): 392 | music_features = self.music_model(music_inputs.to(self.device), 393 | music_masks.to(self.device))['last_hidden_state'] 394 | 395 | if get_normalized: 396 | music_features = self.avg_pooling(music_features, music_masks) 397 | music_features = self.music_proj(music_features) 398 | 399 | return music_features 400 | 401 | def forward(self, 402 | text_inputs, # [batch_size, seq_length] 403 | text_masks, # [batch_size, seq_length] 404 | music_inputs, # [batch_size, seq_length, hidden_size] 405 | music_masks): # [batch_size, seq_length] 406 | # Compute the text features 407 | text_features = self.get_text_features(text_inputs, text_masks, get_normalized=True) 408 | 409 | # Compute the music features 410 | music_features = self.get_music_features(music_inputs, music_masks, get_normalized=True) 411 | 412 | return self.loss_fn(text_features, 413 | music_features, 414 | LOGIT_SCALE, 415 | output_dict=False) 416 | 417 | def split_data(data, eval_ratio=EVAL_SPLIT): 418 | random.shuffle(data) 419 | split_idx = int(len(data)*eval_ratio) 420 | eval_set = data[:split_idx] 421 | train_set = data[split_idx:] 422 | return train_set, eval_set 423 | 424 | def mask_patches(target_patches, patchilizer, mode): 425 | indices = list(range(len(target_patches))) 426 | random.shuffle(indices) 427 | selected_indices = indices[:math.ceil(M3_MASK_RATIO*len(indices))] 428 | sorted_indices = sorted(selected_indices) 429 | input_patches = torch.tensor(target_patches) 430 | 431 | if mode=="eval": 432 | choice = "original" 433 | else: 434 | choice = random.choices(["mask", "shuffle", "original"], weights=[0.8, 0.1, 0.1])[0] 435 | 436 | if choice=="mask": 437 | input_patches[sorted_indices] = torch.tensor([patchilizer.mask_token_id]*PATCH_SIZE) 438 | elif choice=="shuffle": 439 | for idx in sorted_indices: 440 | patch = input_patches[idx] 441 | try: 442 | index_eos = (patch == patchilizer.eos_token_id).nonzero().item() 443 | except: 444 | index_eos = len(patch) 445 | 446 | indices = list(range(1, index_eos)) 447 | random.shuffle(indices) 448 | indices = [0] + indices + list(range(index_eos, len(patch))) 449 | input_patches[idx] = patch[indices] 450 | 451 | selected_indices = torch.zeros(len(target_patches)) 452 | selected_indices[sorted_indices] = 1. 453 | 454 | return input_patches, selected_indices 455 | 456 | def remove_instrument_info(item): 457 | # remove instrument information from symbolic music 458 | lines = re.findall(r'.*?\n|.*$', item) 459 | lines = list(filter(None, lines)) 460 | if lines[0].split(" ")[0] == "ticks_per_beat": 461 | type = "mtf" 462 | else: 463 | type = "abc" 464 | 465 | cleaned_lines = [] 466 | for line in lines: 467 | if type=="abc" and line.startswith("V:"): 468 | # find the position of " nm=" or " snm=" 469 | nm_pos = line.find(" nm=") 470 | snm_pos = line.find(" snm=") 471 | # keep the part before " nm=" or " snm=" 472 | if nm_pos != -1: 473 | line = line[:nm_pos] 474 | elif snm_pos != -1: 475 | line = line[:snm_pos] 476 | if nm_pos != -1 or snm_pos != -1: 477 | line += "\n" 478 | elif type=="mtf" and line.startswith("program_change"): 479 | line = " ".join(line.split(" ")[:-1]) + " 0\n" 480 | 481 | cleaned_lines.append(line) 482 | 483 | return ''.join(cleaned_lines) 484 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: clamp2 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | dependencies: 7 | - blas=1.0=mkl 8 | - brotli-python=1.0.9=py310hd77b12b_8 9 | - bzip2=1.0.8=h2bbff1b_6 10 | - ca-certificates=2024.7.2=haa95532_0 11 | - cuda-cccl=12.6.37=0 12 | - cuda-cccl_win-64=12.6.37=0 13 | - cuda-cudart=11.8.89=0 14 | - cuda-cudart-dev=11.8.89=0 15 | - cuda-cupti=11.8.87=0 16 | - cuda-libraries=11.8.0=0 17 | - cuda-libraries-dev=11.8.0=0 18 | - cuda-nvrtc=11.8.89=0 19 | - cuda-nvrtc-dev=11.8.89=0 20 | - cuda-nvtx=11.8.86=0 21 | - cuda-profiler-api=12.6.68=0 22 | - cuda-runtime=11.8.0=0 23 | - cuda-version=12.6=3 24 | - freetype=2.12.1=ha860e81_0 25 | - gmpy2=2.1.2=py310h7f96b67_0 26 | - intel-openmp=2023.1.0=h59b6b97_46320 27 | - jinja2=3.1.4=py310haa95532_0 28 | - jpeg=9e=h827c3e9_3 29 | - lcms2=2.12=h83e58a3_0 30 | - lerc=3.0=hd77b12b_0 31 | - libcublas=11.11.3.6=0 32 | - libcublas-dev=11.11.3.6=0 33 | - libcufft=10.9.0.58=0 34 | - libcufft-dev=10.9.0.58=0 35 | - libcurand=10.3.7.68=0 36 | - libcurand-dev=10.3.7.68=0 37 | - libcusolver=11.4.1.48=0 38 | - libcusolver-dev=11.4.1.48=0 39 | - libcusparse=11.7.5.86=0 40 | - libcusparse-dev=11.7.5.86=0 41 | - libdeflate=1.17=h2bbff1b_1 42 | - libffi=3.4.4=hd77b12b_1 43 | - libjpeg-turbo=2.0.0=h196d8e1_0 44 | - libnpp=11.8.0.86=0 45 | - libnpp-dev=11.8.0.86=0 46 | - libnvjpeg=11.9.0.86=0 47 | - libnvjpeg-dev=11.9.0.86=0 48 | - libpng=1.6.39=h8cc25b3_0 49 | - libtiff=4.5.1=hd77b12b_0 50 | - libuv=1.48.0=h827c3e9_0 51 | - libwebp-base=1.3.2=h2bbff1b_0 52 | - lz4-c=1.9.4=h2bbff1b_1 53 | - mkl=2023.1.0=h6b88ed4_46358 54 | - mkl-service=2.4.0=py310h2bbff1b_1 55 | - mkl_fft=1.3.8=py310h2bbff1b_0 56 | - mkl_random=1.2.4=py310h59b6b97_0 57 | - mpc=1.1.0=h7edee0f_1 58 | - mpfr=4.0.2=h62dcd97_1 59 | - mpir=3.0.0=hec2e145_1 60 | - mpmath=1.3.0=py310haa95532_0 61 | - networkx=3.3=py310haa95532_0 62 | - numpy=1.26.4=py310h055cbcc_0 63 | - numpy-base=1.26.4=py310h65a83cf_0 64 | - openjpeg=2.5.2=hae555c5_0 65 | - openssl=3.0.14=h827c3e9_0 66 | - pip=24.2=py310haa95532_0 67 | - pysocks=1.7.1=py310haa95532_0 68 | - python=3.10.14=he1021f5_1 69 | - pytorch=2.4.0=py3.10_cuda11.8_cudnn9_0 70 | - pytorch-cuda=11.8=h24eeafa_5 71 | - pytorch-mutex=1.0=cuda 72 | - pyyaml=6.0.1=py310h2bbff1b_0 73 | - requests=2.32.3=py310haa95532_0 74 | - setuptools=72.1.0=py310haa95532_0 75 | - sqlite=3.45.3=h2bbff1b_0 76 | - sympy=1.13.2=py310haa95532_0 77 | - tbb=2021.8.0=h59b6b97_0 78 | - tk=8.6.14=h0416ee5_0 79 | - typing_extensions=4.11.0=py310haa95532_0 80 | - tzdata=2024a=h04d1e81_0 81 | - vc=14.40=h2eaa2aa_0 82 | - vs2015_runtime=14.40.33807=h98bb1dd_0 83 | - wheel=0.43.0=py310haa95532_0 84 | - win_inet_pton=1.1.0=py310haa95532_0 85 | - xz=5.4.6=h8cc25b3_1 86 | - yaml=0.2.5=he774522_0 87 | - zlib=1.2.13=h8cc25b3_1 88 | - zstd=1.5.5=hd43e919_2 89 | - pip: 90 | - abctoolkit==0.0.4 91 | - accelerate==0.34.0 92 | - aiohappyeyeballs==2.4.0 93 | - aiohttp==3.10.5 94 | - aiosignal==1.3.1 95 | - annotated-types==0.7.0 96 | - anyio==4.6.2.post1 97 | - async-timeout==4.0.3 98 | - attrs==24.2.0 99 | - audioread==3.0.1 100 | - certifi==2023.7.22 101 | - cffi==1.17.0 102 | - chardet==5.2.0 103 | - charset-normalizer==3.2.0 104 | - click==8.1.7 105 | - colorama==0.4.6 106 | - coloredlogs==15.0.1 107 | - cycler==0.11.0 108 | - datasets==2.21.0 109 | - decorator==5.1.1 110 | - dill==0.3.8 111 | - distro==1.9.0 112 | - docker-pycreds==0.4.0 113 | - exceptiongroup==1.2.2 114 | - filelock==3.12.2 115 | - fonttools==4.38.0 116 | - frozenlist==1.4.1 117 | - fsspec==2024.6.1 118 | - gitdb==4.0.11 119 | - gitpython==3.1.43 120 | - h11==0.14.0 121 | - httpcore==1.0.6 122 | - httpx==0.27.2 123 | - huggingface-hub==0.24.6 124 | - humanfriendly==10.0 125 | - idna==3.4 126 | - importlib-metadata==6.7.0 127 | - jellyfish==1.0.0 128 | - jiter==0.6.1 129 | - joblib==1.3.2 130 | - jsonpickle==3.0.2 131 | - kiwisolver==1.4.4 132 | - langcodes==3.4.0 133 | - langdetect==1.0.9 134 | - langid==1.1.6 135 | - language-data==1.2.0 136 | - lazy-loader==0.4 137 | - levenshtein==0.25.1 138 | - librosa==0.10.1 139 | - llvmlite==0.43.0 140 | - lxml==5.3.0 141 | - marisa-trie==1.2.0 142 | - markupsafe==2.1.5 143 | - matplotlib==3.5.3 144 | - mido==1.3.0 145 | - more-itertools==9.1.0 146 | - msgpack==1.0.8 147 | - multidict==6.0.5 148 | - multiprocess==0.70.16 149 | - music21==7.3.3 150 | - nltk==3.8.1 151 | - numba==0.60.0 152 | - openai==1.51.2 153 | - optimum==1.21.4 154 | - packaging==23.1 155 | - pandas==1.3.5 156 | - pillow==9.5.0 157 | - platformdirs==4.2.2 158 | - pooch==1.8.2 159 | - portalocker==2.10.1 160 | - protobuf==5.28.0 161 | - psutil==6.0.0 162 | - pyarrow==17.0.0 163 | - pycparser==2.22 164 | - pydantic==2.9.2 165 | - pydantic-core==2.23.4 166 | - pydub==0.25.1 167 | - pyparsing==3.1.1 168 | - pyreadline3==3.4.1 169 | - python-dateutil==2.8.2 170 | - pytz==2023.3 171 | - pywin32==306 172 | - rapidfuzz==3.9.7 173 | - rarfile==4.1 174 | - regex==2023.8.8 175 | - sacrebleu==2.4.3 176 | - sacremoses==0.0.53 177 | - safetensors==0.4.4 178 | - samplings==0.1.7 179 | - scikit-learn==1.5.1 180 | - scipy==1.14.1 181 | - sentencepiece==0.2.0 182 | - sentry-sdk==2.13.0 183 | - setproctitle==1.3.3 184 | - six==1.16.0 185 | - smmap==5.0.1 186 | - sniffio==1.3.1 187 | - soundfile==0.12.1 188 | - soxr==0.5.0.post1 189 | - tabulate==0.9.0 190 | - threadpoolctl==3.5.0 191 | - tokenizers==0.19.1 192 | - torch==2.4.0 193 | - torchaudio==2.4.0 194 | - torchvision==0.19.0 195 | - tqdm==4.66.5 196 | - transformers==4.40.0 197 | - typing-extensions==4.12.2 198 | - unidecode==1.3.6 199 | - urllib3==2.0.4 200 | - wandb==0.17.8 201 | - webcolors==1.13 202 | - xxhash==3.5.0 203 | - yarl==1.9.7 204 | - zipp==3.15.0 -------------------------------------------------------------------------------- /music_classification/README.md: -------------------------------------------------------------------------------- 1 | # Music Classification Codebase 2 | 3 | ## Overview 4 | Linear Probe is a powerful classification tool that leverages feature representations for supervised learning tasks. This codebase includes scripts for training a linear classification model, performing classification on new feature data. The features utilized can be extracted from the M3 or CLaMP 2 models, ensuring that the time dimension information is preserved and **not normalized**. Below is a description of the scripts contained in the `music_classification/` folder. 5 | 6 | ## Repository Structure 7 | The `music_classification/` folder contains the following scripts: 8 | 9 | ### 1. `config.py` 10 | This script defines configurations for the linear probe training and inference, specifying training data paths and parameters like learning rate, number of epochs, and hidden size. 11 | 12 | ### 2. `inference_cls.py` 13 | This script enables the classification of feature vectors using a pre-trained linear probe model. 14 | 15 | #### JSON Output Format 16 | The resulting JSON file contains a dictionary with the following structure: 17 | ```json 18 | { 19 | "path/to/feature1.npy": "class_A", 20 | "path/to/feature2.npy": "class_B", 21 | "path/to/feature3.npy": "class_A" 22 | } 23 | ``` 24 | - **Key**: The path to the input feature file (e.g., `feature1.npy`). 25 | - **Value**: The predicted class label assigned by the linear probe model (e.g., `class_A`). 26 | 27 | #### Usage 28 | ```bash 29 | python inference_cls.py 30 | ``` 31 | - `feature_folder`: Directory containing input feature files (in `.npy` format). 32 | - `output_file`: File path to save the classification results (in JSON format). 33 | 34 | ### 3. `train_cls.py` 35 | This script is designed for training the linear classification model. 36 | 37 | #### Usage 38 | ```bash 39 | python train_cls.py 40 | ``` 41 | 42 | ### 4. `utils.py` 43 | The utility script defines the architecture of the linear classification model. 44 | 45 | ## Naming Convention 46 | All `.npy` files used in this codebase must follow the naming convention of `label_filename.npy`, where the filename should not contain any underscores (`_`). 47 | -------------------------------------------------------------------------------- /music_classification/config.py: -------------------------------------------------------------------------------- 1 | # Configuration for generative modelling and classification 2 | TRAIN_FOLDERS = [ 3 | "" # Directory containing training data 4 | ] 5 | 6 | EVAL_FOLDERS = [ 7 | "" # (Optional) Directory containing evaluation data 8 | ] 9 | 10 | EVAL_SPLIT = 0.2 # Fraction of training data to use for evaluation 11 | 12 | # Weights and Biases configuration 13 | WANDB_KEY = "" # Set M3/CLaMP2_WANDB_LOG=False if no API key for Weights and Biases logging 14 | 15 | # Model Configuration 16 | INPUT_HIDDEN_SIZE = 768 # Input hidden size 17 | HIDDEN_SIZE = 768 # Model hidden size 18 | NUM_EPOCHS = 1000 # Max number of epochs to train (early stopping can terminate earlier) 19 | LEARNING_RATE = 1e-5 # Optimizer learning rate 20 | BALANCED_TRAINING = False # Set to True to balance labels in training data 21 | WANDB_LOG = False # Set to True to log training metrics to WANDB 22 | 23 | # Paths Configuration 24 | last_folder_name = TRAIN_FOLDERS[-1].split('/')[-1] 25 | WEIGHTS_PATH = f"weights-{last_folder_name}.pth" # Weights file path 26 | LOGS_PATH = f"logs-{last_folder_name}.txt" # Log file path 27 | -------------------------------------------------------------------------------- /music_classification/inference_cls.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import random 5 | import numpy as np 6 | from utils import * 7 | from tqdm import tqdm 8 | from samplings import * 9 | import argparse 10 | 11 | def list_files_in_directory(directories, extensions=["npy"]): 12 | file_list = [] 13 | 14 | for directory in directories: 15 | for root, dirs, files in os.walk(directory): 16 | for file in files: 17 | if any(file.endswith(ext) for ext in extensions): 18 | file_path = os.path.join(root, file) 19 | file_list.append(file_path) 20 | 21 | return file_list 22 | 23 | if __name__ == "__main__": 24 | # Setup argument parser 25 | parser = argparse.ArgumentParser(description="Feature extraction and classification with CLaMP2.") 26 | parser.add_argument("feature_folder", type=str, help="Directory containing input feature files.") 27 | parser.add_argument("output_file", type=str, help="File to save the classification results. (format: json)") 28 | 29 | # Parse arguments 30 | args = parser.parse_args() 31 | feature_folder = args.feature_folder 32 | output_file = args.output_file 33 | 34 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 35 | seed = 42 36 | random.seed(seed) 37 | np.random.seed(seed) 38 | torch.manual_seed(seed) 39 | torch.cuda.manual_seed_all(seed) 40 | torch.backends.cudnn.deterministic = True 41 | torch.backends.cudnn.benchmark = False 42 | 43 | checkpoint = torch.load(WEIGHTS_PATH, map_location='cpu') 44 | print(f"Successfully Loaded Checkpoint from Epoch {checkpoint['epoch']} with acc {checkpoint['max_eval_acc']}") 45 | label2idx = checkpoint['labels'] 46 | idx2label = {idx: label for label, idx in label2idx.items()} # Create reverse mapping 47 | model = LinearClassification(num_classes=len(label2idx)) 48 | model = model.to(device) 49 | 50 | # print parameter number 51 | print("Parameter Number: "+str(sum(p.numel() for p in model.parameters() if p.requires_grad))) 52 | 53 | model.eval() 54 | model.load_state_dict(checkpoint['model']) 55 | 56 | # load filenames under train and eval folder 57 | feature_files = list_files_in_directory([feature_folder]) 58 | cls_results = {} 59 | 60 | for filepath in tqdm(feature_files): 61 | outputs = np.load(filepath)[0] 62 | outputs = torch.from_numpy(outputs).to(device) 63 | outputs = outputs.unsqueeze(0) 64 | cls_list = model(outputs)[0].tolist() 65 | max_prob = max(cls_list) 66 | cls_idx = cls_list.index(max_prob) 67 | cls_label = idx2label[cls_idx] 68 | cls_results[filepath] = cls_label 69 | 70 | with open(output_file, "w", encoding="utf-8") as f: 71 | json.dump(cls_results, f) 72 | -------------------------------------------------------------------------------- /music_classification/train_cls.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import math 4 | import wandb 5 | import torch 6 | import random 7 | import numpy as np 8 | from utils import * 9 | from config import * 10 | from tqdm import tqdm 11 | from sklearn.metrics import f1_score 12 | from torch.amp import autocast, GradScaler 13 | from torch.utils.data import Dataset, DataLoader 14 | from transformers import get_constant_schedule_with_warmup 15 | import torch.distributed as dist 16 | from torch.nn.parallel import DistributedDataParallel as DDP 17 | from torch.utils.data.distributed import DistributedSampler 18 | 19 | # Set up distributed training 20 | world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 21 | global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else 0 22 | local_rank = int(os.environ['LOCAL_RANK']) if 'LOCAL_RANK' in os.environ else 0 23 | 24 | if world_size > 1: 25 | torch.cuda.set_device(local_rank) 26 | device = torch.device("cuda", local_rank) 27 | dist.init_process_group(backend='nccl') if world_size > 1 else None 28 | else: 29 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 30 | 31 | # Set random seed 32 | seed = 42 + global_rank 33 | random.seed(seed) 34 | np.random.seed(seed) 35 | torch.manual_seed(seed) 36 | torch.cuda.manual_seed_all(seed) 37 | torch.backends.cudnn.deterministic = True 38 | torch.backends.cudnn.benchmark = False 39 | 40 | batch_size = 1 41 | 42 | def collate_batch(input_tensors): 43 | 44 | input_tensors, labels = zip(*input_tensors) 45 | input_tensors = torch.stack(input_tensors, dim=0) 46 | labels = torch.stack(labels, dim=0) 47 | 48 | return input_tensors.to(device), labels.to(device) 49 | 50 | def list_files_in_directory(directories): 51 | file_list = [] 52 | 53 | for directory in directories: 54 | for root, dirs, files in os.walk(directory): 55 | for file in files: 56 | if file.endswith(".npy"): 57 | file_path = os.path.join(root, file) 58 | file_list.append(file_path) 59 | return file_list 60 | 61 | class TensorDataset(Dataset): 62 | def __init__(self, filenames): 63 | print(f"Loading {len(filenames)} files for classification") 64 | self.filenames = [] 65 | self.label2idx = {} 66 | 67 | for filename in tqdm(filenames): 68 | label = os.path.basename(filename).split('_')[0] 69 | 70 | self.filenames.append(filename) 71 | if label not in self.label2idx: 72 | self.label2idx[label] = len(self.label2idx) 73 | print(f"Found {len(self.label2idx)} classes") 74 | 75 | def __len__(self): 76 | return len(self.filenames) 77 | 78 | def __getitem__(self, idx): 79 | 80 | filename = self.filenames[idx] 81 | label = os.path.basename(filename).split('_')[0] 82 | label = self.label2idx[label] 83 | 84 | # load numpy file 85 | data = np.load(filename) 86 | data = torch.from_numpy(data)[0] 87 | label = torch.tensor(label) 88 | 89 | return data, label 90 | 91 | class BalancedTensorDataset(Dataset): 92 | def __init__(self, filenames): 93 | print(f"Loading {len(filenames)} files for classification") 94 | self.filenames = filenames 95 | self.label2idx = {} 96 | self.label2files = {} 97 | 98 | for filename in tqdm(filenames): 99 | label = os.path.basename(filename).split('_')[0] 100 | if label not in self.label2idx: 101 | self.label2idx[label] = len(self.label2idx) 102 | if label not in self.label2files: 103 | self.label2files[label] = [] 104 | self.label2files[label].append(filename) 105 | print(f"Found {len(self.label2idx)} classes") 106 | 107 | self.min_samples = min(len(files) for files in self.label2files.values()) 108 | 109 | self._update_epoch_filenames() 110 | 111 | def _update_epoch_filenames(self): 112 | self.epoch_filenames = [] 113 | for label, files in self.label2files.items(): 114 | sampled_files = random.sample(files, self.min_samples) 115 | self.epoch_filenames.extend(sampled_files) 116 | 117 | random.shuffle(self.epoch_filenames) 118 | 119 | def __len__(self): 120 | return len(self.epoch_filenames) 121 | 122 | def __getitem__(self, idx): 123 | filename = self.epoch_filenames[idx] 124 | label = os.path.basename(filename).split('_')[0] 125 | label = self.label2idx[label] 126 | 127 | data = np.load(filename) 128 | data = torch.from_numpy(data)[0] 129 | label = torch.tensor(label) 130 | 131 | return data, label 132 | 133 | def on_epoch_end(self): 134 | self._update_epoch_filenames() 135 | 136 | # load filenames under train and eval folder 137 | train_files = list_files_in_directory(TRAIN_FOLDERS) 138 | eval_files = list_files_in_directory(EVAL_FOLDERS) 139 | 140 | if len(eval_files)==0: 141 | random.shuffle(train_files) 142 | eval_files = train_files[:math.ceil(len(train_files)*EVAL_SPLIT)] 143 | train_files = train_files[math.ceil(len(train_files)*EVAL_SPLIT):] 144 | if BALANCED_TRAINING: 145 | train_set = BalancedTensorDataset(train_files) 146 | else: 147 | train_set = TensorDataset(train_files) 148 | eval_set = TensorDataset(eval_files) 149 | eval_set.label2idx = train_set.label2idx 150 | 151 | model = LinearClassification(num_classes=len(train_set.label2idx)) 152 | model = model.to(device) 153 | 154 | # print parameter number 155 | print("Parameter Number: "+str(sum(p.numel() for p in model.parameters() if p.requires_grad))) 156 | 157 | if world_size > 1: 158 | model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) 159 | 160 | scaler = GradScaler() 161 | is_autocast = True 162 | optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) 163 | loss_fn = torch.nn.CrossEntropyLoss() 164 | 165 | # call model with a batch of input 166 | def process_one_batch(batch): 167 | input_tensors, labels = batch 168 | logits = model(input_tensors) 169 | loss = loss_fn(logits, labels) 170 | prediction = torch.argmax(logits, dim=1) 171 | acc_num = torch.sum(prediction==labels) 172 | 173 | return loss, acc_num, prediction, labels 174 | 175 | # do one epoch for training 176 | def train_epoch(): 177 | tqdm_train_set = tqdm(train_set) 178 | total_train_loss = 0 179 | total_acc_num = 0 180 | iter_idx = 1 181 | model.train() 182 | 183 | for batch in tqdm_train_set: 184 | if is_autocast: 185 | with autocast(device_type='cuda'): 186 | loss, acc_num, prediction, labels = process_one_batch(batch) 187 | scaler.scale(loss).backward() 188 | scaler.step(optimizer) 189 | scaler.update() 190 | else: 191 | loss, acc_num, prediction, labels = process_one_batch(batch) 192 | loss.backward() 193 | optimizer.step() 194 | 195 | lr_scheduler.step() 196 | model.zero_grad(set_to_none=True) 197 | total_train_loss += loss.item() 198 | total_acc_num += acc_num.item() 199 | tqdm_train_set.set_postfix({str(global_rank)+'_train_acc': total_acc_num / (iter_idx*batch_size)}) 200 | # Log the training loss to wandb 201 | if global_rank==0 and WANDB_LOG: 202 | wandb.log({"acc": total_acc_num / (iter_idx*batch_size)}) 203 | 204 | iter_idx += 1 205 | 206 | if BALANCED_TRAINING: 207 | train_set.dataset.on_epoch_end() 208 | 209 | return total_acc_num / ((iter_idx-1)*batch_size) 210 | 211 | # do one epoch for eval 212 | def eval_epoch(): 213 | tqdm_eval_set = tqdm(eval_set) 214 | total_eval_loss = 0 215 | total_acc_num = 0 216 | iter_idx = 1 217 | model.eval() 218 | 219 | all_predictions = [] 220 | all_labels = [] 221 | 222 | # Evaluate data for one epoch 223 | for batch in tqdm_eval_set: 224 | with torch.no_grad(): 225 | loss, acc_num, prediction, labels = process_one_batch(batch) 226 | total_eval_loss += loss.item() 227 | total_acc_num += acc_num.item() 228 | 229 | # Accumulate predictions and labels 230 | all_predictions.extend(prediction.cpu().numpy()) 231 | all_labels.extend(labels.cpu().numpy()) 232 | 233 | tqdm_eval_set.set_postfix({str(global_rank)+'_eval_acc': total_acc_num / (iter_idx*batch_size)}) 234 | iter_idx += 1 235 | 236 | # Compute F1 Macro 237 | f1_macro = f1_score(all_labels, all_predictions, average='macro') 238 | return total_acc_num / ((iter_idx - 1) * batch_size), f1_macro 239 | 240 | # train and eval 241 | if __name__ == "__main__": 242 | 243 | label2idx = train_set.label2idx 244 | max_eval_acc = 0 245 | train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=global_rank) 246 | eval_sampler = DistributedSampler(eval_set, num_replicas=world_size, rank=global_rank) 247 | 248 | train_set = DataLoader(train_set, batch_size=batch_size, collate_fn=collate_batch, sampler=train_sampler, shuffle = (train_sampler is None)) 249 | eval_set = DataLoader(eval_set, batch_size=batch_size, collate_fn=collate_batch, sampler=eval_sampler, shuffle = (train_sampler is None)) 250 | 251 | lr_scheduler = get_constant_schedule_with_warmup(optimizer = optimizer, num_warmup_steps = len(train_set)) 252 | 253 | model = model.to(device) 254 | optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) 255 | 256 | if WANDB_LOG and global_rank==0: 257 | # Initialize wandb 258 | if WANDB_KEY: 259 | wandb.login(key=WANDB_KEY) 260 | wandb.init(project="linear", 261 | name=WEIGHTS_PATH.replace("weights_", "").replace(".pth", "")) 262 | 263 | for epoch in range(1, NUM_EPOCHS+1): 264 | train_sampler.set_epoch(epoch) 265 | eval_sampler.set_epoch(epoch) 266 | print('-' * 21 + "Epoch " + str(epoch) + '-' * 21) 267 | train_acc = train_epoch() 268 | eval_acc, eval_f1_macro = eval_epoch() 269 | if global_rank==0: 270 | with open(LOGS_PATH,'a') as f: 271 | f.write("Epoch " + str(epoch) + "\ntrain_acc: " + str(train_acc) + "\neval_acc: " +str(eval_acc) + "\neval_f1_macro: " +str(eval_f1_macro) + "\ntime: " + time.asctime(time.localtime(time.time())) + "\n\n") 272 | if eval_acc > max_eval_acc: 273 | best_epoch = epoch 274 | max_eval_acc = eval_acc 275 | checkpoint = { 276 | 'model': model.module.state_dict() if hasattr(model, "module") else model.state_dict(), 277 | 'optimizer': optimizer.state_dict(), 278 | 'lr_sched': lr_scheduler.state_dict(), 279 | 'epoch': epoch, 280 | 'best_epoch': best_epoch, 281 | 'max_eval_acc': max_eval_acc, 282 | "labels": label2idx 283 | } 284 | torch.save(checkpoint, WEIGHTS_PATH) 285 | with open(LOGS_PATH,'a') as f: 286 | f.write("Best Epoch so far!\n\n\n") 287 | 288 | if world_size > 1: 289 | dist.barrier() 290 | 291 | if global_rank==0: 292 | print("Best Eval Epoch : "+str(best_epoch)) 293 | print("Max Eval Accuracy : "+str(max_eval_acc)) 294 | -------------------------------------------------------------------------------- /music_classification/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from config import * 3 | 4 | class LinearClassification(torch.nn.Module): 5 | def __init__(self, num_classes): 6 | super(LinearClassification, self).__init__() 7 | self.fc1 = torch.nn.Linear(INPUT_HIDDEN_SIZE, HIDDEN_SIZE) 8 | self.relu = torch.nn.ReLU() 9 | self.fc2 = torch.nn.Linear(HIDDEN_SIZE, num_classes) 10 | self.softmax = torch.nn.Softmax(dim=1) 11 | 12 | def forward(self, x): 13 | # Apply the linear layer and ReLU to each time step 14 | x = self.fc1(x) # x shape (B, L, H) -> (B, L, hidden_size) 15 | x = self.relu(x) 16 | 17 | # Average over the time steps (L dimension) 18 | x = x.mean(dim=1) # Now x has shape (B, hidden_size) 19 | 20 | x = self.fc2(x) # Now applying the final layer (B, hidden_size) -> (B, num_classes) 21 | x = self.softmax(x) 22 | return x -------------------------------------------------------------------------------- /overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanderwood/clamp2/8404488ef36a61735c0295a5afcd6d3a3d74bbcd/overview.jpg -------------------------------------------------------------------------------- /process_data/README.md: -------------------------------------------------------------------------------- 1 | # Data Processing Database 2 | 3 | ## Overview 4 | This codebase contains scripts and utilities for converting between various musical data formats, including ABC notation, MusicXML, MIDI, and MTF (MIDI Text Format). Additionally, it includes a script for summarizing music metadata, which is represented in JSON format containing textual information, using the OpenAI GPT-4 API. The GPT-4 model processes this metadata to generate concise summaries in multiple languages to boost multilingual MIR. These tools are designed to facilitate the transformation and manipulation of musical files, as well as to provide concise multilingual summaries of music metadata for use with CLaMP 2. 5 | 6 | 7 | ## About ABC notation 8 | ### Standard ABC Notation 9 | ABC notation (sheet music), a text-based sheet music representation like stave notation, is theory-oriented and ideal for presenting complex musical concepts to musicians for study and analysis. Standard ABC notation encodes each voice separately, which often results in corresponding bars being spaced far apart. This separation makes it difficult for models to accurately understand the interactions between voices in sheet music that are meant to align musically. 10 | 11 | Example Standard ABC notation representation: 12 | ``` 13 | %%score { 1 | 2 } 14 | L:1/8 15 | Q:1/4=120 16 | M:3/4 17 | K:G 18 | V:1 treble nm="Piano" snm="Pno." 19 | V:2 bass 20 | V:1 21 | !mf!"^Allegro" d2 (GA Bc | d2) .G2 .G2 |] 22 | V:2 23 | [G,B,D]4 A,2 | B,6 |] 24 | ``` 25 | 26 | ### Interleaved ABC Notation 27 | In contrast, interleaved ABC notation effectively aligns multi-track music by integrating multiple voices of the same bar into a single line, ensuring that all parts remain synchronized. This format combines voices in-line and tags each bar with its corresponding voice (e.g., `[V:1]` for treble and `[V:2]` for bass). By directly aligning related bars, interleaved ABC notation enhances the model’s understanding of how different voices interact within the same bar. 28 | 29 | Below is the same data optimized with M3 encoding, where each bar or header corresponds to a patch: 30 | ``` 31 | %%score { 1 | 2 } 32 | L:1/8 33 | Q:1/4=120 34 | M:3/4 35 | K:G 36 | V:1 treble nm="Piano" snm="Pno." 37 | V:2 bass 38 | [V:1]!mf!"^Allegro" d2 (GA Bc|[V:2][G,B,D]4 A,2| 39 | [V:1]d2) .G2 .G2|][V:2]B,6|] 40 | ``` 41 | 42 | ## About MTF 43 | ### Raw MIDI Messages 44 | MIDI (performance data) precisely encodes performance information related to timing and dynamics, thus suitable for music production and live performance. Raw MIDI messages contain essential musical instructions and metadata, extracted directly from a MIDI file. These include events like note on/off, tempo changes, key signatures, and control changes, which define how the music is performed. The [mido library](https://mido.readthedocs.io/) allows for reading these messages in their native format, as seen below. Each message can include multiple parameters, making the output comprehensive but sometimes redundant. 45 | 46 | ``` 47 | MetaMessage ('time_signature', numerator=3, denominator=4, clocks_per_click=24, notated_32nd_notes_per_beat=8, time=0) 48 | MetaMessage('key_signature', key='G', time=0) 49 | MetaMessage('set_tempo', tempo=500000, time=0) 50 | control_change channel=0 control=121 value=0 time=0 51 | program_change channel=0 program=0 time=0 52 | control_change channel=0 control=7 value=100 time=0 53 | control_change channel=0 control=10 value=64 time=0 54 | control_change channel=0 control=91 value=0 time=0 55 | control_change channel=0 control=93 value=0 time=0 56 | MetaMessage('midi_port', port=0, time=0) 57 | note_on channel=0 note=74 velocity=80 time=0 58 | MetaMessage('key_signature', key='G', time=0) 59 | MetaMessage('midi_port', port=0, time=0) 60 | note_on channel=0 note=55 velocity=80 time=0 61 | note_on channel=0 note=59 velocity=80 time=0 62 | note_on channel=0 note=62 velocity=80 time=0 63 | note_on channel=0 note=74 velocity=0 time=455 64 | note_on channel=0 note=67 velocity=80 time=25 65 | note_on channel=0 note=67 velocity=0 time=239 66 | note_on channel=0 note=69 velocity=80 time=1 67 | note_on channel=0 note=55 velocity=0 time=191 68 | note_on channel=0 note=59 velocity=0 time=0 69 | note_on channel=0 note=62 velocity=0 time=0 70 | note_on channel=0 note=69 velocity=0 time=48 71 | note_on channel=0 note=71 velocity=80 time=1 72 | note_on channel=0 note=57 velocity=80 time=0 73 | note_on channel=0 note=71 velocity=0 time=239 74 | note_on channel=0 note=72 velocity=80 time=1 75 | note_on channel=0 note=57 velocity=0 time=215 76 | note_on channel=0 note=72 velocity=0 time=24 77 | note_on channel=0 note=74 velocity=80 time=1 78 | note_on channel=0 note=59 velocity=80 time=0 79 | note_on channel=0 note=74 velocity=0 time=455 80 | note_on channel=0 note=67 velocity=80 time=25 81 | note_on channel=0 note=67 velocity=0 time=239 82 | note_on channel=0 note=67 velocity=80 time=241 83 | note_on channel=0 note=67 velocity=0 time=239 84 | note_on channel=0 note=59 velocity=0 time=168 85 | MetaMessage('end_of_track', time=1) 86 | ``` 87 | ### MIDI Text Format (MTF) 88 | The MIDI Text Format (MTF) provides a structured, textual representation of MIDI data that preserves all original information without loss. Each MIDI message is accurately represented, allowing full reconstruction, ensuring no musical nuances are overlooked during conversion. 89 | 90 | To generate MTF, the mido library reads raw MIDI messages from MIDI files. The output retains all essential information but can be lengthy and redundant. To simplify the representation, parameter values are read in a fixed order and separated by spaces. For example, the raw time signature message, which contains several parameters—numerator, denominator, clocks per click, notated 32nd notes per beat, and time—is represented in MTF as: 91 | 92 | ``` 93 | time_signature 3 4 24 8 0 94 | ``` 95 | 96 | Other messages, such as control changes and note events, follow a similar compact format while preserving all relevant musical details. This structured simplification improves computational performance and maintains precise control over musical elements, including timing and dynamics. 97 | 98 | Example MTF representation: 99 | ``` 100 | ticks_per_beat 480 101 | time_signature 3 4 24 8 0 102 | key_signature G 0 103 | set_tempo 500000 0 104 | control_change 0 0 121 0 105 | program_change 0 0 0 106 | control_change 0 0 7 100 107 | control_change 0 0 10 64 108 | control_change 0 0 91 0 109 | control_change 0 0 93 0 110 | midi_port 0 0 111 | note_on 0 0 74 80 112 | key_signature G 0 113 | midi_port 0 0 114 | note_on 0 0 55 80 115 | note_on 0 0 59 80 116 | note_on 0 0 62 80 117 | note_on 455 0 74 0 118 | note_on 25 0 67 80 119 | note_on 239 0 67 0 120 | note_on 1 0 69 80 121 | note_on 191 0 55 0 122 | note_on 0 0 59 0 123 | note_on 0 0 62 0 124 | note_on 48 0 69 0 125 | note_on 1 0 71 80 126 | note_on 0 0 57 80 127 | note_on 239 0 71 0 128 | note_on 1 0 72 80 129 | note_on 215 0 57 0 130 | note_on 24 0 72 0 131 | note_on 1 0 74 80 132 | note_on 0 0 59 80 133 | note_on 455 0 74 0 134 | note_on 25 0 67 80 135 | note_on 239 0 67 0 136 | note_on 241 0 67 80 137 | note_on 239 0 67 0 138 | note_on 168 0 59 0 139 | end_of_track 1 140 | ``` 141 | For simplicity, `ticks_per_beat`, though originally an attribute of MIDI objects in mido, is included as the first message at the beginning of the MTF representation. 142 | 143 | ### M3-Encoded MTF 144 | When processed using M3 encoding, consecutive messages of the same type that fit within a 64-character limit (the patch size of M3) are combined into a single line. Only the first message in each group specifies the type, with subsequent messages listing only the parameter values separated by tabs. This further simplifies the representation and improves processing efficiency. 145 | 146 | Below is the same data optimized with M3 encoding, where each line corresponds to a patch: 147 | ``` 148 | ticks_per_beat 480 149 | time_signature 3 4 24 8 0 150 | key_signature G 0 151 | set_tempo 500000 0 152 | control_change 0 0 121 0 153 | program_change 0 0 0 154 | control_change 0 0 7 100\t0 0 10 64\t0 0 91 0\t0 0 93 0 155 | midi_port 0 0 156 | note_on 0 0 74 80 157 | key_signature G 0 158 | midi_port 0 0 159 | note_on 0 0 55 80\t0 0 59 80\t0 0 62 80\t455 0 74 0\t25 0 67 80 160 | note_on 239 0 67 0\t1 0 69 80\t191 0 55 0\t0 0 59 0\t0 0 62 0 161 | note_on 48 0 69 0\t1 0 71 80\t0 0 57 80\t239 0 71 0\t1 0 72 80 162 | note_on 215 0 57 0\t24 0 72 0\t1 0 74 80\t0 0 59 80\t455 0 74 0 163 | note_on 25 0 67 80\t239 0 67 0\t0 67 80\t239 0 67 0\t168 0 59 0 164 | end_of_track 1 165 | ``` 166 | 167 | By reducing redundancy, M3 encoding ensures improved computational performance while maintaining precise timing and musical control, making it an ideal choice for efficient MIDI processing. 168 | 169 | ## Repository Structure 170 | The `process_data/` folder includes the following scripts and utility files: 171 | 172 | ### 1. **Conversion Scripts** 173 | 174 | #### `batch_abc2xml.py` 175 | - **Purpose**: Converts ABC notation files into MusicXML format. 176 | - **Input**: Directory of interleaved ABC files (modify the `input_dir` variable in the code). 177 | - **Output**: MusicXML files saved in a newly created `_xml` directory. 178 | - **Logging**: Errors are logged to `logs/abc2xml_error_log.txt`. 179 | 180 | #### `batch_xml2abc.py` 181 | - **Purpose**: Converts MusicXML files into standard ABC notation format. 182 | - **Input**: Directory of MusicXML files (e.g., `.xml`, `.mxl`, `.musicxml`) (modify the `input_dir` variable in the code). 183 | - **Output**: Standard ABC files saved in a newly created `_abc` directory. 184 | - **Logging**: Errors are logged to `logs/xml2abc_error_log.txt`. 185 | 186 | #### `batch_interleaved_abc.py` 187 | - **Purpose**: Processes standard ABC notation files into interleaved ABC notation. 188 | - **Input**: Directory of ABC files (modify the `input_dir` variable in the code). 189 | - **Output**: Interleaved ABC files saved in a newly created `_interleaved` directory. 190 | - **Logging**: Any processing errors are printed to the console. 191 | 192 | #### `batch_midi2mtf.py` 193 | - **Purpose**: Converts MIDI files into MIDI Text Format (MTF). 194 | - **Input**: Directory of MIDI files (e.g., `.mid`, `.midi`) (modify the `input_dir` variable in the code). 195 | - **Output**: MTF files saved in a newly created `_mtf` directory. 196 | - **Logging**: Errors are logged to `logs/midi2mtf_error_log.txt`. 197 | - **Note**: The script includes an `m3_compatible` variable, which is set to `True` by default. When `True`, the conversion omits messages whose parameters are strings or lists to eliminate potential natural language information. This ensures that the converted MTF files align with the data format used for training the M3 and CLaMP 2 pretrained weights. 198 | 199 | #### `batch_mtf2midi.py` 200 | - **Purpose**: Converts MTF files into MIDI format. 201 | - **Input**: Directory of MTF files (modify the `input_dir` variable in the code). 202 | - **Output**: MIDI files saved in a newly created `_midi` directory. 203 | - **Logging**: Errors are logged to `logs/mtf2midi_error_log.txt`. 204 | 205 | ### 2. **Summarization Script** 206 | 207 | #### `gpt4_summarize.py` 208 | - **Purpose**: Utilizes the OpenAI GPT-4 API to generate concise summaries of music metadata in multiple languages. The script filters out any entries that lack sufficient musical information to ensure meaningful summaries are produced. 209 | - **Input**: Directory of JSON files containing music metadata (modify the `input_dir` variable in the code). For any missing metadata fields, the corresponding keys can be set to `None`. Each JSON file corresponds to a single musical composition and can be linked to both ABC notation and MTF formats. Here’s an example of the required metadata format: 210 | 211 | ```json 212 | { 213 | "title": "Hard Times Come Again No More", 214 | "composer": "Stephen Foster", 215 | "genres": ["Children's Music", "Folk"], 216 | "description": "\"Hard Times Come Again No More\" (sometimes referred to as \"Hard Times\") is an American parlor song written by Stephen Foster, reflecting themes of sorrow and hope.", 217 | "lyrics": "Let us pause in life's pleasures and count its many tears,\nWhile we all sup sorrow with the poor;\nThere's a song that will linger forever in our ears;\nOh! Hard times come again no more.\n\nChorus:\n'Tis the song, the sigh of the weary,\nHard Times, hard times, come again no more.\nMany days you have lingered around my cabin door;\nOh! Hard times come again no more.\n\nWhile we seek mirth and beauty and music light and gay,\nThere are frail forms fainting at the door;\nThough their voices are silent, their pleading looks will say\nOh! Hard times come again no more.\nChorus\n\nThere's a pale weeping maiden who toils her life away,\nWith a worn heart whose better days are o'er:\nThough her voice would be merry, 'tis sighing all the day,\nOh! Hard times come again no more.\nChorus\n\n'Tis a sigh that is wafted across the troubled wave,\n'Tis a wail that is heard upon the shore\n'Tis a dirge that is murmured around the lowly grave\nOh! Hard times come again no more.\nChorus", 218 | "tags": ["folk", "traditional", "bluegrass", "nostalgic", "heartfelt", "acoustic", "melancholic", "storytelling", "American roots", "resilience"], 219 | "ensembles": ["Folk Ensemble"], 220 | "instruments": ["Vocal", "Violin", "Tin whistle", "Guitar", "Banjo", "Tambourine"], 221 | "filepaths": [ 222 | "abc/American_Music/Folk_Traditions/19th_Century/Stephen_Foster/Hard_Times_Come_Again_No_More.abc", 223 | "mtf/American_Music/Folk_Traditions/19th_Century/Stephen_Foster/Hard_Times_Come_Again_No_More.mtf" 224 | ] 225 | } 226 | ``` 227 | 228 | - **Output**: JSON files containing structured summaries in both English and a randomly selected non-English language, chosen from a selection of 100 different non-English languages (in this case, Simplified Chinese). Here’s an example of the expected output format: 229 | 230 | ```json 231 | { 232 | "title": "Hard Times Come Again No More", 233 | "composer": "Stephen Foster", 234 | "genres": ["Children's Music", "Folk"], 235 | "description": "\"Hard Times Come Again No More\" (sometimes referred to as \"Hard Times\") is an American parlor song written by Stephen Foster, reflecting themes of sorrow and hope.", 236 | "lyrics": "Let us pause in life's pleasures and count its many tears,\nWhile we all sup sorrow with the poor;\nThere's a song that will linger forever in our ears;\nOh! Hard times come again no more.\n\nChorus:\n'Tis the song, the sigh of the weary,\nHard Times, hard times, come again no more.\nMany days you have lingered around my cabin door;\nOh! Hard times come again no more.\n\nWhile we seek mirth and beauty and music light and gay,\nThere are frail forms fainting at the door;\nThough their voices are silent, their pleading looks will say\nOh! Hard times come again no more.\nChorus\n\nThere's a pale weeping maiden who toils her life away,\nWith a worn heart whose better days are o'er:\nThough her voice would be merry, 'tis sighing all the day,\nOh! Hard times come again no more.\nChorus\n\n'Tis a sigh that is wafted across the troubled wave,\n'Tis a wail that is heard upon the shore\n'Tis a dirge that is murmured around the lowly grave\nOh! Hard times come again no more.\nChorus", 237 | "tags": ["folk", "traditional", "bluegrass", "nostalgic", "heartfelt", "acoustic", "melancholic", "storytelling", "American roots", "resilience"], 238 | "ensembles": ["Folk Ensemble"], 239 | "instruments": ["Vocal", "Violin", "Tin whistle", "Guitar", "Banjo", "Tambourine"], 240 | "summary_en": "\"Hard Times Come Again No More,\" composed by Stephen Foster, is a poignant American parlor song that explores themes of sorrow and hope. The lyrics reflect on the contrast between life's pleasures and its hardships, inviting listeners to acknowledge both joy and suffering. With a heartfelt chorus that repeats the line \"Hard times come again no more,\" the song resonates with nostalgia and resilience. It is often performed by folk ensembles and features a variety of instruments, including vocals, violin, guitar, and banjo, encapsulating the spirit of American roots music.", 241 | "summary_nen": { 242 | "language": "Chinese (Simplified)", 243 | "summary": "《艰难时光再无来临》是斯蒂芬·福斯特创作的一首感人至深的美国小歌厅歌曲,探讨了悲伤与希望的主题。歌词展现了生活的乐趣与艰辛之间的对比,邀请听众去感受快乐与痛苦的交织。歌曲中那句反复吟唱的“艰难时光再无来临”深情地表达了怀旧与坚韧。它常常由民谣乐队演奏,伴随着人声、小提琴、吉他和班卓琴等多种乐器,生动地展现了美国根源音乐的独特魅力。" 244 | }, 245 | "filepaths": [ 246 | "abc/American_Music/Folk_Traditions/19th_Century/Stephen_Foster/Hard_Times_Come_Again_No_More.abc", 247 | "mtf/American_Music/Folk_Traditions/19th_Century/Stephen_Foster/Hard_Times_Come_Again_No_More.mtf" 248 | ] 249 | } 250 | ``` 251 | 252 | - **Logging**: Errors are logged to `logs/gpt4_summarize_error_log.txt`. 253 | 254 | ### 3. **Utilities** 255 | - **`utils/`**: Contains utility files required for the conversion processes. 256 | 257 | ## Usage 258 | To use the scripts, modify the `input_dir` variable in each script to point to the directory containing your input files. Then run the script from the command line. Below are example commands for each script: 259 | 260 | ### Example Commands 261 | ```bash 262 | # Modify the input_dir variable in the script before running 263 | python batch_abc2xml.py 264 | python batch_xml2abc.py 265 | python batch_interleaved_abc.py 266 | python batch_midi2mtf.py 267 | python batch_mtf2midi.py 268 | python gpt4_summarize.py 269 | ``` 270 | 271 | ### Execution Order 272 | To achieve specific conversions, follow the order below: 273 | 274 | 1. **To obtain interleaved ABC notation**: 275 | - First, run `batch_xml2abc.py` to convert MusicXML files to ABC notation. 276 | - Then, run `batch_interleaved_abc.py` to process the ABC files into interleaved ABC notation. 277 | 278 | 2. **To obtain MTF**: 279 | - Run `batch_midi2mtf.py` to convert MIDI files into MTF. 280 | 281 | 3. **To convert interleaved ABC back to XML**: 282 | - Run `batch_xml2abc.py` on the interleaved ABC files to convert them back to MusicXML format. 283 | 284 | 4. **To convert MTF back to MIDI**: 285 | - Run `batch_mtf2midi.py` to convert MTF files back to MIDI format. 286 | 287 | 5. **To summarize music metadata**: 288 | - Run `gpt4_summarize.py` to generate summaries for the music metadata files in JSON format. This assumes you have a directory of JSON files that includes a `filepaths` key, which connects to the corresponding interleaved ABC and MTF files. 289 | 290 | ### Parameters 291 | To run the scripts, you need to configure the following parameters: 292 | 293 | - **`input_dir`**: This variable should be set to the directory containing the input files to be processed (such as ABC, MusicXML, MIDI, MTF, or JSON files), which is shared across all scripts. 294 | 295 | In addition to **`input_dir`**, the following parameters are specific to certain scripts: 296 | 297 | - **`m3_compatible`** (specific to `batch_midi2mtf.py`): 298 | - Default is `True`, which omits messages with parameters that are strings or lists to avoid including potential natural language information. 299 | - Setting this to `False` retains all MIDI messages, which is crucial for those planning to retrain models on custom datasets or needing precise MIDI reproduction. 300 | 301 | For **`gpt4_summarize.py`**, you also need to configure these parameters: 302 | 303 | 1. **`base_url`**: The base URL for the OpenAI API, used to initialize the client. 304 | 2. **`api_key`**: Your API key for authenticating requests, required for client initialization. 305 | 3. **`model`**: The GPT-4 model to use, specified when generating summaries. 306 | 307 | **Important**: When `m3_compatible` is set to `True`, the conversion back from MTF to MIDI using `batch_mtf2midi.py` may produce MIDI files that do not exactly match the original MIDI files. This discrepancy is unexpected; however, retraining both M3 and CLaMP 2 to address this issue would require approximately 6000 hours of H800 GPU hours. Considering that M3 and CLaMP 2 have already achieved state-of-the-art results on MIDI tasks, we have opted not to retrain. Therefore, if consistency with original MIDI files is critical for your application, it is advisable to set `m3_compatible` to `False`. 308 | -------------------------------------------------------------------------------- /process_data/batch_abc2xml.py: -------------------------------------------------------------------------------- 1 | input_dir = "" # Replace with the path to your folder containing interleaved ABC (.abc) files 2 | 3 | import os 4 | import math 5 | import random 6 | import subprocess 7 | from tqdm import tqdm 8 | from multiprocessing import Pool 9 | 10 | def convert_abc2xml(file_list): 11 | cmd = 'cmd /u /c python utils/abc2xml.py ' 12 | for file in tqdm(file_list): 13 | filename = file.split('/')[-1] # Extract file name 14 | output_dir = file.split('/')[:-1] # Extract directory path 15 | output_dir[0] = output_dir[0] + '_xml' # Create new output folder 16 | output_dir = '/'.join(output_dir) 17 | os.makedirs(output_dir, exist_ok=True) 18 | 19 | try: 20 | p = subprocess.Popen(cmd + '"' + file + '"', stdout=subprocess.PIPE, shell=True) 21 | result = p.communicate() 22 | output = result[0].decode('utf-8') 23 | 24 | if output == '': 25 | with open("logs/abc2xml_error_log.txt", "a", encoding="utf-8") as f: 26 | f.write(file + '\n') 27 | continue 28 | else: 29 | output_path = f"{output_dir}/" + ".".join(filename.split(".")[:-1]) + ".xml" 30 | with open(output_path, 'w', encoding='utf-8') as f: 31 | f.write(output) 32 | except Exception as e: 33 | with open("logs/abc2xml_error_log.txt", "a", encoding="utf-8") as f: 34 | f.write(file + ' ' + str(e) + '\n') 35 | pass 36 | 37 | if __name__ == '__main__': 38 | file_list = [] 39 | os.makedirs("logs", exist_ok=True) 40 | 41 | # Traverse the specified folder for ABC files 42 | for root, dirs, files in os.walk(input_dir): 43 | for file in files: 44 | if not file.endswith(".abc"): 45 | continue 46 | filename = os.path.join(root, file).replace("\\", "/") 47 | file_list.append(filename) 48 | 49 | # Prepare for multiprocessing 50 | file_lists = [] 51 | random.shuffle(file_list) 52 | for i in range(os.cpu_count()): 53 | start_idx = int(math.floor(i * len(file_list) / os.cpu_count())) 54 | end_idx = int(math.floor((i + 1) * len(file_list) / os.cpu_count())) 55 | file_lists.append(file_list[start_idx:end_idx]) 56 | 57 | pool = Pool(processes=os.cpu_count()) 58 | pool.map(convert_abc2xml, file_lists) 59 | -------------------------------------------------------------------------------- /process_data/batch_interleaved_abc.py: -------------------------------------------------------------------------------- 1 | input_dir = "" # Replace with the path to your folder containing standard ABC (.abc) files 2 | 3 | import os 4 | import re 5 | import random 6 | from multiprocessing import Pool 7 | from tqdm import tqdm 8 | from abctoolkit.utils import ( 9 | find_all_abc, 10 | remove_information_field, 11 | remove_bar_no_annotations, 12 | Quote_re, 13 | Barlines, 14 | strip_empty_bars 15 | ) 16 | from abctoolkit.rotate import rotate_abc 17 | from abctoolkit.check import check_alignment_unrotated 18 | 19 | def abc_pipeline(abc_path, input_dir, output_dir): 20 | """ 21 | Converts standard ABC notation to interleaved ABC notation. 22 | """ 23 | with open(abc_path, 'r', encoding='utf-8') as f: 24 | abc_lines = f.readlines() 25 | 26 | abc_lines = [line for line in abc_lines if line.strip() != ''] 27 | abc_lines = remove_information_field(abc_lines=abc_lines, 28 | info_fields=['X:', 'T:', 'C:', 'W:', 'w:', 'Z:', '%%MIDI']) 29 | abc_lines = remove_bar_no_annotations(abc_lines) 30 | 31 | # Remove escaped quotes and clean up barlines inside quotes 32 | for i, line in enumerate(abc_lines): 33 | if not (re.search(r'^[A-Za-z]:', line) or line.startswith('%')): 34 | abc_lines[i] = line.replace(r'\"', '') 35 | quote_contents = re.findall(Quote_re, line) 36 | for quote_content in quote_contents: 37 | for barline in Barlines: 38 | if barline in quote_content: 39 | line = line.replace(quote_content, '') 40 | abc_lines[i] = line 41 | 42 | try: 43 | stripped_abc_lines, bar_counts = strip_empty_bars(abc_lines) 44 | except Exception as e: 45 | print(abc_path, 'Error in stripping empty bars:', e) 46 | return 47 | 48 | if stripped_abc_lines is None: 49 | print(abc_path, 'Failed to strip') 50 | return 51 | 52 | # Check alignment 53 | _, bar_no_equal_flag, bar_dur_equal_flag = check_alignment_unrotated(stripped_abc_lines) 54 | if not bar_no_equal_flag: 55 | print(abc_path, 'Unequal bar number') 56 | if not bar_dur_equal_flag: 57 | print(abc_path, 'Unequal bar duration (unaligned)') 58 | 59 | # Construct the output path, maintaining input folder structure 60 | relative_path = os.path.relpath(abc_path, input_dir) # Get relative path from input dir 61 | output_file_path = os.path.join(output_dir, os.path.normpath(relative_path)) # Recreate output path 62 | os.makedirs(os.path.dirname(output_file_path), exist_ok=True) # Ensure output folder exists 63 | 64 | try: 65 | rotated_abc_lines = rotate_abc(stripped_abc_lines) 66 | except Exception as e: 67 | print(abc_path, 'Error in rotating:', e) 68 | return 69 | 70 | if rotated_abc_lines is None: 71 | print(abc_path, 'Failed to rotate') 72 | return 73 | 74 | with open(output_file_path, 'w', encoding='utf-8') as w: 75 | w.writelines(rotated_abc_lines) 76 | 77 | 78 | def abc_pipeline_list(abc_path_list, input_dir, output_dir): 79 | for abc_path in tqdm(abc_path_list): 80 | try: 81 | abc_pipeline(abc_path, input_dir, output_dir) 82 | except Exception as e: 83 | print(abc_path, e) 84 | pass 85 | 86 | 87 | def batch_abc_pipeline(input_dir): 88 | """ 89 | Batch process all ABC files from `input_dir`, converting them to interleaved notation. 90 | """ 91 | output_dir = input_dir + "_interleaved" 92 | os.makedirs(output_dir, exist_ok=True) 93 | 94 | abc_path_list = [abc_path for abc_path in find_all_abc(input_dir) if os.path.getsize(abc_path) > 0] 95 | random.shuffle(abc_path_list) 96 | print(f"Found {len(abc_path_list)} ABC files.") 97 | 98 | num_cpus = os.cpu_count() 99 | split_lists = [abc_path_list[i::num_cpus] for i in range(num_cpus)] 100 | 101 | with Pool(processes=num_cpus) as pool: 102 | pool.starmap(abc_pipeline_list, [(split, input_dir, output_dir) for split in split_lists]) 103 | 104 | 105 | if __name__ == '__main__': 106 | batch_abc_pipeline(input_dir) 107 | 108 | -------------------------------------------------------------------------------- /process_data/batch_midi2mtf.py: -------------------------------------------------------------------------------- 1 | input_dir = "" # Replace with the path to your folder containing MIDI (.midi, .mid) files 2 | m3_compatible = True # Set to True for M3 compatibility; set to False to retain all MIDI information during conversion. 3 | 4 | import os 5 | import math 6 | import mido 7 | import random 8 | from tqdm import tqdm 9 | from multiprocessing import Pool 10 | 11 | def msg_to_str(msg): 12 | str_msg = "" 13 | for key, value in msg.dict().items(): 14 | str_msg += " " + str(value) 15 | return str_msg.strip().encode('unicode_escape').decode('utf-8') 16 | 17 | 18 | def load_midi(filename): 19 | # Load a MIDI file 20 | mid = mido.MidiFile(filename) 21 | msg_list = ["ticks_per_beat " + str(mid.ticks_per_beat)] 22 | 23 | # Traverse the MIDI file 24 | for msg in mid.merged_track: 25 | if m3_compatible: 26 | if msg.is_meta: 27 | if msg.type in [ 28 | "text", "copyright", "track_name", "instrument_name", "lyrics", "marker", "cue_marker", 29 | "device_name", "sequencer_specific" 30 | ]: 31 | continue 32 | else: 33 | if msg.type in ["sysex"]: 34 | continue 35 | str_msg = msg_to_str(msg) 36 | msg_list.append(str_msg) 37 | 38 | return "\n".join(msg_list) 39 | 40 | 41 | def convert_midi2mtf(file_list): 42 | for file in tqdm(file_list): 43 | filename = os.path.basename(file) 44 | output_dir = os.path.join(os.path.dirname(file) + '_mtf') 45 | os.makedirs(output_dir, exist_ok=True) 46 | 47 | try: 48 | output = load_midi(file) 49 | log_path = os.path.join("logs", "midi2mtf_error_log.txt") 50 | if output == '': 51 | with open(log_path, 'a', encoding='utf-8') as f: 52 | f.write(file + '\n') 53 | continue 54 | else: 55 | output_file = os.path.join(output_dir, ".".join(filename.split(".")[:-1]) + '.mtf') 56 | with open(output_file, 'w', encoding='utf-8') as f: 57 | f.write(output) 58 | except Exception as e: 59 | log_path = os.path.join("logs", "midi2mtf_error_log.txt") 60 | with open(log_path, 'a', encoding='utf-8') as f: 61 | f.write(file + " " + str(e) + '\n') 62 | pass 63 | 64 | 65 | if __name__ == '__main__': 66 | file_list = [] 67 | os.makedirs("logs", exist_ok=True) 68 | 69 | # Traverse the specified folder for MIDI files 70 | for root, dirs, files in os.walk(input_dir): 71 | for file in files: 72 | if not file.endswith(".mid") and not file.endswith(".midi"): 73 | continue 74 | filename = os.path.join(root, file) 75 | file_list.append(filename) 76 | 77 | # Prepare for multiprocessing 78 | file_lists = [] 79 | random.shuffle(file_list) 80 | for i in range(os.cpu_count()): 81 | start_idx = int(math.floor(i * len(file_list) / os.cpu_count())) 82 | end_idx = int(math.floor((i + 1) * len(file_list) / os.cpu_count())) 83 | file_lists.append(file_list[start_idx:end_idx]) 84 | 85 | pool = Pool(processes=os.cpu_count()) 86 | pool.map(convert_midi2mtf, file_lists) 87 | -------------------------------------------------------------------------------- /process_data/batch_mtf2midi.py: -------------------------------------------------------------------------------- 1 | input_dir = "" # Replace with the path to your folder containing MTF (.mtf) files 2 | 3 | import os 4 | import math 5 | import mido 6 | import random 7 | from tqdm import tqdm 8 | from multiprocessing import Pool 9 | 10 | def str_to_msg(str_msg): 11 | type = str_msg.split(" ")[0] 12 | try: 13 | msg = mido.Message(type) 14 | except: 15 | msg = mido.MetaMessage(type) 16 | 17 | if type in ["text", "copyright", "track_name", "instrument_name", 18 | "lyrics", "marker", "cue_marker", "device_name"]: 19 | values = [type, " ".join(str_msg.split(" ")[1:-1]).encode('utf-8').decode('unicode_escape'), str_msg.split(" ")[-1]] 20 | elif "[" in str_msg or "(" in str_msg: 21 | is_bracket = "[" in str_msg 22 | left_idx = str_msg.index("[") if is_bracket else str_msg.index("(") 23 | right_idx = str_msg.index("]") if is_bracket else str_msg.index(")") 24 | list_str = [int(num) for num in str_msg[left_idx+1:right_idx].split(", ")] 25 | if not is_bracket: 26 | list_str = tuple(list_str) 27 | values = str_msg[:left_idx].split(" ") + [list_str] + str_msg[right_idx+1:].split(" ") 28 | values = [value for value in values if value != ""] 29 | else: 30 | values = str_msg.split(" ") 31 | 32 | if len(values) != 1: 33 | for idx, (key, content) in enumerate(msg.__dict__.items()): 34 | if key == "type": 35 | continue 36 | value = values[idx] 37 | if isinstance(content, int) or isinstance(content, float): 38 | float_value = float(value) 39 | value = float_value 40 | if value % 1 == 0: 41 | value = int(value) 42 | setattr(msg, key, value) 43 | 44 | return msg 45 | 46 | def convert_mtf2midi(file_list): 47 | for file in tqdm(file_list): 48 | filename = file.split('/')[-1] 49 | output_dir = file.split('/')[:-1] 50 | output_dir[0] = output_dir[0] + '_midi' 51 | output_dir = '/'.join(output_dir) 52 | os.makedirs(output_dir, exist_ok=True) 53 | try: 54 | with open(file, 'r', encoding='utf-8') as f: 55 | msg_list = f.read().splitlines() 56 | 57 | # Build a new MIDI file based on the MIDI messages 58 | new_mid = mido.MidiFile() 59 | new_mid.ticks_per_beat = int(msg_list[0].split(" ")[1]) 60 | 61 | track = mido.MidiTrack() 62 | new_mid.tracks.append(track) 63 | 64 | for msg in msg_list[1:]: 65 | if "unknown_meta" in msg: 66 | continue 67 | new_msg = str_to_msg(msg) 68 | track.append(new_msg) 69 | 70 | output_file_path = os.path.join(output_dir, os.path.basename(file).replace('.mtf', '.mid')) 71 | new_mid.save(output_file_path) 72 | except Exception as e: 73 | with open('logs/mtf2midi_error_log.txt', 'a', encoding='utf-8') as f: 74 | f.write(f"Error processing {file}: {str(e)}\n") 75 | 76 | if __name__ == '__main__': 77 | file_list = [] 78 | os.makedirs("logs", exist_ok=True) 79 | 80 | # Traverse the specified folder for MTF files 81 | for root, dirs, files in os.walk(input_dir): 82 | for file in files: 83 | if not file.endswith(".mtf"): 84 | continue 85 | filename = os.path.join(root, file).replace("\\", "/") 86 | file_list.append(filename) 87 | 88 | # Prepare for multiprocessing 89 | file_lists = [] 90 | random.shuffle(file_list) 91 | for i in range(os.cpu_count()): 92 | start_idx = int(math.floor(i * len(file_list) / os.cpu_count())) 93 | end_idx = int(math.floor((i + 1) * len(file_list) / os.cpu_count())) 94 | file_lists.append(file_list[start_idx:end_idx]) 95 | 96 | pool = Pool(processes=os.cpu_count()) 97 | pool.map(convert_mtf2midi, file_lists) 98 | -------------------------------------------------------------------------------- /process_data/batch_xml2abc.py: -------------------------------------------------------------------------------- 1 | input_dir = "" # Replace with the path to your folder containing XML (.xml, .mxl, .musicxml) files 2 | 3 | import os 4 | import math 5 | import random 6 | import subprocess 7 | from tqdm import tqdm 8 | from multiprocessing import Pool 9 | 10 | 11 | def convert_xml2abc(file_list): 12 | cmd = 'python utils/xml2abc.py -d 8 -x ' 13 | for file in tqdm(file_list): 14 | filename = os.path.basename(file) 15 | output_dir = os.path.dirname(file) 16 | output_dir = os.path.join(output_dir + '_abc') # Add '_abc' to the output directory 17 | os.makedirs(output_dir, exist_ok=True) 18 | 19 | try: 20 | p = subprocess.Popen(cmd + '"' + file + '"', stdout=subprocess.PIPE, shell=True) 21 | result = p.communicate() 22 | output = result[0].decode('utf-8') 23 | 24 | if output == '': 25 | with open("logs/xml2abc_error_log.txt", "a", encoding="utf-8") as f: 26 | f.write(file + '\n') 27 | continue 28 | else: 29 | with open(os.path.join(output_dir, filename.rsplit('.', 1)[0] + '.abc'), 'w', encoding='utf-8') as f: 30 | f.write(output) 31 | except Exception as e: 32 | with open("logs/xml2abc_error_log.txt", "a", encoding="utf-8") as f: 33 | f.write(file + ' ' + str(e) + '\n') 34 | 35 | 36 | if __name__ == '__main__': 37 | file_list = [] 38 | os.makedirs("logs", exist_ok=True) 39 | 40 | # Traverse the specified folder for XML/MXL files 41 | for root, dirs, files in os.walk(os.path.abspath(input_dir)): 42 | for file in files: 43 | if file.endswith((".mxl", ".xml", ".musicxml")): 44 | filename = os.path.join(root, file).replace("\\", "/") 45 | file_list.append(filename) 46 | 47 | # Shuffle and prepare for multiprocessing 48 | random.shuffle(file_list) 49 | num_files = len(file_list) 50 | num_processes = os.cpu_count() 51 | file_lists = [file_list[i::num_processes] for i in range(num_processes)] 52 | 53 | # Create a pool for processing 54 | with Pool(processes=num_processes) as pool: 55 | pool.map(convert_xml2abc, file_lists) -------------------------------------------------------------------------------- /process_data/gpt4_summarize.py: -------------------------------------------------------------------------------- 1 | input_dir = "" # Replace with the path to your folder containing metadata (.json) files 2 | base_url = "" # Replace with the base URL for the API 3 | api_key = "" # Replace with your API key 4 | model = "" # Replace with your model name 5 | 6 | import os 7 | import json 8 | import random 9 | from openai import OpenAI 10 | 11 | # Initialize the OpenAI client 12 | client = OpenAI(base_url=base_url, api_key=api_key) 13 | 14 | def log_error(file_path, error_message): 15 | """Logs error messages to a specified log file.""" 16 | os.makedirs("logs", exist_ok=True) 17 | with open("logs/gpt4_summarize_error_log.txt", 'a', encoding='utf-8') as log_file: 18 | log_file.write(f"Error processing {file_path}: {error_message}\n") 19 | 20 | def process_json(metadata, language): 21 | """ 22 | Processes the given metadata of a music piece using GPT-4 API. 23 | 24 | This function sends the metadata and target language to the GPT-4 model to generate 25 | a structured summary. The summary is provided in both English and the specified 26 | non-English language from the 'nen_language' field. 27 | 28 | If the provided metadata lacks sufficient music-related details, the function returns `None`. 29 | 30 | Parameters: 31 | - metadata (dict): A dictionary containing the metadata of the music piece. 32 | - language (str): The target non-English language for the summary. 33 | 34 | Returns: 35 | - str: A JSON-formatted string containing the English and non-English summaries, 36 | or `None` if there is insufficient information. 37 | """ 38 | system = """Your task is to provide a concise, comprehensive, and coherent summary of the music piece using the provided metadata. Please write the summary in English first, and then write an equivalent summary in the specified non-English language from the "nen_language" field. Use this JSON format: 39 | { 40 | "summary_en": "Your English summary here.", 41 | "summary_nen": { 42 | "language": "Specified non-English language.", 43 | "summary": "Your non-English summary here." 44 | } 45 | If there is not enough music-related information, return `None` instead. 46 | } 47 | """ 48 | user1 = """{ 49 | "title": "Brejeiro", 50 | "composer": "Ernesto Nazareth", 51 | "genres": ["Choro", "Classical", "Instrumental"], 52 | "description": "\"Brejeiro\" is in A major and 2/4 time. A joyful melody begins at bar six, and a lively tango rhythm starts at bar fourteen. It has a D.C. al Fine at bar fifty-three and ends on two quarter notes in bar thirty-seven. The piece, with its vibrant melodies and rhythms, reflects celebration and carefreeness, embodying the spirit of Brazilian music.", 53 | "tags": ["Brazilian", "Choro", "Piano"], 54 | "ensembles": ["Solo Piano", "Small Ensemble"], 55 | "instruments": ["Piano"], 56 | "nen_language": "Japanese" 57 | } 58 | """ 59 | assistant1 = """{ 60 | "summary_en": "Brejeiro, composed by Ernesto Nazareth, is a lively choro piece in A major and 2/4 time. It features a joyful melody that begins at bar six and a vibrant tango rhythm introduced at bar fourteen. The piece includes a D.C. al Fine at bar fifty-three, concluding on two quarter notes in bar thirty-seven. With its themes of celebration and carefreeness, Brejeiro beautifully captures the essence of Brazilian music and is well-suited for solo piano and small ensembles.", 61 | "summary_nen": { 62 | "language": "Japanese", 63 | "summary": "「ブレジェイロ」は、エルネスト・ナザレが作曲した活気あふれるショーロの作品で、イ長調の2/4拍子で書かれています。第6小節から始まる喜びに満ちたメロディーと、第14小節で導入される活気あるタンゴのリズムが特徴です。この曲には、第53小節でのD.C. al Fineが含まれ、また第37小節で二つの四分音符で締めくくられています。「ブレジェイロ」は、お祝いと無邪気さのテーマを持ち、ブラジル音楽の本質を美しく捉えており、ソロピアノや小編成のアンサンブルにぴったりの作品です。" 64 | } 65 | } 66 | """ 67 | user2 = """{ 68 | "title": "Untitled", 69 | "composer": "Unknown", 70 | "description": "This is a good song.", 71 | "nen_language": "Russian" 72 | } 73 | """ 74 | assistant2 = "None" 75 | filepaths = metadata.pop('filepaths') 76 | metadata = {k: v for k, v in metadata.items() if v is not None} 77 | 78 | metadata["nen_language"] = language 79 | metadata = json.dumps(metadata, ensure_ascii=False, indent=4) 80 | summaries = client.chat.completions.create( 81 | model=model, 82 | messages=[ 83 | {"role": "system", "content": system}, 84 | {"role": "user", "content": user1}, 85 | {"role": "assistant", "content": assistant1}, 86 | {"role": "user", "content": user2}, 87 | {"role": "assistant", "content": assistant2}, 88 | {"role": "user", "content": metadata}, 89 | ] 90 | ).choices[0].message.content 91 | 92 | if summaries == "None": 93 | raise ValueError("Received 'None' as summaries response") 94 | 95 | metadata = json.loads(metadata) 96 | summaries = json.loads(summaries) 97 | 98 | if metadata["nen_language"] == summaries["summary_nen"]["language"]: 99 | metadata.pop("nen_language") 100 | metadata["summary_en"] = summaries["summary_en"] 101 | metadata["summary_nen"] = summaries["summary_nen"] 102 | metadata["filepaths"] = filepaths 103 | return metadata 104 | else: 105 | raise ValueError("Language mismatch: nen_language does not match summary_nen language") 106 | 107 | def process_files(input_dir): 108 | # Create output directory with _summarized suffix 109 | output_dir = input_dir + "_summarized" 110 | 111 | # Define available languages 112 | languages = """Afrikaans 113 | Amharic 114 | Arabic 115 | Assamese 116 | Azerbaijani 117 | Belarusian 118 | Bulgarian 119 | Bengali 120 | Bengali (Romanized) 121 | Breton 122 | Bosnian 123 | Catalan 124 | Czech 125 | Welsh 126 | Danish 127 | German 128 | Greek 129 | Esperanto 130 | Spanish 131 | Estonian 132 | Basque 133 | Persian 134 | Finnish 135 | French 136 | Western Frisian 137 | Irish 138 | Scottish Gaelic 139 | Galician 140 | Gujarati 141 | Hausa 142 | Hebrew 143 | Hindi 144 | Hindi (Romanized) 145 | Croatian 146 | Hungarian 147 | Armenian 148 | Indonesian 149 | Icelandic 150 | Italian 151 | Japanese 152 | Javanese 153 | Georgian 154 | Kazakh 155 | Khmer 156 | Kannada 157 | Korean 158 | Kurdish (Kurmanji) 159 | Kyrgyz 160 | Latin 161 | Lao 162 | Lithuanian 163 | Latvian 164 | Malagasy 165 | Macedonian 166 | Malayalam 167 | Mongolian 168 | Marathi 169 | Malay 170 | Burmese 171 | Burmese (Romanized) 172 | Nepali 173 | Dutch 174 | Norwegian 175 | Oromo 176 | Oriya 177 | Punjabi 178 | Polish 179 | Pashto 180 | Portuguese 181 | Romanian 182 | Russian 183 | Sanskrit 184 | Sindhi 185 | Sinhala 186 | Slovak 187 | Slovenian 188 | Somali 189 | Albanian 190 | Serbian 191 | Sundanese 192 | Swedish 193 | Swahili 194 | Tamil 195 | Tamil (Romanized) 196 | Telugu 197 | Telugu (Romanized) 198 | Thai 199 | Filipino 200 | Turkish 201 | Uyghur 202 | Ukrainian 203 | Urdu 204 | Urdu (Romanized) 205 | Uzbek 206 | Vietnamese 207 | Xhosa 208 | Yiddish 209 | Chinese (Simplified) 210 | Chinese (Traditional) 211 | Cantonese""" 212 | languages = [language.strip() for language in languages.split("\n")] 213 | 214 | # Walk through the input directory 215 | for root, _, files in os.walk(input_dir): 216 | # Construct the corresponding path in the output folder 217 | relative_path = os.path.relpath(root, input_dir) 218 | output_path = os.path.join(output_dir, relative_path) 219 | 220 | # Create the output directory if it doesn't exist 221 | os.makedirs(output_path, exist_ok=True) 222 | 223 | for file in files: 224 | if file.endswith('.json'): 225 | input_file = os.path.join(root, file) 226 | output_file = os.path.join(output_path, file) 227 | 228 | try: 229 | # Read the JSON file 230 | with open(input_file, 'r', encoding='utf-8') as f: 231 | metadata = json.load(f) 232 | 233 | # Randomly select a language from the list of languages 234 | language = random.choice(languages) 235 | 236 | # Process the JSON data 237 | processed_metadata = process_json(metadata, language) 238 | 239 | # Write the processed JSON to the output file 240 | with open(output_file, 'w', encoding='utf-8') as f: 241 | json.dump(processed_metadata, f, indent=4, ensure_ascii=False) 242 | 243 | print(f"Processed: {input_file} -> {output_file}") 244 | 245 | except Exception as e: 246 | print(f"Failed to process {input_file}: {e}") 247 | log_error(input_file, str(e)) 248 | 249 | if __name__ == "__main__": 250 | process_files(input_dir) 251 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # torch>=2.5.0 2 | # torchvision 3 | # torchaudio 4 | horovod # remove this on ARM-based macs.. 5 | abctoolkit 6 | accelerate 7 | mido 8 | numpy<=2.0.0 9 | openai>=1.52.2 10 | samplings==0.1.7 11 | scikit_learn 12 | tqdm 13 | transformers>=4.46.0 14 | Unidecode 15 | wandb 16 | -------------------------------------------------------------------------------- /semantic_search/README.md: -------------------------------------------------------------------------------- 1 | # Semantic Search Codebase 2 | 3 | ## Overview 4 | CLaMP 2 is a state-of-the-art multimodal music information retrieval system designed to work with 101 languages. This codebase includes scripts for evaluating model performance, performing semantic searches, and calculating similarity metrics based on CLaMP2-extracted **nomarlized** feature vectors from music or text data. Below is a description of the scripts contained in the `semantic_search/` folder. 5 | 6 | ## Repository Structure 7 | The `semantic_search/` folder contains the following scripts: 8 | 9 | ### 1. `clamp2_score.py` 10 | This script calculates the cosine similarity between the average feature vectors extracted from two sets of `.npy` files, serving as a measure of similarity between the reference and test datasets. 11 | 12 | It can be used to validate the semantic similarity between generated music and ground truth, providing an objective metric. Through empirical observation, we found that this metric aligns well with subjective judgments made by individuals with professional music expertise. 13 | 14 | **Usage:** 15 | ```bash 16 | python clamp2_score.py 17 | ``` 18 | - `reference_folder`: Path to the folder containing reference `.npy` files. 19 | - `test_folder`: Path to the folder containing test `.npy` files. 20 | 21 | **Functionality:** 22 | - Loads all `.npy` files from the specified folders. 23 | - Computes the average feature vector for each folder. 24 | - Calculates the cosine similarity between the two averaged vectors. 25 | - Outputs the similarity score rounded to four decimal places. 26 | 27 | ### 2. `semantic_search.py` 28 | This script performs semantic search by calculating the cosine similarity between a query feature and a set of features stored in `.npy` files. 29 | 30 | **Usage:** 31 | ```bash 32 | python semantic_search.py [--top_k TOP_K] 33 | ``` 34 | - `query_file`: Path to the query feature file (e.g., `ballad.npy`). 35 | - `features_folder`: Path to the folder containing feature files for comparison. 36 | - `--top_k`: (Optional) Number of top similar items to display. Defaults to 10 if not specified. 37 | 38 | **Functionality:** 39 | - Loads a query feature from the specified file. 40 | - Loads feature vectors from the given folder. 41 | - Computes cosine similarity between the query feature and each loaded feature vector. 42 | - Displays the top K most similar features along with their similarity scores. 43 | 44 | ### 3. `semantic_search_metrics.py` 45 | This script calculates evaluation metrics for semantic search by comparing query features to reference features. 46 | 47 | **Usage:** 48 | ```bash 49 | python semantic_search_metrics.py 50 | ``` 51 | - `query_folder`: Path to the folder containing query features (in `.npy` format). 52 | - `reference_folder`: Path to the folder containing reference features (in `.npy` format). 53 | 54 | **Functionality:** 55 | - Loads query features from the specified folder. 56 | - Loads reference features from the given folder. 57 | - Computes the following metrics based on cosine similarity: 58 | - **Mean Reciprocal Rank (MRR)** 59 | - **Hit@1** 60 | - **Hit@10** 61 | - **Hit@100** 62 | - Outputs the calculated metrics to the console. 63 | -------------------------------------------------------------------------------- /semantic_search/clamp2_score.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import argparse 4 | 5 | def load_npy_files(folder_path): 6 | """ 7 | Load all .npy files from a specified folder and return a list of numpy arrays. 8 | """ 9 | npy_list = [] 10 | for file_name in os.listdir(folder_path): 11 | if file_name.endswith('.npy'): 12 | file_path = os.path.join(folder_path, file_name) 13 | np_array = np.load(file_path)[0] 14 | npy_list.append(np_array) 15 | return npy_list 16 | 17 | def average_npy(npy_list): 18 | """ 19 | Compute the average of a list of numpy arrays. 20 | """ 21 | return np.mean(npy_list, axis=0) 22 | 23 | def cosine_similarity(vec1, vec2): 24 | """ 25 | Compute cosine similarity between two numpy arrays. 26 | """ 27 | dot_product = np.dot(vec1, vec2) 28 | 29 | norm_vec1 = np.linalg.norm(vec1) 30 | norm_vec2 = np.linalg.norm(vec2) 31 | 32 | cosine_sim = dot_product / (norm_vec1 * norm_vec2) 33 | 34 | return cosine_sim 35 | 36 | if __name__ == '__main__': 37 | # Set up argument parsing for input folders 38 | parser = argparse.ArgumentParser(description="Calculate cosine similarity between average feature vectors.") 39 | parser.add_argument('reference', type=str, help='Path to the reference folder containing .npy files.') 40 | parser.add_argument('test', type=str, help='Path to the test folder containing .npy files.') 41 | args = parser.parse_args() 42 | 43 | reference = args.reference 44 | test = args.test 45 | # Load .npy files 46 | ref_npy = load_npy_files(reference) 47 | test_npy = load_npy_files(test) 48 | 49 | # Compute the average of each list of numpy arrays 50 | avg_ref = average_npy(ref_npy) 51 | avg_test = average_npy(test_npy) 52 | 53 | # Compute the cosine similarity between the two averaged numpy arrays 54 | similarity = cosine_similarity(avg_ref, avg_test) 55 | 56 | # Output the cosine similarity rounded to four decimal places 57 | print(f"Cosine similarity between '{reference}' and '{test}': {similarity:.4f}") 58 | -------------------------------------------------------------------------------- /semantic_search/semantic_search.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import argparse 5 | 6 | def get_info(folder_path): 7 | """ 8 | Load all .npy files from a specified folder and return a dictionary of features. 9 | """ 10 | files = sorted(os.listdir(folder_path)) 11 | features = {} 12 | 13 | for file in files: 14 | if file.endswith(".npy"): 15 | key = file.split(".")[0] 16 | features[key] = np.load(os.path.join(folder_path, file))[0] 17 | 18 | return features 19 | 20 | def main(query_file, features_folder, top_k=10): 21 | # Load query feature from the specified file 22 | query_feature = np.load(query_file)[0] # Load directly from the query file 23 | query_tensor = torch.tensor(query_feature).unsqueeze(dim=0) 24 | 25 | # Load key features from the specified folder 26 | key_features = get_info(features_folder) 27 | 28 | # Prepare tensor for key features 29 | key_feats_tensor = torch.tensor(np.array([key_features[k] for k in key_features.keys()])) 30 | 31 | # Calculate cosine similarity 32 | similarities = torch.cosine_similarity(query_tensor, key_feats_tensor) 33 | ranked_indices = torch.argsort(similarities, descending=True) 34 | 35 | # Get the keys for the features 36 | keys = list(key_features.keys()) 37 | 38 | print(f"Top {top_k} similar items:") 39 | for i in range(top_k): 40 | print(keys[ranked_indices[i]], similarities[ranked_indices[i]].item()) 41 | 42 | if __name__ == '__main__': 43 | # Set up argument parsing for input paths 44 | parser = argparse.ArgumentParser(description="Find top similar features based on cosine similarity.") 45 | parser.add_argument('query_file', type=str, help='Path to the query feature file (e.g., ballad.npy).') 46 | parser.add_argument('features_folder', type=str, help='Path to the folder containing feature files for comparison.') 47 | parser.add_argument('--top_k', type=int, default=10, help='Number of top similar items to display (default: 10).') 48 | args = parser.parse_args() 49 | 50 | # Execute the main functionality 51 | main(args.query_file, args.features_folder, args.top_k) 52 | -------------------------------------------------------------------------------- /semantic_search/semantic_search_metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import argparse 5 | 6 | def get_features(path): 7 | """ 8 | Load and return feature data from .npy files in the given directory. 9 | Each feature is stored in a dictionary with the filename (without extension) as the key. 10 | """ 11 | files = sorted(os.listdir(path)) 12 | features = {} 13 | 14 | for file in files: 15 | if file.endswith(".npy"): 16 | key = file.split(".")[0] 17 | features[key] = np.load(os.path.join(path, file))[0] 18 | 19 | return features 20 | 21 | def calculate_metrics(query_features, reference_features): 22 | """ 23 | Calculate MRR, Hit@1, Hit@10, and Hit@100 metrics based on the similarity 24 | between query and reference features. 25 | """ 26 | common_keys = set(query_features.keys()) & set(reference_features.keys()) 27 | mrr, hit_1, hit_10, hit_100 = 0, 0, 0, 0 28 | 29 | for idx, key in enumerate(common_keys): 30 | # Convert query feature to tensor and add batch dimension 31 | query_feat = torch.tensor(query_features[key]).unsqueeze(dim=0) 32 | 33 | # Collect all reference features for common keys 34 | ref_feats = torch.tensor(np.array([reference_features[k] for k in common_keys])) 35 | 36 | # Compute cosine similarity between the query and all reference features 37 | similarities = torch.cosine_similarity(query_feat, ref_feats) 38 | 39 | # Create a list of (similarity, index) pairs 40 | indexed_sims = list(enumerate(similarities.tolist())) 41 | 42 | # Sort by similarity in descending order, with idx-based tie-breaking 43 | sorted_indices = sorted(indexed_sims, key=lambda x: (x[1], x[0] == idx), reverse=True) 44 | 45 | # Extract the sorted rank list 46 | ranks = [x[0] for x in sorted_indices] 47 | 48 | # Calculate MRR 49 | mrr += 1 / (ranks.index(idx) + 1) 50 | 51 | # Calculate Hit@1, Hit@10, Hit@100 52 | if idx in ranks[:100]: 53 | hit_100 += 1 54 | if idx in ranks[:10]: 55 | hit_10 += 1 56 | if idx in ranks[:1]: 57 | hit_1 += 1 58 | 59 | # Compute the final metrics 60 | total_keys = len(common_keys) 61 | print(f"MRR: {round(mrr / total_keys, 4)}") 62 | print(f"Hit@1: {round(hit_1 / total_keys, 4)}") 63 | print(f"Hit@10: {round(hit_10 / total_keys, 4)}") 64 | print(f"Hit@100: {round(hit_100 / total_keys, 4)}") 65 | 66 | if __name__ == '__main__': 67 | # Set up argument parsing for input directories 68 | parser = argparse.ArgumentParser(description="Calculate similarity metrics between query and reference features.") 69 | parser.add_argument('query_folder', type=str, help='Path to the folder containing query features (.npy files).') 70 | parser.add_argument('reference_folder', type=str, help='Path to the folder containing reference features (.npy files).') 71 | args = parser.parse_args() 72 | 73 | # Load features from the specified folders 74 | query_features = get_features(args.query_folder) 75 | reference_features = get_features(args.reference_folder) 76 | 77 | # Calculate and print the metrics 78 | calculate_metrics(query_features, reference_features) 79 | --------------------------------------------------------------------------------