├── .gitignore ├── Makefile ├── README.md ├── config ├── base.yaml ├── data │ ├── 1_shot.yaml │ ├── 5_shot.yaml │ └── variable.yaml └── method │ ├── bd_cspn.yaml │ ├── finetune.yaml │ ├── laplacian_shot.yaml │ ├── maml.yaml │ ├── protonet.yaml │ ├── simpleshot.yaml │ ├── tim_adm.yaml │ └── tim_gd.yaml ├── github_figures ├── episode.png ├── inference_metrics.png └── training_metric.png ├── requirements.txt ├── setup.py └── src ├── datasets ├── __init__.py ├── config.py ├── dataset_spec.py ├── imagenet_specification.py ├── imagenet_stats.py ├── loader.py ├── original_meta_dataset │ ├── data │ │ ├── __init__.py │ │ ├── config.py │ │ ├── dataset_spec.py │ │ ├── decoder.py │ │ ├── decoder_test.py │ │ ├── homemade_transforms.py │ │ ├── imagenet_specification.py │ │ ├── imagenet_specification_test.py │ │ ├── imagenet_stats.py │ │ ├── learning_spec.py │ │ ├── pipeline.py │ │ ├── pipeline_test.py │ │ ├── providers.py │ │ ├── reader.py │ │ ├── reader_test.py │ │ ├── sampling.py │ │ ├── sampling_test.py │ │ └── test_utils.py │ └── learn │ │ └── gin │ │ ├── best │ │ ├── baseline_all.gin │ │ ├── baseline_imagenet.gin │ │ ├── baselinefinetune_all.gin │ │ ├── baselinefinetune_cosine_mini_imagenet_fiveshot.gin │ │ ├── baselinefinetune_cosine_mini_imagenet_oneshot.gin │ │ ├── baselinefinetune_imagenet.gin │ │ ├── baselinefinetune_mini_imagenet_fiveshot.gin │ │ ├── baselinefinetune_mini_imagenet_oneshot.gin │ │ ├── maml_all.gin │ │ ├── maml_all_from_scratch.gin │ │ ├── maml_imagenet.gin │ │ ├── maml_imagenet_from_scratch.gin │ │ ├── maml_init_with_proto_all.gin │ │ ├── maml_init_with_proto_all_from_scratch.gin │ │ ├── maml_init_with_proto_imagenet.gin │ │ ├── maml_init_with_proto_imagenet_from_scratch.gin │ │ ├── maml_init_with_proto_inference_all.gin │ │ ├── maml_init_with_proto_inference_imagenet.gin │ │ ├── matching_all.gin │ │ ├── matching_all_from_scratch.gin │ │ ├── matching_imagenet.gin │ │ ├── matching_imagenet_from_scratch.gin │ │ ├── matching_inference_all.gin │ │ ├── matching_inference_imagenet.gin │ │ ├── pretrain_imagenet_convnet.gin │ │ ├── pretrain_imagenet_resnet.gin │ │ ├── pretrain_imagenet_wide_resnet.gin │ │ ├── pretrained_convnet.gin │ │ ├── pretrained_resnet.gin │ │ ├── pretrained_wide_resnet.gin │ │ ├── prototypical_all.gin │ │ ├── prototypical_all_from_scratch.gin │ │ ├── prototypical_imagenet.gin │ │ ├── prototypical_imagenet_from_scratch.gin │ │ ├── prototypical_inference_all.gin │ │ ├── prototypical_inference_imagenet.gin │ │ ├── relationnet_all.gin │ │ ├── relationnet_all_from_scratch.gin │ │ ├── relationnet_imagenet.gin │ │ ├── relationnet_imagenet_from_scratch.gin │ │ └── relationnet_mini_imagenet_fiveshot.gin │ │ ├── default │ │ ├── baseline_all.gin │ │ ├── baseline_cosine_imagenet.gin │ │ ├── baseline_imagenet.gin │ │ ├── baseline_mini_imagenet_fiveshot.gin │ │ ├── baseline_mini_imagenet_oneshot.gin │ │ ├── baselinefinetune_all.gin │ │ ├── baselinefinetune_cosine_imagenet.gin │ │ ├── baselinefinetune_cosine_mini_imagenet_fiveshot.gin │ │ ├── baselinefinetune_cosine_mini_imagenet_oneshot.gin │ │ ├── baselinefinetune_imagenet.gin │ │ ├── baselinefinetune_mini_imagenet_fiveshot.gin │ │ ├── baselinefinetune_mini_imagenet_oneshot.gin │ │ ├── crosstransformer_imagenet.gin │ │ ├── crosstransformer_simclreps_imagenet.gin │ │ ├── debug_proto_fungi.gin │ │ ├── debug_proto_mini_imagenet.gin │ │ ├── maml_all.gin │ │ ├── maml_imagenet.gin │ │ ├── maml_init_with_proto_all.gin │ │ ├── maml_init_with_proto_imagenet.gin │ │ ├── maml_init_with_proto_inference_all.gin │ │ ├── maml_init_with_proto_inference_imagenet.gin │ │ ├── maml_init_with_proto_mini_imagenet_fiveshot.gin │ │ ├── maml_init_with_proto_mini_imagenet_oneshot.gin │ │ ├── maml_mini_imagenet_fiveshot.gin │ │ ├── maml_mini_imagenet_oneshot.gin │ │ ├── maml_protonet_all.gin │ │ ├── matching_all.gin │ │ ├── matching_imagenet.gin │ │ ├── matching_inference_all.gin │ │ ├── matching_inference_imagenet.gin │ │ ├── matching_mini_imagenet_fiveshot.gin │ │ ├── matching_mini_imagenet_oneshot.gin │ │ ├── pretrained_resnet34_224.gin │ │ ├── prototypical_all.gin │ │ ├── prototypical_imagenet.gin │ │ ├── prototypical_inference_all.gin │ │ ├── prototypical_inference_imagenet.gin │ │ ├── prototypical_mini_imagenet_fiveshot.gin │ │ ├── prototypical_mini_imagenet_oneshot.gin │ │ ├── relationnet_all.gin │ │ ├── relationnet_imagenet.gin │ │ ├── relationnet_mini_imagenet_fiveshot.gin │ │ ├── relationnet_mini_imagenet_oneshot.gin │ │ └── resnet34_stride16.gin │ │ ├── learners │ │ ├── baseline_config.gin │ │ ├── baseline_cosine_config.gin │ │ ├── baselinefinetune_config.gin │ │ ├── baselinefinetune_cosine_config.gin │ │ ├── crosstransformer_config.gin │ │ ├── learner_config.gin │ │ ├── maml_config.gin │ │ ├── maml_init_with_proto_config.gin │ │ ├── matching_config.gin │ │ ├── prototypical_config.gin │ │ └── relationnet_config.gin │ │ └── setups │ │ ├── all.gin │ │ ├── all_datasets.gin │ │ ├── data_config.gin │ │ ├── data_config_common.gin │ │ ├── data_config_feature.gin │ │ ├── data_config_no_decoder.gin │ │ ├── data_config_string.gin │ │ ├── fixed_way_and_shot.gin │ │ ├── imagenet.gin │ │ ├── mini_imagenet.gin │ │ ├── trainer_config.gin │ │ ├── trainer_config_debug.gin │ │ └── variable_way_and_shot.gin ├── pipeline.py ├── reader.py ├── sampling.py ├── tfrecord │ ├── __init__.py │ ├── example_pb2.py │ ├── iterator_utils.py │ ├── reader.py │ ├── tools │ │ ├── __init__.py │ │ └── tfrecord2idx.py │ ├── torch │ │ ├── __init__.py │ │ └── dataset.py │ └── writer.py ├── transform.py └── utils.py ├── eval.py ├── losses.py ├── methods ├── __init__.py ├── bd_cspn.py ├── finetune.py ├── maml.py ├── method.py ├── protonet.py ├── simpleshot.py ├── tim.py └── utils.py ├── metrics.py ├── metrics ├── __init__.py └── metric.py ├── models ├── ingredient.py ├── meta │ ├── __init__.py │ ├── conv4.py │ ├── metamodules │ │ ├── __init__.py │ │ ├── activation.py │ │ ├── batchnorm.py │ │ ├── container.py │ │ ├── conv.py │ │ ├── linear.py │ │ ├── module.py │ │ ├── normalization.py │ │ ├── parallel.py │ │ └── sparse.py │ ├── resnet.py │ └── wideres.py └── standard │ ├── __init__.py │ ├── conv4.py │ ├── helpers.py │ ├── hub.py │ ├── resnet.py │ ├── resnet_v2.py │ ├── vit.py │ └── wideres.py ├── optim.py ├── plot.py ├── train.py ├── utils.py └── utils ├── list_files.py └── produce_result_tables.py /.gitignore: -------------------------------------------------------------------------------- 1 | conversion/ 2 | **/make_splits/ 3 | checkpoints/ 4 | **/beluga/ 5 | results/ 6 | deploy/ 7 | # Sublime files 8 | *.sublime-project 9 | *.sublime-workspace 10 | 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | *.py,cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | cover/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | db.sqlite3 72 | db.sqlite3-journal 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | .pybuilder/ 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | # For a library or package, you might want to ignore these files since the code is 97 | # intended to run in multiple environments; otherwise, check them in: 98 | # .python-version 99 | 100 | # pipenv 101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 104 | # install all needed dependencies. 105 | #Pipfile.lock 106 | 107 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 108 | __pypackages__/ 109 | 110 | # Celery stuff 111 | celerybeat-schedule 112 | celerybeat.pid 113 | 114 | # SageMath parsed files 115 | *.sage.py 116 | 117 | # Environments 118 | .env 119 | .venv 120 | env/ 121 | venv/ 122 | ENV/ 123 | env.bak/ 124 | venv.bak/ 125 | 126 | # Spyder project settings 127 | .spyderproject 128 | .spyproject 129 | 130 | # Rope project settings 131 | .ropeproject 132 | 133 | # mkdocs documentation 134 | /site 135 | 136 | # mypy 137 | .mypy_cache/ 138 | .dmypy.json 139 | dmypy.json 140 | 141 | # Pyre type checker 142 | .pyre/ 143 | 144 | # pytype static type analyzer 145 | .pytype/ 146 | 147 | # Cython debug symbols 148 | cython_debug/ 149 | -------------------------------------------------------------------------------- /config/base.yaml: -------------------------------------------------------------------------------- 1 | TRAINING: 2 | print_freq: 100 3 | label_smoothing: 0.1 4 | num_updates: 75000 5 | ckpt_path: 'checkpoints' 6 | seed: 2021 7 | visdom_port: 0 8 | eval_freq: 500 9 | loss: 'xent' 10 | focal_gamma: 3 11 | debug: False 12 | 13 | AUGMENTATIONS: 14 | beta: 1.0 15 | cutmix_prob: 1.0 16 | augmentation: 'none' 17 | 18 | MODEL: 19 | load_from_timm: False 20 | arch: 'resnet18' 21 | 22 | OPTIM: 23 | gamma: 0.1 24 | lr: 0.1 25 | lr_stepsize: 30 26 | nesterov: False 27 | weight_decay: 0.0001 28 | optimizer_name: 'SGD' 29 | scheduler: 'multi_step' 30 | 31 | DATA: 32 | # Data config 33 | image_size: 126 34 | base_source: 'ilsvrc_2012' 35 | val_source: 'ilsvrc_2012' 36 | test_source: 'ilsvrc_2012' 37 | batch_size: 256 38 | split: 'train' 39 | num_workers: 6 40 | train_transforms: ['resize', 'to_tensor', 'normalize', 'gaussian', 'jitter'] 41 | test_transforms: ['resize', 'center_crop', 'to_tensor', 'normalize'] 42 | gaussian_noise_std: 0. 43 | jitter_amount: 0 44 | num_unique_descriptions: 0 45 | shuffle: True 46 | loader_version: 'pytorch' 47 | 48 | EVAL-GENERAL: 49 | eval_metrics: ['Acc'] 50 | plot_freq: 10 51 | model_tag: 'best' 52 | eval_mode: 'test' 53 | val_episodes: 100 54 | val_batch_size: 1 55 | iter: 1 56 | extract_batch_size: 10 # set to >0 to batch feature extraction (save memory) at inference 57 | center_features: False 58 | 59 | 60 | EVAL-VISU: 61 | res_path: 'results/' 62 | max_s_visu: 2 63 | max_q_visu: 3 64 | max_class_visu: 5 65 | visu: False 66 | visu_freq: 10 67 | 68 | EVAL-EPISODES: 69 | num_ways: 10 70 | num_support: 10 71 | num_query: 15 72 | min_ways: 5 73 | max_ways_upper_bound: 50 74 | max_num_query: 10 75 | max_support_set_size: 500 76 | max_support_size_contrib_per_class: 100 77 | min_log_weight: -0.69314718055994529 78 | max_log_weight: 0.69314718055994529 79 | ignore_dag_ontology: False 80 | ignore_bilevel_ontology: False 81 | ignore_hierarchy_probability: 0 82 | min_examples_in_class: 0 83 | 84 | METHOD: 85 | episodic_training: False 86 | 87 | DISTRIBUTED: 88 | gpus: [0] -------------------------------------------------------------------------------- /config/data/1_shot.yaml: -------------------------------------------------------------------------------- 1 | EVAL-EPISODES: 2 | num_ways: 5 3 | num_support: 1 4 | num_query: 15 -------------------------------------------------------------------------------- /config/data/5_shot.yaml: -------------------------------------------------------------------------------- 1 | EVAL-EPISODES: 2 | num_ways: 5 3 | num_support: 5 4 | num_query: 15 -------------------------------------------------------------------------------- /config/data/variable.yaml: -------------------------------------------------------------------------------- 1 | EVAL-EPISODES: 2 | num_ways: 0 3 | num_support: 0 4 | num_query: 0 -------------------------------------------------------------------------------- /config/method/bd_cspn.yaml: -------------------------------------------------------------------------------- 1 | METHOD: 2 | method: 'BDCSPN' 3 | episodic_training: False 4 | temp: 10 5 | iter: 2 6 | hyperparams: ['method', 'arch', 'base_source', 'test_source', 'temp'] 7 | eval_metrics: ['Acc'] -------------------------------------------------------------------------------- /config/method/finetune.yaml: -------------------------------------------------------------------------------- 1 | METHOD: 2 | method: 'Finetune' 3 | episodic_training: False 4 | iter: 100 5 | finetune_lr: 0.01 6 | temp: 10 7 | finetune_all_layers: False 8 | eval_metrics: ['Acc', 'XEnt'] 9 | finetune_cosine_head: True 10 | finetune_weight_norm: True 11 | hyperparams: ['method', 'arch', 'base_source', 'test_source', 'finetune_all_layers'] 12 | -------------------------------------------------------------------------------- /config/method/laplacian_shot.yaml: -------------------------------------------------------------------------------- 1 | METHOD: 2 | method: 'LaplacianShot' 3 | episodic_training: False 4 | knn: 2 5 | lmd: 0.01 6 | proto_rect: 0 7 | normalize_features: True 8 | center_features: False 9 | merge_support: False 10 | iter: 20 11 | hyperparams: ['method', 'base_source', 'test_source', 'arch', 'knn', 'lmd', 'proto_rect'] 12 | eval_metrics: ['Acc', 'LaplacianEnergy'] -------------------------------------------------------------------------------- /config/method/maml.yaml: -------------------------------------------------------------------------------- 1 | TRAINING: 2 | episodic_training: True 3 | num_ways: 5 4 | num_support: 5 5 | num_query: 20 6 | val_batch_size: 1 7 | 8 | METHOD: 9 | method: 'MAML' 10 | temp: 1 11 | train_iter: 1 12 | iter: 10 13 | step_size: 0.4 14 | first_order: True 15 | hyperparams: ['method', 'arch', 'base_source', 'test_source', 'step_size', 'iter'] 16 | eval_metrics: ['Acc'] -------------------------------------------------------------------------------- /config/method/protonet.yaml: -------------------------------------------------------------------------------- 1 | METHOD: 2 | method: 'ProtoNet' 3 | episodic_training: True 4 | hyperparams: ['method', 'arch', 'base_source', 'test_source'] 5 | eval_metrics: ['Acc'] 6 | 7 | DATA: 8 | batch_size: 1 9 | image_size: 126 10 | gaussian_noise_std: 0.1533 11 | jitter_amount: 5 12 | num_unique_descriptions: 0 13 | shuffle: True 14 | 15 | TRAINING: 16 | num_updates: 75000 17 | 18 | OPTIM: 19 | gamma: 0.885 20 | lr: 0.0003 21 | lr_stepsize: 5000 22 | weight_decay: 0.0001 23 | optimizer_name: 'Adam' 24 | scheduler: 'step' -------------------------------------------------------------------------------- /config/method/simpleshot.yaml: -------------------------------------------------------------------------------- 1 | METHOD: 2 | method: 'SimpleShot' 3 | episodic_training: False 4 | eval_metrics: ['Acc'] 5 | hyperparams: ['method', 'arch', 'base_source', 'test_source'] 6 | 7 | DATA: 8 | batch_size: 256 9 | image_size: 126 10 | gaussian_noise_std: 0.1533 11 | jitter_amount: 5 12 | num_unique_descriptions: 0 13 | shuffle: True 14 | 15 | TRAINING: 16 | num_updates: 100000 17 | 18 | OPTIM: 19 | gamma: 0.885 20 | lr: 0.0001 21 | lr_stepsize: 5000 22 | weight_decay: 0.0001 23 | optimizer_name: 'Adam' 24 | scheduler: 'step' -------------------------------------------------------------------------------- /config/method/tim_adm.yaml: -------------------------------------------------------------------------------- 1 | TRAINING: 2 | episodic_training: False 3 | 4 | METHOD: 5 | method: 'TIM_ADM' 6 | iter: 200 7 | alpha: 1.0 8 | temp: 15 9 | loss_weights: [0.5, 10., 0.] # [Xent, H(Y), H(Y|X)] 10 | eval_metrics: ['Acc', 'MI', 'CondEnt', 'MargEnt', 'XEnt'] 11 | hyperparams: ['method', 'arch', 'base_source', 'test_source', 'loss_weights'] -------------------------------------------------------------------------------- /config/method/tim_gd.yaml: -------------------------------------------------------------------------------- 1 | TRAINING: 2 | episodic_training: False 3 | 4 | METHOD: 5 | method: 'TIM_GD' 6 | iter: 100 7 | tim_lr: 0.001 8 | temp: 10 9 | loss_weights: [1.0, 1.0, 0.5] # [Xent, H(Y), H(Y|X)] 10 | eval_metrics: ['Acc', 'MI', 'CondEnt', 'MargEnt', 'XEnt'] 11 | hyperparams: ['method', 'arch', 'base_source', 'test_source', 'loss_weights'] -------------------------------------------------------------------------------- /github_figures/episode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mboudiaf/pytorch-meta-dataset/0822b481a14ff159fad75349adb9ad66f6e0698e/github_figures/episode.png -------------------------------------------------------------------------------- /github_figures/inference_metrics.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mboudiaf/pytorch-meta-dataset/0822b481a14ff159fad75349adb9ad66f6e0698e/github_figures/inference_metrics.png -------------------------------------------------------------------------------- /github_figures/training_metric.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mboudiaf/pytorch-meta-dataset/0822b481a14ff159fad75349adb9ad66f6e0698e/github_figures/training_metric.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.11.0 2 | argon2-cffi==20.1.0 3 | ase==3.21.1 4 | astunparse==1.6.3 5 | async-generator==1.10 6 | attrs==20.3.0 7 | backcall==0.2.0 8 | bleach==3.2.1 9 | bokeh==2.3.1 10 | cachetools==4.2.0 11 | certifi==2020.12.5 12 | cffi==1.14.4 13 | chardet==3.0.4 14 | click==7.1.2 15 | cloudpickle==1.6.0 16 | colorama==0.4.4 17 | colorcet==2.0.6 18 | cox==0.1.post3 19 | cycler==0.10.0 20 | dask==2021.4.1 21 | datashader==0.12.1 22 | datashape==0.5.2 23 | decorator==4.4.2 24 | defusedxml==0.6.0 25 | dill==0.3.3 26 | distributed==2021.4.1 27 | docopt==0.6.2 28 | entrypoints==0.3 29 | flatbuffers==1.12 30 | fsspec==2021.4.0 31 | gast==0.3.3 32 | gin-config==0.4.0 33 | gitdb==4.0.5 34 | GitPython==3.1.11 35 | google-auth==1.24.0 36 | google-auth-oauthlib==0.4.2 37 | google-pasta==0.2.0 38 | googledrivedownloader==0.4 39 | grpcio==1.32.0 40 | h5py==2.10.0 41 | HeapDict==1.0.1 42 | holoviews==1.14.3 43 | idna==2.10 44 | imageio==2.9.0 45 | ipykernel==5.4.2 46 | ipython==7.19.0 47 | ipython-genutils==0.2.0 48 | ipywidgets==7.5.1 49 | isodate==0.6.0 50 | jedi==0.17.2 51 | Jinja2==2.11.2 52 | joblib==1.0.0 53 | jsonpatch==1.28 54 | jsonpickle==1.4.2 55 | jsonpointer==2.0 56 | jsonschema==3.2.0 57 | jupyter==1.0.0 58 | jupyter-client==6.1.7 59 | jupyter-console==6.2.0 60 | jupyter-core==4.7.0 61 | jupyterlab-pygments==0.1.2 62 | Keras-Preprocessing==1.1.2 63 | kiwisolver==1.3.1 64 | littleutils==0.2.2 65 | llvmlite==0.35.0 66 | locket==0.2.1 67 | Markdown==3.3.3 68 | MarkupSafe==1.1.1 69 | matplotlib==3.3.3 70 | mistune==0.8.4 71 | msgpack==1.0.2 72 | multipledispatch==0.6.0 73 | munch==2.5.0 74 | nbclient==0.5.1 75 | nbconvert==6.0.7 76 | nbformat==5.0.8 77 | nest-asyncio==1.4.3 78 | networkx==2.5 79 | notebook==6.1.5 80 | numba==0.52.0 81 | numexpr==2.7.3 82 | numpy==1.19.4 83 | oauthlib==3.1.0 84 | ogb==1.2.4 85 | opencv-python==4.4.0.46 86 | opt-einsum==3.3.0 87 | outdated==0.2.0 88 | packaging==20.8 89 | pandas==1.2.0 90 | pandocfilters==1.4.3 91 | panel==0.11.3 92 | param==1.10.1 93 | parso==0.7.1 94 | partd==1.2.0 95 | pexpect==4.8.0 96 | pickleshare==0.7.5 97 | Pillow==8.0.1 98 | prometheus-client==0.9.0 99 | prompt-toolkit==3.0.8 100 | protobuf==3.14.0 101 | psutil==5.8.0 102 | ptyprocess==0.6.0 103 | py-cpuinfo==7.0.0 104 | py3nvml==0.2.6 105 | pyasn1==0.4.8 106 | pyasn1-modules==0.2.8 107 | pycparser==2.20 108 | pyct==0.4.8 109 | Pygments==2.7.3 110 | pynndescent==0.5.2 111 | pyparsing==2.4.7 112 | pyrsistent==0.17.3 113 | python-dateutil==2.8.1 114 | python-louvain==0.15 115 | pytz==2020.5 116 | pyviz-comms==2.0.1 117 | PyWavelets==1.1.1 118 | PyYAML==5.3.1 119 | pyzmq==20.0.0 120 | qtconsole==5.0.1 121 | QtPy==1.9.0 122 | rdflib==5.0.0 123 | requests==2.25.0 124 | requests-oauthlib==1.3.0 125 | rsa==4.6 126 | sacred==0.8.2 127 | scikit-image==0.18.1 128 | scikit-learn==0.24.1 129 | scipy==1.5.4 130 | seaborn==0.11.1 131 | Send2Trash==1.5.0 132 | six==1.15.0 133 | smmap==3.0.4 134 | sortedcontainers==2.3.0 135 | tables==3.6.1 136 | tblib==1.7.0 137 | tensorboard==2.4.0 138 | tensorboard-plugin-wit==1.7.0 139 | tensorboardX==2.2 140 | tensorflow==2.4.0 141 | tensorflow-estimator==2.4.0 142 | termcolor==1.1.0 143 | terminado==0.9.1 144 | testpath==0.4.4 145 | tfrecord==1.12 146 | threadpoolctl==2.1.0 147 | tifffile==2021.4.8 148 | toolz==0.11.1 149 | torch==1.7.1+cu110 150 | torch-geometric==1.6.3 151 | torch-scatter==2.0.5 152 | torchaudio==0.7.2 153 | torchfile==0.1.0 154 | torchvision==0.8.2+cu110 155 | tornado==6.1 156 | tqdm==4.54.1 157 | traitlets==5.0.5 158 | typing-extensions==3.7.4.3 159 | umap-learn==0.5.1 160 | urllib3==1.26.2 161 | visdom==0.1.8.9 162 | visdom-logger @ git+https://github.com/luizgh/visdom_logger.git@3c0afb18321ef65382d0baf8866772c6a62f15a2 163 | wcwidth==0.2.5 164 | webencodings==0.5.1 165 | websocket-client==0.57.0 166 | Werkzeug==1.0.1 167 | widgetsnbextension==3.5.1 168 | wrapt==1.12.1 169 | xarray==0.17.0 170 | xmltodict==0.12.0 171 | zict==2.0.0 172 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | refs: 3 | - setup tools: https://setuptools.pypa.io/en/latest/userguide/package_discovery.html#using-find-or-find-packages 4 | - https://stackoverflow.com/questions/70295885/how-does-one-install-pytorch-and-related-tools-from-within-the-setup-py-install 5 | """ 6 | from setuptools import setup 7 | from setuptools import find_packages 8 | import os 9 | 10 | # import pathlib 11 | 12 | here = os.path.abspath(os.path.dirname(__file__)) 13 | with open(os.path.join(here, 'README.md'), encoding='utf-8') as f: 14 | long_description = f.read() 15 | 16 | setup( 17 | name='pytorch-meta-dataset', # project name 18 | version='0.0.1', 19 | description="Brando and Patrick's pytorch-meta-datset", 20 | long_description=long_description, 21 | long_description_content_type="text/markdown", 22 | url='https://github.com/brando90/pytorch-meta-dataset', 23 | author='Brando Miranda', 24 | author_email='brandojazz@gmail.com', 25 | python_requires='>=3.9.0', 26 | license='MIT', 27 | 28 | # currently 29 | package_dir={'': 'pytorch_meta_dataset'}, 30 | packages=find_packages('pytorch_meta_dataset'), # imports all modules/folders with __init__.py & python files 31 | 32 | # for pytorch see doc string at the top of file 33 | install_requires=[ 34 | 'absl-py==0.11.0', 35 | 'tfrecord==1.11', 36 | # 'torchvision==0.8.2+cu110', # original pytorch-meta-dataset 37 | # 'torchvision', 38 | 'torchvision==0.10.1+cu111', 39 | 'tqdm==4.54.1' 40 | ] 41 | ) 42 | 43 | -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mboudiaf/pytorch-meta-dataset/0822b481a14ff159fad75349adb9ad66f6e0698e/src/datasets/__init__.py -------------------------------------------------------------------------------- /src/datasets/loader.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | from functools import partial 4 | from typing import Any, Callable, Iterable, Tuple, Union, cast 5 | 6 | import torch 7 | import numpy as np 8 | from torch import Tensor 9 | from torch.utils.data import DataLoader, Dataset 10 | from loguru import logger 11 | 12 | from .utils import Split 13 | from .pipeline import worker_init_fn_ 14 | from . import config as config_lib 15 | from . import pipeline as torch_pipeline 16 | from . import dataset_spec as dataset_spec_lib 17 | 18 | 19 | DL = Union[DataLoader, 20 | Iterable[Tuple[Tensor, ...]]] 21 | 22 | 23 | def get_dataspecs(args: argparse.Namespace, 24 | source: str) -> Tuple[Any, Any, Any]: 25 | # Recovering data 26 | data_config = config_lib.DataConfig(args=args) 27 | episod_config = config_lib.EpisodeDescriptionConfig(args=args) 28 | 29 | use_bilevel_ontology = False 30 | use_dag_ontology = False 31 | 32 | # Enable ontology aware sampling for Omniglot and ImageNet. 33 | if source == 'omniglot': 34 | # use_bilevel_ontology_list[sources.index('omniglot')] = True 35 | use_bilevel_ontology = True 36 | if source in ['ilsvrc_2012', 'ilsvrc_2012_v2']: 37 | use_dag_ontology = True 38 | 39 | episod_config.use_bilevel_ontology = use_bilevel_ontology 40 | episod_config.use_dag_ontology = use_dag_ontology 41 | 42 | dataset_records_path: Path = data_config.path / source 43 | # Original codes handles paths as strings: 44 | dataset_spec = dataset_spec_lib.load_dataset_spec(str(dataset_records_path)) 45 | 46 | return dataset_spec, data_config, episod_config 47 | 48 | 49 | def get_dataloader(args: argparse.Namespace, 50 | source: str, 51 | batch_size: int, 52 | split: Split, 53 | world_size: int, 54 | version: str, 55 | episodic: bool): 56 | dataset_spec, data_config, episod_config = get_dataspecs(args, source) 57 | num_classes = len(dataset_spec.get_classes(split=split)) 58 | 59 | pipeline_fn: Callable[..., Dataset] 60 | data_loader: DL 61 | if version == 'pytorch': 62 | pipeline_fn = cast(Callable[..., Dataset], (torch_pipeline.make_episode_pipeline if episodic 63 | else torch_pipeline.make_batch_pipeline)) 64 | dataset: Dataset = pipeline_fn(dataset_spec=dataset_spec, 65 | data_config=data_config, 66 | split=split, 67 | episode_descr_config=episod_config) 68 | 69 | worker_init_fn = partial(worker_init_fn_, seed=args.seed) 70 | data_loader = DataLoader(dataset=dataset, 71 | batch_size=int(batch_size / world_size), 72 | num_workers=data_config.num_workers, 73 | worker_init_fn=worker_init_fn) 74 | elif version == 'tf': 75 | import gin 76 | import tensorflow as tf 77 | from .original_meta_dataset.data import pipeline as tf_pipeline 78 | 79 | tf.compat.v1.disable_eager_execution() 80 | tf_pipeline_fn: Callable[..., Any] 81 | tf_pipeline_fn = (tf_pipeline.make_one_source_episode_pipeline if episodic 82 | else tf_pipeline.make_one_source_batch_pipeline) 83 | 84 | GIN_FILE_PATH = 'src/datasets/original_meta_dataset/learn/gin/setups/data_config.gin' 85 | gin.parse_config_file(GIN_FILE_PATH) 86 | 87 | tf_dataset: Any = tf_pipeline_fn( 88 | dataset_spec=dataset_spec, 89 | use_dag_ontology=episod_config.use_dag_ontology, 90 | use_bilevel_ontology=episod_config.use_bilevel_ontology, 91 | episode_descr_config=episod_config, 92 | split=split, 93 | batch_size=int(batch_size / world_size), 94 | image_size=84, 95 | shuffle_buffer_size=300) 96 | 97 | iterator: Iterable = tf_dataset.make_one_shot_iterator().get_next() 98 | 99 | mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) 100 | std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) 101 | 102 | session = tf.compat.v1.Session() 103 | 104 | data_loader = infinite_loader(session, iterator, episodic, mean, std) 105 | else: 106 | raise ValueError(f"Wrong loader version, got {version}, \ 107 | expected to be in ['pytorch', 'tf']") 108 | 109 | return data_loader, num_classes 110 | 111 | 112 | def to_torch_imgs(img: np.ndarray, mean: Tensor, std: Tensor) -> Tensor: 113 | t_img: Tensor = torch.from_numpy(np.transpose(img, (0, 3, 1, 2))) 114 | t_img -= mean 115 | t_img /= std 116 | 117 | return t_img 118 | 119 | 120 | def to_torch_labels(a: np.ndarray) -> Tensor: 121 | return torch.from_numpy(a).long() 122 | 123 | 124 | def infinite_loader(session: Any, iterator: Iterable, episodic: bool, 125 | mean: Tensor, std: Tensor) -> Iterable[Tuple[Tensor, ...]]: 126 | while True: 127 | (e, source_id) = session.run(iterator) 128 | if episodic: 129 | yield (to_torch_imgs(e[0], mean, std).unsqueeze(0), 130 | to_torch_imgs(e[3], mean, std).unsqueeze(0), 131 | to_torch_labels(e[1]).unsqueeze(0), 132 | to_torch_labels(e[4]).unsqueeze(0)) 133 | else: 134 | yield (to_torch_imgs(e[0], mean, std), 135 | to_torch_labels(e[1])) 136 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/data/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Meta-Dataset Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python2, python3 17 | """Sub-module for reading data and assembling episodes.""" 18 | # Whether datasets with example-level splits, or pools, are supported. 19 | # Currently, this is not implemented. 20 | POOL_SUPPORTED = False 21 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/data/decoder_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Meta-Dataset Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Tests for meta_dataset.data.decoder.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | from meta_dataset.data import decoder 24 | from meta_dataset.dataset_conversion import dataset_to_records 25 | import numpy as np 26 | import tensorflow.compat.v1 as tf 27 | 28 | 29 | class DecoderTest(tf.test.TestCase): 30 | 31 | def test_string_decoder(self): 32 | # Make random image. 33 | image_size = 32 34 | image = np.random.randint( 35 | low=0, high=255, size=[image_size, image_size, 3]).astype(np.ubyte) 36 | 37 | # Encode 38 | image_bytes = dataset_to_records.encode_image(image, image_format='PNG') 39 | label = np.zeros(1).astype(np.int64) 40 | image_example = dataset_to_records.make_example([ 41 | ('image', 'bytes', [image_bytes]), ('label', 'int64', [label]) 42 | ]) 43 | 44 | # Decode 45 | string_decoder = decoder.StringDecoder() 46 | image_string = string_decoder(image_example) 47 | decoded_image = tf.image.decode_image(image_string) 48 | # Assert perfect reconstruction. 49 | with self.session(use_gpu=False) as sess: 50 | image_decoded = sess.run(decoded_image) 51 | self.assertAllClose(image, image_decoded) 52 | 53 | def test_image_decoder(self): 54 | # Make random image. 55 | image_size = 84 56 | image = np.random.randint( 57 | low=0, high=255, size=[image_size, image_size, 3]).astype(np.ubyte) 58 | 59 | # Encode 60 | image_bytes = dataset_to_records.encode_image(image, image_format='PNG') 61 | label = np.zeros(1).astype(np.int64) 62 | image_example = dataset_to_records.make_example([ 63 | ('image', 'bytes', [image_bytes]), ('label', 'int64', [label]) 64 | ]) 65 | 66 | # Decode 67 | image_decoder = decoder.ImageDecoder(image_size=image_size) 68 | image_decoded = image_decoder(image_example) 69 | # Assert perfect reconstruction. 70 | with self.session(use_gpu=False) as sess: 71 | image_rec_numpy = sess.run(image_decoded) 72 | self.assertAllClose(2 * (image.astype(np.float32) / 255.0 - 0.5), 73 | image_rec_numpy) 74 | 75 | def test_feature_decoder(self): 76 | # Make random feature. 77 | feat_size = 64 78 | feat = np.random.randn(feat_size).astype(np.float32) 79 | label = np.zeros(1).astype(np.int64) 80 | 81 | # Encode 82 | feat_example = dataset_to_records.make_example([ 83 | ('image/embedding', 'float32', feat), 84 | ('image/class/label', 'int64', [label]), 85 | ]) 86 | 87 | # Decode 88 | feat_decoder = decoder.FeatureDecoder(feat_len=feat_size) 89 | feat_decoded = feat_decoder(feat_example) 90 | 91 | # Assert perfect reconstruction. 92 | with self.session(use_gpu=False) as sess: 93 | feat_rec_numpy = sess.run(feat_decoded) 94 | self.assertAllEqual(feat_rec_numpy, feat) 95 | 96 | 97 | if __name__ == '__main__': 98 | tf.test.main() 99 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/data/learning_spec.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Meta-Dataset Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python2, python3 17 | """Interfaces for learning specifications.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import collections 24 | import enum 25 | 26 | 27 | class Split(enum.Enum): 28 | """The possible data splits.""" 29 | TRAIN = 0 30 | VALID = 1 31 | TEST = 2 32 | 33 | 34 | class BatchSpecification( 35 | collections.namedtuple('BatchSpecification', 'split, batch_size')): 36 | """The specification of an episode. 37 | 38 | Args: 39 | split: the Split from which to pick data. 40 | batch_size: an int, the number of (image, label) pairs in the batch. 41 | """ 42 | pass 43 | 44 | 45 | class EpisodeSpecification( 46 | collections.namedtuple( 47 | 'EpisodeSpecification', 48 | 'split, num_classes, num_train_examples, num_test_examples')): 49 | """The specification of an episode. 50 | 51 | Args: 52 | split: A Split from which to pick data. 53 | num_classes: The number of classes in the episode, or None for variable. 54 | num_train_examples: The number of examples to use per class in the train 55 | phase, or None for variable. 56 | num_test_examples: the number of examples to use per class in the test 57 | phase, or None for variable. 58 | """ 59 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/data/pipeline_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Meta-Dataset Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | r"""Tests for meta_dataset.data.pipeline. 18 | 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import os 26 | 27 | from absl.testing import parameterized 28 | import gin 29 | from meta_dataset.data import config 30 | from meta_dataset.data import learning_spec 31 | from meta_dataset.data import pipeline 32 | from meta_dataset.data import test_utils 33 | from meta_dataset.data.dataset_spec import DatasetSpecification 34 | import numpy as np 35 | import tensorflow.compat.v1 as tf 36 | 37 | 38 | class PipelineTest(tf.test.TestCase, parameterized.TestCase): 39 | 40 | @parameterized.named_parameters( 41 | ('FeatureDecoder', 'feature', 42 | 'third_party/py/meta_dataset/learn/gin/setups/data_config_feature.gin'), 43 | ('NoDecoder', 'none', 44 | 'third_party/py/meta_dataset/learn/gin/setups/data_config_no_decoder.gin' 45 | ), 46 | ) 47 | def test_make_multisource_episode_pipeline_feature(self, decoder_type, 48 | config_file_path): 49 | 50 | # Create some feature records and write them to a temp directory. 51 | feat_size = 64 52 | num_examples = 100 53 | num_classes = 10 54 | output_path = self.get_temp_dir() 55 | gin.parse_config_file(config_file_path) 56 | 57 | # 1-Write feature records to temp directory. 58 | self.rng = np.random.RandomState(0) 59 | class_features = [] 60 | class_examples = [] 61 | for class_id in range(num_classes): 62 | features = self.rng.randn(num_examples, feat_size).astype(np.float32) 63 | label = np.array(class_id).astype(np.int64) 64 | output_file = os.path.join(output_path, str(class_id) + '.tfrecords') 65 | examples = test_utils.write_feature_records(features, label, output_file) 66 | class_examples.append(examples) 67 | class_features.append(features) 68 | class_examples = np.stack(class_examples) 69 | class_features = np.stack(class_features) 70 | 71 | # 2-Read records back using multi-source pipeline. 72 | # DatasetSpecification to use in tests 73 | dataset_spec = DatasetSpecification( 74 | name=None, 75 | classes_per_split={ 76 | learning_spec.Split.TRAIN: 5, 77 | learning_spec.Split.VALID: 2, 78 | learning_spec.Split.TEST: 3 79 | }, 80 | images_per_class={i: num_examples for i in range(num_classes)}, 81 | class_names=None, 82 | path=output_path, 83 | file_pattern='{}.tfrecords') 84 | 85 | # Duplicate the dataset to simulate reading from multiple datasets. 86 | use_bilevel_ontology_list = [False] * 2 87 | use_dag_ontology_list = [False] * 2 88 | all_dataset_specs = [dataset_spec] * 2 89 | 90 | fixed_ways_shots = config.EpisodeDescriptionConfig( 91 | num_query=5, num_support=5, num_ways=5) 92 | 93 | dataset_episodic = pipeline.make_multisource_episode_pipeline( 94 | dataset_spec_list=all_dataset_specs, 95 | use_dag_ontology_list=use_dag_ontology_list, 96 | use_bilevel_ontology_list=use_bilevel_ontology_list, 97 | episode_descr_config=fixed_ways_shots, 98 | split=learning_spec.Split.TRAIN, 99 | image_size=None) 100 | 101 | episode, _ = self.evaluate( 102 | dataset_episodic.make_one_shot_iterator().get_next()) 103 | 104 | if decoder_type == 'feature': 105 | # 3-Check that support and query features are in class_features and have 106 | # the correct corresponding label. 107 | support_features, support_class_ids = episode[0], episode[2] 108 | query_features, query_class_ids = episode[3], episode[5] 109 | 110 | for feat, class_id in zip( 111 | list(support_features), list(support_class_ids)): 112 | abs_err = np.abs(np.sum(class_features - feat[None][None], axis=-1)) 113 | # Make sure the feature is present in the original data. 114 | self.assertEqual(abs_err.min(), 0.0) 115 | found_class_id = np.where(abs_err == 0.0)[0][0] 116 | self.assertEqual(found_class_id, class_id) 117 | 118 | for feat, class_id in zip(list(query_features), list(query_class_ids)): 119 | abs_err = np.abs(np.sum(class_features - feat[None][None], axis=-1)) 120 | # Make sure the feature is present in the original data. 121 | self.assertEqual(abs_err.min(), 0.0) 122 | found_class_id = np.where(abs_err == 0.0)[0][0] 123 | self.assertEqual(found_class_id, class_id) 124 | 125 | elif decoder_type == 'none': 126 | # 3-Check that support and query examples are in class_examples and have 127 | # the correct corresponding label. 128 | 129 | support_examples, support_class_ids = episode[0], episode[2] 130 | query_examples, query_class_ids = episode[3], episode[5] 131 | 132 | for example, class_id in zip( 133 | list(support_examples), list(support_class_ids)): 134 | found_class_id = np.where(class_examples == example)[0][0] 135 | self.assertEqual(found_class_id, class_id) 136 | 137 | for example, class_id in zip(list(query_examples), list(query_class_ids)): 138 | found_class_id = np.where(class_examples == example)[0][0] 139 | self.assertEqual(found_class_id, class_id) 140 | 141 | 142 | if __name__ == '__main__': 143 | tf.test.main() 144 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/data/test_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Meta-Dataset Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python2, python3 17 | """Utility functions for input pipeline tests.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import gin.tf 24 | from meta_dataset.data.dataset_spec import DatasetSpecification 25 | from meta_dataset.data.learning_spec import Split 26 | from meta_dataset.dataset_conversion import dataset_to_records 27 | import numpy as np 28 | import tensorflow.compat.v1 as tf 29 | 30 | # DatasetSpecification to use in tests 31 | DATASET_SPEC = DatasetSpecification( 32 | name=None, 33 | classes_per_split={ 34 | Split.TRAIN: 15, 35 | Split.VALID: 5, 36 | Split.TEST: 10 37 | }, 38 | images_per_class=dict(enumerate([10, 20, 30] * 10)), 39 | class_names=['%d' % i for i in range(30)], 40 | path=None, 41 | file_pattern='{}.tfrecords') 42 | 43 | # Define defaults for the input pipeline. 44 | MIN_WAYS = 5 45 | MAX_WAYS_UPPER_BOUND = 50 46 | MAX_NUM_QUERY = 10 47 | MAX_SUPPORT_SET_SIZE = 500 48 | MAX_SUPPORT_SIZE_CONTRIB_PER_CLASS = 100 49 | MIN_LOG_WEIGHT = np.log(0.5) 50 | MAX_LOG_WEIGHT = np.log(2) 51 | 52 | 53 | def set_episode_descr_config_defaults(): 54 | """Sets default values for EpisodeDescriptionConfig using gin.""" 55 | gin.parse_config('import meta_dataset.data.config') 56 | 57 | gin.bind_parameter('EpisodeDescriptionConfig.num_ways', None) 58 | gin.bind_parameter('EpisodeDescriptionConfig.num_support', None) 59 | gin.bind_parameter('EpisodeDescriptionConfig.num_query', None) 60 | gin.bind_parameter('EpisodeDescriptionConfig.min_ways', MIN_WAYS) 61 | gin.bind_parameter('EpisodeDescriptionConfig.max_ways_upper_bound', 62 | MAX_WAYS_UPPER_BOUND) 63 | gin.bind_parameter('EpisodeDescriptionConfig.max_num_query', MAX_NUM_QUERY) 64 | gin.bind_parameter('EpisodeDescriptionConfig.max_support_set_size', 65 | MAX_SUPPORT_SET_SIZE) 66 | gin.bind_parameter( 67 | 'EpisodeDescriptionConfig.max_support_size_contrib_per_class', 68 | MAX_SUPPORT_SIZE_CONTRIB_PER_CLASS) 69 | gin.bind_parameter('EpisodeDescriptionConfig.min_log_weight', MIN_LOG_WEIGHT) 70 | gin.bind_parameter('EpisodeDescriptionConfig.max_log_weight', MAX_LOG_WEIGHT) 71 | gin.bind_parameter('EpisodeDescriptionConfig.ignore_dag_ontology', False) 72 | gin.bind_parameter('EpisodeDescriptionConfig.ignore_bilevel_ontology', False) 73 | gin.bind_parameter('EpisodeDescriptionConfig.ignore_hierarchy_probability', 74 | 0.) 75 | gin.bind_parameter('EpisodeDescriptionConfig.simclr_episode_fraction', 0.) 76 | 77 | # Following is set in a different scope. 78 | gin.bind_parameter('none/EpisodeDescriptionConfig.min_ways', None) 79 | gin.bind_parameter('none/EpisodeDescriptionConfig.max_ways_upper_bound', None) 80 | gin.bind_parameter('none/EpisodeDescriptionConfig.max_num_query', None) 81 | gin.bind_parameter('none/EpisodeDescriptionConfig.max_support_set_size', None) 82 | gin.bind_parameter( 83 | 'none/EpisodeDescriptionConfig.max_support_size_contrib_per_class', None) 84 | gin.bind_parameter('none/EpisodeDescriptionConfig.min_log_weight', None) 85 | gin.bind_parameter('none/EpisodeDescriptionConfig.max_log_weight', None) 86 | 87 | 88 | def write_feature_records(features, label, output_path): 89 | """Creates a record file from features and labels. 90 | 91 | Args: 92 | features: An [n, m] numpy array of features. 93 | label: An integer, the label common to all records. 94 | output_path: A string specifying the location of the record. 95 | 96 | Returns: 97 | serialized_examples: list tf.Example protos written by the writer. 98 | """ 99 | writer = tf.python_io.TFRecordWriter(output_path) 100 | serialized_examples = [] 101 | for feat in list(features): 102 | # Write the example. 103 | serialized_example = dataset_to_records.make_example([ 104 | ('image/embedding', 'float32', feat.tolist()), 105 | ('image/class/label', 'int64', [label]) 106 | ]) 107 | writer.write(serialized_example) 108 | serialized_examples.append(serialized_example) 109 | 110 | writer.close() 111 | return serialized_examples 112 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/baseline_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/baseline_config.gin' 3 | BatchSplitReaderGetReader.add_dataset_offset = True 4 | 5 | # Backbone hypers. 6 | include 'meta_dataset/learn/gin/best_v2/pretrained_wide_resnet.gin' 7 | 8 | # Model hypers. 9 | BaselineLearner.knn_distance = 'cosine' 10 | BaselineLearner.cosine_classifier = True 11 | BaselineLearner.cosine_logits_multiplier = 1 12 | BaselineLearner.use_weight_norm = True 13 | 14 | # Data hypers. 15 | DataConfig.image_height = 126 16 | 17 | # Training hypers (not needed for eval). 18 | Trainer.decay_every = 500 19 | Trainer.decay_rate = 0.8778059962506467 20 | Trainer.learning_rate = 0.000253906846867988 21 | weight_decay = 0.00002393929026012612 22 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/baseline_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/baseline_config.gin' 3 | 4 | # Backbone hypers. 5 | Learner.embedding_fn = @wide_resnet 6 | weight_decay = 0.000007138118976497546 7 | 8 | Trainer.checkpoint_to_restore = '' 9 | Trainer.pretrained_source = 'scratch' 10 | 11 | # Model hypers. 12 | BaselineLearner.knn_distance = 'cosine' 13 | BaselineLearner.cosine_classifier = False 14 | BaselineLearner.cosine_logits_multiplier = 10 15 | BaselineLearner.use_weight_norm = False 16 | 17 | # Data hypers. 18 | DataConfig.image_height = 126 19 | 20 | # Training hypers (not needed for eval). 21 | Trainer.decay_every = 10000 22 | Trainer.decay_rate = 0.7294597641152971 23 | Trainer.learning_rate = 0.007634189137886614 24 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/baselinefinetune_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/baselinefinetune_config.gin' 3 | BatchSplitReaderGetReader.add_dataset_offset = True 4 | 5 | # Backbone hypers. 6 | include 'meta_dataset/learn/gin/best_v2/pretrained_wide_resnet.gin' 7 | 8 | # Model hypers. 9 | BaselineLearner.cosine_classifier = False 10 | BaselineLearner.use_weight_norm = True 11 | BaselineLearner.cosine_logits_multiplier = 1 12 | BaselineFinetuneLearner.num_finetune_steps = 200 13 | BaselineFinetuneLearner.finetune_lr = 0.01 14 | BaselineFinetuneLearner.finetune_all_layers = True 15 | BaselineFinetuneLearner.finetune_with_adam = True 16 | 17 | # Data hypers. 18 | DataConfig.image_height = 84 19 | 20 | # Training hypers (not needed for eval). 21 | Trainer.decay_every = 5000 22 | Trainer.decay_rate = 0.5559080744371039 23 | Trainer.learning_rate = 0.0027015533546616804 24 | weight_decay = 0.00002266979856832968 25 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/baselinefinetune_cosine_mini_imagenet_fiveshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet_five_way_five_shot.gin' 2 | include 'meta_dataset/learn/gin/learners/baselinefinetune_cosine_config.gin' 3 | 4 | Learner.embedding_fn = @wide_resnet 5 | weight_decay = 0 6 | 7 | DataConfig.image_height = 84 8 | Trainer.pretrained_source = 'scratch' 9 | 10 | BaselineFinetuneLearner.num_finetune_steps = 200 11 | BaselineFinetuneLearner.finetune_lr = 0.01 12 | BaselineFinetuneLearner.finetune_with_adam = False 13 | BaselineLearner.use_weight_norm = False 14 | BaselineFinetuneLearner.finetune_all_layers = False 15 | BaselineLearner.cosine_logits_multiplier = 10 16 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/baselinefinetune_cosine_mini_imagenet_oneshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet_five_way_five_shot.gin' 2 | EpisodeDescriptionConfig.num_support = 1 # Change 5-shot to 1-shot. 3 | include 'meta_dataset/learn/gin/learners/baselinefinetune_cosine_config.gin' 4 | 5 | Learner.embedding_fn = @wide_resnet 6 | weight_decay = 0 7 | 8 | DataConfig.image_height = 84 9 | Trainer.pretrained_source = 'scratch' 10 | 11 | BaselineFinetuneLearner.num_finetune_steps = 200 12 | BaselineFinetuneLearner.finetune_lr = 0.01 13 | BaselineFinetuneLearner.finetune_with_adam = False 14 | BaselineLearner.use_weight_norm = False 15 | BaselineFinetuneLearner.finetune_all_layers = False 16 | BaselineLearner.cosine_logits_multiplier = 10 17 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/baselinefinetune_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/baselinefinetune_config.gin' 3 | 4 | # Backbone hypers. 5 | include 'meta_dataset/learn/gin/best_v2/pretrained_wide_resnet.gin' 6 | 7 | # Model hypers. 8 | BaselineLearner.cosine_classifier = False 9 | BaselineLearner.use_weight_norm = True 10 | BaselineLearner.cosine_logits_multiplier = 1 11 | BaselineFinetuneLearner.num_finetune_steps = 200 12 | BaselineFinetuneLearner.finetune_lr = 0.01 13 | BaselineFinetuneLearner.finetune_all_layers = True 14 | BaselineFinetuneLearner.finetune_with_adam = True 15 | 16 | # Data hypers. 17 | DataConfig.image_height = 84 18 | 19 | # Training hypers (not needed for eval). 20 | Trainer.decay_every = 5000 21 | Trainer.decay_rate = 0.5559080744371039 22 | Trainer.learning_rate = 0.0027015533546616804 23 | weight_decay = 0.00002266979856832968 24 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/baselinefinetune_mini_imagenet_fiveshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin' 2 | EpisodeDescriptionConfig.num_support = 5 3 | 4 | include 'meta_dataset/learn/gin/learners/baselinefinetune_config.gin' 5 | 6 | Learner.embedding_fn = @wide_resnet 7 | weight_decay = 0 8 | 9 | DataConfig.image_height = 84 10 | Trainer.pretrained_source = 'scratch' 11 | 12 | BaselineFinetuneLearner.num_finetune_steps = 200 13 | BaselineFinetuneLearner.finetune_lr = 0.01 14 | BaselineFinetuneLearner.finetune_with_adam = False 15 | BaselineLearner.use_weight_norm = False 16 | BaselineFinetuneLearner.finetune_all_layers = False 17 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/baselinefinetune_mini_imagenet_oneshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet_five_way_five_shot.gin' 2 | EpisodeDescriptionConfig.num_support = 1 # Change 5-shot to 1-shot. 3 | include 'meta_dataset/learn/gin/learners/baselinefinetune_config.gin' 4 | 5 | Learner.embedding_fn = @wide_resnet 6 | weight_decay = 0 7 | 8 | DataConfig.image_height = 84 9 | Trainer.pretrained_source = 'scratch' 10 | 11 | BaselineFinetuneLearner.num_finetune_steps = 200 12 | BaselineFinetuneLearner.finetune_lr = 0.01 13 | BaselineFinetuneLearner.finetune_with_adam = False 14 | BaselineLearner.use_weight_norm = False 15 | BaselineFinetuneLearner.finetune_all_layers = False 16 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/maml_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/maml_config.gin' 3 | 4 | # Backbone hypers. 5 | include 'meta_dataset/learn/gin/best_v2/pretrained_resnet.gin' 6 | 7 | # Model hypers. (new results) 8 | MAMLLearner.first_order = True 9 | MAMLLearner.alpha = 0.09915203576061224 10 | MAMLLearner.additional_evaluation_update_steps = 5 11 | MAMLLearner.num_update_steps = 10 12 | 13 | # Data hypers. 14 | DataConfig.image_height = 126 15 | 16 | # Training hypers (not needed for eval). 17 | Trainer.decay_learning_rate = True 18 | Trainer.decay_every = 10000 19 | Trainer.decay_rate = 0.8354967819706964 20 | Trainer.learning_rate = 0.00023702989294963628 21 | weight_decay = 0.00030572401935585053 22 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/maml_all_from_scratch.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/maml_config.gin' 3 | 4 | # Backbone hypers. 5 | Learner.embedding_fn = @resnet 6 | 7 | Trainer.pretrained_source = 'scratch' 8 | Trainer.checkpoint_to_restore = '' 9 | 10 | # Model hypers. 11 | MAMLLearner.first_order = True 12 | MAMLLearner.alpha = 0.01406233807768908 13 | MAMLLearner.additional_evaluation_update_steps = 5 14 | MAMLLearner.num_update_steps = 6 15 | 16 | # Data hypers. 17 | DataConfig.image_height = 126 18 | 19 | # Training hypers (not needed for eval). 20 | Trainer.decay_learning_rate = True 21 | Trainer.decay_every = 10000 22 | Trainer.decay_rate = 0.5575524279100452 23 | Trainer.learning_rate = 0.0004067024273450244 24 | weight_decay = 0.0000033242172150107816 25 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/maml_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/maml_config.gin' 3 | 4 | # Backbone hypers. 5 | include 'meta_dataset/learn/gin/best_v2/pretrained_resnet.gin' 6 | 7 | # Model hypers. 8 | MAMLLearner.first_order = True 9 | MAMLLearner.alpha = 0.09915203576061224 10 | MAMLLearner.additional_evaluation_update_steps = 5 11 | MAMLLearner.num_update_steps = 10 12 | 13 | # Data hypers. 14 | DataConfig.image_height = 126 15 | 16 | # Training hypers (not needed for eval). 17 | Trainer.decay_learning_rate = True 18 | Trainer.decay_every = 10000 19 | Trainer.decay_rate = 0.8354967819706964 20 | Trainer.learning_rate = 0.0003601303690709516 21 | weight_decay = 0.00030572401935585053 22 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/maml_imagenet_from_scratch.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/maml_config.gin' 3 | 4 | # Backbone hypers. 5 | Learner.embedding_fn = @four_layer_convnet 6 | 7 | Trainer.pretrained_source = 'scratch' 8 | Trainer.checkpoint_to_restore = '' 9 | 10 | # Model hypers. 11 | MAMLLearner.first_order = True 12 | MAMLLearner.alpha = 0.05602373320602529 13 | MAMLLearner.additional_evaluation_update_steps = 5 14 | MAMLLearner.num_update_steps = 10 15 | 16 | # Data hypers. 17 | DataConfig.image_height = 84 18 | 19 | # Training hypers (not needed for eval). 20 | Trainer.decay_learning_rate = True 21 | Trainer.decay_every = 5000 22 | Trainer.decay_rate = 0.9153310958769965 23 | Trainer.learning_rate = 0.00012385665511953882 24 | weight_decay = 0.0006858703351826797 25 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/maml_init_with_proto_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/maml_init_with_proto_config.gin' 3 | 4 | # Backbone hypers. 5 | include 'meta_dataset/learn/gin/best_v2/pretrained_resnet.gin' 6 | 7 | # Model hypers. 8 | MAMLLearner.first_order = True 9 | MAMLLearner.alpha = 0.01988024603751885 10 | MAMLLearner.additional_evaluation_update_steps = 5 11 | MAMLLearner.num_update_steps = 1 12 | 13 | # Data hypers. 14 | DataConfig.image_height = 126 15 | 16 | # Training hypers (not needed for eval). 17 | Trainer.decay_learning_rate = True 18 | Trainer.decay_every = 500 19 | Trainer.decay_rate = 0.7659662681067452 20 | Trainer.learning_rate = 0.00029809808680211807 21 | weight_decay = 0.0007159940362690606 22 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/maml_init_with_proto_all_from_scratch.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/maml_init_with_proto_config.gin' 3 | 4 | # Backbone hypers. 5 | Learner.embedding_fn = @resnet 6 | 7 | Trainer.pretrained_source = 'scratch' 8 | Trainer.checkpoint_to_restore = '' 9 | 10 | # Model hypers. 11 | MAMLLearner.first_order = True 12 | MAMLLearner.alpha = 0.01139772733084185 13 | MAMLLearner.additional_evaluation_update_steps = 5 14 | MAMLLearner.num_update_steps = 1 15 | 16 | # Data hypers. 17 | DataConfig.image_height = 126 18 | 19 | # Training hypers (not needed for eval). 20 | Trainer.decay_learning_rate = True 21 | Trainer.decay_every = 10000 22 | Trainer.decay_rate = 0.6731268939317062 23 | Trainer.learning_rate = 0.0007252750256444753 24 | weight_decay = 0.00005323094246436839 25 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/maml_init_with_proto_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/maml_init_with_proto_config.gin' 3 | 4 | # Backbone hypers. 5 | include 'meta_dataset/learn/gin/best_v2/pretrained_resnet.gin' 6 | 7 | # Model hypers. 8 | MAMLLearner.first_order = True 9 | MAMLLearner.alpha = 0.005435022808033229 10 | MAMLLearner.additional_evaluation_update_steps = 0 11 | MAMLLearner.num_update_steps = 6 12 | 13 | # Data hypers. 14 | DataConfig.image_height = 126 15 | 16 | # Training hypers (not needed for eval). 17 | Trainer.decay_learning_rate = True 18 | Trainer.decay_every = 2500 19 | Trainer.decay_rate = 0.6477898086638092 20 | Trainer.learning_rate = 0.00036339913514891586 21 | weight_decay = 0.000013656044331235537 22 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/maml_init_with_proto_imagenet_from_scratch.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/maml_init_with_proto_config.gin' 3 | 4 | # Backbone hypers. 5 | 6 | Learner.embedding_fn = @resnet 7 | 8 | Trainer.pretrained_source = 'scratch' 9 | Trainer.checkpoint_to_restore = '' 10 | 11 | # Model hypers. 12 | MAMLLearner.first_order = True 13 | MAMLLearner.alpha = 0.31106175977182243 14 | MAMLLearner.additional_evaluation_update_steps = 5 15 | MAMLLearner.num_update_steps = 6 16 | 17 | # Data hypers. 18 | DataConfig.image_height = 126 19 | 20 | # Training hypers (not needed for eval). 21 | Trainer.decay_learning_rate = True 22 | Trainer.decay_every = 5000 23 | Trainer.decay_rate = 0.6431136271727287 24 | Trainer.learning_rate = 0.0007181155997029211 25 | weight_decay = 0.00003630199690303937 26 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/maml_init_with_proto_inference_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/maml_init_with_proto_config.gin' 3 | 4 | # Backbone hypers. 5 | include 'meta_dataset/learn/gin/best_v2/pretrained_resnet.gin' 6 | 7 | # Model hypers. 8 | MAMLLearner.first_order = True 9 | MAMLLearner.alpha = 0.061445396411177286 10 | MAMLLearner.additional_evaluation_update_steps = 0 11 | MAMLLearner.num_update_steps = 1 12 | 13 | # Data hypers. 14 | DataConfig.image_height = 126 15 | 16 | # Training hypers (not needed for eval). 17 | Trainer.decay_learning_rate = True 18 | Trainer.decay_every = 5000 19 | Trainer.decay_rate = 0.8449806376151254 20 | Trainer.learning_rate = 0.0008118174317224871 21 | weight_decay = 0.0000038819970639091496 22 | 23 | # Baseline hypers (just for the record). 24 | BaselineLearner.cosine_logits_multiplier = 1 25 | BaselineLearner.use_weight_norm = True 26 | BaselineLearner.knn_distance = 'l2' 27 | BaselineLearner.cosine_classifier = False 28 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/maml_init_with_proto_inference_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/maml_init_with_proto_config.gin' 3 | 4 | # Backbone hypers. 5 | Learner.embedding_fn = @resnet 6 | 7 | Trainer.pretrained_source = 'scratch' 8 | Trainer.checkpoint_to_restore = '' 9 | 10 | # Model hypers. 11 | MAMLLearner.first_order = True 12 | MAMLLearner.alpha = 0.00577257710776957 13 | MAMLLearner.additional_evaluation_update_steps = 0 14 | MAMLLearner.num_update_steps = 1 15 | 16 | # Data hypers. 17 | DataConfig.image_height = 126 18 | 19 | # Training hypers (not needed for eval). 20 | Trainer.decay_learning_rate = True 21 | Trainer.decay_every = 500 22 | Trainer.decay_rate = 0.885662482266546 23 | Trainer.learning_rate = 0.00025036275525430426 24 | weight_decay = 0.003952085241872012 25 | 26 | # Baseline hypers (just for the record). 27 | BaselineLearner.cosine_logits_multiplier = 10 28 | BaselineLearner.use_weight_norm = True 29 | BaselineLearner.knn_distance = 'l2' 30 | BaselineLearner.cosine_classifier = True 31 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/matching_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/matching_config.gin' 3 | 4 | # Backbone hypers. 5 | include 'meta_dataset/learn/gin/best_v2/pretrained_convnet.gin' 6 | 7 | # Data hypers. 8 | DataConfig.image_height = 84 9 | 10 | # Training hypers (not needed for eval). 11 | Trainer.decay_learning_rate = True 12 | Trainer.decay_every = 500 13 | Trainer.decay_rate = 0.5937020776489796 14 | Trainer.learning_rate = 0.005052178216688174 15 | weight_decay = 0.00000392173790384195 16 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/matching_all_from_scratch.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/matching_config.gin' 3 | 4 | # Backbone hypers. 5 | Learner.embedding_fn = @four_layer_convnet 6 | 7 | Trainer.pretrained_source = 'scratch' 8 | Trainer.checkpoint_to_restore = '' 9 | 10 | # Data hypers. 11 | DataConfig.image_height = 126 12 | 13 | # Training hypers (not needed for eval). 14 | Trainer.decay_learning_rate = True 15 | Trainer.decay_every = 2500 16 | Trainer.decay_rate = 0.9197652570309498 17 | Trainer.learning_rate = 0.002748105034397538 18 | weight_decay = 0.0000013254789476822292 19 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/matching_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/matching_config.gin' 3 | 4 | # Backbone hypers. 5 | include 'meta_dataset/learn/gin/best_v2/pretrained_resnet.gin' 6 | 7 | # Data hypers. 8 | DataConfig.image_height = 126 9 | 10 | # Training hypers (not needed for eval). 11 | Trainer.decay_learning_rate = True 12 | Trainer.decay_every = 500 13 | Trainer.decay_rate = 0.915193802145601 14 | Trainer.learning_rate = 0.0012064626897259694 15 | weight_decay = 0.0000885787420909229 16 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/matching_imagenet_from_scratch.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/matching_config.gin' 3 | 4 | # Backbone hypers. 5 | Learner.embedding_fn = @resnet 6 | 7 | Trainer.pretrained_source = 'scratch' 8 | Trainer.checkpoint_to_restore = '' 9 | 10 | # Data hypers. 11 | DataConfig.image_height = 126 12 | 13 | # Training hypers (not needed for eval). 14 | Trainer.decay_learning_rate = True 15 | Trainer.decay_every = 2500 16 | Trainer.decay_rate = 0.8333411536286996 17 | Trainer.learning_rate = 0.0006229660387662655 18 | weight_decay = 0.00018036259587809225 19 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/matching_inference_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/matching_config.gin' 3 | 4 | # Backbone hypers. 5 | include 'meta_dataset/learn/gin/best_v2/pretrained_resnet.gin' 6 | 7 | # Data hypers. 8 | DataConfig.image_height = 84 9 | 10 | # Training hypers (not needed for eval). 11 | Trainer.decay_learning_rate = True 12 | Trainer.decay_every = 1000 13 | Trainer.decay_rate = 0.5568532448398684 14 | Trainer.learning_rate = 0.0009946451213635535 15 | weight_decay = 0.000006779214221539446 16 | 17 | # Baseline hypers (just for the record). 18 | BaselineLearner.cosine_logits_multiplier = 100 19 | BaselineLearner.use_weight_norm = False 20 | BaselineLearner.knn_distance = 'cosine' 21 | BaselineLearner.cosine_classifier = False 22 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/matching_inference_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/matching_config.gin' 3 | 4 | # Backbone hypers. 5 | include 'meta_dataset/learn/gin/best_v2/pretrained_resnet.gin' 6 | 7 | # Data hypers. 8 | DataConfig.image_height = 84 9 | 10 | # Training hypers (not needed for eval). 11 | Trainer.decay_learning_rate = True 12 | Trainer.decay_every = 10000 13 | Trainer.decay_rate = 0.9714872833376629 14 | Trainer.learning_rate = 0.002670412664551181 15 | weight_decay = 0.000013646011503835137 16 | 17 | # Baseline hypers (just for the record). 18 | BaselineLearner.cosine_logits_multiplier = 1 19 | BaselineLearner.use_weight_norm = True 20 | BaselineLearner.knn_distance = 'cosine' 21 | BaselineLearner.cosine_classifier = False 22 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/pretrain_imagenet_convnet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/baseline_config.gin' 3 | 4 | # Backbone hypers. 5 | Learner.embedding_fn = @four_layer_convnet 6 | 7 | Trainer.checkpoint_to_restore = '' 8 | Trainer.pretrained_source = 'scratch' 9 | 10 | # Model hypers. 11 | BaselineLearner.knn_distance = 'cosine' 12 | BaselineLearner.cosine_classifier = True 13 | BaselineLearner.cosine_logits_multiplier = 1 14 | BaselineLearner.use_weight_norm = True 15 | 16 | # Data hypers. 17 | DataConfig.image_height = 84 18 | 19 | # Training hypers (not needed for eval). 20 | Trainer.num_updates = 50000 21 | Trainer.decay_every = 1000 22 | Trainer.decay_rate = 0.9105573818947892 23 | Trainer.learning_rate = 0.008644114436633987 24 | weight_decay = 0.000005171477829794739 25 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/pretrain_imagenet_resnet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/baseline_config.gin' 3 | 4 | # Backbone hypers. 5 | Learner.embedding_fn = @resnet 6 | 7 | Trainer.checkpoint_to_restore = '' 8 | Trainer.pretrained_source = 'scratch' 9 | 10 | # Model hypers. 11 | BaselineLearner.knn_distance = 'cosine' 12 | BaselineLearner.cosine_classifier = False 13 | BaselineLearner.cosine_logits_multiplier = 2 14 | BaselineLearner.use_weight_norm = False 15 | 16 | # Data hypers. 17 | DataConfig.image_height = 126 18 | 19 | # Training hypers (not needed for eval). 20 | Trainer.num_updates = 50000 21 | Trainer.decay_every = 1000 22 | Trainer.decay_rate = 0.9967524905880909 23 | Trainer.learning_rate = 0.00375640851370052 24 | weight_decay = 0.00002628042826116842 25 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/pretrain_imagenet_wide_resnet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/baseline_config.gin' 3 | 4 | # Backbone hypers. 5 | Learner.embedding_fn = @wide_resnet 6 | 7 | Trainer.checkpoint_to_restore = '' 8 | Trainer.pretrained_source = 'scratch' 9 | 10 | # Model hypers. 11 | BaselineLearner.knn_distance = 'cosine' 12 | BaselineLearner.cosine_classifier = True 13 | BaselineLearner.cosine_logits_multiplier = 1 14 | BaselineLearner.use_weight_norm = True 15 | 16 | # Data hypers. 17 | DataConfig.image_height = 126 18 | 19 | # Training hypers (not needed for eval). 20 | Trainer.num_updates = 50000 21 | Trainer.decay_every = 100 22 | Trainer.decay_rate = 0.5082121576573064 23 | Trainer.learning_rate = 0.007084776688116927 24 | weight_decay = 0.000005078192976503067 25 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/pretrained_convnet.gin: -------------------------------------------------------------------------------- 1 | # Best pre-trained convnet. Update the path to the checkpoint. 2 | Learner.embedding_fn = @four_layer_convnet 3 | 4 | Trainer.checkpoint_to_restore = '/path/to/checkpoints/pretrain_imagenet_convnet/model_42500.ckpt' 5 | Trainer.pretrained_source = 'imagenet' 6 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/pretrained_resnet.gin: -------------------------------------------------------------------------------- 1 | # Best pre-trained resnet. Update the path to the checkpoint. 2 | Learner.embedding_fn = @resnet 3 | 4 | Trainer.checkpoint_to_restore = '/path/to/checkpoints/pretrain_imagenet_resnet/model_27500.ckpt' 5 | Trainer.pretrained_source = 'imagenet' 6 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/pretrained_wide_resnet.gin: -------------------------------------------------------------------------------- 1 | # Best pre-trained wide resnet. Update the path to the checkpoint. 2 | Learner.embedding_fn = @wide_resnet 3 | 4 | Trainer.checkpoint_to_restore = '/path/to/checkpoints/pretrain_imagenet_wide_resnet/model_46000.ckpt' 5 | Trainer.pretrained_source = 'imagenet' 6 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/prototypical_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin' 3 | 4 | # Backbone hypers. 5 | include 'meta_dataset/learn/gin/best_v2/pretrained_resnet.gin' 6 | 7 | # Data hypers. 8 | DataConfig.image_height = 126 9 | 10 | # Training hypers (not needed for eval). 11 | Trainer.decay_learning_rate = True 12 | Trainer.decay_every = 100 13 | Trainer.decay_rate = 0.5475646651850319 14 | Trainer.learning_rate = 0.0018231563858704268 15 | weight_decay = 0.0074472606534577565 16 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/prototypical_all_from_scratch.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin' 3 | 4 | # Backbone hypers. 5 | 6 | Learner.embedding_fn = @resnet 7 | 8 | Trainer.pretrained_source = 'scratch' 9 | Trainer.checkpoint_to_restore = '' 10 | 11 | # Data hypers. 12 | DataConfig.image_height = 126 13 | 14 | # Training hypers (not needed for eval). 15 | Trainer.decay_learning_rate = True 16 | Trainer.decay_every = 2500 17 | Trainer.decay_rate = 0.8333411536286996 18 | Trainer.learning_rate = 0.0006229660387662655 19 | weight_decay = 0.00018036259587809225 20 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/prototypical_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin' 3 | 4 | # Backbone hypers. 5 | include 'meta_dataset/learn/gin/best_v2/pretrained_resnet.gin' 6 | 7 | # Data hypers. 8 | DataConfig.image_height = 126 9 | 10 | # Training hypers (not needed for eval). 11 | Trainer.decay_learning_rate = True 12 | Trainer.decay_every = 500 13 | Trainer.decay_rate = 0.915193802145601 14 | Trainer.learning_rate = 0.0012064626897259694 15 | weight_decay = 0.0000885787420909229 16 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/prototypical_imagenet_from_scratch.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin' 3 | 4 | # Backbone hypers. 5 | Learner.embedding_fn = @resnet 6 | 7 | Trainer.pretrained_source = 'scratch' 8 | Trainer.checkpoint_to_restore = '' 9 | 10 | # Data hypers. 11 | DataConfig.image_height = 126 12 | 13 | # Training hypers (not needed for eval). 14 | Trainer.decay_learning_rate = True 15 | Trainer.decay_every = 2500 16 | Trainer.decay_rate = 0.8333411536286996 17 | Trainer.learning_rate = 0.0006229660387662655 18 | weight_decay = 0.00018036259587809225 19 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/prototypical_inference_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin' 3 | 4 | # Backbone hypers. 5 | include 'meta_dataset/learn/gin/best_v2/pretrained_wide_resnet.gin' 6 | 7 | # Data hypers. 8 | DataConfig.image_height = 126 9 | 10 | # Training hypers (not needed for eval). 11 | Trainer.decay_learning_rate = True 12 | Trainer.decay_every = 100 13 | Trainer.decay_rate = 0.7426809516243701 14 | Trainer.learning_rate = 0.0009325756201058525 15 | weight_decay = 0.00003386806355382518 16 | 17 | # Baseline hypers (just for the record). 18 | BaselineLearner.cosine_logits_multiplier = 2 19 | BaselineLearner.use_weight_norm = True 20 | BaselineLearner.knn_distance = 'l2' 21 | BaselineLearner.cosine_classifier = False 22 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/prototypical_inference_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin' 3 | 4 | # Backbone hypers. 5 | include 'meta_dataset/learn/gin/best_v2/pretrained_wide_resnet.gin' 6 | 7 | # Data hypers. 8 | DataConfig.image_height = 126 9 | 10 | # Training hypers (not needed for eval). 11 | Trainer.decay_learning_rate = True 12 | Trainer.decay_every = 100 13 | Trainer.decay_rate = 0.7426809516243701 14 | Trainer.learning_rate = 0.0009325756201058525 15 | weight_decay = 0.00003386806355382518 16 | 17 | # Baseline hypers (just for the record). 18 | BaselineLearner.cosine_logits_multiplier = 2 19 | BaselineLearner.use_weight_norm = True 20 | BaselineLearner.knn_distance = 'l2' 21 | BaselineLearner.cosine_classifier = False 22 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/relationnet_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/relationnet_config.gin' 3 | 4 | # Backbone hypers. 5 | include 'meta_dataset/learn/gin/best_v2/pretrained_convnet.gin' 6 | 7 | # Data hypers. 8 | DataConfig.image_height = 126 9 | 10 | # Training hypers (not needed for eval). 11 | Trainer.decay_learning_rate = True 12 | Trainer.decay_every = 5000 13 | Trainer.decay_rate = 0.8707355191010226 14 | Trainer.learning_rate = 0.000906783323100297 15 | weight_decay = 0.0000617576791048944 16 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/relationnet_all_from_scratch.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/relationnet_config.gin' 3 | 4 | # Backbone hypers. 5 | Learner.embedding_fn = @four_layer_convnet 6 | 7 | Trainer.pretrained_source = 'scratch' 8 | Trainer.checkpoint_to_restore = '' 9 | 10 | # Data hypers. 11 | DataConfig.image_height = 126 12 | 13 | # Training hypers (not needed for eval). 14 | Trainer.decay_learning_rate = True 15 | Trainer.decay_every = 2500 16 | Trainer.decay_rate = 0.9197652570309498 17 | Trainer.learning_rate = 0.002748105034397538 18 | weight_decay = 0.0000013254789476822292 19 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/relationnet_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/relationnet_config.gin' 3 | 4 | # Backbone hypers. 5 | include 'meta_dataset/learn/gin/best_v2/pretrained_convnet.gin' 6 | 7 | # Data hypers. 8 | DataConfig.image_height = 126 9 | 10 | # Training hypers (not needed for eval). 11 | Trainer.decay_learning_rate = True 12 | Trainer.decay_every = 5000 13 | Trainer.decay_rate = 0.8707355191010226 14 | Trainer.learning_rate = 0.000906783323100297 15 | weight_decay = 0.0000617576791048944 16 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/relationnet_imagenet_from_scratch.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/relationnet_config.gin' 3 | 4 | # Backbone hypers. 5 | Learner.embedding_fn = @four_layer_convnet 6 | 7 | Trainer.pretrained_source = 'scratch' 8 | Trainer.checkpoint_to_restore = '' 9 | 10 | # Data hypers. 11 | DataConfig.image_height = 126 12 | 13 | # Training hypers (not needed for eval). 14 | Trainer.decay_learning_rate = True 15 | Trainer.decay_every = 2500 16 | Trainer.decay_rate = 0.9197652570309498 17 | Trainer.learning_rate = 0.002748105034397538 18 | weight_decay = 0.0000013254789476822292 19 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/best/relationnet_mini_imagenet_fiveshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/relationnet_config.gin' 3 | 4 | ## Hyperparameters used to train the best model 5 | Trainer.checkpoint_to_restore = '' 6 | Trainer.pretrained_source = 'scratch' 7 | weight_decay = 1e-6 8 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/baseline_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | 3 | include 'meta_dataset/learn/gin/learners/baseline_config.gin' 4 | Trainer.learning_rate = 1e-5 5 | BatchSplitReaderGetReader.add_dataset_offset = True 6 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/baseline_cosine_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | 3 | include 'meta_dataset/learn/gin/learners/baseline_cosine_config.gin' 4 | Trainer.learning_rate = 1e-5 5 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/baseline_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | 3 | include 'meta_dataset/learn/gin/learners/baseline_config.gin' 4 | Trainer.learning_rate = 1e-5 5 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/baseline_mini_imagenet_fiveshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin' 2 | EpisodeDescriptionConfig.num_support = 5 3 | 4 | include 'meta_dataset/learn/gin/learners/baseline_config.gin' 5 | Trainer.learning_rate = 1e-4 6 | Trainer.decay_learning_rate = True 7 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/baseline_mini_imagenet_oneshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin' 2 | EpisodeDescriptionConfig.num_support = 1 3 | 4 | include 'meta_dataset/learn/gin/learners/baseline_config.gin' 5 | Trainer.learning_rate = 1e-4 6 | Trainer.decay_learning_rate = True 7 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/baselinefinetune_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | 3 | include 'meta_dataset/learn/gin/learners/baselinefinetune_config.gin' 4 | Trainer.learning_rate = 1e-5 5 | BatchSplitReaderGetReader.add_dataset_offset = True 6 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/baselinefinetune_cosine_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | 3 | include 'meta_dataset/learn/gin/learners/baselinefinetune_cosine_config.gin' 4 | Trainer.learning_rate = 1e-5 5 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/baselinefinetune_cosine_mini_imagenet_fiveshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin' 2 | EpisodeDescriptionConfig.num_support = 5 3 | 4 | include 'meta_dataset/learn/gin/learners/baselinefinetune_cosine_config.gin' 5 | Trainer.learning_rate = 1e-4 6 | Trainer.decay_learning_rate = True 7 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/baselinefinetune_cosine_mini_imagenet_oneshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin' 2 | EpisodeDescriptionConfig.num_support = 1 3 | 4 | include 'meta_dataset/learn/gin/learners/baselinefinetune_cosine_config.gin' 5 | Trainer.learning_rate = 1e-4 6 | Trainer.decay_learning_rate = True 7 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/baselinefinetune_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | 3 | include 'meta_dataset/learn/gin/learners/baselinefinetune_config.gin' 4 | Trainer.learning_rate = 1e-5 5 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/baselinefinetune_mini_imagenet_fiveshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin' 2 | EpisodeDescriptionConfig.num_support = 5 3 | 4 | include 'meta_dataset/learn/gin/learners/baselinefinetune_config.gin' 5 | Trainer.learning_rate = 1e-4 6 | Trainer.decay_learning_rate = True 7 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/baselinefinetune_mini_imagenet_oneshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin' 2 | EpisodeDescriptionConfig.num_support = 1 3 | 4 | include 'meta_dataset/learn/gin/learners/baselinefinetune_config.gin' 5 | Trainer.learning_rate = 1e-4 6 | Trainer.decay_learning_rate = True 7 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/crosstransformer_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/crosstransformer_config.gin' 3 | 4 | # Backbone hypers. 5 | include 'meta_dataset/learn/gin/default/resnet34_stride16.gin' 6 | 7 | # Data hypers. 8 | DataConfig.image_height = 224 9 | train/EpisodeDescriptionConfig.ignore_hierarchy_probability = 0.5 10 | resnet34.deeplab_alignment = False 11 | 12 | # Training hypers (not needed for eval). 13 | Trainer.num_updates = 100000 14 | Trainer.normalized_gradient_descent = True 15 | Trainer.decay_learning_rate = True 16 | Trainer.decay_every = 2000 17 | Trainer.decay_rate = 0.915193802145601 18 | Trainer.learning_rate = 0.0006 19 | Trainer.num_updates = 100000 20 | weight_decay = 0.0000885787420909229 21 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/crosstransformer_simclreps_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/crosstransformer_config.gin' 3 | 4 | # Backbone hypers. 5 | include 'meta_dataset/learn/gin/default/resnet34_stride16.gin' 6 | 7 | # Data hypers. 8 | DataConfig.image_height = 224 9 | train/EpisodeDescriptionConfig.ignore_hierarchy_probability = 0.5 10 | resnet34.deeplab_alignment = False 11 | 12 | # Training hypers (not needed for eval). 13 | Trainer.normalized_gradient_descent = True 14 | Trainer.decay_learning_rate = True 15 | Trainer.decay_every = 4000 16 | Trainer.decay_rate = 0.915193802145601 17 | Trainer.learning_rate = 0.0006 18 | Trainer.num_updates = 400000 19 | Trainer.enable_tf_optimizations = False 20 | weight_decay = 0.0000885787420909229 21 | 22 | train/EpisodeDescriptionConfig.simclr_episode_fraction = 0.5 23 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/debug_proto_fungi.gin: -------------------------------------------------------------------------------- 1 | # 'fungi' has many classes (> 1000, more than ImageNet) and a wide spread of 2 | # class sizes (from 6 to 442), while being overall not too big (13GB on disk), 3 | # which makes it a good candidate to profile imbalanced datasets with many 4 | # classes locally. 5 | benchmark.datasets = 'fungi' 6 | include 'meta_dataset/learn/gin/setups/data_config.gin' 7 | include 'meta_dataset/learn/gin/setups/trainer_config_debug.gin' 8 | include 'meta_dataset/learn/gin/setups/variable_way_and_shot.gin' 9 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin' 10 | Trainer.data_config = @DataConfig() 11 | 12 | # Total number of updates is 100, do not checkpoint or validate during profiling. 13 | Trainer.checkpoint_every = 1000 14 | Trainer.validate_every = 1000 15 | Trainer.log_every = 10 16 | 17 | Trainer.train_episode_config = @train/EpisodeDescriptionConfig() 18 | Trainer.eval_episode_config = @EpisodeDescriptionConfig() 19 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/debug_proto_mini_imagenet.gin: -------------------------------------------------------------------------------- 1 | benchmark.datasets = 'mini_imagenet' 2 | include 'meta_dataset/learn/gin/setups/data_config.gin' 3 | include 'meta_dataset/learn/gin/setups/trainer_config_debug.gin' 4 | include 'meta_dataset/learn/gin/setups/variable_way_and_shot.gin' 5 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin' 6 | Trainer.data_config = @DataConfig() 7 | Trainer.learning_rate = 1e-4 8 | 9 | Trainer.train_episode_config = @train/EpisodeDescriptionConfig() 10 | Trainer.eval_episode_config = @EpisodeDescriptionConfig() 11 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/maml_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | 3 | include 'meta_dataset/learn/gin/learners/maml_config.gin' 4 | Trainer.learning_rate = 1e-3 5 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/maml_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | 3 | include 'meta_dataset/learn/gin/learners/maml_config.gin' 4 | Trainer.learning_rate = 1e-3 5 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/maml_init_with_proto_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | 3 | include 'meta_dataset/learn/gin/learners/maml_config.gin' 4 | Trainer.learning_rate = 1e-3 5 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/maml_init_with_proto_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | 3 | include 'meta_dataset/learn/gin/learners/maml_config.gin' 4 | Trainer.learning_rate = 1e-3 5 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/maml_init_with_proto_inference_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/learner_config.gin' 3 | BatchSplitReaderGetReader.add_dataset_offset = True 4 | Trainer.train_learner_class = @BaselineLearner 5 | Trainer.eval_learner_class = @MAMLLearner 6 | # The following line is what makes this proto-MAML. 7 | MAMLLearner.proto_maml_fc_layer_init = True 8 | weight_decay = 1e-4 9 | BaselineLearner.knn_in_fc = False 10 | MAMLLearner.debug = False 11 | 12 | BaselineLearner.transductive_batch_norm = False 13 | BaselineLearner.backprop_through_moments = True 14 | 15 | MAMLLearner.transductive_batch_norm = False 16 | MAMLLearner.backprop_through_moments = True 17 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/maml_init_with_proto_inference_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/learner_config.gin' 3 | Trainer.train_learner_class = @BaselineLearner 4 | Trainer.eval_learner_class = @MAMLLearner 5 | # The following line is what makes this proto-MAML. 6 | MAMLLearner.proto_maml_fc_layer_init = True 7 | weight_decay = 1e-4 8 | BaselineLearner.knn_in_fc = False 9 | MAMLLearner.debug = False 10 | 11 | BaselineLearner.transductive_batch_norm = False 12 | BaselineLearner.backprop_through_moments = True 13 | 14 | MAMLLearner.transductive_batch_norm = False 15 | MAMLLearner.backprop_through_moments = True 16 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/maml_init_with_proto_mini_imagenet_fiveshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin' 2 | EpisodeDescriptionConfig.num_support = 5 3 | 4 | include 'meta_dataset/learn/gin/learners/maml_config.gin' 5 | Trainer.learning_rate = 1e-4 6 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/maml_init_with_proto_mini_imagenet_oneshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin' 2 | EpisodeDescriptionConfig.num_support = 1 3 | 4 | include 'meta_dataset/learn/gin/learners/maml_config.gin' 5 | Trainer.learning_rate = 1e-4 6 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/maml_mini_imagenet_fiveshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin' 2 | EpisodeDescriptionConfig.num_support = 5 3 | 4 | include 'meta_dataset/learn/gin/learners/maml_config.gin' 5 | Trainer.learning_rate = 1e-4 6 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/maml_mini_imagenet_oneshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin' 2 | EpisodeDescriptionConfig.num_support = 1 3 | 4 | include 'meta_dataset/learn/gin/learners/maml_config.gin' 5 | Trainer.learning_rate = 1e-4 6 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/maml_protonet_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | 3 | include 'meta_dataset/learn/gin/learners/maml_config.gin' 4 | Trainer.learning_rate = 1e-3 5 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/matching_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | 3 | include 'meta_dataset/learn/gin/learners/matching_config.gin' 4 | Trainer.learning_rate = 1e-3 5 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/matching_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | 3 | include 'meta_dataset/learn/gin/learners/matching_config.gin' 4 | Trainer.learning_rate = 1e-3 5 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/matching_inference_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/learner_config.gin' 3 | BatchSplitReaderGetReader.add_dataset_offset = True 4 | Trainer.train_learner_class = @BaselineLearner 5 | Trainer.eval_learner_class = @MatchingNetworkLearner 6 | weight_decay = 1e-4 7 | BaselineLearner.knn_in_fc = False 8 | MatchingNetworkLearner.exact_cosine_distance = False 9 | 10 | BaselineLearner.transductive_batch_norm = False 11 | BaselineLearner.backprop_through_moments = True 12 | 13 | MatchingNetworkLearner.transductive_batch_norm = False 14 | MatchingNetworkLearner.backprop_through_moments = True 15 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/matching_inference_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/learner_config.gin' 3 | Trainer.train_learner_class = @BaselineLearner 4 | Trainer.eval_learner_class = @MatchingNetworkLearner 5 | weight_decay = 1e-4 6 | BaselineLearner.knn_in_fc = False 7 | MatchingNetworkLearner.exact_cosine_distance = False 8 | 9 | BaselineLearner.transductive_batch_norm = False 10 | BaselineLearner.backprop_through_moments = True 11 | 12 | MatchingNetworkLearner.transductive_batch_norm = False 13 | MatchingNetworkLearner.backprop_through_moments = True 14 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/matching_mini_imagenet_fiveshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin' 2 | EpisodeDescriptionConfig.num_support = 5 3 | 4 | include 'meta_dataset/learn/gin/learners/matching_config.gin' 5 | Trainer.learning_rate = 1e-4 6 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/matching_mini_imagenet_oneshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin' 2 | EpisodeDescriptionConfig.num_support = 1 3 | 4 | include 'meta_dataset/learn/gin/learners/matching_config.gin' 5 | Trainer.learning_rate = 1e-4 6 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/pretrained_resnet34_224.gin: -------------------------------------------------------------------------------- 1 | # Best pre-trained resnet. Update the path to the checkpoint. 2 | Learner.embedding_fn = @resnet34 3 | Trainer.checkpoint_to_restore = '/path/to/checkpoints/pretrain_imagenet_resnet/model_27500.ckpt' 4 | Trainer.pretrained_source = 'imagenet' 5 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/prototypical_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | 3 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin' 4 | Trainer.learning_rate = 1e-3 5 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/prototypical_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | 3 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin' 4 | Trainer.learning_rate = 1e-3 5 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/prototypical_inference_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | include 'meta_dataset/learn/gin/learners/learner_config.gin' 3 | BatchSplitReaderGetReader.add_dataset_offset = True 4 | Trainer.train_learner_class = @BaselineLearner 5 | Trainer.eval_learner_class = @PrototypicalNetworkLearner 6 | weight_decay = 1e-4 7 | BaselineLearner.knn_in_fc = False 8 | 9 | BaselineLearner.transductive_batch_norm = False 10 | BaselineLearner.backprop_through_moments = True 11 | 12 | PrototypicalNetworkLearner.transductive_batch_norm = False 13 | PrototypicalNetworkLearner.backprop_through_moments = True 14 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/prototypical_inference_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/learner_config.gin' 3 | Trainer.train_learner_class = @BaselineLearner 4 | Trainer.eval_learner_class = @PrototypicalNetworkLearner 5 | weight_decay = 1e-4 6 | BaselineLearner.knn_in_fc = False 7 | 8 | BaselineLearner.transductive_batch_norm = False 9 | BaselineLearner.backprop_through_moments = True 10 | 11 | PrototypicalNetworkLearner.transductive_batch_norm = False 12 | PrototypicalNetworkLearner.backprop_through_moments = True 13 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/prototypical_mini_imagenet_fiveshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin' 2 | EpisodeDescriptionConfig.num_support = 5 3 | train/EpisodeDescriptionConfig.num_ways = 20 4 | 5 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin' 6 | Trainer.learning_rate = 1e-4 7 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/prototypical_mini_imagenet_oneshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin' 2 | EpisodeDescriptionConfig.num_support = 1 3 | train/EpisodeDescriptionConfig.num_ways = 30 4 | 5 | include 'meta_dataset/learn/gin/learners/prototypical_config.gin' 6 | Trainer.learning_rate = 1e-4 7 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/relationnet_all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all.gin' 2 | 3 | include 'meta_dataset/learn/gin/learners/relationnet_config.gin' 4 | 5 | # There are more parameters than usual, so allow more updates. 6 | Trainer.num_updates = 100000 7 | 8 | Trainer.learning_rate = 1e-3 9 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/relationnet_imagenet.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/relationnet_config.gin' 3 | 4 | # There are more parameters than usual, so allow more updates. 5 | Trainer.num_updates = 100000 6 | 7 | Trainer.learning_rate = 1e-3 8 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/relationnet_mini_imagenet_fiveshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin' 2 | include 'meta_dataset/learn/gin/learners/relationnet_config.gin' 3 | 4 | Trainer.learning_rate = 1e-3 5 | Trainer.decay_every = 100000 6 | Trainer.decay_rate = 0.5 7 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/relationnet_mini_imagenet_oneshot.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/mini_imagenet.gin' 2 | EpisodeDescriptionConfig.num_support = 1 3 | include 'meta_dataset/learn/gin/learners/relationnet_config.gin' 4 | 5 | Trainer.learning_rate = 1e-3 6 | Trainer.decay_every = 100000 7 | Trainer.decay_rate = 0.5 8 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/default/resnet34_stride16.gin: -------------------------------------------------------------------------------- 1 | Learner.embedding_fn = @resnet34 2 | Trainer.pretrained_source = 'imagenet' 3 | resnet34.max_stride = 16 4 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/learners/baseline_config.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/learners/learner_config.gin' 2 | Trainer.train_learner_class = @BaselineLearner 3 | Trainer.eval_learner_class = @BaselineLearner 4 | 5 | Learner.embedding_fn = @four_layer_convnet 6 | weight_decay = 1e-4 7 | 8 | BaselineLearner.knn_in_fc = False 9 | BaselineLearner.knn_distance = 'l2' 10 | BaselineLearner.cosine_classifier = False 11 | BaselineLearner.cosine_logits_multiplier = None 12 | BaselineLearner.use_weight_norm = False 13 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/learners/baseline_cosine_config.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/learners/baseline_config.gin' 2 | 3 | BaselineLearner.cosine_classifier = True 4 | BaselineLearner.cosine_logits_multiplier = 1 5 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/learners/baselinefinetune_config.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/learners/baseline_config.gin' 2 | Trainer.train_learner_class = @BaselineFinetuneLearner 3 | Trainer.eval_learner_class = @BaselineFinetuneLearner 4 | weight_decay = 1e-4 5 | BaselineFinetuneLearner.num_finetune_steps = 5 6 | BaselineFinetuneLearner.finetune_lr = 1e-4 7 | BaselineFinetuneLearner.finetune_all_layers = False 8 | BaselineFinetuneLearner.finetune_with_adam = False 9 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/learners/baselinefinetune_cosine_config.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/learners/baselinefinetune_config.gin' 2 | 3 | BaselineLearner.cosine_classifier = True 4 | BaselineLearner.cosine_logits_multiplier = 1 5 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/learners/crosstransformer_config.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/learners/learner_config.gin' 2 | Trainer.train_learner_class = @CrossTransformerLearner 3 | Trainer.eval_learner_class = @CrossTransformerLearner 4 | Trainer.normalized_gradient_descent = True 5 | Learner.embedding_fn = 'four_layer_convnet' 6 | weight_decay = 1e-4 7 | CrossTransformerLearner.tformer_weight_decay = %weight_decay 8 | bn.use_ema=True 9 | Trainer.distribute = True 10 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/learners/learner_config.gin: -------------------------------------------------------------------------------- 1 | # Backbone regularization settings. 2 | linear_classifier.weight_decay = %weight_decay 3 | fully_connected_network.weight_decay = %weight_decay 4 | four_layer_convnet.weight_decay = %weight_decay 5 | relationnet_convnet.weight_decay = %weight_decay 6 | relation_module.weight_decay = %weight_decay 7 | resnet.weight_decay = %weight_decay 8 | resnet34.weight_decay = %weight_decay 9 | wide_resnet.weight_decay = %weight_decay 10 | MAMLLearner.classifier_weight_decay = %weight_decay 11 | weight_decay = 0.001 12 | 13 | # Training settings. 14 | Trainer.checkpoint_to_restore = '' 15 | Trainer.learning_rate = 1e-4 16 | Trainer.decay_learning_rate = False 17 | Trainer.decay_every = 5000 18 | Trainer.decay_rate = 0.5 19 | Trainer.normalized_gradient_descent = False 20 | Trainer.pretrained_source = '' 21 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/learners/maml_config.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/learners/learner_config.gin' 2 | Trainer.train_learner_class = @MAMLLearner 3 | Trainer.eval_learner_class = @MAMLLearner 4 | Trainer.decay_learning_rate = True 5 | 6 | Learner.embedding_fn = @four_layer_convnet 7 | weight_decay = 1e-4 8 | 9 | MAMLLearner.num_update_steps = 5 10 | MAMLLearner.additional_evaluation_update_steps = 0 11 | MAMLLearner.first_order = True 12 | MAMLLearner.alpha = 0.01 13 | MAMLLearner.adapt_batch_norm = False 14 | MAMLLearner.debug = False 15 | MAMLLearner.zero_fc_layer = True 16 | MAMLLearner.proto_maml_fc_layer_init = False 17 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/learners/maml_init_with_proto_config.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/learners/learner_config.gin' 2 | Trainer.train_learner_class = @MAMLLearner 3 | Trainer.decay_learning_rate = True 4 | Trainer.eval_learner_class = @MAMLLearner 5 | 6 | Learner.embedding_fn = @four_layer_convnet 7 | weight_decay = 1e-4 8 | 9 | MAMLLearner.num_update_steps = 5 10 | MAMLLearner.additional_evaluation_update_steps = 0 11 | MAMLLearner.first_order = True 12 | MAMLLearner.alpha = 0.01 13 | MAMLLearner.adapt_batch_norm = False 14 | MAMLLearner.debug = False 15 | MAMLLearner.zero_fc_layer = True 16 | MAMLLearner.proto_maml_fc_layer_init = True 17 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/learners/matching_config.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/learners/learner_config.gin' 2 | Trainer.train_learner_class = @MatchingNetworkLearner 3 | Trainer.eval_learner_class = @MatchingNetworkLearner 4 | 5 | Trainer.decay_learning_rate = True 6 | Learner.embedding_fn = @four_layer_convnet 7 | weight_decay = 1e-4 8 | 9 | MatchingNetworkLearner.exact_cosine_distance = False 10 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/learners/prototypical_config.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/learners/learner_config.gin' 2 | Trainer.train_learner_class = @PrototypicalNetworkLearner 3 | Trainer.eval_learner_class = @PrototypicalNetworkLearner 4 | Trainer.decay_learning_rate = True 5 | 6 | Learner.embedding_fn = @four_layer_convnet 7 | weight_decay = 1e-4 8 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/learners/relationnet_config.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/learners/learner_config.gin' 2 | Trainer.train_learner_class = @RelationNetworkLearner 3 | Trainer.eval_learner_class = @RelationNetworkLearner 4 | Trainer.decay_learning_rate = True 5 | 6 | Learner.embedding_fn = @relationnet_convnet 7 | Learner.transductive_batch_norm = True 8 | weight_decay = 1e-6 9 | 10 | # Allow transductive batch norm to be faithful with the original Relation Net 11 | # implementation, even though this gives it an advantage over the rest of the 12 | # models we used on Meta-Dataset. 13 | Learner.transductive_batch_norm = True 14 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/setups/all.gin: -------------------------------------------------------------------------------- 1 | include 'meta_dataset/learn/gin/setups/all_datasets.gin' 2 | include 'meta_dataset/learn/gin/setups/data_config.gin' 3 | include 'meta_dataset/learn/gin/setups/trainer_config.gin' 4 | include 'meta_dataset/learn/gin/setups/variable_way_and_shot.gin' 5 | Trainer.data_config = @DataConfig() 6 | Trainer.train_episode_config = @train/EpisodeDescriptionConfig() 7 | Trainer.eval_episode_config = @EpisodeDescriptionConfig() 8 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/setups/all_datasets.gin: -------------------------------------------------------------------------------- 1 | benchmark.train_datasets = 'ilsvrc_2012,aircraft,cu_birds,omniglot,quickdraw,vgg_flower,dtd,fungi' 2 | benchmark.eval_datasets = 'ilsvrc_2012' 3 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/setups/data_config.gin: -------------------------------------------------------------------------------- 1 | import src.datasets.original_meta_dataset.data.config 2 | import src.datasets.original_meta_dataset.data.decoder 3 | include 'src/datasets/original_meta_dataset/learn/gin/setups/data_config_common.gin' 4 | 5 | # If we decode features then change the lines below to use FeatureDecoder. 6 | 7 | process_dumped_episode.support_decoder = @support/ImageDecoder() 8 | process_episode.support_decoder = @support/ImageDecoder() 9 | support/ImageDecoder.data_augmentation = @support/DataAugmentation() 10 | 11 | support/DataAugmentation.enable_jitter = True 12 | support/DataAugmentation.jitter_amount = 0 13 | support/DataAugmentation.enable_gaussian_noise = True 14 | support/DataAugmentation.gaussian_noise_std = 0.0 15 | 16 | process_dumped_episode.query_decoder = @query/ImageDecoder() 17 | process_episode.query_decoder = @query/ImageDecoder() 18 | query/ImageDecoder.data_augmentation = @query/DataAugmentation() 19 | query/DataAugmentation.enable_jitter = False 20 | query/DataAugmentation.jitter_amount = 0 21 | query/DataAugmentation.enable_gaussian_noise = False 22 | query/DataAugmentation.gaussian_noise_std = 0.0 23 | 24 | # It is possible to override the support and query decoders for the meta-train 25 | # and / or meta-evaluation phases using the scopes 'train' or 'evaluation' 26 | # (thanks to trainer.py that will appropriately pick the scope for each phase). 27 | # The following commented-out lines show an example: 28 | # train/process_episode.support_decoder = @train/support/ImageDecoder() 29 | # train/support/ImageDecoder.image_size = ... 30 | 31 | process_batch.batch_decoder = @batch/ImageDecoder() 32 | batch/ImageDecoder.data_augmentation = @batch/DataAugmentation() 33 | batch/DataAugmentation.enable_jitter = True 34 | batch/DataAugmentation.jitter_amount = 0 35 | batch/DataAugmentation.enable_gaussian_noise = True 36 | batch/DataAugmentation.gaussian_noise_std = 0.0 37 | 38 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/setups/data_config_common.gin: -------------------------------------------------------------------------------- 1 | # This gin file doesn't set process_episode.support_decoder and .query_decoder 2 | # therefore the image strings are not decoded. This gin file can be used directly 3 | # if the episodes are needed without decoding. 4 | import src.datasets.original_meta_dataset.data.config 5 | 6 | # Default values for sampling variable shots / ways. 7 | EpisodeDescriptionConfig.min_ways = 5 8 | EpisodeDescriptionConfig.max_ways_upper_bound = 50 9 | EpisodeDescriptionConfig.max_num_query = 10 10 | 11 | # For weak shot experiments where we have missing class data. 12 | # This should not affect any other experiements. 13 | 14 | EpisodeDescriptionConfig.min_examples_in_class = 0 15 | EpisodeDescriptionConfig.max_support_set_size = 500 16 | EpisodeDescriptionConfig.max_support_size_contrib_per_class = 100 17 | EpisodeDescriptionConfig.min_log_weight = -0.69314718055994529 # np.log(0.5) 18 | EpisodeDescriptionConfig.max_log_weight = 0.69314718055994529 # np.log(2) 19 | EpisodeDescriptionConfig.ignore_dag_ontology = False 20 | EpisodeDescriptionConfig.ignore_bilevel_ontology = False 21 | EpisodeDescriptionConfig.ignore_hierarchy_probability = 0.0 22 | 23 | # By default don't use SimCLR Episodes. 24 | EpisodeDescriptionConfig.simclr_episode_fraction = 0.0 25 | # It is possible to override some of the above defaults only for meta-training. 26 | # An example is shown in the following two commented-out lines. 27 | # train/EpisodeDescriptionConfig.min_ways = 5 28 | # train/EpisodeDescriptionConfig.max_ways_upper_bound = 50 29 | 30 | # Other default values for the data pipeline. 31 | DataConfig.image_height = 84 32 | DataConfig.shuffle_buffer_size = 1000 33 | DataConfig.read_buffer_size_bytes = 1048576 # 1 MB (1024**2) 34 | DataConfig.num_prefetch = 64 35 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/setups/data_config_feature.gin: -------------------------------------------------------------------------------- 1 | import meta_dataset.data.config 2 | import meta_dataset.data.decoder 3 | import meta_dataset.data.pipeline 4 | include 'meta_dataset/learn/gin/setups/data_config_common.gin' 5 | 6 | process_episode.support_decoder = @FeatureDecoder() 7 | process_episode.query_decoder = @FeatureDecoder() 8 | process_batch.batch_decoder = @FeatureDecoder() 9 | FeatureDecoder.feat_len = 64 10 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/setups/data_config_no_decoder.gin: -------------------------------------------------------------------------------- 1 | import meta_dataset.data.config 2 | import meta_dataset.data.decoder 3 | import meta_dataset.data.pipeline 4 | include 'meta_dataset/learn/gin/setups/data_config_common.gin' 5 | 6 | process_episode.support_decoder = None 7 | process_episode.query_decoder = None 8 | process_batch.batch_decoder = None 9 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/setups/data_config_string.gin: -------------------------------------------------------------------------------- 1 | import meta_dataset.data.config 2 | import meta_dataset.data.decoder 3 | import meta_dataset.data.pipeline 4 | include 'meta_dataset/learn/gin/setups/data_config_common.gin' 5 | 6 | process_episode.support_decoder = @StringDecoder() 7 | process_episode.query_decoder = @StringDecoder() 8 | process_batch.batch_decoder = @StringDecoder() 9 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/setups/fixed_way_and_shot.gin: -------------------------------------------------------------------------------- 1 | EpisodeDescriptionConfig.num_ways = 5 2 | EpisodeDescriptionConfig.num_support = 5 3 | EpisodeDescriptionConfig.num_query = 15 4 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/setups/imagenet.gin: -------------------------------------------------------------------------------- 1 | benchmark.train_datasets = 'ilsvrc_2012' 2 | benchmark.eval_datasets = 'ilsvrc_2012' 3 | include 'meta_dataset/learn/gin/setups/data_config.gin' 4 | include 'meta_dataset/learn/gin/setups/trainer_config.gin' 5 | include 'meta_dataset/learn/gin/setups/variable_way_and_shot.gin' 6 | Trainer.data_config = @DataConfig() 7 | Trainer.train_episode_config = @train/EpisodeDescriptionConfig() 8 | Trainer.eval_episode_config = @EpisodeDescriptionConfig() 9 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/setups/mini_imagenet.gin: -------------------------------------------------------------------------------- 1 | benchmark.train_datasets = 'mini_imagenet' 2 | benchmark.eval_datasets = 'mini_imagenet' 3 | include 'meta_dataset/learn/gin/setups/data_config.gin' 4 | include 'meta_dataset/learn/gin/setups/trainer_config.gin' 5 | include 'meta_dataset/learn/gin/setups/fixed_way_and_shot.gin' 6 | Trainer.data_config = @DataConfig() 7 | Trainer.train_episode_config = @train/EpisodeDescriptionConfig() 8 | Trainer.eval_episode_config = @EpisodeDescriptionConfig() 9 | DataConfig.image_height = 84 10 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/setups/trainer_config.gin: -------------------------------------------------------------------------------- 1 | Trainer.num_updates = 75000 2 | Trainer.batch_size = 256 # Only applicable to non-episodic models. 3 | Trainer.num_eval_episodes = 600 4 | Trainer.checkpoint_every = 500 5 | Trainer.validate_every = 500 6 | Trainer.log_every = 100 7 | Trainer.distribute = False 8 | # Enable TensorFlow optimizations. It can add a few minutes to the first 9 | # calls to session.run(), but decrease memory usage. 10 | Trainer.enable_tf_optimizations = True 11 | 12 | Learner.transductive_batch_norm = False 13 | Learner.backprop_through_moments = True 14 | 15 | 16 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/setups/trainer_config_debug.gin: -------------------------------------------------------------------------------- 1 | Trainer.num_updates = 100 2 | Trainer.batch_size = 8 # Only applicable to non-episodic models. 3 | Trainer.num_eval_episodes = 10 4 | Trainer.checkpoint_every = 10 5 | Trainer.validate_every = 5 6 | Trainer.log_every = 1 7 | Learner.transductive_batch_norm = False 8 | Learner.backprop_through_moments = True 9 | -------------------------------------------------------------------------------- /src/datasets/original_meta_dataset/learn/gin/setups/variable_way_and_shot.gin: -------------------------------------------------------------------------------- 1 | EpisodeDescriptionConfig.num_ways = None 2 | EpisodeDescriptionConfig.num_support = None 3 | EpisodeDescriptionConfig.num_query = None 4 | -------------------------------------------------------------------------------- /src/datasets/reader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | from typing import Union 4 | 5 | from .utils import Split 6 | from .tfrecord.torch.dataset import TFRecordDataset 7 | from .dataset_spec import DatasetSpecification as DS 8 | from .dataset_spec import BiLevelDatasetSpecification as BDS 9 | from .dataset_spec import HierarchicalDatasetSpecification as HDS 10 | 11 | 12 | class Reader(object): 13 | """Class reading data from one source and assembling examples. 14 | 15 | Specifically, it holds part of a tf.data pipeline (the source-specific part), 16 | that reads data from TFRecords and assembles examples from them. 17 | """ 18 | 19 | def __init__(self, 20 | dataset_spec: Union[HDS, BDS, DS], 21 | split: Split, 22 | shuffle: bool, 23 | offset: int): 24 | """Initializes a Reader from a source. 25 | 26 | The source is identified by dataset_spec and split. 27 | 28 | Args: 29 | dataset_spec: DatasetSpecification, dataset specification. 30 | split: A learning_spec.Split object identifying the source split. 31 | shuffle_buffer_size: An integer, the shuffle buffer size for each Dataset 32 | object. If 0, no shuffling operation will happen. 33 | read_buffer_size_bytes: int or None, buffer size for each TFRecordDataset. 34 | num_to_take: Optional, an int specifying a number of elements to pick from 35 | each tfrecord. If specified, the available images of each class will be 36 | restricted to that int. By default (-1) no restriction is applied and 37 | all data is used. 38 | """ 39 | self.split = split 40 | self.dataset_spec = dataset_spec 41 | self.offset = offset 42 | self.shuffle = shuffle 43 | 44 | self.base_path = self.dataset_spec.path 45 | self.class_set = self.dataset_spec.get_classes(self.split) 46 | self.num_classes = len(self.class_set) 47 | 48 | def construct_class_datasets(self): 49 | """Constructs the list of class datasets. 50 | 51 | Returns: 52 | class_datasets: list of tf.data.Dataset, one for each class. 53 | """ 54 | record_file_pattern = self.dataset_spec.file_pattern 55 | index_file_pattern = '{}.index' 56 | # We construct one dataset object per class. Each dataset outputs a stream 57 | # of `(example_string, dataset_id)` tuples. 58 | class_datasets = [] 59 | pbar = tqdm(range(self.num_classes)) 60 | pbar.set_description("Constructing class dataset") 61 | for dataset_id in pbar: 62 | class_id = self.class_set[dataset_id] # noqa: E111 63 | if record_file_pattern.startswith('{}_{}'): 64 | # TODO(lamblinp): Add support for sharded files if needed. 65 | raise NotImplementedError('Sharded files are not supported yet. ' # noqa: E111 66 | 'The code expects one dataset per class.') 67 | elif record_file_pattern.startswith('{}'): 68 | data_path = os.path.join(self.base_path, record_file_pattern.format(class_id)) # noqa: E111 69 | index_path = os.path.join(self.base_path, index_file_pattern.format(class_id)) # noqa: E111 70 | else: 71 | raise ValueError('Unsupported record_file_pattern in DatasetSpec: %s. ' # noqa: E111 72 | 'Expected something starting with "{}" or "{}_{}".' % 73 | record_file_pattern) 74 | description = {"image": "byte", "label": "int"} 75 | 76 | # decode_fn = partial(self.decode_image, offset=self.offset) 77 | dataset = TFRecordDataset(data_path=data_path, 78 | index_path=index_path, 79 | description=description, 80 | shuffle=self.shuffle) 81 | 82 | class_datasets.append(dataset) 83 | 84 | assert len(class_datasets) == self.num_classes 85 | return class_datasets 86 | -------------------------------------------------------------------------------- /src/datasets/tfrecord/__init__.py: -------------------------------------------------------------------------------- 1 | from . import tools 2 | from . import torch 3 | 4 | from . import example_pb2 5 | from . import iterator_utils 6 | from . import reader 7 | from . import writer 8 | 9 | from .iterator_utils import * 10 | from .reader import * 11 | from .writer import * 12 | -------------------------------------------------------------------------------- /src/datasets/tfrecord/iterator_utils.py: -------------------------------------------------------------------------------- 1 | """Iterator utils.""" 2 | 3 | from __future__ import division 4 | 5 | import typing 6 | import warnings 7 | 8 | import numpy as np 9 | 10 | 11 | def cycle(iterator_fn: typing.Callable) -> typing.Iterable[typing.Any]: 12 | """Create a repeating iterator from an iterator generator.""" 13 | while True: 14 | for element in iterator_fn(): 15 | yield element 16 | 17 | 18 | def sample_iterators(iterators: typing.List[typing.Iterator], 19 | ratios: typing.List[int]) -> typing.Iterable[typing.Any]: 20 | """Retrieve info generated from the iterator(s) according to their 21 | sampling ratios. 22 | 23 | Params: 24 | ------- 25 | iterators: list of iterators 26 | All iterators (one for each file). 27 | 28 | ratios: list of int 29 | The ratios with which to sample each iterator. 30 | 31 | Yields: 32 | ------- 33 | item: Any 34 | Decoded bytes of features into its respective data types from 35 | an iterator (based off their sampling ratio). 36 | """ 37 | iterators = [cycle(iterator) for iterator in iterators] 38 | ratios = np.array(ratios) 39 | ratios = ratios / ratios.sum() 40 | while True: 41 | choice = np.random.choice(len(ratios), p=ratios) 42 | yield next(iterators[choice]) 43 | 44 | 45 | def shuffle_iterator(iterator: typing.Iterator, 46 | queue_size: int) -> typing.Iterable[typing.Any]: 47 | """Shuffle elements contained in an iterator. 48 | 49 | Params: 50 | ------- 51 | iterator: iterator 52 | The iterator. 53 | 54 | queue_size: int 55 | Length of buffer. Determines how many records are queued to 56 | sample from. 57 | 58 | Yields: 59 | ------- 60 | item: Any 61 | Decoded bytes of the features into its respective data type (for 62 | an individual record) from an iterator. 63 | """ 64 | # yield next(iterator) 65 | buffer = [] 66 | try: 67 | for _ in range(queue_size): 68 | buffer.append(next(iterator)) 69 | except StopIteration: 70 | warnings.warn("Number of elements in the iterator is less than the " 71 | f"queue size (N={queue_size}).") 72 | while buffer: 73 | index = np.random.randint(len(buffer)) 74 | try: 75 | item = buffer[index] 76 | buffer[index] = next(iterator) 77 | yield item 78 | except StopIteration: 79 | yield buffer.pop(index) 80 | -------------------------------------------------------------------------------- /src/datasets/tfrecord/tools/__init__.py: -------------------------------------------------------------------------------- 1 | from . import tfrecord2idx 2 | 3 | from .tfrecord2idx import create_index 4 | -------------------------------------------------------------------------------- /src/datasets/tfrecord/tools/tfrecord2idx.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import sys 4 | import struct 5 | 6 | 7 | def create_index(tfrecord_file: str, index_file: str) -> None: 8 | """Create index from the tfrecords file. 9 | 10 | Stores starting location (byte) and length (in bytes) of each 11 | serialized record. 12 | 13 | Params: 14 | ------- 15 | tfrecord_file: str 16 | Path to the TFRecord file. 17 | 18 | index_file: str 19 | Path where to store the index file. 20 | """ 21 | infile = open(tfrecord_file, "rb") 22 | outfile = open(index_file, "w") 23 | 24 | while True: 25 | current = infile.tell() 26 | try: 27 | byte_len = infile.read(8) 28 | if len(byte_len) == 0: 29 | break 30 | infile.read(4) 31 | proto_len = struct.unpack("q", byte_len)[0] 32 | infile.read(proto_len) 33 | infile.read(4) 34 | outfile.write(str(current) + " " + str(infile.tell() - current) + "\n") 35 | except: 36 | print("Failed to parse TFRecord.") 37 | break 38 | infile.close() 39 | outfile.close() 40 | 41 | 42 | def main(): 43 | if len(sys.argv) < 3: 44 | print("Usage: tfrecord2idx ") 45 | sys.exit() 46 | 47 | create_index(sys.argv[1], sys.argv[2]) 48 | 49 | 50 | if __name__ == "__main__": 51 | main() 52 | -------------------------------------------------------------------------------- /src/datasets/tfrecord/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from . import dataset 2 | 3 | from .dataset import TFRecordDataset 4 | from .dataset import MultiTFRecordDataset 5 | -------------------------------------------------------------------------------- /src/datasets/tfrecord/writer.py: -------------------------------------------------------------------------------- 1 | """Writer utils.""" 2 | 3 | import io 4 | import struct 5 | import typing 6 | 7 | import numpy as np 8 | try: 9 | import crc32c 10 | except ImportError: 11 | crc32c = None 12 | 13 | from . import example_pb2 14 | 15 | 16 | class TFRecordWriter: 17 | """Opens a TFRecord file for writing. 18 | 19 | Params: 20 | ------- 21 | data_path: str 22 | Path to the tfrecord file. 23 | """ 24 | 25 | def __init__(self, data_path: str) -> None: 26 | self.file = io.open(data_path, "wb") 27 | 28 | def close(self) -> None: 29 | """Close the tfrecord file.""" 30 | self.file.close() 31 | 32 | def write(self, datum: typing.Dict[str, typing.Tuple[typing.Any, str]], 33 | sequence_datum: typing.Union[typing.Dict[str, typing.Tuple[typing.List[typing.Any], str]], None] = None, 34 | ) -> None: 35 | """Write an example into tfrecord file. Either as a Example 36 | SequenceExample depending on the presence of `sequence_datum`. 37 | If `sequence_datum` is None (by default), this writes a Example 38 | to file. Otherwise, it writes a SequenceExample to file, assuming 39 | `datum` to be the context and `sequence_datum` to be the sequential 40 | features. 41 | 42 | Params: 43 | ------- 44 | datum: dict 45 | Dictionary of tuples of form (value, dtype). dtype can be 46 | "byte", "float" or "int". 47 | sequence_datum: dict 48 | By default, it is set to None. If this value is present, then the 49 | Dictionary of tuples of the form (value, dtype). dtype can be 50 | "byte", "float" or "int". value should be the sequential features. 51 | """ 52 | if sequence_datum is None: 53 | record = TFRecordWriter.serialize_tf_example(datum) 54 | else: 55 | record = TFRecordWriter.serialize_tf_sequence_example(datum, sequence_datum) 56 | 57 | length = len(record) 58 | length_bytes = struct.pack(" bytes: 66 | """CRC checksum.""" 67 | mask = 0xa282ead8 68 | crc = crc32c.crc32(data) 69 | masked = ((crc >> 15) | (crc << 17)) + mask 70 | masked = np.uint32(masked) 71 | masked_bytes = struct.pack(" bytes: 76 | """Serialize example into tfrecord.Example proto. 77 | 78 | Params: 79 | ------- 80 | datum: dict 81 | Dictionary of tuples of form (value, dtype). dtype can be 82 | "byte", "float" or "int". 83 | 84 | Returns: 85 | -------- 86 | proto: bytes 87 | Serialized tfrecord.example to bytes. 88 | """ 89 | feature_map = { 90 | "byte": lambda f: example_pb2.Feature( 91 | bytes_list=example_pb2.BytesList(value=f)), 92 | "float": lambda f: example_pb2.Feature( 93 | float_list=example_pb2.FloatList(value=f)), 94 | "int": lambda f: example_pb2.Feature( 95 | int64_list=example_pb2.Int64List(value=f)) 96 | } 97 | 98 | def serialize(value, dtype): 99 | if not isinstance(value, (list, tuple, np.ndarray)): 100 | value = [value] 101 | return feature_map[dtype](value) 102 | 103 | features = {key: serialize(value, dtype) for key, (value, dtype) in datum.items()} 104 | example_proto = example_pb2.Example(features=example_pb2.Features(feature=features)) 105 | return example_proto.SerializeToString() 106 | 107 | @staticmethod 108 | def serialize_tf_sequence_example(context_datum: typing.Dict[str, typing.Tuple[typing.Any, str]], 109 | features_datum: typing.Dict[str, typing.Tuple[typing.List[typing.Any], str]], 110 | ) -> bytes: 111 | """Serialize sequence example into tfrecord.SequenceExample proto. 112 | 113 | Params: 114 | ------- 115 | context_datum: dict 116 | Dictionary of tuples of form (value, dtype). dtype can be 117 | "byte", "float" or int. 118 | 119 | features_datum: dict 120 | Same as `context_datum`, but for the features. 121 | 122 | Returns: 123 | -------- 124 | proto: bytes 125 | Serialized tfrecord.SequenceExample to bytes. 126 | """ 127 | feature_map = { 128 | "byte": lambda f: example_pb2.Feature( 129 | bytes_list=example_pb2.BytesList(value=f)), 130 | "float": lambda f: example_pb2.Feature( 131 | float_list=example_pb2.FloatList(value=f)), 132 | "int": lambda f: example_pb2.Feature( 133 | int64_list=example_pb2.Int64List(value=f)) 134 | } 135 | 136 | def serialize(value, dtype): 137 | if not isinstance(value, (list, tuple, np.ndarray)): 138 | value = [value] 139 | return feature_map[dtype](value) 140 | 141 | def serialize_repeated(value, dtype): 142 | feature_list = example_pb2.FeatureList() 143 | for v in value: 144 | feature_list.feature.append(serialize(v, dtype)) 145 | return feature_list 146 | 147 | context = {key: serialize(value, dtype) for key, (value, dtype) in context_datum.items()} 148 | features = {key: serialize_repeated(value, dtype) for key, (value, dtype) in features_datum.items()} 149 | 150 | context = example_pb2.Features(feature=context) 151 | features = example_pb2.FeatureLists(feature_list=features) 152 | proto = example_pb2.SequenceExample(context=context, feature_lists=features) 153 | return proto.SerializeToString() 154 | -------------------------------------------------------------------------------- /src/datasets/transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | from PIL import ImageEnhance 4 | 5 | from .utils import Split 6 | from .config import DataConfig 7 | from loguru import logger 8 | 9 | jitter_param = dict(Brightness=0.4, Contrast=0.4, Color=0.4) 10 | 11 | 12 | # class ImageJitter(object): 13 | # def __init__(self, transformdict): 14 | # transformtypedict = dict(Brightness=ImageEnhance.Brightness, 15 | # Contrast=ImageEnhance.Contrast, 16 | # Sharpness=ImageEnhance.Sharpness, 17 | # Color=ImageEnhance.Color) 18 | # self.params = [(transformtypedict[k], transformdict[k]) for k in transformdict] 19 | 20 | # def __call__(self, img): 21 | # out = img 22 | # randtensor = torch.rand(len(self.params)) 23 | 24 | # for i, (transformer, alpha) in enumerate(self.params): 25 | # r = alpha * (randtensor[i] * 2.0 - 1.0) + 1 26 | # out = transformer(out).enhance(r).convert('RGB') 27 | 28 | # return out 29 | 30 | 31 | 32 | class GaussianNoise(object): 33 | def __init__(self, noise_std): 34 | self.noise_std = noise_std 35 | 36 | def __call__(self, img): 37 | img += self.noise_std * torch.randn(img.size()) 38 | return img 39 | 40 | 41 | def get_transforms(data_config: DataConfig, 42 | split: Split): 43 | 44 | if split == Split["TRAIN"]: 45 | return train_transform(data_config) 46 | else: 47 | return test_transform(data_config) 48 | 49 | 50 | def test_transform(data_config: DataConfig): 51 | normalize = transforms.Normalize(mean=data_config.norm_mean, 52 | std=data_config.norm_std) 53 | resize_size = data_config.image_size 54 | # resize_size = data_config.image_size 55 | 56 | transf_dict = {'resize': transforms.Resize(resize_size), 57 | 'center_crop': transforms.CenterCrop(data_config.image_size), 58 | 'to_tensor': transforms.ToTensor(), 59 | 'normalize': normalize} 60 | augmentations = data_config.test_transforms 61 | 62 | return transforms.Compose([transf_dict[key] for key in augmentations]) 63 | 64 | 65 | def train_transform(data_config: DataConfig): 66 | normalize = transforms.Normalize(mean=data_config.norm_mean, 67 | std=data_config.norm_std) 68 | transf_dict = {'resize': transforms.Resize(data_config.image_size), 69 | 'center_crop': transforms.CenterCrop(data_config.image_size), 70 | 'random_resized_crop': transforms.RandomResizedCrop(data_config.image_size), 71 | 'jitter': transforms.RandomCrop(data_config.image_size, 72 | padding=[data_config.jitter_amount], 73 | padding_mode='reflect'), 74 | 'random_flip': transforms.RandomHorizontalFlip(), 75 | 'to_tensor': transforms.ToTensor(), 76 | 'normalize': normalize, 77 | 'gaussian': GaussianNoise(data_config.gaussian_noise_std), 78 | } 79 | augmentations = data_config.train_transforms 80 | 81 | return transforms.Compose([transf_dict[key] for key in augmentations]) 82 | -------------------------------------------------------------------------------- /src/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import enum 2 | 3 | 4 | def cycle_(iterable): 5 | # Creating custom cycle since itertools.cycle attempts to save all outputs in order to 6 | # re-cycle through them, creating amazing memory leak 7 | iterator = iter(iterable) 8 | while True: 9 | try: 10 | yield next(iterator) 11 | except StopIteration: 12 | iterator = iter(iterable) 13 | 14 | 15 | class Split(enum.Enum): 16 | """The possible data splits.""" 17 | TRAIN = 0 18 | VALID = 1 19 | TEST = 2 20 | -------------------------------------------------------------------------------- /src/losses.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import numpy as np 5 | import torch.nn as nn 6 | from torch import Tensor 7 | 8 | from .utils import rand_bbox 9 | 10 | 11 | class _Loss(nn.Module): 12 | def __init__(self, 13 | args: argparse.Namespace, 14 | num_classes: int, 15 | reduction: str = 'mean') -> None: 16 | super(_Loss, self).__init__() 17 | 18 | self.reduction: str = reduction 19 | self.augmentation = None if args.augmentation not in ['mixup', 'cutmix'] \ 20 | else eval(f'self.{args.augmentation}') # noqa: E127 21 | 22 | assert 0 <= args.label_smoothing < 1 23 | self.label_smoothing: float = args.label_smoothing 24 | 25 | self.num_classes: int = num_classes 26 | self.beta: int = args.beta 27 | 28 | self.cutmix_prob: float = args.cutmix_prob # Clash with the method `cutmix` 29 | 30 | def smooth_one_hot(self, 31 | targets: Tensor): 32 | with torch.no_grad(): 33 | new_targets = torch.empty(size=(targets.size(0), self.num_classes), device=targets.device) 34 | new_targets.fill_(self.label_smoothing / (self.num_classes - 1)) 35 | new_targets.scatter_(1, targets.unsqueeze(1), 1. - self.label_smoothing) 36 | 37 | return new_targets 38 | 39 | def mixup(self, 40 | input_: Tensor, 41 | one_hot_targets: Tensor, 42 | model: nn.Module): 43 | # Forward pass 44 | device = one_hot_targets.device 45 | 46 | # generate mixed sample and targets 47 | lam = np.random.beta(self.beta, self.beta) 48 | rand_index = torch.randperm(input_.size()[0]).to(device) 49 | 50 | target_a = one_hot_targets 51 | target_b = one_hot_targets[rand_index] 52 | 53 | mixed_input_ = lam * input_ + (1 - lam) * input_[rand_index] 54 | output = model(mixed_input_) 55 | 56 | loss = self.loss_fn(output, target_a) * lam + self.loss_fn(output, target_b) * (1. - lam) 57 | 58 | return loss 59 | 60 | def cutmix(self, 61 | input_: Tensor, 62 | one_hot_targets: Tensor, 63 | model: nn.Module): 64 | # generate mixed sample 65 | lam = np.random.beta(self.beta, self.beta) 66 | rand_index = torch.randperm(input_.size()[0]).cuda() 67 | 68 | target_a = one_hot_targets 69 | target_b = one_hot_targets[rand_index] 70 | 71 | bbx1, bby1, bbx2, bby2 = rand_bbox(input_.size(), lam) 72 | input_[:, :, bbx1:bbx2, bby1:bby2] = input_[rand_index, :, bbx1:bbx2, bby1:bby2] 73 | 74 | # adjust lambda to exactly match pixel ratio 75 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (input_.size()[-1] * input_.size()[-2])) 76 | output = model(input_) 77 | 78 | loss = self.loss_fn(output, target_a) * lam + self.loss_fn(output, target_b) * (1. - lam) 79 | 80 | return loss 81 | 82 | def loss_fn(self, 83 | logits: Tensor, 84 | one_hot_targets: Tensor): 85 | raise NotImplementedError 86 | 87 | def forward(self, 88 | logits: Tensor, 89 | targets: Tensor, 90 | model: torch.nn.Module): 91 | raise NotImplementedError 92 | 93 | 94 | class _CrossEntropy(_Loss): 95 | def loss_fn(self, 96 | logits: Tensor, 97 | one_hot_targets: Tensor): 98 | logsoftmax_fn = nn.LogSoftmax(dim=1) 99 | logsoftmax = logsoftmax_fn(logits) 100 | loss = - (one_hot_targets * logsoftmax).sum(1) 101 | if self.reduction == 'mean': 102 | return loss.mean() 103 | else: 104 | return loss 105 | 106 | def forward(self, 107 | input_: Tensor, 108 | targets: Tensor, 109 | model: nn.Module): 110 | 111 | one_hot_targets = self.smooth_one_hot(targets) 112 | if self.augmentation: 113 | return self.augmentation(input_, one_hot_targets, model) 114 | else: 115 | logits = model(input_) 116 | 117 | return self.loss_fn(logits, one_hot_targets), logits.softmax(-1) 118 | 119 | 120 | class _FocalLoss(_Loss): 121 | def __init__(self, 122 | **kwargs) -> None: 123 | super(_FocalLoss, self).__init__(**kwargs) 124 | self.gamma = kwargs['args'].focal_gamma 125 | 126 | def loss_fn(self, 127 | logits: Tensor, 128 | one_hot_targets: Tensor): 129 | softmax = logits.softmax(1) 130 | logsoftmax = torch.log(softmax + 1e-10) 131 | loss = - (one_hot_targets * (1 - softmax)**self.gamma * logsoftmax).sum(1) 132 | 133 | if self.reduction == 'mean': 134 | return loss.mean() 135 | else: 136 | return loss 137 | 138 | def forward(self, 139 | input_: Tensor, 140 | targets: Tensor, 141 | model: nn.Module): 142 | one_hot_targets = self.smooth_one_hot(targets) 143 | if self.augmentation: 144 | return self.augmentation(input_, one_hot_targets, model) 145 | else: 146 | logits = model(input_) 147 | 148 | return self.loss_fn(logits, one_hot_targets) 149 | 150 | 151 | __losses__ = {'xent': _CrossEntropy, 'focal': _FocalLoss} 152 | -------------------------------------------------------------------------------- /src/methods/__init__.py: -------------------------------------------------------------------------------- 1 | from .method import * 2 | from .tim import * 3 | from .bd_cspn import * 4 | from .simpleshot import * 5 | from .maml import * 6 | from .protonet import * 7 | from .finetune import * -------------------------------------------------------------------------------- /src/methods/bd_cspn.py: -------------------------------------------------------------------------------- 1 | import time 2 | import argparse 3 | from typing import Dict, Optional, Tuple 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import Tensor 8 | 9 | from .method import FSmethod 10 | from ..metrics import Metric 11 | from .utils import get_one_hot, extract_features, compute_centroids 12 | 13 | 14 | class BDCSPN(FSmethod): 15 | 16 | """ 17 | Implementation of BD-CSPN (ECCV 2020) https://arxiv.org/abs/1911.10713 18 | """ 19 | def __init__(self, 20 | args: argparse.Namespace): 21 | self.temp = args.temp 22 | self.iter = 1 23 | self.episodic_training = False 24 | self.extract_batch_size = args.extract_batch_size 25 | if args.val_batch_size > 1: 26 | raise ValueError("For now, only val_batch_size=1 is support for LaplacianShot") 27 | super().__init__(args) 28 | 29 | def record_info(self, 30 | metrics: Optional[Dict], 31 | task_ids: Optional[Tuple], 32 | iteration: int, 33 | new_time: float, 34 | support: Tensor, 35 | query: Tensor, 36 | y_s: Tensor, 37 | y_q: Tensor) -> None: 38 | """ 39 | inputs: 40 | support : tensor of shape [n_task, s_shot, feature_dim] 41 | query : tensor of shape [n_task, q_shot, feature_dim] 42 | y_s : tensor of shape [n_task, s_shot] 43 | y_q : tensor of shape [n_task, q_shot] : 44 | """ 45 | if metrics: 46 | logits_s = self.get_logits(support).detach() 47 | probs_s = logits_s.softmax(-1) 48 | logits_q = self.get_logits(query).detach() 49 | preds_q = logits_q.argmax(2) 50 | probs_q = logits_q.softmax(2) 51 | kwargs = {'probs': probs_q, 'probs_s': probs_s, 'preds': preds_q, 52 | 'gt': y_q, 'z_s': support, 'z_q': query, 'gt_s': y_s, 53 | 'weights': self.weights} 54 | 55 | assert task_ids is not None 56 | for metric_name in metrics: 57 | metrics[metric_name].update(task_ids[0], 58 | task_ids[1], 59 | iteration, 60 | **kwargs) 61 | 62 | def rectify(self, 63 | support: Tensor, 64 | query: Tensor, 65 | y_s: Tensor, 66 | y_q: Tensor) -> None: 67 | num_classes = y_s.unique().size(0) 68 | one_hot_s = get_one_hot(y_s, num_classes) 69 | eta = support.mean(1, keepdim=True) - query.mean(1, keepdim=True) 70 | query = query + eta 71 | 72 | logits_s = self.get_logits(support).exp() # [n_task, shot_s, num_class] 73 | logits_q = self.get_logits(query).exp() # [n_task, shot_q, num_class] 74 | 75 | preds_q = logits_q.argmax(-1) 76 | one_hot_q = get_one_hot(preds_q, num_classes) 77 | 78 | normalization = ((one_hot_s * logits_s).sum(1) + (one_hot_q * logits_q).sum(1)).unsqueeze(1) 79 | w_s = (one_hot_s * logits_s) / normalization # [n_task, shot_s, num_class] 80 | w_q = (one_hot_q * logits_q) / normalization # [n_task, shot_q, num_class] 81 | 82 | # assert np.allclose((torch.cat([w_s, w_q], 1).sum(1) - 1.).sum().item(), 0.), 83 | # (torch.cat([w_s, w_q], 1).sum(1) - 1.).sum() 84 | 85 | self.weights = ((w_s * one_hot_s).transpose(1, 2).matmul(support) \ 86 | + (w_q * one_hot_q).transpose(1, 2).matmul(query)) 87 | 88 | def get_logits(self, samples: Tensor): 89 | """ 90 | inputs: 91 | samples : tensor of shape [n_task, shot, feature_dim] 92 | 93 | returns : 94 | logits : tensor of shape [n_task, shot, num_class] 95 | """ 96 | # weights = self.weights, 2) 97 | # samples = samples, 2) 98 | # logits = self.temp * (samples.matmul(weights.transpose(1, 2))) # 99 | cosine = torch.nn.CosineSimilarity(dim=3, eps=1e-6) 100 | logits = cosine(samples[:, :, None, :], self.weights[:, None, :, :]) 101 | assert logits.max() <= self.temp and logits.min() >= -self.temp, (logits.min(), logits.max()) 102 | return logits 103 | 104 | def forward(self, 105 | model: torch.nn.Module, 106 | support: Tensor, 107 | query: Tensor, 108 | y_s: Tensor, 109 | y_q: Tensor, 110 | metrics: Dict[str, Metric] = None, 111 | task_ids: Tuple[int, int] = None) -> Tuple[Optional[Tensor], Tensor]: 112 | """ 113 | Corresponds to the TIM-GD inference 114 | inputs: 115 | support : tensor of shape [n_task, s_shot, feature_dim] 116 | query : tensor of shape [n_task, q_shot, feature_dim] 117 | y_s : tensor of shape [n_task, s_shot] 118 | y_q : tensor of shape [n_task, q_shot] 119 | 120 | 121 | updates : 122 | self.weights : tensor of shape [n_task, num_class, feature_dim] 123 | """ 124 | model.eval() 125 | 126 | # Extract support and query features 127 | with torch.no_grad(): 128 | feat_s, feat_q = extract_features(self.extract_batch_size, 129 | support, 130 | query, 131 | model) 132 | 133 | # Perform required normalizations 134 | feat_s = F.normalize(feat_s, dim=2) 135 | feat_q = F.normalize(feat_q, dim=2) 136 | 137 | # Initialize weights 138 | t0 = time.time() 139 | self.weights = compute_centroids(feat_s, y_s) # [batch, num_class, d] 140 | self.record_info(iteration=0, 141 | task_ids=task_ids, 142 | metrics=metrics, 143 | new_time=time.time() - t0, 144 | support=feat_s, 145 | query=feat_q, 146 | y_s=y_s, 147 | y_q=y_q) 148 | 149 | self.rectify(support=feat_s, y_s=y_s, 150 | query=feat_q, y_q=y_q) 151 | 152 | self.record_info(iteration=1, 153 | task_ids=task_ids, 154 | metrics=metrics, 155 | new_time=time.time() - t0, 156 | support=feat_s, 157 | query=feat_q, 158 | y_s=y_s, 159 | y_q=y_q) 160 | q_probs = self.get_logits(feat_q) 161 | 162 | return None, q_probs 163 | -------------------------------------------------------------------------------- /src/methods/finetune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Dict, Optional, Tuple 3 | from loguru import logger 4 | import torch 5 | import torch.nn.functional as F 6 | import torch.distributed as dist 7 | from torch import Tensor 8 | import torch.nn as nn 9 | 10 | from .method import FSmethod 11 | from ..metrics import Metric 12 | from .utils import get_one_hot, extract_features 13 | 14 | 15 | class Finetune(FSmethod): 16 | """ 17 | Implementation of Finetune (or Baseline method) (ICLR 2019) https://arxiv.org/abs/1904.04232 18 | """ 19 | def __init__(self, 20 | args: argparse.Namespace): 21 | self.temp = args.temp 22 | self.iter = args.iter 23 | self.extract_batch_size = args.extract_batch_size 24 | self.finetune_all_layers = args.finetune_all_layers 25 | self.cosine_head = args.finetune_cosine_head 26 | self.weight_norm = args.finetune_weight_norm 27 | self.lr = args.finetune_lr 28 | 29 | self.episodic_training = False 30 | 31 | super().__init__(args) 32 | 33 | def record_info(self, 34 | metrics: Optional[Dict], 35 | task_ids: Optional[Tuple], 36 | iteration: int, 37 | preds_q: Tensor, 38 | probs_s: Tensor, 39 | y_q: Tensor, 40 | y_s: Tensor) -> None: 41 | """ 42 | inputs: 43 | support : tensor of shape [n_task, s_shot, feature_dim] 44 | query : tensor of shape [n_task, q_shot, feature_dim] 45 | y_s : tensor of shape [n_task, s_shot] 46 | y_q : tensor of shape [n_task, q_shot] : 47 | """ 48 | if metrics: 49 | kwargs = {'preds': preds_q, 'gt': y_q, 'probs_s': probs_s, 50 | 'gt_s': y_s} 51 | 52 | assert task_ids is not None 53 | for metric_name in metrics: 54 | metrics[metric_name].update(task_ids[0], 55 | task_ids[1], 56 | iteration, 57 | **kwargs) 58 | 59 | def _do_data_dependent_init(self, classifier: nn.Module, feat_s: Tensor): 60 | """Returns ops for the data-dependent init of g and maybe b_fc.""" 61 | w_fc_normalized = F.normalize(classifier.weight_v, dim=1) # [num_classes, d] 62 | output_init = feat_s @ w_fc_normalized.t() # [n_s, num_classes] 63 | var_init = output_init.var(0, keepdim=True) # [num_classes] 64 | # Data-dependent init values. 65 | classifier.weight_g.data = 1. / torch.sqrt(var_init + 1e-10) 66 | 67 | def forward(self, 68 | model: torch.nn.Module, 69 | support: Tensor, 70 | query: Tensor, 71 | y_s: Tensor, 72 | y_q: Tensor, 73 | metrics: Dict[str, Metric] = None, 74 | task_ids: Tuple[int, int] = None) -> Tuple[Optional[Tensor], Tensor]: 75 | """ 76 | Corresponds to the TIM-GD inference 77 | inputs: 78 | support : tensor of shape [n_task, s_shot, feature_dim] 79 | query : tensor of shape [n_task, q_shot, feature_dim] 80 | y_s : tensor of shape [n_task, s_shot] 81 | y_q : tensor of shape [n_task, q_shot] 82 | 83 | 84 | updates : 85 | self.weights : tensor of shape [n_task, num_class, feature_dim] 86 | """ 87 | device = dist.get_rank() 88 | n_tasks = support.size(0) 89 | if n_tasks > 1: 90 | raise ValueError('Finetune method can only deal with 1 task at a time. \ 91 | Currently {} tasks.'.format(n_tasks)) 92 | y_s = y_s[0] 93 | y_q = y_q[0] 94 | num_classes = y_s.unique().size(0) 95 | y_s_one_hot = get_one_hot(y_s, num_classes) 96 | 97 | # Initialize classifier 98 | with torch.no_grad(): 99 | feat_s, feat_q = extract_features(self.extract_batch_size, 100 | support, 101 | query, 102 | model) 103 | 104 | classifier = nn.Linear(feat_s.size(-1), num_classes, bias=False).to(device) 105 | if self.weight_norm: 106 | classifier = nn.utils.weight_norm(classifier, name='weight') 107 | 108 | # self._do_data_dependent_init(classifier, feat_s) 109 | 110 | if self.cosine_head: 111 | feat_s = F.normalize(feat_s, dim=-1) 112 | feat_q = F.normalize(feat_q, dim=-1) 113 | 114 | logits_q = self.temp * classifier(feat_q[0]) 115 | logits_s = self.temp * classifier(feat_s[0]) 116 | preds_q = logits_q.argmax(-1) 117 | probs_s = logits_s.softmax(-1) 118 | self.record_info(iteration=0, 119 | task_ids=task_ids, 120 | metrics=metrics, 121 | preds_q=preds_q, 122 | probs_s=probs_s, 123 | y_q=y_q, 124 | y_s=y_s) 125 | 126 | # Define optimizer 127 | if self.finetune_all_layers: 128 | params = list(model.parameters()) + list(classifier.parameters()) 129 | else: 130 | params = list(classifier.parameters()) 131 | optimizer = torch.optim.Adam(params, lr=self.lr) 132 | 133 | # Run adaptation 134 | for i in range(1, self.iter): 135 | if self.finetune_all_layers: 136 | model.train() 137 | feat_s, feat_q = extract_features(self.extract_batch_size, 138 | support, 139 | query, 140 | model) 141 | if self.cosine_head: 142 | feat_s = F.normalize(feat_s, dim=-1) 143 | feat_q = F.normalize(feat_q, dim=-1) 144 | 145 | logits_s = self.temp * classifier(feat_s[0]) 146 | probs_s = logits_s.softmax(-1) 147 | loss = - (y_s_one_hot * probs_s.log()).sum(-1).mean(-1) 148 | 149 | optimizer.zero_grad() 150 | loss.backward() 151 | optimizer.step() 152 | 153 | with torch.no_grad(): 154 | preds_q = classifier(feat_q[0]).argmax(-1) 155 | 156 | self.record_info(iteration=i, 157 | task_ids=task_ids, 158 | metrics=metrics, 159 | preds_q=preds_q, 160 | probs_s=probs_s, 161 | y_q=y_q, 162 | y_s=y_s) 163 | 164 | probs_q = classifier(feat_q[0]).softmax(-1).unsqueeze(0) 165 | return loss.detach(), probs_q.detach() 166 | -------------------------------------------------------------------------------- /src/methods/maml.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Dict, Tuple, Optional 3 | from collections import OrderedDict 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import Tensor 8 | 9 | from .method import FSmethod 10 | from ..models.meta.metamodules import MetaModule 11 | from .utils import get_one_hot 12 | from ..metrics import Metric 13 | 14 | 15 | class MAML(FSmethod): 16 | """ 17 | Implementation of MAML (ICML 2017) https://arxiv.org/abs/1703.03400. 18 | Inspired by https://github.com/tristandeleu/pytorch-meta/tree/master/examples/maml 19 | """ 20 | 21 | def __init__(self, args: argparse.Namespace): 22 | self.step_size = args.step_size 23 | self.first_order = args.first_order 24 | self.iter = args.iter 25 | self.train_iter = args.iter 26 | 27 | super().__init__(args) 28 | 29 | def forward(self, 30 | model: torch.nn.Module, 31 | support: Tensor, 32 | query: Tensor, 33 | y_s: Tensor, 34 | y_q: Tensor, 35 | metrics: Dict[str, Metric] = None, 36 | task_ids: Tuple[int, int] = None) -> Tuple[Optional[Tensor], Tensor]: 37 | iter_ = self.train_iter if self.training else self.iter 38 | 39 | model.train() 40 | device = torch.distributed.get_rank() 41 | 42 | outer_loss = torch.tensor(0., device=device) 43 | soft_preds = torch.zeros_like(get_one_hot(y_q, y_s.unique().size(0))) 44 | 45 | for task_idx, (x_s, y_support, x_q, y_query) in enumerate(zip(support, 46 | y_s, 47 | query, 48 | y_q)): 49 | params = None 50 | x_s, x_q = x_s.to(device), x_q.to(device) 51 | 52 | for i in range(iter_): 53 | train_logit = model(x_s, params=params) 54 | inner_loss = F.cross_entropy(train_logit, y_support) 55 | 56 | model.zero_grad() 57 | params = self.gradient_update_parameters(model=model, 58 | loss=inner_loss, 59 | params=params) 60 | 61 | if not self.training: # if doing evaluation, put back the model in eval() 62 | model.eval() 63 | 64 | with torch.set_grad_enabled(self.training): 65 | query_logit = model(x_q, params=params) 66 | 67 | outer_loss += F.cross_entropy(query_logit, y_query) 68 | soft_preds[task_idx] = query_logit.detach().softmax(-1) 69 | 70 | return outer_loss, soft_preds 71 | 72 | def gradient_update_parameters(self, 73 | model, 74 | loss, 75 | params=None) -> OrderedDict: 76 | """Update of the meta-parameters with one step of gradient descent on the 77 | loss function. 78 | Parameters 79 | ---------- 80 | model : `torchmeta.modules.MetaModule` instance 81 | The model. 82 | loss : `torch.Tensor` instance 83 | The value of the inner-loss. This is the result of the training dataset 84 | through the loss function. 85 | params : `collections.OrderedDict` instance, optional 86 | Dictionary containing the meta-parameters of the model. If `None`, then 87 | the values stored in `model.meta_named_parameters()` are used. This is 88 | useful for running multiple steps of gradient descent as the inner-loop. 89 | step_size : int, `torch.Tensor`, or `collections.OrderedDict` instance (default: 0.5) 90 | The step size in the gradient update. If an `OrderedDict`, then the 91 | keys must match the keys in `params`. 92 | first_order : bool (default: `False`) 93 | If `True`, then the first order approximation of MAML is used. 94 | Returns 95 | ------- 96 | updated_params : `collections.OrderedDict` instance 97 | Dictionary containing the updated meta-parameters of the model, with one 98 | gradient update wrt. the inner-loss. 99 | """ 100 | if not isinstance(model, MetaModule): 101 | raise ValueError('The model must be an instance of `torchmeta.modules.' 102 | 'MetaModule`, got `{0}`'.format(type(model))) 103 | 104 | if params is None: 105 | params = OrderedDict(model.meta_named_parameters()) 106 | 107 | create_graph = (not self.first_order) and self.training 108 | grads = torch.autograd.grad(loss, 109 | params.values(), 110 | create_graph=create_graph) 111 | 112 | updated_params = OrderedDict() 113 | 114 | if isinstance(self.step_size, (dict, OrderedDict)): 115 | for (name, param), grad in zip(params.items(), grads): 116 | updated_params[name] = param - self.step_size[name] * grad 117 | 118 | else: 119 | for (name, param), grad in zip(params.items(), grads): 120 | updated_params[name] = param - self.step_size * grad 121 | 122 | return updated_params 123 | -------------------------------------------------------------------------------- /src/methods/method.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Dict, Optional, Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch import Tensor 7 | 8 | from ..metrics import Metric 9 | 10 | 11 | class FSmethod(nn.Module): 12 | ''' 13 | Abstract class for few-shot methods 14 | ''' 15 | def __init__(self, args: argparse.Namespace): 16 | super(FSmethod, self).__init__() 17 | 18 | def forward(self, 19 | model: torch.nn.Module, 20 | support: Tensor, 21 | query: Tensor, 22 | y_s: Tensor, 23 | y_q: Tensor, 24 | metrics: Dict[str, Metric] = None, 25 | task_ids: Tuple[int, int] = None) -> Tuple[Optional[Tensor], Tensor]: 26 | ''' 27 | args: 28 | model: Network to train/test with 29 | support: Tensor representing support images, of shape [n_tasks, n_support, C, H, W] 30 | where n_tasks is the batch dimension (only useful for fixed-dimension tasks) 31 | and n_support the total number of support samples 32 | query: Tensor representing query images, of shape [n_tasks, n_query, C, H, W] 33 | where n_tasks is the batch dimension (only useful for fixed-dimension tasks) 34 | and n_query the total number of query samples 35 | y_s: Tensor representing the support labels of shape [n_tasks, n_support] 36 | y_q: Tensor representing the query labels of shape [n_tasks, n_query] 37 | metrics: A dictionnary of Metric objects to be filled during inference 38 | (mostly useful if the method performs test-time inference). Refer to tim.py for 39 | an instance of usage 40 | task_ids: Start and end tasks ids. Only used to fill the metrics dictionnary. 41 | 42 | returns: 43 | loss: Tensor of shape [] representing the loss to be minimized (for methods using episodic training) 44 | soft_preds: Tensor of shape [n_tasks, n_query, K], where K is the number of classes in the task, 45 | representing the soft predictions of the method for the input query samples. 46 | ''' 47 | raise NotImplementedError 48 | 49 | def record_info(self, 50 | metrics: Optional[Dict], 51 | task_ids: Optional[Tuple], 52 | iteration: int, 53 | new_time: float, 54 | support: Tensor, 55 | query: Tensor, 56 | y_s: Tensor, 57 | y_q: Tensor) -> None: 58 | """ 59 | inputs: 60 | support : Tensor of shape [n_task, s_shot, feature_dim] 61 | query : Tensor of shape [n_task, q_shot, feature_dim] 62 | y_s : Tensor of shape [n_task, s_shot] 63 | y_q : Tensor of shape [n_task, q_shot] 64 | """ 65 | raise NotImplementedError 66 | -------------------------------------------------------------------------------- /src/methods/protonet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Dict, Optional, Tuple 3 | 4 | import torch 5 | from torch import Tensor 6 | from loguru import logger 7 | from .utils import get_one_hot, compute_centroids, extract_features 8 | from .method import FSmethod 9 | from ..metrics import Metric 10 | 11 | 12 | class ProtoNet(FSmethod): 13 | """ 14 | Implementation of ProtoNet method https://arxiv.org/abs/1703.05175 15 | """ 16 | 17 | def __init__(self, args: argparse.Namespace): 18 | self.extract_batch_size = args.extract_batch_size 19 | self.normamlize = False 20 | 21 | super().__init__(args) 22 | 23 | def record_info(self, 24 | metrics: Optional[Dict], 25 | task_ids: Optional[Tuple], 26 | iteration: int, 27 | preds_q: Tensor, 28 | y_q: Tensor): 29 | """ 30 | inputs: 31 | support : tensor of shape [n_task, s_shot, feature_dim] 32 | query : tensor of shape [n_task, q_shot, feature_dim] 33 | y_s : tensor of shape [n_task, s_shot] 34 | y_q : tensor of shape [n_task, q_shot] : 35 | """ 36 | if metrics: 37 | kwargs = {'gt': y_q, 'preds': preds_q} 38 | 39 | assert task_ids is not None 40 | for metric_name in metrics: 41 | metrics[metric_name].update(task_ids[0], 42 | task_ids[1], 43 | iteration, 44 | **kwargs) 45 | 46 | def forward(self, 47 | model: torch.nn.Module, 48 | support: Tensor, 49 | query: Tensor, 50 | y_s: Tensor, 51 | y_q: Tensor, 52 | metrics: Dict[str, Metric] = None, 53 | task_ids: Tuple[int, int] = None) -> Tuple[Optional[Tensor], Tensor]: 54 | """ 55 | inputs: 56 | support : tensor of size [n_task, s_shot, c, h, w] 57 | query : tensor of size [n_task, q_shot, c, h, w] 58 | y_s : tensor of size [n_task, s_shot] 59 | y_q : tensor of size [n_task, q_shot] 60 | """ 61 | num_classes = y_s.unique().size(0) 62 | 63 | with torch.set_grad_enabled(self.training): 64 | z_s, z_q = extract_features(0, support, query, model) 65 | 66 | centroids = compute_centroids(z_s, y_s) # [batch, num_class, d] 67 | l2_distance = torch.cdist(z_q, centroids) ** 2 # [batch, q_shot, num_class] 68 | 69 | log_probas = (-l2_distance).log_softmax(-1) # [batch, q_shot, num_class] 70 | one_hot_q = get_one_hot(y_q, num_classes) # [batch, q_shot, num_class] 71 | ce = - (one_hot_q * log_probas).sum(-1) # [batch, q_shot, num_class] 72 | 73 | soft_preds_q = log_probas.detach().exp() 74 | self.record_info(iteration=0, 75 | metrics=metrics, 76 | task_ids=task_ids, 77 | preds_q=soft_preds_q.argmax(-1), 78 | y_q=y_q) 79 | 80 | return ce, soft_preds_q 81 | -------------------------------------------------------------------------------- /src/methods/simpleshot.py: -------------------------------------------------------------------------------- 1 | import time 2 | import argparse 3 | from typing import Dict, Optional, Tuple 4 | import torch.nn.functional as F 5 | 6 | import torch 7 | from torch import Tensor 8 | 9 | from .method import FSmethod 10 | from ..metrics import Metric 11 | from .utils import extract_features, compute_centroids 12 | 13 | 14 | class SimpleShot(FSmethod): 15 | """ 16 | Implementation of SimpleShot method https://arxiv.org/abs/1911.04623 17 | """ 18 | def __init__(self, args: argparse.Namespace): 19 | self.iter = args.iter 20 | self.episodic_training = False 21 | self.extract_batch_size = args.extract_batch_size 22 | 23 | super().__init__(args) 24 | 25 | def get_logits(self, samples: Tensor) -> Tensor: 26 | """ 27 | inputs: 28 | samples : tensor of shape [n_task, shot, feature_dim] 29 | 30 | returns : 31 | logits : tensor of shape [n_task, shot, num_class] 32 | """ 33 | n_tasks = samples.size(0) 34 | logits = (samples.matmul(self.weights.transpose(1, 2)) 35 | - 1 / 2 * (self.weights**2).sum(2).view(n_tasks, 1, -1) 36 | - 1 / 2 * (samples**2).sum(2).view(n_tasks, -1, 1)) 37 | 38 | return logits 39 | 40 | def get_preds(self, samples: Tensor) -> Tensor: 41 | """ 42 | inputs: 43 | samples : tensor of shape [n_task, s_shot, feature_dim] 44 | 45 | returns : 46 | preds : tensor of shape [n_task, shot] 47 | """ 48 | logits = self.get_logits(samples) 49 | preds = logits.argmax(2) 50 | return preds 51 | 52 | def record_info(self, 53 | metrics: Optional[Dict], 54 | task_ids: Optional[Tuple], 55 | iteration: int, 56 | new_time: float, 57 | support: Tensor, 58 | query: Tensor, 59 | y_s: Tensor, 60 | y_q: Tensor) -> Tensor: 61 | """ 62 | inputs: 63 | support : tensor of shape [n_task, s_shot, feature_dim] 64 | query : tensor of shape [n_task, q_shot, feature_dim] 65 | y_s : tensor of shape [n_task, s_shot] 66 | y_q : tensor of shape [n_task, q_shot] : 67 | """ 68 | if metrics: 69 | logits_s = self.get_logits(support).detach() 70 | probs_s = logits_s.softmax(-1) 71 | 72 | logits_q = self.get_logits(query).detach() 73 | preds_q = logits_q.argmax(2) 74 | probs_q = logits_q.softmax(2) 75 | 76 | kwargs = {'probs': probs_q, 'probs_s': probs_s, 'preds': preds_q, 77 | 'gt': y_q, 'z_s': support, 'z_q': query, 'gt_s': y_s, 78 | 'weights': self.weights} 79 | 80 | assert task_ids is not None 81 | for metric_name in metrics: 82 | metrics[metric_name].update(task_ids[0], 83 | task_ids[1], 84 | iteration, 85 | **kwargs) 86 | 87 | def forward(self, 88 | model: torch.nn.Module, 89 | support: Tensor, 90 | query: Tensor, 91 | y_s: Tensor, 92 | y_q: Tensor, 93 | metrics: Dict[str, Metric] = None, 94 | task_ids: Tuple[int, int] = None) -> Tuple[Optional[Tensor], Tensor]: 95 | 96 | model.eval() 97 | with torch.no_grad(): 98 | feat_s, feat_q = extract_features(self.extract_batch_size, 99 | support, query, model) 100 | 101 | # Perform required normalizations 102 | feat_s = F.normalize(feat_s, dim=-1) 103 | feat_q = F.normalize(feat_q, dim=-1) 104 | 105 | # Initialize weights 106 | t0 = time.time() 107 | self.weights = compute_centroids(feat_s, y_s) 108 | self.record_info(iteration=0, 109 | metrics=metrics, 110 | task_ids=task_ids, 111 | new_time=time.time() - t0, 112 | support=feat_s, 113 | query=feat_q, 114 | y_s=y_s, 115 | y_q=y_q) 116 | 117 | P_q = self.get_logits(feat_q).softmax(2) 118 | 119 | return None, P_q 120 | -------------------------------------------------------------------------------- /src/methods/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.distributed as dist 7 | from torch import Tensor 8 | 9 | 10 | def compute_centroids(z_s: Tensor, 11 | y_s: Tensor): 12 | """ 13 | inputs: 14 | z_s : torch.Tensor of size [*, s_shot, d] 15 | y_s : torch.Tensor of size [*, s_shot] 16 | updates : 17 | centroids : torch.Tensor of size [*, num_class, d] 18 | """ 19 | one_hot = get_one_hot(y_s, num_classes=y_s.unique().size(0)).transpose(-2, -1) # [*, K, s_shot] 20 | centroids = one_hot.matmul(z_s) / one_hot.sum(-1, keepdim=True) # [*, K, d] 21 | 22 | return centroids 23 | 24 | 25 | def get_one_hot(y_s: Tensor, num_classes: int): 26 | """ 27 | args: 28 | y_s : torch.Tensor of shape [*] 29 | returns 30 | one_hot : torch.Tensor of shape [*, num_classes] 31 | """ 32 | one_hot_size = list(y_s.size()) + [num_classes] 33 | one_hot = torch.zeros(one_hot_size, device=y_s.device) 34 | one_hot.scatter_(-1, y_s.unsqueeze(-1), 1) 35 | 36 | return one_hot 37 | 38 | 39 | def extract_features(bs: int, 40 | support: Tensor, 41 | query: Tensor, 42 | model: nn.Module): 43 | """ 44 | Extract features from support and query set using the provided model 45 | args: 46 | x_s : torch.Tensor of size [batch, s_shot, c, h, w] 47 | returns 48 | z_s : torch.Tensor of shape [batch, s_shot, d] 49 | z_s : torch.Tensor of shape [batch, q_shot, d] 50 | """ 51 | # Extract support and query features 52 | n_tasks, shots_s, C, H, W = support.size() 53 | shots_q = query.size(1) 54 | device = 'cpu' if not torch.cuda.is_available() else dist.get_rank() 55 | 56 | if bs > 0: 57 | if n_tasks > 1: 58 | raise ValueError("Multi task and feature batching not yet supported") 59 | feat_s = batch_feature_extract(model, support, bs, device) 60 | feat_q = batch_feature_extract(model, query, bs, device) 61 | else: 62 | support = support.to(device) 63 | query = query.to(device) 64 | feat_s = model(support.view(n_tasks * shots_s, C, H, W), feature=True) 65 | feat_q = model(query.view(n_tasks * shots_q, C, H, W), feature=True) 66 | feat_s = feat_s.view(n_tasks, shots_s, -1) 67 | feat_q = feat_q.view(n_tasks, shots_q, -1) 68 | 69 | return feat_s, feat_q 70 | 71 | 72 | def batch_feature_extract(model: nn.Module, 73 | t: Tensor, 74 | bs: int, 75 | device: torch.device) -> Tensor: 76 | shots: int 77 | n_tasks, shots, C, H, W = t.size() 78 | 79 | feat: Tensor 80 | feats: List[Tensor] = [] 81 | for i in range(math.ceil(shots / bs)): 82 | start = i * bs 83 | end = min(shots, (i + 1) * bs) 84 | 85 | x = t[0, start:end, ...] 86 | x = x.to(device) 87 | 88 | feat = model(x, feature=True) 89 | feats.append(feat) 90 | 91 | feat_res = torch.cat(feats, 0).unsqueeze(0) 92 | 93 | return feat_res 94 | -------------------------------------------------------------------------------- /src/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | 7 | def ECE_(soft_preds: Tensor, 8 | targets: Tensor, 9 | reduc_dims: List[int]) -> Tensor: 10 | ''' 11 | args: 12 | 13 | soft_preds: tensor of shape [*, K] 14 | targets: tensor of shape [*,] 15 | ''' 16 | device = soft_preds.device 17 | 18 | highest_prob, hard_preds = soft_preds.max(-1) 19 | bins = torch.linspace(0, 1, 10).to(device) 20 | binned_indexes = vectorized_binning(highest_prob, bins) # one-hot [*, M] 21 | 22 | accurate = (hard_preds == targets).float().unsqueeze(-1) 23 | 24 | B = binned_indexes.sum(reduc_dims) + 1e-10 # [M,] 25 | A = (accurate * binned_indexes).sum(reduc_dims) / B # [M,] 26 | C = (highest_prob.unsqueeze(-1) * binned_indexes).sum(reduc_dims) / B # [M,] 27 | N = torch.tensor(targets.size()).prod() 28 | ECE = ((B * torch.abs(A - C))).sum(-1, keepdim=True) / N 29 | 30 | return ECE 31 | 32 | 33 | def vectorized_binning(probs: Tensor, 34 | bins: Tensor) -> Tensor: 35 | ''' 36 | args: 37 | 38 | probs: tensor of shape [*,] 39 | bins: tensor of shape [M] 40 | ''' 41 | batch_shape = probs.shape 42 | new_bin_shape = [1 for _ in range(len(batch_shape))] 43 | new_bin_shape.append(-1) 44 | 45 | bins = bins.view(*new_bin_shape) # [*, M] 46 | probs = probs.unsqueeze(-1) # [*, 1] 47 | bins_indexes = (probs >= bins).long().sum(-1, keepdim=True) # [*,] 48 | 49 | ones_hot_indexes = torch.zeros(batch_shape + bins.shape[-1:]).to(probs.device) # [*,] 50 | ones_hot_indexes.scatter_(-1, bins_indexes, 1.) # 51 | 52 | return ones_hot_indexes 53 | -------------------------------------------------------------------------------- /src/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .metric import * -------------------------------------------------------------------------------- /src/models/ingredient.py: -------------------------------------------------------------------------------- 1 | from .standard import __dict__ as standard_dict 2 | from .meta import __dict__ as meta_dict 3 | from loguru import logger 4 | import argparse 5 | 6 | 7 | def get_model(args: argparse.Namespace, 8 | num_classes: int): 9 | if 'MAML' in args.method: 10 | logger.info(f"Meta {args.arch} loaded") 11 | return meta_dict[args.arch](num_classes=num_classes) 12 | else: 13 | logger.info(f"Standard {args.arch} loaded") 14 | return standard_dict[args.arch](num_classes=num_classes, 15 | pretrained=args.load_from_timm) 16 | -------------------------------------------------------------------------------- /src/models/meta/__init__.py: -------------------------------------------------------------------------------- 1 | from .conv4 import * 2 | from .resnet import * 3 | from .wideres import * 4 | -------------------------------------------------------------------------------- /src/models/meta/conv4.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .metamodules import (MetaModule, MetaSequential, MetaConv2d, 3 | MetaBatchNorm2d, MetaLinear) 4 | 5 | 6 | def conv3x3(in_channels, out_channels, **kwargs): 7 | return MetaSequential( 8 | MetaConv2d(in_channels, out_channels, kernel_size=3, padding=1, **kwargs), 9 | MetaBatchNorm2d(out_channels, momentum=1., track_running_stats=False), 10 | nn.ReLU(), 11 | nn.MaxPool2d(2) 12 | ) 13 | 14 | 15 | class conv4(MetaModule): 16 | def __init__(self, num_classes, in_channels=3, hidden_size=64, **kwargs): 17 | super(conv4, self).__init__() 18 | self.in_channels = in_channels 19 | self.num_classes = num_classes 20 | self.hidden_size = hidden_size 21 | 22 | self.features = MetaSequential( 23 | conv3x3(in_channels, hidden_size), 24 | conv3x3(hidden_size, hidden_size), 25 | conv3x3(hidden_size, hidden_size), 26 | conv3x3(hidden_size, hidden_size) 27 | ) 28 | 29 | self.classifier = MetaLinear(1600, num_classes) 30 | 31 | def forward(self, inputs, params=None, features=False): 32 | x = self.features(inputs, params=self.get_subdict(params, 'features')) 33 | x = x.view((x.size(0), -1)) 34 | if features: 35 | return x 36 | logits = self.classifier(x, params=self.get_subdict(params, 'classifier')) 37 | return logits -------------------------------------------------------------------------------- /src/models/meta/metamodules/__init__.py: -------------------------------------------------------------------------------- 1 | from .activation import MetaMultiheadAttention 2 | from .batchnorm import MetaBatchNorm1d, MetaBatchNorm2d, MetaBatchNorm3d 3 | from .container import MetaSequential 4 | from .conv import MetaConv1d, MetaConv2d, MetaConv3d 5 | from .linear import MetaLinear, MetaBilinear 6 | from .module import MetaModule 7 | from .normalization import MetaLayerNorm 8 | from .parallel import DataParallel 9 | from .sparse import MetaEmbedding, MetaEmbeddingBag 10 | 11 | __all__ = [ 12 | 'MetaMultiheadAttention', 13 | 'MetaBatchNorm1d', 'MetaBatchNorm2d', 'MetaBatchNorm3d', 14 | 'MetaSequential', 15 | 'MetaConv1d', 'MetaConv2d', 'MetaConv3d', 16 | 'MetaLinear', 'MetaBilinear', 17 | 'MetaModule', 18 | 'MetaLayerNorm', 19 | 'DataParallel', 20 | 'MetaEmbedding', 'MetaEmbeddingBag', 21 | ] -------------------------------------------------------------------------------- /src/models/meta/metamodules/activation.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from collections import OrderedDict 5 | from .module import MetaModule 6 | from .linear import MetaLinear 7 | 8 | 9 | class MetaMultiheadAttention(nn.MultiheadAttention, MetaModule): 10 | __doc__ = nn.MultiheadAttention.__doc__ 11 | 12 | def __init__(self, *args, **kwargs): 13 | super(MetaMultiheadAttention, self).__init__(*args, **kwargs) 14 | self.out_proj = MetaLinear(self.embed_dim, self.embed_dim, bias=True) 15 | 16 | def forward(self, query, key, value, key_padding_mask=None, 17 | need_weights=True, attn_mask=None, params=None): 18 | if params is None: 19 | params = OrderedDict(self.named_parameters()) 20 | 21 | in_proj_weight = params.get('in_proj_weight', None) 22 | in_proj_bias = params.get('in_proj_bias', None) 23 | bias_k = params.get('bias_k', None) 24 | bias_v = params.get('bias_v', None) 25 | 26 | if not self._qkv_same_embed_dim: 27 | return F.multi_head_attention_forward( 28 | query, key, value, self.embed_dim, self.num_heads, 29 | in_proj_weight, in_proj_bias, 30 | bias_k, bias_v, self.add_zero_attn, 31 | self.dropout, params['out_proj.weight'], params['out_proj.bias'], 32 | training=self.training, 33 | key_padding_mask=key_padding_mask, need_weights=need_weights, 34 | attn_mask=attn_mask, use_separate_proj_weight=True, 35 | q_proj_weight=params['q_proj_weight'], 36 | k_proj_weight=params['k_proj_weight'], 37 | v_proj_weight=params['v_proj_weight']) 38 | else: 39 | return F.multi_head_attention_forward( 40 | query, key, value, self.embed_dim, self.num_heads, 41 | in_proj_weight, in_proj_bias, 42 | bias_k, bias_v, self.add_zero_attn, 43 | self.dropout, params['out_proj.weight'], params['out_proj.bias'], 44 | training=self.training, 45 | key_padding_mask=key_padding_mask, need_weights=need_weights, 46 | attn_mask=attn_mask) 47 | -------------------------------------------------------------------------------- /src/models/meta/metamodules/batchnorm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from collections import OrderedDict 5 | from torch.nn.modules.batchnorm import _BatchNorm 6 | from .module import MetaModule 7 | 8 | 9 | class _MetaBatchNorm(_BatchNorm, MetaModule): 10 | def forward(self, input, params=None): 11 | self._check_input_dim(input) 12 | if params is None: 13 | params = OrderedDict(self.named_parameters()) 14 | 15 | # exponential_average_factor is self.momentum set to 16 | # (when it is available) only so that if gets updated 17 | # in ONNX graph when this node is exported to ONNX. 18 | if self.momentum is None: 19 | exponential_average_factor = 0.0 20 | else: 21 | exponential_average_factor = self.momentum 22 | 23 | if self.training and self.track_running_stats: 24 | if self.num_batches_tracked is not None: 25 | self.num_batches_tracked += 1 26 | if self.momentum is None: # use cumulative moving average 27 | exponential_average_factor = 1.0 / float(self.num_batches_tracked) 28 | else: # use exponential moving average 29 | exponential_average_factor = self.momentum 30 | 31 | weight = params.get('weight', None) 32 | bias = params.get('bias', None) 33 | 34 | return F.batch_norm( 35 | input, self.running_mean, self.running_var, weight, bias, 36 | self.training or not self.track_running_stats, 37 | exponential_average_factor, self.eps) 38 | 39 | class MetaBatchNorm1d(_MetaBatchNorm): 40 | __doc__ = nn.BatchNorm1d.__doc__ 41 | 42 | def _check_input_dim(self, input): 43 | if input.dim() != 2 and input.dim() != 3: 44 | raise ValueError('expected 2D or 3D input (got {}D input)' 45 | .format(input.dim())) 46 | 47 | class MetaBatchNorm2d(_MetaBatchNorm): 48 | __doc__ = nn.BatchNorm2d.__doc__ 49 | 50 | def _check_input_dim(self, input): 51 | if input.dim() != 4: 52 | raise ValueError('expected 4D input (got {}D input)' 53 | .format(input.dim())) 54 | 55 | class MetaBatchNorm3d(_MetaBatchNorm): 56 | __doc__ = nn.BatchNorm3d.__doc__ 57 | 58 | def _check_input_dim(self, input): 59 | if input.dim() != 5: 60 | raise ValueError('expected 5D input (got {}D input)' 61 | .format(input.dim())) 62 | -------------------------------------------------------------------------------- /src/models/meta/metamodules/container.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .module import MetaModule 4 | 5 | 6 | class MetaSequential(nn.Sequential, MetaModule): 7 | __doc__ = nn.Sequential.__doc__ 8 | 9 | def forward(self, input, params=None): 10 | for name, module in self._modules.items(): 11 | if isinstance(module, MetaModule): 12 | # print(f'================ {module}') 13 | input = module(input, params=self.get_subdict(params, name)) 14 | elif isinstance(module, nn.Module): 15 | input = module(input) 16 | else: 17 | raise TypeError('The module must be either a torch module ' 18 | '(inheriting from `nn.Module`), or a `MetaModule`. ' 19 | 'Got type: `{0}`'.format(type(module))) 20 | return input 21 | -------------------------------------------------------------------------------- /src/models/meta/metamodules/conv.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from collections import OrderedDict 5 | from torch.nn.modules.utils import _single, _pair, _triple 6 | from .module import MetaModule 7 | 8 | class MetaConv1d(nn.Conv1d, MetaModule): 9 | __doc__ = nn.Conv1d.__doc__ 10 | 11 | def forward(self, input, params=None): 12 | if params is None: 13 | params = OrderedDict(self.named_parameters()) 14 | bias = params.get('bias', None) 15 | 16 | if self.padding_mode == 'circular': 17 | expanded_padding = ((self.padding[0] + 1) // 2, self.padding[0] // 2) 18 | return F.conv1d(F.pad(input, expanded_padding, mode='circular'), 19 | params['weight'], bias, self.stride, 20 | _single(0), self.dilation, self.groups) 21 | 22 | return F.conv1d(input, params['weight'], bias, self.stride, 23 | self.padding, self.dilation, self.groups) 24 | 25 | class MetaConv2d(nn.Conv2d, MetaModule): 26 | __doc__ = nn.Conv2d.__doc__ 27 | 28 | def forward(self, input, params=None): 29 | if params is None: 30 | params = OrderedDict(self.named_parameters()) 31 | bias = params.get('bias', None) 32 | 33 | if self.padding_mode == 'circular': 34 | expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2, 35 | (self.padding[0] + 1) // 2, self.padding[0] // 2) 36 | return F.conv2d(F.pad(input, expanded_padding, mode='circular'), 37 | params['weight'], bias, self.stride, 38 | _pair(0), self.dilation, self.groups) 39 | 40 | return F.conv2d(input, params['weight'], bias, self.stride, 41 | self.padding, self.dilation, self.groups) 42 | 43 | class MetaConv3d(nn.Conv3d, MetaModule): 44 | __doc__ = nn.Conv3d.__doc__ 45 | 46 | def forward(self, input, params=None): 47 | if params is None: 48 | params = OrderedDict(self.named_parameters()) 49 | bias = params.get('bias', None) 50 | 51 | if self.padding_mode == 'circular': 52 | expanded_padding = ((self.padding[2] + 1) // 2, self.padding[2] // 2, 53 | (self.padding[1] + 1) // 2, self.padding[1] // 2, 54 | (self.padding[0] + 1) // 2, self.padding[0] // 2) 55 | return F.conv3d(F.pad(input, expanded_padding, mode='circular'), 56 | params['weight'], bias, self.stride, 57 | _triple(0), self.dilation, self.groups) 58 | 59 | return F.conv3d(input, params['weight'], bias, self.stride, 60 | self.padding, self.dilation, self.groups) 61 | -------------------------------------------------------------------------------- /src/models/meta/metamodules/linear.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from collections import OrderedDict 5 | from .module import MetaModule 6 | 7 | 8 | class MetaLinear(nn.Linear, MetaModule): 9 | __doc__ = nn.Linear.__doc__ 10 | 11 | def forward(self, input, params=None): 12 | if params is None: 13 | params = OrderedDict(self.named_parameters()) 14 | bias = params.get('bias', None) 15 | return F.linear(input, params['weight'], bias) 16 | 17 | 18 | class MetaBilinear(nn.Bilinear, MetaModule): 19 | __doc__ = nn.Bilinear.__doc__ 20 | 21 | def forward(self, input1, input2, params=None): 22 | if params is None: 23 | params = OrderedDict(self.named_parameters()) 24 | bias = params.get('bias', None) 25 | return F.bilinear(input1, input2, params['weight'], bias) 26 | -------------------------------------------------------------------------------- /src/models/meta/metamodules/module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import re 4 | import warnings 5 | 6 | from collections import OrderedDict 7 | 8 | 9 | class MetaModule(nn.Module): 10 | """ 11 | Base class for PyTorch meta-learning modules. These modules accept an 12 | additional argument `params` in their `forward` method. 13 | 14 | Notes 15 | ----- 16 | Objects inherited from `MetaModule` are fully compatible with PyTorch 17 | modules from `torch.nn.Module`. The argument `params` is a dictionary of 18 | tensors, with full support of the computation graph (for differentiation). 19 | """ 20 | def __init__(self): 21 | super(MetaModule, self).__init__() 22 | self._children_modules_parameters_cache = dict() 23 | 24 | def meta_named_parameters(self, prefix='', recurse=True): 25 | gen = self._named_members( 26 | lambda module: module._parameters.items() 27 | if isinstance(module, MetaModule) else [], 28 | prefix=prefix, recurse=recurse) 29 | for elem in gen: 30 | yield elem 31 | 32 | def meta_parameters(self, recurse=True): 33 | for name, param in self.meta_named_parameters(recurse=recurse): 34 | yield param 35 | 36 | def get_subdict(self, params, key=None): 37 | if params is None: 38 | return None 39 | 40 | # print(params.keys()) 41 | all_names = tuple(params.keys()) 42 | if (key, all_names) not in self._children_modules_parameters_cache: 43 | if key is None: 44 | self._children_modules_parameters_cache[(key, all_names)] = all_names 45 | 46 | else: 47 | key_escape = re.escape(key) 48 | key_re = re.compile(r'^{0}\.(.+)'.format(key_escape)) 49 | 50 | self._children_modules_parameters_cache[(key, all_names)] = [ 51 | key_re.sub(r'\1', k) for k in all_names if key_re.match(k) is not None] 52 | 53 | names = self._children_modules_parameters_cache[(key, all_names)] 54 | # print(key, names) 55 | if not names: 56 | warnings.warn('Module `{0}` has no parameter corresponding to the ' 57 | 'submodule named `{1}` in the dictionary `params` ' 58 | 'provided as an argument to `forward()`. Using the ' 59 | 'default parameters for this submodule. The list of ' 60 | 'the parameters in `params`: [{2}].'.format( 61 | self.__class__.__name__, key, ', '.join(all_names)), 62 | stacklevel=2) 63 | return None 64 | 65 | return OrderedDict([(name, params[f'{key}.{name}']) for name in names]) 66 | -------------------------------------------------------------------------------- /src/models/meta/metamodules/normalization.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from collections import OrderedDict 5 | from .module import MetaModule 6 | 7 | 8 | class MetaLayerNorm(nn.LayerNorm, MetaModule): 9 | __doc__ = nn.LayerNorm.__doc__ 10 | 11 | def forward(self, input, params=None): 12 | if params is None: 13 | params = OrderedDict(self.named_parameters()) 14 | weight = params.get('weight', None) 15 | bias = params.get('bias', None) 16 | return F.layer_norm( 17 | input, self.normalized_shape, weight, bias, self.eps) 18 | -------------------------------------------------------------------------------- /src/models/meta/metamodules/parallel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import warnings 3 | 4 | from torch.nn import DataParallel as DataParallel_ 5 | from .module import MetaModule 6 | from collections import OrderedDict 7 | 8 | from torch.nn.parallel import parallel_apply 9 | from torch.nn.parallel.scatter_gather import scatter_kwargs 10 | from torch.nn.parallel.replicate import _broadcast_coalesced_reshape 11 | 12 | 13 | class DataParallel(DataParallel_, MetaModule): 14 | __doc__ = DataParallel_.__doc__ 15 | 16 | def scatter(self, inputs, kwargs, device_ids): 17 | if not isinstance(self.module, MetaModule): 18 | return super(DataParallel, self).scatter(inputs, kwargs, device_ids) 19 | 20 | params = kwargs.pop('params', None) 21 | inputs_, kwargs_ = scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) 22 | # Add params argument unchanged back in kwargs 23 | replicas = self._replicate_params(params, inputs_, device_ids, 24 | detach=not torch.is_grad_enabled()) 25 | kwargs_ = tuple(dict(params=replica, **kwarg) 26 | for (kwarg, replica) in zip(kwargs_, replicas)) 27 | return inputs_, kwargs_ 28 | 29 | def _replicate_params(self, params, inputs, device_ids, detach=False): 30 | if params is None: 31 | module_params = OrderedDict(self.module.named_parameters()) 32 | else: 33 | # Temporarily disable the warning if no parameter with key prefix 34 | # `module` was found. In that case, the original params dictionary 35 | # is used. 36 | with warnings.catch_warnings(): 37 | warnings.simplefilter('ignore') 38 | module_params = self.get_subdict(params, key='module') 39 | if module_params is None: 40 | module_params = params 41 | 42 | replicas = _broadcast_coalesced_reshape(list(module_params.values()), 43 | device_ids[:len(inputs)], 44 | detach) 45 | replicas = tuple(OrderedDict(zip(module_params.keys(), replica)) 46 | for replica in replicas) 47 | return replicas 48 | -------------------------------------------------------------------------------- /src/models/meta/metamodules/sparse.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from collections import OrderedDict 5 | from .module import MetaModule 6 | 7 | class MetaEmbedding(nn.Embedding, MetaModule): 8 | __doc__ = nn.Embedding.__doc__ 9 | 10 | def forward(self, input, params=None): 11 | if params is None: 12 | params = OrderedDict(self.named_parameters()) 13 | return F.embedding( 14 | input, params['weight'], self.padding_idx, self.max_norm, 15 | self.norm_type, self.scale_grad_by_freq, self.sparse) 16 | 17 | class MetaEmbeddingBag(nn.EmbeddingBag, MetaModule): 18 | __doc__ = nn.EmbeddingBag.__doc__ 19 | 20 | def forward(self, input, offsets=None, per_sample_weights=None, params=None): 21 | if params is None: 22 | params = OrderedDict(self.named_parameters()) 23 | return F.embedding_bag(input, params['weight'], offsets, 24 | self.max_norm, self.norm_type, 25 | self.scale_grad_by_freq, self.mode, self.sparse, 26 | per_sample_weights, self.include_last_offset) 27 | -------------------------------------------------------------------------------- /src/models/meta/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .metamodules import (MetaModule, MetaSequential, MetaConv2d, 3 | MetaBatchNorm2d, MetaLinear) 4 | __all__ = ['resnet10', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 5 | 'resnet152'] 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | """3x3 convolution with padding""" 10 | return MetaConv2d(in_planes, out_planes, kernel_size=3, stride=stride, 11 | padding=1, bias=False) 12 | 13 | 14 | def conv1x1(in_planes, out_planes, stride=1): 15 | """1x1 convolution""" 16 | return MetaConv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 17 | 18 | 19 | def BasicBlock(inplanes, planes, stride=1, downsample=None): 20 | block = MetaSequential( 21 | conv3x3(inplanes, planes, stride), 22 | MetaBatchNorm2d(planes), 23 | nn.ReLU(inplace=True), 24 | conv3x3(planes, planes), 25 | MetaBatchNorm2d(planes), 26 | nn.ReLU(inplace=True), 27 | ) 28 | block.expansion = 1 29 | return block 30 | 31 | 32 | class Bottleneck(MetaModule): 33 | expansion = 4 34 | 35 | def __init__(self, inplanes, planes, stride=1, downsample=None): 36 | super(Bottleneck, self).__init__() 37 | self.bottleneck = MetaSequential( 38 | conv1x1(inplanes, planes), 39 | MetaBatchNorm2d(planes), 40 | nn.ReLU(inplace=True), 41 | conv3x3(planes, planes, stride), 42 | MetaBatchNorm2d(planes), 43 | nn.ReLU(inplace=True), 44 | conv1x1(planes, planes * self.expansion), 45 | MetaBatchNorm2d(planes * self.expansion), 46 | nn.ReLU(inplace=True), 47 | ) 48 | 49 | def forward(self, x, params=None): 50 | return self.bottleneck(x, params=self.get_subdict(params, 'bottleneck')) 51 | 52 | 53 | class ResNet(MetaModule): 54 | 55 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, expansion=1, **kwargs): 56 | super(ResNet, self).__init__() 57 | self.inplanes = 64 58 | 59 | self.features = MetaSequential( 60 | MetaConv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False), 61 | MetaBatchNorm2d(64), 62 | nn.ReLU(inplace=True), 63 | self._make_layer(block, 64, layers[0], expansion=expansion), 64 | self._make_layer(block, 128, layers[1], stride=2, expansion=expansion), 65 | self._make_layer(block, 256, layers[2], stride=2, expansion=expansion), 66 | self._make_layer(block, 512, layers[3], stride=2, expansion=expansion), 67 | nn.AdaptiveAvgPool2d((1, 1))) 68 | self.classifier = MetaLinear(512 * expansion, num_classes) 69 | 70 | for m in self.modules(): 71 | if isinstance(m, MetaConv2d): 72 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 73 | elif isinstance(m, MetaBatchNorm2d): 74 | nn.init.constant_(m.weight, 1) 75 | nn.init.constant_(m.bias, 0) 76 | 77 | # Zero-initialize the last BN in each residual branch, 78 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 79 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 80 | if zero_init_residual: 81 | for m in self.modules(): 82 | if isinstance(m, Bottleneck): 83 | nn.init.constant_(m.bn3.weight, 0) 84 | elif isinstance(m, BasicBlock): 85 | nn.init.constant_(m.bn2.weight, 0) 86 | 87 | def _make_layer(self, block, planes, blocks, stride=1, expansion=1): 88 | downsample = None 89 | if stride != 1 or self.inplanes != planes * expansion: 90 | downsample = MetaSequential( 91 | conv1x1(self.inplanes, planes * expansion, stride), 92 | MetaBatchNorm2d(planes * expansion), 93 | ) 94 | 95 | layers = [] 96 | layers.append(block(self.inplanes, planes, stride, downsample)) 97 | self.inplanes = planes * expansion 98 | for _ in range(1, blocks): 99 | layers.append(block(self.inplanes, planes)) 100 | 101 | return MetaSequential(*layers) 102 | 103 | def forward(self, inputs, params=None, features=False): 104 | x = self.features(inputs, params=self.get_subdict(params, 'features')) 105 | x = x.view((x.size(0), -1)) 106 | if features: 107 | return x 108 | logits = self.classifier(x, params=self.get_subdict(params, 'classifier')) 109 | return logits 110 | 111 | 112 | def resnet10(**kwargs): 113 | """Constructs a ResNet-10 model. 114 | """ 115 | model = ResNet(BasicBlock, [1, 1, 1, 1], expansion=1, **kwargs) 116 | return model 117 | 118 | 119 | def resnet18(**kwargs): 120 | """Constructs a ResNet-18 model. 121 | """ 122 | model = ResNet(BasicBlock, [2, 2, 2, 2], expansion=1, **kwargs) 123 | return model 124 | 125 | 126 | def resnet34(**kwargs): 127 | """Constructs a ResNet-34 model. 128 | """ 129 | model = ResNet(BasicBlock, [3, 4, 6, 3], expansion=1, **kwargs) 130 | return model 131 | 132 | 133 | def resnet50(**kwargs): 134 | """Constructs a ResNet-50 model. 135 | """ 136 | model = ResNet(Bottleneck, [3, 4, 6, 3], expansion=4, **kwargs) 137 | return model 138 | 139 | 140 | def resnet101(**kwargs): 141 | """Constructs a ResNet-101 model. 142 | """ 143 | model = ResNet(Bottleneck, [3, 4, 23, 3], expansion=4, **kwargs) 144 | return model 145 | 146 | 147 | def resnet152(**kwargs): 148 | """Constructs a ResNet-152 model. 149 | """ 150 | model = ResNet(Bottleneck, [3, 8, 36, 3], expansion=4, **kwargs) 151 | return model 152 | -------------------------------------------------------------------------------- /src/models/meta/wideres.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.nn.init as init 6 | from torch.autograd import Variable 7 | from .metamodules import (MetaModule, MetaSequential, MetaConv2d, 8 | MetaBatchNorm2d, MetaLinear) 9 | 10 | 11 | def conv3x3(in_planes, out_planes, stride=1): 12 | return MetaConv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 13 | 14 | 15 | def conv_init(m): 16 | classname = m.__class__.__name__ 17 | if classname.find('Conv') != -1: 18 | init.xavier_uniform(m.weight, gain=np.sqrt(2)) 19 | init.constant(m.bias, 0) 20 | elif classname.find('BatchNorm') != -1: 21 | init.constant(m.weight, 1) 22 | init.constant(m.bias, 0) 23 | 24 | 25 | def wide_basic(in_planes, planes, dropout_rate, stride=1): 26 | block = MetaSequential( 27 | MetaBatchNorm2d(in_planes), 28 | nn.ReLU(inplace=True), 29 | MetaConv2d(in_planes, planes, kernel_size=3, padding=1, bias=True), 30 | nn.Dropout(p=dropout_rate), 31 | MetaBatchNorm2d(planes), 32 | nn.ReLU(inplace=True), 33 | MetaConv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 34 | ) 35 | return block 36 | 37 | 38 | class Wide_ResNet(MetaModule): 39 | def __init__(self, depth, widen_factor, dropout_rate, num_classes, use_fc=True): 40 | super(Wide_ResNet, self).__init__() 41 | self.in_planes = 16 42 | 43 | assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4' 44 | n = (depth - 4) // 6 45 | k = widen_factor 46 | 47 | print('| Wide-Resnet %dx%d' % (depth, k)) 48 | nStages = [16, 16 * k, 32 * k, 64 * k] 49 | 50 | self.features = MetaSequential( 51 | conv3x3(3, nStages[0]), 52 | self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1), 53 | self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2), 54 | self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2), 55 | MetaBatchNorm2d(nStages[3], momentum=0.9), 56 | nn.ReLU(inplace=True), 57 | nn.AdaptiveAvgPool2d((1, 1))) 58 | if use_fc: 59 | self.classifier = MetaLinear(nStages[3], num_classes) 60 | for m in self.modules(): 61 | if isinstance(m, MetaConv2d): 62 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 63 | elif isinstance(m, MetaBatchNorm2d): 64 | nn.init.constant_(m.weight, 1) 65 | nn.init.constant_(m.bias, 0) 66 | 67 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 68 | strides = [stride] + [1] * (num_blocks - 1) 69 | layers = [] 70 | 71 | for stride in strides: 72 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 73 | self.in_planes = planes 74 | 75 | return MetaSequential(*layers) 76 | 77 | def forward(self, inputs, params=None, features=False): 78 | if params: 79 | print(params.keys()) 80 | x = self.features(inputs, params=self.get_subdict(params, 'features')) 81 | x = x.view((x.size(0), -1)) 82 | if features: 83 | return x 84 | logits = self.classifier(x, params=self.get_subdict(params, 'classifier')) 85 | return logits 86 | 87 | 88 | def wideres2810(num_classes): 89 | """Constructs a wideres-28-10 model without dropout. 90 | """ 91 | return Wide_ResNet(28, 10, 0, num_classes) 92 | 93 | 94 | if __name__ == '__main__': 95 | net = Wide_ResNet(28, 10, 0.3, 10) 96 | y = net(Variable(torch.randn(1, 3, 32, 32))) -------------------------------------------------------------------------------- /src/models/standard/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | from .resnet_v2 import * 3 | from .vit import * 4 | 5 | from .resnet_v2 import default_cfgs as resnet_v2_config 6 | from .resnet import default_cfgs as resnet_config 7 | from .vit import default_cfgs as vit_config 8 | 9 | model_configs = resnet_v2_config 10 | model_configs.update(resnet_config) 11 | model_configs.update(vit_config) -------------------------------------------------------------------------------- /src/models/standard/conv4.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | def conv_block(in_channels: int, out_channels: int) -> nn.Module: 5 | return nn.Sequential( 6 | nn.Conv2d(in_channels, out_channels, 3, padding=1), 7 | nn.BatchNorm2d(out_channels), 8 | nn.ReLU(), 9 | nn.MaxPool2d(kernel_size=2, stride=2) 10 | ) 11 | 12 | 13 | class conv4(nn.Module): 14 | def __init__(self, num_classes, remove_linear=False): 15 | super(conv4, self).__init__() 16 | self.conv1 = conv_block(3, 64) 17 | self.conv2 = conv_block(64, 64) 18 | self.conv3 = conv_block(64, 64) 19 | self.conv4 = conv_block(64, 64) 20 | if remove_linear: 21 | self.logits = None 22 | else: 23 | self.logits = nn.Linear(1600, num_classes) 24 | 25 | def forward(self, x, feature=False): 26 | x = self.conv1(x) 27 | x = self.conv2(x) 28 | x = self.conv3(x) 29 | x = self.conv4(x) 30 | 31 | x = x.view(x.size(0), -1) 32 | if self.logits is None: 33 | if feature: 34 | return x, None 35 | else: 36 | return x 37 | if feature: 38 | x1 = self.logits(x) 39 | return x, x1 40 | 41 | return self.logits(x) 42 | -------------------------------------------------------------------------------- /src/models/standard/hub.py: -------------------------------------------------------------------------------- 1 | import json 2 | from loguru import logger 3 | import os 4 | from functools import partial 5 | from pathlib import Path 6 | from typing import Union 7 | 8 | import torch 9 | from torch.hub import HASH_REGEX, download_url_to_file, urlparse 10 | try: 11 | from torch.hub import get_dir 12 | except ImportError: 13 | from torch.hub import _get_torch_home as get_dir 14 | 15 | from timm import __version__ 16 | try: 17 | from huggingface_hub import HfApi, HfFolder, Repository, cached_download, hf_hub_url 18 | cached_download = partial(cached_download, library_name="timm", library_version=__version__) 19 | _has_hf_hub = True 20 | except ImportError: 21 | cached_download = None 22 | _has_hf_hub = False 23 | 24 | 25 | def get_cache_dir(child_dir=''): 26 | """ 27 | Returns the location of the directory where models are cached (and creates it if necessary). 28 | """ 29 | # Issue warning to move data if old env is set 30 | if os.getenv('TORCH_MODEL_ZOO'): 31 | logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') 32 | 33 | hub_dir = get_dir() 34 | child_dir = () if not child_dir else (child_dir,) 35 | model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir) 36 | os.makedirs(model_dir, exist_ok=True) 37 | return model_dir 38 | 39 | 40 | def download_cached_file(url, check_hash=True, progress=False): 41 | parts = urlparse(url) 42 | filename = os.path.basename(parts.path) 43 | cached_file = os.path.join(get_cache_dir(), filename) 44 | if not os.path.exists(cached_file): 45 | logger.info('Downloading: "{}" to {}\n'.format(url, cached_file)) 46 | hash_prefix = None 47 | if check_hash: 48 | r = HASH_REGEX.search(filename) # r is Optional[Match[str]] 49 | hash_prefix = r.group(1) if r else None 50 | download_url_to_file(url, cached_file, hash_prefix, progress=progress) 51 | return cached_file 52 | 53 | 54 | def has_hf_hub(necessary=False): 55 | if not _has_hf_hub and necessary: 56 | # if no HF Hub module installed and it is necessary to continue, raise error 57 | raise RuntimeError( 58 | 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') 59 | return _has_hf_hub 60 | 61 | 62 | def hf_split(hf_id): 63 | rev_split = hf_id.split('@') 64 | assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.' 65 | hf_model_id = rev_split[0] 66 | hf_revision = rev_split[-1] if len(rev_split) > 1 else None 67 | return hf_model_id, hf_revision 68 | 69 | 70 | def load_cfg_from_json(json_file: Union[str, os.PathLike]): 71 | with open(json_file, "r", encoding="utf-8") as reader: 72 | text = reader.read() 73 | return json.loads(text) 74 | 75 | 76 | def _download_from_hf(model_id: str, filename: str): 77 | hf_model_id, hf_revision = hf_split(model_id) 78 | url = hf_hub_url(hf_model_id, filename, revision=hf_revision) 79 | return cached_download(url, cache_dir=get_cache_dir('hf')) 80 | 81 | 82 | def load_model_config_from_hf(model_id: str): 83 | assert has_hf_hub(True) 84 | cached_file = _download_from_hf(model_id, 'config.json') 85 | default_cfg = load_cfg_from_json(cached_file) 86 | default_cfg['hf_hub'] = model_id # insert hf_hub id for pretrained weight load during model creation 87 | model_name = default_cfg.get('architecture') 88 | return default_cfg, model_name 89 | 90 | 91 | def load_state_dict_from_hf(model_id: str): 92 | assert has_hf_hub(True) 93 | cached_file = _download_from_hf(model_id, 'pytorch_model.bin') 94 | state_dict = torch.load(cached_file, map_location='cpu') 95 | return state_dict 96 | 97 | 98 | def save_for_hf(model, save_directory, model_config=None): 99 | assert has_hf_hub(True) 100 | model_config = model_config or {} 101 | save_directory = Path(save_directory) 102 | save_directory.mkdir(exist_ok=True, parents=True) 103 | 104 | weights_path = save_directory / 'pytorch_model.bin' 105 | torch.save(model.state_dict(), weights_path) 106 | 107 | config_path = save_directory / 'config.json' 108 | hf_config = model.default_cfg 109 | hf_config['num_classes'] = model_config.pop('num_classes', model.num_classes) 110 | hf_config['num_features'] = model_config.pop('num_features', model.num_features) 111 | hf_config['labels'] = model_config.pop('labels', [f"LABEL_{i}" for i in range(hf_config['num_classes'])]) 112 | hf_config.update(model_config) 113 | 114 | with config_path.open('w') as f: 115 | json.dump(hf_config, f, indent=2) 116 | 117 | 118 | def push_to_hf_hub( 119 | model, 120 | local_dir, 121 | repo_namespace_or_url=None, 122 | commit_message='Add model', 123 | use_auth_token=True, 124 | git_email=None, 125 | git_user=None, 126 | revision=None, 127 | model_config=None, 128 | ): 129 | if repo_namespace_or_url: 130 | repo_owner, repo_name = repo_namespace_or_url.rstrip('/').split('/')[-2:] 131 | else: 132 | if isinstance(use_auth_token, str): 133 | token = use_auth_token 134 | else: 135 | token = HfFolder.get_token() 136 | 137 | if token is None: 138 | raise ValueError( 139 | "You must login to the Hugging Face hub on this computer by typing `transformers-cli login` and " 140 | "entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own " 141 | "token as the `use_auth_token` argument." 142 | ) 143 | 144 | repo_owner = HfApi().whoami(token)['name'] 145 | repo_name = Path(local_dir).name 146 | 147 | repo_url = f'https://huggingface.co/{repo_owner}/{repo_name}' 148 | 149 | repo = Repository( 150 | local_dir, 151 | clone_from=repo_url, 152 | use_auth_token=use_auth_token, 153 | git_user=git_user, 154 | git_email=git_email, 155 | revision=revision, 156 | ) 157 | 158 | # Prepare a default model card that includes the necessary tags to enable inference. 159 | readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_name}' 160 | with repo.commit(commit_message): 161 | # Save model weights and config. 162 | save_for_hf(model, repo.local_dir, model_config=model_config) 163 | 164 | # Save a model card if it doesn't exist. 165 | readme_path = Path(repo.local_dir) / 'README.md' 166 | if not readme_path.exists(): 167 | readme_path.write_text(readme_text) 168 | 169 | return repo.git_remote_url() -------------------------------------------------------------------------------- /src/models/standard/wideres.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 9 | 10 | 11 | def conv_init(m): 12 | classname = m.__class__.__name__ 13 | if classname.find('Conv') != -1: 14 | init.xavier_uniform(m.weight, gain=np.sqrt(2)) 15 | init.constant(m.bias, 0) 16 | elif classname.find('BatchNorm') != -1: 17 | init.constant(m.weight, 1) 18 | init.constant(m.bias, 0) 19 | 20 | 21 | class wide_basic(nn.Module): 22 | def __init__(self, in_planes, planes, dropout_rate, stride=1): 23 | super(wide_basic, self).__init__() 24 | self.bn1 = nn.BatchNorm2d(in_planes) 25 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) 26 | self.dropout = nn.Dropout(p=dropout_rate) 27 | self.bn2 = nn.BatchNorm2d(planes) 28 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 29 | 30 | self.shortcut = nn.Sequential() 31 | if stride != 1 or in_planes != planes: 32 | self.shortcut = nn.Sequential( 33 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), 34 | ) 35 | 36 | def forward(self, x): 37 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 38 | out = self.conv2(F.relu(self.bn2(out))) 39 | out += self.shortcut(x) 40 | 41 | return out 42 | 43 | 44 | class Wide_ResNet(nn.Module): 45 | def __init__(self, depth, widen_factor, dropout_rate, num_classes, use_fc=True): 46 | super(Wide_ResNet, self).__init__() 47 | self.in_planes = 16 48 | 49 | assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4' 50 | n = (depth - 4) // 6 51 | k = widen_factor 52 | 53 | # print('| Wide-Resnet %dx%d' % (depth, k)) 54 | nStages = [16, 16 * k, 32 * k, 64 * k] 55 | 56 | self.conv1 = conv3x3(3, nStages[0]) 57 | self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1) 58 | self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2) 59 | self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2) 60 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9) 61 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 62 | if use_fc: 63 | self.linear = nn.Linear(nStages[3], num_classes) 64 | for m in self.modules(): 65 | if isinstance(m, nn.Conv2d): 66 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 67 | elif isinstance(m, nn.BatchNorm2d): 68 | nn.init.constant_(m.weight, 1) 69 | nn.init.constant_(m.bias, 0) 70 | 71 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 72 | strides = [stride] + [1] * (num_blocks - 1) 73 | layers = [] 74 | 75 | for stride in strides: 76 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 77 | self.in_planes = planes 78 | 79 | return nn.Sequential(*layers) 80 | 81 | def forward(self, x, feature=False): 82 | out = self.conv1(x) 83 | out = self.layer1(out) 84 | out = self.layer2(out) 85 | out = self.layer3(out) 86 | out = F.relu(self.bn1(out)) 87 | out = self.avgpool(out) 88 | out = out.view(out.size(0), -1) 89 | 90 | if feature: 91 | return out 92 | out = self.linear(out) 93 | return out 94 | 95 | 96 | def wideres2810(num_classes, use_fc): 97 | """Constructs a wideres-28-10 model without dropout. 98 | """ 99 | return Wide_ResNet(28, 10, 0, num_classes, use_fc) 100 | -------------------------------------------------------------------------------- /src/optim.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | from typing import Optional 3 | 4 | import torch.optim 5 | from torch.optim.lr_scheduler import MultiStepLR, StepLR, CosineAnnealingLR 6 | 7 | 8 | def get_scheduler(args: Namespace, 9 | optimizer: torch.optim.Optimizer) -> Optional[torch.optim.lr_scheduler._LRScheduler]: 10 | 11 | SCHEDULER = {'step': StepLR(optimizer, args.lr_stepsize, args.gamma), 12 | 'multi_step': MultiStepLR(optimizer, 13 | milestones=[int(.5 * args.num_updates), 14 | int(.75 * args.num_updates)], 15 | gamma=args.gamma), 16 | 'cosine': CosineAnnealingLR(optimizer, args.num_updates, eta_min=1e-9), 17 | None: None} 18 | 19 | assert args.scheduler in SCHEDULER.keys() 20 | 21 | return SCHEDULER[args.scheduler] 22 | 23 | 24 | def get_optimizer(args: Namespace, 25 | model: torch.nn.Module) -> torch.optim.Optimizer: 26 | OPTIMIZER = {'SGD': torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, 27 | weight_decay=args.weight_decay, nesterov=args.nesterov), 28 | 'Adam': torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)} 29 | 30 | assert args.optimizer_name in OPTIMIZER.keys() 31 | 32 | return OPTIMIZER[args.optimizer_name] 33 | -------------------------------------------------------------------------------- /src/plot.py: -------------------------------------------------------------------------------- 1 | from .utils import compute_confidence_interval 2 | from pathlib import Path 3 | from itertools import cycle 4 | from collections import defaultdict 5 | from functools import partial 6 | from typing import Any 7 | 8 | import json 9 | import argparse 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | plt.style.use('ggplot') 13 | 14 | colors = ["g", "b", "m", 'y', 'k', 'chartreuse', 'coral', 'gold', 'lavender', 15 | 'silver', 'tan', 'teal', 'wheat', 'orchid', 'orange', 'tomato'] 16 | 17 | styles = ['--', '-.', ':', '-'] 18 | 19 | 20 | def parse_args() -> argparse.Namespace: 21 | parser = argparse.ArgumentParser(description='Plot training metrics') 22 | parser.add_argument('--folder', type=str, 23 | help='Folder to search') 24 | parser.add_argument('--fontsize', type=int, default=12) 25 | parser.add_argument('--linewidth', type=int, default=2.) 26 | parser.add_argument('--fontfamily', type=str, default='sans-serif') 27 | parser.add_argument('--fontweight', type=str, default='normal') 28 | parser.add_argument('--figsize', type=int, nargs=2, default=[10, 10]) 29 | parser.add_argument('--dpi', type=int, default=200, 30 | help='Dots per inch when saving the fig') 31 | parser.add_argument('--max_col', type=int, default=1, 32 | help='Maximum number of columns for legend') 33 | 34 | parser.add_argument('--reduce_by', type=str, default='seed') 35 | parser.add_argument('--simu_params', type=str, nargs='+', 36 | default=['method', 'base_source', 'arch']) 37 | 38 | args = parser.parse_args() 39 | return args 40 | 41 | 42 | def main(args: argparse.Namespace) -> None: 43 | plt.rc('font', 44 | size=args.fontsize, 45 | family=args.fontfamily, 46 | weight=args.fontweight) 47 | 48 | # Recover all files that match .npy pattern in folder/ 49 | p = Path(args.folder) 50 | all_files = list(p.glob('**/*.npy')) 51 | 52 | if not len(all_files): 53 | print("No .pny files found in this subtree. Cancelling plotting operations.") 54 | return 55 | 56 | # Group files by metric name 57 | filenames_dic = nested_default_dict(4, str) 58 | for path in all_files: 59 | root = path.parent 60 | metric = path.stem 61 | with open(root / 'config.json') as f: 62 | config = json.load(f) 63 | fixed_key = tuple([config[key] for key in args.simu_params]) 64 | reduce_key = config[args.reduce_by] 65 | filenames_dic[metric][fixed_key][reduce_key]['num_updates'] = config['num_updates'] 66 | filenames_dic[metric][fixed_key][reduce_key]['log_freq'] = config['print_freq'] if 'train' in metric else config['eval_freq'] 67 | filenames_dic[metric][fixed_key][reduce_key]['path'] = path 68 | 69 | # Do one plot per metric 70 | for metric in filenames_dic: 71 | fig = plt.Figure(args.figsize) 72 | ax = fig.gca() 73 | for style, color, simu_args in zip(cycle(styles), cycle(colors), filenames_dic[metric]): 74 | values = [] 75 | for _, res_dic in filenames_dic[metric][simu_args].items(): 76 | values.append(np.load(res_dic['path'])[None, :]) 77 | y = np.concatenate(values, axis=0) 78 | mean, std = compute_confidence_interval(y, axis=0) 79 | n_epochs = mean.shape[0] 80 | x = np.linspace(0, n_epochs - 1, (n_epochs)) * res_dic['log_freq'] 81 | valid_iter = (mean != 0) 82 | 83 | label = "/".join([f"{k}={v}" for k, v in zip(args.simu_params, simu_args)]) 84 | ax.plot(x[valid_iter], mean[valid_iter], label=label, color=color, linestyle=style) 85 | ax.fill_between(x[valid_iter], mean[valid_iter] - std[valid_iter], 86 | mean[valid_iter] + std[valid_iter], color=color, alpha=0.3) 87 | 88 | n_cols = min(args.max_col, len(filenames_dic[metric])) 89 | ax.legend(bbox_to_anchor=(0.5, 1.05), loc='center', ncol=n_cols, shadow=True) 90 | ax.set_xlabel("Training iterations") 91 | ax.grid(True) 92 | fig.tight_layout() 93 | save_path = p / 'plots' 94 | save_path.mkdir(parents=True, exist_ok=True) 95 | fig.savefig(save_path / f'{metric}.png', dpi=args.dpi) 96 | print(f"Plots saved at {args.folder}") 97 | 98 | 99 | def nested_default_dict(depth: int, final_type: Any, i: int = 1): 100 | if i == depth: 101 | return defaultdict(final_type) 102 | fn = partial(nested_default_dict, depth=depth, final_type=final_type, i=i+1) 103 | return defaultdict(fn) 104 | 105 | 106 | if __name__ == "__main__": 107 | args = parse_args() 108 | main(args=args) 109 | -------------------------------------------------------------------------------- /src/utils/list_files.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | from typing import List, Union 4 | from os.path import join 5 | 6 | 7 | def menu_selection(paths: Union[Path, List[Path]]) -> List[Path]: 8 | """ 9 | Ask user to select directiories from the current folder 10 | """ 11 | if not isinstance(paths, list): 12 | paths = list(paths.iterdir()) 13 | paths.sort() 14 | print() 15 | for i, dir_ in enumerate(paths): 16 | printable = join(*dir_.parts[-2:]) 17 | print(f"{i} - {printable} \n") 18 | 19 | print(f"{i+1} - All \n") 20 | 21 | print("Please enter corresponding digits (space separated)? \n") 22 | str_ans = str(input()).split() 23 | selected_digits = [int(x) for x in str_ans] 24 | if (i + 1) in selected_digits: 25 | return paths 26 | else: 27 | return [paths[i] for i in selected_digits] 28 | 29 | 30 | if __name__ == '__main__': 31 | 32 | in_folder = sys.argv[1] 33 | out_folder = sys.argv[2] 34 | out_file = sys.argv[3] 35 | 36 | in_path = Path(in_folder) 37 | if "results" in out_folder: 38 | print("From which experiment to restore? \n") 39 | selected_dir = menu_selection(in_path) 40 | assert len(selected_dir) == 1, "Please select only 1 experiment" 41 | in_path = Path(selected_dir[0]) 42 | 43 | print("Which folders would you like to select? \n") 44 | src_line = menu_selection(in_path) 45 | src_line = list(map(lambda x: str(x), src_line)) # type: ignore[arg-type] 46 | if isinstance(src_line, list): 47 | src_line = ' '.join(src_line) # type: ignore[arg-type] 48 | with open(out_file, 'w') as f: 49 | f.write(src_line + "\n") # type: ignore[arg-type] 50 | 51 | if "archive" in out_folder: 52 | print("What name for the experiment? \n") 53 | exp_name = str(input()) 54 | archive_path = Path(out_folder) / exp_name 55 | archive_path.mkdir(parents=True, exist_ok=True) 56 | 57 | with open(out_file, 'a') as f: 58 | f.write(str(archive_path)) 59 | -------------------------------------------------------------------------------- /src/utils/produce_result_tables.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from markdownTable import markdownTable 3 | from loguru import logger 4 | import argparse 5 | from pathlib import Path 6 | from collections import defaultdict 7 | import numpy as np 8 | from functools import partial 9 | from typing import Dict, List, Any 10 | # file = 'sample.csv' 11 | # df = pd.read_csv(file) 12 | # markdownTable(data).setParams(row_sep = 'markdown', quote = False).getMarkdown() 13 | 14 | 15 | def main() -> None: 16 | 17 | # Recover all files that match .npy pattern in folder/ 18 | # ===== Recover all csv result files ===== 19 | p = Path('results') 20 | csv_files = list(p.glob('**/*.csv')) # every csv file represents the summary of an experiment 21 | csv_files.sort() 22 | assert len(csv_files) 23 | 24 | # ===== Recover all csv result files ===== 25 | result_dict = nested_default_dict(2, list) 26 | 27 | for file in csv_files: 28 | df = pd.read_csv(file) 29 | 30 | # Read remaining rows and add it to result_dict 31 | for index, row in df.iterrows(): 32 | result_dict[row['method']][row['test_source']] = np.round(100 * row['acc'], 2) 33 | 34 | records: List[Dict[str, Any]] = [] 35 | for method in result_dict: 36 | rec = {'Method': method} 37 | for dataset, acc in result_dict[method].items(): 38 | rec[dataset] = acc 39 | records.append(rec) 40 | 41 | mt = markdownTable(records).setParams(row_sep='markdown', quote=False).getMarkdown() 42 | logger.info(mt) 43 | 44 | 45 | def nested_default_dict(depth: int, final_type: Any, i: int = 1): 46 | if i == depth: 47 | return defaultdict(final_type) 48 | fn = partial(nested_default_dict, depth=depth, final_type=final_type, i=i+1) 49 | return defaultdict(fn) 50 | 51 | 52 | if __name__ == "__main__": 53 | main() 54 | --------------------------------------------------------------------------------