├── .DS_Store ├── img ├── 1.png ├── 2.png ├── 3.png ├── 4.png └── 5.png ├── src ├── .DS_Store ├── ds_files.pt ├── lib │ ├── .DS_Store │ ├── model │ │ ├── __pycache__ │ │ │ ├── rpr.cpython-37.pyc │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── transformer.cpython-37.pyc │ │ │ ├── transformer_bpe.cpython-37.pyc │ │ │ └── positional_encoding.cpython-37.pyc │ │ ├── positional_encoding.py │ │ ├── transformer.py │ │ └── rpr.py │ ├── constants.py │ ├── encoded_dataset.py │ ├── inverse_power_with_warmup_sheduler.py │ ├── colab_utils.py │ ├── augmentations.py │ ├── midi_processing.py │ └── generation.py ├── encoded_dataset │ ├── .DS_Store │ ├── pop │ │ ├── pop_0_80612d513c2c998c87cca766a76be91a_0.pt │ │ └── pop_0_80612d513c2c998c87cca766a76be91a_1.pt │ ├── calm │ │ ├── calm_0_7bfc0a94983dd5eb495ae0555efa4521_0.pt │ │ ├── calm_0_7bfc0a94983dd5eb495ae0555efa4521_1.pt │ │ └── calm_0_7bfc0a94983dd5eb495ae0555efa4521_2.pt │ ├── jazz │ │ ├── jazz_0_4cf8f1246ebc3c24375aca4539fd8adb_0.pt │ │ ├── jazz_0_4cf8f1246ebc3c24375aca4539fd8adb_1.pt │ │ └── jazz_0_4cf8f1246ebc3c24375aca4539fd8adb_2.pt │ └── classic │ │ ├── classic_0_94d87b7dc3b6ee96d83f8a173085ce8c_0.pt │ │ ├── classic_0_94d87b7dc3b6ee96d83f8a173085ce8c_1.pt │ │ ├── classic_0_94d87b7dc3b6ee96d83f8a173085ce8c_2.pt │ │ └── classic_0_94d87b7dc3b6ee96d83f8a173085ce8c_3.pt ├── test_dataset │ ├── pop │ │ └── pop_0.mid │ ├── calm │ │ └── calm_0.mid │ ├── jazz │ │ └── jazz_0.mid │ └── classic │ │ └── classic_0.mid ├── encode_dataset.py ├── generate.py ├── train.py ├── Music_Composer_Demo_Colab_en.ipynb └── Music_Composer_Demo_Colab_ru.ipynb ├── gpt2-rga ├── README.md └── [GPT2RGA] Quantum_Music.ipynb ├── README.md └── README_ru.md /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/music-composer/main/.DS_Store -------------------------------------------------------------------------------- /img/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/music-composer/main/img/1.png -------------------------------------------------------------------------------- /img/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/music-composer/main/img/2.png -------------------------------------------------------------------------------- /img/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/music-composer/main/img/3.png -------------------------------------------------------------------------------- /img/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/music-composer/main/img/4.png -------------------------------------------------------------------------------- /img/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/music-composer/main/img/5.png -------------------------------------------------------------------------------- /src/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/music-composer/main/src/.DS_Store -------------------------------------------------------------------------------- /src/ds_files.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/music-composer/main/src/ds_files.pt -------------------------------------------------------------------------------- /src/lib/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/music-composer/main/src/lib/.DS_Store -------------------------------------------------------------------------------- /src/encoded_dataset/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/music-composer/main/src/encoded_dataset/.DS_Store -------------------------------------------------------------------------------- /src/test_dataset/pop/pop_0.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/music-composer/main/src/test_dataset/pop/pop_0.mid -------------------------------------------------------------------------------- /src/test_dataset/calm/calm_0.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/music-composer/main/src/test_dataset/calm/calm_0.mid -------------------------------------------------------------------------------- /src/test_dataset/jazz/jazz_0.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/music-composer/main/src/test_dataset/jazz/jazz_0.mid -------------------------------------------------------------------------------- /src/test_dataset/classic/classic_0.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/music-composer/main/src/test_dataset/classic/classic_0.mid -------------------------------------------------------------------------------- /src/lib/model/__pycache__/rpr.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/music-composer/main/src/lib/model/__pycache__/rpr.cpython-37.pyc -------------------------------------------------------------------------------- /src/lib/model/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/music-composer/main/src/lib/model/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /src/lib/model/__pycache__/transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/music-composer/main/src/lib/model/__pycache__/transformer.cpython-37.pyc -------------------------------------------------------------------------------- /src/lib/model/__pycache__/transformer_bpe.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/music-composer/main/src/lib/model/__pycache__/transformer_bpe.cpython-37.pyc -------------------------------------------------------------------------------- /src/lib/model/__pycache__/positional_encoding.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/music-composer/main/src/lib/model/__pycache__/positional_encoding.cpython-37.pyc -------------------------------------------------------------------------------- /src/encoded_dataset/pop/pop_0_80612d513c2c998c87cca766a76be91a_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/music-composer/main/src/encoded_dataset/pop/pop_0_80612d513c2c998c87cca766a76be91a_0.pt -------------------------------------------------------------------------------- /src/encoded_dataset/pop/pop_0_80612d513c2c998c87cca766a76be91a_1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/music-composer/main/src/encoded_dataset/pop/pop_0_80612d513c2c998c87cca766a76be91a_1.pt -------------------------------------------------------------------------------- /src/encoded_dataset/calm/calm_0_7bfc0a94983dd5eb495ae0555efa4521_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/music-composer/main/src/encoded_dataset/calm/calm_0_7bfc0a94983dd5eb495ae0555efa4521_0.pt -------------------------------------------------------------------------------- /src/encoded_dataset/calm/calm_0_7bfc0a94983dd5eb495ae0555efa4521_1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/music-composer/main/src/encoded_dataset/calm/calm_0_7bfc0a94983dd5eb495ae0555efa4521_1.pt -------------------------------------------------------------------------------- /src/encoded_dataset/calm/calm_0_7bfc0a94983dd5eb495ae0555efa4521_2.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/music-composer/main/src/encoded_dataset/calm/calm_0_7bfc0a94983dd5eb495ae0555efa4521_2.pt -------------------------------------------------------------------------------- /src/encoded_dataset/jazz/jazz_0_4cf8f1246ebc3c24375aca4539fd8adb_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/music-composer/main/src/encoded_dataset/jazz/jazz_0_4cf8f1246ebc3c24375aca4539fd8adb_0.pt -------------------------------------------------------------------------------- /src/encoded_dataset/jazz/jazz_0_4cf8f1246ebc3c24375aca4539fd8adb_1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/music-composer/main/src/encoded_dataset/jazz/jazz_0_4cf8f1246ebc3c24375aca4539fd8adb_1.pt -------------------------------------------------------------------------------- /src/encoded_dataset/jazz/jazz_0_4cf8f1246ebc3c24375aca4539fd8adb_2.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/music-composer/main/src/encoded_dataset/jazz/jazz_0_4cf8f1246ebc3c24375aca4539fd8adb_2.pt -------------------------------------------------------------------------------- /src/encoded_dataset/classic/classic_0_94d87b7dc3b6ee96d83f8a173085ce8c_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/music-composer/main/src/encoded_dataset/classic/classic_0_94d87b7dc3b6ee96d83f8a173085ce8c_0.pt -------------------------------------------------------------------------------- /src/encoded_dataset/classic/classic_0_94d87b7dc3b6ee96d83f8a173085ce8c_1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/music-composer/main/src/encoded_dataset/classic/classic_0_94d87b7dc3b6ee96d83f8a173085ce8c_1.pt -------------------------------------------------------------------------------- /src/encoded_dataset/classic/classic_0_94d87b7dc3b6ee96d83f8a173085ce8c_2.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/music-composer/main/src/encoded_dataset/classic/classic_0_94d87b7dc3b6ee96d83f8a173085ce8c_2.pt -------------------------------------------------------------------------------- /src/encoded_dataset/classic/classic_0_94d87b7dc3b6ee96d83f8a173085ce8c_3.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/music-composer/main/src/encoded_dataset/classic/classic_0_94d87b7dc3b6ee96d83f8a173085ce8c_3.pt -------------------------------------------------------------------------------- /gpt2-rga/README.md: -------------------------------------------------------------------------------- 1 | GPT2 RGA Version 2 | 3 | GPT2 is significantly superior to Transformer XL so it is best to use GPT2 with RGA for best resutls. 4 | 5 | Particularly you want to pay attention to continuations because GPT2 is the only architechture that can handle it properly 6 | 7 | Alex 8 | -------------------------------------------------------------------------------- /src/lib/constants.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | RANGE_NOTE_ON = 128 4 | RANGE_NOTE_OFF = 128 5 | RANGE_VEL = 32 6 | RANGE_TIME_SHIFT = 100 7 | 8 | TOKEN_END = RANGE_NOTE_ON + RANGE_NOTE_OFF + RANGE_VEL + RANGE_TIME_SHIFT 9 | TOKEN_PAD = TOKEN_END + 1 10 | VOCAB_SIZE = TOKEN_PAD + 1 + 4 11 | 12 | TORCH_FLOAT = torch.float32 13 | TORCH_INT = torch.int32 14 | 15 | TORCH_LABEL_TYPE = torch.long 16 | 17 | PREPEND_ZEROS_WIDTH = 4 18 | 19 | -------------------------------------------------------------------------------- /src/lib/model/positional_encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | # PositionalEncoding 6 | # Taken from https://pytorch.org/tutorials/beginner/transformer_tutorial.html 7 | class PositionalEncoding(nn.Module): 8 | 9 | def __init__(self, d_model, dropout=0.1, max_len=5000): 10 | super(PositionalEncoding, self).__init__() 11 | self.dropout = nn.Dropout(p=dropout) 12 | 13 | pe = torch.zeros(max_len, d_model) 14 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 15 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 16 | pe[:, 0::2] = torch.sin(position * div_term) 17 | pe[:, 1::2] = torch.cos(position * div_term) 18 | pe = pe.unsqueeze(0).transpose(0, 1) 19 | self.register_buffer('pe', pe) 20 | 21 | def forward(self, x): 22 | x = x + self.pe[:x.size(0), :] 23 | return self.dropout(x) 24 | -------------------------------------------------------------------------------- /src/lib/encoded_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset 4 | 5 | 6 | class EncodedDataset(Dataset): 7 | """ 8 | Dataset class for training and evaluating the model. 9 | 10 | Parameters 11 | ---------- 12 | ds_files : str 13 | path to file 'ds_files.pt'. The file contains the list of paths to encoded sequences (samples of dataset). 14 | prefix_path : str 15 | prefix_path will be added to paths in 'ds_files.pt'. Used sometimes for convenience. 16 | transform : MusicAugmentations 17 | in-fly augmentations for sequences. 18 | """ 19 | def __init__(self, ds_files, prefix_path='', transform=None): 20 | self.transform = transform 21 | self.files = torch.load(ds_files) 22 | self.prefix_path = prefix_path 23 | self.genre2id = {'classic':0, 'jazz':1, 'calm':2, 'pop':3} 24 | self.genre = [self.genre2id.get(f.split('/')[1], 0) for f in self.files] # 1 for 'encoded_data/GENRE/xxx.pt' 25 | 26 | def __len__(self): 27 | return len(self.files) 28 | 29 | def __getitem__(self, idx): 30 | x = torch.load(self.prefix_path + self.files[idx]) 31 | if self.transform: 32 | x = torch.from_numpy(self.transform(x)) 33 | genre = self.genre[idx] 34 | return x, genre, idx 35 | -------------------------------------------------------------------------------- /src/lib/inverse_power_with_warmup_sheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class InversePowerWithWarmupLRScheduler(torch.optim.lr_scheduler._LRScheduler): 5 | """ 6 | Warmup learning rate until `warmup_steps` then apply inverse power function. 7 | It is more flexible version of inverse sqrt function. 8 | base formula after warmup: 1 / (i + shift) ** power 9 | 10 | Parameters 11 | ---------- 12 | optimizer : torch.optim 13 | PyTorch optimizer for model weights updates. 14 | peak_lr : float 15 | maximum learning rate at peak. 16 | warmup_steps : int 17 | Number of warmup steps. 18 | power : float 19 | power for LR function. Set to 1/2 for inverse sqrt function. 20 | shift : int 21 | shift helps control the duration of decay. 22 | last_epoch : int 23 | last_epoch is treated as last_step in this scheduler. 24 | """ 25 | def __init__(self, optimizer, peak_lr, warmup_steps, power=0.5, shift=0, last_epoch=-1): 26 | self.peak_lr = peak_lr 27 | self.warmup_steps = warmup_steps 28 | self.power = power 29 | self.shift = shift 30 | self.warmup_rate = self.peak_lr / (self.warmup_steps) 31 | self.decay_factor = self.peak_lr * (self.warmup_steps + self.shift) ** self.power 32 | super(InversePowerWithWarmupLRScheduler, self).__init__(optimizer, last_epoch=last_epoch) 33 | 34 | def get_lr(self): 35 | i = self.last_epoch 36 | if i < self.warmup_steps: 37 | lr = self.warmup_rate * (i+1) 38 | else: 39 | lr = self.decay_factor / (i + self.shift) ** self.power 40 | 41 | lrs = [lr for _ in self.optimizer.param_groups] 42 | return lrs -------------------------------------------------------------------------------- /src/lib/colab_utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import torch 3 | import base64 4 | import hashlib 5 | import ipywidgets 6 | import numpy as np 7 | import lib.midi_processing 8 | from midi2audio import FluidSynth 9 | from torch.utils.data import Dataset 10 | from IPython.display import Audio, display, FileLink, HTML 11 | 12 | 13 | id2genre = {0:'classic', 1:'jazz', 2:'calm', 3:'pop'} 14 | rugenre = {'classic': 'Классика', 'jazz': 'Джаз', 'calm': 'Эмбиент', 'pop': 'Поп'} 15 | genre2id = dict([[x[1],x[0]] for x in id2genre.items()]) 16 | tuned_params = { 17 | 0: 1.1, 18 | 1: 0.95, 19 | 2: 0.9, 20 | 3: 1.0 21 | } 22 | 23 | def decode_and_write(generated, primer, genre, out_dir): 24 | '''Decodes midi files from event-based format and writes them to disk''' 25 | if len(glob.glob(out_dir + '/*.mid')) != 0: 26 | ids = [int(path.split('_')[-2]) for path in glob.glob(out_dir + '/*.mid')] 27 | start_from = max(ids) 28 | else: 29 | start_from = 0 30 | 31 | for i, (gen, g) in enumerate(zip(generated, genre)): 32 | midi = lib.midi_processing.decode(gen) 33 | midi.write(f'{out_dir}/gen_{i + start_from:>02}_{id2genre[g]}.mid') 34 | 35 | def convert_midi_to_wav(midi_path): 36 | '''Converts MIDI to WAV format for listening in Colab''' 37 | FluidSynth("font.sf2").midi_to_audio(midi_path, midi_path.replace('.mid', '.wav')) 38 | 39 | class DownloadButton(ipywidgets.Button): 40 | """Download button with dynamic content 41 | 42 | The content is generated using a callback when the button is clicked. 43 | """ 44 | 45 | def __init__(self, filename: str, **kwargs): 46 | super(DownloadButton, self).__init__(**kwargs) 47 | self.filename = filename 48 | self.on_click(self.__on_click) 49 | 50 | def __on_click(self, b): 51 | with open(self.filename, 'rb') as f: 52 | b64 = base64.b64encode(f.read()) 53 | payload = b64.decode() 54 | digest = hashlib.md5(self.filename.encode()).hexdigest() # bypass browser cache 55 | id = f'dl_{digest}' 56 | 57 | display(HTML(f""" 58 | 59 | 60 | 61 | 62 | 63 | 68 | 69 | 70 | 71 | """)) 72 | -------------------------------------------------------------------------------- /src/lib/augmentations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from midi_processing import RANGES 3 | 4 | RANGES_SUM = np.cumsum(RANGES) 5 | 6 | class MusicAugmentations: 7 | def __init__(self, transpose=(-3,3), time_stretch=(0.95,0.975,1.0,1.025,1.05)): 8 | """ 9 | Class for applying random transpose and time_stretch augmentations for encoded sequences. 10 | 11 | Parameters 12 | ---------- 13 | transpose : tuple(min, max) 14 | range for transpose in pitches. 15 | time_stretch : list 16 | list of time_stretch multipliers to sample from. 17 | """ 18 | self.transpose = range(transpose[0], transpose[1]+1) 19 | self.time_stretch = time_stretch 20 | 21 | def __call__(self, encoded): 22 | """encoded: list or 1D np.ndarray""" 23 | transpose = np.random.choice(self.transpose) 24 | # time_stretch = np.random.uniform(*self.time_stretch) 25 | time_stretch = np.random.choice(self.time_stretch) 26 | return augment(encoded, transpose, time_stretch) 27 | 28 | def augment(encoded, transpose, time_stretch): 29 | """ 30 | Applies transpose and time_stretch augmentation for encoded sequence. Inplace operation. 31 | 32 | Parameters 33 | ---------- 34 | encoded : np.ndarray or list 35 | encoded sequence (input for model). 36 | transpose : int 37 | bias for transpose in pitches. 38 | time_stretch : float 39 | time_stretch multiplier. 40 | 41 | Returns 42 | ------- 43 | encoded : np.array or list 44 | augmented sequence. 45 | """ 46 | for i,ev in enumerate(encoded): 47 | if ev < RANGES_SUM[0]: 48 | # NOTE_ON 49 | encoded[i] = min(RANGES_SUM[0]-1, max(0, ev+transpose)) 50 | elif ev < RANGES_SUM[1]: 51 | # NOTE_OFF 52 | encoded[i] = min(RANGES_SUM[1]-1, max(RANGES_SUM[0], ev+transpose)) 53 | elif ev < RANGES_SUM[2]: 54 | # SET_VELOCITY 55 | pass 56 | elif ev < RANGES_SUM[3] and time_stretch != 1.0: 57 | # TIME_SHIFT 58 | t = ev - RANGES_SUM[2] + 1 # since 0 = 10ms 59 | t = max(min(RANGES[3], int(round(t*time_stretch))), 1) 60 | encoded[i] = t + RANGES_SUM[2] - 1 61 | else: 62 | continue 63 | return encoded -------------------------------------------------------------------------------- /src/encode_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import joblib 4 | import hashlib 5 | import pretty_midi 6 | import numpy as np 7 | from tqdm import tqdm 8 | from pathlib import Path 9 | from concurrent.futures import ProcessPoolExecutor 10 | 11 | from lib import constants 12 | from lib import midi_processing 13 | 14 | DATA_DIR = 'test_dataset' 15 | OUTPUT_DIR = 'encoded_dataset' 16 | DS_FILE_PATH = './ds_files.pt' # path where ds_files.pt will be created 17 | 18 | GENRES = ['classic', 'jazz', 'calm', 'pop'] 19 | MAX_LEN = 2048 20 | 21 | print('creating dirs...') 22 | [os.makedirs(OUTPUT_DIR+'/'+g, exist_ok=True) for g in GENRES] 23 | 24 | print('collecting *.mid files...') 25 | FILES = list(map(str, Path(DATA_DIR).rglob('*.mid'))) 26 | 27 | def encode_fn(i): 28 | """wrapper for loading i-th midi-file, encoding, padding and saving encoded tensor on disk""" 29 | file = FILES[i] 30 | max_len = MAX_LEN 31 | 32 | path, fname = os.path.split(file) 33 | try: 34 | midi = pretty_midi.PrettyMIDI(file) 35 | genre = path.split('/')[1] # take GENRE from 'data/GENRE/xxx.mid' 36 | except: 37 | print(f'{i} not loaded') 38 | return -1 39 | 40 | assert genre in GENRES, f'{genre} is not in {GENRES}' 41 | 42 | fname, ext = os.path.splitext(fname) 43 | h = hashlib.md5(file.encode()).hexdigest() 44 | save_name = f'{OUTPUT_DIR}/{genre}/{fname}_{h}' 45 | 46 | events = midi_processing.encode(midi, use_piano_range=True) 47 | events = np.array(events) 48 | split_idxs = np.cumsum([max_len]*(events.shape[0]//max_len)) 49 | splits = np.split(events, split_idxs, axis=0) 50 | n_last = splits[-1].shape[0] 51 | if n_last < 256: 52 | splits.pop(-1) 53 | drop_last = 1 54 | else: 55 | drop_last = 0 56 | 57 | for i, split in enumerate(splits): 58 | keep_idxs = midi_processing.filter_bad_note_offs(split) 59 | split = split[keep_idxs] 60 | eos_idx = min(max_len - 1, len(split)) 61 | split = np.pad(split, [[0,max_len - len(split)]]) 62 | split[eos_idx] = constants.TOKEN_END 63 | try: 64 | torch.save(split, f'{save_name}_{i}.pt') 65 | except OSError: # if fname is too long 66 | save_name = f'{OUTPUT_DIR}/{genre}/{h}' 67 | torch.save(split, f'{save_name}_{i}.pt') 68 | return drop_last 69 | 70 | cpu_count = joblib.cpu_count() 71 | print(f'starting encoding in {cpu_count} processes...') 72 | with ProcessPoolExecutor(cpu_count) as pool: 73 | x = list(tqdm(pool.map(encode_fn, range(len(FILES))), position=0, total=len(FILES))) 74 | 75 | print('collecting encoded (*.pt) files...') 76 | ds_files = list(map(str, Path(OUTPUT_DIR).rglob('*.pt'))) 77 | print('total encoded files:', len(ds_files)) 78 | 79 | torch.save(ds_files, DS_FILE_PATH) 80 | print('ds_files.pt saved to', os.path.abspath(DS_FILE_PATH)) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Music Composer 2 | This repository is dedicated to synthesizing symbolic music in MIDI format using the Music Transformer model (103M paramaters). In the repository, you can find a demo notebook for generating on a GPU Google Colab instance, data preparation and model training code. 3 | 4 | ## Table of Contents 5 | 1. [Demo notebook](#demo-notebook) 6 | 2. [Model code](#model-code) 7 | 3. [Data](#data) 8 | 4. [Training](#training) 9 | 10 | 11 | ## Demo notebook 12 | 13 | Jupyter Notebook can be opened on Colab by clicking on the button: 14 | 15 | [![Open in colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sberbank-ai/music-composer/blob/main/src/Music_Composer_Demo_Colab_en.ipynb) 16 | 17 | It unrolls the environment, loads the code and weights for synthesis. Generation parameters are set in the generation control panel, and you can listen and download the results in the last cell. 18 | 19 | ❗ Make sure the GPU instance is being used at startup. It is possible to synthesize on a CPU, but it takes significantly more time. 20 | 21 | ## Model code 22 | Located in [folder](https://github.com/sberbank-ai/music-composer/tree/main/src/lib/model). 23 | Consists of three main parts: 24 | - Positional encoding - normal positional encoding for transformer models 25 | - Relative Positional Representation - a module with the implementation of Relative Attention 26 | - Transformer - the model itself is a transformer 27 | 28 | Model code and relative attention taken from [repository](https://github.com/gwinndr/MusicTransformer-Pytorch). 29 | 30 | ## Data 31 | To demonstrate the encoding script, we provide several MIDI files from our training sample. They are located in the src / test_dataset folder and are divided into folders by genre. Each folder contains one file to check. You can start preparing event-based versions of these files using the command: 32 | ```python encode_dataset.py``` 33 | 34 | The folder with the source MIDI and the folder for the results are set inside the script through the variables `DATA_DIR`,` OUTPUT_DIR`. Dataset files with file paths will be created in `DS_FILE_PATH`. The genre list is specified using `GENRES`, and the maximum record length in event tokens is` MAX_LEN`. 35 | 36 | For demonstration, we also provide the output of this command in the encoded_dataset folder. It contains tensors with MIDI converted to event-based format. They can be loaded using the standard `torch.load (file_path)` 37 | Datasets can be used as public MIDI for training: 38 | [MAESTRO Dataset](https://magenta.tensorflow.org/datasets/maestro) 39 | [Lakh MIDI Dataset](https://colinraffel.com/projects/lmd/) 40 | [GiantMIDI-Piano Dataset](https://github.com/bytedance/GiantMIDI-Piano) 41 | 42 | There is another way to get MIDI files - transcribing wave files with music. An approach like [Onset-frames] (https://magenta.tensorflow.org/onsets-frames) can help with this. 43 | As music for transcription, you can use for example [Free Music Archive] (https://github.com/mdeff/fma). 44 | ❗Significant resources may be required for transcribing, but this is exactly what will allow to get around the main limitation of the current models of symbolic music generation - the absence of large corpora with notes. 45 | ❗ After transcription, it is recommended to analyze the results and filter out bad recordings. 46 | ## Training 47 | A script for training a model on prepared data can be run using: 48 | ```python train.py``` 49 | Training parameters are set inside the script in the params variable. A description of each of the parameters will be given later in this section. 50 | -------------------------------------------------------------------------------- /README_ru.md: -------------------------------------------------------------------------------- 1 | # Music Composer 2 | Данный репозиторий посвящен синтезу символьной музыки в MIDI формате с помощью модели Music Transformer. В репозитории можно найти демонстрационный ноутбук для генерации на GPU инстансе Google Colab, код подготовки данных и обучения модели. 3 | 4 | ## Оглавление 5 | 1. [Демонстрационный ноутбук](#демонстрационный-ноутбук) 6 | 2. [Код модели](#код-модели) 7 | 3. [Данные](#данные) 8 | 4. [Обучение](#обучение) 9 | 10 | 11 | ## Демонстрационный ноутбук 12 | 13 | Jupyter Notebook можно открыть на Colab нажав на кнопку: 14 | 15 | [![Open in colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sberbank-ai/music-composer/blob/main/src/Music_Composer_Demo_Colab_ru.ipynb) 16 | 17 | В нем производится разворачивание окружения, подгрузка кода и весов для синтеза. Параметры генерации задаются в панели управления генерацией, а прослушать и скачать результаты можно в последней ячейке. 18 | 19 | ❗При запуске убедитесь в том, что используется GPU инстанс. Можно синтезировать и на CPU, но это занимает ощутимо больше времени. 20 | 21 | ## Код модели 22 | Расположен в [папке](https://github.com/sberbank-ai/music-composer/tree/main/src/lib/model). 23 | Состоит из трех осовных частей: 24 | - Positional encoding - обычное позиционное кодирование для трансформерных моделей 25 | - Relative Positional Representation - модуль с реализацией Relative Attention 26 | - Transformer - сама модель трансформер 27 | 28 | Код модели и relative attention взят из [репозитория](https://github.com/gwinndr/MusicTransformer-Pytorch). 29 | 30 | ## Данные 31 | Для демонстрации скрипта энкодинга мы предоставляем несколько MIDI файлов из нашей обучающей выборки. Они находятся в папке src/test_dataset и разбиты по папкам на жанры. В каждой папке по одному файлу для проверки. Запустить подготовку закодированных в event-based формате версий этих файлов можно с помощью команды: 32 | ```python encode_dataset.py``` 33 | 34 | Папка с исходными MIDI и папка для результатов задаются внутри скрипта через переменные `DATA_DIR`, `OUTPUT_DIR`. Файлы датасетов с путями до файлов будут созданы в `DS_FILE_PATH`. Список жанров задается через `GENRES`, а максимальная длина записи в event токенах - `MAX_LEN`. 35 | 36 | Для демонстрации мы также предоставляем результат работы данной команды в папке encoded_dataset. В нем находятся тензоры с MIDI, переведенными в event-based формат. Их можно загрузить с помощью стандартного `torch.load(file_path)` 37 | В качестве общедоступных MIDI для обучения можно использовать датасеты: 38 | [MAESTRO Dataset](https://magenta.tensorflow.org/datasets/maestro) 39 | [Lakh MIDI Dataset](https://colinraffel.com/projects/lmd/) 40 | [GiantMIDI-Piano Dataset](https://github.com/bytedance/GiantMIDI-Piano) 41 | 42 | Есть еще один способ получения MIDI файлов - транскрибирование волновых файлов с музыкой. В этом может помочь подход наподобие [Onset-frames](https://magenta.tensorflow.org/onsets-frames). 43 | В качестве музыки для транскрибирования можно использовать например [Free Music Archive](https://github.com/mdeff/fma). 44 | ❗Для транскрибирования могут потребоваться значительные ресурсы, однако именно это позволит обойти основное ограничение текущих моделей генерации символьной музыки - отсутствие крупных корпусов с нотами. 45 | ❗ После транскрибирования рекомендуется проанализировать результаты и отфильтровать плохие записи. 46 | ## Обучение 47 | Скрипт для обучения модели на подготовленных данных можно запустить с помощью: 48 | ```python train.py``` 49 | Параметры обучения задаются внутри скрипта в переменной params. Позднее в данном разделе будет дано описание каждого из параметров. 50 | -------------------------------------------------------------------------------- /src/generate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import argparse 5 | import pretty_midi 6 | import numpy as np 7 | from tqdm import tqdm 8 | 9 | from lib import constants 10 | from lib import midi_processing 11 | from lib import generation 12 | from lib.midi_processing import PIANO_RANGE 13 | from lib.model.transformer import MusicTransformer 14 | 15 | 16 | def decode_and_write(generated, primer, genre, out_dir): 17 | '''Decodes event-based format to midi and writes resulting file to disk''' 18 | for i, (gen, g) in enumerate(zip(generated, genre)): 19 | midi = midi_processing.decode(gen) 20 | midi.write(f'{out_dir}/gen_{i:>02}_{id2genre[g]}.mid') 21 | 22 | 23 | id2genre = {0:'classic',1:'jazz',2:'calm',3:'pop'} 24 | genre2id = dict([[x[1],x[0]] for x in id2genre.items()]) 25 | tuned_params = { 26 | 0: 1.1, 27 | 1: 0.95, 28 | 2: 0.9, 29 | 3: 1.0 30 | } 31 | 32 | 33 | if __name__ == '__main__': 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument('--genre') 36 | parser.add_argument('--target_seq_length', default=512, type=int) 37 | parser.add_argument('--temperature', default=None, type=float) 38 | parser.add_argument('--topk', default=40, type=int) 39 | parser.add_argument('--topp', default=0.99, type=float) 40 | parser.add_argument('--topp_temperature', default=1.0, type=float) 41 | parser.add_argument('--at_least_k', default=1, type=int) 42 | parser.add_argument('--use_rp', action='store_true') 43 | parser.add_argument('--rp_penalty', default=0.05, type=int) 44 | parser.add_argument('--rp_restore_speed', default=0.7, type=int) 45 | parser.add_argument('--seed', default=None, type=int) 46 | parser.add_argument('--device', default='cuda:0') 47 | parser.add_argument('--keep_bad_generations', action='store_true') 48 | parser.add_argument('--out_dir', default=None) 49 | parser.add_argument('--load_path', default=None) 50 | parser.add_argument('--batch_size', default=8, type=int) 51 | args = parser.parse_args() 52 | 53 | 54 | try: 55 | genre_id = genre2id[args.genre] 56 | except KeyError: 57 | raise KeyError("Invalid genre name. Use one of ['classic', 'jazz', 'calm', 'pop']") 58 | 59 | load_path = args.load_path or '../checkpoints/model_big_v3_378k.pt' 60 | out_dir = args.out_dir or ('generated_' + time.strftime('%d-%m-%Y_%H-%M-%S')) 61 | batch_size = args.batch_size 62 | device = torch.device(args.device) 63 | remove_bad_generations = not args.keep_bad_generations 64 | 65 | default_params = dict( 66 | target_seq_length = 512, 67 | temperature = tuned_params[genre_id], 68 | topk = 40, 69 | topp = 0.99, 70 | topp_temperature = 1.0, 71 | at_least_k = 1, 72 | use_rp = False, 73 | rp_penalty = 0.05, 74 | rp_restore_speed = 0.7, 75 | seed = None, 76 | ) 77 | 78 | params = {k:args.__dict__[k] if args.__dict__[k] else default_params[k] for k in default_params} 79 | 80 | os.makedirs(out_dir, exist_ok=True) 81 | 82 | # init model 83 | print('loading model...') 84 | model = MusicTransformer(device, n_layers=12, d_model=1024, dim_feedforward=2048, num_heads=16, vocab_size=constants.VOCAB_SIZE, rpr=True).to(device).eval() 85 | model.load_state_dict(torch.load(load_path, map_location=device)) 86 | 87 | # add information about genre (first token) 88 | primer_genre = np.repeat([genre_id], batch_size) 89 | primer = torch.tensor(primer_genre)[:,None] + constants.VOCAB_SIZE - 4 90 | 91 | print('generating to:', os.path.abspath(out_dir)) 92 | generated = generation.generate(model, primer, **params) 93 | generated = generation.post_process(generated, remove_bad_generations=remove_bad_generations) 94 | 95 | decode_and_write(generated, primer, primer_genre, out_dir) -------------------------------------------------------------------------------- /src/lib/midi_processing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pretty_midi 3 | 4 | 5 | NOTE_ON = 0 6 | NOTE_OFF = 1 7 | SET_VELOCITY = 2 8 | TIME_SHIFT = 3 9 | 10 | MAX_TIME_SHIFT = 1.0 11 | TIME_SHIFT_STEP = 0.01 12 | RANGES = [128,128,32,100] 13 | 14 | PIANO_RANGE = [21,96] # 76 piano keys 15 | 16 | 17 | def encode(midi, use_piano_range=True): 18 | """ 19 | Encodes midi to event-based sequences for MusicTransformer. 20 | 21 | Parameters 22 | ---------- 23 | midi : prettyMIDI object 24 | MIDI to encode. 25 | use_piano_range : bool 26 | if True, classical piano range will be used for skip pitches. Pitches which are not in range PIANO_RANGE will be skipped. 27 | 28 | Returns 29 | ------- 30 | encoded_splits : list(list()) 31 | splits of encoded sequences. 32 | """ 33 | events = get_events(midi, use_piano_range=use_piano_range) 34 | if len(events) == 0: 35 | return [] 36 | quantize_(events) 37 | add_time_shifts(events) 38 | encoded = encode_events(events) 39 | return encoded 40 | 41 | 42 | def decode(encoded): 43 | """ 44 | Decode event-based encoded sequence into MIDI object. 45 | 46 | Parameters 47 | ---------- 48 | encoded : np.array or list 49 | encoded sequence to decode. 50 | 51 | Returns 52 | ------- 53 | midi_out: PrettyMIDI object 54 | decoded MIDI. 55 | """ 56 | midi_out = pretty_midi.PrettyMIDI() 57 | midi_out.instruments.append(pretty_midi.Instrument(0, name='piano')) 58 | notes = midi_out.instruments[0].notes 59 | 60 | notes_tmp = {} # pitch: [vel, start, end] 61 | cur_time = 0 62 | cur_velocity = 100 63 | for ev in encoded: 64 | if ev < RANGES[0]: 65 | # NOTE_ON 66 | pitch = ev 67 | if notes_tmp.get(pitch) is None: 68 | notes_tmp[pitch] = [cur_velocity, cur_time] 69 | elif ev >= RANGES[0] and ev < sum(RANGES[:2]): 70 | # NOTE_OFF 71 | pitch = ev - RANGES[0] 72 | note = notes_tmp.get(pitch) 73 | if note is not None: # check for overlaps (first-OFF mode) 74 | notes.append(pretty_midi.Note(note[0], pitch, note[1], cur_time)) 75 | notes_tmp.pop(pitch) 76 | elif ev >= sum(RANGES[:2]) and ev < sum(RANGES[:3]): 77 | # SET_VELOCITY 78 | cur_velocity = max(1,(ev - sum(RANGES[:2]))*128//RANGES[2]) 79 | elif ev >= sum(RANGES[:3]) and ev < sum(RANGES[:]): 80 | # TIME_SHIFT 81 | cur_time += (ev - sum(RANGES[:3]) + 1)*TIME_SHIFT_STEP 82 | else: 83 | continue 84 | 85 | for pitch, note in notes_tmp.items(): 86 | if note[1] != cur_time: 87 | notes.append(pretty_midi.Note(note[0], pitch, note[1], cur_time)) 88 | 89 | return midi_out 90 | 91 | 92 | def round_step(x, step=0.01): 93 | return round(x/step)*step 94 | 95 | 96 | def get_events(midi, use_piano_range=False): 97 | # helper function used in encode() 98 | # time, type, value 99 | events = [] 100 | for inst in midi.instruments: 101 | if inst.is_drum: 102 | continue 103 | for note in inst.notes: 104 | if use_piano_range and not (PIANO_RANGE[0] <= note.pitch <= PIANO_RANGE[1]): 105 | continue 106 | start = note.start 107 | end = note.end 108 | events.append([start, SET_VELOCITY, note.velocity]) 109 | events.append([start, NOTE_ON, note.pitch]) 110 | events.append([end, NOTE_OFF, note.pitch]) 111 | events = sorted(events, key=lambda x: x[0]) 112 | return events 113 | 114 | 115 | def quantize_(events): 116 | for ev in events: 117 | ev[0] = round_step(ev[0]) 118 | 119 | 120 | def add_time_shifts(events): 121 | # populate time_shifts, helper function used in encode() 122 | times = np.array(list(zip(*events)))[0] 123 | diff = np.diff(times, prepend=0) 124 | idxs = diff.nonzero()[0] 125 | for i in reversed(idxs): 126 | if i == 0: 127 | continue 128 | t0 = events[i-1][0] # if i != 0 else 0 129 | t1 = events[i][0] 130 | dt = t1-t0 131 | events.insert(i, [t0, TIME_SHIFT, dt]) 132 | 133 | 134 | def encode_events(events): 135 | # helper function used in encode() 136 | out = [] 137 | types = [] 138 | for time, typ, value in events: 139 | offset = sum(RANGES[:typ]) 140 | 141 | if typ == SET_VELOCITY: 142 | value = value*RANGES[SET_VELOCITY]//128 143 | out.append(offset+value) 144 | types.append(typ) 145 | 146 | elif typ == TIME_SHIFT: 147 | dt = value 148 | n = RANGES[TIME_SHIFT] 149 | enc = lambda x: int(x*n)-1 150 | for _ in range(int(dt//MAX_TIME_SHIFT)): 151 | out.append(offset+enc(MAX_TIME_SHIFT)) 152 | types.append(typ) 153 | r = round_step(dt%MAX_TIME_SHIFT, TIME_SHIFT_STEP) 154 | if r > 0: 155 | out.append(offset+enc(r)) 156 | types.append(typ) 157 | 158 | else: 159 | out.append(offset+value) 160 | types.append(typ) 161 | 162 | return out 163 | 164 | 165 | RANGES_SUM = np.cumsum(RANGES) 166 | 167 | 168 | def get_type(ev): 169 | if ev < RANGES_SUM[0]: 170 | # NOTE_ON 171 | return 0 172 | elif ev < RANGES_SUM[1]: 173 | # NOTE_OFF 174 | return 1 175 | elif ev < RANGES_SUM[2]: 176 | # VEL 177 | return 2 178 | elif ev < RANGES_SUM[3]: 179 | # TS 180 | return 3 181 | else: 182 | return -1 183 | 184 | 185 | def filter_bad_note_offs(events): 186 | """Clear NOTE_OFF events for which the corresponding NOTE_ON event is missing.""" 187 | notes_down = {} # pitch: 1 188 | keep_idxs = set(range(len(events))) 189 | 190 | for i,ev in enumerate(events): 191 | typ = get_type(ev) 192 | 193 | if typ == NOTE_ON: 194 | pitch = ev 195 | notes_down[pitch] = 1 196 | if typ == NOTE_OFF: 197 | pitch = ev-128 198 | if notes_down.get(pitch) is None: 199 | # if NOTE_OFF without NOTE_ON, then remove the event 200 | keep_idxs.remove(i) 201 | else: 202 | notes_down.pop(pitch) 203 | 204 | return list(keep_idxs) -------------------------------------------------------------------------------- /src/lib/model/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.modules.normalization import LayerNorm 4 | import random 5 | 6 | from lib.constants import * 7 | from .positional_encoding import PositionalEncoding 8 | from .rpr import TransformerEncoderRPR, TransformerEncoderLayerRPR 9 | 10 | 11 | # MusicTransformer 12 | class MusicTransformer(nn.Module): 13 | """ 14 | ---------- 15 | Author: Damon Gwinn 16 | ---------- 17 | Music Transformer reproduction from https://arxiv.org/abs/1809.04281. Arguments allow for 18 | tweaking the transformer architecture (https://arxiv.org/abs/1706.03762) and the rpr argument 19 | toggles Relative Position Representations (RPR - https://arxiv.org/abs/1803.02155). 20 | 21 | Supports training and generation using Pytorch's nn.Transformer class with dummy decoder to 22 | make a decoder-only transformer architecture 23 | 24 | For RPR support, there is modified Pytorch 1.2.0 code in rpr.py. Modified source will be 25 | kept up to date with Pytorch revisions only as necessary. 26 | ---------- 27 | """ 28 | 29 | def __init__(self, device, n_layers=6, num_heads=8, d_model=512, dim_feedforward=1024, dropout=0.1, 30 | max_sequence=2048, rpr=False, vocab_size=VOCAB_SIZE, cond_vocab_size=None, reduce_qk=False): 31 | super(MusicTransformer, self).__init__() 32 | 33 | self.vocab_size = vocab_size 34 | self.cond_vocab_size = cond_vocab_size 35 | self.device = device 36 | self.dummy = DummyDecoder() 37 | 38 | self.nlayers = n_layers 39 | self.nhead = num_heads 40 | self.d_model = d_model 41 | self.d_ff = dim_feedforward 42 | self.dropout = dropout 43 | self.max_seq = max_sequence 44 | self.rpr = rpr 45 | self.reduce_qk = reduce_qk 46 | 47 | # Input embedding 48 | self.embedding = nn.Embedding(self.vocab_size, self.d_model) 49 | if self.cond_vocab_size is not None: 50 | self.cond_embedding = nn.Embedding(self.cond_vocab_size, self.d_model) 51 | else: 52 | self.cond_embedding = None 53 | 54 | # Positional encoding 55 | self.positional_encoding = PositionalEncoding(self.d_model, self.dropout, self.max_seq) 56 | 57 | # Base transformer 58 | if(not self.rpr): 59 | # To make a decoder-only transformer we need to use masked encoder layers 60 | # Dummy decoder to essentially just return the encoder output 61 | self.transformer = nn.Transformer( 62 | d_model=self.d_model, nhead=self.nhead, num_encoder_layers=self.nlayers, 63 | num_decoder_layers=0, dropout=self.dropout, dim_feedforward=self.d_ff, custom_decoder=self.dummy 64 | ) 65 | # RPR Transformer 66 | else: 67 | encoder_norm = LayerNorm(self.d_model) 68 | encoder_layer = TransformerEncoderLayerRPR(self.d_model, self.nhead, self.d_ff, self.dropout, 69 | er_len=self.max_seq, reduce_qk=self.reduce_qk, device=self.device) 70 | encoder = TransformerEncoderRPR(encoder_layer, self.nlayers, encoder_norm) 71 | self.transformer = nn.Transformer( 72 | d_model=self.d_model, nhead=self.nhead, num_encoder_layers=self.nlayers, 73 | num_decoder_layers=0, dropout=self.dropout, dim_feedforward=self.d_ff, custom_decoder=self.dummy, 74 | custom_encoder=encoder 75 | ) 76 | 77 | # Final output is a softmaxed linear layer 78 | self.Wout = nn.Linear(self.d_model, self.vocab_size) 79 | self.softmax = nn.Softmax(dim=-1) 80 | self.mask = self.transformer.generate_square_subsequent_mask(max_sequence).to(self.device) 81 | 82 | # forward 83 | def forward(self, x, condition=None, mask=True): 84 | """ 85 | ---------- 86 | Author: Damon Gwinn 87 | ---------- 88 | Takes an input sequence and outputs predictions using a sequence to sequence method. 89 | 90 | A prediction at one index is the "next" prediction given all information seen previously. 91 | ---------- 92 | """ 93 | if (mask is True): 94 | mask = self.mask[..., :x.shape[1], :x.shape[1]] 95 | else: 96 | mask = None 97 | x = self.embedding(x) 98 | if condition is not None and self.cond_embedding is not None: 99 | x_cond = self.cond_embedding(condition) 100 | x = x + x_cond[:, None] 101 | 102 | # Input shape is (max_seq, batch_size, d_model) 103 | x = x.permute(1,0,2) 104 | 105 | x = self.positional_encoding(x) 106 | 107 | # Since there are no true decoder layers, the tgt is unused 108 | # Pytorch wants src and tgt to have some equal dims however 109 | x_out = self.transformer(src=x, tgt=x, src_mask=mask) 110 | 111 | # Back to (batch_size, max_seq, d_model) 112 | x_out = x_out.permute(1,0,2) 113 | 114 | y = self.Wout(x_out) 115 | 116 | # They are trained to predict the next note in sequence (we don't need the last one) 117 | return y 118 | 119 | def get_norms(self): 120 | norm_dict = {'embedding_weight_norm': torch.norm(self.embedding.weight).item(), 121 | 'embedding_grad_norm': torch.norm(self.embedding.weight.grad).item(), 122 | 'output_weight_norm': torch.norm(self.Wout.weight).item(), 123 | 'output_grad_norm': torch.norm(self.Wout.weight.grad).item()} 124 | return norm_dict 125 | 126 | def get_parameters(self): 127 | return {'device': self.device, 128 | 'n_layers': self.nlayers, 129 | 'num_heads': self.nhead, 130 | 'd_model': self.d_model, 131 | 'dim_feedforward': self.d_ff, 132 | 'dropout': self.dropout, 133 | 'max_sequence': self.max_seq, 134 | 'rpr': self.rpr, 135 | 'vocab_size': self.vocab_size, 136 | 'cond_vocab_size': self.cond_vocab_size, 137 | 'reduce_qk': self.reduce_qk, 138 | } 139 | 140 | # Used as a dummy to nn.Transformer 141 | # DummyDecoder 142 | class DummyDecoder(nn.Module): 143 | """ 144 | ---------- 145 | Author: Damon Gwinn 146 | ---------- 147 | A dummy decoder that returns its input. Used to make the Pytorch transformer into a decoder-only 148 | architecture (stacked encoders with dummy decoder fits the bill) 149 | ---------- 150 | """ 151 | 152 | def __init__(self): 153 | super(DummyDecoder, self).__init__() 154 | 155 | def forward(self, tgt, memory, tgt_mask, memory_mask,tgt_key_padding_mask,memory_key_padding_mask): 156 | """ 157 | ---------- 158 | Author: Damon Gwinn 159 | ---------- 160 | Returns the input (memory) 161 | ---------- 162 | """ 163 | 164 | return memory 165 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import os, sys, shutil 2 | import time 3 | import json 4 | import math 5 | import argparse 6 | import itertools 7 | import numpy as np 8 | import pandas as pd 9 | from tqdm import tqdm 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch.utils.tensorboard import SummaryWriter 15 | from torch.nn.parallel import DistributedDataParallel 16 | from torch.utils.data import DataLoader, Subset, DistributedSampler, Dataset 17 | 18 | 19 | from lib import constants 20 | from lib.model.transformer import MusicTransformer 21 | from lib.inverse_power_with_warmup_scheduler import InversePowerWithWarmupLRScheduler 22 | from lib.encoded_dataset import EncodedDataset 23 | from lib.augmentations import MusicAugmentations 24 | 25 | PAD_TOKEN = constants.TOKEN_PAD 26 | 27 | params = dict( 28 | NAME = 'model_name', 29 | DS_FILE_PATH = 'ds_files.pt', 30 | SEED = 0, 31 | num_epochs = 100, 32 | batch_size = 2, 33 | num_workers = 0, 34 | val_every = 6000, 35 | save_every = 6000, 36 | lr = 1e-4, 37 | use_scheduler = True, 38 | peak_lr = 1e-4, 39 | warmup_steps = 4000, 40 | power = 2, 41 | shift = 100000, 42 | LOAD_NAME = '', 43 | LOG_TOTAL_NORM = True, 44 | CLIPPING = False, 45 | gpus = [0,1,2,3], 46 | ) 47 | 48 | globals().update(params) 49 | 50 | 51 | def create_dataloaders(batch_size, num_workers=0): 52 | '''Initializes augmentations, loads file lists to datasets and loaders and returns them''' 53 | print('loading data...') 54 | 55 | aug = MusicAugmentations() 56 | 57 | tr_dataset = YoutubeDataset(DS_FILE_PATH, transform=aug) 58 | vl_dataset = YoutubeDataset(DS_FILE_PATH, transform=None) 59 | np.random.seed(0) 60 | idxs = np.random.permutation(len(tr_dataset)) 61 | vl, tr = np.split(idxs, [4000]) 62 | train_dataset = Subset(tr_dataset, tr) 63 | val_dataset = Subset(vl_dataset, vl) 64 | 65 | sampler = DistributedSampler(train_dataset, world_size, rank, True) 66 | train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler, pin_memory=False, num_workers=num_workers) 67 | sampler = DistributedSampler(val_dataset, world_size, rank, False) 68 | val_loader = DataLoader(val_dataset, batch_size=batch_size*4, sampler=sampler, pin_memory=False, num_workers=num_workers) 69 | 70 | return train_loader, val_loader 71 | 72 | 73 | def init_model(lr, seed=0): 74 | '''Initializes model, loads weights if necessary and creates optimizer''' 75 | torch.manual_seed(seed) 76 | model = MusicTransformer(device, n_layers=12, d_model=1024, dim_feedforward=2048, num_heads=16, vocab_size=390+4, rpr=True).to(device) 77 | if LOAD_NAME != '': 78 | model.load_state_dict(torch.load(LOAD_NAME, map_location=device)) 79 | print(f'Loaded model from {LOAD_NAME}') 80 | model = DistributedDataParallel(model, device_ids=[gpus[rank]]) 81 | print(sum((torch.numel(x) for x in model.parameters()))/1e6, 'M parameters') 82 | optimizer = torch.optim.AdamW(model.parameters(), lr, weight_decay=1e-5) 83 | return model, optimizer 84 | 85 | def validate(model, val_loader): 86 | CE = 0 87 | ACC = 0 88 | n = 0 89 | model.eval() 90 | with torch.no_grad(): 91 | for x, genre, idxs in val_loader: 92 | x[x==0] = PAD_TOKEN 93 | tgt = x.clone() 94 | x[:,-1] = constants.VOCAB_SIZE - 4 + genre 95 | x = torch.roll(x, 1, -1) 96 | x, tgt = x.to(device), tgt.to(device) 97 | 98 | logits = model(x) 99 | pred = logits.argmax(-1) 100 | 101 | mask = tgt != PAD_TOKEN 102 | n += mask.sum().item() 103 | CE += F.cross_entropy(logits.view(-1, logits.shape[-1]), tgt.flatten(), ignore_index=PAD_TOKEN, reduction='sum').item() 104 | ACC += (pred[mask] == tgt[mask]).sum().item() 105 | 106 | model.train() 107 | return CE/n, ACC/n 108 | 109 | def train_ddp(rank_, world_size_): 110 | global device, NAME, SEED, rank, world_size 111 | rank, world_size = rank_, world_size_ 112 | 113 | os.environ['MASTER_ADDR'] = 'localhost' 114 | os.environ['MASTER_PORT'] = '12355' 115 | torch.distributed.init_process_group('nccl', rank=rank, world_size=world_size) 116 | 117 | device = torch.device(f'cuda:{gpus[rank]}') 118 | print(rank, gpus[rank], device) 119 | 120 | train_loader, val_loader = create_dataloaders(batch_size, num_workers) 121 | 122 | model, optimizer = init_model(lr, SEED) 123 | if use_scheduler: 124 | scheduler = InversePowerWithWarmupLRScheduler(optimizer, peak_lr=peak_lr, warmup_steps=warmup_steps, power=power, shift=shift) 125 | 126 | if rank == 0: 127 | save_dir = f'output/{NAME}' 128 | save_name = f'{NAME}' 129 | if os.path.exists(save_dir): 130 | print(f'WARNING: {save_dir} exists! It may rewrite useful files') 131 | os.makedirs(save_dir, exist_ok=True) 132 | writer = SummaryWriter(f'runs/{save_name}') 133 | 134 | # TRAIN 135 | LS = {'loss':[], 'lr':[], 'val_ce':[], 'val_acc':[]} 136 | 137 | i_val = 0 138 | i_step = -1 139 | best_ce = float('inf') 140 | patience = 0 141 | for ep in range(num_epochs): 142 | model.train() 143 | train_loader.sampler.set_epoch(ep) 144 | if rank == 0: 145 | bar = tqdm(train_loader, position=rank) 146 | else: 147 | bar = train_loader 148 | for x, genre, idxs in bar: 149 | i_step += 1 150 | x[x==0] = PAD_TOKEN 151 | tgt = x.clone() 152 | x[:,-1] = constants.VOCAB_SIZE - 4 + genre 153 | x = torch.roll(x, 1, -1) 154 | x, tgt = x.to(device), tgt.to(device) 155 | logits = model(x) 156 | loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), tgt.flatten(), ignore_index=PAD_TOKEN) 157 | 158 | optimizer.zero_grad() 159 | loss.backward() 160 | 161 | if CLIPPING: 162 | total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), CLIPPING).item() 163 | else: 164 | total_norm = 0 165 | 166 | optimizer.step() 167 | 168 | if use_scheduler: 169 | scheduler.step() 170 | 171 | if i_step == warmup_steps - 1 and rank == 0: 172 | torch.save(model.module.state_dict(), f'{save_dir}/model_{save_name}_after_warmup.pt') 173 | 174 | if rank == 0: 175 | # logs 176 | LS['loss'] += [loss.item()] 177 | LS['lr'] += [optimizer.param_groups[0]['lr']] 178 | writer.add_scalar(f'Train/embedding_weight_norm', torch.norm(model.module.embedding.weight).item(), i_step) 179 | writer.add_scalar(f'Train/embedding_grad_norm', torch.norm(model.module.embedding.weight.grad).item(), i_step) 180 | writer.add_scalar(f'Train/output_weight_norm', torch.norm(model.module.Wout.weight).item(), i_step) 181 | writer.add_scalar(f'Train/output_grad_norm', torch.norm(model.module.Wout.weight.grad).item(), i_step) 182 | writer.add_scalar(f'Train/loss', loss.item(), i_step) 183 | writer.add_scalar(f'Train/perplexity', math.exp(loss.item()), i_step) 184 | writer.add_scalar(f'Train/lr', optimizer.param_groups[0]['lr'], i_step) 185 | if LOG_TOTAL_NORM: 186 | total_norm = 0. 187 | for p in model.module.parameters(): 188 | param_norm = p.grad.detach().data.norm(2) 189 | total_norm += param_norm.item() ** 2 190 | total_norm = total_norm ** 0.5 191 | writer.add_scalar(f'Train/total_grad_norm', total_norm, i_step) 192 | bar.set_postfix(loss=loss.item(), lr=optimizer.param_groups[0]['lr'], norm=total_norm) 193 | 194 | 195 | # VALIDATION 196 | if i_step % val_every == val_every-1: 197 | val_ce, val_acc = validate(model, val_loader) 198 | if world_size > 1: 199 | ce_all, acc_all = [[torch.zeros(1,device=device) for i in range(world_size)] for _ in range(2)] 200 | [torch.distributed.all_gather(a, torch.tensor(x, dtype=torch.float32, device=device)) for a,x in zip([ce_all,acc_all], [val_ce,val_acc])] 201 | val_ce, val_acc = [torch.cat(a).mean().item() for a in [ce_all,acc_all]] 202 | if rank == 0: 203 | # log, save, patience tracking 204 | LS['val_ce'] += [val_ce] 205 | LS['val_acc'] += [val_acc] 206 | writer.add_scalar(f'Val/ce', val_ce, i_val) 207 | writer.add_scalar(f'Val/acc', val_acc, i_val) 208 | writer.add_scalar(f'Val/perplexity', math.exp(val_ce), i_val) 209 | if val_ce < best_ce: 210 | patience = 0 211 | best_ce = val_ce 212 | torch.save({'history':LS,'epoch':ep,'params':params}, f'{save_dir}/hist_{save_name}_best.pt') 213 | torch.save(model.module.state_dict(), f'{save_dir}/model_{save_name}_best.pt') 214 | else: 215 | patience += 1 216 | print(f'{ep}: val_ce={val_ce}, val_acc={val_acc}, patience={patience}') 217 | i_val += 1 218 | 219 | # CHECKPOINT 220 | if (i_step % save_every == save_every-1) and rank == 0: 221 | torch.save({'history':LS,'epoch':ep,'params':params}, f'{save_dir}/hist_{save_name}.pt') 222 | torch.save(model.module.state_dict(), f'{save_dir}/model_{save_name}_{(i_step+1)//1000}k.pt') 223 | 224 | torch.distributed.destroy_process_group() 225 | 226 | 227 | if __name__ == "__main__": 228 | print(NAME, SEED) 229 | world_size = len(gpus) 230 | torch.multiprocessing.spawn(train_ddp, args=(world_size,), nprocs=world_size, join=True) 231 | -------------------------------------------------------------------------------- /src/lib/generation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from tqdm import tqdm 4 | 5 | from lib import midi_processing 6 | from lib import constants 7 | from lib.midi_processing import RANGES_SUM, get_type, NOTE_ON, NOTE_OFF 8 | from lib.midi_processing import PIANO_RANGE 9 | 10 | 11 | def generate(model, primer, target_seq_length=1024, temperature=1.0, topk=40, topp=0.99, topp_temperature=1.0, at_least_k=1, use_rp=False, rp_penalty=0.05, rp_restore_speed=0.7, seed=None, **forward_args): 12 | """ 13 | Generate batch of samples, conditioned on `primer`. There are used several techniques for acquiring better generated samples such as: 14 | - temperature skewing for controlling entropy of distribuitions 15 | - top-k sampling 16 | - top-p (nucleus) sampling (https://arxiv.org/abs/1904.09751) 17 | - DynamicRepetitionPenaltyProcessor that prevents notes repeating 18 | values by default usualy are suitable for our models 19 | 20 | Parameters 21 | ---------- 22 | model : MusicTransformer 23 | trained model. 24 | primer : torch.Tensor (B x N) 25 | primer for condition on. 26 | B = batch_size, N = seq_lenght. 27 | We are using the primer consisted of one token - genre. These tokens are {390:'classic', 391:'jazz', 392:'calm', 393:'pop'}. 28 | target_seq_length : int 29 | desired length of generated sequences. 30 | temperature : float 31 | temperature alters the output distribuition of the model. Higher values ( > 1.0) lead to more stohastic sampling, lower values lead to more expected and predictable sequences (ending up with endlessly repeating musical patterns). 32 | topk : int 33 | restricts sampling from lower probabilities. It is the length of set of tokens from which sampling will be. 34 | topp : float 35 | restricts sampling from lower probabilities, but more adaptive then topk. see (https://arxiv.org/abs/1904.09751). 36 | topp_temperature : float 37 | temperature for counting cumulative sum doing topp sampling. 38 | at_least_k : int 39 | like topk, but force to sample from at least k tokens of higher probabilities. 40 | use_rp : bool 41 | use or not the DynamicRepetitionPenaltyProcessor (RP). Trying to prevent the generation of repeated notes. 42 | rp_penalty : float 43 | coef for RP. Higher values lead to more RP impact. 44 | rp_restore_speed : float 45 | how fast the penalty will be lifted. Lower values lead to more RP impact. 46 | seed : int 47 | fixes seed for deterministic generation. 48 | forward_args : dict 49 | args for model's forward. 50 | 51 | Returns 52 | ------- 53 | generated : torch.Tensor (B x target_seq_length) 54 | generated batch of sequences. 55 | """ 56 | device = model.device 57 | if seed is not None: 58 | torch.manual_seed(seed) 59 | np.random.seed(seed) 60 | 61 | if at_least_k < 1: 62 | at_least_k = 1 63 | B,N = primer.shape 64 | generated = torch.full((B,target_seq_length), constants.TOKEN_PAD, dtype=torch.int64, device=device) 65 | generated[..., :N] = primer.to(device) 66 | 67 | if use_rp: 68 | RP_processor = DynamicRepetitionPenaltyProcessor(B, penalty=rp_penalty, restore_speed=rp_restore_speed, device=device) 69 | whitelist_mask = make_whitelist_mask() 70 | 71 | model.eval() 72 | with torch.no_grad(): 73 | for i in tqdm(range(N, target_seq_length)): 74 | logits = model(generated[:, :i], **forward_args)[:, i-1, :] 75 | logits[:,~whitelist_mask] = float('-inf') 76 | p = torch.softmax(logits/topp_temperature, -1) 77 | 78 | # apply topk: 79 | if topk == 0: 80 | topk = p.shape[-1] 81 | p_topk, idxs = torch.topk(p, topk, -1, sorted=True) 82 | 83 | # apply topp: 84 | mask = p_topk.cumsum(-1) < topp 85 | mask[:,:at_least_k] = True 86 | logits_masked = logits.gather(-1, idxs) 87 | logits_masked[~mask] = float('-inf') 88 | p_topp = torch.softmax(logits_masked/temperature, -1) 89 | 90 | # apply penalty: 91 | if use_rp: 92 | p_penalized = RP_processor.apply_penalty(p_topp, idxs) 93 | ib = p_penalized.sum(-1) == 0 94 | if ib.sum() > 0: 95 | # if all topp tokens get zeroes due RP_processor, then fallback to topk-sampling 96 | p_fallback = p_topk[ib].clone() 97 | p_fallback[mask[ib]] = 0. # zeroing topp 98 | p_penalized[ib] = p_fallback 99 | 100 | ib = p_penalized.sum(-1) == 0 101 | if ib.sum() > 0: 102 | # if topk tokens get zeroes, fallback to topp without RP 103 | print('fallback-2') 104 | p_penalized = p_topp 105 | p_topp = p_penalized 106 | 107 | # sample: 108 | next_token = idxs.gather(-1, torch.multinomial(p_topp, 1)) 109 | generated[:, i] = next_token.squeeze(-1) 110 | 111 | # update penalty: 112 | if use_rp: 113 | RP_processor.update(next_token) 114 | 115 | return generated[:, :i+1] 116 | 117 | 118 | def post_process(generated, remove_bad_generations=True): 119 | """ 120 | Post-process does 3 routines: 121 | 1) removes long pauses (3+ seconds) 122 | 2) clips velocities to range(30,100) to avoid dramaticly loud notes, which are not suitable for our case. 123 | 3) removes bad generated samples. The model sometimes may generate music that consists only of many repeating notes. We try to detect them and remove from batch. 124 | 125 | Parameters 126 | ---------- 127 | generated : torch.Tensor (B x N) 128 | batch of generated samples 129 | 130 | Returns 131 | ------- 132 | filtered_generated : cleaner and slightly better sounding generated batch 133 | """ 134 | generated = generated.cpu().numpy() 135 | remove_pauses(generated, 3) 136 | clip_velocity(generated) 137 | 138 | bad_filter = np.ones(len(generated), dtype=bool) 139 | 140 | if remove_bad_generations: 141 | for i, gen in enumerate(generated): 142 | midi = midi_processing.decode(gen) 143 | if detect_note_repetition(midi) > 0.9: 144 | bad_filter[i] = False 145 | 146 | if np.sum(bad_filter) != len(bad_filter): 147 | print(f'{np.sum(~bad_filter)} bad samples will be removed.') 148 | 149 | return generated[bad_filter] 150 | 151 | 152 | def make_whitelist_mask(): 153 | """Generate mask for PIANO_RANGE""" 154 | whitelist_mask = np.zeros(constants.VOCAB_SIZE, dtype=bool) 155 | whitelist_mask[PIANO_RANGE[0]:PIANO_RANGE[1]+1] = True 156 | whitelist_mask[128+PIANO_RANGE[0]:128+PIANO_RANGE[1]+1] = True 157 | whitelist_mask[128*2:] = True 158 | return whitelist_mask 159 | 160 | 161 | class DynamicRepetitionPenaltyProcessor: 162 | """ 163 | The class is trying to prevent cases where the model generates repetitive notes or musical patterns that degrade quality. 164 | It dynamically reduces and restores the probabilities of generatied notes. 165 | Each generated note will reduce its probability for the next step by `penalty` value (which is hyperparameter). If this note has been generated again, then we continue to reduce its probability, else we will gradually restore its probability (speed is controlled by restore_speed parameter). 166 | 167 | Parameters 168 | ---------- 169 | bs : int 170 | batch_size. We need to know batch_size in advance to create the penalty_matrix. 171 | penalty : float 172 | value by which the probability will be reduced. 173 | restore_speed : float 174 | the number inversed to the number of seconds needs to fully restore probability from 0 to 1. 175 | for restore_speed equal to 1.0 we need 1.0 sec to restore, for 2.0 - 0.5 sec and so on. 176 | """ 177 | def __init__(self, bs, device, penalty=0.3, restore_speed=1.0) : 178 | self.bs = bs 179 | self.penalty = penalty 180 | self.restore_speed = restore_speed 181 | self.penalty_matrix = torch.ones(bs,128).to(device) 182 | 183 | def apply_penalty(self, p, idxs): 184 | p = p.clone() 185 | for b in range(len(p)): 186 | i = idxs[b] 187 | pi = p[b] 188 | mask = i < 128 189 | if len(i) > 0: 190 | pi[mask] = pi[mask]*self.penalty_matrix[b,i[mask]] 191 | return p 192 | 193 | def update(self, next_token): 194 | restoring = next_token - (128+128+32) # only TS do restore 195 | restoring = torch.clamp(restoring.float(), 0, 100)/100*self.restore_speed 196 | self.penalty_matrix += restoring 197 | nt = next_token.squeeze(-1) 198 | nt = next_token[next_token < 128] 199 | self.penalty_matrix[:, nt] -= restoring + self.penalty 200 | torch.clamp(self.penalty_matrix, 0, 1.0, out=self.penalty_matrix) 201 | return restoring, nt 202 | 203 | 204 | def detect_note_repetition(midi, threshold_sec=0.01): 205 | """ 206 | Returns the fraction of note repetitions. Counts cases where prev_note_end == next_note_start at the same pitch ('glued' notes). Used in detection bad generated samples. 207 | 208 | Parameters 209 | ---------- 210 | midi : prettyMIDI object 211 | threshold_sec : float 212 | intervals smaller then threshold_sec are treated as 'glued' notes. 213 | 214 | Returns 215 | ------- 216 | fraction of notes repetitions relative to the number of all notes. 217 | """ 218 | all_notes = [x for inst in midi.instruments for x in inst.notes if not inst.is_drum] 219 | if len(all_notes) == 0: 220 | return 0 221 | all_notes_np = np.array([[x.start,x.end,x.pitch,x.velocity] for x in all_notes]) 222 | 223 | i_sort = np.lexsort([all_notes_np[:,0], all_notes_np[:,2]]) 224 | 225 | s = [] 226 | cur_p = -1 227 | cur_t = -1 228 | for t in all_notes_np[i_sort]: 229 | a,b,p,v = t 230 | if cur_p != p: 231 | cur_p = p 232 | else: 233 | s.append(a-cur_t) 234 | cur_t = b 235 | s = np.array(s) 236 | return (s < threshold_sec).sum()/len(s) 237 | 238 | 239 | def remove_pauses(generated, threshold=3): 240 | """ 241 | Fills pauses by constants.TOKEN_PAD values. Only pauses that longer than `threshold` seconds are considered. 242 | Inplace operation. `generated` is a tensor (batch of sequences). 243 | 244 | Parameters 245 | ---------- 246 | generated : torch.Tensor (B x N) 247 | generated batch of sequences. 248 | threshold : int/float 249 | the minimum seconds of silence to treat them as a pause. 250 | """ 251 | mask = (generated>=RANGES_SUM[2]) & (generated= threshold and notes_down.sum() == 0: 271 | res_ab[ib].append([a,i,s]) 272 | s = 0 273 | a = i+1 274 | s += t 275 | if s >= threshold and notes_down.sum() == 0: 276 | res_ab[ib].append([a,len(i_seconds),s]) 277 | 278 | # remove inplace 279 | for ib,t in enumerate(res_ab): 280 | for a,b,s in t: 281 | generated[ib, a:b] = constants.TOKEN_PAD 282 | print(f'pause removed:',ib,f'n={b-a}',a,b,s) 283 | 284 | 285 | def clip_velocity(generated, min_velocity=30, max_velocity=100): 286 | """ 287 | Clip velocity to range(min_velocity, max_velocity). Since the model sometimes generate overloud sequences, we try to neutralize this effect. 288 | Inplace operation. `generated` is a tensor (batch of sequences). 289 | 290 | Parameters 291 | ---------- 292 | generated : torch.Tensor (B x N) 293 | generated batch of sequences. 294 | min_velocity : int 295 | max_velocity : int 296 | """ 297 | max_velocity_encoded = max_velocity*32//128 + RANGES_SUM[1] 298 | min_velocity_encoded = min_velocity*32//128 + RANGES_SUM[1] 299 | 300 | mask = (generated>=RANGES_SUM[1]) & (generated Tuple[Tensor, Optional[Tensor]] 242 | 243 | qkv_same = torch.equal(query, key) and torch.equal(key, value) 244 | if reduce_qk: 245 | #print('Using reduced qk') 246 | qkv_same = False 247 | use_separate_proj_weight = True 248 | kv_same = torch.equal(key, value) 249 | 250 | tgt_len, bsz, embed_dim = query.size() 251 | assert embed_dim == embed_dim_to_check 252 | assert list(query.size()) == [tgt_len, bsz, embed_dim] 253 | assert key.size() == value.size() 254 | 255 | head_dim = embed_dim // num_heads 256 | assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" 257 | scaling = float(head_dim) ** -0.5 258 | 259 | if use_separate_proj_weight is not True: 260 | if qkv_same: 261 | # self-attention 262 | q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) 263 | 264 | elif kv_same: 265 | # encoder-decoder attention 266 | # This is inline in_proj function with in_proj_weight and in_proj_bias 267 | _b = in_proj_bias 268 | _start = 0 269 | _end = embed_dim 270 | _w = in_proj_weight[_start:_end, :] 271 | if _b is not None: 272 | _b = _b[_start:_end] 273 | q = linear(query, _w, _b) 274 | 275 | if key is None: 276 | assert value is None 277 | k = None 278 | v = None 279 | else: 280 | 281 | # This is inline in_proj function with in_proj_weight and in_proj_bias 282 | _b = in_proj_bias 283 | _start = embed_dim 284 | _end = None 285 | _w = in_proj_weight[_start:, :] 286 | if _b is not None: 287 | _b = _b[_start:] 288 | k, v = linear(key, _w, _b).chunk(2, dim=-1) 289 | 290 | else: 291 | # This is inline in_proj function with in_proj_weight and in_proj_bias 292 | _b = in_proj_bias 293 | _start = 0 294 | _end = embed_dim 295 | _w = in_proj_weight[_start:_end, :] 296 | if _b is not None: 297 | _b = _b[_start:_end] 298 | q = linear(query, _w, _b) 299 | 300 | # This is inline in_proj function with in_proj_weight and in_proj_bias 301 | _b = in_proj_bias 302 | _start = embed_dim 303 | _end = embed_dim * 2 304 | _w = in_proj_weight[_start:_end, :] 305 | if _b is not None: 306 | _b = _b[_start:_end] 307 | k = linear(key, _w, _b) 308 | 309 | # This is inline in_proj function with in_proj_weight and in_proj_bias 310 | _b = in_proj_bias 311 | _start = embed_dim * 2 312 | _end = None 313 | _w = in_proj_weight[_start:, :] 314 | if _b is not None: 315 | _b = _b[_start:] 316 | v = linear(value, _w, _b) 317 | else: 318 | q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight) 319 | if not reduce_qk: 320 | len1, len2 = q_proj_weight_non_opt.size() 321 | assert len1 == embed_dim and len2 == query.size(-1) 322 | 323 | k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight) 324 | if not reduce_qk: 325 | len1, len2 = k_proj_weight_non_opt.size() 326 | assert len1 == embed_dim and len2 == key.size(-1) 327 | 328 | v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight) 329 | len1, len2 = v_proj_weight_non_opt.size() 330 | assert len1 == embed_dim and len2 == value.size(-1) 331 | 332 | if in_proj_bias is not None: 333 | if reduce_qk: 334 | q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:int(embed_dim/2)]) 335 | k = linear(key, k_proj_weight_non_opt, in_proj_bias[int(embed_dim/2):embed_dim]) 336 | v = linear(value, v_proj_weight_non_opt, in_proj_bias[embed_dim:]) 337 | else: 338 | q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim]) 339 | k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)]) 340 | v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):]) 341 | else: 342 | q = linear(query, q_proj_weight_non_opt, in_proj_bias) 343 | k = linear(key, k_proj_weight_non_opt, in_proj_bias) 344 | v = linear(value, v_proj_weight_non_opt, in_proj_bias) 345 | q = q * scaling 346 | 347 | if bias_k is not None and bias_v is not None: 348 | if static_k is None and static_v is None: 349 | k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) 350 | v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) 351 | if attn_mask is not None: 352 | attn_mask = torch.cat([attn_mask, 353 | torch.zeros((attn_mask.size(0), 1), 354 | dtype=attn_mask.dtype, 355 | device=attn_mask.device)], dim=1) 356 | if key_padding_mask is not None: 357 | key_padding_mask = torch.cat( 358 | [key_padding_mask, torch.zeros((key_padding_mask.size(0), 1), 359 | dtype=key_padding_mask.dtype, 360 | device=key_padding_mask.device)], dim=1) 361 | else: 362 | assert static_k is None, "bias cannot be added to static key." 363 | assert static_v is None, "bias cannot be added to static value." 364 | else: 365 | assert bias_k is None 366 | assert bias_v is None 367 | 368 | if reduce_qk: 369 | q = q.contiguous().view(tgt_len, bsz * num_heads, int(head_dim / 2)).transpose(0, 1) 370 | if k is not None: 371 | k = k.contiguous().view(-1, bsz * num_heads, int(head_dim / 2)).transpose(0, 1) 372 | if v is not None: 373 | v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) 374 | else: 375 | q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) 376 | if k is not None: 377 | k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) 378 | if v is not None: 379 | v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) 380 | 381 | if static_k is not None: 382 | assert static_k.size(0) == bsz * num_heads 383 | assert static_k.size(2) == head_dim 384 | k = static_k 385 | 386 | if static_v is not None: 387 | assert static_v.size(0) == bsz * num_heads 388 | assert static_v.size(2) == head_dim 389 | v = static_v 390 | 391 | src_len = k.size(1) 392 | 393 | if key_padding_mask is not None: 394 | assert key_padding_mask.size(0) == bsz 395 | assert key_padding_mask.size(1) == src_len 396 | 397 | if add_zero_attn: 398 | src_len += 1 399 | k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1) 400 | v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1) 401 | if attn_mask is not None: 402 | attn_mask = torch.cat([attn_mask, torch.zeros((attn_mask.size(0), 1), 403 | dtype=attn_mask.dtype, 404 | device=attn_mask.device)], dim=1) 405 | if key_padding_mask is not None: 406 | key_padding_mask = torch.cat( 407 | [key_padding_mask, torch.zeros((key_padding_mask.size(0), 1), 408 | dtype=key_padding_mask.dtype, 409 | device=key_padding_mask.device)], dim=1) 410 | 411 | attn_output_weights = torch.bmm(q, k.transpose(1, 2)) 412 | assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] 413 | 414 | ######### ADDITION OF RPR ########### 415 | if(rpr_mat is not None): 416 | rpr_mat = _get_valid_embedding(rpr_mat, q.shape[1], k.shape[1]) 417 | qe = torch.einsum("hld,md->hlm", q, rpr_mat) 418 | srel = _skew(qe, skew_mask) 419 | 420 | attn_output_weights += srel 421 | 422 | if attn_mask is not None: 423 | attn_mask = attn_mask.unsqueeze(0) 424 | attn_output_weights += attn_mask 425 | 426 | if key_padding_mask is not None: 427 | attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) 428 | attn_output_weights = attn_output_weights.masked_fill( 429 | key_padding_mask.unsqueeze(1).unsqueeze(2), 430 | float('-inf'), 431 | ) 432 | attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len) 433 | 434 | attn_output_weights = softmax(attn_output_weights, dim=-1) 435 | 436 | attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training) 437 | 438 | attn_output = torch.bmm(attn_output_weights, v) 439 | assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] 440 | attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) 441 | attn_output = linear(attn_output, out_proj_weight, out_proj_bias) 442 | 443 | if need_weights: 444 | # average attention weights over heads 445 | attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) 446 | return attn_output, attn_output_weights.sum(dim=1) / num_heads 447 | else: 448 | return attn_output, None 449 | 450 | def _get_valid_embedding(Er, len_q, len_k): 451 | """ 452 | ---------- 453 | Author: Damon Gwinn 454 | ---------- 455 | Gets valid embeddings based on max length of RPR attention 456 | ---------- 457 | """ 458 | 459 | len_e = Er.shape[0] 460 | start = max(0, len_e - len_q) 461 | return Er[start:, :] 462 | 463 | def _skew(qe, skew_mask=None): 464 | """ 465 | ---------- 466 | Author: Damon Gwinn 467 | ---------- 468 | Performs the skew optimized RPR computation (https://arxiv.org/abs/1809.04281) 469 | ---------- 470 | """ 471 | 472 | sz = qe.shape[1] 473 | # If mask is generated on the fly performance might degrade 474 | if skew_mask is None: 475 | mask = (torch.triu(torch.ones(sz, sz).to(qe.device)) == 1).float().flip(0) 476 | else: 477 | mask = skew_mask[..., skew_mask.shape[0] - sz:, :sz] 478 | 479 | qe = mask * qe 480 | qe = F.pad(qe, (1,0, 0,0, 0,0)) 481 | qe = torch.reshape(qe, (qe.shape[0], qe.shape[2], qe.shape[1])) 482 | 483 | srel = qe[:, 1:, :] 484 | return srel 485 | -------------------------------------------------------------------------------- /src/Music_Composer_Demo_Colab_en.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "cellView": "form", 8 | "id": "S_i1BkSBikEs" 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "#@title 0. Preparing the environment\n", 13 | "#@markdown Libraries are simply downloaded and imported here, nothing interesting\n", 14 | "!pip install gdown > /dev/null\n", 15 | "!apt install fluidsynth > /dev/null\n", 16 | "!pip install midi2audio > /dev/null\n", 17 | "!pip install pretty_midi > /dev/null\n", 18 | "!cp /usr/share/sounds/sf2/FluidR3_GM.sf2 ./font.sf2\n", 19 | "!git clone https://github.com/sberbank-ai/music-composer.git\n", 20 | "\n", 21 | "import os\n", 22 | "import sys\n", 23 | "import time\n", 24 | "import glob\n", 25 | "import torch\n", 26 | "import gdown\n", 27 | "import webbrowser\n", 28 | "import ipywidgets\n", 29 | "import pretty_midi\n", 30 | "import numpy as np\n", 31 | "\n", 32 | "from tqdm import tqdm\n", 33 | "from pathlib import Path\n", 34 | "from midi2audio import FluidSynth\n", 35 | "from google.colab import widgets, files\n", 36 | "from IPython.display import Audio, display, FileLink, HTML\n", 37 | "from ipywidgets import AppLayout, HBox, VBox, Label\n", 38 | "\n", 39 | "url = 'https://drive.google.com/u/0/uc?id=1lcNp0y4IZMIos0ASSsERG25WVDuEmLks'\n", 40 | "output = 'model_finetune_700k.pt'\n", 41 | "gdown.download(url, output, quiet=True)\n", 42 | "\n", 43 | "\n", 44 | "sys.path.append('/content/music-composer/src/')\n", 45 | "from lib import constants\n", 46 | "from lib import generation\n", 47 | "from lib import midi_processing\n", 48 | "from lib.midi_processing import PIANO_RANGE\n", 49 | "from lib.model.transformer import MusicTransformer\n", 50 | "from lib.colab_utils import id2genre, rugenre, genre2id, decode_and_write, convert_midi_to_wav, DownloadButton" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": { 57 | "cellView": "form", 58 | "id": "Yn4cWWoOJHfi" 59 | }, 60 | "outputs": [], 61 | "source": [ 62 | "#@title 1. Generation control panel\n", 63 | "#@markdown In this panel, you can customize the music generation parameters for yourself. Let's go through all the points:\n", 64 | "#@markdown * You can choose __primer__ - it's your MIDI file, which will be taken as the beginning of the song. If it is long, then it will be trimmed. Specify the desired length (in seconds) and how to trim it - from the beginning, or from the end (From start / From end). __From start__ means that we will leave the first N seconds, __From end__ - after N seconds.\n", 65 | "#@markdown * __Genre__ - allows you to select the genre in which the tracks will be generated. For each of the genres, we have selected our own set of parameters, which allows you to best generate it, but more on that later.\n", 66 | "#@markdown * __Seed__ - allows you to fix the seed with which the tracks are generated. Fixing it can be useful if you want to study the influence of a particular parameter on generation - set a numerical seed and generate several times changing the studied parameter. In other cases, it is recommended to leave this field blank. \n", 67 | "#@markdown * __Batch_size__ - the number of simultaneously generated tracks. If you overdo it, you may not have enough memory. Here you need to find a balance between the number of simultaneous tracks and their length. \n", 68 | "#@markdown * __Sequence length__ - track length. Due to the peculiarities of generation, we cannot say in advance how long the track will last, but increasing this parameter potentially increases the track length. It is with this parameter that the Batch_size must be balanced so that the video memory doesn't end.\n", 69 | "#@markdown * __Remove bad generations__ - sometimes the model starts generating garbage. If this flag is enabled, we will try to detect it, filter it out and generate new compositions instead. The generation time increases accordingly.\n", 70 | "#@markdown * __Temperature__ - temperature scaling. Affects the probability distribution itself, making it more or less equiprobable. This regulates the variety of generation.\n", 71 | "#@markdown * __TopK__ - restriction of dialing on the upper border. Sampling comes from a set of k most likely tokens.\n", 72 | "#@markdown * __At least K__ - limitation of dialing on the lower border. Sampling will be guaranteed to come from at least the most likely tokens at_least_k.\n", 73 | "#@markdown * __TopP__ - the topp parameter about which you can read in more detail in the article.\n", 74 | "#@markdown * __TopP Temperature__ - temperature scaling applied after selection by topp criterion.\n", 75 | "#@markdown * __Use Repetition Penalty__ - flag to use note repetition penalties. It is recommended to turn it on only if the model generates cyclic boring tracks.\n", 76 | "#@markdown * __RP Penalty__ - the amount of penalties for repeating notes. The higher, the less the model gets stuck in some musical phrases.\n", 77 | "#@markdown * __Restore speed__ - speed of recovery after fines of the repetition penalty module\n", 78 | "\n", 79 | "#@markdown There are also buttons by genre in the lower right corner. With the help of them, you can return to the generation parameters we have selected for each of the genres.\n", 80 | "\n", 81 | "\n", 82 | "def truncate_midi(midi, primer_len_sec=15.0, from_end=False):\n", 83 | " time0 = max([inst.notes[-1].end for inst in midi.instruments]) if from_end else 0\n", 84 | "\n", 85 | " for inst in midi.instruments:\n", 86 | " notes = sorted(inst.notes, key=lambda x: x.start, reverse=from_end)\n", 87 | " for i,note in enumerate(notes):\n", 88 | " if np.abs(note.start - time0) > primer_len_sec:\n", 89 | " break\n", 90 | " if not from_end and note.end > primer_len_sec:\n", 91 | " note.end = time0 - (int(from_end)*2-1) * primer_len_sec\n", 92 | " inst.notes = notes[:i]\n", 93 | " if from_end:\n", 94 | " inst.notes = inst.notes[::-1]\n", 95 | "\n", 96 | "\n", 97 | "style = {'description_width': 'initial'}\n", 98 | "genre_to_generate = ipywidgets.Dropdown(\n", 99 | " options=['calm', 'jazz', 'pop', 'classic'],\n", 100 | " value='calm',\n", 101 | " description='Genre:',\n", 102 | " disabled=False,\n", 103 | ")\n", 104 | "genre_to_generate.default_value = 'calm'\n", 105 | "\n", 106 | "seed = ipywidgets.Text(\n", 107 | " value='',\n", 108 | " placeholder='leave blank for random seed',\n", 109 | " description='Seed:',\n", 110 | " disabled=False\n", 111 | ")\n", 112 | "seed.default_value = ''\n", 113 | "\n", 114 | "temp = ipywidgets.FloatSlider( \n", 115 | " value=1.0,\n", 116 | " min=0.01,\n", 117 | " max=5.0,\n", 118 | " step=0.01,\n", 119 | " description='Temperature:',\n", 120 | " disabled=False,\n", 121 | " continuous_update=False,\n", 122 | " orientation='horizontal',\n", 123 | " readout=True,\n", 124 | " readout_format='.2f',\n", 125 | " style=style\n", 126 | ")\n", 127 | "temp.default_value = 1.0\n", 128 | "\n", 129 | "b_size = ipywidgets.IntSlider(\n", 130 | " value=8,\n", 131 | " min=1,\n", 132 | " max=16,\n", 133 | " step=1,\n", 134 | " description='Batch size:',\n", 135 | " disabled=False,\n", 136 | " continuous_update=False,\n", 137 | " orientation='horizontal',\n", 138 | " readout=True,\n", 139 | " readout_format='d',\n", 140 | " style=style\n", 141 | ")\n", 142 | "b_size.default_value = 8\n", 143 | "\n", 144 | "seq_length = ipywidgets.IntSlider(\n", 145 | " value=512,\n", 146 | " min=256,\n", 147 | " max=2048,\n", 148 | " step=256,\n", 149 | " description='Sequence length:',\n", 150 | " disabled=False,\n", 151 | " continuous_update=False,\n", 152 | " orientation='horizontal',\n", 153 | " readout=True,\n", 154 | " readout_format='d',\n", 155 | " style=style\n", 156 | ")\n", 157 | "seq_length.default_value = 512\n", 158 | "\n", 159 | "topk = ipywidgets.IntSlider(\n", 160 | " value=60,\n", 161 | " min=1,\n", 162 | " max=300,\n", 163 | " step=1,\n", 164 | " description='Top k:',\n", 165 | " disabled=False,\n", 166 | " continuous_update=False,\n", 167 | " orientation='horizontal',\n", 168 | " readout=True,\n", 169 | " readout_format='d',\n", 170 | " style=style\n", 171 | ")\n", 172 | "topk.default_value = 60\n", 173 | "\n", 174 | "at_least_k = ipywidgets.IntSlider(\n", 175 | " value=1,\n", 176 | " min=1,\n", 177 | " max=300,\n", 178 | " step=1,\n", 179 | " description='At least k:',\n", 180 | " disabled=False,\n", 181 | " continuous_update=False,\n", 182 | " orientation='horizontal',\n", 183 | " readout=True,\n", 184 | " readout_format='d',\n", 185 | " style=style\n", 186 | ")\n", 187 | "at_least_k.default_value = 1\n", 188 | "\n", 189 | "topp = ipywidgets.FloatSlider( \n", 190 | " value=0.99,\n", 191 | " min=0.5,\n", 192 | " max=1.0,\n", 193 | " step=0.01,\n", 194 | " description='Topp:',\n", 195 | " disabled=False,\n", 196 | " continuous_update=False,\n", 197 | " orientation='horizontal',\n", 198 | " readout=True,\n", 199 | " readout_format='.2f',\n", 200 | " style=style\n", 201 | ")\n", 202 | "topp.default_value = 0.99\n", 203 | "\n", 204 | "topp_temperature = ipywidgets.FloatSlider( \n", 205 | " value=1.0,\n", 206 | " min=0.01,\n", 207 | " max=5.0,\n", 208 | " step=0.01,\n", 209 | " description='Topp Temperature:',\n", 210 | " disabled=False,\n", 211 | " continuous_update=False,\n", 212 | " orientation='horizontal',\n", 213 | " readout=True,\n", 214 | " readout_format='.2f',\n", 215 | " style=style\n", 216 | ")\n", 217 | "topp_temperature.default_value = 1.0\n", 218 | "\n", 219 | "use_rp = ipywidgets.Checkbox(\n", 220 | " value=False,\n", 221 | " description='Use Repetition Penalty',\n", 222 | " disabled=False,\n", 223 | " style=style\n", 224 | ")\n", 225 | "use_rp.default_value = False\n", 226 | "\n", 227 | "rp_penalty = ipywidgets.FloatSlider( \n", 228 | " value=0.05,\n", 229 | " min=0.,\n", 230 | " max=1.0,\n", 231 | " step=0.05,\n", 232 | " description='RP penalty:',\n", 233 | " disabled=False,\n", 234 | " continuous_update=False,\n", 235 | " orientation='horizontal',\n", 236 | " readout=True,\n", 237 | " readout_format='.2f',\n", 238 | " style=style\n", 239 | ")\n", 240 | "rp_penalty.default_value = 0.05\n", 241 | "\n", 242 | "restore_speed = ipywidgets.FloatSlider( \n", 243 | " value=0.7,\n", 244 | " min=0.,\n", 245 | " max=1.0,\n", 246 | " step=0.05,\n", 247 | " description='Restore speed:',\n", 248 | " disabled=False,\n", 249 | " continuous_update=False,\n", 250 | " orientation='horizontal',\n", 251 | " readout=True,\n", 252 | " readout_format='.2f',\n", 253 | " style=style\n", 254 | ")\n", 255 | "restore_speed.default_value = 0.7\n", 256 | "\n", 257 | "remove_bad_generations = ipywidgets.Checkbox(\n", 258 | " value=True,\n", 259 | " description='Remove bad generations',\n", 260 | " disabled=False,\n", 261 | " style=style\n", 262 | ")\n", 263 | "remove_bad_generations.default_value = True\n", 264 | "\n", 265 | "defaulting_widgets = (genre_to_generate, seed, b_size, seq_length, remove_bad_generations,\n", 266 | " temp, topk, at_least_k, topp, topp_temperature, use_rp, rp_penalty, restore_speed)\n", 267 | "def set_classic(button):\n", 268 | " for idx, widget in enumerate(defaulting_widgets):\n", 269 | " if idx == 0:\n", 270 | " widget.value = 'classic'\n", 271 | " elif idx == 5:\n", 272 | " widget.value = 1.0\n", 273 | " elif idx == 7:\n", 274 | " widget.value = 1\n", 275 | " elif idx == 8:\n", 276 | " widget.value = 0.99\n", 277 | " else:\n", 278 | " widget.value = widget.default_value\n", 279 | "\n", 280 | "def set_calm(button):\n", 281 | " for idx, widget in enumerate(defaulting_widgets):\n", 282 | " if idx == 0:\n", 283 | " widget.value = 'calm'\n", 284 | " elif idx == 5:\n", 285 | " widget.value = 1.03\n", 286 | " elif idx == 7:\n", 287 | " widget.value = 4\n", 288 | " elif idx == 8:\n", 289 | " widget.value = 0.98\n", 290 | " else:\n", 291 | " widget.value = widget.default_value\n", 292 | "\n", 293 | "def set_jazz(button):\n", 294 | " for idx, widget in enumerate(defaulting_widgets):\n", 295 | " if idx == 0:\n", 296 | " widget.value = 'jazz'\n", 297 | " elif idx == 5:\n", 298 | " widget.value = 0.99\n", 299 | " elif idx == 7:\n", 300 | " widget.value = 1\n", 301 | " elif idx == 8:\n", 302 | " widget.value = 0.99\n", 303 | " else:\n", 304 | " widget.value = widget.default_value\n", 305 | "\n", 306 | "def set_pop(button):\n", 307 | " for idx, widget in enumerate(defaulting_widgets):\n", 308 | " if idx == 0:\n", 309 | " widget.value = 'pop'\n", 310 | " elif idx == 5:\n", 311 | " widget.value = 0.98\n", 312 | " elif idx == 7:\n", 313 | " widget.value = 4\n", 314 | " elif idx == 8:\n", 315 | " widget.value = 0.99\n", 316 | " else:\n", 317 | " widget.value = widget.default_value\n", 318 | "\n", 319 | "calm_value_button = ipywidgets.Button(description='Calm')\n", 320 | "calm_value_button.on_click(set_calm)\n", 321 | "jazz_value_button = ipywidgets.Button(description='Jazz')\n", 322 | "jazz_value_button.on_click(set_jazz)\n", 323 | "pop_value_button = ipywidgets.Button(description='Pop')\n", 324 | "pop_value_button.on_click(set_pop)\n", 325 | "classic_value_button = ipywidgets.Button(description='Classic')\n", 326 | "classic_value_button.on_click(set_classic)\n", 327 | "buttons = (calm_value_button, jazz_value_button, pop_value_button, classic_value_button)\n", 328 | "left_box = ipywidgets.VBox((buttons[0], buttons[1]))\n", 329 | "right_box = ipywidgets.VBox((buttons[2], buttons[3]))\n", 330 | "\n", 331 | "def update_dropdown():\n", 332 | " select_primer_widget.options = ['None'] + sorted(map(str, Path('primers').glob('*.*')))\n", 333 | "\n", 334 | "def on_upload(x):\n", 335 | " x = x['new']\n", 336 | " names = x.keys()\n", 337 | " os.makedirs('primers', exist_ok=True)\n", 338 | " for name in names:\n", 339 | " content = x[name]['content']\n", 340 | " if content:\n", 341 | " path = 'primers/'+name\n", 342 | " with open(path, \"wb\") as fp:\n", 343 | " fp.write(content)\n", 344 | " try:\n", 345 | " pretty_midi.PrettyMIDI(path)\n", 346 | " except:\n", 347 | " print(f'file \"{name}\" is corrupted or not a MIDI!')\n", 348 | " os.remove(path)\n", 349 | " update_dropdown()\n", 350 | "\n", 351 | "select_primer_widget = ipywidgets.Dropdown()\n", 352 | "update_dropdown()\n", 353 | "\n", 354 | "uploader = ipywidgets.FileUpload(description='Upload MIDI', multiple=True)\n", 355 | "uploader.observe(on_upload, names='value')\n", 356 | "\n", 357 | "primer_len = ipywidgets.FloatSlider(\n", 358 | " value=15,\n", 359 | " min=0,\n", 360 | " max=60,\n", 361 | " step=0.1,\n", 362 | " description='Seconds:',\n", 363 | " continuous_update=False,\n", 364 | " style=style\n", 365 | ")\n", 366 | "\n", 367 | "primer_position = ipywidgets.RadioButtons(options=[['From start',0],['From end',1]], orientation='horizontal')\n", 368 | "\n", 369 | "AppLayout(header=ipywidgets.HBox([Label('Select primer:'), select_primer_widget, uploader, primer_len, primer_position]),\n", 370 | " left_sidebar=VBox([genre_to_generate, seed, b_size, seq_length, remove_bad_generations]),\n", 371 | " center=VBox([temp, topk, at_least_k, topp, topp_temperature]),\n", 372 | " right_sidebar=VBox([use_rp, rp_penalty, restore_speed, ipywidgets.HBox((left_box, right_box))]),\n", 373 | " footer=None,\n", 374 | " pane_widths=[1, 1, 1],\n", 375 | " pane_heights=[1, 5, '60px'])" 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "execution_count": null, 381 | "metadata": { 382 | "cellView": "form", 383 | "id": "XPdirMXYoign" 384 | }, 385 | "outputs": [], 386 | "source": [ 387 | "#@title 2. Start generating\n", 388 | "load_path = '/content/model_finetune_700k.pt'\n", 389 | "out_dir = 'generated_' + time.strftime('%d-%m-%Y_%H-%M-%S') \n", 390 | "device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'\n", 391 | "\n", 392 | "if device == 'cpu':\n", 393 | " print('Generating on CPU. Expect lower generation speed.')\n", 394 | "\n", 395 | "os.makedirs(out_dir, exist_ok=True)\n", 396 | "genre_id = genre2id[genre_to_generate.value]\n", 397 | "\n", 398 | "params = dict(\n", 399 | " target_seq_length = seq_length.value,\n", 400 | " temperature = temp.value,\n", 401 | " topk = topk.value,\n", 402 | " topp = topp.value,\n", 403 | " topp_temperature = topp_temperature.value,\n", 404 | " at_least_k = at_least_k.value,\n", 405 | " use_rp = use_rp.value,\n", 406 | " rp_penalty = rp_penalty.value,\n", 407 | " rp_restore_speed = restore_speed.value,\n", 408 | " seed = int(seed.value) if seed.value else None\n", 409 | ")\n", 410 | "max_primer_tokens = 512\n", 411 | "\n", 412 | "# Init model\n", 413 | "print('loading model...')\n", 414 | "model = MusicTransformer(device, n_layers=12, d_model=1024, dim_feedforward=2048, num_heads=16, vocab_size=constants.VOCAB_SIZE, rpr=True).to(device).eval()\n", 415 | "model.load_state_dict(torch.load(load_path, map_location=device))\n", 416 | "\n", 417 | "# Add genre and primer\n", 418 | "primer_genre = np.repeat([genre_id], b_size.value)[:,None] + constants.VOCAB_SIZE - 4\n", 419 | "if primer_len.value > 0 and select_primer_widget.value != 'None':\n", 420 | " file = select_primer_widget.value\n", 421 | " midi = pretty_midi.PrettyMIDI(file)\n", 422 | " from_end = primer_position.value == 1\n", 423 | " truncate_midi(midi, primer_len.value, from_end) # truncate to specified length (in seconds)\n", 424 | " encoded = midi_processing.encode(midi)\n", 425 | " # truncate to max_primer_tokens (in tokens)\n", 426 | " l = len(encoded)\n", 427 | " if l > max_primer_tokens:\n", 428 | " import warnings\n", 429 | " warnings.warn('Primer MIDI is too long (length > 512), it will be truncated to 512!')\n", 430 | " if from_end:\n", 431 | " encoded = encoded[l-max_primer_tokens:]\n", 432 | " else:\n", 433 | " encoded = encoded[:max_primer_tokens]\n", 434 | " primer_seq = np.repeat(np.array(encoded)[None], b_size.value, 0)\n", 435 | " primer = np.concatenate([primer_genre, primer_seq], -1)\n", 436 | "else:\n", 437 | " primer = primer_genre\n", 438 | "primer = torch.tensor(primer, dtype=torch.int64)\n", 439 | "\n", 440 | "# Generation\n", 441 | "if primer.shape[-1] >= seq_length.value:\n", 442 | " print('Nothing to generate. Try to set larger \"Sequence length\" parameter!')\n", 443 | "else:\n", 444 | " while len(glob.glob(out_dir + '/*.mid')) != b_size.value:\n", 445 | " print('generating to:', os.path.abspath(out_dir))\n", 446 | " generated = generation.generate(model, primer, **params)\n", 447 | " generated = generation.post_process(generated, remove_bad_generations=remove_bad_generations.value)\n", 448 | " decode_and_write(generated, primer, primer_genre.squeeze(-1)-390, out_dir)\n", 449 | " files_to_delete = len(glob.glob(out_dir + '/*.mid')) - b_size.value\n", 450 | " if files_to_delete > 0:\n", 451 | " for idx in range(files_to_delete):\n", 452 | " os.remove(sorted(glob.glob(out_dir + '/*.mid'), key=lambda x: int(x.split('_')[-2]))[-idx])\n", 453 | " for midi_name in glob.glob(out_dir + '/*.mid'):\n", 454 | " convert_midi_to_wav(midi_name)" 455 | ] 456 | }, 457 | { 458 | "cell_type": "code", 459 | "execution_count": null, 460 | "metadata": { 461 | "cellView": "form", 462 | "id": "IHRrMDeyk2GI" 463 | }, 464 | "outputs": [], 465 | "source": [ 466 | "#@title 3. Listen and download the generation results\n", 467 | "wav_files = glob.glob(out_dir + '/*.wav')\n", 468 | "if len(wav_files) > 3:\n", 469 | " rows = round(len(wav_files) / 3) + 1\n", 470 | " columns = 3\n", 471 | "else:\n", 472 | " rows = 1\n", 473 | " columns = len(wav_files)\n", 474 | "grid = widgets.Grid(rows, columns)\n", 475 | "current_position = 0\n", 476 | "for row in range(rows):\n", 477 | " if current_position > len(wav_files) - 1:\n", 478 | " break\n", 479 | " for column in range(columns):\n", 480 | " if current_position > len(wav_files) - 1:\n", 481 | " break\n", 482 | " with grid.output_to(row, column):\n", 483 | " genre = rugenre[wav_files[current_position].split('.wav')[0].split('_')[-1]]\n", 484 | " print(f'Номер трека: {current_position + 1}')\n", 485 | " print(f'Жанр: {genre}')\n", 486 | " display(Audio(wav_files[current_position]))\n", 487 | " display(DownloadButton(filename=wav_files[current_position], description='Скачать .wav'))\n", 488 | " display(DownloadButton(filename=wav_files[current_position].replace('.wav', '.mid'), description='Скачать .midi'))\n", 489 | " current_position += 1" 490 | ] 491 | } 492 | ], 493 | "metadata": { 494 | "accelerator": "GPU", 495 | "colab": { 496 | "collapsed_sections": [], 497 | "name": "Music_Composer_Demo_Colab.ipynb", 498 | "private_outputs": true, 499 | "provenance": [] 500 | }, 501 | "kernelspec": { 502 | "display_name": "Python 3", 503 | "language": "python", 504 | "name": "python3" 505 | }, 506 | "language_info": { 507 | "codemirror_mode": { 508 | "name": "ipython", 509 | "version": 3 510 | }, 511 | "file_extension": ".py", 512 | "mimetype": "text/x-python", 513 | "name": "python", 514 | "nbconvert_exporter": "python", 515 | "pygments_lexer": "ipython3", 516 | "version": "3.8.5" 517 | } 518 | }, 519 | "nbformat": 4, 520 | "nbformat_minor": 1 521 | } 522 | -------------------------------------------------------------------------------- /src/Music_Composer_Demo_Colab_ru.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Music_Composer_Demo_Colab.ipynb", 7 | "private_outputs": true, 8 | "provenance": [], 9 | "collapsed_sections": [] 10 | }, 11 | "kernelspec": { 12 | "display_name": "Python 3", 13 | "language": "python", 14 | "name": "python3" 15 | }, 16 | "language_info": { 17 | "codemirror_mode": { 18 | "name": "ipython", 19 | "version": 3 20 | }, 21 | "file_extension": ".py", 22 | "mimetype": "text/x-python", 23 | "name": "python", 24 | "nbconvert_exporter": "python", 25 | "pygments_lexer": "ipython3", 26 | "version": "3.8.5" 27 | }, 28 | "accelerator": "GPU" 29 | }, 30 | "cells": [ 31 | { 32 | "cell_type": "code", 33 | "metadata": { 34 | "id": "S_i1BkSBikEs", 35 | "cellView": "form" 36 | }, 37 | "source": [ 38 | "#@title 0. Подготовка окружения к работе\n", 39 | "#@markdown Тут просто скачиваются и импортируются библиотеки, ничего интересного\n", 40 | "!pip install gdown > /dev/null\n", 41 | "!apt install fluidsynth > /dev/null\n", 42 | "!pip install midi2audio > /dev/null\n", 43 | "!pip install pretty_midi > /dev/null\n", 44 | "!cp /usr/share/sounds/sf2/FluidR3_GM.sf2 ./font.sf2\n", 45 | "!git clone https://github.com/sberbank-ai/music-composer.git\n", 46 | "\n", 47 | "import os\n", 48 | "import sys\n", 49 | "import time\n", 50 | "import glob\n", 51 | "import torch\n", 52 | "import gdown\n", 53 | "import webbrowser\n", 54 | "import ipywidgets\n", 55 | "import pretty_midi\n", 56 | "import numpy as np\n", 57 | "\n", 58 | "from tqdm import tqdm\n", 59 | "from pathlib import Path\n", 60 | "from midi2audio import FluidSynth\n", 61 | "from google.colab import widgets, files\n", 62 | "from IPython.display import Audio, display, FileLink, HTML\n", 63 | "from ipywidgets import AppLayout, HBox, VBox, Label\n", 64 | "\n", 65 | "url = 'https://drive.google.com/u/0/uc?id=1lcNp0y4IZMIos0ASSsERG25WVDuEmLks'\n", 66 | "output = 'model_finetune_700k.pt'\n", 67 | "gdown.download(url, output, quiet=True)\n", 68 | "\n", 69 | "\n", 70 | "sys.path.append('/content/music-composer/src/')\n", 71 | "from lib import constants\n", 72 | "from lib import generation\n", 73 | "from lib import midi_processing\n", 74 | "from lib.midi_processing import PIANO_RANGE\n", 75 | "from lib.model.transformer import MusicTransformer\n", 76 | "from lib.colab_utils import id2genre, rugenre, genre2id, decode_and_write, convert_midi_to_wav, DownloadButton" 77 | ], 78 | "execution_count": null, 79 | "outputs": [] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "metadata": { 84 | "id": "Yn4cWWoOJHfi", 85 | "cellView": "form" 86 | }, 87 | "source": [ 88 | "#@title 1. Панель управления генерацией\n", 89 | "#@markdown В данной панели можно настроить параметры генерации музыки под себя. Пройдемся по всем пунктам:\n", 90 | "#@markdown * Вы можете выбрать __primer__ - это ваш MIDI файл, который будет взят за начало композиции. Если он длинный, то произойдет его обрезка. Укажите желаемую длину (в секундах) и способ его обрезки - с начала, или с конца (From start / From end). __From start__ означает, что мы оставим первые N сек, __From end__ - посление N сек.\n", 91 | "#@markdown * __Genre__ - позволяет выбрать жанр в котором будут сгенерированы треки. Для каждого из жанров мы подобрали свой набор параметров, который позволяет лучше всего его генерировать, но об этом позднее. \n", 92 | "#@markdown * __Seed__ - позволяет зафиксировать сид с которым генерируются треки. Его фиксация может быть полезна если вы хотите изучить влияние того или иного параметра на генерацию - задайте численный сид и генерируйте несколько раз изменяя изучаемый параметр. В остальных случаях рекомендуется не заполнять данное поле. \n", 93 | "#@markdown * __Batch_size__ - количество одновременно генерируемых треков. Если переборщить - может не хватить памяти. Тут нужно подбирать баланс между кол-вом одновременных треков и их длиной. \n", 94 | "#@markdown * __Sequence length__ - длина трека. В связи с особенностями генерации мы не можем заранее сказать сколько по времени будет длиться трек, но увеличение этого параметра потенциально увеличивает длину трека. Именно с этим параметром надо балансировать Batch_size чтобы не закончилась видеопамять\n", 95 | "#@markdown * __Remove bad generations__ - иногда модель начинает генерировать мусор. Если данный флаг включен - мы постараемся его обнаружить, отсеять и сгенерировать вместо него новые композиции. Соответственно увеличивает время генерации.\n", 96 | "#@markdown * __Temperature__ - температурный скейлинг. Влияет на само распределение вероятностей, делая его более либо менее равновероятным. Тем самым регулируется разнообразие генерации.\n", 97 | "#@markdown * __TopK__ - ограничение набора по верхней границе. Сэмплирование происходит из набора k наиболее вероятных токенов.\n", 98 | "#@markdown * __At least K__ - ограничение набора по нижней границе. Сэмплирование будет гарантированно происходить как минимум из at_least_k наиболее вероятных токенов.\n", 99 | "#@markdown * __TopP__ - параметр topp о котором можно подробнее прочитать в статье.\n", 100 | "#@markdown * __TopP Temperature__ - температурный скейлинг применяющийся после отбора по критерию topp.\n", 101 | "#@markdown * __Use Repetition Penalty__ - флаг для использования штрафов за повторы нот. Рекомендуется включать только если модель генерирует цикличные скучные треки.\n", 102 | "#@markdown * __RP Penalty__ - размер штрафов за повтор нот. Чем выше тем меньше модель зацикливается в одних музыкальных фразах.\n", 103 | "#@markdown * __Restore speed__ - скорость восстановления после штрафов модуля repetition penalty \n", 104 | "\n", 105 | "#@markdown Также в правом нижнем углу доступны кнопки по жанрам. С помощью них можно вернуться к подобранным нами параметрам генерации под каждый из жанров.\n", 106 | "\n", 107 | "\n", 108 | "def truncate_midi(midi, primer_len_sec=15.0, from_end=False):\n", 109 | " time0 = max([inst.notes[-1].end for inst in midi.instruments]) if from_end else 0\n", 110 | "\n", 111 | " for inst in midi.instruments:\n", 112 | " notes = sorted(inst.notes, key=lambda x: x.start, reverse=from_end)\n", 113 | " for i,note in enumerate(notes):\n", 114 | " if np.abs(note.start - time0) > primer_len_sec:\n", 115 | " break\n", 116 | " if not from_end and note.end > primer_len_sec:\n", 117 | " note.end = time0 - (int(from_end)*2-1) * primer_len_sec\n", 118 | " inst.notes = notes[:i]\n", 119 | " if from_end:\n", 120 | " inst.notes = inst.notes[::-1]\n", 121 | "\n", 122 | "\n", 123 | "style = {'description_width': 'initial'}\n", 124 | "genre_to_generate = ipywidgets.Dropdown(\n", 125 | " options=['calm', 'jazz', 'pop', 'classic'],\n", 126 | " value='calm',\n", 127 | " description='Genre:',\n", 128 | " disabled=False,\n", 129 | ")\n", 130 | "genre_to_generate.default_value = 'calm'\n", 131 | "\n", 132 | "seed = ipywidgets.Text(\n", 133 | " value='',\n", 134 | " placeholder='leave blank for random seed',\n", 135 | " description='Seed:',\n", 136 | " disabled=False\n", 137 | ")\n", 138 | "seed.default_value = ''\n", 139 | "\n", 140 | "temp = ipywidgets.FloatSlider( \n", 141 | " value=1.0,\n", 142 | " min=0.01,\n", 143 | " max=5.0,\n", 144 | " step=0.01,\n", 145 | " description='Temperature:',\n", 146 | " disabled=False,\n", 147 | " continuous_update=False,\n", 148 | " orientation='horizontal',\n", 149 | " readout=True,\n", 150 | " readout_format='.2f',\n", 151 | " style=style\n", 152 | ")\n", 153 | "temp.default_value = 1.0\n", 154 | "\n", 155 | "b_size = ipywidgets.IntSlider(\n", 156 | " value=8,\n", 157 | " min=1,\n", 158 | " max=16,\n", 159 | " step=1,\n", 160 | " description='Batch size:',\n", 161 | " disabled=False,\n", 162 | " continuous_update=False,\n", 163 | " orientation='horizontal',\n", 164 | " readout=True,\n", 165 | " readout_format='d',\n", 166 | " style=style\n", 167 | ")\n", 168 | "b_size.default_value = 8\n", 169 | "\n", 170 | "seq_length = ipywidgets.IntSlider(\n", 171 | " value=512,\n", 172 | " min=256,\n", 173 | " max=2048,\n", 174 | " step=256,\n", 175 | " description='Sequence length:',\n", 176 | " disabled=False,\n", 177 | " continuous_update=False,\n", 178 | " orientation='horizontal',\n", 179 | " readout=True,\n", 180 | " readout_format='d',\n", 181 | " style=style\n", 182 | ")\n", 183 | "seq_length.default_value = 512\n", 184 | "\n", 185 | "topk = ipywidgets.IntSlider(\n", 186 | " value=60,\n", 187 | " min=1,\n", 188 | " max=300,\n", 189 | " step=1,\n", 190 | " description='Top k:',\n", 191 | " disabled=False,\n", 192 | " continuous_update=False,\n", 193 | " orientation='horizontal',\n", 194 | " readout=True,\n", 195 | " readout_format='d',\n", 196 | " style=style\n", 197 | ")\n", 198 | "topk.default_value = 60\n", 199 | "\n", 200 | "at_least_k = ipywidgets.IntSlider(\n", 201 | " value=1,\n", 202 | " min=1,\n", 203 | " max=300,\n", 204 | " step=1,\n", 205 | " description='At least k:',\n", 206 | " disabled=False,\n", 207 | " continuous_update=False,\n", 208 | " orientation='horizontal',\n", 209 | " readout=True,\n", 210 | " readout_format='d',\n", 211 | " style=style\n", 212 | ")\n", 213 | "at_least_k.default_value = 1\n", 214 | "\n", 215 | "topp = ipywidgets.FloatSlider( \n", 216 | " value=0.99,\n", 217 | " min=0.5,\n", 218 | " max=1.0,\n", 219 | " step=0.01,\n", 220 | " description='Topp:',\n", 221 | " disabled=False,\n", 222 | " continuous_update=False,\n", 223 | " orientation='horizontal',\n", 224 | " readout=True,\n", 225 | " readout_format='.2f',\n", 226 | " style=style\n", 227 | ")\n", 228 | "topp.default_value = 0.99\n", 229 | "\n", 230 | "topp_temperature = ipywidgets.FloatSlider( \n", 231 | " value=1.0,\n", 232 | " min=0.01,\n", 233 | " max=5.0,\n", 234 | " step=0.01,\n", 235 | " description='Topp Temperature:',\n", 236 | " disabled=False,\n", 237 | " continuous_update=False,\n", 238 | " orientation='horizontal',\n", 239 | " readout=True,\n", 240 | " readout_format='.2f',\n", 241 | " style=style\n", 242 | ")\n", 243 | "topp_temperature.default_value = 1.0\n", 244 | "\n", 245 | "use_rp = ipywidgets.Checkbox(\n", 246 | " value=False,\n", 247 | " description='Use Repetition Penalty',\n", 248 | " disabled=False,\n", 249 | " style=style\n", 250 | ")\n", 251 | "use_rp.default_value = False\n", 252 | "\n", 253 | "rp_penalty = ipywidgets.FloatSlider( \n", 254 | " value=0.05,\n", 255 | " min=0.,\n", 256 | " max=1.0,\n", 257 | " step=0.05,\n", 258 | " description='RP penalty:',\n", 259 | " disabled=False,\n", 260 | " continuous_update=False,\n", 261 | " orientation='horizontal',\n", 262 | " readout=True,\n", 263 | " readout_format='.2f',\n", 264 | " style=style\n", 265 | ")\n", 266 | "rp_penalty.default_value = 0.05\n", 267 | "\n", 268 | "restore_speed = ipywidgets.FloatSlider( \n", 269 | " value=0.7,\n", 270 | " min=0.,\n", 271 | " max=1.0,\n", 272 | " step=0.05,\n", 273 | " description='Restore speed:',\n", 274 | " disabled=False,\n", 275 | " continuous_update=False,\n", 276 | " orientation='horizontal',\n", 277 | " readout=True,\n", 278 | " readout_format='.2f',\n", 279 | " style=style\n", 280 | ")\n", 281 | "restore_speed.default_value = 0.7\n", 282 | "\n", 283 | "remove_bad_generations = ipywidgets.Checkbox(\n", 284 | " value=True,\n", 285 | " description='Remove bad generations',\n", 286 | " disabled=False,\n", 287 | " style=style\n", 288 | ")\n", 289 | "remove_bad_generations.default_value = True\n", 290 | "\n", 291 | "defaulting_widgets = (genre_to_generate, seed, b_size, seq_length, remove_bad_generations,\n", 292 | " temp, topk, at_least_k, topp, topp_temperature, use_rp, rp_penalty, restore_speed)\n", 293 | "def set_classic(button):\n", 294 | " for idx, widget in enumerate(defaulting_widgets):\n", 295 | " if idx == 0:\n", 296 | " widget.value = 'classic'\n", 297 | " elif idx == 5:\n", 298 | " widget.value = 1.0\n", 299 | " elif idx == 7:\n", 300 | " widget.value = 1\n", 301 | " elif idx == 8:\n", 302 | " widget.value = 0.99\n", 303 | " else:\n", 304 | " widget.value = widget.default_value\n", 305 | "\n", 306 | "def set_calm(button):\n", 307 | " for idx, widget in enumerate(defaulting_widgets):\n", 308 | " if idx == 0:\n", 309 | " widget.value = 'calm'\n", 310 | " elif idx == 5:\n", 311 | " widget.value = 1.03\n", 312 | " elif idx == 7:\n", 313 | " widget.value = 4\n", 314 | " elif idx == 8:\n", 315 | " widget.value = 0.98\n", 316 | " else:\n", 317 | " widget.value = widget.default_value\n", 318 | "\n", 319 | "def set_jazz(button):\n", 320 | " for idx, widget in enumerate(defaulting_widgets):\n", 321 | " if idx == 0:\n", 322 | " widget.value = 'jazz'\n", 323 | " elif idx == 5:\n", 324 | " widget.value = 0.99\n", 325 | " elif idx == 7:\n", 326 | " widget.value = 1\n", 327 | " elif idx == 8:\n", 328 | " widget.value = 0.99\n", 329 | " else:\n", 330 | " widget.value = widget.default_value\n", 331 | "\n", 332 | "def set_pop(button):\n", 333 | " for idx, widget in enumerate(defaulting_widgets):\n", 334 | " if idx == 0:\n", 335 | " widget.value = 'pop'\n", 336 | " elif idx == 5:\n", 337 | " widget.value = 0.98\n", 338 | " elif idx == 7:\n", 339 | " widget.value = 4\n", 340 | " elif idx == 8:\n", 341 | " widget.value = 0.99\n", 342 | " else:\n", 343 | " widget.value = widget.default_value\n", 344 | "\n", 345 | "calm_value_button = ipywidgets.Button(description='Calm')\n", 346 | "calm_value_button.on_click(set_calm)\n", 347 | "jazz_value_button = ipywidgets.Button(description='Jazz')\n", 348 | "jazz_value_button.on_click(set_jazz)\n", 349 | "pop_value_button = ipywidgets.Button(description='Pop')\n", 350 | "pop_value_button.on_click(set_pop)\n", 351 | "classic_value_button = ipywidgets.Button(description='Classic')\n", 352 | "classic_value_button.on_click(set_classic)\n", 353 | "buttons = (calm_value_button, jazz_value_button, pop_value_button, classic_value_button)\n", 354 | "left_box = ipywidgets.VBox((buttons[0], buttons[1]))\n", 355 | "right_box = ipywidgets.VBox((buttons[2], buttons[3]))\n", 356 | "\n", 357 | "def update_dropdown():\n", 358 | " select_primer_widget.options = ['None'] + sorted(map(str, Path('primers').glob('*.*')))\n", 359 | "\n", 360 | "def on_upload(x):\n", 361 | " x = x['new']\n", 362 | " names = x.keys()\n", 363 | " os.makedirs('primers', exist_ok=True)\n", 364 | " for name in names:\n", 365 | " content = x[name]['content']\n", 366 | " if content:\n", 367 | " path = 'primers/'+name\n", 368 | " with open(path, \"wb\") as fp:\n", 369 | " fp.write(content)\n", 370 | " try:\n", 371 | " pretty_midi.PrettyMIDI(path)\n", 372 | " except:\n", 373 | " print(f'file \"{name}\" is corrupted or not a MIDI!')\n", 374 | " os.remove(path)\n", 375 | " update_dropdown()\n", 376 | "\n", 377 | "select_primer_widget = ipywidgets.Dropdown()\n", 378 | "update_dropdown()\n", 379 | "\n", 380 | "uploader = ipywidgets.FileUpload(description='Upload MIDI', multiple=True)\n", 381 | "uploader.observe(on_upload, names='value')\n", 382 | "\n", 383 | "primer_len = ipywidgets.FloatSlider(\n", 384 | " value=15,\n", 385 | " min=0,\n", 386 | " max=60,\n", 387 | " step=0.1,\n", 388 | " description='Seconds:',\n", 389 | " continuous_update=False,\n", 390 | " style=style\n", 391 | ")\n", 392 | "\n", 393 | "primer_position = ipywidgets.RadioButtons(options=[['From start',0],['From end',1]], orientation='horizontal')\n", 394 | "\n", 395 | "AppLayout(header=ipywidgets.HBox([Label('Select primer:'), select_primer_widget, uploader, primer_len, primer_position]),\n", 396 | " left_sidebar=VBox([genre_to_generate, seed, b_size, seq_length, remove_bad_generations]),\n", 397 | " center=VBox([temp, topk, at_least_k, topp, topp_temperature]),\n", 398 | " right_sidebar=VBox([use_rp, rp_penalty, restore_speed, ipywidgets.HBox((left_box, right_box))]),\n", 399 | " footer=None,\n", 400 | " pane_widths=[1, 1, 1],\n", 401 | " pane_heights=[1, 5, '60px'])" 402 | ], 403 | "execution_count": null, 404 | "outputs": [] 405 | }, 406 | { 407 | "cell_type": "code", 408 | "metadata": { 409 | "id": "XPdirMXYoign", 410 | "cellView": "form" 411 | }, 412 | "source": [ 413 | "#@title 2. Запуск процесса генерации\n", 414 | "load_path = '/content/model_finetune_700k.pt'\n", 415 | "out_dir = 'generated_' + time.strftime('%d-%m-%Y_%H-%M-%S') \n", 416 | "device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'\n", 417 | "\n", 418 | "if device == 'cpu':\n", 419 | " print('Generating on CPU. Expect lower generation speed.')\n", 420 | "\n", 421 | "os.makedirs(out_dir, exist_ok=True)\n", 422 | "genre_id = genre2id[genre_to_generate.value]\n", 423 | "\n", 424 | "params = dict(\n", 425 | " target_seq_length = seq_length.value,\n", 426 | " temperature = temp.value,\n", 427 | " topk = topk.value,\n", 428 | " topp = topp.value,\n", 429 | " topp_temperature = topp_temperature.value,\n", 430 | " at_least_k = at_least_k.value,\n", 431 | " use_rp = use_rp.value,\n", 432 | " rp_penalty = rp_penalty.value,\n", 433 | " rp_restore_speed = restore_speed.value,\n", 434 | " seed = int(seed.value) if seed.value else None\n", 435 | ")\n", 436 | "max_primer_tokens = 512\n", 437 | "\n", 438 | "# Init model\n", 439 | "print('loading model...')\n", 440 | "model = MusicTransformer(device, n_layers=12, d_model=1024, dim_feedforward=2048, num_heads=16, vocab_size=constants.VOCAB_SIZE, rpr=True).to(device).eval()\n", 441 | "model.load_state_dict(torch.load(load_path, map_location=device))\n", 442 | "\n", 443 | "# Add genre and primer\n", 444 | "primer_genre = np.repeat([genre_id], b_size.value)[:,None] + constants.VOCAB_SIZE - 4\n", 445 | "if primer_len.value > 0 and select_primer_widget.value != 'None':\n", 446 | " file = select_primer_widget.value\n", 447 | " midi = pretty_midi.PrettyMIDI(file)\n", 448 | " from_end = primer_position.value == 1\n", 449 | " truncate_midi(midi, primer_len.value, from_end) # truncate to specified length (in seconds)\n", 450 | " encoded = midi_processing.encode(midi)\n", 451 | " # truncate to max_primer_tokens (in tokens)\n", 452 | " l = len(encoded)\n", 453 | " if l > max_primer_tokens:\n", 454 | " import warnings\n", 455 | " warnings.warn('Primer MIDI is too long (length > 512), it will be truncated to 512!')\n", 456 | " if from_end:\n", 457 | " encoded = encoded[l-max_primer_tokens:]\n", 458 | " else:\n", 459 | " encoded = encoded[:max_primer_tokens]\n", 460 | " primer_seq = np.repeat(np.array(encoded)[None], b_size.value, 0)\n", 461 | " primer = np.concatenate([primer_genre, primer_seq], -1)\n", 462 | "else:\n", 463 | " primer = primer_genre\n", 464 | "primer = torch.tensor(primer, dtype=torch.int64)\n", 465 | "\n", 466 | "# Generation\n", 467 | "if primer.shape[-1] >= seq_length.value:\n", 468 | " print('Nothing to generate. Try to set larger \"Sequence length\" parameter!')\n", 469 | "else:\n", 470 | " while len(glob.glob(out_dir + '/*.mid')) != b_size.value:\n", 471 | " print('generating to:', os.path.abspath(out_dir))\n", 472 | " generated = generation.generate(model, primer, **params)\n", 473 | " generated = generation.post_process(generated, remove_bad_generations=remove_bad_generations.value)\n", 474 | " decode_and_write(generated, primer, primer_genre.squeeze(-1)-390, out_dir)\n", 475 | " files_to_delete = len(glob.glob(out_dir + '/*.mid')) - b_size.value\n", 476 | " if files_to_delete > 0:\n", 477 | " for idx in range(files_to_delete):\n", 478 | " os.remove(sorted(glob.glob(out_dir + '/*.mid'), key=lambda x: int(x.split('_')[-2]))[-idx])\n", 479 | " for midi_name in glob.glob(out_dir + '/*.mid'):\n", 480 | " convert_midi_to_wav(midi_name)" 481 | ], 482 | "execution_count": null, 483 | "outputs": [] 484 | }, 485 | { 486 | "cell_type": "code", 487 | "metadata": { 488 | "cellView": "form", 489 | "id": "IHRrMDeyk2GI" 490 | }, 491 | "source": [ 492 | "#@title 3. Прослушать и скачать результаты генерации\n", 493 | "wav_files = glob.glob(out_dir + '/*.wav')\n", 494 | "if len(wav_files) > 3:\n", 495 | " rows = round(len(wav_files) / 3) + 1\n", 496 | " columns = 3\n", 497 | "else:\n", 498 | " rows = 1\n", 499 | " columns = len(wav_files)\n", 500 | "grid = widgets.Grid(rows, columns)\n", 501 | "current_position = 0\n", 502 | "for row in range(rows):\n", 503 | " if current_position > len(wav_files) - 1:\n", 504 | " break\n", 505 | " for column in range(columns):\n", 506 | " if current_position > len(wav_files) - 1:\n", 507 | " break\n", 508 | " with grid.output_to(row, column):\n", 509 | " genre = rugenre[wav_files[current_position].split('.wav')[0].split('_')[-1]]\n", 510 | " print(f'Номер трека: {current_position + 1}')\n", 511 | " print(f'Жанр: {genre}')\n", 512 | " display(Audio(wav_files[current_position]))\n", 513 | " display(DownloadButton(filename=wav_files[current_position], description='Скачать .wav'))\n", 514 | " display(DownloadButton(filename=wav_files[current_position].replace('.wav', '.mid'), description='Скачать .midi'))\n", 515 | " current_position += 1" 516 | ], 517 | "execution_count": null, 518 | "outputs": [] 519 | } 520 | ] 521 | } -------------------------------------------------------------------------------- /gpt2-rga/[GPT2RGA] Quantum_Music.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "gradient": { 7 | "editing": false, 8 | "id": "ac5a4cf0-d9d2-47b5-9633-b53f8d99a4d2", 9 | "kernelId": "" 10 | }, 11 | "id": "SiTIpPjArIyr" 12 | }, 13 | "source": [ 14 | "# Quantum Music (ver. 1.0)\n", 15 | "\n", 16 | "## Tokenized Sparse Time Quantization Example\n", 17 | "\n", 18 | "***\n", 19 | "\n", 20 | "Powered by tegridy-tools TMIDIX Optimus Processors: https://github.com/asigalov61/tegridy-tools\n", 21 | "\n", 22 | "***\n", 23 | "\n", 24 | "Credit for GPT2-RGA code used in this colab goes out @ Sashmark97 https://github.com/Sashmark97/midigen and @ Damon Gwinn https://github.com/gwinndr/MusicTransformer-Pytorch\n", 25 | "\n", 26 | "***\n", 27 | "\n", 28 | "WARNING: This complete implementation is a functioning model of the Artificial Intelligence. Please excercise great humility, care, and respect. https://www.nscai.gov/\n", 29 | "\n", 30 | "***\n", 31 | "\n", 32 | "#### Project Los Angeles\n", 33 | "\n", 34 | "#### Tegridy Code 2021\n", 35 | "\n", 36 | "***" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": { 42 | "gradient": { 43 | "editing": false, 44 | "id": "fa0a611c-1803-42ae-bdf6-a49b5a4e781b", 45 | "kernelId": "" 46 | }, 47 | "id": "gOd93yV0sGd2" 48 | }, 49 | "source": [ 50 | "# (Setup Environment)" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": { 57 | "cellView": "form", 58 | "gradient": { 59 | "editing": false, 60 | "id": "39411b40-9e39-416e-8fe4-d40f733e7956", 61 | "kernelId": "" 62 | }, 63 | "id": "lw-4aqV3sKQG" 64 | }, 65 | "outputs": [], 66 | "source": [ 67 | "#@title nvidia-smi gpu check\n", 68 | "!nvidia-smi" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": { 75 | "cellView": "form", 76 | "gradient": { 77 | "editing": false, 78 | "id": "a1a45a91-d909-4fd4-b67a-5e16b971d179", 79 | "kernelId": "" 80 | }, 81 | "id": "fX12Yquyuihc" 82 | }, 83 | "outputs": [], 84 | "source": [ 85 | "#@title Install all dependencies (run only once per session)\n", 86 | "\n", 87 | "!git clone https://github.com/asigalov61/tegridy-tools\n", 88 | "!pip install torch\n", 89 | "!pip install tqdm\n", 90 | "!pip install matplotlib" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": { 97 | "cellView": "form", 98 | "gradient": { 99 | "editing": false, 100 | "id": "b8207b76-9514-4c07-95db-95a4742e52c5", 101 | "kernelId": "" 102 | }, 103 | "id": "z7n9vnKmug1J" 104 | }, 105 | "outputs": [], 106 | "source": [ 107 | "#@title Import all needed modules\n", 108 | "\n", 109 | "print('Loading needed modules. Please wait...')\n", 110 | "import os\n", 111 | "from datetime import datetime\n", 112 | "import secrets\n", 113 | "import copy\n", 114 | "import tqdm\n", 115 | "from tqdm import tqdm\n", 116 | "\n", 117 | "if not os.path.exists('/notebooks/Dataset'):\n", 118 | " os.makedirs('/notebooks/Dataset')\n", 119 | "\n", 120 | "print('Loading TMIDIX module...')\n", 121 | "os.chdir('/notebooks/tegridy-tools/tegridy-tools')\n", 122 | "import TMIDIX\n", 123 | "\n", 124 | "os.chdir('/notebooks/tegridy-tools/tegridy-tools')\n", 125 | "from GPT2RGAX import *\n", 126 | "\n", 127 | "import matplotlib.pyplot as plt\n", 128 | "\n", 129 | "os.chdir('/notebooks/')" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": { 135 | "id": "ObPxlEutsQBj" 136 | }, 137 | "source": [ 138 | "# (FROM SCRATCH) Download and process MIDI dataset" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "metadata": { 145 | "cellView": "form", 146 | "gradient": { 147 | "id": "ffbb7a2a-d91a-477f-ac89-56d77d6cdf42", 148 | "kernelId": "" 149 | }, 150 | "id": "snIZ3xKPsPgB" 151 | }, 152 | "outputs": [], 153 | "source": [ 154 | "#@title Download Endless Violin Carousel MIDI dataset (Recommended)\n", 155 | "\n", 156 | "#@markdown Piano Violin Duo\n", 157 | "\n", 158 | "#@markdown Works best stand-alone/as-is for the optimal results\n", 159 | "%cd /notebooks/Dataset/\n", 160 | "\n", 161 | "!wget 'https://github.com/asigalov61/Tegridy-MIDI-Dataset/raw/master/Endless-Violin-Carousel-CC-BY-NC-SA.zip'\n", 162 | "!unzip -j '/notebooks/Dataset/Endless-Violin-Carousel-CC-BY-NC-SA.zip'\n", 163 | "!rm '/notebooks/Dataset/Endless-Violin-Carousel-CC-BY-NC-SA.zip'\n", 164 | "\n", 165 | "%cd /notebooks/" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "metadata": { 172 | "cellView": "form", 173 | "gradient": { 174 | "id": "ed07b44f-07fe-45fb-a64f-adba8df1bdcb", 175 | "kernelId": "" 176 | }, 177 | "id": "on7sgKEP3Yc8" 178 | }, 179 | "outputs": [], 180 | "source": [ 181 | "#@title Process MIDIs to special MIDI dataset with Tegridy MIDI Processor\n", 182 | "\n", 183 | "#@markdown IMPORTANT NOTES:\n", 184 | "\n", 185 | "#@markdown 1) Best results are achieved with the single-track, single-channel, single-instrument MIDI 0 files with plain English names (avoid special or sys/foreign chars)\n", 186 | "\n", 187 | "#@markdown 2) MIDI Channel = -1 means all MIDI channels except the drums. MIDI Channel = 16 means all channels will be processed. Otherwise, only single indicated MIDI channel will be processed\n", 188 | "\n", 189 | "desired_dataset_name = \"Quantum-Music-Dataset\" #@param {type:\"string\"}\n", 190 | "file_name_to_output_dataset_to = \"/notebooks/Quantum-Music-Dataset\" #@param {type:\"string\"}\n", 191 | "desired_MIDI_channel_to_process = -1 #@param {type:\"slider\", min:-1, max:16, step:1}\n", 192 | "sorted_or_random_file_loading_order = False #@param {type:\"boolean\"}\n", 193 | "encode_velocities = True #@param {type:\"boolean\"}\n", 194 | "encode_MIDI_channels = True #@param {type:\"boolean\"}\n", 195 | "add_transposed_dataset_by_this_many_pitches = 0 #@param {type:\"slider\", min:-12, max:12, step:1}\n", 196 | "add_transposed_and_flipped_dataset = False #@param {type:\"boolean\"}\n", 197 | "chordify_input_MIDIs = False #@param {type:\"boolean\"}\n", 198 | "melody_conditioned_chords = False #@param {type:\"boolean\"}\n", 199 | "melody_pitch_baseline = 60 #@param {type:\"slider\", min:0, max:127, step:1}\n", 200 | "time_denominator = 1 #@param {type:\"slider\", min:1, max:50, step:1}\n", 201 | "transform_to_pitch = 0 #@param {type:\"slider\", min:0, max:127, step:1}\n", 202 | "perfect_timings = True #@param {type:\"boolean\"}\n", 203 | "MuseNet_encoding = True #@param {type:\"boolean\"}\n", 204 | "chars_encoding_offset = 0 #@param {type:\"number\"}\n", 205 | "\n", 206 | "print('TMIDI Optimus MIDI Processor')\n", 207 | "print('Starting up...')\n", 208 | "###########\n", 209 | "\n", 210 | "average_note_pitch = 0\n", 211 | "min_note = 127\n", 212 | "max_note = 0\n", 213 | "\n", 214 | "files_count = 0\n", 215 | "\n", 216 | "gfiles = 0\n", 217 | "\n", 218 | "chords_list_f = []\n", 219 | "melody_list_f = []\n", 220 | "\n", 221 | "chords_list = []\n", 222 | "chords_count = 0\n", 223 | "\n", 224 | "melody_chords = []\n", 225 | "melody_count = 0\n", 226 | "\n", 227 | "TXT_String = ''\n", 228 | "\n", 229 | "TXT = ''\n", 230 | "melody = []\n", 231 | "chords = []\n", 232 | "INTS_f = []\n", 233 | "\n", 234 | "flist = []\n", 235 | "\n", 236 | "###########\n", 237 | "\n", 238 | "print('Loading MIDI files...')\n", 239 | "print('This may take a while on a large dataset in particular.')\n", 240 | "\n", 241 | "dataset_addr = \"/notebooks/Dataset/\"\n", 242 | "os.chdir(dataset_addr)\n", 243 | "filez = list()\n", 244 | "for (dirpath, dirnames, filenames) in os.walk(dataset_addr):\n", 245 | " filez += [os.path.join(dirpath, file) for file in filenames]\n", 246 | "print('=' * 70)\n", 247 | "\n", 248 | "if filez == []:\n", 249 | " print('Could not find any MIDI files. Please check Dataset dir...')\n", 250 | " print('=' * 70)\n", 251 | "\n", 252 | "if sorted_or_random_file_loading_order:\n", 253 | " print('Sorting files...')\n", 254 | " filez.sort()\n", 255 | " print('Done!')\n", 256 | " print('=' * 70)\n", 257 | "\n", 258 | "else:\n", 259 | " random.shuffle(filez)\n", 260 | "\n", 261 | "# Stamping the dataset info\n", 262 | "print('Stamping the dataset info...')\n", 263 | "\n", 264 | "TXT_String += 'DATASET=' + str(desired_dataset_name) + chr(10)\n", 265 | "TXT_String += 'CREATED_ON=' + str(datetime.now()).replace(' ', '-').replace(':', '-').replace('.', '-') + chr(10)\n", 266 | "\n", 267 | "TXT_String += 'CHARS_ENCODING_OFFSET=' + str(chars_encoding_offset) + chr(10)\n", 268 | "TXT_String += 'TIME_DENOMINATOR=' + str(time_denominator) + chr(10)\n", 269 | "TXT_String += 'TRANSFORM=' + str(transform_to_pitch) + chr(10)\n", 270 | "TXT_String += 'PERFECT_TIMINGS=' + str(perfect_timings) + chr(10)\n", 271 | "TXT_String += 'MUSENET_ENCODING=' + str(MuseNet_encoding) + chr(10)\n", 272 | "TXT_String += 'TRANSPOSED_BY=' + str(add_transposed_dataset_by_this_many_pitches) + chr(10)\n", 273 | "TXT_String += 'TRANSPOSED_AND_FLIPPED=' + str(add_transposed_and_flipped_dataset) + chr(10)\n", 274 | "\n", 275 | "TXT_String += 'LEGEND=STA-DUR-PTC'\n", 276 | "if encode_velocities:\n", 277 | " TXT_String += '-VEL'\n", 278 | "if encode_MIDI_channels:\n", 279 | " TXT_String += '-CHA'\n", 280 | "TXT_String += chr(10)\n", 281 | "\n", 282 | "print('Processing MIDI files. Please wait...')\n", 283 | "for f in tqdm(filez):\n", 284 | " try:\n", 285 | " fn = os.path.basename(f)\n", 286 | " fn1 = fn.split('.')[0]\n", 287 | "\n", 288 | " files_count += 1\n", 289 | " TXT, melody, chords, bass_melody, karaokez, INTS, aux1, aux2 = TMIDIX.Optimus_MIDI_TXT_Processor(f, chordify_TXT=chordify_input_MIDIs, output_MIDI_channels=encode_MIDI_channels, char_offset=chars_encoding_offset, dataset_MIDI_events_time_denominator=time_denominator, output_velocity=encode_velocities, MIDI_channel=desired_MIDI_channel_to_process, MIDI_patch=range(0, 127), melody_conditioned_encoding=melody_conditioned_chords, melody_pitch_baseline=melody_pitch_baseline, perfect_timings=perfect_timings, musenet_encoding=MuseNet_encoding, transform=transform_to_pitch)\n", 290 | " TXT_String += TXT\n", 291 | " melody_list_f += melody\n", 292 | " chords_list_f.append(chords)\n", 293 | " INTS_f.append(INTS)\n", 294 | " flist.append([f, fn1])\n", 295 | " gfiles += 1\n", 296 | "\n", 297 | " if add_transposed_dataset_by_this_many_pitches != 0:\n", 298 | "\n", 299 | " TXT, melody, chords, bass_melody, karaokez, INTS, aux1, aux2 = TMIDIX.Optimus_MIDI_TXT_Processor(f, chordify_TXT=chordify_input_MIDIs, output_MIDI_channels=encode_MIDI_channels, char_offset=chars_encoding_offset, dataset_MIDI_events_time_denominator=time_denominator, output_velocity=encode_velocities, MIDI_channel=desired_MIDI_channel_to_process, transpose_by=add_transposed_dataset_by_this_many_pitches, MIDI_patch=range(0, 127), melody_conditioned_encoding=melody_conditioned_chords, melody_pitch_baseline=melody_pitch_baseline, perfect_timings=perfect_timings, musenet_encoding=MuseNet_encoding, transform=transform_to_pitch)\n", 300 | " TXT_String += TXT\n", 301 | " melody_list_f += melody\n", 302 | " chords_list_f.append(chords)\n", 303 | " INTS_f.append(INTS)\n", 304 | " gfiles += 1\n", 305 | "\n", 306 | " if add_transposed_and_flipped_dataset == True:\n", 307 | "\n", 308 | " TXT, melody, chords, bass_melody, karaokez, INTS, aux1, aux2 = TMIDIX.Optimus_MIDI_TXT_Processor(f, chordify_TXT=chordify_input_MIDIs, output_MIDI_channels=encode_MIDI_channels, char_offset=chars_encoding_offset, dataset_MIDI_events_time_denominator=time_denominator, output_velocity=encode_velocities, MIDI_channel=desired_MIDI_channel_to_process, transpose_by=-12, MIDI_patch=range(0, 127), flip=True, melody_conditioned_encoding=melody_conditioned_chords, melody_pitch_baseline=melody_pitch_baseline, perfect_timings=perfect_timings, musenet_encoding=MuseNet_encoding, transform=transform_to_pitch)\n", 309 | " TXT_String += TXT\n", 310 | " melody_list_f += melody\n", 311 | " chords_list_f += chords\n", 312 | " INTS_f.append(INTS)\n", 313 | " gfiles += 1\n", 314 | "\n", 315 | " except KeyboardInterrupt:\n", 316 | " print('Saving current progress and quitting...')\n", 317 | " break \n", 318 | " \n", 319 | " except:\n", 320 | " print('Bad MIDI:', f)\n", 321 | " continue\n", 322 | "\n", 323 | "TXT_String += 'TOTAL_SONGS_IN_DATASET=' + str(gfiles)\n", 324 | "\n", 325 | "try:\n", 326 | " print('Task complete :)')\n", 327 | " print('==================================================')\n", 328 | " if add_transposed_dataset_by_this_many_pitches != 0:\n", 329 | " print('NOTE: Transposed dataset was added per users request.')\n", 330 | " print('==================================================')\n", 331 | " if add_transposed_and_flipped_dataset == True:\n", 332 | " print('NOTE: Flipped dataset was added per users request.') \n", 333 | " print('==================================================')\n", 334 | " print('Number of processed dataset MIDI files:', files_count)\n", 335 | " print('Number of MIDI chords recorded:', len(chords_list_f))\n", 336 | " print('First chord event:', chords_list_f[0], 'Last chord event:', chords_list_f[-1]) \n", 337 | " print('Number of recorded melody events:', len(melody_list_f))\n", 338 | " print('First melody event:', melody_list_f[0], 'Last Melody event:', melody_list_f[-1])\n", 339 | " print('Total number of MIDI events recorded:', len(chords_list_f) + len(melody_list_f))\n", 340 | " print('==================================================')\n", 341 | "\n", 342 | " # Writing dataset to TXT file\n", 343 | " with open(file_name_to_output_dataset_to + '.txt', 'wb') as f:\n", 344 | " f.write(TXT_String.encode('utf-8', 'replace'))\n", 345 | " f.close\n", 346 | "\n", 347 | " # Dataset\n", 348 | " MusicDataset = [chords_list_f, melody_list_f, INTS_f]\n", 349 | "\n", 350 | " # Writing dataset to pickle file\n", 351 | " TMIDIX.Tegridy_Any_Pickle_File_Writer(MusicDataset, file_name_to_output_dataset_to)\n", 352 | "\n", 353 | "except:\n", 354 | " print('=' * 70)\n", 355 | " print('IO Error!')\n", 356 | " print('Please check that Dataset dir is not empty/check other IO code.')\n", 357 | " print('=' * 70)\n", 358 | " print('Shutting down...')\n", 359 | " print('=' * 70)" 360 | ] 361 | }, 362 | { 363 | "cell_type": "code", 364 | "execution_count": null, 365 | "metadata": { 366 | "gradient": { 367 | "id": "0826f622-2edc-4f09-9a01-58df049738d4", 368 | "kernelId": "" 369 | } 370 | }, 371 | "outputs": [], 372 | "source": [ 373 | "INTS_f1 = []\n", 374 | "\n", 375 | "\n", 376 | "for chords_list in tqdm(chords_list_f):\n", 377 | " INTS_f1.append([-1, -1, -1]) # Intro\n", 378 | " pe = chords_list[0]\n", 379 | " for i in chords_list:\n", 380 | "\n", 381 | " INTS_f1.append([int(abs(i[1]-pe[1])/ 10), int(i[2] / 10), i[4] ])\n", 382 | " \n", 383 | " if chords_list.index(i) == len(chords_list)-50:\n", 384 | " INTS_f1.append([-2, -2, -2]) # Outro\n", 385 | " \n", 386 | " \n", 387 | " pe = i\n", 388 | " INTS_f1.append([-3, -3, -3]) # End" 389 | ] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "execution_count": null, 394 | "metadata": { 395 | "cellView": "form", 396 | "gradient": { 397 | "id": "53252e52-5e68-4e60-8e4d-a584667749a4", 398 | "kernelId": "" 399 | }, 400 | "id": "lT0TyqUnpxu_" 401 | }, 402 | "outputs": [], 403 | "source": [ 404 | "#@title Load processed INTs datasets\n", 405 | "number_of_batches = 16 #@param {type:\"slider\", min:2, max:32, step:2}\n", 406 | "n_workers = 6\n", 407 | "\n", 408 | "print('=' * 50)\n", 409 | "print('Prepping INTs datasets...')\n", 410 | "\n", 411 | "\n", 412 | "train_data1 = []\n", 413 | "for i in INTS_f1:\n", 414 | " if max(i) < 256 and min(i) >= 0:\n", 415 | "\n", 416 | " if i[0] < 16:\n", 417 | " train_data1.extend([i[0]])\n", 418 | " else:\n", 419 | " train_data1.extend([16, i[0]-16])\n", 420 | " \n", 421 | " train_data1.extend([256+i[2], 512+i[1]-4 ])\n", 422 | " \n", 423 | " if max(i) == -1 and min(i) == -1: # Intro\n", 424 | " train_data1.extend([256+512-3])\n", 425 | " \n", 426 | " if max(i) == -2 and min(i) == -2: # Outro\n", 427 | " train_data1.extend([256+512-2])\n", 428 | " \n", 429 | " if max(i) == -3 and min(i) == -3: # End\n", 430 | " train_data1.extend([256+512-1])\n", 431 | "\n", 432 | "train_data = train_data1[:int(len(train_data1) / 3)]\n", 433 | "\n", 434 | "val_dataset = train_data[:int(len(train_data) * 0.03)]\n", 435 | "test_dataset = train_data[:int(len(train_data) * 0.03)]\n", 436 | "\n", 437 | "train_list = train_data\n", 438 | "val_list = val_dataset\n", 439 | "test_list = []\n", 440 | "print('=' * 50)\n", 441 | "\n", 442 | "print('Processing INTs datasets...')\n", 443 | "train_dataset = EPianoDataset(train_list, max_seq, random_seq)\n", 444 | "val_dataset = EPianoDataset(val_list, max_seq)\n", 445 | "test_dataset = EPianoDataset(test_list, max_seq)\n", 446 | "print('=' * 50)\n", 447 | "\n", 448 | "print('Loading INTs datasets...')\n", 449 | "batch_size = number_of_batches\n", 450 | "train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=n_workers, shuffle=True)\n", 451 | "val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=n_workers)\n", 452 | "test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=n_workers)\n", 453 | "print('=' * 50)\n", 454 | "\n", 455 | "print('Total INTs in the dataset', len(train_data))\n", 456 | "print('Total unique INTs in the dataset', len(set(train_data)))\n", 457 | "print('Max INT in the dataset', max(train_data))\n", 458 | "print('Min INT in the dataset', min(train_data))\n", 459 | "print('=' * 50)\n", 460 | "\n", 461 | "print('Checking datasets shapes...')\n", 462 | "print('=' * 50)\n", 463 | "\n", 464 | "print('Train loader')\n", 465 | "for x, tgt in train_loader:\n", 466 | " print(f'X shape: {x.shape}')\n", 467 | " print(f'Target shape: {tgt.shape}')\n", 468 | " break\n", 469 | "print('=' * 50)\n", 470 | "\n", 471 | "print('Validation loader')\n", 472 | "for x, tgt in val_loader:\n", 473 | " print(f'X shape: {x.shape}')\n", 474 | " print(f'Target shape: {tgt.shape}')\n", 475 | " break\n", 476 | "print('=' * 50)\n", 477 | "\n", 478 | "print('Test loader')\n", 479 | "for x, tgt in test_loader:\n", 480 | " print(f'X shape: {x.shape}')\n", 481 | " print(f'Target shape: {tgt.shape}')\n", 482 | " break\n", 483 | "print('=' * 50)\n", 484 | "\n", 485 | "print('Done! Enjoy! :)')\n", 486 | "print('=' * 50)" 487 | ] 488 | }, 489 | { 490 | "cell_type": "markdown", 491 | "metadata": {}, 492 | "source": [ 493 | "# Test the resulting INTs dataset..." 494 | ] 495 | }, 496 | { 497 | "cell_type": "code", 498 | "execution_count": null, 499 | "metadata": { 500 | "gradient": { 501 | "id": "708f16d3-1747-4e72-bcc9-7504cdd963d4", 502 | "kernelId": "" 503 | } 504 | }, 505 | "outputs": [], 506 | "source": [ 507 | "train_data" 508 | ] 509 | }, 510 | { 511 | "cell_type": "code", 512 | "execution_count": null, 513 | "metadata": { 514 | "gradient": { 515 | "execution_count": 6, 516 | "id": "dd411e56-532f-47dd-8283-ecb57126a3ae", 517 | "kernelId": "" 518 | } 519 | }, 520 | "outputs": [], 521 | "source": [ 522 | "out = train_data[:10000]\n", 523 | "if len(out) != 0:\n", 524 | " song = []\n", 525 | " song = out\n", 526 | " song_f = []\n", 527 | " time = 0\n", 528 | " pitch = 0\n", 529 | " duration = 0\n", 530 | " for s in song:\n", 531 | " if s >= 0 and s <= 256:\n", 532 | " time += s\n", 533 | " if s >= 256 and s < 512:\n", 534 | " pitch = s-256\n", 535 | " if s >= 512 and s < 256+512-4:\n", 536 | " duration = s-512\n", 537 | " song_f.append(['note', (abs(time))*10, (duration*10), 0, pitch, pitch ])\n", 538 | " \n", 539 | " detailed_stats = TMIDIX.Tegridy_SONG_to_MIDI_Converter(song_f,\n", 540 | " output_signature = 'Quantum Music', \n", 541 | " output_file_name = '/notebooks/Quantum-Music-Composition', \n", 542 | " track_name='Project Los Angeles', \n", 543 | " number_of_ticks_per_quarter=500)\n", 544 | "\n", 545 | " print('Done!')\n" 546 | ] 547 | }, 548 | { 549 | "cell_type": "markdown", 550 | "metadata": { 551 | "id": "fkVqviDzJOrv" 552 | }, 553 | "source": [ 554 | "# (TRAIN)" 555 | ] 556 | }, 557 | { 558 | "cell_type": "markdown", 559 | "metadata": { 560 | "id": "Y9CBW8xYupH8" 561 | }, 562 | "source": [ 563 | "# Train the model" 564 | ] 565 | }, 566 | { 567 | "cell_type": "code", 568 | "execution_count": null, 569 | "metadata": { 570 | "cellView": "form", 571 | "gradient": { 572 | "id": "4aa21407-a3e9-4ed2-9bf1-83c295482b8a", 573 | "kernelId": "" 574 | }, 575 | "id": "2moo7uUmpxvC" 576 | }, 577 | "outputs": [], 578 | "source": [ 579 | "#@title Train\n", 580 | "config = GPTConfig(VOCAB_SIZE, \n", 581 | " max_seq,\n", 582 | " dim_feedforward=dim_feedforward,\n", 583 | " n_layer=6, \n", 584 | " n_head=8, \n", 585 | " n_embd=512,\n", 586 | " enable_rpr=True,\n", 587 | " er_len=max_seq)\n", 588 | "model = GPT(config).to(get_device())\n", 589 | "\n", 590 | "#=====\n", 591 | "\n", 592 | "init_step = 0\n", 593 | "lr = LR_DEFAULT_START\n", 594 | "lr_stepper = LrStepTracker(d_model, SCHEDULER_WARMUP_STEPS, init_step)\n", 595 | "eval_loss_func = nn.CrossEntropyLoss(ignore_index=TOKEN_PAD)\n", 596 | "train_loss_func = eval_loss_func\n", 597 | "\n", 598 | "opt = Adam(model.parameters(), lr=lr, betas=(ADAM_BETA_1, ADAM_BETA_2), eps=ADAM_EPSILON)\n", 599 | "lr_scheduler = LambdaLR(opt, lr_stepper.step)\n", 600 | "\n", 601 | "\n", 602 | "#===\n", 603 | "\n", 604 | "best_eval_acc = 0.0\n", 605 | "best_eval_acc_epoch = -1\n", 606 | "best_eval_loss = float(\"inf\")\n", 607 | "best_eval_loss_epoch = -1\n", 608 | "best_acc_file = '/notebooks/gpt2_rpr_acc.pth'\n", 609 | "best_loss_file = '/notebooks/gpt2_rpr_loss.pth'\n", 610 | "loss_train, loss_val, acc_val = [], [], []\n", 611 | "\n", 612 | "for epoch in range(0, epochs):\n", 613 | " new_best = False\n", 614 | " \n", 615 | " loss = train(epoch+1, model, train_loader, train_loss_func, opt, lr_scheduler, num_iters=-1)\n", 616 | " loss_train.append(loss)\n", 617 | " \n", 618 | " eval_loss, eval_acc = eval_model(model, val_loader, eval_loss_func, num_iters=-1)\n", 619 | " loss_val.append(eval_loss)\n", 620 | " acc_val.append(eval_acc)\n", 621 | " \n", 622 | " if(eval_acc > best_eval_acc):\n", 623 | " best_eval_acc = eval_acc\n", 624 | " best_eval_acc_epoch = epoch+1\n", 625 | " torch.save(model.state_dict(), best_acc_file)\n", 626 | " new_best = True\n", 627 | "\n", 628 | " if(eval_loss < best_eval_loss):\n", 629 | " best_eval_loss = eval_loss\n", 630 | " best_eval_loss_epoch = epoch+1\n", 631 | " torch.save(model.state_dict(), best_loss_file)\n", 632 | " new_best = True\n", 633 | " \n", 634 | " if(new_best):\n", 635 | " print(\"Best eval acc epoch:\", best_eval_acc_epoch)\n", 636 | " print(\"Best eval acc:\", best_eval_acc)\n", 637 | " print(\"\")\n", 638 | " print(\"Best eval loss epoch:\", best_eval_loss_epoch)\n", 639 | " print(\"Best eval loss:\", best_eval_loss)" 640 | ] 641 | }, 642 | { 643 | "cell_type": "code", 644 | "execution_count": null, 645 | "metadata": {}, 646 | "outputs": [], 647 | "source": [ 648 | "eval_loss, eval_acc = eval_model(model, val_loader, eval_loss_func, num_iters=-1)" 649 | ] 650 | }, 651 | { 652 | "cell_type": "code", 653 | "execution_count": null, 654 | "metadata": {}, 655 | "outputs": [], 656 | "source": [ 657 | "train_data" 658 | ] 659 | }, 660 | { 661 | "cell_type": "code", 662 | "execution_count": null, 663 | "metadata": { 664 | "cellView": "form", 665 | "gradient": { 666 | "id": "0e338550-f170-44a6-9479-ba0ddbc64608", 667 | "kernelId": "" 668 | }, 669 | "id": "NNqmcFdRyC2M" 670 | }, 671 | "outputs": [], 672 | "source": [ 673 | "#@title Plot resulting training loss graph\n", 674 | "\n", 675 | "tr_loss_list = [item for sublist in loss_train for item in sublist]\n", 676 | "plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')\n", 677 | "plt.savefig('/notebooks/training-loss.png')" 678 | ] 679 | }, 680 | { 681 | "cell_type": "markdown", 682 | "metadata": { 683 | "id": "mdKFoeke9L7H" 684 | }, 685 | "source": [ 686 | "# (SAVE/LOAD)" 687 | ] 688 | }, 689 | { 690 | "cell_type": "code", 691 | "execution_count": null, 692 | "metadata": { 693 | "cellView": "form", 694 | "gradient": { 695 | "id": "73bea62d-084b-4f9a-9e55-2b34a932a7a4", 696 | "kernelId": "" 697 | }, 698 | "id": "gqyDatHC9X1z" 699 | }, 700 | "outputs": [], 701 | "source": [ 702 | "#@title Save the model\n", 703 | "\n", 704 | "print('Saving the model...')\n", 705 | "full_path_to_model_checkpoint = \"/notebooks/Quantum-Music-Trained-Model-6.pth\" #@param {type:\"string\"}\n", 706 | "torch.save(model.state_dict(), full_path_to_model_checkpoint)\n", 707 | "print('Done!')" 708 | ] 709 | }, 710 | { 711 | "cell_type": "code", 712 | "execution_count": null, 713 | "metadata": { 714 | "cellView": "form", 715 | "gradient": { 716 | "id": "c83edd89-9a36-430a-9fa7-3a967417c88e", 717 | "kernelId": "" 718 | }, 719 | "id": "OaNkGcFo9UP_" 720 | }, 721 | "outputs": [], 722 | "source": [ 723 | "#@title Load/Reload the model\n", 724 | "full_path_to_model_checkpoint = \"/notebooks/Quantum-Music-Trained-Model-6.pth\" #@param {type:\"string\"}\n", 725 | "\n", 726 | "print('Loading the model...')\n", 727 | "config = GPTConfig(256+512+2, \n", 728 | " max_seq,\n", 729 | " dim_feedforward=dim_feedforward,\n", 730 | " n_layer=6, \n", 731 | " n_head=8, \n", 732 | " n_embd=512,\n", 733 | " enable_rpr=True,\n", 734 | " er_len=max_seq)\n", 735 | "\n", 736 | "model = GPT(config).to(get_device())\n", 737 | "\n", 738 | "model.load_state_dict(torch.load(full_path_to_model_checkpoint))\n", 739 | "print('Done!')" 740 | ] 741 | }, 742 | { 743 | "cell_type": "markdown", 744 | "metadata": {}, 745 | "source": [ 746 | "# Custom MIDI option" 747 | ] 748 | }, 749 | { 750 | "cell_type": "code", 751 | "execution_count": null, 752 | "metadata": { 753 | "gradient": { 754 | "id": "5f771604-39e7-431d-b1dd-86d7437b8872", 755 | "kernelId": "" 756 | } 757 | }, 758 | "outputs": [], 759 | "source": [ 760 | "data = TMIDIX.Optimus_MIDI_TXT_Processor('/notebooks/seed97-super.mid', \n", 761 | " dataset_MIDI_events_time_denominator=10, \n", 762 | " perfect_timings=True, \n", 763 | " musenet_encoding=True, \n", 764 | " char_offset=0, \n", 765 | " MIDI_channel=-1, \n", 766 | " MIDI_patch=range(0, 127)\n", 767 | " )\n", 768 | "\n", 769 | "SONG = data[5]\n", 770 | "inputs = []\n", 771 | "for i in SONG:\n", 772 | " if max(i) < 256 and max(i) >= 0:\n", 773 | " if i[0] < 16:\n", 774 | " inputs.extend([i[0]])\n", 775 | " else:\n", 776 | " \n", 777 | " inputs.extend([16, i[0]-16])\n", 778 | " \n", 779 | " inputs.extend([256+i[3], 512+i[1] ]) " 780 | ] 781 | }, 782 | { 783 | "cell_type": "markdown", 784 | "metadata": { 785 | "id": "UX1_5y5Fu8AH" 786 | }, 787 | "source": [ 788 | "# (GENERATE MUSIC)" 789 | ] 790 | }, 791 | { 792 | "cell_type": "code", 793 | "execution_count": null, 794 | "metadata": { 795 | "cellView": "form", 796 | "gradient": { 797 | "id": "97793d01-6a74-4e34-be95-ea337277b38d", 798 | "kernelId": "" 799 | }, 800 | "id": "M_K93hWWv2Yx" 801 | }, 802 | "outputs": [], 803 | "source": [ 804 | "#@title Generate and download a MIDI file\n", 805 | "\n", 806 | "number_of_tokens_to_generate = 1024 #@param {type:\"slider\", min:8, max:1024, step:8}\n", 807 | "use_random_primer = False #@param {type:\"boolean\"}\n", 808 | "start_with_zero_token = False #@param {type:\"boolean\"}\n", 809 | "number_of_ticks_per_quarter = 500 #@param {type:\"slider\", min:50, max:1000, step:50}\n", 810 | "dataset_time_denominator = 10\n", 811 | "melody_conditioned_encoding = False\n", 812 | "encoding_has_MIDI_channels = False \n", 813 | "encoding_has_velocities = False\n", 814 | "simulate_velocity = True #@param {type:\"boolean\"}\n", 815 | "save_only_first_composition = True\n", 816 | "chars_encoding_offset_used_for_dataset = 33\n", 817 | "\n", 818 | "fname = '/notebooks/Quantum-Music-Composition'\n", 819 | "\n", 820 | "print('Quantum Music Model Generator')\n", 821 | "\n", 822 | "output_signature = 'Quantum Music'\n", 823 | "song_name = 'RGA Composition'\n", 824 | "\n", 825 | "model.eval()\n", 826 | "\n", 827 | "if use_random_primer:\n", 828 | " sequence = [random.randint(10, 387) for i in range(64)]\n", 829 | " idx = secrets.randbelow(len(sequence))\n", 830 | " rand_seq = model.generate(torch.Tensor(sequence[idx:idx+120]), target_seq_length=number_of_tokens_to_generate)\n", 831 | " out = rand_seq[0].cpu().numpy().tolist()\n", 832 | "\n", 833 | "else:\n", 834 | " out = []\n", 835 | " \n", 836 | " try:\n", 837 | " if start_with_zero_token:\n", 838 | " sequence = inputs[-512:] #[256+512 - 2, 0]# inputs[-512:]\n", 839 | " rand_seq = model.generate(torch.Tensor(sequence), target_seq_length=number_of_tokens_to_generate, stop_token=256+512)\n", 840 | " out = rand_seq[0].cpu().numpy().tolist()\n", 841 | " else:\n", 842 | " idx = secrets.randbelow(len(train_data))\n", 843 | " sequence = train_data[idx:idx+512]\n", 844 | " rand_seq = model.generate(torch.Tensor(sequence), target_seq_length=number_of_tokens_to_generate, stop_token=256+512)\n", 845 | " out = rand_seq[0].cpu().numpy().tolist()\n", 846 | " \n", 847 | " except:\n", 848 | " print('=' * 50)\n", 849 | " print('Error! Try random priming instead!')\n", 850 | " print('Shutting down...')\n", 851 | " print('=' * 50)\n", 852 | "\n", 853 | "if len(out) != 0:\n", 854 | " song = []\n", 855 | " song = out\n", 856 | " song_f = []\n", 857 | " time = 0\n", 858 | " pitch = 0\n", 859 | " duration = 0\n", 860 | " once = True\n", 861 | " for s in song:\n", 862 | " if s >= 0 and s < 256:\n", 863 | " time += s\n", 864 | " if s >= 256 and s < 512:\n", 865 | " pitch = s-256\n", 866 | " if s >= 512 and s < 256+512-4:\n", 867 | " duration = s-512\n", 868 | " song_f.append(['note', (abs(time))*10, (duration*10), 0, pitch, pitch ])\n", 869 | " \n", 870 | " if song.index(s) >= len(sequence) and once:\n", 871 | " song_f.append(['text_event', abs(time) * 10, 'Continuation Start Here'])\n", 872 | " once = False\n", 873 | " \n", 874 | " detailed_stats = TMIDIX.Tegridy_SONG_to_MIDI_Converter(song_f,\n", 875 | " output_signature = 'Quantum Music', \n", 876 | " output_file_name = '/notebooks/Quantum-Music-Composition', \n", 877 | " track_name='Project Los Angeles', \n", 878 | " number_of_ticks_per_quarter=500)\n", 879 | " \n", 880 | " print('Done!')\n", 881 | "\n", 882 | "\n", 883 | " #print('Downloading your composition now...')\n", 884 | " #from google.colab import files\n", 885 | " #files.download(fname + '.mid')\n", 886 | "\n", 887 | " print('=' * 70)\n", 888 | " print('Detailed MIDI stats:')\n", 889 | " for key, value in detailed_stats.items():\n", 890 | " print('=' * 70)\n", 891 | " print(key, '|', value)\n", 892 | "\n", 893 | " print('=' * 70)\n", 894 | "\n", 895 | "else:\n", 896 | " print('Models output is empty! Check the code...')\n", 897 | " print('Shutting down...')" 898 | ] 899 | }, 900 | { 901 | "cell_type": "code", 902 | "execution_count": null, 903 | "metadata": {}, 904 | "outputs": [], 905 | "source": [ 906 | "len(out)" 907 | ] 908 | }, 909 | { 910 | "cell_type": "code", 911 | "execution_count": null, 912 | "metadata": {}, 913 | "outputs": [], 914 | "source": [ 915 | "out[-64:]" 916 | ] 917 | }, 918 | { 919 | "cell_type": "code", 920 | "execution_count": null, 921 | "metadata": { 922 | "cellView": "form", 923 | "gradient": { 924 | "id": "c8149763-2e09-4fcf-9823-85ca778b9e80", 925 | "kernelId": "" 926 | }, 927 | "id": "STtGgBsf4-tA" 928 | }, 929 | "outputs": [], 930 | "source": [ 931 | "#@title Auto-Regressive Generator\n", 932 | "\n", 933 | "#@markdown NOTE: You much generate a seed composition first or it is not going to start\n", 934 | "\n", 935 | "number_of_cycles_to_run = 5 #@param {type:\"slider\", min:1, max:50, step:1}\n", 936 | "number_of_prime_tokens = 128 #@param {type:\"slider\", min:64, max:256, step:64}\n", 937 | "\n", 938 | "print('=' * 70)\n", 939 | "print('Quantum Music Auto-Regressive Model Generator')\n", 940 | "print('=' * 70)\n", 941 | "print('Starting up...')\n", 942 | "print('=' * 70)\n", 943 | "print('Prime length:', len(out))\n", 944 | "print('Prime tokens:', number_of_prime_tokens)\n", 945 | "print('Prime input sequence', out[-8:])\n", 946 | "\n", 947 | "if len(out) != 0:\n", 948 | " print('=' * 70)\n", 949 | " out_all = []\n", 950 | " out_all.append(out)\n", 951 | " for i in tqdm(range(number_of_cycles_to_run)):\n", 952 | " rand_seq1 = model.generate(torch.Tensor(out[-number_of_prime_tokens:]), target_seq_length=1024, stop_token=256+512)\n", 953 | " out1 = rand_seq1[0].cpu().numpy().tolist()\n", 954 | " out_all.append(out1[number_of_prime_tokens:])\n", 955 | " out = out1[number_of_prime_tokens:]\n", 956 | " \n", 957 | " print(chr(10))\n", 958 | " print('=' * 70)\n", 959 | " print('Block number:', i+1)\n", 960 | " print('Composition length so far:', (i+1) * 1024, 'notes')\n", 961 | " print('=' * 70)\n", 962 | "\n", 963 | " print('Done!' * 70)\n", 964 | " print('Total blocks:', i+1)\n", 965 | " print('Final omposition length:', (i+1) * 1024, 'notes')\n", 966 | " print('=' * 70)\n", 967 | " \n", 968 | " out2 = []\n", 969 | " for o in out_all:\n", 970 | " out2.extend(o)\n", 971 | "\n", 972 | " if len(out2) != 0:\n", 973 | " song = []\n", 974 | " song = out2\n", 975 | " song_f = []\n", 976 | " time = 0\n", 977 | " pitch = 0\n", 978 | " duration = 0\n", 979 | " once = True\n", 980 | " for s in song:\n", 981 | " if s >= 0 and s < 256:\n", 982 | " time += s\n", 983 | " if s >= 256 and s < 512:\n", 984 | " pitch = s-256\n", 985 | " if s >= 512 and s < 256+512-4:\n", 986 | " duration = s-512\n", 987 | " song_f.append(['note', (abs(time))*10, (duration*10), 0, pitch, pitch ])\n", 988 | "\n", 989 | " detailed_stats = TMIDIX.Tegridy_SONG_to_MIDI_Converter(song_f,\n", 990 | " output_signature = 'Quantum Music', \n", 991 | " output_file_name = '/notebooks/Quantum-Music-Composition', \n", 992 | " track_name='Project Los Angeles', \n", 993 | " number_of_ticks_per_quarter=500)\n", 994 | "\n", 995 | " print('Done!')\n", 996 | "\n", 997 | " \n", 998 | "\n", 999 | "else:\n", 1000 | " print('=' * 70)\n", 1001 | " print('INPUT ERROR !!!')\n", 1002 | " print('Prime sequence is empty...')\n", 1003 | " print('Please generate prime sequence and retry')\n", 1004 | "\n", 1005 | "print('=' * 70)" 1006 | ] 1007 | }, 1008 | { 1009 | "cell_type": "markdown", 1010 | "metadata": { 1011 | "id": "YzCMd94Tu_gz" 1012 | }, 1013 | "source": [ 1014 | "# Congrats! You did it! :)" 1015 | ] 1016 | } 1017 | ], 1018 | "metadata": { 1019 | "accelerator": "GPU", 1020 | "colab": { 1021 | "collapsed_sections": [], 1022 | "machine_shape": "hm", 1023 | "name": "Optimus_VIRTUOSO_Multi_Instrumental_RGA_Edition.ipynb", 1024 | "private_outputs": true, 1025 | "provenance": [] 1026 | }, 1027 | "kernelspec": { 1028 | "display_name": "Python 3 (ipykernel)", 1029 | "language": "python", 1030 | "name": "python3" 1031 | }, 1032 | "language_info": { 1033 | "codemirror_mode": { 1034 | "name": "ipython", 1035 | "version": 3 1036 | }, 1037 | "file_extension": ".py", 1038 | "mimetype": "text/x-python", 1039 | "name": "python", 1040 | "nbconvert_exporter": "python", 1041 | "pygments_lexer": "ipython3", 1042 | "version": "3.8.12" 1043 | } 1044 | }, 1045 | "nbformat": 4, 1046 | "nbformat_minor": 4 1047 | } 1048 | --------------------------------------------------------------------------------