├── .gitignore ├── README.md ├── arc.py ├── carc.py ├── data ├── setup_lfw.py └── setup_omniglot.py ├── data_workers.py ├── full_context.py ├── glimpse_ablation_arc_omniglot.py ├── image_augmenter.py ├── layers.py ├── main.py ├── one_shot_tests.py ├── vis_attn_arc_omniglot.py └── wrn.py /.gitignore: -------------------------------------------------------------------------------- 1 | results/ 2 | *.pyc 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ARC 2 | Code repository for reproducing the results in the paper - [Attentive Recurrent Comparators](https://arxiv.org/abs/1703.00767) 3 | -------------------------------------------------------------------------------- /arc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import theano 4 | import theano.tensor as T 5 | 6 | import lasagne 7 | from lasagne.layers import InputLayer, DenseLayer, DropoutLayer 8 | from lasagne.nonlinearities import sigmoid 9 | from lasagne.layers import get_all_params, get_output 10 | from lasagne.objectives import binary_crossentropy, binary_accuracy 11 | from lasagne.updates import adam 12 | from lasagne.layers import helper 13 | 14 | from layers import SimpleARC 15 | from data_workers import OmniglotOS, LFWVerif 16 | from main import train, test, serialize, deserialize 17 | 18 | import argparse 19 | 20 | 21 | parser = argparse.ArgumentParser(description="CLI for specifying hyper-parameters") 22 | parser.add_argument("-n", "--expt-name", type=str, default="", help="experiment name(for logging purposes)") 23 | parser.add_argument("--dataset", type=str, default="omniglot", help="omniglot/LFW") 24 | 25 | meta_data = vars(parser.parse_args()) 26 | meta_data["expt_name"] = "ARC_" + meta_data["dataset"] + meta_data["expt_name"] 27 | 28 | for md in meta_data.keys(): 29 | print md, meta_data[md] 30 | 31 | expt_name = meta_data["expt_name"] 32 | learning_rate = 1e-4 33 | image_size = 64 # 32 34 | attn_win = 6 # 4 35 | glimpses = 4 #8 36 | lstm_states = 512 37 | fg_bias_init = 0.0 # 0.2 38 | dropout = 0.3 # 0.2 39 | meta_data["n_iter"] = n_iter = 1500000 40 | batch_size = 128 41 | meta_data["num_output"] = 2 42 | 43 | print "... setting up the network" 44 | X = T.tensor4("input") 45 | y = T.imatrix("target") 46 | 47 | l_in = InputLayer(shape=(None, 1, image_size, image_size), input_var=X) 48 | l_noise = DropoutLayer(l_in, p=dropout) 49 | l_arc = SimpleARC(l_noise, lstm_states=lstm_states, image_size=image_size, attn_win=attn_win, 50 | glimpses=glimpses, fg_bias_init=fg_bias_init) 51 | l_y = DenseLayer(l_arc, 1, nonlinearity=sigmoid) 52 | 53 | prediction = get_output(l_y) 54 | prediction_clean = get_output(l_y, deterministic=True) 55 | embedding = get_output(l_arc, deterministic=True) 56 | 57 | loss = T.mean(binary_crossentropy(prediction, y)) 58 | accuracy = T.mean(binary_accuracy(prediction_clean, y)) 59 | 60 | params = get_all_params(l_y) 61 | updates = adam(loss, params, learning_rate=learning_rate) 62 | 63 | meta_data["num_param"] = lasagne.layers.count_params(l_y) 64 | print "number of parameters: ", meta_data["num_param"] 65 | 66 | print "... compiling" 67 | train_fn = theano.function([X, y], outputs=loss, updates=updates) 68 | val_fn = theano.function([X, y], outputs=[loss, accuracy]) 69 | embed_fn = theano.function([X], outputs=embedding) 70 | op_fn = theano.function([X], outputs=prediction_clean) 71 | 72 | print "... loading dataset" 73 | if meta_data["dataset"] == 'omniglot': 74 | worker = OmniglotOS(image_size=image_size, batch_size=batch_size) 75 | elif meta_data["dataset"] == 'lfw': 76 | worker = LFWVerif(image_size=image_size, batch_size=batch_size) 77 | 78 | meta_data, params = train(train_fn, val_fn, worker, meta_data, \ 79 | get_params=lambda: helper.get_all_param_values(l_y)) 80 | 81 | print "... testing" 82 | helper.set_all_param_values(l_y, params) 83 | meta_data = test(val_fn, worker, meta_data) 84 | 85 | serialize(params, expt_name + '.params') 86 | serialize(meta_data, expt_name + '.mtd') 87 | serialize(embed_fn, expt_name + '.emf') 88 | serialize(op_fn, expt_name + '.opf') 89 | -------------------------------------------------------------------------------- /carc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import theano 4 | import theano.tensor as T 5 | 6 | import lasagne 7 | from lasagne.layers import InputLayer, DenseLayer, DropoutLayer 8 | from lasagne.layers import batch_norm, BatchNormLayer, ExpressionLayer 9 | from lasagne.layers import Conv2DLayer as ConvLayer 10 | from lasagne.layers import ElemwiseSumLayer, NonlinearityLayer, GlobalPoolLayer 11 | from lasagne.nonlinearities import rectify, sigmoid 12 | from lasagne.init import HeNormal 13 | from lasagne.layers import get_all_params, get_all_layers, get_output 14 | from lasagne.regularization import regularize_layer_params 15 | from lasagne.objectives import binary_crossentropy, binary_accuracy 16 | from lasagne.updates import adam 17 | from lasagne.layers import helper 18 | 19 | from layers import ConvARC 20 | from data_workers import OmniglotOS, LFWVerif 21 | from main import train, test, serialize, deserialize 22 | 23 | import sys 24 | sys.setrecursionlimit(10000) 25 | 26 | import argparse 27 | 28 | 29 | def residual_block(l, increase_dim=False, projection=True, first=False, filters=16): 30 | if increase_dim: 31 | first_stride = (2, 2) 32 | else: 33 | first_stride = (1, 1) 34 | 35 | if first: 36 | bn_pre_relu = l 37 | else: 38 | bn_pre_conv = BatchNormLayer(l) 39 | bn_pre_relu = NonlinearityLayer(bn_pre_conv, rectify) 40 | 41 | conv_1 = batch_norm(ConvLayer(bn_pre_relu, num_filters=filters, filter_size=(3,3), stride=first_stride, nonlinearity=rectify, pad='same', W=HeNormal(gain='relu'))) 42 | dropout = DropoutLayer(conv_1, p=0.3) 43 | conv_2 = ConvLayer(dropout, num_filters=filters, filter_size=(3,3), stride=(1,1), nonlinearity=None, pad='same', W=HeNormal(gain='relu')) 44 | 45 | if increase_dim: 46 | projection = ConvLayer(l, num_filters=filters, filter_size=(1,1), stride=(2,2), nonlinearity=None, pad='same', b=None) 47 | block = ElemwiseSumLayer([conv_2, projection]) 48 | elif first: 49 | projection = ConvLayer(l, num_filters=filters, filter_size=(1,1), stride=(1,1), nonlinearity=None, pad='same', b=None) 50 | block = ElemwiseSumLayer([conv_2, projection]) 51 | else: 52 | block = ElemwiseSumLayer([conv_2, l]) 53 | 54 | return block 55 | 56 | 57 | parser = argparse.ArgumentParser(description="CLI for specifying hyper-parameters") 58 | parser.add_argument("-n", "--expt-name", type=str, default="", help="experiment name(for logging purposes)") 59 | 60 | parser.add_argument("--dataset", type=str, default="omniglot", help="omniglot/LFW") 61 | 62 | parser.add_argument("--wrn-depth", type=int, default=4, help="the resnet has depth equal to 4d+7") 63 | parser.add_argument("--wrn-width", type=int, default=2, help="width multiplier for each WRN block") 64 | 65 | meta_data = vars(parser.parse_args()) 66 | meta_data["expt_name"] = "ConvARC_" + meta_data["dataset"] + meta_data["expt_name"] 67 | 68 | for md in meta_data.keys(): 69 | print md, meta_data[md] 70 | 71 | expt_name = meta_data["expt_name"] 72 | learning_rate = 1e-4 73 | image_size = 32 74 | attn_win = 4 75 | glimpses = 8 76 | lstm_states = 256 77 | fg_bias_init = 0.2 78 | batch_size = 128 79 | meta_data["n_iter"] = n_iter = 100000 80 | wrn_n = meta_data["wrn_depth"] 81 | wrn_k = meta_data["wrn_width"] 82 | 83 | meta_data["num_output"] = 2 84 | 85 | print "... setting up the network" 86 | n_filters = {0: 16, 1: 16 * wrn_k, 2: 32 * wrn_k} 87 | 88 | X = T.tensor4("input") 89 | y = T.imatrix("target") 90 | 91 | l_in = InputLayer(shape=(None, 1, image_size, image_size), input_var=X) 92 | 93 | # first layer, output is 16 x 32 x 32 | (1) 94 | l = batch_norm(ConvLayer(l_in, num_filters=n_filters[0], filter_size=(3, 3), \ 95 | stride=(1, 1), nonlinearity=rectify, pad='same', W=HeNormal(gain='relu'))) 96 | 97 | # first stack of residual blocks, output is (16 * wrn_k) x 32 x 32 | (3 + 2 * (n - 1)) 98 | l = residual_block(l, first=True, filters=n_filters[1]) 99 | for _ in range(1, wrn_n): 100 | l = residual_block(l, filters=n_filters[1]) 101 | 102 | # second stack of residual blocks, output is (32 * wrn_k) x 16 x 16 | (3 + 2 * (n + 1)) 103 | l = residual_block(l, increase_dim=True, filters=n_filters[2]) 104 | for _ in range(1, (wrn_n+2)): 105 | l = residual_block(l, filters=n_filters[2]) 106 | 107 | bn_post_conv = BatchNormLayer(l) 108 | bn_post_relu = NonlinearityLayer(bn_post_conv, rectify) 109 | 110 | l_carc = ConvARC(bn_post_relu, num_filters=n_filters[2], lstm_states=lstm_states, image_size=16, 111 | attn_win=attn_win, glimpses=glimpses, fg_bias_init=fg_bias_init) 112 | l_y = DenseLayer(l_carc, num_units=1, nonlinearity=sigmoid) 113 | 114 | prediction = get_output(l_y) 115 | prediction_clean = get_output(l_y, deterministic=True) 116 | embedding = get_output(l_carc, deterministic=True) 117 | 118 | loss = T.mean(binary_crossentropy(prediction, y)) 119 | accuracy = T.mean(binary_accuracy(prediction_clean, y)) 120 | 121 | all_layers = get_all_layers(l_y) 122 | l2_penalty = 0.0001 * regularize_layer_params(all_layers, lasagne.regularization.l2) 123 | loss = loss + l2_penalty 124 | 125 | params = get_all_params(l_y, trainable=True) 126 | updates = adam(loss, params, learning_rate=learning_rate) 127 | 128 | meta_data["num_param"] = lasagne.layers.count_params(l_y) 129 | print "number of parameters: ", meta_data["num_param"] 130 | 131 | print "... compiling" 132 | train_fn = theano.function(inputs=[X, y], outputs=loss, updates=updates) 133 | val_fn = theano.function(inputs=[X, y], outputs=[loss, accuracy]) 134 | embed_fn = theano.function([X], outputs=embedding) 135 | op_fn = theano.function([X], outputs=prediction_clean) 136 | 137 | print "... loading dataset" 138 | if meta_data["dataset"] == 'omniglot': 139 | worker = OmniglotOS(image_size=image_size, batch_size=batch_size) 140 | elif meta_data["dataset"] == 'lfw': 141 | worker = LFWVerif(image_size=image_size, batch_size=batch_size) 142 | 143 | meta_data, params = train(train_fn, val_fn, worker, meta_data, \ 144 | get_params=lambda: helper.get_all_param_values(l_y)) 145 | 146 | print "... testing" 147 | helper.set_all_param_values(l_y, params) 148 | meta_data = test(val_fn, worker, meta_data) 149 | 150 | serialize(params, expt_name + '.params') 151 | serialize(meta_data, expt_name + '.mtd') 152 | serialize(embed_fn, expt_name + '.emf') 153 | serialize(op_fn, expt_name + '.opf') 154 | -------------------------------------------------------------------------------- /data/setup_lfw.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import os 4 | 5 | import PIL 6 | from PIL import Image 7 | 8 | 9 | def image2pixelarray(filepath): 10 | F = PIL.Image.open(filepath).resize((64, 64), resample=PIL.Image.LANCZOS).convert('L') 11 | F = list(F.getdata()) 12 | F = np.array(F) 13 | F = F.reshape((64, 64)) 14 | return F 15 | 16 | 17 | os.system('wget http://vis-www.cs.umass.edu/lfw/lfw.tgz') 18 | os.system('tar -xvf lfw.tgz') 19 | 20 | path = 'lfw' 21 | 22 | FACES = [] 23 | COUNTS = [] 24 | 25 | people = os.listdir(path) 26 | for person in people: 27 | person_path = os.path.join(path, person) 28 | faces = os.listdir(person_path) 29 | for face in faces: 30 | face_path = os.path.join(person_path, face) 31 | FACES.append(image2pixelarray(face_path)) 32 | COUNTS.append(len(faces)) 33 | 34 | FACES = np.array(FACES) 35 | COUNTS = np.array(COUNTS) 36 | 37 | os.system('rm -rf lfw') 38 | os.system('rm lfw.tgz') 39 | 40 | os.mkdir('LFW') 41 | np.save('LFW/faces', FACES) 42 | np.save('LFW/counts', COUNTS) 43 | -------------------------------------------------------------------------------- /data/setup_omniglot.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | from scipy.ndimage import imread 6 | 7 | 8 | os.system('wget https://github.com/brendenlake/omniglot/archive/master.zip') 9 | os.system('unzip -a master.zip') 10 | 11 | path_bg = os.path.join('omniglot-master', 'python', 'images_background.zip') 12 | path_ev = os.path.join('omniglot-master', 'python', 'images_evaluation.zip') 13 | path_os = os.path.join('omniglot-master', 'python', 'one-shot-classification', 'all_runs.zip') 14 | 15 | os.system('unzip -a ' + path_bg) 16 | os.system('unzip -a ' + path_ev) 17 | os.system('unzip -a ' + path_os) 18 | 19 | 20 | def omniglot_folder_to_NDarray(path_im): 21 | alphbts = os.listdir(path_im) 22 | ALL_IMGS = [] 23 | 24 | for alphbt in alphbts: 25 | chars = os.listdir(os.path.join(path_im, alphbt)) 26 | for char in chars: 27 | img_filenames = os.listdir(os.path.join(path_im, alphbt, char)) 28 | char_imgs = [] 29 | for img_fn in img_filenames: 30 | fn = os.path.join(path_im, alphbt, char, img_fn) 31 | I = imread(fn) 32 | I = np.invert(I) 33 | char_imgs.append(I) 34 | ALL_IMGS.append(char_imgs) 35 | 36 | return np.array(ALL_IMGS) 37 | 38 | 39 | all_bg = omniglot_folder_to_NDarray('images_background') 40 | all_ev = omniglot_folder_to_NDarray('images_evaluation') 41 | all_imgs = np.concatenate([all_bg, all_ev], axis=0) 42 | 43 | np.save('omniglot', all_imgs) 44 | 45 | TRIALS = [] 46 | LABELS = [] 47 | for run_num in range(1, 21): 48 | path = str(run_num) 49 | if len(path) == 1: 50 | path = '0' + path 51 | path = 'run' + path 52 | support_set = np.zeros((20, 105, 105), dtype='uint8') 53 | for img_num in range(1, 21): 54 | name = str(img_num) 55 | if len(name) == 1: 56 | name = '0' + name 57 | name = 'training/' + 'class' + name + '.png' 58 | filename = os.path.join(path, name) 59 | I = imread(filename) 60 | I = np.invert(I) 61 | support_set[img_num - 1] = I 62 | 63 | test_set = np.zeros((20, 105, 105), dtype='uint8') 64 | for img_num in range(1, 21): 65 | name = str(img_num) 66 | if len(name) == 1: 67 | name = '0' + name 68 | name = 'test/' + 'item' + name + '.png' 69 | filename = os.path.join(path, name) 70 | I = imread(filename) 71 | I = np.invert(I) 72 | test_set[img_num - 1] = I 73 | 74 | key_f = open(path + '/class_labels.txt', 'r') 75 | keys = key_f.readlines() 76 | key_f.close() 77 | matches = [int(key[-7:-5]) - 1 for key in keys] 78 | 79 | run = np.zeros((40 * 20, 105, 105), dtype='uint8') 80 | for i in range(20): 81 | k = i * 20 82 | run[k:k+20] = support_set 83 | k = 400 + i * 20 84 | run[k:k+20] = test_set[i] 85 | TRIALS.append(run) 86 | LABELS.append(matches) 87 | 88 | os.system('rm -rf ' + path) 89 | 90 | X_OS = np.array(TRIALS) 91 | y_OS = np.array(LABELS, dtype='int32') 92 | 93 | os.mkdir('one_shot') 94 | np.save('one_shot/X', X_OS) 95 | np.save('one_shot/y', y_OS) 96 | 97 | # clean up. just leave what is required. 98 | os.system('rm master.zip') 99 | os.system('rm -rf omniglot-master/') 100 | os.system('rm -rf images_background/') 101 | os.system('rm -rf images_evaluation/') 102 | 103 | print 'omniglot data preperation done, exiting ...' 104 | -------------------------------------------------------------------------------- /data_workers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.random import choice 3 | 4 | import theano 5 | 6 | from scipy.misc import imresize as resize 7 | 8 | from image_augmenter import ImageAugmenter 9 | 10 | from main import serialize, deserialize 11 | 12 | 13 | class Omniglot(object): 14 | def __init__(self, path='data/omniglot.npy', batch_size=128, image_size=32): 15 | """ 16 | path: path to omniglot.npy file produced by "data/setup_omniglot.py" script 17 | batch_size: the output is (2 * batch size, 1, image_size, image_size) 18 | X[i] & X[i + batch_size] are the pair 19 | image_size: size of the image 20 | data_split: in number of alphabets, e.g. [30, 10] means out of 50 Omniglot characters, 21 | 30 is for training, 10 for validation and the remaining(10) for testing 22 | within_alphabet: for verfication task, when 2 characters are sampled to form a pair, 23 | this flag specifies if should they be from the same alphabet/language 24 | --------------------- 25 | Data Augmentation Parameters: 26 | flip: here flipping both the images in a pair 27 | scale: x would scale image by + or - x% 28 | rotation_deg 29 | shear_deg 30 | translation_px: in both x and y directions 31 | """ 32 | chars = np.load(path) 33 | 34 | # resize the images 35 | resized_chars = np.zeros((1623, 20, image_size, image_size), dtype='uint8') 36 | for i in xrange(1623): 37 | for j in xrange(20): 38 | resized_chars[i, j] = resize(chars[i, j], (image_size, image_size)) 39 | chars = resized_chars 40 | 41 | self.mean_pixel = chars.mean() / 255.0 # used later for mean subtraction 42 | 43 | # starting index of each alphabet in a list of chars 44 | a_start = [0, 20, 49, 75, 116, 156, 180, 226, 240, 266, 300, 333, 355, 381, 45 | 424, 448, 496, 518, 534, 586, 633, 673, 699, 739, 780, 813, 46 | 827, 869, 892, 909, 964, 984, 1010, 1036, 1062, 1088, 1114, 47 | 1159, 1204, 1245, 1271, 1318, 1358, 1388, 1433, 1479, 1507, 48 | 1530, 1555, 1597] 49 | 50 | # size of each alphabet (num of chars) 51 | a_size = [20, 29, 26, 41, 40, 24, 46, 14, 26, 34, 33, 22, 26, 43, 24, 48, 22, 52 | 16, 52, 47, 40, 26, 40, 41, 33, 14, 42, 23, 17, 55, 20, 26, 26, 26, 53 | 26, 26, 45, 45, 41, 26, 47, 40, 30, 45, 46, 28, 23, 25, 42, 26] 54 | 55 | # each alphabet/language has different number of characters. 56 | # in order to uniformly sample all characters, we need weigh the probability 57 | # of sampling a alphabet by its size. p is that probability 58 | def size2p(size): 59 | s = np.array(size).astype('float64') 60 | return s / s.sum() 61 | 62 | self.size2p = size2p 63 | 64 | self.data = chars 65 | self.a_start = a_start 66 | self.a_size = a_size 67 | self.image_size = image_size 68 | self.batch_size = batch_size 69 | 70 | flip = True 71 | scale = 0.2 72 | rotation_deg = 20 73 | shear_deg = 10 74 | translation_px = 5 75 | self.augmentor = ImageAugmenter(image_size, image_size, 76 | hflip=flip, vflip=flip, 77 | scale_to_percent=1.0+scale, rotation_deg=rotation_deg, shear_deg=shear_deg, 78 | translation_x_px=translation_px, translation_y_px=translation_px) 79 | 80 | def fetch_batch(self, part): 81 | """ 82 | This outputs batch_size number of pairs 83 | Thus the actual number of images outputted is 2 * batch_size 84 | Say A & B form the half of a pair 85 | The Batch is divided into 4 parts: 86 | Dissimilar A Dissimilar B 87 | Similar A Similar B 88 | 89 | Corresponding images in Similar A and Similar B form the similar pair 90 | similarly, Dissimilar A and Dissimilar B form the dissimilar pair 91 | 92 | When flattened, the batch has 4 parts with indices: 93 | Dissimilar A 0 - batch_size / 2 94 | Similar A batch_size / 2 - batch_size 95 | Dissimilar B batch_size - 3 * batch_size / 2 96 | Similar B 3 * batch_size / 2 - batch_size 97 | 98 | """ 99 | pass 100 | 101 | 102 | class OmniglotVerif(Omniglot): 103 | def __init__(self, path='data/omniglot.npy', batch_size=128, image_size=32): 104 | Omniglot.__init__(self, path, batch_size, image_size) 105 | 106 | a_start = self.a_start 107 | a_size = self.a_size 108 | 109 | # slicing indices for splitting a_start & a_size 110 | i = 30 111 | j = 40 112 | starts = {} 113 | starts['train'], starts['val'], starts['test'] = a_start[:i], a_start[i:j], a_start[j:] 114 | sizes = {} 115 | sizes['train'], sizes['val'], sizes['test'] = a_size[:i], a_size[i:j], a_size[j:] 116 | 117 | size2p = self.size2p 118 | 119 | p = {} 120 | p['train'], p['val'], p['test'] = size2p(sizes['train']), size2p(sizes['val']), size2p(sizes['test']) 121 | 122 | self.starts = starts 123 | self.sizes = sizes 124 | self.p = p 125 | 126 | def fetch_batch(self, part): 127 | data = self.data 128 | starts = self.starts[part] 129 | sizes = self.sizes[part] 130 | p = self.p[part] 131 | image_size = self.image_size 132 | batch_size = self.batch_size 133 | 134 | num_alphbts = len(starts) 135 | 136 | X = np.zeros((2 * batch_size, image_size, image_size), dtype='uint8') 137 | for i in xrange(batch_size / 2): 138 | # choose similar chars 139 | same_idx = choice(range(starts[0], starts[-1] + sizes[-1])) 140 | 141 | # choose dissimilar chars within alphabet 142 | alphbt_idx = choice(num_alphbts, p=p) 143 | char_offset = choice(sizes[alphbt_idx], 2, replace=False) 144 | diff_idx = starts[alphbt_idx] + char_offset 145 | 146 | X[i], X[i + batch_size] = data[diff_idx, choice(20, 2)] 147 | X[i + batch_size / 2], X[i + 3 * batch_size / 2] = data[same_idx, choice(20, 2, replace=False)] 148 | 149 | y = np.zeros((batch_size, 1), dtype='int32') 150 | y[:batch_size / 2] = 0 151 | y[batch_size / 2:] = 1 152 | 153 | if part == 'train': 154 | X = self.augmentor.augment_batch(X) 155 | else: 156 | X = X / 255.0 157 | 158 | X = X - self.mean_pixel 159 | X = X[:, np.newaxis] 160 | X = X.astype(theano.config.floatX) 161 | 162 | return X, y 163 | 164 | 165 | class OmniglotOS(Omniglot): 166 | def __init__(self, path='data/omniglot.npy', batch_size=128, image_size=32): 167 | Omniglot.__init__(self, path, batch_size, image_size) 168 | 169 | a_start = self.a_start 170 | a_size = self.a_size 171 | 172 | num_train_chars = a_start[29] + a_size[29] 173 | 174 | train = self.data[:num_train_chars, :16] # (964, 16, H, W) 175 | val = self.data[:num_train_chars, 16:] # (964, 4, H, W) 176 | test = self.data[num_train_chars:] # (659, 20, H, W) 177 | 178 | # slicing indices for splitting a_start & a_size 179 | i = 30 180 | starts = {} 181 | starts['train'], starts['val'], starts['test'] = a_start[:i], a_start[:i], a_start[i:] 182 | sizes = {} 183 | sizes['train'], sizes['val'], sizes['test'] = a_size[:i], a_size[:i], a_size[i:] 184 | 185 | size2p = self.size2p 186 | 187 | p = {} 188 | p['train'], p['val'], p['test'] = size2p(sizes['train']), size2p(sizes['val']), size2p(sizes['test']) 189 | 190 | data = {} 191 | data['train'], data['val'], data['test'] = train, val, test 192 | 193 | num_drawers = {} 194 | num_drawers['train'], num_drawers['val'], num_drawers['test'] = 16, 4, 20 195 | 196 | self.data = data 197 | self.starts = starts 198 | self.sizes = sizes 199 | self.p = p 200 | self.num_drawers = num_drawers 201 | 202 | def fetch_batch(self, part): 203 | data = self.data[part] 204 | starts = self.starts[part] 205 | sizes = self.sizes[part] 206 | p = self.p[part] 207 | num_drawers = self.num_drawers[part] 208 | 209 | image_size = self.image_size 210 | batch_size = self.batch_size 211 | 212 | num_alphbts = len(starts) 213 | 214 | X = np.zeros((2 * batch_size, image_size, image_size), dtype='uint8') 215 | for i in xrange(batch_size / 2): 216 | # choose similar chars 217 | same_idx = choice(range(data.shape[0])) 218 | 219 | # choose dissimilar chars within alphabet 220 | alphbt_idx = choice(num_alphbts, p=p) 221 | char_offset = choice(sizes[alphbt_idx], 2, replace=False) 222 | diff_idx = starts[alphbt_idx] + char_offset - starts[0] 223 | 224 | X[i], X[i + batch_size] = data[diff_idx, choice(num_drawers, 2)] 225 | X[i + batch_size / 2], X[i + 3 * batch_size / 2] = data[same_idx, choice(num_drawers, 2, replace=False)] 226 | 227 | y = np.zeros((batch_size, 1), dtype='int32') 228 | y[:batch_size / 2] = 0 229 | y[batch_size / 2:] = 1 230 | 231 | if part == 'train': 232 | X = self.augmentor.augment_batch(X) 233 | else: 234 | X = X / 255.0 235 | 236 | X = X - self.mean_pixel 237 | X = X[:, np.newaxis] 238 | X = X.astype(theano.config.floatX) 239 | 240 | return X, y 241 | 242 | 243 | class OmniglotOSLake(object): 244 | def __init__(self, image_size=32): 245 | X = np.load('data/one_shot/X.npy') 246 | y = np.load('data/one_shot/y.npy') 247 | 248 | # resize the images 249 | resized_X = np.zeros((20, 800, image_size, image_size), dtype='uint8') 250 | for i in xrange(20): 251 | for j in xrange(800): 252 | resized_X[i, j] = resize(X[i, j], (image_size, image_size)) 253 | X = resized_X 254 | 255 | self.mean_pixel = 0.0805 # dataset mean pixel 256 | 257 | self.X = X 258 | self.y = y 259 | 260 | def fetch_batch(self): 261 | X = self.X 262 | y = self.y 263 | 264 | X = X / 255.0 265 | X = X - self.mean_pixel 266 | X = X[:, :, np.newaxis] 267 | X = X.astype(theano.config.floatX) 268 | 269 | y = y.astype('int32') 270 | 271 | return X, y 272 | 273 | 274 | class OmniglotVinyals(Omniglot): 275 | def __init__(self, path='data/omniglot.npy', num_trials=128, image_size=32): 276 | Omniglot.__init__(self, path, 0, image_size) 277 | del self.batch_size 278 | self.num_trials = num_trials 279 | 280 | def fetch_batch(self): 281 | data = self.data 282 | image_size = self.image_size 283 | num_trials = self.num_trials 284 | 285 | X = np.zeros((num_trials * 40, image_size, image_size), dtype='uint8') 286 | y = np.zeros(num_trials, dtype='int32') 287 | 288 | for t in range(num_trials): 289 | trial = np.zeros((2 * 20, image_size, image_size), dtype='uint8') 290 | char_choices = range(1200, 1623) # set of all possible chars 291 | key_char_idx = choice(char_choices) # this will be the char to be matched 292 | 293 | # sample 19 other chars excluding key 294 | char_choices.remove(key_char_idx) 295 | other_char_idxs = choice(char_choices, 19, replace=False) 296 | 297 | pos = range(20) 298 | key_char_pos = choice(pos) # position of the key char out of 20 pairs 299 | pos.remove(key_char_pos) 300 | other_char_pos = np.array(pos, dtype='int32') 301 | 302 | drawers = choice(20, 2, replace=False) 303 | trial[key_char_pos] = data[key_char_idx, drawers[0]] 304 | trial[other_char_pos] = data[other_char_idxs, drawers[0]] 305 | trial[20:] = data[key_char_idx, drawers[1]] 306 | 307 | k = t * 20 308 | X[k:k+20] = trial[:20] 309 | k = k + num_trials * 20 310 | X[k:k+20] = trial[20:] 311 | 312 | y[t] = key_char_pos 313 | 314 | X = X / 255.0 315 | X = X - self.mean_pixel 316 | X = X[:, np.newaxis] 317 | X = X.astype(theano.config.floatX) 318 | 319 | return X, y 320 | 321 | 322 | class LFWVerif(object): 323 | def __init__(self, batch_size = 128, split=[80, 10], image_size=64): 324 | faces = np.load('data/LFW/faces.npy') 325 | counts = np.load('data/LFW/counts.npy') 326 | 327 | num_people = len(counts) 328 | num_train = int(np.round(split[0] / 100.0 * num_people)) 329 | num_val = int(np.round(split[1] / 100.0 * num_people)) 330 | num_test = num_people - num_train - num_val 331 | 332 | i = num_train 333 | j = i + num_val 334 | k = j + num_test 335 | 336 | num_faces_so_far = np.cumsum(counts) 337 | train_faces = faces[:num_faces_so_far[i]] 338 | val_faces = faces[num_faces_so_far[i]:num_faces_so_far[j]] 339 | test_faces = faces[num_faces_so_far[j]:] 340 | 341 | self.mean_pixel = faces.mean() / 255.0 342 | 343 | train_counts = counts[:i] 344 | val_counts = counts[i:j] 345 | test_counts = counts[j:] 346 | 347 | faces = {} 348 | faces['train'], faces['val'], faces['test'] = train_faces, val_faces, test_faces 349 | 350 | counts = {} 351 | counts['train'], counts['val'], counts['test'] = train_counts, val_counts, test_counts 352 | 353 | self.i = i 354 | self.j = j 355 | self.k = k 356 | 357 | self.faces = faces 358 | self.counts = counts 359 | self.batch_size = batch_size 360 | 361 | vflip = False 362 | hflip = True 363 | scale = 0.2 364 | rotation_deg = 15 365 | shear_deg = 5 366 | translation_px = 10 367 | self.augmentor = ImageAugmenter(image_size, image_size, 368 | hflip=hflip, vflip=vflip, 369 | scale_to_percent=1.0+scale, rotation_deg=rotation_deg, shear_deg=shear_deg, 370 | translation_x_px=translation_px, translation_y_px=translation_px) 371 | 372 | 373 | def fetch_batch(self, part): 374 | faces = self.faces[part] 375 | counts = self.counts[part] 376 | batch_size = self.batch_size 377 | 378 | num_people = counts.shape[0] 379 | person_start_idx = np.cumsum(counts) - counts 380 | 381 | X = np.zeros((batch_size * 2, 64, 64), dtype='uint8') 382 | 383 | while(1): 384 | try: 385 | person_idxs = choice(num_people, batch_size, replace=False) 386 | face_sub_idxs = np.array([choice(counts[idx]) for idx in person_idxs]) 387 | face_idxs = person_start_idx[person_idxs] 388 | net_index = face_idxs + face_sub_idxs 389 | X[:batch_size/2] = faces[net_index[:batch_size/2]] 390 | X[batch_size:3*batch_size/2] = faces[net_index[-batch_size/2:]] 391 | 392 | # sample similar 393 | similar_p = np.array(counts >= 2, dtype='float64') 394 | similar_p /= similar_p.sum() 395 | 396 | person_idxs = choice(num_people, batch_size/2, replace=False, p=similar_p) 397 | face_sub_idxs = np.array([choice(counts[idx], 2) for idx in person_idxs]) 398 | face_idxs = person_start_idx[person_idxs] 399 | faces_idxsA = face_idxs + face_sub_idxs[:, 0] 400 | faces_idxsB = face_idxs + face_sub_idxs[:, 1] 401 | X[batch_size/2:batch_size] = faces[faces_idxsA] 402 | X[-batch_size/2:] = faces[faces_idxsB] 403 | 404 | except IndexError: 405 | continue 406 | break 407 | 408 | y = np.zeros((batch_size, 1), dtype='int32') 409 | y[:batch_size / 2] = 0 410 | y[batch_size / 2:] = 1 411 | 412 | if part == 'train': 413 | X = self.augmentor.augment_batch(X) 414 | else: 415 | X = X / 255.0 416 | 417 | X = X - self.mean_pixel 418 | X = X[:, np.newaxis] 419 | X = X.astype(theano.config.floatX) 420 | 421 | return X, y 422 | -------------------------------------------------------------------------------- /full_context.py: -------------------------------------------------------------------------------- 1 | expt_name = 'ConvARC_OSFC' 2 | emf = 'ConvARC_OS.emf' 3 | embedding_size = 256 4 | 5 | num_trials = 32 6 | image_size = 32 7 | num_states = 128 8 | 9 | import numpy as np 10 | from numpy.random import choice 11 | 12 | import theano 13 | import theano.tensor as T 14 | 15 | import lasagne 16 | from lasagne.layers import InputLayer, DenseLayer, NonlinearityLayer 17 | from lasagne.layers import LSTMLayer 18 | from lasagne.layers import ConcatLayer, ReshapeLayer 19 | from lasagne.layers import Gate 20 | from lasagne.nonlinearities import tanh, elu, softmax 21 | from lasagne.init import HeNormal, Orthogonal, Constant 22 | from lasagne.layers import get_all_params, get_output 23 | from lasagne.objectives import categorical_crossentropy, categorical_accuracy 24 | from lasagne.updates import adam 25 | from lasagne.layers import helper 26 | 27 | from scipy.misc import imresize as resize 28 | from image_augmenter import ImageAugmenter 29 | from main import train, test, serialize, deserialize 30 | 31 | from data_workers import OmniglotOS 32 | 33 | 34 | def fetch_batch(self, part): 35 | data = self.data[part] 36 | starts = self.starts[part] 37 | sizes = self.sizes[part] 38 | p = self.p[part] 39 | num_drawers = self.num_drawers[part] 40 | image_size = self.image_size 41 | num_trials = self.num_trials 42 | num_alphbts = len(starts) 43 | 44 | X = np.zeros((num_trials * 40, image_size, image_size), dtype='uint8') 45 | y = np.zeros(num_trials, dtype='int32') 46 | 47 | for t in range(num_trials): 48 | trial = np.zeros((2 * 20, image_size, image_size), dtype='uint8') 49 | alphbt_idx = choice(num_alphbts) # choose an alphabet 50 | char_choices = range(sizes[alphbt_idx]) # set of all possible chars 51 | key_char_idx = choice(char_choices) # this will be the char to be matched 52 | 53 | # sample 19 other chars excluding key 54 | char_choices.pop(key_char_idx) 55 | other_char_idxs = choice(char_choices, 19) 56 | 57 | key_char_idx = starts[alphbt_idx] + key_char_idx - starts[0] 58 | other_char_idxs = starts[alphbt_idx] + other_char_idxs - starts[0] 59 | 60 | pos = range(20) 61 | key_char_pos = choice(pos) # position of the key char out of 20 pairs 62 | pos.pop(key_char_pos) 63 | other_char_pos = np.array(pos, dtype='int32') 64 | 65 | trial[key_char_pos] = data[key_char_idx, choice(num_drawers)] 66 | trial[other_char_pos] = data[other_char_idxs, choice(num_drawers)] 67 | trial[20:] = data[key_char_idx, choice(num_drawers)] 68 | 69 | k = t * 20 70 | X[k:k+20] = trial[:20] 71 | k = k + num_trials * 20 72 | X[k:k+20] = trial[20:] 73 | 74 | y[t] = key_char_pos 75 | 76 | if part == 'train': 77 | X = self.augmentor.augment_batch(X) 78 | else: 79 | X = X / 255.0 80 | 81 | X = X - self.mean_pixel 82 | X = X[:, np.newaxis] 83 | X = X.astype(theano.config.floatX) 84 | 85 | E = embedding_fn(X) 86 | E = E.reshape(num_trials, 20, embedding_size) 87 | return E, y 88 | 89 | OmniglotOS.fetch_batch = fetch_batch 90 | 91 | worker = OmniglotOS(image_size=image_size) 92 | del worker.batch_size 93 | worker.num_trials = num_trials 94 | 95 | embedding_fn = deserialize(emf) 96 | 97 | 98 | X = T.tensor3("input") 99 | y = T.ivector("target") 100 | 101 | l_in = InputLayer(shape=(None, 20, embedding_size), input_var=X) 102 | 103 | gate_parameters = Gate(W_in=Orthogonal(), W_hid=Orthogonal(), b=Constant(0.)) 104 | forget_gate_parameters = Gate(W_in=Orthogonal(), W_hid=Orthogonal(), b=Constant(1.)) 105 | cell_parameters = Gate(W_in=Orthogonal(), W_hid=Orthogonal(), b=Constant(0.), W_cell=None, nonlinearity=tanh) 106 | 107 | l_lstm_up = LSTMLayer(l_in, num_states, 108 | ingate=gate_parameters, forgetgate=forget_gate_parameters, 109 | cell=cell_parameters, outgate=gate_parameters, 110 | learn_init=True, grad_clipping=100.0) 111 | l_lstm_down = LSTMLayer(l_in, num_states, backwards=True, 112 | ingate=gate_parameters, forgetgate=forget_gate_parameters, 113 | cell=cell_parameters, outgate=gate_parameters, 114 | learn_init=True, grad_clipping=100.0) 115 | 116 | l_merge = ConcatLayer([l_lstm_up, l_lstm_down]) 117 | l_rshp1 = ReshapeLayer(l_merge, (-1, 2 * num_states)) 118 | l_dense = DenseLayer(l_rshp1, 1, W=HeNormal(gain='relu'), nonlinearity=elu) 119 | l_rshp2 = ReshapeLayer(l_dense, (-1, 20)) 120 | l_y = NonlinearityLayer(l_rshp2, softmax) 121 | 122 | prediction = get_output(l_y) 123 | 124 | loss = T.mean(categorical_crossentropy(prediction, y)) 125 | accuracy = T.mean(categorical_accuracy(prediction, y)) 126 | 127 | params = get_all_params(l_y, trainable=True) 128 | updates = adam(loss, params, learning_rate=3e-4) 129 | 130 | print "... compiling" 131 | train_fn = theano.function(inputs=[X, y], outputs=loss, updates=updates) 132 | val_fn = theano.function(inputs=[X, y], outputs=[loss, accuracy]) 133 | op_fn = theano.function([X], outputs=prediction) 134 | 135 | meta_data = {} 136 | meta_data["n_iter"] = 50000 137 | meta_data["num_output"] = 20 138 | 139 | meta_data, params = train(train_fn, val_fn, worker, meta_data, get_params=lambda: helper.get_all_param_values(l_y)) 140 | meta_data = test(val_fn, worker, meta_data) 141 | 142 | serialize(params, expt_name + '.params') 143 | serialize(meta_data, expt_name + '.mtd') 144 | serialize(op_fn, expt_name + '.opf') 145 | -------------------------------------------------------------------------------- /glimpse_ablation_arc_omniglot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import theano 4 | import theano.tensor as T 5 | 6 | import lasagne 7 | from lasagne.layers import InputLayer, DenseLayer 8 | from lasagne.nonlinearities import sigmoid 9 | from lasagne.layers import get_all_params, get_output 10 | from lasagne.objectives import binary_crossentropy, binary_accuracy 11 | from lasagne.updates import adam 12 | from lasagne.layers import helper 13 | 14 | from layers import SimpleARC 15 | from data_workers import OmniglotOS 16 | from main import serialize, deserialize 17 | 18 | 19 | def create_embedder_fn(glimpses): 20 | X = T.tensor4("input") 21 | l_in = InputLayer(shape=(None, 1, 32, 32), input_var=X) 22 | l_arc = SimpleARC(l_in, lstm_states=512, image_size=32, attn_win=4, 23 | glimpses=glimpses, fg_bias_init=0.0) 24 | embedding = get_output(l_arc, deterministic=True) 25 | embedding_fn = theano.function([X], outputs=embedding) 26 | 27 | params = deserialize('ARC_OS' + '.params') 28 | helper.set_all_param_values(l_arc, params[:2]) 29 | 30 | return embedding_fn 31 | 32 | 33 | worker = OmniglotOS(image_size=32, batch_size=1024) 34 | 35 | X_test, y_test = worker.fetch_batch('test') 36 | 37 | for glimpses in range(1, 9): 38 | embedding_fn = create_embedder_fn(glimpses) 39 | 40 | X = T.matrix("embedding") 41 | y = T.imatrix("target") 42 | l_in = InputLayer(shape=(None, 512), input_var=X) 43 | l_y = DenseLayer(l_in, 1, nonlinearity=sigmoid) 44 | prediction = get_output(l_y) 45 | loss = T.mean(binary_crossentropy(prediction, y)) 46 | accuracy = T.mean(binary_accuracy(prediction, y)) 47 | params = get_all_params(l_y) 48 | updates = adam(loss, params, learning_rate=1e-3) 49 | train_fn = theano.function([X, y], outputs=loss, updates=updates) 50 | val_fn = theano.function([X, y], outputs=[loss, accuracy]) 51 | 52 | for i in range(250): 53 | X_train, y_train = worker.fetch_batch('train') 54 | train_fn(embedding_fn(X_train), y_train) 55 | 56 | X_train, y_train = worker.fetch_batch('train') 57 | train_loss = train_fn(embedding_fn(X_train), y_train) 58 | val_loss, val_acc = val_fn(embedding_fn(X_test), y_test) 59 | print "number of glimpses per image: ", glimpses 60 | print "\ttraining performance:" 61 | print "\t\t loss:", train_loss 62 | print "\ttesting performance:" 63 | print "\t\t loss:", val_loss 64 | print "\t\t accuracy:", val_acc 65 | 66 | params = helper.get_all_param_values(l_y) 67 | serialize(params, str(glimpses) + 'glimpses' + '.params') 68 | -------------------------------------------------------------------------------- /image_augmenter.py: -------------------------------------------------------------------------------- 1 | """ 2 | taken and modified from: https://github.com/aleju/ImageAugmenter 3 | """ 4 | 5 | # -*- coding: utf-8 -*- 6 | """Wrapper functions and classes around scikit-images AffineTransformation. 7 | Simplifies augmentation of images in machine learning. 8 | 9 | Example usage: 10 | img_width = 32 # width of the images 11 | img_height = 32 # height of the images 12 | images = ... # e.g. load via scipy.misc.imload(filename) 13 | 14 | # For each image: randomly flip it horizontally (50% chance), 15 | # randomly rotate it between -20 and +20 degrees, randomly translate 16 | # it on the x-axis between -5 and +5 pixel. 17 | ia = ImageAugmenter(img_width, img_height, hlip=True, rotation_deg=20, 18 | translation_x_px=5) 19 | augmented_images = ia.augment_batch(images) 20 | """ 21 | from skimage import transform as tf 22 | import numpy as np 23 | import random 24 | 25 | def is_minmax_tuple(param): 26 | """Returns whether the parameter is a tuple containing two values. 27 | 28 | Used in create_aug_matrices() and probably useless everywhere else. 29 | 30 | Args: 31 | param: The parameter to check (whether it is a tuple of length 2). 32 | 33 | Returns: 34 | Boolean 35 | """ 36 | return type(param) is tuple and len(param) == 2 37 | 38 | def create_aug_matrices(nb_matrices, img_width_px, img_height_px, 39 | scale_to_percent=1.0, scale_axis_equally=False, 40 | rotation_deg=0, shear_deg=0, 41 | translation_x_px=0, translation_y_px=0, 42 | seed=None): 43 | """Creates the augmentation matrices that may later be used to transform 44 | images. 45 | 46 | This is a wrapper around scikit-image's transform.AffineTransform class. 47 | You can apply those matrices to images using the apply_aug_matrices() 48 | function. 49 | 50 | Args: 51 | nb_matrices: How many matrices to return, e.g. 100 returns 100 different 52 | random-generated matrices (= 100 different transformations). 53 | img_width_px: Width of the images that will be transformed later 54 | on (same as the width of each of the matrices). 55 | img_height_px: Height of the images that will be transformed later 56 | on (same as the height of each of the matrices). 57 | scale_to_percent: Same as in ImageAugmenter.__init__(). 58 | Up to which percentage the images may be 59 | scaled/zoomed. The negative scaling is automatically derived 60 | from this value. A value of 1.1 allows scaling by any value 61 | between -10% and +10%. You may set min and max values yourself 62 | by using a tuple instead, like (1.1, 1.2) to scale between 63 | +10% and +20%. Default is 1.0 (no scaling). 64 | scale_axis_equally: Same as in ImageAugmenter.__init__(). 65 | Whether to always scale both axis (x and y) 66 | in the same way. If set to False, then e.g. the Augmenter 67 | might scale the x-axis by 20% and the y-axis by -5%. 68 | Default is False. 69 | rotation_deg: Same as in ImageAugmenter.__init__(). 70 | By how much the image may be rotated around its 71 | center (in degrees). The negative rotation will automatically 72 | be derived from this value. E.g. a value of 20 allows any 73 | rotation between -20 degrees and +20 degrees. You may set min 74 | and max values yourself by using a tuple instead, e.g. (5, 20) 75 | to rotate between +5 und +20 degrees. Default is 0 (no 76 | rotation). 77 | shear_deg: Same as in ImageAugmenter.__init__(). 78 | By how much the image may be sheared (in degrees). The 79 | negative value will automatically be derived from this value. 80 | E.g. a value of 20 allows any shear between -20 degrees and 81 | +20 degrees. You may set min and max values yourself by using a 82 | tuple instead, e.g. (5, 20) to shear between +5 und +20 83 | degrees. Default is 0 (no shear). 84 | translation_x_px: Same as in ImageAugmenter.__init__(). 85 | By up to how many pixels the image may be 86 | translated (moved) on the x-axis. The negative value will 87 | automatically be derived from this value. E.g. a value of +7 88 | allows any translation between -7 and +7 pixels on the x-axis. 89 | You may set min and max values yourself by using a tuple 90 | instead, e.g. (5, 20) to translate between +5 und +20 pixels. 91 | Default is 0 (no translation on the x-axis). 92 | translation_y_px: Same as in ImageAugmenter.__init__(). 93 | See translation_x_px, just for the y-axis. 94 | seed: Seed to use for python's and numpy's random functions. 95 | 96 | Returns: 97 | List of augmentation matrices. 98 | """ 99 | assert nb_matrices > 0 100 | assert img_width_px > 0 101 | assert img_height_px > 0 102 | assert is_minmax_tuple(scale_to_percent) or scale_to_percent >= 1.0 103 | assert is_minmax_tuple(rotation_deg) or rotation_deg >= 0 104 | assert is_minmax_tuple(shear_deg) or shear_deg >= 0 105 | assert is_minmax_tuple(translation_x_px) or translation_x_px >= 0 106 | assert is_minmax_tuple(translation_y_px) or translation_y_px >= 0 107 | 108 | if seed is not None: 109 | random.seed(seed) 110 | np.random.seed(seed) 111 | 112 | result = [] 113 | 114 | shift_x = int(img_width_px / 2.0) 115 | shift_y = int(img_height_px / 2.0) 116 | 117 | # prepare min and max values for 118 | # scaling/zooming (min/max values) 119 | if is_minmax_tuple(scale_to_percent): 120 | scale_x_min = scale_to_percent[0] 121 | scale_x_max = scale_to_percent[1] 122 | else: 123 | scale_x_min = scale_to_percent 124 | scale_x_max = 1.0 - (scale_to_percent - 1.0) 125 | assert scale_x_min > 0.0 126 | #if scale_x_max >= 2.0: 127 | # warnings.warn("Scaling by more than 100 percent (%.2f)." % (scale_x_max,)) 128 | scale_y_min = scale_x_min # scale_axis_equally affects the random value generation 129 | scale_y_max = scale_x_max 130 | 131 | # rotation (min/max values) 132 | if is_minmax_tuple(rotation_deg): 133 | rotation_deg_min = rotation_deg[0] 134 | rotation_deg_max = rotation_deg[1] 135 | else: 136 | rotation_deg_min = (-1) * int(rotation_deg) 137 | rotation_deg_max = int(rotation_deg) 138 | 139 | # shear (min/max values) 140 | if is_minmax_tuple(shear_deg): 141 | shear_deg_min = shear_deg[0] 142 | shear_deg_max = shear_deg[1] 143 | else: 144 | shear_deg_min = (-1) * int(shear_deg) 145 | shear_deg_max = int(shear_deg) 146 | 147 | # translation x-axis (min/max values) 148 | if is_minmax_tuple(translation_x_px): 149 | translation_x_px_min = translation_x_px[0] 150 | translation_x_px_max = translation_x_px[1] 151 | else: 152 | translation_x_px_min = (-1) * translation_x_px 153 | translation_x_px_max = translation_x_px 154 | 155 | # translation y-axis (min/max values) 156 | if is_minmax_tuple(translation_y_px): 157 | translation_y_px_min = translation_y_px[0] 158 | translation_y_px_max = translation_y_px[1] 159 | else: 160 | translation_y_px_min = (-1) * translation_y_px 161 | translation_y_px_max = translation_y_px 162 | 163 | # create nb_matrices randomized affine transformation matrices 164 | for _ in range(nb_matrices): 165 | # generate random values for scaling, rotation, shear, translation 166 | scale_x = random.uniform(scale_x_min, scale_x_max) 167 | scale_y = random.uniform(scale_y_min, scale_y_max) 168 | if not scale_axis_equally: 169 | scale_y = random.uniform(scale_y_min, scale_y_max) 170 | else: 171 | scale_y = scale_x 172 | rotation = np.deg2rad(random.randint(rotation_deg_min, rotation_deg_max)) 173 | shear = np.deg2rad(random.randint(shear_deg_min, shear_deg_max)) 174 | translation_x = random.randint(translation_x_px_min, translation_x_px_max) 175 | translation_y = random.randint(translation_y_px_min, translation_y_px_max) 176 | 177 | # create three affine transformation matrices 178 | # 1st one moves the image to the top left, 2nd one transforms it, 3rd one 179 | # moves it back to the center. 180 | # The movement is neccessary, because rotation is applied to the top left 181 | # and not to the image's center (same for scaling and shear). 182 | matrix_to_topleft = tf.SimilarityTransform(translation=[-shift_x, -shift_y]) 183 | matrix_transforms = tf.AffineTransform(scale=(scale_x, scale_y), 184 | rotation=rotation, shear=shear, 185 | translation=(translation_x, 186 | translation_y)) 187 | matrix_to_center = tf.SimilarityTransform(translation=[shift_x, shift_y]) 188 | 189 | # Combine the three matrices to one affine transformation (one matrix) 190 | matrix = matrix_to_topleft + matrix_transforms + matrix_to_center 191 | 192 | # one matrix is ready, add it to the result 193 | result.append(matrix.inverse) 194 | 195 | return result 196 | 197 | def apply_aug_matrices(images, matrices, transform_channels_equally=True, 198 | channel_is_first_axis=False, random_order=True, 199 | mode="constant", cval=0.0, interpolation_order=1, 200 | seed=None): 201 | """Augment the given images using the given augmentation matrices. 202 | 203 | This function is a wrapper around scikit-image's transform.warp(). 204 | It is expected to be called by ImageAugmenter.augment_batch(). 205 | The matrices may be generated by create_aug_matrices(). 206 | 207 | Args: 208 | images: Same as in ImageAugmenter.augment_batch(). 209 | Numpy array (dtype: uint8, i.e. values 0-255) with the images. 210 | Expected shape is either (image-index, height, width) for 211 | grayscale images or (image-index, channel, height, width) for 212 | images with channels (e.g. RGB) where the channel has the first 213 | index or (image-index, height, width, channel) for images with 214 | channels, where the channel is the last index. 215 | If your shape is (image-index, channel, width, height) then 216 | you must also set channel_is_first_axis=True in the constructor. 217 | matrices: A list of augmentation matrices as produced by 218 | create_aug_matrices(). 219 | transform_channels_equally: Same as in ImageAugmenter.__init__(). 220 | Whether to apply the exactly same 221 | transformations to each channel of an image (True). Setting 222 | it to False allows different transformations per channel, 223 | e.g. the red-channel might be rotated by +20 degrees, while 224 | the blue channel (of the same image) might be rotated 225 | by -5 degrees. If you don't have any channels (2D grayscale), 226 | you can simply ignore this setting. 227 | Default is True (transform all equally). 228 | channel_is_first_axis: Same as in ImageAugmenter.__init__(). 229 | Whether the channel (e.g. RGB) is the first 230 | axis of each image (True) or the last axis (False). 231 | False matches the scipy and PIL implementation and is the 232 | default. If your images are 2D-grayscale then you can ignore 233 | this setting (as the augmenter will ignore it too). 234 | random_order: Whether to apply the augmentation matrices in a random 235 | order (True, e.g. the 2nd matrix might be applied to the 236 | 5th image) or in the given order (False, e.g. the 2nd matrix might 237 | be applied to the 2nd image). 238 | Notice that for multi-channel images (e.g. RGB) this function 239 | will use a different matrix for each channel, unless 240 | transform_channels_equally is set to True. 241 | mode: Parameter used for the transform.warp-function of scikit-image. 242 | Can usually be ignored. 243 | cval: Parameter used for the transform.warp-function of scikit-image. 244 | Defines the fill color for "new" pixels, e.g. for empty areas 245 | after rotations. (0.0 is black, 1.0 is white.) 246 | interpolation_order: Parameter used for the transform.warp-function of 247 | scikit-image. Defines the order of all interpolations used to 248 | generate the new/augmented image. See their documentation for 249 | further details. 250 | seed: Seed to use for python's and numpy's random functions. 251 | """ 252 | # images must be numpy array 253 | assert type(images).__module__ == np.__name__, "Expected numpy array for " \ 254 | "parameter 'images'." 255 | 256 | # images must have uint8 as dtype (0-255) 257 | assert images.dtype.name == "uint8", "Expected numpy.uint8 as image dtype." 258 | 259 | # 3 axis total (2 per image) for grayscale, 260 | # 4 axis total (3 per image) for RGB (usually) 261 | assert len(images.shape) in [3, 4], """Expected 'images' parameter to have 262 | either shape (image index, y, x) for greyscale 263 | or (image index, channel, y, x) / (image index, y, x, channel) 264 | for multi-channel (usually color) images.""" 265 | 266 | if seed: 267 | np.random.seed(seed) 268 | 269 | nb_images = images.shape[0] 270 | 271 | # estimate number of channels, set to 1 if there is no axis channel, 272 | # otherwise it will usually be 3 273 | has_channels = False 274 | nb_channels = 1 275 | if len(images.shape) == 4: 276 | has_channels = True 277 | if channel_is_first_axis: 278 | nb_channels = images.shape[1] # first axis within each image 279 | else: 280 | nb_channels = images.shape[3] # last axis within each image 281 | 282 | # whether to apply the transformations directly to the whole image 283 | # array (True) or for each channel individually (False) 284 | apply_directly = not has_channels or (transform_channels_equally 285 | and not channel_is_first_axis) 286 | 287 | # We generate here the order in which the matrices may be applied. 288 | # At the end, order_indices will contain the index of the matrix to use 289 | # for each image, e.g. [15, 2] would mean, that the 15th matrix will be 290 | # applied to the 0th image, the 2nd matrix to the 1st image. 291 | # If the images gave multiple channels (e.g. RGB) and 292 | # transform_channels_equally has been set to False, we will need one 293 | # matrix per channel instead of per image. 294 | 295 | # 0 to nb_images, but restart at 0 if index is beyond number of matrices 296 | len_indices = nb_images if apply_directly else nb_images * nb_channels 297 | if random_order: 298 | # Notice: This way to choose random matrices is concise, but can create 299 | # problems if there is a low amount of images and matrices. 300 | # E.g. suppose that 2 images are ought to be transformed by either 301 | # 0px translation on the x-axis or 1px translation. So 50% of all 302 | # matrices translate by 0px and 50% by 1px. The following method 303 | # will randomly choose a combination of the two matrices for the 304 | # two images (matrix 0 for image 0 and matrix 0 for image 1, 305 | # matrix 0 for image 0 and matrix 1 for image 1, ...). 306 | # In 50% of these cases, a different matrix will be chosen for image 0 307 | # and image 1 (matrices 0, 1 or matrices 1, 0). But 50% of these 308 | # "different" matrices (different index) will be the same, as 50% 309 | # translate by 1px and 50% by 0px. As a result, 75% of all augmentations 310 | # will transform both images in the same way. 311 | # The effect decreases if more matrices or images are chosen. 312 | order_indices = np.random.random_integers(0, len(matrices) - 1, len_indices) 313 | else: 314 | # monotonously growing indexes (each by +1), but none of them may be 315 | # higher than or equal to the number of matrices 316 | order_indices = np.arange(0, len_indices) % len(matrices) 317 | 318 | result = np.zeros(images.shape, dtype=np.float32) 319 | matrix_number = 0 320 | 321 | # iterate over every image, find out which matrix to apply and then use 322 | # that matrix to augment the image 323 | for img_idx, image in enumerate(images): 324 | if apply_directly: 325 | # we can apply the matrix to the whole numpy array of the image 326 | # at the same time, so we do that to save time (instead of eg. three 327 | # steps for three channels as in the else-part) 328 | matrix = matrices[order_indices[matrix_number]] 329 | result[img_idx, ...] = tf.warp(image, matrix, mode=mode, cval=cval, 330 | order=interpolation_order) 331 | matrix_number += 1 332 | else: 333 | # we cant apply the matrix to the whole image in one step, instead 334 | # we have to apply it to each channel individually. that happens 335 | # if the channel is the first axis of each image (incompatible with 336 | # tf.warp()) or if it was explicitly requested via 337 | # transform_channels_equally=False. 338 | for channel_idx in range(nb_channels): 339 | matrix = matrices[order_indices[matrix_number]] 340 | if channel_is_first_axis: 341 | warped = tf.warp(image[channel_idx], matrix, mode=mode, 342 | cval=cval, order=interpolation_order) 343 | result[img_idx, channel_idx, ...] = warped 344 | else: 345 | warped = tf.warp(image[..., channel_idx], matrix, mode=mode, 346 | cval=cval, order=interpolation_order) 347 | result[img_idx, ..., channel_idx] = warped 348 | 349 | if not transform_channels_equally: 350 | matrix_number += 1 351 | if transform_channels_equally: 352 | matrix_number += 1 353 | 354 | return result 355 | 356 | class ImageAugmenter(object): 357 | """Helper class to randomly augment images, usually for neural networks. 358 | 359 | Example usage: 360 | img_width = 32 # width of the images 361 | img_height = 32 # height of the images 362 | images = ... # e.g. load via scipy.misc.imload(filename) 363 | 364 | # For each image: randomly flip it horizontally (50% chance), 365 | # randomly rotate it between -20 and +20 degrees, randomly translate 366 | # it on the x-axis between -5 and +5 pixel. 367 | ia = ImageAugmenter(img_width, img_height, hlip=True, rotation_deg=20, 368 | translation_x_px=5) 369 | augmented_images = ia.augment_batch(images) 370 | """ 371 | def __init__(self, img_width_px, img_height_px, channel_is_first_axis=False, 372 | hflip=False, vflip=False, 373 | scale_to_percent=1.0, scale_axis_equally=False, 374 | rotation_deg=0, shear_deg=0, 375 | translation_x_px=0, translation_y_px=0, 376 | transform_channels_equally=True): 377 | """ 378 | Args: 379 | img_width_px: The intended width of each image in pixels. 380 | img_height_px: The intended height of each image in pixels. 381 | channel_is_first_axis: Whether the channel (e.g. RGB) is the first 382 | axis of each image (True) or the last axis (False). 383 | False matches the scipy and PIL implementation and is the 384 | default. If your images are 2D-grayscale then you can ignore 385 | this setting (as the augmenter will ignore it too). 386 | hflip: Whether to randomly flip images horizontally (on the y-axis). 387 | You may choose either False (no horizontal flipping), 388 | True (flip with probability 0.5) or use a float 389 | value (probability) between 0.0 and 1.0. Default is False. 390 | vflip: Whether to randomly flip images vertically (on the x-axis). 391 | You may choose either False (no vertical flipping), 392 | True (flip with probability 0.5) or use a float 393 | value (probability) between 0.0 and 1.0. Default is False. 394 | scale_to_percent: Up to which percentage the images may be 395 | scaled/zoomed. The negative scaling is automatically derived 396 | from this value. A value of 1.1 allows scaling by any value 397 | between -10% and +10%. You may set min and max values yourself 398 | by using a tuple instead, like (1.1, 1.2) to scale between 399 | +10% and +20%. Default is 1.0 (no scaling). 400 | scale_axis_equally: Whether to always scale both axis (x and y) 401 | in the same way. If set to False, then e.g. the Augmenter 402 | might scale the x-axis by 20% and the y-axis by -5%. 403 | Default is False. 404 | rotation_deg: By how much the image may be rotated around its 405 | center (in degrees). The negative rotation will automatically 406 | be derived from this value. E.g. a value of 20 allows any 407 | rotation between -20 degrees and +20 degrees. You may set min 408 | and max values yourself by using a tuple instead, e.g. (5, 20) 409 | to rotate between +5 und +20 degrees. Default is 0 (no 410 | rotation). 411 | shear_deg: By how much the image may be sheared (in degrees). The 412 | negative value will automatically be derived from this value. 413 | E.g. a value of 20 allows any shear between -20 degrees and 414 | +20 degrees. You may set min and max values yourself by using a 415 | tuple instead, e.g. (5, 20) to shear between +5 und +20 416 | degrees. Default is 0 (no shear). 417 | translation_x_px: By up to how many pixels the image may be 418 | translated (moved) on the x-axis. The negative value will 419 | automatically be derived from this value. E.g. a value of +7 420 | allows any translation between -7 and +7 pixels on the x-axis. 421 | You may set min and max values yourself by using a tuple 422 | instead, e.g. (5, 20) to translate between +5 und +20 pixels. 423 | Default is 0 (no translation on the x-axis). 424 | translation_y_px: See translation_x_px, just for the y-axis. 425 | transform_channels_equally: Whether to apply the exactly same 426 | transformations to each channel of an image (True). Setting 427 | it to False allows different transformations per channel, 428 | e.g. the red-channel might be rotated by +20 degrees, while 429 | the blue channel (of the same image) might be rotated 430 | by -5 degrees. If you don't have any channels (2D grayscale), 431 | you can simply ignore this setting. 432 | Default is True (transform all equally). 433 | """ 434 | self.img_width_px = img_width_px 435 | self.img_height_px = img_height_px 436 | self.channel_is_first_axis = channel_is_first_axis 437 | 438 | self.hflip_prob = 0.0 439 | # note: we have to check first for floats, otherwise "hflip == True" 440 | # will evaluate to true if hflip is 1.0. So chosing 1.0 (100%) would 441 | # result in hflip_prob to be set to 0.5 (50%). 442 | if isinstance(hflip, float): 443 | assert hflip >= 0.0 and hflip <= 1.0 444 | self.hflip_prob = hflip 445 | elif hflip == True: 446 | self.hflip_prob = 0.5 447 | elif hflip == False: 448 | self.hflip_prob = 0.0 449 | else: 450 | raise Exception("Unexpected value for parameter 'hflip'.") 451 | 452 | self.vflip_prob = 0.0 453 | if isinstance(vflip, float): 454 | assert vflip >= 0.0 and vflip <= 1.0 455 | self.vflip_prob = vflip 456 | elif vflip == True: 457 | self.vflip_prob = 0.5 458 | elif vflip == False: 459 | self.vflip_prob = 0.0 460 | else: 461 | raise Exception("Unexpected value for parameter 'vflip'.") 462 | 463 | self.scale_to_percent = scale_to_percent 464 | self.scale_axis_equally = scale_axis_equally 465 | self.rotation_deg = rotation_deg 466 | self.shear_deg = shear_deg 467 | self.translation_x_px = translation_x_px 468 | self.translation_y_px = translation_y_px 469 | self.transform_channels_equally = transform_channels_equally 470 | self.cval = 0.0 471 | self.interpolation_order = 1 472 | self.pregenerated_matrices = None 473 | 474 | def pregenerate_matrices(self, nb_matrices, seed=None): 475 | """Pregenerate/cache augmentation matrices. 476 | 477 | If matrices are pregenerated, augment_batch() will reuse them on 478 | each call. The augmentations will not always be the same, 479 | as the order of the matrices will be randomized (when 480 | they are applied to the images). The requirement for that is though 481 | that you pregenerate enough of them (e.g. a couple thousand). 482 | 483 | Note that generating the augmentation matrices is usually fast 484 | and only starts to make sense if you process millions of small images 485 | or many tens of thousands of big images. 486 | 487 | Each call to this method results in pregenerating a new set of matrices, 488 | e.g. to replace a list of matrices that has been used often enough. 489 | 490 | Calling this method with nb_matrices set to 0 will remove the 491 | pregenerated matrices and augment_batch() returns to its default 492 | behaviour of generating new matrices on each call. 493 | 494 | Args: 495 | nb_matrices: The number of matrices to pregenerate. E.g. a few 496 | thousand. If set to 0, the matrices will be generated again on 497 | each call of augment_batch(). 498 | seed: A random seed to use. 499 | """ 500 | assert nb_matrices >= 0 501 | if nb_matrices == 0: 502 | self.pregenerated_matrices = None 503 | else: 504 | matrices = create_aug_matrices(nb_matrices, 505 | self.img_width_px, 506 | self.img_height_px, 507 | scale_to_percent=self.scale_to_percent, 508 | scale_axis_equally=self.scale_axis_equally, 509 | rotation_deg=self.rotation_deg, 510 | shear_deg=self.shear_deg, 511 | translation_x_px=self.translation_x_px, 512 | translation_y_px=self.translation_y_px, 513 | seed=seed) 514 | self.pregenerated_matrices = matrices 515 | 516 | def augment_batch(self, images, seed=None): 517 | """Augments a batch of images. 518 | 519 | Applies all settings (rotation, shear, translation, ...) that 520 | have been chosen in the constructor. 521 | 522 | Args: 523 | images: Numpy array (dtype: uint8, i.e. values 0-255) with the images. 524 | Expected shape is either (image-index, height, width) for 525 | grayscale images or (image-index, channel, height, width) for 526 | images with channels (e.g. RGB) where the channel has the first 527 | index or (image-index, height, width, channel) for images with 528 | channels, where the channel is the last index. 529 | If your shape is (image-index, channel, width, height) then 530 | you must also set channel_is_first_axis=True in the constructor. 531 | seed: Seed to use for python's and numpy's random functions. 532 | Default is None (dont use a seed). 533 | 534 | Returns: 535 | Augmented images as numpy array of dtype float32 (i.e. values 536 | are between 0.0 and 1.0). 537 | """ 538 | shape = images.shape 539 | nb_channels = 0 540 | if len(shape) == 3: 541 | # shape like (image_index, y-axis, x-axis) 542 | assert shape[1] == self.img_height_px 543 | assert shape[2] == self.img_width_px 544 | nb_channels = 1 545 | elif len(shape) == 4: 546 | if not self.channel_is_first_axis: 547 | # shape like (image-index, y-axis, x-axis, channel-index) 548 | assert shape[1] == self.img_height_px 549 | assert shape[2] == self.img_width_px 550 | nb_channels = shape[3] 551 | else: 552 | # shape like (image-index, channel-index, y-axis, x-axis) 553 | assert shape[2] == self.img_height_px 554 | assert shape[3] == self.img_width_px 555 | nb_channels = shape[1] 556 | else: 557 | msg = "Mismatch between images shape %s and " \ 558 | "predefined image width/height (%d/%d)." 559 | raise Exception(msg % (str(shape), self.img_width_px, self.img_height_px)) 560 | 561 | if seed: 562 | random.seed(seed) 563 | np.random.seed(seed) 564 | 565 | # -------------------------------- 566 | # horizontal and vertical flipping/mirroring 567 | # -------------------------------- 568 | # This should be done before applying the affine matrices, as otherwise 569 | # contents of image might already be rotated/translated out of the image. 570 | # It is done with numpy instead of the affine matrices, because 571 | # scikit-image doesn't offer a nice interface to add mirroring/flipping 572 | # to affine transformations. The numpy operations are O(1), so they 573 | # shouldn't have a noticeable effect on runtimes. They also won't suffer 574 | # from interpolation problems. 575 | if self.hflip_prob > 0 or self.vflip_prob > 0: 576 | # TODO this currently ignores the setting in 577 | # transform_channels_equally and will instead always flip all 578 | # channels equally 579 | 580 | # if this is simply a view, then the input array gets flipped too 581 | # for some reason 582 | images_flipped = np.copy(images) 583 | #images_flipped = images.view() 584 | 585 | if len(shape) == 4 and self.channel_is_first_axis: 586 | # roll channel to the last axis 587 | # swapaxes doesnt work here, because 588 | # (image index, channel, y, x) 589 | # would be turned into 590 | # (image index, x, y, channel) 591 | # and y needs to come before x 592 | images_flipped = np.rollaxis(images_flipped, 1, 4) 593 | 594 | y_p = self.hflip_prob 595 | x_p = self.vflip_prob 596 | batch_size = images.shape[0] / 2 597 | for i in range(batch_size): 598 | if y_p > 0 and random.random() < y_p: 599 | images_flipped[i] = np.fliplr(images_flipped[i]) 600 | images_flipped[i+batch_size] = np.fliplr(images_flipped[i+batch_size]) 601 | if x_p > 0 and random.random() < x_p: 602 | images_flipped[i] = np.flipud(images_flipped[i]) 603 | images_flipped[i+batch_size] = np.flipud(images_flipped[i+batch_size]) 604 | 605 | if len(shape) == 4 and self.channel_is_first_axis: 606 | # roll channel back to the second axis (index 1) 607 | images_flipped = np.rollaxis(images_flipped, 3, 1) 608 | images = images_flipped 609 | 610 | # -------------------------------- 611 | # if no augmentation has been chosen, stop early 612 | # for improved performance (evade applying matrices) 613 | # -------------------------------- 614 | if self.pregenerated_matrices is None \ 615 | and self.scale_to_percent == 1.0 and self.rotation_deg == 0 \ 616 | and self.shear_deg == 0 \ 617 | and self.translation_x_px == 0 and self.translation_y_px == 0: 618 | return np.array(images, dtype=np.float32) / 255 619 | 620 | # -------------------------------- 621 | # generate transformation matrices 622 | # -------------------------------- 623 | if self.pregenerated_matrices is not None: 624 | matrices = self.pregenerated_matrices 625 | else: 626 | # estimate the number of matrices required 627 | if self.transform_channels_equally: 628 | nb_matrices = shape[0] 629 | else: 630 | nb_matrices = shape[0] * nb_channels 631 | 632 | # generate matrices 633 | matrices = create_aug_matrices(nb_matrices, 634 | self.img_width_px, 635 | self.img_height_px, 636 | scale_to_percent=self.scale_to_percent, 637 | scale_axis_equally=self.scale_axis_equally, 638 | rotation_deg=self.rotation_deg, 639 | shear_deg=self.shear_deg, 640 | translation_x_px=self.translation_x_px, 641 | translation_y_px=self.translation_y_px, 642 | seed=seed) 643 | 644 | # -------------------------------- 645 | # apply transformation matrices (i.e. augment images) 646 | # -------------------------------- 647 | return apply_aug_matrices(images, matrices, 648 | transform_channels_equally=self.transform_channels_equally, 649 | channel_is_first_axis=self.channel_is_first_axis, 650 | cval=self.cval, interpolation_order=self.interpolation_order, 651 | seed=seed) 652 | 653 | def plot_image(self, image, nb_repeat=40, show_plot=True): 654 | """Plot augmented variations of an image. 655 | 656 | This method takes an image and plots it by default in 40 differently 657 | augmented versions. 658 | 659 | This method is intended to visualize the strength of your chosen 660 | augmentations (so for debugging). 661 | 662 | Args: 663 | image: The image to plot. 664 | nb_repeat: How often to plot the image. Each time it is plotted, 665 | the chosen augmentation will be different. (Default: 40). 666 | show_plot: Whether to show the plot. False makes sense if you 667 | don't have a graphical user interface on the machine. 668 | (Default: True) 669 | 670 | Returns: 671 | The figure of the plot. 672 | Use figure.savefig() to save the image. 673 | """ 674 | if len(image.shape) == 2: 675 | images = np.resize(image, (nb_repeat, image.shape[0], image.shape[1])) 676 | else: 677 | images = np.resize(image, (nb_repeat, image.shape[0], image.shape[1], 678 | image.shape[2])) 679 | return self.plot_images(images, True, show_plot=show_plot) 680 | 681 | def plot_images(self, images, augment, show_plot=True, figure=None): 682 | """Plot augmented variations of images. 683 | 684 | The images will all be shown in the same window. 685 | It is recommended to not plot too many of them (i.e. stay below 100). 686 | 687 | This method is intended to visualize the strength of your chosen 688 | augmentations (so for debugging). 689 | 690 | Args: 691 | images: A numpy array of images. See augment_batch(). 692 | augment: Whether to augment the images (True) or just display 693 | them in the way they are (False). 694 | show_plot: Whether to show the plot. False makes sense if you 695 | don't have a graphical user interface on the machine. 696 | (Default: True) 697 | figure: The figure of the plot in which to draw the images. 698 | Provide the return value of this function (from a prior call) 699 | to draw in the same plot window again. Chosing 'None' will 700 | create a new figure. (Default is None.) 701 | 702 | Returns: 703 | The figure of the plot. 704 | Use figure.savefig() to save the image. 705 | """ 706 | import matplotlib.pyplot as plt 707 | import matplotlib.cm as cm 708 | 709 | if augment: 710 | images = self.augment_batch(images) 711 | 712 | # (Lists of) Grayscale images have the shape (image index, y, x) 713 | # Multi-Channel images therefore must have 4 or more axes here 714 | if len(images.shape) >= 4: 715 | # The color-channel is expected to be the last axis by matplotlib 716 | # therefore exchange the axes, if its the first one here 717 | if self.channel_is_first_axis: 718 | images = np.rollaxis(images, 1, 4) 719 | 720 | nb_cols = 10 721 | nb_rows = 1 + int(images.shape[0] / nb_cols) 722 | if figure is not None: 723 | fig = figure 724 | plt.figure(fig.number) 725 | fig.clear() 726 | else: 727 | fig = plt.figure(figsize=(10, 10)) 728 | 729 | for i, image in enumerate(images): 730 | image = images[i] 731 | 732 | plot_number = i + 1 733 | ax = fig.add_subplot(nb_rows, nb_cols, plot_number, xticklabels=[], 734 | yticklabels=[]) 735 | ax.set_axis_off() 736 | # "cmap" should restrict the color map to grayscale, but strangely 737 | # also works well with color images 738 | imgplot = plt.imshow(image, cmap=cm.Greys_r, aspect="equal") 739 | 740 | # not showing the plot might be useful e.g. on clusters 741 | if show_plot: 742 | plt.show() 743 | 744 | return fig 745 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import theano 4 | import theano.tensor as T 5 | from theano.ifelse import ifelse 6 | 7 | import lasagne 8 | 9 | dtype = theano.config.floatX 10 | PI = np.pi 11 | 12 | 13 | def ortho_init(shape): 14 | """ 15 | taken from: https://github.com/Lasagne/Lasagne/blob/master/lasagne/init.py#L327-L367 16 | """ 17 | a = np.random.normal(0.0, 1.0, shape) 18 | u, _, v = np.linalg.svd(a, full_matrices=False) 19 | W = u if u.shape == shape else v # pick the one with the correct shape 20 | return W.astype(dtype) 21 | 22 | 23 | def normal_init(shape, sigma): 24 | W = np.random.normal(0.0, sigma, shape) 25 | return W.astype(dtype) 26 | 27 | 28 | def batched_dot(A, B): 29 | C = A.dimshuffle([0, 1, 2, 'x']) * B.dimshuffle([0, 'x', 1, 2]) 30 | return C.sum(axis=-2) 31 | 32 | 33 | class BaseARC(lasagne.layers.Layer): 34 | def __init__(self, incoming, num_input, num_glimpse_params, lstm_states, image_size, attn_win, \ 35 | glimpses, fg_bias_init, final_state_only=True, **kwargs): 36 | super(BaseARC, self).__init__(incoming, **kwargs) 37 | 38 | W_lstm = np.zeros((4 * lstm_states, num_input + lstm_states + 1), dtype=dtype) 39 | for i in range(4): 40 | W_lstm[i*lstm_states:(i+1)*lstm_states, :num_input] = ortho_init(shape=(lstm_states, num_input)) 41 | W_lstm[i*lstm_states:(i+1)*lstm_states, num_input:-1] = ortho_init(shape=(lstm_states, lstm_states)) 42 | W_lstm[2*lstm_states:3*lstm_states, -1] = fg_bias_init 43 | 44 | W_g = normal_init(shape=(num_glimpse_params, lstm_states), sigma=0.01) 45 | 46 | self.W_lstm = self.add_param(W_lstm, (4 * lstm_states, num_input + lstm_states + 1), name='W_lstm') 47 | self.W_g = self.add_param(W_g, (num_glimpse_params, lstm_states), name='W_g') 48 | 49 | self.lstm_states = lstm_states 50 | self.image_size = image_size 51 | self.attn_win = attn_win 52 | self.glimpses = glimpses 53 | self.final_state_only = final_state_only 54 | 55 | def get_filterbanks(self, gp): 56 | attn_win = self.attn_win 57 | image_size = self.image_size 58 | 59 | # (3, B) 60 | center_y = gp[:, 0].dimshuffle([0, 'x']) 61 | center_x = gp[:, 1].dimshuffle([0, 'x']) 62 | delta = 1.0 - T.abs_(gp[:, 2]).dimshuffle([0, 'x']) 63 | gamma = T.exp(1.0 - 2 * T.abs_(gp[:, 2])).dimshuffle([0, 'x', 'x']) 64 | 65 | center_y = (image_size - 1) * (center_y + 1.0) / 2.0 66 | center_x = (image_size - 1) * (center_x + 1.0) / 2.0 67 | delta = image_size / attn_win * delta 68 | 69 | rng = T.arange(attn_win, dtype=dtype) - attn_win / 2.0 + 0.5 70 | cX = center_x + delta * rng 71 | cY = center_y + delta * rng 72 | 73 | a = T.arange(image_size, dtype=dtype) 74 | b = T.arange(image_size, dtype=dtype) 75 | 76 | F_X = 1.0 + ((a - cX.dimshuffle([0, 1, 'x'])) / gamma) ** 2.0 77 | F_Y = 1.0 + ((b - cY.dimshuffle([0, 1, 'x'])) / gamma) ** 2.0 78 | F_X = 1.0 / (PI * gamma * F_X) 79 | F_Y = 1.0 / (PI * gamma * F_Y) 80 | F_X = F_X / (F_X.sum(axis=-1).dimshuffle([0, 1, 'x']) + 1e-4) 81 | F_Y = F_Y / (F_Y.sum(axis=-1).dimshuffle([0, 1, 'x']) + 1e-4) 82 | 83 | return F_X, F_Y 84 | 85 | def attend(self, I, H, W): 86 | raise NotImplementedError('This method must be implemented by subclassed layers') 87 | 88 | def get_output_for(self, input, **kwargs): 89 | # input is 4D tensor: (batch_size, num_filters, 0, 1) 90 | image_size = self.image_size 91 | lstm_states = self.lstm_states 92 | attn_win = self.attn_win 93 | 94 | B = input.shape[0] / 2 # number of pairs in batch 95 | odd_input = input[:B] 96 | even_input = input[B:] 97 | 98 | def step(glimpse_count, c_tm1, h_tm1, odd_input, even_input, W_lstm, W_g): 99 | turn = T.eq(glimpse_count % 2, 0) 100 | I = ifelse(turn, even_input, odd_input) 101 | 102 | glimpse = self.attend(I, h_tm1, W_g) # (B, attn_win, attn_win) 103 | flat_glimpse = glimpse.reshape((B, -1)) 104 | 105 | lstm_ip = T.concatenate([flat_glimpse, h_tm1, T.ones((B, 1))], axis=1) # (B, num_input + states + 1) 106 | pre_activation = T.dot(W_lstm, lstm_ip.T) # result: (4 * states, B) 107 | 108 | z = T.tanh(pre_activation[0*lstm_states:1*lstm_states]) 109 | i = T.nnet.sigmoid(pre_activation[1*lstm_states:2*lstm_states]) 110 | f = T.nnet.sigmoid(pre_activation[2*lstm_states:3*lstm_states]) 111 | o = T.nnet.sigmoid(pre_activation[3*lstm_states:4*lstm_states]) 112 | 113 | c_t = f * c_tm1.T + i * z # all in (states, B) 114 | h_t = o * T.tanh(c_t) 115 | 116 | return glimpse_count + 1, c_t.T, h_t.T # c, h in (B, states) 117 | 118 | glimpse_count_0 = 0 119 | c_0 = T.zeros((B, lstm_states)) 120 | h_0 = T.zeros((B, lstm_states)) 121 | 122 | _, cells, hiddens = theano.scan(fn=step, non_sequences=[odd_input, even_input, self.W_lstm, self.W_g], 123 | outputs_info=[glimpse_count_0, c_0, h_0], n_steps=self.glimpses * 2)[0] 124 | 125 | if self.final_state_only: 126 | return hiddens[-1] 127 | else: 128 | return hiddens 129 | 130 | def get_output_shape_for(self, input_shape): 131 | # the batch size in both must be input_shape[0] / 2 132 | # but since that it is None, we leave it as it is 133 | if self.final_state_only: 134 | return (input_shape[0], self.lstm_states) 135 | else: 136 | return (2 * self.glimpses, input_shape[0], self.lstm_states) 137 | 138 | 139 | class SimpleARC(BaseARC): 140 | def __init__(self, incoming, lstm_states, image_size, attn_win, glimpses, \ 141 | fg_bias_init, final_state_only=True, **kwargs): 142 | 143 | BaseARC.__init__(self, incoming, attn_win**2, 3, lstm_states, image_size, \ 144 | attn_win, glimpses, fg_bias_init, final_state_only, **kwargs) 145 | 146 | def attend(self, I, H, W): 147 | I = I[:, 0] 148 | gp = T.dot(W, H.T).T 149 | F_X, F_Y = self.get_filterbanks(gp) 150 | G = batched_dot(batched_dot(F_Y, I), F_X.transpose([0, 2, 1])) 151 | return G 152 | 153 | 154 | class ConvARC(BaseARC): 155 | def __init__(self, incoming, num_filters, lstm_states, image_size, attn_win, glimpses, \ 156 | fg_bias_init, final_state_only=True, **kwargs): 157 | 158 | self.num_filters = num_filters 159 | 160 | BaseARC.__init__(self, incoming, num_filters * attn_win ** 2, 3, lstm_states, \ 161 | image_size, attn_win, glimpses, fg_bias_init, final_state_only, **kwargs) 162 | 163 | def attend(self, I, H, W): 164 | I = I.reshape((-1, self.image_size, self.image_size)) 165 | num_filters = self.num_filters 166 | gp = T.dot(W, H.T).T 167 | F_X, F_Y = self.get_filterbanks(gp) 168 | F_X = F_X.repeat(num_filters, axis=0) 169 | F_Y = F_Y.repeat(num_filters, axis=0) 170 | G = batched_dot(batched_dot(F_Y, I), F_X.transpose([0, 2, 1])) 171 | return G 172 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import time 4 | import gzip 5 | import cPickle 6 | 7 | 8 | def train(train_fn, val_fn, worker, meta_data, get_params): 9 | n_iter = meta_data["n_iter"] 10 | val_freq = 1000 11 | val_num_batches = 250 12 | patience = 0.1 13 | 14 | meta_data["training_loss"] = [] 15 | meta_data["validation_loss"] = [] 16 | meta_data["validation_accuracy"] = [] 17 | 18 | best_val_loss = np.inf 19 | best_val_acc = 0.0 20 | best_iter_n = 0 21 | best_params = get_params() 22 | 23 | print "... training" 24 | try: 25 | smooth_loss = np.log(meta_data["num_output"]) 26 | iter_n = 0 27 | while iter_n < n_iter: 28 | iter_n += 1 29 | tick = time.clock() 30 | X_train, y_train = worker.fetch_batch('train') 31 | batch_loss = train_fn(X_train, y_train) 32 | tock = time.clock() 33 | meta_data["training_loss"].append((iter_n, batch_loss)) 34 | 35 | smooth_loss = 0.99 * smooth_loss + 0.01 * batch_loss 36 | print "iteration: ", iter_n, " train loss: ", np.round(smooth_loss, 4), "\t", np.round((tock - tick), 3) * 1000, "ms" 37 | 38 | if np.isnan(batch_loss): 39 | print "... NaN Detected, terminating" 40 | break 41 | 42 | if iter_n % val_freq == 0: 43 | net_val_loss, net_val_acc = 0.0, 0.0 44 | for i in xrange(val_num_batches): 45 | X_val, y_val = worker.fetch_batch('val') 46 | val_loss, val_acc = val_fn(X_val, y_val) 47 | net_val_loss += val_loss 48 | net_val_acc += val_acc 49 | val_loss = net_val_loss / val_num_batches 50 | val_acc = net_val_acc / val_num_batches 51 | 52 | meta_data["validation_loss"].append((iter_n, val_loss)) 53 | meta_data["validation_accuracy"].append((iter_n, val_acc)) 54 | 55 | print "====" * 20, "\n", "validation loss: ", val_loss, ", validation accuracy: ", val_acc * 100.0, "\n", "====" * 20 56 | 57 | if val_acc > best_val_acc: 58 | best_val_acc = val_acc 59 | best_iter_n = iter_n 60 | 61 | if val_loss < best_val_loss: 62 | best_val_loss = val_loss 63 | best_params = get_params() 64 | 65 | if val_loss > best_val_loss + patience: 66 | break 67 | 68 | 69 | except KeyboardInterrupt: 70 | pass 71 | 72 | print "... training done" 73 | print "best validation accuracy: ", best_val_acc * 100.0, " at iteration number: ", best_iter_n 74 | print "... exiting training regime" 75 | 76 | return meta_data, best_params 77 | 78 | 79 | def test(test_fn, worker, meta_data): 80 | print "... testing" 81 | 82 | test_num_batches = 500 83 | 84 | net_test_loss, net_test_acc = 0.0, 0.0 85 | for i in range(test_num_batches): 86 | X_test, y_test = worker.fetch_batch('test') 87 | test_loss, test_acc = test_fn(X_test, y_test) 88 | print "\t", i, test_loss, test_acc 89 | net_test_loss += test_loss 90 | net_test_acc += test_acc 91 | test_loss = net_test_loss / test_num_batches 92 | test_acc = net_test_acc / test_num_batches 93 | 94 | print "====" * 20, "\n", "test loss: ", test_loss, ", test accuracy: ", test_acc * 100.0, "\n", "====" * 20 95 | 96 | meta_data["testing_loss"] = test_loss 97 | meta_data["testing_accuracy"] = test_acc 98 | 99 | return meta_data 100 | 101 | 102 | def serialize(obj, filename): 103 | with gzip.open("results/" + filename, "wb") as f: 104 | cPickle.dump(obj, f) 105 | f.close() 106 | 107 | 108 | def deserialize(filename): 109 | with gzip.open("results/" + filename, "rb") as f: 110 | obj = cPickle.load(f) 111 | f.close() 112 | return obj 113 | -------------------------------------------------------------------------------- /one_shot_tests.py: -------------------------------------------------------------------------------- 1 | predictor_file = 'ConvARC_OSFC.opf' 2 | embedder_file = 'ConvARC_OS.emf' 3 | embedding_size = 256 4 | 5 | 6 | import numpy as np 7 | 8 | from main import deserialize 9 | from data_workers import OmniglotOSLake, OmniglotVinyals 10 | 11 | np.random.seed(1969) 12 | 13 | 14 | if embedder_file is None: 15 | predictor = deserialize(predictor_file) 16 | else: 17 | def predictor(X): 18 | predictor = deserialize(predictor_file) 19 | embedder = deserialize(embedder_file) 20 | embeddings = embedder(X).reshape(-1, 20, embedding_size) 21 | return predictor(embeddings) 22 | 23 | print "\n ... testing on the set by Brenden Lake et al" 24 | 25 | worker = OmniglotOSLake() 26 | X_OS, t_OS = worker.fetch_batch() 27 | 28 | all_acc = [] 29 | for run in range(20): 30 | X = X_OS[run] 31 | t = t_OS[run] 32 | 33 | y = predictor(X).reshape(20, 20).argmax(axis=1) 34 | run_acc = np.mean(y == t) * 100.0 35 | print "run ", run + 1, ": ", run_acc 36 | all_acc.append(run_acc) 37 | 38 | print "accuracy: ", np.mean(all_acc), "%" 39 | 40 | 41 | print "\n\n ... testing on the method of Vinyals et al" 42 | 43 | worker = OmniglotVinyals(num_trials=20) 44 | 45 | all_acc = [] 46 | for run in range(20): 47 | X, t = worker.fetch_batch() 48 | 49 | y = predictor(X).reshape(20, 20).argmax(axis=1) 50 | run_acc = np.mean(y == t) * 100.0 51 | print "run ", run + 1, ": ", run_acc 52 | all_acc.append(run_acc) 53 | 54 | print "accuracy: ", np.mean(all_acc), "%" 55 | -------------------------------------------------------------------------------- /vis_attn_arc_omniglot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import theano 4 | import theano.tensor as T 5 | 6 | import lasagne 7 | from lasagne.layers import InputLayer, DenseLayer 8 | from lasagne.nonlinearities import sigmoid 9 | from lasagne.layers import get_output, helper 10 | 11 | from layers import SimpleARC 12 | from data_workers import OmniglotOS 13 | from main import deserialize 14 | 15 | import matplotlib.pyplot as plt 16 | import matplotlib.patches as patches 17 | 18 | plt.ion() 19 | plt.style.use('fivethirtyeight') 20 | plt.rcParams["figure.figsize"] = (16, 8) 21 | 22 | 23 | expt_name = "ARC_OS" 24 | image_size = 32 25 | attn_win = 4 26 | glimpses = 8 27 | lstm_states = 512 28 | batch_size = 2 29 | 30 | X = T.tensor4("input") 31 | y = T.imatrix("target") 32 | 33 | l_in = InputLayer(shape=(None, 1, image_size, image_size), input_var=X) 34 | l_arc = SimpleARC(l_in, lstm_states=lstm_states, image_size=image_size, attn_win=attn_win, glimpses=glimpses, fg_bias_init=0.0, final_state_only=False) 35 | 36 | embeddings = get_output(l_arc, deterministic=True) 37 | 38 | GPs = [] 39 | for i in range(-1, 2 * glimpses - 1): 40 | if i == -1: 41 | gp = T.dot(l_arc.W_g, T.zeros_like(embeddings[0].T)).T 42 | else: 43 | gp = T.dot(l_arc.W_g, embeddings[i].T).T 44 | 45 | center_y = gp[:, 0].dimshuffle([0, 'x']) 46 | center_x = gp[:, 1].dimshuffle([0, 'x']) 47 | delta = 1.0 - T.abs_(gp[:, 2]).dimshuffle([0, 'x']) 48 | gamma = T.exp(1.0 - 2 * T.abs_(gp[:, 2])).dimshuffle([0, 'x', 'x']) 49 | 50 | center_y = (image_size - 1) * (center_y + 1.0) / 2.0 51 | center_x = (image_size - 1) * (center_x + 1.0) / 2.0 52 | delta = image_size / attn_win * delta 53 | 54 | GPs.extend([center_y, center_x, delta]) 55 | 56 | embedding_fn = theano.function([X], outputs=GPs) 57 | 58 | params = deserialize(expt_name + '.params') 59 | helper.set_all_param_values(l_arc, params[:2]) 60 | 61 | worker = OmniglotOS(image_size=image_size, batch_size=batch_size) 62 | 63 | while(1): 64 | X_sample, _ = worker.fetch_batch('val') 65 | G = embedding_fn(X_sample) 66 | 67 | G = np.array(G) 68 | G = G.reshape(2 * glimpses, 3, batch_size) 69 | 70 | g = G[:, :, 0] 71 | I1 = X_sample[0, 0] 72 | I2 = X_sample[2, 0] 73 | 74 | fig_axs = plt.subplots(2, glimpses) 75 | fig = fig_axs[0] 76 | axs = fig_axs[1:] 77 | axs = axs[0] 78 | 79 | for i in range(glimpses): 80 | ax = axs[0, i] 81 | ax.imshow(I1, cmap="Greys_r") 82 | ax.set_xticklabels([]) 83 | ax.set_yticklabels([]) 84 | ax.xaxis.grid(False) 85 | ax.yaxis.grid(False) 86 | ax.set_title(str(i + 1)) 87 | x, y, w = g[2*i] 88 | w *= attn_win 89 | x = x - w / 2.0 90 | y = 32 - y - w / 2.0 91 | rect = patches.Rectangle((x, y), w, w, linewidth=(2*w-1)/8, edgecolor='b', facecolor='none') 92 | ax.add_patch(rect) 93 | 94 | ax = axs[1, i] 95 | ax.imshow(I2, cmap="Greys_r") 96 | ax.set_xticklabels([]) 97 | ax.set_yticklabels([]) 98 | ax.xaxis.grid(False) 99 | ax.yaxis.grid(False) 100 | x, y, w = g[2*i + 1] 101 | w *= attn_win 102 | x = x - w / 2.0 103 | y = 32 - y - w / 2.0 104 | rect = patches.Rectangle((x, y), w, w, linewidth=(2*w-1)/8, edgecolor='b', facecolor='none') 105 | ax.add_patch(rect) 106 | 107 | g = G[:, :, 1] 108 | I1 = X_sample[1, 0] 109 | I2 = X_sample[3, 0] 110 | 111 | fig_axs = plt.subplots(2, glimpses) 112 | fig = fig_axs[0] 113 | plt.subplots_adjust(wspace=0, hspace=0) 114 | axs = fig_axs[1:] 115 | axs = axs[0] 116 | 117 | for i in range(glimpses): 118 | ax = axs[0, i] 119 | ax.imshow(I1, cmap="Greys_r") 120 | ax.set_xticklabels([]) 121 | ax.set_yticklabels([]) 122 | ax.xaxis.grid(False) 123 | ax.yaxis.grid(False) 124 | ax.set_title(str(i + 1)) 125 | x, y, w = g[2*i] 126 | w *= attn_win 127 | x = x - w / 2.0 128 | y = 32 - y - w / 2.0 129 | rect = patches.Rectangle((x, y), w, w, linewidth=(2*w-1)/8, edgecolor='b', facecolor='none') 130 | ax.add_patch(rect) 131 | 132 | ax = axs[1, i] 133 | ax.imshow(I2, cmap="Greys_r") 134 | ax.set_xticklabels([]) 135 | ax.set_yticklabels([]) 136 | ax.xaxis.grid(False) 137 | ax.yaxis.grid(False) 138 | x, y, w = g[2*i + 1] 139 | w *= attn_win 140 | x = x - w / 2.0 141 | y = 32 - y - w / 2.0 142 | rect = patches.Rectangle((x, y), w, w, linewidth=(2*w-1)/8, edgecolor='b', facecolor='none') 143 | ax.add_patch(rect) 144 | 145 | raw_input() 146 | -------------------------------------------------------------------------------- /wrn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import theano 4 | import theano.tensor as T 5 | 6 | import lasagne 7 | from lasagne.layers import InputLayer 8 | from lasagne.layers import DenseLayer, DropoutLayer 9 | from lasagne.layers import batch_norm, BatchNormLayer, ExpressionLayer 10 | from lasagne.layers import Conv2DLayer as ConvLayer 11 | from lasagne.layers import ElemwiseSumLayer, NonlinearityLayer, GlobalPoolLayer 12 | from lasagne.nonlinearities import rectify, sigmoid 13 | from lasagne.init import HeNormal 14 | from lasagne.layers import get_all_params, get_all_layers, get_output 15 | from lasagne.regularization import regularize_layer_params 16 | from lasagne.objectives import binary_crossentropy, binary_accuracy 17 | from lasagne.updates import adam 18 | from lasagne.layers import helper 19 | 20 | from data_workers import OmniglotVerif, LFWVerif 21 | from main import train, test, serialize, deserialize 22 | 23 | import sys 24 | sys.setrecursionlimit(10000) 25 | 26 | 27 | import argparse 28 | 29 | 30 | def residual_block(l, increase_dim=False, projection=True, first=False, filters=16): 31 | if increase_dim: 32 | first_stride = (2, 2) 33 | else: 34 | first_stride = (1, 1) 35 | if first: 36 | bn_pre_relu = l 37 | else: 38 | bn_pre_conv = BatchNormLayer(l) 39 | bn_pre_relu = NonlinearityLayer(bn_pre_conv, rectify) 40 | conv_1 = batch_norm(ConvLayer(bn_pre_relu, num_filters=filters, filter_size=(3,3), stride=first_stride, nonlinearity=rectify, pad='same', W=HeNormal(gain='relu'))) 41 | dropout = DropoutLayer(conv_1, p=0.3) 42 | conv_2 = ConvLayer(dropout, num_filters=filters, filter_size=(3,3), stride=(1,1), nonlinearity=None, pad='same', W=HeNormal(gain='relu')) 43 | if increase_dim: 44 | projection = ConvLayer(l, num_filters=filters, filter_size=(1,1), stride=(2,2), nonlinearity=None, pad='same', b=None) 45 | block = ElemwiseSumLayer([conv_2, projection]) 46 | elif first: 47 | projection = ConvLayer(l, num_filters=filters, filter_size=(1,1), stride=(1,1), nonlinearity=None, pad='same', b=None) 48 | block = ElemwiseSumLayer([conv_2, projection]) 49 | else: 50 | block = ElemwiseSumLayer([conv_2, l]) 51 | return block 52 | 53 | 54 | parser = argparse.ArgumentParser(description="CLI for specifying hyper-parameters") 55 | parser.add_argument("-n", "--expt-name", type=str, default="", help="experiment name(for logging purposes)") 56 | parser.add_argument("--dataset", type=str, default="omniglot", help="omniglot/LFW") 57 | 58 | parser.add_argument("--wrn-depth", type=int, default=3, help="the resnet has depth equal to 6d+12") 59 | parser.add_argument("--wrn-width", type=int, default=2, help="width multiplier for each WRN block") 60 | 61 | meta_data = vars(parser.parse_args()) 62 | meta_data["expt_name"] = "WRN_VERIF_" + meta_data["dataset"] + "_" + meta_data["expt_name"] 63 | 64 | for md in meta_data.keys(): 65 | print md, meta_data[md] 66 | 67 | expt_name = meta_data["expt_name"] 68 | learning_rate = 1e-3 69 | image_size = 64 # 32 70 | batch_size = 128 71 | meta_data["n_iter"] = n_iter = 100000 72 | wrn_n = meta_data["wrn_depth"] 73 | wrn_k = meta_data["wrn_width"] 74 | meta_data["num_output"] = 2 75 | 76 | 77 | print "... setting up the network" 78 | n_filters = {0: 16, 1: 16 * wrn_k, 2: 32 * wrn_k, 3: 64 * wrn_k} 79 | 80 | X = T.tensor4("input") 81 | y = T.imatrix("target") 82 | 83 | l_in = InputLayer(shape=(None, 1, image_size, image_size), input_var=X) 84 | l = batch_norm(ConvLayer(l_in, num_filters=n_filters[0], filter_size=(3, 3), \ 85 | stride=(1, 1), nonlinearity=rectify, pad='same', W=HeNormal(gain='relu'))) 86 | l = residual_block(l, first=True, filters=n_filters[1]) 87 | for _ in range(1, wrn_n): 88 | l = residual_block(l, filters=n_filters[1]) 89 | l = residual_block(l, increase_dim=True, filters=n_filters[2]) 90 | for _ in range(1, (wrn_n+2)): 91 | l = residual_block(l, filters=n_filters[2]) 92 | l = residual_block(l, increase_dim=True, filters=n_filters[3]) 93 | for _ in range(1, (wrn_n+2)): 94 | l = residual_block(l, filters=n_filters[3]) 95 | 96 | bn_post_conv = BatchNormLayer(l) 97 | bn_post_relu = NonlinearityLayer(bn_post_conv, rectify) 98 | avg_pool = GlobalPoolLayer(bn_post_relu) 99 | dense_layer = DenseLayer(avg_pool, num_units=128, W=HeNormal(gain='relu'), nonlinearity=rectify) 100 | dist_layer = ExpressionLayer(dense_layer, lambda I: T.abs_(I[:I.shape[0]/2] - I[I.shape[0]/2:]), output_shape='auto') 101 | l_y = DenseLayer(dist_layer, num_units=1, nonlinearity=sigmoid) 102 | 103 | prediction = get_output(l_y) 104 | prediction_clean = get_output(l_y, deterministic=True) 105 | 106 | loss = T.mean(binary_crossentropy(prediction, y)) 107 | accuracy = T.mean(binary_accuracy(prediction_clean, y)) 108 | 109 | all_layers = get_all_layers(l_y) 110 | l2_penalty = 0.0001 * regularize_layer_params(all_layers, lasagne.regularization.l2) 111 | loss = loss + l2_penalty 112 | 113 | params = get_all_params(l_y, trainable=True) 114 | updates = adam(loss, params, learning_rate=learning_rate) 115 | 116 | meta_data["num_param"] = lasagne.layers.count_params(l_y) 117 | print "number of parameters: ", meta_data["num_param"] 118 | 119 | print "... compiling" 120 | train_fn = theano.function(inputs=[X, y], outputs=loss, updates=updates) 121 | val_fn = theano.function(inputs=[X, y], outputs=[loss, accuracy]) 122 | op_fn = theano.function([X], outputs=prediction_clean) 123 | 124 | 125 | print "... loading dataset" 126 | if meta_data["dataset"] == 'omniglot': 127 | worker = OmniglotOS(image_size=image_size, batch_size=batch_size) 128 | elif meta_data["dataset"] == 'lfw': 129 | worker = LFWVerif(image_size=image_size, batch_size=batch_size) 130 | 131 | meta_data, best_params = train(train_fn, val_fn, worker, meta_data, \ 132 | get_params=lambda: helper.get_all_param_values(l_y)) 133 | 134 | if meta_data["testing"]: 135 | print "... testing" 136 | helper.set_all_param_values(l_y, best_params) 137 | meta_data = test(val_fn, worker, meta_data) 138 | 139 | serialize(params, expt_name + '.params') 140 | serialize(meta_data, expt_name + '.mtd') 141 | serialize(op_fn, expt_name + '.opf') 142 | --------------------------------------------------------------------------------