├── fewshot ├── __init__.py ├── configs │ ├── __init__.py │ ├── config_factory.py │ ├── mini_imagenet_config.py │ ├── tiered_imagenet_config.py │ └── omniglot_config.py ├── data │ ├── __init__.py │ ├── tiered_imagenet_split │ │ ├── val.csv │ │ ├── test.csv │ │ ├── old_test.csv │ │ ├── train.csv │ │ ├── old_train.csv │ │ └── trainval.csv │ ├── compress_tiered_imagenet.py │ ├── data_factory.py │ ├── episode.py │ ├── omniglot_split │ │ ├── val.txt │ │ └── test.txt │ ├── batch_iter.py │ ├── concurrent_batch_iter.py │ ├── omniglot.py │ └── refinement_dataset.py ├── utils │ ├── __init__.py │ ├── debug.py │ ├── experiment_logger.py │ ├── lr_schedule.py │ ├── batch_iter.py │ └── logger.py └── models │ ├── __init__.py │ ├── model_factory.py │ ├── measure_tests.py │ ├── distractor_utils.py │ ├── kmeans_refine_model.py │ ├── refine_model.py │ ├── measure.py │ ├── kmeans_refine_radius_model.py │ ├── basic_model.py │ ├── kmeans_utils.py │ ├── kmeans_refine_mask_model.py │ └── model.py ├── .style.yapf ├── LICENSE ├── .gitignore ├── run_exp_tests.py ├── README.md └── run_multi_exp.py /fewshot/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fewshot/configs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fewshot/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fewshot/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fewshot/models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | based_on_style = google 3 | COLUMN_LIMIT = 80 4 | INDENT_WIDTH = 2 5 | -------------------------------------------------------------------------------- /fewshot/configs/config_factory.py: -------------------------------------------------------------------------------- 1 | from __future__ import (absolute_import, division, print_function, 2 | unicode_literals) 3 | 4 | CONFIG_REGISTRY = {} 5 | 6 | 7 | def RegisterConfig(dataset_name, model_name): 8 | """Registers a configuration.""" 9 | 10 | def decorator(f): 11 | key = "{}_{}".format(dataset_name, model_name) 12 | CONFIG_REGISTRY[key] = f 13 | return f 14 | 15 | return decorator 16 | 17 | 18 | def get_config(dataset_name, model_name): 19 | key = "{}_{}".format(dataset_name, model_name) 20 | if key in CONFIG_REGISTRY: 21 | return CONFIG_REGISTRY[key]() 22 | else: 23 | raise ValueError("No registered config: \"{}\"".format(key)) 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Mengye Ren, Eleni Triantafillou, Sachin Ravi, Jake Snell, 4 | Kevin Swersky, Joshua B. Tenenbaum, Hugo Larochelle, Richars S. Zemel. 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /fewshot/utils/debug.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 Mengye Ren, Eleni Triantafillou, Sachin Ravi, Jake Snell, 2 | # Kevin Swersky, Joshua B. Tenenbaum, Hugo Larochelle, Richars S. Zemel. 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the "Software"), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all 12 | # copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | # SOFTWARE. 21 | # ============================================================================= 22 | import tensorflow as tf 23 | import os 24 | 25 | 26 | def debug_identity(x, name=None): 27 | if int(os.environ.get('TF_DEBUG', 0)) == 1: 28 | return tf.Print(x, [ 29 | x.name if name is None else name, tf.reduce_mean(x), tf.reduce_max(x), 30 | tf.reduce_min(x) 31 | ]) 32 | else: 33 | return x 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /run_exp_tests.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 Mengye Ren, Eleni Triantafillou, Sachin Ravi, Jake Snell, 2 | # Kevin Swersky, Joshua B. Tenenbaum, Hugo Larochelle, Richars S. Zemel. 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the "Software"), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all 12 | # copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | # SOFTWARE. 21 | # ============================================================================= 22 | """Runs a set of small training programs to verify that the code runs. 23 | Author: Mengye Ren (mren@cs.toronto.edu) 24 | """ 25 | import tensorflow as tf 26 | 27 | from run_multi_exp import run_one 28 | 29 | 30 | def main(): 31 | for dataset in ['omniglot']: 32 | for model in [ 33 | 'basic', 'kmeans-refine', 'kmeans-refine-radius', 'kmeans-refine-mask' 34 | ]: 35 | with tf.Graph().as_default(): 36 | run_one(dataset, model + '-test', 0, None, 'abc') 37 | 38 | 39 | if __name__ == '__main__': 40 | main() 41 | -------------------------------------------------------------------------------- /fewshot/models/model_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 Mengye Ren, Eleni Triantafillou, Sachin Ravi, Jake Snell, 2 | # Kevin Swersky, Joshua B. Tenenbaum, Hugo Larochelle, Richars S. Zemel. 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the "Software"), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all 12 | # copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | # SOFTWARE. 21 | # ============================================================================= 22 | from __future__ import (absolute_import, division, print_function, 23 | unicode_literals) 24 | 25 | from fewshot.utils import logger 26 | 27 | log = logger.get() 28 | 29 | MODEL_REGISTRY = {} 30 | 31 | 32 | def RegisterModel(model_name): 33 | """Registers a model class""" 34 | 35 | def decorator(f): 36 | MODEL_REGISTRY[model_name] = f 37 | return f 38 | 39 | return decorator 40 | 41 | 42 | def get_model(model_name, *args, **kwargs): 43 | log.info("Model {}".format(model_name)) 44 | if model_name in MODEL_REGISTRY: 45 | return MODEL_REGISTRY[model_name](*args, **kwargs) 46 | else: 47 | raise ValueError("Model class does not exist {}".format(model_name)) 48 | -------------------------------------------------------------------------------- /fewshot/models/measure_tests.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 Mengye Ren, Eleni Triantafillou, Sachin Ravi, Jake Snell, 2 | # Kevin Swersky, Joshua B. Tenenbaum, Hugo Larochelle, Richars S. Zemel. 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the "Software"), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all 12 | # copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | # SOFTWARE. 21 | # ============================================================================= 22 | from __future__ import (absolute_import, division, print_function, 23 | unicode_literals) 24 | 25 | import numpy as np 26 | import unittest 27 | 28 | from fewshot.models.measure import batch_apk, apk 29 | 30 | 31 | def fake_batch_apk(logits, pos_mask, k): 32 | ap = [] 33 | for ii in range(logits.shape[0]): 34 | ap.append(apk(logits[ii], pos_mask[ii], k[ii])) 35 | return np.array(ap) 36 | 37 | 38 | class MeasureTests(unittest.TestCase): 39 | 40 | def test_batch_apk(self): 41 | rnd = np.random.RandomState(0) 42 | for ii in range(100): 43 | logits = rnd.uniform(0.0, 1.0, [10, 12]) 44 | pos_mask = (rnd.uniform(0.0, 1.0, [10, 12]) > 0.5).astype(np.float32) 45 | k = rnd.uniform(5.0, 10.0, [10]).astype(np.int32) 46 | ap1 = batch_apk(logits, pos_mask, k) 47 | ap2 = fake_batch_apk(logits, pos_mask, k) 48 | np.testing.assert_allclose(ap1, ap2) 49 | 50 | 51 | if __name__ == "__main__": 52 | unittest.main() 53 | -------------------------------------------------------------------------------- /fewshot/models/distractor_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 Mengye Ren, Eleni Triantafillou, Sachin Ravi, Jake Snell, 2 | # Kevin Swersky, Joshua B. Tenenbaum, Hugo Larochelle, Richars S. Zemel. 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the "Software"), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all 12 | # copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | # SOFTWARE. 21 | # ============================================================================= 22 | from __future__ import print_function, division 23 | 24 | import tensorflow as tf 25 | 26 | 27 | def eval_distractor(pred_non_distractor, gt_non_distractor): 28 | """Evaluates distractor prediction. 29 | 30 | Args: 31 | pred_non_distractor 32 | gt_non_distractor 33 | 34 | Returns: 35 | acc: 36 | recall: 37 | precision: 38 | """ 39 | y = gt_non_distractor 40 | pred_distractor = 1.0 - pred_non_distractor 41 | non_distractor_correct = tf.to_float(tf.equal(pred_non_distractor, y)) 42 | distractor_tp = pred_distractor * (1 - y) 43 | distractor_recall = tf.reduce_sum(distractor_tp) / tf.reduce_sum(1 - y) 44 | distractor_precision = tf.reduce_sum(distractor_tp) / ( 45 | tf.reduce_sum(pred_distractor) + 46 | tf.to_float(tf.equal(tf.reduce_sum(pred_distractor), 0.0))) 47 | acc = tf.reduce_mean(non_distractor_correct) 48 | recall = tf.reduce_mean(distractor_recall) 49 | precision = tf.reduce_mean(distractor_precision) 50 | 51 | return acc, recall, precision 52 | -------------------------------------------------------------------------------- /fewshot/data/tiered_imagenet_split/val.csv: -------------------------------------------------------------------------------- 1 | n02099267,n02098550 2 | n02099429,n02098550 3 | n02099601,n02098550 4 | n02099712,n02098550 5 | n02099849,n02098550 6 | n02100236,n02098550 7 | n02100583,n02098550 8 | n02100735,n02098550 9 | n02100877,n02098550 10 | n02101006,n02098550 11 | n02101388,n02098550 12 | n02101556,n02098550 13 | n02102040,n02098550 14 | n02102177,n02098550 15 | n02102318,n02098550 16 | n02102480,n02098550 17 | n02102973,n02098550 18 | n03207941,n03257877 19 | n03259280,n03257877 20 | n03297495,n03257877 21 | n03483316,n03257877 22 | n03584829,n03257877 23 | n03761084,n03257877 24 | n04070727,n03257877 25 | n04111531,n03257877 26 | n04442312,n03257877 27 | n04517823,n03257877 28 | n04542943,n03257877 29 | n04554684,n03257877 30 | n02791124,n03405265 31 | n02804414,n03405265 32 | n02870880,n03405265 33 | n03016953,n03405265 34 | n03018349,n03405265 35 | n03125729,n03405265 36 | n03131574,n03405265 37 | n03179701,n03405265 38 | n03201208,n03405265 39 | n03290653,n03405265 40 | n03337140,n03405265 41 | n03376595,n03405265 42 | n03388549,n03405265 43 | n03742115,n03405265 44 | n03891251,n03405265 45 | n03998194,n03405265 46 | n04099969,n03405265 47 | n04344873,n03405265 48 | n04380533,n03405265 49 | n04429376,n03405265 50 | n04447861,n03405265 51 | n04550184,n03405265 52 | n02666196,n03699975 53 | n02977058,n03699975 54 | n03180011,n03699975 55 | n03485407,n03699975 56 | n03496892,n03699975 57 | n03642806,n03699975 58 | n03832673,n03699975 59 | n04238763,n03699975 60 | n04243546,n03699975 61 | n04428191,n03699975 62 | n04525305,n03699975 63 | n06359193,n03699975 64 | n02966193,n03738472 65 | n02974003,n03738472 66 | n03425413,n03738472 67 | n03532672,n03738472 68 | n03874293,n03738472 69 | n03944341,n03738472 70 | n03992509,n03738472 71 | n04019541,n03738472 72 | n04040759,n03738472 73 | n04067472,n03738472 74 | n04371774,n03738472 75 | n04372370,n03738472 76 | n02701002,n03791235 77 | n02704792,n03791235 78 | n02814533,n03791235 79 | n02930766,n03791235 80 | n03100240,n03791235 81 | n03345487,n03791235 82 | n03417042,n03791235 83 | n03444034,n03791235 84 | n03445924,n03791235 85 | n03594945,n03791235 86 | n03670208,n03791235 87 | n03770679,n03791235 88 | n03777568,n03791235 89 | n03785016,n03791235 90 | n03796401,n03791235 91 | n03930630,n03791235 92 | n03977966,n03791235 93 | n04037443,n03791235 94 | n04252225,n03791235 95 | n04285008,n03791235 96 | n04461696,n03791235 97 | n04467665,n03791235 98 | -------------------------------------------------------------------------------- /fewshot/data/compress_tiered_imagenet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 Mengye Ren, Eleni Triantafillou, Sachin Ravi, Jake Snell, 2 | # Kevin Swersky, Joshua B. Tenenbaum, Hugo Larochelle, Richars S. Zemel. 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the "Software"), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all 12 | # copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | # SOFTWARE. 21 | # ============================================================================= 22 | import cv2 23 | import numpy as np 24 | import six 25 | import sys 26 | import pickle as pkl 27 | 28 | from tqdm import tqdm 29 | 30 | 31 | def compress(path, output): 32 | with np.load(path, mmap_mode="r") as data: 33 | images = data["images"] 34 | array = [] 35 | for ii in tqdm(six.moves.xrange(images.shape[0]), desc='compress'): 36 | im = images[ii] 37 | im_str = cv2.imencode('.png', im)[1] 38 | array.append(im_str) 39 | with open(output, 'wb') as f: 40 | pkl.dump(array, f, protocol=pkl.HIGHEST_PROTOCOL) 41 | 42 | 43 | def decompress(path, output): 44 | with open(output, 'rb') as f: 45 | array = pkl.load(f) 46 | images = np.zeros([len(array), 84, 84, 3], dtype=np.uint8) 47 | for ii, item in tqdm(enumerate(array), desc='decompress'): 48 | im = cv2.imdecode(item, 1) 49 | images[ii] = im 50 | np.savez(path, images=images) 51 | 52 | 53 | def main(): 54 | if sys.argv[1] == 'compress': 55 | compress(sys.argv[2], sys.argv[3]) 56 | elif sys.argv[1] == 'decompress': 57 | decompress(sys.argv[2], sys.argv[3]) 58 | 59 | 60 | if __name__ == '__main__': 61 | main() 62 | -------------------------------------------------------------------------------- /fewshot/data/data_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 Mengye Ren, Eleni Triantafillou, Sachin Ravi, Jake Snell, 2 | # Kevin Swersky, Joshua B. Tenenbaum, Hugo Larochelle, Richars S. Zemel. 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the "Software"), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all 12 | # copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | # SOFTWARE. 21 | # ============================================================================= 22 | import os 23 | import tensorflow as tf 24 | 25 | from fewshot.data.concurrent_batch_iter import ConcurrentBatchIterator 26 | 27 | flags = tf.flags 28 | flags.DEFINE_string("data_root", "data", "Data root") 29 | FLAGS = tf.flags.FLAGS 30 | 31 | DATASET_REGISTRY = {} 32 | 33 | 34 | def RegisterDataset(dataset_name): 35 | """Registers a dataset class""" 36 | 37 | def decorator(f): 38 | DATASET_REGISTRY[dataset_name] = f 39 | return f 40 | 41 | return decorator 42 | 43 | 44 | def get_data_folder(dataset_name): 45 | data_folder = os.path.join(FLAGS.data_root, dataset_name) 46 | return data_folder 47 | 48 | 49 | def get_dataset(dataset_name, split, *args, **kwargs): 50 | if dataset_name in DATASET_REGISTRY: 51 | return DATASET_REGISTRY[dataset_name](get_data_folder(dataset_name), split, 52 | *args, **kwargs) 53 | else: 54 | raise ValueError("Unknown dataset \"{}\"".format(dataset_name)) 55 | 56 | 57 | def get_concurrent_iterator(dataset, max_queue_size=100, num_threads=10): 58 | return ConcurrentBatchIterator( 59 | dataset, 60 | max_queue_size=max_queue_size, 61 | num_threads=num_threads, 62 | log_queue=-1) 63 | -------------------------------------------------------------------------------- /fewshot/data/episode.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 Mengye Ren, Eleni Triantafillou, Sachin Ravi, Jake Snell, 2 | # Kevin Swersky, Joshua B. Tenenbaum, Hugo Larochelle, Richars S. Zemel. 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the "Software"), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all 12 | # copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | # SOFTWARE. 21 | # ============================================================================= 22 | class Episode(object): 23 | 24 | def __init__(self, 25 | x_train, 26 | y_train, 27 | x_test, 28 | y_test, 29 | x_unlabel=None, 30 | y_unlabel=None, 31 | y_train_str=None, 32 | y_test_str=None): 33 | """Creates a miniImageNet episode. 34 | Args: 35 | x_train: [N, ...]. Training data. 36 | y_train: [N]. Training label. 37 | x_test: [N, ...]. Testing data. 38 | y_test: [N]. Testing label. 39 | """ 40 | self._x_train = x_train 41 | self._y_train = y_train 42 | self._x_test = x_test 43 | self._y_test = y_test 44 | self._x_unlabel = x_unlabel 45 | self._y_unlabel = y_unlabel 46 | self._y_train_str = y_train_str 47 | self._y_test_str = y_test_str 48 | 49 | def next_batch(self): 50 | return self 51 | 52 | @property 53 | def x_train(self): 54 | return self._x_train 55 | 56 | @property 57 | def x_test(self): 58 | return self._x_test 59 | 60 | @property 61 | def y_train(self): 62 | return self._y_train 63 | 64 | @property 65 | def y_test(self): 66 | return self._y_test 67 | 68 | @property 69 | def x_unlabel(self): 70 | return self._x_unlabel 71 | 72 | @property 73 | def y_unlabel(self): 74 | return self._y_unlabel 75 | 76 | @property 77 | def y_train_str(self): 78 | return self._y_train_str 79 | 80 | @property 81 | def y_test_str(self): 82 | return self._y_test_str 83 | -------------------------------------------------------------------------------- /fewshot/models/kmeans_refine_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 Mengye Ren, Eleni Triantafillou, Sachin Ravi, Jake Snell, 2 | # Kevin Swersky, Joshua B. Tenenbaum, Hugo Larochelle, Richars S. Zemel. 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the "Software"), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all 12 | # copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | # SOFTWARE. 21 | # ============================================================================= 22 | """ 23 | A prototypical network with K-means to refine unlabeled examples. 24 | 25 | Author: Mengye Ren (mren@cs.toronto.edu) 26 | 27 | In a single episode, the model computes the mean representation of the 28 | positive refereene images as prototypes and then refine the representation by 29 | running a few steps of soft K-means iterations. 30 | """ 31 | import tensorflow as tf 32 | 33 | from fewshot.models.kmeans_utils import assign_cluster, update_cluster, compute_logits 34 | from fewshot.models.model_factory import RegisterModel 35 | from fewshot.models.nnlib import concat 36 | from fewshot.models.refine_model import RefineModel 37 | from fewshot.utils import logger 38 | 39 | log = logger.get() 40 | 41 | 42 | @RegisterModel("kmeans-refine") 43 | class KMeansRefineModel(RefineModel): 44 | 45 | def predict(self): 46 | """See `model.py` for documentation.""" 47 | nclasses = self.nway 48 | num_cluster_steps = self.config.num_cluster_steps 49 | h_train, h_unlabel, h_test = self.get_encoded_inputs( 50 | self.x_train, self.x_unlabel, self.x_test) 51 | y_train = self.y_train 52 | protos = self._compute_protos(nclasses, h_train, y_train) 53 | logits = compute_logits(protos, h_test) 54 | 55 | # Hard assignment for training images. 56 | prob_train = [None] * nclasses 57 | for kk in range(nclasses): 58 | # [B, N, 1] 59 | prob_train[kk] = tf.expand_dims( 60 | tf.cast(tf.equal(y_train, kk), h_train.dtype), 2) 61 | prob_train = concat(prob_train, 2) 62 | 63 | h_all = concat([h_train, h_unlabel], 1) 64 | 65 | logits_list = [] 66 | logits_list.append(compute_logits(protos, h_test)) 67 | 68 | # Run clustering. 69 | for tt in range(num_cluster_steps): 70 | # Label assignment. 71 | prob_unlabel = assign_cluster(protos, h_unlabel) 72 | entropy = tf.reduce_sum( 73 | -prob_unlabel * tf.log(prob_unlabel), [2], keep_dims=True) 74 | prob_all = concat([prob_train, prob_unlabel], 1) 75 | prob_all = tf.stop_gradient(prob_all) 76 | protos = update_cluster(h_all, prob_all) 77 | # protos = tf.cond( 78 | # tf.shape(self._x_unlabel)[1] > 0, 79 | # lambda: update_cluster(h_all, prob_all), lambda: protos) 80 | logits_list.append(compute_logits(protos, h_test)) 81 | 82 | return logits_list 83 | -------------------------------------------------------------------------------- /fewshot/models/refine_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 Mengye Ren, Eleni Triantafillou, Sachin Ravi, Jake Snell, 2 | # Kevin Swersky, Joshua B. Tenenbaum, Hugo Larochelle, Richars S. Zemel. 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the "Software"), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all 12 | # copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | # SOFTWARE. 21 | # ============================================================================= 22 | """ 23 | A few-shot classification model implementation that refines on unlabled 24 | refinement images. 25 | 26 | Author: Mengye Ren (mren@cs.toronto.edu) 27 | 28 | A single episode is divided into three parts: 29 | 1) Labeled reference images (self.x_ref). 30 | 2) Unlabeled refinement images (self.x_unlabel). 31 | 3) Labeled query images (from validation) (self.x_candidate). 32 | """ 33 | from __future__ import (absolute_import, division, print_function, 34 | unicode_literals) 35 | 36 | import numpy as np 37 | import tensorflow as tf 38 | 39 | from fewshot.models.nnlib import cnn, weight_variable, concat 40 | from fewshot.models.basic_model import BasicModel 41 | from fewshot.utils import logger 42 | log = logger.get() 43 | 44 | # Load up the LSTM cell implementation. 45 | if tf.__version__.startswith("0"): 46 | BasicLSTMCell = tf.nn.rnn_cell.BasicLSTMCell 47 | LSTMStateTuple = tf.nn.rnn_cell.LSTMStateTuple 48 | else: 49 | BasicLSTMCell = tf.contrib.rnn.BasicLSTMCell 50 | LSTMStateTuple = tf.contrib.rnn.LSTMStateTuple 51 | 52 | 53 | class RefineModel(BasicModel): 54 | """A retrieval model with an additional refinement stage.""" 55 | 56 | def __init__(self, 57 | config, 58 | nway=1, 59 | nshot=1, 60 | num_unlabel=10, 61 | candidate_size=10, 62 | is_training=True, 63 | dtype=tf.float32): 64 | """Initiliazer. 65 | Args: 66 | config: Model configuration object. 67 | nway: Int. Number of classes in the reference images. 68 | nshot: Int. Number of labeled reference images. 69 | num_unlabel: Int. Number of unlabeled refinement images. 70 | candidate_size: Int. Number of candidates in the query stage. 71 | is_training: Bool. Whether is in training mode. 72 | dtype: TensorFlow data type. 73 | """ 74 | self._num_unlabel = num_unlabel 75 | self._x_unlabel = tf.placeholder( 76 | dtype, [None, None, config.height, config.width, config.num_channel], 77 | name="x_unlabel") 78 | self._y_unlabel = tf.placeholder(dtype, [None, None], name="y_unlabel") 79 | super(RefineModel, self).__init__( 80 | config, 81 | nway=nway, 82 | nshot=nshot, 83 | num_test=candidate_size, 84 | is_training=is_training, 85 | dtype=dtype) 86 | 87 | @property 88 | def x_unlabel(self): 89 | return self._x_unlabel 90 | 91 | @property 92 | def y_unlabel(self): 93 | return self._y_unlabel 94 | 95 | 96 | if __name__ == "__main__": 97 | from fewshot.configs.omniglot_config import OmniglotRefineConfig 98 | model = RefineModel(OmniglotRefineConfig()) 99 | -------------------------------------------------------------------------------- /fewshot/data/tiered_imagenet_split/test.csv: -------------------------------------------------------------------------------- 1 | n03314780,n00020090 2 | n07565083,n00020090 3 | n07579787,n00020090 4 | n07583066,n00020090 5 | n07584110,n00020090 6 | n07590611,n00020090 7 | n07613480,n00020090 8 | n07614500,n00020090 9 | n07615774,n00020090 10 | n07697313,n00020090 11 | n07697537,n00020090 12 | n07802026,n00020090 13 | n07831146,n00020090 14 | n07836838,n00020090 15 | n07860988,n00020090 16 | n07873807,n00020090 17 | n07875152,n00020090 18 | n07880968,n00020090 19 | n07892512,n00020090 20 | n07920052,n00020090 21 | n07930864,n00020090 22 | n07932039,n00020090 23 | n01440764,n01473806 24 | n01443537,n01473806 25 | n01484850,n01473806 26 | n01491361,n01473806 27 | n01494475,n01473806 28 | n01496331,n01473806 29 | n01498041,n01473806 30 | n02514041,n01473806 31 | n02526121,n01473806 32 | n02536864,n01473806 33 | n02606052,n01473806 34 | n02607072,n01473806 35 | n02640242,n01473806 36 | n02641379,n01473806 37 | n02643566,n01473806 38 | n02655020,n01473806 39 | n02104029,n02103406 40 | n02104365,n02103406 41 | n02105056,n02103406 42 | n02105162,n02103406 43 | n02105251,n02103406 44 | n02105412,n02103406 45 | n02105505,n02103406 46 | n02105641,n02103406 47 | n02105855,n02103406 48 | n02106030,n02103406 49 | n02106166,n02103406 50 | n02106382,n02103406 51 | n02106550,n02103406 52 | n02106662,n02103406 53 | n02107142,n02103406 54 | n02107312,n02103406 55 | n02107574,n02103406 56 | n02107683,n02103406 57 | n02107908,n02103406 58 | n02108000,n02103406 59 | n02108089,n02103406 60 | n02108422,n02103406 61 | n02108551,n02103406 62 | n02108915,n02103406 63 | n02109047,n02103406 64 | n02109525,n02103406 65 | n02109961,n02103406 66 | n02110063,n02103406 67 | n02110185,n02103406 68 | n02110627,n02103406 69 | n02165105,n02159955 70 | n02165456,n02159955 71 | n02167151,n02159955 72 | n02168699,n02159955 73 | n02169497,n02159955 74 | n02172182,n02159955 75 | n02174001,n02159955 76 | n02177972,n02159955 77 | n02190166,n02159955 78 | n02206856,n02159955 79 | n02219486,n02159955 80 | n02226429,n02159955 81 | n02229544,n02159955 82 | n02231487,n02159955 83 | n02233338,n02159955 84 | n02236044,n02159955 85 | n02256656,n02159955 86 | n02259212,n02159955 87 | n02264363,n02159955 88 | n02268443,n02159955 89 | n02268853,n02159955 90 | n02276258,n02159955 91 | n02277742,n02159955 92 | n02279972,n02159955 93 | n02280649,n02159955 94 | n02281406,n02159955 95 | n02281787,n02159955 96 | n02788148,n03839993 97 | n02894605,n03839993 98 | n03000134,n03839993 99 | n03160309,n03839993 100 | n03459775,n03839993 101 | n03930313,n03839993 102 | n04239074,n03839993 103 | n04326547,n03839993 104 | n04501370,n03839993 105 | n04604644,n03839993 106 | n02795169,n04531098 107 | n02808440,n04531098 108 | n02815834,n04531098 109 | n02823428,n04531098 110 | n02909870,n04531098 111 | n02939185,n04531098 112 | n03063599,n04531098 113 | n03063689,n04531098 114 | n03633091,n04531098 115 | n03786901,n04531098 116 | n03937543,n04531098 117 | n03950228,n04531098 118 | n03983396,n04531098 119 | n04049303,n04531098 120 | n04398044,n04531098 121 | n04493381,n04531098 122 | n04522168,n04531098 123 | n04553703,n04531098 124 | n04557648,n04531098 125 | n04560804,n04531098 126 | n04562935,n04531098 127 | n04579145,n04531098 128 | n04591713,n04531098 129 | n09193705,n09287968 130 | n09246464,n09287968 131 | n09256479,n09287968 132 | n09288635,n09287968 133 | n09332890,n09287968 134 | n09399592,n09287968 135 | n09421951,n09287968 136 | n09428293,n09287968 137 | n09468604,n09287968 138 | n09472597,n09287968 139 | n07714571,n15046900 140 | n07714990,n15046900 141 | n07715103,n15046900 142 | n07716358,n15046900 143 | n07716906,n15046900 144 | n07717410,n15046900 145 | n07717556,n15046900 146 | n07718472,n15046900 147 | n07718747,n15046900 148 | n07720875,n15046900 149 | n07730033,n15046900 150 | n07734744,n15046900 151 | n07742313,n15046900 152 | n07745940,n15046900 153 | n07747607,n15046900 154 | n07749582,n15046900 155 | n07753113,n15046900 156 | n07753275,n15046900 157 | n07753592,n15046900 158 | n07754684,n15046900 159 | n07760859,n15046900 160 | n07768694,n15046900 161 | -------------------------------------------------------------------------------- /fewshot/data/tiered_imagenet_split/old_test.csv: -------------------------------------------------------------------------------- 1 | n03314780,n00020090 2 | n07565083,n00020090 3 | n07579787,n00020090 4 | n07583066,n00020090 5 | n07584110,n00020090 6 | n07590611,n00020090 7 | n07613480,n00020090 8 | n07614500,n00020090 9 | n07615774,n00020090 10 | n07697313,n00020090 11 | n07697537,n00020090 12 | n07802026,n00020090 13 | n07831146,n00020090 14 | n07836838,n00020090 15 | n07860988,n00020090 16 | n07873807,n00020090 17 | n07875152,n00020090 18 | n07880968,n00020090 19 | n07892512,n00020090 20 | n07920052,n00020090 21 | n07930864,n00020090 22 | n07932039,n00020090 23 | n01440764,n01473806 24 | n01443537,n01473806 25 | n01484850,n01473806 26 | n01491361,n01473806 27 | n01494475,n01473806 28 | n01496331,n01473806 29 | n01498041,n01473806 30 | n02514041,n01473806 31 | n02526121,n01473806 32 | n02536864,n01473806 33 | n02606052,n01473806 34 | n02607072,n01473806 35 | n02640242,n01473806 36 | n02641379,n01473806 37 | n02643566,n01473806 38 | n02655020,n01473806 39 | n02104029,n02103406 40 | n02104365,n02103406 41 | n02105056,n02103406 42 | n02105162,n02103406 43 | n02105251,n02103406 44 | n02105412,n02103406 45 | n02105505,n02103406 46 | n02105641,n02103406 47 | n02105855,n02103406 48 | n02106030,n02103406 49 | n02106166,n02103406 50 | n02106382,n02103406 51 | n02106550,n02103406 52 | n02106662,n02103406 53 | n02107142,n02103406 54 | n02107312,n02103406 55 | n02107574,n02103406 56 | n02107683,n02103406 57 | n02107908,n02103406 58 | n02108000,n02103406 59 | n02108089,n02103406 60 | n02108422,n02103406 61 | n02108551,n02103406 62 | n02108915,n02103406 63 | n02109047,n02103406 64 | n02109525,n02103406 65 | n02109961,n02103406 66 | n02110063,n02103406 67 | n02110185,n02103406 68 | n02110627,n02103406 69 | n02165105,n02159955 70 | n02165456,n02159955 71 | n02167151,n02159955 72 | n02168699,n02159955 73 | n02169497,n02159955 74 | n02172182,n02159955 75 | n02174001,n02159955 76 | n02177972,n02159955 77 | n02190166,n02159955 78 | n02206856,n02159955 79 | n02219486,n02159955 80 | n02226429,n02159955 81 | n02229544,n02159955 82 | n02231487,n02159955 83 | n02233338,n02159955 84 | n02236044,n02159955 85 | n02256656,n02159955 86 | n02259212,n02159955 87 | n02264363,n02159955 88 | n02268443,n02159955 89 | n02268853,n02159955 90 | n02276258,n02159955 91 | n02277742,n02159955 92 | n02279972,n02159955 93 | n02280649,n02159955 94 | n02281406,n02159955 95 | n02281787,n02159955 96 | n02788148,n03839993 97 | n02894605,n03839993 98 | n03000134,n03839993 99 | n03160309,n03839993 100 | n03459775,n03839993 101 | n03930313,n03839993 102 | n04239074,n03839993 103 | n04326547,n03839993 104 | n04501370,n03839993 105 | n04604644,n03839993 106 | n02795169,n04531098 107 | n02808440,n04531098 108 | n02815834,n04531098 109 | n02823428,n04531098 110 | n02909870,n04531098 111 | n02939185,n04531098 112 | n03063599,n04531098 113 | n03063689,n04531098 114 | n03633091,n04531098 115 | n03786901,n04531098 116 | n03937543,n04531098 117 | n03950228,n04531098 118 | n03983396,n04531098 119 | n04049303,n04531098 120 | n04398044,n04531098 121 | n04493381,n04531098 122 | n04522168,n04531098 123 | n04553703,n04531098 124 | n04557648,n04531098 125 | n04560804,n04531098 126 | n04562935,n04531098 127 | n04579145,n04531098 128 | n04591713,n04531098 129 | n09193705,n09287968 130 | n09246464,n09287968 131 | n09256479,n09287968 132 | n09288635,n09287968 133 | n09332890,n09287968 134 | n09399592,n09287968 135 | n09421951,n09287968 136 | n09428293,n09287968 137 | n09468604,n09287968 138 | n09472597,n09287968 139 | n07714571,n15046900 140 | n07714990,n15046900 141 | n07715103,n15046900 142 | n07716358,n15046900 143 | n07716906,n15046900 144 | n07717410,n15046900 145 | n07717556,n15046900 146 | n07718472,n15046900 147 | n07718747,n15046900 148 | n07720875,n15046900 149 | n07730033,n15046900 150 | n07734744,n15046900 151 | n07742313,n15046900 152 | n07745940,n15046900 153 | n07747607,n15046900 154 | n07749582,n15046900 155 | n07753113,n15046900 156 | n07753275,n15046900 157 | n07753592,n15046900 158 | n07754684,n15046900 159 | n07760859,n15046900 160 | n07768694,n15046900 161 | -------------------------------------------------------------------------------- /fewshot/configs/mini_imagenet_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 Mengye Ren, Eleni Triantafillou, Sachin Ravi, Jake Snell, 2 | # Kevin Swersky, Joshua B. Tenenbaum, Hugo Larochelle, Richars S. Zemel. 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the "Software"), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all 12 | # copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | # SOFTWARE. 21 | # ============================================================================= 22 | from fewshot.configs.config_factory import RegisterConfig 23 | 24 | 25 | @RegisterConfig("mini-imagenet", "basic") 26 | class BasicConfig(object): 27 | """Standard CNN with prototypical layer.""" 28 | 29 | def __init__(self): 30 | self.name = "mini-imagenet_basic" 31 | self.model_class = "basic" 32 | self.height = 84 33 | self.width = 84 34 | self.num_channel = 3 35 | self.steps_per_valid = 2000 36 | self.steps_per_log = 100 37 | self.steps_per_save = 2000 38 | self.filter_size = [[3, 3, 3, 64]] + [[3, 3, 64, 64]] * 3 39 | self.strides = [[1, 1, 1, 1]] * 4 40 | self.pool_fn = ["max_pool"] * 4 41 | self.pool_size = [[1, 2, 2, 1]] * 4 42 | self.pool_strides = [[1, 2, 2, 1]] * 4 43 | self.conv_act_fn = ["relu"] * 4 44 | self.conv_init_method = None 45 | self.conv_init_std = [1.0e-2] * 4 46 | self.wd = 5e-5 47 | self.learn_rate = 1e-3 48 | self.normalization = "batch_norm" 49 | self.lr_scheduler = "fixed" 50 | self.max_train_steps = 200000 51 | self.lr_decay_steps = list(range(0, self.max_train_steps, 25000)[1:]) 52 | self.lr_list = list( 53 | map(lambda x: self.learn_rate * (0.5)**x, range( 54 | len(self.lr_decay_steps)))) 55 | self.similarity = "euclidean" 56 | 57 | 58 | @RegisterConfig("mini-imagenet", "basic-pretrain") 59 | class BasicPretrainConfig(BasicConfig): 60 | 61 | def __init__(self): 62 | super(BasicPretrainConfig, self).__init__() 63 | self.max_train_steps = 4000 64 | self.lr_decay_steps = [2000, 2500, 3000, 3500] 65 | self.lr_list = list( 66 | map(lambda x: self.learn_rate * (0.5)**x, 67 | range(1, 68 | len(self.lr_decay_steps) + 1))) 69 | self.similarity = "euclidean" 70 | 71 | 72 | @RegisterConfig("mini-imagenet", "kmeans-refine") 73 | class KMeansRefineConfig(BasicConfig): 74 | 75 | def __init__(self): 76 | super(KMeansRefineConfig, self).__init__() 77 | self.name = "mini-imagenet_kmeans-refine" 78 | self.model_class = "kmeans-refine" 79 | self.num_cluster_steps = 1 80 | 81 | 82 | @RegisterConfig("mini-imagenet", "kmeans-refine-radius") 83 | class KMeansRefineDistractorConfig(BasicConfig): 84 | 85 | def __init__(self): 86 | super(KMeansRefineDistractorConfig, self).__init__() 87 | self.name = "mini-imagenet_kmeans-refine-radius" 88 | self.model_class = "kmeans-refine-radius" 89 | self.num_cluster_steps = 1 90 | 91 | 92 | @RegisterConfig("mini-imagenet", "kmeans-refine-mask") 93 | class KMeansRefineDistractorMSV3Config(BasicConfig): 94 | 95 | def __init__(self): 96 | super(KMeansRefineDistractorMSV3Config, self).__init__() 97 | self.name = "mini-imagenet_kmeans-refine-mask" 98 | self.model_class = "kmeans-refine-mask" 99 | self.num_cluster_steps = 1 100 | -------------------------------------------------------------------------------- /fewshot/configs/tiered_imagenet_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 Mengye Ren, Eleni Triantafillou, Sachin Ravi, Jake Snell, 2 | # Kevin Swersky, Joshua B. Tenenbaum, Hugo Larochelle, Richars S. Zemel. 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the "Software"), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all 12 | # copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | # SOFTWARE. 21 | # ============================================================================= 22 | from fewshot.configs.config_factory import RegisterConfig 23 | 24 | 25 | @RegisterConfig("tiered-imagenet", "basic") 26 | class BasicConfig(object): 27 | """Standard CNN with prototypical layer.""" 28 | 29 | def __init__(self): 30 | self.name = "tiered-imagenet_basic" 31 | self.model_class = "basic" 32 | self.height = 84 33 | self.width = 84 34 | self.num_channel = 3 35 | self.steps_per_valid = 2000 36 | self.steps_per_log = 100 37 | self.steps_per_save = 2000 38 | self.filter_size = [[3, 3, 3, 64]] + [[3, 3, 64, 64]] * 3 39 | self.strides = [[1, 1, 1, 1]] * 4 40 | self.pool_fn = ["max_pool"] * 4 41 | self.pool_size = [[1, 2, 2, 1]] * 4 42 | self.pool_strides = [[1, 2, 2, 1]] * 4 43 | self.conv_act_fn = ["relu"] * 4 44 | self.conv_init_method = None 45 | self.conv_init_std = [1.0e-2] * 4 46 | self.wd = 5e-5 47 | self.learn_rate = 1e-3 48 | self.normalization = "batch_norm" 49 | self.lr_scheduler = "fixed" 50 | self.max_train_steps = 200000 51 | self.lr_decay_steps = list(range(0, self.max_train_steps, 25000)[1:]) 52 | self.lr_list = list( 53 | map(lambda x: self.learn_rate * (0.5)**x, range( 54 | len(self.lr_decay_steps)))) 55 | self.similarity = "euclidean" 56 | 57 | 58 | @RegisterConfig("tiered-imagenet", "basic-pretrain") 59 | class BasicPretrainConfig(BasicConfig): 60 | 61 | def __init__(self): 62 | super(BasicPretrainConfig, self).__init__() 63 | self.max_train_steps = 4000 64 | self.lr_decay_steps = [2000, 2500, 3000, 3500] 65 | self.lr_list = list( 66 | map(lambda x: self.learn_rate * (0.5)**x, 67 | range(1, 68 | len(self.lr_decay_steps) + 1))) 69 | self.similarity = "euclidean" 70 | 71 | 72 | @RegisterConfig("tiered-imagenet", "kmeans-refine") 73 | class KMeansRefineConfig(BasicConfig): 74 | 75 | def __init__(self): 76 | super(KMeansRefineConfig, self).__init__() 77 | self.name = "tiered-imagenet_kmeans-refine" 78 | self.model_class = "kmeans-refine" 79 | self.num_cluster_steps = 1 80 | 81 | 82 | @RegisterConfig("tiered-imagenet", "kmeans-refine-radius") 83 | class KMeansRefineDistractorConfig(BasicConfig): 84 | 85 | def __init__(self): 86 | super(KMeansRefineDistractorConfig, self).__init__() 87 | self.name = "tiered-imagenet_kmeans-refine-radius" 88 | self.model_class = "kmeans-refine-radius" 89 | self.num_cluster_steps = 1 90 | 91 | 92 | @RegisterConfig("tiered-imagenet", "kmeans-refine-mask") 93 | class KMeansRefineDistractorMSV3Config(BasicConfig): 94 | 95 | def __init__(self): 96 | super(KMeansRefineDistractorMSV3Config, self).__init__() 97 | self.name = "tiered-imagenet_kmeans-refine-mask" 98 | self.model_class = "kmeans-refine-mask" 99 | self.num_cluster_steps = 1 100 | -------------------------------------------------------------------------------- /fewshot/models/measure.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 Mengye Ren, Eleni Triantafillou, Sachin Ravi, Jake Snell, 2 | # Kevin Swersky, Joshua B. Tenenbaum, Hugo Larochelle, Richars S. Zemel. 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the "Software"), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all 12 | # copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | # SOFTWARE. 21 | # ============================================================================= 22 | from __future__ import (absolute_import, division, print_function, 23 | unicode_literals) 24 | 25 | import numpy as np 26 | 27 | 28 | def batch_apk(logits, pos_mask, k): 29 | """Computes Average Precision at cut-off k. 30 | Args: 31 | logits: [B, M]. Predicted relevance of each candidate. 32 | pos_mask: [B, M]. Ground truth binary relevance of each candidate. 33 | Returns: 34 | ap_score: [B]. Average precision of the induced ranking. 35 | """ 36 | ranks = np.argsort(logits, axis=1)[:, ::-1] # for example [1, 0, 3, 2] 37 | actual = np.zeros_like(logits) 38 | predicted = np.zeros_like(logits) 39 | for ii in range(logits.shape[0]): 40 | actual[ii] = pos_mask[ii, ranks[ii]] 41 | predicted[ii] = logits[ii, ranks[ii]] 42 | num_items = np.minimum(k, predicted.shape[1]) 43 | hits = actual 44 | hits = np.cumsum(hits, axis=1) 45 | mask = (np.expand_dims(np.arange(actual.shape[1]), 0) < np.expand_dims( 46 | num_items, 1)).astype(np.float32) 47 | hits *= mask 48 | hits *= actual 49 | denom = np.arange(actual.shape[1]) + 1.0 50 | score = hits / denom 51 | score = score.sum(axis=1) 52 | num_relevant_at_k = np.maximum(np.minimum(k, actual.sum(axis=1)), 1.0) 53 | ap_score = score / num_relevant_at_k 54 | return ap_score 55 | 56 | 57 | # def batch_apk(logits, pos_mask, k): 58 | # ap_score = np.array( 59 | # [apk(logits[ii], pos_mask[ii], k) for ii in range(logits.shape[0])]) 60 | # # print('measure', ap_score) 61 | # return ap_score 62 | 63 | 64 | def apk(logits, pos_mask, k): 65 | """Computes Average Precision at cut-off k. 66 | Args: 67 | logits: [M]. Predicted relevance of each candidate. 68 | pos_mask: [M]. Ground truth binary relevance of each candidate. 69 | Returns: 70 | ap_score: []. Average precision of the induced ranking. 71 | """ 72 | ranks = np.argsort(logits)[::-1] # for example [1, 0, 3, 2] 73 | # print('ranks', ranks) 74 | actual = np.array(pos_mask)[ranks] 75 | # print('sorted', actual) 76 | # print('logits', logits) 77 | predicted = np.array(logits)[ranks] 78 | #log.info("actual: {}".format(actual)) 79 | #log.info("predicted: {}".format(predicted)) 80 | score = 0.0 81 | num_hits = 0.0 82 | for ii in range(min(k, len(predicted))): 83 | if actual[ii]: 84 | num_hits += 1.0 85 | score += num_hits / (ii + 1.0) 86 | num_relevant_at_k = max(min(k, len(np.where(actual == 1)[0])), 1.0) 87 | ap_score = score / num_relevant_at_k 88 | return ap_score 89 | 90 | 91 | def ap(logits, pos_mask): 92 | """Computes Average Precision. 93 | Args: 94 | logits: [M]. Predicted relevance of each candidate. 95 | pos_mask: [M]. Ground truth binary relevance of each candidate. 96 | Returns: 97 | ap_score: []. Average precision of the induced ranking. 98 | """ 99 | rank = np.argsort(logits)[::-1] 100 | actual = np.array(pos_mask)[ranks] 101 | predicted = np.array(logits)[ranks] 102 | num_pos = pos_mask.sum() 103 | score = 0.0 104 | num_hits = 0.0 105 | for ii in range(len(predicted)): 106 | if actual[ii]: 107 | num_hits += 1.0 108 | score += num_hits / (ii + 1.0) 109 | num_relevant_at_k = max(np.where(actual == 1)[0], 1.0) 110 | return ap_score / num_relevant_at_k 111 | -------------------------------------------------------------------------------- /fewshot/utils/experiment_logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 Mengye Ren, Eleni Triantafillou, Sachin Ravi, Jake Snell, 2 | # Kevin Swersky, Joshua B. Tenenbaum, Hugo Larochelle, Richars S. Zemel. 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the "Software"), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all 12 | # copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | # SOFTWARE. 21 | # ============================================================================= 22 | from __future__ import (absolute_import, division, print_function, 23 | unicode_literals) 24 | 25 | import datetime 26 | import os 27 | import sys 28 | 29 | from fewshot.utils import logger 30 | 31 | log = logger.get() 32 | 33 | 34 | class ExperimentLogger(): 35 | """Writes experimental logs to CSV file.""" 36 | 37 | def __init__(self, logs_folder): 38 | """Initialize files.""" 39 | self._write_to_csv = logs_folder is not None 40 | 41 | if self._write_to_csv: 42 | if not os.path.isdir(logs_folder): 43 | os.makedirs(logs_folder) 44 | 45 | catalog_file = os.path.join(logs_folder, "catalog") 46 | 47 | with open(catalog_file, "w") as f: 48 | f.write("filename,type,name\n") 49 | 50 | with open(catalog_file, "a") as f: 51 | f.write("{},plain,{}\n".format("cmd.txt", "Commands")) 52 | 53 | with open(os.path.join(logs_folder, "cmd.txt"), "w") as f: 54 | f.write(" ".join(sys.argv)) 55 | 56 | with open(catalog_file, "a") as f: 57 | f.write("train_ce.csv,csv,Train Loss (Cross Entropy)\n") 58 | f.write("train_acc.csv,csv,Train Accuracy\n") 59 | f.write("valid_acc.csv,csv,Validation Accuracy\n") 60 | f.write("learn_rate.csv,csv,Learning Rate\n") 61 | 62 | self.train_file_name = os.path.join(logs_folder, "train_ce.csv") 63 | if not os.path.exists(self.train_file_name): 64 | with open(self.train_file_name, "w") as f: 65 | f.write("step,time,ce\n") 66 | 67 | self.trainval_file_name = os.path.join(logs_folder, "train_acc.csv") 68 | if not os.path.exists(self.trainval_file_name): 69 | with open(self.trainval_file_name, "w") as f: 70 | f.write("step,time,acc\n") 71 | 72 | self.val_file_name = os.path.join(logs_folder, "valid_acc.csv") 73 | if not os.path.exists(self.val_file_name): 74 | with open(self.val_file_name, "w") as f: 75 | f.write("step,time,acc\n") 76 | 77 | self.lr_file_name = os.path.join(logs_folder, "learn_rate.csv") 78 | if not os.path.exists(self.lr_file_name): 79 | with open(self.lr_file_name, "w") as f: 80 | f.write("step,time,lr\n") 81 | 82 | def log_train_ce(self, niter, ce): 83 | """Writes training CE.""" 84 | if self._write_to_csv: 85 | with open(self.train_file_name, "a") as f: 86 | f.write("{:d},{:s},{:e}\n".format( 87 | niter + 1, datetime.datetime.now().isoformat(), ce)) 88 | 89 | def log_train_acc(self, niter, acc): 90 | """Writes training accuracy.""" 91 | if self._write_to_csv: 92 | with open(self.trainval_file_name, "a") as f: 93 | f.write("{:d},{:s},{:e}\n".format( 94 | niter + 1, datetime.datetime.now().isoformat(), acc)) 95 | 96 | def log_valid_acc(self, niter, acc): 97 | """Writes validation accuracy.""" 98 | if self._write_to_csv: 99 | with open(self.val_file_name, "a") as f: 100 | f.write("{:d},{:s},{:e}\n".format( 101 | niter + 1, datetime.datetime.now().isoformat(), acc)) 102 | 103 | def log_learn_rate(self, niter, lr): 104 | """Writes validation accuracy.""" 105 | if self._write_to_csv: 106 | with open(self.lr_file_name, "a") as f: 107 | f.write("{:d},{:s},{:e}\n".format( 108 | niter + 1, datetime.datetime.now().isoformat(), lr)) 109 | -------------------------------------------------------------------------------- /fewshot/utils/lr_schedule.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 Mengye Ren, Eleni Triantafillou, Sachin Ravi, Jake Snell, 2 | # Kevin Swersky, Joshua B. Tenenbaum, Hugo Larochelle, Richars S. Zemel. 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the "Software"), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all 12 | # copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | # SOFTWARE. 21 | # ============================================================================= 22 | """Learning rate scheduler utilities.""" 23 | from __future__ import (absolute_import, division, print_function, 24 | unicode_literals) 25 | 26 | import numpy as np 27 | 28 | from fewshot.utils import logger 29 | 30 | log = logger.get() 31 | 32 | 33 | class FixedLearnRateScheduler(object): 34 | """Adjusts learning rate according to a fixed schedule.""" 35 | 36 | def __init__(self, sess, model, base_lr, lr_decay_steps, lr_list=None): 37 | """ 38 | Args: 39 | sess: TensorFlow session object. 40 | model: Model object. 41 | base_lr: Base learning rate. 42 | lr_decay_steps: A list of step number which we perform learning decay. 43 | lr_list: A list of learning rate decay multiplier. By default, all 0.1. 44 | """ 45 | self.model = model 46 | self.sess = sess 47 | self.lr = base_lr 48 | self.lr_list = lr_list 49 | self.lr_decay_steps = lr_decay_steps 50 | self.model.assign_lr(self.sess, self.lr) 51 | 52 | def step(self, niter): 53 | """Adds to counter. Adjusts learning rate if necessary. 54 | 55 | Args: 56 | niter: Current number of iterations. 57 | """ 58 | if len(self.lr_decay_steps) > 0: 59 | if (niter + 1) == self.lr_decay_steps[0]: 60 | if self.lr_list is not None: 61 | self.lr = self.lr_list[0] 62 | else: 63 | self.lr *= 0.1 ## Divide 10 by default!!! 64 | self.model.assign_lr(self.sess, self.lr) 65 | self.lr_decay_steps.pop(0) 66 | log.warning("LR decay steps {}".format(self.lr_decay_steps)) 67 | if self.lr_list is not None: 68 | self.lr_list.pop(0) 69 | elif (niter + 1) > self.lr_decay_steps[0]: 70 | ls = self.lr_decay_steps 71 | while len(ls) > 0 and (niter + 1) > ls[0]: 72 | ls.pop(0) 73 | log.warning("LR decay steps {}".format(self.lr_decay_steps)) 74 | if self.lr_list is not None: 75 | self.lr = self.lr_list.pop(0) 76 | else: 77 | self.lr *= 0.1 78 | self.model.assign_lr(self.sess, self.lr) 79 | 80 | 81 | class ExponentialLearnRateScheduler(object): 82 | """Adjusts learning rate according to an exponential decay schedule.""" 83 | 84 | def __init__(self, sess, model, base_lr, offset_steps, total_steps, final_lr, 85 | interval): 86 | """ 87 | Args: 88 | sess: TensorFlow session object. 89 | model: Model object. 90 | base_lr: Base learning rate. 91 | offset_steps: Initial non-decay steps. 92 | total_steps: Total number of steps. 93 | final_lr: Final learning rate by the end of training. 94 | interval: Number of steps in between learning rate updates (staircase). 95 | """ 96 | self.model = model 97 | self.sess = sess 98 | self.lr = base_lr 99 | self.offset_steps = offset_steps 100 | self.total_steps = total_steps 101 | self.time_constant = (total_steps - offset_steps) / np.log( 102 | base_lr / final_lr) 103 | self.final_lr = final_lr 104 | self.interval = interval 105 | self.model.assign_lr(self.sess, self.lr) 106 | 107 | def step(self, niter): 108 | """Adds to counter. Adjusts learning rate if necessary. 109 | 110 | Args: 111 | niter: Current number of iterations. 112 | """ 113 | if niter > self.offset_steps: 114 | steps2 = niter - self.offset_steps 115 | if steps2 % self.interval == 0: 116 | new_lr = base_lr * np.exp(-steps2 / self.time_constant) 117 | self.model.assign_lr(self.sess, new_lr) 118 | -------------------------------------------------------------------------------- /fewshot/data/omniglot_split/val.txt: -------------------------------------------------------------------------------- 1 | Hebrew/character01 2 | Hebrew/character02 3 | Hebrew/character03 4 | Hebrew/character04 5 | Hebrew/character05 6 | Hebrew/character06 7 | Hebrew/character07 8 | Hebrew/character08 9 | Hebrew/character09 10 | Hebrew/character10 11 | Hebrew/character11 12 | Hebrew/character12 13 | Hebrew/character13 14 | Hebrew/character14 15 | Hebrew/character15 16 | Hebrew/character16 17 | Hebrew/character17 18 | Hebrew/character18 19 | Hebrew/character19 20 | Hebrew/character20 21 | Hebrew/character21 22 | Hebrew/character22 23 | Mkhedruli_(Georgian)/character01 24 | Mkhedruli_(Georgian)/character02 25 | Mkhedruli_(Georgian)/character03 26 | Mkhedruli_(Georgian)/character04 27 | Mkhedruli_(Georgian)/character05 28 | Mkhedruli_(Georgian)/character06 29 | Mkhedruli_(Georgian)/character07 30 | Mkhedruli_(Georgian)/character08 31 | Mkhedruli_(Georgian)/character09 32 | Mkhedruli_(Georgian)/character10 33 | Mkhedruli_(Georgian)/character11 34 | Mkhedruli_(Georgian)/character12 35 | Mkhedruli_(Georgian)/character13 36 | Mkhedruli_(Georgian)/character14 37 | Mkhedruli_(Georgian)/character15 38 | Mkhedruli_(Georgian)/character16 39 | Mkhedruli_(Georgian)/character17 40 | Mkhedruli_(Georgian)/character18 41 | Mkhedruli_(Georgian)/character19 42 | Mkhedruli_(Georgian)/character20 43 | Mkhedruli_(Georgian)/character21 44 | Mkhedruli_(Georgian)/character22 45 | Mkhedruli_(Georgian)/character23 46 | Mkhedruli_(Georgian)/character24 47 | Mkhedruli_(Georgian)/character25 48 | Mkhedruli_(Georgian)/character26 49 | Mkhedruli_(Georgian)/character27 50 | Mkhedruli_(Georgian)/character28 51 | Mkhedruli_(Georgian)/character29 52 | Mkhedruli_(Georgian)/character30 53 | Mkhedruli_(Georgian)/character31 54 | Mkhedruli_(Georgian)/character32 55 | Mkhedruli_(Georgian)/character33 56 | Mkhedruli_(Georgian)/character34 57 | Mkhedruli_(Georgian)/character35 58 | Mkhedruli_(Georgian)/character36 59 | Mkhedruli_(Georgian)/character37 60 | Mkhedruli_(Georgian)/character38 61 | Mkhedruli_(Georgian)/character39 62 | Mkhedruli_(Georgian)/character40 63 | Mkhedruli_(Georgian)/character41 64 | Armenian/character01 65 | Armenian/character02 66 | Armenian/character03 67 | Armenian/character04 68 | Armenian/character05 69 | Armenian/character06 70 | Armenian/character07 71 | Armenian/character08 72 | Armenian/character09 73 | Armenian/character10 74 | Armenian/character11 75 | Armenian/character12 76 | Armenian/character13 77 | Armenian/character14 78 | Armenian/character15 79 | Armenian/character16 80 | Armenian/character17 81 | Armenian/character18 82 | Armenian/character19 83 | Armenian/character20 84 | Armenian/character21 85 | Armenian/character22 86 | Armenian/character23 87 | Armenian/character24 88 | Armenian/character25 89 | Armenian/character26 90 | Armenian/character27 91 | Armenian/character28 92 | Armenian/character29 93 | Armenian/character30 94 | Armenian/character31 95 | Armenian/character32 96 | Armenian/character33 97 | Armenian/character34 98 | Armenian/character35 99 | Armenian/character36 100 | Armenian/character37 101 | Armenian/character38 102 | Armenian/character39 103 | Armenian/character40 104 | Armenian/character41 105 | Early_Aramaic/character01 106 | Early_Aramaic/character02 107 | Early_Aramaic/character03 108 | Early_Aramaic/character04 109 | Early_Aramaic/character05 110 | Early_Aramaic/character06 111 | Early_Aramaic/character07 112 | Early_Aramaic/character08 113 | Early_Aramaic/character09 114 | Early_Aramaic/character10 115 | Early_Aramaic/character11 116 | Early_Aramaic/character12 117 | Early_Aramaic/character13 118 | Early_Aramaic/character14 119 | Early_Aramaic/character15 120 | Early_Aramaic/character16 121 | Early_Aramaic/character17 122 | Early_Aramaic/character18 123 | Early_Aramaic/character19 124 | Early_Aramaic/character20 125 | Early_Aramaic/character21 126 | Early_Aramaic/character22 127 | Bengali/character01 128 | Bengali/character02 129 | Bengali/character03 130 | Bengali/character04 131 | Bengali/character05 132 | Bengali/character06 133 | Bengali/character07 134 | Bengali/character08 135 | Bengali/character09 136 | Bengali/character10 137 | Bengali/character11 138 | Bengali/character12 139 | Bengali/character13 140 | Bengali/character14 141 | Bengali/character15 142 | Bengali/character16 143 | Bengali/character17 144 | Bengali/character18 145 | Bengali/character19 146 | Bengali/character20 147 | Bengali/character21 148 | Bengali/character22 149 | Bengali/character23 150 | Bengali/character24 151 | Bengali/character25 152 | Bengali/character26 153 | Bengali/character27 154 | Bengali/character28 155 | Bengali/character29 156 | Bengali/character30 157 | Bengali/character31 158 | Bengali/character32 159 | Bengali/character33 160 | Bengali/character34 161 | Bengali/character35 162 | Bengali/character36 163 | Bengali/character37 164 | Bengali/character38 165 | Bengali/character39 166 | Bengali/character40 167 | Bengali/character41 168 | Bengali/character42 169 | Bengali/character43 170 | Bengali/character44 171 | Bengali/character45 172 | Bengali/character46 173 | -------------------------------------------------------------------------------- /fewshot/models/kmeans_refine_radius_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 Mengye Ren, Eleni Triantafillou, Sachin Ravi, Jake Snell, 2 | # Kevin Swersky, Joshua B. Tenenbaum, Hugo Larochelle, Richars S. Zemel. 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the "Software"), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all 12 | # copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | # SOFTWARE. 21 | # ============================================================================= 22 | """Another KMeans based semisupervised model. Adds another class for distractors. The distractor 23 | has a zero vector for the mean representation, and a learned radius to capture the remainders. 24 | 25 | Author: Mengye Ren (mren@cs.toronto.edu) 26 | """ 27 | from __future__ import (absolute_import, division, print_function, 28 | unicode_literals) 29 | 30 | import numpy as np 31 | import tensorflow as tf 32 | 33 | from fewshot.models.distractor_utils import eval_distractor 34 | from fewshot.models.kmeans_refine_model import KMeansRefineModel 35 | from fewshot.models.kmeans_utils import assign_cluster_radii 36 | from fewshot.models.kmeans_utils import compute_logits_radii 37 | from fewshot.models.kmeans_utils import update_cluster 38 | from fewshot.models.model_factory import RegisterModel 39 | from fewshot.models.nnlib import concat 40 | from fewshot.utils import logger 41 | 42 | log = logger.get() 43 | 44 | flags = tf.flags 45 | flags.DEFINE_bool("learn_radius", True, 46 | "Whether or not to learn distractor radius.") 47 | flags.DEFINE_float("init_radius", 100.0, "Initial radius for the distractors.") 48 | FLAGS = tf.flags.FLAGS 49 | 50 | 51 | @RegisterModel("kmeans-refine-radius") 52 | class KMeansRefineRadiusModel(KMeansRefineModel): 53 | 54 | def predict(self): 55 | """See `model.py` for documentation.""" 56 | nclasses = self.nway 57 | num_cluster_steps = self.config.num_cluster_steps 58 | h_train, h_unlabel, h_test = self.get_encoded_inputs( 59 | self.x_train, self.x_unlabel, self.x_test) 60 | y_train = self.y_train 61 | protos = self._compute_protos(nclasses, h_train, y_train) 62 | 63 | # Distractor class has a zero vector as prototype. 64 | protos = concat([protos, tf.zeros_like(protos[:, 0:1, :])], 1) 65 | 66 | # Hard assignment for training images. 67 | prob_train = [None] * (nclasses + 1) 68 | for kk in range(nclasses): 69 | # [B, N, 1] 70 | prob_train[kk] = tf.expand_dims( 71 | tf.cast(tf.equal(y_train, kk), h_train.dtype), 2) 72 | prob_train[-1] = tf.zeros_like(prob_train[0]) 73 | prob_train = concat(prob_train, 2) 74 | 75 | # Initialize cluster radii. 76 | radii = [None] * (nclasses + 1) 77 | y_train_shape = tf.shape(y_train) 78 | bsize = y_train_shape[0] 79 | for kk in range(nclasses): 80 | radii[kk] = tf.ones([bsize, 1]) * 1.0 81 | 82 | # Distractor class has a larger radius. 83 | if FLAGS.learn_radius: 84 | log_distractor_radius = tf.get_variable( 85 | "log_distractor_radius", 86 | shape=[], 87 | dtype=tf.float32, 88 | initializer=tf.constant_initializer(np.log(FLAGS.init_radius))) 89 | distractor_radius = tf.exp(log_distractor_radius) 90 | else: 91 | distractor_radius = FLAGS.init_radius 92 | distractor_radius = tf.cond( 93 | tf.shape(self._x_unlabel)[1] > 0, lambda: distractor_radius, 94 | lambda: 100000.0) 95 | # distractor_radius = tf.Print(distractor_radius, [distractor_radius]) 96 | radii[-1] = tf.ones([bsize, 1]) * distractor_radius 97 | radii = concat(radii, 1) # [B, K] 98 | 99 | h_all = concat([h_train, h_unlabel], 1) 100 | logits_list = [] 101 | logits_list.append(compute_logits_radii(protos, h_test, radii)) 102 | 103 | # Run clustering. 104 | for tt in range(num_cluster_steps): 105 | # Label assignment. 106 | prob_unlabel = assign_cluster_radii(protos, h_unlabel, radii) 107 | prob_all = concat([prob_train, prob_unlabel], 1) 108 | protos = update_cluster(h_all, prob_all) 109 | logits_list.append(compute_logits_radii(protos, h_test, radii)) 110 | 111 | # Distractor evaluation. 112 | is_distractor = tf.equal(tf.argmax(prob_unlabel, axis=-1), nclasses) 113 | pred_non_distractor = 1.0 - tf.to_float(is_distractor) 114 | acc, recall, precision = eval_distractor(pred_non_distractor, 115 | self.y_unlabel) 116 | self._non_distractor_acc = acc 117 | self._distractor_recall = recall 118 | self._distractor_precision = precision 119 | self._distractor_pred = 1.0 - tf.exp(prob_unlabel[:, :, -1]) 120 | return logits_list 121 | -------------------------------------------------------------------------------- /fewshot/data/batch_iter.py: -------------------------------------------------------------------------------- 1 | """ 2 | A batch iterator. 3 | Usage: 4 | for idx in BatchIterator(num=1000, batch_size=25): 5 | inp_batch = inp_all[idx] 6 | labels_batch = labels_all[idx] 7 | train(inp_batch, labels_batch) 8 | """ 9 | from __future__ import division, print_function 10 | 11 | import numpy as np 12 | import threading 13 | 14 | from fewshot.utils.logger import get as get_logger 15 | 16 | 17 | class IBatchIterator(object): 18 | 19 | def __iter__(self): 20 | """Get iterable.""" 21 | return self 22 | 23 | def next(self): 24 | raise Exception("Not implemented") 25 | 26 | def reset(self): 27 | raise Exception("Not implemented") 28 | 29 | 30 | class BatchIterator(IBatchIterator): 31 | 32 | def __init__(self, 33 | num, 34 | batch_size=1, 35 | progress_bar=False, 36 | log_epoch=10, 37 | get_fn=None, 38 | cycle=False, 39 | shuffle=True, 40 | stagnant=False, 41 | seed=2, 42 | num_batches=-1): 43 | """Construct a batch iterator. 44 | Args: 45 | data: numpy.ndarray, (N, D), N is the number of examples, D is the 46 | feature dimension. 47 | labels: numpy.ndarray, (N), N is the number of examples. 48 | batch_size: int, batch size. 49 | """ 50 | 51 | self._num = num 52 | self._batch_size = batch_size 53 | self._step = 0 54 | self._num_steps = int(np.ceil(self._num / float(batch_size))) 55 | if num_batches > 0: 56 | self._num_steps = min(self._num_steps, num_batches) 57 | self._pb = None 58 | self._variables = None 59 | self._get_fn = get_fn 60 | self.get_fn = get_fn 61 | self._cycle = cycle 62 | self._shuffle_idx = np.arange(self._num) 63 | self._shuffle = shuffle 64 | self._random = np.random.RandomState(seed) 65 | if shuffle: 66 | self._random.shuffle(self._shuffle_idx) 67 | self._shuffle_flag = False 68 | self._stagnant = stagnant 69 | self._log_epoch = log_epoch 70 | self._log = get_logger() 71 | self._epoch = 0 72 | if progress_bar: 73 | self._pb = pb.get(self._num_steps) 74 | pass 75 | self._mutex = threading.Lock() 76 | pass 77 | 78 | def __iter__(self): 79 | """Get iterable.""" 80 | return self 81 | 82 | def __len__(self): 83 | """Get iterable length.""" 84 | return self._num_steps 85 | 86 | @property 87 | def variables(self): 88 | return self._variables 89 | 90 | def set_variables(self, variables): 91 | self._variables = variables 92 | 93 | def get_fn(idx): 94 | return self._get_fn(idx, variables=variables) 95 | 96 | self.get_fn = get_fn 97 | return self 98 | 99 | def reset(self): 100 | self._step = 0 101 | 102 | def print_progress(self): 103 | e = self._epoch 104 | a = (self._step * self._batch_size) % self._num 105 | b = self._num 106 | p = a / b * 100 107 | digit = int(np.ceil(np.log10(b))) 108 | progress_str = "{:" + str(digit) + "d}" 109 | progress_str = (progress_str + "/" + progress_str).format(int(a), int(b)) 110 | self._log.info("Epoch {:3d} Progress {} ({:5.2f}%)".format( 111 | e, progress_str, p)) 112 | pass 113 | 114 | def next(self): 115 | """Iterate next element.""" 116 | self._mutex.acquire() 117 | try: 118 | # Shuffle data. 119 | if self._shuffle_flag: 120 | self._random.shuffle(self._shuffle_idx) 121 | self._shuffle_flag = False 122 | 123 | # Read/write of self._step stay in a thread-safe block. 124 | if not self._cycle: 125 | if self._step >= self._num_steps: 126 | raise StopIteration() 127 | 128 | # Calc start/end based on current step. 129 | start = self._batch_size * self._step 130 | end = self._batch_size * (self._step + 1) 131 | 132 | # Progress bar. 133 | if self._pb is not None: 134 | self._pb.increment() 135 | 136 | # Epoch record. 137 | if self._cycle: 138 | if int(end / self._num) > int(start / self._num): 139 | self._epoch += 1 140 | 141 | # Increment step. 142 | if not self._stagnant: 143 | self._step += 1 144 | 145 | # Print progress 146 | if self._log_epoch > 0 and self._step % self._log_epoch == 0: 147 | self.print_progress() 148 | finally: 149 | self._mutex.release() 150 | 151 | if not self._cycle: 152 | end = min(self._num, end) 153 | idx = np.arange(start, end) 154 | idx = idx.astype("int") 155 | if self.get_fn is not None: 156 | return self.get_fn(idx) 157 | else: 158 | return idx 159 | else: 160 | start = start % self._num 161 | end = end % self._num 162 | if end > start: 163 | idx = np.arange(start, end) 164 | idx = idx.astype("int") 165 | idx = self._shuffle_idx[idx] 166 | else: 167 | idx = np.array(range(start, self._num) + range(0, end)) 168 | idx = idx.astype("int") 169 | idx = self._shuffle_idx[idx] 170 | # Shuffle every cycle. 171 | if self._shuffle: 172 | self._shuffle_flag = True 173 | if self.get_fn is not None: 174 | return self.get_fn(idx) 175 | else: 176 | return idx 177 | 178 | 179 | if __name__ == "__main__": 180 | b = BatchIterator( 181 | 400, 182 | batch_size=32, 183 | progress_bar=False, 184 | get_fn=lambda x: x, 185 | cycle=False, 186 | shuffle=False) 187 | for ii in b: 188 | print(ii) 189 | b.reset() 190 | for ii in b: 191 | print(ii) 192 | -------------------------------------------------------------------------------- /fewshot/models/basic_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 Mengye Ren, Eleni Triantafillou, Sachin Ravi, Jake Snell, 2 | # Kevin Swersky, Joshua B. Tenenbaum, Hugo Larochelle, Richars S. Zemel. 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the "Software"), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all 12 | # copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | # SOFTWARE. 21 | # ============================================================================= 22 | """ 23 | A prototypical network for few-shot classification task. 24 | 25 | Author: Mengye Ren (mren@cs.toronto.edu) 26 | 27 | In a single episode, the model computes the mean representation of the positive 28 | reference images as prototypes, and then calculates pairwise similarity in the 29 | retrieval set. The similarity score runs through a sigmoid to give [0, 1] 30 | prediction on whether a candidate belongs to the same class or not. The 31 | candidates are used to backpropagate into the feature extraction CNN model phi. 32 | """ 33 | 34 | from __future__ import (absolute_import, division, print_function, 35 | unicode_literals) 36 | 37 | import numpy as np 38 | import tensorflow as tf 39 | 40 | from fewshot.models.kmeans_utils import compute_logits 41 | from fewshot.models.model import Model 42 | from fewshot.models.model_factory import RegisterModel 43 | from fewshot.models.nnlib import (concat, weight_variable) 44 | from fewshot.utils import logger 45 | from fewshot.utils.debug import debug_identity 46 | 47 | FLAGS = tf.flags.FLAGS 48 | log = logger.get() 49 | 50 | 51 | @RegisterModel("basic") 52 | class BasicModel(Model): 53 | """A basic retrieval model that runs the images through a CNN and compute 54 | basic similarity scores.""" 55 | 56 | def get_encoded_inputs(self, *x_list, **kwargs): 57 | """Runs the reference and candidate images through the feature model phi. 58 | Returns: 59 | h_train: [B, N, D] 60 | h_unlabel: [B, P, D] 61 | h_test: [B, M, D] 62 | """ 63 | config = self.config 64 | bsize = tf.shape(self.x_train)[0] 65 | bsize = tf.shape(x_list[0])[0] 66 | num = [tf.shape(xx)[1] for xx in x_list] 67 | x_all = concat(x_list, 1) 68 | if 'ext_wts' in kwargs: 69 | ext_wts = kwargs['ext_wts'] 70 | else: 71 | ext_wts = None 72 | x_all = tf.reshape(x_all, 73 | [-1, config.height, config.width, config.num_channel]) 74 | h_all = self.phi(x_all, ext_wts=ext_wts) 75 | h_all = tf.reshape(h_all, [bsize, sum(num), -1]) 76 | h_list = tf.split(h_all, num, axis=1) 77 | return h_list 78 | 79 | def _compute_protos(self, nclasses, h_train, y_train): 80 | """Computes the prototypes, cluster centers. 81 | Args: 82 | nclasses: Int. Number of classes. 83 | h_train: [B, N, D], Train features. 84 | y_train: [B, N], Train class labels. 85 | Returns: 86 | protos: [B, K, D], Test prediction. 87 | """ 88 | protos = [None] * nclasses 89 | for kk in range(nclasses): 90 | # [B, N, 1] 91 | ksel = tf.expand_dims(tf.cast(tf.equal(y_train, kk), h_train.dtype), 2) 92 | # [B, N, D] 93 | protos[kk] = tf.reduce_sum(h_train * ksel, [1], keep_dims=True) 94 | protos[kk] /= tf.reduce_sum(ksel, [1, 2], keep_dims=True) 95 | protos[kk] = debug_identity(protos[kk], "proto") 96 | protos = concat(protos, 1) # [B, K, D] 97 | return protos 98 | 99 | def predict(self): 100 | """See `model.py` for documentation.""" 101 | h_train, h_test = self.get_encoded_inputs(self.x_train, self.x_test) 102 | y_train = self.y_train 103 | nclasses = self.nway 104 | protos = self._compute_protos(nclasses, h_train, y_train) 105 | logits = compute_logits(protos, h_test) 106 | return [logits] 107 | 108 | def get_train_op(self, logits, y_test): 109 | """See `model.py` for documentation.""" 110 | if FLAGS.allstep: 111 | log.info("Compute average loss for all timestep.") 112 | if self.nway > 1: 113 | loss = tf.add_n([ 114 | tf.nn.sparse_softmax_cross_entropy_with_logits( 115 | logits=ll, labels=y_test) for ll in logits 116 | ]) / float(len(logits)) 117 | else: 118 | loss = tf.add_n([ 119 | tf.nn.sigmoid_cross_entropy_with_logits(logits=ll, labels=y_test) 120 | for ll in logits 121 | ]) / float(len(logits)) 122 | else: 123 | log.info("Compute loss for the final timestep.") 124 | if self.nway > 1: 125 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits( 126 | logits=logits[-1], labels=y_test) 127 | else: 128 | loss = tf.nn.sigmoid_cross_entropy_with_logits( 129 | logits=logits[-1], labels=y_test) 130 | loss = tf.reduce_mean(loss) 131 | wd_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) 132 | log.info("Weight decay variables: {}".format(wd_losses)) 133 | if len(wd_losses) > 0: 134 | loss += tf.add_n(wd_losses) 135 | opt = tf.train.AdamOptimizer(self.learn_rate) 136 | grads_and_vars = opt.compute_gradients(loss) 137 | train_op = opt.apply_gradients(grads_and_vars) 138 | return loss, train_op 139 | -------------------------------------------------------------------------------- /fewshot/models/kmeans_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 Mengye Ren, Eleni Triantafillou, Sachin Ravi, Jake Snell, 2 | # Kevin Swersky, Joshua B. Tenenbaum, Hugo Larochelle, Richars S. Zemel. 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the "Software"), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all 12 | # copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | # SOFTWARE. 21 | # ============================================================================= 22 | from __future__ import division, print_function 23 | 24 | import numpy as np 25 | import tensorflow as tf 26 | 27 | from fewshot.models.nnlib import concat, round_st 28 | from fewshot.utils import logger 29 | 30 | log = logger.get() 31 | 32 | 33 | def compute_logits(cluster_centers, data): 34 | """Computes the logits of being in one cluster, squared Euclidean. 35 | Args: 36 | cluster_centers: [B, K, D] Cluster center representation. 37 | data: [B, N, D] Data representation. 38 | Returns: 39 | log_prob: [B, N, K] logits. 40 | """ 41 | cluster_centers = tf.expand_dims(cluster_centers, 1) # [B, 1, K, D] 42 | data = tf.expand_dims(data, 2) # [B, N, 1, D] 43 | # [B, N, K] 44 | neg_dist = -tf.reduce_sum(tf.square(data - cluster_centers), [-1]) 45 | return neg_dist 46 | 47 | 48 | def assign_cluster(cluster_centers, data): 49 | """Assigns data to cluster center, using K-Means. 50 | Args: 51 | cluster_centers: [B, K, D] Cluster center representation. 52 | data: [B, N, D] Data representation. 53 | Returns: 54 | prob: [B, N, K] Soft assignment. 55 | """ 56 | logits = compute_logits(cluster_centers, data) 57 | logits_shape = tf.shape(logits) 58 | bsize = logits_shape[0] 59 | ndata = logits_shape[1] 60 | ncluster = logits_shape[2] 61 | logits = tf.reshape(logits, [-1, ncluster]) 62 | prob = tf.nn.softmax(logits) # Use softmax distance. 63 | prob = tf.reshape(prob, [bsize, ndata, ncluster]) 64 | return prob 65 | 66 | 67 | def update_cluster(data, prob, fix_last_row=False): 68 | """Updates cluster center based on assignment, standard K-Means. 69 | Args: 70 | data: [B, N, D]. Data representation. 71 | prob: [B, N, K]. Cluster assignment soft probability. 72 | fix_last_row: Bool. Whether or not to fix the last row to 0. 73 | Returns: 74 | cluster_centers: [B, K, D]. Cluster center representation. 75 | """ 76 | # Normalize accross N. 77 | if fix_last_row: 78 | prob_ = prob[:, :, :-1] 79 | else: 80 | prob_ = prob 81 | prob_sum = tf.reduce_sum(prob_, [1], keep_dims=True) 82 | prob_sum += tf.to_float(tf.equal(prob_sum, 0.0)) 83 | prob2 = prob_ / prob_sum 84 | cluster_centers = tf.reduce_sum( 85 | tf.expand_dims(data, 2) * tf.expand_dims(prob2, 3), [1]) 86 | if fix_last_row: 87 | cluster_centers = concat( 88 | [cluster_centers, 89 | tf.zeros_like(cluster_centers[:, 0:1, :])], 1) 90 | return cluster_centers 91 | 92 | 93 | def assign_cluster_radii(cluster_centers, data, radii): 94 | """Assigns data to cluster center, using K-Means. 95 | 96 | Args: 97 | cluster_centers: [B, K, D] Cluster center representation. 98 | data: [B, N, D] Data representation. 99 | radii: [B, K] Cluster radii. 100 | Returns: 101 | prob: [B, N, K] Soft assignment. 102 | """ 103 | logits = compute_logits_radii(cluster_centers, data, radii) 104 | logits_shape = tf.shape(logits) 105 | bsize = logits_shape[0] 106 | ndata = logits_shape[1] 107 | ncluster = logits_shape[2] 108 | logits = tf.reshape(logits, [-1, ncluster]) 109 | prob = tf.nn.softmax(logits) 110 | prob = tf.reshape(prob, [bsize, ndata, ncluster]) 111 | return prob 112 | 113 | 114 | def compute_logits_radii(cluster_centers, data, radii): 115 | """Computes the logits of being in one cluster, squared Euclidean. 116 | 117 | Args: 118 | cluster_centers: [B, K, D] Cluster center representation. 119 | data: [B, N, D] Data representation. 120 | radii: [B, K] Cluster radii. 121 | Returns: 122 | log_prob: [B, N, K] logits. 123 | """ 124 | cluster_centers = tf.expand_dims(cluster_centers, 1) # [B, 1, K, D] 125 | data = tf.expand_dims(data, 2) # [B, N, 1, D] 126 | radii = tf.expand_dims(radii, 1) # [B, 1, K] 127 | # [B, N, K] 128 | neg_dist = -tf.reduce_sum(tf.square(data - cluster_centers), [-1]) 129 | logits = neg_dist / 2.0 / (radii**2) 130 | norm_constant = 0.5 * tf.log(2 * np.pi) + tf.log(radii) 131 | logits -= norm_constant 132 | return logits 133 | 134 | 135 | def assign_cluster_soft_mask(cluster_centers, data, mask): 136 | """Assigns data to cluster center, using K-Means. 137 | Args: 138 | cluster_centers: [B, K, D] Cluster center representation. 139 | data: [B, N, D] Data representation. 140 | mask: [B, N, K] Mask for each cluster. 141 | Returns: 142 | prob: [B, N, K] Soft assignment. 143 | """ 144 | logits = compute_logits(cluster_centers, data) 145 | logits_shape = tf.shape(logits) 146 | bsize = logits_shape[0] 147 | ndata = logits_shape[1] 148 | ncluster = logits_shape[2] 149 | logits = tf.reshape(logits, [-1, ncluster]) 150 | prob = tf.nn.softmax(logits) # Use softmax distance. 151 | prob = tf.reshape(prob, [bsize, ndata, ncluster]) * mask 152 | return prob, mask 153 | -------------------------------------------------------------------------------- /fewshot/data/concurrent_batch_iter.py: -------------------------------------------------------------------------------- 1 | from __future__ import (absolute_import, division, print_function, 2 | unicode_literals) 3 | import six 4 | import sys 5 | 6 | is_py2 = sys.version[0] == "2" 7 | if is_py2: 8 | import Queue as queue 9 | else: 10 | import queue as queue 11 | import threading 12 | 13 | from fewshot.data.batch_iter import IBatchIterator, BatchIterator 14 | from fewshot.utils import logger 15 | 16 | 17 | class BatchProducer(threading.Thread): 18 | 19 | def __init__(self, q, batch_iter): 20 | super(BatchProducer, self).__init__() 21 | threading.Thread.__init__(self) 22 | self.q = q 23 | self.batch_iter = batch_iter 24 | self.log = logger.get() 25 | self._stoper = threading.Event() 26 | self.daemon = True 27 | 28 | def stop(self): 29 | self._stoper.set() 30 | 31 | def stopped(self): 32 | return self._stoper.isSet() 33 | 34 | def run(self): 35 | while not self.stopped(): 36 | try: 37 | b = self.batch_iter.next() 38 | self.q.put(b) 39 | except StopIteration: 40 | self.q.put(None) 41 | break 42 | 43 | 44 | class BatchConsumer(threading.Thread): 45 | 46 | def __init__(self, q): 47 | super(BatchConsumer, self).__init__() 48 | self.q = q 49 | self.daemon = True 50 | self._stoper = threading.Event() 51 | 52 | def stop(self): 53 | self._stoper.set() 54 | 55 | def stopped(self): 56 | return self._stoper.isSet() 57 | 58 | def run(self): 59 | while not self.stopped(): 60 | try: 61 | self.q.get(False) 62 | self.q.task_done() 63 | except queue.Empty: 64 | pass 65 | 66 | 67 | class ConcurrentBatchIterator(IBatchIterator): 68 | 69 | def __init__(self, 70 | batch_iter, 71 | max_queue_size=10, 72 | num_threads=5, 73 | log_queue=20, 74 | name=None): 75 | """ 76 | Data provider wrapper that supports concurrent data fetching. 77 | """ 78 | super(ConcurrentBatchIterator, self).__init__() 79 | self.max_queue_size = max_queue_size 80 | self.num_threads = num_threads 81 | self.q = queue.Queue(maxsize=max_queue_size) 82 | self.log = logger.get() 83 | self.batch_iter = batch_iter 84 | self.fetchers = [] 85 | self.init_fetchers() 86 | self.counter = 0 87 | self.relaunch = True 88 | self._stopped = False 89 | self.log_queue = log_queue 90 | self.name = name 91 | 92 | def __len__(self): 93 | return len(self.batch_iter) 94 | 95 | def init_fetchers(self): 96 | for ii in six.moves.xrange(self.num_threads): 97 | f = BatchProducer(self.q, self.batch_iter) 98 | f.start() 99 | self.fetchers.append(f) 100 | 101 | def get_name(self): 102 | if self.name is not None: 103 | return "Queue \"{}\":".format(self.name) 104 | else: 105 | return "" 106 | 107 | def info(self, message): 108 | self.log.info("{} {}".format(self.get_name(), message), verbose=2) 109 | 110 | def warning(self, message): 111 | self.log.warning("{} {}".format(self.get_name(), message)) 112 | 113 | def scan(self, do_print=False): 114 | dead = [] 115 | num_alive = 0 116 | for ff in self.fetchers: 117 | if not ff.is_alive(): 118 | dead.append(ff) 119 | self.info("Found one dead thread.") 120 | if self.relaunch: 121 | self.info("Relaunch") 122 | fnew = BatchProducer(self.q, self.batch_iter) 123 | fnew.start() 124 | self.fetchers.append(fnew) 125 | else: 126 | num_alive += 1 127 | if do_print: 128 | self.info("Number of alive threads: {}".format(num_alive)) 129 | s = self.q.qsize() 130 | if s > self.max_queue_size / 3: 131 | self.info("Data queue size: {}".format(s)) 132 | else: 133 | self.warning("Data queue size: {}".format(s)) 134 | for dd in dead: 135 | self.fetchers.remove(dd) 136 | 137 | def next(self): 138 | if self._stopped: 139 | raise StopIteration 140 | self.scan(do_print=(self.counter % self.log_queue == 0)) 141 | if self.counter % self.log_queue == 0: 142 | self.counter = 0 143 | batch = self.q.get() 144 | self.q.task_done() 145 | self.counter += 1 146 | while batch is None: 147 | self.info("Got an empty batch. Ending iteration.") 148 | self.relaunch = False 149 | try: 150 | batch = self.q.get(False) 151 | self.q.task_done() 152 | qempty = False 153 | except queue.Empty: 154 | qempty = True 155 | pass 156 | 157 | if qempty: 158 | self.info("Queue empty. Scanning for alive thread.") 159 | # Scan for alive thread. 160 | found_alive = False 161 | for ff in self.fetchers: 162 | if ff.is_alive(): 163 | found_alive = True 164 | break 165 | 166 | self.info("No alive thread found. Joining.") 167 | # If no alive thread, join all. 168 | if not found_alive: 169 | for ff in self.fetchers: 170 | ff.join() 171 | self._stopped = True 172 | raise StopIteration 173 | else: 174 | self.info("Got another batch from the queue.") 175 | return batch 176 | 177 | def reset(self): 178 | self.info("Resetting concurrent batch iter") 179 | self.info("Stopping all workers") 180 | for f in self.fetchers: 181 | f.stop() 182 | self.info("Cleaning queue") 183 | cleaner = BatchConsumer(self.q) 184 | cleaner.start() 185 | for f in self.fetchers: 186 | f.join() 187 | self.q.join() 188 | cleaner.stop() 189 | self.info("Resetting index") 190 | self.batch_iter.reset() 191 | self.info("Restarting workers") 192 | self.fetchers = [] 193 | self.init_fetchers() 194 | self.relaunch = True 195 | self._stopped = False 196 | 197 | 198 | if __name__ == "__main__": 199 | from batch_iter import BatchIterator 200 | b = BatchIterator(100, batch_size=6, get_fn=None) 201 | cb = ConcurrentBatchIterator(b, max_queue_size=5, num_threads=3) 202 | for _batch in cb: 203 | log = logger.get() 204 | log.info(("Final out", _batch)) 205 | -------------------------------------------------------------------------------- /fewshot/configs/omniglot_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 Mengye Ren, Eleni Triantafillou, Sachin Ravi, Jake Snell, 2 | # Kevin Swersky, Joshua B. Tenenbaum, Hugo Larochelle, Richars S. Zemel. 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the "Software"), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all 12 | # copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | # SOFTWARE. 21 | # ============================================================================= 22 | from fewshot.configs.config_factory import RegisterConfig 23 | 24 | 25 | @RegisterConfig("omniglot", "basic") 26 | class BasicConfig(object): 27 | """Standard CNN on Omniglot with prototypical layer.""" 28 | 29 | def __init__(self): 30 | self.name = "omniglot_basic" 31 | self.model_class = "basic" 32 | self.height = 28 33 | self.width = 28 34 | self.num_channel = 1 35 | self.steps_per_valid = 1000 36 | self.steps_per_log = 100 37 | self.steps_per_save = 1000 38 | self.filter_size = [[3, 3, 1, 64]] + [[3, 3, 64, 64]] * 3 39 | self.strides = [[1, 1, 1, 1]] * 4 40 | self.pool_fn = ["max_pool"] * 4 41 | self.pool_size = [[1, 2, 2, 1]] * 4 42 | self.pool_strides = [[1, 2, 2, 1]] * 4 43 | self.conv_act_fn = ["relu"] * 4 44 | self.conv_init_method = None 45 | self.conv_init_std = [1.0e-2] * 4 46 | self.wd = 5e-5 47 | self.learn_rate = 1e-3 48 | self.normalization = "batch_norm" 49 | self.lr_scheduler = "fixed" 50 | self.lr_decay_steps = [4000, 6000, 8000, 12000, 14000, 16000, 18000, 20000] 51 | self.max_train_steps = 20000 52 | self.lr_list = list( 53 | map(lambda x: self.learn_rate * (0.5)**x, 54 | range(len(self.lr_decay_steps)))) 55 | self.similarity = "euclidean" 56 | 57 | 58 | @RegisterConfig("omniglot", "basic-pretrain") 59 | class BasicTestConfig(BasicConfig): 60 | 61 | def __init__(self): 62 | super(BasicTestConfig, self).__init__() 63 | self.lr_decay_steps = [2000, 2500, 3000, 3500] 64 | self.lr_list = list( 65 | map(lambda x: self.learn_rate * (0.5)**x, 66 | range(1, len(self.lr_decay_steps) + 1))) 67 | self.max_train_steps = 4000 68 | 69 | 70 | @RegisterConfig("omniglot", "basic-test") 71 | class BasicPretrainConfig(BasicConfig): 72 | 73 | def __init__(self): 74 | super(BasicPretrainConfig, self).__init__() 75 | self.lr_decay_steps = [30, 60, 90] 76 | self.lr_list = list( 77 | map(lambda x: self.learn_rate * (0.5)**x, 78 | range(1, len(self.lr_decay_steps) + 1))) 79 | self.max_train_steps = 100 80 | self.steps_per_valid = 10 81 | self.steps_per_log = 10 82 | self.steps_per_save = 10 83 | 84 | 85 | @RegisterConfig("omniglot", "kmeans-refine") 86 | class KMeansRefineConfig(BasicConfig): 87 | 88 | def __init__(self): 89 | super(KMeansRefineConfig, self).__init__() 90 | self.name = "omniglot_kmeans-refine" 91 | self.model_class = "kmeans-refine" 92 | self.num_cluster_steps = 1 93 | 94 | 95 | @RegisterConfig("omniglot", "kmeans-refine-test") 96 | class KMeansRefineTestConfig(KMeansRefineConfig): 97 | 98 | def __init__(self): 99 | super(KMeansRefineTestConfig, self).__init__() 100 | self.lr_decay_steps = [30, 60, 90] 101 | self.lr_list = list( 102 | map(lambda x: self.learn_rate * (0.5)**x, 103 | range(1, len(self.lr_decay_steps) + 1))) 104 | self.max_train_steps = 100 105 | self.steps_per_valid = 10 106 | self.steps_per_log = 10 107 | self.steps_per_save = 10 108 | 109 | 110 | @RegisterConfig("omniglot", "kmeans-refine-radius") 111 | class KMeansRefineRadiusConfig(BasicConfig): 112 | 113 | def __init__(self): 114 | super(KMeansRefineRadiusConfig, self).__init__() 115 | self.name = "omniglot_kmeans-refine-radius" 116 | self.model_class = "kmeans-refine-radius" 117 | self.num_cluster_steps = 1 118 | 119 | 120 | @RegisterConfig("omniglot", "kmeans-refine-radius-test") 121 | class KMeansRefineRadiusTestConfig(KMeansRefineRadiusConfig): 122 | 123 | def __init__(self): 124 | super(KMeansRefineRadiusTestConfig, self).__init__() 125 | self.lr_decay_steps = [30, 60, 90] 126 | self.lr_list = list( 127 | map(lambda x: self.learn_rate * (0.5)**x, 128 | range(1, len(self.lr_decay_steps) + 1))) 129 | self.max_train_steps = 100 130 | self.steps_per_valid = 10 131 | self.steps_per_log = 10 132 | self.steps_per_save = 10 133 | 134 | 135 | @RegisterConfig("omniglot", "kmeans-refine-mask") 136 | class KMeansRefineMaskConfig(BasicConfig): 137 | 138 | def __init__(self): 139 | super(KMeansRefineMaskConfig, self).__init__() 140 | self.name = "omniglot_kmeans-refine-mask" 141 | self.model_class = "kmeans-refine-mask" 142 | self.num_cluster_steps = 1 143 | 144 | 145 | @RegisterConfig("omniglot", "kmeans-refine-mask-test") 146 | class KMeansRefineMaskTestConfig(KMeansRefineMaskConfig): 147 | 148 | def __init__(self): 149 | super(KMeansRefineMaskTestConfig, self).__init__() 150 | self.lr_decay_steps = [30, 60, 90] 151 | self.lr_list = list( 152 | map(lambda x: self.learn_rate * (0.5)**x, 153 | range(1, len(self.lr_decay_steps) + 1))) 154 | self.max_train_steps = 100 155 | self.steps_per_valid = 10 156 | self.steps_per_log = 10 157 | self.steps_per_save = 10 158 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # few-shot-ssl-public 2 | Code for paper 3 | *Meta-Learning for Semi-Supervised Few-Shot Classification.* [[arxiv](https://arxiv.org/abs/1803.00676)] 4 | 5 | ## Dependencies 6 | * cv2 7 | * numpy 8 | * pandas 9 | * python 2.7 / 3.5+ 10 | * tensorflow 1.3+ 11 | * tqdm 12 | 13 | Our code is tested on Ubuntu 14.04 and 16.04. 14 | 15 | ## Setup 16 | First, designate a folder to be your data root: 17 | ``` 18 | export DATA_ROOT={DATA_ROOT} 19 | ``` 20 | 21 | Then, set up the datasets following the instructions in the subsections. 22 | 23 | ### Omniglot 24 | [[Google Drive](https://drive.google.com/open?id=1INlOTyPtnCJgm0hBVvtRLu5a0itk8bjs)] (9.3 MB) 25 | ``` 26 | # Download and place "omniglot.tar.gz" in "$DATA_ROOT/omniglot". 27 | mkdir -p $DATA_ROOT/omniglot 28 | cd $DATA_ROOT/omniglot 29 | mv ~/Downloads/omniglot.tar.gz . 30 | tar -xzvf omniglot.tar.gz 31 | rm -f omniglot.tar.gz 32 | ``` 33 | 34 | ### miniImageNet 35 | [[Google Drive](https://drive.google.com/open?id=16V_ZlkW4SsnNDtnGmaBRq2OoPmUOc5mY)] (1.1 GB) 36 | 37 | Update: Python 2 and 3 compatible version: 38 | [[train](https://drive.google.com/file/d/1I3itTXpXxGV68olxM5roceUMG8itH9Xj)] 39 | [[val](https://drive.google.com/file/d/1KY5e491bkLFqJDp0-UWou3463Mo8AOco)] 40 | [[test](https://drive.google.com/file/d/1wpmY-hmiJUUlRBkO9ZDCXAcIpHEFdOhD)] 41 | ``` 42 | # Download and place "mini-imagenet.tar.gz" in "$DATA_ROOT/mini-imagenet". 43 | mkdir -p $DATA_ROOT/mini-imagenet 44 | cd $DATA_ROOT/mini-imagenet 45 | mv ~/Downloads/mini-imagenet.tar.gz . 46 | tar -xzvf mini-imagenet.tar.gz 47 | rm -f mini-imagenet.tar.gz 48 | ``` 49 | 50 | ### tieredImageNet 51 | [[Google Drive](https://drive.google.com/open?id=1g1aIDy2Ar_MViF2gDXFYDBTR-HYecV07)] (12.9 GB) 52 | ``` 53 | # Download and place "tiered-imagenet.tar" in "$DATA_ROOT/tiered-imagenet". 54 | mkdir -p $DATA_ROOT/tiered-imagenet 55 | cd $DATA_ROOT/tiered-imagenet 56 | mv ~/Downloads/tiered-imagenet.tar . 57 | tar -xvf tiered-imagenet.tar 58 | rm -f tiered-imagenet.tar 59 | ``` 60 | Note: Please make sure that the following hardware requirements are met before running 61 | tieredImageNet experiments. 62 | * Disk: **30 GB** 63 | * RAM: **32 GB** 64 | 65 | 66 | ## Core Experiments 67 | Please run the following scripts to reproduce the core experiments. 68 | ``` 69 | # Clone the repository. 70 | git clone https://github.com/renmengye/few-shot-ssl-public.git 71 | cd few-shot-ssl-public 72 | 73 | # To train a model. 74 | python run_exp.py --data_root $DATA_ROOT \ 75 | --dataset {DATASET} \ 76 | --label_ratio {LABEL_RATIO} \ 77 | --model {MODEL} \ 78 | --results {SAVE_CKPT_FOLDER} \ 79 | [--disable_distractor] 80 | 81 | # To test a model. 82 | python run_exp.py --data_root $DATA_ROOT \ 83 | --dataset {DATASET} \ 84 | --label_ratio {LABEL_RATIO} \ 85 | --model {MODEL} \ 86 | --results {SAVE_CKPT_FOLDER} \ 87 | --eval --pretrain {MODEL_ID} \ 88 | [--num_unlabel {NUM_UNLABEL}] \ 89 | [--num_test {NUM_TEST}] \ 90 | [--disable_distractor] \ 91 | [--use_test] 92 | ``` 93 | * Possible `{MODEL}` options are `basic`, `kmeans-refine`, `kmeans-refine-radius`, and `kmeans-refine-mask`. 94 | * Possible `{DATASET}` options are `omniglot`, `mini-imagenet`, `tiered-imagenet`. 95 | * Use `{LABEL_RATIO}` 0.1 for `omniglot` and `tiered-imagenet`, and 0.4 for `mini-imagenet`. 96 | * Replace `{MODEL_ID}` with the model ID obtained from the training program. 97 | * Replace `{SAVE_CKPT_FOLDER}` with the folder where you save your checkpoints. 98 | * Add additional flags `--num_unlabel 20 --num_test 20` for testing `mini-imagenet` and `tiered-imagenet` models, so that each episode contains 20 unlabeled images per class and 20 query images per class. 99 | * Add an additional flag `--disable_distractor` to remove all distractor classes in the unlabeled images. 100 | * Add an additional flag `--use_test` to evaluate on the test set instead of the validation set. 101 | * More commandline details see `run_exp.py`. 102 | 103 | ## Simple Baselines for Few-Shot Classification 104 | Please run the following script to reproduce a suite of baseline results. 105 | ``` 106 | python run_baseline_exp.py --data_root $DATA_ROOT \ 107 | --dataset {DATASET} 108 | ``` 109 | * Possible `DATASET` options are `omniglot`, `mini-imagenet`, `tiered-imagenet`. 110 | 111 | ## Run over Multiple Random Splits 112 | Please run the following script to reproduce results over 10 random label/unlabel splits, and test 113 | the model with different number of unlabeled items per episode. The default seeds are 0, 1001, ..., 114 | 9009. 115 | ``` 116 | python run_multi_exp.py --data_root $DATA_ROOT \ 117 | --dataset {DATASET} \ 118 | --label_ratio {LABEL_RATIO} \ 119 | --model {MODEL} \ 120 | [--disable_distractor] \ 121 | [--use_test] 122 | ``` 123 | * Possible `MODEL` options are `basic`, `kmeans-refine`, `kmeans-refine-radius`, and `kmeans-refine-mask`. 124 | * Possible `DATASET` options are `omniglot`, `mini_imagenet`, `tiered_imagenet`. 125 | * Use `{LABEL_RATIO}` 0.1 for `omniglot` and `tiered-imagenet`, and 0.4 for `mini-imagenet`. 126 | * Add an additional flag `--disable_distractor` to remove all distractor classes in the unlabeled images. 127 | * Add an additional flag `--use_test` to evaluate on the test set instead of the validation set. 128 | 129 | ## Citation 130 | If you use our code, please consider cite the following: 131 | * Mengye Ren, Eleni Triantafillou, Sachin Ravi, Jake Snell, Kevin Swersky, Joshua B. Tenenbaum, Hugo Larochelle and Richard S. Zemel. 132 | Meta-Learning for Semi-Supervised Few-Shot Classification. 133 | In *Proceedings of 6th International Conference on Learning Representations (ICLR)*, 2018. 134 | 135 | ``` 136 | @inproceedings{ren18fewshotssl, 137 | author = {Mengye Ren and 138 | Eleni Triantafillou and 139 | Sachin Ravi and 140 | Jake Snell and 141 | Kevin Swersky and 142 | Joshua B. Tenenbaum and 143 | Hugo Larochelle and 144 | Richard S. Zemel}, 145 | title = {Meta-Learning for Semi-Supervised Few-Shot Classification}, 146 | booktitle= {Proceedings of 6th International Conference on Learning Representations {ICLR}}, 147 | year = {2018}, 148 | } 149 | ``` 150 | -------------------------------------------------------------------------------- /fewshot/utils/batch_iter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 Mengye Ren, Eleni Triantafillou, Sachin Ravi, Jake Snell, 2 | # Kevin Swersky, Joshua B. Tenenbaum, Hugo Larochelle, Richars S. Zemel. 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the "Software"), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all 12 | # copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | # SOFTWARE. 21 | # ============================================================================= 22 | """ 23 | A batch iterator. 24 | 25 | Usage: 26 | for idx in BatchIterator(num=1000, batch_size=25): 27 | inp_batch = inp_all[idx] 28 | labels_batch = labels_all[idx] 29 | train(inp_batch, labels_batch) 30 | """ 31 | from __future__ import division, print_function 32 | 33 | import numpy as np 34 | import threading 35 | from fewshot.utils import logger 36 | 37 | 38 | class IBatchIterator(object): 39 | 40 | def __iter__(self): 41 | """Get iterable.""" 42 | return self 43 | 44 | def next(self): 45 | raise Exception("Not implemented") 46 | 47 | def reset(self): 48 | raise Exception("Not implemented") 49 | 50 | pass 51 | 52 | 53 | class BatchIterator(IBatchIterator): 54 | 55 | def __init__(self, 56 | num, 57 | batch_size=1, 58 | progress_bar=False, 59 | log_epoch=10, 60 | get_fn=None, 61 | cycle=False, 62 | shuffle=True, 63 | stagnant=False, 64 | seed=2, 65 | num_batches=-1): 66 | """Construct a batch iterator. 67 | 68 | Args: 69 | data: numpy.ndarray, (N, D), N is the number of examples, D is the 70 | feature dimension. 71 | labels: numpy.ndarray, (N), N is the number of examples. 72 | batch_size: int, batch size. 73 | """ 74 | 75 | self._num = num 76 | self._batch_size = batch_size 77 | self._step = 0 78 | self._num_steps = int(np.ceil(self._num / float(batch_size))) 79 | if num_batches > 0: 80 | self._num_steps = min(self._num_steps, num_batches) 81 | self._pb = None 82 | self._variables = None 83 | self._get_fn = get_fn 84 | self.get_fn = get_fn 85 | self._cycle = cycle 86 | self._shuffle_idx = np.arange(self._num) 87 | self._shuffle = shuffle 88 | self._random = np.random.RandomState(seed) 89 | if shuffle: 90 | self._random.shuffle(self._shuffle_idx) 91 | self._shuffle_flag = False 92 | self._stagnant = stagnant 93 | self._log_epoch = log_epoch 94 | self._log = logger.get() 95 | self._epoch = 0 96 | if progress_bar: 97 | self._pb = pb.get(self._num_steps) 98 | pass 99 | self._mutex = threading.Lock() 100 | pass 101 | 102 | def __iter__(self): 103 | """Get iterable.""" 104 | return self 105 | 106 | def __len__(self): 107 | """Get iterable length.""" 108 | return self._num_steps 109 | 110 | @property 111 | def variables(self): 112 | return self._variables 113 | 114 | def set_variables(self, variables): 115 | self._variables = variables 116 | 117 | def get_fn(idx): 118 | return self._get_fn(idx, variables=variables) 119 | 120 | self.get_fn = get_fn 121 | return self 122 | 123 | def reset(self): 124 | self._step = 0 125 | 126 | def print_progress(self): 127 | e = self._epoch 128 | a = (self._step * self._batch_size) % self._num 129 | b = self._num 130 | p = a / b * 100 131 | digit = int(np.ceil(np.log10(b))) 132 | progress_str = "{:" + str(digit) + "d}" 133 | progress_str = (progress_str + "/" + progress_str).format(int(a), int(b)) 134 | self._log.info( 135 | "Epoch {:3d} Progress {} ({:5.2f}%)".format(e, progress_str, p)) 136 | pass 137 | 138 | def next(self): 139 | """Iterate next element.""" 140 | self._mutex.acquire() 141 | try: 142 | # Shuffle data. 143 | if self._shuffle_flag: 144 | self._random.shuffle(self._shuffle_idx) 145 | self._shuffle_flag = False 146 | 147 | # Read/write of self._step stay in a thread-safe block. 148 | if not self._cycle: 149 | if self._step >= self._num_steps: 150 | raise StopIteration() 151 | 152 | # Calc start/end based on current step. 153 | start = self._batch_size * self._step 154 | end = self._batch_size * (self._step + 1) 155 | 156 | # Progress bar. 157 | if self._pb is not None: 158 | self._pb.increment() 159 | 160 | # Epoch record. 161 | if self._cycle: 162 | if int(end / self._num) > int(start / self._num): 163 | self._epoch += 1 164 | 165 | # Increment step. 166 | if not self._stagnant: 167 | self._step += 1 168 | 169 | # Print progress 170 | if self._log_epoch > 0 and self._step % self._log_epoch == 0: 171 | self.print_progress() 172 | finally: 173 | self._mutex.release() 174 | 175 | if not self._cycle: 176 | end = min(self._num, end) 177 | idx = np.arange(start, end) 178 | idx = idx.astype("int") 179 | if self.get_fn is not None: 180 | return self.get_fn(idx) 181 | else: 182 | return idx 183 | else: 184 | start = start % self._num 185 | end = end % self._num 186 | if end > start: 187 | idx = np.arange(start, end) 188 | idx = idx.astype("int") 189 | idx = self._shuffle_idx[idx] 190 | else: 191 | idx = np.concatenate([np.arange(start, self._num), np.arange(0, end)]) 192 | idx = idx.astype("int") 193 | idx = self._shuffle_idx[idx] 194 | # Shuffle every cycle. 195 | if self._shuffle: 196 | self._shuffle_flag = True 197 | if self.get_fn is not None: 198 | return self.get_fn(idx) 199 | else: 200 | return idx 201 | pass 202 | 203 | 204 | if __name__ == "__main__": 205 | b = BatchIterator( 206 | 400, 207 | batch_size=32, 208 | progress_bar=False, 209 | get_fn=lambda x: x, 210 | cycle=False, 211 | shuffle=False) 212 | for ii in b: 213 | print(ii) 214 | b.reset() 215 | for ii in b: 216 | print(ii) 217 | -------------------------------------------------------------------------------- /fewshot/models/kmeans_refine_mask_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 Mengye Ren, Eleni Triantafillou, Sachin Ravi, Jake Snell, 2 | # Kevin Swersky, Joshua B. Tenenbaum, Hugo Larochelle, Richars S. Zemel. 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the "Software"), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all 12 | # copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | # SOFTWARE. 21 | # ============================================================================= 22 | """ Another KMeans based semisupervised model. Predict a mask based on the neighbor distance 23 | distribution. 24 | 25 | Author: Mengye Ren (mren@cs.toronto.edu) 26 | """ 27 | from __future__ import (absolute_import, division, print_function, 28 | unicode_literals) 29 | 30 | import numpy as np 31 | import tensorflow as tf 32 | 33 | from fewshot.models.distractor_utils import eval_distractor 34 | from fewshot.models.kmeans_refine_model import KMeansRefineModel 35 | from fewshot.models.kmeans_utils import assign_cluster_soft_mask 36 | from fewshot.models.kmeans_utils import compute_logits 37 | from fewshot.models.kmeans_utils import update_cluster 38 | from fewshot.models.model_factory import RegisterModel 39 | from fewshot.models.nnlib import concat, mlp 40 | from fewshot.utils import logger 41 | 42 | log = logger.get() 43 | 44 | flags = tf.flags 45 | FLAGS = tf.flags.FLAGS 46 | 47 | 48 | @RegisterModel("kmeans-refine-mask") 49 | class KMeansRefineMaskModel(KMeansRefineModel): 50 | 51 | def predict(self): 52 | """See `model.py` for documentation.""" 53 | nclasses = self.nway 54 | num_cluster_steps = self.config.num_cluster_steps 55 | h_train, h_unlabel, h_test = self.get_encoded_inputs( 56 | self.x_train, self.x_unlabel, self.x_test) 57 | y_train = self.y_train 58 | protos = self._compute_protos(nclasses, h_train, y_train) 59 | logits_list = [] 60 | logits_list.append(compute_logits(protos, h_test)) 61 | 62 | # Hard assignment for training images. 63 | prob_train = [None] * (nclasses) 64 | for kk in range(nclasses): 65 | # [B, N, 1] 66 | prob_train[kk] = tf.expand_dims( 67 | tf.cast(tf.equal(y_train, kk), h_train.dtype), 2) 68 | prob_train = concat(prob_train, 2) 69 | 70 | y_train_shape = tf.shape(y_train) 71 | bsize = y_train_shape[0] 72 | 73 | h_all = concat([h_train, h_unlabel], 1) 74 | mask = None 75 | 76 | # Calculate pairwise distances. 77 | protos_1 = tf.expand_dims(protos, 2) 78 | protos_2 = tf.expand_dims(h_unlabel, 1) 79 | pair_dist = tf.reduce_sum((protos_1 - protos_2)**2, [3]) # [B, K, N] 80 | mean_dist = tf.reduce_mean(pair_dist, [2], keep_dims=True) 81 | pair_dist_normalize = pair_dist / mean_dist 82 | min_dist = tf.reduce_min( 83 | pair_dist_normalize, [2], keep_dims=True) # [B, K, 1] 84 | max_dist = tf.reduce_max(pair_dist_normalize, [2], keep_dims=True) 85 | mean_dist, var_dist = tf.nn.moments( 86 | pair_dist_normalize, [2], keep_dims=True) 87 | mean_dist += tf.to_float(tf.equal(mean_dist, 0.0)) 88 | var_dist += tf.to_float(tf.equal(var_dist, 0.0)) 89 | skew = tf.reduce_mean( 90 | ((pair_dist_normalize - mean_dist)**3) / (tf.sqrt(var_dist)**3), [2], 91 | keep_dims=True) 92 | kurt = tf.reduce_mean( 93 | ((pair_dist_normalize - mean_dist)**4) / (var_dist**2) - 3, [2], 94 | keep_dims=True) 95 | 96 | n_features = 5 97 | n_out = 3 98 | 99 | dist_features = tf.reshape( 100 | concat([min_dist, max_dist, var_dist, skew, kurt], 2), 101 | [-1, n_features]) # [BK, 4] 102 | dist_features = tf.stop_gradient(dist_features) 103 | 104 | hdim = [n_features, 20, n_out] 105 | act_fn = [tf.nn.tanh, None] 106 | thresh = mlp( 107 | dist_features, 108 | hdim, 109 | is_training=True, 110 | act_fn=act_fn, 111 | dtype=tf.float32, 112 | add_bias=True, 113 | wd=None, 114 | init_std=[0.01, 0.01], 115 | init_method=None, 116 | scope="dist_mlp", 117 | dropout=None, 118 | trainable=True) 119 | scale = tf.exp(thresh[:, 2]) 120 | bias_start = tf.exp(thresh[:, 0]) 121 | bias_add = thresh[:, 1] 122 | bias_start = tf.reshape(bias_start, [bsize, 1, -1]) #[B, 1, K] 123 | bias_add = tf.reshape(bias_add, [bsize, 1, -1]) 124 | 125 | self._scale = scale 126 | self._bias_start = bias_start 127 | self._bias_add = bias_add 128 | 129 | # Run clustering. 130 | for tt in range(num_cluster_steps): 131 | protos_1 = tf.expand_dims(protos, 2) 132 | protos_2 = tf.expand_dims(h_unlabel, 1) 133 | pair_dist = tf.reduce_sum((protos_1 - protos_2)**2, [3]) # [B, K, N] 134 | m_dist = tf.reduce_mean(pair_dist, [2]) # [B, K] 135 | m_dist_1 = tf.expand_dims(m_dist, 1) # [B, 1, K] 136 | m_dist_1 += tf.to_float(tf.equal(m_dist_1, 0.0)) 137 | # Label assignment. 138 | if num_cluster_steps > 1: 139 | bias_tt = bias_start + (tt / float(num_cluster_steps - 1)) * bias_add 140 | else: 141 | bias_tt = bias_start 142 | 143 | negdist = compute_logits(protos, h_unlabel) 144 | mask = tf.sigmoid((negdist / m_dist_1 + bias_tt) * scale) 145 | prob_unlabel, mask = assign_cluster_soft_mask(protos, h_unlabel, mask) 146 | prob_all = concat([prob_train, prob_unlabel * mask], 1) 147 | # No update if 0 unlabel. 148 | protos = tf.cond( 149 | tf.shape(self._x_unlabel)[1] > 0, 150 | lambda: update_cluster(h_all, prob_all), lambda: protos) 151 | logits_list.append(compute_logits(protos, h_test)) 152 | 153 | # Distractor evaluation. 154 | if mask is not None: 155 | max_mask = tf.reduce_max(mask, [2]) 156 | mean_mask = tf.reduce_mean(max_mask) 157 | pred_non_distractor = tf.to_float(max_mask > mean_mask) 158 | acc, recall, precision = eval_distractor(pred_non_distractor, 159 | self.y_unlabel) 160 | self._non_distractor_acc = acc 161 | self._distractor_recall = recall 162 | self._distractor_precision = precision 163 | self._distractor_pred = max_mask 164 | return logits_list 165 | -------------------------------------------------------------------------------- /fewshot/data/tiered_imagenet_split/train.csv: -------------------------------------------------------------------------------- 1 | n01530575,n01524359 2 | n01531178,n01524359 3 | n01532829,n01524359 4 | n01534433,n01524359 5 | n01537544,n01524359 6 | n01558993,n01524359 7 | n01560419,n01524359 8 | n01580077,n01524359 9 | n01582220,n01524359 10 | n01592084,n01524359 11 | n01601694,n01524359 12 | n01675722,n01674216 13 | n01677366,n01674216 14 | n01682714,n01674216 15 | n01685808,n01674216 16 | n01687978,n01674216 17 | n01688243,n01674216 18 | n01689811,n01674216 19 | n01692333,n01674216 20 | n01693334,n01674216 21 | n01694178,n01674216 22 | n01695060,n01674216 23 | n01728572,n01726692 24 | n01728920,n01726692 25 | n01729322,n01726692 26 | n01729977,n01726692 27 | n01734418,n01726692 28 | n01735189,n01726692 29 | n01737021,n01726692 30 | n01739381,n01726692 31 | n01740131,n01726692 32 | n01742172,n01726692 33 | n01744401,n01726692 34 | n01748264,n01726692 35 | n01749939,n01726692 36 | n01751748,n01726692 37 | n01753488,n01726692 38 | n01755581,n01726692 39 | n01756291,n01726692 40 | n01847000,n01844917 41 | n01855032,n01844917 42 | n01855672,n01844917 43 | n01860187,n01844917 44 | n02002556,n01844917 45 | n02002724,n01844917 46 | n02006656,n01844917 47 | n02007558,n01844917 48 | n02009229,n01844917 49 | n02009912,n01844917 50 | n02011460,n01844917 51 | n02012849,n01844917 52 | n02013706,n01844917 53 | n02017213,n01844917 54 | n02018207,n01844917 55 | n02018795,n01844917 56 | n02025239,n01844917 57 | n02027492,n01844917 58 | n02028035,n01844917 59 | n02033041,n01844917 60 | n02037110,n01844917 61 | n02051845,n01844917 62 | n02056570,n01844917 63 | n02058221,n01844917 64 | n02088094,n02087551 65 | n02088238,n02087551 66 | n02088364,n02087551 67 | n02088466,n02087551 68 | n02088632,n02087551 69 | n02089078,n02087551 70 | n02089867,n02087551 71 | n02089973,n02087551 72 | n02090379,n02087551 73 | n02090622,n02087551 74 | n02090721,n02087551 75 | n02091032,n02087551 76 | n02091134,n02087551 77 | n02091244,n02087551 78 | n02091467,n02087551 79 | n02091635,n02087551 80 | n02091831,n02087551 81 | n02092002,n02087551 82 | n02092339,n02087551 83 | n02093256,n02092468 84 | n02093428,n02092468 85 | n02093647,n02092468 86 | n02093754,n02092468 87 | n02093859,n02092468 88 | n02093991,n02092468 89 | n02094114,n02092468 90 | n02094258,n02092468 91 | n02094433,n02092468 92 | n02095314,n02092468 93 | n02095570,n02092468 94 | n02095889,n02092468 95 | n02096051,n02092468 96 | n02096177,n02092468 97 | n02096294,n02092468 98 | n02096437,n02092468 99 | n02096585,n02092468 100 | n02097047,n02092468 101 | n02097130,n02092468 102 | n02097209,n02092468 103 | n02097298,n02092468 104 | n02097474,n02092468 105 | n02097658,n02092468 106 | n02098105,n02092468 107 | n02098286,n02092468 108 | n02098413,n02092468 109 | n02123045,n02120997 110 | n02123159,n02120997 111 | n02123394,n02120997 112 | n02123597,n02120997 113 | n02124075,n02120997 114 | n02125311,n02120997 115 | n02127052,n02120997 116 | n02128385,n02120997 117 | n02128757,n02120997 118 | n02128925,n02120997 119 | n02129165,n02120997 120 | n02129604,n02120997 121 | n02130308,n02120997 122 | n02389026,n02370806 123 | n02391049,n02370806 124 | n02395406,n02370806 125 | n02396427,n02370806 126 | n02397096,n02370806 127 | n02398521,n02370806 128 | n02403003,n02370806 129 | n02408429,n02370806 130 | n02410509,n02370806 131 | n02412080,n02370806 132 | n02415577,n02370806 133 | n02417914,n02370806 134 | n02422106,n02370806 135 | n02422699,n02370806 136 | n02423022,n02370806 137 | n02437312,n02370806 138 | n02437616,n02370806 139 | n02480495,n02469914 140 | n02480855,n02469914 141 | n02481823,n02469914 142 | n02483362,n02469914 143 | n02483708,n02469914 144 | n02484975,n02469914 145 | n02486261,n02469914 146 | n02486410,n02469914 147 | n02487347,n02469914 148 | n02488291,n02469914 149 | n02488702,n02469914 150 | n02489166,n02469914 151 | n02490219,n02469914 152 | n02492035,n02469914 153 | n02492660,n02469914 154 | n02493509,n02469914 155 | n02493793,n02469914 156 | n02494079,n02469914 157 | n02497673,n02469914 158 | n02500267,n02469914 159 | n02727426,n02913152 160 | n02793495,n02913152 161 | n02859443,n02913152 162 | n03028079,n02913152 163 | n03032252,n02913152 164 | n03457902,n02913152 165 | n03529860,n02913152 166 | n03661043,n02913152 167 | n03781244,n02913152 168 | n03788195,n02913152 169 | n03877845,n02913152 170 | n03956157,n02913152 171 | n04081281,n02913152 172 | n04346328,n02913152 173 | n02687172,n03125870 174 | n02690373,n03125870 175 | n02692877,n03125870 176 | n02782093,n03125870 177 | n02951358,n03125870 178 | n02981792,n03125870 179 | n03095699,n03125870 180 | n03344393,n03125870 181 | n03447447,n03125870 182 | n03662601,n03125870 183 | n03673027,n03125870 184 | n03947888,n03125870 185 | n04147183,n03125870 186 | n04266014,n03125870 187 | n04273569,n03125870 188 | n04347754,n03125870 189 | n04483307,n03125870 190 | n04552348,n03125870 191 | n04606251,n03125870 192 | n04612504,n03125870 193 | n02979186,n03278248 194 | n02988304,n03278248 195 | n02992529,n03278248 196 | n03085013,n03278248 197 | n03187595,n03278248 198 | n03584254,n03278248 199 | n03777754,n03278248 200 | n03782006,n03278248 201 | n03857828,n03278248 202 | n03902125,n03278248 203 | n04392985,n03278248 204 | n02776631,n03297735 205 | n02791270,n03297735 206 | n02871525,n03297735 207 | n02927161,n03297735 208 | n03089624,n03297735 209 | n03461385,n03297735 210 | n04005630,n03297735 211 | n04200800,n03297735 212 | n04443257,n03297735 213 | n04462240,n03297735 214 | n02799071,n03414162 215 | n02802426,n03414162 216 | n03134739,n03414162 217 | n03445777,n03414162 218 | n03598930,n03414162 219 | n03942813,n03414162 220 | n04023962,n03414162 221 | n04118538,n03414162 222 | n04254680,n03414162 223 | n04409515,n03414162 224 | n04540053,n03414162 225 | n06785654,n03414162 226 | n02667093,n03419014 227 | n02837789,n03419014 228 | n02865351,n03419014 229 | n02883205,n03419014 230 | n02892767,n03419014 231 | n02963159,n03419014 232 | n03188531,n03419014 233 | n03325584,n03419014 234 | n03404251,n03419014 235 | n03534580,n03419014 236 | n03594734,n03419014 237 | n03595614,n03419014 238 | n03617480,n03419014 239 | n03630383,n03419014 240 | n03710721,n03419014 241 | n03770439,n03419014 242 | n03866082,n03419014 243 | n03980874,n03419014 244 | n04136333,n03419014 245 | n04325704,n03419014 246 | n04350905,n03419014 247 | n04370456,n03419014 248 | n04371430,n03419014 249 | n04479046,n03419014 250 | n04591157,n03419014 251 | n02708093,n03574816 252 | n02749479,n03574816 253 | n02794156,n03574816 254 | n02841315,n03574816 255 | n02879718,n03574816 256 | n02950826,n03574816 257 | n03196217,n03574816 258 | n03197337,n03574816 259 | n03467068,n03574816 260 | n03544143,n03574816 261 | n03692522,n03574816 262 | n03706229,n03574816 263 | n03773504,n03574816 264 | n03841143,n03574816 265 | n03891332,n03574816 266 | n04008634,n03574816 267 | n04009552,n03574816 268 | n04044716,n03574816 269 | n04086273,n03574816 270 | n04090263,n03574816 271 | n04118776,n03574816 272 | n04141975,n03574816 273 | n04317175,n03574816 274 | n04328186,n03574816 275 | n04355338,n03574816 276 | n04356056,n03574816 277 | n04376876,n03574816 278 | n04548280,n03574816 279 | n02672831,n03800933 280 | n02676566,n03800933 281 | n02787622,n03800933 282 | n02804610,n03800933 283 | n02992211,n03800933 284 | n03017168,n03800933 285 | n03110669,n03800933 286 | n03249569,n03800933 287 | n03272010,n03800933 288 | n03372029,n03800933 289 | n03394916,n03800933 290 | n03447721,n03800933 291 | n03452741,n03800933 292 | n03494278,n03800933 293 | n03495258,n03800933 294 | n03720891,n03800933 295 | n03721384,n03800933 296 | n03838899,n03800933 297 | n03840681,n03800933 298 | n03854065,n03800933 299 | n03884397,n03800933 300 | n04141076,n03800933 301 | n04311174,n03800933 302 | n04487394,n03800933 303 | n04515003,n03800933 304 | n04536866,n03800933 305 | n02825657,n04014297 306 | n02840245,n04014297 307 | n02843684,n04014297 308 | n02895154,n04014297 309 | n03000247,n04014297 310 | n03146219,n04014297 311 | n03220513,n04014297 312 | n03347037,n04014297 313 | n03424325,n04014297 314 | n03527444,n04014297 315 | n03637318,n04014297 316 | n03657121,n04014297 317 | n03788365,n04014297 318 | n03929855,n04014297 319 | n04141327,n04014297 320 | n04192698,n04014297 321 | n04229816,n04014297 322 | n04417672,n04014297 323 | n04423845,n04014297 324 | n04435653,n04014297 325 | n04507155,n04014297 326 | n04523525,n04014297 327 | n04589890,n04014297 328 | n04590129,n04014297 329 | n02910353,n04081844 330 | n03075370,n04081844 331 | n03208938,n04081844 332 | n03476684,n04081844 333 | n03627232,n04081844 334 | n03803284,n04081844 335 | n03804744,n04081844 336 | n03874599,n04081844 337 | n04127249,n04081844 338 | n04153751,n04081844 339 | n04162706,n04081844 340 | n02951585,n04451818 341 | n03041632,n04451818 342 | n03109150,n04451818 343 | n03481172,n04451818 344 | n03498962,n04451818 345 | n03649909,n04451818 346 | n03658185,n04451818 347 | n03954731,n04451818 348 | n03967562,n04451818 349 | n03970156,n04451818 350 | n04154565,n04451818 351 | n04208210,n04451818 352 | -------------------------------------------------------------------------------- /run_multi_exp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 Mengye Ren, Eleni Triantafillou, Sachin Ravi, Jake Snell, 2 | # Kevin Swersky, Joshua B. Tenenbaum, Hugo Larochelle, Richars S. Zemel. 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the "Software"), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all 12 | # copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | # SOFTWARE. 21 | # ============================================================================= 22 | """Runs multiple experiments over random labeled/unlabeled splits of the dataset. 23 | Author: Mengye Ren (mren@cs.toronto.edu) 24 | 25 | Usage: 26 | 27 | Example: 28 | python run_multi_exp.py --data_root /data/ \ 29 | --dataset omniglot \ 30 | --label_ratio 0.1 \ 31 | --model basic 32 | 33 | 34 | Flags: Same set of flags as `run_exp.py`. 35 | 36 | """ 37 | 38 | from __future__ import division, print_function 39 | 40 | import numpy as np 41 | import os 42 | import six 43 | import tensorflow as tf 44 | 45 | from collections import namedtuple 46 | 47 | flags = tf.flags 48 | flags.DEFINE_bool("eval_train", False, "Whether to evaluate training set") 49 | flags.DEFINE_string("num_unlabel_list", None, "Number of unlabel items") 50 | 51 | from fewshot.utils import logger 52 | from run_exp import _get_model 53 | from run_exp import evaluate 54 | from run_exp import get_config 55 | from run_exp import get_dataset 56 | from run_exp import train 57 | 58 | FLAGS = tf.flags.FLAGS 59 | 60 | if 'imagenet' in FLAGS.dataset: 61 | if FLAGS.num_unlabel_list is None: 62 | NUM_UNLABEL_LIST = '0,1,2,5,10,15,20,25' 63 | else: 64 | NUM_UNLABEL_LIST = FLAGS.num_unlabel_list 65 | NUM_RUN = 10 66 | else: 67 | if FLAGS.num_unlabel_list is None: 68 | NUM_UNLABEL_LIST = '0,1,2,5,10' 69 | else: 70 | NUM_UNLABEL_LIST = FLAGS.num_unlabel_list 71 | NUM_RUN = 10 72 | log = logger.get() 73 | 74 | 75 | def gen_id(config): 76 | import datetime 77 | dtstr = datetime.datetime.now().isoformat(chr(ord("-"))).replace(":", 78 | "-").replace( 79 | ".", "-") 80 | return "{}_{}".format(config.name, dtstr) 81 | 82 | 83 | def run_one(dataset, model, seed, pretrain_id, exp_id): 84 | log.info("Random seed = {}".format(seed)) 85 | config = get_config(dataset, model) 86 | nclasses_train = FLAGS.nclasses_train 87 | nclasses_eval = FLAGS.nclasses_eval 88 | train_split_name = 'train' 89 | 90 | if FLAGS.use_test: 91 | log.info('Using the test set') 92 | test_split_name = 'test' 93 | else: 94 | log.info('Not using the test set, using val') 95 | test_split_name = 'val' 96 | 97 | if dataset in ['mini-imagenet', 'tiered-imagenet']: 98 | _aug_90 = False 99 | num_test_test = 20 100 | num_train_test = 5 101 | else: 102 | _aug_90 = True 103 | num_test_test = -1 104 | num_train_test = -1 105 | 106 | meta_train_dataset = get_dataset( 107 | dataset, 108 | train_split_name, 109 | nclasses_train, 110 | FLAGS.nshot, 111 | num_test=num_train_test, 112 | aug_90=_aug_90, 113 | num_unlabel=FLAGS.num_unlabel, 114 | shuffle_episode=False, 115 | seed=seed) 116 | 117 | meta_test_dataset = get_dataset( 118 | dataset, 119 | test_split_name, 120 | nclasses_eval, 121 | FLAGS.nshot, 122 | num_test=num_test_test, 123 | aug_90=_aug_90, 124 | num_unlabel=5, 125 | shuffle_episode=False, 126 | label_ratio=1, 127 | seed=seed) 128 | 129 | sconfig = tf.ConfigProto() 130 | sconfig.gpu_options.allow_growth = True 131 | with tf.Session(config=sconfig) as sess: 132 | tf.set_random_seed(seed) 133 | with log.verbose_level(2): 134 | m, mvalid = _get_model(config, nclasses_train, nclasses_eval) 135 | if pretrain_id is not None: 136 | ckpt = tf.train.latest_checkpoint( 137 | os.path.join(FLAGS.results, pretrain_id)) 138 | saver = tf.train.Saver() 139 | saver.restore(sess, ckpt) 140 | else: 141 | sess.run(tf.global_variables_initializer()) 142 | 143 | if not FLAGS.eval: 144 | exp_id_ = exp_id + "-{:05d}".format(seed) 145 | train( 146 | sess, 147 | config, 148 | m, 149 | meta_train_dataset, 150 | mvalid, 151 | meta_test_dataset, 152 | log_results=False, 153 | run_eval=False, 154 | exp_id=exp_id_) 155 | else: 156 | exp_id_ = None 157 | 158 | if FLAGS.eval_train: 159 | train_results = evaluate(sess, mvalid, meta_train_dataset) 160 | log.info("Final train acc {:.3f}% ({:.3f}%)".format( 161 | train_results['acc'] * 100.0, train_results['acc_ci'] * 100.0)) 162 | else: 163 | train_results = None 164 | 165 | num_unlabel_list = [int(nn) for nn in NUM_UNLABEL_LIST.split(',')] 166 | test_results_list = [] 167 | for nn in num_unlabel_list: 168 | 169 | if dataset == 'mini-imagenet': 170 | AL_Instance = namedtuple( 171 | 'AL_Instance', 'n_class, n_distractor, k_train, k_test, k_unlbl') 172 | new_al_instance = AL_Instance( 173 | n_class=meta_test_dataset.al_instance.n_class, 174 | n_distractor=meta_test_dataset.al_instance.n_distractor, 175 | k_train=meta_test_dataset.al_instance.k_train, 176 | k_test=meta_test_dataset.al_instance.k_test, 177 | k_unlbl=nn) 178 | meta_test_dataset.al_instance = new_al_instance 179 | else: 180 | meta_test_dataset._num_unlabel = nn 181 | 182 | meta_test_dataset.reset() 183 | _test_results = evaluate(sess, mvalid, meta_test_dataset) 184 | test_results_list.append(_test_results) 185 | log.info("Final test acc {:.3f}% ({:.3f}%)".format( 186 | _test_results['acc'] * 100.0, _test_results['acc_ci'] * 100.0)) 187 | 188 | return train_results, test_results_list, exp_id_, num_unlabel_list 189 | 190 | 191 | def calc_avg(number): 192 | number_ = np.array(number) 193 | return np.mean(number_), np.std(number_) 194 | 195 | 196 | def collect(results): 197 | acc = [rr['acc'] for rr in results] 198 | return calc_avg(acc) 199 | 200 | 201 | def main(): 202 | rnd = np.random.RandomState(0) 203 | 204 | # Set up pretrain ID list. 205 | if FLAGS.pretrain is not None: 206 | num_runs = NUM_RUN 207 | pretrain_ids = [ 208 | FLAGS.pretrain + '-{:05d}'.format(1001 * ii) 209 | for ii in six.moves.xrange(num_runs) 210 | ] 211 | else: 212 | pretrain_ids = [None] * NUM_RUN 213 | num_runs = NUM_RUN 214 | 215 | all_train_results = [] 216 | all_test_results = [] 217 | exp_ids = [] 218 | seed_list = [] 219 | config = get_config(FLAGS.dataset, FLAGS.model) 220 | exp_id_root = gen_id(config) 221 | 222 | for ii, pid in enumerate(pretrain_ids): 223 | log.info("Run {} out of {}".format(ii + 1, NUM_RUN)) 224 | with tf.Graph().as_default(): 225 | _seed = 1001 * ii 226 | train_results, test_results_list, exp_id, num_unlabel_list = run_one( 227 | FLAGS.dataset, FLAGS.model, _seed, pid, exp_id_root) 228 | all_train_results.append(train_results) 229 | all_test_results.append(test_results_list) 230 | exp_ids.append(exp_id) 231 | seed_list.append(_seed) 232 | 233 | if FLAGS.eval_train: 234 | trn_acc = collect(all_train_results) 235 | log.info('Train Acc = {:.3f} ({:.3f})'.format(trn_acc[0] * 100.0, 236 | trn_acc[1] * 100.0)) 237 | for ii in range(len(num_unlabel_list)): 238 | _all_test_results = [] 239 | for vr in all_test_results: 240 | _all_test_results.append(vr[ii]) 241 | _test_acc = collect(_all_test_results) 242 | log.info('Num Unlabel {}'.format(num_unlabel_list[ii])) 243 | log.info('Test Acc = {:.3f} ({:.3f})'.format(_test_acc[0] * 100.0, 244 | _test_acc[1] * 100.0)) 245 | log.info('Experiment ID:') 246 | for ee, seed in zip(exp_ids, seed_list): 247 | print(ee, seed) 248 | 249 | 250 | if __name__ == "__main__": 251 | main() 252 | -------------------------------------------------------------------------------- /fewshot/models/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 Mengye Ren, Eleni Triantafillou, Sachin Ravi, Jake Snell, 2 | # Kevin Swersky, Joshua B. Tenenbaum, Hugo Larochelle, Richars S. Zemel. 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the "Software"), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all 12 | # copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | # SOFTWARE. 21 | # ============================================================================= 22 | """Model abstract class. 23 | Author: Mengye Ren (mren@cs.toronto.edu) 24 | """ 25 | 26 | #TODO: complete this section. 27 | from __future__ import (absolute_import, division, print_function, 28 | unicode_literals) 29 | 30 | import numpy as np 31 | import tensorflow as tf 32 | from fewshot.models.nnlib import cnn, concat 33 | from fewshot.models.measure import batch_apk, apk 34 | from fewshot.utils import logger 35 | 36 | flags = tf.flags 37 | flags.DEFINE_bool("allstep", False, 38 | "Whether or not to average loss for all steps") 39 | FLAGS = tf.flags.FLAGS 40 | log = logger.get() 41 | 42 | 43 | class Model(object): 44 | """Sub-classes need to implement the following 45 | two methods: 46 | 1) get_inference: Given reference and candidate images, compute the logits 47 | of whether the candidate images are relevant. 48 | 2) get_train_op: Given the computed logits and groundtruth mask, compute 49 | the loss and op to optimize the network. 50 | """ 51 | 52 | def __init__(self, 53 | config, 54 | nway=2, 55 | nshot=1, 56 | num_test=30, 57 | is_training=True, 58 | dtype=tf.float32): 59 | """Builds model.""" 60 | self._config = config 61 | self._dtype = dtype 62 | self._nshot = nshot 63 | self._nway = nway 64 | self._num_test = num_test 65 | self._is_training = is_training 66 | 67 | height = config.height 68 | width = config.width 69 | channels = config.num_channel 70 | 71 | # Train images. 72 | self._x_train = tf.placeholder( 73 | dtype, [None, None, height, width, channels], name="x_train") 74 | 75 | # Test images. 76 | self._x_test = tf.placeholder( 77 | dtype, [None, None, height, width, channels], name="x_test") 78 | 79 | self._y_train = tf.placeholder(tf.int64, [None, None], name="y_train") 80 | 81 | # Whether the candidate is relevant. 82 | self._y_test = tf.placeholder(tf.int64, [None, None], name="y_test") 83 | 84 | if self._nway > 1: 85 | self._y_train_one_hot = tf.one_hot(self._y_train, self._nway) 86 | self._y_test_one_hot = tf.one_hot(self._y_test, self._nway) 87 | 88 | # Learning rate. 89 | self._learn_rate = tf.get_variable( 90 | "learn_rate", shape=[], initializer=tf.constant_initializer(0.0)) 91 | self._new_lr = tf.placeholder(dtype, [], name="new_lr") 92 | self._assign_lr = tf.assign(self._learn_rate, self._new_lr) 93 | self._embedding_weights = None 94 | 95 | # Predition. 96 | self._logits = self.predict() 97 | 98 | # Output. 99 | self.compute_output() 100 | 101 | if is_training: 102 | self._loss, self._train_op = self.get_train_op(self.logits, self.y_test) 103 | 104 | def predict(self): 105 | """Build inference graph. To be implemented by sub models. 106 | Returns: 107 | logits: [B, M]. Logits on whether each candidate image belongs to the 108 | reference class. 109 | """ 110 | raise NotImplemented() 111 | 112 | def compute_output(self): 113 | # Evaluation. 114 | logits = self.logits[-1] 115 | if self.nway > 1: 116 | self._prediction = tf.nn.softmax(logits) 117 | self._correct = tf.equal(tf.argmax(self.prediction, axis=2), self.y_test) 118 | else: 119 | self._prediction = tf.sigmoid(logits) 120 | self._correct = tf.equal( 121 | tf.cast(self.prediction > 0.5, self.dtype), self.y_test) 122 | self._acc = tf.reduce_mean(tf.cast(self._correct, self.dtype)) 123 | 124 | def get_train_op(self, logits, y_test): 125 | """Builds optimization operation. To be implemented by sub models. 126 | Args: 127 | logits: [B, M]. Logits on whether each candidate image belongs to the 128 | reference class. 129 | y_test: [B, M, K]. Test image labels. 130 | Returns: 131 | loss: Scalar. Loss function to be optimized. 132 | train_op: TensorFlow operation. 133 | """ 134 | raise NotImplemented() 135 | 136 | def phi(self, x, ext_wts=None, reuse=None): 137 | """Feature extraction function. 138 | Args: 139 | x: [N, H, W, C]. Input. 140 | reuse: Whether to reuse variables here. 141 | """ 142 | config = self.config 143 | is_training = self.is_training 144 | dtype = self.dtype 145 | with tf.variable_scope("phi", reuse=reuse): 146 | h, wts = cnn( 147 | x, 148 | config.filter_size, 149 | strides=config.strides, 150 | pool_fn=[tf.nn.max_pool] * len(config.pool_fn), 151 | pool_size=config.pool_size, 152 | pool_strides=config.pool_strides, 153 | act_fn=[tf.nn.relu for aa in config.conv_act_fn], 154 | add_bias=True, 155 | init_std=config.conv_init_std, 156 | init_method=config.conv_init_method, 157 | wd=config.wd, 158 | dtype=dtype, 159 | batch_norm=True, 160 | is_training=is_training, 161 | ext_wts=ext_wts) 162 | if self._embedding_weights is None: 163 | self._embedding_weights = wts 164 | h_shape = h.get_shape() 165 | h_size = 1 166 | for ss in h_shape[1:]: 167 | h_size *= int(ss) 168 | h = tf.reshape(h, [-1, h_size]) 169 | return h 170 | 171 | def assign_lr(self, sess, value): 172 | """Assign new learning rate value.""" 173 | sess.run(self._assign_lr, feed_dict={self._new_lr: value}) 174 | 175 | def assign_pretrained_weights(self, sess, ext_wts): 176 | """Load pretrained weights. 177 | Args: 178 | sess: TensorFlow session object. 179 | ext_wts: External weights dictionary. 180 | """ 181 | assign_ops = [] 182 | with tf.variable_scope("Model/phi/cnn", reuse=True): 183 | for layer in range(len(self.config.filter_size)): 184 | with tf.variable_scope("layer_{}".format(layer)): 185 | for wname1, wname2 in zip( 186 | ["w", "b", "ema_mean", "ema_var", "beta", "gamma"], 187 | ["w", "b", "emean", "evar", "beta", "gamma"]): 188 | assign_ops.append( 189 | tf.assign( 190 | tf.get_variable(wname1), ext_wts["{}_{}".format( 191 | wname2, layer)])) 192 | sess.run(assign_ops) 193 | 194 | @property 195 | def y_test(self): 196 | return self._y_test 197 | 198 | @property 199 | def logits(self): 200 | return self._logits 201 | 202 | @property 203 | def prediction(self): 204 | return self._prediction 205 | 206 | @property 207 | def correct(self): 208 | return self._correct 209 | 210 | @property 211 | def x_train(self): 212 | return self._x_train 213 | 214 | @property 215 | def y_train(self): 216 | return self._y_train 217 | 218 | @property 219 | def y_train_one_hot(self): 220 | return self._y_train_one_hot 221 | 222 | @property 223 | def x_test(self): 224 | return self._x_test 225 | 226 | @property 227 | def y_test_one_hot(self): 228 | return self._y_test_one_hot 229 | 230 | @property 231 | def learn_rate(self): 232 | return self._learn_rate 233 | 234 | @property 235 | def config(self): 236 | return self._config 237 | 238 | @property 239 | def dtype(self): 240 | return self._dtype 241 | 242 | @property 243 | def loss(self): 244 | return self._loss 245 | 246 | @property 247 | def train_op(self): 248 | return self._train_op 249 | 250 | @property 251 | def is_training(self): 252 | return self._is_training 253 | 254 | @property 255 | def nshot(self): 256 | return self._nshot 257 | 258 | @property 259 | def nway(self): 260 | return self._nway 261 | 262 | @property 263 | def candidate_size(self): 264 | return self._candidate_size 265 | 266 | @property 267 | def embedding_weights(self): 268 | return self._embedding_weights 269 | -------------------------------------------------------------------------------- /fewshot/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 Mengye Ren, Eleni Triantafillou, Sachin Ravi, Jake Snell, 2 | # Kevin Swersky, Joshua B. Tenenbaum, Hugo Larochelle, Richars S. Zemel. 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the "Software"), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all 12 | # copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | # SOFTWARE. 21 | # ============================================================================= 22 | """ 23 | A python logger 24 | 25 | Usage: 26 | # Set logger verbose level. 27 | import os 28 | os.environ["VERBOSE"] = 1 29 | 30 | import logger 31 | log = logger.get("../logs/sample_log") 32 | 33 | log.info("Hello world!") 34 | log.info("Hello again!", verbose=2) 35 | log.warning("Something might be wrong.") 36 | log.error("Something is wrong.") 37 | log.fatal("Failed.") 38 | """ 39 | from __future__ import (absolute_import, division, print_function, 40 | unicode_literals) 41 | 42 | import datetime 43 | import inspect 44 | import os 45 | import sys 46 | import threading 47 | import traceback 48 | 49 | TERM_COLOR = { 50 | "normal": "\033[0m", 51 | "bright": "\033[1m", 52 | "invert": "\033[7m", 53 | "black": "\033[30m", 54 | "red": "\033[31m", 55 | "green": "\033[32m", 56 | "yellow": "\033[33m", 57 | "blue": "\033[34m", 58 | "magenta": "\033[35m", 59 | "cyan": "\033[36m", 60 | "white": "\033[37m", 61 | "default": "\033[39m" 62 | } 63 | 64 | log = None 65 | log_lock = threading.Lock() 66 | 67 | 68 | def get(fname=None): 69 | """ 70 | Returns a logger instance, with optional log file output. 71 | """ 72 | global log 73 | if log is not None and fname is None: 74 | return log 75 | 76 | # fname = os.environ.get("LOGTO", None) 77 | # if fname is None: 78 | # fname = default_fname 79 | else: 80 | log = Logger(fname) 81 | 82 | return log 83 | 84 | 85 | class Logger(object): 86 | 87 | def __init__(self, filename=None, default_verbose=0): 88 | """ 89 | Constructs a logger with optional log file output. 90 | 91 | Args: 92 | filename: optional log file output. If None, nothing will be 93 | written to file 94 | """ 95 | now = datetime.datetime.now() 96 | self.verbose_thresh = int(os.environ.get("VERBOSE", 0)) 97 | self.default_verbose = default_verbose 98 | if filename is not None: 99 | self.filename = filename 100 | dirname = os.path.dirname(self.filename) 101 | if not os.path.exists(dirname): 102 | os.makedirs(dirname) 103 | open(self.filename, "w").close() 104 | self.info("Log written to {}".format(os.path.abspath(self.filename))) 105 | else: 106 | self.filename = None 107 | 108 | pass 109 | 110 | @staticmethod 111 | def get_time_str(t=None): 112 | """ 113 | Returns a formatted time string. 114 | 115 | Args: 116 | t: datetime, default now. 117 | """ 118 | if t is None: 119 | t = datetime.datetime.now() 120 | timestr = t.isoformat(chr(32)) 121 | return timestr 122 | 123 | def log(self, message, typ="info", verbose=None): 124 | """ 125 | Writes a message. 126 | 127 | Args: 128 | message: string, message content. 129 | typ: string, type of the message. info, warning, error, or fatal. 130 | verbose: number, verbose level of the message. If lower than the 131 | environment variable, then the message will be logged to standard 132 | output and log output file (if set). 133 | """ 134 | threadstr = "{}".format(threading.current_thread().ident)[-4:] 135 | if typ == "info": 136 | typstr_print = "{}I{}{}".format(TERM_COLOR["green"], threadstr, 137 | TERM_COLOR["default"]) 138 | typstr_log = "I{}".format(threadstr) 139 | elif typ == "warning": 140 | typstr_print = "{}W{}{}".format(TERM_COLOR["yellow"], threadstr, 141 | TERM_COLOR["default"]) 142 | typstr_log = "W{}".format(threadstr) 143 | elif typ == "debug": 144 | typstr_print = "{}D{}{}".format(TERM_COLOR["yellow"], threadstr, 145 | TERM_COLOR["default"]) 146 | typstr_log = "D{}".format(threadstr) 147 | elif typ == "error": 148 | typstr_print = "{}E{}{}".format(TERM_COLOR["red"], threadstr, 149 | TERM_COLOR["default"]) 150 | typstr_log = "E{}".format(threadstr) 151 | elif typ == "fatal": 152 | typstr_print = "{}F{}{}".format(TERM_COLOR["red"], threadstr, 153 | TERM_COLOR["default"]) 154 | typstr_log = "F{}".format(threadstr) 155 | else: 156 | raise Exception("Unknown log type: {0}".format(typ)) 157 | timestr = self.get_time_str() 158 | for (frame, filename, line_number, function_name, lines, index) in \ 159 | inspect.getouterframes(inspect.currentframe()): 160 | fn = os.path.basename(filename) 161 | if fn != "logger.py": 162 | break 163 | cwd = os.getcwd() 164 | if filename.startswith(cwd): 165 | filename = filename[len(cwd):] 166 | filename = filename.lstrip("/") 167 | 168 | callerstr = "{}:{}".format(filename, line_number) 169 | if len(callerstr) > 20: 170 | callerstr = "...{}".format(callerstr[-17:]) 171 | printstr = "{} {} {} {}".format(typstr_print, timestr, callerstr, message) 172 | logstr = "{} {} {} {}".format(typstr_log, timestr, callerstr, message) 173 | 174 | print(printstr) 175 | pass 176 | 177 | def log_wrapper(self, message, typ="info", verbose=None): 178 | if verbose is None: 179 | verbose = self.default_verbose 180 | 181 | if type(verbose) != int: 182 | raise Exception("Unknown verbose value: {}".format(verbose)) 183 | 184 | log_lock.acquire() 185 | try: 186 | if self.verbose_thresh >= verbose: 187 | self.log(message, typ=typ, verbose=verbose) 188 | 189 | if self.filename is not None: 190 | with open(self.filename, "a") as f: 191 | f.write(logstr) 192 | f.write("\n") 193 | except e: 194 | print("Error occurred!!") 195 | print(str(e)) 196 | finally: 197 | log_lock.release() 198 | 199 | def info(self, message, verbose=None): 200 | """ 201 | Writes an info message. 202 | 203 | Args: 204 | message: string, message content. 205 | verbose: number, verbose level. 206 | """ 207 | self.log_wrapper(message, typ="info", verbose=verbose) 208 | pass 209 | 210 | def warning(self, message, verbose=1): 211 | """ 212 | Writes a warning message. 213 | 214 | Args: 215 | message: string, message content. 216 | verbose: number, verbose level. 217 | """ 218 | self.log_wrapper(message, typ="warning", verbose=verbose) 219 | pass 220 | 221 | def error(self, message, verbose=0): 222 | """ 223 | Writes an info message. 224 | 225 | Args: 226 | message: string, message content. 227 | verbose: number, verbose level. 228 | """ 229 | self.log_wrapper(message, typ="error", verbose=verbose) 230 | pass 231 | 232 | def debug(self, message, verbose=None): 233 | self.log_wrapper(message, typ="debug", verbose=verbose) 234 | pass 235 | 236 | def fatal(self, message, verbose=0): 237 | """ 238 | Writes a fatal message, and exits the program. 239 | 240 | Args: 241 | message: string, message content. 242 | verbose: number, verbose level. 243 | """ 244 | self.log_wrapper(message, typ="fatal", verbose=verbose) 245 | sys.exit(0) 246 | pass 247 | 248 | def log_args(self, verbose=None): 249 | self.info("Command: {}".format(" ".join(sys.argv))) 250 | pass 251 | 252 | def log_exception(self, exception): 253 | tb_str = traceback.format_exc(exception) 254 | self.error(tb_str) 255 | pass 256 | 257 | def verbose_level(self, level): 258 | 259 | class VerboseScope(): 260 | 261 | def __init__(self, logger, new_level): 262 | self._new_level = new_level 263 | self._logger = logger 264 | pass 265 | 266 | def __enter__(self): 267 | self._restore = self._logger.default_verbose 268 | self._logger.default_verbose = self._new_level 269 | pass 270 | 271 | def __exit__(self, type, value, traceback): 272 | self._logger.default_verbose = self._restore 273 | pass 274 | 275 | return VerboseScope(self, level) 276 | -------------------------------------------------------------------------------- /fewshot/data/tiered_imagenet_split/old_train.csv: -------------------------------------------------------------------------------- 1 | n01530575,n01524359 2 | n01531178,n01524359 3 | n01532829,n01524359 4 | n01534433,n01524359 5 | n01537544,n01524359 6 | n01558993,n01524359 7 | n01560419,n01524359 8 | n01580077,n01524359 9 | n01582220,n01524359 10 | n01592084,n01524359 11 | n01601694,n01524359 12 | n01675722,n01674216 13 | n01677366,n01674216 14 | n01682714,n01674216 15 | n01685808,n01674216 16 | n01687978,n01674216 17 | n01688243,n01674216 18 | n01689811,n01674216 19 | n01692333,n01674216 20 | n01693334,n01674216 21 | n01694178,n01674216 22 | n01695060,n01674216 23 | n01728572,n01726692 24 | n01728920,n01726692 25 | n01729322,n01726692 26 | n01729977,n01726692 27 | n01734418,n01726692 28 | n01735189,n01726692 29 | n01737021,n01726692 30 | n01739381,n01726692 31 | n01740131,n01726692 32 | n01742172,n01726692 33 | n01744401,n01726692 34 | n01748264,n01726692 35 | n01749939,n01726692 36 | n01751748,n01726692 37 | n01753488,n01726692 38 | n01755581,n01726692 39 | n01756291,n01726692 40 | n01847000,n01844917 41 | n01855032,n01844917 42 | n01855672,n01844917 43 | n01860187,n01844917 44 | n02002556,n01844917 45 | n02002724,n01844917 46 | n02006656,n01844917 47 | n02007558,n01844917 48 | n02009229,n01844917 49 | n02009912,n01844917 50 | n02011460,n01844917 51 | n02012849,n01844917 52 | n02013706,n01844917 53 | n02017213,n01844917 54 | n02018207,n01844917 55 | n02018795,n01844917 56 | n02025239,n01844917 57 | n02027492,n01844917 58 | n02028035,n01844917 59 | n02033041,n01844917 60 | n02037110,n01844917 61 | n02051845,n01844917 62 | n02056570,n01844917 63 | n02058221,n01844917 64 | n02088094,n02087551 65 | n02088238,n02087551 66 | n02088364,n02087551 67 | n02088466,n02087551 68 | n02088632,n02087551 69 | n02089078,n02087551 70 | n02089867,n02087551 71 | n02089973,n02087551 72 | n02090379,n02087551 73 | n02090622,n02087551 74 | n02090721,n02087551 75 | n02091032,n02087551 76 | n02091134,n02087551 77 | n02091244,n02087551 78 | n02091467,n02087551 79 | n02091635,n02087551 80 | n02091831,n02087551 81 | n02092002,n02087551 82 | n02092339,n02087551 83 | n02093256,n02092468 84 | n02093428,n02092468 85 | n02093647,n02092468 86 | n02093754,n02092468 87 | n02093859,n02092468 88 | n02093991,n02092468 89 | n02094114,n02092468 90 | n02094258,n02092468 91 | n02094433,n02092468 92 | n02095314,n02092468 93 | n02095570,n02092468 94 | n02095889,n02092468 95 | n02096051,n02092468 96 | n02096177,n02092468 97 | n02096294,n02092468 98 | n02096437,n02092468 99 | n02096585,n02092468 100 | n02097047,n02092468 101 | n02097130,n02092468 102 | n02097209,n02092468 103 | n02097298,n02092468 104 | n02097474,n02092468 105 | n02097658,n02092468 106 | n02098105,n02092468 107 | n02098286,n02092468 108 | n02098413,n02092468 109 | n02099267,n02098550 110 | n02099429,n02098550 111 | n02099601,n02098550 112 | n02099712,n02098550 113 | n02099849,n02098550 114 | n02100236,n02098550 115 | n02100583,n02098550 116 | n02100735,n02098550 117 | n02100877,n02098550 118 | n02101006,n02098550 119 | n02101388,n02098550 120 | n02101556,n02098550 121 | n02102040,n02098550 122 | n02102177,n02098550 123 | n02102318,n02098550 124 | n02102480,n02098550 125 | n02102973,n02098550 126 | n02123045,n02120997 127 | n02123159,n02120997 128 | n02123394,n02120997 129 | n02123597,n02120997 130 | n02124075,n02120997 131 | n02125311,n02120997 132 | n02127052,n02120997 133 | n02128385,n02120997 134 | n02128757,n02120997 135 | n02128925,n02120997 136 | n02129165,n02120997 137 | n02129604,n02120997 138 | n02130308,n02120997 139 | n02389026,n02370806 140 | n02391049,n02370806 141 | n02395406,n02370806 142 | n02396427,n02370806 143 | n02397096,n02370806 144 | n02398521,n02370806 145 | n02403003,n02370806 146 | n02408429,n02370806 147 | n02410509,n02370806 148 | n02412080,n02370806 149 | n02415577,n02370806 150 | n02417914,n02370806 151 | n02422106,n02370806 152 | n02422699,n02370806 153 | n02423022,n02370806 154 | n02437312,n02370806 155 | n02437616,n02370806 156 | n02480495,n02469914 157 | n02480855,n02469914 158 | n02481823,n02469914 159 | n02483362,n02469914 160 | n02483708,n02469914 161 | n02484975,n02469914 162 | n02486261,n02469914 163 | n02486410,n02469914 164 | n02487347,n02469914 165 | n02488291,n02469914 166 | n02488702,n02469914 167 | n02489166,n02469914 168 | n02490219,n02469914 169 | n02492035,n02469914 170 | n02492660,n02469914 171 | n02493509,n02469914 172 | n02493793,n02469914 173 | n02494079,n02469914 174 | n02497673,n02469914 175 | n02500267,n02469914 176 | n02727426,n02913152 177 | n02793495,n02913152 178 | n02859443,n02913152 179 | n03028079,n02913152 180 | n03032252,n02913152 181 | n03457902,n02913152 182 | n03529860,n02913152 183 | n03661043,n02913152 184 | n03781244,n02913152 185 | n03788195,n02913152 186 | n03877845,n02913152 187 | n03956157,n02913152 188 | n04081281,n02913152 189 | n04346328,n02913152 190 | n02687172,n03125870 191 | n02690373,n03125870 192 | n02692877,n03125870 193 | n02782093,n03125870 194 | n02951358,n03125870 195 | n02981792,n03125870 196 | n03095699,n03125870 197 | n03344393,n03125870 198 | n03447447,n03125870 199 | n03662601,n03125870 200 | n03673027,n03125870 201 | n03947888,n03125870 202 | n04147183,n03125870 203 | n04266014,n03125870 204 | n04273569,n03125870 205 | n04347754,n03125870 206 | n04483307,n03125870 207 | n04552348,n03125870 208 | n04606251,n03125870 209 | n04612504,n03125870 210 | n03207941,n03257877 211 | n03259280,n03257877 212 | n03297495,n03257877 213 | n03483316,n03257877 214 | n03584829,n03257877 215 | n03761084,n03257877 216 | n04070727,n03257877 217 | n04111531,n03257877 218 | n04442312,n03257877 219 | n04517823,n03257877 220 | n04542943,n03257877 221 | n04554684,n03257877 222 | n02979186,n03278248 223 | n02988304,n03278248 224 | n02992529,n03278248 225 | n03085013,n03278248 226 | n03187595,n03278248 227 | n03584254,n03278248 228 | n03777754,n03278248 229 | n03782006,n03278248 230 | n03857828,n03278248 231 | n03902125,n03278248 232 | n04392985,n03278248 233 | n02776631,n03297735 234 | n02791270,n03297735 235 | n02871525,n03297735 236 | n02927161,n03297735 237 | n03089624,n03297735 238 | n03461385,n03297735 239 | n04005630,n03297735 240 | n04200800,n03297735 241 | n04443257,n03297735 242 | n04462240,n03297735 243 | n02791124,n03405265 244 | n02804414,n03405265 245 | n02870880,n03405265 246 | n03016953,n03405265 247 | n03018349,n03405265 248 | n03125729,n03405265 249 | n03131574,n03405265 250 | n03179701,n03405265 251 | n03201208,n03405265 252 | n03290653,n03405265 253 | n03337140,n03405265 254 | n03376595,n03405265 255 | n03388549,n03405265 256 | n03742115,n03405265 257 | n03891251,n03405265 258 | n03998194,n03405265 259 | n04099969,n03405265 260 | n04344873,n03405265 261 | n04380533,n03405265 262 | n04429376,n03405265 263 | n04447861,n03405265 264 | n04550184,n03405265 265 | n02799071,n03414162 266 | n02802426,n03414162 267 | n03134739,n03414162 268 | n03445777,n03414162 269 | n03598930,n03414162 270 | n03942813,n03414162 271 | n04023962,n03414162 272 | n04118538,n03414162 273 | n04254680,n03414162 274 | n04409515,n03414162 275 | n04540053,n03414162 276 | n06785654,n03414162 277 | n02667093,n03419014 278 | n02837789,n03419014 279 | n02865351,n03419014 280 | n02883205,n03419014 281 | n02892767,n03419014 282 | n02963159,n03419014 283 | n03188531,n03419014 284 | n03325584,n03419014 285 | n03404251,n03419014 286 | n03534580,n03419014 287 | n03594734,n03419014 288 | n03595614,n03419014 289 | n03617480,n03419014 290 | n03630383,n03419014 291 | n03710721,n03419014 292 | n03770439,n03419014 293 | n03866082,n03419014 294 | n03980874,n03419014 295 | n04136333,n03419014 296 | n04325704,n03419014 297 | n04350905,n03419014 298 | n04370456,n03419014 299 | n04371430,n03419014 300 | n04479046,n03419014 301 | n04591157,n03419014 302 | n02708093,n03574816 303 | n02749479,n03574816 304 | n02794156,n03574816 305 | n02841315,n03574816 306 | n02879718,n03574816 307 | n02950826,n03574816 308 | n03196217,n03574816 309 | n03197337,n03574816 310 | n03467068,n03574816 311 | n03544143,n03574816 312 | n03692522,n03574816 313 | n03706229,n03574816 314 | n03773504,n03574816 315 | n03841143,n03574816 316 | n03891332,n03574816 317 | n04008634,n03574816 318 | n04009552,n03574816 319 | n04044716,n03574816 320 | n04086273,n03574816 321 | n04090263,n03574816 322 | n04118776,n03574816 323 | n04141975,n03574816 324 | n04317175,n03574816 325 | n04328186,n03574816 326 | n04355338,n03574816 327 | n04356056,n03574816 328 | n04376876,n03574816 329 | n04548280,n03574816 330 | n02666196,n03699975 331 | n02977058,n03699975 332 | n03180011,n03699975 333 | n03485407,n03699975 334 | n03496892,n03699975 335 | n03642806,n03699975 336 | n03832673,n03699975 337 | n04238763,n03699975 338 | n04243546,n03699975 339 | n04428191,n03699975 340 | n04525305,n03699975 341 | n06359193,n03699975 342 | n02966193,n03738472 343 | n02974003,n03738472 344 | n03425413,n03738472 345 | n03532672,n03738472 346 | n03874293,n03738472 347 | n03944341,n03738472 348 | n03992509,n03738472 349 | n04019541,n03738472 350 | n04040759,n03738472 351 | n04067472,n03738472 352 | n04371774,n03738472 353 | n04372370,n03738472 354 | n02701002,n03791235 355 | n02704792,n03791235 356 | n02814533,n03791235 357 | n02930766,n03791235 358 | n03100240,n03791235 359 | n03345487,n03791235 360 | n03417042,n03791235 361 | n03444034,n03791235 362 | n03445924,n03791235 363 | n03594945,n03791235 364 | n03670208,n03791235 365 | n03770679,n03791235 366 | n03777568,n03791235 367 | n03785016,n03791235 368 | n03796401,n03791235 369 | n03930630,n03791235 370 | n03977966,n03791235 371 | n04037443,n03791235 372 | n04252225,n03791235 373 | n04285008,n03791235 374 | n04461696,n03791235 375 | n04467665,n03791235 376 | n02672831,n03800933 377 | n02676566,n03800933 378 | n02787622,n03800933 379 | n02804610,n03800933 380 | n02992211,n03800933 381 | n03017168,n03800933 382 | n03110669,n03800933 383 | n03249569,n03800933 384 | n03272010,n03800933 385 | n03372029,n03800933 386 | n03394916,n03800933 387 | n03447721,n03800933 388 | n03452741,n03800933 389 | n03494278,n03800933 390 | n03495258,n03800933 391 | n03720891,n03800933 392 | n03721384,n03800933 393 | n03838899,n03800933 394 | n03840681,n03800933 395 | n03854065,n03800933 396 | n03884397,n03800933 397 | n04141076,n03800933 398 | n04311174,n03800933 399 | n04487394,n03800933 400 | n04515003,n03800933 401 | n04536866,n03800933 402 | n02825657,n04014297 403 | n02840245,n04014297 404 | n02843684,n04014297 405 | n02895154,n04014297 406 | n03000247,n04014297 407 | n03146219,n04014297 408 | n03220513,n04014297 409 | n03347037,n04014297 410 | n03424325,n04014297 411 | n03527444,n04014297 412 | n03637318,n04014297 413 | n03657121,n04014297 414 | n03788365,n04014297 415 | n03929855,n04014297 416 | n04141327,n04014297 417 | n04192698,n04014297 418 | n04229816,n04014297 419 | n04417672,n04014297 420 | n04423845,n04014297 421 | n04435653,n04014297 422 | n04507155,n04014297 423 | n04523525,n04014297 424 | n04589890,n04014297 425 | n04590129,n04014297 426 | n02910353,n04081844 427 | n03075370,n04081844 428 | n03208938,n04081844 429 | n03476684,n04081844 430 | n03627232,n04081844 431 | n03803284,n04081844 432 | n03804744,n04081844 433 | n03874599,n04081844 434 | n04127249,n04081844 435 | n04153751,n04081844 436 | n04162706,n04081844 437 | n02951585,n04451818 438 | n03041632,n04451818 439 | n03109150,n04451818 440 | n03481172,n04451818 441 | n03498962,n04451818 442 | n03649909,n04451818 443 | n03658185,n04451818 444 | n03954731,n04451818 445 | n03967562,n04451818 446 | n03970156,n04451818 447 | n04154565,n04451818 448 | n04208210,n04451818 449 | -------------------------------------------------------------------------------- /fewshot/data/tiered_imagenet_split/trainval.csv: -------------------------------------------------------------------------------- 1 | n01530575,n01524359 2 | n01531178,n01524359 3 | n01532829,n01524359 4 | n01534433,n01524359 5 | n01537544,n01524359 6 | n01558993,n01524359 7 | n01560419,n01524359 8 | n01580077,n01524359 9 | n01582220,n01524359 10 | n01592084,n01524359 11 | n01601694,n01524359 12 | n01675722,n01674216 13 | n01677366,n01674216 14 | n01682714,n01674216 15 | n01685808,n01674216 16 | n01687978,n01674216 17 | n01688243,n01674216 18 | n01689811,n01674216 19 | n01692333,n01674216 20 | n01693334,n01674216 21 | n01694178,n01674216 22 | n01695060,n01674216 23 | n01728572,n01726692 24 | n01728920,n01726692 25 | n01729322,n01726692 26 | n01729977,n01726692 27 | n01734418,n01726692 28 | n01735189,n01726692 29 | n01737021,n01726692 30 | n01739381,n01726692 31 | n01740131,n01726692 32 | n01742172,n01726692 33 | n01744401,n01726692 34 | n01748264,n01726692 35 | n01749939,n01726692 36 | n01751748,n01726692 37 | n01753488,n01726692 38 | n01755581,n01726692 39 | n01756291,n01726692 40 | n01847000,n01844917 41 | n01855032,n01844917 42 | n01855672,n01844917 43 | n01860187,n01844917 44 | n02002556,n01844917 45 | n02002724,n01844917 46 | n02006656,n01844917 47 | n02007558,n01844917 48 | n02009229,n01844917 49 | n02009912,n01844917 50 | n02011460,n01844917 51 | n02012849,n01844917 52 | n02013706,n01844917 53 | n02017213,n01844917 54 | n02018207,n01844917 55 | n02018795,n01844917 56 | n02025239,n01844917 57 | n02027492,n01844917 58 | n02028035,n01844917 59 | n02033041,n01844917 60 | n02037110,n01844917 61 | n02051845,n01844917 62 | n02056570,n01844917 63 | n02058221,n01844917 64 | n02088094,n02087551 65 | n02088238,n02087551 66 | n02088364,n02087551 67 | n02088466,n02087551 68 | n02088632,n02087551 69 | n02089078,n02087551 70 | n02089867,n02087551 71 | n02089973,n02087551 72 | n02090379,n02087551 73 | n02090622,n02087551 74 | n02090721,n02087551 75 | n02091032,n02087551 76 | n02091134,n02087551 77 | n02091244,n02087551 78 | n02091467,n02087551 79 | n02091635,n02087551 80 | n02091831,n02087551 81 | n02092002,n02087551 82 | n02092339,n02087551 83 | n02093256,n02092468 84 | n02093428,n02092468 85 | n02093647,n02092468 86 | n02093754,n02092468 87 | n02093859,n02092468 88 | n02093991,n02092468 89 | n02094114,n02092468 90 | n02094258,n02092468 91 | n02094433,n02092468 92 | n02095314,n02092468 93 | n02095570,n02092468 94 | n02095889,n02092468 95 | n02096051,n02092468 96 | n02096177,n02092468 97 | n02096294,n02092468 98 | n02096437,n02092468 99 | n02096585,n02092468 100 | n02097047,n02092468 101 | n02097130,n02092468 102 | n02097209,n02092468 103 | n02097298,n02092468 104 | n02097474,n02092468 105 | n02097658,n02092468 106 | n02098105,n02092468 107 | n02098286,n02092468 108 | n02098413,n02092468 109 | n02099267,n02098550 110 | n02099429,n02098550 111 | n02099601,n02098550 112 | n02099712,n02098550 113 | n02099849,n02098550 114 | n02100236,n02098550 115 | n02100583,n02098550 116 | n02100735,n02098550 117 | n02100877,n02098550 118 | n02101006,n02098550 119 | n02101388,n02098550 120 | n02101556,n02098550 121 | n02102040,n02098550 122 | n02102177,n02098550 123 | n02102318,n02098550 124 | n02102480,n02098550 125 | n02102973,n02098550 126 | n02123045,n02120997 127 | n02123159,n02120997 128 | n02123394,n02120997 129 | n02123597,n02120997 130 | n02124075,n02120997 131 | n02125311,n02120997 132 | n02127052,n02120997 133 | n02128385,n02120997 134 | n02128757,n02120997 135 | n02128925,n02120997 136 | n02129165,n02120997 137 | n02129604,n02120997 138 | n02130308,n02120997 139 | n02389026,n02370806 140 | n02391049,n02370806 141 | n02395406,n02370806 142 | n02396427,n02370806 143 | n02397096,n02370806 144 | n02398521,n02370806 145 | n02403003,n02370806 146 | n02408429,n02370806 147 | n02410509,n02370806 148 | n02412080,n02370806 149 | n02415577,n02370806 150 | n02417914,n02370806 151 | n02422106,n02370806 152 | n02422699,n02370806 153 | n02423022,n02370806 154 | n02437312,n02370806 155 | n02437616,n02370806 156 | n02480495,n02469914 157 | n02480855,n02469914 158 | n02481823,n02469914 159 | n02483362,n02469914 160 | n02483708,n02469914 161 | n02484975,n02469914 162 | n02486261,n02469914 163 | n02486410,n02469914 164 | n02487347,n02469914 165 | n02488291,n02469914 166 | n02488702,n02469914 167 | n02489166,n02469914 168 | n02490219,n02469914 169 | n02492035,n02469914 170 | n02492660,n02469914 171 | n02493509,n02469914 172 | n02493793,n02469914 173 | n02494079,n02469914 174 | n02497673,n02469914 175 | n02500267,n02469914 176 | n02727426,n02913152 177 | n02793495,n02913152 178 | n02859443,n02913152 179 | n03028079,n02913152 180 | n03032252,n02913152 181 | n03457902,n02913152 182 | n03529860,n02913152 183 | n03661043,n02913152 184 | n03781244,n02913152 185 | n03788195,n02913152 186 | n03877845,n02913152 187 | n03956157,n02913152 188 | n04081281,n02913152 189 | n04346328,n02913152 190 | n02687172,n03125870 191 | n02690373,n03125870 192 | n02692877,n03125870 193 | n02782093,n03125870 194 | n02951358,n03125870 195 | n02981792,n03125870 196 | n03095699,n03125870 197 | n03344393,n03125870 198 | n03447447,n03125870 199 | n03662601,n03125870 200 | n03673027,n03125870 201 | n03947888,n03125870 202 | n04147183,n03125870 203 | n04266014,n03125870 204 | n04273569,n03125870 205 | n04347754,n03125870 206 | n04483307,n03125870 207 | n04552348,n03125870 208 | n04606251,n03125870 209 | n04612504,n03125870 210 | n03207941,n03257877 211 | n03259280,n03257877 212 | n03297495,n03257877 213 | n03483316,n03257877 214 | n03584829,n03257877 215 | n03761084,n03257877 216 | n04070727,n03257877 217 | n04111531,n03257877 218 | n04442312,n03257877 219 | n04517823,n03257877 220 | n04542943,n03257877 221 | n04554684,n03257877 222 | n02979186,n03278248 223 | n02988304,n03278248 224 | n02992529,n03278248 225 | n03085013,n03278248 226 | n03187595,n03278248 227 | n03584254,n03278248 228 | n03777754,n03278248 229 | n03782006,n03278248 230 | n03857828,n03278248 231 | n03902125,n03278248 232 | n04392985,n03278248 233 | n02776631,n03297735 234 | n02791270,n03297735 235 | n02871525,n03297735 236 | n02927161,n03297735 237 | n03089624,n03297735 238 | n03461385,n03297735 239 | n04005630,n03297735 240 | n04200800,n03297735 241 | n04443257,n03297735 242 | n04462240,n03297735 243 | n02791124,n03405265 244 | n02804414,n03405265 245 | n02870880,n03405265 246 | n03016953,n03405265 247 | n03018349,n03405265 248 | n03125729,n03405265 249 | n03131574,n03405265 250 | n03179701,n03405265 251 | n03201208,n03405265 252 | n03290653,n03405265 253 | n03337140,n03405265 254 | n03376595,n03405265 255 | n03388549,n03405265 256 | n03742115,n03405265 257 | n03891251,n03405265 258 | n03998194,n03405265 259 | n04099969,n03405265 260 | n04344873,n03405265 261 | n04380533,n03405265 262 | n04429376,n03405265 263 | n04447861,n03405265 264 | n04550184,n03405265 265 | n02799071,n03414162 266 | n02802426,n03414162 267 | n03134739,n03414162 268 | n03445777,n03414162 269 | n03598930,n03414162 270 | n03942813,n03414162 271 | n04023962,n03414162 272 | n04118538,n03414162 273 | n04254680,n03414162 274 | n04409515,n03414162 275 | n04540053,n03414162 276 | n06785654,n03414162 277 | n02667093,n03419014 278 | n02837789,n03419014 279 | n02865351,n03419014 280 | n02883205,n03419014 281 | n02892767,n03419014 282 | n02963159,n03419014 283 | n03188531,n03419014 284 | n03325584,n03419014 285 | n03404251,n03419014 286 | n03534580,n03419014 287 | n03594734,n03419014 288 | n03595614,n03419014 289 | n03617480,n03419014 290 | n03630383,n03419014 291 | n03710721,n03419014 292 | n03770439,n03419014 293 | n03866082,n03419014 294 | n03980874,n03419014 295 | n04136333,n03419014 296 | n04325704,n03419014 297 | n04350905,n03419014 298 | n04370456,n03419014 299 | n04371430,n03419014 300 | n04479046,n03419014 301 | n04591157,n03419014 302 | n02708093,n03574816 303 | n02749479,n03574816 304 | n02794156,n03574816 305 | n02841315,n03574816 306 | n02879718,n03574816 307 | n02950826,n03574816 308 | n03196217,n03574816 309 | n03197337,n03574816 310 | n03467068,n03574816 311 | n03544143,n03574816 312 | n03692522,n03574816 313 | n03706229,n03574816 314 | n03773504,n03574816 315 | n03841143,n03574816 316 | n03891332,n03574816 317 | n04008634,n03574816 318 | n04009552,n03574816 319 | n04044716,n03574816 320 | n04086273,n03574816 321 | n04090263,n03574816 322 | n04118776,n03574816 323 | n04141975,n03574816 324 | n04317175,n03574816 325 | n04328186,n03574816 326 | n04355338,n03574816 327 | n04356056,n03574816 328 | n04376876,n03574816 329 | n04548280,n03574816 330 | n02666196,n03699975 331 | n02977058,n03699975 332 | n03180011,n03699975 333 | n03485407,n03699975 334 | n03496892,n03699975 335 | n03642806,n03699975 336 | n03832673,n03699975 337 | n04238763,n03699975 338 | n04243546,n03699975 339 | n04428191,n03699975 340 | n04525305,n03699975 341 | n06359193,n03699975 342 | n02966193,n03738472 343 | n02974003,n03738472 344 | n03425413,n03738472 345 | n03532672,n03738472 346 | n03874293,n03738472 347 | n03944341,n03738472 348 | n03992509,n03738472 349 | n04019541,n03738472 350 | n04040759,n03738472 351 | n04067472,n03738472 352 | n04371774,n03738472 353 | n04372370,n03738472 354 | n02701002,n03791235 355 | n02704792,n03791235 356 | n02814533,n03791235 357 | n02930766,n03791235 358 | n03100240,n03791235 359 | n03345487,n03791235 360 | n03417042,n03791235 361 | n03444034,n03791235 362 | n03445924,n03791235 363 | n03594945,n03791235 364 | n03670208,n03791235 365 | n03770679,n03791235 366 | n03777568,n03791235 367 | n03785016,n03791235 368 | n03796401,n03791235 369 | n03930630,n03791235 370 | n03977966,n03791235 371 | n04037443,n03791235 372 | n04252225,n03791235 373 | n04285008,n03791235 374 | n04461696,n03791235 375 | n04467665,n03791235 376 | n02672831,n03800933 377 | n02676566,n03800933 378 | n02787622,n03800933 379 | n02804610,n03800933 380 | n02992211,n03800933 381 | n03017168,n03800933 382 | n03110669,n03800933 383 | n03249569,n03800933 384 | n03272010,n03800933 385 | n03372029,n03800933 386 | n03394916,n03800933 387 | n03447721,n03800933 388 | n03452741,n03800933 389 | n03494278,n03800933 390 | n03495258,n03800933 391 | n03720891,n03800933 392 | n03721384,n03800933 393 | n03838899,n03800933 394 | n03840681,n03800933 395 | n03854065,n03800933 396 | n03884397,n03800933 397 | n04141076,n03800933 398 | n04311174,n03800933 399 | n04487394,n03800933 400 | n04515003,n03800933 401 | n04536866,n03800933 402 | n02825657,n04014297 403 | n02840245,n04014297 404 | n02843684,n04014297 405 | n02895154,n04014297 406 | n03000247,n04014297 407 | n03146219,n04014297 408 | n03220513,n04014297 409 | n03347037,n04014297 410 | n03424325,n04014297 411 | n03527444,n04014297 412 | n03637318,n04014297 413 | n03657121,n04014297 414 | n03788365,n04014297 415 | n03929855,n04014297 416 | n04141327,n04014297 417 | n04192698,n04014297 418 | n04229816,n04014297 419 | n04417672,n04014297 420 | n04423845,n04014297 421 | n04435653,n04014297 422 | n04507155,n04014297 423 | n04523525,n04014297 424 | n04589890,n04014297 425 | n04590129,n04014297 426 | n02910353,n04081844 427 | n03075370,n04081844 428 | n03208938,n04081844 429 | n03476684,n04081844 430 | n03627232,n04081844 431 | n03803284,n04081844 432 | n03804744,n04081844 433 | n03874599,n04081844 434 | n04127249,n04081844 435 | n04153751,n04081844 436 | n04162706,n04081844 437 | n02951585,n04451818 438 | n03041632,n04451818 439 | n03109150,n04451818 440 | n03481172,n04451818 441 | n03498962,n04451818 442 | n03649909,n04451818 443 | n03658185,n04451818 444 | n03954731,n04451818 445 | n03967562,n04451818 446 | n03970156,n04451818 447 | n04154565,n04451818 448 | n04208210,n04451818 449 | -------------------------------------------------------------------------------- /fewshot/data/omniglot.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 Mengye Ren, Eleni Triantafillou, Sachin Ravi, Jake Snell, 2 | # Kevin Swersky, Joshua B. Tenenbaum, Hugo Larochelle, Richars S. Zemel. 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the 'Software'), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all 12 | # copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | # SOFTWARE. 21 | # ============================================================================= 22 | from __future__ import (absolute_import, division, print_function, 23 | unicode_literals) 24 | import cv2 25 | import numpy as np 26 | import os 27 | import pickle as pkl 28 | 29 | import tensorflow as tf 30 | 31 | from fewshot.data.episode import Episode 32 | from fewshot.data.data_factory import RegisterDataset 33 | from fewshot.data.refinement_dataset import RefinementMetaDataset 34 | from fewshot.utils import logger 35 | 36 | log = logger.get() 37 | flags = tf.flags 38 | FLAGS = tf.flags.FLAGS 39 | 40 | 41 | def get_image_folder(folder, split_def, split): 42 | if split_def == 'lake': 43 | if split == 'train': 44 | folder_ = os.path.join(folder, 'images_background') 45 | else: 46 | folder_ = os.path.join(folder, 'images_evaluation') 47 | elif split_def == 'vinyals': 48 | folder_ = os.path.join(folder, 'images_all') 49 | return folder_ 50 | 51 | 52 | def get_vinyals_split_file(split): 53 | curdir = os.path.dirname(os.path.realpath(__file__)) 54 | split_file = os.path.join(curdir, 'omniglot_split', '{}.txt'.format(split)) 55 | return split_file 56 | 57 | 58 | def read_lake_split(folder, aug_90=False): 59 | """Reads dataset from folder.""" 60 | subfolders = os.listdir(folder) 61 | label_idx = [] 62 | label_str = [] 63 | data = [] 64 | for sf in subfolders: 65 | sf_ = os.path.join(folder, sf) 66 | img_fnames = os.listdir(sf_) 67 | for character in img_fnames: 68 | char_folder = os.path.join(sf_, character) 69 | img_list = os.listdir(char_folder) 70 | for img_fname in img_list: 71 | fname_ = os.path.join(char_folder, img_fname) 72 | img = cv2.imread(fname_) 73 | # Shrink images. 74 | img = cv2.resize(img, (28, 28), interpolation=cv2.INTER_AREA) 75 | img = np.minimum(255, np.maximum(0, img)) 76 | img = 255 - img[:, :, 0:1] 77 | if aug_90: 78 | M = cv2.getRotationMatrix2D((14, 14), 90, 1) 79 | dst = img 80 | for ii in range(4): 81 | dst = cv2.warpAffine(dst, M, (28, 28)) 82 | data.append(np.expand_dims(np.expand_dims(dst, 0), 3)) 83 | label_idx.append(len(label_str) + ii) 84 | else: 85 | img = np.expand_dims(img, 0) 86 | data.append(img) 87 | label_idx.append(len(label_str)) 88 | 89 | if aug_90: 90 | for ii in range(4): 91 | label_str.append(sf + '_' + character + '_' + str(ii)) 92 | else: 93 | label_str.append(sf + '_' + character) 94 | print('Number of classes {}'.format(len(label_str))) 95 | print('Number of images {}'.format(len(data))) 96 | images = np.concatenate(data, axis=0) 97 | labels = np.array(label_idx, dtype=np.int32) 98 | label_str = label_str 99 | return images, labels, label_str 100 | 101 | 102 | def read_vinyals_split(folder, split_file, aug_90=False): 103 | """Reads dataset from a folder with a split file.""" 104 | lines = open(split_file, 'r').readlines() 105 | lines = map(lambda x: x.strip('\n\r'), lines) 106 | label_idx = [] 107 | label_str = [] 108 | data = [] 109 | for ff in lines: 110 | char_folder = os.path.join(folder, ff) 111 | img_list = os.listdir(char_folder) 112 | for img_fname in img_list: 113 | fname_ = os.path.join(char_folder, img_fname) 114 | img = cv2.imread(fname_) 115 | img = cv2.resize(img, (28, 28), interpolation=cv2.INTER_CUBIC) 116 | img = np.minimum(255, np.maximum(0, img)) 117 | img = 255 - img[:, :, 0:1] 118 | if aug_90: 119 | M = cv2.getRotationMatrix2D((14, 14), 90, 1) 120 | dst = img 121 | for ii in range(4): 122 | dst = cv2.warpAffine(dst, M, (28, 28)) 123 | data.append(np.expand_dims(np.expand_dims(dst, 0), 3)) 124 | label_idx.append(len(label_str) + ii) 125 | else: 126 | img = np.expand_dims(img, 0) 127 | data.append(img) 128 | label_idx.append(len(label_str)) 129 | if aug_90: 130 | for ii in range(4): 131 | label_str.append(ff + '_' + str(ii)) 132 | else: 133 | label_str.append(ff) 134 | print('Number of classes {}'.format(len(label_str))) 135 | print('Number of images {}'.format(len(data))) 136 | images = np.concatenate(data, axis=0) 137 | labels = np.array(label_idx, dtype=np.int32) 138 | return images, labels, label_str 139 | 140 | 141 | @RegisterDataset('omniglot') 142 | class OmniglotDataset(RefinementMetaDataset): 143 | """A few-shot learning dataset with refinement (unlabeled) training. images. 144 | """ 145 | 146 | def __init__(self, 147 | folder, 148 | split, 149 | nway=5, 150 | nshot=1, 151 | num_unlabel=5, 152 | num_distractor=5, 153 | num_test=-1, 154 | split_def='vinyals', 155 | label_ratio=None, 156 | aug_90=True, 157 | shuffle_episode=False, 158 | seed=0): 159 | """Creates a meta dataset. 160 | Args: 161 | folder: String. Path to the Omniglot dataset. 162 | split: String. 'train' or 'test' for Lake's split, 'train', 'trainval', 163 | 'val', test' for Vinyals' split. 164 | nway: Int. N way classification problem, default 5. 165 | nshot: Int. N-shot classification problem, default 1. 166 | num_unlabel: Int. Number of unlabeled examples per class, default 2. 167 | num_distractor: Int. Number of distractor classes, default 0. 168 | num_test: Int. Number of query images, default 10. 169 | split_def: String. 'vinyals' or 'lake', using different split definitions. 170 | aug_90: Bool. Whether to augment the training data by rotating 90 degrees. 171 | seed: Int. Random seed. 172 | """ 173 | self._folder = folder 174 | self._aug_90 = aug_90 175 | self._split_def = split_def 176 | self._split = split 177 | if FLAGS.disable_distractor: 178 | num_distractor = 0 179 | super(OmniglotDataset, 180 | self).__init__(split, nway, nshot, num_unlabel, num_distractor, 181 | num_test, label_ratio, shuffle_episode, seed) 182 | 183 | def get_images(self, inds): 184 | return self._images[inds] 185 | 186 | def read_cache(self): 187 | """Reads dataset from cached pklz file.""" 188 | cache_path = self.get_cache_path() 189 | print(cache_path) 190 | if os.path.exists(cache_path): 191 | try: 192 | with open(cache_path, 'rb') as f: 193 | data = pkl.load(f, encoding='bytes') 194 | self._images = data[b'images'] 195 | self._labels = data[b'labels'] 196 | self._label_str = data[b'label_str'] 197 | except: 198 | with open(cache_path, 'rb') as f: 199 | data = pkl.load(f) 200 | self._images = data['images'] 201 | self._labels = data['labels'] 202 | self._label_str = data['label_str'] 203 | self.read_label_split() 204 | return True 205 | else: 206 | return False 207 | 208 | def read_label_split(self): 209 | cache_path_labelsplit = self.get_label_split_path() 210 | if os.path.exists(cache_path_labelsplit): 211 | self._label_split_idx = np.loadtxt(cache_path_labelsplit, dtype=np.int64) 212 | else: 213 | if self._split in ['train', 'trainval']: 214 | log.info('Use {}% image for labeled split.'.format( 215 | int(self._label_ratio * 100))) 216 | self._label_split_idx = self.label_split() 217 | elif self._split in ['val', 'test']: 218 | log.info('Use all image in labeled split, since we are in val/test') 219 | self._label_split_idx = np.arange(self._images.shape[0]) 220 | else: 221 | raise ValueError('Unknown split {}'.format(self._split)) 222 | self._label_split_idx = np.array(self.label_split(), dtype=np.int64) 223 | self.save_label_split() 224 | 225 | def save_cache(self): 226 | """Saves pklz cache.""" 227 | data = { 228 | 'images': self._images, 229 | 'labels': self._labels, 230 | 'label_str': self._label_str, 231 | } 232 | with open(self.get_cache_path(), 'wb') as f: 233 | pkl.dump(data, f, protocol=pkl.HIGHEST_PROTOCOL) 234 | 235 | def save_label_split(self): 236 | np.savetxt(self.get_label_split_path(), self._label_split_idx, fmt='%d') 237 | 238 | def read_dataset(self): 239 | # Read data from folder or cache. 240 | if not self.read_cache(): 241 | folder, split_def, split = self._folder, self._split_def, self._split 242 | folder = get_image_folder(folder, split_def, split) 243 | if split_def == 'lake': 244 | self._images, self._labels, self._label_str = read_lake_split( 245 | folder, aug_90=self._aug_90) 246 | elif split_def == 'vinyals': 247 | split_file = get_vinyals_split_file(self._split) 248 | self._images, self._labels, self._label_str = read_vinyals_split( 249 | folder, split_file, aug_90=self._aug_90) 250 | self.read_label_split() 251 | self.save_cache() 252 | 253 | def get_label_split_path(self): 254 | aug_str = '_aug90' if self._aug_90 else '' 255 | split_def_str = '_' + self._split_def 256 | label_ratio_str = '_' + str(int(self._label_ratio * 100)) 257 | seed_id_str = '_' + str(self._seed) 258 | if self._split in ['train', 'trainval']: 259 | cache_path = os.path.join( 260 | self._folder, self._split + split_def_str + aug_str + '_labelsplit' + 261 | label_ratio_str + seed_id_str + '.txt') 262 | elif self._split in ['val', 'test']: 263 | cache_path = os.path.join( 264 | self._folder, 265 | self._split + split_def_str + aug_str + '_labelsplit' + '.txt') 266 | return cache_path 267 | 268 | def get_cache_path(self): 269 | """Gets cache file name.""" 270 | aug_str = '_aug90' if self._aug_90 else '' 271 | split_def_str = '_' + self._split_def 272 | cache_path = os.path.join(self._folder, 273 | self._split + split_def_str + aug_str + '.pkl') 274 | 275 | return cache_path 276 | -------------------------------------------------------------------------------- /fewshot/data/omniglot_split/test.txt: -------------------------------------------------------------------------------- 1 | Gurmukhi/character42 2 | Gurmukhi/character43 3 | Gurmukhi/character44 4 | Gurmukhi/character45 5 | Kannada/character01 6 | Kannada/character02 7 | Kannada/character03 8 | Kannada/character04 9 | Kannada/character05 10 | Kannada/character06 11 | Kannada/character07 12 | Kannada/character08 13 | Kannada/character09 14 | Kannada/character10 15 | Kannada/character11 16 | Kannada/character12 17 | Kannada/character13 18 | Kannada/character14 19 | Kannada/character15 20 | Kannada/character16 21 | Kannada/character17 22 | Kannada/character18 23 | Kannada/character19 24 | Kannada/character20 25 | Kannada/character21 26 | Kannada/character22 27 | Kannada/character23 28 | Kannada/character24 29 | Kannada/character25 30 | Kannada/character26 31 | Kannada/character27 32 | Kannada/character28 33 | Kannada/character29 34 | Kannada/character30 35 | Kannada/character31 36 | Kannada/character32 37 | Kannada/character33 38 | Kannada/character34 39 | Kannada/character35 40 | Kannada/character36 41 | Kannada/character37 42 | Kannada/character38 43 | Kannada/character39 44 | Kannada/character40 45 | Kannada/character41 46 | Keble/character01 47 | Keble/character02 48 | Keble/character03 49 | Keble/character04 50 | Keble/character05 51 | Keble/character06 52 | Keble/character07 53 | Keble/character08 54 | Keble/character09 55 | Keble/character10 56 | Keble/character11 57 | Keble/character12 58 | Keble/character13 59 | Keble/character14 60 | Keble/character15 61 | Keble/character16 62 | Keble/character17 63 | Keble/character18 64 | Keble/character19 65 | Keble/character20 66 | Keble/character21 67 | Keble/character22 68 | Keble/character23 69 | Keble/character24 70 | Keble/character25 71 | Keble/character26 72 | Malayalam/character01 73 | Malayalam/character02 74 | Malayalam/character03 75 | Malayalam/character04 76 | Malayalam/character05 77 | Malayalam/character06 78 | Malayalam/character07 79 | Malayalam/character08 80 | Malayalam/character09 81 | Malayalam/character10 82 | Malayalam/character11 83 | Malayalam/character12 84 | Malayalam/character13 85 | Malayalam/character14 86 | Malayalam/character15 87 | Malayalam/character16 88 | Malayalam/character17 89 | Malayalam/character18 90 | Malayalam/character19 91 | Malayalam/character20 92 | Malayalam/character21 93 | Malayalam/character22 94 | Malayalam/character23 95 | Malayalam/character24 96 | Malayalam/character25 97 | Malayalam/character26 98 | Malayalam/character27 99 | Malayalam/character28 100 | Malayalam/character29 101 | Malayalam/character30 102 | Malayalam/character31 103 | Malayalam/character32 104 | Malayalam/character33 105 | Malayalam/character34 106 | Malayalam/character35 107 | Malayalam/character36 108 | Malayalam/character37 109 | Malayalam/character38 110 | Malayalam/character39 111 | Malayalam/character40 112 | Malayalam/character41 113 | Malayalam/character42 114 | Malayalam/character43 115 | Malayalam/character44 116 | Malayalam/character45 117 | Malayalam/character46 118 | Malayalam/character47 119 | Manipuri/character01 120 | Manipuri/character02 121 | Manipuri/character03 122 | Manipuri/character04 123 | Manipuri/character05 124 | Manipuri/character06 125 | Manipuri/character07 126 | Manipuri/character08 127 | Manipuri/character09 128 | Manipuri/character10 129 | Manipuri/character11 130 | Manipuri/character12 131 | Manipuri/character13 132 | Manipuri/character14 133 | Manipuri/character15 134 | Manipuri/character16 135 | Manipuri/character17 136 | Manipuri/character18 137 | Manipuri/character19 138 | Manipuri/character20 139 | Manipuri/character21 140 | Manipuri/character22 141 | Manipuri/character23 142 | Manipuri/character24 143 | Manipuri/character25 144 | Manipuri/character26 145 | Manipuri/character27 146 | Manipuri/character28 147 | Manipuri/character29 148 | Manipuri/character30 149 | Manipuri/character31 150 | Manipuri/character32 151 | Manipuri/character33 152 | Manipuri/character34 153 | Manipuri/character35 154 | Manipuri/character36 155 | Manipuri/character37 156 | Manipuri/character38 157 | Manipuri/character39 158 | Manipuri/character40 159 | Mongolian/character01 160 | Mongolian/character02 161 | Mongolian/character03 162 | Mongolian/character04 163 | Mongolian/character05 164 | Mongolian/character06 165 | Mongolian/character07 166 | Mongolian/character08 167 | Mongolian/character09 168 | Mongolian/character10 169 | Mongolian/character11 170 | Mongolian/character12 171 | Mongolian/character13 172 | Mongolian/character14 173 | Mongolian/character15 174 | Mongolian/character16 175 | Mongolian/character17 176 | Mongolian/character18 177 | Mongolian/character19 178 | Mongolian/character20 179 | Mongolian/character21 180 | Mongolian/character22 181 | Mongolian/character23 182 | Mongolian/character24 183 | Mongolian/character25 184 | Mongolian/character26 185 | Mongolian/character27 186 | Mongolian/character28 187 | Mongolian/character29 188 | Mongolian/character30 189 | Old_Church_Slavonic_(Cyrillic)/character01 190 | Old_Church_Slavonic_(Cyrillic)/character02 191 | Old_Church_Slavonic_(Cyrillic)/character03 192 | Old_Church_Slavonic_(Cyrillic)/character04 193 | Old_Church_Slavonic_(Cyrillic)/character05 194 | Old_Church_Slavonic_(Cyrillic)/character06 195 | Old_Church_Slavonic_(Cyrillic)/character07 196 | Old_Church_Slavonic_(Cyrillic)/character08 197 | Old_Church_Slavonic_(Cyrillic)/character09 198 | Old_Church_Slavonic_(Cyrillic)/character10 199 | Old_Church_Slavonic_(Cyrillic)/character11 200 | Old_Church_Slavonic_(Cyrillic)/character12 201 | Old_Church_Slavonic_(Cyrillic)/character13 202 | Old_Church_Slavonic_(Cyrillic)/character14 203 | Old_Church_Slavonic_(Cyrillic)/character15 204 | Old_Church_Slavonic_(Cyrillic)/character16 205 | Old_Church_Slavonic_(Cyrillic)/character17 206 | Old_Church_Slavonic_(Cyrillic)/character18 207 | Old_Church_Slavonic_(Cyrillic)/character19 208 | Old_Church_Slavonic_(Cyrillic)/character20 209 | Old_Church_Slavonic_(Cyrillic)/character21 210 | Old_Church_Slavonic_(Cyrillic)/character22 211 | Old_Church_Slavonic_(Cyrillic)/character23 212 | Old_Church_Slavonic_(Cyrillic)/character24 213 | Old_Church_Slavonic_(Cyrillic)/character25 214 | Old_Church_Slavonic_(Cyrillic)/character26 215 | Old_Church_Slavonic_(Cyrillic)/character27 216 | Old_Church_Slavonic_(Cyrillic)/character28 217 | Old_Church_Slavonic_(Cyrillic)/character29 218 | Old_Church_Slavonic_(Cyrillic)/character30 219 | Old_Church_Slavonic_(Cyrillic)/character31 220 | Old_Church_Slavonic_(Cyrillic)/character32 221 | Old_Church_Slavonic_(Cyrillic)/character33 222 | Old_Church_Slavonic_(Cyrillic)/character34 223 | Old_Church_Slavonic_(Cyrillic)/character35 224 | Old_Church_Slavonic_(Cyrillic)/character36 225 | Old_Church_Slavonic_(Cyrillic)/character37 226 | Old_Church_Slavonic_(Cyrillic)/character38 227 | Old_Church_Slavonic_(Cyrillic)/character39 228 | Old_Church_Slavonic_(Cyrillic)/character40 229 | Old_Church_Slavonic_(Cyrillic)/character41 230 | Old_Church_Slavonic_(Cyrillic)/character42 231 | Old_Church_Slavonic_(Cyrillic)/character43 232 | Old_Church_Slavonic_(Cyrillic)/character44 233 | Old_Church_Slavonic_(Cyrillic)/character45 234 | Oriya/character01 235 | Oriya/character02 236 | Oriya/character03 237 | Oriya/character04 238 | Oriya/character05 239 | Oriya/character06 240 | Oriya/character07 241 | Oriya/character08 242 | Oriya/character09 243 | Oriya/character10 244 | Oriya/character11 245 | Oriya/character12 246 | Oriya/character13 247 | Oriya/character14 248 | Oriya/character15 249 | Oriya/character16 250 | Oriya/character17 251 | Oriya/character18 252 | Oriya/character19 253 | Oriya/character20 254 | Oriya/character21 255 | Oriya/character22 256 | Oriya/character23 257 | Oriya/character24 258 | Oriya/character25 259 | Oriya/character26 260 | Oriya/character27 261 | Oriya/character28 262 | Oriya/character29 263 | Oriya/character30 264 | Oriya/character31 265 | Oriya/character32 266 | Oriya/character33 267 | Oriya/character34 268 | Oriya/character35 269 | Oriya/character36 270 | Oriya/character37 271 | Oriya/character38 272 | Oriya/character39 273 | Oriya/character40 274 | Oriya/character41 275 | Oriya/character42 276 | Oriya/character43 277 | Oriya/character44 278 | Oriya/character45 279 | Oriya/character46 280 | Syriac_(Serto)/character01 281 | Syriac_(Serto)/character02 282 | Syriac_(Serto)/character03 283 | Syriac_(Serto)/character04 284 | Syriac_(Serto)/character05 285 | Syriac_(Serto)/character06 286 | Syriac_(Serto)/character07 287 | Syriac_(Serto)/character08 288 | Syriac_(Serto)/character09 289 | Syriac_(Serto)/character10 290 | Syriac_(Serto)/character11 291 | Syriac_(Serto)/character12 292 | Syriac_(Serto)/character13 293 | Syriac_(Serto)/character14 294 | Syriac_(Serto)/character15 295 | Syriac_(Serto)/character16 296 | Syriac_(Serto)/character17 297 | Syriac_(Serto)/character18 298 | Syriac_(Serto)/character19 299 | Syriac_(Serto)/character20 300 | Syriac_(Serto)/character21 301 | Syriac_(Serto)/character22 302 | Syriac_(Serto)/character23 303 | Sylheti/character01 304 | Sylheti/character02 305 | Sylheti/character03 306 | Sylheti/character04 307 | Sylheti/character05 308 | Sylheti/character06 309 | Sylheti/character07 310 | Sylheti/character08 311 | Sylheti/character09 312 | Sylheti/character10 313 | Sylheti/character11 314 | Sylheti/character12 315 | Sylheti/character13 316 | Sylheti/character14 317 | Sylheti/character15 318 | Sylheti/character16 319 | Sylheti/character17 320 | Sylheti/character18 321 | Sylheti/character19 322 | Sylheti/character20 323 | Sylheti/character21 324 | Sylheti/character22 325 | Sylheti/character23 326 | Sylheti/character24 327 | Sylheti/character25 328 | Sylheti/character26 329 | Sylheti/character27 330 | Sylheti/character28 331 | Tengwar/character01 332 | Tengwar/character02 333 | Tengwar/character03 334 | Tengwar/character04 335 | Tengwar/character05 336 | Tengwar/character06 337 | Tengwar/character07 338 | Tengwar/character08 339 | Tengwar/character09 340 | Tengwar/character10 341 | Tengwar/character11 342 | Tengwar/character12 343 | Tengwar/character13 344 | Tengwar/character14 345 | Tengwar/character15 346 | Tengwar/character16 347 | Tengwar/character17 348 | Tengwar/character18 349 | Tengwar/character19 350 | Tengwar/character20 351 | Tengwar/character21 352 | Tengwar/character22 353 | Tengwar/character23 354 | Tengwar/character24 355 | Tengwar/character25 356 | Tibetan/character01 357 | Tibetan/character02 358 | Tibetan/character03 359 | Tibetan/character04 360 | Tibetan/character05 361 | Tibetan/character06 362 | Tibetan/character07 363 | Tibetan/character08 364 | Tibetan/character09 365 | Tibetan/character10 366 | Tibetan/character11 367 | Tibetan/character12 368 | Tibetan/character13 369 | Tibetan/character14 370 | Tibetan/character15 371 | Tibetan/character16 372 | Tibetan/character17 373 | Tibetan/character18 374 | Tibetan/character19 375 | Tibetan/character20 376 | Tibetan/character21 377 | Tibetan/character22 378 | Tibetan/character23 379 | Tibetan/character24 380 | Tibetan/character25 381 | Tibetan/character26 382 | Tibetan/character27 383 | Tibetan/character28 384 | Tibetan/character29 385 | Tibetan/character30 386 | Tibetan/character31 387 | Tibetan/character32 388 | Tibetan/character33 389 | Tibetan/character34 390 | Tibetan/character35 391 | Tibetan/character36 392 | Tibetan/character37 393 | Tibetan/character38 394 | Tibetan/character39 395 | Tibetan/character40 396 | Tibetan/character41 397 | Tibetan/character42 398 | ULOG/character01 399 | ULOG/character02 400 | ULOG/character03 401 | ULOG/character04 402 | ULOG/character05 403 | ULOG/character06 404 | ULOG/character07 405 | ULOG/character08 406 | ULOG/character09 407 | ULOG/character10 408 | ULOG/character11 409 | ULOG/character12 410 | ULOG/character13 411 | ULOG/character14 412 | ULOG/character15 413 | ULOG/character16 414 | ULOG/character17 415 | ULOG/character18 416 | ULOG/character19 417 | ULOG/character20 418 | ULOG/character21 419 | ULOG/character22 420 | ULOG/character23 421 | ULOG/character24 422 | ULOG/character25 423 | ULOG/character26 424 | -------------------------------------------------------------------------------- /fewshot/data/refinement_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 Mengye Ren, Eleni Triantafillou, Sachin Ravi, Jake Snell, 2 | # Kevin Swersky, Joshua B. Tenenbaum, Hugo Larochelle, Richars S. Zemel. 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the "Software"), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in all 12 | # copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | # SOFTWARE. 21 | # ============================================================================= 22 | import cv2 23 | import numpy as np 24 | import os 25 | import gzip 26 | import pickle as pkl 27 | import tensorflow as tf 28 | 29 | from fewshot.data.episode import Episode 30 | from fewshot.data.data_factory import RegisterDataset 31 | from fewshot.utils import logger 32 | 33 | log = logger.get() 34 | 35 | flags = tf.flags 36 | flags.DEFINE_bool("disable_distractor", False, 37 | "Whether or not to disable distractors") 38 | flags.DEFINE_float("label_ratio", 0.1, 39 | "Portion of labeled images in the training set.") 40 | FLAGS = tf.flags.FLAGS 41 | 42 | 43 | class MetaDataset(object): 44 | 45 | def next(self): 46 | """Get a new episode training.""" 47 | pass 48 | 49 | 50 | class RefinementMetaDataset(object): 51 | """A few-shot learning dataset with refinement (unlabeled) training. images. 52 | """ 53 | 54 | def __init__(self, split, nway, nshot, num_unlabel, num_distractor, num_test, 55 | label_ratio, shuffle_episode, seed): 56 | """Creates a meta dataset. 57 | Args: 58 | folder: String. Path to the dataset. 59 | split: String. 60 | nway: Int. N way classification problem, default 5. 61 | nshot: Int. N-shot classification problem, default 1. 62 | num_unlabel: Int. Number of unlabeled examples per class, default 2. 63 | num_distractor: Int. Number of distractor classes, default 0. 64 | num_test: Int. Number of query images, default 10. 65 | split_def: String. "vinyals" or "lake", using different split definitions. 66 | aug_90: Bool. Whether to augment the training data by rotating 90 degrees. 67 | seed: Int. Random seed. 68 | """ 69 | self._split = split 70 | self._nway = nway 71 | self._nshot = nshot 72 | self._num_unlabel = num_unlabel 73 | self._rnd = np.random.RandomState(seed) 74 | self._seed = seed 75 | self._num_distractor = 0 if FLAGS.disable_distractor else num_distractor 76 | log.warning("Number of distractors in each episode: {}".format( 77 | self._num_distractor)) 78 | self._num_test = num_test 79 | self._label_ratio = FLAGS.label_ratio if label_ratio is None else label_ratio 80 | log.info('Label ratio {}'.format(self._label_ratio)) 81 | self._shuffle_episode = shuffle_episode 82 | 83 | self.read_dataset() 84 | 85 | # Build a set for quick query. 86 | self._label_split_idx = np.array(self._label_split_idx) 87 | self._label_split_idx_set = set(list(self._label_split_idx)) 88 | self._unlabel_split_idx = list( 89 | filter(lambda _idx: _idx not in self._label_split_idx_set, 90 | range(self._labels.shape[0]))) 91 | self._unlabel_split_idx = np.array(self._unlabel_split_idx) 92 | if len(self._unlabel_split_idx) > 0: 93 | self._unlabel_split_idx_set = set(self._unlabel_split_idx) 94 | else: 95 | self._unlabel_split_idx_set = set() 96 | 97 | num_label_cls = len(self._label_str) 98 | self._num_classes = num_label_cls 99 | num_ex = self._labels.shape[0] 100 | ex_ids = np.arange(num_ex) 101 | self._label_idict = {} 102 | for cc in range(num_label_cls): 103 | self._label_idict[cc] = ex_ids[self._labels == cc] 104 | self._nshot = nshot 105 | 106 | def read_dataset(self): 107 | """Reads data from folder or cache.""" 108 | raise NotImplemented() 109 | 110 | def label_split(self): 111 | """Gets label/unlabel image splits. 112 | Returns: 113 | labeled_split: List of int. 114 | """ 115 | log.info('Label split using seed {:d}'.format(self._seed)) 116 | rnd = np.random.RandomState(self._seed) 117 | num_label_cls = len(self._label_str) 118 | num_ex = self._labels.shape[0] 119 | ex_ids = np.arange(num_ex) 120 | 121 | labeled_split = [] 122 | for cc in range(num_label_cls): 123 | cids = ex_ids[self._labels == cc] 124 | rnd.shuffle(cids) 125 | labeled_split.extend(cids[:int(len(cids) * self._label_ratio)]) 126 | log.info("Total number of classes {}".format(num_label_cls)) 127 | log.info("Labeled split {}".format(len(labeled_split))) 128 | log.info("Total image {}".format(num_ex)) 129 | return sorted(labeled_split) 130 | 131 | def next(self, within_category=False, catcode=None): 132 | """Gets a new episode. 133 | within_category: bool. Whether or not to choose the N classes 134 | to all belong to the same more general category. 135 | (Only applicable for datasets with self._category_labels defined). 136 | 137 | within_category: bool. Whether or not to restrict the episode's classes 138 | to belong to the same general category (only applicable for JakeImageNet). 139 | If True, a random general category will be chosen, unless catcode is set. 140 | 141 | catcode: str. (e.g. 'n02795169') if catcode is provided (is not None), 142 | then the classes chosen for this episode will be restricted 143 | to be synsets belonging to the more general category with code catcode. 144 | """ 145 | 146 | if within_category or not catcode is None: 147 | assert hasattr(self, "_category_labels") 148 | assert hasattr(self, "_category_label_str") 149 | if catcode is None: 150 | # Choose a category for this episode's classes 151 | cat_idx = np.random.randint(len(self._category_label_str)) 152 | catcode = self._catcode_to_syncode.keys()[cat_idx] 153 | cat_synsets = self._catcode_to_syncode[catcode] 154 | cat_synsets_str = [self._syncode_to_str[code] for code in cat_synsets] 155 | allowable_inds = [] 156 | for str in cat_synsets_str: 157 | allowable_inds.append(np.where(np.array(self._label_str) == str)[0]) 158 | class_seq = np.array(allowable_inds).reshape((-1)) 159 | else: 160 | num_label_cls = len(self._label_str) 161 | class_seq = np.arange(num_label_cls) 162 | 163 | self._rnd.shuffle(class_seq) 164 | 165 | train_img_ids = [] 166 | train_labels = [] 167 | test_img_ids = [] 168 | test_labels = [] 169 | 170 | train_unlabel_img_ids = [] 171 | non_distractor = [] 172 | 173 | train_labels_str = [] 174 | test_labels_str = [] 175 | 176 | is_training = self._split in ["train", "trainval"] 177 | assert is_training or self._split in ["val", "test"] 178 | 179 | for ii in range(self._nway + self._num_distractor): 180 | 181 | cc = class_seq[ii] 182 | # print(cc, ii < self._nway) 183 | _ids = self._label_idict[cc] 184 | 185 | # Split the image IDs into labeled and unlabeled. 186 | _label_ids = list( 187 | filter(lambda _id: _id in self._label_split_idx_set, _ids)) 188 | _unlabel_ids = list( 189 | filter(lambda _id: _id not in self._label_split_idx_set, _ids)) 190 | self._rnd.shuffle(_label_ids) 191 | self._rnd.shuffle(_unlabel_ids) 192 | 193 | # Add support set and query set (not for distractors). 194 | if ii < self._nway: 195 | train_img_ids.extend(_label_ids[:self._nshot]) 196 | 197 | # Use the rest of the labeled image as queries, if num_test = -1. 198 | QUERY_SIZE_LARGE_ERR_MSG = ( 199 | "Query + reference should be less than labeled examples." + 200 | "Num labeled {} Num test {} Num shot {}".format( 201 | len(_label_ids), self._num_test, self._nshot)) 202 | assert self._nshot + self._num_test <= len( 203 | _label_ids), QUERY_SIZE_LARGE_ERR_MSG 204 | 205 | if self._num_test == -1: 206 | if is_training: 207 | num_test = len(_label_ids) - self._nshot 208 | else: 209 | num_test = len(_label_ids) - self._nshot - self._num_unlabel 210 | else: 211 | num_test = self._num_test 212 | if is_training: 213 | assert num_test <= len(_label_ids) - self._nshot 214 | else: 215 | assert num_test <= len(_label_ids) - self._num_unlabel - self._nshot 216 | 217 | test_img_ids.extend(_label_ids[self._nshot:self._nshot + num_test]) 218 | train_labels.extend([ii] * self._nshot) 219 | train_labels_str.extend([self._label_str[cc]] * self._nshot) 220 | test_labels.extend([ii] * num_test) 221 | test_labels_str.extend([self._label_str[cc]] * num_test) 222 | non_distractor.extend([1] * self._num_unlabel) 223 | else: 224 | non_distractor.extend([0] * self._num_unlabel) 225 | 226 | # Add unlabeled images here. 227 | if is_training: 228 | # Use labeled, unlabeled split here for refinement. 229 | train_unlabel_img_ids.extend(_unlabel_ids[:self._num_unlabel]) 230 | 231 | else: 232 | # Copy test set for refinement. 233 | # This will only work if the test procedure is rolled out in a sequence. 234 | train_unlabel_img_ids.extend(_label_ids[ 235 | self._nshot + num_test:self._nshot + num_test + self._num_unlabel]) 236 | 237 | train_img = self.get_images(train_img_ids) / 255.0 238 | train_unlabel_img = self.get_images(train_unlabel_img_ids) / 255.0 239 | test_img = self.get_images(test_img_ids) / 255.0 240 | train_labels = np.array(train_labels) 241 | test_labels = np.array(test_labels) 242 | train_labels_str = np.array(train_labels_str) 243 | test_labels_str = np.array(test_labels_str) 244 | non_distractor = np.array(non_distractor) 245 | 246 | test_ids_set = set(test_img_ids) 247 | for _id in train_unlabel_img_ids: 248 | assert _id not in test_ids_set 249 | 250 | if self._shuffle_episode: 251 | # log.fatal('') 252 | # Shuffle the sequence order in an episode. Very important for RNN based 253 | # meta learners. 254 | train_idx = np.arange(train_img.shape[0]) 255 | self._rnd.shuffle(train_idx) 256 | train_img = train_img[train_idx] 257 | train_labels = train_labels[train_idx] 258 | 259 | train_unlabel_idx = np.arange(train_unlabel_img.shape[0]) 260 | self._rnd.shuffle(train_unlabel_idx) 261 | train_unlabel_img = train_unlabel_img[train_unlabel_idx] 262 | 263 | test_idx = np.arange(test_img.shape[0]) 264 | self._rnd.shuffle(test_idx) 265 | test_img = test_img[test_idx] 266 | test_labels = test_labels[test_idx] 267 | 268 | return Episode( 269 | train_img, 270 | train_labels, 271 | test_img, 272 | test_labels, 273 | x_unlabel=train_unlabel_img, 274 | y_unlabel=non_distractor, 275 | y_train_str=train_labels_str, 276 | y_test_str=test_labels_str) 277 | 278 | def reset(self): 279 | self._rnd = np.random.RandomState(self._seed) 280 | 281 | def get_size(self): 282 | """Gets the size of the supervised portion.""" 283 | return len(self._label_split_idx) 284 | 285 | def get_batch_idx(self, idx): 286 | """Gets a fully supervised training batch for classification. 287 | 288 | Returns: A tuple of 289 | x: Input image batch [N, H, W, C]. 290 | y: Label class integer ID [N]. 291 | """ 292 | return self._images[self._label_split_idx[idx]], self._labels[ 293 | self._label_split_idx[idx]] 294 | 295 | def get_batch_idx_test(self, idx): 296 | """Gets the test set (unlabeled set) for the fully supervised training.""" 297 | 298 | return self._images[self._unlabel_split_idx[idx]], self._labels[ 299 | self._unlabel_split_idx[idx]] 300 | 301 | @property 302 | def num_classes(self): 303 | return self._num_classes 304 | --------------------------------------------------------------------------------