├── README.md ├── chord_sequence_generation.py ├── data └── music │ ├── data.json │ └── data.mat ├── eval └── README.md ├── generate_midi.py ├── imagernn ├── Readme.md ├── __init__.py ├── data_provider.py ├── generic_batch_generator.py ├── imagernn_utils.py ├── lstm_generator.py ├── rnn_generator.py ├── solver.py └── utils.py ├── reg_range.py ├── status └── Readme.md └── train.py /README.md: -------------------------------------------------------------------------------- 1 | ## Overview 2 | 3 | - This repository implements melody generation model proposed in [this paper](https://arxiv.org/abs/1710.11549). 4 | 5 | - The **input** is a two-hot vector in which the first 1 corresponds to a certain chord progression of 2-bar lengths (ex: C - Am), and the second 1 corresponds to the part annotation, e.g., verse, chorus, etc. 6 | 7 | - The **output** is a MIDI file with generated melody converted from generated strings. Generated strings are currently in the form of **pitch;pos;duration**. 8 | 9 | - This repository is a modification of [NeuralTalk](https://github.com/karpathy/neuraltalk). 10 | 11 | 12 | ## Dependencies 13 | - **[pretty-midi](https://github.com/craffel/pretty-midi)** 14 | 15 | - **[hmmlearn](https://github.com/hmmlearn/hmmlearn)** 16 | 17 | - **[mido](http://mido.readthedocs.io/en/latest/installing.html)** 18 | 19 | ## Usage 20 | - To train 21 | 22 | `python train.py` 23 | 24 | - To deactivate regularization on pitch range 25 | 26 | `python train.py --reg_range_coeff 0` 27 | 28 | - To set pitch range for regularization (default is 60~72) 29 | 30 | `python train.py --reg_range_min your_min_val --reg_range_max your_max_val` 31 | 32 | - To generate MIDI file 33 | 34 | `python generate_midi.py cv/checkpoint_file` 35 | 36 | - To generate MIDI file with HMM-generated input (by default, song will be generated based on our pre-set test input) 37 | 38 | `python generate_midi.py cv/checkpoint_file --gen_chords True` 39 | 40 | 41 | - Notes are inserted to MIDI files on a real-valued time instead of discrete musical lengths, so make sure to quantize it on any sequencer (e.g. GarageBand). 1/16 is recommended. 42 | 43 | - Check our [demos](https://soundcloud.com/iclr2018eval) 44 | 45 | ## Citation 46 | `@article{andrew2017neuralmelody, 47 | author={Andrew Shin, Leopold Crestel, Hiroharu Kato, Kuniaki Saito, Katsunori Ohnishi, Masataka Yamaguchi, Masahiro Nakawaki, Yoshitaka Ushiku, Tatsuya Harada}, 48 | title={Melody Generation for Pop Music via Word Representation of Musical Properties}, 49 | journal={arXiv preprint arXiv:1710.11549}, 50 | year={2017} 51 | }` 52 | 53 | ## License 54 | BSD license. 55 | -------------------------------------------------------------------------------- /chord_sequence_generation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf8 -*- 3 | 4 | import json 5 | import os 6 | import argparse 7 | import scipy.io as sio 8 | import numpy as np 9 | from hmmlearn.hmm import MultinomialHMM 10 | from sklearn.utils import check_random_state 11 | 12 | # import matplotlib.pylab as plt 13 | 14 | def post_processing_parts(matrix, ratio): 15 | # Cast in int and repeat four time 16 | A = np.repeat(matrix, ratio) 17 | # Add bar index information which might be useful 18 | bar_counter = np.mod(np.arange(len(A)), np.zeros(len(A))+ratio) 19 | B = A * ratio + bar_counter 20 | return B.astype(int) 21 | 22 | 23 | def build_proba(var, cond): 24 | # Count occurences 25 | dim = (int(np.max(var))+1, int(np.max(cond))+1) 26 | proba = np.zeros((dim)) 27 | # Normalize 28 | for (v, c) in zip(var, cond): 29 | proba[int(v), int(c)] += 1 30 | # Normalize along var axis 31 | return np.transpose(proba / proba.sum(axis=0)) 32 | 33 | 34 | class MultinomialHMM_prod(MultinomialHMM): 35 | def __init__(self, n_components=1, 36 | startprob_prior=1.0, transmat_prior=1.0, 37 | algorithm="viterbi", random_state=None, 38 | n_iter=10, tol=1e-2, verbose=False, 39 | params="ste", init_params="ste"): 40 | MultinomialHMM.__init__(self, n_components=n_components, 41 | startprob_prior=startprob_prior, transmat_prior=transmat_prior, 42 | algorithm=algorithm, random_state=random_state, 43 | n_iter=n_iter, tol=tol, verbose=verbose, 44 | params=params, init_params=init_params) 45 | return 46 | 47 | def _generate_sample_from_state_PROD(self, state, cond_matrix, cond, random_state=None): 48 | cum_prod = np.cumsum(self.emissionprob_[state, :] * cond_matrix[cond, :]) 49 | cdf = cum_prod / np.max(cum_prod) 50 | random_state = check_random_state(random_state) 51 | return [(cdf > random_state.rand()).argmax()] 52 | 53 | def sampling_prod_hmm(self, cond_matrix, cond_variable, random_state=None): 54 | n_samples = len(cond_variable) 55 | if random_state is None: 56 | random_state = self.random_state 57 | random_state = check_random_state(random_state) 58 | 59 | startprob_cdf = np.cumsum(self.startprob_) 60 | transmat_cdf = np.cumsum(self.transmat_, axis=1) 61 | 62 | currstate = (startprob_cdf > random_state.rand()).argmax() 63 | curr_cond = cond_variable[0] 64 | state_sequence = [currstate] 65 | 66 | X = [self._generate_sample_from_state_PROD( 67 | currstate, cond_matrix, curr_cond, random_state=random_state)] 68 | 69 | for t in range(n_samples - 1): 70 | currstate = (transmat_cdf[currstate] > random_state.rand()) \ 71 | .argmax() 72 | curr_cond = cond_variable[t+1] 73 | state_sequence.append(currstate) 74 | X.append(self._generate_sample_from_state_PROD( 75 | currstate, cond_matrix, curr_cond, random_state=random_state)) 76 | 77 | return np.atleast_2d(X), np.array(state_sequence, dtype=int) 78 | 79 | def main(params): 80 | DEBUG = params['DEBUG'] 81 | dataset = params['dataset'] 82 | nh_part = params['nh_part'] 83 | nh_chords = params['nh_chords'] 84 | num_gen = params['num_gen'] 85 | 86 | ################################################################## 87 | # DATA PROCESSING 88 | # Songs indices 89 | song_indices = [43,85,133,183,225,265,309,349,413,471,519,560,590,628,670,712,764,792,836,872,918,966,1018,1049,1091,1142,1174,1222,1266,1278,1304,1340,1372,1416,1456,1484,1536,1576,1632,1683,1707,1752,1805,1857,1891,1911] 90 | # Chords mapping 91 | chord_names = ['C;Em', 'A#;F', 'Dm;Em', 'Dm;G', 'Dm;C', 'Am;Em', 'F;C', 'F;G', 'Dm;F', 'C;C', 'C;E', 'Am;G', 'F;Em', 'F;F', 'G;G', 'Am;Am', 'Dm;Dm', 'C;A#', 'Em;F', 'C;G', 'G#;A#', 'F;Am', 'G#;Fm', 'Am;Gm', 'F;E', 'Dm;Am', 'Em;Em', 'G#;G#', 'Em;Am', 'C;Am', 'F;Dm', 'G#;G', 'F;A#', 'Am;G#', 'C;D', 'G;Am', 'Am;C', 'Am;A#', 'A#;G', 'Am;F', 'A#;Am', 'E;Am', 'Dm;E', 'A;G', 'Am;Dm', 'Em;Dm', 'C;F#m', 'Am;D', 'G#;Em', 'C;Dm', 'C;F', 'G;C', 'A#;A#', 'Am;Caug', 'Fm;G', 'A;A'] 92 | 93 | # Import .mat file 94 | dataset_root = os.path.join('data', dataset) 95 | mat_path = os.path.join(dataset_root, 'data.mat') 96 | data_mat = sio.loadmat(mat_path) 97 | chords_per_part = 2 98 | chords_per_bar = 4 99 | num_chords = 56 100 | num_parts = 4 101 | sub_sampling_ratio_parts = chords_per_bar/chords_per_part 102 | 103 | # Get parts 104 | parts_data_ = (np.dot(np.transpose(data_mat["feats"][-num_parts:]), np.asarray(range(num_parts))).astype(int)).reshape(-1, 1) 105 | # Group by bar 106 | parts_data = parts_data_[::sub_sampling_ratio_parts] 107 | # Parts with position in bar. Used condition chords generation 108 | parts_bar_data = post_processing_parts(parts_data, sub_sampling_ratio_parts) 109 | # Get chords transitions 110 | chords_data = (np.dot(np.transpose(data_mat["feats"][:-num_parts]), np.asarray(range(num_chords))).astype(int)).reshape(-1, 1) 111 | 112 | 113 | ################################# 114 | # Group by song 115 | parts_length = [] 116 | chords_length = [] 117 | start_ind = 0 118 | for end_ind in song_indices: 119 | chords_length.append(end_ind - start_ind + 1) 120 | start_ind = end_ind + 1 121 | parts_length = [e/2 for e in chords_length] 122 | ################################################################## 123 | 124 | ################################################################## 125 | # PARTS 126 | # Compute HMM for part modeling 127 | hmm_part = MultinomialHMM(n_components=nh_part, n_iter=20) 128 | hmm_part.fit(parts_data, parts_length) 129 | 130 | # def plot_mat(matrix, name): 131 | # fig = plt.figure() 132 | # ax = fig.add_subplot(1,1,1) 133 | # ax.set_aspect('equal') 134 | # plt.imshow(matrix, interpolation='nearest', cmap=plt.cm.ocean) 135 | # plt.colorbar() 136 | # plt.savefig(name, format='pdf') 137 | 138 | # plot_mat(hmm_part.transmat_, 'part_transmat.pdf') 139 | # plot_mat(np.reshape(hmm_part.startprob_, [-1, 1]), 'part_startprob.pdf') 140 | # plot_mat(hmm_part.emissionprob_, 'part_emissionprob.pdf') 141 | ################################################################## 142 | 143 | ################################################################## 144 | # CHORDS 145 | hmm_chords = MultinomialHMM_prod(n_components=nh_chords, n_iter=20) 146 | hmm_chords.fit(chords_data, chords_length) 147 | # plot_mat(hmm_chords.transmat_, 'chords_transmat.pdf') 148 | # plot_mat(np.reshape(hmm_chords.startprob_, [-1, 1]), 'chords_startprob.pdf') 149 | # plot_mat(hmm_chords.emissionprob_, 'chords_emissionprob.pdf') 150 | ################################################################## 151 | 152 | ################################# 153 | # GENERATION 154 | # Sample sequence 155 | for n in range(num_gen): 156 | gen_part_sequence_, _ = hmm_part.sample(params["gen_seq_length"]) 157 | gen_part_sequence = post_processing_parts(gen_part_sequence_, sub_sampling_ratio_parts) 158 | # Compute conditioning on parts 159 | p_chords_given_partBar = build_proba(chords_data, parts_bar_data) 160 | gen_chord_sequence, _ = hmm_chords.sampling_prod_hmm(p_chords_given_partBar, gen_part_sequence) 161 | ######## T E S T ################ 162 | # Independent HMM ? 163 | # gen_chord_sequence, _ = hmm_chords.sampling(n_samples=44) 164 | ################################## 165 | if params["DEBUG"]: 166 | with open("results_chords/" + str(n), 'wb') as f: 167 | for count, (part, chord) in enumerate(zip(gen_part_sequence, gen_chord_sequence)): 168 | if count % 2 == 0: 169 | f.write(str(part/2) + " ; " + chord_names[chord[0]] + "\n") 170 | else: 171 | f.write(" ; " + chord_names[chord[0]] + "\n") 172 | if count % 8 == 7: 173 | f.write("\n") 174 | gen_part_sequence = [e/2 for e in gen_part_sequence] 175 | return gen_part_sequence, gen_chord_sequence, num_chords, num_parts 176 | 177 | if __name__ == "__main__": 178 | 179 | parser = argparse.ArgumentParser() 180 | 181 | # Data 182 | parser.add_argument('-d', '--dataset', dest='dataset', default='music', help='dataset: flickr8k/flickr30k') 183 | # Parts' HMM 184 | parser.add_argument('--nh_part', dest='nh_part', type=int, default=20, help='number of hidden states for the part\'s HMM') 185 | parser.add_argument('--nh_chords', dest='nh_chords', type=int, default=40, help='number of hidden states for the part\'s HMM') 186 | # Generation 187 | parser.add_argument('--gen_seq_length', type=int, default=8, help='length of the generated sequences') 188 | parser.add_argument('--num_gen', dest='num_gen', type=int, default=10, help='number sequences generated (i.e. sampling n times from the hmm)') 189 | parser.add_argument('--DEBUG', dest='DEBUG', type=bool, default=False, help='True = debug mode on') 190 | 191 | args = parser.parse_args() 192 | params = vars(args) # convert to ordinary dict 193 | print 'parsed parameters:' 194 | print json.dumps(params, indent=2) 195 | 196 | main(params) 197 | -------------------------------------------------------------------------------- /data/music/data.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mil-tokyo/NeuralMelody/f5ac26c8acb01d167e602deefb65d464f94850b8/data/music/data.mat -------------------------------------------------------------------------------- /eval/README.md: -------------------------------------------------------------------------------- 1 | Generated strings are recorded here. 2 | -------------------------------------------------------------------------------- /generate_midi.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import time 4 | import datetime 5 | import numpy as np 6 | import code 7 | import socket 8 | import os 9 | import cPickle as pickle 10 | import math 11 | import pretty_midi 12 | import scipy.io 13 | import re 14 | 15 | from imagernn.data_provider import getDataProvider 16 | from imagernn.solver import Solver 17 | from imagernn.imagernn_utils import decodeGenerator, eval_split 18 | 19 | import chord_sequence_generation 20 | 21 | program_num = 80 22 | base_tempo = 120 23 | new_track = pretty_midi.Instrument(program_num,is_drum=False,name='melody') 24 | bass_track = pretty_midi.Instrument(38,is_drum=False,name='bass') 25 | 26 | def convert_pos(p,c): 27 | if '&' not in p: 28 | return float(p)+float(2*c) 29 | nat,frac = p.split('&') 30 | nu, de = frac.split('/') 31 | nat = int(nat) 32 | nat += 2*c 33 | return float(nat)+(float(nu) / float(de)) 34 | 35 | def convert_dur(d): 36 | if '/' not in d: 37 | return float(d) 38 | nu,de = d.split('/') 39 | return float(nu)/float(de) 40 | 41 | def two_hot_encoding(parts, chords, num_chords, num_parts): 42 | T = len(parts) 43 | output = np.zeros((T, num_chords+num_parts)) 44 | for index, (part, chord) in enumerate(zip(parts, chords)): 45 | output[index, num_chords + part] = 1 46 | output[index, chord] = 1 47 | return output 48 | 49 | def adjust_tempo(new_midi_data): 50 | bpm = new_midi_data.get_tempo_changes()[-1][0] 51 | min_length = 60. / (bpm * 4) 52 | 53 | for instrument in new_midi_data.instruments: 54 | for note in instrument.notes: 55 | note.start *= base_tempo / bpm 56 | note.end *= base_tempo / bpm 57 | 58 | def quantize(new_midi_data): 59 | bpm = new_midi_data.get_tempo_changes()[-1][0] 60 | min_length = 60. / (bpm * 4) 61 | 62 | for instrument in new_midi_data.instruments: 63 | for note in instrument.notes: 64 | note.start = round(note.start / min_length) * min_length 65 | note.end = round(note.end / min_length) * min_length 66 | if note.end - note.start == 0: 67 | note.end += min_length 68 | 69 | def gen_from_scratch(params): 70 | # load the checkpoint 71 | checkpoint_path = params['checkpoint_path'] 72 | max_images = params['max_images'] 73 | fout = params['output_file'] 74 | tempo = params['tempo'] 75 | 76 | print 'loading checkpoint %s' % (checkpoint_path, ) 77 | checkpoint = pickle.load(open(checkpoint_path, 'rb')) 78 | checkpoint_params = checkpoint['params'] 79 | dataset = checkpoint_params['dataset'] 80 | params['dataset'] = dataset 81 | model = checkpoint['model'] 82 | dump_folder = params['dump_folder'] 83 | ixtoword = checkpoint['ixtoword'] 84 | 85 | if dump_folder: 86 | print 'creating dump folder ' + dump_folder 87 | os.system('mkdir -p ' + dump_folder) 88 | 89 | # Generate the chord sequence 90 | parts, chords, num_chords, num_parts = chord_sequence_generation.main(params) 91 | imgs = two_hot_encoding(parts, chords, num_chords, num_parts) 92 | 93 | blob = {} # output blob which we will dump to JSON for visualizing the results 94 | blob['params'] = params 95 | blob['checkpoint_params'] = checkpoint_params 96 | blob['imgblobs'] = [] 97 | 98 | # iterate over all images in test set and predict sentences 99 | BatchGenerator = decodeGenerator(checkpoint_params) 100 | n = 0 101 | candidates=[] 102 | for img in imgs: 103 | n+=1 104 | print 'image %d/%d:' % (n, max_images) 105 | kwparams = { 'beam_size' : params['beam_size'] } 106 | img_dict = {'feat': img} 107 | Ys = BatchGenerator.predict([{'image':img_dict}], model, checkpoint_params, **kwparams) 108 | 109 | # now evaluate and encode the top prediction 110 | top_predictions = Ys[0] # take predictions for the first (and only) image we passed in 111 | top_prediction = top_predictions[0] # these are sorted with highest on top 112 | candidate = ' '.join([ixtoword[ix] for ix in top_prediction[1] if ix > 0]) # ix 0 is the END token, skip that 113 | candidates.append(candidate) 114 | print 'PRED: (%f) %s' % (top_prediction[0], candidate) 115 | 116 | # Write midi 117 | for idx,c in enumerate(candidates): 118 | cs = c.split() 119 | for e in cs: 120 | es=e.split(';') 121 | pitch=int(es[0]) 122 | pos=es[1] 123 | pos=convert_pos(pos,idx) 124 | dur=es[2] 125 | dur=convert_dur(dur) 126 | note=pretty_midi.Note(90,pitch,pos,pos+dur) 127 | new_track.notes.append(note) 128 | 129 | new_midi_data = pretty_midi.PrettyMIDI(initial_tempo=tempo) 130 | new_midi_data.instruments.append(new_track) 131 | 132 | # pre-set chord preogression 133 | chord_names = ['C;Em', 'A#;F', 'Dm;Em', 'Dm;G', 'Dm;C', 'Am;Em', 'F;C', 'F;G', 'Dm;F', 'C;C', 'C;E', 'Am;G', 'F;Em', 'F;F', 'G;G', 'Am;Am', 'Dm;Dm', 'C;A#', 'Em;F', 'C;G', 'G#;A#', 'F;Am', 'G#;Fm', 'Am;Gm', 'F;E', 'Dm;Am', 'Em;Em', 'G#;G#', 'Em;Am', 'C;Am', 'F;Dm', 'G#;G', 'F;A#', 'Am;G#', 'C;D', 'G;Am', 'Am;C', 'Am;A#', 'A#;G', 'Am;F', 'A#;Am', 'E;Am', 'Dm;E', 'A;G', 'Am;Dm', 'Em;Dm', 'C;F#m', 'Am;D', 'G#;Em', 'C;Dm', 'C;F', 'G;C', 'A#;A#', 'Am;Caug', 'Fm;G', 'A;A'] 134 | chord_to_pitch = {'C':36, 'C#':37, 'D':38, 'D#':39, 'E':40, 'F':41, 'F#':42, 'G':43, 'G#':44, 'A':45, 'A#':46, 'B':47} 135 | for time, chord in enumerate(chords): 136 | n1, n2 = re.split(";", chord_names[chord[0]]) 137 | n1, n2 = re.sub("m", "", n1), re.sub("m", "", n2) 138 | bass_track.notes.append(pretty_midi.Note(90,chord_to_pitch[n1],2*time,2*time+1)) 139 | bass_track.notes.append(pretty_midi.Note(90,chord_to_pitch[n2],2*time+1,2*(time+1))) 140 | new_midi_data.instruments.append(bass_track) 141 | adjust_tempo(new_midi_data) 142 | if params['quantize']: 143 | quantize(new_midi_data) 144 | new_midi_data.write(fout) 145 | 146 | def gen_from_test(params): 147 | # load the checkpoint 148 | checkpoint_path = params['checkpoint_path'] 149 | max_images = params['max_images'] 150 | fout = params['output_file'] 151 | tempo = params['tempo'] 152 | 153 | print 'loading checkpoint %s' % (checkpoint_path, ) 154 | checkpoint = pickle.load(open(checkpoint_path, 'rb')) 155 | checkpoint_params = checkpoint['params'] 156 | dataset = checkpoint_params['dataset'] 157 | model = checkpoint['model'] 158 | dump_folder = params['dump_folder'] 159 | 160 | if dump_folder: 161 | print 'creating dump folder ' + dump_folder 162 | os.system('mkdir -p ' + dump_folder) 163 | 164 | # fetch the data provider 165 | dp = getDataProvider(dataset) 166 | 167 | misc = {} 168 | misc['wordtoix'] = checkpoint['wordtoix'] 169 | ixtoword = checkpoint['ixtoword'] 170 | 171 | blob = {} # output blob which we will dump to JSON for visualizing the results 172 | blob['params'] = params 173 | blob['checkpoint_params'] = checkpoint_params 174 | blob['imgblobs'] = [] 175 | 176 | # iterate over all images in test set and predict sentences 177 | BatchGenerator = decodeGenerator(checkpoint_params) 178 | n = 0 179 | all_references = [] 180 | all_candidates = [] 181 | candidates=[] 182 | for img in dp.iterImages(split = 'test', max_images = max_images): 183 | n+=1 184 | print 'image %d/%d:' % (n, max_images) 185 | references = [' '.join(x['tokens']) for x in img['sentences']] # as list of lists of tokens 186 | kwparams = { 'beam_size' : params['beam_size'] } 187 | 188 | Ys = BatchGenerator.predict([{'image':img}], model, checkpoint_params, **kwparams) 189 | 190 | img_blob = {} # we will build this up 191 | img_blob['img_path'] = img['local_file_path'] 192 | img_blob['imgid'] = img['imgid'] 193 | 194 | if dump_folder: 195 | # copy source file to some folder. This makes it easier to distribute results 196 | # into a webpage, because all images that were predicted on are in a single folder 197 | source_file = img['local_file_path'] 198 | target_file = os.path.join(dump_folder, os.path.basename(img['local_file_path'])) 199 | os.system('cp %s %s' % (source_file, target_file)) 200 | 201 | # encode the human-provided references 202 | img_blob['references'] = [] 203 | for gtsent in references: 204 | print 'GT: ' + gtsent 205 | img_blob['references'].append({'text': gtsent}) 206 | 207 | # now evaluate and encode the top prediction 208 | top_predictions = Ys[0] # take predictions for the first (and only) image we passed in 209 | top_prediction = top_predictions[0] # these are sorted with highest on top 210 | candidate = ' '.join([ixtoword[ix] for ix in top_prediction[1] if ix > 0]) # ix 0 is the END token, skip that 211 | candidates.append(candidate) 212 | print 'PRED: (%f) %s' % (top_prediction[0], candidate) 213 | 214 | # save for later eval 215 | all_references.append(references) 216 | all_candidates.append(candidate) 217 | 218 | img_blob['candidate'] = {'text': candidate, 'logprob': top_prediction[0]} 219 | blob['imgblobs'].append(img_blob) 220 | 221 | # use perl script to eval BLEU score for fair comparison to other research work 222 | # first write intermediate files 223 | print 'writing intermediate files into eval/' 224 | open('eval/output', 'w').write('\n'.join(all_candidates)) 225 | for q in xrange(1): 226 | open('eval/reference'+`q`, 'w').write('\n'.join([x[q] for x in all_references])) 227 | # invoke the perl script to get BLEU scores 228 | print 'invoking eval/multi-bleu.perl script...' 229 | owd = os.getcwd() 230 | os.chdir('eval') 231 | os.system('./multi-bleu.perl reference < output') 232 | os.chdir(owd) 233 | 234 | # now also evaluate test split perplexity 235 | gtppl = eval_split('test', dp, model, checkpoint_params, misc, eval_max_images = max_images) 236 | print 'perplexity of ground truth words based on dictionary of %d words: %f' % (len(ixtoword), gtppl) 237 | blob['gtppl'] = gtppl 238 | 239 | # dump result struct to file 240 | # print 'saving result struct to %s' % (params['result_struct_filename'], ) 241 | # json.dump(blob, open(params['result_struct_filename'], 'w')) 242 | 243 | for idx,c in enumerate(candidates): 244 | cs = c.split() 245 | for e in cs: 246 | es=e.split(';') 247 | pitch=int(es[0]) 248 | pos=es[1] 249 | pos=convert_pos(pos,idx) 250 | dur=es[2] 251 | dur=convert_dur(dur) 252 | note=pretty_midi.Note(90,pitch,pos,pos+dur) 253 | new_track.notes.append(note) 254 | 255 | new_midi_data = pretty_midi.PrettyMIDI(initial_tempo=tempo) 256 | new_midi_data.instruments.append(new_track) 257 | 258 | # pre-set chord preogression 259 | bass_track.notes.append(pretty_midi.Note(90,36,0,1)) 260 | bass_track.notes.append(pretty_midi.Note(90,47,1,2)) 261 | bass_track.notes.append(pretty_midi.Note(90,45,2,3)) 262 | bass_track.notes.append(pretty_midi.Note(90,43,3,4)) 263 | bass_track.notes.append(pretty_midi.Note(90,41,4,5)) 264 | bass_track.notes.append(pretty_midi.Note(90,40,5,6)) 265 | bass_track.notes.append(pretty_midi.Note(90,38,6,7)) 266 | bass_track.notes.append(pretty_midi.Note(90,43,7,8)) 267 | 268 | bass_track.notes.append(pretty_midi.Note(90,36,8,9)) 269 | bass_track.notes.append(pretty_midi.Note(90,47,9,10)) 270 | bass_track.notes.append(pretty_midi.Note(90,45,10,11)) 271 | bass_track.notes.append(pretty_midi.Note(90,43,11,12)) 272 | bass_track.notes.append(pretty_midi.Note(90,41,12,13)) 273 | bass_track.notes.append(pretty_midi.Note(90,40,13,14)) 274 | bass_track.notes.append(pretty_midi.Note(90,38,14,15)) 275 | bass_track.notes.append(pretty_midi.Note(90,43,15,16)) 276 | 277 | bass_track.notes.append(pretty_midi.Note(90,45,16,17)) 278 | bass_track.notes.append(pretty_midi.Note(90,41,17,18)) 279 | bass_track.notes.append(pretty_midi.Note(90,36,18,19)) 280 | bass_track.notes.append(pretty_midi.Note(90,43,19,20)) 281 | bass_track.notes.append(pretty_midi.Note(90,45,20,21)) 282 | bass_track.notes.append(pretty_midi.Note(90,41,21,22)) 283 | bass_track.notes.append(pretty_midi.Note(90,43,22,23)) 284 | bass_track.notes.append(pretty_midi.Note(90,43,23,24)) 285 | 286 | bass_track.notes.append(pretty_midi.Note(90,36,24,25)) 287 | bass_track.notes.append(pretty_midi.Note(90,47,25,26)) 288 | bass_track.notes.append(pretty_midi.Note(90,45,26,27)) 289 | bass_track.notes.append(pretty_midi.Note(90,43,27,28)) 290 | bass_track.notes.append(pretty_midi.Note(90,41,28,29)) 291 | bass_track.notes.append(pretty_midi.Note(90,40,29,30)) 292 | bass_track.notes.append(pretty_midi.Note(90,38,30,31)) 293 | bass_track.notes.append(pretty_midi.Note(90,43,31,32)) 294 | 295 | bass_track.notes.append(pretty_midi.Note(90,36,32,33)) 296 | bass_track.notes.append(pretty_midi.Note(90,47,33,34)) 297 | bass_track.notes.append(pretty_midi.Note(90,45,34,35)) 298 | bass_track.notes.append(pretty_midi.Note(90,43,35,36)) 299 | bass_track.notes.append(pretty_midi.Note(90,41,36,37)) 300 | bass_track.notes.append(pretty_midi.Note(90,40,37,38)) 301 | bass_track.notes.append(pretty_midi.Note(90,38,38,39)) 302 | bass_track.notes.append(pretty_midi.Note(90,43,39,40)) 303 | 304 | new_midi_data.instruments.append(bass_track) 305 | adjust_tempo(new_midi_data) 306 | if params['quantize']: 307 | quantize(new_midi_data) 308 | new_midi_data.write(fout) 309 | 310 | 311 | def main(params): 312 | if params["gen_chords"]: 313 | gen_from_scratch(params) 314 | else: 315 | gen_from_test(params) 316 | 317 | if __name__ == "__main__": 318 | 319 | parser = argparse.ArgumentParser() 320 | parser.add_argument('checkpoint_path', type=str, help='the input checkpoint') 321 | parser.add_argument('-b', '--beam_size', type=int, default=1, help='beam size in inference. 1 indicates greedy per-word max procedure. Good value is approx 20 or so, and more = better.') 322 | parser.add_argument('-m', '--max_images', type=int, default=-1, help='max images to use') 323 | parser.add_argument('-d', '--dump_folder', type=str, default="", help='dump the relevant images to a separate folder with this name?') 324 | parser.add_argument('-o', '--output_file', type=str, default="generate_test.mid",help='name of the midi file generated') 325 | 326 | # Chords sequence generation ? 327 | parser.add_argument('--gen_chords', type=bool, default=False, help='whether the chords and parts are automatically generated or picked from the test set') 328 | parser.add_argument('--gen_seq_length', type=int, default=8, help='length of the generated sequences') 329 | parser.add_argument('--nh_part', dest='nh_part', type=int, default=20, help='number of hidden states for the part\'s HMM') 330 | parser.add_argument('--nh_chords', dest='nh_chords', type=int, default=40, help='number of hidden states for the part\'s HMM') 331 | parser.add_argument('--num_gen', dest='num_gen', type=int, default=1, help='number of sequences generated') 332 | parser.add_argument('--quantize', dest='quantize', type=int, default=0, help='') 333 | parser.add_argument('--tempo', dest='tempo', type=int, default=120, help='beats per minute') 334 | 335 | parser.add_argument('--DEBUG', type=bool, default=False, help='debug mode') 336 | 337 | args = parser.parse_args() 338 | params = vars(args) # convert to ordinary dict 339 | print 'parsed parameters:' 340 | print json.dumps(params, indent = 2) 341 | main(params) 342 | -------------------------------------------------------------------------------- /imagernn/Readme.md: -------------------------------------------------------------------------------- 1 | The code is organized as follows: 2 | 3 | - `data_provider.py` abstracts away the datasets and provides uniform API for the code. 4 | - `utils.py` is what it sounds like it is :) 5 | - `solver.py`: the solver class doesn't know anything about images or sentences, it gets a model and the gradients and performs a step update 6 | - `generic_batch_generator.py` handles batching across a batch of image/sentences that need to be forwarded through the networks. It calls the 7 | - `lstm_generator.py`, which is an implementation of the Google LSTM for generating images. 8 | - `imagernn_utils.py` contains some image-rnn specific utilities, such as evaluation function etc. These come in handy when we want to use some functionality across different scripts (e.g. driver and evaluator) 9 | - `rnn_generator.py` has a simple RNN implementation for now, an alternative to LSTM 10 | -------------------------------------------------------------------------------- /imagernn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mil-tokyo/NeuralMelody/f5ac26c8acb01d167e602deefb65d464f94850b8/imagernn/__init__.py -------------------------------------------------------------------------------- /imagernn/data_provider.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | import scipy.io 5 | import codecs 6 | from collections import defaultdict 7 | 8 | class BasicDataProvider: 9 | def __init__(self, dataset): 10 | print 'Initializing data provider for dataset %s...' % (dataset, ) 11 | 12 | # !assumptions on folder structure 13 | self.dataset_root = os.path.join('data', dataset) 14 | self.image_root = os.path.join('data', dataset, 'imgs') 15 | 16 | # load the dataset into memory 17 | dataset_path = os.path.join(self.dataset_root, 'data.json') 18 | print 'BasicDataProvider: reading %s' % (dataset_path, ) 19 | self.dataset = json.load(open(dataset_path, 'r')) 20 | 21 | # load the image features into memory 22 | features_path = os.path.join(self.dataset_root, 'data.mat') 23 | print 'BasicDataProvider: reading %s' % (features_path, ) 24 | features_struct = scipy.io.loadmat(features_path) 25 | self.features = features_struct['feats'] 26 | 27 | # group images by their train/val/test split into a dictionary -> list structure 28 | self.split = defaultdict(list) 29 | for img in self.dataset['images']: 30 | self.split[img['split']].append(img) 31 | 32 | # "PRIVATE" FUNCTIONS 33 | # in future we may want to create copies here so that we don't touch the 34 | # data provider class data, but for now lets do the simple thing and 35 | # just return raw internal img sent structs. This also has the advantage 36 | # that the driver could store various useful caching stuff in these structs 37 | # and they will be returned in the future with the cache present 38 | def _getImage(self, img): 39 | """ create an image structure for the driver """ 40 | 41 | # lazily fill in some attributes 42 | if not 'local_file_path' in img: img['local_file_path'] = os.path.join(self.image_root, img['filename']) 43 | if not 'feat' in img: # also fill in the features 44 | feature_index = img['imgid'] # NOTE: imgid is an integer, and it indexes into features 45 | img['feat'] = self.features[:,feature_index] 46 | return img 47 | 48 | def _getSentence(self, sent): 49 | """ create a sentence structure for the driver """ 50 | # NOOP for now 51 | return sent 52 | 53 | # PUBLIC FUNCTIONS 54 | 55 | def getSplitSize(self, split, ofwhat = 'sentences'): 56 | """ return size of a split, either number of sentences or number of images """ 57 | if ofwhat == 'sentences': 58 | return sum(len(img['sentences']) for img in self.split[split]) 59 | else: # assume images 60 | return len(self.split[split]) 61 | 62 | def sampleImageSentencePair(self, split = 'train'): 63 | """ sample image sentence pair from a split """ 64 | images = self.split[split] 65 | 66 | img = random.choice(images) 67 | sent = random.choice(img['sentences']) 68 | 69 | out = {} 70 | out['image'] = self._getImage(img) 71 | out['sentence'] = self._getSentence(sent) 72 | return out 73 | 74 | def iterImageSentencePair(self, split = 'train', max_images = -1): 75 | for i,img in enumerate(self.split[split]): 76 | if max_images >= 0 and i >= max_images: break 77 | for sent in img['sentences']: 78 | out = {} 79 | out['image'] = self._getImage(img) 80 | out['sentence'] = self._getSentence(sent) 81 | yield out 82 | 83 | def iterImageSentencePairBatch(self, split = 'train', max_images = -1, max_batch_size = 100): 84 | batch = [] 85 | for i,img in enumerate(self.split[split]): 86 | if max_images >= 0 and i >= max_images: break 87 | for sent in img['sentences']: 88 | out = {} 89 | out['image'] = self._getImage(img) 90 | out['sentence'] = self._getSentence(sent) 91 | batch.append(out) 92 | if len(batch) >= max_batch_size: 93 | yield batch 94 | batch = [] 95 | if batch: 96 | yield batch 97 | 98 | def iterSentences(self, split = 'train'): 99 | for img in self.split[split]: 100 | for sent in img['sentences']: 101 | yield self._getSentence(sent) 102 | 103 | def iterImages(self, split = 'train', shuffle = False, max_images = -1): 104 | imglist = self.split[split] 105 | ix = range(len(imglist)) 106 | if shuffle: 107 | random.shuffle(ix) 108 | if max_images > 0: 109 | ix = ix[:min(len(ix),max_images)] # crop the list 110 | for i in ix: 111 | yield self._getImage(imglist[i]) 112 | 113 | def getDataProvider(dataset): 114 | """ we could intercept a special dataset and return different data providers """ 115 | assert dataset in ['flickr8k', 'flickr30k', 'coco','music'], 'dataset %s unknown' % (dataset, ) 116 | return BasicDataProvider(dataset) 117 | -------------------------------------------------------------------------------- /imagernn/generic_batch_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import code 3 | from imagernn.utils import merge_init_structs, initw, accumNpDicts 4 | from imagernn.lstm_generator import LSTMGenerator 5 | from imagernn.rnn_generator import RNNGenerator 6 | 7 | def decodeGenerator(generator): 8 | if generator == 'lstm': 9 | return LSTMGenerator 10 | if generator == 'rnn': 11 | return RNNGenerator 12 | else: 13 | raise Exception('generator %s is not yet supported' % (base_generator_str,)) 14 | 15 | class GenericBatchGenerator: 16 | """ 17 | Base batch generator class. 18 | This class is aware of the fact that we are generating 19 | sentences from images. 20 | """ 21 | 22 | @staticmethod 23 | def init(params, misc): 24 | 25 | # inputs 26 | image_encoding_size = params.get('image_encoding_size', 128) 27 | word_encoding_size = params.get('word_encoding_size', 128) 28 | hidden_size = params.get('hidden_size', 128) 29 | generator = params.get('generator', 'lstm') 30 | vocabulary_size = len(misc['wordtoix']) 31 | output_size = len(misc['ixtoword']) # these should match though 32 | image_size = 60 # size of CNN vectors hardcoded here 33 | 34 | if generator == 'lstm': 35 | assert image_encoding_size == word_encoding_size, 'this implementation does not support different sizes for these parameters' 36 | 37 | # initialize the encoder models 38 | model = {} 39 | model['We'] = initw(image_size, image_encoding_size) # image encoder 40 | model['be'] = np.zeros((1,image_encoding_size)) 41 | model['Ws'] = initw(vocabulary_size, word_encoding_size) # word encoder 42 | update = ['We', 'be', 'Ws'] 43 | regularize = ['We', 'Ws'] 44 | init_struct = { 'model' : model, 'update' : update, 'regularize' : regularize} 45 | 46 | # descend into the specific Generator and initialize it 47 | Generator = decodeGenerator(generator) 48 | generator_init_struct = Generator.init(word_encoding_size, hidden_size, output_size) 49 | merge_init_structs(init_struct, generator_init_struct) 50 | return init_struct 51 | 52 | @staticmethod 53 | def forward(batch, model, params, misc, predict_mode = False): 54 | """ iterates over items in the batch and calls generators on them """ 55 | # we do the encoding here across all images/words in batch in single matrix 56 | # multiplies to gain efficiency. The RNNs are then called individually 57 | # in for loop on per-image-sentence pair and all they are concerned about is 58 | # taking single matrix of vectors and doing the forward/backward pass without 59 | # knowing anything about images, sentences or anything of that sort. 60 | 61 | # encode all images 62 | # concatenate as rows. If N is number of image-sentence pairs, 63 | # F will be N x image_size 64 | F = np.row_stack(x['image']['feat'] for x in batch) 65 | We = model['We'] 66 | be = model['be'] 67 | Xe = F.dot(We) + be # Xe becomes N x image_encoding_size 68 | 69 | # decode the generator we wish to use 70 | generator_str = params.get('generator', 'lstm') 71 | Generator = decodeGenerator(generator_str) 72 | 73 | # encode all words in all sentences (which exist in our vocab) 74 | wordtoix = misc['wordtoix'] 75 | Ws = model['Ws'] 76 | gen_caches = [] 77 | Ys = [] # outputs 78 | for i,x in enumerate(batch): 79 | # take all words in this sentence and pluck out their word vectors 80 | # from Ws. Then arrange them in a single matrix Xs 81 | # Note that we are setting the start token as first vector 82 | # and then all the words afterwards. And start token is the first row of Ws 83 | ix = [0] + [ wordtoix[w] for w in x['sentence']['tokens'] if w in wordtoix ] 84 | Xs = np.row_stack( [Ws[j, :] for j in ix] ) 85 | Xi = Xe[i,:] 86 | 87 | # forward prop through the RNN 88 | gen_Y, gen_cache = Generator.forward(Xi, Xs, model, params, predict_mode = predict_mode) 89 | gen_caches.append((ix, gen_cache)) 90 | Ys.append(gen_Y) 91 | 92 | # back up information we need for efficient backprop 93 | cache = {} 94 | if not predict_mode: 95 | # ok we need cache as well because we'll do backward pass 96 | cache['gen_caches'] = gen_caches 97 | cache['Xe'] = Xe 98 | cache['Ws_shape'] = Ws.shape 99 | cache['F'] = F 100 | cache['generator_str'] = generator_str 101 | 102 | return Ys, cache 103 | 104 | @staticmethod 105 | def backward(dY, cache): 106 | Xe = cache['Xe'] 107 | generator_str = cache['generator_str'] 108 | dWs = np.zeros(cache['Ws_shape']) 109 | gen_caches = cache['gen_caches'] 110 | F = cache['F'] 111 | dXe = np.zeros(Xe.shape) 112 | 113 | Generator = decodeGenerator(generator_str) 114 | 115 | # backprop each item in the batch 116 | grads = {} 117 | for i in xrange(len(gen_caches)): 118 | ix, gen_cache = gen_caches[i] # unpack 119 | local_grads = Generator.backward(dY[i], gen_cache) 120 | dXs = local_grads['dXs'] # intercept the gradients wrt Xi and Xs 121 | del local_grads['dXs'] 122 | dXi = local_grads['dXi'] 123 | del local_grads['dXi'] 124 | accumNpDicts(grads, local_grads) # add up the gradients wrt model parameters 125 | 126 | # now backprop from dXs to the image vector and word vectors 127 | dXe[i,:] += dXi # image vector 128 | for n,j in enumerate(ix): # and now all the other words 129 | dWs[j,:] += dXs[n,:] 130 | 131 | # finally backprop into the image encoder 132 | dWe = F.transpose().dot(dXe) 133 | dbe = np.sum(dXe, axis=0, keepdims = True) 134 | 135 | accumNpDicts(grads, { 'We':dWe, 'be':dbe, 'Ws':dWs }) 136 | return grads 137 | 138 | @staticmethod 139 | def predict(batch, model, params, **kwparams): 140 | """ some code duplication here with forward pass, but I think we want the freedom in future """ 141 | F = np.row_stack(x['image']['feat'] for x in batch) 142 | We = model['We'] 143 | be = model['be'] 144 | Xe = F.dot(We) + be # Xe becomes N x image_encoding_size 145 | generator_str = params['generator'] 146 | Generator = decodeGenerator(generator_str) 147 | Ys = [] 148 | for i,x in enumerate(batch): 149 | gen_Y = Generator.predict(Xe[i, :], model, model['Ws'], params, **kwparams) 150 | Ys.append(gen_Y) 151 | return Ys 152 | 153 | 154 | -------------------------------------------------------------------------------- /imagernn/imagernn_utils.py: -------------------------------------------------------------------------------- 1 | from imagernn.generic_batch_generator import GenericBatchGenerator 2 | import numpy as np 3 | 4 | def decodeGenerator(params): 5 | """ 6 | in the future we may want to have different classes 7 | and options for them. For now there is this one generator 8 | implemented and simply returned here. 9 | """ 10 | return GenericBatchGenerator 11 | 12 | def eval_split(split, dp, model, params, misc, **kwargs): 13 | """ evaluate performance on a given split """ 14 | # allow kwargs to override what is inside params 15 | eval_batch_size = kwargs.get('eval_batch_size', params.get('eval_batch_size',100)) 16 | eval_max_images = kwargs.get('eval_max_images', params.get('eval_max_images', -1)) 17 | BatchGenerator = decodeGenerator(params) 18 | wordtoix = misc['wordtoix'] 19 | 20 | print 'evaluating %s performance in batches of %d' % (split, eval_batch_size) 21 | logppl = 0 22 | logppln = 0 23 | nsent = 0 24 | for batch in dp.iterImageSentencePairBatch(split = split, max_batch_size = eval_batch_size, max_images = eval_max_images): 25 | Ys, gen_caches = BatchGenerator.forward(batch, model, params, misc, predict_mode = True) 26 | 27 | for i,pair in enumerate(batch): 28 | gtix = [ wordtoix[w] for w in pair['sentence']['tokens'] if w in wordtoix ] 29 | gtix.append(0) # we expect END token at the end 30 | Y = Ys[i] 31 | maxes = np.amax(Y, axis=1, keepdims=True) 32 | e = np.exp(Y - maxes) # for numerical stability shift into good numerical range 33 | P = e / np.sum(e, axis=1, keepdims=True) 34 | logppl += - np.sum(np.log2(1e-20 + P[range(len(gtix)),gtix])) # also accumulate log2 perplexities 35 | logppln += len(gtix) 36 | nsent += 1 37 | 38 | ppl2 = 2 ** (logppl / logppln) 39 | print 'evaluated %d sentences and got perplexity = %f' % (nsent, ppl2) 40 | return ppl2 # return the perplexity 41 | -------------------------------------------------------------------------------- /imagernn/lstm_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import code 3 | 4 | from imagernn.utils import initw 5 | 6 | class LSTMGenerator: 7 | """ 8 | A multimodal long short-term memory (LSTM) generator 9 | """ 10 | 11 | @staticmethod 12 | def init(input_size, hidden_size, output_size): 13 | 14 | model = {} 15 | # Recurrent weights: take x_t, h_{t-1}, and bias unit 16 | # and produce the 3 gates and the input to cell signal 17 | model['WLSTM'] = initw(input_size + hidden_size + 1, 4 * hidden_size) 18 | # Decoder weights (e.g. mapping to vocabulary) 19 | model['Wd'] = initw(hidden_size, output_size) # decoder 20 | model['bd'] = np.zeros((1, output_size)) 21 | 22 | update = ['WLSTM', 'Wd', 'bd'] 23 | regularize = ['WLSTM', 'Wd'] 24 | return { 'model' : model, 'update' : update, 'regularize' : regularize } 25 | 26 | @staticmethod 27 | def forward(Xi, Xs, model, params, **kwargs): 28 | """ 29 | Xi is 1-d array of size D (containing the image representation) 30 | Xs is N x D (N time steps, rows are data containng word representations), and 31 | it is assumed that the first row is already filled in as the start token. So a 32 | sentence with 10 words will be of size 11xD in Xs. 33 | """ 34 | predict_mode = kwargs.get('predict_mode', False) 35 | 36 | # Google paper concatenates the image to the word vectors as the first word vector 37 | X = np.row_stack([Xi, Xs]) 38 | 39 | # options 40 | # use the version of LSTM with tanh? Otherwise dont use tanh (Google style) 41 | # following http://arxiv.org/abs/1409.3215 42 | tanhC_version = params.get('tanhC_version', 0) 43 | drop_prob_encoder = params.get('drop_prob_encoder', 0.0) 44 | drop_prob_decoder = params.get('drop_prob_decoder', 0.0) 45 | 46 | if drop_prob_encoder > 0: # if we want dropout on the encoder 47 | # inverted version of dropout here. Suppose the drop_prob is 0.5, then during training 48 | # we are going to drop half of the units. In this inverted version we also boost the activations 49 | # of the remaining 50% by 2.0 (scale). The nice property of this is that during prediction time 50 | # we don't have to do any scailing, since all 100% of units will be active, but at their base 51 | # firing rate, giving 100% of the "energy". So the neurons later in the pipeline dont't change 52 | # their expected firing rate magnitudes 53 | if not predict_mode: # and we are in training mode 54 | scale = 1.0 / (1.0 - drop_prob_encoder) 55 | U = (np.random.rand(*(X.shape)) < (1 - drop_prob_encoder)) * scale # generate scaled mask 56 | X *= U # drop! 57 | 58 | # follows http://arxiv.org/pdf/1409.2329.pdf 59 | WLSTM = model['WLSTM'] 60 | n = X.shape[0] 61 | d = model['Wd'].shape[0] # size of hidden layer 62 | Hin = np.zeros((n, WLSTM.shape[0])) # xt, ht-1, bias 63 | Hout = np.zeros((n, d)) 64 | IFOG = np.zeros((n, d * 4)) 65 | IFOGf = np.zeros((n, d * 4)) # after nonlinearity 66 | C = np.zeros((n, d)) 67 | for t in xrange(n): 68 | # set input 69 | prev = np.zeros(d) if t == 0 else Hout[t-1] 70 | Hin[t,0] = 1 71 | Hin[t,1:1+d] = X[t] 72 | Hin[t,1+d:] = prev 73 | 74 | # compute all gate activations. dots: 75 | IFOG[t] = Hin[t].dot(WLSTM) 76 | 77 | # non-linearities 78 | IFOGf[t,:3*d] = 1.0/(1.0+np.exp(-IFOG[t,:3*d])) # sigmoids; these are the gates 79 | IFOGf[t,3*d:] = np.tanh(IFOG[t, 3*d:]) # tanh 80 | 81 | # compute the cell activation 82 | C[t] = IFOGf[t,:d] * IFOGf[t, 3*d:] 83 | if t > 0: C[t] += IFOGf[t,d:2*d] * C[t-1] 84 | if tanhC_version: 85 | Hout[t] = IFOGf[t,2*d:3*d] * np.tanh(C[t]) 86 | else: 87 | Hout[t] = IFOGf[t,2*d:3*d] * C[t] 88 | 89 | if drop_prob_decoder > 0: # if we want dropout on the decoder 90 | if not predict_mode: # and we are in training mode 91 | scale2 = 1.0 / (1.0 - drop_prob_decoder) 92 | U2 = (np.random.rand(*(Hout.shape)) < (1 - drop_prob_decoder)) * scale2 # generate scaled mask 93 | Hout *= U2 # drop! 94 | 95 | # decoder at the end 96 | Wd = model['Wd'] 97 | bd = model['bd'] 98 | # NOTE1: we are leaving out the first prediction, which was made for the image 99 | # and is meaningless. 100 | Y = Hout[1:, :].dot(Wd) + bd 101 | 102 | cache = {} 103 | if not predict_mode: 104 | # we can expect to do a backward pass 105 | cache['WLSTM'] = WLSTM 106 | cache['Hout'] = Hout 107 | cache['Wd'] = Wd 108 | cache['IFOGf'] = IFOGf 109 | cache['IFOG'] = IFOG 110 | cache['C'] = C 111 | cache['X'] = X 112 | cache['Hin'] = Hin 113 | cache['tanhC_version'] = tanhC_version 114 | cache['drop_prob_encoder'] = drop_prob_encoder 115 | cache['drop_prob_decoder'] = drop_prob_decoder 116 | if drop_prob_encoder > 0: cache['U'] = U # keep the dropout masks around for backprop 117 | if drop_prob_decoder > 0: cache['U2'] = U2 118 | 119 | return Y, cache 120 | 121 | @staticmethod 122 | def backward(dY, cache): 123 | 124 | Wd = cache['Wd'] 125 | Hout = cache['Hout'] 126 | IFOG = cache['IFOG'] 127 | IFOGf = cache['IFOGf'] 128 | C = cache['C'] 129 | Hin = cache['Hin'] 130 | WLSTM = cache['WLSTM'] 131 | X = cache['X'] 132 | tanhC_version = cache['tanhC_version'] 133 | drop_prob_encoder = cache['drop_prob_encoder'] 134 | drop_prob_decoder = cache['drop_prob_decoder'] 135 | n,d = Hout.shape 136 | 137 | # we have to add back a row of zeros, since in the forward pass 138 | # this information was not used. See NOTE1 above. 139 | dY = np.row_stack([np.zeros(dY.shape[1]), dY]) 140 | 141 | # backprop the decoder 142 | dWd = Hout.transpose().dot(dY) 143 | dbd = np.sum(dY, axis=0, keepdims = True) 144 | dHout = dY.dot(Wd.transpose()) 145 | 146 | # backprop dropout, if it was applied 147 | if drop_prob_decoder > 0: 148 | dHout *= cache['U2'] 149 | 150 | # backprop the LSTM 151 | dIFOG = np.zeros(IFOG.shape) 152 | dIFOGf = np.zeros(IFOGf.shape) 153 | dWLSTM = np.zeros(WLSTM.shape) 154 | dHin = np.zeros(Hin.shape) 155 | dC = np.zeros(C.shape) 156 | dX = np.zeros(X.shape) 157 | for t in reversed(xrange(n)): 158 | 159 | if tanhC_version: 160 | tanhCt = np.tanh(C[t]) # recompute this here 161 | dIFOGf[t,2*d:3*d] = tanhCt * dHout[t] 162 | # backprop tanh non-linearity first then continue backprop 163 | dC[t] += (1-tanhCt**2) * (IFOGf[t,2*d:3*d] * dHout[t]) 164 | else: 165 | dIFOGf[t,2*d:3*d] = C[t] * dHout[t] 166 | dC[t] += IFOGf[t,2*d:3*d] * dHout[t] 167 | 168 | if t > 0: 169 | dIFOGf[t,d:2*d] = C[t-1] * dC[t] 170 | dC[t-1] += IFOGf[t,d:2*d] * dC[t] 171 | dIFOGf[t,:d] = IFOGf[t, 3*d:] * dC[t] 172 | dIFOGf[t, 3*d:] = IFOGf[t,:d] * dC[t] 173 | 174 | # backprop activation functions 175 | dIFOG[t,3*d:] = (1 - IFOGf[t, 3*d:] ** 2) * dIFOGf[t,3*d:] 176 | y = IFOGf[t,:3*d] 177 | dIFOG[t,:3*d] = (y*(1.0-y)) * dIFOGf[t,:3*d] 178 | 179 | # backprop matrix multiply 180 | dWLSTM += np.outer(Hin[t], dIFOG[t]) 181 | dHin[t] = dIFOG[t].dot(WLSTM.transpose()) 182 | 183 | # backprop the identity transforms into Hin 184 | dX[t] = dHin[t,1:1+d] 185 | if t > 0: 186 | dHout[t-1] += dHin[t,1+d:] 187 | 188 | if drop_prob_encoder > 0: # backprop encoder dropout 189 | dX *= cache['U'] 190 | 191 | return { 'WLSTM': dWLSTM, 'Wd': dWd, 'bd': dbd, 'dXi': dX[0,:], 'dXs': dX[1:,:] } 192 | 193 | @staticmethod 194 | def predict(Xi, model, Ws, params, **kwargs): 195 | """ 196 | Run in prediction mode with beam search. The input is the vector Xi, which 197 | should be a 1-D array that contains the encoded image vector. We go from there. 198 | Ws should be NxD array where N is size of vocabulary + 1. So there should be exactly 199 | as many rows in Ws as there are outputs in the decoder Y. We are passing in Ws like 200 | this because we may not want it to be exactly model['Ws']. For example it could be 201 | fixed word vectors from somewhere else. 202 | """ 203 | tanhC_version = params['tanhC_version'] 204 | beam_size = kwargs.get('beam_size', 1) 205 | 206 | WLSTM = model['WLSTM'] 207 | d = model['Wd'].shape[0] # size of hidden layer 208 | Wd = model['Wd'] 209 | bd = model['bd'] 210 | 211 | # lets define a helper function that does a single LSTM tick 212 | def LSTMtick(x, h_prev, c_prev): 213 | t = 0 214 | 215 | # setup the input vector 216 | Hin = np.zeros((1,WLSTM.shape[0])) # xt, ht-1, bias 217 | Hin[t,0] = 1 218 | Hin[t,1:1+d] = x 219 | Hin[t,1+d:] = h_prev 220 | 221 | # LSTM tick forward 222 | IFOG = np.zeros((1, d * 4)) 223 | IFOGf = np.zeros((1, d * 4)) 224 | C = np.zeros((1, d)) 225 | Hout = np.zeros((1, d)) 226 | IFOG[t] = Hin[t].dot(WLSTM) 227 | IFOGf[t,:3*d] = 1.0/(1.0+np.exp(-IFOG[t,:3*d])) 228 | IFOGf[t,3*d:] = np.tanh(IFOG[t, 3*d:]) 229 | C[t] = IFOGf[t,:d] * IFOGf[t, 3*d:] + IFOGf[t,d:2*d] * c_prev 230 | if tanhC_version: 231 | Hout[t] = IFOGf[t,2*d:3*d] * np.tanh(C[t]) 232 | else: 233 | Hout[t] = IFOGf[t,2*d:3*d] * C[t] 234 | Y = Hout.dot(Wd) + bd 235 | return (Y, Hout, C) # return output, new hidden, new cell 236 | 237 | # forward prop the image 238 | (y0, h, c) = LSTMtick(Xi, np.zeros(d), np.zeros(d)) 239 | 240 | # perform BEAM search. NOTE: I am not very confident in this implementation since I don't have 241 | # a lot of experience with these models. This implements my current understanding but I'm not 242 | # sure how to handle beams that predict END tokens. TODO: research this more. 243 | if beam_size > 1: 244 | # log probability, indices of words predicted in this beam so far, and the hidden and cell states 245 | beams = [(0.0, [], h, c)] 246 | nsteps = 0 247 | while True: 248 | beam_candidates = [] 249 | for b in beams: 250 | ixprev = b[1][-1] if b[1] else 0 # start off with the word where this beam left off 251 | if ixprev == 0 and b[1]: 252 | # this beam predicted end token. Keep in the candidates but don't expand it out any more 253 | beam_candidates.append(b) 254 | continue 255 | (y1, h1, c1) = LSTMtick(Ws[ixprev], b[2], b[3]) 256 | y1 = y1.ravel() # make into 1D vector 257 | maxy1 = np.amax(y1) 258 | e1 = np.exp(y1 - maxy1) # for numerical stability shift into good numerical range 259 | p1 = e1 / np.sum(e1) 260 | y1 = np.log(1e-20 + p1) # and back to log domain 261 | top_indices = np.argsort(-y1) # we do -y because we want decreasing order 262 | for i in xrange(beam_size): 263 | wordix = top_indices[i] 264 | beam_candidates.append((b[0] + y1[wordix], b[1] + [wordix], h1, c1)) 265 | beam_candidates.sort(reverse = True) # decreasing order 266 | beams = beam_candidates[:beam_size] # truncate to get new beams 267 | nsteps += 1 268 | if nsteps >= 20: # bad things are probably happening, break out 269 | break 270 | # strip the intermediates 271 | predictions = [(b[0], b[1]) for b in beams] 272 | else: 273 | # greedy inference. lets write it up independently, should be bit faster and simpler 274 | ixprev = 0 275 | nsteps = 0 276 | predix = [] 277 | predlogprob = 0.0 278 | while True: 279 | (y1, h, c) = LSTMtick(Ws[ixprev], h, c) 280 | ixprev, ixlogprob = ymax(y1) 281 | predix.append(ixprev) 282 | predlogprob += ixlogprob 283 | nsteps += 1 284 | if ixprev == 0 or nsteps >= 20: 285 | break 286 | predictions = [(predlogprob, predix)] 287 | 288 | return predictions 289 | 290 | def ymax(y): 291 | """ simple helper function here that takes unnormalized logprobs """ 292 | y1 = y.ravel() # make sure 1d 293 | maxy1 = np.amax(y1) 294 | e1 = np.exp(y1 - maxy1) # for numerical stability shift into good numerical range 295 | p1 = e1 / np.sum(e1) 296 | y1 = np.log(1e-20 + p1) # guard against zero probabilities just in case 297 | ix = np.argmax(y1) 298 | return (ix, y1[ix]) 299 | -------------------------------------------------------------------------------- /imagernn/rnn_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import code 3 | 4 | from imagernn.utils import initw 5 | 6 | class RNNGenerator: 7 | """ 8 | An RNN generator. 9 | This class is as stupid as possible. It gets some conditioning vector, 10 | a sequence of input vectors, and produces a sequence of output vectors 11 | """ 12 | 13 | @staticmethod 14 | def init(input_size, hidden_size, output_size): 15 | 16 | model = {} 17 | # connections to x_t 18 | model['Wxh'] = initw(input_size, hidden_size) 19 | model['bxh'] = np.zeros((1, hidden_size)) 20 | # connections to h_{t-1} 21 | model['Whh'] = initw(hidden_size, hidden_size) 22 | model['bhh'] = np.zeros((1, hidden_size)) 23 | # Decoder weights (e.g. mapping to vocabulary) 24 | model['Wd'] = initw(hidden_size, output_size) * 0.1 # decoder 25 | model['bd'] = np.zeros((1, output_size)) 26 | 27 | update = ['Whh', 'bhh', 'Wxh', 'bxh', 'Wd', 'bd'] 28 | regularize = ['Whh', 'Wxh', 'Wd'] 29 | return { 'model' : model, 'update' : update, 'regularize' : regularize } 30 | 31 | @staticmethod 32 | def forward(Xi, Xs, model, params, **kwargs): 33 | """ 34 | Xi is 1-d array of size D1 (containing the image representation) 35 | Xs is N x D2 (N time steps, rows are data containng word representations), and 36 | it is assumed that the first row is already filled in as the start token. So a 37 | sentence with 10 words will be of size 11xD2 in Xs. 38 | """ 39 | predict_mode = kwargs.get('predict_mode', False) 40 | 41 | # options 42 | drop_prob_encoder = params.get('drop_prob_encoder', 0.0) 43 | drop_prob_decoder = params.get('drop_prob_decoder', 0.0) 44 | relu_encoders = params.get('rnn_relu_encoders', 0) 45 | rnn_feed_once = params.get('rnn_feed_once', 0) 46 | 47 | if drop_prob_encoder > 0: # if we want dropout on the encoder 48 | # inverted version of dropout here. Suppose the drop_prob is 0.5, then during training 49 | # we are going to drop half of the units. In this inverted version we also boost the activations 50 | # of the remaining 50% by 2.0 (scale). The nice property of this is that during prediction time 51 | # we don't have to do any scailing, since all 100% of units will be active, but at their base 52 | # firing rate, giving 100% of the "energy". So the neurons later in the pipeline dont't change 53 | # their expected firing rate magnitudes 54 | if not predict_mode: # and we are in training mode 55 | scale = 1.0 / (1.0 - drop_prob_encoder) 56 | Us = (np.random.rand(*(Xs.shape)) < (1 - drop_prob_encoder)) * scale # generate scaled mask 57 | Xs *= Us # drop! 58 | Ui = (np.random.rand(*(Xi.shape)) < (1 - drop_prob_encoder)) * scale 59 | Xi *= Ui # drop! 60 | 61 | # encode input vectors 62 | Wxh = model['Wxh'] 63 | bxh = model['bxh'] 64 | Xsh = Xs.dot(Wxh) + bxh 65 | 66 | if relu_encoders: 67 | Xsh = np.maximum(Xsh, 0) 68 | Xi = np.maximum(Xi, 0) 69 | 70 | # recurrence iteration for the Multimodal RNN similar to one described in Karpathy et al. 71 | d = model['Wd'].shape[0] # size of hidden layer 72 | n = Xs.shape[0] 73 | H = np.zeros((n, d)) # hidden layer representation 74 | Whh = model['Whh'] 75 | bhh = model['bhh'] 76 | for t in xrange(n): 77 | 78 | prev = np.zeros(d) if t == 0 else H[t-1] 79 | if not rnn_feed_once or t == 0: 80 | # feed the image in if feedonce is false. And it it is true, then 81 | # only feed the image in if its the first iteration 82 | H[t] = np.maximum(Xi + Xsh[t] + prev.dot(Whh) + bhh, 0) # also ReLU 83 | else: 84 | H[t] = np.maximum(Xsh[t] + prev.dot(Whh) + bhh, 0) # also ReLU 85 | 86 | if drop_prob_decoder > 0: # if we want dropout on the decoder 87 | if not predict_mode: # and we are in training mode 88 | scale2 = 1.0 / (1.0 - drop_prob_decoder) 89 | U2 = (np.random.rand(*(H.shape)) < (1 - drop_prob_decoder)) * scale2 # generate scaled mask 90 | H *= U2 # drop! 91 | 92 | # decoder at the end 93 | Wd = model['Wd'] 94 | bd = model['bd'] 95 | Y = H.dot(Wd) + bd 96 | 97 | cache = {} 98 | if not predict_mode: 99 | # we can expect to do a backward pass 100 | cache['Whh'] = Whh 101 | cache['H'] = H 102 | cache['Wd'] = Wd 103 | cache['Xs'] = Xs 104 | cache['Xsh'] = Xsh 105 | cache['Wxh'] = Wxh 106 | cache['Xi'] = Xi 107 | cache['relu_encoders'] = relu_encoders 108 | cache['drop_prob_encoder'] = drop_prob_encoder 109 | cache['drop_prob_decoder'] = drop_prob_decoder 110 | cache['rnn_feed_once'] = rnn_feed_once 111 | if drop_prob_encoder > 0: 112 | cache['Us'] = Us # keep the dropout masks around for backprop 113 | cache['Ui'] = Ui 114 | if drop_prob_decoder > 0: cache['U2'] = U2 115 | 116 | return Y, cache 117 | 118 | @staticmethod 119 | def backward(dY, cache): 120 | 121 | Wd = cache['Wd'] 122 | H = cache['H'] 123 | Xs = cache['Xs'] 124 | Xsh = cache['Xsh'] 125 | Whh = cache['Whh'] 126 | Wxh = cache['Wxh'] 127 | Xi = cache['Xi'] 128 | drop_prob_encoder = cache['drop_prob_encoder'] 129 | drop_prob_decoder = cache['drop_prob_decoder'] 130 | relu_encoders = cache['relu_encoders'] 131 | rnn_feed_once = cache['rnn_feed_once'] 132 | n,d = H.shape 133 | 134 | # backprop the decoder 135 | dWd = H.transpose().dot(dY) 136 | dbd = np.sum(dY, axis=0, keepdims = True) 137 | dH = dY.dot(Wd.transpose()) 138 | 139 | # backprop dropout, if it was applied 140 | if drop_prob_decoder > 0: 141 | dH *= cache['U2'] 142 | 143 | # backprop the recurrent connections 144 | dXsh = np.zeros(Xsh.shape) 145 | dXi = np.zeros(d) 146 | dWhh = np.zeros(Whh.shape) 147 | dbhh = np.zeros((1,d)) 148 | for t in reversed(xrange(n)): 149 | dht = (H[t] > 0) * dH[t] # backprop ReLU 150 | 151 | if not rnn_feed_once or t == 0: 152 | dXi += dht # backprop to Xi 153 | 154 | dXsh[t] += dht # backprop to word encodings 155 | dbhh[0] += dht # backprop to bias 156 | 157 | if t > 0: 158 | dH[t-1] += dht.dot(Whh.transpose()) 159 | dWhh += np.outer(H[t-1], dht) 160 | 161 | if relu_encoders: 162 | # backprop relu 163 | dXsh[Xsh <= 0] = 0 164 | dXi[Xi <= 0] = 0 165 | 166 | # backprop the word encoder 167 | dWxh = Xs.transpose().dot(dXsh) 168 | dbxh = np.sum(dXsh, axis=0, keepdims = True) 169 | dXs = dXsh.dot(Wxh.transpose()) 170 | 171 | if drop_prob_encoder > 0: # backprop encoder dropout 172 | dXi *= cache['Ui'] 173 | dXs *= cache['Us'] 174 | 175 | return { 'Whh': dWhh, 'bhh': dbhh, 'Wd': dWd, 'bd': dbd, 'Wxh':dWxh, 'bxh':dbxh, 'dXs' : dXs, 'dXi': dXi } 176 | 177 | @staticmethod 178 | def predict(Xi, model, Ws, params, **kwargs): 179 | 180 | beam_size = kwargs.get('beam_size', 1) 181 | relu_encoders = params.get('rnn_relu_encoders', 0) 182 | rnn_feed_once = params.get('rnn_feed_once', 0) 183 | 184 | d = model['Wd'].shape[0] # size of hidden layer 185 | Whh = model['Whh'] 186 | bhh = model['bhh'] 187 | Wd = model['Wd'] 188 | bd = model['bd'] 189 | Wxh = model['Wxh'] 190 | bxh = model['bxh'] 191 | 192 | if relu_encoders: 193 | Xi = np.maximum(Xi, 0) 194 | 195 | if beam_size > 1: 196 | # perform beam search 197 | # NOTE: code duplication here with lstm_generator 198 | # ideally the beam search would be abstracted away nicely and would take 199 | # a TICK function or something, but for now lets save time & copy code around. Sorry ;\ 200 | beams = [(0.0, [], np.zeros(d))] 201 | nsteps = 0 202 | while True: 203 | beam_candidates = [] 204 | for b in beams: 205 | ixprev = b[1][-1] if b[1] else 0 206 | if ixprev == 0 and b[1]: 207 | # this beam predicted end token. Keep in the candidates but don't expand it out any more 208 | beam_candidates.append(b) 209 | continue 210 | # tick the RNN for this beam 211 | Xsh = Ws[ixprev].dot(Wxh) + bxh 212 | if relu_encoders: 213 | Xsh = np.maximum(Xsh, 0) 214 | 215 | if (not rnn_feed_once) or (not b[1]): 216 | h1 = np.maximum(Xi + Xsh + b[2].dot(Whh) + bhh, 0) 217 | else: 218 | h1 = np.maximum(Xsh + b[2].dot(Whh) + bhh, 0) 219 | 220 | y1 = h1.dot(Wd) + bd 221 | 222 | # compute new candidates that expand out form this beam 223 | y1 = y1.ravel() # make into 1D vector 224 | maxy1 = np.amax(y1) 225 | e1 = np.exp(y1 - maxy1) # for numerical stability shift into good numerical range 226 | p1 = e1 / np.sum(e1) 227 | y1 = np.log(1e-20 + p1) # and back to log domain 228 | top_indices = np.argsort(-y1) # we do -y because we want decreasing order 229 | for i in xrange(beam_size): 230 | wordix = top_indices[i] 231 | beam_candidates.append((b[0] + y1[wordix], b[1] + [wordix], h1)) 232 | 233 | beam_candidates.sort(reverse = True) # decreasing order 234 | beams = beam_candidates[:beam_size] # truncate to get new beams 235 | nsteps += 1 236 | if nsteps >= 20: # bad things are probably happening, break out 237 | break 238 | # strip the intermediates 239 | predictions = [(b[0], b[1]) for b in beams] 240 | 241 | else: 242 | ixprev = 0 # start out on start token 243 | nsteps = 0 244 | predix = [] 245 | predlogprob = 0.0 246 | hprev = np.zeros((1, d)) # hidden layer representation 247 | xsprev = Ws[0] # start token 248 | while True: 249 | Xsh = Ws[ixprev].dot(Wxh) + bxh 250 | if relu_encoders: 251 | Xsh = np.maximum(Xsh, 0) 252 | 253 | if (not rnn_feed_once) or (nsteps == 0): 254 | ht = np.maximum(Xi + Xsh + hprev.dot(Whh) + bhh, 0) 255 | else: 256 | ht = np.maximum(Xsh + hprev.dot(Whh) + bhh, 0) 257 | 258 | Y = ht.dot(Wd) + bd 259 | hprev = ht 260 | 261 | ixprev, ixlogprob = ymax(Y) 262 | predix.append(ixprev) 263 | predlogprob += ixlogprob 264 | 265 | nsteps += 1 266 | if ixprev == 0 or nsteps >= 20: 267 | break 268 | predictions = [(predlogprob, predix)] 269 | return predictions 270 | 271 | 272 | def ymax(y): 273 | """ simple helper function here that takes unnormalized logprobs """ 274 | y1 = y.ravel() # make sure 1d 275 | maxy1 = np.amax(y1) 276 | e1 = np.exp(y1 - maxy1) # for numerical stability shift into good numerical range 277 | p1 = e1 / np.sum(e1) 278 | y1 = np.log(1e-20 + p1) # guard against zero probabilities just in case 279 | ix = np.argmax(y1) 280 | return (ix, y1[ix]) 281 | -------------------------------------------------------------------------------- /imagernn/solver.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | from imagernn.utils import randi 4 | 5 | class Solver: 6 | """ 7 | solver worries about: 8 | - different optimization methods, updates, weight decays 9 | - it can also perform gradient check 10 | """ 11 | def __init__(self): 12 | self.step_cache_ = {} # might need this 13 | self.step_cache2_ = {} # might need this 14 | 15 | def step(self, batch, model, cost_function, **kwargs): 16 | """ 17 | perform a single batch update. Takes as input: 18 | - batch of data (X) 19 | - model (W) 20 | - cost function which takes batch, model 21 | """ 22 | 23 | learning_rate = kwargs.get('learning_rate', 0.0) 24 | update = kwargs.get('update', model.keys()) 25 | grad_clip = kwargs.get('grad_clip', -1) 26 | solver = kwargs.get('solver', 'vanilla') 27 | momentum = kwargs.get('momentum', 0) 28 | smooth_eps = kwargs.get('smooth_eps', 1e-8) 29 | decay_rate = kwargs.get('decay_rate', 0.999) 30 | 31 | if not (solver == 'vanilla' and momentum == 0): 32 | # lazily make sure we initialize step cache if needed 33 | for u in update: 34 | if not u in self.step_cache_: 35 | self.step_cache_[u] = np.zeros(model[u].shape) 36 | if solver == 'adadelta': 37 | self.step_cache2_[u] = np.zeros(model[u].shape) # adadelta needs one more cache 38 | 39 | # compute cost and gradient 40 | cg = cost_function(batch, model) 41 | cost = cg['cost'] 42 | grads = cg['grad'] 43 | stats = cg['stats'] 44 | 45 | # clip gradients if needed, simplest possible version 46 | # todo later: maybe implement the gradient direction conserving version 47 | if grad_clip > 0: 48 | for p in update: 49 | if p in grads: 50 | grads[p] = np.minimum(grads[p], grad_clip) 51 | grads[p] = np.maximum(grads[p], -grad_clip) 52 | 53 | # perform parameter update 54 | for p in update: 55 | if p in grads: 56 | 57 | if solver == 'vanilla': # vanilla sgd, optional with momentum 58 | if momentum > 0: 59 | dx = momentum * self.step_cache_[p] - learning_rate * grads[p] 60 | self.step_cache_[p] = dx 61 | else: 62 | dx = - learning_rate * grads[p] 63 | 64 | elif solver == 'rmsprop': 65 | self.step_cache_[p] = self.step_cache_[p] * decay_rate + (1.0 - decay_rate) * grads[p] ** 2 66 | dx = -(learning_rate * grads[p]) / np.sqrt(self.step_cache_[p] + smooth_eps) 67 | 68 | elif solver == 'adagrad': 69 | self.step_cache_[p] += grads[p] ** 2 70 | dx = -(learning_rate * grads[p]) / np.sqrt(self.step_cache_[p] + smooth_eps) 71 | 72 | elif solver == 'adadelta': 73 | self.step_cache_[p] = self.step_cache_[p] * decay_rate + (1.0 - decay_rate) * grads[p] ** 2 74 | dx = - np.sqrt( (self.step_cache2_[p] + smooth_eps) / (self.step_cache_[p] + smooth_eps) ) * grads[p] 75 | self.step_cache2_[p] = self.step_cache2_[p] * decay_rate + (1.0 - decay_rate) * (dx ** 2) 76 | 77 | else: 78 | raise Exception("solver %s not supported" % (solver, )) 79 | 80 | # perform the parameter update 81 | model[p] += dx 82 | 83 | # create output dict and return 84 | out = {} 85 | out['cost'] = cost 86 | out['stats'] = stats 87 | return out 88 | 89 | def gradCheck(self, batch, model, cost_function, **kwargs): 90 | """ 91 | perform gradient check. 92 | since gradcheck can be tricky (especially with relus involved) 93 | this function prints to console for visual inspection 94 | """ 95 | 96 | num_checks = kwargs.get('num_checks', 10) 97 | delta = kwargs.get('delta', 1e-5) 98 | rel_error_thr_warning = kwargs.get('rel_error_thr_warning', 1e-2) 99 | rel_error_thr_error = kwargs.get('rel_error_thr_error', 1) 100 | 101 | cg = cost_function(batch, model) 102 | 103 | print 'running gradient check...' 104 | for p in model.keys(): 105 | print 'checking gradient on parameter %s of shape %s...' % (p, `model[p].shape`) 106 | mat = model[p] 107 | 108 | s0 = cg['grad'][p].shape 109 | s1 = mat.shape 110 | assert s0 == s1, 'Error dims dont match: %s and %s.' % (`s0`, `s1`) 111 | 112 | for i in xrange(num_checks): 113 | ri = randi(mat.size) 114 | 115 | # evluate cost at [x + delta] and [x - delta] 116 | old_val = mat.flat[ri] 117 | mat.flat[ri] = old_val + delta 118 | cg0 = cost_function(batch, model) 119 | mat.flat[ri] = old_val - delta 120 | cg1 = cost_function(batch, model) 121 | mat.flat[ri] = old_val # reset old value for this parameter 122 | 123 | # fetch both numerical and analytic gradient 124 | grad_analytic = cg['grad'][p].flat[ri] 125 | grad_numerical = (cg0['cost']['total_cost'] - cg1['cost']['total_cost']) / ( 2 * delta ) 126 | 127 | # compare them 128 | if grad_numerical == 0 and grad_analytic == 0: 129 | rel_error = 0 # both are zero, OK. 130 | status = 'OK' 131 | elif abs(grad_numerical) < 1e-7 and abs(grad_analytic) < 1e-7: 132 | rel_error = 0 # not enough precision to check this 133 | status = 'VAL SMALL WARNING' 134 | else: 135 | rel_error = abs(grad_analytic - grad_numerical) / abs(grad_numerical + grad_analytic) 136 | status = 'OK' 137 | if rel_error > rel_error_thr_warning: status = 'WARNING' 138 | if rel_error > rel_error_thr_error: status = '!!!!! NOTOK' 139 | 140 | # print stats 141 | print '%s checking param %s index %8d (val = %+8f), analytic = %+8f, numerical = %+8f, relative error = %+8f' \ 142 | % (status, p, ri, old_val, grad_analytic, grad_numerical, rel_error) 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | -------------------------------------------------------------------------------- /imagernn/utils.py: -------------------------------------------------------------------------------- 1 | from random import uniform 2 | import numpy as np 3 | 4 | def randi(N): 5 | """ get random integer in range [0, N) """ 6 | return int(uniform(0, N)) 7 | 8 | def merge_init_structs(s0, s1): 9 | """ merge struct s1 into s0 """ 10 | for k in s1['model']: 11 | assert (not k in s0['model']), 'Error: looks like parameter %s is trying to be initialized twice!' % (k, ) 12 | s0['model'][k] = s1['model'][k] # copy over the pointer 13 | s0['update'].extend(s1['update']) 14 | s0['regularize'].extend(s1['regularize']) 15 | 16 | def initw(n,d): # initialize matrix of this size 17 | magic_number = 0.1 18 | return (np.random.rand(n,d) * 2 - 1) * magic_number # U[-0.1, 0.1] 19 | 20 | def accumNpDicts(d0, d1): 21 | """ forall k in d0, d0 += d1 . d's are dictionaries of key -> numpy array """ 22 | for k in d1: 23 | if k in d0: 24 | d0[k] += d1[k] 25 | else: 26 | d0[k] = d1[k] -------------------------------------------------------------------------------- /reg_range.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | from math import log 4 | 5 | 6 | def freq_list_to_dict(data): 7 | dico = {} 8 | for x in data: 9 | if x in dico.keys(): 10 | dico[x] += 1 11 | else: 12 | dico[x] = 1 13 | return dico 14 | 15 | def barplot_range(data, filename): 16 | import plotly.plotly as py 17 | import plotly.graph_objs as go 18 | dico = freq_list_to_dict(data) 19 | data_plot = [go.Bar( 20 | x=dico.keys(), 21 | y=dico.values() 22 | )] 23 | py.plot(data_plot, filename=filename) 24 | return 25 | 26 | 27 | def get_reg_range_weights(ixtoword, min_range, max_range): 28 | # Define measure outside range 29 | # def f(x): 30 | # return x 31 | def f(x): 32 | return log(10+x) 33 | ixtocost = [] 34 | ranges_note = [] 35 | for ix, word in ixtoword.items(): 36 | if word == '.': 37 | ixtocost.append(0) 38 | continue 39 | note = int(re.split(r";", word)[0]) 40 | ranges_note.append(note) 41 | if note >= min_range: 42 | if note <= max_range: 43 | ixtocost.append(0) 44 | else: 45 | ixtocost.append(f(note-max_range)) 46 | else: 47 | ixtocost.append(f(min_range-note)) 48 | # barplot_range(ranges_note) 49 | return np.asarray(ixtocost) 50 | 51 | def get_reg_range_cost(P, W): 52 | # P = preds 53 | # C = \sum_j WjPj 54 | reg_range_cost = np.sum(np.multiply(P, W), axis=1) 55 | return reg_range_cost 56 | 57 | def get_reg_range_derivative(P, W, C): 58 | # P = preds 59 | # W = reg_range_weights 60 | # C = reg_range_cost 61 | # 62 | # dE / dYi = Pi * [Wi - \sum_{j} Pj . Wj] 63 | reg_range_d = np.multiply(P, W) - C.reshape([-1,1])*P 64 | return reg_range_d -------------------------------------------------------------------------------- /status/Readme.md: -------------------------------------------------------------------------------- 1 | This contains status JSON files of running optimizations 2 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import time 4 | import datetime 5 | import numpy as np 6 | import code 7 | import socket 8 | import os 9 | import sys 10 | import cPickle as pickle 11 | import re 12 | 13 | from imagernn.data_provider import getDataProvider 14 | from imagernn.solver import Solver 15 | from imagernn.imagernn_utils import decodeGenerator, eval_split 16 | 17 | from reg_range import get_reg_range_weights, get_reg_range_cost, get_reg_range_derivative 18 | 19 | 20 | def preProBuildWordVocab(sentence_iterator, word_count_threshold): 21 | # count up all word counts so that we can threshold 22 | # this shouldnt be too expensive of an operation 23 | print 'preprocessing word counts and creating vocab based on word count threshold %d' % (word_count_threshold, ) 24 | t0 = time.time() 25 | word_counts = {} 26 | nsents = 0 27 | for sent in sentence_iterator: 28 | nsents += 1 29 | for w in sent['tokens']: 30 | word_counts[w] = word_counts.get(w, 0) + 1 31 | vocab = [w for w in word_counts if word_counts[w] >= word_count_threshold] 32 | print 'filtered words from %d to %d in %.2fs' % (len(word_counts), len(vocab), time.time() - t0) 33 | 34 | # with K distinct words: 35 | # - there are K+1 possible inputs (START token and all the words) 36 | # - there are K+1 possible outputs (END token and all the words) 37 | # we use ixtoword to take predicted indeces and map them to words for output visualization 38 | # we use wordtoix to take raw words and get their index in word vector matrix 39 | ixtoword = {} 40 | ixtoword[0] = '.' # period at the end of the sentence. make first dimension be end token 41 | wordtoix = {} 42 | wordtoix['#START#'] = 0 # make first vector be the start token 43 | ix = 1 44 | for w in vocab: 45 | wordtoix[w] = ix 46 | ixtoword[ix] = w 47 | ix += 1 48 | 49 | # compute bias vector, which is related to the log probability of the distribution 50 | # of the labels (words) and how often they occur. We will use this vector to initialize 51 | # the decoder weights, so that the loss function doesnt show a huge increase in performance 52 | # very quickly (which is just the network learning this anyway, for the most part). This makes 53 | # the visualizations of the cost function nicer because it doesn't look like a hockey stick. 54 | # for example on Flickr8K, doing this brings down initial perplexity from ~2500 to ~170. 55 | word_counts['.'] = nsents 56 | bias_init_vector = np.array([1.0*word_counts[ixtoword[i]] for i in ixtoword]) 57 | bias_init_vector /= np.sum(bias_init_vector) # normalize to frequencies 58 | bias_init_vector = np.log(bias_init_vector) 59 | bias_init_vector -= np.max(bias_init_vector) # shift to nice numeric range 60 | return wordtoix, ixtoword, bias_init_vector 61 | 62 | def RNNGenCost(batch, model, params, misc): 63 | """ cost function, returns cost and gradients for model """ 64 | regc = params['regc'] # regularization cost 65 | BatchGenerator = decodeGenerator(params) 66 | wordtoix = misc['wordtoix'] 67 | reg_range_weights = misc['reg_range_weights'] 68 | reg_range_coeff = params['reg_range_coeff'] 69 | 70 | # forward the RNN on each image sentence pair 71 | # the generator returns a list of matrices that have word probabilities 72 | # and a list of cache objects that will be needed for backprop 73 | Ys, gen_caches = BatchGenerator.forward(batch, model, params, misc, predict_mode = False) 74 | 75 | # compute softmax costs for all generated sentences, and the gradients on top 76 | loss_cost = 0.0 77 | dYs = [] 78 | logppl = 0.0 79 | logppln = 0 80 | for i,pair in enumerate(batch): 81 | img = pair['image'] 82 | # ground truth indices for this sentence we expect to see 83 | gtix = [ wordtoix[w] for w in pair['sentence']['tokens'] if w in wordtoix ] 84 | gtix.append(0) # don't forget END token must be predicted in the end! 85 | # fetch the predicted probabilities, as rows 86 | Y = Ys[i] 87 | maxes = np.amax(Y, axis=1, keepdims=True) 88 | e = np.exp(Y - maxes) # for numerical stability shift into good numerical range 89 | P = e / np.sum(e, axis=1, keepdims=True) # Softmax 90 | loss_cost += - np.sum(np.log(1e-20 + P[range(len(gtix)),gtix])) # note: add smoothing to not get infs 91 | logppl += - np.sum(np.log2(1e-20 + P[range(len(gtix)),gtix])) # also accumulate log2 perplexities 92 | logppln += len(gtix) 93 | 94 | ########################################################################################### 95 | # compute range regularization 96 | reg_range_cost_ = get_reg_range_cost(P, reg_range_weights) 97 | reg_range_derivative = get_reg_range_derivative(P, reg_range_weights, reg_range_cost_) 98 | reg_range_cost = reg_range_coeff * reg_range_cost_.sum() # Sum reg cost along time axis for display 99 | # IMPORTANT REMINDER : do not multiply by coeff before computing gradient !! 100 | 101 | # lets be clever and optimize for speed here to derive the gradient in place quickly 102 | for iy,y in enumerate(gtix): 103 | P[iy,y] -= 1 # softmax derivatives are pretty simple 104 | 105 | P += reg_range_coeff * reg_range_derivative # Add range regularizarion cost 106 | dYs.append(P) 107 | 108 | # backprop the RNN 109 | grads = BatchGenerator.backward(dYs, gen_caches) 110 | 111 | # add L2 regularization cost and gradients 112 | reg_cost = 0.0 113 | if regc > 0: 114 | for p in misc['regularize']: 115 | mat = model[p] 116 | reg_cost += 0.5 * regc * np.sum(mat * mat) 117 | grads[p] += regc * mat 118 | 119 | # normalize the cost and gradient by the batch size 120 | batch_size = len(batch) 121 | reg_cost /= batch_size 122 | loss_cost /= batch_size 123 | for k in grads: grads[k] /= batch_size 124 | 125 | # return output in json 126 | out = {} 127 | out['cost'] = {'reg_range_cost' : reg_range_cost, 'reg_cost' : reg_cost, 'loss_cost' : loss_cost, 'total_cost' : loss_cost + reg_cost + reg_range_cost} 128 | out['grad'] = grads 129 | out['stats'] = { 'ppl2' : 2 ** (logppl / logppln)} 130 | return out 131 | 132 | def main(params): 133 | batch_size = params['batch_size'] 134 | dataset = params['dataset'] 135 | word_count_threshold = params['word_count_threshold'] 136 | do_grad_check = params['do_grad_check'] 137 | max_epochs = params['max_epochs'] 138 | host = socket.gethostname() # get computer hostname 139 | 140 | # fetch the data provider 141 | dp = getDataProvider(dataset) 142 | 143 | misc = {} # stores various misc items that need to be passed around the framework 144 | 145 | # go over all training sentences and find the vocabulary we want to use, i.e. the words that occur 146 | # at least word_count_threshold number of times 147 | misc['wordtoix'], misc['ixtoword'], bias_init_vector = preProBuildWordVocab(dp.iterSentences('train'), word_count_threshold) 148 | 149 | misc['reg_range_weights'] = get_reg_range_weights(misc['ixtoword'], params['reg_range_min'], params['reg_range_max']) 150 | 151 | # delegate the initialization of the model to the Generator class 152 | BatchGenerator = decodeGenerator(params) 153 | init_struct = BatchGenerator.init(params, misc) 154 | model, misc['update'], misc['regularize'] = (init_struct['model'], init_struct['update'], init_struct['regularize']) 155 | 156 | # force overwrite here. This is a bit of a hack, not happy about it 157 | model['bd'] = bias_init_vector.reshape(1, bias_init_vector.size) 158 | 159 | print 'model init done.' 160 | print 'model has keys: ' + ', '.join(model.keys()) 161 | print 'updating: ' + ', '.join( '%s [%dx%d]' % (k, model[k].shape[0], model[k].shape[1]) for k in misc['update']) 162 | print 'updating: ' + ', '.join( '%s [%dx%d]' % (k, model[k].shape[0], model[k].shape[1]) for k in misc['regularize']) 163 | print 'number of learnable parameters total: %d' % (sum(model[k].shape[0] * model[k].shape[1] for k in misc['update']), ) 164 | 165 | if params.get('init_model_from', ''): 166 | # load checkpoint 167 | checkpoint = pickle.load(open(params['init_model_from'], 'rb')) 168 | model = checkpoint['model'] # overwrite the model 169 | 170 | # initialize the Solver and the cost function 171 | solver = Solver() 172 | def costfun(batch, model): 173 | # wrap the cost function to abstract some things away from the Solver 174 | return RNNGenCost(batch, model, params, misc) 175 | 176 | # calculate how many iterations we need 177 | num_sentences_total = dp.getSplitSize('train', ofwhat = 'sentences') 178 | num_iters_one_epoch = num_sentences_total / batch_size 179 | max_iters = max_epochs * num_iters_one_epoch 180 | eval_period_in_epochs = params['eval_period'] 181 | eval_period_in_iters = max(1, int(num_iters_one_epoch * eval_period_in_epochs)) 182 | abort = False 183 | top_val_ppl2 = -1 184 | smooth_train_ppl2 = len(misc['ixtoword']) # initially size of dictionary of confusion 185 | val_ppl2 = len(misc['ixtoword']) 186 | last_status_write_time = 0 # for writing worker job status reports 187 | json_worker_status = {} 188 | json_worker_status['params'] = params 189 | json_worker_status['history'] = [] 190 | 191 | for it in xrange(max_iters): 192 | if abort: break 193 | t0 = time.time() 194 | # fetch a batch of data 195 | batch = [dp.sampleImageSentencePair() for i in xrange(batch_size)] 196 | # evaluate cost, gradient and perform parameter update 197 | step_struct = solver.step(batch, model, costfun, **params) 198 | cost = step_struct['cost'] 199 | dt = time.time() - t0 200 | 201 | # print training statistics 202 | train_ppl2 = step_struct['stats']['ppl2'] 203 | smooth_train_ppl2 = 0.99 * smooth_train_ppl2 + 0.01 * train_ppl2 # smooth exponentially decaying moving average 204 | if it == 0: smooth_train_ppl2 = train_ppl2 # start out where we start out 205 | epoch = it * 1.0 / num_iters_one_epoch 206 | print '%d/%d batch done in %.3fs. at epoch %.2f. loss cost = %f, reg cost = %f, range cost = %f, ppl2 = %.2f (smooth %.2f)' \ 207 | % (it, max_iters, dt, epoch, cost['loss_cost'], cost['reg_cost'], cost['reg_range_cost'], \ 208 | train_ppl2, smooth_train_ppl2) 209 | 210 | # perform gradient check if desired, with a bit of a burnin time (10 iterations) 211 | if it == 10 and do_grad_check: 212 | print 'disabling dropout for gradient check...' 213 | params['drop_prob_encoder'] = 0 214 | params['drop_prob_decoder'] = 0 215 | solver.gradCheck(batch, model, costfun) 216 | print 'done gradcheck, exitting.' 217 | sys.exit() # hmmm. probably should exit here 218 | 219 | # detect if loss is exploding and kill the job if so 220 | total_cost = cost['total_cost'] 221 | if it == 0: 222 | total_cost0 = total_cost # store this initial cost 223 | if total_cost > total_cost0 * 2: 224 | print 'Aboring, cost seems to be exploding. Run gradcheck? Lower the learning rate?' 225 | # abort = True # set the abort flag, we'll break out 226 | 227 | # logging: write JSON files for visual inspection of the training 228 | tnow = time.time() 229 | if tnow > last_status_write_time + 60*1: # every now and then lets write a report 230 | last_status_write_time = tnow 231 | jstatus = {} 232 | jstatus['time'] = datetime.datetime.now().isoformat() 233 | jstatus['iter'] = (it, max_iters) 234 | jstatus['epoch'] = (epoch, max_epochs) 235 | jstatus['time_per_batch'] = dt 236 | jstatus['smooth_train_ppl2'] = smooth_train_ppl2 237 | jstatus['val_ppl2'] = val_ppl2 # just write the last available one 238 | jstatus['train_ppl2'] = train_ppl2 239 | json_worker_status['history'].append(jstatus) 240 | status_file = os.path.join(params['worker_status_output_directory'], host + '_status.json') 241 | try: 242 | json.dump(json_worker_status, open(status_file, 'w')) 243 | except Exception, e: # todo be more clever here 244 | print 'tried to write worker status into %s but got error:' % (status_file, ) 245 | print e 246 | 247 | # perform perplexity evaluation on the validation set and save a model checkpoint if it's good 248 | is_last_iter = (it+1) == max_iters 249 | if (((it+1) % eval_period_in_iters) == 0 and it < max_iters - 5) or is_last_iter: 250 | val_ppl2 = eval_split('val', dp, model, params, misc) # perform the evaluation on VAL set 251 | print 'validation perplexity = %f' % (val_ppl2, ) 252 | 253 | # abort training if the perplexity is no good 254 | min_ppl_or_abort = params['min_ppl_or_abort'] 255 | if val_ppl2 > min_ppl_or_abort and min_ppl_or_abort > 0: 256 | print 'aborting job because validation perplexity %f < %f' % (val_ppl2, min_ppl_or_abort) 257 | abort = True # abort the job 258 | 259 | write_checkpoint_ppl_threshold = params['write_checkpoint_ppl_threshold'] 260 | #if val_ppl2 < top_val_ppl2 or top_val_ppl2 < 0: 261 | if val_ppl2 < write_checkpoint_ppl_threshold or write_checkpoint_ppl_threshold < 0: 262 | # if we beat a previous record or if this is the first time 263 | # AND we also beat the user-defined threshold or it doesnt exist 264 | top_val_ppl2 = val_ppl2 265 | filename = 'model_checkpoint_%s_%s_%s_%.2f_%.2f.p' % (dataset, host, params['fappend'], val_ppl2, epoch) 266 | filepath = os.path.join(params['checkpoint_output_directory'], filename) 267 | checkpoint = {} 268 | checkpoint['it'] = it 269 | checkpoint['epoch'] = epoch 270 | checkpoint['model'] = model 271 | checkpoint['params'] = params 272 | checkpoint['perplexity'] = val_ppl2 273 | checkpoint['wordtoix'] = misc['wordtoix'] 274 | checkpoint['ixtoword'] = misc['ixtoword'] 275 | try: 276 | if not os.path.exists(params['checkpoint_output_directory']): 277 | os.makedirs(params['checkpoint_output_directory']) 278 | pickle.dump(checkpoint, open(filepath, "wb")) 279 | print 'saved checkpoint in %s' % (filepath, ) 280 | except Exception, e: # todo be more clever here 281 | print 'tried to write checkpoint into %s but got error: ' % (filepath, ) 282 | print e 283 | 284 | 285 | if __name__ == "__main__": 286 | 287 | parser = argparse.ArgumentParser() 288 | 289 | # global setup settings, and checkpoints 290 | parser.add_argument('-d', '--dataset', dest='dataset', default='music', help='dataset: flickr8k/flickr30k') 291 | parser.add_argument('-a', '--do_grad_check', dest='do_grad_check', type=int, default=0, help='perform gradcheck? program will block for visual inspection and will need manual user input') 292 | parser.add_argument('--fappend', dest='fappend', type=str, default='baseline', help='append this string to checkpoint filenames') 293 | parser.add_argument('-o', '--checkpoint_output_directory', dest='checkpoint_output_directory', type=str, default='cv/', help='output directory to write checkpoints to') 294 | parser.add_argument('--worker_status_output_directory', dest='worker_status_output_directory', type=str, default='status/', help='directory to write worker status JSON blobs to') 295 | parser.add_argument('--write_checkpoint_ppl_threshold', dest='write_checkpoint_ppl_threshold', type=float, default=-1, help='ppl threshold above which we dont bother writing a checkpoint to save space') 296 | parser.add_argument('--init_model_from', dest='init_model_from', type=str, default='', help='initialize the model parameters from some specific checkpoint?') 297 | 298 | # model parameters 299 | parser.add_argument('--generator', dest='generator', type=str, default='lstm', help='generator to use') 300 | parser.add_argument('--image_encoding_size', dest='image_encoding_size', type=int, default=256, help='size of the image encoding') 301 | parser.add_argument('--word_encoding_size', dest='word_encoding_size', type=int, default=256, help='size of word encoding') 302 | parser.add_argument('--hidden_size', dest='hidden_size', type=int, default=256, help='size of hidden layer in generator RNNs') 303 | # lstm-specific params 304 | parser.add_argument('--tanhC_version', dest='tanhC_version', type=int, default=0, help='use tanh version of LSTM?') 305 | # rnn-specific params 306 | parser.add_argument('--rnn_relu_encoders', dest='rnn_relu_encoders', type=int, default=0, help='relu encoders before going to RNN?') 307 | parser.add_argument('--rnn_feed_once', dest='rnn_feed_once', type=int, default=0, help='feed image to the rnn only single time?') 308 | 309 | # Sequence length regularization 310 | parser.add_argument('--reg_range_coeff', dest='reg_range_coeff', type=float, default=1, help='range regularization strength. Just set it to 0 to discard regularization') 311 | parser.add_argument('--reg_range_min', dest='reg_range_min', type=float, default=60, help='lowest pitch allowed') 312 | parser.add_argument('--reg_range_max', dest='reg_range_max', type=float, default=72, help='highest pitch allowed') 313 | 314 | # optimization parameters 315 | parser.add_argument('-c', '--regc', dest='regc', type=float, default=1e-8, help='regularization strength') 316 | parser.add_argument('-m', '--max_epochs', dest='max_epochs', type=int, default=50, help='number of epochs to train for') 317 | parser.add_argument('--solver', dest='solver', type=str, default='rmsprop', help='solver type: vanilla/adagrad/adadelta/rmsprop') 318 | parser.add_argument('--momentum', dest='momentum', type=float, default=0.0, help='momentum for vanilla sgd') 319 | parser.add_argument('--decay_rate', dest='decay_rate', type=float, default=0.999, help='decay rate for adadelta/rmsprop') 320 | parser.add_argument('--smooth_eps', dest='smooth_eps', type=float, default=1e-8, help='epsilon smoothing for rmsprop/adagrad/adadelta') 321 | parser.add_argument('-l', '--learning_rate', dest='learning_rate', type=float, default=1e-3, help='solver learning rate') 322 | parser.add_argument('-b', '--batch_size', dest='batch_size', type=int, default=1, help='batch size') 323 | parser.add_argument('--grad_clip', dest='grad_clip', type=float, default=5, help='clip gradients (normalized by batch size)? elementwise. if positive, at what threshold?') 324 | parser.add_argument('--drop_prob_encoder', dest='drop_prob_encoder', type=float, default=0.5, help='what dropout to apply right after the encoder to an RNN/LSTM') 325 | parser.add_argument('--drop_prob_decoder', dest='drop_prob_decoder', type=float, default=0.5, help='what dropout to apply right before the decoder in an RNN/LSTM') 326 | 327 | # data preprocessing parameters 328 | parser.add_argument('--word_count_threshold', dest='word_count_threshold', type=int, default=1, help='if a word occurs less than this number of times in training data, it is discarded') 329 | 330 | # evaluation parameters 331 | parser.add_argument('-p', '--eval_period', dest='eval_period', type=float, default=5, help='in units of epochs, how often do we evaluate on val set?') 332 | parser.add_argument('--eval_batch_size', dest='eval_batch_size', type=int, default=100, help='for faster validation performance evaluation, what batch size to use on val img/sentences?') 333 | parser.add_argument('--eval_max_images', dest='eval_max_images', type=int, default=-1, help='for efficiency we can use a smaller number of images to get validation error') 334 | parser.add_argument('--min_ppl_or_abort', dest='min_ppl_or_abort', type=float , default=-1, help='if validation perplexity is below this threshold the job will abort') 335 | 336 | args = parser.parse_args() 337 | params = vars(args) # convert to ordinary dict 338 | print 'parsed parameters:' 339 | print json.dumps(params, indent = 2) 340 | 341 | main(params) 342 | --------------------------------------------------------------------------------