├── .gitignore ├── README.md ├── config ├── __init__.py └── nn_config.py ├── convert_directory.py ├── data_utils ├── __init__.py └── parse_files.py ├── datasets ├── .gitignore └── YourMusicLibrary │ └── Test.mp3 ├── gen_utils ├── __init__.py ├── seed_generator.py └── sequence_generator.py ├── generate.py ├── nn_utils ├── __init__.py └── network_utils.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | datasets/YourMusicLibrary/* 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GRUV 2 | GRUV is a Python project for algorithmic music generation using recurrent neural networks. 3 | 4 | Note: This code works with Keras v. 0.1.0, later versions of Keras may not work. 5 | 6 | For a demonstration of our project on raw audio waveforms (as opposed to the standard MIDI), see here: https://www.youtube.com/watch?v=0VTI1BBLydE 7 | 8 | Copyright (C) 2015 Matt Vitelli matthew.vitelli@gmail.com and Aran Nayebi aran.nayebi@gmail.com 9 | 10 | # Dependencies 11 | In order to use GRUV, you will first need to install the following dependencies: 12 | 13 | Theano: http://deeplearning.net/software/theano/#download 14 | 15 | Keras: https://github.com/fchollet/keras.git 16 | 17 | NumPy: http://www.numpy.org/ 18 | 19 | SciPy: http://www.scipy.org/ 20 | 21 | LAME (for MP3 source files): http://lame.sourceforge.net/ 22 | 23 | SoX (for FLAC source files): http://sox.sourceforge.net/ 24 | 25 | h5py (for serializing the model): http://www.h5py.org/ 26 | 27 | Once that's taken care of, you can try training a model of your own as follows: 28 | # Step 1. Prepare the data 29 | Copy your music into ./datasets/YourMusicLibrary/ and type the following command into Terminal: 30 | > python convert_directory.py 31 | 32 | This will convert all mp3s in ./datasets/YourMusicLibrary/ into WAVs and convert the WAVs into a useful representation for the deep learning algorithms. 33 | 34 | # Step 2. Train your model 35 | At this point, you should have four files named YourMusicLibraryNP_x.npy, YourMusicLibraryNP_y.npy, YourMusicLibraryNP_var.npy, and YourMusicLibraryNP_mean.npy. 36 | 37 | YourMusicLibraryNP_x contains the input sequences for training 38 | YourMusicLibraryNP_y contains the output sequences for training 39 | YourMusicLibraryNP_mean contains the mean for each feature computed from the training set 40 | YourMusicLibraryNP_var contains the variance for each feature computed from the training set 41 | 42 | You can train your very first model by typing the following command into Terminal: 43 | > python train.py 44 | 45 | Training will take a while depending on the length and number of songs used 46 | If you get an error of the following form: 47 | Error allocating X bytes of device memory (out of memory). Driver report Y bytes free and Z bytes total 48 | you must adjust the parameters in train.py - specifically, decrease the batch_size to something smaller. If you still have out of memory errors, you can also decrease the hidden_dims parameter in train.py and generate.py, although this will have a significant impact on the quality of the generated music. 49 | 50 | # Step 3. Generation 51 | After you've finished training your model, it's time to generate some music! 52 | Type the following command into Terminal: 53 | > python generate.py 54 | 55 | After some amount of time, you should have a file called generated_song.wav 56 | 57 | Future work: 58 | Improve generation algorithms. Our current generation scheme uses the training / testing data as a seed sequence, which tends to produce verbatum copies of the original songs. One might imagine that we could improve these results by taking linear combinations of the hidden states for different songs and projecting the combinations back into the frequency space and using those as seed sequences. You can find the core components of the generation algorithms in gen_utils/seed_generator.py and gen_utils/sequence_generator.py 59 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MattVitelli/GRUV/71998b14fda01d58a6ad98a82b4f65455a13a18a/config/__init__.py -------------------------------------------------------------------------------- /config/nn_config.py: -------------------------------------------------------------------------------- 1 | def get_neural_net_configuration(): 2 | nn_params = {} 3 | nn_params['sampling_frequency'] = 44100 4 | #Number of hidden dimensions. 5 | #For best results, this should be >= freq_space_dims, but most consumer GPUs can't handle large sizes 6 | nn_params['hidden_dimension_size'] = 1024 7 | #The weights filename for saving/loading trained models 8 | nn_params['model_basename'] = './YourMusicLibraryNPWeights' 9 | #The model filename for the training data 10 | nn_params['model_file'] = './datasets/YourMusicLibraryNP' 11 | #The dataset directory 12 | nn_params['dataset_directory'] = './datasets/YourMusicLibrary/' 13 | return nn_params -------------------------------------------------------------------------------- /convert_directory.py: -------------------------------------------------------------------------------- 1 | from data_utils.parse_files import * 2 | import config.nn_config as nn_config 3 | 4 | config = nn_config.get_neural_net_configuration() 5 | input_directory = config['dataset_directory'] 6 | output_filename = config['model_file'] 7 | 8 | freq = config['sampling_frequency'] #sample frequency in Hz 9 | clip_len = 10 #length of clips for training. Defined in seconds 10 | block_size = freq / 4 #block sizes used for training - this defines the size of our input state 11 | max_seq_len = int(round((freq * clip_len) / block_size)) #Used later for zero-padding song sequences 12 | #Step 1 - convert MP3s to WAVs 13 | new_directory = convert_folder_to_wav(input_directory, freq) 14 | #Step 2 - convert WAVs to frequency domain with mean 0 and standard deviation of 1 15 | convert_wav_files_to_nptensor(new_directory, block_size, max_seq_len, output_filename) -------------------------------------------------------------------------------- /data_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MattVitelli/GRUV/71998b14fda01d58a6ad98a82b4f65455a13a18a/data_utils/__init__.py -------------------------------------------------------------------------------- /data_utils/parse_files.py: -------------------------------------------------------------------------------- 1 | import os 2 | import scipy.io.wavfile as wav 3 | import numpy as np 4 | from pipes import quote 5 | from config import nn_config 6 | 7 | def convert_mp3_to_wav(filename, sample_frequency): 8 | ext = filename[-4:] 9 | if(ext != '.mp3'): 10 | return 11 | files = filename.split('/') 12 | orig_filename = files[-1][0:-4] 13 | orig_path = filename[0:-len(files[-1])] 14 | new_path = '' 15 | if(filename[0] == '/'): 16 | new_path = '/' 17 | for i in xrange(len(files)-1): 18 | new_path += files[i]+'/' 19 | tmp_path = new_path + 'tmp' 20 | new_path += 'wave' 21 | if not os.path.exists(new_path): 22 | os.makedirs(new_path) 23 | if not os.path.exists(tmp_path): 24 | os.makedirs(tmp_path) 25 | filename_tmp = tmp_path + '/' + orig_filename + '.mp3' 26 | new_name = new_path + '/' + orig_filename + '.wav' 27 | sample_freq_str = "{0:.1f}".format(float(sample_frequency)/1000.0) 28 | cmd = 'lame -a -m m {0} {1}'.format(quote(filename), quote(filename_tmp)) 29 | os.system(cmd) 30 | cmd = 'lame --decode {0} {1} --resample {2}'.format(quote(filename_tmp), quote(new_name), sample_freq_str) 31 | os.system(cmd) 32 | return new_name 33 | 34 | def convert_flac_to_wav(filename, sample_frequency): 35 | ext = filename[-5:] 36 | if(ext != '.flac'): 37 | return 38 | files = filename.split('/') 39 | orig_filename = files[-1][0:-5] 40 | orig_path = filename[0:-len(files[-1])] 41 | new_path = '' 42 | if(filename[0] == '/'): 43 | new_path = '/' 44 | for i in xrange(len(files)-1): 45 | new_path += files[i]+'/' 46 | new_path += 'wave' 47 | if not os.path.exists(new_path): 48 | os.makedirs(new_path) 49 | new_name = new_path + '/' + orig_filename + '.wav' 50 | cmd = 'sox {0} {1} channels 1 rate {2}'.format(quote(filename), quote(new_name), sample_frequency) 51 | os.system(cmd) 52 | return new_name 53 | 54 | 55 | def convert_folder_to_wav(directory, sample_rate=44100): 56 | for file in os.listdir(directory): 57 | fullfilename = directory+file 58 | if file.endswith('.mp3'): 59 | convert_mp3_to_wav(filename=fullfilename, sample_frequency=sample_rate) 60 | if file.endswith('.flac'): 61 | convert_flac_to_wav(filename=fullfilename, sample_frequency=sample_rate) 62 | return directory + 'wave/' 63 | 64 | def read_wav_as_np(filename): 65 | data = wav.read(filename) 66 | np_arr = data[1].astype('float32') / 32767.0 #Normalize 16-bit input to [-1, 1] range 67 | #np_arr = np.array(np_arr) 68 | return np_arr, data[0] 69 | 70 | def write_np_as_wav(X, sample_rate, filename): 71 | Xnew = X * 32767.0 72 | Xnew = Xnew.astype('int16') 73 | wav.write(filename, sample_rate, Xnew) 74 | return 75 | 76 | def convert_np_audio_to_sample_blocks(song_np, block_size): 77 | block_lists = [] 78 | total_samples = song_np.shape[0] 79 | num_samples_so_far = 0 80 | while(num_samples_so_far < total_samples): 81 | block = song_np[num_samples_so_far:num_samples_so_far+block_size] 82 | if(block.shape[0] < block_size): 83 | padding = np.zeros((block_size - block.shape[0],)) 84 | block = np.concatenate((block, padding)) 85 | block_lists.append(block) 86 | num_samples_so_far += block_size 87 | return block_lists 88 | 89 | def convert_sample_blocks_to_np_audio(blocks): 90 | song_np = np.concatenate(blocks) 91 | return song_np 92 | 93 | def time_blocks_to_fft_blocks(blocks_time_domain): 94 | fft_blocks = [] 95 | for block in blocks_time_domain: 96 | fft_block = np.fft.fft(block) 97 | new_block = np.concatenate((np.real(fft_block), np.imag(fft_block))) 98 | fft_blocks.append(new_block) 99 | return fft_blocks 100 | 101 | def fft_blocks_to_time_blocks(blocks_ft_domain): 102 | time_blocks = [] 103 | for block in blocks_ft_domain: 104 | num_elems = block.shape[0] / 2 105 | real_chunk = block[0:num_elems] 106 | imag_chunk = block[num_elems:] 107 | new_block = real_chunk + 1.0j * imag_chunk 108 | time_block = np.fft.ifft(new_block) 109 | time_blocks.append(time_block) 110 | return time_blocks 111 | 112 | def convert_wav_files_to_nptensor(directory, block_size, max_seq_len, out_file, max_files=20, useTimeDomain=False): 113 | files = [] 114 | for file in os.listdir(directory): 115 | if file.endswith('.wav'): 116 | files.append(directory+file) 117 | chunks_X = [] 118 | chunks_Y = [] 119 | num_files = len(files) 120 | if(num_files > max_files): 121 | num_files = max_files 122 | for file_idx in xrange(num_files): 123 | file = files[file_idx] 124 | print 'Processing: ', (file_idx+1),'/',num_files 125 | print 'Filename: ', file 126 | X, Y = load_training_example(file, block_size, useTimeDomain=useTimeDomain) 127 | cur_seq = 0 128 | total_seq = len(X) 129 | print total_seq 130 | print max_seq_len 131 | while cur_seq + max_seq_len < total_seq: 132 | chunks_X.append(X[cur_seq:cur_seq+max_seq_len]) 133 | chunks_Y.append(Y[cur_seq:cur_seq+max_seq_len]) 134 | cur_seq += max_seq_len 135 | num_examples = len(chunks_X) 136 | num_dims_out = block_size * 2 137 | if(useTimeDomain): 138 | num_dims_out = block_size 139 | out_shape = (num_examples, max_seq_len, num_dims_out) 140 | x_data = np.zeros(out_shape) 141 | y_data = np.zeros(out_shape) 142 | for n in xrange(num_examples): 143 | for i in xrange(max_seq_len): 144 | x_data[n][i] = chunks_X[n][i] 145 | y_data[n][i] = chunks_Y[n][i] 146 | print 'Saved example ', (n+1), ' / ',num_examples 147 | print 'Flushing to disk...' 148 | mean_x = np.mean(np.mean(x_data, axis=0), axis=0) #Mean across num examples and num timesteps 149 | std_x = np.sqrt(np.mean(np.mean(np.abs(x_data-mean_x)**2, axis=0), axis=0)) # STD across num examples and num timesteps 150 | std_x = np.maximum(1.0e-8, std_x) #Clamp variance if too tiny 151 | x_data[:][:] -= mean_x #Mean 0 152 | x_data[:][:] /= std_x #Variance 1 153 | y_data[:][:] -= mean_x #Mean 0 154 | y_data[:][:] /= std_x #Variance 1 155 | 156 | np.save(out_file+'_mean', mean_x) 157 | np.save(out_file+'_var', std_x) 158 | np.save(out_file+'_x', x_data) 159 | np.save(out_file+'_y', y_data) 160 | print 'Done!' 161 | 162 | def convert_nptensor_to_wav_files(tensor, indices, filename, useTimeDomain=False): 163 | num_seqs = tensor.shape[1] 164 | for i in indices: 165 | chunks = [] 166 | for x in xrange(num_seqs): 167 | chunks.append(tensor[i][x]) 168 | save_generated_example(filename+str(i)+'.wav', chunks,useTimeDomain=useTimeDomain) 169 | 170 | def load_training_example(filename, block_size=2048, useTimeDomain=False): 171 | data, bitrate = read_wav_as_np(filename) 172 | x_t = convert_np_audio_to_sample_blocks(data, block_size) 173 | y_t = x_t[1:] 174 | y_t.append(np.zeros(block_size)) #Add special end block composed of all zeros 175 | if useTimeDomain: 176 | return x_t, y_t 177 | X = time_blocks_to_fft_blocks(x_t) 178 | Y = time_blocks_to_fft_blocks(y_t) 179 | return X, Y 180 | 181 | def save_generated_example(filename, generated_sequence, useTimeDomain=False, sample_frequency=44100): 182 | if useTimeDomain: 183 | time_blocks = generated_sequence 184 | else: 185 | time_blocks = fft_blocks_to_time_blocks(generated_sequence) 186 | song = convert_sample_blocks_to_np_audio(time_blocks) 187 | write_np_as_wav(song, sample_frequency, filename) 188 | return 189 | 190 | def audio_unit_test(filename, filename2): 191 | data, bitrate = read_wav_as_np(filename) 192 | time_blocks = convert_np_audio_to_sample_blocks(data, 1024) 193 | ft_blocks = time_blocks_to_fft_blocks(time_blocks) 194 | time_blocks = fft_blocks_to_time_blocks(ft_blocks) 195 | song = convert_sample_blocks_to_np_audio(time_blocks) 196 | write_np_as_wav(song, bitrate, filename2) 197 | return 198 | -------------------------------------------------------------------------------- /datasets/.gitignore: -------------------------------------------------------------------------------- 1 | # mp3 --> wav --> npy 2 | *.npy 3 | -------------------------------------------------------------------------------- /datasets/YourMusicLibrary/Test.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MattVitelli/GRUV/71998b14fda01d58a6ad98a82b4f65455a13a18a/datasets/YourMusicLibrary/Test.mp3 -------------------------------------------------------------------------------- /gen_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MattVitelli/GRUV/71998b14fda01d58a6ad98a82b4f65455a13a18a/gen_utils/__init__.py -------------------------------------------------------------------------------- /gen_utils/seed_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | #A very simple seed generator 4 | #Copies a random example's first seed_length sequences as input to the generation algorithm 5 | def generate_copy_seed_sequence(seed_length, training_data): 6 | num_examples = training_data.shape[0] 7 | example_len = training_data.shape[1] 8 | randIdx = np.random.randint(num_examples, size=1)[0] 9 | randSeed = np.concatenate(tuple([training_data[randIdx + i] for i in xrange(seed_length)]), axis=0) 10 | seedSeq = np.reshape(randSeed, (1, randSeed.shape[0], randSeed.shape[1])) 11 | return seedSeq -------------------------------------------------------------------------------- /gen_utils/sequence_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | #Extrapolates from a given seed sequence 4 | def generate_from_seed(model, seed, sequence_length, data_variance, data_mean): 5 | seedSeq = seed.copy() 6 | output = [] 7 | 8 | #The generation algorithm is simple: 9 | #Step 1 - Given A = [X_0, X_1, ... X_n], generate X_n + 1 10 | #Step 2 - Concatenate X_n + 1 onto A 11 | #Step 3 - Repeat MAX_SEQ_LEN times 12 | for it in xrange(sequence_length): 13 | seedSeqNew = model.predict(seedSeq) #Step 1. Generate X_n + 1 14 | #Step 2. Append it to the sequence 15 | if it == 0: 16 | for i in xrange(seedSeqNew.shape[1]): 17 | output.append(seedSeqNew[0][i].copy()) 18 | else: 19 | output.append(seedSeqNew[0][seedSeqNew.shape[1]-1].copy()) 20 | newSeq = seedSeqNew[0][seedSeqNew.shape[1]-1] 21 | newSeq = np.reshape(newSeq, (1, 1, newSeq.shape[0])) 22 | seedSeq = np.concatenate((seedSeq, newSeq), axis=1) 23 | 24 | #Finally, post-process the generated sequence so that we have valid frequencies 25 | #We're essentially just undo-ing the data centering process 26 | for i in xrange(len(output)): 27 | output[i] *= data_variance 28 | output[i] += data_mean 29 | return output 30 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | import numpy as np 4 | import os 5 | import nn_utils.network_utils as network_utils 6 | import gen_utils.seed_generator as seed_generator 7 | import gen_utils.sequence_generator as sequence_generator 8 | from data_utils.parse_files import * 9 | import config.nn_config as nn_config 10 | 11 | config = nn_config.get_neural_net_configuration() 12 | sample_frequency = config['sampling_frequency'] 13 | inputFile = config['model_file'] 14 | model_basename = config['model_basename'] 15 | cur_iter = 25 16 | model_filename = model_basename + str(cur_iter) 17 | output_filename = './generated_song.wav' 18 | 19 | #Load up the training data 20 | print ('Loading training data') 21 | #X_train is a tensor of size (num_train_examples, num_timesteps, num_frequency_dims) 22 | #y_train is a tensor of size (num_train_examples, num_timesteps, num_frequency_dims) 23 | #X_mean is a matrix of size (num_frequency_dims,) containing the mean for each frequency dimension 24 | #X_var is a matrix of size (num_frequency_dims,) containing the variance for each frequency dimension 25 | X_train = np.load(inputFile + '_x.npy') 26 | y_train = np.load(inputFile + '_y.npy') 27 | X_mean = np.load(inputFile + '_mean.npy') 28 | X_var = np.load(inputFile + '_var.npy') 29 | print ('Finished loading training data') 30 | 31 | #Figure out how many frequencies we have in the data 32 | freq_space_dims = X_train.shape[2] 33 | hidden_dims = config['hidden_dimension_size'] 34 | 35 | #Creates a lstm network 36 | model = network_utils.create_lstm_network(num_frequency_dimensions=freq_space_dims, num_hidden_dimensions=hidden_dims) 37 | #You could also substitute this with a RNN or GRU 38 | #model = network_utils.create_gru_network() 39 | 40 | #Load existing weights if available 41 | if os.path.isfile(model_filename): 42 | model.load_weights(model_filename) 43 | else: 44 | print('Model filename ' + model_filename + ' could not be found!') 45 | 46 | print ('Starting generation!') 47 | #Here's the interesting part 48 | #We need to create some seed sequence for the algorithm to start with 49 | #Currently, we just grab an existing seed sequence from our training data and use that 50 | #However, this will generally produce verbatum copies of the original songs 51 | #In a sense, choosing good seed sequences = how you get interesting compositions 52 | #There are many, many ways we can pick these seed sequences such as taking linear combinations of certain songs 53 | #We could even provide a uniformly random sequence, but that is highly unlikely to produce good results 54 | seed_len = 1 55 | seed_seq = seed_generator.generate_copy_seed_sequence(seed_length=seed_len, training_data=X_train) 56 | 57 | max_seq_len = 10; #Defines how long the final song is. Total song length in samples = max_seq_len * example_len 58 | output = sequence_generator.generate_from_seed(model=model, seed=seed_seq, 59 | sequence_length=max_seq_len, data_variance=X_var, data_mean=X_mean) 60 | print ('Finished generation!') 61 | 62 | #Save the generated sequence to a WAV file 63 | save_generated_example(output_filename, output, sample_frequency=sample_frequency) -------------------------------------------------------------------------------- /nn_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MattVitelli/GRUV/71998b14fda01d58a6ad98a82b4f65455a13a18a/nn_utils/__init__.py -------------------------------------------------------------------------------- /nn_utils/network_utils.py: -------------------------------------------------------------------------------- 1 | from keras.models import Sequential 2 | from keras.layers.core import TimeDistributedDense 3 | from keras.layers.recurrent import LSTM, GRU 4 | 5 | def create_lstm_network(num_frequency_dimensions, num_hidden_dimensions, num_recurrent_units=1): 6 | model = Sequential() 7 | #This layer converts frequency space to hidden space 8 | model.add(TimeDistributedDense(input_dim=num_frequency_dimensions, output_dim=num_hidden_dimensions)) 9 | for cur_unit in xrange(num_recurrent_units): 10 | model.add(LSTM(input_dim=num_hidden_dimensions, output_dim=num_hidden_dimensions, return_sequences=True)) 11 | #This layer converts hidden space back to frequency space 12 | model.add(TimeDistributedDense(input_dim=num_hidden_dimensions, output_dim=num_frequency_dimensions)) 13 | model.compile(loss='mean_squared_error', optimizer='rmsprop') 14 | return model 15 | 16 | def create_gru_network(num_frequency_dimensions, num_hidden_dimensions, num_recurrent_units=1): 17 | model = Sequential() 18 | #This layer converts frequency space to hidden space 19 | model.add(TimeDistributedDense(input_dim=num_frequency_dimensions, output_dim=num_hidden_dimensions)) 20 | for cur_unit in xrange(num_recurrent_units): 21 | model.add(GRU(input_dim=num_hidden_dimensions, output_dim=num_hidden_dimensions, return_sequences=True)) 22 | #This layer converts hidden space back to frequency space 23 | model.add(TimeDistributedDense(input_dim=num_hidden_dimensions, output_dim=num_frequency_dimensions)) 24 | model.compile(loss='mean_squared_error', optimizer='rmsprop') 25 | return model 26 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | import numpy as np 4 | import os 5 | import nn_utils.network_utils as network_utils 6 | import config.nn_config as nn_config 7 | 8 | config = nn_config.get_neural_net_configuration() 9 | inputFile = config['model_file'] 10 | cur_iter = 0 11 | model_basename = config['model_basename'] 12 | model_filename = model_basename + str(cur_iter) 13 | 14 | #Load up the training data 15 | print ('Loading training data') 16 | #X_train is a tensor of size (num_train_examples, num_timesteps, num_frequency_dims) 17 | #y_train is a tensor of size (num_train_examples, num_timesteps, num_frequency_dims) 18 | X_train = np.load(inputFile + '_x.npy') 19 | y_train = np.load(inputFile + '_y.npy') 20 | print ('Finished loading training data') 21 | 22 | #Figure out how many frequencies we have in the data 23 | freq_space_dims = X_train.shape[2] 24 | hidden_dims = config['hidden_dimension_size'] 25 | 26 | #Creates a lstm network 27 | model = network_utils.create_lstm_network(num_frequency_dimensions=freq_space_dims, num_hidden_dimensions=hidden_dims) 28 | #You could also substitute this with a RNN or GRU 29 | #model = network_utils.create_gru_network() 30 | 31 | #Load existing weights if available 32 | if os.path.isfile(model_filename): 33 | model.load_weights(model_filename) 34 | 35 | num_iters = 50 #Number of iterations for training 36 | epochs_per_iter = 25 #Number of iterations before we save our model 37 | batch_size = 5 #Number of training examples pushed to the GPU per batch. 38 | #Larger batch sizes require more memory, but training will be faster 39 | print ('Starting training!') 40 | while cur_iter < num_iters: 41 | print('Iteration: ' + str(cur_iter)) 42 | #We set cross-validation to 0, 43 | #as cross-validation will be on different datasets 44 | #if we reload our model between runs 45 | #The moral way to handle this is to manually split 46 | #your data into two sets and run cross-validation after 47 | #you've trained the model for some number of epochs 48 | history = model.fit(X_train, y_train, batch_size=batch_size, nb_epoch=epochs_per_iter, verbose=1, validation_split=0.0) 49 | cur_iter += epochs_per_iter 50 | print ('Training complete!') 51 | model.save_weights(model_basename + str(cur_iter)) --------------------------------------------------------------------------------