├── .gitignore ├── README.md ├── __pycache__ ├── tf_mnist_loader.cpython-34.pyc └── tf_mnist_loader.cpython-35.pyc ├── ram.py ├── ram.py.bak ├── ram_modified.py ├── ram_up.py ├── report.txt ├── tf_mnist_loader.py ├── tf_mnist_loader.pyc └── tf_upgrade.py /.gitignore: -------------------------------------------------------------------------------- 1 | demo/ 2 | mnist_data/ 3 | summary/ 4 | chckPts/ 5 | oldScripts/ 6 | __pycache_/ 7 | 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Tensorflow Implementation of Recurrent Attention Model (RAM) 2 | 3 | 4 | ## Author 5 | 6 | Juntae, Kim, jtkim@kaist.ac.kr 7 | 8 | ## Requirement 9 | 10 | tensorflow rc 1.1.0-rc0 11 | 12 | ## Description 13 | 14 | code: 'ram_modified.py' 15 | 16 | This project is modified version of https://github.com/jlindsey15/RAM. 17 | The critical problem of last implemetnation is that the location network cannot learn because of tf.stop_gradient implementation so that they got just '94% accuracy'. It seems relatively bad compared to the result of paper. 18 | If 'tf.stop_gradient' was commented, the classification result was very bad. 19 | The reason I think is that the problem is originated from sharing the gradient flow through location, core, glimpse network. 20 | Through gradient sharing, gradients of classification part are corrupted by gradients of reinforcement part so that classification result 21 | become very bad. (If someone want to share gradient, the weighted loss should be needed. please refer https://arxiv.org/pdf/1412.7755.pdf) 22 | According to their post research, 'Multiple Object Recognition with Visual Attention' (https://arxiv.org/pdf/1412.7755.pdf) they 23 | softly separate location network and others through multi-layer RNN. From this, I assume that sharing the gradient through whole network 24 | is not a good idea so separate them, and finally got a good result. 25 | In summary, the learning stretegy is as follow. 26 | 27 | 1. location network, baseline network : learn with gradients of reinforcement learning only. 28 | 29 | 2. glimpse network, core network : learn with gradients of supervised learning only. 30 | 31 | Thank you! 32 | 33 | ## Result 34 | 35 | After 600,000 epoch, I got about 98% accuracy. 36 | 37 | ## Reference 38 | 39 | Recurrent Models of Visual Attention 40 | 41 | http://papers.nips.cc/paper/5542-recurrent-models-of-visual-attention.pdf 42 | 43 | https://arxiv.org/pdf/1412.7755.pdf 44 | -------------------------------------------------------------------------------- /__pycache__/tf_mnist_loader.cpython-34.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jtkim-kaist/ram_modified/e77fd42c3ad19b8d19d5e3a17d4ac3321f66ac90/__pycache__/tf_mnist_loader.cpython-34.pyc -------------------------------------------------------------------------------- /__pycache__/tf_mnist_loader.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jtkim-kaist/ram_modified/e77fd42c3ad19b8d19d5e3a17d4ac3321f66ac90/__pycache__/tf_mnist_loader.cpython-35.pyc -------------------------------------------------------------------------------- /ram.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tf_mnist_loader 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import time 6 | import random 7 | import sys 8 | import os 9 | 10 | try: 11 | xrange 12 | except NameError: 13 | xrange = range 14 | 15 | dataset = tf_mnist_loader.read_data_sets("mnist_data") 16 | save_dir = "chckPts/" 17 | save_prefix = "save" 18 | summaryFolderName = "summary/" 19 | 20 | 21 | if len(sys.argv) == 2: 22 | simulationName = str(sys.argv[1]) 23 | print("Simulation name = " + simulationName) 24 | summaryFolderName = summaryFolderName + simulationName + "/" 25 | saveImgs = True 26 | imgsFolderName = "imgs/" + simulationName + "/" 27 | if os.path.isdir(summaryFolderName) == False: 28 | os.mkdir(summaryFolderName) 29 | # if os.path.isdir(imgsFolderName) == False: 30 | # os.mkdir(imgsFolderName) 31 | else: 32 | saveImgs = False 33 | print("Testing... image files will not be saved.") 34 | 35 | 36 | start_step = 0 37 | #load_path = None 38 | load_path = save_dir + save_prefix + str(start_step) + ".ckpt" 39 | # to enable visualization, set draw to True 40 | eval_only = False 41 | draw = 0 42 | animate = 0 43 | 44 | # conditions 45 | translateMnist = 1 46 | eyeCentered = 0 47 | 48 | preTraining = 1 49 | preTraining_epoch = 20000 50 | drawReconsturction = 1 51 | 52 | # about translation 53 | MNIST_SIZE = 28 54 | translated_img_size = 60 # side length of the picture 55 | 56 | if translateMnist: 57 | print("TRANSLATED MNIST") 58 | img_size = translated_img_size 59 | depth = 3 # number of zooms 60 | sensorBandwidth = 12 61 | minRadius = 6 # zooms -> minRadius * 2** 62 | 63 | initLr = 5e-3 64 | lrDecayRate = .995 65 | lrDecayFreq = 500 66 | momentumValue = .9 67 | batch_size = 20 68 | 69 | else: 70 | print("CENTERED MNIST") 71 | img_size = MNIST_SIZE 72 | depth = 1 # number of zooms 73 | sensorBandwidth = 8 74 | minRadius = 4 # zooms -> minRadius * 2** 75 | 76 | initLr = 5e-3 77 | lrDecayRate = .99 78 | lrDecayFreq = 200 79 | momentumValue = .9 80 | batch_size = 20 81 | 82 | 83 | # model parameters 84 | channels = 1 # mnist are grayscale images 85 | totalSensorBandwidth = depth * channels * (sensorBandwidth **2) 86 | nGlimpses = 6 # number of glimpses 87 | loc_sd = 0.11 # std when setting the location 88 | 89 | # network units 90 | hg_size = 128 # 91 | hl_size = 128 # 92 | g_size = 256 # 93 | cell_size = 256 # 94 | cell_out_size = cell_size # 95 | 96 | # paramters about the training examples 97 | n_classes = 10 # card(Y) 98 | 99 | # training parameters 100 | max_iters = 1000000 101 | SMALL_NUM = 1e-10 102 | 103 | # resource prellocation 104 | mean_locs = [] # expectation of locations 105 | sampled_locs = [] # sampled locations ~N(mean_locs[.], loc_sd) 106 | baselines = [] # baseline, the value prediction 107 | glimpse_images = [] # to show in window 108 | 109 | 110 | # set the weights to be small random values, with truncated normal distribution 111 | def weight_variable(shape, myname, train): 112 | initial = tf.random_uniform(shape, minval=-0.1, maxval = 0.1) 113 | return tf.Variable(initial, name=myname, trainable=train) 114 | 115 | # get local glimpses 116 | def glimpseSensor(img, normLoc): 117 | loc = tf.round(((normLoc + 1) / 2.0) * img_size) # normLoc coordinates are between -1 and 1 118 | loc = tf.cast(loc, tf.int32) 119 | 120 | img = tf.reshape(img, (batch_size, img_size, img_size, channels)) 121 | 122 | # process each image individually 123 | zooms = [] 124 | for k in range(batch_size): 125 | imgZooms = [] 126 | one_img = img[k,:,:,:] 127 | max_radius = minRadius * (2 ** (depth - 1)) 128 | offset = 2 * max_radius 129 | 130 | # pad image with zeros 131 | one_img = tf.image.pad_to_bounding_box(one_img, offset, offset, \ 132 | max_radius * 4 + img_size, max_radius * 4 + img_size) 133 | 134 | for i in range(depth): 135 | r = int(minRadius * (2 ** (i))) 136 | 137 | d_raw = 2 * r 138 | d = tf.constant(d_raw, shape=[1]) 139 | d = tf.tile(d, [2]) 140 | loc_k = loc[k,:] 141 | adjusted_loc = offset + loc_k - r 142 | one_img2 = tf.reshape(one_img, (one_img.get_shape()[0].value, one_img.get_shape()[1].value)) 143 | 144 | # crop image to (d x d) 145 | zoom = tf.slice(one_img2, adjusted_loc, d) 146 | 147 | # resize cropped image to (sensorBandwidth x sensorBandwidth) 148 | zoom = tf.image.resize_bilinear(tf.reshape(zoom, (1, d_raw, d_raw, 1)), (sensorBandwidth, sensorBandwidth)) 149 | zoom = tf.reshape(zoom, (sensorBandwidth, sensorBandwidth)) 150 | imgZooms.append(zoom) 151 | 152 | zooms.append(tf.pack(imgZooms)) 153 | 154 | zooms = tf.pack(zooms) 155 | 156 | glimpse_images.append(zooms) 157 | 158 | return zooms 159 | 160 | # implements the input network 161 | def get_glimpse(loc): 162 | # get input using the previous location 163 | glimpse_input = glimpseSensor(inputs_placeholder, loc) 164 | glimpse_input = tf.reshape(glimpse_input, (batch_size, totalSensorBandwidth)) 165 | 166 | # the hidden units that process location & the input 167 | act_glimpse_hidden = tf.nn.relu(tf.matmul(glimpse_input, Wg_g_h) + Bg_g_h) 168 | act_loc_hidden = tf.nn.relu(tf.matmul(loc, Wg_l_h) + Bg_l_h) 169 | 170 | # the hidden units that integrates the location & the glimpses 171 | glimpseFeature1 = tf.nn.relu(tf.matmul(act_glimpse_hidden, Wg_hg_gf1) + tf.matmul(act_loc_hidden, Wg_hl_gf1) + Bg_hlhg_gf1) 172 | # return g 173 | # glimpseFeature2 = tf.matmul(glimpseFeature1, Wg_gf1_gf2) + Bg_gf1_gf2 174 | return glimpseFeature1 175 | 176 | 177 | def get_next_input(output): 178 | # the next location is computed by the location network 179 | baseline = tf.sigmoid(tf.matmul(output,Wb_h_b) + Bb_h_b) 180 | baselines.append(baseline) 181 | # compute the next location, then impose noise 182 | if eyeCentered: 183 | # add the last sampled glimpse location 184 | # TODO max(-1, min(1, u + N(output, sigma) + prevLoc)) 185 | mean_loc = tf.maximum(-1.0, tf.minimum(1.0, tf.matmul(output, Wl_h_l) + sampled_locs[-1] )) 186 | else: 187 | mean_loc = tf.matmul(output, Wl_h_l) 188 | 189 | mean_loc = tf.stop_gradient(mean_loc) 190 | mean_locs.append(mean_loc) 191 | 192 | # add noise 193 | # sample_loc = tf.tanh(mean_loc + tf.random_normal(mean_loc.get_shape(), 0, loc_sd)) 194 | sample_loc = tf.maximum(-1.0, tf.minimum(1.0, mean_loc + tf.random_normal(mean_loc.get_shape(), 0, loc_sd))) 195 | 196 | # don't propagate throught the locations 197 | sample_loc = tf.stop_gradient(sample_loc) 198 | sampled_locs.append(sample_loc) 199 | 200 | return get_glimpse(sample_loc) 201 | 202 | 203 | def affineTransform(x,output_dim): 204 | """ 205 | affine transformation Wx+b 206 | assumes x.shape = (batch_size, num_features) 207 | """ 208 | w=tf.get_variable("w", [x.get_shape()[1], output_dim]) 209 | b=tf.get_variable("b", [output_dim], initializer=tf.constant_initializer(0.0)) 210 | return tf.matmul(x,w)+b 211 | 212 | 213 | def model(): 214 | # initialize the location under unif[-1,1], for all example in the batch 215 | initial_loc = tf.random_uniform((batch_size, 2), minval=-1, maxval=1) 216 | mean_locs.append(initial_loc) 217 | initial_loc = tf.tanh(initial_loc + tf.random_normal(initial_loc.get_shape(), 0, loc_sd)) 218 | sampled_locs.append(initial_loc) 219 | 220 | # get the input using the input network 221 | initial_glimpse = get_glimpse(initial_loc) 222 | 223 | # set up the recurrent structure 224 | inputs = [0] * nGlimpses 225 | outputs = [0] * nGlimpses 226 | glimpse = initial_glimpse 227 | REUSE = None 228 | for t in range(nGlimpses): 229 | if t == 0: # initialize the hidden state to be the zero vector 230 | hiddenState_prev = tf.zeros((batch_size, cell_size)) 231 | else: 232 | hiddenState_prev = outputs[t-1] 233 | 234 | # forward prop 235 | with tf.variable_scope("coreNetwork", reuse=REUSE): 236 | # the next hidden state is a function of the previous hidden state and the current glimpse 237 | hiddenState = tf.nn.relu(affineTransform(hiddenState_prev, cell_size) + (tf.matmul(glimpse, Wc_g_h) + Bc_g_h)) 238 | 239 | # save the current glimpse and the hidden state 240 | inputs[t] = glimpse 241 | outputs[t] = hiddenState 242 | # get the next input glimpse 243 | if t != nGlimpses -1: 244 | glimpse = get_next_input(hiddenState) 245 | else: 246 | baseline = tf.sigmoid(tf.matmul(hiddenState, Wb_h_b) + Bb_h_b) 247 | baselines.append(baseline) 248 | REUSE = True # share variables for later recurrence 249 | 250 | return outputs 251 | 252 | 253 | def dense_to_one_hot(labels_dense, num_classes=10): 254 | """Convert class labels from scalars to one-hot vectors.""" 255 | # copied from TensorFlow tutorial 256 | num_labels = labels_dense.shape[0] 257 | index_offset = np.arange(num_labels) * num_classes 258 | labels_one_hot = np.zeros((num_labels, num_classes)) 259 | labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 260 | return labels_one_hot 261 | 262 | 263 | # to use for maximum likelihood with input location 264 | def gaussian_pdf(mean, sample): 265 | Z = 1.0 / (loc_sd * tf.sqrt(2.0 * np.pi)) 266 | a = -tf.square(sample - mean) / (2.0 * tf.square(loc_sd)) 267 | return Z * tf.exp(a) 268 | 269 | 270 | def calc_reward(outputs): 271 | 272 | # consider the action at the last time step 273 | outputs = outputs[-1] # look at ONLY THE END of the sequence 274 | outputs = tf.reshape(outputs, (batch_size, cell_out_size)) 275 | 276 | # get the baseline 277 | b = tf.pack(baselines) 278 | b = tf.concat(2, [b, b]) 279 | b = tf.reshape(b, (batch_size, (nGlimpses) * 2)) 280 | no_grad_b = tf.stop_gradient(b) 281 | 282 | # get the action(classification) 283 | p_y = tf.nn.softmax(tf.matmul(outputs, Wa_h_a) + Ba_h_a) 284 | max_p_y = tf.arg_max(p_y, 1) 285 | correct_y = tf.cast(labels_placeholder, tf.int64) 286 | 287 | # reward for all examples in the batch 288 | R = tf.cast(tf.equal(max_p_y, correct_y), tf.float32) 289 | reward = tf.reduce_mean(R) # mean reward 290 | R = tf.reshape(R, (batch_size, 1)) 291 | R = tf.tile(R, [1, (nGlimpses)*2]) 292 | 293 | # get the location 294 | p_loc = gaussian_pdf(mean_locs, sampled_locs) 295 | p_loc = tf.tanh(p_loc) 296 | p_loc_orig = p_loc 297 | p_loc = tf.reshape(p_loc, (batch_size, (nGlimpses) * 2)) 298 | 299 | # define the cost function 300 | J = tf.concat(1, [tf.log(p_y + SMALL_NUM) * (onehot_labels_placeholder), tf.log(p_loc + SMALL_NUM) * (R - no_grad_b)]) 301 | J = tf.reduce_sum(J, 1) 302 | J = J - tf.reduce_sum(tf.square(R - b), 1) 303 | J = tf.reduce_mean(J, 0) 304 | cost = -J 305 | 306 | # define the optimizer 307 | optimizer = tf.train.MomentumOptimizer(lr, momentumValue) 308 | train_op = optimizer.minimize(cost, global_step) 309 | 310 | return cost, reward, max_p_y, correct_y, train_op, b, tf.reduce_mean(b), tf.reduce_mean(R - b), lr 311 | 312 | 313 | def preTrain(outputs): 314 | lr_r = 1e-3 315 | # consider the action at the last time step 316 | outputs = outputs[-1] # look at ONLY THE END of the sequence 317 | outputs = tf.reshape(outputs, (batch_size, cell_out_size)) 318 | # if preTraining: 319 | reconstruction = tf.sigmoid(tf.matmul(outputs, Wr_h_r) + Br_h_r) 320 | reconstructionCost = tf.reduce_mean(tf.square(inputs_placeholder - reconstruction)) 321 | 322 | train_op_r = tf.train.RMSPropOptimizer(lr_r).minimize(reconstructionCost) 323 | return reconstructionCost, reconstruction, train_op_r 324 | 325 | 326 | 327 | def evaluate(): 328 | data = dataset.test 329 | batches_in_epoch = len(data._images) // batch_size 330 | accuracy = 0 331 | 332 | for i in range(batches_in_epoch): 333 | nextX, nextY = dataset.test.next_batch(batch_size) 334 | if translateMnist: 335 | nextX, _ = convertTranslated(nextX, MNIST_SIZE, img_size) 336 | feed_dict = {inputs_placeholder: nextX, labels_placeholder: nextY, 337 | onehot_labels_placeholder: dense_to_one_hot(nextY)} 338 | r = sess.run(reward, feed_dict=feed_dict) 339 | accuracy += r 340 | 341 | accuracy /= batches_in_epoch 342 | print(("ACCURACY: " + str(accuracy))) 343 | 344 | 345 | def convertTranslated(images, initImgSize, finalImgSize): 346 | size_diff = finalImgSize - initImgSize 347 | newimages = np.zeros([batch_size, finalImgSize*finalImgSize]) 348 | imgCoord = np.zeros([batch_size,2]) 349 | for k in range(batch_size): 350 | image = images[k, :] 351 | image = np.reshape(image, (initImgSize, initImgSize)) 352 | # generate and save random coordinates 353 | randX = random.randint(0, size_diff) 354 | randY = random.randint(0, size_diff) 355 | imgCoord[k,:] = np.array([randX, randY]) 356 | # padding 357 | image = np.lib.pad(image, ((randX, size_diff - randX), (randY, size_diff - randY)), 'constant', constant_values = (0)) 358 | newimages[k, :] = np.reshape(image, (finalImgSize*finalImgSize)) 359 | 360 | return newimages, imgCoord 361 | 362 | 363 | 364 | def toMnistCoordinates(coordinate_tanh): 365 | ''' 366 | Transform coordinate in [-1,1] to mnist 367 | :param coordinate_tanh: vector in [-1,1] x [-1,1] 368 | :return: vector in the corresponding mnist coordinate 369 | ''' 370 | return np.round(((coordinate_tanh + 1) / 2.0) * img_size) 371 | 372 | 373 | def variable_summaries(var, name): 374 | """Attach a lot of summaries to a Tensor.""" 375 | with tf.name_scope('param_summaries'): 376 | mean = tf.reduce_mean(var) 377 | tf.scalar_summary('param_mean/' + name, mean) 378 | with tf.name_scope('param_stddev'): 379 | stddev = tf.sqrt(tf.reduce_sum(tf.square(var - mean))) 380 | tf.scalar_summary('param_sttdev/' + name, stddev) 381 | tf.scalar_summary('param_max/' + name, tf.reduce_max(var)) 382 | tf.scalar_summary('param_min/' + name, tf.reduce_min(var)) 383 | tf.histogram_summary(name, var) 384 | 385 | 386 | def plotWholeImg(img, img_size, sampled_locs_fetched): 387 | plt.imshow(np.reshape(img, [img_size, img_size]), 388 | cmap=plt.get_cmap('gray'), interpolation="nearest") 389 | 390 | plt.ylim((img_size - 1, 0)) 391 | plt.xlim((0, img_size - 1)) 392 | 393 | # transform the coordinate to mnist map 394 | sampled_locs_mnist_fetched = toMnistCoordinates(sampled_locs_fetched) 395 | # visualize the trace of successive nGlimpses (note that x and y coordinates are "flipped") 396 | plt.plot(sampled_locs_mnist_fetched[0, :, 1], sampled_locs_mnist_fetched[0, :, 0], '-o', 397 | color='lawngreen') 398 | plt.plot(sampled_locs_mnist_fetched[0, -1, 1], sampled_locs_mnist_fetched[0, -1, 0], 'o', 399 | color='red') 400 | 401 | 402 | 403 | with tf.Graph().as_default(): 404 | 405 | # set the learning rate 406 | global_step = tf.Variable(0, trainable=False) 407 | lr = tf.train.exponential_decay(initLr, global_step, lrDecayFreq, lrDecayRate, staircase=True) 408 | 409 | # preallocate x, y, baseline 410 | labels = tf.placeholder("float32", shape=[batch_size, n_classes]) 411 | labels_placeholder = tf.placeholder(tf.float32, shape=(batch_size), name="labels_raw") 412 | onehot_labels_placeholder = tf.placeholder(tf.float32, shape=(batch_size, 10), name="labels_onehot") 413 | inputs_placeholder = tf.placeholder(tf.float32, shape=(batch_size, img_size * img_size), name="images") 414 | 415 | # declare the model parameters, here're naming rule: 416 | # the 1st captical letter: weights or bias (W = weights, B = bias) 417 | # the 2nd lowercase letter: the network (e.g.: g = glimpse network) 418 | # the 3rd and 4th letter(s): input-output mapping, which is clearly written in the variable name argument 419 | 420 | Wg_l_h = weight_variable((2, hl_size), "glimpseNet_wts_location_hidden", True) 421 | Bg_l_h = weight_variable((1,hl_size), "glimpseNet_bias_location_hidden", True) 422 | 423 | Wg_g_h = weight_variable((totalSensorBandwidth, hg_size), "glimpseNet_wts_glimpse_hidden", True) 424 | Bg_g_h = weight_variable((1,hg_size), "glimpseNet_bias_glimpse_hidden", True) 425 | 426 | Wg_hg_gf1 = weight_variable((hg_size, g_size), "glimpseNet_wts_hiddenGlimpse_glimpseFeature1", True) 427 | Wg_hl_gf1 = weight_variable((hl_size, g_size), "glimpseNet_wts_hiddenLocation_glimpseFeature1", True) 428 | Bg_hlhg_gf1 = weight_variable((1,g_size), "glimpseNet_bias_hGlimpse_hLocs_glimpseFeature1", True) 429 | 430 | Wc_g_h = weight_variable((cell_size, g_size), "coreNet_wts_glimpse_hidden", True) 431 | Bc_g_h = weight_variable((1,g_size), "coreNet_bias_glimpse_hidden", True) 432 | 433 | Wr_h_r = weight_variable((cell_out_size, img_size**2), "reconstructionNet_wts_hidden_action", True) 434 | Br_h_r = weight_variable((1, img_size**2), "reconstructionNet_bias_hidden_action", True) 435 | 436 | Wb_h_b = weight_variable((g_size, 1), "baselineNet_wts_hiddenState_baseline", True) 437 | Bb_h_b = weight_variable((1,1), "baselineNet_bias_hiddenState_baseline", True) 438 | 439 | Wl_h_l = weight_variable((cell_out_size, 2), "locationNet_wts_hidden_location", True) 440 | 441 | Wa_h_a = weight_variable((cell_out_size, n_classes), "actionNet_wts_hidden_action", True) 442 | Ba_h_a = weight_variable((1,n_classes), "actionNet_bias_hidden_action", True) 443 | 444 | # query the model ouput 445 | outputs = model() 446 | 447 | # convert list of tensors to one big tensor 448 | sampled_locs = tf.concat(0, sampled_locs) 449 | sampled_locs = tf.reshape(sampled_locs, (nGlimpses, batch_size, 2)) 450 | sampled_locs = tf.transpose(sampled_locs, [1, 0, 2]) 451 | mean_locs = tf.concat(0, mean_locs) 452 | mean_locs = tf.reshape(mean_locs, (nGlimpses, batch_size, 2)) 453 | mean_locs = tf.transpose(mean_locs, [1, 0, 2]) 454 | glimpse_images = tf.concat(0, glimpse_images) 455 | 456 | 457 | 458 | # compute the reward 459 | reconstructionCost, reconstruction, train_op_r = preTrain(outputs) 460 | cost, reward, predicted_labels, correct_labels, train_op, b, avg_b, rminusb, lr = calc_reward(outputs) 461 | 462 | # tensorboard visualization for the parameters 463 | variable_summaries(Wg_l_h, "glimpseNet_wts_location_hidden") 464 | variable_summaries(Bg_l_h, "glimpseNet_bias_location_hidden") 465 | variable_summaries(Wg_g_h, "glimpseNet_wts_glimpse_hidden") 466 | variable_summaries(Bg_g_h, "glimpseNet_bias_glimpse_hidden") 467 | variable_summaries(Wg_hg_gf1, "glimpseNet_wts_hiddenGlimpse_glimpseFeature1") 468 | variable_summaries(Wg_hl_gf1, "glimpseNet_wts_hiddenLocation_glimpseFeature1") 469 | variable_summaries(Bg_hlhg_gf1, "glimpseNet_bias_hGlimpse_hLocs_glimpseFeature1") 470 | 471 | variable_summaries(Wc_g_h, "coreNet_wts_glimpse_hidden") 472 | variable_summaries(Bc_g_h, "coreNet_bias_glimpse_hidden") 473 | 474 | variable_summaries(Wb_h_b, "baselineNet_wts_hiddenState_baseline") 475 | variable_summaries(Bb_h_b, "baselineNet_bias_hiddenState_baseline") 476 | 477 | variable_summaries(Wl_h_l, "locationNet_wts_hidden_location") 478 | 479 | variable_summaries(Wa_h_a, 'actionNet_wts_hidden_action') 480 | variable_summaries(Ba_h_a, 'actionNet_bias_hidden_action') 481 | 482 | # tensorboard visualization for the performance metrics 483 | tf.scalar_summary("reconstructionCost", reconstructionCost) 484 | tf.scalar_summary("reward", reward) 485 | tf.scalar_summary("cost", cost) 486 | tf.scalar_summary("mean(b)", avg_b) 487 | tf.scalar_summary(" mean(R - b)", rminusb) 488 | summary_op = tf.merge_all_summaries() 489 | 490 | 491 | ####################################### START RUNNING THE MODEL ####################################### 492 | sess = tf.Session() 493 | saver = tf.train.Saver() 494 | b_fetched = np.zeros((batch_size, (nGlimpses)*2)) 495 | 496 | init = tf.initialize_all_variables() 497 | sess.run(init) 498 | 499 | if eval_only: 500 | evaluate() 501 | else: 502 | summary_writer = tf.train.SummaryWriter(summaryFolderName, graph=sess.graph) 503 | 504 | if draw: 505 | fig = plt.figure(1) 506 | txt = fig.suptitle("-", fontsize=36, fontweight='bold') 507 | plt.ion() 508 | plt.show() 509 | plt.subplots_adjust(top=0.7) 510 | plotImgs = [] 511 | 512 | if drawReconsturction: 513 | fig = plt.figure(2) 514 | txt = fig.suptitle("-", fontsize=36, fontweight='bold') 515 | plt.ion() 516 | plt.show() 517 | 518 | if preTraining: 519 | for epoch_r in range(1,preTraining_epoch): 520 | nextX, _ = dataset.train.next_batch(batch_size) 521 | nextX_orig = nextX 522 | if translateMnist: 523 | nextX, _ = convertTranslated(nextX, MNIST_SIZE, img_size) 524 | 525 | fetches_r = [reconstructionCost, reconstruction, train_op_r] 526 | 527 | reconstructionCost_fetched, reconstruction_fetched, train_op_r_fetched = sess.run(fetches_r, feed_dict={inputs_placeholder: nextX}) 528 | 529 | if epoch_r % 20 == 0: 530 | print(('Step %d: reconstructionCost = %.5f' % (epoch_r, reconstructionCost_fetched))) 531 | if epoch_r % 100 == 0: 532 | if drawReconsturction: 533 | fig = plt.figure(2) 534 | 535 | plt.subplot(1, 2, 1) 536 | plt.imshow(np.reshape(nextX[0, :], [img_size, img_size]), 537 | cmap=plt.get_cmap('gray'), interpolation="nearest") 538 | plt.ylim((img_size - 1, 0)) 539 | plt.xlim((0, img_size - 1)) 540 | 541 | plt.subplot(1, 2, 2) 542 | plt.imshow(np.reshape(reconstruction_fetched[0, :], [img_size, img_size]), 543 | cmap=plt.get_cmap('gray'), interpolation="nearest") 544 | plt.ylim((img_size - 1, 0)) 545 | plt.xlim((0, img_size - 1)) 546 | plt.draw() 547 | plt.pause(0.0001) 548 | # plt.show() 549 | 550 | 551 | # training 552 | for epoch in range(start_step + 1, max_iters): 553 | start_time = time.time() 554 | 555 | # get the next batch of examples 556 | nextX, nextY = dataset.train.next_batch(batch_size) 557 | nextX_orig = nextX 558 | if translateMnist: 559 | nextX, nextX_coord = convertTranslated(nextX, MNIST_SIZE, img_size) 560 | 561 | feed_dict = {inputs_placeholder: nextX, labels_placeholder: nextY, \ 562 | onehot_labels_placeholder: dense_to_one_hot(nextY)} 563 | 564 | fetches = [train_op, cost, reward, predicted_labels, correct_labels, glimpse_images, avg_b, rminusb, \ 565 | mean_locs, sampled_locs, lr] 566 | # feed them to the model 567 | results = sess.run(fetches, feed_dict=feed_dict) 568 | 569 | _, cost_fetched, reward_fetched, prediction_labels_fetched, correct_labels_fetched, glimpse_images_fetched, \ 570 | avg_b_fetched, rminusb_fetched, mean_locs_fetched, sampled_locs_fetched, lr_fetched = results 571 | 572 | 573 | duration = time.time() - start_time 574 | 575 | if epoch % 20 == 0: 576 | print(('Step %d: cost = %.5f reward = %.5f (%.3f sec) b = %.5f R-b = %.5f, LR = %.5f' 577 | % (epoch, cost_fetched, reward_fetched, duration, avg_b_fetched, rminusb_fetched, lr_fetched))) 578 | summary_str = sess.run(summary_op, feed_dict=feed_dict) 579 | summary_writer.add_summary(summary_str, epoch) 580 | # if saveImgs: 581 | # plt.savefig(imgsFolderName + simulationName + '_ep%.6d.png' % (epoch)) 582 | 583 | if epoch % 5000 == 0: 584 | saver.save(sess, save_dir + save_prefix + str(epoch) + ".ckpt") 585 | evaluate() 586 | 587 | ##### DRAW WINDOW ################ 588 | f_glimpse_images = np.reshape(glimpse_images_fetched, \ 589 | (nGlimpses, batch_size, depth, sensorBandwidth, sensorBandwidth)) 590 | 591 | if draw: 592 | if animate: 593 | fillList = False 594 | if len(plotImgs) == 0: 595 | fillList = True 596 | 597 | # display the first image in the in mini-batch 598 | nCols = depth+1 599 | plt.subplot2grid((depth, nCols), (0, 1), rowspan=depth, colspan=depth) 600 | # display the entire image 601 | plotWholeImg(nextX[0, :], img_size, sampled_locs_fetched) 602 | 603 | # display the glimpses 604 | for y in range(nGlimpses): 605 | txt.set_text('Epoch: %.6d \nPrediction: %i -- Truth: %i\nStep: %i/%i' 606 | % (epoch, prediction_labels_fetched[0], correct_labels_fetched[0], (y + 1), nGlimpses)) 607 | 608 | for x in range(depth): 609 | plt.subplot(depth, nCols, 1 + nCols * x) 610 | if fillList: 611 | plotImg = plt.imshow(f_glimpse_images[y, 0, x], cmap=plt.get_cmap('gray'), 612 | interpolation="nearest") 613 | plotImg.autoscale() 614 | plotImgs.append(plotImg) 615 | else: 616 | plotImgs[x].set_data(f_glimpse_images[y, 0, x]) 617 | plotImgs[x].autoscale() 618 | fillList = False 619 | 620 | # fig.canvas.draw() 621 | time.sleep(0.1) 622 | plt.pause(0.00005) 623 | 624 | else: 625 | txt.set_text('PREDICTION: %i\nTRUTH: %i' % (prediction_labels_fetched[0], correct_labels_fetched[0])) 626 | for x in range(depth): 627 | for y in range(nGlimpses): 628 | plt.subplot(depth, nGlimpses, x * nGlimpses + y + 1) 629 | plt.imshow(f_glimpse_images[y, 0, x], cmap=plt.get_cmap('gray'), interpolation="nearest") 630 | 631 | plt.draw() 632 | time.sleep(0.05) 633 | plt.pause(0.0001) 634 | 635 | sess.close() -------------------------------------------------------------------------------- /ram.py.bak: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tf_mnist_loader 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import time 6 | import random 7 | import sys 8 | import os 9 | 10 | try: 11 | xrange 12 | except NameError: 13 | xrange = range 14 | 15 | dataset = tf_mnist_loader.read_data_sets("mnist_data") 16 | save_dir = "chckPts/" 17 | save_prefix = "save" 18 | summaryFolderName = "summary/" 19 | 20 | 21 | if len(sys.argv) == 2: 22 | simulationName = str(sys.argv[1]) 23 | print "Simulation name = " + simulationName 24 | summaryFolderName = summaryFolderName + simulationName + "/" 25 | saveImgs = True 26 | imgsFolderName = "imgs/" + simulationName + "/" 27 | if os.path.isdir(summaryFolderName) == False: 28 | os.mkdir(summaryFolderName) 29 | # if os.path.isdir(imgsFolderName) == False: 30 | # os.mkdir(imgsFolderName) 31 | else: 32 | saveImgs = False 33 | print "Testing... image files will not be saved." 34 | 35 | 36 | start_step = 0 37 | #load_path = None 38 | load_path = save_dir + save_prefix + str(start_step) + ".ckpt" 39 | # to enable visualization, set draw to True 40 | eval_only = False 41 | draw = 0 42 | animate = 0 43 | 44 | # conditions 45 | translateMnist = 1 46 | eyeCentered = 0 47 | 48 | preTraining = 1 49 | preTraining_epoch = 20000 50 | drawReconsturction = 1 51 | 52 | # about translation 53 | MNIST_SIZE = 28 54 | translated_img_size = 60 # side length of the picture 55 | 56 | if translateMnist: 57 | print "TRANSLATED MNIST" 58 | img_size = translated_img_size 59 | depth = 3 # number of zooms 60 | sensorBandwidth = 12 61 | minRadius = 6 # zooms -> minRadius * 2** 62 | 63 | initLr = 5e-3 64 | lrDecayRate = .995 65 | lrDecayFreq = 500 66 | momentumValue = .9 67 | batch_size = 20 68 | 69 | else: 70 | print "CENTERED MNIST" 71 | img_size = MNIST_SIZE 72 | depth = 1 # number of zooms 73 | sensorBandwidth = 8 74 | minRadius = 4 # zooms -> minRadius * 2** 75 | 76 | initLr = 5e-3 77 | lrDecayRate = .99 78 | lrDecayFreq = 200 79 | momentumValue = .9 80 | batch_size = 20 81 | 82 | 83 | # model parameters 84 | channels = 1 # mnist are grayscale images 85 | totalSensorBandwidth = depth * channels * (sensorBandwidth **2) 86 | nGlimpses = 6 # number of glimpses 87 | loc_sd = 0.11 # std when setting the location 88 | 89 | # network units 90 | hg_size = 128 # 91 | hl_size = 128 # 92 | g_size = 256 # 93 | cell_size = 256 # 94 | cell_out_size = cell_size # 95 | 96 | # paramters about the training examples 97 | n_classes = 10 # card(Y) 98 | 99 | # training parameters 100 | max_iters = 1000000 101 | SMALL_NUM = 1e-10 102 | 103 | # resource prellocation 104 | mean_locs = [] # expectation of locations 105 | sampled_locs = [] # sampled locations ~N(mean_locs[.], loc_sd) 106 | baselines = [] # baseline, the value prediction 107 | glimpse_images = [] # to show in window 108 | 109 | 110 | # set the weights to be small random values, with truncated normal distribution 111 | def weight_variable(shape, myname, train): 112 | initial = tf.random_uniform(shape, minval=-0.1, maxval = 0.1) 113 | return tf.Variable(initial, name=myname, trainable=train) 114 | 115 | # get local glimpses 116 | def glimpseSensor(img, normLoc): 117 | loc = tf.round(((normLoc + 1) / 2.0) * img_size) # normLoc coordinates are between -1 and 1 118 | loc = tf.cast(loc, tf.int32) 119 | 120 | img = tf.reshape(img, (batch_size, img_size, img_size, channels)) 121 | 122 | # process each image individually 123 | zooms = [] 124 | for k in xrange(batch_size): 125 | imgZooms = [] 126 | one_img = img[k,:,:,:] 127 | max_radius = minRadius * (2 ** (depth - 1)) 128 | offset = 2 * max_radius 129 | 130 | # pad image with zeros 131 | one_img = tf.image.pad_to_bounding_box(one_img, offset, offset, \ 132 | max_radius * 4 + img_size, max_radius * 4 + img_size) 133 | 134 | for i in xrange(depth): 135 | r = int(minRadius * (2 ** (i))) 136 | 137 | d_raw = 2 * r 138 | d = tf.constant(d_raw, shape=[1]) 139 | d = tf.tile(d, [2]) 140 | loc_k = loc[k,:] 141 | adjusted_loc = offset + loc_k - r 142 | one_img2 = tf.reshape(one_img, (one_img.get_shape()[0].value, one_img.get_shape()[1].value)) 143 | 144 | # crop image to (d x d) 145 | zoom = tf.slice(one_img2, adjusted_loc, d) 146 | 147 | # resize cropped image to (sensorBandwidth x sensorBandwidth) 148 | zoom = tf.image.resize_bilinear(tf.reshape(zoom, (1, d_raw, d_raw, 1)), (sensorBandwidth, sensorBandwidth)) 149 | zoom = tf.reshape(zoom, (sensorBandwidth, sensorBandwidth)) 150 | imgZooms.append(zoom) 151 | 152 | zooms.append(tf.pack(imgZooms)) 153 | 154 | zooms = tf.pack(zooms) 155 | 156 | glimpse_images.append(zooms) 157 | 158 | return zooms 159 | 160 | # implements the input network 161 | def get_glimpse(loc): 162 | # get input using the previous location 163 | glimpse_input = glimpseSensor(inputs_placeholder, loc) 164 | glimpse_input = tf.reshape(glimpse_input, (batch_size, totalSensorBandwidth)) 165 | 166 | # the hidden units that process location & the input 167 | act_glimpse_hidden = tf.nn.relu(tf.matmul(glimpse_input, Wg_g_h) + Bg_g_h) 168 | act_loc_hidden = tf.nn.relu(tf.matmul(loc, Wg_l_h) + Bg_l_h) 169 | 170 | # the hidden units that integrates the location & the glimpses 171 | glimpseFeature1 = tf.nn.relu(tf.matmul(act_glimpse_hidden, Wg_hg_gf1) + tf.matmul(act_loc_hidden, Wg_hl_gf1) + Bg_hlhg_gf1) 172 | # return g 173 | # glimpseFeature2 = tf.matmul(glimpseFeature1, Wg_gf1_gf2) + Bg_gf1_gf2 174 | return glimpseFeature1 175 | 176 | 177 | def get_next_input(output): 178 | # the next location is computed by the location network 179 | baseline = tf.sigmoid(tf.matmul(output,Wb_h_b) + Bb_h_b) 180 | baselines.append(baseline) 181 | # compute the next location, then impose noise 182 | if eyeCentered: 183 | # add the last sampled glimpse location 184 | # TODO max(-1, min(1, u + N(output, sigma) + prevLoc)) 185 | mean_loc = tf.maximum(-1.0, tf.minimum(1.0, tf.matmul(output, Wl_h_l) + sampled_locs[-1] )) 186 | else: 187 | mean_loc = tf.matmul(output, Wl_h_l) 188 | 189 | mean_loc = tf.stop_gradient(mean_loc) 190 | mean_locs.append(mean_loc) 191 | 192 | # add noise 193 | # sample_loc = tf.tanh(mean_loc + tf.random_normal(mean_loc.get_shape(), 0, loc_sd)) 194 | sample_loc = tf.maximum(-1.0, tf.minimum(1.0, mean_loc + tf.random_normal(mean_loc.get_shape(), 0, loc_sd))) 195 | 196 | # don't propagate throught the locations 197 | sample_loc = tf.stop_gradient(sample_loc) 198 | sampled_locs.append(sample_loc) 199 | 200 | return get_glimpse(sample_loc) 201 | 202 | 203 | def affineTransform(x,output_dim): 204 | """ 205 | affine transformation Wx+b 206 | assumes x.shape = (batch_size, num_features) 207 | """ 208 | w=tf.get_variable("w", [x.get_shape()[1], output_dim]) 209 | b=tf.get_variable("b", [output_dim], initializer=tf.constant_initializer(0.0)) 210 | return tf.matmul(x,w)+b 211 | 212 | 213 | def model(): 214 | # initialize the location under unif[-1,1], for all example in the batch 215 | initial_loc = tf.random_uniform((batch_size, 2), minval=-1, maxval=1) 216 | mean_locs.append(initial_loc) 217 | initial_loc = tf.tanh(initial_loc + tf.random_normal(initial_loc.get_shape(), 0, loc_sd)) 218 | sampled_locs.append(initial_loc) 219 | 220 | # get the input using the input network 221 | initial_glimpse = get_glimpse(initial_loc) 222 | 223 | # set up the recurrent structure 224 | inputs = [0] * nGlimpses 225 | outputs = [0] * nGlimpses 226 | glimpse = initial_glimpse 227 | REUSE = None 228 | for t in range(nGlimpses): 229 | if t == 0: # initialize the hidden state to be the zero vector 230 | hiddenState_prev = tf.zeros((batch_size, cell_size)) 231 | else: 232 | hiddenState_prev = outputs[t-1] 233 | 234 | # forward prop 235 | with tf.variable_scope("coreNetwork", reuse=REUSE): 236 | # the next hidden state is a function of the previous hidden state and the current glimpse 237 | hiddenState = tf.nn.relu(affineTransform(hiddenState_prev, cell_size) + (tf.matmul(glimpse, Wc_g_h) + Bc_g_h)) 238 | 239 | # save the current glimpse and the hidden state 240 | inputs[t] = glimpse 241 | outputs[t] = hiddenState 242 | # get the next input glimpse 243 | if t != nGlimpses -1: 244 | glimpse = get_next_input(hiddenState) 245 | else: 246 | baseline = tf.sigmoid(tf.matmul(hiddenState, Wb_h_b) + Bb_h_b) 247 | baselines.append(baseline) 248 | REUSE = True # share variables for later recurrence 249 | 250 | return outputs 251 | 252 | 253 | def dense_to_one_hot(labels_dense, num_classes=10): 254 | """Convert class labels from scalars to one-hot vectors.""" 255 | # copied from TensorFlow tutorial 256 | num_labels = labels_dense.shape[0] 257 | index_offset = np.arange(num_labels) * num_classes 258 | labels_one_hot = np.zeros((num_labels, num_classes)) 259 | labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 260 | return labels_one_hot 261 | 262 | 263 | # to use for maximum likelihood with input location 264 | def gaussian_pdf(mean, sample): 265 | Z = 1.0 / (loc_sd * tf.sqrt(2.0 * np.pi)) 266 | a = -tf.square(sample - mean) / (2.0 * tf.square(loc_sd)) 267 | return Z * tf.exp(a) 268 | 269 | 270 | def calc_reward(outputs): 271 | 272 | # consider the action at the last time step 273 | outputs = outputs[-1] # look at ONLY THE END of the sequence 274 | outputs = tf.reshape(outputs, (batch_size, cell_out_size)) 275 | 276 | # get the baseline 277 | b = tf.pack(baselines) 278 | b = tf.concat(2, [b, b]) 279 | b = tf.reshape(b, (batch_size, (nGlimpses) * 2)) 280 | no_grad_b = tf.stop_gradient(b) 281 | 282 | # get the action(classification) 283 | p_y = tf.nn.softmax(tf.matmul(outputs, Wa_h_a) + Ba_h_a) 284 | max_p_y = tf.arg_max(p_y, 1) 285 | correct_y = tf.cast(labels_placeholder, tf.int64) 286 | 287 | # reward for all examples in the batch 288 | R = tf.cast(tf.equal(max_p_y, correct_y), tf.float32) 289 | reward = tf.reduce_mean(R) # mean reward 290 | R = tf.reshape(R, (batch_size, 1)) 291 | R = tf.tile(R, [1, (nGlimpses)*2]) 292 | 293 | # get the location 294 | p_loc = gaussian_pdf(mean_locs, sampled_locs) 295 | p_loc = tf.tanh(p_loc) 296 | p_loc_orig = p_loc 297 | p_loc = tf.reshape(p_loc, (batch_size, (nGlimpses) * 2)) 298 | 299 | # define the cost function 300 | J = tf.concat(1, [tf.log(p_y + SMALL_NUM) * (onehot_labels_placeholder), tf.log(p_loc + SMALL_NUM) * (R - no_grad_b)]) 301 | J = tf.reduce_sum(J, 1) 302 | J = J - tf.reduce_sum(tf.square(R - b), 1) 303 | J = tf.reduce_mean(J, 0) 304 | cost = -J 305 | 306 | # define the optimizer 307 | optimizer = tf.train.MomentumOptimizer(lr, momentumValue) 308 | train_op = optimizer.minimize(cost, global_step) 309 | 310 | return cost, reward, max_p_y, correct_y, train_op, b, tf.reduce_mean(b), tf.reduce_mean(R - b), lr 311 | 312 | 313 | def preTrain(outputs): 314 | lr_r = 1e-3 315 | # consider the action at the last time step 316 | outputs = outputs[-1] # look at ONLY THE END of the sequence 317 | outputs = tf.reshape(outputs, (batch_size, cell_out_size)) 318 | # if preTraining: 319 | reconstruction = tf.sigmoid(tf.matmul(outputs, Wr_h_r) + Br_h_r) 320 | reconstructionCost = tf.reduce_mean(tf.square(inputs_placeholder - reconstruction)) 321 | 322 | train_op_r = tf.train.RMSPropOptimizer(lr_r).minimize(reconstructionCost) 323 | return reconstructionCost, reconstruction, train_op_r 324 | 325 | 326 | 327 | def evaluate(): 328 | data = dataset.test 329 | batches_in_epoch = len(data._images) // batch_size 330 | accuracy = 0 331 | 332 | for i in xrange(batches_in_epoch): 333 | nextX, nextY = dataset.test.next_batch(batch_size) 334 | if translateMnist: 335 | nextX, _ = convertTranslated(nextX, MNIST_SIZE, img_size) 336 | feed_dict = {inputs_placeholder: nextX, labels_placeholder: nextY, 337 | onehot_labels_placeholder: dense_to_one_hot(nextY)} 338 | r = sess.run(reward, feed_dict=feed_dict) 339 | accuracy += r 340 | 341 | accuracy /= batches_in_epoch 342 | print("ACCURACY: " + str(accuracy)) 343 | 344 | 345 | def convertTranslated(images, initImgSize, finalImgSize): 346 | size_diff = finalImgSize - initImgSize 347 | newimages = np.zeros([batch_size, finalImgSize*finalImgSize]) 348 | imgCoord = np.zeros([batch_size,2]) 349 | for k in xrange(batch_size): 350 | image = images[k, :] 351 | image = np.reshape(image, (initImgSize, initImgSize)) 352 | # generate and save random coordinates 353 | randX = random.randint(0, size_diff) 354 | randY = random.randint(0, size_diff) 355 | imgCoord[k,:] = np.array([randX, randY]) 356 | # padding 357 | image = np.lib.pad(image, ((randX, size_diff - randX), (randY, size_diff - randY)), 'constant', constant_values = (0)) 358 | newimages[k, :] = np.reshape(image, (finalImgSize*finalImgSize)) 359 | 360 | return newimages, imgCoord 361 | 362 | 363 | 364 | def toMnistCoordinates(coordinate_tanh): 365 | ''' 366 | Transform coordinate in [-1,1] to mnist 367 | :param coordinate_tanh: vector in [-1,1] x [-1,1] 368 | :return: vector in the corresponding mnist coordinate 369 | ''' 370 | return np.round(((coordinate_tanh + 1) / 2.0) * img_size) 371 | 372 | 373 | def variable_summaries(var, name): 374 | """Attach a lot of summaries to a Tensor.""" 375 | with tf.name_scope('param_summaries'): 376 | mean = tf.reduce_mean(var) 377 | tf.scalar_summary('param_mean/' + name, mean) 378 | with tf.name_scope('param_stddev'): 379 | stddev = tf.sqrt(tf.reduce_sum(tf.square(var - mean))) 380 | tf.scalar_summary('param_sttdev/' + name, stddev) 381 | tf.scalar_summary('param_max/' + name, tf.reduce_max(var)) 382 | tf.scalar_summary('param_min/' + name, tf.reduce_min(var)) 383 | tf.histogram_summary(name, var) 384 | 385 | 386 | def plotWholeImg(img, img_size, sampled_locs_fetched): 387 | plt.imshow(np.reshape(img, [img_size, img_size]), 388 | cmap=plt.get_cmap('gray'), interpolation="nearest") 389 | 390 | plt.ylim((img_size - 1, 0)) 391 | plt.xlim((0, img_size - 1)) 392 | 393 | # transform the coordinate to mnist map 394 | sampled_locs_mnist_fetched = toMnistCoordinates(sampled_locs_fetched) 395 | # visualize the trace of successive nGlimpses (note that x and y coordinates are "flipped") 396 | plt.plot(sampled_locs_mnist_fetched[0, :, 1], sampled_locs_mnist_fetched[0, :, 0], '-o', 397 | color='lawngreen') 398 | plt.plot(sampled_locs_mnist_fetched[0, -1, 1], sampled_locs_mnist_fetched[0, -1, 0], 'o', 399 | color='red') 400 | 401 | 402 | 403 | with tf.Graph().as_default(): 404 | 405 | # set the learning rate 406 | global_step = tf.Variable(0, trainable=False) 407 | lr = tf.train.exponential_decay(initLr, global_step, lrDecayFreq, lrDecayRate, staircase=True) 408 | 409 | # preallocate x, y, baseline 410 | labels = tf.placeholder("float32", shape=[batch_size, n_classes]) 411 | labels_placeholder = tf.placeholder(tf.float32, shape=(batch_size), name="labels_raw") 412 | onehot_labels_placeholder = tf.placeholder(tf.float32, shape=(batch_size, 10), name="labels_onehot") 413 | inputs_placeholder = tf.placeholder(tf.float32, shape=(batch_size, img_size * img_size), name="images") 414 | 415 | # declare the model parameters, here're naming rule: 416 | # the 1st captical letter: weights or bias (W = weights, B = bias) 417 | # the 2nd lowercase letter: the network (e.g.: g = glimpse network) 418 | # the 3rd and 4th letter(s): input-output mapping, which is clearly written in the variable name argument 419 | 420 | Wg_l_h = weight_variable((2, hl_size), "glimpseNet_wts_location_hidden", True) 421 | Bg_l_h = weight_variable((1,hl_size), "glimpseNet_bias_location_hidden", True) 422 | 423 | Wg_g_h = weight_variable((totalSensorBandwidth, hg_size), "glimpseNet_wts_glimpse_hidden", True) 424 | Bg_g_h = weight_variable((1,hg_size), "glimpseNet_bias_glimpse_hidden", True) 425 | 426 | Wg_hg_gf1 = weight_variable((hg_size, g_size), "glimpseNet_wts_hiddenGlimpse_glimpseFeature1", True) 427 | Wg_hl_gf1 = weight_variable((hl_size, g_size), "glimpseNet_wts_hiddenLocation_glimpseFeature1", True) 428 | Bg_hlhg_gf1 = weight_variable((1,g_size), "glimpseNet_bias_hGlimpse_hLocs_glimpseFeature1", True) 429 | 430 | Wc_g_h = weight_variable((cell_size, g_size), "coreNet_wts_glimpse_hidden", True) 431 | Bc_g_h = weight_variable((1,g_size), "coreNet_bias_glimpse_hidden", True) 432 | 433 | Wr_h_r = weight_variable((cell_out_size, img_size**2), "reconstructionNet_wts_hidden_action", True) 434 | Br_h_r = weight_variable((1, img_size**2), "reconstructionNet_bias_hidden_action", True) 435 | 436 | Wb_h_b = weight_variable((g_size, 1), "baselineNet_wts_hiddenState_baseline", True) 437 | Bb_h_b = weight_variable((1,1), "baselineNet_bias_hiddenState_baseline", True) 438 | 439 | Wl_h_l = weight_variable((cell_out_size, 2), "locationNet_wts_hidden_location", True) 440 | 441 | Wa_h_a = weight_variable((cell_out_size, n_classes), "actionNet_wts_hidden_action", True) 442 | Ba_h_a = weight_variable((1,n_classes), "actionNet_bias_hidden_action", True) 443 | 444 | # query the model ouput 445 | outputs = model() 446 | 447 | # convert list of tensors to one big tensor 448 | sampled_locs = tf.concat(0, sampled_locs) 449 | sampled_locs = tf.reshape(sampled_locs, (nGlimpses, batch_size, 2)) 450 | sampled_locs = tf.transpose(sampled_locs, [1, 0, 2]) 451 | mean_locs = tf.concat(0, mean_locs) 452 | mean_locs = tf.reshape(mean_locs, (nGlimpses, batch_size, 2)) 453 | mean_locs = tf.transpose(mean_locs, [1, 0, 2]) 454 | glimpse_images = tf.concat(0, glimpse_images) 455 | 456 | 457 | 458 | # compute the reward 459 | reconstructionCost, reconstruction, train_op_r = preTrain(outputs) 460 | cost, reward, predicted_labels, correct_labels, train_op, b, avg_b, rminusb, lr = calc_reward(outputs) 461 | 462 | # tensorboard visualization for the parameters 463 | variable_summaries(Wg_l_h, "glimpseNet_wts_location_hidden") 464 | variable_summaries(Bg_l_h, "glimpseNet_bias_location_hidden") 465 | variable_summaries(Wg_g_h, "glimpseNet_wts_glimpse_hidden") 466 | variable_summaries(Bg_g_h, "glimpseNet_bias_glimpse_hidden") 467 | variable_summaries(Wg_hg_gf1, "glimpseNet_wts_hiddenGlimpse_glimpseFeature1") 468 | variable_summaries(Wg_hl_gf1, "glimpseNet_wts_hiddenLocation_glimpseFeature1") 469 | variable_summaries(Bg_hlhg_gf1, "glimpseNet_bias_hGlimpse_hLocs_glimpseFeature1") 470 | 471 | variable_summaries(Wc_g_h, "coreNet_wts_glimpse_hidden") 472 | variable_summaries(Bc_g_h, "coreNet_bias_glimpse_hidden") 473 | 474 | variable_summaries(Wb_h_b, "baselineNet_wts_hiddenState_baseline") 475 | variable_summaries(Bb_h_b, "baselineNet_bias_hiddenState_baseline") 476 | 477 | variable_summaries(Wl_h_l, "locationNet_wts_hidden_location") 478 | 479 | variable_summaries(Wa_h_a, 'actionNet_wts_hidden_action') 480 | variable_summaries(Ba_h_a, 'actionNet_bias_hidden_action') 481 | 482 | # tensorboard visualization for the performance metrics 483 | tf.scalar_summary("reconstructionCost", reconstructionCost) 484 | tf.scalar_summary("reward", reward) 485 | tf.scalar_summary("cost", cost) 486 | tf.scalar_summary("mean(b)", avg_b) 487 | tf.scalar_summary(" mean(R - b)", rminusb) 488 | summary_op = tf.merge_all_summaries() 489 | 490 | 491 | ####################################### START RUNNING THE MODEL ####################################### 492 | sess = tf.Session() 493 | saver = tf.train.Saver() 494 | b_fetched = np.zeros((batch_size, (nGlimpses)*2)) 495 | 496 | init = tf.initialize_all_variables() 497 | sess.run(init) 498 | 499 | if eval_only: 500 | evaluate() 501 | else: 502 | summary_writer = tf.train.SummaryWriter(summaryFolderName, graph=sess.graph) 503 | 504 | if draw: 505 | fig = plt.figure(1) 506 | txt = fig.suptitle("-", fontsize=36, fontweight='bold') 507 | plt.ion() 508 | plt.show() 509 | plt.subplots_adjust(top=0.7) 510 | plotImgs = [] 511 | 512 | if drawReconsturction: 513 | fig = plt.figure(2) 514 | txt = fig.suptitle("-", fontsize=36, fontweight='bold') 515 | plt.ion() 516 | plt.show() 517 | 518 | if preTraining: 519 | for epoch_r in xrange(1,preTraining_epoch): 520 | nextX, _ = dataset.train.next_batch(batch_size) 521 | nextX_orig = nextX 522 | if translateMnist: 523 | nextX, _ = convertTranslated(nextX, MNIST_SIZE, img_size) 524 | 525 | fetches_r = [reconstructionCost, reconstruction, train_op_r] 526 | 527 | reconstructionCost_fetched, reconstruction_fetched, train_op_r_fetched = sess.run(fetches_r, feed_dict={inputs_placeholder: nextX}) 528 | 529 | if epoch_r % 20 == 0: 530 | print('Step %d: reconstructionCost = %.5f' % (epoch_r, reconstructionCost_fetched)) 531 | if epoch_r % 100 == 0: 532 | if drawReconsturction: 533 | fig = plt.figure(2) 534 | 535 | plt.subplot(1, 2, 1) 536 | plt.imshow(np.reshape(nextX[0, :], [img_size, img_size]), 537 | cmap=plt.get_cmap('gray'), interpolation="nearest") 538 | plt.ylim((img_size - 1, 0)) 539 | plt.xlim((0, img_size - 1)) 540 | 541 | plt.subplot(1, 2, 2) 542 | plt.imshow(np.reshape(reconstruction_fetched[0, :], [img_size, img_size]), 543 | cmap=plt.get_cmap('gray'), interpolation="nearest") 544 | plt.ylim((img_size - 1, 0)) 545 | plt.xlim((0, img_size - 1)) 546 | plt.draw() 547 | plt.pause(0.0001) 548 | # plt.show() 549 | 550 | 551 | # training 552 | for epoch in xrange(start_step + 1, max_iters): 553 | start_time = time.time() 554 | 555 | # get the next batch of examples 556 | nextX, nextY = dataset.train.next_batch(batch_size) 557 | nextX_orig = nextX 558 | if translateMnist: 559 | nextX, nextX_coord = convertTranslated(nextX, MNIST_SIZE, img_size) 560 | 561 | feed_dict = {inputs_placeholder: nextX, labels_placeholder: nextY, \ 562 | onehot_labels_placeholder: dense_to_one_hot(nextY)} 563 | 564 | fetches = [train_op, cost, reward, predicted_labels, correct_labels, glimpse_images, avg_b, rminusb, \ 565 | mean_locs, sampled_locs, lr] 566 | # feed them to the model 567 | results = sess.run(fetches, feed_dict=feed_dict) 568 | 569 | _, cost_fetched, reward_fetched, prediction_labels_fetched, correct_labels_fetched, glimpse_images_fetched, \ 570 | avg_b_fetched, rminusb_fetched, mean_locs_fetched, sampled_locs_fetched, lr_fetched = results 571 | 572 | 573 | duration = time.time() - start_time 574 | 575 | if epoch % 20 == 0: 576 | print('Step %d: cost = %.5f reward = %.5f (%.3f sec) b = %.5f R-b = %.5f, LR = %.5f' 577 | % (epoch, cost_fetched, reward_fetched, duration, avg_b_fetched, rminusb_fetched, lr_fetched)) 578 | summary_str = sess.run(summary_op, feed_dict=feed_dict) 579 | summary_writer.add_summary(summary_str, epoch) 580 | # if saveImgs: 581 | # plt.savefig(imgsFolderName + simulationName + '_ep%.6d.png' % (epoch)) 582 | 583 | if epoch % 5000 == 0: 584 | saver.save(sess, save_dir + save_prefix + str(epoch) + ".ckpt") 585 | evaluate() 586 | 587 | ##### DRAW WINDOW ################ 588 | f_glimpse_images = np.reshape(glimpse_images_fetched, \ 589 | (nGlimpses, batch_size, depth, sensorBandwidth, sensorBandwidth)) 590 | 591 | if draw: 592 | if animate: 593 | fillList = False 594 | if len(plotImgs) == 0: 595 | fillList = True 596 | 597 | # display the first image in the in mini-batch 598 | nCols = depth+1 599 | plt.subplot2grid((depth, nCols), (0, 1), rowspan=depth, colspan=depth) 600 | # display the entire image 601 | plotWholeImg(nextX[0, :], img_size, sampled_locs_fetched) 602 | 603 | # display the glimpses 604 | for y in xrange(nGlimpses): 605 | txt.set_text('Epoch: %.6d \nPrediction: %i -- Truth: %i\nStep: %i/%i' 606 | % (epoch, prediction_labels_fetched[0], correct_labels_fetched[0], (y + 1), nGlimpses)) 607 | 608 | for x in xrange(depth): 609 | plt.subplot(depth, nCols, 1 + nCols * x) 610 | if fillList: 611 | plotImg = plt.imshow(f_glimpse_images[y, 0, x], cmap=plt.get_cmap('gray'), 612 | interpolation="nearest") 613 | plotImg.autoscale() 614 | plotImgs.append(plotImg) 615 | else: 616 | plotImgs[x].set_data(f_glimpse_images[y, 0, x]) 617 | plotImgs[x].autoscale() 618 | fillList = False 619 | 620 | # fig.canvas.draw() 621 | time.sleep(0.1) 622 | plt.pause(0.00005) 623 | 624 | else: 625 | txt.set_text('PREDICTION: %i\nTRUTH: %i' % (prediction_labels_fetched[0], correct_labels_fetched[0])) 626 | for x in xrange(depth): 627 | for y in xrange(nGlimpses): 628 | plt.subplot(depth, nGlimpses, x * nGlimpses + y + 1) 629 | plt.imshow(f_glimpse_images[y, 0, x], cmap=plt.get_cmap('gray'), interpolation="nearest") 630 | 631 | plt.draw() 632 | time.sleep(0.05) 633 | plt.pause(0.0001) 634 | 635 | sess.close() -------------------------------------------------------------------------------- /ram_modified.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tf_mnist_loader 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import time 6 | import random 7 | import sys 8 | import os 9 | 10 | 11 | try: 12 | xrange 13 | except NameError: 14 | xrange = range 15 | 16 | dataset = tf_mnist_loader.read_data_sets("mnist_data") 17 | save_dir = "chckPts/" 18 | save_prefix = "save" 19 | summaryFolderName = "summary/" 20 | 21 | 22 | if len(sys.argv) == 2: 23 | simulationName = str(sys.argv[1]) 24 | print("Simulation name = " + simulationName) 25 | summaryFolderName = summaryFolderName + simulationName + "/" 26 | saveImgs = True 27 | imgsFolderName = "imgs/" + simulationName + "/" 28 | if os.path.isdir(summaryFolderName) == False: 29 | os.mkdir(summaryFolderName) 30 | # if os.path.isdir(imgsFolderName) == False: 31 | # os.mkdir(imgsFolderName) 32 | else: 33 | saveImgs = False 34 | print("Testing... image files will not be saved.") 35 | 36 | 37 | start_step = 0 38 | #load_path = None 39 | load_path = save_dir + save_prefix + str(start_step) + ".ckpt" 40 | # to enable visualization, set draw to True 41 | eval_only = False 42 | draw = False 43 | animate = False 44 | 45 | # conditions 46 | translateMnist = 1 47 | eyeCentered = 0 48 | 49 | preTraining = 0 50 | preTraining_epoch = 20000 51 | drawReconsturction = 0 52 | 53 | # about translation 54 | MNIST_SIZE = 28 55 | translated_img_size = 60 # side length of the picture 56 | 57 | fixed_learning_rate = 0.001 58 | 59 | 60 | if translateMnist: 61 | print("TRANSLATED MNIST") 62 | img_size = translated_img_size 63 | depth = 3 # number of zooms 64 | sensorBandwidth = 12 65 | minRadius = 8 # zooms -> minRadius * 2** 66 | 67 | initLr = 1e-3 68 | lr_min = 1e-4 69 | lrDecayRate = .999 70 | lrDecayFreq = 200 71 | momentumValue = .9 72 | batch_size = 64 73 | 74 | else: 75 | print("CENTERED MNIST") 76 | img_size = MNIST_SIZE 77 | depth = 1 # number of zooms 78 | sensorBandwidth = 8 79 | minRadius = 4 # zooms -> minRadius * 2** 80 | 81 | initLr = 1e-3 82 | lrDecayRate = .99 83 | lrDecayFreq = 200 84 | momentumValue = .9 85 | batch_size = 20 86 | 87 | 88 | # model parameters 89 | channels = 1 # mnist are grayscale images 90 | totalSensorBandwidth = depth * channels * (sensorBandwidth **2) 91 | nGlimpses = 6 # number of glimpses 92 | loc_sd = 0.22 # std when setting the location 93 | 94 | # network units 95 | hg_size = 128 # 96 | hl_size = 128 # 97 | g_size = 256 # 98 | cell_size = 256 # 99 | cell_out_size = cell_size # 100 | 101 | # paramters about the training examples 102 | n_classes = 10 # card(Y) 103 | 104 | # training parameters 105 | max_iters = 1000000 106 | SMALL_NUM = 1e-10 107 | 108 | # resource prellocation 109 | mean_locs = [] # expectation of locations 110 | sampled_locs = [] # sampled locations ~N(mean_locs[.], loc_sd) 111 | baselines = [] # baseline, the value prediction 112 | glimpse_images = [] # to show in window 113 | 114 | 115 | # set the weights to be small random values, with truncated normal distribution 116 | def weight_variable(shape, myname, train): 117 | initial = tf.random_uniform(shape, minval=-0.1, maxval = 0.1) 118 | return tf.Variable(initial, name=myname, trainable=train) 119 | 120 | # get local glimpses 121 | def glimpseSensor(img, normLoc): 122 | loc = tf.round(((normLoc + 1) / 2.0) * img_size) # normLoc coordinates are between -1 and 1 123 | loc = tf.cast(loc, tf.int32) 124 | 125 | img = tf.reshape(img, (batch_size, img_size, img_size, channels)) 126 | 127 | # process each image individually 128 | zooms = [] 129 | for k in range(batch_size): 130 | imgZooms = [] 131 | one_img = img[k,:,:,:] 132 | max_radius = minRadius * (2 ** (depth - 1)) 133 | offset = 2 * max_radius 134 | 135 | # pad image with zeros 136 | one_img = tf.image.pad_to_bounding_box(one_img, offset, offset, \ 137 | max_radius * 4 + img_size, max_radius * 4 + img_size) 138 | 139 | for i in range(depth): 140 | r = int(minRadius * (2 ** (i))) 141 | 142 | d_raw = 2 * r 143 | d = tf.constant(d_raw, shape=[1]) 144 | d = tf.tile(d, [2]) 145 | loc_k = loc[k,:] 146 | adjusted_loc = offset + loc_k - r 147 | one_img2 = tf.reshape(one_img, (one_img.get_shape()[0].value, one_img.get_shape()[1].value)) 148 | 149 | # crop image to (d x d) 150 | zoom = tf.slice(one_img2, adjusted_loc, d) 151 | 152 | # resize cropped image to (sensorBandwidth x sensorBandwidth) 153 | zoom = tf.image.resize_bilinear(tf.reshape(zoom, (1, d_raw, d_raw, 1)), (sensorBandwidth, sensorBandwidth)) 154 | zoom = tf.reshape(zoom, (sensorBandwidth, sensorBandwidth)) 155 | imgZooms.append(zoom) 156 | 157 | zooms.append(tf.stack(imgZooms)) 158 | 159 | zooms = tf.stack(zooms) 160 | 161 | glimpse_images.append(zooms) 162 | 163 | return zooms 164 | 165 | # implements the input network 166 | def get_glimpse(loc): 167 | # get input using the previous location 168 | glimpse_input = glimpseSensor(inputs_placeholder, loc) 169 | glimpse_input = tf.reshape(glimpse_input, (batch_size, totalSensorBandwidth)) 170 | 171 | # the hidden units that process location & the input 172 | act_glimpse_hidden = tf.nn.relu(tf.matmul(glimpse_input, Wg_g_h) + Bg_g_h) 173 | act_loc_hidden = tf.nn.relu(tf.matmul(loc, Wg_l_h) + Bg_l_h) 174 | 175 | # the hidden units that integrates the location & the glimpses 176 | glimpseFeature1 = tf.nn.relu(tf.matmul(act_glimpse_hidden, Wg_hg_gf1) + tf.matmul(act_loc_hidden, Wg_hl_gf1) + Bg_hlhg_gf1) 177 | # return g 178 | # glimpseFeature2 = tf.matmul(glimpseFeature1, Wg_gf1_gf2) + Bg_gf1_gf2 179 | return glimpseFeature1 180 | 181 | 182 | def get_next_input(output): 183 | # the next location is computed by the location network 184 | core_net_out = tf.stop_gradient(output) 185 | 186 | # baseline = tf.sigmoid(tf.matmul(core_net_out, Wb_h_b) + Bb_h_b) 187 | baseline = tf.sigmoid(tf.matmul(core_net_out, Wb_h_b) + Bb_h_b) 188 | baselines.append(baseline) 189 | 190 | # compute the next location, then impose noise 191 | if eyeCentered: 192 | # add the last sampled glimpse location 193 | # TODO max(-1, min(1, u + N(output, sigma) + prevLoc)) 194 | mean_loc = tf.maximum(-1.0, tf.minimum(1.0, tf.matmul(core_net_out, Wl_h_l) + sampled_locs[-1] )) 195 | else: 196 | # mean_loc = tf.clip_by_value(tf.matmul(core_net_out, Wl_h_l) + Bl_h_l, -1, 1) 197 | mean_loc = tf.matmul(core_net_out, Wl_h_l) + Bl_h_l 198 | mean_loc = tf.clip_by_value(mean_loc, -1, 1) 199 | # mean_loc = tf.stop_gradient(mean_loc) 200 | mean_locs.append(mean_loc) 201 | 202 | # add noise 203 | # sample_loc = tf.tanh(mean_loc + tf.random_normal(mean_loc.get_shape(), 0, loc_sd)) 204 | sample_loc = tf.maximum(-1.0, tf.minimum(1.0, mean_loc + tf.random_normal(mean_loc.get_shape(), 0, loc_sd))) 205 | 206 | # don't propagate throught the locations 207 | sample_loc = tf.stop_gradient(sample_loc) 208 | sampled_locs.append(sample_loc) 209 | 210 | return get_glimpse(sample_loc) 211 | 212 | 213 | def affineTransform(x,output_dim): 214 | """ 215 | affine transformation Wx+b 216 | assumes x.shape = (batch_size, num_features) 217 | """ 218 | w=tf.get_variable("w", [x.get_shape()[1], output_dim]) 219 | b=tf.get_variable("b", [output_dim], initializer=tf.constant_initializer(0.0)) 220 | return tf.matmul(x,w)+b 221 | 222 | 223 | def model(): 224 | 225 | # initialize the location under unif[-1,1], for all example in the batch 226 | initial_loc = tf.random_uniform((batch_size, 2), minval=-1, maxval=1) 227 | mean_locs.append(initial_loc) 228 | 229 | # initial_loc = tf.tanh(initial_loc + tf.random_normal(initial_loc.get_shape(), 0, loc_sd)) 230 | initial_loc = tf.clip_by_value(initial_loc + tf.random_normal(initial_loc.get_shape(), 0, loc_sd), -1, 1) 231 | 232 | sampled_locs.append(initial_loc) 233 | 234 | # get the input using the input network 235 | initial_glimpse = get_glimpse(initial_loc) 236 | 237 | # set up the recurrent structure 238 | inputs = [0] * nGlimpses 239 | outputs = [0] * nGlimpses 240 | glimpse = initial_glimpse 241 | REUSE = None 242 | for t in range(nGlimpses): 243 | if t == 0: # initialize the hidden state to be the zero vector 244 | hiddenState_prev = tf.zeros((batch_size, cell_size)) 245 | else: 246 | hiddenState_prev = outputs[t-1] 247 | 248 | # forward prop 249 | with tf.variable_scope("coreNetwork", reuse=REUSE): 250 | # the next hidden state is a function of the previous hidden state and the current glimpse 251 | hiddenState = tf.nn.relu(affineTransform(hiddenState_prev, cell_size) + (tf.matmul(glimpse, Wc_g_h) + Bc_g_h)) 252 | 253 | # save the current glimpse and the hidden state 254 | inputs[t] = glimpse 255 | outputs[t] = hiddenState 256 | # get the next input glimpse 257 | if t != nGlimpses -1: 258 | glimpse = get_next_input(hiddenState) 259 | else: 260 | first_hiddenState = tf.stop_gradient(hiddenState) 261 | # baseline = tf.sigmoid(tf.matmul(first_hiddenState, Wb_h_b) + Bb_h_b) 262 | baseline = tf.sigmoid(tf.matmul(first_hiddenState, Wb_h_b) + Bb_h_b) 263 | baselines.append(baseline) 264 | REUSE = True # share variables for later recurrence 265 | 266 | return outputs 267 | 268 | 269 | def dense_to_one_hot(labels_dense, num_classes=10): 270 | """Convert class labels from scalars to one-hot vectors.""" 271 | # copied from TensorFlow tutorial 272 | num_labels = labels_dense.shape[0] 273 | index_offset = np.arange(num_labels) * num_classes 274 | labels_one_hot = np.zeros((num_labels, num_classes)) 275 | labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 276 | return labels_one_hot 277 | 278 | 279 | # to use for maximum likelihood with input location 280 | def gaussian_pdf(mean, sample): 281 | Z = 1.0 / (loc_sd * tf.sqrt(2.0 * np.pi)) 282 | a = -tf.square(sample - mean) / (2.0 * tf.square(loc_sd)) 283 | return Z * tf.exp(a) 284 | 285 | 286 | def calc_reward(outputs): 287 | 288 | # consider the action at the last time step 289 | outputs = outputs[-1] # look at ONLY THE END of the sequence 290 | outputs = tf.reshape(outputs, (batch_size, cell_out_size)) 291 | 292 | # get the baseline 293 | b = tf.stack(baselines) 294 | b = tf.concat(axis=2, values=[b, b]) 295 | b = tf.reshape(b, (batch_size, (nGlimpses) * 2)) 296 | no_grad_b = tf.stop_gradient(b) 297 | 298 | # get the action(classification) 299 | p_y = tf.nn.softmax(tf.matmul(outputs, Wa_h_a) + Ba_h_a) 300 | max_p_y = tf.arg_max(p_y, 1) 301 | correct_y = tf.cast(labels_placeholder, tf.int64) 302 | 303 | # reward for all examples in the batch 304 | R = tf.cast(tf.equal(max_p_y, correct_y), tf.float32) 305 | reward = tf.reduce_mean(R) # mean reward 306 | R = tf.reshape(R, (batch_size, 1)) 307 | R = tf.tile(R, [1, (nGlimpses)*2]) 308 | 309 | # get the location 310 | 311 | p_loc = gaussian_pdf(mean_locs, sampled_locs) 312 | # p_loc = tf.tanh(p_loc) 313 | 314 | p_loc_orig = p_loc 315 | p_loc = tf.reshape(p_loc, (batch_size, (nGlimpses) * 2)) 316 | 317 | # define the cost function 318 | J = tf.concat(axis=1, values=[tf.log(p_y + SMALL_NUM) * (onehot_labels_placeholder), tf.log(p_loc + SMALL_NUM) * (R - no_grad_b)]) 319 | J = tf.reduce_sum(J, 1) 320 | J = J - tf.reduce_sum(tf.square(R - b), 1) 321 | J = tf.reduce_mean(J, 0) 322 | cost = -J 323 | var_list = tf.trainable_variables() 324 | grads = tf.gradients(cost, var_list) 325 | grads, _ = tf.clip_by_global_norm(grads, 0.5) 326 | # define the optimizer 327 | # lr_max = tf.maximum(lr, lr_min) 328 | optimizer = tf.train.AdamOptimizer(lr) 329 | # optimizer = tf.train.MomentumOptimizer(lr, momentumValue) 330 | # train_op = optimizer.minimize(cost, global_step) 331 | train_op = optimizer.apply_gradients(zip(grads, var_list), global_step=global_step) 332 | 333 | return cost, reward, max_p_y, correct_y, train_op, b, tf.reduce_mean(b), tf.reduce_mean(R - b), lr 334 | 335 | 336 | def preTrain(outputs): 337 | lr_r = 1e-3 338 | # consider the action at the last time step 339 | outputs = outputs[-1] # look at ONLY THE END of the sequence 340 | outputs = tf.reshape(outputs, (batch_size, cell_out_size)) 341 | # if preTraining: 342 | reconstruction = tf.sigmoid(tf.matmul(outputs, Wr_h_r) + Br_h_r) 343 | reconstructionCost = tf.reduce_mean(tf.square(inputs_placeholder - reconstruction)) 344 | 345 | train_op_r = tf.train.RMSPropOptimizer(lr_r).minimize(reconstructionCost) 346 | return reconstructionCost, reconstruction, train_op_r 347 | 348 | 349 | def evaluate(): 350 | data = dataset.test 351 | batches_in_epoch = len(data._images) // batch_size 352 | accuracy = 0 353 | 354 | for i in range(batches_in_epoch): 355 | nextX, nextY = dataset.test.next_batch(batch_size) 356 | if translateMnist: 357 | nextX, _ = convertTranslated(nextX, MNIST_SIZE, img_size) 358 | feed_dict = {inputs_placeholder: nextX, labels_placeholder: nextY, 359 | onehot_labels_placeholder: dense_to_one_hot(nextY)} 360 | r = sess.run(reward, feed_dict=feed_dict) 361 | accuracy += r 362 | 363 | accuracy /= batches_in_epoch 364 | print(("ACCURACY: " + str(accuracy))) 365 | 366 | 367 | def convertTranslated(images, initImgSize, finalImgSize): 368 | size_diff = finalImgSize - initImgSize 369 | newimages = np.zeros([batch_size, finalImgSize*finalImgSize]) 370 | imgCoord = np.zeros([batch_size,2]) 371 | for k in range(batch_size): 372 | image = images[k, :] 373 | image = np.reshape(image, (initImgSize, initImgSize)) 374 | # generate and save random coordinates 375 | randX = random.randint(0, size_diff) 376 | randY = random.randint(0, size_diff) 377 | imgCoord[k,:] = np.array([randX, randY]) 378 | # padding 379 | image = np.lib.pad(image, ((randX, size_diff - randX), (randY, size_diff - randY)), 'constant', constant_values = (0)) 380 | newimages[k, :] = np.reshape(image, (finalImgSize*finalImgSize)) 381 | 382 | return newimages, imgCoord 383 | 384 | 385 | 386 | def toMnistCoordinates(coordinate_tanh): 387 | ''' 388 | Transform coordinate in [-1,1] to mnist 389 | :param coordinate_tanh: vector in [-1,1] x [-1,1] 390 | :return: vector in the corresponding mnist coordinate 391 | ''' 392 | return np.round(((coordinate_tanh + 1) / 2.0) * img_size) 393 | 394 | 395 | def variable_summaries(var, name): 396 | """Attach a lot of summaries to a Tensor.""" 397 | with tf.name_scope('param_summaries'): 398 | mean = tf.reduce_mean(var) 399 | tf.summary.scalar('param_mean/' + name, mean) 400 | with tf.name_scope('param_stddev'): 401 | stddev = tf.sqrt(tf.reduce_sum(tf.square(var - mean))) 402 | tf.summary.scalar('param_sttdev/' + name, stddev) 403 | tf.summary.scalar('param_max/' + name, tf.reduce_max(var)) 404 | tf.summary.scalar('param_min/' + name, tf.reduce_min(var)) 405 | tf.summary.histogram(name, var) 406 | 407 | 408 | def plotWholeImg(img, img_size, sampled_locs_fetched): 409 | plt.imshow(np.reshape(img, [img_size, img_size]), 410 | cmap=plt.get_cmap('gray'), interpolation="nearest") 411 | 412 | plt.ylim((img_size - 1, 0)) 413 | plt.xlim((0, img_size - 1)) 414 | 415 | # transform the coordinate to mnist map 416 | sampled_locs_mnist_fetched = toMnistCoordinates(sampled_locs_fetched) 417 | # visualize the trace of successive nGlimpses (note that x and y coordinates are "flipped") 418 | plt.plot(sampled_locs_mnist_fetched[0, :, 1], sampled_locs_mnist_fetched[0, :, 0], '-o', 419 | color='lawngreen') 420 | plt.plot(sampled_locs_mnist_fetched[0, -1, 1], sampled_locs_mnist_fetched[0, -1, 0], 'o', 421 | color='red') 422 | 423 | 424 | with tf.device('/gpu:1'): 425 | 426 | with tf.Graph().as_default(): 427 | 428 | # set the learning rate 429 | global_step = tf.Variable(0, trainable=False) 430 | lr = tf.train.exponential_decay(initLr, global_step, lrDecayFreq, lrDecayRate, staircase=True) 431 | 432 | # preallocate x, y, baseline 433 | labels = tf.placeholder("float32", shape=[batch_size, n_classes]) 434 | labels_placeholder = tf.placeholder(tf.float32, shape=(batch_size), name="labels_raw") 435 | onehot_labels_placeholder = tf.placeholder(tf.float32, shape=(batch_size, 10), name="labels_onehot") 436 | inputs_placeholder = tf.placeholder(tf.float32, shape=(batch_size, img_size * img_size), name="images") 437 | 438 | # declare the model parameters, here're naming rule: 439 | # the 1st captical letter: weights or bias (W = weights, B = bias) 440 | # the 2nd lowercase letter: the network (e.g.: g = glimpse network) 441 | # the 3rd and 4th letter(s): input-output mapping, which is clearly written in the variable name argument 442 | 443 | Wg_l_h = weight_variable((2, hl_size), "glimpseNet_wts_location_hidden", True) 444 | Bg_l_h = weight_variable((1,hl_size), "glimpseNet_bias_location_hidden", True) 445 | 446 | Wg_g_h = weight_variable((totalSensorBandwidth, hg_size), "glimpseNet_wts_glimpse_hidden", True) 447 | Bg_g_h = weight_variable((1,hg_size), "glimpseNet_bias_glimpse_hidden", True) 448 | 449 | Wg_hg_gf1 = weight_variable((hg_size, g_size), "glimpseNet_wts_hiddenGlimpse_glimpseFeature1", True) 450 | Wg_hl_gf1 = weight_variable((hl_size, g_size), "glimpseNet_wts_hiddenLocation_glimpseFeature1", True) 451 | Bg_hlhg_gf1 = weight_variable((1,g_size), "glimpseNet_bias_hGlimpse_hLocs_glimpseFeature1", True) 452 | 453 | Wc_g_h = weight_variable((cell_size, g_size), "coreNet_wts_glimpse_hidden", True) 454 | Bc_g_h = weight_variable((1,g_size), "coreNet_bias_glimpse_hidden", True) 455 | 456 | Wr_h_r = weight_variable((cell_out_size, img_size**2), "reconstructionNet_wts_hidden_action", True) 457 | Br_h_r = weight_variable((1, img_size**2), "reconstructionNet_bias_hidden_action", True) 458 | 459 | Wb_h_b = weight_variable((g_size, 1), "baselineNet_wts_hiddenState_baseline", True) 460 | Bb_h_b = weight_variable((1,1), "baselineNet_bias_hiddenState_baseline", True) 461 | 462 | Wl_h_l = weight_variable((cell_out_size, 2), "locationNet_wts_hidden_location", True) 463 | Bl_h_l = weight_variable((1, 2), "locationNet_bias_hidden_location", True) 464 | 465 | Wa_h_a = weight_variable((cell_out_size, n_classes), "actionNet_wts_hidden_action", True) 466 | Ba_h_a = weight_variable((1,n_classes), "actionNet_bias_hidden_action", True) 467 | 468 | # query the model ouput 469 | outputs = model() 470 | 471 | # convert list of tensors to one big tensor 472 | sampled_locs = tf.concat(axis=0, values=sampled_locs) 473 | sampled_locs = tf.reshape(sampled_locs, (nGlimpses, batch_size, 2)) 474 | sampled_locs = tf.transpose(sampled_locs, [1, 0, 2]) 475 | mean_locs = tf.concat(axis=0, values=mean_locs) 476 | mean_locs = tf.reshape(mean_locs, (nGlimpses, batch_size, 2)) 477 | mean_locs = tf.transpose(mean_locs, [1, 0, 2]) 478 | glimpse_images = tf.concat(axis=0, values=glimpse_images) 479 | 480 | # compute the reward 481 | reconstructionCost, reconstruction, train_op_r = preTrain(outputs) 482 | cost, reward, predicted_labels, correct_labels, train_op, b, avg_b, rminusb, lr = calc_reward(outputs) 483 | 484 | # tensorboard visualization for the parameters 485 | variable_summaries(Wg_l_h, "glimpseNet_wts_location_hidden") 486 | variable_summaries(Bg_l_h, "glimpseNet_bias_location_hidden") 487 | variable_summaries(Wg_g_h, "glimpseNet_wts_glimpse_hidden") 488 | variable_summaries(Bg_g_h, "glimpseNet_bias_glimpse_hidden") 489 | variable_summaries(Wg_hg_gf1, "glimpseNet_wts_hiddenGlimpse_glimpseFeature1") 490 | variable_summaries(Wg_hl_gf1, "glimpseNet_wts_hiddenLocation_glimpseFeature1") 491 | variable_summaries(Bg_hlhg_gf1, "glimpseNet_bias_hGlimpse_hLocs_glimpseFeature1") 492 | 493 | variable_summaries(Wc_g_h, "coreNet_wts_glimpse_hidden") 494 | variable_summaries(Bc_g_h, "coreNet_bias_glimpse_hidden") 495 | 496 | variable_summaries(Wb_h_b, "baselineNet_wts_hiddenState_baseline") 497 | variable_summaries(Bb_h_b, "baselineNet_bias_hiddenState_baseline") 498 | 499 | variable_summaries(Wl_h_l, "locationNet_wts_hidden_location") 500 | 501 | variable_summaries(Wa_h_a, 'actionNet_wts_hidden_action') 502 | variable_summaries(Ba_h_a, 'actionNet_bias_hidden_action') 503 | 504 | # tensorboard visualization for the performance metrics 505 | tf.summary.scalar("reconstructionCost", reconstructionCost) 506 | tf.summary.scalar("reward", reward) 507 | tf.summary.scalar("cost", cost) 508 | tf.summary.scalar("mean(b)", avg_b) 509 | tf.summary.scalar("mean(R - b)", rminusb) 510 | summary_op = tf.summary.merge_all() 511 | 512 | 513 | ####################################### START RUNNING THE MODEL ####################################### 514 | 515 | sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) 516 | sess_config.gpu_options.allow_growth = True 517 | sess = tf.Session(config=sess_config) 518 | 519 | saver = tf.train.Saver() 520 | b_fetched = np.zeros((batch_size, (nGlimpses)*2)) 521 | 522 | init = tf.global_variables_initializer() 523 | sess.run(init) 524 | 525 | if eval_only: 526 | evaluate() 527 | else: 528 | summary_writer = tf.summary.FileWriter(summaryFolderName, graph=sess.graph) 529 | 530 | if draw: 531 | fig = plt.figure(1) 532 | txt = fig.suptitle("-", fontsize=36, fontweight='bold') 533 | plt.ion() 534 | plt.show() 535 | plt.subplots_adjust(top=0.7) 536 | plotImgs = [] 537 | 538 | if drawReconsturction: 539 | fig = plt.figure(2) 540 | txt = fig.suptitle("-", fontsize=36, fontweight='bold') 541 | plt.ion() 542 | plt.show() 543 | 544 | if preTraining: 545 | for epoch_r in range(1,preTraining_epoch): 546 | nextX, _ = dataset.train.next_batch(batch_size) 547 | nextX_orig = nextX 548 | if translateMnist: 549 | nextX, _ = convertTranslated(nextX, MNIST_SIZE, img_size) 550 | 551 | fetches_r = [reconstructionCost, reconstruction, train_op_r] 552 | 553 | reconstructionCost_fetched, reconstruction_fetched, train_op_r_fetched = sess.run(fetches_r, feed_dict={inputs_placeholder: nextX}) 554 | 555 | if epoch_r % 20 == 0: 556 | print(('Step %d: reconstructionCost = %.5f' % (epoch_r, reconstructionCost_fetched))) 557 | if epoch_r % 100 == 0: 558 | if drawReconsturction: 559 | fig = plt.figure(2) 560 | 561 | plt.subplot(1, 2, 1) 562 | plt.imshow(np.reshape(nextX[0, :], [img_size, img_size]), 563 | cmap=plt.get_cmap('gray'), interpolation="nearest") 564 | plt.ylim((img_size - 1, 0)) 565 | plt.xlim((0, img_size - 1)) 566 | 567 | plt.subplot(1, 2, 2) 568 | plt.imshow(np.reshape(reconstruction_fetched[0, :], [img_size, img_size]), 569 | cmap=plt.get_cmap('gray'), interpolation="nearest") 570 | plt.ylim((img_size - 1, 0)) 571 | plt.xlim((0, img_size - 1)) 572 | plt.draw() 573 | plt.pause(0.0001) 574 | # plt.show() 575 | 576 | 577 | # training 578 | for epoch in range(start_step + 1, max_iters): 579 | start_time = time.time() 580 | 581 | # get the next batch of examples 582 | nextX, nextY = dataset.train.next_batch(batch_size) 583 | nextX_orig = nextX 584 | if translateMnist: 585 | nextX, nextX_coord = convertTranslated(nextX, MNIST_SIZE, img_size) 586 | 587 | feed_dict = {inputs_placeholder: nextX, labels_placeholder: nextY, \ 588 | onehot_labels_placeholder: dense_to_one_hot(nextY)} 589 | 590 | fetches = [train_op, cost, reward, predicted_labels, correct_labels, glimpse_images, avg_b, rminusb, \ 591 | mean_locs, sampled_locs, lr] 592 | # feed them to the model 593 | results = sess.run(fetches, feed_dict=feed_dict) 594 | 595 | _, cost_fetched, reward_fetched, prediction_labels_fetched, correct_labels_fetched, glimpse_images_fetched, \ 596 | avg_b_fetched, rminusb_fetched, mean_locs_fetched, sampled_locs_fetched, lr_fetched = results 597 | 598 | 599 | duration = time.time() - start_time 600 | 601 | if epoch % 20 == 0: 602 | print(('Step %d: cost = %.5f reward = %.5f (%.3f sec) b = %.5f R-b = %.5f, LR = %.5f' 603 | % (epoch, cost_fetched, reward_fetched, duration, avg_b_fetched, rminusb_fetched, lr_fetched))) 604 | summary_str = sess.run(summary_op, feed_dict=feed_dict) 605 | summary_writer.add_summary(summary_str, epoch) 606 | # if saveImgs: 607 | # plt.savefig(imgsFolderName + simulationName + '_ep%.6d.png' % (epoch)) 608 | 609 | if epoch % 5000 == 0: 610 | saver.save(sess, save_dir + save_prefix + str(epoch) + ".ckpt") 611 | evaluate() 612 | 613 | ##### DRAW WINDOW ################ 614 | f_glimpse_images = np.reshape(glimpse_images_fetched, \ 615 | (nGlimpses, batch_size, depth, sensorBandwidth, sensorBandwidth)) 616 | 617 | if draw: 618 | if animate: 619 | fillList = False 620 | if len(plotImgs) == 0: 621 | fillList = True 622 | 623 | # display the first image in the in mini-batch 624 | nCols = depth+1 625 | plt.subplot2grid((depth, nCols), (0, 1), rowspan=depth, colspan=depth) 626 | # display the entire image 627 | plotWholeImg(nextX[0, :], img_size, sampled_locs_fetched) 628 | 629 | # display the glimpses 630 | for y in range(nGlimpses): 631 | txt.set_text('Epoch: %.6d \nPrediction: %i -- Truth: %i\nStep: %i/%i' 632 | % (epoch, prediction_labels_fetched[0], correct_labels_fetched[0], (y + 1), nGlimpses)) 633 | 634 | for x in range(depth): 635 | plt.subplot(depth, nCols, 1 + nCols * x) 636 | if fillList: 637 | plotImg = plt.imshow(f_glimpse_images[y, 0, x], cmap=plt.get_cmap('gray'), 638 | interpolation="nearest") 639 | plotImg.autoscale() 640 | plotImgs.append(plotImg) 641 | else: 642 | plotImgs[x].set_data(f_glimpse_images[y, 0, x]) 643 | plotImgs[x].autoscale() 644 | fillList = False 645 | 646 | # fig.canvas.draw() 647 | time.sleep(0.1) 648 | plt.pause(0.00005) 649 | 650 | else: 651 | txt.set_text('PREDICTION: %i\nTRUTH: %i' % (prediction_labels_fetched[0], correct_labels_fetched[0])) 652 | for x in range(depth): 653 | for y in range(nGlimpses): 654 | plt.subplot(depth, nGlimpses, x * nGlimpses + y + 1) 655 | plt.imshow(f_glimpse_images[y, 0, x], cmap=plt.get_cmap('gray'), interpolation="nearest") 656 | 657 | plt.draw() 658 | time.sleep(0.05) 659 | plt.pause(0.0001) 660 | 661 | sess.close() 662 | -------------------------------------------------------------------------------- /ram_up.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tf_mnist_loader 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import time 6 | import random 7 | import sys 8 | import os 9 | 10 | try: 11 | xrange 12 | except NameError: 13 | xrange = range 14 | 15 | dataset = tf_mnist_loader.read_data_sets("mnist_data") 16 | save_dir = "chckPts/" 17 | save_prefix = "save" 18 | summaryFolderName = "summary/" 19 | 20 | 21 | if len(sys.argv) == 2: 22 | simulationName = str(sys.argv[1]) 23 | print("Simulation name = " + simulationName) 24 | summaryFolderName = summaryFolderName + simulationName + "/" 25 | saveImgs = True 26 | imgsFolderName = "imgs/" + simulationName + "/" 27 | if os.path.isdir(summaryFolderName) == False: 28 | os.mkdir(summaryFolderName) 29 | # if os.path.isdir(imgsFolderName) == False: 30 | # os.mkdir(imgsFolderName) 31 | else: 32 | saveImgs = False 33 | print("Testing... image files will not be saved.") 34 | 35 | 36 | start_step = 0 37 | #load_path = None 38 | load_path = save_dir + save_prefix + str(start_step) + ".ckpt" 39 | # to enable visualization, set draw to True 40 | eval_only = False 41 | draw = 0 42 | animate = 0 43 | 44 | # conditions 45 | translateMnist = 1 46 | eyeCentered = 0 47 | 48 | preTraining = 0 49 | preTraining_epoch = 20000 50 | drawReconsturction = 0 51 | 52 | # about translation 53 | MNIST_SIZE = 28 54 | translated_img_size = 60 # side length of the picture 55 | 56 | if translateMnist: 57 | print("TRANSLATED MNIST") 58 | img_size = translated_img_size 59 | depth = 3 # number of zooms 60 | sensorBandwidth = 12 61 | minRadius = 6 # zooms -> minRadius * 2** 62 | 63 | initLr = 5e-3 64 | lrDecayRate = .995 65 | lrDecayFreq = 500 66 | momentumValue = .9 67 | batch_size = 20 68 | 69 | else: 70 | print("CENTERED MNIST") 71 | img_size = MNIST_SIZE 72 | depth = 1 # number of zooms 73 | sensorBandwidth = 8 74 | minRadius = 4 # zooms -> minRadius * 2** 75 | 76 | initLr = 5e-3 77 | lrDecayRate = .99 78 | lrDecayFreq = 200 79 | momentumValue = .9 80 | batch_size = 20 81 | 82 | 83 | # model parameters 84 | channels = 1 # mnist are grayscale images 85 | totalSensorBandwidth = depth * channels * (sensorBandwidth **2) 86 | nGlimpses = 8 # number of glimpses 87 | loc_sd = 0.11 # std when setting the location 88 | 89 | # network units 90 | hg_size = 128 # 91 | hl_size = 128 # 92 | g_size = 256 # 93 | cell_size = 256 # 94 | cell_out_size = cell_size # 95 | 96 | # paramters about the training examples 97 | n_classes = 10 # card(Y) 98 | 99 | # training parameters 100 | max_iters = 1000000 101 | SMALL_NUM = 1e-10 102 | 103 | # resource prellocation 104 | mean_locs = [] # expectation of locations 105 | sampled_locs = [] # sampled locations ~N(mean_locs[.], loc_sd) 106 | baselines = [] # baseline, the value prediction 107 | glimpse_images = [] # to show in window 108 | 109 | 110 | # set the weights to be small random values, with truncated normal distribution 111 | def weight_variable(shape, myname, train): 112 | initial = tf.random_uniform(shape, minval=-0.1, maxval = 0.1) 113 | return tf.Variable(initial, name=myname, trainable=train) 114 | 115 | 116 | # get local glimpses 117 | def glimpseSensor(img, normLoc): 118 | loc = tf.round(((normLoc + 1) / 2.0) * img_size) # normLoc coordinates are between -1 and 1 119 | loc = tf.cast(loc, tf.int32) 120 | 121 | img = tf.reshape(img, (batch_size, img_size, img_size, channels)) 122 | 123 | # process each image individually 124 | zooms = [] 125 | for k in range(batch_size): 126 | imgZooms = [] 127 | one_img = img[k,:,:,:] 128 | max_radius = minRadius * (2 ** (depth - 1)) 129 | offset = 2 * max_radius 130 | 131 | # pad image with zeros 132 | one_img = tf.image.pad_to_bounding_box(one_img, offset, offset, \ 133 | max_radius * 4 + img_size, max_radius * 4 + img_size) 134 | 135 | for i in range(depth): 136 | r = int(minRadius * (2 ** (i))) 137 | 138 | d_raw = 2 * r 139 | d = tf.constant(d_raw, shape=[1]) 140 | d = tf.tile(d, [2]) 141 | loc_k = loc[k,:] 142 | adjusted_loc = offset + loc_k - r 143 | one_img2 = tf.reshape(one_img, (one_img.get_shape()[0].value, one_img.get_shape()[1].value)) 144 | 145 | # crop image to (d x d) 146 | zoom = tf.slice(one_img2, adjusted_loc, d) 147 | 148 | # resize cropped image to (sensorBandwidth x sensorBandwidth) 149 | zoom = tf.image.resize_bilinear(tf.reshape(zoom, (1, d_raw, d_raw, 1)), (sensorBandwidth, sensorBandwidth)) 150 | zoom = tf.reshape(zoom, (sensorBandwidth, sensorBandwidth)) 151 | imgZooms.append(zoom) 152 | 153 | zooms.append(tf.stack(imgZooms)) 154 | 155 | zooms = tf.stack(zooms) 156 | 157 | glimpse_images.append(zooms) 158 | 159 | return zooms 160 | 161 | 162 | # implements the input network 163 | def get_glimpse(loc): 164 | # get input using the previous location 165 | glimpse_input = glimpseSensor(inputs_placeholder, loc) 166 | glimpse_input = tf.reshape(glimpse_input, (batch_size, totalSensorBandwidth)) 167 | 168 | # the hidden units that process location & the input 169 | act_glimpse_hidden = tf.nn.relu(tf.matmul(glimpse_input, Wg_g_h) + Bg_g_h) 170 | act_loc_hidden = tf.nn.relu(tf.matmul(loc, Wg_l_h) + Bg_l_h) 171 | 172 | # the hidden units that integrates the location & the glimpses 173 | glimpseFeature1 = tf.nn.relu(tf.matmul(act_glimpse_hidden, Wg_hg_gf1) + tf.matmul(act_loc_hidden, Wg_hl_gf1) + Bg_hlhg_gf1) 174 | # return g 175 | # glimpseFeature2 = tf.matmul(glimpseFeature1, Wg_gf1_gf2) + Bg_gf1_gf2 176 | return glimpseFeature1 177 | 178 | 179 | def get_next_input(output): 180 | # the next location is computed by the location network 181 | baseline = tf.sigmoid(tf.matmul(output, Wb_h_b) + Bb_h_b) 182 | baselines.append(baseline) 183 | # compute the next location, then impose noise 184 | if eyeCentered: 185 | # add the last sampled glimpse location 186 | # TODO max(-1, min(1, u + N(output, sigma) + prevLoc)) 187 | mean_loc = tf.maximum(-1.0, tf.minimum(1.0, tf.matmul(output, Wl_h_l) + sampled_locs[-1] )) 188 | else: 189 | mean_loc = tf.matmul(output, Wl_h_l) + Bl_h_l 190 | 191 | #mean_loc = tf.stop_gradient(mean_loc) 192 | mean_locs.append(mean_loc) 193 | 194 | # add noise 195 | # sample_loc = tf.tanh(mean_loc + tf.random_normal(mean_loc.get_shape(), 0, loc_sd)) 196 | sample_loc = tf.maximum(-1.0, tf.minimum(1.0, mean_loc + tf.random_normal(mean_loc.get_shape(), 0, loc_sd))) 197 | 198 | # don't propagate throught the locations 199 | sample_loc = tf.stop_gradient(sample_loc) 200 | sampled_locs.append(sample_loc) 201 | 202 | return get_glimpse(sample_loc) 203 | 204 | 205 | def affineTransform(x,output_dim): 206 | """ 207 | affine transformation Wx+b 208 | assumes x.shape = (batch_size, num_features) 209 | """ 210 | w=tf.get_variable("w", [x.get_shape()[1], output_dim]) 211 | b=tf.get_variable("b", [output_dim], initializer=tf.constant_initializer(0.0)) 212 | return tf.matmul(x,w)+b 213 | 214 | 215 | def model(): 216 | # initialize the location under unif[-1,1], for all example in the batch 217 | 218 | initial_loc = tf.random_uniform((batch_size, 2), minval=-1, maxval=1) 219 | mean_locs.append(initial_loc) 220 | initial_loc = tf.tanh(initial_loc + tf.random_normal(initial_loc.get_shape(), 0, loc_sd)) 221 | sampled_locs.append(initial_loc) 222 | 223 | # get the input using the input network 224 | initial_glimpse = get_glimpse(initial_loc) 225 | 226 | # set up the recurrent structure 227 | inputs = [0] * nGlimpses 228 | outputs = [0] * nGlimpses 229 | glimpse = initial_glimpse 230 | REUSE = None 231 | for t in range(nGlimpses): 232 | if t == 0: # initialize the hidden state to be the zero vector 233 | hiddenState_prev = tf.zeros((batch_size, cell_size)) 234 | else: 235 | hiddenState_prev = outputs[t-1] 236 | 237 | # forward prop 238 | with tf.variable_scope("coreNetwork", reuse=REUSE): 239 | # the next hidden state is a function of the previous hidden state and the current glimpse 240 | hiddenState = tf.nn.relu(affineTransform(hiddenState_prev, cell_size) + (tf.matmul(glimpse, Wc_g_h) + Bc_g_h)) 241 | 242 | # save the current glimpse and the hidden state 243 | inputs[t] = glimpse 244 | outputs[t] = hiddenState 245 | # get the next input glimpse 246 | if t != nGlimpses -1: 247 | glimpse = get_next_input(hiddenState) 248 | else: 249 | baseline = tf.sigmoid(tf.matmul(hiddenState, Wb_h_b) + Bb_h_b) 250 | #baseline = tf.sigmoid(Bb_h_b) 251 | baselines.append(baseline) 252 | REUSE = True # share variables for later recurrence 253 | 254 | return outputs 255 | 256 | 257 | def dense_to_one_hot(labels_dense, num_classes=10): 258 | """Convert class labels from scalars to one-hot vectors.""" 259 | # copied from TensorFlow tutorial 260 | num_labels = labels_dense.shape[0] 261 | index_offset = np.arange(num_labels) * num_classes 262 | labels_one_hot = np.zeros((num_labels, num_classes)) 263 | labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 264 | return labels_one_hot 265 | 266 | 267 | # to use for maximum likelihood with input location 268 | def gaussian_pdf(mean, sample): 269 | Z = 1.0 / (loc_sd * tf.sqrt(2.0 * np.pi)) 270 | a = -tf.square(sample - mean) / (2.0 * tf.square(loc_sd)) 271 | return Z * tf.exp(a) 272 | 273 | 274 | def calc_reward(outputs): 275 | 276 | # consider the action at the last time step 277 | outputs = outputs[-1] # look at ONLY THE END of the sequence 278 | outputs = tf.reshape(outputs, (batch_size, cell_out_size)) 279 | 280 | # get the baseline 281 | # b = tf.stack(baselines) 282 | b = baselines[0] 283 | b = tf.concat(axis=1, values=[b, b]) 284 | # b = tf.reshape(b, (batch_size, (nGlimpses) * 2)) 285 | no_grad_b = tf.stop_gradient(b) 286 | # no_grad_b = b 287 | # get the action(classification) 288 | p_y = tf.nn.softmax(tf.matmul(outputs, Wa_h_a) + Ba_h_a) 289 | max_p_y = tf.arg_max(p_y, 1) 290 | correct_y = tf.cast(labels_placeholder, tf.int64) 291 | 292 | # reward for all examples in the batch 293 | R = tf.cast(tf.equal(max_p_y, correct_y), tf.float32) 294 | reward = tf.reduce_mean(R) # mean reward 295 | R = tf.reshape(R, (batch_size, 1)) 296 | # R = tf.tile(R, [1, (nGlimpses)*2]) 297 | R = tf.tile(R, [1, 2]) 298 | # get the location 299 | p_loc = gaussian_pdf(mean_locs, sampled_locs) 300 | p_loc = tf.sigmoid(p_loc) 301 | # p_loc_orig = p_loc 302 | 303 | p_loc_list = tf.unstack(p_loc, axis=1) 304 | p_loc_first = p_loc_list[1] 305 | #p_loc = tf.reshape(p_loc, (batch_size, (nGlimpses) * 2)) 306 | 307 | # define the cost function 308 | J = tf.concat(axis=1, values=[tf.log(p_y + SMALL_NUM) * onehot_labels_placeholder, tf.log(p_loc_first + SMALL_NUM) * (R - no_grad_b)]) 309 | # J = tf.concat(axis=1, values=[tf.log(p_y + SMALL_NUM) * (onehot_labels_placeholder), 310 | # tf.log(p_loc_first + SMALL_NUM) * (R - b)]) 311 | # J = tf.concat(axis=1, values=[tf.log(p_y + SMALL_NUM) * (onehot_labels_placeholder), 312 | # tf.log(p_loc_first + SMALL_NUM) * R]) 313 | 314 | J = tf.reduce_sum(J, 1) 315 | J = J - tf.reduce_sum(tf.square(R - b), 1) 316 | J = tf.reduce_mean(J, 0) 317 | cost = -J 318 | 319 | # define the optimizer 320 | optimizer = tf.train.MomentumOptimizer(lr, momentumValue) 321 | # optimizer = tf.train.AdagradOptimizer(lr) 322 | train_op = optimizer.minimize(cost, global_step) 323 | 324 | return cost, reward, max_p_y, correct_y, train_op, b, tf.reduce_mean(b), tf.reduce_mean(R - b), lr 325 | 326 | 327 | def preTrain(outputs): 328 | lr_r = 1e-3 329 | # consider the action at the last time step 330 | outputs = outputs[-1] # look at ONLY THE END of the sequence 331 | outputs = tf.reshape(outputs, (batch_size, cell_out_size)) 332 | # if preTraining: 333 | reconstruction = tf.sigmoid(tf.matmul(outputs, Wr_h_r) + Br_h_r) 334 | reconstructionCost = tf.reduce_mean(tf.square(inputs_placeholder - reconstruction)) 335 | 336 | train_op_r = tf.train.RMSPropOptimizer(lr_r).minimize(reconstructionCost) 337 | return reconstructionCost, reconstruction, train_op_r 338 | 339 | 340 | def evaluate(): 341 | data = dataset.test 342 | batches_in_epoch = len(data._images) // batch_size 343 | accuracy = 0 344 | 345 | for i in range(batches_in_epoch): 346 | nextX, nextY = dataset.test.next_batch(batch_size) 347 | if translateMnist: 348 | nextX, _ = convertTranslated(nextX, MNIST_SIZE, img_size) 349 | feed_dict = {inputs_placeholder: nextX, labels_placeholder: nextY, 350 | onehot_labels_placeholder: dense_to_one_hot(nextY)} 351 | r = sess.run(reward, feed_dict=feed_dict) 352 | accuracy += r 353 | 354 | accuracy /= batches_in_epoch 355 | print(("ACCURACY: " + str(accuracy))) 356 | 357 | 358 | def convertTranslated(images, initImgSize, finalImgSize): 359 | size_diff = finalImgSize - initImgSize 360 | newimages = np.zeros([batch_size, finalImgSize*finalImgSize]) 361 | imgCoord = np.zeros([batch_size,2]) 362 | for k in range(batch_size): 363 | image = images[k, :] 364 | image = np.reshape(image, (initImgSize, initImgSize)) 365 | # generate and save random coordinates 366 | randX = random.randint(0, size_diff) 367 | randY = random.randint(0, size_diff) 368 | imgCoord[k,:] = np.array([randX, randY]) 369 | # padding 370 | image = np.lib.pad(image, ((randX, size_diff - randX), (randY, size_diff - randY)), 'constant', constant_values = (0)) 371 | newimages[k, :] = np.reshape(image, (finalImgSize*finalImgSize)) 372 | 373 | return newimages, imgCoord 374 | 375 | 376 | def toMnistCoordinates(coordinate_tanh): 377 | ''' 378 | Transform coordinate in [-1,1] to mnist 379 | :param coordinate_tanh: vector in [-1,1] x [-1,1] 380 | :return: vector in the corresponding mnist coordinate 381 | ''' 382 | return np.round(((coordinate_tanh + 1) / 2.0) * img_size) 383 | 384 | 385 | def variable_summaries(var, name): 386 | """Attach a lot of summaries to a Tensor.""" 387 | with tf.name_scope('param_summaries'): 388 | mean = tf.reduce_mean(var) 389 | tf.summary.scalar('param_mean/' + name, mean) 390 | with tf.name_scope('param_stddev'): 391 | stddev = tf.sqrt(tf.reduce_sum(tf.square(var - mean))) 392 | tf.summary.scalar('param_sttdev/' + name, stddev) 393 | tf.summary.scalar('param_max/' + name, tf.reduce_max(var)) 394 | tf.summary.scalar('param_min/' + name, tf.reduce_min(var)) 395 | tf.summary.histogram(name, var) 396 | 397 | 398 | def plotWholeImg(img, img_size, sampled_locs_fetched): 399 | plt.imshow(np.reshape(img, [img_size, img_size]), 400 | cmap=plt.get_cmap('gray'), interpolation="nearest") 401 | 402 | plt.ylim((img_size - 1, 0)) 403 | plt.xlim((0, img_size - 1)) 404 | 405 | # transform the coordinate to mnist map 406 | sampled_locs_mnist_fetched = toMnistCoordinates(sampled_locs_fetched) 407 | # visualize the trace of successive nGlimpses (note that x and y coordinates are "flipped") 408 | plt.plot(sampled_locs_mnist_fetched[0, :, 1], sampled_locs_mnist_fetched[0, :, 0], '-o', 409 | color='lawngreen') 410 | plt.plot(sampled_locs_mnist_fetched[0, -1, 1], sampled_locs_mnist_fetched[0, -1, 0], 'o', 411 | color='red') 412 | 413 | 414 | g = tf.get_default_graph() 415 | with g.device("/gpu:1"): 416 | # set the learning rate 417 | global_step = tf.Variable(0, trainable=False) 418 | lr = tf.train.exponential_decay(initLr, global_step, lrDecayFreq, lrDecayRate, staircase=True) 419 | 420 | # preallocate x, y, baseline 421 | labels = tf.placeholder("float32", shape=[batch_size, n_classes]) 422 | labels_placeholder = tf.placeholder(tf.float32, shape=(batch_size), name="labels_raw") 423 | onehot_labels_placeholder = tf.placeholder(tf.float32, shape=(batch_size, 10), name="labels_onehot") 424 | inputs_placeholder = tf.placeholder(tf.float32, shape=(batch_size, img_size * img_size), name="images") 425 | 426 | # declare the model parameters, here're naming rule: 427 | # the 1st captical letter: weights or bias (W = weights, B = bias) 428 | # the 2nd lowercase letter: the network (e.g.: g = glimpse network) 429 | # the 3rd and 4th letter(s): input-output mapping, which is clearly written in the variable name argument 430 | 431 | Wg_l_h = weight_variable((2, hl_size), "glimpseNet_wts_location_hidden", True) 432 | Bg_l_h = weight_variable((1,hl_size), "glimpseNet_bias_location_hidden", True) 433 | 434 | Wg_g_h = weight_variable((totalSensorBandwidth, hg_size), "glimpseNet_wts_glimpse_hidden", True) 435 | Bg_g_h = weight_variable((1,hg_size), "glimpseNet_bias_glimpse_hidden", True) 436 | 437 | Wg_hg_gf1 = weight_variable((hg_size, g_size), "glimpseNet_wts_hiddenGlimpse_glimpseFeature1", True) 438 | Wg_hl_gf1 = weight_variable((hl_size, g_size), "glimpseNet_wts_hiddenLocation_glimpseFeature1", True) 439 | Bg_hlhg_gf1 = weight_variable((1,g_size), "glimpseNet_bias_hGlimpse_hLocs_glimpseFeature1", True) 440 | 441 | Wc_g_h = weight_variable((cell_size, g_size), "coreNet_wts_glimpse_hidden", True) 442 | Bc_g_h = weight_variable((1,g_size), "coreNet_bias_glimpse_hidden", True) 443 | 444 | Wr_h_r = weight_variable((cell_out_size, img_size**2), "reconstructionNet_wts_hidden_action", True) 445 | Br_h_r = weight_variable((1, img_size**2), "reconstructionNet_bias_hidden_action", True) 446 | 447 | Wb_h_b = weight_variable((g_size, 1), "baselineNet_wts_hiddenState_baseline", True) 448 | Bb_h_b = weight_variable((1,1), "baselineNet_bias_hiddenState_baseline", True) 449 | 450 | Wl_h_l = weight_variable((cell_out_size, 2), "locationNet_wts_hidden_location", True) 451 | Bl_h_l = weight_variable((1, 2), "locationNet_bias_hidden_location", True) 452 | 453 | 454 | Wa_h_a = weight_variable((cell_out_size, n_classes), "actionNet_wts_hidden_action", True) 455 | Ba_h_a = weight_variable((1,n_classes), "actionNet_bias_hidden_action", True) 456 | 457 | # query the model ouput 458 | outputs = model() 459 | 460 | # convert list of tensors to one big tensor 461 | sampled_locs = tf.concat(axis=0, values=sampled_locs) 462 | sampled_locs = tf.reshape(sampled_locs, (nGlimpses, batch_size, 2)) 463 | sampled_locs = tf.transpose(sampled_locs, [1, 0, 2]) 464 | mean_locs = tf.concat(axis=0, values=mean_locs) 465 | mean_locs = tf.reshape(mean_locs, (nGlimpses, batch_size, 2)) 466 | mean_locs = tf.transpose(mean_locs, [1, 0, 2]) 467 | glimpse_images = tf.concat(axis=0, values=glimpse_images) 468 | 469 | 470 | 471 | # compute the reward 472 | reconstructionCost, reconstruction, train_op_r = preTrain(outputs) 473 | cost, reward, predicted_labels, correct_labels, train_op, b, avg_b, rminusb, lr = calc_reward(outputs) 474 | 475 | # tensorboard visualization for the parameters 476 | variable_summaries(Wg_l_h, "glimpseNet_wts_location_hidden") 477 | variable_summaries(Bg_l_h, "glimpseNet_bias_location_hidden") 478 | variable_summaries(Wg_g_h, "glimpseNet_wts_glimpse_hidden") 479 | variable_summaries(Bg_g_h, "glimpseNet_bias_glimpse_hidden") 480 | variable_summaries(Wg_hg_gf1, "glimpseNet_wts_hiddenGlimpse_glimpseFeature1") 481 | variable_summaries(Wg_hl_gf1, "glimpseNet_wts_hiddenLocation_glimpseFeature1") 482 | variable_summaries(Bg_hlhg_gf1, "glimpseNet_bias_hGlimpse_hLocs_glimpseFeature1") 483 | 484 | variable_summaries(Wc_g_h, "coreNet_wts_glimpse_hidden") 485 | variable_summaries(Bc_g_h, "coreNet_bias_glimpse_hidden") 486 | 487 | variable_summaries(Wb_h_b, "baselineNet_wts_hiddenState_baseline") 488 | variable_summaries(Bb_h_b, "baselineNet_bias_hiddenState_baseline") 489 | 490 | variable_summaries(Wl_h_l, "locationNet_wts_hidden_location") 491 | 492 | variable_summaries(Wa_h_a, 'actionNet_wts_hidden_action') 493 | variable_summaries(Ba_h_a, 'actionNet_bias_hidden_action') 494 | 495 | # tensorboard visualization for the performance metrics 496 | tf.summary.scalar("reconstructionCost", reconstructionCost) 497 | tf.summary.scalar("reward", reward) 498 | tf.summary.scalar("cost", cost) 499 | tf.summary.scalar("mean(b)", avg_b) 500 | tf.summary.scalar("mean(R-b)", rminusb) 501 | summary_op = tf.summary.merge_all() 502 | 503 | 504 | ####################################### START RUNNING THE MODEL ####################################### 505 | sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=True) 506 | sess_config.gpu_options.allow_growth = True 507 | sess = tf.Session(graph=g, config=sess_config) 508 | saver = tf.train.Saver() 509 | b_fetched = np.zeros((batch_size, (nGlimpses)*2)) 510 | 511 | init = tf.global_variables_initializer() 512 | sess.run(init) 513 | 514 | if eval_only: 515 | evaluate() 516 | else: 517 | summary_writer = tf.summary.FileWriter(summaryFolderName, graph=sess.graph) 518 | 519 | if draw: 520 | fig = plt.figure(1) 521 | txt = fig.suptitle("-", fontsize=36, fontweight='bold') 522 | plt.ion() 523 | plt.show() 524 | plt.subplots_adjust(top=0.7) 525 | plotImgs = [] 526 | 527 | if drawReconsturction: 528 | fig = plt.figure(2) 529 | txt = fig.suptitle("-", fontsize=36, fontweight='bold') 530 | plt.ion() 531 | plt.show() 532 | 533 | if preTraining: 534 | for epoch_r in range(1,preTraining_epoch): 535 | nextX, _ = dataset.train.next_batch(batch_size) 536 | nextX_orig = nextX 537 | if translateMnist: 538 | nextX, _ = convertTranslated(nextX, MNIST_SIZE, img_size) 539 | 540 | fetches_r = [reconstructionCost, reconstruction, train_op_r] 541 | 542 | reconstructionCost_fetched, reconstruction_fetched, train_op_r_fetched = sess.run(fetches_r, feed_dict={inputs_placeholder: nextX}) 543 | 544 | if epoch_r % 20 == 0: 545 | print(('Step %d: reconstructionCost = %.5f' % (epoch_r, reconstructionCost_fetched))) 546 | if epoch_r % 100 == 0: 547 | if drawReconsturction: 548 | fig = plt.figure(2) 549 | 550 | plt.subplot(1, 2, 1) 551 | plt.imshow(np.reshape(nextX[0, :], [img_size, img_size]), 552 | cmap=plt.get_cmap('gray'), interpolation="nearest") 553 | plt.ylim((img_size - 1, 0)) 554 | plt.xlim((0, img_size - 1)) 555 | 556 | plt.subplot(1, 2, 2) 557 | plt.imshow(np.reshape(reconstruction_fetched[0, :], [img_size, img_size]), 558 | cmap=plt.get_cmap('gray'), interpolation="nearest") 559 | plt.ylim((img_size - 1, 0)) 560 | plt.xlim((0, img_size - 1)) 561 | plt.draw() 562 | plt.pause(0.0001) 563 | # plt.show() 564 | 565 | 566 | # training 567 | for epoch in range(start_step + 1, max_iters): 568 | start_time = time.time() 569 | 570 | # get the next batch of examples 571 | nextX, nextY = dataset.train.next_batch(batch_size) 572 | nextX_orig = nextX 573 | if translateMnist: 574 | nextX, nextX_coord = convertTranslated(nextX, MNIST_SIZE, img_size) 575 | 576 | feed_dict = {inputs_placeholder: nextX, labels_placeholder: nextY, \ 577 | onehot_labels_placeholder: dense_to_one_hot(nextY)} 578 | 579 | fetches = [train_op, cost, reward, predicted_labels, correct_labels, glimpse_images, avg_b, rminusb, \ 580 | mean_locs, sampled_locs, lr] 581 | # feed them to the model 582 | results = sess.run(fetches, feed_dict=feed_dict) 583 | 584 | _, cost_fetched, reward_fetched, prediction_labels_fetched, correct_labels_fetched, glimpse_images_fetched, \ 585 | avg_b_fetched, rminusb_fetched, mean_locs_fetched, sampled_locs_fetched, lr_fetched = results 586 | 587 | 588 | duration = time.time() - start_time 589 | 590 | if epoch % 20 == 0: 591 | print(('Step %d: cost = %.5f reward = %.5f (%.3f sec) b = %.5f R-b = %.5f, LR = %.5f' 592 | % (epoch, cost_fetched, reward_fetched, duration, avg_b_fetched, rminusb_fetched, lr_fetched))) 593 | summary_str = sess.run(summary_op, feed_dict=feed_dict) 594 | summary_writer.add_summary(summary_str, epoch) 595 | # if saveImgs: 596 | # plt.savefig(imgsFolderName + simulationName + '_ep%.6d.png' % (epoch)) 597 | 598 | if epoch % 5000 == 0: 599 | saver.save(sess, save_dir + save_prefix + str(epoch) + ".ckpt") 600 | evaluate() 601 | 602 | ##### DRAW WINDOW ################ 603 | f_glimpse_images = np.reshape(glimpse_images_fetched, \ 604 | (nGlimpses, batch_size, depth, sensorBandwidth, sensorBandwidth)) 605 | 606 | if draw: 607 | if animate: 608 | fillList = False 609 | if len(plotImgs) == 0: 610 | fillList = True 611 | 612 | # display the first image in the in mini-batch 613 | nCols = depth+1 614 | plt.subplot2grid((depth, nCols), (0, 1), rowspan=depth, colspan=depth) 615 | # display the entire image 616 | plotWholeImg(nextX[0, :], img_size, sampled_locs_fetched) 617 | 618 | # display the glimpses 619 | for y in range(nGlimpses): 620 | txt.set_text('Epoch: %.6d \nPrediction: %i -- Truth: %i\nStep: %i/%i' 621 | % (epoch, prediction_labels_fetched[0], correct_labels_fetched[0], (y + 1), nGlimpses)) 622 | 623 | for x in range(depth): 624 | plt.subplot(depth, nCols, 1 + nCols * x) 625 | if fillList: 626 | plotImg = plt.imshow(f_glimpse_images[y, 0, x], cmap=plt.get_cmap('gray'), 627 | interpolation="nearest") 628 | plotImg.autoscale() 629 | plotImgs.append(plotImg) 630 | else: 631 | plotImgs[x].set_data(f_glimpse_images[y, 0, x]) 632 | plotImgs[x].autoscale() 633 | fillList = False 634 | 635 | # fig.canvas.draw() 636 | time.sleep(0.1) 637 | plt.pause(0.00005) 638 | 639 | else: 640 | txt.set_text('PREDICTION: %i\nTRUTH: %i' % (prediction_labels_fetched[0], correct_labels_fetched[0])) 641 | for x in range(depth): 642 | for y in range(nGlimpses): 643 | plt.subplot(depth, nGlimpses, x * nGlimpses + y + 1) 644 | plt.imshow(f_glimpse_images[y, 0, x], cmap=plt.get_cmap('gray'), interpolation="nearest") 645 | 646 | plt.draw() 647 | time.sleep(0.05) 648 | plt.pause(0.0001) 649 | 650 | sess.close() -------------------------------------------------------------------------------- /report.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- 2 | Processing file 'ram.py' 3 | outputting to 'ram_up.py' 4 | -------------------------------------------------------------------------------- 5 | 6 | 'ram.py' Line 448 7 | -------------------------------------------------------------------------------- 8 | 9 | Added keyword 'concat_dim' to reordered function 'tf.concat' 10 | Added keyword 'values' to reordered function 'tf.concat' 11 | 12 | Old: sampled_locs = tf.concat(0, sampled_locs) 13 | 14 | New: sampled_locs = tf.concat(axis=0, values=sampled_locs) 15 | ~~~~~ ~~~~~~~ 16 | 17 | 'ram.py' Line 451 18 | -------------------------------------------------------------------------------- 19 | 20 | Added keyword 'concat_dim' to reordered function 'tf.concat' 21 | Added keyword 'values' to reordered function 'tf.concat' 22 | 23 | Old: mean_locs = tf.concat(0, mean_locs) 24 | 25 | New: mean_locs = tf.concat(axis=0, values=mean_locs) 26 | ~~~~~ ~~~~~~~ 27 | 28 | 'ram.py' Line 484 29 | -------------------------------------------------------------------------------- 30 | 31 | Renamed function 'tf.scalar_summary' to 'tf.summary.scalar' 32 | 33 | Old: tf.scalar_summary("reward", reward) 34 | ~~~~~~~~~~~~~~~~~ 35 | New: tf.summary.scalar("reward", reward) 36 | ~~~~~~~~~~~~~~~~~ 37 | 38 | 'ram.py' Line 485 39 | -------------------------------------------------------------------------------- 40 | 41 | Renamed function 'tf.scalar_summary' to 'tf.summary.scalar' 42 | 43 | Old: tf.scalar_summary("cost", cost) 44 | ~~~~~~~~~~~~~~~~~ 45 | New: tf.summary.scalar("cost", cost) 46 | ~~~~~~~~~~~~~~~~~ 47 | 48 | 'ram.py' Line 454 49 | -------------------------------------------------------------------------------- 50 | 51 | Added keyword 'concat_dim' to reordered function 'tf.concat' 52 | Added keyword 'values' to reordered function 'tf.concat' 53 | 54 | Old: glimpse_images = tf.concat(0, glimpse_images) 55 | 56 | New: glimpse_images = tf.concat(axis=0, values=glimpse_images) 57 | ~~~~~ ~~~~~~~ 58 | 59 | 'ram.py' Line 487 60 | -------------------------------------------------------------------------------- 61 | 62 | Renamed function 'tf.scalar_summary' to 'tf.summary.scalar' 63 | 64 | Old: tf.scalar_summary(" mean(R - b)", rminusb) 65 | ~~~~~~~~~~~~~~~~~ 66 | New: tf.summary.scalar(" mean(R - b)", rminusb) 67 | ~~~~~~~~~~~~~~~~~ 68 | 69 | 'ram.py' Line 488 70 | -------------------------------------------------------------------------------- 71 | 72 | Renamed function 'tf.merge_all_summaries' to 'tf.summary.merge_all' 73 | 74 | Old: summary_op = tf.merge_all_summaries() 75 | ~~~~~~~~~~~~~~~~~~~~~~ 76 | New: summary_op = tf.summary.merge_all() 77 | ~~~~~~~~~~~~~~~~~~~~ 78 | 79 | 'ram.py' Line 486 80 | -------------------------------------------------------------------------------- 81 | 82 | Renamed function 'tf.scalar_summary' to 'tf.summary.scalar' 83 | 84 | Old: tf.scalar_summary("mean(b)", avg_b) 85 | ~~~~~~~~~~~~~~~~~ 86 | New: tf.summary.scalar("mean(b)", avg_b) 87 | ~~~~~~~~~~~~~~~~~ 88 | 89 | 'ram.py' Line 300 90 | -------------------------------------------------------------------------------- 91 | 92 | Added keyword 'concat_dim' to reordered function 'tf.concat' 93 | Added keyword 'values' to reordered function 'tf.concat' 94 | 95 | Old: J = tf.concat(1, [tf.log(p_y + SMALL_NUM) * (onehot_labels_placeholder), tf.log(p_loc + SMALL_NUM) * (R - no_grad_b)]) 96 | 97 | New: J = tf.concat(axis=1, values=[tf.log(p_y + SMALL_NUM) * (onehot_labels_placeholder), tf.log(p_loc + SMALL_NUM) * (R - no_grad_b)]) 98 | ~~~~~ ~~~~~~~ 99 | 100 | 'ram.py' Line 502 101 | -------------------------------------------------------------------------------- 102 | 103 | Renamed function 'tf.train.SummaryWriter' to 'tf.summary.FileWriter' 104 | 105 | Old: summary_writer = tf.train.SummaryWriter(summaryFolderName, graph=sess.graph) 106 | ~~~~~~~~~~~~~~~~~~~~~~ 107 | New: summary_writer = tf.summary.FileWriter(summaryFolderName, graph=sess.graph) 108 | ~~~~~~~~~~~~~~~~~~~~~ 109 | 110 | 'ram.py' Line 496 111 | -------------------------------------------------------------------------------- 112 | 113 | Renamed function 'tf.initialize_all_variables' to 'tf.global_variables_initializer' 114 | 115 | Old: init = tf.initialize_all_variables() 116 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~ 117 | New: init = tf.global_variables_initializer() 118 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 119 | 120 | 'ram.py' Line 483 121 | -------------------------------------------------------------------------------- 122 | 123 | Renamed function 'tf.scalar_summary' to 'tf.summary.scalar' 124 | 125 | Old: tf.scalar_summary("reconstructionCost", reconstructionCost) 126 | ~~~~~~~~~~~~~~~~~ 127 | New: tf.summary.scalar("reconstructionCost", reconstructionCost) 128 | ~~~~~~~~~~~~~~~~~ 129 | 130 | 'ram.py' Line 277 131 | -------------------------------------------------------------------------------- 132 | 133 | Renamed function 'tf.pack' to 'tf.stack' 134 | 135 | Old: b = tf.pack(baselines) 136 | ~~~~~~~ 137 | New: b = tf.stack(baselines) 138 | ~~~~~~~~ 139 | 140 | 'ram.py' Line 278 141 | -------------------------------------------------------------------------------- 142 | 143 | Added keyword 'concat_dim' to reordered function 'tf.concat' 144 | Added keyword 'values' to reordered function 'tf.concat' 145 | 146 | Old: b = tf.concat(2, [b, b]) 147 | 148 | New: b = tf.concat(axis=2, values=[b, b]) 149 | ~~~~~ ~~~~~~~ 150 | 151 | 'ram.py' Line 152 152 | -------------------------------------------------------------------------------- 153 | 154 | Renamed function 'tf.pack' to 'tf.stack' 155 | 156 | Old: zooms.append(tf.pack(imgZooms)) 157 | ~~~~~~~ 158 | New: zooms.append(tf.stack(imgZooms)) 159 | ~~~~~~~~ 160 | 161 | 'ram.py' Line 377 162 | -------------------------------------------------------------------------------- 163 | 164 | Renamed function 'tf.scalar_summary' to 'tf.summary.scalar' 165 | 166 | Old: tf.scalar_summary('param_mean/' + name, mean) 167 | ~~~~~~~~~~~~~~~~~ 168 | New: tf.summary.scalar('param_mean/' + name, mean) 169 | ~~~~~~~~~~~~~~~~~ 170 | 171 | 'ram.py' Line 154 172 | -------------------------------------------------------------------------------- 173 | 174 | Renamed function 'tf.pack' to 'tf.stack' 175 | 176 | Old: zooms = tf.pack(zooms) 177 | ~~~~~~~ 178 | New: zooms = tf.stack(zooms) 179 | ~~~~~~~~ 180 | 181 | 'ram.py' Line 380 182 | -------------------------------------------------------------------------------- 183 | 184 | Renamed function 'tf.scalar_summary' to 'tf.summary.scalar' 185 | 186 | Old: tf.scalar_summary('param_sttdev/' + name, stddev) 187 | ~~~~~~~~~~~~~~~~~ 188 | New: tf.summary.scalar('param_sttdev/' + name, stddev) 189 | ~~~~~~~~~~~~~~~~~ 190 | 191 | 'ram.py' Line 381 192 | -------------------------------------------------------------------------------- 193 | 194 | Renamed function 'tf.scalar_summary' to 'tf.summary.scalar' 195 | 196 | Old: tf.scalar_summary('param_max/' + name, tf.reduce_max(var)) 197 | ~~~~~~~~~~~~~~~~~ 198 | New: tf.summary.scalar('param_max/' + name, tf.reduce_max(var)) 199 | ~~~~~~~~~~~~~~~~~ 200 | 201 | 'ram.py' Line 382 202 | -------------------------------------------------------------------------------- 203 | 204 | Renamed function 'tf.scalar_summary' to 'tf.summary.scalar' 205 | 206 | Old: tf.scalar_summary('param_min/' + name, tf.reduce_min(var)) 207 | ~~~~~~~~~~~~~~~~~ 208 | New: tf.summary.scalar('param_min/' + name, tf.reduce_min(var)) 209 | ~~~~~~~~~~~~~~~~~ 210 | 211 | 'ram.py' Line 383 212 | -------------------------------------------------------------------------------- 213 | 214 | Renamed function 'tf.histogram_summary' to 'tf.summary.histogram' 215 | 216 | Old: tf.histogram_summary(name, var) 217 | ~~~~~~~~~~~~~~~~~~~~ 218 | New: tf.summary.histogram(name, var) 219 | ~~~~~~~~~~~~~~~~~~~~ 220 | 221 | 222 | -------------------------------------------------------------------------------- /tf_mnist_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Functions for downloading and reading MNIST data.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | import gzip 20 | import os 21 | import numpy 22 | from six.moves import urllib 23 | from six.moves import xrange # pylint: disable=redefined-builtin 24 | SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/' 25 | def maybe_download(filename, work_directory): 26 | """Download the data from Yann's website, unless it's already here.""" 27 | if not os.path.exists(work_directory): 28 | os.mkdir(work_directory) 29 | filepath = os.path.join(work_directory, filename) 30 | if not os.path.exists(filepath): 31 | filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath) 32 | statinfo = os.stat(filepath) 33 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 34 | return filepath 35 | def _read32(bytestream): 36 | dt = numpy.dtype(numpy.uint32).newbyteorder('>') 37 | return numpy.frombuffer(bytestream.read(4), dtype=dt)[0] 38 | def extract_images(filename): 39 | """Extract the images into a 4D uint8 numpy array [index, y, x, depth].""" 40 | print('Extracting', filename) 41 | with gzip.open(filename) as bytestream: 42 | magic = _read32(bytestream) 43 | if magic != 2051: 44 | raise ValueError( 45 | 'Invalid magic number %d in MNIST image file: %s' % 46 | (magic, filename)) 47 | num_images = _read32(bytestream) 48 | rows = _read32(bytestream) 49 | cols = _read32(bytestream) 50 | buf = bytestream.read(rows * cols * num_images) 51 | data = numpy.frombuffer(buf, dtype=numpy.uint8) 52 | data = data.reshape(num_images, rows, cols, 1) 53 | return data 54 | def dense_to_one_hot(labels_dense, num_classes=10): 55 | """Convert class labels from scalars to one-hot vectors.""" 56 | num_labels = labels_dense.shape[0] 57 | index_offset = numpy.arange(num_labels) * num_classes 58 | labels_one_hot = numpy.zeros((num_labels, num_classes)) 59 | labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 60 | return labels_one_hot 61 | def extract_labels(filename, one_hot=False): 62 | """Extract the labels into a 1D uint8 numpy array [index].""" 63 | print('Extracting', filename) 64 | with gzip.open(filename) as bytestream: 65 | magic = _read32(bytestream) 66 | if magic != 2049: 67 | raise ValueError( 68 | 'Invalid magic number %d in MNIST label file: %s' % 69 | (magic, filename)) 70 | num_items = _read32(bytestream) 71 | buf = bytestream.read(num_items) 72 | labels = numpy.frombuffer(buf, dtype=numpy.uint8) 73 | if one_hot: 74 | return dense_to_one_hot(labels) 75 | return labels 76 | class DataSet(object): 77 | def __init__(self, images, labels, fake_data=False, one_hot=False): 78 | """Construct a DataSet. one_hot arg is used only if fake_data is true.""" 79 | if fake_data: 80 | self._num_examples = 10000 81 | self.one_hot = one_hot 82 | else: 83 | assert images.shape[0] == labels.shape[0], ( 84 | 'images.shape: %s labels.shape: %s' % (images.shape, 85 | labels.shape)) 86 | self._num_examples = images.shape[0] 87 | # Convert shape from [num examples, rows, columns, depth] 88 | # to [num examples, rows*columns] (assuming depth == 1) 89 | assert images.shape[3] == 1 90 | images = images.reshape(images.shape[0], 91 | images.shape[1] * images.shape[2]) 92 | # Convert from [0, 255] -> [0.0, 1.0]. 93 | images = images.astype(numpy.float32) 94 | images = numpy.multiply(images, 1.0 / 255.0) 95 | self._images = images 96 | self._labels = labels 97 | self._epochs_completed = 0 98 | self._index_in_epoch = 0 99 | @property 100 | def images(self): 101 | return self._images 102 | @property 103 | def labels(self): 104 | return self._labels 105 | @property 106 | def num_examples(self): 107 | return self._num_examples 108 | @property 109 | def epochs_completed(self): 110 | return self._epochs_completed 111 | def next_batch(self, batch_size, fake_data=False): 112 | """Return the next `batch_size` examples from this data set.""" 113 | if fake_data: 114 | fake_image = [1] * 784 115 | if self.one_hot: 116 | fake_label = [1] + [0] * 9 117 | else: 118 | fake_label = 0 119 | return [fake_image for _ in xrange(batch_size)], [ 120 | fake_label for _ in xrange(batch_size)] 121 | start = self._index_in_epoch 122 | self._index_in_epoch += batch_size 123 | if self._index_in_epoch > self._num_examples: 124 | # Finished epoch 125 | self._epochs_completed += 1 126 | # Shuffle the data 127 | perm = numpy.arange(self._num_examples) 128 | numpy.random.shuffle(perm) 129 | self._images = self._images[perm] 130 | self._labels = self._labels[perm] 131 | # Start next epoch 132 | start = 0 133 | self._index_in_epoch = batch_size 134 | assert batch_size <= self._num_examples 135 | end = self._index_in_epoch 136 | return self._images[start:end], self._labels[start:end] 137 | def read_data_sets(train_dir, fake_data=False, one_hot=False): 138 | class DataSets(object): 139 | pass 140 | data_sets = DataSets() 141 | if fake_data: 142 | data_sets.train = DataSet([], [], fake_data=True, one_hot=one_hot) 143 | data_sets.validation = DataSet([], [], fake_data=True, one_hot=one_hot) 144 | data_sets.test = DataSet([], [], fake_data=True, one_hot=one_hot) 145 | return data_sets 146 | TRAIN_IMAGES = 'train-images-idx3-ubyte.gz' 147 | TRAIN_LABELS = 'train-labels-idx1-ubyte.gz' 148 | TEST_IMAGES = 't10k-images-idx3-ubyte.gz' 149 | TEST_LABELS = 't10k-labels-idx1-ubyte.gz' 150 | VALIDATION_SIZE = 5000 151 | local_file = maybe_download(TRAIN_IMAGES, train_dir) 152 | train_images = extract_images(local_file) 153 | local_file = maybe_download(TRAIN_LABELS, train_dir) 154 | train_labels = extract_labels(local_file, one_hot=one_hot) 155 | local_file = maybe_download(TEST_IMAGES, train_dir) 156 | test_images = extract_images(local_file) 157 | local_file = maybe_download(TEST_LABELS, train_dir) 158 | test_labels = extract_labels(local_file, one_hot=one_hot) 159 | validation_images = train_images[:VALIDATION_SIZE] 160 | validation_labels = train_labels[:VALIDATION_SIZE] 161 | train_images = train_images[VALIDATION_SIZE:] 162 | train_labels = train_labels[VALIDATION_SIZE:] 163 | data_sets.train = DataSet(train_images, train_labels) 164 | data_sets.validation = DataSet(validation_images, validation_labels) 165 | data_sets.test = DataSet(test_images, test_labels) 166 | return data_sets -------------------------------------------------------------------------------- /tf_mnist_loader.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jtkim-kaist/ram_modified/e77fd42c3ad19b8d19d5e3a17d4ac3321f66ac90/tf_mnist_loader.pyc -------------------------------------------------------------------------------- /tf_upgrade.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Upgrader for Python scripts from pre-1.0 TensorFlow to 1.0 TensorFlow.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | import argparse 21 | import ast 22 | import collections 23 | import os 24 | import shutil 25 | import sys 26 | import tempfile 27 | import traceback 28 | 29 | 30 | class APIChangeSpec(object): 31 | """List of maps that describe what changed in the API.""" 32 | 33 | def __init__(self): 34 | # Maps from a function name to a dictionary that describes how to 35 | # map from an old argument keyword to the new argument keyword. 36 | self.function_keyword_renames = { 37 | "tf.count_nonzero": { 38 | "reduction_indices": "axis" 39 | }, 40 | "tf.reduce_all": { 41 | "reduction_indices": "axis" 42 | }, 43 | "tf.reduce_any": { 44 | "reduction_indices": "axis" 45 | }, 46 | "tf.reduce_max": { 47 | "reduction_indices": "axis" 48 | }, 49 | "tf.reduce_mean": { 50 | "reduction_indices": "axis" 51 | }, 52 | "tf.reduce_min": { 53 | "reduction_indices": "axis" 54 | }, 55 | "tf.reduce_prod": { 56 | "reduction_indices": "axis" 57 | }, 58 | "tf.reduce_sum": { 59 | "reduction_indices": "axis" 60 | }, 61 | "tf.reduce_logsumexp": { 62 | "reduction_indices": "axis" 63 | }, 64 | "tf.expand_dims": { 65 | "dim": "axis" 66 | }, 67 | "tf.argmax": { 68 | "dimension": "axis" 69 | }, 70 | "tf.argmin": { 71 | "dimension": "axis" 72 | }, 73 | "tf.reduce_join": { 74 | "reduction_indices": "axis" 75 | }, 76 | "tf.sparse_concat": { 77 | "concat_dim": "axis" 78 | }, 79 | "tf.sparse_split": { 80 | "split_dim": "axis" 81 | }, 82 | "tf.sparse_reduce_sum": { 83 | "reduction_axes": "axis" 84 | }, 85 | "tf.reverse_sequence": { 86 | "seq_dim": "seq_axis", 87 | "batch_dim": "batch_axis" 88 | }, 89 | "tf.sparse_reduce_sum_sparse": { 90 | "reduction_axes": "axis" 91 | }, 92 | "tf.squeeze": { 93 | "squeeze_dims": "axis" 94 | }, 95 | "tf.split": { 96 | "split_dim": "axis", 97 | "num_split": "num_or_size_splits" 98 | }, 99 | "tf.concat": { 100 | "concat_dim": "axis" 101 | }, 102 | } 103 | 104 | # Mapping from function to the new name of the function 105 | self.function_renames = { 106 | "tf.inv": "tf.reciprocal", 107 | "tf.contrib.deprecated.scalar_summary": "tf.summary.scalar", 108 | "tf.contrib.deprecated.histogram_summary": "tf.summary.histogram", 109 | "tf.listdiff": "tf.setdiff1d", 110 | "tf.list_diff": "tf.setdiff1d", 111 | "tf.mul": "tf.multiply", 112 | "tf.neg": "tf.negative", 113 | "tf.sub": "tf.subtract", 114 | "tf.train.SummaryWriter": "tf.summary.FileWriter", 115 | "tf.scalar_summary": "tf.summary.scalar", 116 | "tf.histogram_summary": "tf.summary.histogram", 117 | "tf.audio_summary": "tf.summary.audio", 118 | "tf.image_summary": "tf.summary.image", 119 | "tf.merge_summary": "tf.summary.merge", 120 | "tf.merge_all_summaries": "tf.summary.merge_all", 121 | "tf.image.per_image_whitening": "tf.image.per_image_standardization", 122 | "tf.all_variables": "tf.global_variables", 123 | "tf.VARIABLES": "tf.GLOBAL_VARIABLES", 124 | "tf.initialize_all_variables": "tf.global_variables_initializer", 125 | "tf.initialize_variables": "tf.variables_initializer", 126 | "tf.initialize_local_variables": "tf.local_variables_initializer", 127 | "tf.batch_matrix_diag": "tf.matrix_diag", 128 | "tf.batch_band_part": "tf.band_part", 129 | "tf.batch_set_diag": "tf.set_diag", 130 | "tf.batch_matrix_transpose": "tf.matrix_transpose", 131 | "tf.batch_matrix_determinant": "tf.matrix_determinant", 132 | "tf.batch_matrix_inverse": "tf.matrix_inverse", 133 | "tf.batch_cholesky": "tf.cholesky", 134 | "tf.batch_cholesky_solve": "tf.cholesky_solve", 135 | "tf.batch_matrix_solve": "tf.matrix_solve", 136 | "tf.batch_matrix_triangular_solve": "tf.matrix_triangular_solve", 137 | "tf.batch_matrix_solve_ls": "tf.matrix_solve_ls", 138 | "tf.batch_self_adjoint_eig": "tf.self_adjoint_eig", 139 | "tf.batch_self_adjoint_eigvals": "tf.self_adjoint_eigvals", 140 | "tf.batch_svd": "tf.svd", 141 | "tf.batch_fft": "tf.fft", 142 | "tf.batch_ifft": "tf.ifft", 143 | "tf.batch_ifft2d": "tf.ifft2d", 144 | "tf.batch_fft3d": "tf.fft3d", 145 | "tf.batch_ifft3d": "tf.ifft3d", 146 | "tf.select": "tf.where", 147 | "tf.complex_abs": "tf.abs", 148 | "tf.batch_matmul": "tf.matmul", 149 | "tf.pack": "tf.stack", 150 | "tf.unpack": "tf.unstack", 151 | } 152 | 153 | self.change_to_function = { 154 | "tf.ones_initializer", 155 | "tf.zeros_initializer", 156 | } 157 | 158 | # Functions that were reordered should be changed to the new keyword args 159 | # for safety, if positional arguments are used. If you have reversed the 160 | # positional arguments yourself, this could do the wrong thing. 161 | self.function_reorders = { 162 | "tf.split": ["axis", "num_or_size_splits", "value", "name"], 163 | "tf.sparse_split": ["axis", "num_or_size_splits", "value", "name"], 164 | "tf.concat": ["concat_dim", "values", "name"], 165 | "tf.svd": ["tensor", "compute_uv", "full_matrices", "name"], 166 | "tf.nn.softmax_cross_entropy_with_logits": [ 167 | "logits", "labels", "dim", "name"], 168 | "tf.nn.sparse_softmax_cross_entropy_with_logits": [ 169 | "logits", "labels", "name"], 170 | "tf.nn.sigmoid_cross_entropy_with_logits": [ 171 | "logits", "labels", "name"] 172 | } 173 | 174 | # Specially handled functions. 175 | self.function_handle = {"tf.reverse": self._reverse_handler} 176 | 177 | @staticmethod 178 | def _reverse_handler(file_edit_recorder, node): 179 | # TODO(aselle): Could check for a literal list of bools and try to convert 180 | # them to indices. 181 | comment = ("ERROR: tf.reverse has had its argument semantics changed\n" 182 | "significantly the converter cannot detect this reliably, so you" 183 | "need to inspect this usage manually.\n") 184 | file_edit_recorder.add(comment, 185 | node.lineno, 186 | node.col_offset, 187 | "tf.reverse", 188 | "tf.reverse", 189 | error="tf.reverse requires manual check.") 190 | 191 | 192 | class FileEditTuple(collections.namedtuple( 193 | "FileEditTuple", ["comment", "line", "start", "old", "new"])): 194 | """Each edit that is recorded by a FileEditRecorder. 195 | 196 | Fields: 197 | comment: A description of the edit and why it was made. 198 | line: The line number in the file where the edit occurs (1-indexed). 199 | start: The line number in the file where the edit occurs (0-indexed). 200 | old: text string to remove (this must match what was in file). 201 | new: text string to add in place of `old`. 202 | """ 203 | 204 | __slots__ = () 205 | 206 | 207 | class FileEditRecorder(object): 208 | """Record changes that need to be done to the file.""" 209 | 210 | def __init__(self, filename): 211 | # all edits are lists of chars 212 | self._filename = filename 213 | 214 | self._line_to_edit = collections.defaultdict(list) 215 | self._errors = [] 216 | 217 | def process(self, text): 218 | """Process a list of strings, each corresponding to the recorded changes. 219 | 220 | Args: 221 | text: A list of lines of text (assumed to contain newlines) 222 | Returns: 223 | A tuple of the modified text and a textual description of what is done. 224 | Raises: 225 | ValueError: if substitution source location does not have expected text. 226 | """ 227 | 228 | change_report = "" 229 | 230 | # Iterate of each line 231 | for line, edits in self._line_to_edit.items(): 232 | offset = 0 233 | # sort by column so that edits are processed in order in order to make 234 | # indexing adjustments cumulative for changes that change the string 235 | # length 236 | edits.sort(key=lambda x: x.start) 237 | 238 | # Extract each line to a list of characters, because mutable lists 239 | # are editable, unlike immutable strings. 240 | char_array = list(text[line - 1]) 241 | 242 | # Record a description of the change 243 | change_report += "%r Line %d\n" % (self._filename, line) 244 | change_report += "-" * 80 + "\n\n" 245 | for e in edits: 246 | change_report += "%s\n" % e.comment 247 | change_report += "\n Old: %s" % (text[line - 1]) 248 | 249 | # Make underscore buffers for underlining where in the line the edit was 250 | change_list = [" "] * len(text[line - 1]) 251 | change_list_new = [" "] * len(text[line - 1]) 252 | 253 | # Iterate for each edit 254 | for e in edits: 255 | # Create effective start, end by accounting for change in length due 256 | # to previous edits 257 | start_eff = e.start + offset 258 | end_eff = start_eff + len(e.old) 259 | 260 | # Make sure the edit is changing what it should be changing 261 | old_actual = "".join(char_array[start_eff:end_eff]) 262 | if old_actual != e.old: 263 | raise ValueError("Expected text %r but got %r" % 264 | ("".join(e.old), "".join(old_actual))) 265 | # Make the edit 266 | char_array[start_eff:end_eff] = list(e.new) 267 | 268 | # Create the underline highlighting of the before and after 269 | change_list[e.start:e.start + len(e.old)] = "~" * len(e.old) 270 | change_list_new[start_eff:end_eff] = "~" * len(e.new) 271 | 272 | # Keep track of how to generate effective ranges 273 | offset += len(e.new) - len(e.old) 274 | 275 | # Finish the report comment 276 | change_report += " %s\n" % "".join(change_list) 277 | text[line - 1] = "".join(char_array) 278 | change_report += " New: %s" % (text[line - 1]) 279 | change_report += " %s\n\n" % "".join(change_list_new) 280 | return "".join(text), change_report, self._errors 281 | 282 | def add(self, comment, line, start, old, new, error=None): 283 | """Add a new change that is needed. 284 | 285 | Args: 286 | comment: A description of what was changed 287 | line: Line number (1 indexed) 288 | start: Column offset (0 indexed) 289 | old: old text 290 | new: new text 291 | error: this "edit" is something that cannot be fixed automatically 292 | Returns: 293 | None 294 | """ 295 | 296 | self._line_to_edit[line].append( 297 | FileEditTuple(comment, line, start, old, new)) 298 | if error: 299 | self._errors.append("%s:%d: %s" % (self._filename, line, error)) 300 | 301 | 302 | class TensorFlowCallVisitor(ast.NodeVisitor): 303 | """AST Visitor that finds TensorFlow Function calls. 304 | 305 | Updates function calls from old API version to new API version. 306 | """ 307 | 308 | def __init__(self, filename, lines): 309 | self._filename = filename 310 | self._file_edit = FileEditRecorder(filename) 311 | self._lines = lines 312 | self._api_change_spec = APIChangeSpec() 313 | 314 | def process(self, lines): 315 | return self._file_edit.process(lines) 316 | 317 | def generic_visit(self, node): 318 | ast.NodeVisitor.generic_visit(self, node) 319 | 320 | def _rename_functions(self, node, full_name): 321 | function_renames = self._api_change_spec.function_renames 322 | try: 323 | new_name = function_renames[full_name] 324 | self._file_edit.add("Renamed function %r to %r" % (full_name, 325 | new_name), 326 | node.lineno, node.col_offset, full_name, new_name) 327 | except KeyError: 328 | pass 329 | 330 | def _get_attribute_full_path(self, node): 331 | """Traverse an attribute to generate a full name e.g. tf.foo.bar. 332 | 333 | Args: 334 | node: A Node of type Attribute. 335 | 336 | Returns: 337 | a '.'-delimited full-name or None if the tree was not a simple form. 338 | i.e. `foo()+b).bar` returns None, while `a.b.c` would return "a.b.c". 339 | """ 340 | curr = node 341 | items = [] 342 | while not isinstance(curr, ast.Name): 343 | if not isinstance(curr, ast.Attribute): 344 | return None 345 | items.append(curr.attr) 346 | curr = curr.value 347 | items.append(curr.id) 348 | return ".".join(reversed(items)) 349 | 350 | def _find_true_position(self, node): 351 | """Return correct line number and column offset for a given node. 352 | 353 | This is necessary mainly because ListComp's location reporting reports 354 | the next token after the list comprehension list opening. 355 | 356 | Args: 357 | node: Node for which we wish to know the lineno and col_offset 358 | """ 359 | import re 360 | find_open = re.compile("^\s*(\\[).*$") 361 | find_string_chars = re.compile("['\"]") 362 | 363 | if isinstance(node, ast.ListComp): 364 | # Strangely, ast.ListComp returns the col_offset of the first token 365 | # after the '[' token which appears to be a bug. Workaround by 366 | # explicitly finding the real start of the list comprehension. 367 | line = node.lineno 368 | col = node.col_offset 369 | # loop over lines 370 | while 1: 371 | # Reverse the text to and regular expression search for whitespace 372 | text = self._lines[line-1] 373 | reversed_preceding_text = text[:col][::-1] 374 | # First find if a [ can be found with only whitespace between it and 375 | # col. 376 | m = find_open.match(reversed_preceding_text) 377 | if m: 378 | new_col_offset = col - m.start(1) - 1 379 | return line, new_col_offset 380 | else: 381 | if (reversed_preceding_text=="" or 382 | reversed_preceding_text.isspace()): 383 | line = line - 1 384 | prev_line = self._lines[line - 1] 385 | # TODO(aselle): 386 | # this is poor comment detection, but it is good enough for 387 | # cases where the comment does not contain string literal starting/ 388 | # ending characters. If ast gave us start and end locations of the 389 | # ast nodes rather than just start, we could use string literal 390 | # node ranges to filter out spurious #'s that appear in string 391 | # literals. 392 | comment_start = prev_line.find("#") 393 | if comment_start == -1: 394 | col = len(prev_line) -1 395 | elif find_string_chars.search(prev_line[comment_start:]) is None: 396 | col = comment_start 397 | else: 398 | return None, None 399 | else: 400 | return None, None 401 | # Most other nodes return proper locations (with notably does not), but 402 | # it is not possible to use that in an argument. 403 | return node.lineno, node.col_offset 404 | 405 | 406 | def visit_Call(self, node): # pylint: disable=invalid-name 407 | """Handle visiting a call node in the AST. 408 | 409 | Args: 410 | node: Current Node 411 | """ 412 | 413 | 414 | # Find a simple attribute name path e.g. "tf.foo.bar" 415 | full_name = self._get_attribute_full_path(node.func) 416 | 417 | # Make sure the func is marked as being part of a call 418 | node.func.is_function_for_call = True 419 | 420 | if full_name and full_name.startswith("tf."): 421 | # Call special handlers 422 | function_handles = self._api_change_spec.function_handle 423 | if full_name in function_handles: 424 | function_handles[full_name](self._file_edit, node) 425 | 426 | # Examine any non-keyword argument and make it into a keyword argument 427 | # if reordering required. 428 | function_reorders = self._api_change_spec.function_reorders 429 | function_keyword_renames = ( 430 | self._api_change_spec.function_keyword_renames) 431 | 432 | if full_name in function_reorders: 433 | reordered = function_reorders[full_name] 434 | for idx, arg in enumerate(node.args): 435 | lineno, col_offset = self._find_true_position(arg) 436 | if lineno is None or col_offset is None: 437 | self._file_edit.add( 438 | "Failed to add keyword %r to reordered function %r" 439 | % (reordered[idx], full_name), arg.lineno, arg.col_offset, 440 | "", "", 441 | error="A necessary keyword argument failed to be inserted.") 442 | else: 443 | keyword_arg = reordered[idx] 444 | if (full_name in function_keyword_renames and 445 | keyword_arg in function_keyword_renames[full_name]): 446 | keyword_arg = function_keyword_renames[full_name][keyword_arg] 447 | self._file_edit.add("Added keyword %r to reordered function %r" 448 | % (reordered[idx], full_name), lineno, 449 | col_offset, "", keyword_arg + "=") 450 | 451 | # Examine each keyword argument and convert it to the final renamed form 452 | renamed_keywords = ({} if full_name not in function_keyword_renames else 453 | function_keyword_renames[full_name]) 454 | for keyword in node.keywords: 455 | argkey = keyword.arg 456 | argval = keyword.value 457 | 458 | if argkey in renamed_keywords: 459 | argval_lineno, argval_col_offset = self._find_true_position(argval) 460 | if (argval_lineno is not None and argval_col_offset is not None): 461 | # TODO(aselle): We should scan backward to find the start of the 462 | # keyword key. Unfortunately ast does not give you the location of 463 | # keyword keys, so we are forced to infer it from the keyword arg 464 | # value. 465 | key_start = argval_col_offset - len(argkey) - 1 466 | key_end = key_start + len(argkey) + 1 467 | if self._lines[argval_lineno - 1][key_start:key_end] == argkey + "=": 468 | self._file_edit.add("Renamed keyword argument from %r to %r" % 469 | (argkey, renamed_keywords[argkey]), 470 | argval_lineno, 471 | argval_col_offset - len(argkey) - 1, 472 | argkey + "=", renamed_keywords[argkey] + "=") 473 | continue 474 | self._file_edit.add( 475 | "Failed to rename keyword argument from %r to %r" % 476 | (argkey, renamed_keywords[argkey]), 477 | argval.lineno, 478 | argval.col_offset - len(argkey) - 1, 479 | "", "", 480 | error="Failed to find keyword lexographically. Fix manually.") 481 | 482 | ast.NodeVisitor.generic_visit(self, node) 483 | 484 | def visit_Attribute(self, node): # pylint: disable=invalid-name 485 | """Handle bare Attributes i.e. [tf.foo, tf.bar]. 486 | 487 | Args: 488 | node: Node that is of type ast.Attribute 489 | """ 490 | full_name = self._get_attribute_full_path(node) 491 | if full_name and full_name.startswith("tf."): 492 | self._rename_functions(node, full_name) 493 | if full_name in self._api_change_spec.change_to_function: 494 | if not hasattr(node, "is_function_for_call"): 495 | new_text = full_name + "()" 496 | self._file_edit.add("Changed %r to %r"%(full_name, new_text), 497 | node.lineno, node.col_offset, full_name, new_text) 498 | 499 | ast.NodeVisitor.generic_visit(self, node) 500 | 501 | 502 | class TensorFlowCodeUpgrader(object): 503 | """Class that handles upgrading a set of Python files to TensorFlow 1.0.""" 504 | 505 | def __init__(self): 506 | pass 507 | 508 | def process_file(self, in_filename, out_filename): 509 | """Process the given python file for incompatible changes. 510 | 511 | Args: 512 | in_filename: filename to parse 513 | out_filename: output file to write to 514 | Returns: 515 | A tuple representing number of files processed, log of actions, errors 516 | """ 517 | 518 | # Write to a temporary file, just in case we are doing an implace modify. 519 | with open(in_filename, "r") as in_file, \ 520 | tempfile.NamedTemporaryFile("w", delete=False) as temp_file: 521 | ret = self.process_opened_file( 522 | in_filename, in_file, out_filename, temp_file) 523 | 524 | shutil.move(temp_file.name, out_filename) 525 | return ret 526 | 527 | # Broad exceptions are required here because ast throws whatever it wants. 528 | # pylint: disable=broad-except 529 | def process_opened_file(self, in_filename, in_file, out_filename, out_file): 530 | """Process the given python file for incompatible changes. 531 | 532 | This function is split out to facilitate StringIO testing from 533 | tf_upgrade_test.py. 534 | 535 | Args: 536 | in_filename: filename to parse 537 | in_file: opened file (or StringIO) 538 | out_filename: output file to write to 539 | out_file: opened file (or StringIO) 540 | Returns: 541 | A tuple representing number of files processed, log of actions, errors 542 | """ 543 | process_errors = [] 544 | text = "-" * 80 + "\n" 545 | text += "Processing file %r\n outputting to %r\n" % (in_filename, 546 | out_filename) 547 | text += "-" * 80 + "\n\n" 548 | 549 | parsed_ast = None 550 | lines = in_file.readlines() 551 | try: 552 | parsed_ast = ast.parse("".join(lines)) 553 | except Exception: 554 | text += "Failed to parse %r\n\n" % in_filename 555 | text += traceback.format_exc() 556 | if parsed_ast: 557 | visitor = TensorFlowCallVisitor(in_filename, lines) 558 | visitor.visit(parsed_ast) 559 | out_text, new_text, process_errors = visitor.process(lines) 560 | text += new_text 561 | if out_file: 562 | out_file.write(out_text) 563 | text += "\n" 564 | return 1, text, process_errors 565 | # pylint: enable=broad-except 566 | 567 | def process_tree(self, root_directory, output_root_directory): 568 | """Processes upgrades on an entire tree of python files in place. 569 | 570 | Note that only Python files. If you have custom code in other languages, 571 | you will need to manually upgrade those. 572 | 573 | Args: 574 | root_directory: Directory to walk and process. 575 | output_root_directory: Directory to use as base 576 | Returns: 577 | A tuple of files processed, the report string ofr all files, and errors 578 | """ 579 | 580 | # make sure output directory doesn't exist 581 | if output_root_directory and os.path.exists(output_root_directory): 582 | print("Output directory %r must not already exist." % ( 583 | output_root_directory)) 584 | sys.exit(1) 585 | 586 | # make sure output directory does not overlap with root_directory 587 | norm_root = os.path.split(os.path.normpath(root_directory)) 588 | norm_output = os.path.split(os.path.normpath(output_root_directory)) 589 | if norm_root == norm_output: 590 | print("Output directory %r same as input directory %r" % ( 591 | root_directory, output_root_directory)) 592 | sys.exit(1) 593 | 594 | # Collect list of files to process (we do this to correctly handle if the 595 | # user puts the output directory in some sub directory of the input dir) 596 | files_to_process = [] 597 | for dir_name, _, file_list in os.walk(root_directory): 598 | py_files = [f for f in file_list if f.endswith(".py")] 599 | for filename in py_files: 600 | fullpath = os.path.join(dir_name, filename) 601 | fullpath_output = os.path.join( 602 | output_root_directory, os.path.relpath(fullpath, root_directory)) 603 | files_to_process.append((fullpath, fullpath_output)) 604 | 605 | file_count = 0 606 | tree_errors = [] 607 | report = "" 608 | report += ("=" * 80) + "\n" 609 | report += "Input tree: %r\n" % root_directory 610 | report += ("=" * 80) + "\n" 611 | 612 | for input_path, output_path in files_to_process: 613 | output_directory = os.path.dirname(output_path) 614 | if not os.path.isdir(output_directory): 615 | os.makedirs(output_directory) 616 | file_count += 1 617 | _, l_report, l_errors = self.process_file(input_path, output_path) 618 | tree_errors += l_errors 619 | report += l_report 620 | return file_count, report, tree_errors 621 | 622 | 623 | if __name__ == "__main__": 624 | parser = argparse.ArgumentParser( 625 | formatter_class=argparse.RawDescriptionHelpFormatter, 626 | description="""Convert a TensorFlow Python file to 1.0 627 | 628 | Simple usage: 629 | tf_convert.py --infile foo.py --outfile bar.py 630 | tf_convert.py --intree ~/code/old --outtree ~/code/new 631 | """) 632 | parser.add_argument( 633 | "--infile", 634 | dest="input_file", 635 | help="If converting a single file, the name of the file " 636 | "to convert") 637 | parser.add_argument( 638 | "--outfile", 639 | dest="output_file", 640 | help="If converting a single file, the output filename.") 641 | parser.add_argument( 642 | "--intree", 643 | dest="input_tree", 644 | help="If converting a whole tree of files, the directory " 645 | "to read from (relative or absolute).") 646 | parser.add_argument( 647 | "--outtree", 648 | dest="output_tree", 649 | help="If converting a whole tree of files, the output " 650 | "directory (relative or absolute).") 651 | parser.add_argument( 652 | "--reportfile", 653 | dest="report_filename", 654 | help=("The name of the file where the report log is " 655 | "stored." 656 | "(default: %(default)s)"), 657 | default="report.txt") 658 | args = parser.parse_args() 659 | 660 | upgrade = TensorFlowCodeUpgrader() 661 | report_text = None 662 | report_filename = args.report_filename 663 | files_processed = 0 664 | if args.input_file: 665 | files_processed, report_text, errors = upgrade.process_file( 666 | args.input_file, args.output_file) 667 | files_processed = 1 668 | elif args.input_tree: 669 | files_processed, report_text, errors = upgrade.process_tree( 670 | args.input_tree, args.output_tree) 671 | else: 672 | parser.print_help() 673 | if report_text: 674 | open(report_filename, "w").write(report_text) 675 | print("TensorFlow 1.0 Upgrade Script") 676 | print("-----------------------------") 677 | print("Converted %d files\n" % files_processed) 678 | print("Detected %d errors that require attention" % len(errors)) 679 | print("-" * 80) 680 | print("\n".join(errors)) 681 | print("\nMake sure to read the detailed log %r\n" % report_filename) 682 | --------------------------------------------------------------------------------