├── .github └── workflows │ └── black.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── aria ├── __init__.py ├── config.py ├── datasets.py ├── inference │ ├── __init__.py │ └── model.py ├── model.py ├── run.py ├── sample.py ├── tokenizer.py ├── train.py └── utils.py ├── config ├── accelerate.yaml ├── config.json └── models │ ├── large.json │ └── medium.json ├── models └── placeholder.txt ├── requirements-dev.txt ├── requirements.txt ├── scripts ├── download_data.sh ├── midi_to_audio.py └── upload_data.sh ├── setup.py └── tests ├── __init__.py ├── test_data.py ├── test_data ├── arabesque.mid ├── bach.mid ├── basic.mid ├── beethoven_moonlight.mid ├── beethoven_sonata.mid ├── clean │ ├── 1.mid │ └── 2.mid ├── expressive.mid ├── noisy │ ├── 1.mid │ └── 2.mid ├── pop.mid └── pop_copy.mid └── test_tokenizers.py /.github/workflows/black.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | lint: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v3 10 | - uses: psf/black@stable 11 | with: 12 | options: "--check --verbose --line-length 80" 13 | src: "./aria" 14 | - name: Check formatting result 15 | run: | 16 | if [ $? -ne 0 ]; then 17 | echo "Formatting check failed. Please run 'make format' to fix formatting issues." 18 | exit 1 19 | fi 20 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | # Project specific 163 | tools/ 164 | ./data/ 165 | fluidsynth/ 166 | *.DS_Store 167 | tests/test_results 168 | lightning_logs/ 169 | .vscode/ 170 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include config * 2 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: test 2 | test: 3 | python -m unittest tests/test_*.py 4 | 5 | 6 | .PHONY: format 7 | format: 8 | black --line-length 80 ./aria 9 | black --line-length 80 ./tests 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # gpt-aria 2 | 3 | [Discord](https://discord.com/invite/zBGx3azzUn) 4 | 5 | A repository containing resources for pre-training, fine-tuning, and evaluating musical (MIDI) transformer models. 6 | 7 | ***Note that this project is under active development*** 8 | 9 | ## Description 10 | 11 | The main goal of the gpt-aria project is to create a suite of powerful pre-trained generative (symbolic) music models. We want to investigate how modern training (pre-training & fine-tuning) techniques can be used to improve the quality/usefulness of such models. Alongside this we are building various data (MIDI) preprocessing tools, allowing **you** to easily fine-tune our models on your own data. 12 | 13 | If you are new to symbolic music models, a good place to start are the following projects/blogposts by Google Magenta and OpenAI: 14 | 15 | - [Music Transformer](https://magenta.tensorflow.org/music-transformer) 16 | - [MuseNet](https://openai.com/research/musenet) 17 | 18 | Long story short: Transformer + MIDI + GPUs = 🎵 x ∞ 19 | 20 | ## Installation 21 | 22 | Make sure you are using Python 3.10+. Note that I haven't explicitly developed this project for anything other than Linux. If you are using Windows, things might not work properly. In this case I suggest installing using WSL. 23 | 24 | ``` 25 | git clone https://github.com/eleutherai/aria 26 | cd aria 27 | pip install -e . 28 | ``` 29 | 30 | ## Inference 31 | 32 | You can find preliminary checkpoints at the following locations 33 | 34 | Finetuned piano-only checkpoints (improved robustness): 35 | 36 | ``` 37 | large - https://storage.googleapis.com/aria-checkpoints/large-abs-inst.safetensors 38 | ``` 39 | 40 | Pretrained checkpoints: 41 | 42 | ``` 43 | large - https://storage.googleapis.com/aria-checkpoints/large-abs-pt.bin 44 | medium - https://storage.googleapis.com/aria-checkpoints/medium-abs-pt.bin 45 | small - https://storage.googleapis.com/aria-checkpoints/small-abs-pt.bin 46 | ``` 47 | 48 | You can then sample using the cli: 49 | 50 | ``` 51 | aria sample \ 52 | -m large \ 53 | -c \ 54 | -p \ 55 | -var \ 56 | -trunc \ 57 | -l \ 58 | -temp 0.95 \ 59 | -e 60 | ``` 61 | 62 | You can use `aria sample -h` to see a full list of options. If you wish to sample from a pretrained checkpoint, please use the `-pt` flag. 63 | 64 | 65 | -------------------------------------------------------------------------------- /aria/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EleutherAI/aria/fedf7630437318e8733eb5ad1dc441cdf9c07b5b/aria/__init__.py -------------------------------------------------------------------------------- /aria/config.py: -------------------------------------------------------------------------------- 1 | """Includes functionality for loading config files.""" 2 | 3 | import os 4 | import json 5 | 6 | from functools import lru_cache 7 | 8 | 9 | CONFIG_DIR = os.path.join(os.path.dirname(__file__), "..", "config") 10 | 11 | 12 | @lru_cache(maxsize=1) 13 | def load_config(): 14 | """Returns a dictionary loaded from the config.json file.""" 15 | with open(os.path.join(CONFIG_DIR, "config.json")) as f: 16 | return json.load(f) 17 | 18 | 19 | def load_model_config(name: str): 20 | """Returns a dictionary containing the model config.""" 21 | model_config_path = os.path.join(CONFIG_DIR, "models", f"{name}.json") 22 | assert os.path.isfile( 23 | model_config_path 24 | ), f"Could not find config file for model={name} in config/models" 25 | with open(model_config_path) as f: 26 | return json.load(f) 27 | -------------------------------------------------------------------------------- /aria/datasets.py: -------------------------------------------------------------------------------- 1 | """Contains classes and utilities for building and processing datasets.""" 2 | 3 | import json 4 | import os 5 | import copy 6 | import re 7 | import mmap 8 | import jsonlines 9 | import logging 10 | import random 11 | import torch 12 | import functools 13 | import shutil 14 | import multiprocessing 15 | 16 | from mido.midifiles.units import second2tick 17 | from pathlib import Path 18 | from typing import List 19 | from copy import deepcopy 20 | from typing import Callable, Iterable 21 | from collections import defaultdict 22 | 23 | from aria.config import load_config 24 | from aria.tokenizer import InferenceAbsTokenizer 25 | from ariautils.tokenizer import Tokenizer 26 | from ariautils.midi import ( 27 | MidiDict, 28 | get_test_fn, 29 | get_duration_ms, 30 | get_metadata_fn, 31 | ) 32 | 33 | 34 | def setup_logger(): 35 | # Get logger and reset all handlers 36 | logger = logging.getLogger(__name__) 37 | for h in logger.handlers[:]: 38 | logger.removeHandler(h) 39 | 40 | logger.propagate = False 41 | logger.setLevel(logging.INFO) 42 | formatter = logging.Formatter( 43 | "[%(asctime)s] %(name)s: [%(levelname)s] %(message)s", 44 | ) 45 | 46 | ch = logging.StreamHandler() 47 | ch.setLevel(logging.INFO) 48 | ch.setFormatter(formatter) 49 | logger.addHandler(ch) 50 | 51 | return logger 52 | 53 | 54 | # TODO: Change the build settings so that it saves the config used for tests 55 | # as json on the first line. 56 | class MidiDataset: 57 | """Container for datasets of MidiDict objects. 58 | 59 | Can be used to save, load, and build, datasets of MidiDict objects. 60 | 61 | Args: 62 | entries (list[MidiDict] | Iterable): MidiDict objects to be stored. 63 | """ 64 | 65 | def __init__(self, entries: list[MidiDict] | Iterable): 66 | self.entries = entries 67 | 68 | def __len__(self): 69 | if not isinstance(self.entries, list): 70 | self.entries = list(self.entries) 71 | return len(self.entries) 72 | 73 | def __getitem__(self, ind: int): 74 | if not isinstance(self.entries, list): 75 | self.entries = list(self.entries) 76 | return self.entries[ind] 77 | 78 | def __iter__(self): 79 | yield from self.entries 80 | 81 | def shuffle(self): 82 | if not isinstance(self.entries, list): 83 | self.entries = list(self.entries) 84 | random.shuffle(self.entries) 85 | 86 | def save(self, save_path: str): 87 | """Saves dataset to JSON file.""" 88 | 89 | with jsonlines.open(save_path, mode="w") as writer: 90 | for midi_dict in self.entries: 91 | writer.write(midi_dict.get_msg_dict()) 92 | 93 | @classmethod 94 | def load(cls, load_path: str): 95 | """Loads dataset (into memory) from JSONL file.""" 96 | with jsonlines.open(load_path) as reader: 97 | _entries = [MidiDict.from_msg_dict(_) for _ in reader] 98 | 99 | return cls(_entries) 100 | 101 | @classmethod 102 | def get_generator(cls, load_path: str): 103 | """Given a MidiDataset JSONL file, returns a MidiDict generator. 104 | 105 | This generator must be reloaded each time you want to iterate over the 106 | file. Internally it iterating over the jsonl file located at load_path. 107 | """ 108 | 109 | def generator(): 110 | with jsonlines.open(load_path, "r") as midi_dataset: 111 | for entry in midi_dataset: 112 | try: 113 | midi_dict = MidiDict.from_msg_dict(entry) 114 | except Exception as e: 115 | logging.info(f"Failed to load MidiDict from file: {e}") 116 | else: 117 | yield midi_dict 118 | 119 | return generator() 120 | 121 | @classmethod 122 | def split_from_file( 123 | cls, 124 | load_path: str, 125 | train_val_ratio: float = 0.95, 126 | repeatable: bool = False, 127 | overwrite: bool = False, 128 | ): 129 | """Splits MidiDataset JSONL file into train/val split.""" 130 | logger = setup_logger() 131 | path = Path(load_path) 132 | train_save_path = path.with_name(f"{path.stem}_train{path.suffix}") 133 | val_save_path = path.with_name(f"{path.stem}_val{path.suffix}") 134 | 135 | if not overwrite: 136 | if os.path.isfile(train_save_path) is True: 137 | raise FileExistsError( 138 | f"File at {train_save_path} already exists." 139 | ) 140 | if os.path.isfile(val_save_path) is True: 141 | raise FileExistsError( 142 | f"File at {val_save_path} already exists." 143 | ) 144 | 145 | if repeatable: 146 | random.seed(42) # The answer to the universe 147 | 148 | idx_original, idx_train, idx_val = 0, 0, 0 149 | 150 | logger.info(f"Creating train/val split with ratio {train_val_ratio}") 151 | with ( 152 | jsonlines.open(load_path) as dataset, 153 | jsonlines.open(train_save_path, mode="w") as train_dataset, 154 | jsonlines.open(val_save_path, mode="w") as val_dataset, 155 | ): 156 | for entry in dataset: 157 | idx_original += 1 158 | if random.uniform(0, 1) <= train_val_ratio: 159 | idx_train += 1 160 | train_dataset.write(entry) 161 | else: 162 | idx_val += 1 163 | val_dataset.write(entry) 164 | 165 | logger.info( 166 | f"Succesfully split into train ({idx_train}) and validation ({idx_val}) sets" 167 | ) 168 | 169 | @classmethod 170 | def build( 171 | cls, 172 | dir: str, 173 | recur: bool = False, 174 | manual_metadata: dict = {}, 175 | shuffle: bool = True, 176 | ): 177 | """Builds are returns a MidiDataset - see build_mididict_dataset.""" 178 | valid_metadata = load_config()["data"]["metadata"]["manual"] 179 | for k, v in manual_metadata.items(): 180 | assert k in valid_metadata.keys(), f"{manual_metadata} is invalid" 181 | assert v in valid_metadata[k], f"{manual_metadata} is invalid" 182 | 183 | return cls( 184 | build_mididict_dataset( 185 | dir=dir, 186 | recur=recur, 187 | manual_metadata=manual_metadata, 188 | shuffle=shuffle, 189 | ) 190 | ) 191 | 192 | @classmethod 193 | def build_to_file( 194 | cls, 195 | dir: str, 196 | save_path: str, 197 | recur: bool = False, 198 | overwrite: bool = False, 199 | manual_metadata: dict = {}, 200 | shuffle: bool = True, 201 | ): 202 | """Builds MidiDataset to a JSONL file - see build_mididict_dataset. 203 | 204 | This function will not return a MidiDataset object. It is well suited 205 | for situations where the resulting MidiDataset will not fit in the 206 | system's memory. Other than this difference, it is identical to 207 | MidiDataset.build. 208 | """ 209 | valid_metadata = load_config()["data"]["metadata"]["manual"] 210 | for k, v in manual_metadata.items(): 211 | assert k in valid_metadata.keys(), f"{manual_metadata} is invalid" 212 | assert v in valid_metadata[k], f"{manual_metadata} is invalid" 213 | 214 | build_mididict_dataset( 215 | dir=dir, 216 | recur=recur, 217 | stream_save_path=save_path, 218 | overwrite=overwrite, 219 | manual_metadata=manual_metadata, 220 | shuffle=shuffle, 221 | ) 222 | 223 | @classmethod 224 | def combine_datasets_from_file(cls, *args: str, output_path: str): 225 | """Utility for concatenating JSONL files, checking for duplicates""" 226 | logger = setup_logger() 227 | 228 | for input_path in args: 229 | assert os.path.isfile(input_path), f"{input_path} doesn't exist" 230 | 231 | dupe_cnt = 0 232 | hashes = {} 233 | with jsonlines.open(output_path, mode="w") as f_out: 234 | for input_path in args: 235 | assert ( 236 | os.path.splitext(input_path)[-1] == ".jsonl" 237 | ), "invalid dataset path" 238 | 239 | with jsonlines.open(input_path, mode="r") as f_in: 240 | for msg_dict in f_in: 241 | midi_dict = MidiDict.from_msg_dict(msg_dict) 242 | midi_dict_hash = midi_dict.calculate_hash() 243 | if hashes.get(midi_dict_hash, False) is not False: 244 | dupe_cnt += 1 245 | else: 246 | f_out.write(msg_dict) 247 | hashes[midi_dict_hash] = True 248 | logger.info(f"Finished processing: {input_path}") 249 | logger.info( 250 | f"{len(hashes)} unique midi_dicts and {dupe_cnt} duplicates so far" 251 | ) 252 | 253 | logger.info( 254 | f"Found {len(hashes)} unique midi_dicts and {dupe_cnt} duplicates" 255 | ) 256 | 257 | 258 | def _get_mididict(path: Path): 259 | # This function is only intended to be used as a process target during the 260 | # multi-processing in build_mididict_dataset. It returns a tuple of the form 261 | # (bool, (MidiDict, str, Path)) where the first element determines if the 262 | # loaded MidiDict was succesfully preprocessed. 263 | 264 | def _add_metadata(_mid_dict: MidiDict): 265 | for metadata_process_name, metadata_process_config in config[ 266 | "metadata" 267 | ]["functions"].items(): 268 | if metadata_process_config["run"] is True: 269 | metadata_fn = get_metadata_fn( 270 | metadata_process_name=metadata_process_name 271 | ) 272 | fn_args: dict = metadata_process_config["args"] 273 | 274 | collected_metadata = metadata_fn(_mid_dict, **fn_args) 275 | if collected_metadata: 276 | for k, v in collected_metadata.items(): 277 | _mid_dict.metadata[k] = v 278 | 279 | return _mid_dict 280 | 281 | def _run_tests(_mid_dict: MidiDict): 282 | failed_tests = [] 283 | for test_name, test_config in config["tests"].items(): 284 | if test_config["run"] is True: 285 | test_fn = get_test_fn(test_name) 286 | test_args = test_config["args"] 287 | 288 | test_res, val = test_fn(_mid_dict, **test_args) 289 | if test_res is False: 290 | failed_tests.append((test_name, val)) 291 | 292 | return failed_tests 293 | 294 | def _preprocess_mididict(_mid_dict: MidiDict): 295 | for fn_name, fn_config in config["pre_processing"].items(): 296 | if fn_config["run"] is True: 297 | fn_args = fn_config["args"] 298 | getattr(_mid_dict, fn_name)(fn_args) 299 | 300 | try: 301 | # Note fn_args is passed as a dict, not unpacked as kwargs 302 | getattr(_mid_dict, fn_name)(fn_args) 303 | except: 304 | logger.error( 305 | f"Error finding preprocessing function for {fn_name}" 306 | ) 307 | 308 | return _mid_dict 309 | 310 | logger = setup_logger() 311 | config = load_config()["data"] 312 | 313 | try: 314 | mid_dict = MidiDict.from_midi(mid_path=path) 315 | except Exception as e: 316 | logger.error(f"Failed to load MIDI at {path}: {e}") 317 | return False, None 318 | 319 | failed_tests = _run_tests(mid_dict) 320 | if failed_tests: 321 | logger.info( 322 | f"MIDI at {path} failed preprocessing tests: {failed_tests} " 323 | ) 324 | return False, None 325 | else: 326 | mid_dict = _preprocess_mididict(mid_dict) 327 | mid_dict = _add_metadata(mid_dict) 328 | mid_hash = mid_dict.calculate_hash() 329 | return True, (mid_dict, mid_hash, path) 330 | 331 | 332 | def build_mididict_dataset( 333 | dir: str, 334 | recur: bool = False, 335 | stream_save_path: str = None, 336 | overwrite: bool = False, 337 | manual_metadata: dict = {}, 338 | shuffle: bool = True, 339 | ): 340 | """Builds dataset of MidiDicts. 341 | 342 | During the build process, successfully parsed MidiDicts can be filtered and 343 | preprocessed. This can be customized by modifying the config.json file. 344 | 345 | Args: 346 | dir (str): Directory to index from. 347 | recur (bool): If True, recursively search directories for MIDI files. 348 | Defaults to False. 349 | stream_save_path (str): If True, stream the dictionaries directly to a 350 | JSONL file instead of returning them as a list. This option is 351 | appropriate when processing very large numbers of MIDI files. 352 | overwrite (bool): If True, overwrite file at stream_save_path when 353 | streaming. 354 | manual_metadata (dict): Metadata tags to uniformly apply. 355 | shuffle (dict): Metadata tags to apply uniformly. 356 | 357 | Returns: 358 | list[MidiDict]: List of parsed, filtered, and preprocessed MidiDicts. 359 | This is only returned if stream_save_path is not provided. 360 | """ 361 | 362 | def _get_mididicts_mp(_paths): 363 | with multiprocessing.Pool() as pool: 364 | results = pool.imap_unordered(_get_mididict, _paths) 365 | seen_hashes = defaultdict(list) 366 | dupe_cnt = 0 367 | failed_cnt = 0 368 | for idx, (success, result) in enumerate(results): 369 | if idx % 50 == 0 and idx != 0: 370 | logger.info(f"Processed MIDI files: {idx}/{num_paths}") 371 | 372 | if not success: 373 | failed_cnt += 1 374 | continue 375 | else: 376 | mid_dict, mid_hash, mid_path = result 377 | 378 | if seen_hashes.get(mid_hash): 379 | logger.info( 380 | f"MIDI located at '{mid_path}' is a duplicate - already" 381 | f" seen at: {seen_hashes[mid_hash][0]}" 382 | ) 383 | seen_hashes[mid_hash].append(str(mid_path)) 384 | dupe_cnt += 1 385 | else: 386 | seen_hashes[mid_hash].append(str(mid_path)) 387 | yield mid_dict 388 | 389 | logger.info(f"Total duplicates: {dupe_cnt}") 390 | logger.info( 391 | f"Total processing fails (tests or otherwise): {failed_cnt}" 392 | ) 393 | 394 | logger = setup_logger() 395 | if multiprocessing.get_start_method() == "spawn": 396 | logger.warning( 397 | 'The current multiprocessing start method is "spawn", this ' 398 | "will slow down dataset building" 399 | ) 400 | 401 | paths = [] 402 | if recur is True: 403 | paths += Path(dir).rglob(f"*.mid") 404 | paths += Path(dir).rglob(f"*.midi") 405 | else: 406 | paths += Path(dir).glob(f"*.mid") 407 | paths += Path(dir).glob(f"*.midi") 408 | 409 | num_paths = len(paths) 410 | if num_paths == 0: 411 | raise FileNotFoundError( 412 | "Directory contains no files matching *.mid or *.midi" 413 | ) 414 | if shuffle is True: 415 | logger.info(f"Shuffling {num_paths} paths") 416 | random.shuffle(paths) 417 | else: 418 | logger.info(f"Ordering {num_paths} paths") 419 | base_path = Path(dir) 420 | paths.sort(key=lambda _path: _path.relative_to(base_path).as_posix()) 421 | 422 | cnt = 0 423 | if stream_save_path is None: 424 | # Not streaming -> return entries directly 425 | entries = [] 426 | for entry in _get_mididicts_mp(_paths=paths): 427 | # manual_metadata should already be validated 428 | for k, v in manual_metadata.items(): 429 | # Only add if it doesn't exist, stops overwriting 430 | if entry.metadata.get(k) is None: 431 | entry.metadata[k] = v 432 | 433 | cnt += 1 434 | entries.append(entry) 435 | 436 | return entries 437 | else: 438 | # Streaming -> write to file instead of returning anything 439 | if overwrite is False and os.path.isfile(stream_save_path) is True: 440 | raise FileExistsError(f"File at {stream_save_path} already exists.") 441 | 442 | with jsonlines.open(stream_save_path, mode="w") as writer: 443 | for entry in _get_mididicts_mp(paths): 444 | # manual_metadata should already be validated 445 | for k, v in manual_metadata.items(): 446 | # Only add if it doesn't exist, stops overwriting 447 | if entry.metadata.get(k) is None: 448 | entry.metadata[k] = v 449 | 450 | cnt += 1 451 | writer.write(entry.get_msg_dict()) 452 | 453 | logger.info( 454 | f"Finished - added {cnt}/{len(paths)} found MIDI files to dataset." 455 | ) 456 | 457 | 458 | class TrainingDataset(torch.utils.data.Dataset): 459 | def __init__(self, tokenizer: Tokenizer): 460 | self.tokenizer = tokenizer 461 | self.logger = setup_logger() 462 | self._transform = None 463 | self.config = None 464 | self.max_seq_len = None 465 | 466 | self.epoch_files_by_dir = [] 467 | self.dir_paths = [] 468 | self.curr_epoch = None 469 | self.file_buffs = None 470 | self.file_mmaps = None 471 | self.index = None 472 | 473 | def build(**kwargs): 474 | raise NotImplementedError 475 | 476 | def get_loss_mask(self, src_seq: list, tgt_seq: list): 477 | # Should returns a bool Tensor with False indicating a masked loss 478 | raise NotImplementedError 479 | 480 | def init_epoch(self, idx: int | None = None): 481 | # If idx not provided, increment curr_epoch 482 | if idx is None: 483 | assert self.curr_epoch is not None, "curr_epoch not initialized" 484 | self.curr_epoch += 1 485 | else: 486 | self.curr_epoch = idx 487 | 488 | self.close() 489 | self.index = [] 490 | self.file_buffs = [] 491 | self.file_mmaps = [] 492 | for dir_idx, epoch_files in enumerate(self.epoch_files_by_dir): 493 | num_epoch_files = len(epoch_files) 494 | epoch_file_idx = self.curr_epoch % num_epoch_files 495 | if self.curr_epoch >= num_epoch_files: 496 | self.logger.warning( 497 | f"File doesn't exist for epoch={self.curr_epoch} for dir={os.path.dirname(epoch_files[0])}, cycling to to epoch={epoch_file_idx}" 498 | ) 499 | 500 | epoch_file_path = epoch_files[epoch_file_idx] 501 | _buff = open(epoch_file_path, mode="r") 502 | self.file_buffs.append(_buff) 503 | _mmap_obj = mmap.mmap(_buff.fileno(), 0, access=mmap.ACCESS_READ) 504 | self.file_mmaps.append(_mmap_obj) 505 | _index = self._build_index(_mmap_obj) 506 | 507 | self.logger.info( 508 | f"Built index ({len(_index)}) for file {epoch_file_path}" 509 | ) 510 | self.index.extend([(dir_idx, pos) for pos in _index]) 511 | 512 | self.logger.info( 513 | f"Initiated epoch {self.curr_epoch} with index length={len(self.index)}" 514 | ) 515 | 516 | def _get_epoch_files(self, dir_path: str): 517 | """Validates and returns a sorted list of epoch dataset files.""" 518 | file_names = [ 519 | file_name 520 | for file_name in os.listdir(dir_path) 521 | if os.path.isfile(os.path.join(dir_path, file_name)) 522 | ] 523 | file_paths = [ 524 | os.path.join(dir_path, file_name) for file_name in file_names 525 | ] 526 | 527 | present_epochs = [] 528 | for file_name in file_names: 529 | if not re.match(r"^epoch\d+\.jsonl$", file_name): 530 | self.logger.warning( 531 | f"Found file with unexpected name: {file_name}" 532 | ) 533 | else: 534 | present_epochs.append( 535 | int(re.match(r"^epoch(\d+)\.jsonl$", file_name).group(1)) 536 | ) 537 | assert len(present_epochs) >= 1, f"no epoch files found in {dir_path}" 538 | assert set(present_epochs) == set( 539 | range(len(present_epochs)) 540 | ), "epoch files missing" 541 | 542 | # Check files have valid configs 543 | for file_path in file_paths: 544 | self.check_config(epoch_load_path=file_path) 545 | 546 | return [ 547 | os.path.join(dir_path, f"epoch{idx}.jsonl") 548 | for idx in range(len(present_epochs)) 549 | ] 550 | 551 | def get_epoch_files_by_dir(self, dir_paths: str): 552 | for dir_path in dir_paths: 553 | assert os.path.isdir(dir_path), f"Directory not found: {dir_path}" 554 | self.epoch_files_by_dir.append( 555 | self._get_epoch_files(dir_path=dir_path) 556 | ) 557 | 558 | @classmethod 559 | def get_config_from_path(cls, path: str): 560 | """Returns config dict from dataset directory. 561 | 562 | Note that this will return the config corresponding to epoch0.jsonl. 563 | """ 564 | assert os.path.isdir(path), "directory not found" 565 | assert os.path.isfile( 566 | epoch0_path := os.path.join(path, "epoch0.jsonl") 567 | ), "epoch file not found" 568 | with open(epoch0_path) as f: 569 | return json.loads(f.readline()) 570 | 571 | def close(self): 572 | if self.file_buffs is not None: 573 | for file_buff in self.file_buffs: 574 | file_buff.close() 575 | 576 | if self.file_mmaps is not None: 577 | for file_mmap in self.file_mmaps: 578 | file_mmap.close() 579 | 580 | def __del__(self): 581 | self.close() 582 | 583 | def __len__(self): 584 | raise NotImplementedError 585 | 586 | def __getitem__(self, idx: int): 587 | def _format(tok): 588 | # This is required because json formats tuples into lists 589 | if isinstance(tok, list): 590 | return tuple(tok) 591 | return tok 592 | 593 | file_idx, pos = self.index[idx] 594 | mmap_obj = self.file_mmaps[file_idx] 595 | mmap_obj.seek(pos) 596 | 597 | _debug = mmap_obj.readline() 598 | seq = json.loads(_debug) # Load raw seq 599 | seq = [_format(tok) for tok in seq] # Format into hashable 600 | if self._transform: 601 | seq = self._transform(seq) # Data augmentation 602 | 603 | src = seq 604 | tgt = seq[1:] + [self.tokenizer.pad_tok] 605 | mask = self.get_loss_mask(src_seq=src, tgt_seq=tgt) 606 | 607 | return ( 608 | torch.tensor(self.tokenizer.encode(src)), 609 | torch.tensor(self.tokenizer.encode(tgt)), 610 | mask, 611 | ) 612 | 613 | def check_config(self, epoch_load_path: str): 614 | def _check_config(): 615 | assert self.config["tokenizer_name"] == self.tokenizer.name 616 | for k, v in self.config["tokenizer_config"].items(): 617 | if isinstance(v, dict): 618 | for _k, _v in v.items(): 619 | assert _v == self.tokenizer.config[k][_k] 620 | elif isinstance(v, str) or isinstance(v, int): 621 | assert v == self.tokenizer.config[k] 622 | 623 | # Check self.tokenizers is the same as the one used to generate file 624 | # This logic could use a refactor (maybe use deepdict?) 625 | with open(epoch_load_path, "r") as f: 626 | buffer = f.readline() 627 | 628 | try: 629 | prev_max_seq_len = ( 630 | self.config["max_seq_len"] if self.config is not None else None 631 | ) 632 | self.config = json.loads(buffer) 633 | assert (prev_max_seq_len is None) or ( 634 | self.config["max_seq_len"] == prev_max_seq_len 635 | ) 636 | self.max_seq_len = self.config["max_seq_len"] 637 | _check_config() 638 | except AssertionError as e: 639 | self.logger.error( 640 | "Tokenizer config setting don't match those in file" 641 | ) 642 | raise e 643 | except Exception as e: 644 | self.logger.error( 645 | "Processing tokenizer config resulted in an error" 646 | ) 647 | raise e 648 | 649 | def _build_index(self, mmap_obj: mmap.mmap): 650 | # Skip first line containing config 651 | mmap_obj.seek(0) 652 | mmap_obj.readline() 653 | 654 | index = [] 655 | while True: 656 | pos = mmap_obj.tell() 657 | line_buffer = mmap_obj.readline() 658 | if line_buffer == b"": 659 | break 660 | else: 661 | index.append(pos) 662 | 663 | self.logger.debug(f"Finished indexing {len(index)} sequences") 664 | 665 | return index 666 | 667 | def set_transform(self, transform: Callable | list[Callable]): 668 | """Sets data augmentation transformation functions. 669 | 670 | Args: 671 | transform (Callable | list[Callable]): Transformation function(s). 672 | Provided functions are expected to be list[str | tuple] -> 673 | list[str | tuple]. 674 | """ 675 | print(f"Setting data augmentation transform") 676 | 677 | if isinstance(transform, Callable): 678 | self._transform = transform 679 | elif isinstance(transform, list): 680 | # Check validity 681 | for fn in transform: 682 | assert isinstance(fn, Callable), "Invalid function" 683 | 684 | # Define new transformation function (apply fn in order) 685 | def _new_transform(x): 686 | for fn in transform: 687 | x = fn(x) 688 | return x 689 | 690 | self._transform = _new_transform 691 | else: 692 | raise ValueError("Must provide function or list of functions.") 693 | 694 | 695 | def _get_seqs( 696 | _entry: MidiDict | dict, 697 | _tokenizer: Tokenizer, 698 | _tokenize_fn: Callable | None = None, 699 | ): 700 | logger = setup_logger() 701 | 702 | if isinstance(_entry, str): 703 | _midi_dict = MidiDict.from_msg_dict(json.loads(_entry.rstrip())) 704 | elif isinstance(_entry, dict): 705 | _midi_dict = MidiDict.from_msg_dict(_entry) 706 | elif isinstance(_entry, MidiDict): 707 | _midi_dict = _entry 708 | else: 709 | raise Exception 710 | 711 | try: 712 | if _tokenize_fn is not None: 713 | _tokenized_seq = _tokenize_fn(_midi_dict) 714 | else: 715 | _tokenized_seq = _tokenizer.tokenize(_midi_dict) 716 | except Exception as e: 717 | print(e) 718 | logger.info(f"Skipping midi_dict: {e}") 719 | return 720 | else: 721 | if _tokenizer.unk_tok in _tokenized_seq: 722 | logger.warning("Unknown token seen while tokenizing midi_dict") 723 | return _tokenized_seq 724 | 725 | 726 | def get_seqs( 727 | tokenizer: Tokenizer, 728 | midi_dict_iter: Iterable, 729 | tokenize_fn: Callable | None = None, 730 | ): 731 | # Can't pickle geneator object when start method is spawn 732 | if multiprocessing.get_start_method() == "spawn": 733 | logging.info( 734 | "Converting generator to list due to multiprocessing start method" 735 | ) 736 | midi_dict_iter = [_ for _ in midi_dict_iter] 737 | 738 | with multiprocessing.Pool() as pool: 739 | results = pool.imap_unordered( 740 | functools.partial( 741 | _get_seqs, _tokenizer=tokenizer, _tokenize_fn=tokenize_fn 742 | ), 743 | midi_dict_iter, 744 | ) 745 | 746 | yield from results 747 | 748 | 749 | def reservoir(_iterable: Iterable, k: int): 750 | _reservoir = [] 751 | for entry in _iterable: 752 | if entry is not None: 753 | _reservoir.append(entry) 754 | 755 | if len(_reservoir) >= k: 756 | random.shuffle(_reservoir) 757 | yield from _reservoir 758 | _reservoir = [] 759 | 760 | if _reservoir != []: 761 | yield from _reservoir 762 | 763 | 764 | def random_selection_itt(iterables: list[Iterable]): 765 | iterators = [iter(x) for x in iterables] 766 | active = list(iterators) # Start with all iterators as active 767 | 768 | try: 769 | while active: 770 | selected = random.choice(active) 771 | yield next(selected) 772 | 773 | for it in iterators: 774 | if it is not selected: 775 | next(it, None) 776 | except StopIteration: 777 | pass 778 | 779 | 780 | class PretrainingDataset(TrainingDataset): 781 | """Torch dataset object yielding sequences formatted for pre-training""" 782 | 783 | def __init__(self, dir_paths: List[str] | str, tokenizer: Tokenizer): 784 | super().__init__(tokenizer=tokenizer) 785 | 786 | if isinstance(dir_paths, str): 787 | dir_paths = [dir_paths] 788 | 789 | self.dir_paths = dir_paths 790 | self.get_epoch_files_by_dir(dir_paths) 791 | self.init_epoch(0) 792 | 793 | def __len__(self): 794 | return len(self.index) 795 | 796 | def get_loss_mask(self, src_seq: list, tgt_seq: list): 797 | return torch.tensor( 798 | [tok != self.tokenizer.pad_tok for tok in tgt_seq], 799 | dtype=torch.bool, 800 | ) 801 | 802 | @classmethod 803 | def build( 804 | cls, 805 | tokenizer: Tokenizer, 806 | save_dir: str, 807 | max_seq_len: int, 808 | num_epochs: int, 809 | midi_dataset: MidiDataset = None, 810 | midi_dataset_path: str = None, 811 | ): 812 | """Builds and returns PretrainingDataset.""" 813 | 814 | def _build_epoch(_save_path, _midi_dataset): 815 | with jsonlines.open(_save_path, mode="w") as writer: 816 | # Write tokenizer info into json on first line 817 | writer.write( 818 | { 819 | "tokenizer_config": tokenizer.config, 820 | "tokenizer_name": tokenizer.name, 821 | "max_seq_len": max_seq_len, 822 | } 823 | ) 824 | 825 | buffer = [] 826 | _idx = 0 827 | for entry in reservoir(get_seqs(tokenizer, _midi_dataset), 10): 828 | if entry is not None: 829 | buffer += entry 830 | while len(buffer) >= max_seq_len: 831 | writer.write(buffer[:max_seq_len]) 832 | buffer = buffer[max_seq_len:] 833 | 834 | _idx += 1 835 | if _idx % 250 == 0: 836 | logger.info(f"Finished processing {_idx}") 837 | 838 | buffer += [tokenizer.pad_tok] * (max_seq_len - len(buffer)) 839 | writer.write(buffer[:max_seq_len]) 840 | 841 | logger = setup_logger() 842 | assert max_seq_len > 0, "max_seq_len must be greater than 0" 843 | assert num_epochs > 0, "num_epochs must be greater than 0" 844 | if multiprocessing.get_start_method() == "spawn": 845 | logger.warning( 846 | 'The current multiprocessing start method is "spawn", this ' 847 | "will slow down dataset building" 848 | ) 849 | 850 | if os.path.isdir(save_dir) and os.listdir(save_dir): 851 | print( 852 | f"The directory at {save_dir} in non-empty, type [Y/y] to " 853 | "remove and continue:" 854 | ) 855 | if input() not in {"Y", "y"}: 856 | print("Aborting") 857 | return 858 | else: 859 | shutil.rmtree(save_dir) 860 | 861 | if not os.path.exists(save_dir): 862 | os.mkdir(save_dir) 863 | 864 | if not midi_dataset and not midi_dataset_path: 865 | Exception("Must provide either midi_dataset or midi_dataset_path") 866 | if midi_dataset and midi_dataset_path: 867 | Exception("Can't provide both midi_dataset and midi_dataset_path") 868 | 869 | logger.info( 870 | f"Building PretrainingDataset with config: " 871 | f"max_seq_len={max_seq_len}, " 872 | f"tokenizer_name={tokenizer.name}" 873 | ) 874 | for idx in range(num_epochs): 875 | logger.info(f"Building epoch {idx}/{num_epochs - 1}...") 876 | 877 | # Reload the dataset on each iter 878 | if midi_dataset_path: 879 | midi_dataset = MidiDataset.get_generator(midi_dataset_path) 880 | 881 | _build_epoch( 882 | _save_path=os.path.join(save_dir, f"epoch{idx}.jsonl"), 883 | _midi_dataset=midi_dataset, 884 | ) 885 | 886 | logger.info( 887 | f"Finished building, saved PretrainingDataset to {save_dir}" 888 | ) 889 | 890 | 891 | # TODO: Refactor for readability 892 | def _get_combined_mididict( 893 | clean_midi_dict: MidiDict, 894 | noisy_midi_dict: MidiDict, 895 | min_noisy_ms: int, 896 | max_noisy_ms: int, 897 | min_clean_ms: int, 898 | max_clean_ms: int, 899 | ) -> MidiDict: 900 | # NOTE: We adopt the tempo/ticks_per_beat of the clean_midi_dict, and 901 | # adjust the noisy note messages accordingly. 902 | assert len(clean_midi_dict.tempo_msgs) == 1, "Unsupported tempo msgs" 903 | assert len(noisy_midi_dict.tempo_msgs) == 1, "Unsupported tempo msgs" 904 | 905 | total_length_ms = get_duration_ms( 906 | start_tick=0, 907 | end_tick=clean_midi_dict.note_msgs[-1]["data"]["start"], 908 | tempo_msgs=clean_midi_dict.tempo_msgs, 909 | ticks_per_beat=clean_midi_dict.ticks_per_beat, 910 | ) 911 | 912 | # Create intervals 913 | noisy_intervals = [] 914 | clean_intervals = [] 915 | prev_ms = -1 916 | add_noisy_next = random.choice([True, False]) 917 | while True: 918 | if add_noisy_next is True: 919 | # Add noisy interval 920 | noisy_end_ms = random.randint( 921 | prev_ms + min_noisy_ms, prev_ms + max_noisy_ms 922 | ) 923 | noisy_intervals.append([prev_ms + 1, noisy_end_ms]) 924 | prev_ms = noisy_end_ms 925 | if prev_ms > total_length_ms: 926 | break 927 | else: 928 | add_noisy_next = False 929 | else: 930 | # Add clean interval 931 | clean_end_ms = random.randint( 932 | prev_ms + min_clean_ms, prev_ms + max_clean_ms 933 | ) 934 | clean_intervals.append([prev_ms + 1, clean_end_ms]) 935 | prev_ms = clean_end_ms 936 | if prev_ms > total_length_ms: 937 | break 938 | else: 939 | add_noisy_next = True 940 | 941 | # Merge note_msgs 942 | clean_ms_to_tick = (clean_midi_dict.ticks_per_beat * 1e3) / ( 943 | clean_midi_dict.tempo_msgs[0]["data"] 944 | ) 945 | 946 | comb_note_msgs = [] 947 | for _note_msg in noisy_midi_dict.note_msgs: 948 | onset_time_ms = noisy_midi_dict.tick_to_ms(_note_msg["data"]["start"]) 949 | 950 | for _interval_start_ms, _interval_end_ms in noisy_intervals: 951 | if _interval_start_ms < onset_time_ms < _interval_end_ms: 952 | offset_time_ms = noisy_midi_dict.tick_to_ms( 953 | _note_msg["data"]["end"] 954 | ) 955 | _adj_note_msg = copy.deepcopy(_note_msg) 956 | _adj_onset_tick = int(onset_time_ms * clean_ms_to_tick) 957 | _adj_offset_tick = int(offset_time_ms * clean_ms_to_tick) 958 | _adj_note_msg["tick"] = _adj_onset_tick 959 | _adj_note_msg["data"]["start"] = _adj_onset_tick 960 | _adj_note_msg["data"]["end"] = _adj_offset_tick 961 | 962 | comb_note_msgs.append(_adj_note_msg) 963 | break 964 | 965 | for _note_msg in clean_midi_dict.note_msgs: 966 | onset_time_ms = clean_midi_dict.tick_to_ms(_note_msg["data"]["start"]) 967 | 968 | for _interval_start_ms, _interval_end_ms in clean_intervals: 969 | if _interval_start_ms < onset_time_ms < _interval_end_ms: 970 | comb_note_msgs.append(_note_msg) 971 | break 972 | 973 | comb_metadata = deepcopy(clean_midi_dict.metadata) 974 | comb_metadata["noisy_intervals"] = noisy_intervals 975 | 976 | # Maybe using clean pedal msgs here is bad? 977 | return MidiDict( 978 | meta_msgs=clean_midi_dict.meta_msgs, 979 | tempo_msgs=clean_midi_dict.tempo_msgs, 980 | pedal_msgs=clean_midi_dict.pedal_msgs, 981 | instrument_msgs=clean_midi_dict.instrument_msgs, 982 | note_msgs=comb_note_msgs, 983 | ticks_per_beat=clean_midi_dict.ticks_per_beat, 984 | metadata=comb_metadata, 985 | ) 986 | 987 | 988 | # TODO: Refactor this function for readability 989 | def _noise_midi_dict(midi_dict: MidiDict, config: dict): 990 | def _get_velocity_adjusted_msg( 991 | __note_msg: dict, 992 | _max_velocity_adjustment: int, 993 | ): 994 | _temp_note_msg = copy.deepcopy(__note_msg) 995 | _temp_note_msg["data"]["velocity"] = min( 996 | max( 997 | 0, 998 | _temp_note_msg["data"]["velocity"] 999 | + random.randint( 1000 | -_max_velocity_adjustment, _max_velocity_adjustment 1001 | ), 1002 | ), 1003 | 127, 1004 | ) 1005 | 1006 | return _temp_note_msg 1007 | 1008 | def _get_quantized_msg( 1009 | __note_msg: dict, 1010 | _q_delta: int, 1011 | _vel_q_delta: int, 1012 | ): 1013 | _start = __note_msg["data"]["start"] 1014 | _adjusted_start = max(0, _q_delta * round(_start / _q_delta)) 1015 | 1016 | _end = __note_msg["data"]["end"] 1017 | _adjusted_end = max( 1018 | _adjusted_start + _q_delta, 1019 | _q_delta * round(_end / _q_delta), 1020 | ) 1021 | _velocity = __note_msg["data"]["velocity"] 1022 | _adjusted_velocity = min( 1023 | 127, 1024 | max( 1025 | _vel_q_delta, 1026 | _vel_q_delta * round(_velocity / _vel_q_delta), 1027 | ), 1028 | ) 1029 | 1030 | _temp_note_msg = copy.deepcopy(__note_msg) 1031 | _temp_note_msg["data"]["start"] = _adjusted_start 1032 | _temp_note_msg["data"]["end"] = _adjusted_end 1033 | _temp_note_msg["tick"] = _adjusted_start 1034 | _temp_note_msg["data"]["velocity"] = _adjusted_velocity 1035 | 1036 | return _temp_note_msg 1037 | 1038 | def _get_onset_adjusted_msg( 1039 | __note_msg: dict, 1040 | _max_tick_adjustment: int, 1041 | ): 1042 | _adjusted_start = max( 1043 | 0, 1044 | __note_msg["data"]["start"] 1045 | + random.randint(-_max_tick_adjustment, _max_tick_adjustment), 1046 | ) 1047 | _adjusted_end = max( 1048 | _adjusted_start + _max_tick_adjustment, 1049 | __note_msg["data"]["end"] 1050 | + random.randint(-_max_tick_adjustment, _max_tick_adjustment), 1051 | ) 1052 | assert ( 1053 | _adjusted_start < _adjusted_end 1054 | ), f"{_adjusted_start, _adjusted_end}" 1055 | 1056 | _temp_note_msg = copy.deepcopy(__note_msg) 1057 | _temp_note_msg["data"]["start"] = _adjusted_start 1058 | _temp_note_msg["data"]["end"] = _adjusted_end 1059 | _temp_note_msg["tick"] = _adjusted_start 1060 | 1061 | return _temp_note_msg 1062 | 1063 | _note_msgs = copy.deepcopy(midi_dict.note_msgs) 1064 | 1065 | # Remove notes 1066 | if random.random() < config["remove_notes"]["activation_prob"]: 1067 | remove_prob = random.uniform( 1068 | config["remove_notes"]["min_ratio"], 1069 | config["remove_notes"]["max_ratio"], 1070 | ) 1071 | _note_msgs = [ 1072 | msg for msg in _note_msgs if random.random() > remove_prob 1073 | ] 1074 | 1075 | # Adjust velocity 1076 | if random.random() < config["adjust_velocity"]["activation_prob"]: 1077 | max_velocity_adjustment = random.randint( 1078 | config["adjust_velocity"]["min_adjust"], 1079 | config["adjust_velocity"]["max_adjust"], 1080 | ) 1081 | 1082 | _note_msgs = [ 1083 | _get_velocity_adjusted_msg(msg, max_velocity_adjustment) 1084 | for msg in _note_msgs 1085 | ] 1086 | 1087 | # Adjust or quantize onsets/offsets 1088 | if len(midi_dict.tempo_msgs) != 1: 1089 | print("Found more than one tempo message, skipping onset noising") 1090 | elif random.random() < config["adjust_onsets"]["activation_prob"]: 1091 | # Min/max adjustments stored in seconds (_s) 1092 | max_tick_adjustment = second2tick( 1093 | random.uniform( 1094 | config["adjust_onsets"]["min_adjust_s"], 1095 | config["adjust_onsets"]["max_adjust_s"], 1096 | ), 1097 | ticks_per_beat=midi_dict.ticks_per_beat, 1098 | tempo=midi_dict.tempo_msgs[0]["data"], 1099 | ) 1100 | adjust_prob = random.uniform( 1101 | config["adjust_onsets"]["min_ratio"], 1102 | config["adjust_onsets"]["max_ratio"], 1103 | ) 1104 | 1105 | _note_msgs = [ 1106 | ( 1107 | _get_onset_adjusted_msg( 1108 | msg, 1109 | _max_tick_adjustment=max_tick_adjustment, 1110 | ) 1111 | if random.random() < adjust_prob 1112 | else msg 1113 | ) 1114 | for msg in _note_msgs 1115 | ] 1116 | elif random.random() < config["quantize_onsets"]["activation_prob"]: 1117 | q_delta = second2tick( 1118 | random.uniform( 1119 | config["quantize_onsets"]["min_quant_s"], 1120 | config["quantize_onsets"]["min_quant_s"], 1121 | ), 1122 | ticks_per_beat=midi_dict.ticks_per_beat, 1123 | tempo=midi_dict.tempo_msgs[0]["data"], 1124 | ) 1125 | vel_q_delta = config["quantize_onsets"]["max_vel_delta"] 1126 | 1127 | _note_msgs = [ 1128 | ( 1129 | _get_quantized_msg( 1130 | msg, 1131 | _q_delta=q_delta, 1132 | _vel_q_delta=vel_q_delta, 1133 | ) 1134 | ) 1135 | for msg in _note_msgs 1136 | ] 1137 | 1138 | _note_msgs = sorted(_note_msgs, key=lambda _msg: _msg["tick"]) 1139 | 1140 | return MidiDict( 1141 | meta_msgs=midi_dict.meta_msgs, 1142 | tempo_msgs=midi_dict.tempo_msgs, 1143 | pedal_msgs=midi_dict.pedal_msgs, 1144 | instrument_msgs=midi_dict.instrument_msgs, 1145 | note_msgs=_note_msgs, 1146 | ticks_per_beat=midi_dict.ticks_per_beat, 1147 | metadata=midi_dict.metadata, 1148 | ) 1149 | 1150 | 1151 | def export_inference_abs_build_tokenize_fn( 1152 | midi_dict: MidiDict, tokenizer: InferenceAbsTokenizer 1153 | ): 1154 | finetuning_config = load_config()["data"]["finetuning"] 1155 | GUIDANCE_PROB = finetuning_config["guidance_prob"] 1156 | NOISING_PROB = finetuning_config["noising"]["activation_prob"] 1157 | MIN_NOISY_MS = finetuning_config["min_noisy_interval_ms"] 1158 | MAX_NOISY_MS = finetuning_config["max_noisy_interval_ms"] 1159 | MIN_CLEAN_MS = finetuning_config["min_clean_interval_ms"] 1160 | MAX_CLEAN_MS = finetuning_config["max_clean_interval_ms"] 1161 | 1162 | if random.random() <= NOISING_PROB: 1163 | noisy_midi_dict = _noise_midi_dict( 1164 | midi_dict, config=finetuning_config["noising"] 1165 | ) 1166 | midi_dict_for_tokenization = _get_combined_mididict( 1167 | clean_midi_dict=midi_dict, 1168 | noisy_midi_dict=noisy_midi_dict, 1169 | min_noisy_ms=MIN_NOISY_MS, 1170 | max_noisy_ms=MAX_NOISY_MS, 1171 | min_clean_ms=MIN_CLEAN_MS, 1172 | max_clean_ms=MAX_CLEAN_MS, 1173 | ) 1174 | else: 1175 | midi_dict_for_tokenization = midi_dict 1176 | 1177 | if random.random() <= GUIDANCE_PROB: 1178 | return tokenizer.tokenize( 1179 | midi_dict=midi_dict_for_tokenization, 1180 | prompt_intervals_ms=midi_dict_for_tokenization.metadata.get( 1181 | "noisy_intervals", [] 1182 | ), 1183 | guidance_midi_dict=midi_dict, 1184 | ) 1185 | else: 1186 | return tokenizer.tokenize( 1187 | midi_dict=midi_dict_for_tokenization, 1188 | prompt_intervals_ms=midi_dict_for_tokenization.metadata.get( 1189 | "noisy_intervals", [] 1190 | ), 1191 | ) 1192 | 1193 | 1194 | class FinetuningDataset(TrainingDataset): 1195 | """Torch dataset object yielding sequences formatted for fine-tuning.""" 1196 | 1197 | def __init__( 1198 | self, dir_paths: List[str] | str, tokenizer: InferenceAbsTokenizer 1199 | ): 1200 | super().__init__(tokenizer=tokenizer) 1201 | 1202 | assert tokenizer.name == "inference_abs", "invalid tokenizer" 1203 | 1204 | if isinstance(dir_paths, str): 1205 | dir_paths = [dir_paths] 1206 | 1207 | self.dir_paths = dir_paths 1208 | self.get_epoch_files_by_dir(dir_paths) 1209 | self.init_epoch(0) 1210 | 1211 | def __len__(self): 1212 | return len(self.index) 1213 | 1214 | def get_loss_mask(self, src_seq: list, tgt_seq: list): 1215 | mask = [False] * len(tgt_seq) 1216 | inside_target = True 1217 | 1218 | for idx, (src_tok, tgt_tok) in enumerate(zip(src_seq, tgt_seq)): 1219 | if src_tok == self.tokenizer.guidance_start_tok: 1220 | inside_target = False 1221 | elif src_tok == self.tokenizer.guidance_end_tok: 1222 | inside_target = True 1223 | elif tgt_tok == self.tokenizer.prompt_start_tok: 1224 | inside_target = False 1225 | elif src_tok == self.tokenizer.prompt_end_tok: 1226 | inside_target = True 1227 | 1228 | if inside_target is True and tgt_tok != self.tokenizer.pad_tok: 1229 | mask[idx] = True 1230 | 1231 | return torch.tensor(mask, dtype=torch.bool) 1232 | 1233 | @classmethod 1234 | def build( 1235 | cls, 1236 | tokenizer: InferenceAbsTokenizer, 1237 | save_dir: str, 1238 | max_seq_len: int, 1239 | num_epochs: int, 1240 | midi_dataset_path: str, 1241 | ): 1242 | 1243 | def _build_epoch(_save_path, _midi_dataset): 1244 | with jsonlines.open(_save_path, mode="w") as writer: 1245 | # Write tokenizer info into json on first line 1246 | writer.write( 1247 | { 1248 | "tokenizer_config": tokenizer.config, 1249 | "tokenizer_name": tokenizer.name, 1250 | "max_seq_len": max_seq_len, 1251 | } 1252 | ) 1253 | 1254 | _idx = 0 1255 | for entry in reservoir( 1256 | get_seqs( 1257 | tokenizer, 1258 | _midi_dataset, 1259 | tokenize_fn=functools.partial( 1260 | export_inference_abs_build_tokenize_fn, 1261 | tokenizer=tokenizer, 1262 | ), 1263 | ), 1264 | 10, 1265 | ): 1266 | for _entry in tokenizer.split(entry, max_seq_len): 1267 | writer.write(_entry) 1268 | 1269 | _idx += 1 1270 | if _idx % 250 == 0: 1271 | logger.info(f"Finished processing {_idx}") 1272 | 1273 | logger = setup_logger() 1274 | assert max_seq_len > 0, "max_seq_len must be greater than 0" 1275 | assert num_epochs > 0, "num_epochs must be greater than 0" 1276 | assert os.path.isfile(midi_dataset_path), "file not found" 1277 | if multiprocessing.get_start_method() == "spawn": 1278 | logger.warning( 1279 | 'The current multiprocessing start method is "spawn", this ' 1280 | "will slow down dataset building" 1281 | ) 1282 | 1283 | if os.path.isdir(save_dir) and os.listdir(save_dir): 1284 | print( 1285 | f"The directory at {save_dir} in non-empty, type [Y/y] to " 1286 | "remove and continue:" 1287 | ) 1288 | if input() not in {"Y", "y"}: 1289 | print("Aborting") 1290 | return 1291 | else: 1292 | shutil.rmtree(save_dir) 1293 | 1294 | if not os.path.exists(save_dir): 1295 | os.mkdir(save_dir) 1296 | 1297 | logger.info( 1298 | f"Building FinetuningDataset with config: " 1299 | f"max_seq_len={max_seq_len}, " 1300 | f"tokenizer_name={tokenizer.name}" 1301 | ) 1302 | 1303 | for idx in range(num_epochs): 1304 | logger.info(f"Building epoch {idx}/{num_epochs - 1}...") 1305 | 1306 | # Reload the combined dataset for each epoch 1307 | midi_dataset = MidiDataset.get_generator(midi_dataset_path) 1308 | _build_epoch( 1309 | _save_path=os.path.join(save_dir, f"epoch{idx}.jsonl"), 1310 | _midi_dataset=midi_dataset, 1311 | ) 1312 | 1313 | logger.info(f"Finished building, saved FinetuningDataset to {save_dir}") 1314 | -------------------------------------------------------------------------------- /aria/inference/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import TransformerLM 2 | -------------------------------------------------------------------------------- /aria/inference/model.py: -------------------------------------------------------------------------------- 1 | """Inference implementation with torch-compiler friendly kv-cache.""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from torch.nn import functional as F 7 | from aria.model import ModelConfig 8 | 9 | 10 | class KVCache(nn.Module): 11 | def __init__( 12 | self, 13 | max_batch_size: int, 14 | max_seq_length: int, 15 | n_heads: int, 16 | head_dim: int, 17 | dtype=torch.bfloat16, 18 | ): 19 | super().__init__() 20 | self.dtype = dtype 21 | cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) 22 | self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) 23 | self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) 24 | 25 | def update(self, input_pos, k_val, v_val): 26 | # input_pos: [S], k_val: [B, H, S, D] 27 | assert input_pos.shape[0] == k_val.shape[2] 28 | 29 | k_out = self.k_cache 30 | v_out = self.v_cache 31 | k_out[:, :, input_pos] = k_val 32 | v_out[:, :, input_pos] = v_val 33 | 34 | return k_out, v_out 35 | 36 | 37 | class TransformerLM(nn.Module): 38 | def __init__(self, model_config: ModelConfig): 39 | super().__init__() 40 | self.model_config = model_config 41 | self.max_seq_len = model_config.max_seq_len 42 | self.model = Transformer(model_config) 43 | self.lm_head = nn.Linear( 44 | model_config.d_model, model_config.vocab_size, bias=False 45 | ) 46 | 47 | def forward( 48 | self, 49 | idxs: torch.Tensor, 50 | input_pos: torch.Tensor, 51 | pad_idxs: torch.Tensor | None = None, 52 | ): 53 | hidden_states = self.model( 54 | idxs=idxs, 55 | input_pos=input_pos, 56 | pad_idxs=pad_idxs, 57 | ) 58 | logits = self.lm_head(hidden_states) 59 | 60 | return logits 61 | 62 | def setup_cache( 63 | self, 64 | batch_size, 65 | max_seq_len=4096, 66 | dtype=torch.bfloat16, 67 | ): 68 | # Init cache 69 | for b in self.model.encode_layers: 70 | b.kv_cache = KVCache( 71 | max_batch_size=batch_size, 72 | max_seq_length=max_seq_len, 73 | n_heads=self.model_config.n_heads, 74 | head_dim=self.model_config.d_model // self.model_config.n_heads, 75 | dtype=dtype, 76 | ).cuda() 77 | 78 | self.model.freqs_cis = precompute_freqs_cis( 79 | seq_len=max_seq_len, 80 | n_elem=self.model_config.d_model // self.model_config.n_heads, 81 | base=500000, 82 | dtype=dtype, 83 | ).cuda() 84 | self.model.causal_mask = torch.tril( 85 | torch.ones(max_seq_len, max_seq_len, dtype=torch.bool) 86 | ).cuda() 87 | 88 | 89 | class Transformer(nn.Module): 90 | def __init__(self, model_config: ModelConfig) -> None: 91 | super().__init__() 92 | self.model_config = model_config 93 | 94 | self.tok_embeddings = nn.Embedding( 95 | num_embeddings=model_config.vocab_size, 96 | embedding_dim=model_config.d_model, 97 | ) 98 | self.encode_layers = nn.ModuleList( 99 | TransformerBlock(model_config) for _ in range(model_config.n_layers) 100 | ) 101 | self.out_layer_norm = nn.LayerNorm(model_config.d_model) 102 | 103 | self.freqs_cis = None 104 | self.casual_mask = None 105 | 106 | def forward( 107 | self, 108 | idxs: torch.Tensor, 109 | input_pos: torch.Tensor, 110 | pad_idxs: torch.Tensor | None = None, 111 | ): 112 | assert self.freqs_cis is not None, "Caches must be initialized first" 113 | 114 | mask = self.causal_mask[None, None, input_pos] 115 | 116 | if pad_idxs is not None: 117 | mask = mask & ~(pad_idxs.unsqueeze(1).unsqueeze(1)) 118 | 119 | freqs_cis = self.freqs_cis[input_pos] 120 | 121 | x = self.tok_embeddings(idxs) 122 | for layer in self.encode_layers: 123 | x = layer(x, input_pos, freqs_cis, mask) 124 | 125 | x = self.out_layer_norm(x) 126 | 127 | return x 128 | 129 | 130 | class TransformerBlock(nn.Module): 131 | def __init__(self, model_config: ModelConfig) -> None: 132 | super().__init__() 133 | 134 | self.d_model = model_config.d_model 135 | self.n_heads = model_config.n_heads 136 | self.d_head = self.d_model // self.n_heads 137 | self.max_seq_len = model_config.max_seq_len 138 | 139 | # Att 140 | self.mixed_qkv = nn.Linear( 141 | in_features=model_config.d_model, 142 | out_features=3 * model_config.d_model, 143 | bias=False, 144 | ) 145 | self.att_proj_linear = nn.Linear( 146 | in_features=model_config.d_model, 147 | out_features=model_config.d_model, 148 | bias=False, 149 | ) 150 | 151 | # FF 152 | self.ff_gate_proj = nn.Linear( 153 | in_features=model_config.d_model, 154 | out_features=model_config.d_model * model_config.ff_mult, 155 | bias=False, 156 | ) 157 | self.ff_up_proj = nn.Linear( 158 | in_features=model_config.d_model, 159 | out_features=model_config.d_model * model_config.ff_mult, 160 | bias=False, 161 | ) 162 | self.ff_down_proj = nn.Linear( 163 | in_features=model_config.d_model * model_config.ff_mult, 164 | out_features=model_config.d_model, 165 | bias=False, 166 | ) 167 | 168 | # Pre layer norms 169 | self.norm1 = nn.LayerNorm(model_config.d_model) 170 | self.norm2 = nn.LayerNorm(model_config.d_model) 171 | 172 | # TODO: Fill in args 173 | self.kv_cache = None 174 | 175 | def forward( 176 | self, 177 | x: torch.Tensor, 178 | input_pos: torch.Tensor, 179 | freqs_cis: torch.Tensor, 180 | mask: torch.Tensor, 181 | ): 182 | assert self.kv_cache is not None, "Cache not initialized" 183 | 184 | x += self._att_block( 185 | x=self.norm1(x), 186 | input_pos=input_pos, 187 | freqs_cis=freqs_cis, 188 | mask=mask, 189 | ) 190 | x = x + self._ff_block(self.norm2(x)) 191 | 192 | return x 193 | 194 | def get_kv(self, k: torch.Tensor, v: torch.Tensor, input_pos: torch.Tensor): 195 | k, v = self.kv_cache.update(k_val=k, v_val=v, input_pos=input_pos) 196 | 197 | return k, v 198 | 199 | def _att_block( 200 | self, 201 | x: torch.Tensor, 202 | input_pos: torch.Tensor, 203 | freqs_cis: torch.Tensor, 204 | mask: torch.Tensor, 205 | ): 206 | 207 | q, k, v = self.mixed_qkv(x).split( 208 | [self.d_model, self.d_model, self.d_model], dim=-1 209 | ) 210 | 211 | batch_size, seq_len, _ = q.shape 212 | q = q.view(batch_size, seq_len, self.n_heads, self.d_head) 213 | k = k.view(batch_size, seq_len, self.n_heads, self.d_head) 214 | v = v.view(batch_size, seq_len, self.n_heads, self.d_head) 215 | 216 | q = apply_rotary_emb(q, freqs_cis) 217 | k = apply_rotary_emb(k, freqs_cis) 218 | q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) 219 | 220 | k, v = self.get_kv(k, v, input_pos=input_pos) 221 | wv = F.scaled_dot_product_attention( 222 | query=q, 223 | key=k, 224 | value=v, 225 | attn_mask=mask, 226 | ) 227 | 228 | # (bz, nh, L, dh) -> (bz, L, nh, dh) -> (bz, L, d) 229 | wv = wv.transpose(1, 2).reshape( 230 | batch_size, seq_len, self.n_heads * self.d_head 231 | ) 232 | 233 | return self.att_proj_linear(wv) 234 | 235 | def _ff_block(self, x: torch.Tensor): 236 | return self.ff_down_proj( 237 | F.silu(self.ff_gate_proj(x)) * self.ff_up_proj(x) 238 | ) 239 | 240 | 241 | def precompute_freqs_cis( 242 | seq_len: int, 243 | n_elem: int, 244 | base: int = 500000, 245 | dtype: torch.dtype = torch.bfloat16, 246 | ): 247 | freqs = 1.0 / ( 248 | base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem) 249 | ) 250 | t = torch.arange(seq_len, device=freqs.device) 251 | freqs = torch.outer(t, freqs) 252 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) 253 | cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) 254 | 255 | return cache.to(dtype=dtype) 256 | 257 | 258 | @torch.jit.script 259 | def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: 260 | """ 261 | In-place RoPE. Credits to Katherine Crowson: 262 | x shape (b_sz, s_len, n_head, d_head). 263 | cos, sin shape (s_len, d_head // 2). 264 | """ 265 | 266 | d = x.shape[-1] // 2 267 | cos = freqs_cis[..., 0][None, :, None] 268 | sin = freqs_cis[..., 1][None, :, None] 269 | x1, x2 = x[..., :d], x[..., d : d * 2] 270 | tmp = x1.clone() 271 | x1.mul_(cos).addcmul_(x2, sin, value=-1) 272 | x2.mul_(cos).addcmul_(tmp, sin, value=1) 273 | return x 274 | -------------------------------------------------------------------------------- /aria/model.py: -------------------------------------------------------------------------------- 1 | """Training implementation.""" 2 | 3 | from dataclasses import dataclass 4 | from typing import Optional 5 | 6 | import torch 7 | import torch.utils.checkpoint 8 | 9 | from torch import nn as nn 10 | from torch.nn import functional as F 11 | 12 | 13 | @dataclass 14 | class ModelConfig: 15 | d_model: int 16 | n_heads: int 17 | n_layers: int 18 | ff_mult: int 19 | drop_p: float 20 | max_seq_len: int 21 | grad_checkpoint: bool 22 | vocab_size: Optional[int] = None 23 | 24 | def set_vocab_size(self, vocab_size: int): 25 | self.vocab_size = vocab_size 26 | 27 | 28 | class FusedEncoderBlock(nn.Module): 29 | def __init__(self, model_config: ModelConfig): 30 | super().__init__() 31 | 32 | self.drop_p = model_config.drop_p 33 | self.n_heads = model_config.n_heads 34 | self.d_head = model_config.d_model // model_config.n_heads 35 | self.max_seq_len = model_config.max_seq_len 36 | 37 | # Attention 38 | self.mixed_qkv = nn.Linear( 39 | in_features=model_config.d_model, 40 | out_features=3 * model_config.d_model, 41 | bias=False, 42 | ) 43 | self.att_proj_linear = nn.Linear( 44 | in_features=model_config.d_model, 45 | out_features=model_config.d_model, 46 | bias=False, 47 | ) 48 | 49 | # FF Layer 50 | self.ff_gate_proj = nn.Linear( 51 | in_features=model_config.d_model, 52 | out_features=model_config.d_model * model_config.ff_mult, 53 | bias=False, 54 | ) 55 | self.ff_up_proj = nn.Linear( 56 | in_features=model_config.d_model, 57 | out_features=model_config.d_model * model_config.ff_mult, 58 | bias=False, 59 | ) 60 | self.ff_down_proj = nn.Linear( 61 | in_features=model_config.d_model * model_config.ff_mult, 62 | out_features=model_config.d_model, 63 | bias=False, 64 | ) 65 | 66 | # Pre layer norms 67 | self.norm1 = nn.LayerNorm(model_config.d_model) 68 | self.norm2 = nn.LayerNorm(model_config.d_model) 69 | 70 | def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor): 71 | x = x + self._att_block(self.norm1(x), freqs_cis) 72 | x = x + self._ff_block(self.norm2(x)) 73 | 74 | return x 75 | 76 | def _att_block(self, x: torch.Tensor, freqs_cis: torch.Tensor): 77 | batch_size, seq_len, _ = x.shape 78 | mixed_qkv = self.mixed_qkv(x) 79 | xq, xk, xv = mixed_qkv.chunk(3, -1) 80 | 81 | # Reshape for rotary embeddings 82 | # Need contiguous for q, k since in-place RoPE cannot be applied on a view 83 | xq = xq.reshape( 84 | batch_size, seq_len, self.n_heads, self.d_head 85 | ).contiguous() 86 | xk = xk.reshape( 87 | batch_size, seq_len, self.n_heads, self.d_head 88 | ).contiguous() 89 | xv = xv.view(batch_size, seq_len, self.n_heads, self.d_head) 90 | 91 | # apply_rotary_post_emb expects: (b_sz, s_len, n_head, d_head) 92 | xq = apply_rotary_emb(xq, freqs_cis) 93 | xk = apply_rotary_emb(xk, freqs_cis) 94 | xq, xk, xv = map(lambda t: t.transpose(1, 2), (xq, xk, xv)) 95 | 96 | # scaled_dot_product_attention expects: (b_sz, n_head, s_len, d_head) 97 | att = F.scaled_dot_product_attention( 98 | query=xq, 99 | key=xk, 100 | value=xv, 101 | is_causal=True, 102 | ) 103 | 104 | # Reshape for out: (b_sz, s_len, n_head, d_head) 105 | out = att.transpose(1, 2).contiguous() 106 | out = out.view(batch_size, seq_len, self.n_heads * self.d_head) 107 | 108 | return self.att_proj_linear(out) 109 | 110 | def _ff_block(self, x: torch.Tensor): 111 | 112 | return self.ff_down_proj( 113 | F.silu(self.ff_gate_proj(x)) * self.ff_up_proj(x) 114 | ) 115 | 116 | 117 | class Transformer(nn.Module): 118 | """Transformer decoder with no language model head. 119 | 120 | Args: 121 | model_config (ModelConfig): Model config settings. 122 | """ 123 | 124 | def __init__(self, model_config: ModelConfig): 125 | super().__init__() 126 | self.model_config = model_config 127 | self.freqs_cis = None 128 | 129 | self.tok_embeddings = nn.Embedding( 130 | num_embeddings=model_config.vocab_size, 131 | embedding_dim=model_config.d_model, 132 | ) 133 | 134 | self.out_layer_norm = nn.LayerNorm(model_config.d_model) 135 | self.encode_layers = nn.ModuleList() 136 | for _ in range(model_config.n_layers): 137 | self.encode_layers.append(FusedEncoderBlock(model_config)) 138 | 139 | def forward( 140 | self, 141 | src: torch.Tensor, 142 | ): 143 | """Forward pass of Transformer. 144 | 145 | Args: 146 | src (torch.tensor): Input to encoder block, of shape (batch_size, 147 | seq_len, d_model). 148 | attn_mask (Optional[torch.tensor]): Attention mask of shape 149 | (batch_size, seq_len). Defaults to None. 150 | past_kv (Optional[list[KVCache]]): a list of kv caches. The list index 151 | corresponds to the layer index. 152 | 153 | Returns: 154 | torch.tensor: Model outputs with shape (batch_size, seq_len, 155 | d_model). 156 | """ 157 | hidden_states = self.tok_embeddings(src) 158 | 159 | if self.freqs_cis is None: 160 | self.freqs_cis = precompute_freqs_cis( 161 | seq_len=self.model_config.max_seq_len, 162 | n_elem=self.model_config.d_model // self.model_config.n_heads, 163 | base=500000, 164 | dtype=hidden_states.dtype, 165 | ).to(src.device) 166 | freqs_cis = self.freqs_cis[: src.shape[1]] 167 | 168 | if self.model_config.grad_checkpoint is True: 169 | for layer in self.encode_layers: 170 | 171 | def create_custom_forward(module): 172 | def custom_forward(*args): 173 | return module(*args) 174 | 175 | return custom_forward 176 | 177 | hidden_states = torch.utils.checkpoint.checkpoint( 178 | create_custom_forward(layer), 179 | hidden_states, 180 | freqs_cis, 181 | preserve_rng_state=True, 182 | use_reentrant=True, 183 | ) 184 | else: 185 | for layer in self.encode_layers: 186 | hidden_states = layer(hidden_states, freqs_cis=freqs_cis) 187 | 188 | return self.out_layer_norm(hidden_states) 189 | 190 | 191 | class TransformerLM(nn.Module): 192 | """Transformer decoder with head for language modelling. 193 | 194 | Args: 195 | model_config (ModelConfig): Model config settings. 196 | """ 197 | 198 | def __init__(self, model_config: ModelConfig): 199 | super().__init__() 200 | 201 | self.max_seq_len = model_config.max_seq_len 202 | self.model = Transformer(model_config) 203 | self.lm_head = nn.Linear( 204 | model_config.d_model, model_config.vocab_size, bias=False 205 | ) 206 | 207 | def forward( 208 | self, 209 | src: torch.Tensor, 210 | ): 211 | """Forward pass of Transformer decoder with LM head. 212 | 213 | Args: 214 | src (torch.tensor): Input to encoder block, of shape (batch_size, 215 | seq_len, d_model). 216 | attn_mask (Optional[torch.tensor]): Attention mask of shape 217 | (batch_size, seq_len). Defaults to None. 218 | past_kv (Optional[list[KVCache]]): a list of kv caches. The list index 219 | corresponds to the layer index. 220 | 221 | Returns: 222 | torch.tensor: Forward pass of src through Transformer and LM head. 223 | Has shape (batch_size, seq_len, vocab_size). 224 | """ 225 | hidden = self.model(src) 226 | logits = self.lm_head(hidden) 227 | 228 | return logits 229 | 230 | 231 | def precompute_freqs_cis( 232 | seq_len: int, 233 | n_elem: int, 234 | base: int = 500000, 235 | dtype: torch.dtype = torch.bfloat16, 236 | ): 237 | freqs = 1.0 / ( 238 | base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem) 239 | ) 240 | t = torch.arange(seq_len, device=freqs.device) 241 | freqs = torch.outer(t, freqs) 242 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) 243 | cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) 244 | 245 | return cache.to(dtype=dtype) 246 | 247 | 248 | @torch.jit.script 249 | def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: 250 | """ 251 | In-place RoPE. Credits to Katherine Crowson: 252 | x shape (b_sz, s_len, n_head, d_head). 253 | cos, sin shape (s_len, d_head // 2). 254 | """ 255 | 256 | d = x.shape[-1] // 2 257 | cos = freqs_cis[..., 0][None, :, None] 258 | sin = freqs_cis[..., 1][None, :, None] 259 | x1, x2 = x[..., :d], x[..., d : d * 2] 260 | tmp = x1.clone() 261 | x1.mul_(cos).addcmul_(x2, sin, value=-1) 262 | x2.mul_(cos).addcmul_(tmp, sin, value=1) 263 | return x 264 | -------------------------------------------------------------------------------- /aria/run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import os 5 | import re 6 | import sys 7 | 8 | 9 | def _parse_sample_args(): 10 | argp = argparse.ArgumentParser(prog="aria sample") 11 | argp.add_argument("-m", help="name of model config file") 12 | argp.add_argument("-c", help="path to model checkpoint") 13 | argp.add_argument("-p", help="path to midi file") 14 | argp.add_argument( 15 | "-temp", 16 | help="sampling temperature value", 17 | type=float, 18 | required=False, 19 | default=0.95, 20 | ) 21 | argp.add_argument( 22 | "-top_p", 23 | help="sampling top_p value", 24 | type=float, 25 | required=False, 26 | default=0.95, 27 | ) 28 | argp.add_argument( 29 | "-cfg", 30 | help="sampling cfg gamma value", 31 | type=float, 32 | required=False, 33 | ) 34 | argp.add_argument( 35 | "-metadata", 36 | nargs=2, 37 | metavar=("KEY", "VALUE"), 38 | action="append", 39 | help="manually add metadata key-value pair when sampling", 40 | ) 41 | argp.add_argument( 42 | "-var", 43 | help="number of variations", 44 | type=int, 45 | default=1, 46 | ) 47 | argp.add_argument( 48 | "-trunc", 49 | help="length (in seconds) of the prompt", 50 | type=int, 51 | default=20, 52 | ) 53 | argp.add_argument("-e", action="store_true", help="enable force end") 54 | argp.add_argument("-l", type=int, help="generation length", default=1024) 55 | argp.add_argument( 56 | "-guidance_path", type=str, help="path to guidance MIDI", required=False 57 | ) 58 | argp.add_argument( 59 | "-guidance_start_ms", 60 | help="guidance interval start (ms)", 61 | type=int, 62 | required=False, 63 | ) 64 | argp.add_argument( 65 | "-guidance_end_ms", 66 | help="guidance interval end (ms)", 67 | type=int, 68 | required=False, 69 | ) 70 | argp.add_argument("-compile", action="store_true", help="compile cudagraph") 71 | 72 | return argp.parse_args(sys.argv[2:]) 73 | 74 | 75 | def sample(args): 76 | """Entrypoint for sampling""" 77 | 78 | from torch.cuda import is_available as cuda_is_available 79 | from aria.inference import TransformerLM 80 | from aria.model import ModelConfig 81 | from aria.config import load_model_config, load_config 82 | from aria.tokenizer import InferenceAbsTokenizer 83 | from aria.sample import ( 84 | sample_batch_cfg, 85 | sample_batch, 86 | get_inference_prompt, 87 | ) 88 | from ariautils.midi import MidiDict 89 | from aria.utils import _load_weight 90 | 91 | if not cuda_is_available(): 92 | raise Exception("CUDA device is not available.") 93 | 94 | model_state = _load_weight(args.c, "cuda") 95 | model_state = { 96 | k.replace("_orig_mod.", ""): v for k, v in model_state.items() 97 | } 98 | 99 | manual_metadata = {k: v for k, v in args.metadata} if args.metadata else {} 100 | valid_metadata = load_config()["data"]["metadata"]["manual"] 101 | for k, v in manual_metadata.copy().items(): 102 | assert k in valid_metadata.keys(), f"{manual_metadata} is invalid" 103 | if v not in valid_metadata[k]: 104 | print(f"Ignoring invalid manual metadata: {k}") 105 | print(f"Please choose from {valid_metadata[k]}") 106 | del manual_metadata[k] 107 | 108 | num_variations = args.var 109 | truncate_len = args.trunc 110 | force_end = args.e 111 | model_name = args.m 112 | 113 | tokenizer = InferenceAbsTokenizer() 114 | model_config = ModelConfig(**load_model_config(model_name)) 115 | model_config.set_vocab_size(tokenizer.vocab_size) 116 | model_config.grad_checkpoint = False 117 | model = TransformerLM(model_config).cuda() 118 | 119 | try: 120 | model.load_state_dict(model_state) 121 | except Exception as e: 122 | print( 123 | "Failed to load model_state. This is likely due to an incompatibility " 124 | "between the checkpoint file (-c) and model name/config (-m)." 125 | ) 126 | raise e 127 | 128 | assert args.l > 0, "Generation length must be positive." 129 | max_new_tokens = args.l 130 | 131 | # Load and format prompts and metadata 132 | midi_dict = MidiDict.from_midi(mid_path=args.p) 133 | if args.guidance_path: 134 | guidance_midi_dict = MidiDict.from_midi(mid_path=args.guidance_path) 135 | else: 136 | guidance_midi_dict = None 137 | 138 | for k, v in manual_metadata.items(): 139 | midi_dict.metadata[k] = v 140 | 141 | print(f"Extracted metadata: {midi_dict.metadata}") 142 | print( 143 | f"Instruments: {set([MidiDict.get_program_to_instrument()[msg['data']] for msg in midi_dict.instrument_msgs])}" 144 | ) 145 | 146 | prompt_seq, guidance_seq = get_inference_prompt( 147 | tokenizer=tokenizer, 148 | midi_dict=midi_dict, 149 | truncate_len=truncate_len, 150 | guidance_start_ms=args.guidance_start_ms, 151 | guidance_end_ms=args.guidance_end_ms, 152 | guidance_midi_dict=guidance_midi_dict, 153 | ) 154 | 155 | if len(prompt_seq) + args.l > model_config.max_seq_len: 156 | print( 157 | "WARNING: Required context exceeds max_seq_len supported by model" 158 | ) 159 | prompts = [prompt_seq for _ in range(num_variations)] 160 | 161 | samples_dir = os.path.join(os.path.dirname(__file__), "..", "samples") 162 | if os.path.isdir(samples_dir) is False: 163 | os.mkdir(samples_dir) 164 | if guidance_seq: 165 | tokenizer.detokenize(guidance_seq).to_midi().save( 166 | os.path.join(samples_dir, f"guidance.mid") 167 | ) 168 | if args.cfg is not None and guidance_seq is not None: 169 | results = sample_batch_cfg( 170 | model=model, 171 | tokenizer=tokenizer, 172 | prompts=prompts, 173 | max_new_tokens=max_new_tokens, 174 | cfg_gamma=args.cfg, 175 | force_end=force_end, 176 | temperature=args.temp, 177 | top_p=args.top_p, 178 | compile=args.compile, 179 | ) 180 | else: 181 | results = sample_batch( 182 | model=model, 183 | tokenizer=tokenizer, 184 | prompts=prompts, 185 | max_new_tokens=max_new_tokens, 186 | force_end=force_end, 187 | temperature=args.temp, 188 | top_p=args.top_p, 189 | compile=args.compile, 190 | ) 191 | 192 | for idx, tokenized_seq in enumerate(results): 193 | res_midi_dict = tokenizer.detokenize(tokenized_seq) 194 | res_midi = res_midi_dict.to_midi() 195 | res_midi.save(os.path.join(samples_dir, f"res_{idx + 1}.mid")) 196 | 197 | print("Results saved to samples/") 198 | 199 | 200 | def _parse_midi_dataset_args(): 201 | argp = argparse.ArgumentParser(prog="aria midi-dataset") 202 | argp.add_argument("dir", help="directory containing midi files") 203 | argp.add_argument("save_path", help="path to save dataset") 204 | argp.add_argument("-r", action="store_true", help="recursively search dirs") 205 | argp.add_argument( 206 | "-s", action="store_true", help="shuffle dataset", default=False 207 | ) 208 | argp.add_argument( 209 | "-metadata", 210 | nargs=2, 211 | metavar=("KEY", "VALUE"), 212 | action="append", 213 | help="manually add metadata key-value pair when building dataset", 214 | ) 215 | argp.add_argument( 216 | "-split", type=float, help="create train/val split", required=False 217 | ) 218 | 219 | return argp.parse_args(sys.argv[2:]) 220 | 221 | 222 | def build_midi_dataset(args): 223 | """Entrypoint for building MidiDatasets from a directory""" 224 | from aria.datasets import MidiDataset 225 | 226 | assert args.dir, "build directory must be provided" 227 | manual_metadata = {k: v for k, v in args.metadata} if args.metadata else {} 228 | MidiDataset.build_to_file( 229 | dir=args.dir, 230 | save_path=args.save_path, 231 | recur=args.r, 232 | overwrite=True, 233 | manual_metadata=manual_metadata, 234 | shuffle=args.s, 235 | ) 236 | 237 | if args.split: 238 | assert 0.0 < args.split < 1.0, "Invalid range given for -split" 239 | MidiDataset.split_from_file( 240 | load_path=args.save_path, 241 | train_val_ratio=args.split, 242 | repeatable=True, 243 | ) 244 | 245 | 246 | def _parse_pretrain_dataset_args(): 247 | argp = argparse.ArgumentParser(prog="aria pretrain-dataset") 248 | argp.add_argument("-load_path", help="path midi_dict dataset") 249 | argp.add_argument("-save_dir", help="path to save dataset") 250 | argp.add_argument( 251 | "-tokenizer_name", help="tokenizer name", choices=["abs", "rel"] 252 | ) 253 | argp.add_argument("-l", help="max sequence length", type=int, default=4096) 254 | argp.add_argument("-e", help="num epochs", type=int, default=1) 255 | 256 | return argp.parse_args(sys.argv[2:]) 257 | 258 | 259 | def build_pretraining_dataset(args): 260 | from ariautils.tokenizer import AbsTokenizer, RelTokenizer 261 | from aria.datasets import PretrainingDataset 262 | 263 | if args.tokenizer_name == "abs": 264 | tokenizer = AbsTokenizer() 265 | elif args.tokenizer_name == "rel": 266 | tokenizer = RelTokenizer() 267 | 268 | PretrainingDataset.build( 269 | tokenizer=tokenizer, 270 | save_dir=args.save_dir, 271 | max_seq_len=args.l, 272 | num_epochs=args.e, 273 | midi_dataset_path=args.load_path, 274 | ) 275 | 276 | 277 | def _parse_finetune_dataset_args(): 278 | argp = argparse.ArgumentParser(prog="aria finetune-dataset") 279 | argp.add_argument( 280 | "-midi_dataset_path", 281 | help="path to midi_dict dataset", 282 | ) 283 | argp.add_argument("-save_dir", help="path to save dataset") 284 | argp.add_argument("-l", help="max sequence length", type=int, default=4096) 285 | argp.add_argument("-e", help="num epochs", type=int, default=1) 286 | 287 | return argp.parse_args(sys.argv[2:]) 288 | 289 | 290 | def build_finetune_dataset(args): 291 | from aria.tokenizer import InferenceAbsTokenizer 292 | from aria.datasets import FinetuningDataset 293 | 294 | tokenizer = InferenceAbsTokenizer() 295 | FinetuningDataset.build( 296 | tokenizer=tokenizer, 297 | save_dir=args.save_dir, 298 | max_seq_len=args.l, 299 | num_epochs=args.e, 300 | midi_dataset_path=args.midi_dataset_path, 301 | ) 302 | 303 | 304 | def main(): 305 | # Nested argparse inspired by - https://shorturl.at/kuKW0 306 | parser = argparse.ArgumentParser(usage="aria []") 307 | parser.add_argument( 308 | "command", 309 | help="command to run", 310 | choices=( 311 | "sample", 312 | "midi-dataset", 313 | "pretrain-dataset", 314 | "finetune-dataset", 315 | ), 316 | ) 317 | 318 | # parse_args defaults to [1:] for args, but you need to 319 | # exclude the rest of the args too, or validation will fail 320 | args = parser.parse_args(sys.argv[1:2]) 321 | 322 | if not hasattr(args, "command"): 323 | parser.print_help() 324 | print("Unrecognized command") 325 | exit(1) 326 | elif args.command == "sample": 327 | sample(args=_parse_sample_args()) 328 | elif args.command == "midi-dataset": 329 | build_midi_dataset(args=_parse_midi_dataset_args()) 330 | elif args.command == "pretrain-dataset": 331 | build_pretraining_dataset(args=_parse_pretrain_dataset_args()) 332 | elif args.command == "finetune-dataset": 333 | build_finetune_dataset(args=_parse_finetune_dataset_args()) 334 | else: 335 | print("Unrecognized command") 336 | parser.print_help() 337 | exit(1) 338 | 339 | 340 | if __name__ == "__main__": 341 | main() 342 | -------------------------------------------------------------------------------- /aria/sample.py: -------------------------------------------------------------------------------- 1 | """Contains generation/sampling code""" 2 | 3 | import copy 4 | import torch 5 | import torch._dynamo.config 6 | import torch._inductor.config 7 | 8 | from typing import List 9 | from tqdm import tqdm 10 | 11 | from aria.inference import TransformerLM 12 | from aria.tokenizer import InferenceAbsTokenizer 13 | from ariautils.tokenizer import Tokenizer, AbsTokenizer 14 | from ariautils.midi import MidiDict 15 | 16 | torch._inductor.config.coordinate_descent_tuning = True 17 | torch._inductor.config.triton.unique_kernel_names = True 18 | torch._inductor.config.fx_graph_cache = True 19 | 20 | 21 | def get_cfg_prompt(prompts: list, pad_tok: str, guidance_end_tok: str): 22 | cfg_prompts = [] 23 | for prompt in prompts: 24 | prompt_no_guidance = prompt[prompt.index(guidance_end_tok) + 1 :] 25 | prompt_no_guidance = [pad_tok] * ( 26 | len(prompt) - len(prompt_no_guidance) 27 | ) + prompt_no_guidance 28 | cfg_prompts.append(prompt) 29 | cfg_prompts.append(prompt_no_guidance) 30 | 31 | return cfg_prompts 32 | 33 | 34 | @torch.inference_mode() 35 | def decode_one( 36 | model: TransformerLM, 37 | idxs: torch.Tensor, 38 | input_pos: torch.Tensor, 39 | pad_idxs: torch.Tensor | None = None, 40 | ): 41 | logits = model.forward( 42 | idxs=idxs, 43 | input_pos=input_pos, 44 | pad_idxs=pad_idxs, 45 | )[:, -1] 46 | 47 | return logits 48 | 49 | 50 | @torch.inference_mode() 51 | def prefill( 52 | model: TransformerLM, 53 | idxs: torch.Tensor, 54 | input_pos: torch.Tensor, 55 | pad_idxs: torch.Tensor | None = None, 56 | ): 57 | logits = model.forward(idxs=idxs, input_pos=input_pos, pad_idxs=pad_idxs)[ 58 | :, -1 59 | ] 60 | 61 | return logits 62 | 63 | 64 | def update_seq_ids_( 65 | seq: torch.Tensor, 66 | idx: int, 67 | next_token_ids: torch.Tensor, 68 | dim_tok_inserted: list, 69 | eos_tok_seen: list, 70 | max_len: int, 71 | force_end: bool, 72 | tokenizer: Tokenizer, 73 | ): 74 | # Insert dim and pad toks 75 | for _idx in range(seq.shape[0]): 76 | if eos_tok_seen[_idx] == True: 77 | next_token_ids[_idx] = tokenizer.tok_to_id[tokenizer.pad_tok] 78 | elif ( 79 | force_end 80 | and idx >= max_len - 130 81 | and dim_tok_inserted[_idx] is False 82 | and tokenizer.id_to_tok[next_token_ids[_idx].item()][0] 83 | not in ("dur", "onset") 84 | ): 85 | next_token_ids[_idx] = tokenizer.tok_to_id[tokenizer.dim_tok] 86 | 87 | # Update dim_tok_inserted and eos_tok_seen 88 | if next_token_ids[_idx] == tokenizer.tok_to_id[tokenizer.dim_tok]: 89 | dim_tok_inserted[_idx] = True 90 | elif next_token_ids[_idx] == tokenizer.tok_to_id[tokenizer.eos_tok]: 91 | eos_tok_seen[_idx] = True 92 | 93 | seq[:, idx] = next_token_ids 94 | 95 | 96 | # TODO: Not working 97 | @torch.autocast( 98 | "cuda", 99 | dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16, 100 | ) 101 | @torch.inference_mode() 102 | def sample_batch( 103 | model: TransformerLM, 104 | tokenizer: Tokenizer, 105 | prompts: List[list], 106 | max_new_tokens: int, 107 | force_end=False, 108 | temperature: float = 0.95, 109 | top_p: float = 0.95, 110 | compile: bool = False, 111 | ): 112 | if force_end: 113 | assert max_new_tokens > 130, "prompt too long to use force_end=True" 114 | 115 | _prompt_len = len(prompts[0]) 116 | _num_prompts = len(prompts) 117 | assert all([len(p) == _prompt_len for p in prompts]) 118 | 119 | model.eval() 120 | dim_tok_inserted = [False for _ in range(_num_prompts)] 121 | eos_tok_seen = [False for _ in range(_num_prompts)] 122 | total_len = _prompt_len + max_new_tokens 123 | seq = torch.stack( 124 | [ 125 | torch.tensor( 126 | tokenizer.encode(p + [tokenizer.pad_tok] * (total_len - len(p))) 127 | ) 128 | for p in prompts 129 | ] 130 | ).cuda() 131 | 132 | if compile is True: 133 | global decode_one 134 | decode_one = torch.compile( 135 | decode_one, 136 | mode="reduce-overhead", 137 | fullgraph=True, 138 | ) 139 | 140 | model.setup_cache( 141 | batch_size=_num_prompts, 142 | max_seq_len=total_len, 143 | dtype=( 144 | torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 145 | ), 146 | ) 147 | 148 | print( 149 | f"Using hyperparams: temp={temperature}, top_p={top_p}, gen_len={max_new_tokens}" 150 | ) 151 | 152 | for idx in ( 153 | pbar := tqdm( 154 | range(_prompt_len, total_len), 155 | total=total_len - _prompt_len, 156 | leave=False, 157 | ) 158 | ): 159 | with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): 160 | if idx == _prompt_len: 161 | logits = prefill( 162 | model, 163 | idxs=seq[:, :idx], 164 | input_pos=torch.arange(0, idx, device=seq.device), 165 | ) 166 | else: 167 | logits = decode_one( 168 | model, 169 | idxs=seq[:, idx - 1 : idx], 170 | input_pos=torch.tensor( 171 | [idx - 1], device=seq.device, dtype=torch.int 172 | ), 173 | ) 174 | 175 | if tokenizer.name == "inference_abs": 176 | logits[:, tokenizer.tok_to_id[tokenizer.prompt_start_tok]] = float( 177 | "-inf" 178 | ) 179 | 180 | if temperature > 0.0: 181 | probs = torch.softmax(logits / temperature, dim=-1) 182 | next_token_ids = sample_top_p(probs, top_p).flatten() 183 | else: 184 | next_token_ids = torch.argmax(logits, dim=-1).flatten() 185 | 186 | update_seq_ids_( 187 | seq=seq, 188 | idx=idx, 189 | next_token_ids=next_token_ids, 190 | dim_tok_inserted=dim_tok_inserted, 191 | eos_tok_seen=eos_tok_seen, 192 | max_len=total_len, 193 | force_end=force_end, 194 | tokenizer=tokenizer, 195 | ) 196 | 197 | if all(seen_eos is True for seen_eos in eos_tok_seen): 198 | break 199 | 200 | decoded_results = [tokenizer.decode(s) for s in seq.tolist()] 201 | decoded_results = [ 202 | ( 203 | res[: res.index(tokenizer.eos_tok) + 1] 204 | if tokenizer.eos_tok in res 205 | else res 206 | ) 207 | for res in decoded_results 208 | ] 209 | 210 | return decoded_results 211 | 212 | 213 | @torch.autocast( 214 | "cuda", 215 | dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16, 216 | ) 217 | @torch.inference_mode() 218 | def sample_batch_cfg( 219 | model: TransformerLM, 220 | tokenizer: InferenceAbsTokenizer, 221 | prompts: List[list], 222 | max_new_tokens: int, 223 | cfg_gamma: float, 224 | force_end=False, 225 | temperature: float = 0.95, 226 | top_p: float = 0.95, 227 | compile: bool = False, 228 | ): 229 | assert 0.0 <= cfg_gamma <= 2.0 230 | assert 0.0 <= temperature <= 2.0 231 | assert 0.5 <= top_p <= 1.0 232 | assert tokenizer.name == "inference_abs" 233 | if force_end: 234 | assert max_new_tokens > 130, "prompt too long to use force_end=True" 235 | 236 | prompts = get_cfg_prompt( 237 | prompts, tokenizer.pad_tok, tokenizer.guidance_end_tok 238 | ) 239 | 240 | _prompt_len = len(prompts[0]) 241 | _num_prompts = len(prompts) 242 | assert all([len(p) == _prompt_len for p in prompts]) 243 | 244 | model.eval() 245 | total_len = _prompt_len + max_new_tokens 246 | seq = torch.stack( 247 | [ 248 | torch.tensor( 249 | tokenizer.encode(p + [tokenizer.pad_tok] * (total_len - len(p))) 250 | ) 251 | for p in prompts 252 | ] 253 | ).cuda() 254 | dim_tok_inserted = [False for _ in range(_num_prompts)] 255 | eos_tok_seen = [False for _ in range(_num_prompts)] 256 | 257 | if compile is True: 258 | global decode_one 259 | decode_one = torch.compile( 260 | decode_one, 261 | mode="reduce-overhead", 262 | fullgraph=True, 263 | ) 264 | 265 | model.setup_cache( 266 | batch_size=_num_prompts, 267 | max_seq_len=total_len, 268 | dtype=( 269 | torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 270 | ), 271 | ) 272 | 273 | print( 274 | f"Using hyperparams: temp={temperature}, top_p={top_p}, gamma={cfg_gamma}, gen_len={max_new_tokens}" 275 | ) 276 | 277 | for idx in ( 278 | pbar := tqdm( 279 | range(_prompt_len, total_len), 280 | total=total_len - _prompt_len, 281 | leave=False, 282 | ) 283 | ): 284 | with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): 285 | if idx == _prompt_len: 286 | logits = prefill( 287 | model, 288 | idxs=seq[:, :idx], 289 | input_pos=torch.arange(0, idx, device=seq.device), 290 | pad_idxs=(seq == tokenizer.pad_id), 291 | ) 292 | else: 293 | logits = decode_one( 294 | model, 295 | idxs=seq[:, idx - 1 : idx], 296 | input_pos=torch.tensor( 297 | [idx - 1], device=seq.device, dtype=torch.int 298 | ), 299 | pad_idxs=(seq == tokenizer.pad_id), 300 | ) 301 | 302 | logits_cfg = cfg_gamma * logits[::2] + (1 - cfg_gamma) * logits[1::2] 303 | logits_cfg[:, tokenizer.tok_to_id[tokenizer.prompt_start_tok]] = float( 304 | "-inf" 305 | ) 306 | logits_cfg[:, tokenizer.tok_to_id[tokenizer.dim_tok]] = float("-inf") 307 | 308 | if temperature > 0.0: 309 | probs = torch.softmax(logits_cfg / temperature, dim=-1) 310 | next_token_ids = sample_top_p(probs, top_p).flatten() 311 | else: 312 | next_token_ids = torch.argmax(logits_cfg, dim=-1).flatten() 313 | 314 | next_token_ids = next_token_ids.repeat_interleave(2) 315 | update_seq_ids_( 316 | seq=seq, 317 | idx=idx, 318 | next_token_ids=next_token_ids, 319 | dim_tok_inserted=dim_tok_inserted, 320 | eos_tok_seen=eos_tok_seen, 321 | max_len=total_len, 322 | force_end=force_end, 323 | tokenizer=tokenizer, 324 | ) 325 | 326 | if all(seen_eos is True for seen_eos in eos_tok_seen): 327 | break 328 | 329 | decoded_results = [tokenizer.decode(s) for s in seq.tolist()][::2] 330 | decoded_results = [ 331 | ( 332 | res[: res.index(tokenizer.eos_tok) + 1] 333 | if tokenizer.eos_tok in res 334 | else res 335 | ) 336 | for res in decoded_results 337 | ] 338 | 339 | return decoded_results 340 | 341 | 342 | def sample_top_p(probs, p): 343 | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) 344 | probs_sum = torch.cumsum(probs_sort, dim=-1) 345 | mask = probs_sum - probs_sort > p 346 | probs_sort[mask] = 0.0 347 | probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) 348 | next_token = torch.multinomial(probs_sort, num_samples=1) 349 | next_token = torch.gather(probs_idx, -1, next_token) 350 | return next_token 351 | 352 | 353 | def get_inference_prompt( 354 | tokenizer: InferenceAbsTokenizer, 355 | midi_dict: MidiDict, 356 | truncate_len: int, 357 | guidance_start_ms: int, 358 | guidance_end_ms: int, 359 | guidance_midi_dict: MidiDict | None = None, 360 | ): 361 | assert tokenizer.name == "inference_abs" 362 | 363 | if guidance_midi_dict is not None: 364 | assert guidance_start_ms is not None and guidance_start_ms >= 0 365 | assert guidance_end_ms is not None and guidance_end_ms >= 0 366 | assert ( 367 | tokenizer._config["guidance"]["min_ms"] 368 | <= guidance_end_ms - guidance_start_ms 369 | <= tokenizer._config["guidance"]["max_ms"] 370 | ) 371 | 372 | prompt_seq = tokenizer.tokenize( 373 | midi_dict=midi_dict, 374 | prompt_intervals_ms=( 375 | [[0, truncate_len * 1e3]] if truncate_len > 0 else [] 376 | ), 377 | guidance_midi_dict=guidance_midi_dict, 378 | guidance_start_ms=guidance_start_ms, 379 | guidance_end_ms=guidance_end_ms, 380 | ) 381 | 382 | if tokenizer.prompt_end_tok in prompt_seq: 383 | prompt_seq = prompt_seq[ 384 | : prompt_seq.index(tokenizer.prompt_end_tok) + 1 385 | ] 386 | else: 387 | print("No notes found in prompt region") 388 | prompt_seq = prompt_seq[: prompt_seq.index(tokenizer.bos_tok) + 1] 389 | 390 | if tokenizer.dim_tok in prompt_seq: 391 | prompt_seq.remove(tokenizer.dim_tok) 392 | 393 | if ( 394 | guidance_midi_dict is not None 395 | and tokenizer.guidance_start_tok in prompt_seq 396 | ): 397 | guidance_seq = copy.deepcopy(prompt_seq) 398 | guidance_seq = guidance_seq[ 399 | : guidance_seq.index(tokenizer.guidance_end_tok) 400 | ] 401 | guidance_seq[0] = ("prefix", "instrument", "piano") 402 | else: 403 | guidance_seq = None 404 | 405 | return prompt_seq, guidance_seq 406 | -------------------------------------------------------------------------------- /aria/tokenizer.py: -------------------------------------------------------------------------------- 1 | """Tokenizer for MIDI conditioned completions""" 2 | 3 | import copy 4 | import random 5 | import functools 6 | 7 | from typing import Callable 8 | 9 | from aria.config import load_config 10 | from ariautils.midi import MidiDict 11 | from ariautils.tokenizer import AbsTokenizer as _AbsTokenizer 12 | 13 | 14 | class InferenceAbsTokenizer(_AbsTokenizer): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | self.name = "inference_abs" 19 | self._config = load_config()["tokenizer"]["inference_abs"] 20 | 21 | self.prompt_start_tok = "" 22 | self.prompt_end_tok = "" 23 | self.guidance_start_tok = "" 24 | self.guidance_end_tok = "" 25 | 26 | self.add_tokens_to_vocab( 27 | [ 28 | self.prompt_start_tok, 29 | self.prompt_end_tok, 30 | self.guidance_start_tok, 31 | self.guidance_end_tok, 32 | ] 33 | ) 34 | self.special_tokens.append(self.prompt_start_tok) 35 | self.special_tokens.append(self.prompt_end_tok) 36 | self.special_tokens.append(self.guidance_start_tok) 37 | self.special_tokens.append(self.guidance_end_tok) 38 | 39 | def _get_guidance_interval_ms(self, guidance_midi_dict: MidiDict): 40 | first_note_onset_ms = guidance_midi_dict.tick_to_ms( 41 | guidance_midi_dict.note_msgs[0]["tick"] 42 | ) 43 | last_note_onset_ms = guidance_midi_dict.tick_to_ms( 44 | guidance_midi_dict.note_msgs[-1]["tick"] 45 | ) 46 | guidance_segment_length_ms = random.randint( 47 | self._config["guidance"]["min_ms"], 48 | min(self._config["guidance"]["max_ms"], last_note_onset_ms), 49 | ) 50 | guidance_start_ms = random.randint( 51 | first_note_onset_ms, 52 | last_note_onset_ms - guidance_segment_length_ms, 53 | ) 54 | guidance_end_ms = guidance_start_ms + guidance_segment_length_ms 55 | 56 | return guidance_start_ms, guidance_end_ms 57 | 58 | def _get_guidance_seq( 59 | self, 60 | guidance_midi_dict: MidiDict, 61 | guidance_start_ms: int | None = None, 62 | guidance_end_ms: int | None = None, 63 | ): 64 | assert guidance_midi_dict.note_msgs is not None 65 | 66 | # Need to validate these numbers 67 | if guidance_start_ms is None: 68 | assert guidance_end_ms is None 69 | guidance_start_ms, guidance_end_ms = self._get_guidance_interval_ms( 70 | guidance_midi_dict=guidance_midi_dict 71 | ) 72 | 73 | slice_note_msgs = [] 74 | for note_msg in guidance_midi_dict.note_msgs: 75 | start_ms = guidance_midi_dict.tick_to_ms(note_msg["data"]["start"]) 76 | if guidance_start_ms <= start_ms <= guidance_end_ms: 77 | slice_note_msgs.append(note_msg) 78 | 79 | slice_midi_dict = copy.deepcopy(guidance_midi_dict) 80 | slice_midi_dict.note_msgs = slice_note_msgs 81 | 82 | if len(slice_midi_dict.note_msgs) == 0: 83 | # Catches not note in interval 84 | return [] 85 | 86 | guidance_seq = self._tokenize_midi_dict( 87 | midi_dict=slice_midi_dict, 88 | remove_preceding_silence=True, 89 | ) 90 | 91 | if self.dim_tok in guidance_seq: 92 | guidance_seq.remove(self.dim_tok) 93 | 94 | guidance_seq = guidance_seq[ 95 | guidance_seq.index(self.bos_tok) 96 | + 1 : guidance_seq.index(self.eos_tok) 97 | ] 98 | 99 | return ( 100 | [self.guidance_start_tok] + guidance_seq + [self.guidance_end_tok] 101 | ) 102 | 103 | def _add_prompt_tokens( 104 | self, seq: list, prompt_start_ms: int, prompt_end_ms: int 105 | ): 106 | res = copy.deepcopy(seq) 107 | prompt_tok_inserted = False 108 | time_tok_cnt = 0 109 | curr_time_ms = 0 110 | for idx, (tok_1, tok_2) in enumerate(zip(seq, seq[1:])): 111 | if tok_1 == self.time_tok: 112 | time_tok_cnt += 1 113 | elif isinstance(tok_1, tuple) and tok_1[0] in self.instruments_wd: 114 | assert isinstance(tok_2, tuple) and tok_2[0] == "onset" 115 | 116 | # Adjust time 117 | curr_time_ms = (self.abs_time_step_ms * time_tok_cnt) + tok_2[1] 118 | 119 | if ( 120 | curr_time_ms >= prompt_start_ms 121 | and prompt_tok_inserted == False 122 | ): 123 | res.insert(idx, self.prompt_start_tok) 124 | prompt_tok_inserted = True 125 | elif ( 126 | curr_time_ms > prompt_end_ms and prompt_tok_inserted == True 127 | ): 128 | res.insert(idx + 1, self.prompt_end_tok) 129 | break 130 | 131 | return res 132 | 133 | def tokenize( 134 | self, 135 | midi_dict: MidiDict, 136 | prompt_intervals_ms: list[tuple[int, int]], 137 | guidance_midi_dict: MidiDict | None = None, 138 | guidance_start_ms: int | None = None, 139 | guidance_end_ms: int | None = None, 140 | ): 141 | seq = self._tokenize_midi_dict( 142 | midi_dict=midi_dict, remove_preceding_silence=True 143 | ) 144 | first_note_ms = midi_dict.tick_to_ms( 145 | midi_dict.note_msgs[0]["data"]["start"] 146 | ) 147 | 148 | for prompt_start_ms, prompt_end_ms in prompt_intervals_ms: 149 | if prompt_end_ms > first_note_ms: 150 | seq = self._add_prompt_tokens( 151 | seq, 152 | prompt_start_ms=prompt_start_ms - first_note_ms, 153 | prompt_end_ms=prompt_end_ms - first_note_ms, 154 | ) 155 | 156 | if guidance_midi_dict is not None: 157 | guidance_seq = self._get_guidance_seq( 158 | guidance_midi_dict=guidance_midi_dict, 159 | guidance_start_ms=guidance_start_ms, 160 | guidance_end_ms=guidance_end_ms, 161 | ) 162 | else: 163 | guidance_seq = [] 164 | 165 | return guidance_seq + seq 166 | 167 | def detokenize(self, tokenized_seq: list, **kwargs): 168 | if self.guidance_end_tok in tokenized_seq: 169 | seq = tokenized_seq[tokenized_seq.index(self.guidance_end_tok) :] 170 | else: 171 | seq = tokenized_seq 172 | 173 | return super()._detokenize_midi_dict(seq, **kwargs) 174 | 175 | def export_data_aug(self): 176 | return [ 177 | self.export_guidance_tempo_aug(max_tempo_aug=0.2, mixup=True), 178 | self.export_guidance_pitch_aug(3), 179 | self.export_guidance_velocity_aug(2), 180 | ] 181 | 182 | def export_guidance_aug_fn(self, aug_fn): 183 | """Transforms augmentation function to only apply to guidance seq""" 184 | 185 | def _guidance_seq_aug_fn( 186 | src: list, 187 | _aug_fn: Callable, 188 | pad_tok: str, 189 | **kwargs, 190 | ) -> list: 191 | 192 | initial_seq_len = len(src) 193 | if self.guidance_start_tok in src and self.guidance_end_tok in src: 194 | guidance_seq = src[ 195 | src.index(self.guidance_start_tok) 196 | + 1 : src.index(self.guidance_end_tok) 197 | ] 198 | seq = src[src.index(self.guidance_end_tok) + 1 :] 199 | 200 | if len(guidance_seq) == 0: 201 | return src 202 | else: 203 | return src 204 | 205 | augmented_guidance_seq = _aug_fn(guidance_seq) 206 | res = ( 207 | [self.guidance_start_tok] 208 | + augmented_guidance_seq 209 | + [self.guidance_end_tok] 210 | + seq 211 | ) 212 | 213 | # Pad or truncate to original sequence length as necessary 214 | res = res[:initial_seq_len] 215 | res += [pad_tok] * (initial_seq_len - len(res)) 216 | 217 | return res 218 | 219 | return functools.partial( 220 | _guidance_seq_aug_fn, 221 | _aug_fn=aug_fn, 222 | pad_tok=self.pad_tok, 223 | ) 224 | 225 | def export_guidance_pitch_aug(self, max_pitch_aug: int): 226 | """Apply pitch augmentation to the guidance sequence""" 227 | 228 | return self.export_guidance_aug_fn( 229 | self.export_pitch_aug(max_pitch_aug=max_pitch_aug) 230 | ) 231 | 232 | def export_guidance_velocity_aug(self, max_num_aug_steps: int): 233 | """Apply velocity augmentation to the guidance sequence""" 234 | 235 | return self.export_guidance_aug_fn( 236 | self.export_velocity_aug(max_num_aug_steps=max_num_aug_steps) 237 | ) 238 | 239 | def export_guidance_tempo_aug(self, max_tempo_aug: int, mixup: bool): 240 | """Apply tempo augmentation to the guidance sequence""" 241 | 242 | return self.export_guidance_aug_fn( 243 | self.export_tempo_aug(max_tempo_aug=max_tempo_aug, mixup=mixup) 244 | ) 245 | 246 | def split(self, seq: list, seq_len: int): 247 | def _process_chunk(_chunk: list): 248 | # Ensure first token is note token 249 | while True: 250 | if _chunk[0] == self.bos_tok: 251 | break 252 | elif ( 253 | isinstance(_chunk[0], tuple) 254 | and _chunk[0][0] in self.instruments_wd 255 | ): 256 | break 257 | else: 258 | _chunk.pop(0) 259 | 260 | # Insert prompt_start_tok if it is missing (but required) 261 | for idx in range(len(_chunk)): 262 | tok = _chunk[idx] 263 | 264 | if tok == self.prompt_start_tok: 265 | break 266 | elif tok == self.prompt_end_tok: 267 | if _chunk[0] == self.bos_tok: 268 | _chunk.insert(1, self.prompt_start_tok) 269 | else: 270 | _chunk.insert(0, self.prompt_start_tok) 271 | break 272 | 273 | return _chunk 274 | 275 | guidance = [] 276 | if self.guidance_start_tok in seq: 277 | guidance_start = seq.index(self.guidance_start_tok) 278 | guidance_end = seq.index(self.guidance_end_tok) 279 | guidance = seq[guidance_start : guidance_end + 1] 280 | seq = seq[guidance_end + 1 :] 281 | 282 | prefix = [] 283 | while seq: 284 | tok = seq[0] 285 | if tok != self.bos_tok and tok[0] == "prefix": 286 | prefix.append(seq.pop(0)) 287 | else: 288 | break 289 | 290 | chunks = [ 291 | _process_chunk(seq[idx : idx + seq_len]) 292 | for idx in range(0, len(seq) - 100, seq_len) 293 | ] 294 | 295 | res = [] 296 | for chunk in chunks: 297 | sub_seq = guidance + prefix + chunk 298 | sub_seq = sub_seq[:seq_len] 299 | sub_seq += [self.pad_tok] * (seq_len - len(sub_seq)) 300 | 301 | res.append(sub_seq) 302 | 303 | return res 304 | -------------------------------------------------------------------------------- /aria/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import csv 4 | import argparse 5 | import logging 6 | import torch 7 | import accelerate 8 | 9 | from torch import nn as nn 10 | from torch.utils.data import DataLoader 11 | 12 | from torch.utils.flop_counter import FlopCounterMode 13 | from triton.testing import do_bench 14 | from accelerate.logging import get_logger 15 | from safetensors.torch import load_file 16 | from logging.handlers import RotatingFileHandler 17 | from tqdm import tqdm 18 | from typing import List 19 | 20 | from aria.config import load_model_config 21 | from aria.model import ModelConfig, TransformerLM 22 | from ariautils.tokenizer import Tokenizer, AbsTokenizer, RelTokenizer 23 | from aria.tokenizer import InferenceAbsTokenizer 24 | from aria.datasets import ( 25 | TrainingDataset, 26 | PretrainingDataset, 27 | FinetuningDataset, 28 | ) 29 | from aria.utils import _load_weight 30 | 31 | torch._dynamo.config.optimize_ddp = False 32 | 33 | 34 | # ----- USAGE ----- 35 | # 36 | # This script is meant to be run using the huggingface accelerate cli, see: 37 | # 38 | # https://huggingface.co/docs/accelerate/basic_tutorials/launch 39 | # https://huggingface.co/docs/accelerate/package_reference/cli 40 | # 41 | # For example usage you could run the pre-training script with: 42 | # 43 | # accelerate launch [arguments] aria/train.py train \ 44 | # small \ 45 | # -train_data data/train \ 46 | # -val_data data/val \ 47 | # -epochs 10 \ 48 | # -bs 32 \ 49 | # -workers 8 50 | # 51 | # You could resume a run from an accelerate checkpoint with: 52 | # 53 | # accelerate launch [arguments] aria/train.py resume \ 54 | # small \ 55 | # -train_data data/train \ 56 | # -val_data data/val \ 57 | # -cp_dir models/epoch5_step0 \ 58 | # -r_step 0 \ 59 | # -r_epoch 5 \ 60 | # -epochs 5 \ 61 | # -bs 32 \ 62 | # -workers 8 63 | 64 | 65 | def setup_logger(project_dir: str): 66 | # Get logger and reset all handlers 67 | logger = logging.getLogger(__name__) 68 | for h in logger.handlers[:]: 69 | logger.removeHandler(h) 70 | 71 | logger.propagate = False 72 | logger.setLevel(logging.DEBUG) 73 | formatter = logging.Formatter( 74 | "[%(asctime)s] %(name)s: [%(levelname)s] %(message)s", 75 | ) 76 | 77 | fh = RotatingFileHandler( 78 | os.path.join(project_dir, "logs.txt"), backupCount=5, maxBytes=1024**3 79 | ) 80 | fh.setLevel(logging.DEBUG) 81 | fh.setFormatter(formatter) 82 | logger.addHandler(fh) 83 | 84 | ch = logging.StreamHandler() 85 | ch.setLevel(logging.INFO) 86 | ch.setFormatter(formatter) 87 | logger.addHandler(ch) 88 | 89 | return get_logger(__name__) # using accelerate.logging.get_logger() 90 | 91 | 92 | def get_tokenizer_name( 93 | train_data_paths: str, 94 | val_data_path: str, 95 | ): 96 | """This will throw an error if there is a tokenizer mismatch""" 97 | train_config = TrainingDataset.get_config_from_path(train_data_paths[0]) 98 | val_config = TrainingDataset.get_config_from_path(val_data_path) 99 | 100 | assert ( 101 | train_config["tokenizer_name"] == val_config["tokenizer_name"] 102 | ), "Dataset tokenizers don't match" 103 | 104 | return train_config["tokenizer_name"] 105 | 106 | 107 | def setup_project_dir(project_dir: str | None): 108 | if not project_dir: 109 | # Create project directory 110 | if not os.path.isdir("./experiments"): 111 | os.mkdir("./experiments") 112 | 113 | project_dirs = [ 114 | _dir 115 | for _dir in os.listdir("./experiments") 116 | if os.path.isdir(os.path.join("experiments", _dir)) 117 | ] 118 | 119 | ind = 0 120 | while True: 121 | if str(ind) not in project_dirs: 122 | break 123 | else: 124 | ind += 1 125 | 126 | project_dir_abs = os.path.abspath(os.path.join("experiments", str(ind))) 127 | assert not os.path.isdir(project_dir_abs) 128 | os.mkdir(project_dir_abs) 129 | 130 | elif project_dir: 131 | # Run checks on project directory 132 | if os.path.isdir(project_dir): 133 | assert ( 134 | len(os.listdir(project_dir)) == 0 135 | ), "Provided project directory is not empty" 136 | project_dir_abs = os.path.abspath(project_dir) 137 | elif os.path.isfile(project_dir): 138 | raise FileExistsError( 139 | "The provided path points toward an existing file" 140 | ) 141 | else: 142 | try: 143 | os.mkdir(project_dir) 144 | except Exception as e: 145 | raise e(f"Failed to create project directory at {project_dir}") 146 | project_dir_abs = os.path.abspath(project_dir) 147 | 148 | os.mkdir(os.path.join(project_dir_abs, "checkpoints")) 149 | 150 | return project_dir_abs 151 | 152 | 153 | def _get_optim( 154 | lr: float, 155 | model: nn.Module, 156 | num_epochs: int, 157 | steps_per_epoch: int, 158 | warmup: int = 100, 159 | end_ratio: int = 0.1, 160 | ): 161 | optimizer = torch.optim.AdamW( 162 | model.parameters(), 163 | lr=lr, 164 | weight_decay=0.1, 165 | betas=(0.9, 0.95), 166 | eps=1e-5, 167 | ) 168 | 169 | warmup_lrs = torch.optim.lr_scheduler.LinearLR( 170 | optimizer, 171 | start_factor=0.000001, 172 | end_factor=1, 173 | total_iters=warmup, 174 | ) 175 | linear_decay_lrs = torch.optim.lr_scheduler.LinearLR( 176 | optimizer, 177 | start_factor=1, 178 | end_factor=end_ratio, 179 | total_iters=(num_epochs * steps_per_epoch) - warmup, 180 | ) 181 | 182 | lr_scheduler = torch.optim.lr_scheduler.SequentialLR( 183 | optimizer, 184 | schedulers=[warmup_lrs, linear_decay_lrs], 185 | milestones=[warmup], 186 | ) 187 | 188 | return optimizer, lr_scheduler 189 | 190 | 191 | def get_optim( 192 | model: nn.Module, 193 | num_epochs: int, 194 | steps_per_epoch: int, 195 | ): 196 | LR = 3e-4 197 | END_RATIO = 0.1 198 | WARMUP_STEPS = 200 199 | 200 | return _get_optim( 201 | lr=LR, 202 | model=model, 203 | num_epochs=num_epochs, 204 | steps_per_epoch=steps_per_epoch, 205 | warmup=WARMUP_STEPS, 206 | end_ratio=END_RATIO, 207 | ) 208 | 209 | 210 | def get_dataloaders( 211 | train_data_dirs: List[str], 212 | val_data_dir: str, 213 | tokenizer: Tokenizer, 214 | batch_size: int, 215 | num_workers: int, 216 | init_epoch: int | None = None, 217 | apply_aug: bool = True, 218 | finetune: bool = False, 219 | ): 220 | logger = logging.getLogger(__name__) 221 | if finetune == False: 222 | train_dataset = PretrainingDataset( 223 | dir_paths=train_data_dirs, 224 | tokenizer=tokenizer, 225 | ) 226 | val_dataset = PretrainingDataset( 227 | dir_paths=val_data_dir, 228 | tokenizer=tokenizer, 229 | ) 230 | elif finetune == True: 231 | train_dataset = FinetuningDataset( 232 | dir_paths=train_data_dirs, 233 | tokenizer=tokenizer, 234 | ) 235 | val_dataset = FinetuningDataset( 236 | dir_paths=val_data_dir, 237 | tokenizer=tokenizer, 238 | ) 239 | else: 240 | raise ValueError 241 | 242 | if init_epoch: 243 | train_dataset.init_epoch(idx=init_epoch) 244 | 245 | assert ( 246 | len(val_dataset.epoch_files_by_dir[0]) == 1 247 | ), "val-data directory should only contain one epoch" 248 | 249 | if apply_aug: 250 | train_dataset.set_transform(tokenizer.export_data_aug()) 251 | 252 | train_dataloader = DataLoader( 253 | train_dataset, 254 | batch_size=batch_size, 255 | num_workers=num_workers, 256 | shuffle=True, 257 | ) 258 | val_dataloader = DataLoader( 259 | val_dataset, 260 | batch_size=batch_size, 261 | num_workers=num_workers, 262 | shuffle=False, 263 | ) 264 | 265 | return train_dataloader, val_dataloader 266 | 267 | 268 | def _train( 269 | epochs: int, 270 | accelerator: accelerate.Accelerator, 271 | model: TransformerLM, 272 | train_dataloader: DataLoader, 273 | val_dataloader: DataLoader, 274 | optimizer: torch.optim.Optimizer, 275 | scheduler: torch.optim.lr_scheduler.LRScheduler = None, 276 | steps_per_checkpoint: int | None = None, 277 | resume_step: int | None = None, 278 | resume_epoch: int | None = None, 279 | project_dir: str | None = None, 280 | ): 281 | def profile_flops(dataloader: DataLoader): 282 | def _bench(): 283 | for batch in dataloader: 284 | src, tgt = batch # (b_sz, s_len), (b_sz, s_len, v_sz) 285 | logits = model(src) # (b_sz, s_len, v_sz) 286 | logits = logits.transpose(1, 2) 287 | loss = loss_fn(logits, tgt) 288 | 289 | # Backwards step - omit optimizer.step() 290 | accelerator.backward(loss) 291 | optimizer.zero_grad() 292 | break 293 | 294 | logger.info( 295 | f"Model has " 296 | f"{'{:,}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad))} " 297 | "parameters" 298 | ) 299 | 300 | # logger.info("Profiling FLOP") 301 | # flop_counter = FlopCounterMode(display=False) 302 | # _bench() 303 | 304 | # with flop_counter: 305 | # _bench() 306 | # total_flop = sum(flop_counter.get_flop_counts()["Global"].values()) 307 | # logger.info(f"Forwards & backwards FLOP: {total_flop / 1e12} TF") 308 | 309 | def make_checkpoint( 310 | _accelerator: accelerate.Accelerator, _epoch: int, _step: int 311 | ): 312 | if accelerator.is_main_process: 313 | checkpoint_dir = os.path.join( 314 | project_dir, 315 | "checkpoints", 316 | f"epoch{_epoch}_step{_step}", 317 | ) 318 | 319 | logger.info( 320 | f"EPOCH {_epoch}/{epochs + start_epoch}: Saving checkpoint - {checkpoint_dir}" 321 | ) 322 | _accelerator.save_state(checkpoint_dir) 323 | 324 | # This is all slightly messy as train_loop and val_loop make use of the 325 | # variables in the wider scope. Perhaps refactor this at some point. 326 | def train_loop(dataloader: DataLoader, _epoch: int, _resume_step: int = 0): 327 | loss = torch.tensor([0.0]) 328 | avg_train_loss = 0 329 | trailing_loss = 0 330 | loss_buffer = [] 331 | 332 | try: 333 | lr_for_print = "{:.2e}".format(scheduler.get_last_lr()[0]) 334 | except Exception: 335 | pass 336 | else: 337 | lr_for_print = "{:.2e}".format(optimizer.param_groups[-1]["lr"]) 338 | 339 | model.train() 340 | for __step, batch in ( 341 | pbar := tqdm( 342 | enumerate(dataloader), 343 | total=len(dataloader) + _resume_step, 344 | initial=_resume_step, 345 | leave=False, 346 | ) 347 | ): 348 | pbar.set_postfix_str( 349 | f"lr={lr_for_print}, " 350 | f"loss={round(loss.item(), 4)}, " 351 | f"trailing={round(trailing_loss, 4)}" 352 | ) 353 | 354 | with accelerator.accumulate(model): 355 | step = __step + _resume_step + 1 356 | src, tgt, mask = batch # (b_sz, s_len), (b_sz, s_len, v_sz) 357 | logits = model(src) # (b_sz, s_len, v_sz) 358 | logits = logits.transpose( 359 | 1, 2 360 | ) # Transpose for CrossEntropyLoss 361 | loss = loss_fn(logits, tgt) 362 | 363 | if mask.sum() == 0: 364 | loss = (loss * 0).sum() 365 | else: 366 | loss = loss * mask 367 | loss = loss[loss != 0.0].mean() 368 | 369 | # Calculate statistics 370 | loss_buffer.append(accelerator.gather(loss).mean(dim=0).item()) 371 | trailing_loss = sum(loss_buffer[-TRAILING_LOSS_STEPS:]) / len( 372 | loss_buffer[-TRAILING_LOSS_STEPS:] 373 | ) 374 | avg_train_loss = sum(loss_buffer) / len(loss_buffer) 375 | 376 | # Logging 377 | logger.debug( 378 | f"EPOCH {_epoch} STEP {step}: " 379 | f"lr={lr_for_print}, " 380 | f"loss={round(loss.item(), 4)}, " 381 | f"trailing_loss={round(trailing_loss, 4)}, " 382 | f"average_loss={round(avg_train_loss, 4)}" 383 | ) 384 | 385 | if accelerator.is_main_process: 386 | loss_writer.writerow([_epoch, step, loss.item()]) 387 | 388 | accelerator.backward(loss) 389 | optimizer.step() 390 | optimizer.zero_grad() 391 | if scheduler: 392 | scheduler.step() 393 | lr_for_print = "{:.2e}".format(scheduler.get_last_lr()[0]) 394 | 395 | if steps_per_checkpoint: 396 | if step % steps_per_checkpoint == 0: 397 | make_checkpoint( 398 | _accelerator=accelerator, 399 | _epoch=_epoch, 400 | _step=step, 401 | ) 402 | 403 | logger.info( 404 | f"EPOCH {_epoch}/{epochs + start_epoch}: Finished training - " 405 | f"average_loss={round(avg_train_loss, 4)}" 406 | ) 407 | 408 | return avg_train_loss 409 | 410 | @torch.no_grad() 411 | def val_loop(dataloader, _epoch: int): 412 | loss_buffer = [] 413 | model.eval() 414 | for step, batch in ( 415 | pbar := tqdm( 416 | enumerate(dataloader), 417 | total=len(dataloader), 418 | leave=False, 419 | ) 420 | ): 421 | src, tgt, mask = batch # (b_sz, s_len), (b_sz, s_len, v_sz) 422 | logits = model(src) # (b_sz, s_len, v_sz) 423 | logits = logits.transpose(1, 2) # Transpose for CrossEntropyLoss 424 | loss = loss_fn(logits, tgt) 425 | 426 | if mask.sum() == 0: 427 | loss = (loss * 0).sum() 428 | else: 429 | loss = loss * mask 430 | loss = loss[loss != 0.0].mean() 431 | 432 | # Logging 433 | loss_buffer.append(accelerator.gather(loss).mean(dim=0).item()) 434 | avg_val_loss = sum(loss_buffer) / len(loss_buffer) 435 | pbar.set_postfix_str(f"average_loss={round(avg_val_loss, 4)}") 436 | 437 | # EPOCH 438 | logger.info( 439 | f"EPOCH {_epoch}/{epochs + start_epoch}: Finished evaluation - " 440 | f"average_loss={round(avg_val_loss, 4)}" 441 | ) 442 | 443 | return avg_val_loss 444 | 445 | if steps_per_checkpoint: 446 | assert ( 447 | steps_per_checkpoint > 1 448 | ), "Invalid checkpoint mode value (too small)" 449 | 450 | TRAILING_LOSS_STEPS = 200 451 | PAD_ID = train_dataloader.dataset.tokenizer.pad_id 452 | logger = get_logger(__name__) # Accelerate logger 453 | loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID, reduction="none") 454 | profile_flops(dataloader=train_dataloader) 455 | 456 | if accelerator.is_main_process: 457 | loss_csv = open(os.path.join(project_dir, "loss.csv"), "w") 458 | loss_writer = csv.writer(loss_csv) 459 | loss_writer.writerow(["epoch", "step", "loss"]) 460 | epoch_csv = open(os.path.join(project_dir, "epoch.csv"), "w") 461 | epoch_writer = csv.writer(epoch_csv) 462 | epoch_writer.writerow(["epoch", "avg_train_loss", "avg_val_loss"]) 463 | 464 | if resume_epoch is not None: 465 | start_epoch = resume_epoch + 1 466 | else: 467 | start_epoch = 0 468 | 469 | if resume_step is not None: 470 | assert resume_epoch is not None, "Must provide resume epoch" 471 | logger.info( 472 | f"Resuming training from step {resume_step} - logging as EPOCH {resume_epoch}" 473 | ) 474 | skipped_dataloader = accelerator.skip_first_batches( 475 | dataloader=train_dataloader, 476 | num_batches=resume_step, 477 | ) 478 | 479 | avg_train_loss = train_loop( 480 | dataloader=skipped_dataloader, 481 | _epoch=resume_epoch, 482 | _resume_step=resume_step, 483 | ) 484 | avg_val_loss = val_loop(dataloader=val_dataloader, _epoch=resume_epoch) 485 | if accelerator.is_main_process: 486 | epoch_writer.writerow([resume_epoch, avg_train_loss, avg_val_loss]) 487 | epoch_csv.flush() 488 | make_checkpoint( 489 | _accelerator=accelerator, _epoch=start_epoch, _step=0 490 | ) 491 | 492 | for epoch in range(start_epoch, epochs + start_epoch): 493 | train_dataloader.dataset.init_epoch(epoch) 494 | avg_train_loss = train_loop(dataloader=train_dataloader, _epoch=epoch) 495 | avg_val_loss = val_loop(dataloader=val_dataloader, _epoch=epoch) 496 | if accelerator.is_main_process: 497 | epoch_writer.writerow([epoch, avg_train_loss, avg_val_loss]) 498 | epoch_csv.flush() 499 | make_checkpoint(_accelerator=accelerator, _epoch=epoch + 1, _step=0) 500 | 501 | logging.shutdown() 502 | if accelerator.is_main_process: 503 | loss_csv.close() 504 | epoch_csv.close() 505 | 506 | 507 | def resume_train( 508 | model_name: str, 509 | train_data_paths: str, 510 | val_data_path: str, 511 | num_workers: int, 512 | batch_size: int, 513 | grad_acc_steps: int, 514 | epochs: int, 515 | checkpoint_dir: str, 516 | resume_epoch: int, 517 | resume_step: int, 518 | steps_per_checkpoint: int | None = None, 519 | project_dir: str = None, 520 | ): 521 | # Validate inputs 522 | assert 0 < num_workers <= 128, "Too many workers" 523 | assert epochs > 0, "Invalid number of epochs" 524 | assert batch_size > 0, "Invalid batch size" 525 | assert torch.cuda.is_available() is True, "CUDA not available" 526 | assert os.path.isdir(checkpoint_dir), f"No dir at {checkpoint_dir}" 527 | for train_data_path in train_data_paths: 528 | assert os.path.isdir( 529 | train_data_path 530 | ), f"No dir found at {train_data_path}" 531 | assert os.path.isdir(val_data_path), f"No dir found at {val_data_path}" 532 | 533 | tokenizer_name = get_tokenizer_name(train_data_paths, val_data_path) 534 | if tokenizer_name == "abs": 535 | tokenizer = AbsTokenizer() 536 | elif tokenizer_name == "inference_abs": 537 | tokenizer = InferenceAbsTokenizer() 538 | elif tokenizer_name == "rel": 539 | tokenizer = RelTokenizer() 540 | else: 541 | raise Exception("Invalid tokenizer name") 542 | 543 | accelerator = accelerate.Accelerator( 544 | project_dir=project_dir, gradient_accumulation_steps=grad_acc_steps 545 | ) 546 | if accelerator.is_main_process: 547 | project_dir = setup_project_dir(project_dir) 548 | logger = setup_logger(project_dir) 549 | 550 | logger = get_logger(__name__) 551 | logger.info(f"Using project directory {project_dir} ") 552 | logger.warning( 553 | "Please insure that the training config and resume step are set " 554 | "correctly, the script does not currently check that this is the case. " 555 | "If the previous checkpoint was saved at step n, then resume_step " 556 | "should be n. If there is a mismatch between the batch size then the " 557 | "script will resume at the wrong step. It is also important that the " 558 | "same distributed setup is used for training." 559 | ) 560 | logger.info( 561 | f"Using training config: " 562 | f"model_name={model_name}, " 563 | f"epochs={epochs}, " 564 | f"batch_size={batch_size}, " 565 | f"grad_acc_steps={grad_acc_steps}, " 566 | f"num_workers={num_workers}, " 567 | f"checkpoint_dir={checkpoint_dir}, " 568 | f"resume_step={resume_step}, " 569 | f"resume_epoch={resume_epoch}" 570 | ) 571 | 572 | if steps_per_checkpoint: 573 | logger.info(f"Creating checkpoints every {steps_per_checkpoint}") 574 | 575 | # Init model 576 | model_config = ModelConfig(**load_model_config(model_name)) 577 | model_config.set_vocab_size(tokenizer.vocab_size) 578 | model = TransformerLM(model_config) 579 | model.compile() 580 | 581 | train_dataloader, val_dataloader = get_dataloaders( 582 | train_data_dirs=train_data_paths, 583 | val_data_dir=val_data_path, 584 | tokenizer=tokenizer, 585 | init_epoch=resume_epoch, 586 | batch_size=batch_size, 587 | num_workers=num_workers, 588 | apply_aug=True, 589 | ) 590 | optimizer, scheduler = get_optim( 591 | model, 592 | num_epochs=epochs, 593 | steps_per_epoch=len(train_dataloader), 594 | ) 595 | 596 | ( 597 | model, 598 | train_dataloader, 599 | val_dataloader, 600 | optimizer, 601 | scheduler, 602 | ) = accelerator.prepare( 603 | model, 604 | train_dataloader, 605 | val_dataloader, 606 | optimizer, 607 | scheduler, 608 | ) 609 | 610 | try: 611 | accelerator.load_state(checkpoint_dir) 612 | except Exception as e: 613 | raise Exception( 614 | f"Failed to load checkpoint: {e}\n" 615 | "This could be due to a mismatch between the tokenizer used " 616 | "to build the pre-training and fine-tuning datasets" 617 | ) 618 | logger.info(f"Loaded checkpoint at {checkpoint_dir}") 619 | logger.info("Starting train job") 620 | 621 | _train( 622 | epochs=epochs, 623 | accelerator=accelerator, 624 | model=model, 625 | train_dataloader=train_dataloader, 626 | val_dataloader=val_dataloader, 627 | optimizer=optimizer, 628 | scheduler=scheduler, 629 | steps_per_checkpoint=steps_per_checkpoint, 630 | resume_step=resume_step, 631 | resume_epoch=resume_epoch, 632 | project_dir=project_dir, 633 | ) 634 | 635 | 636 | def train( 637 | model_name: str, 638 | train_data_paths: List[str], 639 | val_data_path: str, 640 | num_workers: int, 641 | batch_size: int, 642 | grad_acc_steps: int, 643 | epochs: int, 644 | checkpoint_path: str | None = None, 645 | steps_per_checkpoint: int | None = None, 646 | project_dir: str = None, 647 | ): 648 | # Validate inputs 649 | assert 0 < num_workers <= 128, "Too many workers" 650 | assert epochs > 0, "Invalid number of epochs" 651 | assert batch_size > 0, "Invalid batch size" 652 | assert torch.cuda.is_available() is True, "CUDA not available" 653 | for train_data_path in train_data_paths: 654 | assert os.path.isdir( 655 | train_data_path 656 | ), f"No dir found at {train_data_path}" 657 | assert os.path.isdir(val_data_path), f"No dir found at {val_data_path}" 658 | 659 | tokenizer_name = get_tokenizer_name(train_data_paths, val_data_path) 660 | if tokenizer_name == "abs": 661 | tokenizer = AbsTokenizer() 662 | elif tokenizer_name == "inference_abs": 663 | tokenizer = InferenceAbsTokenizer() 664 | elif tokenizer_name == "rel": 665 | tokenizer = RelTokenizer() 666 | else: 667 | raise Exception("Invalid tokenizer name") 668 | 669 | accelerator = accelerate.Accelerator( 670 | project_dir=project_dir, gradient_accumulation_steps=grad_acc_steps 671 | ) 672 | if accelerator.is_main_process: 673 | project_dir = setup_project_dir(project_dir) 674 | logger = setup_logger(project_dir) 675 | 676 | logger = get_logger(__name__) 677 | logger.info(f"Using project directory {project_dir}") 678 | logger.info( 679 | f"Using training config: " 680 | f"model_name={model_name}, " 681 | f"checkpoint_path={checkpoint_path}, " 682 | if checkpoint_path 683 | else "" 684 | f"epochs={epochs}, " 685 | f"batch_size={batch_size}, " 686 | f"grad_acc_steps={grad_acc_steps}, " 687 | f"num_workers={num_workers}" 688 | ) 689 | 690 | if steps_per_checkpoint: 691 | logger.info(f"Creating checkpoints every {steps_per_checkpoint}") 692 | 693 | # Init model 694 | model_config = ModelConfig(**load_model_config(model_name)) 695 | model_config.set_vocab_size(tokenizer.vocab_size) 696 | model = TransformerLM(model_config) 697 | model.compile() 698 | logger.info(f"Loaded model with config: {load_model_config(model_name)}") 699 | if checkpoint_path: 700 | try: 701 | model.load_state_dict(_load_weight(checkpoint_path)) 702 | except Exception as e: 703 | raise Exception( 704 | f"Failed to load checkpoint: {e}\n" 705 | "This could be due to a mismatch between the tokenizer used " 706 | "to build the pre-training and fine-tuning datasets" 707 | ) 708 | logger.info(f"Loaded finetune checkpoint located at: {checkpoint_path}") 709 | 710 | train_dataloader, val_dataloader = get_dataloaders( 711 | train_data_dirs=train_data_paths, 712 | val_data_dir=val_data_path, 713 | tokenizer=tokenizer, 714 | batch_size=batch_size, 715 | num_workers=num_workers, 716 | apply_aug=True, 717 | finetune=True if checkpoint_path is not None else False, 718 | ) 719 | 720 | assert ( 721 | train_dataloader.dataset.config["max_seq_len"] 722 | == model_config.max_seq_len 723 | ) 724 | assert ( 725 | val_dataloader.dataset.config["max_seq_len"] == model_config.max_seq_len 726 | ) 727 | 728 | optimizer, scheduler = get_optim( 729 | model, 730 | num_epochs=epochs, 731 | steps_per_epoch=len(train_dataloader), 732 | ) 733 | 734 | ( 735 | model, 736 | train_dataloader, 737 | val_dataloader, 738 | optimizer, 739 | scheduler, 740 | ) = accelerator.prepare( 741 | model, 742 | train_dataloader, 743 | val_dataloader, 744 | optimizer, 745 | scheduler, 746 | ) 747 | 748 | logger.info(f"Starting {'finetune' if checkpoint_path else 'pretrain'} job") 749 | _train( 750 | epochs=epochs, 751 | accelerator=accelerator, 752 | model=model, 753 | train_dataloader=train_dataloader, 754 | val_dataloader=val_dataloader, 755 | optimizer=optimizer, 756 | scheduler=scheduler, 757 | steps_per_checkpoint=steps_per_checkpoint, 758 | project_dir=project_dir, 759 | ) 760 | 761 | 762 | def convert_cp_from_safetensors(checkpoint_path: str, save_path: str): 763 | d = load_file(checkpoint_path) 764 | key = list(d.keys())[0] 765 | gap = len(key.split(".")[0]) 766 | d = {s[gap + 1 :]: v for s, v in d.items()} 767 | torch.save(d, save_path) 768 | 769 | 770 | def convert_cp_from_accelerate( 771 | model_name: str, tokenizer_name: str, checkpoint_dir: str, save_path: str 772 | ): 773 | def _load_state_dict(_tokenizer: Tokenizer): 774 | model_config = ModelConfig(**load_model_config(model_name)) 775 | model_config.set_vocab_size(_tokenizer.vocab_size) 776 | model = TransformerLM(model_config) 777 | model = accelerator.prepare(model) 778 | accelerator.load_state(checkpoint_dir) 779 | 780 | return model.state_dict() 781 | 782 | accelerator = accelerate.Accelerator() 783 | 784 | # Try both 785 | if tokenizer_name == "abs": 786 | state_dict = _load_state_dict(_tokenizer=AbsTokenizer()) 787 | elif tokenizer_name == "rel": 788 | state_dict = _load_state_dict(_tokenizer=RelTokenizer()) 789 | else: 790 | print("Invalid choice of tokenizer") 791 | 792 | torch.save(state_dict, save_path) 793 | 794 | 795 | def parse_resume_args(): 796 | argp = argparse.ArgumentParser(prog="python aria/train.py resume") 797 | argp.add_argument("model", help="name of model config file") 798 | argp.add_argument("-train_data", nargs="+", help="path to train dir") 799 | argp.add_argument("-val_data", help="path to val dir") 800 | argp.add_argument("-cp_dir", help="checkpoint dir", type=str, required=True) 801 | argp.add_argument("-r_step", help="resume step", type=int, required=True) 802 | argp.add_argument("-r_epoch", help="resume epoch", type=int, required=True) 803 | argp.add_argument("-epochs", help="train epochs", type=int, required=True) 804 | argp.add_argument("-bs", help="batch size", type=int, default=32) 805 | argp.add_argument( 806 | "-grad_acc_steps", 807 | help="gradient accumulation steps", 808 | type=int, 809 | default=1, 810 | ) 811 | argp.add_argument("-workers", help="number workers", type=int, default=1) 812 | argp.add_argument("-pdir", help="project dir", type=str, required=False) 813 | argp.add_argument( 814 | "-spc", help="steps per checkpoint", type=int, required=False 815 | ) 816 | 817 | return argp.parse_args(sys.argv[2:]) 818 | 819 | 820 | def parse_train_args(): 821 | argp = argparse.ArgumentParser(prog="python aria/train.py train") 822 | argp.add_argument("model", help="name of model config file") 823 | argp.add_argument("-train_data", nargs="+", help="path to train dir") 824 | argp.add_argument("-val_data", help="path to val dir") 825 | argp.add_argument( 826 | "-cp_path", help="path to checkpoint", required=False, default=None 827 | ) 828 | argp.add_argument("-epochs", help="train epochs", type=int, required=True) 829 | argp.add_argument("-bs", help="batch size", type=int, default=32) 830 | argp.add_argument( 831 | "-grad_acc_steps", 832 | help="gradient accumulation steps", 833 | type=int, 834 | default=1, 835 | ) 836 | argp.add_argument("-workers", help="number workers", type=int, default=1) 837 | argp.add_argument("-pdir", help="project dir", type=str, required=False) 838 | argp.add_argument( 839 | "-spc", help="steps per checkpoint", type=int, required=False 840 | ) 841 | 842 | return argp.parse_args(sys.argv[2:]) 843 | 844 | 845 | if __name__ == "__main__": 846 | parser = argparse.ArgumentParser( 847 | usage="python aria/train.py []" 848 | ) 849 | parser.add_argument( 850 | "mode", help="training function", choices=("train", "resume") 851 | ) 852 | 853 | args = parser.parse_args(sys.argv[1:2]) 854 | if not hasattr(args, "mode"): 855 | parser.print_help() 856 | print("Unrecognized command") 857 | exit(1) 858 | elif args.mode == "train": 859 | train_args = parse_train_args() 860 | train( 861 | model_name=train_args.model, 862 | train_data_paths=train_args.train_data, 863 | val_data_path=train_args.val_data, 864 | num_workers=train_args.workers, 865 | batch_size=train_args.bs, 866 | grad_acc_steps=train_args.grad_acc_steps, 867 | epochs=train_args.epochs, 868 | checkpoint_path=train_args.cp_path, 869 | steps_per_checkpoint=train_args.spc, 870 | project_dir=train_args.pdir, 871 | ) 872 | elif args.mode == "resume": 873 | resume_args = parse_resume_args() 874 | resume_train( 875 | model_name=resume_args.model, 876 | train_data_paths=resume_args.train_data, 877 | val_data_path=resume_args.val_data, 878 | num_workers=resume_args.workers, 879 | batch_size=resume_args.bs, 880 | grad_acc_steps=resume_args.grad_acc_steps, 881 | epochs=resume_args.epochs, 882 | checkpoint_dir=resume_args.cp_dir, 883 | resume_step=resume_args.r_step, 884 | resume_epoch=resume_args.r_epoch, 885 | steps_per_checkpoint=resume_args.spc, 886 | project_dir=resume_args.pdir, 887 | ) 888 | else: 889 | print("Unrecognized command") 890 | parser.print_help() 891 | exit(1) 892 | -------------------------------------------------------------------------------- /aria/utils.py: -------------------------------------------------------------------------------- 1 | """Contains miscellaneous utilities""" 2 | 3 | 4 | def _load_weight(ckpt_path: str, device="cpu"): 5 | if ckpt_path.endswith("safetensors"): 6 | try: 7 | from safetensors.torch import load_file 8 | except ImportError as e: 9 | raise ImportError( 10 | f"Please install safetensors in order to read from the checkpoint: {ckpt_path}" 11 | ) from e 12 | return load_file(ckpt_path, device=device) 13 | else: 14 | import torch 15 | 16 | return torch.load(ckpt_path, map_location=device) 17 | -------------------------------------------------------------------------------- /config/accelerate.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: 'NO' 4 | downcast_bf16: 'no' 5 | gpu_ids: all 6 | machine_rank: 0 7 | main_training_function: main 8 | mixed_precision: bf16 9 | num_machines: 1 10 | num_processes: 1 11 | rdzv_backend: static 12 | same_network: true 13 | tpu_env: [] 14 | tpu_use_cluster: false 15 | tpu_use_sudo: false 16 | use_cpu: false 17 | -------------------------------------------------------------------------------- /config/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": { 3 | "tests": { 4 | "note_density_in_interval":{ 5 | "run": false, 6 | "args": { 7 | "test_params_list": 8 | [ 9 | { 10 | "max_notes_per_second": 60, 11 | "max_notes_per_second_per_pitch": 15, 12 | "interval_len_s": 3 13 | }, 14 | { 15 | "max_notes_per_second": 45, 16 | "max_notes_per_second_per_pitch": 10, 17 | "interval_len_s": 5 18 | }, 19 | { 20 | "max_notes_per_second": 40, 21 | "max_notes_per_second_per_pitch": 8, 22 | "interval_len_s": 10 23 | }, 24 | { 25 | "max_notes_per_second": 30, 26 | "max_notes_per_second_per_pitch": 6, 27 | "interval_len_s": 45 28 | } 29 | ] 30 | } 31 | }, 32 | "note_timing_entropy":{ 33 | "run": false, 34 | "args": { 35 | "min_length_entropy": 2.5, 36 | "min_onset_delta_entropy": 0.0 37 | } 38 | }, 39 | "note_pitch_entropy":{ 40 | "run": false, 41 | "args": { 42 | "min_entropy": 3.0 43 | } 44 | }, 45 | "unique_pitch_count_in_interval":{ 46 | "run": false, 47 | "args": { 48 | "test_params_list": 49 | [ 50 | {"min_unique_pitch_cnt": 5, "interval_len_s": 30}, 51 | {"min_unique_pitch_cnt": 8, "interval_len_s": 60}, 52 | {"min_unique_pitch_cnt": 10, "interval_len_s": 120} 53 | ] 54 | } 55 | }, 56 | "unique_pitch_count":{ 57 | "run": false, 58 | "args": { 59 | "min_num_unique_pitches": 12 60 | } 61 | }, 62 | "silent_interval":{ 63 | "run": false, 64 | "args": { 65 | "max_silence_s": 20 66 | } 67 | }, 68 | "mean_note_velocity":{ 69 | "run": false, 70 | "args": { 71 | "min_mean_velocity": 20, 72 | "max_mean_velocity": 105 73 | } 74 | }, 75 | "max_programs":{ 76 | "run": false, 77 | "args": { 78 | "max": 12 79 | } 80 | }, 81 | "max_instruments":{ 82 | "run": false, 83 | "args": { 84 | "max": 7 85 | } 86 | }, 87 | "total_note_frequency":{ 88 | "run": false, 89 | "args": { 90 | "min_per_second": 1.5, 91 | "max_per_second": 30 92 | } 93 | }, 94 | "note_frequency_per_instrument":{ 95 | "run": false, 96 | "args": { 97 | "min_per_second": 1.0, 98 | "max_per_second": 25 99 | } 100 | }, 101 | "length":{ 102 | "run": false, 103 | "args": { 104 | "min_length_s": 30, 105 | "max_length_s": 7200 106 | } 107 | }, 108 | "repetitive_content":{ 109 | "run": false, 110 | "args": { 111 | "min_length_m": 20, 112 | "num_chunks": 5, 113 | "kl_tolerance": 0.2 114 | } 115 | } 116 | }, 117 | "pre_processing": { 118 | "remove_instruments": { 119 | "run": true, 120 | "args": { 121 | "piano": false, 122 | "chromatic": true, 123 | "organ": false, 124 | "guitar": false, 125 | "bass": false, 126 | "strings": false, 127 | "ensemble": false, 128 | "brass": false, 129 | "reed": false, 130 | "pipe": false, 131 | "synth_lead": false, 132 | "synth_pad": true, 133 | "synth_effect": true, 134 | "ethnic": true, 135 | "percussive": true, 136 | "sfx": true 137 | } 138 | } 139 | }, 140 | "metadata": { 141 | "functions": { 142 | "aria_midi_json": { 143 | "run": true, 144 | "args": {} 145 | }, 146 | "composer_filename": { 147 | "run": false, 148 | "args": { 149 | "composer_names": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"] 150 | } 151 | }, 152 | "composer_metamsg": { 153 | "run": false, 154 | "args": { 155 | "composer_names": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"] 156 | } 157 | }, 158 | "form_filename": { 159 | "run": false, 160 | "args": { 161 | "form_names": ["sonata", "prelude", "nocturne", "etude", "waltz", "mazurka", "impromptu", "fugue"] 162 | } 163 | }, 164 | "maestro_json": { 165 | "run": false, 166 | "args": { 167 | "composer_names": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"], 168 | "form_names": ["sonata", "prelude", "nocturne", "étude", "waltz", "mazurka", "impromptu", "fugue"] 169 | } 170 | } 171 | }, 172 | "manual": { 173 | "genre": ["classical", "jazz"], 174 | "form": ["sonata", "prelude", "nocturne", "étude", "waltz", "mazurka", "impromptu", "fugue"], 175 | "composer": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"] 176 | } 177 | }, 178 | "finetuning": { 179 | "guidance_prob": 0.5, 180 | "min_noisy_interval_ms": 5000, 181 | "max_noisy_interval_ms": 60000, 182 | "min_clean_interval_ms": 60000, 183 | "max_clean_interval_ms": 200000, 184 | "noising": { 185 | "activation_prob": 0.5, 186 | "remove_notes": { 187 | "activation_prob": 0.25, 188 | "min_ratio": 0.0, 189 | "max_ratio": 0.15 190 | }, 191 | "adjust_velocity": { 192 | "activation_prob": 0.25, 193 | "min_adjust": 1, 194 | "max_adjust": 20 195 | }, 196 | "adjust_onsets": { 197 | "activation_prob": 0.25, 198 | "min_adjust_s": 0.005, 199 | "max_adjust_s": 0.05, 200 | "max_ratio": 0.0, 201 | "min_ratio": 0.2 202 | }, 203 | "quantize_onsets": { 204 | "activation_prob": 0.05, 205 | "min_quant_s": 0.05, 206 | "max_quant_s": 0.1, 207 | "max_vel_delta": 30 208 | } 209 | } 210 | } 211 | }, 212 | "tokenizer": { 213 | "inference_abs": { 214 | "guidance": { 215 | "min_ms": 5000, 216 | "max_ms": 40000 217 | } 218 | 219 | 220 | } 221 | } 222 | 223 | } 224 | -------------------------------------------------------------------------------- /config/models/large.json: -------------------------------------------------------------------------------- 1 | { 2 | "d_model": 2048, 3 | "n_heads": 32, 4 | "n_layers": 16, 5 | "ff_mult": 4, 6 | "drop_p": 0.0, 7 | "max_seq_len": 8192, 8 | "grad_checkpoint": true 9 | } -------------------------------------------------------------------------------- /config/models/medium.json: -------------------------------------------------------------------------------- 1 | { 2 | "d_model": 1536, 3 | "n_heads": 24, 4 | "n_layers": 16, 5 | "ff_mult": 4, 6 | "drop_p": 0.0, 7 | "max_seq_len": 8192, 8 | "grad_checkpoint": true 9 | } 10 | -------------------------------------------------------------------------------- /models/placeholder.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EleutherAI/aria/fedf7630437318e8733eb5ad1dc441cdf9c07b5b/models/placeholder.txt -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | flake8 2 | black 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ariautils @ git+https://github.com/EleutherAI/aria-utils.git 2 | torch >= 2.3 3 | accelerate 4 | jsonlines 5 | tqdm 6 | safetensors -------------------------------------------------------------------------------- /scripts/download_data.sh: -------------------------------------------------------------------------------- 1 | mkdir data 2 | gsutil cp gs://gpt-aria/train_data/train.jsonl data/train.jsonl 3 | gsutil cp gs://gpt-aria/train_data/val.jsonl data/val.jsonl -------------------------------------------------------------------------------- /scripts/midi_to_audio.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from aria.utils import midi_to_audio 4 | 5 | 6 | def main(): 7 | root_dir = "/Users/louis/work/data/mid/prompts/survey" 8 | for dirpath, dirnames, filenames in os.walk(root_dir): 9 | for filename in filenames: 10 | if filename.endswith(".mid"): 11 | midi_path = os.path.join(dirpath, filename) 12 | midi_to_audio(midi_path) 13 | 14 | 15 | if __name__ == "__main__": 16 | main() 17 | 18 | 19 | if __name__ == "__main__": 20 | main() 21 | -------------------------------------------------------------------------------- /scripts/upload_data.sh: -------------------------------------------------------------------------------- 1 | gsutil cp data/train.jsonl gs://gpt-aria/train_data/train.jsonl 2 | gsutil cp data/val.jsonl gs://gpt-aria/train_data/val.jsonl -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pkg_resources 4 | from setuptools import find_packages, setup 5 | 6 | setup( 7 | name="aria", 8 | py_modules=["aria"], 9 | version="0.0.1", 10 | description="", 11 | author="", 12 | packages=find_packages() + ["config"], 13 | include_package_data=True, 14 | entry_points={ 15 | "console_scripts": [ 16 | "aria=aria.run:main", 17 | ], 18 | }, 19 | install_requires=[ 20 | str(r) 21 | for r in pkg_resources.parse_requirements( 22 | open(os.path.join(os.path.dirname(__file__), "requirements.txt")) 23 | ) 24 | ], 25 | ) 26 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EleutherAI/aria/fedf7630437318e8733eb5ad1dc441cdf9c07b5b/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_data.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | import shutil 4 | import logging 5 | 6 | from aria import tokenizer 7 | from aria.config import load_config 8 | from aria.data import datasets 9 | from aria.data.datasets import _noise_midi_dict 10 | from ariautils.midi import MidiDict 11 | 12 | logger = logging.getLogger(__name__) 13 | if not os.path.isdir("tests/test_results"): 14 | os.makedirs("tests/test_results") 15 | 16 | 17 | def setup_logger(): 18 | logger = logging.getLogger(__name__) 19 | for h in logger.handlers[:]: 20 | logger.removeHandler(h) 21 | logger.propagate = False 22 | logger.setLevel(logging.INFO) 23 | formatter = logging.Formatter( 24 | "[%(asctime)s] tests.test_data: [%(levelname)s] %(message)s" 25 | ) 26 | ch = logging.StreamHandler() 27 | ch.setLevel(logging.INFO) 28 | ch.setFormatter(formatter) 29 | logger.addHandler(ch) 30 | 31 | 32 | def get_short_seq(): 33 | return [ 34 | ("prefix", "instrument", "piano"), 35 | ("prefix", "instrument", "drum"), 36 | ("prefix", "composer", "bach"), 37 | "", 38 | ("piano", 62, 50), 39 | ("dur", 50), 40 | ("wait", 100), 41 | ("drum", 50), 42 | ("piano", 64, 70), 43 | ("dur", 100), 44 | ("wait", 100), 45 | "", 46 | ] 47 | 48 | 49 | class TestMidiDict(unittest.TestCase): 50 | def test_resolve_pedal(self): 51 | midi_dict = MidiDict.from_midi("tests/test_data/maestro.mid") 52 | midi_dict.resolve_pedal() 53 | self.assertListEqual(midi_dict.pedal_msgs, []) 54 | mid = midi_dict.to_midi() 55 | mid.save("tests/test_results/maestro_npedal.mid") 56 | 57 | 58 | class TestMidiDataset(unittest.TestCase): 59 | def test_build(self): 60 | dataset = datasets.MidiDataset.build( 61 | dir="tests/test_data", 62 | recur=False, 63 | ) 64 | 65 | self.assertEqual(len(dataset), 7) 66 | self.assertEqual(type(dataset[0]), MidiDict) 67 | 68 | def test_save_load(self): 69 | dataset = datasets.MidiDataset.build( 70 | dir="tests/test_data", 71 | recur=False, 72 | ) 73 | dataset.save("tests/test_results/mididict_dataset.jsonl") 74 | 75 | dataset_reloaded = datasets.MidiDataset.load( 76 | "tests/test_results/mididict_dataset.jsonl" 77 | ) 78 | self.assertEqual(len(dataset_reloaded), 7) 79 | self.assertEqual(type(dataset[0]), type(dataset_reloaded[0])) 80 | 81 | def test_build_to_file(self): 82 | datasets.MidiDataset.build_to_file( 83 | dir="tests/test_data", 84 | save_path="tests/test_results/mididict_dataset_direct.jsonl", 85 | recur=False, 86 | overwrite=True, 87 | ) 88 | 89 | dataset_reloaded = datasets.MidiDataset.load( 90 | load_path="tests/test_results/mididict_dataset_direct.jsonl", 91 | ) 92 | self.assertEqual(len(dataset_reloaded), 7) 93 | self.assertEqual(type(dataset_reloaded[0]), MidiDict) 94 | 95 | def test_split_from_file(self): 96 | datasets.MidiDataset.build_to_file( 97 | dir="tests/test_data", 98 | save_path="tests/test_results/mididict_dataset.jsonl", 99 | recur=False, 100 | overwrite=True, 101 | ) 102 | 103 | datasets.MidiDataset.split_from_file( 104 | load_path="tests/test_results/mididict_dataset.jsonl", 105 | train_val_ratio=0.7, 106 | repeatable=True, 107 | overwrite=True, 108 | ) 109 | 110 | self.assertTrue( 111 | os.path.isfile("tests/test_results/mididict_dataset_train.jsonl") 112 | ) 113 | self.assertTrue( 114 | os.path.isfile("tests/test_results/mididict_dataset_val.jsonl") 115 | ) 116 | 117 | def test_data_hash(self): 118 | mid_1 = MidiDict.from_midi("tests/test_data/pop.mid") 119 | mid_2 = MidiDict.from_midi("tests/test_data/pop_copy.mid") 120 | 121 | self.assertEqual(mid_1.calculate_hash(), mid_2.calculate_hash()) 122 | 123 | def test_concat(self): 124 | if ( 125 | os.path.exists("tests/test_results/mididict_dataset_train.jsonl") 126 | and os.path.exists("tests/test_results/mididict_dataset_val.jsonl") 127 | and os.path.exists("tests/test_results/mididict_dataset.jsonl") 128 | ): 129 | datasets.MidiDataset.combine_datasets_from_file( 130 | "tests/test_results/mididict_dataset_train.jsonl", 131 | "tests/test_results/mididict_dataset_val.jsonl", 132 | "tests/test_results/mididict_dataset.jsonl", 133 | output_path="tests/test_results/mididict_dataset_concat.jsonl", 134 | ) 135 | 136 | self.assertAlmostEqual( 137 | len( 138 | datasets.MidiDataset.load( 139 | "tests/test_results/mididict_dataset_concat.jsonl" 140 | ) 141 | ), 142 | len( 143 | datasets.MidiDataset.load( 144 | "tests/test_results/mididict_dataset.jsonl" 145 | ) 146 | ), 147 | ) 148 | 149 | 150 | class TestPretrainingDataset(unittest.TestCase): 151 | def test_build(self): 152 | MAX_SEQ_LEN = 4096 153 | tknzr = tokenizer.AbsTokenizer(return_tensors=False) 154 | mididict_dataset = datasets.MidiDataset.build( 155 | dir="tests/test_data", 156 | recur=False, 157 | ) 158 | mididict_dataset.save("tests/test_results/mididict_dataset.jsonl") 159 | 160 | if os.path.exists("tests/test_results/pretrain_dataset_buff_1"): 161 | shutil.rmtree("tests/test_results/pretrain_dataset_buff_1") 162 | if os.path.exists("tests/test_results/pretrain_dataset_buff_2"): 163 | shutil.rmtree("tests/test_results/pretrain_dataset_buff_2") 164 | 165 | dataset_from_file = datasets.PretrainingDataset.build( 166 | tokenizer=tknzr, 167 | save_dir="tests/test_results/pretrain_dataset_buff_1", 168 | max_seq_len=MAX_SEQ_LEN, 169 | num_epochs=3, 170 | midi_dataset_path="tests/test_results/mididict_dataset.jsonl", 171 | ) 172 | dataset_from_mdset = datasets.PretrainingDataset.build( 173 | tokenizer=tknzr, 174 | save_dir="tests/test_results/pretrain_dataset_buff_2", 175 | max_seq_len=MAX_SEQ_LEN, 176 | num_epochs=2, 177 | midi_dataset=mididict_dataset, 178 | ) 179 | 180 | def test_multiple_paths(self): 181 | MAX_SEQ_LEN = 4096 182 | tknzr = tokenizer.AbsTokenizer(return_tensors=False) 183 | mididict_dataset = datasets.MidiDataset.build( 184 | dir="tests/test_data", 185 | recur=False, 186 | ) 187 | mididict_dataset.save("tests/test_results/mididict_dataset_1.jsonl") 188 | 189 | if os.path.exists("tests/test_results/pretrain_dataset_buff_1"): 190 | shutil.rmtree("tests/test_results/pretrain_dataset_buff_1") 191 | if os.path.exists("tests/test_results/pretrain_dataset_buff_2"): 192 | shutil.rmtree("tests/test_results/pretrain_dataset_buff_2") 193 | 194 | datasets.PretrainingDataset.build( 195 | tokenizer=tknzr, 196 | save_dir="tests/test_results/pretrain_dataset_buff_1", 197 | max_seq_len=MAX_SEQ_LEN, 198 | num_epochs=3, 199 | midi_dataset_path="tests/test_results/mididict_dataset.jsonl", 200 | ) 201 | datasets.PretrainingDataset.build( 202 | tokenizer=tknzr, 203 | save_dir="tests/test_results/pretrain_dataset_buff_2", 204 | max_seq_len=MAX_SEQ_LEN, 205 | num_epochs=5, 206 | midi_dataset_path="tests/test_results/mididict_dataset.jsonl", 207 | ) 208 | 209 | dataset = datasets.PretrainingDataset( 210 | dir_paths=[ 211 | "tests/test_results/pretrain_dataset_buff_1", 212 | "tests/test_results/pretrain_dataset_buff_2", 213 | ], 214 | tokenizer=tknzr, 215 | ) 216 | 217 | for epoch in range(11): 218 | for idx, _ in enumerate(dataset): 219 | pass 220 | 221 | print("-------------") 222 | dataset.init_epoch() 223 | 224 | def test_aug(self): 225 | MAX_SEQ_LEN = 512 226 | tknzr = tokenizer.AbsTokenizer(return_tensors=False) 227 | mididict_dataset = datasets.MidiDataset.build( 228 | dir="tests/test_data", 229 | recur=False, 230 | ) 231 | if os.path.exists("tests/test_results/pretrain_dataset_buff"): 232 | shutil.rmtree("tests/test_results/pretrain_dataset_buff") 233 | pretrain_dataset = datasets.PretrainingDataset.build( 234 | tokenizer=tknzr, 235 | save_dir="tests/test_results/pretrain_dataset_buff", 236 | max_seq_len=MAX_SEQ_LEN, 237 | num_epochs=1, 238 | midi_dataset=mididict_dataset, 239 | ) 240 | pretrain_dataset.set_transform(tknzr.export_data_aug()) 241 | for idx, seq in enumerate(tknzr.decode(pretrain_dataset[0][0])): 242 | for _idx, tok in enumerate(seq): 243 | if tok == tknzr.unk_tok: 244 | logger.warning(f"unk_tok seen at seq={idx}, idx={_idx}") 245 | 246 | logger.info(f"data_aug_1: {tknzr.decode(pretrain_dataset[0][0][:50])}") 247 | logger.info(f"data_aug_2: {tknzr.decode(pretrain_dataset[0][0][:50])}") 248 | 249 | 250 | class TestFinetuningDataset(unittest.TestCase): 251 | def test_noise(self): 252 | config = load_config()["data"]["finetuning"]["noising"] 253 | midi_dict = MidiDict.from_midi("tests/test_data/clean/1.mid") 254 | noisy_midi_dict = _noise_midi_dict(midi_dict, config) 255 | noisy_midi = noisy_midi_dict.to_midi() 256 | noisy_midi.save("tests/test_results/noisy.mid") 257 | 258 | def test_build(self): 259 | MAX_SEQ_LEN = 4096 260 | tknzr = tokenizer.SeparatedAbsTokenizer(return_tensors=False) 261 | clean_mididict_dataset = datasets.MidiDataset.build( 262 | dir="tests/test_data/clean", 263 | recur=True, 264 | shuffle=False, 265 | ) 266 | noisy_mididict_dataset = datasets.MidiDataset.build( 267 | dir="tests/test_data/noisy", 268 | recur=True, 269 | shuffle=False, 270 | ) 271 | if os.path.exists("tests/test_results/clean.jsonl"): 272 | os.remove("tests/test_results/clean.jsonl") 273 | if os.path.exists("tests/test_results/noisy.jsonl"): 274 | os.remove("tests/test_results/noisy.jsonl") 275 | clean_mididict_dataset.save("tests/test_results/clean.jsonl") 276 | noisy_mididict_dataset.save("tests/test_results/noisy.jsonl") 277 | 278 | if os.path.exists("tests/test_results/comb"): 279 | shutil.rmtree("tests/test_results/comb") 280 | 281 | finetuning_dataset = datasets.FinetuningDataset.build( 282 | tokenizer=tknzr, 283 | save_dir="tests/test_results/comb", 284 | max_seq_len=MAX_SEQ_LEN, 285 | num_epochs=2, 286 | clean_dataset_path="tests/test_results/clean.jsonl", 287 | noisy_dataset_paths=["tests/test_results/noisy.jsonl"], 288 | ) 289 | 290 | finetuning_dataset.init_epoch(0) 291 | for seq, tgt, mask in finetuning_dataset: 292 | tokenized_seq = tknzr.decode(seq) 293 | if ( 294 | tknzr.inst_start_tok in tokenized_seq 295 | and tknzr.bos_tok not in tokenized_seq 296 | ): 297 | detokenized_midi_dict = tknzr.detokenize(tokenized_seq) 298 | res = detokenized_midi_dict.to_midi() 299 | res.save(f"tests/test_results/comb.mid") 300 | break 301 | 302 | 303 | setup_logger() 304 | if __name__ == "__main__": 305 | unittest.main() 306 | -------------------------------------------------------------------------------- /tests/test_data/arabesque.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EleutherAI/aria/fedf7630437318e8733eb5ad1dc441cdf9c07b5b/tests/test_data/arabesque.mid -------------------------------------------------------------------------------- /tests/test_data/bach.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EleutherAI/aria/fedf7630437318e8733eb5ad1dc441cdf9c07b5b/tests/test_data/bach.mid -------------------------------------------------------------------------------- /tests/test_data/basic.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EleutherAI/aria/fedf7630437318e8733eb5ad1dc441cdf9c07b5b/tests/test_data/basic.mid -------------------------------------------------------------------------------- /tests/test_data/beethoven_moonlight.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EleutherAI/aria/fedf7630437318e8733eb5ad1dc441cdf9c07b5b/tests/test_data/beethoven_moonlight.mid -------------------------------------------------------------------------------- /tests/test_data/beethoven_sonata.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EleutherAI/aria/fedf7630437318e8733eb5ad1dc441cdf9c07b5b/tests/test_data/beethoven_sonata.mid -------------------------------------------------------------------------------- /tests/test_data/clean/1.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EleutherAI/aria/fedf7630437318e8733eb5ad1dc441cdf9c07b5b/tests/test_data/clean/1.mid -------------------------------------------------------------------------------- /tests/test_data/clean/2.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EleutherAI/aria/fedf7630437318e8733eb5ad1dc441cdf9c07b5b/tests/test_data/clean/2.mid -------------------------------------------------------------------------------- /tests/test_data/expressive.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EleutherAI/aria/fedf7630437318e8733eb5ad1dc441cdf9c07b5b/tests/test_data/expressive.mid -------------------------------------------------------------------------------- /tests/test_data/noisy/1.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EleutherAI/aria/fedf7630437318e8733eb5ad1dc441cdf9c07b5b/tests/test_data/noisy/1.mid -------------------------------------------------------------------------------- /tests/test_data/noisy/2.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EleutherAI/aria/fedf7630437318e8733eb5ad1dc441cdf9c07b5b/tests/test_data/noisy/2.mid -------------------------------------------------------------------------------- /tests/test_data/pop.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EleutherAI/aria/fedf7630437318e8733eb5ad1dc441cdf9c07b5b/tests/test_data/pop.mid -------------------------------------------------------------------------------- /tests/test_data/pop_copy.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EleutherAI/aria/fedf7630437318e8733eb5ad1dc441cdf9c07b5b/tests/test_data/pop_copy.mid -------------------------------------------------------------------------------- /tests/test_tokenizers.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import logging 3 | import os 4 | import time 5 | 6 | from typing import Callable 7 | 8 | from aria import tokenizer 9 | from aria.config import load_config 10 | from ariautils.midi import MidiDict 11 | from aria.data.datasets import _get_combined_mididict, _noise_midi_dict 12 | from aria.utils import midi_to_audio 13 | 14 | 15 | if not os.path.isdir("tests/test_results"): 16 | os.makedirs("tests/test_results") 17 | 18 | 19 | # TODO: Implement with tokenizer functions 20 | def get_short_seq_abs(tknzr: tokenizer.AbsTokenizer): 21 | return [ 22 | ("prefix", "instrument", "piano"), 23 | ("prefix", "instrument", "drum"), 24 | "", 25 | ("piano", 62, tknzr._quantize_velocity(45)), 26 | ("onset", tknzr._quantize_onset(0)), 27 | ("dur", tknzr._quantize_dur(50)), 28 | ("drum", 50), 29 | ("onset", tknzr._quantize_onset(100)), 30 | ("piano", 64, tknzr._quantize_velocity(75)), 31 | ("onset", tknzr._quantize_onset(100)), 32 | ("dur", tknzr._quantize_dur(5000)), 33 | "", 34 | "", 35 | "", 36 | ("piano", 65, tknzr._quantize_velocity(75)), 37 | ("onset", tknzr._quantize_onset(170)), 38 | ("dur", tknzr._quantize_dur(100)), 39 | "", 40 | ("piano", 60, tknzr._quantize_velocity(45)), 41 | ("onset", tknzr._quantize_onset(270)), 42 | ("dur", tknzr._quantize_dur(60)), 43 | "", 44 | ("onset", tknzr._quantize_onset(270)), 45 | ("dur", tknzr._quantize_dur(70)), 46 | ("drum", 50), 47 | ("onset", tknzr._quantize_onset(270)), 48 | "", 49 | ("piano", 80, tknzr._quantize_velocity(45)), 50 | ("onset", tknzr._quantize_onset(270)), 51 | ("dur", tknzr._quantize_dur(80)), 52 | "", 53 | ] 54 | 55 | 56 | def get_concat_seq_abs(tknzr: tokenizer.AbsTokenizer): 57 | return [ 58 | ("onset", tknzr._quantize_onset(270)), 59 | ("dur", tknzr._quantize_dur(60)), 60 | "", 61 | ("onset", tknzr._quantize_onset(270)), 62 | ("dur", tknzr._quantize_dur(70)), 63 | ("drum", 50), 64 | ("onset", tknzr._quantize_onset(270)), 65 | "", 66 | ("piano", 80, tknzr._quantize_velocity(45)), 67 | ("onset", tknzr._quantize_onset(270)), 68 | ("dur", tknzr._quantize_dur(80)), 69 | "", 70 | ("prefix", "instrument", "piano"), 71 | ("prefix", "instrument", "drum"), 72 | "", 73 | ("piano", 62, tknzr._quantize_velocity(45)), 74 | ("onset", tknzr._quantize_onset(0)), 75 | ("dur", tknzr._quantize_dur(50)), 76 | ("drum", 50), 77 | ("onset", tknzr._quantize_onset(100)), 78 | ("piano", 64, tknzr._quantize_velocity(75)), 79 | ("onset", tknzr._quantize_onset(100)), 80 | ("dur", tknzr._quantize_dur(5000)), 81 | "", 82 | "", 83 | "", 84 | ("piano", 65, tknzr._quantize_velocity(75)), 85 | ("onset", tknzr._quantize_onset(170)), 86 | ("dur", tknzr._quantize_dur(100)), 87 | "", 88 | ("piano", 60, tknzr._quantize_velocity(45)), 89 | ("onset", tknzr._quantize_onset(270)), 90 | ("dur", tknzr._quantize_dur(60)), 91 | "", 92 | ("onset", tknzr._quantize_onset(270)), 93 | ("dur", tknzr._quantize_dur(70)), 94 | ("drum", 50), 95 | ("onset", tknzr._quantize_onset(270)), 96 | "", 97 | ("piano", 80, tknzr._quantize_velocity(45)), 98 | ("onset", tknzr._quantize_onset(270)), 99 | ("dur", tknzr._quantize_dur(80)), 100 | "", 101 | ("prefix", "instrument", "piano"), 102 | ("prefix", "instrument", "drum"), 103 | "", 104 | ("piano", 62, tknzr._quantize_velocity(45)), 105 | ("onset", tknzr._quantize_onset(0)), 106 | ("dur", tknzr._quantize_dur(50)), 107 | ("drum", 50), 108 | ("onset", tknzr._quantize_onset(100)), 109 | ("piano", 64, tknzr._quantize_velocity(75)), 110 | ("onset", tknzr._quantize_onset(100)), 111 | ("dur", tknzr._quantize_dur(5000)), 112 | "", 113 | "", 114 | ] 115 | 116 | 117 | def get_short_seq_rel(tknzr: tokenizer.RelTokenizer): 118 | return [ 119 | ("prefix", "instrument", "piano"), 120 | ("prefix", "instrument", "drum"), 121 | ("prefix", "composer", "bach"), 122 | "", 123 | ("piano", 62, tknzr._quantize_velocity(50)), 124 | ("dur", tknzr._quantize_time(50)), 125 | ("wait", tknzr._quantize_time(100)), 126 | ("drum", 50), 127 | ("piano", 64, tknzr._quantize_velocity(70)), 128 | ("dur", tknzr._quantize_time(1000000)), 129 | ("wait", tknzr._quantize_time(1000000)), 130 | ("wait", tknzr._quantize_time(1000000)), 131 | ("wait", tknzr._quantize_time(1000000)), 132 | ("wait", tknzr._quantize_time(100)), 133 | ("piano", 65, tknzr._quantize_velocity(70)), 134 | ("dur", tknzr._quantize_time(100)), 135 | ("wait", tknzr._quantize_time(100)), 136 | ("piano", 60, tknzr._quantize_velocity(50)), 137 | ("dur", tknzr._quantize_time(60)), 138 | ("piano", 70, tknzr._quantize_velocity(50)), 139 | ("dur", tknzr._quantize_time(70)), 140 | ("drum", 50), 141 | ("piano", 80, tknzr._quantize_velocity(50)), 142 | ("dur", tknzr._quantize_time(80)), 143 | ("wait", tknzr._quantize_time(100)), 144 | "", 145 | ] 146 | 147 | 148 | def get_concat_seq_rel(tknzr: tokenizer.RelTokenizer): 149 | return [ 150 | ("dur", tknzr._quantize_time(1000000)), 151 | ("wait", tknzr._quantize_time(1000000)), 152 | ("wait", tknzr._quantize_time(1000000)), 153 | ("wait", tknzr._quantize_time(1000000)), 154 | ("wait", tknzr._quantize_time(100)), 155 | ("piano", 65, tknzr._quantize_velocity(70)), 156 | ("dur", tknzr._quantize_time(100)), 157 | ("wait", tknzr._quantize_time(100)), 158 | ("piano", 60, tknzr._quantize_velocity(50)), 159 | ("dur", tknzr._quantize_time(60)), 160 | ("piano", 70, tknzr._quantize_velocity(50)), 161 | ("dur", tknzr._quantize_time(70)), 162 | ("drum", 50), 163 | ("piano", 80, tknzr._quantize_velocity(50)), 164 | ("dur", tknzr._quantize_time(80)), 165 | ("wait", tknzr._quantize_time(100)), 166 | "", 167 | ("prefix", "instrument", "piano"), 168 | ("prefix", "instrument", "drum"), 169 | ("prefix", "composer", "bach"), 170 | "", 171 | ("piano", 62, tknzr._quantize_velocity(50)), 172 | ("dur", tknzr._quantize_time(50)), 173 | ("wait", tknzr._quantize_time(100)), 174 | ("drum", tknzr._quantize_time(50)), 175 | ("piano", 64, tknzr._quantize_velocity(70)), 176 | ("dur", tknzr._quantize_time(1000000)), 177 | ("wait", tknzr._quantize_time(1000000)), 178 | ("wait", tknzr._quantize_time(1000000)), 179 | ("wait", tknzr._quantize_time(1000000)), 180 | ("wait", tknzr._quantize_time(100)), 181 | ("piano", 65, tknzr._quantize_velocity(70)), 182 | ("dur", tknzr._quantize_time(100)), 183 | ("wait", tknzr._quantize_time(100)), 184 | ("piano", 60, tknzr._quantize_velocity(50)), 185 | ("dur", tknzr._quantize_time(60)), 186 | ("piano", 70, tknzr._quantize_velocity(50)), 187 | ("dur", tknzr._quantize_time(70)), 188 | ("drum", 50), 189 | ("piano", 80, tknzr._quantize_velocity(50)), 190 | ("dur", tknzr._quantize_time(80)), 191 | ("wait", tknzr._quantize_time(100)), 192 | "", 193 | ("prefix", "instrument", "piano"), 194 | ("prefix", "instrument", "drum"), 195 | ("prefix", "composer", "bach"), 196 | "", 197 | ("piano", 62, tknzr._quantize_velocity(50)), 198 | ("dur", tknzr._quantize_time(50)), 199 | ("wait", tknzr._quantize_time(100)), 200 | ("drum", tknzr._quantize_time(50)), 201 | ("piano", 64, tknzr._quantize_velocity(70)), 202 | ] 203 | 204 | 205 | class TestAbsTokenizer(unittest.TestCase): 206 | def test_tokenize_detokenize_mididict(self): 207 | def tokenize_detokenize(file_name: str): 208 | mid_path = f"tests/test_data/{file_name}" 209 | midi_dict = MidiDict.from_midi(mid_path=mid_path) 210 | tokenized_seq = tknzr.tokenize(midi_dict) 211 | detokenized_midi_dict = tknzr.detokenize(tokenized_seq) 212 | res = detokenized_midi_dict.to_midi() 213 | res.save(f"tests/test_results/{file_name}") 214 | 215 | tknzr = tokenizer.AbsTokenizer(return_tensors=False) 216 | tokenize_detokenize("basic.mid") 217 | tokenize_detokenize("arabesque.mid") 218 | tokenize_detokenize("beethoven_sonata.mid") 219 | tokenize_detokenize("bach.mid") 220 | tokenize_detokenize("expressive.mid") 221 | tokenize_detokenize("pop.mid") 222 | tokenize_detokenize("beethoven_moonlight.mid") 223 | tokenize_detokenize("maestro.mid") 224 | 225 | def test_aug(self): 226 | def tokenize_aug_detokenize( 227 | file_name: str, 228 | aug_fn: Callable, 229 | aug_name: str, 230 | audio=False, 231 | ): 232 | mid_path = f"tests/test_data/{file_name}" 233 | midi_dict = MidiDict.from_midi(mid_path=mid_path) 234 | tokenized_seq = tknzr.tokenize(midi_dict) 235 | tokenized_seq_aug = aug_fn(tokenized_seq) 236 | detokenized_midi_dict = tknzr.detokenize(tokenized_seq_aug) 237 | res = detokenized_midi_dict.to_midi() 238 | save_path = f"tests/test_results/abs_{aug_name}_{file_name}" 239 | res.save(save_path) 240 | if audio is True: 241 | midi_to_audio(save_path) 242 | 243 | tknzr = tokenizer.AbsTokenizer(return_tensors=False) 244 | seq = get_short_seq_abs(tknzr) 245 | seq_concat = get_concat_seq_abs(tknzr) 246 | pitch_aug_fn = tknzr.export_pitch_aug(aug_range=5) 247 | velocity_aug_fn = tknzr.export_velocity_aug(aug_steps_range=2) 248 | tempo_aug_fn = tknzr.export_tempo_aug(tempo_aug_range=0.5, mixup=True) 249 | 250 | # Pitch augmentation 251 | seq_pitch_augmented = pitch_aug_fn(get_short_seq_abs(tknzr)) 252 | logging.info(f"pitch_aug_fn:\n{seq} ->\n\n{seq_pitch_augmented}\n") 253 | tokenize_aug_detokenize("basic.mid", pitch_aug_fn, "pitch") 254 | tokenize_aug_detokenize("arabesque.mid", pitch_aug_fn, "pitch") 255 | tokenize_aug_detokenize("beethoven_sonata.mid", pitch_aug_fn, "pitch") 256 | tokenize_aug_detokenize("bach.mid", pitch_aug_fn, "pitch") 257 | tokenize_aug_detokenize("expressive.mid", pitch_aug_fn, "pitch") 258 | tokenize_aug_detokenize("pop.mid", pitch_aug_fn, "pitch") 259 | tokenize_aug_detokenize( 260 | "beethoven_moonlight.mid", pitch_aug_fn, "pitch" 261 | ) 262 | 263 | # Velocity augmentation 264 | seq_velocity_augmented = velocity_aug_fn(get_short_seq_abs(tknzr)) 265 | logging.info( 266 | f"velocity_aug_fn:\n{seq} ->\n\n{seq_velocity_augmented}\n" 267 | ) 268 | tokenize_aug_detokenize("basic.mid", velocity_aug_fn, "velocity") 269 | tokenize_aug_detokenize("arabesque.mid", velocity_aug_fn, "velocity") 270 | tokenize_aug_detokenize( 271 | "beethoven_sonata.mid", velocity_aug_fn, "velocity" 272 | ) 273 | tokenize_aug_detokenize("bach.mid", velocity_aug_fn, "velocity") 274 | tokenize_aug_detokenize("expressive.mid", velocity_aug_fn, "velocity") 275 | tokenize_aug_detokenize("pop.mid", velocity_aug_fn, "velocity") 276 | tokenize_aug_detokenize( 277 | "beethoven_moonlight.mid", velocity_aug_fn, "velocity" 278 | ) 279 | 280 | # Tempo augmentation 281 | seq_tempo_augmented = tempo_aug_fn(get_short_seq_abs(tknzr)) 282 | logging.info(f"tempo_aug_fn:\n{seq} ->\n\n{seq_tempo_augmented}\n") 283 | 284 | seq_concat_tempo_augmented = tempo_aug_fn(get_concat_seq_abs(tknzr)) 285 | logging.info( 286 | f"tempo_aug_fn:\n{seq_concat} ->\n\n{seq_concat_tempo_augmented}\n" 287 | ) 288 | 289 | tokenize_aug_detokenize("basic.mid", tempo_aug_fn, "tempo") 290 | tokenize_aug_detokenize("arabesque.mid", tempo_aug_fn, "tempo") 291 | tokenize_aug_detokenize("beethoven_sonata.mid", tempo_aug_fn, "tempo") 292 | tokenize_aug_detokenize("bach.mid", tempo_aug_fn, "tempo") 293 | tokenize_aug_detokenize("expressive.mid", tempo_aug_fn, "tempo") 294 | tokenize_aug_detokenize("pop.mid", tempo_aug_fn, "tempo") 295 | tokenize_aug_detokenize( 296 | "beethoven_moonlight.mid", tempo_aug_fn, "tempo" 297 | ) 298 | 299 | def test_aug_time(self): 300 | tknzr = tokenizer.AbsTokenizer() 301 | mid_dict = MidiDict.from_midi("tests/test_data/beethoven_sonata.mid") 302 | tokenized_seq = tknzr.tokenize(mid_dict)[:4096] 303 | pitch_aug_fn = tknzr.export_pitch_aug(aug_range=5) 304 | velocity_aug_fn = tknzr.export_velocity_aug(aug_steps_range=2) 305 | tempo_aug_fn = tknzr.export_tempo_aug(tempo_aug_range=0.5, mixup=True) 306 | 307 | # Pitch augmentation 308 | t_start = time.perf_counter() 309 | pitch_aug_fn(tokenized_seq) 310 | t_pitch_aug = (time.perf_counter() - t_start) * 1e3 311 | logging.info(f"pitch_aug_fn took {int(t_pitch_aug)}ms") 312 | self.assertLessEqual(t_pitch_aug, 50) 313 | 314 | # Velocity augmentation 315 | t_start = time.perf_counter() 316 | velocity_aug_fn(tokenized_seq) 317 | t_vel_aug = (time.perf_counter() - t_start) * 1e3 318 | logging.info(f"velocity_aug_fn took {int(t_vel_aug)}ms") 319 | self.assertLessEqual(t_vel_aug, 50) 320 | 321 | # Tempo augmentation 322 | t_start = time.perf_counter() 323 | tempo_aug_fn(tokenized_seq) 324 | t_tempo_aug = (time.perf_counter() - t_start) * 1e3 325 | logging.info(f"tempo_aug_fn took {int(t_tempo_aug)}ms") 326 | self.assertLessEqual(t_tempo_aug, 50) 327 | 328 | def test_no_unk_token(self): 329 | def _test_no_unk_token(file_name: str): 330 | mid_path = f"tests/test_data/{file_name}" 331 | midi_dict = MidiDict.from_midi(mid_path=mid_path) 332 | seq = tknzr.tokenize(midi_dict) 333 | enc_dec_seq = tknzr.decode(tknzr.encode(seq)) 334 | for tok in enc_dec_seq: 335 | self.assertTrue(tok != tknzr.unk_tok) 336 | 337 | tknzr = tokenizer.AbsTokenizer() 338 | _test_no_unk_token("basic.mid") 339 | _test_no_unk_token("arabesque.mid") 340 | _test_no_unk_token("bach.mid") 341 | _test_no_unk_token("expressive.mid") 342 | _test_no_unk_token("pop.mid") 343 | _test_no_unk_token("beethoven_moonlight.mid") 344 | 345 | 346 | # TODO: This example is not working, I'm pretty sure the issue is in _get_combined_mididict somewhere 347 | # Fix this!! 348 | class TestSeparatedTokenizer(unittest.TestCase): 349 | def test_tokenize_detokenize_mididict(self): 350 | def _find_inst_onsets(_seq: list): 351 | curr_time_ms = 0 352 | time_toks = 0 353 | for tok in _seq: 354 | if tok == "": 355 | time_toks += 1 356 | elif isinstance(tok, tuple) and tok[0] == "onset": 357 | curr_time_ms = 5000 * time_toks + tok[1] 358 | elif tok == "": 359 | print("Seen at", curr_time_ms) 360 | 361 | tknzr = tokenizer.SeparatedAbsTokenizer() 362 | 363 | clean_midi_dict = MidiDict.from_midi( 364 | mid_path="/mnt/ssd1/data/mp3/raw/maestro-mp3/2004/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.midi" 365 | ) 366 | noisy_midi_dict = MidiDict.from_midi( 367 | mid_path="/mnt/ssd1/data/mp3/raw/maestro-mp3/2004/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.midi" 368 | # mid_path="/mnt/ssd1/amt/transcribed_data/noisy_maestro/small-long-e7/2004/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.mid" 369 | ) 370 | 371 | noisy_midi_dict = _noise_midi_dict( 372 | noisy_midi_dict, load_config()["data"]["finetuning"]["noising"] 373 | ) 374 | 375 | clean_mid = clean_midi_dict.to_midi() 376 | clean_mid.save(f"tests/test_results/combined_clean.mid") 377 | noisy_mid = noisy_midi_dict.to_midi() 378 | noisy_mid.save(f"tests/test_results/combined_noisy.mid") 379 | 380 | comb_midi_dict = _get_combined_mididict( 381 | clean_midi_dict, 382 | noisy_midi_dict, 383 | min_noisy_ms=10000, 384 | max_noisy_ms=25000, 385 | min_clean_ms=30000, 386 | max_clean_ms=60000, 387 | ) 388 | 389 | comb_midi = comb_midi_dict.to_midi() 390 | comb_midi.save(f"tests/test_results/combined_raw.mid") 391 | tokenized_seq = tknzr.tokenize(comb_midi_dict) 392 | detokenized_midi_dict = tknzr.detokenize(tokenized_seq) 393 | res = detokenized_midi_dict.to_midi() 394 | res.save(f"tests/test_results/combined.mid") 395 | 396 | for idx, sub_seq in enumerate(tknzr.split(tokenized_seq, 4096)): 397 | if idx == 3: 398 | _find_inst_onsets(sub_seq) 399 | print(idx) 400 | print(sub_seq) 401 | detokenized_midi_dict = tknzr.detokenize(sub_seq) 402 | res = detokenized_midi_dict.to_midi() 403 | res.save(f"tests/test_results/combined{idx}.mid") 404 | 405 | 406 | class TestRelTokenizer(unittest.TestCase): 407 | def test_tokenize_detokenize_mididict(self): 408 | def tokenize_detokenize(file_name: str): 409 | mid_path = f"tests/test_data/{file_name}" 410 | midi_dict = MidiDict.from_midi(mid_path=mid_path) 411 | tokenized_seq = tknzr.tokenize(midi_dict) 412 | detokenized_midi_dict = tknzr.detokenize(tokenized_seq) 413 | res = detokenized_midi_dict.to_midi() 414 | res.save(f"tests/test_results/{file_name}") 415 | 416 | tknzr = tokenizer.RelTokenizer(return_tensors=False) 417 | 418 | tokenize_detokenize("basic.mid") 419 | tokenize_detokenize("arabesque.mid") 420 | tokenize_detokenize("beethoven_sonata.mid") 421 | tokenize_detokenize("bach.mid") 422 | tokenize_detokenize("expressive.mid") 423 | tokenize_detokenize("pop.mid") 424 | tokenize_detokenize("beethoven_moonlight.mid") 425 | 426 | def test_aug(self): 427 | tknzr = tokenizer.RelTokenizer(return_tensors=False) 428 | seq = get_short_seq_rel(tknzr) 429 | seq_concat = get_concat_seq_rel(tknzr) 430 | pitch_aug_fn = tknzr.export_pitch_aug(aug_range=5) 431 | velocity_aug_fn = tknzr.export_velocity_aug(aug_steps_range=2) 432 | tempo_aug_fn = tknzr.export_tempo_aug(tempo_aug_range=0.8) 433 | chord_mixup_fn = tknzr.export_chord_mixup() 434 | 435 | # Pitch augmentation 436 | seq_pitch_augmented = pitch_aug_fn(get_short_seq_rel(tknzr)) 437 | logging.info(f"pitch_aug_fn:\n{seq} ->\n\n{seq_pitch_augmented}\n") 438 | self.assertEqual( 439 | seq_pitch_augmented[4][1] - seq[4][1], 440 | seq_pitch_augmented[8][1] - seq[8][1], 441 | ) 442 | 443 | # Velocity augmentation 444 | seq_velocity_augmented = velocity_aug_fn(get_short_seq_rel(tknzr)) 445 | logging.info( 446 | f"velocity_aug_fn:\n{seq} ->\n\n{seq_velocity_augmented}\n" 447 | ) 448 | self.assertEqual( 449 | seq_velocity_augmented[4][2] - seq[4][2], 450 | seq_velocity_augmented[8][2] - seq[8][2], 451 | ) 452 | 453 | # Tempo augmentation 454 | seq_tempo_augmented = tempo_aug_fn(get_short_seq_rel(tknzr)) 455 | logging.info(f"tempo_aug_fn:\n{seq} ->\n\n{seq_tempo_augmented}\n") 456 | 457 | seq_concat_tempo_augmented = tempo_aug_fn(get_concat_seq_rel(tknzr)) 458 | logging.info( 459 | f"tempo_aug_fn:\n{seq_concat} ->\n\n{seq_concat_tempo_augmented}\n" 460 | ) 461 | 462 | # Chord mix-up augmentation 463 | seq_mixup_augmented = chord_mixup_fn(get_short_seq_rel(tknzr)) 464 | logging.info(f"chord_mixup_fn:\n{seq} ->\n\n{seq_mixup_augmented}\n") 465 | 466 | seq_concat_tempo_augmented = chord_mixup_fn(get_concat_seq_rel(tknzr)) 467 | logging.info( 468 | f"chord_mixup_fn:\n{seq_concat} ->\n\n{seq_concat_tempo_augmented}\n" 469 | ) 470 | 471 | def test_aug_time(self): 472 | tknzr = tokenizer.RelTokenizer() 473 | mid_dict = MidiDict.from_midi("tests/test_data/beethoven_sonata.mid") 474 | tokenized_seq = tknzr.tokenize(mid_dict)[:4096] 475 | 476 | pitch_aug_fn = tknzr.export_pitch_aug(aug_range=5) 477 | velocity_aug_fn = tknzr.export_velocity_aug(aug_steps_range=2) 478 | tempo_aug_fn = tknzr.export_tempo_aug(tempo_aug_range=0.5) 479 | chord_mixup_fn = tknzr.export_chord_mixup() 480 | 481 | # Pitch augmentation 482 | t_start = time.perf_counter() 483 | pitch_aug_fn(tokenized_seq) 484 | t_pitch_aug = (time.perf_counter() - t_start) * 1e3 485 | logging.info(f"pitch_aug_fn took {int(t_pitch_aug)}ms") 486 | self.assertLessEqual(t_pitch_aug, 50) 487 | 488 | # Velocity augmentation 489 | t_start = time.perf_counter() 490 | velocity_aug_fn(tokenized_seq) 491 | t_vel_aug = (time.perf_counter() - t_start) * 1e3 492 | logging.info(f"velocity_aug_fn took {int(t_vel_aug)}ms") 493 | self.assertLessEqual(t_vel_aug, 50) 494 | 495 | # Tempo augmentation 496 | t_start = time.perf_counter() 497 | tempo_aug_fn(tokenized_seq) 498 | t_tempo_aug = (time.perf_counter() - t_start) * 1e3 499 | logging.info(f"tempo_aug_fn took {int(t_tempo_aug)}ms") 500 | self.assertLessEqual(t_tempo_aug, 50) 501 | 502 | # Chord mixup augmentation 503 | t_start = time.perf_counter() 504 | chord_mixup_fn(tokenized_seq) 505 | t_mixup_aug = (time.perf_counter() - t_start) * 1e3 506 | logging.info(f"mixup_aug_fn took {int(t_mixup_aug)}ms") 507 | self.assertLessEqual(t_mixup_aug, 50) 508 | 509 | def test_encode_decode(self): 510 | tknzr = tokenizer.RelTokenizer(return_tensors=True) 511 | seq = get_short_seq_rel(tknzr) 512 | enc_dec_seq = tknzr.decode(tknzr.encode(seq)) 513 | for x, y in zip(seq, enc_dec_seq): 514 | self.assertEqual(x, y) 515 | 516 | tknzr = tokenizer.RelTokenizer(return_tensors=False) 517 | seq = get_short_seq_rel(tknzr) 518 | enc_dec_seq = tknzr.decode(tknzr.encode(seq)) 519 | for x, y in zip(seq, enc_dec_seq): 520 | self.assertEqual(x, y) 521 | 522 | def test_no_unk_token(self): 523 | tknzr = tokenizer.RelTokenizer() 524 | seq = get_short_seq_rel(tknzr) 525 | enc_dec_seq = tknzr.decode(tknzr.encode(seq)) 526 | for tok in enc_dec_seq: 527 | self.assertTrue(tok != tknzr.unk_tok) 528 | 529 | 530 | if __name__ == "__main__": 531 | if os.path.isdir("tests/test_results") is False: 532 | os.mkdir("tests/test_results") 533 | 534 | logging.basicConfig(level=logging.INFO) 535 | unittest.main() 536 | --------------------------------------------------------------------------------