├── src ├── __init__.py ├── configs │ ├── __init__.py │ ├── generic_omniglot_config.py │ ├── generic_mini_imagenet_config.py │ ├── configs.py │ ├── config_siamese_omniglot.py │ ├── config_siamese_mini_imagenet.py │ ├── config_mAP_omniglot.py │ └── config_mAP_mini_imagenet.py ├── data │ ├── __init__.py │ ├── data_handler.py │ ├── dataset.py │ ├── mini_imagenet.py │ └── omniglot.py ├── eval │ ├── __init__.py │ └── evaluator.py ├── models │ ├── __init__.py │ ├── model_handler.py │ ├── mAP_model.py │ ├── siamese.py │ └── model.py └── utils │ ├── __init__.py │ ├── loss_aug_AP.py │ └── dashboard_logger.py ├── setup.sh ├── LICENSE ├── README.md ├── run_eval.py ├── run_train.py └── data └── dataset_splits └── omniglot ├── test.txt ├── val.txt └── train.txt /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/configs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/eval/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | OMNIGLOT_PATH="path_to_omniglot" 2 | MINI_IMAGENET_PATH="path_to_imagenet" 3 | MINI_IMAGENET_SPLITS_PATH="path_to_mini_imagenet_splits" 4 | 5 | ln -s $OMNIGLOT_PATH data/omniglot 6 | ln -s $MINI_IMAGENET_PATH data/mini_imagenet 7 | ln -s $MINI_IMAGENET_SPLITS_PATH data/dataset_splits/mini_imagenet 8 | -------------------------------------------------------------------------------- /src/models/model_handler.py: -------------------------------------------------------------------------------- 1 | from __future__ import (absolute_import, division, print_function, 2 | unicode_literals) 3 | 4 | from src.models.siamese import SiameseModel 5 | from src.models.mAP_model import MeanAveragePrecisionModel 6 | 7 | 8 | def get_model(config, reuse=False): 9 | if config.model_type == "siamese": 10 | return SiameseModel(config, reuse=reuse) 11 | elif config.model_type == "mAP": 12 | return MeanAveragePrecisionModel(config, reuse=reuse) 13 | else: 14 | raise ValueError("Unknown model type. \"{}\"".format(config.model_type)) 15 | -------------------------------------------------------------------------------- /src/configs/generic_omniglot_config.py: -------------------------------------------------------------------------------- 1 | class GenericConfigOmniglot(object): 2 | """Contains options that are probably common to different models 3 | for training on Omniglot.""" 4 | 5 | def __init__(self): 6 | self.dataset = "omniglot" 7 | self.height = 28 8 | self.width = 28 9 | self.channels = 1 10 | self.num_fewshot_samples = 100 11 | self.optimizer = "ADAM" 12 | self.niters = 20000 13 | self.compute_mAP = True 14 | self.compute_fewshot = True 15 | self.display_step = 20 16 | self.validation_freq = 500 17 | self.save_freq = 2000 18 | self.update_dashboard_freq = 100 19 | self.compute_mAP_freq = 500 20 | self.fewshot_test_freq = 1000 21 | self.save_model = True 22 | self.neval_batches = 100 23 | -------------------------------------------------------------------------------- /src/configs/generic_mini_imagenet_config.py: -------------------------------------------------------------------------------- 1 | class GenericConfigMiniImageNet(object): 2 | """Contains options that are probably common to different models 3 | for training on miniImageNet.""" 4 | 5 | def __init__(self): 6 | self.dataset = "mini_imagenet" 7 | self.height = 84 8 | self.width = 84 9 | self.channels = 3 10 | self.num_fewshot_samples = 100 11 | self.optimizer = "ADAM" 12 | self.niters = 25000 13 | self.compute_mAP = True 14 | self.compute_fewshot = True 15 | self.display_step = 20 16 | self.validation_freq = 500 17 | self.save_freq = 2000 18 | self.update_dashboard_freq = 100 19 | self.compute_mAP_freq = 500 20 | self.fewshot_test_freq = 1000 21 | self.save_model = True 22 | self.neval_batches = 100 23 | -------------------------------------------------------------------------------- /src/data/data_handler.py: -------------------------------------------------------------------------------- 1 | from __future__ import (absolute_import, division, print_function, 2 | unicode_literals) 3 | 4 | from src.data.omniglot import OmniglotDataset 5 | from src.data.mini_imagenet import MiniImageNetDataset 6 | 7 | 8 | def get_dataset(dataset_name, config, split): 9 | if dataset_name == "omniglot": 10 | return OmniglotDataset(dataset_name, config, config.nway, split, 11 | config.batch_size, "data/omniglot_cache.pklz") 12 | elif dataset_name == "mini_imagenet": 13 | return MiniImageNetDataset(dataset_name, config, config.nway, split, 14 | config.batch_size, 15 | "data/mini_imagenet_cache.pklz") 16 | else: 17 | raise ValueError("Unknown dataset \"{}\"".format(dataset_name)) 18 | -------------------------------------------------------------------------------- /src/configs/configs.py: -------------------------------------------------------------------------------- 1 | from src.configs.config_siamese_omniglot import SiameseConfigOmniglot 2 | from src.configs.config_mAP_omniglot import MeanAveragePrecisionConfigOmniglot 3 | from src.configs.config_siamese_mini_imagenet import SiameseConfigMiniImageNet 4 | from src.configs.config_mAP_mini_imagenet import MeanAveragePrecisionConfigMiniImageNet 5 | 6 | 7 | class Configs(object): 8 | 9 | def __init__(self): 10 | self._CONFIGS = {} 11 | self._CONFIGS["omniglot_siamese"] = SiameseConfigOmniglot() 12 | self._CONFIGS["omniglot_mAP"] = MeanAveragePrecisionConfigOmniglot() 13 | self._CONFIGS["mini_imagenet_siamese"] = SiameseConfigMiniImageNet() 14 | self._CONFIGS[ 15 | "mini_imagenet_mAP"] = MeanAveragePrecisionConfigMiniImageNet() 16 | 17 | def get_config(self, dataset_name, model_name): 18 | config_key = "{}_{}".format(dataset_name, model_name) 19 | if config_key in self.CONFIGS: 20 | return self.CONFIGS[config_key] 21 | else: 22 | raise ValueError("No matching config {}.".format(config_key)) 23 | 24 | @property 25 | def CONFIGS(self): 26 | return self._CONFIGS 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Eleni Triantafillou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/configs/config_siamese_omniglot.py: -------------------------------------------------------------------------------- 1 | import os 2 | from src.configs.generic_omniglot_config import GenericConfigOmniglot 3 | 4 | SAVE_LOC = "saved_models/" 5 | DASHBOARD_LOC = "/u/eleni/public_html/results/few_shot_mAP/" 6 | 7 | 8 | class SiameseConfigOmniglot(GenericConfigOmniglot): 9 | 10 | def __init__(self): 11 | # Copy the generic Omniglot options 12 | super(SiameseConfigOmniglot, self).__init__() 13 | 14 | # Siamese network 15 | self.model_type = "siamese" 16 | self.loss_function = "cross_entropy" 17 | self.join_branches = "abs_diff" 18 | 19 | # Learning rate value and schedule 20 | self.lr = 0.1 21 | self.ada_learning_rate = False 22 | self.start_decr_lr = 2000 23 | self.mult_lr_by = 0.5 24 | self.freq_decr_lr = 2000 25 | self.smallest_lr = 0.0001 26 | 27 | # Optimization 28 | self.optimizer = "ADAM" 29 | 30 | # Batch formation 31 | self.batch_size = 128 32 | self.nway = 16 # number of classes allowed in each batch 33 | 34 | self.reload = True 35 | 36 | self.name = "siamese_omniglot" 37 | 38 | # Metrics to plot throughout training 39 | self.few_shot_metrics = [{ 40 | "K": 1, 41 | "N": 5, 42 | "type": "classif" 43 | }, { 44 | "K": 1, 45 | "N": 5, 46 | "type": "retrieval" 47 | }] 48 | 49 | # deep dashboard location 50 | self.dashboard_path = os.path.join(DASHBOARD_LOC, self.name) 51 | 52 | # where to save checkpoints of training models 53 | self.saveloc = os.path.join(os.path.join(SAVE_LOC, self.dataset), self.name) 54 | -------------------------------------------------------------------------------- /src/configs/config_siamese_mini_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | from src.configs.generic_mini_imagenet_config import GenericConfigMiniImageNet 3 | 4 | SAVE_LOC = "saved_models/" 5 | DASHBOARD_LOC = "/u/eleni/public_html/results/few_shot_mAP/" 6 | 7 | 8 | class SiameseConfigMiniImageNet(GenericConfigMiniImageNet): 9 | 10 | def __init__(self): 11 | # Copy the generic Omniglot options 12 | super(SiameseConfigMiniImageNet, self).__init__() 13 | 14 | # Siamese network 15 | self.model_type = "siamese" 16 | self.loss_function = "cross_entropy" 17 | self.join_branches = "abs_diff" 18 | 19 | # Learning rate value and schedule 20 | self.lr = 0.01 21 | self.ada_learning_rate = True 22 | self.start_decr_lr = 2000 23 | self.mult_lr_value = 0.5 24 | self.freq_decr_lr = 2000 25 | self.smallest_lr = 0.00001 26 | 27 | # Optimization 28 | self.optimizer = "ADAM" 29 | 30 | # Batch formation 31 | self.batch_size = 128 32 | self.nway = 8 # number of classes allowed in each batch 33 | 34 | self.reload = False 35 | 36 | self.name = "siamese_miniImageNet" 37 | 38 | # Metrics to plot throughout training 39 | self.few_shot_metrics = [{ 40 | "K": 1, 41 | "N": 5, 42 | "type": "classif" 43 | }, { 44 | "K": 1, 45 | "N": 5, 46 | "type": "retrieval" 47 | }] 48 | 49 | # deep dashboard location 50 | self.dashboard_path = os.path.join(DASHBOARD_LOC, self.name) 51 | 52 | # where to save checkpoints of training models 53 | self.saveloc = os.path.join(os.path.join(SAVE_LOC, self.dataset), self.name) 54 | -------------------------------------------------------------------------------- /src/configs/config_mAP_omniglot.py: -------------------------------------------------------------------------------- 1 | import os 2 | from src.configs.generic_omniglot_config import GenericConfigOmniglot 3 | 4 | SAVE_LOC = "saved_models/" 5 | DASHBOARD_LOC = "/u/eleni/public_html/results/few_shot_mAP/" 6 | 7 | 8 | class MeanAveragePrecisionConfigOmniglot(GenericConfigOmniglot): 9 | 10 | def __init__(self): 11 | # Copy the generic Omniglot options 12 | super(MeanAveragePrecisionConfigOmniglot, self).__init__() 13 | 14 | self.model_type = "mAP" 15 | 16 | # Learning rate value and schedule 17 | self.lr = 0.001 18 | self.ada_learning_rate = False 19 | self.start_decr_lr = 2000 20 | self.mult_lr_value = 0.5 21 | self.freq_decr_lr = 2000 22 | self.smallest_lr = 0.0001 23 | 24 | # Optimization 25 | self.optimizer = "ADAM" 26 | self.epsilon = 1 27 | self.alpha = 10 28 | self.optimization_framework = "DLM" # Direct Loss Minimization 29 | # self.optimization_framework = "SSVM" 30 | self.positive_update = True 31 | 32 | # Batch formation 33 | self.batch_size = 128 34 | self.nway = 16 # number of classes allowed in each batch 35 | 36 | self.reload = False 37 | 38 | self.name = "mAP_DLM_omniglot" 39 | # self.name = "mAP_SSVM_omniglot" 40 | 41 | # Metrics to plot throughout training 42 | self.few_shot_metrics = [{ 43 | "K": 1, 44 | "N": 5, 45 | "type": "classif" 46 | }, { 47 | "K": 1, 48 | "N": 5, 49 | "type": "retrieval" 50 | }] 51 | 52 | # deep dashboard location 53 | self.dashboard_path = os.path.join(DASHBOARD_LOC, self.name) 54 | 55 | # where to save checkpoints of training models 56 | self.saveloc = os.path.join(os.path.join(SAVE_LOC, self.dataset), self.name) 57 | -------------------------------------------------------------------------------- /src/configs/config_mAP_mini_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | from src.configs.generic_mini_imagenet_config import GenericConfigMiniImageNet 3 | 4 | SAVE_LOC = "saved_models/" 5 | DASHBOARD_LOC = "/u/eleni/public_html/results/few_shot_mAP/" 6 | 7 | 8 | class MeanAveragePrecisionConfigMiniImageNet(GenericConfigMiniImageNet): 9 | 10 | def __init__(self): 11 | # Copy the generic Omniglot options 12 | super(MeanAveragePrecisionConfigMiniImageNet, self).__init__() 13 | 14 | self.model_type = "mAP" 15 | 16 | # Learning rate value and schedule 17 | self.lr = 0.001 18 | self.ada_learning_rate = True 19 | self.start_decr_lr = 2000 20 | self.mult_lr_value = 0.75 21 | self.freq_decr_lr = 2000 22 | self.smallest_lr = 0.00001 23 | 24 | # Optimization 25 | self.optimizer = "ADAM" 26 | self.epsilon = 1 27 | self.alpha = 10 28 | self.optimization_framework = "DLM" # Direct Loss Minimization 29 | # self.optimization_framework = "SSVM" 30 | self.positive_update = True 31 | 32 | # Batch formation 33 | self.batch_size = 128 34 | self.nway = 8 # number of classes allowed in each batch 35 | 36 | self.reload = False 37 | 38 | self.name = "mAP_DLM_miniImageNet" 39 | # self.name = "mAP_SSVM_miniImageNet" 40 | 41 | # Metrics to plot throughout training 42 | self.few_shot_metrics = [{ 43 | "K": 1, 44 | "N": 5, 45 | "type": "classif" 46 | }, { 47 | "K": 1, 48 | "N": 5, 49 | "type": "retrieval" 50 | }] 51 | 52 | # deep dashboard location 53 | self.dashboard_path = os.path.join(DASHBOARD_LOC, self.name) 54 | 55 | # where to save checkpoints of training models 56 | self.saveloc = os.path.join(os.path.join(SAVE_LOC, self.dataset), self.name) 57 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Few-Shot Learning Through an Information Retrieval Lens 2 | This repository contains the code for the paper "Few-Shot Learning Through an Information Retrieval Lens". Eleni Triantafillou, Richard Zemel, Raquel Urtasun [arXiv preprint](https://arxiv.org/abs/1707.02610). 3 | 4 | ### Prerequisites 5 | - Python 2 6 | - tensorflow version 1.3 7 | - NumPy 8 | - tqdm 9 | - cuda (if you want to run on GPU) 10 | 11 | 12 | ### Setting up 13 | The code assumes the existence of a directory named `data` within the `few_shot_mAP_public` directory. `data` contains 3 subdirectories: `omniglot` and `mini_imagenet`, containing the corresponding datasets, and another directory called `dataset_splits` also containing subdirectories for `omniglot` and `mini_imagenet` containing the data splits for these datasets (.csv files indicating which classes are meant to be used for training / validation / testing). 14 | 15 | This structure will be created by running the provided setup script. Please modify the first 4 lines of this script to add the paths to the Omniglot and mini-ImageNet datasets and their corresponding splits (the datasets and splits are not provided in this repository). 16 | ``` 17 | cd few_shot_mAP_public 18 | ./setup.sh 19 | ``` 20 | 21 | If you'd like to monitor the training progress via [Deep Dashboard](https://github.com/renmengye/deep-dashboard), please follow these instructions: 22 | - Setup Deep Dashboard as detailed here https://github.com/renmengye/deep-dashboard 23 | - In `few_shot_mAP_public/src/configs` there are a number of files, one for each example experiment (corresponding to some choice of dataset and model). Please modify the following line that is found on the top of each config file in order to point to the directory where Deep Dashboard should store its results. 24 | ``` 25 | DASHBOARD_LOC = "/u/eleni/public_html/results/few_shot_mAP/" 26 | ``` 27 | 28 | 29 | ## Reproducing our results 30 | 31 | The experiements in the paper can be reproduced by running 32 | ``` 33 | python run_train.py 34 | ``` 35 | 36 | with the appropriate tf.FLAGS set to point to the correct dataset and model. A config file will then be looked up (among the files in `few_shot_mAP_public/src/configs`) based on these two pieces of information and the settings in that file will be used for training. 37 | 38 | To evaluate a trained model on the benchmark tasks, you can run 39 | ``` 40 | python run_eval.py 41 | ``` 42 | Similarly as before, this requires setting the appropriate dataset and model so that the corresponding config file and model can be looked up. 43 | -------------------------------------------------------------------------------- /src/models/mAP_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import (absolute_import, division, print_function, 2 | unicode_literals) 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | import tensorflow.contrib.slim as slim 7 | 8 | from src.models.model import Model 9 | from src.utils.loss_aug_AP import LossAugmentedInferenceAP 10 | 11 | 12 | class MeanAveragePrecisionModel(Model): 13 | """ 14 | Our model for optimizing Mean Average Precision. 15 | """ 16 | 17 | def __init__(self, config, reuse=False): 18 | super(MeanAveragePrecisionModel, self).__init__(config, reuse) 19 | 20 | def compute_loss(self): 21 | if self.config.optimization_framework == "SSVM": 22 | loss = self.mAP_score_aug * self.config.alpha - self.mAP_score_GT 23 | elif self.config.optimization_framework == "DLM": # Direct Loss minimization 24 | loss = (1 / self.config.epsilon) * ( 25 | self.mAP_score_aug * self.config.alpha - self.mAP_score_std) 26 | if not self.config.positive_update: 27 | loss *= -1 28 | else: 29 | raise ValueError("Unknown optimization framework {}".format( 30 | self.config.optimization_framework)) 31 | return loss 32 | 33 | def perform_loss_augmented_inference(self, sess, batch): 34 | batch_size = len(batch["labels"]) 35 | num_pos, num_neg, pos_inds, neg_inds = self.get_positive_negative_splits( 36 | batch) 37 | _feed_dict = { 38 | self.x: batch["imgs"], 39 | self.n_queries_to_parse: self.config.batch_size, 40 | self.num_pos: num_pos, 41 | self.num_neg: num_neg, 42 | self.pos_inds: pos_inds, 43 | self.neg_inds: neg_inds 44 | } 45 | phi_pos, phi_neg, skipped_queries = sess.run( 46 | [self.phi_pos, self.phi_neg, self.skipped_queries], 47 | feed_dict=_feed_dict) 48 | Y_aug = np.zeros((batch_size, np.max(num_pos), np.max(num_neg))) 49 | for qq in range(batch_size): 50 | if skipped_queries[qq] == 1: 51 | print("Skipped {}".format(qq)) 52 | continue 53 | q_phi_pos = phi_pos[qq][:num_pos[qq]] 54 | q_phi_neg = phi_neg[qq][:num_neg[qq]] 55 | 56 | loss_aug_AP_algo = LossAugmentedInferenceAP(q_phi_pos, q_phi_neg, 57 | self.config.epsilon, 58 | self.config.positive_update) 59 | q_Y_aug = -1 * loss_aug_AP_algo.direction[1:, 1:] 60 | Y_aug[qq, :num_pos[qq], :num_neg[qq]] = q_Y_aug 61 | 62 | return Y_aug 63 | 64 | @property 65 | def config(self): 66 | return self._config 67 | -------------------------------------------------------------------------------- /src/models/siamese.py: -------------------------------------------------------------------------------- 1 | from __future__ import (absolute_import, division, print_function, 2 | unicode_literals) 3 | 4 | import tensorflow as tf 5 | import tensorflow.contrib.slim as slim 6 | 7 | from src.models.model import Model 8 | 9 | 10 | class SiameseModel(Model): 11 | """ 12 | Siamese model with cross-entropy loss function. 13 | """ 14 | 15 | def __init__(self, config, reuse=False): 16 | self._first_inds = tf.placeholder(tf.int32, [None]) 17 | self._second_inds = tf.placeholder(tf.int32, [None]) 18 | y_dim = 1 19 | if config.loss_function == "cross_entropy": 20 | y_dim = 2 21 | self._y = tf.placeholder(tf.float32, [None, y_dim]) 22 | super(SiameseModel, self).__init__(config, reuse) 23 | self._siamese_accuracy = self.compute_siamese_accuracy() 24 | 25 | def join_branches(self, feats_A, feats_B): 26 | feats_A = tf.truediv( 27 | feats_A, tf.sqrt(tf.reduce_sum(tf.square(feats_A), 1, keep_dims=True))) 28 | feats_B = tf.truediv( 29 | feats_B, tf.sqrt(tf.reduce_sum(tf.square(feats_B), 1, keep_dims=True))) 30 | if self.config.join_branches == "concat": 31 | pair_feats = tf.concat(1, [feats_A, feats_B]) 32 | elif self.config.join_branches == "abs_diff": 33 | pair_feats = tf.abs(feats_A - feats_B) 34 | return pair_feats 35 | 36 | def get_siamese_prediction(self): 37 | feats_A = tf.gather(self.feats, self.first_inds) 38 | feats_B = tf.gather(self.feats, self.second_inds) 39 | if self.config.loss_function == "cross_entropy": 40 | joined_feats = self.join_branches(feats_A, feats_B) 41 | pred = slim.fully_connected(joined_feats, 2, activation_fn=None) 42 | return pred 43 | 44 | def compute_loss(self): 45 | self._siamese_prediction = self.get_siamese_prediction() 46 | if self.config.loss_function == "cross_entropy": 47 | loss = tf.reduce_mean( 48 | tf.nn.softmax_cross_entropy_with_logits(logits=self.siamese_prediction, 49 | labels=self.y)) 50 | return loss 51 | 52 | def compute_siamese_accuracy(self): 53 | pred_softmax = tf.nn.softmax(self.siamese_prediction) 54 | correct_pred = tf.equal(tf.argmax(pred_softmax, 1), tf.argmax(self.y, 1)) 55 | accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) 56 | return accuracy 57 | 58 | @property 59 | def first_inds(self): 60 | return self._first_inds 61 | 62 | @property 63 | def second_inds(self): 64 | return self._second_inds 65 | 66 | @property 67 | def y(self): 68 | return self._y 69 | 70 | @property 71 | def siamese_accuracy(self): 72 | return self._siamese_accuracy 73 | 74 | @property 75 | def siamese_prediction(self): 76 | return self._siamese_prediction 77 | -------------------------------------------------------------------------------- /src/utils/loss_aug_AP.py: -------------------------------------------------------------------------------- 1 | from __future__ import (absolute_import, division, print_function, 2 | unicode_literals) 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | import tensorflow.contrib.slim as slim 7 | 8 | 9 | class LossAugmentedInferenceAP(object): 10 | """ 11 | Loss augmented inference algorithm of Song et al. 12 | for the task loss of Average Precision (AP). 13 | 14 | """ 15 | 16 | def __init__(self, phi_pos, phi_neg, epsilon, positive_update=True): 17 | """ 18 | :param phi_pos: cosine similarities between the query and each positive point 19 | :param phi_neg: cosine similarities between the query and each negative point 20 | :param epsilon: float used by DLM (see the paper for details) 21 | :param positive_update: whether or not to perform positive update of DLM 22 | """ 23 | 24 | num_pos = phi_pos.shape[0] 25 | num_neg = phi_neg.shape[0] 26 | self._num_pos = num_pos 27 | self._num_neg = num_neg 28 | 29 | B, G = self.compute_B_and_G(phi_pos, phi_neg) 30 | self._B = B 31 | self._G = G 32 | 33 | if positive_update: 34 | self._negative_update = -1 35 | else: 36 | self._negative_update = 1 37 | self._epsilon = epsilon 38 | 39 | H, d = self.compute_H_and_d() 40 | self._direction = d 41 | self._H = H 42 | 43 | ranking = self.recover_ranking(d) 44 | self._ranking = ranking 45 | 46 | @property 47 | def num_pos(self): 48 | return self._num_pos 49 | 50 | @property 51 | def num_neg(self): 52 | return self._num_neg 53 | 54 | @property 55 | def B(self): 56 | return self._B 57 | 58 | @property 59 | def G(self): 60 | return self._G 61 | 62 | @property 63 | def negative_update(self): 64 | return self._negative_update 65 | 66 | @property 67 | def epsilon(self): 68 | return self._epsilon 69 | 70 | @property 71 | def direction(self): 72 | return self._direction 73 | 74 | @property 75 | def H(self): 76 | return self._H 77 | 78 | @property 79 | def ranking(self): 80 | return self._ranking 81 | 82 | def compute_B_and_G(self, phi_pos, phi_neg): 83 | B = np.zeros((self.num_pos + 1, self.num_neg + 1)) 84 | G = np.zeros((self.num_pos + 1, self.num_neg + 1)) 85 | 86 | for i in range(1, self.num_pos + 1): 87 | for j in range(1, self.num_neg + 1): 88 | B[i, j] = B[i, j - 1] - (phi_pos[i - 1] - phi_neg[j - 1]) / float( 89 | self.num_pos * self.num_neg) 90 | G[i, j] = G[i - 1, j] + (phi_pos[i - 1] - phi_neg[j - 1]) / float( 91 | self.num_pos * self.num_neg) 92 | 93 | return B, G 94 | 95 | def compute_H_and_d(self): 96 | H = np.zeros((self.num_pos + 1, self.num_neg + 1)) 97 | direction = np.zeros((self.num_pos + 1, self.num_neg + 1)) 98 | for i in range(self.num_pos + 1): 99 | for j in range(self.num_neg + 1): 100 | if i == 0 and j == 0: 101 | H[i, j] = 0 102 | direction[i, j] = 0 103 | continue 104 | if i == 1 and j == 0: 105 | H[i, j] = self.epsilon * self.negative_update / float(self.num_pos) 106 | direction[i, j] = 1 107 | continue 108 | if i == 0 and j == 1: 109 | H[i, j] = 0 110 | direction[i, j] = -1 111 | continue 112 | if i == 0: # but j > 1 113 | H[i, j] = H[i, j - 1] + self.G[i, j] 114 | direction[i, j] = -1 115 | continue 116 | 117 | _add_pos = self.epsilon * 1.0 / self.num_pos * i / float( 118 | i + j) * self.negative_update + self.B[i, j] 119 | if j == 0: 120 | H[i, j] = H[i - 1, j] + _add_pos 121 | direction[i, j] = 1 122 | continue 123 | if (H[i, j - 1] + self.G[i, j]) > (H[i - 1, j] + _add_pos): 124 | H[i, j] = H[i, j - 1] + self.G[i, j] 125 | direction[i, j] = -1 126 | else: 127 | H[i, j] = H[i - 1, j] + _add_pos 128 | direction[i, j] = 1 129 | return H, direction 130 | 131 | def recover_ranking(self, d): 132 | ranking = np.zeros((self.num_pos + self.num_neg)) 133 | i = self.num_pos 134 | j = self.num_neg 135 | while (i >= 0 and j >= 0 and not (i == 0 and j == 0)): 136 | if d[i, j] == 1: 137 | ranking[i - 1] = i + j - 1 138 | i -= 1 139 | else: 140 | ranking[j + self.num_pos - 1] = i + j - 1 141 | j -= 1 142 | return ranking 143 | -------------------------------------------------------------------------------- /run_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | 4 | from src.data.data_handler import get_dataset 5 | from src.models.model_handler import get_model 6 | from src.configs.configs import Configs 7 | from src.eval.evaluator import Evaluator 8 | 9 | flags = tf.flags 10 | # flags.DEFINE_string("model", "siamese", "Model name") 11 | flags.DEFINE_string("model", "mAP", "Model name") 12 | # flags.DEFINE_string("dataset", "omniglot", "Dataset name") 13 | flags.DEFINE_string("dataset", "mini_imagenet", "Dataset name") 14 | 15 | FLAGS = tf.flags.FLAGS 16 | 17 | OUTDIR = os.path.join("results", FLAGS.dataset) 18 | 19 | if __name__ == "__main__": 20 | configs = Configs() 21 | config = configs.get_config(FLAGS.dataset, FLAGS.model) 22 | 23 | test_dataset = get_dataset(FLAGS.dataset, config, "test") 24 | model = get_model(config) 25 | saver = tf.train.Saver(tf.global_variables()) 26 | 27 | with tf.Session() as sess: 28 | 29 | ckpt = tf.train.latest_checkpoint(config.saveloc) 30 | if ckpt: 31 | saver.restore(sess, ckpt) 32 | print("Restored weights from {}".format(config.saveloc)) 33 | 34 | # Find out the uidx that we are restoring from 35 | with open(os.path.join(config.saveloc, "checkpoint"), "r") as f: 36 | lines = f.readlines() 37 | model_checkpoint_line = lines[0].strip() 38 | dash_ind = model_checkpoint_line.rfind('-') 39 | uidx = int(model_checkpoint_line[dash_ind + 1:-1]) 40 | print("Restored from update {}".format(uidx)) 41 | else: 42 | raise ValueError( 43 | "No checkpoint to restore from in {}".format(config.saveloc)) 44 | 45 | # Create an Evaluator object 46 | test_evaluator = Evaluator(config, model, test_dataset, sess) 47 | 48 | # Perform evaluation 49 | if config.dataset == "omniglot": 50 | oneshot_5way_mAP, _ = test_evaluator.eval_oneshot_retrieval( 51 | 5, 10, num_samples=1000) 52 | oneshot_20way_mAP, _ = test_evaluator.eval_oneshot_retrieval( 53 | 20, 10, num_samples=1000) 54 | oneshot_5way_acc, _ = test_evaluator.eval_fewshot_classif( 55 | 1, 5, num_samples=1000) 56 | oneshot_20way_acc, _ = test_evaluator.eval_fewshot_classif( 57 | 1, 20, num_samples=1000) 58 | if not FLAGS.model == "siamese": 59 | fiveshot_5way_acc, _ = test_evaluator.eval_fewshot_classif( 60 | 5, 5, num_samples=1000) 61 | fiveshot_20way_acc, _ = test_evaluator.eval_fewshot_classif( 62 | 5, 20, num_samples=1000) 63 | 64 | elif config.dataset == "mini_imagenet": 65 | oneshot_5way_acc_mean, oneshot_5way_acc_pm = test_evaluator.eval_fewshot_classif( 66 | 1, 5, num_samples=600) 67 | if not FLAGS.model == "siamese": 68 | fiveshot_5way_acc_mean, fiveshot_5way_acc_pm = test_evaluator.eval_fewshot_classif( 69 | 5, 5, num_samples=600) 70 | oneshot_5way_mAP_mean, oneshot_5way_mAP_pm = test_evaluator.eval_oneshot_retrieval( 71 | 5, 10, num_samples=600) 72 | oneshot_20way_mAP_mean, oneshot_20way_mAP_pm = test_evaluator.eval_oneshot_retrieval( 73 | 20, 10, num_samples=600) 74 | 75 | # Save results to file 76 | if not os.path.isdir(OUTDIR): 77 | os.makedirs(OUTDIR) 78 | with open(os.path.join(OUTDIR, config.name + ".txt"), "a") as f: 79 | if config.dataset == "omniglot": 80 | f.write( 81 | "Results from model {} at update {}:\n".format(config.name, uidx)) 82 | f.write("1-shot 5-way acc {}\n".format(oneshot_5way_acc)) 83 | f.write("1-shot 20-way acc {}\n".format(oneshot_20way_acc)) 84 | if not FLAGS.model == "siamese": 85 | f.write("5-shot 5-way acc {}\n".format(fiveshot_5way_acc)) 86 | f.write("5-shot 20-way acc {}\n".format(fiveshot_20way_acc)) 87 | f.write("1-shot 5-way mAP {}\n".format(oneshot_5way_mAP)) 88 | f.write("1-shot 20-way mAP {}\n".format(oneshot_20way_mAP)) 89 | elif config.dataset == "mini_imagenet": 90 | f.write( 91 | "Results from model {} at update {}:\n".format(config.name, uidx)) 92 | f.write("1-shot 5-way acc {} plus/minus {}\n".format( 93 | oneshot_5way_acc_mean, oneshot_5way_acc_pm)) 94 | if not FLAGS.model == "siamese": 95 | f.write("5-shot 5-way acc {} plus/minus {}\n".format( 96 | fiveshot_5way_acc_mean, fiveshot_5way_acc_pm)) 97 | f.write("1-shot 5-way mAP {} plus/minus {}\n".format( 98 | oneshot_5way_mAP_mean, oneshot_5way_mAP_pm)) 99 | f.write("1-shot 20-way mAP {} plus/minus {}\n".format( 100 | oneshot_20way_mAP_mean, oneshot_20way_mAP_pm)) 101 | -------------------------------------------------------------------------------- /src/utils/dashboard_logger.py: -------------------------------------------------------------------------------- 1 | from __future__ import (absolute_import, division, print_function, 2 | unicode_literals) 3 | 4 | import os 5 | from time import strftime 6 | 7 | 8 | class DashboardLogger(object): 9 | """ 10 | Create the appropriate files to store 11 | results throughout training for use with 12 | deep dashboard. 13 | """ 14 | 15 | def __init__(self, config): 16 | self._config = config 17 | self._name = config.name 18 | self._path = self.config.dashboard_path 19 | 20 | self._filenames_dict = {} 21 | self._filenames_dict["catalog"] = "catalog" 22 | self._filenames_dict["config"] = "config.txt" 23 | self._filenames_dict["notes"] = "notes.txt" 24 | self._filenames_dict["acc"] = ".csv" 25 | self._filenames_dict["F_score"] = ".csv" 26 | self._filenames_dict["1shot_5way_acc"] = "1shot_5way_acc.csv" 27 | self._filenames_dict["1shot_20way_acc"] = "1shot_20way_acc.csv" 28 | self._filenames_dict["5shot_5way_acc"] = "5shot_5way_acc.csv" 29 | self._filenames_dict["1shot_20way_mAP"] = "1shot_20way_mAP.csv" 30 | self._filenames_dict["1shot_5way_mAP"] = "1shot_5way_mAP.csv" 31 | self._filenames_dict["1shot_20way_mAP"] = "1shot_20way_mAP.csv" 32 | self._filenames_dict["mAP"] = "mAP.csv" 33 | 34 | self._paths_dict = self.create_paths() 35 | 36 | self.print_experiment_path() 37 | if not config.reload: 38 | self.setup() 39 | 40 | def print_experiment_path(self): 41 | print("Creating dashboard logger for experiment at location {}".format( 42 | self.path)) 43 | 44 | def create_paths(self): 45 | paths_dict = {} 46 | for k, v in self.filenames_dict.iteritems(): 47 | paths_dict[k] = os.path.join(self.path, v) 48 | return paths_dict 49 | 50 | def add_to_catalog(self, _type, fname, catalog_entry_name): 51 | print("Experiment path: {}".format(self.path)) 52 | if not os.path.exists(self.paths_dict["catalog"]): 53 | with open(self.paths_dict["catalog"], "w+") as f: 54 | f.write("filename,type,name\n") 55 | with open(self.paths_dict["catalog"], "r") as f: 56 | lines = f.readlines() 57 | found = 0 58 | for line in lines: 59 | this_line = line.split(",") 60 | if this_line[0] == fname: 61 | found = 1 62 | break 63 | if not found: 64 | with open(self.paths_dict["catalog"], "a") as f: 65 | f.write('%s,%s,%s\n' % (fname, _type, catalog_entry_name)) 66 | return 67 | 68 | def setup(self): 69 | if not os.path.isdir(self.path): 70 | os.makedirs(self.path) 71 | if self.config.model_type == "siamese": 72 | self.add_to_catalog("csv", self.filenames_dict["acc"], 73 | 'Accuracy During Training') 74 | with open(self.paths_dict["acc"], "w") as f: 75 | f.write("step,time,train acc,val acc\n") 76 | elif self.config.model_type == "mAP": 77 | self.add_to_catalog("csv", self.filenames_dict["F_score"], 78 | "Scoring function") 79 | with open(self.paths_dict["F_score"], "w") as f: 80 | f.write("step,time,standard,augmented\n") 81 | if self.config.model_type == "mAP" or self.config.compute_mAP: 82 | self.add_to_catalog("csv", self.filenames_dict["mAP"], 83 | "Mean Average Precision") 84 | with open(self.paths_dict["mAP"], "w") as f: 85 | f.write("step,time,train mAP,val mAP\n") 86 | if self.config.compute_fewshot: 87 | self.add_to_catalog("csv", self.filenames_dict["1shot_5way_acc"], 88 | "1-shot 5-way Classification") 89 | with open(self.paths_dict["1shot_5way_acc"], "w") as f: 90 | f.write("step,time,acc\n") 91 | if self.config.dataset == "omniglot": 92 | self.add_to_catalog("csv", self.filenames_dict["1shot_20way_acc"], 93 | "1-shot 20-way Classification") 94 | with open(self.paths_dict["1shot_20way_acc"], "w") as f: 95 | f.write("step,time,acc\n") 96 | self.add_to_catalog("csv", self.filenames_dict["1shot_20way_mAP"], 97 | "1-shot 20-way Retrieval") 98 | with open(self.paths_dict["1shot_20way_mAP"], "w") as f: 99 | f.write("step,time,mAP\n") 100 | elif self.config.dataset == "mini_imagenet": 101 | if not self.config.model_type == "siamese": 102 | self.add_to_catalog("csv", self.filenames_dict["5shot_5way_acc"], 103 | "5-shot 5-way Classification") 104 | with open(self.paths_dict["5shot_5way_acc"], "w") as f: 105 | f.write("step,time,acc\n") 106 | self.add_to_catalog("csv", self.filenames_dict["1shot_5way_mAP"], 107 | "1-shot 5-way Retrieval") 108 | with open(self.paths_dict["1shot_5way_mAP"], "w") as f: 109 | f.write("step,time,mAP\n") 110 | self.add_to_catalog("csv", self.filenames_dict["1shot_20way_mAP"], 111 | "1-shot 20-way Retrieval") 112 | with open(self.paths_dict["1shot_20way_mAP"], "w") as f: 113 | f.write("step,time,mAP\n") 114 | self.add_to_catalog("plain", self.filenames_dict["config"], "Config") 115 | self.write_config() 116 | self.add_to_catalog("plain", self.filenames_dict["notes"], "Notes") 117 | with open(self.paths_dict["notes"], "w") as f: 118 | f.write("Notes:\n") 119 | 120 | def write_acc(self, uidx, acc_train, acc_val): 121 | current_time = strftime("%Y-%m-%dT%H:%M:%S") 122 | with open(self.paths_dict["acc"], "a") as f: 123 | f.write("%d,%s,%f,%f\n" % (uidx, current_time, acc_train, acc_val)) 124 | 125 | def write_F_score(self, uidx, standard, augmented): 126 | current_time = strftime("%Y-%m-%dT%H:%M:%S") 127 | with open(self.paths_dict["acc"], "a") as f: 128 | f.write("%d,%s,%f,%f\n" % (uidx, current_time, standard, augmented)) 129 | 130 | def write_Kshot_Nway_classif(self, uidx, K, N, acc): 131 | current_time = strftime("%Y-%m-%dT%H:%M:%S") 132 | key = "{}shot_{}way_acc".format(K, N) 133 | with open(self.paths_dict[key], "a") as f: 134 | f.write("%d,%s,%f\n" % (uidx, current_time, acc)) 135 | 136 | def write_oneshot_Nway_retrieval(self, uidx, N, mAP): 137 | key = "1shot_{}way_mAP".format(N) 138 | current_time = strftime("%Y-%m-%dT%H:%M:%S") 139 | with open(self.paths_dict[key], "a") as f: 140 | f.write("%d,%s,%f\n" % (uidx, current_time, mAP)) 141 | 142 | def write_mAP(self, uidx, train_mAP, val_mAP): 143 | current_time = strftime("%Y-%m-%dT%H:%M:%S") 144 | with open(self.paths_dict["mAP"], "a") as f: 145 | f.write("%d,%s,%f,%f\n" % (uidx, current_time, train_mAP, val_mAP)) 146 | 147 | def write_config(self): 148 | with open(self.paths_dict["config"], "w") as f: 149 | attr_dict = vars(self.config) 150 | for key, val in attr_dict.iteritems(): 151 | f.write("{}: {}\n".format(key, val)) 152 | 153 | def take_note(self, note): 154 | with open(self.paths_dict["notes"], "a") as f: 155 | f.write("{}".format(note)) 156 | 157 | @property 158 | def name(self): 159 | return self._name 160 | 161 | @property 162 | def config(self): 163 | return self._config 164 | 165 | @property 166 | def path(self): 167 | return self._path 168 | 169 | @property 170 | def path_notes(self): 171 | return self._path_notes 172 | 173 | @property 174 | def path_acc(self): 175 | return self._path_acc 176 | 177 | @property 178 | def path_catalog(self): 179 | return self._path_catalog 180 | 181 | @property 182 | def paths_dict(self): 183 | return self._paths_dict 184 | 185 | @property 186 | def filenames_dict(self): 187 | return self._filenames_dict 188 | -------------------------------------------------------------------------------- /src/data/dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import (absolute_import, division, print_function, 2 | unicode_literals) 3 | 4 | import os 5 | import gzip 6 | import numpy as np 7 | import pickle as pkl 8 | 9 | from scipy.misc import imread 10 | from scipy.misc import imresize 11 | 12 | 13 | class Dataset(object): 14 | 15 | def __init__(self, name, config, nway, split, batch_size, cache_path): 16 | self._name = name 17 | self._height = config.height 18 | self._width = config.width 19 | self._channels = config.channels 20 | self._config = config 21 | self._nway = nway 22 | self._split = split 23 | self._batch_size = batch_size 24 | self._cache_path = cache_path 25 | 26 | tr_dict, tr_label, val_dict, val_label, test_dict, test_label = self.read_dataset( 27 | ) 28 | if self._split == "train": 29 | self._images_dict = tr_dict 30 | self._labels = tr_label 31 | elif self._split == "val": 32 | self._images_dict = val_dict 33 | self._labels = val_label 34 | elif self._split == "test": 35 | self._images_dict = test_dict 36 | self._labels = test_label 37 | 38 | self._num_examples = len(self._images_dict.keys()) 39 | self._img_inds = np.arange(self._num_examples) 40 | self._classes = list(set(self._labels)) 41 | self._num_classes = len(self._classes) 42 | 43 | def get_size(self): 44 | return self.num_examples 45 | 46 | def next_batch(self): 47 | if self.config.model_type == "siamese": 48 | batch_imgs, labels, inds_A_relative, inds_B_relative, pair_labels = self.get_batch_pairs( 49 | ) 50 | batch = { 51 | "imgs": batch_imgs, 52 | "labels": labels, 53 | "inds_A": inds_A_relative, 54 | "inds_B": inds_B_relative, 55 | "pair_labels": pair_labels 56 | } 57 | elif self.config.model_type == "mAP": 58 | batch_imgs, selected_labels = self.get_batch_points() 59 | batch = {"imgs": batch_imgs, "labels": selected_labels} 60 | return batch 61 | 62 | def read_dataset(self): 63 | data = self.read_cache() 64 | if not data: 65 | train_dict, l_train, val_dict, l_val, test_dict, l_test = self.load_dataset( 66 | ) 67 | self.save_cache(train_dict, l_train, val_dict, l_val, test_dict, l_test) 68 | else: 69 | train_dict = data["imgs_train_dict"] 70 | val_dict = data["imgs_val_dict"] 71 | test_dict = data["imgs_test_dict"] 72 | l_train = data["labels_train"] 73 | l_val = data["labels_val"] 74 | l_test = data["labels_test"] 75 | return train_dict, l_train, val_dict, l_val, test_dict, l_test 76 | 77 | def read_cache(self): 78 | """Reads dataset from cached pklz file.""" 79 | print("Attempting to read from cache in {}...".format(self.cache_path)) 80 | if os.path.exists(self.cache_path): 81 | with gzip.open(self.cache_path, "rb") as f: 82 | data = pkl.load(f) 83 | print("Successful.") 84 | return data 85 | else: 86 | print("Unsuccessful.") 87 | return False 88 | 89 | def save_cache(self, imgs_train_dict, labels_train, imgs_val_dict, labels_val, 90 | images_test_dict, labels_test): 91 | """Saves pklz cache.""" 92 | data = { 93 | "imgs_train_dict": imgs_train_dict, 94 | "labels_train": labels_train, 95 | "imgs_val_dict": imgs_val_dict, 96 | "labels_val": labels_val, 97 | "imgs_test_dict": images_test_dict, 98 | "labels_test": labels_test 99 | } 100 | with gzip.open(self.cache_path, "wb") as f: 101 | pkl.dump(data, f, protocol=pkl.HIGHEST_PROTOCOL) 102 | 103 | def load_img_as_array(self, img_loc): 104 | if self.channels == 3: 105 | _img_array = imread(img_loc, mode='RGB') 106 | else: 107 | _img_array = imread(img_loc, mode='L') 108 | img_array = imresize( 109 | _img_array, (self.height, self.width, self.channels), interp='bicubic') 110 | img_array = img_array.reshape((self.height, self.width, self.channels)) 111 | return img_array 112 | 113 | def dense_to_one_hot(self, labels_dense, num_classes): 114 | """Convert class labels from scalars to one-hot vectors.""" 115 | num_labels = labels_dense.shape[0] 116 | index_offset = np.arange(num_labels) * num_classes 117 | labels_one_hot = np.zeros((num_labels, num_classes)) 118 | flat_labels_one_hot = labels_one_hot.flatten() 119 | flat_dense = labels_dense.ravel() 120 | flat_labels_one_hot[index_offset + flat_dense.astype('int')] = 1 121 | return flat_labels_one_hot 122 | 123 | def load_batch_imgs(self, imgs_inds): 124 | raise NotImplemented 125 | 126 | def load_dataset(self): 127 | """ 128 | Loads the dataset. 129 | :return: 130 | imgs_train_dict: A dictionary mapping the index of each training image to its location on disk 131 | labels_train: The label for each training image 132 | imgs_val_dict: A dictionary mapping the index of each validation image to its location on disk 133 | labels_val: The label for each validation image 134 | imgs_test_dict: A dictionary mapping the index of each test image to its location on disk 135 | labels_test: The label for each test image 136 | """ 137 | raise NotImplemented 138 | 139 | def get_batch_points(self): 140 | """ 141 | Construct a batch of self.batch_size many points 142 | belonging to self.nway different classes. 143 | """ 144 | raise NotImplemented() 145 | 146 | def get_batch_pairs(self): 147 | """ 148 | Create a batch for training the siamese network 149 | by selecting self.batch_size many points of self.nway classes 150 | and then forming all possible pairs from these points. 151 | """ 152 | raise NotImplemented() 153 | 154 | def create_KshotNway_classification_episode(self, K, N, split): 155 | """ 156 | Create an episode for K-shot N-way classification. 157 | 158 | :param K: Number of representatives of each class (to be used for classification of new points) 159 | :param N: Number of different classes 160 | :return: 161 | ref_paths: A list of N lists, each containing K paths 162 | query_paths: A list of N lists, each containing (20-K) paths 163 | labels: A list of the N class labels 164 | """ 165 | raise NotImplemented 166 | 167 | def create_oneshotNway_retrieval_episode(self, N, n_per_class, split): 168 | """ 169 | Sample a "pool" of images for 1-shot N-way retrieval. 170 | 171 | :param n_per_class: Number of sampled examples of each class in the "pool" 172 | :param N: Number of different classes 173 | :return: 174 | paths: The paths of each example in the pool 175 | labels: The corresponding class labels 176 | """ 177 | raise NotImplemented 178 | 179 | @property 180 | def name(self): 181 | return self._name 182 | 183 | @property 184 | def num_examples(self): 185 | return self._num_examples 186 | 187 | @property 188 | def img_inds(self): 189 | return self._img_inds 190 | 191 | @property 192 | def nway(self): 193 | return self._nway 194 | 195 | @property 196 | def config(self): 197 | return self._config 198 | 199 | @property 200 | def split(self): 201 | return self._split 202 | 203 | @property 204 | def batch_size(self): 205 | return self._batch_size 206 | 207 | @property 208 | def height(self): 209 | return self._height 210 | 211 | @property 212 | def width(self): 213 | return self._width 214 | 215 | @property 216 | def channels(self): 217 | return self._channels 218 | 219 | @property 220 | def images_dict(self): 221 | return self._images_dict 222 | 223 | @property 224 | def labels(self): 225 | return self._labels 226 | 227 | @property 228 | def classes(self): 229 | return self._classes 230 | 231 | @property 232 | def num_classes(self): 233 | return self._num_classes 234 | 235 | @property 236 | def cache_path(self): 237 | return self._cache_path 238 | -------------------------------------------------------------------------------- /src/data/mini_imagenet.py: -------------------------------------------------------------------------------- 1 | from __future__ import (absolute_import, division, print_function, 2 | unicode_literals) 3 | 4 | import os 5 | from tqdm import tqdm 6 | import csv 7 | import numpy as np 8 | from itertools import combinations 9 | 10 | from src.data.dataset import Dataset 11 | 12 | MINI_IMAGENET_FOLDER = "data/mini_imagenet/" 13 | MINI_IMAGENET_SPLITS_FOLDER = "data/dataset_splits/mini_imagenet/Ravi" 14 | 15 | 16 | class MiniImageNetDataset(Dataset): 17 | 18 | def __init__(self, name, config, nway, split, batch_size, cache_path): 19 | self._cache_path = cache_path 20 | super(MiniImageNetDataset, self).__init__(name, config, nway, split, 21 | batch_size, cache_path) 22 | 23 | def get_batch_points(self): 24 | batch_imgs, selected_labels, selected_inds = self.get_batch_points_() 25 | return batch_imgs, selected_labels 26 | 27 | def get_batch_points_(self): 28 | class_inds = np.random.choice(self.num_classes, self.nway, replace=False) 29 | selected_classes = np.array(self.classes)[class_inds] 30 | 31 | includable_point_inds = [] 32 | for selected_class in selected_classes: 33 | sat_inds = list(np.where(self.labels == selected_class)[0]) 34 | includable_point_inds += sat_inds 35 | num_includable_points = len(includable_point_inds) 36 | includable_point_inds_array = np.array(includable_point_inds) 37 | 38 | # Now randomly select batch_size many of the "allowed" examples 39 | selected_inds_ = np.random.choice( 40 | num_includable_points, self.batch_size, replace=False) 41 | selected_inds = includable_point_inds_array[selected_inds_] 42 | 43 | selected_labels = self.labels[selected_inds] 44 | batch_imgs = self.load_batch_imgs(selected_inds) 45 | return batch_imgs, selected_labels, selected_inds 46 | 47 | def get_batch_pairs(self): 48 | batch_imgs, labels, selected_inds = self.get_batch_points_() 49 | relative_inds = np.arange(self.batch_size) 50 | pair_inds = [list(x) for x in combinations(relative_inds, 2)] 51 | label1s = [labels[pair[0]] for pair in pair_inds] 52 | label2s = [labels[pair[1]] for pair in pair_inds] 53 | _pair_labels = [ 54 | int(label1s[i] == label2s[i]) for i in range(len(pair_inds)) 55 | ] 56 | inds_A_relative = [pair[0] for pair in pair_inds] 57 | inds_B_relative = [pair[1] for pair in pair_inds] 58 | pair_labels = np.array(_pair_labels) 59 | pair_labels = pair_labels.reshape(pair_labels.shape[0], 1) 60 | 61 | if self.config.loss_function == "cross_entropy": 62 | pair_labels = self.dense_to_one_hot(pair_labels, 2) 63 | pair_labels = pair_labels.reshape(int(pair_labels.shape[0] / 2), 2) 64 | return batch_imgs, labels, inds_A_relative, inds_B_relative, pair_labels 65 | 66 | def read_splits(self, split): 67 | classes, files = [], [] 68 | csv_path = os.path.join(MINI_IMAGENET_SPLITS_FOLDER, split + ".csv") 69 | with open(csv_path, "r") as csvfile: 70 | csvreader = csv.reader(csvfile) 71 | for i, row in enumerate(csvreader): 72 | if i == 0: 73 | continue 74 | if len(row[0]) == 0: 75 | break 76 | files.append(row[0]) 77 | classes.append(row[1]) 78 | unique_classes = list(set(classes)) 79 | return files, unique_classes 80 | 81 | def load_batch_imgs(self, img_inds): 82 | batch_imgs = np.array([]) 83 | for i in img_inds: 84 | img_path = self.images_dict[i] 85 | img_array = self.load_img_as_array(img_path) 86 | img_array = img_array.reshape((1, self._height, self._width, 87 | self._channels)) 88 | if img_array.shape[2] == 1: # black-and-white images 89 | img_array = np.zeros((self._height, self._width, self._channels)) 90 | img_array[:, :, 0] = img_array.reshape((self._height, self._width)) 91 | img_array[:, :, 1] = img_array.reshape((self._height, self._width)) 92 | img_array[:, :, 2] = img_array.reshape((self._height, self._width)) 93 | else: 94 | img_array = img_array 95 | 96 | img_array = img_array.reshape((1, self._height, self._width, 97 | self._channels)) 98 | if batch_imgs.shape[0] == 0: 99 | batch_imgs = img_array 100 | else: 101 | batch_imgs = np.concatenate((batch_imgs, img_array), axis=0) 102 | return batch_imgs 103 | 104 | def load_dataset(self): 105 | imgs_dict_train, imgs_dict_val, imgs_dict_test = {}, {}, {} 106 | labels_train, labels_val, labels_test = [], [], [] 107 | train_files, train_classes = self.read_splits("train") 108 | val_files, val_classes = self.read_splits("val") 109 | test_files, test_classes = self.read_splits("test") 110 | classes = train_classes + val_classes + test_classes 111 | all_imgs = os.listdir(MINI_IMAGENET_FOLDER) 112 | 113 | example_ind_train, example_ind_val, example_ind_test = 0, 0, 0 114 | for idx in tqdm(range(len(classes)), desc='Loading mini-ImageNet'): 115 | c = classes[idx] 116 | class_imgs = [f for f in all_imgs if c in f] 117 | assert len(class_imgs) == 600, "Expected 600 images but found {}".format( 118 | len(class_imgs)) 119 | 120 | for i in range(len(class_imgs)): 121 | img = class_imgs[i] 122 | img_path = os.path.join(MINI_IMAGENET_FOLDER, img) 123 | if img in train_files and c in train_classes: 124 | imgs_dict_train[example_ind_train] = img_path 125 | labels_train.append(c) 126 | example_ind_train += 1 127 | elif img in val_files and c in val_classes: 128 | imgs_dict_val[example_ind_val] = img_path 129 | labels_val.append(c) 130 | example_ind_val += 1 131 | elif img in test_files and c in test_classes: 132 | imgs_dict_test[example_ind_test] = img_path 133 | labels_test.append(c) 134 | example_ind_test += 1 135 | else: 136 | raise ValueError("Found an image that does not belong to any split.") 137 | 138 | labels_train = np.array(labels_train) 139 | labels_val = np.array(labels_val) 140 | labels_test = np.array(labels_test) 141 | 142 | assert len(labels_train 143 | ) == 600 * 64, "Expected {} but found {} train examples.".format( 144 | 600 * 64, len(labels_train)) 145 | assert len( 146 | labels_val 147 | ) == 600 * 16, "Expected {} but found {} validation examples.".format( 148 | 600 * 16, len(labels_val)) 149 | assert len(labels_test 150 | ) == 600 * 20, "Expected {} but found {} test examples.".format( 151 | 600 * 20, len(labels_test)) 152 | return imgs_dict_train, labels_train, imgs_dict_val, labels_val, imgs_dict_test, labels_test 153 | 154 | def create_KshotNway_classification_episode(self, K, N): 155 | files, all_classes = self.read_splits(self.split) 156 | num_classes = len(all_classes) 157 | perm = np.arange(num_classes) 158 | np.random.shuffle(perm) 159 | chosen_class_inds = list(perm[:N]) 160 | chosen_classes = np.array(all_classes)[chosen_class_inds] 161 | 162 | ref_paths, query_paths, labels = [], [], [] # Lists of lists 163 | all_imgs = os.listdir(MINI_IMAGENET_FOLDER) 164 | for class_ind, c in enumerate(chosen_classes): 165 | class_imgs = [ 166 | os.path.join(MINI_IMAGENET_FOLDER, f) for f in all_imgs if c in f 167 | ] 168 | num_examples = len(class_imgs) 169 | perm = np.arange(num_examples) 170 | np.random.shuffle(perm) 171 | chosen_examples_inds = list(perm[:20]) 172 | chosen_examples = np.array(class_imgs)[chosen_examples_inds] 173 | 174 | class_ref_paths = [] 175 | class_query_paths = [] 176 | for i, img in enumerate(chosen_examples): 177 | if i < K: # Let the first K be representatives 178 | class_ref_paths.append(img) 179 | else: 180 | class_query_paths.append(img) 181 | ref_paths.append(class_ref_paths) 182 | query_paths.append(class_query_paths) 183 | labels.append(class_ind) 184 | return ref_paths, query_paths, labels 185 | 186 | def create_oneshotNway_retrieval_episode(self, N, n_per_class): 187 | files, all_classes = self.read_splits(self.split) 188 | num_classes = len(all_classes) 189 | perm = np.arange(num_classes) 190 | np.random.shuffle(perm) 191 | chosen_class_inds = list(perm[:N]) 192 | chosen_classes = np.array(all_classes)[chosen_class_inds] 193 | 194 | paths, labels = [], [] # lists of length n_per_class * N 195 | all_imgs = os.listdir(MINI_IMAGENET_FOLDER) 196 | for class_ind, c in enumerate(chosen_classes): 197 | class_imgs = [ 198 | os.path.join(MINI_IMAGENET_FOLDER, f) for f in all_imgs if c in f 199 | ] 200 | num_examples = len(class_imgs) 201 | perm = np.arange(num_examples) 202 | np.random.shuffle(perm) 203 | chosen_examples_inds = list(perm[:n_per_class]) 204 | chosen_examples = np.array(class_imgs)[chosen_examples_inds] 205 | for img in chosen_examples: 206 | paths.append(img) 207 | labels.append(class_ind) 208 | return paths, labels 209 | -------------------------------------------------------------------------------- /run_train.py: -------------------------------------------------------------------------------- 1 | from __future__ import (absolute_import, division, print_function, 2 | unicode_literals) 3 | 4 | import os 5 | from tqdm import tqdm 6 | import tensorflow as tf 7 | 8 | from src.data.data_handler import get_dataset 9 | from src.models.model_handler import get_model 10 | from src.configs.configs import Configs 11 | from src.utils.dashboard_logger import DashboardLogger 12 | from src.eval.evaluator import Evaluator 13 | 14 | flags = tf.flags 15 | flags.DEFINE_string("model", "siamese", "Model name") 16 | # flags.DEFINE_string("model", "mAP", "Model name") 17 | flags.DEFINE_string("dataset", "omniglot", "Dataset name") 18 | # flags.DEFINE_string("dataset", "mini_imagenet", "Dataset name") 19 | 20 | FLAGS = tf.flags.FLAGS 21 | 22 | 23 | def _get_model(config): 24 | model = get_model(config) 25 | return model 26 | 27 | 28 | def _get_datasets(dataset_name, config): 29 | train_dataset = get_dataset(dataset_name, config, "train") 30 | val_dataset = get_dataset(dataset_name, config, "val") 31 | test_dataset = get_dataset(dataset_name, config, "test") 32 | return train_dataset, val_dataset, test_dataset 33 | 34 | 35 | def evaluate_siamese(sess, model, dataset, num_batches=100): 36 | summed_accs = 0 37 | for neval in tqdm( 38 | range(num_batches), desc="Computing validation siamese accuracy"): 39 | batch = dataset.next_batch() 40 | this_loss, this_acc = sess.run( 41 | [model.loss, model.siamese_accuracy], 42 | feed_dict={ 43 | model.x: batch["imgs"], 44 | model.y: batch["pair_labels"], 45 | model.first_inds: batch["inds_A"], 46 | model.second_inds: batch["inds_B"] 47 | }) 48 | summed_accs += this_acc 49 | 50 | return summed_accs / float(num_batches) 51 | 52 | 53 | def train(sess, config, start_uidx, model, train_dataset, val_dataset, 54 | train_evaluator, val_evaluator, dashboard_logger, saver): 55 | 56 | best_val_acc = 0 57 | best_val_mAP = 0 58 | for uidx in tqdm( 59 | range(start_uidx, config.niters), 60 | desc="Training model {}".format(config.name)): 61 | 62 | just_reloaded = False 63 | if config.reload and uidx == start_uidx: 64 | just_reloaded = True 65 | train_batch = train_dataset.next_batch() 66 | 67 | # Save a model checkpoint 68 | if config.save_model and uidx % config.save_freq == 0 and uidx > start_uidx: 69 | if not os.path.isdir(config.saveloc): 70 | os.makedirs(config.saveloc) 71 | print("Saving model at {}".format(config.saveloc)) 72 | saver.save( 73 | sess, os.path.join(config.saveloc, config.name), global_step=uidx) 74 | 75 | # Compute the validation performance 76 | if uidx % config.validation_freq == 0: 77 | if config.model_type == "siamese": 78 | val_acc = evaluate_siamese(sess, model, val_dataset, 79 | config.neval_batches) 80 | if val_acc > best_val_acc: 81 | best_val_acc = val_acc 82 | print( 83 | "Update {}, validation accuracy: {}, best validation accuracy so far {}". 84 | format(uidx, val_acc, best_val_acc)) 85 | elif config.model_type == "mAP": 86 | mAPs = [] 87 | for neval in tqdm( 88 | range(config.neval_batches), desc="Computing validation mAP"): 89 | val_batch = val_dataset.next_batch() 90 | mAPs.append(val_evaluator.eval_mAP(val_batch)) 91 | val_mAP = sum(mAPs) / float(config.neval_batches) 92 | if val_mAP > best_val_mAP: 93 | best_val_mAP = val_mAP 94 | print("Update {}, validation mAP: {}, best validation mAP so far {}". 95 | format(uidx, val_mAP, best_val_mAP)) 96 | 97 | # Compute mAP performance of siamese on train/validation sets 98 | if config.model_type == "siamese" and config.compute_mAP and uidx % config.compute_mAP_freq == 0: 99 | mAPs = [] 100 | train_mAP = train_evaluator.eval_mAP(train_batch) 101 | for neval in tqdm( 102 | range(config.neval_batches), desc="Computing validation mAP"): 103 | val_batch = val_dataset.next_batch() 104 | mAPs.append(val_evaluator.eval_mAP(val_batch)) 105 | val_mAP = sum(mAPs) / float(config.neval_batches) 106 | print("Update {}, train mAP: {}, validation mAP: {}".format( 107 | uidx, train_mAP, val_mAP)) 108 | 109 | # Compute few-shot learning performance 110 | if config.compute_fewshot and uidx % config.fewshot_test_freq == 0: 111 | results = [] 112 | for metric in config.few_shot_metrics: 113 | if metric["type"] == "classif": 114 | result, _ = val_evaluator.eval_fewshot_classif( 115 | metric["K"], metric["N"]) 116 | results.append(result) 117 | elif metric["type"] == "retrieval": 118 | if not metric["K"] == 1: 119 | raise ValueError("Only 1-shot retrieval supported currently.") 120 | result, _ = val_evaluator.eval_oneshot_retrieval(metric["N"], 10) 121 | results.append(result) 122 | 123 | # Potentially adapt learning rate according to specified schedule 124 | if config.ada_learning_rate and uidx >= config.start_decr_lr and uidx % config.freq_decr_lr == 0 and not just_reloaded: 125 | current_lr = sess.run(model.lr) 126 | new_lr = current_lr * config.mult_lr_value 127 | if new_lr >= config.smallest_lr: 128 | sess.run(model.assign_lr, feed_dict={model.new_lr: new_lr}) 129 | updated_lr = sess.run(model.lr) 130 | note = "Updated lr from {} to {} in uidx {}\n".format( 131 | current_lr, updated_lr, uidx) 132 | else: 133 | note = "Reached smallest lr value {}, omitting learning rate decrease.\n".format( 134 | config.smallest_lr) 135 | print(note) 136 | dashboard_logger.take_note(note) 137 | 138 | # Perform a training step 139 | if config.model_type == "siamese": 140 | train_loss, train_acc, _ = sess.run( 141 | [model.loss, model.siamese_accuracy, model.train_step], 142 | feed_dict={ 143 | model.x: train_batch["imgs"], 144 | model.y: train_batch["pair_labels"], 145 | model.first_inds: train_batch["inds_A"], 146 | model.second_inds: train_batch["inds_B"] 147 | }) 148 | elif config.model_type == "mAP": 149 | num_pos, num_neg, pos_inds, neg_inds = model.get_positive_negative_splits( 150 | train_batch) 151 | Y_aug = model.perform_loss_augmented_inference(sess, train_batch) 152 | _feed_dict = { 153 | model.x: train_batch["imgs"], 154 | model.n_queries_to_parse: model.config.batch_size, 155 | model.num_pos: num_pos, 156 | model.num_neg: num_neg, 157 | model.pos_inds: pos_inds, 158 | model.neg_inds: neg_inds, 159 | model.Y_aug: Y_aug 160 | } 161 | train_loss, _, score_std, score_aug = sess.run( 162 | [ 163 | model.loss, model.train_step, model.mAP_score_std, 164 | model.mAP_score_aug 165 | ], 166 | feed_dict=_feed_dict) 167 | train_mAP = train_evaluator.eval_mAP(train_batch) 168 | 169 | # Update deep dashboard 170 | if uidx % config.update_dashboard_freq == 0 and not just_reloaded: 171 | if config.model_type == "siamese": 172 | dashboard_logger.write_acc(uidx, train_acc, val_acc) 173 | elif config.model_type == "mAP": 174 | dashboard_logger.write_F_score(uidx, score_std, score_aug) 175 | if config.model_type == "mAP" or config.compute_mAP: 176 | dashboard_logger.write_mAP(uidx, train_mAP, val_mAP) 177 | if config.compute_fewshot: 178 | for result, metric in zip(results, config.few_shot_metrics): 179 | if metric["type"] == "classif": 180 | dashboard_logger.write_Kshot_Nway_classif(uidx, metric["K"], 181 | metric["N"], result) 182 | elif metric["type"] == "retrieval": 183 | dashboard_logger.write_oneshot_Nway_retrieval( 184 | uidx, metric["N"], result) 185 | 186 | # Display training progress 187 | if uidx % config.display_step == 0: 188 | if config.model_type == "siamese": 189 | print( 190 | "Update {}, train accuracy: {}, best validation accuracy so far: {}". 191 | format(uidx, train_acc, best_val_acc)) 192 | elif config.model_type == "mAP": 193 | print("Update {}, train mAP: {}, best validation mAP so far: {}".format( 194 | uidx, train_mAP, best_val_mAP)) 195 | print("Train loss {}".format(train_loss)) 196 | 197 | 198 | def main(): 199 | configs = Configs() 200 | config = configs.get_config(FLAGS.dataset, FLAGS.model) 201 | 202 | train_dataset, val_dataset, test_dataset = _get_datasets( 203 | FLAGS.dataset, config) 204 | 205 | model = _get_model(config) 206 | dashboard_logger = DashboardLogger(config) 207 | saver = tf.train.Saver(tf.global_variables()) 208 | 209 | with tf.Session() as sess: 210 | if config.reload: 211 | ckpt = tf.train.latest_checkpoint(config.saveloc) 212 | if ckpt: 213 | saver.restore(sess, ckpt) 214 | print("Restored weights from {}".format(config.saveloc)) 215 | 216 | # Find out the uidx that we are restoring from 217 | with open(os.path.join(config.saveloc, "checkpoint"), "r") as f: 218 | lines = f.readlines() 219 | model_checkpoint_line = lines[0].strip() 220 | dash_ind = model_checkpoint_line.rfind('-') 221 | uidx = int(model_checkpoint_line[dash_ind + 1:-1]) 222 | print("Continuing from update uidx: " + str(uidx)) 223 | else: 224 | raise ValueError( 225 | "No checkpoint to restore from in {}".format(config.saveloc)) 226 | 227 | # If using an adaptive learning rate schedule, 228 | # resume from the appropriate point 229 | if config.ada_learning_rate: 230 | current_lr = config.lr 231 | for uidx_ in range(uidx + 1): 232 | if uidx_ >= config.start_decr_lr and uidx_ % config.freq_decr_lr == 0: 233 | new_lr = current_lr * config.mult_lr_value 234 | if new_lr >= config.smallest_lr: 235 | current_lr = new_lr 236 | config.lr = current_lr 237 | note = "Reloaded from uidx {} and using lr {}\n".format(uidx, config.lr) 238 | print(note) 239 | dashboard_logger.take_note(note) 240 | else: 241 | uidx = 0 242 | sess.run(tf.global_variables_initializer()) 243 | 244 | # Create Evaluator objects 245 | train_evaluator = Evaluator(config, model, train_dataset, sess) 246 | val_evaluator = Evaluator(config, model, val_dataset, sess) 247 | 248 | train(sess, config, uidx, model, train_dataset, val_dataset, 249 | train_evaluator, val_evaluator, dashboard_logger, saver) 250 | 251 | 252 | if __name__ == "__main__": 253 | main() 254 | -------------------------------------------------------------------------------- /data/dataset_splits/omniglot/test.txt: -------------------------------------------------------------------------------- 1 | Gurmukhi/character42 2 | Gurmukhi/character43 3 | Gurmukhi/character44 4 | Gurmukhi/character45 5 | Kannada/character01 6 | Kannada/character02 7 | Kannada/character03 8 | Kannada/character04 9 | Kannada/character05 10 | Kannada/character06 11 | Kannada/character07 12 | Kannada/character08 13 | Kannada/character09 14 | Kannada/character10 15 | Kannada/character11 16 | Kannada/character12 17 | Kannada/character13 18 | Kannada/character14 19 | Kannada/character15 20 | Kannada/character16 21 | Kannada/character17 22 | Kannada/character18 23 | Kannada/character19 24 | Kannada/character20 25 | Kannada/character21 26 | Kannada/character22 27 | Kannada/character23 28 | Kannada/character24 29 | Kannada/character25 30 | Kannada/character26 31 | Kannada/character27 32 | Kannada/character28 33 | Kannada/character29 34 | Kannada/character30 35 | Kannada/character31 36 | Kannada/character32 37 | Kannada/character33 38 | Kannada/character34 39 | Kannada/character35 40 | Kannada/character36 41 | Kannada/character37 42 | Kannada/character38 43 | Kannada/character39 44 | Kannada/character40 45 | Kannada/character41 46 | Keble/character01 47 | Keble/character02 48 | Keble/character03 49 | Keble/character04 50 | Keble/character05 51 | Keble/character06 52 | Keble/character07 53 | Keble/character08 54 | Keble/character09 55 | Keble/character10 56 | Keble/character11 57 | Keble/character12 58 | Keble/character13 59 | Keble/character14 60 | Keble/character15 61 | Keble/character16 62 | Keble/character17 63 | Keble/character18 64 | Keble/character19 65 | Keble/character20 66 | Keble/character21 67 | Keble/character22 68 | Keble/character23 69 | Keble/character24 70 | Keble/character25 71 | Keble/character26 72 | Malayalam/character01 73 | Malayalam/character02 74 | Malayalam/character03 75 | Malayalam/character04 76 | Malayalam/character05 77 | Malayalam/character06 78 | Malayalam/character07 79 | Malayalam/character08 80 | Malayalam/character09 81 | Malayalam/character10 82 | Malayalam/character11 83 | Malayalam/character12 84 | Malayalam/character13 85 | Malayalam/character14 86 | Malayalam/character15 87 | Malayalam/character16 88 | Malayalam/character17 89 | Malayalam/character18 90 | Malayalam/character19 91 | Malayalam/character20 92 | Malayalam/character21 93 | Malayalam/character22 94 | Malayalam/character23 95 | Malayalam/character24 96 | Malayalam/character25 97 | Malayalam/character26 98 | Malayalam/character27 99 | Malayalam/character28 100 | Malayalam/character29 101 | Malayalam/character30 102 | Malayalam/character31 103 | Malayalam/character32 104 | Malayalam/character33 105 | Malayalam/character34 106 | Malayalam/character35 107 | Malayalam/character36 108 | Malayalam/character37 109 | Malayalam/character38 110 | Malayalam/character39 111 | Malayalam/character40 112 | Malayalam/character41 113 | Malayalam/character42 114 | Malayalam/character43 115 | Malayalam/character44 116 | Malayalam/character45 117 | Malayalam/character46 118 | Malayalam/character47 119 | Manipuri/character01 120 | Manipuri/character02 121 | Manipuri/character03 122 | Manipuri/character04 123 | Manipuri/character05 124 | Manipuri/character06 125 | Manipuri/character07 126 | Manipuri/character08 127 | Manipuri/character09 128 | Manipuri/character10 129 | Manipuri/character11 130 | Manipuri/character12 131 | Manipuri/character13 132 | Manipuri/character14 133 | Manipuri/character15 134 | Manipuri/character16 135 | Manipuri/character17 136 | Manipuri/character18 137 | Manipuri/character19 138 | Manipuri/character20 139 | Manipuri/character21 140 | Manipuri/character22 141 | Manipuri/character23 142 | Manipuri/character24 143 | Manipuri/character25 144 | Manipuri/character26 145 | Manipuri/character27 146 | Manipuri/character28 147 | Manipuri/character29 148 | Manipuri/character30 149 | Manipuri/character31 150 | Manipuri/character32 151 | Manipuri/character33 152 | Manipuri/character34 153 | Manipuri/character35 154 | Manipuri/character36 155 | Manipuri/character37 156 | Manipuri/character38 157 | Manipuri/character39 158 | Manipuri/character40 159 | Mongolian/character01 160 | Mongolian/character02 161 | Mongolian/character03 162 | Mongolian/character04 163 | Mongolian/character05 164 | Mongolian/character06 165 | Mongolian/character07 166 | Mongolian/character08 167 | Mongolian/character09 168 | Mongolian/character10 169 | Mongolian/character11 170 | Mongolian/character12 171 | Mongolian/character13 172 | Mongolian/character14 173 | Mongolian/character15 174 | Mongolian/character16 175 | Mongolian/character17 176 | Mongolian/character18 177 | Mongolian/character19 178 | Mongolian/character20 179 | Mongolian/character21 180 | Mongolian/character22 181 | Mongolian/character23 182 | Mongolian/character24 183 | Mongolian/character25 184 | Mongolian/character26 185 | Mongolian/character27 186 | Mongolian/character28 187 | Mongolian/character29 188 | Mongolian/character30 189 | Old_Church_Slavonic_(Cyrillic)/character01 190 | Old_Church_Slavonic_(Cyrillic)/character02 191 | Old_Church_Slavonic_(Cyrillic)/character03 192 | Old_Church_Slavonic_(Cyrillic)/character04 193 | Old_Church_Slavonic_(Cyrillic)/character05 194 | Old_Church_Slavonic_(Cyrillic)/character06 195 | Old_Church_Slavonic_(Cyrillic)/character07 196 | Old_Church_Slavonic_(Cyrillic)/character08 197 | Old_Church_Slavonic_(Cyrillic)/character09 198 | Old_Church_Slavonic_(Cyrillic)/character10 199 | Old_Church_Slavonic_(Cyrillic)/character11 200 | Old_Church_Slavonic_(Cyrillic)/character12 201 | Old_Church_Slavonic_(Cyrillic)/character13 202 | Old_Church_Slavonic_(Cyrillic)/character14 203 | Old_Church_Slavonic_(Cyrillic)/character15 204 | Old_Church_Slavonic_(Cyrillic)/character16 205 | Old_Church_Slavonic_(Cyrillic)/character17 206 | Old_Church_Slavonic_(Cyrillic)/character18 207 | Old_Church_Slavonic_(Cyrillic)/character19 208 | Old_Church_Slavonic_(Cyrillic)/character20 209 | Old_Church_Slavonic_(Cyrillic)/character21 210 | Old_Church_Slavonic_(Cyrillic)/character22 211 | Old_Church_Slavonic_(Cyrillic)/character23 212 | Old_Church_Slavonic_(Cyrillic)/character24 213 | Old_Church_Slavonic_(Cyrillic)/character25 214 | Old_Church_Slavonic_(Cyrillic)/character26 215 | Old_Church_Slavonic_(Cyrillic)/character27 216 | Old_Church_Slavonic_(Cyrillic)/character28 217 | Old_Church_Slavonic_(Cyrillic)/character29 218 | Old_Church_Slavonic_(Cyrillic)/character30 219 | Old_Church_Slavonic_(Cyrillic)/character31 220 | Old_Church_Slavonic_(Cyrillic)/character32 221 | Old_Church_Slavonic_(Cyrillic)/character33 222 | Old_Church_Slavonic_(Cyrillic)/character34 223 | Old_Church_Slavonic_(Cyrillic)/character35 224 | Old_Church_Slavonic_(Cyrillic)/character36 225 | Old_Church_Slavonic_(Cyrillic)/character37 226 | Old_Church_Slavonic_(Cyrillic)/character38 227 | Old_Church_Slavonic_(Cyrillic)/character39 228 | Old_Church_Slavonic_(Cyrillic)/character40 229 | Old_Church_Slavonic_(Cyrillic)/character41 230 | Old_Church_Slavonic_(Cyrillic)/character42 231 | Old_Church_Slavonic_(Cyrillic)/character43 232 | Old_Church_Slavonic_(Cyrillic)/character44 233 | Old_Church_Slavonic_(Cyrillic)/character45 234 | Oriya/character01 235 | Oriya/character02 236 | Oriya/character03 237 | Oriya/character04 238 | Oriya/character05 239 | Oriya/character06 240 | Oriya/character07 241 | Oriya/character08 242 | Oriya/character09 243 | Oriya/character10 244 | Oriya/character11 245 | Oriya/character12 246 | Oriya/character13 247 | Oriya/character14 248 | Oriya/character15 249 | Oriya/character16 250 | Oriya/character17 251 | Oriya/character18 252 | Oriya/character19 253 | Oriya/character20 254 | Oriya/character21 255 | Oriya/character22 256 | Oriya/character23 257 | Oriya/character24 258 | Oriya/character25 259 | Oriya/character26 260 | Oriya/character27 261 | Oriya/character28 262 | Oriya/character29 263 | Oriya/character30 264 | Oriya/character31 265 | Oriya/character32 266 | Oriya/character33 267 | Oriya/character34 268 | Oriya/character35 269 | Oriya/character36 270 | Oriya/character37 271 | Oriya/character38 272 | Oriya/character39 273 | Oriya/character40 274 | Oriya/character41 275 | Oriya/character42 276 | Oriya/character43 277 | Oriya/character44 278 | Oriya/character45 279 | Oriya/character46 280 | Syriac_(Serto)/character01 281 | Syriac_(Serto)/character02 282 | Syriac_(Serto)/character03 283 | Syriac_(Serto)/character04 284 | Syriac_(Serto)/character05 285 | Syriac_(Serto)/character06 286 | Syriac_(Serto)/character07 287 | Syriac_(Serto)/character08 288 | Syriac_(Serto)/character09 289 | Syriac_(Serto)/character10 290 | Syriac_(Serto)/character11 291 | Syriac_(Serto)/character12 292 | Syriac_(Serto)/character13 293 | Syriac_(Serto)/character14 294 | Syriac_(Serto)/character15 295 | Syriac_(Serto)/character16 296 | Syriac_(Serto)/character17 297 | Syriac_(Serto)/character18 298 | Syriac_(Serto)/character19 299 | Syriac_(Serto)/character20 300 | Syriac_(Serto)/character21 301 | Syriac_(Serto)/character22 302 | Syriac_(Serto)/character23 303 | Sylheti/character01 304 | Sylheti/character02 305 | Sylheti/character03 306 | Sylheti/character04 307 | Sylheti/character05 308 | Sylheti/character06 309 | Sylheti/character07 310 | Sylheti/character08 311 | Sylheti/character09 312 | Sylheti/character10 313 | Sylheti/character11 314 | Sylheti/character12 315 | Sylheti/character13 316 | Sylheti/character14 317 | Sylheti/character15 318 | Sylheti/character16 319 | Sylheti/character17 320 | Sylheti/character18 321 | Sylheti/character19 322 | Sylheti/character20 323 | Sylheti/character21 324 | Sylheti/character22 325 | Sylheti/character23 326 | Sylheti/character24 327 | Sylheti/character25 328 | Sylheti/character26 329 | Sylheti/character27 330 | Sylheti/character28 331 | Tengwar/character01 332 | Tengwar/character02 333 | Tengwar/character03 334 | Tengwar/character04 335 | Tengwar/character05 336 | Tengwar/character06 337 | Tengwar/character07 338 | Tengwar/character08 339 | Tengwar/character09 340 | Tengwar/character10 341 | Tengwar/character11 342 | Tengwar/character12 343 | Tengwar/character13 344 | Tengwar/character14 345 | Tengwar/character15 346 | Tengwar/character16 347 | Tengwar/character17 348 | Tengwar/character18 349 | Tengwar/character19 350 | Tengwar/character20 351 | Tengwar/character21 352 | Tengwar/character22 353 | Tengwar/character23 354 | Tengwar/character24 355 | Tengwar/character25 356 | Tibetan/character01 357 | Tibetan/character02 358 | Tibetan/character03 359 | Tibetan/character04 360 | Tibetan/character05 361 | Tibetan/character06 362 | Tibetan/character07 363 | Tibetan/character08 364 | Tibetan/character09 365 | Tibetan/character10 366 | Tibetan/character11 367 | Tibetan/character12 368 | Tibetan/character13 369 | Tibetan/character14 370 | Tibetan/character15 371 | Tibetan/character16 372 | Tibetan/character17 373 | Tibetan/character18 374 | Tibetan/character19 375 | Tibetan/character20 376 | Tibetan/character21 377 | Tibetan/character22 378 | Tibetan/character23 379 | Tibetan/character24 380 | Tibetan/character25 381 | Tibetan/character26 382 | Tibetan/character27 383 | Tibetan/character28 384 | Tibetan/character29 385 | Tibetan/character30 386 | Tibetan/character31 387 | Tibetan/character32 388 | Tibetan/character33 389 | Tibetan/character34 390 | Tibetan/character35 391 | Tibetan/character36 392 | Tibetan/character37 393 | Tibetan/character38 394 | Tibetan/character39 395 | Tibetan/character40 396 | Tibetan/character41 397 | Tibetan/character42 398 | ULOG/character01 399 | ULOG/character02 400 | ULOG/character03 401 | ULOG/character04 402 | ULOG/character05 403 | ULOG/character06 404 | ULOG/character07 405 | ULOG/character08 406 | ULOG/character09 407 | ULOG/character10 408 | ULOG/character11 409 | ULOG/character12 410 | ULOG/character13 411 | ULOG/character14 412 | ULOG/character15 413 | ULOG/character16 414 | ULOG/character17 415 | ULOG/character18 416 | ULOG/character19 417 | ULOG/character20 418 | ULOG/character21 419 | ULOG/character22 420 | ULOG/character23 421 | ULOG/character24 422 | ULOG/character25 423 | ULOG/character26 424 | -------------------------------------------------------------------------------- /data/dataset_splits/omniglot/val.txt: -------------------------------------------------------------------------------- 1 | Arcadian/character05 2 | Glagolitic/character21 3 | Balinese/character02 4 | Gurmukhi/character30 5 | Asomtavruli_(Georgian)/character25 6 | Bengali/character14 7 | Sanskrit/character34 8 | Burmese_(Myanmar)/character09 9 | Hebrew/character19 10 | Atlantean/character14 11 | Bengali/character18 12 | Cyrillic/character31 13 | Gujarati/character05 14 | Greek/character11 15 | Tifinagh/character01 16 | Atemayar_Qelisayer/character09 17 | Tifinagh/character28 18 | Japanese_(hiragana)/character23 19 | Avesta/character06 20 | Latin/character26 21 | Armenian/character15 22 | Japanese_(hiragana)/character09 23 | Gurmukhi/character38 24 | N_Ko/character23 25 | Anglo-Saxon_Futhorc/character29 26 | Avesta/character11 27 | Early_Aramaic/character17 28 | Ge_ez/character21 29 | Gurmukhi/character14 30 | Armenian/character05 31 | Atemayar_Qelisayer/character24 32 | Glagolitic/character26 33 | Blackfoot_(Canadian_Aboriginal_Syllabics)/character01 34 | Grantha/character17 35 | Tifinagh/character39 36 | Syriac_(Estrangelo)/character17 37 | Japanese_(katakana)/character03 38 | Balinese/character17 39 | Japanese_(hiragana)/character46 40 | Japanese_(hiragana)/character48 41 | Malay_(Jawi_-_Arabic)/character27 42 | Arcadian/character12 43 | Latin/character07 44 | Malay_(Jawi_-_Arabic)/character01 45 | Tifinagh/character41 46 | Tifinagh/character18 47 | Sanskrit/character09 48 | Japanese_(hiragana)/character30 49 | Greek/character15 50 | Armenian/character36 51 | Futurama/character08 52 | Japanese_(katakana)/character10 53 | Gurmukhi/character01 54 | Avesta/character17 55 | Futurama/character11 56 | Gujarati/character08 57 | Japanese_(hiragana)/character43 58 | Asomtavruli_(Georgian)/character14 59 | Braille/character05 60 | Armenian/character08 61 | Syriac_(Estrangelo)/character18 62 | Ojibwe_(Canadian_Aboriginal_Syllabics)/character06 63 | Korean/character14 64 | Balinese/character19 65 | Balinese/character07 66 | Asomtavruli_(Georgian)/character19 67 | Gujarati/character23 68 | Mkhedruli_(Georgian)/character35 69 | Braille/character08 70 | Armenian/character41 71 | Malay_(Jawi_-_Arabic)/character04 72 | Glagolitic/character06 73 | Gujarati/character15 74 | Ojibwe_(Canadian_Aboriginal_Syllabics)/character12 75 | Hebrew/character21 76 | Latin/character19 77 | N_Ko/character10 78 | N_Ko/character21 79 | Korean/character15 80 | Balinese/character16 81 | Tifinagh/character30 82 | Sanskrit/character27 83 | Bengali/character26 84 | Burmese_(Myanmar)/character27 85 | Cyrillic/character13 86 | Ge_ez/character25 87 | Gurmukhi/character17 88 | Arcadian/character15 89 | Greek/character02 90 | Burmese_(Myanmar)/character21 91 | Early_Aramaic/character13 92 | Cyrillic/character24 93 | Korean/character04 94 | Futurama/character17 95 | Bengali/character17 96 | Bengali/character08 97 | Gurmukhi/character23 98 | Greek/character24 99 | Gujarati/character41 100 | Sanskrit/character07 101 | Ge_ez/character08 102 | Mkhedruli_(Georgian)/character21 103 | Arcadian/character11 104 | Ojibwe_(Canadian_Aboriginal_Syllabics)/character01 105 | Anglo-Saxon_Futhorc/character14 106 | Korean/character35 107 | Hebrew/character09 108 | Futurama/character02 109 | Bengali/character32 110 | Gujarati/character10 111 | Latin/character08 112 | Mkhedruli_(Georgian)/character17 113 | Futurama/character04 114 | Japanese_(katakana)/character06 115 | Balinese/character08 116 | N_Ko/character05 117 | Atemayar_Qelisayer/character18 118 | Syriac_(Estrangelo)/character08 119 | Japanese_(katakana)/character26 120 | Ge_ez/character23 121 | Atlantean/character18 122 | Armenian/character33 123 | Japanese_(katakana)/character11 124 | Ge_ez/character11 125 | Asomtavruli_(Georgian)/character40 126 | Angelic/character15 127 | Mkhedruli_(Georgian)/character27 128 | Japanese_(hiragana)/character36 129 | Atemayar_Qelisayer/character21 130 | Angelic/character03 131 | Burmese_(Myanmar)/character07 132 | Avesta/character01 133 | Latin/character22 134 | Alphabet_of_the_Magi/character07 135 | Latin/character01 136 | Tagalog/character04 137 | Japanese_(hiragana)/character47 138 | Gurmukhi/character16 139 | Malay_(Jawi_-_Arabic)/character25 140 | Korean/character05 141 | Ojibwe_(Canadian_Aboriginal_Syllabics)/character09 142 | Cyrillic/character11 143 | Sanskrit/character02 144 | Korean/character26 145 | Japanese_(hiragana)/character27 146 | Cyrillic/character08 147 | Cyrillic/character29 148 | Gurmukhi/character19 149 | Bengali/character23 150 | Latin/character12 151 | Grantha/character20 152 | Greek/character12 153 | Syriac_(Estrangelo)/character22 154 | Cyrillic/character12 155 | Asomtavruli_(Georgian)/character22 156 | Malay_(Jawi_-_Arabic)/character12 157 | Avesta/character02 158 | Gurmukhi/character36 159 | Burmese_(Myanmar)/character12 160 | Asomtavruli_(Georgian)/character10 161 | Anglo-Saxon_Futhorc/character18 162 | Gurmukhi/character04 163 | Gujarati/character25 164 | Bengali/character15 165 | Glagolitic/character29 166 | Asomtavruli_(Georgian)/character30 167 | Japanese_(katakana)/character23 168 | Malay_(Jawi_-_Arabic)/character15 169 | Korean/character29 170 | Asomtavruli_(Georgian)/character05 171 | Arcadian/character04 172 | Ojibwe_(Canadian_Aboriginal_Syllabics)/character11 173 | Korean/character18 174 | Gurmukhi/character09 175 | Angelic/character12 176 | Alphabet_of_the_Magi/character20 177 | Hebrew/character14 178 | Malay_(Jawi_-_Arabic)/character16 179 | Sanskrit/character05 180 | Aurek-Besh/character13 181 | Balinese/character24 182 | Angelic/character14 183 | Japanese_(hiragana)/character35 184 | Grantha/character26 185 | Tifinagh/character22 186 | Inuktitut_(Canadian_Aboriginal_Syllabics)/character12 187 | Syriac_(Estrangelo)/character20 188 | Ge_ez/character20 189 | Early_Aramaic/character10 190 | Anglo-Saxon_Futhorc/character10 191 | Gujarati/character17 192 | Atlantean/character06 193 | Alphabet_of_the_Magi/character02 194 | N_Ko/character13 195 | Braille/character09 196 | Aurek-Besh/character05 197 | Aurek-Besh/character22 198 | Latin/character15 199 | Syriac_(Estrangelo)/character13 200 | Greek/character21 201 | Glagolitic/character44 202 | Angelic/character09 203 | Japanese_(hiragana)/character51 204 | Japanese_(katakana)/character05 205 | Arcadian/character26 206 | Burmese_(Myanmar)/character20 207 | Grantha/character04 208 | Gujarati/character27 209 | Braille/character03 210 | Cyrillic/character25 211 | Bengali/character42 212 | Bengali/character41 213 | Mkhedruli_(Georgian)/character25 214 | Armenian/character31 215 | Burmese_(Myanmar)/character32 216 | Atemayar_Qelisayer/character03 217 | Malay_(Jawi_-_Arabic)/character38 218 | Alphabet_of_the_Magi/character04 219 | Korean/character34 220 | Atlantean/character25 221 | Atlantean/character09 222 | Bengali/character33 223 | Anglo-Saxon_Futhorc/character22 224 | N_Ko/character24 225 | Ge_ez/character13 226 | Malay_(Jawi_-_Arabic)/character20 227 | Sanskrit/character04 228 | Tagalog/character06 229 | Grantha/character14 230 | Burmese_(Myanmar)/character02 231 | Cyrillic/character28 232 | Japanese_(hiragana)/character03 233 | Ojibwe_(Canadian_Aboriginal_Syllabics)/character08 234 | Sanskrit/character37 235 | Hebrew/character07 236 | Mkhedruli_(Georgian)/character19 237 | Aurek-Besh/character14 238 | N_Ko/character15 239 | Japanese_(katakana)/character31 240 | Cyrillic/character19 241 | Early_Aramaic/character15 242 | Grantha/character41 243 | Grantha/character12 244 | Asomtavruli_(Georgian)/character23 245 | Japanese_(katakana)/character25 246 | Sanskrit/character15 247 | Braille/character14 248 | Asomtavruli_(Georgian)/character37 249 | Burmese_(Myanmar)/character14 250 | Sanskrit/character10 251 | Japanese_(hiragana)/character13 252 | Aurek-Besh/character16 253 | Glagolitic/character39 254 | Latin/character24 255 | Grantha/character30 256 | Early_Aramaic/character04 257 | Angelic/character04 258 | Cyrillic/character26 259 | Japanese_(katakana)/character22 260 | Armenian/character07 261 | Tifinagh/character20 262 | Early_Aramaic/character06 263 | Alphabet_of_the_Magi/character03 264 | Balinese/character13 265 | Sanskrit/character21 266 | Japanese_(hiragana)/character06 267 | Arcadian/character02 268 | Bengali/character19 269 | Japanese_(hiragana)/character04 270 | Grantha/character21 271 | Mkhedruli_(Georgian)/character33 272 | Asomtavruli_(Georgian)/character02 273 | Greek/character10 274 | Aurek-Besh/character06 275 | Japanese_(katakana)/character36 276 | Sanskrit/character35 277 | N_Ko/character29 278 | Asomtavruli_(Georgian)/character29 279 | Blackfoot_(Canadian_Aboriginal_Syllabics)/character07 280 | Latin/character05 281 | Sanskrit/character01 282 | Grantha/character35 283 | Sanskrit/character31 284 | Japanese_(hiragana)/character24 285 | Ge_ez/character06 286 | Avesta/character18 287 | Asomtavruli_(Georgian)/character28 288 | Gujarati/character38 289 | Hebrew/character04 290 | Alphabet_of_the_Magi/character18 291 | Japanese_(hiragana)/character08 292 | Japanese_(katakana)/character37 293 | Grantha/character15 294 | Aurek-Besh/character01 295 | Braille/character19 296 | Sanskrit/character13 297 | Burmese_(Myanmar)/character25 298 | Bengali/character38 299 | Ge_ez/character24 300 | Mkhedruli_(Georgian)/character06 301 | Atlantean/character22 302 | Burmese_(Myanmar)/character18 303 | Grantha/character36 304 | Blackfoot_(Canadian_Aboriginal_Syllabics)/character09 305 | Asomtavruli_(Georgian)/character12 306 | Japanese_(hiragana)/character41 307 | Balinese/character20 308 | Cyrillic/character06 309 | Braille/character18 310 | Tifinagh/character52 311 | Ge_ez/character12 312 | Japanese_(katakana)/character44 313 | Japanese_(hiragana)/character45 314 | Japanese_(hiragana)/character29 315 | Early_Aramaic/character01 316 | Atemayar_Qelisayer/character10 317 | Tifinagh/character13 318 | Anglo-Saxon_Futhorc/character06 319 | Japanese_(hiragana)/character32 320 | Gurmukhi/character18 321 | Tifinagh/character05 322 | Tifinagh/character08 323 | Grantha/character18 324 | Armenian/character39 325 | Balinese/character18 326 | Hebrew/character13 327 | Gujarati/character21 328 | Glagolitic/character12 329 | Mkhedruli_(Georgian)/character36 330 | Aurek-Besh/character24 331 | Atemayar_Qelisayer/character05 332 | Armenian/character29 333 | Tifinagh/character32 334 | Sanskrit/character14 335 | Atemayar_Qelisayer/character23 336 | Arcadian/character17 337 | Arcadian/character08 338 | Greek/character18 339 | Mkhedruli_(Georgian)/character01 340 | Hebrew/character05 341 | Japanese_(katakana)/character12 342 | Tifinagh/character45 343 | Korean/character12 344 | Atlantean/character24 345 | Tifinagh/character06 346 | Braille/character01 347 | Burmese_(Myanmar)/character11 348 | N_Ko/character04 349 | Glagolitic/character34 350 | Gurmukhi/character05 351 | Braille/character10 352 | N_Ko/character31 353 | Sanskrit/character40 354 | Gurmukhi/character37 355 | Braille/character22 356 | Japanese_(katakana)/character47 357 | Anglo-Saxon_Futhorc/character01 358 | Inuktitut_(Canadian_Aboriginal_Syllabics)/character10 359 | Avesta/character07 360 | Ge_ez/character26 361 | Arcadian/character21 362 | Armenian/character02 363 | Burmese_(Myanmar)/character17 364 | Cyrillic/character30 365 | Gujarati/character07 366 | Aurek-Besh/character21 367 | Greek/character03 368 | Armenian/character27 369 | Anglo-Saxon_Futhorc/character25 370 | Early_Aramaic/character16 371 | Cyrillic/character01 372 | Gujarati/character32 373 | Inuktitut_(Canadian_Aboriginal_Syllabics)/character15 374 | Gurmukhi/character02 375 | Japanese_(hiragana)/character39 376 | Alphabet_of_the_Magi/character05 377 | Mkhedruli_(Georgian)/character07 378 | Burmese_(Myanmar)/character24 379 | Tagalog/character07 380 | Hebrew/character06 381 | Sanskrit/character39 382 | Armenian/character04 383 | Cyrillic/character02 384 | Angelic/character16 385 | Armenian/character03 386 | Gujarati/character06 387 | Glagolitic/character19 388 | Avesta/character04 389 | Burmese_(Myanmar)/character33 390 | Asomtavruli_(Georgian)/character21 391 | Inuktitut_(Canadian_Aboriginal_Syllabics)/character07 392 | Braille/character26 393 | N_Ko/character11 394 | Anglo-Saxon_Futhorc/character03 395 | Anglo-Saxon_Futhorc/character28 396 | Armenian/character06 397 | Malay_(Jawi_-_Arabic)/character08 398 | Mkhedruli_(Georgian)/character26 399 | Syriac_(Estrangelo)/character15 400 | Mkhedruli_(Georgian)/character32 401 | Glagolitic/character09 402 | Cyrillic/character16 403 | Glagolitic/character32 404 | Anglo-Saxon_Futhorc/character08 405 | Arcadian/character19 406 | Armenian/character40 407 | Ge_ez/character19 408 | Futurama/character18 409 | Tagalog/character17 410 | Aurek-Besh/character15 411 | Angelic/character13 412 | Early_Aramaic/character21 413 | N_Ko/character01 414 | Bengali/character34 415 | Korean/character23 416 | Tifinagh/character49 417 | Malay_(Jawi_-_Arabic)/character06 418 | Balinese/character15 419 | Blackfoot_(Canadian_Aboriginal_Syllabics)/character05 420 | Sanskrit/character26 421 | Bengali/character28 422 | Armenian/character37 423 | Ge_ez/character05 424 | Futurama/character03 425 | Korean/character10 426 | Arcadian/character07 427 | Early_Aramaic/character12 428 | Greek/character16 429 | Gujarati/character29 430 | Armenian/character24 431 | Ge_ez/character04 432 | Glagolitic/character07 433 | Tifinagh/character29 434 | Avesta/character24 435 | Sanskrit/character22 436 | Braille/character25 437 | Gurmukhi/character27 438 | Bengali/character40 439 | Atemayar_Qelisayer/character01 440 | Arcadian/character16 441 | Burmese_(Myanmar)/character04 442 | Futurama/character09 443 | Japanese_(hiragana)/character26 444 | Armenian/character22 445 | Gujarati/character26 446 | Asomtavruli_(Georgian)/character38 447 | Japanese_(hiragana)/character17 448 | Tifinagh/character23 449 | Burmese_(Myanmar)/character28 450 | Gujarati/character44 451 | Latin/character18 452 | Inuktitut_(Canadian_Aboriginal_Syllabics)/character01 453 | Japanese_(hiragana)/character38 454 | Gurmukhi/character24 455 | Blackfoot_(Canadian_Aboriginal_Syllabics)/character10 456 | Tifinagh/character31 457 | Blackfoot_(Canadian_Aboriginal_Syllabics)/character13 458 | Malay_(Jawi_-_Arabic)/character30 459 | Gujarati/character12 460 | Japanese_(katakana)/character38 461 | Braille/character21 462 | Sanskrit/character16 463 | Japanese_(katakana)/character09 464 | Japanese_(katakana)/character07 465 | Armenian/character11 466 | Glagolitic/character22 467 | Sanskrit/character28 468 | Braille/character07 469 | Aurek-Besh/character18 470 | Atemayar_Qelisayer/character06 471 | Burmese_(Myanmar)/character29 472 | Bengali/character31 473 | Atemayar_Qelisayer/character19 474 | Armenian/character38 475 | Glagolitic/character33 476 | Japanese_(hiragana)/character02 477 | Sanskrit/character33 478 | Asomtavruli_(Georgian)/character08 479 | Tifinagh/character40 480 | Armenian/character13 481 | -------------------------------------------------------------------------------- /src/data/omniglot.py: -------------------------------------------------------------------------------- 1 | from __future__ import (absolute_import, division, print_function, 2 | unicode_literals) 3 | 4 | import os 5 | import numpy as np 6 | from itertools import combinations 7 | 8 | from src.data.dataset import Dataset 9 | 10 | OMNIGLOT_FOLDER = "data/omniglot/omniglot-master/python/" 11 | OMNIGLOT_SPLITS_TRAINVAL = "data/dataset_splits/omniglot/trainval.txt" 12 | OMNIGLOT_SPLITS_TRAIN = "data/dataset_splits/omniglot/train.txt" 13 | OMNIGLOT_SPLITS_VAL = "data/dataset_splits/omniglot/val.txt" 14 | OMNIGLOT_SPLITS_TEST = "data/dataset_splits/omniglot/test.txt" 15 | OMNIGLOT_IMGS_BACKGROUND_ROT = OMNIGLOT_FOLDER + "images_background_resized_rot" 16 | OMNIGLOT_IMGS_EVAL_ROT = OMNIGLOT_FOLDER + "images_evaluation_resized_rot" 17 | 18 | 19 | class OmniglotDataset(Dataset): 20 | 21 | def __init__(self, name, config, nway, split, batch_size, cache_path): 22 | super(OmniglotDataset, self).__init__(name, config, nway, split, batch_size, 23 | cache_path) 24 | 25 | def get_batch_points_(self): 26 | char_classes = self.get_classes_list(self.split, augmented=False) 27 | n_classes_no_rot = len(char_classes) 28 | char_inds = np.random.choice(n_classes_no_rot, self.nway, replace=False) 29 | selected_chars = np.array(char_classes)[char_inds] 30 | rot_inds = np.random.choice(4, self.nway) 31 | angles = ['000', '090', '180', '270'] 32 | selected_angles = [angles[i] for i in rot_inds] 33 | selected_classes = np.array([ 34 | selected_chars[i] + '_rot_' + selected_angles[i] 35 | for i in range(self.nway) 36 | ]) 37 | labels_no_rot = np.array([label[:-8] for label in list(self.labels)]) 38 | rot_angles = np.array([label[-3:] for label in list(self.labels)]) 39 | includable_point_inds = [] 40 | for selected_char, selected_angle in zip(selected_chars, selected_angles): 41 | sat_char_inds = list(np.where(labels_no_rot == selected_char)[0]) 42 | sat_inds = [ 43 | ind for ind in sat_char_inds if rot_angles[ind] == selected_angle 44 | ] 45 | includable_point_inds += sat_inds 46 | 47 | num_includable_points = len(includable_point_inds) 48 | includable_point_inds_array = np.array(includable_point_inds) 49 | 50 | # Randomly select batch_size many of the 'allowed' examples 51 | selected_inds_inds = np.random.choice( 52 | num_includable_points, self.batch_size, replace=False) 53 | selected_inds = includable_point_inds_array[selected_inds_inds] 54 | 55 | selected_img_inds = self.img_inds[selected_inds] 56 | selected_labels = self.labels[selected_inds] 57 | batch_imgs = self.load_batch_imgs(selected_img_inds) 58 | return batch_imgs, selected_labels, selected_inds 59 | 60 | def get_batch_points(self): 61 | batch_imgs, selected_labels, selected_inds = self.get_batch_points_() 62 | return batch_imgs, selected_labels 63 | 64 | def get_batch_pairs(self): 65 | ''' 66 | Forms all pairs from a given set of images 67 | :return: a batch for training the siamese network 68 | ''' 69 | batch_imgs, labels, selected_inds = self.get_batch_points_() 70 | 71 | relative_inds = np.arange(self.batch_size) 72 | pair_inds = [list(x) for x in combinations(relative_inds, 2)] 73 | label1s = [labels[pair[0]] for pair in pair_inds] 74 | label2s = [labels[pair[1]] for pair in pair_inds] 75 | _pair_labels = [ 76 | int(label1s[i] == label2s[i]) for i in range(len(pair_inds)) 77 | ] 78 | 79 | inds_A_relative = [pair[0] for pair in pair_inds] 80 | inds_B_relative = [pair[1] for pair in pair_inds] 81 | 82 | pair_labels = np.array(_pair_labels) 83 | pair_labels = pair_labels.reshape(pair_labels.shape[0], 1) 84 | if self.config.loss_function == 'cross_entropy': 85 | pair_labels = self.dense_to_one_hot(pair_labels, 2) 86 | pair_labels = pair_labels.reshape(int(pair_labels.shape[0] / 2), 87 | 2) # for 2-way softmax 88 | return batch_imgs, labels, inds_A_relative, inds_B_relative, pair_labels 89 | 90 | def get_classes_list(self, split, augmented=True): 91 | 92 | def read_splits(fpath, augmented): 93 | classes = [] 94 | with open(fpath, "r") as f: 95 | for line in f: 96 | if len(line) > 0: 97 | _class = line.strip() 98 | if augmented: 99 | for s in ['_rot_000', '_rot_090', '_rot_180', '_rot_270']: 100 | _class_rot = _class + s 101 | classes.append(_class_rot) 102 | else: 103 | classes.append(_class) 104 | return classes 105 | 106 | if split == "test": 107 | classes = read_splits(OMNIGLOT_SPLITS_TEST, augmented) 108 | elif split == "train" or split == "val": 109 | if split == "train": 110 | fpath = OMNIGLOT_SPLITS_TRAIN 111 | elif split == "val": 112 | fpath = OMNIGLOT_SPLITS_VAL 113 | 114 | # Check if trainval has already been split into train and val 115 | if os.path.exists(fpath): 116 | classes = read_splits(fpath, augmented) 117 | else: 118 | # Split trainval into train and val and write to disk 119 | trainval_classes_norot = read_splits( 120 | OMNIGLOT_SPLITS_TRAINVAL, augmented=False) 121 | num_trainval_classes = len(trainval_classes_norot) 122 | perm = np.arange(num_trainval_classes) 123 | np.random.shuffle(perm) 124 | num_training_classes = int( 125 | 0.6 * 126 | num_trainval_classes) # 60/40 split of trainval into train/val 127 | train_classes_inds = list(perm[:num_training_classes]) 128 | val_classes_inds = list(perm[num_training_classes:]) 129 | train_chars = np.array(trainval_classes_norot)[train_classes_inds] 130 | val_chars = np.array(trainval_classes_norot)[val_classes_inds] 131 | 132 | assert len(train_chars) + len(val_chars) == len(trainval_classes_norot), \ 133 | "Num train chars {} + num val chars {} should be {}.".format(len(train_chars), len(val_chars), 134 | len(trainval_classes_norot)) 135 | # Write to disk 136 | with open(OMNIGLOT_SPLITS_TRAIN, "a") as f: 137 | for c in train_chars: 138 | f.write("{}\n".format(c)) 139 | with open(OMNIGLOT_SPLITS_VAL, "a") as f: 140 | for c in val_chars: 141 | f.write("{}\n".format(c)) 142 | classes = read_splits(fpath, augmented) 143 | else: 144 | raise ValueError( 145 | "Unknown split. Please choose one of 'train', 'val' and 'test'.") 146 | return classes 147 | 148 | def load_batch_imgs(self, img_inds): 149 | batch_imgs = np.array([]) 150 | for i in img_inds: 151 | img_path = self.images_dict[i] 152 | img_array = self.load_img_as_array(img_path) 153 | img_array = img_array.reshape((1, self._height, self._width, 154 | self._channels)) 155 | if batch_imgs.shape[0] == 0: # first image 156 | batch_imgs = img_array 157 | else: 158 | batch_imgs = np.concatenate((batch_imgs, img_array), axis=0) 159 | return batch_imgs 160 | 161 | def load_dataset(self): 162 | # Note: The trainval classes are split into training and validation classes 163 | # by randomly selecting 60 percent of the overall trainval *characters* 164 | # to be used for training (along with all 4 rotations of each) 165 | # and the remaining characters (with all their rotations) to be validation classes. 166 | root_folder_1 = OMNIGLOT_IMGS_BACKGROUND_ROT 167 | root_folder_2 = OMNIGLOT_IMGS_EVAL_ROT 168 | 169 | # dictionaries mapping index of image into the dataset 170 | # to the location of the image on disk 171 | imgs_train_dict = {} 172 | imgs_val_dict = {} 173 | imgs_test_dict = {} 174 | labels_train = [] 175 | labels_val = [] 176 | labels_test = [] 177 | 178 | # For example, a class here is: Grantha/character08_rot180 if augmented is true 179 | train_classes = self.get_classes_list("train", augmented=True) 180 | val_classes = self.get_classes_list("val", augmented=True) 181 | test_classes = self.get_classes_list("test", augmented=True) 182 | 183 | example_ind_train = 0 184 | example_ind_val = 0 185 | example_ind_test = 0 186 | num_classes_loaded = -1 187 | for c in train_classes + val_classes + test_classes: 188 | num_classes_loaded += 1 189 | slash_ind = c.find('/') 190 | alphabet = c[:slash_ind] 191 | char = c[slash_ind + 1:] 192 | 193 | # Determine which folder this alphabet belongs to 194 | path1 = os.path.join(root_folder_1, alphabet) 195 | path2 = os.path.join(root_folder_2, alphabet) 196 | if os.path.isdir(path1): 197 | alphabet_folder = path1 198 | else: 199 | alphabet_folder = path2 200 | 201 | # The index of the example into the class (there are 20 of each class) 202 | class_image_num_train = 0 203 | class_image_num_val = 0 204 | class_image_num_test = 0 205 | 206 | # char is something like Grantha/character08_rot180 207 | underscore_ind = char.find('_') 208 | img_folder = os.path.join(alphabet_folder, char[:underscore_ind]) 209 | rot_angle = char[-3:] # one of '000', '090', '180', '270' 210 | img_files = [ 211 | img_f for img_f in os.listdir(img_folder) if img_f[-7:-4] == rot_angle 212 | ] 213 | for img in img_files: 214 | img_loc = os.path.join(img_folder, img) 215 | if c in train_classes: 216 | imgs_train_dict[example_ind_train] = img_loc 217 | example_ind_train += 1 218 | label = c 219 | labels_train.append(label) 220 | class_image_num_train += 1 221 | elif c in val_classes: 222 | imgs_val_dict[example_ind_val] = img_loc 223 | example_ind_val += 1 224 | label = c 225 | labels_val.append(label) 226 | class_image_num_val += 1 227 | elif c in test_classes: 228 | imgs_test_dict[example_ind_test] = img_loc 229 | example_ind_test += 1 230 | label = c 231 | labels_test.append(label) 232 | class_image_num_test += 1 233 | else: 234 | raise ValueError("Found a class that does not belong to any split.") 235 | labels_train = np.array(labels_train) 236 | labels_val = np.array(labels_val) 237 | labels_test = np.array(labels_test) 238 | return imgs_train_dict, labels_train, imgs_val_dict, labels_val, imgs_test_dict, labels_test 239 | 240 | def create_KshotNway_classification_episode(self, K, N): 241 | all_classes = self.get_classes_list(self.split, augmented=True) 242 | num_classes = len(all_classes) 243 | perm = np.arange(num_classes) 244 | np.random.shuffle(perm) 245 | chosen_class_inds = list(perm[:N]) 246 | 247 | ref_paths, query_paths, labels = [], [], [] 248 | for n in range(N): 249 | c = all_classes[chosen_class_inds[n]] 250 | root_folder_1 = OMNIGLOT_IMGS_BACKGROUND_ROT 251 | root_folder_2 = OMNIGLOT_IMGS_EVAL_ROT 252 | slash_ind = c.find('/') 253 | alphabet = c[:slash_ind] 254 | char = c[slash_ind + 1:] 255 | 256 | # Determine which folder this alphabet belongs to 257 | # (since the new splits may have mixed which alphabets are background/evaluation) 258 | # with respect to these folders corresponding to the old splits. 259 | path1 = os.path.join(root_folder_1, alphabet) 260 | path2 = os.path.join(root_folder_2, alphabet) 261 | if os.path.isdir(path1): 262 | alphabet_folder = path1 263 | else: 264 | alphabet_folder = path2 265 | 266 | # char is something like Grantha/character08_rot180 267 | underscore_ind = char.find('_') 268 | img_folder = os.path.join(alphabet_folder, char[:underscore_ind]) 269 | rot_angle = char[-3:] # one of '000', '090', '180', '270' 270 | img_files = [ 271 | img_f for img_f in os.listdir(img_folder) if img_f[-7:-4] == rot_angle 272 | ] 273 | 274 | # get an example of this character class 275 | img_example = img_files[0] # for example 1040_06_rot_090.png 276 | char_baselabel = img_example[:img_example.find('_')] # for example 1040 277 | 278 | # choose K images (drawers) to act as the representatives for this class 279 | drawer_inds = np.arange(20) 280 | np.random.shuffle(drawer_inds) 281 | ref_draw_inds = drawer_inds[:K] 282 | query_draw_inds = drawer_inds[K:] 283 | 284 | class_ref_paths = [] 285 | class_query_paths = [] 286 | for i in range(20): 287 | if len(str(i + 1)) < 2: 288 | str_ind = '0' + str(i + 1) 289 | else: 290 | str_ind = str(i + 1) 291 | img_name = char_baselabel + '_' + str_ind + '_rot_' + rot_angle + '.png' 292 | class_path = os.path.join(alphabet_folder, char[:underscore_ind]) 293 | img_path = os.path.join(class_path, img_name) 294 | if i in ref_draw_inds: # reference 295 | class_ref_paths.append(img_path) 296 | elif i in query_draw_inds: # query 297 | class_query_paths.append(img_path) 298 | 299 | ref_paths.append(class_ref_paths) 300 | query_paths.append(class_query_paths) 301 | labels.append(n) 302 | return ref_paths, query_paths, labels 303 | 304 | def create_oneshotNway_retrieval_episode(self, N, n_per_class): 305 | all_classes = self.get_classes_list(self.split, augmented=True) 306 | num_classes = len(all_classes) 307 | perm = np.arange(num_classes) 308 | np.random.shuffle(perm) 309 | chosen_class_inds = list(perm[:N]) 310 | 311 | paths, labels = [], [] # lists of length n_per_class * N 312 | for n in range(N): 313 | c = all_classes[chosen_class_inds[n]] 314 | root_folder_1 = OMNIGLOT_IMGS_BACKGROUND_ROT 315 | root_folder_2 = OMNIGLOT_IMGS_EVAL_ROT 316 | slash_ind = c.find('/') 317 | alphabet = c[:slash_ind] 318 | char = c[slash_ind + 1:] 319 | 320 | # Determine which folder this alphabet belongs to 321 | # (since the new splits may have mixed which alphabets are background/evaluation) 322 | # with respect to these folders corresponding to the old splits. 323 | path1 = os.path.join(root_folder_1, alphabet) 324 | path2 = os.path.join(root_folder_2, alphabet) 325 | if os.path.isdir(path1): 326 | alphabet_folder = path1 327 | else: 328 | alphabet_folder = path2 329 | 330 | # char is something like Grantha/character08_rot180 331 | underscore_ind = char.find('_') 332 | img_folder = os.path.join(alphabet_folder, char[:underscore_ind]) 333 | rot_angle = char[-3:] # one of '000', '090', '180', '270' 334 | img_files = [ 335 | img_f for img_f in os.listdir(img_folder) if img_f[-7:-4] == rot_angle 336 | ] 337 | 338 | img_example = img_files[0] # for example 1040_06_rot_090.png 339 | char_baselabel = img_example[:img_example.find('_')] # for example 1040 340 | 341 | perm = np.arange(20) 342 | np.random.shuffle(perm) 343 | chosen_drawer_inds = list(perm[:n_per_class]) 344 | for draw_ind, img in enumerate(img_files): 345 | if not draw_ind in chosen_drawer_inds: 346 | continue 347 | 348 | if len(str(draw_ind + 1)) < 2: 349 | str_ind = '0' + str(draw_ind + 1) 350 | else: 351 | str_ind = str(draw_ind + 1) 352 | 353 | img_name = char_baselabel + '_' + str_ind + '_rot_' + rot_angle + '.png' 354 | class_path = os.path.join(alphabet_folder, char[:underscore_ind]) 355 | img_path = os.path.join(class_path, img_name) 356 | paths.append(img_path) 357 | labels.append(n) 358 | return paths, labels 359 | -------------------------------------------------------------------------------- /src/eval/evaluator.py: -------------------------------------------------------------------------------- 1 | from __future__ import (absolute_import, division, print_function, 2 | unicode_literals) 3 | 4 | import os 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | 9 | class Evaluator(object): 10 | """ 11 | Class for Evaluation. 12 | """ 13 | 14 | def __init__(self, config, model, dataset, sess): 15 | self._config = config 16 | self._model = model 17 | self._dataset = dataset 18 | self._sess = sess 19 | 20 | @property 21 | def config(self): 22 | return self._config 23 | 24 | @property 25 | def model(self): 26 | return self._model 27 | 28 | @property 29 | def dataset(self): 30 | return self._dataset 31 | 32 | @property 33 | def sess(self): 34 | return self._sess 35 | 36 | def eval_mAP(self, batch): 37 | imgs = batch["imgs"] 38 | labels = batch["labels"] 39 | batch_size = len(labels) 40 | 41 | num_pos, num_neg, pos_inds, neg_inds = self.model.get_positive_negative_splits( 42 | batch) 43 | _feed_dict = { 44 | self.model.x: imgs, 45 | self.model.n_queries_to_parse: self.config.batch_size, 46 | self.model.num_pos: num_pos, 47 | self.model.num_neg: num_neg, 48 | self.model.pos_inds: pos_inds, 49 | self.model.neg_inds: neg_inds 50 | } 51 | eval_skipped_queries, eval_phi_pos, eval_phi_neg = self.sess.run( 52 | [self.model.skipped_queries, self.model.phi_pos, self.model.phi_neg], 53 | feed_dict=_feed_dict) 54 | skip = list(np.where(eval_skipped_queries == 1)[0]) 55 | 56 | query_APs = [] 57 | for q in range(batch_size): 58 | if q in skip: 59 | continue 60 | q_phi_pos = eval_phi_pos[q][:num_pos[q]] 61 | q_phi_neg = eval_phi_neg[q][:num_neg[q]] 62 | 63 | y_true = np.concatenate( 64 | (np.ones((q_phi_pos.shape[0])), np.zeros((q_phi_neg.shape[0]))), 65 | axis=0) 66 | y_scores = np.concatenate((q_phi_pos, q_phi_neg), axis=0) 67 | AP = self.apk(y_true, y_scores) 68 | query_APs.append(AP) 69 | 70 | mAP = None # If skipped all queries 71 | if len(query_APs) > 0: 72 | mAP = sum(query_APs) / float(len(query_APs)) 73 | return mAP 74 | 75 | def apk(self, relevant, scores, k=10000): 76 | """ 77 | Computes the average precision at k between two lists of items. 78 | 79 | :param relevant: A list of truth values for each point 80 | :param scores: A list of predicted scores for each point 81 | :param k: (optional) The maximum number of predicted elements 82 | :return: 83 | """ 84 | assert len(relevant) == len(scores) 85 | assert relevant.ndim == 1 and scores.ndim == 1 86 | ranks = np.argsort(scores)[::-1] # for example [1, 0, 3, 2] 87 | # sort in descending order the parallel lists relevant and scores according to ranks 88 | actual = np.array(relevant)[ranks] 89 | predicted = np.array(scores)[ranks] 90 | 91 | score = 0.0 92 | num_hits = 0.0 93 | for i, p in enumerate(predicted): 94 | if actual[i]: 95 | num_hits += 1.0 96 | score += num_hits / (i + 1.0) 97 | num_relevant = len(np.where(actual == 1)[0]) 98 | if min(num_relevant, k) > 0: 99 | AP = score / min(num_relevant, k) 100 | else: 101 | AP = None 102 | return AP 103 | 104 | def compute_confidence_interval(self, data): 105 | """ 106 | Compute 95% confidence interval 107 | :param data: An array of mean accuracy (or mAP) across a number of sampled episodes. 108 | :return: the 95% confidence interval for this data. 109 | """ 110 | a = 1.0 * np.array(data) 111 | m = np.mean(a) 112 | std = np.std(a) 113 | pm = 1.96 * (std / np.sqrt(len(a))) 114 | return m, pm 115 | 116 | def eval_fewshot_classif(self, K, N, num_samples=None): 117 | accuracy = [] 118 | if num_samples is None: 119 | num_samples = self.config.num_fewshot_samples 120 | for r in range(num_samples): 121 | ref_paths, query_paths, labels = self.dataset.create_KshotNway_classification_episode( 122 | K, N) 123 | 124 | if K == 1: 125 | accuracy.append( 126 | self.run_1shotNway_classif(ref_paths, query_paths, labels)) 127 | elif K > 1: 128 | accuracy.append( 129 | self.run_KshotNway_classif(ref_paths, query_paths, labels)) 130 | else: 131 | raise ValueError("Expecting K >= 1.") 132 | accuracy = np.array(accuracy) 133 | mean, plus_minus = self.compute_confidence_interval(accuracy) 134 | print("{}-shot-{}-way classif accuracy: {} plus/minus {}%".format( 135 | K, N, mean, plus_minus)) 136 | return mean, plus_minus 137 | 138 | def eval_oneshot_retrieval(self, N, n_per_class, num_samples=None): 139 | mAP = [] 140 | if num_samples is None: 141 | num_samples = self.config.num_fewshot_samples 142 | for r in range(num_samples): 143 | paths, labels = self.dataset.create_oneshotNway_retrieval_episode( 144 | N, n_per_class) 145 | mAP.append(self.run_1shotNway_retrieval(paths, labels)) 146 | mean, plus_minus = self.compute_confidence_interval(mAP) 147 | print("one-shot-{}-way retrieval mAP: {} plus/minus {}%".format( 148 | N, mean, plus_minus)) 149 | return mean, plus_minus 150 | 151 | def run_1shotNway_classif(self, ref_paths, query_paths, labels): 152 | """ 153 | Perform 1-shot N-way classification. 154 | 155 | :param ref_paths: A list of N lists, each containing the path of the single class references 156 | :param query_paths: A list of N lists, each containing the paths of the 19 class queries 157 | :param labels: A list of N labels, one for each class 158 | :return: mean_accuracy: Mean (across the queries) 1-shot N-way classification accuracy. 159 | """ 160 | 161 | def add_to_images(imgs_array, new_img_array): 162 | if not len(new_img_array.shape) == 4: 163 | _h = new_img_array.shape[0] 164 | _w = new_img_array.shape[1] 165 | _c = new_img_array.shape[2] 166 | new_img_array = new_img_array.reshape((1, _h, _w, _c)) 167 | if imgs_array.shape[0] == 0: 168 | imgs_array = new_img_array 169 | else: 170 | imgs_array = np.concatenate((imgs_array, new_img_array), axis=0) 171 | return imgs_array 172 | 173 | def read_data(): 174 | ref_labels, query_labels = [], [] 175 | K = len(ref_paths[0]) # number of representatives of each class 176 | for class_label in labels: 177 | ref_labels += [class_label] * K 178 | query_labels += [class_label] * (20 - K) 179 | 180 | imgs_queries, imgs_refs = np.array([]), np.array([]) 181 | for class_rpaths, class_qpaths in zip(ref_paths, query_paths): 182 | for rpath in class_rpaths: 183 | img_array = self.dataset.load_img_as_array(rpath) 184 | imgs_refs = add_to_images(imgs_refs, img_array) 185 | for qpath in class_qpaths: 186 | img_array = self.dataset.load_img_as_array(qpath) 187 | imgs_queries = add_to_images(imgs_queries, img_array) 188 | return imgs_refs, imgs_queries, ref_labels, query_labels 189 | 190 | imgs_refs, imgs_queries, ref_labels, query_labels = read_data() 191 | 192 | num_queries = imgs_queries.shape[0] 193 | 194 | # computational graph for computing pairwise similarities 195 | feats_queries = tf.placeholder(tf.float32, [None, None]) 196 | feats_references = tf.placeholder(tf.float32, [None, None]) 197 | S = tf.matmul( 198 | feats_queries, feats_references, 199 | transpose_b=True) # (num_queries, num_refs) 200 | 201 | feats_queries_ = self.sess.run( 202 | self.model.feats, feed_dict={self.model.x: imgs_queries}) 203 | feats_references_ = self.sess.run( 204 | self.model.feats, feed_dict={self.model.x: imgs_refs}) 205 | S_ = self.sess.run( 206 | S, 207 | feed_dict={ 208 | feats_queries: feats_queries_, 209 | feats_references: feats_references_ 210 | }) 211 | pred_ind = np.argmax(S_, axis=1) 212 | correct = 0.0 213 | for i in range(num_queries): 214 | pred_label = ref_labels[pred_ind[i]] 215 | true_label = query_labels[i] 216 | if pred_label == true_label: 217 | correct += 1.0 218 | mean_accuracy = 100 * correct / float(num_queries) 219 | return mean_accuracy 220 | 221 | def run_KshotNway_classif(self, ref_paths, query_paths, labels): 222 | """ 223 | Perform K-shot N-way classification. 224 | 225 | The algorithm for exploiting all K examples of each of the N new classes 226 | in order to make a classification decision is the following. 227 | Given a query point to be classified: 228 | For each class n \in N: 229 | Compute AP for the ranking of all K*N candidate points with the reference as query 230 | assuming the correct class is the nth 231 | (i.e. treat as relevant the points of the nth class and irrelevant everyhting else) 232 | At the end, classify the reference point into the class for which the average precision 233 | of the reference's ranking is highest when this class is treated as the groundtruth. 234 | 235 | :param ref_paths: A list of N lists, each containing the paths of the K references 236 | :param query_paths: A list of N lists, each containing the paths of the (20 - K) queries 237 | :param labels: A list of N labels, one for each class 238 | :return: mean_accuracy: Mean (across the queries) K-shot N-way classification accuracy. 239 | """ 240 | 241 | N = len(labels) 242 | 243 | def add_to_images(imgs_array, new_img_array): 244 | if not len(new_img_array.shape) == 4: 245 | _h = new_img_array.shape[0] 246 | _w = new_img_array.shape[1] 247 | _c = new_img_array.shape[2] 248 | new_img_array = new_img_array.reshape((1, _h, _w, _c)) 249 | if imgs_array.shape[0] == 0: 250 | imgs_array = new_img_array 251 | else: 252 | imgs_array = np.concatenate((imgs_array, new_img_array), axis=0) 253 | return imgs_array 254 | 255 | def read_data(): 256 | ref_labels, query_labels = [], [] 257 | K = len(ref_paths[0]) # number of representatives of each class 258 | for class_label in labels: 259 | ref_labels += [class_label] * K 260 | query_labels += [class_label] * (20 - K) 261 | 262 | imgs_queries, imgs_refs = np.array([]), np.array([]) 263 | for class_rpaths, class_qpaths in zip(ref_paths, query_paths): 264 | for rpath in class_rpaths: 265 | img_array = self.dataset.load_img_as_array(rpath) 266 | imgs_refs = add_to_images(imgs_refs, img_array) 267 | for qpath in class_qpaths: 268 | img_array = self.dataset.load_img_as_array(qpath) 269 | imgs_queries = add_to_images(imgs_queries, img_array) 270 | return imgs_refs, imgs_queries, ref_labels, query_labels 271 | 272 | def get_splits(candidates_labels, query_class): 273 | """ 274 | Separate the candidates into positive/negative for the 275 | query under the assumption that the query's class is 276 | query_class. 277 | 278 | :param candidates_labels: The labels of all K * N candidates 279 | :param query_class: The (assumed) class of the query 280 | :return: The splits for positive/negative set of the query 281 | under the assumption that it belongs to class query_class. 282 | """ 283 | 284 | pos_inds, neg_inds = [], [] 285 | for i in range(len(candidates_labels)): 286 | if candidates_labels[i] == query_class: 287 | pos_inds.append(i + 1) 288 | else: 289 | neg_inds.append(i + 1) 290 | num_pos, num_neg = np.array([len(pos_inds)]), np.array([len(neg_inds)]) 291 | pos_inds = np.array(pos_inds).reshape((1, -1)) 292 | neg_inds = np.array(neg_inds).reshape((1, -1)) 293 | return num_pos, num_neg, pos_inds, neg_inds 294 | 295 | imgs_refs, imgs_queries, ref_labels, query_labels = read_data() 296 | num_queries = len(imgs_queries) 297 | 298 | predicted_classes = [] # becomes length num_queries 299 | for q in range(num_queries): 300 | query_img = imgs_queries[q].reshape( 301 | (1, self.config.height, self.config.width, self.config.channels)) 302 | q_and_refs = np.concatenate((query_img, imgs_refs), axis=0) 303 | 304 | AP_for_different_ns = [] # list of length N eventually 305 | for n in range(N): 306 | num_pos, num_neg, pos_inds, neg_inds = get_splits(ref_labels, labels[n]) 307 | _feed_dict = { 308 | self.model.x: q_and_refs, 309 | self.model.n_queries_to_parse: 310 | 1, # we are only interested in computing AP for query q's ranking 311 | self.model.num_pos: num_pos, 312 | self.model.num_neg: num_neg, 313 | self.model.pos_inds: pos_inds, 314 | self.model.neg_inds: neg_inds 315 | } 316 | phi_pos = self.sess.run(self.model.phi_pos, feed_dict=_feed_dict) 317 | phi_neg = self.sess.run(self.model.phi_neg, feed_dict=_feed_dict) 318 | q_phi_pos = phi_pos[0, :][:num_pos[0]] 319 | q_phi_neg = phi_neg[0, :][:num_neg[0]] 320 | y_true = np.concatenate( 321 | (np.ones((q_phi_pos.shape[0])), np.zeros((q_phi_neg.shape[0]))), 322 | axis=0) 323 | y_scores = np.concatenate((q_phi_pos, q_phi_neg), axis=0) 324 | y_scores = y_scores.reshape((y_scores.shape[0],)) 325 | AP = self.apk(y_true, y_scores) 326 | AP_for_different_ns.append(AP) 327 | 328 | pred = np.argmax(AP_for_different_ns) 329 | predicted_classes.append(labels[pred]) 330 | 331 | correct = 0.0 332 | for i in range(num_queries): 333 | pred_label = predicted_classes[i] 334 | true_label = query_labels[i] 335 | if pred_label == true_label: 336 | correct += 1.0 337 | mean_accuracy = 100 * correct / float(num_queries) 338 | return mean_accuracy 339 | 340 | def run_1shotNway_retrieval(self, paths, labels): 341 | """ 342 | Perform 1-shot N-way retrieval. 343 | Given a "pool" of points, treat each one as a query and compute 344 | the AP of its ranking of the remaining points based on its 345 | predicted relevance to them. Report mAP across all queries. 346 | 347 | :param paths: A list of length 10 * N of paths of each point in the "pool" 348 | :param labels: A list of the corresponding labels for each point in the "pool" 349 | :return: mAP: The mean Average Precision over all queries in the "pool" 350 | """ 351 | 352 | def add_to_images(imgs_array, new_img_array): 353 | if not len(new_img_array.shape) == 4: 354 | _h = new_img_array.shape[0] 355 | _w = new_img_array.shape[1] 356 | _c = new_img_array.shape[2] 357 | new_img_array = new_img_array.reshape((1, _h, _w, _c)) 358 | if imgs_array.shape[0] == 0: 359 | imgs_array = new_img_array 360 | else: 361 | imgs_array = np.concatenate((imgs_array, new_img_array), axis=0) 362 | return imgs_array 363 | 364 | n_queries = len(paths) 365 | imgs = np.array([]) 366 | for path in paths: 367 | img_array = self.dataset.load_img_as_array(path) 368 | imgs = add_to_images(imgs, img_array) 369 | 370 | batch = {} 371 | batch["labels"] = labels 372 | num_pos, num_neg, pos_inds, neg_inds = self.model.get_positive_negative_splits( 373 | batch) 374 | _feed_dict = { 375 | self.model.x: imgs, 376 | self.model.n_queries_to_parse: len(imgs), 377 | self.model.num_pos: num_pos, 378 | self.model.num_neg: num_neg, 379 | self.model.pos_inds: pos_inds, 380 | self.model.neg_inds: neg_inds 381 | } 382 | phi_pos = self.sess.run(self.model.phi_pos, feed_dict=_feed_dict) 383 | phi_neg = self.sess.run(self.model.phi_neg, feed_dict=_feed_dict) 384 | 385 | query_AP = [] 386 | for q in range(n_queries): 387 | this_phi_pos = phi_pos[q][:num_pos[q]] 388 | this_phi_neg = phi_neg[q][:num_neg[q]] 389 | y_true = np.concatenate( 390 | (np.ones((this_phi_pos.shape[0])), np.zeros((this_phi_neg.shape[0]))), 391 | axis=0) 392 | y_scores = np.concatenate((this_phi_pos, this_phi_neg), axis=0) 393 | AP = self.apk(y_true, y_scores) 394 | query_AP.append(AP) 395 | mAP = sum(query_AP) / float(len(query_AP)) 396 | return mAP 397 | -------------------------------------------------------------------------------- /data/dataset_splits/omniglot/train.txt: -------------------------------------------------------------------------------- 1 | Japanese_(hiragana)/character10 2 | Aurek-Besh/character20 3 | Early_Aramaic/character02 4 | Grantha/character23 5 | Malay_(Jawi_-_Arabic)/character14 6 | Sanskrit/character36 7 | Tifinagh/character15 8 | Mkhedruli_(Georgian)/character29 9 | Atemayar_Qelisayer/character11 10 | Anglo-Saxon_Futhorc/character07 11 | Cyrillic/character32 12 | Bengali/character44 13 | Burmese_(Myanmar)/character31 14 | Blackfoot_(Canadian_Aboriginal_Syllabics)/character11 15 | Mkhedruli_(Georgian)/character39 16 | Alphabet_of_the_Magi/character09 17 | Armenian/character19 18 | Latin/character16 19 | Tifinagh/character11 20 | Anglo-Saxon_Futhorc/character26 21 | Sanskrit/character23 22 | Tifinagh/character12 23 | Armenian/character20 24 | Korean/character37 25 | Gujarati/character11 26 | Tagalog/character10 27 | Atemayar_Qelisayer/character08 28 | Glagolitic/character27 29 | Korean/character31 30 | Japanese_(katakana)/character19 31 | Grantha/character16 32 | Aurek-Besh/character12 33 | Mkhedruli_(Georgian)/character20 34 | Sanskrit/character18 35 | Malay_(Jawi_-_Arabic)/character18 36 | Grantha/character11 37 | Atlantean/character16 38 | Syriac_(Estrangelo)/character06 39 | Tifinagh/character36 40 | Cyrillic/character14 41 | Anglo-Saxon_Futhorc/character02 42 | Gujarati/character47 43 | Armenian/character16 44 | Arcadian/character20 45 | Gurmukhi/character11 46 | Early_Aramaic/character11 47 | Sanskrit/character17 48 | Glagolitic/character35 49 | Tagalog/character11 50 | Tifinagh/character09 51 | Bengali/character02 52 | Japanese_(katakana)/character35 53 | Tifinagh/character53 54 | Asomtavruli_(Georgian)/character35 55 | N_Ko/character14 56 | Burmese_(Myanmar)/character10 57 | Armenian/character18 58 | Ge_ez/character14 59 | Gujarati/character18 60 | Syriac_(Estrangelo)/character03 61 | N_Ko/character28 62 | Burmese_(Myanmar)/character08 63 | Gurmukhi/character41 64 | Futurama/character12 65 | Mkhedruli_(Georgian)/character09 66 | Avesta/character26 67 | Avesta/character05 68 | Bengali/character21 69 | Anglo-Saxon_Futhorc/character21 70 | Angelic/character02 71 | Cyrillic/character10 72 | Korean/character28 73 | Cyrillic/character03 74 | Angelic/character17 75 | Anglo-Saxon_Futhorc/character12 76 | Anglo-Saxon_Futhorc/character23 77 | Asomtavruli_(Georgian)/character06 78 | Japanese_(katakana)/character16 79 | Gujarati/character01 80 | Anglo-Saxon_Futhorc/character24 81 | Gujarati/character13 82 | N_Ko/character26 83 | Syriac_(Estrangelo)/character01 84 | Gurmukhi/character40 85 | Mkhedruli_(Georgian)/character08 86 | Aurek-Besh/character26 87 | Gurmukhi/character07 88 | Cyrillic/character15 89 | Japanese_(katakana)/character21 90 | Glagolitic/character20 91 | N_Ko/character33 92 | Gujarati/character33 93 | Bengali/character06 94 | Tagalog/character02 95 | Anglo-Saxon_Futhorc/character13 96 | Malay_(Jawi_-_Arabic)/character39 97 | Japanese_(katakana)/character43 98 | Asomtavruli_(Georgian)/character16 99 | Sanskrit/character30 100 | Inuktitut_(Canadian_Aboriginal_Syllabics)/character02 101 | Balinese/character05 102 | Greek/character20 103 | Korean/character38 104 | Greek/character06 105 | Early_Aramaic/character08 106 | Braille/character17 107 | Braille/character24 108 | Korean/character20 109 | Braille/character13 110 | N_Ko/character17 111 | Tifinagh/character47 112 | Malay_(Jawi_-_Arabic)/character02 113 | Grantha/character43 114 | Sanskrit/character11 115 | Inuktitut_(Canadian_Aboriginal_Syllabics)/character14 116 | Atemayar_Qelisayer/character26 117 | Gujarati/character48 118 | Aurek-Besh/character09 119 | Glagolitic/character11 120 | Grantha/character10 121 | Atemayar_Qelisayer/character12 122 | Korean/character09 123 | Malay_(Jawi_-_Arabic)/character32 124 | Futurama/character19 125 | Arcadian/character13 126 | Korean/character27 127 | Tifinagh/character38 128 | Glagolitic/character17 129 | Avesta/character13 130 | Blackfoot_(Canadian_Aboriginal_Syllabics)/character14 131 | Angelic/character07 132 | Angelic/character08 133 | Early_Aramaic/character18 134 | Mkhedruli_(Georgian)/character02 135 | Latin/character20 136 | Mkhedruli_(Georgian)/character04 137 | Malay_(Jawi_-_Arabic)/character29 138 | Malay_(Jawi_-_Arabic)/character21 139 | Armenian/character25 140 | Japanese_(katakana)/character34 141 | Bengali/character22 142 | Gurmukhi/character25 143 | Bengali/character05 144 | Greek/character09 145 | Japanese_(katakana)/character42 146 | Japanese_(katakana)/character08 147 | Ge_ez/character10 148 | Cyrillic/character07 149 | Futurama/character05 150 | Glagolitic/character28 151 | Ge_ez/character02 152 | Korean/character22 153 | Angelic/character20 154 | Avesta/character03 155 | Tifinagh/character35 156 | Tifinagh/character24 157 | Glagolitic/character10 158 | Angelic/character19 159 | Futurama/character22 160 | Gurmukhi/character26 161 | Burmese_(Myanmar)/character34 162 | Arcadian/character06 163 | Latin/character09 164 | Atlantean/character21 165 | Burmese_(Myanmar)/character03 166 | Cyrillic/character33 167 | Bengali/character10 168 | Anglo-Saxon_Futhorc/character15 169 | Arcadian/character18 170 | Atemayar_Qelisayer/character07 171 | Japanese_(katakana)/character17 172 | Gurmukhi/character12 173 | Avesta/character21 174 | Alphabet_of_the_Magi/character11 175 | Futurama/character21 176 | Grantha/character09 177 | Latin/character17 178 | Alphabet_of_the_Magi/character06 179 | Gurmukhi/character13 180 | Burmese_(Myanmar)/character19 181 | N_Ko/character09 182 | Grantha/character19 183 | Anglo-Saxon_Futhorc/character19 184 | Cyrillic/character22 185 | Tifinagh/character02 186 | Futurama/character14 187 | Tifinagh/character04 188 | Syriac_(Estrangelo)/character16 189 | Balinese/character21 190 | Grantha/character08 191 | Gurmukhi/character39 192 | Glagolitic/character43 193 | Glagolitic/character14 194 | Sanskrit/character25 195 | Futurama/character23 196 | Malay_(Jawi_-_Arabic)/character26 197 | Gurmukhi/character32 198 | Armenian/character12 199 | Japanese_(katakana)/character14 200 | Grantha/character33 201 | Glagolitic/character23 202 | Armenian/character35 203 | Asomtavruli_(Georgian)/character01 204 | Alphabet_of_the_Magi/character10 205 | Hebrew/character08 206 | Asomtavruli_(Georgian)/character20 207 | Korean/character08 208 | Japanese_(hiragana)/character20 209 | Avesta/character25 210 | Glagolitic/character40 211 | Gurmukhi/character34 212 | Hebrew/character16 213 | Korean/character01 214 | Avesta/character15 215 | Grantha/character38 216 | Mkhedruli_(Georgian)/character37 217 | Armenian/character21 218 | Asomtavruli_(Georgian)/character39 219 | Japanese_(hiragana)/character07 220 | Bengali/character12 221 | Early_Aramaic/character03 222 | Malay_(Jawi_-_Arabic)/character23 223 | Gujarati/character31 224 | Bengali/character35 225 | Asomtavruli_(Georgian)/character31 226 | Futurama/character15 227 | Hebrew/character01 228 | Atlantean/character26 229 | Japanese_(hiragana)/character21 230 | Futurama/character26 231 | Tifinagh/character42 232 | Gujarati/character34 233 | N_Ko/character32 234 | Balinese/character22 235 | Futurama/character20 236 | Early_Aramaic/character19 237 | Japanese_(hiragana)/character52 238 | Tifinagh/character16 239 | Avesta/character23 240 | Grantha/character39 241 | Early_Aramaic/character14 242 | Mkhedruli_(Georgian)/character16 243 | Gujarati/character04 244 | Balinese/character11 245 | Cyrillic/character18 246 | Burmese_(Myanmar)/character22 247 | Gurmukhi/character20 248 | Japanese_(hiragana)/character22 249 | Sanskrit/character06 250 | Syriac_(Estrangelo)/character10 251 | Syriac_(Estrangelo)/character04 252 | Latin/character13 253 | Korean/character03 254 | Syriac_(Estrangelo)/character14 255 | Malay_(Jawi_-_Arabic)/character17 256 | Tagalog/character01 257 | Japanese_(hiragana)/character05 258 | Tagalog/character16 259 | Atemayar_Qelisayer/character17 260 | Sanskrit/character32 261 | N_Ko/character20 262 | Gurmukhi/character29 263 | Japanese_(katakana)/character33 264 | Futurama/character13 265 | Asomtavruli_(Georgian)/character36 266 | Inuktitut_(Canadian_Aboriginal_Syllabics)/character04 267 | Japanese_(katakana)/character40 268 | Japanese_(hiragana)/character49 269 | Bengali/character20 270 | Atemayar_Qelisayer/character13 271 | Ge_ez/character18 272 | Japanese_(hiragana)/character33 273 | Mkhedruli_(Georgian)/character30 274 | Braille/character12 275 | Tifinagh/character34 276 | Cyrillic/character23 277 | Balinese/character23 278 | Balinese/character10 279 | Atemayar_Qelisayer/character14 280 | Armenian/character17 281 | Alphabet_of_the_Magi/character01 282 | Japanese_(hiragana)/character37 283 | Japanese_(katakana)/character30 284 | Gujarati/character22 285 | Glagolitic/character25 286 | Tifinagh/character54 287 | Malay_(Jawi_-_Arabic)/character35 288 | Korean/character30 289 | Latin/character06 290 | Glagolitic/character41 291 | Asomtavruli_(Georgian)/character03 292 | Atlantean/character23 293 | Avesta/character16 294 | Gujarati/character09 295 | Angelic/character05 296 | Gurmukhi/character22 297 | N_Ko/character22 298 | Malay_(Jawi_-_Arabic)/character22 299 | Latin/character23 300 | Avesta/character12 301 | Japanese_(hiragana)/character01 302 | Ojibwe_(Canadian_Aboriginal_Syllabics)/character02 303 | Tifinagh/character50 304 | Hebrew/character17 305 | Glagolitic/character42 306 | Blackfoot_(Canadian_Aboriginal_Syllabics)/character08 307 | Japanese_(katakana)/character28 308 | Bengali/character45 309 | Atlantean/character04 310 | Balinese/character03 311 | N_Ko/character06 312 | N_Ko/character19 313 | Tifinagh/character26 314 | Korean/character39 315 | Grantha/character32 316 | Korean/character24 317 | Gurmukhi/character08 318 | Atlantean/character10 319 | Avesta/character19 320 | Glagolitic/character03 321 | Ojibwe_(Canadian_Aboriginal_Syllabics)/character05 322 | Grantha/character29 323 | Bengali/character27 324 | N_Ko/character02 325 | N_Ko/character18 326 | Burmese_(Myanmar)/character05 327 | Latin/character03 328 | Tifinagh/character33 329 | Atemayar_Qelisayer/character15 330 | Gujarati/character35 331 | Japanese_(hiragana)/character18 332 | Arcadian/character22 333 | Greek/character22 334 | Gujarati/character36 335 | Greek/character13 336 | Futurama/character07 337 | N_Ko/character27 338 | Asomtavruli_(Georgian)/character32 339 | Alphabet_of_the_Magi/character15 340 | Atlantean/character07 341 | Glagolitic/character15 342 | Grantha/character07 343 | Korean/character07 344 | Avesta/character08 345 | Inuktitut_(Canadian_Aboriginal_Syllabics)/character16 346 | Anglo-Saxon_Futhorc/character04 347 | Inuktitut_(Canadian_Aboriginal_Syllabics)/character09 348 | Hebrew/character10 349 | Japanese_(katakana)/character29 350 | Korean/character06 351 | Grantha/character34 352 | Hebrew/character18 353 | Cyrillic/character17 354 | Anglo-Saxon_Futhorc/character11 355 | Greek/character23 356 | Inuktitut_(Canadian_Aboriginal_Syllabics)/character06 357 | Glagolitic/character13 358 | Malay_(Jawi_-_Arabic)/character13 359 | Anglo-Saxon_Futhorc/character20 360 | Latin/character25 361 | Japanese_(katakana)/character46 362 | Mkhedruli_(Georgian)/character41 363 | Asomtavruli_(Georgian)/character15 364 | Japanese_(hiragana)/character15 365 | Futurama/character10 366 | Japanese_(katakana)/character01 367 | Ojibwe_(Canadian_Aboriginal_Syllabics)/character04 368 | Mkhedruli_(Georgian)/character03 369 | Syriac_(Estrangelo)/character02 370 | Malay_(Jawi_-_Arabic)/character34 371 | Korean/character02 372 | Tifinagh/character48 373 | Asomtavruli_(Georgian)/character26 374 | Mkhedruli_(Georgian)/character40 375 | Gujarati/character20 376 | Grantha/character02 377 | Tagalog/character14 378 | Greek/character07 379 | N_Ko/character07 380 | Armenian/character26 381 | Atemayar_Qelisayer/character20 382 | Arcadian/character14 383 | Grantha/character25 384 | Japanese_(katakana)/character18 385 | Malay_(Jawi_-_Arabic)/character37 386 | Tifinagh/character03 387 | Cyrillic/character27 388 | Bengali/character13 389 | Tifinagh/character51 390 | Braille/character11 391 | Angelic/character18 392 | Bengali/character37 393 | Japanese_(katakana)/character04 394 | Syriac_(Estrangelo)/character11 395 | Atlantean/character17 396 | Mkhedruli_(Georgian)/character14 397 | Ge_ez/character07 398 | Glagolitic/character45 399 | Latin/character11 400 | Tifinagh/character55 401 | Japanese_(katakana)/character39 402 | Asomtavruli_(Georgian)/character18 403 | Tifinagh/character46 404 | Balinese/character06 405 | Bengali/character36 406 | Asomtavruli_(Georgian)/character13 407 | Blackfoot_(Canadian_Aboriginal_Syllabics)/character12 408 | Alphabet_of_the_Magi/character14 409 | Gujarati/character45 410 | Armenian/character01 411 | Atlantean/character20 412 | Mkhedruli_(Georgian)/character15 413 | Japanese_(hiragana)/character44 414 | Gujarati/character24 415 | Bengali/character01 416 | Braille/character04 417 | Syriac_(Estrangelo)/character12 418 | Tagalog/character15 419 | Tifinagh/character14 420 | Korean/character36 421 | Bengali/character16 422 | Mkhedruli_(Georgian)/character31 423 | Glagolitic/character30 424 | Inuktitut_(Canadian_Aboriginal_Syllabics)/character03 425 | Mkhedruli_(Georgian)/character24 426 | Asomtavruli_(Georgian)/character24 427 | Greek/character17 428 | Atlantean/character13 429 | Grantha/character22 430 | Alphabet_of_the_Magi/character17 431 | Armenian/character14 432 | Gujarati/character43 433 | Armenian/character32 434 | Ge_ez/character09 435 | Malay_(Jawi_-_Arabic)/character19 436 | Ojibwe_(Canadian_Aboriginal_Syllabics)/character13 437 | Greek/character01 438 | Japanese_(hiragana)/character34 439 | Balinese/character12 440 | Burmese_(Myanmar)/character23 441 | Bengali/character04 442 | Glagolitic/character24 443 | Japanese_(hiragana)/character42 444 | Asomtavruli_(Georgian)/character11 445 | N_Ko/character16 446 | Tifinagh/character25 447 | Burmese_(Myanmar)/character06 448 | Japanese_(katakana)/character45 449 | Aurek-Besh/character19 450 | Braille/character16 451 | Gujarati/character28 452 | Bengali/character24 453 | Armenian/character28 454 | Aurek-Besh/character02 455 | N_Ko/character25 456 | Gurmukhi/character35 457 | Sanskrit/character24 458 | Alphabet_of_the_Magi/character12 459 | Glagolitic/character16 460 | Armenian/character34 461 | Glagolitic/character05 462 | Blackfoot_(Canadian_Aboriginal_Syllabics)/character04 463 | Japanese_(hiragana)/character16 464 | Latin/character21 465 | Ge_ez/character16 466 | Gujarati/character37 467 | Tifinagh/character37 468 | Malay_(Jawi_-_Arabic)/character31 469 | Syriac_(Estrangelo)/character09 470 | Aurek-Besh/character03 471 | Gurmukhi/character03 472 | Bengali/character09 473 | Alphabet_of_the_Magi/character13 474 | Grantha/character05 475 | Bengali/character29 476 | Korean/character13 477 | Arcadian/character10 478 | Asomtavruli_(Georgian)/character34 479 | Tifinagh/character19 480 | Gurmukhi/character31 481 | Syriac_(Estrangelo)/character21 482 | Atlantean/character05 483 | Japanese_(katakana)/character27 484 | Avesta/character20 485 | Aurek-Besh/character04 486 | Gujarati/character40 487 | Ojibwe_(Canadian_Aboriginal_Syllabics)/character10 488 | Latin/character04 489 | Aurek-Besh/character07 490 | Hebrew/character02 491 | Aurek-Besh/character23 492 | Glagolitic/character02 493 | Gujarati/character14 494 | Tagalog/character05 495 | Hebrew/character03 496 | Asomtavruli_(Georgian)/character04 497 | Bengali/character30 498 | Atlantean/character19 499 | Burmese_(Myanmar)/character16 500 | Bengali/character07 501 | Asomtavruli_(Georgian)/character09 502 | Gujarati/character30 503 | Ojibwe_(Canadian_Aboriginal_Syllabics)/character14 504 | Inuktitut_(Canadian_Aboriginal_Syllabics)/character08 505 | Malay_(Jawi_-_Arabic)/character09 506 | Tifinagh/character07 507 | Gurmukhi/character21 508 | Inuktitut_(Canadian_Aboriginal_Syllabics)/character13 509 | Gurmukhi/character06 510 | Braille/character02 511 | Blackfoot_(Canadian_Aboriginal_Syllabics)/character06 512 | Angelic/character01 513 | Hebrew/character11 514 | Tagalog/character13 515 | Braille/character15 516 | Mkhedruli_(Georgian)/character23 517 | Malay_(Jawi_-_Arabic)/character11 518 | Glagolitic/character37 519 | Atlantean/character11 520 | Atlantean/character15 521 | Tifinagh/character27 522 | Malay_(Jawi_-_Arabic)/character40 523 | Malay_(Jawi_-_Arabic)/character33 524 | Early_Aramaic/character09 525 | Bengali/character25 526 | Sanskrit/character20 527 | Cyrillic/character09 528 | Alphabet_of_the_Magi/character16 529 | Tifinagh/character44 530 | Korean/character21 531 | Atemayar_Qelisayer/character25 532 | Gurmukhi/character33 533 | Mkhedruli_(Georgian)/character11 534 | Grantha/character40 535 | Bengali/character43 536 | Futurama/character01 537 | Japanese_(hiragana)/character40 538 | Burmese_(Myanmar)/character15 539 | Arcadian/character01 540 | Gujarati/character46 541 | Gurmukhi/character10 542 | Mkhedruli_(Georgian)/character28 543 | Grantha/character37 544 | Cyrillic/character05 545 | Bengali/character46 546 | Gujarati/character19 547 | Aurek-Besh/character17 548 | Armenian/character30 549 | Asomtavruli_(Georgian)/character17 550 | Greek/character19 551 | Mkhedruli_(Georgian)/character34 552 | Japanese_(katakana)/character13 553 | Sanskrit/character08 554 | Malay_(Jawi_-_Arabic)/character36 555 | Glagolitic/character04 556 | Inuktitut_(Canadian_Aboriginal_Syllabics)/character05 557 | Malay_(Jawi_-_Arabic)/character07 558 | Anglo-Saxon_Futhorc/character27 559 | Braille/character06 560 | Korean/character40 561 | Korean/character25 562 | Angelic/character06 563 | Greek/character05 564 | Atemayar_Qelisayer/character16 565 | Cyrillic/character20 566 | Avesta/character10 567 | Anglo-Saxon_Futhorc/character09 568 | Futurama/character16 569 | Greek/character04 570 | Hebrew/character20 571 | Malay_(Jawi_-_Arabic)/character28 572 | Aurek-Besh/character11 573 | Gujarati/character02 574 | Futurama/character24 575 | Arcadian/character09 576 | Sanskrit/character03 577 | Burmese_(Myanmar)/character26 578 | Anglo-Saxon_Futhorc/character05 579 | Hebrew/character12 580 | Arcadian/character25 581 | Glagolitic/character18 582 | Avesta/character09 583 | Ge_ez/character22 584 | Gurmukhi/character15 585 | Inuktitut_(Canadian_Aboriginal_Syllabics)/character11 586 | Korean/character32 587 | Arcadian/character03 588 | Aurek-Besh/character25 589 | Atlantean/character12 590 | Early_Aramaic/character05 591 | Sanskrit/character29 592 | Sanskrit/character41 593 | Armenian/character10 594 | Tifinagh/character43 595 | Blackfoot_(Canadian_Aboriginal_Syllabics)/character02 596 | Sanskrit/character12 597 | Arcadian/character23 598 | Japanese_(hiragana)/character50 599 | Atemayar_Qelisayer/character22 600 | Balinese/character01 601 | Atlantean/character02 602 | Grantha/character24 603 | Glagolitic/character01 604 | Syriac_(Estrangelo)/character07 605 | Alphabet_of_the_Magi/character08 606 | Avesta/character22 607 | Atlantean/character03 608 | Japanese_(hiragana)/character12 609 | Balinese/character09 610 | Alphabet_of_the_Magi/character19 611 | Grantha/character31 612 | Tagalog/character09 613 | N_Ko/character08 614 | Malay_(Jawi_-_Arabic)/character24 615 | Sanskrit/character42 616 | Japanese_(katakana)/character41 617 | Latin/character14 618 | Braille/character20 619 | Braille/character23 620 | Asomtavruli_(Georgian)/character33 621 | Greek/character14 622 | Armenian/character23 623 | Gujarati/character42 624 | Arcadian/character24 625 | Gujarati/character03 626 | Gurmukhi/character28 627 | Ge_ez/character15 628 | Asomtavruli_(Georgian)/character07 629 | Cyrillic/character04 630 | Early_Aramaic/character07 631 | Mkhedruli_(Georgian)/character38 632 | Avesta/character14 633 | Ge_ez/character03 634 | N_Ko/character03 635 | Japanese_(katakana)/character24 636 | Atemayar_Qelisayer/character04 637 | Malay_(Jawi_-_Arabic)/character03 638 | Burmese_(Myanmar)/character01 639 | Mkhedruli_(Georgian)/character10 640 | Japanese_(hiragana)/character25 641 | Early_Aramaic/character20 642 | Japanese_(hiragana)/character14 643 | Aurek-Besh/character10 644 | Ge_ez/character17 645 | Angelic/character10 646 | Hebrew/character22 647 | Japanese_(katakana)/character20 648 | Korean/character11 649 | Burmese_(Myanmar)/character30 650 | Mkhedruli_(Georgian)/character22 651 | Malay_(Jawi_-_Arabic)/character10 652 | Grantha/character13 653 | Syriac_(Estrangelo)/character23 654 | Anglo-Saxon_Futhorc/character16 655 | Aurek-Besh/character08 656 | Tagalog/character12 657 | Atlantean/character01 658 | N_Ko/character30 659 | Bengali/character03 660 | Atlantean/character08 661 | Bengali/character39 662 | Mkhedruli_(Georgian)/character12 663 | Cyrillic/character21 664 | Grantha/character28 665 | Latin/character02 666 | Syriac_(Estrangelo)/character19 667 | N_Ko/character12 668 | Ge_ez/character01 669 | Ojibwe_(Canadian_Aboriginal_Syllabics)/character03 670 | Futurama/character25 671 | Glagolitic/character36 672 | Glagolitic/character08 673 | Japanese_(hiragana)/character31 674 | Sanskrit/character19 675 | Japanese_(katakana)/character32 676 | Burmese_(Myanmar)/character13 677 | Anglo-Saxon_Futhorc/character17 678 | Blackfoot_(Canadian_Aboriginal_Syllabics)/character03 679 | Latin/character10 680 | Hebrew/character15 681 | Angelic/character11 682 | Tifinagh/character17 683 | Atemayar_Qelisayer/character02 684 | Asomtavruli_(Georgian)/character27 685 | Gujarati/character16 686 | Grantha/character01 687 | Glagolitic/character38 688 | Korean/character17 689 | Bengali/character11 690 | Korean/character19 691 | Japanese_(katakana)/character02 692 | Gujarati/character39 693 | Futurama/character06 694 | Glagolitic/character31 695 | Sanskrit/character38 696 | Korean/character33 697 | Ojibwe_(Canadian_Aboriginal_Syllabics)/character07 698 | Grantha/character03 699 | Grantha/character06 700 | Syriac_(Estrangelo)/character05 701 | Korean/character16 702 | Grantha/character27 703 | Mkhedruli_(Georgian)/character05 704 | Japanese_(hiragana)/character11 705 | Tifinagh/character10 706 | Japanese_(katakana)/character15 707 | Tagalog/character08 708 | Grantha/character42 709 | Tifinagh/character21 710 | Tagalog/character03 711 | Mkhedruli_(Georgian)/character18 712 | Early_Aramaic/character22 713 | Balinese/character14 714 | Japanese_(hiragana)/character19 715 | Armenian/character09 716 | Japanese_(hiragana)/character28 717 | Balinese/character04 718 | Mkhedruli_(Georgian)/character13 719 | Greek/character08 720 | Malay_(Jawi_-_Arabic)/character05 721 | -------------------------------------------------------------------------------- /src/models/model.py: -------------------------------------------------------------------------------- 1 | from __future__ import (absolute_import, division, print_function, 2 | unicode_literals) 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | import tensorflow.contrib.slim as slim 7 | 8 | 9 | class Model(object): 10 | """ 11 | Model class. 12 | """ 13 | 14 | def __init__(self, config, reuse=False): 15 | self._config = config 16 | self._x = tf.placeholder( 17 | tf.float32, [None, config.height, config.width, config.channels], 18 | name="x") 19 | 20 | embedding = self.forward_pass(reuse) 21 | self._feats = tf.truediv( 22 | embedding, 23 | tf.sqrt(tf.reduce_sum(tf.square(embedding), 1, keep_dims=True))) 24 | 25 | # Number of relevant points for each query 26 | self._num_pos = tf.placeholder(tf.int32, [None], name="num_pos") 27 | self._num_neg = tf.placeholder(tf.int32, [None], name="num_neg") 28 | self._batch_size = tf.shape(self._x)[0] 29 | 30 | # The inds belonging to the positive and negative sets for each query 31 | self._pos_inds = tf.placeholder(tf.int32, [None, None], name="pos_inds") 32 | self._neg_inds = tf.placeholder(tf.int32, [None, None], name="neg_inds") 33 | 34 | self._n_queries_to_parse = tf.placeholder( 35 | tf.int32, [], name="n_queries_to_parse") 36 | 37 | # The solution of loss-augmented inference for each query 38 | self._Y_aug = tf.placeholder( 39 | tf.float32, [None, None, None], 40 | name="Y_aug") # (num queries, num_pos, num_neg) 41 | 42 | self._phi_pos, self._phi_neg, self._mAP_score_std, \ 43 | self._mAP_score_aug, self._mAP_score_GT, self._skipped_queries = self.perform_inference_mAP() 44 | self._loss = self.compute_loss() 45 | self._train_step = self.get_train_step() 46 | 47 | def forward_pass(self, reuse, _print=True): 48 | """ 49 | Perform a forward pass through the network 50 | 51 | :param reuse: Whether to re-use the network's parameters. 52 | :param _print: Whether to print the shapes of activations. 53 | :return: 54 | """ 55 | 56 | def print_activations(t, _print=True): 57 | if _print: 58 | print(t.op.name, ' ', t.get_shape().as_list()) 59 | 60 | conv1 = slim.convolution2d( 61 | self.x, 62 | 64, 63 | kernel_size=[3, 3], 64 | activation_fn=slim.nn.relu, 65 | normalizer_fn=slim.batch_norm, 66 | scope='conv1', 67 | reuse=reuse) 68 | print_activations(conv1, _print=_print) 69 | pool1 = slim.max_pool2d(conv1, [2, 2], padding='VALID', scope='pool1') 70 | print_activations(pool1, _print=_print) 71 | 72 | conv2 = slim.convolution2d( 73 | pool1, 74 | 64, 75 | kernel_size=[3, 3], 76 | activation_fn=slim.nn.relu, 77 | normalizer_fn=slim.batch_norm, 78 | scope='conv2', 79 | reuse=reuse) 80 | print_activations(conv2, _print=_print) 81 | pool2 = slim.max_pool2d(conv2, [2, 2], padding='VALID', scope='pool2') 82 | print_activations(pool2, _print=_print) 83 | 84 | conv3 = slim.convolution2d( 85 | pool2, 86 | 64, 87 | kernel_size=[3, 3], 88 | activation_fn=slim.nn.relu, 89 | normalizer_fn=slim.batch_norm, 90 | scope='conv3', 91 | reuse=reuse) 92 | print_activations(conv3, _print=_print) 93 | pool3 = slim.max_pool2d(conv3, [2, 2], padding='VALID', scope='pool3') 94 | print_activations(pool3, _print=_print) 95 | 96 | conv4 = slim.convolution2d( 97 | pool3, 98 | 64, 99 | kernel_size=[3, 3], 100 | activation_fn=slim.nn.relu, 101 | normalizer_fn=slim.batch_norm, 102 | scope='conv4', 103 | reuse=reuse) 104 | print_activations(conv4, _print=_print) 105 | pool4 = slim.max_pool2d(conv4, [2, 2], padding='VALID', scope='pool4') 106 | print_activations(pool4, _print=_print) 107 | 108 | all_but_first_dims = pool4.get_shape().as_list()[1:] 109 | mult_dims = 1 110 | for dim in all_but_first_dims: 111 | mult_dims = mult_dims * dim 112 | embedding = tf.reshape(pool4, [-1, mult_dims]) 113 | print_activations(embedding, _print=_print) 114 | 115 | return embedding 116 | 117 | def get_train_step(self): 118 | lr = tf.get_variable( 119 | "learning_rate", 120 | shape=[], 121 | initializer=tf.constant_initializer(self.config.lr), 122 | dtype=tf.float32, 123 | trainable=False) 124 | if self.config.optimizer == 'ADAM': 125 | optimizer = tf.train.AdamOptimizer(learning_rate=lr) 126 | elif self.config.optimizer == 'SGD': 127 | optimizer = tf.train.GradientDescentOptimizer(learning_rate=lr) 128 | elif self.config.optimizer == 'SGD_momentum': 129 | optimizer = tf.train.MomentumOptimizer(learning_rate=lr, momentum=0.9) 130 | train_step = optimizer.minimize(self.loss) 131 | self._lr = lr 132 | self._new_lr = tf.placeholder(tf.float32, [], name="new_lr") 133 | self._assign_lr = tf.assign(self._lr, self._new_lr) 134 | return train_step 135 | 136 | def compute_siamese_accuracy(self): 137 | raise NotImplemented 138 | 139 | def get_positive_negative_splits(self, batch): 140 | """ 141 | Compute the positive/negative sets for each query in the batch. 142 | 143 | :param batch: A batch of points 144 | :return: num_pos: An array of the number of positive points of each query 145 | :return: num_neg: An array of the number of negative points of each query 146 | :return: pos_inds: An array of the inds of positive points of each query 147 | :return: neg_inds: An array of the inds of negative points of each query 148 | """ 149 | 150 | pos_inds, neg_inds = [], [] # Lists of lists 151 | num_pos, num_neg = [], [] 152 | batch_labels = batch["labels"] 153 | batch_size = len(batch_labels) 154 | 155 | def get_query_splits(query_ind): 156 | """ 157 | Create the splits for a single query. 158 | 159 | :param query_ind: Index of the query in the batch 160 | :return: q_pos_inds: The inds of the query's positive points 161 | :return: q_neg_inds: The inds of the query's negative points 162 | """ 163 | 164 | query_label = batch_labels[query_ind] 165 | q_pos_inds, q_neg_inds = [], [] 166 | for i in range(batch_size): 167 | if i == query_ind: 168 | continue 169 | this_label = batch_labels[i] 170 | if this_label == query_label: 171 | q_pos_inds.append(i) 172 | else: # negative point 173 | q_neg_inds.append(i) 174 | return q_pos_inds, q_neg_inds 175 | 176 | for i in range(batch_size): 177 | q_pos_inds, q_neg_inds = get_query_splits(i) 178 | pos_inds.append(q_pos_inds) 179 | neg_inds.append(q_neg_inds) 180 | num_pos.append(len(q_pos_inds)) 181 | num_neg.append(len(q_neg_inds)) 182 | 183 | # Pad all lists to be max length and make np arrays 184 | max_pos, max_neg = max(num_pos), max(num_neg) 185 | pos_inds_np = np.zeros((batch_size, max_pos)) 186 | neg_inds_np = np.zeros((batch_size, max_neg)) 187 | for i in range(batch_size): 188 | this_len_pos = len(pos_inds[i]) 189 | pos_inds_np[i, :this_len_pos] = pos_inds[i] 190 | this_len_neg = len(neg_inds[i]) 191 | neg_inds_np[i, :this_len_neg] = neg_inds[i] 192 | num_pos, num_neg = np.array(num_pos), np.array(num_neg) 193 | pos_inds, neg_inds = pos_inds_np, neg_inds_np 194 | return [num_pos, num_neg, pos_inds, neg_inds] 195 | 196 | def perform_query_inference(self, q_feats, q_pos_feats, q_neg_feats, 197 | q_num_pos, q_num_neg, q_Y_aug): 198 | """ 199 | Inference for a specific query. 200 | 201 | :param q_feats: the features for the query 202 | :param q_pos_feats: the features of the query's positive points 203 | :param q_neg_feats: the features of the query's negative points 204 | :param q_num_pos: the number of positive points for the query 205 | :param q_num_neg: the number of negative points for the query 206 | :param q_Y_aug: the solution of loss-augmented inference for this query 207 | 208 | :return: phi_pos: the similarity between the query and each positive point 209 | :return: phi_neg: the similarity between the query and each negative point 210 | :return: AP_score_std: the score of the standard inference solution for AP of this query 211 | :return: AP_score_aug: the score of the loss-augmented inference solution for AP of this query 212 | :return: AP_score_GT: the score of the ground truth solution for AP of this query 213 | """ 214 | 215 | S_pos = tf.matmul(q_feats, q_pos_feats, transpose_b=True) # (1, num_pos) 216 | S_neg = tf.matmul(q_feats, q_neg_feats, transpose_b=True) # (1, num_neg) 217 | phi_pos, sorted_inds_pos = tf.nn.top_k(S_pos, k=q_num_pos) 218 | phi_neg, sorted_inds_neg = tf.nn.top_k(S_neg, k=q_num_neg) 219 | 220 | phi_pos = tf.transpose(phi_pos) 221 | phi_neg = tf.transpose(phi_neg) 222 | 223 | # Score of standard inference 224 | phi_pos_expanded = tf.tile(phi_pos, [1, q_num_neg]) # (num_pos, num_neg) 225 | phi_neg_expanded = tf.tile(tf.transpose(phi_neg), [q_num_pos, 226 | 1]) # (num_pos, num_neg) 227 | temp1_Y = tf.greater(phi_pos_expanded, 228 | phi_neg_expanded) # (num_pos, num_neg) of True/False's 229 | temp2_Y = 2. * tf.to_float(temp1_Y) # (num_pos, num_neg) of 2/0's 230 | Y_std = temp2_Y - tf.ones_like(temp2_Y) # (num_pos, num_neg) of 1/-1's 231 | F_std = Y_std * (phi_pos_expanded - phi_neg_expanded) # (num_pos, num_neg) 232 | AP_score_std = tf.truediv( 233 | tf.reduce_sum(F_std), tf.to_float(q_num_pos * q_num_neg)) 234 | 235 | # Score of loss-augmented inferred ranking 236 | F_aug = q_Y_aug * (phi_pos_expanded - phi_neg_expanded) 237 | AP_score_aug = tf.truediv( 238 | tf.reduce_sum(F_aug), tf.to_float(q_num_pos * q_num_neg)) 239 | 240 | # Score of the groundtruth 241 | q_Y_GT = tf.ones_like(Y_std) 242 | F_GT = q_Y_GT * (phi_pos_expanded - phi_neg_expanded) 243 | AP_score_GT = tf.truediv( 244 | tf.reduce_sum(F_GT), tf.to_float(q_num_pos * q_num_neg)) 245 | 246 | AP_score_std = tf.reshape(AP_score_std, [1, 1]) 247 | AP_score_aug = tf.reshape(AP_score_aug, [1, 1]) 248 | AP_score_GT = tf.reshape(AP_score_GT, [1, 1]) 249 | return phi_pos, phi_neg, AP_score_std, AP_score_aug, AP_score_GT 250 | 251 | def perform_inference_mAP(self): 252 | """ 253 | Perform inference for the task loss of mAP. 254 | This involves looping over the different queries 255 | in the batch to compute the AP of each and then 256 | composing these scores. 257 | 258 | """ 259 | 260 | def body(i, skipped_q_prev, score_std_prev, score_aug_prev, score_GT_prev, 261 | phi_pos_prev, phi_neg_prev): 262 | """ 263 | 264 | :param i: The index of the query currently being considered 265 | :param skipped_q_prev: Binary array indicating whether queries were skipped, up till the ith 266 | :param score_std_prev: Score of standard inference solution for queries up till the ith 267 | :param score_aug_prev: Score of loss-augmented inference solution for queries up till the ith 268 | :param score_GT_prev: Score of ground truth solution for queries up till the ith 269 | :param phi_pos_prev: Cosine similarities between each of the first (i-1) queries and their positives 270 | :param phi_neg_prev: Cosine similarities between each of the first (i-1) queries and their negatives 271 | 272 | :return: The same quantities but after having incorporated the ith query as well. 273 | 274 | """ 275 | 276 | query_feats = tf.reshape(tf.gather(self.feats, i), [1, -1]) 277 | pos_feats_inds = tf.gather(self.pos_inds, i) 278 | pos_feats = tf.gather(self.feats, pos_feats_inds) 279 | neg_feats_inds = tf.gather(self.neg_inds, i) 280 | neg_feats = tf.gather(self.feats, neg_feats_inds) 281 | q_num_pos = tf.gather(self.num_pos, i) 282 | q_num_neg = tf.gather(self.num_neg, i) 283 | _q_Y_aug = tf.gather(self.Y_aug, i) 284 | 285 | max_pos = tf.reduce_max(self.num_pos) 286 | max_neg = tf.reduce_max(self.num_neg) 287 | 288 | q_Y_aug = tf.slice(_q_Y_aug, [0, 0], [q_num_pos, q_num_neg]) 289 | 290 | # Case where the ith point in the batch forms a non-empty positive set 291 | def _use_query(): 292 | q_phi_pos, q_phi_neg, q_score_std, q_score_aug, q_score_GT = self.perform_query_inference( 293 | query_feats, pos_feats, neg_feats, q_num_pos, q_num_neg, q_Y_aug) 294 | 295 | # In what follows, we update score_std_next, score_aug_next 296 | # Note that due to the requirement of quantities score_std_next, score_aug_next, etc to have 297 | # the same shape in each iteration, we zero-pad them to be batch_size-long. 298 | def _first_score(): 299 | this_score_std_padded = tf.concat( 300 | [q_score_std, tf.zeros([self.batch_size - i - 1, 1])], 0) 301 | this_score_aug_padded = tf.concat( 302 | [q_score_aug, tf.zeros([self.batch_size - i - 1, 1])], 0) 303 | this_score_GT_padded = tf.concat( 304 | [q_score_GT, tf.zeros([self.batch_size - i - 1, 1])], 0) 305 | score_std_next = tf.add(score_std_prev, this_score_std_padded) 306 | score_aug_next = tf.add(score_aug_prev, this_score_aug_padded) 307 | score_GT_next = tf.add(score_GT_prev, this_score_GT_padded) 308 | return score_std_next, score_aug_next, score_GT_next 309 | 310 | def _else_score(): 311 | temp = tf.concat([tf.zeros([i, 1]), q_score_std], 0) 312 | this_score_std_padded = tf.concat( 313 | [temp, tf.zeros([self.batch_size - i - 1, 1])], 0) 314 | temp = tf.concat([tf.zeros([i, 1]), q_score_aug], 0) 315 | this_score_aug_padded = tf.concat( 316 | [temp, tf.zeros([self.batch_size - i - 1, 1])], 0) 317 | temp = tf.concat([tf.zeros([i, 1]), q_score_GT], 0) 318 | this_score_GT_padded = tf.concat( 319 | [temp, tf.zeros([self.batch_size - i - 1, 1])], 0) 320 | score_std_next = tf.add(score_std_prev, this_score_std_padded) 321 | score_aug_next = tf.add(score_aug_prev, this_score_aug_padded) 322 | score_GT_next = tf.add(score_GT_prev, this_score_GT_padded) 323 | return score_std_next, score_aug_next, score_GT_next 324 | 325 | score_std_next, score_aug_next, score_GT_next = tf.cond( 326 | tf.equal(i, tf.constant(0)), _first_score, _else_score) 327 | 328 | q_phi_pos = tf.transpose(q_phi_pos) 329 | q_phi_neg = tf.transpose(q_phi_neg) 330 | 331 | # Concatenate with appropriate amount of zeros, to pad along each row 332 | # This padding is different from above: we're not padding along the dimension 333 | # of queries in order to get to batch_size, but we're padding the phi_pos of 334 | # just the single query to get to max_pos 335 | # (since not all queries will have the same number of positive points) 336 | padding_phi_pos = tf.zeros([1, max_pos - q_num_pos]) 337 | padding_phi_neg = tf.zeros([1, max_neg - q_num_neg]) 338 | this_phi_pos = tf.concat([q_phi_pos, padding_phi_pos], 1) 339 | this_phi_neg = tf.concat([q_phi_neg, padding_phi_neg], 1) 340 | 341 | # Update phi_pos_next, phi_neg_next 342 | def _first_phi(): 343 | this_phi_pos_padded = tf.concat( 344 | [this_phi_pos, tf.zeros([self.batch_size - i - 1, max_pos])], 0) 345 | this_phi_neg_padded = tf.concat( 346 | [this_phi_neg, tf.zeros([self.batch_size - i - 1, max_neg])], 0) 347 | phi_pos_next = tf.add(phi_pos_prev, this_phi_pos_padded) 348 | phi_neg_next = tf.add(phi_neg_prev, this_phi_neg_padded) 349 | return phi_pos_next, phi_neg_next 350 | 351 | def _else_phi(): 352 | temp = tf.concat([tf.zeros([i, max_pos]), this_phi_pos], 0) 353 | this_phi_pos_padded = tf.concat( 354 | [temp, tf.zeros([self.batch_size - i - 1, max_pos])], 0) 355 | temp = tf.concat([tf.zeros([i, max_neg]), this_phi_neg], 0) 356 | this_phi_neg_padded = tf.concat( 357 | [temp, tf.zeros([self.batch_size - i - 1, max_neg])], 0) 358 | phi_pos_next = tf.add(phi_pos_prev, this_phi_pos_padded) 359 | phi_neg_next = tf.add(phi_neg_prev, this_phi_neg_padded) 360 | return phi_pos_next, phi_neg_next 361 | 362 | phi_pos_next, phi_neg_next = tf.cond( 363 | tf.equal(i, tf.constant(0)), _first_phi, _else_phi) 364 | 365 | # The appropriate entry is already 0 which we keep since we didn't skip this query 366 | skipped_q_next = skipped_q_prev 367 | return score_std_next, score_aug_next, score_GT_next, phi_pos_next, phi_neg_next, skipped_q_next 368 | 369 | def _dont_use_query(): 370 | 371 | q_score_std = tf.zeros([1, 1]) 372 | q_score_aug = tf.zeros([1, 1]) 373 | q_score_GT = tf.zeros([1, 1]) 374 | 375 | # Update score_std_next, score_aug_next 376 | def _first_score(): 377 | this_score_std_padded = tf.concat( 378 | [q_score_std, tf.zeros([self.batch_size - i - 1, 1])], 0) 379 | this_score_aug_padded = tf.concat( 380 | [q_score_aug, tf.zeros([self.batch_size - i - 1, 1])], 0) 381 | this_score_GT_padded = tf.concat( 382 | [q_score_GT, tf.zeros([self.batch_size - i - 1, 1])], 0) 383 | score_std_next = tf.add(score_std_prev, this_score_std_padded) 384 | score_aug_next = tf.add(score_aug_prev, this_score_aug_padded) 385 | score_GT_next = tf.add(score_GT_prev, this_score_GT_padded) 386 | return score_std_next, score_aug_next, score_GT_next 387 | 388 | def _else_score(): 389 | temp = tf.concat([tf.zeros([i, 1]), q_score_std], 0) 390 | this_score_std_padded = tf.concat( 391 | [temp, tf.zeros([self.batch_size - i - 1, 1])], 0) 392 | temp = tf.concat([tf.zeros([i, 1]), q_score_aug], 0) 393 | this_score_aug_padded = tf.concat( 394 | [temp, tf.zeros([self.batch_size - i - 1, 1])], 0) 395 | temp = tf.concat([tf.zeros([i, 1]), q_score_GT], 0) 396 | this_score_GT_padded = tf.concat( 397 | [temp, tf.zeros([self.batch_size - i - 1, 1])], 0) 398 | score_std_next = tf.add(score_std_prev, this_score_std_padded) 399 | score_aug_next = tf.add(score_aug_prev, this_score_aug_padded) 400 | score_GT_next = tf.add(score_GT_prev, this_score_GT_padded) 401 | return score_std_next, score_aug_next, score_GT_next 402 | 403 | score_std_next, score_aug_next, score_GT_next = tf.cond( 404 | tf.equal(i, tf.constant(0)), _first_score, _else_score) 405 | 406 | q_phi_pos = tf.zeros([1, max_pos]) 407 | q_phi_neg = tf.zeros([1, max_neg]) 408 | 409 | # Update phi_pos_next, phi_neg_next 410 | def _first_phi(): 411 | this_phi_pos_padded = tf.concat( 412 | [q_phi_pos, tf.zeros([self.batch_size - i - 1, max_pos])], 0) 413 | this_phi_neg_padded = tf.concat( 414 | [q_phi_neg, tf.zeros([self.batch_size - i - 1, max_neg])], 0) 415 | phi_pos_next = tf.add(phi_pos_prev, this_phi_pos_padded) 416 | phi_neg_next = tf.add(phi_neg_prev, this_phi_neg_padded) 417 | return phi_pos_next, phi_neg_next 418 | 419 | def _else_phi(): 420 | temp = tf.concat([tf.zeros([i, max_pos]), q_phi_pos], 0) 421 | this_phi_pos_padded = tf.concat( 422 | [temp, tf.zeros([self.batch_size - i - 1, max_pos])], 0) 423 | temp = tf.concat([tf.zeros([i, max_neg]), q_phi_neg], 0) 424 | this_phi_neg_padded = tf.concat( 425 | [temp, tf.zeros([self.batch_size - i - 1, max_neg])], 0) 426 | phi_pos_next = tf.add(phi_pos_prev, this_phi_pos_padded) 427 | phi_neg_next = tf.add(phi_neg_prev, this_phi_neg_padded) 428 | return phi_pos_next, phi_neg_next 429 | 430 | phi_pos_next, phi_neg_next = tf.cond( 431 | tf.equal(i, tf.constant(0)), _first_phi, _else_phi) 432 | 433 | # Update skipped_q_next 434 | # make ith position a 1 in the one-hot-encoded vector 435 | def _first_skip(): 436 | skipped_this = tf.concat( 437 | [tf.ones([1]), tf.zeros([self.batch_size - i - 1])], 0) 438 | skipped_q_next = tf.add(skipped_q_prev, skipped_this) 439 | return skipped_q_next 440 | 441 | def _else_skip(): 442 | temp = tf.concat([tf.zeros([i]), tf.ones([1])], 0) 443 | skipped_this = tf.concat([temp, tf.zeros([self.batch_size - i - 1])], 0) 444 | skipped_q_next = tf.add(skipped_q_prev, skipped_this) 445 | return skipped_q_next 446 | 447 | skipped_q_next = tf.cond( 448 | tf.equal(i, tf.constant(0)), _first_skip, _else_skip) 449 | return score_std_next, score_aug_next, score_GT_next, phi_pos_next, phi_neg_next, skipped_q_next 450 | 451 | _use_query_cond = tf.greater(q_num_pos, 0) 452 | score_std_next, score_aug_next, score_GT_next, phi_pos_next, phi_neg_next, skipped_q_next = tf.cond( 453 | _use_query_cond, _use_query, _dont_use_query) 454 | i = tf.add(i, 1) 455 | return i, skipped_q_next, score_std_next, score_aug_next, score_GT_next, phi_pos_next, phi_neg_next 456 | 457 | i = tf.constant(0) 458 | 459 | def condition(i, _1, _2, _3, _4, _5, _6): 460 | return tf.less(i, self.n_queries_to_parse) 461 | 462 | # Initialize the loop variables - their size will remain unchanged throughout the loop. 463 | phi_pos = tf.zeros([self.batch_size, tf.reduce_max(self.num_pos)]) 464 | phi_neg = tf.zeros([self.batch_size, tf.reduce_max(self.num_neg)]) 465 | score_std = tf.zeros([self.batch_size, 1]) 466 | score_aug = tf.zeros([self.batch_size, 1]) 467 | score_GT = tf.zeros([self.batch_size, 1]) 468 | # One-hot-encoded vector indicating which queries were skipped 469 | # (a query is skipped if it can't create a positive set) 470 | skipped_queries = tf.zeros([self.batch_size]) 471 | 472 | _i, skipped_queries, score_std, score_aug, score_GT, phi_pos, phi_neg = tf.while_loop( 473 | condition, 474 | body, 475 | loop_vars=[ 476 | i, skipped_queries, score_std, score_aug, score_GT, phi_pos, phi_neg 477 | ]) 478 | mAP_score_std = tf.reduce_mean(score_std) 479 | mAP_score_aug = tf.reduce_mean(score_aug) 480 | mAP_score_GT = tf.reduce_mean(score_GT) 481 | return phi_pos, phi_neg, mAP_score_std, mAP_score_aug, mAP_score_GT, skipped_queries 482 | 483 | @property 484 | def train_step(self): 485 | return self._train_step 486 | 487 | @property 488 | def x(self): 489 | return self._x 490 | 491 | @property 492 | def config(self): 493 | return self._config 494 | 495 | @property 496 | def feats(self): 497 | return self._feats 498 | 499 | @property 500 | def lr(self): 501 | return self._lr 502 | 503 | @property 504 | def new_lr(self): 505 | return self._new_lr 506 | 507 | @property 508 | def assign_lr(self): 509 | return self._assign_lr 510 | 511 | @property 512 | def loss(self): 513 | return self._loss 514 | 515 | @property 516 | def batch_size(self): 517 | return self._batch_size 518 | 519 | @property 520 | def num_pos(self): 521 | return self._num_pos 522 | 523 | @property 524 | def num_neg(self): 525 | return self._num_neg 526 | 527 | @property 528 | def pos_inds(self): 529 | return self._pos_inds 530 | 531 | @property 532 | def neg_inds(self): 533 | return self._neg_inds 534 | 535 | @property 536 | def Y_aug(self): 537 | return self._Y_aug 538 | 539 | @property 540 | def phi_pos(self): 541 | return self._phi_pos 542 | 543 | @property 544 | def phi_neg(self): 545 | return self._phi_neg 546 | 547 | @property 548 | def mAP_score_std(self): 549 | return self._mAP_score_std 550 | 551 | @property 552 | def mAP_score_aug(self): 553 | return self._mAP_score_aug 554 | 555 | @property 556 | def mAP_score_GT(self): 557 | return self._mAP_score_GT 558 | 559 | @property 560 | def skipped_queries(self): 561 | return self._skipped_queries 562 | 563 | @property 564 | def n_queries_to_parse(self): 565 | return self._n_queries_to_parse 566 | --------------------------------------------------------------------------------