├── .DS_Store
├── img
├── 1.png
├── 2.png
├── 3.png
├── 4.png
└── 5.png
├── src
├── .DS_Store
├── ds_files.pt
├── lib
│ ├── .DS_Store
│ ├── model
│ │ ├── __pycache__
│ │ │ ├── rpr.cpython-37.pyc
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── transformer.cpython-37.pyc
│ │ │ ├── transformer_bpe.cpython-37.pyc
│ │ │ └── positional_encoding.cpython-37.pyc
│ │ ├── positional_encoding.py
│ │ ├── transformer.py
│ │ └── rpr.py
│ ├── constants.py
│ ├── encoded_dataset.py
│ ├── inverse_power_with_warmup_sheduler.py
│ ├── colab_utils.py
│ ├── augmentations.py
│ ├── midi_processing.py
│ └── generation.py
├── encoded_dataset
│ ├── .DS_Store
│ ├── pop
│ │ ├── pop_0_80612d513c2c998c87cca766a76be91a_0.pt
│ │ └── pop_0_80612d513c2c998c87cca766a76be91a_1.pt
│ ├── calm
│ │ ├── calm_0_7bfc0a94983dd5eb495ae0555efa4521_0.pt
│ │ ├── calm_0_7bfc0a94983dd5eb495ae0555efa4521_1.pt
│ │ └── calm_0_7bfc0a94983dd5eb495ae0555efa4521_2.pt
│ ├── jazz
│ │ ├── jazz_0_4cf8f1246ebc3c24375aca4539fd8adb_0.pt
│ │ ├── jazz_0_4cf8f1246ebc3c24375aca4539fd8adb_1.pt
│ │ └── jazz_0_4cf8f1246ebc3c24375aca4539fd8adb_2.pt
│ └── classic
│ │ ├── classic_0_94d87b7dc3b6ee96d83f8a173085ce8c_0.pt
│ │ ├── classic_0_94d87b7dc3b6ee96d83f8a173085ce8c_1.pt
│ │ ├── classic_0_94d87b7dc3b6ee96d83f8a173085ce8c_2.pt
│ │ └── classic_0_94d87b7dc3b6ee96d83f8a173085ce8c_3.pt
├── test_dataset
│ ├── pop
│ │ └── pop_0.mid
│ ├── calm
│ │ └── calm_0.mid
│ ├── jazz
│ │ └── jazz_0.mid
│ └── classic
│ │ └── classic_0.mid
├── encode_dataset.py
├── generate.py
├── train.py
├── Music_Composer_Demo_Colab_en.ipynb
└── Music_Composer_Demo_Colab_ru.ipynb
├── gpt2-rga
├── README.md
└── [GPT2RGA] Quantum_Music.ipynb
├── README.md
└── README_ru.md
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/music-composer/main/.DS_Store
--------------------------------------------------------------------------------
/img/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/music-composer/main/img/1.png
--------------------------------------------------------------------------------
/img/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/music-composer/main/img/2.png
--------------------------------------------------------------------------------
/img/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/music-composer/main/img/3.png
--------------------------------------------------------------------------------
/img/4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/music-composer/main/img/4.png
--------------------------------------------------------------------------------
/img/5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/music-composer/main/img/5.png
--------------------------------------------------------------------------------
/src/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/music-composer/main/src/.DS_Store
--------------------------------------------------------------------------------
/src/ds_files.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/music-composer/main/src/ds_files.pt
--------------------------------------------------------------------------------
/src/lib/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/music-composer/main/src/lib/.DS_Store
--------------------------------------------------------------------------------
/src/encoded_dataset/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/music-composer/main/src/encoded_dataset/.DS_Store
--------------------------------------------------------------------------------
/src/test_dataset/pop/pop_0.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/music-composer/main/src/test_dataset/pop/pop_0.mid
--------------------------------------------------------------------------------
/src/test_dataset/calm/calm_0.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/music-composer/main/src/test_dataset/calm/calm_0.mid
--------------------------------------------------------------------------------
/src/test_dataset/jazz/jazz_0.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/music-composer/main/src/test_dataset/jazz/jazz_0.mid
--------------------------------------------------------------------------------
/src/test_dataset/classic/classic_0.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/music-composer/main/src/test_dataset/classic/classic_0.mid
--------------------------------------------------------------------------------
/src/lib/model/__pycache__/rpr.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/music-composer/main/src/lib/model/__pycache__/rpr.cpython-37.pyc
--------------------------------------------------------------------------------
/src/lib/model/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/music-composer/main/src/lib/model/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/src/lib/model/__pycache__/transformer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/music-composer/main/src/lib/model/__pycache__/transformer.cpython-37.pyc
--------------------------------------------------------------------------------
/src/lib/model/__pycache__/transformer_bpe.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/music-composer/main/src/lib/model/__pycache__/transformer_bpe.cpython-37.pyc
--------------------------------------------------------------------------------
/src/lib/model/__pycache__/positional_encoding.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/music-composer/main/src/lib/model/__pycache__/positional_encoding.cpython-37.pyc
--------------------------------------------------------------------------------
/src/encoded_dataset/pop/pop_0_80612d513c2c998c87cca766a76be91a_0.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/music-composer/main/src/encoded_dataset/pop/pop_0_80612d513c2c998c87cca766a76be91a_0.pt
--------------------------------------------------------------------------------
/src/encoded_dataset/pop/pop_0_80612d513c2c998c87cca766a76be91a_1.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/music-composer/main/src/encoded_dataset/pop/pop_0_80612d513c2c998c87cca766a76be91a_1.pt
--------------------------------------------------------------------------------
/src/encoded_dataset/calm/calm_0_7bfc0a94983dd5eb495ae0555efa4521_0.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/music-composer/main/src/encoded_dataset/calm/calm_0_7bfc0a94983dd5eb495ae0555efa4521_0.pt
--------------------------------------------------------------------------------
/src/encoded_dataset/calm/calm_0_7bfc0a94983dd5eb495ae0555efa4521_1.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/music-composer/main/src/encoded_dataset/calm/calm_0_7bfc0a94983dd5eb495ae0555efa4521_1.pt
--------------------------------------------------------------------------------
/src/encoded_dataset/calm/calm_0_7bfc0a94983dd5eb495ae0555efa4521_2.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/music-composer/main/src/encoded_dataset/calm/calm_0_7bfc0a94983dd5eb495ae0555efa4521_2.pt
--------------------------------------------------------------------------------
/src/encoded_dataset/jazz/jazz_0_4cf8f1246ebc3c24375aca4539fd8adb_0.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/music-composer/main/src/encoded_dataset/jazz/jazz_0_4cf8f1246ebc3c24375aca4539fd8adb_0.pt
--------------------------------------------------------------------------------
/src/encoded_dataset/jazz/jazz_0_4cf8f1246ebc3c24375aca4539fd8adb_1.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/music-composer/main/src/encoded_dataset/jazz/jazz_0_4cf8f1246ebc3c24375aca4539fd8adb_1.pt
--------------------------------------------------------------------------------
/src/encoded_dataset/jazz/jazz_0_4cf8f1246ebc3c24375aca4539fd8adb_2.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/music-composer/main/src/encoded_dataset/jazz/jazz_0_4cf8f1246ebc3c24375aca4539fd8adb_2.pt
--------------------------------------------------------------------------------
/src/encoded_dataset/classic/classic_0_94d87b7dc3b6ee96d83f8a173085ce8c_0.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/music-composer/main/src/encoded_dataset/classic/classic_0_94d87b7dc3b6ee96d83f8a173085ce8c_0.pt
--------------------------------------------------------------------------------
/src/encoded_dataset/classic/classic_0_94d87b7dc3b6ee96d83f8a173085ce8c_1.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/music-composer/main/src/encoded_dataset/classic/classic_0_94d87b7dc3b6ee96d83f8a173085ce8c_1.pt
--------------------------------------------------------------------------------
/src/encoded_dataset/classic/classic_0_94d87b7dc3b6ee96d83f8a173085ce8c_2.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/music-composer/main/src/encoded_dataset/classic/classic_0_94d87b7dc3b6ee96d83f8a173085ce8c_2.pt
--------------------------------------------------------------------------------
/src/encoded_dataset/classic/classic_0_94d87b7dc3b6ee96d83f8a173085ce8c_3.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/music-composer/main/src/encoded_dataset/classic/classic_0_94d87b7dc3b6ee96d83f8a173085ce8c_3.pt
--------------------------------------------------------------------------------
/gpt2-rga/README.md:
--------------------------------------------------------------------------------
1 | GPT2 RGA Version
2 |
3 | GPT2 is significantly superior to Transformer XL so it is best to use GPT2 with RGA for best resutls.
4 |
5 | Particularly you want to pay attention to continuations because GPT2 is the only architechture that can handle it properly
6 |
7 | Alex
8 |
--------------------------------------------------------------------------------
/src/lib/constants.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | RANGE_NOTE_ON = 128
4 | RANGE_NOTE_OFF = 128
5 | RANGE_VEL = 32
6 | RANGE_TIME_SHIFT = 100
7 |
8 | TOKEN_END = RANGE_NOTE_ON + RANGE_NOTE_OFF + RANGE_VEL + RANGE_TIME_SHIFT
9 | TOKEN_PAD = TOKEN_END + 1
10 | VOCAB_SIZE = TOKEN_PAD + 1 + 4
11 |
12 | TORCH_FLOAT = torch.float32
13 | TORCH_INT = torch.int32
14 |
15 | TORCH_LABEL_TYPE = torch.long
16 |
17 | PREPEND_ZEROS_WIDTH = 4
18 |
19 |
--------------------------------------------------------------------------------
/src/lib/model/positional_encoding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 |
5 | # PositionalEncoding
6 | # Taken from https://pytorch.org/tutorials/beginner/transformer_tutorial.html
7 | class PositionalEncoding(nn.Module):
8 |
9 | def __init__(self, d_model, dropout=0.1, max_len=5000):
10 | super(PositionalEncoding, self).__init__()
11 | self.dropout = nn.Dropout(p=dropout)
12 |
13 | pe = torch.zeros(max_len, d_model)
14 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
15 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
16 | pe[:, 0::2] = torch.sin(position * div_term)
17 | pe[:, 1::2] = torch.cos(position * div_term)
18 | pe = pe.unsqueeze(0).transpose(0, 1)
19 | self.register_buffer('pe', pe)
20 |
21 | def forward(self, x):
22 | x = x + self.pe[:x.size(0), :]
23 | return self.dropout(x)
24 |
--------------------------------------------------------------------------------
/src/lib/encoded_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.utils.data import Dataset
4 |
5 |
6 | class EncodedDataset(Dataset):
7 | """
8 | Dataset class for training and evaluating the model.
9 |
10 | Parameters
11 | ----------
12 | ds_files : str
13 | path to file 'ds_files.pt'. The file contains the list of paths to encoded sequences (samples of dataset).
14 | prefix_path : str
15 | prefix_path will be added to paths in 'ds_files.pt'. Used sometimes for convenience.
16 | transform : MusicAugmentations
17 | in-fly augmentations for sequences.
18 | """
19 | def __init__(self, ds_files, prefix_path='', transform=None):
20 | self.transform = transform
21 | self.files = torch.load(ds_files)
22 | self.prefix_path = prefix_path
23 | self.genre2id = {'classic':0, 'jazz':1, 'calm':2, 'pop':3}
24 | self.genre = [self.genre2id.get(f.split('/')[1], 0) for f in self.files] # 1 for 'encoded_data/GENRE/xxx.pt'
25 |
26 | def __len__(self):
27 | return len(self.files)
28 |
29 | def __getitem__(self, idx):
30 | x = torch.load(self.prefix_path + self.files[idx])
31 | if self.transform:
32 | x = torch.from_numpy(self.transform(x))
33 | genre = self.genre[idx]
34 | return x, genre, idx
35 |
--------------------------------------------------------------------------------
/src/lib/inverse_power_with_warmup_sheduler.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class InversePowerWithWarmupLRScheduler(torch.optim.lr_scheduler._LRScheduler):
5 | """
6 | Warmup learning rate until `warmup_steps` then apply inverse power function.
7 | It is more flexible version of inverse sqrt function.
8 | base formula after warmup: 1 / (i + shift) ** power
9 |
10 | Parameters
11 | ----------
12 | optimizer : torch.optim
13 | PyTorch optimizer for model weights updates.
14 | peak_lr : float
15 | maximum learning rate at peak.
16 | warmup_steps : int
17 | Number of warmup steps.
18 | power : float
19 | power for LR function. Set to 1/2 for inverse sqrt function.
20 | shift : int
21 | shift helps control the duration of decay.
22 | last_epoch : int
23 | last_epoch is treated as last_step in this scheduler.
24 | """
25 | def __init__(self, optimizer, peak_lr, warmup_steps, power=0.5, shift=0, last_epoch=-1):
26 | self.peak_lr = peak_lr
27 | self.warmup_steps = warmup_steps
28 | self.power = power
29 | self.shift = shift
30 | self.warmup_rate = self.peak_lr / (self.warmup_steps)
31 | self.decay_factor = self.peak_lr * (self.warmup_steps + self.shift) ** self.power
32 | super(InversePowerWithWarmupLRScheduler, self).__init__(optimizer, last_epoch=last_epoch)
33 |
34 | def get_lr(self):
35 | i = self.last_epoch
36 | if i < self.warmup_steps:
37 | lr = self.warmup_rate * (i+1)
38 | else:
39 | lr = self.decay_factor / (i + self.shift) ** self.power
40 |
41 | lrs = [lr for _ in self.optimizer.param_groups]
42 | return lrs
--------------------------------------------------------------------------------
/src/lib/colab_utils.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import torch
3 | import base64
4 | import hashlib
5 | import ipywidgets
6 | import numpy as np
7 | import lib.midi_processing
8 | from midi2audio import FluidSynth
9 | from torch.utils.data import Dataset
10 | from IPython.display import Audio, display, FileLink, HTML
11 |
12 |
13 | id2genre = {0:'classic', 1:'jazz', 2:'calm', 3:'pop'}
14 | rugenre = {'classic': 'Классика', 'jazz': 'Джаз', 'calm': 'Эмбиент', 'pop': 'Поп'}
15 | genre2id = dict([[x[1],x[0]] for x in id2genre.items()])
16 | tuned_params = {
17 | 0: 1.1,
18 | 1: 0.95,
19 | 2: 0.9,
20 | 3: 1.0
21 | }
22 |
23 | def decode_and_write(generated, primer, genre, out_dir):
24 | '''Decodes midi files from event-based format and writes them to disk'''
25 | if len(glob.glob(out_dir + '/*.mid')) != 0:
26 | ids = [int(path.split('_')[-2]) for path in glob.glob(out_dir + '/*.mid')]
27 | start_from = max(ids)
28 | else:
29 | start_from = 0
30 |
31 | for i, (gen, g) in enumerate(zip(generated, genre)):
32 | midi = lib.midi_processing.decode(gen)
33 | midi.write(f'{out_dir}/gen_{i + start_from:>02}_{id2genre[g]}.mid')
34 |
35 | def convert_midi_to_wav(midi_path):
36 | '''Converts MIDI to WAV format for listening in Colab'''
37 | FluidSynth("font.sf2").midi_to_audio(midi_path, midi_path.replace('.mid', '.wav'))
38 |
39 | class DownloadButton(ipywidgets.Button):
40 | """Download button with dynamic content
41 |
42 | The content is generated using a callback when the button is clicked.
43 | """
44 |
45 | def __init__(self, filename: str, **kwargs):
46 | super(DownloadButton, self).__init__(**kwargs)
47 | self.filename = filename
48 | self.on_click(self.__on_click)
49 |
50 | def __on_click(self, b):
51 | with open(self.filename, 'rb') as f:
52 | b64 = base64.b64encode(f.read())
53 | payload = b64.decode()
54 | digest = hashlib.md5(self.filename.encode()).hexdigest() # bypass browser cache
55 | id = f'dl_{digest}'
56 |
57 | display(HTML(f"""
58 |
59 |
60 |
61 |
62 |
63 |
68 |
69 |
70 |
71 | """))
72 |
--------------------------------------------------------------------------------
/src/lib/augmentations.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from midi_processing import RANGES
3 |
4 | RANGES_SUM = np.cumsum(RANGES)
5 |
6 | class MusicAugmentations:
7 | def __init__(self, transpose=(-3,3), time_stretch=(0.95,0.975,1.0,1.025,1.05)):
8 | """
9 | Class for applying random transpose and time_stretch augmentations for encoded sequences.
10 |
11 | Parameters
12 | ----------
13 | transpose : tuple(min, max)
14 | range for transpose in pitches.
15 | time_stretch : list
16 | list of time_stretch multipliers to sample from.
17 | """
18 | self.transpose = range(transpose[0], transpose[1]+1)
19 | self.time_stretch = time_stretch
20 |
21 | def __call__(self, encoded):
22 | """encoded: list or 1D np.ndarray"""
23 | transpose = np.random.choice(self.transpose)
24 | # time_stretch = np.random.uniform(*self.time_stretch)
25 | time_stretch = np.random.choice(self.time_stretch)
26 | return augment(encoded, transpose, time_stretch)
27 |
28 | def augment(encoded, transpose, time_stretch):
29 | """
30 | Applies transpose and time_stretch augmentation for encoded sequence. Inplace operation.
31 |
32 | Parameters
33 | ----------
34 | encoded : np.ndarray or list
35 | encoded sequence (input for model).
36 | transpose : int
37 | bias for transpose in pitches.
38 | time_stretch : float
39 | time_stretch multiplier.
40 |
41 | Returns
42 | -------
43 | encoded : np.array or list
44 | augmented sequence.
45 | """
46 | for i,ev in enumerate(encoded):
47 | if ev < RANGES_SUM[0]:
48 | # NOTE_ON
49 | encoded[i] = min(RANGES_SUM[0]-1, max(0, ev+transpose))
50 | elif ev < RANGES_SUM[1]:
51 | # NOTE_OFF
52 | encoded[i] = min(RANGES_SUM[1]-1, max(RANGES_SUM[0], ev+transpose))
53 | elif ev < RANGES_SUM[2]:
54 | # SET_VELOCITY
55 | pass
56 | elif ev < RANGES_SUM[3] and time_stretch != 1.0:
57 | # TIME_SHIFT
58 | t = ev - RANGES_SUM[2] + 1 # since 0 = 10ms
59 | t = max(min(RANGES[3], int(round(t*time_stretch))), 1)
60 | encoded[i] = t + RANGES_SUM[2] - 1
61 | else:
62 | continue
63 | return encoded
--------------------------------------------------------------------------------
/src/encode_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import joblib
4 | import hashlib
5 | import pretty_midi
6 | import numpy as np
7 | from tqdm import tqdm
8 | from pathlib import Path
9 | from concurrent.futures import ProcessPoolExecutor
10 |
11 | from lib import constants
12 | from lib import midi_processing
13 |
14 | DATA_DIR = 'test_dataset'
15 | OUTPUT_DIR = 'encoded_dataset'
16 | DS_FILE_PATH = './ds_files.pt' # path where ds_files.pt will be created
17 |
18 | GENRES = ['classic', 'jazz', 'calm', 'pop']
19 | MAX_LEN = 2048
20 |
21 | print('creating dirs...')
22 | [os.makedirs(OUTPUT_DIR+'/'+g, exist_ok=True) for g in GENRES]
23 |
24 | print('collecting *.mid files...')
25 | FILES = list(map(str, Path(DATA_DIR).rglob('*.mid')))
26 |
27 | def encode_fn(i):
28 | """wrapper for loading i-th midi-file, encoding, padding and saving encoded tensor on disk"""
29 | file = FILES[i]
30 | max_len = MAX_LEN
31 |
32 | path, fname = os.path.split(file)
33 | try:
34 | midi = pretty_midi.PrettyMIDI(file)
35 | genre = path.split('/')[1] # take GENRE from 'data/GENRE/xxx.mid'
36 | except:
37 | print(f'{i} not loaded')
38 | return -1
39 |
40 | assert genre in GENRES, f'{genre} is not in {GENRES}'
41 |
42 | fname, ext = os.path.splitext(fname)
43 | h = hashlib.md5(file.encode()).hexdigest()
44 | save_name = f'{OUTPUT_DIR}/{genre}/{fname}_{h}'
45 |
46 | events = midi_processing.encode(midi, use_piano_range=True)
47 | events = np.array(events)
48 | split_idxs = np.cumsum([max_len]*(events.shape[0]//max_len))
49 | splits = np.split(events, split_idxs, axis=0)
50 | n_last = splits[-1].shape[0]
51 | if n_last < 256:
52 | splits.pop(-1)
53 | drop_last = 1
54 | else:
55 | drop_last = 0
56 |
57 | for i, split in enumerate(splits):
58 | keep_idxs = midi_processing.filter_bad_note_offs(split)
59 | split = split[keep_idxs]
60 | eos_idx = min(max_len - 1, len(split))
61 | split = np.pad(split, [[0,max_len - len(split)]])
62 | split[eos_idx] = constants.TOKEN_END
63 | try:
64 | torch.save(split, f'{save_name}_{i}.pt')
65 | except OSError: # if fname is too long
66 | save_name = f'{OUTPUT_DIR}/{genre}/{h}'
67 | torch.save(split, f'{save_name}_{i}.pt')
68 | return drop_last
69 |
70 | cpu_count = joblib.cpu_count()
71 | print(f'starting encoding in {cpu_count} processes...')
72 | with ProcessPoolExecutor(cpu_count) as pool:
73 | x = list(tqdm(pool.map(encode_fn, range(len(FILES))), position=0, total=len(FILES)))
74 |
75 | print('collecting encoded (*.pt) files...')
76 | ds_files = list(map(str, Path(OUTPUT_DIR).rglob('*.pt')))
77 | print('total encoded files:', len(ds_files))
78 |
79 | torch.save(ds_files, DS_FILE_PATH)
80 | print('ds_files.pt saved to', os.path.abspath(DS_FILE_PATH))
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Music Composer
2 | This repository is dedicated to synthesizing symbolic music in MIDI format using the Music Transformer model (103M paramaters). In the repository, you can find a demo notebook for generating on a GPU Google Colab instance, data preparation and model training code.
3 |
4 | ## Table of Contents
5 | 1. [Demo notebook](#demo-notebook)
6 | 2. [Model code](#model-code)
7 | 3. [Data](#data)
8 | 4. [Training](#training)
9 |
10 |
11 | ## Demo notebook
12 |
13 | Jupyter Notebook can be opened on Colab by clicking on the button:
14 |
15 | [](https://colab.research.google.com/github/sberbank-ai/music-composer/blob/main/src/Music_Composer_Demo_Colab_en.ipynb)
16 |
17 | It unrolls the environment, loads the code and weights for synthesis. Generation parameters are set in the generation control panel, and you can listen and download the results in the last cell.
18 |
19 | ❗ Make sure the GPU instance is being used at startup. It is possible to synthesize on a CPU, but it takes significantly more time.
20 |
21 | ## Model code
22 | Located in [folder](https://github.com/sberbank-ai/music-composer/tree/main/src/lib/model).
23 | Consists of three main parts:
24 | - Positional encoding - normal positional encoding for transformer models
25 | - Relative Positional Representation - a module with the implementation of Relative Attention
26 | - Transformer - the model itself is a transformer
27 |
28 | Model code and relative attention taken from [repository](https://github.com/gwinndr/MusicTransformer-Pytorch).
29 |
30 | ## Data
31 | To demonstrate the encoding script, we provide several MIDI files from our training sample. They are located in the src / test_dataset folder and are divided into folders by genre. Each folder contains one file to check. You can start preparing event-based versions of these files using the command:
32 | ```python encode_dataset.py```
33 |
34 | The folder with the source MIDI and the folder for the results are set inside the script through the variables `DATA_DIR`,` OUTPUT_DIR`. Dataset files with file paths will be created in `DS_FILE_PATH`. The genre list is specified using `GENRES`, and the maximum record length in event tokens is` MAX_LEN`.
35 |
36 | For demonstration, we also provide the output of this command in the encoded_dataset folder. It contains tensors with MIDI converted to event-based format. They can be loaded using the standard `torch.load (file_path)`
37 | Datasets can be used as public MIDI for training:
38 | [MAESTRO Dataset](https://magenta.tensorflow.org/datasets/maestro)
39 | [Lakh MIDI Dataset](https://colinraffel.com/projects/lmd/)
40 | [GiantMIDI-Piano Dataset](https://github.com/bytedance/GiantMIDI-Piano)
41 |
42 | There is another way to get MIDI files - transcribing wave files with music. An approach like [Onset-frames] (https://magenta.tensorflow.org/onsets-frames) can help with this.
43 | As music for transcription, you can use for example [Free Music Archive] (https://github.com/mdeff/fma).
44 | ❗Significant resources may be required for transcribing, but this is exactly what will allow to get around the main limitation of the current models of symbolic music generation - the absence of large corpora with notes.
45 | ❗ After transcription, it is recommended to analyze the results and filter out bad recordings.
46 | ## Training
47 | A script for training a model on prepared data can be run using:
48 | ```python train.py```
49 | Training parameters are set inside the script in the params variable. A description of each of the parameters will be given later in this section.
50 |
--------------------------------------------------------------------------------
/README_ru.md:
--------------------------------------------------------------------------------
1 | # Music Composer
2 | Данный репозиторий посвящен синтезу символьной музыки в MIDI формате с помощью модели Music Transformer. В репозитории можно найти демонстрационный ноутбук для генерации на GPU инстансе Google Colab, код подготовки данных и обучения модели.
3 |
4 | ## Оглавление
5 | 1. [Демонстрационный ноутбук](#демонстрационный-ноутбук)
6 | 2. [Код модели](#код-модели)
7 | 3. [Данные](#данные)
8 | 4. [Обучение](#обучение)
9 |
10 |
11 | ## Демонстрационный ноутбук
12 |
13 | Jupyter Notebook можно открыть на Colab нажав на кнопку:
14 |
15 | [](https://colab.research.google.com/github/sberbank-ai/music-composer/blob/main/src/Music_Composer_Demo_Colab_ru.ipynb)
16 |
17 | В нем производится разворачивание окружения, подгрузка кода и весов для синтеза. Параметры генерации задаются в панели управления генерацией, а прослушать и скачать результаты можно в последней ячейке.
18 |
19 | ❗При запуске убедитесь в том, что используется GPU инстанс. Можно синтезировать и на CPU, но это занимает ощутимо больше времени.
20 |
21 | ## Код модели
22 | Расположен в [папке](https://github.com/sberbank-ai/music-composer/tree/main/src/lib/model).
23 | Состоит из трех осовных частей:
24 | - Positional encoding - обычное позиционное кодирование для трансформерных моделей
25 | - Relative Positional Representation - модуль с реализацией Relative Attention
26 | - Transformer - сама модель трансформер
27 |
28 | Код модели и relative attention взят из [репозитория](https://github.com/gwinndr/MusicTransformer-Pytorch).
29 |
30 | ## Данные
31 | Для демонстрации скрипта энкодинга мы предоставляем несколько MIDI файлов из нашей обучающей выборки. Они находятся в папке src/test_dataset и разбиты по папкам на жанры. В каждой папке по одному файлу для проверки. Запустить подготовку закодированных в event-based формате версий этих файлов можно с помощью команды:
32 | ```python encode_dataset.py```
33 |
34 | Папка с исходными MIDI и папка для результатов задаются внутри скрипта через переменные `DATA_DIR`, `OUTPUT_DIR`. Файлы датасетов с путями до файлов будут созданы в `DS_FILE_PATH`. Список жанров задается через `GENRES`, а максимальная длина записи в event токенах - `MAX_LEN`.
35 |
36 | Для демонстрации мы также предоставляем результат работы данной команды в папке encoded_dataset. В нем находятся тензоры с MIDI, переведенными в event-based формат. Их можно загрузить с помощью стандартного `torch.load(file_path)`
37 | В качестве общедоступных MIDI для обучения можно использовать датасеты:
38 | [MAESTRO Dataset](https://magenta.tensorflow.org/datasets/maestro)
39 | [Lakh MIDI Dataset](https://colinraffel.com/projects/lmd/)
40 | [GiantMIDI-Piano Dataset](https://github.com/bytedance/GiantMIDI-Piano)
41 |
42 | Есть еще один способ получения MIDI файлов - транскрибирование волновых файлов с музыкой. В этом может помочь подход наподобие [Onset-frames](https://magenta.tensorflow.org/onsets-frames).
43 | В качестве музыки для транскрибирования можно использовать например [Free Music Archive](https://github.com/mdeff/fma).
44 | ❗Для транскрибирования могут потребоваться значительные ресурсы, однако именно это позволит обойти основное ограничение текущих моделей генерации символьной музыки - отсутствие крупных корпусов с нотами.
45 | ❗ После транскрибирования рекомендуется проанализировать результаты и отфильтровать плохие записи.
46 | ## Обучение
47 | Скрипт для обучения модели на подготовленных данных можно запустить с помощью:
48 | ```python train.py```
49 | Параметры обучения задаются внутри скрипта в переменной params. Позднее в данном разделе будет дано описание каждого из параметров.
50 |
--------------------------------------------------------------------------------
/src/generate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import torch
4 | import argparse
5 | import pretty_midi
6 | import numpy as np
7 | from tqdm import tqdm
8 |
9 | from lib import constants
10 | from lib import midi_processing
11 | from lib import generation
12 | from lib.midi_processing import PIANO_RANGE
13 | from lib.model.transformer import MusicTransformer
14 |
15 |
16 | def decode_and_write(generated, primer, genre, out_dir):
17 | '''Decodes event-based format to midi and writes resulting file to disk'''
18 | for i, (gen, g) in enumerate(zip(generated, genre)):
19 | midi = midi_processing.decode(gen)
20 | midi.write(f'{out_dir}/gen_{i:>02}_{id2genre[g]}.mid')
21 |
22 |
23 | id2genre = {0:'classic',1:'jazz',2:'calm',3:'pop'}
24 | genre2id = dict([[x[1],x[0]] for x in id2genre.items()])
25 | tuned_params = {
26 | 0: 1.1,
27 | 1: 0.95,
28 | 2: 0.9,
29 | 3: 1.0
30 | }
31 |
32 |
33 | if __name__ == '__main__':
34 | parser = argparse.ArgumentParser()
35 | parser.add_argument('--genre')
36 | parser.add_argument('--target_seq_length', default=512, type=int)
37 | parser.add_argument('--temperature', default=None, type=float)
38 | parser.add_argument('--topk', default=40, type=int)
39 | parser.add_argument('--topp', default=0.99, type=float)
40 | parser.add_argument('--topp_temperature', default=1.0, type=float)
41 | parser.add_argument('--at_least_k', default=1, type=int)
42 | parser.add_argument('--use_rp', action='store_true')
43 | parser.add_argument('--rp_penalty', default=0.05, type=int)
44 | parser.add_argument('--rp_restore_speed', default=0.7, type=int)
45 | parser.add_argument('--seed', default=None, type=int)
46 | parser.add_argument('--device', default='cuda:0')
47 | parser.add_argument('--keep_bad_generations', action='store_true')
48 | parser.add_argument('--out_dir', default=None)
49 | parser.add_argument('--load_path', default=None)
50 | parser.add_argument('--batch_size', default=8, type=int)
51 | args = parser.parse_args()
52 |
53 |
54 | try:
55 | genre_id = genre2id[args.genre]
56 | except KeyError:
57 | raise KeyError("Invalid genre name. Use one of ['classic', 'jazz', 'calm', 'pop']")
58 |
59 | load_path = args.load_path or '../checkpoints/model_big_v3_378k.pt'
60 | out_dir = args.out_dir or ('generated_' + time.strftime('%d-%m-%Y_%H-%M-%S'))
61 | batch_size = args.batch_size
62 | device = torch.device(args.device)
63 | remove_bad_generations = not args.keep_bad_generations
64 |
65 | default_params = dict(
66 | target_seq_length = 512,
67 | temperature = tuned_params[genre_id],
68 | topk = 40,
69 | topp = 0.99,
70 | topp_temperature = 1.0,
71 | at_least_k = 1,
72 | use_rp = False,
73 | rp_penalty = 0.05,
74 | rp_restore_speed = 0.7,
75 | seed = None,
76 | )
77 |
78 | params = {k:args.__dict__[k] if args.__dict__[k] else default_params[k] for k in default_params}
79 |
80 | os.makedirs(out_dir, exist_ok=True)
81 |
82 | # init model
83 | print('loading model...')
84 | model = MusicTransformer(device, n_layers=12, d_model=1024, dim_feedforward=2048, num_heads=16, vocab_size=constants.VOCAB_SIZE, rpr=True).to(device).eval()
85 | model.load_state_dict(torch.load(load_path, map_location=device))
86 |
87 | # add information about genre (first token)
88 | primer_genre = np.repeat([genre_id], batch_size)
89 | primer = torch.tensor(primer_genre)[:,None] + constants.VOCAB_SIZE - 4
90 |
91 | print('generating to:', os.path.abspath(out_dir))
92 | generated = generation.generate(model, primer, **params)
93 | generated = generation.post_process(generated, remove_bad_generations=remove_bad_generations)
94 |
95 | decode_and_write(generated, primer, primer_genre, out_dir)
--------------------------------------------------------------------------------
/src/lib/midi_processing.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pretty_midi
3 |
4 |
5 | NOTE_ON = 0
6 | NOTE_OFF = 1
7 | SET_VELOCITY = 2
8 | TIME_SHIFT = 3
9 |
10 | MAX_TIME_SHIFT = 1.0
11 | TIME_SHIFT_STEP = 0.01
12 | RANGES = [128,128,32,100]
13 |
14 | PIANO_RANGE = [21,96] # 76 piano keys
15 |
16 |
17 | def encode(midi, use_piano_range=True):
18 | """
19 | Encodes midi to event-based sequences for MusicTransformer.
20 |
21 | Parameters
22 | ----------
23 | midi : prettyMIDI object
24 | MIDI to encode.
25 | use_piano_range : bool
26 | if True, classical piano range will be used for skip pitches. Pitches which are not in range PIANO_RANGE will be skipped.
27 |
28 | Returns
29 | -------
30 | encoded_splits : list(list())
31 | splits of encoded sequences.
32 | """
33 | events = get_events(midi, use_piano_range=use_piano_range)
34 | if len(events) == 0:
35 | return []
36 | quantize_(events)
37 | add_time_shifts(events)
38 | encoded = encode_events(events)
39 | return encoded
40 |
41 |
42 | def decode(encoded):
43 | """
44 | Decode event-based encoded sequence into MIDI object.
45 |
46 | Parameters
47 | ----------
48 | encoded : np.array or list
49 | encoded sequence to decode.
50 |
51 | Returns
52 | -------
53 | midi_out: PrettyMIDI object
54 | decoded MIDI.
55 | """
56 | midi_out = pretty_midi.PrettyMIDI()
57 | midi_out.instruments.append(pretty_midi.Instrument(0, name='piano'))
58 | notes = midi_out.instruments[0].notes
59 |
60 | notes_tmp = {} # pitch: [vel, start, end]
61 | cur_time = 0
62 | cur_velocity = 100
63 | for ev in encoded:
64 | if ev < RANGES[0]:
65 | # NOTE_ON
66 | pitch = ev
67 | if notes_tmp.get(pitch) is None:
68 | notes_tmp[pitch] = [cur_velocity, cur_time]
69 | elif ev >= RANGES[0] and ev < sum(RANGES[:2]):
70 | # NOTE_OFF
71 | pitch = ev - RANGES[0]
72 | note = notes_tmp.get(pitch)
73 | if note is not None: # check for overlaps (first-OFF mode)
74 | notes.append(pretty_midi.Note(note[0], pitch, note[1], cur_time))
75 | notes_tmp.pop(pitch)
76 | elif ev >= sum(RANGES[:2]) and ev < sum(RANGES[:3]):
77 | # SET_VELOCITY
78 | cur_velocity = max(1,(ev - sum(RANGES[:2]))*128//RANGES[2])
79 | elif ev >= sum(RANGES[:3]) and ev < sum(RANGES[:]):
80 | # TIME_SHIFT
81 | cur_time += (ev - sum(RANGES[:3]) + 1)*TIME_SHIFT_STEP
82 | else:
83 | continue
84 |
85 | for pitch, note in notes_tmp.items():
86 | if note[1] != cur_time:
87 | notes.append(pretty_midi.Note(note[0], pitch, note[1], cur_time))
88 |
89 | return midi_out
90 |
91 |
92 | def round_step(x, step=0.01):
93 | return round(x/step)*step
94 |
95 |
96 | def get_events(midi, use_piano_range=False):
97 | # helper function used in encode()
98 | # time, type, value
99 | events = []
100 | for inst in midi.instruments:
101 | if inst.is_drum:
102 | continue
103 | for note in inst.notes:
104 | if use_piano_range and not (PIANO_RANGE[0] <= note.pitch <= PIANO_RANGE[1]):
105 | continue
106 | start = note.start
107 | end = note.end
108 | events.append([start, SET_VELOCITY, note.velocity])
109 | events.append([start, NOTE_ON, note.pitch])
110 | events.append([end, NOTE_OFF, note.pitch])
111 | events = sorted(events, key=lambda x: x[0])
112 | return events
113 |
114 |
115 | def quantize_(events):
116 | for ev in events:
117 | ev[0] = round_step(ev[0])
118 |
119 |
120 | def add_time_shifts(events):
121 | # populate time_shifts, helper function used in encode()
122 | times = np.array(list(zip(*events)))[0]
123 | diff = np.diff(times, prepend=0)
124 | idxs = diff.nonzero()[0]
125 | for i in reversed(idxs):
126 | if i == 0:
127 | continue
128 | t0 = events[i-1][0] # if i != 0 else 0
129 | t1 = events[i][0]
130 | dt = t1-t0
131 | events.insert(i, [t0, TIME_SHIFT, dt])
132 |
133 |
134 | def encode_events(events):
135 | # helper function used in encode()
136 | out = []
137 | types = []
138 | for time, typ, value in events:
139 | offset = sum(RANGES[:typ])
140 |
141 | if typ == SET_VELOCITY:
142 | value = value*RANGES[SET_VELOCITY]//128
143 | out.append(offset+value)
144 | types.append(typ)
145 |
146 | elif typ == TIME_SHIFT:
147 | dt = value
148 | n = RANGES[TIME_SHIFT]
149 | enc = lambda x: int(x*n)-1
150 | for _ in range(int(dt//MAX_TIME_SHIFT)):
151 | out.append(offset+enc(MAX_TIME_SHIFT))
152 | types.append(typ)
153 | r = round_step(dt%MAX_TIME_SHIFT, TIME_SHIFT_STEP)
154 | if r > 0:
155 | out.append(offset+enc(r))
156 | types.append(typ)
157 |
158 | else:
159 | out.append(offset+value)
160 | types.append(typ)
161 |
162 | return out
163 |
164 |
165 | RANGES_SUM = np.cumsum(RANGES)
166 |
167 |
168 | def get_type(ev):
169 | if ev < RANGES_SUM[0]:
170 | # NOTE_ON
171 | return 0
172 | elif ev < RANGES_SUM[1]:
173 | # NOTE_OFF
174 | return 1
175 | elif ev < RANGES_SUM[2]:
176 | # VEL
177 | return 2
178 | elif ev < RANGES_SUM[3]:
179 | # TS
180 | return 3
181 | else:
182 | return -1
183 |
184 |
185 | def filter_bad_note_offs(events):
186 | """Clear NOTE_OFF events for which the corresponding NOTE_ON event is missing."""
187 | notes_down = {} # pitch: 1
188 | keep_idxs = set(range(len(events)))
189 |
190 | for i,ev in enumerate(events):
191 | typ = get_type(ev)
192 |
193 | if typ == NOTE_ON:
194 | pitch = ev
195 | notes_down[pitch] = 1
196 | if typ == NOTE_OFF:
197 | pitch = ev-128
198 | if notes_down.get(pitch) is None:
199 | # if NOTE_OFF without NOTE_ON, then remove the event
200 | keep_idxs.remove(i)
201 | else:
202 | notes_down.pop(pitch)
203 |
204 | return list(keep_idxs)
--------------------------------------------------------------------------------
/src/lib/model/transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.modules.normalization import LayerNorm
4 | import random
5 |
6 | from lib.constants import *
7 | from .positional_encoding import PositionalEncoding
8 | from .rpr import TransformerEncoderRPR, TransformerEncoderLayerRPR
9 |
10 |
11 | # MusicTransformer
12 | class MusicTransformer(nn.Module):
13 | """
14 | ----------
15 | Author: Damon Gwinn
16 | ----------
17 | Music Transformer reproduction from https://arxiv.org/abs/1809.04281. Arguments allow for
18 | tweaking the transformer architecture (https://arxiv.org/abs/1706.03762) and the rpr argument
19 | toggles Relative Position Representations (RPR - https://arxiv.org/abs/1803.02155).
20 |
21 | Supports training and generation using Pytorch's nn.Transformer class with dummy decoder to
22 | make a decoder-only transformer architecture
23 |
24 | For RPR support, there is modified Pytorch 1.2.0 code in rpr.py. Modified source will be
25 | kept up to date with Pytorch revisions only as necessary.
26 | ----------
27 | """
28 |
29 | def __init__(self, device, n_layers=6, num_heads=8, d_model=512, dim_feedforward=1024, dropout=0.1,
30 | max_sequence=2048, rpr=False, vocab_size=VOCAB_SIZE, cond_vocab_size=None, reduce_qk=False):
31 | super(MusicTransformer, self).__init__()
32 |
33 | self.vocab_size = vocab_size
34 | self.cond_vocab_size = cond_vocab_size
35 | self.device = device
36 | self.dummy = DummyDecoder()
37 |
38 | self.nlayers = n_layers
39 | self.nhead = num_heads
40 | self.d_model = d_model
41 | self.d_ff = dim_feedforward
42 | self.dropout = dropout
43 | self.max_seq = max_sequence
44 | self.rpr = rpr
45 | self.reduce_qk = reduce_qk
46 |
47 | # Input embedding
48 | self.embedding = nn.Embedding(self.vocab_size, self.d_model)
49 | if self.cond_vocab_size is not None:
50 | self.cond_embedding = nn.Embedding(self.cond_vocab_size, self.d_model)
51 | else:
52 | self.cond_embedding = None
53 |
54 | # Positional encoding
55 | self.positional_encoding = PositionalEncoding(self.d_model, self.dropout, self.max_seq)
56 |
57 | # Base transformer
58 | if(not self.rpr):
59 | # To make a decoder-only transformer we need to use masked encoder layers
60 | # Dummy decoder to essentially just return the encoder output
61 | self.transformer = nn.Transformer(
62 | d_model=self.d_model, nhead=self.nhead, num_encoder_layers=self.nlayers,
63 | num_decoder_layers=0, dropout=self.dropout, dim_feedforward=self.d_ff, custom_decoder=self.dummy
64 | )
65 | # RPR Transformer
66 | else:
67 | encoder_norm = LayerNorm(self.d_model)
68 | encoder_layer = TransformerEncoderLayerRPR(self.d_model, self.nhead, self.d_ff, self.dropout,
69 | er_len=self.max_seq, reduce_qk=self.reduce_qk, device=self.device)
70 | encoder = TransformerEncoderRPR(encoder_layer, self.nlayers, encoder_norm)
71 | self.transformer = nn.Transformer(
72 | d_model=self.d_model, nhead=self.nhead, num_encoder_layers=self.nlayers,
73 | num_decoder_layers=0, dropout=self.dropout, dim_feedforward=self.d_ff, custom_decoder=self.dummy,
74 | custom_encoder=encoder
75 | )
76 |
77 | # Final output is a softmaxed linear layer
78 | self.Wout = nn.Linear(self.d_model, self.vocab_size)
79 | self.softmax = nn.Softmax(dim=-1)
80 | self.mask = self.transformer.generate_square_subsequent_mask(max_sequence).to(self.device)
81 |
82 | # forward
83 | def forward(self, x, condition=None, mask=True):
84 | """
85 | ----------
86 | Author: Damon Gwinn
87 | ----------
88 | Takes an input sequence and outputs predictions using a sequence to sequence method.
89 |
90 | A prediction at one index is the "next" prediction given all information seen previously.
91 | ----------
92 | """
93 | if (mask is True):
94 | mask = self.mask[..., :x.shape[1], :x.shape[1]]
95 | else:
96 | mask = None
97 | x = self.embedding(x)
98 | if condition is not None and self.cond_embedding is not None:
99 | x_cond = self.cond_embedding(condition)
100 | x = x + x_cond[:, None]
101 |
102 | # Input shape is (max_seq, batch_size, d_model)
103 | x = x.permute(1,0,2)
104 |
105 | x = self.positional_encoding(x)
106 |
107 | # Since there are no true decoder layers, the tgt is unused
108 | # Pytorch wants src and tgt to have some equal dims however
109 | x_out = self.transformer(src=x, tgt=x, src_mask=mask)
110 |
111 | # Back to (batch_size, max_seq, d_model)
112 | x_out = x_out.permute(1,0,2)
113 |
114 | y = self.Wout(x_out)
115 |
116 | # They are trained to predict the next note in sequence (we don't need the last one)
117 | return y
118 |
119 | def get_norms(self):
120 | norm_dict = {'embedding_weight_norm': torch.norm(self.embedding.weight).item(),
121 | 'embedding_grad_norm': torch.norm(self.embedding.weight.grad).item(),
122 | 'output_weight_norm': torch.norm(self.Wout.weight).item(),
123 | 'output_grad_norm': torch.norm(self.Wout.weight.grad).item()}
124 | return norm_dict
125 |
126 | def get_parameters(self):
127 | return {'device': self.device,
128 | 'n_layers': self.nlayers,
129 | 'num_heads': self.nhead,
130 | 'd_model': self.d_model,
131 | 'dim_feedforward': self.d_ff,
132 | 'dropout': self.dropout,
133 | 'max_sequence': self.max_seq,
134 | 'rpr': self.rpr,
135 | 'vocab_size': self.vocab_size,
136 | 'cond_vocab_size': self.cond_vocab_size,
137 | 'reduce_qk': self.reduce_qk,
138 | }
139 |
140 | # Used as a dummy to nn.Transformer
141 | # DummyDecoder
142 | class DummyDecoder(nn.Module):
143 | """
144 | ----------
145 | Author: Damon Gwinn
146 | ----------
147 | A dummy decoder that returns its input. Used to make the Pytorch transformer into a decoder-only
148 | architecture (stacked encoders with dummy decoder fits the bill)
149 | ----------
150 | """
151 |
152 | def __init__(self):
153 | super(DummyDecoder, self).__init__()
154 |
155 | def forward(self, tgt, memory, tgt_mask, memory_mask,tgt_key_padding_mask,memory_key_padding_mask):
156 | """
157 | ----------
158 | Author: Damon Gwinn
159 | ----------
160 | Returns the input (memory)
161 | ----------
162 | """
163 |
164 | return memory
165 |
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | import os, sys, shutil
2 | import time
3 | import json
4 | import math
5 | import argparse
6 | import itertools
7 | import numpy as np
8 | import pandas as pd
9 | from tqdm import tqdm
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | from torch.utils.tensorboard import SummaryWriter
15 | from torch.nn.parallel import DistributedDataParallel
16 | from torch.utils.data import DataLoader, Subset, DistributedSampler, Dataset
17 |
18 |
19 | from lib import constants
20 | from lib.model.transformer import MusicTransformer
21 | from lib.inverse_power_with_warmup_scheduler import InversePowerWithWarmupLRScheduler
22 | from lib.encoded_dataset import EncodedDataset
23 | from lib.augmentations import MusicAugmentations
24 |
25 | PAD_TOKEN = constants.TOKEN_PAD
26 |
27 | params = dict(
28 | NAME = 'model_name',
29 | DS_FILE_PATH = 'ds_files.pt',
30 | SEED = 0,
31 | num_epochs = 100,
32 | batch_size = 2,
33 | num_workers = 0,
34 | val_every = 6000,
35 | save_every = 6000,
36 | lr = 1e-4,
37 | use_scheduler = True,
38 | peak_lr = 1e-4,
39 | warmup_steps = 4000,
40 | power = 2,
41 | shift = 100000,
42 | LOAD_NAME = '',
43 | LOG_TOTAL_NORM = True,
44 | CLIPPING = False,
45 | gpus = [0,1,2,3],
46 | )
47 |
48 | globals().update(params)
49 |
50 |
51 | def create_dataloaders(batch_size, num_workers=0):
52 | '''Initializes augmentations, loads file lists to datasets and loaders and returns them'''
53 | print('loading data...')
54 |
55 | aug = MusicAugmentations()
56 |
57 | tr_dataset = YoutubeDataset(DS_FILE_PATH, transform=aug)
58 | vl_dataset = YoutubeDataset(DS_FILE_PATH, transform=None)
59 | np.random.seed(0)
60 | idxs = np.random.permutation(len(tr_dataset))
61 | vl, tr = np.split(idxs, [4000])
62 | train_dataset = Subset(tr_dataset, tr)
63 | val_dataset = Subset(vl_dataset, vl)
64 |
65 | sampler = DistributedSampler(train_dataset, world_size, rank, True)
66 | train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler, pin_memory=False, num_workers=num_workers)
67 | sampler = DistributedSampler(val_dataset, world_size, rank, False)
68 | val_loader = DataLoader(val_dataset, batch_size=batch_size*4, sampler=sampler, pin_memory=False, num_workers=num_workers)
69 |
70 | return train_loader, val_loader
71 |
72 |
73 | def init_model(lr, seed=0):
74 | '''Initializes model, loads weights if necessary and creates optimizer'''
75 | torch.manual_seed(seed)
76 | model = MusicTransformer(device, n_layers=12, d_model=1024, dim_feedforward=2048, num_heads=16, vocab_size=390+4, rpr=True).to(device)
77 | if LOAD_NAME != '':
78 | model.load_state_dict(torch.load(LOAD_NAME, map_location=device))
79 | print(f'Loaded model from {LOAD_NAME}')
80 | model = DistributedDataParallel(model, device_ids=[gpus[rank]])
81 | print(sum((torch.numel(x) for x in model.parameters()))/1e6, 'M parameters')
82 | optimizer = torch.optim.AdamW(model.parameters(), lr, weight_decay=1e-5)
83 | return model, optimizer
84 |
85 | def validate(model, val_loader):
86 | CE = 0
87 | ACC = 0
88 | n = 0
89 | model.eval()
90 | with torch.no_grad():
91 | for x, genre, idxs in val_loader:
92 | x[x==0] = PAD_TOKEN
93 | tgt = x.clone()
94 | x[:,-1] = constants.VOCAB_SIZE - 4 + genre
95 | x = torch.roll(x, 1, -1)
96 | x, tgt = x.to(device), tgt.to(device)
97 |
98 | logits = model(x)
99 | pred = logits.argmax(-1)
100 |
101 | mask = tgt != PAD_TOKEN
102 | n += mask.sum().item()
103 | CE += F.cross_entropy(logits.view(-1, logits.shape[-1]), tgt.flatten(), ignore_index=PAD_TOKEN, reduction='sum').item()
104 | ACC += (pred[mask] == tgt[mask]).sum().item()
105 |
106 | model.train()
107 | return CE/n, ACC/n
108 |
109 | def train_ddp(rank_, world_size_):
110 | global device, NAME, SEED, rank, world_size
111 | rank, world_size = rank_, world_size_
112 |
113 | os.environ['MASTER_ADDR'] = 'localhost'
114 | os.environ['MASTER_PORT'] = '12355'
115 | torch.distributed.init_process_group('nccl', rank=rank, world_size=world_size)
116 |
117 | device = torch.device(f'cuda:{gpus[rank]}')
118 | print(rank, gpus[rank], device)
119 |
120 | train_loader, val_loader = create_dataloaders(batch_size, num_workers)
121 |
122 | model, optimizer = init_model(lr, SEED)
123 | if use_scheduler:
124 | scheduler = InversePowerWithWarmupLRScheduler(optimizer, peak_lr=peak_lr, warmup_steps=warmup_steps, power=power, shift=shift)
125 |
126 | if rank == 0:
127 | save_dir = f'output/{NAME}'
128 | save_name = f'{NAME}'
129 | if os.path.exists(save_dir):
130 | print(f'WARNING: {save_dir} exists! It may rewrite useful files')
131 | os.makedirs(save_dir, exist_ok=True)
132 | writer = SummaryWriter(f'runs/{save_name}')
133 |
134 | # TRAIN
135 | LS = {'loss':[], 'lr':[], 'val_ce':[], 'val_acc':[]}
136 |
137 | i_val = 0
138 | i_step = -1
139 | best_ce = float('inf')
140 | patience = 0
141 | for ep in range(num_epochs):
142 | model.train()
143 | train_loader.sampler.set_epoch(ep)
144 | if rank == 0:
145 | bar = tqdm(train_loader, position=rank)
146 | else:
147 | bar = train_loader
148 | for x, genre, idxs in bar:
149 | i_step += 1
150 | x[x==0] = PAD_TOKEN
151 | tgt = x.clone()
152 | x[:,-1] = constants.VOCAB_SIZE - 4 + genre
153 | x = torch.roll(x, 1, -1)
154 | x, tgt = x.to(device), tgt.to(device)
155 | logits = model(x)
156 | loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), tgt.flatten(), ignore_index=PAD_TOKEN)
157 |
158 | optimizer.zero_grad()
159 | loss.backward()
160 |
161 | if CLIPPING:
162 | total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), CLIPPING).item()
163 | else:
164 | total_norm = 0
165 |
166 | optimizer.step()
167 |
168 | if use_scheduler:
169 | scheduler.step()
170 |
171 | if i_step == warmup_steps - 1 and rank == 0:
172 | torch.save(model.module.state_dict(), f'{save_dir}/model_{save_name}_after_warmup.pt')
173 |
174 | if rank == 0:
175 | # logs
176 | LS['loss'] += [loss.item()]
177 | LS['lr'] += [optimizer.param_groups[0]['lr']]
178 | writer.add_scalar(f'Train/embedding_weight_norm', torch.norm(model.module.embedding.weight).item(), i_step)
179 | writer.add_scalar(f'Train/embedding_grad_norm', torch.norm(model.module.embedding.weight.grad).item(), i_step)
180 | writer.add_scalar(f'Train/output_weight_norm', torch.norm(model.module.Wout.weight).item(), i_step)
181 | writer.add_scalar(f'Train/output_grad_norm', torch.norm(model.module.Wout.weight.grad).item(), i_step)
182 | writer.add_scalar(f'Train/loss', loss.item(), i_step)
183 | writer.add_scalar(f'Train/perplexity', math.exp(loss.item()), i_step)
184 | writer.add_scalar(f'Train/lr', optimizer.param_groups[0]['lr'], i_step)
185 | if LOG_TOTAL_NORM:
186 | total_norm = 0.
187 | for p in model.module.parameters():
188 | param_norm = p.grad.detach().data.norm(2)
189 | total_norm += param_norm.item() ** 2
190 | total_norm = total_norm ** 0.5
191 | writer.add_scalar(f'Train/total_grad_norm', total_norm, i_step)
192 | bar.set_postfix(loss=loss.item(), lr=optimizer.param_groups[0]['lr'], norm=total_norm)
193 |
194 |
195 | # VALIDATION
196 | if i_step % val_every == val_every-1:
197 | val_ce, val_acc = validate(model, val_loader)
198 | if world_size > 1:
199 | ce_all, acc_all = [[torch.zeros(1,device=device) for i in range(world_size)] for _ in range(2)]
200 | [torch.distributed.all_gather(a, torch.tensor(x, dtype=torch.float32, device=device)) for a,x in zip([ce_all,acc_all], [val_ce,val_acc])]
201 | val_ce, val_acc = [torch.cat(a).mean().item() for a in [ce_all,acc_all]]
202 | if rank == 0:
203 | # log, save, patience tracking
204 | LS['val_ce'] += [val_ce]
205 | LS['val_acc'] += [val_acc]
206 | writer.add_scalar(f'Val/ce', val_ce, i_val)
207 | writer.add_scalar(f'Val/acc', val_acc, i_val)
208 | writer.add_scalar(f'Val/perplexity', math.exp(val_ce), i_val)
209 | if val_ce < best_ce:
210 | patience = 0
211 | best_ce = val_ce
212 | torch.save({'history':LS,'epoch':ep,'params':params}, f'{save_dir}/hist_{save_name}_best.pt')
213 | torch.save(model.module.state_dict(), f'{save_dir}/model_{save_name}_best.pt')
214 | else:
215 | patience += 1
216 | print(f'{ep}: val_ce={val_ce}, val_acc={val_acc}, patience={patience}')
217 | i_val += 1
218 |
219 | # CHECKPOINT
220 | if (i_step % save_every == save_every-1) and rank == 0:
221 | torch.save({'history':LS,'epoch':ep,'params':params}, f'{save_dir}/hist_{save_name}.pt')
222 | torch.save(model.module.state_dict(), f'{save_dir}/model_{save_name}_{(i_step+1)//1000}k.pt')
223 |
224 | torch.distributed.destroy_process_group()
225 |
226 |
227 | if __name__ == "__main__":
228 | print(NAME, SEED)
229 | world_size = len(gpus)
230 | torch.multiprocessing.spawn(train_ddp, args=(world_size,), nprocs=world_size, join=True)
231 |
--------------------------------------------------------------------------------
/src/lib/generation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from tqdm import tqdm
4 |
5 | from lib import midi_processing
6 | from lib import constants
7 | from lib.midi_processing import RANGES_SUM, get_type, NOTE_ON, NOTE_OFF
8 | from lib.midi_processing import PIANO_RANGE
9 |
10 |
11 | def generate(model, primer, target_seq_length=1024, temperature=1.0, topk=40, topp=0.99, topp_temperature=1.0, at_least_k=1, use_rp=False, rp_penalty=0.05, rp_restore_speed=0.7, seed=None, **forward_args):
12 | """
13 | Generate batch of samples, conditioned on `primer`. There are used several techniques for acquiring better generated samples such as:
14 | - temperature skewing for controlling entropy of distribuitions
15 | - top-k sampling
16 | - top-p (nucleus) sampling (https://arxiv.org/abs/1904.09751)
17 | - DynamicRepetitionPenaltyProcessor that prevents notes repeating
18 | values by default usualy are suitable for our models
19 |
20 | Parameters
21 | ----------
22 | model : MusicTransformer
23 | trained model.
24 | primer : torch.Tensor (B x N)
25 | primer for condition on.
26 | B = batch_size, N = seq_lenght.
27 | We are using the primer consisted of one token - genre. These tokens are {390:'classic', 391:'jazz', 392:'calm', 393:'pop'}.
28 | target_seq_length : int
29 | desired length of generated sequences.
30 | temperature : float
31 | temperature alters the output distribuition of the model. Higher values ( > 1.0) lead to more stohastic sampling, lower values lead to more expected and predictable sequences (ending up with endlessly repeating musical patterns).
32 | topk : int
33 | restricts sampling from lower probabilities. It is the length of set of tokens from which sampling will be.
34 | topp : float
35 | restricts sampling from lower probabilities, but more adaptive then topk. see (https://arxiv.org/abs/1904.09751).
36 | topp_temperature : float
37 | temperature for counting cumulative sum doing topp sampling.
38 | at_least_k : int
39 | like topk, but force to sample from at least k tokens of higher probabilities.
40 | use_rp : bool
41 | use or not the DynamicRepetitionPenaltyProcessor (RP). Trying to prevent the generation of repeated notes.
42 | rp_penalty : float
43 | coef for RP. Higher values lead to more RP impact.
44 | rp_restore_speed : float
45 | how fast the penalty will be lifted. Lower values lead to more RP impact.
46 | seed : int
47 | fixes seed for deterministic generation.
48 | forward_args : dict
49 | args for model's forward.
50 |
51 | Returns
52 | -------
53 | generated : torch.Tensor (B x target_seq_length)
54 | generated batch of sequences.
55 | """
56 | device = model.device
57 | if seed is not None:
58 | torch.manual_seed(seed)
59 | np.random.seed(seed)
60 |
61 | if at_least_k < 1:
62 | at_least_k = 1
63 | B,N = primer.shape
64 | generated = torch.full((B,target_seq_length), constants.TOKEN_PAD, dtype=torch.int64, device=device)
65 | generated[..., :N] = primer.to(device)
66 |
67 | if use_rp:
68 | RP_processor = DynamicRepetitionPenaltyProcessor(B, penalty=rp_penalty, restore_speed=rp_restore_speed, device=device)
69 | whitelist_mask = make_whitelist_mask()
70 |
71 | model.eval()
72 | with torch.no_grad():
73 | for i in tqdm(range(N, target_seq_length)):
74 | logits = model(generated[:, :i], **forward_args)[:, i-1, :]
75 | logits[:,~whitelist_mask] = float('-inf')
76 | p = torch.softmax(logits/topp_temperature, -1)
77 |
78 | # apply topk:
79 | if topk == 0:
80 | topk = p.shape[-1]
81 | p_topk, idxs = torch.topk(p, topk, -1, sorted=True)
82 |
83 | # apply topp:
84 | mask = p_topk.cumsum(-1) < topp
85 | mask[:,:at_least_k] = True
86 | logits_masked = logits.gather(-1, idxs)
87 | logits_masked[~mask] = float('-inf')
88 | p_topp = torch.softmax(logits_masked/temperature, -1)
89 |
90 | # apply penalty:
91 | if use_rp:
92 | p_penalized = RP_processor.apply_penalty(p_topp, idxs)
93 | ib = p_penalized.sum(-1) == 0
94 | if ib.sum() > 0:
95 | # if all topp tokens get zeroes due RP_processor, then fallback to topk-sampling
96 | p_fallback = p_topk[ib].clone()
97 | p_fallback[mask[ib]] = 0. # zeroing topp
98 | p_penalized[ib] = p_fallback
99 |
100 | ib = p_penalized.sum(-1) == 0
101 | if ib.sum() > 0:
102 | # if topk tokens get zeroes, fallback to topp without RP
103 | print('fallback-2')
104 | p_penalized = p_topp
105 | p_topp = p_penalized
106 |
107 | # sample:
108 | next_token = idxs.gather(-1, torch.multinomial(p_topp, 1))
109 | generated[:, i] = next_token.squeeze(-1)
110 |
111 | # update penalty:
112 | if use_rp:
113 | RP_processor.update(next_token)
114 |
115 | return generated[:, :i+1]
116 |
117 |
118 | def post_process(generated, remove_bad_generations=True):
119 | """
120 | Post-process does 3 routines:
121 | 1) removes long pauses (3+ seconds)
122 | 2) clips velocities to range(30,100) to avoid dramaticly loud notes, which are not suitable for our case.
123 | 3) removes bad generated samples. The model sometimes may generate music that consists only of many repeating notes. We try to detect them and remove from batch.
124 |
125 | Parameters
126 | ----------
127 | generated : torch.Tensor (B x N)
128 | batch of generated samples
129 |
130 | Returns
131 | -------
132 | filtered_generated : cleaner and slightly better sounding generated batch
133 | """
134 | generated = generated.cpu().numpy()
135 | remove_pauses(generated, 3)
136 | clip_velocity(generated)
137 |
138 | bad_filter = np.ones(len(generated), dtype=bool)
139 |
140 | if remove_bad_generations:
141 | for i, gen in enumerate(generated):
142 | midi = midi_processing.decode(gen)
143 | if detect_note_repetition(midi) > 0.9:
144 | bad_filter[i] = False
145 |
146 | if np.sum(bad_filter) != len(bad_filter):
147 | print(f'{np.sum(~bad_filter)} bad samples will be removed.')
148 |
149 | return generated[bad_filter]
150 |
151 |
152 | def make_whitelist_mask():
153 | """Generate mask for PIANO_RANGE"""
154 | whitelist_mask = np.zeros(constants.VOCAB_SIZE, dtype=bool)
155 | whitelist_mask[PIANO_RANGE[0]:PIANO_RANGE[1]+1] = True
156 | whitelist_mask[128+PIANO_RANGE[0]:128+PIANO_RANGE[1]+1] = True
157 | whitelist_mask[128*2:] = True
158 | return whitelist_mask
159 |
160 |
161 | class DynamicRepetitionPenaltyProcessor:
162 | """
163 | The class is trying to prevent cases where the model generates repetitive notes or musical patterns that degrade quality.
164 | It dynamically reduces and restores the probabilities of generatied notes.
165 | Each generated note will reduce its probability for the next step by `penalty` value (which is hyperparameter). If this note has been generated again, then we continue to reduce its probability, else we will gradually restore its probability (speed is controlled by restore_speed parameter).
166 |
167 | Parameters
168 | ----------
169 | bs : int
170 | batch_size. We need to know batch_size in advance to create the penalty_matrix.
171 | penalty : float
172 | value by which the probability will be reduced.
173 | restore_speed : float
174 | the number inversed to the number of seconds needs to fully restore probability from 0 to 1.
175 | for restore_speed equal to 1.0 we need 1.0 sec to restore, for 2.0 - 0.5 sec and so on.
176 | """
177 | def __init__(self, bs, device, penalty=0.3, restore_speed=1.0) :
178 | self.bs = bs
179 | self.penalty = penalty
180 | self.restore_speed = restore_speed
181 | self.penalty_matrix = torch.ones(bs,128).to(device)
182 |
183 | def apply_penalty(self, p, idxs):
184 | p = p.clone()
185 | for b in range(len(p)):
186 | i = idxs[b]
187 | pi = p[b]
188 | mask = i < 128
189 | if len(i) > 0:
190 | pi[mask] = pi[mask]*self.penalty_matrix[b,i[mask]]
191 | return p
192 |
193 | def update(self, next_token):
194 | restoring = next_token - (128+128+32) # only TS do restore
195 | restoring = torch.clamp(restoring.float(), 0, 100)/100*self.restore_speed
196 | self.penalty_matrix += restoring
197 | nt = next_token.squeeze(-1)
198 | nt = next_token[next_token < 128]
199 | self.penalty_matrix[:, nt] -= restoring + self.penalty
200 | torch.clamp(self.penalty_matrix, 0, 1.0, out=self.penalty_matrix)
201 | return restoring, nt
202 |
203 |
204 | def detect_note_repetition(midi, threshold_sec=0.01):
205 | """
206 | Returns the fraction of note repetitions. Counts cases where prev_note_end == next_note_start at the same pitch ('glued' notes). Used in detection bad generated samples.
207 |
208 | Parameters
209 | ----------
210 | midi : prettyMIDI object
211 | threshold_sec : float
212 | intervals smaller then threshold_sec are treated as 'glued' notes.
213 |
214 | Returns
215 | -------
216 | fraction of notes repetitions relative to the number of all notes.
217 | """
218 | all_notes = [x for inst in midi.instruments for x in inst.notes if not inst.is_drum]
219 | if len(all_notes) == 0:
220 | return 0
221 | all_notes_np = np.array([[x.start,x.end,x.pitch,x.velocity] for x in all_notes])
222 |
223 | i_sort = np.lexsort([all_notes_np[:,0], all_notes_np[:,2]])
224 |
225 | s = []
226 | cur_p = -1
227 | cur_t = -1
228 | for t in all_notes_np[i_sort]:
229 | a,b,p,v = t
230 | if cur_p != p:
231 | cur_p = p
232 | else:
233 | s.append(a-cur_t)
234 | cur_t = b
235 | s = np.array(s)
236 | return (s < threshold_sec).sum()/len(s)
237 |
238 |
239 | def remove_pauses(generated, threshold=3):
240 | """
241 | Fills pauses by constants.TOKEN_PAD values. Only pauses that longer than `threshold` seconds are considered.
242 | Inplace operation. `generated` is a tensor (batch of sequences).
243 |
244 | Parameters
245 | ----------
246 | generated : torch.Tensor (B x N)
247 | generated batch of sequences.
248 | threshold : int/float
249 | the minimum seconds of silence to treat them as a pause.
250 | """
251 | mask = (generated>=RANGES_SUM[2]) & (generated= threshold and notes_down.sum() == 0:
271 | res_ab[ib].append([a,i,s])
272 | s = 0
273 | a = i+1
274 | s += t
275 | if s >= threshold and notes_down.sum() == 0:
276 | res_ab[ib].append([a,len(i_seconds),s])
277 |
278 | # remove inplace
279 | for ib,t in enumerate(res_ab):
280 | for a,b,s in t:
281 | generated[ib, a:b] = constants.TOKEN_PAD
282 | print(f'pause removed:',ib,f'n={b-a}',a,b,s)
283 |
284 |
285 | def clip_velocity(generated, min_velocity=30, max_velocity=100):
286 | """
287 | Clip velocity to range(min_velocity, max_velocity). Since the model sometimes generate overloud sequences, we try to neutralize this effect.
288 | Inplace operation. `generated` is a tensor (batch of sequences).
289 |
290 | Parameters
291 | ----------
292 | generated : torch.Tensor (B x N)
293 | generated batch of sequences.
294 | min_velocity : int
295 | max_velocity : int
296 | """
297 | max_velocity_encoded = max_velocity*32//128 + RANGES_SUM[1]
298 | min_velocity_encoded = min_velocity*32//128 + RANGES_SUM[1]
299 |
300 | mask = (generated>=RANGES_SUM[1]) & (generated Tuple[Tensor, Optional[Tensor]]
242 |
243 | qkv_same = torch.equal(query, key) and torch.equal(key, value)
244 | if reduce_qk:
245 | #print('Using reduced qk')
246 | qkv_same = False
247 | use_separate_proj_weight = True
248 | kv_same = torch.equal(key, value)
249 |
250 | tgt_len, bsz, embed_dim = query.size()
251 | assert embed_dim == embed_dim_to_check
252 | assert list(query.size()) == [tgt_len, bsz, embed_dim]
253 | assert key.size() == value.size()
254 |
255 | head_dim = embed_dim // num_heads
256 | assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
257 | scaling = float(head_dim) ** -0.5
258 |
259 | if use_separate_proj_weight is not True:
260 | if qkv_same:
261 | # self-attention
262 | q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
263 |
264 | elif kv_same:
265 | # encoder-decoder attention
266 | # This is inline in_proj function with in_proj_weight and in_proj_bias
267 | _b = in_proj_bias
268 | _start = 0
269 | _end = embed_dim
270 | _w = in_proj_weight[_start:_end, :]
271 | if _b is not None:
272 | _b = _b[_start:_end]
273 | q = linear(query, _w, _b)
274 |
275 | if key is None:
276 | assert value is None
277 | k = None
278 | v = None
279 | else:
280 |
281 | # This is inline in_proj function with in_proj_weight and in_proj_bias
282 | _b = in_proj_bias
283 | _start = embed_dim
284 | _end = None
285 | _w = in_proj_weight[_start:, :]
286 | if _b is not None:
287 | _b = _b[_start:]
288 | k, v = linear(key, _w, _b).chunk(2, dim=-1)
289 |
290 | else:
291 | # This is inline in_proj function with in_proj_weight and in_proj_bias
292 | _b = in_proj_bias
293 | _start = 0
294 | _end = embed_dim
295 | _w = in_proj_weight[_start:_end, :]
296 | if _b is not None:
297 | _b = _b[_start:_end]
298 | q = linear(query, _w, _b)
299 |
300 | # This is inline in_proj function with in_proj_weight and in_proj_bias
301 | _b = in_proj_bias
302 | _start = embed_dim
303 | _end = embed_dim * 2
304 | _w = in_proj_weight[_start:_end, :]
305 | if _b is not None:
306 | _b = _b[_start:_end]
307 | k = linear(key, _w, _b)
308 |
309 | # This is inline in_proj function with in_proj_weight and in_proj_bias
310 | _b = in_proj_bias
311 | _start = embed_dim * 2
312 | _end = None
313 | _w = in_proj_weight[_start:, :]
314 | if _b is not None:
315 | _b = _b[_start:]
316 | v = linear(value, _w, _b)
317 | else:
318 | q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
319 | if not reduce_qk:
320 | len1, len2 = q_proj_weight_non_opt.size()
321 | assert len1 == embed_dim and len2 == query.size(-1)
322 |
323 | k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
324 | if not reduce_qk:
325 | len1, len2 = k_proj_weight_non_opt.size()
326 | assert len1 == embed_dim and len2 == key.size(-1)
327 |
328 | v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
329 | len1, len2 = v_proj_weight_non_opt.size()
330 | assert len1 == embed_dim and len2 == value.size(-1)
331 |
332 | if in_proj_bias is not None:
333 | if reduce_qk:
334 | q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:int(embed_dim/2)])
335 | k = linear(key, k_proj_weight_non_opt, in_proj_bias[int(embed_dim/2):embed_dim])
336 | v = linear(value, v_proj_weight_non_opt, in_proj_bias[embed_dim:])
337 | else:
338 | q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
339 | k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)])
340 | v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):])
341 | else:
342 | q = linear(query, q_proj_weight_non_opt, in_proj_bias)
343 | k = linear(key, k_proj_weight_non_opt, in_proj_bias)
344 | v = linear(value, v_proj_weight_non_opt, in_proj_bias)
345 | q = q * scaling
346 |
347 | if bias_k is not None and bias_v is not None:
348 | if static_k is None and static_v is None:
349 | k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
350 | v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
351 | if attn_mask is not None:
352 | attn_mask = torch.cat([attn_mask,
353 | torch.zeros((attn_mask.size(0), 1),
354 | dtype=attn_mask.dtype,
355 | device=attn_mask.device)], dim=1)
356 | if key_padding_mask is not None:
357 | key_padding_mask = torch.cat(
358 | [key_padding_mask, torch.zeros((key_padding_mask.size(0), 1),
359 | dtype=key_padding_mask.dtype,
360 | device=key_padding_mask.device)], dim=1)
361 | else:
362 | assert static_k is None, "bias cannot be added to static key."
363 | assert static_v is None, "bias cannot be added to static value."
364 | else:
365 | assert bias_k is None
366 | assert bias_v is None
367 |
368 | if reduce_qk:
369 | q = q.contiguous().view(tgt_len, bsz * num_heads, int(head_dim / 2)).transpose(0, 1)
370 | if k is not None:
371 | k = k.contiguous().view(-1, bsz * num_heads, int(head_dim / 2)).transpose(0, 1)
372 | if v is not None:
373 | v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
374 | else:
375 | q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
376 | if k is not None:
377 | k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
378 | if v is not None:
379 | v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
380 |
381 | if static_k is not None:
382 | assert static_k.size(0) == bsz * num_heads
383 | assert static_k.size(2) == head_dim
384 | k = static_k
385 |
386 | if static_v is not None:
387 | assert static_v.size(0) == bsz * num_heads
388 | assert static_v.size(2) == head_dim
389 | v = static_v
390 |
391 | src_len = k.size(1)
392 |
393 | if key_padding_mask is not None:
394 | assert key_padding_mask.size(0) == bsz
395 | assert key_padding_mask.size(1) == src_len
396 |
397 | if add_zero_attn:
398 | src_len += 1
399 | k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
400 | v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
401 | if attn_mask is not None:
402 | attn_mask = torch.cat([attn_mask, torch.zeros((attn_mask.size(0), 1),
403 | dtype=attn_mask.dtype,
404 | device=attn_mask.device)], dim=1)
405 | if key_padding_mask is not None:
406 | key_padding_mask = torch.cat(
407 | [key_padding_mask, torch.zeros((key_padding_mask.size(0), 1),
408 | dtype=key_padding_mask.dtype,
409 | device=key_padding_mask.device)], dim=1)
410 |
411 | attn_output_weights = torch.bmm(q, k.transpose(1, 2))
412 | assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
413 |
414 | ######### ADDITION OF RPR ###########
415 | if(rpr_mat is not None):
416 | rpr_mat = _get_valid_embedding(rpr_mat, q.shape[1], k.shape[1])
417 | qe = torch.einsum("hld,md->hlm", q, rpr_mat)
418 | srel = _skew(qe, skew_mask)
419 |
420 | attn_output_weights += srel
421 |
422 | if attn_mask is not None:
423 | attn_mask = attn_mask.unsqueeze(0)
424 | attn_output_weights += attn_mask
425 |
426 | if key_padding_mask is not None:
427 | attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
428 | attn_output_weights = attn_output_weights.masked_fill(
429 | key_padding_mask.unsqueeze(1).unsqueeze(2),
430 | float('-inf'),
431 | )
432 | attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
433 |
434 | attn_output_weights = softmax(attn_output_weights, dim=-1)
435 |
436 | attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training)
437 |
438 | attn_output = torch.bmm(attn_output_weights, v)
439 | assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
440 | attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
441 | attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
442 |
443 | if need_weights:
444 | # average attention weights over heads
445 | attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
446 | return attn_output, attn_output_weights.sum(dim=1) / num_heads
447 | else:
448 | return attn_output, None
449 |
450 | def _get_valid_embedding(Er, len_q, len_k):
451 | """
452 | ----------
453 | Author: Damon Gwinn
454 | ----------
455 | Gets valid embeddings based on max length of RPR attention
456 | ----------
457 | """
458 |
459 | len_e = Er.shape[0]
460 | start = max(0, len_e - len_q)
461 | return Er[start:, :]
462 |
463 | def _skew(qe, skew_mask=None):
464 | """
465 | ----------
466 | Author: Damon Gwinn
467 | ----------
468 | Performs the skew optimized RPR computation (https://arxiv.org/abs/1809.04281)
469 | ----------
470 | """
471 |
472 | sz = qe.shape[1]
473 | # If mask is generated on the fly performance might degrade
474 | if skew_mask is None:
475 | mask = (torch.triu(torch.ones(sz, sz).to(qe.device)) == 1).float().flip(0)
476 | else:
477 | mask = skew_mask[..., skew_mask.shape[0] - sz:, :sz]
478 |
479 | qe = mask * qe
480 | qe = F.pad(qe, (1,0, 0,0, 0,0))
481 | qe = torch.reshape(qe, (qe.shape[0], qe.shape[2], qe.shape[1]))
482 |
483 | srel = qe[:, 1:, :]
484 | return srel
485 |
--------------------------------------------------------------------------------
/src/Music_Composer_Demo_Colab_en.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "cellView": "form",
8 | "id": "S_i1BkSBikEs"
9 | },
10 | "outputs": [],
11 | "source": [
12 | "#@title 0. Preparing the environment\n",
13 | "#@markdown Libraries are simply downloaded and imported here, nothing interesting\n",
14 | "!pip install gdown > /dev/null\n",
15 | "!apt install fluidsynth > /dev/null\n",
16 | "!pip install midi2audio > /dev/null\n",
17 | "!pip install pretty_midi > /dev/null\n",
18 | "!cp /usr/share/sounds/sf2/FluidR3_GM.sf2 ./font.sf2\n",
19 | "!git clone https://github.com/sberbank-ai/music-composer.git\n",
20 | "\n",
21 | "import os\n",
22 | "import sys\n",
23 | "import time\n",
24 | "import glob\n",
25 | "import torch\n",
26 | "import gdown\n",
27 | "import webbrowser\n",
28 | "import ipywidgets\n",
29 | "import pretty_midi\n",
30 | "import numpy as np\n",
31 | "\n",
32 | "from tqdm import tqdm\n",
33 | "from pathlib import Path\n",
34 | "from midi2audio import FluidSynth\n",
35 | "from google.colab import widgets, files\n",
36 | "from IPython.display import Audio, display, FileLink, HTML\n",
37 | "from ipywidgets import AppLayout, HBox, VBox, Label\n",
38 | "\n",
39 | "url = 'https://drive.google.com/u/0/uc?id=1lcNp0y4IZMIos0ASSsERG25WVDuEmLks'\n",
40 | "output = 'model_finetune_700k.pt'\n",
41 | "gdown.download(url, output, quiet=True)\n",
42 | "\n",
43 | "\n",
44 | "sys.path.append('/content/music-composer/src/')\n",
45 | "from lib import constants\n",
46 | "from lib import generation\n",
47 | "from lib import midi_processing\n",
48 | "from lib.midi_processing import PIANO_RANGE\n",
49 | "from lib.model.transformer import MusicTransformer\n",
50 | "from lib.colab_utils import id2genre, rugenre, genre2id, decode_and_write, convert_midi_to_wav, DownloadButton"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": null,
56 | "metadata": {
57 | "cellView": "form",
58 | "id": "Yn4cWWoOJHfi"
59 | },
60 | "outputs": [],
61 | "source": [
62 | "#@title 1. Generation control panel\n",
63 | "#@markdown In this panel, you can customize the music generation parameters for yourself. Let's go through all the points:\n",
64 | "#@markdown * You can choose __primer__ - it's your MIDI file, which will be taken as the beginning of the song. If it is long, then it will be trimmed. Specify the desired length (in seconds) and how to trim it - from the beginning, or from the end (From start / From end). __From start__ means that we will leave the first N seconds, __From end__ - after N seconds.\n",
65 | "#@markdown * __Genre__ - allows you to select the genre in which the tracks will be generated. For each of the genres, we have selected our own set of parameters, which allows you to best generate it, but more on that later.\n",
66 | "#@markdown * __Seed__ - allows you to fix the seed with which the tracks are generated. Fixing it can be useful if you want to study the influence of a particular parameter on generation - set a numerical seed and generate several times changing the studied parameter. In other cases, it is recommended to leave this field blank. \n",
67 | "#@markdown * __Batch_size__ - the number of simultaneously generated tracks. If you overdo it, you may not have enough memory. Here you need to find a balance between the number of simultaneous tracks and their length. \n",
68 | "#@markdown * __Sequence length__ - track length. Due to the peculiarities of generation, we cannot say in advance how long the track will last, but increasing this parameter potentially increases the track length. It is with this parameter that the Batch_size must be balanced so that the video memory doesn't end.\n",
69 | "#@markdown * __Remove bad generations__ - sometimes the model starts generating garbage. If this flag is enabled, we will try to detect it, filter it out and generate new compositions instead. The generation time increases accordingly.\n",
70 | "#@markdown * __Temperature__ - temperature scaling. Affects the probability distribution itself, making it more or less equiprobable. This regulates the variety of generation.\n",
71 | "#@markdown * __TopK__ - restriction of dialing on the upper border. Sampling comes from a set of k most likely tokens.\n",
72 | "#@markdown * __At least K__ - limitation of dialing on the lower border. Sampling will be guaranteed to come from at least the most likely tokens at_least_k.\n",
73 | "#@markdown * __TopP__ - the topp parameter about which you can read in more detail in the article.\n",
74 | "#@markdown * __TopP Temperature__ - temperature scaling applied after selection by topp criterion.\n",
75 | "#@markdown * __Use Repetition Penalty__ - flag to use note repetition penalties. It is recommended to turn it on only if the model generates cyclic boring tracks.\n",
76 | "#@markdown * __RP Penalty__ - the amount of penalties for repeating notes. The higher, the less the model gets stuck in some musical phrases.\n",
77 | "#@markdown * __Restore speed__ - speed of recovery after fines of the repetition penalty module\n",
78 | "\n",
79 | "#@markdown There are also buttons by genre in the lower right corner. With the help of them, you can return to the generation parameters we have selected for each of the genres.\n",
80 | "\n",
81 | "\n",
82 | "def truncate_midi(midi, primer_len_sec=15.0, from_end=False):\n",
83 | " time0 = max([inst.notes[-1].end for inst in midi.instruments]) if from_end else 0\n",
84 | "\n",
85 | " for inst in midi.instruments:\n",
86 | " notes = sorted(inst.notes, key=lambda x: x.start, reverse=from_end)\n",
87 | " for i,note in enumerate(notes):\n",
88 | " if np.abs(note.start - time0) > primer_len_sec:\n",
89 | " break\n",
90 | " if not from_end and note.end > primer_len_sec:\n",
91 | " note.end = time0 - (int(from_end)*2-1) * primer_len_sec\n",
92 | " inst.notes = notes[:i]\n",
93 | " if from_end:\n",
94 | " inst.notes = inst.notes[::-1]\n",
95 | "\n",
96 | "\n",
97 | "style = {'description_width': 'initial'}\n",
98 | "genre_to_generate = ipywidgets.Dropdown(\n",
99 | " options=['calm', 'jazz', 'pop', 'classic'],\n",
100 | " value='calm',\n",
101 | " description='Genre:',\n",
102 | " disabled=False,\n",
103 | ")\n",
104 | "genre_to_generate.default_value = 'calm'\n",
105 | "\n",
106 | "seed = ipywidgets.Text(\n",
107 | " value='',\n",
108 | " placeholder='leave blank for random seed',\n",
109 | " description='Seed:',\n",
110 | " disabled=False\n",
111 | ")\n",
112 | "seed.default_value = ''\n",
113 | "\n",
114 | "temp = ipywidgets.FloatSlider( \n",
115 | " value=1.0,\n",
116 | " min=0.01,\n",
117 | " max=5.0,\n",
118 | " step=0.01,\n",
119 | " description='Temperature:',\n",
120 | " disabled=False,\n",
121 | " continuous_update=False,\n",
122 | " orientation='horizontal',\n",
123 | " readout=True,\n",
124 | " readout_format='.2f',\n",
125 | " style=style\n",
126 | ")\n",
127 | "temp.default_value = 1.0\n",
128 | "\n",
129 | "b_size = ipywidgets.IntSlider(\n",
130 | " value=8,\n",
131 | " min=1,\n",
132 | " max=16,\n",
133 | " step=1,\n",
134 | " description='Batch size:',\n",
135 | " disabled=False,\n",
136 | " continuous_update=False,\n",
137 | " orientation='horizontal',\n",
138 | " readout=True,\n",
139 | " readout_format='d',\n",
140 | " style=style\n",
141 | ")\n",
142 | "b_size.default_value = 8\n",
143 | "\n",
144 | "seq_length = ipywidgets.IntSlider(\n",
145 | " value=512,\n",
146 | " min=256,\n",
147 | " max=2048,\n",
148 | " step=256,\n",
149 | " description='Sequence length:',\n",
150 | " disabled=False,\n",
151 | " continuous_update=False,\n",
152 | " orientation='horizontal',\n",
153 | " readout=True,\n",
154 | " readout_format='d',\n",
155 | " style=style\n",
156 | ")\n",
157 | "seq_length.default_value = 512\n",
158 | "\n",
159 | "topk = ipywidgets.IntSlider(\n",
160 | " value=60,\n",
161 | " min=1,\n",
162 | " max=300,\n",
163 | " step=1,\n",
164 | " description='Top k:',\n",
165 | " disabled=False,\n",
166 | " continuous_update=False,\n",
167 | " orientation='horizontal',\n",
168 | " readout=True,\n",
169 | " readout_format='d',\n",
170 | " style=style\n",
171 | ")\n",
172 | "topk.default_value = 60\n",
173 | "\n",
174 | "at_least_k = ipywidgets.IntSlider(\n",
175 | " value=1,\n",
176 | " min=1,\n",
177 | " max=300,\n",
178 | " step=1,\n",
179 | " description='At least k:',\n",
180 | " disabled=False,\n",
181 | " continuous_update=False,\n",
182 | " orientation='horizontal',\n",
183 | " readout=True,\n",
184 | " readout_format='d',\n",
185 | " style=style\n",
186 | ")\n",
187 | "at_least_k.default_value = 1\n",
188 | "\n",
189 | "topp = ipywidgets.FloatSlider( \n",
190 | " value=0.99,\n",
191 | " min=0.5,\n",
192 | " max=1.0,\n",
193 | " step=0.01,\n",
194 | " description='Topp:',\n",
195 | " disabled=False,\n",
196 | " continuous_update=False,\n",
197 | " orientation='horizontal',\n",
198 | " readout=True,\n",
199 | " readout_format='.2f',\n",
200 | " style=style\n",
201 | ")\n",
202 | "topp.default_value = 0.99\n",
203 | "\n",
204 | "topp_temperature = ipywidgets.FloatSlider( \n",
205 | " value=1.0,\n",
206 | " min=0.01,\n",
207 | " max=5.0,\n",
208 | " step=0.01,\n",
209 | " description='Topp Temperature:',\n",
210 | " disabled=False,\n",
211 | " continuous_update=False,\n",
212 | " orientation='horizontal',\n",
213 | " readout=True,\n",
214 | " readout_format='.2f',\n",
215 | " style=style\n",
216 | ")\n",
217 | "topp_temperature.default_value = 1.0\n",
218 | "\n",
219 | "use_rp = ipywidgets.Checkbox(\n",
220 | " value=False,\n",
221 | " description='Use Repetition Penalty',\n",
222 | " disabled=False,\n",
223 | " style=style\n",
224 | ")\n",
225 | "use_rp.default_value = False\n",
226 | "\n",
227 | "rp_penalty = ipywidgets.FloatSlider( \n",
228 | " value=0.05,\n",
229 | " min=0.,\n",
230 | " max=1.0,\n",
231 | " step=0.05,\n",
232 | " description='RP penalty:',\n",
233 | " disabled=False,\n",
234 | " continuous_update=False,\n",
235 | " orientation='horizontal',\n",
236 | " readout=True,\n",
237 | " readout_format='.2f',\n",
238 | " style=style\n",
239 | ")\n",
240 | "rp_penalty.default_value = 0.05\n",
241 | "\n",
242 | "restore_speed = ipywidgets.FloatSlider( \n",
243 | " value=0.7,\n",
244 | " min=0.,\n",
245 | " max=1.0,\n",
246 | " step=0.05,\n",
247 | " description='Restore speed:',\n",
248 | " disabled=False,\n",
249 | " continuous_update=False,\n",
250 | " orientation='horizontal',\n",
251 | " readout=True,\n",
252 | " readout_format='.2f',\n",
253 | " style=style\n",
254 | ")\n",
255 | "restore_speed.default_value = 0.7\n",
256 | "\n",
257 | "remove_bad_generations = ipywidgets.Checkbox(\n",
258 | " value=True,\n",
259 | " description='Remove bad generations',\n",
260 | " disabled=False,\n",
261 | " style=style\n",
262 | ")\n",
263 | "remove_bad_generations.default_value = True\n",
264 | "\n",
265 | "defaulting_widgets = (genre_to_generate, seed, b_size, seq_length, remove_bad_generations,\n",
266 | " temp, topk, at_least_k, topp, topp_temperature, use_rp, rp_penalty, restore_speed)\n",
267 | "def set_classic(button):\n",
268 | " for idx, widget in enumerate(defaulting_widgets):\n",
269 | " if idx == 0:\n",
270 | " widget.value = 'classic'\n",
271 | " elif idx == 5:\n",
272 | " widget.value = 1.0\n",
273 | " elif idx == 7:\n",
274 | " widget.value = 1\n",
275 | " elif idx == 8:\n",
276 | " widget.value = 0.99\n",
277 | " else:\n",
278 | " widget.value = widget.default_value\n",
279 | "\n",
280 | "def set_calm(button):\n",
281 | " for idx, widget in enumerate(defaulting_widgets):\n",
282 | " if idx == 0:\n",
283 | " widget.value = 'calm'\n",
284 | " elif idx == 5:\n",
285 | " widget.value = 1.03\n",
286 | " elif idx == 7:\n",
287 | " widget.value = 4\n",
288 | " elif idx == 8:\n",
289 | " widget.value = 0.98\n",
290 | " else:\n",
291 | " widget.value = widget.default_value\n",
292 | "\n",
293 | "def set_jazz(button):\n",
294 | " for idx, widget in enumerate(defaulting_widgets):\n",
295 | " if idx == 0:\n",
296 | " widget.value = 'jazz'\n",
297 | " elif idx == 5:\n",
298 | " widget.value = 0.99\n",
299 | " elif idx == 7:\n",
300 | " widget.value = 1\n",
301 | " elif idx == 8:\n",
302 | " widget.value = 0.99\n",
303 | " else:\n",
304 | " widget.value = widget.default_value\n",
305 | "\n",
306 | "def set_pop(button):\n",
307 | " for idx, widget in enumerate(defaulting_widgets):\n",
308 | " if idx == 0:\n",
309 | " widget.value = 'pop'\n",
310 | " elif idx == 5:\n",
311 | " widget.value = 0.98\n",
312 | " elif idx == 7:\n",
313 | " widget.value = 4\n",
314 | " elif idx == 8:\n",
315 | " widget.value = 0.99\n",
316 | " else:\n",
317 | " widget.value = widget.default_value\n",
318 | "\n",
319 | "calm_value_button = ipywidgets.Button(description='Calm')\n",
320 | "calm_value_button.on_click(set_calm)\n",
321 | "jazz_value_button = ipywidgets.Button(description='Jazz')\n",
322 | "jazz_value_button.on_click(set_jazz)\n",
323 | "pop_value_button = ipywidgets.Button(description='Pop')\n",
324 | "pop_value_button.on_click(set_pop)\n",
325 | "classic_value_button = ipywidgets.Button(description='Classic')\n",
326 | "classic_value_button.on_click(set_classic)\n",
327 | "buttons = (calm_value_button, jazz_value_button, pop_value_button, classic_value_button)\n",
328 | "left_box = ipywidgets.VBox((buttons[0], buttons[1]))\n",
329 | "right_box = ipywidgets.VBox((buttons[2], buttons[3]))\n",
330 | "\n",
331 | "def update_dropdown():\n",
332 | " select_primer_widget.options = ['None'] + sorted(map(str, Path('primers').glob('*.*')))\n",
333 | "\n",
334 | "def on_upload(x):\n",
335 | " x = x['new']\n",
336 | " names = x.keys()\n",
337 | " os.makedirs('primers', exist_ok=True)\n",
338 | " for name in names:\n",
339 | " content = x[name]['content']\n",
340 | " if content:\n",
341 | " path = 'primers/'+name\n",
342 | " with open(path, \"wb\") as fp:\n",
343 | " fp.write(content)\n",
344 | " try:\n",
345 | " pretty_midi.PrettyMIDI(path)\n",
346 | " except:\n",
347 | " print(f'file \"{name}\" is corrupted or not a MIDI!')\n",
348 | " os.remove(path)\n",
349 | " update_dropdown()\n",
350 | "\n",
351 | "select_primer_widget = ipywidgets.Dropdown()\n",
352 | "update_dropdown()\n",
353 | "\n",
354 | "uploader = ipywidgets.FileUpload(description='Upload MIDI', multiple=True)\n",
355 | "uploader.observe(on_upload, names='value')\n",
356 | "\n",
357 | "primer_len = ipywidgets.FloatSlider(\n",
358 | " value=15,\n",
359 | " min=0,\n",
360 | " max=60,\n",
361 | " step=0.1,\n",
362 | " description='Seconds:',\n",
363 | " continuous_update=False,\n",
364 | " style=style\n",
365 | ")\n",
366 | "\n",
367 | "primer_position = ipywidgets.RadioButtons(options=[['From start',0],['From end',1]], orientation='horizontal')\n",
368 | "\n",
369 | "AppLayout(header=ipywidgets.HBox([Label('Select primer:'), select_primer_widget, uploader, primer_len, primer_position]),\n",
370 | " left_sidebar=VBox([genre_to_generate, seed, b_size, seq_length, remove_bad_generations]),\n",
371 | " center=VBox([temp, topk, at_least_k, topp, topp_temperature]),\n",
372 | " right_sidebar=VBox([use_rp, rp_penalty, restore_speed, ipywidgets.HBox((left_box, right_box))]),\n",
373 | " footer=None,\n",
374 | " pane_widths=[1, 1, 1],\n",
375 | " pane_heights=[1, 5, '60px'])"
376 | ]
377 | },
378 | {
379 | "cell_type": "code",
380 | "execution_count": null,
381 | "metadata": {
382 | "cellView": "form",
383 | "id": "XPdirMXYoign"
384 | },
385 | "outputs": [],
386 | "source": [
387 | "#@title 2. Start generating\n",
388 | "load_path = '/content/model_finetune_700k.pt'\n",
389 | "out_dir = 'generated_' + time.strftime('%d-%m-%Y_%H-%M-%S') \n",
390 | "device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'\n",
391 | "\n",
392 | "if device == 'cpu':\n",
393 | " print('Generating on CPU. Expect lower generation speed.')\n",
394 | "\n",
395 | "os.makedirs(out_dir, exist_ok=True)\n",
396 | "genre_id = genre2id[genre_to_generate.value]\n",
397 | "\n",
398 | "params = dict(\n",
399 | " target_seq_length = seq_length.value,\n",
400 | " temperature = temp.value,\n",
401 | " topk = topk.value,\n",
402 | " topp = topp.value,\n",
403 | " topp_temperature = topp_temperature.value,\n",
404 | " at_least_k = at_least_k.value,\n",
405 | " use_rp = use_rp.value,\n",
406 | " rp_penalty = rp_penalty.value,\n",
407 | " rp_restore_speed = restore_speed.value,\n",
408 | " seed = int(seed.value) if seed.value else None\n",
409 | ")\n",
410 | "max_primer_tokens = 512\n",
411 | "\n",
412 | "# Init model\n",
413 | "print('loading model...')\n",
414 | "model = MusicTransformer(device, n_layers=12, d_model=1024, dim_feedforward=2048, num_heads=16, vocab_size=constants.VOCAB_SIZE, rpr=True).to(device).eval()\n",
415 | "model.load_state_dict(torch.load(load_path, map_location=device))\n",
416 | "\n",
417 | "# Add genre and primer\n",
418 | "primer_genre = np.repeat([genre_id], b_size.value)[:,None] + constants.VOCAB_SIZE - 4\n",
419 | "if primer_len.value > 0 and select_primer_widget.value != 'None':\n",
420 | " file = select_primer_widget.value\n",
421 | " midi = pretty_midi.PrettyMIDI(file)\n",
422 | " from_end = primer_position.value == 1\n",
423 | " truncate_midi(midi, primer_len.value, from_end) # truncate to specified length (in seconds)\n",
424 | " encoded = midi_processing.encode(midi)\n",
425 | " # truncate to max_primer_tokens (in tokens)\n",
426 | " l = len(encoded)\n",
427 | " if l > max_primer_tokens:\n",
428 | " import warnings\n",
429 | " warnings.warn('Primer MIDI is too long (length > 512), it will be truncated to 512!')\n",
430 | " if from_end:\n",
431 | " encoded = encoded[l-max_primer_tokens:]\n",
432 | " else:\n",
433 | " encoded = encoded[:max_primer_tokens]\n",
434 | " primer_seq = np.repeat(np.array(encoded)[None], b_size.value, 0)\n",
435 | " primer = np.concatenate([primer_genre, primer_seq], -1)\n",
436 | "else:\n",
437 | " primer = primer_genre\n",
438 | "primer = torch.tensor(primer, dtype=torch.int64)\n",
439 | "\n",
440 | "# Generation\n",
441 | "if primer.shape[-1] >= seq_length.value:\n",
442 | " print('Nothing to generate. Try to set larger \"Sequence length\" parameter!')\n",
443 | "else:\n",
444 | " while len(glob.glob(out_dir + '/*.mid')) != b_size.value:\n",
445 | " print('generating to:', os.path.abspath(out_dir))\n",
446 | " generated = generation.generate(model, primer, **params)\n",
447 | " generated = generation.post_process(generated, remove_bad_generations=remove_bad_generations.value)\n",
448 | " decode_and_write(generated, primer, primer_genre.squeeze(-1)-390, out_dir)\n",
449 | " files_to_delete = len(glob.glob(out_dir + '/*.mid')) - b_size.value\n",
450 | " if files_to_delete > 0:\n",
451 | " for idx in range(files_to_delete):\n",
452 | " os.remove(sorted(glob.glob(out_dir + '/*.mid'), key=lambda x: int(x.split('_')[-2]))[-idx])\n",
453 | " for midi_name in glob.glob(out_dir + '/*.mid'):\n",
454 | " convert_midi_to_wav(midi_name)"
455 | ]
456 | },
457 | {
458 | "cell_type": "code",
459 | "execution_count": null,
460 | "metadata": {
461 | "cellView": "form",
462 | "id": "IHRrMDeyk2GI"
463 | },
464 | "outputs": [],
465 | "source": [
466 | "#@title 3. Listen and download the generation results\n",
467 | "wav_files = glob.glob(out_dir + '/*.wav')\n",
468 | "if len(wav_files) > 3:\n",
469 | " rows = round(len(wav_files) / 3) + 1\n",
470 | " columns = 3\n",
471 | "else:\n",
472 | " rows = 1\n",
473 | " columns = len(wav_files)\n",
474 | "grid = widgets.Grid(rows, columns)\n",
475 | "current_position = 0\n",
476 | "for row in range(rows):\n",
477 | " if current_position > len(wav_files) - 1:\n",
478 | " break\n",
479 | " for column in range(columns):\n",
480 | " if current_position > len(wav_files) - 1:\n",
481 | " break\n",
482 | " with grid.output_to(row, column):\n",
483 | " genre = rugenre[wav_files[current_position].split('.wav')[0].split('_')[-1]]\n",
484 | " print(f'Номер трека: {current_position + 1}')\n",
485 | " print(f'Жанр: {genre}')\n",
486 | " display(Audio(wav_files[current_position]))\n",
487 | " display(DownloadButton(filename=wav_files[current_position], description='Скачать .wav'))\n",
488 | " display(DownloadButton(filename=wav_files[current_position].replace('.wav', '.mid'), description='Скачать .midi'))\n",
489 | " current_position += 1"
490 | ]
491 | }
492 | ],
493 | "metadata": {
494 | "accelerator": "GPU",
495 | "colab": {
496 | "collapsed_sections": [],
497 | "name": "Music_Composer_Demo_Colab.ipynb",
498 | "private_outputs": true,
499 | "provenance": []
500 | },
501 | "kernelspec": {
502 | "display_name": "Python 3",
503 | "language": "python",
504 | "name": "python3"
505 | },
506 | "language_info": {
507 | "codemirror_mode": {
508 | "name": "ipython",
509 | "version": 3
510 | },
511 | "file_extension": ".py",
512 | "mimetype": "text/x-python",
513 | "name": "python",
514 | "nbconvert_exporter": "python",
515 | "pygments_lexer": "ipython3",
516 | "version": "3.8.5"
517 | }
518 | },
519 | "nbformat": 4,
520 | "nbformat_minor": 1
521 | }
522 |
--------------------------------------------------------------------------------
/src/Music_Composer_Demo_Colab_ru.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "Music_Composer_Demo_Colab.ipynb",
7 | "private_outputs": true,
8 | "provenance": [],
9 | "collapsed_sections": []
10 | },
11 | "kernelspec": {
12 | "display_name": "Python 3",
13 | "language": "python",
14 | "name": "python3"
15 | },
16 | "language_info": {
17 | "codemirror_mode": {
18 | "name": "ipython",
19 | "version": 3
20 | },
21 | "file_extension": ".py",
22 | "mimetype": "text/x-python",
23 | "name": "python",
24 | "nbconvert_exporter": "python",
25 | "pygments_lexer": "ipython3",
26 | "version": "3.8.5"
27 | },
28 | "accelerator": "GPU"
29 | },
30 | "cells": [
31 | {
32 | "cell_type": "code",
33 | "metadata": {
34 | "id": "S_i1BkSBikEs",
35 | "cellView": "form"
36 | },
37 | "source": [
38 | "#@title 0. Подготовка окружения к работе\n",
39 | "#@markdown Тут просто скачиваются и импортируются библиотеки, ничего интересного\n",
40 | "!pip install gdown > /dev/null\n",
41 | "!apt install fluidsynth > /dev/null\n",
42 | "!pip install midi2audio > /dev/null\n",
43 | "!pip install pretty_midi > /dev/null\n",
44 | "!cp /usr/share/sounds/sf2/FluidR3_GM.sf2 ./font.sf2\n",
45 | "!git clone https://github.com/sberbank-ai/music-composer.git\n",
46 | "\n",
47 | "import os\n",
48 | "import sys\n",
49 | "import time\n",
50 | "import glob\n",
51 | "import torch\n",
52 | "import gdown\n",
53 | "import webbrowser\n",
54 | "import ipywidgets\n",
55 | "import pretty_midi\n",
56 | "import numpy as np\n",
57 | "\n",
58 | "from tqdm import tqdm\n",
59 | "from pathlib import Path\n",
60 | "from midi2audio import FluidSynth\n",
61 | "from google.colab import widgets, files\n",
62 | "from IPython.display import Audio, display, FileLink, HTML\n",
63 | "from ipywidgets import AppLayout, HBox, VBox, Label\n",
64 | "\n",
65 | "url = 'https://drive.google.com/u/0/uc?id=1lcNp0y4IZMIos0ASSsERG25WVDuEmLks'\n",
66 | "output = 'model_finetune_700k.pt'\n",
67 | "gdown.download(url, output, quiet=True)\n",
68 | "\n",
69 | "\n",
70 | "sys.path.append('/content/music-composer/src/')\n",
71 | "from lib import constants\n",
72 | "from lib import generation\n",
73 | "from lib import midi_processing\n",
74 | "from lib.midi_processing import PIANO_RANGE\n",
75 | "from lib.model.transformer import MusicTransformer\n",
76 | "from lib.colab_utils import id2genre, rugenre, genre2id, decode_and_write, convert_midi_to_wav, DownloadButton"
77 | ],
78 | "execution_count": null,
79 | "outputs": []
80 | },
81 | {
82 | "cell_type": "code",
83 | "metadata": {
84 | "id": "Yn4cWWoOJHfi",
85 | "cellView": "form"
86 | },
87 | "source": [
88 | "#@title 1. Панель управления генерацией\n",
89 | "#@markdown В данной панели можно настроить параметры генерации музыки под себя. Пройдемся по всем пунктам:\n",
90 | "#@markdown * Вы можете выбрать __primer__ - это ваш MIDI файл, который будет взят за начало композиции. Если он длинный, то произойдет его обрезка. Укажите желаемую длину (в секундах) и способ его обрезки - с начала, или с конца (From start / From end). __From start__ означает, что мы оставим первые N сек, __From end__ - посление N сек.\n",
91 | "#@markdown * __Genre__ - позволяет выбрать жанр в котором будут сгенерированы треки. Для каждого из жанров мы подобрали свой набор параметров, который позволяет лучше всего его генерировать, но об этом позднее. \n",
92 | "#@markdown * __Seed__ - позволяет зафиксировать сид с которым генерируются треки. Его фиксация может быть полезна если вы хотите изучить влияние того или иного параметра на генерацию - задайте численный сид и генерируйте несколько раз изменяя изучаемый параметр. В остальных случаях рекомендуется не заполнять данное поле. \n",
93 | "#@markdown * __Batch_size__ - количество одновременно генерируемых треков. Если переборщить - может не хватить памяти. Тут нужно подбирать баланс между кол-вом одновременных треков и их длиной. \n",
94 | "#@markdown * __Sequence length__ - длина трека. В связи с особенностями генерации мы не можем заранее сказать сколько по времени будет длиться трек, но увеличение этого параметра потенциально увеличивает длину трека. Именно с этим параметром надо балансировать Batch_size чтобы не закончилась видеопамять\n",
95 | "#@markdown * __Remove bad generations__ - иногда модель начинает генерировать мусор. Если данный флаг включен - мы постараемся его обнаружить, отсеять и сгенерировать вместо него новые композиции. Соответственно увеличивает время генерации.\n",
96 | "#@markdown * __Temperature__ - температурный скейлинг. Влияет на само распределение вероятностей, делая его более либо менее равновероятным. Тем самым регулируется разнообразие генерации.\n",
97 | "#@markdown * __TopK__ - ограничение набора по верхней границе. Сэмплирование происходит из набора k наиболее вероятных токенов.\n",
98 | "#@markdown * __At least K__ - ограничение набора по нижней границе. Сэмплирование будет гарантированно происходить как минимум из at_least_k наиболее вероятных токенов.\n",
99 | "#@markdown * __TopP__ - параметр topp о котором можно подробнее прочитать в статье.\n",
100 | "#@markdown * __TopP Temperature__ - температурный скейлинг применяющийся после отбора по критерию topp.\n",
101 | "#@markdown * __Use Repetition Penalty__ - флаг для использования штрафов за повторы нот. Рекомендуется включать только если модель генерирует цикличные скучные треки.\n",
102 | "#@markdown * __RP Penalty__ - размер штрафов за повтор нот. Чем выше тем меньше модель зацикливается в одних музыкальных фразах.\n",
103 | "#@markdown * __Restore speed__ - скорость восстановления после штрафов модуля repetition penalty \n",
104 | "\n",
105 | "#@markdown Также в правом нижнем углу доступны кнопки по жанрам. С помощью них можно вернуться к подобранным нами параметрам генерации под каждый из жанров.\n",
106 | "\n",
107 | "\n",
108 | "def truncate_midi(midi, primer_len_sec=15.0, from_end=False):\n",
109 | " time0 = max([inst.notes[-1].end for inst in midi.instruments]) if from_end else 0\n",
110 | "\n",
111 | " for inst in midi.instruments:\n",
112 | " notes = sorted(inst.notes, key=lambda x: x.start, reverse=from_end)\n",
113 | " for i,note in enumerate(notes):\n",
114 | " if np.abs(note.start - time0) > primer_len_sec:\n",
115 | " break\n",
116 | " if not from_end and note.end > primer_len_sec:\n",
117 | " note.end = time0 - (int(from_end)*2-1) * primer_len_sec\n",
118 | " inst.notes = notes[:i]\n",
119 | " if from_end:\n",
120 | " inst.notes = inst.notes[::-1]\n",
121 | "\n",
122 | "\n",
123 | "style = {'description_width': 'initial'}\n",
124 | "genre_to_generate = ipywidgets.Dropdown(\n",
125 | " options=['calm', 'jazz', 'pop', 'classic'],\n",
126 | " value='calm',\n",
127 | " description='Genre:',\n",
128 | " disabled=False,\n",
129 | ")\n",
130 | "genre_to_generate.default_value = 'calm'\n",
131 | "\n",
132 | "seed = ipywidgets.Text(\n",
133 | " value='',\n",
134 | " placeholder='leave blank for random seed',\n",
135 | " description='Seed:',\n",
136 | " disabled=False\n",
137 | ")\n",
138 | "seed.default_value = ''\n",
139 | "\n",
140 | "temp = ipywidgets.FloatSlider( \n",
141 | " value=1.0,\n",
142 | " min=0.01,\n",
143 | " max=5.0,\n",
144 | " step=0.01,\n",
145 | " description='Temperature:',\n",
146 | " disabled=False,\n",
147 | " continuous_update=False,\n",
148 | " orientation='horizontal',\n",
149 | " readout=True,\n",
150 | " readout_format='.2f',\n",
151 | " style=style\n",
152 | ")\n",
153 | "temp.default_value = 1.0\n",
154 | "\n",
155 | "b_size = ipywidgets.IntSlider(\n",
156 | " value=8,\n",
157 | " min=1,\n",
158 | " max=16,\n",
159 | " step=1,\n",
160 | " description='Batch size:',\n",
161 | " disabled=False,\n",
162 | " continuous_update=False,\n",
163 | " orientation='horizontal',\n",
164 | " readout=True,\n",
165 | " readout_format='d',\n",
166 | " style=style\n",
167 | ")\n",
168 | "b_size.default_value = 8\n",
169 | "\n",
170 | "seq_length = ipywidgets.IntSlider(\n",
171 | " value=512,\n",
172 | " min=256,\n",
173 | " max=2048,\n",
174 | " step=256,\n",
175 | " description='Sequence length:',\n",
176 | " disabled=False,\n",
177 | " continuous_update=False,\n",
178 | " orientation='horizontal',\n",
179 | " readout=True,\n",
180 | " readout_format='d',\n",
181 | " style=style\n",
182 | ")\n",
183 | "seq_length.default_value = 512\n",
184 | "\n",
185 | "topk = ipywidgets.IntSlider(\n",
186 | " value=60,\n",
187 | " min=1,\n",
188 | " max=300,\n",
189 | " step=1,\n",
190 | " description='Top k:',\n",
191 | " disabled=False,\n",
192 | " continuous_update=False,\n",
193 | " orientation='horizontal',\n",
194 | " readout=True,\n",
195 | " readout_format='d',\n",
196 | " style=style\n",
197 | ")\n",
198 | "topk.default_value = 60\n",
199 | "\n",
200 | "at_least_k = ipywidgets.IntSlider(\n",
201 | " value=1,\n",
202 | " min=1,\n",
203 | " max=300,\n",
204 | " step=1,\n",
205 | " description='At least k:',\n",
206 | " disabled=False,\n",
207 | " continuous_update=False,\n",
208 | " orientation='horizontal',\n",
209 | " readout=True,\n",
210 | " readout_format='d',\n",
211 | " style=style\n",
212 | ")\n",
213 | "at_least_k.default_value = 1\n",
214 | "\n",
215 | "topp = ipywidgets.FloatSlider( \n",
216 | " value=0.99,\n",
217 | " min=0.5,\n",
218 | " max=1.0,\n",
219 | " step=0.01,\n",
220 | " description='Topp:',\n",
221 | " disabled=False,\n",
222 | " continuous_update=False,\n",
223 | " orientation='horizontal',\n",
224 | " readout=True,\n",
225 | " readout_format='.2f',\n",
226 | " style=style\n",
227 | ")\n",
228 | "topp.default_value = 0.99\n",
229 | "\n",
230 | "topp_temperature = ipywidgets.FloatSlider( \n",
231 | " value=1.0,\n",
232 | " min=0.01,\n",
233 | " max=5.0,\n",
234 | " step=0.01,\n",
235 | " description='Topp Temperature:',\n",
236 | " disabled=False,\n",
237 | " continuous_update=False,\n",
238 | " orientation='horizontal',\n",
239 | " readout=True,\n",
240 | " readout_format='.2f',\n",
241 | " style=style\n",
242 | ")\n",
243 | "topp_temperature.default_value = 1.0\n",
244 | "\n",
245 | "use_rp = ipywidgets.Checkbox(\n",
246 | " value=False,\n",
247 | " description='Use Repetition Penalty',\n",
248 | " disabled=False,\n",
249 | " style=style\n",
250 | ")\n",
251 | "use_rp.default_value = False\n",
252 | "\n",
253 | "rp_penalty = ipywidgets.FloatSlider( \n",
254 | " value=0.05,\n",
255 | " min=0.,\n",
256 | " max=1.0,\n",
257 | " step=0.05,\n",
258 | " description='RP penalty:',\n",
259 | " disabled=False,\n",
260 | " continuous_update=False,\n",
261 | " orientation='horizontal',\n",
262 | " readout=True,\n",
263 | " readout_format='.2f',\n",
264 | " style=style\n",
265 | ")\n",
266 | "rp_penalty.default_value = 0.05\n",
267 | "\n",
268 | "restore_speed = ipywidgets.FloatSlider( \n",
269 | " value=0.7,\n",
270 | " min=0.,\n",
271 | " max=1.0,\n",
272 | " step=0.05,\n",
273 | " description='Restore speed:',\n",
274 | " disabled=False,\n",
275 | " continuous_update=False,\n",
276 | " orientation='horizontal',\n",
277 | " readout=True,\n",
278 | " readout_format='.2f',\n",
279 | " style=style\n",
280 | ")\n",
281 | "restore_speed.default_value = 0.7\n",
282 | "\n",
283 | "remove_bad_generations = ipywidgets.Checkbox(\n",
284 | " value=True,\n",
285 | " description='Remove bad generations',\n",
286 | " disabled=False,\n",
287 | " style=style\n",
288 | ")\n",
289 | "remove_bad_generations.default_value = True\n",
290 | "\n",
291 | "defaulting_widgets = (genre_to_generate, seed, b_size, seq_length, remove_bad_generations,\n",
292 | " temp, topk, at_least_k, topp, topp_temperature, use_rp, rp_penalty, restore_speed)\n",
293 | "def set_classic(button):\n",
294 | " for idx, widget in enumerate(defaulting_widgets):\n",
295 | " if idx == 0:\n",
296 | " widget.value = 'classic'\n",
297 | " elif idx == 5:\n",
298 | " widget.value = 1.0\n",
299 | " elif idx == 7:\n",
300 | " widget.value = 1\n",
301 | " elif idx == 8:\n",
302 | " widget.value = 0.99\n",
303 | " else:\n",
304 | " widget.value = widget.default_value\n",
305 | "\n",
306 | "def set_calm(button):\n",
307 | " for idx, widget in enumerate(defaulting_widgets):\n",
308 | " if idx == 0:\n",
309 | " widget.value = 'calm'\n",
310 | " elif idx == 5:\n",
311 | " widget.value = 1.03\n",
312 | " elif idx == 7:\n",
313 | " widget.value = 4\n",
314 | " elif idx == 8:\n",
315 | " widget.value = 0.98\n",
316 | " else:\n",
317 | " widget.value = widget.default_value\n",
318 | "\n",
319 | "def set_jazz(button):\n",
320 | " for idx, widget in enumerate(defaulting_widgets):\n",
321 | " if idx == 0:\n",
322 | " widget.value = 'jazz'\n",
323 | " elif idx == 5:\n",
324 | " widget.value = 0.99\n",
325 | " elif idx == 7:\n",
326 | " widget.value = 1\n",
327 | " elif idx == 8:\n",
328 | " widget.value = 0.99\n",
329 | " else:\n",
330 | " widget.value = widget.default_value\n",
331 | "\n",
332 | "def set_pop(button):\n",
333 | " for idx, widget in enumerate(defaulting_widgets):\n",
334 | " if idx == 0:\n",
335 | " widget.value = 'pop'\n",
336 | " elif idx == 5:\n",
337 | " widget.value = 0.98\n",
338 | " elif idx == 7:\n",
339 | " widget.value = 4\n",
340 | " elif idx == 8:\n",
341 | " widget.value = 0.99\n",
342 | " else:\n",
343 | " widget.value = widget.default_value\n",
344 | "\n",
345 | "calm_value_button = ipywidgets.Button(description='Calm')\n",
346 | "calm_value_button.on_click(set_calm)\n",
347 | "jazz_value_button = ipywidgets.Button(description='Jazz')\n",
348 | "jazz_value_button.on_click(set_jazz)\n",
349 | "pop_value_button = ipywidgets.Button(description='Pop')\n",
350 | "pop_value_button.on_click(set_pop)\n",
351 | "classic_value_button = ipywidgets.Button(description='Classic')\n",
352 | "classic_value_button.on_click(set_classic)\n",
353 | "buttons = (calm_value_button, jazz_value_button, pop_value_button, classic_value_button)\n",
354 | "left_box = ipywidgets.VBox((buttons[0], buttons[1]))\n",
355 | "right_box = ipywidgets.VBox((buttons[2], buttons[3]))\n",
356 | "\n",
357 | "def update_dropdown():\n",
358 | " select_primer_widget.options = ['None'] + sorted(map(str, Path('primers').glob('*.*')))\n",
359 | "\n",
360 | "def on_upload(x):\n",
361 | " x = x['new']\n",
362 | " names = x.keys()\n",
363 | " os.makedirs('primers', exist_ok=True)\n",
364 | " for name in names:\n",
365 | " content = x[name]['content']\n",
366 | " if content:\n",
367 | " path = 'primers/'+name\n",
368 | " with open(path, \"wb\") as fp:\n",
369 | " fp.write(content)\n",
370 | " try:\n",
371 | " pretty_midi.PrettyMIDI(path)\n",
372 | " except:\n",
373 | " print(f'file \"{name}\" is corrupted or not a MIDI!')\n",
374 | " os.remove(path)\n",
375 | " update_dropdown()\n",
376 | "\n",
377 | "select_primer_widget = ipywidgets.Dropdown()\n",
378 | "update_dropdown()\n",
379 | "\n",
380 | "uploader = ipywidgets.FileUpload(description='Upload MIDI', multiple=True)\n",
381 | "uploader.observe(on_upload, names='value')\n",
382 | "\n",
383 | "primer_len = ipywidgets.FloatSlider(\n",
384 | " value=15,\n",
385 | " min=0,\n",
386 | " max=60,\n",
387 | " step=0.1,\n",
388 | " description='Seconds:',\n",
389 | " continuous_update=False,\n",
390 | " style=style\n",
391 | ")\n",
392 | "\n",
393 | "primer_position = ipywidgets.RadioButtons(options=[['From start',0],['From end',1]], orientation='horizontal')\n",
394 | "\n",
395 | "AppLayout(header=ipywidgets.HBox([Label('Select primer:'), select_primer_widget, uploader, primer_len, primer_position]),\n",
396 | " left_sidebar=VBox([genre_to_generate, seed, b_size, seq_length, remove_bad_generations]),\n",
397 | " center=VBox([temp, topk, at_least_k, topp, topp_temperature]),\n",
398 | " right_sidebar=VBox([use_rp, rp_penalty, restore_speed, ipywidgets.HBox((left_box, right_box))]),\n",
399 | " footer=None,\n",
400 | " pane_widths=[1, 1, 1],\n",
401 | " pane_heights=[1, 5, '60px'])"
402 | ],
403 | "execution_count": null,
404 | "outputs": []
405 | },
406 | {
407 | "cell_type": "code",
408 | "metadata": {
409 | "id": "XPdirMXYoign",
410 | "cellView": "form"
411 | },
412 | "source": [
413 | "#@title 2. Запуск процесса генерации\n",
414 | "load_path = '/content/model_finetune_700k.pt'\n",
415 | "out_dir = 'generated_' + time.strftime('%d-%m-%Y_%H-%M-%S') \n",
416 | "device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'\n",
417 | "\n",
418 | "if device == 'cpu':\n",
419 | " print('Generating on CPU. Expect lower generation speed.')\n",
420 | "\n",
421 | "os.makedirs(out_dir, exist_ok=True)\n",
422 | "genre_id = genre2id[genre_to_generate.value]\n",
423 | "\n",
424 | "params = dict(\n",
425 | " target_seq_length = seq_length.value,\n",
426 | " temperature = temp.value,\n",
427 | " topk = topk.value,\n",
428 | " topp = topp.value,\n",
429 | " topp_temperature = topp_temperature.value,\n",
430 | " at_least_k = at_least_k.value,\n",
431 | " use_rp = use_rp.value,\n",
432 | " rp_penalty = rp_penalty.value,\n",
433 | " rp_restore_speed = restore_speed.value,\n",
434 | " seed = int(seed.value) if seed.value else None\n",
435 | ")\n",
436 | "max_primer_tokens = 512\n",
437 | "\n",
438 | "# Init model\n",
439 | "print('loading model...')\n",
440 | "model = MusicTransformer(device, n_layers=12, d_model=1024, dim_feedforward=2048, num_heads=16, vocab_size=constants.VOCAB_SIZE, rpr=True).to(device).eval()\n",
441 | "model.load_state_dict(torch.load(load_path, map_location=device))\n",
442 | "\n",
443 | "# Add genre and primer\n",
444 | "primer_genre = np.repeat([genre_id], b_size.value)[:,None] + constants.VOCAB_SIZE - 4\n",
445 | "if primer_len.value > 0 and select_primer_widget.value != 'None':\n",
446 | " file = select_primer_widget.value\n",
447 | " midi = pretty_midi.PrettyMIDI(file)\n",
448 | " from_end = primer_position.value == 1\n",
449 | " truncate_midi(midi, primer_len.value, from_end) # truncate to specified length (in seconds)\n",
450 | " encoded = midi_processing.encode(midi)\n",
451 | " # truncate to max_primer_tokens (in tokens)\n",
452 | " l = len(encoded)\n",
453 | " if l > max_primer_tokens:\n",
454 | " import warnings\n",
455 | " warnings.warn('Primer MIDI is too long (length > 512), it will be truncated to 512!')\n",
456 | " if from_end:\n",
457 | " encoded = encoded[l-max_primer_tokens:]\n",
458 | " else:\n",
459 | " encoded = encoded[:max_primer_tokens]\n",
460 | " primer_seq = np.repeat(np.array(encoded)[None], b_size.value, 0)\n",
461 | " primer = np.concatenate([primer_genre, primer_seq], -1)\n",
462 | "else:\n",
463 | " primer = primer_genre\n",
464 | "primer = torch.tensor(primer, dtype=torch.int64)\n",
465 | "\n",
466 | "# Generation\n",
467 | "if primer.shape[-1] >= seq_length.value:\n",
468 | " print('Nothing to generate. Try to set larger \"Sequence length\" parameter!')\n",
469 | "else:\n",
470 | " while len(glob.glob(out_dir + '/*.mid')) != b_size.value:\n",
471 | " print('generating to:', os.path.abspath(out_dir))\n",
472 | " generated = generation.generate(model, primer, **params)\n",
473 | " generated = generation.post_process(generated, remove_bad_generations=remove_bad_generations.value)\n",
474 | " decode_and_write(generated, primer, primer_genre.squeeze(-1)-390, out_dir)\n",
475 | " files_to_delete = len(glob.glob(out_dir + '/*.mid')) - b_size.value\n",
476 | " if files_to_delete > 0:\n",
477 | " for idx in range(files_to_delete):\n",
478 | " os.remove(sorted(glob.glob(out_dir + '/*.mid'), key=lambda x: int(x.split('_')[-2]))[-idx])\n",
479 | " for midi_name in glob.glob(out_dir + '/*.mid'):\n",
480 | " convert_midi_to_wav(midi_name)"
481 | ],
482 | "execution_count": null,
483 | "outputs": []
484 | },
485 | {
486 | "cell_type": "code",
487 | "metadata": {
488 | "cellView": "form",
489 | "id": "IHRrMDeyk2GI"
490 | },
491 | "source": [
492 | "#@title 3. Прослушать и скачать результаты генерации\n",
493 | "wav_files = glob.glob(out_dir + '/*.wav')\n",
494 | "if len(wav_files) > 3:\n",
495 | " rows = round(len(wav_files) / 3) + 1\n",
496 | " columns = 3\n",
497 | "else:\n",
498 | " rows = 1\n",
499 | " columns = len(wav_files)\n",
500 | "grid = widgets.Grid(rows, columns)\n",
501 | "current_position = 0\n",
502 | "for row in range(rows):\n",
503 | " if current_position > len(wav_files) - 1:\n",
504 | " break\n",
505 | " for column in range(columns):\n",
506 | " if current_position > len(wav_files) - 1:\n",
507 | " break\n",
508 | " with grid.output_to(row, column):\n",
509 | " genre = rugenre[wav_files[current_position].split('.wav')[0].split('_')[-1]]\n",
510 | " print(f'Номер трека: {current_position + 1}')\n",
511 | " print(f'Жанр: {genre}')\n",
512 | " display(Audio(wav_files[current_position]))\n",
513 | " display(DownloadButton(filename=wav_files[current_position], description='Скачать .wav'))\n",
514 | " display(DownloadButton(filename=wav_files[current_position].replace('.wav', '.mid'), description='Скачать .midi'))\n",
515 | " current_position += 1"
516 | ],
517 | "execution_count": null,
518 | "outputs": []
519 | }
520 | ]
521 | }
--------------------------------------------------------------------------------
/gpt2-rga/[GPT2RGA] Quantum_Music.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "gradient": {
7 | "editing": false,
8 | "id": "ac5a4cf0-d9d2-47b5-9633-b53f8d99a4d2",
9 | "kernelId": ""
10 | },
11 | "id": "SiTIpPjArIyr"
12 | },
13 | "source": [
14 | "# Quantum Music (ver. 1.0)\n",
15 | "\n",
16 | "## Tokenized Sparse Time Quantization Example\n",
17 | "\n",
18 | "***\n",
19 | "\n",
20 | "Powered by tegridy-tools TMIDIX Optimus Processors: https://github.com/asigalov61/tegridy-tools\n",
21 | "\n",
22 | "***\n",
23 | "\n",
24 | "Credit for GPT2-RGA code used in this colab goes out @ Sashmark97 https://github.com/Sashmark97/midigen and @ Damon Gwinn https://github.com/gwinndr/MusicTransformer-Pytorch\n",
25 | "\n",
26 | "***\n",
27 | "\n",
28 | "WARNING: This complete implementation is a functioning model of the Artificial Intelligence. Please excercise great humility, care, and respect. https://www.nscai.gov/\n",
29 | "\n",
30 | "***\n",
31 | "\n",
32 | "#### Project Los Angeles\n",
33 | "\n",
34 | "#### Tegridy Code 2021\n",
35 | "\n",
36 | "***"
37 | ]
38 | },
39 | {
40 | "cell_type": "markdown",
41 | "metadata": {
42 | "gradient": {
43 | "editing": false,
44 | "id": "fa0a611c-1803-42ae-bdf6-a49b5a4e781b",
45 | "kernelId": ""
46 | },
47 | "id": "gOd93yV0sGd2"
48 | },
49 | "source": [
50 | "# (Setup Environment)"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": null,
56 | "metadata": {
57 | "cellView": "form",
58 | "gradient": {
59 | "editing": false,
60 | "id": "39411b40-9e39-416e-8fe4-d40f733e7956",
61 | "kernelId": ""
62 | },
63 | "id": "lw-4aqV3sKQG"
64 | },
65 | "outputs": [],
66 | "source": [
67 | "#@title nvidia-smi gpu check\n",
68 | "!nvidia-smi"
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "execution_count": null,
74 | "metadata": {
75 | "cellView": "form",
76 | "gradient": {
77 | "editing": false,
78 | "id": "a1a45a91-d909-4fd4-b67a-5e16b971d179",
79 | "kernelId": ""
80 | },
81 | "id": "fX12Yquyuihc"
82 | },
83 | "outputs": [],
84 | "source": [
85 | "#@title Install all dependencies (run only once per session)\n",
86 | "\n",
87 | "!git clone https://github.com/asigalov61/tegridy-tools\n",
88 | "!pip install torch\n",
89 | "!pip install tqdm\n",
90 | "!pip install matplotlib"
91 | ]
92 | },
93 | {
94 | "cell_type": "code",
95 | "execution_count": null,
96 | "metadata": {
97 | "cellView": "form",
98 | "gradient": {
99 | "editing": false,
100 | "id": "b8207b76-9514-4c07-95db-95a4742e52c5",
101 | "kernelId": ""
102 | },
103 | "id": "z7n9vnKmug1J"
104 | },
105 | "outputs": [],
106 | "source": [
107 | "#@title Import all needed modules\n",
108 | "\n",
109 | "print('Loading needed modules. Please wait...')\n",
110 | "import os\n",
111 | "from datetime import datetime\n",
112 | "import secrets\n",
113 | "import copy\n",
114 | "import tqdm\n",
115 | "from tqdm import tqdm\n",
116 | "\n",
117 | "if not os.path.exists('/notebooks/Dataset'):\n",
118 | " os.makedirs('/notebooks/Dataset')\n",
119 | "\n",
120 | "print('Loading TMIDIX module...')\n",
121 | "os.chdir('/notebooks/tegridy-tools/tegridy-tools')\n",
122 | "import TMIDIX\n",
123 | "\n",
124 | "os.chdir('/notebooks/tegridy-tools/tegridy-tools')\n",
125 | "from GPT2RGAX import *\n",
126 | "\n",
127 | "import matplotlib.pyplot as plt\n",
128 | "\n",
129 | "os.chdir('/notebooks/')"
130 | ]
131 | },
132 | {
133 | "cell_type": "markdown",
134 | "metadata": {
135 | "id": "ObPxlEutsQBj"
136 | },
137 | "source": [
138 | "# (FROM SCRATCH) Download and process MIDI dataset"
139 | ]
140 | },
141 | {
142 | "cell_type": "code",
143 | "execution_count": null,
144 | "metadata": {
145 | "cellView": "form",
146 | "gradient": {
147 | "id": "ffbb7a2a-d91a-477f-ac89-56d77d6cdf42",
148 | "kernelId": ""
149 | },
150 | "id": "snIZ3xKPsPgB"
151 | },
152 | "outputs": [],
153 | "source": [
154 | "#@title Download Endless Violin Carousel MIDI dataset (Recommended)\n",
155 | "\n",
156 | "#@markdown Piano Violin Duo\n",
157 | "\n",
158 | "#@markdown Works best stand-alone/as-is for the optimal results\n",
159 | "%cd /notebooks/Dataset/\n",
160 | "\n",
161 | "!wget 'https://github.com/asigalov61/Tegridy-MIDI-Dataset/raw/master/Endless-Violin-Carousel-CC-BY-NC-SA.zip'\n",
162 | "!unzip -j '/notebooks/Dataset/Endless-Violin-Carousel-CC-BY-NC-SA.zip'\n",
163 | "!rm '/notebooks/Dataset/Endless-Violin-Carousel-CC-BY-NC-SA.zip'\n",
164 | "\n",
165 | "%cd /notebooks/"
166 | ]
167 | },
168 | {
169 | "cell_type": "code",
170 | "execution_count": null,
171 | "metadata": {
172 | "cellView": "form",
173 | "gradient": {
174 | "id": "ed07b44f-07fe-45fb-a64f-adba8df1bdcb",
175 | "kernelId": ""
176 | },
177 | "id": "on7sgKEP3Yc8"
178 | },
179 | "outputs": [],
180 | "source": [
181 | "#@title Process MIDIs to special MIDI dataset with Tegridy MIDI Processor\n",
182 | "\n",
183 | "#@markdown IMPORTANT NOTES:\n",
184 | "\n",
185 | "#@markdown 1) Best results are achieved with the single-track, single-channel, single-instrument MIDI 0 files with plain English names (avoid special or sys/foreign chars)\n",
186 | "\n",
187 | "#@markdown 2) MIDI Channel = -1 means all MIDI channels except the drums. MIDI Channel = 16 means all channels will be processed. Otherwise, only single indicated MIDI channel will be processed\n",
188 | "\n",
189 | "desired_dataset_name = \"Quantum-Music-Dataset\" #@param {type:\"string\"}\n",
190 | "file_name_to_output_dataset_to = \"/notebooks/Quantum-Music-Dataset\" #@param {type:\"string\"}\n",
191 | "desired_MIDI_channel_to_process = -1 #@param {type:\"slider\", min:-1, max:16, step:1}\n",
192 | "sorted_or_random_file_loading_order = False #@param {type:\"boolean\"}\n",
193 | "encode_velocities = True #@param {type:\"boolean\"}\n",
194 | "encode_MIDI_channels = True #@param {type:\"boolean\"}\n",
195 | "add_transposed_dataset_by_this_many_pitches = 0 #@param {type:\"slider\", min:-12, max:12, step:1}\n",
196 | "add_transposed_and_flipped_dataset = False #@param {type:\"boolean\"}\n",
197 | "chordify_input_MIDIs = False #@param {type:\"boolean\"}\n",
198 | "melody_conditioned_chords = False #@param {type:\"boolean\"}\n",
199 | "melody_pitch_baseline = 60 #@param {type:\"slider\", min:0, max:127, step:1}\n",
200 | "time_denominator = 1 #@param {type:\"slider\", min:1, max:50, step:1}\n",
201 | "transform_to_pitch = 0 #@param {type:\"slider\", min:0, max:127, step:1}\n",
202 | "perfect_timings = True #@param {type:\"boolean\"}\n",
203 | "MuseNet_encoding = True #@param {type:\"boolean\"}\n",
204 | "chars_encoding_offset = 0 #@param {type:\"number\"}\n",
205 | "\n",
206 | "print('TMIDI Optimus MIDI Processor')\n",
207 | "print('Starting up...')\n",
208 | "###########\n",
209 | "\n",
210 | "average_note_pitch = 0\n",
211 | "min_note = 127\n",
212 | "max_note = 0\n",
213 | "\n",
214 | "files_count = 0\n",
215 | "\n",
216 | "gfiles = 0\n",
217 | "\n",
218 | "chords_list_f = []\n",
219 | "melody_list_f = []\n",
220 | "\n",
221 | "chords_list = []\n",
222 | "chords_count = 0\n",
223 | "\n",
224 | "melody_chords = []\n",
225 | "melody_count = 0\n",
226 | "\n",
227 | "TXT_String = ''\n",
228 | "\n",
229 | "TXT = ''\n",
230 | "melody = []\n",
231 | "chords = []\n",
232 | "INTS_f = []\n",
233 | "\n",
234 | "flist = []\n",
235 | "\n",
236 | "###########\n",
237 | "\n",
238 | "print('Loading MIDI files...')\n",
239 | "print('This may take a while on a large dataset in particular.')\n",
240 | "\n",
241 | "dataset_addr = \"/notebooks/Dataset/\"\n",
242 | "os.chdir(dataset_addr)\n",
243 | "filez = list()\n",
244 | "for (dirpath, dirnames, filenames) in os.walk(dataset_addr):\n",
245 | " filez += [os.path.join(dirpath, file) for file in filenames]\n",
246 | "print('=' * 70)\n",
247 | "\n",
248 | "if filez == []:\n",
249 | " print('Could not find any MIDI files. Please check Dataset dir...')\n",
250 | " print('=' * 70)\n",
251 | "\n",
252 | "if sorted_or_random_file_loading_order:\n",
253 | " print('Sorting files...')\n",
254 | " filez.sort()\n",
255 | " print('Done!')\n",
256 | " print('=' * 70)\n",
257 | "\n",
258 | "else:\n",
259 | " random.shuffle(filez)\n",
260 | "\n",
261 | "# Stamping the dataset info\n",
262 | "print('Stamping the dataset info...')\n",
263 | "\n",
264 | "TXT_String += 'DATASET=' + str(desired_dataset_name) + chr(10)\n",
265 | "TXT_String += 'CREATED_ON=' + str(datetime.now()).replace(' ', '-').replace(':', '-').replace('.', '-') + chr(10)\n",
266 | "\n",
267 | "TXT_String += 'CHARS_ENCODING_OFFSET=' + str(chars_encoding_offset) + chr(10)\n",
268 | "TXT_String += 'TIME_DENOMINATOR=' + str(time_denominator) + chr(10)\n",
269 | "TXT_String += 'TRANSFORM=' + str(transform_to_pitch) + chr(10)\n",
270 | "TXT_String += 'PERFECT_TIMINGS=' + str(perfect_timings) + chr(10)\n",
271 | "TXT_String += 'MUSENET_ENCODING=' + str(MuseNet_encoding) + chr(10)\n",
272 | "TXT_String += 'TRANSPOSED_BY=' + str(add_transposed_dataset_by_this_many_pitches) + chr(10)\n",
273 | "TXT_String += 'TRANSPOSED_AND_FLIPPED=' + str(add_transposed_and_flipped_dataset) + chr(10)\n",
274 | "\n",
275 | "TXT_String += 'LEGEND=STA-DUR-PTC'\n",
276 | "if encode_velocities:\n",
277 | " TXT_String += '-VEL'\n",
278 | "if encode_MIDI_channels:\n",
279 | " TXT_String += '-CHA'\n",
280 | "TXT_String += chr(10)\n",
281 | "\n",
282 | "print('Processing MIDI files. Please wait...')\n",
283 | "for f in tqdm(filez):\n",
284 | " try:\n",
285 | " fn = os.path.basename(f)\n",
286 | " fn1 = fn.split('.')[0]\n",
287 | "\n",
288 | " files_count += 1\n",
289 | " TXT, melody, chords, bass_melody, karaokez, INTS, aux1, aux2 = TMIDIX.Optimus_MIDI_TXT_Processor(f, chordify_TXT=chordify_input_MIDIs, output_MIDI_channels=encode_MIDI_channels, char_offset=chars_encoding_offset, dataset_MIDI_events_time_denominator=time_denominator, output_velocity=encode_velocities, MIDI_channel=desired_MIDI_channel_to_process, MIDI_patch=range(0, 127), melody_conditioned_encoding=melody_conditioned_chords, melody_pitch_baseline=melody_pitch_baseline, perfect_timings=perfect_timings, musenet_encoding=MuseNet_encoding, transform=transform_to_pitch)\n",
290 | " TXT_String += TXT\n",
291 | " melody_list_f += melody\n",
292 | " chords_list_f.append(chords)\n",
293 | " INTS_f.append(INTS)\n",
294 | " flist.append([f, fn1])\n",
295 | " gfiles += 1\n",
296 | "\n",
297 | " if add_transposed_dataset_by_this_many_pitches != 0:\n",
298 | "\n",
299 | " TXT, melody, chords, bass_melody, karaokez, INTS, aux1, aux2 = TMIDIX.Optimus_MIDI_TXT_Processor(f, chordify_TXT=chordify_input_MIDIs, output_MIDI_channels=encode_MIDI_channels, char_offset=chars_encoding_offset, dataset_MIDI_events_time_denominator=time_denominator, output_velocity=encode_velocities, MIDI_channel=desired_MIDI_channel_to_process, transpose_by=add_transposed_dataset_by_this_many_pitches, MIDI_patch=range(0, 127), melody_conditioned_encoding=melody_conditioned_chords, melody_pitch_baseline=melody_pitch_baseline, perfect_timings=perfect_timings, musenet_encoding=MuseNet_encoding, transform=transform_to_pitch)\n",
300 | " TXT_String += TXT\n",
301 | " melody_list_f += melody\n",
302 | " chords_list_f.append(chords)\n",
303 | " INTS_f.append(INTS)\n",
304 | " gfiles += 1\n",
305 | "\n",
306 | " if add_transposed_and_flipped_dataset == True:\n",
307 | "\n",
308 | " TXT, melody, chords, bass_melody, karaokez, INTS, aux1, aux2 = TMIDIX.Optimus_MIDI_TXT_Processor(f, chordify_TXT=chordify_input_MIDIs, output_MIDI_channels=encode_MIDI_channels, char_offset=chars_encoding_offset, dataset_MIDI_events_time_denominator=time_denominator, output_velocity=encode_velocities, MIDI_channel=desired_MIDI_channel_to_process, transpose_by=-12, MIDI_patch=range(0, 127), flip=True, melody_conditioned_encoding=melody_conditioned_chords, melody_pitch_baseline=melody_pitch_baseline, perfect_timings=perfect_timings, musenet_encoding=MuseNet_encoding, transform=transform_to_pitch)\n",
309 | " TXT_String += TXT\n",
310 | " melody_list_f += melody\n",
311 | " chords_list_f += chords\n",
312 | " INTS_f.append(INTS)\n",
313 | " gfiles += 1\n",
314 | "\n",
315 | " except KeyboardInterrupt:\n",
316 | " print('Saving current progress and quitting...')\n",
317 | " break \n",
318 | " \n",
319 | " except:\n",
320 | " print('Bad MIDI:', f)\n",
321 | " continue\n",
322 | "\n",
323 | "TXT_String += 'TOTAL_SONGS_IN_DATASET=' + str(gfiles)\n",
324 | "\n",
325 | "try:\n",
326 | " print('Task complete :)')\n",
327 | " print('==================================================')\n",
328 | " if add_transposed_dataset_by_this_many_pitches != 0:\n",
329 | " print('NOTE: Transposed dataset was added per users request.')\n",
330 | " print('==================================================')\n",
331 | " if add_transposed_and_flipped_dataset == True:\n",
332 | " print('NOTE: Flipped dataset was added per users request.') \n",
333 | " print('==================================================')\n",
334 | " print('Number of processed dataset MIDI files:', files_count)\n",
335 | " print('Number of MIDI chords recorded:', len(chords_list_f))\n",
336 | " print('First chord event:', chords_list_f[0], 'Last chord event:', chords_list_f[-1]) \n",
337 | " print('Number of recorded melody events:', len(melody_list_f))\n",
338 | " print('First melody event:', melody_list_f[0], 'Last Melody event:', melody_list_f[-1])\n",
339 | " print('Total number of MIDI events recorded:', len(chords_list_f) + len(melody_list_f))\n",
340 | " print('==================================================')\n",
341 | "\n",
342 | " # Writing dataset to TXT file\n",
343 | " with open(file_name_to_output_dataset_to + '.txt', 'wb') as f:\n",
344 | " f.write(TXT_String.encode('utf-8', 'replace'))\n",
345 | " f.close\n",
346 | "\n",
347 | " # Dataset\n",
348 | " MusicDataset = [chords_list_f, melody_list_f, INTS_f]\n",
349 | "\n",
350 | " # Writing dataset to pickle file\n",
351 | " TMIDIX.Tegridy_Any_Pickle_File_Writer(MusicDataset, file_name_to_output_dataset_to)\n",
352 | "\n",
353 | "except:\n",
354 | " print('=' * 70)\n",
355 | " print('IO Error!')\n",
356 | " print('Please check that Dataset dir is not empty/check other IO code.')\n",
357 | " print('=' * 70)\n",
358 | " print('Shutting down...')\n",
359 | " print('=' * 70)"
360 | ]
361 | },
362 | {
363 | "cell_type": "code",
364 | "execution_count": null,
365 | "metadata": {
366 | "gradient": {
367 | "id": "0826f622-2edc-4f09-9a01-58df049738d4",
368 | "kernelId": ""
369 | }
370 | },
371 | "outputs": [],
372 | "source": [
373 | "INTS_f1 = []\n",
374 | "\n",
375 | "\n",
376 | "for chords_list in tqdm(chords_list_f):\n",
377 | " INTS_f1.append([-1, -1, -1]) # Intro\n",
378 | " pe = chords_list[0]\n",
379 | " for i in chords_list:\n",
380 | "\n",
381 | " INTS_f1.append([int(abs(i[1]-pe[1])/ 10), int(i[2] / 10), i[4] ])\n",
382 | " \n",
383 | " if chords_list.index(i) == len(chords_list)-50:\n",
384 | " INTS_f1.append([-2, -2, -2]) # Outro\n",
385 | " \n",
386 | " \n",
387 | " pe = i\n",
388 | " INTS_f1.append([-3, -3, -3]) # End"
389 | ]
390 | },
391 | {
392 | "cell_type": "code",
393 | "execution_count": null,
394 | "metadata": {
395 | "cellView": "form",
396 | "gradient": {
397 | "id": "53252e52-5e68-4e60-8e4d-a584667749a4",
398 | "kernelId": ""
399 | },
400 | "id": "lT0TyqUnpxu_"
401 | },
402 | "outputs": [],
403 | "source": [
404 | "#@title Load processed INTs datasets\n",
405 | "number_of_batches = 16 #@param {type:\"slider\", min:2, max:32, step:2}\n",
406 | "n_workers = 6\n",
407 | "\n",
408 | "print('=' * 50)\n",
409 | "print('Prepping INTs datasets...')\n",
410 | "\n",
411 | "\n",
412 | "train_data1 = []\n",
413 | "for i in INTS_f1:\n",
414 | " if max(i) < 256 and min(i) >= 0:\n",
415 | "\n",
416 | " if i[0] < 16:\n",
417 | " train_data1.extend([i[0]])\n",
418 | " else:\n",
419 | " train_data1.extend([16, i[0]-16])\n",
420 | " \n",
421 | " train_data1.extend([256+i[2], 512+i[1]-4 ])\n",
422 | " \n",
423 | " if max(i) == -1 and min(i) == -1: # Intro\n",
424 | " train_data1.extend([256+512-3])\n",
425 | " \n",
426 | " if max(i) == -2 and min(i) == -2: # Outro\n",
427 | " train_data1.extend([256+512-2])\n",
428 | " \n",
429 | " if max(i) == -3 and min(i) == -3: # End\n",
430 | " train_data1.extend([256+512-1])\n",
431 | "\n",
432 | "train_data = train_data1[:int(len(train_data1) / 3)]\n",
433 | "\n",
434 | "val_dataset = train_data[:int(len(train_data) * 0.03)]\n",
435 | "test_dataset = train_data[:int(len(train_data) * 0.03)]\n",
436 | "\n",
437 | "train_list = train_data\n",
438 | "val_list = val_dataset\n",
439 | "test_list = []\n",
440 | "print('=' * 50)\n",
441 | "\n",
442 | "print('Processing INTs datasets...')\n",
443 | "train_dataset = EPianoDataset(train_list, max_seq, random_seq)\n",
444 | "val_dataset = EPianoDataset(val_list, max_seq)\n",
445 | "test_dataset = EPianoDataset(test_list, max_seq)\n",
446 | "print('=' * 50)\n",
447 | "\n",
448 | "print('Loading INTs datasets...')\n",
449 | "batch_size = number_of_batches\n",
450 | "train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=n_workers, shuffle=True)\n",
451 | "val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=n_workers)\n",
452 | "test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=n_workers)\n",
453 | "print('=' * 50)\n",
454 | "\n",
455 | "print('Total INTs in the dataset', len(train_data))\n",
456 | "print('Total unique INTs in the dataset', len(set(train_data)))\n",
457 | "print('Max INT in the dataset', max(train_data))\n",
458 | "print('Min INT in the dataset', min(train_data))\n",
459 | "print('=' * 50)\n",
460 | "\n",
461 | "print('Checking datasets shapes...')\n",
462 | "print('=' * 50)\n",
463 | "\n",
464 | "print('Train loader')\n",
465 | "for x, tgt in train_loader:\n",
466 | " print(f'X shape: {x.shape}')\n",
467 | " print(f'Target shape: {tgt.shape}')\n",
468 | " break\n",
469 | "print('=' * 50)\n",
470 | "\n",
471 | "print('Validation loader')\n",
472 | "for x, tgt in val_loader:\n",
473 | " print(f'X shape: {x.shape}')\n",
474 | " print(f'Target shape: {tgt.shape}')\n",
475 | " break\n",
476 | "print('=' * 50)\n",
477 | "\n",
478 | "print('Test loader')\n",
479 | "for x, tgt in test_loader:\n",
480 | " print(f'X shape: {x.shape}')\n",
481 | " print(f'Target shape: {tgt.shape}')\n",
482 | " break\n",
483 | "print('=' * 50)\n",
484 | "\n",
485 | "print('Done! Enjoy! :)')\n",
486 | "print('=' * 50)"
487 | ]
488 | },
489 | {
490 | "cell_type": "markdown",
491 | "metadata": {},
492 | "source": [
493 | "# Test the resulting INTs dataset..."
494 | ]
495 | },
496 | {
497 | "cell_type": "code",
498 | "execution_count": null,
499 | "metadata": {
500 | "gradient": {
501 | "id": "708f16d3-1747-4e72-bcc9-7504cdd963d4",
502 | "kernelId": ""
503 | }
504 | },
505 | "outputs": [],
506 | "source": [
507 | "train_data"
508 | ]
509 | },
510 | {
511 | "cell_type": "code",
512 | "execution_count": null,
513 | "metadata": {
514 | "gradient": {
515 | "execution_count": 6,
516 | "id": "dd411e56-532f-47dd-8283-ecb57126a3ae",
517 | "kernelId": ""
518 | }
519 | },
520 | "outputs": [],
521 | "source": [
522 | "out = train_data[:10000]\n",
523 | "if len(out) != 0:\n",
524 | " song = []\n",
525 | " song = out\n",
526 | " song_f = []\n",
527 | " time = 0\n",
528 | " pitch = 0\n",
529 | " duration = 0\n",
530 | " for s in song:\n",
531 | " if s >= 0 and s <= 256:\n",
532 | " time += s\n",
533 | " if s >= 256 and s < 512:\n",
534 | " pitch = s-256\n",
535 | " if s >= 512 and s < 256+512-4:\n",
536 | " duration = s-512\n",
537 | " song_f.append(['note', (abs(time))*10, (duration*10), 0, pitch, pitch ])\n",
538 | " \n",
539 | " detailed_stats = TMIDIX.Tegridy_SONG_to_MIDI_Converter(song_f,\n",
540 | " output_signature = 'Quantum Music', \n",
541 | " output_file_name = '/notebooks/Quantum-Music-Composition', \n",
542 | " track_name='Project Los Angeles', \n",
543 | " number_of_ticks_per_quarter=500)\n",
544 | "\n",
545 | " print('Done!')\n"
546 | ]
547 | },
548 | {
549 | "cell_type": "markdown",
550 | "metadata": {
551 | "id": "fkVqviDzJOrv"
552 | },
553 | "source": [
554 | "# (TRAIN)"
555 | ]
556 | },
557 | {
558 | "cell_type": "markdown",
559 | "metadata": {
560 | "id": "Y9CBW8xYupH8"
561 | },
562 | "source": [
563 | "# Train the model"
564 | ]
565 | },
566 | {
567 | "cell_type": "code",
568 | "execution_count": null,
569 | "metadata": {
570 | "cellView": "form",
571 | "gradient": {
572 | "id": "4aa21407-a3e9-4ed2-9bf1-83c295482b8a",
573 | "kernelId": ""
574 | },
575 | "id": "2moo7uUmpxvC"
576 | },
577 | "outputs": [],
578 | "source": [
579 | "#@title Train\n",
580 | "config = GPTConfig(VOCAB_SIZE, \n",
581 | " max_seq,\n",
582 | " dim_feedforward=dim_feedforward,\n",
583 | " n_layer=6, \n",
584 | " n_head=8, \n",
585 | " n_embd=512,\n",
586 | " enable_rpr=True,\n",
587 | " er_len=max_seq)\n",
588 | "model = GPT(config).to(get_device())\n",
589 | "\n",
590 | "#=====\n",
591 | "\n",
592 | "init_step = 0\n",
593 | "lr = LR_DEFAULT_START\n",
594 | "lr_stepper = LrStepTracker(d_model, SCHEDULER_WARMUP_STEPS, init_step)\n",
595 | "eval_loss_func = nn.CrossEntropyLoss(ignore_index=TOKEN_PAD)\n",
596 | "train_loss_func = eval_loss_func\n",
597 | "\n",
598 | "opt = Adam(model.parameters(), lr=lr, betas=(ADAM_BETA_1, ADAM_BETA_2), eps=ADAM_EPSILON)\n",
599 | "lr_scheduler = LambdaLR(opt, lr_stepper.step)\n",
600 | "\n",
601 | "\n",
602 | "#===\n",
603 | "\n",
604 | "best_eval_acc = 0.0\n",
605 | "best_eval_acc_epoch = -1\n",
606 | "best_eval_loss = float(\"inf\")\n",
607 | "best_eval_loss_epoch = -1\n",
608 | "best_acc_file = '/notebooks/gpt2_rpr_acc.pth'\n",
609 | "best_loss_file = '/notebooks/gpt2_rpr_loss.pth'\n",
610 | "loss_train, loss_val, acc_val = [], [], []\n",
611 | "\n",
612 | "for epoch in range(0, epochs):\n",
613 | " new_best = False\n",
614 | " \n",
615 | " loss = train(epoch+1, model, train_loader, train_loss_func, opt, lr_scheduler, num_iters=-1)\n",
616 | " loss_train.append(loss)\n",
617 | " \n",
618 | " eval_loss, eval_acc = eval_model(model, val_loader, eval_loss_func, num_iters=-1)\n",
619 | " loss_val.append(eval_loss)\n",
620 | " acc_val.append(eval_acc)\n",
621 | " \n",
622 | " if(eval_acc > best_eval_acc):\n",
623 | " best_eval_acc = eval_acc\n",
624 | " best_eval_acc_epoch = epoch+1\n",
625 | " torch.save(model.state_dict(), best_acc_file)\n",
626 | " new_best = True\n",
627 | "\n",
628 | " if(eval_loss < best_eval_loss):\n",
629 | " best_eval_loss = eval_loss\n",
630 | " best_eval_loss_epoch = epoch+1\n",
631 | " torch.save(model.state_dict(), best_loss_file)\n",
632 | " new_best = True\n",
633 | " \n",
634 | " if(new_best):\n",
635 | " print(\"Best eval acc epoch:\", best_eval_acc_epoch)\n",
636 | " print(\"Best eval acc:\", best_eval_acc)\n",
637 | " print(\"\")\n",
638 | " print(\"Best eval loss epoch:\", best_eval_loss_epoch)\n",
639 | " print(\"Best eval loss:\", best_eval_loss)"
640 | ]
641 | },
642 | {
643 | "cell_type": "code",
644 | "execution_count": null,
645 | "metadata": {},
646 | "outputs": [],
647 | "source": [
648 | "eval_loss, eval_acc = eval_model(model, val_loader, eval_loss_func, num_iters=-1)"
649 | ]
650 | },
651 | {
652 | "cell_type": "code",
653 | "execution_count": null,
654 | "metadata": {},
655 | "outputs": [],
656 | "source": [
657 | "train_data"
658 | ]
659 | },
660 | {
661 | "cell_type": "code",
662 | "execution_count": null,
663 | "metadata": {
664 | "cellView": "form",
665 | "gradient": {
666 | "id": "0e338550-f170-44a6-9479-ba0ddbc64608",
667 | "kernelId": ""
668 | },
669 | "id": "NNqmcFdRyC2M"
670 | },
671 | "outputs": [],
672 | "source": [
673 | "#@title Plot resulting training loss graph\n",
674 | "\n",
675 | "tr_loss_list = [item for sublist in loss_train for item in sublist]\n",
676 | "plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')\n",
677 | "plt.savefig('/notebooks/training-loss.png')"
678 | ]
679 | },
680 | {
681 | "cell_type": "markdown",
682 | "metadata": {
683 | "id": "mdKFoeke9L7H"
684 | },
685 | "source": [
686 | "# (SAVE/LOAD)"
687 | ]
688 | },
689 | {
690 | "cell_type": "code",
691 | "execution_count": null,
692 | "metadata": {
693 | "cellView": "form",
694 | "gradient": {
695 | "id": "73bea62d-084b-4f9a-9e55-2b34a932a7a4",
696 | "kernelId": ""
697 | },
698 | "id": "gqyDatHC9X1z"
699 | },
700 | "outputs": [],
701 | "source": [
702 | "#@title Save the model\n",
703 | "\n",
704 | "print('Saving the model...')\n",
705 | "full_path_to_model_checkpoint = \"/notebooks/Quantum-Music-Trained-Model-6.pth\" #@param {type:\"string\"}\n",
706 | "torch.save(model.state_dict(), full_path_to_model_checkpoint)\n",
707 | "print('Done!')"
708 | ]
709 | },
710 | {
711 | "cell_type": "code",
712 | "execution_count": null,
713 | "metadata": {
714 | "cellView": "form",
715 | "gradient": {
716 | "id": "c83edd89-9a36-430a-9fa7-3a967417c88e",
717 | "kernelId": ""
718 | },
719 | "id": "OaNkGcFo9UP_"
720 | },
721 | "outputs": [],
722 | "source": [
723 | "#@title Load/Reload the model\n",
724 | "full_path_to_model_checkpoint = \"/notebooks/Quantum-Music-Trained-Model-6.pth\" #@param {type:\"string\"}\n",
725 | "\n",
726 | "print('Loading the model...')\n",
727 | "config = GPTConfig(256+512+2, \n",
728 | " max_seq,\n",
729 | " dim_feedforward=dim_feedforward,\n",
730 | " n_layer=6, \n",
731 | " n_head=8, \n",
732 | " n_embd=512,\n",
733 | " enable_rpr=True,\n",
734 | " er_len=max_seq)\n",
735 | "\n",
736 | "model = GPT(config).to(get_device())\n",
737 | "\n",
738 | "model.load_state_dict(torch.load(full_path_to_model_checkpoint))\n",
739 | "print('Done!')"
740 | ]
741 | },
742 | {
743 | "cell_type": "markdown",
744 | "metadata": {},
745 | "source": [
746 | "# Custom MIDI option"
747 | ]
748 | },
749 | {
750 | "cell_type": "code",
751 | "execution_count": null,
752 | "metadata": {
753 | "gradient": {
754 | "id": "5f771604-39e7-431d-b1dd-86d7437b8872",
755 | "kernelId": ""
756 | }
757 | },
758 | "outputs": [],
759 | "source": [
760 | "data = TMIDIX.Optimus_MIDI_TXT_Processor('/notebooks/seed97-super.mid', \n",
761 | " dataset_MIDI_events_time_denominator=10, \n",
762 | " perfect_timings=True, \n",
763 | " musenet_encoding=True, \n",
764 | " char_offset=0, \n",
765 | " MIDI_channel=-1, \n",
766 | " MIDI_patch=range(0, 127)\n",
767 | " )\n",
768 | "\n",
769 | "SONG = data[5]\n",
770 | "inputs = []\n",
771 | "for i in SONG:\n",
772 | " if max(i) < 256 and max(i) >= 0:\n",
773 | " if i[0] < 16:\n",
774 | " inputs.extend([i[0]])\n",
775 | " else:\n",
776 | " \n",
777 | " inputs.extend([16, i[0]-16])\n",
778 | " \n",
779 | " inputs.extend([256+i[3], 512+i[1] ]) "
780 | ]
781 | },
782 | {
783 | "cell_type": "markdown",
784 | "metadata": {
785 | "id": "UX1_5y5Fu8AH"
786 | },
787 | "source": [
788 | "# (GENERATE MUSIC)"
789 | ]
790 | },
791 | {
792 | "cell_type": "code",
793 | "execution_count": null,
794 | "metadata": {
795 | "cellView": "form",
796 | "gradient": {
797 | "id": "97793d01-6a74-4e34-be95-ea337277b38d",
798 | "kernelId": ""
799 | },
800 | "id": "M_K93hWWv2Yx"
801 | },
802 | "outputs": [],
803 | "source": [
804 | "#@title Generate and download a MIDI file\n",
805 | "\n",
806 | "number_of_tokens_to_generate = 1024 #@param {type:\"slider\", min:8, max:1024, step:8}\n",
807 | "use_random_primer = False #@param {type:\"boolean\"}\n",
808 | "start_with_zero_token = False #@param {type:\"boolean\"}\n",
809 | "number_of_ticks_per_quarter = 500 #@param {type:\"slider\", min:50, max:1000, step:50}\n",
810 | "dataset_time_denominator = 10\n",
811 | "melody_conditioned_encoding = False\n",
812 | "encoding_has_MIDI_channels = False \n",
813 | "encoding_has_velocities = False\n",
814 | "simulate_velocity = True #@param {type:\"boolean\"}\n",
815 | "save_only_first_composition = True\n",
816 | "chars_encoding_offset_used_for_dataset = 33\n",
817 | "\n",
818 | "fname = '/notebooks/Quantum-Music-Composition'\n",
819 | "\n",
820 | "print('Quantum Music Model Generator')\n",
821 | "\n",
822 | "output_signature = 'Quantum Music'\n",
823 | "song_name = 'RGA Composition'\n",
824 | "\n",
825 | "model.eval()\n",
826 | "\n",
827 | "if use_random_primer:\n",
828 | " sequence = [random.randint(10, 387) for i in range(64)]\n",
829 | " idx = secrets.randbelow(len(sequence))\n",
830 | " rand_seq = model.generate(torch.Tensor(sequence[idx:idx+120]), target_seq_length=number_of_tokens_to_generate)\n",
831 | " out = rand_seq[0].cpu().numpy().tolist()\n",
832 | "\n",
833 | "else:\n",
834 | " out = []\n",
835 | " \n",
836 | " try:\n",
837 | " if start_with_zero_token:\n",
838 | " sequence = inputs[-512:] #[256+512 - 2, 0]# inputs[-512:]\n",
839 | " rand_seq = model.generate(torch.Tensor(sequence), target_seq_length=number_of_tokens_to_generate, stop_token=256+512)\n",
840 | " out = rand_seq[0].cpu().numpy().tolist()\n",
841 | " else:\n",
842 | " idx = secrets.randbelow(len(train_data))\n",
843 | " sequence = train_data[idx:idx+512]\n",
844 | " rand_seq = model.generate(torch.Tensor(sequence), target_seq_length=number_of_tokens_to_generate, stop_token=256+512)\n",
845 | " out = rand_seq[0].cpu().numpy().tolist()\n",
846 | " \n",
847 | " except:\n",
848 | " print('=' * 50)\n",
849 | " print('Error! Try random priming instead!')\n",
850 | " print('Shutting down...')\n",
851 | " print('=' * 50)\n",
852 | "\n",
853 | "if len(out) != 0:\n",
854 | " song = []\n",
855 | " song = out\n",
856 | " song_f = []\n",
857 | " time = 0\n",
858 | " pitch = 0\n",
859 | " duration = 0\n",
860 | " once = True\n",
861 | " for s in song:\n",
862 | " if s >= 0 and s < 256:\n",
863 | " time += s\n",
864 | " if s >= 256 and s < 512:\n",
865 | " pitch = s-256\n",
866 | " if s >= 512 and s < 256+512-4:\n",
867 | " duration = s-512\n",
868 | " song_f.append(['note', (abs(time))*10, (duration*10), 0, pitch, pitch ])\n",
869 | " \n",
870 | " if song.index(s) >= len(sequence) and once:\n",
871 | " song_f.append(['text_event', abs(time) * 10, 'Continuation Start Here'])\n",
872 | " once = False\n",
873 | " \n",
874 | " detailed_stats = TMIDIX.Tegridy_SONG_to_MIDI_Converter(song_f,\n",
875 | " output_signature = 'Quantum Music', \n",
876 | " output_file_name = '/notebooks/Quantum-Music-Composition', \n",
877 | " track_name='Project Los Angeles', \n",
878 | " number_of_ticks_per_quarter=500)\n",
879 | " \n",
880 | " print('Done!')\n",
881 | "\n",
882 | "\n",
883 | " #print('Downloading your composition now...')\n",
884 | " #from google.colab import files\n",
885 | " #files.download(fname + '.mid')\n",
886 | "\n",
887 | " print('=' * 70)\n",
888 | " print('Detailed MIDI stats:')\n",
889 | " for key, value in detailed_stats.items():\n",
890 | " print('=' * 70)\n",
891 | " print(key, '|', value)\n",
892 | "\n",
893 | " print('=' * 70)\n",
894 | "\n",
895 | "else:\n",
896 | " print('Models output is empty! Check the code...')\n",
897 | " print('Shutting down...')"
898 | ]
899 | },
900 | {
901 | "cell_type": "code",
902 | "execution_count": null,
903 | "metadata": {},
904 | "outputs": [],
905 | "source": [
906 | "len(out)"
907 | ]
908 | },
909 | {
910 | "cell_type": "code",
911 | "execution_count": null,
912 | "metadata": {},
913 | "outputs": [],
914 | "source": [
915 | "out[-64:]"
916 | ]
917 | },
918 | {
919 | "cell_type": "code",
920 | "execution_count": null,
921 | "metadata": {
922 | "cellView": "form",
923 | "gradient": {
924 | "id": "c8149763-2e09-4fcf-9823-85ca778b9e80",
925 | "kernelId": ""
926 | },
927 | "id": "STtGgBsf4-tA"
928 | },
929 | "outputs": [],
930 | "source": [
931 | "#@title Auto-Regressive Generator\n",
932 | "\n",
933 | "#@markdown NOTE: You much generate a seed composition first or it is not going to start\n",
934 | "\n",
935 | "number_of_cycles_to_run = 5 #@param {type:\"slider\", min:1, max:50, step:1}\n",
936 | "number_of_prime_tokens = 128 #@param {type:\"slider\", min:64, max:256, step:64}\n",
937 | "\n",
938 | "print('=' * 70)\n",
939 | "print('Quantum Music Auto-Regressive Model Generator')\n",
940 | "print('=' * 70)\n",
941 | "print('Starting up...')\n",
942 | "print('=' * 70)\n",
943 | "print('Prime length:', len(out))\n",
944 | "print('Prime tokens:', number_of_prime_tokens)\n",
945 | "print('Prime input sequence', out[-8:])\n",
946 | "\n",
947 | "if len(out) != 0:\n",
948 | " print('=' * 70)\n",
949 | " out_all = []\n",
950 | " out_all.append(out)\n",
951 | " for i in tqdm(range(number_of_cycles_to_run)):\n",
952 | " rand_seq1 = model.generate(torch.Tensor(out[-number_of_prime_tokens:]), target_seq_length=1024, stop_token=256+512)\n",
953 | " out1 = rand_seq1[0].cpu().numpy().tolist()\n",
954 | " out_all.append(out1[number_of_prime_tokens:])\n",
955 | " out = out1[number_of_prime_tokens:]\n",
956 | " \n",
957 | " print(chr(10))\n",
958 | " print('=' * 70)\n",
959 | " print('Block number:', i+1)\n",
960 | " print('Composition length so far:', (i+1) * 1024, 'notes')\n",
961 | " print('=' * 70)\n",
962 | "\n",
963 | " print('Done!' * 70)\n",
964 | " print('Total blocks:', i+1)\n",
965 | " print('Final omposition length:', (i+1) * 1024, 'notes')\n",
966 | " print('=' * 70)\n",
967 | " \n",
968 | " out2 = []\n",
969 | " for o in out_all:\n",
970 | " out2.extend(o)\n",
971 | "\n",
972 | " if len(out2) != 0:\n",
973 | " song = []\n",
974 | " song = out2\n",
975 | " song_f = []\n",
976 | " time = 0\n",
977 | " pitch = 0\n",
978 | " duration = 0\n",
979 | " once = True\n",
980 | " for s in song:\n",
981 | " if s >= 0 and s < 256:\n",
982 | " time += s\n",
983 | " if s >= 256 and s < 512:\n",
984 | " pitch = s-256\n",
985 | " if s >= 512 and s < 256+512-4:\n",
986 | " duration = s-512\n",
987 | " song_f.append(['note', (abs(time))*10, (duration*10), 0, pitch, pitch ])\n",
988 | "\n",
989 | " detailed_stats = TMIDIX.Tegridy_SONG_to_MIDI_Converter(song_f,\n",
990 | " output_signature = 'Quantum Music', \n",
991 | " output_file_name = '/notebooks/Quantum-Music-Composition', \n",
992 | " track_name='Project Los Angeles', \n",
993 | " number_of_ticks_per_quarter=500)\n",
994 | "\n",
995 | " print('Done!')\n",
996 | "\n",
997 | " \n",
998 | "\n",
999 | "else:\n",
1000 | " print('=' * 70)\n",
1001 | " print('INPUT ERROR !!!')\n",
1002 | " print('Prime sequence is empty...')\n",
1003 | " print('Please generate prime sequence and retry')\n",
1004 | "\n",
1005 | "print('=' * 70)"
1006 | ]
1007 | },
1008 | {
1009 | "cell_type": "markdown",
1010 | "metadata": {
1011 | "id": "YzCMd94Tu_gz"
1012 | },
1013 | "source": [
1014 | "# Congrats! You did it! :)"
1015 | ]
1016 | }
1017 | ],
1018 | "metadata": {
1019 | "accelerator": "GPU",
1020 | "colab": {
1021 | "collapsed_sections": [],
1022 | "machine_shape": "hm",
1023 | "name": "Optimus_VIRTUOSO_Multi_Instrumental_RGA_Edition.ipynb",
1024 | "private_outputs": true,
1025 | "provenance": []
1026 | },
1027 | "kernelspec": {
1028 | "display_name": "Python 3 (ipykernel)",
1029 | "language": "python",
1030 | "name": "python3"
1031 | },
1032 | "language_info": {
1033 | "codemirror_mode": {
1034 | "name": "ipython",
1035 | "version": 3
1036 | },
1037 | "file_extension": ".py",
1038 | "mimetype": "text/x-python",
1039 | "name": "python",
1040 | "nbconvert_exporter": "python",
1041 | "pygments_lexer": "ipython3",
1042 | "version": "3.8.12"
1043 | }
1044 | },
1045 | "nbformat": 4,
1046 | "nbformat_minor": 4
1047 | }
1048 |
--------------------------------------------------------------------------------