├── README.md ├── training └── .gitignore └── train_lstm.py /README.md: -------------------------------------------------------------------------------- 1 | # BinarySearchLSTM 2 | -------------------------------------------------------------------------------- /training/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /train_lstm.py: -------------------------------------------------------------------------------- 1 | import lasagne 2 | from lasagne.layers import * 3 | from lasagne import nonlinearities 4 | from lasagne import init 5 | 6 | import numpy as np 7 | import pickle 8 | import skimage.transform 9 | import scipy 10 | 11 | import theano 12 | import theano.tensor as T 13 | import theano.sandbox.neighbours as TSN 14 | 15 | from lasagne.utils import floatX 16 | 17 | import matplotlib.pyplot as plt 18 | 19 | import sys 20 | import os 21 | from math import exp,log 22 | 23 | import subprocess as sp 24 | 25 | PARAM_EXTENSION = 'params' 26 | NETWORK = "network_lstm" 27 | 28 | # Parameters of the model 29 | 30 | HIDDEN = 16 31 | OUTPUTS = 256 32 | ROUNDS = 8 33 | 34 | LEARN_SUPERVISED = 1e-2 35 | LEARN_REINFORCE = 1e-2 36 | 37 | BATCHSIZE = OUTPUTS*2 38 | 39 | def leakyReLU(x): 40 | return T.maximum(x,0.04*x) 41 | 42 | def build_model(): 43 | net = {} 44 | 45 | net["targetVar"] = T.ivector('targ') 46 | net["reinforceWeight"] = T.vector() 47 | 48 | net["input"] = lasagne.layers.InputLayer(shape=(None,None,4) ) # Inputs are last-guess-low, last-guess-high, last guess 49 | net["batchsize"] = net["input"].input_var.shape[0] 50 | 51 | net["lstm"] = lasagne.layers.LSTMLayer(incoming = net["input"], num_units = HIDDEN, grad_clipping = 1.0) 52 | net["slice"] = lasagne.layers.SliceLayer( net["lstm"], -1, 1) 53 | net["output"] = lasagne.layers.DenseLayer(incoming = net["slice"], num_units=OUTPUTS, nonlinearity=lasagne.nonlinearities.softmax) 54 | 55 | net["getOutput"] = lasagne.layers.get_output( net["output"] ) 56 | 57 | net["params"] = lasagne.layers.get_all_params( net["output"], trainable = True) 58 | net["loss"] = -T.log( net["getOutput"][T.arange(net["batchsize"]),net["targetVar"]] ).mean() 59 | net["rloss"] = -(net["reinforceWeight"] * T.log( net["getOutput"][T.arange(net["batchsize"]),net["targetVar"]] )).mean() 60 | 61 | net["updates"] = lasagne.updates.adam(net["loss"], net["params"], learning_rate = LEARN_SUPERVISED) 62 | net["reinforce_updates"] = lasagne.updates.adam(net["rloss"], net["params"], learning_rate = LEARN_REINFORCE) 63 | 64 | net["process"] = theano.function([net["input"].input_var], net["getOutput"]) 65 | net["train"] = theano.function([net["input"].input_var, net["targetVar"]], [net["loss"]], updates=net["updates"]) 66 | 67 | net["reinforce"] = theano.function([net["input"].input_var, net["targetVar"], net["reinforceWeight"]], [net["getOutput"]], updates=net["reinforce_updates"]) 68 | return net 69 | 70 | net = build_model() 71 | 72 | def runGame(net, epoch): 73 | gameinputs = np.zeros( (BATCHSIZE, ROUNDS, 4) ).astype(np.float32) 74 | winloss = np.zeros( BATCHSIZE ) 75 | gametargets = np.zeros( BATCHSIZE ).astype(np.int32) 76 | guesses = np.zeros( (BATCHSIZE, ROUNDS) ) 77 | 78 | turns = 0 79 | avgerr = 0 80 | for runidx in range(BATCHSIZE): 81 | numbers = np.arange(0,OUTPUTS,1) 82 | correctidx = runidx%OUTPUTS # We play the games in order rather than randomly, for training stability 83 | correctnum = numbers[correctidx] 84 | gametargets[runidx] = correctidx 85 | inputs = np.zeros( (1,0,4) ).astype(np.float32) 86 | guess = np.zeros( 4 ).astype(np.float32) 87 | 88 | for iter in range(ROUNDS): 89 | inputs = np.concatenate( [inputs, guess.reshape( (1,1,4) )], axis = 1 ) 90 | out = net["process"]( inputs ) 91 | 92 | guess[0] = guess[1] = guess[2] = 0 93 | 94 | gn = np.random.choice(numbers, p=out[0]) # Sample from the output distribution to determine the guess 95 | 96 | guess[3] = gn/float(OUTPUTS) # We want to tell the network what it guessed, but no reason to one-hot encode this 97 | guesses[runidx, iter] = gn 98 | 99 | if gn == correctidx: # Let the network know it guessed right 100 | guess[2] = 1 101 | else: 102 | if (numbers[gn] > correctnum): # High 103 | turns += 1 104 | guess[1] = 1 105 | else: # Low 106 | turns += 1 107 | guess[0] = 1 108 | 109 | gameinputs[runidx,:,:] = inputs[0,:,:] # Record the sequence of inputs to the network for this game, for later use in training 110 | 111 | # Accumulate performance statistics for the REINFORCE algorithm for each game 112 | winloss[runidx] = 2*out[0,correctidx]-1 113 | 114 | # Log loss of the final guess 115 | avgerr -= log(out[0,correctidx]+1e-16)/float(BATCHSIZE) 116 | 117 | avgwin = winloss.mean() 118 | wlstd = winloss.std() + 1e-3 119 | 120 | # Log the reinforcement learning parameters 121 | f = open("rlparams.txt","a") 122 | f.write("%d %.6g, %.6g\n" % (epoch, avgwin, wlstd)) 123 | f.close() 124 | 125 | # This is the weight applied to reinforcement updates for each game 126 | winloss = (winloss - avgwin)/wlstd 127 | 128 | # Log a single example play sequence 129 | f = open("examples.txt","wb") 130 | for runidx in range(BATCHSIZE): 131 | f.write("%d " % gametargets[runidx]) 132 | for j in range(ROUNDS): 133 | f.write(" %d" % guesses[runidx,j]) 134 | f.write("\n") 135 | f.close() 136 | 137 | # Make a plot of this batch of games 138 | plt.imshow(np.hstack([gametargets.reshape( (BATCHSIZE, 1) ), guesses]),interpolation='nearest') 139 | plt.axes().set_aspect(1.0/128) 140 | plt.savefig("training/%.6d.png" % epoch) 141 | plt.clf() 142 | 143 | net["train"]( gameinputs, gametargets ) # Supervised training on the final round 144 | 145 | # Pick a random round during the game, and reinforce based on how it turned out 146 | movepoint = np.random.randint(ROUNDS) 147 | net["reinforce"]( gameinputs[:,0:(movepoint+1),:], guesses[:,movepoint].astype(np.int32), winloss.astype(np.float32) ) 148 | 149 | return avgerr, turns/float(BATCHSIZE) 150 | 151 | epoch = 0 152 | while epoch<=6000: 153 | err, err2 = runGame(net,epoch) 154 | 155 | # Log the error 156 | f = open("error.txt","a") 157 | f.write("%d %.6g %.6g\n" % (epoch, err, err2)) # Epoch, log-loss, average number of turns till correct guess 158 | f.close() 159 | 160 | epoch += 1 161 | --------------------------------------------------------------------------------