├── .gitignore ├── imgs ├── piano_roll.png ├── rnn_unfold.jpg ├── piano_roll_2.png ├── character_level_model.jpg ├── sampling_temperature.png ├── piano_roll_early_sample.png └── piano_roll_late_sample.png ├── README.md ├── notebooks ├── visualization.py ├── unconditional_lyrics_sampling.ipynb ├── conditional_lyrics_sampling.ipynb └── truncated_backprop_music_generation.ipynb ├── blog_post.md └── blog_post.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | notebooks/models/ 3 | -------------------------------------------------------------------------------- /imgs/piano_roll.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmspringwinds/pytorch-rnn-sequence-generation-classification/HEAD/imgs/piano_roll.png -------------------------------------------------------------------------------- /imgs/rnn_unfold.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmspringwinds/pytorch-rnn-sequence-generation-classification/HEAD/imgs/rnn_unfold.jpg -------------------------------------------------------------------------------- /imgs/piano_roll_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmspringwinds/pytorch-rnn-sequence-generation-classification/HEAD/imgs/piano_roll_2.png -------------------------------------------------------------------------------- /imgs/character_level_model.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmspringwinds/pytorch-rnn-sequence-generation-classification/HEAD/imgs/character_level_model.jpg -------------------------------------------------------------------------------- /imgs/sampling_temperature.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmspringwinds/pytorch-rnn-sequence-generation-classification/HEAD/imgs/sampling_temperature.png -------------------------------------------------------------------------------- /imgs/piano_roll_early_sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmspringwinds/pytorch-rnn-sequence-generation-classification/HEAD/imgs/piano_roll_early_sample.png -------------------------------------------------------------------------------- /imgs/piano_roll_late_sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmspringwinds/pytorch-rnn-sequence-generation-classification/HEAD/imgs/piano_roll_late_sample.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Lyrics and piano music generation in Pytorch 2 | 3 | 4 | Implementation of generative character-level and multi-pitch-level rnn models described in "Learning to generate lyrics and music with Recurrent Neural Networks" [blog post](http://warmspringwinds.github.io/pytorch/rnns/2018/01/27/learning-to-generate-lyrics-and-music-with-recurrent-neural-networks/). 5 | 6 | The original jupyter notebook source of the blog can be found [here](blog_post.ipynb). 7 | 8 | Trained models can be downloaded via the following [link](https://www.dropbox.com/s/23d9n091jje8sct/music_lyrics_blogpost_models.zip?dl=0). You can skip training and 9 | sample using provided models (follow jupyter notebooks below). Some examples of the piano song samples [are available on youtube](https://www.youtube.com/watch?v=EOQQOQYvGnw&list=PLJkMX36nfYD000TG-T59hmEgJ3ojkOlBp), and examples of lyris samples can be found in the [original blog post](http://warmspringwinds.github.io/pytorch/rnns/2018/01/27/learning-to-generate-lyrics-and-music-with-recurrent-neural-networks/). 10 | 11 | ## Lyrics Generation 12 | 13 | We are providing jupyter notebooks for training and sampling from generative RNN-based models 14 | trained on a [song lyrics dataset](https://www.kaggle.com/mousehead/songlyrics) which features most 15 | popular/recent artists. Separate notebooks are avalable for: 16 | 17 | 1. Training of the unconditional RNN-based generative model on the specified lyrics dataset ([notebook file](notebooks/unconditional_lyrics_training.ipynb)). 18 | 2. Sampling from a trained unconditional RNN-based generative model ([notebook file](notebooks/unconditional_lyrics_sampling.ipynb)). 19 | 3. Training of the conditional RNN-based generative model ([notebook file](notebooks/conditional_lyrics_training.ipynb)). 20 | 4. Sampling from a trained conditional RNN-based generative model ([notebook file](notebooks/conditional_lyrics_sampling.ipynb)). 21 | 22 | ## Piano polyphonic midi song generation 23 | 24 | We are providing jupyter notebooks for training and sampling from generative RNN-based models 25 | trained on a [piano midi songs dataset](http://www-etud.iro.umontreal.ca/~boulanni/icml2012). Separate notebooks are avalable for: 26 | 27 | 1. Training of the RNN-based generative model on the specified piano midi dataset ([notebook file](notebooks/music_generation_training_nottingham.ipynb)). 28 | 2. Sampling from a trained RNN-based generative model ([notebook file](notebooks/music_sampling.ipynb)). -------------------------------------------------------------------------------- /notebooks/visualization.py: -------------------------------------------------------------------------------- 1 | class VizList(list): 2 | """Extended List class which can be binded to an matplotlib's pyplot axis 3 | and, when being appended a value, automatically update the figure. 4 | 5 | Originally designed to be used in a jupyter notebook with activated 6 | %matplotlib notebook mode. 7 | 8 | Example of usage: 9 | 10 | %matplotlib notebook 11 | from matplotlib import pyplot as plt 12 | f, (loss_axis, validation_axis) = plt.subplots(2, 1) 13 | loss_axis.set_title('Training loss') 14 | validation_axis.set_title('MIoU on validation dataset') 15 | plt.tight_layout() 16 | 17 | loss_list = VizList() 18 | validation_accuracy_res = VizList() 19 | train_accuracy_res = VizList() 20 | loss_axis.plot([], []) 21 | validation_axis.plot([], [], 'b', 22 | [], [], 'r') 23 | loss_list.bind_to_axis(loss_axis) 24 | validation_accuracy_res.bind_to_axis(validation_axis, 0) 25 | train_accuracy_res.bind_to_axis(validation_axis, 1) 26 | 27 | Now everytime the list are updated, the figure are updated 28 | automatically: 29 | 30 | # Run multiple times 31 | loss_list.append(1) 32 | loss_list.append(2) 33 | 34 | 35 | Attributes 36 | ---------- 37 | axis : pyplot axis object 38 | Axis object that is being binded with a list 39 | axis_index : int 40 | Index of the plot in the axis object to bind to 41 | 42 | """ 43 | 44 | def __init__(self, *args): 45 | 46 | super(VizList, self).__init__(*args) 47 | 48 | self.object_count = 0 49 | self.object_count_history = [] 50 | 51 | self.axis = None 52 | self.axis_index = None 53 | 54 | def append(self, object): 55 | 56 | self.object_count += 1 57 | self.object_count_history.append(self.object_count) 58 | super(VizList, self).append(object) 59 | 60 | self.update_axis() 61 | 62 | def bind_to_axis(self, axis, axis_index=0): 63 | 64 | self.axis = axis 65 | self.axis_index = axis_index 66 | 67 | def update_axis(self): 68 | 69 | self.axis.lines[self.axis_index].set_xdata(self.object_count_history) 70 | self.axis.lines[self.axis_index].set_ydata(self) 71 | 72 | self.axis.relim() 73 | self.axis.autoscale_view() 74 | self.axis.figure.canvas.draw() -------------------------------------------------------------------------------- /notebooks/unconditional_lyrics_sampling.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import torch\n", 12 | "import torch.nn as nn\n", 13 | "from torch.autograd import Variable\n", 14 | "\n", 15 | "import pandas as pd\n", 16 | "\n", 17 | "import random\n", 18 | "import string\n", 19 | "import numpy as np\n", 20 | "\n", 21 | "import sys, os\n", 22 | "\n", 23 | "import torch.utils.data as data\n", 24 | "\n", 25 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'\n", 26 | "\n", 27 | "all_characters = string.printable\n", 28 | "number_of_characters = len(all_characters)\n", 29 | "\n", 30 | "\n", 31 | "def character_to_label(character):\n", 32 | " \"\"\"Returns a one-hot-encoded tensor given a character.\n", 33 | " \n", 34 | " Uses string.printable as a dictionary.\n", 35 | " \n", 36 | " Parameters\n", 37 | " ----------\n", 38 | " character : str\n", 39 | " A character\n", 40 | " \n", 41 | " Returns\n", 42 | " -------\n", 43 | " one_hot_tensor : Tensor of shape (1, number_of_characters)\n", 44 | " One-hot-encoded tensor\n", 45 | " \"\"\"\n", 46 | " \n", 47 | " character_label = all_characters.find(character)\n", 48 | " \n", 49 | " return character_label\n", 50 | "\n", 51 | "def string_to_labels(character_string):\n", 52 | " \n", 53 | " return map(lambda character: character_to_label(character), character_string)\n", 54 | "\n", 55 | "\n", 56 | "class RNN(nn.Module):\n", 57 | " \n", 58 | " def __init__(self, input_size, hidden_size, num_classes, n_layers=2):\n", 59 | " \n", 60 | " super(RNN, self).__init__()\n", 61 | " \n", 62 | " self.input_size = input_size\n", 63 | " self.hidden_size = hidden_size\n", 64 | " self.num_classes = num_classes\n", 65 | " self.n_layers = n_layers\n", 66 | " \n", 67 | " # Converts labels into one-hot encoding and runs a linear\n", 68 | " # layer on each of the converted one-hot encoded elements\n", 69 | " \n", 70 | " # input_size -- size of the dictionary + 1 (accounts for padding constant)\n", 71 | " self.encoder = nn.Embedding(input_size, hidden_size)\n", 72 | " \n", 73 | " self.gru = nn.LSTM(hidden_size, hidden_size, n_layers)\n", 74 | " \n", 75 | " self.logits_fc = nn.Linear(hidden_size, num_classes)\n", 76 | " \n", 77 | " \n", 78 | " def forward(self, input_sequences, input_sequences_lengths, hidden=None):\n", 79 | " \n", 80 | " batch_size = input_sequences.shape[1]\n", 81 | "\n", 82 | " embedded = self.encoder(input_sequences)\n", 83 | "\n", 84 | " # Here we run rnns only on non-padded regions of the batch\n", 85 | " packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_sequences_lengths)\n", 86 | " outputs, hidden = self.gru(packed, hidden)\n", 87 | " outputs, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(outputs) # unpack (back to padded)\n", 88 | " \n", 89 | " logits = self.logits_fc(outputs)\n", 90 | " \n", 91 | " logits = logits.transpose(0, 1).contiguous()\n", 92 | " \n", 93 | " logits_flatten = logits.view(-1, self.num_classes)\n", 94 | " \n", 95 | " return logits_flatten, hidden\n" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 2, 101 | "metadata": { 102 | "collapsed": true 103 | }, 104 | "outputs": [], 105 | "source": [ 106 | "rnn = RNN(input_size=len(all_characters) + 1, hidden_size=512, num_classes=len(all_characters))\n", 107 | "rnn.load_state_dict(torch.load('models/unconditional_lyrics_rnn.pth'))\n", 108 | "rnn.cuda()\n", 109 | "\n", 110 | "def sample_from_rnn(starting_sting=\"Why\", sample_length=300, temperature=1):\n", 111 | "\n", 112 | " sampled_string = starting_sting\n", 113 | " hidden = None\n", 114 | "\n", 115 | " first_input = torch.LongTensor( string_to_labels(starting_sting) ).cuda()\n", 116 | " first_input = first_input.unsqueeze(1)\n", 117 | " current_input = Variable(first_input)\n", 118 | "\n", 119 | " output, hidden = rnn(current_input, [len(sampled_string)], hidden=hidden)\n", 120 | "\n", 121 | " output = output[-1, :].unsqueeze(0)\n", 122 | "\n", 123 | " for i in xrange(sample_length):\n", 124 | "\n", 125 | " output_dist = nn.functional.softmax( output.view(-1).div(temperature) ).data\n", 126 | "\n", 127 | " predicted_label = torch.multinomial(output_dist, 1)\n", 128 | "\n", 129 | " sampled_string += all_characters[int(predicted_label[0])]\n", 130 | "\n", 131 | " current_input = Variable(predicted_label.unsqueeze(1))\n", 132 | "\n", 133 | " output, hidden = rnn(current_input, [1], hidden=hidden)\n", 134 | " \n", 135 | " return sampled_string" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 7, 141 | "metadata": {}, 142 | "outputs": [ 143 | { 144 | "name": "stderr", 145 | "output_type": "stream", 146 | "text": [ 147 | "/home/daniil/repos/anaconda2/lib/python2.7/site-packages/ipykernel_launcher.py:20: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n" 148 | ] 149 | }, 150 | { 151 | "name": "stdout", 152 | "output_type": "stream", 153 | "text": [ 154 | "The end of the day \n", 155 | "I felt the love that I was lost \n", 156 | "I saw the world a shadow \n", 157 | "I saw the stars above \n", 158 | "I saw the friend of my days \n", 159 | "With my soul and the days \n", 160 | "That was my destiny \n", 161 | " \n", 162 | "The story ends \n", 163 | "The sound of the blue \n", 164 | "The tears were shining \n", 165 | "The story of my life \n", 166 | "I still believe \n", 167 | "The story of my life \n", 168 | " \n", 169 | "For the stars above \n", 170 | "The love I feel \n", 171 | "The stars in the sky \n", 172 | "The stars are bright \n", 173 | "The stars above \n", 174 | " \n", 175 | "The stars are bright \n", 176 | "The stars are bright \n", 177 | "The stars are bright \n", 178 | "The stars\n" 179 | ] 180 | } 181 | ], 182 | "source": [ 183 | "print(sample_from_rnn(temperature=0.5, starting_sting=\"The end\", sample_length=500))" 184 | ] 185 | } 186 | ], 187 | "metadata": { 188 | "kernelspec": { 189 | "display_name": "Python 2", 190 | "language": "python", 191 | "name": "python2" 192 | }, 193 | "language_info": { 194 | "codemirror_mode": { 195 | "name": "ipython", 196 | "version": 2 197 | }, 198 | "file_extension": ".py", 199 | "mimetype": "text/x-python", 200 | "name": "python", 201 | "nbconvert_exporter": "python", 202 | "pygments_lexer": "ipython2", 203 | "version": "2.7.14" 204 | } 205 | }, 206 | "nbformat": 4, 207 | "nbformat_minor": 2 208 | } 209 | -------------------------------------------------------------------------------- /notebooks/conditional_lyrics_sampling.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import torch\n", 12 | "import torch.nn as nn\n", 13 | "from torch.autograd import Variable\n", 14 | "\n", 15 | "import pandas as pd\n", 16 | "import random\n", 17 | "import string\n", 18 | "import numpy as np\n", 19 | "\n", 20 | "import sys, os\n", 21 | "\n", 22 | "import torch.utils.data as data\n", 23 | "\n", 24 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'\n", 25 | "\n", 26 | "all_characters = string.printable\n", 27 | "number_of_characters = len(all_characters)\n", 28 | "\n", 29 | "artists = [\n", 30 | "'ABBA',\n", 31 | "'Ace Of Base',\n", 32 | "'Aerosmith',\n", 33 | "'Avril Lavigne',\n", 34 | "'Backstreet Boys',\n", 35 | "'Bob Marley',\n", 36 | "'Bon Jovi',\n", 37 | "'Britney Spears',\n", 38 | "'Bruno Mars',\n", 39 | "'Coldplay',\n", 40 | "'Def Leppard',\n", 41 | "'Depeche Mode',\n", 42 | "'Ed Sheeran',\n", 43 | "'Elton John',\n", 44 | "'Elvis Presley',\n", 45 | "'Eminem',\n", 46 | "'Enrique Iglesias',\n", 47 | "'Evanescence',\n", 48 | "'Fall Out Boy',\n", 49 | "'Foo Fighters',\n", 50 | "'Green Day',\n", 51 | " 'HIM',\n", 52 | " 'Imagine Dragons',\n", 53 | " 'Incubus',\n", 54 | " 'Jimi Hendrix',\n", 55 | " 'Justin Bieber',\n", 56 | " 'Justin Timberlake',\n", 57 | "'Kanye West',\n", 58 | " 'Katy Perry',\n", 59 | " 'The Killers',\n", 60 | " 'Kiss',\n", 61 | " 'Lady Gaga',\n", 62 | " 'Lana Del Rey',\n", 63 | " 'Linkin Park',\n", 64 | " 'Madonna',\n", 65 | " 'Marilyn Manson',\n", 66 | " 'Maroon 5',\n", 67 | " 'Metallica',\n", 68 | " 'Michael Bolton',\n", 69 | " 'Michael Jackson',\n", 70 | " 'Miley Cyrus',\n", 71 | " 'Nickelback',\n", 72 | " 'Nightwish',\n", 73 | " 'Nirvana',\n", 74 | " 'Oasis',\n", 75 | " 'Offspring',\n", 76 | " 'One Direction',\n", 77 | " 'Ozzy Osbourne',\n", 78 | " 'P!nk',\n", 79 | " 'Queen',\n", 80 | " 'Radiohead',\n", 81 | " 'Red Hot Chili Peppers',\n", 82 | " 'Rihanna',\n", 83 | " 'Robbie Williams',\n", 84 | " 'Rolling Stones',\n", 85 | " 'Roxette',\n", 86 | " 'Scorpions',\n", 87 | " 'Snoop Dogg',\n", 88 | " 'Sting',\n", 89 | " 'The Script',\n", 90 | " 'U2',\n", 91 | " 'Weezer',\n", 92 | " 'Yellowcard',\n", 93 | " 'ZZ Top']\n", 94 | "\n", 95 | "\n", 96 | "def character_to_label(character):\n", 97 | " \"\"\"Returns a one-hot-encoded tensor given a character.\n", 98 | " \n", 99 | " Uses string.printable as a dictionary.\n", 100 | " \n", 101 | " Parameters\n", 102 | " ----------\n", 103 | " character : str\n", 104 | " A character\n", 105 | " \n", 106 | " Returns\n", 107 | " -------\n", 108 | " one_hot_tensor : Tensor of shape (1, number_of_characters)\n", 109 | " One-hot-encoded tensor\n", 110 | " \"\"\"\n", 111 | " \n", 112 | " character_label = all_characters.find(character)\n", 113 | " \n", 114 | " return character_label\n", 115 | "\n", 116 | "\n", 117 | "\n", 118 | "def string_to_labels(character_string):\n", 119 | " \n", 120 | " return map(lambda character: character_to_label(character), character_string)\n", 121 | "\n", 122 | "\n", 123 | "class RNN(nn.Module):\n", 124 | " \n", 125 | " def __init__(self, input_size, hidden_size, num_classes, num_conditions, n_layers=2):\n", 126 | " \n", 127 | " super(RNN, self).__init__()\n", 128 | " \n", 129 | " self.input_size = input_size\n", 130 | " self.hidden_size = hidden_size\n", 131 | " self.num_classes = num_classes\n", 132 | " self.n_layers = n_layers\n", 133 | " self.num_conditions = num_conditions\n", 134 | " \n", 135 | " # Converts labels into one-hot encoding and runs a linear\n", 136 | " # layer on each of the converted one-hot encoded elements\n", 137 | " \n", 138 | " # input_size -- size of the dictionary + 1 (accounts for padding constant)\n", 139 | " self.characters_encoder = nn.Embedding(input_size, hidden_size)\n", 140 | " \n", 141 | " self.conditions_encoder = nn.Embedding(num_conditions, hidden_size)\n", 142 | " \n", 143 | " self.lstm = nn.LSTM(hidden_size * 2, hidden_size, n_layers)\n", 144 | " \n", 145 | " self.logits_fc = nn.Linear(hidden_size, num_classes)\n", 146 | " \n", 147 | " \n", 148 | " def forward(self, input_sequences, input_sequences_conditions, input_sequences_lengths, hidden=None):\n", 149 | " \n", 150 | " batch_size = input_sequences.shape[1]\n", 151 | "\n", 152 | " characters_encoded = self.characters_encoder(input_sequences)\n", 153 | " conditions_endoded = self.conditions_encoder(input_sequences_conditions)\n", 154 | " \n", 155 | " encodings_combined = torch.cat((characters_encoded, conditions_endoded), dim=2)\n", 156 | "\n", 157 | " # Here we run rnns only on non-padded regions of the batch\n", 158 | " packed = torch.nn.utils.rnn.pack_padded_sequence(encodings_combined, input_sequences_lengths)\n", 159 | " outputs, hidden = self.lstm(packed, hidden)\n", 160 | " outputs, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(outputs) # unpack (back to padded)\n", 161 | " \n", 162 | " logits = self.logits_fc(outputs)\n", 163 | " \n", 164 | " logits = logits.transpose(0, 1).contiguous()\n", 165 | " \n", 166 | " logits_flatten = logits.view(-1, self.num_classes)\n", 167 | " \n", 168 | " return logits_flatten, hidden" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 5, 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "rnn = RNN(input_size=len(all_characters) + 1,\n", 178 | " hidden_size=512,\n", 179 | " num_classes=len(all_characters),\n", 180 | " num_conditions=len(artists))\n", 181 | "\n", 182 | "\n", 183 | "rnn.load_state_dict(torch.load('models/conditional_lyrics_rnn.pth'))\n", 184 | "\n", 185 | "rnn.cuda()\n", 186 | "\n", 187 | "def sample_from_rnn_conditionally(starting_sting=\"Why\", sample_length=300, temperature=1, artist_label=0):\n", 188 | " \n", 189 | " sampled_string = starting_sting\n", 190 | " hidden = None\n", 191 | "\n", 192 | " first_input = torch.LongTensor( string_to_labels(starting_sting) ).cuda()\n", 193 | " first_input = first_input.unsqueeze(1)\n", 194 | "\n", 195 | " # Expand the artist label to have the same size as input sequence\n", 196 | " # we duplicate it in every input\n", 197 | " artist_label_input = torch.LongTensor([artist_label]).expand_as(first_input)\n", 198 | "\n", 199 | " current_sequence_input = Variable(first_input)\n", 200 | " current_artist_input = Variable(artist_label_input.cuda())\n", 201 | "\n", 202 | " output, hidden = rnn(current_sequence_input, current_artist_input, [len(sampled_string)], hidden=hidden)\n", 203 | "\n", 204 | " output = output[-1, :].unsqueeze(0)\n", 205 | "\n", 206 | " for i in xrange(sample_length):\n", 207 | "\n", 208 | " output_dist = nn.functional.softmax( output.view(-1).div(temperature) ).data\n", 209 | "\n", 210 | " predicted_label = torch.multinomial(output_dist, 1)\n", 211 | "\n", 212 | " sampled_string += all_characters[int(predicted_label[0])]\n", 213 | " current_sequence_input = Variable(predicted_label.unsqueeze(1))\n", 214 | "\n", 215 | " artist_label_input = torch.LongTensor([artist_label]).expand_as(current_sequence_input)\n", 216 | " current_artist_input = Variable(artist_label_input.cuda())\n", 217 | "\n", 218 | " output, hidden = rnn(current_sequence_input, current_artist_input, [1], hidden=hidden)\n", 219 | " \n", 220 | " return sampled_string" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 7, 226 | "metadata": {}, 227 | "outputs": [ 228 | { 229 | "name": "stderr", 230 | "output_type": "stream", 231 | "text": [ 232 | "/home/daniil/repos/anaconda2/lib/python2.7/site-packages/ipykernel_launcher.py:32: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n" 233 | ] 234 | }, 235 | { 236 | "name": "stdout", 237 | "output_type": "stream", 238 | "text": [ 239 | "Why do you hear my soul? \n", 240 | "If I could change your back \n", 241 | "When you see where I belong? \n", 242 | "If I had the chance I could make any strange? \n", 243 | "When youre not it agree \n", 244 | "I would not live the way you love me \n", 245 | " \n", 246 | "When I was young your drives \n", 247 | "What's your days when I was young \n", 248 | "What a while? \n", 249 | "I wouldn't miss \n" 250 | ] 251 | } 252 | ], 253 | "source": [ 254 | "print(sample_from_rnn_conditionally(artist_label=artists.index(\"Queen\"),\n", 255 | " temperature=0.5,\n", 256 | " starting_sting=\"Why\"))" 257 | ] 258 | } 259 | ], 260 | "metadata": { 261 | "kernelspec": { 262 | "display_name": "Python 2", 263 | "language": "python", 264 | "name": "python2" 265 | }, 266 | "language_info": { 267 | "codemirror_mode": { 268 | "name": "ipython", 269 | "version": 2 270 | }, 271 | "file_extension": ".py", 272 | "mimetype": "text/x-python", 273 | "name": "python", 274 | "nbconvert_exporter": "python", 275 | "pygments_lexer": "ipython2", 276 | "version": "2.7.14" 277 | } 278 | }, 279 | "nbformat": 4, 280 | "nbformat_minor": 2 281 | } 282 | -------------------------------------------------------------------------------- /blog_post.md: -------------------------------------------------------------------------------- 1 | 2 | ## Learning to generate lyrics and music with Recurrent Neural Networks 3 | 4 | In this post we will train RNN character-level language model on lyrics dataset of 5 | most popular/recent artists. Having a trained model, we will sample a couple of 6 | songs which will be a funny mixture of different styles of different artists. 7 | After that we will update our model to become a conditional character-level RNN, 8 | making it possible for us to sample songs conditioned on artist. 9 | And finally, we conclude by training our model on midi dataset of piano songs. 10 | While solving all these tasks, we will briefly explore some interesting concepts related to RNN 11 | training and inference like character-level RNN, conditional character-level RNN, 12 | sampling from RNN, truncated backpropagation through time and gradient checkpointing. 13 | 14 | 15 | ### Character-Level language model 16 | 17 |  18 | 19 | Before choosing a model, let's have a closer look at our task. Given current letter and all previous 20 | letters, we will try to predict the next character. During training we will just take a sequence, and use 21 | all its characters except the last one as an input and the same sequence starting from the second character as groundtruth (see the picture above; [Source](https://github.com/spro/practical-pytorch/blob/master/conditional-char-rnn/conditional-char-rnn.ipynb)). We will start from the simplest model that ignores all the previous characters while making a prediction, improve this model to make it take only a certain number of previous characters into account, and conclude with a model that takes all the previous characters into consideration while making a prediction. 22 | 23 | Our language model is defined on a character level. We will create a dictionary which will contain 24 | all English characters plus some special symbols, like period, comma, and end-of-line symbol. Each charecter will be represented as one-hot-encoded tensor. For more information about character-level models and examples, I recommend [this resource](https://github.com/spro/practical-pytorch). 25 | 26 | Having characters, we can now form sequences of characters. We can generate sentences even now just by 27 | randomly sampling character after character with a fixed probability $p(any~letter)=\frac{1}{dictionary~size}$. 28 | That's the most simple character level language model. Can we do better than this? Yes, we can compute the probabily of occurance of each letter from our training corpus (number of times a letter occures divided by the size of our dataset) and randomly sample letter using these probabilities. This model is better but it totally ignores the relative positional aspect of each letter. For example, pay attention on how you read any word: you start with the first letter, which is usually hard to predict, but as you reach the end of a word you can sometimes guess the next letter. When you read any word you are implicitly using some rules which you learned by reading other texts: for example, with each additional letter that you read from a word, the probability of a space character increases (really long words are rare) or the probability of any consonant after the letter "r" is low as it usually followed by vowel. There are lot of similar rules and we hope that our model will be able to learn them from data. To give our model a chance to learn these rules we need to extend it. 29 | 30 | Let's make a small gradual improvement of our model and let probability of each letter depend 31 | only on the previously occured letter ([markov assumption](https://en.wikipedia.org/wiki/Markov_property)). So, basically we will have $p(current~letter|previous~letter)$. 32 | This is a [Markov chain model](https://en.wikipedia.org/wiki/Markov_chain) (also try these [interactive visualizations](http://setosa.io/ev/markov-chains/) if you are not familiar with it). We can also estimate the probability distribution $p(current~letter|previous~letter)$ from our training dataset. This model is limited because in most cases the probability of the current letter depends not only on the previous letter. 33 | 34 | What we would like to model is actually $p(current~letter|all~previous~letters)$. At first, the task seems intractable as the number of previous letters is variable and it might become really large in case of long 35 | sequences. Turns out Reccurent Neural Netoworks can tackle this problem to a certain extent by using shared weights and fixed size hidden state. This leads us to a next section dedicated to RNNs. 36 | 37 | ### Recurrent Neural Networks 38 | 39 |  40 | 41 | Recurrent neural networks are a family of neural networks for processing sequential data. 42 | Unlike feedforward neural networks, RNNs can use their internal memory to process arbitrary sequences of inputs. 43 | Because of arbitrary size input sequences, they are concisely depicted as a graph with a cycle (see the picture; [Source](http://www.wildml.com/2015/09/recurrent-neural-networks-tutorial-part-1-introduction-to-rnns/)). 44 | But they can be "unfolded" if the size of input sequence is known. They define a non-linear mapping from a current input $x_t$ and previous hidden state $s_{t-1}$ to the output $o_t$ and current hidden state $s_t$. Hidden state size has a predefined size and stores features which are updated on each step and affect the result of mapping. 45 | 46 | Now align the previous picture of the character-level language model and the ufolded RNN picture to see how 47 | we are using the RNN model to learn a character level language model. 48 | 49 | While the picture depicts the Vanilla RNN, we will use LSTM in our work as it is easier to train usually achieves better results. 50 | 51 | For a more elaborate introduction to RNNs, we refer reader to the [following resource](http://www.wildml.com/2015/09/recurrent-neural-networks-tutorial-part-1-introduction-to-rnns/). 52 | 53 | ### Lyrics dataset 54 | 55 | For our experiments we have chosen [55000+ Song Lyrics Kaggle dataset](https://www.kaggle.com/mousehead/songlyrics) which contains good variety of recent artists and more older ones. It is stored as a pandas file and we wrote a python wrapper around it to be able to use it for training purposes. You will have to download it yourself in order to be able to use our code. 56 | 57 | In order to be able to interpret the results better, I have chosen a subset of artists which I am 58 | more or less familiar with: 59 | 60 | 61 | ```python 62 | artists = [ 63 | 'ABBA', 64 | 'Ace Of Base', 65 | 'Aerosmith', 66 | 'Avril Lavigne', 67 | 'Backstreet Boys', 68 | 'Bob Marley', 69 | 'Bon Jovi', 70 | 'Britney Spears', 71 | 'Bruno Mars', 72 | 'Coldplay', 73 | 'Def Leppard', 74 | 'Depeche Mode', 75 | 'Ed Sheeran', 76 | 'Elton John', 77 | 'Elvis Presley', 78 | 'Eminem', 79 | 'Enrique Iglesias', 80 | 'Evanescence', 81 | 'Fall Out Boy', 82 | 'Foo Fighters', 83 | 'Green Day', 84 | 'HIM', 85 | 'Imagine Dragons', 86 | 'Incubus', 87 | 'Jimi Hendrix', 88 | 'Justin Bieber', 89 | 'Justin Timberlake', 90 | 'Kanye West', 91 | 'Katy Perry', 92 | 'The Killers', 93 | 'Kiss', 94 | 'Lady Gaga', 95 | 'Lana Del Rey', 96 | 'Linkin Park', 97 | 'Madonna', 98 | 'Marilyn Manson', 99 | 'Maroon 5', 100 | 'Metallica', 101 | 'Michael Bolton', 102 | 'Michael Jackson', 103 | 'Miley Cyrus', 104 | 'Nickelback', 105 | 'Nightwish', 106 | 'Nirvana', 107 | 'Oasis', 108 | 'Offspring', 109 | 'One Direction', 110 | 'Ozzy Osbourne', 111 | 'P!nk', 112 | 'Queen', 113 | 'Radiohead', 114 | 'Red Hot Chili Peppers', 115 | 'Rihanna', 116 | 'Robbie Williams', 117 | 'Rolling Stones', 118 | 'Roxette', 119 | 'Scorpions', 120 | 'Snoop Dogg', 121 | 'Sting', 122 | 'The Script', 123 | 'U2', 124 | 'Weezer', 125 | 'Yellowcard', 126 | 'ZZ Top'] 127 | ``` 128 | 129 | ### Training unconditional character-level language model 130 | 131 | Our first experiment consisted of training of our character-level language model RNN 132 | on the whole corpus. We didn't take into consideration the artist information while training. 133 | 134 | ### Sampling from RNN 135 | 136 | Let's try to sample a couple of songs after training our model. Basically, on each 137 | step our RNN will output logits and we can softmax them and sample from that distribution. 138 | Or we can use Gumble-Max trick and [sample using logits directly](https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/) which is equivalent. 139 | 140 | One intersting thing about sampling is that we can partially define the input sequence ourselves and start sampling 141 | with that initial condition. For example, we can sample a song that starts with "Why": 142 | 143 | ``` 144 | Why do you have to leave me? 145 | I think I know I'm not the only one 146 | I don't know if I'm gonna stay awake 147 | I don't know why I go along 148 | 149 | I don't know why I can't go on 150 | I don't know why I don't know 151 | I don't know why I don't know 152 | I don't know why I keep on dreaming of you 153 | ``` 154 | 155 | Well, that sounds like a possible song :D 156 | 157 | Let's sample with a song that starts with "Well": 158 | 159 | ``` 160 | Well, I was a real good time 161 | I was a rolling stone 162 | I was a rock and roller 163 | Well, I never had a rock and roll 164 | There were times I had to do it 165 | I had a feeling that I was found 166 | I was the one who had to go 167 | ``` 168 | 169 | There is "temperature" parameter that is used during sampling which controls the randomness of sampling 170 | process. When this parameter approaches zero, 171 | the sampling is equivalent to argmax and when it is close to infinity the sampling is equivalent to sampling 172 | from a uniform distribution. Have a look at the figure from a [relevant paper by Jang et al.](https://arxiv.org/pdf/1611.01144.pdf): 173 | 174 | 175 |  176 | 177 | When $\tau=1$, the distribution is not affected. If we decrease $\tau$, the distribution 178 | becomes more pronounced, meaning that value with bigger probability mass will have it increased. When $\tau$ will approach zero, sampling will be equivalent to armax, because the probability of that value will be close to one. When we start to icrease $\tau$ the distribution becomes more and more uniform. 179 | 180 | The previous sample was generated with a temperature paramter equal to $0.5$. 181 | Let's see what happens when we increase it to $1.0$ and sample: 182 | 183 | ``` 184 | Why can't we drop out of time? 185 | We were born for words to see. 186 | Won't you love this. You're still so amazing. 187 | This could be that down on Sunday Time. 188 | Oh, Caroline, a lady floor. 189 | I thought of love, oh baby. 190 | ``` 191 | 192 | Let's try increasing it even more: 193 | 194 | 195 | ``` 196 | Why - won't we grow up naked? 197 | We went quietly what we would've still give 198 | That girl you walked before our bedroom room 199 | I see your mind is so small to a freak 200 | Stretching for a cold white-heart of crashing 201 | Truth in the universal daughter 202 | 203 | I lose more and more hard 204 | I love you anytime at all 205 | Ah come let your help remind me 206 | Now I've wanted waste and never noticed 207 | 208 | I swear I saw you today 209 | You needed to get by 210 | But you sold a hurricane 211 | Well out whispered in store 212 | ``` 213 | 214 | Why don't we grow up naked, indeed? :D 215 | Well, you can see that trend that when we increase the temperature, sampled 216 | sentences become more and more random. 217 | 218 | ### Training conditional character-level language model 219 | 220 | Imagine if we could generate lyrics in a style of some particular artist. 221 | Let's change our model, so that it can use this information during training. 222 | 223 | We will do this by adding an additional input to our RNN. So far, our RNN model 224 | was only accepting tensors containing one-hot encoded character on each step. 225 | 226 | The extention to our model will be very simple: we will have and additional one-hot encoded 227 | tensor which will represent the artist. So on each step the RNN will accept one tensor which will consist of concatenated tensors representing character and artist. Look [here for more](https://github.com/spro/practical-pytorch/blob/master/conditional-char-rnn/conditional-char-rnn.ipynb). 228 | 229 | ### Sampling from conditional language model RNN 230 | 231 | After training, we sampled a couple of songs conditined on artist. 232 | Below you can find some results. 233 | 234 | Him: 235 | 236 | ``` 237 | My fears 238 | And the moment don't make me sing 239 | So free from you 240 | The pain you love me yeah 241 | 242 | Whatever caused the warmth 243 | You smile you're happy 244 | You sit away 245 | You say it's all in vain 246 | ``` 247 | 248 | Seems really possible, especially the fact the the word pain was used, which is 249 | very common in the lyrics of the artist. 250 | 251 | ABBA: 252 | 253 | ``` 254 | Oh, my love it makes me close a thing 255 | You've been heard, I must have waited 256 | I hear you 257 | So I say 258 | Thank you for the music, that makes me cry 259 | 260 | And you moving my bad as me, ah-hang wind in the hell 261 | I was meant to be with you, I'll never be playing up 262 | ``` 263 | 264 | Bob Marley: 265 | 266 | ``` 267 | Mercy on judgment, we got so much 268 | 269 | Alcohol, cry, cry, cry 270 | Why don't try to find our own 271 | I want to know, Lord, I wanna give you 272 | Just saving it, learned 273 | Is there any more? 274 | 275 | All that damage done 276 | That's all reason, don't worry 277 | Need a hammer 278 | I need you more and more 279 | ``` 280 | 281 | Coldplay: 282 | 283 | ``` 284 | Look at the stars 285 | Into life matter where you lay 286 | Saying no doubt 287 | I don't want to fly 288 | In my dreams and fight today 289 | 290 | I will fall for you 291 | 292 | All I know 293 | And I want you to stay 294 | Into the night 295 | 296 | I want to live waiting 297 | With my love and always 298 | Have I wouldn't wasted 299 | Would it hurt you 300 | ``` 301 | 302 | Kanye West: 303 | 304 | ``` 305 | I'm everywhere for you 306 | The way that it couldn't stop 307 | I mean it too late and love I made in the world 308 | I told you so I took the studs full cold-stop 309 | The hardest stressed growin' 310 | The hustler raisin' on my tears 311 | I know I'm true, one of your love 312 | ``` 313 | 314 | Looks pretty cool but keep in mind that we didn't track the validation accuracy so some sampled lines could have been just memorized by our rnn. A better way to do it is to pick a model that gives best validation score during training (see the code for the next section where we performed training this way). We also noticed one interesting thing: the unconditional 315 | model usually performes better when you want to sample with a specified starting string. 316 | Our intuition is that when sampling from a conditional model with a specified starting string, 317 | we actually put two conditions on our model -- starting string and an artist compared to the one condition 318 | in the case of previous model that we explored. And we didn't have enough data to model that conditional 319 | distribution well (every artist has relatively limited number of songs). 320 | 321 | We are making the code and models available and you can sample songs from our trained models 322 | even without gpu as it is not really computationally demanding. 323 | 324 | ### Midi dataset 325 | 326 | Next, we will work with a [small midi dataset](http://www-etud.iro.umontreal.ca/~boulanni/icml2012) consisting 327 | from approximately $700$ piano songs. We have used the ```Nottingam``` piano dataset (training split only). 328 | 329 | Turns out that any midi file can be [converted to piano roll](http://nbviewer.jupyter.org/github/craffel/pretty-midi/blob/master/Tutorial.ipynb) which is just is a time-frequency matrix where each row is a different MIDI pitch and each column is a different slice in time. So each piano song from our dataset will be represented as a matrix of size $88\times song\_length$, where $88$ is a number of pitches of the piano. Here is an example of 330 | piano roll matrix: 331 | 332 |  333 | 334 | This representation is very intuitive and easy to interpret even for a person that is not familiar 335 | with music theory. Each row represents a pitch: top rows represent low frequency pitches and bottom 336 | rows represent high pitches. Plus, we have a horizontal axis which represents time. So if we play a sound 337 | with a certain pitch for a certian period of time, we will see a horizontal line. Overall, this is very 338 | similar to [piano tutorials on youtube](. 339 | 340 | Now, let's try to see the similarities between the character-level model and our new task. In the current case, we will have to predict the pitches that will be played on the next timestep, given all the previously played 341 | pitches. So, if you look at the picture of the piano roll, each column represents some kind of a musical character and given all the previous musical characters, we want to predict the next one. Let's pay attention to the difference between a text character and a musical character. If you recall, each character in our language model was represented by one-hot vector (meaning that only one value in our vector is $1$ and others are $0$). 342 | For music character multiple keys can be pressed at one timestep (since we are working with polyphonic dataset). 343 | In this case, each timestep will be represented by a vector which can contain more than one $1$. 344 | 345 | 346 | 347 | ### Training pitch-level piano music model 348 | 349 | Before starting the training, we will have to adjust our loss that we have used for language model 350 | to account for different input that we discussed in the previous section. In the language model, 351 | we had one-hot encoded tensor (character) as an input on each timestep and one-hot encoded tensor as output (predicted next character). As we had to make a single exlusive choice for predicted next character, we used 352 | [cross-entropy loss](https://rdipietro.github.io/friendly-intro-to-cross-entropy-loss/). 353 | 354 | But now our model outputs a vector which is no longer one-hot encoded (multiple keys can be pressed). Of course, we can treat all possible combinations of pressed keys as a separate class, but this is intractable. Instead, we will treat each element of the output vector as a binary variable ($1$ -- pressing, $0$ -- not pressing a key). We will define a separate loss 355 | for each element of the output vector to be binary cross-entropy. And our final loss will be an averaged sum of these binary cross-entropies. You can also read the code to get a better understanding. 356 | 357 | After making the aforementioned changes, we trained our model. In the next section, we will perform sampling 358 | and inspect the results. 359 | 360 | ### Sampling from pitch-level RNN 361 | 362 | We have sampled piano rolls during the early stages of optimization: 363 | 364 |  365 | 366 | You can see that our model is starting to learn one common pattern that is common among the songs from 367 | our dataset: each song consists of two different parts. First part contains a sequence of pitches that are played separately and are very [distinguishable and are often singable](https://www.didjshop.com/BasicMusicalHarmony.html) (also know as melody). If you look at the sampled piano roll, this part can be clearly seen in the bottom. If you also have a look at the top of our piano roll, we can see a group of pitches that are usually played together -- this is harmony or a progression of chords (pitches that are played together throughout the song) which accompanies the melody. 368 | 369 | By the end of the training samples drawn from our model started to look like this: 370 | 371 |  372 | 373 | As you can see they started to look more similar to the picture of the ground-truth piano roll that we showed 374 | in the previous sections. 375 | 376 | After training, we have sampled songs and analyzed them. We got one sample with [an interesting introduction](https://www.youtube.com/watch?v=Iz8xQou2OqA). While another sample features [a nice style transition](https://www.youtube.com/watch?v=fUdsWVIOeeU&feature=youtu.be&t=15s). At the same time we generated a couple of examples with low temperature parameter which resulted in songs with a slow tempo: [first one](https://www.youtube.com/watch?v=UoLyeauBsNk) and a [second one here](https://www.youtube.com/watch?v=Iz8xQou2OqA). 377 | You can find the whole playlist [here](https://www.youtube.com/watch?v=EOQQOQYvGnw&list=PLJkMX36nfYD000TG-T59hmEgJ3ojkOlBp). 378 | 379 | 380 | ### Sequence length and related problems 381 | 382 | Now let's look at our problem from the gpu memory consumption and speed point of view. 383 | 384 | We greatly speed up computation by processing our sequences in batches. At the same time, as 385 | our sequences become longer (depending on the dataset), our max batch size starts to decrease. 386 | Why is it a case? As we use backpropagation to compute gradients, we need to store all the intermediate acitvations, which contribute the most to the memory consumption. As our sequence becomes longer, we need to store more activations, therefore, we can fit less examples in our batch. 387 | 388 | Sometimes, we either have to work with really long sequences or we want to increase our batch size or maybe you just have a gpu with small amount of memory available. There are multiple possible solutions to reduce memory 389 | consumption in this case, but we will mention two, which will have different trade-offs. 390 | 391 | First one is a [truncated back propagation](https://www.quora.com/Whats-the-key-difference-between-backprop-and-truncated-backprop-through-time#sLRGO). The idea is to split the whole sequence into subsequences and treat 392 | them as separate batches with an exception that we process these batches in the order of split and every next batch uses hidden state of previous batch as an initial hidden state. We also provide an implementation of this approach, so that you can get the better understanding. This approach is obviously not an exact equivalent of processing the whole sequence but it makes more frequent updates and consumes less memory. On the other hand, there is a chance that we might not be able to capture long-term dependencies that span beyond the length of one subsequence. 393 | 394 | Second one is [gradient checkpointing](https://medium.com/@yaroslavvb/fitting-larger-networks-into-memory-583e3c758ff9). This method gives us a possibilty to use less memory while training our model on the whole sequence on the expence of performing more computation. If you recall, previously we mentioned that the most memory during 395 | training is occupied by activations. The idea of gradient checkpointing consists of storing only every $n$-th activation and recomputing the unsaved activations later. This method is already [implemented in Tensorflow](https://github.com/openai/gradient-checkpointing) and [being implemented in Pytorch](https://github.com/pytorch/pytorch/pull/4594). 396 | 397 | 398 | 399 | ### Conclusion and future work 400 | 401 | In our work we trained simple generative model for text, extended our model to work with 402 | polyphonic music, briefly looked at how sampling works and how the temperature parameter affects our 403 | text and music samples -- low temperature gives more stable results while high temperature adds more 404 | randomness which sometimes gives rise to very interesting samples. 405 | 406 | Future work can include two directions -- more applications or deeper analysis of the already trained models. 407 | Same models can be applied to your spotify listening history, for example. After training on your 408 | listening history data, you can give it a sequence of songs that you have listened to in the previous hour or so, and it will sample a playlist for you for the rest of the day. Well, you can also do the same for your browsing history, which will be just a cool tool to analyze your browsing behaviour patterns. [Capture the accelerometer and gyroscope data](https://www.kaggle.com/uciml/human-activity-recognition-with-smartphones) from your phone while doing different activities (exercising in the gym, working in the office, sleeping) and learn to classify these 409 | activity stages. After that you can change your music playlist automatically, based on your activity (sleeping -- calm music of rain, exercising in the gym -- high intensity music). In terms of medical applications, model can 410 | be applied to detect heart problems based on pulse and other data, similar to [this work](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC5391725/). 411 | 412 | It would be very interesting to analyze the neuron firings in our RNN trained 413 | for music generation like [here](http://karpathy.github.io/2015/05/21/rnn-effectiveness/). To see if the model learned some simple music concepts implicitly (like our discussion 414 | of harmony and melody). The hidden representation of RNN can be used to cluster our music dataset to 415 | find similar songs. 416 | 417 | Let's sample one last lyrics from our unconditional model to conclude this post :D : 418 | 419 | ``` 420 | The story ends 421 | The sound of the blue 422 | The tears were shining 423 | The story of my life 424 | I still believe 425 | The story of my life 426 | ``` 427 | -------------------------------------------------------------------------------- /blog_post.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Learning to generate lyrics and music with Recurrent Neural Networks" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "In this post we will train RNN character-level language model on lyrics dataset of\n", 15 | "most popular/recent artists. Having a trained model, we will sample a couple of\n", 16 | "songs which will be a funny mixture of different styles of different artists.\n", 17 | "After that we will update our model to become a conditional character-level RNN,\n", 18 | "making it possible for us to sample songs conditioned on artist.\n", 19 | "And finally, we conclude by training our model on midi dataset of piano songs.\n", 20 | "While solving all these tasks, we will briefly explore some interesting concepts related to RNN\n", 21 | "training and inference like character-level RNN, conditional character-level RNN,\n", 22 | "sampling from RNN, truncated backpropagation through time and gradient checkpointing.\n" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "metadata": {}, 28 | "source": [ 29 | "### Character-Level language model" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "metadata": {}, 35 | "source": [ 36 | "" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "Before choosing a model, let's have a closer look at our task. Given current letter and all previous\n", 44 | "letters, we will try to predict the next character. During training we will just take a sequence, and use\n", 45 | "all its characters except the last one as an input and the same sequence starting from the second character as groundtruth (see the picture above; [Source](https://github.com/spro/practical-pytorch/blob/master/conditional-char-rnn/conditional-char-rnn.ipynb)). We will start from the simplest model that ignores all the previous characters while making a prediction, improve this model to make it take only a certain number of previous characters into account, and conclude with a model that takes all the previous characters into consideration while making a prediction.\n", 46 | "\n", 47 | "Our language model is defined on a character level. We will create a dictionary which will contain\n", 48 | "all English characters plus some special symbols, like period, comma, and end-of-line symbol. Each charecter will be represented as one-hot-encoded tensor. For more information about character-level models and examples, I recommend [this resource](https://github.com/spro/practical-pytorch).\n", 49 | "\n", 50 | "Having characters, we can now form sequences of characters. We can generate sentences even now just by\n", 51 | "randomly sampling character after character with a fixed probability $p(any~letter)=\\frac{1}{dictionary~size}$.\n", 52 | "That's the most simple character level language model. Can we do better than this? Yes, we can compute the probabily of occurance of each letter from our training corpus (number of times a letter occures divided by the size of our dataset) and randomly sample letter using these probabilities. This model is better but it totally ignores the relative positional aspect of each letter. For example, pay attention on how you read any word: you start with the first letter, which is usually hard to predict, but as you reach the end of a word you can sometimes guess the next letter. When you read any word you are implicitly using some rules which you learned by reading other texts: for example, with each additional letter that you read from a word, the probability of a space character increases (really long words are rare) or the probability of any consonant after the letter \"r\" is low as it usually followed by vowel. There are lot of similar rules and we hope that our model will be able to learn them from data. To give our model a chance to learn these rules we need to extend it.\n", 53 | "\n", 54 | "Let's make a small gradual improvement of our model and let probability of each letter depend\n", 55 | "only on the previously occured letter ([markov assumption](https://en.wikipedia.org/wiki/Markov_property)). So, basically we will have $p(current~letter|previous~letter)$.\n", 56 | "This is a [Markov chain model](https://en.wikipedia.org/wiki/Markov_chain) (also try these [interactive visualizations](http://setosa.io/ev/markov-chains/) if you are not familiar with it). We can also estimate the probability distribution $p(current~letter|previous~letter)$ from our training dataset. This model is limited because in most cases the probability of the current letter depends not only on the previous letter.\n", 57 | "\n", 58 | "What we would like to model is actually $p(current~letter|all~previous~letters)$. At first, the task seems intractable as the number of previous letters is variable and it might become really large in case of long\n", 59 | "sequences. Turns out Reccurent Neural Netoworks can tackle this problem to a certain extent by using shared weights and fixed size hidden state. This leads us to a next section dedicated to RNNs." 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "metadata": {}, 65 | "source": [ 66 | "### Recurrent Neural Networks" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": {}, 72 | "source": [ 73 | "" 74 | ] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "metadata": {}, 79 | "source": [ 80 | "Recurrent neural networks are a family of neural networks for processing sequential data.\n", 81 | "Unlike feedforward neural networks, RNNs can use their internal memory to process arbitrary sequences of inputs.\n", 82 | "Because of arbitrary size input sequences, they are concisely depicted as a graph with a cycle (see the picture; [Source](http://www.wildml.com/2015/09/recurrent-neural-networks-tutorial-part-1-introduction-to-rnns/)).\n", 83 | "But they can be \"unfolded\" if the size of input sequence is known. They define a non-linear mapping from a current input $x_t$ and previous hidden state $s_{t-1}$ to the output $o_t$ and current hidden state $s_t$. Hidden state size has a predefined size and stores features which are updated on each step and affect the result of mapping.\n", 84 | "\n", 85 | "Now align the previous picture of the character-level language model and the ufolded RNN picture to see how\n", 86 | "we are using the RNN model to learn a character level language model.\n", 87 | "\n", 88 | "While the picture depicts the Vanilla RNN, we will use LSTM in our work as it is easier to train usually achieves better results.\n", 89 | "\n", 90 | "For a more elaborate introduction to RNNs, we refer reader to the [following resource](http://www.wildml.com/2015/09/recurrent-neural-networks-tutorial-part-1-introduction-to-rnns/)." 91 | ] 92 | }, 93 | { 94 | "cell_type": "markdown", 95 | "metadata": {}, 96 | "source": [ 97 | "### Lyrics dataset" 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "metadata": {}, 103 | "source": [ 104 | "For our experiments we have chosen [55000+ Song Lyrics Kaggle dataset](https://www.kaggle.com/mousehead/songlyrics) which contains good variety of recent artists and more older ones. It is stored as a pandas file and we wrote a python wrapper around it to be able to use it for training purposes. You will have to download it yourself in order to be able to use our code.\n", 105 | "\n", 106 | "In order to be able to interpret the results better, I have chosen a subset of artists which I am\n", 107 | "more or less familiar with:" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 3, 113 | "metadata": { 114 | "collapsed": true 115 | }, 116 | "outputs": [], 117 | "source": [ 118 | "artists = [\n", 119 | "'ABBA',\n", 120 | "'Ace Of Base',\n", 121 | "'Aerosmith',\n", 122 | "'Avril Lavigne',\n", 123 | "'Backstreet Boys',\n", 124 | "'Bob Marley',\n", 125 | "'Bon Jovi',\n", 126 | "'Britney Spears',\n", 127 | "'Bruno Mars',\n", 128 | "'Coldplay',\n", 129 | "'Def Leppard',\n", 130 | "'Depeche Mode',\n", 131 | "'Ed Sheeran',\n", 132 | "'Elton John',\n", 133 | "'Elvis Presley',\n", 134 | "'Eminem',\n", 135 | "'Enrique Iglesias',\n", 136 | "'Evanescence',\n", 137 | "'Fall Out Boy',\n", 138 | "'Foo Fighters',\n", 139 | "'Green Day',\n", 140 | " 'HIM',\n", 141 | " 'Imagine Dragons',\n", 142 | " 'Incubus',\n", 143 | " 'Jimi Hendrix',\n", 144 | " 'Justin Bieber',\n", 145 | " 'Justin Timberlake',\n", 146 | "'Kanye West',\n", 147 | " 'Katy Perry',\n", 148 | " 'The Killers',\n", 149 | " 'Kiss',\n", 150 | " 'Lady Gaga',\n", 151 | " 'Lana Del Rey',\n", 152 | " 'Linkin Park',\n", 153 | " 'Madonna',\n", 154 | " 'Marilyn Manson',\n", 155 | " 'Maroon 5',\n", 156 | " 'Metallica',\n", 157 | " 'Michael Bolton',\n", 158 | " 'Michael Jackson',\n", 159 | " 'Miley Cyrus',\n", 160 | " 'Nickelback',\n", 161 | " 'Nightwish',\n", 162 | " 'Nirvana',\n", 163 | " 'Oasis',\n", 164 | " 'Offspring',\n", 165 | " 'One Direction',\n", 166 | " 'Ozzy Osbourne',\n", 167 | " 'P!nk',\n", 168 | " 'Queen',\n", 169 | " 'Radiohead',\n", 170 | " 'Red Hot Chili Peppers',\n", 171 | " 'Rihanna',\n", 172 | " 'Robbie Williams',\n", 173 | " 'Rolling Stones',\n", 174 | " 'Roxette',\n", 175 | " 'Scorpions',\n", 176 | " 'Snoop Dogg',\n", 177 | " 'Sting',\n", 178 | " 'The Script',\n", 179 | " 'U2',\n", 180 | " 'Weezer',\n", 181 | " 'Yellowcard',\n", 182 | " 'ZZ Top']" 183 | ] 184 | }, 185 | { 186 | "cell_type": "markdown", 187 | "metadata": {}, 188 | "source": [ 189 | "### Training unconditional character-level language model" 190 | ] 191 | }, 192 | { 193 | "cell_type": "markdown", 194 | "metadata": {}, 195 | "source": [ 196 | "Our first experiment consisted of training of our character-level language model RNN\n", 197 | "on the whole corpus. We didn't take into consideration the artist information while training." 198 | ] 199 | }, 200 | { 201 | "cell_type": "markdown", 202 | "metadata": {}, 203 | "source": [ 204 | "### Sampling from RNN" 205 | ] 206 | }, 207 | { 208 | "cell_type": "markdown", 209 | "metadata": {}, 210 | "source": [ 211 | "Let's try to sample a couple of songs after training our model. Basically, on each\n", 212 | "step our RNN will output logits and we can softmax them and sample from that distribution.\n", 213 | "Or we can use Gumble-Max trick and [sample using logits directly](https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/) which is equivalent.\n", 214 | "\n", 215 | "One intersting thing about sampling is that we can partially define the input sequence ourselves and start sampling\n", 216 | "with that initial condition. For example, we can sample a song that starts with \"Why\":\n", 217 | "\n", 218 | "```\n", 219 | "Why do you have to leave me? \n", 220 | "I think I know I'm not the only one \n", 221 | "I don't know if I'm gonna stay awake \n", 222 | "I don't know why I go along \n", 223 | " \n", 224 | "I don't know why I can't go on \n", 225 | "I don't know why I don't know \n", 226 | "I don't know why I don't know \n", 227 | "I don't know why I keep on dreaming of you \n", 228 | "```\n", 229 | "\n", 230 | "Well, that sounds like a possible song :D\n", 231 | "\n", 232 | "Let's sample with a song that starts with \"Well\":\n", 233 | "\n", 234 | "```\n", 235 | "Well, I was a real good time \n", 236 | "I was a rolling stone \n", 237 | "I was a rock and roller \n", 238 | "Well, I never had a rock and roll \n", 239 | "There were times I had to do it \n", 240 | "I had a feeling that I was found \n", 241 | "I was the one who had to go \n", 242 | "```\n", 243 | "\n", 244 | "There is \"temperature\" parameter that is used during sampling which controls the randomness of sampling\n", 245 | "process. When this parameter approaches zero,\n", 246 | "the sampling is equivalent to argmax and when it is close to infinity the sampling is equivalent to sampling\n", 247 | "from a uniform distribution. Have a look at the figure from a [relevant paper by Jang et al.](https://arxiv.org/pdf/1611.01144.pdf):\n", 248 | "\n", 249 | "\n", 250 | "\n", 251 | "\n", 252 | "When $\\tau=1$, the distribution is not affected. If we decrease $\\tau$, the distribution\n", 253 | "becomes more pronounced, meaning that value with bigger probability mass will have it increased. When $\\tau$ will approach zero, sampling will be equivalent to armax, because the probability of that value will be close to one. When we start to icrease $\\tau$ the distribution becomes more and more uniform.\n", 254 | "\n", 255 | "The previous sample was generated with a temperature paramter equal to $0.5$.\n", 256 | "Let's see what happens when we increase it to $1.0$ and sample:\n", 257 | "\n", 258 | "```\n", 259 | "Why can't we drop out of time? \n", 260 | "We were born for words to see. \n", 261 | "Won't you love this. You're still so amazing. \n", 262 | "This could be that down on Sunday Time. \n", 263 | "Oh, Caroline, a lady floor. \n", 264 | "I thought of love, oh baby. \n", 265 | "```\n", 266 | "\n", 267 | "Let's try increasing it even more:\n", 268 | "\n", 269 | "\n", 270 | "```\n", 271 | "Why - won't we grow up naked? \n", 272 | "We went quietly what we would've still give \n", 273 | "That girl you walked before our bedroom room \n", 274 | "I see your mind is so small to a freak \n", 275 | "Stretching for a cold white-heart of crashing \n", 276 | "Truth in the universal daughter \n", 277 | " \n", 278 | "I lose more and more hard \n", 279 | "I love you anytime at all \n", 280 | "Ah come let your help remind me \n", 281 | "Now I've wanted waste and never noticed \n", 282 | " \n", 283 | "I swear I saw you today \n", 284 | "You needed to get by \n", 285 | "But you sold a hurricane \n", 286 | "Well out whispered in store\n", 287 | "```\n", 288 | "\n", 289 | "Why don't we grow up naked, indeed? :D\n", 290 | "Well, you can see that trend that when we increase the temperature, sampled\n", 291 | "sentences become more and more random." 292 | ] 293 | }, 294 | { 295 | "cell_type": "markdown", 296 | "metadata": {}, 297 | "source": [ 298 | "### Training conditional character-level language model" 299 | ] 300 | }, 301 | { 302 | "cell_type": "markdown", 303 | "metadata": {}, 304 | "source": [ 305 | "Imagine if we could generate lyrics in a style of some particular artist.\n", 306 | "Let's change our model, so that it can use this information during training.\n", 307 | "\n", 308 | "We will do this by adding an additional input to our RNN. So far, our RNN model\n", 309 | "was only accepting tensors containing one-hot encoded character on each step.\n", 310 | "\n", 311 | "The extention to our model will be very simple: we will have and additional one-hot encoded\n", 312 | "tensor which will represent the artist. So on each step the RNN will accept one tensor which will consist of concatenated tensors representing character and artist. Look [here for more](https://github.com/spro/practical-pytorch/blob/master/conditional-char-rnn/conditional-char-rnn.ipynb)." 313 | ] 314 | }, 315 | { 316 | "cell_type": "markdown", 317 | "metadata": {}, 318 | "source": [ 319 | "### Sampling from conditional language model RNN" 320 | ] 321 | }, 322 | { 323 | "cell_type": "markdown", 324 | "metadata": {}, 325 | "source": [ 326 | "After training, we sampled a couple of songs conditined on artist.\n", 327 | "Below you can find some results.\n", 328 | "\n", 329 | "Him:\n", 330 | "\n", 331 | "```\n", 332 | "My fears \n", 333 | "And the moment don't make me sing \n", 334 | "So free from you \n", 335 | "The pain you love me yeah \n", 336 | " \n", 337 | "Whatever caused the warmth \n", 338 | "You smile you're happy \n", 339 | "You sit away \n", 340 | "You say it's all in vain \n", 341 | "```\n", 342 | "\n", 343 | "Seems really possible, especially the fact the the word pain was used, which is\n", 344 | "very common in the lyrics of the artist.\n", 345 | "\n", 346 | "ABBA:\n", 347 | "\n", 348 | "```\n", 349 | "Oh, my love it makes me close a thing \n", 350 | "You've been heard, I must have waited \n", 351 | "I hear you \n", 352 | "So I say \n", 353 | "Thank you for the music, that makes me cry \n", 354 | " \n", 355 | "And you moving my bad as me, ah-hang wind in the hell \n", 356 | "I was meant to be with you, I'll never be playing up\n", 357 | "```\n", 358 | "\n", 359 | "Bob Marley:\n", 360 | "\n", 361 | "```\n", 362 | "Mercy on judgment, we got so much \n", 363 | " \n", 364 | "Alcohol, cry, cry, cry \n", 365 | "Why don't try to find our own \n", 366 | "I want to know, Lord, I wanna give you \n", 367 | "Just saving it, learned \n", 368 | "Is there any more? \n", 369 | " \n", 370 | "All that damage done \n", 371 | "That's all reason, don't worry \n", 372 | "Need a hammer \n", 373 | "I need you more and more \n", 374 | "```\n", 375 | "\n", 376 | "Coldplay:\n", 377 | "\n", 378 | "```\n", 379 | "Look at the stars \n", 380 | "Into life matter where you lay \n", 381 | "Saying no doubt \n", 382 | "I don't want to fly \n", 383 | "In my dreams and fight today\n", 384 | "\n", 385 | "I will fall for you \n", 386 | " \n", 387 | "All I know \n", 388 | "And I want you to stay \n", 389 | "Into the night \n", 390 | " \n", 391 | "I want to live waiting \n", 392 | "With my love and always \n", 393 | "Have I wouldn't wasted \n", 394 | "Would it hurt you\n", 395 | "```\n", 396 | "\n", 397 | "Kanye West:\n", 398 | "\n", 399 | "```\n", 400 | "I'm everywhere for you \n", 401 | "The way that it couldn't stop \n", 402 | "I mean it too late and love I made in the world \n", 403 | "I told you so I took the studs full cold-stop \n", 404 | "The hardest stressed growin' \n", 405 | "The hustler raisin' on my tears \n", 406 | "I know I'm true, one of your love\n", 407 | "```\n", 408 | "\n", 409 | "Looks pretty cool but keep in mind that we didn't track the validation accuracy so some sampled lines could have been just memorized by our rnn. A better way to do it is to pick a model that gives best validation score during training (see the code for the next section where we performed training this way). We also noticed one interesting thing: the unconditional\n", 410 | "model usually performes better when you want to sample with a specified starting string.\n", 411 | "Our intuition is that when sampling from a conditional model with a specified starting string, \n", 412 | "we actually put two conditions on our model -- starting string and an artist compared to the one condition\n", 413 | "in the case of previous model that we explored. And we didn't have enough data to model that conditional\n", 414 | "distribution well (every artist has relatively limited number of songs).\n", 415 | "\n", 416 | "We are making the code and models available and you can sample songs from our trained models\n", 417 | "even without gpu as it is not really computationally demanding." 418 | ] 419 | }, 420 | { 421 | "cell_type": "markdown", 422 | "metadata": {}, 423 | "source": [ 424 | "### Midi dataset" 425 | ] 426 | }, 427 | { 428 | "cell_type": "markdown", 429 | "metadata": {}, 430 | "source": [ 431 | "Next, we will work with a [small midi dataset](http://www-etud.iro.umontreal.ca/~boulanni/icml2012) consisting\n", 432 | "from approximately $700$ piano songs. We have used the ```Nottingam``` piano dataset (training split only).\n", 433 | "\n", 434 | "Turns out that any midi file can be [converted to piano roll](http://nbviewer.jupyter.org/github/craffel/pretty-midi/blob/master/Tutorial.ipynb) which is just is a time-frequency matrix where each row is a different MIDI pitch and each column is a different slice in time. So each piano song from our dataset will be represented as a matrix of size $88\\times song\\_length$, where $88$ is a number of pitches of the piano. Here is an example of\n", 435 | "piano roll matrix:\n", 436 | "\n", 437 | "\n", 438 | "\n", 439 | "This representation is very intuitive and easy to interpret even for a person that is not familiar\n", 440 | "with music theory. Each row represents a pitch: top rows represent low frequency pitches and bottom\n", 441 | "rows represent high pitches. Plus, we have a horizontal axis which represents time. So if we play a sound\n", 442 | "with a certain pitch for a certian period of time, we will see a horizontal line. Overall, this is very\n", 443 | "similar to [piano tutorials on youtube](.\n", 444 | "\n", 445 | "Now, let's try to see the similarities between the character-level model and our new task. In the current case, we will have to predict the pitches that will be played on the next timestep, given all the previously played\n", 446 | "pitches. So, if you look at the picture of the piano roll, each column represents some kind of a musical character and given all the previous musical characters, we want to predict the next one. Let's pay attention to the difference between a text character and a musical character. If you recall, each character in our language model was represented by one-hot vector (meaning that only one value in our vector is $1$ and others are $0$).\n", 447 | "For music character multiple keys can be pressed at one timestep (since we are working with polyphonic dataset).\n", 448 | "In this case, each timestep will be represented by a vector which can contain more than one $1$.\n", 449 | "\n" 450 | ] 451 | }, 452 | { 453 | "cell_type": "markdown", 454 | "metadata": {}, 455 | "source": [ 456 | "### Training pitch-level piano music model" 457 | ] 458 | }, 459 | { 460 | "cell_type": "markdown", 461 | "metadata": {}, 462 | "source": [ 463 | "Before starting the training, we will have to adjust our loss that we have used for language model\n", 464 | "to account for different input that we discussed in the previous section. In the language model,\n", 465 | "we had one-hot encoded tensor (character) as an input on each timestep and one-hot encoded tensor as output (predicted next character). As we had to make a single exlusive choice for predicted next character, we used\n", 466 | "[cross-entropy loss](https://rdipietro.github.io/friendly-intro-to-cross-entropy-loss/).\n", 467 | "\n", 468 | "But now our model outputs a vector which is no longer one-hot encoded (multiple keys can be pressed). Of course, we can treat all possible combinations of pressed keys as a separate class, but this is intractable. Instead, we will treat each element of the output vector as a binary variable ($1$ -- pressing, $0$ -- not pressing a key). We will define a separate loss\n", 469 | "for each element of the output vector to be binary cross-entropy. And our final loss will be an averaged sum of these binary cross-entropies. You can also read the code to get a better understanding.\n", 470 | "\n", 471 | "After making the aforementioned changes, we trained our model. In the next section, we will perform sampling\n", 472 | "and inspect the results." 473 | ] 474 | }, 475 | { 476 | "cell_type": "markdown", 477 | "metadata": {}, 478 | "source": [ 479 | "### Sampling from pitch-level RNN" 480 | ] 481 | }, 482 | { 483 | "cell_type": "markdown", 484 | "metadata": {}, 485 | "source": [ 486 | "We have sampled piano rolls during the early stages of optimization:\n", 487 | "\n", 488 | "\n", 489 | "\n", 490 | "You can see that our model is starting to learn one common pattern that is common among the songs from\n", 491 | "our dataset: each song consists of two different parts. First part contains a sequence of pitches that are played separately and are very [distinguishable and are often singable](https://www.didjshop.com/BasicMusicalHarmony.html) (also know as melody). If you look at the sampled piano roll, this part can be clearly seen in the bottom. If you also have a look at the top of our piano roll, we can see a group of pitches that are usually played together -- this is harmony or a progression of chords (pitches that are played together throughout the song) which accompanies the melody.\n", 492 | "\n", 493 | "By the end of the training samples drawn from our model started to look like this:\n", 494 | "\n", 495 | "\n", 496 | "\n", 497 | "As you can see they started to look more similar to the picture of the ground-truth piano roll that we showed\n", 498 | "in the previous sections.\n", 499 | "\n", 500 | "After training, we have sampled songs and analyzed them. We got one sample with [an interesting introduction](https://www.youtube.com/watch?v=Iz8xQou2OqA). While another sample features [a nice style transition](https://www.youtube.com/watch?v=fUdsWVIOeeU&feature=youtu.be&t=15s). At the same time we generated a couple of examples with low temperature parameter which resulted in songs with a slow tempo: [first one](https://www.youtube.com/watch?v=UoLyeauBsNk) and a [second one here](https://www.youtube.com/watch?v=Iz8xQou2OqA).\n", 501 | "You can find the whole playlist [here](https://www.youtube.com/watch?v=EOQQOQYvGnw&list=PLJkMX36nfYD000TG-T59hmEgJ3ojkOlBp).\n" 502 | ] 503 | }, 504 | { 505 | "cell_type": "markdown", 506 | "metadata": {}, 507 | "source": [ 508 | "### Sequence length and related problems" 509 | ] 510 | }, 511 | { 512 | "cell_type": "markdown", 513 | "metadata": {}, 514 | "source": [ 515 | "Now let's look at our problem from the gpu memory consumption and speed point of view.\n", 516 | "\n", 517 | "We greatly speed up computation by processing our sequences in batches. At the same time, as\n", 518 | "our sequences become longer (depending on the dataset), our max batch size starts to decrease.\n", 519 | "Why is it a case? As we use backpropagation to compute gradients, we need to store all the intermediate acitvations, which contribute the most to the memory consumption. As our sequence becomes longer, we need to store more activations, therefore, we can fit less examples in our batch.\n", 520 | "\n", 521 | "Sometimes, we either have to work with really long sequences or we want to increase our batch size or maybe you just have a gpu with small amount of memory available. There are multiple possible solutions to reduce memory\n", 522 | "consumption in this case, but we will mention two, which will have different trade-offs.\n", 523 | "\n", 524 | "First one is a [truncated back propagation](https://www.quora.com/Whats-the-key-difference-between-backprop-and-truncated-backprop-through-time#sLRGO). The idea is to split the whole sequence into subsequences and treat\n", 525 | "them as separate batches with an exception that we process these batches in the order of split and every next batch uses hidden state of previous batch as an initial hidden state. We also provide an implementation of this approach, so that you can get the better understanding. This approach is obviously not an exact equivalent of processing the whole sequence but it makes more frequent updates and consumes less memory. On the other hand, there is a chance that we might not be able to capture long-term dependencies that span beyond the length of one subsequence.\n", 526 | "\n", 527 | "Second one is [gradient checkpointing](https://medium.com/@yaroslavvb/fitting-larger-networks-into-memory-583e3c758ff9). This method gives us a possibilty to use less memory while training our model on the whole sequence on the expence of performing more computation. If you recall, previously we mentioned that the most memory during\n", 528 | "training is occupied by activations. The idea of gradient checkpointing consists of storing only every $n$-th activation and recomputing the unsaved activations later. This method is already [implemented in Tensorflow](https://github.com/openai/gradient-checkpointing) and [being implemented in Pytorch](https://github.com/pytorch/pytorch/pull/4594).\n", 529 | "\n" 530 | ] 531 | }, 532 | { 533 | "cell_type": "markdown", 534 | "metadata": {}, 535 | "source": [ 536 | "### Conclusion and future work" 537 | ] 538 | }, 539 | { 540 | "cell_type": "markdown", 541 | "metadata": {}, 542 | "source": [ 543 | "In our work we trained simple generative model for text, extended our model to work with\n", 544 | "polyphonic music, briefly looked at how sampling works and how the temperature parameter affects our\n", 545 | "text and music samples -- low temperature gives more stable results while high temperature adds more\n", 546 | "randomness which sometimes gives rise to very interesting samples.\n", 547 | "\n", 548 | "Future work can include two directions -- more applications or deeper analysis of the already trained models.\n", 549 | "Same models can be applied to your spotify listening history, for example. After training on your\n", 550 | "listening history data, you can give it a sequence of songs that you have listened to in the previous hour or so, and it will sample a playlist for you for the rest of the day. Well, you can also do the same for your browsing history, which will be just a cool tool to analyze your browsing behaviour patterns. [Capture the accelerometer and gyroscope data](https://www.kaggle.com/uciml/human-activity-recognition-with-smartphones) from your phone while doing different activities (exercising in the gym, working in the office, sleeping) and learn to classify these\n", 551 | "activity stages. After that you can change your music playlist automatically, based on your activity (sleeping -- calm music of rain, exercising in the gym -- high intensity music). In terms of medical applications, model can\n", 552 | "be applied to detect heart problems based on pulse and other data, similar to [this work](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC5391725/).\n", 553 | "\n", 554 | "It would be very interesting to analyze the neuron firings in our RNN trained\n", 555 | "for music generation like [here](http://karpathy.github.io/2015/05/21/rnn-effectiveness/). To see if the model learned some simple music concepts implicitly (like our discussion\n", 556 | "of harmony and melody). The hidden representation of RNN can be used to cluster our music dataset to \n", 557 | "find similar songs.\n", 558 | "\n", 559 | "Let's sample one last lyrics from our unconditional model to conclude this post :D :\n", 560 | "\n", 561 | "```\n", 562 | "The story ends \n", 563 | "The sound of the blue \n", 564 | "The tears were shining \n", 565 | "The story of my life \n", 566 | "I still believe \n", 567 | "The story of my life \n", 568 | "```" 569 | ] 570 | } 571 | ], 572 | "metadata": { 573 | "kernelspec": { 574 | "display_name": "Python 2", 575 | "language": "python", 576 | "name": "python2" 577 | }, 578 | "language_info": { 579 | "codemirror_mode": { 580 | "name": "ipython", 581 | "version": 2 582 | }, 583 | "file_extension": ".py", 584 | "mimetype": "text/x-python", 585 | "name": "python", 586 | "nbconvert_exporter": "python", 587 | "pygments_lexer": "ipython2", 588 | "version": "2.7.14" 589 | } 590 | }, 591 | "nbformat": 4, 592 | "nbformat_minor": 2 593 | } 594 | -------------------------------------------------------------------------------- /notebooks/truncated_backprop_music_generation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import pretty_midi\n", 12 | "import numpy as np\n", 13 | "import torch\n", 14 | "import torch.nn as nn\n", 15 | "from torch.autograd import Variable\n", 16 | "import torch.utils.data as data\n", 17 | "import os\n", 18 | "import random\n", 19 | "\n", 20 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'\n", 21 | "\n", 22 | "def midi_filename_to_piano_roll(midi_filename):\n", 23 | " \"\"\"Returns matrix which represents a piano roll of specified midi filename.\n", 24 | " Reads the midi file using pretty_midi, calls get_piano_roll() and binarizes\n", 25 | " notes (some notes that are being pressed have different volume -- we don't\n", 26 | " really need this information during training)\n", 27 | " \n", 28 | " Parameters\n", 29 | " ----------\n", 30 | " midi_filename : string\n", 31 | " Full-path of the midi file\n", 32 | " \n", 33 | " Returns\n", 34 | " -------\n", 35 | " piano_roll : numpy matrix of size 128xseq_length\n", 36 | " Matrix that cotains piano roll\n", 37 | " \"\"\"\n", 38 | "\n", 39 | " midi_data = pretty_midi.PrettyMIDI(midi_filename)\n", 40 | " \n", 41 | " piano_roll = midi_data.get_piano_roll()\n", 42 | " \n", 43 | " # Binarize the pressed notes\n", 44 | " piano_roll[piano_roll > 0] = 1\n", 45 | " \n", 46 | " return piano_roll\n", 47 | "\n", 48 | "\n", 49 | "def pad_piano_roll(piano_roll, max_length=132333, pad_value=0):\n", 50 | " \"\"\"Pads piano roll to have a certain sequence length: max_length with\n", 51 | " a special value: pad_value\n", 52 | " \n", 53 | " Parameters\n", 54 | " ----------\n", 55 | " piano_roll : numpy matrix shape (128, seq_length) \n", 56 | " Matrix representing the piano roll\n", 57 | " \n", 58 | " Returns\n", 59 | " -------\n", 60 | " piano_roll : numpy matrix shape (128, max_length)\n", 61 | " Matrix representing the padded piano roll\n", 62 | " \"\"\"\n", 63 | " \n", 64 | " # We hardcode 128 -- because we will always use only\n", 65 | " # 128 pitches\n", 66 | " \n", 67 | " original_piano_roll_length = piano_roll.shape[1]\n", 68 | " \n", 69 | " padded_piano_roll = np.zeros((128, max_length))\n", 70 | " padded_piano_roll[:] = pad_value\n", 71 | " \n", 72 | " padded_piano_roll[:, :original_piano_roll_length] = piano_roll\n", 73 | "\n", 74 | " return padded_piano_roll\n", 75 | "\n", 76 | "\n", 77 | "def piano_roll_random_crop(piano_roll, crop_length=10000):\n", 78 | "\n", 79 | " piano_roll_length = piano_roll.shape[1]\n", 80 | "\n", 81 | " max_possible_start_crop = piano_roll_length - crop_length \n", 82 | "\n", 83 | " if max_possible_start_crop < 0:\n", 84 | "\n", 85 | " return piano_roll\n", 86 | "\n", 87 | " random_start_pos = random.randrange(max_possible_start_crop)\n", 88 | "\n", 89 | " piano_roll_crop = piano_roll[:, random_start_pos:random_start_pos+crop_length]\n", 90 | "\n", 91 | " return piano_roll_crop\n", 92 | "\n", 93 | "\n", 94 | "\n", 95 | "class NotesGenerationDataset(data.Dataset):\n", 96 | " \"\"\"Dataloader for note-level (like char-rnn) rnn training\n", 97 | " Dataloader reads midi files in the provided directory and forms\n", 98 | " batches which contain original sequence without the last element,\n", 99 | " original sequence without the first element (typical note/char-level model)\n", 100 | " and length of these sequences.\n", 101 | " \"\"\"\n", 102 | " \n", 103 | " \n", 104 | " def __init__(self, midi_folder_path, longest_sequence_length=74569):\n", 105 | " \"\"\"Initializes the note dataloader.\n", 106 | "\n", 107 | " Reads all the .midi files from the specified directory.\n", 108 | " Important: set longest_sequence_length to None if you are\n", 109 | " using this loader for a different midi dataset.\n", 110 | " \n", 111 | " When being used with torch.utils.data.DataLoader, returns\n", 112 | " batch with training sequence of the size (batch_size, longest_sequence_length, 128),\n", 113 | " corresponding ground truth (batch_size, longest_sequence_length, 128) and a tensor with\n", 114 | " actual (non-padded) length of each sequence in the batch.\n", 115 | " We pad each sequence to become of the size longest_sequence_length, which might seem\n", 116 | " to be an inefficient solution -- we later trim the batch using post_process_sequence_batch()\n", 117 | " helper function -- be sure to use it after you get batches using this dataloader.\n", 118 | "\n", 119 | " Parameters\n", 120 | " ----------\n", 121 | " midi_folder_path : string\n", 122 | " String specifying the path to the midi folder\n", 123 | " longest_sequence_length : int\n", 124 | " Constant specifying the longest midi sequence in the dataset\n", 125 | " Set it to None if you don't know it (it might take some time)\n", 126 | " to initilize then.\n", 127 | " \"\"\"\n", 128 | " \n", 129 | " self.midi_folder_path = midi_folder_path\n", 130 | " \n", 131 | " midi_filenames = os.listdir(midi_folder_path)\n", 132 | " \n", 133 | " midi_full_filenames = map(lambda filename: os.path.join(midi_folder_path, filename),\n", 134 | " midi_filenames)\n", 135 | " \n", 136 | " self.midi_full_filenames = midi_full_filenames\n", 137 | " \n", 138 | " \n", 139 | " if longest_sequence_length is None:\n", 140 | " \n", 141 | " self.update_the_max_length()\n", 142 | " \n", 143 | " else:\n", 144 | " \n", 145 | " self.longest_sequence_length = longest_sequence_length\n", 146 | " \n", 147 | " \n", 148 | " def update_the_max_length(self):\n", 149 | " \"\"\"Recomputes the longest sequence constant of the dataset.\n", 150 | "\n", 151 | " Reads all the midi files from the midi folder and finds the max\n", 152 | " length.\n", 153 | " \"\"\"\n", 154 | " \n", 155 | " sequences_lengths = map(lambda filename: midi_filename_to_piano_roll(filename).shape[1],\n", 156 | " self.midi_full_filenames)\n", 157 | " \n", 158 | " max_length = max(sequences_lengths)\n", 159 | " \n", 160 | " self.longest_sequence_length = max_length\n", 161 | " \n", 162 | " \n", 163 | " def __len__(self):\n", 164 | " \n", 165 | " return len(self.midi_full_filenames)\n", 166 | " \n", 167 | " def __getitem__(self, index):\n", 168 | " \n", 169 | " midi_full_filename = self.midi_full_filenames[index]\n", 170 | " \n", 171 | " piano_roll = midi_filename_to_piano_roll(midi_full_filename)\n", 172 | " \n", 173 | " piano_roll = piano_roll_random_crop(piano_roll, crop_length=30000)\n", 174 | " \n", 175 | " # -1 because we will shift it\n", 176 | " sequence_length = piano_roll.shape[1] - 1\n", 177 | " \n", 178 | " # Shifted by one time step\n", 179 | " input_sequence = piano_roll[:, :-1]\n", 180 | " ground_truth_sequence = piano_roll[:, 1:]\n", 181 | " \n", 182 | " # pad sequence so that all of them have the same lenght\n", 183 | " # Otherwise the batching won't work\n", 184 | " input_sequence_padded = pad_piano_roll(input_sequence, max_length=30000)\n", 185 | " \n", 186 | " ground_truth_sequence_padded = pad_piano_roll(ground_truth_sequence,\n", 187 | " max_length=30000,\n", 188 | " pad_value=-100)\n", 189 | " \n", 190 | " input_sequence_padded = input_sequence_padded.transpose()\n", 191 | " ground_truth_sequence_padded = ground_truth_sequence_padded.transpose()\n", 192 | " \n", 193 | " return (torch.FloatTensor(input_sequence_padded),\n", 194 | " torch.LongTensor(ground_truth_sequence_padded),\n", 195 | " torch.LongTensor([sequence_length]) )\n", 196 | "\n", 197 | "\n", 198 | "def post_process_sequence_batch(batch_tuple):\n", 199 | " \"\"\"Performs the post processing of a batch acquired using NotesGenerationDataset object.\n", 200 | " \n", 201 | " The function sorts the sequences in the batch by their actual (non-padded)\n", 202 | " length, trims the batch so that size of the sequence length dim becomes\n", 203 | " equal to the longest sequence in the batch + transposes tensors so that they become\n", 204 | " of size (max_sequence_length, batch_size, 128) where max_sequence_length is the longest\n", 205 | " sequence in the batch.\n", 206 | " \n", 207 | " Parameters\n", 208 | " ----------\n", 209 | " batch_tuple : tuple\n", 210 | " (input_sequence, ground_truth_sequence, sequence_lengths) tuple acquired\n", 211 | " from dataloader created using NotesGenerationDataset object.\n", 212 | " \n", 213 | " Returns\n", 214 | " -------\n", 215 | " processed_tuple : tuple\n", 216 | " Returns the processed batch tuple (read the main description of the function)\n", 217 | " \"\"\"\n", 218 | " \n", 219 | " \n", 220 | " input_sequences, output_sequences, lengths = batch_tuple\n", 221 | " \n", 222 | " splitted_input_sequence_batch = input_sequences.split(split_size=1)\n", 223 | " splitted_output_sequence_batch = output_sequences.split(split_size=1)\n", 224 | " splitted_lengths_batch = lengths.split(split_size=1)\n", 225 | "\n", 226 | " training_data_tuples = zip(splitted_input_sequence_batch,\n", 227 | " splitted_output_sequence_batch,\n", 228 | " splitted_lengths_batch)\n", 229 | "\n", 230 | " training_data_tuples_sorted = sorted(training_data_tuples,\n", 231 | " key=lambda p: int(p[2]),\n", 232 | " reverse=True)\n", 233 | "\n", 234 | " splitted_input_sequence_batch, splitted_output_sequence_batch, splitted_lengths_batch = zip(*training_data_tuples_sorted)\n", 235 | "\n", 236 | " input_sequence_batch_sorted = torch.cat(splitted_input_sequence_batch)\n", 237 | " output_sequence_batch_sorted = torch.cat(splitted_output_sequence_batch)\n", 238 | " lengths_batch_sorted = torch.cat(splitted_lengths_batch)\n", 239 | " \n", 240 | " # Here we trim overall data matrix using the size of the longest sequence\n", 241 | " input_sequence_batch_sorted = input_sequence_batch_sorted[:, :lengths_batch_sorted[0, 0], :]\n", 242 | " output_sequence_batch_sorted = output_sequence_batch_sorted[:, :lengths_batch_sorted[0, 0], :]\n", 243 | " \n", 244 | " input_sequence_batch_transposed = input_sequence_batch_sorted.transpose(0, 1)\n", 245 | " \n", 246 | " # pytorch's api for rnns wants lenghts to be list of ints\n", 247 | " lengths_batch_sorted_list = list(lengths_batch_sorted)\n", 248 | " lengths_batch_sorted_list = map(lambda x: int(x), lengths_batch_sorted_list)\n", 249 | " \n", 250 | " return input_sequence_batch_transposed, output_sequence_batch_sorted, lengths_batch_sorted_list\n", 251 | "\n", 252 | "class RNN(nn.Module):\n", 253 | " \n", 254 | " def __init__(self, input_size, hidden_size, num_classes, n_layers=2):\n", 255 | " \n", 256 | " super(RNN, self).__init__()\n", 257 | " \n", 258 | " self.input_size = input_size\n", 259 | " self.hidden_size = hidden_size\n", 260 | " self.num_classes = num_classes\n", 261 | " self.n_layers = n_layers\n", 262 | " \n", 263 | " self.notes_encoder = nn.Linear(in_features=input_size, out_features=hidden_size)\n", 264 | " \n", 265 | " self.lstm = nn.LSTM(hidden_size, hidden_size, n_layers)\n", 266 | " \n", 267 | " self.logits_fc = nn.Linear(hidden_size, num_classes)\n", 268 | " \n", 269 | " \n", 270 | " def forward(self, input_sequences, hidden=None):\n", 271 | " \n", 272 | " batch_size = input_sequences.shape[1]\n", 273 | "\n", 274 | " notes_encoded = self.notes_encoder(input_sequences)\n", 275 | " \n", 276 | " outputs, hidden = self.lstm(notes_encoded, hidden)\n", 277 | " \n", 278 | " logits = self.logits_fc(outputs)\n", 279 | " \n", 280 | " # transpose it so that it becomes aligned with groundtruth later on in the\n", 281 | " # training\n", 282 | " logits = logits.transpose(0, 1).contiguous()\n", 283 | " \n", 284 | " neg_logits = (1 - logits)\n", 285 | " \n", 286 | " # Since the BCE loss doesn't support masking, we use the crossentropy\n", 287 | " binary_logits = torch.stack((logits, neg_logits), dim=3).contiguous()\n", 288 | " \n", 289 | " logits_flatten = binary_logits.view(-1, 2)\n", 290 | " \n", 291 | " return logits_flatten, hidden" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": 2, 297 | "metadata": { 298 | "collapsed": true 299 | }, 300 | "outputs": [], 301 | "source": [ 302 | "trainset = NotesGenerationDataset('Piano-midi.de/valid/')\n", 303 | "\n", 304 | "trainset_loader = torch.utils.data.DataLoader(trainset, batch_size=6,\n", 305 | " shuffle=True, num_workers=4, drop_last=True)" 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": 3, 311 | "metadata": { 312 | "collapsed": true 313 | }, 314 | "outputs": [], 315 | "source": [ 316 | "\n", 317 | "rnn = RNN(input_size=128, hidden_size=256, num_classes=128)\n", 318 | "rnn = rnn.cuda()\n", 319 | "\n", 320 | "criterion = nn.CrossEntropyLoss().cuda()\n", 321 | "\n", 322 | "learning_rate = 0.001\n", 323 | "optimizer = torch.optim.Adam(rnn.parameters(), lr=learning_rate)" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": 5, 329 | "metadata": { 330 | "collapsed": true 331 | }, 332 | "outputs": [], 333 | "source": [ 334 | "def sample_from_piano_rnn(sample_length=4, temperature=1, starting_sequence=None):\n", 335 | "\n", 336 | " if starting_sequence is None:\n", 337 | " \n", 338 | " current_sequence_input = torch.zeros(1, 1, 128)\n", 339 | " current_sequence_input[0, 0, 64] = 0\n", 340 | " current_sequence_input[0, 0, 50] = 0\n", 341 | " current_sequence_input = Variable(current_sequence_input.cuda())\n", 342 | "\n", 343 | " final_output_sequence = [current_sequence_input.data.squeeze(1)]\n", 344 | " \n", 345 | " hidden = None\n", 346 | "\n", 347 | " for i in xrange(sample_length):\n", 348 | "\n", 349 | " output, hidden = rnn(current_sequence_input, hidden)\n", 350 | "\n", 351 | " probabilities = nn.functional.softmax(output.div(temperature), dim=1)\n", 352 | "\n", 353 | " current_sequence_input = torch.multinomial(probabilities.data, 1).squeeze().unsqueeze(0).unsqueeze(1)\n", 354 | "\n", 355 | " current_sequence_input = Variable(current_sequence_input.float())\n", 356 | "\n", 357 | " final_output_sequence.append(current_sequence_input.data.squeeze(1))\n", 358 | "\n", 359 | " sampled_sequence = torch.cat(final_output_sequence, dim=0).cpu().numpy()\n", 360 | " \n", 361 | " return sampled_sequence" 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "execution_count": 6, 367 | "metadata": {}, 368 | "outputs": [ 369 | { 370 | "data": { 371 | "application/javascript": [ 372 | "/* Put everything inside the global mpl namespace */\n", 373 | "window.mpl = {};\n", 374 | "\n", 375 | "\n", 376 | "mpl.get_websocket_type = function() {\n", 377 | " if (typeof(WebSocket) !== 'undefined') {\n", 378 | " return WebSocket;\n", 379 | " } else if (typeof(MozWebSocket) !== 'undefined') {\n", 380 | " return MozWebSocket;\n", 381 | " } else {\n", 382 | " alert('Your browser does not have WebSocket support.' +\n", 383 | " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", 384 | " 'Firefox 4 and 5 are also supported but you ' +\n", 385 | " 'have to enable WebSockets in about:config.');\n", 386 | " };\n", 387 | "}\n", 388 | "\n", 389 | "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", 390 | " this.id = figure_id;\n", 391 | "\n", 392 | " this.ws = websocket;\n", 393 | "\n", 394 | " this.supports_binary = (this.ws.binaryType != undefined);\n", 395 | "\n", 396 | " if (!this.supports_binary) {\n", 397 | " var warnings = document.getElementById(\"mpl-warnings\");\n", 398 | " if (warnings) {\n", 399 | " warnings.style.display = 'block';\n", 400 | " warnings.textContent = (\n", 401 | " \"This browser does not support binary websocket messages. \" +\n", 402 | " \"Performance may be slow.\");\n", 403 | " }\n", 404 | " }\n", 405 | "\n", 406 | " this.imageObj = new Image();\n", 407 | "\n", 408 | " this.context = undefined;\n", 409 | " this.message = undefined;\n", 410 | " this.canvas = undefined;\n", 411 | " this.rubberband_canvas = undefined;\n", 412 | " this.rubberband_context = undefined;\n", 413 | " this.format_dropdown = undefined;\n", 414 | "\n", 415 | " this.image_mode = 'full';\n", 416 | "\n", 417 | " this.root = $('
');\n", 418 | " this._root_extra_style(this.root)\n", 419 | " this.root.attr('style', 'display: inline-block');\n", 420 | "\n", 421 | " $(parent_element).append(this.root);\n", 422 | "\n", 423 | " this._init_header(this);\n", 424 | " this._init_canvas(this);\n", 425 | " this._init_toolbar(this);\n", 426 | "\n", 427 | " var fig = this;\n", 428 | "\n", 429 | " this.waiting = false;\n", 430 | "\n", 431 | " this.ws.onopen = function () {\n", 432 | " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", 433 | " fig.send_message(\"send_image_mode\", {});\n", 434 | " if (mpl.ratio != 1) {\n", 435 | " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", 436 | " }\n", 437 | " fig.send_message(\"refresh\", {});\n", 438 | " }\n", 439 | "\n", 440 | " this.imageObj.onload = function() {\n", 441 | " if (fig.image_mode == 'full') {\n", 442 | " // Full images could contain transparency (where diff images\n", 443 | " // almost always do), so we need to clear the canvas so that\n", 444 | " // there is no ghosting.\n", 445 | " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", 446 | " }\n", 447 | " fig.context.drawImage(fig.imageObj, 0, 0);\n", 448 | " };\n", 449 | "\n", 450 | " this.imageObj.onunload = function() {\n", 451 | " fig.ws.close();\n", 452 | " }\n", 453 | "\n", 454 | " this.ws.onmessage = this._make_on_message_function(this);\n", 455 | "\n", 456 | " this.ondownload = ondownload;\n", 457 | "}\n", 458 | "\n", 459 | "mpl.figure.prototype._init_header = function() {\n", 460 | " var titlebar = $(\n", 461 | " '');\n", 463 | " var titletext = $(\n", 464 | " '');\n", 466 | " titlebar.append(titletext)\n", 467 | " this.root.append(titlebar);\n", 468 | " this.header = titletext[0];\n", 469 | "}\n", 470 | "\n", 471 | "\n", 472 | "\n", 473 | "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", 474 | "\n", 475 | "}\n", 476 | "\n", 477 | "\n", 478 | "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", 479 | "\n", 480 | "}\n", 481 | "\n", 482 | "mpl.figure.prototype._init_canvas = function() {\n", 483 | " var fig = this;\n", 484 | "\n", 485 | " var canvas_div = $('');\n", 486 | "\n", 487 | " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", 488 | "\n", 489 | " function canvas_keyboard_event(event) {\n", 490 | " return fig.key_event(event, event['data']);\n", 491 | " }\n", 492 | "\n", 493 | " canvas_div.keydown('key_press', canvas_keyboard_event);\n", 494 | " canvas_div.keyup('key_release', canvas_keyboard_event);\n", 495 | " this.canvas_div = canvas_div\n", 496 | " this._canvas_extra_style(canvas_div)\n", 497 | " this.root.append(canvas_div);\n", 498 | "\n", 499 | " var canvas = $('');\n", 500 | " canvas.addClass('mpl-canvas');\n", 501 | " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", 502 | "\n", 503 | " this.canvas = canvas[0];\n", 504 | " this.context = canvas[0].getContext(\"2d\");\n", 505 | "\n", 506 | " var backingStore = this.context.backingStorePixelRatio ||\n", 507 | "\tthis.context.webkitBackingStorePixelRatio ||\n", 508 | "\tthis.context.mozBackingStorePixelRatio ||\n", 509 | "\tthis.context.msBackingStorePixelRatio ||\n", 510 | "\tthis.context.oBackingStorePixelRatio ||\n", 511 | "\tthis.context.backingStorePixelRatio || 1;\n", 512 | "\n", 513 | " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", 514 | "\n", 515 | " var rubberband = $('');\n", 516 | " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", 517 | "\n", 518 | " var pass_mouse_events = true;\n", 519 | "\n", 520 | " canvas_div.resizable({\n", 521 | " start: function(event, ui) {\n", 522 | " pass_mouse_events = false;\n", 523 | " },\n", 524 | " resize: function(event, ui) {\n", 525 | " fig.request_resize(ui.size.width, ui.size.height);\n", 526 | " },\n", 527 | " stop: function(event, ui) {\n", 528 | " pass_mouse_events = true;\n", 529 | " fig.request_resize(ui.size.width, ui.size.height);\n", 530 | " },\n", 531 | " });\n", 532 | "\n", 533 | " function mouse_event_fn(event) {\n", 534 | " if (pass_mouse_events)\n", 535 | " return fig.mouse_event(event, event['data']);\n", 536 | " }\n", 537 | "\n", 538 | " rubberband.mousedown('button_press', mouse_event_fn);\n", 539 | " rubberband.mouseup('button_release', mouse_event_fn);\n", 540 | " // Throttle sequential mouse events to 1 every 20ms.\n", 541 | " rubberband.mousemove('motion_notify', mouse_event_fn);\n", 542 | "\n", 543 | " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", 544 | " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", 545 | "\n", 546 | " canvas_div.on(\"wheel\", function (event) {\n", 547 | " event = event.originalEvent;\n", 548 | " event['data'] = 'scroll'\n", 549 | " if (event.deltaY < 0) {\n", 550 | " event.step = 1;\n", 551 | " } else {\n", 552 | " event.step = -1;\n", 553 | " }\n", 554 | " mouse_event_fn(event);\n", 555 | " });\n", 556 | "\n", 557 | " canvas_div.append(canvas);\n", 558 | " canvas_div.append(rubberband);\n", 559 | "\n", 560 | " this.rubberband = rubberband;\n", 561 | " this.rubberband_canvas = rubberband[0];\n", 562 | " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", 563 | " this.rubberband_context.strokeStyle = \"#000000\";\n", 564 | "\n", 565 | " this._resize_canvas = function(width, height) {\n", 566 | " // Keep the size of the canvas, canvas container, and rubber band\n", 567 | " // canvas in synch.\n", 568 | " canvas_div.css('width', width)\n", 569 | " canvas_div.css('height', height)\n", 570 | "\n", 571 | " canvas.attr('width', width * mpl.ratio);\n", 572 | " canvas.attr('height', height * mpl.ratio);\n", 573 | " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", 574 | "\n", 575 | " rubberband.attr('width', width);\n", 576 | " rubberband.attr('height', height);\n", 577 | " }\n", 578 | "\n", 579 | " // Set the figure to an initial 600x600px, this will subsequently be updated\n", 580 | " // upon first draw.\n", 581 | " this._resize_canvas(600, 600);\n", 582 | "\n", 583 | " // Disable right mouse context menu.\n", 584 | " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", 585 | " return false;\n", 586 | " });\n", 587 | "\n", 588 | " function set_focus () {\n", 589 | " canvas.focus();\n", 590 | " canvas_div.focus();\n", 591 | " }\n", 592 | "\n", 593 | " window.setTimeout(set_focus, 100);\n", 594 | "}\n", 595 | "\n", 596 | "mpl.figure.prototype._init_toolbar = function() {\n", 597 | " var fig = this;\n", 598 | "\n", 599 | " var nav_element = $('')\n", 600 | " nav_element.attr('style', 'width: 100%');\n", 601 | " this.root.append(nav_element);\n", 602 | "\n", 603 | " // Define a callback function for later on.\n", 604 | " function toolbar_event(event) {\n", 605 | " return fig.toolbar_button_onclick(event['data']);\n", 606 | " }\n", 607 | " function toolbar_mouse_event(event) {\n", 608 | " return fig.toolbar_button_onmouseover(event['data']);\n", 609 | " }\n", 610 | "\n", 611 | " for(var toolbar_ind in mpl.toolbar_items) {\n", 612 | " var name = mpl.toolbar_items[toolbar_ind][0];\n", 613 | " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", 614 | " var image = mpl.toolbar_items[toolbar_ind][2];\n", 615 | " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", 616 | "\n", 617 | " if (!name) {\n", 618 | " // put a spacer in here.\n", 619 | " continue;\n", 620 | " }\n", 621 | " var button = $('');\n", 622 | " button.addClass('ui-button ui-widget ui-state-default ui-corner-all ' +\n", 623 | " 'ui-button-icon-only');\n", 624 | " button.attr('role', 'button');\n", 625 | " button.attr('aria-disabled', 'false');\n", 626 | " button.click(method_name, toolbar_event);\n", 627 | " button.mouseover(tooltip, toolbar_mouse_event);\n", 628 | "\n", 629 | " var icon_img = $('');\n", 630 | " icon_img.addClass('ui-button-icon-primary ui-icon');\n", 631 | " icon_img.addClass(image);\n", 632 | " icon_img.addClass('ui-corner-all');\n", 633 | "\n", 634 | " var tooltip_span = $('');\n", 635 | " tooltip_span.addClass('ui-button-text');\n", 636 | " tooltip_span.html(tooltip);\n", 637 | "\n", 638 | " button.append(icon_img);\n", 639 | " button.append(tooltip_span);\n", 640 | "\n", 641 | " nav_element.append(button);\n", 642 | " }\n", 643 | "\n", 644 | " var fmt_picker_span = $('');\n", 645 | "\n", 646 | " var fmt_picker = $('');\n", 647 | " fmt_picker.addClass('mpl-toolbar-option ui-widget ui-widget-content');\n", 648 | " fmt_picker_span.append(fmt_picker);\n", 649 | " nav_element.append(fmt_picker_span);\n", 650 | " this.format_dropdown = fmt_picker[0];\n", 651 | "\n", 652 | " for (var ind in mpl.extensions) {\n", 653 | " var fmt = mpl.extensions[ind];\n", 654 | " var option = $(\n", 655 | " '', {selected: fmt === mpl.default_extension}).html(fmt);\n", 656 | " fmt_picker.append(option)\n", 657 | " }\n", 658 | "\n", 659 | " // Add hover states to the ui-buttons\n", 660 | " $( \".ui-button\" ).hover(\n", 661 | " function() { $(this).addClass(\"ui-state-hover\");},\n", 662 | " function() { $(this).removeClass(\"ui-state-hover\");}\n", 663 | " );\n", 664 | "\n", 665 | " var status_bar = $('