├── .gitignore ├── CODE_OF_CONDUCT.md ├── Data ├── CP_data │ ├── ASAP.npy │ ├── ASAP_CP.npy │ ├── composer_cp_test.npy │ ├── composer_cp_test_ans.npy │ ├── composer_cp_train.npy │ ├── composer_cp_train_ans.npy │ ├── composer_cp_valid.npy │ ├── composer_cp_valid_ans.npy │ ├── composer_test.npy │ ├── composer_test_ans.npy │ ├── composer_train.npy │ ├── composer_train_ans.npy │ ├── composer_valid.npy │ ├── composer_valid_ans.npy │ ├── emopia_cp_test.npy │ ├── emopia_cp_test_ans.npy │ ├── emopia_cp_train.npy │ ├── emopia_cp_train_ans.npy │ ├── emopia_cp_valid.npy │ ├── emopia_cp_valid_ans.npy │ ├── emopia_test.npy │ ├── emopia_test_ans.npy │ ├── emopia_train.npy │ ├── emopia_train_ans.npy │ ├── emopia_valid.npy │ ├── emopia_valid_ans.npy │ ├── pop1k7.npy │ ├── pop909_test.npy │ ├── pop909_test_melans.npy │ ├── pop909_test_velans.npy │ ├── pop909_train.npy │ ├── pop909_train_melans.npy │ ├── pop909_train_velans.npy │ ├── pop909_valid.npy │ ├── pop909_valid_melans.npy │ └── pop909_valid_velans.npy └── Dataset │ ├── ASAP_song.pkl │ ├── emopia.py │ ├── emopia_test.pkl │ ├── emopia_train.pkl │ ├── emopia_valid.pkl │ ├── pianist8.py │ ├── pianist8_test.pkl │ ├── pianist8_train.pkl │ └── pianist8_valid.pkl ├── LICENSE ├── MidiBERT ├── README.md ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── cm_fig.cpython-38.pyc │ ├── finetune_dataset.cpython-38.pyc │ ├── finetune_model.cpython-38.pyc │ ├── finetune_trainer.cpython-38.pyc │ ├── midi_dataset.cpython-38.pyc │ ├── model.cpython-38.pyc │ ├── modelLM.cpython-38.pyc │ └── trainer.cpython-38.pyc ├── cm_fig.py ├── eval.py ├── finetune.py ├── finetune_dataset.py ├── finetune_model.py ├── finetune_trainer.py ├── main.py ├── midi_dataset.py ├── model.py ├── modelLM.py └── trainer.py ├── README.md ├── data_creation ├── README.md ├── prepare_data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── model.cpython-38.pyc │ │ └── utils.cpython-38.pyc │ ├── dict │ │ ├── CP.pkl │ │ ├── make_dict.py │ │ └── remi.pkl │ ├── main.py │ ├── model.py │ └── utils.py └── preprocess_pop909 │ ├── exploratory.py │ ├── preprocess.py │ ├── qual_pieces.pkl │ └── split.pkl ├── melody_extraction ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── midi2CP.cpython-38.pyc │ └── utils.cpython-38.pyc ├── audio │ ├── README.md │ ├── pianist8 │ │ ├── Clayderman_I_Have_A_Dream.mid │ │ ├── Clayderman_I_Like_Chopin.mid │ │ ├── Clayderman_Yesterday_Once_More.mid │ │ ├── Yiruma_Love_Hurts.mid │ │ └── Yiruma_River_Flows_in_You.mid │ ├── pianist8_melody │ │ ├── Clayderman_I_Have_A_Dream_cnn.mid │ │ ├── Clayderman_I_Have_A_Dream_ours.mid │ │ ├── Clayderman_I_Have_A_Dream_skyline.mid │ │ ├── Clayderman_I_Like_Chopin_cnn.mid │ │ ├── Clayderman_I_Like_Chopin_ours.mid │ │ ├── Clayderman_I_Like_Chopin_skyline.mid │ │ ├── Clayderman_Yesterday_Once_More_cnn.mid │ │ ├── Clayderman_Yesterday_Once_More_ours.mid │ │ ├── Clayderman_Yesterday_Once_More_skyline.mid │ │ ├── Yiruma_Love_Hurts_cnn.mid │ │ ├── Yiruma_Love_Hurts_ours.mid │ │ ├── Yiruma_Love_Hurts_skyline.mid │ │ ├── Yiruma_River_Flows_in_You_cnn.mid │ │ ├── Yiruma_River_Flows_in_You_ours.mid │ │ └── Yiruma_River_Flows_in_You_skyline.mid │ ├── pop909 │ │ ├── 018.mid │ │ ├── 067.mid │ │ ├── 395.mid │ │ ├── 596.mid │ │ ├── 828.mid │ │ └── extract.py │ └── pop909_melody │ │ ├── 018_cnn.mid │ │ ├── 018_gt.mid │ │ ├── 018_ours.mid │ │ ├── 018_skyline.mid │ │ ├── 067_cnn.mid │ │ ├── 067_gt.mid │ │ ├── 067_ours.mid │ │ ├── 067_skyline.mid │ │ ├── 395_cnn.mid │ │ ├── 395_gt.mid │ │ ├── 395_ours.mid │ │ ├── 395_skyline.mid │ │ ├── 596_cnn.mid │ │ ├── 596_gt.mid │ │ ├── 596_ours.mid │ │ ├── 596_skyline.mid │ │ ├── 828_cnn.mid │ │ ├── 828_gt.mid │ │ ├── 828_ours.mid │ │ └── 828_skyline.mid ├── midibert │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── midi2CP.cpython-38.pyc │ │ └── utils.cpython-38.pyc │ ├── extract.py │ ├── midi2CP.py │ └── utils.py ├── pianoroll │ ├── 018_pianoroll.png │ ├── 596_pianoroll.png │ ├── README.md │ ├── analyzer.py │ ├── plot.py │ └── setting.py └── skyline │ ├── README.md │ ├── __pycache__ │ └── analyzer.cpython-36.pyc │ ├── analyzer.py │ ├── cal_acc.py │ └── test.py ├── requirements.txt ├── resources ├── Adele.mid └── fig │ ├── midibert.png │ └── result.png └── scripts ├── eval.sh ├── finetune.sh ├── melody_extraction.sh ├── prepare_data.sh └── pretrain.sh /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | CODE_OF_CONDUCT.md. 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at 128 | https://www.contributor-covenant.org/translations. 129 | -------------------------------------------------------------------------------- /Data/CP_data/ASAP.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/CP_data/ASAP.npy -------------------------------------------------------------------------------- /Data/CP_data/ASAP_CP.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/CP_data/ASAP_CP.npy -------------------------------------------------------------------------------- /Data/CP_data/composer_cp_test.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/CP_data/composer_cp_test.npy -------------------------------------------------------------------------------- /Data/CP_data/composer_cp_test_ans.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/CP_data/composer_cp_test_ans.npy -------------------------------------------------------------------------------- /Data/CP_data/composer_cp_train.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/CP_data/composer_cp_train.npy -------------------------------------------------------------------------------- /Data/CP_data/composer_cp_train_ans.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/CP_data/composer_cp_train_ans.npy -------------------------------------------------------------------------------- /Data/CP_data/composer_cp_valid.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/CP_data/composer_cp_valid.npy -------------------------------------------------------------------------------- /Data/CP_data/composer_cp_valid_ans.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/CP_data/composer_cp_valid_ans.npy -------------------------------------------------------------------------------- /Data/CP_data/composer_test.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/CP_data/composer_test.npy -------------------------------------------------------------------------------- /Data/CP_data/composer_test_ans.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/CP_data/composer_test_ans.npy -------------------------------------------------------------------------------- /Data/CP_data/composer_train.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/CP_data/composer_train.npy -------------------------------------------------------------------------------- /Data/CP_data/composer_train_ans.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/CP_data/composer_train_ans.npy -------------------------------------------------------------------------------- /Data/CP_data/composer_valid.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/CP_data/composer_valid.npy -------------------------------------------------------------------------------- /Data/CP_data/composer_valid_ans.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/CP_data/composer_valid_ans.npy -------------------------------------------------------------------------------- /Data/CP_data/emopia_cp_test.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/CP_data/emopia_cp_test.npy -------------------------------------------------------------------------------- /Data/CP_data/emopia_cp_test_ans.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/CP_data/emopia_cp_test_ans.npy -------------------------------------------------------------------------------- /Data/CP_data/emopia_cp_train.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/CP_data/emopia_cp_train.npy -------------------------------------------------------------------------------- /Data/CP_data/emopia_cp_train_ans.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/CP_data/emopia_cp_train_ans.npy -------------------------------------------------------------------------------- /Data/CP_data/emopia_cp_valid.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/CP_data/emopia_cp_valid.npy -------------------------------------------------------------------------------- /Data/CP_data/emopia_cp_valid_ans.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/CP_data/emopia_cp_valid_ans.npy -------------------------------------------------------------------------------- /Data/CP_data/emopia_test.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/CP_data/emopia_test.npy -------------------------------------------------------------------------------- /Data/CP_data/emopia_test_ans.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/CP_data/emopia_test_ans.npy -------------------------------------------------------------------------------- /Data/CP_data/emopia_train.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/CP_data/emopia_train.npy -------------------------------------------------------------------------------- /Data/CP_data/emopia_train_ans.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/CP_data/emopia_train_ans.npy -------------------------------------------------------------------------------- /Data/CP_data/emopia_valid.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/CP_data/emopia_valid.npy -------------------------------------------------------------------------------- /Data/CP_data/emopia_valid_ans.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/CP_data/emopia_valid_ans.npy -------------------------------------------------------------------------------- /Data/CP_data/pop1k7.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/CP_data/pop1k7.npy -------------------------------------------------------------------------------- /Data/CP_data/pop909_test.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/CP_data/pop909_test.npy -------------------------------------------------------------------------------- /Data/CP_data/pop909_test_melans.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/CP_data/pop909_test_melans.npy -------------------------------------------------------------------------------- /Data/CP_data/pop909_test_velans.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/CP_data/pop909_test_velans.npy -------------------------------------------------------------------------------- /Data/CP_data/pop909_train.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/CP_data/pop909_train.npy -------------------------------------------------------------------------------- /Data/CP_data/pop909_train_melans.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/CP_data/pop909_train_melans.npy -------------------------------------------------------------------------------- /Data/CP_data/pop909_train_velans.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/CP_data/pop909_train_velans.npy -------------------------------------------------------------------------------- /Data/CP_data/pop909_valid.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/CP_data/pop909_valid.npy -------------------------------------------------------------------------------- /Data/CP_data/pop909_valid_melans.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/CP_data/pop909_valid_melans.npy -------------------------------------------------------------------------------- /Data/CP_data/pop909_valid_velans.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/CP_data/pop909_valid_velans.npy -------------------------------------------------------------------------------- /Data/Dataset/ASAP_song.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/Dataset/ASAP_song.pkl -------------------------------------------------------------------------------- /Data/Dataset/emopia.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import shutil 4 | 5 | root = 'EMOPIA_1.0/' 6 | 7 | def move(files, subset): 8 | for f in files: 9 | piece = f.split('/')[-1] 10 | src = os.path.join(root, 'midis', piece) 11 | shutil.move(src, os.path.join(root, subset, piece)) 12 | 13 | 14 | if __name__ == '__main__': 15 | train = pickle.load(open('emopia_train.pkl','rb')) 16 | valid = pickle.load(open('emopia_valid.pkl','rb')) 17 | test = pickle.load(open('emopia_test.pkl','rb')) 18 | 19 | dest = os.path.join(root, 'train') 20 | os.makedirs(dest, exist_ok=True) 21 | dest = os.path.join(root, 'valid') 22 | os.makedirs(dest, exist_ok=True) 23 | dest = os.path.join(root, 'test') 24 | os.makedirs(dest, exist_ok=True) 25 | 26 | move(train, 'train') 27 | move(valid, 'valid') 28 | move(test, 'test') 29 | -------------------------------------------------------------------------------- /Data/Dataset/emopia_test.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/Dataset/emopia_test.pkl -------------------------------------------------------------------------------- /Data/Dataset/emopia_train.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/Dataset/emopia_train.pkl -------------------------------------------------------------------------------- /Data/Dataset/emopia_valid.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/Dataset/emopia_valid.pkl -------------------------------------------------------------------------------- /Data/Dataset/pianist8.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import shutil 4 | 5 | root = 'joann8512-Pianist8-ab9f541/' 6 | 7 | def move(files, subset): 8 | for f in files: 9 | composer, piece = f.split('/') 10 | dest = os.path.join(root, subset, composer) 11 | os.makedirs(dest, exist_ok=True) 12 | 13 | src = os.path.join(root, 'midi', f) 14 | shutil.move(src, os.path.join(dest, piece)) 15 | 16 | 17 | if __name__ == '__main__': 18 | train = pickle.load(open('pianist8_train.pkl','rb')) 19 | valid = pickle.load(open('pianist8_valid.pkl','rb')) 20 | test = pickle.load(open('pianist8_test.pkl','rb')) 21 | 22 | move(train, 'train') 23 | move(valid, 'valid') 24 | move(test, 'test') 25 | 26 | -------------------------------------------------------------------------------- /Data/Dataset/pianist8_test.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/Dataset/pianist8_test.pkl -------------------------------------------------------------------------------- /Data/Dataset/pianist8_train.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/Dataset/pianist8_train.pkl -------------------------------------------------------------------------------- /Data/Dataset/pianist8_valid.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/Data/Dataset/pianist8_valid.pkl -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) MidiBERT-Piano development team 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MidiBERT/README.md: -------------------------------------------------------------------------------- 1 | # MidiBERT 2 | 3 | ## 1. Pre-train a MidiBERT-Piano 4 | ### Pre-train (default) 5 | ```./scripts/pretrain.sh``` 6 | You can change the output directory name by specifying `--name={name}`. 7 | 8 | A folder named ```result/pretrain/{name}/``` will be created, with checkpoint & log inside. 9 | 10 | ### Customize your own pre-training dataset 11 | 12 | Feel free to select given dataset and add your own dataset. To do this, add ```--dataset```, and specify the respective path in ```load_data()``` function. 13 | 14 | For example, 15 | ```python 16 | # To pre-train a model with only 2 datasets 17 | export PYTHONPATH='.' 18 | python3 main.py --name=default --dataset pop1k7 asap 19 | ``` 20 | 21 | Acknowledgement: [HuggingFace](https://github.com/huggingface/transformers), [codertimo/BERT-pytorch](https://github.com/codertimo/BERT-pytorch) 22 | 23 | Special thanks to Chin-Jui Chang 24 | 25 | ## 2. Fine-tune on Downstream Tasks 26 | ```./scripts/finetune.sh``` 27 | 28 | A folder named ```result/finetune/{name}/``` will be created, with checkpoint & log inside. 29 | 30 | ## 3. Evaluation 31 | ```./scripts/eval.sh``` 32 | 33 | ```python 34 | python3 eval.py --task=melody --cpu --ckpt=[ckpt_path] 35 | ``` 36 | 37 | Test loss & accuracy will be printed, and a figure of confusion matrix will be saved in the same directory as the checkpoint. 38 | -------------------------------------------------------------------------------- /MidiBERT/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/MidiBERT/__init__.py -------------------------------------------------------------------------------- /MidiBERT/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/MidiBERT/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /MidiBERT/__pycache__/cm_fig.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/MidiBERT/__pycache__/cm_fig.cpython-38.pyc -------------------------------------------------------------------------------- /MidiBERT/__pycache__/finetune_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/MidiBERT/__pycache__/finetune_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /MidiBERT/__pycache__/finetune_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/MidiBERT/__pycache__/finetune_model.cpython-38.pyc -------------------------------------------------------------------------------- /MidiBERT/__pycache__/finetune_trainer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/MidiBERT/__pycache__/finetune_trainer.cpython-38.pyc -------------------------------------------------------------------------------- /MidiBERT/__pycache__/midi_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/MidiBERT/__pycache__/midi_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /MidiBERT/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/MidiBERT/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /MidiBERT/__pycache__/modelLM.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/MidiBERT/__pycache__/modelLM.cpython-38.pyc -------------------------------------------------------------------------------- /MidiBERT/__pycache__/trainer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/MidiBERT/__pycache__/trainer.cpython-38.pyc -------------------------------------------------------------------------------- /MidiBERT/cm_fig.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import itertools 3 | import numpy as np 4 | 5 | 6 | def save_cm_fig(cm, classes, normalize, title, outdir, seq): 7 | if not seq: 8 | cm = cm[1:,1:] # exclude padding 9 | 10 | if normalize: 11 | cm = cm.astype('float')*100/cm.sum(axis=1)[:,None] 12 | 13 | # print(cm) 14 | plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) 15 | plt.title(title) 16 | tick_marks = np.arange(len(classes)) 17 | plt.xticks(tick_marks, classes, fontsize=20) 18 | plt.yticks(tick_marks, classes, fontsize=20) 19 | 20 | fmt = '.2f' if normalize else 'd' 21 | threshold = cm.max()/2. 22 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 23 | plt.text(j, i, format(cm[i,j],fmt), fontsize=15, 24 | horizontalalignment='center', 25 | color='white' if cm[i,j] > threshold else 'black') 26 | plt.xlabel('predicted', fontsize=18) 27 | plt.ylabel('true', fontsize=18) 28 | plt.tight_layout() 29 | 30 | plt.savefig(f'{outdir}/cm_{title.split()[2]}.jpg') 31 | return 32 | -------------------------------------------------------------------------------- /MidiBERT/eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluate the model on fine-tuning task (melody, velocity, composer, emotion) 3 | Return loss, accuracy, confusion matrix. 4 | """ 5 | import argparse 6 | import numpy as np 7 | import random 8 | import pickle 9 | import os 10 | import copy 11 | import shutil 12 | import json 13 | from sklearn.metrics import confusion_matrix 14 | from cm_fig import save_cm_fig 15 | 16 | from torch.utils.data import DataLoader 17 | import torch 18 | import torch.nn as nn 19 | from transformers import BertConfig 20 | 21 | from MidiBERT.model import MidiBert 22 | from MidiBERT.finetune_trainer import FinetuneTrainer 23 | from MidiBERT.finetune_dataset import FinetuneDataset 24 | from MidiBERT.finetune_model import TokenClassification, SequenceClassification 25 | 26 | 27 | def get_args(): 28 | parser = argparse.ArgumentParser(description='') 29 | 30 | ### mode ### 31 | parser.add_argument('--task', choices=['melody', 'velocity','composer', 'emotion'], required=True) 32 | 33 | ### path setup ### 34 | parser.add_argument('--dict_file', type=str, default='data_creation/prepare_data/dict/CP.pkl') 35 | parser.add_argument('--ckpt', type=str, default='') 36 | 37 | ### parameter setting ### 38 | parser.add_argument('--num_workers', type=int, default=5) 39 | parser.add_argument('--class_num', type=int) 40 | parser.add_argument('--batch_size', type=int, default=12) 41 | parser.add_argument('--max_seq_len', type=int, default=512, help='all sequences are padded to `max_seq_len`') 42 | parser.add_argument('--hs', type=int, default=768) 43 | parser.add_argument("--index_layer", type=int, default=12, help="number of layers") 44 | parser.add_argument('--lr', type=float, default=2e-5, help='initial learning rate') 45 | 46 | ### cuda ### 47 | parser.add_argument('--cpu', action="store_true") # default: false 48 | parser.add_argument("--cuda_devices", type=int, nargs='+', default=[0,1,2,3], help="CUDA device ids") 49 | 50 | args = parser.parse_args() 51 | 52 | root = 'result/finetune/' 53 | 54 | if args.task == 'melody': 55 | args.class_num = 4 56 | args.ckpt = root + 'melody_default/model_best.ckpt' if args.ckpt=='' else args.ckpt 57 | elif args.task == 'velocity': 58 | args.class_num = 7 59 | args.ckpt = root + 'velocity_default/model_best.ckpt' if args.ckpt=='' else args.ckpt 60 | elif args.task == 'composer': 61 | args.class_num = 8 62 | args.ckpt = root + 'composer_default/model_best.ckpt' if args.ckpt=='' else args.ckpt 63 | elif args.task == 'emotion': 64 | args.class_num = 4 65 | args.ckpt = root + 'emotion_default/model_best.ckpt' if args.ckpt=='' else args.ckpt 66 | 67 | return args 68 | 69 | 70 | def load_data(dataset, task): 71 | data_root = 'Data/CP_data' 72 | 73 | if dataset == 'emotion': 74 | dataset = 'emopia' 75 | 76 | if dataset not in ['pop909', 'composer', 'emopia']: 77 | print('dataset {} not supported'.format(dataset)) 78 | exit(1) 79 | 80 | X_train = np.load(os.path.join(data_root, f'{dataset}_train.npy'), allow_pickle=True) 81 | X_val = np.load(os.path.join(data_root, f'{dataset}_valid.npy'), allow_pickle=True) 82 | X_test = np.load(os.path.join(data_root, f'{dataset}_test.npy'), allow_pickle=True) 83 | 84 | print('X_train: {}, X_valid: {}, X_test: {}'.format(X_train.shape, X_val.shape, X_test.shape)) 85 | 86 | if dataset == 'pop909': 87 | y_train = np.load(os.path.join(data_root, f'{dataset}_train_{task[:3]}ans.npy'), allow_pickle=True) 88 | y_val = np.load(os.path.join(data_root, f'{dataset}_valid_{task[:3]}ans.npy'), allow_pickle=True) 89 | y_test = np.load(os.path.join(data_root, f'{dataset}_test_{task[:3]}ans.npy'), allow_pickle=True) 90 | else: 91 | y_train = np.load(os.path.join(data_root, f'{dataset}_train_ans.npy'), allow_pickle=True) 92 | y_val = np.load(os.path.join(data_root, f'{dataset}_valid_ans.npy'), allow_pickle=True) 93 | y_test = np.load(os.path.join(data_root, f'{dataset}_test_ans.npy'), allow_pickle=True) 94 | 95 | print('y_train: {}, y_valid: {}, y_test: {}'.format(y_train.shape, y_val.shape, y_test.shape)) 96 | 97 | return X_train, X_val, X_test, y_train, y_val, y_test 98 | 99 | 100 | def conf_mat(_y, output, task, outdir): 101 | if task == 'melody': 102 | target_names = ['M','B','A'] 103 | seq = False 104 | elif task == 'velocity': 105 | target_names = ['pp','p','mp','mf','f','ff'] 106 | seq = False 107 | elif task == 'composer': 108 | target_names = ['M', 'C', 'E','H','W','J','S','Y'] 109 | seq = True 110 | elif task == 'emotion': 111 | target_names = ['HAHV', 'HALV', 'LALV', 'LAHV'] 112 | seq = True 113 | 114 | output = output.detach().cpu().numpy() 115 | output = output.reshape(-1,1) 116 | _y = _y.reshape(-1,1) 117 | 118 | cm = confusion_matrix(_y, output) 119 | 120 | _title = 'BERT (CP): ' + task + ' task' 121 | 122 | save_cm_fig(cm, classes=target_names, normalize=False, 123 | title=_title, outdir=outdir, seq=seq) 124 | 125 | 126 | def main(): 127 | args = get_args() 128 | 129 | print("Loading Dictionary") 130 | with open(args.dict_file, 'rb') as f: 131 | e2w, w2e = pickle.load(f) 132 | 133 | print("\nBuilding BERT model") 134 | configuration = BertConfig(max_position_embeddings=args.max_seq_len, 135 | position_embedding_type='relative_key_query', 136 | hidden_size=args.hs) 137 | 138 | midibert = MidiBert(bertConfig=configuration, e2w=e2w, w2e=w2e) 139 | 140 | print("\nLoading Dataset") 141 | if args.task == 'melody' or args.task == 'velocity': 142 | dataset = 'pop909' 143 | model = TokenClassification(midibert, args.class_num, args.hs) 144 | seq_class = False 145 | elif args.task == 'composer' or args.task == 'emotion': 146 | dataset = args.task 147 | model = SequenceClassification(midibert, args.class_num, args.hs) 148 | seq_class = True 149 | 150 | X_train, X_val, X_test, y_train, y_val, y_test = load_data(dataset, args.task) 151 | 152 | trainset = FinetuneDataset(X=X_train, y=y_train) 153 | validset = FinetuneDataset(X=X_val, y=y_val) 154 | testset = FinetuneDataset(X=X_test, y=y_test) 155 | 156 | train_loader = DataLoader(trainset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True) 157 | print(" len of train_loader",len(train_loader)) 158 | valid_loader = DataLoader(validset, batch_size=args.batch_size, num_workers=args.num_workers) 159 | print(" len of valid_loader",len(valid_loader)) 160 | test_loader = DataLoader(testset, batch_size=args.batch_size, num_workers=args.num_workers) 161 | print(" len of test_loader",len(test_loader)) 162 | 163 | 164 | print('\nLoad ckpt from', args.ckpt) 165 | best_mdl = args.ckpt 166 | checkpoint = torch.load(best_mdl, map_location='cpu') 167 | model.load_state_dict(checkpoint['state_dict']) 168 | 169 | # remove module 170 | #from collections import OrderedDict 171 | #new_state_dict = OrderedDict() 172 | #for k, v in checkpoint['state_dict'].items(): 173 | # name = k[7:] 174 | # new_state_dict[name] = v 175 | #model.load_state_dict(new_state_dict) 176 | 177 | index_layer = int(args.index_layer)-13 178 | print("\nCreating Finetune Trainer using index layer", index_layer) 179 | trainer = FinetuneTrainer(midibert, train_loader, valid_loader, test_loader, index_layer, args.lr, args.class_num, 180 | args.hs, y_test.shape, args.cpu, args.cuda_devices, model, seq_class) 181 | 182 | 183 | test_loss, test_acc, all_output = trainer.test() 184 | print('test loss: {}, test_acc: {}'.format(test_loss, test_acc)) 185 | 186 | outdir = os.path.dirname(args.ckpt) 187 | conf_mat(y_test, all_output, args.task, outdir) 188 | 189 | if __name__ == '__main__': 190 | main() 191 | -------------------------------------------------------------------------------- /MidiBERT/finetune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import pickle 4 | import os 5 | import random 6 | 7 | from torch.utils.data import DataLoader 8 | import torch 9 | from transformers import BertConfig 10 | 11 | from MidiBERT.model import MidiBert 12 | from MidiBERT.finetune_trainer import FinetuneTrainer 13 | from MidiBERT.finetune_dataset import FinetuneDataset 14 | 15 | from matplotlib import pyplot as plt 16 | 17 | def get_args(): 18 | parser = argparse.ArgumentParser(description='') 19 | 20 | ### mode ### 21 | parser.add_argument('--task', choices=['melody', 'velocity', 'composer', 'emotion'], required=True) 22 | ### path setup ### 23 | parser.add_argument('--dict_file', type=str, default='data_creation/prepare_data/dict/CP.pkl') 24 | parser.add_argument('--name', type=str, default='') 25 | parser.add_argument('--ckpt', default='result/pretrain/default/model_best.ckpt') 26 | 27 | ### parameter setting ### 28 | parser.add_argument('--num_workers', type=int, default=5) 29 | parser.add_argument('--class_num', type=int) 30 | parser.add_argument('--batch_size', type=int, default=12) 31 | parser.add_argument('--max_seq_len', type=int, default=512, help='all sequences are padded to `max_seq_len`') 32 | parser.add_argument('--hs', type=int, default=768) 33 | parser.add_argument("--index_layer", type=int, default=12, help="number of layers") 34 | parser.add_argument('--epochs', type=int, default=10, help='number of training epochs') 35 | parser.add_argument('--lr', type=float, default=2e-5, help='initial learning rate') 36 | parser.add_argument('--nopretrain', action="store_true") # default: false 37 | 38 | ### cuda ### 39 | parser.add_argument("--cpu", action="store_true") # default=False 40 | parser.add_argument("--cuda_devices", type=int, nargs='+', default=[0,1,2,3], help="CUDA device ids") 41 | 42 | args = parser.parse_args() 43 | 44 | if args.task == 'melody': 45 | args.class_num = 4 46 | elif args.task == 'velocity': 47 | args.class_num = 7 48 | elif args.task == 'composer': 49 | args.class_num = 8 50 | elif args.task == 'emotion': 51 | args.class_num = 4 52 | 53 | return args 54 | 55 | 56 | def load_data(dataset, task): 57 | data_root = 'Data/CP_data' 58 | 59 | if dataset == 'emotion': 60 | dataset = 'emopia' 61 | 62 | if dataset not in ['pop909', 'composer', 'emopia']: 63 | print(f'Dataset {dataset} not supported') 64 | exit(1) 65 | 66 | X_train = np.load(os.path.join(data_root, f'{dataset}_train.npy'), allow_pickle=True) 67 | X_val = np.load(os.path.join(data_root, f'{dataset}_valid.npy'), allow_pickle=True) 68 | X_test = np.load(os.path.join(data_root, f'{dataset}_test.npy'), allow_pickle=True) 69 | 70 | print('X_train: {}, X_valid: {}, X_test: {}'.format(X_train.shape, X_val.shape, X_test.shape)) 71 | 72 | if dataset == 'pop909': 73 | y_train = np.load(os.path.join(data_root, f'{dataset}_train_{task[:3]}ans.npy'), allow_pickle=True) 74 | y_val = np.load(os.path.join(data_root, f'{dataset}_valid_{task[:3]}ans.npy'), allow_pickle=True) 75 | y_test = np.load(os.path.join(data_root, f'{dataset}_test_{task[:3]}ans.npy'), allow_pickle=True) 76 | else: 77 | y_train = np.load(os.path.join(data_root, f'{dataset}_train_ans.npy'), allow_pickle=True) 78 | y_val = np.load(os.path.join(data_root, f'{dataset}_valid_ans.npy'), allow_pickle=True) 79 | y_test = np.load(os.path.join(data_root, f'{dataset}_test_ans.npy'), allow_pickle=True) 80 | 81 | print('y_train: {}, y_valid: {}, y_test: {}'.format(y_train.shape, y_val.shape, y_test.shape)) 82 | 83 | return X_train, X_val, X_test, y_train, y_val, y_test 84 | 85 | 86 | def main(): 87 | # set seed 88 | seed = 2021 89 | torch.manual_seed(seed) # cpu 90 | torch.cuda.manual_seed(seed) # current gpu 91 | torch.cuda.manual_seed_all(seed) # all gpu 92 | np.random.seed(seed) 93 | random.seed(seed) 94 | 95 | # argument 96 | args = get_args() 97 | 98 | print("Loading Dictionary") 99 | with open(args.dict_file, 'rb') as f: 100 | e2w, w2e = pickle.load(f) 101 | 102 | print("\nLoading Dataset") 103 | if args.task == 'melody' or args.task == 'velocity': 104 | dataset = 'pop909' 105 | seq_class = False 106 | elif args.task == 'composer': 107 | dataset = 'composer' 108 | seq_class = True 109 | elif args.task == 'emotion': 110 | dataset = 'emopia' 111 | seq_class = True 112 | X_train, X_val, X_test, y_train, y_val, y_test = load_data(dataset, args.task) 113 | 114 | trainset = FinetuneDataset(X=X_train, y=y_train) 115 | validset = FinetuneDataset(X=X_val, y=y_val) 116 | testset = FinetuneDataset(X=X_test, y=y_test) 117 | 118 | train_loader = DataLoader(trainset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True) 119 | print(" len of train_loader",len(train_loader)) 120 | valid_loader = DataLoader(validset, batch_size=args.batch_size, num_workers=args.num_workers) 121 | print(" len of valid_loader",len(valid_loader)) 122 | test_loader = DataLoader(testset, batch_size=args.batch_size, num_workers=args.num_workers) 123 | print(" len of valid_loader",len(test_loader)) 124 | 125 | 126 | print("\nBuilding BERT model") 127 | configuration = BertConfig(max_position_embeddings=args.max_seq_len, 128 | position_embedding_type='relative_key_query', 129 | hidden_size=args.hs) 130 | 131 | midibert = MidiBert(bertConfig=configuration, e2w=e2w, w2e=w2e) 132 | best_mdl = '' 133 | if not args.nopretrain: 134 | best_mdl = args.ckpt 135 | print(" Loading pre-trained model from", best_mdl.split('/')[-1]) 136 | checkpoint = torch.load(best_mdl, map_location='cpu') 137 | midibert.load_state_dict(checkpoint['state_dict']) 138 | 139 | index_layer = int(args.index_layer)-13 140 | print("\nCreating Finetune Trainer using index layer", index_layer) 141 | trainer = FinetuneTrainer(midibert, train_loader, valid_loader, test_loader, index_layer, args.lr, args.class_num, 142 | args.hs, y_test.shape, args.cpu, args.cuda_devices, None, seq_class) 143 | 144 | 145 | print("\nTraining Start") 146 | save_dir = os.path.join('result/finetune/', args.task + '_' + args.name) 147 | os.makedirs(save_dir, exist_ok=True) 148 | filename = os.path.join(save_dir, 'model.ckpt') 149 | print(" save model at {}".format(filename)) 150 | 151 | best_acc, best_epoch = 0, 0 152 | bad_cnt = 0 153 | 154 | # train_accs, valid_accs = [], [] 155 | with open(os.path.join(save_dir, 'log'), 'a') as outfile: 156 | outfile.write("Loading pre-trained model from " + best_mdl.split('/')[-1] + '\n') 157 | for epoch in range(args.epochs): 158 | train_loss, train_acc = trainer.train() 159 | valid_loss, valid_acc = trainer.valid() 160 | test_loss, test_acc, _ = trainer.test() 161 | 162 | is_best = valid_acc >= best_acc 163 | best_acc = max(valid_acc, best_acc) 164 | 165 | if is_best: 166 | bad_cnt, best_epoch = 0, epoch 167 | else: 168 | bad_cnt += 1 169 | 170 | print('epoch: {}/{} | Train Loss: {} | Train acc: {} | Valid Loss: {} | Valid acc: {} | Test loss: {} | Test acc: {}'.format( 171 | epoch+1, args.epochs, train_loss, train_acc, valid_loss, valid_acc, test_loss, test_acc)) 172 | 173 | # train_accs.append(train_acc) 174 | # valid_accs.append(valid_acc) 175 | trainer.save_checkpoint(epoch, train_acc, valid_acc, 176 | valid_loss, train_loss, is_best, filename) 177 | 178 | 179 | outfile.write('Epoch {}: train_loss={}, valid_loss={}, test_loss={}, train_acc={}, valid_acc={}, test_acc={}\n'.format( 180 | epoch+1, train_loss, valid_loss, test_loss, train_acc, valid_acc, test_acc)) 181 | 182 | if bad_cnt > 3: 183 | print('valid acc not improving for 3 epochs') 184 | break 185 | 186 | # draw figure valid_acc & train_acc 187 | '''plt.figure() 188 | plt.plot(train_accs) 189 | plt.plot(valid_accs) 190 | plt.title(f'{args.task} task accuracy (w/o pre-training)') 191 | plt.xlabel('epoch') 192 | plt.ylabel('accuracy') 193 | plt.legend(['train','valid'], loc='upper left') 194 | plt.savefig(f'acc_{args.task}_scratch.jpg')''' 195 | 196 | if __name__ == '__main__': 197 | main() 198 | -------------------------------------------------------------------------------- /MidiBERT/finetune_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torch 3 | 4 | class FinetuneDataset(Dataset): 5 | """ 6 | Expected data shape: (data_num, data_len) 7 | """ 8 | def __init__(self, X, y): 9 | self.data = X 10 | self.label = y 11 | 12 | def __len__(self): 13 | return(len(self.data)) 14 | 15 | def __getitem__(self, index): 16 | return torch.tensor(self.data[index]), torch.tensor(self.label[index]) 17 | -------------------------------------------------------------------------------- /MidiBERT/finetune_model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import random 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from MidiBERT.model import MidiBert 10 | 11 | 12 | class TokenClassification(nn.Module): 13 | def __init__(self, midibert, class_num, hs): 14 | super().__init__() 15 | 16 | self.midibert = midibert 17 | self.classifier = nn.Sequential( 18 | nn.Dropout(0.1), 19 | nn.Linear(hs, 256), 20 | nn.ReLU(), 21 | nn.Linear(256, class_num) 22 | ) 23 | 24 | def forward(self, y, attn, layer): 25 | # feed to bert 26 | y = self.midibert(y, attn, output_hidden_states=True) 27 | #y = y.last_hidden_state # (batch_size, seq_len, 768) 28 | y = y.hidden_states[layer] 29 | return self.classifier(y) 30 | 31 | 32 | class SequenceClassification(nn.Module): 33 | def __init__(self, midibert, class_num, hs, da=128, r=4): 34 | super(SequenceClassification, self).__init__() 35 | self.midibert = midibert 36 | self.attention = SelfAttention(hs, da, r) 37 | self.classifier = nn.Sequential( 38 | nn.Linear(hs*r, 256), 39 | nn.ReLU(), 40 | nn.Linear(256, class_num) 41 | ) 42 | 43 | def forward(self, x, attn, layer): # x: (batch, 512, 4) 44 | x = self.midibert(x, attn, output_hidden_states=True) # (batch, 512, 768) 45 | #y = y.last_hidden_state # (batch_size, seq_len, 768) 46 | x = x.hidden_states[layer] 47 | attn_mat = self.attention(x) # attn_mat: (batch, r, 512) 48 | m = torch.bmm(attn_mat, x) # m: (batch, r, 768) 49 | flatten = m.view(m.size()[0], -1) # flatten: (batch, r*768) 50 | res = self.classifier(flatten) # res: (batch, class_num) 51 | return res 52 | 53 | 54 | class SelfAttention(nn.Module): 55 | def __init__(self, input_dim, da, r): 56 | ''' 57 | Args: 58 | input_dim (int): batch, seq, input_dim 59 | da (int): number of features in hidden layer from self-attn 60 | r (int): number of aspects of self-attn 61 | ''' 62 | super(SelfAttention, self).__init__() 63 | self.ws1 = nn.Linear(input_dim, da, bias=False) 64 | self.ws2 = nn.Linear(da, r, bias=False) 65 | 66 | def forward(self, h): 67 | attn_mat = F.softmax(self.ws2(torch.tanh(self.ws1(h))), dim=1) 68 | attn_mat = attn_mat.permute(0,2,1) 69 | return attn_mat 70 | -------------------------------------------------------------------------------- /MidiBERT/finetune_trainer.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import numpy as np 3 | import tqdm 4 | import torch 5 | import torch.nn as nn 6 | from transformers import AdamW 7 | from torch.nn.utils import clip_grad_norm_ 8 | 9 | from MidiBERT.finetune_model import TokenClassification, SequenceClassification 10 | 11 | 12 | class FinetuneTrainer: 13 | def __init__(self, midibert, train_dataloader, valid_dataloader, test_dataloader, layer, 14 | lr, class_num, hs, testset_shape, cpu, cuda_devices=None, model=None, SeqClass=False): 15 | self.device = torch.device("cuda" if torch.cuda.is_available() and not cpu else 'cpu') 16 | print(' device:',self.device) 17 | self.midibert = midibert 18 | self.SeqClass = SeqClass 19 | self.layer = layer 20 | 21 | if model != None: # load model 22 | print('load a fine-tuned model') 23 | self.model = model.to(self.device) 24 | else: 25 | print('init a fine-tune model, sequence-level task?', SeqClass) 26 | if SeqClass: 27 | self.model = SequenceClassification(self.midibert, class_num, hs).to(self.device) 28 | else: 29 | self.model = TokenClassification(self.midibert, class_num, hs).to(self.device) 30 | 31 | # for name, param in self.model.named_parameters(): 32 | # if 'midibert.bert' in name: 33 | # param.requires_grad = False 34 | # print(name, param.requires_grad) 35 | 36 | 37 | if torch.cuda.device_count() > 1 and not cpu: 38 | print("Use %d GPUS" % torch.cuda.device_count()) 39 | self.model = nn.DataParallel(self.model, device_ids=cuda_devices) 40 | 41 | self.train_data = train_dataloader 42 | self.valid_data = valid_dataloader 43 | self.test_data = test_dataloader 44 | 45 | self.optim = AdamW(self.model.parameters(), lr=lr, weight_decay=0.01) 46 | self.loss_func = nn.CrossEntropyLoss(reduction='none') 47 | 48 | self.testset_shape = testset_shape 49 | 50 | def compute_loss(self, predict, target, loss_mask, seq): 51 | loss = self.loss_func(predict, target) 52 | if not seq: 53 | loss = loss * loss_mask 54 | loss = torch.sum(loss) / torch.sum(loss_mask) 55 | else: 56 | loss = torch.sum(loss)/loss.shape[0] 57 | return loss 58 | 59 | 60 | def train(self): 61 | self.model.train() 62 | train_loss, train_acc = self.iteration(self.train_data, 0, self.SeqClass) 63 | return train_loss, train_acc 64 | 65 | def valid(self): 66 | self.model.eval() 67 | valid_loss, valid_acc = self.iteration(self.valid_data, 1, self.SeqClass) 68 | return valid_loss, valid_acc 69 | 70 | def test(self): 71 | self.model.eval() 72 | test_loss, test_acc, all_output = self.iteration(self.test_data, 2, self.SeqClass) 73 | return test_loss, test_acc, all_output 74 | 75 | def iteration(self, training_data, mode, seq): 76 | pbar = tqdm.tqdm(training_data, disable=False) 77 | 78 | total_acc, total_cnt, total_loss = 0, 0, 0 79 | 80 | if mode == 2: # testing 81 | all_output = torch.empty(self.testset_shape) 82 | cnt = 0 83 | 84 | for x, y in pbar: # (batch, 512, 768) 85 | batch = x.shape[0] 86 | x, y = x.to(self.device), y.to(self.device) # seq: (batch, 512, 4), (batch) / token: , (batch, 512) 87 | 88 | # avoid attend to pad word 89 | if not seq: 90 | attn = (y != 0).float().to(self.device) # (batch,512) 91 | else: 92 | attn = torch.ones((batch, 512)).to(self.device) # attend each of them 93 | 94 | y_hat = self.model.forward(x, attn, self.layer) # seq: (batch, class_num) / token: (batch, 512, class_num) 95 | 96 | # get the most likely choice with max 97 | output = np.argmax(y_hat.cpu().detach().numpy(), axis=-1) 98 | output = torch.from_numpy(output).to(self.device) 99 | if mode == 2: 100 | all_output[cnt : cnt+batch] = output 101 | cnt += batch 102 | 103 | # accuracy 104 | if not seq: 105 | acc = torch.sum((y == output).float() * attn) 106 | total_acc += acc 107 | total_cnt += torch.sum(attn).item() 108 | else: 109 | acc = torch.sum((y == output).float()) 110 | total_acc += acc 111 | total_cnt += y.shape[0] 112 | 113 | # calculate losses 114 | if not seq: 115 | y_hat = y_hat.permute(0,2,1) 116 | loss = self.compute_loss(y_hat, y, attn, seq) 117 | total_loss += loss.item() 118 | 119 | # udpate only in train 120 | if mode == 0: 121 | self.model.zero_grad() 122 | loss.backward() 123 | self.optim.step() 124 | 125 | if mode == 2: 126 | return round(total_loss/len(training_data),4), round(total_acc.item()/total_cnt,4), all_output 127 | return round(total_loss/len(training_data),4), round(total_acc.item()/total_cnt,4) 128 | 129 | 130 | def save_checkpoint(self, epoch, train_acc, valid_acc, 131 | valid_loss, train_loss, is_best, filename): 132 | state = { 133 | 'epoch': epoch + 1, 134 | 'state_dict': self.model.module.state_dict(), 135 | 'valid_acc': valid_acc, 136 | 'valid_loss': valid_loss, 137 | 'train_loss': train_loss, 138 | 'train_acc': train_acc, 139 | 'optimizer' : self.optim.state_dict() 140 | } 141 | torch.save(state, filename) 142 | 143 | best_mdl = filename.split('.')[0]+'_best.ckpt' 144 | 145 | if is_best: 146 | shutil.copyfile(filename, best_mdl) 147 | 148 | -------------------------------------------------------------------------------- /MidiBERT/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import random 4 | import pickle 5 | import os 6 | import json 7 | 8 | from torch.utils.data import DataLoader 9 | from transformers import BertConfig 10 | from model import MidiBert 11 | from trainer import BERTTrainer 12 | from midi_dataset import MidiDataset 13 | 14 | 15 | def get_args(): 16 | parser = argparse.ArgumentParser(description='') 17 | 18 | ### path setup ### 19 | parser.add_argument('--dict_file', type=str, default='data_creation/prepare_data/dict/CP.pkl') 20 | parser.add_argument('--name', type=str, default='MidiBert') 21 | 22 | ### pre-train dataset ### 23 | parser.add_argument("--datasets", type=str, nargs='+', default=['pop909','composer', 'pop1k7', 'ASAP', 'emopia']) 24 | 25 | ### parameter setting ### 26 | parser.add_argument('--num_workers', type=int, default=5) 27 | parser.add_argument('--batch_size', type=int, default=12) 28 | parser.add_argument('--mask_percent', type=float, default=0.15, help="Up to `valid_seq_len * target_max_percent` tokens will be masked out for prediction") 29 | parser.add_argument('--max_seq_len', type=int, default=512, help='all sequences are padded to `max_seq_len`') 30 | parser.add_argument('--hs', type=int, default=768) # hidden state 31 | parser.add_argument('--epochs', type=int, default=500, help='number of training epochs') 32 | parser.add_argument('--lr', type=float, default=2e-5, help='initial learning rate') 33 | 34 | ### cuda ### 35 | parser.add_argument("--cpu", action="store_true") # default: False 36 | parser.add_argument("--cuda_devices", type=int, nargs='+', default=[0,1,2,3], help="CUDA device ids") 37 | 38 | args = parser.parse_args() 39 | 40 | return args 41 | 42 | 43 | def load_data(datasets): 44 | to_concat = [] 45 | root = 'Data/CP_data' 46 | 47 | for dataset in datasets: 48 | if dataset in {'pop909', 'composer', 'emopia'}: 49 | X_train = np.load(os.path.join(root, f'{dataset}_train.npy'), allow_pickle=True) 50 | X_valid = np.load(os.path.join(root, f'{dataset}_valid.npy'), allow_pickle=True) 51 | X_test = np.load(os.path.join(root, f'{dataset}_test.npy'), allow_pickle=True) 52 | data = np.concatenate((X_train, X_valid, X_test), axis=0) 53 | 54 | elif dataset == 'pop1k7' or dataset == 'ASAP': 55 | data = np.load(os.path.join(root, f'{dataset}.npy'), allow_pickle=True) 56 | 57 | print(f' {dataset}: {data.shape}') 58 | to_concat.append(data) 59 | 60 | 61 | training_data = np.vstack(to_concat) 62 | print(' > all training data:', training_data.shape) 63 | 64 | # shuffle during training phase 65 | index = np.arange(len(training_data)) 66 | np.random.shuffle(index) 67 | training_data = training_data[index] 68 | split = int(len(training_data)*0.85) 69 | X_train, X_val = training_data[:split], training_data[split:] 70 | 71 | return X_train, X_val 72 | 73 | 74 | def main(): 75 | args = get_args() 76 | 77 | print("Loading Dictionary") 78 | with open(args.dict_file, 'rb') as f: 79 | e2w, w2e = pickle.load(f) 80 | 81 | print("\nLoading Dataset", args.datasets) 82 | X_train, X_val = load_data(args.datasets) 83 | 84 | trainset = MidiDataset(X=X_train) 85 | validset = MidiDataset(X=X_val) 86 | 87 | train_loader = DataLoader(trainset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True) 88 | print(" len of train_loader",len(train_loader)) 89 | valid_loader = DataLoader(validset, batch_size=args.batch_size, num_workers=args.num_workers) 90 | print(" len of valid_loader",len(valid_loader)) 91 | 92 | print("\nBuilding BERT model") 93 | configuration = BertConfig(max_position_embeddings=args.max_seq_len, 94 | position_embedding_type='relative_key_query', 95 | hidden_size=args.hs) 96 | midibert = MidiBert(bertConfig=configuration, e2w=e2w, w2e=w2e) 97 | 98 | print("\nCreating BERT Trainer") 99 | trainer = BERTTrainer(midibert, train_loader, valid_loader, args.lr, args.batch_size, args.max_seq_len, args.mask_percent, args.cpu, args.cuda_devices) 100 | 101 | print("\nTraining Start") 102 | save_dir = 'MidiBERT/result/pretrain/' + args.name 103 | os.makedirs(save_dir, exist_ok=True) 104 | filename = os.path.join(save_dir, 'model.ckpt') 105 | print(" save model at {}".format(filename)) 106 | 107 | best_acc, best_epoch = 0, 0 108 | bad_cnt = 0 109 | 110 | for epoch in range(args.epochs): 111 | if bad_cnt >= 30: 112 | print('valid acc not improving for 30 epochs') 113 | break 114 | train_loss, train_acc = trainer.train() 115 | valid_loss, valid_acc = trainer.valid() 116 | 117 | weighted_score = [x*y for (x,y) in zip(valid_acc, midibert.n_tokens)] 118 | avg_acc = sum(weighted_score)/sum(midibert.n_tokens) 119 | 120 | is_best = avg_acc > best_acc 121 | best_acc = max(avg_acc, best_acc) 122 | 123 | if is_best: 124 | bad_cnt, best_epoch = 0, epoch 125 | else: 126 | bad_cnt += 1 127 | 128 | print('epoch: {}/{} | Train Loss: {} | Train acc: {} | Valid Loss: {} | Valid acc: {}'.format( 129 | epoch+1, args.epochs, train_loss, train_acc, valid_loss, valid_acc)) 130 | 131 | trainer.save_checkpoint(epoch, best_acc, valid_acc, 132 | valid_loss, train_loss, is_best, filename) 133 | 134 | 135 | with open(os.path.join(save_dir, 'log'), 'a') as outfile: 136 | outfile.write('Epoch {}: train_loss={}, train_acc={}, valid_loss={}, valid_acc={}\n'.format( 137 | epoch+1, train_loss, train_acc, valid_loss, valid_acc)) 138 | 139 | 140 | if __name__ == '__main__': 141 | main() 142 | -------------------------------------------------------------------------------- /MidiBERT/midi_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torch 3 | 4 | class MidiDataset(Dataset): 5 | """ 6 | Expected data shape: (data_num, data_len) 7 | """ 8 | def __init__(self, X): 9 | self.data = X 10 | 11 | def __len__(self): 12 | return(len(self.data)) 13 | 14 | def __getitem__(self, index): 15 | return torch.tensor(self.data[index]) 16 | -------------------------------------------------------------------------------- /MidiBERT/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import random 4 | 5 | import torch 6 | import torch.nn as nn 7 | from transformers import BertModel 8 | 9 | class Embeddings(nn.Module): 10 | def __init__(self, n_token, d_model): 11 | super().__init__() 12 | self.lut = nn.Embedding(n_token, d_model) 13 | self.d_model = d_model 14 | 15 | def forward(self, x): 16 | return self.lut(x) * math.sqrt(self.d_model) 17 | 18 | 19 | # BERT model: similar approach to "felix" 20 | class MidiBert(nn.Module): 21 | def __init__(self, bertConfig, e2w, w2e): 22 | super().__init__() 23 | 24 | self.bert = BertModel(bertConfig) 25 | bertConfig.d_model = bertConfig.hidden_size 26 | self.hidden_size = bertConfig.hidden_size 27 | self.bertConfig = bertConfig 28 | 29 | # token types: [Bar, Position, Pitch, Duration] 30 | self.n_tokens = [] # [3,18,88,66] 31 | self.classes = ['Bar', 'Position', 'Pitch', 'Duration'] 32 | for key in self.classes: 33 | self.n_tokens.append(len(e2w[key])) 34 | self.emb_sizes = [256, 256, 256, 256] 35 | self.e2w = e2w 36 | self.w2e = w2e 37 | 38 | # for deciding whether the current input_ids is a token 39 | self.bar_pad_word = self.e2w['Bar']['Bar '] 40 | self.mask_word_np = np.array([self.e2w[etype]['%s ' % etype] for etype in self.classes], dtype=np.long) 41 | self.pad_word_np = np.array([self.e2w[etype]['%s ' % etype] for etype in self.classes], dtype=np.long) 42 | 43 | # word_emb: embeddings to change token ids into embeddings 44 | self.word_emb = [] 45 | for i, key in enumerate(self.classes): 46 | self.word_emb.append(Embeddings(self.n_tokens[i], self.emb_sizes[i])) 47 | self.word_emb = nn.ModuleList(self.word_emb) 48 | 49 | # linear layer to merge embeddings from different token types 50 | self.in_linear = nn.Linear(np.sum(self.emb_sizes), bertConfig.d_model) 51 | 52 | 53 | def forward(self, input_ids, attn_mask=None, output_hidden_states=True): 54 | # convert input_ids into embeddings and merge them through linear layer 55 | embs = [] 56 | for i, key in enumerate(self.classes): 57 | embs.append(self.word_emb[i](input_ids[..., i])) 58 | embs = torch.cat([*embs], dim=-1) 59 | emb_linear = self.in_linear(embs) 60 | 61 | # feed to bert 62 | y = self.bert(inputs_embeds=emb_linear, attention_mask=attn_mask, output_hidden_states=output_hidden_states) 63 | #y = y.last_hidden_state # (batch_size, seq_len, 768) 64 | return y 65 | 66 | def get_rand_tok(self): 67 | c1,c2,c3,c4 = self.n_tokens[0], self.n_tokens[1], self.n_tokens[2], self.n_tokens[3] 68 | return np.array([random.choice(range(c1)),random.choice(range(c2)),random.choice(range(c3)),random.choice(range(c4))]) 69 | -------------------------------------------------------------------------------- /MidiBERT/modelLM.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import random 4 | 5 | import torch 6 | import torch.nn as nn 7 | from transformers import BertModel 8 | 9 | from MidiBERT.model import MidiBert 10 | 11 | 12 | class MidiBertLM(nn.Module): 13 | def __init__(self, midibert: MidiBert): 14 | super().__init__() 15 | 16 | self.midibert = midibert 17 | self.mask_lm = MLM(self.midibert.e2w, self.midibert.n_tokens, self.midibert.hidden_size) 18 | 19 | def forward(self, x, attn): 20 | x = self.midibert(x, attn) 21 | return self.mask_lm(x) 22 | 23 | 24 | class MLM(nn.Module): 25 | def __init__(self, e2w, n_tokens, hidden_size): 26 | super().__init__() 27 | 28 | # proj: project embeddings to logits for prediction 29 | self.proj = [] 30 | for i, etype in enumerate(e2w): 31 | self.proj.append(nn.Linear(hidden_size, n_tokens[i])) 32 | self.proj = nn.ModuleList(self.proj) 33 | 34 | self.e2w = e2w 35 | 36 | def forward(self, y): 37 | # feed to bert 38 | y = y.hidden_states[-1] 39 | 40 | # convert embeddings back to logits for prediction 41 | ys = [] 42 | for i, etype in enumerate(self.e2w): 43 | ys.append(self.proj[i](y)) # (batch_size, seq_len, dict_size) 44 | return ys 45 | 46 | -------------------------------------------------------------------------------- /MidiBERT/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import AdamW 4 | from torch.nn.utils import clip_grad_norm_ 5 | 6 | import numpy as np 7 | import random 8 | import tqdm 9 | import sys 10 | import shutil 11 | import copy 12 | 13 | from MidiBERT.model import MidiBert 14 | from MidiBERT.modelLM import MidiBertLM 15 | 16 | 17 | class BERTTrainer: 18 | def __init__(self, midibert: MidiBert, train_dataloader, valid_dataloader, 19 | lr, batch, max_seq_len, mask_percent, cpu, cuda_devices=None): 20 | self.device = torch.device("cuda" if torch.cuda.is_available() and not cpu else 'cpu') 21 | self.midibert = midibert # save this for ckpt 22 | self.model = MidiBertLM(midibert).to(self.device) 23 | self.total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) 24 | print('# total parameters:', self.total_params) 25 | 26 | if torch.cuda.device_count() > 1 and not cpu: 27 | print("Use %d GPUS" % torch.cuda.device_count()) 28 | self.model = nn.DataParallel(self.model, device_ids=cuda_devices) 29 | 30 | self.train_data = train_dataloader 31 | self.valid_data = valid_dataloader 32 | 33 | self.optim = AdamW(self.model.parameters(), lr=lr, weight_decay=0.01) 34 | self.batch = batch 35 | self.max_seq_len = max_seq_len 36 | self.mask_percent = mask_percent 37 | self.Lseq = [i for i in range(self.max_seq_len)] 38 | self.loss_func = nn.CrossEntropyLoss(reduction='none') 39 | 40 | def compute_loss(self, predict, target, loss_mask): 41 | loss = self.loss_func(predict, target) 42 | loss = loss * loss_mask 43 | loss = torch.sum(loss) / torch.sum(loss_mask) 44 | return loss 45 | 46 | def get_mask_ind(self): 47 | mask_ind = random.sample(self.Lseq, round(self.max_seq_len * self.mask_percent)) 48 | mask80 = random.sample(mask_ind, round(len(mask_ind)*0.8)) 49 | left = list(set(mask_ind)-set(mask80)) 50 | rand10 = random.sample(left, round(len(mask_ind)*0.1)) 51 | cur10 = list(set(left)-set(rand10)) 52 | return mask80, rand10, cur10 53 | 54 | 55 | def train(self): 56 | self.model.train() 57 | train_loss, train_acc = self.iteration(self.train_data, self.max_seq_len) 58 | return train_loss, train_acc 59 | 60 | def valid(self): 61 | self.model.eval() 62 | valid_loss, valid_acc = self.iteration(self.valid_data, self.max_seq_len, train=False) 63 | return valid_loss, valid_acc 64 | 65 | def iteration(self, training_data, max_seq_len, train=True): 66 | pbar = tqdm.tqdm(training_data, disable=False) 67 | 68 | total_acc, total_losses = [0]*len(self.midibert.e2w), 0 69 | 70 | for ori_seq_batch in pbar: 71 | batch = ori_seq_batch.shape[0] 72 | ori_seq_batch = ori_seq_batch.to(self.device) # (batch, seq_len, 4) 73 | input_ids = copy.deepcopy(ori_seq_batch) 74 | loss_mask = torch.zeros(batch, max_seq_len) 75 | 76 | for b in range(batch): 77 | # get index for masking 78 | mask80, rand10, cur10 = self.get_mask_ind() 79 | # apply mask, random, remain current token 80 | for i in mask80: 81 | mask_word = torch.tensor(self.midibert.mask_word_np).to(self.device) 82 | input_ids[b][i] = mask_word 83 | loss_mask[b][i] = 1 84 | for i in rand10: 85 | rand_word = torch.tensor(self.midibert.get_rand_tok()).to(self.device) 86 | input_ids[b][i] = rand_word 87 | loss_mask[b][i] = 1 88 | for i in cur10: 89 | loss_mask[b][i] = 1 90 | 91 | loss_mask = loss_mask.to(self.device) 92 | 93 | # avoid attend to pad word 94 | attn_mask = (input_ids[:, :, 0] != self.midibert.bar_pad_word).float().to(self.device) # (batch, seq_len) 95 | 96 | y = self.model.forward(input_ids, attn_mask) 97 | 98 | # get the most likely choice with max 99 | outputs = [] 100 | for i, etype in enumerate(self.midibert.e2w): 101 | output = np.argmax(y[i].cpu().detach().numpy(), axis=-1) 102 | outputs.append(output) 103 | outputs = np.stack(outputs, axis=-1) 104 | outputs = torch.from_numpy(outputs).to(self.device) # (batch, seq_len) 105 | 106 | # accuracy 107 | all_acc = [] 108 | for i in range(4): 109 | acc = torch.sum((ori_seq_batch[:,:,i] == outputs[:,:,i]).float() * loss_mask) 110 | acc /= torch.sum(loss_mask) 111 | all_acc.append(acc) 112 | total_acc = [sum(x) for x in zip(total_acc, all_acc)] 113 | 114 | # reshape (b, s, f) -> (b, f, s) 115 | for i, etype in enumerate(self.midibert.e2w): 116 | #print('before',y[i][:,...].shape) # each: (4,512,5), (4,512,20), (4,512,90), (4,512,68) 117 | y[i] = y[i][:, ...].permute(0, 2, 1) 118 | 119 | # calculate losses 120 | losses, n_tok = [], [] 121 | for i, etype in enumerate(self.midibert.e2w): 122 | n_tok.append(len(self.midibert.e2w[etype])) 123 | losses.append(self.compute_loss(y[i], ori_seq_batch[..., i], loss_mask)) 124 | total_loss_all = [x*y for x, y in zip(losses, n_tok)] 125 | total_loss = sum(total_loss_all)/sum(n_tok) # weighted 126 | 127 | # udpate only in train 128 | if train: 129 | self.model.zero_grad() 130 | total_loss.backward() 131 | clip_grad_norm_(self.model.parameters(), 3.0) 132 | self.optim.step() 133 | 134 | # acc 135 | accs = list(map(float, all_acc)) 136 | sys.stdout.write('Loss: {:06f} | loss: {:03f}, {:03f}, {:03f}, {:03f} | acc: {:03f}, {:03f}, {:03f}, {:03f} \r'.format( 137 | total_loss, *losses, *accs)) 138 | 139 | losses = list(map(float, losses)) 140 | total_losses += total_loss.item() 141 | 142 | return round(total_losses/len(training_data),3), [round(x.item()/len(training_data),3) for x in total_acc] 143 | 144 | def save_checkpoint(self, epoch, best_acc, valid_acc, 145 | valid_loss, train_loss, is_best, filename): 146 | state = { 147 | 'epoch': epoch + 1, 148 | 'state_dict': self.midibert.state_dict(), 149 | 'best_acc': best_acc, 150 | 'valid_acc': valid_acc, 151 | 'valid_loss': valid_loss, 152 | 'train_loss': train_loss, 153 | 'optimizer' : self.optim.state_dict() 154 | } 155 | 156 | torch.save(state, filename) 157 | 158 | best_mdl = filename.split('.')[0]+'_best.ckpt' 159 | if is_best: 160 | shutil.copyfile(filename, best_mdl) 161 | 162 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MidiBERT-Piano 2 |

