├── scripts ├── train_backbone.sh └── train_transf.sh ├── CONTRIBUTING.md ├── .gitignore ├── README.md ├── CODE_OF_CONDUCT.md ├── environment.yml ├── noise_utils.py ├── mini_imagenet.py ├── data_utils.py ├── models.py ├── few_shot_utils.py ├── LICENSE.md ├── main.py └── create_miniimagenet_outliers.py /scripts/train_backbone.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | python main.py \ 7 | --output-dir ./output/backbones/ \ 8 | --dataset miniimagenet \ 9 | --data-path ./data/miniimagenet \ 10 | --train-shot 5 \ 11 | --test-shot 5 \ 12 | --name conv/conv_backbone_5W_5K \ 13 | --conv-model conv4 \ 14 | --agg-method mean \ 15 | --random-horizontal-flip \ 16 | --random-resized-crop \ 17 | --color-jitter \ 18 | --warm-up-epochs 100 \ 19 | --step-rate 100 \ 20 | --max-epoch 1000 \ 21 | --test-support-label-noise-list 0.0 0.2 0.4 0.6 0.8 22 | -------------------------------------------------------------------------------- /scripts/train_transf.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | python main.py \ 7 | --dataset miniimagenet \ 8 | --data-path ./data/miniimagenet \ 9 | --output-dir ./output/tranfs/ \ 10 | --name trans3_conv4_5w_5k \ 11 | --conv-model conv4 \ 12 | --transformer-layers 3 \ 13 | --trans-d-model 128 \ 14 | --ortho-proj \ 15 | --cls-type rand_const \ 16 | --output-from-cls \ 17 | --binary-outlier-loss-weight 0.5 \ 18 | --clean-proto-loss-weight 1.0 \ 19 | --agg-method mean \ 20 | --random-horizontal-flip \ 21 | --random-resized-crop \ 22 | --lr 0.0005 \ 23 | --warm-up-epochs 100 \ 24 | --step-rate 250 \ 25 | --step-gamma 0.7 \ 26 | --max-epoch 2000 \ 27 | --noise-type sym_swap \ 28 | --train-support-label-noise-choices 0.2 0.4 \ 29 | --test-support-label-noise-list 0.0 0.2 0.4 0.6 0.8 \ 30 | --load-checkpoint-path ./output/backbones/conv/conv_backbone_5W_5K/checkpoint.pth \ 31 | --freeze-conv 32 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to noisy_few_shot 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to noisy_few_shot, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | 141 | # Emacs 142 | *~ 143 | 144 | # Symbolic link to data 145 | data 146 | 147 | # Output 148 | output -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Few-shot Learning with Noisy Labels 3 | [![License: CC BY-NC 4.0](https://img.shields.io/badge/License-CC%20BY--NC%204.0-lightgrey.svg)](https://creativecommons.org/licenses/by-nc/4.0/) 4 | 5 | Authors: [Kevin J Liang](https://github.com/kevinjliang), Samrudhdhi B. Rangrej, Vladan Petrovic, Tal Hassner 6 | 7 | This repository is the official PyTorch implementation of the [CVPR 2022](https://cvpr2022.thecvf.com/) paper [Few-shot Learning with Noisy Labels](https://arxiv.org/abs/2204.05494). 8 | 9 | ### Citation 10 | If you find any part of our paper or this codebase useful, please consider citing our paper: 11 | 12 | ``` 13 | @inproceedings{liang2022few, 14 | title={Few-shot learning with noisy labels}, 15 | author={Liang, Kevin J and Rangrej, Samrudhdhi B and Petrovic, Vladan and Hassner, Tal}, 16 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 17 | pages={9089--9098}, 18 | year={2022} 19 | } 20 | ``` 21 | 22 | ### License 23 | Please see [LICENSE.md](https://github.com/facebookresearch/noisy_few_shot/blob/main/LICENSE.md) for more details. 24 | 25 | ### Acknowledgements 26 | This codebase was built starting from the [learn2learn](http://learn2learn.net/) library for meta-learning, in particular for setting up baseline non-noisy few-shot methods. 27 | 28 | ## Set-up 29 | 30 | ### Environment 31 | For requirements, please see [`environment.yml`](https://github.com/facebookresearch/noisy_few_shot/blob/main/environment.yml). Experiments in the paper were run with Python 3.9.6 and Pytorch 1.9.0; other versions of Python and PyTorch may work, though are untested. 32 | 33 | ``` 34 | conda env create --file environment.yml 35 | ``` 36 | 37 | ### Datasets 38 | Place datasets in the directory `./data`. 39 | MiniImageNet and TieredImageNet datasets should download automatically when run for the first time if they do not exist already. 40 | If running experiments with ImageNet outlier noise, you will need to generate the outlier dataset from the ImageNet training set. See [`create_miniimagenet_outliers.py`](https://github.com/facebookresearch/noisy_few_shot/blob/main/create_miniimagenet_outliers.py) for how to do so. 41 | 42 | 43 | # Training and Evaluation 44 | The script [`main.py`](https://github.com/facebookresearch/noisy_few_shot/blob/main/main.py) trains a model and then runs evaluation at the specified noise levels. Running just evaluation with this script can be done by simply setting the number of train epochs to 0 (`--max-epoch 0`). 45 | 46 | This codebase uses [Visdom](https://github.com/fossasia/visdom) for visualizing training/eval curves. Although not required for the script to run, a Visdom server can be initialized with: 47 | 48 | ``` 49 | visdom 50 | ``` 51 | 52 | ## Training backbone 53 | For all models, we follow the common practice of first pre-training the backbone and then freezing it. An example command for training the conv4 backbone for MiniImageNet: 54 | 55 | ``` 56 | bash scripts/train_backbone.sh 57 | ``` 58 | 59 | This script will train the conv4 backbone and then evaluate it at the specified noise levels and noise type. This represents the performance of ProtoNet at various noise levels. Various other few-shot baselines can be run by loading this backbone and running evaluation on noisy support sets. 60 | 61 | ## Training/Evaluating TraNFS 62 | In our paper, we propose a Transformer for Noisy Few-Shot (TraNFS), which learns a dynamic noise filtering mechanism for noisy support sets. An example command for launching TraNFS training for symmetric label swap noise on MiniImageNet: 63 | 64 | ``` 65 | bash scripts/train_transf.sh 66 | ``` 67 | 68 | After training the model, this script will also evaluate the model at the specified noise levels and noise type. 69 | 70 | 71 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: nfsl 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=4.5=1_gnu 8 | - argon2-cffi=20.1.0=py39h27cfd23_1 9 | - async_generator=1.10=pyhd3eb1b0_0 10 | - attrs=21.2.0=pyhd3eb1b0_0 11 | - backcall=0.2.0=pyhd3eb1b0_0 12 | - blas=1.0=mkl 13 | - bleach=4.0.0=pyhd3eb1b0_0 14 | - brotli=1.0.9=he6710b0_2 15 | - ca-certificates=2021.7.5=h06a4308_1 16 | - certifi=2021.5.30=py39h06a4308_0 17 | - cffi=1.14.6=py39h400218f_0 18 | - configargparse=1.4=pyhd3eb1b0_0 19 | - cudatoolkit=10.2.89=hfd86e86_1 20 | - cycler=0.10.0=py39h06a4308_0 21 | - dbus=1.13.18=hb2f20db_0 22 | - decorator=5.0.9=pyhd3eb1b0_0 23 | - defusedxml=0.7.1=pyhd3eb1b0_0 24 | - entrypoints=0.3=py39h06a4308_0 25 | - expat=2.4.1=h2531618_2 26 | - fontconfig=2.13.1=h6c09931_0 27 | - fonttools=4.25.0=pyhd3eb1b0_0 28 | - freetype=2.10.4=h5ab3b9f_0 29 | - glib=2.69.0=h5202010_0 30 | - gst-plugins-base=1.14.0=h8213a91_2 31 | - gstreamer=1.14.0=h28cd5cc_2 32 | - icu=58.2=he6710b0_3 33 | - importlib-metadata=3.10.0=py39h06a4308_0 34 | - importlib_metadata=3.10.0=hd3eb1b0_0 35 | - intel-openmp=2021.3.0=h06a4308_3350 36 | - ipykernel=5.3.4=py39hb070fc8_0 37 | - ipython=7.26.0=py39hb070fc8_0 38 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 39 | - ipywidgets=7.6.3=pyhd3eb1b0_1 40 | - jedi=0.18.0=py39h06a4308_1 41 | - jinja2=3.0.1=pyhd3eb1b0_0 42 | - jpeg=9b=h024ee3a_2 43 | - jsonschema=3.2.0=py_2 44 | - jupyter=1.0.0=py39h06a4308_7 45 | - jupyter_client=6.1.12=pyhd3eb1b0_0 46 | - jupyter_console=6.4.0=pyhd3eb1b0_0 47 | - jupyter_core=4.7.1=py39h06a4308_0 48 | - jupyterlab_pygments=0.1.2=py_0 49 | - jupyterlab_widgets=1.0.0=pyhd3eb1b0_1 50 | - kiwisolver=1.3.1=py39h2531618_0 51 | - lcms2=2.12=h3be6417_0 52 | - ld_impl_linux-64=2.35.1=h7274673_9 53 | - libffi=3.3=he6710b0_2 54 | - libgcc-ng=9.3.0=h5101ec6_17 55 | - libgfortran-ng=7.5.0=ha8ba4b0_17 56 | - libgfortran4=7.5.0=ha8ba4b0_17 57 | - libgomp=9.3.0=h5101ec6_17 58 | - libpng=1.6.37=hbc83047_0 59 | - libsodium=1.0.18=h7b6447c_0 60 | - libstdcxx-ng=9.3.0=hd4cf53a_17 61 | - libtiff=4.2.0=h85742a9_0 62 | - libuuid=1.0.3=h1bed415_2 63 | - libuv=1.40.0=h7b6447c_0 64 | - libwebp-base=1.2.0=h27cfd23_0 65 | - libxcb=1.14=h7b6447c_0 66 | - libxml2=2.9.12=h03d6c58_0 67 | - lz4-c=1.9.3=h295c915_1 68 | - markupsafe=2.0.1=py39h27cfd23_0 69 | - matplotlib=3.4.2=py39h06a4308_0 70 | - matplotlib-base=3.4.2=py39hab158f2_0 71 | - matplotlib-inline=0.1.2=pyhd3eb1b0_2 72 | - mistune=0.8.4=py39h27cfd23_1000 73 | - mkl=2021.3.0=h06a4308_520 74 | - mkl-service=2.4.0=py39h7f8727e_0 75 | - mkl_fft=1.3.0=py39h42c9631_2 76 | - mkl_random=1.2.2=py39h51133e4_0 77 | - munkres=1.1.4=py_0 78 | - nbclient=0.5.3=pyhd3eb1b0_0 79 | - nbconvert=6.1.0=py39h06a4308_0 80 | - nbformat=5.1.3=pyhd3eb1b0_0 81 | - ncurses=6.2=he6710b0_1 82 | - nest-asyncio=1.5.1=pyhd3eb1b0_0 83 | - ninja=1.10.2=hff7bd54_1 84 | - notebook=6.4.1=py39h06a4308_0 85 | - numpy=1.20.3=py39hf144106_0 86 | - numpy-base=1.20.3=py39h74d4b33_0 87 | - olefile=0.46=py_0 88 | - openjpeg=2.3.0=h05c96fa_1 89 | - openssl=1.1.1k=h27cfd23_0 90 | - packaging=21.0=pyhd3eb1b0_0 91 | - pandocfilters=1.4.3=py39h06a4308_1 92 | - parso=0.8.2=pyhd3eb1b0_0 93 | - path=16.0.0=py39h06a4308_0 94 | - pcre=8.45=h295c915_0 95 | - pexpect=4.8.0=pyhd3eb1b0_3 96 | - pickleshare=0.7.5=pyhd3eb1b0_1003 97 | - pip=21.2.2=py39h06a4308_0 98 | - prometheus_client=0.11.0=pyhd3eb1b0_0 99 | - prompt-toolkit=3.0.17=pyh06a4308_0 100 | - prompt_toolkit=3.0.17=hd3eb1b0_0 101 | - ptyprocess=0.7.0=pyhd3eb1b0_2 102 | - pycparser=2.20=py_2 103 | - pygments=2.9.0=pyhd3eb1b0_0 104 | - pyparsing=2.4.7=pyhd3eb1b0_0 105 | - pyqt=5.9.2=py39h2531618_6 106 | - pyrsistent=0.18.0=py39h7f8727e_0 107 | - python=3.9.6=h12debd9_0 108 | - python-dateutil=2.8.2=pyhd3eb1b0_0 109 | - pytorch=1.9.0=py3.9_cuda10.2_cudnn7.6.5_0 110 | - pyzmq=20.0.0=py39h2531618_1 111 | - qt=5.9.7=h5867ecd_1 112 | - qtconsole=5.1.0=pyhd3eb1b0_0 113 | - qtpy=1.9.0=py_0 114 | - readline=8.1=h27cfd23_0 115 | - scipy=1.6.2=py39had2a1c9_1 116 | - send2trash=1.5.0=pyhd3eb1b0_1 117 | - setuptools=52.0.0=py39h06a4308_0 118 | - sip=4.19.13=py39h2531618_0 119 | - six=1.16.0=pyhd3eb1b0_0 120 | - sqlite=3.36.0=hc218d9a_0 121 | - terminado=0.9.4=py39h06a4308_0 122 | - testpath=0.5.0=pyhd3eb1b0_0 123 | - tk=8.6.10=hbc83047_0 124 | - tornado=6.1=py39h27cfd23_0 125 | - tqdm=4.62.0=pyhd3eb1b0_1 126 | - traitlets=5.0.5=pyhd3eb1b0_0 127 | - typing_extensions=3.10.0.0=pyh06a4308_0 128 | - tzdata=2021a=h52ac0ba_0 129 | - wcwidth=0.2.5=py_0 130 | - webencodings=0.5.1=py39h06a4308_1 131 | - wheel=0.36.2=pyhd3eb1b0_0 132 | - widgetsnbextension=3.5.1=py39h06a4308_0 133 | - xz=5.2.5=h7b6447c_0 134 | - zeromq=4.3.4=h2531618_0 135 | - zipp=3.5.0=pyhd3eb1b0_0 136 | - zlib=1.2.11=h7b6447c_3 137 | - zstd=1.4.9=haebb681_0 138 | - pip: 139 | - argcomplete==1.12.3 140 | - boto==2.49.0 141 | - charset-normalizer==2.0.4 142 | - cloudpickle==1.6.0 143 | - crcmod==1.7 144 | - cryptography==3.4.7 145 | - fasteners==0.16.3 146 | - gcs-oauth2-boto-plugin==2.7 147 | - google-apitools==0.5.32 148 | - google-reauth==0.1.1 149 | - gsutil==4.66 150 | - gym==0.18.3 151 | - httplib2==0.19.1 152 | - idna==3.2 153 | - jsonpatch==1.32 154 | - jsonpointer==2.1 155 | - learn2learn==0.1.5 156 | - mock==2.0.0 157 | - monotonic==1.6 158 | - oauth2client==4.1.3 159 | - pandas==1.3.1 160 | - pbr==5.6.0 161 | - pillow==8.2.0 162 | - pyasn1==0.4.8 163 | - pyasn1-modules==0.2.8 164 | - pyglet==1.5.15 165 | - pyopenssl==20.0.1 166 | - pytz==2021.1 167 | - pyu2f==0.1.5 168 | - requests==2.26.0 169 | - retry-decorator==1.1.1 170 | - rsa==4.7.2 171 | - torchfile==0.1.0 172 | - torchvision==0.10.0 173 | - urllib3==1.26.6 174 | - visdom==0.1.8.9 175 | - warmup-scheduler==0.3 176 | - websocket-client==1.2.1 177 | 178 | -------------------------------------------------------------------------------- /noise_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | 11 | 12 | OUTLIER = -1 13 | 14 | 15 | def gen_derangement(n): 16 | """ 17 | Generates a derangement (random permutation without any fixed points) 18 | """ 19 | in_order = np.arange(n, dtype=np.int16) 20 | while True: 21 | derangement_candidate = np.random.permutation(n) 22 | if 0 not in derangement_candidate - in_order: 23 | break 24 | return derangement_candidate 25 | 26 | 27 | def gen_swap_indices(ways, num_noise_samples, num_clean_shots, indices_to_change): 28 | shot = num_noise_samples // ways 29 | 30 | # Track unpicked options 31 | available = np.arange(num_noise_samples) 32 | available_KN = available.reshape(ways, shot) 33 | 34 | swap_indices = {} 35 | 36 | # permute for loop to avoid bias of class 0 picking first 37 | class_order = np.arange(ways) 38 | np.random.shuffle(class_order) 39 | for c in class_order: 40 | # Remove samples from class c from options 41 | available_c = [i for i in available if i not in available_KN[c]] 42 | 43 | # Track number of noise samples from each class 44 | noise_class_count = np.zeros(ways) 45 | 46 | swap_indices_c = [] 47 | 48 | for i in range(len(indices_to_change[c])): 49 | # Random pick, if there are enough choices 50 | if len(available_c) == 0: 51 | return None 52 | 53 | # Pick random sample idx from available choices; remove from availabe options 54 | noise_choice_i = np.random.choice(available_c) 55 | swap_indices_c.append(noise_choice_i) 56 | available_c.remove(noise_choice_i) 57 | 58 | # Find class of picked sample and increment class counter 59 | noise_choice_class_i = noise_choice_i // shot 60 | noise_class_count[noise_choice_class_i] += 1 61 | 62 | # Limit noise samples per class to ensure clean class has plurality 63 | # If class has 1 less than the number of clean class, remove other samples from choices 64 | if (noise_class_count[noise_choice_class_i] + 1) >= num_clean_shots: 65 | available_c = [ 66 | i 67 | for i in available_c 68 | if i not in available_KN[noise_choice_class_i] 69 | ] 70 | 71 | swap_indices[c] = swap_indices_c 72 | 73 | # Update unpicked options 74 | available = np.setdiff1d(available, swap_indices[c]) 75 | 76 | # Put swap indices back in order 77 | swap_indices = np.vstack([swap_indices[c] for c in range(ways)]) 78 | 79 | return swap_indices 80 | 81 | 82 | def gen_valid_swap_indices(ways, shot, indices_to_change): 83 | num_noise_samples = ways * shot 84 | num_clean_shots = shot - len(indices_to_change[0]) 85 | 86 | while True: 87 | swap_indices = gen_swap_indices( 88 | ways, num_noise_samples, num_clean_shots, indices_to_change 89 | ) 90 | 91 | if swap_indices is not None: 92 | break 93 | 94 | return swap_indices 95 | 96 | 97 | def add_noise( 98 | data, 99 | labels, 100 | mask_indices, 101 | ways, 102 | noise_fraction, 103 | noise_type="sym_swap", 104 | outlier_data=None, 105 | ): 106 | """ 107 | Adds either label swap noise (symmetric or paired) or outlier noise 108 | """ 109 | # Calculate number of noisy samples 110 | num_idx = np.arange(data.shape[0], dtype=np.int16)[mask_indices["support"]] 111 | shot = int(len(num_idx) / ways) 112 | noise_num = int(round(noise_fraction * shot)) 113 | 114 | # Select the indices of samples to noise 115 | indices_to_change = np.empty((0, noise_num), dtype=np.int16) 116 | for i in range(ways): 117 | class_i_idx = num_idx[i * shot : (i + 1) * shot] 118 | indices_to_change_i = np.random.choice(class_i_idx, noise_num, replace=False) 119 | indices_to_change = np.vstack((indices_to_change, indices_to_change_i)) 120 | indices_to_change_flat = indices_to_change.flatten() 121 | 122 | # Make copy of data and labels 123 | noised_data = torch.clone(data) 124 | noised_labels = torch.clone(labels) 125 | 126 | # Replace clean data with noisy data at selected positions 127 | if noise_type == "sym_swap": 128 | # Swap data positions of selected samples, ensuring plurality of clean class 129 | swap_indices = gen_valid_swap_indices(ways, shot, indices_to_change) 130 | noised_data[indices_to_change_flat] = data[mask_indices["noise"]][ 131 | swap_indices.flatten() 132 | ] 133 | noised_labels[indices_to_change_flat] = labels[mask_indices["noise"]][ 134 | swap_indices.flatten() 135 | ] 136 | 137 | elif noise_type == "pair_swap": 138 | # Randomly select class pairs for noise 139 | drng = gen_derangement(ways) 140 | 141 | # Reorganize data and labels by class 142 | noisy_data = data[mask_indices["noise"]].reshape( 143 | ways, shot, data.shape[1], data.shape[2], data.shape[3] 144 | ) 145 | noisy_labels = labels[mask_indices["noise"]].reshape(ways, shot) 146 | 147 | # Swap in noisy data from noisy samples 148 | for c in range(ways): 149 | noised_data[indices_to_change[c]] = noisy_data[drng[c]][:noise_num] 150 | noised_labels[indices_to_change[c]] = noisy_labels[drng[c]][:noise_num] 151 | 152 | elif noise_type == "outlier": 153 | # Randomly select outlier indices 154 | # Draw new indices rather than reuse indices_to_change so outlier classes don't line up 155 | outlier_select_indices = np.random.choice( 156 | len(outlier_data), len(indices_to_change_flat), replace=False 157 | ) 158 | noised_data[indices_to_change_flat] = outlier_data[outlier_select_indices] 159 | noised_labels[indices_to_change_flat] = OUTLIER 160 | 161 | else: 162 | raise NotImplementedError 163 | 164 | # Location of noised data 165 | noise_positions = np.zeros(data.shape[0]) 166 | noise_positions[indices_to_change_flat] = 1 167 | 168 | return noised_data, noised_labels, noise_positions 169 | -------------------------------------------------------------------------------- /mini_imagenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Copyright (c) Meta Platforms, Inc. and affiliates. 5 | All rights reserved. 6 | This source code is licensed under the license found in the 7 | LICENSE file in the root directory of this source tree. 8 | """ 9 | 10 | from __future__ import print_function 11 | 12 | import os 13 | import pickle 14 | 15 | import numpy as np 16 | import torch 17 | import torch.utils.data as data 18 | 19 | from learn2learn.data.utils import download_file_from_google_drive, download_file 20 | 21 | 22 | def download_pkl(google_drive_id, data_root, mode): 23 | filename = "mini-imagenet-cache-" + mode 24 | file_path = os.path.join(data_root, filename) 25 | 26 | if not os.path.exists(file_path + ".pkl"): 27 | print("Downloading:", file_path + ".pkl") 28 | download_file_from_google_drive(google_drive_id, file_path + ".pkl") 29 | else: 30 | print("Data was already downloaded") 31 | 32 | 33 | def index_classes(items): 34 | idx = {} 35 | for i in items: 36 | if i not in idx: 37 | idx[i] = len(idx) 38 | return idx 39 | 40 | 41 | class MiniImagenet(data.Dataset): 42 | """ 43 | [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/vision/datasets/mini_imagenet.py) 44 | 45 | **Description** 46 | 47 | The *mini*-ImageNet dataset was originally introduced by Vinyals et al., 2016. 48 | 49 | 50 | It consists of 60'000 colour images of sizes 84x84 pixels. 51 | The dataset is divided in 3 splits of 64 training, 16 validation, and 20 testing classes each containing 600 examples. 52 | The classes are sampled from the ImageNet dataset, and we use the splits from Ravi & Larochelle, 2017. 53 | 54 | **References** 55 | 56 | 1. Vinyals et al. 2016. “Matching Networks for One Shot Learning.” NeurIPS. 57 | 2. Ravi and Larochelle. 2017. “Optimization as a Model for Few-Shot Learning.” ICLR. 58 | 59 | **Arguments** 60 | 61 | * **root** (str) - Path to download the data. 62 | * **mode** (str, *optional*, default='train') - Which split to use. 63 | Must be 'train', 'validation', or 'test'. 64 | * **transform** (Transform, *optional*, default=None) - Input pre-processing. 65 | * **target_transform** (Transform, *optional*, default=None) - Target pre-processing. 66 | 67 | **Example** 68 | 69 | ~~~python 70 | train_dataset = l2l.vision.datasets.MiniImagenet(root='./data', mode='train') 71 | train_dataset = l2l.data.MetaDataset(train_dataset) 72 | train_generator = l2l.data.TaskGenerator(dataset=train_dataset, ways=ways) 73 | ~~~ 74 | 75 | """ 76 | 77 | def __init__( 78 | self, root, mode="train", transform=None, target_transform=None, download=False 79 | ): 80 | super(MiniImagenet, self).__init__() 81 | self.root = os.path.expanduser(root) 82 | if not os.path.exists(self.root): 83 | os.mkdir(self.root) 84 | self.transform = transform 85 | self.target_transform = target_transform 86 | self.mode = mode 87 | self._bookkeeping_path = os.path.join( 88 | self.root, "mini-imagenet-bookkeeping-" + mode + ".pkl" 89 | ) 90 | if self.mode == "test": 91 | google_drive_file_id = "1wpmY-hmiJUUlRBkO9ZDCXAcIpHEFdOhD" 92 | dropbox_file_link = "https://www.dropbox.com/s/ye9jeb5tyz0x01b/mini-imagenet-cache-test.pkl?dl=1" 93 | elif self.mode == "train": 94 | google_drive_file_id = "1I3itTXpXxGV68olxM5roceUMG8itH9Xj" 95 | dropbox_file_link = "https://www.dropbox.com/s/9g8c6w345s2ek03/mini-imagenet-cache-train.pkl?dl=1" 96 | elif self.mode == "validation": 97 | google_drive_file_id = "1KY5e491bkLFqJDp0-UWou3463Mo8AOco" 98 | dropbox_file_link = "https://www.dropbox.com/s/ip1b7se3gij3r1b/mini-imagenet-cache-validation.pkl?dl=1" 99 | else: 100 | raise ("ValueError", "Needs to be train, test or validation") 101 | 102 | pickle_file = os.path.join(self.root, "mini-imagenet-cache-" + mode + ".pkl") 103 | try: 104 | if not self._check_exists() and download: 105 | print("Downloading mini-ImageNet --", mode) 106 | download_pkl(google_drive_file_id, self.root, mode) 107 | with open(pickle_file, "rb") as f: 108 | self.data = pickle.load(f) 109 | except pickle.UnpicklingError: 110 | if not self._check_exists() and download: 111 | print("Download failed. Re-trying mini-ImageNet --", mode) 112 | download_file(dropbox_file_link, pickle_file) 113 | with open(pickle_file, "rb") as f: 114 | self.data = pickle.load(f) 115 | 116 | # self.x = torch.from_numpy(self.data["image_data"]).permute(0, 3, 1, 2).float() 117 | self.x = self.data["image_data"] 118 | self.y = np.ones(len(self.x)) 119 | 120 | # TODO Remove index_classes from here 121 | self.class_idx = index_classes(self.data["class_dict"].keys()) 122 | for class_name, idxs in self.data["class_dict"].items(): 123 | for idx in idxs: 124 | self.y[idx] = self.class_idx[class_name] 125 | 126 | def __getitem__(self, idx): 127 | data = self.x[idx] 128 | if self.transform: 129 | data = self.transform(data) 130 | return data, self.y[idx] 131 | 132 | def __len__(self): 133 | return len(self.x) 134 | 135 | def _check_exists(self): 136 | return os.path.exists( 137 | os.path.join(self.root, "mini-imagenet-cache-" + self.mode + ".pkl") 138 | ) 139 | 140 | 141 | class MiniImagenetOutlier(data.Dataset): 142 | def __init__(self, root, mode="train", transform=None): 143 | super(MiniImagenetOutlier, self).__init__() 144 | self.root = os.path.expanduser(root) 145 | self.transform = transform 146 | 147 | pickle_file = os.path.join( 148 | self.root, "mini-imagenet-outlier-cache-" + mode + ".pkl" 149 | ) 150 | with open(pickle_file, "rb") as f: 151 | self.data = pickle.load(f) 152 | 153 | # self.x = torch.from_numpy(self.data["image_data"]).permute(0, 3, 1, 2).float() 154 | self.x = self.data["image_data"] 155 | self.y = np.ones(len(self.x)) * -1 156 | 157 | def __getitem__(self, idx): 158 | data = self.x[idx] 159 | if self.transform: 160 | data = self.transform(data) 161 | return data, self.y[idx] 162 | 163 | def __len__(self): 164 | return len(self.x) 165 | 166 | 167 | if __name__ == "__main__": 168 | mi = MiniImagenet(root="./data", download=True) 169 | __import__("pdb").set_trace() 170 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import os 9 | 10 | import numpy as np 11 | import torch 12 | from torch.utils.data import DataLoader 13 | from torchvision import transforms 14 | 15 | import learn2learn as l2l 16 | from learn2learn.data.transforms import NWays, KShots, LoadData, RemapLabels 17 | from learn2learn.vision.datasets import TieredImagenet 18 | 19 | from mini_imagenet import MiniImagenet, MiniImagenetOutlier 20 | 21 | 22 | IMAGENET_MEANS = (0.485, 0.456, 0.406) 23 | IMAGENET_STDS = (0.229, 0.224, 0.225) 24 | IMAGENET_SIZE = (84, 84) 25 | 26 | 27 | def get_augmentation_transforms(args): 28 | transform_list = [] 29 | 30 | if args.dataset == "miniimagenet": 31 | transform_list.append(transforms.ToPILImage()) 32 | 33 | if args.random_horizontal_flip: 34 | transform_list.append(transforms.RandomHorizontalFlip(p=0.5)) 35 | 36 | if args.random_resized_crop: 37 | if "imagenet" in args.dataset: 38 | transform_list.append( 39 | transforms.RandomResizedCrop(size=IMAGENET_SIZE, scale=(0.8, 1.0)) 40 | ) 41 | else: 42 | raise NotImplementedError 43 | 44 | if args.color_jitter: 45 | transform_list.append( 46 | transforms.ColorJitter( 47 | brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1 48 | ) 49 | ) 50 | 51 | transform_list.append(transforms.ToTensor()) 52 | 53 | if "imagenet" in args.dataset: 54 | transform_list.append(transforms.Normalize(IMAGENET_MEANS, IMAGENET_STDS)) 55 | else: 56 | raise NotImplementedError 57 | 58 | if args.random_erasing: 59 | transform_list.append(transforms.RandomErasing(value="random")) 60 | 61 | return transforms.Compose(transform_list) 62 | 63 | 64 | def get_data_loaders(args): 65 | if "imagenet" in args.dataset: 66 | normalize_transform = transforms.Normalize(IMAGENET_MEANS, IMAGENET_STDS) 67 | else: 68 | raise NotImplementedError 69 | 70 | if args.dataset == "miniimagenet": 71 | train_dataset = MiniImagenet( 72 | root=args.data_path, 73 | mode="train", 74 | transform=get_augmentation_transforms(args), 75 | ) 76 | valid_dataset = MiniImagenet( 77 | root=args.data_path, 78 | mode="validation", 79 | transform=transforms.Compose([transforms.ToTensor(), normalize_transform]), 80 | ) 81 | test_dataset = MiniImagenet( 82 | root=args.data_path, 83 | mode="test", 84 | transform=transforms.Compose([transforms.ToTensor(), normalize_transform]), 85 | ) 86 | elif args.dataset == "tieredimagenet": 87 | train_dataset = TieredImagenet( 88 | root=args.data_path, 89 | mode="train", 90 | transform=get_augmentation_transforms(args), 91 | ) 92 | valid_dataset = TieredImagenet( 93 | root=args.data_path, 94 | mode="validation", 95 | transform=transforms.Compose([transforms.ToTensor(), normalize_transform]), 96 | ) 97 | test_dataset = TieredImagenet( 98 | root=args.data_path, 99 | mode="test", 100 | transform=transforms.Compose([transforms.ToTensor(), normalize_transform]), 101 | ) 102 | else: 103 | raise NotImplementedError 104 | 105 | if args.noise_type == "outlier" and "imagenet" in args.dataset: 106 | outlier_train_dataset = MiniImagenetOutlier( 107 | root=os.path.join(args.data_path, "../miniimagenet_outlier"), 108 | mode="train", 109 | transform=transforms.Compose([transforms.ToTensor(), normalize_transform]), 110 | ) 111 | outlier_test_dataset = MiniImagenetOutlier( 112 | root=os.path.join(args.data_path, "../miniimagenet_outlier"), 113 | mode="test", 114 | transform=transforms.Compose([transforms.ToTensor(), normalize_transform]), 115 | ) 116 | 117 | train_dataset = l2l.data.MetaDataset(train_dataset) 118 | train_transforms = [ 119 | NWays(train_dataset, args.train_way), 120 | KShots(train_dataset, args.train_query + 2 * args.train_shot), 121 | LoadData(train_dataset), 122 | RemapLabels(train_dataset), 123 | ] 124 | train_tasks = l2l.data.TaskDataset(train_dataset, task_transforms=train_transforms) 125 | train_loader = DataLoader(train_tasks, pin_memory=True, shuffle=True) 126 | 127 | valid_dataset = l2l.data.MetaDataset(valid_dataset) 128 | valid_transforms = [ 129 | NWays(valid_dataset, args.test_way), 130 | KShots(valid_dataset, args.test_query + 2 * args.test_shot), 131 | LoadData(valid_dataset), 132 | RemapLabels(valid_dataset), 133 | ] 134 | valid_tasks = l2l.data.TaskDataset( 135 | valid_dataset, task_transforms=valid_transforms, num_tasks=200 136 | ) 137 | valid_loader = DataLoader(valid_tasks, pin_memory=True, shuffle=True) 138 | 139 | test_dataset = l2l.data.MetaDataset(test_dataset) 140 | test_transforms = [ 141 | NWays(test_dataset, args.test_way), 142 | KShots(test_dataset, args.test_query + 2 * args.test_shot), 143 | LoadData(test_dataset), 144 | RemapLabels(test_dataset), 145 | ] 146 | test_tasks = l2l.data.TaskDataset( 147 | test_dataset, task_transforms=test_transforms, num_tasks=args.test_tasks 148 | ) 149 | test_loader = DataLoader(test_tasks, pin_memory=True, shuffle=True) 150 | 151 | if args.noise_type == "outlier": 152 | outlier_train_loader = DataLoader( 153 | outlier_train_dataset, 154 | batch_size=(args.train_way * args.train_shot), 155 | pin_memory=True, 156 | shuffle=True, 157 | ) 158 | outlier_test_loader = DataLoader( 159 | outlier_test_dataset, 160 | batch_size=(args.test_way * args.test_shot), 161 | pin_memory=True, 162 | shuffle=True, 163 | ) 164 | else: 165 | outlier_train_loader = None 166 | outlier_test_loader = None 167 | 168 | return { 169 | "train": train_loader, 170 | "valid": valid_loader, 171 | "test": test_loader, 172 | "outlier_train": outlier_train_loader, 173 | "outlier_test": outlier_test_loader, 174 | } 175 | 176 | 177 | ##################### 178 | # Data Pre-processing 179 | ##################### 180 | 181 | 182 | def preprocess_data_labels(batch, device): 183 | data, labels = batch 184 | data = data.to(device) 185 | labels = labels.to(device) 186 | 187 | # Sort data samples by labels 188 | # TODO: Can this be replaced by ConsecutiveLabels ? 189 | sort = torch.sort(labels) 190 | data = data.squeeze(0)[sort.indices].squeeze(0) 191 | labels = labels.squeeze(0)[sort.indices].squeeze(0) 192 | return data, labels 193 | 194 | 195 | def get_support_noise_query_indices(ways, shot, query_num): 196 | # Init support and noise arrays 197 | support_indices = np.zeros(ways * (2 * shot + query_num), dtype=bool) 198 | noise_indices = support_indices.copy() 199 | 200 | # Marker for beginning of each of the ways 201 | selection = np.arange(ways) * (2 * shot + query_num) 202 | 203 | # Mark support indices, starting from the beginning of each class 204 | for offset in range(shot): 205 | support_indices[selection + offset] = True 206 | noise_indices[selection + shot + offset] = True 207 | # Query indices are those that aren't support or noise 208 | query_indices = ~(support_indices | noise_indices) 209 | 210 | # Convert to torch 211 | return { 212 | "support": torch.from_numpy(support_indices), 213 | "noise": torch.from_numpy(noise_indices), 214 | "query": torch.from_numpy(query_indices), 215 | } 216 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import learn2learn as l2l 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | 13 | from few_shot_utils import gen_prototypes 14 | 15 | 16 | class Convnet(nn.Module): 17 | def __init__(self, model="conv4", conv_proj_dim=None): 18 | super().__init__() 19 | if model == "conv4": 20 | self.encoder = l2l.vision.models.ConvBase( 21 | output_size=64, hidden=64, channels=3, max_pool=True 22 | ) 23 | self.conv_out_dims = 1600 24 | else: 25 | raise NotImplementedError 26 | 27 | if conv_proj_dim: 28 | self.fc = nn.Linear(self.conv_out_dims, conv_proj_dim) 29 | self.conv_out_dims = conv_proj_dim 30 | 31 | def forward(self, x): 32 | x = self.encoder(x) 33 | x = x.view(x.size(0), -1) 34 | if hasattr(self, "fc"): 35 | x = self.fc(x) 36 | return x 37 | 38 | 39 | class Transformer(nn.Module): 40 | def __init__( 41 | self, 42 | ways, 43 | shot, 44 | num_layers, 45 | nhead, 46 | d_model, 47 | dim_feedforward, 48 | device, 49 | cls_type="cls_learn", 50 | pos_type="pos_learn", 51 | agg_method="mean", 52 | transformer_metric="dot_prod", 53 | ): 54 | super().__init__() 55 | self.ways = ways 56 | self.shot = shot 57 | 58 | self.cls_type = cls_type 59 | self.pos_type = pos_type 60 | self.agg_method = agg_method 61 | 62 | if self.cls_type == "cls_learn": 63 | self.cls_embeddings = nn.Embedding( 64 | max(ways["train"], ways["test"]), dim_feedforward 65 | ) 66 | elif self.cls_type == "rand_const": 67 | self.cls_embeddings = nn.Embedding( 68 | max(ways["train"], ways["test"]), dim_feedforward 69 | ).requires_grad_(False) 70 | 71 | if self.pos_type == "pos_learn": 72 | self.pos_embeddings = nn.Embedding( 73 | max(ways["train"], ways["test"]), dim_feedforward 74 | ) 75 | elif self.pos_type == "rand_const": 76 | self.pos_embeddings = nn.Embedding( 77 | max(ways["train"], ways["test"]), dim_feedforward 78 | ).requires_grad_(False) 79 | 80 | encoder_layer = nn.TransformerEncoderLayer( 81 | d_model, 82 | nhead, 83 | dim_feedforward, 84 | dropout=0.1, 85 | activation="relu", 86 | ) 87 | self.encoder = nn.TransformerEncoder(encoder_layer, num_layers) 88 | 89 | self.device = device 90 | 91 | def forward(self, x): 92 | ways = self.ways["train"] if self.training else self.ways["test"] 93 | shot = self.shot["train"] if self.training else self.shot["test"] 94 | 95 | n_arng = torch.arange(ways, device=self.device) 96 | 97 | # Concatenate cls tokens with support embeddings 98 | if self.cls_type in ["cls_learn", "rand_const"]: 99 | cls_tokens = self.cls_embeddings(n_arng) # (ways, dim) 100 | elif self.cls_type == "proto": 101 | cls_tokens = gen_prototypes(x, ways, shot, self.agg_method) # (ways, dim) 102 | else: 103 | raise NotImplementedError 104 | 105 | cls_sup_embeds = torch.cat((cls_tokens, x), dim=0) # (ways*(shot+1), dim) 106 | cls_sup_embeds = torch.unsqueeze( 107 | cls_sup_embeds, dim=1 108 | ) # (ways*(shot+1), BS, dim) 109 | 110 | # Position embeddings based on class ID 111 | pos_idx = torch.cat((n_arng, torch.repeat_interleave(n_arng, shot))) 112 | pos_tokens = torch.unsqueeze( 113 | self.pos_embeddings(pos_idx), dim=1 114 | ) # (ways*(shot+1), BS, dim) 115 | 116 | # Inputs combined with position encoding 117 | transformer_input = cls_sup_embeds + pos_tokens 118 | 119 | return self.encoder(transformer_input) 120 | 121 | 122 | class BinaryOutlierDetector(nn.Module): 123 | def __init__(self, dim): 124 | super().__init__() 125 | 126 | self.dim = dim 127 | self.fc = nn.Linear(self.dim, 1) 128 | 129 | def forward(self, x): 130 | return self.fc(x) 131 | 132 | 133 | class ProtoConvnet(nn.Module): 134 | def __init__( 135 | self, 136 | ways, 137 | shot, 138 | model="conv4", 139 | conv_proj_dim=None, 140 | agg_method="mean", 141 | binary_outlier_detection=False, 142 | ): 143 | super().__init__() 144 | 145 | self.convnet = Convnet(model, conv_proj_dim) 146 | 147 | self.binary_outlier_detection = binary_outlier_detection 148 | if self.binary_outlier_detection: 149 | self.binary_outlier_detector = BinaryOutlierDetector( 150 | self.convnet.conv_out_dims 151 | ) 152 | 153 | self.ways = ways 154 | self.shot = shot 155 | self.agg_method = agg_method 156 | 157 | def forward(self, x, support_indices): 158 | ways = self.ways["train"] if self.training else self.ways["test"] 159 | shot = self.shot["train"] if self.training else self.shot["test"] 160 | 161 | outputs = {} 162 | 163 | # Embed inputs with conv 164 | embeddings = self.convnet(x) 165 | outputs["conv_embeddings"] = embeddings 166 | 167 | # Calculate prototypes 168 | support_embeddings = embeddings[support_indices] 169 | outputs["support_prototypes"] = gen_prototypes( 170 | support_embeddings, ways, shot, self.agg_method 171 | ) 172 | 173 | if self.binary_outlier_detection: 174 | outputs["outlier_logits"] = self.binary_outlier_detector(support_embeddings) 175 | 176 | return outputs 177 | 178 | 179 | class ProtoConvTransformer(nn.Module): 180 | def __init__( 181 | self, 182 | ways, 183 | shot, 184 | device, 185 | conv_model="conv4", 186 | conv_proj_dim=None, 187 | trans_layers=1, 188 | trans_d_model=None, 189 | nhead=8, 190 | ortho_proj=False, 191 | ortho_proj_residual=False, 192 | cls_type="cls_learn", 193 | pos_type="pos_learn", 194 | agg_method="mean", 195 | transformer_metric="dot_prod", 196 | output_from_cls=True, 197 | binary_outlier_detection=False, 198 | ): 199 | super().__init__() 200 | 201 | self.convnet = Convnet(conv_model, conv_proj_dim) 202 | if trans_d_model is None: 203 | trans_d_model = self.convnet.conv_out_dims 204 | self.transformer = Transformer( 205 | ways, 206 | shot, 207 | trans_layers, 208 | nhead, 209 | trans_d_model, 210 | trans_d_model, 211 | device, 212 | cls_type, 213 | pos_type, 214 | transformer_metric=transformer_metric, 215 | ) 216 | 217 | self.ortho_proj = ortho_proj 218 | if self.ortho_proj: 219 | self.proj_trans_in = nn.Parameter( 220 | torch.nn.init.orthogonal_( 221 | torch.empty(self.convnet.conv_out_dims, trans_d_model) 222 | ) 223 | ) 224 | self.proj_trans_out = nn.Parameter(self.proj_trans_in.data.detach()) 225 | self.ortho_proj_residual = ortho_proj_residual 226 | 227 | self.out_dim = self.convnet.conv_out_dims if self.ortho_proj else trans_d_model 228 | 229 | self.binary_outlier_detection = binary_outlier_detection 230 | if self.binary_outlier_detection: 231 | self.binary_outlier_detector = BinaryOutlierDetector(self.out_dim) 232 | 233 | self.ways = ways 234 | self.shot = shot 235 | 236 | self.agg_method = agg_method 237 | self.output_from_cls = output_from_cls 238 | 239 | self.device = device 240 | 241 | def forward(self, x, support_indices): 242 | ways = self.ways["train"] if self.training else self.ways["test"] 243 | shot = self.shot["train"] if self.training else self.shot["test"] 244 | 245 | outputs = {} 246 | 247 | # Embed inputs with conv 248 | embeddings = self.convnet(x) 249 | outputs["conv_embeddings"] = embeddings 250 | 251 | # Process embeddings as a sequence input to the transformer, projecting before and after if necessary 252 | transformer_input = ( 253 | embeddings[support_indices] @ self.proj_trans_in 254 | if self.ortho_proj 255 | else embeddings[support_indices] 256 | ) 257 | transformer_output = self.transformer(transformer_input).squeeze( 258 | 1 259 | ) # squeeze out batch size dim 260 | if self.ortho_proj: 261 | transformer_output = transformer_output @ self.proj_trans_out.T 262 | 263 | if self.ortho_proj_residual: 264 | conv_support_prototypes = gen_prototypes( 265 | embeddings[support_indices], ways, shot, self.agg_method 266 | ) 267 | input_skip = torch.cat( 268 | [conv_support_prototypes, embeddings[support_indices]], dim=0 269 | ) 270 | transformer_output = input_skip + transformer_output 271 | outputs["trans_output"] = transformer_output 272 | 273 | if self.output_from_cls: 274 | # first #ways outputs in seq correspond to the cls tokens, and thus the prototypes 275 | outputs["support_prototypes"] = transformer_output[:ways] 276 | else: 277 | # Aggregate prototypes from support sample positions of each class 278 | outputs["support_prototypes"] = gen_prototypes( 279 | transformer_output[ways:], ways, shot, self.agg_method 280 | ) 281 | 282 | if self.binary_outlier_detection: 283 | outputs["outlier_logits"] = self.binary_outlier_detector( 284 | transformer_output[ways:] 285 | ) 286 | 287 | return outputs 288 | -------------------------------------------------------------------------------- /few_shot_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from data_utils import preprocess_data_labels, get_support_noise_query_indices 14 | from noise_utils import add_noise 15 | 16 | 17 | def pairwise_distances_logits(a, b): 18 | n = a.shape[0] 19 | m = b.shape[0] 20 | logits = -( 21 | (a.unsqueeze(1).expand(n, m, -1) - b.unsqueeze(0).expand(n, m, -1)) ** 2 22 | ).sum(dim=2) 23 | return logits 24 | 25 | 26 | def pairwise_cosine_logits(a, b, epsilon=1e-6): 27 | # Normalize all embeddings to unit vectors 28 | norm_a = a / (torch.norm(a, dim=1, keepdim=True) + epsilon) 29 | norm_b = b / (torch.norm(b, dim=1, keepdim=True) + epsilon) 30 | 31 | # Calculate cosine angle between all support samples in each class 32 | cos = norm_a @ norm_b.T 33 | 34 | return cos 35 | 36 | 37 | def accuracy(predictions, targets, binary=False): 38 | if binary: 39 | predictions = (predictions > 0.5).float().view(targets.shape) 40 | else: 41 | predictions = predictions.argmax(dim=1).view(targets.shape) 42 | return (predictions == targets).sum().float() / targets.size(0) 43 | 44 | 45 | ##################### 46 | # Prototypes 47 | ##################### 48 | 49 | 50 | def gen_prototypes(embeddings, ways, shots, agg_method="mean"): 51 | assert ( 52 | embeddings.size(0) == ways * shots 53 | ), "# of embeddings ({}) doesn't match ways ({}) and shots ({})".format( 54 | embeddings.size(0), ways, shots 55 | ) 56 | 57 | embeddings = embeddings.reshape(ways, shots, -1) 58 | mean_embeddings = embeddings.mean(dim=1) 59 | 60 | if agg_method == "mean": 61 | return mean_embeddings 62 | 63 | elif agg_method == "median": 64 | # Init median as mean 65 | median_embeddings = torch.unsqueeze(mean_embeddings, dim=1) 66 | c = 0.5 67 | for i in range(5): 68 | errors = median_embeddings - embeddings 69 | # Poor man's Newton's method 70 | denom = torch.sqrt(torch.sum(errors ** 2, axis=2, keepdims=True) + c ** 2) 71 | dw = -torch.sum(errors / denom, axis=1, keepdims=True) / torch.sum( 72 | 1.0 / denom, axis=1, keepdims=True 73 | ) 74 | median_embeddings += dw 75 | return torch.squeeze(median_embeddings, dim=1) 76 | 77 | elif ( 78 | agg_method.startswith("cosine") 79 | or agg_method.startswith("euclidean") 80 | or agg_method.startswith("abs") 81 | ): 82 | epsilon = 1e-6 83 | 84 | if agg_method.startswith("cosine"): 85 | # Normalize all embeddings to unit vectors 86 | norm_embeddings = embeddings / ( 87 | torch.norm(embeddings, dim=2, keepdim=True) + epsilon 88 | ) 89 | # Calculate cosine angle between all support samples in each class: ways x shots x shots 90 | # Make negative, as higher cosine angle means greater correlation 91 | cos = torch.bmm(norm_embeddings, norm_embeddings.permute(0, 2, 1)) 92 | attn = (torch.sum(cos, dim=1) - 1) / (shots - 1) 93 | elif agg_method.startswith("euclidean"): 94 | # dist: ways x shots x shots 95 | dist = ( 96 | (embeddings.unsqueeze(dim=2) - embeddings.unsqueeze(dim=1)) ** 2 97 | ).sum(dim=-1) 98 | attn = -torch.sum(dist, dim=1) / (shots - 1) 99 | elif agg_method.startswith("abs"): 100 | # dist: ways x shots x shots 101 | dist = ( 102 | torch.abs(embeddings.unsqueeze(dim=2) - embeddings.unsqueeze(dim=1)) 103 | ).sum(dim=-1) 104 | attn = -torch.sum(dist, dim=1) / (shots - 1) 105 | 106 | # Parse softmax temperature (default=1) 107 | T = float(agg_method.split("_")[-1]) if "_" in agg_method else 1 108 | weights = F.softmax(attn / T, dim=1).unsqueeze(dim=2) 109 | weighted_embeddings = embeddings * weights 110 | return weighted_embeddings.sum(dim=1) 111 | 112 | else: 113 | raise NotImplementedError 114 | 115 | 116 | def gen_subset_prototypes( 117 | support_embeddings, shot, ways, subset_proportion=0.4, num_ensembles=10 118 | ): 119 | support_prototypes_sets = [] 120 | num_support_samples = int(subset_proportion * shot) 121 | mask_array = [True] * num_support_samples + [False] * (shot - num_support_samples) 122 | 123 | for i in range(num_ensembles): 124 | subsample_indices_mask_i = np.concatenate( 125 | [np.random.permutation(mask_array) for _ in range(ways)] 126 | ) 127 | sampled_support_embeddings = support_embeddings[subsample_indices_mask_i] 128 | support_prototypes_i = sampled_support_embeddings.reshape( 129 | ways, num_support_samples, -1 130 | ).mean(dim=1) 131 | support_prototypes_sets.append(support_prototypes_i) 132 | 133 | return support_prototypes_sets 134 | 135 | 136 | def subset_prototype_predictions(subset_protos, query_embeddings, ways): 137 | # Make prediction with all sub-prototypes 138 | predictions = [] 139 | for subset_protos_i in subset_protos: 140 | logits_i = pairwise_distances_logits(query_embeddings, subset_protos_i) 141 | predictions.append(torch.argmax(logits_i, dim=1)) 142 | ensemble_predictions = torch.vstack(predictions) 143 | 144 | # Pick most common prediction as final prediction 145 | final_predictions = [] 146 | for preds_i in ensemble_predictions.T: 147 | # Count up predictions of ensemble per class 148 | pred_counts_i = torch.bincount(preds_i, minlength=ways) 149 | # Find the class(es) with the most predictions 150 | (max_preds_i,) = torch.where(pred_counts_i == pred_counts_i.max()) 151 | # Randomly select one of the classes with max predictions (if tie) 152 | random_index = np.random.randint(len(max_preds_i)) 153 | final_predictions.append(max_preds_i[random_index]) 154 | 155 | return torch.hstack(final_predictions) 156 | 157 | 158 | ##################### 159 | # MetaTrainer 160 | ##################### 161 | 162 | 163 | class MetaTrainer: 164 | def __init__(self, model, device, args): 165 | self.model = model 166 | self.device = device 167 | self.args = args 168 | 169 | if args.dist_metric == "euclidean": 170 | self.metric = pairwise_distances_logits 171 | elif args.dist_metric == "cosine": 172 | self.metric = pairwise_cosine_logits 173 | else: 174 | raise NotImplementedError 175 | 176 | def prepare_data( 177 | self, 178 | batch, 179 | ways, 180 | shot, 181 | query_num, 182 | support_label_noise_choices, 183 | outlier_batch=None, 184 | ): 185 | # Format data and place on device 186 | data, labels = preprocess_data_labels(batch, self.device) 187 | if outlier_batch: 188 | outlier_data, _ = preprocess_data_labels(outlier_batch, self.device) 189 | else: 190 | outlier_data = None 191 | 192 | # Separate query and support 193 | mask_indices = get_support_noise_query_indices(ways, shot, query_num) 194 | 195 | # Add noise 196 | support_label_noise = np.random.choice(support_label_noise_choices) 197 | noise_positions = np.zeros(data.shape[0]) 198 | if support_label_noise > 0: 199 | data, labels, noise_positions = add_noise( 200 | data, 201 | labels, 202 | mask_indices, 203 | ways, 204 | support_label_noise, 205 | self.args.noise_type, 206 | outlier_data, 207 | ) 208 | 209 | # Remove noise batch data for swapping, so conv doesn't need to process 210 | data = data[~mask_indices["noise"]] 211 | labels = labels[~mask_indices["noise"]] 212 | noise_positions = noise_positions[~mask_indices["noise"]] 213 | mask_indices["support"] = mask_indices["support"][~mask_indices["noise"]] 214 | mask_indices["query"] = mask_indices["query"][~mask_indices["noise"]] 215 | 216 | return data, labels, mask_indices, noise_positions 217 | 218 | def loss_acc(self, batch, outlier_batch=None, support_label_noise_choices=[0]): 219 | ways = self.args.train_way if self.model.training else self.args.test_way 220 | shot = self.args.train_shot if self.model.training else self.args.test_shot 221 | query_num = ( 222 | self.args.train_query if self.model.training else self.args.test_query 223 | ) 224 | 225 | # Format data, place on device, add noise 226 | data, labels, mask_indices, noise_positions = self.prepare_data( 227 | batch, ways, shot, query_num, support_label_noise_choices, outlier_batch 228 | ) 229 | 230 | # Compute support/query embeddings and prototypes 231 | outputs = self.model(data, mask_indices["support"]) 232 | support_embeddings = outputs["conv_embeddings"][mask_indices["support"]] 233 | query_embeddings = outputs["conv_embeddings"][mask_indices["query"]] 234 | 235 | # Note: support labels assumes class labels 0..(N-1), in order. labels has original labels before noise 236 | support_labels = torch.repeat_interleave(torch.arange(ways), shot).to( 237 | self.device 238 | ) 239 | query_labels = labels[mask_indices["query"]].long() 240 | 241 | losses = {} 242 | accs = {} 243 | 244 | # Prototype-based comparison 245 | if self.args.comp_method == "proto": 246 | support_prototypes = outputs["support_prototypes"] 247 | 248 | # Compare query embeddings with prototypes 249 | logits = ( 250 | self.metric(query_embeddings, support_prototypes) 251 | / self.args.logit_temperature 252 | ) 253 | 254 | accs["accuracy"] = accuracy(logits, query_labels) 255 | losses["cross_entropy"] = F.cross_entropy(logits, query_labels) 256 | 257 | # Match query embedding with support embeddings 258 | elif self.args.comp_method == "match": 259 | # Create one-hot labels for support set 260 | one_hot_support_labels = torch.zeros(len(support_labels), ways).to( 261 | self.device 262 | ) 263 | one_hot_support_labels[np.arange(len(support_labels)), support_labels] = 1 264 | 265 | # Normalize embeddings and compute cosine distance, then softmax 266 | support_embeddings_unit = support_embeddings / torch.norm( 267 | support_embeddings, dim=1, keepdim=True 268 | ) 269 | query_embeddings_unit = query_embeddings / torch.norm( 270 | query_embeddings, dim=1, keepdim=True 271 | ) 272 | similarities = F.softmax( 273 | query_embeddings_unit @ support_embeddings_unit.T, dim=1 274 | ) 275 | 276 | # Output probability is the similarity scores multiplied with the labels 277 | p_y = similarities @ one_hot_support_labels 278 | 279 | accs["accuracy"] = accuracy(p_y, query_labels) 280 | losses["cross_entropy"] = F.nll_loss(p_y.log(), query_labels) 281 | 282 | # Train linear classifier 283 | elif self.args.comp_method == "lin_cls": 284 | # Instantiate linear classifier 285 | lin_cls = nn.Linear(self.model.convnet.conv_out_dims, ways).to(self.device) 286 | opt = torch.optim.AdamW(lin_cls.parameters(), lr=1e-3, weight_decay=0.01) 287 | 288 | for i in range(101): 289 | opt.zero_grad() 290 | lin_y = lin_cls(support_embeddings) 291 | loss = F.cross_entropy(lin_y, support_labels) 292 | loss.backward() 293 | opt.step() 294 | 295 | lin_y_val = lin_cls(query_embeddings) 296 | accs["accuracy"] = accuracy(lin_y_val, query_labels) 297 | 298 | else: 299 | raise NotImplementedError 300 | 301 | # Auxiliary loss: protonet of conv embeddings 302 | if self.args.embed_proto_loss_weight > 0: 303 | embed_protos = gen_prototypes(support_embeddings, ways, shot) 304 | embed_protos_logits = metric(query_embeddings, embed_protos) 305 | embed_protos_loss = self.args.embed_proto_loss_weight * F.cross_entropy( 306 | embed_protos_logits, query_labels 307 | ) 308 | losses["conv_embed_proto"] = embed_protos_loss 309 | 310 | # Binary Outlier loss: identify transformer embeddings as noisy samples or not 311 | if self.args.binary_outlier_loss_weight > 0: 312 | noise_labels = ( 313 | torch.tensor(noise_positions[mask_indices["support"]]) 314 | .unsqueeze(dim=1) 315 | .to(self.device) 316 | ) 317 | accs["binary_outlier"] = accuracy( 318 | outputs["outlier_logits"], noise_labels, binary=True 319 | ) 320 | losses["binary_outlier"] = ( 321 | self.args.binary_outlier_loss_weight 322 | * F.binary_cross_entropy_with_logits( 323 | outputs["outlier_logits"], noise_labels 324 | ) 325 | ) 326 | 327 | # Clean Prototype matching: loss to match predicted prototype with clean exmples (reject noisy samples) 328 | if self.args.clean_proto_loss_weight > 0: 329 | clean_support_embeddings = support_embeddings[ 330 | ~noise_positions[mask_indices["support"]].astype(bool) 331 | ] 332 | num_clean_shots = shot - int( 333 | noise_positions[mask_indices["support"]].sum() / ways 334 | ) 335 | clean_protos = gen_prototypes( 336 | clean_support_embeddings, ways, num_clean_shots 337 | ) 338 | 339 | # Distance of clean protos from predicted protos 340 | if self.args.dist_metric == "euclidean": 341 | losses["dist_to_clean_protos"] = ( 342 | self.args.clean_proto_loss_weight 343 | * ((outputs["support_prototypes"] - clean_protos) ** 2).sum() 344 | / ways 345 | ) 346 | elif self.args.dist_metric == "cosine": 347 | epsilon = 1e-6 348 | norm_protos = outputs["support_prototypes"] / ( 349 | torch.norm(outputs["support_prototypes"], dim=1, keepdim=True) 350 | + epsilon 351 | ) 352 | norm_clean_protos = clean_protos / ( 353 | torch.norm(clean_protos, dim=1, keepdim=True) + epsilon 354 | ) 355 | 356 | # Negative, because we want to maximize cosine angle 357 | losses["dist_to_clean_protos"] = ( 358 | -self.args.clean_proto_loss_weight 359 | * (norm_protos.flatten() * norm_clean_protos.flatten()).sum() 360 | / ways 361 | ) 362 | else: 363 | raise NotImplementedError 364 | 365 | return losses, accs 366 | 367 | def nearest_neighbor( 368 | self, batch, outlier_batch=None, support_label_noise_choices=[0] 369 | ): 370 | ways = self.args.train_way if self.model.training else self.args.test_way 371 | shot = self.args.train_shot if self.model.training else self.args.test_shot 372 | query_num = ( 373 | self.args.train_query if self.model.training else self.args.test_query 374 | ) 375 | 376 | assert self.args.comp_method.startswith("nearest") 377 | k = int(self.args.comp_method[8:]) # comp_method specified as "nearest_[k]" 378 | 379 | # Format data, place on device, add noise 380 | data, labels, mask_indices, noise_positions = self.prepare_data( 381 | batch, ways, shot, query_num, support_label_noise_choices, outlier_batch 382 | ) 383 | 384 | # Compute support/query embeddings 385 | outputs = self.model(data, mask_indices["support"]) 386 | support_embeddings = outputs["conv_embeddings"][mask_indices["support"]] 387 | query_embeddings = outputs["conv_embeddings"][mask_indices["query"]] 388 | 389 | # Note: support labels assumes class labels 0..(N-1), in order 390 | support_labels = torch.repeat_interleave(torch.arange(ways), shot).to( 391 | self.device 392 | ) 393 | query_labels = labels[mask_indices["query"]].long() 394 | 395 | # Find distance between each query and all support samples 396 | distances = self.metric(query_embeddings, support_embeddings) 397 | 398 | # Sort queries by distance and find labels of closest samples 399 | sorted_idx = torch.argsort(distances, axis=1, descending=True) 400 | topk = support_labels[sorted_idx][:, :k] 401 | 402 | # Convert to one-hot 403 | one_hot_labels = torch.zeros((ways * query_num * k, ways)).to(self.device) 404 | one_hot_labels[np.arange(ways * query_num * k), topk.flatten()] = 1 405 | one_hot_labels = one_hot_labels.reshape(ways * query_num, k, ways) 406 | 407 | # Tally votes, add small noise as random tiebreaker 408 | tallied = ( 409 | one_hot_labels.mean(dim=1) 410 | + torch.randn((ways * query_num, ways)).to(self.device) / 1e4 411 | ) 412 | predictions = torch.argmax(tallied, dim=1) 413 | 414 | acc = (predictions == query_labels).sum().float() / query_labels.size(0) 415 | 416 | return {}, {"accuracy": acc} 417 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | 2 | Attribution-NonCommercial 4.0 International 3 | 4 | ======================================================================= 5 | 6 | Creative Commons Corporation ("Creative Commons") is not a law firm and 7 | does not provide legal services or legal advice. Distribution of 8 | Creative Commons public licenses does not create a lawyer-client or 9 | other relationship. Creative Commons makes its licenses and related 10 | information available on an "as-is" basis. Creative Commons gives no 11 | warranties regarding its licenses, any material licensed under their 12 | terms and conditions, or any related information. Creative Commons 13 | disclaims all liability for damages resulting from their use to the 14 | fullest extent possible. 15 | 16 | Using Creative Commons Public Licenses 17 | 18 | Creative Commons public licenses provide a standard set of terms and 19 | conditions that creators and other rights holders may use to share 20 | original works of authorship and other material subject to copyright 21 | and certain other rights specified in the public license below. The 22 | following considerations are for informational purposes only, are not 23 | exhaustive, and do not form part of our licenses. 24 | 25 | Considerations for licensors: Our public licenses are 26 | intended for use by those authorized to give the public 27 | permission to use material in ways otherwise restricted by 28 | copyright and certain other rights. Our licenses are 29 | irrevocable. Licensors should read and understand the terms 30 | and conditions of the license they choose before applying it. 31 | Licensors should also secure all rights necessary before 32 | applying our licenses so that the public can reuse the 33 | material as expected. Licensors should clearly mark any 34 | material not subject to the license. This includes other CC- 35 | licensed material, or material used under an exception or 36 | limitation to copyright. More considerations for licensors: 37 | wiki.creativecommons.org/Considerations_for_licensors 38 | 39 | Considerations for the public: By using one of our public 40 | licenses, a licensor grants the public permission to use the 41 | licensed material under specified terms and conditions. If 42 | the licensor's permission is not necessary for any reason--for 43 | example, because of any applicable exception or limitation to 44 | copyright--then that use is not regulated by the license. Our 45 | licenses grant only permissions under copyright and certain 46 | other rights that a licensor has authority to grant. Use of 47 | the licensed material may still be restricted for other 48 | reasons, including because others have copyright or other 49 | rights in the material. A licensor may make special requests, 50 | such as asking that all changes be marked or described. 51 | Although not required by our licenses, you are encouraged to 52 | respect those requests where reasonable. More_considerations 53 | for the public: 54 | wiki.creativecommons.org/Considerations_for_licensees 55 | 56 | ======================================================================= 57 | 58 | Creative Commons Attribution-NonCommercial 4.0 International Public 59 | License 60 | 61 | By exercising the Licensed Rights (defined below), You accept and agree 62 | to be bound by the terms and conditions of this Creative Commons 63 | Attribution-NonCommercial 4.0 International Public License ("Public 64 | License"). To the extent this Public License may be interpreted as a 65 | contract, You are granted the Licensed Rights in consideration of Your 66 | acceptance of these terms and conditions, and the Licensor grants You 67 | such rights in consideration of benefits the Licensor receives from 68 | making the Licensed Material available under these terms and 69 | conditions. 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | Section 2 -- Scope. 142 | 143 | a. License grant. 144 | 145 | 1. Subject to the terms and conditions of this Public License, 146 | the Licensor hereby grants You a worldwide, royalty-free, 147 | non-sublicensable, non-exclusive, irrevocable license to 148 | exercise the Licensed Rights in the Licensed Material to: 149 | 150 | a. reproduce and Share the Licensed Material, in whole or 151 | in part, for NonCommercial purposes only; and 152 | 153 | b. produce, reproduce, and Share Adapted Material for 154 | NonCommercial purposes only. 155 | 156 | 2. Exceptions and Limitations. For the avoidance of doubt, where 157 | Exceptions and Limitations apply to Your use, this Public 158 | License does not apply, and You do not need to comply with 159 | its terms and conditions. 160 | 161 | 3. Term. The term of this Public License is specified in Section 162 | 6(a). 163 | 164 | 4. Media and formats; technical modifications allowed. The 165 | Licensor authorizes You to exercise the Licensed Rights in 166 | all media and formats whether now known or hereafter created, 167 | and to make technical modifications necessary to do so. The 168 | Licensor waives and/or agrees not to assert any right or 169 | authority to forbid You from making technical modifications 170 | necessary to exercise the Licensed Rights, including 171 | technical modifications necessary to circumvent Effective 172 | Technological Measures. For purposes of this Public License, 173 | simply making modifications authorized by this Section 2(a) 174 | (4) never produces Adapted Material. 175 | 176 | 5. Downstream recipients. 177 | 178 | a. Offer from the Licensor -- Licensed Material. Every 179 | recipient of the Licensed Material automatically 180 | receives an offer from the Licensor to exercise the 181 | Licensed Rights under the terms and conditions of this 182 | Public License. 183 | 184 | b. No downstream restrictions. You may not offer or impose 185 | any additional or different terms or conditions on, or 186 | apply any Effective Technological Measures to, the 187 | Licensed Material if doing so restricts exercise of the 188 | Licensed Rights by any recipient of the Licensed 189 | Material. 190 | 191 | 6. No endorsement. Nothing in this Public License constitutes or 192 | may be construed as permission to assert or imply that You 193 | are, or that Your use of the Licensed Material is, connected 194 | with, or sponsored, endorsed, or granted official status by, 195 | the Licensor or others designated to receive attribution as 196 | provided in Section 3(a)(1)(A)(i). 197 | 198 | b. Other rights. 199 | 200 | 1. Moral rights, such as the right of integrity, are not 201 | licensed under this Public License, nor are publicity, 202 | privacy, and/or other similar personality rights; however, to 203 | the extent possible, the Licensor waives and/or agrees not to 204 | assert any such rights held by the Licensor to the limited 205 | extent necessary to allow You to exercise the Licensed 206 | Rights, but not otherwise. 207 | 208 | 2. Patent and trademark rights are not licensed under this 209 | Public License. 210 | 211 | 3. To the extent possible, the Licensor waives any right to 212 | collect royalties from You for the exercise of the Licensed 213 | Rights, whether directly or through a collecting society 214 | under any voluntary or waivable statutory or compulsory 215 | licensing scheme. In all other cases the Licensor expressly 216 | reserves any right to collect such royalties, including when 217 | the Licensed Material is used other than for NonCommercial 218 | purposes. 219 | 220 | Section 3 -- License Conditions. 221 | 222 | Your exercise of the Licensed Rights is expressly made subject to the 223 | following conditions. 224 | 225 | a. Attribution. 226 | 227 | 1. If You Share the Licensed Material (including in modified 228 | form), You must: 229 | 230 | a. retain the following if it is supplied by the Licensor 231 | with the Licensed Material: 232 | 233 | i. identification of the creator(s) of the Licensed 234 | Material and any others designated to receive 235 | attribution, in any reasonable manner requested by 236 | the Licensor (including by pseudonym if 237 | designated); 238 | 239 | ii. a copyright notice; 240 | 241 | iii. a notice that refers to this Public License; 242 | 243 | iv. a notice that refers to the disclaimer of 244 | warranties; 245 | 246 | v. a URI or hyperlink to the Licensed Material to the 247 | extent reasonably practicable; 248 | 249 | b. indicate if You modified the Licensed Material and 250 | retain an indication of any previous modifications; and 251 | 252 | c. indicate the Licensed Material is licensed under this 253 | Public License, and include the text of, or the URI or 254 | hyperlink to, this Public License. 255 | 256 | 2. You may satisfy the conditions in Section 3(a)(1) in any 257 | reasonable manner based on the medium, means, and context in 258 | which You Share the Licensed Material. For example, it may be 259 | reasonable to satisfy the conditions by providing a URI or 260 | hyperlink to a resource that includes the required 261 | information. 262 | 263 | 3. If requested by the Licensor, You must remove any of the 264 | information required by Section 3(a)(1)(A) to the extent 265 | reasonably practicable. 266 | 267 | 4. If You Share Adapted Material You produce, the Adapter's 268 | License You apply must not prevent recipients of the Adapted 269 | Material from complying with this Public License. 270 | 271 | Section 4 -- Sui Generis Database Rights. 272 | 273 | Where the Licensed Rights include Sui Generis Database Rights that 274 | apply to Your use of the Licensed Material: 275 | 276 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 277 | to extract, reuse, reproduce, and Share all or a substantial 278 | portion of the contents of the database for NonCommercial purposes 279 | only; 280 | 281 | b. if You include all or a substantial portion of the database 282 | contents in a database in which You have Sui Generis Database 283 | Rights, then the database in which You have Sui Generis Database 284 | Rights (but not its individual contents) is Adapted Material; and 285 | 286 | c. You must comply with the conditions in Section 3(a) if You Share 287 | all or a substantial portion of the contents of the database. 288 | 289 | For the avoidance of doubt, this Section 4 supplements and does not 290 | replace Your obligations under this Public License where the Licensed 291 | Rights include other Copyright and Similar Rights. 292 | 293 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 294 | 295 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 296 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 297 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 298 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 299 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 300 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 301 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 302 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 303 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 304 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 305 | 306 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 307 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 308 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 309 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 310 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 311 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 312 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 313 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 314 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 315 | 316 | c. The disclaimer of warranties and limitation of liability provided 317 | above shall be interpreted in a manner that, to the extent 318 | possible, most closely approximates an absolute disclaimer and 319 | waiver of all liability. 320 | 321 | Section 6 -- Term and Termination. 322 | 323 | a. This Public License applies for the term of the Copyright and 324 | Similar Rights licensed here. However, if You fail to comply with 325 | this Public License, then Your rights under this Public License 326 | terminate automatically. 327 | 328 | b. Where Your right to use the Licensed Material has terminated under 329 | Section 6(a), it reinstates: 330 | 331 | 1. automatically as of the date the violation is cured, provided 332 | it is cured within 30 days of Your discovery of the 333 | violation; or 334 | 335 | 2. upon express reinstatement by the Licensor. 336 | 337 | For the avoidance of doubt, this Section 6(b) does not affect any 338 | right the Licensor may have to seek remedies for Your violations 339 | of this Public License. 340 | 341 | c. For the avoidance of doubt, the Licensor may also offer the 342 | Licensed Material under separate terms or conditions or stop 343 | distributing the Licensed Material at any time; however, doing so 344 | will not terminate this Public License. 345 | 346 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 347 | License. 348 | 349 | Section 7 -- Other Terms and Conditions. 350 | 351 | a. The Licensor shall not be bound by any additional or different 352 | terms or conditions communicated by You unless expressly agreed. 353 | 354 | b. Any arrangements, understandings, or agreements regarding the 355 | Licensed Material not stated herein are separate from and 356 | independent of the terms and conditions of this Public License. 357 | 358 | Section 8 -- Interpretation. 359 | 360 | a. For the avoidance of doubt, this Public License does not, and 361 | shall not be interpreted to, reduce, limit, restrict, or impose 362 | conditions on any use of the Licensed Material that could lawfully 363 | be made without permission under this Public License. 364 | 365 | b. To the extent possible, if any provision of this Public License is 366 | deemed unenforceable, it shall be automatically reformed to the 367 | minimum extent necessary to make it enforceable. If the provision 368 | cannot be reformed, it shall be severed from this Public License 369 | without affecting the enforceability of the remaining terms and 370 | conditions. 371 | 372 | c. No term or condition of this Public License will be waived and no 373 | failure to comply consented to unless expressly agreed to by the 374 | Licensor. 375 | 376 | d. Nothing in this Public License constitutes or may be interpreted 377 | as a limitation upon, or waiver of, any privileges and immunities 378 | that apply to the Licensor or You, including from the legal 379 | processes of any jurisdiction or authority. 380 | 381 | ======================================================================= 382 | 383 | Creative Commons is not a party to its public 384 | licenses. Notwithstanding, Creative Commons may elect to apply one of 385 | its public licenses to material it publishes and in those instances 386 | will be considered the “Licensor.” The text of the Creative Commons 387 | public licenses is dedicated to the public domain under the CC0 Public 388 | Domain Dedication. Except for the limited purpose of indicating that 389 | material is shared under a Creative Commons public license or as 390 | otherwise permitted by the Creative Commons policies published at 391 | creativecommons.org/policies, Creative Commons does not authorize the 392 | use of the trademark "Creative Commons" or any other trademark or logo 393 | of Creative Commons without its prior written consent including, 394 | without limitation, in connection with any unauthorized modifications 395 | to any of its public licenses or any other arrangements, 396 | understandings, or agreements concerning use of licensed material. For 397 | the avoidance of doubt, this paragraph does not form part of the 398 | public licenses. 399 | 400 | Creative Commons may be contacted at creativecommons.org. 401 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Copyright (c) Meta Platforms, Inc. and affiliates. 5 | All rights reserved. 6 | This source code is licensed under the license found in the 7 | LICENSE file in the root directory of this source tree. 8 | """ 9 | 10 | import argparse 11 | import datetime 12 | import numpy as np 13 | import os 14 | import time 15 | 16 | import torch 17 | import torch.nn as nn 18 | from torch.utils.data import DataLoader 19 | from visdom import Visdom 20 | 21 | import learn2learn as l2l 22 | from learn2learn.data.transforms import NWays, KShots, LoadData, RemapLabels 23 | from warmup_scheduler import GradualWarmupScheduler 24 | 25 | from data_utils import get_data_loaders 26 | from models import ProtoConvnet, ProtoConvTransformer 27 | from few_shot_utils import MetaTrainer 28 | 29 | 30 | if __name__ == "__main__": 31 | train_start = time.time() 32 | 33 | parser = argparse.ArgumentParser() 34 | # Few-shot setting 35 | parser.add_argument("--train-shot", type=int, default=5) 36 | parser.add_argument("--train-way", type=int, default=5) 37 | parser.add_argument("--train-query", type=int, default=15) 38 | parser.add_argument("--test-shot", type=int, default=5) 39 | parser.add_argument("--test-way", type=int, default=5) 40 | parser.add_argument("--test-query", type=int, default=15) 41 | parser.add_argument("--test-tasks", type=int, default=10000) 42 | # Method 43 | parser.add_argument( 44 | "--agg-method", 45 | type=str, 46 | default="mean", 47 | help="Aggregation method. Choices: mean, median, cosine_[T], euclidean_[T], abs_[T]", 48 | ) 49 | parser.add_argument( 50 | "--comp-method", 51 | type=str, 52 | default="proto", 53 | help="Comparison method. Choices: proto, nearest_[k], match, lin_cls", 54 | ) 55 | parser.add_argument( 56 | "--dist-metric", 57 | type=str, 58 | default="euclidean", 59 | help="Distance metric. Choices: euclidean, cosine", 60 | ) 61 | parser.add_argument("--logit-temperature", type=float, default=1.0) 62 | # Dataset 63 | parser.add_argument("--dataset", type=str, default="miniimagenet") 64 | parser.add_argument("--data-path", type=str, default="data/miniimagenet") 65 | parser.add_argument("--random-horizontal-flip", action="store_true") 66 | parser.add_argument("--random-resized-crop", action="store_true") 67 | parser.add_argument("--color-jitter", action="store_true") 68 | parser.add_argument("--random-erasing", action="store_true") 69 | # Noise args 70 | parser.add_argument("--noise-type", type=str, default="sym_swap") 71 | parser.add_argument( 72 | "--train-support-label-noise-choices", type=float, nargs="+", default=[0.0] 73 | ) 74 | parser.add_argument("--train-query-label-noise", type=float, default=0.0) 75 | parser.add_argument( 76 | "--test-support-label-noise-list", type=float, nargs="+", default=[0.0] 77 | ) 78 | parser.add_argument("--binary-outlier-loss-weight", type=float, default=0.0) 79 | parser.add_argument("--clean-proto-loss-weight", type=float, default=0.0) 80 | # Conv/Transformer args 81 | parser.add_argument("--conv-model", type=str, default="conv4") 82 | parser.add_argument("--freeze-conv", action="store_true") 83 | parser.add_argument("--conv-proj-dim", type=int, default=None) 84 | parser.add_argument("--transformer-layers", type=int, default=0) 85 | parser.add_argument("--trans-d-model", type=int, default=128) 86 | parser.add_argument("--nhead", type=int, default=8) 87 | parser.add_argument("--transformer-metric", type=str, default="dot_prod") 88 | parser.add_argument("--ortho-proj", action="store_true") 89 | parser.add_argument("--ortho-proj-residual", action="store_true") 90 | parser.add_argument("--cls-type", type=str, default="cls_learn") 91 | parser.add_argument("--pos-type", type=str, default="pos_learn") 92 | parser.add_argument("--output-from-cls", action="store_true") 93 | parser.add_argument("--embed-proto-loss-weight", type=float, default=0.0) 94 | parser.add_argument("--embed-proto-loss-weight-gamma", type=float, default=0.8) 95 | parser.add_argument("--embed-proto-loss-weight-step", type=int, default=100) 96 | # Learning rate args 97 | parser.add_argument("--optimizer", type=str, default="adamw") 98 | parser.add_argument("--scheduler", type=str, default="steplr") 99 | parser.add_argument("--lr", type=float, default=0.001) 100 | parser.add_argument("--max-epoch", type=int, default=250) 101 | parser.add_argument("--step-rate", type=int, default=25) 102 | parser.add_argument("--step-gamma", type=float, default=0.5) 103 | parser.add_argument("--warm-up-epochs", type=int, default=0) 104 | # Administrative 105 | parser.add_argument("--name", type=str, required=True) 106 | parser.add_argument("--gpu", type=int, default=0) 107 | parser.add_argument("--load-checkpoint-path", type=str, default="") 108 | parser.add_argument("--output-dir", type=str, default="") 109 | parser.add_argument("--seed", type=int, default=1234) 110 | args = parser.parse_args() 111 | print(args) 112 | 113 | # Random seeds and devices 114 | np.random.seed(args.seed) 115 | device = torch.device("cpu") 116 | torch.manual_seed(args.seed) 117 | if torch.cuda.device_count(): 118 | print("Using gpu") 119 | torch.cuda.manual_seed(args.seed) 120 | device = torch.device("cuda:" + str(args.gpu)) 121 | 122 | # Visdom 123 | viz = Visdom(env=args.name.replace("/", "__")) 124 | 125 | # Model Initialization 126 | ways = {"train": args.train_way, "test": args.test_way} 127 | shot = {"train": args.train_shot, "test": args.test_shot} 128 | if args.transformer_layers > 0: 129 | model = ProtoConvTransformer( 130 | ways, 131 | shot, 132 | device, 133 | conv_model=args.conv_model, 134 | conv_proj_dim=args.conv_proj_dim, 135 | trans_layers=args.transformer_layers, 136 | trans_d_model=args.trans_d_model, 137 | nhead=args.nhead, 138 | ortho_proj=args.ortho_proj, 139 | ortho_proj_residual=args.ortho_proj_residual, 140 | cls_type=args.cls_type, 141 | pos_type=args.pos_type, 142 | agg_method=args.agg_method, 143 | transformer_metric=args.transformer_metric, 144 | output_from_cls=args.output_from_cls, 145 | binary_outlier_detection=(args.binary_outlier_loss_weight > 0.0), 146 | ) 147 | else: 148 | model = ProtoConvnet( 149 | ways, 150 | shot, 151 | model=args.conv_model, 152 | conv_proj_dim=args.conv_proj_dim, 153 | agg_method=args.agg_method, 154 | binary_outlier_detection=(args.binary_outlier_loss_weight > 0.0), 155 | ) 156 | model.to(device) 157 | print(model) 158 | 159 | # Freeze convolutional parameters 160 | if args.freeze_conv: 161 | for p in model.convnet.parameters(): 162 | p.requires_grad = False 163 | 164 | # Set up MetaTrainer 165 | meta_trainer = MetaTrainer(model, device, args) 166 | 167 | # Optimizer and learning schedule 168 | if args.optimizer == "adam": 169 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 170 | elif args.optimizer == "adamw": 171 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.01) 172 | elif args.optimizer == "nesterov": 173 | optimizer = torch.optim.SGD( 174 | model.parameters(), 175 | lr=args.lr, 176 | momentum=0.9, 177 | nesterov=True, 178 | weight_decay=0.0005, 179 | ) 180 | else: 181 | raise NotImplementedError 182 | 183 | if args.scheduler == "steplr": 184 | lr_scheduler = torch.optim.lr_scheduler.StepLR( 185 | optimizer, step_size=args.step_rate, gamma=args.step_gamma 186 | ) 187 | elif args.scheduler == "cyclic_tri": 188 | lr_scheduler = torch.optim.lr_scheduler.CyclicLR( 189 | optimizer, 190 | base_lr=0.000001, 191 | max_lr=0.0005, 192 | step_size_up=150, 193 | mode="triangular", 194 | cycle_momentum=False, 195 | ) 196 | elif args.scheduler == "cyclic_tri2": 197 | lr_scheduler = torch.optim.lr_scheduler.CyclicLR( 198 | optimizer, 199 | base_lr=0.000001, 200 | max_lr=0.0005, 201 | step_size_up=150, 202 | mode="triangular2", 203 | cycle_momentum=False, 204 | ) 205 | elif args.scheduler == "cyclic_exp": 206 | lr_scheduler = torch.optim.lr_scheduler.CyclicLR( 207 | optimizer, 208 | base_lr=0.000001, 209 | max_lr=0.0005, 210 | step_size_up=150, 211 | mode="exp_range", 212 | gamma=0.998, 213 | cycle_momentum=False, 214 | ) 215 | else: 216 | raise NotImplementedError 217 | if args.warm_up_epochs > 0: 218 | lr_scheduler = GradualWarmupScheduler( 219 | optimizer, 220 | multiplier=1, 221 | total_epoch=args.warm_up_epochs, 222 | after_scheduler=lr_scheduler, 223 | ) 224 | 225 | # Data loading 226 | data_loaders = get_data_loaders(args) 227 | 228 | # Output directory/logging 229 | output_dirname = os.path.join(args.output_dir, args.name) 230 | os.makedirs(output_dirname, exist_ok=True) 231 | with open(os.path.join(output_dirname, "args.txt"), "w") as f: 232 | f.write(str(args) + "\n\n") 233 | f.write(str(model) + "\n") 234 | 235 | # Initialize visdom windows 236 | acc_win = viz.line( 237 | X=np.column_stack((0, 0)), 238 | Y=np.column_stack((0, 0)), 239 | opts={ 240 | "title": "Accuracy", 241 | "xlabel": "Epoch", 242 | "ylabel": "Accuracy", 243 | "show_legend": True, 244 | }, 245 | ) 246 | loss_win = viz.line( 247 | X=np.column_stack((0, 0)), 248 | Y=np.column_stack((0, 0)), 249 | opts={ 250 | "title": "Loss", 251 | "xlabel": "Epoch", 252 | "ylabel": "Loss", 253 | "ytickmax": 2, 254 | "show_legend": True, 255 | }, 256 | ) 257 | lr_win = viz.line( 258 | X=np.array([0]), 259 | Y=np.array([0]), 260 | opts={ 261 | "title": "Learning Rate", 262 | "xlabel": "Epoch", 263 | "ylabel": "Learning Rate", 264 | }, 265 | ) 266 | if args.embed_proto_loss_weight > 0: 267 | embed_proto_loss_win = viz.line( 268 | X=np.column_stack((0, 0)), 269 | Y=np.column_stack((0, 0)), 270 | opts={ 271 | "title": "Embed Proto Loss", 272 | "xlabel": "Epoch", 273 | "ylabel": "Loss", 274 | "ytickmax": 2, 275 | "show_legend": True, 276 | }, 277 | ) 278 | if args.binary_outlier_loss_weight > 0.0: 279 | outlier_loss_win = viz.line( 280 | X=np.column_stack((0, 0)), 281 | Y=np.column_stack((0, 0)), 282 | opts={ 283 | "title": "Outlier Loss", 284 | "xlabel": "Epoch", 285 | "ylabel": "Loss", 286 | "ytickmax": 2, 287 | "show_legend": True, 288 | }, 289 | ) 290 | outlier_acc_win = viz.line( 291 | X=np.column_stack((0, 0)), 292 | Y=np.column_stack((0, 0)), 293 | opts={ 294 | "title": "Outlier Accuracy", 295 | "xlabel": "Epoch", 296 | "ylabel": "Accuracy", 297 | "show_legend": True, 298 | }, 299 | ) 300 | if args.clean_proto_loss_weight > 0.0: 301 | clean_proto_loss_win = viz.line( 302 | X=np.column_stack((0, 0)), 303 | Y=np.column_stack((0, 0)), 304 | opts={ 305 | "title": "Clean Proto Loss", 306 | "xlabel": "Epoch", 307 | "ylabel": "Loss", 308 | "ytickmax": 2, 309 | "show_legend": True, 310 | }, 311 | ) 312 | 313 | # Load model checkpoint if specified; can be partial 314 | if args.load_checkpoint_path != "": 315 | model.load_state_dict( 316 | torch.load(args.load_checkpoint_path, map_location="cuda:" + str(args.gpu)), 317 | strict=False, 318 | ) 319 | # Freeze convolutional parameters 320 | if args.freeze_conv: 321 | for p in model.convnet.parameters(): 322 | p.requires_grad = False 323 | 324 | # Training 325 | best_valid_acc = 0 326 | best_valid_acc_epoch = 0 327 | 328 | for epoch in range(1, args.max_epoch + 1): 329 | # Train step 330 | model.train() 331 | 332 | loss_ctr = 0 333 | n_loss = 0 334 | n_acc = 0 335 | 336 | for i in range(100): 337 | batch = next(iter(data_loaders["train"])) 338 | outlier_batch = ( 339 | next(iter(data_loaders["outlier_train"])) 340 | if args.noise_type == "outlier" 341 | else None 342 | ) 343 | 344 | train_losses, train_accs = meta_trainer.loss_acc( 345 | batch, outlier_batch, args.train_support_label_noise_choices 346 | ) 347 | 348 | loss = torch.sum(torch.stack(list(train_losses.values()))) 349 | 350 | loss_ctr += 1 351 | n_loss += loss.item() 352 | n_acc += train_accs["accuracy"] 353 | 354 | optimizer.zero_grad() 355 | loss.backward() 356 | optimizer.step() 357 | lr_scheduler.step() 358 | 359 | # Train Logging 360 | train_acc = n_acc.cpu().numpy() / loss_ctr 361 | train_loss = n_loss / loss_ctr 362 | train_log = "epoch {}, train, loss={:.4f} acc={:.4f}".format( 363 | epoch, train_loss, train_acc 364 | ) 365 | print(train_log) 366 | with open(os.path.join(output_dirname, "train_val_logs.txt"), "a") as f: 367 | f.write(train_log + "\n") 368 | 369 | # Validation step 370 | model.eval() 371 | 372 | loss_ctr = 0 373 | n_loss = 0 374 | n_acc = 0 375 | for i, batch in enumerate(data_loaders["valid"]): 376 | outlier_batch = ( 377 | next(iter(data_loaders["outlier_test"])) 378 | if args.noise_type == "outlier" 379 | else None 380 | ) 381 | 382 | valid_losses, valid_accs = meta_trainer.loss_acc( 383 | batch, outlier_batch, support_label_noise_choices=[0.0] 384 | ) 385 | 386 | loss = torch.sum(torch.stack(list(valid_losses.values()))) 387 | 388 | loss_ctr += 1 389 | n_loss += loss.item() 390 | n_acc += valid_accs["accuracy"] 391 | 392 | # Validation logging 393 | valid_acc = n_acc.cpu().numpy() / loss_ctr 394 | valid_loss = n_loss / loss_ctr 395 | val_log = "epoch {}, val, loss={:.4f} acc={:.4f}".format( 396 | epoch, valid_loss, valid_acc 397 | ) 398 | print(val_log) 399 | with open(os.path.join(output_dirname, "train_val_logs.txt"), "a") as f: 400 | f.write(val_log + "\n") 401 | 402 | # Save model if best so far 403 | if valid_acc > best_valid_acc: 404 | best_valid_acc = valid_acc 405 | best_valid_acc_epoch = epoch 406 | checkpoint_path = os.path.join(output_dirname, "checkpoint.pth") 407 | torch.save(model.state_dict(), checkpoint_path) 408 | 409 | # Visdom logging of train and validation 410 | viz.line( 411 | X=np.column_stack((epoch, epoch)), 412 | Y=np.column_stack((train_acc, valid_acc)), 413 | win=acc_win, 414 | update="append", 415 | ) 416 | viz.line( 417 | X=np.column_stack((epoch, epoch)), 418 | Y=np.column_stack((train_loss, valid_loss)), 419 | win=loss_win, 420 | update="append", 421 | ) 422 | viz.line( 423 | X=np.array([epoch]), 424 | Y=np.array([lr_scheduler.get_last_lr()]), 425 | win=lr_win, 426 | update="append", 427 | ) 428 | if args.embed_proto_loss_weight > 0: 429 | viz.line( 430 | X=np.column_stack((epoch, epoch)), 431 | Y=np.column_stack( 432 | ( 433 | train_losses["conv_embed_proto"].detach().cpu(), 434 | valid_losses["conv_embed_proto"].detach().cpu(), 435 | ) 436 | ), 437 | win=embed_proto_loss_win, 438 | update="append", 439 | ) 440 | if args.binary_outlier_loss_weight > 0.0: 441 | viz.line( 442 | X=np.column_stack((epoch, epoch)), 443 | Y=np.column_stack( 444 | ( 445 | train_losses["binary_outlier"].detach().cpu(), 446 | valid_losses["binary_outlier"].detach().cpu(), 447 | ) 448 | ), 449 | win=outlier_loss_win, 450 | update="append", 451 | ) 452 | viz.line( 453 | X=np.column_stack((epoch, epoch)), 454 | Y=np.column_stack( 455 | ( 456 | train_accs["binary_outlier"].detach().cpu(), 457 | valid_accs["binary_outlier"].detach().cpu(), 458 | ) 459 | ), 460 | win=outlier_acc_win, 461 | update="append", 462 | ) 463 | if args.clean_proto_loss_weight > 0.0: 464 | viz.line( 465 | X=np.column_stack((epoch, epoch)), 466 | Y=np.column_stack( 467 | ( 468 | train_losses["dist_to_clean_protos"].detach().cpu(), 469 | valid_losses["dist_to_clean_protos"].detach().cpu(), 470 | ) 471 | ), 472 | win=clean_proto_loss_win, 473 | update="append", 474 | ) 475 | 476 | # # Save model 477 | # checkpoint_path = os.path.join(output_dirname, "final_checkpoint.pth") 478 | # torch.save(model.state_dict(), checkpoint_path) 479 | 480 | # Write total training time 481 | train_end = time.time() 482 | train_time_log = "Training time: {}".format( 483 | str(datetime.timedelta(seconds=train_end - train_start)) 484 | ) 485 | print(train_time_log) 486 | with open(os.path.join(output_dirname, "train_val_logs.txt"), "a") as f: 487 | f.write(train_time_log + "\n") 488 | 489 | ## Test 490 | # Load best model 491 | test_start = time.time() 492 | if args.max_epoch > 0: 493 | model.load_state_dict(torch.load(checkpoint_path)) 494 | with open(os.path.join(output_dirname, "test_logs.txt"), "w") as f: 495 | f.write( 496 | "Best epoch: {} ({})\n".format(best_valid_acc_epoch, best_valid_acc) 497 | ) 498 | 499 | # Fix random seeds again 500 | np.random.seed(args.seed) 501 | torch.manual_seed(args.seed) 502 | if torch.cuda.device_count(): 503 | torch.cuda.manual_seed(args.seed) 504 | 505 | model.eval() 506 | for label_noise in args.test_support_label_noise_list: 507 | test_accs = [] 508 | for i, batch in enumerate(data_loaders["test"], 1): 509 | outlier_batch = ( 510 | next(iter(data_loaders["outlier_test"])) 511 | if args.noise_type == "outlier" 512 | else None 513 | ) 514 | 515 | if args.comp_method.startswith("nearest"): 516 | _, accs = meta_trainer.nearest_neighbor( 517 | batch, outlier_batch, support_label_noise_choices=[label_noise] 518 | ) 519 | else: 520 | _, accs = meta_trainer.loss_acc( 521 | batch, outlier_batch, support_label_noise_choices=[label_noise] 522 | ) 523 | test_accs.append(accs["accuracy"].cpu()) 524 | # +/- 95% CI (1.96 * STD/sqrt(n)) 525 | print( 526 | "batch {}: {:.2f}+/-{:.2f} ({:.2f})".format( 527 | i, 528 | np.mean(test_accs) * 100, 529 | 1.96 * np.std(test_accs) / np.sqrt(i + 1) * 100, 530 | accs["accuracy"] * 100, 531 | ) 532 | ) 533 | 534 | test_log = "Noise {} {}: {:.2f}+/-{:.2f}".format( 535 | args.noise_type, 536 | label_noise, 537 | np.mean(test_accs) * 100, 538 | 1.96 * np.std(test_accs) / np.sqrt(args.test_tasks) * 100, 539 | ) 540 | with open(os.path.join(output_dirname, "test_logs.txt"), "a") as f: 541 | f.write(test_log + "\n") 542 | 543 | # Write total test time 544 | test_end = time.time() 545 | test_time_log = "Test time: {}".format( 546 | str(datetime.timedelta(seconds=test_end - test_start)) 547 | ) 548 | print(test_time_log) 549 | with open(os.path.join(output_dirname, "test_logs.txt"), "a") as f: 550 | f.write(test_time_log + "\n") 551 | -------------------------------------------------------------------------------- /create_miniimagenet_outliers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Copyright (c) Meta Platforms, Inc. and affiliates. 5 | All rights reserved. 6 | This source code is licensed under the license found in the 7 | LICENSE file in the root directory of this source tree. 8 | """ 9 | 10 | 11 | import argparse 12 | import cv2 13 | import os 14 | import pickle 15 | 16 | import numpy as np 17 | from tqdm import tqdm 18 | 19 | 20 | MINI_KEYS = [ 21 | "n02110341", 22 | "n01930112", 23 | "n04509417", 24 | "n04067472", 25 | "n04515003", 26 | "n02120079", 27 | "n03924679", 28 | "n02687172", 29 | "n03075370", 30 | "n07747607", 31 | "n09246464", 32 | "n02457408", 33 | "n04418357", 34 | "n03535780", 35 | "n04435653", 36 | "n03207743", 37 | "n04251144", 38 | "n03062245", 39 | "n02174001", 40 | "n07613480", 41 | "n03998194", 42 | "n02074367", 43 | "n04146614", 44 | "n04243546", 45 | "n03854065", 46 | "n03838899", 47 | "n02871525", 48 | "n03544143", 49 | "n02108089", 50 | "n13133613", 51 | "n03676483", 52 | "n03337140", 53 | "n03272010", 54 | "n01770081", 55 | "n09256479", 56 | "n02091244", 57 | "n02116738", 58 | "n04275548", 59 | "n03773504", 60 | "n02606052", 61 | "n03146219", 62 | "n04149813", 63 | "n07697537", 64 | "n02823428", 65 | "n02089867", 66 | "n03017168", 67 | "n01704323", 68 | "n01532829", 69 | "n03047690", 70 | "n03775546", 71 | "n01843383", 72 | "n02971356", 73 | "n13054560", 74 | "n02108551", 75 | "n02101006", 76 | "n03417042", 77 | "n04612504", 78 | "n01558993", 79 | "n04522168", 80 | "n02795169", 81 | "n06794110", 82 | "n01855672", 83 | "n04258138", 84 | "n02110063", 85 | "n07584110", 86 | "n02091831", 87 | "n03584254", 88 | "n03888605", 89 | "n02113712", 90 | "n03980874", 91 | "n02219486", 92 | "n02138441", 93 | "n02165456", 94 | "n02108915", 95 | "n03770439", 96 | "n01981276", 97 | "n03220513", 98 | "n02099601", 99 | "n02747177", 100 | "n01749939", 101 | "n03476684", 102 | "n02105505", 103 | "n02950826", 104 | "n04389033", 105 | "n03347037", 106 | "n02966193", 107 | "n03127925", 108 | "n03400231", 109 | "n04296562", 110 | "n03527444", 111 | "n04443257", 112 | "n02443484", 113 | "n02114548", 114 | "n04604644", 115 | "n01910747", 116 | "n04596742", 117 | "n02111277", 118 | "n03908618", 119 | "n02129165", 120 | "n02981792", 121 | ] 122 | 123 | TIERED_KEYS_TRAIN_VAL = [ 124 | "n01530575", 125 | "n01531178", 126 | "n01532829", 127 | "n01534433", 128 | "n01537544", 129 | "n01558993", 130 | "n01560419", 131 | "n01580077", 132 | "n01582220", 133 | "n01592084", 134 | "n01601694", 135 | "n01675722", 136 | "n01677366", 137 | "n01682714", 138 | "n01685808", 139 | "n01687978", 140 | "n01688243", 141 | "n01689811", 142 | "n01692333", 143 | "n01693334", 144 | "n01694178", 145 | "n01695060", 146 | "n01728572", 147 | "n01728920", 148 | "n01729322", 149 | "n01729977", 150 | "n01734418", 151 | "n01735189", 152 | "n01737021", 153 | "n01739381", 154 | "n01740131", 155 | "n01742172", 156 | "n01744401", 157 | "n01748264", 158 | "n01749939", 159 | "n01751748", 160 | "n01753488", 161 | "n01755581", 162 | "n01756291", 163 | "n01847000", 164 | "n01855032", 165 | "n01855672", 166 | "n01860187", 167 | "n02002556", 168 | "n02002724", 169 | "n02006656", 170 | "n02007558", 171 | "n02009229", 172 | "n02009912", 173 | "n02011460", 174 | "n02012849", 175 | "n02013706", 176 | "n02017213", 177 | "n02018207", 178 | "n02018795", 179 | "n02025239", 180 | "n02027492", 181 | "n02028035", 182 | "n02033041", 183 | "n02037110", 184 | "n02051845", 185 | "n02056570", 186 | "n02058221", 187 | "n02088094", 188 | "n02088238", 189 | "n02088364", 190 | "n02088466", 191 | "n02088632", 192 | "n02089078", 193 | "n02089867", 194 | "n02089973", 195 | "n02090379", 196 | "n02090622", 197 | "n02090721", 198 | "n02091032", 199 | "n02091134", 200 | "n02091244", 201 | "n02091467", 202 | "n02091635", 203 | "n02091831", 204 | "n02092002", 205 | "n02092339", 206 | "n02093256", 207 | "n02093428", 208 | "n02093647", 209 | "n02093754", 210 | "n02093859", 211 | "n02093991", 212 | "n02094114", 213 | "n02094258", 214 | "n02094433", 215 | "n02095314", 216 | "n02095570", 217 | "n02095889", 218 | "n02096051", 219 | "n02096177", 220 | "n02096294", 221 | "n02096437", 222 | "n02096585", 223 | "n02097047", 224 | "n02097130", 225 | "n02097209", 226 | "n02097298", 227 | "n02097474", 228 | "n02097658", 229 | "n02098105", 230 | "n02098286", 231 | "n02098413", 232 | "n02099267", 233 | "n02099429", 234 | "n02099601", 235 | "n02099712", 236 | "n02099849", 237 | "n02100236", 238 | "n02100583", 239 | "n02100735", 240 | "n02100877", 241 | "n02101006", 242 | "n02101388", 243 | "n02101556", 244 | "n02102040", 245 | "n02102177", 246 | "n02102318", 247 | "n02102480", 248 | "n02102973", 249 | "n02123045", 250 | "n02123159", 251 | "n02123394", 252 | "n02123597", 253 | "n02124075", 254 | "n02125311", 255 | "n02127052", 256 | "n02128385", 257 | "n02128757", 258 | "n02128925", 259 | "n02129165", 260 | "n02129604", 261 | "n02130308", 262 | "n02389026", 263 | "n02391049", 264 | "n02395406", 265 | "n02396427", 266 | "n02397096", 267 | "n02398521", 268 | "n02403003", 269 | "n02408429", 270 | "n02410509", 271 | "n02412080", 272 | "n02415577", 273 | "n02417914", 274 | "n02422106", 275 | "n02422699", 276 | "n02423022", 277 | "n02437312", 278 | "n02437616", 279 | "n02480495", 280 | "n02480855", 281 | "n02481823", 282 | "n02483362", 283 | "n02483708", 284 | "n02484975", 285 | "n02486261", 286 | "n02486410", 287 | "n02487347", 288 | "n02488291", 289 | "n02488702", 290 | "n02489166", 291 | "n02490219", 292 | "n02492035", 293 | "n02492660", 294 | "n02493509", 295 | "n02493793", 296 | "n02494079", 297 | "n02497673", 298 | "n02500267", 299 | "n02727426", 300 | "n02793495", 301 | "n02859443", 302 | "n03028079", 303 | "n03032252", 304 | "n03457902", 305 | "n03529860", 306 | "n03661043", 307 | "n03781244", 308 | "n03788195", 309 | "n03877845", 310 | "n03956157", 311 | "n04081281", 312 | "n04346328", 313 | "n02687172", 314 | "n02690373", 315 | "n02692877", 316 | "n02782093", 317 | "n02951358", 318 | "n02981792", 319 | "n03095699", 320 | "n03344393", 321 | "n03447447", 322 | "n03662601", 323 | "n03673027", 324 | "n03947888", 325 | "n04147183", 326 | "n04266014", 327 | "n04273569", 328 | "n04347754", 329 | "n04483307", 330 | "n04552348", 331 | "n04606251", 332 | "n04612504", 333 | "n03207941", 334 | "n03259280", 335 | "n03297495", 336 | "n03483316", 337 | "n03584829", 338 | "n03761084", 339 | "n04070727", 340 | "n04111531", 341 | "n04442312", 342 | "n04517823", 343 | "n04542943", 344 | "n04554684", 345 | "n02979186", 346 | "n02988304", 347 | "n02992529", 348 | "n03085013", 349 | "n03187595", 350 | "n03584254", 351 | "n03777754", 352 | "n03782006", 353 | "n03857828", 354 | "n03902125", 355 | "n04392985", 356 | "n02776631", 357 | "n02791270", 358 | "n02871525", 359 | "n02927161", 360 | "n03089624", 361 | "n03461385", 362 | "n04005630", 363 | "n04200800", 364 | "n04443257", 365 | "n04462240", 366 | "n02791124", 367 | "n02804414", 368 | "n02870880", 369 | "n03016953", 370 | "n03018349", 371 | "n03125729", 372 | "n03131574", 373 | "n03179701", 374 | "n03201208", 375 | "n03290653", 376 | "n03337140", 377 | "n03376595", 378 | "n03388549", 379 | "n03742115", 380 | "n03891251", 381 | "n03998194", 382 | "n04099969", 383 | "n04344873", 384 | "n04380533", 385 | "n04429376", 386 | "n04447861", 387 | "n04550184", 388 | "n02799071", 389 | "n02802426", 390 | "n03134739", 391 | "n03445777", 392 | "n03598930", 393 | "n03942813", 394 | "n04023962", 395 | "n04118538", 396 | "n04254680", 397 | "n04409515", 398 | "n04540053", 399 | "n06785654", 400 | "n02667093", 401 | "n02837789", 402 | "n02865351", 403 | "n02883205", 404 | "n02892767", 405 | "n02963159", 406 | "n03188531", 407 | "n03325584", 408 | "n03404251", 409 | "n03534580", 410 | "n03594734", 411 | "n03595614", 412 | "n03617480", 413 | "n03630383", 414 | "n03710721", 415 | "n03770439", 416 | "n03866082", 417 | "n03980874", 418 | "n04136333", 419 | "n04325704", 420 | "n04350905", 421 | "n04370456", 422 | "n04371430", 423 | "n04479046", 424 | "n04591157", 425 | "n02708093", 426 | "n02749479", 427 | "n02794156", 428 | "n02841315", 429 | "n02879718", 430 | "n02950826", 431 | "n03196217", 432 | "n03197337", 433 | "n03467068", 434 | "n03544143", 435 | "n03692522", 436 | "n03706229", 437 | "n03773504", 438 | "n03841143", 439 | "n03891332", 440 | "n04008634", 441 | "n04009552", 442 | "n04044716", 443 | "n04086273", 444 | "n04090263", 445 | "n04118776", 446 | "n04141975", 447 | "n04317175", 448 | "n04328186", 449 | "n04355338", 450 | "n04356056", 451 | "n04376876", 452 | "n04548280", 453 | "n02666196", 454 | "n02977058", 455 | "n03180011", 456 | "n03485407", 457 | "n03496892", 458 | "n03642806", 459 | "n03832673", 460 | "n04238763", 461 | "n04243546", 462 | "n04428191", 463 | "n04525305", 464 | "n06359193", 465 | "n02966193", 466 | "n02974003", 467 | "n03425413", 468 | "n03532672", 469 | "n03874293", 470 | "n03944341", 471 | "n03992509", 472 | "n04019541", 473 | "n04040759", 474 | "n04067472", 475 | "n04371774", 476 | "n04372370", 477 | "n02701002", 478 | "n02704792", 479 | "n02814533", 480 | "n02930766", 481 | "n03100240", 482 | "n03345487", 483 | "n03417042", 484 | "n03444034", 485 | "n03445924", 486 | "n03594945", 487 | "n03670208", 488 | "n03770679", 489 | "n03777568", 490 | "n03785016", 491 | "n03796401", 492 | "n03930630", 493 | "n03977966", 494 | "n04037443", 495 | "n04252225", 496 | "n04285008", 497 | "n04461696", 498 | "n04467665", 499 | "n02672831", 500 | "n02676566", 501 | "n02787622", 502 | "n02804610", 503 | "n02992211", 504 | "n03017168", 505 | "n03110669", 506 | "n03249569", 507 | "n03272010", 508 | "n03372029", 509 | "n03394916", 510 | "n03447721", 511 | "n03452741", 512 | "n03494278", 513 | "n03495258", 514 | "n03720891", 515 | "n03721384", 516 | "n03838899", 517 | "n03840681", 518 | "n03854065", 519 | "n03884397", 520 | "n04141076", 521 | "n04311174", 522 | "n04487394", 523 | "n04515003", 524 | "n04536866", 525 | "n02825657", 526 | "n02840245", 527 | "n02843684", 528 | "n02895154", 529 | "n03000247", 530 | "n03146219", 531 | "n03220513", 532 | "n03347037", 533 | "n03424325", 534 | "n03527444", 535 | "n03637318", 536 | "n03657121", 537 | "n03788365", 538 | "n03929855", 539 | "n04141327", 540 | "n04192698", 541 | "n04229816", 542 | "n04417672", 543 | "n04423845", 544 | "n04435653", 545 | "n04507155", 546 | "n04523525", 547 | "n04589890", 548 | "n04590129", 549 | "n02910353", 550 | "n03075370", 551 | "n03208938", 552 | "n03476684", 553 | "n03627232", 554 | "n03803284", 555 | "n03804744", 556 | "n03874599", 557 | "n04127249", 558 | "n04153751", 559 | "n04162706", 560 | "n02951585", 561 | "n03041632", 562 | "n03109150", 563 | "n03481172", 564 | "n03498962", 565 | "n03649909", 566 | "n03658185", 567 | "n03954731", 568 | "n03967562", 569 | "n03970156", 570 | "n04154565", 571 | "n04208210", 572 | ] 573 | 574 | TIERED_KEYS_TEST = [ 575 | "n03314780", 576 | "n07565083", 577 | "n07579787", 578 | "n07583066", 579 | "n07584110", 580 | "n07590611", 581 | "n07613480", 582 | "n07614500", 583 | "n07615774", 584 | "n07697313", 585 | "n07697537", 586 | "n07802026", 587 | "n07831146", 588 | "n07836838", 589 | "n07860988", 590 | "n07873807", 591 | "n07875152", 592 | "n07880968", 593 | "n07892512", 594 | "n07920052", 595 | "n07930864", 596 | "n07932039", 597 | "n01440764", 598 | "n01443537", 599 | "n01484850", 600 | "n01491361", 601 | "n01494475", 602 | "n01496331", 603 | "n01498041", 604 | "n02514041", 605 | "n02526121", 606 | "n02536864", 607 | "n02606052", 608 | "n02607072", 609 | "n02640242", 610 | "n02641379", 611 | "n02643566", 612 | "n02655020", 613 | "n02104029", 614 | "n02104365", 615 | "n02105056", 616 | "n02105162", 617 | "n02105251", 618 | "n02105412", 619 | "n02105505", 620 | "n02105641", 621 | "n02105855", 622 | "n02106030", 623 | "n02106166", 624 | "n02106382", 625 | "n02106550", 626 | "n02106662", 627 | "n02107142", 628 | "n02107312", 629 | "n02107574", 630 | "n02107683", 631 | "n02107908", 632 | "n02108000", 633 | "n02108089", 634 | "n02108422", 635 | "n02108551", 636 | "n02108915", 637 | "n02109047", 638 | "n02109525", 639 | "n02109961", 640 | "n02110063", 641 | "n02110185", 642 | "n02110627", 643 | "n02165105", 644 | "n02165456", 645 | "n02167151", 646 | "n02168699", 647 | "n02169497", 648 | "n02172182", 649 | "n02174001", 650 | "n02177972", 651 | "n02190166", 652 | "n02206856", 653 | "n02219486", 654 | "n02226429", 655 | "n02229544", 656 | "n02231487", 657 | "n02233338", 658 | "n02236044", 659 | "n02256656", 660 | "n02259212", 661 | "n02264363", 662 | "n02268443", 663 | "n02268853", 664 | "n02276258", 665 | "n02277742", 666 | "n02279972", 667 | "n02280649", 668 | "n02281406", 669 | "n02281787", 670 | "n02788148", 671 | "n02894605", 672 | "n03000134", 673 | "n03160309", 674 | "n03459775", 675 | "n03930313", 676 | "n04239074", 677 | "n04326547", 678 | "n04501370", 679 | "n04604644", 680 | "n02795169", 681 | "n02808440", 682 | "n02815834", 683 | "n02823428", 684 | "n02909870", 685 | "n02939185", 686 | "n03063599", 687 | "n03063689", 688 | "n03633091", 689 | "n03786901", 690 | "n03937543", 691 | "n03950228", 692 | "n03983396", 693 | "n04049303", 694 | "n04398044", 695 | "n04493381", 696 | "n04522168", 697 | "n04553703", 698 | "n04557648", 699 | "n04560804", 700 | "n04562935", 701 | "n04579145", 702 | "n04591713", 703 | "n09193705", 704 | "n09246464", 705 | "n09256479", 706 | "n09288635", 707 | "n09332890", 708 | "n09399592", 709 | "n09421951", 710 | "n09428293", 711 | "n09468604", 712 | "n09472597", 713 | "n07714571", 714 | "n07714990", 715 | "n07715103", 716 | "n07716358", 717 | "n07716906", 718 | "n07717410", 719 | "n07717556", 720 | "n07718472", 721 | "n07718747", 722 | "n07720875", 723 | "n07730033", 724 | "n07734744", 725 | "n07742313", 726 | "n07745940", 727 | "n07747607", 728 | "n07749582", 729 | "n07753113", 730 | "n07753275", 731 | "n07753592", 732 | "n07754684", 733 | "n07760859", 734 | "n07768694", 735 | ] 736 | 737 | ALL_KEYS = [ 738 | "n01440764", 739 | "n01443537", 740 | "n01484850", 741 | "n01491361", 742 | "n01494475", 743 | "n01496331", 744 | "n01498041", 745 | "n01514668", 746 | "n01514859", 747 | "n01518878", 748 | "n01530575", 749 | "n01531178", 750 | "n01532829", 751 | "n01534433", 752 | "n01537544", 753 | "n01558993", 754 | "n01560419", 755 | "n01580077", 756 | "n01582220", 757 | "n01592084", 758 | "n01601694", 759 | "n01608432", 760 | "n01614925", 761 | "n01616318", 762 | "n01622779", 763 | "n01629819", 764 | "n01630670", 765 | "n01631663", 766 | "n01632458", 767 | "n01632777", 768 | "n01641577", 769 | "n01644373", 770 | "n01644900", 771 | "n01664065", 772 | "n01665541", 773 | "n01667114", 774 | "n01667778", 775 | "n01669191", 776 | "n01675722", 777 | "n01677366", 778 | "n01682714", 779 | "n01685808", 780 | "n01687978", 781 | "n01688243", 782 | "n01689811", 783 | "n01692333", 784 | "n01693334", 785 | "n01694178", 786 | "n01695060", 787 | "n01697457", 788 | "n01698640", 789 | "n01704323", 790 | "n01728572", 791 | "n01728920", 792 | "n01729322", 793 | "n01729977", 794 | "n01734418", 795 | "n01735189", 796 | "n01737021", 797 | "n01739381", 798 | "n01740131", 799 | "n01742172", 800 | "n01744401", 801 | "n01748264", 802 | "n01749939", 803 | "n01751748", 804 | "n01753488", 805 | "n01755581", 806 | "n01756291", 807 | "n01768244", 808 | "n01770081", 809 | "n01770393", 810 | "n01773157", 811 | "n01773549", 812 | "n01773797", 813 | "n01774384", 814 | "n01774750", 815 | "n01775062", 816 | "n01776313", 817 | "n01784675", 818 | "n01795545", 819 | "n01796340", 820 | "n01797886", 821 | "n01798484", 822 | "n01806143", 823 | "n01806567", 824 | "n01807496", 825 | "n01817953", 826 | "n01818515", 827 | "n01819313", 828 | "n01820546", 829 | "n01824575", 830 | "n01828970", 831 | "n01829413", 832 | "n01833805", 833 | "n01843065", 834 | "n01843383", 835 | "n01847000", 836 | "n01855032", 837 | "n01855672", 838 | "n01860187", 839 | "n01871265", 840 | "n01872401", 841 | "n01873310", 842 | "n01877812", 843 | "n01882714", 844 | "n01883070", 845 | "n01910747", 846 | "n01914609", 847 | "n01917289", 848 | "n01924916", 849 | "n01930112", 850 | "n01943899", 851 | "n01944390", 852 | "n01945685", 853 | "n01950731", 854 | "n01955084", 855 | "n01968897", 856 | "n01978287", 857 | "n01978455", 858 | "n01980166", 859 | "n01981276", 860 | "n01983481", 861 | "n01984695", 862 | "n01985128", 863 | "n01986214", 864 | "n01990800", 865 | "n02002556", 866 | "n02002724", 867 | "n02006656", 868 | "n02007558", 869 | "n02009229", 870 | "n02009912", 871 | "n02011460", 872 | "n02012849", 873 | "n02013706", 874 | "n02017213", 875 | "n02018207", 876 | "n02018795", 877 | "n02025239", 878 | "n02027492", 879 | "n02028035", 880 | "n02033041", 881 | "n02037110", 882 | "n02051845", 883 | "n02056570", 884 | "n02058221", 885 | "n02066245", 886 | "n02071294", 887 | "n02074367", 888 | "n02077923", 889 | "n02085620", 890 | "n02085782", 891 | "n02085936", 892 | "n02086079", 893 | "n02086240", 894 | "n02086646", 895 | "n02086910", 896 | "n02087046", 897 | "n02087394", 898 | "n02088094", 899 | "n02088238", 900 | "n02088364", 901 | "n02088466", 902 | "n02088632", 903 | "n02089078", 904 | "n02089867", 905 | "n02089973", 906 | "n02090379", 907 | "n02090622", 908 | "n02090721", 909 | "n02091032", 910 | "n02091134", 911 | "n02091244", 912 | "n02091467", 913 | "n02091635", 914 | "n02091831", 915 | "n02092002", 916 | "n02092339", 917 | "n02093256", 918 | "n02093428", 919 | "n02093647", 920 | "n02093754", 921 | "n02093859", 922 | "n02093991", 923 | "n02094114", 924 | "n02094258", 925 | "n02094433", 926 | "n02095314", 927 | "n02095570", 928 | "n02095889", 929 | "n02096051", 930 | "n02096177", 931 | "n02096294", 932 | "n02096437", 933 | "n02096585", 934 | "n02097047", 935 | "n02097130", 936 | "n02097209", 937 | "n02097298", 938 | "n02097474", 939 | "n02097658", 940 | "n02098105", 941 | "n02098286", 942 | "n02098413", 943 | "n02099267", 944 | "n02099429", 945 | "n02099601", 946 | "n02099712", 947 | "n02099849", 948 | "n02100236", 949 | "n02100583", 950 | "n02100735", 951 | "n02100877", 952 | "n02101006", 953 | "n02101388", 954 | "n02101556", 955 | "n02102040", 956 | "n02102177", 957 | "n02102318", 958 | "n02102480", 959 | "n02102973", 960 | "n02104029", 961 | "n02104365", 962 | "n02105056", 963 | "n02105162", 964 | "n02105251", 965 | "n02105412", 966 | "n02105505", 967 | "n02105641", 968 | "n02105855", 969 | "n02106030", 970 | "n02106166", 971 | "n02106382", 972 | "n02106550", 973 | "n02106662", 974 | "n02107142", 975 | "n02107312", 976 | "n02107574", 977 | "n02107683", 978 | "n02107908", 979 | "n02108000", 980 | "n02108089", 981 | "n02108422", 982 | "n02108551", 983 | "n02108915", 984 | "n02109047", 985 | "n02109525", 986 | "n02109961", 987 | "n02110063", 988 | "n02110185", 989 | "n02110341", 990 | "n02110627", 991 | "n02110806", 992 | "n02110958", 993 | "n02111129", 994 | "n02111277", 995 | "n02111500", 996 | "n02111889", 997 | "n02112018", 998 | "n02112137", 999 | "n02112350", 1000 | "n02112706", 1001 | "n02113023", 1002 | "n02113186", 1003 | "n02113624", 1004 | "n02113712", 1005 | "n02113799", 1006 | "n02113978", 1007 | "n02114367", 1008 | "n02114548", 1009 | "n02114712", 1010 | "n02114855", 1011 | "n02115641", 1012 | "n02115913", 1013 | "n02116738", 1014 | "n02117135", 1015 | "n02119022", 1016 | "n02119789", 1017 | "n02120079", 1018 | "n02120505", 1019 | "n02123045", 1020 | "n02123159", 1021 | "n02123394", 1022 | "n02123597", 1023 | "n02124075", 1024 | "n02125311", 1025 | "n02127052", 1026 | "n02128385", 1027 | "n02128757", 1028 | "n02128925", 1029 | "n02129165", 1030 | "n02129604", 1031 | "n02130308", 1032 | "n02132136", 1033 | "n02133161", 1034 | "n02134084", 1035 | "n02134418", 1036 | "n02137549", 1037 | "n02138441", 1038 | "n02165105", 1039 | "n02165456", 1040 | "n02167151", 1041 | "n02168699", 1042 | "n02169497", 1043 | "n02172182", 1044 | "n02174001", 1045 | "n02177972", 1046 | "n02190166", 1047 | "n02206856", 1048 | "n02219486", 1049 | "n02226429", 1050 | "n02229544", 1051 | "n02231487", 1052 | "n02233338", 1053 | "n02236044", 1054 | "n02256656", 1055 | "n02259212", 1056 | "n02264363", 1057 | "n02268443", 1058 | "n02268853", 1059 | "n02276258", 1060 | "n02277742", 1061 | "n02279972", 1062 | "n02280649", 1063 | "n02281406", 1064 | "n02281787", 1065 | "n02317335", 1066 | "n02319095", 1067 | "n02321529", 1068 | "n02325366", 1069 | "n02326432", 1070 | "n02328150", 1071 | "n02342885", 1072 | "n02346627", 1073 | "n02356798", 1074 | "n02361337", 1075 | "n02363005", 1076 | "n02364673", 1077 | "n02389026", 1078 | "n02391049", 1079 | "n02395406", 1080 | "n02396427", 1081 | "n02397096", 1082 | "n02398521", 1083 | "n02403003", 1084 | "n02408429", 1085 | "n02410509", 1086 | "n02412080", 1087 | "n02415577", 1088 | "n02417914", 1089 | "n02422106", 1090 | "n02422699", 1091 | "n02423022", 1092 | "n02437312", 1093 | "n02437616", 1094 | "n02441942", 1095 | "n02442845", 1096 | "n02443114", 1097 | "n02443484", 1098 | "n02444819", 1099 | "n02445715", 1100 | "n02447366", 1101 | "n02454379", 1102 | "n02457408", 1103 | "n02480495", 1104 | "n02480855", 1105 | "n02481823", 1106 | "n02483362", 1107 | "n02483708", 1108 | "n02484975", 1109 | "n02486261", 1110 | "n02486410", 1111 | "n02487347", 1112 | "n02488291", 1113 | "n02488702", 1114 | "n02489166", 1115 | "n02490219", 1116 | "n02492035", 1117 | "n02492660", 1118 | "n02493509", 1119 | "n02493793", 1120 | "n02494079", 1121 | "n02497673", 1122 | "n02500267", 1123 | "n02504013", 1124 | "n02504458", 1125 | "n02509815", 1126 | "n02510455", 1127 | "n02514041", 1128 | "n02526121", 1129 | "n02536864", 1130 | "n02606052", 1131 | "n02607072", 1132 | "n02640242", 1133 | "n02641379", 1134 | "n02643566", 1135 | "n02655020", 1136 | "n02666196", 1137 | "n02667093", 1138 | "n02669723", 1139 | "n02672831", 1140 | "n02676566", 1141 | "n02687172", 1142 | "n02690373", 1143 | "n02692877", 1144 | "n02699494", 1145 | "n02701002", 1146 | "n02704792", 1147 | "n02708093", 1148 | "n02727426", 1149 | "n02730930", 1150 | "n02747177", 1151 | "n02749479", 1152 | "n02769748", 1153 | "n02776631", 1154 | "n02777292", 1155 | "n02782093", 1156 | "n02783161", 1157 | "n02786058", 1158 | "n02787622", 1159 | "n02788148", 1160 | "n02790996", 1161 | "n02791124", 1162 | "n02791270", 1163 | "n02793495", 1164 | "n02794156", 1165 | "n02795169", 1166 | "n02797295", 1167 | "n02799071", 1168 | "n02802426", 1169 | "n02804414", 1170 | "n02804610", 1171 | "n02807133", 1172 | "n02808304", 1173 | "n02808440", 1174 | "n02814533", 1175 | "n02814860", 1176 | "n02815834", 1177 | "n02817516", 1178 | "n02823428", 1179 | "n02823750", 1180 | "n02825657", 1181 | "n02834397", 1182 | "n02835271", 1183 | "n02837789", 1184 | "n02840245", 1185 | "n02841315", 1186 | "n02843684", 1187 | "n02859443", 1188 | "n02860847", 1189 | "n02865351", 1190 | "n02869837", 1191 | "n02870880", 1192 | "n02871525", 1193 | "n02877765", 1194 | "n02879718", 1195 | "n02883205", 1196 | "n02892201", 1197 | "n02892767", 1198 | "n02894605", 1199 | "n02895154", 1200 | "n02906734", 1201 | "n02909870", 1202 | "n02910353", 1203 | "n02916936", 1204 | "n02917067", 1205 | "n02927161", 1206 | "n02930766", 1207 | "n02939185", 1208 | "n02948072", 1209 | "n02950826", 1210 | "n02951358", 1211 | "n02951585", 1212 | "n02963159", 1213 | "n02965783", 1214 | "n02966193", 1215 | "n02966687", 1216 | "n02971356", 1217 | "n02974003", 1218 | "n02977058", 1219 | "n02978881", 1220 | "n02979186", 1221 | "n02980441", 1222 | "n02981792", 1223 | "n02988304", 1224 | "n02992211", 1225 | "n02992529", 1226 | "n02999410", 1227 | "n03000134", 1228 | "n03000247", 1229 | "n03000684", 1230 | "n03014705", 1231 | "n03016953", 1232 | "n03017168", 1233 | "n03018349", 1234 | "n03026506", 1235 | "n03028079", 1236 | "n03032252", 1237 | "n03041632", 1238 | "n03042490", 1239 | "n03045698", 1240 | "n03047690", 1241 | "n03062245", 1242 | "n03063599", 1243 | "n03063689", 1244 | "n03065424", 1245 | "n03075370", 1246 | "n03085013", 1247 | "n03089624", 1248 | "n03095699", 1249 | "n03100240", 1250 | "n03109150", 1251 | "n03110669", 1252 | "n03124043", 1253 | "n03124170", 1254 | "n03125729", 1255 | "n03126707", 1256 | "n03127747", 1257 | "n03127925", 1258 | "n03131574", 1259 | "n03133878", 1260 | "n03134739", 1261 | "n03141823", 1262 | "n03146219", 1263 | "n03160309", 1264 | "n03179701", 1265 | "n03180011", 1266 | "n03187595", 1267 | "n03188531", 1268 | "n03196217", 1269 | "n03197337", 1270 | "n03201208", 1271 | "n03207743", 1272 | "n03207941", 1273 | "n03208938", 1274 | "n03216828", 1275 | "n03218198", 1276 | "n03220513", 1277 | "n03223299", 1278 | "n03240683", 1279 | "n03249569", 1280 | "n03250847", 1281 | "n03255030", 1282 | "n03259280", 1283 | "n03271574", 1284 | "n03272010", 1285 | "n03272562", 1286 | "n03290653", 1287 | "n03291819", 1288 | "n03297495", 1289 | "n03314780", 1290 | "n03325584", 1291 | "n03337140", 1292 | "n03344393", 1293 | "n03345487", 1294 | "n03347037", 1295 | "n03355925", 1296 | "n03372029", 1297 | "n03376595", 1298 | "n03379051", 1299 | "n03384352", 1300 | "n03388043", 1301 | "n03388183", 1302 | "n03388549", 1303 | "n03393912", 1304 | "n03394916", 1305 | "n03400231", 1306 | "n03404251", 1307 | "n03417042", 1308 | "n03424325", 1309 | "n03425413", 1310 | "n03443371", 1311 | "n03444034", 1312 | "n03445777", 1313 | "n03445924", 1314 | "n03447447", 1315 | "n03447721", 1316 | "n03450230", 1317 | "n03452741", 1318 | "n03457902", 1319 | "n03459775", 1320 | "n03461385", 1321 | "n03467068", 1322 | "n03476684", 1323 | "n03476991", 1324 | "n03478589", 1325 | "n03481172", 1326 | "n03482405", 1327 | "n03483316", 1328 | "n03485407", 1329 | "n03485794", 1330 | "n03492542", 1331 | "n03494278", 1332 | "n03495258", 1333 | "n03496892", 1334 | "n03498962", 1335 | "n03527444", 1336 | "n03529860", 1337 | "n03530642", 1338 | "n03532672", 1339 | "n03534580", 1340 | "n03535780", 1341 | "n03538406", 1342 | "n03544143", 1343 | "n03584254", 1344 | "n03584829", 1345 | "n03590841", 1346 | "n03594734", 1347 | "n03594945", 1348 | "n03595614", 1349 | "n03598930", 1350 | "n03599486", 1351 | "n03602883", 1352 | "n03617480", 1353 | "n03623198", 1354 | "n03627232", 1355 | "n03630383", 1356 | "n03633091", 1357 | "n03637318", 1358 | "n03642806", 1359 | "n03649909", 1360 | "n03657121", 1361 | "n03658185", 1362 | "n03661043", 1363 | "n03662601", 1364 | "n03666591", 1365 | "n03670208", 1366 | "n03673027", 1367 | "n03676483", 1368 | "n03680355", 1369 | "n03690938", 1370 | "n03691459", 1371 | "n03692522", 1372 | "n03697007", 1373 | "n03706229", 1374 | "n03709823", 1375 | "n03710193", 1376 | "n03710637", 1377 | "n03710721", 1378 | "n03717622", 1379 | "n03720891", 1380 | "n03721384", 1381 | "n03724870", 1382 | "n03729826", 1383 | "n03733131", 1384 | "n03733281", 1385 | "n03733805", 1386 | "n03742115", 1387 | "n03743016", 1388 | "n03759954", 1389 | "n03761084", 1390 | "n03763968", 1391 | "n03764736", 1392 | "n03769881", 1393 | "n03770439", 1394 | "n03770679", 1395 | "n03773504", 1396 | "n03775071", 1397 | "n03775546", 1398 | "n03776460", 1399 | "n03777568", 1400 | "n03777754", 1401 | "n03781244", 1402 | "n03782006", 1403 | "n03785016", 1404 | "n03786901", 1405 | "n03787032", 1406 | "n03788195", 1407 | "n03788365", 1408 | "n03791053", 1409 | "n03792782", 1410 | "n03792972", 1411 | "n03793489", 1412 | "n03794056", 1413 | "n03796401", 1414 | "n03803284", 1415 | "n03804744", 1416 | "n03814639", 1417 | "n03814906", 1418 | "n03825788", 1419 | "n03832673", 1420 | "n03837869", 1421 | "n03838899", 1422 | "n03840681", 1423 | "n03841143", 1424 | "n03843555", 1425 | "n03854065", 1426 | "n03857828", 1427 | "n03866082", 1428 | "n03868242", 1429 | "n03868863", 1430 | "n03871628", 1431 | "n03873416", 1432 | "n03874293", 1433 | "n03874599", 1434 | "n03876231", 1435 | "n03877472", 1436 | "n03877845", 1437 | "n03884397", 1438 | "n03887697", 1439 | "n03888257", 1440 | "n03888605", 1441 | "n03891251", 1442 | "n03891332", 1443 | "n03895866", 1444 | "n03899768", 1445 | "n03902125", 1446 | "n03903868", 1447 | "n03908618", 1448 | "n03908714", 1449 | "n03916031", 1450 | "n03920288", 1451 | "n03924679", 1452 | "n03929660", 1453 | "n03929855", 1454 | "n03930313", 1455 | "n03930630", 1456 | "n03933933", 1457 | "n03935335", 1458 | "n03937543", 1459 | "n03938244", 1460 | "n03942813", 1461 | "n03944341", 1462 | "n03947888", 1463 | "n03950228", 1464 | "n03954731", 1465 | "n03956157", 1466 | "n03958227", 1467 | "n03961711", 1468 | "n03967562", 1469 | "n03970156", 1470 | "n03976467", 1471 | "n03976657", 1472 | "n03977966", 1473 | "n03980874", 1474 | "n03982430", 1475 | "n03983396", 1476 | "n03991062", 1477 | "n03992509", 1478 | "n03995372", 1479 | "n03998194", 1480 | "n04004767", 1481 | "n04005630", 1482 | "n04008634", 1483 | "n04009552", 1484 | "n04019541", 1485 | "n04023962", 1486 | "n04026417", 1487 | "n04033901", 1488 | "n04033995", 1489 | "n04037443", 1490 | "n04039381", 1491 | "n04040759", 1492 | "n04041544", 1493 | "n04044716", 1494 | "n04049303", 1495 | "n04065272", 1496 | "n04067472", 1497 | "n04069434", 1498 | "n04070727", 1499 | "n04074963", 1500 | "n04081281", 1501 | "n04086273", 1502 | "n04090263", 1503 | "n04099969", 1504 | "n04111531", 1505 | "n04116512", 1506 | "n04118538", 1507 | "n04118776", 1508 | "n04120489", 1509 | "n04125021", 1510 | "n04127249", 1511 | "n04131690", 1512 | "n04133789", 1513 | "n04136333", 1514 | "n04141076", 1515 | "n04141327", 1516 | "n04141975", 1517 | "n04146614", 1518 | "n04147183", 1519 | "n04149813", 1520 | "n04152593", 1521 | "n04153751", 1522 | "n04154565", 1523 | "n04162706", 1524 | "n04179913", 1525 | "n04192698", 1526 | "n04200800", 1527 | "n04201297", 1528 | "n04204238", 1529 | "n04204347", 1530 | "n04208210", 1531 | "n04209133", 1532 | "n04209239", 1533 | "n04228054", 1534 | "n04229816", 1535 | "n04235860", 1536 | "n04238763", 1537 | "n04239074", 1538 | "n04243546", 1539 | "n04251144", 1540 | "n04252077", 1541 | "n04252225", 1542 | "n04254120", 1543 | "n04254680", 1544 | "n04254777", 1545 | "n04258138", 1546 | "n04259630", 1547 | "n04263257", 1548 | "n04264628", 1549 | "n04265275", 1550 | "n04266014", 1551 | "n04270147", 1552 | "n04273569", 1553 | "n04275548", 1554 | "n04277352", 1555 | "n04285008", 1556 | "n04286575", 1557 | "n04296562", 1558 | "n04310018", 1559 | "n04311004", 1560 | "n04311174", 1561 | "n04317175", 1562 | "n04325704", 1563 | "n04326547", 1564 | "n04328186", 1565 | "n04330267", 1566 | "n04332243", 1567 | "n04335435", 1568 | "n04336792", 1569 | "n04344873", 1570 | "n04346328", 1571 | "n04347754", 1572 | "n04350905", 1573 | "n04355338", 1574 | "n04355933", 1575 | "n04356056", 1576 | "n04357314", 1577 | "n04366367", 1578 | "n04367480", 1579 | "n04370456", 1580 | "n04371430", 1581 | "n04371774", 1582 | "n04372370", 1583 | "n04376876", 1584 | "n04380533", 1585 | "n04389033", 1586 | "n04392985", 1587 | "n04398044", 1588 | "n04399382", 1589 | "n04404412", 1590 | "n04409515", 1591 | "n04417672", 1592 | "n04418357", 1593 | "n04423845", 1594 | "n04428191", 1595 | "n04429376", 1596 | "n04435653", 1597 | "n04442312", 1598 | "n04443257", 1599 | "n04447861", 1600 | "n04456115", 1601 | "n04458633", 1602 | "n04461696", 1603 | "n04462240", 1604 | "n04465501", 1605 | "n04467665", 1606 | "n04476259", 1607 | "n04479046", 1608 | "n04482393", 1609 | "n04483307", 1610 | "n04485082", 1611 | "n04486054", 1612 | "n04487081", 1613 | "n04487394", 1614 | "n04493381", 1615 | "n04501370", 1616 | "n04505470", 1617 | "n04507155", 1618 | "n04509417", 1619 | "n04515003", 1620 | "n04517823", 1621 | "n04522168", 1622 | "n04523525", 1623 | "n04525038", 1624 | "n04525305", 1625 | "n04532106", 1626 | "n04532670", 1627 | "n04536866", 1628 | "n04540053", 1629 | "n04542943", 1630 | "n04548280", 1631 | "n04548362", 1632 | "n04550184", 1633 | "n04552348", 1634 | "n04553703", 1635 | "n04554684", 1636 | "n04557648", 1637 | "n04560804", 1638 | "n04562935", 1639 | "n04579145", 1640 | "n04579432", 1641 | "n04584207", 1642 | "n04589890", 1643 | "n04590129", 1644 | "n04591157", 1645 | "n04591713", 1646 | "n04592741", 1647 | "n04596742", 1648 | "n04597913", 1649 | "n04599235", 1650 | "n04604644", 1651 | "n04606251", 1652 | "n04612504", 1653 | "n04613696", 1654 | "n06359193", 1655 | "n06596364", 1656 | "n06785654", 1657 | "n06794110", 1658 | "n06874185", 1659 | "n07248320", 1660 | "n07565083", 1661 | "n07579787", 1662 | "n07583066", 1663 | "n07584110", 1664 | "n07590611", 1665 | "n07613480", 1666 | "n07614500", 1667 | "n07615774", 1668 | "n07684084", 1669 | "n07693725", 1670 | "n07695742", 1671 | "n07697313", 1672 | "n07697537", 1673 | "n07711569", 1674 | "n07714571", 1675 | "n07714990", 1676 | "n07715103", 1677 | "n07716358", 1678 | "n07716906", 1679 | "n07717410", 1680 | "n07717556", 1681 | "n07718472", 1682 | "n07718747", 1683 | "n07720875", 1684 | "n07730033", 1685 | "n07734744", 1686 | "n07742313", 1687 | "n07745940", 1688 | "n07747607", 1689 | "n07749582", 1690 | "n07753113", 1691 | "n07753275", 1692 | "n07753592", 1693 | "n07754684", 1694 | "n07760859", 1695 | "n07768694", 1696 | "n07802026", 1697 | "n07831146", 1698 | "n07836838", 1699 | "n07860988", 1700 | "n07871810", 1701 | "n07873807", 1702 | "n07875152", 1703 | "n07880968", 1704 | "n07892512", 1705 | "n07920052", 1706 | "n07930864", 1707 | "n07932039", 1708 | "n09193705", 1709 | "n09229709", 1710 | "n09246464", 1711 | "n09256479", 1712 | "n09288635", 1713 | "n09332890", 1714 | "n09399592", 1715 | "n09421951", 1716 | "n09428293", 1717 | "n09468604", 1718 | "n09472597", 1719 | "n09835506", 1720 | "n10148035", 1721 | "n10565667", 1722 | "n11879895", 1723 | "n11939491", 1724 | "n12057211", 1725 | "n12144580", 1726 | "n12267677", 1727 | "n12620546", 1728 | "n12768682", 1729 | "n12985857", 1730 | "n12998815", 1731 | "n13037406", 1732 | "n13040303", 1733 | "n13044778", 1734 | "n13052670", 1735 | "n13054560", 1736 | "n13133613", 1737 | "n15075141", 1738 | ] 1739 | 1740 | 1741 | def create_miniimagenet_outlier(imagenet_dir, images_per_class=600): 1742 | # Find class keys appearing in neither MiniImageNet nor TieredImageNet 1743 | mini_tiered_keys = set(MINI_KEYS + TIERED_KEYS_TRAIN_VAL + TIERED_KEYS_TEST) 1744 | outlier_keys = list(set(ALL_KEYS) - mini_tiered_keys) 1745 | 1746 | # Split outlier keys into train and test 1747 | train_outlier_keys = outlier_keys[:175] 1748 | test_outlier_keys = outlier_keys[175:] 1749 | 1750 | # Initialize images and labels 1751 | train_images = np.empty( 1752 | (images_per_class * len(train_outlier_keys), 84, 84, 3), dtype=np.uint8 1753 | ) 1754 | test_images = np.empty( 1755 | (images_per_class * len(test_outlier_keys), 84, 84, 3), dtype=np.uint8 1756 | ) 1757 | train_labels = np.repeat(np.arange(len(train_outlier_keys)), images_per_class) 1758 | test_labels = np.repeat(np.arange(len(test_outlier_keys)), images_per_class) 1759 | 1760 | # Format training set images 1761 | for i, c in tqdm(enumerate(train_outlier_keys)): 1762 | src_class_dir = os.path.join(imagenet_dir, c) 1763 | image_filenames = os.listdir(src_class_dir) 1764 | 1765 | for j in range(images_per_class): 1766 | im = cv2.imread(os.path.join(src_class_dir, image_filenames[j])) 1767 | train_images[(i * images_per_class + j), :, :, :] = cv2.resize( 1768 | im, (84, 84), interpolation=cv2.INTER_AREA 1769 | ) 1770 | 1771 | # Format test set images 1772 | for i, c in tqdm(enumerate(test_outlier_keys)): 1773 | src_class_dir = os.path.join(imagenet_dir, c) 1774 | image_filenames = os.listdir(src_class_dir) 1775 | 1776 | for j in range(images_per_class): 1777 | im = cv2.imread(os.path.join(src_class_dir, image_filenames[j])) 1778 | test_images[(i * images_per_class + j), :, :, :] = cv2.resize( 1779 | im, (84, 84), interpolation=cv2.INTER_AREA 1780 | ) 1781 | 1782 | # Save outlier data 1783 | train_data = { 1784 | "image_data": np.copy(train_images[:, :, :, ::-1]), 1785 | "labels": train_labels, 1786 | } 1787 | test_data = { 1788 | "image_data": np.copy(test_images[:, :, :, ::-1]), 1789 | "labels": test_labels, 1790 | } 1791 | 1792 | output_dir = "./data/miniimagenet_outlier" 1793 | os.makedirs(output_dir) 1794 | 1795 | with open( 1796 | os.path.join(output_dir, "mini-imagenet-outlier-cache-train.pkl"), "wb" 1797 | ) as f: 1798 | pickle.dump(train_data, f) 1799 | 1800 | with open( 1801 | os.path.join(output_dir, "mini-imagenet-outlier-cache-test.pkl"), "wb" 1802 | ) as f: 1803 | pickle.dump(test_data, f) 1804 | 1805 | 1806 | if __name__ == "__main__": 1807 | parser = argparse.ArgumentParser() 1808 | parser.add_argument("--imagenet_dir", type=str, required=True) 1809 | parser.add_argument("--images_per_class", type=int, default=600) 1810 | args = parser.parse_args() 1811 | 1812 | create_miniimagenet_outlier(args.imagenet_dir, args.images_per_class) 1813 | --------------------------------------------------------------------------------