├── media ├── ChatPDF_Schematics.png ├── AttentionSchematics_white-01.png ├── AttentionSchematics_white-06.png ├── AttentionSchematics_white-07.png ├── AttentionSchematics_white_Artboard 2.png └── AttentionSchematics_white_Artboard 3.png ├── main.py ├── LICENSE ├── midi_utils.py ├── README.md ├── .gitignore ├── maestro_visualize.py ├── audio_classifying_transformer.py ├── CIFAR_classifying_transformer.py ├── MNIST_generative_transformer.py ├── attention_to_Transformer.py └── train_maestro_GPT2.py /media/ChatPDF_Schematics.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Animadversio/TransformerFromScratch/HEAD/media/ChatPDF_Schematics.png -------------------------------------------------------------------------------- /media/AttentionSchematics_white-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Animadversio/TransformerFromScratch/HEAD/media/AttentionSchematics_white-01.png -------------------------------------------------------------------------------- /media/AttentionSchematics_white-06.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Animadversio/TransformerFromScratch/HEAD/media/AttentionSchematics_white-06.png -------------------------------------------------------------------------------- /media/AttentionSchematics_white-07.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Animadversio/TransformerFromScratch/HEAD/media/AttentionSchematics_white-07.png -------------------------------------------------------------------------------- /media/AttentionSchematics_white_Artboard 2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Animadversio/TransformerFromScratch/HEAD/media/AttentionSchematics_white_Artboard 2.png -------------------------------------------------------------------------------- /media/AttentionSchematics_white_Artboard 3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Animadversio/TransformerFromScratch/HEAD/media/AttentionSchematics_white_Artboard 3.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # This is a sample Python script. 2 | 3 | # Press Shift+F10 to execute it or replace it with your code. 4 | # Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings. 5 | 6 | 7 | def print_hi(name): 8 | # Use a breakpoint in the code line below to debug your script. 9 | print(f'Hi, {name}') # Press Ctrl+F8 to toggle the breakpoint. 10 | 11 | 12 | # Press the green button in the gutter to run the script. 13 | if __name__ == '__main__': 14 | print_hi('PyCharm') 15 | 16 | # See PyCharm help at https://www.jetbrains.com/help/pycharm/ 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Binxu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /midi_utils.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from mido import MidiFile 3 | from pydub import AudioSegment 4 | from pydub.generators import Sine 5 | 6 | 7 | def note_to_freq(note, concert_A=440.0): 8 | ''' 9 | from wikipedia: http://en.wikipedia.org/wiki/MIDI_Tuning_Standard#Frequency_values 10 | ''' 11 | return (2.0 ** ((note - 69) / 12.0)) * concert_A 12 | 13 | 14 | def ticks_to_ms(ticks): 15 | tick_ms = (60000.0 / tempo) / mid.ticks_per_beat 16 | return ticks * tick_ms 17 | 18 | 19 | 20 | # mid = MidiFile("./maroon_5-animals.mid") 21 | mid = MidiFile(r"E:\Datasets\maestro-v3.0.0-midi\maestro-v3.0.0\2011\MIDI-Unprocessed_01_R1_2011_MID--AUDIO_R1-D1_02_Track02_wav.midi") 22 | 23 | output = AudioSegment.silent(mid.length * 1000.0) 24 | 25 | tempo = 100 # bpm 26 | 27 | 28 | for track in mid.tracks: 29 | # position of rendering in ms 30 | current_pos = 0.0 31 | 32 | current_notes = defaultdict(dict) 33 | # current_notes = { 34 | # channel: { 35 | # note: (start_time, message) 36 | # } 37 | # } 38 | 39 | for msg in track: 40 | current_pos += ticks_to_ms(msg.time) 41 | 42 | if msg.type == 'note_on': 43 | current_notes[msg.channel][msg.note] = (current_pos, msg) 44 | 45 | if msg.type == 'note_off': 46 | start_pos, start_msg = current_notes[msg.channel].pop(msg.note) 47 | 48 | duration = current_pos - start_pos 49 | 50 | signal_generator = Sine(note_to_freq(msg.note)) 51 | rendered = signal_generator.to_audio_segment(duration=duration - 50, volume=-20).fade_out(100).fade_in(30) 52 | 53 | output = output.overlay(rendered, start_pos) 54 | 55 | output.export("piano.wav", format="wav") -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TransformerFromScratch Tutorial Series 2 | 3 | Binxu Wang (binxu_wang@hms.harvard.edu) 4 | 5 | April. 17th, 2023 6 | 7 | [Tutorial website](https://scholar.harvard.edu/binxuw/classes/machine-learning-scratch/materials/transformers) 8 | 9 | [Lecture slides (PDF)](https://scholar.harvard.edu/sites/scholar.harvard.edu/files/binxuw/files/mlfs_tutorial_nlp_transformer_ssl_updated.pdf) 10 | 11 | ![](media/AttentionSchematics_white-01.png) 12 | 13 | **Transformer Jupyter Notebook Tutorial Series** 14 | 15 | * Fundamentals 16 | * [Understanding Attention & Transformer](https://colab.research.google.com/drive/1ZuhA6khlWm57WGZ8i38JH-gc5aJrvpvs?usp=sharing) (no GPU required) 17 | * [Tutorial on Einstein summation rules](https://colab.research.google.com/drive/1mizzN7iRlS2Du5TJvv7Wz7ecKOnpHzrQ?usp=sharing) 18 | * [Language modelling with transformer](https://colab.research.google.com/drive/1zZYzAopL__LW4glruSF9lnZYlEmSVI8j?usp=sharing) (CPU or GPU) 19 | * Beyond Language 20 | * **All the following notebooks include training transformer, shall be run with GPU runtime or the training takes too long.** 21 | * [Learn to do arithmetics by sequence modelling.](https://colab.research.google.com/drive/1vO71-o-8-3IrOe44Ha0nsHmUsEGVSC37?usp=sharing) (Simple, GPU Training ~ 10 min) 22 | * [Image generation by sequence modelling.](https://colab.research.google.com/drive/1UHlEbepqdvk68cYV1fvkmWl2TBKXfm8E?usp=sharing) (Simple, GPU Training ~ 10 ~ 20 min) 23 | * ~~[Audio signal classification](https://colab.research.google.com/drive/1O4XHOJyOu3_lyaPHAKJM_XTztrAb7VFP?usp=sharing) (Medium, GPU Training ~ 20 min)~~ (WARNING: currently, there is a dependency install error, don't run on Colab.) 24 | * [Image classification](https://colab.research.google.com/drive/1JDQQlLMGzo675AfrtkFn1kbuADtVemJz?usp=sharing) (Medium, GPU Training ~ 30 min) 25 | * [Music generation by sequence modelling.](https://colab.research.google.com/drive/14zpzLpR4UBIzEQmeaXlMv_mDFYIv3Vht?usp=sharing) (Difficult, GPU Training takes hrs) 26 | * Large Language Model 27 | * [OpenAI API and Chat with PDF](https://colab.research.google.com/drive/19mYEyavBhOnAbEQJQuztXAxWxyYbsQzi?usp=sharing) (Simple, no GPU needed, ~5mins) 28 | 29 | ![](media/ChatPDF_Schematics.png) 30 | 31 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | runs/* 9 | .idea/* 10 | *.wav 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | -------------------------------------------------------------------------------- /maestro_visualize.py: -------------------------------------------------------------------------------- 1 | maestro_v3_midi_url = "https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0-midi.zip" 2 | maestro_v3_meta_url = "https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0.csv" 3 | maestro_v3_json_url = "https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0.json" 4 | #%% 5 | # download and unzip the urls above 6 | # !wget $maestro_v3_midi_url 7 | # !wget $maestro_v3_meta_url 8 | # !wget $maestro_v3_json_url 9 | # !unzip maestro-v3.0.0-midi.zip 10 | #%% 11 | import os 12 | from os.path import join 13 | import numpy as np 14 | import pandas as pd 15 | import torch 16 | import torch.nn as nn 17 | from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config 18 | from torch.optim import AdamW, Adam 19 | # from transformers import AdamW, get_linear_schedule_with_warmup 20 | import pretty_midi 21 | import librosa.display 22 | import librosa 23 | import matplotlib.pyplot as plt 24 | maestro_root = r"E:\Datasets\maestro-v3.0.0-midi\maestro-v3.0.0" 25 | # visualize the midi file as piano roll 26 | def plot_piano_roll(pm, start_pitch, end_pitch, fs=100): 27 | # Use librosa's specshow function for displaying the piano roll 28 | librosa.display.specshow(pm.get_piano_roll(fs)[start_pitch:end_pitch], 29 | hop_length=1, sr=fs, x_axis='time', y_axis='cqt_note', 30 | fmin=pretty_midi.note_number_to_hz(start_pitch)) 31 | 32 | # load the metadata 33 | maestro_meta = pd.read_csv(join(maestro_root, "maestro-v3.0.0.csv")) 34 | #%% 35 | # load sample midi file 36 | midi_file = join(maestro_root, maestro_meta["midi_filename"][0].replace("/", os.path.sep)) 37 | midi_data = pretty_midi.PrettyMIDI(midi_file) 38 | 39 | print(midi_data.instruments) 40 | len(midi_data.instruments[0].notes) 41 | #%% 42 | # create a new figure 43 | fig, ax = plt.subplots(figsize=(12, 4)) 44 | # plot the piano roll 45 | plot_piano_roll(midi_data, 24, 84) 46 | plt.tight_layout() 47 | plt.show() 48 | #%% 49 | note_dt_dist = [] 50 | note_len_dist = [] 51 | note_velo_dist = [] 52 | note_pitch_dist = [] 53 | for i, note_sample in enumerate(midi_data.instruments[0].notes[1:]): 54 | note_dt_dist.append(note_sample.start - midi_data.instruments[0].notes[i-1].end) 55 | note_len_dist.append(note_sample.duration) 56 | note_velo_dist.append(note_sample.velocity) 57 | note_pitch_dist.append(note_sample.pitch) 58 | #%% 59 | 60 | #%% 61 | plt.subplots(1, 3, figsize=(12, 4)) 62 | plt.subplot(1, 3, 1) 63 | plt.hist(note_len_dist, bins=5000) 64 | plt.xlim(0, 3) 65 | plt.title("Note Length Distribution") 66 | plt.subplot(1, 3, 2) 67 | plt.hist(note_velo_dist, bins=100) 68 | plt.title("Note Velocity Distribution") 69 | plt.subplot(1, 3, 3) 70 | plt.hist(note_pitch_dist, bins=100) 71 | plt.title("Note Pitch Distribution") 72 | plt.show() 73 | 74 | #%% 75 | print(len(maestro_meta.canonical_composer.unique())) 76 | print(len(maestro_meta.canonical_title.unique())) 77 | 78 | #%% 79 | notes_str = [librosa.midi_to_note(note.pitch) 80 | for note in midi_data.instruments[0].notes] 81 | #%% 82 | midi_data.time_to_tick(midi_data.instruments[0].notes[-1].end) 83 | #%% 84 | # https://notebook.community/craffel/pretty-midi/Tutorial 85 | # Each note has a start time, end time, pitch, and velocity. 86 | # 87 | # Velocity / volume of a note: 1-127 88 | # Pitch of a note: 0-127 C1-G9 89 | # Duration of a note: seconds, float number 90 | # Start time of a note: seconds, float number 91 | 92 | #%% 93 | 94 | #%% Scratch 95 | # midi to mp3 96 | # !pip install mido 97 | # !pip install midi2audio 98 | #%% 99 | import fluidsynth 100 | # from pydub import AudioSegment 101 | 102 | # Load the MIDI file using FluidSynth 103 | fs = fluidsynth.Synth() 104 | fs.start(driver="coreaudio") # Use appropriate driver for your system 105 | sfid = fs.sfload(midi_file) 106 | fs.program_select(0, sfid, 0, 0) 107 | # Render the MIDI file to a WAV file 108 | fs.midi_to_audio("example.wav", "example.mid") -------------------------------------------------------------------------------- /audio_classifying_transformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | from tqdm import tqdm, trange 4 | import matplotlib.pyplot as plt 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.optim import AdamW 9 | from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config, GPT2Model 10 | from transformers import get_linear_schedule_with_warmup 11 | from torch.utils.data import Dataset, DataLoader 12 | from torch.utils.data import random_split 13 | from torch.nn.utils.rnn import pad_sequence 14 | #%% 15 | from torchfsdd import TorchFSDDGenerator, TrimSilence 16 | from torchaudio.transforms import MFCC 17 | from torchvision.transforms import Compose 18 | 19 | # Create a transformation pipeline to apply to the recordings 20 | transforms = Compose([ 21 | TrimSilence(threshold=1e-6), 22 | MFCC(sample_rate=8e3, n_mfcc=64) 23 | ]) 24 | 25 | # Fetch the latest version of FSDD and initialize a generator with those files 26 | fsdd = TorchFSDDGenerator(version='master', transforms=transforms, 27 | ) #path="/home/binxu/Datasets" 28 | #%% 29 | # Create a Torch dataset for the entire dataset from the generator 30 | full_set = fsdd.full() 31 | # Create two Torch datasets for a train-test split from the generator 32 | train_set, test_set = fsdd.train_test_split(test_size=0.1) 33 | # Create three Torch datasets for a train-validation-test split from the generator 34 | train_set, val_set, test_set = fsdd.train_val_test_split(test_size=0.15, val_size=0.15) 35 | #%% 36 | plt.figure() 37 | plt.imshow(np.log(np.abs(train_set[100][0]))) 38 | plt.show() 39 | 40 | #%% 41 | def collate_fn(batch): 42 | # batch is a list of tuples, where each tuple is (audio_tensor, label_scalar) 43 | audios = [] 44 | labels = [] 45 | for audio, label in batch: 46 | audios.append(audio.T) # time, freq features 47 | labels.append(label) 48 | # pad audio tensors to ensure they have the same length 49 | audios = pad_sequence(audios, batch_first=True, padding_value=0) 50 | # convert the labels list to a tensor 51 | labels = torch.tensor(labels) 52 | return audios, labels 53 | 54 | 55 | #%% 56 | # audio_tsrs, labels = next(iter(dataloaders)) 57 | 58 | #%% 59 | from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config, GPT2Model 60 | config = GPT2Config(n_embd=128, n_layer=12, n_head=16, n_positions=256, 61 | vocab_size=100, bos_token_id=101, eos_token_id=102, 62 | cls_token_id=103, ) 63 | MF_emb = nn.Linear(64, config.n_embd).cuda() 64 | model = GPT2Model(config).cuda() 65 | classifier_head = nn.Linear(config.n_embd, 10).cuda() 66 | CLS_token = torch.randn(1, 1, config.n_embd).cuda() / math.sqrt(config.n_embd) 67 | CLS_token = nn.Parameter(CLS_token) 68 | optimizer = AdamW([*model.parameters(), 69 | *MF_emb.parameters(), 70 | *classifier_head.parameters(), 71 | CLS_token], lr=1e-4) 72 | #%% 73 | 74 | dataloaders = DataLoader(train_set, batch_size=128, shuffle=True, 75 | collate_fn=collate_fn) 76 | test_loader = DataLoader(test_set, batch_size=256, shuffle=True, 77 | collate_fn=collate_fn) 78 | for epoch in trange(20): 79 | model.train() 80 | pbar = tqdm(dataloaders) 81 | for i, (audio, label) in enumerate(pbar): 82 | audio = audio.cuda() 83 | audio = MF_emb(audio) 84 | audio = torch.cat([audio, CLS_token.repeat(audio.shape[0], 1, 1)], dim=1) 85 | output = model(inputs_embeds=audio) 86 | last_hidden_state = output.last_hidden_state 87 | pooled_output = last_hidden_state[:, -1] 88 | logits = classifier_head(pooled_output) 89 | loss = F.cross_entropy(logits, label.cuda()) 90 | optimizer.zero_grad() 91 | loss.backward() 92 | optimizer.step() 93 | pbar.set_description(f"loss: {loss.item():.4f}") 94 | model.eval() 95 | with torch.no_grad(): 96 | test_corr_num = 0 97 | test_loss = 0 98 | for i, (audio, label) in enumerate(test_loader): 99 | audio = audio.cuda() 100 | audio = MF_emb(audio) 101 | audio = torch.cat([audio, CLS_token.repeat(audio.shape[0], 1, 1)], dim=1) 102 | output = model(inputs_embeds=audio) 103 | last_hidden_state = output.last_hidden_state 104 | pooled_output = last_hidden_state[:, -1] 105 | logits = classifier_head(pooled_output) 106 | loss = F.cross_entropy(logits, label.cuda()) 107 | pbar.set_description(f"test loss: {loss.item():.4f}") 108 | test_corr_num += (logits.argmax(dim=1) == label.cuda()).float().sum() 109 | test_loss += loss.item() 110 | print(f"test acc: {test_corr_num / len(test_set):.4f}") 111 | #%% 112 | from transformers import BertModel, BertTokenizer, BertConfig 113 | config = BertConfig(hidden_size=64, intermediate_size=256, num_hidden_layers=12, 114 | num_attention_heads=8, max_position_embeddings=256, 115 | vocab_size=100, bos_token_id=101, eos_token_id=102, 116 | cls_token_id=103, ) 117 | model = BertModel(config).cuda() 118 | # MF_emb = nn.Linear(64, config.hidden_size).cuda() 119 | MF_emb = nn.Sequential(nn.Conv1d(64, config.hidden_size, 3, 1, 1), 120 | nn.ReLU(), 121 | nn.Conv1d(config.hidden_size, config.hidden_size, 3, 1, 1), 122 | ).cuda() 123 | classifier_head = nn.Linear(config.hidden_size, 10).cuda() 124 | CLS_token = torch.randn(1, 1, config.hidden_size).cuda() / math.sqrt(config.hidden_size) 125 | CLS_token = nn.Parameter(CLS_token) 126 | optimizer = AdamW([*model.parameters(), 127 | *MF_emb.parameters(), 128 | *classifier_head.parameters(), 129 | CLS_token], lr=1e-4) 130 | # https://datasets.activeloop.ai/docs/ml/datasets/free-spoken-digit-dataset-fsdd/ 131 | # https://github.com/adhishthite/sound-mnist 132 | #%% 133 | dataloaders = DataLoader(train_set, batch_size=128, shuffle=True, 134 | collate_fn=collate_fn) 135 | val_loader = DataLoader(val_set, batch_size=256, shuffle=True, 136 | collate_fn=collate_fn) 137 | test_loader = DataLoader(test_set, batch_size=256, shuffle=True, 138 | collate_fn=collate_fn) 139 | for epoch in trange(40): 140 | model.train() 141 | pbar = tqdm(dataloaders) 142 | for i, (audio, label) in enumerate(pbar): 143 | audio = audio.cuda() 144 | audio = MF_emb(audio.permute(0, 2, 1)).permute(0, 2, 1) 145 | audio = torch.cat([CLS_token.repeat(audio.shape[0], 1, 1), audio, ], dim=1) 146 | output = model(inputs_embeds=audio) 147 | last_hidden_state = output.last_hidden_state 148 | pooled_output = last_hidden_state[:, 0] 149 | logits = classifier_head(pooled_output) 150 | loss = F.cross_entropy(logits, label.cuda()) 151 | optimizer.zero_grad() 152 | loss.backward() 153 | optimizer.step() 154 | pbar.set_description(f"loss: {loss.item():.4f}") 155 | model.eval() 156 | with torch.no_grad(): 157 | val_corr_num = 0 158 | val_loss = 0 159 | for i, (audio, label) in enumerate(val_loader): 160 | audio = audio.cuda() 161 | audio = MF_emb(audio.permute(0, 2, 1)).permute(0, 2, 1) 162 | audio = torch.cat([CLS_token.repeat(audio.shape[0], 1, 1), audio, ], dim=1) 163 | output = model(inputs_embeds=audio) 164 | last_hidden_state = output.last_hidden_state 165 | pooled_output = last_hidden_state[:, 0] 166 | logits = classifier_head(pooled_output) 167 | loss = F.cross_entropy(logits, label.cuda()) 168 | val_corr_num += (logits.argmax(dim=1) == label.cuda()).float().sum() 169 | val_loss += loss.item() 170 | print(f"val acc: {val_corr_num / len(val_set):.4f}") 171 | 172 | test_corr_num = 0 173 | test_loss = 0 174 | for i, (audio, label) in enumerate(test_loader): 175 | audio = audio.cuda() 176 | audio = MF_emb(audio.permute(0, 2, 1)).permute(0, 2, 1) 177 | audio = torch.cat([CLS_token.repeat(audio.shape[0], 1, 1), audio, ], dim=1) 178 | output = model(inputs_embeds=audio) 179 | last_hidden_state = output.last_hidden_state 180 | pooled_output = last_hidden_state[:, 0] 181 | logits = classifier_head(pooled_output) 182 | loss = F.cross_entropy(logits, label.cuda()) 183 | test_corr_num += (logits.argmax(dim=1) == label.cuda()).float().sum() 184 | test_loss += loss.item() 185 | print(f"test acc: {test_corr_num / len(test_set):.4f}") 186 | # loss: 0.0476: 100%|██████████| 17/17 [00:28<00:00, 1.66s/it] 187 | # val acc: 0.9833 188 | # test acc: 0.9714 189 | # 100%|██████████| 40/40 [25:56<00:00, 38.92s/it] -------------------------------------------------------------------------------- /CIFAR_classifying_transformer.py: -------------------------------------------------------------------------------- 1 | ## Import transformers 2 | from transformers import get_linear_schedule_with_warmup 3 | from transformers import BertTokenizer, BertForSequenceClassification 4 | from transformers import BertModel, BertConfig 5 | from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config, GPT2Model 6 | ## import MNIST from torchvision package 7 | from torchvision import datasets, transforms 8 | ## import torch 9 | import os 10 | from os.path import join 11 | from tqdm import tqdm, trange 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | from torch.optim import AdamW, Adam 17 | from torch.utils.data import Dataset, DataLoader 18 | from torchvision.utils import make_grid, save_image 19 | import matplotlib.pyplot as plt 20 | #%% 21 | from torchvision.datasets import MNIST, CIFAR10 22 | dataset = CIFAR10(root='/home/binxu/Datasets', train=True, download=True, transform= 23 | transforms.Compose([ 24 | transforms.RandomHorizontalFlip(), 25 | transforms.RandomCrop(32, padding=4), 26 | transforms.ToTensor(), 27 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 28 | ])) 29 | # augmentations are super important for CNN trainings, or it will overfit very fast without achieving good generalization accuracy 30 | val_dataset = CIFAR10(root='/home/binxu/Datasets', train=False, download=True, transform=transforms.Compose( 31 | [transforms.ToTensor(), 32 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),])) 33 | #%% 34 | from transformers import BertModel, BertTokenizer, BertConfig 35 | config = BertConfig(hidden_size=256, intermediate_size=1024, num_hidden_layers=12, 36 | num_attention_heads=8, max_position_embeddings=256, 37 | vocab_size=100, bos_token_id=101, eos_token_id=102, 38 | cls_token_id=103, ) 39 | model = BertModel(config).cuda() 40 | patch_embed = nn.Conv2d(3, config.hidden_size, kernel_size=4, stride=4).cuda() 41 | CLS_token = nn.Parameter(torch.randn(1, 1, config.hidden_size, device="cuda") / math.sqrt(config.hidden_size)) 42 | readout = nn.Sequential(nn.Linear(config.hidden_size, config.hidden_size), 43 | nn.GELU(), 44 | nn.Linear(config.hidden_size, 10) 45 | ).cuda() 46 | for module in [patch_embed, readout, model, CLS_token]: 47 | module.cuda() 48 | 49 | optimizer = AdamW([*model.parameters(), 50 | *patch_embed.parameters(), 51 | *readout.parameters(), 52 | CLS_token], lr=5e-4) 53 | #%% 54 | batch_size = 128 # 96 55 | train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) 56 | val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) 57 | model.train() 58 | for epoch in trange(50, leave=False): 59 | pbar = tqdm(train_loader, leave=False) 60 | for i, (imgs, labels) in enumerate(pbar): 61 | patch_embs = patch_embed(imgs.cuda()) 62 | patch_embs = patch_embs.flatten(2).permute(0, 2, 1) # (batch_size, HW, hidden) 63 | # print(patch_embs.shape) 64 | input_embs = torch.cat([CLS_token.expand(imgs.shape[0], 1, -1), patch_embs], dim=1) 65 | # print(input_embs.shape) 66 | output = model(inputs_embeds=input_embs) 67 | logit = readout(output.last_hidden_state[:, 0, :]) 68 | loss = F.cross_entropy(logit, labels.cuda()) 69 | # print(loss) 70 | loss.backward() 71 | optimizer.step() 72 | optimizer.zero_grad() 73 | pbar.set_description(f"loss: {loss.item():.4f}") 74 | 75 | # test on validation set 76 | model.eval() 77 | correct_cnt = 0 78 | total_loss = 0 79 | for i, (imgs, labels) in enumerate(val_loader): 80 | patch_embs = patch_embed(imgs.cuda()) 81 | patch_embs = patch_embs.flatten(2).permute(0, 2, 1) # (batch_size, HW, hidden) 82 | input_embs = torch.cat([CLS_token.expand(imgs.shape[0], 1, -1), patch_embs], dim=1) 83 | output = model(inputs_embeds=input_embs) 84 | logit = readout(output.last_hidden_state[:, 0, :]) 85 | loss = F.cross_entropy(logit, labels.cuda()) 86 | total_loss += loss.item() * imgs.shape[0] 87 | correct_cnt += (logit.argmax(dim=1) == labels.cuda()).sum().item() 88 | 89 | print(f"val loss: {total_loss / len(val_dataset):.4f}, val acc: {correct_cnt / len(val_dataset):.4f}") 90 | 91 | # over fitting, validation accuracy 92 | # patch size: 4 93 | # loss: 0.2320: 100%|███████████████████████████| 521/521 [00:36<00:00, 14.45it/s] 94 | # val loss: 1.3631, val acc: 0.6644 95 | # 100%|███████████████████████████████████████████| 25/25 [16:14<00:00, 38.97s/it] 96 | 97 | # over fitting, validation accuracy 98 | # patch size: 8 99 | # loss: 0.3418: 100%|██████████████████████████▉| 520/521 [00:36<00:00, 14.23it/s] 100 | # val loss: 1.7761, val acc: 0.6016 101 | # 96%|█████████████████████████████████████████▎ | 24/25 [15:06<00:38, 38.13s/it] 102 | 103 | # with data augmentation, patch size: 4 104 | # batch_size = 128, lr=5e-4 (AdamW) 105 | # loss: 0.1783: 100%|███████████████████████████| 391/391 [00:40<00:00, 9.71it/s] 106 | # val loss: 0.5719, val acc: 0.8309 107 | # 86%|████████████████████████████████████▉ | 43/50 [29:29<05:07, 43.88s/it] 108 | 109 | # loss: 0.1398: 100%|██████████████████████████▉| 390/391 [00:37<00:00, 10.69it/s] 110 | # val loss: 0.6580, val acc: 0.8181 111 | #%% 112 | # Visual Transformers. Despite some previous work in which attention is used inside the convolutional layers of a CNN [57, 26], the first fully-transformer architectures for vision are iGPT [8] and ViT [17]. The former is trained using a "masked-pixel" self-supervised approach, similar in spirit to the common masked-word task used, for instance, in BERT [15] and in GPT [45] (see below). On the other hand, ViT is trained in a supervised way, using a special "class token" and a classification head attached to the final embedding of this token. Both methods are computationally expensive and, despite their very good results when trained on huge datasets, they underperform ResNet architectures when trained from scratch using only ImageNet-1K [17, 8]. VideoBERT [51] is conceptually similar to iGPT, but, rather than using pixels as tokens, each frame of a video is holistically represented by a feature vector, which is quantized using an off-the-shelf pretrained video classification model. DeiT [53] trains ViT using distillation information provided by a pretrained CNN 113 | #%% 114 | # https://openreview.net/pdf?id=SCN8UaetXx 115 | from torchvision.models import resnet18 116 | model_cnn = resnet18(pretrained=False) 117 | model_cnn.fc = nn.Linear(512, 10) 118 | model_cnn.cuda() 119 | optimizer = AdamW(model_cnn.parameters(), lr=10e-4) 120 | #%% 121 | batch_size = 512 122 | train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) 123 | val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) 124 | for epoch in trange(50, leave=False): 125 | model_cnn.train() 126 | pbar = tqdm(train_loader, leave=False) 127 | for i, (imgs, labels) in enumerate(pbar): 128 | output = model_cnn(imgs.cuda()) 129 | loss = F.cross_entropy(output, labels.cuda()) 130 | loss.backward() 131 | optimizer.step() 132 | optimizer.zero_grad() 133 | pbar.set_description(f"loss: {loss.item():.4f}") 134 | 135 | # test on validation set 136 | model_cnn.eval() 137 | correct_cnt = 0 138 | total_loss = 0 139 | for i, (imgs, labels) in enumerate(val_loader): 140 | output = model_cnn(imgs.cuda()) 141 | loss = F.cross_entropy(output, labels.cuda()) 142 | total_loss += loss.item() * imgs.shape[0] 143 | correct_cnt += (output.argmax(dim=1) == labels.cuda()).sum().item() 144 | 145 | print(f"val loss: {total_loss / len(val_dataset):.4f}, val acc: {correct_cnt / len(val_dataset):.4f}") 146 | 147 | # batch size: 96, lr 1E-4 148 | # loss: 0.1894: 100%|██████████████████████████▉| 519/521 [00:13<00:00, 37.36it/s] 149 | # 96%|█████████████████████████████████████████▎ | 24/25 [06:03<00:15, 15.17s/it] 150 | # val loss: 1.9748, val acc: 0.6292 151 | 152 | # batch size: 256, lr 10E-4 153 | # 96%|█████████████████████████████████████████▎ | 24/25 [03:41<00:09, 9.04s/it] 154 | # loss: 0.0179: 99%|██████████████████████████▊| 195/196 [00:07<00:00, 24.92it/s] 155 | # val loss: 1.6367, val acc: 0.7034 156 | 157 | # batch size: 512, lr 10E-4 with data augmentation 158 | # batch size: 512, lr 10E-4 with data augmentation 159 | # loss: 0.4298: 100%|█████████████████████████████| 98/98 [00:16<00:00, 6.22it/s] 160 | # val loss: 0.6143, val acc: 0.8060 161 | # 56%|████████████████████████ | 28/50 [08:37<06:40, 18.21s/it] 162 | # loss: 0.2113: 100%|█████████████████████████████| 98/98 [00:16<00:00, 6.28it/s] 163 | # val loss: 0.6069, val acc: 0.8352 164 | # 98%|██████████████████████████████████████████▏| 49/50 [14:58<00:18, 18.13s/it] 165 | 166 | 167 | #%% 168 | """ 169 | CNN is an architecture with strong inductive bias. 170 | It will learn the data faster than transformer. But it's also faster to overfit. 171 | When proper data augmentation is applied, CNN will outperform transformer. 172 | 173 | Transformer is an architecture with weak inductive bias 174 | It will learn the data slower / needs more compute than CNN. 175 | """ 176 | torch.save(model_cnn.state_dict(),"cnn_resnet18.pth") 177 | #%% 178 | torch.save(model.state_dict(),"bert.pth") -------------------------------------------------------------------------------- /MNIST_generative_transformer.py: -------------------------------------------------------------------------------- 1 | ## Import transformers 2 | from transformers import get_linear_schedule_with_warmup 3 | from transformers import BertTokenizer, BertForSequenceClassification 4 | from transformers import BertModel, BertConfig 5 | from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config, GPT2Model 6 | ## import MNIST from torchvision package 7 | from torchvision import datasets, transforms 8 | ## import torch 9 | import os 10 | from os.path import join 11 | from tqdm import tqdm, trange 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | from torch.optim import AdamW, Adam 17 | from torch.utils.data import Dataset, DataLoader 18 | from torchvision.utils import make_grid, save_image 19 | import matplotlib.pyplot as plt 20 | #%% 21 | from torchvision.datasets import MNIST 22 | 23 | dataset = MNIST(root='/home/binxu/Datasets', download=True, transform=transforms.ToTensor()) 24 | #%% 25 | # image to patch sequence by convolution 26 | def img2patch(img, patch_size=2): 27 | # img: (batch_size, channel, height, width) 28 | # patch_size: int 29 | # return: (batch_size, channel, height//patch_size, width//patch_size, patch_size**2) 30 | batch_size, channel, height, width = img.shape 31 | img = img.reshape(batch_size, channel, height//patch_size, patch_size, width//patch_size, patch_size) 32 | img = img.permute(0, 2, 4, 1, 3, 5) 33 | img = img.reshape(batch_size, height//patch_size, width//patch_size, channel * patch_size**2) 34 | return img 35 | 36 | 37 | def patch2seq(patches): 38 | batch_size, height, width, channel = patches.shape 39 | return patches.reshape(batch_size, height*width, channel) 40 | 41 | 42 | def seq2img(patch_seq, patch_size=2): 43 | # patch_seq: (batch_size, channel, height//patch_size, width//patch_size, patch_size**2) 44 | # patch_size: int 45 | # return: (batch_size, channel, height, width) 46 | batch_size, HW, hidden = patch_seq.shape 47 | height = width = int(math.sqrt(HW)) 48 | channel = hidden // patch_size**2 49 | patch_seq = patch_seq.reshape(batch_size, height, width, channel, patch_size, patch_size) 50 | patch_seq = patch_seq.permute(0, 3, 1, 4, 2, 5) 51 | imgtsr = patch_seq.reshape(batch_size, channel, height*patch_size, width*patch_size) 52 | return imgtsr 53 | #%% 54 | # GPT2Config() 55 | dataloaders = DataLoader(dataset, batch_size=64, shuffle=True) 56 | imgs, labels = next(iter(dataloaders)) 57 | print(patch2seq(img2patch(imgs, patch_size=4)).shape) 58 | assert torch.allclose(seq2img(patch2seq(img2patch(imgs,)),), imgs) 59 | assert torch.allclose(seq2img(patch2seq(img2patch(imgs, patch_size=4)), patch_size=4), imgs) 60 | #%% 61 | 62 | #%% Conditional model 63 | patch_size = 4 64 | config = GPT2Config(n_embd=128, n_layer=24, n_head=16, n_positions=256, 65 | vocab_size=100, bos_token_id=101, eos_token_id=102, 66 | add_cross_attention=True, ) 67 | 68 | model = GPT2Model(config).cuda() 69 | patch_emb = nn.Linear(patch_size * patch_size, config.n_embd).cuda() 70 | patch_readout = nn.Linear(config.n_embd, patch_size * patch_size).cuda() 71 | digit_emb = nn.Embedding(10, config.n_embd).cuda() 72 | optimizer = AdamW([*model.parameters(), 73 | *patch_emb.parameters(), 74 | *patch_readout.parameters(), 75 | *digit_emb.parameters()], lr=5e-4) 76 | #%% 77 | def generate_img(prompt_digit, prompt_patch, model, digit_emb, patch_emb, patch_readout, 78 | patch_size=4, pixel_size=28): 79 | max_seq_len = (pixel_size // patch_size) ** 2 - 1 80 | prompt_digit_emb = digit_emb(torch.tensor([*prompt_digit]).cuda())[:, None, :] 81 | patch_seq = [prompt_patch] 82 | input_patch_emb = patch_emb(prompt_patch) 83 | with torch.no_grad(): 84 | for i in range(max_seq_len): 85 | output = model(inputs_embeds=input_patch_emb, 86 | encoder_hidden_states=prompt_digit_emb) 87 | output_hiddens = output.last_hidden_state 88 | next_patch = patch_readout(output_hiddens[:, -1:, :]) 89 | input_patch_emb = torch.cat([input_patch_emb, patch_emb(next_patch)], dim=1) 90 | patch_seq.append(next_patch) 91 | 92 | patch_seq_tsr = torch.cat(patch_seq, dim=1) 93 | gen_imgs = seq2img(patch_seq_tsr, patch_size=patch_size) 94 | return gen_imgs 95 | #%% 96 | # saveroot = r"D:\DL_Projects\Vision\pixel_GPT" 97 | saveroot = r"/home/binxu/DL_Projects/patchGPT" 98 | runname = "conditional" 99 | os.makedirs(join(saveroot, runname), exist_ok=True) 100 | batch_size = 512 101 | dataloaders = DataLoader(dataset, batch_size=batch_size, shuffle=True) 102 | for epoch in trange(1,50): 103 | pbar = tqdm(dataloaders) 104 | model.train() 105 | for ibatch, (imgs, labels) in enumerate(pbar): 106 | digit_hiddens = digit_emb(labels.cuda())[:, None, :] 107 | patch_seq = patch2seq(img2patch(imgs.cuda(), patch_size=patch_size)) 108 | input_embeds = patch_emb(patch_seq) 109 | output = model(inputs_embeds=input_embeds, encoder_hidden_states=digit_hiddens) 110 | output_hiddens = output.last_hidden_state 111 | output_patches = patch_readout(output_hiddens) 112 | loss = F.mse_loss(output_patches[:, :-1, :], patch_seq[:, 1:, :]) 113 | optimizer.zero_grad() 114 | loss.backward() 115 | optimizer.step() 116 | pbar.set_description(f"loss: {loss.item():.4f}") 117 | # print(loss.item()) 118 | 119 | prompt_digit = range(10) 120 | prompt_patch = torch.zeros(10, 1, patch_size * patch_size).cuda() 121 | model.eval() 122 | gen_imgs = generate_img(prompt_digit, prompt_patch, model, digit_emb, patch_emb, patch_readout, 123 | patch_size=patch_size, pixel_size=28) 124 | save_image(make_grid(gen_imgs, nrow=5), join(saveroot, runname, f'{epoch}_genimages.png')) 125 | 126 | model.save_pretrained(join(saveroot, runname, "model")) 127 | patch_emb.cpu().save(join(saveroot, runname, "patch_emb.pth")) 128 | patch_readout.cpu().save(join(saveroot, runname, "patch_readout.pth")) 129 | #%% 130 | 131 | #%% 132 | prompt_digit = range(10) 133 | prompt_patch = torch.zeros(10, 1, patch_size * patch_size).cuda() 134 | gen_imgs = generate_img(prompt_digit, prompt_patch, model, digit_emb, patch_emb, patch_readout, 135 | patch_size=patch_size, pixel_size=28) 136 | plt.figure() 137 | plt.imshow(make_grid(gen_imgs, nrow=5).permute(1, 2, 0).cpu()) 138 | plt.show() 139 | 140 | #%% 141 | def generate_img_uncond(prompt_patch, model, patch_emb, patch_readout, 142 | patch_size=4, pixel_size=28): 143 | max_seq_len = (pixel_size // patch_size) ** 2 - 1 144 | patch_seq = [prompt_patch] 145 | input_patch_emb = patch_emb(prompt_patch) 146 | with torch.no_grad(): 147 | for i in range(max_seq_len): 148 | output = model(inputs_embeds=input_patch_emb,) 149 | output_hiddens = output.last_hidden_state 150 | next_patch = patch_readout(output_hiddens[:, -1:, :]) 151 | input_patch_emb = torch.cat([input_patch_emb, patch_emb(next_patch)], dim=1) 152 | patch_seq.append(next_patch) 153 | 154 | patch_seq_tsr = torch.cat(patch_seq, dim=1) 155 | gen_imgs = seq2img(patch_seq_tsr, patch_size=patch_size) 156 | return gen_imgs 157 | #%% 158 | patch_size = 4 159 | config = GPT2Config(n_embd=128, n_layer=24, n_head=16, n_positions=256, 160 | vocab_size=100, bos_token_id=101, eos_token_id=102, ) 161 | model = GPT2Model(config).cuda() 162 | patch_emb = nn.Linear(patch_size * patch_size, config.n_embd).cuda() 163 | patch_readout = nn.Linear(config.n_embd, patch_size * patch_size).cuda() 164 | optimizer = AdamW([*model.parameters(), 165 | *patch_emb.parameters(), 166 | *patch_readout.parameters()], lr=5e-4) 167 | saveroot = r"/home/binxu/DL_Projects/patchGPT" 168 | runname = "unconditional" 169 | os.makedirs(join(saveroot, runname), exist_ok=True) 170 | #%% 171 | batch_size = 512 172 | dataloaders = DataLoader(dataset, batch_size=batch_size, shuffle=True) 173 | for epoch in range(50): 174 | pbar = tqdm(dataloaders) 175 | model.train() 176 | for ibatch, (imgs, labels) in enumerate(pbar): 177 | patch_seq = patch2seq(img2patch(imgs.cuda(), patch_size=patch_size)) 178 | input_embeds = patch_emb(patch_seq) 179 | output = model(inputs_embeds=input_embeds) 180 | output_hiddens = output.last_hidden_state 181 | output_patches = patch_readout(output_hiddens) 182 | loss = F.mse_loss(output_patches[:, :-1, :], patch_seq[:, 1:, :]) 183 | optimizer.zero_grad() 184 | loss.backward() 185 | optimizer.step() 186 | pbar.set_description(f"loss: {loss.item():.4f}") 187 | # print(loss.item()) 188 | 189 | prompt_digit = range(10) 190 | prompt_patch = torch.zeros(10, 1, patch_size * patch_size).cuda() 191 | model.eval() 192 | gen_imgs = generate_img_uncond(prompt_patch, model, patch_emb, patch_readout, 193 | patch_size=patch_size, pixel_size=28) 194 | save_image(make_grid(gen_imgs, nrow=5), join(saveroot, runname, f'{epoch}_genimages.png')) 195 | 196 | model.save_pretrained(join(saveroot, runname, "model")) 197 | patch_emb.cpu().save(join(saveroot, runname, "patch_emb.pth")) 198 | patch_readout.cpu().save(join(saveroot, runname, "patch_readout.pth")) 199 | 200 | #%% 201 | make_grid(seq2img(output_patches, patch_size=patch_size)) 202 | # plt.imshow(make_grid(imgs[:16, :, :, :]).permute(1, 2, 0)) 203 | plt.figure() 204 | plt.imshow(make_grid(seq2img(output_patches.detach().cpu(), patch_size=patch_size)).permute(1, 2, 0)) 205 | plt.show() -------------------------------------------------------------------------------- /attention_to_Transformer.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import math 7 | seed = 42 8 | np.random.seed(seed) 9 | torch.manual_seed(seed) 10 | torch.cuda.manual_seed(seed) 11 | 12 | #%% Multi head attention 13 | embdim = 768 14 | headcnt = 12 15 | headdim = embdim // headcnt 16 | assert headdim * headcnt == embdim 17 | tokens = torch.randn(1, 5, embdim) # batch, tokens, embedding 18 | Wq = torch.randn(embdim, headcnt * headdim) / math.sqrt(embdim) # heads packed in a single dim 19 | Wk = torch.randn(embdim, headcnt * headdim) / math.sqrt(embdim) # heads packed in a single dim 20 | Wv = torch.randn(embdim, headcnt * headdim) / math.sqrt(embdim) # heads packed in a single dim 21 | 22 | batch, token_num, _ = tokens.shape 23 | qis = torch.einsum("BSE,EH->BSH", tokens, Wq) 24 | kis = torch.einsum("BTE,EH->BTH", tokens, Wk) 25 | vis = torch.einsum("BTE,EH->BTH", tokens, Wv) 26 | qis_mh = qis.view(batch, token_num, headcnt, headdim) 27 | kis_mh = kis.view(batch, token_num, headcnt, headdim) 28 | vis_mh = vis.view(batch, token_num, headcnt, headdim) 29 | 30 | scoremat_mh = torch.einsum("BSCH,BTCH->BCST", qis_mh, kis_mh) # batch x headcnt x seqlen (query) x seqlen (key) 31 | attmat_mh = F.softmax(scoremat_mh / math.sqrt(headdim), dim=-1) 32 | zis_mh = torch.einsum("BCST,BTCH->BSCH", attmat_mh, vis_mh) # batch x seqlen (query) x headcnt x headdim 33 | # zis_mh = torch.einsum("BSCT,BTCH->BSCH", attmat_mh, vis_mh) 34 | zis = zis_mh.reshape(batch, token_num, headcnt * headdim) 35 | #%% 36 | mha = nn.MultiheadAttention(embdim, headcnt, batch_first=True,) 37 | # print(mha.in_proj_weight.shape) # 3 * embdim x embdim 38 | mha.in_proj_weight.data = torch.cat([Wq, Wk, Wv], dim=1).T 39 | attn_out, attn_weights = mha(tokens, tokens, tokens, average_attn_weights=False,) 40 | assert torch.allclose(attmat_mh, attn_weights, atol=1e-6, rtol=1e-6) 41 | assert torch.allclose(attn_out, mha.out_proj(zis), atol=1e-6, rtol=1e-6) 42 | 43 | #%% 44 | plt.figure() 45 | for head in range(headcnt): 46 | plt.subplot(3, 4, head + 1) 47 | plt.imshow(attmat_mh[0, head].detach().numpy()) 48 | plt.title(f"head {head}") 49 | plt.show() 50 | 51 | #%% Causal attention mask 52 | attn_mask = torch.ones(token_num,token_num,) 53 | attn_mask = -1E4 * torch.triu(attn_mask,1) 54 | attn_mask 55 | 56 | #%% 57 | scoremat_mh_msk = torch.einsum("BSCH,BTCH->BCST", qis_mh, kis_mh) # batch x headcnt x seqlen (query) x seqlen (key) 58 | scoremat_mh_msk += attn_mask # add the attn mask to the scores before SoftMax normalization 59 | attmat_mh_msk = F.softmax(scoremat_mh_msk / math.sqrt(headdim), dim=-1) 60 | zis_mh_msk = torch.einsum("BCST,BTCH->BSCH", attmat_mh_msk, vis_mh) # batch x seqlen (query) x headcnt x headdim 61 | # zis_mh = torch.einsum("BSCT,BTCH->BSCH", attmat_mh, vis_mh) 62 | zis_msk = zis_mh_msk.reshape(batch, token_num, headcnt * headdim) 63 | 64 | #%% 65 | attn_out_causal, attn_weights_causal = mha(tokens, tokens, tokens, average_attn_weights=False, attn_mask=attn_mask) 66 | assert torch.allclose(attn_weights_causal, attmat_mh_msk, atol=1e-6, rtol=1e-6) 67 | assert torch.allclose(attn_out_causal, mha.out_proj(zis_msk), atol=1e-6, rtol=1e-6) 68 | 69 | #%% 70 | class TransformerBlock_simple(nn.Module): 71 | 72 | def __init__(self, embdim, headcnt, *args, dropout=0.0, **kwargs) -> None: 73 | super().__init__(*args, **kwargs) 74 | self.ln1 = nn.LayerNorm(embdim) 75 | self.ln2 = nn.LayerNorm(embdim) 76 | self.attn = nn.MultiheadAttention(embdim, headcnt, batch_first=True,) 77 | self.ffn = nn.Sequential( 78 | nn.Linear(embdim, 4 * embdim), 79 | nn.GELU(), 80 | nn.Linear(4 * embdim, embdim), 81 | nn.Dropout(dropout), 82 | ) 83 | 84 | def forward(self, x, is_causal=True): 85 | batch, token_num, _ = x.shape 86 | if is_causal: 87 | attn_mask = torch.ones(token_num, token_num,) 88 | attn_mask = -1E4 * torch.triu(attn_mask,1) 89 | else: 90 | attn_mask = None 91 | 92 | residue = x 93 | x = self.ln1(x) 94 | attn_output, attn_weights = self.attn(x, x, x, attn_mask=attn_mask) # first output is the output latent states 95 | x = residue + attn_output 96 | 97 | residue = x 98 | x = self.ln2(x) 99 | ffn_output = self.ffn(x) 100 | output = residue + ffn_output 101 | return output 102 | 103 | #%% 104 | from transformers import GPT2Model, GPT2Tokenizer 105 | from transformers.activations import NewGELUActivation 106 | model = GPT2Model.from_pretrained("gpt2") 107 | tokenizer = GPT2Tokenizer.from_pretrained("gpt2") 108 | model.eval() 109 | #%% 110 | #%% 111 | inputs = tokenizer("Hi, I have a cat, her name is", return_tensors="pt") 112 | outputs = model(**inputs, output_attentions=True, output_hidden_states=True) 113 | #%% 114 | token_strs = tokenizer.tokenize("Hi, I have a cat, her name is") 115 | #%% 116 | print("Shape of final output token vectors", outputs.last_hidden_state.shape) 117 | # attention of each GPTBlock: 118 | print("num of attention outputs", len(outputs.attentions)) 119 | # shape of each attention tensor: [batch, heads, token (source), token (target)] 120 | print("shape of each attention tensor", outputs.attentions[-1].shape) 121 | print("num of hidden states (input embed included.) ", len(outputs.hidden_states)) 122 | print("shape of each hidden states tensor", outputs.hidden_states[-1].shape) #[batch, token, hidden] 123 | assert torch.allclose(outputs.hidden_states[-1], outputs.last_hidden_state) 124 | #%% 125 | 126 | plt.figure(figsize=(10, 10)) 127 | for head in range(12): 128 | plt.subplot(3, 4, head + 1) 129 | plt.imshow(outputs.attentions[-1][0, head, :, :].detach().numpy()) 130 | plt.yticks(range(len(token_strs)), token_strs) 131 | plt.xticks(range(len(token_strs)), token_strs) 132 | plt.show() 133 | 134 | 135 | #%% GPT2 from scratch 136 | embdim = 768 137 | headcnt = 12 138 | tfmblock = TransformerBlock_simple(embdim, headcnt) 139 | 140 | #%% 141 | model.h[0].attn.c_attn.weight.shape 142 | tfmblock.attn.in_proj_weight.shape 143 | tfmblock.ln1.weight.data = model.h[0].ln_1.weight 144 | tfmblock.ln1.bias.data = model.h[0].ln_1.bias 145 | tfmblock.ln2.weight.data = model.h[0].ln_2.weight 146 | tfmblock.ln2.bias.data = model.h[0].ln_2.bias 147 | tfmblock.attn.in_proj_weight.data = model.h[0].attn.c_attn.weight.T 148 | tfmblock.attn.in_proj_bias.data = model.h[0].attn.c_attn.bias 149 | tfmblock.attn.out_proj.weight.data = model.h[0].attn.c_proj.weight.T 150 | tfmblock.attn.out_proj.bias.data = model.h[0].attn.c_proj.bias 151 | tfmblock.ffn[0].weight.data = model.h[0].mlp.c_fc.weight.T 152 | tfmblock.ffn[0].bias.data = model.h[0].mlp.c_fc.bias 153 | tfmblock.ffn[1] = NewGELUActivation() # mlp in GPT2 and BERT used a new GELU activation, using nn.GeLU() will cause a small error around 1E-3 154 | tfmblock.ffn[2].weight.data = model.h[0].mlp.c_proj.weight.T 155 | tfmblock.ffn[2].bias.data = model.h[0].mlp.c_proj.bias 156 | #%% 157 | def GPT2block_to_TransformerBlock_simple(tfmblock, gpt2block, ): 158 | """copy the weights from a GPT2 block to a TransformerBlock_simple""" 159 | tfmblock.ln1.weight.data = gpt2block.ln_1.weight 160 | tfmblock.ln1.bias.data = gpt2block.ln_1.bias 161 | tfmblock.ln2.weight.data = gpt2block.ln_2.weight 162 | tfmblock.ln2.bias.data = gpt2block.ln_2.bias 163 | tfmblock.attn.in_proj_weight.data = gpt2block.attn.c_attn.weight.T 164 | tfmblock.attn.in_proj_bias.data = gpt2block.attn.c_attn.bias 165 | tfmblock.attn.out_proj.weight.data = gpt2block.attn.c_proj.weight.T 166 | tfmblock.attn.out_proj.bias.data = gpt2block.attn.c_proj.bias 167 | tfmblock.ffn[0].weight.data = gpt2block.mlp.c_fc.weight.T 168 | tfmblock.ffn[0].bias.data = gpt2block.mlp.c_fc.bias 169 | tfmblock.ffn[1] = NewGELUActivation() 170 | # mlp in GPT2 and BERT used a new GELU activation, using nn.GeLU() will cause a small error around 1E-3 171 | tfmblock.ffn[2].weight.data = gpt2block.mlp.c_proj.weight.T 172 | tfmblock.ffn[2].bias.data = gpt2block.mlp.c_proj.bias 173 | return tfmblock 174 | 175 | 176 | def test_TransformerBlock_simple_GPT(block): 177 | tfmblock = TransformerBlock_simple(768, 12) 178 | GPT2block_to_TransformerBlock_simple(tfmblock, block) 179 | tokens_embs = torch.randn(2, 5, 768) 180 | tfmblock_out = tfmblock(tokens_embs, is_causal=True) 181 | block_out, = block(tokens_embs) 182 | assert torch.allclose(tfmblock_out, block_out, atol=1e-5, rtol=1e-5) 183 | 184 | 185 | GPT2block_to_TransformerBlock_simple(tfmblock, model.h[0]) 186 | tokens_embs = torch.randn(2, 5, 768) 187 | #%% 188 | tfmblock_out = tfmblock(tokens_embs, is_causal=True) 189 | modelblock_out, = model.h[0](tokens_embs) 190 | #%% 191 | assert torch.allclose(tfmblock_out, modelblock_out, atol=1e-5, rtol=1e-5) 192 | 193 | 194 | #%% 195 | class GPT2Model_simple(nn.Module): 196 | 197 | def __init__(self): 198 | super().__init__() 199 | self.wte = nn.Embedding(50257, 768) 200 | self.wpe = nn.Embedding(1024, 768) 201 | self.blocks = nn.ModuleList([TransformerBlock_simple(768, 12) for _ in range(12)]) 202 | self.ln_f = nn.LayerNorm(768) 203 | 204 | def forward(self, input_ids, input_embeds=None, is_causal=True): 205 | embeds = self.wte(input_ids) if input_embeds is None else input_embeds 206 | embeds = embeds + self.wpe(torch.arange(embeds.shape[1], device=embeds.device)) 207 | for block in self.blocks: 208 | embeds = block(embeds, is_causal=is_causal) 209 | return self.ln_f(embeds) 210 | 211 | 212 | def GPT2Model_to_GPT2Model_simple(gpt2modelsimple: GPT2Model_simple, gpt2model: GPT2Model): 213 | """copy the weights from a GPT2 model to a GPT2Model_simple""" 214 | gpt2modelsimple.wte.weight.data = gpt2model.wte.weight 215 | gpt2modelsimple.wpe.weight.data = gpt2model.wpe.weight 216 | gpt2modelsimple.ln_f.weight.data = gpt2model.ln_f.weight 217 | gpt2modelsimple.ln_f.bias.data = gpt2model.ln_f.bias 218 | for i in range(12): 219 | GPT2block_to_TransformerBlock_simple(gpt2modelsimple.blocks[i], gpt2model.h[i]) 220 | return gpt2modelsimple 221 | 222 | 223 | def test_our_GPT2(model: GPT2Model, tokenizer: GPT2Tokenizer, 224 | text: str = "I have a cat, her name is"): 225 | model_ours = GPT2Model_simple() 226 | GPT2Model_to_GPT2Model_simple(model_ours, model) 227 | inputs = tokenizer(text, return_tensors="pt") 228 | outputs = model(**inputs, ) 229 | hidden_last_ours = model_ours(inputs['input_ids']) 230 | assert torch.allclose(outputs.last_hidden_state, hidden_last_ours, atol=1e-5, rtol=1e-5) 231 | return model_ours 232 | 233 | model_ours = test_our_GPT2(model, tokenizer) -------------------------------------------------------------------------------- /train_maestro_GPT2.py: -------------------------------------------------------------------------------- 1 | maestro_v3_midi_url = "https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0-midi.zip" 2 | maestro_v3_meta_url = "https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0.csv" 3 | maestro_v3_json_url = "https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0.json" 4 | # %% 5 | # download and unzip the urls above 6 | # !wget $maestro_v3_midi_url 7 | # !wget $maestro_v3_meta_url 8 | # !wget $maestro_v3_json_url 9 | # !unzip maestro-v3.0.0-midi.zip 10 | # %% 11 | import os 12 | from os.path import join 13 | import numpy as np 14 | import pandas as pd 15 | import torch 16 | import torch.nn as nn 17 | from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config 18 | from torch.optim import AdamW, Adam 19 | from transformers import AdamW, get_linear_schedule_with_warmup 20 | import librosa.display 21 | import librosa 22 | import matplotlib.pyplot as plt 23 | from tqdm import tqdm, trange 24 | from typing import Union 25 | import pretty_midi 26 | import pickle 27 | 28 | # maestro_root = r"E:\Datasets\maestro-v3.0.0-midi\maestro-v3.0.0" 29 | maestro_root = r"/home/binxu/Datasets/maestro-v3.0.0" 30 | 31 | 32 | # visualize the midi file as piano roll 33 | def plot_piano_roll(pm, start_pitch, end_pitch, fs=100): 34 | # Use librosa's specshow function for displaying the piano roll 35 | librosa.display.specshow(pm.get_piano_roll(fs)[start_pitch:end_pitch], 36 | hop_length=1, sr=fs, x_axis='time', y_axis='cqt_note', 37 | fmin=pretty_midi.note_number_to_hz(start_pitch)) 38 | 39 | 40 | # load the metadata 41 | maestro_meta = pd.read_csv(join(maestro_root, "maestro-v3.0.0.csv")) 42 | # %% 43 | # https://notebook.community/craffel/pretty-midi/Tutorial 44 | # Each note has a start time, end time, pitch, and velocity. 45 | # 46 | # Velocity / volume of a note: 1-127 47 | # Pitch of a note: 0-127 C1-G9 48 | # Duration of a note: seconds, float number 49 | # Start time of a note: seconds, float number 50 | 51 | # %% 52 | # midi to note sequence 53 | def midi2notes(midi_file: Union[str, pretty_midi.PrettyMIDI]): 54 | if isinstance(midi_file, str): 55 | midi_data = pretty_midi.PrettyMIDI(midi_file) 56 | else: 57 | midi_data = midi_file 58 | 59 | note_t_seq = [] 60 | note_dt_seq = [] 61 | note_duration_seq = [] 62 | note_velo_seq = [] 63 | note_pitch_seq = [] 64 | prev_note_end = 0.0 65 | for i, note_sample in enumerate(midi_data.instruments[0].notes[1:]): 66 | note_dt_seq.append(note_sample.start - prev_note_end) 67 | note_t_seq.append(note_sample.start) 68 | note_duration_seq.append(note_sample.duration) 69 | note_velo_seq.append(note_sample.velocity) 70 | note_pitch_seq.append(note_sample.pitch) 71 | prev_note_end = note_sample.end 72 | 73 | note_seq = pd.DataFrame({"t": note_t_seq, "dt": note_dt_seq, "duration": note_duration_seq, "velo": note_velo_seq, 74 | "pitch": note_pitch_seq}) # .to_csv("note_seq.csv", index=False) 75 | return note_seq 76 | 77 | 78 | midi_file = join(maestro_root, maestro_meta["midi_filename"][0].replace("/", os.path.sep)) 79 | note_seq = midi2notes(midi_file) 80 | # %% 81 | class MusicScoreDataset(torch.utils.data.Dataset): 82 | 83 | def __init__(self, maestro_root, maestro_meta=None, ): 84 | self.maestro_root = maestro_root 85 | if maestro_meta is None: 86 | self.maestro_meta = pd.read_csv(join(maestro_root, "maestro-v3.0.0.csv")) 87 | else: 88 | self.maestro_meta = maestro_meta 89 | 90 | self.dataset = {} 91 | 92 | def load_dataset(self): 93 | for i in tqdm(range(len(self.maestro_meta))): 94 | row = self.maestro_meta.iloc[i] 95 | midi_path = join(self.maestro_root, row["midi_filename"].replace("/", os.path.sep)) 96 | # load sample midi file 97 | midi_data = pretty_midi.PrettyMIDI(midi_path) 98 | note_seq = midi2notes(midi_data) 99 | self.dataset[i] = (row[["canonical_composer", 'canonical_title', 'year', 'duration']], note_seq) 100 | 101 | return 102 | 103 | def __len__(self): 104 | return len(self.dataset) 105 | 106 | def __getitem__(self, idx): 107 | meta, note_seq = self.dataset[idx] 108 | return meta, note_seq 109 | 110 | 111 | dataset = MusicScoreDataset(maestro_root, maestro_meta) 112 | # %% 113 | dataset.dataset = pickle.load(open(join(maestro_root, "dataset.pkl"), "rb")) 114 | # %% 115 | # import pickle 116 | # pickle.dump(dataset.dataset, open(join(maestro_root, "dataset.pkl"), "wb")) 117 | # %% 118 | meta, note_seq = dataset[0] 119 | # %% 120 | # %% 121 | from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence 122 | # from torch.utils.data import DataLoader 123 | # loader = DataLoader(dataset, batch_size=128, collate_fn=batch_sampler, shuffle=True) 124 | # next(iter(loader)) 125 | token_num = 128 # 128 pitch classes 126 | BOS_id = token_num 127 | EOS_id = token_num + 1 128 | PAD_ID = EOS_id + 1 129 | MASK_LABEL_ID = - 100 130 | 131 | def notedf2tensor(note_df, device="cuda"): 132 | pitch_ids = torch.tensor(note_df.pitch.values).unsqueeze(0).long().to(device) 133 | velo = torch.tensor(note_df.velo.values).unsqueeze(0).float().to(device) / 128.0 134 | dur = torch.tensor(note_df.duration.values).unsqueeze(0).float().to(device) 135 | dt = torch.tensor(note_df.dt.values).unsqueeze(0).float().to(device) 136 | return pitch_ids, velo, dur, dt 137 | 138 | 139 | def batch_sampler(data_source, batch_idxs, max_seq_len=1024): 140 | batch_pitch = [] # integer value 141 | batch_velo = [] # integer value 142 | batch_dt = [] # float value 143 | batch_duration = [] # float value 144 | label_batch = [] 145 | if isinstance(batch_idxs, int): 146 | batch_idxs = range(batch_idxs) 147 | for i in batch_idxs: 148 | meta, note_seq = data_source[i] 149 | # sample note sequence with max_seq_len 150 | if len(note_seq) <= max_seq_len: 151 | note_seq_subsamp = note_seq 152 | else: 153 | seq_start = np.random.randint(0, len(note_seq) - max_seq_len) 154 | note_seq_subsamp = note_seq.iloc[seq_start:seq_start + max_seq_len] 155 | 156 | pitch, velo, dur, dt = notedf2tensor(note_seq_subsamp, device="cpu") 157 | batch_pitch.append(pitch[0]) 158 | batch_velo.append(velo[0]) 159 | batch_dt.append(dt[0]) 160 | batch_duration.append(dur[0]) 161 | # batch_pitch.append(torch.tensor(note_seq_subsamp["pitch"].values).long()) 162 | # batch_velo.append(torch.tensor(note_seq_subsamp["velo"].values).long()) 163 | # batch_dt.append(torch.tensor(note_seq_subsamp["dt"].values)) 164 | # batch_duration.append(torch.tensor(note_seq_subsamp["duration"].values)) 165 | 166 | batch_pitch = pad_sequence(batch_pitch, batch_first=True, padding_value=PAD_ID) 167 | batch_velo = pad_sequence(batch_velo, batch_first=True, padding_value=0.5) 168 | batch_dt = pad_sequence(batch_dt, batch_first=True, padding_value=0.0) 169 | batch_duration = pad_sequence(batch_duration, batch_first=True, padding_value=0.0) 170 | # label_batch = pad_sequence(label_batch, batch_first=True, padding_value=MASK_LABEL_ID) 171 | return batch_pitch, batch_velo, batch_dt, batch_duration 172 | 173 | 174 | batch_pitch, batch_velo, batch_dt, batch_duration = batch_sampler(dataset, range(32), max_seq_len=512) 175 | 176 | # %% Submodules for a Music Note Transformer 177 | class NoteReadoutHeads(nn.Module): 178 | def __init__(self, config): 179 | super().__init__() 180 | self.config = config 181 | self.velocity_head = nn.Sequential( 182 | nn.Linear(config.n_embd, config.n_embd), 183 | nn.Tanh(), 184 | nn.Linear(config.n_embd, 1), 185 | ) 186 | self.duration_head = nn.Sequential( 187 | nn.Linear(config.n_embd, config.n_embd), 188 | nn.Tanh(), 189 | nn.Linear(config.n_embd, 1), 190 | nn.Softplus(), # duration should be positive 191 | ) 192 | self.dt_head = nn.Sequential( 193 | nn.Linear(config.n_embd, config.n_embd), 194 | nn.Tanh(), 195 | nn.Linear(config.n_embd, 1), 196 | ) 197 | 198 | def forward(self, hidden_states): 199 | velocity = self.velocity_head(hidden_states).squeeze(-1) 200 | duration = self.duration_head(hidden_states).squeeze(-1) 201 | dt = self.dt_head(hidden_states).squeeze(-1) 202 | return velocity, duration, dt 203 | 204 | 205 | class ScalarEmbedding(nn.Module): 206 | def __init__(self, config, ): 207 | super().__init__() 208 | self.n_embd = config.n_embd 209 | self.weights = nn.Parameter(torch.randn(1, 1, self.n_embd) * np.sqrt(1 / self.n_embd)) 210 | 211 | def forward(self, x): 212 | return self.weights * x[:, :, None] 213 | 214 | # %% 215 | 216 | def naive_greedy_decode(model, prompt_ids, prompt_velos, prompt_durations, prompt_dts, max_gen_tokens=500, 217 | temperature=None): 218 | model.eval() 219 | readout.eval() 220 | dt_emb.eval() 221 | duration_emb.eval() 222 | velocity_emb.eval() 223 | pitch_embed = model.get_input_embeddings()(prompt_ids) 224 | velo_embed = velocity_emb(prompt_velos) 225 | dur_embed = duration_emb(prompt_durations) 226 | dt_embed = dt_emb(prompt_dts) 227 | input_embed = pitch_embed + velo_embed + dur_embed + dt_embed 228 | input_embed_full = input_embed.clone() 229 | generated_note_seq = [] 230 | with torch.no_grad(): 231 | for i in trange(max_gen_tokens): 232 | out = model(inputs_embeds=input_embed_full, output_hidden_states=True) 233 | last_hidden = out.hidden_states[-1] 234 | last_logits = out.logits[:, -1, :] 235 | # pred_velos, pred_durations, pred_dts = readout(last_hidden) 236 | pred_velo, pred_duration, pred_dt = readout(last_hidden[:, -1, :]) 237 | if temperature is None: 238 | # greedy decoding 239 | pitch_maxprob = torch.argmax(last_logits, dim=-1) 240 | else: 241 | # Sample from the distribution with temperature 242 | pitch_maxprob = torch.multinomial(F.softmax(last_logits / temperature, dim=-1), 243 | num_samples=1).squeeze(1) 244 | input_embed_new = model.get_input_embeddings()(pitch_maxprob) + \ 245 | velocity_emb(pred_velo.unsqueeze(0)) + \ 246 | duration_emb(pred_duration.unsqueeze(0)) + \ 247 | dt_emb(pred_dt.unsqueeze(0)) 248 | generated_note_seq.append((pitch_maxprob, pred_velo, pred_duration, pred_dt)) 249 | input_embed_full = torch.concat((input_embed_full, input_embed_new), dim=1) 250 | 251 | pitch_gen = torch.stack([pitch for pitch, velo, dur, dt in generated_note_seq], dim=1) 252 | velo_gen = torch.stack([velo for pitch, velo, dur, dt in generated_note_seq], dim=1) 253 | dur_gen = torch.stack([dur for pitch, velo, dur, dt in generated_note_seq], dim=1) 254 | dt_gen = torch.stack([dt for pitch, velo, dur, dt in generated_note_seq], dim=1) 255 | pitch_gen = torch.concat((prompt_ids, pitch_gen), dim=1).cpu() 256 | velo_gen = torch.concat((prompt_velos, velo_gen), dim=1).cpu() # this is 0, 1 normalized 257 | dur_gen = torch.concat((prompt_durations, dur_gen), dim=1).cpu() 258 | dt_gen = torch.concat((prompt_dts, dt_gen), dim=1).cpu() 259 | velo_gen = (velo_gen * 128).long() # convert to integer 260 | return pitch_gen, velo_gen, dur_gen, dt_gen 261 | 262 | 263 | def tensor2midi(pitch, velo, dur, dt, savedir, filename, instrument="Acoustic Grand Piano"): 264 | generated_midi = pretty_midi.PrettyMIDI() 265 | piano_program = pretty_midi.instrument_name_to_program(instrument) 266 | piano = pretty_midi.Instrument(program=piano_program) 267 | cur_time = 0 268 | for i, pitch in enumerate(pitch): 269 | cur_time = cur_time + dt[i].item() 270 | note = pretty_midi.Note(velocity=min(127, max(0, velo[i].item())), 271 | pitch=min(127, max(0, pitch.item())), 272 | start=cur_time, end=cur_time + dur[i].item()) 273 | cur_time = cur_time + dur[i].item() 274 | piano.notes.append(note) 275 | 276 | generated_midi.instruments.append(piano) 277 | generated_midi.write(join(savedir, filename)) 278 | return generated_midi 279 | 280 | 281 | 282 | # %% 283 | import torch.nn.functional as F 284 | from torch.optim import AdamW 285 | from torch.utils.tensorboard import SummaryWriter 286 | saveroot = "/home/binxu/DL_Projects/Maestro-GPT" #"runs" 287 | 288 | # config = GPT2Config(n_embd=128, n_layer=12, n_head=8, n_positions=512, n_ctx=128, 289 | # vocab_size=132, bos_token_id=BOS_id, eos_token_id=EOS_id, ) 290 | # config = GPT2Config(n_embd=128, n_layer=18, n_head=8, n_positions=512, n_ctx=128, 291 | # vocab_size=token_num + 4, bos_token_id=BOS_id, eos_token_id=EOS_id, ) 292 | config = GPT2Config(n_embd=128, n_layer=36, n_head=8, n_positions=256, 293 | vocab_size=token_num + 4, bos_token_id=BOS_id, eos_token_id=EOS_id, ) 294 | model = GPT2LMHeadModel(config) 295 | readout = NoteReadoutHeads(config) 296 | dt_emb = ScalarEmbedding(config).cuda() 297 | duration_emb = ScalarEmbedding(config).cuda() 298 | velocity_emb = ScalarEmbedding(config).cuda() 299 | model.cuda() 300 | readout.cuda() 301 | optimizer = AdamW([*model.parameters(), 302 | *readout.parameters(), 303 | *dt_emb.parameters(), 304 | *duration_emb.parameters(), 305 | *velocity_emb.parameters(), 306 | ], lr=10e-4) 307 | scheduler = get_linear_schedule_with_warmup(optimizer, 50, 3000) 308 | # savedir = join(saveroot, "runs_L18") 309 | # savedir = join(saveroot, "runs_L18_lrsched_nexttok") 310 | # savedir = join(saveroot, "runs_L36_ctx256_lrsched_nexttok") 311 | savedir = join(saveroot, "runs_L36_ctx256_lrsched_nexttok_fixeval") 312 | os.makedirs(savedir, exist_ok=True) 313 | os.makedirs(join(savedir, "ckpt"), exist_ok=True) 314 | # %% 315 | save_per_epoch = 10 316 | synth_per_epoch = 10 317 | batch_size = 24 318 | max_seq_len = 256 319 | writer = SummaryWriter(savedir) 320 | for epoch in trange(0, 3000): 321 | for module in [model, readout, dt_emb, duration_emb, velocity_emb]: 322 | module.train() 323 | rand_idx_seq = np.random.permutation(len(dataset)) 324 | for i, csr in enumerate(trange(0, len(dataset), batch_size)): 325 | batch_idxs = rand_idx_seq[csr:csr + batch_size] 326 | batch_pitch, batch_velo, batch_dt, batch_duration = batch_sampler(dataset, 327 | batch_idxs, max_seq_len=max_seq_len) 328 | pitch_embed = model.transformer.wte(batch_pitch.cuda()) 329 | note_embed = pitch_embed + velocity_emb(batch_velo.cuda()) + \ 330 | dt_emb(batch_dt.cuda()) + duration_emb(batch_duration.cuda()) # 331 | 332 | # out = model(batch_pitch.cuda(), labels=batch_pitch.cuda(), output_hidden_states=True) 333 | out = model(inputs_embeds=note_embed, labels=batch_pitch.cuda(), output_hidden_states=True) 334 | loss = out.loss 335 | last_hidden = out.hidden_states[-1] # (batch, seq_len, embd) 336 | velocity_pred, duration_pred, dt_pred = readout(last_hidden) 337 | # note this is wrong, need to predict the next note so shift by 1 338 | # loss_velo = F.mse_loss(velocity_pred, batch_velo.cuda()) 339 | # loss_duration = F.mse_loss(duration_pred, batch_duration.cuda()) 340 | # loss_dt = F.mse_loss(dt_pred, batch_dt.cuda()) 341 | # this is correct prediction 342 | loss_velo = F.mse_loss(velocity_pred[:, :-1], batch_velo.cuda()[:, 1:]) 343 | loss_duration = F.mse_loss(duration_pred[:, :-1], batch_duration.cuda()[:, 1:]) 344 | loss_dt = F.mse_loss(dt_pred[:, :-1], batch_dt.cuda()[:, 1:]) 345 | loss += loss_velo + loss_duration + loss_dt 346 | # compute additional loss based on last hidden state 347 | loss.backward() 348 | optimizer.step() 349 | optimizer.zero_grad() 350 | print( 351 | f"epoch{epoch}-step{i:03d} {loss.item():.5f} vel {loss_velo.item():.5f} dur {loss_duration.item():.5f} dt {loss_dt.item():.5f}") 352 | writer.add_scalar("loss", loss.item(), epoch * len(dataset) // batch_size + i) 353 | writer.add_scalar("loss_velo", loss_velo.item(), epoch * len(dataset) // batch_size + i) 354 | writer.add_scalar("loss_duration", loss_duration.item(), epoch * len(dataset) // batch_size + i) 355 | writer.add_scalar("loss_dt", loss_dt.item(), epoch * len(dataset) // batch_size + i) 356 | writer.add_scalar("lr", optimizer.param_groups[0]["lr"], epoch * len(dataset) // batch_size + i) 357 | writer.add_scalar("epoch", epoch, epoch * len(dataset) // batch_size + i) 358 | # writer.add_scalar("grad_norm", torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0),) 359 | scheduler.step() 360 | if epoch % save_per_epoch == 0: 361 | torch.save(model.state_dict(), join(savedir, "ckpt", f"model_{epoch:03d}.pt")) 362 | torch.save(readout.state_dict(), join(savedir, "ckpt", f"readout_{epoch:03d}.pt")) 363 | torch.save(dt_emb.state_dict(), join(savedir, "ckpt", f"dt_emb_{epoch:03d}.pt")) 364 | torch.save(duration_emb.state_dict(), join(savedir, "ckpt", f"duration_emb_{epoch:03d}.pt")) 365 | torch.save(velocity_emb.state_dict(), join(savedir, "ckpt", f"velocity_emb_{epoch:03d}.pt")) 366 | 367 | if epoch % synth_per_epoch == 0: 368 | prompt_ids, prompt_velos, prompt_durations, prompt_dts = notedf2tensor(note_seq[:10], device="cuda") 369 | pitch_gen, velo_gen, dur_gen, dt_gen = naive_greedy_decode(model, 370 | prompt_ids, prompt_velos, prompt_durations, prompt_dts, 371 | temperature=.2, max_gen_tokens=max_seq_len - 12) 372 | tensor2midi(pitch_gen[0], velo_gen[0], dur_gen[0], dt_gen[0], 373 | savedir, f"generated_midi_with_tempo_T02_{epoch:03d}.mid") 374 | #%% 375 | 376 | torch.save(readout.state_dict(), join(savedir, "ckpt", f"readout_{epoch:03d}.pt")) 377 | torch.save(dt_emb.state_dict(), join(savedir, "ckpt", f"dt_emb_{epoch:03d}.pt")) 378 | torch.save(duration_emb.state_dict(), join(savedir, "ckpt", f"duration_emb_{epoch:03d}.pt")) 379 | torch.save(velocity_emb.state_dict(), join(savedir, "ckpt", f"velocity_emb_{epoch:03d}.pt")) 380 | model.save_pretrained(join(savedir, "ckpt", f"model_{epoch:03d}.pt")) 381 | config.save_pretrained(join(savedir, "ckpt","config.json")) 382 | #%% 383 | # sampling from the model 384 | model.eval() 385 | readout.eval() 386 | dt_emb.eval() 387 | duration_emb.eval() 388 | velocity_emb.eval() 389 | # %% 390 | prompt_ids = torch.tensor(note_seq.pitch[:10]).unsqueeze(0).long().cuda() 391 | answers = model.generate(prompt_ids, max_length=128, do_sample=True, 392 | top_k=0, top_p=0.90, num_return_sequences=3, 393 | bos_token_id=BOS_id, eos_token_id=EOS_id, pad_token_id=PAD_ID, 394 | output_hidden_states=True, return_dict_in_generate=True) 395 | 396 | 397 | 398 | #%% Generate sequence 399 | prompt_ids = torch.tensor(note_seq.pitch[:10]).unsqueeze(0).long().cuda() 400 | answers = model.generate(prompt_ids, max_length=100, do_sample=True, 401 | top_k=0, top_p=0.90, num_return_sequences=3, 402 | bos_token_id=BOS_id, eos_token_id=EOS_id, pad_token_id=PAD_ID, 403 | output_hidden_states=True, return_dict_in_generate=True) 404 | #%% 405 | #%% 406 | # generate sequence with readout and embedding 407 | 408 | #%% 409 | 410 | 411 | # prompt_ids = torch.tensor(note_seq.pitch[:12]).unsqueeze(0).long().cuda() 412 | # prompt_velos = torch.tensor(note_seq.velo[:12]).unsqueeze(0).cuda().float() / 128.0 413 | # prompt_durations = torch.tensor(note_seq.duration[:12]).unsqueeze(0).cuda().float() 414 | # prompt_dts = torch.tensor(note_seq.dt[:12]).unsqueeze(0).cuda().float() 415 | prompt_ids, prompt_velos, prompt_durations, prompt_dts = notedf2tensor(note_seq[:10], device="cuda") 416 | pitch_gen, velo_gen, dur_gen, dt_gen = naive_greedy_decode(model, 417 | prompt_ids, prompt_velos, prompt_durations, prompt_dts, 418 | temperature=.5, max_gen_tokens=500) 419 | tensor2midi(pitch_gen[0], velo_gen[0], dur_gen[0], dt_gen[0], 420 | savedir, "generated_midi_with_tempo_T0_5.mid") 421 | 422 | 423 | #%% 424 | note_seq_full = answers.sequences[0].cpu() 425 | note_seq_full = note_seq_full[note_seq_full != PAD_ID] 426 | for pitch in note_seq_full: 427 | print(librosa.midi_to_note(pitch), end=" ") 428 | 429 | 430 | 431 | 432 | #%% save to midi and play 433 | # create midi file based on the generated sequence 434 | generated_midi = pretty_midi.PrettyMIDI() 435 | piano_program = pretty_midi.instrument_name_to_program('Acoustic Grand Piano') 436 | piano = pretty_midi.Instrument(program=piano_program) 437 | delta = 0.25 438 | for i, pitch in enumerate(note_seq_full): 439 | note = pretty_midi.Note(velocity=100, pitch=pitch.item(), start=i*delta, end=(i+1)*delta) 440 | piano.notes.append(note) 441 | #%% 442 | generated_midi.instruments.append(piano) 443 | generated_midi.write(join(savedir, "generated_midi.mid")) 444 | 445 | 446 | # %% Old version 447 | # dataset = MusicScoreDataset(maestro_root, maestro_meta) 448 | batch_size = 12 449 | writer = SummaryWriter(savedir) 450 | for epoch in trange(0, 100): 451 | rand_idx_seq = np.random.permutation(len(dataset)) 452 | for i, csr in enumerate(trange(0, len(dataset), batch_size)): 453 | batch_idxs = rand_idx_seq[csr:csr + batch_size] 454 | batch_pitch, batch_velo, batch_dt, batch_duration = batch_sampler(dataset, 455 | batch_idxs, max_seq_len=512) 456 | out = model(batch_pitch.cuda(), labels=batch_pitch.cuda(), output_hidden_states=True) 457 | loss = out.loss 458 | last_hidden = out.hidden_states[-1] # (batch, seq_len, embd) 459 | velocity_pred, duration_pred, dt_pred = readout(last_hidden) 460 | # note this is wrong, need to predict the next note so shift by 1 461 | loss_velo = F.mse_loss(velocity_pred, batch_velo.cuda()) 462 | loss_duration = F.mse_loss(duration_pred, batch_duration.cuda()) 463 | loss_dt = F.mse_loss(dt_pred, batch_dt.cuda()) 464 | loss += loss_velo + loss_duration + loss_dt 465 | # compute additional loss based on last hidden state 466 | loss.backward() 467 | optimizer.step() 468 | optimizer.zero_grad() 469 | print( 470 | f"epoch{epoch}-step{i:03d} {loss.item():.5f} vel {loss_velo.item():.5f} dur {loss_duration.item():.5f} dt {loss_dt.item():.5f}") 471 | writer.add_scalar("loss", loss.item(), epoch * len(dataset) // batch_size + i) 472 | writer.add_scalar("loss_velo", loss_velo.item(), epoch * len(dataset) // batch_size + i) 473 | writer.add_scalar("loss_duration", loss_duration.item(), epoch * len(dataset) // batch_size + i) 474 | writer.add_scalar("loss_dt", loss_dt.item(), epoch * len(dataset) // batch_size + i) 475 | writer.add_scalar("lr", optimizer.param_groups[0]["lr"], epoch * len(dataset) // batch_size + i) 476 | writer.add_scalar("epoch", epoch, epoch * len(dataset) // batch_size + i) 477 | # writer.add_scalar("grad_norm", torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0),) 478 | 479 | torch.save(model.state_dict(), join(savedir, "ckpt", f"model_{epoch:03d}.pt")) 480 | 481 | #%% 482 | 483 | 484 | 485 | 486 | 487 | #%%# %% 488 | # # load sample midi file 489 | # midi_file = join(maestro_root, maestro_meta["midi_filename"][0].replace("/", os.path.sep)) 490 | # midi_data = pretty_midi.PrettyMIDI(midi_file) 491 | # 492 | # print(midi_data.instruments) 493 | # len(midi_data.instruments[0].notes) 494 | # # %% 495 | # # create a new figure 496 | # fig, ax = plt.subplots(figsize=(12, 4)) 497 | # # plot the piano roll 498 | # plot_piano_roll(midi_data, 24, 84) 499 | # plt.tight_layout() 500 | # plt.show() 501 | # # %% 502 | # note_dt_dist = [] 503 | # note_len_dist = [] 504 | # note_velo_dist = [] 505 | # note_pitch_dist = [] 506 | # for i, note_sample in enumerate(midi_data.instruments[0].notes[1:]): 507 | # note_dt_dist.append(note_sample.start - midi_data.instruments[0].notes[i - 1].end) 508 | # note_len_dist.append(note_sample.duration) 509 | # note_velo_dist.append(note_sample.velocity) 510 | # note_pitch_dist.append(note_sample.pitch) 511 | # # %% 512 | # 513 | # # %% 514 | # plt.subplots(1, 3, figsize=(12, 4)) 515 | # plt.subplot(1, 3, 1) 516 | # plt.hist(note_len_dist, bins=5000) 517 | # plt.xlim(0, 3) 518 | # plt.title("Note Length Distribution") 519 | # plt.subplot(1, 3, 2) 520 | # plt.hist(note_velo_dist, bins=100) 521 | # plt.title("Note Velocity Distribution") 522 | # plt.subplot(1, 3, 3) 523 | # plt.hist(note_pitch_dist, bins=100) 524 | # plt.title("Note Pitch Distribution") 525 | # plt.show() 526 | # 527 | # # %% 528 | # print(len(maestro_meta.canonical_composer.unique())) 529 | # print(len(maestro_meta.canonical_title.unique())) 530 | # 531 | # # %% 532 | # notes_str = [librosa.midi_to_note(note.pitch) 533 | # for note in midi_data.instruments[0].notes] 534 | # # %% 535 | # midi_data.time_to_tick(midi_data.instruments[0].notes[-1].end) 536 | # # %% --------------------------------------------------------------------------------