├── AM3_TADAM.py ├── AM3_protonet++.py ├── LICENSE ├── NOTICE ├── common ├── gen_experiments.py └── util.py ├── datasets ├── create_dataset_miniImagenet.py ├── create_dataset_tieredimagenet.py ├── data.py ├── mini_imagenet_class_label_dict3.txt ├── test.csv ├── train.csv └── val.csv ├── protonet++.py ├── readme.md └── tadam.py /AM3_protonet++.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Training and evaluation entry point.""" 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import os 9 | import numpy as np 10 | import argparse 11 | import tensorflow as tf 12 | import tensorflow.contrib.slim as slim 13 | from tensorflow.python.ops import init_ops 14 | from tensorflow.python.ops import random_ops 15 | from tensorflow.python.framework import dtypes 16 | from scipy.spatial import KDTree 17 | from common.util import Dataset 18 | from common.util import ACTIVATION_MAP 19 | from tqdm import trange 20 | import pathlib 21 | import logging 22 | from common.util import summary_writer 23 | from common.gen_experiments import load_and_save_params 24 | import time 25 | import pickle as pkl 26 | 27 | 28 | tf.logging.set_verbosity(tf.logging.INFO) 29 | logging.basicConfig(level=logging.INFO) 30 | 31 | 32 | 33 | def _load_mini_imagenet(data_dir, split): 34 | """Load mini-imagenet from numpy's npz file format.""" 35 | _split_tag = {'sources': 'train', 'target_val': 'val', 'target_tst': 'test'}[split] 36 | dataset_path = os.path.join(data_dir, 'few-shot-{}.npz'.format(_split_tag)) 37 | logging.info("Loading mini-imagenet...") 38 | data = np.load(dataset_path) 39 | fields = data['features'], data['targets'] 40 | logging.info("Done loading.") 41 | return fields 42 | 43 | def get_image_size(data_dir): 44 | if 'mini-imagenet' or 'tiered' in data_dir: 45 | image_size = 84 46 | elif 'cifar' in data_dir: 47 | image_size = 32 48 | else: 49 | raise Exception('Unknown dataset: %s' % data_dir) 50 | return image_size 51 | 52 | 53 | class Namespace(object): 54 | def __init__(self, adict): 55 | self.__dict__.update(adict) 56 | 57 | 58 | def get_arguments(): 59 | parser = argparse.ArgumentParser() 60 | 61 | parser.add_argument('--mode', type=str, default='train', 62 | choices=['train', 'eval', 'test', 'train_classifier', 'create_embedding']) 63 | # Dataset parameters 64 | parser.add_argument('--data_dir', type=str, default=None, help='Path to the data.') 65 | parser.add_argument('--data_split', type=str, default='sources', choices=['sources', 'target_val', 'target_tst'], 66 | help='Split of the data to be used to perform operation.') 67 | 68 | # Training parameters 69 | parser.add_argument('--number_of_steps', type=int, default=int(30000), 70 | help="Number of training steps (number of Epochs in Hugo's paper)") 71 | parser.add_argument('--number_of_steps_to_early_stop', type=int, default=int(1000000), 72 | help="Number of training steps after half way to early stop the training") 73 | parser.add_argument('--log_dir', type=str, default='', help='Base log dir') 74 | parser.add_argument('--num_classes_train', type=int, default=5, 75 | help='Number of classes in the train phase, this is coming from the prototypical networks') 76 | parser.add_argument('--num_shots_train', type=int, default=5, 77 | help='Number of shots in a few shot meta-train scenario') 78 | parser.add_argument('--train_batch_size', type=int, default=32, help='Training batch size.') 79 | parser.add_argument('--num_tasks_per_batch', type=int, default=2, 80 | help='Number of few shot tasks per batch, so the task encoding batch is num_tasks_per_batch x num_classes_test x num_shots_train .') 81 | parser.add_argument('--init_learning_rate', type=float, default=0.1, help='Initial learning rate.') 82 | parser.add_argument('--save_summaries_secs', type=int, default=60, help='Time between saving summaries') 83 | parser.add_argument('--save_interval_secs', type=int, default=60, help='Time between saving model?') 84 | parser.add_argument('--optimizer', type=str, default='sgd', choices=['sgd', 'adam']) 85 | parser.add_argument('--augment', type=bool, default=False) 86 | # Learning rate paramteres 87 | parser.add_argument('--lr_anneal', type=str, default='pwc', choices=['const', 'pwc', 'cos', 'exp']) 88 | parser.add_argument('--n_lr_decay', type=int, default=3) 89 | parser.add_argument('--lr_decay_rate', type=float, default=10.0) 90 | parser.add_argument('--num_steps_decay_pwc', type=int, default=2500, 91 | help='Decay learning rate every num_steps_decay_pwc') 92 | 93 | parser.add_argument('--clip_gradient_norm', type=float, default=1.0, help='gradient clip norm.') 94 | parser.add_argument('--weights_initializer_factor', type=float, default=0.1, 95 | help='multiplier in the variance of the initialization noise.') 96 | # Evaluation parameters 97 | parser.add_argument('--max_number_of_evaluations', type=float, default=float('inf')) 98 | parser.add_argument('--eval_interval_secs', type=int, default=120, help='Time between evaluating model?') 99 | parser.add_argument('--eval_interval_steps', type=int, default=1000, 100 | help='Number of train steps between evaluating model in the training loop') 101 | parser.add_argument('--eval_interval_fine_steps', type=int, default=250, 102 | help='Number of train steps between evaluating model in the training loop in the final phase') 103 | # Test parameters 104 | parser.add_argument('--num_classes_test', type=int, default=5, help='Number of classes in the test phase') 105 | parser.add_argument('--num_shots_test', type=int, default=5, 106 | help='Number of shots in a few shot meta-test scenario') 107 | parser.add_argument('--num_cases_test', type=int, default=100000, 108 | help='Number of few-shot cases to compute test accuracy') 109 | # Architecture parameters 110 | parser.add_argument('--dropout', type=float, default=1.0) 111 | parser.add_argument('--conv_dropout', type=float, default=None) 112 | parser.add_argument('--feature_dropout_p', type=float, default=None) 113 | 114 | parser.add_argument('--weight_decay', type=float, default=0.0005) 115 | parser.add_argument('--num_filters', type=int, default=64) 116 | parser.add_argument('--num_units_in_block', type=int, default=3) 117 | parser.add_argument('--num_blocks', type=int, default=4) 118 | parser.add_argument('--num_max_pools', type=int, default=3) 119 | parser.add_argument('--block_size_growth', type=float, default=2.0) 120 | parser.add_argument('--activation', type=str, default='swish-1', choices=['relu', 'selu', 'swish-1']) 121 | 122 | parser.add_argument('--feature_expansion_size', type=int, default=None) 123 | parser.add_argument('--feature_bottleneck_size', type=int, default=None) 124 | 125 | parser.add_argument('--feature_extractor', type=str, default='simple_res_net', 126 | choices=['simple_res_net'], help='Which feature extractor to use') 127 | 128 | 129 | parser.add_argument('--encoder_sharing', type=str, default='shared', 130 | choices=['shared'], 131 | help='How to link fetaure extractors in task encoder and classifier') 132 | parser.add_argument('--encoder_classifier_link', type=str, default='prototypical', 133 | choices=['prototypical'], 134 | help='How to link fetaure extractors in task encoder and classifier') 135 | parser.add_argument('--embedding_pooled', type=bool, default=True, 136 | help='Whether to use avg pooling to create embedding') 137 | parser.add_argument('--task_encoder', type=str, default='self_att_mlp', 138 | choices=['class_mean', 'fixed_alpha','fixed_alpha_mlp','self_att_mlp']) 139 | 140 | # 141 | parser.add_argument('--num_batches_neg_mining', type=int, default=0) 142 | parser.add_argument('--eval_batch_size', type=int, default=100, help='Evaluation batch size') 143 | 144 | parser.add_argument('--alpha', type=float, default=1.0) 145 | parser.add_argument('--mlp_weight_decay', type=float, default=0.0) 146 | parser.add_argument('--mlp_dropout', type=float, default=0.0) 147 | parser.add_argument('--mlp_type', type=str, default='non-linear') 148 | parser.add_argument('--att_input', type=str, default='word') 149 | 150 | args = parser.parse_args() 151 | 152 | print(args) 153 | return args 154 | 155 | 156 | def get_logdir_name(flags): 157 | """Generates the name of the log directory from the values of flags 158 | Parameters 159 | ---------- 160 | flags: neural net architecture generated by get_arguments() 161 | Outputs 162 | ------- 163 | the name of the directory to store the training and evaluation results 164 | """ 165 | logdir = flags.log_dir 166 | 167 | return logdir 168 | 169 | 170 | class ScaledVarianceRandomNormal(init_ops.Initializer): 171 | """Initializer that generates tensors with a normal distribution scaled as per https://arxiv.org/pdf/1502.01852.pdf. 172 | Args: 173 | mean: a python scalar or a scalar tensor. Mean of the random values 174 | to generate. 175 | stddev: a python scalar or a scalar tensor. Standard deviation of the 176 | random values to generate. 177 | seed: A Python integer. Used to create random seeds. See 178 | @{tf.set_random_seed} 179 | for behavior. 180 | dtype: The data type. Only floating point types are supported. 181 | """ 182 | 183 | def __init__(self, mean=0.0, factor=1.0, seed=None, dtype=dtypes.float32): 184 | self.mean = mean 185 | self.factor = factor 186 | self.seed = seed 187 | self.dtype = dtypes.as_dtype(dtype) 188 | 189 | def __call__(self, shape, dtype=None, partition_info=None): 190 | if dtype is None: 191 | dtype = self.dtype 192 | 193 | if shape: 194 | n = float(shape[-1]) 195 | else: 196 | n = 1.0 197 | for dim in shape[:-2]: 198 | n *= float(dim) 199 | 200 | self.stddev = np.sqrt(self.factor * 2.0 / n) 201 | return random_ops.random_normal(shape, self.mean, self.stddev, 202 | dtype, seed=self.seed) 203 | 204 | 205 | def _get_scope(is_training, flags): 206 | normalizer_params = { 207 | 'epsilon': 0.001, 208 | 'momentum': .95, 209 | 'trainable': is_training, 210 | 'training': is_training, 211 | } 212 | conv2d_arg_scope = slim.arg_scope( 213 | [slim.conv2d, slim.fully_connected], 214 | activation_fn=ACTIVATION_MAP[flags.activation], 215 | normalizer_fn=tf.layers.batch_normalization, 216 | normalizer_params=normalizer_params, 217 | # padding='SAME', 218 | trainable=is_training, 219 | weights_regularizer=tf.contrib.layers.l2_regularizer(scale=flags.weight_decay), 220 | weights_initializer=ScaledVarianceRandomNormal(factor=flags.weights_initializer_factor), 221 | biases_initializer=tf.constant_initializer(0.0) 222 | ) 223 | dropout_arg_scope = slim.arg_scope( 224 | [slim.dropout], 225 | keep_prob=flags.dropout, 226 | is_training=is_training) 227 | return conv2d_arg_scope, dropout_arg_scope 228 | 229 | 230 | def build_simple_conv_net(images, flags, is_training, reuse=None, scope=None): 231 | conv2d_arg_scope, dropout_arg_scope = _get_scope(is_training, flags) 232 | with conv2d_arg_scope, dropout_arg_scope: 233 | with tf.variable_scope(scope or 'feature_extractor', reuse=reuse): 234 | h = images 235 | for i in range(4): 236 | h = slim.conv2d(h, num_outputs=flags.num_filters, kernel_size=3, stride=1, 237 | scope='conv' + str(i), padding='SAME', 238 | weights_initializer=ScaledVarianceRandomNormal(factor=flags.weights_initializer_factor)) 239 | h = slim.max_pool2d(h, kernel_size=2, stride=2, padding='VALID', scope='max_pool' + str(i)) 240 | 241 | if flags.embedding_pooled == True: 242 | kernel_size = h.shape.as_list()[-2] 243 | h = slim.avg_pool2d(h, kernel_size=kernel_size, scope='avg_pool') 244 | h = slim.flatten(h) 245 | return h 246 | 247 | 248 | def leaky_relu(x, alpha=0.1, name=None): 249 | return tf.maximum(x, alpha * x, name=name) 250 | 251 | 252 | 253 | 254 | def build_simple_res_net(images, flags, num_filters, beta=None, gamma=None, is_training=False, reuse=None, scope=None): 255 | conv2d_arg_scope, dropout_arg_scope = _get_scope(is_training, flags) 256 | activation_fn = ACTIVATION_MAP[flags.activation] 257 | with conv2d_arg_scope, dropout_arg_scope: 258 | with tf.variable_scope(scope or 'feature_extractor', reuse=reuse): 259 | h = images 260 | for i in range(len(num_filters)): 261 | # make shortcut 262 | shortcut = slim.conv2d(h, num_outputs=num_filters[i], kernel_size=1, stride=1, 263 | activation_fn=None, 264 | scope='shortcut' + str(i), padding='SAME') 265 | 266 | for j in range(flags.num_units_in_block): 267 | h = slim.conv2d(h, num_outputs=num_filters[i], kernel_size=3, stride=1, 268 | scope='conv' + str(i) + '_' + str(j), padding='SAME', activation_fn=None) 269 | if flags.conv_dropout: 270 | h = slim.dropout(h, keep_prob=1.0 - flags.conv_dropout) 271 | 272 | if j < (flags.num_units_in_block - 1): 273 | h = activation_fn(h, name='activation_' + str(i) + '_' + str(j)) 274 | h = h + shortcut 275 | 276 | h = activation_fn(h, name='activation_' + str(i) + '_' + str(flags.num_units_in_block - 1)) 277 | if i < flags.num_max_pools: 278 | h = slim.max_pool2d(h, kernel_size=2, stride=2, padding='SAME', scope='max_pool' + str(i)) 279 | 280 | if flags.feature_expansion_size: 281 | if flags.feature_dropout_p: 282 | h = slim.dropout(h, scope='feature_expansion_dropout', keep_prob=1.0 - flags.feature_dropout_p) 283 | h = slim.conv2d(slim.dropout(h), num_outputs=flags.feature_expansion_size, kernel_size=1, stride=1, 284 | scope='feature_expansion', padding='SAME') 285 | 286 | if flags.embedding_pooled == True: 287 | kernel_size = h.shape.as_list()[-2] 288 | h = slim.avg_pool2d(h, kernel_size=kernel_size, scope='avg_pool') 289 | h = slim.flatten(h) 290 | 291 | if flags.feature_dropout_p: 292 | h = slim.dropout(h, scope='feature_bottleneck_dropout', keep_prob=1.0 - flags.feature_dropout_p) 293 | # Bottleneck layer 294 | if flags.feature_bottleneck_size: 295 | h = slim.fully_connected(h, num_outputs=flags.feature_bottleneck_size, 296 | activation_fn=activation_fn, normalizer_fn=None, 297 | scope='feature_bottleneck') 298 | 299 | return h 300 | 301 | 302 | 303 | def build_wordemb_transformer(embeddings, flags, is_training=False, reuse=None, scope=None): 304 | with tf.variable_scope(scope or 'mlp_transformer', reuse=reuse): 305 | h = embeddings 306 | if flags.mlp_type=='linear': 307 | h = slim.fully_connected(h, 512, reuse=False, scope='mlp_layer', 308 | activation_fn=None, trainable=is_training, 309 | weights_regularizer=tf.contrib.layers.l2_regularizer(scale=flags.mlp_weight_decay), 310 | weights_initializer=ScaledVarianceRandomNormal(factor=flags.weights_initializer_factor), 311 | biases_initializer=tf.constant_initializer(0.0)) 312 | elif flags.mlp_type=='non-linear': 313 | h = slim.fully_connected(h, 300, reuse=False, scope='mlp_layer', 314 | activation_fn=tf.nn.relu, trainable=is_training, 315 | weights_regularizer=tf.contrib.layers.l2_regularizer( 316 | scale=flags.mlp_weight_decay), 317 | weights_initializer=ScaledVarianceRandomNormal( 318 | factor=flags.weights_initializer_factor), 319 | biases_initializer=tf.constant_initializer(0.0)) 320 | h = slim.dropout(h, scope='mlp_dropout', keep_prob=1.0 - flags.mlp_dropout, is_training=is_training) 321 | h = slim.fully_connected(h, 512, reuse=False, scope='mlp_layer_1', 322 | activation_fn=None, trainable=is_training, 323 | weights_regularizer=tf.contrib.layers.l2_regularizer( 324 | scale=flags.mlp_weight_decay), 325 | weights_initializer=ScaledVarianceRandomNormal( 326 | factor=flags.weights_initializer_factor), 327 | biases_initializer=tf.constant_initializer(0.0)) 328 | 329 | return h 330 | 331 | def build_self_attention(embeddings, flags, is_training=False, reuse=None, scope=None): 332 | with tf.variable_scope(scope or 'self_attention', reuse=reuse): 333 | h = embeddings 334 | if flags.mlp_type=='linear': 335 | h = slim.fully_connected(h, 1, reuse=False, scope='self_att_layer', 336 | activation_fn=None, trainable=is_training, 337 | weights_regularizer=tf.contrib.layers.l2_regularizer(scale=flags.mlp_weight_decay), 338 | weights_initializer=ScaledVarianceRandomNormal(factor=flags.weights_initializer_factor), 339 | biases_initializer=tf.constant_initializer(0.0)) 340 | elif flags.mlp_type=='non-linear': 341 | h = slim.fully_connected(h, 300, reuse=False, scope='self_att_layer', 342 | activation_fn=tf.nn.relu, trainable=is_training, 343 | weights_regularizer=tf.contrib.layers.l2_regularizer( 344 | scale=flags.mlp_weight_decay), 345 | weights_initializer=ScaledVarianceRandomNormal( 346 | factor=flags.weights_initializer_factor), 347 | biases_initializer=tf.constant_initializer(0.0)) 348 | h = slim.dropout(h, scope='self_att_dropout', keep_prob=1.0 - flags.mlp_dropout, is_training=is_training) 349 | h = slim.fully_connected(h, 1, reuse=False, scope='self_att_layer_1', 350 | activation_fn=None, trainable=is_training, 351 | weights_regularizer=tf.contrib.layers.l2_regularizer( 352 | scale=flags.mlp_weight_decay), 353 | weights_initializer=ScaledVarianceRandomNormal( 354 | factor=flags.weights_initializer_factor), 355 | biases_initializer=tf.constant_initializer(0.0)) 356 | h = tf.sigmoid(h) 357 | 358 | return h 359 | 360 | def get_res_net_block(h, flags, num_filters, num_units, pool=False, is_training=False, 361 | reuse=None, scope=None): 362 | conv2d_arg_scope, dropout_arg_scope = _get_scope(is_training, flags) 363 | activation_fn = ACTIVATION_MAP[flags.activation] 364 | with conv2d_arg_scope, dropout_arg_scope: 365 | with tf.variable_scope(scope, reuse=reuse): 366 | # make shortcut 367 | shortcut = slim.conv2d(h, num_outputs=num_filters, kernel_size=1, stride=1, 368 | activation_fn=None, 369 | scope='shortcut', padding='SAME') 370 | 371 | for j in range(num_units): 372 | h = slim.conv2d(h, num_outputs=num_filters, kernel_size=3, stride=1, 373 | scope='conv_' + str(j), padding='SAME', activation_fn=None) 374 | if flags.conv_dropout: 375 | h = slim.dropout(h, keep_prob=1.0 - flags.conv_dropout) 376 | if j < (num_units - 1): 377 | h = activation_fn(h, name='activation_' + str(j)) 378 | h = h + shortcut 379 | h = activation_fn(h, name='activation_' + '_' + str(flags.num_units_in_block - 1)) 380 | if pool: 381 | h = slim.max_pool2d(h, kernel_size=2, stride=2, padding='SAME', scope='max_pool') 382 | return h 383 | 384 | 385 | 386 | def build_feature_extractor_graph(images, flags, num_filters, beta=None, gamma=None, is_training=False, 387 | scope='feature_extractor_task_encoder', reuse=None, is_64way=False): 388 | if flags.feature_extractor == 'simple_conv_net': 389 | h = build_simple_conv_net(images, flags=flags, is_training=is_training, reuse=reuse, scope=scope) 390 | elif flags.feature_extractor == 'simple_res_net': 391 | h = build_simple_res_net(images, flags=flags, num_filters=num_filters, beta=beta, gamma=gamma, 392 | is_training=is_training, reuse=reuse, scope=scope) 393 | else: 394 | h = None 395 | 396 | embedding_shape = h.get_shape().as_list() 397 | if is_training and is_64way is False: 398 | h = tf.reshape(h, shape=(flags.num_tasks_per_batch, embedding_shape[0] // flags.num_tasks_per_batch, -1), 399 | name='reshape_to_separate_tasks_generic_features') 400 | else: 401 | h = tf.reshape(h, shape=(1, embedding_shape[0], -1), 402 | name='reshape_to_separate_tasks_generic_features') 403 | 404 | return h 405 | 406 | 407 | 408 | def build_task_encoder(embeddings, label_embeddings, flags, is_training, querys=None, reuse=None, scope='class_encoder'): 409 | conv2d_arg_scope, dropout_arg_scope = _get_scope(is_training, flags) 410 | alpha=None 411 | 412 | with conv2d_arg_scope, dropout_arg_scope: 413 | with tf.variable_scope(scope, reuse=reuse): 414 | 415 | if flags.task_encoder == 'talkthrough': 416 | task_encoding = embeddings 417 | elif flags.task_encoder == 'class_mean': 418 | task_encoding = embeddings 419 | 420 | if is_training: 421 | task_encoding = tf.reshape(task_encoding, shape=( 422 | flags.num_tasks_per_batch, flags.num_classes_train, flags.num_shots_train, -1), 423 | name='reshape_to_separate_tasks_task_encoding') 424 | else: 425 | task_encoding = tf.reshape(task_encoding, 426 | shape=(1, flags.num_classes_test, flags.num_shots_test, -1), 427 | name='reshape_to_separate_tasks_task_encoding') 428 | task_encoding = tf.reduce_mean(task_encoding, axis=2, keep_dims=False) 429 | elif flags.task_encoder == 'fixed_alpha': 430 | task_encoding = embeddings 431 | print("entered the word embedding task encoder...") 432 | 433 | if is_training: 434 | task_encoding = tf.reshape(task_encoding, shape=( 435 | flags.num_tasks_per_batch, flags.num_classes_train, flags.num_shots_train, -1), 436 | name='reshape_to_separate_tasks_task_encoding') 437 | label_embeddings = tf.reshape(label_embeddings, shape=( 438 | flags.num_tasks_per_batch, flags.num_classes_train, -1), 439 | name='reshape_to_separate_tasks_label_embedding') 440 | else: 441 | task_encoding = tf.reshape(task_encoding, 442 | shape=(1, flags.num_classes_test, flags.num_shots_test, -1), 443 | name='reshape_to_separate_tasks_task_encoding') 444 | label_embeddings = tf.reshape(label_embeddings, 445 | shape=(1, flags.num_classes_test, -1), 446 | name='reshape_to_separate_tasks_label_embedding') 447 | task_encoding = tf.reduce_mean(task_encoding, axis=2, keep_dims=False) 448 | task_encoding = flags.alpha*task_encoding+(1-flags.alpha)*label_embeddings 449 | elif flags.task_encoder == 'fixed_alpha_mlp': 450 | task_encoding = embeddings 451 | print("entered the word embedding task encoder...") 452 | label_embeddings = build_wordemb_transformer(label_embeddings,flags,is_training) 453 | 454 | if is_training: 455 | task_encoding = tf.reshape(task_encoding, shape=( 456 | flags.num_tasks_per_batch, flags.num_classes_train, flags.num_shots_train, -1), 457 | name='reshape_to_separate_tasks_task_encoding') 458 | label_embeddings = tf.reshape(label_embeddings, shape=( 459 | flags.num_tasks_per_batch, flags.num_classes_train, -1), 460 | name='reshape_to_separate_tasks_label_embedding') 461 | else: 462 | task_encoding = tf.reshape(task_encoding, 463 | shape=(1, flags.num_classes_test, flags.num_shots_test, -1), 464 | name='reshape_to_separate_tasks_task_encoding') 465 | label_embeddings = tf.reshape(label_embeddings, 466 | shape=(1, flags.num_classes_test, -1), 467 | name='reshape_to_separate_tasks_label_embedding') 468 | task_encoding = tf.reduce_mean(task_encoding, axis=2, keep_dims=False) 469 | task_encoding = flags.alpha*task_encoding+(1-flags.alpha)*label_embeddings 470 | elif flags.task_encoder == 'self_att_mlp': 471 | task_encoding = embeddings 472 | print("entered the word embedding task encoder...") 473 | label_embeddings = build_wordemb_transformer(label_embeddings,flags,is_training) 474 | 475 | if is_training: 476 | task_encoding = tf.reshape(task_encoding, shape=( 477 | flags.num_tasks_per_batch, flags.num_classes_train, flags.num_shots_train, -1), 478 | name='reshape_to_separate_tasks_task_encoding') 479 | label_embeddings = tf.reshape(label_embeddings, shape=( 480 | flags.num_tasks_per_batch, flags.num_classes_train, -1), 481 | name='reshape_to_separate_tasks_label_embedding') 482 | else: 483 | task_encoding = tf.reshape(task_encoding, 484 | shape=(1, flags.num_classes_test, flags.num_shots_test, -1), 485 | name='reshape_to_separate_tasks_task_encoding') 486 | label_embeddings = tf.reshape(label_embeddings, 487 | shape=(1, flags.num_classes_test, -1), 488 | name='reshape_to_separate_tasks_label_embedding') 489 | task_encoding = tf.reduce_mean(task_encoding, axis=2, keep_dims=False) 490 | 491 | if flags.att_input=='proto': 492 | alpha = build_self_attention(task_encoding,flags,is_training) 493 | elif flags.att_input=='word': 494 | alpha = build_self_attention(label_embeddings,flags,is_training) 495 | elif flags.att_input=='combined': 496 | embeddings=tf.concat([task_encoding, label_embeddings], axis=2) 497 | alpha = build_self_attention(embeddings, flags, is_training) 498 | 499 | elif flags.att_input=='queryword': 500 | j = label_embeddings.get_shape().as_list()[1] 501 | i = querys.get_shape().as_list()[1] 502 | task_encoding_tile = tf.expand_dims(task_encoding, axis=1) 503 | task_encoding_tile = tf.tile(task_encoding_tile, (1, i, 1, 1)) 504 | querys_tile = tf.expand_dims(querys, axis=2) 505 | querys_tile = tf.tile(querys_tile, (1, 1, j, 1)) 506 | label_embeddings_tile = tf.expand_dims(label_embeddings, axis=1) 507 | label_embeddings_tile = tf.tile(label_embeddings_tile, (1, i, 1, 1)) 508 | att_input = tf.concat([label_embeddings_tile, querys_tile], axis=3) 509 | alpha = build_self_attention(att_input, flags, is_training) 510 | elif flags.att_input=='queryproto': 511 | j = task_encoding.get_shape().as_list()[1] 512 | i = querys.get_shape().as_list()[1] 513 | task_encoding_tile = tf.expand_dims(task_encoding, axis=1) 514 | task_encoding_tile = tf.tile(task_encoding_tile, (1, i, 1, 1)) 515 | querys_tile = tf.expand_dims(querys, axis=2) 516 | querys_tile = tf.tile(querys_tile, (1, 1, j, 1)) 517 | label_embeddings_tile = tf.expand_dims(label_embeddings, axis=1) 518 | label_embeddings_tile = tf.tile(label_embeddings_tile, (1, i, 1, 1)) 519 | att_input = tf.concat([task_encoding_tile, querys_tile], axis=3) 520 | alpha = build_self_attention(att_input, flags, is_training) 521 | 522 | if querys is None: 523 | task_encoding = alpha*task_encoding+(1-alpha)*label_embeddings 524 | else: 525 | task_encoding = alpha * task_encoding_tile + (1-alpha) * label_embeddings_tile 526 | 527 | else: 528 | task_encoding = None 529 | 530 | return task_encoding, alpha 531 | 532 | 533 | def build_prototypical_head(features_generic, task_encoding, flags, is_training, scope='prototypical_head'): 534 | """ 535 | Implements the prototypical networks few-shot head 536 | :param features_generic: 537 | :param task_encoding: 538 | :param flags: 539 | :param is_training: 540 | :param reuse: 541 | :param scope: 542 | :return: 543 | """ 544 | 545 | with tf.variable_scope(scope): 546 | 547 | if len(features_generic.get_shape().as_list()) == 2: 548 | features_generic = tf.expand_dims(features_generic, axis=0) 549 | if len(task_encoding.get_shape().as_list()) == 2: 550 | task_encoding = tf.expand_dims(task_encoding, axis=0) 551 | 552 | # i is the number of steps in the task_encoding sequence 553 | # j is the number of steps in the features_generic sequence 554 | j = task_encoding.get_shape().as_list()[1] 555 | i = features_generic.get_shape().as_list()[1] 556 | 557 | # tile to be able to produce weight matrix alpha in (i,j) space 558 | features_generic = tf.expand_dims(features_generic, axis=2) 559 | task_encoding = tf.expand_dims(task_encoding, axis=1) 560 | # features_generic changes over i and is constant over j 561 | # task_encoding changes over j and is constant over i 562 | task_encoding_tile = tf.tile(task_encoding, (1, i, 1, 1)) 563 | features_generic_tile = tf.tile(features_generic, (1, 1, j, 1)) 564 | # implement equation (4) 565 | euclidian = -tf.norm(task_encoding_tile - features_generic_tile, name='neg_euclidian_distance', axis=-1) 566 | 567 | if is_training: 568 | euclidian = tf.reshape(euclidian, shape=(flags.num_tasks_per_batch * flags.train_batch_size, -1)) 569 | else: 570 | euclidian_shape = euclidian.get_shape().as_list() 571 | euclidian = tf.reshape(euclidian, shape=(euclidian_shape[1], -1)) 572 | 573 | return euclidian 574 | 575 | 576 | def build_prototypical_head_protoperquery(features_generic, task_encoding, flags, is_training, scope='prototypical_head'): 577 | """ 578 | Implements the prototypical networks few-shot head 579 | :param features_generic: 580 | :param task_encoding: 581 | :param flags: 582 | :param is_training: 583 | :param reuse: 584 | :param scope: 585 | :return: 586 | """ 587 | # the shape of task_encoding is [num_tasks, batch_size, num_classes, ] 588 | 589 | with tf.variable_scope(scope): 590 | 591 | if len(features_generic.get_shape().as_list()) == 2: 592 | features_generic = tf.expand_dims(features_generic, axis=0) 593 | if len(task_encoding.get_shape().as_list()) == 2: 594 | task_encoding = tf.expand_dims(task_encoding, axis=0) 595 | 596 | # i is the number of steps in the task_encoding sequence 597 | # j is the number of steps in the features_generic sequence 598 | j = task_encoding.get_shape().as_list()[2] 599 | i = features_generic.get_shape().as_list()[1] 600 | 601 | # tile to be able to produce weight matrix alpha in (i,j) space 602 | features_generic = tf.expand_dims(features_generic, axis=2) 603 | #task_encoding = tf.expand_dims(task_encoding, axis=1) 604 | # features_generic changes over i and is constant over j 605 | # task_encoding changes over j and is constant over i 606 | features_generic_tile = tf.tile(features_generic, (1, 1, j, 1)) 607 | # implement equation (4) 608 | euclidian = -tf.norm(task_encoding - features_generic_tile, name='neg_euclidian_distance', axis=-1) 609 | 610 | if is_training: 611 | euclidian = tf.reshape(euclidian, shape=(flags.num_tasks_per_batch * flags.train_batch_size, -1)) 612 | else: 613 | euclidian_shape = euclidian.get_shape().as_list() 614 | euclidian = tf.reshape(euclidian, shape=(euclidian_shape[1], -1)) 615 | 616 | return euclidian 617 | 618 | def build_regularizer_head(embeddings, label_embeddings, flags, is_training, scope='regularizer_head'): 619 | """ 620 | Implements the prototypical networks few-shot head 621 | :param features_generic: 622 | :param task_encoding: 623 | :param flags: 624 | :param is_training: 625 | :param reuse: 626 | :param scope: 627 | :return: 628 | """ 629 | 630 | with tf.variable_scope(scope): 631 | task_encoding = embeddings 632 | 633 | if is_training: 634 | task_encoding = tf.reshape(task_encoding, shape=( 635 | flags.num_tasks_per_batch, flags.num_classes_train, flags.num_shots_train, -1), 636 | name='reshape_to_separate_tasks_task_encoding') 637 | label_embeddings = tf.reshape(label_embeddings, shape=( 638 | flags.num_tasks_per_batch, flags.num_classes_train, -1), 639 | name='reshape_to_separate_tasks_label_embedding') 640 | else: 641 | task_encoding = tf.reshape(task_encoding, 642 | shape=(1, flags.num_classes_test, flags.num_shots_test, -1), 643 | name='reshape_to_separate_tasks_task_encoding') 644 | label_embeddings = tf.reshape(label_embeddings, 645 | shape=(1, flags.num_classes_test, -1), 646 | name='reshape_to_separate_tasks_label_embedding') 647 | task_encoding = tf.reduce_mean(task_encoding, axis=2, keep_dims=False) 648 | 649 | # i is the number of steps in the task_encoding sequence 650 | # j is the number of steps in the features_generic sequence 651 | j = task_encoding.get_shape().as_list()[1] 652 | i = label_embeddings.get_shape().as_list()[1] 653 | 654 | # tile to be able to produce weight matrix alpha in (i,j) space 655 | task_encoding = tf.expand_dims(task_encoding, axis=2) 656 | label_embeddings = tf.expand_dims(label_embeddings, axis=1) 657 | # features_generic changes over i and is constant over j 658 | # task_encoding changes over j and is constant over i 659 | label_embeddings_tile = tf.tile(label_embeddings, (1, i, 1, 1)) 660 | task_encoding_tile = tf.tile(task_encoding, (1, 1, j, 1)) 661 | # implement equation (4) 662 | euclidian = -tf.norm(task_encoding_tile - label_embeddings_tile, name='neg_euclidian_distance_regularizer', axis=-1) 663 | 664 | if is_training: 665 | euclidian = tf.reshape(euclidian, shape=(flags.num_tasks_per_batch * flags.num_classes_train, -1)) 666 | else: 667 | euclidian_shape = euclidian.get_shape().as_list() 668 | euclidian = tf.reshape(euclidian, shape=(euclidian_shape[1], -1)) 669 | 670 | return euclidian 671 | 672 | 673 | def placeholder_inputs(batch_size, image_size, scope): 674 | """ 675 | :param batch_size: 676 | :return: placeholders for images and 677 | """ 678 | with tf.variable_scope(scope): 679 | images_placeholder = tf.placeholder(tf.float32, shape=(batch_size, image_size, image_size, 3), name='images') 680 | labels_placeholder = tf.placeholder(tf.int64, shape=(batch_size), name='labels') 681 | return images_placeholder, labels_placeholder 682 | 683 | 684 | def get_batch(data_set, images_placeholder, labels_placeholder, batch_size): 685 | """ 686 | :param data_set: 687 | :param images_placeholder: 688 | :param labels_placeholder: 689 | :return: 690 | """ 691 | images_feed, labels_feed = data_set.next_batch(batch_size) 692 | 693 | feed_dict = { 694 | images_placeholder: images_feed.astype(dtype=np.float32), 695 | labels_placeholder: labels_feed, 696 | } 697 | return feed_dict 698 | 699 | 700 | def preprocess(images): 701 | # mean = tf.constant(np.asarray([127.5, 127.5, 127.5]).reshape([1, 1, 3]), dtype=tf.float32, name='image_mean') 702 | # std = tf.constant(np.asarray([127.5, 127.5, 127.5]).reshape([1, 1, 3]), dtype=tf.float32, name='image_std') 703 | # return tf.div(tf.subtract(images, mean), std) 704 | 705 | std = tf.constant(np.asarray([0.5, 0.5, 0.5]).reshape([1, 1, 3]), dtype=tf.float32, name='image_std') 706 | return tf.div(images, std) 707 | 708 | 709 | def get_nearest_neighbour_acc(flags, embeddings, labels): 710 | num_correct = 0 711 | num_tot = 0 712 | for i in trange(flags.num_cases_test): 713 | test_classes = np.random.choice(np.unique(labels), size=flags.num_classes_test, replace=False) 714 | train_idxs, test_idxs = get_few_shot_idxs(labels=labels, classes=test_classes, num_shots=flags.num_shots_test) 715 | # TODO: this is to fix the OOM error, this can be removed when embed() supports batch processing 716 | test_idxs = np.random.choice(test_idxs, size=100, replace=False) 717 | 718 | np_embedding_train = embeddings[train_idxs] 719 | # Using the np.std instead of np.linalg.norm improves results by around 1-1.5% 720 | np_embedding_train = np_embedding_train / np.std(np_embedding_train, axis=1, keepdims=True) 721 | # np_embedding_train = np_embedding_train / np.linalg.norm(np_embedding_train, axis=1, keepdims=True) 722 | labels_train = labels[train_idxs] 723 | 724 | np_embedding_test = embeddings[test_idxs] 725 | np_embedding_test = np_embedding_test / np.std(np_embedding_test, axis=1, keepdims=True) 726 | # np_embedding_test = np_embedding_test / np.linalg.norm(np_embedding_test, axis=1, keepdims=True) 727 | labels_test = labels[test_idxs] 728 | 729 | kdtree = KDTree(np_embedding_train) 730 | nns, nn_idxs = kdtree.query(np_embedding_test, k=1) 731 | labels_predicted = labels_train[nn_idxs] 732 | 733 | num_matches = sum(labels_predicted == labels_test) 734 | 735 | num_correct += num_matches 736 | num_tot += len(labels_predicted) 737 | 738 | # print("Accuracy: ", (100.0 * num_correct) / num_tot) 739 | return (100.0 * num_correct) / num_tot 740 | 741 | 742 | 743 | def build_inference_graph(images_deploy_pl, images_task_encode_pl, flags, is_training, 744 | is_primary, label_embeddings): 745 | num_filters = [round(flags.num_filters * pow(flags.block_size_growth, i)) for i in range(flags.num_blocks)] 746 | reuse = not is_primary 747 | alpha=None 748 | 749 | with tf.variable_scope('Model'): 750 | feature_extractor_encoding_scope = 'feature_extractor_encoder' 751 | 752 | features_task_encode = build_feature_extractor_graph(images=images_task_encode_pl, flags=flags, 753 | is_training=is_training, 754 | num_filters=num_filters, 755 | scope=feature_extractor_encoding_scope, 756 | reuse=False) 757 | if flags.encoder_sharing == 'shared': 758 | ecoder_reuse = True 759 | feature_extractor_classifier_scope = feature_extractor_encoding_scope 760 | elif flags.encoder_sharing == 'siamese': 761 | # TODO: in the case of pretrained feature extractor this is not good, 762 | # because the classfier part will be randomly initialized 763 | ecoder_reuse = False 764 | feature_extractor_classifier_scope = 'feature_extractor_classifier' 765 | else: 766 | raise Exception('Option not implemented') 767 | 768 | if flags.encoder_classifier_link == 'prototypical': 769 | #flags.task_encoder = 'class_mean' 770 | features_generic = build_feature_extractor_graph(images=images_deploy_pl, flags=flags, 771 | is_training=is_training, 772 | scope=feature_extractor_classifier_scope, 773 | num_filters=num_filters, 774 | reuse=ecoder_reuse) 775 | querys = None 776 | if 'query' in flags.att_input: 777 | querys = features_generic 778 | task_encoding, alpha = build_task_encoder(embeddings=features_task_encode, 779 | label_embeddings=label_embeddings, 780 | flags=flags, is_training=is_training, reuse=reuse, querys=querys, 781 | threshold=flags.alpha) 782 | if 'query' in flags.att_input: 783 | logits = build_prototypical_head_protoperquery(features_generic, task_encoding, flags, 784 | is_training=is_training) 785 | else: 786 | logits = build_prototypical_head(features_generic, task_encoding, flags, is_training=is_training) 787 | # logits_regularizer = build_regularizer_head(embeddings= features_task_encode, 788 | # label_embeddings=label_embeddings, flags=flags, 789 | # is_training=is_training ) 790 | else: 791 | raise Exception('Option not implemented') 792 | 793 | return logits, None, features_task_encode, features_generic, alpha 794 | 795 | 796 | 797 | 798 | def get_train_datasets(flags): 799 | mini_imagenet = _load_mini_imagenet(data_dir=flags.data_dir, split='sources') 800 | few_shot_data_train = Dataset(mini_imagenet) 801 | pretrain_data_train, pretrain_data_test = None, None 802 | return few_shot_data_train, pretrain_data_train, pretrain_data_test 803 | 804 | 805 | def get_pwc_learning_rate(global_step, flags): 806 | learning_rate = tf.train.piecewise_constant(global_step, [np.int64(flags.number_of_steps / 2), 807 | np.int64( 808 | flags.number_of_steps / 2 + flags.num_steps_decay_pwc), 809 | np.int64( 810 | flags.number_of_steps / 2 + 2 * flags.num_steps_decay_pwc)], 811 | [flags.init_learning_rate, flags.init_learning_rate * 0.1, 812 | flags.init_learning_rate * 0.01, 813 | flags.init_learning_rate * 0.001]) 814 | return learning_rate 815 | 816 | 817 | def create_hard_negative_batch(misclass, feed_dict, sess, few_shot_data_train, flags, 818 | images_deploy_pl, labels_deploy_pl, images_task_encode_pl, labels_task_encode_pl): 819 | """ 820 | 821 | :param logits: 822 | :param feed_dict: 823 | :param sess: 824 | :param few_shot_data_train: 825 | :param flags: 826 | :param images_deploy_pl: 827 | :param labels_deploy_pl: 828 | :param images_task_encode_pl: 829 | :param labels_task_encode_pl: 830 | :return: 831 | """ 832 | feed_dict_test = dict(feed_dict) 833 | misclass_test_final = 0.0 834 | misclass_history = np.zeros(flags.num_batches_neg_mining) 835 | for i in range(flags.num_batches_neg_mining): 836 | images_deploy, labels_deploy, images_task_encode, labels_task_encode = \ 837 | few_shot_data_train.next_few_shot_batch(deploy_batch_size=flags.train_batch_size, 838 | num_classes_test=flags.num_classes_train, 839 | num_shots=flags.num_shots_train, 840 | num_tasks=flags.num_tasks_per_batch) 841 | 842 | feed_dict_test[images_deploy_pl] = images_deploy.astype(dtype=np.float32) 843 | feed_dict_test[labels_deploy_pl] = labels_deploy 844 | feed_dict_test[images_task_encode_pl] = images_task_encode.astype(dtype=np.float32) 845 | feed_dict_test[labels_task_encode_pl] = labels_task_encode 846 | 847 | # logits 848 | misclass_test = sess.run(misclass, feed_dict=feed_dict_test) 849 | misclass_history[i] = misclass_test 850 | if misclass_test > misclass_test_final: 851 | misclass_test_final = misclass_test 852 | feed_dict = dict(feed_dict_test) 853 | 854 | return feed_dict 855 | 856 | 857 | def train(flags): 858 | log_dir = get_logdir_name(flags) 859 | flags.pretrained_model_dir = log_dir 860 | fout=open(log_dir+'/out','a') 861 | log_dir = os.path.join(log_dir, 'train') 862 | # This is setting to run evaluation loop only once 863 | flags.max_number_of_evaluations = 1 864 | flags.eval_interval_secs = 0 865 | image_size = get_image_size(flags.data_dir) 866 | 867 | with tf.Graph().as_default(): 868 | global_step = tf.Variable(0, trainable=False, name='global_step', dtype=tf.int64) 869 | global_step_pretrain = tf.Variable(0, trainable=False, name='global_step_pretrain', dtype=tf.int64) 870 | 871 | images_deploy_pl, labels_deploy_pl = placeholder_inputs( 872 | batch_size=flags.num_tasks_per_batch * flags.train_batch_size, 873 | image_size=image_size, scope='inputs/deploy') 874 | images_task_encode_pl, _ = placeholder_inputs( 875 | batch_size=flags.num_tasks_per_batch * flags.num_classes_train * flags.num_shots_train, 876 | image_size=image_size, scope='inputs/task_encode') 877 | with tf.variable_scope('inputs/task_encode'): 878 | labels_task_encode_pl_real = tf.placeholder(tf.int64, 879 | shape=(flags.num_tasks_per_batch * flags.num_classes_train), name='labels_real') 880 | labels_task_encode_pl = tf.placeholder(tf.int64, 881 | shape=(flags.num_tasks_per_batch * flags.num_classes_train), 882 | name='labels') 883 | 884 | #here is the word embedding layer for training 885 | 886 | emb_path = os.path.join(flags.data_dir, 'few-shot-wordemb-{}.npz'.format("train")) 887 | embedding_train = np.load(emb_path)["features"].astype(np.float32) 888 | print(embedding_train.dtype) 889 | logging.info("Loading mini-imagenet...") 890 | W_train = tf.constant(embedding_train, name="W_train") 891 | label_embeddings_train = tf.nn.embedding_lookup(W_train, labels_task_encode_pl_real) 892 | 893 | # Primary task operations 894 | logits, regularizer_logits, _, _, alpha = build_inference_graph(images_deploy_pl=images_deploy_pl, 895 | images_task_encode_pl=images_task_encode_pl, 896 | flags=flags, is_training=True, is_primary=True, 897 | label_embeddings=label_embeddings_train) 898 | loss = tf.reduce_mean( 899 | tf.nn.softmax_cross_entropy_with_logits(logits=logits, 900 | labels=tf.one_hot(labels_deploy_pl, flags.num_classes_train))) 901 | # Losses and optimizer 902 | regu_losses = slim.losses.get_regularization_losses() 903 | loss = tf.add_n([loss] + regu_losses) 904 | misclass = 1.0 - slim.metrics.accuracy(tf.argmax(logits, 1), labels_deploy_pl) 905 | 906 | # Learning rate 907 | if flags.lr_anneal == 'const': 908 | learning_rate = flags.init_learning_rate 909 | elif flags.lr_anneal == 'pwc': 910 | learning_rate = get_pwc_learning_rate(global_step, flags) 911 | elif flags.lr_anneal == 'exp': 912 | lr_decay_step = flags.number_of_steps // flags.n_lr_decay 913 | learning_rate = tf.train.exponential_decay(flags.init_learning_rate, global_step, lr_decay_step, 914 | 1.0 / flags.lr_decay_rate, staircase=True) 915 | else: 916 | raise Exception('Not implemented') 917 | 918 | # Optimizer 919 | if flags.optimizer == 'sgd': 920 | optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9) 921 | else: 922 | optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) 923 | 924 | train_op = slim.learning.create_train_op(total_loss=loss, optimizer=optimizer, global_step=global_step, 925 | clip_gradient_norm=flags.clip_gradient_norm) 926 | 927 | tf.summary.scalar('loss', loss) 928 | tf.summary.scalar('misclassification', misclass) 929 | tf.summary.scalar('learning_rate', learning_rate) 930 | # Merge all summaries except for pretrain 931 | summary = tf.summary.merge(tf.get_collection('summaries', scope='(?!pretrain).*')) 932 | 933 | 934 | # Get datasets 935 | few_shot_data_train, pretrain_data_train, pretrain_data_test = get_train_datasets(flags) 936 | # Define session and logging 937 | summary_writer = tf.summary.FileWriter(log_dir, flush_secs=1) 938 | saver = tf.train.Saver(max_to_keep=1, save_relative_paths=True) 939 | run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) 940 | run_metadata = tf.RunMetadata() 941 | supervisor = tf.train.Supervisor(logdir=log_dir, init_feed_dict=None, 942 | summary_op=None, 943 | init_op=tf.global_variables_initializer(), 944 | summary_writer=summary_writer, 945 | saver=saver, 946 | global_step=global_step, save_summaries_secs=flags.save_summaries_secs, 947 | save_model_secs=0) # flags.save_interval_secs 948 | 949 | with supervisor.managed_session() as sess: 950 | checkpoint_step = sess.run(global_step) 951 | if checkpoint_step > 0: 952 | checkpoint_step += 1 953 | 954 | eval_interval_steps = flags.eval_interval_steps 955 | for step in range(checkpoint_step, flags.number_of_steps): 956 | # get batch of data to compute classification loss 957 | images_deploy, labels_deploy, images_task_encode, labels_task_encode_real, labels_task_encode = \ 958 | few_shot_data_train.next_few_shot_batch_wordemb(deploy_batch_size=flags.train_batch_size, 959 | num_classes_test=flags.num_classes_train, 960 | num_shots=flags.num_shots_train, 961 | num_tasks=flags.num_tasks_per_batch) 962 | if flags.augment: 963 | images_deploy = image_augment(images_deploy) 964 | images_task_encode = image_augment(images_task_encode) 965 | 966 | feed_dict = {images_deploy_pl: images_deploy.astype(dtype=np.float32), labels_deploy_pl: labels_deploy, 967 | images_task_encode_pl: images_task_encode.astype(dtype=np.float32), 968 | labels_task_encode_pl_real: labels_task_encode_real, 969 | labels_task_encode_pl: labels_task_encode} 970 | 971 | 972 | t_batch = time.time() 973 | feed_dict = create_hard_negative_batch(misclass, feed_dict, sess, few_shot_data_train, flags, 974 | images_deploy_pl, labels_deploy_pl, images_task_encode_pl, 975 | labels_task_encode_pl_real) 976 | dt_batch = time.time() - t_batch 977 | 978 | t_train = time.time() 979 | loss,alpha_np = sess.run([train_op,alpha], feed_dict=feed_dict) 980 | dt_train = time.time() - t_train 981 | 982 | if step % 100 == 0: 983 | summary_str = sess.run(summary, feed_dict=feed_dict) 984 | summary_writer.add_summary(summary_str, step) 985 | summary_writer.flush() 986 | logging.info("step %d, loss : %.4g, dt: %.3gs, dt_batch: %.3gs" % (step, loss, dt_train, dt_batch)) 987 | fout.write("step: "+str(step)+' loss: '+str(loss)+'\n') 988 | 989 | if float(step) / flags.number_of_steps > 0.5: 990 | eval_interval_steps = flags.eval_interval_fine_steps 991 | 992 | if eval_interval_steps > 0 and step % eval_interval_steps == 0: 993 | saver.save(sess, os.path.join(log_dir, 'model'), global_step=step) 994 | eval(flags=flags, is_primary=True, fout=fout) 995 | 996 | if float(step) > 0.5 * flags.number_of_steps + flags.number_of_steps_to_early_stop: 997 | break 998 | 999 | 1000 | 1001 | class ModelLoader: 1002 | def __init__(self, model_path, batch_size, is_primary, split): 1003 | self.batch_size = batch_size 1004 | 1005 | latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir=os.path.join(model_path, 'train')) 1006 | step = int(os.path.basename(latest_checkpoint).split('-')[1]) 1007 | 1008 | flags = Namespace(load_and_save_params(default_params=dict(), exp_dir=model_path)) 1009 | image_size = get_image_size(flags.data_dir) 1010 | 1011 | with tf.Graph().as_default(): 1012 | images_deploy_pl, labels_deploy_pl = placeholder_inputs(batch_size=batch_size, 1013 | image_size=image_size, scope='inputs/deploy') 1014 | if is_primary: 1015 | task_encode_batch_size = flags.num_classes_test * flags.num_shots_test 1016 | images_task_encode_pl, _ = placeholder_inputs(batch_size=task_encode_batch_size, 1017 | image_size=image_size, 1018 | scope='inputs/task_encode') 1019 | with tf.variable_scope('inputs/task_encode'): 1020 | labels_task_encode_pl_real = tf.placeholder(tf.int64, 1021 | shape=(flags.num_classes_test), name='labels_real') 1022 | labels_task_encode_pl = tf.placeholder(tf.int64, 1023 | shape=(flags.num_classes_test), 1024 | name='labels') 1025 | self.vocab_size = tf.placeholder(tf.float32, shape=(), name='vocab_size') 1026 | self.tensor_images_deploy = images_deploy_pl 1027 | self.tensor_labels_deploy = labels_deploy_pl 1028 | self.tensor_labels_task_encode_real = labels_task_encode_pl_real 1029 | self.tensor_labels_task_encode = labels_task_encode_pl 1030 | self.tensor_images_task_encode = images_task_encode_pl 1031 | 1032 | emb_path = os.path.join(flags.data_dir, 'few-shot-wordemb-{}.npz'.format(split)) 1033 | embedding_train = np.load(emb_path)["features"].astype(np.float32) 1034 | print(embedding_train.dtype) 1035 | logging.info("Loading mini-imagenet...") 1036 | W = tf.constant(embedding_train, name="W_"+split) 1037 | 1038 | 1039 | label_embeddings_train = tf.nn.embedding_lookup(W, labels_task_encode_pl_real) 1040 | 1041 | # Primary task operations 1042 | logits, regularizer_logits, features_sample, features_query, self.alpha = build_inference_graph(images_deploy_pl=images_deploy_pl, 1043 | images_task_encode_pl=images_task_encode_pl, 1044 | flags=flags, is_training=False, is_primary=True, 1045 | label_embeddings=label_embeddings_train) 1046 | loss = tf.reduce_mean( 1047 | tf.nn.softmax_cross_entropy_with_logits(logits=logits, 1048 | labels=tf.one_hot(labels_deploy_pl, flags.num_classes_test))) 1049 | regularizer_loss = 0.0 1050 | 1051 | # Losses and optimizer 1052 | regu_losses = slim.losses.get_regularization_losses() 1053 | 1054 | loss = tf.add_n([loss] + regu_losses + [regularizer_loss]) 1055 | 1056 | init_fn = slim.assign_from_checkpoint_fn( 1057 | latest_checkpoint, 1058 | slim.get_model_variables('Model')) 1059 | 1060 | config = tf.ConfigProto(allow_soft_placement=True) 1061 | config.gpu_options.allow_growth = True 1062 | self.sess = tf.Session(config=config) 1063 | 1064 | # Run init before loading the weights 1065 | self.sess.run(tf.global_variables_initializer()) 1066 | # Load weights 1067 | init_fn(self.sess) 1068 | 1069 | self.flags = flags 1070 | self.logits = logits 1071 | self.loss = loss 1072 | self.features_sample = features_sample 1073 | self.features_query = features_query 1074 | self.logits_size = self.logits.get_shape().as_list()[-1] 1075 | self.step = step 1076 | self.is_primary = is_primary 1077 | 1078 | log_dir = get_logdir_name(flags) 1079 | graphpb_txt = str(tf.get_default_graph().as_graph_def()) 1080 | pathlib.Path(os.path.join(log_dir, 'eval')).mkdir(parents=True, exist_ok=True) 1081 | with open(os.path.join(log_dir, 'eval', 'graph.pbtxt'), 'w') as f: 1082 | f.write(graphpb_txt) 1083 | 1084 | def eval(self, data_dir, num_cases_test, split='target_val'): 1085 | data_set = Dataset(_load_mini_imagenet(data_dir=data_dir, split=split)) 1086 | 1087 | num_batches = num_cases_test // self.batch_size 1088 | num_correct = 0.0 1089 | num_tot = 0.0 1090 | loss_tot = 0.0 1091 | final_alpha=[] 1092 | for i in range(num_batches): 1093 | num_classes, num_shots = self.flags.num_classes_test, self.flags.num_shots_test 1094 | 1095 | images_deploy, labels_deploy, images_task_encode, labels_task_encode_real, labels_task_encode = \ 1096 | data_set.next_few_shot_batch_wordemb(deploy_batch_size=self.batch_size, 1097 | num_classes_test=num_classes, num_shots=num_shots, 1098 | num_tasks=1) 1099 | 1100 | 1101 | feed_dict = {self.tensor_images_deploy: images_deploy.astype(dtype=np.float32), 1102 | self.tensor_labels_task_encode_real: labels_task_encode_real, 1103 | self.tensor_labels_deploy: labels_deploy, 1104 | self.tensor_labels_task_encode: labels_task_encode, 1105 | self.tensor_images_task_encode: images_task_encode.astype(dtype=np.float32)} 1106 | [logits, loss, alpha] = self.sess.run([self.logits, self.loss, self.alpha], feed_dict) 1107 | final_alpha.append(alpha) 1108 | labels_deploy_pred = np.argmax(logits, axis=-1) 1109 | 1110 | num_matches = sum(labels_deploy_pred == labels_deploy) 1111 | num_correct += num_matches 1112 | num_tot += len(labels_deploy_pred) 1113 | loss_tot += loss 1114 | if split=='target_tst': 1115 | log_dir = get_logdir_name(self.flags) 1116 | pathlib.Path(os.path.join(log_dir, 'eval')).mkdir(parents=True, exist_ok=True) 1117 | pkl.dump(final_alpha,open(os.path.join(log_dir, 'eval', 'lambdas.pkl'), "wb")) 1118 | 1119 | return num_correct / num_tot, loss_tot / num_batches 1120 | 1121 | 1122 | def get_few_shot_idxs(labels, classes, num_shots): 1123 | train_idxs, test_idxs = [], [] 1124 | idxs = np.arange(len(labels)) 1125 | for cl in classes: 1126 | class_idxs = idxs[labels == cl] 1127 | class_idxs_train = np.random.choice(class_idxs, size=num_shots, replace=False) 1128 | class_idxs_test = np.setxor1d(class_idxs, class_idxs_train) 1129 | 1130 | train_idxs.extend(class_idxs_train) 1131 | test_idxs.extend(class_idxs_test) 1132 | 1133 | assert set(class_idxs_train).isdisjoint(test_idxs) 1134 | 1135 | return np.array(train_idxs), np.array(test_idxs) 1136 | 1137 | 1138 | def test(flags): 1139 | test_dataset = _load_mini_imagenet(data_dir=flags.data_dir, split='target_val') 1140 | 1141 | # test_dataset = _load_mini_imagenet(data_dir=flags.data_dir, split='sources') 1142 | images = test_dataset[0] 1143 | labels = test_dataset[1] 1144 | 1145 | embedding_model = ModelLoader(flags.pretrained_model_dir, batch_size=100) 1146 | embeddings = embedding_model.embed(images=test_dataset[0]) 1147 | embedding_model = None 1148 | print("Accuracy test raw embedding: ", get_nearest_neighbour_acc(flags, embeddings, labels)) 1149 | 1150 | 1151 | def get_agg_misclassification(logits_dict, labels_dict): 1152 | summary_ops = [] 1153 | update_ops = {} 1154 | for key, logits in logits_dict.items(): 1155 | accuracy, update = slim.metrics.streaming_accuracy(tf.argmax(logits, 1), labels_dict[key]) 1156 | 1157 | names_to_values, names_to_updates = slim.metrics.aggregate_metric_map( 1158 | {'misclassification_' + key: (1.0 - accuracy, update)}) 1159 | 1160 | for metric_name, metric_value in names_to_values.items(): 1161 | op = tf.summary.scalar(metric_name, metric_value) 1162 | op = tf.Print(op, [metric_value], metric_name) 1163 | summary_ops.append(op) 1164 | 1165 | for update_name, update_op in names_to_updates.items(): 1166 | update_ops[update_name] = update_op 1167 | return summary_ops, update_ops 1168 | 1169 | 1170 | def eval(flags, is_primary, fout): 1171 | log_dir = get_logdir_name(flags) 1172 | if is_primary: 1173 | aux_prefix = '' 1174 | else: 1175 | aux_prefix = 'aux/' 1176 | 1177 | eval_writer = summary_writer(log_dir + '/eval') 1178 | i = 0 1179 | last_step = -1 1180 | while i < flags.max_number_of_evaluations: 1181 | latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir=flags.pretrained_model_dir) 1182 | model_step = int(os.path.basename(latest_checkpoint or '0-0').split('-')[1]) 1183 | if last_step < model_step: 1184 | results = {} 1185 | model_train = ModelLoader(model_path=flags.pretrained_model_dir, batch_size=flags.eval_batch_size, 1186 | is_primary=is_primary,split='train') 1187 | acc_trn, loss_trn = model_train.eval(data_dir=flags.data_dir, num_cases_test=flags.num_cases_test, 1188 | split='sources') 1189 | 1190 | model_val = ModelLoader(model_path=flags.pretrained_model_dir, batch_size=flags.eval_batch_size, 1191 | is_primary=is_primary, split='val') 1192 | acc_val, loss_val = model_val.eval(data_dir=flags.data_dir, num_cases_test=flags.num_cases_test, 1193 | split='target_val') 1194 | 1195 | model_test = ModelLoader(model_path=flags.pretrained_model_dir, batch_size=flags.eval_batch_size, 1196 | is_primary=is_primary, split='test') 1197 | acc_tst, loss_tst = model_test.eval(data_dir=flags.data_dir, num_cases_test=flags.num_cases_test, 1198 | split='target_tst') 1199 | 1200 | results[aux_prefix + "accuracy_target_tst"] = acc_tst 1201 | results[aux_prefix + "accuracy_target_val"] = acc_val 1202 | results[aux_prefix + "accuracy_sources"] = acc_trn 1203 | 1204 | results[aux_prefix + "loss_target_tst"] = loss_tst 1205 | results[aux_prefix + "loss_target_val"] = loss_val 1206 | results[aux_prefix + "loss_sources"] = loss_trn 1207 | 1208 | last_step = model_train.step 1209 | eval_writer(model_train.step, **results) 1210 | logging.info("accuracy_%s: %.3g, accuracy_%s: %.3g, accuracy_%s: %.3g, loss_%s: %.3g, loss_%s: %.3g, loss_%s: %.3g." 1211 | % ( 1212 | aux_prefix + "target_tst", acc_tst, aux_prefix + "target_val", acc_val, aux_prefix + "sources", 1213 | acc_trn, aux_prefix + "target_tst", loss_tst, aux_prefix + "target_val", loss_val, aux_prefix + "sources", 1214 | loss_trn)) 1215 | fout.write("accuracy_test: "+str(acc_tst)+" accuracy_val: "+str(acc_val)+" accuracy_test: "+str(acc_trn)) 1216 | if flags.eval_interval_secs > 0: 1217 | time.sleep(flags.eval_interval_secs) 1218 | i = i + 1 1219 | 1220 | 1221 | 1222 | 1223 | 1224 | def image_augment(images): 1225 | """ 1226 | 1227 | :param images: 1228 | :return: 1229 | """ 1230 | pad_percent = 0.125 1231 | flip_proba = 0.5 1232 | image_size = images.shape[1] 1233 | pad_size = int(pad_percent * image_size) 1234 | max_crop = 2 * pad_size 1235 | 1236 | images_aug = np.pad(images, ((0, 0), (pad_size, pad_size), (pad_size, pad_size), (0, 0)), mode='constant') 1237 | output = [] 1238 | for image in images_aug: 1239 | if np.random.rand() < flip_proba: 1240 | image = np.flip(image, axis=1) 1241 | crop_val = np.random.randint(0, max_crop) 1242 | image = image[crop_val:crop_val + image_size, crop_val:crop_val + image_size, :] 1243 | output.append(image) 1244 | return np.asarray(output) 1245 | 1246 | 1247 | def main(argv=None): 1248 | config = tf.ConfigProto(allow_soft_placement=True) 1249 | config.gpu_options.per_process_gpu_memory_fraction = 1.0 1250 | config.gpu_options.allow_growth = True 1251 | sess = tf.Session(config=config) 1252 | 1253 | print(os.getcwd()) 1254 | 1255 | default_params = get_arguments() 1256 | log_dir = get_logdir_name(flags=default_params) 1257 | 1258 | pathlib.Path(log_dir).mkdir(parents=True, exist_ok=True) 1259 | # This makes sure that we can store a json and recove a namespace back 1260 | flags = Namespace(load_and_save_params(vars(default_params), log_dir)) 1261 | 1262 | if flags.mode == 'train': 1263 | train(flags=flags) 1264 | elif flags.mode == 'eval': 1265 | eval(flags=flags, is_primary=True) 1266 | elif flags.mode == 'test': 1267 | test(flags=flags) 1268 | 1269 | 1270 | if __name__ == '__main__': 1271 | tf.app.run() -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2021 ServiceNow 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright 2021 ServiceNow, Inc. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /common/gen_experiments.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | import os 3 | import json 4 | from subprocess import Popen 5 | import logging 6 | 7 | 8 | def find_variables(param_dict): 9 | """Find items in the dictionnary that are lists and consider them as variables.""" 10 | variables = [] 11 | for key, val in param_dict.items(): 12 | if isinstance(val, list): 13 | variables.append(key) 14 | return variables 15 | 16 | 17 | def grid_search(param_dict): 18 | """Find the variables in param_dict and yields every instance part of the cartesian product. 19 | 20 | args: 21 | param_dict: dictionnary of parameters. Every item that is a list will be crossvalidated. 22 | 23 | yields: A dictionnary of parameters where lists are replaced with one of their instance. 24 | """ 25 | variables = [] 26 | for key, val in param_dict.items(): 27 | if isinstance(val, list): 28 | variables.append([(key, element) for element in val]) 29 | 30 | for experiment in product(*variables): 31 | yield dict(experiment) 32 | 33 | 34 | def make_experiment_name(experiment): 35 | """Create a readable name containing the name and value of the variables.""" 36 | args = [] 37 | for name, value in experiment.items(): 38 | if isinstance(value, float): 39 | args.append("%s=%.4g" % (name, value)) 40 | else: 41 | args.append("%s=%s" % (name, value)) 42 | return ';'.join(args) 43 | 44 | 45 | def gen_experiments_dir(param_dict, root_dir, exp_description, cmd=None, blocking=False, borgy_args=None): 46 | """Generate all directories with their json and launch cmd with the flag --exp_dir.""" 47 | process_list = [] 48 | for i, experiment in enumerate(grid_search(param_dict)): 49 | name = make_experiment_name(experiment) 50 | print("Exp %d: %s." % (i, name)) 51 | param_dict.update(experiment) 52 | 53 | exp_dir_borgy = os.path.join(root_dir, name) 54 | exp_dir = '/mnt/' + exp_dir_borgy 55 | if not os.path.exists(exp_dir): 56 | os.makedirs(exp_dir) 57 | 58 | param_path = os.path.join(exp_dir, 'params.json') 59 | 60 | with open(param_path, 'w') as fd: 61 | json.dump(param_dict, fd, indent=4) 62 | 63 | if cmd is not None: 64 | 65 | if borgy_args: 66 | cmd_ = """cd '%s'; stdbuf -oL '%s' --exp_dir='%s' 1>>stdout 2>>stderr""" % ( 67 | exp_dir_borgy, cmd, exp_dir_borgy) 68 | args = ['borgy', 'submit', '--name', "%s_(%s)" % (exp_description, name)] + borgy_args + ['--', 'bash', 69 | '-c', cmd_, ] 70 | str_cmd = ' '.join(['"' + arg + '"' for arg in args]) 71 | print(str_cmd) 72 | 73 | with open(os.path.join(exp_dir, 'borgy_submit.cmd'), 'w') as fd: 74 | fd.write(str_cmd) 75 | 76 | process = Popen(args) 77 | if blocking: 78 | process.wait() 79 | process_list.append(process) 80 | 81 | else: 82 | args = [cmd, '--exp_dir=%s' % exp_dir] 83 | with open(os.path.join(exp_dir, 'stderr'), 'w') as err_fd: 84 | with open(os.path.join(exp_dir, 'stdout'), 'w') as out_fd: 85 | process_list.append(Popen(args, stderr=err_fd, stdout=out_fd)) 86 | 87 | if blocking: 88 | for process in process_list: 89 | process.wait() 90 | 91 | 92 | def re_run(root_dir, cmd=None, blocking=False, borgy_args=None, exp_dir_list=None): 93 | process_list = [] 94 | 95 | for exp in exp_dir_list or os.listdir(root_dir): 96 | exp_dir = os.path.join(root_dir, exp) 97 | # Remove the "mnt" part 98 | exp_dir = "/" + "/".join((exp_dir.split("/")[2:])) 99 | 100 | if cmd is not None: 101 | 102 | if borgy_args: 103 | cmd_ = """cd '%s'; stdbuf -oL '%s' --exp_dir='%s' 1>>stdout 2>>stderr""" % (exp_dir, cmd, exp_dir) 104 | args = ['borgy', 'submit', '--name', "%s" % exp] + borgy_args + ['--', 'bash', '-c', cmd_, ] 105 | print(' '.join(args)) 106 | process_list.append(Popen(args)) 107 | 108 | else: 109 | args = [cmd, '--exp_dir=%s' % exp_dir] 110 | with open(os.path.join(exp_dir, 'stderr'), 'w') as err_fd: 111 | with open(os.path.join(exp_dir, 'stdout'), 'w') as out_fd: 112 | process_list.append(Popen(args, stderr=err_fd, stdout=out_fd)) 113 | 114 | if blocking: 115 | for process in process_list: 116 | process.wait() 117 | 118 | 119 | def load_and_save_params(default_params, exp_dir, ignore_existing=False): 120 | """Update default_params with params.json from exp_dir and overwrite params.json with updated version.""" 121 | default_params = json.loads(json.dumps(default_params)) 122 | param_path = os.path.join(exp_dir, 'params.json') 123 | logging.info("Searching for '%s'" % param_path) 124 | if os.path.exists(param_path) and not ignore_existing: 125 | logging.info("Loading existing params.") 126 | with open(param_path, 'r') as fd: 127 | params = json.load(fd) 128 | default_params.update(params) 129 | 130 | if not os.path.exists(exp_dir): 131 | os.makedirs(exp_dir) 132 | 133 | with open(param_path, 'w') as fd: 134 | json.dump(default_params, fd, indent=4) 135 | 136 | return default_params 137 | -------------------------------------------------------------------------------- /common/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow.python.client import timeline 4 | import os 5 | import fnmatch 6 | 7 | 8 | def variable_report(report_non_trainable=True): 9 | """Create a small report, showing the shapes of all trainable variables.""" 10 | total_params = 0 11 | lines = ['Trainable Variables Report', 12 | '--------------------------'] 13 | 14 | trainable_variables = tf.trainable_variables() 15 | 16 | for var in trainable_variables: 17 | shape = var.get_shape().as_list() 18 | num_params = np.prod(shape) 19 | total_params += num_params 20 | lines.append("shape: %15s, %5d, %s, %s" % (shape, num_params, var.name, var.dtype)) 21 | lines.append("Total number of trainable parameters: %d" % total_params) 22 | 23 | if report_non_trainable: 24 | lines.extend(['', 'Non-Trainable Variables', '---------------------']) 25 | for var in tf.global_variables(): 26 | if var in trainable_variables: 27 | continue 28 | shape = var.get_shape().as_list() 29 | num_params = np.prod(shape) 30 | lines.append("shape: %15s, %5d, %s, %s" % (shape, num_params, var.name, var.dtype)) 31 | 32 | return '\n'.join(lines) 33 | 34 | 35 | def variables_by_name(pattern, variable_list=None): 36 | if variable_list is None: 37 | variable_list = tf.global_variables() 38 | return [var for var in variable_list if fnmatch.fnmatch(var.name, pattern)] 39 | 40 | 41 | def unique_variable_by_name(pattern, variable_list=None): 42 | var_list = variables_by_name(pattern, variable_list) 43 | if len(var_list) != 0: 44 | raise ValueError("Non unique variable. list = %s" % str(var_list)) 45 | return var_list[0] 46 | 47 | 48 | def profiled_run(sess, ops, feed_dict, is_profiling=False, log_dir=None): 49 | if not is_profiling: 50 | return sess.run(ops, feed_dict=feed_dict) 51 | else: 52 | if log_dir is None: 53 | raise ValueError("You need to provide a log_dir for profiling.") 54 | run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) 55 | run_metadata = tf.RunMetadata() 56 | outputs = sess.run(ops, feed_dict=feed_dict, options=run_options, run_metadata=run_metadata) 57 | 58 | # Create the Timeline object, and write it to a json 59 | tl = timeline.Timeline(run_metadata.step_stats) 60 | ctf = tl.generate_chrome_trace_format() 61 | with open(os.path.join(log_dir, 'timeline.json'), 'w') as f: 62 | f.write(ctf) 63 | 64 | return outputs 65 | 66 | 67 | def summary_writer(log_dir): 68 | """Convenient wrapper for writing summaries.""" 69 | writer = tf.summary.FileWriter(log_dir) 70 | 71 | def call(step, **value_dict): 72 | summary = tf.Summary() 73 | for tag, value in value_dict.items(): 74 | summary.value.add(tag=tag, simple_value=value) 75 | writer.add_summary(summary, step) 76 | writer.flush() 77 | 78 | return call 79 | 80 | 81 | def uniform(n): 82 | def sampler(n_samples, rng=np.random): 83 | return rng.choice(n, n_samples) 84 | 85 | return sampler 86 | 87 | 88 | def categorical(probs): 89 | probs = np.asarray(probs) 90 | np.testing.assert_array_less(0, probs) 91 | cumsum = np.cumsum(probs) 92 | 93 | def sampler(n_samples, rng=np.random): 94 | return cumsum.searchsorted(rng.uniform(0, cumsum[-1], size=n_samples)) 95 | 96 | return sampler 97 | 98 | 99 | class Dataset(object): 100 | """Basic dataset interface.""" 101 | 102 | def __init__(self, fields, fn=None, sampler=None): 103 | """Store a tuple of fields and access it through next_batch interface. 104 | 105 | By default, field[0] and field[1] are considered to be x and y. More fields can be 106 | stored but they are unnamed. 107 | """ 108 | self.fn = fn 109 | self.n_samples = len(fields[0]) 110 | self.fields = fields 111 | if sampler is None: 112 | self.sampler = uniform(self.n_samples) 113 | else: 114 | self.sampler = sampler 115 | 116 | @property 117 | def x(self): 118 | return self.fields[0] 119 | 120 | @property 121 | def y(self): 122 | return self.fields[1] 123 | 124 | def next_batch(self, n, rng=np.random): 125 | idx = self.sampler(n, rng) 126 | return tuple(field[idx] for field in self.fields) 127 | 128 | def get_few_shot_idxs(self, labels, classes, num_shots): 129 | train_idxs, test_idxs = [], [] 130 | idxs = np.arange(len(labels)) 131 | for cl in classes: 132 | class_idxs = idxs[labels == cl] 133 | class_idxs_train = np.random.choice(class_idxs, size=num_shots, replace=False) 134 | class_idxs_test = np.setxor1d(class_idxs, class_idxs_train) 135 | 136 | train_idxs.extend(class_idxs_train) 137 | test_idxs.extend(class_idxs_test) 138 | 139 | assert set(class_idxs_train).isdisjoint(test_idxs) 140 | 141 | return np.array(train_idxs), np.array(test_idxs) 142 | 143 | def next_few_shot_batch(self, deploy_batch_size, num_classes_test, num_shots, num_tasks): 144 | labels = self.y 145 | classes = np.unique(labels) 146 | 147 | deploy_images = [] 148 | deploy_labels = [] 149 | task_encode_images = [] 150 | task_encode_labels = [] 151 | for task in range(num_tasks): 152 | test_classes = np.random.choice(classes, size=num_classes_test, replace=False) 153 | 154 | task_encode_idxs, deploy_idxs = self.get_few_shot_idxs(labels, classes=test_classes, num_shots=num_shots) 155 | deploy_idxs = np.random.choice(deploy_idxs, size=deploy_batch_size, replace=False) 156 | 157 | labels_deploy = labels[deploy_idxs] 158 | labels_task_encode = labels[task_encode_idxs] 159 | 160 | class_map = {c: i for i, c in enumerate(test_classes)} 161 | class_map_fn = np.vectorize(lambda t: class_map[t]) 162 | 163 | deploy_images.append(self.x[deploy_idxs]) 164 | deploy_labels.append(class_map_fn(labels_deploy)) 165 | task_encode_images.append(self.x[task_encode_idxs]) 166 | task_encode_labels.append(class_map_fn(labels_task_encode)) 167 | 168 | return np.concatenate(deploy_images, axis=0), np.concatenate(deploy_labels, axis=0), \ 169 | np.concatenate(task_encode_images, axis=0), np.concatenate(task_encode_labels, axis=0) 170 | 171 | def next_few_shot_batch_wordemb(self, deploy_batch_size, num_classes_test, num_shots, num_tasks): 172 | labels = self.y 173 | classes = np.unique(labels) 174 | 175 | deploy_images = [] 176 | deploy_labels = [] 177 | task_encode_images = [] 178 | task_encode_labels_real = [] 179 | task_encode_labels = [] 180 | for task in range(num_tasks): 181 | test_classes = np.random.choice(classes, size=num_classes_test, replace=False) 182 | 183 | task_encode_idxs, deploy_idxs = self.get_few_shot_idxs(labels, classes=test_classes, num_shots=num_shots) 184 | deploy_idxs = np.random.choice(deploy_idxs, size=deploy_batch_size, replace=False) 185 | 186 | labels_deploy = labels[deploy_idxs] 187 | labels_task_encode = labels[task_encode_idxs] 188 | # print(labels_task_encode) 189 | # print(test_classes) 190 | 191 | class_map = {c: i for i, c in enumerate(test_classes)} 192 | class_map_fn = np.vectorize(lambda t: class_map[t]) 193 | 194 | deploy_images.append(self.x[deploy_idxs]) 195 | deploy_labels.append(class_map_fn(labels_deploy)) 196 | task_encode_images.append(self.x[task_encode_idxs]) 197 | task_encode_labels_real.append(test_classes) 198 | task_encode_labels.append([i for i in range(num_classes_test)]) 199 | 200 | return np.concatenate(deploy_images, axis=0), np.concatenate(deploy_labels, axis=0), \ 201 | np.concatenate(task_encode_images, axis=0), np.concatenate(task_encode_labels_real, axis=0), \ 202 | np.concatenate(task_encode_labels, axis=0) 203 | 204 | def next_triplet_batch_with_hard_negative(self, sess, triplet_logits_neg_mine, anchor_features_placeholder, 205 | positive_features_placeholder, negative_features_placeholder, batch_size, 206 | num_negatives_mining): 207 | """Generator for the triplet batches (anchor, positive, negative) based on the facenet paper""" 208 | labels = self.y 209 | classes = np.unique(labels) 210 | all_idxs = np.arange(self.n_samples) 211 | 212 | anchor_idxs = np.zeros(shape=(batch_size,), dtype=np.int32) 213 | positive_idxs = np.zeros(shape=(batch_size,), dtype=np.int32) 214 | negative_idxs_hard = np.zeros(shape=(batch_size,), dtype=np.int32) 215 | 216 | anchor_classes = np.random.choice(classes, size=batch_size, replace=True) 217 | for i in range(batch_size): 218 | pos_and_anchor_idxs = np.random.choice(all_idxs[labels == anchor_classes[i]], size=2, replace=False) 219 | 220 | anchor_idxs[i] = pos_and_anchor_idxs[0] 221 | positive_idxs[i] = pos_and_anchor_idxs[1] 222 | negative_idxs = np.random.choice(all_idxs[labels != anchor_classes[i]], size=num_negatives_mining, 223 | replace=False) 224 | 225 | anchor_features, positive_features, negative_features = self.x[anchor_idxs[i]], self.x[positive_idxs[i]], \ 226 | self.x[negative_idxs] 227 | 228 | feed_dict = {} 229 | feed_dict[anchor_features_placeholder] = np.expand_dims(anchor_features, axis=0) 230 | feed_dict[positive_features_placeholder] = np.expand_dims(positive_features, axis=0) 231 | feed_dict[negative_features_placeholder] = negative_features 232 | triplet_logits_neg_mine_np = sess.run(triplet_logits_neg_mine, feed_dict=feed_dict) 233 | 234 | negative_idxs_hard[i] = np.argmin(triplet_logits_neg_mine_np) 235 | 236 | return self.x[anchor_idxs], self.x[positive_idxs], self.x[negative_idxs_hard] 237 | 238 | def next_triplet_batch(self, batch_size): 239 | """Generator for the triplet batches (anchor, positive, negative) based on the facenet paper""" 240 | labels = self.y 241 | classes = np.unique(labels) 242 | all_idxs = np.arange(self.n_samples) 243 | anchor_idxs = np.zeros(shape=(batch_size,), dtype=np.int32) 244 | positive_idxs = np.zeros(shape=(batch_size,), dtype=np.int32) 245 | negative_idxs = np.zeros(shape=(batch_size,), dtype=np.int32) 246 | 247 | chosen_classes = np.random.choice(classes, size=2, replace=False) 248 | pos_and_anchor_idxs = np.random.choice(all_idxs[labels == chosen_classes[0]], size=2, replace=False) 249 | for i in range(batch_size): 250 | # chosen_classes = np.random.choice(classes, size=2, replace=False) 251 | # pos_and_anchor_idxs = np.random.choice(all_idxs[labels == chosen_classes[0]], size=2, replace=False) 252 | 253 | anchor_idxs[i] = pos_and_anchor_idxs[0] 254 | positive_idxs[i] = pos_and_anchor_idxs[1] 255 | negative_idxs[i] = np.random.choice(all_idxs[labels != chosen_classes[0]], size=1, replace=False) 256 | 257 | images = self.x 258 | return images[anchor_idxs], images[positive_idxs], images[negative_idxs] 259 | 260 | def sequential_batches(self, batch_size, n_batches, rng=np.random): 261 | """Generator for a random sequence of minibatches with no overlap.""" 262 | permutation = rng.permutation(self.n_samples) 263 | for batch_idx in range(n_batches): 264 | start = batch_idx * batch_size 265 | end = np.minimum((start + batch_size), self.n_samples) 266 | idx = permutation[start:end] 267 | yield tuple(field[idx] for field in self.fields) 268 | if end == self.n_samples: 269 | break 270 | 271 | 272 | class Bunch: 273 | def __init__(self, **kwargs): 274 | self.__init__ = kwargs 275 | 276 | 277 | ACTIVATION_MAP = {"relu": tf.nn.relu, 278 | "selu": tf.nn.selu, 279 | "swish-1": lambda x, name='swish-1': tf.multiply(x, tf.nn.sigmoid(x), name=name), 280 | } -------------------------------------------------------------------------------- /datasets/create_dataset_miniImagenet.py: -------------------------------------------------------------------------------- 1 | """Creates the mini-ImageNet dataset.""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import argparse 7 | import csv 8 | import os 9 | import sys 10 | 11 | import zipfile 12 | import numpy as np 13 | import scipy.misc 14 | import io 15 | 16 | # Make train, validation and test splits deterministic from one run to another 17 | np.random.seed(2017 + 5 + 17) 18 | 19 | def get_class_label_dict(class_label_addr): 20 | lines = [x.strip() for x in open(class_label_addr, 'r').readlines()] 21 | cld={} 22 | for l in lines: 23 | tl=l.split(' ') 24 | if tl[0] not in cld.keys(): 25 | cld[tl[0]]=tl[2].lower() 26 | return cld 27 | # def load_embedding_dict(emb_addr): 28 | # lines = [x.strip() for x in open(emb_addr, 'r', encoding="utf-8").readlines()] 29 | # emb_dict = {} 30 | # for l in lines: 31 | # w = l.split(' ') 32 | # print(l) 33 | # if w[0] not in emb_dict.keys(): 34 | # tmpv = [float(w[i]) for i in range(1, len(w))] 35 | # emb_dict[w[0]] = tmpv 36 | # return emb_dict 37 | 38 | def load_embedding_dict(emb_addr): 39 | fin = io.open(emb_addr, 'r', encoding='utf-8', newline='\n', errors='ignore') 40 | n, d = map(int, fin.readline().split()) 41 | data = {} 42 | for line in fin: 43 | #print(line) 44 | tokens = line.rstrip().split(' ') 45 | tmpv = [float(tokens[i]) for i in range(1, len(tokens))] 46 | data[tokens[0]] = tmpv 47 | return data 48 | 49 | def get_embeddings_for_labels(all_classes, cld, emb_dict): 50 | label_list = [] 51 | emb_list = [] 52 | no_emb = 0 53 | print(all_classes) 54 | print(len(all_classes)) 55 | for c in all_classes: 56 | label_list.append(cld[c]) 57 | print(label_list) 58 | print(len(label_list)) 59 | for v in label_list: 60 | # check the embeddings of labels 61 | #print(v) 62 | labels = v.split('_') 63 | tmpv = np.zeros(300) 64 | tmpl = [] 65 | c = 0 66 | for l in labels: 67 | if l in emb_dict.keys(): 68 | tmpv += emb_dict[l] 69 | tmpl.append(l) 70 | c += 1 71 | if len(labels) != 1: 72 | if c != len(labels): 73 | print(v, c, tmpl) 74 | if c != 0: 75 | emb_list.append(tmpv / c) 76 | else: 77 | emb_list.append(tmpv) 78 | no_emb += 1 79 | print("no embedding for " + v) 80 | print(no_emb) 81 | return emb_list 82 | 83 | 84 | def main(data_dir, output_dir, emb_addr, class_label_addr): 85 | print("loading the embedding dictionary....") 86 | cld = get_class_label_dict(class_label_addr) 87 | emb_dict = load_embedding_dict(emb_addr) 88 | for split in ('val', 'test', 'train'): 89 | # List of selected image files for the current split 90 | file_paths = [] 91 | 92 | with open('{}.csv'.format(split), 'r') as csv_file: 93 | # Read the CSV file for that split, and get all classes present in 94 | # that split. 95 | reader = csv.DictReader(csv_file, delimiter=',') 96 | file_paths, labels = zip( 97 | *((os.path.join('images', row['filename']), row['label']) 98 | for row in reader)) 99 | all_labels = sorted(list(set(labels))) 100 | print("getting word embeddings....") 101 | emb_list = get_embeddings_for_labels(all_labels,cld, emb_dict) 102 | print("saving word embeddings...") 103 | np.savez( 104 | os.path.join(output_dir, 'few-shot-wordemb-{}.npz'.format(split)), 105 | features=np.asarray(emb_list)) 106 | 107 | archive = zipfile.ZipFile(os.path.join(data_dir, 'images.zip'), 'r') 108 | 109 | # Processing loop over examples 110 | features, targets = [], [] 111 | for i, (file_path, label) in enumerate(zip(file_paths, labels)): 112 | # Write progress to stdout 113 | sys.stdout.write( 114 | '\r>> Processing {} image {}/{}'.format( 115 | split, i + 1, len(file_paths))) 116 | sys.stdout.flush() 117 | 118 | # Load image in RGB mode to ensure image.ndim == 3 119 | file_path = archive.open(file_path) 120 | image = scipy.misc.imread(file_path, mode='RGB') 121 | # Infer class from filename. 122 | label = all_labels.index(label) 123 | 124 | # Central square crop of size equal to the image's smallest side. 125 | height, width, channels = image.shape 126 | crop_size = min(height, width) 127 | start_height = (height // 2) - (crop_size // 2) 128 | start_width = (width // 2) - (crop_size // 2) 129 | image = image[ 130 | start_height: start_height + crop_size, 131 | start_width: start_width + crop_size, :] 132 | 133 | # Resize image to 84 x 84. 134 | image = scipy.misc.imresize(image, (84, 84), interp='bilinear') 135 | 136 | features.append(image) 137 | targets.append(label) 138 | 139 | sys.stdout.write('\n') 140 | sys.stdout.flush() 141 | 142 | # Save dataset to disk 143 | features = np.stack(features, axis=0) 144 | targets = np.stack(targets, axis=0) 145 | permutation = np.random.permutation(len(features)) 146 | features = features[permutation] 147 | targets = targets[permutation] 148 | np.savez( 149 | os.path.join(output_dir, 'few-shot-{}.npz'.format(split)), 150 | features=features, targets=targets) 151 | 152 | 153 | if __name__ == '__main__': 154 | parser = argparse.ArgumentParser() 155 | parser.add_argument( 156 | '--data-dir', type=str, 157 | default=os.path.join(os.sep, 'mnt', 'datasets', 'public', 'mini-imagenet', 'raw-data'), 158 | help='Path to the raw data') 159 | parser.add_argument( 160 | '--output-dir', type=str, default=os.path.join(os.sep, 'mnt', 'datasets', 'public', 'mini-imagenet'), 161 | help='Output directory') 162 | parser.add_argument( 163 | '--emb_addr', type=str, 164 | default=os.path.join(os.sep, 'mnt', 'datasets', 'public', 'mini-imagenet', 'raw-data'), 165 | help='Path to the raw data') 166 | parser.add_argument( 167 | '--class_label_addr', type=str, default=os.path.join(os.sep, 'mnt', 'datasets', 'public', 'mini-imagenet'), 168 | help='Output directory') 169 | 170 | args = parser.parse_args() 171 | main(args.data_dir, args.output_dir, args.emb_addr, args.class_label_addr) 172 | -------------------------------------------------------------------------------- /datasets/create_dataset_tieredimagenet.py: -------------------------------------------------------------------------------- 1 | """Creates the tiered-ImageNet dataset.""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import argparse 7 | import csv 8 | import os 9 | import sys 10 | 11 | import zipfile 12 | import numpy as np 13 | import scipy.misc 14 | import pickle as pkl 15 | import cv2 16 | from tqdm import trange 17 | 18 | # Make train, validation and test splits deterministic from one run to another 19 | np.random.seed(2017 + 5 + 17) 20 | 21 | def get_class_label_dict(class_label_addr): 22 | lines = [x.strip() for x in open(class_label_addr, 'r').readlines()] 23 | cld={} 24 | for l in lines: 25 | tl=l.split(' ') 26 | if tl[0] not in cld.keys(): 27 | cld[tl[0]]=tl[2].lower() 28 | return cld 29 | def load_embedding_dict(emb_addr): 30 | lines = [x.strip() for x in open(emb_addr, 'r', encoding="utf-8").readlines()] 31 | emb_dict = {} 32 | for l in lines: 33 | w = l.split(' ') 34 | if w[0] not in emb_dict.keys(): 35 | tmpv = [float(w[i]) for i in range(1, len(w))] 36 | emb_dict[w[0]] = tmpv 37 | return emb_dict 38 | 39 | def get_embeddings_for_labels(all_classes, emb_dict): 40 | emb_list = [] 41 | no_emb = 0 42 | print(all_classes) 43 | print(len(all_classes)) 44 | for v in all_classes: 45 | # check the embeddings of labels 46 | #print(v) 47 | labels = v.split(', ') 48 | tmpv = np.zeros(300) 49 | tmpl = [] 50 | c = 0 51 | lw=labels[0].split(' ') 52 | if labels[0]=='limpkin' or labels[0]=='otterhound' or labels[0]=='barracouta' or labels[0]=='lycaenid': 53 | lw = labels[1].strip().split(' ') 54 | for l in lw: 55 | if l in emb_dict.keys(): 56 | tmpv += emb_dict[l] 57 | tmpl.append(l) 58 | c += 1 59 | if len(lw) != 1: 60 | if c != len(lw): 61 | print(v, c, tmpl) 62 | if c != 0: 63 | emb_list.append(tmpv / c) 64 | else: 65 | emb_list.append(np.random.rand(300)*2-1) 66 | no_emb += 1 67 | print("no embedding for " + v) 68 | print(no_emb) 69 | return emb_list 70 | 71 | 72 | def main(data_dir, output_dir, emb_addr): 73 | print("loading the embedding dictionary....") 74 | emb_dict = load_embedding_dict(emb_addr) 75 | for split in ('val', 'test', 'train'): 76 | # List of selected image files for the current split 77 | with open(data_dir + '/' + split + '_images_png.pkl', 'rb') as f: 78 | raw_data = pkl.load(f, encoding='latin1') 79 | data = np.zeros([len(raw_data), 84, 84, 3], dtype=np.uint8) 80 | for ii in trange(len(raw_data)): 81 | item = raw_data[ii] 82 | im = cv2.imdecode(item, 1) 83 | # print(im) 84 | data[ii] = im 85 | f = open(data_dir + '/' + split + '_labels.pkl', 'rb') 86 | label_set = pkl.load(f, encoding='latin1') 87 | labels = label_set['label_specific'] 88 | all_labels = label_set['label_specific_str'] 89 | print("getting word embeddings....") 90 | emb_list = get_embeddings_for_labels(all_labels, emb_dict) 91 | print("saving word embeddings...") 92 | np.savez( 93 | os.path.join(output_dir, 'few-shot-wordemb-{}.npz'.format(split)), 94 | features=np.asarray(emb_list)) 95 | 96 | # Processing loop over examples 97 | features, targets = [], [] 98 | for i, (image, label) in enumerate(zip(data, labels)): 99 | # Write progress to stdout 100 | sys.stdout.write( 101 | '\r>> Processing {} image {}/{}'.format( 102 | split, i + 1, label)) 103 | # Infer class from filename. 104 | 105 | # Central square crop of size equal to the image's smallest side. 106 | height, width, channels = image.shape 107 | crop_size = min(height, width) 108 | start_height = (height // 2) - (crop_size // 2) 109 | start_width = (width // 2) - (crop_size // 2) 110 | image = image[ 111 | start_height: start_height + crop_size, 112 | start_width: start_width + crop_size, :] 113 | 114 | features.append(image) 115 | targets.append(label) 116 | 117 | sys.stdout.write('\n') 118 | sys.stdout.flush() 119 | 120 | # Save dataset to disk 121 | features = np.stack(features, axis=0) 122 | targets = np.stack(targets, axis=0) 123 | permutation = np.random.permutation(len(features)) 124 | features = features[permutation] 125 | targets = targets[permutation] 126 | np.savez( 127 | os.path.join(output_dir, 'few-shot-{}.npz'.format(split)), 128 | features=features, targets=targets) 129 | 130 | 131 | if __name__ == '__main__': 132 | parser = argparse.ArgumentParser() 133 | parser.add_argument( 134 | '--data-dir', type=str, 135 | default=os.path.join(os.sep, 'mnt', 'datasets', 'public', 'mini-imagenet', 'raw-data'), 136 | help='Path to the raw data') 137 | parser.add_argument( 138 | '--output-dir', type=str, default=os.path.join(os.sep, 'mnt', 'datasets', 'public', 'mini-imagenet'), 139 | help='Output directory') 140 | parser.add_argument( 141 | '--emb_addr', type=str, 142 | default=os.path.join(os.sep, 'mnt', 'datasets', 'public', 'mini-imagenet', 'raw-data'), 143 | help='Path to the raw data') 144 | 145 | args = parser.parse_args() 146 | main(args.data_dir, args.output_dir, args.emb_addr) 147 | -------------------------------------------------------------------------------- /datasets/data.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import logging 7 | import numpy as np 8 | from common.util import Dataset 9 | from collections import defaultdict 10 | import pickle as pkl 11 | import cv2 12 | from tqdm import trange 13 | #import sys 14 | #reload(sys) 15 | #sys.setdefaultencoding("utf-8") 16 | 17 | 18 | def _load_mini_imagenet(data_dir, split): 19 | """Load mini-imagenet from numpy's npz file format.""" 20 | _split_tag = {'sources': 'train', 'target_val': 'val', 'target_tst': 'test'}[split] 21 | dataset_path = os.path.join(data_dir, 'few-shot-{}.npz'.format(_split_tag)) 22 | logging.info("Loading mini-imagenet...") 23 | data = np.load(dataset_path) 24 | fields = data['features'], data['targets'] 25 | logging.info("Done loading.") 26 | print(data['features'][0]) 27 | return fields 28 | 29 | def _load_tiered_imagenet(data_dir, split): 30 | """Load tiered-imagenet from numpy's npz file format.""" 31 | logging.info("Loading tiered-imagenet...") 32 | _split_tag = {'sources': 'train', 'target_val': 'val', 'target_tst': 'test'}[split] 33 | with open(data_dir+'/'+_split_tag+'_images_png.pkl', 'rb') as f: 34 | raw_data = pkl.load(f,encoding='latin1') 35 | data=np.zeros([len(raw_data), 84, 84, 3], dtype=np.uint8) 36 | for ii in trange(len(raw_data)): 37 | item=raw_data[ii] 38 | im = cv2.imdecode(item, 1) 39 | #print(im) 40 | data[ii] = im 41 | #data = None 42 | print(data[0]) 43 | f = open(data_dir+'/'+_split_tag+'_labels.pkl', 'rb') 44 | print(data_dir+'/'+_split_tag+'_labels.pkl') 45 | label = pkl.load(f,encoding='latin1') 46 | fields = data, label 47 | logging.info("Done loading.") 48 | return fields 49 | 50 | def make_all_datasets(data_dir, n_tasks_dict, n_samples_per_classes, n_classes_per_task=5): 51 | """High level function creating a dictionary of datasest for (sources, targets) x (train, test). 52 | 53 | Args: 54 | data_dir: The directory containing the mini-imagnet dataset in npz format. 55 | n_tasks_dict: A dictionary with keys in ('sources', 'target_val', 'target_tst'), mapping to the number of 56 | tasks required for that partition. 57 | n_samples_per_class: determines the number of samples each class will have in each task i.e. this imposes 58 | that each class have the same number of samples and the size of each tasks is: 59 | n_classes_per_task * n_samples_per_class 60 | n_classes_per_task: numer of classes per tasks. 61 | 62 | Returns: 63 | A dictionary of multitask dataset, where keys are tuples of strings generated from the cartesian product of keys 64 | in n_tasks_dict and ('trn', 'tst'). e.g.: ('sources', 'trn'), ('target_val', 'tst') ... 65 | """ 66 | dataset_dict = {} 67 | task_id_start = 0 68 | for split_name, task_size in n_tasks_dict.items(): 69 | images, labels = _load_mini_imagenet(data_dir, split_name) 70 | task_ids = range(task_id_start, task_id_start + task_size) 71 | task_id_start += task_size 72 | trn_dataset, tst_dataset = make_multitask_dataset( 73 | images, labels, task_ids, n_samples_per_classes, n_classes_per_task) 74 | dataset_dict[(split_name, 'trn')] = trn_dataset 75 | dataset_dict[(split_name, 'tst')] = tst_dataset 76 | 77 | return dataset_dict 78 | 79 | 80 | def make_multitask_dataset(images_db, labels, task_ids, n_samples_per_class, n_classes_per_task): 81 | """Make train and test datasets containing multiple tasks merged into a single dataset. 82 | 83 | Args: 84 | images_db: array of shape (n_total_samples, width, height, depth) containing the images in mini-imagenet 85 | labels: array of shape (n_total_samples,) representing original labels 86 | task_ids: array of pre-specified task_ids. The lenght determines the number of tasks created. 87 | the actual values don't matter as long as they are unique across all sources and target datasets. 88 | n_samples_per_class: determines the number of samples each class will have in each task i.e. this imposes 89 | that each class have the same number of samples and the size of each tasks is: 90 | n_classes_per_task * n_samples_per_class 91 | n_classes_per_task: numer of classes per tasks. 92 | """ 93 | trn_fields = [] 94 | tst_fields = [] 95 | class_maps = {} 96 | 97 | for task_id in task_ids: 98 | trn_indices, tst_indices, class_map = make_task(labels, n_samples_per_class, n_classes_per_task) 99 | class_maps[task_id] = class_map 100 | trn_fields.append((trn_indices, class_map.to_new_class(labels[trn_indices]), [task_id] * len(trn_indices))) 101 | tst_fields.append((tst_indices, class_map.to_new_class(labels[tst_indices]), [task_id] * len(tst_indices))) 102 | 103 | def make_dataset(fields): 104 | return ImagesDataset(tuple(np.hstack(fields).astype(np.int32)), images_db, class_maps) 105 | 106 | return make_dataset(trn_fields), make_dataset(tst_fields) 107 | 108 | 109 | def make_task(labels, n_samples_per_class, n_classes_per_task, rng=np.random): 110 | """Create a new task and make its train and test partition. 111 | 112 | Args: 113 | labels: 1d array of labels from the original classes of mini-imagenet 114 | n_samples_per_class: integer 115 | n_classes_per_task: integer 116 | rng: Random number generator 117 | 118 | Returns: 119 | trn_indices: 1d array of pointers in the mini-imagenet dataset, reserved for training 120 | tst_indices: 1d array of pointers in the mini-imagenet dataset, reserved for tests 121 | """ 122 | unique_labels = np.unique(labels) 123 | classes = rng.choice(unique_labels, n_classes_per_task, replace=False) 124 | 125 | trn_indices = [] 126 | tst_indices = [] 127 | for c in classes: 128 | indices = np.random.permutation(np.flatnonzero(labels == c)).astype(np.int32) 129 | trn_indices.append(indices[:n_samples_per_class]) 130 | tst_indices.append(indices[n_samples_per_class:]) 131 | trn_indices = np.hstack(trn_indices) 132 | tst_indices = np.hstack(tst_indices) 133 | 134 | return trn_indices, tst_indices, ClassMap(classes) 135 | 136 | 137 | def grouped_sampler(task_ids, n_task_per_batch): 138 | groups = defaultdict(list) 139 | 140 | for idx, task_id in enumerate(task_ids): 141 | groups[task_id].append(idx) 142 | 143 | for key, val in groups.items(): 144 | groups[key] = np.array(val) 145 | 146 | unique_tasks = np.unique(task_ids) 147 | 148 | def sampler(n_samples, rng=np.random): 149 | n_samples_per_task = n_samples // n_task_per_batch 150 | assert n_samples_per_task * n_task_per_batch == n_samples 151 | active_task_ids = rng.choice(unique_tasks, n_task_per_batch) 152 | return np.hstack([rng.choice(groups[task_id], n_samples_per_task) for task_id in active_task_ids]) 153 | 154 | return sampler 155 | 156 | 157 | class ClassMap: 158 | 159 | def __init__(self, classes): 160 | """Simple class to help map back and forth the classes ids. 161 | 162 | Args: 163 | classes: array of classes such that classes[new_class] == original_class 164 | """ 165 | self.classes = np.asarray(classes, dtype=np.int32) 166 | self.reverse_map = np.zeros(np.max(classes) + 1, dtype=np.int32) - 1 # not the most memory efficient but it doesn't matter 167 | for i, c in enumerate(classes): 168 | self.reverse_map[c] = i 169 | 170 | def to_new_class(self, original_classes): 171 | return self.reverse_map[original_classes] 172 | 173 | def to_original_class(self, new_classes): 174 | return self.classes[new_classes] 175 | 176 | 177 | class ImagesDataset(Dataset): 178 | 179 | def __init__(self, fields, images_db, class_map): 180 | super().__init__(fields) 181 | self.images_db = images_db 182 | self.class_map = class_map 183 | 184 | def _from_ptr_to_images(self, indices, new_labels, task_ids): 185 | images = self.images_db[indices] 186 | return images, new_labels, task_ids 187 | 188 | def next_batch(self, n, rng=np.random): 189 | return self._from_ptr_to_images(*super().next_batch(n, rng)) 190 | 191 | def sequential_batches(self, batch_size, n_batches, rng=np.random): 192 | for fields in super().sequential_batches(batch_size, n_batches, rng): 193 | yield self._from_ptr_to_images(*fields) 194 | 195 | -------------------------------------------------------------------------------- /datasets/mini_imagenet_class_label_dict3.txt: -------------------------------------------------------------------------------- 1 | n02119789 1 kit_fox 2 | n02100735 2 English_setter 3 | n02110185 3 Siberian_husky 4 | n02096294 4 Australian_terrier 5 | n02102040 5 English_springer 6 | n02066245 6 grey_whale 7 | n02509815 7 lesser_panda 8 | n02124075 8 Egyptian_cat 9 | n02417914 9 ibex 10 | n02123394 10 Persian_cat 11 | n02125311 11 cougar 12 | n02423022 12 gazelle 13 | n02346627 13 porcupine 14 | n02077923 14 sea_lion 15 | n02110063 15 malamute 16 | n02447366 16 badger 17 | n02109047 17 Great_Dane 18 | n02089867 18 Walker_hound 19 | n02102177 19 Welsh_springer_spaniel 20 | n02091134 20 whippet 21 | n02092002 21 Scottish_deerhound 22 | n02071294 22 killer_whale 23 | n02442845 23 mink 24 | n02504458 24 African_elephant 25 | n02092339 25 Weimaraner 26 | n02098105 26 soft-coated_wheaten_terrier 27 | n02096437 27 Dandie_Dinmont_terrier 28 | n02114712 28 red_wolf 29 | n02105641 29 Old_English_sheepdog 30 | n02128925 30 jaguar 31 | n02091635 31 otter_hound 32 | n02088466 32 bloodhound 33 | n02096051 33 Airedale 34 | n02117135 34 hyena 35 | n02138441 35 meerkat 36 | n02097130 36 giant_schnauzer 37 | n02493509 37 titi 38 | n02457408 38 three-toed_sloth 39 | n02389026 39 sorrel 40 | n02443484 40 black-footed_ferret 41 | n02110341 41 dalmatian 42 | n02089078 42 black-and-tan_coonhound 43 | n02086910 43 papillon 44 | n02445715 44 skunk 45 | n02093256 45 Staffordshire_bullterrier 46 | n02113978 46 Mexican_hairless 47 | n02106382 47 Bouvier_des_Flandres 48 | n02441942 48 weasel 49 | n02113712 49 miniature_poodle 50 | n02113186 50 Cardigan 51 | n02105162 51 malinois 52 | n02415577 52 bighorn 53 | n02356798 53 fox_squirrel 54 | n02488702 54 colobus 55 | n02123159 55 tiger_cat 56 | n02098413 56 Lhasa 57 | n02422699 57 impala 58 | n02114855 58 coyote 59 | n02094433 59 Yorkshire_terrier 60 | n02111277 60 Newfoundland 61 | n02132136 61 brown_bear 62 | n02119022 62 red_fox 63 | n02091467 63 Norwegian_elkhound 64 | n02106550 64 Rottweiler 65 | n02422106 65 hartebeest 66 | n02091831 66 Saluki 67 | n02120505 67 grey_fox 68 | n02104365 68 schipperke 69 | n02086079 69 Pekinese 70 | n02112706 70 Brabancon_griffon 71 | n02098286 71 West_Highland_white_terrier 72 | n02095889 72 Sealyham_terrier 73 | n02484975 73 guenon 74 | n02137549 74 mongoose 75 | n02500267 75 indri 76 | n02129604 76 tiger 77 | n02090721 77 Irish_wolfhound 78 | n02396427 78 wild_boar 79 | n02108000 79 EntleBucher 80 | n02391049 80 zebra 81 | n02412080 81 ram 82 | n02108915 82 French_bulldog 83 | n02480495 83 orangutan 84 | n02110806 84 basenji 85 | n02128385 85 leopard 86 | n02107683 86 Bernese_mountain_dog 87 | n02085936 87 Maltese_dog 88 | n02094114 88 Norfolk_terrier 89 | n02087046 89 toy_terrier 90 | n02100583 90 vizsla 91 | n02096177 91 cairn 92 | n02494079 92 squirrel_monkey 93 | n02105056 93 groenendael 94 | n02101556 94 clumber 95 | n02123597 95 Siamese_cat 96 | n02481823 96 chimpanzee 97 | n02105505 97 komondor 98 | n02088094 98 Afghan_hound 99 | n02085782 99 Japanese_spaniel 100 | n02489166 100 proboscis_monkey 101 | n02364673 101 guinea_pig 102 | n02114548 102 white_wolf 103 | n02134084 103 ice_bear 104 | n02480855 104 gorilla 105 | n02090622 105 borzoi 106 | n02113624 106 toy_poodle 107 | n02093859 107 Kerry_blue_terrier 108 | n02403003 108 ox 109 | n02097298 109 Scotch_terrier 110 | n02108551 110 Tibetan_mastiff 111 | n02493793 111 spider_monkey 112 | n02107142 112 Doberman 113 | n02096585 113 Boston_bull 114 | n02107574 114 Greater_Swiss_Mountain_dog 115 | n02107908 115 Appenzeller 116 | n02086240 116 Shih-Tzu 117 | n02102973 117 Irish_water_spaniel 118 | n02112018 118 Pomeranian 119 | n02093647 119 Bedlington_terrier 120 | n02397096 120 warthog 121 | n02437312 121 Arabian_camel 122 | n02483708 122 siamang 123 | n02097047 123 miniature_schnauzer 124 | n02106030 124 collie 125 | n02099601 125 golden_retriever 126 | n02093991 126 Irish_terrier 127 | n02110627 127 affenpinscher 128 | n02106166 128 Border_collie 129 | n02326432 129 hare 130 | n02108089 130 boxer 131 | n02097658 131 silky_terrier 132 | n02088364 132 beagle 133 | n02111129 133 Leonberg 134 | n02100236 134 German_short-haired_pointer 135 | n02486261 135 patas 136 | n02115913 136 dhole 137 | n02486410 137 baboon 138 | n02487347 138 macaque 139 | n02099849 139 Chesapeake_Bay_retriever 140 | n02108422 140 bull_mastiff 141 | n02104029 141 kuvasz 142 | n02492035 142 capuchin 143 | n02110958 143 pug 144 | n02099429 144 curly-coated_retriever 145 | n02094258 145 Norwich_terrier 146 | n02099267 146 flat-coated_retriever 147 | n02395406 147 hog 148 | n02112350 148 keeshond 149 | n02109961 149 Eskimo_dog 150 | n02101388 150 Brittany_spaniel 151 | n02113799 151 standard_poodle 152 | n02095570 152 Lakeland_terrier 153 | n02128757 153 snow_leopard 154 | n02101006 154 Gordon_setter 155 | n02115641 155 dingo 156 | n02097209 156 standard_schnauzer 157 | n02342885 157 hamster 158 | n02097474 158 Tibetan_terrier 159 | n02120079 159 Arctic_fox 160 | n02095314 160 wire-haired_fox_terrier 161 | n02088238 161 basset 162 | n02408429 162 water_buffalo 163 | n02133161 163 American_black_bear 164 | n02328150 164 Angora 165 | n02410509 165 bison 166 | n02492660 166 howler_monkey 167 | n02398521 167 hippopotamus 168 | n02112137 168 chow 169 | n02510455 169 giant_panda 170 | n02093428 170 American_Staffordshire_terrier 171 | n02105855 171 Shetland_sheepdog 172 | n02111500 172 Great_Pyrenees 173 | n02085620 173 Chihuahua 174 | n02123045 174 tabby 175 | n02490219 175 marmoset 176 | n02099712 176 Labrador_retriever 177 | n02109525 177 Saint_Bernard 178 | n02454379 178 armadillo 179 | n02111889 179 Samoyed 180 | n02088632 180 bluetick 181 | n02090379 181 redbone 182 | n02443114 182 polecat 183 | n02361337 183 marmot 184 | n02105412 184 kelpie 185 | n02483362 185 gibbon 186 | n02437616 186 llama 187 | n02107312 187 miniature_pinscher 188 | n02325366 188 wood_rabbit 189 | n02091032 189 Italian_greyhound 190 | n02129165 190 lion 191 | n02102318 191 cocker_spaniel 192 | n02100877 192 Irish_setter 193 | n02074367 193 dugong 194 | n02504013 194 Indian_elephant 195 | n02363005 195 beaver 196 | n02102480 196 Sussex_spaniel 197 | n02113023 197 Pembroke 198 | n02086646 198 Blenheim_spaniel 199 | n02497673 199 Madagascar_cat 200 | n02087394 200 Rhodesian_ridgeback 201 | n02127052 201 lynx 202 | n02116738 202 African_hunting_dog 203 | n02488291 203 langur 204 | n02091244 204 Ibizan_hound 205 | n02114367 205 timber_wolf 206 | n02130308 206 cheetah 207 | n02089973 207 English_foxhound 208 | n02105251 208 briard 209 | n02134418 209 sloth_bear 210 | n02093754 210 Border_terrier 211 | n02106662 211 German_shepherd 212 | n02444819 212 otter 213 | n01882714 213 koala 214 | n01871265 214 tusker 215 | n01872401 215 echidna 216 | n01877812 216 wallaby 217 | n01873310 217 platypus 218 | n01883070 218 wombat 219 | n04086273 219 revolver 220 | n04507155 220 umbrella 221 | n04147183 221 schooner 222 | n04254680 222 soccer_ball 223 | n02672831 223 accordion 224 | n02219486 224 ant 225 | n02317335 225 starfish 226 | n01968897 226 chambered_nautilus 227 | n03452741 227 grand_piano 228 | n03642806 228 laptop 229 | n07745940 229 strawberry 230 | n02690373 230 airliner 231 | n04552348 231 warplane 232 | n02692877 232 airship 233 | n02782093 233 balloon 234 | n04266014 234 space_shuttle 235 | n03344393 235 fireboat 236 | n03447447 236 gondola 237 | n04273569 237 speedboat 238 | n03662601 238 lifeboat 239 | n02951358 239 canoe 240 | n04612504 240 yawl 241 | n02981792 241 catamaran 242 | n04483307 242 trimaran 243 | n03095699 243 container_ship 244 | n03673027 244 liner 245 | n03947888 245 pirate 246 | n02687172 246 aircraft_carrier 247 | n04347754 247 submarine 248 | n04606251 248 wreck 249 | n03478589 249 half_track 250 | n04389033 250 tank 251 | n03773504 251 missile 252 | n02860847 252 bobsled 253 | n03218198 253 dogsled 254 | n02835271 254 tandem_bicycle 255 | n03792782 255 mountain_bike 256 | n03393912 256 freight_car 257 | n03895866 257 passenger_car 258 | n02797295 258 barrow 259 | n04204347 259 shopping_cart 260 | n03791053 260 motor_scooter 261 | n03384352 261 forklift 262 | n03272562 262 electric_locomotive 263 | n04310018 263 steam_locomotive 264 | n02704792 264 amphibian 265 | n02701002 265 ambulance 266 | n02814533 266 beach_wagon 267 | n02930766 267 cab 268 | n03100240 268 convertible 269 | n03594945 269 jeep 270 | n03670208 270 limousine 271 | n03770679 271 minivan 272 | n03777568 272 Model_T 273 | n04037443 273 racer 274 | n04285008 274 sports_car 275 | n03444034 275 go-kart 276 | n03445924 276 golfcart 277 | n03785016 277 moped 278 | n04252225 278 snowplow 279 | n03345487 279 fire_engine 280 | n03417042 280 garbage_truck 281 | n03930630 281 pickup 282 | n04461696 282 tow_truck 283 | n04467665 283 trailer_truck 284 | n03796401 284 moving_van 285 | n03977966 285 police_van 286 | n04065272 286 recreational_vehicle 287 | n04335435 287 streetcar 288 | n04252077 288 snowmobile 289 | n04465501 289 tractor 290 | n03776460 290 mobile_home 291 | n04482393 291 tricycle 292 | n04509417 292 unicycle 293 | n03538406 293 horse_cart 294 | n03599486 294 ricksha_rickshaw 295 | n03868242 295 oxcart 296 | n02804414 296 bassinet 297 | n03125729 297 cradle 298 | n03131574 298 crib 299 | n03388549 299 four-poster 300 | n02870880 300 bookcase 301 | n03018349 301 china_cabinet 302 | n03742115 302 medicine_chest 303 | n03016953 303 chiffonier 304 | n04380533 304 table_lamp 305 | n03337140 305 file 306 | n03891251 306 park_bench 307 | n02791124 307 barber_chair 308 | n04429376 308 throne 309 | n03376595 309 folding_chair 310 | n04099969 310 rocking_chair 311 | n04344873 311 studio_couch 312 | n04447861 312 toilet_seat 313 | n03179701 313 desk 314 | n03982430 314 pool_table 315 | n03201208 315 dining_table 316 | n03290653 316 entertainment_center 317 | n04550184 317 wardrobe 318 | n07742313 318 Granny_Smith 319 | n07747607 319 orange 320 | n07749582 320 lemon 321 | n07753113 321 fig 322 | n07753275 322 pineapple 323 | n07753592 323 banana 324 | n07754684 324 jackfruit 325 | n07760859 325 custard_apple 326 | n07768694 326 pomegranate 327 | n12267677 327 acorn 328 | n12620546 328 hip 329 | n13133613 329 ear 330 | n11879895 330 rapeseed 331 | n12144580 331 corn 332 | n12768682 332 buckeye 333 | n03854065 333 organ 334 | n04515003 334 upright 335 | n03017168 335 chime 336 | n03249569 336 drum 337 | n03447721 337 gong 338 | n03720891 338 maraca 339 | n03721384 339 marimba 340 | n04311174 340 steel_drum 341 | n02787622 341 banjo 342 | n02992211 342 cello 343 | n04536866 343 violin 344 | n03495258 344 harp 345 | n02676566 345 acoustic_guitar 346 | n03272010 346 electric_guitar 347 | n03110669 347 cornet 348 | n03394916 348 French_horn 349 | n04487394 349 trombone 350 | n03494278 350 harmonica 351 | n03840681 351 ocarina 352 | n03884397 352 panpipe 353 | n02804610 353 bassoon 354 | n03838899 354 oboe 355 | n04141076 355 sax 356 | n03372029 356 flute 357 | n11939491 357 daisy 358 | n12057211 358 yellow_lady's_slipper 359 | n09246464 359 cliff 360 | n09468604 360 valley 361 | n09193705 361 alp 362 | n09472597 362 volcano 363 | n09399592 363 promontory 364 | n09421951 364 sandbar 365 | n09256479 365 coral_reef 366 | n09332890 366 lakeside 367 | n09428293 367 seashore 368 | n09288635 368 geyser 369 | n03498962 369 hatchet 370 | n03041632 370 cleaver 371 | n03658185 371 letter_opener 372 | n03954731 372 plane 373 | n03995372 373 power_drill 374 | n03649909 374 lawn_mower 375 | n03481172 375 hammer 376 | n03109150 376 corkscrew 377 | n02951585 377 can_opener 378 | n03970156 378 plunger 379 | n04154565 379 screwdriver 380 | n04208210 380 shovel 381 | n03967562 381 plow 382 | n03000684 382 chain_saw 383 | n01514668 383 cock 384 | n01514859 384 hen 385 | n01518878 385 ostrich 386 | n01530575 386 brambling 387 | n01531178 387 goldfinch 388 | n01532829 388 house_finch 389 | n01534433 389 junco 390 | n01537544 390 indigo_bunting 391 | n01558993 391 robin 392 | n01560419 392 bulbul 393 | n01580077 393 jay 394 | n01582220 394 magpie 395 | n01592084 395 chickadee 396 | n01601694 396 water_ouzel 397 | n01608432 397 kite 398 | n01614925 398 bald_eagle 399 | n01616318 399 vulture 400 | n01622779 400 great_grey_owl 401 | n01795545 401 black_grouse 402 | n01796340 402 ptarmigan 403 | n01797886 403 ruffed_grouse 404 | n01798484 404 prairie_chicken 405 | n01806143 405 peacock 406 | n01806567 406 quail 407 | n01807496 407 partridge 408 | n01817953 408 African_grey 409 | n01818515 409 macaw 410 | n01819313 410 sulphur-crested_cockatoo 411 | n01820546 411 lorikeet 412 | n01824575 412 coucal 413 | n01828970 413 bee_eater 414 | n01829413 414 hornbill 415 | n01833805 415 hummingbird 416 | n01843065 416 jacamar 417 | n01843383 417 toucan 418 | n01847000 418 drake 419 | n01855032 419 red-breasted_merganser 420 | n01855672 420 goose 421 | n01860187 421 black_swan 422 | n02002556 422 white_stork 423 | n02002724 423 black_stork 424 | n02006656 424 spoonbill 425 | n02007558 425 flamingo 426 | n02009912 426 American_egret 427 | n02009229 427 little_blue_heron 428 | n02011460 428 bittern 429 | n02012849 429 crane 430 | n02013706 430 Aramus_pictus 431 | n02018207 431 American_coot 432 | n02018795 432 bustard 433 | n02025239 433 ruddy_turnstone 434 | n02027492 434 red-backed_sandpiper 435 | n02028035 435 redshank 436 | n02033041 436 dowitcher 437 | n02037110 437 oystercatcher 438 | n02017213 438 European_gallinule 439 | n02051845 439 pelican 440 | n02056570 440 king_penguin 441 | n02058221 441 albatross 442 | n01484850 442 great_white_shark 443 | n01491361 443 tiger_shark 444 | n01494475 444 hammerhead 445 | n01496331 445 electric_ray 446 | n01498041 446 stingray 447 | n02514041 447 snoek 448 | n02536864 448 coho 449 | n01440764 449 tench 450 | n01443537 450 goldfish 451 | n02526121 451 eel 452 | n02606052 452 rock_beauty 453 | n02607072 453 anemone_fish 454 | n02643566 454 lionfish 455 | n02655020 455 puffer 456 | n02640242 456 sturgeon 457 | n02641379 457 gar 458 | n01664065 458 loggerhead 459 | n01665541 459 leatherback_turtle 460 | n01667114 460 mud_turtle 461 | n01667778 461 terrapin 462 | n01669191 462 box_turtle 463 | n01675722 463 banded_gecko 464 | n01677366 464 common_iguana 465 | n01682714 465 American_chameleon 466 | n01685808 466 whiptail 467 | n01687978 467 agama 468 | n01688243 468 frilled_lizard 469 | n01689811 469 alligator_lizard 470 | n01692333 470 Gila_monster 471 | n01693334 471 green_lizard 472 | n01694178 472 African_chameleon 473 | n01695060 473 Komodo_dragon 474 | n01704323 474 triceratops 475 | n01697457 475 African_crocodile 476 | n01698640 476 American_alligator 477 | n01728572 477 thunder_snake 478 | n01728920 478 ringneck_snake 479 | n01729322 479 hognose_snake 480 | n01729977 480 green_snake 481 | n01734418 481 king_snake 482 | n01735189 482 garter_snake 483 | n01737021 483 water_snake 484 | n01739381 484 vine_snake 485 | n01740131 485 night_snake 486 | n01742172 486 boa_constrictor 487 | n01744401 487 rock_python 488 | n01748264 488 Indian_cobra 489 | n01749939 489 green_mamba 490 | n01751748 490 sea_snake 491 | n01753488 491 horned_viper 492 | n01755581 492 diamondback 493 | n01756291 493 sidewinder 494 | n01629819 494 European_fire_salamander 495 | n01630670 495 common_newt 496 | n01631663 496 eft 497 | n01632458 497 spotted_salamander 498 | n01632777 498 axolotl 499 | n01641577 499 bullfrog 500 | n01644373 500 tree_frog 501 | n01644900 501 tailed_frog 502 | n04579432 502 whistle 503 | n04592741 503 wing 504 | n03876231 504 paintbrush 505 | n03483316 505 hand_blower 506 | n03868863 506 oxygen_mask 507 | n04251144 507 snorkel 508 | n03691459 508 loudspeaker 509 | n03759954 509 microphone 510 | n04152593 510 screen 511 | n03793489 511 mouse 512 | n03271574 512 electric_fan 513 | n03843555 513 oil_filter 514 | n04332243 514 strainer 515 | n04265275 515 space_heater 516 | n04330267 516 stove 517 | n03467068 517 guillotine 518 | n02794156 518 barometer 519 | n04118776 519 rule 520 | n03841143 520 odometer 521 | n04141975 521 scale 522 | n02708093 522 analog_clock 523 | n03196217 523 digital_clock 524 | n04548280 524 wall_clock 525 | n03544143 525 hourglass 526 | n04355338 526 sundial 527 | n03891332 527 parking_meter 528 | n04328186 528 stopwatch 529 | n03197337 529 digital_watch 530 | n04317175 530 stethoscope 531 | n04376876 531 syringe 532 | n03706229 532 magnetic_compass 533 | n02841315 533 binoculars 534 | n04009552 534 projector 535 | n04356056 535 sunglasses 536 | n03692522 536 loupe 537 | n04044716 537 radio_telescope 538 | n02879718 538 bow 539 | n02950826 539 cannon 540 | n02749479 540 assault_rifle 541 | n04090263 541 rifle 542 | n04008634 542 projectile 543 | n03085013 543 computer_keyboard 544 | n04505470 544 typewriter_keyboard 545 | n03126707 545 crane 546 | n03666591 546 lighter 547 | n02666196 547 abacus 548 | n02977058 548 cash_machine 549 | n04238763 549 slide_rule 550 | n03180011 550 desktop_computer 551 | n03485407 551 hand-held_computer 552 | n03832673 552 notebook 553 | n06359193 553 web_site 554 | n03496892 554 harvester 555 | n04428191 555 thresher 556 | n04004767 556 printer 557 | n04243546 557 slot 558 | n04525305 558 vending_machine 559 | n04179913 559 sewing_machine 560 | n03602883 560 joystick 561 | n04372370 561 switch 562 | n03532672 562 hook 563 | n02974003 563 car_wheel 564 | n03874293 564 paddlewheel 565 | n03944341 565 pinwheel 566 | n03992509 566 potter's_wheel 567 | n03425413 567 gas_pump 568 | n02966193 568 carousel 569 | n04371774 569 swing 570 | n04067472 570 reel 571 | n04040759 571 radiator 572 | n04019541 572 puck 573 | n03492542 573 hard_disc 574 | n04355933 574 sunglass 575 | n03929660 575 pick 576 | n02965783 576 car_mirror 577 | n04258138 577 solar_dish 578 | n04074963 578 remote_control 579 | n03208938 579 disk_brake 580 | n02910353 580 buckle 581 | n03476684 581 hair_slide 582 | n03627232 582 knot 583 | n03075370 583 combination_lock 584 | n03874599 584 padlock 585 | n03804744 585 nail 586 | n04127249 586 safety_pin 587 | n04153751 587 screw 588 | n03803284 588 muzzle 589 | n04162706 589 seat_belt 590 | n04228054 590 ski 591 | n02948072 591 candle 592 | n03590841 592 jack-o'-lantern 593 | n04286575 593 spotlight 594 | n04456115 594 torch 595 | n03814639 595 neck_brace 596 | n03933933 596 pier 597 | n04485082 597 tripod 598 | n03733131 598 maypole 599 | n03794056 599 mousetrap 600 | n04275548 600 spider_web 601 | n01768244 601 trilobite 602 | n01770081 602 harvestman 603 | n01770393 603 scorpion 604 | n01773157 604 black_and_gold_garden_spider 605 | n01773549 605 barn_spider 606 | n01773797 606 garden_spider 607 | n01774384 607 black_widow 608 | n01774750 608 tarantula 609 | n01775062 609 wolf_spider 610 | n01776313 610 tick 611 | n01784675 611 centipede 612 | n01990800 612 isopod 613 | n01978287 613 Dungeness_crab 614 | n01978455 614 rock_crab 615 | n01980166 615 fiddler_crab 616 | n01981276 616 king_crab 617 | n01983481 617 American_lobster 618 | n01984695 618 spiny_lobster 619 | n01985128 619 crayfish 620 | n01986214 620 hermit_crab 621 | n02165105 621 tiger_beetle 622 | n02165456 622 ladybug 623 | n02167151 623 ground_beetle 624 | n02168699 624 long-horned_beetle 625 | n02169497 625 leaf_beetle 626 | n02172182 626 dung_beetle 627 | n02174001 627 rhinoceros_beetle 628 | n02177972 628 weevil 629 | n02190166 629 fly 630 | n02206856 630 bee 631 | n02226429 631 grasshopper 632 | n02229544 632 cricket 633 | n02231487 633 walking_stick 634 | n02233338 634 cockroach 635 | n02236044 635 mantis 636 | n02256656 636 cicada 637 | n02259212 637 leafhopper 638 | n02264363 638 lacewing 639 | n02268443 639 dragonfly 640 | n02268853 640 damselfly 641 | n02276258 641 admiral 642 | n02277742 642 ringlet 643 | n02279972 643 monarch 644 | n02280649 644 cabbage_butterfly 645 | n02281406 645 sulphur_butterfly 646 | n02281787 646 lycaenid_butterfly 647 | n01910747 647 jellyfish 648 | n01914609 648 sea_anemone 649 | n01917289 649 brain_coral 650 | n01924916 650 flatworm 651 | n01930112 651 nematode 652 | n01943899 652 conch 653 | n01944390 653 snail 654 | n01945685 654 slug 655 | n01950731 655 sea_slug 656 | n01955084 656 chiton 657 | n02319095 657 sea_urchin 658 | n02321529 658 sea_cucumber 659 | n03584829 659 iron 660 | n03297495 660 espresso_maker 661 | n03761084 661 microwave 662 | n03259280 662 Dutch_oven 663 | n04111531 663 rotisserie 664 | n04442312 664 toaster 665 | n04542943 665 waffle_iron 666 | n04517823 666 vacuum 667 | n03207941 667 dishwasher 668 | n04070727 668 refrigerator 669 | n04554684 669 washer 670 | n03133878 670 Crock_Pot 671 | n03400231 671 frying_pan 672 | n04596742 672 wok 673 | n02939185 673 caldron 674 | n03063689 674 coffeepot 675 | n04398044 675 teapot 676 | n04270147 676 spatula 677 | n02699494 677 altar 678 | n04486054 678 triumphal_arch 679 | n03899768 679 patio 680 | n04311004 680 steel_arch_bridge 681 | n04366367 681 suspension_bridge 682 | n04532670 682 viaduct 683 | n02793495 683 barn 684 | n03457902 684 greenhouse 685 | n03877845 685 palace 686 | n03781244 686 monastery 687 | n03661043 687 library 688 | n02727426 688 apiary 689 | n02859443 689 boathouse 690 | n03028079 690 church 691 | n03788195 691 mosque 692 | n04346328 692 stupa 693 | n03956157 693 planetarium 694 | n04081281 694 restaurant 695 | n03032252 695 cinema 696 | n03529860 696 home_theater 697 | n03697007 697 lumbermill 698 | n03065424 698 coil 699 | n03837869 699 obelisk 700 | n04458633 700 totem_pole 701 | n02980441 701 castle 702 | n04005630 702 prison 703 | n03461385 703 grocery_store 704 | n02776631 704 bakery 705 | n02791270 705 barbershop 706 | n02871525 706 bookshop 707 | n02927161 707 butcher_shop 708 | n03089624 708 confectionery 709 | n04200800 709 shoe_shop 710 | n04443257 710 tobacco_shop 711 | n04462240 711 toyshop 712 | n03388043 712 fountain 713 | n03042490 713 cliff_dwelling 714 | n04613696 714 yurt 715 | n03216828 715 dock 716 | n02892201 716 brass 717 | n03743016 717 megalith 718 | n02788148 718 bannister 719 | n02894605 719 breakwater 720 | n03160309 720 dam 721 | n03000134 721 chainlink_fence 722 | n03930313 722 picket_fence 723 | n04604644 723 worm_fence 724 | n04326547 724 stone_wall 725 | n03459775 725 grille 726 | n04239074 726 sliding_door 727 | n04501370 727 turnstile 728 | n03792972 728 mountain_tent 729 | n04149813 729 scoreboard 730 | n03530642 730 honeycomb 731 | n03961711 731 plate_rack 732 | n03903868 732 pedestal 733 | n02814860 733 beacon 734 | n07711569 734 mashed_potato 735 | n07720875 735 bell_pepper 736 | n07714571 736 head_cabbage 737 | n07714990 737 broccoli 738 | n07715103 738 cauliflower 739 | n07716358 739 zucchini 740 | n07716906 740 spaghetti_squash 741 | n07717410 741 acorn_squash 742 | n07717556 742 butternut_squash 743 | n07718472 743 cucumber 744 | n07718747 744 artichoke 745 | n07730033 745 cardoon 746 | n07734744 746 mushroom 747 | n04209239 747 shower_curtain 748 | n03594734 748 jean 749 | n02971356 749 carton 750 | n03485794 750 handkerchief 751 | n04133789 751 sandal 752 | n02747177 752 ashcan 753 | n04125021 753 safe 754 | n07579787 754 plate 755 | n03814906 755 necklace 756 | n03134739 756 croquet_ball 757 | n03404251 757 fur_coat 758 | n04423845 758 thimble 759 | n03877472 759 pajama 760 | n04120489 760 running_shoe 761 | n03062245 761 cocktail_shaker 762 | n03014705 762 chest 763 | n03717622 763 manhole_cover 764 | n03777754 764 modem 765 | n04493381 765 tub 766 | n04476259 766 tray 767 | n02777292 767 balance_beam 768 | n07693725 768 bagel 769 | n03998194 769 prayer_rug 770 | n03617480 770 kimono 771 | n07590611 771 hot_pot 772 | n04579145 772 whiskey_jug 773 | n03623198 773 knee_pad 774 | n07248320 774 book_jacket 775 | n04277352 775 spindle 776 | n04229816 776 ski_mask 777 | n02823428 777 beer_bottle 778 | n03127747 778 crash_helmet 779 | n02877765 779 bottlecap 780 | n04435653 780 tile_roof 781 | n03724870 781 mask 782 | n03710637 782 maillot 783 | n03920288 783 Petri_dish 784 | n03379051 784 football_helmet 785 | n02807133 785 bathing_cap 786 | n04399382 786 teddy 787 | n03527444 787 holster 788 | n03983396 788 pop_bottle 789 | n03924679 789 photocopier 790 | n04532106 790 vestment 791 | n06785654 791 crossword_puzzle 792 | n03445777 792 golf_ball 793 | n07613480 793 trifle 794 | n04350905 794 suit 795 | n04562935 795 water_tower 796 | n03325584 796 feather_boa 797 | n03045698 797 cloak 798 | n07892512 798 red_wine 799 | n03250847 799 drumstick 800 | n04192698 800 shield 801 | n03026506 801 Christmas_stocking 802 | n03534580 802 hoopskirt 803 | n07565083 803 menu 804 | n04296562 804 stage 805 | n02869837 805 bonnet 806 | n07871810 806 meat_loaf 807 | n02799071 807 baseball 808 | n03314780 808 face_powder 809 | n04141327 809 scabbard 810 | n04357314 810 sunscreen 811 | n02823750 811 beer_glass 812 | n13052670 812 hen_of_the_woods 813 | n07583066 813 guacamole 814 | n03637318 814 lampshade 815 | n04599235 815 wool 816 | n07802026 816 hay 817 | n02883205 817 bow_tie 818 | n03709823 818 mailbag 819 | n04560804 819 water_jug 820 | n02909870 820 bucket 821 | n03207743 821 dishrag 822 | n04263257 822 soup_bowl 823 | n07932039 823 eggnog 824 | n03786901 824 mortar 825 | n04479046 825 trench_coat 826 | n03873416 826 paddle 827 | n02999410 827 chain 828 | n04367480 828 swab 829 | n03775546 829 mixing_bowl 830 | n07875152 830 potpie 831 | n04591713 831 wine_bottle 832 | n04201297 832 shoji 833 | n02916936 833 bulletproof_vest 834 | n03240683 834 drilling_platform 835 | n02840245 835 binder 836 | n02963159 836 cardigan 837 | n04370456 837 sweatshirt 838 | n03991062 838 pot 839 | n02843684 839 birdhouse 840 | n03482405 840 hamper 841 | n03942813 841 ping-pong_ball 842 | n03908618 842 pencil_box 843 | n03902125 843 pay-phone 844 | n07584110 844 consomme 845 | n02730930 845 apron 846 | n04023962 846 punching_bag 847 | n02769748 847 backpack 848 | n10148035 848 groom 849 | n02817516 849 bearskin 850 | n03908714 850 pencil_sharpener 851 | n02906734 851 broom 852 | n03788365 852 mosquito_net 853 | n02667093 853 abaya 854 | n03787032 854 mortarboard 855 | n03980874 855 poncho 856 | n03141823 856 crutch 857 | n03976467 857 Polaroid_camera 858 | n04264628 858 space_bar 859 | n07930864 859 cup 860 | n04039381 860 racket 861 | n06874185 861 traffic_light 862 | n04033901 862 quill 863 | n04041544 863 radio 864 | n07860988 864 dough 865 | n03146219 865 cuirass 866 | n03763968 866 military_uniform 867 | n03676483 867 lipstick 868 | n04209133 868 shower_cap 869 | n03782006 869 monitor 870 | n03857828 870 oscilloscope 871 | n03775071 871 mitten 872 | n02892767 872 brassiere 873 | n07684084 873 French_loaf 874 | n04522168 874 vase 875 | n03764736 875 milk_can 876 | n04118538 876 rugby_ball 877 | n03887697 877 paper_towel 878 | n13044778 878 earthstar 879 | n03291819 879 envelope 880 | n03770439 880 miniskirt 881 | n03124170 881 cowboy_hat 882 | n04487081 882 trolleybus 883 | n03916031 883 perfume 884 | n02808440 884 bathtub 885 | n07697537 885 hotdog 886 | n12985857 886 coral_fungus 887 | n02917067 887 bullet_train 888 | n03938244 888 pillow 889 | n15075141 889 toilet_tissue 890 | n02978881 890 cassette 891 | n02966687 891 carpenter's_kit 892 | n03633091 892 ladle 893 | n13040303 893 stinkhorn 894 | n03690938 894 lotion 895 | n03476991 895 hair_spray 896 | n02669723 896 academic_gown 897 | n03220513 897 dome 898 | n03127925 898 crate 899 | n04584207 899 wig 900 | n07880968 900 burrito 901 | n03937543 901 pill_bottle 902 | n03000247 902 chain_mail 903 | n04418357 903 theater_curtain 904 | n04590129 904 window_shade 905 | n02795169 905 barrel 906 | n04553703 906 washbasin 907 | n02783161 907 ballpoint 908 | n02802426 908 basketball 909 | n02808304 909 bath_towel 910 | n03124043 910 cowboy_boot 911 | n03450230 911 gown 912 | n04589890 912 window_screen 913 | n12998815 913 agaric 914 | n02992529 914 cellular_telephone 915 | n03825788 915 nipple 916 | n02790996 916 barbell 917 | n03710193 917 mailbox 918 | n03630383 918 lab_coat 919 | n03347037 919 fire_screen 920 | n03769881 920 minibus 921 | n03871628 921 packet 922 | n03733281 922 maze 923 | n03976657 923 pole 924 | n03535780 924 horizontal_bar 925 | n04259630 925 sombrero 926 | n03929855 926 pickelhaube 927 | n04049303 927 rain_barrel 928 | n04548362 928 wallet 929 | n02979186 929 cassette_player 930 | n06596364 930 comic_book 931 | n03935335 931 piggy_bank 932 | n06794110 932 street_sign 933 | n02825657 933 bell_cote 934 | n03388183 934 fountain_pen 935 | n04591157 935 Windsor_tie 936 | n04540053 936 volleyball 937 | n03866082 937 overskirt 938 | n04136333 938 sarong 939 | n04026417 939 purse 940 | n02865351 940 bolo_tie 941 | n02834397 941 bib 942 | n03888257 942 parachute 943 | n04235860 943 sleeping_bag 944 | n04404412 944 television 945 | n04371430 945 swimming_trunks 946 | n03733805 946 measuring_cup 947 | n07920052 947 espresso 948 | n07873807 948 pizza 949 | n02895154 949 breastplate 950 | n04204238 950 shopping_basket 951 | n04597913 951 wooden_spoon 952 | n04131690 952 saltshaker 953 | n07836838 953 chocolate_sauce 954 | n09835506 954 ballplayer 955 | n03443371 955 goblet 956 | n13037406 956 gyromitra 957 | n04336792 957 stretcher 958 | n04557648 958 water_bottle 959 | n03187595 959 dial_telephone 960 | n04254120 960 soap_dispenser 961 | n03595614 961 jersey 962 | n04146614 962 school_bus 963 | n03598930 963 jigsaw_puzzle 964 | n03958227 964 plastic_bag 965 | n04069434 965 reflex_camera 966 | n03188531 966 diaper 967 | n02786058 967 Band_Aid 968 | n07615774 968 ice_lolly 969 | n04525038 969 velvet 970 | n04409515 970 tennis_ball 971 | n03424325 971 gasmask 972 | n03223299 972 doormat 973 | n03680355 973 Loafer 974 | n07614500 974 ice_cream 975 | n07695742 975 pretzel 976 | n04033995 976 quilt 977 | n03710721 977 maillot 978 | n04392985 978 tape_player 979 | n03047690 979 clog 980 | n03584254 980 iPod 981 | n13054560 981 bolete 982 | n10565667 982 scuba_diver 983 | n03950228 983 pitcher 984 | n03729826 984 matchstick 985 | n02837789 985 bikini 986 | n04254777 986 sock 987 | n02988304 987 CD_player 988 | n03657121 988 lens_cap 989 | n04417672 989 thatch 990 | n04523525 990 vault 991 | n02815834 991 beaker 992 | n09229709 992 bubble 993 | n07697313 993 cheeseburger 994 | n03888605 994 parallel_bars 995 | n03355925 995 flagpole 996 | n03063599 996 coffee_mug 997 | n04116512 997 rubber_eraser 998 | n04325704 998 stole 999 | n07831146 999 carbonara 1000 | n03255030 1000 dumbbell -------------------------------------------------------------------------------- /protonet++.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Training and evaluation entry point.""" 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import os 9 | import numpy as np 10 | import argparse 11 | import tensorflow as tf 12 | import tensorflow.contrib.slim as slim 13 | from tensorflow.python.ops import init_ops 14 | from tensorflow.python.ops import random_ops 15 | from tensorflow.python.framework import dtypes 16 | from scipy.spatial import KDTree 17 | from common.util import Dataset 18 | from common.util import ACTIVATION_MAP 19 | from tqdm import trange 20 | import pathlib 21 | import logging 22 | from common.util import summary_writer 23 | from common.gen_experiments import load_and_save_params 24 | import time 25 | 26 | tf.logging.set_verbosity(tf.logging.INFO) 27 | logging.basicConfig(level=logging.INFO) 28 | 29 | 30 | def _load_mini_imagenet(data_dir, split): 31 | """Load mini-imagenet from numpy's npz file format.""" 32 | _split_tag = {'sources': 'train', 'target_val': 'val', 'target_tst': 'test'}[split] 33 | dataset_path = os.path.join(data_dir, 'few-shot-{}.npz'.format(_split_tag)) 34 | logging.info("Loading mini-imagenet...") 35 | data = np.load(dataset_path) 36 | fields = data['features'], data['targets'] 37 | logging.info("Done loading.") 38 | return fields 39 | 40 | def get_image_size(data_dir): 41 | if 'mini-imagenet' in data_dir: 42 | image_size = 84 43 | elif 'cifar' in data_dir: 44 | image_size = 32 45 | else: 46 | raise Exception('Unknown dataset: %s' % data_dir) 47 | return image_size 48 | 49 | 50 | class Namespace(object): 51 | def __init__(self, adict): 52 | self.__dict__.update(adict) 53 | 54 | 55 | def get_arguments(): 56 | parser = argparse.ArgumentParser() 57 | 58 | parser.add_argument('--mode', type=str, default='train', 59 | choices=['train', 'eval', 'test', 'train_classifier', 'create_embedding']) 60 | # Dataset parameters 61 | parser.add_argument('--data_dir', type=str, default=None, help='Path to the data.') 62 | parser.add_argument('--data_split', type=str, default='sources', choices=['sources', 'target_val', 'target_tst'], 63 | help='Split of the data to be used to perform operation.') 64 | 65 | # Training parameters 66 | parser.add_argument('--number_of_steps', type=int, default=int(30000), 67 | help="Number of training steps (number of Epochs in Hugo's paper)") 68 | parser.add_argument('--number_of_steps_to_early_stop', type=int, default=int(1000000), 69 | help="Number of training steps after half way to early stop the training") 70 | parser.add_argument('--log_dir', type=str, default='', help='Base log dir') 71 | parser.add_argument('--exp_dir', type=str, default=None, help='experiement directory for Borgy') 72 | parser.add_argument('--num_classes_train', type=int, default=5, 73 | help='Number of classes in the train phase, this is coming from the prototypical networks') 74 | parser.add_argument('--num_shots_train', type=int, default=5, 75 | help='Number of shots in a few shot meta-train scenario') 76 | parser.add_argument('--num_samples_train', type=int, default=38400, help='Number of train samples.') 77 | parser.add_argument('--train_batch_size', type=int, default=32, help='Training batch size.') 78 | parser.add_argument('--num_tasks_per_batch', type=int, default=2, 79 | help='Number of few shot tasks per batch, so the task encoding batch is num_tasks_per_batch x num_classes_test x num_shots_train .') 80 | parser.add_argument('--init_learning_rate', type=float, default=0.1, help='Initial learning rate.') 81 | parser.add_argument('--save_summaries_secs', type=int, default=60, help='Time between saving summaries') 82 | parser.add_argument('--save_interval_secs', type=int, default=60, help='Time between saving model?') 83 | parser.add_argument('--optimizer', type=str, default='sgd', choices=['sgd', 'adam']) 84 | parser.add_argument('--augment', type=bool, default=False) 85 | # Learning rate paramteres 86 | parser.add_argument('--lr_anneal', type=str, default='pwc', choices=['const', 'pwc', 'cos', 'exp']) 87 | parser.add_argument('--n_lr_decay', type=int, default=3) 88 | parser.add_argument('--lr_decay_rate', type=float, default=10.0) 89 | parser.add_argument('--num_steps_decay_pwc', type=int, default=2500, 90 | help='Decay learning rate every num_steps_decay_pwc') 91 | 92 | parser.add_argument('--clip_gradient_norm', type=float, default=1.0, help='gradient clip norm.') 93 | parser.add_argument('--weights_initializer_factor', type=float, default=0.1, 94 | help='multiplier in the variance of the initialization noise.') 95 | # Evaluation parameters 96 | parser.add_argument('--max_number_of_evaluations', type=float, default=float('inf')) 97 | parser.add_argument('--eval_interval_secs', type=int, default=120, help='Time between evaluating model?') 98 | parser.add_argument('--eval_interval_steps', type=int, default=1000, 99 | help='Number of train steps between evaluating model in the training loop') 100 | parser.add_argument('--eval_interval_fine_steps', type=int, default=250, 101 | help='Number of train steps between evaluating model in the training loop in the final phase') 102 | # parser.add_argument('--num_samples_eval', type=int, default=12000, help='Number of evaluation samples?') 103 | # parser.add_argument('--eval_batch_size', type=int, default=100, help='Evaluation batch size?') 104 | # parser.add_argument('--num_evals', type=int, default=100, help='Number of evaluations in the evaluation phase') 105 | # Test parameters 106 | parser.add_argument('--num_classes_test', type=int, default=5, help='Number of classes in the test phase') 107 | parser.add_argument('--num_shots_test', type=int, default=5, 108 | help='Number of shots in a few shot meta-test scenario') 109 | parser.add_argument('--num_cases_test', type=int, default=50000, 110 | help='Number of few-shot cases to compute test accuracy') 111 | # Architecture parameters 112 | parser.add_argument('--dropout', type=float, default=1.0) 113 | parser.add_argument('--fc_dropout', type=float, default=None, help='Dropout before the final fully connected layer') 114 | parser.add_argument('--weight_decay', type=float, default=0.0005) 115 | parser.add_argument('--num_filters', type=int, default=64) 116 | parser.add_argument('--num_units_in_block', type=int, default=3) 117 | parser.add_argument('--num_blocks', type=int, default=4) 118 | parser.add_argument('--num_max_pools', type=int, default=3) 119 | parser.add_argument('--block_size_growth', type=float, default=2.0) 120 | parser.add_argument('--activation', type=str, default='swish-1', choices=['relu', 'selu', 'swish-1']) 121 | 122 | parser.add_argument('--feature_dropout_p', type=float, default=None) 123 | parser.add_argument('--feature_expansion_size', type=int, default=None) 124 | parser.add_argument('--feature_bottleneck_size', type=int, default=None) 125 | 126 | parser.add_argument('--feature_extractor', type=str, default='simple_res_net', 127 | choices=['simple_conv_net', 'simple_res_net', 'res_net', 'dense_net', 'residense_net', 128 | 'res_net_34'], help='Which feature extractor to use') 129 | # Feature extractor pretraining parameters (auxiliary 64-classification task) 130 | parser.add_argument('--feat_extract_pretrain', type=str, default=None, 131 | choices=[None, 'finetune', 'freeze', 'multitask'], 132 | help='Whether or not pretrain the feature extractor') 133 | 134 | 135 | parser.add_argument('--encoder_sharing', type=str, default='shared', 136 | choices=['shared', 'siamese'], 137 | help='How to link fetaure extractors in task encoder and classifier') 138 | parser.add_argument('--encoder_classifier_link', type=str, default='prototypical', 139 | choices=['attention', 'cbn', 'prototypical', 'std_normalized_euc_head', 140 | 'self_attention_euclidian', 141 | 'cosine', 'polynomial', 'perceptron', 'cbn_cos'], 142 | help='How to link fetaure extractors in task encoder and classifier') 143 | parser.add_argument('--embedding_pooled', type=bool, default=True, 144 | help='Whether to use avg pooling to create embedding') 145 | parser.add_argument('--task_encoder', type=str, default='class_mean', 146 | choices=['talkthrough', 'class_mean', 'label_embed', 'self_attention']) 147 | 148 | parser.add_argument('--conv_dropout', type=float, default=None) 149 | # 150 | parser.add_argument('--num_batches_neg_mining', type=int, default=0) 151 | parser.add_argument('--eval_batch_size', type=int, default=100, help='Evaluation batch size?') 152 | 153 | args = parser.parse_args() 154 | 155 | print(args) 156 | return args 157 | 158 | 159 | def get_logdir_name(flags): 160 | """Generates the name of the log directory from the values of flags 161 | Parameters 162 | ---------- 163 | flags: neural net architecture generated by get_arguments() 164 | Outputs 165 | ------- 166 | the name of the directory to store the training and evaluation results 167 | """ 168 | epochs = (flags.number_of_steps * flags.train_batch_size) / flags.num_samples_train 169 | 170 | param_list = ['batch_size', str(flags.train_batch_size), 'num_tasks', str(flags.num_tasks_per_batch), 'lr', 171 | str(flags.init_learning_rate), 'lr_anneal', flags.lr_anneal, 172 | 'epochs', str(epochs), 'dropout', str(flags.dropout), 'opt', flags.optimizer, 173 | 'weight_decay', str(flags.weight_decay), 174 | 'nfilt', str(flags.num_filters), 'feature_extractor', str(flags.feature_extractor), 175 | 'task_encoder', str(flags.task_encoder), 176 | 'enc_cl_link', flags.encoder_classifier_link] 177 | 178 | if flags.log_dir == '': 179 | logdir = './logs1/' + '-'.join(param_list) 180 | else: 181 | logdir = os.path.join(flags.log_dir, '-'.join(param_list)) 182 | 183 | if flags.exp_dir is not None: 184 | # Running a Borgy experiment 185 | logdir = flags.exp_dir 186 | 187 | return logdir 188 | 189 | 190 | class ScaledVarianceRandomNormal(init_ops.Initializer): 191 | """Initializer that generates tensors with a normal distribution scaled as per https://arxiv.org/pdf/1502.01852.pdf. 192 | Args: 193 | mean: a python scalar or a scalar tensor. Mean of the random values 194 | to generate. 195 | stddev: a python scalar or a scalar tensor. Standard deviation of the 196 | random values to generate. 197 | seed: A Python integer. Used to create random seeds. See 198 | @{tf.set_random_seed} 199 | for behavior. 200 | dtype: The data type. Only floating point types are supported. 201 | """ 202 | 203 | def __init__(self, mean=0.0, factor=1.0, seed=None, dtype=dtypes.float32): 204 | self.mean = mean 205 | self.factor = factor 206 | self.seed = seed 207 | self.dtype = dtypes.as_dtype(dtype) 208 | 209 | def __call__(self, shape, dtype=None, partition_info=None): 210 | if dtype is None: 211 | dtype = self.dtype 212 | 213 | if shape: 214 | n = float(shape[-1]) 215 | else: 216 | n = 1.0 217 | for dim in shape[:-2]: 218 | n *= float(dim) 219 | 220 | self.stddev = np.sqrt(self.factor * 2.0 / n) 221 | return random_ops.random_normal(shape, self.mean, self.stddev, 222 | dtype, seed=self.seed) 223 | 224 | 225 | def _get_scope(is_training, flags): 226 | normalizer_params = { 227 | 'epsilon': 0.001, 228 | 'momentum': .95, 229 | 'trainable': is_training, 230 | 'training': is_training, 231 | } 232 | conv2d_arg_scope = slim.arg_scope( 233 | [slim.conv2d, slim.fully_connected], 234 | activation_fn=ACTIVATION_MAP[flags.activation], 235 | normalizer_fn=tf.layers.batch_normalization, 236 | normalizer_params=normalizer_params, 237 | # padding='SAME', 238 | trainable=is_training, 239 | weights_regularizer=tf.contrib.layers.l2_regularizer(scale=flags.weight_decay), 240 | weights_initializer=ScaledVarianceRandomNormal(factor=flags.weights_initializer_factor), 241 | biases_initializer=tf.constant_initializer(0.0) 242 | ) 243 | dropout_arg_scope = slim.arg_scope( 244 | [slim.dropout], 245 | keep_prob=flags.dropout, 246 | is_training=is_training) 247 | return conv2d_arg_scope, dropout_arg_scope 248 | 249 | 250 | def build_simple_conv_net(images, flags, is_training, reuse=None, scope=None): 251 | conv2d_arg_scope, dropout_arg_scope = _get_scope(is_training, flags) 252 | with conv2d_arg_scope, dropout_arg_scope: 253 | with tf.variable_scope(scope or 'feature_extractor', reuse=reuse): 254 | h = images 255 | for i in range(4): 256 | h = slim.conv2d(h, num_outputs=flags.num_filters, kernel_size=3, stride=1, 257 | scope='conv' + str(i), padding='SAME', 258 | weights_initializer=ScaledVarianceRandomNormal(factor=flags.weights_initializer_factor)) 259 | h = slim.max_pool2d(h, kernel_size=2, stride=2, padding='VALID', scope='max_pool' + str(i)) 260 | 261 | if flags.embedding_pooled == True: 262 | kernel_size = h.shape.as_list()[-2] 263 | h = slim.avg_pool2d(h, kernel_size=kernel_size, scope='avg_pool') 264 | h = slim.flatten(h) 265 | return h 266 | 267 | 268 | def leaky_relu(x, alpha=0.1, name=None): 269 | return tf.maximum(x, alpha * x, name=name) 270 | 271 | 272 | 273 | 274 | def build_simple_res_net(images, flags, num_filters, beta=None, gamma=None, is_training=False, reuse=None, scope=None): 275 | conv2d_arg_scope, dropout_arg_scope = _get_scope(is_training, flags) 276 | activation_fn = ACTIVATION_MAP[flags.activation] 277 | with conv2d_arg_scope, dropout_arg_scope: 278 | with tf.variable_scope(scope or 'feature_extractor', reuse=reuse): 279 | # h = slim.conv2d(images, num_outputs=num_filters[0], kernel_size=6, stride=1, 280 | # scope='conv_input', padding='SAME') 281 | # h = slim.max_pool2d(h, kernel_size=2, stride=2, padding='SAME', scope='max_pool_input') 282 | h = images 283 | for i in range(len(num_filters)): 284 | # make shortcut 285 | shortcut = slim.conv2d(h, num_outputs=num_filters[i], kernel_size=1, stride=1, 286 | activation_fn=None, 287 | scope='shortcut' + str(i), padding='SAME') 288 | 289 | for j in range(flags.num_units_in_block): 290 | h = slim.conv2d(h, num_outputs=num_filters[i], kernel_size=3, stride=1, 291 | scope='conv' + str(i) + '_' + str(j), padding='SAME', activation_fn=None) 292 | if flags.conv_dropout: 293 | h = slim.dropout(h, keep_prob=1.0 - flags.conv_dropout) 294 | 295 | if j < (flags.num_units_in_block - 1): 296 | h = activation_fn(h, name='activation_' + str(i) + '_' + str(j)) 297 | h = h + shortcut 298 | 299 | h = activation_fn(h, name='activation_' + str(i) + '_' + str(flags.num_units_in_block - 1)) 300 | if i < flags.num_max_pools: 301 | h = slim.max_pool2d(h, kernel_size=2, stride=2, padding='SAME', scope='max_pool' + str(i)) 302 | 303 | if flags.feature_expansion_size: 304 | if flags.feature_dropout_p: 305 | h = slim.dropout(h, scope='feature_expansion_dropout', keep_prob=1.0 - flags.feature_dropout_p) 306 | h = slim.conv2d(slim.dropout(h), num_outputs=flags.feature_expansion_size, kernel_size=1, stride=1, 307 | scope='feature_expansion', padding='SAME') 308 | 309 | if flags.embedding_pooled == True: 310 | kernel_size = h.shape.as_list()[-2] 311 | h = slim.avg_pool2d(h, kernel_size=kernel_size, scope='avg_pool') 312 | h = slim.flatten(h) 313 | 314 | if flags.feature_dropout_p: 315 | h = slim.dropout(h, scope='feature_bottleneck_dropout', keep_prob=1.0 - flags.feature_dropout_p) 316 | # Bottleneck layer 317 | if flags.feature_bottleneck_size: 318 | h = slim.fully_connected(h, num_outputs=flags.feature_bottleneck_size, 319 | activation_fn=activation_fn, normalizer_fn=None, 320 | scope='feature_bottleneck') 321 | 322 | return h 323 | 324 | 325 | def get_res_net_block(h, flags, num_filters, num_units, pool=False, beta=None, gamma=None, is_training=False, 326 | reuse=None, scope=None): 327 | conv2d_arg_scope, dropout_arg_scope = _get_scope(is_training, flags) 328 | activation_fn = ACTIVATION_MAP[flags.activation] 329 | with conv2d_arg_scope, dropout_arg_scope: 330 | with tf.variable_scope(scope, reuse=reuse): 331 | # make shortcut 332 | shortcut = slim.conv2d(h, num_outputs=num_filters, kernel_size=1, stride=1, 333 | activation_fn=None, 334 | scope='shortcut', padding='SAME') 335 | 336 | for j in range(num_units): 337 | h = slim.conv2d(h, num_outputs=num_filters, kernel_size=3, stride=1, 338 | scope='conv_' + str(j), padding='SAME', activation_fn=None) 339 | if flags.conv_dropout: 340 | h = slim.dropout(h, keep_prob=1.0 - flags.conv_dropout) 341 | if j < (num_units - 1): 342 | h = activation_fn(h, name='activation_' + str(j)) 343 | h = h + shortcut 344 | h = activation_fn(h, name='activation_' + '_' + str(flags.num_units_in_block - 1)) 345 | if pool: 346 | h = slim.max_pool2d(h, kernel_size=2, stride=2, padding='SAME', scope='max_pool') 347 | return h 348 | 349 | 350 | 351 | def build_feature_extractor_graph(images, flags, num_filters, beta=None, gamma=None, is_training=False, 352 | scope='feature_extractor_task_encoder', reuse=None, is_64way=False): 353 | if flags.feature_extractor == 'simple_conv_net': 354 | h = build_simple_conv_net(images, flags=flags, is_training=is_training, reuse=reuse, scope=scope) 355 | elif flags.feature_extractor == 'simple_res_net': 356 | h = build_simple_res_net(images, flags=flags, num_filters=num_filters, beta=beta, gamma=gamma, 357 | is_training=is_training, reuse=reuse, scope=scope) 358 | else: 359 | h = None 360 | 361 | embedding_shape = h.get_shape().as_list() 362 | if is_training and is_64way is False: 363 | h = tf.reshape(h, shape=(flags.num_tasks_per_batch, embedding_shape[0] // flags.num_tasks_per_batch, -1), 364 | name='reshape_to_separate_tasks_generic_features') 365 | else: 366 | h = tf.reshape(h, shape=(1, embedding_shape[0], -1), 367 | name='reshape_to_separate_tasks_generic_features') 368 | 369 | return h 370 | 371 | 372 | 373 | def build_task_encoder(embeddings, labels, flags, is_training, reuse=None, scope='class_encoder'): 374 | conv2d_arg_scope, dropout_arg_scope = _get_scope(is_training, flags) 375 | 376 | with conv2d_arg_scope, dropout_arg_scope: 377 | with tf.variable_scope(scope, reuse=reuse): 378 | 379 | if flags.task_encoder == 'talkthrough': 380 | task_encoding = embeddings 381 | elif flags.task_encoder == 'class_mean': 382 | task_encoding = embeddings 383 | 384 | if is_training: 385 | task_encoding = tf.reshape(task_encoding, shape=( 386 | flags.num_tasks_per_batch, flags.num_classes_train, flags.num_shots_train, -1), 387 | name='reshape_to_separate_tasks_task_encoding') 388 | else: 389 | task_encoding = tf.reshape(task_encoding, 390 | shape=(1, flags.num_classes_test, flags.num_shots_test, -1), 391 | name='reshape_to_separate_tasks_task_encoding') 392 | task_encoding = tf.reduce_mean(task_encoding, axis=2, keep_dims=False) 393 | else: 394 | task_encoding = None 395 | 396 | return task_encoding 397 | 398 | 399 | def build_prototypical_head(features_generic, task_encoding, flags, is_training, scope='prototypical_head'): 400 | """ 401 | Implements the prototypical networks few-shot head 402 | :param features_generic: 403 | :param task_encoding: 404 | :param flags: 405 | :param is_training: 406 | :param reuse: 407 | :param scope: 408 | :return: 409 | """ 410 | 411 | with tf.variable_scope(scope): 412 | 413 | if len(features_generic.get_shape().as_list()) == 2: 414 | features_generic = tf.expand_dims(features_generic, axis=0) 415 | if len(task_encoding.get_shape().as_list()) == 2: 416 | task_encoding = tf.expand_dims(task_encoding, axis=0) 417 | 418 | # i is the number of steps in the task_encoding sequence 419 | # j is the number of steps in the features_generic sequence 420 | j = task_encoding.get_shape().as_list()[1] 421 | i = features_generic.get_shape().as_list()[1] 422 | 423 | # tile to be able to produce weight matrix alpha in (i,j) space 424 | features_generic = tf.expand_dims(features_generic, axis=2) 425 | task_encoding = tf.expand_dims(task_encoding, axis=1) 426 | # features_generic changes over i and is constant over j 427 | # task_encoding changes over j and is constant over i 428 | task_encoding_tile = tf.tile(task_encoding, (1, i, 1, 1)) 429 | features_generic_tile = tf.tile(features_generic, (1, 1, j, 1)) 430 | # implement equation (4) 431 | euclidian = -tf.norm(task_encoding_tile - features_generic_tile, name='neg_euclidian_distance', axis=-1) 432 | 433 | if is_training: 434 | euclidian = tf.reshape(euclidian, shape=(flags.num_tasks_per_batch * flags.train_batch_size, -1)) 435 | else: 436 | euclidian_shape = euclidian.get_shape().as_list() 437 | euclidian = tf.reshape(euclidian, shape=(euclidian_shape[1], -1)) 438 | 439 | return euclidian 440 | 441 | 442 | def placeholder_inputs(batch_size, image_size, scope): 443 | """ 444 | :param batch_size: 445 | :return: placeholders for images and 446 | """ 447 | with tf.variable_scope(scope): 448 | images_placeholder = tf.placeholder(tf.float32, shape=(batch_size, image_size, image_size, 3), name='images') 449 | labels_placeholder = tf.placeholder(tf.int64, shape=(batch_size), name='labels') 450 | return images_placeholder, labels_placeholder 451 | 452 | 453 | def get_batch(data_set, images_placeholder, labels_placeholder, batch_size): 454 | """ 455 | :param data_set: 456 | :param images_placeholder: 457 | :param labels_placeholder: 458 | :return: 459 | """ 460 | images_feed, labels_feed = data_set.next_batch(batch_size) 461 | 462 | feed_dict = { 463 | images_placeholder: images_feed.astype(dtype=np.float32), 464 | labels_placeholder: labels_feed, 465 | } 466 | return feed_dict 467 | 468 | 469 | def preprocess(images): 470 | # mean = tf.constant(np.asarray([127.5, 127.5, 127.5]).reshape([1, 1, 3]), dtype=tf.float32, name='image_mean') 471 | # std = tf.constant(np.asarray([127.5, 127.5, 127.5]).reshape([1, 1, 3]), dtype=tf.float32, name='image_std') 472 | # return tf.div(tf.subtract(images, mean), std) 473 | 474 | std = tf.constant(np.asarray([0.5, 0.5, 0.5]).reshape([1, 1, 3]), dtype=tf.float32, name='image_std') 475 | return tf.div(images, std) 476 | 477 | 478 | def get_nearest_neighbour_acc(flags, embeddings, labels): 479 | num_correct = 0 480 | num_tot = 0 481 | for i in trange(flags.num_cases_test): 482 | test_classes = np.random.choice(np.unique(labels), size=flags.num_classes_test, replace=False) 483 | train_idxs, test_idxs = get_few_shot_idxs(labels=labels, classes=test_classes, num_shots=flags.num_shots_test) 484 | # TODO: this is to fix the OOM error, this can be removed when embed() supports batch processing 485 | test_idxs = np.random.choice(test_idxs, size=100, replace=False) 486 | 487 | np_embedding_train = embeddings[train_idxs] 488 | # Using the np.std instead of np.linalg.norm improves results by around 1-1.5% 489 | np_embedding_train = np_embedding_train / np.std(np_embedding_train, axis=1, keepdims=True) 490 | # np_embedding_train = np_embedding_train / np.linalg.norm(np_embedding_train, axis=1, keepdims=True) 491 | labels_train = labels[train_idxs] 492 | 493 | np_embedding_test = embeddings[test_idxs] 494 | np_embedding_test = np_embedding_test / np.std(np_embedding_test, axis=1, keepdims=True) 495 | # np_embedding_test = np_embedding_test / np.linalg.norm(np_embedding_test, axis=1, keepdims=True) 496 | labels_test = labels[test_idxs] 497 | 498 | kdtree = KDTree(np_embedding_train) 499 | nns, nn_idxs = kdtree.query(np_embedding_test, k=1) 500 | labels_predicted = labels_train[nn_idxs] 501 | 502 | num_matches = sum(labels_predicted == labels_test) 503 | 504 | num_correct += num_matches 505 | num_tot += len(labels_predicted) 506 | 507 | # print("Accuracy: ", (100.0 * num_correct) / num_tot) 508 | return (100.0 * num_correct) / num_tot 509 | 510 | 511 | 512 | def build_inference_graph(images_deploy_pl, images_task_encode_pl, labels_task_encode_pl, flags, is_training, 513 | is_primary): 514 | num_filters = [round(flags.num_filters * pow(flags.block_size_growth, i)) for i in range(flags.num_blocks)] 515 | reuse = not is_primary 516 | 517 | with tf.variable_scope('Model'): 518 | feature_extractor_encoding_scope = 'feature_extractor_encoder' 519 | 520 | features_task_encode = build_feature_extractor_graph(images=images_task_encode_pl, flags=flags, 521 | is_training=is_training, 522 | num_filters=num_filters, 523 | scope=feature_extractor_encoding_scope, 524 | reuse=False) 525 | if flags.encoder_sharing == 'shared': 526 | ecoder_reuse = True 527 | feature_extractor_classifier_scope = feature_extractor_encoding_scope 528 | elif flags.encoder_sharing == 'siamese': 529 | # TODO: in the case of pretrained feature extractor this is not good, 530 | # because the classfier part will be randomly initialized 531 | ecoder_reuse = False 532 | feature_extractor_classifier_scope = 'feature_extractor_classifier' 533 | else: 534 | raise Exception('Option not implemented') 535 | 536 | if flags.encoder_classifier_link == 'prototypical': 537 | flags.task_encoder = 'class_mean' 538 | task_encoding = build_task_encoder(embeddings=features_task_encode, labels=labels_task_encode_pl, 539 | flags=flags, is_training=is_training, reuse=reuse) 540 | features_generic = build_feature_extractor_graph(images=images_deploy_pl, flags=flags, 541 | is_training=is_training, 542 | scope=feature_extractor_classifier_scope, 543 | num_filters=num_filters, 544 | reuse=ecoder_reuse) 545 | logits = build_prototypical_head(features_generic, task_encoding, flags, is_training=is_training) 546 | else: 547 | raise Exception('Option not implemented') 548 | 549 | return logits, features_task_encode, features_generic 550 | 551 | 552 | 553 | 554 | def get_train_datasets(flags): 555 | mini_imagenet = _load_mini_imagenet(data_dir=flags.data_dir, split='sources') 556 | few_shot_data_train = Dataset(mini_imagenet) 557 | pretrain_data_train, pretrain_data_test = None, None 558 | return few_shot_data_train, pretrain_data_train, pretrain_data_test 559 | 560 | 561 | def get_pwc_learning_rate(global_step, flags): 562 | learning_rate = tf.train.piecewise_constant(global_step, [np.int64(flags.number_of_steps / 2), 563 | np.int64( 564 | flags.number_of_steps / 2 + flags.num_steps_decay_pwc), 565 | np.int64( 566 | flags.number_of_steps / 2 + 2 * flags.num_steps_decay_pwc)], 567 | [flags.init_learning_rate, flags.init_learning_rate * 0.1, 568 | flags.init_learning_rate * 0.01, 569 | flags.init_learning_rate * 0.001]) 570 | return learning_rate 571 | 572 | 573 | def create_hard_negative_batch(misclass, feed_dict, sess, few_shot_data_train, flags, 574 | images_deploy_pl, labels_deploy_pl, images_task_encode_pl, labels_task_encode_pl): 575 | """ 576 | 577 | :param logits: 578 | :param feed_dict: 579 | :param sess: 580 | :param few_shot_data_train: 581 | :param flags: 582 | :param images_deploy_pl: 583 | :param labels_deploy_pl: 584 | :param images_task_encode_pl: 585 | :param labels_task_encode_pl: 586 | :return: 587 | """ 588 | feed_dict_test = dict(feed_dict) 589 | misclass_test_final = 0.0 590 | misclass_history = np.zeros(flags.num_batches_neg_mining) 591 | for i in range(flags.num_batches_neg_mining): 592 | images_deploy, labels_deploy, images_task_encode, labels_task_encode = \ 593 | few_shot_data_train.next_few_shot_batch(deploy_batch_size=flags.train_batch_size, 594 | num_classes_test=flags.num_classes_train, 595 | num_shots=flags.num_shots_train, 596 | num_tasks=flags.num_tasks_per_batch) 597 | 598 | feed_dict_test[images_deploy_pl] = images_deploy.astype(dtype=np.float32) 599 | feed_dict_test[labels_deploy_pl] = labels_deploy 600 | feed_dict_test[images_task_encode_pl] = images_task_encode.astype(dtype=np.float32) 601 | feed_dict_test[labels_task_encode_pl] = labels_task_encode 602 | 603 | # logits 604 | misclass_test = sess.run(misclass, feed_dict=feed_dict_test) 605 | misclass_history[i] = misclass_test 606 | if misclass_test > misclass_test_final: 607 | misclass_test_final = misclass_test 608 | feed_dict = dict(feed_dict_test) 609 | 610 | return feed_dict 611 | 612 | 613 | def train(flags): 614 | log_dir = get_logdir_name(flags) 615 | flags.pretrained_model_dir = log_dir 616 | log_dir = os.path.join(log_dir, 'train') 617 | # This is setting to run evaluation loop only once 618 | flags.max_number_of_evaluations = 1 619 | flags.eval_interval_secs = 0 620 | image_size = get_image_size(flags.data_dir) 621 | 622 | with tf.Graph().as_default(): 623 | global_step = tf.Variable(0, trainable=False, name='global_step', dtype=tf.int64) 624 | global_step_pretrain = tf.Variable(0, trainable=False, name='global_step_pretrain', dtype=tf.int64) 625 | 626 | images_deploy_pl, labels_deploy_pl = placeholder_inputs( 627 | batch_size=flags.num_tasks_per_batch * flags.train_batch_size, 628 | image_size=image_size, scope='inputs/deploy') 629 | images_task_encode_pl, labels_task_encode_pl = placeholder_inputs( 630 | batch_size=flags.num_tasks_per_batch * flags.num_classes_train * flags.num_shots_train, 631 | image_size=image_size, scope='inputs/task_encode') 632 | 633 | 634 | # Primary task operations 635 | logits, _, _ = build_inference_graph(images_deploy_pl=images_deploy_pl, 636 | images_task_encode_pl=images_task_encode_pl, 637 | labels_task_encode_pl=labels_task_encode_pl, 638 | flags=flags, is_training=True, is_primary=True) 639 | loss = tf.reduce_mean( 640 | tf.nn.softmax_cross_entropy_with_logits(logits=logits, 641 | labels=tf.one_hot(labels_deploy_pl, flags.num_classes_train))) 642 | 643 | # Losses and optimizer 644 | regu_losses = slim.losses.get_regularization_losses() 645 | loss = tf.add_n([loss] + regu_losses) 646 | misclass = 1.0 - slim.metrics.accuracy(tf.argmax(logits, 1), labels_deploy_pl) 647 | 648 | # Learning rate 649 | if flags.lr_anneal == 'const': 650 | learning_rate = flags.init_learning_rate 651 | elif flags.lr_anneal == 'pwc': 652 | learning_rate = get_pwc_learning_rate(global_step, flags) 653 | elif flags.lr_anneal == 'exp': 654 | lr_decay_step = flags.number_of_steps // flags.n_lr_decay 655 | learning_rate = tf.train.exponential_decay(flags.init_learning_rate, global_step, lr_decay_step, 656 | 1.0 / flags.lr_decay_rate, staircase=True) 657 | else: 658 | raise Exception('Not implemented') 659 | 660 | # Optimizer 661 | if flags.optimizer == 'sgd': 662 | optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9) 663 | else: 664 | optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) 665 | 666 | train_op = slim.learning.create_train_op(total_loss=loss, optimizer=optimizer, global_step=global_step, 667 | clip_gradient_norm=flags.clip_gradient_norm) 668 | 669 | tf.summary.scalar('loss', loss) 670 | tf.summary.scalar('misclassification', misclass) 671 | tf.summary.scalar('learning_rate', learning_rate) 672 | # Merge all summaries except for pretrain 673 | summary = tf.summary.merge(tf.get_collection('summaries', scope='(?!pretrain).*')) 674 | 675 | 676 | # Get datasets 677 | few_shot_data_train, pretrain_data_train, pretrain_data_test = get_train_datasets(flags) 678 | # Define session and logging 679 | summary_writer = tf.summary.FileWriter(log_dir, flush_secs=1) 680 | saver = tf.train.Saver(max_to_keep=1, save_relative_paths=True) 681 | run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) 682 | run_metadata = tf.RunMetadata() 683 | supervisor = tf.train.Supervisor(logdir=log_dir, init_feed_dict=None, 684 | summary_op=None, 685 | init_op=tf.global_variables_initializer(), 686 | summary_writer=summary_writer, 687 | saver=saver, 688 | global_step=global_step, save_summaries_secs=flags.save_summaries_secs, 689 | save_model_secs=0) # flags.save_interval_secs 690 | 691 | with supervisor.managed_session() as sess: 692 | checkpoint_step = sess.run(global_step) 693 | if checkpoint_step > 0: 694 | checkpoint_step += 1 695 | 696 | eval_interval_steps = flags.eval_interval_steps 697 | for step in range(checkpoint_step, flags.number_of_steps): 698 | # get batch of data to compute classification loss 699 | images_deploy, labels_deploy, images_task_encode, labels_task_encode = \ 700 | few_shot_data_train.next_few_shot_batch(deploy_batch_size=flags.train_batch_size, 701 | num_classes_test=flags.num_classes_train, 702 | num_shots=flags.num_shots_train, 703 | num_tasks=flags.num_tasks_per_batch) 704 | if flags.augment: 705 | images_deploy = image_augment(images_deploy) 706 | images_task_encode = image_augment(images_task_encode) 707 | 708 | feed_dict = {images_deploy_pl: images_deploy.astype(dtype=np.float32), labels_deploy_pl: labels_deploy, 709 | images_task_encode_pl: images_task_encode.astype(dtype=np.float32), 710 | labels_task_encode_pl: labels_task_encode} 711 | 712 | 713 | t_batch = time.time() 714 | feed_dict = create_hard_negative_batch(misclass, feed_dict, sess, few_shot_data_train, flags, 715 | images_deploy_pl, labels_deploy_pl, images_task_encode_pl, 716 | labels_task_encode_pl) 717 | dt_batch = time.time() - t_batch 718 | 719 | t_train = time.time() 720 | loss = sess.run(train_op, feed_dict=feed_dict) 721 | dt_train = time.time() - t_train 722 | 723 | if step % 100 == 0: 724 | summary_str = sess.run(summary, feed_dict=feed_dict) 725 | summary_writer.add_summary(summary_str, step) 726 | summary_writer.flush() 727 | logging.info("step %d, loss : %.4g, dt: %.3gs, dt_batch: %.3gs" % (step, loss, dt_train, dt_batch)) 728 | 729 | if float(step) / flags.number_of_steps > 0.5: 730 | eval_interval_steps = flags.eval_interval_fine_steps 731 | 732 | if eval_interval_steps > 0 and step % eval_interval_steps == 0: 733 | saver.save(sess, os.path.join(log_dir, 'model'), global_step=step) 734 | eval(flags=flags, is_primary=True) 735 | 736 | if float(step) > 0.5 * flags.number_of_steps + flags.number_of_steps_to_early_stop: 737 | break 738 | 739 | 740 | 741 | class ModelLoader: 742 | def __init__(self, model_path, batch_size, is_primary): 743 | self.batch_size = batch_size 744 | 745 | latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir=os.path.join(model_path, 'train')) 746 | step = int(os.path.basename(latest_checkpoint).split('-')[1]) 747 | 748 | flags = Namespace(load_and_save_params(default_params=dict(), exp_dir=model_path)) 749 | image_size = get_image_size(flags.data_dir) 750 | 751 | with tf.Graph().as_default(): 752 | images_deploy_pl, labels_deploy_pl = placeholder_inputs(batch_size=batch_size, 753 | image_size=image_size, scope='inputs/deploy') 754 | if is_primary: 755 | task_encode_batch_size = flags.num_classes_test * flags.num_shots_test 756 | images_task_encode_pl, labels_task_encode_pl = placeholder_inputs(batch_size=task_encode_batch_size, 757 | image_size=image_size, 758 | scope='inputs/task_encode') 759 | 760 | self.tensor_images_deploy = images_deploy_pl 761 | self.tensor_labels_task_encode = labels_task_encode_pl 762 | self.tensor_images_task_encode = images_task_encode_pl 763 | self.tensor_labels_deploy_pl = labels_deploy_pl 764 | # TODO: This is just used to create the variables of primary graph that will be reused in aux graph 765 | if not is_primary: 766 | primary_logits, _, _ = build_inference_graph(images_deploy_pl=images_deploy_pl, 767 | images_task_encode_pl=images_task_encode_pl, 768 | labels_task_encode_pl=labels_task_encode_pl, 769 | flags=flags, is_training=False, is_primary=True) 770 | 771 | logits, features_sample, features_query = build_inference_graph( 772 | images_deploy_pl=images_deploy_pl, images_task_encode_pl=images_task_encode_pl, 773 | labels_task_encode_pl=labels_task_encode_pl, flags=flags, is_training=False, 774 | is_primary=is_primary) 775 | 776 | loss = tf.reduce_mean( 777 | tf.nn.softmax_cross_entropy_with_logits(logits=logits, 778 | labels=tf.one_hot(labels_deploy_pl, flags.num_classes_test))) 779 | # Losses and optimizer 780 | regu_losses = slim.losses.get_regularization_losses() 781 | loss = tf.add_n([loss] + regu_losses) 782 | 783 | init_fn = slim.assign_from_checkpoint_fn( 784 | latest_checkpoint, 785 | slim.get_model_variables('Model')) 786 | 787 | config = tf.ConfigProto(allow_soft_placement=True) 788 | config.gpu_options.allow_growth = True 789 | self.sess = tf.Session(config=config) 790 | 791 | # Run init before loading the weights 792 | self.sess.run(tf.global_variables_initializer()) 793 | # Load weights 794 | init_fn(self.sess) 795 | 796 | self.flags = flags 797 | self.logits = logits 798 | self.loss = loss 799 | self.features_sample = features_sample 800 | self.features_query = features_query 801 | self.logits_size = self.logits.get_shape().as_list()[-1] 802 | self.step = step 803 | self.is_primary = is_primary 804 | 805 | log_dir = get_logdir_name(flags) 806 | graphpb_txt = str(tf.get_default_graph().as_graph_def()) 807 | pathlib.Path(os.path.join(log_dir, 'eval')).mkdir(parents=True, exist_ok=True) 808 | with open(os.path.join(log_dir, 'eval', 'graph.pbtxt'), 'w') as f: 809 | f.write(graphpb_txt) 810 | 811 | def eval(self, data_dir, num_cases_test, split='target_val'): 812 | data_set = Dataset(_load_mini_imagenet(data_dir=data_dir, split=split)) 813 | 814 | num_batches = num_cases_test // self.batch_size 815 | num_correct = 0.0 816 | num_tot = 0.0 817 | loss_tot = 0.0 818 | for i in trange(num_batches): 819 | if self.is_primary: 820 | num_classes, num_shots = self.flags.num_classes_test, self.flags.num_shots_test 821 | 822 | 823 | images_deploy, labels_deploy, images_task_encode, labels_task_encode = \ 824 | data_set.next_few_shot_batch(deploy_batch_size=self.batch_size, 825 | num_classes_test=num_classes, num_shots=num_shots, 826 | num_tasks=1) 827 | 828 | feed_dict = {self.tensor_images_deploy: images_deploy.astype(dtype=np.float32), 829 | self.tensor_labels_deploy_pl: labels_deploy, 830 | self.tensor_labels_task_encode: labels_task_encode, 831 | self.tensor_images_task_encode: images_task_encode.astype(dtype=np.float32)} 832 | [logits, loss] = self.sess.run([self.logits, self.loss], feed_dict) 833 | labels_deploy_pred = np.argmax(logits, axis=-1) 834 | 835 | num_matches = sum(labels_deploy_pred == labels_deploy) 836 | num_correct += num_matches 837 | num_tot += len(labels_deploy_pred) 838 | loss_tot += loss 839 | 840 | return num_correct / num_tot, loss_tot / num_batches 841 | 842 | 843 | def get_few_shot_idxs(labels, classes, num_shots): 844 | train_idxs, test_idxs = [], [] 845 | idxs = np.arange(len(labels)) 846 | for cl in classes: 847 | class_idxs = idxs[labels == cl] 848 | class_idxs_train = np.random.choice(class_idxs, size=num_shots, replace=False) 849 | class_idxs_test = np.setxor1d(class_idxs, class_idxs_train) 850 | 851 | train_idxs.extend(class_idxs_train) 852 | test_idxs.extend(class_idxs_test) 853 | 854 | assert set(class_idxs_train).isdisjoint(test_idxs) 855 | 856 | return np.array(train_idxs), np.array(test_idxs) 857 | 858 | 859 | def test(flags): 860 | test_dataset = _load_mini_imagenet(data_dir=flags.data_dir, split='target_val') 861 | 862 | # test_dataset = _load_mini_imagenet(data_dir=flags.data_dir, split='sources') 863 | images = test_dataset[0] 864 | labels = test_dataset[1] 865 | 866 | embedding_model = ModelLoader(flags.pretrained_model_dir, batch_size=100) 867 | embeddings = embedding_model.embed(images=test_dataset[0]) 868 | embedding_model = None 869 | print("Accuracy test raw embedding: ", get_nearest_neighbour_acc(flags, embeddings, labels)) 870 | 871 | 872 | def get_agg_misclassification(logits_dict, labels_dict): 873 | summary_ops = [] 874 | update_ops = {} 875 | for key, logits in logits_dict.items(): 876 | accuracy, update = slim.metrics.streaming_accuracy(tf.argmax(logits, 1), labels_dict[key]) 877 | 878 | names_to_values, names_to_updates = slim.metrics.aggregate_metric_map( 879 | {'misclassification_' + key: (1.0 - accuracy, update)}) 880 | 881 | for metric_name, metric_value in names_to_values.items(): 882 | op = tf.summary.scalar(metric_name, metric_value) 883 | op = tf.Print(op, [metric_value], metric_name) 884 | summary_ops.append(op) 885 | 886 | for update_name, update_op in names_to_updates.items(): 887 | update_ops[update_name] = update_op 888 | return summary_ops, update_ops 889 | 890 | 891 | def eval(flags, is_primary): 892 | log_dir = get_logdir_name(flags) 893 | if is_primary: 894 | aux_prefix = '' 895 | else: 896 | aux_prefix = 'aux/' 897 | 898 | eval_writer = summary_writer(log_dir + '/eval') 899 | i = 0 900 | last_step = -1 901 | while i < flags.max_number_of_evaluations: 902 | latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir=flags.pretrained_model_dir) 903 | model_step = int(os.path.basename(latest_checkpoint or '0-0').split('-')[1]) 904 | if last_step < model_step: 905 | results = {} 906 | model = ModelLoader(model_path=flags.pretrained_model_dir, batch_size=flags.eval_batch_size, 907 | is_primary=is_primary) 908 | 909 | acc_tst, loss_tst = model.eval(data_dir=flags.data_dir, num_cases_test=flags.num_cases_test, 910 | split='target_tst') 911 | acc_val, loss_val = model.eval(data_dir=flags.data_dir, num_cases_test=flags.num_cases_test, 912 | split='target_val') 913 | acc_trn, loss_trn = model.eval(data_dir=flags.data_dir, num_cases_test=flags.num_cases_test, 914 | split='sources') 915 | 916 | results[aux_prefix + "accuracy_target_tst"] = acc_tst 917 | results[aux_prefix + "accuracy_target_val"] = acc_val 918 | results[aux_prefix + "accuracy_sources"] = acc_trn 919 | 920 | results[aux_prefix + "loss_target_tst"] = loss_tst 921 | results[aux_prefix + "loss_target_val"] = loss_val 922 | results[aux_prefix + "loss_sources"] = loss_trn 923 | 924 | last_step = model.step 925 | eval_writer(model.step, **results) 926 | logging.info("accuracy_%s: %.3g, accuracy_%s: %.3g, accuracy_%s: %.3g, loss_%s: %.3g, loss_%s: %.3g, loss_%s: %.3g." 927 | % ( 928 | aux_prefix + "target_tst", acc_tst, aux_prefix + "target_val", acc_val, aux_prefix + "sources", 929 | acc_trn, aux_prefix + "target_tst", loss_tst, aux_prefix + "target_val", loss_val, aux_prefix + "sources", 930 | loss_trn)) 931 | if flags.eval_interval_secs > 0: 932 | time.sleep(flags.eval_interval_secs) 933 | i = i + 1 934 | 935 | 936 | 937 | 938 | 939 | def image_augment(images): 940 | """ 941 | 942 | :param images: 943 | :return: 944 | """ 945 | pad_percent = 0.125 946 | flip_proba = 0.5 947 | image_size = images.shape[1] 948 | pad_size = int(pad_percent * image_size) 949 | max_crop = 2 * pad_size 950 | 951 | images_aug = np.pad(images, ((0, 0), (pad_size, pad_size), (pad_size, pad_size), (0, 0)), mode='constant') 952 | output = [] 953 | for image in images_aug: 954 | if np.random.rand() < flip_proba: 955 | image = np.flip(image, axis=1) 956 | crop_val = np.random.randint(0, max_crop) 957 | image = image[crop_val:crop_val + image_size, crop_val:crop_val + image_size, :] 958 | output.append(image) 959 | return np.asarray(output) 960 | 961 | 962 | def main(argv=None): 963 | config = tf.ConfigProto(allow_soft_placement=True) 964 | config.gpu_options.per_process_gpu_memory_fraction = 1.0 965 | config.gpu_options.allow_growth = True 966 | sess = tf.Session(config=config) 967 | 968 | print(os.getcwd()) 969 | 970 | default_params = get_arguments() 971 | log_dir = get_logdir_name(flags=default_params) 972 | 973 | pathlib.Path(log_dir).mkdir(parents=True, exist_ok=True) 974 | # This makes sure that we can store a json and recove a namespace back 975 | flags = Namespace(load_and_save_params(vars(default_params), log_dir)) 976 | 977 | if flags.mode == 'train': 978 | train(flags=flags) 979 | elif flags.mode == 'eval': 980 | eval(flags=flags, is_primary=True) 981 | elif flags.mode == 'test': 982 | test(flags=flags) 983 | 984 | 985 | if __name__ == '__main__': 986 | tf.app.run() -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | *ServiceNow completed its acquisition of Element AI on January 8, 2021. All references to Element AI in the materials that are part of this project should refer to ServiceNow.* 2 | 3 | # ADAPTIVE CROSS-MODAL FEW-SHOT LEARNING (AW3) 4 | 5 | Code for paper Adaptive Cross-Modal Few-shot Learning. [[Arxiv]](http://www.dropwizard.io/1.0.2/docs/) 6 | 7 | ## Dependencies 8 | 9 | * cv2 10 | * numpy 11 | * python 3.5+ 12 | * tensorflow 1.3+ 13 | * tqdm 14 | * scipy 15 | 16 | ## Datasets 17 | 18 | First, designate a folder to be your data root: 19 | ``` 20 | export DATA_ROOT={DATA_ROOT} 21 | Then, set up the datasets following the instructions in the subsections. 22 | ``` 23 | ###miniImageNet 24 | 25 | [[Google Drive]](https://drive.google.com/file/d/1g4wOa0FpWalffXJMN2IZw0K2TM2uxzbk/view?usp=sharing)(1.05G) 26 | ``` 27 | # Download and place "mini-imagenet.zip" in "$DATA_ROOT/mini-imagenet". 28 | mkdir -p $DATA_ROOT/mini-imagenet 29 | cd $DATA_ROOT/mini-imagenet 30 | mv ~/Downloads/mini-imagenet.zip . 31 | unzip mini-imagenet.zip 32 | rm -f mini-imagenet.zip 33 | ``` 34 | ###tieredImageNet 35 | [[Google Drive]](https://drive.google.com/file/d/1Letu5U_kAjQfqJjNPWS_rdjJ7Fd46LbX/view?usp=sharing)(14.33G) 36 | ``` 37 | # Download and place "tiered-imagenet.zip" in "$DATA_ROOT/tiered-imagenet". 38 | mkdir -p $DATA_ROOT/tiered-imagenet 39 | cd $DATA_ROOT/tiered-imagenet 40 | mv ~/Downloads/tiered-imagenet.tar.gz . 41 | tar -xvf tiered-imagenet.tar.gz 42 | rm -f tiered-imagenet.tar.gz 43 | ``` 44 | ## AM3-ProtoNet 45 | ### 1-shot experiments 46 | For mini-ImageNet: 47 | ``` 48 | python AM3_protonet++.py --data_dir $DATA_ROOT/mini-imagenet/ 49 | --num_tasks_per_batch 5 --num_shots_train 1 --num_shots_test 1 --train_batch_size 24 50 | --mlp_dropout 0.7 --att_input word --task_encoder self_att_mlp 51 | --mlp_type non-linear --mlp_weight_decay 0.001 52 | --log_dir $EXP_DIR 53 | ``` 54 | 55 | For tiered-ImageNet: 56 | ``` 57 | python AM3_protonet++.py --data_dir $DATA_ROOT/tiered-imagenet/ 58 | --num_tasks_per_batch 5 --num_shots_train 1 --num_shots_test 1 --train_batch_size 24 59 | --num_steps_decay_pwc 10000 --number_of_steps 80000 60 | --mlp_dropout 0.7 --att_input word --task_encoder self_att_mlp 61 | --mlp_type non-linear --mlp_weight_decay 0.001 62 | --log_dir $EXP_DIR 63 | 64 | ``` 65 | 66 | ### 5-shot experiments 67 | For mini-ImageNet: 68 | ``` 69 | python AM3_protonet++.py --data_dir $DATA_ROOT/mini-imagenet/ 70 | --mlp_dropout 0.7 --att_input word --task_encoder self_att_mlp 71 | --mlp_type non-linear --mlp_weight_decay 0.001 72 | --log_dir $EXP_DIR 73 | ``` 74 | 75 | For tiered-ImageNet: 76 | ``` 77 | python AM3_protonet++.py --data_dir $DATA_ROOT/tiered-imagenet/ 78 | --num_steps_decay_pwc 10000 --number_of_steps 80000 79 | --mlp_dropout 0.7 --att_input word --task_encoder self_att_mlp 80 | --mlp_type non-linear --mlp_weight_decay 0.001 81 | --log_dir $EXP_DIR 82 | 83 | ``` 84 | 85 | ##AM3-TADAM 86 | Note that you may need to tune "--metric_multiplier_init" which is a TADAM hyper-parameter, via cross-validation to achieve sota results. The range of "--metric_multiplier_init" is usually (5, 10). 87 | ### 1-shot experiments 88 | For mini-ImageNet: 89 | ``` 90 | python AM3_TADAM.py --data_dir $DATA_ROOT/mini-imagenet/ 91 | --num_tasks_per_batch 5 --num_shots_train 1 --num_shots_test 1 --train_batch_size 24 --metric_multiplier_init 5 92 | --feat_extract_pretrain multitask --encoder_classifier_link cbn --num_cases_test 100000 93 | --activation_mlp relu --att_dropout 0.7 --att_type non-linear --att_weight_decay 0.001 94 | --mlp_dropout 0.7 --mlp_type non-linear --mlp_weight_decay 0.001 --att_input word --task_encoder self_att_mlp 95 | --log_dir $EXP_DIR 96 | ``` 97 | For tiered-ImageNet: 98 | ``` 99 | python AM3_TADAM.py --data_dir $DATA_ROOT/tiered-imagenet/ 100 | --num_tasks_per_batch 5 --num_shots_train 1 --num_shots_test 1 --train_batch_size 24 --metric_multiplier_init 5 101 | --feat_extract_pretrain multitask --encoder_classifier_link cbn --num_steps_decay_pwc 10000 102 | --number_of_steps 80000 --num_cases_test 100000 --num_classes_pretrain 351 103 | --att_dropout 0.9 --mlp_dropout 0.9 104 | --log_dir "$EXP_DIR 105 | 106 | ``` 107 | 108 | ### 5-shot experiments 109 | For mini-ImageNet: 110 | ``` 111 | python AM3_TADAM.py --data_dir $DATA_ROOT/mini-imagenet/ 112 | --metric_multiplier_init 7 113 | --feat_extract_pretrain multitask --encoder_classifier_link cbn --num_cases_test 100000 114 | --activation_mlp relu --att_dropout 0.7 --att_type non-linear --att_weight_decay 0.001 115 | --mlp_dropout 0.7 --mlp_type non-linear --mlp_weight_decay 0.001 --att_input word --task_encoder self_att_mlp 116 | --log_dir $EXP_DIR 117 | ``` 118 | For tiered-ImageNet: 119 | ``` 120 | python AM3_TADAM.py --data_dir $DATA_ROOT/tiered-imagenet/ 121 | --metric_multiplier_init 7 122 | --feat_extract_pretrain multitask --encoder_classifier_link cbn --num_steps_decay_pwc 10000 123 | --number_of_steps 80000 --num_cases_test 100000 --num_classes_pretrain 351 124 | --att_dropout 0.9 --mlp_dropout 0.9 125 | --log_dir "$EXP_DIR 126 | 127 | ``` 128 | 129 | ## Citation 130 | 131 | If you use our code, please consider cite the following: 132 | 133 | * Chen Xing, 134 | 135 | --------------------------------------------------------------------------------