├── .bumpversion.cfg ├── .circleci └── config.yml ├── .gitattributes ├── .github └── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── feature_request.md │ └── question.md ├── .gitignore ├── CITATION.cff ├── LICENSE ├── Makefile ├── README.md ├── data ├── CUB │ ├── test.json │ ├── train.json │ └── val.json ├── fungi │ └── .gitignore ├── mini_imagenet │ ├── test.csv │ ├── train.csv │ └── val.csv ├── models │ └── .gitignore └── tiered_imagenet │ ├── test.json │ ├── train.json │ └── val.json ├── dev_requirements.txt ├── easyfsl ├── __init__.py ├── datasets │ ├── __init__.py │ ├── cub.py │ ├── danish_fungi.py │ ├── default_configs.py │ ├── easy_set.py │ ├── features_dataset.py │ ├── few_shot_dataset.py │ ├── mini_imagenet.py │ ├── support_set_folder.py │ ├── tiered_imagenet.py │ └── wrap_few_shot_dataset.py ├── methods │ ├── __init__.py │ ├── bd_cspn.py │ ├── feat.py │ ├── few_shot_classifier.py │ ├── finetune.py │ ├── laplacian_shot.py │ ├── matching_networks.py │ ├── prototypical_networks.py │ ├── pt_map.py │ ├── relation_networks.py │ ├── simple_shot.py │ ├── tim.py │ ├── transductive_finetuning.py │ └── utils.py ├── modules │ ├── __init__.py │ ├── attention_modules.py │ ├── build_from_checkpoint.py │ ├── feat_resnet12.py │ ├── predesigned_modules.py │ └── resnet.py ├── samplers │ ├── __init__.py │ └── task_sampler.py ├── tests │ ├── __init__.py │ ├── datasets │ │ ├── __init__.py │ │ ├── easy_set_test.py │ │ ├── features_dataset_test.py │ │ ├── resources │ │ │ ├── balanced_support_set │ │ │ │ ├── 160.Black_throated_Blue_Warbler │ │ │ │ │ ├── Black_throated_Blue_Warbler_0007_2916700989.jpg │ │ │ │ │ └── Black_throated_Blue_Warbler_0008_2966090836.jpg │ │ │ │ ├── 161.Blue_winged_Warbler │ │ │ │ │ ├── Blue_winged_Warbler_0011_2521539056.jpg │ │ │ │ │ └── Blue_winged_Warbler_0028_1988388399.jpg │ │ │ │ └── 162.Canada_Warbler │ │ │ │ │ ├── Canada_Warbler_0001_2495535649.jpg │ │ │ │ │ └── Canada_Warbler_0002_2529931098.jpg │ │ │ ├── empty_support_set │ │ │ │ └── class_with_no_image │ │ │ │ │ └── not_an_image.txt │ │ │ └── unbalanced_support_set │ │ │ │ ├── 160.Black_throated_Blue_Warbler │ │ │ │ ├── Black_throated_Blue_Warbler_0007_2916700989.jpg │ │ │ │ └── Black_throated_Blue_Warbler_0008_2966090836.jpg │ │ │ │ ├── 161.Blue_winged_Warbler │ │ │ │ └── Blue_winged_Warbler_0011_2521539056.jpg │ │ │ │ └── 162.Canada_Warbler │ │ │ │ ├── Canada_Warbler_0001_2495535649.jpg │ │ │ │ ├── Canada_Warbler_0002_2529931098.jpg │ │ │ │ ├── Canada_Warbler_0003_2509806963.jpg │ │ │ │ ├── Canada_Warbler_0004_2530218943.jpg │ │ │ │ └── Canada_Warbler_0005_887179386.jpg │ │ ├── support_set_folder_test.py │ │ └── wrap_few_shot_dataset_test.py │ ├── methods │ │ ├── __init__.py │ │ ├── conftest.py │ │ ├── feat_test.py │ │ ├── few_shot_classifier_test.py │ │ ├── finetuning_methods_test.py │ │ ├── matching_networks_test.py │ │ ├── prototypical_networks_test.py │ │ ├── relation_networks_test.py │ │ └── resources │ │ │ ├── Black_footed_Albatross_0001_2950163169.jpg │ │ │ ├── Black_footed_Albatross_0002_2293084168.jpg │ │ │ ├── Black_footed_Albatross_0004_2731401028.jpg │ │ │ ├── Least_Auklet_0001_2947317867.jpg │ │ │ └── Least_Auklet_0004_2685272855.jpg │ ├── modules │ │ ├── __init__.py │ │ ├── predesigned_modules_test.py │ │ └── resnet_test.py │ ├── samplers │ │ ├── __init__.py │ │ └── task_sampler_test.py │ └── utils_test.py └── utils.py ├── notebooks ├── classical_training.ipynb ├── episodic_training.ipynb ├── inference_with_extracted_embeddings.ipynb └── my_first_few_shot_classifier.ipynb ├── pyproject.toml ├── scripts ├── __init__.py ├── backbones_configs.json ├── benchmark_methods.py ├── grid_search.json ├── hyperparameter_search.py ├── methods_configs.json ├── predict_embeddings.py └── utils.py └── setup.py /.bumpversion.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 1.5.0 3 | commit = True 4 | tag = False 5 | 6 | [bumpversion:file:setup.py] 7 | 8 | [bumpversion:file:easyfsl/__init__.py] 9 | -------------------------------------------------------------------------------- /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | 3 | orbs: 4 | 5 | executors: 6 | base-python: 7 | working_directory: ~/project 8 | docker: 9 | - image: cimg/python:3.11.3 10 | 11 | commands: 12 | get_code_and_cached_dependencies: 13 | steps: 14 | - run: 15 | name: Set python 3.11.3 16 | command: | 17 | pyenv global 3.11.3 18 | python -m pip install --upgrade pip setuptools wheel 19 | - checkout 20 | - restore_cache: 21 | key: docker-3.11.3-{{ checksum "setup.py" }}-{{ checksum "dev_requirements.txt" }} 22 | 23 | 24 | jobs: 25 | install: 26 | parameters: 27 | python-version: 28 | type: string 29 | working_directory: ~/project 30 | docker: 31 | - image: cimg/python:<< parameters.python-version >> 32 | steps: 33 | - run: 34 | name: Set python << parameters.python-version >> 35 | command: | 36 | pyenv versions 37 | pyenv global << parameters.python-version >> 38 | python -m pip install --upgrade pip setuptools wheel 39 | - checkout 40 | # Download and cache dependencies 41 | - restore_cache: 42 | keys: 43 | - docker-<< parameters.python-version >>-{{ checksum "setup.py" }}-{{ checksum "dev_requirements.txt" }} 44 | - run: 45 | name: Install dependencies 46 | command: | 47 | python -m venv venv 48 | . venv/bin/activate 49 | pip install --upgrade pip 50 | pip install -r dev_requirements.txt 51 | - save_cache: 52 | paths: 53 | - ./venv 54 | key: docker-<< parameters.python-version >>-{{ checksum "setup.py" }}-{{ checksum "dev_requirements.txt" }} 55 | lint: 56 | executor: base-python 57 | steps: 58 | - get_code_and_cached_dependencies 59 | - run: 60 | name: run pylint 61 | command: | 62 | . venv/bin/activate 63 | pylint --version 64 | make lint 65 | black: 66 | executor: base-python 67 | steps: 68 | - get_code_and_cached_dependencies 69 | - run: 70 | name: run black 71 | command: | 72 | . venv/bin/activate 73 | make black-check 74 | isort: 75 | executor: base-python 76 | steps: 77 | - get_code_and_cached_dependencies 78 | - run: 79 | name: run isort 80 | command: | 81 | . venv/bin/activate 82 | make isort-check 83 | mypy: 84 | executor: base-python 85 | steps: 86 | - get_code_and_cached_dependencies 87 | - run: 88 | name: run mypy 89 | command: | 90 | . venv/bin/activate 91 | make mypy 92 | test: 93 | parameters: 94 | python-version: 95 | type: string 96 | working_directory: ~/project 97 | docker: 98 | - image: cimg/python:<< parameters.python-version >> 99 | steps: 100 | - run: 101 | name: Set python << parameters.python-version >> 102 | command: | 103 | pyenv global << parameters.python-version >> 104 | python -m pip install --upgrade pip setuptools wheel 105 | - checkout 106 | - restore_cache: 107 | key: docker-<< parameters.python-version >>-{{ checksum "setup.py" }}-{{ checksum "dev_requirements.txt" }} 108 | - run: 109 | name: run pytest 110 | command: | 111 | . venv/bin/activate 112 | make test 113 | 114 | workflows: 115 | main: 116 | jobs: 117 | - install: 118 | matrix: 119 | parameters: 120 | python-version: ["3.7.10", "3.8.8", "3.9.11", "3.10.10", "3.11.3"] 121 | - lint: 122 | requires: 123 | - install-3.11.3 124 | - black: 125 | requires: 126 | - install-3.11.3 127 | - isort: 128 | requires: 129 | - install-3.11.3 130 | - mypy: 131 | requires: 132 | - install-3.11.3 133 | - test: 134 | matrix: 135 | parameters: 136 | python-version: ["3.7.10", "3.8.8", "3.9.11", "3.10.10", "3.11.3"] 137 | requires: 138 | - install-<< matrix.python-version >> 139 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-language=Python -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Code 16 | 2. Stack trace 17 | 18 | **Additional context** 19 | Add any other context about the problem here. 20 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Do you want something new? 4 | title: '' 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/question.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Question 3 | about: Any question about how to use EasyFSL? 4 | title: '' 5 | labels: question 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Problem** 11 | What do you want to do? What is blocking you? 12 | 13 | **Considered solutions** 14 | What have you tried but didn't work? 15 | 16 | **How can we help** 17 | Be as clear and concise as possible so we can help you in the most efficient way. 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # PyCharm project settings 118 | .idea/ 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # data 135 | data/features 136 | data/**/images 137 | data/omniglot-py 138 | *.tgz 139 | *.tar 140 | 141 | # Tensorboard logs 142 | events.out.tfevents* -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | # This CITATION.cff file was generated with cffinit. 2 | # Visit https://bit.ly/cffinit to generate yours today! 3 | 4 | cff-version: 1.2.0 5 | title: easyfsl 6 | message: >- 7 | If you use easyfsl in your research, please cite it 8 | using these metadata. 9 | type: software 10 | authors: 11 | - given-names: Etienne 12 | family-names: Bennequin 13 | email: etienneb@sicara.com 14 | affiliation: Université Paris-Saclay 15 | repository-code: 'https://github.com/sicara/easy-few-shot-learning' 16 | abstract: >- 17 | Ready-to-use code and tutorial notebooks to boost 18 | your way into few-shot image classification. 19 | license: MIT 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Sicara 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Checks 2 | 3 | lint: 4 | pylint easyfsl scripts 5 | 6 | test: 7 | pytest easyfsl 8 | 9 | isort: 10 | isort easyfsl scripts 11 | 12 | isort-check: 13 | isort easyfsl scripts --check 14 | 15 | black: 16 | black easyfsl scripts 17 | 18 | black-check: 19 | black easyfsl scripts --check 20 | 21 | mypy: 22 | mypy easyfsl scripts 23 | 24 | # Install 25 | 26 | dev-install: 27 | pip install -r dev_requirements.txt 28 | 29 | # Download data 30 | 31 | # Google Drive sometimes blocks wget downloads. If this recipe doesn't work, download the archive manually from https://docs.google.com/uc?export=download&id=1GDr1OkoXdhaXWGA8S3MAq3a522Tak-nx 32 | download-cub: 33 | mkdir -p data/CUB 34 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1GDr1OkoXdhaXWGA8S3MAq3a522Tak-nx' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1GDr1OkoXdhaXWGA8S3MAq3a522Tak-nx" -O data/CUB/images.tgz 35 | rm -rf /tmp/cookies.txt 36 | tar --exclude='._*' -zxvf data/CUB/images.tgz -C data/CUB/ 37 | 38 | # Benchmarks 39 | 40 | BATCH_SIZE=1024 41 | NUM_WORKERS=12 42 | MODEL_CHECKPOINTS_DIR=data/models 43 | DEVICE=cuda 44 | 45 | extract-mini-imagenet-features-with-resnet12: 46 | python -m scripts.predict_embeddings \ 47 | feat_resnet12 \ 48 | ${MODEL_CHECKPOINTS_DIR}/feat_resnet12_mini_imagenet.pth \ 49 | mini_imagenet \ 50 | --device=${DEVICE} \ 51 | --num-workers=${NUM_WORKERS} \ 52 | --batch-size=${BATCH_SIZE} 53 | 54 | extract-features-with-resnet12-trained-on-tiered-imagenet: 55 | for target_dataset in cub tiered_imagenet fungi; do \ 56 | python -m scripts.predict_embeddings \ 57 | feat_resnet12 \ 58 | ${MODEL_CHECKPOINTS_DIR}/feat_resnet12_tiered_imagenet.pth \ 59 | $${target_dataset} \ 60 | --device=${DEVICE} \ 61 | --num-workers=${NUM_WORKERS} \ 62 | --batch-size=${BATCH_SIZE}; \ 63 | done; \ 64 | 65 | extract-all-features-with-resnet12: 66 | make extract-mini-imagenet-features-with-resnet12 ; \ 67 | make extract-features-with-resnet12-trained-on-tiered-imagenet ; \ 68 | 69 | benchmark-mini-imagenet: 70 | for n_shot in 1 5; do \ 71 | for method in bd_cspn prototypical_networks simple_shot tim finetune laplacian_shot pt_map transductive_finetuning; do \ 72 | python -m scripts.benchmark_methods \ 73 | $${method} \ 74 | data/features/mini_imagenet/test/feat_resnet12_mini_imagenet.parquet.gzip \ 75 | --config="default" \ 76 | --n-shot=$${n_shot} \ 77 | --device=${DEVICE} \ 78 | --num-workers=${NUM_WORKERS}; \ 79 | done; \ 80 | python -m scripts.benchmark_methods \ 81 | feat \ 82 | data/features/mini_imagenet/test/feat_resnet12_mini_imagenet.parquet.gzip \ 83 | --config="resnet12_mini_imagenet" \ 84 | --n-shot=$${n_shot} \ 85 | --device=${DEVICE} \ 86 | --num-workers=${NUM_WORKERS}; \ 87 | done 88 | 89 | benchmark-tiered-imagenet: 90 | for n_shot in 1 5; do \ 91 | for method in bd_cspn prototypical_networks simple_shot tim finetune laplacian_shot pt_map transductive_finetuning; do \ 92 | python -m scripts.benchmark_methods \ 93 | $${method} \ 94 | data/features/tiered_imagenet/test/feat_resnet12_tiered_imagenet.parquet.gzip \ 95 | --config="default" \ 96 | --n-shot=$${n_shot} \ 97 | --device=${DEVICE} \ 98 | --num-workers=${NUM_WORKERS}; \ 99 | done; \ 100 | python -m scripts.benchmark_methods \ 101 | feat \ 102 | data/features/tiered_imagenet/test/feat_resnet12_tiered_imagenet.parquet.gzip \ 103 | --config="resnet12_tiered_imagenet" \ 104 | --n-shot=$${n_shot} \ 105 | --device=${DEVICE} \ 106 | --num-workers=${NUM_WORKERS}; \ 107 | done 108 | 109 | # Hyperparameter search 110 | extract-mini-imagenet-val-features-with-resnet12: 111 | python -m scripts.predict_embeddings \ 112 | feat_resnet12 \ 113 | ${MODEL_CHECKPOINTS_DIR}/feat_resnet12_mini_imagenet.pth \ 114 | mini_imagenet \ 115 | --split=val \ 116 | --device=${DEVICE} \ 117 | --num-workers=${NUM_WORKERS} \ 118 | --batch-size=${BATCH_SIZE} 119 | 120 | hyperparameter-search: 121 | for method in tim finetune pt_map laplacian_shot transductive_finetuning; do \ 122 | python -m scripts.hyperparameter_search \ 123 | $${method} \ 124 | data/features/mini_imagenet/val/feat_resnet12_mini_imagenet.parquet.gzip \ 125 | --n-shot=5 \ 126 | --device=${DEVICE} \ 127 | --num-workers=${NUM_WORKERS}; \ 128 | done; 129 | -------------------------------------------------------------------------------- /data/CUB/test.json: -------------------------------------------------------------------------------- 1 | { 2 | "class_names": [ 3 | "008.Rhinoceros_Auklet", 4 | "009.Brewer_Blackbird", 5 | "015.Lazuli_Bunting", 6 | "020.Yellow_breasted_Chat", 7 | "028.Brown_Creeper", 8 | "030.Fish_Crow", 9 | "035.Purple_Finch", 10 | "039.Least_Flycatcher", 11 | "045.Northern_Fulmar", 12 | "046.Gadwall", 13 | "082.Ringed_Kingfisher", 14 | "085.Horned_Lark", 15 | "094.White_breasted_Nuthatch", 16 | "101.White_Pelican", 17 | "103.Sayornis", 18 | "112.Great_Grey_Shrike", 19 | "118.House_Sparrow", 20 | "122.Harris_Sparrow", 21 | "128.Seaside_Sparrow", 22 | "133.White_throated_Sparrow", 23 | "134.Cape_Glossy_Starling", 24 | "137.Cliff_Swallow", 25 | "147.Least_Tern", 26 | "148.Green_tailed_Towhee", 27 | "163.Cape_May_Warbler", 28 | "168.Kentucky_Warbler", 29 | "169.Magnolia_Warbler", 30 | "170.Mourning_Warbler", 31 | "193.Bewick_Wren", 32 | "194.Cactus_Wren" 33 | ], 34 | "class_roots": [ 35 | "./data/CUB/images/008.Rhinoceros_Auklet", 36 | "./data/CUB/images/009.Brewer_Blackbird", 37 | "./data/CUB/images/015.Lazuli_Bunting", 38 | "./data/CUB/images/020.Yellow_breasted_Chat", 39 | "./data/CUB/images/028.Brown_Creeper", 40 | "./data/CUB/images/030.Fish_Crow", 41 | "./data/CUB/images/035.Purple_Finch", 42 | "./data/CUB/images/039.Least_Flycatcher", 43 | "./data/CUB/images/045.Northern_Fulmar", 44 | "./data/CUB/images/046.Gadwall", 45 | "./data/CUB/images/082.Ringed_Kingfisher", 46 | "./data/CUB/images/085.Horned_Lark", 47 | "./data/CUB/images/094.White_breasted_Nuthatch", 48 | "./data/CUB/images/101.White_Pelican", 49 | "./data/CUB/images/103.Sayornis", 50 | "./data/CUB/images/112.Great_Grey_Shrike", 51 | "./data/CUB/images/118.House_Sparrow", 52 | "./data/CUB/images/122.Harris_Sparrow", 53 | "./data/CUB/images/128.Seaside_Sparrow", 54 | "./data/CUB/images/133.White_throated_Sparrow", 55 | "./data/CUB/images/134.Cape_Glossy_Starling", 56 | "./data/CUB/images/137.Cliff_Swallow", 57 | "./data/CUB/images/147.Least_Tern", 58 | "./data/CUB/images/148.Green_tailed_Towhee", 59 | "./data/CUB/images/163.Cape_May_Warbler", 60 | "./data/CUB/images/168.Kentucky_Warbler", 61 | "./data/CUB/images/169.Magnolia_Warbler", 62 | "./data/CUB/images/170.Mourning_Warbler", 63 | "./data/CUB/images/193.Bewick_Wren", 64 | "./data/CUB/images/194.Cactus_Wren" 65 | ] 66 | } -------------------------------------------------------------------------------- /data/CUB/val.json: -------------------------------------------------------------------------------- 1 | { 2 | "class_names": [ 3 | "044.Frigatebird", 4 | "051.Horned_Grebe", 5 | "052.Pied_billed_Grebe", 6 | "055.Evening_Grosbeak", 7 | "057.Rose_breasted_Grosbeak", 8 | "061.Heermann_Gull", 9 | "064.Ring_billed_Gull", 10 | "066.Western_Gull", 11 | "068.Ruby_throated_Hummingbird", 12 | "072.Pomarine_Jaeger", 13 | "076.Dark_eyed_Junco", 14 | "080.Green_Kingfisher", 15 | "084.Red_legged_Kittiwake", 16 | "096.Hooded_Oriole", 17 | "113.Baird_Sparrow", 18 | "121.Grasshopper_Sparrow", 19 | "127.Savannah_Sparrow", 20 | "136.Barn_Swallow", 21 | "140.Summer_Tanager", 22 | "145.Elegant_Tern", 23 | "150.Sage_Thrasher", 24 | "155.Warbling_Vireo", 25 | "159.Black_and_white_Warbler", 26 | "162.Canada_Warbler", 27 | "172.Nashville_Warbler", 28 | "176.Prairie_Warbler", 29 | "178.Swainson_Warbler", 30 | "198.Rock_Wren", 31 | "199.Winter_Wren", 32 | "200.Common_Yellowthroat" 33 | ], 34 | "class_roots": [ 35 | "./data/CUB/images/044.Frigatebird", 36 | "./data/CUB/images/051.Horned_Grebe", 37 | "./data/CUB/images/052.Pied_billed_Grebe", 38 | "./data/CUB/images/055.Evening_Grosbeak", 39 | "./data/CUB/images/057.Rose_breasted_Grosbeak", 40 | "./data/CUB/images/061.Heermann_Gull", 41 | "./data/CUB/images/064.Ring_billed_Gull", 42 | "./data/CUB/images/066.Western_Gull", 43 | "./data/CUB/images/068.Ruby_throated_Hummingbird", 44 | "./data/CUB/images/072.Pomarine_Jaeger", 45 | "./data/CUB/images/076.Dark_eyed_Junco", 46 | "./data/CUB/images/080.Green_Kingfisher", 47 | "./data/CUB/images/084.Red_legged_Kittiwake", 48 | "./data/CUB/images/096.Hooded_Oriole", 49 | "./data/CUB/images/113.Baird_Sparrow", 50 | "./data/CUB/images/121.Grasshopper_Sparrow", 51 | "./data/CUB/images/127.Savannah_Sparrow", 52 | "./data/CUB/images/136.Barn_Swallow", 53 | "./data/CUB/images/140.Summer_Tanager", 54 | "./data/CUB/images/145.Elegant_Tern", 55 | "./data/CUB/images/150.Sage_Thrasher", 56 | "./data/CUB/images/155.Warbling_Vireo", 57 | "./data/CUB/images/159.Black_and_white_Warbler", 58 | "./data/CUB/images/162.Canada_Warbler", 59 | "./data/CUB/images/172.Nashville_Warbler", 60 | "./data/CUB/images/176.Prairie_Warbler", 61 | "./data/CUB/images/178.Swainson_Warbler", 62 | "./data/CUB/images/198.Rock_Wren", 63 | "./data/CUB/images/199.Winter_Wren", 64 | "./data/CUB/images/200.Common_Yellowthroat" 65 | ] 66 | } -------------------------------------------------------------------------------- /data/fungi/.gitignore: -------------------------------------------------------------------------------- 1 | DF20_metadata.csv 2 | 3 | images/ 4 | -------------------------------------------------------------------------------- /data/models/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | **/* -------------------------------------------------------------------------------- /data/tiered_imagenet/val.json: -------------------------------------------------------------------------------- 1 | { 2 | "class_names": [ 3 | "n02099267", 4 | "n02099429", 5 | "n02099601", 6 | "n02099712", 7 | "n02099849", 8 | "n02100236", 9 | "n02100583", 10 | "n02100735", 11 | "n02100877", 12 | "n02101006", 13 | "n02101388", 14 | "n02101556", 15 | "n02102040", 16 | "n02102177", 17 | "n02102318", 18 | "n02102480", 19 | "n02102973", 20 | "n03207941", 21 | "n03259280", 22 | "n03297495", 23 | "n03483316", 24 | "n03584829", 25 | "n03761084", 26 | "n04070727", 27 | "n04111531", 28 | "n04442312", 29 | "n04517823", 30 | "n04542943", 31 | "n04554684", 32 | "n02791124", 33 | "n02804414", 34 | "n02870880", 35 | "n03016953", 36 | "n03018349", 37 | "n03125729", 38 | "n03131574", 39 | "n03179701", 40 | "n03201208", 41 | "n03290653", 42 | "n03337140", 43 | "n03376595", 44 | "n03388549", 45 | "n03742115", 46 | "n03891251", 47 | "n03998194", 48 | "n04099969", 49 | "n04344873", 50 | "n04380533", 51 | "n04429376", 52 | "n04447861", 53 | "n04550184", 54 | "n02666196", 55 | "n02977058", 56 | "n03180011", 57 | "n03485407", 58 | "n03496892", 59 | "n03642806", 60 | "n03832673", 61 | "n04238763", 62 | "n04243546", 63 | "n04428191", 64 | "n04525305", 65 | "n06359193", 66 | "n02966193", 67 | "n02974003", 68 | "n03425413", 69 | "n03532672", 70 | "n03874293", 71 | "n03944341", 72 | "n03992509", 73 | "n04019541", 74 | "n04040759", 75 | "n04067472", 76 | "n04371774", 77 | "n04372370", 78 | "n02701002", 79 | "n02704792", 80 | "n02814533", 81 | "n02930766", 82 | "n03100240", 83 | "n03345487", 84 | "n03417042", 85 | "n03444034", 86 | "n03445924", 87 | "n03594945", 88 | "n03670208", 89 | "n03770679", 90 | "n03777568", 91 | "n03785016", 92 | "n03796401", 93 | "n03930630", 94 | "n03977966", 95 | "n04037443", 96 | "n04252225", 97 | "n04285008", 98 | "n04461696", 99 | "n04467665" 100 | ], 101 | "class_roots": [ 102 | "./data/tiered_imagenet/images/n02099267", 103 | "./data/tiered_imagenet/images/n02099429", 104 | "./data/tiered_imagenet/images/n02099601", 105 | "./data/tiered_imagenet/images/n02099712", 106 | "./data/tiered_imagenet/images/n02099849", 107 | "./data/tiered_imagenet/images/n02100236", 108 | "./data/tiered_imagenet/images/n02100583", 109 | "./data/tiered_imagenet/images/n02100735", 110 | "./data/tiered_imagenet/images/n02100877", 111 | "./data/tiered_imagenet/images/n02101006", 112 | "./data/tiered_imagenet/images/n02101388", 113 | "./data/tiered_imagenet/images/n02101556", 114 | "./data/tiered_imagenet/images/n02102040", 115 | "./data/tiered_imagenet/images/n02102177", 116 | "./data/tiered_imagenet/images/n02102318", 117 | "./data/tiered_imagenet/images/n02102480", 118 | "./data/tiered_imagenet/images/n02102973", 119 | "./data/tiered_imagenet/images/n03207941", 120 | "./data/tiered_imagenet/images/n03259280", 121 | "./data/tiered_imagenet/images/n03297495", 122 | "./data/tiered_imagenet/images/n03483316", 123 | "./data/tiered_imagenet/images/n03584829", 124 | "./data/tiered_imagenet/images/n03761084", 125 | "./data/tiered_imagenet/images/n04070727", 126 | "./data/tiered_imagenet/images/n04111531", 127 | "./data/tiered_imagenet/images/n04442312", 128 | "./data/tiered_imagenet/images/n04517823", 129 | "./data/tiered_imagenet/images/n04542943", 130 | "./data/tiered_imagenet/images/n04554684", 131 | "./data/tiered_imagenet/images/n02791124", 132 | "./data/tiered_imagenet/images/n02804414", 133 | "./data/tiered_imagenet/images/n02870880", 134 | "./data/tiered_imagenet/images/n03016953", 135 | "./data/tiered_imagenet/images/n03018349", 136 | "./data/tiered_imagenet/images/n03125729", 137 | "./data/tiered_imagenet/images/n03131574", 138 | "./data/tiered_imagenet/images/n03179701", 139 | "./data/tiered_imagenet/images/n03201208", 140 | "./data/tiered_imagenet/images/n03290653", 141 | "./data/tiered_imagenet/images/n03337140", 142 | "./data/tiered_imagenet/images/n03376595", 143 | "./data/tiered_imagenet/images/n03388549", 144 | "./data/tiered_imagenet/images/n03742115", 145 | "./data/tiered_imagenet/images/n03891251", 146 | "./data/tiered_imagenet/images/n03998194", 147 | "./data/tiered_imagenet/images/n04099969", 148 | "./data/tiered_imagenet/images/n04344873", 149 | "./data/tiered_imagenet/images/n04380533", 150 | "./data/tiered_imagenet/images/n04429376", 151 | "./data/tiered_imagenet/images/n04447861", 152 | "./data/tiered_imagenet/images/n04550184", 153 | "./data/tiered_imagenet/images/n02666196", 154 | "./data/tiered_imagenet/images/n02977058", 155 | "./data/tiered_imagenet/images/n03180011", 156 | "./data/tiered_imagenet/images/n03485407", 157 | "./data/tiered_imagenet/images/n03496892", 158 | "./data/tiered_imagenet/images/n03642806", 159 | "./data/tiered_imagenet/images/n03832673", 160 | "./data/tiered_imagenet/images/n04238763", 161 | "./data/tiered_imagenet/images/n04243546", 162 | "./data/tiered_imagenet/images/n04428191", 163 | "./data/tiered_imagenet/images/n04525305", 164 | "./data/tiered_imagenet/images/n06359193", 165 | "./data/tiered_imagenet/images/n02966193", 166 | "./data/tiered_imagenet/images/n02974003", 167 | "./data/tiered_imagenet/images/n03425413", 168 | "./data/tiered_imagenet/images/n03532672", 169 | "./data/tiered_imagenet/images/n03874293", 170 | "./data/tiered_imagenet/images/n03944341", 171 | "./data/tiered_imagenet/images/n03992509", 172 | "./data/tiered_imagenet/images/n04019541", 173 | "./data/tiered_imagenet/images/n04040759", 174 | "./data/tiered_imagenet/images/n04067472", 175 | "./data/tiered_imagenet/images/n04371774", 176 | "./data/tiered_imagenet/images/n04372370", 177 | "./data/tiered_imagenet/images/n02701002", 178 | "./data/tiered_imagenet/images/n02704792", 179 | "./data/tiered_imagenet/images/n02814533", 180 | "./data/tiered_imagenet/images/n02930766", 181 | "./data/tiered_imagenet/images/n03100240", 182 | "./data/tiered_imagenet/images/n03345487", 183 | "./data/tiered_imagenet/images/n03417042", 184 | "./data/tiered_imagenet/images/n03444034", 185 | "./data/tiered_imagenet/images/n03445924", 186 | "./data/tiered_imagenet/images/n03594945", 187 | "./data/tiered_imagenet/images/n03670208", 188 | "./data/tiered_imagenet/images/n03770679", 189 | "./data/tiered_imagenet/images/n03777568", 190 | "./data/tiered_imagenet/images/n03785016", 191 | "./data/tiered_imagenet/images/n03796401", 192 | "./data/tiered_imagenet/images/n03930630", 193 | "./data/tiered_imagenet/images/n03977966", 194 | "./data/tiered_imagenet/images/n04037443", 195 | "./data/tiered_imagenet/images/n04252225", 196 | "./data/tiered_imagenet/images/n04285008", 197 | "./data/tiered_imagenet/images/n04461696", 198 | "./data/tiered_imagenet/images/n04467665" 199 | ] 200 | } -------------------------------------------------------------------------------- /dev_requirements.txt: -------------------------------------------------------------------------------- 1 | black>=20.8b1 2 | isort>=5.10.1 3 | jupyter>=1.0.0 4 | loguru>=0.5.3 5 | matplotlib>=3.3.4 6 | mypy>=0.971 7 | pandas>=1.2.1 8 | pyarrow>=12.0.0 9 | pylint==2.17.7 10 | pytest>=7.3.1 11 | pytest-mock>=3.10.0 12 | tensorboard>=2.8.0 13 | torch>=1.9.0 14 | torchvision>=0.10.0 15 | tqdm>=4.56.0 16 | typer>=0.9.0 17 | -------------------------------------------------------------------------------- /easyfsl/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The easyfsl library. 3 | 4 | This library implements few-shot learning methods, along with data loading tools 5 | for few-shot learning experiences. 6 | """ 7 | 8 | __version__ = "1.5.0" 9 | -------------------------------------------------------------------------------- /easyfsl/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .cub import CUB 2 | from .danish_fungi import DanishFungi 3 | from .easy_set import EasySet 4 | from .features_dataset import FeaturesDataset 5 | from .few_shot_dataset import FewShotDataset 6 | from .mini_imagenet import MiniImageNet 7 | from .support_set_folder import SupportSetFolder 8 | from .tiered_imagenet import TieredImageNet 9 | from .wrap_few_shot_dataset import WrapFewShotDataset 10 | -------------------------------------------------------------------------------- /easyfsl/datasets/cub.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from .easy_set import EasySet 4 | 5 | CUB_SPECS_DIR = Path("data/CUB") 6 | 7 | 8 | class CUB(EasySet): 9 | def __init__(self, split: str, **kwargs): 10 | """ 11 | Build the CUB dataset for the specific split. 12 | Args: 13 | split: one of the available split (typically train, val, test). 14 | Raises: 15 | ValueError: if the specified split cannot be associated with a JSON spec file 16 | from CUB's specs directory 17 | """ 18 | specs_file = CUB_SPECS_DIR / f"{split}.json" 19 | if not specs_file.is_file(): 20 | raise ValueError( 21 | f"Could not find specs file {specs_file.name} in {CUB_SPECS_DIR}" 22 | ) 23 | super().__init__(specs_file=specs_file, **kwargs) 24 | -------------------------------------------------------------------------------- /easyfsl/datasets/danish_fungi.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Callable, List, Optional, Tuple, Union 3 | 4 | import pandas as pd 5 | from pandas import DataFrame 6 | from PIL import Image 7 | from torch import Tensor 8 | 9 | from .default_configs import default_transform 10 | from .few_shot_dataset import FewShotDataset 11 | 12 | WHOLE_DANISH_FUNGI_SPECS_FILE = Path("data/fungi") / "DF20_metadata.csv" 13 | 14 | 15 | class DanishFungi(FewShotDataset): 16 | def __init__( 17 | self, 18 | root: Union[Path, str], 19 | specs_file: Union[Path, str] = WHOLE_DANISH_FUNGI_SPECS_FILE, 20 | image_size: int = 84, 21 | transform: Optional[Callable] = None, 22 | training: bool = False, 23 | image_file_extension: str = ".JPG", 24 | ): 25 | """ 26 | Args: 27 | root: directory where all the images are 28 | specs_file: path to the CSV file 29 | image_size: images returned by the dataset will be square images of the given size 30 | transform: torchvision transforms to be applied to images. If none is provided, 31 | we use some standard transformations including ImageNet normalization. 32 | These default transformations depend on the "training" argument. 33 | training: preprocessing is slightly different for a training set, adding a random 34 | cropping and a random horizontal flip. Only used if transforms = None. 35 | image_file_extension: the metadata csv file and the complete dataset user ".JPG" image file extension, 36 | but the version of the dataset with 300px images uses ".jpg" extensions. If using the small dataset, 37 | set this to ".jpg". 38 | """ 39 | self.root = Path(root) 40 | self.image_file_extension = image_file_extension 41 | self.data = self.load_specs(Path(specs_file)) 42 | 43 | self.class_names = list(self.data.drop_duplicates("label").scientific_name) 44 | 45 | self.transform = ( 46 | transform if transform else default_transform(image_size, training=training) 47 | ) 48 | 49 | def load_specs(self, specs_file: Path) -> DataFrame: 50 | """ 51 | Load specs from a CSV file. 52 | Args: 53 | specs_file: path to the CSV file 54 | Returns: 55 | curated data contained in the CSV file 56 | """ 57 | data = pd.read_csv(specs_file) 58 | 59 | class_names = list(data.scientific_name.unique()) 60 | label_mapping = {name: class_names.index(name) for name in class_names} 61 | 62 | if self.image_file_extension != ".JPG": 63 | data.image_path = data.image_path.str.replace( 64 | ".JPG", self.image_file_extension 65 | ) 66 | 67 | return data.assign(label=lambda df: df.scientific_name.map(label_mapping)) 68 | 69 | def __getitem__(self, item: int) -> Tuple[Tensor, int]: 70 | """ 71 | Get a data sample from its integer id. 72 | Args: 73 | item: sample's integer id 74 | Returns: 75 | data sample in the form of a tuple (image, label), where label is an integer. 76 | The type of the image object depends on the output type of self.transform. 77 | """ 78 | img = self.transform( 79 | Image.open(self.root / self.data.image_path[item]).convert("RGB") 80 | ) 81 | label = self.data.label[item] 82 | 83 | return img, label 84 | 85 | def __len__(self) -> int: 86 | return len(self.data) 87 | 88 | def get_labels(self) -> List[int]: 89 | return list(self.data.label) 90 | -------------------------------------------------------------------------------- /easyfsl/datasets/default_configs.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | from torchvision import transforms 4 | 5 | IMAGENET_NORMALIZATION = {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]} 6 | 7 | DEFAULT_IMAGE_FORMATS = {".bmp", ".png", ".jpeg", ".jpg"} 8 | 9 | 10 | def default_transform(image_size: int, training: bool) -> Callable: 11 | """ 12 | Create a composition of torchvision transformations, with some randomization if we are 13 | building a training set. 14 | Args: 15 | image_size: size of dataset images 16 | training: whether this is a training set or not 17 | 18 | Returns: 19 | compositions of torchvision transformations 20 | """ 21 | return ( 22 | transforms.Compose( 23 | [ 24 | transforms.RandomResizedCrop(image_size), 25 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 26 | transforms.RandomHorizontalFlip(), 27 | transforms.ToTensor(), 28 | transforms.Normalize(**IMAGENET_NORMALIZATION), 29 | ] 30 | ) 31 | if training 32 | else transforms.Compose( 33 | [ 34 | transforms.Resize([int(image_size * 1.15), int(image_size * 1.15)]), 35 | transforms.CenterCrop(image_size), 36 | transforms.ToTensor(), 37 | transforms.Normalize(**IMAGENET_NORMALIZATION), 38 | ] 39 | ) 40 | ) 41 | 42 | 43 | def default_mini_imagenet_loading_transform( 44 | image_size: int, 45 | ) -> Callable: 46 | """ 47 | Create a composition of torchvision transformations to perform when loading images, but before 48 | serving them (when data is loaded at instantiation, not on the fly). 49 | Args: 50 | image_size: size of dataset images 51 | 52 | Returns: 53 | compositions of torchvision transformations 54 | """ 55 | return transforms.Compose( 56 | [ 57 | transforms.Resize([int(image_size * 2.0), int(image_size * 2.0)]), 58 | transforms.ToTensor(), 59 | ], 60 | ) 61 | 62 | 63 | def default_mini_imagenet_serving_transform( 64 | image_size: int, training: bool 65 | ) -> Callable: 66 | """ 67 | Create a composition of torchvision transformations to perform when serving images 68 | (when data is loaded at instantiation, not on the fly). 69 | Args: 70 | image_size: size of dataset images 71 | training: whether this is a training set or not 72 | 73 | Returns: 74 | compositions of torchvision transformations 75 | """ 76 | return ( 77 | transforms.Compose( 78 | [ 79 | transforms.RandomResizedCrop(image_size), 80 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 81 | transforms.RandomHorizontalFlip(), 82 | transforms.Normalize(**IMAGENET_NORMALIZATION), 83 | ] 84 | ) 85 | if training 86 | else transforms.Compose( 87 | [ 88 | transforms.Resize([int(image_size * 1.15), int(image_size * 1.15)]), 89 | transforms.CenterCrop(image_size), 90 | transforms.Normalize(**IMAGENET_NORMALIZATION), 91 | ] 92 | ) 93 | ) 94 | -------------------------------------------------------------------------------- /easyfsl/datasets/easy_set.py: -------------------------------------------------------------------------------- 1 | import json 2 | import warnings 3 | from pathlib import Path 4 | from typing import Callable, List, Optional, Set, Tuple, Union 5 | 6 | from PIL import Image 7 | 8 | from .default_configs import DEFAULT_IMAGE_FORMATS, default_transform 9 | from .few_shot_dataset import FewShotDataset 10 | 11 | 12 | class EasySet(FewShotDataset): 13 | """ 14 | A ready-to-use dataset. Will work for any dataset where the images are 15 | grouped in directories by class. It expects a JSON file defining the 16 | classes and where to find them. It must have the following shape: 17 | { 18 | "class_names": [ 19 | "class_1", 20 | "class_2" 21 | ], 22 | "class_roots": [ 23 | "path/to/class_1_folder", 24 | "path/to/class_2_folder" 25 | ] 26 | } 27 | """ 28 | 29 | def __init__( 30 | self, 31 | specs_file: Union[Path, str], 32 | image_size: int = 84, 33 | transform: Optional[Callable] = None, 34 | training: bool = False, 35 | supported_formats: Optional[Set[str]] = None, 36 | ): 37 | """ 38 | Args: 39 | specs_file: path to the JSON file 40 | image_size: images returned by the dataset will be square images of the given size 41 | transform: torchvision transforms to be applied to images. If none is provided, 42 | we use some standard transformations including ImageNet normalization. 43 | These default transformations depend on the "training" argument. 44 | training: preprocessing is slightly different for a training set, adding a random 45 | cropping and a random horizontal flip. Only used if transforms = None. 46 | supported_formats: set of allowed file format. When listing data instances, EasySet 47 | will only consider these files. If none is provided, we use the default set of 48 | image formats. 49 | """ 50 | specs = self.load_specs(Path(specs_file)) 51 | 52 | self.images, self.labels = self.list_data_instances( 53 | specs["class_roots"], supported_formats=supported_formats 54 | ) 55 | 56 | self.class_names = specs["class_names"] 57 | 58 | self.transform = ( 59 | transform if transform else default_transform(image_size, training) 60 | ) 61 | 62 | @staticmethod 63 | def load_specs(specs_file: Path) -> dict: 64 | """ 65 | Load specs from a JSON file. 66 | Args: 67 | specs_file: path to the JSON file 68 | 69 | Returns: 70 | dictionary contained in the JSON file 71 | 72 | Raises: 73 | ValueError: if specs_file is not a JSON, or if it is a JSON and the content is not 74 | of the expected shape. 75 | """ 76 | 77 | if specs_file.suffix != ".json": 78 | raise ValueError("EasySet requires specs in a JSON file.") 79 | 80 | with open(specs_file, "r", encoding="utf-8") as file: 81 | specs = json.load(file) 82 | 83 | if "class_names" not in specs.keys() or "class_roots" not in specs.keys(): 84 | raise ValueError( 85 | "EasySet requires specs in a JSON file with the keys class_names and class_roots." 86 | ) 87 | 88 | if len(specs["class_names"]) != len(specs["class_roots"]): 89 | raise ValueError( 90 | "Number of class names does not match the number of class root directories." 91 | ) 92 | 93 | return specs 94 | 95 | @staticmethod 96 | def list_data_instances( 97 | class_roots: List[str], supported_formats: Optional[Set[str]] = None 98 | ) -> Tuple[List[str], List[int]]: 99 | """ 100 | Explore the directories specified in class_roots to find all data instances. 101 | Args: 102 | class_roots: each element is the path to the directory containing the elements 103 | of one class 104 | supported_formats: set of allowed file format. When listing data instances, EasySet 105 | will only consider these files. If none is provided, we use the default set of 106 | image formats. 107 | 108 | Returns: 109 | list of paths to the images, and a list of same length containing the integer label 110 | of each image 111 | """ 112 | if supported_formats is None: 113 | supported_formats = DEFAULT_IMAGE_FORMATS 114 | 115 | images = [] 116 | labels = [] 117 | for class_id, class_root in enumerate(class_roots): 118 | class_images = [ 119 | str(image_path) 120 | for image_path in sorted(Path(class_root).glob("*")) 121 | if image_path.is_file() 122 | & (image_path.suffix.lower() in supported_formats) 123 | ] 124 | 125 | images += class_images 126 | labels += len(class_images) * [class_id] 127 | 128 | if len(images) == 0: 129 | warnings.warn( 130 | UserWarning( 131 | "No images found in the specified directories. The dataset will be empty." 132 | ) 133 | ) 134 | 135 | return images, labels 136 | 137 | def __getitem__(self, item: int): 138 | """ 139 | Get a data sample from its integer id. 140 | Args: 141 | item: sample's integer id 142 | 143 | Returns: 144 | data sample in the form of a tuple (image, label), where label is an integer. 145 | The type of the image object depends of the output type of self.transform. By default 146 | it's a torch.Tensor, however you are free to define any function as self.transform, and 147 | therefore any type for the output image. For instance, if self.transform = lambda x: x, 148 | then the output image will be of type PIL.Image.Image. 149 | """ 150 | # Some images of ILSVRC2015 are grayscale, so we convert everything to RGB for consistence. 151 | # If you want to work on grayscale images, use torch.transforms.Grayscale in your 152 | # transformation pipeline. 153 | img = self.transform(Image.open(self.images[item]).convert("RGB")) 154 | label = self.labels[item] 155 | 156 | return img, label 157 | 158 | def __len__(self) -> int: 159 | return len(self.labels) 160 | 161 | def get_labels(self) -> List[int]: 162 | return self.labels 163 | 164 | def number_of_classes(self): 165 | return len(self.class_names) 166 | -------------------------------------------------------------------------------- /easyfsl/datasets/features_dataset.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Dict, List, Tuple, Union 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | from numpy import ndarray 8 | from torch import Tensor 9 | 10 | from .few_shot_dataset import FewShotDataset 11 | 12 | 13 | class FeaturesDataset(FewShotDataset): 14 | def __init__( 15 | self, 16 | labels: List[int], 17 | embeddings: Tensor, 18 | class_names: List[str], 19 | ): 20 | """ 21 | Initialize a FeaturesDataset from explicit labels, class_names and embeddings. 22 | You can also initialize a FeaturesDataset from: 23 | - a dataframe with from_dataframe(); 24 | - a dictionary with from_dict(); 25 | Args: 26 | labels: list of labels, one for each embedding 27 | embeddings: tensor of embeddings with shape (n_images_for_this_class, **embedding_dimension) 28 | class_names: the name of the class associated to each integer label 29 | (length is the number of unique integers in labels) 30 | """ 31 | self.labels = labels 32 | self.embeddings = embeddings 33 | self.class_names = class_names 34 | 35 | @classmethod 36 | def from_dataframe(cls, source_dataframe: pd.DataFrame): 37 | """ 38 | Instantiate a FeaturesDataset from a dataframe. 39 | embeddings and class_names are directly inferred from the dataframe's content, 40 | while labels are inferred from the class_names. 41 | Args: 42 | source_dataframe: must have the columns embedding and class_name. 43 | Embeddings must be tensors or numpy arrays. 44 | """ 45 | if not {"embedding", "class_name"}.issubset(source_dataframe.columns): 46 | raise ValueError( 47 | f"Source dataframe must have the columns embedding and class_name, " 48 | f"but has columns {source_dataframe.columns}" 49 | ) 50 | 51 | class_names = list(source_dataframe.class_name.unique()) 52 | labels = list( 53 | source_dataframe.class_name.map( 54 | { 55 | class_name: class_id 56 | for class_id, class_name in enumerate(class_names) 57 | } 58 | ) 59 | ) 60 | if len(source_dataframe) == 0: 61 | warnings.warn( 62 | UserWarning( 63 | "Empty source dataframe. Initializing an empty FeaturesDataset." 64 | ) 65 | ) 66 | embeddings = torch.empty(0) 67 | else: 68 | embeddings = torch.from_numpy(np.stack(list(source_dataframe.embedding))) 69 | 70 | return cls(labels, embeddings, class_names) 71 | 72 | @classmethod 73 | def from_dict(cls, source_dict: Dict[str, Union[ndarray, Tensor]]): 74 | """ 75 | Instantiate a FeaturesDataset from a dictionary. 76 | Args: 77 | source_dict: each key is a class's name and each value is a numpy array or torch tensor 78 | with shape (n_images_for_this_class, **embedding_dimension) 79 | """ 80 | class_names = [] 81 | labels = [] 82 | embeddings_list = [] 83 | for class_id, (class_name, class_embeddings) in enumerate(source_dict.items()): 84 | class_names.append(class_name) 85 | if isinstance(class_embeddings, ndarray): 86 | embeddings_list.append(torch.from_numpy(class_embeddings)) 87 | elif isinstance(class_embeddings, Tensor): 88 | embeddings_list.append(class_embeddings) 89 | else: 90 | raise ValueError( 91 | f"Each value of the source_dict must be a ndarray or torch tensor, " 92 | f"but the value for class {class_name} is {class_embeddings}" 93 | ) 94 | labels += len(class_embeddings) * [class_id] 95 | return cls(labels, torch.cat(embeddings_list), class_names) 96 | 97 | def __getitem__(self, index: int) -> Tuple[Tensor, int]: 98 | return self.embeddings[index], self.labels[index] 99 | 100 | def __len__(self) -> int: 101 | return len(self.labels) 102 | 103 | def get_labels(self) -> List[int]: 104 | return self.labels 105 | 106 | def number_of_classes(self): 107 | return len(self.class_names) 108 | -------------------------------------------------------------------------------- /easyfsl/datasets/few_shot_dataset.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import List, Tuple 3 | 4 | from torch import Tensor 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class FewShotDataset(Dataset): 9 | """ 10 | Abstract class for all datasets used in a context of Few-Shot Learning. 11 | The tools we use in few-shot learning, especially TaskSampler, expect an 12 | implementation of FewShotDataset. 13 | Compared to PyTorch's Dataset, FewShotDataset forces a method get_labels. 14 | This exposes the list of all items labels and therefore allows to sample 15 | items depending on their label. 16 | """ 17 | 18 | @abstractmethod 19 | def __getitem__(self, item: int) -> Tuple[Tensor, int]: 20 | raise NotImplementedError( 21 | "All PyTorch datasets, including few-shot datasets, need a __getitem__ method." 22 | ) 23 | 24 | @abstractmethod 25 | def __len__(self) -> int: 26 | raise NotImplementedError( 27 | "All PyTorch datasets, including few-shot datasets, need a __len__ method." 28 | ) 29 | 30 | @abstractmethod 31 | def get_labels(self) -> List[int]: 32 | raise NotImplementedError( 33 | "Implementations of FewShotDataset need a get_labels method." 34 | ) 35 | -------------------------------------------------------------------------------- /easyfsl/datasets/mini_imagenet.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Callable, List, Optional, Union 3 | 4 | import pandas as pd 5 | import torch 6 | from pandas import DataFrame 7 | from PIL import Image 8 | from torch import Tensor 9 | from tqdm import tqdm 10 | 11 | from .default_configs import ( 12 | default_mini_imagenet_loading_transform, 13 | default_mini_imagenet_serving_transform, 14 | default_transform, 15 | ) 16 | from .few_shot_dataset import FewShotDataset 17 | 18 | MINI_IMAGENET_SPECS_DIR = Path("data/mini_imagenet") 19 | 20 | 21 | class MiniImageNet(FewShotDataset): 22 | def __init__( 23 | self, 24 | root: Union[Path, str], 25 | split: Optional[str] = None, 26 | specs_file: Optional[Union[Path, str]] = None, 27 | image_size: int = 84, 28 | load_on_ram: bool = False, 29 | loading_transform: Optional[Callable] = None, 30 | transform: Optional[Callable] = None, 31 | training: bool = False, 32 | ): 33 | """ 34 | Build the miniImageNet dataset from specific specs file. By default all images are loaded 35 | in RAM at construction time. Otherwise images are loaded on the fly. 36 | Args: 37 | root: directory where all the images are 38 | split: if specs_file is not specified, will look for the CSV file corresponding 39 | to this split in miniImageNet's specs directory. If both are unspecified, 40 | raise an error. 41 | specs_file: path to the specs CSV file. Mutually exclusive with split but one of them 42 | must be specified. 43 | image_size: images returned by the dataset will be square images of the given size 44 | load_on_ram: if True, images are processed through loading_transform then stored on RAM. 45 | If False , images are loaded on the fly. Preloading demands available space on RAM 46 | and a few minutes at construction time, but will save a lot of time during training. 47 | loading_transform: only used if load_on_ram is True. Torchvision transforms to be 48 | applied to images during preloading. Must contain ToTensor. If none is provided, we 49 | use standard transformations (Resize if training is False, RandomResizedCrop 50 | if True) 51 | transform: torchvision transforms to be applied to images. If none is provided, 52 | we use some standard transformations including ImageNet normalization. 53 | These default transformations depend on the "training" argument. 54 | If load_on_ram is False, default transformations include default loading 55 | transformations. 56 | training: preprocessing is slightly different for a training set, adding a random 57 | cropping and a random horizontal flip. Only used if transforms = None. 58 | """ 59 | self.root = Path(root) 60 | self.data_df = self.load_specs(split, specs_file) 61 | self.load_on_ram = load_on_ram 62 | 63 | if self.load_on_ram: 64 | # Transformation to do before loading the dataset in RAM 65 | self.loading_transform = ( 66 | loading_transform 67 | if loading_transform 68 | else default_mini_imagenet_loading_transform(image_size) 69 | ) 70 | 71 | # Transformation to operate on the fly 72 | self.transform = ( 73 | transform 74 | if transform 75 | else default_mini_imagenet_serving_transform(image_size, training) 76 | ) 77 | 78 | self.images = torch.stack( 79 | [ 80 | self.load_image_as_tensor(image_path) 81 | for image_path in tqdm( 82 | self.data_df.image_path, desc="Loading images" 83 | ) 84 | ] 85 | ) 86 | 87 | else: 88 | self.loading_transform = lambda x: x 89 | self.transform = ( 90 | transform if transform else default_transform(image_size, training) 91 | ) 92 | self.images = self.data_df.image_path.tolist() 93 | 94 | self.class_names = self.data_df.class_name.unique() 95 | self.class_to_label = {v: k for k, v in enumerate(self.class_names)} 96 | self.labels = self.get_labels() 97 | 98 | def __len__(self): 99 | return len(self.data_df) 100 | 101 | def __getitem__(self, item): 102 | img = ( 103 | self.transform(self.images[item]) 104 | if self.load_on_ram 105 | else self.transform( 106 | Image.open(self.data_df.image_path[item]).convert("RGB") 107 | ) 108 | ) 109 | 110 | return img, self.labels[item] 111 | 112 | def load_image_as_tensor(self, filename) -> Tensor: 113 | return self.loading_transform(Image.open(filename).convert("RGB")) 114 | 115 | def load_specs( 116 | self, 117 | split: Optional[str] = None, 118 | specs_file: Optional[Union[Path, str]] = None, 119 | ) -> DataFrame: 120 | """ 121 | Load the classes and paths of images from the CSV specs file. 122 | Args: 123 | split: if specs_file is not specified, will look for the CSV file corresponding 124 | to this split in miniImageNet's specs directory. If both are unspecified, 125 | raise an error. 126 | specs_file: path to the specs CSV file. Mutually exclusive with split but one of them 127 | must be specified. 128 | 129 | Returns: 130 | dataframe with 3 columns class_name, image_name and image_path 131 | 132 | Raises: 133 | ValueError: you need to specify a split or a specs_file, but not both. 134 | """ 135 | if (specs_file is None) & (split is None): 136 | raise ValueError("Please specify either a split or an explicit specs_file.") 137 | if (specs_file is not None) & (split is not None): 138 | raise ValueError("Conflict: you can't specify a split AND a specs file.") 139 | 140 | specs_file = ( 141 | specs_file if specs_file else MINI_IMAGENET_SPECS_DIR / f"{split}.csv" 142 | ) 143 | 144 | return pd.read_csv(specs_file).assign( 145 | image_path=lambda df: df.apply( 146 | lambda row: self.root / row["class_name"] / row["image_name"], axis=1 147 | ) 148 | ) 149 | 150 | def get_labels(self) -> List[int]: 151 | return list(self.data_df.class_name.map(self.class_to_label)) 152 | -------------------------------------------------------------------------------- /easyfsl/datasets/support_set_folder.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Callable, Optional, Union 3 | 4 | import torch 5 | from torch import Tensor 6 | from torchvision.datasets import ImageFolder 7 | 8 | from .default_configs import default_transform 9 | 10 | NOT_A_TENSOR_ERROR_MESSAGE = ( 11 | "SupportSetFolder handles instances as tensors. " 12 | "Please ensure that the specific transform outputs a tensor." 13 | ) 14 | 15 | 16 | class SupportSetFolder(ImageFolder): 17 | """ 18 | Create a support set from images located in a specified folder 19 | with the following file structure: 20 | 21 | root: 22 | |_ subfolder_1: 23 | |_ image_1 24 | |_ … 25 | |_ image_n 26 | |_ subfolder_2: 27 | |_ image_1 28 | |_ … 29 | |_ image_n 30 | 31 | Following the ImageFolder logic, images of a same subfolder will share the same label, 32 | and the classes will be named after the subfolders. 33 | 34 | Example of use: 35 | 36 | predict_transformation = transforms.Compose([ 37 | transforms.Resize((224, 224)), 38 | transforms.ToTensor() 39 | ]) 40 | support_set = SupportSetFolder( 41 | root=path_to_support_images, 42 | transform=predict_transformation, 43 | device="cuda" 44 | ) 45 | with torch.no_grad(): 46 | few_shot_classifier.eval() 47 | few_shot_classifier.process_support_set(support_set.get_images(), support_set.get_labels()) 48 | class_names = support_set.classes 49 | predicted_labels = few_shot_classifier(query_images.to(device)).argmax(dim=1) 50 | predicted_classes = [ support_set.classes[label] for label in predicted_labels] 51 | """ 52 | 53 | def __init__( 54 | self, 55 | root: Union[str, Path], 56 | device="cpu", 57 | image_size: int = 84, 58 | transform: Optional[Callable] = None, 59 | **kwargs, 60 | ): 61 | """ 62 | Args: 63 | device: 64 | **kwargs: kwargs for the parent ImageFolder class 65 | """ 66 | transform = ( 67 | transform if transform else default_transform(image_size, training=False) 68 | ) 69 | 70 | super().__init__(str(root), transform=transform, **kwargs) 71 | 72 | self.device = device 73 | try: 74 | self.images = torch.stack([instance[0] for instance in self]).to( 75 | self.device 76 | ) 77 | except TypeError as type_error: 78 | raise TypeError(NOT_A_TENSOR_ERROR_MESSAGE) from type_error 79 | 80 | def get_images(self) -> Tensor: 81 | """ 82 | Returns: 83 | support set images as a (n_images, n_channels, width, height) tensor 84 | on the selected device 85 | """ 86 | return self.images 87 | 88 | def get_labels(self) -> Tensor: 89 | """ 90 | Returns: 91 | support set labels as a tensor on the selected device 92 | """ 93 | return torch.tensor(self.targets).to(self.device) 94 | -------------------------------------------------------------------------------- /easyfsl/datasets/tiered_imagenet.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from .easy_set import EasySet 4 | 5 | TIERED_IMAGENET_SPECS_DIR = Path("data/tiered_imagenet") 6 | 7 | 8 | class TieredImageNet(EasySet): 9 | def __init__(self, split: str, **kwargs): 10 | """ 11 | Build the tieredImageNet dataset for the specific split. 12 | Args: 13 | split: one of the available split (typically train, val, test). 14 | Raises: 15 | ValueError: if the specified split cannot be associated with a JSON spec file 16 | from tieredImageNet's specs directory 17 | """ 18 | specs_file = TIERED_IMAGENET_SPECS_DIR / f"{split}.json" 19 | if not specs_file.is_file(): 20 | raise ValueError( 21 | f"Could not find specs file {specs_file.name} in {TIERED_IMAGENET_SPECS_DIR}" 22 | ) 23 | super().__init__(specs_file=specs_file, **kwargs) 24 | -------------------------------------------------------------------------------- /easyfsl/datasets/wrap_few_shot_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | from torch import Tensor 4 | from torch.utils.data import Dataset 5 | from tqdm import tqdm 6 | 7 | from .few_shot_dataset import FewShotDataset 8 | 9 | 10 | class WrapFewShotDataset(FewShotDataset): 11 | """ 12 | Wrap a dataset in a FewShotDataset. This is useful if you have your own dataset 13 | and want to use it with the tools provided by EasyFSL such as TaskSampler. 14 | """ 15 | 16 | def __init__( 17 | self, 18 | dataset: Dataset, 19 | image_position_in_get_item_output: int = 0, 20 | label_position_in_get_item_output: int = 1, 21 | ): 22 | """ 23 | Wrap a dataset in a FewShotDataset. 24 | Args: 25 | dataset: dataset to wrap 26 | image_position_in_get_item_output: position of the image in the tuple returned 27 | by dataset.__getitem__(). Default: 0 28 | label_position_in_get_item_output: position of the label in the tuple returned 29 | by dataset.__getitem__(). Default: 1 30 | """ 31 | if image_position_in_get_item_output == label_position_in_get_item_output: 32 | raise ValueError( 33 | "image_position_in_get_item_output and label_position_in_get_item_output must be different." 34 | ) 35 | if ( 36 | image_position_in_get_item_output < 0 37 | or label_position_in_get_item_output < 0 38 | ): 39 | raise ValueError( 40 | "image_position_in_get_item_output and label_position_in_get_item_output must be positive." 41 | ) 42 | item_length = len(dataset[0]) 43 | if ( 44 | image_position_in_get_item_output >= item_length 45 | or label_position_in_get_item_output >= item_length 46 | ): 47 | raise ValueError("Specified positions in output are out of range.") 48 | 49 | self.source_dataset = dataset 50 | self.labels = [ 51 | source_dataset_instance[label_position_in_get_item_output] 52 | for source_dataset_instance in tqdm( 53 | dataset, desc="Scrolling dataset's labels..." 54 | ) 55 | ] 56 | self.image_position_in_get_item_output = image_position_in_get_item_output 57 | self.label_position_in_get_item_output = label_position_in_get_item_output 58 | 59 | def __getitem__(self, item: int) -> Tuple[Tensor, int]: 60 | return ( 61 | self.source_dataset[item][self.image_position_in_get_item_output], 62 | self.source_dataset[item][self.label_position_in_get_item_output], 63 | ) 64 | 65 | def __len__(self) -> int: 66 | return len(self.labels) 67 | 68 | def get_labels(self) -> List[int]: 69 | return self.labels 70 | -------------------------------------------------------------------------------- /easyfsl/methods/__init__.py: -------------------------------------------------------------------------------- 1 | from .bd_cspn import BDCSPN 2 | from .feat import FEAT 3 | from .few_shot_classifier import FewShotClassifier 4 | from .finetune import Finetune 5 | from .laplacian_shot import LaplacianShot 6 | from .matching_networks import MatchingNetworks 7 | from .prototypical_networks import PrototypicalNetworks 8 | from .pt_map import PTMAP 9 | from .relation_networks import RelationNetworks 10 | from .simple_shot import SimpleShot 11 | from .tim import TIM 12 | from .transductive_finetuning import TransductiveFinetuning 13 | -------------------------------------------------------------------------------- /easyfsl/methods/bd_cspn.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor, nn 2 | 3 | from .few_shot_classifier import FewShotClassifier 4 | 5 | 6 | class BDCSPN(FewShotClassifier): 7 | """ 8 | Jinlu Liu, Liang Song, Yongqiang Qin 9 | "Prototype Rectification for Few-Shot Learning" (ECCV 2020) 10 | https://arxiv.org/abs/1911.10713 11 | 12 | Rectify prototypes with label propagation and feature shifting. 13 | Classify queries based on their cosine distance to prototypes. 14 | This is a transductive method. 15 | """ 16 | 17 | def rectify_prototypes( 18 | self, query_features: Tensor 19 | ): # pylint: disable=not-callable 20 | """ 21 | Updates prototypes with label propagation and feature shifting. 22 | Args: 23 | query_features: query features of shape (n_query, feature_dimension) 24 | """ 25 | n_classes = self.support_labels.unique().size(0) 26 | one_hot_support_labels = nn.functional.one_hot(self.support_labels, n_classes) 27 | 28 | average_support_query_shift = self.support_features.mean( 29 | 0, keepdim=True 30 | ) - query_features.mean(0, keepdim=True) 31 | query_features = query_features + average_support_query_shift 32 | 33 | support_logits = self.cosine_distance_to_prototypes(self.support_features).exp() 34 | query_logits = self.cosine_distance_to_prototypes(query_features).exp() 35 | 36 | one_hot_query_prediction = nn.functional.one_hot( 37 | query_logits.argmax(-1), n_classes 38 | ) 39 | 40 | normalization_vector = ( 41 | (one_hot_support_labels * support_logits).sum(0) 42 | + (one_hot_query_prediction * query_logits).sum(0) 43 | ).unsqueeze( 44 | 0 45 | ) # [1, n_classes] 46 | support_reweighting = ( 47 | one_hot_support_labels * support_logits 48 | ) / normalization_vector # [n_support, n_classes] 49 | query_reweighting = ( 50 | one_hot_query_prediction * query_logits 51 | ) / normalization_vector # [n_query, n_classes] 52 | 53 | self.prototypes = (support_reweighting * one_hot_support_labels).t().matmul( 54 | self.support_features 55 | ) + (query_reweighting * one_hot_query_prediction).t().matmul(query_features) 56 | 57 | def forward( 58 | self, 59 | query_images: Tensor, 60 | ) -> Tensor: 61 | """ 62 | Overrides forward method of FewShotClassifier. 63 | Update prototypes using query images, then classify query images based 64 | on their cosine distance to updated prototypes. 65 | """ 66 | query_features = self.compute_features(query_images) 67 | 68 | self.rectify_prototypes( 69 | query_features=query_features, 70 | ) 71 | return self.softmax_if_specified( 72 | self.cosine_distance_to_prototypes(query_features) 73 | ) 74 | 75 | @staticmethod 76 | def is_transductive() -> bool: 77 | return True 78 | -------------------------------------------------------------------------------- /easyfsl/methods/feat.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Union 3 | 4 | import torch 5 | from torch import Tensor, nn 6 | 7 | from easyfsl.modules import MultiHeadAttention 8 | from easyfsl.modules.feat_resnet12 import feat_resnet12 9 | 10 | from .prototypical_networks import PrototypicalNetworks 11 | from .utils import strip_prefix 12 | 13 | 14 | class FEAT(PrototypicalNetworks): 15 | """ 16 | Han-Jia Ye, Hexiang Hu, De-Chuan Zhan, Fei Sha. 17 | "Few-Shot Learning via Embedding Adaptation With Set-to-Set Functions" (CVPR 2020) 18 | https://openaccess.thecvf.com/content_CVPR_2020/html/Ye_Few-Shot_Learning_via_Embedding_Adaptation_With_Set-to-Set_Functions_CVPR_2020_paper.html 19 | 20 | This method uses an episodically trained attention module to improve the prototypes. 21 | Queries are then classified based on their euclidean distance to the prototypes, 22 | as in Prototypical Networks. 23 | This in an inductive method. 24 | 25 | The attention module must follow specific constraints described in the docstring of FEAT.__init__(). 26 | We provide a default attention module following the one used in the original implementation. 27 | FEAT can be initialized in the default configuration from the authors, by calling FEAT.from_resnet12_checkpoint(). 28 | """ 29 | 30 | def __init__(self, *args, attention_module: nn.Module, **kwargs): 31 | """ 32 | FEAT needs an additional attention module. 33 | Args: 34 | *args: 35 | attention_module: the forward method must accept 3 Tensor arguments of shape 36 | (1, num_classes, feature_dimension) and return a pair of Tensor, with the first 37 | one of shape (1, num_classes, feature_dimension). 38 | This follows the original implementation of https://github.com/Sha-Lab/FEAT 39 | **kwargs: 40 | """ 41 | super().__init__(*args, **kwargs) 42 | self.attention_module = attention_module 43 | 44 | def process_support_set( 45 | self, 46 | support_images: Tensor, 47 | support_labels: Tensor, 48 | ): 49 | """ 50 | Extract prototypes from support set and rectify them with the attention module. 51 | Args: 52 | support_images: support images of shape (n_support, **image_shape) 53 | support_labels: support labels of shape (n_support,) 54 | """ 55 | super().process_support_set(support_images, support_labels) 56 | self.prototypes = self.attention_module( 57 | self.prototypes.unsqueeze(0), 58 | self.prototypes.unsqueeze(0), 59 | self.prototypes.unsqueeze(0), 60 | )[0][0] 61 | 62 | @classmethod 63 | def from_resnet12_checkpoint( 64 | cls, 65 | checkpoint_path: Union[Path, str], 66 | device: str = "cpu", 67 | feature_dimension: int = 640, 68 | use_backbone: bool = True, 69 | **kwargs, 70 | ): 71 | """ 72 | Load a FEAT model from a checkpoint of a resnet12 model as provided by the authors. 73 | We initialize the default ResNet12 backbone and attention module and load the weights. 74 | We solve some compatibility issues in the names of the parameters and ensure there 75 | missing keys. 76 | 77 | Compatible weights can be found here (availability verified 30/05/2023): 78 | - miniImageNet: https://drive.google.com/file/d/1ixqw1l9XVxl3lh1m5VXkctw6JssahGbQ/view 79 | - tieredImageNet: https://drive.google.com/file/d/1M93jdOjAn8IihICPKJg8Mb4B-eYDSZfE/view 80 | Args: 81 | checkpoint_path: path to the checkpoint 82 | device: device to load the model on 83 | feature_dimension: dimension of the features extracted by the backbone. 84 | Should be 640 with the default Resnet12 backbone. 85 | use_backbone: if False, we initialize the backbone to nn.Identity() (useful for 86 | working on pre-extracted features) 87 | Returns: 88 | a FEAT model with weights loaded from the checkpoint 89 | Raises: 90 | ValueError: if the checkpoint does not contain all the expected keys 91 | of the backbone or the attention module 92 | """ 93 | state_dict = torch.load(str(checkpoint_path), map_location=device)["params"] 94 | 95 | if use_backbone: 96 | backbone = feat_resnet12().to(device) 97 | backbone_missing_keys, _ = backbone.load_state_dict( 98 | strip_prefix(state_dict, "encoder."), strict=False 99 | ) 100 | if len(backbone_missing_keys) > 0: 101 | raise ValueError(f"Missing keys for backbone: {backbone_missing_keys}") 102 | else: 103 | backbone = nn.Identity() 104 | 105 | attention_module = MultiHeadAttention( 106 | 1, 107 | feature_dimension, 108 | feature_dimension, 109 | feature_dimension, 110 | ).to(device) 111 | attention_missing_keys, _ = attention_module.load_state_dict( 112 | strip_prefix(state_dict, "slf_attn."), strict=False 113 | ) 114 | if len(attention_missing_keys) > 0: 115 | raise ValueError( 116 | f"Missing keys for attention module: {attention_missing_keys}" 117 | ) 118 | 119 | return cls(backbone, attention_module=attention_module, **kwargs).to(device) 120 | -------------------------------------------------------------------------------- /easyfsl/methods/few_shot_classifier.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Optional 3 | 4 | import torch 5 | from torch import Tensor, nn 6 | 7 | from easyfsl.methods.utils import compute_prototypes 8 | 9 | 10 | class FewShotClassifier(nn.Module): 11 | """ 12 | Abstract class providing methods usable by all few-shot classification algorithms 13 | """ 14 | 15 | def __init__( 16 | self, 17 | backbone: Optional[nn.Module] = None, 18 | use_softmax: bool = False, 19 | feature_centering: Optional[Tensor] = None, 20 | feature_normalization: Optional[float] = None, 21 | ): 22 | """ 23 | Initialize the Few-Shot Classifier 24 | Args: 25 | backbone: the feature extractor used by the method. Must output a tensor of the 26 | appropriate shape (depending on the method). 27 | If None is passed, the backbone will be initialized as nn.Identity(). 28 | use_softmax: whether to return predictions as soft probabilities 29 | feature_centering: a features vector on which to center all computed features. 30 | If None is passed, no centering is performed. 31 | feature_normalization: a value by which to normalize all computed features after centering. 32 | It is used as the p argument in torch.nn.functional.normalize(). 33 | If None is passed, no normalization is performed. 34 | """ 35 | super().__init__() 36 | 37 | self.backbone = backbone if backbone is not None else nn.Identity() 38 | self.use_softmax = use_softmax 39 | 40 | self.prototypes = torch.tensor(()) 41 | self.support_features = torch.tensor(()) 42 | self.support_labels = torch.tensor(()) 43 | 44 | self.feature_centering = ( 45 | feature_centering if feature_centering is not None else torch.tensor(0) 46 | ) 47 | self.feature_normalization = feature_normalization 48 | 49 | @abstractmethod 50 | def forward( 51 | self, 52 | query_images: Tensor, 53 | ) -> Tensor: 54 | """ 55 | Predict classification labels. 56 | Args: 57 | query_images: images of the query set of shape (n_query, **image_shape) 58 | Returns: 59 | a prediction of classification scores for query images of shape (n_query, n_classes) 60 | """ 61 | raise NotImplementedError( 62 | "All few-shot algorithms must implement a forward method." 63 | ) 64 | 65 | def process_support_set( 66 | self, 67 | support_images: Tensor, 68 | support_labels: Tensor, 69 | ): 70 | """ 71 | Harness information from the support set, so that query labels can later be predicted using a forward call. 72 | The default behaviour shared by most few-shot classifiers is to compute prototypes and store the support set. 73 | Args: 74 | support_images: images of the support set of shape (n_support, **image_shape) 75 | support_labels: labels of support set images of shape (n_support, ) 76 | """ 77 | self.compute_prototypes_and_store_support_set(support_images, support_labels) 78 | 79 | @staticmethod 80 | def is_transductive() -> bool: 81 | raise NotImplementedError( 82 | "All few-shot algorithms must implement a is_transductive method." 83 | ) 84 | 85 | def compute_features(self, images: Tensor) -> Tensor: 86 | """ 87 | Compute features from images and perform centering and normalization. 88 | Args: 89 | images: images of shape (n_images, **image_shape) 90 | Returns: 91 | features of shape (n_images, feature_dimension) 92 | """ 93 | original_features = self.backbone(images) 94 | centered_features = original_features - self.feature_centering 95 | if self.feature_normalization is not None: 96 | return nn.functional.normalize( 97 | centered_features, p=self.feature_normalization, dim=1 98 | ) 99 | return centered_features 100 | 101 | def softmax_if_specified(self, output: Tensor, temperature: float = 1.0) -> Tensor: 102 | """ 103 | If the option is chosen when the classifier is initialized, we perform a softmax on the 104 | output in order to return soft probabilities. 105 | Args: 106 | output: output of the forward method of shape (n_query, n_classes) 107 | temperature: temperature of the softmax 108 | Returns: 109 | output as it was, or output as soft probabilities, of shape (n_query, n_classes) 110 | """ 111 | return (temperature * output).softmax(-1) if self.use_softmax else output 112 | 113 | def l2_distance_to_prototypes(self, samples: Tensor) -> Tensor: 114 | """ 115 | Compute prediction logits from their euclidean distance to support set prototypes. 116 | Args: 117 | samples: features of the items to classify of shape (n_samples, feature_dimension) 118 | Returns: 119 | prediction logits of shape (n_samples, n_classes) 120 | """ 121 | return -torch.cdist(samples, self.prototypes) 122 | 123 | def cosine_distance_to_prototypes(self, samples) -> Tensor: 124 | """ 125 | Compute prediction logits from their cosine distance to support set prototypes. 126 | Args: 127 | samples: features of the items to classify of shape (n_samples, feature_dimension) 128 | Returns: 129 | prediction logits of shape (n_samples, n_classes) 130 | """ 131 | return ( 132 | nn.functional.normalize(samples, dim=1) 133 | @ nn.functional.normalize(self.prototypes, dim=1).T 134 | ) 135 | 136 | def compute_prototypes_and_store_support_set( 137 | self, 138 | support_images: Tensor, 139 | support_labels: Tensor, 140 | ): 141 | """ 142 | Extract support features, compute prototypes, and store support labels, features, and prototypes. 143 | Args: 144 | support_images: images of the support set of shape (n_support, **image_shape) 145 | support_labels: labels of support set images of shape (n_support, ) 146 | """ 147 | self.support_labels = support_labels 148 | self.support_features = self.compute_features(support_images) 149 | self._raise_error_if_features_are_multi_dimensional(self.support_features) 150 | self.prototypes = compute_prototypes(self.support_features, support_labels) 151 | 152 | @staticmethod 153 | def _raise_error_if_features_are_multi_dimensional(features: Tensor): 154 | if len(features.shape) != 2: 155 | raise ValueError( 156 | "Illegal backbone or feature shape. " 157 | "Expected output for an image is a 1-dim tensor." 158 | ) 159 | -------------------------------------------------------------------------------- /easyfsl/methods/finetune.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor, nn 3 | 4 | from .few_shot_classifier import FewShotClassifier 5 | 6 | 7 | class Finetune(FewShotClassifier): 8 | """ 9 | Wei-Yu Chen, Yen-Cheng Liu, Zsolt Kira, Yu-Chiang Frank Wang, Jia-Bin Huang 10 | A Closer Look at Few-shot Classification (ICLR 2019) 11 | https://arxiv.org/abs/1904.04232 12 | 13 | Fine-tune prototypes based on classification error on support images. 14 | Classify queries based on their cosine distances to updated prototypes. 15 | As is, it is incompatible with episodic training because we freeze the backbone to perform 16 | fine-tuning. 17 | 18 | This is an inductive method. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | *args, 24 | fine_tuning_steps: int = 200, 25 | fine_tuning_lr: float = 1e-4, 26 | temperature: float = 1.0, 27 | **kwargs, 28 | ): 29 | """ 30 | Args: 31 | fine_tuning_steps: number of fine-tuning steps 32 | fine_tuning_lr: learning rate for fine-tuning 33 | temperature: temperature applied to the logits before computing 34 | softmax or cross-entropy. Higher temperature means softer predictions. 35 | """ 36 | super().__init__(*args, **kwargs) 37 | 38 | # Since we fine-tune the prototypes we need to make them leaf variables 39 | # i.e. we need to freeze the backbone. 40 | self.backbone.requires_grad_(False) 41 | 42 | self.fine_tuning_steps = fine_tuning_steps 43 | self.fine_tuning_lr = fine_tuning_lr 44 | self.temperature = temperature 45 | 46 | def forward( 47 | self, 48 | query_images: Tensor, 49 | ) -> Tensor: 50 | """ 51 | Overrides forward method of FewShotClassifier. 52 | Fine-tune prototypes based on support classification error. 53 | Then classify w.r.t. to cosine distance to prototypes. 54 | """ 55 | query_features = self.compute_features(query_images) 56 | 57 | with torch.enable_grad(): 58 | self.prototypes.requires_grad_() 59 | optimizer = torch.optim.Adam([self.prototypes], lr=self.fine_tuning_lr) 60 | for _ in range(self.fine_tuning_steps): 61 | support_logits = self.cosine_distance_to_prototypes( 62 | self.support_features 63 | ) 64 | loss = nn.functional.cross_entropy( 65 | self.temperature * support_logits, self.support_labels 66 | ) 67 | optimizer.zero_grad() 68 | loss.backward() 69 | optimizer.step() 70 | 71 | return self.softmax_if_specified( 72 | self.cosine_distance_to_prototypes(query_features), 73 | temperature=self.temperature, 74 | ).detach() 75 | 76 | @staticmethod 77 | def is_transductive() -> bool: 78 | return False 79 | -------------------------------------------------------------------------------- /easyfsl/methods/laplacian_shot.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | from .bd_cspn import BDCSPN 5 | from .utils import k_nearest_neighbours 6 | 7 | 8 | class LaplacianShot(BDCSPN): 9 | """ 10 | Imtiaz Masud Ziko, Jose Dolz, Eric Granger, Ismail Ben Ayed. 11 | "Laplacian Regularized Few-Shot Learning" (ICML 2020) 12 | https://arxiv.org/abs/2006.15486 13 | 14 | LaplacianShot updates the soft-assignments using a Laplacian Regularization to 15 | improve consistency between the assignments of neighbouring query points. 16 | Default hyperparameters have been optimized for 5-way 5-shot classification on 17 | miniImageNet (see https://github.com/ebennequin/few-shot-open-set/blob/master/configs/classifiers.yaml). 18 | 19 | LaplianShot is a transductive method. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | *args, 25 | inference_steps: int = 20, 26 | knn: int = 3, 27 | lambda_regularization: float = 0.7, 28 | **kwargs, 29 | ): 30 | super().__init__(*args, **kwargs) 31 | self.knn = knn 32 | self.inference_steps = inference_steps 33 | self.lambda_regularization = lambda_regularization 34 | 35 | def forward( 36 | self, 37 | query_images: Tensor, 38 | ) -> Tensor: 39 | query_features = self.compute_features(query_images) 40 | self.rectify_prototypes(query_features=query_features) 41 | 42 | features_to_prototypes_distances = ( 43 | torch.cdist(query_features, self.prototypes) ** 2 44 | ) 45 | pairwise_affinities = self.compute_pairwise_affinities(query_features) 46 | predictions = self.bound_updates( 47 | initial_scores=features_to_prototypes_distances, kernel=pairwise_affinities 48 | ) 49 | 50 | return predictions 51 | 52 | def compute_pairwise_affinities(self, features: Tensor) -> Tensor: 53 | """ 54 | Build pairwise affinity matrix from features using k-nearest neighbours. 55 | Item (i, j) of the matrix is 1 if i is among the k-nearest neighbours of j, and vice versa, and 0 otherwise. 56 | Args: 57 | features: tensor of shape (n_features, feature_dimension) 58 | 59 | Returns: 60 | tensor of shape (n_features, n_features) corresponding to W in the paper. 61 | """ 62 | # Compute the k-nearest neighbours of each feature vector. 63 | # Each row is the indices of the k nearest neighbours of the corresponding feature, not including itself 64 | nearest_neighbours = k_nearest_neighbours(features, self.knn) 65 | affinity_matrix = torch.zeros((len(features), len(features))).to( 66 | nearest_neighbours.device 67 | ) 68 | for vector_index, vector_nearest_neighbours in enumerate(nearest_neighbours): 69 | affinity_matrix[vector_index].index_fill_(0, vector_nearest_neighbours, 1) 70 | 71 | return affinity_matrix 72 | 73 | def compute_upper_bound( 74 | self, soft_assignments: Tensor, initial_scores: Tensor, kernel: Tensor 75 | ) -> float: 76 | """ 77 | Compute the upper bound objective for the soft assignments following Equation (7) of the paper. 78 | Args: 79 | soft_assignments: soft assignments of shape (n_query, n_classes), $$y_q$$ in the paper 80 | initial_scores: distances from each query to each prototype, 81 | of shape (n_query, n_classes), $$a_q$$ in the paper 82 | kernel: pairwise affinities between query feature vectors, 83 | of shape (n_features, n_features), $$W$$ in the paper 84 | Returns: 85 | upper bound objective 86 | """ 87 | pairwise = kernel.matmul(soft_assignments) 88 | temp = (initial_scores * soft_assignments) + ( 89 | -self.lambda_regularization * pairwise * soft_assignments 90 | ) 91 | upper_bound = (soft_assignments * (soft_assignments + 1e-12).log() + temp).sum() 92 | 93 | return upper_bound.item() 94 | 95 | def bound_updates(self, initial_scores: Tensor, kernel: Tensor) -> Tensor: 96 | """ 97 | Compute the soft assignments using the bound update algorithm described in the paper 98 | as Algorithm 1. 99 | Args: 100 | initial_scores: distances from each query to each prototype, of shape (n_query, n_classes) 101 | kernel: pairwise affinities between query feature vectors, of shape (n_features, n_features) 102 | Returns: 103 | soft_assignments: soft assignments of shape (n_query, n_classes) 104 | """ 105 | old_upper_bound = float("inf") 106 | soft_assignments = (-initial_scores).softmax(dim=1) 107 | for i in range(self.inference_steps): 108 | additive = -initial_scores 109 | mul_kernel = kernel.matmul(soft_assignments) 110 | soft_assignments = -self.lambda_regularization * mul_kernel 111 | additive = additive - soft_assignments 112 | soft_assignments = additive.softmax(dim=1) 113 | upper_bound = self.compute_upper_bound( 114 | soft_assignments, initial_scores, kernel 115 | ) 116 | 117 | if i > 1 and ( 118 | abs(upper_bound - old_upper_bound) <= 1e-6 * abs(old_upper_bound) 119 | ): 120 | break 121 | 122 | old_upper_bound = upper_bound 123 | 124 | return soft_assignments 125 | 126 | @staticmethod 127 | def is_transductive() -> bool: 128 | return True 129 | -------------------------------------------------------------------------------- /easyfsl/methods/prototypical_networks.py: -------------------------------------------------------------------------------- 1 | """ 2 | See original implementation (quite far from this one) 3 | at https://github.com/jakesnell/prototypical-networks 4 | """ 5 | 6 | from torch import Tensor 7 | 8 | from .few_shot_classifier import FewShotClassifier 9 | 10 | 11 | class PrototypicalNetworks(FewShotClassifier): 12 | """ 13 | Jake Snell, Kevin Swersky, and Richard S. Zemel. 14 | "Prototypical networks for few-shot learning." (2017) 15 | https://arxiv.org/abs/1703.05175 16 | 17 | Prototypical networks extract feature vectors for both support and query images. Then it 18 | computes the mean of support features for each class (called prototypes), and predict 19 | classification scores for query images based on their euclidean distance to the prototypes. 20 | """ 21 | 22 | def forward( 23 | self, 24 | query_images: Tensor, 25 | ) -> Tensor: 26 | """ 27 | Overrides forward method of FewShotClassifier. 28 | Predict query labels based on their distance to class prototypes in the feature space. 29 | Classification scores are the negative of euclidean distances. 30 | """ 31 | # Extract the features of query images 32 | query_features = self.compute_features(query_images) 33 | self._raise_error_if_features_are_multi_dimensional(query_features) 34 | 35 | # Compute the euclidean distance from queries to prototypes 36 | scores = self.l2_distance_to_prototypes(query_features) 37 | 38 | return self.softmax_if_specified(scores) 39 | 40 | @staticmethod 41 | def is_transductive() -> bool: 42 | return False 43 | -------------------------------------------------------------------------------- /easyfsl/methods/pt_map.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor, nn 3 | 4 | from easyfsl.methods.utils import power_transform 5 | 6 | from .few_shot_classifier import FewShotClassifier 7 | 8 | MAXIMUM_SINKHORN_ITERATIONS = 1000 9 | 10 | 11 | class PTMAP(FewShotClassifier): 12 | """ 13 | Yuqing Hu, Vincent Gripon, Stéphane Pateux. 14 | "Leveraging the Feature Distribution in Transfer-based Few-Shot Learning" (2020) 15 | https://arxiv.org/abs/2006.03806 16 | 17 | Query soft assignments are computed as the optimal transport plan to class prototypes. 18 | At each iteration, prototypes are fine-tuned based on the soft assignments. 19 | This is a transductive method. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | *args, 25 | fine_tuning_steps: int = 10, 26 | fine_tuning_lr: float = 0.2, 27 | lambda_regularization: float = 10.0, 28 | power_factor: float = 0.5, 29 | **kwargs, 30 | ): 31 | super().__init__(*args, **kwargs) 32 | self.fine_tuning_steps = fine_tuning_steps 33 | self.fine_tuning_lr = fine_tuning_lr 34 | self.lambda_regularization = lambda_regularization 35 | self.power_factor = power_factor 36 | 37 | def forward( 38 | self, 39 | query_images: Tensor, 40 | ) -> Tensor: 41 | """ 42 | Predict query soft assignments following Algorithm 1 of the paper. 43 | """ 44 | query_features = self.compute_features(query_images) 45 | 46 | support_assignments = nn.functional.one_hot( # pylint: disable=not-callable 47 | self.support_labels, len(self.prototypes) 48 | ) 49 | for _ in range(self.fine_tuning_steps): 50 | query_soft_assignments = self.compute_soft_assignments(query_features) 51 | all_features = torch.cat([self.support_features, query_features], 0) 52 | all_assignments = torch.cat( 53 | [support_assignments, query_soft_assignments], dim=0 54 | ) 55 | 56 | self.update_prototypes(all_features, all_assignments) 57 | 58 | return self.compute_soft_assignments(query_features) 59 | 60 | def compute_features(self, images: Tensor) -> Tensor: 61 | """ 62 | Apply power transform on features following Equation (1) in the paper. 63 | Args: 64 | images: images of shape (n_images, **image_shape) 65 | Returns: 66 | features of shape (n_images, feature_dimension) with power-transform. 67 | """ 68 | features = super().compute_features(images) 69 | return power_transform(features, self.power_factor) 70 | 71 | def compute_soft_assignments(self, query_features: Tensor) -> Tensor: 72 | """ 73 | Compute soft assignments from queries to prototypes, following Equation (3) of the paper. 74 | Args: 75 | query_features: query features, of shape (n_queries, feature_dim) 76 | 77 | Returns: 78 | soft assignments from queries to prototypes, of shape (n_queries, n_classes) 79 | """ 80 | 81 | distances_to_prototypes = ( 82 | torch.cdist(query_features, self.prototypes) ** 2 83 | ) # [Nq, K] 84 | 85 | soft_assignments = self.compute_optimal_transport( 86 | distances_to_prototypes, epsilon=1e-6 87 | ) 88 | 89 | return soft_assignments 90 | 91 | def compute_optimal_transport( 92 | self, cost_matrix: Tensor, epsilon: float = 1e-6 93 | ) -> Tensor: 94 | """ 95 | Compute the optimal transport plan from queries to prototypes using Sinkhorn-Knopp algorithm. 96 | Args: 97 | cost_matrix: euclidean distances from queries to prototypes, 98 | of shape (n_queries, n_classes) 99 | epsilon: convergence parameter. Stop when the update is smaller than epsilon. 100 | Returns: 101 | transport plan from queries to prototypes of shape (n_queries, n_classes) 102 | """ 103 | 104 | instance_multiplication_factor = cost_matrix.shape[0] // cost_matrix.shape[1] 105 | 106 | transport_plan = torch.exp(-self.lambda_regularization * cost_matrix) 107 | transport_plan /= transport_plan.sum(dim=(0, 1), keepdim=True) 108 | 109 | for _ in range(MAXIMUM_SINKHORN_ITERATIONS): 110 | per_class_sums = transport_plan.sum(1) 111 | transport_plan *= (1 / (per_class_sums + 1e-10)).unsqueeze(1) 112 | transport_plan *= ( 113 | instance_multiplication_factor / (transport_plan.sum(0) + 1e-10) 114 | ).unsqueeze(0) 115 | if torch.max(torch.abs(per_class_sums - transport_plan.sum(1))) < epsilon: 116 | break 117 | 118 | return transport_plan 119 | 120 | def update_prototypes(self, all_features, all_assignments) -> None: 121 | """ 122 | Update prototypes by weigh-averaging the features with their soft assignments, 123 | following Equation (6) of the paper. 124 | Args: 125 | all_features: concatenation of support and query features, 126 | of shape (n_support + n_query, feature_dim) 127 | all_assignments: concatenation of support and query soft assignments, 128 | of shape (n_support + n_query, n_classes)- 129 | """ 130 | new_prototypes = (all_assignments.T @ all_features) / all_assignments.sum( 131 | 0 132 | ).unsqueeze(1) 133 | delta = new_prototypes - self.prototypes 134 | self.prototypes += self.fine_tuning_lr * delta 135 | 136 | @staticmethod 137 | def is_transductive() -> bool: 138 | return True 139 | -------------------------------------------------------------------------------- /easyfsl/methods/relation_networks.py: -------------------------------------------------------------------------------- 1 | """ 2 | See original implementation at 3 | https://github.com/floodsung/LearningToCompare_FSL 4 | """ 5 | 6 | from typing import Optional 7 | 8 | import torch 9 | from torch import Tensor, nn 10 | 11 | from easyfsl.modules.predesigned_modules import default_relation_module 12 | 13 | from .few_shot_classifier import FewShotClassifier 14 | from .utils import compute_prototypes 15 | 16 | 17 | class RelationNetworks(FewShotClassifier): 18 | """ 19 | Sung, Flood, Yongxin Yang, Li Zhang, Tao Xiang, Philip HS Torr, and Timothy M. Hospedales. 20 | "Learning to compare: Relation network for few-shot learning." (2018) 21 | https://openaccess.thecvf.com/content_cvpr_2018/papers/Sung_Learning_to_Compare_CVPR_2018_paper.pdf 22 | 23 | In the Relation Networks algorithm, we first extract feature maps for both support and query 24 | images. Then we compute the mean of support features for each class (called prototypes). 25 | To predict the label of a query image, its feature map is concatenated with each class prototype 26 | and fed into a relation module, i.e. a CNN that outputs a relation score. Finally, the 27 | classification vector of the query is its relation score to each class prototype. 28 | 29 | Note that for most other few-shot algorithms we talk about feature vectors, because for each 30 | input image, the backbone outputs a 1-dim feature vector. Here we talk about feature maps, 31 | because for each input image, the backbone outputs a "feature map" of shape 32 | (n_channels, width, height). This raises different constraints on the architecture of the 33 | backbone: while other algorithms require a "flatten" operation in the backbone, here "flatten" 34 | operations are forbidden. 35 | 36 | Relation Networks use Mean Square Error. This is unusual because this is a classification 37 | problem. The authors justify this choice by the fact that the output of the model is a relation 38 | score, which makes it a regression problem. See the article for more details. 39 | """ 40 | 41 | def __init__( 42 | self, 43 | *args, 44 | feature_dimension: int, 45 | relation_module: Optional[nn.Module] = None, 46 | **kwargs, 47 | ): 48 | """ 49 | Build Relation Networks by calling the constructor of FewShotClassifier. 50 | Args: 51 | feature_dimension: first dimension of the feature maps extracted by the backbone. 52 | relation_module: module that will take the concatenation of a query features vector 53 | and a prototype to output a relation score. If none is specific, we use the default 54 | relation module from the original paper. 55 | """ 56 | super().__init__(*args, **kwargs) 57 | 58 | self.feature_dimension = feature_dimension 59 | 60 | # Here we build the relation module that will output the relation score for each 61 | # (query, prototype) pair. See the function docstring for more details. 62 | self.relation_module = ( 63 | relation_module 64 | if relation_module 65 | else default_relation_module(self.feature_dimension) 66 | ) 67 | 68 | def process_support_set( 69 | self, 70 | support_images: Tensor, 71 | support_labels: Tensor, 72 | ): 73 | """ 74 | Overrides process_support_set of FewShotClassifier. 75 | Extract feature maps from the support set and store class prototypes. 76 | """ 77 | 78 | support_features = self.compute_features(support_images) 79 | self._validate_features_shape(support_features) 80 | self.prototypes = compute_prototypes(support_features, support_labels) 81 | 82 | def forward(self, query_images: Tensor) -> Tensor: 83 | """ 84 | Overrides method forward in FewShotClassifier. 85 | Predict the label of a query image by concatenating its feature map with each class 86 | prototype and feeding the result into a relation module, i.e. a CNN that outputs a relation 87 | score. Finally, the classification vector of the query is its relation score to each class 88 | prototype. 89 | """ 90 | query_features = self.compute_features(query_images) 91 | self._validate_features_shape(query_features) 92 | 93 | # For each pair (query, prototype), we compute the concatenation of their feature maps 94 | # Given that query_features is of shape (n_queries, n_channels, width, height), the 95 | # constructed tensor is of shape (n_queries * n_prototypes, 2 * n_channels, width, height) 96 | # (2 * n_channels because prototypes and queries are concatenated) 97 | query_prototype_feature_pairs = torch.cat( 98 | ( 99 | self.prototypes.unsqueeze(dim=0).expand( 100 | query_features.shape[0], -1, -1, -1, -1 101 | ), 102 | query_features.unsqueeze(dim=1).expand( 103 | -1, self.prototypes.shape[0], -1, -1, -1 104 | ), 105 | ), 106 | dim=2, 107 | ).view(-1, 2 * self.feature_dimension, *query_features.shape[2:]) 108 | 109 | # Each pair (query, prototype) is assigned a relation scores in [0,1]. Then we reshape the 110 | # tensor so that relation_scores is of shape (n_queries, n_prototypes). 111 | relation_scores = self.relation_module(query_prototype_feature_pairs).view( 112 | -1, self.prototypes.shape[0] 113 | ) 114 | 115 | return self.softmax_if_specified(relation_scores) 116 | 117 | def _validate_features_shape(self, features): 118 | if len(features.shape) != 4: 119 | raise ValueError( 120 | "Illegal backbone for Relation Networks. " 121 | "Expected output for an image is a 3-dim tensor of shape (n_channels, width, height)." 122 | ) 123 | if features.shape[1] != self.feature_dimension: 124 | raise ValueError( 125 | f"Expected feature dimension is {self.feature_dimension}, but got {features.shape[1]}." 126 | ) 127 | 128 | @staticmethod 129 | def is_transductive() -> bool: 130 | return False 131 | -------------------------------------------------------------------------------- /easyfsl/methods/simple_shot.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | 3 | from .few_shot_classifier import FewShotClassifier 4 | 5 | 6 | class SimpleShot(FewShotClassifier): 7 | """ 8 | Yan Wang, Wei-Lun Chao, Kilian Q. Weinberger, and Laurens van der Maaten. 9 | "SimpleShot: Revisiting Nearest-Neighbor Classification for Few-Shot Learning" (2019) 10 | https://arxiv.org/abs/1911.04623 11 | 12 | Almost exactly Prototypical Classification, but with cosine distance instead of euclidean distance. 13 | """ 14 | 15 | def forward( 16 | self, 17 | query_images: Tensor, 18 | ) -> Tensor: 19 | """ 20 | Predict classification labels. 21 | Args: 22 | query_images: images of the query set of shape (n_query, **image_shape) 23 | Returns: 24 | a prediction of classification scores for query images of shape (n_query, n_classes) 25 | """ 26 | query_features = self.compute_features(query_images) 27 | self._raise_error_if_features_are_multi_dimensional(query_features) 28 | 29 | scores = self.cosine_distance_to_prototypes(query_features) 30 | 31 | return self.softmax_if_specified(scores) 32 | 33 | @staticmethod 34 | def is_transductive() -> bool: 35 | return False 36 | -------------------------------------------------------------------------------- /easyfsl/methods/tim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor, nn 3 | 4 | from .few_shot_classifier import FewShotClassifier 5 | 6 | 7 | class TIM(FewShotClassifier): 8 | """ 9 | Malik Boudiaf, Ziko Imtiaz Masud, Jérôme Rony, José Dolz, Pablo Piantanida, Ismail Ben Ayed. 10 | "Transductive Information Maximization For Few-Shot Learning" (NeurIPS 2020) 11 | https://arxiv.org/abs/2008.11297 12 | 13 | Fine-tune prototypes based on 14 | 1) classification error on support images 15 | 2) mutual information between query features and their label predictions 16 | Classify w.r.t. to euclidean distance to updated prototypes. 17 | As is, it is incompatible with episodic training because we freeze the backbone to perform 18 | fine-tuning. 19 | 20 | TIM is a transductive method. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | *args, 26 | fine_tuning_steps: int = 50, 27 | fine_tuning_lr: float = 1e-4, 28 | cross_entropy_weight: float = 1.0, 29 | marginal_entropy_weight: float = 1.0, 30 | conditional_entropy_weight: float = 0.1, 31 | temperature: float = 10.0, 32 | **kwargs, 33 | ): 34 | """ 35 | Args: 36 | fine_tuning_steps: number of fine-tuning steps 37 | fine_tuning_lr: learning rate for fine-tuning 38 | cross_entropy_weight: weight given to the cross-entropy term of the loss 39 | marginal_entropy_weight: weight given to the marginal entropy term of the loss 40 | conditional_entropy_weight: weight given to the conditional entropy term of the loss 41 | temperature: temperature applied to the logits before computing 42 | softmax or cross-entropy. Higher temperature means softer predictions. 43 | """ 44 | super().__init__(*args, **kwargs) 45 | 46 | # Since we fine-tune the prototypes we need to make them leaf variables 47 | # i.e. we need to freeze the backbone. 48 | self.backbone.requires_grad_(False) 49 | 50 | self.fine_tuning_steps = fine_tuning_steps 51 | self.fine_tuning_lr = fine_tuning_lr 52 | self.cross_entropy_weight = cross_entropy_weight 53 | self.marginal_entropy_weight = marginal_entropy_weight 54 | self.conditional_entropy_weight = conditional_entropy_weight 55 | self.temperature = temperature 56 | 57 | def forward( 58 | self, 59 | query_images: Tensor, 60 | ) -> Tensor: 61 | """ 62 | Overrides forward method of FewShotClassifier. 63 | Fine-tune prototypes based on support classification error and mutual information between 64 | query features and their label predictions. 65 | Then classify w.r.t. to euclidean distance to prototypes. 66 | """ 67 | query_features = self.compute_features(query_images) 68 | 69 | num_classes = self.support_labels.unique().size(0) 70 | support_labels_one_hot = nn.functional.one_hot( # pylint: disable=not-callable 71 | self.support_labels, num_classes 72 | ) 73 | 74 | with torch.enable_grad(): 75 | self.prototypes.requires_grad_() 76 | optimizer = torch.optim.Adam([self.prototypes], lr=self.fine_tuning_lr) 77 | 78 | for _ in range(self.fine_tuning_steps): 79 | support_logits = self.temperature * self.cosine_distance_to_prototypes( 80 | self.support_features 81 | ) 82 | query_logits = self.temperature * self.cosine_distance_to_prototypes( 83 | query_features 84 | ) 85 | 86 | support_cross_entropy = ( 87 | -(support_labels_one_hot * support_logits.log_softmax(1)) 88 | .sum(1) 89 | .mean(0) 90 | ) 91 | 92 | query_soft_probs = query_logits.softmax(1) 93 | query_conditional_entropy = ( 94 | -(query_soft_probs * torch.log(query_soft_probs + 1e-12)) 95 | .sum(1) 96 | .mean(0) 97 | ) 98 | 99 | marginal_prediction = query_soft_probs.mean(0) 100 | marginal_entropy = -( 101 | marginal_prediction * torch.log(marginal_prediction) 102 | ).sum(0) 103 | 104 | loss = self.cross_entropy_weight * support_cross_entropy - ( 105 | self.marginal_entropy_weight * marginal_entropy 106 | - self.conditional_entropy_weight * query_conditional_entropy 107 | ) 108 | 109 | optimizer.zero_grad() 110 | loss.backward() 111 | optimizer.step() 112 | 113 | return self.softmax_if_specified( 114 | self.cosine_distance_to_prototypes(query_features), 115 | temperature=self.temperature, 116 | ).detach() 117 | 118 | @staticmethod 119 | def is_transductive() -> bool: 120 | return True 121 | -------------------------------------------------------------------------------- /easyfsl/methods/transductive_finetuning.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor, nn 3 | 4 | from .finetune import Finetune 5 | from .utils import entropy 6 | 7 | 8 | class TransductiveFinetuning(Finetune): 9 | """ 10 | Guneet S. Dhillon, Pratik Chaudhari, Avinash Ravichandran, Stefano Soatto. 11 | "A Baseline for Few-Shot Image Classification" (ICLR 2020) 12 | https://arxiv.org/abs/1909.02729 13 | 14 | Fine-tune the parameters of the pre-trained model based on 15 | 1) classification error on support images 16 | 2) classification entropy for query images 17 | Classify queries based on their euclidean distance to prototypes. 18 | This is a transductive method. 19 | WARNING: this implementation only updates prototypes, not the whole set of model's 20 | parameters. Updating the model's parameters raises performance issues that we didn't 21 | have time to solve yet. We welcome contributions. 22 | As is, it is incompatible with episodic training because we freeze the backbone to perform 23 | fine-tuning. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | *args, 29 | fine_tuning_steps: int = 25, 30 | fine_tuning_lr: float = 5e-5, 31 | temperature: float = 1.0, 32 | **kwargs, 33 | ): 34 | """ 35 | TransductiveFinetuning is very similar to the inductive method Finetune. 36 | The difference only resides in the way we perform the fine-tuning step and in the 37 | distance we use. Therefore, we call the super constructor of Finetune 38 | (and same for preprocess_support_set()). 39 | Args: 40 | fine_tuning_steps: number of fine-tuning steps 41 | fine_tuning_lr: learning rate for fine-tuning 42 | temperature: temperature applied to the logits before computing 43 | softmax or cross-entropy. Higher temperature means softer predictions. 44 | """ 45 | super().__init__( 46 | *args, 47 | fine_tuning_steps=fine_tuning_steps, 48 | fine_tuning_lr=fine_tuning_lr, 49 | temperature=temperature, 50 | **kwargs, 51 | ) 52 | 53 | def forward( 54 | self, 55 | query_images: Tensor, 56 | ) -> Tensor: 57 | """ 58 | Overrides forward method of FewShotClassifier. 59 | Fine-tune model's parameters based on support classification error and 60 | query classification entropy. 61 | """ 62 | query_features = self.compute_features(query_images) 63 | 64 | with torch.enable_grad(): 65 | self.prototypes.requires_grad_() 66 | optimizer = torch.optim.Adam([self.prototypes], lr=self.fine_tuning_lr) 67 | for _ in range(self.fine_tuning_steps): 68 | support_cross_entropy = nn.functional.cross_entropy( 69 | self.temperature 70 | * self.l2_distance_to_prototypes(self.support_features), 71 | self.support_labels, 72 | ) 73 | query_conditional_entropy = entropy( 74 | self.temperature * self.l2_distance_to_prototypes(query_features) 75 | ) 76 | loss = support_cross_entropy + query_conditional_entropy 77 | optimizer.zero_grad() 78 | loss.backward() 79 | optimizer.step() 80 | 81 | return self.softmax_if_specified( 82 | self.l2_distance_to_prototypes(query_features), temperature=self.temperature 83 | ).detach() 84 | 85 | @staticmethod 86 | def is_transductive() -> bool: 87 | return True 88 | -------------------------------------------------------------------------------- /easyfsl/methods/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | 7 | def compute_prototypes(support_features: Tensor, support_labels: Tensor) -> Tensor: 8 | """ 9 | Compute class prototypes from support features and labels 10 | Args: 11 | support_features: for each instance in the support set, its feature vector 12 | support_labels: for each instance in the support set, its label 13 | 14 | Returns: 15 | for each label of the support set, the average feature vector of instances with this label 16 | """ 17 | 18 | n_way = len(torch.unique(support_labels)) 19 | # Prototype i is the mean of all instances of features corresponding to labels == i 20 | return torch.cat( 21 | [ 22 | support_features[torch.nonzero(support_labels == label)].mean(0) 23 | for label in range(n_way) 24 | ] 25 | ) 26 | 27 | 28 | def entropy(logits: Tensor) -> Tensor: 29 | """ 30 | Compute entropy of prediction. 31 | WARNING: takes logit as input, not probability. 32 | Args: 33 | logits: shape (n_images, n_way) 34 | Returns: 35 | Tensor: shape(), Mean entropy. 36 | """ 37 | probabilities = logits.softmax(dim=1) 38 | return (-(probabilities * (probabilities + 1e-12).log()).sum(dim=1)).mean() 39 | 40 | 41 | def k_nearest_neighbours(features: Tensor, k: int, p_norm: int = 2) -> Tensor: 42 | """ 43 | Compute k nearest neighbours of each feature vector, not included itself. 44 | Args: 45 | features: input features of shape (n_features, feature_dimension) 46 | k: number of nearest neighbours to retain 47 | p_norm: use l_p distance. Defaults: 2. 48 | 49 | Returns: 50 | Tensor: shape (n_features, k), indices of k nearest neighbours of each feature vector. 51 | """ 52 | distances = torch.cdist(features, features, p_norm) 53 | 54 | return distances.topk(k, largest=False).indices[:, 1:] 55 | 56 | 57 | def power_transform(features: Tensor, power_factor: float) -> Tensor: 58 | """ 59 | Apply power transform to features. 60 | Args: 61 | features: input features of shape (n_features, feature_dimension) 62 | power_factor: power to apply to features 63 | 64 | Returns: 65 | Tensor: shape (n_features, feature_dimension), power transformed features. 66 | """ 67 | return (features.relu() + 1e-6).pow(power_factor) 68 | 69 | 70 | def strip_prefix(state_dict: OrderedDict, prefix: str): 71 | """ 72 | Strip a prefix from the keys of a state_dict. Can be used to address compatibility issues from 73 | a loaded state_dict to a model with slightly different parameter names. 74 | Example usage: 75 | state_dict = torch.load("model.pth") 76 | # state_dict contains keys like "module.encoder.0.weight" but the model expects keys like "encoder.0.weight" 77 | state_dict = strip_prefix(state_dict, "module.") 78 | model.load_state_dict(state_dict) 79 | Args: 80 | state_dict: pytorch state_dict, as returned by model.state_dict() or loaded via torch.load() 81 | Keys are the names of the parameters and values are the parameter tensors. 82 | prefix: prefix to strip from the keys of the state_dict. Usually ends with a dot. 83 | 84 | Returns: 85 | copy of the state_dict with the prefix stripped from the keys 86 | """ 87 | return OrderedDict( 88 | [ 89 | (k[len(prefix) :] if k.startswith(prefix) else k, v) 90 | for k, v in state_dict.items() 91 | ] 92 | ) 93 | -------------------------------------------------------------------------------- /easyfsl/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import ResNet # isort:skip 2 | from .attention_modules import MultiHeadAttention 3 | from .feat_resnet12 import feat_resnet12 4 | from .predesigned_modules import * 5 | -------------------------------------------------------------------------------- /easyfsl/modules/attention_modules.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | 6 | # pylint: disable=invalid-name,too-many-instance-attributes 7 | 8 | 9 | class ScaledDotProductAttention(nn.Module): 10 | """Scaled Dot-Product Attention""" 11 | 12 | def __init__(self, temperature, dropout=0.1): 13 | super().__init__() 14 | self.temperature = temperature 15 | self.dropout = nn.Dropout(dropout) 16 | self.softmax = nn.Softmax(dim=2) 17 | 18 | def forward(self, q, k, v): 19 | attention = torch.bmm(q, k.transpose(1, 2)) 20 | attention = attention / self.temperature 21 | raw_attention = attention 22 | log_attention = nn.functional.log_softmax(attention, 2) 23 | attention = self.softmax(attention) 24 | attention = self.dropout(attention) 25 | output = torch.bmm(attention, v) 26 | return output, attention, log_attention, raw_attention 27 | 28 | 29 | class MultiHeadAttention(nn.Module): 30 | """Multi-Head Attention module""" 31 | 32 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1, flag_norm=True): 33 | super().__init__() 34 | self.n_head = n_head 35 | self.d_k = d_k 36 | self.d_v = d_v 37 | 38 | self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) 39 | self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False) 40 | self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False) 41 | nn.init.normal_(self.w_qs.weight, mean=0, std=math.sqrt(2.0 / (d_model + d_k))) 42 | nn.init.normal_(self.w_ks.weight, mean=0, std=math.sqrt(2.0 / (d_model + d_k))) 43 | nn.init.normal_(self.w_vs.weight, mean=0, std=math.sqrt(2.0 / (d_model + d_v))) 44 | 45 | self.attention = ScaledDotProductAttention(temperature=math.sqrt(d_k)) 46 | self.layer_norm = nn.LayerNorm(d_model) 47 | 48 | self.fc = nn.Linear(n_head * d_v, d_model) 49 | nn.init.xavier_normal_(self.fc.weight) 50 | self.dropout = nn.Dropout(dropout) 51 | self.flag_norm = flag_norm 52 | 53 | def forward(self, q, k, v): 54 | """ 55 | Go through the multi-head attention module. 56 | """ 57 | sz_q, len_q, _ = q.size() 58 | sz_b, len_k, _ = k.size() 59 | sz_b, len_v, _ = v.size() 60 | 61 | residual = q 62 | q = self.w_qs(q).view(sz_q, len_q, self.n_head, self.d_k) 63 | k = self.w_ks(k).view(sz_b, len_k, self.n_head, self.d_k) 64 | v = self.w_vs(v).view(sz_b, len_v, self.n_head, self.d_v) 65 | 66 | q = ( 67 | q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, self.d_k) 68 | ) # (n*b) x lq x dk 69 | k = ( 70 | k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, self.d_k) 71 | ) # (n*b) x lk x dk 72 | v = ( 73 | v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, self.d_v) 74 | ) # (n*b) x lv x dv 75 | 76 | output, _, _, _ = self.attention(q, k, v) 77 | 78 | output = output.view(self.n_head, sz_q, len_q, self.d_v) 79 | output = ( 80 | output.permute(1, 2, 0, 3).contiguous().view(sz_q, len_q, -1) 81 | ) # b x lq x (n*dv) 82 | resout = self.fc(output) 83 | output = self.dropout(resout) 84 | if self.flag_norm: 85 | output = self.layer_norm(output + residual) 86 | 87 | return output, resout 88 | -------------------------------------------------------------------------------- /easyfsl/modules/build_from_checkpoint.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | 5 | from easyfsl.methods.utils import strip_prefix 6 | from easyfsl.modules.feat_resnet12 import FEATResNet12, feat_resnet12 7 | 8 | 9 | def feat_resnet12_from_checkpoint( 10 | checkpoint_path: Path, device: str, **kwargs 11 | ) -> FEATResNet12: 12 | model = feat_resnet12(**kwargs).to(device) 13 | 14 | state_dict = torch.load(str(checkpoint_path), map_location=device)["params"] 15 | 16 | backbone_missing_keys, _ = model.load_state_dict( 17 | strip_prefix(state_dict, "encoder."), strict=False 18 | ) 19 | 20 | if len(backbone_missing_keys) > 0: 21 | raise ValueError(f"Missing keys for backbone: {backbone_missing_keys}") 22 | 23 | return model 24 | -------------------------------------------------------------------------------- /easyfsl/modules/feat_resnet12.py: -------------------------------------------------------------------------------- 1 | """ 2 | This particular ResNet12 is simplified from the original implementation of FEAT (https://github.com/Sha-Lab/FEAT). 3 | We provide it to allow the reproduction of the FEAT method and the use of the chekcpoints they made available. 4 | It contains some design choices that differ from the usual ResNet12. Use this one or the other. 5 | Just remember that it is important to use the same backbone for a fair comparison between methods. 6 | """ 7 | 8 | from torch import nn 9 | from torchvision.models.resnet import conv3x3 10 | 11 | 12 | class FEATBasicBlock(nn.Module): 13 | """ 14 | BasicBlock for FEAT. Uses 3 convolutions instead of 2, a LeakyReLU instead of ReLU, and a MaxPool2d. 15 | """ 16 | 17 | expansion = 1 18 | 19 | def __init__( 20 | self, 21 | inplanes, 22 | planes, 23 | stride=1, 24 | downsample=None, 25 | ): 26 | super().__init__() 27 | self.conv1 = conv3x3(inplanes, planes) 28 | self.bn1 = nn.BatchNorm2d(planes) 29 | self.relu = nn.LeakyReLU(0.1) 30 | self.conv2 = conv3x3(planes, planes) 31 | self.bn2 = nn.BatchNorm2d(planes) 32 | self.conv3 = conv3x3(planes, planes) 33 | self.bn3 = nn.BatchNorm2d(planes) 34 | self.maxpool = nn.MaxPool2d(stride) 35 | self.downsample = downsample 36 | 37 | def forward(self, x): # pylint: disable=invalid-name 38 | """ 39 | Pass input through the block, including an activation and maxpooling at the end. 40 | """ 41 | 42 | residual = x 43 | 44 | out = self.conv1(x) 45 | out = self.bn1(out) 46 | 47 | out = self.relu(out) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | 52 | out = self.relu(out) 53 | 54 | out = self.conv3(out) 55 | out = self.bn3(out) 56 | 57 | if self.downsample is not None: 58 | residual = self.downsample(x) 59 | out += residual 60 | 61 | out = self.relu(out) 62 | out = self.maxpool(out) 63 | 64 | return out 65 | 66 | 67 | class FEATResNet12(nn.Module): 68 | """ 69 | ResNet12 for FEAT. See feat_resnet12 doc for more details. 70 | """ 71 | 72 | def __init__( 73 | self, 74 | block=FEATBasicBlock, 75 | ): 76 | self.inplanes = 3 77 | super().__init__() 78 | 79 | channels = [64, 160, 320, 640] 80 | self.layer_dims = [ 81 | channels[i] * block.expansion for i in range(4) for j in range(4) 82 | ] 83 | 84 | self.layer1 = self._make_layer( 85 | block, 86 | 64, 87 | stride=2, 88 | ) 89 | self.layer2 = self._make_layer( 90 | block, 91 | 160, 92 | stride=2, 93 | ) 94 | self.layer3 = self._make_layer( 95 | block, 96 | 320, 97 | stride=2, 98 | ) 99 | self.layer4 = self._make_layer( 100 | block, 101 | 640, 102 | stride=2, 103 | ) 104 | 105 | for module in self.modules(): 106 | if isinstance(module, nn.Conv2d): 107 | nn.init.kaiming_normal_( 108 | module.weight, mode="fan_out", nonlinearity="leaky_relu" 109 | ) 110 | elif isinstance(module, nn.BatchNorm2d): 111 | nn.init.constant_(module.weight, 1) 112 | nn.init.constant_(module.bias, 0) 113 | 114 | def _make_layer(self, block, planes, stride=1): 115 | downsample = None 116 | if stride != 1 or self.inplanes != planes * block.expansion: 117 | downsample = nn.Sequential( 118 | nn.Conv2d( 119 | self.inplanes, 120 | planes * block.expansion, 121 | kernel_size=1, 122 | stride=1, 123 | bias=False, 124 | ), 125 | nn.BatchNorm2d(planes * block.expansion), 126 | ) 127 | 128 | layers = [] 129 | layers.append( 130 | block( 131 | self.inplanes, 132 | planes, 133 | stride, 134 | downsample, 135 | ) 136 | ) 137 | self.inplanes = planes * block.expansion 138 | 139 | return nn.Sequential(*layers) 140 | 141 | def forward(self, x): # pylint: disable=invalid-name 142 | """ 143 | Iterate over the blocks and apply them sequentially. 144 | """ 145 | x = self.layer4(self.layer3(self.layer2(self.layer1(x)))) 146 | return x.mean((-2, -1)) 147 | 148 | 149 | def feat_resnet12(**kwargs): 150 | """ 151 | Build a ResNet12 model as used in the FEAT paper, following the implementation of 152 | https://github.com/Sha-Lab/FEAT. 153 | This ResNet network also follows the practice of the following papers: 154 | TADAM: Task dependent adaptive metric for improved few-shot learning (Oreshkin et al., in NIPS 2018) and 155 | A Simple Neural Attentive Meta-Learner (Mishra et al., in ICLR 2018). 156 | 157 | There are 4 main differences with the other ResNet models used in EasyFSL: 158 | - There is no first convolutional layer (3x3, 64) before the first block. 159 | - The stride of the first block is 2 instead of 1. 160 | - The BasicBlock uses 3 convolutional layers, instead of 2 in the standard torch implementation. 161 | - We don't initialize the last fully connected layer, since we never use it. 162 | 163 | Note that we removed the Dropout logic from the original implementation, as it is not part of the paper. 164 | 165 | Args: 166 | **kwargs: Additional arguments to pass to the FEATResNet12 class. 167 | 168 | Returns: 169 | The standard ResNet12 from FEAT model. 170 | """ 171 | return FEATResNet12(FEATBasicBlock, **kwargs) 172 | -------------------------------------------------------------------------------- /easyfsl/modules/predesigned_modules.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torchvision.models.resnet import BasicBlock, Bottleneck 3 | 4 | from .resnet import ResNet 5 | 6 | __all__ = [ 7 | "resnet10", 8 | "resnet12", 9 | "resnet18", 10 | "resnet34", 11 | "resnet50", 12 | "resnet101", 13 | "resnet152", 14 | "default_matching_networks_support_encoder", 15 | "default_matching_networks_query_encoder", 16 | "default_relation_module", 17 | ] 18 | 19 | 20 | def resnet10(**kwargs) -> ResNet: 21 | """Constructs a ResNet-10 model.""" 22 | return ResNet(BasicBlock, [1, 1, 1, 1], **kwargs) 23 | 24 | 25 | def resnet12(**kwargs) -> ResNet: 26 | """Constructs a ResNet-12 model.""" 27 | return ResNet(BasicBlock, [1, 1, 2, 1], planes=[64, 160, 320, 640], **kwargs) 28 | 29 | 30 | def resnet18(**kwargs) -> ResNet: 31 | """Constructs a ResNet-18 model.""" 32 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 33 | 34 | 35 | def resnet34(**kwargs) -> ResNet: 36 | """Constructs a ResNet-34 model.""" 37 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 38 | 39 | 40 | def resnet50(**kwargs) -> ResNet: 41 | """Constructs a ResNet-50 model.""" 42 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 43 | 44 | 45 | def resnet101(**kwargs) -> ResNet: 46 | """Constructs a ResNet-101 model.""" 47 | return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 48 | 49 | 50 | def resnet152(**kwargs) -> ResNet: 51 | """Constructs a ResNet-152 model.""" 52 | return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 53 | 54 | 55 | def default_matching_networks_support_encoder(feature_dimension: int) -> nn.Module: 56 | return nn.LSTM( 57 | input_size=feature_dimension, 58 | hidden_size=feature_dimension, 59 | num_layers=1, 60 | batch_first=True, 61 | bidirectional=True, 62 | ) 63 | 64 | 65 | def default_matching_networks_query_encoder(feature_dimension: int) -> nn.Module: 66 | return nn.LSTMCell(feature_dimension * 2, feature_dimension) 67 | 68 | 69 | def default_relation_module( 70 | feature_dimension: int, inner_channels: int = 8 71 | ) -> nn.Module: 72 | """ 73 | Build the relation module that takes as input the concatenation of two feature maps, from 74 | Sung et al. : "Learning to compare: Relation network for few-shot learning." (2018) 75 | In order to make the network robust to any change in the dimensions of the input images, 76 | we made some changes to the architecture defined in the original implementation 77 | from Sung et al.(typically the use of adaptive pooling). 78 | Args: 79 | feature_dimension: the dimension of the feature space i.e. size of a feature vector 80 | inner_channels: number of hidden channels between the linear layers of the relation module 81 | Returns: 82 | the constructed relation module 83 | """ 84 | return nn.Sequential( 85 | nn.Sequential( 86 | nn.Conv2d( 87 | feature_dimension * 2, 88 | feature_dimension, 89 | kernel_size=3, 90 | padding=1, 91 | ), 92 | nn.BatchNorm2d(feature_dimension, momentum=1, affine=True), 93 | nn.ReLU(), 94 | nn.AdaptiveMaxPool2d((5, 5)), 95 | ), 96 | nn.Sequential( 97 | nn.Conv2d( 98 | feature_dimension, 99 | feature_dimension, 100 | kernel_size=3, 101 | padding=0, 102 | ), 103 | nn.BatchNorm2d(feature_dimension, momentum=1, affine=True), 104 | nn.ReLU(), 105 | nn.AdaptiveMaxPool2d((1, 1)), 106 | ), 107 | nn.Flatten(), 108 | nn.Linear(feature_dimension, inner_channels), 109 | nn.ReLU(), 110 | nn.Linear(inner_channels, 1), 111 | nn.Sigmoid(), 112 | ) 113 | -------------------------------------------------------------------------------- /easyfsl/modules/resnet.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Type, Union 2 | 3 | import torch 4 | from torch import Tensor, nn 5 | from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1 6 | 7 | # pylint: disable=invalid-name, too-many-instance-attributes, too-many-arguments 8 | 9 | 10 | class ResNet(nn.Module): 11 | def __init__( 12 | self, 13 | block: Type[Union[BasicBlock, Bottleneck]], 14 | layers: List[int], 15 | planes: Optional[List[int]] = None, 16 | use_fc: bool = False, 17 | num_classes: int = 1000, 18 | use_pooling: bool = True, 19 | big_kernel: bool = False, 20 | zero_init_residual: bool = False, 21 | ): 22 | """ 23 | Custom ResNet architecture, with some design differences compared to the built-in 24 | PyTorch ResNet. 25 | This implementation and its usage in predesigned_modules is derived from 26 | https://github.com/fiveai/on-episodes-fsl/blob/master/src/models/ResNet.py 27 | Args: 28 | block: which core block to use (BasicBlock, Bottleneck, or any child of one of these) 29 | layers: number of blocks in each of the 4 layers 30 | planes: number of planes in each of the 4 layers 31 | use_fc: whether to use one last linear layer on features 32 | num_classes: output dimension of the last linear layer (only used if use_fc is True) 33 | use_pooling: whether to average pool the features (must be True if use_fc is True) 34 | big_kernel: whether to use the shape of the built-in PyTorch ResNet designed for 35 | ImageNet. If False, make the first convolutional layer less destructive. 36 | zero_init_residual: zero-initialize the last BN in each residual branch, so that the 37 | residual branch starts with zeros, and each residual block behaves like an identity. 38 | This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 39 | """ 40 | super().__init__() 41 | if planes is None: 42 | planes = [64, 128, 256, 512] 43 | 44 | self.inplanes = 64 45 | 46 | # Built-in PyTorch ResNet uses a first conv layer with a 7*7 kernel and a stride of 2, 47 | # which is fine for ImageNet's 224x224 images, but too destructive for 84x84 images 48 | # which are commonly used in few-shot settings. 49 | self.conv1 = ( 50 | nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=1, bias=False) 51 | if big_kernel 52 | else nn.Conv2d( 53 | 3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False 54 | ) 55 | ) 56 | self.bn1 = nn.BatchNorm2d(self.inplanes) 57 | self.relu = nn.ReLU(inplace=True) 58 | 59 | self.layer1 = self._make_layer(block, planes[0], layers[0]) 60 | self.layer2 = self._make_layer(block, planes[1], layers[1], stride=2) 61 | self.layer3 = self._make_layer(block, planes[2], layers[2], stride=2) 62 | self.layer4 = self._make_layer(block, planes[3], layers[3], stride=2) 63 | 64 | self.use_pooling = use_pooling 65 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 66 | 67 | # Only used when self.use_fc is True 68 | self.use_fc = use_fc 69 | self.fc = nn.Linear(self.inplanes, num_classes) 70 | 71 | for module in self.modules(): 72 | if isinstance(module, nn.Conv2d): 73 | nn.init.kaiming_normal_( 74 | module.weight, mode="fan_out", nonlinearity="relu" 75 | ) 76 | elif isinstance(module, nn.BatchNorm2d): 77 | nn.init.constant_(module.weight, 1) 78 | nn.init.constant_(module.bias, 0) 79 | 80 | if zero_init_residual: 81 | for module in self.modules(): 82 | if isinstance(module, Bottleneck): 83 | nn.init.constant_(module.bn3.weight, 0) 84 | elif isinstance(module, BasicBlock): 85 | nn.init.constant_(module.bn2.weight, 0) 86 | 87 | def _make_layer( 88 | self, 89 | block: Type[Union[BasicBlock, Bottleneck]], 90 | planes: int, 91 | blocks: int, 92 | stride: int = 1, 93 | ) -> nn.Module: 94 | downsample = None 95 | if stride != 1 or self.inplanes != planes * block.expansion: 96 | downsample = nn.Sequential( 97 | conv1x1(self.inplanes, planes * block.expansion, stride), 98 | nn.BatchNorm2d(planes * block.expansion), 99 | ) 100 | 101 | layers = [] 102 | layers.append(block(self.inplanes, planes, stride, downsample)) 103 | self.inplanes = planes * block.expansion 104 | for _ in range(1, blocks): 105 | layers.append(block(self.inplanes, planes)) 106 | 107 | return nn.Sequential(*layers) 108 | 109 | def forward(self, x: Tensor) -> Tensor: 110 | """ 111 | Forward pass through the ResNet. 112 | Args: 113 | x: input tensor of shape (batch_size, **image_shape) 114 | Returns: 115 | x: output tensor of shape (batch_size, num_classes) if self.use_fc is True, 116 | otherwise of shape (batch_size, **feature_shape) 117 | """ 118 | x = self.layer4( 119 | self.layer3(self.layer2(self.layer1(self.relu(self.bn1(self.conv1(x)))))) 120 | ) 121 | 122 | if self.use_pooling: 123 | x = torch.flatten( 124 | self.avgpool(x), 125 | 1, 126 | ) 127 | 128 | if self.use_fc: 129 | return self.fc(x) 130 | 131 | else: 132 | if self.use_fc: 133 | raise ValueError( 134 | "You can't use the fully connected layer without pooling features." 135 | ) 136 | 137 | return x 138 | 139 | def set_use_fc(self, use_fc: bool): 140 | """ 141 | Change the use_fc property. Allow to decide when and where the model should use its last 142 | fully connected layer. 143 | Args: 144 | use_fc: whether to set self.use_fc to True or False 145 | """ 146 | self.use_fc = use_fc 147 | 148 | 149 | # pylint: enable=invalid-name, too-many-instance-attributes, too-many-arguments 150 | -------------------------------------------------------------------------------- /easyfsl/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | from .task_sampler import TaskSampler 2 | -------------------------------------------------------------------------------- /easyfsl/samplers/task_sampler.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Dict, Iterator, List, Tuple, Union 3 | 4 | import torch 5 | from torch import Tensor 6 | from torch.utils.data import Sampler 7 | 8 | from easyfsl.datasets import FewShotDataset 9 | 10 | GENERIC_TYPING_ERROR_MESSAGE = ( 11 | "Check out the output's type of your dataset's __getitem__() method." 12 | "It must be a Tuple[Tensor, int] or Tuple[Tensor, 0-dim Tensor]." 13 | ) 14 | 15 | 16 | class TaskSampler(Sampler): 17 | """ 18 | Samples batches in the shape of few-shot classification tasks. At each iteration, it will sample 19 | n_way classes, and then sample support and query images from these classes. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | dataset: FewShotDataset, 25 | n_way: int, 26 | n_shot: int, 27 | n_query: int, 28 | n_tasks: int, 29 | ): 30 | """ 31 | Args: 32 | dataset: dataset from which to sample classification tasks. Must have implement get_labels() from 33 | FewShotDataset. 34 | n_way: number of classes in one task 35 | n_shot: number of support images for each class in one task 36 | n_query: number of query images for each class in one task 37 | n_tasks: number of tasks to sample 38 | """ 39 | super().__init__(data_source=None) 40 | self.n_way = n_way 41 | self.n_shot = n_shot 42 | self.n_query = n_query 43 | self.n_tasks = n_tasks 44 | 45 | self.items_per_label: Dict[int, List[int]] = {} 46 | for item, label in enumerate(dataset.get_labels()): 47 | if label in self.items_per_label: 48 | self.items_per_label[label].append(item) 49 | else: 50 | self.items_per_label[label] = [item] 51 | 52 | self._check_dataset_size_fits_sampler_parameters() 53 | 54 | def __len__(self) -> int: 55 | return self.n_tasks 56 | 57 | def __iter__(self) -> Iterator[List[int]]: 58 | """ 59 | Sample n_way labels uniformly at random, 60 | and then sample n_shot + n_query items for each label, also uniformly at random. 61 | Yields: 62 | a list of indices of length (n_way * (n_shot + n_query)) 63 | """ 64 | for _ in range(self.n_tasks): 65 | yield torch.cat( 66 | [ 67 | torch.tensor( 68 | random.sample( 69 | self.items_per_label[label], self.n_shot + self.n_query 70 | ) 71 | ) 72 | for label in random.sample( 73 | sorted(self.items_per_label.keys()), self.n_way 74 | ) 75 | ] 76 | ).tolist() 77 | 78 | def episodic_collate_fn( 79 | self, input_data: List[Tuple[Tensor, Union[Tensor, int]]] 80 | ) -> Tuple[Tensor, Tensor, Tensor, Tensor, List[int]]: 81 | """ 82 | Collate function to be used as argument for the collate_fn parameter of episodic 83 | data loaders. 84 | Args: 85 | input_data: each element is a tuple containing: 86 | - an image as a torch Tensor of shape (n_channels, height, width) 87 | - the label of this image as an int or a 0-dim tensor 88 | Returns: 89 | tuple(Tensor, Tensor, Tensor, Tensor, list[int]): respectively: 90 | - support images of shape (n_way * n_shot, n_channels, height, width), 91 | - their labels of shape (n_way * n_shot), 92 | - query images of shape (n_way * n_query, n_channels, height, width) 93 | - their labels of shape (n_way * n_query), 94 | - the dataset class ids of the class sampled in the episode 95 | """ 96 | input_data_with_int_labels = self._cast_input_data_to_tensor_int_tuple( 97 | input_data 98 | ) 99 | true_class_ids = list({x[1] for x in input_data_with_int_labels}) 100 | all_images = torch.cat([x[0].unsqueeze(0) for x in input_data_with_int_labels]) 101 | all_images = all_images.reshape( 102 | (self.n_way, self.n_shot + self.n_query, *all_images.shape[1:]) 103 | ) 104 | all_labels = torch.tensor( 105 | [true_class_ids.index(x[1]) for x in input_data_with_int_labels] 106 | ).reshape((self.n_way, self.n_shot + self.n_query)) 107 | support_images = all_images[:, : self.n_shot].reshape( 108 | (-1, *all_images.shape[2:]) 109 | ) 110 | query_images = all_images[:, self.n_shot :].reshape((-1, *all_images.shape[2:])) 111 | support_labels = all_labels[:, : self.n_shot].flatten() 112 | query_labels = all_labels[:, self.n_shot :].flatten() 113 | return ( 114 | support_images, 115 | support_labels, 116 | query_images, 117 | query_labels, 118 | true_class_ids, 119 | ) 120 | 121 | @staticmethod 122 | def _cast_input_data_to_tensor_int_tuple( 123 | input_data: List[Tuple[Tensor, Union[Tensor, int]]] 124 | ) -> List[Tuple[Tensor, int]]: 125 | """ 126 | Check the type of the input for the episodic_collate_fn method, and cast it to the right type if possible. 127 | Args: 128 | input_data: each element is a tuple containing: 129 | - an image as a torch Tensor of shape (n_channels, height, width) 130 | - the label of this image as an int or a 0-dim tensor 131 | Returns: 132 | the input data with the labels cast to int 133 | Raises: 134 | TypeError : Wrong type of input images or labels 135 | ValueError: Input label is not a 0-dim tensor 136 | """ 137 | for image, label in input_data: 138 | if not isinstance(image, Tensor): 139 | raise TypeError( 140 | f"Illegal type of input instance: {type(image)}. " 141 | + GENERIC_TYPING_ERROR_MESSAGE 142 | ) 143 | if not isinstance(label, int): 144 | if not isinstance(label, Tensor): 145 | raise TypeError( 146 | f"Illegal type of input label: {type(label)}. " 147 | + GENERIC_TYPING_ERROR_MESSAGE 148 | ) 149 | if label.dtype not in { 150 | torch.uint8, 151 | torch.int8, 152 | torch.int16, 153 | torch.int32, 154 | torch.int64, 155 | }: 156 | raise TypeError( 157 | f"Illegal dtype of input label tensor: {label.dtype}. " 158 | + GENERIC_TYPING_ERROR_MESSAGE 159 | ) 160 | if label.ndim != 0: 161 | raise ValueError( 162 | f"Illegal shape for input label tensor: {label.shape}. " 163 | + GENERIC_TYPING_ERROR_MESSAGE 164 | ) 165 | 166 | return [(image, int(label)) for (image, label) in input_data] 167 | 168 | def _check_dataset_size_fits_sampler_parameters(self): 169 | """ 170 | Check that the dataset size is compatible with the sampler parameters 171 | """ 172 | self._check_dataset_has_enough_labels() 173 | self._check_dataset_has_enough_items_per_label() 174 | 175 | def _check_dataset_has_enough_labels(self): 176 | if self.n_way > len(self.items_per_label): 177 | raise ValueError( 178 | f"The number of labels in the dataset ({len(self.items_per_label)} " 179 | f"must be greater or equal to n_way ({self.n_way})." 180 | ) 181 | 182 | def _check_dataset_has_enough_items_per_label(self): 183 | number_of_samples_per_label = [ 184 | len(items_for_label) for items_for_label in self.items_per_label.values() 185 | ] 186 | minimum_number_of_samples_per_label = min(number_of_samples_per_label) 187 | label_with_minimum_number_of_samples = number_of_samples_per_label.index( 188 | minimum_number_of_samples_per_label 189 | ) 190 | if self.n_shot + self.n_query > minimum_number_of_samples_per_label: 191 | raise ValueError( 192 | f"Label {label_with_minimum_number_of_samples} has only {minimum_number_of_samples_per_label} samples" 193 | f"but all classes must have at least n_shot + n_query ({self.n_shot + self.n_query}) samples." 194 | ) 195 | -------------------------------------------------------------------------------- /easyfsl/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicara/easy-few-shot-learning/8023ff49a02a68830c10a21b8eb908cb33bdf1b9/easyfsl/tests/__init__.py -------------------------------------------------------------------------------- /easyfsl/tests/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicara/easy-few-shot-learning/8023ff49a02a68830c10a21b8eb908cb33bdf1b9/easyfsl/tests/datasets/__init__.py -------------------------------------------------------------------------------- /easyfsl/tests/datasets/easy_set_test.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from unittest.mock import mock_open, patch 4 | 5 | import pytest 6 | 7 | from easyfsl.datasets import EasySet 8 | from easyfsl.datasets.tiered_imagenet import TieredImageNet 9 | 10 | 11 | def init_easy_set(specs): 12 | buffer = json.dumps(specs) 13 | with patch("builtins.open", mock_open(read_data=buffer)): 14 | with patch("pathlib.Path.glob") as mock_glob: 15 | mock_glob.return_value = [Path("a.jpeg"), Path("b.jpeg")] 16 | EasySet(Path("dummy.json")) 17 | 18 | 19 | class TestEasySetInit: 20 | @staticmethod 21 | @pytest.mark.filterwarnings("ignore::UserWarning") 22 | @pytest.mark.parametrize( 23 | "specs", 24 | [ 25 | { 26 | "class_names": ["class_1", "class_2"], 27 | "class_roots": ["path/to/class_1_folder", "path/to/class_2_folder"], 28 | "extra_key": [], 29 | }, 30 | ], 31 | ) 32 | def test_init_does_not_break_when_specs_are_ok(specs): 33 | init_easy_set(specs) 34 | 35 | @staticmethod 36 | @pytest.mark.parametrize( 37 | "specs_file_str", 38 | [ 39 | "path/to/file.csv", 40 | "file.png", 41 | ], 42 | ) 43 | def test_init_does_not_accept_non_json_specs(specs_file_str): 44 | with pytest.raises(ValueError): 45 | EasySet(Path(specs_file_str)) 46 | 47 | @staticmethod 48 | @pytest.mark.parametrize( 49 | "specs", 50 | [ 51 | {"class_roots": ["path/to/class_1_folder", "path/to/class_2_folder"]}, 52 | { 53 | "class_names": ["class_1", "class_2"], 54 | }, 55 | { 56 | "class_names": ["class_1", "class_2", "class_3"], 57 | "class_roots": ["path/to/class_1_folder", "path/to/class_2_folder"], 58 | }, 59 | { 60 | "class_names": ["class_1", "class_2"], 61 | "class_roots": [ 62 | "path/to/class_1_folder", 63 | "path/to/class_2_folder", 64 | "path/to/class_3_folder", 65 | ], 66 | }, 67 | ], 68 | ) 69 | def test_init_returns_error_when_specs_dont_match_template(specs): 70 | with pytest.raises(ValueError): 71 | init_easy_set(specs) 72 | 73 | 74 | class TestEasySetListDataInstances: 75 | @staticmethod 76 | @pytest.mark.parametrize( 77 | "class_roots,images,labels", 78 | [ 79 | ( 80 | [ 81 | "path/to/class_1_folder", 82 | "path/to/class_2_folder", 83 | "path/to/class_3_folder", 84 | ], 85 | [ 86 | "a.png", 87 | "b.png", 88 | "a.png", 89 | "b.png", 90 | "a.png", 91 | "b.png", 92 | ], 93 | [0, 0, 1, 1, 2, 2], 94 | ) 95 | ], 96 | ) 97 | def test_list_data_instances_returns_expected_values( 98 | class_roots, images, labels, mocker 99 | ): 100 | mocker.patch("pathlib.Path.glob", return_value=[Path("a.png"), Path("b.png")]) 101 | mocker.patch("pathlib.Path.is_file", return_value=True) 102 | 103 | assert (images, labels) == EasySet.list_data_instances(class_roots) 104 | 105 | @staticmethod 106 | @pytest.mark.parametrize( 107 | "images, all_files", 108 | [ 109 | ( 110 | [ 111 | # These must be sorted 112 | "a.bmp", 113 | "a.jpeg", 114 | "a.jpg", 115 | "a.png", 116 | ], 117 | [ 118 | "a.png", 119 | "a.jpg", 120 | "a.txt", 121 | "a.bmp", 122 | "a.jpeg", 123 | "a.tmp", 124 | ], 125 | ), 126 | ], 127 | ) 128 | def test_list_data_instances_lists_only_images(images, all_files, mocker): 129 | mocker.patch( 130 | "pathlib.Path.glob", 131 | return_value=[Path(file_name) for file_name in all_files], 132 | ) 133 | mocker.patch("pathlib.Path.is_file", return_value=True) 134 | 135 | assert images == EasySet.list_data_instances(["abc"])[0] 136 | 137 | 138 | class TestTieredImagenet: 139 | @staticmethod 140 | def test_tiered_imagenet_raises_error_if_wrong_split(): 141 | with pytest.raises(ValueError): 142 | TieredImageNet("nope") 143 | 144 | @staticmethod 145 | @pytest.mark.parametrize( 146 | "split", 147 | [ 148 | "train", 149 | "val", 150 | "test", 151 | ], 152 | ) 153 | @pytest.mark.filterwarnings("ignore::UserWarning") 154 | def test_tiered_imagenet_builds_easyset(split, mocker): 155 | mocker.patch( 156 | "pathlib.Path.glob", 157 | return_value=[Path("a.png"), Path("b.png")], 158 | ) 159 | dataset = TieredImageNet(split) 160 | assert isinstance(dataset, EasySet) 161 | -------------------------------------------------------------------------------- /easyfsl/tests/datasets/resources/balanced_support_set/160.Black_throated_Blue_Warbler/Black_throated_Blue_Warbler_0007_2916700989.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicara/easy-few-shot-learning/8023ff49a02a68830c10a21b8eb908cb33bdf1b9/easyfsl/tests/datasets/resources/balanced_support_set/160.Black_throated_Blue_Warbler/Black_throated_Blue_Warbler_0007_2916700989.jpg -------------------------------------------------------------------------------- /easyfsl/tests/datasets/resources/balanced_support_set/160.Black_throated_Blue_Warbler/Black_throated_Blue_Warbler_0008_2966090836.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicara/easy-few-shot-learning/8023ff49a02a68830c10a21b8eb908cb33bdf1b9/easyfsl/tests/datasets/resources/balanced_support_set/160.Black_throated_Blue_Warbler/Black_throated_Blue_Warbler_0008_2966090836.jpg -------------------------------------------------------------------------------- /easyfsl/tests/datasets/resources/balanced_support_set/161.Blue_winged_Warbler/Blue_winged_Warbler_0011_2521539056.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicara/easy-few-shot-learning/8023ff49a02a68830c10a21b8eb908cb33bdf1b9/easyfsl/tests/datasets/resources/balanced_support_set/161.Blue_winged_Warbler/Blue_winged_Warbler_0011_2521539056.jpg -------------------------------------------------------------------------------- /easyfsl/tests/datasets/resources/balanced_support_set/161.Blue_winged_Warbler/Blue_winged_Warbler_0028_1988388399.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicara/easy-few-shot-learning/8023ff49a02a68830c10a21b8eb908cb33bdf1b9/easyfsl/tests/datasets/resources/balanced_support_set/161.Blue_winged_Warbler/Blue_winged_Warbler_0028_1988388399.jpg -------------------------------------------------------------------------------- /easyfsl/tests/datasets/resources/balanced_support_set/162.Canada_Warbler/Canada_Warbler_0001_2495535649.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicara/easy-few-shot-learning/8023ff49a02a68830c10a21b8eb908cb33bdf1b9/easyfsl/tests/datasets/resources/balanced_support_set/162.Canada_Warbler/Canada_Warbler_0001_2495535649.jpg -------------------------------------------------------------------------------- /easyfsl/tests/datasets/resources/balanced_support_set/162.Canada_Warbler/Canada_Warbler_0002_2529931098.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicara/easy-few-shot-learning/8023ff49a02a68830c10a21b8eb908cb33bdf1b9/easyfsl/tests/datasets/resources/balanced_support_set/162.Canada_Warbler/Canada_Warbler_0002_2529931098.jpg -------------------------------------------------------------------------------- /easyfsl/tests/datasets/resources/empty_support_set/class_with_no_image/not_an_image.txt: -------------------------------------------------------------------------------- 1 | this is not an image 2 | -------------------------------------------------------------------------------- /easyfsl/tests/datasets/resources/unbalanced_support_set/160.Black_throated_Blue_Warbler/Black_throated_Blue_Warbler_0007_2916700989.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicara/easy-few-shot-learning/8023ff49a02a68830c10a21b8eb908cb33bdf1b9/easyfsl/tests/datasets/resources/unbalanced_support_set/160.Black_throated_Blue_Warbler/Black_throated_Blue_Warbler_0007_2916700989.jpg -------------------------------------------------------------------------------- /easyfsl/tests/datasets/resources/unbalanced_support_set/160.Black_throated_Blue_Warbler/Black_throated_Blue_Warbler_0008_2966090836.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicara/easy-few-shot-learning/8023ff49a02a68830c10a21b8eb908cb33bdf1b9/easyfsl/tests/datasets/resources/unbalanced_support_set/160.Black_throated_Blue_Warbler/Black_throated_Blue_Warbler_0008_2966090836.jpg -------------------------------------------------------------------------------- /easyfsl/tests/datasets/resources/unbalanced_support_set/161.Blue_winged_Warbler/Blue_winged_Warbler_0011_2521539056.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicara/easy-few-shot-learning/8023ff49a02a68830c10a21b8eb908cb33bdf1b9/easyfsl/tests/datasets/resources/unbalanced_support_set/161.Blue_winged_Warbler/Blue_winged_Warbler_0011_2521539056.jpg -------------------------------------------------------------------------------- /easyfsl/tests/datasets/resources/unbalanced_support_set/162.Canada_Warbler/Canada_Warbler_0001_2495535649.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicara/easy-few-shot-learning/8023ff49a02a68830c10a21b8eb908cb33bdf1b9/easyfsl/tests/datasets/resources/unbalanced_support_set/162.Canada_Warbler/Canada_Warbler_0001_2495535649.jpg -------------------------------------------------------------------------------- /easyfsl/tests/datasets/resources/unbalanced_support_set/162.Canada_Warbler/Canada_Warbler_0002_2529931098.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicara/easy-few-shot-learning/8023ff49a02a68830c10a21b8eb908cb33bdf1b9/easyfsl/tests/datasets/resources/unbalanced_support_set/162.Canada_Warbler/Canada_Warbler_0002_2529931098.jpg -------------------------------------------------------------------------------- /easyfsl/tests/datasets/resources/unbalanced_support_set/162.Canada_Warbler/Canada_Warbler_0003_2509806963.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicara/easy-few-shot-learning/8023ff49a02a68830c10a21b8eb908cb33bdf1b9/easyfsl/tests/datasets/resources/unbalanced_support_set/162.Canada_Warbler/Canada_Warbler_0003_2509806963.jpg -------------------------------------------------------------------------------- /easyfsl/tests/datasets/resources/unbalanced_support_set/162.Canada_Warbler/Canada_Warbler_0004_2530218943.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicara/easy-few-shot-learning/8023ff49a02a68830c10a21b8eb908cb33bdf1b9/easyfsl/tests/datasets/resources/unbalanced_support_set/162.Canada_Warbler/Canada_Warbler_0004_2530218943.jpg -------------------------------------------------------------------------------- /easyfsl/tests/datasets/resources/unbalanced_support_set/162.Canada_Warbler/Canada_Warbler_0005_887179386.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicara/easy-few-shot-learning/8023ff49a02a68830c10a21b8eb908cb33bdf1b9/easyfsl/tests/datasets/resources/unbalanced_support_set/162.Canada_Warbler/Canada_Warbler_0005_887179386.jpg -------------------------------------------------------------------------------- /easyfsl/tests/datasets/support_set_folder_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | from torchvision import transforms 5 | 6 | from easyfsl.datasets import SupportSetFolder 7 | from easyfsl.datasets.support_set_folder import NOT_A_TENSOR_ERROR_MESSAGE 8 | 9 | 10 | class TestSupportSetFolderInit: 11 | @staticmethod 12 | @pytest.mark.parametrize( 13 | "root", 14 | [ 15 | Path("easyfsl/tests/datasets/resources/balanced_support_set"), 16 | Path("easyfsl/tests/datasets/resources/unbalanced_support_set"), 17 | ], 18 | ) 19 | def test_init_does_not_break_when_support_set_is_ok_and_not_custom_args(root): 20 | SupportSetFolder(root) 21 | 22 | @staticmethod 23 | @pytest.mark.parametrize( 24 | "root,transform", 25 | [ 26 | ( 27 | Path("easyfsl/tests/datasets/resources/balanced_support_set"), 28 | transforms.Resize((10, 10)), 29 | ), 30 | ], 31 | ) 32 | def test_init_raises_type_error_when_transform_does_not_input_tensor( 33 | root, transform 34 | ): 35 | with pytest.raises(TypeError) as exc_info: 36 | SupportSetFolder(root, transform=transform) 37 | assert exc_info.value.args[0] == NOT_A_TENSOR_ERROR_MESSAGE 38 | 39 | @staticmethod 40 | @pytest.mark.parametrize( 41 | "root", 42 | [ 43 | Path("easyfsl/tests/datasets/resources/empty_support_set"), 44 | ], 45 | ) 46 | def test_init_raises_error_when_support_set_is_empty(root): 47 | with pytest.raises(FileNotFoundError): 48 | SupportSetFolder(root) 49 | -------------------------------------------------------------------------------- /easyfsl/tests/datasets/wrap_few_shot_dataset_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from torchvision.datasets import ImageFolder 3 | 4 | from easyfsl.datasets import WrapFewShotDataset 5 | 6 | 7 | class FakeImageFolder(ImageFolder): 8 | def __init__( 9 | self, 10 | *args, 11 | image_position_in_get_item_output, 12 | label_position_in_get_item_output, 13 | **kwargs, 14 | ): 15 | super().__init__(*args, **kwargs) 16 | self.image_position_in_get_item_output = image_position_in_get_item_output 17 | self.label_position_in_get_item_output = label_position_in_get_item_output 18 | 19 | def __getitem__(self, item): 20 | image, label = super().__getitem__(item) 21 | output_as_list = [None] * ( 22 | max( 23 | self.image_position_in_get_item_output, 24 | self.label_position_in_get_item_output, 25 | ) 26 | + 1 27 | ) 28 | output_as_list[self.image_position_in_get_item_output] = image 29 | output_as_list[self.label_position_in_get_item_output] = label 30 | return tuple(output_as_list) 31 | 32 | 33 | class TestInit: 34 | @staticmethod 35 | @pytest.mark.parametrize( 36 | "source_dataset,expected_labels", 37 | [ 38 | ( 39 | ImageFolder("easyfsl/tests/datasets/resources/balanced_support_set"), 40 | [0, 0, 1, 1, 2, 2], 41 | ), 42 | ( 43 | ImageFolder("easyfsl/tests/datasets/resources/unbalanced_support_set"), 44 | [0, 0, 1, 2, 2, 2, 2, 2], 45 | ), 46 | ], 47 | ) 48 | def test_default_init_retrieves_correct_labels(source_dataset, expected_labels): 49 | wrapped_dataset = WrapFewShotDataset(source_dataset) 50 | assert wrapped_dataset.get_labels() == expected_labels 51 | 52 | @staticmethod 53 | @pytest.mark.parametrize( 54 | "image_position_in_get_item_output,label_position_in_get_item_output", 55 | [ 56 | (1, 0), 57 | (1, 2), 58 | (4, 5), 59 | (0, 10), 60 | (10, 0), 61 | ], 62 | ) 63 | def test_init_retrieves_correct_labels_from_special_positions( 64 | image_position_in_get_item_output, 65 | label_position_in_get_item_output, 66 | ): 67 | source_dataset = FakeImageFolder( 68 | "easyfsl/tests/datasets/resources/unbalanced_support_set", 69 | image_position_in_get_item_output=image_position_in_get_item_output, 70 | label_position_in_get_item_output=label_position_in_get_item_output, 71 | ) 72 | wrapped_dataset = WrapFewShotDataset( 73 | source_dataset, 74 | image_position_in_get_item_output, 75 | label_position_in_get_item_output, 76 | ) 77 | assert wrapped_dataset.get_labels() == [0, 0, 1, 2, 2, 2, 2, 2] 78 | 79 | @staticmethod 80 | @pytest.mark.parametrize( 81 | "image_position_in_get_item_output,label_position_in_get_item_output", 82 | [ 83 | (0, 2), 84 | (2, 0), 85 | (-1, 0), 86 | (0, -1), 87 | (10, 9), 88 | (1, 1), 89 | ], 90 | ) 91 | def test_raises_error_when_input_positions_are_out_of_item_range( 92 | image_position_in_get_item_output, label_position_in_get_item_output 93 | ): 94 | source_dataset = FakeImageFolder( 95 | "easyfsl/tests/datasets/resources/unbalanced_support_set", 96 | image_position_in_get_item_output=0, 97 | label_position_in_get_item_output=1, 98 | ) 99 | with pytest.raises(ValueError): 100 | WrapFewShotDataset( 101 | source_dataset, 102 | image_position_in_get_item_output, 103 | label_position_in_get_item_output, 104 | ) 105 | -------------------------------------------------------------------------------- /easyfsl/tests/methods/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicara/easy-few-shot-learning/8023ff49a02a68830c10a21b8eb908cb33bdf1b9/easyfsl/tests/methods/__init__.py -------------------------------------------------------------------------------- /easyfsl/tests/methods/conftest.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | import torch 5 | from PIL import Image 6 | from torch import nn 7 | from torchvision import transforms 8 | 9 | 10 | @pytest.fixture 11 | def example_few_shot_classification_task(): 12 | """Dummy few-shot classification task fixture.""" 13 | images_dir = Path("easyfsl/tests/methods/resources") 14 | support_image_paths = [ 15 | "Black_footed_Albatross_0001_2950163169.jpg", 16 | "Black_footed_Albatross_0002_2293084168.jpg", 17 | "Least_Auklet_0001_2947317867.jpg", 18 | ] 19 | query_image_paths = [ 20 | "Black_footed_Albatross_0004_2731401028.jpg", 21 | "Least_Auklet_0004_2685272855.jpg", 22 | ] 23 | support_labels = torch.tensor([0, 0, 1]) 24 | 25 | to_tensor = transforms.ToTensor() 26 | support_images = torch.stack( 27 | [ 28 | to_tensor(Image.open(images_dir / img_name)) 29 | for img_name in support_image_paths 30 | ] 31 | ) 32 | query_images = torch.stack( 33 | [to_tensor(Image.open(images_dir / img_name)) for img_name in query_image_paths] 34 | ) 35 | 36 | return support_images, support_labels, query_images 37 | 38 | 39 | @pytest.fixture() 40 | def dummy_network(): 41 | return nn.Sequential( 42 | nn.Flatten(), 43 | nn.AdaptiveAvgPool1d(output_size=10), 44 | nn.Linear(10, 5), 45 | ) 46 | 47 | 48 | @pytest.fixture() 49 | def deterministic_dummy_network(): 50 | return nn.Sequential( 51 | nn.Flatten(), 52 | nn.AdaptiveAvgPool1d(output_size=1), 53 | ) 54 | -------------------------------------------------------------------------------- /easyfsl/tests/methods/feat_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch import nn 4 | 5 | from easyfsl.methods import PrototypicalNetworks 6 | from easyfsl.methods.feat import FEAT 7 | from easyfsl.modules.attention_modules import MultiHeadAttention 8 | 9 | FLATTENED_IMAGE_SIZE = 3072 10 | 11 | 12 | class TestFeat: 13 | @staticmethod 14 | def test_forward_runs_without_error(example_few_shot_classification_task): 15 | ( 16 | support_images, 17 | support_labels, 18 | query_images, 19 | ) = example_few_shot_classification_task 20 | 21 | model = FEAT( 22 | nn.Flatten(), 23 | attention_module=MultiHeadAttention( 24 | 1, FLATTENED_IMAGE_SIZE, FLATTENED_IMAGE_SIZE, FLATTENED_IMAGE_SIZE 25 | ), 26 | ) 27 | model.eval() 28 | model.process_support_set(support_images, support_labels) 29 | model(query_images) 30 | 31 | @staticmethod 32 | def test_returns_expected_output_for_example_images( 33 | example_few_shot_classification_task, 34 | ): 35 | ( 36 | support_images, 37 | support_labels, 38 | query_images, 39 | ) = example_few_shot_classification_task 40 | 41 | torch.manual_seed(1) 42 | torch.set_num_threads(1) 43 | 44 | model = FEAT( 45 | nn.Flatten(), 46 | attention_module=MultiHeadAttention( 47 | 1, FLATTENED_IMAGE_SIZE, FLATTENED_IMAGE_SIZE, FLATTENED_IMAGE_SIZE 48 | ), 49 | ) 50 | model.eval() 51 | 52 | model.process_support_set(support_images, support_labels) 53 | predictions = model(query_images) 54 | 55 | assert torch.all( 56 | torch.isclose( 57 | predictions, 58 | torch.tensor( 59 | [[-59.5840, -57.9814], [-71.3547, -70.1513]], 60 | ), 61 | atol=1, 62 | ) 63 | ) 64 | 65 | @staticmethod 66 | def test_raise_error_when_features_are_not_1_dim( 67 | example_few_shot_classification_task, 68 | ): 69 | ( 70 | support_images, 71 | support_labels, 72 | _, 73 | ) = example_few_shot_classification_task 74 | 75 | model = FEAT( 76 | nn.Identity(), 77 | attention_module=MultiHeadAttention( 78 | 1, FLATTENED_IMAGE_SIZE, FLATTENED_IMAGE_SIZE, FLATTENED_IMAGE_SIZE 79 | ), 80 | ) 81 | with pytest.raises(ValueError): 82 | model.process_support_set(support_images, support_labels) 83 | 84 | @staticmethod 85 | def test_attention_module_updates_prototypes(example_few_shot_classification_task): 86 | ( 87 | support_images, 88 | support_labels, 89 | _, 90 | ) = example_few_shot_classification_task 91 | model_feat = FEAT( 92 | nn.Flatten(), 93 | attention_module=MultiHeadAttention( 94 | 1, FLATTENED_IMAGE_SIZE, FLATTENED_IMAGE_SIZE, FLATTENED_IMAGE_SIZE 95 | ), 96 | ) 97 | model_protonet = PrototypicalNetworks(nn.Flatten()) 98 | 99 | model_feat.process_support_set(support_images, support_labels) 100 | model_protonet.process_support_set(support_images, support_labels) 101 | 102 | assert not model_feat.prototypes.isclose( 103 | model_protonet.prototypes, atol=1e-02 104 | ).all() 105 | -------------------------------------------------------------------------------- /easyfsl/tests/methods/few_shot_classifier_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | 5 | from easyfsl.methods import FewShotClassifier 6 | from easyfsl.modules.predesigned_modules import resnet12 7 | 8 | 9 | class TestFSCAbstractMethods: 10 | @staticmethod 11 | def test_forward_raises_error_when_not_implemented(): 12 | with pytest.raises(NotImplementedError): 13 | model = FewShotClassifier(resnet12()) 14 | model(None) 15 | 16 | 17 | class TestFSCComputeFeatures: 18 | @staticmethod 19 | def test_compute_features_gives_unchanged_features_when_centering_is_none(): 20 | model = FewShotClassifier() 21 | features = torch.rand(10, 64) 22 | assert torch.allclose(model.compute_features(features), features) 23 | 24 | @staticmethod 25 | def test_compute_features_gives_centered_features_when_centering_is_not_none(): 26 | model = FewShotClassifier(feature_centering=torch.rand(64)) 27 | features = torch.rand(10, 64) 28 | assert torch.allclose( 29 | model.compute_features(features), features - model.feature_centering 30 | ) 31 | 32 | @staticmethod 33 | def test_compute_features_gives_l2_normalized_features_when_specified(): 34 | model = FewShotClassifier(feature_normalization=2) 35 | features = torch.ones((10, 2)) 36 | assert torch.allclose( 37 | model.compute_features(features), 38 | (np.sqrt(2) / 2) * torch.ones((10, 2)), 39 | ) 40 | 41 | @staticmethod 42 | def test_compute_features_gives_l1_normalized_features_when_specified(): 43 | model = FewShotClassifier(feature_normalization=1) 44 | features = torch.ones((10, 2)) 45 | assert torch.allclose( 46 | model.compute_features(features), 47 | 0.5 * torch.ones((10, 2)), 48 | ) 49 | 50 | @staticmethod 51 | def test_compute_features_gives_unnormalized_features_when_argument_is_none(): 52 | model = FewShotClassifier(feature_normalization=None) 53 | features = torch.ones((10, 2)) 54 | assert torch.allclose( 55 | model.compute_features(features), 56 | torch.ones((10, 2)), 57 | ) 58 | -------------------------------------------------------------------------------- /easyfsl/tests/methods/finetuning_methods_test.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | 3 | import pytest 4 | import torch 5 | from torch import nn 6 | 7 | from easyfsl.datasets import SupportSetFolder 8 | from easyfsl.methods import PTMAP, TIM, Finetune, TransductiveFinetuning 9 | 10 | ALL_FINETUNING_METHODS = [ 11 | Finetune, 12 | TIM, 13 | TransductiveFinetuning, 14 | PTMAP, 15 | ] 16 | 17 | 18 | class TestFinetuningMethodsRun: 19 | @staticmethod 20 | @pytest.mark.parametrize("method", ALL_FINETUNING_METHODS) 21 | def test_methods_run_in_ordinary_context( 22 | method, example_few_shot_classification_task, dummy_network 23 | ): 24 | model = method(backbone=dummy_network, fine_tuning_steps=2) 25 | ( 26 | support_images, 27 | support_labels, 28 | query_images, 29 | ) = example_few_shot_classification_task 30 | 31 | model.process_support_set(support_images, support_labels) 32 | 33 | model(query_images) 34 | 35 | @staticmethod 36 | @pytest.mark.parametrize("method", ALL_FINETUNING_METHODS) 37 | def test_methods_run_in_no_grad_context( 38 | method, example_few_shot_classification_task, dummy_network 39 | ): 40 | model = method(backbone=dummy_network, fine_tuning_steps=2) 41 | ( 42 | support_images, 43 | support_labels, 44 | query_images, 45 | ) = example_few_shot_classification_task 46 | with torch.no_grad(): 47 | model.process_support_set(support_images, support_labels) 48 | 49 | model(query_images) 50 | 51 | @staticmethod 52 | @pytest.mark.parametrize("method", ALL_FINETUNING_METHODS) 53 | def test_prototypes_update_in_ordinary_context( 54 | method, example_few_shot_classification_task, dummy_network 55 | ): 56 | model = method(backbone=dummy_network, fine_tuning_steps=2, fine_tuning_lr=1.0) 57 | ( 58 | support_images, 59 | support_labels, 60 | query_images, 61 | ) = example_few_shot_classification_task 62 | model.process_support_set(support_images, support_labels) 63 | prototypes = model.prototypes.clone() 64 | 65 | model(query_images) 66 | assert not prototypes.isclose(model.prototypes, atol=1e-02).all() 67 | 68 | @staticmethod 69 | @pytest.mark.parametrize("method", ALL_FINETUNING_METHODS) 70 | def test_prototypes_update_in_no_grad_context( 71 | method, example_few_shot_classification_task, dummy_network 72 | ): 73 | model = method(backbone=dummy_network, fine_tuning_steps=2, fine_tuning_lr=1.0) 74 | ( 75 | support_images, 76 | support_labels, 77 | query_images, 78 | ) = example_few_shot_classification_task 79 | with torch.no_grad(): 80 | model.process_support_set(support_images, support_labels) 81 | prototypes = model.prototypes.clone() 82 | 83 | model(query_images) 84 | assert not prototypes.isclose(model.prototypes, atol=1e-02).all() 85 | 86 | @staticmethod 87 | @pytest.mark.parametrize("method", ALL_FINETUNING_METHODS) 88 | def test_raise_value_error_for_not_1_dim_features( 89 | method, 90 | example_few_shot_classification_task, 91 | ): 92 | model = method(backbone=nn.Identity(), fine_tuning_steps=2, fine_tuning_lr=1.0) 93 | ( 94 | support_images, 95 | support_labels, 96 | _, 97 | ) = example_few_shot_classification_task 98 | with pytest.raises(ValueError): 99 | model.process_support_set(support_images, support_labels) 100 | 101 | 102 | class TestFinetuningMethodsCanProcessSupportSetFolder: 103 | @staticmethod 104 | @pytest.mark.parametrize("method", ALL_FINETUNING_METHODS) 105 | def test_finetuning_methods_can_process_support_set_from_balanced_folder( 106 | method, dummy_network 107 | ): 108 | support_set = SupportSetFolder( 109 | "easyfsl/tests/datasets/resources/balanced_support_set" 110 | ) 111 | support_images = support_set.get_images() 112 | support_labels = support_set.get_labels() 113 | 114 | model = method(backbone=dummy_network, fine_tuning_steps=2, fine_tuning_lr=1.0) 115 | model.process_support_set(support_images, support_labels) 116 | 117 | query_images = torch.randn((4, 3, 224, 224)) 118 | model(query_images) 119 | 120 | @staticmethod 121 | @pytest.mark.parametrize("method", ALL_FINETUNING_METHODS) 122 | def test_finetuning_methods_can_process_support_set_from_unbalanced_folder( 123 | method, dummy_network 124 | ): 125 | support_set = SupportSetFolder( 126 | "easyfsl/tests/datasets/resources/unbalanced_support_set" 127 | ) 128 | support_images = support_set.get_images() 129 | support_labels = support_set.get_labels() 130 | 131 | model = method(backbone=dummy_network, fine_tuning_steps=2, fine_tuning_lr=1.0) 132 | model.process_support_set(support_images, support_labels) 133 | 134 | query_images = torch.randn((4, 3, 224, 224)) 135 | model(query_images) 136 | 137 | @staticmethod 138 | @pytest.mark.parametrize( 139 | ("method", "support_set_path"), 140 | list( 141 | product( 142 | ALL_FINETUNING_METHODS, 143 | [ 144 | "easyfsl/tests/datasets/resources/unbalanced_support_set", 145 | "easyfsl/tests/datasets/resources/balanced_support_set", 146 | ], 147 | ) 148 | ), 149 | ) 150 | def test_finetuning_methods_store_correct_support_labels( 151 | method, support_set_path, dummy_network 152 | ): 153 | support_set = SupportSetFolder(support_set_path) 154 | support_images = support_set.get_images() 155 | support_labels = support_set.get_labels() 156 | 157 | model = method(backbone=dummy_network, fine_tuning_steps=2, fine_tuning_lr=1.0) 158 | model.process_support_set(support_images, support_labels) 159 | 160 | assert torch.equal(model.support_labels, support_labels) 161 | 162 | @staticmethod 163 | @pytest.mark.parametrize( 164 | ( 165 | "method", 166 | "support_set_path_and_expected_prototypes", 167 | ), 168 | list( 169 | product( 170 | ALL_FINETUNING_METHODS[:3], 171 | [ 172 | ( 173 | "easyfsl/tests/datasets/resources/unbalanced_support_set", 174 | [-0.0987, -0.0489, -0.3414], 175 | ), 176 | ( 177 | "easyfsl/tests/datasets/resources/balanced_support_set", 178 | [-0.0987, 0.2805, -0.3582], 179 | ), 180 | ], 181 | ) 182 | ) 183 | + list( 184 | product( 185 | [PTMAP], 186 | [ 187 | ( 188 | "easyfsl/tests/datasets/resources/unbalanced_support_set", 189 | [0.1311, 0.0010, 0.1176], 190 | ), 191 | ( 192 | "easyfsl/tests/datasets/resources/balanced_support_set", 193 | [0.1311, 0.3910, 0.2925], 194 | ), 195 | ], 196 | ) 197 | ), 198 | ) 199 | def test_finetuning_methods_store_correct_prototypes( 200 | method, support_set_path_and_expected_prototypes, deterministic_dummy_network 201 | ): 202 | support_set_path, expected_prototypes = support_set_path_and_expected_prototypes 203 | support_set = SupportSetFolder(support_set_path) 204 | support_images = support_set.get_images() 205 | support_labels = support_set.get_labels() 206 | 207 | model = method( 208 | backbone=deterministic_dummy_network, 209 | fine_tuning_steps=2, 210 | fine_tuning_lr=1.0, 211 | ) 212 | model.process_support_set(support_images, support_labels) 213 | 214 | assert torch.all( 215 | torch.isclose( 216 | model.prototypes, 217 | torch.tensor(expected_prototypes).unsqueeze(1), 218 | atol=1e-04, 219 | ) 220 | ) 221 | -------------------------------------------------------------------------------- /easyfsl/tests/methods/matching_networks_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch import nn 4 | 5 | from easyfsl.datasets import SupportSetFolder 6 | from easyfsl.methods import MatchingNetworks 7 | 8 | 9 | class TestMatchingNetworksInit: 10 | @staticmethod 11 | def test_init(dummy_network): 12 | MatchingNetworks(dummy_network, feature_dimension=4) 13 | 14 | 15 | class TestMatchingNetworksPipeline: 16 | @staticmethod 17 | def test_matching_networks_returns_expected_output_for_example_images( 18 | example_few_shot_classification_task, 19 | ): 20 | ( 21 | support_images, 22 | support_labels, 23 | query_images, 24 | ) = example_few_shot_classification_task 25 | 26 | torch.manual_seed(1) 27 | torch.set_num_threads(1) 28 | 29 | model = MatchingNetworks(nn.Flatten(), feature_dimension=3072) 30 | 31 | model.process_support_set(support_images, support_labels) 32 | predictions = model(query_images) 33 | 34 | assert torch.all( 35 | torch.isclose( 36 | predictions, 37 | torch.tensor([[-1.3137, -0.3131], [-1.0779, -0.4160]]), 38 | atol=1e-01, 39 | ) 40 | ) 41 | 42 | @staticmethod 43 | def test_process_support_set_returns_value_error_for_not_1_dim_features( 44 | example_few_shot_classification_task, 45 | ): 46 | ( 47 | support_images, 48 | support_labels, 49 | _, 50 | ) = example_few_shot_classification_task 51 | 52 | torch.manual_seed(1) 53 | torch.set_num_threads(1) 54 | 55 | model = MatchingNetworks(nn.Identity(), feature_dimension=3072) 56 | with pytest.raises(ValueError): 57 | model.process_support_set(support_images, support_labels) 58 | 59 | @staticmethod 60 | def test_process_support_set_returns_value_error_for_wrong_dim_features( 61 | example_few_shot_classification_task, 62 | ): 63 | ( 64 | support_images, 65 | support_labels, 66 | _, 67 | ) = example_few_shot_classification_task 68 | 69 | torch.manual_seed(1) 70 | torch.set_num_threads(1) 71 | 72 | model = MatchingNetworks(nn.Identity(), feature_dimension=10) 73 | with pytest.raises(ValueError): 74 | model.process_support_set(support_images, support_labels) 75 | 76 | 77 | class TestMatchingNetsCanProcessSupportSetFolder: 78 | @staticmethod 79 | @pytest.mark.parametrize( 80 | "support_set_path", 81 | [ 82 | "easyfsl/tests/datasets/resources/balanced_support_set", 83 | "easyfsl/tests/datasets/resources/unbalanced_support_set", 84 | ], 85 | ) 86 | def test_matching_nets_can_process_support_set(support_set_path, dummy_network): 87 | support_set = SupportSetFolder(support_set_path) 88 | support_images = support_set.get_images() 89 | support_labels = support_set.get_labels() 90 | 91 | model = MatchingNetworks(backbone=dummy_network, feature_dimension=5) 92 | model.process_support_set(support_images, support_labels) 93 | 94 | query_images = torch.randn((4, 3, 224, 224)) 95 | model(query_images) 96 | -------------------------------------------------------------------------------- /easyfsl/tests/methods/prototypical_networks_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch import nn 4 | 5 | from easyfsl.datasets import SupportSetFolder 6 | from easyfsl.methods import PrototypicalNetworks 7 | 8 | 9 | class TestPrototypicalNetworksPipeline: 10 | @staticmethod 11 | def test_prototypical_networks_returns_expected_output_for_example_images( 12 | example_few_shot_classification_task, 13 | ): 14 | ( 15 | support_images, 16 | support_labels, 17 | query_images, 18 | ) = example_few_shot_classification_task 19 | 20 | torch.manual_seed(1) 21 | torch.set_num_threads(1) 22 | 23 | model = PrototypicalNetworks(nn.Flatten()) 24 | 25 | model.process_support_set(support_images, support_labels) 26 | predictions = model(query_images) 27 | 28 | assert torch.all( 29 | torch.isclose( 30 | predictions, 31 | torch.tensor( 32 | [[-15.5485, -22.0652], [-21.3081, -18.0292]], 33 | ), 34 | atol=1e-01, 35 | ) 36 | ) 37 | 38 | @staticmethod 39 | def test_prototypical_networks_raise_error_when_features_are_not_1_dim( 40 | example_few_shot_classification_task, 41 | ): 42 | ( 43 | support_images, 44 | support_labels, 45 | _, 46 | ) = example_few_shot_classification_task 47 | 48 | model = PrototypicalNetworks(nn.Identity()) 49 | with pytest.raises(ValueError): 50 | model.process_support_set(support_images, support_labels) 51 | 52 | 53 | class TestProtoNetsCanProcessSupportSetFolder: 54 | @staticmethod 55 | @pytest.mark.parametrize( 56 | "support_set_path", 57 | [ 58 | "easyfsl/tests/datasets/resources/balanced_support_set", 59 | "easyfsl/tests/datasets/resources/unbalanced_support_set", 60 | ], 61 | ) 62 | def test_proto_nets_can_process_support_set_from_balanced_folder( 63 | support_set_path, dummy_network 64 | ): 65 | support_set = SupportSetFolder(support_set_path) 66 | support_images = support_set.get_images() 67 | support_labels = support_set.get_labels() 68 | 69 | model = PrototypicalNetworks(backbone=dummy_network) 70 | model.process_support_set(support_images, support_labels) 71 | 72 | query_images = torch.randn((4, 3, 224, 224)) 73 | model(query_images) 74 | 75 | @staticmethod 76 | @pytest.mark.parametrize( 77 | ( 78 | "support_set_path", 79 | "expected_prototypes", 80 | ), 81 | [ 82 | ( 83 | "easyfsl/tests/datasets/resources/unbalanced_support_set", 84 | [-0.0987, -0.0489, -0.3414], 85 | ), 86 | ( 87 | "easyfsl/tests/datasets/resources/balanced_support_set", 88 | [-0.0987, 0.2805, -0.3582], 89 | ), 90 | ], 91 | ) 92 | def test_proto_nets_store_correct_prototypes( 93 | support_set_path, expected_prototypes, deterministic_dummy_network 94 | ): 95 | support_set = SupportSetFolder(support_set_path) 96 | support_images = support_set.get_images() 97 | support_labels = support_set.get_labels() 98 | 99 | model = PrototypicalNetworks(backbone=deterministic_dummy_network) 100 | model.process_support_set(support_images, support_labels) 101 | 102 | assert torch.all( 103 | torch.isclose( 104 | model.prototypes, 105 | torch.tensor(expected_prototypes).unsqueeze(1), 106 | atol=1e-04, 107 | ) 108 | ) 109 | -------------------------------------------------------------------------------- /easyfsl/tests/methods/relation_networks_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch import nn 4 | 5 | from easyfsl.datasets import SupportSetFolder 6 | from easyfsl.methods import RelationNetworks 7 | 8 | 9 | class TestPrototypicalNetworksInit: 10 | @staticmethod 11 | @pytest.mark.parametrize( 12 | "backbone", 13 | [ 14 | nn.Conv2d(3, 4, 4), 15 | ], 16 | ) 17 | def test_init(backbone): 18 | RelationNetworks(backbone, feature_dimension=4) 19 | 20 | 21 | class TestRelationNetworksPipeline: 22 | @staticmethod 23 | def test_prototypical_networks_returns_expected_output_for_example_images( 24 | example_few_shot_classification_task, 25 | ): 26 | ( 27 | support_images, 28 | support_labels, 29 | query_images, 30 | ) = example_few_shot_classification_task 31 | 32 | torch.manual_seed(1) 33 | torch.set_num_threads(1) 34 | 35 | model = RelationNetworks( 36 | nn.Identity(), 37 | relation_module=nn.Sequential( 38 | nn.AdaptiveAvgPool3d((1, 1, 1)), nn.Flatten() 39 | ), 40 | feature_dimension=3, 41 | ) 42 | 43 | model.process_support_set(support_images, support_labels) 44 | predictions = model(query_images) 45 | 46 | assert torch.all( 47 | torch.isclose( 48 | predictions, 49 | torch.tensor( 50 | [[0.4148, 0.4866], [0.6354, 0.7073]], 51 | ), 52 | rtol=1e-3, 53 | ), 54 | ) 55 | 56 | @staticmethod 57 | def test_process_support_set_returns_value_error_for_not_3_dim_features( 58 | example_few_shot_classification_task, 59 | ): 60 | ( 61 | support_images, 62 | support_labels, 63 | _, 64 | ) = example_few_shot_classification_task 65 | 66 | torch.manual_seed(1) 67 | torch.set_num_threads(1) 68 | 69 | model = RelationNetworks( 70 | nn.Flatten(), 71 | relation_module=nn.Sequential( 72 | nn.AdaptiveAvgPool3d((1, 1, 1)), nn.Flatten() 73 | ), 74 | feature_dimension=3, 75 | ) 76 | with pytest.raises(ValueError): 77 | model.process_support_set(support_images, support_labels) 78 | 79 | @staticmethod 80 | def test_process_support_set_returns_value_error_for_wrong_dim_features( 81 | example_few_shot_classification_task, 82 | ): 83 | ( 84 | support_images, 85 | support_labels, 86 | _, 87 | ) = example_few_shot_classification_task 88 | 89 | torch.manual_seed(1) 90 | torch.set_num_threads(1) 91 | 92 | model = RelationNetworks( 93 | nn.Identity(), 94 | relation_module=nn.Sequential( 95 | nn.AdaptiveAvgPool3d((1, 1, 1)), nn.Flatten() 96 | ), 97 | feature_dimension=2, 98 | ) 99 | with pytest.raises(ValueError): 100 | model.process_support_set(support_images, support_labels) 101 | 102 | 103 | class TestRelationNetsCanProcessSupportSetFolder: 104 | @staticmethod 105 | @pytest.mark.parametrize( 106 | "support_set_path", 107 | [ 108 | "easyfsl/tests/datasets/resources/balanced_support_set", 109 | "easyfsl/tests/datasets/resources/unbalanced_support_set", 110 | ], 111 | ) 112 | def test_relation_nets_can_process_support_set_from_balanced_folder( 113 | support_set_path, 114 | ): 115 | support_set = SupportSetFolder(support_set_path) 116 | support_images = support_set.get_images() 117 | support_labels = support_set.get_labels() 118 | 119 | model = RelationNetworks( 120 | nn.Identity(), 121 | relation_module=nn.Sequential( 122 | nn.AdaptiveAvgPool3d((1, 1, 1)), nn.Flatten() 123 | ), 124 | feature_dimension=3, 125 | ) 126 | model.process_support_set(support_images, support_labels) 127 | 128 | query_images = torch.randn((4, 3, 84, 84)) 129 | model(query_images) 130 | -------------------------------------------------------------------------------- /easyfsl/tests/methods/resources/Black_footed_Albatross_0001_2950163169.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicara/easy-few-shot-learning/8023ff49a02a68830c10a21b8eb908cb33bdf1b9/easyfsl/tests/methods/resources/Black_footed_Albatross_0001_2950163169.jpg -------------------------------------------------------------------------------- /easyfsl/tests/methods/resources/Black_footed_Albatross_0002_2293084168.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicara/easy-few-shot-learning/8023ff49a02a68830c10a21b8eb908cb33bdf1b9/easyfsl/tests/methods/resources/Black_footed_Albatross_0002_2293084168.jpg -------------------------------------------------------------------------------- /easyfsl/tests/methods/resources/Black_footed_Albatross_0004_2731401028.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicara/easy-few-shot-learning/8023ff49a02a68830c10a21b8eb908cb33bdf1b9/easyfsl/tests/methods/resources/Black_footed_Albatross_0004_2731401028.jpg -------------------------------------------------------------------------------- /easyfsl/tests/methods/resources/Least_Auklet_0001_2947317867.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicara/easy-few-shot-learning/8023ff49a02a68830c10a21b8eb908cb33bdf1b9/easyfsl/tests/methods/resources/Least_Auklet_0001_2947317867.jpg -------------------------------------------------------------------------------- /easyfsl/tests/methods/resources/Least_Auklet_0004_2685272855.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicara/easy-few-shot-learning/8023ff49a02a68830c10a21b8eb908cb33bdf1b9/easyfsl/tests/methods/resources/Least_Auklet_0004_2685272855.jpg -------------------------------------------------------------------------------- /easyfsl/tests/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicara/easy-few-shot-learning/8023ff49a02a68830c10a21b8eb908cb33bdf1b9/easyfsl/tests/modules/__init__.py -------------------------------------------------------------------------------- /easyfsl/tests/modules/predesigned_modules_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from easyfsl.modules import resnet10, resnet12, resnet18, resnet34, resnet50 5 | 6 | 7 | class TestResNets: 8 | all_resnets = [ 9 | resnet10, 10 | resnet12, 11 | resnet18, 12 | resnet34, 13 | resnet50, 14 | ] 15 | 16 | @staticmethod 17 | @pytest.mark.parametrize("network", all_resnets) 18 | def test_resnets_instantiate_without_error(network): 19 | network() 20 | 21 | @staticmethod 22 | @pytest.mark.parametrize("network", all_resnets) 23 | def test_resnets_output_vector_of_size_num_classes_with_use_fc(network): 24 | num_classes = 10 25 | n_images = 5 26 | 27 | model = network(use_fc=True, num_classes=num_classes) 28 | 29 | input_images = torch.ones((n_images, 3, 84, 84)) 30 | 31 | assert model(input_images).shape == (n_images, num_classes) 32 | -------------------------------------------------------------------------------- /easyfsl/tests/modules/resnet_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torchvision.models.resnet import BasicBlock 4 | 5 | from easyfsl.modules import ResNet 6 | 7 | 8 | class TestResNetForward: 9 | @staticmethod 10 | @pytest.mark.parametrize( 11 | "layers,planes,output_size", 12 | [ 13 | ( 14 | [1, 1, 1, 1], 15 | [16, 32, 14, 8], 16 | 8, 17 | ), 18 | ( 19 | [1, 1, 3, 4], 20 | [16, 32, 14, 8], 21 | 8, 22 | ), 23 | ( 24 | [1, 1, 1, 1], 25 | [16, 32, 14, 1], 26 | 1, 27 | ), 28 | ( 29 | [1, 1, 1, 1], 30 | [16, 32, 14, 8], 31 | 8, 32 | ), 33 | ( 34 | [1, 1, 1, 1], 35 | [4, 4, 4, 4], 36 | 4, 37 | ), 38 | ], 39 | ) 40 | def test_basicblock_resnets_output_vector_of_correct_size_without_fc( 41 | layers, planes, output_size 42 | ): 43 | n_images = 5 44 | 45 | model = ResNet( 46 | block=BasicBlock, 47 | layers=layers, 48 | planes=planes, 49 | use_fc=False, 50 | use_pooling=True, 51 | ) 52 | 53 | input_images = torch.ones((n_images, 3, 84, 84)) 54 | 55 | assert model(input_images).shape == (n_images, output_size) 56 | -------------------------------------------------------------------------------- /easyfsl/tests/samplers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicara/easy-few-shot-learning/8023ff49a02a68830c10a21b8eb908cb33bdf1b9/easyfsl/tests/samplers/__init__.py -------------------------------------------------------------------------------- /easyfsl/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | General utilities 3 | """ 4 | 5 | from typing import List, Optional, Tuple 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import torch 10 | import torchvision 11 | from matplotlib import pyplot as plt 12 | from torch import Tensor, nn 13 | from torch.utils.data import DataLoader 14 | from tqdm import tqdm 15 | 16 | from easyfsl.methods import FewShotClassifier 17 | 18 | 19 | def plot_images(images: Tensor, title: str, images_per_row: int): 20 | """ 21 | Plot images in a grid. 22 | Args: 23 | images: 4D mini-batch Tensor of shape (B x C x H x W) 24 | title: title of the figure to plot 25 | images_per_row: number of images in each row of the grid 26 | """ 27 | plt.figure() 28 | plt.title(title) 29 | plt.imshow( 30 | torchvision.utils.make_grid(images, nrow=images_per_row).permute(1, 2, 0) 31 | ) 32 | 33 | 34 | def sliding_average(value_list: List[float], window: int) -> float: 35 | """ 36 | Computes the average of the latest instances in a list 37 | Args: 38 | value_list: input list of floats (can't be empty) 39 | window: number of instances to take into account. If value is 0 or greater than 40 | the length of value_list, all instances will be taken into account. 41 | 42 | Returns: 43 | average of the last window instances in value_list 44 | 45 | Raises: 46 | ValueError: if the input list is empty 47 | """ 48 | if len(value_list) == 0: 49 | raise ValueError("Cannot perform sliding average on an empty list.") 50 | return np.asarray(value_list[-window:]).mean() 51 | 52 | 53 | def predict_embeddings( 54 | dataloader: DataLoader, 55 | model: nn.Module, 56 | device: Optional[str] = None, 57 | ) -> pd.DataFrame: 58 | """ 59 | Predict embeddings for a dataloader. 60 | Args: 61 | dataloader: dataloader to predict embeddings for. Must deliver tuples (images, class_names) 62 | model: model to use for prediction 63 | device: device to cast the images to. If none, no casting is performed. Must be the same as 64 | the device the model is on. 65 | Returns: 66 | dataframe with columns embedding and class_name 67 | """ 68 | all_embeddings = [] 69 | all_class_names = [] 70 | with torch.no_grad(): 71 | for images, class_names in tqdm( 72 | dataloader, unit="batch", desc="Predicting embeddings" 73 | ): 74 | if device is not None: 75 | images = images.to(device) 76 | all_embeddings.append(model(images).detach().cpu()) 77 | if isinstance(class_names, torch.Tensor): 78 | all_class_names += class_names.tolist() 79 | else: 80 | all_class_names += class_names 81 | 82 | concatenated_embeddings = torch.cat(all_embeddings) 83 | 84 | return pd.DataFrame( 85 | {"embedding": list(concatenated_embeddings), "class_name": all_class_names} 86 | ) 87 | 88 | 89 | def evaluate_on_one_task( 90 | model: FewShotClassifier, 91 | support_images: Tensor, 92 | support_labels: Tensor, 93 | query_images: Tensor, 94 | query_labels: Tensor, 95 | ) -> Tuple[int, int]: 96 | """ 97 | Returns the number of correct predictions of query labels, and the total number of 98 | predictions. 99 | """ 100 | model.process_support_set(support_images, support_labels) 101 | predictions = model(query_images).detach().data 102 | number_of_correct_predictions = int( 103 | (torch.max(predictions, 1)[1] == query_labels).sum().item() 104 | ) 105 | return number_of_correct_predictions, len(query_labels) 106 | 107 | 108 | def evaluate( 109 | model: FewShotClassifier, 110 | data_loader: DataLoader, 111 | device: str = "cuda", 112 | use_tqdm: bool = True, 113 | tqdm_prefix: Optional[str] = None, 114 | ) -> float: 115 | """ 116 | Evaluate the model on few-shot classification tasks 117 | Args: 118 | model: a few-shot classifier 119 | data_loader: loads data in the shape of few-shot classification tasks* 120 | device: where to cast data tensors. 121 | Must be the same as the device hosting the model's parameters. 122 | use_tqdm: whether to display the evaluation's progress bar 123 | tqdm_prefix: prefix of the tqdm bar 124 | Returns: 125 | average classification accuracy 126 | """ 127 | # We'll count everything and compute the ratio at the end 128 | total_predictions = 0 129 | correct_predictions = 0 130 | 131 | # eval mode affects the behaviour of some layers (such as batch normalization or dropout) 132 | # no_grad() tells torch not to keep in memory the whole computational graph 133 | model.eval() 134 | with torch.no_grad(): 135 | # We use a tqdm context to show a progress bar in the logs 136 | with tqdm( 137 | enumerate(data_loader), 138 | total=len(data_loader), 139 | disable=not use_tqdm, 140 | desc=tqdm_prefix, 141 | ) as tqdm_eval: 142 | for _, ( 143 | support_images, 144 | support_labels, 145 | query_images, 146 | query_labels, 147 | _, 148 | ) in tqdm_eval: 149 | correct, total = evaluate_on_one_task( 150 | model, 151 | support_images.to(device), 152 | support_labels.to(device), 153 | query_images.to(device), 154 | query_labels.to(device), 155 | ) 156 | 157 | total_predictions += total 158 | correct_predictions += correct 159 | 160 | # Log accuracy in real time 161 | tqdm_eval.set_postfix(accuracy=correct_predictions / total_predictions) 162 | 163 | return correct_predictions / total_predictions 164 | 165 | 166 | def compute_average_features_from_images( 167 | dataloader: DataLoader, 168 | model: nn.Module, 169 | device: Optional[str] = None, 170 | ): 171 | """ 172 | Compute the average features vector from all images in a DataLoader. 173 | Assumes the images are always first element of the batch. 174 | Returns: 175 | Tensor: shape (1, feature_dimension) 176 | """ 177 | all_embeddings = torch.stack( 178 | predict_embeddings(dataloader, model, device)["embedding"].to_list() 179 | ) 180 | average_features = all_embeddings.mean(dim=0) 181 | if device is not None: 182 | average_features = average_features.to(device) 183 | return average_features 184 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | ## black 2 | 3 | [tool.black] 4 | target-version = ['py311'] 5 | 6 | ## pylint 7 | 8 | [tool.pylint.messages_control] 9 | disable = [ 10 | "no-member", 11 | "missing-module-docstring", 12 | "missing-class-docstring", 13 | "too-few-public-methods", 14 | # already managed by isort: 15 | "ungrouped-imports", 16 | "wrong-import-order", 17 | "wrong-import-position", 18 | ] 19 | 20 | [tool.pylint.miscellaneous] 21 | notes = ["FIXME"] 22 | 23 | [tool.pylint.similarities] 24 | ignore-signatures = "yes" 25 | ignore-imports = "yes" 26 | min-similarity-lines = 8 27 | 28 | [tool.pylint.basic] 29 | no-docstring-rgx = "^_|^test_|^Test[A-Z]" # no docstrings for tests 30 | max-line-length = 120 31 | docstring-min-length = 15 32 | max-args=10 33 | max-attributes=10 34 | 35 | ## isort 36 | 37 | [tool.isort] 38 | profile = "black" 39 | multi_line_output = 3 40 | py_version=311 41 | 42 | ## mypy 43 | 44 | [tool.mypy] 45 | python_version = "3.11" 46 | ignore_missing_imports = true -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sicara/easy-few-shot-learning/8023ff49a02a68830c10a21b8eb908cb33bdf1b9/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/backbones_configs.json: -------------------------------------------------------------------------------- 1 | { 2 | "feat_resnet12": { 3 | "transform": { 4 | "image_size": 84, 5 | "crop_ratio": 1.15, 6 | "mean": [0.485, 0.456, 0.406], 7 | "std": [0.229, 0.224, 0.225], 8 | "interpolation": "bilinear" 9 | } 10 | } 11 | } -------------------------------------------------------------------------------- /scripts/benchmark_methods.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import Optional 4 | 5 | import typer 6 | from loguru import logger 7 | 8 | from easyfsl.utils import evaluate 9 | from scripts.utils import ( 10 | METHODS_CONFIGS_JSON, 11 | build_model, 12 | get_dataloader_from_features_path, 13 | set_random_seed, 14 | ) 15 | 16 | 17 | def main( 18 | method: str, 19 | features: Path, 20 | config: Optional[str] = None, 21 | n_way: int = 5, 22 | n_shot: int = 5, 23 | n_query: int = 15, 24 | n_tasks: int = 1000, 25 | device: str = "cuda", 26 | num_workers: int = 0, 27 | random_seed: int = 0, 28 | ) -> None: 29 | """ 30 | Evaluate a method on a dataset of features pre-extracted by a backbone. Print the average accuracy. 31 | Args: 32 | method: Few-Shot Classifier to use. 33 | features: path to a Parquet or Pickle file containing the features. 34 | config: existing configuration for the method available in scripts/methods_configs.json 35 | n_way: number of classes per task. 36 | n_shot: number of support example per class. 37 | n_query: number of query instances per class. 38 | n_tasks: number of tasks to evaluate on. 39 | device: device to use 40 | num_workers: The number of workers to use for the DataLoader. Defaults to 0 for no multiprocessing. 41 | random_seed: random seed to use for reproducibility. 42 | """ 43 | set_random_seed(random_seed) 44 | 45 | config_dict = read_config(method, config) 46 | model = build_model(method, device, **config_dict) 47 | logger.info(f"Loaded model {method} with {config} config.") 48 | 49 | features_loader = get_dataloader_from_features_path( 50 | features, n_way, n_shot, n_query, n_tasks, num_workers 51 | ) 52 | logger.info(f"Loaded features from {features}") 53 | 54 | accuracy = evaluate(model, features_loader, device) 55 | logger.info(f"Average accuracy : {(100 * accuracy):.2f} %") 56 | 57 | 58 | def read_config(method: str, config: Optional[str]) -> dict: 59 | if config is None: 60 | return {} 61 | with open(METHODS_CONFIGS_JSON, "r", encoding="utf-8") as file: 62 | all_configs = json.load(file) 63 | if method not in all_configs: 64 | raise ValueError( 65 | f"No available config for {method} in {str(METHODS_CONFIGS_JSON)}." 66 | ) 67 | configs = all_configs[method] 68 | if config not in configs: 69 | raise ValueError( 70 | f"No available config {config} for {method} in {str(METHODS_CONFIGS_JSON)}." 71 | ) 72 | return configs[config] 73 | 74 | 75 | if __name__ == "__main__": 76 | typer.run(main) 77 | -------------------------------------------------------------------------------- /scripts/grid_search.json: -------------------------------------------------------------------------------- 1 | { 2 | "tim": { 3 | "feature_normalization": [2], 4 | "fine_tuning_steps": [50, 100, 200, 300], 5 | "fine_tuning_lr": [0.001, 0.0001], 6 | "cross_entropy_weight": [1.0], 7 | "marginal_entropy_weight": [1.0], 8 | "conditional_entropy_weight": [0.1, 0.5, 1.0], 9 | "temperature": [1, 5, 10] 10 | }, 11 | "finetune": { 12 | "feature_normalization": [2], 13 | "fine_tuning_steps": [50, 100, 200, 300], 14 | "fine_tuning_lr": [0.001, 0.0001], 15 | "temperature": [1, 5, 10] 16 | }, 17 | "pt_map": { 18 | "feature_normalization": [2], 19 | "fine_tuning_steps": [10, 20, 30], 20 | "fine_tuning_lr": [0.2, 0.3, 0.5], 21 | "lambda_regularization": [10, 20, 30] 22 | }, 23 | "laplacian_shot": { 24 | "feature_normalization": [2], 25 | "inference_steps": [10, 20, 30], 26 | "knn": [1, 3, 5, 7], 27 | "lambda_regularization": [0.1, 0.3, 0.5, 0.7, 0.8] 28 | }, 29 | "transductive_finetuning": { 30 | "feature_normalization": [2], 31 | "fine_tuning_steps": [10, 25, 40], 32 | "fine_tuning_lr": [0.001, 0.0001, 0.00001], 33 | "temperature": [1, 5, 10] 34 | } 35 | } -------------------------------------------------------------------------------- /scripts/hyperparameter_search.py: -------------------------------------------------------------------------------- 1 | import json 2 | from itertools import product 3 | from pathlib import Path 4 | from typing import Optional 5 | 6 | import pandas as pd 7 | import typer 8 | from loguru import logger 9 | 10 | from easyfsl.utils import evaluate 11 | from scripts.utils import ( 12 | GRID_SEARCH_JSON, 13 | build_model, 14 | get_dataloader_from_features_path, 15 | set_random_seed, 16 | ) 17 | 18 | 19 | def main( # pylint: disable=too-many-locals 20 | method: str, 21 | features: Path, 22 | n_way: int = 5, 23 | n_shot: int = 5, 24 | n_query: int = 15, 25 | n_tasks: int = 500, 26 | device: str = "cuda", 27 | num_workers: int = 0, 28 | random_seed: int = 0, 29 | output_csv: Optional[Path] = None, 30 | ) -> None: 31 | """ 32 | Perform hyperparameter grid search for a method on a dataset of features pre-extracted by a backbone. 33 | Outputs the results in a csv file with all tested combinations of parameters and the corresponding accuracy. 34 | Args: 35 | method: Few-Shot Classifier to use. 36 | features: path to a Parquet or Pickle file containing the features. 37 | n_way: number of classes per task. 38 | n_shot: number of support example per class. 39 | n_query: number of query instances per class. 40 | n_tasks: number of tasks to evaluate on. 41 | device: device to use 42 | num_workers: The number of workers to use for the DataLoader. Defaults to 0 for no multiprocessing. 43 | random_seed: random seed to use for reproducibility. 44 | output_csv: path to the output csv file. 45 | """ 46 | set_random_seed(random_seed) 47 | hyperparameter_grid_df = read_hyperparameter_grid(method) 48 | logger.info( 49 | f"Loaded {len(hyperparameter_grid_df)} hyperparameter combinations for {method}." 50 | ) 51 | 52 | features_loader = get_dataloader_from_features_path( 53 | features, n_way, n_shot, n_query, n_tasks, num_workers 54 | ) 55 | logger.info(f"Loaded features from {features}") 56 | 57 | accuracies_record = [] 58 | for config_dict in iter(hyperparameter_grid_df.to_dict(orient="records")): 59 | model = build_model(method, device, **config_dict) 60 | logger.info(f"Loaded model {method} with following config:") 61 | logger.info(json.dumps(config_dict, indent=4)) 62 | accuracy = evaluate(model, features_loader, device) 63 | accuracies_record.append(accuracy) 64 | logger.info(f"Average accuracy : {(100 * accuracy):.2f} %") 65 | 66 | hyperparameter_grid_df = hyperparameter_grid_df.assign(accuracy=accuracies_record) 67 | logger.info(f"Hyperparameter search results for {method}:") 68 | logger.info("Best hyperparameters:") 69 | logger.info( 70 | json.dumps( 71 | hyperparameter_grid_df.sort_values("accuracy", ascending=False) 72 | .iloc[0] 73 | .to_dict(), 74 | indent=4, 75 | ) 76 | ) 77 | 78 | if output_csv is None: 79 | output_csv = Path(f"{method}_hyperparameter_search.csv") 80 | 81 | hyperparameter_grid_df.to_csv(output_csv, index=False) 82 | logger.info(f"Saved results in {output_csv}") 83 | 84 | 85 | def read_hyperparameter_grid(method: str) -> pd.DataFrame: 86 | with open(GRID_SEARCH_JSON, "r", encoding="utf-8") as file: 87 | all_grids = json.load(file) 88 | if method not in all_grids: 89 | raise ValueError( 90 | f"No available hyperparameter grid for {method} in {str(GRID_SEARCH_JSON)}." 91 | ) 92 | grid = all_grids[method] 93 | return pd.DataFrame(unroll_grid(grid)) 94 | 95 | 96 | def unroll_grid(input_dict: dict[str, list]) -> list[dict]: 97 | """ 98 | Unroll a grid of hyperparameters into a list of dicts. 99 | Args: 100 | input_dict: each key is a parameter name, each value is a list of values for this parameter. 101 | Returns: 102 | a list of dicts, each dict is a combination of parameters. 103 | Examples: 104 | >>> unroll_grid({"a": [1, 2], "b": [3, 4]}) 105 | [{"a": 1, "b": 3}, {"a": 1, "b": 4}, {"a": 2, "b": 3}, {"a": 2, "b": 4}] 106 | """ 107 | return [ 108 | dict(zip(input_dict.keys(), values)) for values in product(*input_dict.values()) 109 | ] 110 | 111 | 112 | if __name__ == "__main__": 113 | typer.run(main) 114 | -------------------------------------------------------------------------------- /scripts/methods_configs.json: -------------------------------------------------------------------------------- 1 | { 2 | "tim": { 3 | "default": { 4 | "fine_tuning_steps": 50, 5 | "fine_tuning_lr": 0.001, 6 | "cross_entropy_weight": 1.0, 7 | "marginal_entropy_weight": 1.0, 8 | "conditional_entropy_weight": 0.5, 9 | "temperature": 10.0, 10 | "feature_normalization": 2 11 | } 12 | }, 13 | "feat": { 14 | "resnet12_mini_imagenet": { 15 | "checkpoint_path": "data/models/feat_resnet12_mini_imagenet.pth", 16 | "feature_dimension": 640 17 | }, 18 | "resnet12_tiered_imagenet": { 19 | "checkpoint_path": "data/models/feat_resnet12_mini_imagenet.pth", 20 | "feature_dimension": 640 21 | } 22 | }, 23 | "finetune": { 24 | "default": { 25 | "fine_tuning_steps": 50, 26 | "fine_tuning_lr": 0.0001, 27 | "temperature": 10.0, 28 | "feature_normalization": 2 29 | } 30 | }, 31 | "pt_map": { 32 | "default": { 33 | "fine_tuning_steps": 30, 34 | "fine_tuning_lr": 0.3, 35 | "lambda_regularization": 10, 36 | "power_factor": 0.5, 37 | "feature_normalization": 2 38 | } 39 | }, 40 | "transductive_finetuning": { 41 | "default": { 42 | "fine_tuning_steps": 10, 43 | "fine_tuning_lr": 0.0001, 44 | "temperature": 10.0, 45 | "feature_normalization": 2 46 | } 47 | }, 48 | "laplacian_shot": { 49 | "default": { 50 | "inference_steps": 10, 51 | "knn": 3, 52 | "lambda_regularization": 0.1, 53 | "feature_normalization": 2 54 | } 55 | }, 56 | "bd_cspn": { 57 | "default": { 58 | "feature_normalization": 2 59 | } 60 | }, 61 | "prototypical_networks": { 62 | "default": { 63 | "feature_normalization": 2 64 | } 65 | }, 66 | "simple_shot": { 67 | "default": { 68 | "feature_normalization": 2 69 | } 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /scripts/predict_embeddings.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import Optional 4 | 5 | import pandas as pd 6 | import typer 7 | from loguru import logger 8 | from torch import nn 9 | from torch.utils.data import DataLoader 10 | from torchvision import transforms 11 | from torchvision.transforms import InterpolationMode 12 | 13 | from easyfsl.datasets import ( 14 | CUB, 15 | DanishFungi, 16 | FewShotDataset, 17 | MiniImageNet, 18 | TieredImageNet, 19 | ) 20 | from easyfsl.modules.build_from_checkpoint import feat_resnet12_from_checkpoint 21 | from easyfsl.utils import predict_embeddings 22 | 23 | BACKBONES_DICT = { 24 | "feat_resnet12": feat_resnet12_from_checkpoint, 25 | } 26 | BACKBONES_CONFIGS_JSON = Path("scripts/backbones_configs.json") 27 | INTERPOLATIONS = { 28 | "bilinear": InterpolationMode.BILINEAR, 29 | "bicubic": InterpolationMode.BICUBIC, 30 | } 31 | 32 | DATASETS_DICT = { 33 | "cub": CUB, 34 | "fungi": DanishFungi, 35 | "mini_imagenet": MiniImageNet, 36 | "tiered_imagenet": TieredImageNet, 37 | } 38 | DEFAULT_FUNGI_PATH = Path("data/fungi/images") 39 | DEFAULT_MINI_IMAGENET_PATH = Path("data/mini_imagenet/images") 40 | 41 | 42 | def main( 43 | backbone: str, 44 | checkpoint: Path, 45 | dataset: str, 46 | split: str = "test", 47 | device: str = "cuda", 48 | batch_size: int = 128, 49 | num_workers: int = 0, 50 | output_parquet: Optional[Path] = None, 51 | ) -> None: 52 | """ 53 | Use a pretrained backbone to extract embeddings from a dataset, and save them as Parquet. 54 | Args: 55 | backbone: The name of the backbone to use. 56 | checkpoint: The path to the checkpoint to use. 57 | dataset: The name of the dataset to use. 58 | split: Which split to use among train, val test. Some datasets only have a test split. 59 | device: The device to use. 60 | batch_size: The batch size to use. 61 | num_workers: The number of workers to use for the DataLoader. Defaults to 0 for no multiprocessing. 62 | output_parquet: Where to save the extracted embeddings. Defaults to 63 | {backbone}_{dataset}_{split}.parquet.gzip in the current directory. 64 | """ 65 | model = build_backbone(backbone, checkpoint, device) 66 | logger.info(f"Loaded backbone {backbone} from {checkpoint}") 67 | 68 | dataset_transform = get_dataset_transform(backbone) 69 | 70 | initialized_dataset = get_dataset(dataset, split, dataset_transform) 71 | dataloader = DataLoader( 72 | initialized_dataset, 73 | batch_size=batch_size, 74 | num_workers=num_workers, 75 | shuffle=False, 76 | ) 77 | logger.info(f"Loaded dataset {dataset} ({split} split)") 78 | 79 | embeddings_df = predict_embeddings(dataloader, model, device=device) 80 | cast_embeddings_to_numpy(embeddings_df) 81 | 82 | if output_parquet is None: 83 | output_parquet = ( 84 | Path("data/features") 85 | / dataset 86 | / split 87 | / checkpoint.with_suffix(".parquet.gzip").name 88 | ) 89 | output_parquet.parent.mkdir(parents=True, exist_ok=True) 90 | 91 | embeddings_df.to_parquet(output_parquet, index=False, compression="gzip") 92 | logger.info(f"Saved embeddings to {output_parquet}") 93 | 94 | 95 | def build_backbone( 96 | backbone: str, 97 | checkpoint: Path, 98 | device: str, 99 | ) -> nn.Module: 100 | """ 101 | Build a backbone from a checkpoint. 102 | Args: 103 | backbone: name of the backbone. Must be a key of BACKBONES_DICT. 104 | checkpoint: path to the checkpoint 105 | device: device on which to build the backbone 106 | Returns: 107 | The backbone, loaded from the checkpoint, and in eval mode. 108 | """ 109 | if backbone not in BACKBONES_DICT: 110 | raise ValueError( 111 | "Unknown backbone name. " f"Valid names are {BACKBONES_DICT.keys()}" 112 | ) 113 | model = BACKBONES_DICT[backbone](checkpoint, device) 114 | model.eval() 115 | 116 | return model 117 | 118 | 119 | def get_dataset_transform(backbone_name: str) -> transforms.Compose: 120 | """ 121 | Get the transform to apply to the images before feeding them to the backbone. 122 | Use the config defined for the specified backbone at scripts/backbones_configs.json. 123 | Args: 124 | backbone_name: must be a key in scripts/backbones_configs.json. 125 | Returns: 126 | A callable to apply to the images, with a resize, a center-crop, a conversion to tensor, and a normalization. 127 | """ 128 | with open(BACKBONES_CONFIGS_JSON, "r", encoding="utf-8") as file: 129 | all_configs = json.load(file) 130 | if backbone_name not in all_configs: 131 | raise ValueError( 132 | f"No available config for {backbone_name} in {str(BACKBONES_CONFIGS_JSON)}." 133 | ) 134 | transform_config = all_configs[backbone_name]["transform"] 135 | return transforms.Compose( 136 | [ 137 | transforms.Resize( 138 | int(transform_config["image_size"] * transform_config["crop_ratio"]), 139 | interpolation=INTERPOLATIONS[transform_config["interpolation"]], 140 | ), 141 | transforms.CenterCrop(transform_config["image_size"]), 142 | transforms.ToTensor(), 143 | transforms.Normalize(transform_config["mean"], transform_config["std"]), 144 | ] 145 | ) 146 | 147 | 148 | def get_dataset( 149 | dataset_name: str, split: str, transform: transforms.Compose 150 | ) -> FewShotDataset: 151 | """ 152 | Get a dataset using the built-in constructors from EasyFSL. 153 | Args: 154 | dataset_name: must be one of "cub", "tiered_imagenet", "mini_imagenet", "fungi". 155 | split: train, val, or test 156 | transform: a callable to apply to the images. 157 | Returns: 158 | The requested dataset. 159 | """ 160 | if dataset_name not in DATASETS_DICT: 161 | raise ValueError( 162 | "Unknown dataset name. " f"Valid names are {DATASETS_DICT.keys()}" 163 | ) 164 | if dataset_name == "fungi": 165 | if split != "test": 166 | raise ValueError("Danish Fungi only has a test set.") 167 | return DanishFungi(DEFAULT_FUNGI_PATH, transform=transform) 168 | if dataset_name == "mini_imagenet": 169 | return MiniImageNet( 170 | root=DEFAULT_MINI_IMAGENET_PATH, 171 | split=split, 172 | training=False, 173 | transform=transform, 174 | ) 175 | return DATASETS_DICT[dataset_name](split=split, training=False, transform=transform) 176 | 177 | 178 | def cast_embeddings_to_numpy(embeddings_df: pd.DataFrame) -> None: 179 | """ 180 | Cast the tensor embeddings in a DataFrame to numpy arrays, in an inplace fashion. 181 | Args: 182 | embeddings_df: dataframe with an "embeddings" column containing torch tensors. 183 | """ 184 | embeddings_df["embedding"] = embeddings_df["embedding"].apply( 185 | lambda embedding: embedding.detach().cpu().numpy() 186 | ) 187 | 188 | 189 | if __name__ == "__main__": 190 | typer.run(main) 191 | -------------------------------------------------------------------------------- /scripts/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | from torch import nn 8 | from torch.utils.data import DataLoader 9 | 10 | from easyfsl.datasets import FeaturesDataset 11 | from easyfsl.methods import ( 12 | BDCSPN, 13 | FEAT, 14 | PTMAP, 15 | TIM, 16 | FewShotClassifier, 17 | Finetune, 18 | LaplacianShot, 19 | MatchingNetworks, 20 | PrototypicalNetworks, 21 | RelationNetworks, 22 | SimpleShot, 23 | TransductiveFinetuning, 24 | ) 25 | from easyfsl.samplers import TaskSampler 26 | 27 | METHODS_DICT = { 28 | "bd_cspn": BDCSPN, 29 | "feat": FEAT, 30 | "finetune": Finetune, 31 | "laplacian_shot": LaplacianShot, 32 | "matching_networks": MatchingNetworks, 33 | "prototypical_networks": PrototypicalNetworks, 34 | "pt_map": PTMAP, 35 | "relation_networks": RelationNetworks, 36 | "simple_shot": SimpleShot, 37 | "tim": TIM, 38 | "transductive_finetuning": TransductiveFinetuning, 39 | } 40 | METHODS_CONFIGS_JSON = Path("scripts/methods_configs.json") 41 | GRID_SEARCH_JSON = Path("scripts/grid_search.json") 42 | 43 | 44 | def set_random_seed(seed: int): 45 | """ 46 | Set random, numpy and torch random seed, for reproducibility of the training 47 | Args: 48 | seed: defined random seed 49 | """ 50 | np.random.seed(seed) 51 | torch.manual_seed(seed) 52 | random.seed(seed) 53 | torch.backends.cudnn.deterministic = True 54 | torch.backends.cudnn.benchmark = False 55 | 56 | 57 | def build_model( 58 | method: str, 59 | device: str, 60 | **kwargs, 61 | ) -> FewShotClassifier: 62 | """ 63 | Build a model from a method name and an optional config. 64 | Args: 65 | method: must be one of the keys of METHODS_DICT 66 | device: device to cast the model to (cpu or cuda) 67 | kwargs: optional hyperparameters to initialize the model 68 | Returns: 69 | the requested FewShotClassifier 70 | """ 71 | if method not in METHODS_DICT: 72 | raise ValueError( 73 | "Unknown method name. " f"Valid names are {METHODS_DICT.keys()}" 74 | ) 75 | 76 | if method == "feat": 77 | return FEAT.from_resnet12_checkpoint( 78 | **kwargs, device=device, use_backbone=False 79 | ) 80 | 81 | model = METHODS_DICT[method](nn.Identity(), **kwargs).to(device) 82 | return model 83 | 84 | 85 | def get_dataloader_from_features_path( 86 | features: Path, 87 | n_way: int, 88 | n_shot: int, 89 | n_query: int, 90 | n_tasks: int, 91 | num_workers: int, 92 | ): 93 | """ 94 | Build a dataloader from a path to a pickle file containing a dict mapping labels to all their embeddings. 95 | Args: 96 | features: path to a Parquet or Pickle file containing the features. 97 | n_way: number of classes per task. 98 | n_shot: number of support example per class. 99 | n_query: number of query instances per class. 100 | n_tasks: number of tasks to evaluate on. 101 | num_workers: The number of workers to use for the DataLoader. 102 | Returns: 103 | a DataLoader that yields features in the shape of a task 104 | """ 105 | features_dataset = get_dataset(features) 106 | task_sampler = TaskSampler( 107 | features_dataset, 108 | n_way=n_way, 109 | n_shot=n_shot, 110 | n_query=n_query, 111 | n_tasks=n_tasks, 112 | ) 113 | features_loader = DataLoader( 114 | features_dataset, 115 | batch_sampler=task_sampler, 116 | num_workers=num_workers, 117 | pin_memory=True, 118 | collate_fn=task_sampler.episodic_collate_fn, 119 | ) 120 | return features_loader 121 | 122 | 123 | def get_dataset(features_path: Path) -> FeaturesDataset: 124 | """ 125 | Load a FeaturesDataset from a path to either a pickle file containing a dict mapping labels to all their embeddings, 126 | or a parquet file containing a dataframe with the columns "embedding" and "label". 127 | Args: 128 | features_path: path to a pickle or parquet file containing the features. 129 | Returns: 130 | a FeaturesDataset 131 | """ 132 | if features_path.suffix == ".pickle": 133 | embeddings_dict = pd.read_pickle(features_path) 134 | return FeaturesDataset.from_dict(embeddings_dict) 135 | embeddings_df = pd.read_parquet(features_path) 136 | return FeaturesDataset.from_dataframe(embeddings_df) 137 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | 4 | with open("README.md", "r") as f: 5 | long_description = f.read() 6 | 7 | 8 | setup( 9 | name="easyfsl", 10 | version="1.5.0", 11 | description="Ready-to-use PyTorch code to boost your way into few-shot image classification", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/sicara/easy-few-shot-learning", 15 | license="MIT", 16 | install_requires=[ 17 | "matplotlib>=3.0.0", 18 | "pandas>=1.5.0", 19 | "torch>=1.5.0", 20 | "torchvision>=0.7.0", 21 | "tqdm>=4.1.0", 22 | ], 23 | packages=find_packages(), 24 | python_requires=">=3.6", 25 | entry_points={}, 26 | classifiers=[ 27 | "Development Status :: 3 - Alpha", 28 | "License :: OSI Approved :: MIT License", 29 | "Programming Language :: Python", 30 | "Programming Language :: Python :: 3", 31 | "Programming Language :: Python :: 3.7", 32 | "Programming Language :: Python :: 3.8", 33 | "Programming Language :: Python :: 3.9", 34 | "Programming Language :: Python :: 3.10", 35 | "Programming Language :: Python :: 3.11", 36 | "Operating System :: OS Independent", 37 | ], 38 | ) 39 | --------------------------------------------------------------------------------