3 | 4 |
5 | MIT License 6 | ARXIV LICENSE 7 | STAR 8 | ISSUE 9 | 10 |

11 | Authors: Yi-Hui (Sophia) Chou, I-Chun (Bronwin) Chen 12 | 13 | ## Introduction 14 | This is the official repository for the paper, [MidiBERT-Piano: Large-scale Pre-training for Symbolic Music Understanding](https://arxiv.org/pdf/2107.05223.pdf). 15 | 16 | With this repository, you can 17 | * pre-train a MidiBERT-Piano with your customized pre-trained dataset 18 | * fine-tune & evaluate on 4 downstream tasks 19 | * extract melody (mid to mid) using pre-trained MidiBERT-Piano 20 | 21 | All the datasets employed in this work are publicly available. 22 | 23 | 24 | ## Quick Start 25 | ### For programmers 26 | If you'd like to reproduce the results (MidiBERT) shown in the paper, 27 | ![image-20210710185007453](resources/fig/result.png) 28 | 29 | 1. Please download the [checkpoints](https://drive.google.com/drive/folders/1ceIfC1UugZQHPgpEEMkdAF0VhZ1EeLl3?usp=sharing), and rename files like the following 30 | 31 | (Note: we only provide checkpoints for models in CP representations) 32 | ``` 33 | result/ 34 | └── finetune/ 35 | └── melody_default/ 36 | └── model_best.ckpt 37 | └── velocity_default/ 38 | └── model_best.ckpt 39 | └── composer_default/ 40 | └── model_best.ckpt 41 | └── emotion_default/ 42 | └── model_best.ckpt 43 | ``` 44 | 45 | 2. Run `./scripts/eval.sh` 46 | 47 | Or refer to Readme in MidiBERT folder for more details. 48 | 49 | *No gpu is needed for evaluation* 50 | 51 | ### For musicians who want to test melody extraction 52 | Edit `scripts/melody_extraction.sh` and modify `song_path` to your midi path. 53 | The midi file to predicted melody will be saved at the root folder. 54 | ``` 55 | ./scripts/melody_extraction.sh 56 | ``` 57 | #### Windows Users 58 | ``` 59 | # modify this line (export PYTHONPATH='.') to the following 60 | set PYTHONPATH='.' 61 | # print the environment variable to make sure it's working 62 | echo %PYTHONPATH% 63 | ``` 64 | I've experimented this on Adele hello (piano cover), and I think it's good. 65 | But for non-pop music like Mozart sonata, I feel like the model is pretty confused. This is expected. As the training data is POP909 Dataset, the model knows very little about classical music. 66 | 67 | Side note: I try to make it more friendly for non-programmers. Feel free to open an issue if there's any problem. 68 | 69 | ## Installation 70 | * Python3 71 | * Install generally used packages for MidiBERT-Piano: 72 | ```python 73 | git clone https://github.com/wazenmai/MIDI-BERT.git 74 | cd MIDI-BERT 75 | pip install -r requirements.txt 76 | ``` 77 | 78 | ## Usage 79 | Please see `scripts` folder, which includes bash file for 80 | * prepare data 81 | * pretrain 82 | * finetune 83 | * evaluation 84 | * melody extraction 85 | 86 | You may need to change the folder/file name or any config settings you prefer. 87 | 88 | 89 | ## Repo Structure 90 | ``` 91 | Data/ 92 | └── Dataset/ 93 | └── pop909/ 94 | └── .../ 95 | └── CP_data/ 96 | └── pop909_train.npy 97 | └── *.npy 98 | 99 | data_creation/ 100 | └── preprocess_pop909/ 101 | └── prepare_data/ # convert midi to CP_data 102 | └── dict/ # CP dictionary 103 | 104 | melody_extraction/ 105 | └── skyline/ 106 | └── midibert/ 107 | 108 | MidiBERT/ 109 | └── *py 110 | 111 | ``` 112 | 113 | ## More 114 | For more details on 115 | * data preparation, please go to `data_creation` and follow Readme 116 | * MidiBERT pretraining, finetuning, evaluation, please go to `MidiBERT` and follow Readme. 117 | * skyline, please go to `melody_extraction/skyline` and follow Readme. 118 | * pianoroll figure generation, please go to `melody_extraction/pianoroll` and follow Readme. We also provide clearer pianoroll pictures of the paper. 119 | * listening to melody extraction results, please go to `melody_extraction/audio` and read Readme for more details. 120 | 121 | Note that Baseline (LSTM) and code in remi versions are removed for cleaness. But you could find them in `main` branch. 122 | 123 | ## Citation 124 | 125 | If you find this useful, please cite our paper. 126 | 127 | ``` 128 | @article{midibertpiano, 129 | title={{MidiBERT-Piano}: Large-scale Pre-training for Symbolic Music Understanding}, 130 | author={Yi-Hui Chou and I-Chun Chen and Chin-Jui Chang and Joann Ching, and Yi-Hsuan Yang}, 131 | journal={arXiv preprint arXiv:2107.05223}, 132 | year={2021} 133 | } 134 | ``` 135 | 136 | -------------------------------------------------------------------------------- /data_creation/README.md: -------------------------------------------------------------------------------- 1 | # Data Creation 2 | 3 | All data in CP token are already in `Data/CP_data`, including the train, valid, test split. 4 | 5 | You can also preprocess as below. 6 | 7 | ## 1. Download Dataset and Preprocess 8 | Save the following dataset in `Dataset/` 9 | * [Pop1K7](https://github.com/YatingMusic/compound-word-transformer) 10 | * [ASAP](https://github.com/fosfrancesco/asap-dataset) 11 | * Download ASAP dataset from the link 12 | * [POP909](https://github.com/music-x-lab/POP909-Dataset) 13 | * preprocess to have 865 pieces in qualified 4/4 time signature 14 | * ```cd data_creation/preprocess_pop909``` 15 | * ```exploratory.py``` to get pieces qualified in 4/4 time signature and save them at ```qual_pieces.pkl``` 16 | * ```preprocess.py``` to realign and preprocess 17 | * Special thanks to Shih-Lun (Sean) Wu 18 | * [Pianist8](https://zenodo.org/record/5089279) 19 | * Step 1: Download Pianist8 dataset from the link 20 | * Step 2: Run `python3 pianist8.py` to split data by `Dataset/pianist8_(split).pkl` 21 | * [EMOPIA](https://annahung31.github.io/EMOPIA/) 22 | * Step 1: Download Emopia dataset from the link 23 | * Step 2: Run `python3 emopia.py` to split data by `Dataset/emopia_(split).pkl` 24 | 25 | ## 2. Prepare Dictionary 26 | 27 | ```cd data_creation/prepare_data/dict/``` 28 | Run ```python make_dict.py``` to customize the events & words you'd like to add. 29 | 30 | In this paper, we only use *Bar*, *Position*, *Pitch*, *Duration*. And we provide our dictionaries in CP representation. (```data_creation/prepare_data/dict/CP.pkl```) 31 | 32 | ## 3. Prepare CP 33 | Note that the CP tokens here only contain Bar, Position, Pitch, and Duration. Please look into the repos below if you prefer the original definition of CP tokens. 34 | 35 | All the commands are in ```scripts/prepare_data.sh```. You can directly edit the script and run it. 36 | 37 | (Note that `export PYTHONPATH='.'` is necessary.) 38 | 39 | ### Melody task 40 | ``` 41 | python3 data_creation/prepare_data/main.py --dataset=pop909 --task=melody 42 | ``` 43 | 44 | ### Velocity task 45 | ``` 46 | python3 data_creation/prepare_data/main.py --dataset=pop909 --task=velocity 47 | ``` 48 | 49 | ### Composer task 50 | ``` 51 | python3 data_creation/prepare_data/main.py --dataset=pianist8 --task=composer 52 | ``` 53 | 54 | ### Emotion task 55 | ``` 56 | python3 data_creation/prepare_data/main.py --dataset=emopia --task=emotion 57 | ``` 58 | 59 | ### Custom input path 60 | * A directory to many midi files 61 | ``` 62 | python3 data_creation/prepare_data/main.py --input_dir=$input_dir 63 | ``` 64 | 65 | * A single midi file 66 | ``` 67 | python3 data_creation/prepare_data/main.py --input_file="${input_dir}/pop.mid" 68 | ``` 69 | 70 | You can also specify the filename by adding `--name={name}`. 71 | 72 | The CP tokens will be saved in ```Data/CP_data/``` 73 | 74 | Acknowledgement: [CP repo](https://github.com/YatingMusic/compound-word-transformer) 75 | -------------------------------------------------------------------------------- /data_creation/prepare_data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/data_creation/prepare_data/__init__.py -------------------------------------------------------------------------------- /data_creation/prepare_data/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/data_creation/prepare_data/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /data_creation/prepare_data/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/data_creation/prepare_data/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /data_creation/prepare_data/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/data_creation/prepare_data/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /data_creation/prepare_data/dict/CP.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/data_creation/prepare_data/dict/CP.pkl -------------------------------------------------------------------------------- /data_creation/prepare_data/dict/make_dict.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | event2word = {'Bar': {}, 'Position': {}, 'Pitch': {}, 'Duration': {}} 4 | word2event = {'Bar': {}, 'Position': {}, 'Pitch': {}, 'Duration': {}} 5 | 6 | def special_tok(cnt, cls): 7 | '''event2word[cls][cls+' '] = cnt 8 | word2event[cls][cnt] = cls+' ' 9 | cnt += 1 10 | 11 | event2word[cls][cls+' '] = cnt 12 | word2event[cls][cnt] = cls+' ' 13 | cnt += 1''' 14 | 15 | event2word[cls][cls+' '] = cnt 16 | word2event[cls][cnt] = cls+' ' 17 | cnt += 1 18 | 19 | event2word[cls][cls+' '] = cnt 20 | word2event[cls][cnt] = cls+' ' 21 | cnt += 1 22 | 23 | 24 | # Bar 25 | cnt, cls = 0, 'Bar' 26 | event2word[cls]['Bar New'] = cnt 27 | word2event[cls][cnt] = 'Bar New' 28 | cnt += 1 29 | 30 | event2word[cls]['Bar Continue'] = cnt 31 | word2event[cls][cnt] = 'Bar Continue' 32 | cnt += 1 33 | special_tok(cnt, cls) 34 | 35 | # Position 36 | cnt, cls = 0, 'Position' 37 | for i in range(1, 17): 38 | event2word[cls][f'Position {i}/16'] = cnt 39 | word2event[cls][cnt]= f'Position {i}/16' 40 | cnt += 1 41 | 42 | special_tok(cnt, cls) 43 | 44 | # Note On 45 | cnt, cls = 0, 'Pitch' 46 | for i in range(22, 108): 47 | event2word[cls][f'Pitch {i}'] = cnt 48 | word2event[cls][cnt] = f'Pitch {i}' 49 | cnt += 1 50 | 51 | special_tok(cnt, cls) 52 | 53 | # Note Duration 54 | cnt, cls = 0, 'Duration' 55 | for i in range(64): 56 | event2word[cls][f'Duration {i}'] = cnt 57 | word2event[cls][cnt] = f'Duration {i}' 58 | cnt += 1 59 | 60 | special_tok(cnt, cls) 61 | 62 | print(event2word) 63 | print(word2event) 64 | t = (event2word, word2event) 65 | 66 | with open('CP.pkl', 'wb') as f: 67 | pickle.dump(t, f) 68 | 69 | -------------------------------------------------------------------------------- /data_creation/prepare_data/dict/remi.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/data_creation/prepare_data/dict/remi.pkl -------------------------------------------------------------------------------- /data_creation/prepare_data/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import glob 4 | import pickle 5 | import pathlib 6 | import argparse 7 | import numpy as np 8 | from data_creation.prepare_data.model import * 9 | 10 | 11 | def get_args(): 12 | parser = argparse.ArgumentParser(description='') 13 | ### mode ### 14 | parser.add_argument('-t', '--task', default='', choices=['melody', 'velocity', 'composer', 'emotion']) 15 | 16 | ### path ### 17 | parser.add_argument('--dict', type=str, default='data_creation/prepare_data/dict/CP.pkl') 18 | parser.add_argument('--dataset', type=str, choices=["pop909", "pop1k7", "ASAP", "pianist8", "emopia"]) 19 | parser.add_argument('--input_dir', type=str, default='') 20 | parser.add_argument('--input_file', type=str, default='') 21 | 22 | ### parameter ### 23 | parser.add_argument('--max_len', type=int, default=512) 24 | 25 | ### output ### 26 | parser.add_argument('--output_dir', default="Data/CP_data/tmp") 27 | parser.add_argument('--name', default="") # will be saved as "{output_dir}/{name}.npy" 28 | 29 | args = parser.parse_args() 30 | 31 | if args.task == 'melody' and args.dataset != 'pop909': 32 | print('[error] melody task is only supported for pop909 dataset') 33 | exit(1) 34 | elif args.task == 'composer' and args.dataset != 'pianist8': 35 | print('[error] composer task is only supported for pianist8 dataset') 36 | exit(1) 37 | elif args.task == 'emotion' and args.dataset != 'emopia': 38 | print('[error] emotion task is only supported for emopia dataset') 39 | exit(1) 40 | elif args.dataset == None and args.input_dir == None and args.input_file == None: 41 | print('[error] Please specify the input directory or dataset') 42 | exit(1) 43 | 44 | return args 45 | 46 | 47 | def extract(files, args, model, mode=''): 48 | ''' 49 | files: list of midi path 50 | mode: 'train', 'valid', 'test', '' 51 | args.input_dir: '' or the directory to your custom data 52 | args.output_dir: the directory to store the data (and answer data) in CP representation 53 | ''' 54 | assert len(files) 55 | 56 | print(f'Number of {mode} files: {len(files)}') 57 | 58 | segments, ans = model.prepare_data(files, args.task, int(args.max_len)) 59 | 60 | dataset = args.dataset if args.dataset != 'pianist8' else 'composer' 61 | 62 | if args.input_dir != '' or args.input_file != '': 63 | name = args.input_dir or args.input_file 64 | if args.name == '': 65 | args.name = Path(name).stem 66 | output_file = os.path.join(args.output_dir, f'{args.name}.npy') 67 | elif dataset == 'composer' or dataset == 'emopia' or dataset == 'pop909': 68 | output_file = os.path.join(args.output_dir, f'{dataset}_{mode}.npy') 69 | elif dataset == 'pop1k7' or dataset == 'ASAP': 70 | output_file = os.path.join(args.output_dir, f'{dataset}.npy') 71 | 72 | np.save(output_file, segments) 73 | print(f'Data shape: {segments.shape}, saved at {output_file}') 74 | 75 | if args.task != '': 76 | if args.task == 'melody' or args.task == 'velocity': 77 | ans_file = os.path.join(args.output_dir, f'{dataset}_{mode}_{args.task[:3]}ans.npy') 78 | elif args.task == 'composer' or args.task == 'emotion': 79 | ans_file = os.path.join(args.output_dir, f'{dataset}_{mode}_ans.npy') 80 | np.save(ans_file, ans) 81 | print(f'Answer shape: {ans.shape}, saved at {ans_file}') 82 | 83 | 84 | def main(): 85 | args = get_args() 86 | pathlib.Path(args.output_dir).mkdir(parents=True, exist_ok=True) 87 | 88 | # initialize model 89 | model = CP(dict=args.dict) 90 | 91 | if args.dataset == 'pop909': 92 | dataset = 'pop909_processed' 93 | elif args.dataset == 'emopia': 94 | dataset = 'EMOPIA_1.0' 95 | elif args.dataset == 'pianist8': 96 | dataset = 'joann8512-Pianist8-ab9f541' 97 | 98 | if args.dataset == 'pop909' or args.dataset == 'emopia': 99 | train_files = glob.glob(f'Data/Dataset/{dataset}/train/*.mid') 100 | valid_files = glob.glob(f'Data/Dataset/{dataset}/valid/*.mid') 101 | test_files = glob.glob(f'Data/Dataset/{dataset}/test/*.mid') 102 | 103 | elif args.dataset == 'pianist8': 104 | train_files = glob.glob(f'Data/Dataset/{dataset}/train/*/*.mid') 105 | valid_files = glob.glob(f'Data/Dataset/{dataset}/valid/*/*.mid') 106 | test_files = glob.glob(f'Data/Dataset/{dataset}/test/*/*.mid') 107 | 108 | elif args.dataset == 'pop1k7': 109 | files = glob.glob('Data/Dataset/dataset/midi_transcribed/*/*.midi') 110 | 111 | elif args.dataset == 'ASAP': 112 | files = pickle.load(open('Data/Dataset/ASAP_song.pkl', 'rb')) 113 | files = [f'Dataset/asap-dataset/{file}' for file in files] 114 | 115 | elif args.input_dir: 116 | files = glob.glob(f'{args.input_dir}/*.mid') 117 | 118 | elif args.input_file: 119 | files = [args.input_file] 120 | 121 | else: 122 | print('not supported') 123 | exit(1) 124 | 125 | 126 | if args.dataset in {'pop909', 'emopia', 'pianist8'}: 127 | extract(train_files, args, model, 'train') 128 | extract(valid_files, args, model, 'valid') 129 | extract(test_files, args, model, 'test') 130 | else: 131 | # in one single file 132 | extract(files, args, model) 133 | 134 | if __name__ == '__main__': 135 | main() 136 | -------------------------------------------------------------------------------- /data_creation/prepare_data/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | from tqdm import tqdm 4 | import data_creation.prepare_data.utils as utils 5 | 6 | Composer = { 7 | "Bethel": 0, 8 | "Clayderman": 1, 9 | "Einaudi": 2, 10 | "Hancock": 3, 11 | "Hillsong": 4, 12 | "Hisaishi": 5, 13 | "Ryuichi": 6, 14 | "Yiruma": 7, 15 | "Padding": 8, 16 | } 17 | 18 | Emotion = { 19 | "Q1": 0, 20 | "Q2": 1, 21 | "Q3": 2, 22 | "Q4": 3, 23 | } 24 | 25 | class CP(object): 26 | def __init__(self, dict): 27 | # load dictionary 28 | self.event2word, self.word2event = pickle.load(open(dict, 'rb')) 29 | # pad word: ['Bar ', 'Position ', 'Pitch ', 'Duration '] 30 | self.pad_word = [self.event2word[etype]['%s ' % etype] for etype in self.event2word] 31 | 32 | def extract_events(self, input_path, task): 33 | note_items, tempo_items = utils.read_items(input_path) 34 | if len(note_items) == 0: # if the midi contains nothing 35 | return None 36 | note_items = utils.quantize_items(note_items) 37 | max_time = note_items[-1].end 38 | items = tempo_items + note_items 39 | 40 | groups = utils.group_items(items, max_time) 41 | events = utils.item2event(groups, task) 42 | return events 43 | 44 | def padding(self, data, max_len, ans): 45 | pad_len = max_len - len(data) 46 | for _ in range(pad_len): 47 | if not ans: 48 | data.append(self.pad_word) 49 | else: 50 | data.append(0) 51 | 52 | return data 53 | 54 | def prepare_data(self, midi_paths, task, max_len): 55 | all_words, all_ys = [], [] 56 | 57 | for path in tqdm(midi_paths): 58 | # extract events 59 | events = self.extract_events(path, task) 60 | if not events: # if midi contains nothing 61 | print(f'skip {path} because it is empty') 62 | continue 63 | # events to words 64 | words, ys = [], [] 65 | for note_tuple in events: 66 | nts, to_class = [], -1 67 | for e in note_tuple: 68 | e_text = '{} {}'.format(e.name, e.value) 69 | nts.append(self.event2word[e.name][e_text]) 70 | if e.name == 'Pitch': 71 | to_class = e.Type 72 | words.append(nts) 73 | if task == 'melody' or task == 'velocity': 74 | ys.append(to_class+1) 75 | 76 | # slice to chunks so that max length = max_len (default: 512) 77 | slice_words, slice_ys = [], [] 78 | for i in range(0, len(words), max_len): 79 | slice_words.append(words[i:i+max_len]) 80 | if task == "composer": 81 | name = path.split('/')[-2] 82 | slice_ys.append(Composer[name]) 83 | elif task == "emotion": 84 | name = path.split('/')[-1].split('_')[0] 85 | slice_ys.append(Emotion[name]) 86 | else: 87 | slice_ys.append(ys[i:i+max_len]) 88 | 89 | # padding or drop 90 | # drop only when the task is 'composer' and the data length < max_len//2 91 | if len(slice_words[-1]) < max_len: 92 | if task == 'composer' and len(slice_words[-1]) < max_len//2: 93 | slice_words.pop() 94 | slice_ys.pop() 95 | else: 96 | slice_words[-1] = self.padding(slice_words[-1], max_len, ans=False) 97 | 98 | if (task == 'melody' or task == 'velocity') and len(slice_ys[-1]) < max_len: 99 | slice_ys[-1] = self.padding(slice_ys[-1], max_len, ans=True) 100 | 101 | all_words = all_words + slice_words 102 | all_ys = all_ys + slice_ys 103 | 104 | all_words = np.array(all_words) 105 | all_ys = np.array(all_ys) 106 | 107 | return all_words, all_ys 108 | -------------------------------------------------------------------------------- /data_creation/prepare_data/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import miditoolkit 3 | import copy 4 | 5 | # parameters for input 6 | DEFAULT_VELOCITY_BINS = np.array([ 0, 32, 48, 64, 80, 96, 128]) # np.linspace(0, 128, 32+1, dtype=np.int) 7 | DEFAULT_FRACTION = 16 8 | DEFAULT_DURATION_BINS = np.arange(60, 3841, 60, dtype=int) 9 | DEFAULT_TEMPO_INTERVALS = [range(30, 90), range(90, 150), range(150, 210)] 10 | 11 | # parameters for output 12 | DEFAULT_RESOLUTION = 480 13 | 14 | # define "Item" for general storage 15 | class Item(object): 16 | def __init__(self, name, start, end, velocity, pitch, Type): 17 | self.name = name 18 | self.start = start 19 | self.end = end 20 | self.velocity = velocity 21 | self.pitch = pitch 22 | self.Type = Type 23 | 24 | def __repr__(self): 25 | return 'Item(name={}, start={}, end={}, velocity={}, pitch={}, Type={})'.format( 26 | self.name, self.start, self.end, self.velocity, self.pitch, self.Type) 27 | 28 | # read notes and tempo changes from midi (assume there is only one track) 29 | def read_items(file_path): 30 | midi_obj = miditoolkit.midi.parser.MidiFile(file_path) 31 | # note 32 | note_items = [] 33 | num_of_instr = len(midi_obj.instruments) 34 | 35 | for i in range(num_of_instr): 36 | notes = midi_obj.instruments[i].notes 37 | notes.sort(key=lambda x: (x.start, x.pitch)) 38 | 39 | for note in notes: 40 | note_items.append(Item( 41 | name='Note', 42 | start=note.start, 43 | end=note.end, 44 | velocity=note.velocity, 45 | pitch=note.pitch, 46 | Type=i)) 47 | 48 | note_items.sort(key=lambda x: x.start) 49 | 50 | # tempo 51 | tempo_items = [] 52 | for tempo in midi_obj.tempo_changes: 53 | tempo_items.append(Item( 54 | name='Tempo', 55 | start=tempo.time, 56 | end=None, 57 | velocity=None, 58 | pitch=int(tempo.tempo), 59 | Type=-1)) 60 | tempo_items.sort(key=lambda x: x.start) 61 | 62 | # expand to all beat 63 | max_tick = tempo_items[-1].start 64 | existing_ticks = {item.start: item.pitch for item in tempo_items} 65 | wanted_ticks = np.arange(0, max_tick+1, DEFAULT_RESOLUTION) 66 | output = [] 67 | for tick in wanted_ticks: 68 | if tick in existing_ticks: 69 | output.append(Item( 70 | name='Tempo', 71 | start=tick, 72 | end=None, 73 | velocity=None, 74 | pitch=existing_ticks[tick], 75 | Type=-1)) 76 | else: 77 | output.append(Item( 78 | name='Tempo', 79 | start=tick, 80 | end=None, 81 | velocity=None, 82 | pitch=output[-1].pitch, 83 | Type=-1)) 84 | tempo_items = output 85 | return note_items, tempo_items 86 | 87 | 88 | class Event(object): 89 | def __init__(self, name, time, value, text, Type): 90 | self.name = name 91 | self.time = time 92 | self.value = value 93 | self.text = text 94 | self.Type = Type 95 | 96 | def __repr__(self): 97 | return 'Event(name={}, time={}, value={}, text={}, Type={})'.format( 98 | self.name, self.time, self.value, self.text, self.Type) 99 | 100 | 101 | def item2event(groups, task): 102 | events = [] 103 | n_downbeat = 0 104 | for i in range(len(groups)): 105 | if 'Note' not in [item.name for item in groups[i][1:-1]]: 106 | continue 107 | bar_st, bar_et = groups[i][0], groups[i][-1] 108 | n_downbeat += 1 109 | new_bar = True 110 | 111 | for item in groups[i][1:-1]: 112 | if item.name != 'Note': 113 | continue 114 | note_tuple = [] 115 | 116 | # Bar 117 | if new_bar: 118 | BarValue = 'New' 119 | new_bar = False 120 | else: 121 | BarValue = "Continue" 122 | note_tuple.append(Event( 123 | name='Bar', 124 | time=None, 125 | value=BarValue, 126 | text='{}'.format(n_downbeat), 127 | Type=-1)) 128 | 129 | # Position 130 | flags = np.linspace(bar_st, bar_et, DEFAULT_FRACTION, endpoint=False) 131 | index = np.argmin(abs(flags-item.start)) 132 | note_tuple.append(Event( 133 | name='Position', 134 | time=item.start, 135 | value='{}/{}'.format(index+1, DEFAULT_FRACTION), 136 | text='{}'.format(item.start), 137 | Type=-1)) 138 | 139 | # Pitch 140 | velocity_index = np.searchsorted(DEFAULT_VELOCITY_BINS, item.velocity, side='right') - 1 141 | 142 | if task == 'melody': 143 | pitchType = item.Type 144 | elif task == 'velocity': 145 | pitchType = velocity_index 146 | else: 147 | pitchType = -1 148 | 149 | note_tuple.append(Event( 150 | name='Pitch', 151 | time=item.start, 152 | value=item.pitch, 153 | text='{}'.format(item.pitch), 154 | Type=pitchType)) 155 | 156 | # Duration 157 | duration = item.end - item.start 158 | index = np.argmin(abs(DEFAULT_DURATION_BINS-duration)) 159 | note_tuple.append(Event( 160 | name='Duration', 161 | time=item.start, 162 | value=index, 163 | text='{}/{}'.format(duration, DEFAULT_DURATION_BINS[index]), 164 | Type=-1)) 165 | 166 | events.append(note_tuple) 167 | 168 | return events 169 | 170 | 171 | def quantize_items(items, ticks=120): 172 | grids = np.arange(0, items[-1].start, ticks, dtype=int) 173 | # process 174 | for item in items: 175 | index = np.argmin(abs(grids - item.start)) 176 | shift = grids[index] - item.start 177 | item.start += shift 178 | item.end += shift 179 | return items 180 | 181 | 182 | def group_items(items, max_time, ticks_per_bar=DEFAULT_RESOLUTION*4): 183 | items.sort(key=lambda x: x.start) 184 | downbeats = np.arange(0, max_time+ticks_per_bar, ticks_per_bar) 185 | groups = [] 186 | for db1, db2 in zip(downbeats[:-1], downbeats[1:]): 187 | insiders = [] 188 | for item in items: 189 | if (item.start >= db1) and (item.start < db2): 190 | insiders.append(item) 191 | overall = [db1] + insiders + [db2] 192 | groups.append(overall) 193 | return groups 194 | -------------------------------------------------------------------------------- /data_creation/preprocess_pop909/exploratory.py: -------------------------------------------------------------------------------- 1 | import miditoolkit 2 | import os, pickle 3 | import matplotlib.pyplot as plt 4 | from collections import Counter 5 | 6 | root_dir = '../../Data/Dataset/POP909' 7 | 8 | def read_info_file(fpath, tgt_cols): 9 | with open(fpath, 'r') as f: 10 | lines = f.read().splitlines() 11 | 12 | ret_dict = {col: [] for col in tgt_cols} 13 | for l in lines: 14 | l = l.split() 15 | for col in tgt_cols: 16 | ret_dict[col].append( float(l[col]) ) 17 | 18 | return ret_dict 19 | 20 | if __name__ == '__main__': 21 | pieces_dir = [ i for i in os.listdir(root_dir) 22 | if os.path.isdir( os.path.join(root_dir, i) ) ] 23 | 24 | qualified_quad = 0 25 | qualified_triple = 0 26 | qualified_pieces = [] 27 | 28 | for pdir in pieces_dir: 29 | audio_beat_path = os.path.join(root_dir, pdir, 'beat_audio.txt') 30 | audio_beats = read_info_file(audio_beat_path, [1])[1] 31 | 32 | midi_beat_path = os.path.join(root_dir, pdir, 'beat_midi.txt') 33 | midi_beats = read_info_file(midi_beat_path, [1, 2]) 34 | midi_beats_minor, midi_beats_major = midi_beats[1], midi_beats[2] 35 | 36 | if max(audio_beats) == 4.: 37 | try: 38 | assert abs(0.25 - sum(midi_beats_major) / len(midi_beats_major)) < 0.03 39 | qualified_quad += 1 40 | qualified_pieces.append(pdir) 41 | except: 42 | print (pdir, '[error] 4-beat !!') 43 | elif max(audio_beats) == 3.: 44 | try: 45 | assert abs(0.33 - sum(midi_beats_minor) / len(midi_beats_major)) < 0.03, sum(midi_beats_minor) / len(midi_beats_major) 46 | qualified_triple += 1 47 | except: 48 | print (pdir, '[error] 3-beat !!') 49 | 50 | 51 | print('qualified quad: {}; qualified triple: {}'.format(qualified_quad, qualified_triple)) 52 | 53 | pickle.dump( 54 | qualified_pieces, 55 | open('qual_pieces.pkl', 'wb'), 56 | protocol=pickle.HIGHEST_PROTOCOL 57 | ) 58 | -------------------------------------------------------------------------------- /data_creation/preprocess_pop909/preprocess.py: -------------------------------------------------------------------------------- 1 | import miditoolkit 2 | import os, pickle 3 | from copy import deepcopy 4 | import numpy as np 5 | import glob 6 | from exploratory import read_info_file 7 | from tqdm import tqdm 8 | 9 | from collections import Counter 10 | from itertools import chain 11 | import matplotlib.pyplot as plt 12 | 13 | DEFAULT_TICKS_PER_BEAT = 480 14 | DEFAULT_RESOLUTION = 120 15 | 16 | root_dir = '../../Data/Dataset/POP909' 17 | melody_out_dir = '../../Data/Dataset/pop909_processed' 18 | 19 | downbeat_records = [] 20 | all_bpms = [] 21 | 22 | def justify_tick(n_beats): 23 | n_ticks = n_beats * DEFAULT_TICKS_PER_BEAT 24 | return int(DEFAULT_RESOLUTION * round(n_ticks / DEFAULT_RESOLUTION)) 25 | 26 | def bpm2sec(bpm): 27 | return 60. / bpm 28 | 29 | def calc_accum_secs(bpm, n_ticks, ticks_per_beat): 30 | return bpm2sec(bpm) * n_ticks / ticks_per_beat 31 | 32 | def find_downbeat_idx_audio(audio_dbt): 33 | for st_idx in range(4): 34 | if audio_dbt[ st_idx ] == 1.: 35 | return st_idx 36 | 37 | def get_note_time_sec(note, tempo_bpms, ticks_per_beat, tempo_change_ticks, tempo_accum_times): 38 | st_seg = np.searchsorted(tempo_change_ticks, note.start, side='left') - 1 39 | ed_seg = np.searchsorted(tempo_change_ticks, note.end, side='left') - 1 40 | # print (note.start, tempo_change_ticks[ st_seg ]) 41 | 42 | start_sec = tempo_accum_times[ st_seg ] +\ 43 | calc_accum_secs( 44 | tempo_bpms[ st_seg ], 45 | note.start - tempo_change_ticks[ st_seg ], 46 | ticks_per_beat 47 | ) 48 | end_sec = tempo_accum_times[ ed_seg ] +\ 49 | calc_accum_secs( 50 | tempo_bpms[ ed_seg ], 51 | note.end - tempo_change_ticks[ ed_seg ], 52 | ticks_per_beat 53 | ) 54 | 55 | return start_sec, end_sec 56 | 57 | def align_notes_to_secs(midi_obj): 58 | tempo_bpms = [] 59 | tempo_change_ticks = [] 60 | tempo_accum_times = [] 61 | for tc in midi_obj.tempo_changes: 62 | # print (tc.tempo, tc.time) 63 | if tc.time == 0: 64 | tempo_accum_times.append( 0. ) 65 | else: 66 | tempo_accum_times.append( 67 | tempo_accum_times[-1] + \ 68 | calc_accum_secs( 69 | tempo_bpms[-1], 70 | tc.time - tempo_change_ticks[-1], 71 | midi_obj.ticks_per_beat 72 | ) 73 | ) 74 | 75 | tempo_bpms.append(tc.tempo) 76 | tempo_change_ticks.append(tc.time) 77 | 78 | # print (tempo_accum_times) 79 | 80 | vocal_notes = [] 81 | for note in midi_obj.instruments[0].notes: 82 | note_st_sec, note_ed_sec = get_note_time_sec( 83 | note, tempo_bpms, 84 | midi_obj.ticks_per_beat, tempo_change_ticks, 85 | tempo_accum_times 86 | ) 87 | # print (note_st_sec, note_ed_sec) 88 | vocal_notes.append( 89 | {'st_sec': note_st_sec, 'ed_sec': note_ed_sec, 'pitch': note.pitch, 'velocity': note.velocity} 90 | ) 91 | 92 | bridge_notes = [] 93 | for note in midi_obj.instruments[1].notes: 94 | note_st_sec, note_ed_sec = get_note_time_sec( 95 | note, tempo_bpms, 96 | midi_obj.ticks_per_beat, tempo_change_ticks, 97 | tempo_accum_times 98 | ) 99 | # print (note_st_sec, note_ed_sec) 100 | bridge_notes.append( 101 | {'st_sec': note_st_sec, 'ed_sec': note_ed_sec, 'pitch': note.pitch, 'velocity': note.velocity} 102 | ) 103 | 104 | piano_notes = [] 105 | for note in midi_obj.instruments[2].notes: 106 | note_st_sec, note_ed_sec = get_note_time_sec( 107 | note, tempo_bpms, 108 | midi_obj.ticks_per_beat, tempo_change_ticks, 109 | tempo_accum_times 110 | ) 111 | # print (note_st_sec, note_ed_sec) 112 | piano_notes.append( 113 | {'st_sec': note_st_sec, 'ed_sec': note_ed_sec, 'pitch': note.pitch, 'velocity': note.velocity} 114 | ) 115 | 116 | return vocal_notes, bridge_notes, piano_notes 117 | 118 | def group_notes_per_beat(notes, beat_times): 119 | n_beats = len(beat_times) 120 | note_groups = [[] for _ in range(n_beats)] 121 | cur_beat = 0 122 | 123 | notes = sorted(notes, key=lambda x: (x['st_sec'], -x['pitch'])) 124 | 125 | for note in notes: 126 | while cur_beat < (n_beats - 1) and note['st_sec'] > beat_times[ cur_beat + 1 ]: 127 | # print (cur_beat, note['st_sec'], beat_times[ cur_beat + 1 ]) 128 | cur_beat += 1 129 | 130 | if cur_beat == 0 and note['st_sec'] < beat_times[0]: 131 | if note['st_sec'] >= (beat_times[0] - 0.1) and note['ed_sec'] - note['st_sec'] > 0.2: 132 | note['st_sec'] = beat_times[0] 133 | else: 134 | continue 135 | 136 | if cur_beat == n_beats - 1: 137 | if note['st_sec'] - beat_times[-1] > beat_times[-1] - beat_times[-2]: 138 | continue 139 | 140 | note_groups[ cur_beat ].append( deepcopy(note) ) 141 | 142 | return note_groups 143 | 144 | def remove_piano_notes_collision(vocal_notes, piano_notes): 145 | n_beats = len(vocal_notes) 146 | 147 | for beat in range(n_beats): 148 | if (beat - 1 >= 0 and len(vocal_notes[ beat - 1 ])) or \ 149 | len (vocal_notes[ beat ]) or \ 150 | (beat + 1 < n_beats and len(vocal_notes[ beat + 1 ])) or \ 151 | (beat + 2 < n_beats and len(vocal_notes[ beat + 2 ])): 152 | piano_notes[ beat ] = [] 153 | 154 | return piano_notes 155 | 156 | def quantize_notes(notes, beat_times, downbeat_idx): 157 | quantized = [[] for _ in range(len(beat_times))] 158 | 159 | if downbeat_idx == 1: 160 | cur_tick = 3 * DEFAULT_TICKS_PER_BEAT 161 | elif downbeat_idx == 2: 162 | cur_tick = 2 * DEFAULT_TICKS_PER_BEAT 163 | elif downbeat_idx == 3: 164 | cur_tick = DEFAULT_TICKS_PER_BEAT 165 | else: 166 | cur_tick = 0 167 | 168 | for b_idx, beat_notes in enumerate(notes): 169 | beat_dur = beat_times[b_idx + 1] - beat_times[b_idx]\ 170 | if b_idx < len(notes) - 1 else beat_times[-1] - beat_times[-2] 171 | beat_st_sec = beat_times[b_idx] 172 | 173 | for note in beat_notes: 174 | note_dur_tick = justify_tick( (note['ed_sec'] - note['st_sec']) / beat_dur ) 175 | if note_dur_tick == 0: 176 | continue 177 | note_st_tick = cur_tick +\ 178 | justify_tick( (note['st_sec'] - beat_st_sec) / beat_dur ) 179 | 180 | if note_st_tick < 0: 181 | # print (note['st_sec'], beat_st_sec, b_idx, cur_tick) 182 | print ('[violation]', note_st_tick) 183 | 184 | note['st_tick'] = note_st_tick 185 | note['dur_tick'] = note_dur_tick 186 | quantized[ b_idx ].append( deepcopy(note) ) 187 | 188 | cur_tick += DEFAULT_TICKS_PER_BEAT 189 | 190 | return quantized 191 | 192 | 193 | def merge_notes(vocal_notes, bridge_notes, piano_notes): 194 | vocal_notes = list(chain(*vocal_notes)) 195 | bridge_notes = list(chain(*bridge_notes)) 196 | piano_notes = list(chain(*piano_notes)) 197 | 198 | vocal_notes = sorted( 199 | vocal_notes, 200 | key=lambda x : (x['st_tick'], -x['pitch']) 201 | ) 202 | bridge_notes = sorted( 203 | bridge_notes, 204 | key=lambda x : (x['st_tick'], -x['pitch']) 205 | ) 206 | piano_notes = sorted( 207 | piano_notes, 208 | key=lambda x : (x['st_tick'], -x['pitch']) 209 | ) 210 | 211 | return vocal_notes, bridge_notes, piano_notes 212 | 213 | 214 | def dump_melody_midi(vocal_notes, bridge_notes, piano_notes, bpm_changes, midi_out_path): 215 | midi_obj = miditoolkit.midi.MidiFile() 216 | midi_obj.time_signature_changes = [ 217 | miditoolkit.midi.containers.TimeSignature(4, 4, 0) 218 | ] 219 | midi_obj.tempo_changes = bpm_changes 220 | midi_obj.instruments = [ 221 | miditoolkit.midi.Instrument(0, name='vocal'), 222 | miditoolkit.midi.Instrument(1, name='bridge'), 223 | miditoolkit.midi.Instrument(2, name='piano'), 224 | ] 225 | 226 | for n in vocal_notes: 227 | midi_obj.instruments[0].notes.append( 228 | miditoolkit.midi.containers.Note( 229 | n['velocity'], n['pitch'], n['st_tick'], n['st_tick'] + n['dur_tick'] 230 | ) 231 | ) 232 | 233 | for n in bridge_notes: 234 | midi_obj.instruments[1].notes.append( 235 | miditoolkit.midi.containers.Note( 236 | n['velocity'], n['pitch'], n['st_tick'], n['st_tick'] + n['dur_tick'] 237 | ) 238 | ) 239 | 240 | for n in piano_notes: 241 | midi_obj.instruments[2].notes.append( 242 | miditoolkit.midi.containers.Note( 243 | n['velocity'], n['pitch'], n['st_tick'], n['st_tick'] + n['dur_tick'] 244 | ) 245 | ) 246 | midi_obj.dump(midi_out_path) 247 | return 248 | 249 | def align_midi_beats(piece_dir, subfolder): 250 | audio_beat_path = os.path.join(piece_dir, 'beat_audio.txt') 251 | midi_beat_times = read_info_file(audio_beat_path, [0])[0] 252 | midi_beat_idx = read_info_file(audio_beat_path, [1])[1] 253 | 254 | # find the 1st down beat 255 | downbeat_idx = find_downbeat_idx_audio(midi_beat_idx) 256 | 257 | midi_obj = miditoolkit.midi.MidiFile( 258 | os.path.join(root_dir, pdir, pdir + '.mid') 259 | ) 260 | 261 | vocal_notes, bridge_notes, piano_notes = align_notes_to_secs(midi_obj) 262 | 263 | vocal_notes = group_notes_per_beat(vocal_notes, midi_beat_times) 264 | bridge_notes = group_notes_per_beat(bridge_notes, midi_beat_times) 265 | piano_notes = group_notes_per_beat(piano_notes, midi_beat_times) 266 | 267 | vocal_notes = quantize_notes(vocal_notes, midi_beat_times, downbeat_idx) 268 | bridge_notes = quantize_notes(bridge_notes, midi_beat_times, downbeat_idx) 269 | piano_notes = quantize_notes(piano_notes, midi_beat_times, downbeat_idx) 270 | 271 | vocal_notes, bridge_notes, piano_notes = merge_notes(vocal_notes, bridge_notes, piano_notes) 272 | 273 | 274 | # recalculate bpm 275 | # change index 0 to 4 276 | if downbeat_idx == 0: 277 | downbeat_idx = 4 278 | 279 | first_beat_tick = (4-downbeat_idx) * DEFAULT_TICKS_PER_BEAT 280 | first_bpm = np.round( (60./(midi_beat_times[1]-midi_beat_times[0])), 2 ) 281 | bpm_changes = [ miditoolkit.midi.containers.TempoChange(first_bpm, 0) ] 282 | 283 | beat_diff = [np.round(j-i, 2) for i,j in zip(midi_beat_times[:-1], midi_beat_times[1:])] 284 | bd_tmp = beat_diff[0] 285 | 286 | for i, bd in enumerate(beat_diff[1:]): 287 | if abs(bd-bd_tmp) > 0.05: 288 | bd_tmp = bd 289 | neighbor_avg = (beat_diff[i-1] + bd + beat_diff[i+1])/3 290 | _bpm = 60. / (bd) 291 | _st = first_beat_tick + ((i+2) * DEFAULT_TICKS_PER_BEAT) 292 | bpm_changes.append(miditoolkit.midi.containers.TempoChange(_bpm,_st)) 293 | 294 | 295 | dump_melody_midi( 296 | vocal_notes, 297 | bridge_notes, 298 | piano_notes, 299 | bpm_changes, 300 | os.path.join(melody_out_dir, subfolder, piece_dir.split('/')[-1] + '.mid') 301 | ) 302 | 303 | 304 | if __name__ == '__main__': 305 | pieces_dir = pickle.load(open('qual_pieces.pkl', 'rb')) 306 | data_split = pickle.load(open('split.pkl','rb')) 307 | train_data = set(data_split['train_data']) 308 | valid_data = set(data_split['valid_data']) 309 | test_data = set(data_split['test_data']) 310 | 311 | os.makedirs(f'{melody_out_dir}/train', exist_ok=True) 312 | os.makedirs(f'{melody_out_dir}/valid', exist_ok=True) 313 | os.makedirs(f'{melody_out_dir}/test', exist_ok=True) 314 | 315 | for pdir in tqdm(pieces_dir): 316 | piece_dir = os.path.join(root_dir, pdir) 317 | mid = f'{pdir}.mid' 318 | if mid in train_data: 319 | subfolder = 'train' 320 | elif mid in valid_data: 321 | subfolder = 'valid' 322 | elif mid in test_data: 323 | subfolder = 'test' 324 | else: 325 | print(f'invalid midi {mid}') 326 | exit(1) 327 | align_midi_beats(piece_dir, subfolder) 328 | print(f"Preprocessed data saved at {melody_out_dir}") 329 | -------------------------------------------------------------------------------- /data_creation/preprocess_pop909/qual_pieces.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/data_creation/preprocess_pop909/qual_pieces.pkl -------------------------------------------------------------------------------- /data_creation/preprocess_pop909/split.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/data_creation/preprocess_pop909/split.pkl -------------------------------------------------------------------------------- /melody_extraction/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/__init__.py -------------------------------------------------------------------------------- /melody_extraction/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /melody_extraction/__pycache__/midi2CP.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/__pycache__/midi2CP.cpython-38.pyc -------------------------------------------------------------------------------- /melody_extraction/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /melody_extraction/audio/README.md: -------------------------------------------------------------------------------- 1 | # Melody Extraction on POP909 and Pianist8 2 | 3 | We have runned melody extraction algorithms on random selected songs in POP909 and Pianist8. The algorithms are listed below: 4 | - skyline 5 | - A Convolutional Approach to Melody Line Identification in Symbolic Scores (https://arxiv.org/abs/1906.10547) (https://github.com/sophia1488/symbolic-melody-identification), we train the model on POP909 train split. 6 | - MidiBERT finetuned with POP909 for melody extraction, here we view POP909's "melody" and "bridge" as melody, the detailed information of these categorization please see https://github.com/music-x-lab/POP909-Dataset/tree/master. 7 | 8 | ## :headphones: 2024/04 Update 9 | Upload corresponding MP3 files in directories with postfix `_mp3`. 10 | Since the files are too large, please access with this link: https://drive.google.com/file/d/1cRcTAV43n-1VeJciCxQXSx5TLyChOXHI/view 11 | 12 | 13 | ## In-domain Dataset - POP909 14 | - 018.mid 15 | - 067.mid 16 | - 395.mid 17 | - 596.mid 18 | - 828.mid 19 | 20 | || Accuracy | Precision | Recall | F1 | 21 | |---|---|---|---|---| 22 | |Skyline | 79.52 | 81.42 | 56.57 | 66.76 | 23 | |CNN | 92.08 | 88.95 | 89.3 | 89.13 | 24 | |MidiBERT | 99.06 | 98.68 | 98.72 | 98.7 | 25 | 26 | ## Out-domain Dataset - Pianist8 27 | - Clayderman_I_Have_A_dream 28 | - Clayderman_I_Like_Chopin 29 | - Clayderman_Yesterday_Once_More 30 | - Yiruma_Love_Hurts 31 | - Yiruma_River_Flows_In_You 32 | 33 | ## Filename Description 34 | - `_gt`: ground truth 35 | - `_skyline`: melody extracted by skyline algorithm 36 | - `_cnn`: melody extracted by A Convolutional Approach to Melody Line Identification in Symbolic Scores 37 | - `_ours`: melody extracted by MidiBERT 38 | -------------------------------------------------------------------------------- /melody_extraction/audio/pianist8/Clayderman_I_Have_A_Dream.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pianist8/Clayderman_I_Have_A_Dream.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pianist8/Clayderman_I_Like_Chopin.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pianist8/Clayderman_I_Like_Chopin.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pianist8/Clayderman_Yesterday_Once_More.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pianist8/Clayderman_Yesterday_Once_More.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pianist8/Yiruma_Love_Hurts.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pianist8/Yiruma_Love_Hurts.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pianist8/Yiruma_River_Flows_in_You.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pianist8/Yiruma_River_Flows_in_You.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pianist8_melody/Clayderman_I_Have_A_Dream_cnn.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pianist8_melody/Clayderman_I_Have_A_Dream_cnn.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pianist8_melody/Clayderman_I_Have_A_Dream_ours.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pianist8_melody/Clayderman_I_Have_A_Dream_ours.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pianist8_melody/Clayderman_I_Have_A_Dream_skyline.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pianist8_melody/Clayderman_I_Have_A_Dream_skyline.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pianist8_melody/Clayderman_I_Like_Chopin_cnn.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pianist8_melody/Clayderman_I_Like_Chopin_cnn.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pianist8_melody/Clayderman_I_Like_Chopin_ours.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pianist8_melody/Clayderman_I_Like_Chopin_ours.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pianist8_melody/Clayderman_I_Like_Chopin_skyline.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pianist8_melody/Clayderman_I_Like_Chopin_skyline.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pianist8_melody/Clayderman_Yesterday_Once_More_cnn.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pianist8_melody/Clayderman_Yesterday_Once_More_cnn.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pianist8_melody/Clayderman_Yesterday_Once_More_ours.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pianist8_melody/Clayderman_Yesterday_Once_More_ours.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pianist8_melody/Clayderman_Yesterday_Once_More_skyline.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pianist8_melody/Clayderman_Yesterday_Once_More_skyline.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pianist8_melody/Yiruma_Love_Hurts_cnn.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pianist8_melody/Yiruma_Love_Hurts_cnn.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pianist8_melody/Yiruma_Love_Hurts_ours.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pianist8_melody/Yiruma_Love_Hurts_ours.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pianist8_melody/Yiruma_Love_Hurts_skyline.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pianist8_melody/Yiruma_Love_Hurts_skyline.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pianist8_melody/Yiruma_River_Flows_in_You_cnn.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pianist8_melody/Yiruma_River_Flows_in_You_cnn.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pianist8_melody/Yiruma_River_Flows_in_You_ours.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pianist8_melody/Yiruma_River_Flows_in_You_ours.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pianist8_melody/Yiruma_River_Flows_in_You_skyline.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pianist8_melody/Yiruma_River_Flows_in_You_skyline.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pop909/018.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pop909/018.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pop909/067.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pop909/067.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pop909/395.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pop909/395.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pop909/596.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pop909/596.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pop909/828.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pop909/828.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pop909/extract.py: -------------------------------------------------------------------------------- 1 | """ 2 | Extract Ground Truth from POP909 songs 3 | """ 4 | 5 | import os 6 | import sys 7 | import numpy as np 8 | import miditoolkit 9 | 10 | files = [] 11 | for f in os.listdir(sys.argv[1]): 12 | if f.endswith('.mid'): 13 | files.append(f) 14 | 15 | for f in files: 16 | midi = miditoolkit.midi.parser.MidiFile(os.path.join(sys.argv[1], f)) 17 | midi.instruments = midi.instruments[:2] 18 | num = f.split(".")[0] 19 | midi.dump(os.path.join(sys.argv[2], num + "_gt.mid")) 20 | -------------------------------------------------------------------------------- /melody_extraction/audio/pop909_melody/018_cnn.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pop909_melody/018_cnn.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pop909_melody/018_gt.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pop909_melody/018_gt.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pop909_melody/018_ours.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pop909_melody/018_ours.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pop909_melody/018_skyline.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pop909_melody/018_skyline.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pop909_melody/067_cnn.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pop909_melody/067_cnn.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pop909_melody/067_gt.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pop909_melody/067_gt.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pop909_melody/067_ours.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pop909_melody/067_ours.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pop909_melody/067_skyline.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pop909_melody/067_skyline.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pop909_melody/395_cnn.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pop909_melody/395_cnn.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pop909_melody/395_gt.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pop909_melody/395_gt.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pop909_melody/395_ours.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pop909_melody/395_ours.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pop909_melody/395_skyline.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pop909_melody/395_skyline.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pop909_melody/596_cnn.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pop909_melody/596_cnn.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pop909_melody/596_gt.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pop909_melody/596_gt.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pop909_melody/596_ours.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pop909_melody/596_ours.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pop909_melody/596_skyline.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pop909_melody/596_skyline.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pop909_melody/828_cnn.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pop909_melody/828_cnn.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pop909_melody/828_gt.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pop909_melody/828_gt.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pop909_melody/828_ours.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pop909_melody/828_ours.mid -------------------------------------------------------------------------------- /melody_extraction/audio/pop909_melody/828_skyline.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/audio/pop909_melody/828_skyline.mid -------------------------------------------------------------------------------- /melody_extraction/midibert/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/midibert/__init__.py -------------------------------------------------------------------------------- /melody_extraction/midibert/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/midibert/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /melody_extraction/midibert/__pycache__/midi2CP.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/midibert/__pycache__/midi2CP.cpython-38.pyc -------------------------------------------------------------------------------- /melody_extraction/midibert/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/midibert/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /melody_extraction/midibert/extract.py: -------------------------------------------------------------------------------- 1 | """ 2 | [ Melody Extraction ] 3 | Given path to input midi file, save the predicted melody midi file. 4 | Please note that the model is trained on pop909 dataset (containing 3 classes: melody, bridge, accompaniment), 5 | so there are 2 interpretations: view `bridge` as `melody` or view it as `accompaniment`. 6 | You could choose the mode - `bridge` is viewed as `melody` by default. 7 | 8 | Also, the sequence is zero-padded so that the shape (length) is the same, but it won't affect the results, 9 | as zero-padded tokens will be excluded in post-processing. 10 | """ 11 | 12 | import argparse 13 | import numpy as np 14 | import random 15 | import pickle 16 | import os 17 | import copy 18 | import shutil 19 | import json 20 | import miditoolkit 21 | 22 | from torch.utils.data import DataLoader 23 | import torch 24 | import torch.nn as nn 25 | from transformers import BertConfig 26 | 27 | from melody_extraction.midibert.midi2CP import CP 28 | from melody_extraction.midibert.utils import DEFAULT_VELOCITY_BINS, DEFAULT_FRACTION, DEFAULT_DURATION_BINS, DEFAULT_TEMPO_INTERVALS, DEFAULT_RESOLUTION 29 | from MidiBERT.model import MidiBert 30 | from MidiBERT.finetune_model import TokenClassification 31 | 32 | 33 | def boolean_string(s): 34 | if s not in ['False', 'True']: 35 | raise ValueError('Not a valid boolean string') 36 | return s == 'True' 37 | 38 | def get_args(): 39 | parser = argparse.ArgumentParser(description='') 40 | 41 | ### path setup ### 42 | parser.add_argument('--input_path', type=str, required=True, help="path to the input midi file") 43 | parser.add_argument('--output_path', type=str, default=None, help="path to the output midi file") 44 | parser.add_argument('--dict_file', type=str, default='data_creation/prepare_data/dict/CP.pkl') 45 | parser.add_argument('--ckpt', type=str, default='') 46 | 47 | ### parameter setting ### 48 | parser.add_argument('--max_seq_len', type=int, default=512, help='all sequences are padded to `max_seq_len`') 49 | parser.add_argument('--hs', type=int, default=768) 50 | parser.add_argument('--bridge', default=True, type=boolean_string, help='View bridge as melody (True) or accompaniment (False)') 51 | 52 | ### cuda ### 53 | parser.add_argument('--cpu', action="store_true") # default: false 54 | 55 | args = parser.parse_args() 56 | 57 | root = 'result/finetune/' 58 | args.ckpt = root + 'melody_default/model_best.ckpt' if args.ckpt=='' else args.ckpt 59 | 60 | if not args.output_path: 61 | basename = args.input_path.split('/')[-1].split('.')[0] 62 | args.output_path = f'{basename}_melody.mid' 63 | 64 | return args 65 | 66 | 67 | def load_model(args, e2w, w2e): 68 | print("\nBuilding BERT model") 69 | configuration = BertConfig(max_position_embeddings=args.max_seq_len, 70 | position_embedding_type='relative_key_query', 71 | hidden_size=args.hs) 72 | 73 | midibert = MidiBert(bertConfig=configuration, e2w=e2w, w2e=w2e) 74 | 75 | model = TokenClassification(midibert, 4, args.hs) 76 | 77 | print('\nLoading ckpt from', args.ckpt) 78 | checkpoint = torch.load(args.ckpt, map_location='cpu') 79 | 80 | # remove module 81 | #from collections import OrderedDict 82 | #new_state_dict = OrderedDict() 83 | #for k, v in checkpoint['state_dict'].items(): 84 | # name = k[7:] 85 | # new_state_dict[name] = v 86 | #model.load_state_dict(new_state_dict) 87 | model.load_state_dict(checkpoint['state_dict']) 88 | 89 | return model 90 | 91 | 92 | def inference(model, tokens, pad_CP, device): 93 | """ 94 | Given `model`, `tokens` (input), `pad_CP` (to indicate which notes are padded) 95 | Return inference output 96 | """ 97 | tokens = torch.from_numpy(tokens).to(device) 98 | pad_CP = torch.tensor(pad_CP).to(device) 99 | attn = torch.all(tokens != pad_CP, dim=2).float().to(device) 100 | 101 | # forward (input, attn, layer idx) 102 | pred = model.forward(tokens, attn, -1) # pred: (batch, seq_len, class_num) 103 | output = np.argmax(pred.cpu().detach().numpy(), axis=-1) # (batch, seq_len) 104 | 105 | return torch.from_numpy(output) 106 | 107 | 108 | def get_melody_events(events, inputs, preds, pad_CP, bridge=True): 109 | """ 110 | Filter out predicted melody events. 111 | Arguments: 112 | - events: complete events, including tempo changes and velocity 113 | - inputs: input compact_CP tokens (batch, seq, CP_class), np.array 114 | - preds: predicted classes (batch, seq), torch.tensor 115 | Note for predictions: 1 is melody, 2 is bridge, 3 is piano/accompaniment 116 | - pad_CP: padded CP representation (list) 117 | - bridge (bool): whether bridge is viewed as melody 118 | """ 119 | numClass = inputs.shape[-1] 120 | inputs = inputs.reshape(-1, numClass) 121 | preds = preds.reshape(-1) 122 | pad_CP = np.array(pad_CP) 123 | 124 | melody_events = [] 125 | note_ind = 0 126 | for event in events: 127 | if len(event) == 5: # filter out melody events 128 | is_melody = preds[note_ind] == 1 or (bridge and preds[note_ind] == 2) 129 | is_valid_note = np.all(inputs[note_ind] != pad_CP) 130 | if is_valid_note and is_melody: 131 | melody_events.append(event) 132 | note_ind += 1 133 | else: 134 | melody_events.append(event) 135 | 136 | return melody_events 137 | 138 | 139 | def events2midi(events, output_path, prompt_path=None): 140 | """ 141 | Given melody events, convert back to midi 142 | """ 143 | temp_notes, temp_tempos = [], [] 144 | 145 | for event in events: 146 | if len(event) == 1: # [Bar] 147 | temp_notes.append('Bar') 148 | temp_tempos.append('Bar') 149 | 150 | elif len(event) == 5: # [Bar, Position, Pitch, Duration, Velocity] 151 | # start time and end time from position 152 | position = int(event[1].value.split('/')[0]) - 1 153 | # pitch 154 | pitch = int(event[2].value) 155 | # duration 156 | index = int(event[3].value) 157 | duration = DEFAULT_DURATION_BINS[index] 158 | # velocity 159 | index = int(event[4].value) 160 | velocity = int(DEFAULT_VELOCITY_BINS[index]) 161 | # adding 162 | temp_notes.append([position, velocity, pitch, duration]) 163 | 164 | else: # [Position, Tempo Class, Tempo Value] 165 | position = int(event[0].value.split('/')[0]) - 1 166 | if event[1].value == 'slow': 167 | tempo = DEFAULT_TEMPO_INTERVALS[0].start + int(event[2].value) 168 | elif event[1].value == 'mid': 169 | tempo = DEFAULT_TEMPO_INTERVALS[1].start + int(event[2].value) 170 | elif event[1].value == 'fast': 171 | tempo = DEFAULT_TEMPO_INTERVALS[2].start + int(event[2].value) 172 | temp_tempos.append([position, tempo]) 173 | 174 | # get specific time for notes 175 | ticks_per_beat = DEFAULT_RESOLUTION 176 | ticks_per_bar = DEFAULT_RESOLUTION * 4 # assume 4/4 177 | notes = [] 178 | current_bar = 0 179 | for note in temp_notes: 180 | if note == 'Bar': 181 | current_bar += 1 182 | else: 183 | position, velocity, pitch, duration = note 184 | # position (start time) 185 | current_bar_st = current_bar * ticks_per_bar 186 | current_bar_et = (current_bar + 1) * ticks_per_bar 187 | flags = np.linspace(current_bar_st, current_bar_et, DEFAULT_FRACTION, endpoint=False, dtype=int) 188 | st = flags[position] 189 | # duration (end time) 190 | et = st + duration 191 | notes.append(miditoolkit.Note(velocity, pitch, st, et)) 192 | 193 | # get specific time for tempos 194 | tempos = [] 195 | current_bar = 0 196 | for tempo in temp_tempos: 197 | if tempo == 'Bar': 198 | current_bar += 1 199 | else: 200 | position, value = tempo 201 | # position (start time) 202 | current_bar_st = current_bar * ticks_per_bar 203 | current_bar_et = (current_bar + 1) * ticks_per_bar 204 | flags = np.linspace(current_bar_st, current_bar_et, DEFAULT_FRACTION, endpoint=False, dtype=int) 205 | st = flags[position] 206 | tempos.append([int(st), value]) 207 | # write 208 | if prompt_path: 209 | midi = miditoolkit.midi.parser.MidiFile(prompt_path) 210 | # 211 | last_time = DEFAULT_RESOLUTION * 4 * 4 212 | # note shift 213 | for note in notes: 214 | note.start += last_time 215 | note.end += last_time 216 | midi.instruments[0].notes.extend(notes) 217 | # tempo changes 218 | temp_tempos = [] 219 | for tempo in midi.tempo_changes: 220 | if tempo.time < DEFAULT_RESOLUTION*4*4: 221 | temp_tempos.append(tempo) 222 | else: 223 | break 224 | for st, bpm in tempos: 225 | st += last_time 226 | temp_tempos.append(miditoolkit.midi.containers.TempoChange(bpm, st)) 227 | midi.tempo_changes = temp_tempos 228 | else: 229 | midi = miditoolkit.midi.parser.MidiFile() 230 | midi.ticks_per_beat = DEFAULT_RESOLUTION 231 | # write instrument 232 | inst = miditoolkit.midi.containers.Instrument(0, is_drum=False) 233 | inst.notes = notes 234 | midi.instruments.append(inst) 235 | # write tempo 236 | tempo_changes = [] 237 | for st, bpm in tempos: 238 | tempo_changes.append(miditoolkit.midi.containers.TempoChange(bpm, st)) 239 | midi.tempo_changes = tempo_changes 240 | 241 | # write 242 | midi.dump(output_path) 243 | print(f"predicted melody midi file is saved at {output_path}") 244 | 245 | return 246 | 247 | 248 | def main(): 249 | args = get_args() 250 | with open(args.dict_file, 'rb') as f: 251 | e2w, w2e = pickle.load(f) 252 | 253 | compact_classes = ['Bar', 'Position', 'Pitch', 'Duration'] 254 | pad_CP = [e2w[subclass][f"{subclass} "] for subclass in compact_classes] 255 | 256 | # preprocess input file 257 | CP_model = CP(dict=args.dict_file) 258 | events, tokens = CP_model.prepare_data(args.input_path, args.max_seq_len) # files, task, seq_len 259 | filename = args.input_path.split('/')[-1] 260 | print(f"'{filename}' is preprocessed to CP repr. with shape {tokens.shape}") 261 | 262 | # load pre-trained model 263 | model = load_model(args, e2w, w2e) 264 | 265 | # inference 266 | device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else 'cpu') 267 | print("Using", device) 268 | predictions = inference(model, tokens, pad_CP, device) 269 | print(f"predicted melody shape {predictions.shape}") 270 | #np.save("input.npy", tokens) 271 | #np.save("pred.npy", predictions) 272 | 273 | # post-process 274 | melody_events = get_melody_events(events, tokens, predictions, pad_CP, bridge=args.bridge) 275 | print(f"Melody Events: {len(melody_events)}/{len(events)}") 276 | 277 | # save melody midi 278 | melody_midi = events2midi(melody_events, args.output_path) 279 | 280 | 281 | if __name__ == '__main__': 282 | main() 283 | -------------------------------------------------------------------------------- /melody_extraction/midibert/midi2CP.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | from tqdm import tqdm 4 | import melody_extraction.midibert.utils as utils 5 | 6 | 7 | class CP(object): 8 | def __init__(self, dict): 9 | # load dictionary 10 | self.event2word, self.word2event = pickle.load(open(dict, 'rb')) 11 | # pad word: ['Bar ', 'Position ', 'Pitch ', 'Duration '] 12 | classes = ['Bar', 'Position', 'Pitch', 'Duration'] 13 | self.pad_word = [self.event2word[etype]['%s ' % etype] for etype in classes] 14 | 15 | def extract_events(self, input_path): 16 | note_items, tempo_items = utils.read_items(input_path) 17 | if len(note_items) == 0: # if the midi contains nothing 18 | return None 19 | note_items = utils.quantize_items(note_items) 20 | max_time = note_items[-1].end 21 | items = tempo_items + note_items 22 | 23 | groups = utils.group_items(items, max_time) 24 | events = utils.item2event(groups) 25 | return events 26 | 27 | def padding(self, data, max_len): 28 | pad_len = max_len - len(data) 29 | for _ in range(pad_len): 30 | data.append(self.pad_word) 31 | 32 | return data 33 | 34 | def prepare_data(self, midi_path, max_len): 35 | """ 36 | Prepare data for a single midi 37 | """ 38 | # extract events 39 | events = self.extract_events(midi_path) 40 | if not events: # if midi contains nothing 41 | raise ValueError(f'The given {midi_path} is empty') 42 | 43 | # events to words 44 | # 1. Bar, Position, Pitch, Duration, Velocity ---> we only convert note events to words 45 | # 2. Position, Tempo Style, Tempo Class 46 | # 3. Bar 47 | words = [] 48 | 49 | for tup in events: 50 | nts = [] 51 | if len(tup) == 5: # Note 52 | for e in tup: 53 | if e.name == 'Velocity': 54 | continue 55 | e_text = '{} {}'.format(e.name, e.value) 56 | nts.append(self.event2word[e.name][e_text]) 57 | words.append(nts) 58 | 59 | # slice to chunks so that max length = max_len (default: 512) 60 | slice_words = [] 61 | for i in range(0, len(words), max_len): 62 | slice_words.append(words[i:i+max_len]) 63 | 64 | # padding or drop 65 | if len(slice_words[-1]) < max_len: 66 | slice_words[-1] = self.padding(slice_words[-1], max_len) 67 | 68 | slice_words = np.array(slice_words) 69 | 70 | return events, slice_words 71 | -------------------------------------------------------------------------------- /melody_extraction/midibert/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import miditoolkit 3 | import copy 4 | 5 | # parameters for input 6 | DEFAULT_VELOCITY_BINS = np.linspace(0, 128, 32+1, dtype=np.int) 7 | DEFAULT_FRACTION = 16 8 | DEFAULT_DURATION_BINS = np.arange(60, 3841, 60, dtype=int) 9 | DEFAULT_TEMPO_INTERVALS = [range(30, 90), range(90, 150), range(150, 210)] 10 | 11 | # parameters for output 12 | DEFAULT_RESOLUTION = 480 13 | 14 | # define "Item" for general storage 15 | class Item(object): 16 | def __init__(self, name, start, end, velocity, pitch): 17 | self.name = name 18 | self.start = start 19 | self.end = end 20 | self.velocity = velocity 21 | self.pitch = pitch 22 | 23 | def __repr__(self): 24 | return 'Item(name={}, start={}, end={}, velocity={}, pitch={})'.format( 25 | self.name, self.start, self.end, self.velocity, self.pitch) 26 | 27 | # read notes and tempo changes from midi (assume there is only one track) 28 | def read_items(file_path): 29 | midi_obj = miditoolkit.midi.parser.MidiFile(file_path) 30 | # note 31 | note_items = [] 32 | num_of_instr = len(midi_obj.instruments) 33 | 34 | for i in range(num_of_instr): 35 | notes = midi_obj.instruments[i].notes 36 | notes.sort(key=lambda x: (x.start, x.pitch)) 37 | 38 | for note in notes: 39 | note_items.append(Item( 40 | name='Note', 41 | start=note.start, 42 | end=note.end, 43 | velocity=note.velocity, 44 | pitch=note.pitch)) 45 | 46 | note_items.sort(key=lambda x: x.start) 47 | 48 | # tempo 49 | tempo_items = [] 50 | for tempo in midi_obj.tempo_changes: 51 | tempo_items.append(Item( 52 | name='Tempo', 53 | start=tempo.time, 54 | end=None, 55 | velocity=None, 56 | pitch=int(tempo.tempo))) 57 | 58 | tempo_items.sort(key=lambda x: x.start) 59 | 60 | # expand to all beat 61 | max_tick = tempo_items[-1].start 62 | existing_ticks = {item.start: item.pitch for item in tempo_items} 63 | wanted_ticks = np.arange(0, max_tick+1, DEFAULT_RESOLUTION) 64 | output = [] 65 | for tick in wanted_ticks: 66 | if tick in existing_ticks: 67 | output.append(Item( 68 | name='Tempo', 69 | start=tick, 70 | end=None, 71 | velocity=None, 72 | pitch=existing_ticks[tick])) 73 | else: 74 | output.append(Item( 75 | name='Tempo', 76 | start=tick, 77 | end=None, 78 | velocity=None, 79 | pitch=output[-1].pitch)) 80 | tempo_items = output 81 | return note_items, tempo_items 82 | 83 | 84 | class Event(object): 85 | def __init__(self, name, time, value, text): 86 | self.name = name 87 | self.time = time 88 | self.value = value 89 | self.text = text 90 | 91 | def __repr__(self): 92 | return 'Event(name={}, time={}, value={}, text={})'.format( 93 | self.name, self.time, self.value, self.text) 94 | 95 | 96 | def item2event(groups): 97 | events = [] 98 | n_downbeat = 0 99 | for i in range(len(groups)): 100 | if 'Note' not in [item.name for item in groups[i][1:-1]]: 101 | continue 102 | bar_st, bar_et = groups[i][0], groups[i][-1] 103 | n_downbeat += 1 104 | new_bar = True 105 | events.append([Event( 106 | name='Bar', 107 | time=None, 108 | value=None, 109 | text='{}'.format(n_downbeat))]) 110 | 111 | for item in groups[i][1:-1]: 112 | if item.name == 'Note': 113 | note_tuple = [] 114 | # Bar 115 | if new_bar: 116 | BarValue = 'New' 117 | new_bar = False 118 | else: 119 | BarValue = "Continue" 120 | note_tuple.append(Event( 121 | name='Bar', 122 | time=None, 123 | value=BarValue, 124 | text='{}'.format(n_downbeat))) 125 | 126 | # Position 127 | flags = np.linspace(bar_st, bar_et, DEFAULT_FRACTION, endpoint=False) 128 | index = np.argmin(abs(flags-item.start)) 129 | note_tuple.append(Event( 130 | name='Position', 131 | time=item.start, 132 | value='{}/{}'.format(index+1, DEFAULT_FRACTION), 133 | text='{}'.format(item.start))) 134 | 135 | # Pitch 136 | note_tuple.append(Event( 137 | name='Pitch', 138 | time=item.start, 139 | value=item.pitch, 140 | text='{}'.format(item.pitch))) 141 | 142 | # Duration 143 | duration = item.end - item.start 144 | index = np.argmin(abs(DEFAULT_DURATION_BINS-duration)) 145 | note_tuple.append(Event( 146 | name='Duration', 147 | time=item.start, 148 | value=index, 149 | text='{}/{}'.format(duration, DEFAULT_DURATION_BINS[index]))) 150 | 151 | # Velocity 152 | velocity_index = np.searchsorted(DEFAULT_VELOCITY_BINS, item.velocity, side='right') - 1 153 | note_tuple.append(Event( 154 | name='Velocity', 155 | time=item.start, 156 | value=velocity_index, 157 | text='{}/{}'.format(item.velocity, DEFAULT_VELOCITY_BINS[velocity_index]))) 158 | 159 | events.append(note_tuple) 160 | 161 | elif item.name == 'Tempo': 162 | # Position 163 | flags = np.linspace(bar_st, bar_et, DEFAULT_FRACTION, endpoint=False) 164 | index = np.argmin(abs(flags-item.start)) 165 | position = Event( 166 | name='Position', 167 | time=item.start, 168 | value='{}/{}'.format(index+1, DEFAULT_FRACTION), 169 | text='{}'.format(item.start)) 170 | # tempo 171 | tempo = item.pitch 172 | if tempo in DEFAULT_TEMPO_INTERVALS[0]: 173 | tempo_style = Event('Tempo Class', item.start, 'slow', None) 174 | tempo_value = Event('Tempo Value', item.start, 175 | tempo-DEFAULT_TEMPO_INTERVALS[0].start, None) 176 | elif tempo in DEFAULT_TEMPO_INTERVALS[1]: 177 | tempo_style = Event('Tempo Class', item.start, 'mid', None) 178 | tempo_value = Event('Tempo Value', item.start, 179 | tempo-DEFAULT_TEMPO_INTERVALS[1].start, None) 180 | elif tempo in DEFAULT_TEMPO_INTERVALS[2]: 181 | tempo_style = Event('Tempo Class', item.start, 'fast', None) 182 | tempo_value = Event('Tempo Value', item.start, 183 | tempo-DEFAULT_TEMPO_INTERVALS[2].start, None) 184 | elif tempo < DEFAULT_TEMPO_INTERVALS[0].start: 185 | tempo_style = Event('Tempo Class', item.start, 'slow', None) 186 | tempo_value = Event('Tempo Value', item.start, 0, None) 187 | elif tempo > DEFAULT_TEMPO_INTERVALS[2].stop: 188 | tempo_style = Event('Tempo Class', item.start, 'fast', None) 189 | tempo_value = Event('Tempo Value', item.start, 59, None) 190 | 191 | tempo_tuple = [position, tempo_style, tempo_value] 192 | events.append(tempo_tuple) 193 | 194 | return events 195 | 196 | 197 | def quantize_items(items, ticks=120): 198 | grids = np.arange(0, items[-1].start, ticks, dtype=int) 199 | # process 200 | for item in items: 201 | index = np.argmin(abs(grids - item.start)) 202 | shift = grids[index] - item.start 203 | item.start += shift 204 | item.end += shift 205 | return items 206 | 207 | 208 | def group_items(items, max_time, ticks_per_bar=DEFAULT_RESOLUTION*4): 209 | items.sort(key=lambda x: x.start) 210 | downbeats = np.arange(0, max_time+ticks_per_bar, ticks_per_bar) 211 | groups = [] 212 | for db1, db2 in zip(downbeats[:-1], downbeats[1:]): 213 | insiders = [] 214 | for item in items: 215 | if (item.start >= db1) and (item.start < db2): 216 | insiders.append(item) 217 | overall = [db1] + insiders + [db2] 218 | groups.append(overall) 219 | return groups 220 | -------------------------------------------------------------------------------- /melody_extraction/pianoroll/018_pianoroll.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/pianoroll/018_pianoroll.png -------------------------------------------------------------------------------- /melody_extraction/pianoroll/596_pianoroll.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/pianoroll/596_pianoroll.png -------------------------------------------------------------------------------- /melody_extraction/pianoroll/README.md: -------------------------------------------------------------------------------- 1 | # Generate Pianoroll 2 | 3 | The pianoroll diagram can give us a more intuitive understanding of each melody-extraction method. 4 | 5 | ## Ground Truth 6 | Use following code to generate CP data for ground truth and MidiBERT. 7 | ```python 8 | import numpy as np 9 | from data_creation.prepare_data.model import * 10 | 11 | dir_path = "melody_extraction/pianoroll/audio/" # modify to your own directory path 12 | song_path = os.path.join(dir_path, "018.mid") # modify the filename 13 | output_name = "pop909_018" # modify the output name 14 | 15 | songs = [song_path] 16 | model = CP(dict="data_creation/prepare_data/dict/CP.pkl") 17 | word, ys = model.prepare_data(songs, 'melody', 512) 18 | np.save(os.path.join(dir_path, output_name + "_cp.npy"), word) 19 | np.save(os.path.join(dir_path, output_name + "_groundtruth.npy"), ys) 20 | ``` 21 | 22 | Set the data path and mode in `settings.py`, and run the script from the main directory of MidiBERT by `bash scripts/pianoroll.sh`. 23 | 24 | ## MidiBERT-CP 25 | 1. Get the prediction by running `MidiBERT/eval.py` by `scripts/eval.sh`. You need to modify the code to change the test data path and save the prediction. 26 | 2. Run the script `bash scripts/pianoroll.sh` and make sure the settings is correct. 27 | 28 | ## Skyline 29 | ### Skyline from Midi 30 | Simply uncomment the code in `plot.py` and modify the midi file's name. 31 | ```python 32 | sky_melody, sky_accomp = skyline_from_midi("018.mid") 33 | plot_roll([sky_melody, sky_accomp], 48*4*4) 34 | ``` 35 | ### Skyline from CP 36 | In order to get the fair comparison to our MidiBERT, we use the CP data to align the skyline generated melody. After generating the CP data, set the mode to `skyline` and run the script `bash scripts/pianoroll.sh`. 37 | 38 | -------------------------------------------------------------------------------- /melody_extraction/pianoroll/analyzer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import numpy as np 4 | 5 | import miditoolkit 6 | from miditoolkit.midi import parser as mid_parser 7 | from miditoolkit.pianoroll import parser as pr_parser 8 | from miditoolkit.midi.containers import Marker, Instrument, TempoChange 9 | 10 | from chorder import Dechorder 11 | 12 | 13 | num2pitch = { 14 | 0: 'C', 15 | 1: 'C#', 16 | 2: 'D', 17 | 3: 'D#', 18 | 4: 'E', 19 | 5: 'F', 20 | 6: 'F#', 21 | 7: 'G', 22 | 8: 'G#', 23 | 9: 'A', 24 | 10: 'A#', 25 | 11: 'B', 26 | } 27 | 28 | def traverse_dir( 29 | root_dir, 30 | extension=('mid', 'MID', 'midi'), 31 | amount=None, 32 | str_=None, 33 | is_pure=False, 34 | verbose=False, 35 | is_sort=False, 36 | is_ext=True): 37 | if verbose: 38 | print('[*] Scanning...') 39 | file_list = [] 40 | cnt = 0 41 | for root, _, files in os.walk(root_dir): 42 | for file in files: 43 | if file.endswith(extension): 44 | if (amount is not None) and (cnt == amount): 45 | break 46 | if str_ is not None: 47 | if str_ not in file: 48 | continue 49 | mix_path = os.path.join(root, file) 50 | pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path 51 | if not is_ext: 52 | ext = pure_path.split('.')[-1] 53 | pure_path = pure_path[:-(len(ext)+1)] 54 | if verbose: 55 | print(pure_path) 56 | file_list.append(pure_path) 57 | cnt += 1 58 | if verbose: 59 | print('Total: %d files' % len(file_list)) 60 | print('Done!!!') 61 | if is_sort: 62 | file_list.sort() 63 | return file_list 64 | 65 | 66 | def quantize_melody(notes, tick_resol=240): 67 | melody_notes = [] 68 | for note in notes: 69 | # cut too long notes 70 | if note.end - note.start > tick_resol * 8: 71 | note.end = note.start + tick_resol * 4 72 | 73 | # quantize 74 | note.start = int(np.round(note.start / tick_resol) * tick_resol) 75 | note.end = int(np.round(note.end / tick_resol) * tick_resol) 76 | 77 | # append 78 | melody_notes.append(note) 79 | return melody_notes 80 | 81 | 82 | def extract_melody(notes): 83 | # quantize 84 | melody_notes = quantize_melody(notes) 85 | # print(melody_notes[0].start, melody_notes[0].pitch) 86 | 87 | # sort by start, pitch from high to low 88 | melody_notes.sort(key=lambda x: (x.start, -x.pitch)) 89 | print("melody_notes: ", len(melody_notes)) 90 | 91 | # exclude notes < 60 92 | bins = [] 93 | prev = None 94 | tmp_list = [] 95 | for nidx in range(len(melody_notes)): 96 | note = melody_notes[nidx] 97 | if note.pitch >= 60: 98 | if note.start != prev: 99 | if tmp_list: 100 | bins.append(tmp_list) 101 | tmp_list = [note] 102 | else: 103 | tmp_list.append(note) 104 | prev = note.start 105 | 106 | # preserve only highest one at each step 107 | notes_out = [] 108 | for b in bins: 109 | notes_out.append(b[0]) 110 | 111 | # avoid overlapping 112 | notes_out.sort(key=lambda x:x.start) 113 | for idx in range(len(notes_out) - 1): 114 | if notes_out[idx].end >= notes_out[idx+1].start: 115 | notes_out[idx].end = notes_out[idx+1].start 116 | 117 | print("notes_out: ", len(notes_out)) 118 | # delete note having no duration 119 | notes_clean = [] 120 | for note in notes_out: 121 | if note.start != note.end: 122 | notes_clean.append(note) 123 | 124 | # filtered by interval 125 | notes_final = [notes_clean[0]] 126 | for i in range(1, len(notes_clean) -1): 127 | if ((notes_clean[i].pitch - notes_clean[i-1].pitch) <= -9) and \ 128 | ((notes_clean[i].pitch - notes_clean[i+1].pitch) <= -9): 129 | continue 130 | else: 131 | notes_final.append(notes_clean[i]) 132 | notes_final += [notes_clean[-1]] 133 | print("notes_final: ", len(notes_final)) 134 | return notes_final 135 | 136 | 137 | '''def proc_one(path_infile, path_outfile): 138 | # load 139 | midi_obj = miditoolkit.midi.parser.MidiFile(path_infile) 140 | midi_obj_out = copy.deepcopy(midi_obj) 141 | notes = midi_obj.instruments[0].notes 142 | notes = sorted(notes, key=lambda x: (x.start, x.pitch)) 143 | 144 | # --- chord --- # 145 | # exctract chord 146 | chords = Dechorder.dechord(midi_obj) 147 | markers = [] 148 | for cidx, chord in enumerate(chords): 149 | if chord.is_complete(): 150 | chord_text = num2pitch[chord.root_pc] + '_' + chord.quality + '_' + num2pitch[chord.bass_pc] 151 | else: 152 | chord_text = 'N_N_N' 153 | markers.append(Marker(time=int(cidx*480), text=chord_text)) 154 | 155 | # dedup 156 | prev_chord = None 157 | dedup_chords = [] 158 | for m in markers: 159 | if m.text != prev_chord: 160 | prev_chord = m.text 161 | dedup_chords.append(m) 162 | 163 | # --- structure --- # 164 | # structure analysis 165 | bounds, labs = segmenter.proc_midi(path_infile) 166 | bounds = np.round(bounds / 4) 167 | bounds = np.unique(bounds) 168 | print(' > [structure] bars:', bounds) 169 | print(' > [structure] labs:', labs) 170 | 171 | bounds_marker = [] 172 | for i in range(len(labs)): 173 | b = bounds[i] 174 | l = int(labs[i]) 175 | bounds_marker.append( 176 | Marker(time=int(b*4*480), text='Boundary_'+str(l))) 177 | 178 | # --- melody --- # 179 | melody_notes = extract_melody(notes) 180 | melody_notes = quantize_melody(melody_notes) 181 | 182 | # --- global properties --- # 183 | # global tempo 184 | tempos = [b.tempo for b in midi_obj.tempo_changes][:40] 185 | tempo_median = np.median(tempos) 186 | global_bpm =int(tempo_median) 187 | print(' > [global] bpm:', global_bpm) 188 | 189 | # === save === # 190 | # mkdir 191 | fn = os.path.basename(path_outfile) 192 | os.makedirs(path_outfile[:-len(fn)], exist_ok=True) 193 | 194 | # save piano (0) and melody (1) 195 | melody_track = Instrument(program=0, is_drum=False, name='melody') 196 | melody_track.notes = melody_notes 197 | midi_obj_out.instruments.append(melody_track) 198 | 199 | # markers 200 | midi_obj_out.markers = dedup_chords + bounds_marker 201 | midi_obj_out.markers.insert(0, Marker(text='global_bpm_'+str(int(global_bpm)), time=0)) 202 | 203 | # save 204 | midi_obj_out.instruments[0].name = 'piano' 205 | midi_obj_out.dump(path_outfile) 206 | ''' 207 | 208 | if __name__ == '__main__': 209 | # paths 210 | path_indir = './midi_synchronized' 211 | path_outdir = './midi_analyzed' 212 | os.makedirs(path_outdir, exist_ok=True) 213 | 214 | # list files 215 | midifiles = traverse_dir( 216 | path_indir, 217 | is_pure=True, 218 | is_sort=True) 219 | n_files = len(midifiles) 220 | print('num fiels:', n_files) 221 | 222 | # run 223 | for fidx in range(n_files): 224 | path_midi = midifiles[fidx] 225 | print('{}/{}'.format(fidx, n_files)) 226 | 227 | # paths 228 | path_infile = os.path.join(path_indir, path_midi) 229 | path_outfile = os.path.join(path_outdir, path_midi) 230 | 231 | # proc 232 | #proc_one(path_infile, path_outfile) 233 | -------------------------------------------------------------------------------- /melody_extraction/pianoroll/plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | import matplotlib as mpl 5 | import matplotlib.pyplot as plt 6 | import matplotlib.patches as mpatches 7 | from matplotlib.colors import colorConverter 8 | import miditoolkit.midi.parser as midparser 9 | import setting 10 | 11 | from analyzer import extract_melody 12 | from melody_extraction.midibert.utils import DEFAULT_VELOCITY_BINS, DEFAULT_FRACTION, DEFAULT_DURATION_BINS, DEFAULT_TEMPO_INTERVALS, DEFAULT_RESOLUTION 13 | 14 | e2w, w2e = pickle.load(open("data_creation/prepare_data/dict/CP.pkl", "rb")) 15 | keys = list(w2e.keys()) 16 | ans_dict = { 17 | 0: "padding", 18 | 1: "melody", # melody & bridge 19 | 2: "non-melody" # piano 20 | } 21 | bar_tick = 16 22 | seg_len = 512 23 | 24 | # DEFAULT_RESOLUTION = 480 25 | # DEFAULT_VELOCITY_BINS = np.linspace(0, 128, 64+1, dtype=np.int) 26 | ticks_per_beat = 480 27 | ticks_per_bar = DEFAULT_RESOLUTION * 4 # assume 4/4 28 | 29 | class NOTE(object): 30 | def __init__(self, start, end, pitch, velocity, Type): 31 | self.velocity = velocity 32 | self.pitch = pitch 33 | self.start = start 34 | self.end = end 35 | self.Type = Type 36 | def get_duration(self): 37 | return self.end - self.start 38 | def __repr__(self): 39 | return f'Note(start={self.start}, end={self.end}, pitch={self.pitch}, velocity={self.velocity}, Type={self.Type})' 40 | 41 | def make_note_dict(): 42 | note_dict = {} 43 | note = ['C', 'D', 'E', 'F', 'G', 'A', 'B'] 44 | note_dict[21] = 'A0' 45 | note_dict[23] = 'B0' 46 | base_number = 24 47 | for i in range(1, 10): 48 | for n in note: 49 | note_dict[base_number] = n + str(i) 50 | if n == 'E' or n == 'B': 51 | base_number += 1 52 | else: 53 | base_number += 2 54 | return note_dict 55 | 56 | def events2notes(events, verbose=True): 57 | notes = [] 58 | if len(events.shape) > 2: 59 | events = events.reshape(-1, events.shape[-1]) 60 | 61 | current_bar = -1 62 | for event in events: 63 | # 2.1 Turn CP token to events 64 | # [Bar, Position, Pitch, Velocity, Duration, Tempo] 65 | bar = w2e['Bar'][event[0]] 66 | pos = w2e['Position'][event[1]] 67 | pitch = w2e['Pitch'][event[2]] 68 | velocity = w2e['Velocity'][event[3]] 69 | duration = w2e['Duration'][event[4]] 70 | 71 | if bar == "Bar ": 72 | continue 73 | if verbose: 74 | print(bar, pos, pitch, velocity, duration) 75 | # 2.2 Turn events to value 76 | pos = int(pos.split(" ")[1].split("/")[0]) - 1 77 | pitch = int(pitch.split(" ")[1]) 78 | duration = DEFAULT_DURATION_BINS[int(duration.split(" ")[1])] 79 | 80 | # 2.3 Turn value to note 81 | if bar == "Bar New": 82 | current_bar += 1 83 | current_bar_st = current_bar * ticks_per_bar 84 | current_bar_et = (current_bar + 1) * ticks_per_bar 85 | flags = np.linspace(current_bar_st, current_bar_et, DEFAULT_FRACTION, endpoint=False, dtype=int) 86 | st = flags[pos] 87 | et = st + duration 88 | notes.append(NOTE(st, et, pitch, velocity, 0)) 89 | if verbose: 90 | print(st, et, pitch, velocity) 91 | return notes 92 | 93 | 94 | def data2pr(data, ans, max_bar, verbose=False): 95 | 96 | # 24 bars = 24 * 4 (beat)* 4(semiquater) = 384 97 | melody_pr = np.zeros((128, max_bar * 16)) 98 | accomp_pr = np.zeros((128, max_bar * 16)) 99 | 100 | bars = -1 101 | for i in range(data.shape[0]): # segment 102 | for j in range(data.shape[1]): # note 103 | # [Bar, Position, Pitch, Velocity, Duration, Tempo] 104 | if data[i][j][0] == 0: 105 | bars += 1 106 | if bars >= max_bar: 107 | break 108 | 109 | duration = data[i][j][4] # +1 110 | pitch = data[i][j][2] + 22 111 | pos = data[i][j][1] 112 | 113 | start = bars * bar_tick + pos 114 | duration = DEFAULT_DURATION_BINS[duration] 115 | duration = duration // 120 116 | end = min(start + duration, max_bar*16 - 1) 117 | 118 | if ans[i][j] == 1: 119 | melody_pr[pitch, start: end] = np.ones((1, end - start)) 120 | elif ans[i][j] == 2: 121 | accomp_pr[pitch, start: end] = np.ones((1, end - start)) 122 | if verbose: 123 | print("{} {} {} {}".format(pos, pitch, duration, ans_dict[ans[i][j]])) 124 | if bars >= max_bar: 125 | break 126 | return melody_pr, accomp_pr 127 | 128 | def plot_roll(pianoroll, pixels, filename): 129 | # pianoroll[idx, msg['pitch'], note_on_start_time:note_on_end_time] = intensity # idx => 0: melody, 1: non-melody 130 | 131 | # build and set figure object 132 | plt.ioff() 133 | fig = plt.figure(figsize=(17, 11)) 134 | a1 = fig.add_subplot(111) 135 | a1.axis("equal") 136 | a1.set_facecolor("white") 137 | a1.set_axisbelow(True) 138 | a1.yaxis.grid(color='gray', linestyle='dashed') 139 | 140 | # set colors 141 | channel_nb = 2 142 | transparent = colorConverter.to_rgba('white') 143 | colors = [mpl.colors.to_rgba('lightcoral'), mpl.colors.to_rgba('cornflowerblue')] 144 | cmaps = [mpl.colors.LinearSegmentedColormap.from_list('my_cmap', [transparent, colors[i]], 128) for i in range(channel_nb)] 145 | 146 | # build color maps 147 | for i in range(channel_nb): 148 | cmaps[i]._init() 149 | # create your alpha array and fill the colormap with them 150 | alphas = np.linspace(0, 1, cmaps[i].N + 3) 151 | # create the _lut array, with rgba value 152 | cmaps[i]._lut[:, -1] = alphas 153 | 154 | label_name = ['melody', 'non-melody'] 155 | a1.imshow(pianoroll[1], origin="lower", interpolation="nearest", cmap=cmaps[1], aspect='auto', label=label_name[1]) 156 | a1.imshow(pianoroll[0], origin="lower", interpolation="nearest", cmap=cmaps[0], aspect='auto', label=label_name[0]) 157 | note_dict = make_note_dict() 158 | 159 | # set scale and limit of axis 160 | interval = 64 161 | plt.xticks([i*interval for i in range(13)], [i*4 for i in range(13)]) 162 | plt.yticks([(24 + (y)*12) for y in range(8)], [note_dict[24 + (y)*12] for y in range(8)]) 163 | plt.ylim([36, 96]) # C2 to C8 164 | 165 | # show legend, and create a patch (proxy artist) for every color 166 | patches = [ mpatches.Patch(color=colors[i], label=label_name[i] ) for i in range(channel_nb) ] 167 | # put those patched as legend-handles into the legend 168 | first_legend = plt.legend(handles=[patches[0]], loc=2, fontsize=40) 169 | ax = plt.gca().add_artist(first_legend) 170 | plt.legend(handles=[patches[1]], loc=1, fontsize=40) 171 | 172 | # save pianoroll to figure 173 | plt.xticks(fontsize=40) 174 | plt.yticks(fontsize=40) 175 | plt.xlabel("bars", fontsize=40) 176 | plt.ylabel("note name", fontsize=40) 177 | 178 | plt.savefig('pianoroll_' + filename) 179 | return 180 | 181 | def skyline_from_midi(file): 182 | obj = midparser.MidiFile(file) 183 | melody = obj.instruments[0].notes 184 | bridge = obj.instruments[1].notes 185 | piano = obj.instruments[2].notes 186 | melody = [NOTE(i.start, i.end, i.pitch, i.velocity, 0) for i in melody] 187 | bridge = [NOTE(i.start, i.end, i.pitch, i.velocity, 1) for i in bridge] 188 | piano = [NOTE(i.start, i.end, i.pitch, i.velocity, 2) for i in piano] 189 | all_notes = [] 190 | all_notes.extend(melody) 191 | all_notes.extend(bridge) 192 | all_notes.extend(piano) 193 | all_notes.sort(key=lambda x:(x.start)) 194 | pred_m = extract_melody(all_notes) 195 | print(pred_m, len(pred_m)) 196 | melody_pr = notes_to_pianoroll(pred_m, 48*4*4) 197 | accomp_pr = notes_to_pianoroll(all_notes, 48*4*4) 198 | return melody_pr, accomp_pr 199 | 200 | def skyline_from_cp(data): 201 | """ 202 | 2023/11/26 203 | """ 204 | notes = events2notes(data, verbose=False) 205 | pred_m = extract_melody(notes) 206 | 207 | melody_pr = notes_to_pianoroll(pred_m, 48*4*4) # 48 bars, 4 ticks per beat, 4 semiquaver per beat 208 | accomp_pr = notes_to_pianoroll(notes, 48*4*4) 209 | return melody_pr, accomp_pr 210 | 211 | 212 | def notes_to_pianoroll(notes, length): 213 | melody_pr = np.zeros((128, length)) 214 | 215 | for i, note in enumerate(notes): 216 | s, e = int(note.start/120), int(note.end/120) # 4 pixel per beat 217 | print(s, e) 218 | if s > length: 219 | break 220 | e = min(length, e) 221 | melody_pr[note.pitch, s:e] = np.ones((1, e-s)) 222 | 223 | return melody_pr 224 | 225 | def main(): 226 | data = np.load(setting.data_path) 227 | if 'gt' in setting.mode: 228 | # ground truth 229 | ans = np.load(setting.ground_truth_path) 230 | melody_pr, accomp_pr = data2pr(data, ans, 48, verbose=True) 231 | plot_roll([melody_pr, accomp_pr], 48*4*4, 'gt') 232 | print("-------") 233 | 234 | if 'skyline' in setting.mode: 235 | # skyline_from_midi 236 | # sky_melody, sky_accomp = skyline2("018.mid") 237 | # print(sky_melody.shape, sky_accomp.shape) 238 | # plot_roll([sky_melody, sky_accomp], 48*4*4) 239 | 240 | # skyline_from_cp 2023/11/26 241 | sky_melody, sky_accomp = skyline_from_cp(data) 242 | plot_roll([sky_melody, sky_accomp], 48*4*4, 'skyline') 243 | print("-------") 244 | 245 | if 'bert' in setting.mode: 246 | # bert 247 | bert_pred = np.load(setting.bert_ans_path) 248 | melody_pr, accomp_pr = data2pr(data, bert_pred, 48, verbose=True) 249 | plot_roll([melody_pr, accomp_pr], 48*4*4, 'bert') 250 | 251 | 252 | if __name__ == '__main__': 253 | main() 254 | -------------------------------------------------------------------------------- /melody_extraction/pianoroll/setting.py: -------------------------------------------------------------------------------- 1 | # settings for data 2 | dir_path = 'melody_extraction/pianoroll/data/' 3 | data_path = dir_path + 'pop909_596_cp_6.npy' 4 | ground_truth_path = dir_path + 'pop909_596_groundtruth_6.npy' 5 | bert_ans_path = dir_path + 'pop909_596_bert_6.npy' 6 | 7 | # indicate what pianoroll image we want to generate 8 | mode = ['gt', 'skyline', 'bert'] 9 | -------------------------------------------------------------------------------- /melody_extraction/skyline/README.md: -------------------------------------------------------------------------------- 1 | # Skyline 2 | 3 | Get the accuracy on pop909 using skyline algorithm 4 | ``` 5 | cd melody_extraction/skyline 6 | python3 cal_acc.py 7 | ``` 8 | 9 | Since Pop909 contains *melody*, *bridge*, *accompaniment*, yet skyline cannot distinguish between melody and bridge. 10 | 11 | There are 2 ways to report its accuracy: 12 | 13 | 1. Consider *Bridge* as *Accompaniment*, attains 78.54% accuracy 14 | 2. Consider *Bridge* as *Melody*, attains 79.51% 15 | 16 | Special thanks to Wen-Yi Hsiao for providing the code for skyline algorithm. 17 | -------------------------------------------------------------------------------- /melody_extraction/skyline/__pycache__/analyzer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/melody_extraction/skyline/__pycache__/analyzer.cpython-36.pyc -------------------------------------------------------------------------------- /melody_extraction/skyline/analyzer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import numpy as np 4 | 5 | import miditoolkit 6 | from miditoolkit.midi import parser as mid_parser 7 | from miditoolkit.pianoroll import parser as pr_parser 8 | from miditoolkit.midi.containers import Marker, Instrument, TempoChange 9 | 10 | from chorder import Dechorder 11 | 12 | 13 | num2pitch = { 14 | 0: 'C', 15 | 1: 'C#', 16 | 2: 'D', 17 | 3: 'D#', 18 | 4: 'E', 19 | 5: 'F', 20 | 6: 'F#', 21 | 7: 'G', 22 | 8: 'G#', 23 | 9: 'A', 24 | 10: 'A#', 25 | 11: 'B', 26 | } 27 | 28 | def traverse_dir( 29 | root_dir, 30 | extension=('mid', 'MID', 'midi'), 31 | amount=None, 32 | str_=None, 33 | is_pure=False, 34 | verbose=False, 35 | is_sort=False, 36 | is_ext=True): 37 | if verbose: 38 | print('[*] Scanning...') 39 | file_list = [] 40 | cnt = 0 41 | for root, _, files in os.walk(root_dir): 42 | for file in files: 43 | if file.endswith(extension): 44 | if (amount is not None) and (cnt == amount): 45 | break 46 | if str_ is not None: 47 | if str_ not in file: 48 | continue 49 | mix_path = os.path.join(root, file) 50 | pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path 51 | if not is_ext: 52 | ext = pure_path.split('.')[-1] 53 | pure_path = pure_path[:-(len(ext)+1)] 54 | if verbose: 55 | print(pure_path) 56 | file_list.append(pure_path) 57 | cnt += 1 58 | if verbose: 59 | print('Total: %d files' % len(file_list)) 60 | print('Done!!!') 61 | if is_sort: 62 | file_list.sort() 63 | return file_list 64 | 65 | 66 | def quantize_melody(notes, tick_resol=240): 67 | melody_notes = [] 68 | for note in notes: 69 | # cut too long notes 70 | if note.end - note.start > tick_resol * 8: 71 | note.end = note.start + tick_resol * 4 72 | 73 | # quantize 74 | note.start = int(np.round(note.start / tick_resol) * tick_resol) 75 | note.end = int(np.round(note.end / tick_resol) * tick_resol) 76 | 77 | # append 78 | melody_notes.append(note) 79 | return melody_notes 80 | 81 | 82 | def extract_melody(notes): 83 | # quantize 84 | melody_notes = quantize_melody(notes) 85 | 86 | # sort by start, pitch from high to low 87 | melody_notes.sort(key=lambda x: (x.start, -x.pitch)) 88 | 89 | # exclude notes < 60 90 | bins = [] 91 | prev = None 92 | tmp_list = [] 93 | for nidx in range(len(melody_notes)): 94 | note = melody_notes[nidx] 95 | if note.pitch >= 60: 96 | if note.start != prev: 97 | if tmp_list: 98 | bins.append(tmp_list) 99 | tmp_list = [note] 100 | else: 101 | tmp_list.append(note) 102 | prev = note.start 103 | 104 | # preserve only highest one at each step 105 | notes_out = [] 106 | for b in bins: 107 | notes_out.append(b[0]) 108 | 109 | # avoid overlapping 110 | notes_out.sort(key=lambda x:x.start) 111 | for idx in range(len(notes_out) - 1): 112 | if notes_out[idx].end >= notes_out[idx+1].start: 113 | notes_out[idx].end = notes_out[idx+1].start 114 | 115 | # delete note having no duration 116 | notes_clean = [] 117 | for note in notes_out: 118 | if note.start != note.end: 119 | notes_clean.append(note) 120 | 121 | # filtered by interval 122 | notes_final = [notes_clean[0]] 123 | for i in range(1, len(notes_clean) -1): 124 | if ((notes_clean[i].pitch - notes_clean[i-1].pitch) <= -9) and \ 125 | ((notes_clean[i].pitch - notes_clean[i+1].pitch) <= -9): 126 | continue 127 | else: 128 | notes_final.append(notes_clean[i]) 129 | notes_final += [notes_clean[-1]] 130 | return notes_final 131 | 132 | 133 | '''def proc_one(path_infile, path_outfile): 134 | # load 135 | midi_obj = miditoolkit.midi.parser.MidiFile(path_infile) 136 | midi_obj_out = copy.deepcopy(midi_obj) 137 | notes = midi_obj.instruments[0].notes 138 | notes = sorted(notes, key=lambda x: (x.start, x.pitch)) 139 | 140 | # --- chord --- # 141 | # exctract chord 142 | chords = Dechorder.dechord(midi_obj) 143 | markers = [] 144 | for cidx, chord in enumerate(chords): 145 | if chord.is_complete(): 146 | chord_text = num2pitch[chord.root_pc] + '_' + chord.quality + '_' + num2pitch[chord.bass_pc] 147 | else: 148 | chord_text = 'N_N_N' 149 | markers.append(Marker(time=int(cidx*480), text=chord_text)) 150 | 151 | # dedup 152 | prev_chord = None 153 | dedup_chords = [] 154 | for m in markers: 155 | if m.text != prev_chord: 156 | prev_chord = m.text 157 | dedup_chords.append(m) 158 | 159 | # --- structure --- # 160 | # structure analysis 161 | bounds, labs = segmenter.proc_midi(path_infile) 162 | bounds = np.round(bounds / 4) 163 | bounds = np.unique(bounds) 164 | print(' > [structure] bars:', bounds) 165 | print(' > [structure] labs:', labs) 166 | 167 | bounds_marker = [] 168 | for i in range(len(labs)): 169 | b = bounds[i] 170 | l = int(labs[i]) 171 | bounds_marker.append( 172 | Marker(time=int(b*4*480), text='Boundary_'+str(l))) 173 | 174 | # --- melody --- # 175 | melody_notes = extract_melody(notes) 176 | melody_notes = quantize_melody(melody_notes) 177 | 178 | # --- global properties --- # 179 | # global tempo 180 | tempos = [b.tempo for b in midi_obj.tempo_changes][:40] 181 | tempo_median = np.median(tempos) 182 | global_bpm =int(tempo_median) 183 | print(' > [global] bpm:', global_bpm) 184 | 185 | # === save === # 186 | # mkdir 187 | fn = os.path.basename(path_outfile) 188 | os.makedirs(path_outfile[:-len(fn)], exist_ok=True) 189 | 190 | # save piano (0) and melody (1) 191 | melody_track = Instrument(program=0, is_drum=False, name='melody') 192 | melody_track.notes = melody_notes 193 | midi_obj_out.instruments.append(melody_track) 194 | 195 | # markers 196 | midi_obj_out.markers = dedup_chords + bounds_marker 197 | midi_obj_out.markers.insert(0, Marker(text='global_bpm_'+str(int(global_bpm)), time=0)) 198 | 199 | # save 200 | midi_obj_out.instruments[0].name = 'piano' 201 | midi_obj_out.dump(path_outfile) 202 | ''' 203 | 204 | if __name__ == '__main__': 205 | # paths 206 | path_indir = './midi_synchronized' 207 | path_outdir = './midi_analyzed' 208 | os.makedirs(path_outdir, exist_ok=True) 209 | 210 | # list files 211 | midifiles = traverse_dir( 212 | path_indir, 213 | is_pure=True, 214 | is_sort=True) 215 | n_files = len(midifiles) 216 | print('num fiels:', n_files) 217 | 218 | # run 219 | for fidx in range(n_files): 220 | path_midi = midifiles[fidx] 221 | print('{}/{}'.format(fidx, n_files)) 222 | 223 | # paths 224 | path_infile = os.path.join(path_indir, path_midi) 225 | path_outfile = os.path.join(path_outdir, path_midi) 226 | 227 | # proc 228 | #proc_one(path_infile, path_outfile) 229 | -------------------------------------------------------------------------------- /melody_extraction/skyline/cal_acc.py: -------------------------------------------------------------------------------- 1 | from analyzer import extract_melody 2 | import glob 3 | from operator import itemgetter 4 | import miditoolkit.midi.parser as midparser 5 | 6 | class NOTE(object): 7 | def __init__(self, start, end, velocity, pitch, Type): 8 | self.start = start 9 | self.end = end 10 | self.velocity = velocity 11 | self.pitch = pitch 12 | self.Type = Type 13 | def __repr__(self): 14 | return 'NOTE(start={}, end={}, velocity={}, pitch={}, Type={})'.format( 15 | self.start, self.end, self.velocity, self.pitch, self.Type) 16 | 17 | 18 | def extract(file, f): 19 | obj = midparser.MidiFile(file) 20 | melody = obj.instruments[0].notes 21 | bridge = obj.instruments[1].notes 22 | piano = obj.instruments[2].notes 23 | melody = [NOTE(i.start, i.end, i.velocity, i.pitch, 0) for i in melody] 24 | bridge = [NOTE(i.start, i.end, i.velocity, i.pitch, 1) for i in bridge] 25 | piano = [NOTE(i.start, i.end, i.velocity, i.pitch, 2) for i in piano] 26 | all_notes = [] 27 | all_notes.extend(melody) 28 | all_notes.extend(bridge) 29 | all_notes.extend(piano) 30 | 31 | all_notes.sort(key=lambda x:(x.start)) 32 | 33 | print('# melody: {}, # bridge: {}, # piano: {}'.format(len(melody), len(bridge), len(piano))) 34 | f.write('# melody: {}, # bridge: {}, # piano: {}\n'.format(len(melody), len(bridge), len(piano))) 35 | 36 | # extract 37 | pred_m = extract_melody(all_notes) 38 | 39 | all_notes = set(all_notes) 40 | pred_m = set(pred_m) 41 | 42 | TP1, TP2, FP1, FP2 = 0, 0, 0, 0 43 | for i in pred_m: 44 | if i.Type == 0 or i.Type == 1: 45 | # predict correctly 46 | TP1 += 1 47 | else: 48 | FP1 += 1 49 | 50 | FN1 = len(melody) + len(bridge) - TP1 51 | TN1 = len(all_notes) - TP1 - FN1 - FP1 52 | 53 | for i in pred_m: 54 | if i.Type == 0: 55 | # predict correctly 56 | TP2 += 1 57 | else: 58 | FP2 += 1 59 | 60 | FN2 = len(melody) - TP2 61 | TN2 = len(all_notes) - TP2 - FP2 - FN2 62 | 63 | print('accuracy (melody only):', (TP2+TN2)/(TP2+TN2+FP2+FN2)) 64 | print('accuracy (melody & bridge):', (TP1+TN1)/(TP1+TN1+FP1+FN1)) 65 | 66 | f.write('TN2:{}, TP2:{}, FN2:{}, FP2:{}\n'.format(TN2, TP2, FN2, FP2)) 67 | f.write('accuracy (melody only):' + str((TP2+TN2)/(TP2+TN2+FP2+FN2)) + '\n') 68 | 69 | f.write('TN1:{}, TP1:{}, FN1:{}, FP1:{}\n'.format(TN1, TP1, FN1, FP1)) 70 | f.write('accuracy (melody & bridge):' + str((TP1+TN1)/(TP1+TN1+FP1+FN1)) + '\n') 71 | return TP1, TN1, FP1, FN1, TP2, TN2, FP2, FN2 72 | 73 | def main(): 74 | # use test set ONLY 75 | # please change the root_dir to `your path to pop909 test set` 76 | root_dir = '/home/user/Dataset/pop909_aligned/test' 77 | files = glob.glob(f'{root_dir}/*.mid') 78 | 79 | TP1, TN1, FP1, FN1, TP2, TN2, FP2, FN2 = 0,0,0,0,0,0,0,0 80 | 81 | with open('acc.log','a') as f: 82 | f.write('root_dir: ' + root_dir + '\n') 83 | f.write('# file: ' + str(len(files)) + '\n') 84 | for file in files: 85 | print('\n[',file,']') 86 | f.write('\n[ ' + file.split('/')[-1] + ' ]\n') 87 | tp1, tn1, fp1, fn1, tp2, tn2, fp2, fn2 = extract(file, f) 88 | TP1 += tp1; TN1 += tn1; FP1 += fp1; FN1 += fn1; 89 | TP2 += tp2; TN2 += tn2; FP2 += fp2; FN2 += fn2; 90 | 91 | f.write('\navg accuracy (melody only):' + str((TP2+TN2)/(TP2+TN2+FP2+FN2)) + '\n') 92 | f.write(f'(melody only) TP: {TP2}, FP: {FP2}, FN: {FN2}, TN:{TN2} \n') 93 | f.write('avg accuracy (melody & bridge):' + str((TP1+TN1)/(TP1+TN1+FP1+FN1)) + '\n') 94 | 95 | 96 | if __name__ == '__main__': 97 | main() 98 | -------------------------------------------------------------------------------- /melody_extraction/skyline/test.py: -------------------------------------------------------------------------------- 1 | from analyzer import extract_melody 2 | import glob 3 | from operator import itemgetter 4 | import miditoolkit.midi.parser as midparser 5 | 6 | class NOTE(object): 7 | def __init__(self, start, end, velocity, pitch, Type): 8 | self.start = start 9 | self.end = end 10 | self.velocity = velocity 11 | self.pitch = pitch 12 | self.Type = Type 13 | def __repr__(self): 14 | return 'NOTE(start={}, end={}, velocity={}, pitch={}, Type={})'.format( 15 | self.start, self.end, self.velocity, self.pitch, self.Type) 16 | 17 | 18 | def extract(file, f): 19 | obj = midparser.MidiFile(file) 20 | melody = obj.instruments[0].notes 21 | bridge = obj.instruments[1].notes 22 | piano = obj.instruments[2].notes 23 | melody = [NOTE(i.start, i.end, i.velocity, i.pitch, 0) for i in melody] 24 | bridge = [NOTE(i.start, i.end, i.velocity, i.pitch, 1) for i in bridge] 25 | piano = [NOTE(i.start, i.end, i.velocity, i.pitch, 2) for i in piano] 26 | all_notes = [] 27 | all_notes.extend(melody) 28 | all_notes.extend(bridge) 29 | all_notes.extend(piano) 30 | 31 | all_notes.sort(key=lambda x:(x.start)) 32 | 33 | print('# melody: {}, # bridge: {}, # piano: {}'.format(len(melody), len(bridge), len(piano))) 34 | f.write('# melody: {}, # bridge: {}, # piano: {}\n'.format(len(melody), len(bridge), len(piano))) 35 | 36 | # extract 37 | pred_m = extract_melody(all_notes) 38 | 39 | notesL = len(all_notes) 40 | print('# all_notes', notesL) 41 | all_notes = set(all_notes) 42 | if (notesL != len(all_notes)): 43 | print('error!') 44 | exit(1) 45 | pred_m_L = len(pred_m) 46 | print('pred melody (list):', pred_m_L) 47 | pred_m = set(pred_m) 48 | if (pred_m_L != len(pred_m)): 49 | print('error!') 50 | exit(1) 51 | 52 | TP1, TP2, FP1, FP2 = 0, 0, 0, 0 53 | for i in pred_m: 54 | if i.Type == 0 or i.Type == 1: 55 | # predict correctly 56 | TP1 += 1 57 | else: 58 | FP1 += 1 59 | 60 | TN1 = len(melody) + len(bridge) - TP1 61 | FN1 = len(all_notes) - TP1 - FP1 - TN1 62 | 63 | for i in pred_m: 64 | if i.Type == 0: 65 | # predict correctly 66 | TP2 += 1 67 | else: 68 | FP2 += 1 69 | 70 | TN2 = len(melody) - TP2 71 | FN2 = len(all_notes) - TP2 - FP2 - TN2 72 | '''pred_p = all_notes - pred_m 73 | 74 | for i in pred_p: 75 | if i.Type == 2: 76 | # predict correctly 77 | acc1 += 1 78 | else: 79 | wrong1 += 1 80 | 81 | for i in pred_p: 82 | if i.Type == 1 or i.Type == 2: 83 | # predict correctly 84 | acc2 += 1 85 | else: 86 | wrong2 += 1 87 | ' 88 | print('accuracy (melody only):', acc2/(acc2+wrong2)) 89 | print('accuracy (melody & bridge):', acc1/(acc1+wrong1)) 90 | f.write('acc(melody only): ' + str(acc2/(acc2+wrong2)) + '\n') 91 | f.write('acc(melody + bridge): ' + str(acc1/(acc1+wrong1)) + '\n') 92 | return acc1, acc2, wrong1, wrong2 93 | ''' 94 | 95 | print('accuracy (melody only):', (TP2+FN2)/(TP2+TN2+FP2+FN2)) 96 | print('accuracy (melody & bridge):', (TP1+FN1)/(TP1+TN1+FP1+FN1)) 97 | f.write('TN2:{}, TP2:{}, FN2:{}, FP2:{}\n'.format(TN2, TP2, FN2, FP2)) 98 | f.write('accuracy (melody only):' + str((TP2+FN2)/(TP2+TN2+FP2+FN2)) + '\n') 99 | f.write('TN1:{}, TP1:{}, FN1:{}, FP1:{}\n'.format(TN1, TP1, FN1, FP1)) 100 | f.write('accuracy (melody & bridge):' + str((TP1+FN1)/(TP1+TN1+FP1+FN1)) + '\n') 101 | return TP1, TN1, FP1, FN1, TP2, TN2, FP2, FN2 102 | 103 | def main(): 104 | root_dir = '/home/yh1488/NAS-189/home/Dataset/pop909_aligned/' 105 | files = glob.glob(root_dir+'*.mid') 106 | # use test set ONLY 107 | files = files[-86:] 108 | # all_acc1, all_acc2, all_wrong1, all_wrong2 = 0,0,0,0 109 | TP1, TN1, FP1, FN1, TP2, TN2, FP2, FN2 = 0,0,0,0,0,0,0,0 110 | 111 | with open('acc.log','a') as f: 112 | f.write('root_dir: ' + root_dir + '\n') 113 | f.write('# file: ' + str(len(files)) + '\n') 114 | for file in files: 115 | print('\n[',file,']') 116 | f.write('\n[ ' + file.split('/')[-1] + ' ]\n') 117 | # acc1, acc2, wrong1, wrong2 = extract(file, f) 118 | tp1, tn1, fp1, fn1, tp2, tn2, fp2, fn2 = extract(file, f) 119 | '''all_acc1 += acc1 120 | all_acc2 += acc2 121 | all_wrong1 += wrong1 122 | all_wrong2 += wrong2''' 123 | TP1 += tp1; TN1 += tn1; FP1 += fp1; FN1 += fn1; 124 | TP2 += tp2; TN2 += tn2; FP2 += fp2; FN2 += fn2; 125 | 126 | #f.write('\navg_acc (melody only): ' + str(all_acc1/(all_acc1+all_wrong1)) + '\n') 127 | #f.write('avg_acc (melody & bridge): ' + str(all_acc2/(all_acc2+all_wrong2))) 128 | f.write('avg accuracy (melody only):' + str((TP2+FN2)/(TP2+TN2+FP2+FN2)) + '\n') 129 | f.write('avg accuracy (melody & bridge):' + str((TP1+FN1)/(TP1+TN1+FP1+FN1)) + '\n') 130 | 131 | 132 | if __name__ == '__main__': 133 | main() 134 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.13.3 2 | matplotlib>=3.3.3 3 | mido==1.2.10 4 | torch>=1.3.1 5 | chorder==0.1.2 6 | miditoolkit==0.1.14 7 | scikit_learn==0.24.2 8 | torchaudio==0.9.0 9 | transformers==4.8.2 10 | SoundFile 11 | tqdm 12 | pypianoroll 13 | -------------------------------------------------------------------------------- /resources/Adele.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/resources/Adele.mid -------------------------------------------------------------------------------- /resources/fig/midibert.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/resources/fig/midibert.png -------------------------------------------------------------------------------- /resources/fig/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wazenmai/MIDI-BERT/0b935584f641d3d59a8e6aff2f334b425ce1542d/resources/fig/result.png -------------------------------------------------------------------------------- /scripts/eval.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH="." 2 | 3 | # melody 4 | python3 MidiBERT/eval.py --task=melody --cpu 5 | 6 | # velocity 7 | python3 MidiBERT/eval.py --task=velociy --cpu 8 | 9 | # composer 10 | python3 MidiBERT/eval.py --task=composer --cpu 11 | 12 | # emotion 13 | python3 MidiBERT/eval.py --task=emotion --cpu 14 | 15 | -------------------------------------------------------------------------------- /scripts/finetune.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH="." 2 | 3 | # melody, its output folder name will be {task}_{name} 4 | python3 MidiBERT/finetune.py --task=melody --name=default 5 | 6 | # velocity 7 | python3 MidiBERT/finetune.py --task=velociy --name=default 8 | 9 | # composer 10 | python3 MidiBERT/finetune.py --task=composer --name=default 11 | 12 | # emotion 13 | python3 MidiBERT/finetune.py --task=emotion --name=default 14 | 15 | -------------------------------------------------------------------------------- /scripts/melody_extraction.sh: -------------------------------------------------------------------------------- 1 | song_path="resources/Adele.mid" 2 | model_path="result/finetune/melody_default/model_best.ckpt" 3 | 4 | export PYTHONPATH='.' 5 | python3 melody_extraction/midibert/extract.py --input=$song_path --ckpt=$model_path --cpu --bridge=False 6 | 7 | -------------------------------------------------------------------------------- /scripts/prepare_data.sh: -------------------------------------------------------------------------------- 1 | input_dir="../example_midis" 2 | 3 | export PYTHONPATH='.' 4 | 5 | # melody 6 | python3 data_creation/prepare_data/main.py --dataset=pop909 --task=melody 7 | 8 | # velocity 9 | python3 data_creation/prepare_data/main.py --dataset=pop909 --task=velocity 10 | 11 | # composer 12 | python3 data_creation/prepare_data/main.py --dataset=pianist8 --task=composer 13 | 14 | # emotion 15 | python3 data_creation/prepare_data/main.py --dataset=emopia --task=emotion 16 | 17 | # custom directory 18 | python3 data_creation/prepare_data/main.py --input_dir=$input_dir 19 | 20 | # custom single file 21 | python3 data_creation/prepare_data/main.py --input_file="${input_dir}/pop.mid" 22 | -------------------------------------------------------------------------------- /scripts/pretrain.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH="." 2 | 3 | python3 MidiBERT/main.py --name=default 4 | --------------------------------------------------------------------------------