├── musicautobot ├── utils │ ├── __init__.py │ ├── attention_mask.py │ ├── setup_musescore.py │ ├── top_k_top_p.py │ ├── file_processing.py │ ├── stacked_dataloader.py │ ├── midifile.py │ └── lamb.py ├── __init__.py ├── multitask_transformer │ ├── __init__.py │ ├── transform.py │ ├── dataloader.py │ ├── model.py │ └── learner.py ├── music_transformer │ ├── __init__.py │ ├── model.py │ ├── learner.py │ ├── transform.py │ └── dataloader.py ├── config.py ├── vocab.py └── numpy_encode.py ├── serve ├── api │ ├── api.cfg │ ├── config.py │ ├── __init__.py │ ├── save.py │ ├── predict.py │ └── predict_multitask.py ├── .flaskenv ├── run.py ├── run_guni.py ├── environment.yml ├── README.md └── app.json ├── images └── musicautobot_screenshot.png ├── data └── midi │ ├── notebook_examples │ ├── example.mid │ └── single_bar_example.mid │ └── examples │ ├── Levels - Avicii - Verse.mid │ ├── Middle - Zedd - Pre-Chorus.mid │ ├── La Bamba - Ritchie Valen - Chorus.mid │ ├── Let It Go - Idina Menzel - Chorus.mid │ ├── Colors Of The Wind - Disney - Chorus.mid │ ├── I Want You Back - Jackson 5 - Intro.mid │ ├── Just Give Me A Reason - Pink - Chorus.mid │ ├── Call Me Maybe - Carly Rae Jepsen - Chorus.mid │ ├── Fuer Elise - Ludwig Van Beethoven - Verse.mid │ ├── Locked Out Of Heaven - Bruno Mars - Chorus.mid │ ├── Roses Ft Rozes - The Chainsmokers - Chorus.mid │ ├── Canon In D Major - Johann Pachelbel - Chorus.mid │ ├── Where Is The Love - Black Eyed Peas - Chorus.mid │ ├── Can You Feel The Love Tonight - Elton John - Verse.mid │ ├── In The Hall Of The Mountain King - Edvard Grieg - Intro.mid │ ├── Scary Monsters And Nice Sprites - Skrillex - Pre-Chorus.mid │ ├── A Thousand Miles - Vanessa Carlton - Verse-And-Pre-Chorus.mid │ ├── All I Want For Christmas Is You - Mariah Carey - Pre-Chorus-And-Chorus.mid │ └── The Four Seasons Concerto No 4 Winter - Antonio Vivaldi - Instrumental.mid ├── scripts ├── run_ddp.sh ├── run_music_transformer.py └── run_multitask.py ├── environment.yml ├── LICENSE.md ├── .gitignore ├── notebooks ├── music_transformer │ ├── Train-Simple.ipynb │ ├── Train-Advanced.ipynb │ ├── Train.ipynb │ └── Generate_colab.ipynb └── multitask_transformer │ └── Generate_colab.ipynb └── README.md /musicautobot/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /musicautobot/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils.setup_musescore import setup_musescore 2 | 3 | setup_musescore() -------------------------------------------------------------------------------- /serve/api/api.cfg: -------------------------------------------------------------------------------- 1 | # Input bucket name only. Not the whole s3 URL 2 | # S3_BUCKET_NAME = 's3-bucket-name' -------------------------------------------------------------------------------- /musicautobot/multitask_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataloader import * 2 | from .model import * 3 | from .learner import * -------------------------------------------------------------------------------- /musicautobot/music_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataloader import * 2 | from .model import * 3 | from .learner import * -------------------------------------------------------------------------------- /images/musicautobot_screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearpelican/musicautobot/HEAD/images/musicautobot_screenshot.png -------------------------------------------------------------------------------- /data/midi/notebook_examples/example.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearpelican/musicautobot/HEAD/data/midi/notebook_examples/example.mid -------------------------------------------------------------------------------- /data/midi/examples/Levels - Avicii - Verse.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearpelican/musicautobot/HEAD/data/midi/examples/Levels - Avicii - Verse.mid -------------------------------------------------------------------------------- /data/midi/examples/Middle - Zedd - Pre-Chorus.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearpelican/musicautobot/HEAD/data/midi/examples/Middle - Zedd - Pre-Chorus.mid -------------------------------------------------------------------------------- /data/midi/notebook_examples/single_bar_example.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearpelican/musicautobot/HEAD/data/midi/notebook_examples/single_bar_example.mid -------------------------------------------------------------------------------- /data/midi/examples/La Bamba - Ritchie Valen - Chorus.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearpelican/musicautobot/HEAD/data/midi/examples/La Bamba - Ritchie Valen - Chorus.mid -------------------------------------------------------------------------------- /data/midi/examples/Let It Go - Idina Menzel - Chorus.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearpelican/musicautobot/HEAD/data/midi/examples/Let It Go - Idina Menzel - Chorus.mid -------------------------------------------------------------------------------- /data/midi/examples/Colors Of The Wind - Disney - Chorus.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearpelican/musicautobot/HEAD/data/midi/examples/Colors Of The Wind - Disney - Chorus.mid -------------------------------------------------------------------------------- /data/midi/examples/I Want You Back - Jackson 5 - Intro.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearpelican/musicautobot/HEAD/data/midi/examples/I Want You Back - Jackson 5 - Intro.mid -------------------------------------------------------------------------------- /serve/.flaskenv: -------------------------------------------------------------------------------- 1 | # Production Enviroment should be set to 'production' 2 | FLASK_ENV = "development" 3 | FLASK_APP = "app" 4 | # Uncomment this to debug: 5 | # FLASK_DEBUG=1 6 | -------------------------------------------------------------------------------- /data/midi/examples/Just Give Me A Reason - Pink - Chorus.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearpelican/musicautobot/HEAD/data/midi/examples/Just Give Me A Reason - Pink - Chorus.mid -------------------------------------------------------------------------------- /data/midi/examples/Call Me Maybe - Carly Rae Jepsen - Chorus.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearpelican/musicautobot/HEAD/data/midi/examples/Call Me Maybe - Carly Rae Jepsen - Chorus.mid -------------------------------------------------------------------------------- /data/midi/examples/Fuer Elise - Ludwig Van Beethoven - Verse.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearpelican/musicautobot/HEAD/data/midi/examples/Fuer Elise - Ludwig Van Beethoven - Verse.mid -------------------------------------------------------------------------------- /data/midi/examples/Locked Out Of Heaven - Bruno Mars - Chorus.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearpelican/musicautobot/HEAD/data/midi/examples/Locked Out Of Heaven - Bruno Mars - Chorus.mid -------------------------------------------------------------------------------- /data/midi/examples/Roses Ft Rozes - The Chainsmokers - Chorus.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearpelican/musicautobot/HEAD/data/midi/examples/Roses Ft Rozes - The Chainsmokers - Chorus.mid -------------------------------------------------------------------------------- /data/midi/examples/Canon In D Major - Johann Pachelbel - Chorus.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearpelican/musicautobot/HEAD/data/midi/examples/Canon In D Major - Johann Pachelbel - Chorus.mid -------------------------------------------------------------------------------- /data/midi/examples/Where Is The Love - Black Eyed Peas - Chorus.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearpelican/musicautobot/HEAD/data/midi/examples/Where Is The Love - Black Eyed Peas - Chorus.mid -------------------------------------------------------------------------------- /serve/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | from api import app 3 | 4 | # app.run(host="0.0.0.0", port=80) 5 | app.run(port=5000) 6 | 7 | # To Run: 8 | # python run.py 9 | # or 10 | # python -m flask run 11 | -------------------------------------------------------------------------------- /data/midi/examples/Can You Feel The Love Tonight - Elton John - Verse.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearpelican/musicautobot/HEAD/data/midi/examples/Can You Feel The Love Tonight - Elton John - Verse.mid -------------------------------------------------------------------------------- /serve/run_guni.py: -------------------------------------------------------------------------------- 1 | import os 2 | from api import app 3 | 4 | if __name__ == "__main__": 5 | app.run() 6 | 7 | # To Run: 8 | # yarn build 9 | # gunicorn -w 8 run_guni:app -b 127.0.0.1:5000 10 | -------------------------------------------------------------------------------- /data/midi/examples/In The Hall Of The Mountain King - Edvard Grieg - Intro.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearpelican/musicautobot/HEAD/data/midi/examples/In The Hall Of The Mountain King - Edvard Grieg - Intro.mid -------------------------------------------------------------------------------- /data/midi/examples/Scary Monsters And Nice Sprites - Skrillex - Pre-Chorus.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearpelican/musicautobot/HEAD/data/midi/examples/Scary Monsters And Nice Sprites - Skrillex - Pre-Chorus.mid -------------------------------------------------------------------------------- /serve/environment.yml: -------------------------------------------------------------------------------- 1 | name: musicautobot 2 | channels: 3 | - defaults 4 | dependencies: 5 | - gunicorn 6 | - flask 7 | - boto3 8 | - pip: 9 | - flask-restplus 10 | - python-dotenv 11 | - flask_cors 12 | -------------------------------------------------------------------------------- /data/midi/examples/A Thousand Miles - Vanessa Carlton - Verse-And-Pre-Chorus.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearpelican/musicautobot/HEAD/data/midi/examples/A Thousand Miles - Vanessa Carlton - Verse-And-Pre-Chorus.mid -------------------------------------------------------------------------------- /data/midi/examples/All I Want For Christmas Is You - Mariah Carey - Pre-Chorus-And-Chorus.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearpelican/musicautobot/HEAD/data/midi/examples/All I Want For Christmas Is You - Mariah Carey - Pre-Chorus-And-Chorus.mid -------------------------------------------------------------------------------- /data/midi/examples/The Four Seasons Concerto No 4 Winter - Antonio Vivaldi - Instrumental.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bearpelican/musicautobot/HEAD/data/midi/examples/The Four Seasons Concerto No 4 Winter - Antonio Vivaldi - Instrumental.mid -------------------------------------------------------------------------------- /scripts/run_ddp.sh: -------------------------------------------------------------------------------- 1 | QUERY="$(nvidia-smi --query-gpu=gpu_name --format=csv | wc -l)" 2 | NUM_GPUS=$((QUERY-1)) 3 | 4 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} ${SCRIPT} "$@" 5 | #python -m torch.distributed.launch --nproc_per_node=1 ${SCRIPT} "$@" 6 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | # To Run: conda env create -f environment.yml 2 | name: musicautobot 3 | channels: 4 | - fastai 5 | - pytorch 6 | - defaults 7 | dependencies: 8 | - pytorch 9 | - fastai==1.0.61 10 | - jupyter 11 | - ipyparallel 12 | - pip 13 | - python>=3.6 14 | - pip: 15 | - music21 16 | - pebble 17 | # - "--editable=git+https://github.com/fastai/fastai@master" 18 | -------------------------------------------------------------------------------- /serve/README.md: -------------------------------------------------------------------------------- 1 | # Flask API Endpoint for Music Generation 2 | 3 | This API is build specifically for the front end app - musicautobot.com 4 | 5 | See: https://github.com/bearpelican/musicautobot_vueapp for the client code 6 | 7 | Installation: 8 | 9 | *Make sure you have already created musicautobot conda environment* 10 | 11 | cd serve 12 | conda env update -f environment.yml 13 | 14 | Set S3 BUCKET in api/api.cfg 15 | 16 | 17 | Running server: 18 | 19 | conda activate musicautobot 20 | 21 | Local Host: 22 | python run.py 23 | 24 | Production: 25 | gunicorn --certfile SSL_CERT --keyfile SSL_KEY -b 127.0.0.1:5000 run_guni:app --timeout 180 --workers 16 -------------------------------------------------------------------------------- /serve/api/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Global Flask Application Setting 3 | 4 | See `.flaskenv` for default settings. 5 | """ 6 | 7 | import os 8 | # from app import app 9 | from . import app 10 | from pathlib import Path 11 | 12 | class Config(object): 13 | project_path = Path(__file__).parents[2] 14 | LIB_PATH = project_path 15 | DATA_PATH = project_path/'data/numpy' 16 | DATA_SAVE_NAME = 'musicitem_data_save.pkl' 17 | MULTITASK_MODEL_PATH = DATA_PATH/'pretrained/MultitaskSmallKeyC.pth' 18 | MUSIC_MODEL_PATH = DATA_PATH/'pretrained/MusicTransformerKeyC.pth' 19 | 20 | app.config.from_object('api.config.Config') 21 | app.config.from_pyfile('api.cfg') 22 | -------------------------------------------------------------------------------- /serve/app.json: -------------------------------------------------------------------------------- 1 | 2 | { 3 | "name": "Midi Generator", 4 | "description": "", 5 | "repository": "https://github.com/bearpelican/vue_midi_generator", 6 | "logo": "https://static1.squarespace.com/static/5b7ec8e536099bcb6fd1b54f/t/5b989a2bc2241bdafc1f5a39/1536806734803/?format=50w", 7 | "keywords": ["flask", "vue"], 8 | "env": { 9 | "FLASK_ENV": { 10 | "description": "Flask Enviroment", 11 | "value": "production" 12 | }, 13 | "SECRET": { 14 | "description": "Flask Secret Key", 15 | "value": "YourKeyHere" 16 | } 17 | }, 18 | "addons": [ 19 | ], 20 | "buildpacks": [ 21 | { 22 | "url": "heroku/nodejs" 23 | }, 24 | { 25 | "url": "heroku/python" 26 | } 27 | ] 28 | } 29 | -------------------------------------------------------------------------------- /serve/api/__init__.py: -------------------------------------------------------------------------------- 1 | """ API Blueprint Application """ 2 | 3 | import os 4 | from flask import Flask 5 | # from flask_restplus import Api 6 | from flask_cors import CORS 7 | from flask import Blueprint, current_app 8 | 9 | # from .api import api_bp 10 | # from .client import client_bp 11 | 12 | app = Flask(__name__) 13 | CORS(app) 14 | # api = Api(app) 15 | 16 | # app.logger.info('>>> {}'.format(Config.FLASK_ENV)) 17 | 18 | @app.route('/hello') 19 | def hello(): return 'hello' 20 | 21 | # api_bp = Blueprint('api_bp', __name__, url_prefix='/api') 22 | 23 | 24 | # @api_bp.after_request 25 | # def add_header(response): 26 | # response.headers['Access-Control-Allow-Headers'] = 'Content-Type,Authorization' 27 | # return response 28 | 29 | 30 | from .config import Config 31 | 32 | # Import prediction api (choose only one) 33 | 34 | # MusicTransformer API 35 | # from .predict import * 36 | 37 | # Multitask API 38 | from .predict_multitask import * -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Andrew Shaw 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /musicautobot/utils/attention_mask.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def window_mask(x_len, device, m_len=0, size=(1,1)): 5 | win_size,k = size 6 | mem_mask = torch.zeros((x_len,m_len), device=device) 7 | tri_mask = torch.triu(torch.ones((x_len//win_size+1,x_len//win_size+1), device=device),diagonal=k) 8 | window_mask = tri_mask.repeat_interleave(win_size,dim=0).repeat_interleave(win_size,dim=1)[:x_len,:x_len] 9 | if x_len: window_mask[...,0] = 0 # Always allowing first index to see. Otherwise you'll get NaN loss 10 | mask = torch.cat((mem_mask, window_mask), dim=1)[None,None] 11 | return mask.bool() if hasattr(mask, 'bool') else mask.byte() 12 | 13 | def rand_window_mask(x_len,m_len,device,max_size:int=None,p:float=0.2,is_eval:bool=False): 14 | if is_eval or np.random.rand() >= p or max_size is None: 15 | win_size,k = (1,1) 16 | else: win_size,k = (np.random.randint(0,max_size)+1,0) 17 | return window_mask(x_len, device, m_len, size=(win_size,k)) 18 | 19 | def lm_mask(x_len, device): 20 | mask = torch.triu(torch.ones((x_len, x_len), device=device), diagonal=1)[None,None] 21 | return mask.bool() if hasattr(mask, 'bool') else mask.byte() 22 | -------------------------------------------------------------------------------- /musicautobot/config.py: -------------------------------------------------------------------------------- 1 | from fastai.text.models.transformer import tfmerXL_lm_config, Activation 2 | # from .vocab import MusicVocab 3 | 4 | def default_config(): 5 | config = tfmerXL_lm_config.copy() 6 | config['act'] = Activation.GeLU 7 | 8 | config['mem_len'] = 512 9 | config['d_model'] = 512 10 | config['d_inner'] = 2048 11 | config['n_layers'] = 16 12 | 13 | config['n_heads'] = 8 14 | config['d_head'] = 64 15 | 16 | return config 17 | 18 | def music_config(): 19 | config = default_config() 20 | config['encode_position'] = True 21 | return config 22 | 23 | def musicm_config(): 24 | config = music_config() 25 | config['d_model'] = 768 26 | config['d_inner'] = 3072 27 | config['n_heads'] = 12 28 | config['d_head'] = 64 29 | config['n_layers'] = 12 30 | return config 31 | 32 | def multitask_config(): 33 | config = default_config() 34 | config['bias'] = True 35 | config['enc_layers'] = 8 36 | config['dec_layers'] = 8 37 | del config['n_layers'] 38 | return config 39 | 40 | def multitaskm_config(): 41 | config = musicm_config() 42 | config['bias'] = True 43 | config['enc_layers'] = 12 44 | config['dec_layers'] = 12 45 | del config['n_layers'] 46 | return config 47 | 48 | -------------------------------------------------------------------------------- /serve/api/save.py: -------------------------------------------------------------------------------- 1 | 2 | import uuid 3 | import boto3 4 | import json 5 | from pathlib import Path 6 | from . import app 7 | 8 | s3 = boto3.client('s3') 9 | bucket = app.config['S3_BUCKET_NAME'] 10 | 11 | def to_s3(file, args): 12 | s3_id = str(uuid.uuid4()).replace('-', '') 13 | base_dir = 'generated/' 14 | s3_file = base_dir + s3_id + '.mid' 15 | s3_json = base_dir + s3_id + '.json' 16 | 17 | if not isinstance(file, (str, Path)): 18 | tmp_midi = '/tmp/' + s3_id + '.mid' 19 | with open(tmp_midi, 'wb') as f: 20 | f.write(file) 21 | else: 22 | tmp_midi = file 23 | 24 | if not isinstance(args, (str, Path)): 25 | tmp_json = '/tmp/' + s3_id + '.json' 26 | with open(tmp_json, 'w') as f: 27 | json.dump(args, f) 28 | else: tmp_json = args 29 | 30 | # Uploads the given file using a managed uploader, which will split up large 31 | # files automatically and upload parts in parallel. 32 | s3.upload_file(str(tmp_midi), bucket, s3_file) 33 | s3.upload_file(str(tmp_json), bucket, s3_json) 34 | print('Saved IDS:', s3_id, s3_id[::-1]) 35 | return s3_id[::-1] 36 | 37 | # @app.route('/store/save', methods=['POST']) 38 | # def save_store(): 39 | # args = request.form.to_dict() 40 | # midi = request.files['midi'].read() 41 | # print('Saving store:', args) 42 | # s3_id = to_s3(midi, args) 43 | # result = { 44 | # 'result': s3_id 45 | # } 46 | # return jsonify(result) 47 | -------------------------------------------------------------------------------- /musicautobot/utils/setup_musescore.py: -------------------------------------------------------------------------------- 1 | def setup_musescore(musescore_path=None): 2 | if not is_ipython(): return 3 | 4 | import platform 5 | from music21 import environment 6 | from pathlib import Path 7 | 8 | system = platform.system() 9 | if system == 'Linux': 10 | import os 11 | os.environ['QT_QPA_PLATFORM']='offscreen' # https://musescore.org/en/node/29041 12 | 13 | existing_path = environment.get('musicxmlPath') 14 | if existing_path: return 15 | if musescore_path is None: 16 | if system == 'Darwin': 17 | app_paths = list(Path('/Applications').glob('MuseScore *.app')) 18 | if len(app_paths): musescore_path = app_paths[-1]/'Contents/MacOS/mscore' 19 | elif system == 'Linux': 20 | musescore_path = '/usr/bin/musescore' 21 | 22 | if musescore_path is None or not Path(musescore_path).exists(): 23 | print('Warning: Could not find musescore installation. Please install musescore (see README) and/or update music21 environment paths') 24 | else : 25 | environment.set('musicxmlPath', musescore_path) 26 | environment.set('musescoreDirectPNGPath', musescore_path) 27 | 28 | def is_ipython(): 29 | try: get_ipython 30 | except: return False 31 | return True 32 | 33 | def is_colab(): 34 | try: import google.colab 35 | except: return False 36 | return True 37 | 38 | def setup_fluidsynth(): 39 | from midi2audio import FluidSynth 40 | from IPython.display import Audio 41 | 42 | def play_wav(stream): 43 | out_midi = stream.write('midi') 44 | out_wav = str(Path(out_midi).with_suffix('.wav')) 45 | FluidSynth("font.sf2").midi_to_audio(out_midi, out_wav) 46 | return Audio(out_wav) 47 | -------------------------------------------------------------------------------- /musicautobot/utils/top_k_top_p.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | __all__ = ['top_k_top_p'] 5 | 6 | # top_k + nucleus filter - https://twitter.com/thom_wolf/status/1124263861727760384?lang=en 7 | # https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 8 | def top_k_top_p(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): 9 | """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering 10 | Args: 11 | logits: logits distribution shape (vocabulary size) 12 | top_k >0: keep only top k tokens with highest probability (top-k filtering). 13 | top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). 14 | """ 15 | logits = logits.clone() 16 | assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear 17 | top_k = min(top_k, logits.size(-1)) # Safety check 18 | if top_k > 0: 19 | # Remove all tokens with a probability less than the last token of the top-k 20 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 21 | logits[indices_to_remove] = filter_value 22 | 23 | if top_p > 0.0: 24 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 25 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 26 | 27 | # Remove tokens with cumulative probability above the threshold 28 | sorted_indices_to_remove = cumulative_probs > top_p 29 | # Shift the indices to the right to keep also the first token above the threshold 30 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 31 | sorted_indices_to_remove[..., 0] = 0 32 | 33 | indices_to_remove = sorted_indices[sorted_indices_to_remove] 34 | logits[indices_to_remove] = filter_value 35 | return logits 36 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Added files 2 | data/* 3 | !data/midi 4 | data/midi/* 5 | !data/midi/examples 6 | !data/midi/notebook_examples 7 | 8 | 9 | models/ 10 | tmp_out/ 11 | data_serve/ 12 | notebooks/data_collection/musescore/*.json 13 | notebooks/develop/ 14 | .vscode 15 | .DS_Store 16 | 17 | # Byte-compiled / optimized / DLL files 18 | __pycache__/ 19 | *.py[cod] 20 | *$py.class 21 | 22 | # C extensions 23 | *.so 24 | 25 | # Distribution / packaging 26 | .Python 27 | build/ 28 | develop-eggs/ 29 | dist/ 30 | downloads/ 31 | eggs/ 32 | .eggs/ 33 | lib/ 34 | lib64/ 35 | parts/ 36 | sdist/ 37 | var/ 38 | wheels/ 39 | *.egg-info/ 40 | .installed.cfg 41 | *.egg 42 | MANIFEST 43 | 44 | # PyInstaller 45 | # Usually these files are written by a python script from a template 46 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 47 | *.manifest 48 | *.spec 49 | 50 | # Installer logs 51 | pip-log.txt 52 | pip-delete-this-directory.txt 53 | 54 | # Unit test / coverage reports 55 | htmlcov/ 56 | .tox/ 57 | .coverage 58 | .coverage.* 59 | .cache 60 | nosetests.xml 61 | coverage.xml 62 | *.cover 63 | .hypothesis/ 64 | .pytest_cache/ 65 | 66 | # Translations 67 | *.mo 68 | *.pot 69 | 70 | # Django stuff: 71 | *.log 72 | local_settings.py 73 | db.sqlite3 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # celery beat schedule file 95 | celerybeat-schedule 96 | 97 | # SageMath parsed files 98 | *.sage.py 99 | 100 | # Environments 101 | .env 102 | .venv 103 | env/ 104 | venv/ 105 | ENV/ 106 | env.bak/ 107 | venv.bak/ 108 | 109 | # Spyder project settings 110 | .spyderproject 111 | .spyproject 112 | 113 | # Rope project settings 114 | .ropeproject 115 | 116 | # mkdocs documentation 117 | /site 118 | 119 | # mypy 120 | .mypy_cache/ 121 | -------------------------------------------------------------------------------- /musicautobot/utils/file_processing.py: -------------------------------------------------------------------------------- 1 | "Parallel processing for midi files" 2 | import csv 3 | from fastprogress.fastprogress import master_bar, progress_bar 4 | from pathlib import Path 5 | from pebble import ProcessPool 6 | from concurrent.futures import TimeoutError 7 | import numpy as np 8 | 9 | # https://stackoverflow.com/questions/20991968/asynchronous-multiprocessing-with-a-worker-pool-in-python-how-to-keep-going-aft 10 | def process_all(func, arr, timeout_func=None, total=None, max_workers=None, timeout=None): 11 | with ProcessPool() as pool: 12 | future = pool.map(func, arr, timeout=timeout) 13 | 14 | iterator = future.result() 15 | results = [] 16 | for i in progress_bar(range(len(arr)), total=len(arr)): 17 | try: 18 | result = next(iterator) 19 | if result: results.append(result) 20 | except StopIteration: 21 | break 22 | except TimeoutError as error: 23 | if timeout_func: timeout_func(arr[i], error.args[1]) 24 | return results 25 | 26 | def process_file(file_path, tfm_func=None, src_path=None, dest_path=None): 27 | "Utility function that transforms midi file to numpy array." 28 | output_file = Path(str(file_path).replace(str(src_path), str(dest_path))).with_suffix('.npy') 29 | if output_file.exists(): return output_file 30 | output_file.parent.mkdir(parents=True, exist_ok=True) 31 | 32 | # Call tfm_func and save file 33 | npenc = tfm_func(file_path) 34 | if npenc is not None: 35 | np.save(output_file, npenc) 36 | return output_file 37 | 38 | def arr2csv(arr, out_file): 39 | "Convert metadata array to csv" 40 | all_keys = {k for d in arr for k in d.keys()} 41 | arr = [format_values(x) for x in arr] 42 | with open(out_file, 'w') as f: 43 | dict_writer = csv.DictWriter(f, list(all_keys)) 44 | dict_writer.writeheader() 45 | dict_writer.writerows(arr) 46 | 47 | def format_values(d): 48 | "Format array values for csv encoding" 49 | def format_value(v): 50 | if isinstance(v, list): return ','.join(v) 51 | return v 52 | return {k:format_value(v) for k,v in d.items()} -------------------------------------------------------------------------------- /musicautobot/utils/stacked_dataloader.py: -------------------------------------------------------------------------------- 1 | "Dataloader wrapper that can combine and handle multiple dataloaders for multitask training" 2 | from fastai.callback import Callback 3 | from typing import Callable 4 | 5 | __all__ = ['StackedDataBunch'] 6 | 7 | # DataLoading 8 | class StackedDataBunch(): 9 | def __init__(self, dbs, num_it=100): 10 | self.dbs = dbs 11 | self.train_dl = StackedDataloader([db.train_dl for db in self.dbs], num_it) 12 | self.valid_dl = StackedDataloader([db.valid_dl for db in self.dbs], num_it) 13 | self.train_ds = None 14 | self.path = dbs[0].path 15 | self.device = dbs[0].device 16 | self.vocab = dbs[0].vocab 17 | self.empty_val = False 18 | 19 | def add_tfm(self,tfm:Callable)->None: 20 | for dl in self.dbs: dl.add_tfm(tfm) 21 | 22 | def remove_tfm(self,tfm:Callable)->None: 23 | for dl in self.dbs: dl.remove_tfm(tfm) 24 | 25 | # Helper functions 26 | class StackedDataset(Callback): 27 | def __init__(self, dss): 28 | self.dss = dss 29 | def __getattribute__(self, attr): 30 | if attr == 'dss': return super().__getattribute__(attr) 31 | def redirected(*args, **kwargs): 32 | for ds in self.dss: 33 | if hasattr(ds, attr): getattr(ds, attr)(*args, **kwargs) 34 | return redirected 35 | def __len__(self)->int: return sum([len(ds) for ds in self.dss]) 36 | def __repr__(self): return '\n'.join([self.__class__.__name__] + [repr(ds) for ds in self.dss]) 37 | 38 | class StackedDataloader(): 39 | def __init__(self, dls, num_it=100): 40 | self.dls = dls 41 | self.dataset = StackedDataset([dl.dataset for dl in dls if hasattr(dl, 'dataset')]) 42 | self.num_it = num_it 43 | self.dl_idx = -1 44 | 45 | def __len__(self)->int: return sum([len(dl) for dl in self.dls]) 46 | def __getattr__(self, attr): 47 | def redirected(*args, **kwargs): 48 | for dl in self.dls: 49 | if hasattr(dl, attr): 50 | getattr(dl, attr)(*args, **kwargs) 51 | return redirected 52 | 53 | def __iter__(self): 54 | "Process and returns items from `DataLoader`." 55 | iters = [iter(dl) for dl in self.dls] 56 | self.dl_idx = -1 57 | while len(iters): 58 | self.dl_idx = (self.dl_idx+1) % len(iters) 59 | for b in range(self.num_it): 60 | try: 61 | yield next(iters[self.dl_idx]) 62 | except StopIteration as e: 63 | iters.remove(iters[self.dl_idx]) 64 | break 65 | # raise StopIteration 66 | 67 | def new(self, **kwargs): 68 | "Create a new copy of `self` with `kwargs` replacing current values." 69 | new_dls = [dl.new(**kwargs) for dl in self.dls] 70 | return StackedDataloader(new_dls, self.num_it) 71 | -------------------------------------------------------------------------------- /musicautobot/music_transformer/model.py: -------------------------------------------------------------------------------- 1 | from fastai.basics import * 2 | from fastai.text.models.transformer import TransformerXL 3 | from ..utils.attention_mask import rand_window_mask 4 | 5 | class MusicTransformerXL(TransformerXL): 6 | "Exactly like fastai's TransformerXL, but with more aggressive attention mask: see `rand_window_mask`" 7 | def __init__(self, *args, encode_position=True, mask_steps=1, **kwargs): 8 | import inspect 9 | sig = inspect.signature(TransformerXL) 10 | arg_params = { k:kwargs[k] for k in sig.parameters if k in kwargs } 11 | super().__init__(*args, **arg_params) 12 | 13 | self.encode_position = encode_position 14 | if self.encode_position: self.beat_enc = BeatPositionEncoder(kwargs['d_model']) 15 | 16 | self.mask_steps=mask_steps 17 | 18 | 19 | def forward(self, x): 20 | #The hidden state has to be initiliazed in the forward pass for nn.DataParallel 21 | if self.mem_len > 0 and not self.init: 22 | self.reset() 23 | self.init = True 24 | 25 | benc = 0 26 | if self.encode_position: 27 | x,pos = x['x'], x['pos'] 28 | benc = self.beat_enc(pos) 29 | 30 | bs,x_len = x.size() 31 | inp = self.drop_emb(self.encoder(x) + benc) #.mul_(self.d_model ** 0.5) 32 | m_len = self.hidden[0].size(1) if hasattr(self, 'hidden') and len(self.hidden[0].size()) > 1 else 0 33 | seq_len = m_len + x_len 34 | 35 | mask = rand_window_mask(x_len, m_len, inp.device, max_size=self.mask_steps, is_eval=not self.training) if self.mask else None 36 | if m_len == 0: mask[...,0,0] = 0 37 | #[None,:,:None] for einsum implementation of attention 38 | hids = [] 39 | pos = torch.arange(seq_len-1, -1, -1, device=inp.device, dtype=inp.dtype) 40 | pos_enc = self.pos_enc(pos) 41 | hids.append(inp) 42 | for i, layer in enumerate(self.layers): 43 | mem = self.hidden[i] if self.mem_len > 0 else None 44 | inp = layer(inp, r=pos_enc, u=self.u, v=self.v, mask=mask, mem=mem) 45 | hids.append(inp) 46 | core_out = inp[:,-x_len:] 47 | if self.mem_len > 0 : self._update_mems(hids) 48 | return (self.hidden if self.mem_len > 0 else [core_out]),[core_out] 49 | 50 | 51 | # Beat encoder 52 | class BeatPositionEncoder(nn.Module): 53 | "Embedding + positional encoding + dropout" 54 | def __init__(self, emb_sz:int, beat_len=32, max_bar_len=1024): 55 | super().__init__() 56 | 57 | self.beat_len, self.max_bar_len = beat_len, max_bar_len 58 | self.beat_enc = nn.Embedding(beat_len, emb_sz, padding_idx=0) 59 | self.bar_enc = nn.Embedding(max_bar_len, emb_sz, padding_idx=0) 60 | 61 | def forward(self, pos): 62 | beat_enc = self.beat_enc(pos % self.beat_len) 63 | bar_pos = pos // self.beat_len % self.max_bar_len 64 | bar_pos[bar_pos >= self.max_bar_len] = self.max_bar_len - 1 65 | bar_enc = self.bar_enc((bar_pos)) 66 | return beat_enc + bar_enc -------------------------------------------------------------------------------- /musicautobot/multitask_transformer/transform.py: -------------------------------------------------------------------------------- 1 | from ..music_transformer.transform import * 2 | 3 | class MultitrackItem(): 4 | def __init__(self, melody:MusicItem, chords:MusicItem, stream=None): 5 | self.melody,self.chords = melody, chords 6 | self.vocab = melody.vocab 7 | self._stream = stream 8 | 9 | @classmethod 10 | def from_file(cls, midi_file, vocab): 11 | return cls.from_stream(file2stream(midi_file), vocab) 12 | 13 | @classmethod 14 | def from_stream(cls, stream, vocab): 15 | if not isinstance(stream, music21.stream.Score): stream = stream.voicesToParts() 16 | num_parts = len(stream.parts) 17 | sort_pitch = False 18 | if num_parts > 2: 19 | raise ValueError('Could not extract melody and chords from midi file. Please make sure file contains exactly 2 tracks') 20 | elif num_parts == 1: 21 | print('Warning: only 1 track found. Inferring melody/chords') 22 | stream = separate_melody_chord(stream) 23 | sort_pitch = False 24 | 25 | mpart, cpart = stream2npenc_parts(stream, sort_pitch=sort_pitch) 26 | return cls.from_npenc_parts(mpart, cpart, vocab, stream) 27 | 28 | @classmethod 29 | def from_npenc_parts(cls, mpart, cpart, vocab, stream=None): 30 | mpart = npenc2idxenc(mpart, seq_type=SEQType.Melody, vocab=vocab, add_eos=False) 31 | cpart = npenc2idxenc(cpart, seq_type=SEQType.Chords, vocab=vocab, add_eos=False) 32 | return MultitrackItem(MusicItem(mpart, vocab), MusicItem(cpart, vocab), stream) 33 | 34 | @classmethod 35 | def from_idx(cls, item, vocab): 36 | m, c = item 37 | return MultitrackItem(MusicItem.from_idx(m, vocab), MusicItem.from_idx(c, vocab)) 38 | def to_idx(self): return np.array((self.melody.to_idx(), self.chords.to_idx())) 39 | 40 | @property 41 | def stream(self): 42 | self._stream = self.to_stream() if self._stream is None else self._stream 43 | return self._stream 44 | 45 | def to_stream(self, bpm=120): 46 | ps = self.melody.to_npenc(), self.chords.to_npenc() 47 | ps = [npenc2chordarr(p) for p in ps] 48 | chordarr = chordarr_combine_parts(ps) 49 | return chordarr2stream(chordarr, bpm=bpm) 50 | 51 | 52 | def show(self, format:str=None): 53 | return self.stream.show(format) 54 | def play(self): self.stream.show('midi') 55 | 56 | def transpose(self, val): 57 | return MultitrackItem(self.melody.transpose(val), self.chords.transpose(val)) 58 | def pad_to(self, val): 59 | return MultitrackItem(self.melody.pad_to(val), self.chords.pad_to(val)) 60 | def trim_to_beat(self, beat): 61 | return MultitrackItem(self.melody.trim_to_beat(beat), self.chords.trim_to_beat(beat)) 62 | 63 | def combine2chordarr(np1, np2, vocab): 64 | if len(np1.shape) == 1: np1 = idxenc2npenc(np1, vocab) 65 | if len(np2.shape) == 1: np2 = idxenc2npenc(np2, vocab) 66 | p1 = npenc2chordarr(np1) 67 | p2 = npenc2chordarr(np2) 68 | return chordarr_combine_parts((p1, p2)) 69 | -------------------------------------------------------------------------------- /serve/api/predict.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from . import app 3 | sys.path.append(str(app.config['LIB_PATH'])) 4 | 5 | from musicautobot.music_transformer import * 6 | from musicautobot.config import * 7 | 8 | from flask import Response, send_from_directory, send_file, request, jsonify 9 | from .save import to_s3 10 | 11 | import torch 12 | import traceback 13 | torch.set_num_threads(4) 14 | 15 | data = load_data(app.config['DATA_PATH'], app.config['DATA_SAVE_NAME'], num_workers=1) 16 | learn = music_model_learner(data, pretrained_path=app.config['MUSIC_MODEL_PATH']) 17 | 18 | if torch.cuda.is_available(): learn.model.cuda() 19 | # learn.to_fp16(loss_scale=512) # fp16 not supported for cpu - https://github.com/pytorch/pytorch/issues/17699 20 | 21 | @app.route('/predict/midi', methods=['POST']) 22 | def predict_midi(): 23 | args = request.form.to_dict() 24 | midi = request.files['midi'].read() 25 | print('THE ARGS PASSED:', args) 26 | bpm = float(args['bpm']) # (AS) TODO: get bpm from midi file instead 27 | temperatures = (float(args.get('noteTemp', 1.2)), float(args.get('durationTemp', 0.8))) 28 | n_words = int(args.get('nSteps', 200)) 29 | seed_len = int(args.get('seedLen', 12)) 30 | # debugging 1 - send exact midi back 31 | # with open('/tmp/test.mid', 'wb') as f: 32 | # f.write(midi) 33 | # return send_from_directory('/tmp', 'test.mid', mimetype='audio/midi') 34 | 35 | # debugging 2 - test music21 conversion 36 | # stream = file2stream(midi) # 1. 37 | 38 | # debugging 3 - test npenc conversion 39 | # seed_np = midi2npenc(midi) # music21 can handle bytes directly 40 | # stream = npenc2stream(seed_np, bpm=bpm) 41 | 42 | # debugging 4 - midi in, convert, midi out 43 | # stream = file2stream(midi) # 1. 44 | # midi_in = Path(stream.write("musicxml")) 45 | # print('Midi in:', midi_in) 46 | # stream_sep = separate_melody_chord(stream) 47 | # midi_out = Path(stream_sep.write("midi")) 48 | # print('Midi out:', midi_out) 49 | # s3_id = to_s3(midi_out, args) 50 | # result = { 51 | # 'result': s3_id 52 | # } 53 | # return jsonify(result) 54 | 55 | # Main logic 56 | try: 57 | full = predict_from_midi(learn, midi=midi, n_words=n_words, seed_len=seed_len, temperatures=temperatures) 58 | stream = separate_melody_chord(full.to_stream(bpm=bpm)) 59 | midi_out = Path(stream.write("midi")) 60 | print('Wrote to temporary file:', midi_out) 61 | except Exception as e: 62 | traceback.print_exc() 63 | return jsonify({'error': f'Failed to predict: {e}'}) 64 | 65 | s3_id = to_s3(midi_out, args) 66 | result = { 67 | 'result': s3_id 68 | } 69 | return jsonify(result) 70 | 71 | # return send_from_directory(midi_out.parent, midi_out.name, mimetype='audio/midi') 72 | 73 | # @app.route('/midi/song/') 74 | # def get_song_midi(sid): 75 | # return send_from_directory(file_path/data_dir, htlist[sid]['midi'], mimetype='audio/midi') 76 | 77 | @app.route('/midi/convert', methods=['POST']) 78 | def convert_midi(): 79 | args = request.form.to_dict() 80 | if 'midi' in request.files: 81 | midi = request.files['midi'].read() 82 | elif 'midi_path'in args: 83 | midi = args['midi_path'] 84 | 85 | stream = file2stream(midi) # 1. 86 | # stream = file2stream(midi).chordify() # 1. 87 | stream_out = Path(stream.write('musicxml')) 88 | return send_from_directory(stream_out.parent, stream_out.name, mimetype='xml') 89 | 90 | -------------------------------------------------------------------------------- /serve/api/predict_multitask.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | from . import app 3 | sys.path.append(str(app.config['LIB_PATH'])) 4 | 5 | from musicautobot.multitask_transformer import * 6 | from musicautobot.music_transformer import * 7 | from musicautobot.config import * 8 | from flask import Response, send_from_directory, send_file, request, jsonify 9 | 10 | from .save import to_s3 11 | 12 | import torch 13 | import traceback 14 | torch.set_num_threads(4) 15 | 16 | data = load_data(app.config['DATA_PATH'], app.config['DATA_SAVE_NAME'], num_workers=1) 17 | learn = multitask_model_learner(data, pretrained_path=app.config['MULTITASK_MODEL_PATH']) 18 | 19 | if torch.cuda.is_available(): learn.model.cuda() 20 | 21 | 22 | @app.route('/predict/midi', methods=['POST']) 23 | def predict_midi(): 24 | args = request.form.to_dict() 25 | midi = request.files['midi'].read() 26 | print('Prediction Args:', args) 27 | 28 | # Universal parameters 29 | bpm = float(args['bpm']) # (AS) TODO: get bpm from midi file instead 30 | prediction_type = args.get('predictionType', 'next') 31 | temperatures = (float(args.get('noteTemp', 1.2)), float(args.get('durationTemp', 0.8))) 32 | top_k, top_p = (int(args.get('topK', 20)), float(args.get('topP', 0.9))) 33 | 34 | # Parameters for NextSeq and Melody/Chords 35 | n_words = int(args.get('nSteps', 200)) 36 | seed_len = int(args.get('seedLen', 12)) 37 | 38 | # Parameters for Masking 39 | mask_start, mask_end = None, None 40 | try: 41 | mask_start = int(args['maskStart']) 42 | mask_end = int(args['maskEnd']) 43 | except: pass 44 | 45 | # Main logic 46 | try: 47 | if prediction_type == 'next': 48 | full = nw_predict_from_midi(learn, midi=midi, n_words=n_words, seed_len=seed_len, temperatures=temperatures, top_k=top_k, top_p=top_p) 49 | stream = separate_melody_chord(full.to_stream(bpm=bpm)) 50 | elif prediction_type in ['melody', 'chords']: 51 | full = s2s_predict_from_midi(learn, midi=midi, n_words=n_words, temperatures=temperatures, seed_len=seed_len, 52 | pred_melody=(prediction_type == 'melody'), use_memory=True, top_k=top_k, top_p=top_p) 53 | stream = full.to_stream(bpm=bpm) 54 | elif prediction_type in ['pitch', 'rhythm']: 55 | full = mask_predict_from_midi(learn, midi=midi, temperatures=temperatures, predict_notes=(prediction_type == 'pitch'), section=(mask_start, mask_end), top_k=top_k, top_p=top_p) 56 | stream = separate_melody_chord(full.to_stream(bpm=bpm)) 57 | midi_out = Path(stream.write("midi")) 58 | print('Wrote to temporary file:', midi_out) 59 | except Exception as e: 60 | traceback.print_exc() 61 | return jsonify({'error': f'Failed to predict: {e}'}) 62 | 63 | s3_id = to_s3(midi_out, args) 64 | result = { 65 | 'result': s3_id 66 | } 67 | return jsonify(result) 68 | # return send_from_directory(midi_out.parent, midi_out.name, mimetype='audio/midi') 69 | 70 | @app.route('/midi/convert', methods=['POST']) 71 | def convert_midi(): 72 | args = request.form.to_dict() 73 | if 'midi' in request.files: 74 | midi = request.files['midi'].read() 75 | elif 'midi_path'in args: 76 | midi = args['midi_path'] 77 | 78 | stream = file2stream(midi) # 1. 79 | # stream = file2stream(midi).chordify() # 1. 80 | stream_out = Path(stream.write('musicxml')) 81 | return send_from_directory(stream_out.parent, stream_out.name, mimetype='xml') -------------------------------------------------------------------------------- /musicautobot/vocab.py: -------------------------------------------------------------------------------- 1 | from fastai.basics import * 2 | from .numpy_encode import * 3 | from .music_transformer import transform 4 | 5 | BOS = 'xxbos' 6 | PAD = 'xxpad' 7 | EOS = 'xxeos' 8 | MASK = 'xxmask' # Used for BERT masked language modeling. 9 | CSEQ = 'xxcseq' # Used for Seq2Seq translation - denotes start of chord sequence 10 | MSEQ = 'xxmseq' # Used for Seq2Seq translation - denotes start of melody sequence 11 | 12 | # Deprecated tokens. Kept for compatibility 13 | S2SCLS = 'xxs2scls' # deprecated 14 | NSCLS = 'xxnscls' # deprecated 15 | 16 | SEP = 'xxsep' # Used to denote end of timestep (required for polyphony). separator idx = -1 (part of notes) 17 | 18 | SPECIAL_TOKS = [BOS, PAD, EOS, S2SCLS, MASK, CSEQ, MSEQ, NSCLS, SEP] # Important: SEP token must be last 19 | 20 | NOTE_TOKS = [f'n{i}' for i in range(NOTE_SIZE)] 21 | DUR_TOKS = [f'd{i}' for i in range(DUR_SIZE)] 22 | NOTE_START, NOTE_END = NOTE_TOKS[0], NOTE_TOKS[-1] 23 | DUR_START, DUR_END = DUR_TOKS[0], DUR_TOKS[-1] 24 | 25 | MTEMPO_SIZE = 10 26 | MTEMPO_OFF = 'mt0' 27 | MTEMPO_TOKS = [f'mt{i}' for i in range(MTEMPO_SIZE)] 28 | 29 | # Vocab - token to index mapping 30 | class MusicVocab(): 31 | "Contain the correspondence between numbers and tokens and numericalize." 32 | def __init__(self, itos:Collection[str]): 33 | self.itos = itos 34 | self.stoi = {v:k for k,v in enumerate(self.itos)} 35 | 36 | def numericalize(self, t:Collection[str]) -> List[int]: 37 | "Convert a list of tokens `t` to their ids." 38 | return [self.stoi[w] for w in t] 39 | 40 | def textify(self, nums:Collection[int], sep=' ') -> List[str]: 41 | "Convert a list of `nums` to their tokens." 42 | items = [self.itos[i] for i in nums] 43 | return sep.join(items) if sep is not None else items 44 | 45 | def to_music_item(self, idxenc): 46 | return transform.MusicItem(idxenc, self) 47 | 48 | @property 49 | def mask_idx(self): return self.stoi[MASK] 50 | @property 51 | def pad_idx(self): return self.stoi[PAD] 52 | @property 53 | def bos_idx(self): return self.stoi[BOS] 54 | @property 55 | def sep_idx(self): return self.stoi[SEP] 56 | @property 57 | def npenc_range(self): return (self.stoi[SEP], self.stoi[DUR_END]+1) 58 | @property 59 | def note_range(self): return self.stoi[NOTE_START], self.stoi[NOTE_END]+1 60 | @property 61 | def dur_range(self): return self.stoi[DUR_START], self.stoi[DUR_END]+1 62 | 63 | def is_duration(self, idx): 64 | return idx >= self.dur_range[0] and idx < self.dur_range[1] 65 | def is_duration_or_pad(self, idx): 66 | return idx == self.pad_idx or self.is_duration(idx) 67 | 68 | def __getstate__(self): 69 | return {'itos':self.itos} 70 | 71 | def __setstate__(self, state:dict): 72 | self.itos = state['itos'] 73 | self.stoi = {v:k for k,v in enumerate(self.itos)} 74 | 75 | def __len__(self): return len(self.itos) 76 | 77 | def save(self, path): 78 | "Save `self.itos` in `path`" 79 | pickle.dump(self.itos, open(path, 'wb')) 80 | 81 | @classmethod 82 | def create(cls) -> 'Vocab': 83 | "Create a vocabulary from a set of `tokens`." 84 | itos = SPECIAL_TOKS + NOTE_TOKS + DUR_TOKS + MTEMPO_TOKS 85 | if len(itos)%8 != 0: 86 | itos = itos + [f'dummy{i}' for i in range(len(itos)%8)] 87 | return cls(itos) 88 | 89 | @classmethod 90 | def load(cls, path): 91 | "Load the `Vocab` contained in `path`" 92 | itos = pickle.load(open(path, 'rb')) 93 | return cls(itos) 94 | -------------------------------------------------------------------------------- /scripts/run_music_transformer.py: -------------------------------------------------------------------------------- 1 | 2 | import music21 3 | import torch 4 | import numpy as np 5 | try: from apex.optimizers import FusedAdam 6 | except: from torch.optim import Adam as FusedAdam 7 | 8 | from fastai.distributed import * 9 | from fastai.callbacks import SaveModelCallback 10 | from fastai.text.models.transformer import * 11 | 12 | 13 | import sys 14 | sys.path.insert(0, '..') 15 | 16 | from musicautobot.music_transformer import * 17 | 18 | 19 | import argparse 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--path', type=str, default='../data/numpy/') 22 | parser.add_argument('--data_file', type=str, default='musicitem_data_save.pkl') 23 | parser.add_argument('--save', type=str, default='first_run') 24 | parser.add_argument('--load', type=str, default=None) 25 | parser.add_argument("--local_rank", type=int, default=0) 26 | parser.add_argument("--batch_size", type=int, default=12) 27 | parser.add_argument("--mem_len", type=int, default=512) 28 | parser.add_argument("--bptt", type=int, default=512) 29 | parser.add_argument("--num_workers", type=int, default=1) 30 | parser.add_argument('--half', action='store_true', help='Use half precision') 31 | parser.add_argument('--lamb', action='store_true', help='Use lamb optimizer') 32 | parser.add_argument('--wd', type=float, default=1e-3, help='weight decay for adam') 33 | parser.add_argument('--epochs', type=int, default=5, help='num epochs') 34 | parser.add_argument('--lr', type=float, default=1e-3, help='learning rate') 35 | parser.add_argument('--div_factor', type=int, default=10, help='learning rate div factor') 36 | parser.add_argument('--config', type=str, default='default_config', help='serve.py config name') 37 | parser.add_argument('--no_transpose', action='store_true', help='No transpose data augmentation') 38 | parser.add_argument('--parallel', action='store_true', help='Run in dataparallel') 39 | parser.add_argument('--mask_steps', type=int, default=1, help='Attention mask - max number of random steps. Basically teacher forcing') 40 | 41 | args = parser.parse_args() 42 | is_distributed = num_distrib() > 0 43 | if args.local_rank != 0: 44 | f = open('/dev/null', 'w') 45 | sys.stdout = f 46 | 47 | if is_distributed: 48 | torch.cuda.set_device(args.local_rank) 49 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 50 | 51 | 52 | path = Path(args.path) 53 | 54 | from musicautobot import config 55 | config = getattr(config, args.config)() 56 | config['encode_position'] = True 57 | config['mask_steps'] = args.mask_steps 58 | 59 | transpose_range = None if args.no_transpose else (0,12) 60 | data = load_data(path, args.data_file, encode_position=config['encode_position'], dl_tfms=[batch_position_tfm], 61 | bs=args.batch_size, bptt=args.bptt, transpose_range=transpose_range, num_workers=args.num_workers) 62 | 63 | eps = 1e-2 if args.half else 1e-6 64 | opt_func = partial(FusedAdam, betas=(0.9,0.99), eps=eps) 65 | if args.lamb: 66 | from musicautobot.utils.lamb import Lamb 67 | opt_func = partial(Lamb, eps=eps) 68 | 69 | load_path = path/args.load if args.load else None 70 | learn = music_model_learner(data, config=config, drop_mult=1.5, opt_func=opt_func, pretrained_path=load_path) 71 | if not args.half: learn.clip_grad(1.0) 72 | 73 | if args.save: 74 | save_path = path/learn.model_dir/args.save 75 | save_path.parent.mkdir(parents=True, exist_ok=True) 76 | if args.half: learn = learn.to_fp16(clip=1.0, dynamic=True, max_scale=2**18) 77 | if is_distributed: learn = learn.to_distributed(args.local_rank, cache_dir=path/'dist_logs') 78 | if args.parallel: learn = learn.to_parallel() 79 | if args.local_rank == 0: learn.callbacks.append(SaveModelCallback(learn, name=f'{args.save}_best')) 80 | 81 | learn.fit_one_cycle(args.epochs, args.lr, div_factor=args.div_factor, pct_start=0.2, final_div=200, wd=args.wd) 82 | 83 | if args.local_rank == 0: learn.save(f'{args.save}', config=config) 84 | -------------------------------------------------------------------------------- /musicautobot/utils/midifile.py: -------------------------------------------------------------------------------- 1 | "Transform functions for raw midi files" 2 | from enum import Enum 3 | import music21 4 | 5 | PIANO_TYPES = list(range(24)) + list(range(80, 96)) # Piano, Synths 6 | PLUCK_TYPES = list(range(24, 40)) + list(range(104, 112)) # Guitar, Bass, Ethnic 7 | BRIGHT_TYPES = list(range(40, 56)) + list(range(56, 80)) 8 | 9 | PIANO_RANGE = (21, 109) # https://en.wikipedia.org/wiki/Scientific_pitch_notation 10 | 11 | class Track(Enum): 12 | PIANO = 0 # discrete instruments - keyboard, woodwinds 13 | PLUCK = 1 # continuous instruments with pitch bend: violin, trombone, synths 14 | BRIGHT = 2 15 | PERC = 3 16 | UNDEF = 4 17 | 18 | type2inst = { 19 | # use print_music21_instruments() to see supported types 20 | Track.PIANO: 0, # Piano 21 | Track.PLUCK: 24, # Guitar 22 | Track.BRIGHT: 40, # Violin 23 | Track.PERC: 114, # Steel Drum 24 | } 25 | 26 | # INFO_TYPES = set(['TIME_SIGNATURE', 'KEY_SIGNATURE']) 27 | INFO_TYPES = set(['TIME_SIGNATURE', 'KEY_SIGNATURE', 'SET_TEMPO']) 28 | 29 | def file2mf(fp): 30 | mf = music21.midi.MidiFile() 31 | if isinstance(fp, bytes): 32 | mf.readstr(fp) 33 | else: 34 | mf.open(fp) 35 | mf.read() 36 | mf.close() 37 | return mf 38 | 39 | def mf2stream(mf): return music21.midi.translate.midiFileToStream(mf) 40 | 41 | def is_empty_midi(fp): 42 | if fp is None: return False 43 | mf = file2mf(fp) 44 | return not any([t.hasNotes() for t in mf.tracks]) 45 | 46 | def num_piano_tracks(fp): 47 | music_file = file2mf(fp) 48 | note_tracks = [t for t in music_file.tracks if t.hasNotes() and get_track_type(t) == Track.PIANO] 49 | return len(note_tracks) 50 | 51 | def is_channel(t, c_val): 52 | return any([c == c_val for c in t.getChannels()]) 53 | 54 | def track_sort(t): # sort by 1. variation of pitch, 2. number of notes 55 | return len(unique_track_notes(t)), len(t.events) 56 | 57 | def is_piano_note(pitch): 58 | return (pitch >= PIANO_RANGE[0]) and (pitch < PIANO_RANGE[1]) 59 | 60 | def unique_track_notes(t): 61 | return { e.pitch for e in t.events if e.pitch is not None } 62 | 63 | def compress_midi_file(fp, cutoff=6, min_variation=3, supported_types=set([Track.PIANO, Track.PLUCK, Track.BRIGHT])): 64 | music_file = file2mf(fp) 65 | 66 | info_tracks = [t for t in music_file.tracks if not t.hasNotes()] 67 | note_tracks = [t for t in music_file.tracks if t.hasNotes()] 68 | 69 | if len(note_tracks) > cutoff: 70 | note_tracks = sorted(note_tracks, key=track_sort, reverse=True) 71 | 72 | supported_tracks = [] 73 | for idx,t in enumerate(note_tracks): 74 | if len(supported_tracks) >= cutoff: break 75 | track_type = get_track_type(t) 76 | if track_type not in supported_types: continue 77 | pitch_set = unique_track_notes(t) 78 | if (len(pitch_set) < min_variation): continue # must have more than x unique notes 79 | if not all(map(is_piano_note, pitch_set)): continue # must not contain midi notes outside of piano range 80 | # if track_type == Track.UNDEF: print('Could not designate track:', fp, t) 81 | change_track_instrument(t, type2inst[track_type]) 82 | supported_tracks.append(t) 83 | if not supported_tracks: return None 84 | music_file.tracks = info_tracks + supported_tracks 85 | return music_file 86 | 87 | def get_track_type(t): 88 | if is_channel(t, 10): return Track.PERC 89 | i = get_track_instrument(t) 90 | if i in PIANO_TYPES: return Track.PIANO 91 | if i in PLUCK_TYPES: return Track.PLUCK 92 | if i in BRIGHT_TYPES: return Track.BRIGHT 93 | return Track.UNDEF 94 | 95 | def get_track_instrument(t): 96 | for idx,e in enumerate(t.events): 97 | if e.type == 'PROGRAM_CHANGE': return e.data 98 | return None 99 | 100 | def change_track_instrument(t, value): 101 | for idx,e in enumerate(t.events): 102 | if e.type == 'PROGRAM_CHANGE': e.data = value 103 | 104 | def print_music21_instruments(): 105 | for i in range(200): 106 | try: print(i, music21.instrument.instrumentFromMidiProgram(i)) 107 | except: pass -------------------------------------------------------------------------------- /musicautobot/utils/lamb.py: -------------------------------------------------------------------------------- 1 | # SOURCE: https://github.com/cybertronai/pytorch-lamb/ 2 | 3 | import collections 4 | import math 5 | 6 | import torch 7 | from torch.optim import Optimizer 8 | 9 | 10 | class Lamb(Optimizer): 11 | r"""Implements Lamb algorithm. 12 | 13 | It has been proposed in `Reducing BERT Pre-Training Time from 3 Days to 76 Minutes`_. 14 | 15 | Arguments: 16 | params (iterable): iterable of parameters to optimize or dicts defining 17 | parameter groups 18 | lr (float, optional): learning rate (default: 1e-3) 19 | betas (Tuple[float, float], optional): coefficients used for computing 20 | running averages of gradient and its square (default: (0.9, 0.999)) 21 | eps (float, optional): term added to the denominator to improve 22 | numerical stability (default: 1e-8) 23 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 24 | adam (bool, optional): always use trust ratio = 1, which turns this into 25 | Adam. Useful for comparison purposes. 26 | 27 | .. _Reducing BERT Pre-Training Time from 3 Days to 76 Minutes: 28 | https://arxiv.org/abs/1904.00962 29 | """ 30 | 31 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-4, 32 | weight_decay=0, adam=False): 33 | if not 0.0 <= lr: 34 | raise ValueError("Invalid learning rate: {}".format(lr)) 35 | if not 0.0 <= eps: 36 | raise ValueError("Invalid epsilon value: {}".format(eps)) 37 | if not 0.0 <= betas[0] < 1.0: 38 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 39 | if not 0.0 <= betas[1] < 1.0: 40 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 41 | defaults = dict(lr=lr, betas=betas, eps=eps, 42 | weight_decay=weight_decay) 43 | self.adam = adam 44 | super(Lamb, self).__init__(params, defaults) 45 | 46 | def step(self, closure=None): 47 | """Performs a single optimization step. 48 | 49 | Arguments: 50 | closure (callable, optional): A closure that reevaluates the model 51 | and returns the loss. 52 | """ 53 | loss = None 54 | if closure is not None: 55 | loss = closure() 56 | 57 | for group in self.param_groups: 58 | for p in group['params']: 59 | if p.grad is None: 60 | continue 61 | grad = p.grad.data 62 | if grad.is_sparse: 63 | raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') 64 | 65 | state = self.state[p] 66 | 67 | # State initialization 68 | if len(state) == 0: 69 | state['step'] = 0 70 | # Exponential moving average of gradient values 71 | state['exp_avg'] = torch.zeros_like(p.data) 72 | # Exponential moving average of squared gradient values 73 | state['exp_avg_sq'] = torch.zeros_like(p.data) 74 | 75 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 76 | beta1, beta2 = group['betas'] 77 | 78 | state['step'] += 1 79 | 80 | if group['weight_decay'] != 0: 81 | grad.add_(group['weight_decay'], p.data) 82 | 83 | # Decay the first and second moment running average coefficient 84 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 85 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 86 | denom = exp_avg_sq.sqrt().add_(group['eps']) 87 | 88 | bias_correction1 = 1 - beta1 ** state['step'] 89 | bias_correction2 = 1 - beta2 ** state['step'] 90 | # Apply bias to lr to avoid broadcast. 91 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 92 | 93 | adam_step = exp_avg / denom 94 | # L2 norm uses sum, but here since we're dividing, use mean to avoid overflow. 95 | r1 = p.data.pow(2).mean().sqrt() 96 | r2 = adam_step.pow(2).mean().sqrt() 97 | r = 1 if r1 == 0 or r2 == 0 else min(r1/r2, 10) 98 | state['r1'] = r1 99 | state['r2'] = r2 100 | state['r'] = r 101 | if self.adam: 102 | r = 1 103 | 104 | p.data.add_(-step_size * r, adam_step) 105 | 106 | return loss 107 | -------------------------------------------------------------------------------- /scripts/run_multitask.py: -------------------------------------------------------------------------------- 1 | import music21 2 | import torch 3 | 4 | from fastai.distributed import * 5 | from fastai.callbacks import SaveModelCallback 6 | try: from apex.optimizers import FusedAdam 7 | except: from torch.optim import Adam as FusedAdam 8 | 9 | import numpy as np 10 | 11 | import sys 12 | sys.path.insert(0, '..') 13 | 14 | from musicautobot.music_transformer import * 15 | from musicautobot.multitask_transformer import * 16 | from musicautobot.utils.stacked_dataloader import StackedDataBunch 17 | 18 | import argparse 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--path', type=str, default='../data/numpy/') 21 | parser.add_argument('--data_file', type=str, default='musicitem_data_save.pkl') 22 | parser.add_argument('--s2s_data_file', type=str, default='multiitem_data_save.pkl') 23 | parser.add_argument('--save', type=str, default='first_run') 24 | parser.add_argument('--load', type=str, default=None) 25 | parser.add_argument("--local_rank", type=int, default=0) 26 | parser.add_argument("--batch_size", type=int, default=4) 27 | parser.add_argument("--num_workers", type=int, default=12) 28 | parser.add_argument("--bptt", type=int, default=1024) 29 | parser.add_argument('--half', action='store_true', help='Use half precision') 30 | parser.add_argument('--lamb', action='store_true', help='Use lamb optimizer') 31 | parser.add_argument('--wd', type=float, default=1e-3, help='weight decay for adam') 32 | parser.add_argument('--epochs', type=int, default=5, help='num epochs') 33 | parser.add_argument('--lr', type=float, default=1e-3, help='learning rate') 34 | parser.add_argument('--div_factor', type=int, default=10, help='learning rate div factor') 35 | parser.add_argument('--save_every', action='store_true', help='Save every epoch') 36 | parser.add_argument('--config', type=str, default='multitask_config', help='serve.py config name') 37 | parser.add_argument('--no_transpose', action='store_true', help='No transpose data augmentation') 38 | parser.add_argument('--data_parallel', action='store_true', help='DataParallel instead of DDP') 39 | parser.add_argument('--mask_steps', type=int, default=1, help='Attention mask - max number of random steps. Basically teacher forcing') 40 | parser.add_argument('--mask_pitchdur', action='store_true', help='Mask either pitch or duration') 41 | 42 | args = parser.parse_args() 43 | args.path = Path(args.path) 44 | 45 | 46 | if args.local_rank != 0: 47 | f = open('/dev/null', 'w') 48 | sys.stdout = f 49 | 50 | # if is_distributed: 51 | # torch.cuda.set_device(args.local_rank) 52 | # torch.distributed.init_process_group(backend='nccl', init_method='env://') 53 | is_distributed = num_distrib() > 0 54 | setup_distrib(args.local_rank) 55 | 56 | path = Path(args.path) 57 | 58 | from musicautobot import config 59 | config = getattr(config, args.config)() 60 | config['mask_steps'] = args.mask_steps 61 | 62 | 63 | datasets = [] 64 | transpose_range = None if args.no_transpose else (0,12) 65 | 66 | mlm_tfm = mask_lm_tfm_pitchdur if args.mask_pitchdur else partial(mask_lm_tfm_default, mask_p=0.4) 67 | data = load_data(args.path, Path('piano_duet')/args.data_file, 68 | bs=args.batch_size, bptt=args.bptt, transpose_range=transpose_range, 69 | dl_tfms=mlm_tfm, num_workers=args.num_workers) 70 | 71 | datasets.append(data) 72 | 73 | s2s_data = load_data(args.path, Path('s2s_encode')/args.data_file, 74 | bs=args.batch_size//4, bptt=args.bptt, transpose_range=transpose_range, 75 | preloader_cls=S2SPreloader, dl_tfms=melody_chord_tfm, num_workers=args.num_workers) 76 | 77 | datasets.append(s2s_data) 78 | 79 | combined_data = StackedDataBunch(datasets) 80 | 81 | # Load Optimizer 82 | eps = 1e-2 if args.half else 1e-6 83 | opt_func = partial(FusedAdam, betas=(0.9,0.99), eps=eps) 84 | if args.lamb: 85 | from musicautobot.utils.lamb import Lamb 86 | opt_func = partial(Lamb, eps=eps) 87 | 88 | # Load Learner 89 | load_path = path/args.load if args.load else None 90 | learn = multitask_model_learner(combined_data, config.copy(), opt_func=opt_func, pretrained_path=load_path) 91 | 92 | if not args.half: learn.clip_grad(1.0) 93 | if args.save: 94 | save_path = path/learn.model_dir/args.save 95 | save_path.parent.mkdir(parents=True, exist_ok=True) 96 | if args.half: learn = learn.to_fp16(clip=1.0, dynamic=True, max_scale=2**18) 97 | if is_distributed: learn = learn.to_distributed(args.local_rank, cache_dir=path/'dist_logs') 98 | if args.data_parallel: learn = learn.to_parallel() 99 | if args.local_rank == 0: learn.callbacks.append(SaveModelCallback(learn, name=f'{args.save}_best')) 100 | 101 | learn.fit_one_cycle(args.epochs, args.lr, div_factor=args.div_factor, pct_start=.3, final_div=50, wd=args.wd) 102 | 103 | if args.local_rank == 0: learn.save(f'{args.save}', config=config) 104 | -------------------------------------------------------------------------------- /notebooks/music_transformer/Train-Simple.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%reload_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "%matplotlib inline" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import os\n", 21 | "os.chdir('../../')" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 3, 27 | "metadata": {}, 28 | "outputs": [ 29 | { 30 | "name": "stdout", 31 | "output_type": "stream", 32 | "text": [ 33 | "Warning: Could not find musescore installation. Please install musescore (see README) and/or update music21 environment paths\n" 34 | ] 35 | } 36 | ], 37 | "source": [ 38 | "from musicautobot.numpy_encode import *\n", 39 | "from musicautobot.utils.file_processing import process_all, process_file\n", 40 | "from musicautobot.config import *\n", 41 | "from musicautobot.music_transformer import *" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 4, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "# Location of your midi filesfiles\n", 51 | "midi_path = Path('data/midi/examples')\n", 52 | "data_path = Path('data/numpy')\n", 53 | "data_save_name = 'musicitem_data_save.pkl'" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 5, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "midi_files = get_files(midi_path, '.mid', recurse=True)\n", 63 | "data = MusicDataBunch.from_files(midi_files, data_path, processors=[Midi2ItemProcessor()], bs=4, bptt=128, encode_position=False)\n", 64 | "\n", 65 | "learn = music_model_learner(data, arch=TransformerXL, config=default_config())" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 8, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "import warnings\n", 75 | "warnings.simplefilter(\"ignore\", UserWarning)" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 9, 81 | "metadata": { 82 | "scrolled": false 83 | }, 84 | "outputs": [ 85 | { 86 | "data": { 87 | "text/html": [ 88 | "\n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | "
epochtrain_lossvalid_lossaccuracytime
02.4888762.2629550.53613300:01
12.4937752.6460770.20996100:01
22.5079152.2969160.53613300:01
32.4859552.1833620.53613300:01
" 129 | ], 130 | "text/plain": [ 131 | "" 132 | ] 133 | }, 134 | "metadata": {}, 135 | "output_type": "display_data" 136 | } 137 | ], 138 | "source": [ 139 | "learn.fit_one_cycle(4)" 140 | ] 141 | }, 142 | { 143 | "cell_type": "markdown", 144 | "metadata": {}, 145 | "source": [ 146 | "## 5. Predict" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "midi_file = Path('data/midi/notebook_examples/single_bar_example.mid')\n", 156 | "item = MusicItem.from_file(midi_file, data.vocab);\n", 157 | "pred = learn.predict(item, n_words=100)\n", 158 | "pred.show()" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "pred.play()" 168 | ] 169 | } 170 | ], 171 | "metadata": { 172 | "kernelspec": { 173 | "display_name": "Python 3", 174 | "language": "python", 175 | "name": "python3" 176 | }, 177 | "language_info": { 178 | "codemirror_mode": { 179 | "name": "ipython", 180 | "version": 3 181 | }, 182 | "file_extension": ".py", 183 | "mimetype": "text/x-python", 184 | "name": "python", 185 | "nbconvert_exporter": "python", 186 | "pygments_lexer": "ipython3", 187 | "version": "3.7.6" 188 | } 189 | }, 190 | "nbformat": 4, 191 | "nbformat_minor": 2 192 | } 193 | -------------------------------------------------------------------------------- /musicautobot/multitask_transformer/dataloader.py: -------------------------------------------------------------------------------- 1 | from fastai.basics import * 2 | from .transform import * 3 | from ..music_transformer.dataloader import MusicDataBunch, MusicItemList 4 | # Sequence 2 Sequence Translate 5 | 6 | class S2SFileProcessor(PreProcessor): 7 | "`PreProcessor` that opens the filenames and read the texts." 8 | def process_one(self,item): 9 | out = np.load(item, allow_pickle=True) 10 | if out.shape != (2,): return None 11 | if not 16 < len(out[0]) < 2048: return None 12 | if not 16 < len(out[1]) < 2048: return None 13 | return out 14 | 15 | def process(self, ds:Collection): 16 | ds.items = [self.process_one(item) for item in ds.items] 17 | ds.items = [i for i in ds.items if i is not None] # filter out None 18 | 19 | class S2SPartsProcessor(PreProcessor): 20 | "Encodes midi file into 2 separate parts - melody and chords." 21 | 22 | def process_one(self, item): 23 | m, c = item 24 | mtrack = MultitrackItem.from_npenc_parts(m, c, vocab=self.vocab) 25 | return mtrack.to_idx() 26 | 27 | def process(self, ds): 28 | self.vocab = ds.vocab 29 | ds.items = [self.process_one(item) for item in ds.items] 30 | 31 | class Midi2MultitrackProcessor(PreProcessor): 32 | "Converts midi files to multitrack items" 33 | def process_one(self, midi_file): 34 | try: 35 | item = MultitrackItem.from_file(midi_file, vocab=self.vocab) 36 | except Exception as e: 37 | print(e) 38 | return None 39 | return item.to_idx() 40 | 41 | def process(self, ds): 42 | self.vocab = ds.vocab 43 | ds.items = [self.process_one(item) for item in ds.items] 44 | ds.items = [i for i in ds.items if i is not None] 45 | 46 | class S2SPreloader(Callback): 47 | def __init__(self, dataset:LabelList, bptt:int=512, 48 | transpose_range=None, **kwargs): 49 | self.dataset,self.bptt = dataset,bptt 50 | self.vocab = self.dataset.vocab 51 | self.transpose_range = transpose_range 52 | self.rand_transpose = partial(rand_transpose_value, rand_range=transpose_range) if transpose_range is not None else None 53 | 54 | def __getitem__(self, k:int): 55 | item,empty_label = self.dataset[k] 56 | 57 | if self.rand_transpose is not None: 58 | val = self.rand_transpose() 59 | item = item.transpose(val) 60 | item = item.pad_to(self.bptt+1) 61 | ((m_x, m_pos), (c_x, c_pos)) = item.to_idx() 62 | return m_x, m_pos, c_x, c_pos 63 | 64 | def __len__(self): 65 | return len(self.dataset) 66 | 67 | def rand_transpose_value(rand_range=(0,24), p=0.5): 68 | if np.random.rand() < p: return np.random.randint(*rand_range)-rand_range[1]//2 69 | return 0 70 | 71 | class S2SItemList(MusicItemList): 72 | _bunch = MusicDataBunch 73 | def get(self, i): 74 | return MultitrackItem.from_idx(self.items[i], self.vocab) 75 | 76 | # DATALOADING AND TRANSFORMATIONS 77 | # These transforms happen on batch 78 | 79 | def mask_tfm(b, mask_range, mask_idx, pad_idx, p=0.3): 80 | # mask range (min, max) 81 | # replacement vals - [x_replace, y_replace]. Usually [mask_idx, pad_idx] 82 | # p = replacement probability 83 | x,y = b 84 | x,y = x.clone(),y.clone() 85 | rand = torch.rand(x.shape, device=x.device) 86 | rand[x < mask_range[0]] = 1.0 87 | rand[x >= mask_range[1]] = 1.0 88 | 89 | # p(15%) of words are replaced. Of those p(15%) - 80% are masked. 10% wrong word. 10% unchanged 90 | y[rand > p] = pad_idx # pad unchanged 80%. Remove these from loss/acc metrics 91 | x[rand <= (p*.8)] = mask_idx # 80% = mask 92 | wrong_word = (rand > (p*.8)) & (rand <= (p*.9)) # 10% = wrong word 93 | x[wrong_word] = torch.randint(*mask_range, [wrong_word.sum().item()], device=x.device) 94 | return x, y 95 | 96 | def mask_lm_tfm_default(b, vocab, mask_p=0.3): 97 | return mask_lm_tfm(b, mask_range=vocab.npenc_range, mask_idx=vocab.mask_idx, pad_idx=vocab.pad_idx, mask_p=mask_p) 98 | 99 | def mask_lm_tfm_pitchdur(b, vocab, mask_p=0.9): 100 | mask_range = vocab.dur_range if np.random.rand() < 0.5 else vocab.note_range 101 | return mask_lm_tfm(b, mask_range=mask_range, mask_idx=vocab.mask_idx, pad_idx=vocab.pad_idx, mask_p=mask_p) 102 | 103 | def mask_lm_tfm(b, mask_range, mask_idx, pad_idx, mask_p): 104 | x,y = b 105 | x_lm,x_pos = x[...,0], x[...,1] 106 | y_lm,y_pos = y[...,0], y[...,1] 107 | 108 | # Note: masking y_lm instead of x_lm. Just in case we ever do sequential s2s training 109 | x_msk, y_msk = mask_tfm((y_lm, y_lm), mask_range=mask_range, mask_idx=mask_idx, pad_idx=pad_idx, p=mask_p) 110 | msk_pos = y_pos 111 | 112 | x_dict = { 113 | 'msk': { 'x': x_msk, 'pos': msk_pos }, 114 | 'lm': { 'x': x_lm, 'pos': msk_pos } 115 | } 116 | y_dict = { 'msk': y_msk, 'lm': y_lm } 117 | return x_dict, y_dict 118 | 119 | def melody_chord_tfm(b): 120 | m,m_pos,c,c_pos = b 121 | 122 | # offset x and y for next word prediction 123 | y_m = m[:,1:] 124 | x_m, m_pos = m[:,:-1], m_pos[:,:-1] 125 | 126 | y_c = c[:,1:] 127 | x_c, c_pos = c[:,:-1], c_pos[:,:-1] 128 | 129 | x_dict = { 130 | 'c2m': { 131 | 'enc': x_c, 132 | 'enc_pos': c_pos, 133 | 'dec': x_m, 134 | 'dec_pos': m_pos 135 | }, 136 | 'm2c': { 137 | 'enc': x_m, 138 | 'enc_pos': m_pos, 139 | 'dec': x_c, 140 | 'dec_pos': c_pos 141 | } 142 | } 143 | y_dict = { 144 | 'c2m': y_m, 'm2c': y_c 145 | } 146 | return x_dict, y_dict 147 | -------------------------------------------------------------------------------- /notebooks/music_transformer/Train-Advanced.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%reload_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "%matplotlib inline" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import os\n", 21 | "os.chdir('../../')" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 3, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "from musicautobot.numpy_encode import *\n", 31 | "from musicautobot.utils.file_processing import process_all, process_file\n", 32 | "from musicautobot.config import *\n", 33 | "from musicautobot.music_transformer import *" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 4, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "# Location of your midi filesfiles\n", 43 | "midi_path = Path('data/midi/examples')\n", 44 | "data_path = Path('data/numpy')\n", 45 | "data_save_name = 'musicitem_data_save.pkl'" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 5, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "midi_files = get_files(midi_path, '.mid', recurse=True)\n", 55 | "data = MusicDataBunch.from_files(midi_files, data_path, processors=[Midi2ItemProcessor()], bs=4, bptt=128,\n", 56 | " encode_position=True, dl_tfms=[batch_position_tfm])" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 6, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "config = default_config()\n", 66 | "config['encode_position'] = True\n", 67 | "config['transpose_range'] = (0, 12)\n", 68 | "config['mask_steps'] = 4" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 7, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "model = get_language_model(arch=MusicTransformerXL, vocab_sz=len(data.vocab), config=config.copy())" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 8, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "learn = MusicLearner(data, model)" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 9, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "learn.to_fp16(dynamic=True, clip=0.5);" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 10, 101 | "metadata": { 102 | "scrolled": false 103 | }, 104 | "outputs": [ 105 | { 106 | "data": { 107 | "text/html": [ 108 | "\n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | "
epochtrain_lossvalid_lossaccuracytime
03.8912042.6331920.37500000:02
13.7576072.6682470.38671900:01
23.5064662.5034910.12500000:01
33.4037562.6069180.12500000:01
" 149 | ], 150 | "text/plain": [ 151 | "" 152 | ] 153 | }, 154 | "metadata": {}, 155 | "output_type": "display_data" 156 | } 157 | ], 158 | "source": [ 159 | "learn.fit_one_cycle(4)" 160 | ] 161 | }, 162 | { 163 | "cell_type": "markdown", 164 | "metadata": {}, 165 | "source": [ 166 | "## 5. Predict" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "midi_file = Path('data/midi/notebook_examples/single_bar_example.mid')\n", 176 | "item = MusicItem.from_file(midi_file, data.vocab);\n", 177 | "pred, full = learn.predict(item, n_words=100)\n", 178 | "pred.show()" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": null, 184 | "metadata": {}, 185 | "outputs": [], 186 | "source": [ 187 | "pred.play()" 188 | ] 189 | } 190 | ], 191 | "metadata": { 192 | "kernelspec": { 193 | "display_name": "Python 3", 194 | "language": "python", 195 | "name": "python3" 196 | }, 197 | "language_info": { 198 | "codemirror_mode": { 199 | "name": "ipython", 200 | "version": 3 201 | }, 202 | "file_extension": ".py", 203 | "mimetype": "text/x-python", 204 | "name": "python", 205 | "nbconvert_exporter": "python", 206 | "pygments_lexer": "ipython3", 207 | "version": "3.7.4" 208 | } 209 | }, 210 | "nbformat": 4, 211 | "nbformat_minor": 2 212 | } 213 | -------------------------------------------------------------------------------- /musicautobot/music_transformer/learner.py: -------------------------------------------------------------------------------- 1 | from fastai.basics import * 2 | from fastai.text.learner import LanguageLearner, get_language_model, _model_meta 3 | from .model import * 4 | from .transform import MusicItem 5 | from ..numpy_encode import SAMPLE_FREQ 6 | from ..utils.top_k_top_p import top_k_top_p 7 | from ..utils.midifile import is_empty_midi 8 | 9 | _model_meta[MusicTransformerXL] = _model_meta[TransformerXL] # copy over fastai's model metadata 10 | 11 | def music_model_learner(data:DataBunch, arch=MusicTransformerXL, config:dict=None, drop_mult:float=1., 12 | pretrained_path:PathOrStr=None, **learn_kwargs) -> 'LanguageLearner': 13 | "Create a `Learner` with a language model from `data` and `arch`." 14 | meta = _model_meta[arch] 15 | 16 | if pretrained_path: 17 | state = torch.load(pretrained_path, map_location='cpu') 18 | if config is None: config = state['config'] 19 | 20 | model = get_language_model(arch, len(data.vocab.itos), config=config, drop_mult=drop_mult) 21 | learn = MusicLearner(data, model, split_func=meta['split_lm'], **learn_kwargs) 22 | 23 | if pretrained_path: 24 | get_model(model).load_state_dict(state['model'], strict=False) 25 | if not hasattr(learn, 'opt'): learn.create_opt(defaults.lr, learn.wd) 26 | try: learn.opt.load_state_dict(state['opt']) 27 | except: pass 28 | del state 29 | gc.collect() 30 | 31 | return learn 32 | 33 | # Predictions 34 | from fastai import basic_train # for predictions 35 | class MusicLearner(LanguageLearner): 36 | def save(self, file:PathLikeOrBinaryStream=None, with_opt:bool=True, config=None): 37 | "Save model and optimizer state (if `with_opt`) with `file` to `self.model_dir`. `file` can be file-like (file or buffer)" 38 | out_path = super().save(file, return_path=True, with_opt=with_opt) 39 | if config and out_path: 40 | state = torch.load(out_path) 41 | state['config'] = config 42 | torch.save(state, out_path) 43 | del state 44 | gc.collect() 45 | return out_path 46 | 47 | def beam_search(self, xb:Tensor, n_words:int, top_k:int=10, beam_sz:int=10, temperature:float=1., 48 | ): 49 | "Return the `n_words` that come after `text` using beam search." 50 | self.model.reset() 51 | self.model.eval() 52 | xb_length = xb.shape[-1] 53 | if xb.shape[0] > 1: xb = xb[0][None] 54 | yb = torch.ones_like(xb) 55 | 56 | nodes = None 57 | xb = xb.repeat(top_k, 1) 58 | nodes = xb.clone() 59 | scores = xb.new_zeros(1).float() 60 | with torch.no_grad(): 61 | for k in progress_bar(range(n_words), leave=False): 62 | out = F.log_softmax(self.model(xb)[0][:,-1], dim=-1) 63 | values, indices = out.topk(top_k, dim=-1) 64 | scores = (-values + scores[:,None]).view(-1) 65 | indices_idx = torch.arange(0,nodes.size(0))[:,None].expand(nodes.size(0), top_k).contiguous().view(-1) 66 | sort_idx = scores.argsort()[:beam_sz] 67 | scores = scores[sort_idx] 68 | nodes = torch.cat([nodes[:,None].expand(nodes.size(0),top_k,nodes.size(1)), 69 | indices[:,:,None].expand(nodes.size(0),top_k,1),], dim=2) 70 | nodes = nodes.view(-1, nodes.size(2))[sort_idx] 71 | self.model[0].select_hidden(indices_idx[sort_idx]) 72 | xb = nodes[:,-1][:,None] 73 | if temperature != 1.: scores.div_(temperature) 74 | node_idx = torch.multinomial(torch.exp(-scores), 1).item() 75 | return [i.item() for i in nodes[node_idx][xb_length:] ] 76 | 77 | def predict(self, item:MusicItem, n_words:int=128, 78 | temperatures:float=(1.0,1.0), min_bars=4, 79 | top_k=30, top_p=0.6): 80 | "Return the `n_words` that come after `text`." 81 | self.model.reset() 82 | new_idx = [] 83 | vocab = self.data.vocab 84 | x, pos = item.to_tensor(), item.get_pos_tensor() 85 | last_pos = pos[-1] if len(pos) else 0 86 | y = torch.tensor([0]) 87 | 88 | start_pos = last_pos 89 | 90 | sep_count = 0 91 | bar_len = SAMPLE_FREQ * 4 # assuming 4/4 time 92 | vocab = self.data.vocab 93 | 94 | repeat_count = 0 95 | if hasattr(self.model[0], 'encode_position'): 96 | encode_position = self.model[0].encode_position 97 | else: encode_position = False 98 | 99 | for i in progress_bar(range(n_words), leave=True): 100 | with torch.no_grad(): 101 | if encode_position: 102 | batch = { 'x': x[None], 'pos': pos[None] } 103 | logits = self.model(batch)[0][-1][-1] 104 | else: 105 | logits = self.model(x[None])[0][-1][-1] 106 | 107 | prev_idx = new_idx[-1] if len(new_idx) else vocab.pad_idx 108 | 109 | # Temperature 110 | # Use first temperatures value if last prediction was duration 111 | temperature = temperatures[0] if vocab.is_duration_or_pad(prev_idx) else temperatures[1] 112 | repeat_penalty = max(0, np.log((repeat_count+1)/4)/5) * temperature 113 | temperature += repeat_penalty 114 | if temperature != 1.: logits = logits / temperature 115 | 116 | 117 | # Filter 118 | # bar = 16 beats 119 | filter_value = -float('Inf') 120 | if ((last_pos - start_pos) // 16) <= min_bars: logits[vocab.bos_idx] = filter_value 121 | 122 | logits = filter_invalid_indexes(logits, prev_idx, vocab, filter_value=filter_value) 123 | logits = top_k_top_p(logits, top_k=top_k, top_p=top_p, filter_value=filter_value) 124 | 125 | # Sample 126 | probs = F.softmax(logits, dim=-1) 127 | idx = torch.multinomial(probs, 1).item() 128 | 129 | # Update repeat count 130 | num_choices = len(probs.nonzero().view(-1)) 131 | if num_choices <= 2: repeat_count += 1 132 | else: repeat_count = repeat_count // 2 133 | 134 | if prev_idx==vocab.sep_idx: 135 | duration = idx - vocab.dur_range[0] 136 | last_pos = last_pos + duration 137 | 138 | bars_pred = (last_pos - start_pos) // 16 139 | abs_bar = last_pos // 16 140 | # if (bars % 8 == 0) and (bars_pred > min_bars): break 141 | if (i / n_words > 0.80) and (abs_bar % 4 == 0): break 142 | 143 | 144 | if idx==vocab.bos_idx: 145 | print('Predicted BOS token. Returning prediction...') 146 | break 147 | 148 | new_idx.append(idx) 149 | x = x.new_tensor([idx]) 150 | pos = pos.new_tensor([last_pos]) 151 | 152 | pred = vocab.to_music_item(np.array(new_idx)) 153 | full = item.append(pred) 154 | return pred, full 155 | 156 | # High level prediction functions from midi file 157 | def predict_from_midi(learn, midi=None, n_words=400, 158 | temperatures=(1.0,1.0), top_k=30, top_p=0.6, seed_len=None, **kwargs): 159 | vocab = learn.data.vocab 160 | seed = MusicItem.from_file(midi, vocab) if not is_empty_midi(midi) else MusicItem.empty(vocab) 161 | if seed_len is not None: seed = seed.trim_to_beat(seed_len) 162 | 163 | pred, full = learn.predict(seed, n_words=n_words, temperatures=temperatures, top_k=top_k, top_p=top_p, **kwargs) 164 | return full 165 | 166 | def filter_invalid_indexes(res, prev_idx, vocab, filter_value=-float('Inf')): 167 | if vocab.is_duration_or_pad(prev_idx): 168 | res[list(range(*vocab.dur_range))] = filter_value 169 | else: 170 | res[list(range(*vocab.note_range))] = filter_value 171 | return res 172 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MusicAutobot 2 | 3 | Using Deep Learning to generate pop music! 4 | 5 | You can also experiment through the web app - [musicautobot.com](http://musicautobot.com) 6 | 7 | ![Screenshot](images/musicautobot_screenshot.png) 8 | 9 | ## Overview 10 | 11 | Recent advances in NLP have produced amazing [results](https://transformer.huggingface.co/) in generating text. 12 | [Transformer](http://jalammar.github.io/illustrated-transformer/) architecture is a big reason behind this. 13 | 14 | This project aims to leverage these powerful language models and apply them to music. It's built on top of the fast.ai [library](https://github.com/fastai/fastai) 15 | 16 | ## Implementation 17 | 18 | **MusicTransformer** - This basic model uses [Transformer-XL](https://github.com/kimiyoung/transformer-xl) to take a sequence of music notes and predict the next note. 19 | 20 | **MultitaskTransformer** - Built on top of MusicTransformer, this model is trained on multiple tasks. 21 | * Next Note Prediction (same as MusicTransformer) 22 | * [BERT](https://github.com/google-research/bert) Token Masking 23 | * Sequence To Sequence Translation - Using chords to predict melody and vice versa. 24 | 25 | Training on multiple tasks means we can generate some really cool predictions (Check out this [Notebook](notebooks/multitask_transformer/Generate.ipynb)): 26 | 1. [Harmonization](http://musicautobot.com/#/predict/2b4f5e6613f366bad7b4f39c61be32b9) - generate accompanying chords 27 | 2. [Melody](http://musicautobot.com/#/predict/3087b73963aaa2bae62424808a251628) - new melody from existing chord progression 28 | 3. Remix [tune](http://musicautobot.com/#/predict/1bbfcb942133414a5664a35a7e7b5612) - new song in the rhythm of a reference song 29 | 4. Remix [beat](http://musicautobot.com/#/predict/71d7ff59f67fffa98614c841101e1b6b) - same tune, different rhythm 30 | 31 | 32 | ## How it works 33 | 34 | Details are explained in this 4 part series: 35 | * [Part I](https://towardsdatascience.com/creating-a-pop-music-generator-with-the-transformer-5867511b382a) - Creating a Pop Music Generator 36 | * [Part II](https://towardsdatascience.com/practical-tips-for-training-a-music-model-755c62560ec2) - Implementation details 37 | * [Part III](https://towardsdatascience.com/a-multitask-music-model-with-bert-transformer-xl-and-seq2seq-3d80bd2ea08e) - Multitask Transformer 38 | * [Part IV](https://towardsdatascience.com/how-to-remix-the-chainsmokers-with-a-music-bot-6b920359248c) - Composing a song with Multitask 39 | 40 | 41 | ## Example Notebooks 42 | 43 | 1. Play with predictions on Google Colab 44 | * [MusicTransformer Generate](https://colab.research.google.com/github/bearpelican/musicautobot/blob/master/notebooks/music_transformer/Generate_colab.ipynb) - Loads a pretrained model and shows how to generate/predict new notes 45 | * [MultitaskTransformer Generate](https://colab.research.google.com/github/bearpelican/musicautobot/blob/master/notebooks/multitask_transformer/Generate_colab.ipynb) - Loads a pretrained model and shows how to harmonize, generate new melodies, and remix existing songs. 46 | 47 | 2. MusicTransformer 48 | * [Train](notebooks/music_transformer/Train.ipynb) - End to end example on how to create a dataset from midi files and train a model from scratch 49 | * [Generate](notebooks/music_tranformer/Generate.ipynb) - Loads a pretrained model and shows how to generate/predict new notes 50 | 51 | 3. MultitaskTransformer 52 | * [Train](notebooks/multitask_transformer/Train.ipynb) - End to end example on creating a seq2seq and masked dataset for multitask training. 53 | * [Generate](notebooks/multitask_tranformer/Generate.ipynb) - Loads a pretrained model and shows how to harmonize, generate new melodies, and remix existing songs. 54 | 55 | 4. Data Encoding 56 | * [Midi2Tensor](notebooks/data_encoding/Midi2Tensor.ipynb) - Shows how the libary internally encodes midi files to tensors for training. 57 | * [MusicItem](notebooks/data_encoding/MusicItem-Transforms.ipynb) - MusicItem is a wrapper that makes it easy to manipulate midi data. Convert midi to tensor, apply data transformations, even play music or display the notes within browser. 58 | 59 | ## Pretrained Models 60 | 61 | Pretrained models are available as MusicTransformer and MultitaskTransformer (small and large). 62 | 63 | Each model has an additional `keyC` version. `keyC` means that the model has been trained solely on music transposed to the key of C (all white keys). These models produce better results, but expects the input to all be in the key of C. 64 | 65 | 1. MusicTransformer (600 MB) - [AnyKey](https://ashaw-midi-web-server.s3-us-west-2.amazonaws.com/pretrained/MusicTransformer.pth) | [KeyC](https://ashaw-midi-web-server.s3-us-west-2.amazonaws.com/pretrained/MusicTransformerKeyC.pth) 66 | 67 | 2. MultitaskTransformer 68 | * Small (700 MB) - [AnyKey](https://ashaw-midi-web-server.s3-us-west-2.amazonaws.com/pretrained/MultitaskSmall.pth) | [KeyC](https://ashaw-midi-web-server.s3-us-west-2.amazonaws.com/pretrained/MultitaskSmallKeyC.pth) 69 | * Large (2.1 GB) - [AnyKey](https://ashaw-midi-web-server.s3-us-west-2.amazonaws.com/pretrained/MultitaskLarge.pth) | [KeyC](https://ashaw-midi-web-server.s3-us-west-2.amazonaws.com/pretrained/MultitaskLargeKeyC.pth) 70 | 71 | For details on how to load these models, follow the [Generate](notebooks/music_tranformer/Generate.ipynb) and [Multitask Generate](notebooks/multitask_tranformer/Generate.ipynb) notebooks 72 | 73 | ## Source Code 74 | 75 | * [musicautobot/](musicautobot) 76 | * [numpy_encode.py](musicautobot/numpy_encode.py) - submodule for encoding midi to tensor 77 | * [music_transformer.py](musicautobot/music_transformer) - Submodule structure similar to fastai's library. 78 | * Learner, Model, Transform - MusicItem, Dataloader 79 | * [multitask_transformer.py](musicautobot/multitask_transformer) - Submodule structure similar to fastai's library. 80 | * Learner, Model, Transform - MusicItem, Dataloader 81 | 82 | ## Scripts 83 | 84 | CLI scripts for training models: 85 | **[run_multitask.py](scripts/run_multitask.py)** - multitask training 86 | ``` 87 | python run_multitask.py --epochs 14 --save multitask_model --batch_size=16 --bptt=512 --lamb --data_parallel --lr 1e-4 88 | ``` 89 | **[run_music_transformer.py](scripts/run_music_transformer.py)** - music model training 90 | ``` 91 | python run_music_transformer.py --epochs 14 --save music_model --batch_size=16 --bptt=512 --lr 1e-4 92 | ``` 93 | **[run_ddp.sh](scripts/run_ddp.sh)** - Helper method to train with mulitple GPUs (DistributedDataParallel). Only works with run_music_transformer.py 94 | ``` 95 | SCRIPT=run_multitask.py bash run_ddp.sh --epochs 14 --save music_model --batch_size=16 --bptt=512 --lr 1e-4 96 | ``` 97 | 98 | **Commands must be run inside the `scripts/` folder** 99 | 100 | ## Installation 101 | 102 | 1. Install anaconda: https://www.anaconda.com/distribution/ 103 | 104 | 105 | 2. Run: 106 | 107 | ```bash 108 | git clone https://github.com/bearpelican/musicautobot.git 109 | 110 | cd musicautobot 111 | 112 | conda env update -f environment.yml 113 | 114 | source activate musicautobot 115 | ``` 116 | 117 | 3. Install Musescore - to view sheet music within a jupyter notebook 118 | 119 | Ubuntu: 120 | ```bash 121 | sudo apt-get install musescore 122 | ``` 123 | 124 | MacOS - [download](https://musescore.org/en/download) 125 | 126 | ## Flask Server 127 | 128 | Installation: 129 | ```bash 130 | cd serve 131 | 132 | conda env update -f environment.yml 133 | ``` 134 | 135 | #### S3 Bucket 136 | You need to setup an s3 bucket to save your predictions. 137 | After you've created a bucket, update the config [api/api.cfg](api/api.cfg) with the new bucket name. 138 | 139 | Development: 140 | ```bash 141 | python run.py 142 | ``` 143 | 144 | Production: 145 | ```bash 146 | gunicorn -b 127.0.0.1:5000 run_guni:app --timeout 180 --workers 8 147 | ``` 148 | 149 | ## Data 150 | 151 | Unfortunately I cannot provide the dataset used for training the model. 152 | 153 | Here's some suggestions: 154 | 155 | * [Classical Archives](https://www.classicalarchives.com/) - incredible catalog of high quality classical midi 156 | * [HookTheory](https://www.hooktheory.com/) - great data for sequence to sequence predictions. Need to manually copy files into hookpad 157 | * [Reddit](https://www.reddit.com/r/datasets/comments/3akhxy/the_largest_midi_collection_on_the_internet/) - 130k files 158 | * [Lakh](https://colinraffel.com/projects/lmd/) - great research dataset 159 | 160 | 161 | ## Acknowledgements 162 | 163 | This project is built on top of [fast.ai's](https://github.com/fastai/fastai) deep learning library and music21's incredible musicology [library](https://web.mit.edu/music21/). 164 | 165 | Inspired by [bachbot](https://github.com/feynmanliang/bachbot) and [clara](http://christinemcleavey.com/clara-a-neural-net-music-generator/) 166 | 167 | Special thanks to [SPC](https://southparkcommons.com) and [PalapaVC](https://www.palapavc.com/) 168 | -------------------------------------------------------------------------------- /musicautobot/music_transformer/transform.py: -------------------------------------------------------------------------------- 1 | from ..numpy_encode import * 2 | import numpy as np 3 | from enum import Enum 4 | import torch 5 | from ..vocab import * 6 | from functools import partial 7 | 8 | SEQType = Enum('SEQType', 'Mask, Sentence, Melody, Chords, Empty') 9 | 10 | class MusicItem(): 11 | def __init__(self, data, vocab, stream=None, position=None): 12 | self.data = data 13 | self.vocab = vocab 14 | self._stream = stream 15 | self._position = position 16 | def __repr__(self): return '\n'.join([ 17 | f'\n{self.__class__.__name__} - {self.data.shape}', 18 | f'{self.vocab.textify(self.data[:10])}...']) 19 | def __len__(self): return len(self.data) 20 | 21 | @classmethod 22 | def from_file(cls, midi_file, vocab): 23 | return cls.from_stream(file2stream(midi_file), vocab) 24 | @classmethod 25 | def from_stream(cls, stream, vocab): 26 | if not isinstance(stream, music21.stream.Score): stream = stream.voicesToParts() 27 | chordarr = stream2chordarr(stream) # 2. 28 | npenc = chordarr2npenc(chordarr) # 3. 29 | return cls.from_npenc(npenc, vocab, stream) 30 | @classmethod 31 | def from_npenc(cls, npenc, vocab, stream=None): return MusicItem(npenc2idxenc(npenc, vocab), vocab, stream) 32 | 33 | @classmethod 34 | def from_idx(cls, item, vocab): 35 | idx,pos = item 36 | return MusicItem(idx, vocab=vocab, position=pos) 37 | def to_idx(self): return self.data, self.position 38 | 39 | @classmethod 40 | def empty(cls, vocab, seq_type=SEQType.Sentence): 41 | return MusicItem(seq_prefix(seq_type, vocab), vocab) 42 | 43 | @property 44 | def stream(self): 45 | self._stream = self.to_stream() if self._stream is None else self._stream 46 | return self._stream 47 | 48 | def to_stream(self, bpm=120): 49 | return idxenc2stream(self.data, self.vocab, bpm=bpm) 50 | 51 | def to_tensor(self, device=None): 52 | return to_tensor(self.data, device) 53 | 54 | def to_text(self, sep=' '): return self.vocab.textify(self.data, sep) 55 | 56 | @property 57 | def position(self): 58 | self._position = position_enc(self.data, self.vocab) if self._position is None else self._position 59 | return self._position 60 | 61 | def get_pos_tensor(self, device=None): return to_tensor(self.position, device) 62 | 63 | def to_npenc(self): 64 | return idxenc2npenc(self.data, self.vocab) 65 | 66 | def show(self, format:str=None): 67 | return self.stream.show(format) 68 | def play(self): self.stream.show('midi') 69 | 70 | @property 71 | def new(self): 72 | return partial(type(self), vocab=self.vocab) 73 | 74 | def trim_to_beat(self, beat, include_last_sep=False): 75 | return self.new(trim_to_beat(self.data, self.position, self.vocab, beat, include_last_sep)) 76 | 77 | def transpose(self, interval): 78 | return self.new(tfm_transpose(self.data, interval, self.vocab), position=self._position) 79 | 80 | def append(self, item): 81 | return self.new(np.concatenate((self.data, item.data), axis=0)) 82 | 83 | def mask_pitch(self, section=None): 84 | return self.new(self.mask(self.vocab.note_range, section), position=self.position) 85 | 86 | def mask_duration(self, section=None, keep_position_enc=True): 87 | masked_data = self.mask(self.vocab.dur_range, section) 88 | if keep_position_enc: return self.new(masked_data, position=self.position) 89 | return self.new(masked_data) 90 | 91 | def mask(self, token_range, section_range=None): 92 | return mask_section(self.data, self.position, token_range, self.vocab.mask_idx, section_range=section_range) 93 | 94 | def pad_to(self, bptt): 95 | data = pad_seq(self.data, bptt, self.vocab.pad_idx) 96 | pos = pad_seq(self.position, bptt, 0) 97 | return self.new(data, stream=self._stream, position=pos) 98 | 99 | def split_stream_parts(self): 100 | self._stream = separate_melody_chord(self.stream) 101 | return self.stream 102 | 103 | def remove_eos(self): 104 | if self.data[-1] == self.vocab.stoi[EOS]: return self.new(self.data, stream=self.stream) 105 | return self 106 | 107 | def split_parts(self): 108 | return self.new(self.data, stream=separate_melody_chord(self.stream), position=self.position) 109 | 110 | def pad_seq(seq, bptt, value): 111 | pad_len = max(bptt-seq.shape[0], 0) 112 | return np.pad(seq, (0, pad_len), 'constant', constant_values=value)[:bptt] 113 | 114 | def to_tensor(t, device=None): 115 | t = t if isinstance(t, torch.Tensor) else torch.tensor(t) 116 | if device is None and torch.cuda.is_available(): t = t.cuda() 117 | else: t.to(device) 118 | return t.long() 119 | 120 | def midi2idxenc(midi_file, vocab): 121 | "Converts midi file to index encoding for training" 122 | npenc = midi2npenc(midi_file) # 3. 123 | return npenc2idxenc(npenc, vocab) 124 | 125 | def idxenc2stream(arr, vocab, bpm=120): 126 | "Converts index encoding to music21 stream" 127 | npenc = idxenc2npenc(arr, vocab) 128 | return npenc2stream(npenc, bpm=bpm) 129 | 130 | # single stream instead of note,dur 131 | def npenc2idxenc(t, vocab, seq_type=SEQType.Sentence, add_eos=False): 132 | "Transforms numpy array from 2 column (note, duration) matrix to a single column" 133 | "[[n1, d1], [n2, d2], ...] -> [n1, d1, n2, d2]" 134 | if isinstance(t, (list, tuple)) and len(t) == 2: 135 | return [npenc2idxenc(x, vocab, start_seq) for x in t] 136 | t = t.copy() 137 | 138 | t[:, 0] = t[:, 0] + vocab.note_range[0] 139 | t[:, 1] = t[:, 1] + vocab.dur_range[0] 140 | 141 | prefix = seq_prefix(seq_type, vocab) 142 | suffix = np.array([vocab.stoi[EOS]]) if add_eos else np.empty(0, dtype=int) 143 | return np.concatenate([prefix, t.reshape(-1), suffix]) 144 | 145 | def seq_prefix(seq_type, vocab): 146 | if seq_type == SEQType.Empty: return np.empty(0, dtype=int) 147 | start_token = vocab.bos_idx 148 | if seq_type == SEQType.Chords: start_token = vocab.stoi[CSEQ] 149 | if seq_type == SEQType.Melody: start_token = vocab.stoi[MSEQ] 150 | return np.array([start_token, vocab.pad_idx]) 151 | 152 | def idxenc2npenc(t, vocab, validate=True): 153 | if validate: t = to_valid_idxenc(t, vocab.npenc_range) 154 | t = t.copy().reshape(-1, 2) 155 | if t.shape[0] == 0: return t 156 | 157 | t[:, 0] = t[:, 0] - vocab.note_range[0] 158 | t[:, 1] = t[:, 1] - vocab.dur_range[0] 159 | 160 | if validate: return to_valid_npenc(t) 161 | return t 162 | 163 | def to_valid_idxenc(t, valid_range): 164 | r = valid_range 165 | t = t[np.where((t >= r[0]) & (t < r[1]))] 166 | if t.shape[-1] % 2 == 1: t = t[..., :-1] 167 | return t 168 | 169 | def to_valid_npenc(t): 170 | is_note = (t[:, 0] < VALTSEP) | (t[:, 0] >= NOTE_SIZE) 171 | invalid_note_idx = is_note.argmax() 172 | invalid_dur_idx = (t[:, 1] < 0).argmax() 173 | 174 | invalid_idx = max(invalid_dur_idx, invalid_note_idx) 175 | if invalid_idx > 0: 176 | if invalid_note_idx > 0 and invalid_dur_idx > 0: invalid_idx = min(invalid_dur_idx, invalid_note_idx) 177 | print('Non midi note detected. Only returning valid portion. Index, seed', invalid_idx, t.shape) 178 | return t[:invalid_idx] 179 | return t 180 | 181 | def position_enc(idxenc, vocab): 182 | "Calculates positional beat encoding." 183 | sep_idxs = (idxenc == vocab.sep_idx).nonzero()[0] 184 | sep_idxs = sep_idxs[sep_idxs+2 < idxenc.shape[0]] # remove any indexes right before out of bounds (sep_idx+2) 185 | dur_vals = idxenc[sep_idxs+1] 186 | dur_vals[dur_vals == vocab.mask_idx] = vocab.dur_range[0] # make sure masked durations are 0 187 | dur_vals -= vocab.dur_range[0] 188 | 189 | posenc = np.zeros_like(idxenc) 190 | posenc[sep_idxs+2] = dur_vals 191 | return posenc.cumsum() 192 | 193 | def beat2index(idxenc, pos, vocab, beat, include_last_sep=False): 194 | cutoff = find_beat(pos, beat) 195 | if cutoff < 2: return 2 # always leave starter tokens 196 | if len(idxenc) < 2 or include_last_sep: return cutoff 197 | if idxenc[cutoff - 2] == vocab.sep_idx: return cutoff - 2 198 | return cutoff 199 | 200 | def find_beat(pos, beat, sample_freq=SAMPLE_FREQ, side='left'): 201 | return np.searchsorted(pos, beat * sample_freq, side=side) 202 | 203 | # TRANSFORMS 204 | 205 | def tfm_transpose(x, value, vocab): 206 | x = x.copy() 207 | x[(x >= vocab.note_range[0]) & (x < vocab.note_range[1])] += value 208 | return x 209 | 210 | def trim_to_beat(idxenc, pos, vocab, to_beat=None, include_last_sep=True): 211 | if to_beat is None: return idxenc 212 | cutoff = beat2index(idxenc, pos, vocab, to_beat, include_last_sep=include_last_sep) 213 | return idxenc[:cutoff] 214 | 215 | def mask_input(xb, mask_range, replacement_idx): 216 | xb = xb.copy() 217 | xb[(xb >= mask_range[0]) & (xb < mask_range[1])] = replacement_idx 218 | return xb 219 | 220 | def mask_section(xb, pos, token_range, replacement_idx, section_range=None): 221 | xb = xb.copy() 222 | token_mask = (xb >= token_range[0]) & (xb < token_range[1]) 223 | 224 | if section_range is None: section_range = (None, None) 225 | section_mask = np.zeros_like(xb, dtype=bool) 226 | start_idx = find_beat(pos, section_range[0]) if section_range[0] is not None else 0 227 | end_idx = find_beat(pos, section_range[1]) if section_range[1] is not None else xb.shape[0] 228 | section_mask[start_idx:end_idx] = True 229 | 230 | xb[token_mask & section_mask] = replacement_idx 231 | return xb 232 | -------------------------------------------------------------------------------- /notebooks/music_transformer/Train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%reload_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "%matplotlib inline" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import os\n", 21 | "os.chdir('../../')" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 3, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "from musicautobot.numpy_encode import *\n", 31 | "from musicautobot.config import *\n", 32 | "from musicautobot.music_transformer import *" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "## MusicTransformer Training\n", 40 | "\n", 41 | "MusicTransformer takes the basic idea of [Language Models](https://en.wikipedia.org/wiki/Language_model) and applies it to Music. \n", 42 | "\n", 43 | "Given a sequence of notes, predict the next most likely set of notes.\n", 44 | "\n", 45 | "This model is based off of [transformer-XL](https://arxiv.org/abs/1901.02860) and uses fast.ai's [implementation](https://github.com/fastai/fastai/blob/master/fastai/text/models/transformer.py) of it." 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 4, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "# Location of your midi filesfiles\n", 55 | "midi_path = Path('data/midi/examples')\n", 56 | "midi_path.mkdir(parents=True, exist_ok=True)\n", 57 | "\n", 58 | "# Location to save dataset\n", 59 | "data_path = Path('data/numpy')\n", 60 | "data_path.mkdir(parents=True, exist_ok=True)\n", 61 | "\n", 62 | "data_save_name = 'musicitem_data_save.pkl'" 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "metadata": {}, 68 | "source": [ 69 | "## 1. Gather midi dataset" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "metadata": {}, 75 | "source": [ 76 | "Make sure all your midi data is in `musicautobot/data/midi` directory" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": {}, 82 | "source": [ 83 | "Here's a pretty good dataset with lots of midi data: \n", 84 | "https://www.reddit.com/r/datasets/comments/3akhxy/the_largest_midi_collection_on_the_internet/\n", 85 | "\n", 86 | "Download the folder and unzip it to `data/midi`" 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "metadata": {}, 92 | "source": [ 93 | "## 2. Create dataset from MIDI files" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 5, 99 | "metadata": {}, 100 | "outputs": [ 101 | { 102 | "data": { 103 | "text/plain": [ 104 | "19" 105 | ] 106 | }, 107 | "execution_count": 5, 108 | "metadata": {}, 109 | "output_type": "execute_result" 110 | } 111 | ], 112 | "source": [ 113 | "midi_files = get_files(midi_path, '.mid', recurse=True); len(midi_files)" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 6, 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [ 122 | "processors = [Midi2ItemProcessor()]\n", 123 | "data = MusicDataBunch.from_files(midi_files, data_path, processors=processors, bs=2, bptt=12)\n", 124 | "data.save(data_save_name)" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "# Show Data\n", 134 | "data.train_dl.on_epoch_begin()\n", 135 | "x, y = data.one_batch();\n", 136 | "x, y" 137 | ] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "metadata": {}, 142 | "source": [ 143 | "## 3. Load Model" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": 8, 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "batch_size = 1\n", 153 | "encode_position = True\n", 154 | "dl_tfms = [batch_position_tfm] if encode_position else []\n", 155 | "data = load_data(data_path, data_save_name, bs=batch_size, encode_position=encode_position, dl_tfms=dl_tfms)" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 10, 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "config = default_config()\n", 165 | "config['encode_position'] = encode_position\n", 166 | "learn = music_model_learner(data, config=config.copy())" 167 | ] 168 | }, 169 | { 170 | "cell_type": "markdown", 171 | "metadata": {}, 172 | "source": [ 173 | "## 4. Train" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 12, 179 | "metadata": {}, 180 | "outputs": [ 181 | { 182 | "data": { 183 | "text/html": [ 184 | "\n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | "
epochtrain_lossvalid_lossaccuracytime
03.3612263.0960300.39795900:08
13.2531983.1652670.21020400:08
23.2420482.7099400.39795900:08
33.1596352.7271570.39795900:08
" 225 | ], 226 | "text/plain": [ 227 | "" 228 | ] 229 | }, 230 | "metadata": {}, 231 | "output_type": "display_data" 232 | } 233 | ], 234 | "source": [ 235 | "learn.fit_one_cycle(4)" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": null, 241 | "metadata": {}, 242 | "outputs": [], 243 | "source": [ 244 | "learn.save('example')" 245 | ] 246 | }, 247 | { 248 | "cell_type": "markdown", 249 | "metadata": {}, 250 | "source": [ 251 | "## 5. Predict" 252 | ] 253 | }, 254 | { 255 | "cell_type": "markdown", 256 | "metadata": {}, 257 | "source": [ 258 | "---\n", 259 | "See [Generate.ipynb](Generate.ipynb) to use a pretrained model and generate better predictions\n", 260 | "\n", 261 | "---" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": 11, 267 | "metadata": {}, 268 | "outputs": [], 269 | "source": [ 270 | "midi_file = Path('data/midi/notebook_examples/single_bar_example.mid'); midi_file\n", 271 | "item = MusicItem.from_file(midi_file, data.vocab);" 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": null, 277 | "metadata": {}, 278 | "outputs": [], 279 | "source": [ 280 | "item.show()" 281 | ] 282 | }, 283 | { 284 | "cell_type": "markdown", 285 | "metadata": {}, 286 | "source": [ 287 | "Here's what the seed sounds like:" 288 | ] 289 | }, 290 | { 291 | "cell_type": "markdown", 292 | "metadata": {}, 293 | "source": [ 294 | "### Start Predictions:" 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": 16, 300 | "metadata": {}, 301 | "outputs": [ 302 | { 303 | "data": { 304 | "text/html": [ 305 | "\n", 306 | "
\n", 307 | " \n", 319 | " \n", 320 | " 100.00% [100/100 00:02<00:00]\n", 321 | "
\n", 322 | " " 323 | ], 324 | "text/plain": [ 325 | "" 326 | ] 327 | }, 328 | "metadata": {}, 329 | "output_type": "display_data" 330 | } 331 | ], 332 | "source": [ 333 | "pred, full = learn.predict(item, n_words=100)" 334 | ] 335 | }, 336 | { 337 | "cell_type": "markdown", 338 | "metadata": {}, 339 | "source": [ 340 | "Prediction" 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": null, 346 | "metadata": {}, 347 | "outputs": [], 348 | "source": [ 349 | "# Prediction\n", 350 | "pred.show()" 351 | ] 352 | }, 353 | { 354 | "cell_type": "code", 355 | "execution_count": null, 356 | "metadata": {}, 357 | "outputs": [], 358 | "source": [ 359 | "pred.play()" 360 | ] 361 | } 362 | ], 363 | "metadata": { 364 | "kernelspec": { 365 | "display_name": "Python 3", 366 | "language": "python", 367 | "name": "python3" 368 | }, 369 | "language_info": { 370 | "codemirror_mode": { 371 | "name": "ipython", 372 | "version": 3 373 | }, 374 | "file_extension": ".py", 375 | "mimetype": "text/x-python", 376 | "name": "python", 377 | "nbconvert_exporter": "python", 378 | "pygments_lexer": "ipython3", 379 | "version": "3.7.4" 380 | } 381 | }, 382 | "nbformat": 4, 383 | "nbformat_minor": 2 384 | } 385 | -------------------------------------------------------------------------------- /musicautobot/music_transformer/dataloader.py: -------------------------------------------------------------------------------- 1 | "Fastai Language Model Databunch modified to work with music" 2 | from fastai.basics import * 3 | # from fastai.basic_data import DataBunch 4 | from fastai.text.data import LMLabelList 5 | from .transform import * 6 | from ..vocab import MusicVocab 7 | 8 | 9 | class MusicDataBunch(DataBunch): 10 | "Create a `TextDataBunch` suitable for training a language model." 11 | @classmethod 12 | def create(cls, train_ds, valid_ds, test_ds=None, path:PathOrStr='.', no_check:bool=False, bs=64, val_bs:int=None, 13 | num_workers:int=0, device:torch.device=None, collate_fn:Callable=data_collate, 14 | dl_tfms:Optional[Collection[Callable]]=None, bptt:int=70, 15 | preloader_cls=None, shuffle_dl=False, transpose_range=(0,12), **kwargs) -> DataBunch: 16 | "Create a `TextDataBunch` in `path` from the `datasets` for language modelling." 17 | datasets = cls._init_ds(train_ds, valid_ds, test_ds) 18 | preloader_cls = MusicPreloader if preloader_cls is None else preloader_cls 19 | val_bs = ifnone(val_bs, bs) 20 | datasets = [preloader_cls(ds, shuffle=(i==0), bs=(bs if i==0 else val_bs), bptt=bptt, transpose_range=transpose_range, **kwargs) 21 | for i,ds in enumerate(datasets)] 22 | val_bs = bs 23 | dl_tfms = [partially_apply_vocab(tfm, train_ds.vocab) for tfm in listify(dl_tfms)] 24 | dls = [DataLoader(d, b, shuffle=shuffle_dl) for d,b in zip(datasets, (bs,val_bs,val_bs,val_bs)) if d is not None] 25 | return cls(*dls, path=path, device=device, dl_tfms=dl_tfms, collate_fn=collate_fn, no_check=no_check) 26 | 27 | @classmethod 28 | def from_folder(cls, path:PathOrStr, extensions='.npy', **kwargs): 29 | files = get_files(path, extensions=extensions, recurse=True); 30 | return cls.from_files(files, path, **kwargs) 31 | 32 | @classmethod 33 | def from_files(cls, files, path, processors=None, split_pct=0.1, 34 | vocab=None, list_cls=None, **kwargs): 35 | if vocab is None: vocab = MusicVocab.create() 36 | if list_cls is None: list_cls = MusicItemList 37 | src = (list_cls(items=files, path=path, processor=processors, vocab=vocab) 38 | .split_by_rand_pct(split_pct, seed=6) 39 | .label_const(label_cls=LMLabelList)) 40 | return src.databunch(**kwargs) 41 | 42 | @classmethod 43 | def empty(cls, path, **kwargs): 44 | vocab = MusicVocab.create() 45 | src = MusicItemList([], path=path, vocab=vocab, ignore_empty=True).split_none() 46 | return src.label_const(label_cls=LMLabelList).databunch() 47 | 48 | def partially_apply_vocab(tfm, vocab): 49 | if 'vocab' in inspect.getfullargspec(tfm).args: 50 | return partial(tfm, vocab=vocab) 51 | return tfm 52 | 53 | class MusicItemList(ItemList): 54 | _bunch = MusicDataBunch 55 | 56 | def __init__(self, items:Iterator, vocab:MusicVocab=None, **kwargs): 57 | super().__init__(items, **kwargs) 58 | self.vocab = vocab 59 | self.copy_new += ['vocab'] 60 | 61 | def get(self, i): 62 | o = super().get(i) 63 | if is_pos_enc(o): 64 | return MusicItem.from_idx(o, self.vocab) 65 | return MusicItem(o, self.vocab) 66 | 67 | def is_pos_enc(idxenc): 68 | if len(idxenc.shape) == 2 and idxenc.shape[0] == 2: return True 69 | return idxenc.dtype == np.object and idxenc.shape == (2,) 70 | 71 | class MusicItemProcessor(PreProcessor): 72 | "`PreProcessor` that transforms numpy files to indexes for training" 73 | def process_one(self,item): 74 | item = MusicItem.from_npenc(item, vocab=self.vocab) 75 | return item.to_idx() 76 | 77 | def process(self, ds): 78 | self.vocab = ds.vocab 79 | super().process(ds) 80 | 81 | class OpenNPFileProcessor(PreProcessor): 82 | "`PreProcessor` that opens the filenames and read the texts." 83 | def process_one(self,item): 84 | return np.load(item, allow_pickle=True) if isinstance(item, Path) else item 85 | 86 | class Midi2ItemProcessor(PreProcessor): 87 | "Skips midi preprocessing step. And encodes midi files to MusicItems" 88 | def process_one(self,item): 89 | item = MusicItem.from_file(item, vocab=self.vocab) 90 | return item.to_idx() 91 | 92 | def process(self, ds): 93 | self.vocab = ds.vocab 94 | super().process(ds) 95 | 96 | ## For npenc dataset 97 | class MusicPreloader(Callback): 98 | "Transforms the tokens in `dataset` to a stream of contiguous batches for language modelling." 99 | 100 | class CircularIndex(): 101 | "Handles shuffle, direction of indexing, wraps around to head tail in the ragged array as needed" 102 | def __init__(self, length:int, forward:bool): self.idx, self.forward = np.arange(length), forward 103 | def __getitem__(self, i): 104 | return self.idx[ i%len(self.idx) if self.forward else len(self.idx)-1-i%len(self.idx)] 105 | def __len__(self) -> int: return len(self.idx) 106 | def shuffle(self): np.random.shuffle(self.idx) 107 | 108 | def __init__(self, dataset:LabelList, lengths:Collection[int]=None, bs:int=32, bptt:int=70, backwards:bool=False, 109 | shuffle:bool=False, y_offset:int=1, 110 | transpose_range=None, transpose_p=0.5, 111 | encode_position=True, 112 | **kwargs): 113 | self.dataset,self.bs,self.bptt,self.shuffle,self.backwards,self.lengths = dataset,bs,bptt,shuffle,backwards,lengths 114 | self.vocab = self.dataset.vocab 115 | self.bs *= num_distrib() or 1 116 | self.totalToks,self.ite_len,self.idx = int(0),None,None 117 | self.y_offset = y_offset 118 | 119 | self.transpose_range,self.transpose_p = transpose_range,transpose_p 120 | self.encode_position = encode_position 121 | self.bptt_len = self.bptt 122 | 123 | self.allocate_buffers() # needed for valid_dl on distributed training - otherwise doesn't get initialized on first epoch 124 | 125 | def __len__(self): 126 | if self.ite_len is None: 127 | if self.lengths is None: self.lengths = np.array([len(item) for item in self.dataset.x]) 128 | self.totalToks = self.lengths.sum() 129 | self.ite_len = self.bs*int( math.ceil( self.totalToks/(self.bptt*self.bs) )) if self.item is None else 1 130 | return self.ite_len 131 | 132 | def __getattr__(self,k:str)->Any: return getattr(self.dataset, k) 133 | 134 | def allocate_buffers(self): 135 | "Create the ragged array that will be filled when we ask for items." 136 | if self.ite_len is None: len(self) 137 | self.idx = MusicPreloader.CircularIndex(len(self.dataset.x), not self.backwards) 138 | 139 | # batch shape = (bs, bptt, 2 - [index, pos]) if encode_position. Else - (bs, bptt) 140 | buffer_len = (2,) if self.encode_position else () 141 | self.batch = np.zeros((self.bs, self.bptt+self.y_offset) + buffer_len, dtype=np.int64) 142 | self.batch_x, self.batch_y = self.batch[:,0:self.bptt], self.batch[:,self.y_offset:self.bptt+self.y_offset] 143 | #ro: index of the text we're at inside our datasets for the various batches 144 | self.ro = np.zeros(self.bs, dtype=np.int64) 145 | #ri: index of the token we're at inside our current text for the various batches 146 | self.ri = np.zeros(self.bs, dtype=np.int) 147 | 148 | # allocate random transpose values. Need to allocate this before hand. 149 | self.transpose_values = self.get_random_transpose_values() 150 | 151 | def get_random_transpose_values(self): 152 | if self.transpose_range is None: return None 153 | n = len(self.dataset) 154 | rt_arr = torch.randint(*self.transpose_range, (n,))-self.transpose_range[1]//2 155 | mask = torch.rand(rt_arr.shape) > self.transpose_p 156 | rt_arr[mask] = 0 157 | return rt_arr 158 | 159 | def on_epoch_begin(self, **kwargs): 160 | if self.idx is None: self.allocate_buffers() 161 | elif self.shuffle: 162 | self.ite_len = None 163 | self.idx.shuffle() 164 | self.transpose_values = self.get_random_transpose_values() 165 | self.bptt_len = self.bptt 166 | self.idx.forward = not self.backwards 167 | 168 | step = self.totalToks / self.bs 169 | ln_rag, countTokens, i_rag = 0, 0, -1 170 | for i in range(0,self.bs): 171 | #Compute the initial values for ro and ri 172 | while ln_rag + countTokens <= int(step * i): 173 | countTokens += ln_rag 174 | i_rag += 1 175 | ln_rag = self.lengths[self.idx[i_rag]] 176 | self.ro[i] = i_rag 177 | self.ri[i] = ( ln_rag - int(step * i - countTokens) ) if self.backwards else int(step * i - countTokens) 178 | 179 | #Training dl gets on_epoch_begin called, val_dl, on_epoch_end 180 | def on_epoch_end(self, **kwargs): self.on_epoch_begin() 181 | 182 | def __getitem__(self, k:int): 183 | j = k % self.bs 184 | if j==0: 185 | if self.item is not None: return self.dataset[0] 186 | if self.idx is None: self.on_epoch_begin() 187 | 188 | self.ro[j],self.ri[j] = self.fill_row(not self.backwards, self.dataset.x, self.idx, self.batch[j][:self.bptt_len+self.y_offset], 189 | self.ro[j], self.ri[j], overlap=1, lengths=self.lengths) 190 | return self.batch_x[j][:self.bptt_len], self.batch_y[j][:self.bptt_len] 191 | 192 | def fill_row(self, forward, items, idx, row, ro, ri, overlap, lengths): 193 | "Fill the row with tokens from the ragged array. --OBS-- overlap != 1 has not been implemented" 194 | ibuf = n = 0 195 | ro -= 1 196 | while ibuf < row.shape[0]: 197 | ro += 1 198 | ix = idx[ro] 199 | 200 | item = items[ix] 201 | if self.transpose_values is not None: 202 | item = item.transpose(self.transpose_values[ix].item()) 203 | 204 | if self.encode_position: 205 | # Positions are colomn stacked with indexes. This makes it easier to keep in sync 206 | rag = np.stack([item.data, item.position], axis=1) 207 | else: 208 | rag = item.data 209 | 210 | if forward: 211 | ri = 0 if ibuf else ri 212 | n = min(lengths[ix] - ri, row.shape[0] - ibuf) 213 | row[ibuf:ibuf+n] = rag[ri:ri+n] 214 | else: 215 | ri = lengths[ix] if ibuf else ri 216 | n = min(ri, row.size - ibuf) 217 | row[ibuf:ibuf+n] = rag[ri-n:ri][::-1] 218 | ibuf += n 219 | return ro, ri + ((n-overlap) if forward else -(n-overlap)) 220 | 221 | def batch_position_tfm(b): 222 | "Batch transform for training with positional encoding" 223 | x,y = b 224 | x = { 225 | 'x': x[...,0], 226 | 'pos': x[...,1] 227 | } 228 | return x, y[...,0] 229 | -------------------------------------------------------------------------------- /musicautobot/numpy_encode.py: -------------------------------------------------------------------------------- 1 | "Encoding music21 streams -> numpy array -> text" 2 | 3 | # import re 4 | import music21 5 | import numpy as np 6 | # from pathlib import Path 7 | 8 | BPB = 4 # beats per bar 9 | TIMESIG = f'{BPB}/4' # default time signature 10 | PIANO_RANGE = (21, 108) 11 | VALTSEP = -1 # separator value for numpy encoding 12 | VALTCONT = -2 # numpy value for TCONT - needed for compressing chord array 13 | 14 | SAMPLE_FREQ = 4 15 | NOTE_SIZE = 128 16 | DUR_SIZE = (10*BPB*SAMPLE_FREQ)+1 # Max length - 8 bars. Or 16 beats/quarternotes 17 | MAX_NOTE_DUR = (8*BPB*SAMPLE_FREQ) 18 | 19 | # Encoding process 20 | # 1. midi -> music21.Stream 21 | # 2. Stream -> numpy chord array (timestep X instrument X noterange) 22 | # 3. numpy array -> List[Timestep][NoteEnc] 23 | def midi2npenc(midi_file, skip_last_rest=True): 24 | "Converts midi file to numpy encoding for language model" 25 | stream = file2stream(midi_file) # 1. 26 | chordarr = stream2chordarr(stream) # 2. 27 | return chordarr2npenc(chordarr, skip_last_rest=skip_last_rest) # 3. 28 | 29 | # Decoding process 30 | # 1. NoteEnc -> numpy chord array 31 | # 2. numpy array -> music21.Stream 32 | def npenc2stream(arr, bpm=120): 33 | "Converts numpy encoding to music21 stream" 34 | chordarr = npenc2chordarr(np.array(arr)) # 1. 35 | return chordarr2stream(chordarr, bpm=bpm) # 2. 36 | 37 | ##### ENCODING ###### 38 | 39 | # 1. File To STream 40 | 41 | def file2stream(fp): 42 | if isinstance(fp, music21.midi.MidiFile): return music21.midi.translate.midiFileToStream(fp) 43 | return music21.converter.parse(fp) 44 | 45 | # 2. 46 | def stream2chordarr(s, note_size=NOTE_SIZE, sample_freq=SAMPLE_FREQ, max_note_dur=MAX_NOTE_DUR): 47 | "Converts music21.Stream to 1-hot numpy array" 48 | # assuming 4/4 time 49 | # note x instrument x pitch 50 | # FYI: midi middle C value=60 51 | 52 | # (AS) TODO: need to order by instruments most played and filter out percussion or include the channel 53 | highest_time = max(s.flat.getElementsByClass('Note').highestTime, s.flat.getElementsByClass('Chord').highestTime) 54 | maxTimeStep = round(highest_time * sample_freq)+1 55 | score_arr = np.zeros((maxTimeStep, len(s.parts), NOTE_SIZE)) 56 | 57 | def note_data(pitch, note): 58 | return (pitch.midi, int(round(note.offset*sample_freq)), int(round(note.duration.quarterLength*sample_freq))) 59 | 60 | for idx,part in enumerate(s.parts): 61 | notes=[] 62 | for elem in part.flat: 63 | if isinstance(elem, music21.note.Note): 64 | notes.append(note_data(elem.pitch, elem)) 65 | if isinstance(elem, music21.chord.Chord): 66 | for p in elem.pitches: 67 | notes.append(note_data(p, elem)) 68 | 69 | # sort notes by offset (1), duration (2) so that hits are not overwritten and longer notes have priority 70 | notes_sorted = sorted(notes, key=lambda x: (x[1], x[2])) 71 | for n in notes_sorted: 72 | if n is None: continue 73 | pitch,offset,duration = n 74 | if max_note_dur is not None and duration > max_note_dur: duration = max_note_dur 75 | score_arr[offset, idx, pitch] = duration 76 | score_arr[offset+1:offset+duration, idx, pitch] = VALTCONT # Continue holding note 77 | return score_arr 78 | 79 | def chordarr2npenc(chordarr, skip_last_rest=True): 80 | # combine instruments 81 | result = [] 82 | wait_count = 0 83 | for idx,timestep in enumerate(chordarr): 84 | flat_time = timestep2npenc(timestep) 85 | if len(flat_time) == 0: 86 | wait_count += 1 87 | else: 88 | # pitch, octave, duration, instrument 89 | if wait_count > 0: result.append([VALTSEP, wait_count]) 90 | result.extend(flat_time) 91 | wait_count = 1 92 | if wait_count > 0 and not skip_last_rest: result.append([VALTSEP, wait_count]) 93 | return np.array(result, dtype=int).reshape(-1, 2) # reshaping. Just in case result is empty 94 | 95 | # Note: not worrying about overlaps - as notes will still play. just look tied 96 | # http://web.mit.edu/music21/doc/moduleReference/moduleStream.html#music21.stream.Stream.getOverlaps 97 | def timestep2npenc(timestep, note_range=PIANO_RANGE, enc_type=None): 98 | # inst x pitch 99 | notes = [] 100 | for i,n in zip(*timestep.nonzero()): 101 | d = timestep[i,n] 102 | if d < 0: continue # only supporting short duration encoding for now 103 | if n < note_range[0] or n >= note_range[1]: continue # must be within midi range 104 | notes.append([n,d,i]) 105 | 106 | notes = sorted(notes, key=lambda x: x[0], reverse=True) # sort by note (highest to lowest) 107 | 108 | if enc_type is None: 109 | # note, duration 110 | return [n[:2] for n in notes] 111 | if enc_type == 'parts': 112 | # note, duration, part 113 | return [n for n in notes] 114 | if enc_type == 'full': 115 | # note_class, duration, octave, instrument 116 | return [[n%12, d, n//12, i] for n,d,i in notes] 117 | 118 | ##### DECODING ##### 119 | 120 | # 1. 121 | def npenc2chordarr(npenc, note_size=NOTE_SIZE): 122 | num_instruments = 1 if len(npenc.shape) <= 2 else npenc.max(axis=0)[-1] 123 | 124 | max_len = npenc_len(npenc) 125 | # score_arr = (steps, inst, note) 126 | score_arr = np.zeros((max_len, num_instruments, note_size)) 127 | 128 | idx = 0 129 | for step in npenc: 130 | n,d,i = (step.tolist()+[0])[:3] # or n,d,i 131 | if n < VALTSEP: continue # special token 132 | if n == VALTSEP: 133 | idx += d 134 | continue 135 | score_arr[idx,i,n] = d 136 | return score_arr 137 | 138 | def npenc_len(npenc): 139 | duration = 0 140 | for t in npenc: 141 | if t[0] == VALTSEP: duration += t[1] 142 | return duration + 1 143 | 144 | 145 | # 2. 146 | def chordarr2stream(arr, sample_freq=SAMPLE_FREQ, bpm=120): 147 | duration = music21.duration.Duration(1. / sample_freq) 148 | stream = music21.stream.Score() 149 | stream.append(music21.meter.TimeSignature(TIMESIG)) 150 | stream.append(music21.tempo.MetronomeMark(number=bpm)) 151 | stream.append(music21.key.KeySignature(0)) 152 | for inst in range(arr.shape[1]): 153 | p = partarr2stream(arr[:,inst,:], duration) 154 | stream.append(p) 155 | stream = stream.transpose(0) 156 | return stream 157 | 158 | # 2b. 159 | def partarr2stream(partarr, duration): 160 | "convert instrument part to music21 chords" 161 | part = music21.stream.Part() 162 | part.append(music21.instrument.Piano()) 163 | part_append_duration_notes(partarr, duration, part) # notes already have duration calculated 164 | 165 | return part 166 | 167 | def part_append_duration_notes(partarr, duration, stream): 168 | "convert instrument part to music21 chords" 169 | for tidx,t in enumerate(partarr): 170 | note_idxs = np.where(t > 0)[0] # filter out any negative values (continuous mode) 171 | if len(note_idxs) == 0: continue 172 | notes = [] 173 | for nidx in note_idxs: 174 | note = music21.note.Note(nidx) 175 | note.duration = music21.duration.Duration(partarr[tidx,nidx]*duration.quarterLength) 176 | notes.append(note) 177 | for g in group_notes_by_duration(notes): 178 | if len(g) == 1: 179 | stream.insert(tidx*duration.quarterLength, g[0]) 180 | else: 181 | chord = music21.chord.Chord(g) 182 | stream.insert(tidx*duration.quarterLength, chord) 183 | return stream 184 | 185 | from itertools import groupby 186 | # combining notes with different durations into a single chord may overwrite conflicting durations. Example: aylictal/still-waters-run-deep 187 | def group_notes_by_duration(notes): 188 | "separate notes into chord groups" 189 | keyfunc = lambda n: n.duration.quarterLength 190 | notes = sorted(notes, key=keyfunc) 191 | return [list(g) for k,g in groupby(notes, keyfunc)] 192 | 193 | 194 | # Midi -> npenc Conversion helpers 195 | def is_valid_npenc(npenc, note_range=PIANO_RANGE, max_dur=DUR_SIZE, 196 | min_notes=32, input_path=None, verbose=True): 197 | if len(npenc) < min_notes: 198 | if verbose: print('Sequence too short:', len(npenc), input_path) 199 | return False 200 | if (npenc[:,1] >= max_dur).any(): 201 | if verbose: print(f'npenc exceeds max {max_dur} duration:', npenc[:,1].max(), input_path) 202 | return False 203 | # https://en.wikipedia.org/wiki/Scientific_pitch_notation - 88 key range - 21 = A0, 108 = C8 204 | if ((npenc[...,0] > VALTSEP) & ((npenc[...,0] < note_range[0]) | (npenc[...,0] >= note_range[1]))).any(): 205 | print(f'npenc out of piano note range {note_range}:', input_path) 206 | return False 207 | return True 208 | 209 | # seperates overlapping notes to different tracks 210 | def remove_overlaps(stream, separate_chords=True): 211 | if not separate_chords: 212 | return stream.flat.makeVoices().voicesToParts() 213 | return separate_melody_chord(stream) 214 | 215 | # seperates notes and chords to different tracks 216 | def separate_melody_chord(stream): 217 | new_stream = music21.stream.Score() 218 | if stream.timeSignature: new_stream.append(stream.timeSignature) 219 | new_stream.append(stream.metronomeMarkBoundaries()[0][-1]) 220 | if stream.keySignature: new_stream.append(stream.keySignature) 221 | 222 | melody_part = music21.stream.Part(stream.flat.getElementsByClass('Note')) 223 | melody_part.insert(0, stream.getInstrument()) 224 | chord_part = music21.stream.Part(stream.flat.getElementsByClass('Chord')) 225 | chord_part.insert(0, stream.getInstrument()) 226 | new_stream.append(melody_part) 227 | new_stream.append(chord_part) 228 | return new_stream 229 | 230 | # processing functions for sanitizing data 231 | 232 | def compress_chordarr(chordarr): 233 | return shorten_chordarr_rests(trim_chordarr_rests(chordarr)) 234 | 235 | def trim_chordarr_rests(arr, max_rests=4, sample_freq=SAMPLE_FREQ): 236 | # max rests is in quarter notes 237 | # max 1 bar between song start and end 238 | start_idx = 0 239 | max_sample = max_rests*sample_freq 240 | for idx,t in enumerate(arr): 241 | if (t != 0).any(): break 242 | start_idx = idx+1 243 | 244 | end_idx = 0 245 | for idx,t in enumerate(reversed(arr)): 246 | if (t != 0).any(): break 247 | end_idx = idx+1 248 | start_idx = start_idx - start_idx % max_sample 249 | end_idx = end_idx - end_idx % max_sample 250 | # if start_idx > 0 or end_idx > 0: print('Trimming rests. Start, end:', start_idx, len(arr)-end_idx, end_idx) 251 | return arr[start_idx:(len(arr)-end_idx)] 252 | 253 | def shorten_chordarr_rests(arr, max_rests=8, sample_freq=SAMPLE_FREQ): 254 | # max rests is in quarter notes 255 | # max 2 bar pause 256 | rest_count = 0 257 | result = [] 258 | max_sample = max_rests*sample_freq 259 | for timestep in arr: 260 | if (timestep==0).all(): 261 | rest_count += 1 262 | else: 263 | if rest_count > max_sample: 264 | # old_count = rest_count 265 | rest_count = (rest_count % sample_freq) + max_sample 266 | # print(f'Compressing rests: {old_count} -> {rest_count}') 267 | for i in range(rest_count): result.append(np.zeros(timestep.shape)) 268 | rest_count = 0 269 | result.append(timestep) 270 | for i in range(rest_count): result.append(np.zeros(timestep.shape)) 271 | return np.array(result) 272 | 273 | # sequence 2 sequence convenience functions 274 | 275 | def stream2npenc_parts(stream, sort_pitch=True): 276 | chordarr = stream2chordarr(stream) 277 | _,num_parts,_ = chordarr.shape 278 | parts = [part_enc(chordarr, i) for i in range(num_parts)] 279 | return sorted(parts, key=avg_pitch, reverse=True) if sort_pitch else parts 280 | 281 | def chordarr_combine_parts(parts): 282 | max_ts = max([p.shape[0] for p in parts]) 283 | parts_padded = [pad_part_to(p, max_ts) for p in parts] 284 | chordarr_comb = np.concatenate(parts_padded, axis=1) 285 | return chordarr_comb 286 | 287 | def pad_part_to(p, target_size): 288 | pad_width = ((0,target_size-p.shape[0]),(0,0),(0,0)) 289 | return np.pad(p, pad_width, 'constant') 290 | 291 | def part_enc(chordarr, part): 292 | partarr = chordarr[:,part:part+1,:] 293 | npenc = chordarr2npenc(partarr) 294 | return npenc 295 | 296 | def avg_tempo(t, sep_idx=VALTSEP): 297 | avg = t[t[:, 0] == sep_idx][:, 1].sum()/t.shape[0] 298 | avg = int(round(avg/SAMPLE_FREQ)) 299 | return 'mt'+str(min(avg, MTEMPO_SIZE-1)) 300 | 301 | def avg_pitch(t, sep_idx=VALTSEP): 302 | return t[t[:, 0] > sep_idx][:, 0].mean() 303 | -------------------------------------------------------------------------------- /musicautobot/multitask_transformer/model.py: -------------------------------------------------------------------------------- 1 | from fastai.basics import * 2 | from fastai.text.models.transformer import Activation, PositionalEncoding, feed_forward, init_transformer, _line_shift 3 | from fastai.text.models.awd_lstm import RNNDropout 4 | from ..utils.attention_mask import * 5 | 6 | def get_multitask_model(vocab_size:int, config:dict=None, drop_mult:float=1., pad_idx=None): 7 | "Create a language model from `arch` and its `config`, maybe `pretrained`." 8 | for k in config.keys(): 9 | if k.endswith('_p'): config[k] *= drop_mult 10 | n_hid = config['d_model'] 11 | mem_len = config.pop('mem_len') 12 | embed = TransformerEmbedding(vocab_size, n_hid, embed_p=config['embed_p'], mem_len=mem_len, pad_idx=pad_idx) 13 | encoder = MTEncoder(embed, n_hid, n_layers=config['enc_layers'], mem_len=0, **config) # encoder doesn't need memory 14 | decoder = MTEncoder(embed, n_hid, is_decoder=True, n_layers=config['dec_layers'], mem_len=mem_len, **config) 15 | head = MTLinearDecoder(n_hid, vocab_size, tie_encoder=embed.embed, **config) 16 | model = MultiTransformer(encoder, decoder, head, mem_len=mem_len) 17 | return model.apply(init_transformer) 18 | 19 | class MultiTransformer(nn.Module): 20 | "Multitask Transformer for training mask, next word, and sequence 2 sequence" 21 | def __init__(self, encoder, decoder, head, mem_len): 22 | super().__init__() 23 | self.encoder = encoder 24 | self.decoder = decoder 25 | self.head = head 26 | self.default_mem_len = mem_len 27 | self.current_mem_len = None 28 | 29 | def forward(self, inp): 30 | # data order: mask, next word, melody, chord 31 | outputs = {} 32 | msk, lm, c2m, m2c = [inp.get(key) for key in ['msk', 'lm', 'c2m', 'm2c']] 33 | 34 | if msk is not None: 35 | outputs['msk'] = self.head(self.encoder(msk['x'], msk['pos'])) 36 | if lm is not None: 37 | outputs['lm'] = self.head(self.decoder(lm['x'], lm['pos'])) 38 | 39 | if c2m is not None: 40 | self.reset() 41 | c2m_enc = self.encoder(c2m['enc'], c2m['enc_pos']) 42 | c2m_dec = self.decoder(c2m['dec'], c2m['dec_pos'], c2m_enc) 43 | outputs['c2m'] = self.head(c2m_dec) 44 | 45 | if m2c is not None: 46 | self.reset() 47 | m2c_enc = self.encoder(m2c['enc'], m2c['enc_pos']) 48 | m2c_dec = self.decoder(m2c['dec'], m2c['dec_pos'], m2c_enc) 49 | outputs['m2c'] = self.head(m2c_dec) 50 | 51 | return outputs 52 | 53 | "A sequential module that passes the reset call to its children." 54 | def reset(self): 55 | for module in self.children(): 56 | reset_children(module) 57 | 58 | def reset_children(mod): 59 | if hasattr(mod, 'reset'): mod.reset() 60 | for module in mod.children(): 61 | reset_children(module) 62 | 63 | # COMPONENTS 64 | class TransformerEmbedding(nn.Module): 65 | "Embedding + positional encoding + dropout" 66 | def __init__(self, vocab_size:int, emb_sz:int, embed_p:float=0., mem_len=512, beat_len=32, max_bar_len=1024, pad_idx=None): 67 | super().__init__() 68 | self.emb_sz = emb_sz 69 | self.pad_idx = pad_idx 70 | 71 | self.embed = nn.Embedding(vocab_size, emb_sz, padding_idx=pad_idx) 72 | self.pos_enc = PositionalEncoding(emb_sz) 73 | self.beat_len, self.max_bar_len = beat_len, max_bar_len 74 | self.beat_enc = nn.Embedding(beat_len, emb_sz, padding_idx=0) 75 | self.bar_enc = nn.Embedding(max_bar_len, emb_sz, padding_idx=0) 76 | 77 | self.drop = nn.Dropout(embed_p) 78 | self.mem_len = mem_len 79 | 80 | def forward(self, inp, pos): 81 | beat_enc = self.beat_enc(pos % self.beat_len) 82 | bar_pos = pos // self.beat_len % self.max_bar_len 83 | bar_pos[bar_pos >= self.max_bar_len] = self.max_bar_len - 1 84 | bar_enc = self.bar_enc((bar_pos)) 85 | emb = self.drop(self.embed(inp) + beat_enc + bar_enc) 86 | return emb 87 | 88 | def relative_pos_enc(self, emb): 89 | # return torch.arange(640-1, -1, -1).float().cuda() 90 | seq_len = emb.shape[1] + self.mem_len 91 | pos = torch.arange(seq_len-1, -1, -1, device=emb.device, dtype=emb.dtype) # backwards (txl pos encoding) 92 | return self.pos_enc(pos) 93 | 94 | class MTLinearDecoder(nn.Module): 95 | "To go on top of a RNNCore module and create a Language Model." 96 | initrange=0.1 97 | 98 | def __init__(self, n_hid:int, n_out:int, output_p:float, tie_encoder:nn.Module=None, out_bias:bool=True, **kwargs): 99 | super().__init__() 100 | self.decoder = nn.Linear(n_hid, n_out, bias=out_bias) 101 | self.decoder.weight.data.uniform_(-self.initrange, self.initrange) 102 | self.output_dp = RNNDropout(output_p) 103 | if out_bias: self.decoder.bias.data.zero_() 104 | if tie_encoder: self.decoder.weight = tie_encoder.weight 105 | 106 | def forward(self, input:Tuple[Tensor,Tensor])->Tuple[Tensor,Tensor,Tensor]: 107 | output = self.output_dp(input) 108 | decoded = self.decoder(output) 109 | return decoded 110 | 111 | 112 | # DECODER TRANSLATE BLOCK 113 | class MTEncoder(nn.Module): 114 | def __init__(self, embed:nn.Module, n_hid:int, n_layers:int, n_heads:int, d_model:int, d_head:int, d_inner:int, 115 | resid_p:float=0., attn_p:float=0., ff_p:float=0., bias:bool=True, scale:bool=True, 116 | act:Activation=Activation.ReLU, double_drop:bool=True, mem_len:int=512, is_decoder=False, 117 | mask_steps=1, mask_p=0.3, **kwargs): 118 | super().__init__() 119 | self.embed = embed 120 | self.u = nn.Parameter(torch.Tensor(n_heads, 1, d_head)) #Remove 1 for einsum implementation of attention 121 | self.v = nn.Parameter(torch.Tensor(n_heads, 1, d_head)) #Remove 1 for einsum implementation of attention 122 | self.n_layers,self.d_model = n_layers,d_model 123 | self.layers = nn.ModuleList([MTEncoderBlock(n_heads, d_model, d_head, d_inner, resid_p=resid_p, attn_p=attn_p, 124 | ff_p=ff_p, bias=bias, scale=scale, act=act, double_drop=double_drop, mem_len=mem_len, 125 | ) for k in range(n_layers)]) 126 | 127 | self.mask_steps, self.mask_p = mask_steps, mask_p 128 | self.is_decoder = is_decoder 129 | 130 | nn.init.normal_(self.u, 0., 0.02) 131 | nn.init.normal_(self.v, 0., 0.02) 132 | 133 | def forward(self, x_lm, lm_pos, msk_emb=None): 134 | bs,lm_len = x_lm.size() 135 | 136 | lm_emb = self.embed(x_lm, lm_pos) 137 | if msk_emb is not None and msk_emb.shape[1] > lm_emb.shape[1]: 138 | pos_enc = self.embed.relative_pos_enc(msk_emb) 139 | else: 140 | pos_enc = self.embed.relative_pos_enc(lm_emb) 141 | 142 | # Masks 143 | if self.is_decoder: 144 | lm_mask = rand_window_mask(lm_len, self.embed.mem_len, x_lm.device, 145 | max_size=self.mask_steps, p=self.mask_p, is_eval=not self.training) 146 | else: 147 | lm_mask = None 148 | 149 | for i, layer in enumerate(self.layers): 150 | lm_emb = layer(lm_emb, msk_emb, lm_mask=lm_mask, 151 | r=pos_enc, g_u=self.u, g_v=self.v) 152 | return lm_emb 153 | 154 | class MTEncoderBlock(nn.Module): 155 | "Decoder block of a Transformer model." 156 | #Can't use Sequential directly cause more than one input... 157 | def __init__(self, n_heads:int, d_model:int, d_head:int, d_inner:int, resid_p:float=0., attn_p:float=0., ff_p:float=0., 158 | bias:bool=True, scale:bool=True, double_drop:bool=True, mem_len:int=512, mha2_mem_len=0, **kwargs): 159 | super().__init__() 160 | attn_cls = MemMultiHeadRelativeAttentionKV 161 | self.mha1 = attn_cls(n_heads, d_model, d_head, resid_p=resid_p, attn_p=attn_p, bias=bias, scale=scale, mem_len=mem_len, r_mask=False) 162 | self.mha2 = attn_cls(n_heads, d_model, d_head, resid_p=resid_p, attn_p=attn_p, bias=bias, scale=scale, mem_len=mha2_mem_len, r_mask=True) 163 | self.ff = feed_forward(d_model, d_inner, ff_p=ff_p, double_drop=double_drop) 164 | 165 | def forward(self, enc_lm:Tensor, enc_msk:Tensor, 166 | r=None, g_u=None, g_v=None, 167 | msk_mask:Tensor=None, lm_mask:Tensor=None): 168 | 169 | y_lm = self.mha1(enc_lm, enc_lm, enc_lm, r, g_u, g_v, mask=lm_mask) 170 | if enc_msk is None: return y_lm 171 | return self.ff(self.mha2(y_lm, enc_msk, enc_msk, r, g_u, g_v, mask=msk_mask)) 172 | 173 | 174 | # Attention Layer 175 | 176 | 177 | # Attn 178 | 179 | class MemMultiHeadRelativeAttentionKV(nn.Module): 180 | "Attention Layer monster - relative positioning, keeps track of own memory, separate kv weights to support sequence2sequence decoding." 181 | def __init__(self, n_heads:int, d_model:int, d_head:int=None, resid_p:float=0., attn_p:float=0., bias:bool=True, 182 | scale:bool=True, mem_len:int=512, r_mask=True): 183 | super().__init__() 184 | d_head = ifnone(d_head, d_model//n_heads) 185 | self.n_heads,self.d_head,self.scale = n_heads,d_head,scale 186 | 187 | assert(d_model == d_head * n_heads) 188 | self.q_wgt = nn.Linear(d_model, n_heads * d_head, bias=bias) 189 | self.k_wgt = nn.Linear(d_model, n_heads * d_head, bias=bias) 190 | self.v_wgt = nn.Linear(d_model, n_heads * d_head, bias=bias) 191 | 192 | self.drop_att,self.drop_res = nn.Dropout(attn_p),nn.Dropout(resid_p) 193 | self.ln = nn.LayerNorm(d_model) 194 | self.r_attn = nn.Linear(d_model, n_heads * d_head, bias=bias) 195 | self.r_mask = r_mask 196 | 197 | self.mem_len = mem_len 198 | self.prev_k = None 199 | self.prev_v = None 200 | 201 | def forward(self, q:Tensor, k:Tensor=None, v:Tensor=None, 202 | r:Tensor=None, g_u:Tensor=None, g_v:Tensor=None, 203 | mask:Tensor=None, **kwargs): 204 | if k is None: k = q 205 | if v is None: v = q 206 | return self.ln(q + self.drop_res(self._apply_attention(q, k, v, r, g_u, g_v, mask=mask, **kwargs))) 207 | 208 | def mem_k(self, k): 209 | if self.mem_len == 0: return k 210 | if self.prev_k is None or (self.prev_k.shape[0] != k.shape[0]): # reset if wrong batch size 211 | self.prev_k = k[:, -self.mem_len:] 212 | return k 213 | with torch.no_grad(): 214 | k_ext = torch.cat([self.prev_k, k], dim=1) 215 | self.prev_k = k_ext[:, -self.mem_len:] 216 | return k_ext.detach() 217 | 218 | def mem_v(self, v): 219 | if self.mem_len == 0: return v 220 | if self.prev_v is None or (self.prev_v.shape[0] != v.shape[0]): # reset if wrong batch size 221 | self.prev_v = v[:, -self.mem_len:] 222 | return v 223 | with torch.no_grad(): 224 | v_ext = torch.cat([self.prev_v, v], dim=1) 225 | self.prev_v = v_ext[:, -self.mem_len:] 226 | return v_ext.detach() 227 | 228 | def reset(self): 229 | self.prev_v = None 230 | self.prev_k = None 231 | 232 | def _apply_attention(self, q:Tensor, k:Tensor, v:Tensor, 233 | r:Tensor=None, g_u:Tensor=None, g_v:Tensor=None, 234 | mask:Tensor=None, **kwargs): 235 | #Notations from the paper: x input, r vector of relative distance between two elements, u et v learnable 236 | #parameters of the model common between all layers, mask to avoid cheating and mem the previous hidden states. 237 | # bs,x_len,seq_len = q.size(0),q.size(1),r.size(0) 238 | k = self.mem_k(k) 239 | v = self.mem_v(v) 240 | bs,x_len,seq_len = q.size(0),q.size(1),k.size(1) 241 | wq,wk,wv = self.q_wgt(q),self.k_wgt(k),self.v_wgt(v) 242 | wq = wq[:,-x_len:] 243 | wq,wk,wv = map(lambda x:x.view(bs, x.size(1), self.n_heads, self.d_head), (wq,wk,wv)) 244 | wq,wk,wv = wq.permute(0, 2, 1, 3),wk.permute(0, 2, 3, 1),wv.permute(0, 2, 1, 3) 245 | wkr = self.r_attn(r[-seq_len:]) 246 | wkr = wkr.view(seq_len, self.n_heads, self.d_head) 247 | wkr = wkr.permute(1,2,0) 248 | #### compute attention score (AC is (a) + (c) and BS is (b) + (d) in the paper) 249 | AC = torch.matmul(wq+g_u,wk) 250 | BD = _line_shift(torch.matmul(wq+g_v, wkr), mask=self.r_mask) 251 | if self.scale: attn_score = (AC + BD).mul_(1/(self.d_head ** 0.5)) 252 | if mask is not None: 253 | mask = mask[...,-seq_len:] 254 | if hasattr(mask, 'bool'): mask = mask.bool() 255 | attn_score = attn_score.float().masked_fill(mask, -float('inf')).type_as(attn_score) 256 | attn_prob = self.drop_att(F.softmax(attn_score, dim=-1)) 257 | attn_vec = torch.matmul(attn_prob, wv) 258 | return attn_vec.permute(0, 2, 1, 3).contiguous().view(bs, x_len, -1) 259 | -------------------------------------------------------------------------------- /musicautobot/multitask_transformer/learner.py: -------------------------------------------------------------------------------- 1 | from fastai.basics import * 2 | from ..vocab import * 3 | from ..utils.top_k_top_p import top_k_top_p 4 | from ..utils.midifile import is_empty_midi 5 | from ..music_transformer.transform import * 6 | from ..music_transformer.learner import filter_invalid_indexes 7 | from .model import get_multitask_model 8 | from .dataloader import * 9 | 10 | def multitask_model_learner(data:DataBunch, config:dict=None, drop_mult:float=1., 11 | pretrained_path:PathOrStr=None, **learn_kwargs) -> 'LanguageLearner': 12 | "Create a `Learner` with a language model from `data` and `arch`." 13 | vocab = data.vocab 14 | vocab_size = len(vocab) 15 | 16 | if pretrained_path: 17 | state = torch.load(pretrained_path, map_location='cpu') 18 | if config is None: config = state['config'] 19 | 20 | model = get_multitask_model(vocab_size, config=config, drop_mult=drop_mult, pad_idx=vocab.pad_idx) 21 | metrics = [AverageMultiMetric(partial(m, pad_idx=vocab.pad_idx)) for m in [mask_acc, lm_acc, c2m_acc, m2c_acc]] 22 | loss_func = MultiLoss(ignore_index=data.vocab.pad_idx) 23 | learn = MultitaskLearner(data, model, loss_func=loss_func, metrics=metrics, **learn_kwargs) 24 | 25 | if pretrained_path: 26 | get_model(model).load_state_dict(state['model'], strict=False) 27 | if not hasattr(learn, 'opt'): learn.create_opt(defaults.lr, learn.wd) 28 | try: learn.opt.load_state_dict(state['opt']) 29 | except: pass 30 | del state 31 | gc.collect() 32 | 33 | return learn 34 | 35 | class MultitaskLearner(Learner): 36 | def save(self, file:PathLikeOrBinaryStream=None, with_opt:bool=True, config=None): 37 | "Save model and optimizer state (if `with_opt`) with `file` to `self.model_dir`. `file` can be file-like (file or buffer)" 38 | out_path = super().save(file, return_path=True, with_opt=with_opt) 39 | if config and out_path: 40 | state = torch.load(out_path) 41 | state['config'] = config 42 | torch.save(state, out_path) 43 | del state 44 | gc.collect() 45 | return out_path 46 | 47 | def predict_nw(self, item:MusicItem, n_words:int=128, 48 | temperatures:float=(1.0,1.0), min_bars=4, 49 | top_k=30, top_p=0.6): 50 | "Return the `n_words` that come after `text`." 51 | self.model.reset() 52 | new_idx = [] 53 | vocab = self.data.vocab 54 | x, pos = item.to_tensor(), item.get_pos_tensor() 55 | last_pos = pos[-1] if len(pos) else 0 56 | y = torch.tensor([0]) 57 | 58 | start_pos = last_pos 59 | 60 | sep_count = 0 61 | bar_len = SAMPLE_FREQ * 4 # assuming 4/4 time 62 | vocab = self.data.vocab 63 | 64 | repeat_count = 0 65 | 66 | for i in progress_bar(range(n_words), leave=True): 67 | batch = { 'lm': { 'x': x[None], 'pos': pos[None] } }, y 68 | logits = self.pred_batch(batch=batch)['lm'][-1][-1] 69 | 70 | prev_idx = new_idx[-1] if len(new_idx) else vocab.pad_idx 71 | 72 | # Temperature 73 | # Use first temperatures value if last prediction was duration 74 | temperature = temperatures[0] if vocab.is_duration_or_pad(prev_idx) else temperatures[1] 75 | repeat_penalty = max(0, np.log((repeat_count+1)/4)/5) * temperature 76 | temperature += repeat_penalty 77 | if temperature != 1.: logits = logits / temperature 78 | 79 | 80 | # Filter 81 | # bar = 16 beats 82 | filter_value = -float('Inf') 83 | if ((last_pos - start_pos) // 16) <= min_bars: logits[vocab.bos_idx] = filter_value 84 | 85 | logits = filter_invalid_indexes(logits, prev_idx, vocab, filter_value=filter_value) 86 | logits = top_k_top_p(logits, top_k=top_k, top_p=top_p, filter_value=filter_value) 87 | 88 | # Sample 89 | probs = F.softmax(logits, dim=-1) 90 | idx = torch.multinomial(probs, 1).item() 91 | 92 | # Update repeat count 93 | num_choices = len(probs.nonzero().view(-1)) 94 | if num_choices <= 2: repeat_count += 1 95 | else: repeat_count = repeat_count // 2 96 | 97 | if prev_idx==vocab.sep_idx: 98 | duration = idx - vocab.dur_range[0] 99 | last_pos = last_pos + duration 100 | 101 | bars_pred = (last_pos - start_pos) // 16 102 | abs_bar = last_pos // 16 103 | # if (bars % 8 == 0) and (bars_pred > min_bars): break 104 | if (i / n_words > 0.80) and (abs_bar % 4 == 0): break 105 | 106 | 107 | if idx==vocab.bos_idx: 108 | print('Predicted BOS token. Returning prediction...') 109 | break 110 | 111 | new_idx.append(idx) 112 | x = x.new_tensor([idx]) 113 | pos = pos.new_tensor([last_pos]) 114 | 115 | pred = vocab.to_music_item(np.array(new_idx)) 116 | full = item.append(pred) 117 | return pred, full 118 | 119 | def predict_mask(self, masked_item:MusicItem, 120 | temperatures:float=(1.0,1.0), 121 | top_k=20, top_p=0.8): 122 | x = masked_item.to_tensor() 123 | pos = masked_item.get_pos_tensor() 124 | y = torch.tensor([0]) 125 | vocab = self.data.vocab 126 | self.model.reset() 127 | mask_idxs = (x == vocab.mask_idx).nonzero().view(-1) 128 | 129 | repeat_count = 0 130 | 131 | for midx in progress_bar(mask_idxs, leave=True): 132 | prev_idx = x[midx-1] 133 | 134 | # Using original positions, otherwise model gets too off track 135 | # pos = torch.tensor(-position_enc(xb[0].cpu().numpy()), device=xb.device)[None] 136 | 137 | # Next Word 138 | logits = self.pred_batch(batch=({ 'msk': { 'x': x[None], 'pos': pos[None] } }, y) )['msk'][0][midx] 139 | 140 | # Temperature 141 | # Use first temperatures value if last prediction was duration 142 | temperature = temperatures[0] if vocab.is_duration_or_pad(prev_idx) else temperatures[1] 143 | repeat_penalty = max(0, np.log((repeat_count+1)/4)/5) * temperature 144 | temperature += repeat_penalty 145 | if temperature != 1.: logits = logits / temperature 146 | 147 | # Filter 148 | filter_value = -float('Inf') 149 | special_idxs = [vocab.bos_idx, vocab.sep_idx, vocab.stoi[EOS]] 150 | logits[special_idxs] = filter_value # Don't allow any special tokens (as we are only removing notes and durations) 151 | logits = filter_invalid_indexes(logits, prev_idx, vocab, filter_value=filter_value) 152 | logits = top_k_top_p(logits, top_k=top_k, top_p=top_p, filter_value=filter_value) 153 | 154 | # Sampling 155 | probs = F.softmax(logits, dim=-1) 156 | idx = torch.multinomial(probs, 1).item() 157 | 158 | # Update repeat count 159 | num_choices = len(probs.nonzero().view(-1)) 160 | if num_choices <= 2: repeat_count += 1 161 | else: repeat_count = repeat_count // 2 162 | 163 | x[midx] = idx 164 | 165 | return vocab.to_music_item(x.cpu().numpy()) 166 | 167 | def predict_s2s(self, input_item:MusicItem, target_item:MusicItem, n_words:int=256, 168 | temperatures:float=(1.0,1.0), top_k=30, top_p=0.8, 169 | use_memory=True): 170 | vocab = self.data.vocab 171 | 172 | # Input doesn't change. We can reuse the encoder output on each prediction 173 | with torch.no_grad(): 174 | inp, inp_pos = input_item.to_tensor(), input_item.get_pos_tensor() 175 | x_enc = self.model.encoder(inp[None], inp_pos[None]) 176 | 177 | # target 178 | targ = target_item.data.tolist() 179 | targ_pos = target_item.position.tolist() 180 | last_pos = targ_pos[-1] 181 | self.model.reset() 182 | 183 | repeat_count = 0 184 | 185 | max_pos = input_item.position[-1] + SAMPLE_FREQ * 4 # Only predict until both tracks/parts have the same length 186 | x, pos = inp.new_tensor(targ), inp_pos.new_tensor(targ_pos) 187 | 188 | for i in progress_bar(range(n_words), leave=True): 189 | # Predict 190 | with torch.no_grad(): 191 | dec = self.model.decoder(x[None], pos[None], x_enc) 192 | logits = self.model.head(dec)[-1, -1] 193 | 194 | # Temperature 195 | # Use first temperatures value if last prediction was duration 196 | prev_idx = targ[-1] if len(targ) else vocab.pad_idx 197 | temperature = temperatures[0] if vocab.is_duration_or_pad(prev_idx) else temperatures[1] 198 | repeat_penalty = max(0, np.log((repeat_count+1)/4)/5) * temperature 199 | temperature += repeat_penalty 200 | if temperature != 1.: logits = logits / temperature 201 | 202 | # Filter 203 | filter_value = -float('Inf') 204 | logits = filter_invalid_indexes(logits, prev_idx, vocab, filter_value=filter_value) 205 | logits = top_k_top_p(logits, top_k=top_k, top_p=top_p, filter_value=filter_value) 206 | 207 | # Sample 208 | probs = F.softmax(logits, dim=-1) 209 | idx = torch.multinomial(probs, 1).item() 210 | 211 | # Update repeat count 212 | num_choices = len(probs.nonzero().view(-1)) 213 | if num_choices <= 2: repeat_count += 1 214 | else: repeat_count = repeat_count // 2 215 | 216 | if idx == vocab.bos_idx | idx == vocab.stoi[EOS]: 217 | print('Predicting BOS/EOS') 218 | break 219 | 220 | if prev_idx == vocab.sep_idx: 221 | duration = idx - vocab.dur_range[0] 222 | last_pos = last_pos + duration 223 | if last_pos > max_pos: 224 | print('Predicted past counter-part length. Returning early') 225 | break 226 | 227 | targ_pos.append(last_pos) 228 | targ.append(idx) 229 | 230 | if use_memory: 231 | # Relying on memory for kv. Only need last prediction index 232 | x, pos = inp.new_tensor([targ[-1]]), inp_pos.new_tensor([targ_pos[-1]]) 233 | else: 234 | # Reset memory after each prediction, since we feeding the whole sequence every time 235 | self.model.reset() 236 | x, pos = inp.new_tensor(targ), inp_pos.new_tensor(targ_pos) 237 | 238 | return vocab.to_music_item(np.array(targ)) 239 | 240 | # High level prediction functions from midi file 241 | def nw_predict_from_midi(learn, midi=None, n_words=400, 242 | temperatures=(1.0,1.0), top_k=30, top_p=0.6, seed_len=None, **kwargs): 243 | vocab = learn.data.vocab 244 | seed = MusicItem.from_file(midi, vocab) if not is_empty_midi(midi) else MusicItem.empty(vocab) 245 | if seed_len is not None: seed = seed.trim_to_beat(seed_len) 246 | 247 | pred, full = learn.predict_nw(seed, n_words=n_words, temperatures=temperatures, top_k=top_k, top_p=top_p, **kwargs) 248 | return full 249 | 250 | def s2s_predict_from_midi(learn, midi=None, n_words=200, 251 | temperatures=(1.0,1.0), top_k=24, top_p=0.7, seed_len=None, pred_melody=True, **kwargs): 252 | multitrack_item = MultitrackItem.from_file(midi, learn.data.vocab) 253 | melody, chords = multitrack_item.melody, multitrack_item.chords 254 | inp, targ = (chords, melody) if pred_melody else (melody, chords) 255 | 256 | # if seed_len is passed, cutoff sequence so we can predict the rest 257 | if seed_len is not None: targ = targ.trim_to_beat(seed_len) 258 | targ = targ.remove_eos() 259 | 260 | pred = learn.predict_s2s(inp, targ, n_words=n_words, temperatures=temperatures, top_k=top_k, top_p=top_p, **kwargs) 261 | 262 | part_order = (pred, inp) if pred_melody else (inp, pred) 263 | return MultitrackItem(*part_order) 264 | 265 | def mask_predict_from_midi(learn, midi=None, predict_notes=True, 266 | temperatures=(1.0,1.0), top_k=30, top_p=0.7, section=None, **kwargs): 267 | item = MusicItem.from_file(midi, learn.data.vocab) 268 | masked_item = item.mask_pitch(section) if predict_notes else item.mask_duration(section) 269 | pred = learn.predict_mask(masked_item, temperatures=temperatures, top_k=top_k, top_p=top_p, **kwargs) 270 | return pred 271 | 272 | # LOSS AND METRICS 273 | 274 | class MultiLoss(): 275 | def __init__(self, ignore_index=None): 276 | "Loss mult - Mask, NextWord, Seq2Seq" 277 | self.loss = CrossEntropyFlat(ignore_index=ignore_index) 278 | 279 | def __call__(self, inputs:Dict[str,Tensor], targets:Dict[str,Tensor])->Rank0Tensor: 280 | losses = [self.loss(inputs[key], target) for key,target in targets.items()] 281 | return sum(losses) 282 | 283 | def acc_ignore_pad(input:Tensor, targ:Tensor, pad_idx)->Rank0Tensor: 284 | if input is None or targ is None: return None 285 | n = targ.shape[0] 286 | input = input.argmax(dim=-1).view(n,-1) 287 | targ = targ.view(n,-1) 288 | mask = targ != pad_idx 289 | return (input[mask]==targ[mask]).float().mean() 290 | 291 | def acc_index(inputs, targets, key, pad_idx): 292 | return acc_ignore_pad(inputs.get(key), targets.get(key), pad_idx) 293 | 294 | def mask_acc(inputs, targets, pad_idx): return acc_index(inputs, targets, 'msk', pad_idx) 295 | def lm_acc(inputs, targets, pad_idx): return acc_index(inputs, targets, 'lm', pad_idx) 296 | def c2m_acc(inputs, targets, pad_idx): return acc_index(inputs, targets, 'c2m', pad_idx) 297 | def m2c_acc(inputs, targets, pad_idx): return acc_index(inputs, targets, 'm2c', pad_idx) 298 | 299 | 300 | class AverageMultiMetric(AverageMetric): 301 | "Updated fastai.AverageMetric to support multi task metrics." 302 | def on_batch_end(self, last_output, last_target, **kwargs): 303 | "Update metric computation with `last_output` and `last_target`." 304 | if not is_listy(last_target): last_target=[last_target] 305 | val = self.func(last_output, *last_target) 306 | if val is None: return 307 | self.count += first_el(last_target).size(0) 308 | if self.world: 309 | val = val.clone() 310 | dist.all_reduce(val, op=dist.ReduceOp.SUM) 311 | val /= self.world 312 | self.val += first_el(last_target).size(0) * val.detach().cpu() 313 | 314 | def on_epoch_end(self, last_metrics, **kwargs): 315 | "Set the final result in `last_metrics`." 316 | if self.count == 0: return add_metrics(last_metrics, 0) 317 | return add_metrics(last_metrics, self.val/self.count) 318 | 319 | 320 | # MODEL LOADING 321 | class MTTrainer(LearnerCallback): 322 | "`Callback` that regroups lr adjustment to seq_len, AR and TAR." 323 | def __init__(self, learn:Learner, dataloaders=None, starting_mask_window=1): 324 | super().__init__(learn) 325 | self.count = 1 326 | self.mw_start = starting_mask_window 327 | self.dataloaders = dataloaders 328 | 329 | def on_epoch_begin(self, **kwargs): 330 | "Reset the hidden state of the model." 331 | model = get_model(self.learn.model) 332 | model.reset() 333 | model.encoder.mask_steps = max(self.count+self.mw_start, 100) 334 | 335 | def on_epoch_end(self, last_metrics, **kwargs): 336 | "Finish the computation and sends the result to the Recorder." 337 | if self.dataloaders is not None: 338 | self.learn.data = self.dataloaders[self.count % len(self.dataloaders)] 339 | self.count += 1 340 | 341 | -------------------------------------------------------------------------------- /notebooks/music_transformer/Generate_colab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "kernelspec": { 6 | "name": "python3", 7 | "display_name": "Python 3" 8 | }, 9 | "language_info": { 10 | "codemirror_mode": { 11 | "name": "ipython", 12 | "version": 3 13 | }, 14 | "file_extension": ".py", 15 | "mimetype": "text/x-python", 16 | "name": "python", 17 | "nbconvert_exporter": "python", 18 | "pygments_lexer": "ipython3", 19 | "version": "3.7.2" 20 | }, 21 | "colab": { 22 | "name": "Generate_colab.ipynb", 23 | "provenance": [], 24 | "collapsed_sections": [], 25 | "include_colab_link": true 26 | }, 27 | "accelerator": "GPU" 28 | }, 29 | "cells": [ 30 | { 31 | "cell_type": "markdown", 32 | "metadata": { 33 | "id": "view-in-github", 34 | "colab_type": "text" 35 | }, 36 | "source": [ 37 | "\"Open" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "metadata": { 43 | "id": "z7ZG7C0jMf7d", 44 | "colab_type": "code", 45 | "colab": {} 46 | }, 47 | "source": [ 48 | "!git clone https://github.com/bearpelican/musicautobot.git" 49 | ], 50 | "execution_count": 0, 51 | "outputs": [] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "metadata": { 56 | "id": "ALIKjODvJdFa", 57 | "colab_type": "code", 58 | "colab": {} 59 | }, 60 | "source": [ 61 | "import os\n", 62 | "os.chdir('musicautobot')" 63 | ], 64 | "execution_count": 0, 65 | "outputs": [] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "metadata": { 70 | "id": "Sklh_SE1SE8a", 71 | "colab_type": "code", 72 | "colab": {} 73 | }, 74 | "source": [ 75 | "!apt install musescore fluidsynth\n", 76 | "!cp /usr/share/sounds/sf2/FluidR3_GM.sf2 ./font.sf2" 77 | ], 78 | "execution_count": 0, 79 | "outputs": [] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "metadata": { 84 | "id": "rEWuEC9_M1pi", 85 | "colab_type": "code", 86 | "colab": {} 87 | }, 88 | "source": [ 89 | "!pip install torch fastai music21 pebble fluidsynth midi2audio" 90 | ], 91 | "execution_count": 0, 92 | "outputs": [] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "metadata": { 97 | "id": "dzrRdTvlJdFd", 98 | "colab_type": "code", 99 | "colab": {} 100 | }, 101 | "source": [ 102 | "from musicautobot.numpy_encode import *\n", 103 | "from musicautobot.utils.file_processing import process_all, process_file\n", 104 | "from musicautobot.config import *\n", 105 | "from musicautobot.music_transformer import *\n", 106 | "from musicautobot.utils.setup_musescore import setup_musescore\n", 107 | "setup_musescore()" 108 | ], 109 | "execution_count": 0, 110 | "outputs": [] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "metadata": { 115 | "id": "qWjLAgmVXcoB", 116 | "colab_type": "code", 117 | "colab": {} 118 | }, 119 | "source": [ 120 | "from midi2audio import FluidSynth\n", 121 | "from IPython.display import Audio" 122 | ], 123 | "execution_count": 0, 124 | "outputs": [] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "metadata": { 129 | "id": "j48poJKnXjZp", 130 | "colab_type": "code", 131 | "colab": {} 132 | }, 133 | "source": [ 134 | "# Colab cannot play music directly from music21 - must convert to .wav first\n", 135 | "def play_wav(stream):\n", 136 | " out_midi = stream.write('midi')\n", 137 | " out_wav = str(Path(out_midi).with_suffix('.wav'))\n", 138 | " FluidSynth(\"font.sf2\").midi_to_audio(out_midi, out_wav)\n", 139 | " return Audio(out_wav)\n" 140 | ], 141 | "execution_count": 0, 142 | "outputs": [] 143 | }, 144 | { 145 | "cell_type": "markdown", 146 | "metadata": { 147 | "id": "EckFLJkjJdFg", 148 | "colab_type": "text" 149 | }, 150 | "source": [ 151 | "# Generate Music with Pretrained Model" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": { 157 | "id": "FmhNm2TcJdFh", 158 | "colab_type": "text" 159 | }, 160 | "source": [ 161 | "### Load Pretrained" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "metadata": { 167 | "id": "hH7drZxkJdFi", 168 | "colab_type": "code", 169 | "colab": {} 170 | }, 171 | "source": [ 172 | "# Location of your midi files\n", 173 | "midi_path = Path('data/midi/examples')\n", 174 | "\n", 175 | "# Location of saved datset\n", 176 | "data_path = Path('data/numpy')" 177 | ], 178 | "execution_count": 0, 179 | "outputs": [] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "metadata": { 184 | "id": "eoDJWuSaJdFk", 185 | "colab_type": "code", 186 | "colab": {} 187 | }, 188 | "source": [ 189 | "# Data\n", 190 | "data = MusicDataBunch.empty(data_path)\n", 191 | "vocab = data.vocab\n", 192 | "\n", 193 | "# For Saved Data:\n", 194 | "# data = load_data(data_path, 'musicitem_data_save.pkl')" 195 | ], 196 | "execution_count": 0, 197 | "outputs": [] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "metadata": { 202 | "id": "JZeSORnPJdFm", 203 | "colab_type": "code", 204 | "colab": {} 205 | }, 206 | "source": [ 207 | "# Pretrained Model\n", 208 | "# Download pretrained model if you haven't already\n", 209 | "pretrained_url = 'https://ashaw-midi-web-server.s3-us-west-2.amazonaws.com/pretrained/MusicTransformerKeyC.pth'\n", 210 | "# pretrained_url = 'https://ashaw-midi-web-server.s3-us-west-2.amazonaws.com/pretrained/MusicTransformer.pth'\n", 211 | "\n", 212 | "pretrained_path = data_path/'pretrained'/Path(pretrained_url).name\n", 213 | "pretrained_path.parent.mkdir(parents=True, exist_ok=True)\n", 214 | "download_url(pretrained_url, dest=pretrained_path)" 215 | ], 216 | "execution_count": 0, 217 | "outputs": [] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "metadata": { 222 | "id": "CJqx0wAGJdFp", 223 | "colab_type": "code", 224 | "colab": {} 225 | }, 226 | "source": [ 227 | "# Learner\n", 228 | "learn = music_model_learner(data, pretrained_path=pretrained_path)" 229 | ], 230 | "execution_count": 0, 231 | "outputs": [] 232 | }, 233 | { 234 | "cell_type": "markdown", 235 | "metadata": { 236 | "id": "G3HDp9EMJdFr", 237 | "colab_type": "text" 238 | }, 239 | "source": [ 240 | "## Prediction" 241 | ] 242 | }, 243 | { 244 | "cell_type": "markdown", 245 | "metadata": { 246 | "id": "UDWCdsHFJdFr", 247 | "colab_type": "text" 248 | }, 249 | "source": [ 250 | "#### Choose existing midi file as a starting point" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "metadata": { 256 | "id": "61MAPpw0JdFs", 257 | "colab_type": "code", 258 | "colab": {} 259 | }, 260 | "source": [ 261 | "midi_files = get_files(midi_path, recurse=True, extensions='.mid'); midi_files[:4]" 262 | ], 263 | "execution_count": 0, 264 | "outputs": [] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "metadata": { 269 | "id": "OJA121IEJdFu", 270 | "colab_type": "code", 271 | "colab": {} 272 | }, 273 | "source": [ 274 | "idx = 1\n", 275 | "f = midi_files[idx]; f" 276 | ], 277 | "execution_count": 0, 278 | "outputs": [] 279 | }, 280 | { 281 | "cell_type": "markdown", 282 | "metadata": { 283 | "id": "nJShT9HKJdFw", 284 | "colab_type": "text" 285 | }, 286 | "source": [ 287 | "#### NextWord/Autocomplete\n", 288 | "\n", 289 | "Trim the song to only a few notes. Model will use these notes a seed and continue the idea" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "metadata": { 295 | "id": "WUPmoGqcJdFx", 296 | "colab_type": "code", 297 | "colab": {} 298 | }, 299 | "source": [ 300 | "cutoff_beat = 10\n", 301 | "\n", 302 | "item = MusicItem.from_file(f, data.vocab)\n", 303 | "seed_item = item.trim_to_beat(cutoff_beat)" 304 | ], 305 | "execution_count": 0, 306 | "outputs": [] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "metadata": { 311 | "id": "eUwtyXdkJdFz", 312 | "colab_type": "code", 313 | "colab": {} 314 | }, 315 | "source": [ 316 | "seed_item.show()" 317 | ], 318 | "execution_count": 0, 319 | "outputs": [] 320 | }, 321 | { 322 | "cell_type": "code", 323 | "metadata": { 324 | "id": "wp3IetalYDM8", 325 | "colab_type": "code", 326 | "colab": {} 327 | }, 328 | "source": [ 329 | "# seed_item.play()\n", 330 | "play_wav(seed_item.stream)" 331 | ], 332 | "execution_count": 0, 333 | "outputs": [] 334 | }, 335 | { 336 | "cell_type": "markdown", 337 | "metadata": { 338 | "id": "JtYbbqALJdF5", 339 | "colab_type": "text" 340 | }, 341 | "source": [ 342 | "#### Use seed to predict next sequence" 343 | ] 344 | }, 345 | { 346 | "cell_type": "code", 347 | "metadata": { 348 | "id": "gurscr8IJdF5", 349 | "colab_type": "code", 350 | "colab": {} 351 | }, 352 | "source": [ 353 | "pred, full = learn.predict(seed_item, n_words=400, temperatures=(1.1,0.4), min_bars=12, top_k=24, top_p=0.7)" 354 | ], 355 | "execution_count": 0, 356 | "outputs": [] 357 | }, 358 | { 359 | "cell_type": "code", 360 | "metadata": { 361 | "id": "8wQCDsQ4JdF7", 362 | "colab_type": "code", 363 | "colab": {} 364 | }, 365 | "source": [ 366 | "pred.show()" 367 | ], 368 | "execution_count": 0, 369 | "outputs": [] 370 | }, 371 | { 372 | "cell_type": "code", 373 | "metadata": { 374 | "scrolled": true, 375 | "id": "8biSV173JdF9", 376 | "colab_type": "code", 377 | "colab": {} 378 | }, 379 | "source": [ 380 | "play_wav(pred.stream)" 381 | ], 382 | "execution_count": 0, 383 | "outputs": [] 384 | }, 385 | { 386 | "cell_type": "code", 387 | "metadata": { 388 | "id": "UZg6f8mVJdF_", 389 | "colab_type": "code", 390 | "colab": {} 391 | }, 392 | "source": [ 393 | "full_song = seed_item.append(pred); full_song.show()" 394 | ], 395 | "execution_count": 0, 396 | "outputs": [] 397 | }, 398 | { 399 | "cell_type": "code", 400 | "metadata": { 401 | "scrolled": true, 402 | "id": "-48trKgYJdGB", 403 | "colab_type": "code", 404 | "colab": {} 405 | }, 406 | "source": [ 407 | "play_wav(full_song.stream)" 408 | ], 409 | "execution_count": 0, 410 | "outputs": [] 411 | }, 412 | { 413 | "cell_type": "markdown", 414 | "metadata": { 415 | "id": "wQnTf4TxJdGD", 416 | "colab_type": "text" 417 | }, 418 | "source": [ 419 | "#### Add More Randomness to prediction" 420 | ] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "metadata": { 425 | "id": "s-xmD-avJdGD", 426 | "colab_type": "code", 427 | "colab": {} 428 | }, 429 | "source": [ 430 | "note_temp = 1.4 # Determines amount of variation in note pitches\n", 431 | "dur_temp = 0.8 # Amount of randomness in rhythm\n", 432 | "top_k = 30\n", 433 | "pred, full = learn.predict(seed_item, n_words=400, temperatures=(note_temp, dur_temp), min_bars=12, top_k=top_k, top_p=0.7)" 434 | ], 435 | "execution_count": 0, 436 | "outputs": [] 437 | }, 438 | { 439 | "cell_type": "code", 440 | "metadata": { 441 | "id": "ypaK9xZQJdGG", 442 | "colab_type": "code", 443 | "colab": {} 444 | }, 445 | "source": [ 446 | "pred.show()" 447 | ], 448 | "execution_count": 0, 449 | "outputs": [] 450 | }, 451 | { 452 | "cell_type": "code", 453 | "metadata": { 454 | "id": "wNQGCfyKYjfe", 455 | "colab_type": "code", 456 | "colab": {} 457 | }, 458 | "source": [ 459 | "play_wav(pred.stream)" 460 | ], 461 | "execution_count": 0, 462 | "outputs": [] 463 | }, 464 | { 465 | "cell_type": "markdown", 466 | "metadata": { 467 | "id": "mU1pTyKiJdGJ", 468 | "colab_type": "text" 469 | }, 470 | "source": [ 471 | "### Pop Music Theory" 472 | ] 473 | }, 474 | { 475 | "cell_type": "markdown", 476 | "metadata": { 477 | "id": "yDLbz6KGJdGK", 478 | "colab_type": "text" 479 | }, 480 | "source": [ 481 | "According to hooktheory, the most popular chord progression is I-V-vi-IV \n", 482 | "https://www.hooktheory.com/theorytab/common-chord-progressions" 483 | ] 484 | }, 485 | { 486 | "cell_type": "code", 487 | "metadata": { 488 | "id": "IJTfsTapJdGK", 489 | "colab_type": "code", 490 | "colab": {} 491 | }, 492 | "source": [ 493 | "# Let's create a partial progression I-V-vi\n", 494 | "p = music21.stream.Part()\n", 495 | "p.append(music21.chord.Chord('C4 E4 G4', type='half')) # I\n", 496 | "p.append(music21.chord.Chord('G3 B3 D4', type='half')) # V\n", 497 | "p.append(music21.chord.Chord('A3 C4 E4', type='half')) # vi\n", 498 | "s = music21.stream.Score([p])\n", 499 | "chord_item = MusicItem.from_stream(s, vocab)\n", 500 | "chord_item.show()" 501 | ], 502 | "execution_count": 0, 503 | "outputs": [] 504 | }, 505 | { 506 | "cell_type": "code", 507 | "metadata": { 508 | "id": "S2DXA4MaJdGM", 509 | "colab_type": "code", 510 | "colab": {} 511 | }, 512 | "source": [ 513 | "temperaturs = (0.5,0.5) # Let's lower the note randomness for this test\n", 514 | "pred, full = learn.predict(chord_item, n_words=10, temperatures=(0.5,0.5))\n", 515 | "full.show()" 516 | ], 517 | "execution_count": 0, 518 | "outputs": [] 519 | }, 520 | { 521 | "cell_type": "code", 522 | "metadata": { 523 | "id": "A-H2KOqmJdGO", 524 | "colab_type": "code", 525 | "colab": {} 526 | }, 527 | "source": [ 528 | "# Predicted chords - IV\n", 529 | "play_wav(pred.stream)" 530 | ], 531 | "execution_count": 0, 532 | "outputs": [] 533 | }, 534 | { 535 | "cell_type": "code", 536 | "metadata": { 537 | "id": "HP_Ofs_ZJdGP", 538 | "colab_type": "code", 539 | "colab": {} 540 | }, 541 | "source": [ 542 | "# Full sequence\n", 543 | "chord_item.append(pred).show()" 544 | ], 545 | "execution_count": 0, 546 | "outputs": [] 547 | }, 548 | { 549 | "cell_type": "markdown", 550 | "metadata": { 551 | "id": "dG3J6hM_JdGR", 552 | "colab_type": "text" 553 | }, 554 | "source": [ 555 | "Looks like it predicted the most popular progression!" 556 | ] 557 | }, 558 | { 559 | "cell_type": "markdown", 560 | "metadata": { 561 | "id": "b-hAoe5HJdGS", 562 | "colab_type": "text" 563 | }, 564 | "source": [ 565 | "#### Predict without a starting sequence" 566 | ] 567 | }, 568 | { 569 | "cell_type": "code", 570 | "metadata": { 571 | "id": "63TFLmE6JdGS", 572 | "colab_type": "code", 573 | "colab": {} 574 | }, 575 | "source": [ 576 | "empty_item = MusicItem.empty(vocab)" 577 | ], 578 | "execution_count": 0, 579 | "outputs": [] 580 | }, 581 | { 582 | "cell_type": "code", 583 | "metadata": { 584 | "id": "h9gF6rt7JdGU", 585 | "colab_type": "code", 586 | "colab": {} 587 | }, 588 | "source": [ 589 | "pred, full = learn.predict(empty_item, n_words=200)" 590 | ], 591 | "execution_count": 0, 592 | "outputs": [] 593 | }, 594 | { 595 | "cell_type": "code", 596 | "metadata": { 597 | "id": "Khe7AV8HJdGV", 598 | "colab_type": "code", 599 | "colab": {} 600 | }, 601 | "source": [ 602 | "pred.show()" 603 | ], 604 | "execution_count": 0, 605 | "outputs": [] 606 | } 607 | ] 608 | } -------------------------------------------------------------------------------- /notebooks/multitask_transformer/Generate_colab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "kernelspec": { 6 | "name": "python3", 7 | "display_name": "Python 3" 8 | }, 9 | "language_info": { 10 | "codemirror_mode": { 11 | "name": "ipython", 12 | "version": 3 13 | }, 14 | "file_extension": ".py", 15 | "mimetype": "text/x-python", 16 | "name": "python", 17 | "nbconvert_exporter": "python", 18 | "pygments_lexer": "ipython3", 19 | "version": "3.7.2" 20 | }, 21 | "colab": { 22 | "name": "Generate_colab.ipynb", 23 | "provenance": [], 24 | "collapsed_sections": [], 25 | "include_colab_link": true 26 | }, 27 | "accelerator": "GPU" 28 | }, 29 | "cells": [ 30 | { 31 | "cell_type": "markdown", 32 | "metadata": { 33 | "id": "view-in-github", 34 | "colab_type": "text" 35 | }, 36 | "source": [ 37 | "\"Open" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "metadata": { 43 | "id": "NLW5T8KdcX3E", 44 | "colab_type": "code", 45 | "colab": {} 46 | }, 47 | "source": [ 48 | "!git clone https://github.com/bearpelican/musicautobot.git" 49 | ], 50 | "execution_count": 0, 51 | "outputs": [] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "metadata": { 56 | "id": "TCAs5NW5cX3L", 57 | "colab_type": "code", 58 | "colab": {} 59 | }, 60 | "source": [ 61 | "import os\n", 62 | "os.chdir('musicautobot')" 63 | ], 64 | "execution_count": 0, 65 | "outputs": [] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "metadata": { 70 | "id": "L9etuKmTdT_m", 71 | "colab_type": "code", 72 | "colab": {} 73 | }, 74 | "source": [ 75 | "!nvidia-smi" 76 | ], 77 | "execution_count": 0, 78 | "outputs": [] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "metadata": { 83 | "id": "0fYlE3becc1f", 84 | "colab_type": "code", 85 | "colab": {} 86 | }, 87 | "source": [ 88 | "!apt install musescore fluidsynth\n", 89 | "!cp /usr/share/sounds/sf2/FluidR3_GM.sf2 ./font.sf2\n", 90 | "!pip install torch fastai music21 pebble fluidsynth midi2audio" 91 | ], 92 | "execution_count": 0, 93 | "outputs": [] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "metadata": { 98 | "id": "Wlug8tPUcX3O", 99 | "colab_type": "code", 100 | "colab": {} 101 | }, 102 | "source": [ 103 | "from musicautobot.numpy_encode import *\n", 104 | "from musicautobot.utils.file_processing import process_all, process_file\n", 105 | "from musicautobot.config import *\n", 106 | "from musicautobot.music_transformer import *\n", 107 | "from musicautobot.multitask_transformer import *\n", 108 | "from musicautobot.numpy_encode import stream2npenc_parts\n", 109 | "from musicautobot.utils.setup_musescore import setup_musescore\n", 110 | "setup_musescore()" 111 | ], 112 | "execution_count": 0, 113 | "outputs": [] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "metadata": { 118 | "id": "StzS1R1mcjhp", 119 | "colab_type": "code", 120 | "colab": {} 121 | }, 122 | "source": [ 123 | "from midi2audio import FluidSynth\n", 124 | "from IPython.display import Audio" 125 | ], 126 | "execution_count": 0, 127 | "outputs": [] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "metadata": { 132 | "id": "Wz0j0TcvclAw", 133 | "colab_type": "code", 134 | "colab": {} 135 | }, 136 | "source": [ 137 | "# Colab cannot play music directly from music21 - must convert to .wav first\n", 138 | "def play_wav(stream):\n", 139 | " out_midi = stream.write('midi')\n", 140 | " out_wav = str(Path(out_midi).with_suffix('.wav'))\n", 141 | " FluidSynth(\"font.sf2\").midi_to_audio(out_midi, out_wav)\n", 142 | " return Audio(out_wav)\n" 143 | ], 144 | "execution_count": 0, 145 | "outputs": [] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "metadata": { 150 | "id": "5u3Sup04cX3S", 151 | "colab_type": "text" 152 | }, 153 | "source": [ 154 | "# Generate Music with Pretrained Model" 155 | ] 156 | }, 157 | { 158 | "cell_type": "markdown", 159 | "metadata": { 160 | "id": "fqdq9pHOcX3T", 161 | "colab_type": "text" 162 | }, 163 | "source": [ 164 | "### Load Pretrained" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "metadata": { 170 | "id": "IE8TC8lfcX3U", 171 | "colab_type": "code", 172 | "colab": {} 173 | }, 174 | "source": [ 175 | "# Config\n", 176 | "config = multitask_config();\n", 177 | "\n", 178 | "# Location of your midi files\n", 179 | "midi_path = Path('data/midi')\n", 180 | "\n", 181 | "# Location of saved datset\n", 182 | "data_path = Path('data/numpy')\n", 183 | "data_save_name = 'musicitem_data_save.pkl'" 184 | ], 185 | "execution_count": 0, 186 | "outputs": [] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "metadata": { 191 | "id": "R88krbh1cX3W", 192 | "colab_type": "code", 193 | "colab": {} 194 | }, 195 | "source": [ 196 | "# Data\n", 197 | "data = MusicDataBunch.empty(data_path)\n", 198 | "vocab = data.vocab" 199 | ], 200 | "execution_count": 0, 201 | "outputs": [] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "metadata": { 206 | "id": "hC531JPKcX3Y", 207 | "colab_type": "code", 208 | "colab": {} 209 | }, 210 | "source": [ 211 | "# Pretrained Model\n", 212 | "\n", 213 | "# Download pretrained model if you haven't already\n", 214 | "pretrained_url = 'https://ashaw-midi-web-server.s3-us-west-2.amazonaws.com/pretrained/MultitaskSmallKeyC.pth'\n", 215 | "# pretrained_url = 'https://ashaw-midi-web-server.s3-us-west-2.amazonaws.com/pretrained/MultitaskSmall.pth'\n", 216 | "\n", 217 | "pretrained_path = data_path/'pretrained'/Path(pretrained_url).name\n", 218 | "pretrained_path.parent.mkdir(parents=True, exist_ok=True)\n", 219 | "download_url(pretrained_url, dest=pretrained_path)" 220 | ], 221 | "execution_count": 0, 222 | "outputs": [] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "metadata": { 227 | "id": "ujbf2949cX3b", 228 | "colab_type": "code", 229 | "colab": {} 230 | }, 231 | "source": [ 232 | "# Learner\n", 233 | "learn = multitask_model_learner(data, pretrained_path=pretrained_path)\n", 234 | "# learn.to_fp16();" 235 | ], 236 | "execution_count": 0, 237 | "outputs": [] 238 | }, 239 | { 240 | "cell_type": "markdown", 241 | "metadata": { 242 | "id": "lpoQP5ZVcX3d", 243 | "colab_type": "text" 244 | }, 245 | "source": [ 246 | "### Choose existing midi file as a starting point" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "metadata": { 252 | "id": "ovaLLyFbcX3e", 253 | "colab_type": "code", 254 | "colab": {} 255 | }, 256 | "source": [ 257 | "example_dir = midi_path/'examples'\n", 258 | "midi_files = get_files(example_dir, recurse=True, extensions='.mid'); midi_files[:5]" 259 | ], 260 | "execution_count": 0, 261 | "outputs": [] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "metadata": { 266 | "id": "9AShtRpNcX3h", 267 | "colab_type": "code", 268 | "colab": {} 269 | }, 270 | "source": [ 271 | "file = midi_files[3]; file" 272 | ], 273 | "execution_count": 0, 274 | "outputs": [] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "metadata": { 279 | "id": "CQd84OekcX3k", 280 | "colab_type": "code", 281 | "colab": {} 282 | }, 283 | "source": [ 284 | "# Encode file \n", 285 | "item = MusicItem.from_file(file, data.vocab)\n", 286 | "\n", 287 | "x = item.to_tensor()\n", 288 | "x_pos = item.get_pos_tensor()" 289 | ], 290 | "execution_count": 0, 291 | "outputs": [] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "metadata": { 296 | "id": "57DlVAp1cX3n", 297 | "colab_type": "code", 298 | "colab": {} 299 | }, 300 | "source": [ 301 | "item.show()" 302 | ], 303 | "execution_count": 0, 304 | "outputs": [] 305 | }, 306 | { 307 | "cell_type": "code", 308 | "metadata": { 309 | "id": "7NAwJ3A6cX3q", 310 | "colab_type": "code", 311 | "colab": {} 312 | }, 313 | "source": [ 314 | "# item.play()\n", 315 | "play_wav(item.stream)" 316 | ], 317 | "execution_count": 0, 318 | "outputs": [] 319 | }, 320 | { 321 | "cell_type": "markdown", 322 | "metadata": { 323 | "id": "yQ-ZyQX9cX3s", 324 | "colab_type": "text" 325 | }, 326 | "source": [ 327 | "## Generate" 328 | ] 329 | }, 330 | { 331 | "cell_type": "markdown", 332 | "metadata": { 333 | "id": "A08P9D2JcX3t", 334 | "colab_type": "text" 335 | }, 336 | "source": [ 337 | "MultitaskTransformer trains on 3 separate tasks. \n", 338 | "1. NextWord\n", 339 | "2. Mask\n", 340 | "3. Sequence to Sequence" 341 | ] 342 | }, 343 | { 344 | "cell_type": "markdown", 345 | "metadata": { 346 | "id": "dF9uAEdYcX3u", 347 | "colab_type": "text" 348 | }, 349 | "source": [ 350 | "Because we train on 3 separate tasks, we can actually generate some really cool note sequences.\n", 351 | "\n", 352 | "1. NextWord/Autocomplete - Take a sequence of notes and predict the next note\n", 353 | " * 1a. Vanilla Language Model predictions - See [MusicTransformer](../music_transformer) project\n", 354 | "\n", 355 | "\n", 356 | "2. Mask/Remix - Mask certain parts of song and remix those portions.\n", 357 | " * 2a. Note Masking - Mask all the note pitches and create a new sequence with different notes, but same exact rhythm\n", 358 | " * 2b. Duration Masking - Mask the note durations. Generate a new sequence with the same melody, but with a different rhythm\n", 359 | "\n", 360 | "\n", 361 | "3. Seq2Seq/Translation - Generate melody from chords or vice versa. \n", 362 | " * 3a. New Melody - Generate a new melody from existing chords\n", 363 | " * 3b. Harmonization - Generate chords to acompany an existing melody" 364 | ] 365 | }, 366 | { 367 | "cell_type": "markdown", 368 | "metadata": { 369 | "id": "4WFB3lIscX3v", 370 | "colab_type": "text" 371 | }, 372 | "source": [ 373 | "## 1. NextWord/Autocomplete\n", 374 | "\n", 375 | "Trim the song to only a few notes. Model will use these notes a seed and continue the idea" 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "metadata": { 381 | "id": "i6RpYxmIcX3w", 382 | "colab_type": "code", 383 | "colab": {} 384 | }, 385 | "source": [ 386 | "seed_len = 6 # 4 beats = 1 bar\n", 387 | "seed = item.trim_to_beat(seed_len)" 388 | ], 389 | "execution_count": 0, 390 | "outputs": [] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "metadata": { 395 | "id": "BQTGsKEccX3y", 396 | "colab_type": "code", 397 | "colab": {} 398 | }, 399 | "source": [ 400 | "seed.show()" 401 | ], 402 | "execution_count": 0, 403 | "outputs": [] 404 | }, 405 | { 406 | "cell_type": "code", 407 | "metadata": { 408 | "id": "sY1S6x_FcX30", 409 | "colab_type": "code", 410 | "colab": {} 411 | }, 412 | "source": [ 413 | "pred_nw, full = learn.predict_nw(seed, n_words=200)" 414 | ], 415 | "execution_count": 0, 416 | "outputs": [] 417 | }, 418 | { 419 | "cell_type": "code", 420 | "metadata": { 421 | "id": "Z88xHG_qcX33", 422 | "colab_type": "code", 423 | "colab": {} 424 | }, 425 | "source": [ 426 | "pred_nw.show()" 427 | ], 428 | "execution_count": 0, 429 | "outputs": [] 430 | }, 431 | { 432 | "cell_type": "code", 433 | "metadata": { 434 | "scrolled": true, 435 | "id": "3LAOILA8cX35", 436 | "colab_type": "code", 437 | "colab": {} 438 | }, 439 | "source": [ 440 | "play_wav(pred_nw.stream)" 441 | ], 442 | "execution_count": 0, 443 | "outputs": [] 444 | }, 445 | { 446 | "cell_type": "markdown", 447 | "metadata": { 448 | "id": "lMTkgf2OcX38", 449 | "colab_type": "text" 450 | }, 451 | "source": [ 452 | "Add more randomness" 453 | ] 454 | }, 455 | { 456 | "cell_type": "code", 457 | "metadata": { 458 | "id": "Gytv3l22cX39", 459 | "colab_type": "code", 460 | "colab": {} 461 | }, 462 | "source": [ 463 | "pitch_temp = 1.4 # randomness of melody\n", 464 | "tempo_temp = 1.0 # randomness or rhythm\n", 465 | "top_k = 40\n", 466 | "pred_nw_rand, full = learn.predict_nw(seed, temperatures=(pitch_temp, tempo_temp), top_k=top_k, top_p=0.5)\n", 467 | "pred_nw_rand.show()" 468 | ], 469 | "execution_count": 0, 470 | "outputs": [] 471 | }, 472 | { 473 | "cell_type": "code", 474 | "metadata": { 475 | "id": "QXQtaJc0d_n_", 476 | "colab_type": "code", 477 | "colab": {} 478 | }, 479 | "source": [ 480 | "play_wav(pred_nw_rand.stream)" 481 | ], 482 | "execution_count": 0, 483 | "outputs": [] 484 | }, 485 | { 486 | "cell_type": "code", 487 | "metadata": { 488 | "id": "byUvCf4McX4C", 489 | "colab_type": "code", 490 | "colab": {} 491 | }, 492 | "source": [ 493 | "# Convenience function\n", 494 | "# out = nw_predict_from_midi(learn, file, seed_len=seed_len, top_k=30, top_p=0.5); out.show()" 495 | ], 496 | "execution_count": 0, 497 | "outputs": [] 498 | }, 499 | { 500 | "cell_type": "markdown", 501 | "metadata": { 502 | "id": "nQRnIDhocX4o", 503 | "colab_type": "text" 504 | }, 505 | "source": [ 506 | "## 2. Seq2Seq/Translation" 507 | ] 508 | }, 509 | { 510 | "cell_type": "markdown", 511 | "metadata": { 512 | "id": "gtjVp2M4cX4p", 513 | "colab_type": "text" 514 | }, 515 | "source": [ 516 | "Load MultitrackItem.\n", 517 | "\n", 518 | "MultitrackItem keeps track of which notes are part of the melody and which notes are part of the chords. \n", 519 | "This info is needed for translation task" 520 | ] 521 | }, 522 | { 523 | "cell_type": "code", 524 | "metadata": { 525 | "id": "cux2aE7YcX4p", 526 | "colab_type": "code", 527 | "colab": {} 528 | }, 529 | "source": [ 530 | "multitrack_item = MultitrackItem.from_file(file, vocab)" 531 | ], 532 | "execution_count": 0, 533 | "outputs": [] 534 | }, 535 | { 536 | "cell_type": "code", 537 | "metadata": { 538 | "id": "fR4zixCEcX4r", 539 | "colab_type": "code", 540 | "colab": {} 541 | }, 542 | "source": [ 543 | "melody, chords = multitrack_item.melody, multitrack_item.chords" 544 | ], 545 | "execution_count": 0, 546 | "outputs": [] 547 | }, 548 | { 549 | "cell_type": "code", 550 | "metadata": { 551 | "id": "dl9nGobNcX4u", 552 | "colab_type": "code", 553 | "colab": {} 554 | }, 555 | "source": [ 556 | "melody.show()" 557 | ], 558 | "execution_count": 0, 559 | "outputs": [] 560 | }, 561 | { 562 | "cell_type": "code", 563 | "metadata": { 564 | "id": "J_AQG8QCcX4w", 565 | "colab_type": "code", 566 | "colab": {} 567 | }, 568 | "source": [ 569 | "chords.show()" 570 | ], 571 | "execution_count": 0, 572 | "outputs": [] 573 | }, 574 | { 575 | "cell_type": "code", 576 | "metadata": { 577 | "id": "ITC7MqBHcX4y", 578 | "colab_type": "code", 579 | "colab": {} 580 | }, 581 | "source": [ 582 | "multitrack_item.play()" 583 | ], 584 | "execution_count": 0, 585 | "outputs": [] 586 | }, 587 | { 588 | "cell_type": "code", 589 | "metadata": { 590 | "id": "DxdegKd9c9Ji", 591 | "colab_type": "code", 592 | "colab": {} 593 | }, 594 | "source": [ 595 | "play_wav(multitrack_item.stream)" 596 | ], 597 | "execution_count": 0, 598 | "outputs": [] 599 | }, 600 | { 601 | "cell_type": "markdown", 602 | "metadata": { 603 | "colab_type": "text", 604 | "id": "nRnZH289e97c" 605 | }, 606 | "source": [ 607 | "## 2a. Create Melody\n", 608 | "\n", 609 | "Use existing chord progression to generate a new melody" 610 | ] 611 | }, 612 | { 613 | "cell_type": "code", 614 | "metadata": { 615 | "colab_type": "code", 616 | "id": "RnSIGuCVe97X", 617 | "colab": {} 618 | }, 619 | "source": [ 620 | "# Use a seed for the melody\n", 621 | "partial_melody = melody.trim_to_beat(4)\n", 622 | "\n", 623 | "# Or generate from an empty sequence\n", 624 | "empty_melody = MusicItem.empty(vocab, seq_type=SEQType.Melody)" 625 | ], 626 | "execution_count": 0, 627 | "outputs": [] 628 | }, 629 | { 630 | "cell_type": "code", 631 | "metadata": { 632 | "colab_type": "code", 633 | "id": "9i_AWscAe97T", 634 | "colab": {} 635 | }, 636 | "source": [ 637 | "seed_melody = empty_melody; seed_melody.show()" 638 | ], 639 | "execution_count": 0, 640 | "outputs": [] 641 | }, 642 | { 643 | "cell_type": "code", 644 | "metadata": { 645 | "colab_type": "code", 646 | "id": "9srCEkhLe97P", 647 | "colab": {} 648 | }, 649 | "source": [ 650 | "pred_melody = learn.predict_s2s(chords, seed_melody, use_memory=True)\n", 651 | "pred_melody.show()" 652 | ], 653 | "execution_count": 0, 654 | "outputs": [] 655 | }, 656 | { 657 | "cell_type": "code", 658 | "metadata": { 659 | "colab_type": "code", 660 | "id": "UYGaz2Gme97E", 661 | "colab": {} 662 | }, 663 | "source": [ 664 | "play_wav(pred_melody.stream)" 665 | ], 666 | "execution_count": 0, 667 | "outputs": [] 668 | }, 669 | { 670 | "cell_type": "code", 671 | "metadata": { 672 | "colab_type": "code", 673 | "id": "67VVYrmle97B", 674 | "colab": {} 675 | }, 676 | "source": [ 677 | "combined = MultitrackItem(pred_melody, chords)\n", 678 | "combined.show()" 679 | ], 680 | "execution_count": 0, 681 | "outputs": [] 682 | }, 683 | { 684 | "cell_type": "code", 685 | "metadata": { 686 | "colab_type": "code", 687 | "id": "opwBpBmbe964", 688 | "colab": {} 689 | }, 690 | "source": [ 691 | "play_wav(combined.stream)" 692 | ], 693 | "execution_count": 0, 694 | "outputs": [] 695 | }, 696 | { 697 | "cell_type": "markdown", 698 | "metadata": { 699 | "colab_type": "text", 700 | "id": "WSoL0V_We963" 701 | }, 702 | "source": [ 703 | "## 2b. Harmonization\n", 704 | "\n", 705 | "Generate chords to accompany an existing melody" 706 | ] 707 | }, 708 | { 709 | "cell_type": "code", 710 | "metadata": { 711 | "colab_type": "code", 712 | "id": "zu5kct73e96y", 713 | "colab": {} 714 | }, 715 | "source": [ 716 | "# partial_chords = chords.trim_to_beat(3);\n", 717 | "# partial_chords.show()\n", 718 | "\n", 719 | "empty_chords = MusicItem.empty(vocab, seq_type=SEQType.Chords); empty_chords.show()" 720 | ], 721 | "execution_count": 0, 722 | "outputs": [] 723 | }, 724 | { 725 | "cell_type": "code", 726 | "metadata": { 727 | "colab_type": "code", 728 | "id": "2jr7znaQe96t", 729 | "colab": {} 730 | }, 731 | "source": [ 732 | "pred_chord = learn.predict_s2s(input_item=melody, target_item=empty_chords)" 733 | ], 734 | "execution_count": 0, 735 | "outputs": [] 736 | }, 737 | { 738 | "cell_type": "code", 739 | "metadata": { 740 | "colab_type": "code", 741 | "id": "4BX5Qhbee96n", 742 | "colab": {} 743 | }, 744 | "source": [ 745 | "pred_chord.show()" 746 | ], 747 | "execution_count": 0, 748 | "outputs": [] 749 | }, 750 | { 751 | "cell_type": "code", 752 | "metadata": { 753 | "colab_type": "code", 754 | "id": "efzCw_kfe96i", 755 | "colab": {} 756 | }, 757 | "source": [ 758 | "combined = MultitrackItem(melody, pred_chord)\n", 759 | "combined.show()" 760 | ], 761 | "execution_count": 0, 762 | "outputs": [] 763 | }, 764 | { 765 | "cell_type": "code", 766 | "metadata": { 767 | "colab_type": "code", 768 | "id": "J8JZo71Pe96S", 769 | "colab": {} 770 | }, 771 | "source": [ 772 | "play_wav(combined.stream)" 773 | ], 774 | "execution_count": 0, 775 | "outputs": [] 776 | }, 777 | { 778 | "cell_type": "code", 779 | "metadata": { 780 | "colab_type": "code", 781 | "id": "8itZPefde96M", 782 | "colab": {} 783 | }, 784 | "source": [ 785 | "# Convenience Function\n", 786 | "\n", 787 | "# out = s2s_predict_from_midi(learn, file, seed_len=10); out.show()" 788 | ], 789 | "execution_count": 0, 790 | "outputs": [] 791 | }, 792 | { 793 | "cell_type": "markdown", 794 | "metadata": { 795 | "id": "Td9sxdvocX4F", 796 | "colab_type": "text" 797 | }, 798 | "source": [ 799 | "## 3. Mask/Remix" 800 | ] 801 | }, 802 | { 803 | "cell_type": "markdown", 804 | "metadata": { 805 | "id": "chz6k0HUcX4G", 806 | "colab_type": "text" 807 | }, 808 | "source": [ 809 | "### 3a. Remix Notes\n", 810 | "\n", 811 | "Mask all the note pitches. Model will create a new song with the same rhythm" 812 | ] 813 | }, 814 | { 815 | "cell_type": "code", 816 | "metadata": { 817 | "id": "QJaWCnM0cX4H", 818 | "colab_type": "code", 819 | "colab": {} 820 | }, 821 | "source": [ 822 | "### Mask notes\n", 823 | "note_item = item.mask_pitch();" 824 | ], 825 | "execution_count": 0, 826 | "outputs": [] 827 | }, 828 | { 829 | "cell_type": "code", 830 | "metadata": { 831 | "id": "AMhdMYDIcX4J", 832 | "colab_type": "code", 833 | "colab": {} 834 | }, 835 | "source": [ 836 | "# Mask vs Original\n", 837 | "list(zip(note_item.to_text(None)[:20], item.to_text(None)[:20]))" 838 | ], 839 | "execution_count": 0, 840 | "outputs": [] 841 | }, 842 | { 843 | "cell_type": "code", 844 | "metadata": { 845 | "id": "lQJ0GbqDcX4L", 846 | "colab_type": "code", 847 | "colab": {} 848 | }, 849 | "source": [ 850 | "pred_note = learn.predict_mask(note_item, temperatures=(1.4, 1.0))" 851 | ], 852 | "execution_count": 0, 853 | "outputs": [] 854 | }, 855 | { 856 | "cell_type": "code", 857 | "metadata": { 858 | "id": "DxDg7aDZcX4P", 859 | "colab_type": "code", 860 | "colab": {} 861 | }, 862 | "source": [ 863 | "pred_note.show()" 864 | ], 865 | "execution_count": 0, 866 | "outputs": [] 867 | }, 868 | { 869 | "cell_type": "code", 870 | "metadata": { 871 | "id": "vAWuXYQqcX4a", 872 | "colab_type": "code", 873 | "colab": {} 874 | }, 875 | "source": [ 876 | "play_wav(pred_note.stream)" 877 | ], 878 | "execution_count": 0, 879 | "outputs": [] 880 | }, 881 | { 882 | "cell_type": "markdown", 883 | "metadata": { 884 | "id": "x5-pfAv2cX4d", 885 | "colab_type": "text" 886 | }, 887 | "source": [ 888 | "### 3b. Remix rhythm\n", 889 | "\n", 890 | "Mask note durations. Same notes, different rhythm" 891 | ] 892 | }, 893 | { 894 | "cell_type": "code", 895 | "metadata": { 896 | "id": "w3U-r2YBcX4d", 897 | "colab_type": "code", 898 | "colab": {} 899 | }, 900 | "source": [ 901 | "# duration mask\n", 902 | "dur_item = item.mask_duration()" 903 | ], 904 | "execution_count": 0, 905 | "outputs": [] 906 | }, 907 | { 908 | "cell_type": "code", 909 | "metadata": { 910 | "id": "rsNtNDpNcX4f", 911 | "colab_type": "code", 912 | "colab": {} 913 | }, 914 | "source": [ 915 | "# Mask vs Original\n", 916 | "list(zip(dur_item.to_text(None)[:10], item.to_text(None)[:10]))" 917 | ], 918 | "execution_count": 0, 919 | "outputs": [] 920 | }, 921 | { 922 | "cell_type": "code", 923 | "metadata": { 924 | "id": "MaTeilzpcX4h", 925 | "colab_type": "code", 926 | "colab": {} 927 | }, 928 | "source": [ 929 | "dur_pred = learn.predict_mask(dur_item, temperatures=(0.8,0.8), top_k=40, top_p=0.6)" 930 | ], 931 | "execution_count": 0, 932 | "outputs": [] 933 | }, 934 | { 935 | "cell_type": "code", 936 | "metadata": { 937 | "id": "_sDkCHA9cX4k", 938 | "colab_type": "code", 939 | "colab": {} 940 | }, 941 | "source": [ 942 | "dur_pred.show()" 943 | ], 944 | "execution_count": 0, 945 | "outputs": [] 946 | }, 947 | { 948 | "cell_type": "code", 949 | "metadata": { 950 | "id": "ecJL5whdc4-8", 951 | "colab_type": "code", 952 | "colab": {} 953 | }, 954 | "source": [ 955 | "play_wav(dur_pred.stream)" 956 | ], 957 | "execution_count": 0, 958 | "outputs": [] 959 | }, 960 | { 961 | "cell_type": "code", 962 | "metadata": { 963 | "id": "9AVn6QrHcX4m", 964 | "colab_type": "code", 965 | "colab": {} 966 | }, 967 | "source": [ 968 | "# Convenience function\n", 969 | "# out = mask_predict_from_midi(learn, file, predict_notes=True)" 970 | ], 971 | "execution_count": 0, 972 | "outputs": [] 973 | } 974 | ] 975 | } --------------------------------------------------------------------------------